Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 47 additions & 17 deletions pytensor/link/mlx/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import mlx.core as mx
import numpy as np

from pytensor.graph.basic import Constant
from pytensor.graph.traversal import walk
from pytensor.link.mlx.dispatch.basic import convert_dtype_to_mlx, mlx_funcify
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import (
Expand All @@ -17,6 +19,7 @@
get_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import Shape, Shape_i


MLX_DYNAMIC_SHAPE_ERROR = (
Expand Down Expand Up @@ -178,30 +181,57 @@ def alloc(x, *shape):
return alloc


ARANGE_CONCRETE_VALUE_ERROR = (
"MLX's arange requires all arguments (start, stop, step) to be concrete "
"Python int/float values, not symbolic variables. Unlike NumPy and JAX, "
"MLX does not accept array inputs for arange at all."
"\n\nAn example of a valid graph:"
"\n>>> import pytensor.tensor as pt"
"\n>>> pt.arange(1, 10, 2)"
ARANGE_DATA_DEPENDENT_ERROR = (
"MLX cannot build arange with a data-dependent length: the bounds depend on "
"runtime array values, so the output shape is unknown at compile time. "
"Constant and shape-derived bounds (e.g. pt.arange(x.shape[0])) are supported."
)


@mlx_funcify.register(ARange)
def mlx_funcify_ARange(op, node, **kwargs):
# MLX's arange only accepts Python int/float, not arrays,
# so all arguments must be known at graph-construction time.
def _arange_bound_is_static(var):
# A bound is concrete under mx.compile when its value derives only from input
# shapes and constants. Shape ops are barriers: only the shape, not the
# underlying data, is needed (more general than the JAX dispatch, which only
# recognizes a bare Shape_i).
def expand(v):
owner = v.owner
if owner is None or isinstance(owner.op, Shape | Shape_i):
return None
return owner.inputs

return all(
v.owner is not None or isinstance(v, Constant) for v in walk([var], expand)
)
Comment on lines +202 to +204

@ricardoV94 ricardoV94 Jul 2, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would treat a root input as static? pt.arange(pt.lscalar("end"))



def _arange_static_bound(arg):
try:
start, stop, step = [
get_scalar_constant_value(arg).item() for arg in node.inputs
]
return get_scalar_constant_value(arg).item()
except NotScalarConstantError:
raise NotImplementedError(ARANGE_CONCRETE_VALUE_ERROR)
return None


def _arange_runtime_bound(value):
return value.item() if hasattr(value, "item") else value


@mlx_funcify.register(ARange)
def mlx_funcify_ARange(op, node, **kwargs):
# mx.arange only accepts Python int/float. Bake constant bounds and resolve
# shape-derived ones at runtime (concrete under mx.compile even when the
# static shape is unknown); reject genuinely data-dependent bounds up front.
if not all(_arange_bound_is_static(arg) for arg in node.inputs):
raise NotImplementedError(ARANGE_DATA_DEPENDENT_ERROR)

dtype = convert_dtype_to_mlx(op.dtype)
static_args = [_arange_static_bound(arg) for arg in node.inputs]

def arange(*_args):
return mx.arange(start, stop, step, dtype=dtype)
def arange(*args):
resolved = [
static if static is not None else _arange_runtime_bound(runtime)
for static, runtime in zip(static_args, args, strict=True)
]
return mx.arange(*resolved, dtype=dtype)

return arange

Expand Down
44 changes: 44 additions & 0 deletions tests/link/mlx/test_tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,47 @@ def test_arange():
out = arange(1, 10, 2)

compare_mlx_and_py([], [out], [])


def test_arange_dynamic_shape():
# Shape-derived bounds are concrete under mx.compile even when the static
# shape is unknown, so a genuinely dynamic length must work (regression: this
# used to raise NotImplementedError because of an over-aggressive constant
# check). Exercises every position (start/stop/step) being shape-derived, an
# offset, and an empty result.
x = pt.vector("x")
y = pt.vector("y")
outs = [
arange(x.shape[0]), # dynamic stop
arange(x.shape[0] + 2), # shape-derived expression
arange(x.shape[0], y.shape[0]), # dynamic start and stop
arange(0, y.shape[0], x.shape[0]), # dynamic step
arange(y.shape[0], x.shape[0]), # start > stop -> empty
]
compare_mlx_and_py(
[x, y],
outs,
[np.zeros(3, dtype="float32"), np.zeros(7, dtype="float32")],
)


def test_arange_dynamic_advanced_index():
# The motivating case: a vectorized gather lowers to advanced indexing that
# internally builds arange(idx.shape[0]) with a runtime-dynamic length.
logp = pt.matrix("logp")
targets = pt.lvector("targets")
out = logp[arange(targets.shape[0]), targets]
compare_mlx_and_py(
[logp, targets],
[out],
[np.arange(12, dtype="float32").reshape(3, 4), np.array([0, 2, 3])],
)


def test_arange_data_dependent_raises():
# A genuinely data-dependent length has a runtime-only output shape, which MLX
# cannot compile. This must fail loudly (at compile time) rather than silently.
x = pt.vector("x")
out = arange(pt.sum(x > 0).astype("int64"))
with pytest.raises(NotImplementedError, match="data-dependent length"):
pytensor.function([x], out, mode=compile_mode)
Loading