Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions sequence_layers/mlx/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ class Config(spec.Lambda.Config):
fn: Callable
sequence_input: bool = False
mask_required: bool = True
# Accepted for JAX compatibility but ignored by MLX Lambda.
# Used to guide shape/dtype probing (e.g. for bitcasting operations).
expected_input_spec: object = None
expected_output_spec: object = None
name: str | None = None
Expand All @@ -1433,7 +1433,17 @@ def _probe_output(self, input_shape, input_dtype):
if cache_key in self._cached_output_specs:
return self._cached_output_specs[cache_key]
try:
dummy_values = mx.zeros((1, 1) + tuple(input_shape), dtype=input_dtype)
probe_shape = tuple(input_shape)
probe_dtype = input_dtype
if self._expected_input_spec is not None:
probe_shape = tuple(self._expected_input_spec.shape)
try:
from sequence_layers.mlx.init_mapping import _to_mx_dtype
probe_dtype = _to_mx_dtype(self._expected_input_spec.dtype)
except Exception:
probe_dtype = self._expected_input_spec.dtype

dummy_values = mx.zeros((1, 1) + probe_shape, dtype=probe_dtype)
dummy_mask = mx.ones((1, 1), dtype=mx.bool_)
assert self.config.fn is not None
if self.config.sequence_input:
Expand Down