diff --git a/pytensor/link/mlx/dispatch/tensor_basic.py b/pytensor/link/mlx/dispatch/tensor_basic.py index 3cdc47323f..84e208b0f3 100644 --- a/pytensor/link/mlx/dispatch/tensor_basic.py +++ b/pytensor/link/mlx/dispatch/tensor_basic.py @@ -1,6 +1,8 @@ import mlx.core as mx import numpy as np +from pytensor.graph.basic import Constant +from pytensor.graph.traversal import walk from pytensor.link.mlx.dispatch.basic import convert_dtype_to_mlx, mlx_funcify from pytensor.tensor import get_vector_length from pytensor.tensor.basic import ( @@ -17,6 +19,7 @@ get_scalar_constant_value, ) from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.shape import Shape, Shape_i MLX_DYNAMIC_SHAPE_ERROR = ( @@ -178,30 +181,57 @@ def alloc(x, *shape): return alloc -ARANGE_CONCRETE_VALUE_ERROR = ( - "MLX's arange requires all arguments (start, stop, step) to be concrete " - "Python int/float values, not symbolic variables. Unlike NumPy and JAX, " - "MLX does not accept array inputs for arange at all." - "\n\nAn example of a valid graph:" - "\n>>> import pytensor.tensor as pt" - "\n>>> pt.arange(1, 10, 2)" +ARANGE_DATA_DEPENDENT_ERROR = ( + "MLX cannot build arange with a data-dependent length: the bounds depend on " + "runtime array values, so the output shape is unknown at compile time. " + "Constant and shape-derived bounds (e.g. pt.arange(x.shape[0])) are supported." ) -@mlx_funcify.register(ARange) -def mlx_funcify_ARange(op, node, **kwargs): - # MLX's arange only accepts Python int/float, not arrays, - # so all arguments must be known at graph-construction time. +def _arange_bound_is_static(var): + # A bound is concrete under mx.compile when its value derives only from input + # shapes and constants. Shape ops are barriers: only the shape, not the + # underlying data, is needed (more general than the JAX dispatch, which only + # recognizes a bare Shape_i). + def expand(v): + owner = v.owner + if owner is None or isinstance(owner.op, Shape | Shape_i): + return None + return owner.inputs + + return all( + v.owner is not None or isinstance(v, Constant) for v in walk([var], expand) + ) + + +def _arange_static_bound(arg): try: - start, stop, step = [ - get_scalar_constant_value(arg).item() for arg in node.inputs - ] + return get_scalar_constant_value(arg).item() except NotScalarConstantError: - raise NotImplementedError(ARANGE_CONCRETE_VALUE_ERROR) + return None + + +def _arange_runtime_bound(value): + return value.item() if hasattr(value, "item") else value + + +@mlx_funcify.register(ARange) +def mlx_funcify_ARange(op, node, **kwargs): + # mx.arange only accepts Python int/float. Bake constant bounds and resolve + # shape-derived ones at runtime (concrete under mx.compile even when the + # static shape is unknown); reject genuinely data-dependent bounds up front. + if not all(_arange_bound_is_static(arg) for arg in node.inputs): + raise NotImplementedError(ARANGE_DATA_DEPENDENT_ERROR) + dtype = convert_dtype_to_mlx(op.dtype) + static_args = [_arange_static_bound(arg) for arg in node.inputs] - def arange(*_args): - return mx.arange(start, stop, step, dtype=dtype) + def arange(*args): + resolved = [ + static if static is not None else _arange_runtime_bound(runtime) + for static, runtime in zip(static_args, args, strict=True) + ] + return mx.arange(*resolved, dtype=dtype) return arange diff --git a/tests/link/mlx/test_tensor_basic.py b/tests/link/mlx/test_tensor_basic.py index 8ab3e07542..3f6925b8ad 100644 --- a/tests/link/mlx/test_tensor_basic.py +++ b/tests/link/mlx/test_tensor_basic.py @@ -164,3 +164,47 @@ def test_arange(): out = arange(1, 10, 2) compare_mlx_and_py([], [out], []) + + +def test_arange_dynamic_shape(): + # Shape-derived bounds are concrete under mx.compile even when the static + # shape is unknown, so a genuinely dynamic length must work (regression: this + # used to raise NotImplementedError because of an over-aggressive constant + # check). Exercises every position (start/stop/step) being shape-derived, an + # offset, and an empty result. + x = pt.vector("x") + y = pt.vector("y") + outs = [ + arange(x.shape[0]), # dynamic stop + arange(x.shape[0] + 2), # shape-derived expression + arange(x.shape[0], y.shape[0]), # dynamic start and stop + arange(0, y.shape[0], x.shape[0]), # dynamic step + arange(y.shape[0], x.shape[0]), # start > stop -> empty + ] + compare_mlx_and_py( + [x, y], + outs, + [np.zeros(3, dtype="float32"), np.zeros(7, dtype="float32")], + ) + + +def test_arange_dynamic_advanced_index(): + # The motivating case: a vectorized gather lowers to advanced indexing that + # internally builds arange(idx.shape[0]) with a runtime-dynamic length. + logp = pt.matrix("logp") + targets = pt.lvector("targets") + out = logp[arange(targets.shape[0]), targets] + compare_mlx_and_py( + [logp, targets], + [out], + [np.arange(12, dtype="float32").reshape(3, 4), np.array([0, 2, 3])], + ) + + +def test_arange_data_dependent_raises(): + # A genuinely data-dependent length has a runtime-only output shape, which MLX + # cannot compile. This must fail loudly (at compile time) rather than silently. + x = pt.vector("x") + out = arange(pt.sum(x > 0).astype("int64")) + with pytest.raises(NotImplementedError, match="data-dependent length"): + pytensor.function([x], out, mode=compile_mode)