From 75ab3ed4e344f667ce20a9bce36beb4202443132 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 27 Jun 2026 14:33:31 +0200 Subject: [PATCH 1/3] Don't fuse indexed reads/writes through a non-empty basic slice --- pytensor/tensor/rewriting/indexed_elemwise.py | 6 +++ tests/link/numba/test_indexed_elemwise.py | 39 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/indexed_elemwise.py index b3ff3828b3..0f8ac6d7d1 100644 --- a/pytensor/tensor/rewriting/indexed_elemwise.py +++ b/pytensor/tensor/rewriting/indexed_elemwise.py @@ -368,6 +368,12 @@ def _extract_idx_axis_pairs(node, *, write=False): pairs = [] for axis, entry in enumerate(op.idx_list): if isinstance(entry, slice): + # The fused loop substitutes the full source array and iterates + # non-indexed axes wholesale, so it can only carry a full slice. + # A bounded/stepped basic slice would change the axis extent or + # offset, which it can't represent -- don't fuse. + if entry != slice(None): + return None continue idx = idx_vars[entry] if not isinstance(idx, TensorVariable) or idx.type.dtype == "bool": diff --git a/tests/link/numba/test_indexed_elemwise.py b/tests/link/numba/test_indexed_elemwise.py index 22ffbe5927..386640c636 100644 --- a/tests/link/numba/test_indexed_elemwise.py +++ b/tests/link/numba/test_indexed_elemwise.py @@ -603,6 +603,45 @@ def test_runtime_broadcast_on_index_dim(self): with pytest.raises(Exception): fn(np.zeros(100), np.zeros(1, dtype=np.int64), np.zeros(5)) + def test_no_fusion_with_bounded_basic_slice_read(self): + """Regression: a bounded basic slice on a non-indexed axis can't be + carried by the fused loop, which substitutes the full source array and + iterates non-indexed axes wholesale. Fusing ``x[1:4, idx]`` dropped the + ``1:4`` offset and iterated x's full axis 0 against y's, raising at + runtime. The slice must block fusion and results stay correct.""" + rng = np.random.default_rng(2202) + x = pt.matrix("x", shape=(6, 6)) + y = pt.matrix("y", shape=(3, 3)) + idx = pt.constant(np.array([0, 2, 4])) + + out = x[1:4, idx] + y + fn = function([x, y], out, mode=NUMBA_MODE, trust_input=True) + assert not any( + isinstance(n.op, IndexedElemwise) for n in fn.maker.fgraph.toposort() + ) + + ref = function([x, y], out, mode=NUMBA_NO_FUSION, trust_input=True) + xv, yv = rng.normal(size=(6, 6)), rng.normal(size=(3, 3)) + np.testing.assert_allclose(fn(xv, yv), ref(xv, yv), rtol=1e-10) + + def test_no_fusion_with_bounded_basic_slice_write(self): + """As above for an indexed write: a bounded basic slice on the write + target's non-indexed axis must block fusion.""" + rng = np.random.default_rng(2203) + t = pt.matrix("t", shape=(6, 6)) + y = pt.matrix("y", shape=(3, 3)) + idx = pt.constant(np.array([0, 2, 4])) + + out = t[1:4, idx].set(pt.exp(y)) + fn = function([t, y], out, mode=NUMBA_MODE, trust_input=True) + assert not any( + isinstance(n.op, IndexedElemwise) for n in fn.maker.fgraph.toposort() + ) + + ref = function([t, y], out, mode=NUMBA_NO_FUSION, trust_input=True) + tv, yv = rng.normal(size=(6, 6)), rng.normal(size=(3, 3)) + np.testing.assert_allclose(fn(tv.copy(), yv), ref(tv.copy(), yv), rtol=1e-10) + def test_loop_shape_regression(self): """ Regression test for https://github.com/pymc-devs/pytensor/issues/2201 From bd868daacfd667e2ad8ff9d84dff73bdd16f13f3 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 27 Jun 2026 14:22:43 +0200 Subject: [PATCH 2/3] Fix unsound x[idx].set -> x[idx].inc rewrites with non-uinque indices --- pytensor/tensor/rewriting/subtensor.py | 45 ++++++-- tests/tensor/rewriting/test_subtensor.py | 129 ++++++++++++++++++++--- 2 files changed, 148 insertions(+), 26 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index b00bc342f3..c9f682f110 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -1105,6 +1105,9 @@ def local_set_to_inc_subtensor(fgraph, node): AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) -> AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False) + Only valid when ``ilist`` is duplicate-free: a dense set is last-wins while an + inc accumulates, so duplicate indices would over-count. + TODO FIXME: Why doesn't this apply to all `*IncSubtensor*` `Op`\s? If it did this wouldn't need to also be included in the "specialize" pass. @@ -1132,6 +1135,11 @@ def local_set_to_inc_subtensor(fgraph, node): return if subn.inputs[1] != node.inputs[2] or subn.inputs[0] != node.inputs[0]: return + # set->inc is only valid when ilist is duplicate-free: ``set(x[ilist] + other)`` + # is last-wins at repeated positions, while ``inc(other)`` would accumulate the + # contributions of every occurrence and over-count them. + if not _has_unique_indices(fgraph, node.inputs[2]): + return ret = advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2]) copy_stack_trace(node.outputs, ret) @@ -1150,8 +1158,11 @@ def local_add_of_sparse_write(fgraph, node): Adding it to another tensor is equivalent to incrementing in place, which avoids materialising the dense sparse representation. - Also handles ``zeros[idx].inc(v)`` when ``idx`` is duplicate-free, since - with unique indices inc is semantically equivalent to set. + The ``zeros[idx].inc(v)`` form is rewritten unconditionally: inc applies the + same per-position delta whether the base is zeros (then added to ``x``) or + ``x`` itself, so duplicate indices accumulate identically on both sides. Only + the ``zeros[idx].set(v)`` form needs duplicate-free indices, since a dense set + is last-wins and collapsing it to an inc would over-count repeats. """ for i, sparse_candidate in enumerate(node.inputs): if not ( @@ -1174,11 +1185,21 @@ def local_add_of_sparse_write(fgraph, node): ): continue - # An inc into zeros is only equivalent to a set when indices are - # duplicate-free. Basic (slice/scalar) indexing is always unique; - # advanced integer-array indices must be checked. - if not inner_op.set_instead_of_inc and not isinstance(inner_op, IncSubtensor): - if not all(_has_unique_indices(fgraph, idx) for idx in idx_vars): + # Only the set->inc conversion needs duplicate-free indices. An inc into + # zeros and the resulting inc into ``other`` apply the same per-position + # delta (accumulating any duplicates identically), so ``x + zeros[idx].inc(v) + # -> x[idx].inc(v)`` holds for any indices. A dense set, by contrast, is + # last-wins, so collapsing it to an inc would over-count repeated + # positions. Basic (slice/scalar) IncSubtensor is always unique; advanced + # integer-array set indices must be checked, weighing only the advanced + # indices and not the flattened slice bounds. + if inner_op.set_instead_of_inc and not isinstance(inner_op, IncSubtensor): + adv_idxs = [ + idx + for idx in indices_from_subtensor(idx_vars, inner_op.idx_list) + if isinstance(idx, TensorVariable) and idx.type.ndim > 0 + ] + if not all(_has_unique_indices(fgraph, idx) for idx in adv_idxs): continue others = [node.inputs[j] for j in range(len(node.inputs)) if j != i] @@ -2370,9 +2391,15 @@ def local_write_of_write_same_indices(fgraph, node): # Basic indexing (slices/scalars) is always duplicate-free. # For advanced indexing, per-axis uniqueness is conservative but # sufficient: it guarantees no duplicates in the joint cross-product - # after broadcasting. + # after broadcasting. Weigh only the advanced indices, not the flattened + # slice bounds. if not isinstance(node.op, IncSubtensor): - if not all(_has_unique_indices(fgraph, v) for v in outer_idx_vars): + adv_idxs = [ + idx + for idx in indices_from_subtensor(outer_idx_vars, node.op.idx_list) + if isinstance(idx, TensorVariable) and idx.type.ndim > 0 + ] + if not all(_has_unique_indices(fgraph, idx) for idx in adv_idxs): return new_val = a + b if ( diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 1c1cbb7bb0..89f7582d5b 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -218,8 +218,10 @@ def test_local_add_of_sparse_write(): """``x + set(zeros, v, idx) -> inc(x, v, idx)``: avoid materialising the dense sparse representation when adding a sparse set into a base. - Also covers ``x + inc(zeros, v, idx)`` when ``idx`` is duplicate-free, - since then inc-into-zeros is equivalent to set-into-zeros. + The set form needs duplicate-free ``idx`` (a dense set is last-wins). The + inc form ``x + inc(zeros, v, idx)`` is rewritten unconditionally: inc applies + the same per-position delta on both sides, so duplicates accumulate + identically. """ sparse_rewriter = in2out(local_add_of_sparse_write, name="add_of_sparse_write") @@ -227,18 +229,30 @@ def test_local_add_of_sparse_write(): v = vector("v") idx = ivector("idx") - # set-into-zeros is always rewritten. - out = x + pt.zeros(x.shape)[idx].set(v) - expected = x[idx].inc(v) - rewritten = rewrite_graph(out) - utt.assert_equal_computations([rewritten], [expected], strict_dtype=False) + # set-into-zeros with a provably unique index is rewritten: a dense set into + # zeros equals a sparse inc only when each position is written exactly once. + cst = np.array([1, 3]) + out = x + pt.zeros(x.shape)[cst].set(v) + rewritten = rewrite_graph(out, include=[], custom_rewrite=sparse_rewriter) + utt.assert_equal_computations([rewritten], [x[cst].inc(v)], strict_dtype=False) - f = function([x, v, idx], out) - f_ref = function([x, v, idx], out, mode=Mode(linker="py", optimizer=None)) + f = function([x, v], out) + f_ref = function([x, v], out, mode=Mode(linker="py", optimizer=None)) dx = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=config.floatX) dv = np.array([10.0, 20.0], dtype=config.floatX) - didx = np.array([1, 3], dtype="int32") - np.testing.assert_allclose(f(dx, dv, didx), f_ref(dx, dv, didx)) + np.testing.assert_allclose(f(dx, dv), f_ref(dx, dv)) + + # set-into-zeros with a possibly-duplicated index is left alone: a dense set + # is last-wins, while a sparse inc would accumulate at repeated positions. + # ``assert_eval`` with a duplicated index pins the soundness independently of + # the graph check: a wrongly-collapsed inc would accumulate at the repeated + # position (index 1 -> 2 + 10 + 20 instead of 2 + 20). + out_set_unsafe = x + pt.zeros(x.shape)[idx].set(v) + result = utt.RewriteTester( + [x, v, idx], [out_set_unsafe], include=[], custom_rewrite=sparse_rewriter + ) + result.assert_graph(out_set_unsafe) + result.assert_eval(dx, dv, np.array([1, 1], dtype="int32")) # inc-into-zeros with unique constant indices is rewritten. out_inc = x + pt.zeros(x.shape)[np.array([1, 3])].inc(v) @@ -248,13 +262,16 @@ def test_local_add_of_sparse_write(): ) # inc-into-zeros with a non-constant (potentially duplicated) index is - # left alone. Run the rewrite in isolation so other simplifications - # don't obscure what happens. - out_unsafe = x + pt.zeros(x.shape)[idx].inc(v) - rewritten_unsafe = rewrite_graph( - out_unsafe, include=[], custom_rewrite=sparse_rewriter + # rewritten unconditionally: inc accumulates the same per-position delta + # whether the base is zeros (then added to x) or x itself. ``assert_eval`` + # with a duplicated index pins the soundness: index 1 -> 2 + 10 + 20 on both + # the original and the rewritten graph. + out_dup = x + pt.zeros(x.shape)[idx].inc(v) + result_dup = utt.RewriteTester( + [x, v, idx], [out_dup], include=[], custom_rewrite=sparse_rewriter ) - utt.assert_equal_computations([rewritten_unsafe], [out_unsafe]) + result_dup.assert_graph(x[idx].inc(v)) + result_dup.assert_eval(dx, dv, np.array([1, 1], dtype="int32")) # Basic (scalar) inc-into-zeros is trivially unique and should be rewritten. s = iscalar("s") @@ -266,6 +283,21 @@ def test_local_add_of_sparse_write(): [rewritten_basic], [x[s].inc(v[0])], strict_dtype=False ) + # A bounded slice flattens its (symbolic) bounds into the index variables; + # those must not be mistaken for advanced indices. With a leading slice and a + # unique advanced index the sparse write still collapses. + X = matrix("X") + w = matrix("w") + u = pt.constant(np.array([0, 2], dtype="int32")) + lo, hi = iscalar("lo"), iscalar("hi") + out_slice = X + pt.zeros(X.shape)[lo:hi, u].set(w) + rewritten_slice = rewrite_graph( + out_slice, include=[], custom_rewrite=sparse_rewriter + ) + utt.assert_equal_computations( + [rewritten_slice], [X[lo:hi, u].inc(w)], strict_dtype=False + ) + class TestLocalUselessSubtensor: x = matrix("x") @@ -1744,6 +1776,23 @@ def test_inc_of_set_zero_base_emits_inc(self): rewritten = rewrite_graph(out, include=("canonicalize", "specialize")) utt.assert_equal_computations([rewritten], [inc_subtensor(zeros[:stop], a + b)]) + def test_inc_of_set_advanced_with_slice_rewritten(self): + """A bounded slice flattens its (symbolic) bounds into the index + variables; those must not be mistaken for advanced indices and block the + uniqueness check. With a leading slice and a unique advanced index the + inc-of-set still collapses to ``x[lo:hi, idx].set(a + b)``.""" + x = tensor3("x", dtype="float64") + a = matrix("a", dtype="float64") + b = matrix("b", dtype="float64") + lo, hi = iscalar("lo"), iscalar("hi") + idx = pt.constant(np.array([0, 2], dtype="int32")) + + out = inc_subtensor(set_subtensor(x[lo:hi, idx], a)[lo:hi, idx], b) + rewritten = rewrite_graph(out, include=("canonicalize", "specialize")) + utt.assert_equal_computations( + [rewritten], [set_subtensor(x[lo:hi, idx], a + b)] + ) + def test_inc_of_set_advanced_non_unique_not_rewritten(self): """Inc-of-set requires unique indices; duplicate constant indices on advanced axes block the rewrite.""" @@ -2070,6 +2119,52 @@ def test_local_set_to_inc_subtensor(): assert check_stack_trace(f2, ops_to_check="all") +def test_local_set_to_inc_subtensor_duplicate_indices(): + """``set(x[idx] + other)`` collapses to ``inc(x, other, idx)`` only when + ``idx`` is duplicate-free: a dense set is last-wins while inc accumulates, so + with repeated indices the inc form over-counts. The rewrite must not fire on a + possibly-duplicated symbolic index, and a wrongly-collapsed inc would diverge + at the repeated position (index 1 -> 20 + 1 + 2 instead of 20 + 2).""" + v = vector("v") + other = vector("other") + idx = ivector("idx") + + out = set_subtensor(v[idx], v[idx] + other) + + mode = ( + get_default_mode() + .including( + "local_replace_AdvancedSubtensor", + "local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1", + ) + .excluding("fuse_indexed_into_elemwise") + ) + f = function([v, other, idx], out, mode=mode) + + # The symbolic index is not provably unique, so the set survives. + assert all( + n.op.set_instead_of_inc + for n in f.maker.fgraph.toposort() + if isinstance(n.op, AdvancedIncSubtensor1) + ) + + # Soundness pinned against a no-rewrite reference at a duplicated index. + dv = np.array([10.0, 20.0, 30.0], dtype=v.dtype) + dother = np.array([1.0, 2.0], dtype=v.dtype) + didx = np.array([1, 1], dtype="int32") + f_ref = function([v, other, idx], out, mode=Mode(linker="py", optimizer=None)) + np.testing.assert_allclose(f(dv, dother, didx), f_ref(dv, dother, didx)) + + # A constant duplicated index is also left alone. + out_const = set_subtensor(v[[1, 1]], v[[1, 1]] + other) + f_const = function([v, other], out_const, mode=mode) + assert all( + n.op.set_instead_of_inc + for n in f_const.maker.fgraph.toposort() + if isinstance(n.op, AdvancedIncSubtensor1) + ) + + @pytest.mark.parametrize( "axis, slices_fn, expected_nodes", [ From a727559be2675477de39bfcb8c0933fe6326b985 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 27 Jun 2026 14:23:11 +0200 Subject: [PATCH 3/3] Subtensor rewrites: reason about advanced indices jointly when gating Generalize the duplicate-free reasoning used to gate the advanced-index write rewrites from a per-axis check to a joint one over the whole advanced index group, sharing the logic across subtensor.py and subtensor_lift.py. - _index_provably_unique: per-axis uniqueness, now also proving single-signed aranges and views (Reshape/DimShuffle) that preserve the value multiset. - _indices_jointly_unique: distinct joint coordinate tuples via all-axes uniqueness, a single Nonzero (e.g. tril_indices), or jointly-unique constants where no single axis is unique on its own. - _indices_provably_not_larger: bound a gather by the indexed axes' size using static shapes, per-axis bounds, or joint uniqueness. --- pytensor/tensor/rewriting/subtensor.py | 219 +++++++++++++--- pytensor/tensor/rewriting/subtensor_lift.py | 119 ++++++--- tests/tensor/rewriting/test_subtensor.py | 236 +++++++++++++++++- tests/tensor/rewriting/test_subtensor_lift.py | 45 ++++ 4 files changed, 534 insertions(+), 85 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index c9f682f110..9ba313e242 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -9,6 +9,7 @@ from pytensor.assumptions.core import UNIQUE_INDICES, check_assumption from pytensor.compile import optdb from pytensor.graph.basic import Constant, Variable +from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import ( WalkingGraphRewriter, copy_stack_trace, @@ -23,6 +24,7 @@ Alloc, ARange, Join, + Nonzero, ScalarFromTensor, TensorFromScalar, alloc, @@ -61,6 +63,7 @@ ) from pytensor.tensor.rewriting.blockwise import blockwise_of from pytensor.tensor.shape import ( + Reshape, Shape, Shape_i, shape_padleft, @@ -76,6 +79,7 @@ IncSubtensor, Subtensor, _is_provably_non_negative, + _is_provably_positive, _non_consecutive_adv_indexing, advanced_inc_subtensor1, advanced_subtensor1, @@ -232,12 +236,169 @@ def _constant_has_unique_indices(idx) -> bool: return result -def _has_unique_indices(fgraph, idx) -> bool: - """Whether ``idx``'s entries are provably duplicate-free: a constant with - unique entries, or a variable asserted ``unique_indices`` by the user.""" - return _constant_has_unique_indices(idx) or check_assumption( - fgraph, idx, UNIQUE_INDICES - ) +def _constant_int_or_none(var) -> int | None: + """The integer value of ``var`` if it is a scalar constant, else ``None``.""" + try: + return int(get_scalar_constant_value(var)) + except NotScalarConstantError: + return None + + +def _arange_provably_unique(start, stop, step) -> bool: + """Whether ``arange(start, stop, step)`` selects each position at most once. + + Its entries are always distinct values; they map to distinct positions as + long as they don't wrap around zero, i.e. they all share a sign (``arange(-2, + 2)`` aliases on a size-2 axis, but ``arange(2, 6)`` and ``arange(-6, -2)`` do + not). This is proved from whichever bounds are statically known. + """ + start_c, stop_c, step_c = (_constant_int_or_none(v) for v in (start, stop, step)) + + # Both endpoints non-negative -> every entry lies between them -> all >= 0, + # whatever the step direction (covers symbolic ``arange(x.shape[0])``). + if _is_provably_non_negative(start) and _is_provably_non_negative(stop): + return True + + # With a known direction only one endpoint binds each sign. Ascending entries + # span ``[start, stop)`` (min is ``start``, max ``< stop``); descending entries + # span ``(stop, start]`` (max is ``start``, min ``> stop``). + if _is_provably_positive(step): # ascending + if _is_provably_non_negative(start): # all >= 0 + return True + if stop_c is not None and stop_c <= 0: # all <= -1 + return True + elif step_c is not None and step_c < 0: # descending + if _is_provably_non_negative(stop): # all >= 0 (e.g. arange(k, 5, -1)) + return True + if start_c is not None and start_c < 0: # all <= -1 (e.g. arange(-1, k, -1)) + return True + + # Fully constant: compute the exact entry range without materializing it. The + # checks above only bound the last entry as past ``stop``, so they need ``stop`` + # on the right side of zero. This catches the rest -- ranges whose last entry + # overshoots ``stop`` across zero yet stays single-signed, e.g. + # ``arange(6, -2, -2)`` is ``[6, 4, 2, 0]`` and ``arange(-5, 1, 3)`` is ``[-5, -2]``. + if ( + start_c is not None + and stop_c is not None + and step_c is not None + and step_c != 0 + ): + n = max(0, -(-(stop_c - start_c) // step_c)) # length, via ceil division + if n <= 1: + return True + last_c = start_c + (n - 1) * step_c + return min(start_c, last_c) >= 0 or max(start_c, last_c) < 0 + return False + + +def _index_provably_unique(idx, fgraph: FunctionGraph | None) -> bool: + """Whether a single index selects each position on its own axis at most once. + + This is the duplicate-free reasoning shared by accumulation gating (where + repeated positions would scatter-add) and by ``_index_provably_not_larger`` + (a duplicate-free index can't enlarge its axis). It excludes only the + statically-smaller fallback, which bounds the size without ruling out + repeated positions. + + ``fgraph`` is the handle to the ``AssumptionFeature``: it lets a user-declared + ``unique_indices`` assumption prove uniqueness when the static value, ``arange`` + shape, or view structure can't. Pass ``None`` to skip that leg (no assumptions). + """ + if isinstance(idx, slice) or idx.ndim == 0: + return True + if all(idx.type.broadcastable): + return True + if idx.type.dtype == "bool": + return True + if _constant_has_unique_indices(idx): + return True + if check_assumption(fgraph, idx, UNIQUE_INDICES): + return True + if isinstance(idx.owner_op, ARange): + return _arange_provably_unique(*idx.owner.inputs) + if isinstance(idx.owner_op, Reshape | DimShuffle): + # Views that only reorder or insert size-1 dims keep the value multiset. + return _index_provably_unique(idx.owner.inputs[0], fgraph) + return False + + +def _constants_jointly_unique(consts) -> bool: + """Whether stacked constant indices have no duplicate coordinate tuples. + + The stacked ``np.unique`` can be expensive on large indices, so the result + is cached on the first constant's tag. Uniqueness is a property of the whole + group, and a constant may belong to several groups (constants are shared + across the graph), so the cache is keyed by the group's identities rather + than a single flag. + """ + key = tuple(id(c) for c in consts) + cache = getattr(consts[0].tag, "jointly_unique_indices", None) + if cache is None: + cache = consts[0].tag.jointly_unique_indices = {} + if key not in cache: + datas = [np.asarray(c.data) for c in consts] + # A coordinate axis that mixes positive and negative values may alias + # (``0`` and ``-dim`` are the same position), so distinctness of the raw + # values no longer proves distinctness of the coordinates. + if any((data >= 0).any() and (data < 0).any() for data in datas): + cache[key] = False + else: + coords = np.broadcast_arrays(*datas) + stacked = np.stack([coord.ravel() for coord in coords]) + cache[key] = bool(np.unique(stacked, axis=1).shape[1] == stacked.shape[1]) + return bool(cache[key]) + + +def _indices_jointly_unique(idxs, fgraph: FunctionGraph | None) -> bool: + """Whether advanced indices produce no duplicate joint coordinate tuples. + + For accumulation (``inc``), and for bounding a gather by the indexed axes' + size, what matters is that the broadcast coordinate tuples + ``(idx0[k], idx1[k], ...)`` are all distinct. Sufficient conditions, in + increasing generality: + + - every index is duplicate-free on its own axis, so the tuples are distinct + regardless of the others (sound under broadcasting, and the path basic + slice/scalar indexing trivially takes); + - the indices are all the coordinates of a single ``Nonzero``, distinct by + construction (e.g. symbolic ``tril_indices``); + - the indices are all constants whose stacked coordinate tuples have no + duplicates (catches cases where no single axis is unique on its own). + + ``fgraph`` is forwarded to ``_index_provably_unique`` so a user-declared + ``unique_indices`` assumption can satisfy the per-axis leg; pass ``None`` to + skip assumptions. + """ + if all(_index_provably_unique(idx, fgraph) for idx in idxs): + return True + if len(idxs) > 1: + owners = {idx.owner for idx in idxs} + if ( + len(owners) == 1 + and (owner := next(iter(owners))) is not None + and isinstance(owner.op, Nonzero) + and set(idxs) == set(owner.outputs) + ): + return True + if all(isinstance(idx, Constant) for idx in idxs): + return _constants_jointly_unique(idxs) + return False + + +def _advanced_indices_jointly_unique(indices, fgraph: FunctionGraph | None) -> bool: + """Whether the advanced (``ndim > 0`` tensor) indices in a reconstructed index + tuple have distinct joint coordinate tuples. + + Slice and scalar (basic) indices are ignored: basic indexing is trivially + duplicate-free, so only the advanced array indices are weighed. ``fgraph`` is + forwarded to ``_indices_jointly_unique`` for the ``unique_indices`` assumption + lookup (see there); pass ``None`` to skip assumptions. + """ + adv_idxs = [ + idx for idx in indices if isinstance(idx, TensorVariable) and idx.type.ndim > 0 + ] + return _indices_jointly_unique(adv_idxs, fgraph) def _constant_is_arange(idx) -> tuple[int, int, int] | None: @@ -1138,7 +1299,7 @@ def local_set_to_inc_subtensor(fgraph, node): # set->inc is only valid when ilist is duplicate-free: ``set(x[ilist] + other)`` # is last-wins at repeated positions, while ``inc(other)`` would accumulate the # contributions of every occurrence and over-count them. - if not _has_unique_indices(fgraph, node.inputs[2]): + if not _index_provably_unique(node.inputs[2], fgraph): return ret = advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2]) @@ -1191,15 +1352,11 @@ def local_add_of_sparse_write(fgraph, node): # -> x[idx].inc(v)`` holds for any indices. A dense set, by contrast, is # last-wins, so collapsing it to an inc would over-count repeated # positions. Basic (slice/scalar) IncSubtensor is always unique; advanced - # integer-array set indices must be checked, weighing only the advanced - # indices and not the flattened slice bounds. + # integer-array set indices must be jointly duplicate-free, weighing only + # the advanced indices and not the flattened slice bounds. if inner_op.set_instead_of_inc and not isinstance(inner_op, IncSubtensor): - adv_idxs = [ - idx - for idx in indices_from_subtensor(idx_vars, inner_op.idx_list) - if isinstance(idx, TensorVariable) and idx.type.ndim > 0 - ] - if not all(_has_unique_indices(fgraph, idx) for idx in adv_idxs): + indices = indices_from_subtensor(idx_vars, inner_op.idx_list) + if not _advanced_indices_jointly_unique(indices, fgraph): continue others = [node.inputs[j] for j in range(len(node.inputs)) if j != i] @@ -1212,7 +1369,7 @@ def local_add_of_sparse_write(fgraph, node): else: new_op = inner_op r = new_op(other, v, *idx_vars) - copy_stack_trace(node.outputs[0], r) + copy_stack_trace([node.outputs[0], sparse_candidate], r) return [r] return None @@ -1987,9 +2144,8 @@ def local_read_of_write_same_indices(fgraph, node): Applies when the outer read and inner write share identical index variables (``is`` check) and the same ``idx_list``. The inc case additionally requires duplicate-free indices: slices and scalar indices - are trivially unique, while integer-array indices must be constant with - no repeated entries (mixing positive and negative values counts as - potentially duplicated since they may alias). + are trivially unique, while advanced integer-array indices must be jointly + duplicate-free (see ``_indices_jointly_unique``). Companion rewrites: @@ -2026,13 +2182,11 @@ def local_read_of_write_same_indices(fgraph, node): copy_stack_trace(out, r) return [r] else: - # Inc case: advanced integer-array indices must be duplicate-free; + # Inc case: advanced integer-array indices must be jointly duplicate-free; # slices and scalar indices are trivially unique. indices = indices_from_subtensor(outer_idx_vars, node.op.idx_list) - for idx in indices: - if isinstance(idx, TensorVariable) and idx.type.ndim > 0: - if not _has_unique_indices(fgraph, idx): - return None + if not _advanced_indices_jointly_unique(indices, fgraph): + return None x_at_idx = x[tuple(indices)] copy_stack_trace(out, x_at_idx) @@ -2387,19 +2541,12 @@ def local_write_of_write_same_indices(fgraph, node): new_val = b use_set = True elif inner_is_set: - # x[idx].set(a)[idx].inc(b) — needs unique indices. - # Basic indexing (slices/scalars) is always duplicate-free. - # For advanced indexing, per-axis uniqueness is conservative but - # sufficient: it guarantees no duplicates in the joint cross-product - # after broadcasting. Weigh only the advanced indices, not the flattened - # slice bounds. + # x[idx].set(a)[idx].inc(b) — needs unique indices. Basic indexing + # (slices/scalars) is always duplicate-free; advanced indices must have + # duplicate-free joint coordinate tuples. if not isinstance(node.op, IncSubtensor): - adv_idxs = [ - idx - for idx in indices_from_subtensor(outer_idx_vars, node.op.idx_list) - if isinstance(idx, TensorVariable) and idx.type.ndim > 0 - ] - if not all(_has_unique_indices(fgraph, idx) for idx in adv_idxs): + indices = indices_from_subtensor(outer_idx_vars, node.op.idx_list) + if not _advanced_indices_jointly_unique(indices, fgraph): return new_val = a + b if ( diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index d3a4bf5bf6..3c5888a4e1 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -4,7 +4,6 @@ import numpy as np from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple -from pytensor.assumptions.core import UNIQUE_INDICES, check_assumption from pytensor.compile import optdb from pytensor.graph import ( Constant, @@ -22,7 +21,6 @@ from pytensor.tensor.basic import ( Alloc, AllocDiag, - ARange, ExtractDiag, Eye, Join, @@ -49,7 +47,8 @@ ) from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift from pytensor.tensor.rewriting.subtensor import ( - _constant_has_unique_indices, + _index_provably_unique, + _indices_jointly_unique, local_adv_idx_to_diagonal, local_adv_idx_to_slice, local_advanced_read_of_write_constant_indices, @@ -215,33 +214,69 @@ def _lift_subtensor_non_axis( return None -def _index_provably_not_larger(idx, val_static_dim, fgraph=None) -> bool: +def _static_size_not_larger(idx, val_static_dim) -> bool: + # A purely static bound: the index selects no more elements than the axis holds. + # Reshape/DimShuffle preserve the element count, so follow them to whichever + # view in the chain has a fully-known static shape. + if val_static_dim is not None: + idx_static_shape = idx.type.shape + if not any(d is None for d in idx_static_shape) and ( + np.prod(idx_static_shape) <= val_static_dim + ): + return True + if isinstance(idx.owner_op, Reshape | DimShuffle): + return _static_size_not_larger(idx.owner.inputs[0], val_static_dim) + return False + + +def _index_provably_not_larger( + idx, val_static_dim, fgraph: FunctionGraph | None +) -> bool: # Per-axis check: an index that can't repeat a position can't enlarge that axis. # Does not account for cross-axis broadcast expansion from outer indexing. - if isinstance(idx, slice) or idx.ndim == 0: - return True - if all(idx.type.broadcastable): + # Try the cheap purely-static size bound first; only then the (potentially + # graph-walking) uniqueness reasoning. Both follow Reshape/DimShuffle views, + # which preserve the element count, so uniqueness is still checked just once. + # ``fgraph`` is forwarded to ``_index_provably_unique`` for the user-declared + # ``unique_indices`` assumption (None to skip it). + if _static_size_not_larger(idx, val_static_dim): return True - if idx.type.dtype == "bool": - return True - if _constant_has_unique_indices(idx): - return True - if check_assumption(fgraph, idx, UNIQUE_INDICES): - return True - if isinstance(idx.owner_op, ARange): + return _index_provably_unique(idx, fgraph) + + +def _indices_provably_not_larger(idxs_and_dims, fgraph: FunctionGraph | None) -> bool: + """Whether advanced-indexing some consecutive axes selects no more elements + than those axes already hold, so lifting a Subtensor through the indexing + can't increase computation. + + ``idxs_and_dims`` pairs each advanced index (``ndim > 0``) with the static + size of the axis it indexes. ``fgraph`` is forwarded to the uniqueness helpers + for the ``unique_indices`` assumption lookup (None to skip it). + """ + if not idxs_and_dims: return True - if isinstance(idx.owner_op, Reshape | DimShuffle): - # Views that don't add dimensions - if _index_provably_not_larger(idx.owner.inputs[0], val_static_dim, fgraph): - return True - # Fallback to static shape analysis - if val_static_dim is None: - return False - idx_static_shape = idx.type.shape - if any(d is None for d in idx_static_shape): - return False - return bool(np.prod(idx_static_shape) < val_static_dim) + idxs = [idx for idx, _ in idxs_and_dims] + dims = [dim for _, dim in idxs_and_dims] + idx_shapes = [idx.type.shape for idx in idxs] + + # With static shapes the result size is known exactly, so just compare it + # against the number of elements the indexed axes hold. + if all(d is not None for d in dims) and all( + None not in shape for shape in idx_shapes + ): + return bool(np.prod(np.broadcast_shapes(*idx_shapes)) <= np.prod(dims)) + + # Otherwise each index, on its own axis, may select no more than that axis + # holds (e.g. an arange or a statically-smaller index)... + if all(_index_provably_not_larger(idx, dim, fgraph) for idx, dim in idxs_and_dims): + return True + # ...or the indices are jointly duplicate-free, which on its own bounds the + # result by the axes' size even when the per-axis sizes are unknown. Only the + # joint-only conditions (single Nonzero, jointly-unique constants) can add + # anything here: the per-axis pass above already failed, so at least one index + # is not provably unique and the per-axis leg of the joint check is moot. + return _indices_jointly_unique(idxs, fgraph) @register_canonicalize @@ -345,17 +380,19 @@ def local_subtensor_of_batch_dims(fgraph, node): if _non_consecutive_adv_indexing(idx_tuple): return None - # Skip when lifting would expand a gather past a non-broadcast input's size. + # Skip when indexing each input would select more elements than it holds, + # making the lifted Elemwise do more work. The advanced indices are weighed + # together, over the consecutive axes they jointly index. for inp in elem.owner.inputs: - for axis, idx in enumerate(idx_tuple): - if axis >= inp.type.ndim: - break - if not isinstance(idx, TensorVariable) or idx.type.ndim == 0: - continue - if inp.type.broadcastable[axis]: - continue - if not _index_provably_not_larger(idx, inp.type.shape[axis], fgraph): - return None + adv_indices = [ + (idx, inp.type.shape[axis]) + for axis, idx in enumerate(idx_tuple[: inp.type.ndim]) + if isinstance(idx, TensorVariable) + and idx.type.ndim > 0 + and not inp.type.broadcastable[axis] + ] + if not _indices_provably_not_larger(adv_indices, fgraph): + return None batch_ndim = ( elem.owner.op.batch_ndim(elem.owner) @@ -742,11 +779,15 @@ def lift_subtensor_through_alloc(fgraph, node): # Indices on Alloc-added dims don't reach val; the rest line up with val's dims. val_indexer = indices[n_added_dims:] - dangerous_index_reaches_val = any( - not val.type.broadcastable[axis] - # Per-axis check; doesn't account for net effect across all axes. - and not _index_provably_not_larger(idx, val.type.shape[axis], fgraph) + val_adv_indices = [ + (idx, val.type.shape[axis]) for axis, idx in enumerate(val_indexer) + if isinstance(idx, TensorVariable) + and idx.type.ndim > 0 + and not val.type.broadcastable[axis] + ] + dangerous_index_reaches_val = not _indices_provably_not_larger( + val_adv_indices, fgraph ) # On broadcast val dims the index is neutralized (advanced indices dropped, diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 89f7582d5b..f03f9fcdb3 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -20,6 +20,7 @@ from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot, dot, exp, sqr from pytensor.tensor.rewriting.subtensor import ( + _index_provably_unique, _slice_to_arange, local_add_of_sparse_write, local_adv_idx_to_slice, @@ -283,10 +284,30 @@ def test_local_add_of_sparse_write(): [rewritten_basic], [x[s].inc(v[0])], strict_dtype=False ) + # set-into-zeros with jointly-unique advanced indices (neither axis unique on + # its own) is rewritten via the joint-uniqueness check on the set path. + X = matrix("X") + rows = pt.constant(np.array([0, 1, 1], dtype="int32")) + cols = pt.constant(np.array([0, 0, 1], dtype="int32")) + out_joint = X + pt.zeros(X.shape)[rows, cols].set(v) + rewritten_joint = rewrite_graph( + out_joint, include=[], custom_rewrite=sparse_rewriter + ) + utt.assert_equal_computations([rewritten_joint], [X[rows, cols].inc(v)]) + + # set-into-zeros with a jointly-duplicated advanced index is left alone (the + # (1, 1) coordinate repeats), since a dense set there is last-wins. + dup_rows = pt.constant(np.array([0, 1, 1], dtype="int32")) + dup_cols = pt.constant(np.array([0, 1, 1], dtype="int32")) + out_joint_dup = X + pt.zeros(X.shape)[dup_rows, dup_cols].set(v) + rewritten_joint_dup = rewrite_graph( + out_joint_dup, include=[], custom_rewrite=sparse_rewriter + ) + utt.assert_equal_computations([rewritten_joint_dup], [out_joint_dup]) + # A bounded slice flattens its (symbolic) bounds into the index variables; # those must not be mistaken for advanced indices. With a leading slice and a # unique advanced index the sparse write still collapses. - X = matrix("X") w = matrix("w") u = pt.constant(np.array([0, 2], dtype="int32")) lo, hi = iscalar("lo"), iscalar("hi") @@ -299,6 +320,53 @@ def test_local_add_of_sparse_write(): ) +class TestIndexProvablyUniqueArange: + """An ``arange`` index is duplicate-free when its entries don't wrap around + zero, i.e. they all share a sign. ``_index_provably_unique`` proves this for + non-negative ascending ranges (symbolic-friendly) and, with constant bounds, + for any single-signed range regardless of step direction.""" + + @pytest.mark.parametrize( + "make_arange, expected", + [ + # Non-negative, symbolic-friendly (proved without constant bounds). + (lambda k, n: pt.arange(k), True), + (lambda k, n: pt.arange(n), True), # n = shape, provably >= 0 + (lambda k, n: pt.arange(2, k), True), + (lambda k, n: pt.arange(n, 0, -1), True), # reverse range, both bounds >= 0 + # Descending into a non-negative stop: entries > stop >= 0, any start. + (lambda k, n: pt.arange(k, 0, -1), True), + (lambda k, n: pt.arange(k, 5, -1), True), + # Descending from a negative start: entries <= start < 0, any stop. + (lambda k, n: pt.arange(-1, k, -1), True), + # Constant single-signed ranges, either step direction. + (lambda k, n: pt.arange(2, 6), True), + (lambda k, n: pt.arange(-6, -2), True), # all negative + (lambda k, n: pt.arange(5, -1, -1), True), # descending, [5..0] + ( + lambda k, n: pt.arange(6, -2, -2), + True, + ), # descending, [6,4,2,0], overshoots stop + ( + lambda k, n: pt.arange(-5, 1, 3), + True, + ), # ascending, [-5,-2], overshoots stop + (lambda k, n: pt.arange(-1, -6, -1), True), # descending, all negative + # Straddling zero -> may wrap -> not provably unique. + (lambda k, n: pt.arange(-2, 2), False), + (lambda k, n: pt.arange(0, -5, -1), False), # 0 with negatives + # Sign not statically known. + (lambda k, n: pt.arange(5, k, -1), False), # unknown stop sign + (lambda k, n: pt.arange(k, 5), False), # unknown start sign + (lambda k, n: pt.arange(k, -5, -1), False), # unknown start, neg stop + ], + ) + def test_arange(self, make_arange, expected): + k = iscalar("k") + n = vector("v").shape[0] + assert _index_provably_unique(make_arange(k, n), None) is expected + + class TestLocalUselessSubtensor: x = matrix("x") s = ps.int32("s") @@ -1376,6 +1444,115 @@ def test_inc_unique_constant_idx(self): ) assert check_stack_trace(f, ops_to_check=(AdvancedSubtensor1, Elemwise)) + def test_inc_jointly_unique_constant_idx(self): + """Multiple advanced indices that are jointly (not per-axis) duplicate-free + are recognized via the joint check, so inc read-of-write simplifies even + though neither ``rows`` nor ``cols`` is unique on its own.""" + x = matrix(dtype="float64") + y = vector(dtype="float64") + rows = pt.constant(np.array([0, 1, 1], dtype="int32")) + cols = pt.constant(np.array([0, 0, 1], dtype="int32")) + + inc = inc_subtensor(x[rows, cols], y) + o = inc[rows, cols] + f = function([x, y], o, self.mode) + + dx = np.random.random((2, 2)) + dy = np.random.random((3,)) + expected = dx.copy() + np.add.at(expected, (np.array([0, 1, 1]), np.array([0, 0, 1])), dy) + np.testing.assert_allclose(expected[[0, 1, 1], [0, 0, 1]], f(dx, dy)) + topo = f.maker.fgraph.toposort() + assert not any( + isinstance(n.op, AdvancedIncSubtensor | AdvancedIncSubtensor1) for n in topo + ) + + def test_inc_tril_indices_nonzero(self): + """``tril_indices`` coordinates come from a single ``Nonzero``, distinct by + construction, so the inc read-of-write simplifies even though the indices + are symbolic and neither axis is unique on its own.""" + n = iscalar("n") + x = matrix(dtype="float64") + y = vector(dtype="float64") + rows, cols = pt.tril_indices(n) + + inc = inc_subtensor(x[rows, cols], y) + o = inc[rows, cols] + f = function([x, y, n], o, self.mode) + + dx = np.random.random((4, 4)) + tri = np.tril_indices(4) + dy = np.random.random((tri[0].size,)) + expected = dx.copy() + np.add.at(expected, tri, dy) + np.testing.assert_allclose(expected[tri], f(dx, dy, 4)) + topo = f.maker.fgraph.toposort() + assert not any( + isinstance(n.op, AdvancedIncSubtensor | AdvancedIncSubtensor1) for n in topo + ) + + def test_inc_symbolic_bool_mask(self): + """A boolean mask selects each position at most once, so it is duplicate-free + even when symbolic; the inc read-of-write simplifies to ``x[mask] + v``.""" + x = vector(dtype="float64") + v = vector(dtype="float64") + mask = vector("mask", dtype="bool") + + inc = inc_subtensor(x[mask], v) + o = inc[mask] + f = function([x, v, mask], o, self.mode) + + dx = np.arange(5.0) + dmask = np.array([1, 0, 1, 0, 1], dtype=bool) + dv = np.array([10.0, 20.0, 30.0]) + expected = dx.copy() + np.add.at(expected, np.where(dmask)[0], dv) + np.testing.assert_allclose(expected[dmask], f(dx, dv, dmask)) + topo = f.maker.fgraph.toposort() + assert not any( + isinstance(n.op, AdvancedIncSubtensor | AdvancedIncSubtensor1) for n in topo + ) + + def test_inc_symbolic_arange(self): + """An ``arange`` is strictly monotonic, so its entries are distinct even + when symbolic; the inc read-of-write simplifies to ``x[idx] + v``.""" + k = iscalar("k") + x = vector(dtype="float64") + v = vector(dtype="float64") + idx = pt.arange(k) + + inc = inc_subtensor(x[idx], v) + o = inc[idx] + f = function([x, v, k], o, self.mode) + + dx = np.arange(6.0) + dv = np.array([10.0, 20.0, 30.0, 40.0]) + expected = dx.copy() + np.add.at(expected, np.arange(4), dv) + np.testing.assert_allclose(expected[np.arange(4)], f(dx, dv, 4)) + topo = f.maker.fgraph.toposort() + assert not any( + isinstance(n.op, AdvancedIncSubtensor | AdvancedIncSubtensor1) for n in topo + ) + + def test_inc_mixed_sign_arange_not_rewritten(self): + """A negative-start arange may wrap around (``arange(-2, k)`` aliases + positions on a small axis), so it is not duplicate-free and inc must not + be rewritten.""" + k = iscalar("k") + x = vector(dtype="float64") + v = vector(dtype="float64") + idx = pt.arange(-2, k) + + inc = inc_subtensor(x[idx], v) + o = inc[idx] + f = function([x, v, k], o, self.mode) + + topo = f.maker.fgraph.toposort() + assert any( + isinstance(n.op, AdvancedIncSubtensor | AdvancedIncSubtensor1) for n in topo + ) + @pytest.mark.parametrize( "cidx_values, n_rows", [ @@ -1776,6 +1953,22 @@ def test_inc_of_set_zero_base_emits_inc(self): rewritten = rewrite_graph(out, include=("canonicalize", "specialize")) utt.assert_equal_computations([rewritten], [inc_subtensor(zeros[:stop], a + b)]) + def test_inc_of_set_advanced_jointly_unique_rewritten(self): + """Inc-of-set fires when advanced indices are jointly duplicate-free, even + though neither axis is unique on its own. ``tril_indices`` coordinates come + from a single ``Nonzero`` and so collapse to ``x[idx].set(a + b)``.""" + n = iscalar("n") + x = matrix("x", dtype="float64") + a = vector("a", dtype="float64") + b = vector("b", dtype="float64") + rows, cols = pt.tril_indices(n) + + out = inc_subtensor(set_subtensor(x[rows, cols], a)[rows, cols], b) + rewritten = rewrite_graph(out, include=("canonicalize", "specialize")) + utt.assert_equal_computations( + [rewritten], [set_subtensor(x[rows, cols], a + b)] + ) + def test_inc_of_set_advanced_with_slice_rewritten(self): """A bounded slice flattens its (symbolic) bounds into the index variables; those must not be mistaken for advanced indices and block the @@ -1793,6 +1986,26 @@ def test_inc_of_set_advanced_with_slice_rewritten(self): [rewritten], [set_subtensor(x[lo:hi, idx], a + b)] ) + def test_inc_of_set_advanced_jointly_unique_with_slice_rewritten(self): + """A bounded slice flattens its bounds into the index variables; those + must not be mistaken for advanced indices and block the joint check. With + a leading slice and jointly-unique advanced indices the inc-of-set still + collapses to ``x[lo:hi, rows, cols].set(a + b)``.""" + n = iscalar("n") + x = tensor3("x", dtype="float64") + a = matrix("a", dtype="float64") + b = matrix("b", dtype="float64") + lo, hi = iscalar("lo"), iscalar("hi") + rows, cols = pt.tril_indices(n) + + out = inc_subtensor( + set_subtensor(x[lo:hi, rows, cols], a)[lo:hi, rows, cols], b + ) + rewritten = rewrite_graph(out, include=("canonicalize", "specialize")) + utt.assert_equal_computations( + [rewritten], [set_subtensor(x[lo:hi, rows, cols], a + b)] + ) + def test_inc_of_set_advanced_non_unique_not_rewritten(self): """Inc-of-set requires unique indices; duplicate constant indices on advanced axes block the rewrite.""" @@ -2848,8 +3061,9 @@ def test_cholesky_unconstrain_grad(exp_before_materialize): packed = pt.vector("packed") if exp_before_materialize: - # We test the same optimized result regardless of whether - # the diagonals are updated before or after materialization + # Same ``L`` two ways: exponentiate the diagonal in the packed vector + # before scattering, or (else branch) scatter first and exponentiate the + # matrix diagonal. Equivalent, but optimize to different graphs under BlasOpt. packed_diag_indices = pt.arange(n + 1).cumsum()[1:] - 1 log_diag = packed[packed_diag_indices] packed_update = packed[packed_diag_indices].set(pt.exp(log_diag)) @@ -2873,7 +3087,6 @@ def test_cholesky_unconstrain_grad(exp_before_materialize): mode = get_default_mode().excluding("fuse_indexed_into_elemwise") f = function([packed], [loss, grad], mode=mode) - f.dprint(print_shape=True) idx_types = ( Subtensor, @@ -2885,13 +3098,16 @@ def test_cholesky_unconstrain_grad(exp_before_materialize): ExtractDiag, ) n_idx = sum(1 for n in f.maker.fgraph.toposort() if isinstance(n.op, idx_types)) - # The ``BlasOpt`` rewrites lower ``L @ L.T`` to ``Gemm``; the gradient then - # fuses the diagonal-gradient term into a ``Gemm`` operand, materializing one - # extra set-subtensor. A linker that cannot use them lists ``BlasOpt`` in - # ``incompatible_rewrites`` (e.g. the numba linker), keeping the plain ``Dot`` - # lowering with that term as a vector. Both lowerings are correct. + # Post-materialization, the log-det gradient is a diagonal matrix added to + # ``L@L.T`` at the matrix level; BlasOpt's GemmOptimizer fuses ``add(dot, C)`` + # into one Gemm, materializing that diagonal as one extra set-subtensor (7 ops). + # Pre-materialization keeps the term in the packed vector's index space, so + # there's no matrix-level add to fuse (6 ops). Without Gemm (BlasOpt in the + # linker's incompatible_rewrites, e.g. numba) the term never reaches the matrix + # and both collapse to 6. All lowerings are correct. blas_rewrites_run = "BlasOpt" not in f.maker.mode.linker.incompatible_rewrites - assert n_idx == (7 if blas_rewrites_run else 6) + expected_n_idx = 7 if (blas_rewrites_run and not exp_before_materialize) else 6 + assert n_idx == expected_n_idx x = np.array([1.0, 0.5, 2.0, 0.3, 0.1, 1.5]) # Expected values were computed once by running ``f(x)``. diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 3314ba6952..7ff5936fb2 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -235,6 +235,36 @@ def test_elemwise_adv_index_assumed_unique_lifts(self): ) result.assert_graph(x[idx] + y[idx]) + def test_elemwise_jointly_unique_adv_indices_lift(self): + """A group of adv indices that each repeat but pair up to distinct + coordinates (tril_indices) can't select more elements than the indexed + axes hold, so it lifts.""" + # Symbolic indices: the outputs of a single Nonzero. + n = pt.scalar("n", dtype="int64") + x = pt.matrix("x") + rows, cols = pt.tril_indices(n) + # ``local_add_canonizer`` simplifies the ``n - 0`` inside ``tril_indices``, + # which then lets the duplicate ``arange`` merge -- noise for this test. + result = RewriteTester( + [n, x], [pt.exp(x)[rows, cols]], exclude=("local_add_canonizer",) + ) + result.assert_graph(pt.exp(x[rows, cols])) + result.assert_eval(3, np.arange(9.0).reshape(3, 3)) + + # Constant indices, static array shape: proved through the exact size. + x = pt.matrix("x", shape=(5, 5)) + rows, cols = (pt.constant(i) for i in np.tril_indices(5)) + result = RewriteTester([x], [pt.exp(x)[rows, cols]]) + result.assert_graph(pt.exp(x[rows, cols])) + result.assert_eval(np.arange(25.0).reshape(5, 5)) + + # Constant indices, unknown array shape: proved through joint uniqueness. + x = pt.matrix("x") + rows, cols = (pt.constant(i) for i in np.tril_indices(5)) + result = RewriteTester([x], [pt.exp(x)[rows, cols]]) + result.assert_graph(pt.exp(x[rows, cols])) + result.assert_eval(np.arange(25.0).reshape(5, 5)) + def test_blockwise(self): class CoreTestOp(Op): itypes = [dvector, dvector] @@ -756,6 +786,21 @@ def test_const_idx_with_duplicates_bails(self): rewritten = rewrite_graph(out, **self.rewrite_kw) assert_equal_computations([rewritten], [out], strict_dtype=False) + def test_jointly_unique_adv_indices_lift(self): + """Indices that each repeat but pair up to distinct coordinates + (tril_indices) don't enlarge val, so the read lifts through Alloc.""" + val = pt.matrix("val", shape=(5, 5)) + rows, cols = (pt.constant(i) for i in np.tril_indices(5)) + + result = RewriteTester( + [val], + [pt.alloc(val, 5, 5)[rows, cols]], + include=("ShapeOpt", "canonicalize", "specialize"), + exclude=("local_replace_AdvancedSubtensor",), + ) + result.assert_graph(val[rows, cols], strict_dtype=False) + result.assert_eval(np.arange(25.0).reshape(5, 5)) + def test_negative_step_idx_to_slice(self): """Negative-step constant arange ``[7, 5, 3, 1]`` rewrites to ``x[7::-2]``.""" x = pt.vector("x", shape=(10,))