From 9f4da8d35d9d7c28af0e530a6c4dc3107dd96e6e Mon Sep 17 00:00:00 2001 From: Carlos Trujillo Date: Fri, 26 Jun 2026 13:55:38 +0300 Subject: [PATCH 1/2] Support dynamic shape-derived bounds in MLX ARange The MLX `arange` dispatcher rejected any non-constant bound at funcify time, so `pt.arange(x.shape[0])` failed for tensors with a dynamic (None) static shape, even though shape-derived bounds are concrete under `mx.compile`. This broke advanced-indexing/gather patterns such as `logp[pt.arange(targets.shape[0]), targets]`. Bake constant-foldable bounds and resolve the rest at runtime by converting the MLX scalar to a Python scalar, mirroring the JAX dispatch. Genuinely data-dependent bounds raise a clear NotImplementedError. Co-authored-by: Cursor --- pytensor/link/mlx/dispatch/tensor_basic.py | 49 ++++++++++++++-------- tests/link/mlx/test_tensor_basic.py | 45 ++++++++++++++++++++ 2 files changed, 77 insertions(+), 17 deletions(-) diff --git a/pytensor/link/mlx/dispatch/tensor_basic.py b/pytensor/link/mlx/dispatch/tensor_basic.py index 3cdc47323f..81c19359d0 100644 --- a/pytensor/link/mlx/dispatch/tensor_basic.py +++ b/pytensor/link/mlx/dispatch/tensor_basic.py @@ -178,32 +178,47 @@ def alloc(x, *shape): return alloc -ARANGE_CONCRETE_VALUE_ERROR = ( - "MLX's arange requires all arguments (start, stop, step) to be concrete " - "Python int/float values, not symbolic variables. Unlike NumPy and JAX, " - "MLX does not accept array inputs for arange at all." - "\n\nAn example of a valid graph:" - "\n>>> import pytensor.tensor as pt" - "\n>>> pt.arange(1, 10, 2)" +ARANGE_DATA_DEPENDENT_ERROR = ( + "MLX cannot build arange with a data-dependent length: the bounds depend on " + "runtime array values, so the output shape is unknown at compile time. " + "Constant and shape-derived bounds (e.g. pt.arange(x.shape[0])) are supported." ) @mlx_funcify.register(ARange) def mlx_funcify_ARange(op, node, **kwargs): - # MLX's arange only accepts Python int/float, not arrays, - # so all arguments must be known at graph-construction time. - try: - start, stop, step = [ - get_scalar_constant_value(arg).item() for arg in node.inputs + dtype = convert_dtype_to_mlx(op.dtype) + # mx.arange only accepts Python int/float. Bake constant bounds, and resolve + # the rest at runtime: shape-derived bounds are concrete under mx.compile even + # when the static shape is unknown (mirrors the JAX dispatch). + static_args = [_arange_static_bound(arg) for arg in node.inputs] + + def arange(*args): + resolved = [ + static if static is not None else _arange_runtime_bound(runtime) + for static, runtime in zip(static_args, args, strict=True) ] + return mx.arange(*resolved, dtype=dtype) + + return arange + + +def _arange_static_bound(arg): + try: + return get_scalar_constant_value(arg).item() except NotScalarConstantError: - raise NotImplementedError(ARANGE_CONCRETE_VALUE_ERROR) - dtype = convert_dtype_to_mlx(op.dtype) + return None - def arange(*_args): - return mx.arange(start, stop, step, dtype=dtype) - return arange +def _arange_runtime_bound(value): + try: + return value.item() if hasattr(value, "item") else value + except (ValueError, TypeError) as exc: + if "[eval] Attempting to eval an array during function transformations" in str( + exc + ): + raise NotImplementedError(ARANGE_DATA_DEPENDENT_ERROR) from exc + raise def _extract_static_dims(shape_inputs): diff --git a/tests/link/mlx/test_tensor_basic.py b/tests/link/mlx/test_tensor_basic.py index 8ab3e07542..6eeb6c2f70 100644 --- a/tests/link/mlx/test_tensor_basic.py +++ b/tests/link/mlx/test_tensor_basic.py @@ -164,3 +164,48 @@ def test_arange(): out = arange(1, 10, 2) compare_mlx_and_py([], [out], []) + + +def test_arange_dynamic_shape(): + # Shape-derived bounds are concrete under mx.compile even when the static + # shape is unknown, so a genuinely dynamic length must work (regression: this + # used to raise NotImplementedError because of an over-aggressive constant + # check). Exercises every position (start/stop/step) being shape-derived, an + # offset, and an empty result. + x = pt.vector("x") + y = pt.vector("y") + outs = [ + arange(x.shape[0]), # dynamic stop + arange(x.shape[0] + 2), # shape-derived expression + arange(x.shape[0], y.shape[0]), # dynamic start and stop + arange(0, y.shape[0], x.shape[0]), # dynamic step + arange(y.shape[0], x.shape[0]), # start > stop -> empty + ] + compare_mlx_and_py( + [x, y], + outs, + [np.zeros(3, dtype="float32"), np.zeros(7, dtype="float32")], + ) + + +def test_arange_dynamic_advanced_index(): + # The motivating case: a vectorized gather lowers to advanced indexing that + # internally builds arange(idx.shape[0]) with a runtime-dynamic length. + logp = pt.matrix("logp") + targets = pt.lvector("targets") + out = logp[arange(targets.shape[0]), targets] + compare_mlx_and_py( + [logp, targets], + [out], + [np.arange(12, dtype="float32").reshape(3, 4), np.array([0, 2, 3])], + ) + + +def test_arange_data_dependent_raises(): + # A genuinely data-dependent length has a runtime-only output shape, which MLX + # cannot compile. This must fail loudly rather than silently misbehave. + x = pt.vector("x") + out = arange(pt.sum(x > 0).astype("int64")) + fn = pytensor.function([x], out, mode=compile_mode) + with pytest.raises(NotImplementedError, match="data-dependent length"): + fn(np.array([1.0, -1.0, 1.0], dtype="float32")) From cbbf903ff505cb41d38c6d7e10d7dadd6ed16d06 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo Date: Fri, 26 Jun 2026 17:02:15 +0300 Subject: [PATCH 2/2] Detect data-dependent ARange bounds statically on MLX Address review: replace the fragile runtime match on MLX's eval error string with a static, funcify-time check that walks the bound's graph treating Shape ops as barriers. A bound is resolvable iff it derives only from input shapes and constants; genuinely data-dependent bounds now raise NotImplementedError up front, consistently across MLX modes. This is more general than the JAX dispatch (which only recognizes a bare Shape_i). Helpers moved above the dispatcher that uses them. Co-authored-by: Cursor --- pytensor/link/mlx/dispatch/tensor_basic.py | 57 ++++++++++++++-------- tests/link/mlx/test_tensor_basic.py | 5 +- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/pytensor/link/mlx/dispatch/tensor_basic.py b/pytensor/link/mlx/dispatch/tensor_basic.py index 81c19359d0..84e208b0f3 100644 --- a/pytensor/link/mlx/dispatch/tensor_basic.py +++ b/pytensor/link/mlx/dispatch/tensor_basic.py @@ -1,6 +1,8 @@ import mlx.core as mx import numpy as np +from pytensor.graph.basic import Constant +from pytensor.graph.traversal import walk from pytensor.link.mlx.dispatch.basic import convert_dtype_to_mlx, mlx_funcify from pytensor.tensor import get_vector_length from pytensor.tensor.basic import ( @@ -17,6 +19,7 @@ get_scalar_constant_value, ) from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.shape import Shape, Shape_i MLX_DYNAMIC_SHAPE_ERROR = ( @@ -185,12 +188,42 @@ def alloc(x, *shape): ) +def _arange_bound_is_static(var): + # A bound is concrete under mx.compile when its value derives only from input + # shapes and constants. Shape ops are barriers: only the shape, not the + # underlying data, is needed (more general than the JAX dispatch, which only + # recognizes a bare Shape_i). + def expand(v): + owner = v.owner + if owner is None or isinstance(owner.op, Shape | Shape_i): + return None + return owner.inputs + + return all( + v.owner is not None or isinstance(v, Constant) for v in walk([var], expand) + ) + + +def _arange_static_bound(arg): + try: + return get_scalar_constant_value(arg).item() + except NotScalarConstantError: + return None + + +def _arange_runtime_bound(value): + return value.item() if hasattr(value, "item") else value + + @mlx_funcify.register(ARange) def mlx_funcify_ARange(op, node, **kwargs): + # mx.arange only accepts Python int/float. Bake constant bounds and resolve + # shape-derived ones at runtime (concrete under mx.compile even when the + # static shape is unknown); reject genuinely data-dependent bounds up front. + if not all(_arange_bound_is_static(arg) for arg in node.inputs): + raise NotImplementedError(ARANGE_DATA_DEPENDENT_ERROR) + dtype = convert_dtype_to_mlx(op.dtype) - # mx.arange only accepts Python int/float. Bake constant bounds, and resolve - # the rest at runtime: shape-derived bounds are concrete under mx.compile even - # when the static shape is unknown (mirrors the JAX dispatch). static_args = [_arange_static_bound(arg) for arg in node.inputs] def arange(*args): @@ -203,24 +236,6 @@ def arange(*args): return arange -def _arange_static_bound(arg): - try: - return get_scalar_constant_value(arg).item() - except NotScalarConstantError: - return None - - -def _arange_runtime_bound(value): - try: - return value.item() if hasattr(value, "item") else value - except (ValueError, TypeError) as exc: - if "[eval] Attempting to eval an array during function transformations" in str( - exc - ): - raise NotImplementedError(ARANGE_DATA_DEPENDENT_ERROR) from exc - raise - - def _extract_static_dims(shape_inputs): static_dims = [] for dim in shape_inputs: diff --git a/tests/link/mlx/test_tensor_basic.py b/tests/link/mlx/test_tensor_basic.py index 6eeb6c2f70..3f6925b8ad 100644 --- a/tests/link/mlx/test_tensor_basic.py +++ b/tests/link/mlx/test_tensor_basic.py @@ -203,9 +203,8 @@ def test_arange_dynamic_advanced_index(): def test_arange_data_dependent_raises(): # A genuinely data-dependent length has a runtime-only output shape, which MLX - # cannot compile. This must fail loudly rather than silently misbehave. + # cannot compile. This must fail loudly (at compile time) rather than silently. x = pt.vector("x") out = arange(pt.sum(x > 0).astype("int64")) - fn = pytensor.function([x], out, mode=compile_mode) with pytest.raises(NotImplementedError, match="data-dependent length"): - fn(np.array([1.0, -1.0, 1.0], dtype="float32")) + pytensor.function([x], out, mode=compile_mode)