Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pytensor/tensor/rewriting/indexed_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,12 @@ def _extract_idx_axis_pairs(node, *, write=False):
pairs = []
for axis, entry in enumerate(op.idx_list):
if isinstance(entry, slice):
# The fused loop substitutes the full source array and iterates
# non-indexed axes wholesale, so it can only carry a full slice.
# A bounded/stepped basic slice would change the axis extent or
# offset, which it can't represent -- don't fuse.
if entry != slice(None):
return None
continue
idx = idx_vars[entry]
if not isinstance(idx, TensorVariable) or idx.type.dtype == "bool":
Expand Down
230 changes: 202 additions & 28 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytensor.assumptions.core import UNIQUE_INDICES, check_assumption
from pytensor.compile import optdb
from pytensor.graph.basic import Constant, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import (
WalkingGraphRewriter,
copy_stack_trace,
Expand All @@ -23,6 +24,7 @@
Alloc,
ARange,
Join,
Nonzero,
ScalarFromTensor,
TensorFromScalar,
alloc,
Expand Down Expand Up @@ -61,6 +63,7 @@
)
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.shape import (
Reshape,
Shape,
Shape_i,
shape_padleft,
Expand All @@ -76,6 +79,7 @@
IncSubtensor,
Subtensor,
_is_provably_non_negative,
_is_provably_positive,
_non_consecutive_adv_indexing,
advanced_inc_subtensor1,
advanced_subtensor1,
Expand Down Expand Up @@ -232,12 +236,169 @@ def _constant_has_unique_indices(idx) -> bool:
return result


def _has_unique_indices(fgraph, idx) -> bool:
"""Whether ``idx``'s entries are provably duplicate-free: a constant with
unique entries, or a variable asserted ``unique_indices`` by the user."""
return _constant_has_unique_indices(idx) or check_assumption(
fgraph, idx, UNIQUE_INDICES
)
def _constant_int_or_none(var) -> int | None:
"""The integer value of ``var`` if it is a scalar constant, else ``None``."""
try:
return int(get_scalar_constant_value(var))
except NotScalarConstantError:
return None


def _arange_provably_unique(start, stop, step) -> bool:
"""Whether ``arange(start, stop, step)`` selects each position at most once.

Its entries are always distinct values; they map to distinct positions as
long as they don't wrap around zero, i.e. they all share a sign (``arange(-2,
2)`` aliases on a size-2 axis, but ``arange(2, 6)`` and ``arange(-6, -2)`` do
not). This is proved from whichever bounds are statically known.
"""
start_c, stop_c, step_c = (_constant_int_or_none(v) for v in (start, stop, step))

# Both endpoints non-negative -> every entry lies between them -> all >= 0,
# whatever the step direction (covers symbolic ``arange(x.shape[0])``).
if _is_provably_non_negative(start) and _is_provably_non_negative(stop):
return True

# With a known direction only one endpoint binds each sign. Ascending entries
# span ``[start, stop)`` (min is ``start``, max ``< stop``); descending entries
# span ``(stop, start]`` (max is ``start``, min ``> stop``).
if _is_provably_positive(step): # ascending
if _is_provably_non_negative(start): # all >= 0
return True
if stop_c is not None and stop_c <= 0: # all <= -1
return True
elif step_c is not None and step_c < 0: # descending
if _is_provably_non_negative(stop): # all >= 0 (e.g. arange(k, 5, -1))
return True
if start_c is not None and start_c < 0: # all <= -1 (e.g. arange(-1, k, -1))
return True

# Fully constant: compute the exact entry range without materializing it. The
# checks above only bound the last entry as past ``stop``, so they need ``stop``
# on the right side of zero. This catches the rest -- ranges whose last entry
# overshoots ``stop`` across zero yet stays single-signed, e.g.
# ``arange(6, -2, -2)`` is ``[6, 4, 2, 0]`` and ``arange(-5, 1, 3)`` is ``[-5, -2]``.
if (
start_c is not None
and stop_c is not None
and step_c is not None
and step_c != 0
):
n = max(0, -(-(stop_c - start_c) // step_c)) # length, via ceil division
if n <= 1:
return True
last_c = start_c + (n - 1) * step_c
return min(start_c, last_c) >= 0 or max(start_c, last_c) < 0
return False


def _index_provably_unique(idx, fgraph: FunctionGraph | None) -> bool:
"""Whether a single index selects each position on its own axis at most once.

This is the duplicate-free reasoning shared by accumulation gating (where
repeated positions would scatter-add) and by ``_index_provably_not_larger``
(a duplicate-free index can't enlarge its axis). It excludes only the
statically-smaller fallback, which bounds the size without ruling out
repeated positions.

``fgraph`` is the handle to the ``AssumptionFeature``: it lets a user-declared
``unique_indices`` assumption prove uniqueness when the static value, ``arange``
shape, or view structure can't. Pass ``None`` to skip that leg (no assumptions).
"""
if isinstance(idx, slice) or idx.ndim == 0:
return True
if all(idx.type.broadcastable):
return True
if idx.type.dtype == "bool":
return True
if _constant_has_unique_indices(idx):
return True
if check_assumption(fgraph, idx, UNIQUE_INDICES):
return True
if isinstance(idx.owner_op, ARange):
return _arange_provably_unique(*idx.owner.inputs)
if isinstance(idx.owner_op, Reshape | DimShuffle):
# Views that only reorder or insert size-1 dims keep the value multiset.
return _index_provably_unique(idx.owner.inputs[0], fgraph)
return False


def _constants_jointly_unique(consts) -> bool:
"""Whether stacked constant indices have no duplicate coordinate tuples.

The stacked ``np.unique`` can be expensive on large indices, so the result
is cached on the first constant's tag. Uniqueness is a property of the whole
group, and a constant may belong to several groups (constants are shared
across the graph), so the cache is keyed by the group's identities rather
than a single flag.
"""
key = tuple(id(c) for c in consts)
cache = getattr(consts[0].tag, "jointly_unique_indices", None)
if cache is None:
cache = consts[0].tag.jointly_unique_indices = {}
if key not in cache:
datas = [np.asarray(c.data) for c in consts]
# A coordinate axis that mixes positive and negative values may alias
# (``0`` and ``-dim`` are the same position), so distinctness of the raw
# values no longer proves distinctness of the coordinates.
if any((data >= 0).any() and (data < 0).any() for data in datas):
cache[key] = False
else:
coords = np.broadcast_arrays(*datas)
stacked = np.stack([coord.ravel() for coord in coords])
cache[key] = bool(np.unique(stacked, axis=1).shape[1] == stacked.shape[1])
return bool(cache[key])


def _indices_jointly_unique(idxs, fgraph: FunctionGraph | None) -> bool:
"""Whether advanced indices produce no duplicate joint coordinate tuples.

For accumulation (``inc``), and for bounding a gather by the indexed axes'
size, what matters is that the broadcast coordinate tuples
``(idx0[k], idx1[k], ...)`` are all distinct. Sufficient conditions, in
increasing generality:

- every index is duplicate-free on its own axis, so the tuples are distinct
regardless of the others (sound under broadcasting, and the path basic
slice/scalar indexing trivially takes);
- the indices are all the coordinates of a single ``Nonzero``, distinct by
construction (e.g. symbolic ``tril_indices``);
- the indices are all constants whose stacked coordinate tuples have no
duplicates (catches cases where no single axis is unique on its own).

``fgraph`` is forwarded to ``_index_provably_unique`` so a user-declared
``unique_indices`` assumption can satisfy the per-axis leg; pass ``None`` to
skip assumptions.
"""
if all(_index_provably_unique(idx, fgraph) for idx in idxs):
return True
if len(idxs) > 1:
owners = {idx.owner for idx in idxs}
if (
len(owners) == 1
and (owner := next(iter(owners))) is not None
and isinstance(owner.op, Nonzero)
and set(idxs) == set(owner.outputs)
):
return True
if all(isinstance(idx, Constant) for idx in idxs):
return _constants_jointly_unique(idxs)
return False


def _advanced_indices_jointly_unique(indices, fgraph: FunctionGraph | None) -> bool:
"""Whether the advanced (``ndim > 0`` tensor) indices in a reconstructed index
tuple have distinct joint coordinate tuples.

Slice and scalar (basic) indices are ignored: basic indexing is trivially
duplicate-free, so only the advanced array indices are weighed. ``fgraph`` is
forwarded to ``_indices_jointly_unique`` for the ``unique_indices`` assumption
lookup (see there); pass ``None`` to skip assumptions.
"""
adv_idxs = [
idx for idx in indices if isinstance(idx, TensorVariable) and idx.type.ndim > 0
]
return _indices_jointly_unique(adv_idxs, fgraph)


def _constant_is_arange(idx) -> tuple[int, int, int] | None:
Expand Down Expand Up @@ -1105,6 +1266,9 @@ def local_set_to_inc_subtensor(fgraph, node):
AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) ->
AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False)

Only valid when ``ilist`` is duplicate-free: a dense set is last-wins while an
inc accumulates, so duplicate indices would over-count.

TODO FIXME: Why doesn't this apply to all `*IncSubtensor*` `Op`\s? If it
did this wouldn't need to also be included in the "specialize" pass.

Expand Down Expand Up @@ -1132,6 +1296,11 @@ def local_set_to_inc_subtensor(fgraph, node):
return
if subn.inputs[1] != node.inputs[2] or subn.inputs[0] != node.inputs[0]:
return
# set->inc is only valid when ilist is duplicate-free: ``set(x[ilist] + other)``
# is last-wins at repeated positions, while ``inc(other)`` would accumulate the
# contributions of every occurrence and over-count them.
if not _index_provably_unique(node.inputs[2], fgraph):
return
ret = advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2])

copy_stack_trace(node.outputs, ret)
Expand All @@ -1150,8 +1319,11 @@ def local_add_of_sparse_write(fgraph, node):
Adding it to another tensor is equivalent to incrementing in place, which
avoids materialising the dense sparse representation.

Also handles ``zeros[idx].inc(v)`` when ``idx`` is duplicate-free, since
with unique indices inc is semantically equivalent to set.
The ``zeros[idx].inc(v)`` form is rewritten unconditionally: inc applies the
same per-position delta whether the base is zeros (then added to ``x``) or
``x`` itself, so duplicate indices accumulate identically on both sides. Only
the ``zeros[idx].set(v)`` form needs duplicate-free indices, since a dense set
is last-wins and collapsing it to an inc would over-count repeats.
"""
for i, sparse_candidate in enumerate(node.inputs):
if not (
Expand All @@ -1174,11 +1346,17 @@ def local_add_of_sparse_write(fgraph, node):
):
continue

# An inc into zeros is only equivalent to a set when indices are
# duplicate-free. Basic (slice/scalar) indexing is always unique;
# advanced integer-array indices must be checked.
if not inner_op.set_instead_of_inc and not isinstance(inner_op, IncSubtensor):
if not all(_has_unique_indices(fgraph, idx) for idx in idx_vars):
# Only the set->inc conversion needs duplicate-free indices. An inc into
# zeros and the resulting inc into ``other`` apply the same per-position
# delta (accumulating any duplicates identically), so ``x + zeros[idx].inc(v)
# -> x[idx].inc(v)`` holds for any indices. A dense set, by contrast, is
# last-wins, so collapsing it to an inc would over-count repeated
# positions. Basic (slice/scalar) IncSubtensor is always unique; advanced
# integer-array set indices must be jointly duplicate-free, weighing only
# the advanced indices and not the flattened slice bounds.
if inner_op.set_instead_of_inc and not isinstance(inner_op, IncSubtensor):
indices = indices_from_subtensor(idx_vars, inner_op.idx_list)
if not _advanced_indices_jointly_unique(indices, fgraph):
continue

others = [node.inputs[j] for j in range(len(node.inputs)) if j != i]
Expand All @@ -1191,7 +1369,7 @@ def local_add_of_sparse_write(fgraph, node):
else:
new_op = inner_op
r = new_op(other, v, *idx_vars)
copy_stack_trace(node.outputs[0], r)
copy_stack_trace([node.outputs[0], sparse_candidate], r)
return [r]

return None
Expand Down Expand Up @@ -1966,9 +2144,8 @@ def local_read_of_write_same_indices(fgraph, node):
Applies when the outer read and inner write share identical index
variables (``is`` check) and the same ``idx_list``. The inc case
additionally requires duplicate-free indices: slices and scalar indices
are trivially unique, while integer-array indices must be constant with
no repeated entries (mixing positive and negative values counts as
potentially duplicated since they may alias).
are trivially unique, while advanced integer-array indices must be jointly
duplicate-free (see ``_indices_jointly_unique``).

Companion rewrites:

Expand Down Expand Up @@ -2005,13 +2182,11 @@ def local_read_of_write_same_indices(fgraph, node):
copy_stack_trace(out, r)
return [r]
else:
# Inc case: advanced integer-array indices must be duplicate-free;
# Inc case: advanced integer-array indices must be jointly duplicate-free;
# slices and scalar indices are trivially unique.
indices = indices_from_subtensor(outer_idx_vars, node.op.idx_list)
for idx in indices:
if isinstance(idx, TensorVariable) and idx.type.ndim > 0:
if not _has_unique_indices(fgraph, idx):
return None
if not _advanced_indices_jointly_unique(indices, fgraph):
return None

x_at_idx = x[tuple(indices)]
copy_stack_trace(out, x_at_idx)
Expand Down Expand Up @@ -2366,13 +2541,12 @@ def local_write_of_write_same_indices(fgraph, node):
new_val = b
use_set = True
elif inner_is_set:
# x[idx].set(a)[idx].inc(b) — needs unique indices.
# Basic indexing (slices/scalars) is always duplicate-free.
# For advanced indexing, per-axis uniqueness is conservative but
# sufficient: it guarantees no duplicates in the joint cross-product
# after broadcasting.
# x[idx].set(a)[idx].inc(b) — needs unique indices. Basic indexing
# (slices/scalars) is always duplicate-free; advanced indices must have
# duplicate-free joint coordinate tuples.
if not isinstance(node.op, IncSubtensor):
if not all(_has_unique_indices(fgraph, v) for v in outer_idx_vars):
indices = indices_from_subtensor(outer_idx_vars, node.op.idx_list)
if not _advanced_indices_jointly_unique(indices, fgraph):
return
new_val = a + b
if (
Expand Down
Loading
Loading