From 826cfba90c556e12267bf7498cc6a5913201a6d0 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo Date: Fri, 26 Jun 2026 12:16:29 +0300 Subject: [PATCH 1/3] Add MLX backend dispatch for Scan Op Port the JAX scan dispatch to MLX. As MLX has no native general-scan primitive, the inner fgraph is driven by a Python carry loop that `mx.compile` unrolls. Because scalar values are not readable while MLX traces, the (full-sized) recurring buffers are used to infer the number of steps, falling back to a constant `n_steps` or the sequence length. Covers seqs, MIT-SOT, SIT-SOT, NIT-SOT, untraced SIT-SOT, MIT-MOT (scan gradients) and non-sequences, mirroring the JAX semantics of recreating the trace and prepending/truncating to the buffer size. Gradients over sequences reverse the trace and currently trip a separate MLX bug (an elementwise op fed by a negative-stride array is miscompiled under `mx.compile`); this is captured as a strict xfail under the full `mode="MLX"` and is addressed by a follow-up. Co-authored-by: Cursor --- pytensor/link/mlx/dispatch/__init__.py | 1 + pytensor/link/mlx/dispatch/scan.py | 241 +++++++++++++++++++++++++ tests/link/mlx/test_scan.py | 210 +++++++++++++++++++++ 3 files changed, 452 insertions(+) create mode 100644 pytensor/link/mlx/dispatch/scan.py create mode 100644 tests/link/mlx/test_scan.py diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py index 59b0604856..ed71f95764 100644 --- a/pytensor/link/mlx/dispatch/__init__.py +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -17,4 +17,5 @@ import pytensor.link.mlx.dispatch.pad import pytensor.link.mlx.dispatch.sort import pytensor.link.mlx.dispatch.linalg +import pytensor.link.mlx.dispatch.scan # isort: on diff --git a/pytensor/link/mlx/dispatch/scan.py b/pytensor/link/mlx/dispatch/scan.py new file mode 100644 index 0000000000..3ed159f8e6 --- /dev/null +++ b/pytensor/link/mlx/dispatch/scan.py @@ -0,0 +1,241 @@ +from itertools import chain + +import mlx.core as mx + +from pytensor.compile.mode import MLX, get_mode +from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.scan.op import Scan +from pytensor.tensor.basic import get_scalar_constant_value +from pytensor.tensor.exceptions import NotScalarConstantError + + +@mlx_funcify.register(Scan) +def mlx_funcify_Scan(op: Scan, node, **kwargs): + # Mirrors the JAX dispatch (`link/jax/dispatch/scan.py`): the loop recreates + # the concatenated trace from the per-step values and prepends the initial + # state / truncates as needed, instead of writing into the PyTensor buffers. + # MLX has no native general-scan primitive, so JAX's `lax.scan` is replaced + # by a Python carry loop that `mx.compile` unrolls. This needs a statically + # known number of steps, which is read from the (full-sized) recurring + # buffers since `scan_reduce_trace_prealloc` is excluded for MLX. + info = op.info + + if info.as_while: + raise NotImplementedError("While Scan cannot yet be converted to MLX") + + # NIT-SOT output lengths are runtime scalars under `mx.compile`; take the + # static output shape when known and fall back to ``n_steps`` otherwise. + nitsot_static_sizes = [ + out.type.shape[0] for out in op.outer_nitsot_outs(node.outputs) + ] + + # A constant ``n_steps`` is authoritative (and the only inference source for + # scans without recurring buffers or sequences, e.g. a pure NIT-SOT map). + try: + static_n_steps = int(get_scalar_constant_value(node.inputs[0])) + except NotScalarConstantError: + static_n_steps = None + + rewriter = ( + get_mode(op.mode) + .including("mlx") + .excluding("numba", *MLX._optimizer.exclude) + .optimizer + ) + rewriter(op.fgraph) + scan_inner_func = mlx_funcify(op.fgraph, **kwargs) + + def scan(*outer_inputs): + outer_inputs = list(outer_inputs) + n_steps = _infer_n_steps(op, outer_inputs, nitsot_static_sizes, static_n_steps) + seqs = [seq[:n_steps] for seq in op.outer_seqs(outer_inputs)] + + # MIT-MOT have no "initial state"; the whole buffer is meaningful. + # MIT-SOT and SIT-SOT initial states are extracted from the buffers. + # The ``_init`` states are kept untouched (they prepend the final traces), + # while ``carry`` copies evolve through the loop. + mit_sot_init = [ + buff[: -min(tap)] + for buff, tap in zip( + op.outer_mitsot(outer_inputs), info.mit_sot_in_slices, strict=True + ) + ] + sit_sot_init = [buff[0] for buff in op.outer_sitsot(outer_inputs)] + + mit_mot = list(op.outer_mitmot(outer_inputs)) + mit_sot = list(mit_sot_init) + sit_sot = list(sit_sot_init) + untraced_sit_sot = list(op.outer_untraced_sit_sot(outer_inputs)) + non_seqs = op.outer_non_seqs(outer_inputs) + + n_traced = info.n_mit_sot + info.n_sit_sot + info.n_nit_sot + traces: list[list] = [[] for _ in range(n_traced)] + for i in range(n_steps): + inner_seqs = [seq[i] for seq in seqs] + mit_mot_flatten = list( + chain.from_iterable( + buffer[[i + tap for tap in taps]] + for buffer, taps in zip( + mit_mot, info.normalized_mit_mot_in_slices, strict=True + ) + ) + ) + mit_sot_flatten = list( + chain.from_iterable( + buffer[list(taps)] + for buffer, taps in zip( + mit_sot, info.mit_sot_in_slices, strict=True + ) + ) + ) + + inner_outs = list( + scan_inner_func( + *inner_seqs, + *mit_mot_flatten, + *mit_sot_flatten, + *sit_sot, + *untraced_sit_sot, + *non_seqs, + ) + ) + + new_mit_mot_vals = op.inner_mitmot_outs_grouped(inner_outs) + new_mit_sot_vals = op.inner_mitsot_outs(inner_outs) + new_sit_sot = op.inner_sitsot_outs(inner_outs) + new_nit_sot = op.inner_nitsot_outs(inner_outs) + new_untraced_sit_sot = op.inner_untraced_sit_sot_outs(inner_outs) + + # Write the new MIT-MOT values at the output-tap positions. + mit_mot = [ + _functional_set( + buffer, [i + tap for tap in taps], mx.stack(new_vals, axis=0) + ) + for buffer, new_vals, taps in zip( + mit_mot, + new_mit_mot_vals, + info.normalized_mit_mot_out_slices, + strict=True, + ) + ] + # Discard oldest MIT-SOT tap and append the newest value. + mit_sot = [ + mx.concatenate([buffer[1:], new_val[None, ...]], axis=0) + for buffer, new_val in zip(mit_sot, new_mit_sot_vals, strict=True) + ] + sit_sot = new_sit_sot + untraced_sit_sot = new_untraced_sit_sot + + step_traced = [*new_mit_sot_vals, *new_sit_sot, *new_nit_sot] + for trace, val in zip(traces, step_traced, strict=True): + trace.append(val) + + # Per-step shape of each traced output (for synthesizing empty traces + # when ``n_steps == 0``): MIT-SOT/SIT-SOT match their state shape. + traced_trailing = ( + [tuple(init.shape[1:]) for init in mit_sot_init] + + [tuple(init.shape) for init in sit_sot_init] + + [() for _ in range(info.n_nit_sot)] + ) + stacked_traces = [ + mx.stack(trace, axis=0) if trace else mx.zeros((0, *trailing)) + for trace, trailing in zip(traces, traced_trailing, strict=True) + ] + + def get_partial_traces(traces): + """Prepend initial states and slice traces down to buffer sizes.""" + init_states = mit_sot_init + sit_sot_init + [None] * info.n_nit_sot + buffer_sizes = ( + [buff.shape[0] for buff in op.outer_mitsot(outer_inputs)] + + [buff.shape[0] for buff in op.outer_sitsot(outer_inputs)] + + [ + size if size is not None else n_steps + for size in nitsot_static_sizes + ] + ) + partial_traces = [] + for init_state, trace, buffer_size in zip( + init_states, traces, buffer_sizes, strict=True + ): + if init_state is not None: + if trace.shape[0] >= buffer_size: + # Trace at least as long as the buffer: keep the tail. + partial_trace = trace[-buffer_size:] + else: + # Trace shorter than the buffer: prepend (part of) init. + if init_state.ndim < trace.ndim: + init_state = init_state[None] + if ( + n_init_needed := buffer_size - trace.shape[0] + ) < init_state.shape[0]: + init_state = init_state[-n_init_needed:] + partial_trace = mx.concatenate([init_state, trace], axis=0) + else: + partial_trace = ( + trace[-buffer_size:] if trace.shape[0] > buffer_size else trace + ) + + assert partial_trace.shape[0] == buffer_size + partial_traces.append(partial_trace) + + return partial_traces + + scan_outs_final = [ + *mit_mot, + *get_partial_traces(stacked_traces), + *untraced_sit_sot, + ] + + if len(scan_outs_final) == 1: + return scan_outs_final[0] + return scan_outs_final + + return scan + + +def _infer_n_steps(op, outer_inputs, nitsot_static_sizes, static_n_steps): + """Derive the number of steps for the unrolled loop. + + Scalar input values are not readable while ``mx.compile`` traces, but array + shapes are concrete. A constant ``n_steps`` is used directly; otherwise the + count comes from a recurring buffer (which stays full-sized because the + trace-prealloc reduction is disabled for MLX, so each encodes ``n_steps`` + plus its initial taps) or a sequence. A non-constant ``n_steps`` with no + such buffer (e.g. a dynamic-length pure ``map``) is an MLX static-shape + limitation, like dynamic ``arange``. + """ + info = op.info + if static_n_steps is not None: + return static_n_steps + for buff in op.outer_sitsot(outer_inputs): + return buff.shape[0] - 1 + for buff, taps in zip( + op.outer_mitsot(outer_inputs), info.mit_sot_in_slices, strict=True + ): + return buff.shape[0] + min(taps) + for seq in op.outer_seqs(outer_inputs): + return seq.shape[0] + for buff, in_taps, out_taps in zip( + op.outer_mitmot(outer_inputs), + info.normalized_mit_mot_in_slices, + info.normalized_mit_mot_out_slices, + strict=True, + ): + return buff.shape[0] - (max(*in_taps, *out_taps) - min(*in_taps, *out_taps)) + for size in nitsot_static_sizes: + if size is not None: + return size + raise NotImplementedError( + "MLX Scan requires a statically known number of steps when there are no " + "recurring buffers or sequences to infer it from." + ) + + +def _functional_set(buffer, idx, vals): + """Return ``buffer`` with rows ``idx`` set to ``vals``. + + MLX has no ``.at[].set`` and in-place item assignment aliases buffers under + ``mx.compile``, so a functional scatter-add of the delta is used instead. + """ + idx = mx.array(idx) + return buffer.at[idx].add(vals - buffer[idx]) diff --git a/tests/link/mlx/test_scan.py b/tests/link/mlx/test_scan.py new file mode 100644 index 0000000000..90293e1407 --- /dev/null +++ b/tests/link/mlx/test_scan.py @@ -0,0 +1,210 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor.scan import until +from pytensor.scan.basic import scan +from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode + + +mx = pytest.importorskip("mlx.core") + + +@pytest.mark.parametrize("view", [None, (-1,), slice(-2, None, None)]) +def test_scan_sit_sot(view): + x0 = pt.scalar("x0", dtype="float32") + xs = scan( + lambda xtm1: xtm1 + 1, + outputs_info=[x0], + n_steps=10, + return_updates=False, + ) + if view: + xs = xs[view] + compare_mlx_and_py([x0], [xs], [np.float32(np.e)]) + + +@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)]) +def test_scan_mit_sot(view): + x0 = pt.vector("x0", dtype="float32", shape=(3,)) + xs = scan( + lambda xtm3, xtm1: xtm3 + xtm1 + 1, + outputs_info=[{"initial": x0, "taps": [-3, -1]}], + n_steps=10, + return_updates=False, + ) + if view: + xs = xs[view] + compare_mlx_and_py([x0], [xs], [np.full((3,), np.e, dtype="float32")]) + + +def test_scan_with_sequence_and_non_seq(): + # RNN-style recurrence over a sequence with a shared weight matrix. + xs = pt.matrix("xs", dtype="float32") + h0 = pt.vector("h0", dtype="float32", shape=(3,)) + W = pt.matrix("W", dtype="float32", shape=(3, 3)) + hs = scan( + lambda x_t, h_tm1, W: pt.tanh(x_t + h_tm1 @ W), + sequences=[xs], + outputs_info=[h0], + non_sequences=[W], + return_updates=False, + ) + rng = np.random.default_rng(0) + compare_mlx_and_py( + [xs, h0, W], + [hs], + [ + rng.standard_normal((5, 3)).astype("float32"), + np.zeros(3, dtype="float32"), + (0.1 * rng.standard_normal((3, 3))).astype("float32"), + ], + ) + + +def test_scan_multiple_outputs(): + # One recurring (sit_sot) and one mapped (nit_sot) output. + s = pt.vector("s", dtype="float32") + + def step(s_t, acc): + return acc + s_t, s_t * s_t + + acc, sq = scan( + step, + sequences=[s], + outputs_info=[pt.zeros((), dtype="float32"), None], + return_updates=False, + ) + compare_mlx_and_py([s], [acc, sq], [np.arange(1, 5, dtype="float32")]) + + +def test_scan_nit_sot_only(): + # Pure tiling/map from a non-sequence with an explicit ``n_steps`` (no + # recurring buffer or sequence to infer the step count from). + w = pt.scalar("w", dtype="float32") + ys = scan( + lambda w: w * 2, + outputs_info=[None], + non_sequences=[w], + n_steps=5, + return_updates=False, + ) + compare_mlx_and_py([w], [ys], [np.float32(3.0)]) + + +def test_scan_multiple_recurring_states(): + # MIT-SOT (taps -2, -1) and SIT-SOT and a NIT-SOT map in one scan. + x0 = pt.vector("x0", dtype="float32", shape=(2,)) + + def step(xtm2, xtm1, stm1): + x_t = xtm2 + xtm1 + return x_t, stm1 + x_t, x_t * 2 + + xs, ss, ys = scan( + step, + outputs_info=[{"initial": x0, "taps": [-2, -1]}, pt.zeros((), "float32"), None], + n_steps=6, + return_updates=False, + ) + compare_mlx_and_py([x0], [xs, ss, ys], [np.array([1.0, 1.0], dtype="float32")]) + + +def test_scan_int_dtype_preserved(): + # Integer recurrence: dtype must be preserved (no float upcast). + x0 = pt.scalar("x0", dtype="int32") + xs = scan( + lambda xtm1: xtm1 + 1, + outputs_info=[x0], + n_steps=5, + return_updates=False, + ) + + def assert_int(mlx_res, py_res): + np.testing.assert_array_equal(mlx_res, py_res) + assert np.asarray(mlx_res).dtype == np.asarray(py_res).dtype == np.int32 + + compare_mlx_and_py([x0], [xs], [np.int32(0)], assert_fn=assert_int) + + +def test_scan_zero_steps(): + # Degenerate ``n_steps == 0``: matches the empty output of the reference. + x0 = pt.scalar("x0", dtype="float32") + xs = scan( + lambda xtm1: xtm1 + 1, + outputs_info=[x0], + n_steps=0, + return_updates=False, + ) + compare_mlx_and_py([x0], [xs], [np.float32(3.0)]) + + +def test_scan_while_not_implemented(): + x0 = pt.scalar("x0", dtype="float32") + xs = scan( + lambda xtm1: (xtm1 + 1, until(xtm1 > 5)), + outputs_info=[x0], + n_steps=100, + return_updates=False, + ) + with pytest.raises(NotImplementedError): + from pytensor import function + + function([x0], xs, mode=mlx_mode) + + +def test_scan_grad_non_sequence(): + # Gradient w.r.t. a non-sequence through a pure recurrence (no input + # sequences to reverse), which exercises the MIT-MOT backward Scan. + w = pt.scalar("w", dtype="float32") + xs = scan( + lambda x_tm1, w: x_tm1 * w, + outputs_info=[pt.ones((), dtype="float32")], + non_sequences=[w], + n_steps=4, + return_updates=False, + ) + g = pt.grad(xs[-1], w) + compare_mlx_and_py([w], [g], [np.float32(2.0)]) + + +def _rnn_grad_over_sequence(): + xs = pt.matrix("xs", dtype="float32") + h0 = pt.vector("h0", dtype="float32", shape=(3,)) + W = pt.matrix("W", dtype="float32", shape=(3, 3)) + hs = scan( + lambda x_t, h_tm1, W: pt.tanh(x_t + h_tm1 @ W), + sequences=[xs], + outputs_info=[h0], + non_sequences=[W], + return_updates=False, + ) + gW = pt.grad((hs**2).sum(), W) + rng = np.random.default_rng(0) + test_inputs = [ + rng.standard_normal((5, 3)).astype("float32"), + np.zeros(3, dtype="float32"), + (0.1 * rng.standard_normal((3, 3))).astype("float32"), + ] + return [xs, h0, W], [gW], test_inputs + + +def test_scan_grad_over_sequence(): + # The backward Scan (MIT-MOT + reversed forward trace) is correct under the + # base MLX optimizer query. + inputs, outputs, test_inputs = _rnn_grad_over_sequence() + compare_mlx_and_py(inputs, outputs, test_inputs) + + +@pytest.mark.xfail( + strict=True, + reason=( + "Under the full `mode='MLX'` (fast_run) the gradient over a sequence " + "reverses the trace, and MLX miscompiles an elementwise op fed by a " + "negative-stride array (`mx.compile(lambda x: 2.0 * x[::-1])` zeroes the " + "tail). Unrelated to the Scan dispatch; fixed by a follow-up that " + "materializes negative-stride Subtensor results in the MLX backend." + ), +) +def test_scan_grad_over_sequence_default_mode(): + inputs, outputs, test_inputs = _rnn_grad_over_sequence() + compare_mlx_and_py(inputs, outputs, test_inputs, mlx_mode="MLX") From e736a6a6323ba2ec5c93a432a0b623516bc86cda Mon Sep 17 00:00:00 2001 From: Carlos Trujillo Date: Fri, 26 Jun 2026 13:33:36 +0300 Subject: [PATCH 2/3] Note MLX Scan unroll is a workaround for missing loop primitive Clarify in the dispatch that the Python carry loop is unrolled into the graph at trace time (not by `mx.compile`) and that this is a workaround until MLX exposes a native scan/while primitive (ml-explore/mlx#1441): it needs a static step count and the graph grows as O(n_steps * inner_ops), where a real primitive would keep it O(inner_ops). Co-authored-by: Cursor --- pytensor/link/mlx/dispatch/scan.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/pytensor/link/mlx/dispatch/scan.py b/pytensor/link/mlx/dispatch/scan.py index 3ed159f8e6..a6ac459d1c 100644 --- a/pytensor/link/mlx/dispatch/scan.py +++ b/pytensor/link/mlx/dispatch/scan.py @@ -14,10 +14,16 @@ def mlx_funcify_Scan(op: Scan, node, **kwargs): # Mirrors the JAX dispatch (`link/jax/dispatch/scan.py`): the loop recreates # the concatenated trace from the per-step values and prepends the initial # state / truncates as needed, instead of writing into the PyTensor buffers. - # MLX has no native general-scan primitive, so JAX's `lax.scan` is replaced - # by a Python carry loop that `mx.compile` unrolls. This needs a statically - # known number of steps, which is read from the (full-sized) recurring - # buffers since `scan_reduce_trace_prealloc` is excluded for MLX. + # + # Workaround until MLX exposes a native loop primitive (ml-explore/mlx#1441, + # an `mx` equivalent of `jax.lax.scan`/`while_loop`): MLX has none today, so + # JAX's `lax.scan` is replaced by a plain Python carry loop that is unrolled + # into the graph at trace time, which `mx.compile` then compiles. The cost + # is structural -- it needs a statically known number of steps and the graph + # grows as O(n_steps * inner_ops) (a real primitive would keep it + # O(inner_ops)). When the primitive lands, lower to it instead of unrolling. + # `n_steps` is read from the (full-sized) recurring buffers since + # `scan_reduce_trace_prealloc` is excluded for MLX. info = op.info if info.as_while: From 85e3f6d3a56fe2ce49a519b45f0b19bb4d42d236 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo Date: Fri, 26 Jun 2026 12:20:27 +0300 Subject: [PATCH 3/3] Materialize negative-stride MLX Subtensor results MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `mx.compile` miscompiles an elementwise op fed by a negative-stride view (`mx.compile(lambda x: 2.0 * x[::-1])` zeroes the trailing entries; eager is correct). The MLX Subtensor dispatch now copies reversed slices into a contiguous array. This unblocks Scan gradients over sequences (which reverse the trace) under the full `mode="MLX"`, and resolves the existing strict xfail `test_mlx_IncSubtensor_negative_step_slice_grad` — whose failure was the same negative-stride read feeding the elementwise gradient term, not the IncSubtensor write it was attributed to (ml-explore/mlx#3716). Co-authored-by: Cursor --- pytensor/link/mlx/dispatch/subtensor.py | 14 ++++++++++++- tests/link/mlx/test_scan.py | 13 +++--------- tests/link/mlx/test_subtensor.py | 28 ++++++++++++++++++------- 3 files changed, 37 insertions(+), 18 deletions(-) diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index f2e33f137e..38bce44e6f 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -1,5 +1,7 @@ from copy import deepcopy +import mlx.core as mx + from pytensor.link.mlx.dispatch.basic import mlx_funcify from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -12,6 +14,11 @@ ) +def _has_negative_step(indices): + idxs = indices if isinstance(indices, tuple) else (indices,) + return any(isinstance(i, slice) and i.step is not None and i.step < 0 for i in idxs) + + @mlx_funcify.register(Subtensor) def mlx_funcify_Subtensor(op, node, **kwargs): def subtensor(x, *ilists): @@ -21,7 +28,12 @@ def subtensor(x, *ilists): if len(indices) == 1: indices = indices[0] - return x.__getitem__(indices) + res = x.__getitem__(indices) + # MLX miscompiles an elementwise op fed by a negative-stride view under + # `mx.compile`, so materialize reversed slices into a contiguous array. + if _has_negative_step(indices): + res = mx.contiguous(res) + return res return subtensor diff --git a/tests/link/mlx/test_scan.py b/tests/link/mlx/test_scan.py index 90293e1407..a67ac829ba 100644 --- a/tests/link/mlx/test_scan.py +++ b/tests/link/mlx/test_scan.py @@ -195,16 +195,9 @@ def test_scan_grad_over_sequence(): compare_mlx_and_py(inputs, outputs, test_inputs) -@pytest.mark.xfail( - strict=True, - reason=( - "Under the full `mode='MLX'` (fast_run) the gradient over a sequence " - "reverses the trace, and MLX miscompiles an elementwise op fed by a " - "negative-stride array (`mx.compile(lambda x: 2.0 * x[::-1])` zeroes the " - "tail). Unrelated to the Scan dispatch; fixed by a follow-up that " - "materializes negative-stride Subtensor results in the MLX backend." - ), -) def test_scan_grad_over_sequence_default_mode(): + # Under the full `mode="MLX"` the gradient reverses the trace; this used to + # trip the MLX negative-stride compile bug now handled in the Subtensor + # dispatch (see `test_mlx_negative_step_slice_elemwise`). inputs, outputs, test_inputs = _rnn_grad_over_sequence() compare_mlx_and_py(inputs, outputs, test_inputs, mlx_mode="MLX") diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index 69535fa348..544524df13 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -265,18 +265,15 @@ def test_mlx_IncSubtensor_slice_grad(): compare_mlx_and_py([x_pt], [g], [x_np]) -@pytest.mark.xfail( - reason="Upstream mx.compile bug (ml-explore/mlx#3716): assigning an " - "elementwise expression to a negative-strided slice returns wrong values " - "under mx.compile (correct when eager / use_compile=False).", - strict=True, -) def test_mlx_IncSubtensor_negative_step_slice_grad(): + # The wrong result here (previously attributed to ml-explore/mlx#3716) was + # actually the negative-stride read feeding the elementwise gradient term, + # now materialized by the Subtensor dispatch. x_pt = pt.vector("x", dtype="float32") x_np = np.arange(6, dtype=np.float32) g = pt.grad((x_pt[::-1] ** 2).sum(), x_pt) assert isinstance(g.owner.op, pt_subtensor.IncSubtensor) - compare_mlx_and_py([x_pt], [g], [x_np]) + compare_mlx_and_py([x_pt], [g], [x_np], mlx_mode="MLX") @pytest.mark.parametrize( @@ -343,3 +340,20 @@ def test_mlx_AdvancedIncSubtensor_ignore_duplicates(): assert out.owner.op.ignore_duplicates compare_mlx_and_py([x], [out], [np.zeros(3, dtype=np.float32)]) + + +@pytest.mark.parametrize("axis", [0, 1]) +def test_mlx_negative_step_slice_elemwise(axis): + """A negative-stride slice feeding an elementwise op must materialize. + + Under the full ``mode="MLX"`` (``mx.compile``), an elementwise op fed by a + negative-stride view used to be miscompiled (trailing entries zeroed). The + Subtensor dispatch now copies reversed slices into a contiguous array. This + is what unblocks Scan gradients over sequences, which reverse the trace. + """ + x = pt.matrix("x", dtype="float32") + rev = x[::-1] if axis == 0 else x[:, ::-1] + out = 2.0 * rev + assert isinstance(rev.owner.op, pt_subtensor.Subtensor) + x_np = np.arange(15, dtype=np.float32).reshape(5, 3) + compare_mlx_and_py([x], [out], [x_np], mlx_mode="MLX")