From 826cfba90c556e12267bf7498cc6a5913201a6d0 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo Date: Fri, 26 Jun 2026 12:16:29 +0300 Subject: [PATCH 1/3] Add MLX backend dispatch for Scan Op Port the JAX scan dispatch to MLX. As MLX has no native general-scan primitive, the inner fgraph is driven by a Python carry loop that `mx.compile` unrolls. Because scalar values are not readable while MLX traces, the (full-sized) recurring buffers are used to infer the number of steps, falling back to a constant `n_steps` or the sequence length. Covers seqs, MIT-SOT, SIT-SOT, NIT-SOT, untraced SIT-SOT, MIT-MOT (scan gradients) and non-sequences, mirroring the JAX semantics of recreating the trace and prepending/truncating to the buffer size. Gradients over sequences reverse the trace and currently trip a separate MLX bug (an elementwise op fed by a negative-stride array is miscompiled under `mx.compile`); this is captured as a strict xfail under the full `mode="MLX"` and is addressed by a follow-up. Co-authored-by: Cursor --- pytensor/link/mlx/dispatch/__init__.py | 1 + pytensor/link/mlx/dispatch/scan.py | 241 +++++++++++++++++++++++++ tests/link/mlx/test_scan.py | 210 +++++++++++++++++++++ 3 files changed, 452 insertions(+) create mode 100644 pytensor/link/mlx/dispatch/scan.py create mode 100644 tests/link/mlx/test_scan.py diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py index 59b0604856..ed71f95764 100644 --- a/pytensor/link/mlx/dispatch/__init__.py +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -17,4 +17,5 @@ import pytensor.link.mlx.dispatch.pad import pytensor.link.mlx.dispatch.sort import pytensor.link.mlx.dispatch.linalg +import pytensor.link.mlx.dispatch.scan # isort: on diff --git a/pytensor/link/mlx/dispatch/scan.py b/pytensor/link/mlx/dispatch/scan.py new file mode 100644 index 0000000000..3ed159f8e6 --- /dev/null +++ b/pytensor/link/mlx/dispatch/scan.py @@ -0,0 +1,241 @@ +from itertools import chain + +import mlx.core as mx + +from pytensor.compile.mode import MLX, get_mode +from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.scan.op import Scan +from pytensor.tensor.basic import get_scalar_constant_value +from pytensor.tensor.exceptions import NotScalarConstantError + + +@mlx_funcify.register(Scan) +def mlx_funcify_Scan(op: Scan, node, **kwargs): + # Mirrors the JAX dispatch (`link/jax/dispatch/scan.py`): the loop recreates + # the concatenated trace from the per-step values and prepends the initial + # state / truncates as needed, instead of writing into the PyTensor buffers. + # MLX has no native general-scan primitive, so JAX's `lax.scan` is replaced + # by a Python carry loop that `mx.compile` unrolls. This needs a statically + # known number of steps, which is read from the (full-sized) recurring + # buffers since `scan_reduce_trace_prealloc` is excluded for MLX. + info = op.info + + if info.as_while: + raise NotImplementedError("While Scan cannot yet be converted to MLX") + + # NIT-SOT output lengths are runtime scalars under `mx.compile`; take the + # static output shape when known and fall back to ``n_steps`` otherwise. + nitsot_static_sizes = [ + out.type.shape[0] for out in op.outer_nitsot_outs(node.outputs) + ] + + # A constant ``n_steps`` is authoritative (and the only inference source for + # scans without recurring buffers or sequences, e.g. a pure NIT-SOT map). + try: + static_n_steps = int(get_scalar_constant_value(node.inputs[0])) + except NotScalarConstantError: + static_n_steps = None + + rewriter = ( + get_mode(op.mode) + .including("mlx") + .excluding("numba", *MLX._optimizer.exclude) + .optimizer + ) + rewriter(op.fgraph) + scan_inner_func = mlx_funcify(op.fgraph, **kwargs) + + def scan(*outer_inputs): + outer_inputs = list(outer_inputs) + n_steps = _infer_n_steps(op, outer_inputs, nitsot_static_sizes, static_n_steps) + seqs = [seq[:n_steps] for seq in op.outer_seqs(outer_inputs)] + + # MIT-MOT have no "initial state"; the whole buffer is meaningful. + # MIT-SOT and SIT-SOT initial states are extracted from the buffers. + # The ``_init`` states are kept untouched (they prepend the final traces), + # while ``carry`` copies evolve through the loop. + mit_sot_init = [ + buff[: -min(tap)] + for buff, tap in zip( + op.outer_mitsot(outer_inputs), info.mit_sot_in_slices, strict=True + ) + ] + sit_sot_init = [buff[0] for buff in op.outer_sitsot(outer_inputs)] + + mit_mot = list(op.outer_mitmot(outer_inputs)) + mit_sot = list(mit_sot_init) + sit_sot = list(sit_sot_init) + untraced_sit_sot = list(op.outer_untraced_sit_sot(outer_inputs)) + non_seqs = op.outer_non_seqs(outer_inputs) + + n_traced = info.n_mit_sot + info.n_sit_sot + info.n_nit_sot + traces: list[list] = [[] for _ in range(n_traced)] + for i in range(n_steps): + inner_seqs = [seq[i] for seq in seqs] + mit_mot_flatten = list( + chain.from_iterable( + buffer[[i + tap for tap in taps]] + for buffer, taps in zip( + mit_mot, info.normalized_mit_mot_in_slices, strict=True + ) + ) + ) + mit_sot_flatten = list( + chain.from_iterable( + buffer[list(taps)] + for buffer, taps in zip( + mit_sot, info.mit_sot_in_slices, strict=True + ) + ) + ) + + inner_outs = list( + scan_inner_func( + *inner_seqs, + *mit_mot_flatten, + *mit_sot_flatten, + *sit_sot, + *untraced_sit_sot, + *non_seqs, + ) + ) + + new_mit_mot_vals = op.inner_mitmot_outs_grouped(inner_outs) + new_mit_sot_vals = op.inner_mitsot_outs(inner_outs) + new_sit_sot = op.inner_sitsot_outs(inner_outs) + new_nit_sot = op.inner_nitsot_outs(inner_outs) + new_untraced_sit_sot = op.inner_untraced_sit_sot_outs(inner_outs) + + # Write the new MIT-MOT values at the output-tap positions. + mit_mot = [ + _functional_set( + buffer, [i + tap for tap in taps], mx.stack(new_vals, axis=0) + ) + for buffer, new_vals, taps in zip( + mit_mot, + new_mit_mot_vals, + info.normalized_mit_mot_out_slices, + strict=True, + ) + ] + # Discard oldest MIT-SOT tap and append the newest value. + mit_sot = [ + mx.concatenate([buffer[1:], new_val[None, ...]], axis=0) + for buffer, new_val in zip(mit_sot, new_mit_sot_vals, strict=True) + ] + sit_sot = new_sit_sot + untraced_sit_sot = new_untraced_sit_sot + + step_traced = [*new_mit_sot_vals, *new_sit_sot, *new_nit_sot] + for trace, val in zip(traces, step_traced, strict=True): + trace.append(val) + + # Per-step shape of each traced output (for synthesizing empty traces + # when ``n_steps == 0``): MIT-SOT/SIT-SOT match their state shape. + traced_trailing = ( + [tuple(init.shape[1:]) for init in mit_sot_init] + + [tuple(init.shape) for init in sit_sot_init] + + [() for _ in range(info.n_nit_sot)] + ) + stacked_traces = [ + mx.stack(trace, axis=0) if trace else mx.zeros((0, *trailing)) + for trace, trailing in zip(traces, traced_trailing, strict=True) + ] + + def get_partial_traces(traces): + """Prepend initial states and slice traces down to buffer sizes.""" + init_states = mit_sot_init + sit_sot_init + [None] * info.n_nit_sot + buffer_sizes = ( + [buff.shape[0] for buff in op.outer_mitsot(outer_inputs)] + + [buff.shape[0] for buff in op.outer_sitsot(outer_inputs)] + + [ + size if size is not None else n_steps + for size in nitsot_static_sizes + ] + ) + partial_traces = [] + for init_state, trace, buffer_size in zip( + init_states, traces, buffer_sizes, strict=True + ): + if init_state is not None: + if trace.shape[0] >= buffer_size: + # Trace at least as long as the buffer: keep the tail. + partial_trace = trace[-buffer_size:] + else: + # Trace shorter than the buffer: prepend (part of) init. + if init_state.ndim < trace.ndim: + init_state = init_state[None] + if ( + n_init_needed := buffer_size - trace.shape[0] + ) < init_state.shape[0]: + init_state = init_state[-n_init_needed:] + partial_trace = mx.concatenate([init_state, trace], axis=0) + else: + partial_trace = ( + trace[-buffer_size:] if trace.shape[0] > buffer_size else trace + ) + + assert partial_trace.shape[0] == buffer_size + partial_traces.append(partial_trace) + + return partial_traces + + scan_outs_final = [ + *mit_mot, + *get_partial_traces(stacked_traces), + *untraced_sit_sot, + ] + + if len(scan_outs_final) == 1: + return scan_outs_final[0] + return scan_outs_final + + return scan + + +def _infer_n_steps(op, outer_inputs, nitsot_static_sizes, static_n_steps): + """Derive the number of steps for the unrolled loop. + + Scalar input values are not readable while ``mx.compile`` traces, but array + shapes are concrete. A constant ``n_steps`` is used directly; otherwise the + count comes from a recurring buffer (which stays full-sized because the + trace-prealloc reduction is disabled for MLX, so each encodes ``n_steps`` + plus its initial taps) or a sequence. A non-constant ``n_steps`` with no + such buffer (e.g. a dynamic-length pure ``map``) is an MLX static-shape + limitation, like dynamic ``arange``. + """ + info = op.info + if static_n_steps is not None: + return static_n_steps + for buff in op.outer_sitsot(outer_inputs): + return buff.shape[0] - 1 + for buff, taps in zip( + op.outer_mitsot(outer_inputs), info.mit_sot_in_slices, strict=True + ): + return buff.shape[0] + min(taps) + for seq in op.outer_seqs(outer_inputs): + return seq.shape[0] + for buff, in_taps, out_taps in zip( + op.outer_mitmot(outer_inputs), + info.normalized_mit_mot_in_slices, + info.normalized_mit_mot_out_slices, + strict=True, + ): + return buff.shape[0] - (max(*in_taps, *out_taps) - min(*in_taps, *out_taps)) + for size in nitsot_static_sizes: + if size is not None: + return size + raise NotImplementedError( + "MLX Scan requires a statically known number of steps when there are no " + "recurring buffers or sequences to infer it from." + ) + + +def _functional_set(buffer, idx, vals): + """Return ``buffer`` with rows ``idx`` set to ``vals``. + + MLX has no ``.at[].set`` and in-place item assignment aliases buffers under + ``mx.compile``, so a functional scatter-add of the delta is used instead. + """ + idx = mx.array(idx) + return buffer.at[idx].add(vals - buffer[idx]) diff --git a/tests/link/mlx/test_scan.py b/tests/link/mlx/test_scan.py new file mode 100644 index 0000000000..90293e1407 --- /dev/null +++ b/tests/link/mlx/test_scan.py @@ -0,0 +1,210 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor.scan import until +from pytensor.scan.basic import scan +from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode + + +mx = pytest.importorskip("mlx.core") + + +@pytest.mark.parametrize("view", [None, (-1,), slice(-2, None, None)]) +def test_scan_sit_sot(view): + x0 = pt.scalar("x0", dtype="float32") + xs = scan( + lambda xtm1: xtm1 + 1, + outputs_info=[x0], + n_steps=10, + return_updates=False, + ) + if view: + xs = xs[view] + compare_mlx_and_py([x0], [xs], [np.float32(np.e)]) + + +@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)]) +def test_scan_mit_sot(view): + x0 = pt.vector("x0", dtype="float32", shape=(3,)) + xs = scan( + lambda xtm3, xtm1: xtm3 + xtm1 + 1, + outputs_info=[{"initial": x0, "taps": [-3, -1]}], + n_steps=10, + return_updates=False, + ) + if view: + xs = xs[view] + compare_mlx_and_py([x0], [xs], [np.full((3,), np.e, dtype="float32")]) + + +def test_scan_with_sequence_and_non_seq(): + # RNN-style recurrence over a sequence with a shared weight matrix. + xs = pt.matrix("xs", dtype="float32") + h0 = pt.vector("h0", dtype="float32", shape=(3,)) + W = pt.matrix("W", dtype="float32", shape=(3, 3)) + hs = scan( + lambda x_t, h_tm1, W: pt.tanh(x_t + h_tm1 @ W), + sequences=[xs], + outputs_info=[h0], + non_sequences=[W], + return_updates=False, + ) + rng = np.random.default_rng(0) + compare_mlx_and_py( + [xs, h0, W], + [hs], + [ + rng.standard_normal((5, 3)).astype("float32"), + np.zeros(3, dtype="float32"), + (0.1 * rng.standard_normal((3, 3))).astype("float32"), + ], + ) + + +def test_scan_multiple_outputs(): + # One recurring (sit_sot) and one mapped (nit_sot) output. + s = pt.vector("s", dtype="float32") + + def step(s_t, acc): + return acc + s_t, s_t * s_t + + acc, sq = scan( + step, + sequences=[s], + outputs_info=[pt.zeros((), dtype="float32"), None], + return_updates=False, + ) + compare_mlx_and_py([s], [acc, sq], [np.arange(1, 5, dtype="float32")]) + + +def test_scan_nit_sot_only(): + # Pure tiling/map from a non-sequence with an explicit ``n_steps`` (no + # recurring buffer or sequence to infer the step count from). + w = pt.scalar("w", dtype="float32") + ys = scan( + lambda w: w * 2, + outputs_info=[None], + non_sequences=[w], + n_steps=5, + return_updates=False, + ) + compare_mlx_and_py([w], [ys], [np.float32(3.0)]) + + +def test_scan_multiple_recurring_states(): + # MIT-SOT (taps -2, -1) and SIT-SOT and a NIT-SOT map in one scan. + x0 = pt.vector("x0", dtype="float32", shape=(2,)) + + def step(xtm2, xtm1, stm1): + x_t = xtm2 + xtm1 + return x_t, stm1 + x_t, x_t * 2 + + xs, ss, ys = scan( + step, + outputs_info=[{"initial": x0, "taps": [-2, -1]}, pt.zeros((), "float32"), None], + n_steps=6, + return_updates=False, + ) + compare_mlx_and_py([x0], [xs, ss, ys], [np.array([1.0, 1.0], dtype="float32")]) + + +def test_scan_int_dtype_preserved(): + # Integer recurrence: dtype must be preserved (no float upcast). + x0 = pt.scalar("x0", dtype="int32") + xs = scan( + lambda xtm1: xtm1 + 1, + outputs_info=[x0], + n_steps=5, + return_updates=False, + ) + + def assert_int(mlx_res, py_res): + np.testing.assert_array_equal(mlx_res, py_res) + assert np.asarray(mlx_res).dtype == np.asarray(py_res).dtype == np.int32 + + compare_mlx_and_py([x0], [xs], [np.int32(0)], assert_fn=assert_int) + + +def test_scan_zero_steps(): + # Degenerate ``n_steps == 0``: matches the empty output of the reference. + x0 = pt.scalar("x0", dtype="float32") + xs = scan( + lambda xtm1: xtm1 + 1, + outputs_info=[x0], + n_steps=0, + return_updates=False, + ) + compare_mlx_and_py([x0], [xs], [np.float32(3.0)]) + + +def test_scan_while_not_implemented(): + x0 = pt.scalar("x0", dtype="float32") + xs = scan( + lambda xtm1: (xtm1 + 1, until(xtm1 > 5)), + outputs_info=[x0], + n_steps=100, + return_updates=False, + ) + with pytest.raises(NotImplementedError): + from pytensor import function + + function([x0], xs, mode=mlx_mode) + + +def test_scan_grad_non_sequence(): + # Gradient w.r.t. a non-sequence through a pure recurrence (no input + # sequences to reverse), which exercises the MIT-MOT backward Scan. + w = pt.scalar("w", dtype="float32") + xs = scan( + lambda x_tm1, w: x_tm1 * w, + outputs_info=[pt.ones((), dtype="float32")], + non_sequences=[w], + n_steps=4, + return_updates=False, + ) + g = pt.grad(xs[-1], w) + compare_mlx_and_py([w], [g], [np.float32(2.0)]) + + +def _rnn_grad_over_sequence(): + xs = pt.matrix("xs", dtype="float32") + h0 = pt.vector("h0", dtype="float32", shape=(3,)) + W = pt.matrix("W", dtype="float32", shape=(3, 3)) + hs = scan( + lambda x_t, h_tm1, W: pt.tanh(x_t + h_tm1 @ W), + sequences=[xs], + outputs_info=[h0], + non_sequences=[W], + return_updates=False, + ) + gW = pt.grad((hs**2).sum(), W) + rng = np.random.default_rng(0) + test_inputs = [ + rng.standard_normal((5, 3)).astype("float32"), + np.zeros(3, dtype="float32"), + (0.1 * rng.standard_normal((3, 3))).astype("float32"), + ] + return [xs, h0, W], [gW], test_inputs + + +def test_scan_grad_over_sequence(): + # The backward Scan (MIT-MOT + reversed forward trace) is correct under the + # base MLX optimizer query. + inputs, outputs, test_inputs = _rnn_grad_over_sequence() + compare_mlx_and_py(inputs, outputs, test_inputs) + + +@pytest.mark.xfail( + strict=True, + reason=( + "Under the full `mode='MLX'` (fast_run) the gradient over a sequence " + "reverses the trace, and MLX miscompiles an elementwise op fed by a " + "negative-stride array (`mx.compile(lambda x: 2.0 * x[::-1])` zeroes the " + "tail). Unrelated to the Scan dispatch; fixed by a follow-up that " + "materializes negative-stride Subtensor results in the MLX backend." + ), +) +def test_scan_grad_over_sequence_default_mode(): + inputs, outputs, test_inputs = _rnn_grad_over_sequence() + compare_mlx_and_py(inputs, outputs, test_inputs, mlx_mode="MLX") From e736a6a6323ba2ec5c93a432a0b623516bc86cda Mon Sep 17 00:00:00 2001 From: Carlos Trujillo Date: Fri, 26 Jun 2026 13:33:36 +0300 Subject: [PATCH 2/3] Note MLX Scan unroll is a workaround for missing loop primitive Clarify in the dispatch that the Python carry loop is unrolled into the graph at trace time (not by `mx.compile`) and that this is a workaround until MLX exposes a native scan/while primitive (ml-explore/mlx#1441): it needs a static step count and the graph grows as O(n_steps * inner_ops), where a real primitive would keep it O(inner_ops). Co-authored-by: Cursor --- pytensor/link/mlx/dispatch/scan.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/pytensor/link/mlx/dispatch/scan.py b/pytensor/link/mlx/dispatch/scan.py index 3ed159f8e6..a6ac459d1c 100644 --- a/pytensor/link/mlx/dispatch/scan.py +++ b/pytensor/link/mlx/dispatch/scan.py @@ -14,10 +14,16 @@ def mlx_funcify_Scan(op: Scan, node, **kwargs): # Mirrors the JAX dispatch (`link/jax/dispatch/scan.py`): the loop recreates # the concatenated trace from the per-step values and prepends the initial # state / truncates as needed, instead of writing into the PyTensor buffers. - # MLX has no native general-scan primitive, so JAX's `lax.scan` is replaced - # by a Python carry loop that `mx.compile` unrolls. This needs a statically - # known number of steps, which is read from the (full-sized) recurring - # buffers since `scan_reduce_trace_prealloc` is excluded for MLX. + # + # Workaround until MLX exposes a native loop primitive (ml-explore/mlx#1441, + # an `mx` equivalent of `jax.lax.scan`/`while_loop`): MLX has none today, so + # JAX's `lax.scan` is replaced by a plain Python carry loop that is unrolled + # into the graph at trace time, which `mx.compile` then compiles. The cost + # is structural -- it needs a statically known number of steps and the graph + # grows as O(n_steps * inner_ops) (a real primitive would keep it + # O(inner_ops)). When the primitive lands, lower to it instead of unrolling. + # `n_steps` is read from the (full-sized) recurring buffers since + # `scan_reduce_trace_prealloc` is excluded for MLX. info = op.info if info.as_while: From cebac9d57eec7fece90f1aaa2dbc56e716eb7696 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo Date: Thu, 2 Jul 2026 18:13:05 +0300 Subject: [PATCH 3/3] Reuse ScanCompatibilityTests in MLX backend tests check_aliased_inner_outputs passes as-is (static and dynamic shapes). check_higher_order_derivative is correct with use_compile=False but is a strict xfail under mx.compile: the second-order backward Scan reads the forward trace through negative-stride Subtensor views, hitting the same MLX miscompilation already xfailed for first-order sequence grads. The shared check gains an rtol knob since MLX casts float64 to float32. check_grad_until_and_truncate_sequence_taps needs a while Scan, which the MLX dispatch (like JAX) does not support. Co-Authored-By: Claude Fable 5 --- tests/link/mlx/test_scan.py | 23 +++++++++++++++++++++++ tests/scan/test_basic.py | 8 ++++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/tests/link/mlx/test_scan.py b/tests/link/mlx/test_scan.py index 90293e1407..e47fbca6f1 100644 --- a/tests/link/mlx/test_scan.py +++ b/tests/link/mlx/test_scan.py @@ -5,6 +5,7 @@ from pytensor.scan import until from pytensor.scan.basic import scan from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode +from tests.scan.test_basic import ScanCompatibilityTests mx = pytest.importorskip("mlx.core") @@ -208,3 +209,25 @@ def test_scan_grad_over_sequence(): def test_scan_grad_over_sequence_default_mode(): inputs, outputs, test_inputs = _rnn_grad_over_sequence() compare_mlx_and_py(inputs, outputs, test_inputs, mlx_mode="MLX") + + +@pytest.mark.xfail( + strict=True, + reason=( + "The second-order backward Scan reads the forward trace through " + "negative-stride Subtensor views, hitting the same MLX miscompilation " + "as test_scan_grad_over_sequence_default_mode (here even under the base " + "optimizer query). Correct with `MLXLinker(use_compile=False)`; fixed by " + "the follow-up that materializes negative-stride Subtensor results." + ), +) +def test_higher_order_derivatives(): + # rtol loosened because MLX casts the check's float64 to float32 + ScanCompatibilityTests.check_higher_order_derivative(mode="MLX", rtol=1e-6) + + +@pytest.mark.parametrize("static_shape", [True, False]) +def test_aliased_inner_outputs(static_shape): + ScanCompatibilityTests.check_aliased_inner_outputs( + static_shape=static_shape, mode="MLX" + ) diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 1bbd183953..c21d659581 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -4039,7 +4039,7 @@ class ScanCompatibilityTests: """Collection of test of subtle required behaviors of Scan, that can be reused by different backends.""" @staticmethod - def check_higher_order_derivative(mode): + def check_higher_order_derivative(mode, rtol=1e-7): """This tests different mit-mot taps signs""" x = pt.dscalar("x") @@ -4058,9 +4058,9 @@ def check_higher_order_derivative(mode): fn = function([x], [r, g, gg, ggg], mode=mode) x_test = np.array(0.95, dtype=x.type.dtype) r_res, g_res, gg_res, _ggg_res = fn(x_test) - np.testing.assert_allclose(r_res, x_test**16) - np.testing.assert_allclose(g_res, 16 * x_test**15) - np.testing.assert_allclose(gg_res, (16 * 15) * x_test**14) + np.testing.assert_allclose(r_res, x_test**16, rtol=rtol) + np.testing.assert_allclose(g_res, 16 * x_test**15, rtol=rtol) + np.testing.assert_allclose(gg_res, (16 * 15) * x_test**14, rtol=rtol) # FIXME: All implementations of Scan seem to get this one wrong! # np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13)