From b48c628f8d66b157a36431711c97de6cb17852b6 Mon Sep 17 00:00:00 2001 From: Lancelot Blanchard Date: Mon, 22 Jun 2026 11:47:34 -0700 Subject: [PATCH] Add support for _expected_input_spec in MLX Lambda to match JAX implementation --- sequence_layers/mlx/simple.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/sequence_layers/mlx/simple.py b/sequence_layers/mlx/simple.py index 938a4e1..bee403e 100644 --- a/sequence_layers/mlx/simple.py +++ b/sequence_layers/mlx/simple.py @@ -1406,7 +1406,7 @@ class Config(spec.Lambda.Config): fn: Callable sequence_input: bool = False mask_required: bool = True - # Accepted for JAX compatibility but ignored by MLX Lambda. + # Used to guide shape/dtype probing (e.g. for bitcasting operations). expected_input_spec: object = None expected_output_spec: object = None name: str | None = None @@ -1433,7 +1433,17 @@ def _probe_output(self, input_shape, input_dtype): if cache_key in self._cached_output_specs: return self._cached_output_specs[cache_key] try: - dummy_values = mx.zeros((1, 1) + tuple(input_shape), dtype=input_dtype) + probe_shape = tuple(input_shape) + probe_dtype = input_dtype + if self._expected_input_spec is not None: + probe_shape = tuple(self._expected_input_spec.shape) + try: + from sequence_layers.mlx.init_mapping import _to_mx_dtype + probe_dtype = _to_mx_dtype(self._expected_input_spec.dtype) + except Exception: + probe_dtype = self._expected_input_spec.dtype + + dummy_values = mx.zeros((1, 1) + probe_shape, dtype=probe_dtype) dummy_mask = mx.ones((1, 1), dtype=mx.bool_) assert self.config.fn is not None if self.config.sequence_input: