diff --git a/pytensor/__init__.py b/pytensor/__init__.py index bb8cbe2d3a..c64e73dced 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -33,6 +33,7 @@ from pytensor.scan.basic import scan from pytensor.scan.views import map from pytensor.compile.builders import OpFromGraph +import pytensor.compile.rewriting # register OpFromGraph rewrites in optdb from pytensor.link.jax.ops import wrap_jax from pytensor import _sparse_lazy # isort: on diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index cf10adadb0..7e9b22d439 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -7,11 +7,9 @@ from collections.abc import Callable, Sequence from copy import copy from functools import partial -from typing import cast -from pytensor.compile.maker import function +from pytensor.compile.io import In, Out from pytensor.compile.mode import get_mode -from pytensor.compile.rebuild import rebuild_collect_shared from pytensor.compile.sharedvalue import SharedVariable from pytensor.gradient import DisconnectedType, disconnected_type, grad, pushforward from pytensor.graph.basic import ( @@ -23,7 +21,7 @@ from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph from pytensor.graph.null_type import NullType from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern -from pytensor.graph.replace import clone_replace, graph_replace +from pytensor.graph.replace import clone_replace from pytensor.graph.traversal import graph_inputs from pytensor.graph.utils import MissingInputError from pytensor.tensor.shape import Shape_i @@ -54,70 +52,48 @@ def infer_shape(outs, inputs, input_shapes): if cached is not None: replacements[cached] = s - if replacements: - flat = [s for tup in output_shapes if tup is not None for s in tup] - flat_replaced = graph_replace(flat, replacements, strict=False) - result = [] - idx = 0 - for tup in output_shapes: - if tup is None: - result.append(None) - else: - result.append(tuple(flat_replaced[idx : idx + len(tup)])) - idx += len(tup) - return result - - return output_shapes + flat = [s for tup in output_shapes if tup is not None for s in tup] + flat_replaced = clone_replace(flat, replace=replacements) + result = [] + idx = 0 + for tup in output_shapes: + if tup is None: + result.append(None) + else: + result.append(tuple(flat_replaced[idx : idx + len(tup)])) + idx += len(tup) + return result def construct_nominal_fgraph( inputs: Sequence[Variable], outputs: Sequence[Variable] -) -> tuple[ - FunctionGraph, - Sequence[Variable], - dict[Variable, Variable], - dict[Variable, Variable], -]: - """Construct an inner-`FunctionGraph` with ordered nominal inputs.""" - implicit_shared_inputs = [] +) -> FunctionGraph: + """Construct an inner-`FunctionGraph` with ordered nominal inputs. + Raises ``MissingInputError`` if ``outputs`` implicitly depend on a variable + that is neither a `Constant` nor listed in ``inputs`` (including shared + variables, which must be passed explicitly). + """ dummy_inputs = [inp.type() for inp in inputs] - dummy_implicit_shared_inputs = [] for var in graph_inputs(outputs, inputs): - if var in inputs: + if var in inputs or isinstance(var, Constant): continue if isinstance(var, SharedVariable): - # We allow shared inputs to be added automatically to the graph - implicit_shared_inputs.append(var) - dummy_implicit_shared_inputs.append(var.type()) - elif not isinstance(var, Constant): - raise MissingInputError(f"NominalGraph is missing an input: {var}") - - replacements = dict( - zip( - inputs + implicit_shared_inputs, - dummy_inputs + dummy_implicit_shared_inputs, - strict=True, - ) - ) + raise MissingInputError( + f"Inner graph implicitly depends on shared variable {var}. " + "Provide it explicitly in the 'inputs' list." + ) + raise MissingInputError(f"NominalGraph is missing an input: {var}") - new = rebuild_collect_shared( - cast(Sequence[Variable], outputs), - inputs=inputs + implicit_shared_inputs, - replace=replacements, - copy_inputs_over=False, - ) - ( - local_inputs, - local_outputs, - (_clone_d, update_d, update_expr, new_shared_inputs), - ) = new + replacements = dict(zip(inputs, dummy_inputs, strict=True)) + + # ``outputs`` must be mutable ``Apply`` graphs; a caller holding a frozen + # graph thaws it first (``FrozenFunctionGraph.unfreeze``). + local_inputs = dummy_inputs + local_outputs = clone_replace(outputs, replace=replacements) - assert len(local_inputs) == len(inputs) + len(implicit_shared_inputs) + assert len(local_inputs) == len(inputs) assert len(local_outputs) == len(outputs) - assert not update_d - assert not update_expr - assert not new_shared_inputs fgraph = FunctionGraph(local_inputs, local_outputs, clone=False) @@ -135,7 +111,7 @@ def construct_nominal_fgraph( fgraph.clients.pop(inp, None) fgraph.add_input(nom_inp) - return fgraph, implicit_shared_inputs, update_d, update_expr + return fgraph class OpFromGraph(Op, HasInnerGraph): @@ -153,8 +129,8 @@ class OpFromGraph(Op, HasInnerGraph): Notes ----- - - Shared variables in the inner graph are supported. They are detected automatically and added - as implicit inputs. + - Shared variables used in the inner graph must be passed explicitly as inputs; implicit + capture raises ``MissingInputError``. - Unused inputs are supported (needed for gradient overrides). - Nested OpFromGraph is supported. - ``inline=True`` causes the Op's inner graph to be inlined during compilation, which gives @@ -163,7 +139,7 @@ class OpFromGraph(Op, HasInnerGraph): - Override callables should be pure functions (no side effects). They are called once at the first call to L_op/R_op and converted to OpFromGraph instances. They are also called once at construction time with dummy inputs to build a frozen representation for equality comparison. - - Two OpFromGraph instances with the same inner graph, overrides, shared variables, and settings + - Two OpFromGraph instances with the same inner graph, overrides, and settings are considered equal. This allows the MergeOptimizer to deduplicate identical OpFromGraph nodes. @@ -183,7 +159,7 @@ class OpFromGraph(Op, HasInnerGraph): e2 = op(x, y, z) + op(z, y, x) fn = function([x, y, z], [e2]) - With a shared variable: + With a shared variable (passed explicitly as an input): .. code-block:: python @@ -195,8 +171,8 @@ class OpFromGraph(Op, HasInnerGraph): x, y, z = pt.scalars("xyz") s = pytensor.shared(np.random.random((2, 2)).astype(config.floatX)) e = x + y * z + s - op = OpFromGraph([x, y, z], [e]) - e2 = op(x, y, z) + op(z, y, x) + op = OpFromGraph([x, y, z, s], [e]) + e2 = op(x, y, z, s) + op(z, y, x, s) fn = function([x, y, z], [e2]) Per-input L_op override: @@ -306,9 +282,8 @@ def __init__( If provided, used as the connection pattern for this Op. Each inner list has one bool per output, and the outer list has one entry per input. strict : bool, optional - If True, raises when any variables needed to compute the inner graph are not provided - as explicit inputs. Only relevant for graphs with shared variables. Default False. - Under ``strict=False``, implicit shared-variable capture is deprecated. + Ignored. All variables needed to compute the inner graph must always be + provided as explicit inputs; implicitly captured shared variables raise. name : str, optional A name for debugging purposes. **kwargs @@ -339,24 +314,26 @@ def __init__( self.is_inline = inline - self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph( - inputs, outputs - ) - self._frozen_fgraph = self.fgraph.freeze() - - if strict and self.shared_inputs: - raise ValueError( - "All variables needed to compute inner-graph must be provided as inputs under strict=True. " - f"The inner-graph implicitly depends on the following shared variables {self.shared_inputs}" - ) - elif self.shared_inputs: + inner_fgraph = construct_nominal_fgraph(inputs, outputs) + # The inner graph is stored immutable. The default freeze (no dedup) + # keeps distinct buffers for inplace ``destroy_map`` ops; structural + # folding would alias them. See ``FunctionGraph.freeze``. + self.fgraph = inner_fgraph.freeze() + + # `compile_kwargs` used to control how the inner graph was compiled. + # That is now the job of the `ofg_inner_graph` rewrite (which + # inherits the outer compilation), so they are deprecated AND ignored: + # the inner function is compiled with default settings (see `fn`). + # `on_unused_input` is exempt: tolerating unused inputs is now the + # default behavior, so passing it is a harmless no-op (not warned). + deprecated_kwargs = {k for k in kwargs if k != "on_unused_input"} + if deprecated_kwargs: warnings.warn( - "Implicit capture of shared variables is deprecated. " - "Please provide shared variables explicitly in the 'inputs' list.", - DeprecationWarning, - stacklevel=2, + "Passing `compile_kwargs` to `OpFromGraph` is deprecated and " + "now ignored: the inner graph inherits the outer compilation. " + f"Ignored: {sorted(deprecated_kwargs)}.", + FutureWarning, ) - self.kwargs = kwargs self.input_types = [inp.type for inp in inputs] self.output_types = [out.type for out in outputs] @@ -416,14 +393,14 @@ def _freeze_override_to_fgraph( ] if not connected: return pattern, None - return pattern, FunctionGraph(all_inputs, connected).freeze() + return pattern, FrozenFunctionGraph.from_io(all_inputs, connected) def _freeze_override(self, override, make_dummy_args): """Freeze one override (lop/grad/rop) into a FrozenFunctionGraph.""" if override is None: return None if isinstance(override, OpFromGraph): - return override._frozen_fgraph + return override.fgraph all_inputs, callable_args = make_dummy_args() @@ -485,16 +462,18 @@ def __eq__(self, other): if type(self) is not type(other): return False if ( - self._frozen_fgraph != other._frozen_fgraph + self.fgraph != other.fgraph or self.is_inline != other.is_inline or self.destroy_map != other.destroy_map - or len(self.shared_inputs) != len(other.shared_inputs) - or any( - a is not b - for a, b in zip(self.shared_inputs, other.shared_inputs, strict=True) - ) ): return False + # Identical override objects (e.g. a clone from ``clone_with_inner_graph``) + # are equal without freezing, which would invoke callable overrides. + if ( + self.pullback_overrides is other.pullback_overrides + and self.pushforward_overrides is other.pushforward_overrides + ): + return True # When freezing overrides, skip override comparison to break infinite # recursion for self-referential overrides (e.g. Sylvester L_op). # The fgraph comparison above is sufficient for cache correctness @@ -509,7 +488,7 @@ def __eq__(self, other): ) def __hash__(self): - return hash((type(self), self._frozen_fgraph, self.is_inline)) + return hash((type(self), self.fgraph, self.is_inline)) def __str__(self): name = self.__class__.__name__ if self.name is None else self.name @@ -568,8 +547,13 @@ def _build_and_cache_lop_op( except KeyError: pass - inner_inputs = self.inner_inputs - inner_outputs = self.inner_outputs + # Differentiate a thawed copy of the inner graph so ``grad`` walks + # mutable ``Apply`` nodes rather than the immutable ``FrozenApply`` nodes + # of ``self.fgraph`` (whose tuple inputs/outputs break Ops that + # concatenate them, e.g. ``Blockwise.pullback``). + unfrozen_fgraph = self.fgraph.unfreeze() + inner_inputs = list(unfrozen_fgraph.inputs) + inner_outputs = list(unfrozen_fgraph.outputs) nin = len(inner_inputs) nout = len(inner_outputs) pullback_overrides = self.pullback_overrides @@ -692,8 +676,10 @@ def _build_and_cache_rop_op(self): if self._rop_op_cache is not None: return self._rop_op_cache - inner_inputs = self.inner_inputs - inner_outputs = self.inner_outputs + # Thaw the inner graph before differentiating (see ``_build_and_cache_lop_op``). + unfrozen_fgraph = self.fgraph.unfreeze() + inner_inputs = list(unfrozen_fgraph.inputs) + inner_outputs = list(unfrozen_fgraph.outputs) nout = len(inner_outputs) pushforward_overrides = self.pushforward_overrides @@ -777,89 +763,17 @@ def pushforward(self, inputs, outputs, eval_points): rop_op = self._build_and_cache_rop_op() return rop_op(*inputs, *eval_points, return_list=True) - def __call__(self, *inputs, **kwargs): - # The user interface doesn't expect the shared variable inputs of the - # inner-graph, but, since `Op.make_node` does (and `Op.__call__` - # dispatches to `Op.make_node`), we need to compensate here - num_expected_inps = len(self.inner_inputs) - len(self.shared_inputs) - - if len(inputs) == num_expected_inps: - actual_inputs = inputs + tuple(self.shared_inputs) - return super().__call__(*actual_inputs, **kwargs) - elif len(inputs) == len(self.inner_inputs): - return super().__call__(*inputs, **kwargs) - else: - raise ValueError(f"Expected at least {num_expected_inps} input(s)") - def make_node(self, *inputs): # The `inputs` received here should correspond to the inputs in the # `Apply` nodes we produce below if len(inputs) != len(self.inner_inputs): raise ValueError(f"Expected {len(self.inner_inputs)} input(s)") - num_expected_inps = len(self.inner_inputs) - len(self.shared_inputs) - non_shared_inputs = inputs[:num_expected_inps] - - non_shared_inputs = [ + inputs = [ inp_t.filter_variable(inp) - for inp, inp_t in zip(non_shared_inputs, self.input_types, strict=True) + for inp, inp_t in zip(inputs, self.input_types, strict=True) ] - - new_shared_inputs = inputs[num_expected_inps:] - inner_and_input_shareds = list( - zip(self.shared_inputs, new_shared_inputs, strict=True) - ) - - if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds): - # The shared variables are not equal to the original shared - # variables, so we construct a new `Op` that uses the new shared - # variables instead. - replace = dict( - zip( - self.inner_inputs[num_expected_inps:], - new_shared_inputs, - strict=True, - ) - ) - - # If the new shared variables are inconsistent with the inner-graph, - # such errors should arise in this step - new_inner_outputs = clone_replace( - self.inner_outputs, replace=replace, copy_inputs_over=True - ) - - # It's possible that the new shared variable inputs aren't actually - # shared variables. When they aren't we need to add them as new - # inputs. - unshared_inputs = [ - inp for inp in new_shared_inputs if not isinstance(inp, SharedVariable) - ] - new_inner_inputs = self.inner_inputs[:num_expected_inps] + unshared_inputs - - new_op = type(self)( - inputs=new_inner_inputs, - outputs=new_inner_outputs, - inline=self.is_inline, - pullback=self.pullback_overrides, - pushforward=self.pushforward_overrides, - connection_pattern=self._connection_pattern, - name=self.name, - destroy_map=self.destroy_map, - **self.kwargs, - ) - new_inputs = ( - list(non_shared_inputs) + unshared_inputs + new_op.shared_inputs - ) - else: - new_op = self - new_inputs = list(non_shared_inputs) + new_op.shared_inputs - - apply_node = Apply( - new_op, - new_inputs, - [type() for type in new_op.output_types], - ) - return apply_node + return Apply(self, inputs, [type() for type in self.output_types]) def connection_pattern(self, node): """ @@ -881,8 +795,13 @@ def infer_shape(self, node, shapes): from pytensor.tensor.rewriting.shape import ShapeFeature sf = ShapeFeature() - inner_inputs = self.inner_inputs - template = [sf.shape_tuple(o) for o in self.inner_outputs] + # Build the shape graph on a thawed copy: the fresh shape nodes must + # not be built on top of the frozen inner variables (freezing such a + # mixed graph is not supported). + unfrozen_fgraph = self.fgraph.unfreeze() + inner_inputs = list(unfrozen_fgraph.inputs) + inner_outputs = list(unfrozen_fgraph.outputs) + template = [sf.shape_tuple(o) for o in inner_outputs] flat_shapes = [s for tup in template if tup is not None for s in tup] # Express the inner-output shapes as a frozen function of the inner @@ -927,25 +846,79 @@ def fn(self): if getattr(self, "_fn", None) is not None: return self._fn - kwargs = self.kwargs.copy() - mode = get_mode(kwargs.pop("mode", None)).excluding("symbolic_op_recognition") - self._fn = function(self.inner_inputs, self.inner_outputs, mode=mode, **kwargs) + # ``op.fgraph`` is already backend-optimized (inplace included): the + # ``ofg_inner_graph`` rewrite ran the backend optimizer on it during the + # outer compile. So we only need to link it. The linker forces + # ``minimum_compile`` back in via its ``required_rewrites``, and (for an + # inner graph) ``minimum_compile`` *is* that inner-graph rewrite -- so we + # exclude ``compile_inner_graph`` to stop it re-baking an already-baked + # graph. ``prepare_fgraph`` still inserts the boundary deepcopies; passing + # ``fgraph=`` avoids a re-clone. Unused inputs (e.g. rng, size) and + # internal-only inplace ops are expected and tolerated. + mode = ( + get_mode(None) + .clone(optimizer="minimum_compile") + .excluding("compile_inner_graph") + ) + unfrozen_fgraph = self.fgraph.unfreeze() + self._fn = mode.function_maker( + [In(inp, borrow=True) for inp in unfrozen_fgraph.inputs], + [Out(out, borrow=True) for out in unfrozen_fgraph.outputs], + mode, + fgraph=unfrozen_fgraph, + accept_inplace=True, + on_unused_input="ignore", + ).create() self._fn.trust_input = True return self._fn @property def inner_inputs(self): - return self.fgraph.inputs + # A list (not the frozen tuple) so callers that concatenate inner + # inputs/outputs keep list semantics. Read-only views of the immutable + # graph; manipulating them requires a fresh/unfrozen graph. + return list(self.fgraph.inputs) @property def inner_outputs(self): - return self.fgraph.outputs + return list(self.fgraph.outputs) def clone(self): - res = copy(self) - res.fgraph = res.fgraph.clone(clone_inner_graphs=True) - return res + # The inner graph is immutable (a frozen ``FunctionGraph``), so there is + # nothing to deep-clone -- mirror ``Composite.clone``. + return self + + def clone_with_inner_graph(self, inner_fgraph) -> OpFromGraph: + """Return a copy of this op whose inner graph is ``inner_fgraph``. + + Used by the ``ofg_inner_graph`` rewrite to bake an already-optimized inner + graph into a NEW op without mutating ``self``. + + ``inner_fgraph`` is the rewrite's optimized graph and already carries the + ordering we must keep: when it has baked inplace ops its ``DestroyHandler`` + defines a destroy-aware toposort (every reader of a destroyed buffer runs + before the op that destroys it). We therefore assign it directly -- + ``freeze`` re-roots it on nominal inputs by position and bakes that order + into the frozen graph (and each node's ``topo_idx``). Routing it through + ``construct_nominal_fgraph`` instead would rebuild and re-toposort, losing + the destroy-aware order. Mirrors ``Scan.clone_with_inner_graph``. + """ + new = copy(self) + new._fn = None + new.fgraph = ( + inner_fgraph + if isinstance(inner_fgraph, FrozenFunctionGraph) + else inner_fgraph.freeze() + ) + new.input_types = [inp.type for inp in new.fgraph.inputs] + new.output_types = [out.type for out in new.fgraph.outputs] + # Drop caches tied to the previous inner graph. + new._lop_op_cache = {} + new._rop_op_cache = None + new._frozen_lop = None + new._frozen_rop = None + return new def perform(self, node, inputs, outputs): variables = self.fn(*inputs) diff --git a/pytensor/compile/debug/debugmode.py b/pytensor/compile/debug/debugmode.py index 9a9617774c..135d8ecd83 100644 --- a/pytensor/compile/debug/debugmode.py +++ b/pytensor/compile/debug/debugmode.py @@ -23,6 +23,10 @@ from pytensor.compile.maker import FunctionMaker from pytensor.compile.mode import Mode, register_mode from pytensor.compile.ops import ViewOp +from pytensor.compile.rewriting import ( + destructive_rewrite_ofg_inner_graph, + rewrite_ofg_inner_graph, +) from pytensor.configdefaults import config from pytensor.graph.basic import Variable from pytensor.graph.destroyhandler import DestroyHandler @@ -35,7 +39,13 @@ from pytensor.link.c.op import COp from pytensor.link.utils import map_storage, raise_with_op from pytensor.printing import _debugprint +from pytensor.scan.rewriting.inner_graph import ( + cvm_rewrite_scan_inner_graph, + rewrite_scan_inner_graph, +) from pytensor.tensor import TensorType +from pytensor.tensor.optimize import rewrite_optimize_inner_graph +from pytensor.tensor.rewriting.optimize import c_rewrite_optimize_inner_graph from pytensor.utils import NoDuplicateOptWarningFilter, difference, get_unbound_function @@ -1316,7 +1326,7 @@ def printstuff(self): # 2) it a has a .clone() method # 3) it has required_rewrites and incompatible_rewrites class attributes class _DummyLinker: - required_rewrites = () + required_rewrites = ("minimum_compile",) incompatible_rewrites = () # This is not a real linker anyway @@ -1324,6 +1334,13 @@ def clone(self, allow_gc=None): return self +# DebugMode links inner functions through the C/VM machinery (``Scan.fn`` / +# ``OpFromGraph.fn``), so bake inner graphs exactly as those linkers do. +rewrite_scan_inner_graph.register(_DummyLinker, cvm_rewrite_scan_inner_graph) +rewrite_ofg_inner_graph.register(_DummyLinker, destructive_rewrite_ofg_inner_graph) +rewrite_optimize_inner_graph.register(_DummyLinker, c_rewrite_optimize_inner_graph) + + class _Linker(LocalLinker): """ Special debugging linker. @@ -2018,7 +2035,12 @@ def __init__( fgraph.attach_feature(equivalence_tracker) fgraph.equivalence_tracker = equivalence_tracker - optimizer(fgraph) + # Expose the compile mode to inner-graph rewrites (mirrors ``FunctionMaker``) + fgraph._compile_mode = mode + try: + optimizer(fgraph) + finally: + del fgraph._compile_mode pytensor.compile.aliasing.insert_deepcopy( fgraph, inputs, list(chain(outputs, additional_outputs)) diff --git a/pytensor/compile/maker.py b/pytensor/compile/maker.py index 0073294012..7e9b103717 100644 --- a/pytensor/compile/maker.py +++ b/pytensor/compile/maker.py @@ -469,7 +469,14 @@ def prepare_fgraph( mode=mode, traceback__limit=config.traceback__compile_limit, ): - rewriter_profile = rewriter(fgraph) + # Expose the compile mode so inner-graph rewrites can recover + # the active linker's required/incompatible rewrites reliably + # (``config.mode`` is unreliable across nested compilations). + fgraph._compile_mode = mode + try: + rewriter_profile = rewriter(fgraph) + finally: + del fgraph._compile_mode end_rewriter = time.perf_counter() rewrite_time = end_rewriter - start_rewriter diff --git a/pytensor/compile/rebuild.py b/pytensor/compile/rebuild.py index 29933a51e0..f47d010802 100644 --- a/pytensor/compile/rebuild.py +++ b/pytensor/compile/rebuild.py @@ -29,7 +29,7 @@ def rebuild_collect_shared( rebuild_strict=True, copy_inputs_over=True, no_default_updates=False, - clone_inner_graphs=False, + clone_inner_graphs=None, ) -> tuple[ list[Variable], Variable, @@ -51,7 +51,7 @@ def rebuild_collect_shared( rebuild_strict=True, copy_inputs_over=True, no_default_updates=False, - clone_inner_graphs=False, + clone_inner_graphs=None, ) -> tuple[ list[Variable], list[Variable], @@ -73,7 +73,7 @@ def rebuild_collect_shared( rebuild_strict=True, copy_inputs_over=True, no_default_updates=False, - clone_inner_graphs=False, + clone_inner_graphs=None, ) -> tuple[ list[Variable], Out, @@ -95,7 +95,7 @@ def rebuild_collect_shared( rebuild_strict=True, copy_inputs_over=True, no_default_updates=False, - clone_inner_graphs=False, + clone_inner_graphs=None, ) -> tuple[ list[Variable], list[Out], @@ -116,7 +116,7 @@ def rebuild_collect_shared( rebuild_strict=True, copy_inputs_over=True, no_default_updates=False, - clone_inner_graphs=False, + clone_inner_graphs=None, ) -> tuple[ list[Variable], list[Variable] | Variable | Out | list[Out], @@ -156,12 +156,13 @@ def rebuild_collect_shared( If False (default), perform them all. Else, perform automatic updates on all Variables that are neither in "updates" nor in "no_default_updates". - clone_inner_graphs : bool - If ``True``, clone `Op`\s that are subclasses of `HasInnerGraph` and their - inner-graphs. """ + from pytensor.graph.basic import _warn_deprecated_clone_inner_graph + + _warn_deprecated_clone_inner_graph(clone_inner_graphs, "clone_inner_graphs") + if isinstance(outputs, tuple): outputs = list(outputs) @@ -201,7 +202,6 @@ def clone_v_get_shared_updates(v, copy_inputs_over): owner, clone_d, strict=rebuild_strict, - clone_inner_graphs=clone_inner_graphs, ) clone_d.setdefault(var, var) continue @@ -481,7 +481,6 @@ def param_to_in(param, allow_downcast=None): rebuild_strict=rebuild_strict, copy_inputs_over=True, no_default_updates=no_default_updates, - clone_inner_graphs=True, ) input_variables, cloned_extended_outputs, other_stuff = output_vars clone_d, update_d, _update_expr, shared_inputs = other_stuff diff --git a/pytensor/compile/rewriting.py b/pytensor/compile/rewriting.py new file mode 100644 index 0000000000..7c251d6b3a --- /dev/null +++ b/pytensor/compile/rewriting.py @@ -0,0 +1,224 @@ +"""Backend inner-graph rewriting: the generic baking helper and the ``OpFromGraph`` registrations and inlining.""" + +from collections import defaultdict +from functools import singledispatch + +from pytensor.compile.aliasing import ( + add_supervisor_to_fgraph, + insert_deepcopy, +) +from pytensor.compile.builders import OpFromGraph +from pytensor.compile.io import In, Out +from pytensor.compile.mode import optdb +from pytensor.graph.basic import Apply, Variable +from pytensor.graph.fg import FrozenFunctionGraph +from pytensor.graph.rewriting.basic import ( + copy_stack_trace, + dfs_rewriter, + get_active_mode, + graph_rewriter, + node_rewriter, +) +from pytensor.link.basic import PerformLinker +from pytensor.link.c.basic import CLinker, OpWiseCLinker +from pytensor.link.jax.linker import JAXLinker +from pytensor.link.mlx.linker import MLXLinker +from pytensor.link.numba.linker import NumbaLinker +from pytensor.link.pytorch.linker import PytorchLinker +from pytensor.link.vm import VMLinker + + +def rewrite_inner_graph(fgraph, match, rewrite): + """Bake the inner graphs of the ``match``-ing nodes for the active backend. + + An inner-graph op is matched directly or as the ``core_op`` of a `Blockwise` + (so an `OpFromGraph`/`Scan`/`Minimize` wrapped in a `Blockwise` still gets its + inner graph optimized for the backend). Nodes are grouped by ``(inner op, core + input types)`` -- the inplace/aliasing contract a ``rewrite`` bakes depends + only on the (un-batched, core-level) buffer shapes, which those types capture + -- so each distinct inner graph is prepared once and shared. For each group + ``rewrite(linker, op, node, inner, mode=...)`` mutates the unfrozen ``inner`` + graph in place (optimize + features + boundary deepcopies), deriving its own + optimizer from ``mode`` -- ``mode.optimizer`` to bake inplace, or + ``mode.excluding("inplace").optimizer`` to leave the graph functional; the new + op (re-wrapped in its `Blockwise` if needed) then replaces the nodes. + """ + from pytensor.tensor.blockwise import Blockwise + + mode = get_active_mode(fgraph) + linker = mode.linker + + def unwrap(node): + """Return ``(inner_op, inner_node, rewrap)`` for a matching node, else ``None``. + + ``inner_node`` is the node whose input types the ``rewrite`` sees; for a + `Blockwise` it is the *core* (un-batched) node, so per-node shape logic + (e.g. `Scan`'s destroyability) reasons about the core buffers. ``rewrap`` + rebuilds the outer op from a new inner op. + """ + op = node.op + if isinstance(op, Blockwise) and match(op.core_op): + core_node = op._create_dummy_core_node(node.inputs) + + def rewrap(new_core_op, op=op): + return type(op)( + new_core_op, + signature=op.signature, + name=op.name, + gufunc_spec=op.gufunc_spec, + destroy_map=op.destroy_map, + ) + + return op.core_op, core_node, rewrap + if match(op): + return op, node, lambda new_op: new_op + return None + + groups: dict = defaultdict(list) + node_meta: dict = {} + for node in fgraph.apply_nodes: + if (meta := unwrap(node)) is not None: + inner_op, inner_node, _ = meta + # Ops sharing a frozen inner graph but with different destroy/view maps + # bake differently (the maps decide which taps may be destroyed and which + # boundary deepcopies are needed), so they must group separately. + # The node input types are not redundant with the op's hash/eq either: + # they can be more specific than the op's nominal types (e.g. static + # shapes), and per-node contracts like Scan's destroyability depend on them. + key = ( + inner_op, + tuple(i.type for i in inner_node.inputs), + tuple((o, tuple(v)) for o, v in sorted(inner_op.destroy_map.items())), + tuple((o, tuple(v)) for o, v in sorted(inner_op.view_map.items())), + ) + groups[key].append(node) + node_meta[node] = meta + if not groups: + return + + node_to_new_op: dict = {} + for nodes in groups.values(): + rep_node = nodes[0] + inner_op, inner_node, _ = node_meta[rep_node] + inner = inner_op.fgraph.unfreeze() + # Expose the compile mode to nested inner-graph rewrites (mirrors ``FunctionMaker``) + inner._compile_mode = mode + try: + rewrite(linker, inner_op, inner_node, inner, mode=mode) + finally: + del inner._compile_mode + new_inner_op = inner_op.clone_with_inner_graph(inner) + if new_inner_op != inner_op: + for node in nodes: + node_to_new_op[node] = node_meta[node][2](new_inner_op) + + if not node_to_new_op: + return + + for node in fgraph.toposort(): + new_op = node_to_new_op.get(node) + if new_op is not None: + new_node = new_op.make_node(*node.inputs) + fgraph.replace_all( + list(zip(node.outputs, new_node.outputs, strict=True)), + reason="rewrite_inner_graph", + ) + + +@singledispatch +def rewrite_ofg_inner_graph(linker, op, node, inner, *, mode): + """Rewrite an ``OpFromGraph`` inner graph (in place) for ``linker``'s backend.""" + raise NotImplementedError( + f"Linker {type(linker).__name__} has not registered an OpFromGraph " + "inner-graph rewrite" + ) + + +def _ofg_inner_optimizer(mode): + # Recognition rewrites fold a pattern into an inner-graph op (e.g. + # ``exp(x) / sum(exp(x))`` -> ``Softmax``, itself an ``OpFromGraph``). Running + # them on an ``OpFromGraph`` inner graph -- which may *be* that pattern -- + # would re-create the op inside itself and recurse without end. + return mode.excluding("symbolic_op_recognition").optimizer + + +@rewrite_ofg_inner_graph.register(VMLinker) +@rewrite_ofg_inner_graph.register(PerformLinker) +@rewrite_ofg_inner_graph.register(CLinker) +@rewrite_ofg_inner_graph.register(OpWiseCLinker) +@rewrite_ofg_inner_graph.register(NumbaLinker) +def destructive_rewrite_ofg_inner_graph(linker, op, node, inner, *, mode): + # ``OpFromGraph`` must not mutate its inputs, so all are protected; inplace may + # still be baked between purely internal buffers. + input_specs = [In(x, borrow=True, mutable=False) for x in inner.inputs] + add_supervisor_to_fgraph(fgraph=inner, input_specs=input_specs, accept_inplace=True) + _ofg_inner_optimizer(mode).rewrite(inner) + # The op's outputs must not alias its inputs or each other (it declares no + # view_map, so the outer graph cannot see such aliases); deepcopies break any + # boundary alias the optimized graph ends up with. + output_specs = [Out(o, borrow=False) for o in inner.outputs] + insert_deepcopy(inner, wrapped_inputs=input_specs, wrapped_outputs=output_specs) + + +@rewrite_ofg_inner_graph.register(JAXLinker) +@rewrite_ofg_inner_graph.register(PytorchLinker) +@rewrite_ofg_inner_graph.register(MLXLinker) +def functional_rewrite_ofg_inner_graph(linker, op, node, inner, *, mode): + """Structurally optimize the inner graph for the functional JIT backends.""" + _ofg_inner_optimizer(mode).rewrite(inner) + + +@graph_rewriter +def ofg_inner_graph(fgraph): + # ``OpWithCoreShape`` is imported lazily: at module import time + # ``pytensor.tensor.random`` is only partially initialized. ``*WithCoreShape`` + # are leaf backend ops with dedicated dispatch; re-optimizing them would loop. + from pytensor.tensor.random.op import OpWithCoreShape + + rewrite_inner_graph( + fgraph, + lambda op: isinstance(op, OpFromGraph) and not isinstance(op, OpWithCoreShape), + rewrite_ofg_inner_graph, + ) + + +optdb.register( + "ofg_inner_graph", + ofg_inner_graph, + "minimum_compile", + "compile_inner_graph", + position=49.6, +) + + +def inline_ofg_node(node: Apply) -> list[Variable]: + frozen_fg: FrozenFunctionGraph = node.op.fgraph + replacements = dict(zip(frozen_fg.inputs, node.inputs)) + inlined_outs = frozen_fg.bind(replacements) + copy_stack_trace(frozen_fg.outputs, inlined_outs) + return inlined_outs + + +@node_rewriter([OpFromGraph]) +def inline_ofg_expansion(fgraph, node): + """ + This optimization expands internal graph of OpFromGraph. + Only performed if node.op.is_inline == True + Doing so can improve optimization at the cost of compilation speed. + """ + op = node.op + if not op.is_inline: + return False + + return inline_ofg_node(node) + + +# We want to run this before the first merge optimizer +# and before the first scan optimizer. +optdb.register( + "inline_ofg_expansion", + dfs_rewriter(inline_ofg_expansion), + "fast_compile", + "fast_run", + position=-0.01, +) diff --git a/pytensor/d3viz/formatting.py b/pytensor/d3viz/formatting.py index 5c879258cf..8dced65a5d 100644 --- a/pytensor/d3viz/formatting.py +++ b/pytensor/d3viz/formatting.py @@ -11,7 +11,7 @@ import pytensor from pytensor.compile import builders from pytensor.compile.executor import Function -from pytensor.graph.basic import Apply, Constant, Variable +from pytensor.graph.basic import AbstractApply, Constant, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.traversal import graph_inputs from pytensor.printing import _try_pydot_import @@ -127,7 +127,7 @@ def __call__(self, fct, graph=None): else: if isinstance(fct, Variable): fct = [fct] - elif isinstance(fct, Apply): + elif isinstance(fct, AbstractApply): fct = fct.outputs assert isinstance(fct, list | tuple) assert all(isinstance(v, Variable) for v in fct) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 92d7881234..5927a22470 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -15,7 +15,6 @@ Any, Generic, Optional, - Self, TypeVar, Union, cast, @@ -108,7 +107,89 @@ def dprint(self, **kwargs): return debugprint(self, **kwargs) -class Apply(Node, Generic[OpType]): # noqa: UP046 +def _warn_deprecated_clone_inner_graph(value, name="clone_inner_graph"): + """Warn if a caller still passes the removed ``clone_inner_graph(s)`` kwarg.""" + if value is not None: + warnings.warn( + f"`{name}` is deprecated and ignored: inner-graph `Op`s are immutable, " + "so cloning always shares them.", + FutureWarning, + stacklevel=3, + ) + + +class AbstractApply(Node): + r"""Common, immutability-agnostic base for `Apply` and `FrozenApply`. + + Never instantiated directly. It holds the read-only structural API shared by + the mutable `Apply` and the immutable, interned `FrozenApply`: the `op`, the + `inputs`/`outputs` sequences, and the queries derived from them. Mutation + and cloning live on `Apply` alone, so code that must reject frozen nodes can + test ``isinstance(x, Apply)`` while code that only reads structure can + accept `AbstractApply`. + """ + + op: "Op" + inputs: Sequence["Variable"] + outputs: Sequence["Variable"] + tag: Scratchpad + + def default_output(self): + """ + Returns the default output for this node. + + Returns + ------- + Variable instance + An element of self.outputs, typically self.outputs[0]. + + Notes + ----- + May raise AttributeError self.op.default_output is out of range, or if + there are multiple outputs and self.op.default_output does not exist. + + """ + do = getattr(self.op, "default_output", None) + if do is None: + if len(self.outputs) == 1: + return self.outputs[0] + else: + raise ValueError( + f"Multi-output Op {self.op} default_output not specified" + ) + return self.outputs[do] + + def __str__(self): + # FIXME: The called function is too complicated for this simple use case. + return op_as_string(self.inputs, self) + + def __repr__(self): + return str(self) + + def get_parents(self): + return list(self.inputs) + + @property + def out(self): + """An alias for `self.default_output`""" + return self.default_output() + + @property + def nin(self): + """The number of inputs.""" + return len(self.inputs) + + @property + def nout(self): + """The number of outputs.""" + return len(self.outputs) + + @property + def params_type(self): + return self.op.params_type + + +class Apply(AbstractApply, Generic[OpType]): # noqa: UP046 """A `Node` representing the application of an operation to inputs. Basically, an `Apply` instance is an object that represents the @@ -143,6 +224,8 @@ class Apply(Node, Generic[OpType]): # noqa: UP046 """ + op: OpType + def __init__( self, op: OpType, @@ -194,70 +277,28 @@ def __getstate__(self): d["tag"] = t return d - def default_output(self): - """ - Returns the default output for this node. - - Returns - ------- - Variable instance - An element of self.outputs, typically self.outputs[0]. - - Notes - ----- - May raise AttributeError self.op.default_output is out of range, or if - there are multiple outputs and self.op.default_output does not exist. - - """ - do = getattr(self.op, "default_output", None) - if do is None: - if len(self.outputs) == 1: - return self.outputs[0] - else: - raise ValueError( - f"Multi-output Op {self.op} default_output not specified" - ) - return self.outputs[do] - - def __str__(self): - # FIXME: The called function is too complicated for this simple use case. - return op_as_string(self.inputs, self) - - def __repr__(self): - return str(self) - - def clone(self, clone_inner_graph: bool = False) -> "Apply[OpType]": + def clone(self, clone_inner_graph=None) -> "Apply[OpType]": r"""Clone this `Apply` instance. - Parameters - ---------- - clone_inner_graph - If ``True``, clone `HasInnerGraph` `Op`\s and their inner-graphs. - Returns ------- A new `Apply` instance with new outputs. Notes ----- - Tags are copied from `self` to the returned instance. + Tags are copied from `self` to the returned instance. Inner-graph `Op`\s + are immutable, so the `Op` is shared rather than deep-cloned. """ - from pytensor.graph.op import HasInnerGraph - - new_op = self.op - - if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore - new_op = new_op.clone() # type: ignore - + _warn_deprecated_clone_inner_graph(clone_inner_graph) cp = self.__class__( - new_op, self.inputs, [output.clone() for output in self.outputs] + self.op, self.inputs, [output.clone() for output in self.outputs] ) cp.tag = copy(self.tag) return cp def clone_with_new_inputs( - self, inputs: Sequence["Variable"], strict=True, clone_inner_graph=False + self, inputs: Sequence["Variable"], strict=True, clone_inner_graph=None ) -> "Apply[OpType]": r"""Duplicate this `Apply` instance in a new graph. @@ -274,8 +315,6 @@ def clone_with_new_inputs( ``self.outputs``. If ``False``, then there's no guarantee that the clone's outputs will have the same types as ``self.outputs``, and cloning may not even be possible (it depends on the `Op`). - clone_inner_graph : bool - If ``True``, clone `HasInnerGraph` `Op`\s and their inner-graphs. Returns ------- @@ -283,8 +322,7 @@ def clone_with_new_inputs( An `Apply` instance with the same `Op` but different outputs. """ - from pytensor.graph.op import HasInnerGraph - + _warn_deprecated_clone_inner_graph(clone_inner_graph) assert isinstance(inputs, list | tuple) remake_node = False new_inputs: list[Variable] = list(inputs) @@ -310,40 +348,13 @@ def clone_with_new_inputs( remake_node = True if remake_node: - new_op = self.op - - if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore - new_op = new_op.clone() # type: ignore - - new_node = new_op.make_node(*new_inputs) + new_node = self.op.make_node(*new_inputs) new_node.tag = copy(self.tag).__update__(new_node.tag) else: - new_node = self.clone(clone_inner_graph=clone_inner_graph) + new_node = self.clone() new_node.inputs = new_inputs return new_node - def get_parents(self): - return list(self.inputs) - - @property - def out(self): - """An alias for `self.default_output`""" - return self.default_output() - - @property - def nin(self): - """The number of inputs.""" - return len(self.inputs) - - @property - def nout(self): - """The number of outputs.""" - return len(self.outputs) - - @property - def params_type(self): - return self.op.params_type - class Variable(Node, Generic[_TypeType, OptionalApplyType]): # noqa: UP046 r""" @@ -464,7 +475,7 @@ def __init__( self.owner = owner - if owner is not None and not isinstance(owner, Apply): + if owner is not None and not isinstance(owner, AbstractApply): raise TypeError("owner must be an Apply instance") if index is not None and not isinstance(index, int): @@ -817,17 +828,28 @@ def __reduce_ex__(protocol): return __reduce_ex__ -class FrozenApply(Apply): - """An immutable, globally-interned Apply node for frozen graphs. +class FrozenApply(AbstractApply): + """An immutable, globally-interned application node for frozen graphs. + + It deliberately does *not* subclass `Apply`: its `inputs` / `outputs` are + tuples and it has no `clone` / `clone_with_new_inputs`, so it cannot be + mutated, rebuilt, or walked by the generic clone machinery -- a frozen node + leaking into such a path fails loudly. Code that manipulates a frozen graph + must thaw it explicitly first (``FrozenFunctionGraph.unfreeze`` / ``bind``). - ``inputs`` and ``outputs`` are tuples, so mutating them raises ``TypeError``. + Instances are interned on ``(op, inputs, output_types, topo_idx)``: + constructing one with a matching key returns the cached instance. Constant + inputs are keyed by ``signature()`` (so equal-valued Constants share a node). + Other inputs are already globally interned, so identity is enough; the key + stores their ``id()`` rather than the variables themselves, keeping strong + references out of the cache so chains of ``FrozenApply`` nodes collect in a + single GC pass. - Instances are interned on ``(op, inputs, output_types)``: constructing one - with a matching key returns the cached instance. Constant inputs are keyed - by ``signature()`` (so equal-valued Constants share a node). Other inputs - are already globally interned, so identity is enough; the key stores their - ``id()`` rather than the variables themselves, keeping strong references out - of the cache so chains of ``FrozenApply`` nodes collect in a single GC pass. + ``topo_idx`` is the node's position in a baked custom toposort, or ``-1`` + when no specific order is imposed (deduplicated graphs). Keying on it keeps + two structurally equal computations at different toposort positions as + distinct interned nodes, which is how a `FrozenFunctionGraph` pins a + destroy-aware order. ``output_types`` is in the key because frozen graphs root on ``NominalVariable`` inputs (index and type only). Nominalizing truncates the @@ -837,17 +859,20 @@ class FrozenApply(Apply): """ _cache: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + topo_idx: int def __new__( cls, op: "Op", inputs: tuple[Variable, ...], output_types: tuple["Type", ...], + topo_idx: int = -1, ): cache_key = ( op, tuple(i.signature() if isinstance(i, Constant) else id(i) for i in inputs), output_types, + topo_idx, ) cached = cls._cache.get(cache_key) if cached is not None: @@ -855,8 +880,9 @@ def __new__( instance = object.__new__(cls) instance.op = op - instance.inputs = inputs # type: ignore[assignment] - instance.outputs = tuple( # type: ignore[assignment] + instance.topo_idx = topo_idx + instance.inputs = inputs + instance.outputs = tuple( t.variable_type(type=t, owner=instance, index=i) for i, t in enumerate(output_types) ) @@ -868,16 +894,20 @@ def __new__( cls._cache[cache_key] = instance return instance - def __init__(self, op, inputs, output_types): + def __init__(self, op, inputs, output_types, topo_idx=-1): # All initialization is done in __new__ pass - def clone(self, clone_inner_graph: bool = False) -> Self: - """Frozen nodes are immutable — cloning returns self.""" - return self - def __reduce__(self): - return (type(self), (self.op, self.inputs, tuple(o.type for o in self.outputs))) + return ( + type(self), + ( + self.op, + self.inputs, + tuple(o.type for o in self.outputs), + self.topo_idx, + ), + ) def clone( @@ -885,7 +915,7 @@ def clone( outputs: Sequence[Variable], copy_inputs: bool = True, copy_orphans: bool | None = None, - clone_inner_graphs: bool = False, + clone_inner_graphs=None, ) -> tuple[list[Variable], list[Variable]]: r"""Copies the sub-graph contained between inputs and outputs. @@ -901,9 +931,6 @@ def clone( When ``None``, use the `copy_inputs` value. When ``True``, new orphans nodes are created. When ``False``, original orphans nodes are reused in the new graph. - clone_inner_graphs : bool - If ``True``, clone `HasInnerGraph` `Op`\s and their inner-graphs. - Returns ------- The inputs and outputs of that copy. @@ -916,6 +943,7 @@ def clone( conditional on the `copy_orphans` parameter. """ + _warn_deprecated_clone_inner_graph(clone_inner_graphs, "clone_inner_graphs") if copy_orphans is None: copy_orphans = copy_inputs equiv = clone_get_equiv( @@ -923,7 +951,6 @@ def clone( outputs, copy_inputs=copy_inputs, copy_orphans=copy_orphans, - clone_inner_graphs=clone_inner_graphs, ) return [cast(Variable, equiv[input]) for input in inputs], [ cast(Variable, equiv[output]) for output in outputs @@ -933,14 +960,10 @@ def clone( def clone_node_and_cache( node: Apply, clone_d: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]], - clone_inner_graphs=False, **kwargs, ) -> Apply | None: """Clone an `Apply` node and cache the results in `clone_d`. - This function handles `Op` clones that are generated by inner-graph - cloning. - Returns ------- ``None`` if all of `node`'s outputs are already in `clone_d`; otherwise, @@ -952,29 +975,12 @@ def clone_node_and_cache( # `clone_d`, then there's likely no need to clone it return None - # Use a cached `Op` clone when available - new_op: Op | None = cast(Optional["Op"], clone_d.get(node.op)) - cloned_inputs: list[Variable] = [cast(Variable, clone_d[i]) for i in node.inputs] - new_node = node.clone_with_new_inputs( - cloned_inputs, - # Only clone inner-graph `Op`s when there isn't a cached clone (and - # when `clone_inner_graphs` is enabled) - clone_inner_graph=clone_inner_graphs if new_op is None else False, - **kwargs, - ) - - if new_op: - # If we didn't clone the inner-graph `Op` above, because - # there was a cached version, set the cloned `Apply` to use - # the cached clone `Op` - new_node.op = new_op + new_node = node.clone_with_new_inputs(cloned_inputs, **kwargs) clone_d[node] = new_node - clone_d.setdefault(node.op, new_node.op) - for old_o, new_o in zip(node.outputs, new_node.outputs, strict=True): clone_d.setdefault(old_o, new_o) @@ -988,7 +994,7 @@ def clone_get_equiv( copy_orphans: bool = True, memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]] | None = None, - clone_inner_graphs: bool = False, + clone_inner_graphs=None, **kwargs, ) -> dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]: r"""Clone the graph between `inputs` and `outputs` and return a map of the cloned objects. @@ -1018,14 +1024,13 @@ def clone_get_equiv( Optionally start with a partly-filled dictionary for the return value. If a dictionary is passed, this function will work in-place on that dictionary and return it. - clone_inner_graphs - If ``True``, clone `HasInnerGraph` `Op`\s and their inner-graphs. kwargs Keywords passed to `Apply.clone_with_new_inputs`. """ from pytensor.graph.traversal import toposort + _warn_deprecated_clone_inner_graph(clone_inner_graphs, "clone_inner_graphs") if memo is None: memo = {} @@ -1049,9 +1054,7 @@ def clone_get_equiv( else: memo[input] = input - clone_node_and_cache( - apply, memo, clone_inner_graphs=clone_inner_graphs, **kwargs - ) + clone_node_and_cache(apply, memo, **kwargs) # finish up by cloning any remaining outputs (it can happen) for output in outputs: diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index b43c3cdd8b..7ab6df5477 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -5,10 +5,11 @@ from collections import defaultdict from collections.abc import Iterable, Sequence, Set from functools import partial -from typing import Any, Union, cast +from typing import Any, Literal, Union, cast, overload from pytensor.configdefaults import config from pytensor.graph.basic import ( + AbstractApply, Apply, AtomicVariable, Constant, @@ -23,6 +24,7 @@ from pytensor.graph.traversal import ( applys_between, graph_inputs, + io_toposort, toposort, toposort_with_orderings, vars_between, @@ -848,13 +850,12 @@ def check_integrity(self) -> None: def __repr__(self): return f"FunctionGraph({', '.join(graph_as_string(self.inputs, self.outputs))})" - def clone( - self, check_integrity=True, clone_inner_graphs: bool = False - ) -> "FunctionGraph": + def clone(self, check_integrity=True, clone_inner_graphs=None) -> "FunctionGraph": """Clone the graph.""" - return self.clone_get_equiv( - check_integrity, clone_inner_graphs=clone_inner_graphs - )[0] + from pytensor.graph.basic import _warn_deprecated_clone_inner_graph + + _warn_deprecated_clone_inner_graph(clone_inner_graphs, "clone_inner_graphs") + return self.clone_get_equiv(check_integrity)[0] def clone_get_equiv( self, check_integrity: bool = True, attach_feature: bool = True, **kwargs @@ -916,10 +917,10 @@ def __getstate__(self): d.pop("_execute_callbacks_times_dict", None) return d - def __contains__(self, item: Variable | Apply) -> bool: + def __contains__(self, item: Variable | AbstractApply) -> bool: if isinstance(item, Variable): return item in self.variables - elif isinstance(item, Apply): + elif isinstance(item, AbstractApply): return item in self.apply_nodes else: raise TypeError() @@ -937,8 +938,16 @@ def dprint(self, **kwargs): return debugprint(self, **kwargs) def freeze(self) -> "FrozenFunctionGraph": - """Return a frozen, hashable version of this FunctionGraph.""" - return FrozenFunctionGraph(self.inputs, self.outputs) + """Return a frozen, hashable version of this FunctionGraph. + + The frozen graph bakes ``self.toposort()``, so when ``self`` carries a + ``DestroyHandler`` (e.g. a graph that was just inplace-rewritten) the order + is destroy-aware -- a baked inplace op runs after every reader of the buffer + it destroys -- and a backend may funcify the frozen graph as-is. + """ + return FrozenFunctionGraph.from_toposort( + self.inputs, self.outputs, self.toposort() + ) class FrozenFunctionGraph(AbstractFunctionGraph): @@ -949,8 +958,9 @@ class FrozenFunctionGraph(AbstractFunctionGraph): graphs share the same interned output objects, so equality reduces to identity comparison on the outputs tuple. - Use ``FunctionGraph.freeze()`` or ``FrozenFunctionGraph(inputs, outputs)`` - to create instances. + Use ``FunctionGraph.freeze()``, `from_io`, or `from_toposort` to freeze a + graph; the constructor itself only assembles already-frozen parts (it is + the path ``__reduce__`` round-trips through). .. code-block:: python @@ -967,9 +977,62 @@ class FrozenFunctionGraph(AbstractFunctionGraph): def __init__( self, + inputs: tuple[Variable, ...], + toposort: tuple[Apply, ...], + output_nodes: tuple[Apply, ...], + ): + self.inputs: tuple[Variable, ...] = inputs + self.outputs: tuple[Variable, ...] = tuple( + node.inputs[0] for node in output_nodes + ) + self.apply_nodes: frozenset[Apply] = frozenset(toposort) + self._toposort: tuple[Apply, ...] = toposort + self._output_nodes: tuple[Apply, ...] = output_nodes + self._variables: frozenset[Variable] | None = None + self._clients: dict[Variable, list[ClientType]] | None = None + + @classmethod + def from_io( + cls, inputs: Sequence[Variable], outputs: Sequence[Variable], - ): + dedup_nodes: bool = False, + ) -> "FrozenFunctionGraph": + """Freeze the graph between ``inputs`` and ``outputs`` in plain toposort order. + + By default structurally-identical nodes are kept as distinct interned + nodes (keyed by toposort position): folding them would alias distinct + buffers that downstream inplace/``destroy_map`` logic relies on, and two + structurally-identical graphs still compare equal because positions line + up. Pass ``dedup_nodes=True`` to fold such nodes onto a single interned + node (no order is baked); only safe for graphs free of inplace ops + (e.g. ``Composite``/``ScalarLoop``). + """ + return cls._freeze(inputs, outputs, io_toposort(inputs, outputs), dedup_nodes) + + @classmethod + def from_toposort( + cls, + inputs: Sequence[Variable], + outputs: Sequence[Variable], + toposort: Sequence[Apply], + ) -> "FrozenFunctionGraph": + """Freeze the graph baking ``toposort`` as the node order. + + Each node's position becomes part of its interning key, so a custom + (e.g. destroy-aware) order survives interning and pickling and a backend + may funcify the frozen graph as-is. + """ + return cls._freeze(inputs, outputs, toposort, dedup_nodes=False) + + @classmethod + def _freeze( + cls, + inputs: Sequence[Variable], + outputs: Sequence[Variable], + toposort: Sequence[Apply], + dedup_nodes: bool, + ) -> "FrozenFunctionGraph": nominal_inputs = tuple( NominalVariable(i, inp.type) for i, inp in enumerate(inputs) ) @@ -990,11 +1053,16 @@ def _resolve_input(inp, memo=memo): "or produced by Apply nodes reachable from the inputs." ) - for node in toposort(outputs, blockers=inputs): + for node_idx, node in enumerate(toposort): new_inputs = tuple(_resolve_input(inp) for inp in node.inputs) output_types = tuple(out.type for out in node.outputs) - new_node = FrozenApply(node.op, new_inputs, output_types) - sorted_apply_nodes.append(new_node) + new_node = FrozenApply( + node.op, + new_inputs, + output_types, + topo_idx=-1 if dedup_nodes else node_idx, + ) + sorted_apply_nodes.append(new_node) # type: ignore[arg-type] memo.update(zip(node.outputs, new_node.outputs, strict=True)) @@ -1021,15 +1089,11 @@ def _resolve_input(inp, memo=memo): else: output_nodes.append(FrozenApply(Output(i), (resolved,), ())) - self.inputs: tuple[Variable, ...] = nominal_inputs - self.outputs: tuple[Variable, ...] = tuple( - node.inputs[0] for node in output_nodes + return cls( + nominal_inputs, + tuple(sorted_apply_nodes), + tuple(output_nodes), # type: ignore[arg-type] ) - self.apply_nodes: frozenset[Apply] = frozenset(sorted_apply_nodes) - self._toposort: tuple[Apply, ...] = tuple(sorted_apply_nodes) - self._output_nodes: tuple[Apply, ...] = tuple(output_nodes) - self._variables: frozenset[Variable] | None = None - self._clients: dict[Variable, list[ClientType]] | None = None @classmethod def from_structural_inputs( @@ -1058,12 +1122,23 @@ def from_structural_inputs( roots = [ v for v in graph_inputs([*inputs, *outputs]) if not isinstance(v, Constant) ] - interned = cls(roots, [*inputs, *outputs]) + # Structural matching requires deduplication: each intermediate input + # expression must intern onto the same node as its occurrences in the + # outputs so they can be rewired (position-keyed interning would keep + # them distinct). + interned = cls.from_io(roots, [*inputs, *outputs], dedup_nodes=True) n_inputs = len(inputs) - return cls(interned.outputs[:n_inputs], interned.outputs[n_inputs:]) + return cls.from_io( + interned.outputs[:n_inputs], + interned.outputs[n_inputs:], + dedup_nodes=True, + ) def __reduce__(self): - return FrozenFunctionGraph, (self.inputs, self.outputs) + # Plain reassembly: every part re-interns itself on unpickling + # (``FrozenApply``/``NominalVariable`` interning), restoring the + # canonical graph with its baked order and dedup keys. + return FrozenFunctionGraph, (self.inputs, self._toposort, self._output_nodes) def __hash__(self): return hash(self._output_nodes) @@ -1105,13 +1180,28 @@ def clients(self) -> dict[Variable, list[ClientType]]: # type: ignore[override] self._clients = clients return self._clients + @overload + def bind( + self, + replace: Sequence[Variable] | dict[Variable, Variable], + *, + return_memo: Literal[False] = ..., + ) -> list[Variable]: ... + @overload def bind( - self, replace: Sequence[Variable] | dict[Variable, Variable] - ) -> list[Variable]: + self, + replace: Sequence[Variable] | dict[Variable, Variable], + *, + return_memo: Literal[True], + ) -> tuple[list[Variable], dict[Any, Any]]: ... + def bind(self, replace, *, return_memo: bool = False): """Return fresh outputs with root inputs substituted per *replace*. Constants are reused; any non-Constant input not in *replace* raises KeyError. + With ``return_memo=True``, also return the memo mapping each frozen + variable and Apply node to its rebuilt counterpart. """ + memo: dict[Any, Any] if isinstance(replace, dict): memo = replace.copy() else: @@ -1129,14 +1219,33 @@ def bind( [memo[i] for i in node.inputs], [o.type() for o in node.outputs], ) + memo[node] = new_node memo.update(zip(node.outputs, new_node.outputs)) - return [out if isinstance(out, Constant) else memo[out] for out in self.outputs] - - def unfreeze(self) -> "FunctionGraph": - """Return a mutable FunctionGraph with fresh mutable Apply nodes.""" + for out in self.outputs: + if isinstance(out, Constant): + memo.setdefault(out, out) + outputs = [memo[out] for out in self.outputs] + if return_memo: + return outputs, memo + return outputs + + @overload + def unfreeze(self, *, return_memo: Literal[False] = ...) -> "FunctionGraph": ... + @overload + def unfreeze( + self, *, return_memo: Literal[True] + ) -> tuple["FunctionGraph", dict[Any, Any]]: ... + def unfreeze(self, *, return_memo: bool = False): + """Return a mutable FunctionGraph with fresh mutable Apply nodes. + + With ``return_memo=True``, also return the memo mapping each frozen + variable and Apply node to its mutable counterpart. + """ fresh_inputs = [inp.type() for inp in self.inputs] - return FunctionGraph( - fresh_inputs, - self.bind(dict(zip(self.inputs, fresh_inputs))), - clone=False, + outputs, memo = self.bind( + dict(zip(self.inputs, fresh_inputs)), return_memo=True ) + fgraph = FunctionGraph(fresh_inputs, outputs, clone=False) + if return_memo: + return fgraph, memo + return fgraph diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 001ccc8162..d4c2cbbf3f 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -131,6 +131,21 @@ def print_profile(cls, stream, prof, level=0): ) +def get_active_mode(fgraph): + """Return the compile `Mode` currently being used to rewrite `fgraph`. + + Set by `FunctionMaker` around the optimization pass. Used by inner-graph + rewrites to recover the active linker's required/incompatible rewrites. + Falls back to the default mode when called outside a compilation. + """ + from pytensor.compile.mode import get_mode + from pytensor.configdefaults import config + + if (active := getattr(fgraph, "_compile_mode", None)) is not None: + return get_mode(active) + return get_mode(config.mode) + + class NodeRewriter(Rewriter): """A `Rewriter` that is applied to an `Apply` node.""" diff --git a/pytensor/graph/traversal.py b/pytensor/graph/traversal.py index 50b3359ff8..eacf58a0f3 100644 --- a/pytensor/graph/traversal.py +++ b/pytensor/graph/traversal.py @@ -12,7 +12,7 @@ overload, ) -from pytensor.graph.basic import Apply, Constant, Node, Variable +from pytensor.graph.basic import AbstractApply, Apply, Constant, Node, Variable T = TypeVar("T", bound=Node) @@ -340,7 +340,7 @@ def apply_depends_on(apply: Apply, depends_on: Apply | Iterable[Apply]) -> bool: bool """ - if isinstance(depends_on, Apply): + if isinstance(depends_on, AbstractApply): depends_on = frozenset((depends_on,)) else: depends_on = frozenset(depends_on) @@ -683,7 +683,7 @@ def toposort_with_orderings( def compute_deps(obj, blocker_set=frozenset(blockers), orderings=orderings): if obj in blocker_set: return None - if isinstance(obj, Apply): + if isinstance(obj, AbstractApply): return [*obj.inputs, *orderings.get(obj, [])] else: if (apply := obj.owner) is not None: @@ -694,7 +694,7 @@ def compute_deps(obj, blocker_set=frozenset(blockers), orderings=orderings): # mypy doesn't like conditional functions with different signatures, # but passing the globals as optional is faster def compute_deps(obj, orderings=orderings): # type: ignore[misc] - if isinstance(obj, Apply): + if isinstance(obj, AbstractApply): return [*obj.inputs, *orderings.get(obj, [])] else: if (apply := obj.owner) is not None: @@ -706,7 +706,7 @@ def compute_deps(obj, orderings=orderings): # type: ignore[misc] apply for apply in walk_toposort(graphs, deps=compute_deps) # mypy doesn't understand that our generator will return both Apply and Variables - if isinstance(apply, Apply) # type: ignore + if isinstance(apply, AbstractApply) # type: ignore ) diff --git a/pytensor/link/jax/dispatch/basic.py b/pytensor/link/jax/dispatch/basic.py index c1240bba31..51037e499b 100644 --- a/pytensor/link/jax/dispatch/basic.py +++ b/pytensor/link/jax/dispatch/basic.py @@ -7,7 +7,6 @@ import numpy as np from pytensor.compile.builders import OpFromGraph -from pytensor.compile.mode import JAX from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.configdefaults import config from pytensor.graph import Constant @@ -128,8 +127,6 @@ def type_cast(x): def jax_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs) -> Callable: _ = kwargs.pop("storage_map", None) - # Apply inner rewrites - JAX.optimizer(ofg.fgraph) fgraph_fn = jax_funcify(ofg.fgraph, **kwargs) if len(ofg.fgraph.outputs) == 1: diff --git a/pytensor/link/jax/dispatch/scan.py b/pytensor/link/jax/dispatch/scan.py index c4c24f0000..defce26a96 100644 --- a/pytensor/link/jax/dispatch/scan.py +++ b/pytensor/link/jax/dispatch/scan.py @@ -4,7 +4,6 @@ import numpy as np from jax._src.lax.control_flow import scan as jax_scan -from pytensor.compile.mode import JAX, get_mode from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.scan.op import Scan @@ -27,14 +26,6 @@ def jax_funcify_Scan(op: Scan, node, **kwargs): if info.as_while: raise NotImplementedError("While Scan cannot yet be converted to JAX") - # Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode) - rewriter = ( - get_mode(op.mode) - .including("jax") - .excluding("numba", *JAX._optimizer.exclude) - .optimizer - ) - rewriter(op.fgraph) scan_inner_func = jax_funcify(op.fgraph, **kwargs) def scan(*outer_inputs): diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index 71a2232f85..6fce1ed466 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -20,7 +20,8 @@ get_scalar_constant_value, ) from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.shape import Shape_i +from pytensor.tensor.shape import Shape, Shape_i +from pytensor.tensor.subtensor import Subtensor ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange` to be constants. @@ -62,13 +63,18 @@ def jax_funcify_ARange(op, node, **kwargs): arange_args = node.inputs constant_args = [] for arg in arange_args: - if arg.owner and isinstance(arg.owner.op, Shape_i): - constant_args.append(None) - elif isinstance(arg, Constant): - constant_args.append(arg.value) - else: - # TODO: This might be failing without need (e.g., if arg = shape(x)[-1] + 1)! - raise NotImplementedError(ARANGE_CONCRETE_VALUE_ERROR) + # Under JAX tracing an array's shape is concrete, so any element of it is a + # valid ``arange`` bound + match arg.owner_op_and_inputs: + case (Shape_i(), *_): + constant_args.append(None) + case (Subtensor(), shape_var, *_) if isinstance(shape_var.owner_op, Shape): + constant_args.append(None) + case _ if isinstance(arg, Constant): + constant_args.append(arg.value) + case _: + # TODO: This might be failing without need (e.g., if arg = shape(x)[-1] + 1)! + raise NotImplementedError(ARANGE_CONCRETE_VALUE_ERROR) constant_start, constant_stop, constant_step = constant_args diff --git a/pytensor/link/mlx/dispatch/basic.py b/pytensor/link/mlx/dispatch/basic.py index cda5d66937..4cabe672af 100644 --- a/pytensor/link/mlx/dispatch/basic.py +++ b/pytensor/link/mlx/dispatch/basic.py @@ -7,7 +7,6 @@ import numpy as np from pytensor.compile.builders import OpFromGraph -from pytensor.compile.mode import MLX from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.graph import Constant from pytensor.graph.fg import AbstractFunctionGraph @@ -222,7 +221,6 @@ def assert_fn(x, *inputs): def mlx_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs): _ = kwargs.pop("storage_map", None) - MLX.optimizer(ofg.fgraph) fgraph_fn = mlx_funcify(ofg.fgraph, squeeze_output=True, **kwargs) return fgraph_fn diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index 6c9db9afd6..9d662308cf 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -4,6 +4,7 @@ class MLXLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX.""" + required_rewrites = ("minimum_compile",) incompatible_rewrites = ( "cxx_only", "BlasOpt", diff --git a/pytensor/link/numba/dispatch/compile_ops.py b/pytensor/link/numba/dispatch/compile_ops.py index 74f8dda91a..e8ea1f0abd 100644 --- a/pytensor/link/numba/dispatch/compile_ops.py +++ b/pytensor/link/numba/dispatch/compile_ops.py @@ -5,10 +5,7 @@ import numba import numpy as np -from pytensor.compile.aliasing import add_supervisor_to_fgraph, insert_deepcopy from pytensor.compile.builders import OpFromGraph -from pytensor.compile.io import In, Out -from pytensor.compile.mode import NUMBA from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.ifelse import IfElse from pytensor.link.numba.cache import compile_numba_function_src @@ -53,7 +50,6 @@ def string_deepcopy(x): def numba_funcify_OpFromGraph( op, node=None, - mode=NUMBA.excluding("symbolic_op_recognition"), ofg_memo=None, **kwargs, ): @@ -62,22 +58,11 @@ def numba_funcify_OpFromGraph( if ofg_memo is not None and op in ofg_memo: return ofg_memo[op] - # Apply inner rewrites - # TODO: Not sure this is the right place to do this, should we have a rewrite that - # explicitly triggers the optimization of the inner graphs of OpFromGraph? - # The C-code defers it to the make_thunk phase - fgraph = op.fgraph - input_specs = [In(x, borrow=True, mutable=False) for x in fgraph.inputs] - add_supervisor_to_fgraph( - fgraph=fgraph, - input_specs=input_specs, - accept_inplace=True, - ) - mode.optimizer(fgraph) - output_specs = [Out(o, borrow=False) for o in fgraph.outputs] - insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs) + # The inner graph is already optimized and inplace/deepcopy-baked by the + # ``ofg_inner_graph`` rewrite (numba contract in ``pytensor.tensor.rewriting``), + # so we funcify ``op.fgraph`` directly. fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key( - fgraph, + op.fgraph, squeeze_output=True, fgraph_name="numba_ofg", ofg_memo=ofg_memo, diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index 0a7b5a0c09..d956525265 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -5,14 +5,7 @@ from numba import types from numba.extending import overload -from pytensor.compile.aliasing import ( - add_supervisor_to_fgraph, - alias_root, - insert_deepcopy, -) -from pytensor.compile.io import In, Out -from pytensor.compile.mode import NUMBA, get_mode -from pytensor.graph.features import NoOutputInplaceOnInput +from pytensor.compile.aliasing import alias_root from pytensor.link.numba.cache import compile_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( @@ -61,262 +54,24 @@ def range_arr(x): @register_funcify_and_cache_key(Scan) def numba_funcify_Scan(op: Scan, node, **kwargs): - """Generate a Numba implementation of a `Scan` loop. - - Memory-aliasing contract - ------------------------ - Scan defines a loop over an inner function with signature: - (*sequences[idx], *traced[idx], *untraced, *non_sequences) - -> (*traced_updates[idx], *untraced_updates) - - Traced variables are read from an indexed circular buffer at every iteration, - and the updates stored (copied) back to it immediately after. Untraced variables - are carried by reference, with each update becoming the next iteration's input. - - Scan is sometimes allowed to destroy/alias the outer traced and untraced variables, - but never sequences and non-sequences. Specifically, outer untraced variables can be - destroyed (destroy_map, opt in) or aliased (view_map, default). Traced variables can be - destroyed (destroy_map, opt in), but otherwise not alias (never in view_map). - Note that destroy permission implies alias permission, but not the other way around. - - Scan is not allowed to return outputs that alias each other, unless they were already - aliased from the outside, and it was itself allowed to alias/destroy them. This means - PyTensor already gauged it was safe to destroy/alias them. - - Scan has some freedom in how this outer contract is respected. If needed, it can - deepcopy the outer inputs once at the start, or make sure any aliased output - the inner function returns is properly copied before the final return. - - Internally, Scan also has total control over the boundary memory management of the - inner function: it grants the permissions to destroy or alias the inner loop inputs, - and whether the inner outputs may alias each other. This inner boundary is distinct - from the outer contract above, and it is Scan's responsibility to choose an inner - strategy that produces correct results while still respecting the outer contract. - - Memory-aliasing strategy - ------------------------ - Traced variables - ~~~~~~~~~~~~~~~~ - Traced variables are deep-copied once at the start if they are not in the destroy_map. - - Because every inner trace update is copied back to the buffer immediately, the inner function - is allowed to alias (but not destroy) the sequence reads, non_sequences, as well as the - traced and untraced inputs or updates, when producing the traced updates. - - A special case occurs when the indexed reads will be immediately overwritten by the updates - in the same loop iteration. For single output taps (mit-sot, sit-sot) this can only - happen when the circular buffer is truncated to its minimum legal length. - For mit-mot this can also happen without any buffer loop-around. - In either case, traced updates are not allowed to alias those traced reads, - as they may otherwise be corrupted if the reads are updated before they were copied to their own buffer. - - On the plus side, when this happens, the inner function is granted permission to destroy these - immediately-to-be discarded reads, as long as the returned updates do not themselves alias them. - - The alias-restriction and destroy-permission caused by the loop-around behavior are derived from the - buffer's static length: - * known large enough: no overwrite is possible, neither alias restriction nor destroy permission applies; - * length unknown: the loop-around overwrite can't be ruled out, alias restricted but not granted destroy permission; - * known minimal: the overwrite is certain, alias restricted but granted destroy permission. - - Untraced variables - ~~~~~~~~~~~~~~~~~~ - Untraced variables are deep-copied once at the start if they are not in destroy_map - and the inner function destroys them. - - Untraced updates are allowed to alias their own untraced inputs (which happens when n_steps=0) - or when the inner function update naturally alias the input (eg, o = i; o = i.T; o = i[::-1]). - - Because the last untraced updates are returned as is, the inner function is not allowed to - alias sequences, non_sequences, or other untraced inputs and outputs (violates the outer alias restriction). - Untraced updates are not allowed to alias traced reads (risks corruption by subsequent overwrittes), - but can alias traced updates, since the immediate copy to their buffer that follows, will break the alias. - - Because untraced inputs are immediately discarded (and protected from alias with other updates), - the inner function is always granted permission to destroy them. It can do so from any computation, - not only the one producing the matching untraced update. - - Controlling inner graph alias - ----------------------------- - PyTensor allows initial graphs to contain arbitrary (non-destructive) aliasing. - Alias at the boundary (output aliasing an input or another output) is controlled via - targeted deepcopies at the end (using the insert_deepcopy helper). - - In contrast, destruction is usually NOT allowed to be present in initial graphs. - Destructive alias at the boundary is controlled during rewrites, with the following features: - * Supervisor: Checks whether any protected input are destroyed - * NoOutputInplaceOnInput: Checks whether an output is destroying a non-protected input - (protected inputs are already covered by Supervisor) - Inside the boundary: - * DestroyHandler: Checks whether a consistent ordering exists for the destruction/view chains, - i.e., every read runs before its buffers' destruction and the chain has no cycle. - These features can veto (undo) any rewrite that would violate their spec. - They CANNOT fix violations that already existed in the initial graph. - - """ - # Apply inner rewrites - # TODO: Not sure this is the right place to do this, should we have a rewrite that - # explicitly triggers the optimization of the inner graphs of Scan? - # The C-code defers it to the make_thunk phase - rewriter = ( - get_mode(op.mode) - .including("numba") - .excluding(*NUMBA._optimizer.exclude) - .optimizer - ) - fgraph = op.fgraph - - # A traced read is "overwritten" when a same-iteration write reuses its physical buffer - # slot. With buffer length L, output tap o_out writes the slot read by input tap o_in iff - # (o_out - o_in) % L == 0. Since L is always >= reach (the minimal admissible length, the - # oldest lookback), only two gaps can trigger it: - # - gap 0 (o_out == o_in): the recurrence writes the very slot it just read (mit_mot - # accumulators: read g[k], write g[k] + delta back). Holds for ANY L, so the read is - # *certainly* overwritten this iteration. - # - gap == reach: a write lands just past the oldest read; in a buffer truncated to its - # minimum (L == reach) it wraps onto that just-discarded oldest read. So the read is - # overwritten only at L == reach: certain when the static length says so, merely possible - # when the length is statically unknown, impossible for any larger L. (A write landing - # further ahead would store onto a slot still to be read -- an invalid recurrence we - # need not consider, just as we don't consider L < reach.) - # A certainly-overwritten read is dead once consumed this iteration, so the inner function - # may destroy it in place. A possibly-overwritten read must not be destroyed (it may still - # be live at a larger length), but no output may alias it either, since a same-iteration - # overwrite would corrupt the alias before it is stored. - def find_overwritten_reads(grouped_inner, outers, in_slices_seq, out_slices_seq): - certain, possible = [], [] - for inner_vars, outer, in_slices, out_slices in zip( - grouped_inner, outers, in_slices_seq, out_slices_seq, strict=True - ): - reach = -min(0, min(in_slices)) # minimal admissible buffer length - in_offsets = [reach + t for t in in_slices] - out_offsets = [reach + t for t in out_slices] - static_len = outer.type.shape[0] - for v, o_in in zip(inner_vars, in_offsets, strict=True): - gaps = {o_out - o_in for o_out in out_offsets} - if 0 in gaps: - certain.append(v) - elif reach in gaps: - if static_len == reach: - certain.append(v) - elif static_len is None: - possible.append(v) - return certain, possible - - mit_mot_certain, mit_mot_possible = find_overwritten_reads( - op.inner_mitmot_grouped(fgraph.inputs), - op.outer_mitmot(node.inputs), - op.info.mit_mot_in_slices, - op.info.mit_mot_out_slices, - ) - mit_sot_certain, mit_sot_possible = find_overwritten_reads( - op.inner_mitsot_grouped(fgraph.inputs), - op.outer_mitsot(node.inputs), - op.info.mit_sot_in_slices, - [(0,)] * op.info.n_mit_sot, - ) - sit_sot_certain, sit_sot_possible = find_overwritten_reads( - [[v] for v in op.inner_sitsot(fgraph.inputs)], - op.outer_sitsot(node.inputs), - op.info.sit_sot_in_slices, - [(0,)] * op.info.n_sit_sot, - ) - - # Reads no output may alias (destructively or not): a same-iteration overwrite could corrupt - # the alias before it is copied to its own buffer. - potentially_overwritten_reads = [ - *mit_mot_certain, - *mit_mot_possible, - *mit_sot_certain, - *mit_sot_possible, - *sit_sot_certain, - *sit_sot_possible, - ] - - # Reads the inner function may destroy in place: the certainly-overwritten traced reads - # (dead once consumed this iteration) and every untraced input (always immediately discarded). - discarded = { - *mit_mot_certain, - *mit_sot_certain, - *sit_sot_certain, - *op.inner_untraced_sit_sot(fgraph.inputs), + """Generate a Numba implementation of a `Scan` loop.""" + # Outer untraced_sit_sot outputs whose inner input the baked step fn destroys; + # when not owned, the codegen copies the outer input on the first iteration. + destroyed_roots = { + alias_root(inner_node.inputs[pos]) + for inner_node in op.fgraph.apply_nodes + for positions in inner_node.op.destroy_map.values() + for pos in positions } - # Grant the inner function the right to alias (borrow=True) all inputs and to destroy - # (mutable=True) the reads known to be immediately discarded. - input_specs = [In(x, borrow=True, mutable=x in discarded) for x in fgraph.inputs] - add_supervisor_to_fgraph( - fgraph=fgraph, - input_specs=input_specs, - accept_inplace=True, - ) - - if potentially_overwritten_reads: - # Forbid any output from aliasing these reads. We could instead allow it and patch with a - # deepcopy at the end (as we do for non-destructive boundary alias), but that is wasteful: - # forbidding it makes the inner graph allocate a fresh buffer and write the result there. - in_pos = {v: i for i, v in enumerate(fgraph.inputs)} - fgraph.attach_feature( - NoOutputInplaceOnInput([in_pos[t] for t in potentially_overwritten_reads]) - ) - - # Rewrite graph - rewriter(fgraph) - - # Post-patch alias contract via targeted deepcopies - # Traced and untraced updates: copy an alias of a tap input the loop (may) overwrite - # this iteration. - # Untraced updates: keep as is if it is the FIRST viewer of a freshly produced buffer - # (including traced updates which will be copied immediately anyway), OR its own untraced input. - # Any other alias is broken with a deepcopy: Sequences reads, traced reads, non-sequences, - # other untraced inputs or updates. - # Note: We could squeeze some more memory reuse by delaying the breaking of aliasing between untraced variables - # by delaying the patched deepcopy until after the loop is over. This requires some care to handle alias transitions - # between untraced updates that can happen over multiple iterations, and protect against cross-iteration destruction - # that can corrupt such chains. - own_untraced_input = dict( - zip( - op.inner_untraced_sit_sot_outs(fgraph.outputs), - op.inner_untraced_sit_sot(fgraph.inputs), + untraced_inputs_destroyed_by_inner_function = { + outer_out_idx + for inner_inp, (outer_out_idx, _) in zip( + op.inner_untraced_sit_sot(op.fgraph.inputs), + op.outer_untraced_sit_sot_outs(node.outputs, with_idx=True), strict=True, ) - ) - untraced_outs = set(own_untraced_input) - seen_untraced_roots = set() - output_specs = [] - for update in fgraph.outputs: - root = alias_root(update) - if update in untraced_outs: - borrow = ( - ( - # freshly produced buffer - root.owner is not None - # or a self alias - or root is own_untraced_input[update] - ) - # and not an alias of another untraced update - and root not in seen_untraced_roots - ) - if borrow: - seen_untraced_roots.add(root) - else: - # traced update - borrow = root not in potentially_overwritten_reads - output_specs.append(Out(update, borrow=borrow)) - insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs) - - # Collect a set of untraced slots the inner function destroys in place. - # These may demand an initial copy if the Scan is not granted permission to destroy them already. - untraced_inputs_destroyed_by_inner_function = set() - if hasattr(fgraph, "destroyers"): - untraced_inputs_destroyed_by_inner_function = { - outer_out_idx - for inner_in, (outer_out_idx, _) in zip( - op.inner_untraced_sit_sot(fgraph.inputs), - op.outer_untraced_sit_sot_outs(node.outputs, with_idx=True), - strict=True, - ) - if fgraph.destroyers(inner_in) - } + if inner_inp in destroyed_roots + } scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key( op.fgraph, fgraph_name="numba_scan", ofg_memo=kwargs.get("ofg_memo") diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 89a2b4ec4a..995d0a3bec 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -4,10 +4,7 @@ import numpy as np import torch -from pytensor import In -from pytensor.compile.aliasing import add_supervisor_to_fgraph from pytensor.compile.builders import OpFromGraph -from pytensor.compile.mode import PYTORCH from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.graph.basic import Constant from pytensor.graph.fg import AbstractFunctionGraph @@ -191,15 +188,6 @@ def ifelse(cond, *true_and_false, n_outs=n_outs): @pytorch_funcify.register(OpFromGraph) def pytorch_funcify_OpFromGraph(op, node, **kwargs): kwargs.pop("storage_map", None) - # Apply inner rewrites - PYTORCH.optimizer(op.fgraph) - fgraph = op.fgraph - add_supervisor_to_fgraph( - fgraph=fgraph, - input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs], - accept_inplace=True, - ) - PYTORCH.optimizer(fgraph) fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True) return fgraph_fn diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index 88fb8d7407..7e22830c4a 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -5,6 +5,7 @@ class PytorchLinker(JITLinker): """A `Linker` that compiles NumPy-based operations using torch.compile.""" + required_rewrites = ("minimum_compile",) incompatible_rewrites = ( "cxx_only", "BlasOpt", diff --git a/pytensor/printing.py b/pytensor/printing.py index cd9b9bab1f..fc09634f54 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -20,7 +20,7 @@ from pytensor.compile.executor import Function from pytensor.compile.io import In, Out from pytensor.configdefaults import config -from pytensor.graph.basic import Apply, Constant, Variable +from pytensor.graph.basic import AbstractApply, Apply, Constant, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import HasInnerGraph, Op, StorageMapType from pytensor.graph.traversal import graph_inputs, toposort @@ -681,7 +681,7 @@ def _show_inner_graph(op): profile_list.append(None) storage_maps.append(None) topo_orders.append(None) - elif isinstance(obj, Apply): + elif isinstance(obj, AbstractApply): outputs_to_print.extend(obj.outputs) profile_list.extend(None for item in obj.outputs) storage_maps.extend(None for item in obj.outputs) @@ -853,22 +853,14 @@ def _show_inner_graph(op): continue else: printed_inner_graph_ops.add(ig_var.owner.op) - # This is a work-around to maintain backward compatibility - # (e.g. to only print inner graphs that have been compiled through - # a call to `Op.prepare_node`) - inner_fn = getattr(ig_var.owner.op, "_fn", None) - - if inner_fn: - # If the op was compiled, print the optimized version. - inner_inputs = inner_fn.maker.fgraph.inputs - inner_outputs = inner_fn.maker.fgraph.outputs + # ``Elemwise``/``Blockwise`` hold their inner graph on ``scalar_op`` + # (a ``Composite``/``ScalarLoop``); other ops expose it directly. + if hasattr(ig_var.owner.op, "scalar_op"): + inner_inputs = ig_var.owner.op.scalar_op.inner_inputs + inner_outputs = ig_var.owner.op.scalar_op.inner_outputs else: - if hasattr(ig_var.owner.op, "scalar_op"): - inner_inputs = ig_var.owner.op.scalar_op.inner_inputs - inner_outputs = ig_var.owner.op.scalar_op.inner_outputs - else: - inner_inputs = ig_var.owner.op.inner_inputs - inner_outputs = ig_var.owner.op.inner_outputs + inner_inputs = ig_var.owner.op.inner_inputs + inner_outputs = ig_var.owner.op.inner_outputs outer_inputs = ig_var.owner.inputs @@ -1319,17 +1311,12 @@ def _assign_color(node_key) -> str: continue printed.add(ig_var.owner) - inner_fn = getattr(ig_var.owner.op, "_fn", None) - if inner_fn: - inner_inputs = inner_fn.maker.fgraph.inputs - inner_outputs = inner_fn.maker.fgraph.outputs + if hasattr(ig_var.owner.op, "scalar_op"): + inner_inputs = ig_var.owner.op.scalar_op.inner_inputs + inner_outputs = ig_var.owner.op.scalar_op.inner_outputs else: - if hasattr(ig_var.owner.op, "scalar_op"): - inner_inputs = ig_var.owner.op.scalar_op.inner_inputs - inner_outputs = ig_var.owner.op.scalar_op.inner_outputs - else: - inner_inputs = ig_var.owner.op.inner_inputs - inner_outputs = ig_var.owner.op.inner_outputs + inner_inputs = ig_var.owner.op.inner_inputs + inner_outputs = ig_var.owner.op.inner_outputs outer_inputs = ig_var.owner.inputs inner_to_outer: dict[Variable, Variable] | None @@ -2076,7 +2063,7 @@ def pydotprint( else: if isinstance(fct, Variable): fct = [fct] - elif isinstance(fct, Apply): + elif isinstance(fct, AbstractApply): fct = fct.outputs assert isinstance(fct, list | tuple) assert all(isinstance(v, Variable) for v in fct) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 08375ac0bc..ae03530fa6 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -26,6 +26,7 @@ from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.fg import FrozenFunctionGraph from pytensor.graph.op import HasInnerGraph, Op +from pytensor.graph.replace import clone_replace from pytensor.graph.traversal import applys_between from pytensor.graph.type import HasDataType, HasShape from pytensor.graph.utils import MetaObject, MethodNotDefined @@ -4204,6 +4205,7 @@ def __getstate__(self): rval = dict(self.__dict__) rval.pop("_c_code", None) rval.pop("_py_perform_fn", None) + rval.pop("_name", None) rval.pop("prepare_node_called", None) return rval @@ -4222,6 +4224,8 @@ class Composite(ScalarInnerGraphOp): """ + _name = None + def __init__( self, inputs, @@ -4229,12 +4233,13 @@ def __init__( name="Composite", ): self.name = name - self._name = None for i in inputs: assert i not in outputs # This isn't supported, use identity - self.fgraph = FrozenFunctionGraph(inputs, outputs) + # Composite inner graphs have no inplace ops, so structurally-identical + # nodes can be safely deduplicated. + self.fgraph = FrozenFunctionGraph.from_io(inputs, outputs, dedup_nodes=True) self._validate_inner_graph(self.fgraph) self.inputs = self.fgraph.inputs @@ -4271,21 +4276,21 @@ def make_node(self, *inputs): if self.inputs_type == tuple(i.type for i in inputs): return super().make_node(*inputs) else: - # Make a new op with the right input types. + # Make a new op whose inner graph is rebuilt on fresh inputs of the + # new types. The retype needs ``rebuild_strict=False`` (re-infers + # each node's output types), which in-place ``FunctionGraph`` + # replacements cannot do, so thaw first and rebuild the mutable copy. assert len(inputs) == self.nin - fg = self.fgraph - res = pytensor.compile.rebuild_collect_shared( - fg.outputs, - replace=dict(zip(fg.inputs, inputs, strict=True)), + unfrozen_fgraph = self.fgraph.unfreeze() + new_inner_inputs = [i.type() for i in inputs] + new_outputs = clone_replace( + unfrozen_fgraph.outputs, + replace=dict( + zip(unfrozen_fgraph.inputs, new_inner_inputs, strict=True) + ), rebuild_strict=False, ) - # After rebuild_collect_shared, the Variable in inputs - # are not necessarily in the graph represented by res. - # res[2][0] is a dict that map from the original variable to the - # cloned variable. - cloned_inputs = [res[2][0][i] for i in inputs] - node = Composite(cloned_inputs, res[1]).make_node(*inputs) - return node + return Composite(new_inner_inputs, new_outputs).make_node(*inputs) def perform(self, node, inputs, output_storage): outputs = self.py_perform_fn(*inputs) diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index 1a4b30a008..79cbcd130a 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -1,9 +1,9 @@ from collections.abc import Sequence from itertools import chain -from pytensor.compile.rebuild import rebuild_collect_shared from pytensor.graph.basic import Constant, Variable, clone from pytensor.graph.fg import FrozenFunctionGraph +from pytensor.graph.replace import clone_replace from pytensor.scalar.basic import ScalarInnerGraphOp, as_scalar @@ -61,7 +61,9 @@ def __init__( self.is_while = until is not None - self.fgraph = FrozenFunctionGraph(inputs, outputs) + # ScalarLoop inner graphs have no inplace ops, so structurally-identical + # nodes can be safely deduplicated. + self.fgraph = FrozenFunctionGraph.from_io(inputs, outputs, dedup_nodes=True) self._validate_inner_graph(self.fgraph) self.inputs = self.fgraph.inputs self.outputs = self.fgraph.outputs @@ -123,30 +125,34 @@ def make_node(self, n_steps, *inputs): if self.inputs_type == tuple(i.type for i in inputs): return super().make_node(n_steps, *inputs) else: - # Make a new op with the right input types. - fg = self.fgraph - res = rebuild_collect_shared( - fg.outputs, - replace=dict(zip(fg.inputs, inputs, strict=True)), + # Make a new op whose inner graph is rebuilt on fresh inputs of the + # new types. The retype needs ``rebuild_strict=False`` (re-infers + # each node's output types), which in-place ``FunctionGraph`` + # replacements cannot do, so thaw first and rebuild the mutable copy. + unfrozen_fgraph = self.fgraph.unfreeze() + new_inner_inputs = [i.type() for i in inputs] + new_outputs = clone_replace( + unfrozen_fgraph.outputs, + replace=dict( + zip(unfrozen_fgraph.inputs, new_inner_inputs, strict=True) + ), rebuild_strict=False, ) if self.is_while: - *cloned_update, cloned_until = res[1] + *new_update, new_until = new_outputs else: - cloned_update, cloned_until = res[1], None - cloned_inputs = [res[2][0][i] for i in inputs] - cloned_init = cloned_inputs[: len(cloned_update)] - cloned_constant = cloned_inputs[len(cloned_update) :] - # This will fail if the cloned init have a different dtype than the cloned_update + new_update, new_until = new_outputs, None + new_init = new_inner_inputs[: len(new_update)] + new_constant = new_inner_inputs[len(new_update) :] + # This will fail if the new init have a different dtype than the new update op = ScalarLoop( - init=cloned_init, - update=cloned_update, - constant=cloned_constant, - until=cloned_until, + init=new_init, + update=new_update, + constant=new_constant, + until=new_until, name=self.name, ) - node = op.make_node(n_steps, *inputs) - return node + return op.make_node(n_steps, *inputs) def perform(self, node, inputs, output_storage): n_steps, *inputs = inputs diff --git a/pytensor/scan/basic.py b/pytensor/scan/basic.py index b4ba5d6608..11c584a724 100644 --- a/pytensor/scan/basic.py +++ b/pytensor/scan/basic.py @@ -425,6 +425,15 @@ def f(x): Pass this to `pytensor.function` when compiling your function. """ + if mode is not None: + warnings.warn( + "The `mode` argument of `scan` is deprecated: the inner graph now " + "inherits the outer compilation (see the `optimize_inner_graphs` " + "rewrite). It is still honored for now.", + FutureWarning, + stacklevel=2, + ) + # General observation : this code is executed only once, at creation # of the computational graph, so we don't yet need to be smart about # anything (to speed things up) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 859978230b..f9b46ce045 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -75,12 +75,11 @@ Variable, ) from pytensor.graph.features import NoOutputFromInplace -from pytensor.graph.fg import FunctionGraph +from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern from pytensor.graph.replace import clone_replace from pytensor.graph.traversal import graph_inputs from pytensor.graph.type import HasShape -from pytensor.graph.utils import InconsistencyError, MissingInputError from pytensor.link.vm import VMLinker from pytensor.printing import op_debug_information from pytensor.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new @@ -502,6 +501,112 @@ def outer_untraced_sit_sot_outs(self, list_outputs, with_idx=False): else: return res + def inner_destroyable_inputs(self, outer_inputs, inner_inputs): + """Inner inputs the step function may safely destroy in place. + + Destroyability depends on the *outer* node's buffer shapes, so this is a + per-node property (two nodes sharing a `Scan` op but with different outer + buffers can differ): + + - A sit_sot tap whose outer buffer holds a single state (``shape[0] == 1``): + the buffer always discards the oldest state, so destroying it is safe. + - The oldest mit_sot tap when the outer buffer holds exactly the taps + (``shape[0] == abs(min(taps))``): same reasoning. + - Every untraced sit_sot is physically destroyable: after the first + iteration the input is the previous output (safe to destroy). On the first + iteration it aliases the outer buffer; each backend handles that per its + memory model -- numba always destroys and copies the first iteration, while + the C/VM rewrite only keeps the destroy when the Scan owns the buffer. + + ``mit_mot`` taps are not included here; the numba inner-graph rewrite grants + destroying their certainly-overwritten reads on top of this set, but the C + backend cannot. + """ + destroyable_sitsot = [ + inner_sitsot + for outer_sitsot, inner_sitsot in zip( + self.outer_sitsot(outer_inputs), + self.inner_sitsot(inner_inputs), + strict=True, + ) + if outer_sitsot.type.shape[0] == 1 + ] + destroyable_mitsot = [ + oldest_inner_mitsot + for outer_mitsot, oldest_inner_mitsot, taps in zip( + self.outer_mitsot(outer_inputs), + self.oldest_inner_mitsot(inner_inputs), + self.info.mit_sot_in_slices, + strict=True, + ) + if outer_mitsot.type.shape[0] == abs(min(taps)) + ] + destroyable_untraced_sit_sot = self.inner_untraced_sit_sot(inner_inputs) + return { + *destroyable_sitsot, + *destroyable_mitsot, + *destroyable_untraced_sit_sot, + } + + def _preallocated_mitmot_updates(self): + """Map inner-output index to inner-input index for mit_mot taps that are both. + + With output preallocation these outputs are wrapped as updates that write + back (possibly in place) into the corresponding input buffer, so -- unlike + the other tap outputs -- they are *allowed* to be the result of an in-place + operation. `prepare_fgraph` uses this as the inner ``update_mapping``. + """ + info = self.info + updates = {} + input_idx = info.n_seqs + output_idx_base = 0 + for in_slices, out_slices in zip( + info.mit_mot_in_slices, info.mit_mot_out_slices, strict=True + ): + for inp_tap in in_slices: + if inp_tap in out_slices: + updates[output_idx_base + out_slices.index(inp_tap)] = input_idx + input_idx += 1 + output_idx_base += len(out_slices) + return updates + + def protected_inner_out_idxs(self, preallocated_mitmot_outs=None): + """Inner-output indices that must not be the result of an in-place op. + + These are the tap outputs (mit_mot / mit_sot / sit_sot / nit_sot) whose + buffers the VM reuses across iterations; a protected output computed by a + destroy-map node would alias a value still needed elsewhere. Preallocated + mit_mot updates are excluded -- they are *meant* to write back into their + input buffer. This is the protection installed as `NoOutputFromInplace` + both at link time (`prepare_fgraph`) and when baking inplace into the + frozen inner graph (`scan_inner_graph`), so the two agree. + """ + if preallocated_mitmot_outs is None: + preallocated_mitmot_outs = ( + self._preallocated_mitmot_updates() + if config.scan__allow_output_prealloc + else () + ) + info = self.info + n_taps = info.n_mit_mot_outs + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot + prealloc = set(preallocated_mitmot_outs) + return tuple(i for i in range(n_taps) if i not in prealloc) + + def inner_owned_untraced_sit_sot(self, inner_inputs): + """Inner untraced sit_sot inputs the Scan owns (output index in ``destroy_map``). + + Ownership grants the right to destroy the outer initial buffer, so these may + be destroyed in place even on the first iteration. + """ + untraced_start = self.n_tap_outs + self.info.n_nit_sot + return { + inner_untraced + for j, inner_untraced in enumerate( + self.inner_untraced_sit_sot(inner_inputs) + ) + if untraced_start + j in self.destroy_map + } + def inner_non_seqs(self, list_inputs): n_taps_upto_sit_sot = sum( len(x) @@ -761,6 +866,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): """ + fgraph: FrozenFunctionGraph + def __init__( self, inputs: list[Variable], @@ -842,12 +949,14 @@ def __init__( If ``True``, all the shared variables used in the inner-graph must be provided. """ - self.fgraph, shared_inputs, _, _ = construct_nominal_fgraph(inputs, outputs) + # ``construct_nominal_fgraph`` raises ``MissingInputError`` if the inner + # graph implicitly depends on any non-input, non-constant variable. + inner_fgraph = construct_nominal_fgraph(inputs, outputs) - # The shared variables should have been removed, so, if there are - # any, it's because the user didn't specify an input. - if shared_inputs: - raise MissingInputError(f"Scan is missing inputs: {shared_inputs}") + # The inner graph is stored immutable. The default freeze (no dedup) + # keeps distinct buffers for inplace ``destroy_map`` ops; structural + # folding would alias them. See ``FunctionGraph.freeze``. + self.fgraph = inner_fgraph.freeze() self.info = info self.truncate_gradient = truncate_gradient @@ -945,17 +1054,13 @@ def tensorConstructor(shape, dtype): self.n_outer_inputs = info.n_outer_inputs self.n_outer_outputs = info.n_outer_outputs - if any(node.op.destroy_map for node in self.fgraph.apply_nodes): - raise InconsistencyError( - "Inner-graphs must not contain in-place operations." - ) - - self._frozen_fgraph = self.fgraph.freeze() - def __setstate__(self, d): self.__dict__.update(d) - if not hasattr(self, "_frozen_fgraph"): - self._frozen_fgraph = self.fgraph.freeze() + # Back-compat: older pickles stored a mutable inner ``fgraph`` (plus a + # separate ``_frozen_fgraph``). Collapse to the single frozen graph. + if not isinstance(self.fgraph, FrozenFunctionGraph): + self.fgraph = self.fgraph.freeze() + self.__dict__.pop("_frozen_fgraph", None) # Ensure that the graph associated with the inner function is valid. self.validate_inner_graph() @@ -1334,7 +1439,7 @@ def __eq__(self, other): if self.allow_gc != other.allow_gc: return False - return self._frozen_fgraph == other._frozen_fgraph + return self.fgraph == other.fgraph def __str__(self): inplace = "none" @@ -1354,7 +1459,7 @@ def __hash__(self): return hash( ( type(self), - self._frozen_fgraph, + self.fgraph, self.info, self.profile, self.truncate_gradient, @@ -1376,59 +1481,45 @@ def prepare_fgraph(self, fgraph): # remove those outputs here just to compensate for an overly rigid # `Function` pipeline. update_mapping = {} - preallocated_mitmot_outs = [] if config.scan__allow_output_prealloc: - # Go through the mitmots. Whenever a mitmot has a tap both as an - # input and an output, wrap the input such that the corresponding - # output variable becomes an update to be performed on it, possibly - # inplace at the end of the functions's execution. + # Whenever a mitmot has a tap both as an input and an output, wrap the + # input such that the corresponding output variable becomes an update to + # be performed on it, possibly inplace at the end of the function's + # execution. wrapped_inputs = [In(x, borrow=False) for x in fgraph.inputs[: info.n_seqs]] - input_idx = info.n_seqs - for mitmot_idx in range(info.n_mit_mot): - for inp_tap in info.mit_mot_in_slices[mitmot_idx]: - if inp_tap in info.mit_mot_out_slices[mitmot_idx]: - inp = fgraph.inputs[input_idx] - - # Figure out the index of the corresponding output - output_idx = sum( - len(m) for m in info.mit_mot_out_slices[:mitmot_idx] - ) - output_idx += info.mit_mot_out_slices[mitmot_idx].index(inp_tap) - - preallocated_mitmot_outs.append(output_idx) - - wrapped_inp = In( - variable=inp, - update=fgraph.outputs[output_idx], - ) - update_mapping[output_idx] = input_idx - wrapped_inputs.append(wrapped_inp) - else: - wrapped_inputs.append( - In(fgraph.inputs[input_idx], borrow=False) - ) - input_idx += 1 + update_mapping = self._preallocated_mitmot_updates() + input_updates = { + input_idx: output_idx + for output_idx, input_idx in update_mapping.items() + } + mitmot_inps_end = info.n_seqs + sum(len(s) for s in info.mit_mot_in_slices) + for input_idx in range(info.n_seqs, mitmot_inps_end): + inp = fgraph.inputs[input_idx] + output_idx = input_updates.get(input_idx) + if output_idx is not None: + wrapped_inputs.append( + In(variable=inp, update=fgraph.outputs[output_idx]) + ) + else: + wrapped_inputs.append(In(inp, borrow=False)) # Wrap the inputs not associated to mitmots and wrap the remaining outputs. - # Untraced sit_sot inputs that are in the destroy_map are marked mutable. + # Untraced sit_sot inputs the Scan owns (in the destroy_map) are marked mutable. untraced_sit_sot_inner_inputs = set( self.inner_untraced_sit_sot(fgraph.inputs) ) - untraced_out_start = self.n_tap_outs + info.n_nit_sot - mutable_untraced_inner_inputs = { - self.inner_untraced_sit_sot(fgraph.inputs)[j] - for j in range(info.n_untraced_sit_sot) - if untraced_out_start + j in self.destroy_map - } + mutable_untraced_inner_inputs = self.inner_owned_untraced_sit_sot( + fgraph.inputs + ) wrapped_inputs += [ In( x, borrow=x in untraced_sit_sot_inner_inputs, mutable=x in mutable_untraced_inner_inputs, ) - for x in fgraph.inputs[input_idx:] + for x in fgraph.inputs[mitmot_inps_end:] ] wrapped_outputs = [Out(x, borrow=True) for x in fgraph.outputs[:slices]] # Untraced sit_sot states are kept by reference across iterations, so @@ -1436,16 +1527,7 @@ def prepare_fgraph(self, fgraph): # lets insert_deepcopy break such aliasing). See issue #2252. wrapped_outputs += [Out(x, borrow=False) for x in fgraph.outputs[slices:]] - protected_outs = tuple( - i - for i in range( - info.n_mit_mot_outs - + info.n_mit_sot - + info.n_sit_sot - + info.n_nit_sot - ) - if i not in preallocated_mitmot_outs - ) + protected_outs = self.protected_inner_out_idxs(update_mapping) fgraph.attach_feature(NoOutputFromInplace(protected_outs)) else: @@ -1466,7 +1548,11 @@ def fn(self): if getattr(self, "_fn", None) is not None: return self._fn - wrapped_inputs, wrapped_outputs = self.prepare_fgraph(self.fgraph) + # Compile a throwaway copy of the (already math-optimized) inner graph. + # The canonical inner graph is immutable; linking setup (MIT-MOT update + # wrapping, supervisor) and any inplace happen on this transient. + inner_fgraph = self.fgraph.unfreeze() + wrapped_inputs, wrapped_outputs = self.prepare_fgraph(inner_fgraph) profile = None if config.profile or ( @@ -1479,19 +1565,32 @@ def fn(self): elif self.profile: profile = self.profile - # Clone mode_instance, altering "allow_gc" for the linker, - # and adding a message if we profile + # ``inner_fgraph`` is already backend-optimized (inplace included): + # ``scan_inner_graph`` ran the backend optimizer on it during the outer + # compile. So we only need to link it (``prepare_fgraph`` still inserts the + # boundary deepcopies). The linker forces ``minimum_compile`` back in via + # its ``required_rewrites``, and (for an inner graph) ``minimum_compile`` + # *is* that inner-graph rewrite -- so we exclude ``compile_inner_graph`` to + # stop it re-baking an already-baked graph. Only the linker choice depends + # on ``self.mode``. mode = self.mode if mode in (None, "FAST_RUN"): - mode_instance = Mode("cvm", "fast_run") + mode_instance = Mode("cvm", "minimum_compile").excluding( + "compile_inner_graph" + ) elif mode == "FAST_COMPILE": mode_instance = Mode( - VMLinker(use_cloop=False, c_thunks=False), "fast_compile" - ) + VMLinker(use_cloop=False, c_thunks=False), "minimum_compile" + ).excluding("compile_inner_graph") else: - mode_instance = get_mode(mode).clone( - link_kwargs=dict(allow_gc=self.allow_gc), - message=f"{self.name or 'Scan'} sub profile", + mode_instance = ( + get_mode(mode) + .clone( + optimizer="minimum_compile", + link_kwargs=dict(allow_gc=self.allow_gc), + message=f"{self.name or 'Scan'} sub profile", + ) + .excluding("compile_inner_graph") ) # Scan python and cython perform relies on the VM being able to set updates for preallocated MIT-MOT, # which only the VMs produced by VMLinker do @@ -1505,26 +1604,52 @@ def fn(self): wrapped_inputs, wrapped_outputs, mode=mode_instance, - accept_inplace=False, + # The (already-optimized) inner graph may carry inplace ops baked in + # by scan_inner_graph; prepare_fgraph has already attached the + # DestroyHandler + Supervisor, so accept them here. + accept_inplace=True, profile=profile, on_unused_input="ignore", - fgraph=self.fgraph, + fgraph=inner_fgraph, ).create() return self._fn @property def inner_inputs(self): - return self.fgraph.inputs + # A list (not the frozen tuple) so the many ``inner_*`` slicing helpers + # and their callers keep list semantics. These are read-only views of the + # immutable graph; rewrites that rebuild a Scan must ``unfreeze`` first. + return list(self.fgraph.inputs) @property def inner_outputs(self): - return self.fgraph.outputs + return list(self.fgraph.outputs) def clone(self) -> "Scan": - res = copy(self) - res.fgraph = res.fgraph.clone(clone_inner_graphs=True) # type: ignore[attr-defined] - return res + # The inner graph is immutable (a frozen ``FunctionGraph``), so there is + # nothing to deep-clone -- mirror ``Composite.clone``. + return self + + def clone_with_inner_graph(self, inner_fgraph) -> "Scan": + """Return a copy of this `Scan` whose inner graph is ``inner_fgraph``. + + Used by the ``scan_inner_graph`` rewrite to bake an already-optimized inner + graph into a NEW immutable op without touching ``self``. All inner-graph- + derived state (``output_types``/``mintaps``/``view_map``/``mitmots_preallocated``) + comes from ``info`` + the output types, neither of which optimization changes, + so ``copy`` preserves it; only the inner graph and the compiled ``_fn`` are + swapped. ``inner_fgraph`` is frozen as-is (when it carries a ``DestroyHandler`` + the frozen toposort is destroy-aware), so no rebuild/re-toposort is needed. + """ + clone = copy(self) + clone._fn = None + clone.fgraph = ( + inner_fgraph + if isinstance(inner_fgraph, FrozenFunctionGraph) + else inner_fgraph.freeze() + ) + return clone def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): """ @@ -2309,26 +2434,32 @@ def infer_shape(self, node, input_shapes): inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:] assert len(inner_ins_shapes) == len(self.inner_inputs) - # Non-sequences have a direct equivalent from self.inner_inputs in + # Build the shape graph on a thawed copy: the fresh shape nodes must not + # be built on top of the frozen inner variables. + unfrozen_fgraph = self.fgraph.unfreeze() + inner_inputs = list(unfrozen_fgraph.inputs) + inner_outputs = list(unfrozen_fgraph.outputs) + + # Non-sequences have a direct equivalent from the inner inputs in # node.inputs - inner_non_sequences = self.inner_inputs[len(seqs_shape) + len(outs_shape) :] + inner_non_sequences = inner_inputs[len(seqs_shape) + len(outs_shape) :] out_equivalent.update( zip(inner_non_sequences, node.inputs[offset:], strict=True) ) if info.as_while: - self_outs = self.inner_outputs[:-1] + self_outs = inner_outputs[:-1] else: - self_outs = self.inner_outputs + self_outs = inner_outputs outs_shape = infer_shape( - outs=self_outs, inputs=self.inner_inputs, input_shapes=inner_ins_shapes + outs=self_outs, inputs=inner_inputs, input_shapes=inner_ins_shapes ) # Will be used to check if outs_shape can be expressed without using - # variables in self.inner_inputs. + # variables in the inner inputs. # The shapes of node.inputs are valid. validator = Validator( valid=input_shapes, - invalid=self.inner_inputs, + invalid=inner_inputs, valid_equivalent=out_equivalent, ) @@ -2474,8 +2605,13 @@ def pullback(self, inputs, outs, dC_douts): if self.truncate_gradient != -1: grad_steps = minimum(grad_steps, self.truncate_gradient) - self_inputs = self.inner_inputs - self_outputs = self.inner_outputs + # Differentiate a thawed copy of the inner graph so ``grad`` walks + # mutable ``Apply`` nodes rather than the immutable ``FrozenApply`` nodes + # of ``self.fgraph`` (whose tuple inputs/outputs break Ops that + # concatenate them). + unfrozen_fgraph = self.fgraph.unfreeze() + self_inputs = list(unfrozen_fgraph.inputs) + self_outputs = list(unfrozen_fgraph.outputs) # differentiable inputs diff_inputs = ( self.inner_seqs(self_inputs) @@ -3231,12 +3367,14 @@ def compute_all_gradients(known_grads): def pushforward(self, inputs, outputs, eval_points): # Step 0. Prepare some shortcut variable info = self.info - self_inputs = self.inner_inputs + # Thaw the inner graph before differentiating (see ``L_op``). + unfrozen_fgraph = self.fgraph.unfreeze() + self_inputs = list(unfrozen_fgraph.inputs) + self_outputs = list(unfrozen_fgraph.outputs) rop_of_inputs = ( self_inputs[: info.n_seqs + self.n_tap_outs] + self_inputs[info.n_seqs + self.n_tap_outs + info.n_untraced_sit_sot :] ) - self_outputs = self.inner_outputs # Step 1. Compute the R_op of the inner function inner_eval_points = [safe_new(x, "_evalpoint") for x in rop_of_inputs] @@ -3511,22 +3649,12 @@ def _op_debug_information_Scan(op: Scan, node: Apply): extra_information = {} - inner_fn = getattr(op, "_fn", None) - - if inner_fn: - inner_inputs = inner_fn.maker.fgraph.inputs - inner_outputs = inner_fn.maker.fgraph.outputs - else: - inner_inputs = op.inner_inputs - inner_outputs = op.inner_outputs - scan_args = ScanArgs( node.inputs, node.outputs, - inner_inputs, - inner_outputs, + op.inner_inputs, + op.inner_outputs, node.op.info, - clone=False, ) for field_name in scan_args.field_names: diff --git a/pytensor/scan/rewriting/__init__.py b/pytensor/scan/rewriting/__init__.py index 19fab29ccf..b547cc996b 100644 --- a/pytensor/scan/rewriting/__init__.py +++ b/pytensor/scan/rewriting/__init__.py @@ -1,3 +1,8 @@ +# Register the per-linker ``rewrite_scan_inner_graph`` implementations consulted by +# the ``scan_inner_graph`` rewrite. Imported here (the linker classes are lightweight +# -- they do not pull in their runtimes) so they take effect whenever +# ``pytensor.scan`` is used. +import pytensor.scan.rewriting.inner_graph from pytensor.scan.rewriting.db import ( ScanEquilibriumGraphRewriter, scan_eqopt1, diff --git a/pytensor/scan/rewriting/db.py b/pytensor/scan/rewriting/db.py index 4155df8dd4..301e448fde 100644 --- a/pytensor/scan/rewriting/db.py +++ b/pytensor/scan/rewriting/db.py @@ -7,9 +7,15 @@ """ from pytensor.compile import optdb -from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter, dfs_rewriter +from pytensor.compile.rewriting import rewrite_inner_graph +from pytensor.graph.rewriting.basic import ( + EquilibriumGraphRewriter, + dfs_rewriter, + graph_rewriter, +) from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB from pytensor.scan.op import Scan +from pytensor.scan.rewriting.inner_graph import rewrite_scan_inner_graph from pytensor.scan.rewriting.inplace import ScanInplaceOptimizer from pytensor.scan.rewriting.io import ( scan_inline_invariant_constants, @@ -234,3 +240,27 @@ def apply(self, fgraph, start_from=None): "fast_run", "scan", ) + + +# Bake each ``Scan`` node's inner graph for the active backend (the per-linker +# ``rewrite_scan_inner_graph``; impls in ``pytensor.scan.rewriting.inner_graph``), +# producing new immutable ops. Tagged ``minimum_compile`` because a baked inner graph +# is required for correct linking (e.g. numba funcifies it as-is). +# +# Runs after the inplace passes (notably ``scan_make_inplace`` at 50.5): a tap may +# only be destroyed in place when the outer Scan owns its buffer, so the inner graph +# can only be baked correctly once the outer destroy/view permissions are final. +@graph_rewriter +def scan_inner_graph(fgraph): + rewrite_inner_graph( + fgraph, lambda op: isinstance(op, Scan), rewrite_scan_inner_graph + ) + + +optdb.register( + "scan_inner_graph", + scan_inner_graph, + "minimum_compile", + "compile_inner_graph", + position=100, +) diff --git a/pytensor/scan/rewriting/inner_graph.py b/pytensor/scan/rewriting/inner_graph.py new file mode 100644 index 0000000000..49dc0bb96a --- /dev/null +++ b/pytensor/scan/rewriting/inner_graph.py @@ -0,0 +1,319 @@ +"""Backend inner-graph rewrites for ``Scan``. + +``rewrite_scan_inner_graph`` bakes a node's inner graph for the active backend; +the ``scan_inner_graph`` pass in ``pytensor.scan.rewriting.db`` drives it through +``pytensor.compile.rewriting.rewrite_inner_graph``. +""" + +from functools import singledispatch + +from pytensor.compile.aliasing import ( + add_supervisor_to_fgraph, + alias_root, + insert_deepcopy, +) +from pytensor.compile.io import In, Out +from pytensor.compile.mode import get_mode +from pytensor.configdefaults import config +from pytensor.graph.features import NoOutputFromInplace, NoOutputInplaceOnInput +from pytensor.link.basic import PerformLinker +from pytensor.link.c.basic import CLinker, OpWiseCLinker +from pytensor.link.jax.linker import JAXLinker +from pytensor.link.mlx.linker import MLXLinker +from pytensor.link.numba.linker import NumbaLinker +from pytensor.link.pytorch.linker import PytorchLinker +from pytensor.link.vm import VMLinker + + +@singledispatch +def rewrite_scan_inner_graph(linker, op, node, inner, *, mode): + """Rewrite a ``Scan`` inner graph (in place) for ``linker``'s backend.""" + raise NotImplementedError( + f"Linker {type(linker).__name__} has not registered a Scan inner-graph rewrite" + ) + + +def scan_inner_optimizer(op, mode): + """Optimizer to run on a ``Scan`` inner graph. + + ``mode.optimizer`` unless the op carries a (deprecated) custom ``mode``, in + which case that mode is combined with the active linker's required/incompatible + rewrites so backend must-have ops still apply. + """ + custom_mode = getattr(op, "mode", None) + if custom_mode is None: + return mode.optimizer + linker = mode.linker + return ( + get_mode(custom_mode) + .including(*linker.required_rewrites) + .excluding(*linker.incompatible_rewrites) + .optimizer + ) + + +@rewrite_scan_inner_graph.register(VMLinker) +@rewrite_scan_inner_graph.register(PerformLinker) +@rewrite_scan_inner_graph.register(CLinker) +@rewrite_scan_inner_graph.register(OpWiseCLinker) +def cvm_rewrite_scan_inner_graph(linker, op, node, inner, *, mode): + """Bake the tap inplace and protect the tap outputs ``Scan.prepare_fgraph`` does. + + Leaves the boundary deepcopies to ``Scan.fn`` (which re-optimizes the inner + graph at ``make_thunk``). + """ + # Grant aliasing/destruction of immediately-discarded taps; protect the same tap + # outputs ``Scan.prepare_fgraph`` does so the baked inplace never makes a + # protected output the result of a destroy-map node -- otherwise ``Scan.fn``, + # which re-attaches ``NoOutputFromInplace`` and re-optimizes, would reject it. + # ``prepare_fgraph`` only installs that protection with output preallocation. + destroyable = op.inner_destroyable_inputs(node.inputs, inner.inputs) + # Unlike numba, the C/VM backends don't copy an untraced state on the first + # iteration, so they may destroy it only when the Scan owns the outer buffer; + # otherwise the in-place step would corrupt the caller's input. Drop the unowned + # untraced states from the destroyable set. + destroyable -= set(op.inner_untraced_sit_sot(inner.inputs)) - ( + op.inner_owned_untraced_sit_sot(inner.inputs) + ) + input_specs = [In(x, borrow=True, mutable=x in destroyable) for x in inner.inputs] + add_supervisor_to_fgraph(fgraph=inner, input_specs=input_specs, accept_inplace=True) + if config.scan__allow_output_prealloc: + inner.attach_feature(NoOutputFromInplace(op.protected_inner_out_idxs())) + scan_inner_optimizer(op, mode).rewrite(inner) + + +@rewrite_scan_inner_graph.register(JAXLinker) +@rewrite_scan_inner_graph.register(PytorchLinker) +@rewrite_scan_inner_graph.register(MLXLinker) +def functional_rewrite_scan_inner_graph(linker, op, node, inner, *, mode): + """Structurally optimize the inner graph for the functional JIT backends.""" + scan_inner_optimizer(op, mode).rewrite(inner) + + +def find_overwritten_reads(op, outer_inputs, inner_inputs): + """Inner traced reads a same-iteration write may overwrite, split by certainty. + + A traced read is "overwritten" when a same-iteration write reuses its physical + buffer slot. With buffer length ``L``, output tap ``o_out`` writes the slot read + by input tap ``o_in`` iff ``(o_out - o_in) % L == 0``. Since ``L`` is always + ``>= reach`` (the minimal admissible length, the oldest lookback), only two gaps + can trigger it: + + - gap 0 (``o_out == o_in``): the recurrence writes the very slot it just read + (mit_mot accumulators: read ``g[k]``, write ``g[k] + delta`` back). Holds for + ANY ``L``, so the read is *certainly* overwritten this iteration. + - gap == reach: a write lands just past the oldest read; in a buffer truncated to + its minimum (``L == reach``) it wraps onto that just-discarded oldest read. So + the read is overwritten only at ``L == reach``: certain when the static length + says so, merely possible when the length is statically unknown, impossible for + any larger ``L``. (A write landing further ahead would store onto a slot still + to be read -- an invalid recurrence we need not consider, just as we don't + consider ``L < reach``.) + + A certainly-overwritten read is dead once consumed this iteration, so the inner + function may destroy it in place. A possibly-overwritten read must not be destroyed + (it may still be live at a larger length), but no output may alias it either, since + a same-iteration overwrite would corrupt the alias before it is stored. Depends on + the *outer* buffer's static length, hence a per-node property. + """ + + def find(grouped_inner, outers, in_slices_seq, out_slices_seq): + certain, possible = [], [] + for inner_vars, outer, in_slices, out_slices in zip( + grouped_inner, outers, in_slices_seq, out_slices_seq, strict=True + ): + reach = -min(0, min(in_slices)) # minimal admissible buffer length + in_offsets = [reach + t for t in in_slices] + out_offsets = [reach + t for t in out_slices] + static_len = outer.type.shape[0] + for v, o_in in zip(inner_vars, in_offsets, strict=True): + gaps = {o_out - o_in for o_out in out_offsets} + if 0 in gaps: + certain.append(v) + elif reach in gaps: + if static_len == reach: + certain.append(v) + elif static_len is None: + possible.append(v) + return certain, possible + + info = op.info + mit_mot_certain, mit_mot_possible = find( + op.inner_mitmot_grouped(inner_inputs), + op.outer_mitmot(outer_inputs), + info.mit_mot_in_slices, + info.mit_mot_out_slices, + ) + mit_sot_certain, mit_sot_possible = find( + op.inner_mitsot_grouped(inner_inputs), + op.outer_mitsot(outer_inputs), + info.mit_sot_in_slices, + [(0,)] * info.n_mit_sot, + ) + sit_sot_certain, sit_sot_possible = find( + [[v] for v in op.inner_sitsot(inner_inputs)], + op.outer_sitsot(outer_inputs), + info.sit_sot_in_slices, + [(0,)] * info.n_sit_sot, + ) + certain = {*mit_mot_certain, *mit_sot_certain, *sit_sot_certain} + possible = {*mit_mot_possible, *mit_sot_possible, *sit_sot_possible} + return certain, possible + + +@rewrite_scan_inner_graph.register(NumbaLinker) +def numba_rewrite_scan_inner_graph(linker, op, node, inner, *, mode): + """Bake inplace, alias bans and boundary deepcopies for the numba backend. + + Numba is the only backend that consumes inner *tap* inplace, so the memory + model that governs it lives here: the rewrite bakes the inplace, bans + destroying reads the loop overwrites, and inserts the boundary deepcopies, + leaving ``op.fgraph`` ready for ``numba_funcify_Scan`` to funcify with no + further graph work. + + Memory-aliasing contract + ------------------------ + Scan defines a loop over an inner function with signature: + (*sequences[idx], *traced[idx], *untraced, *non_sequences) + -> (*traced_updates[idx], *untraced_updates) + + Traced variables are read from an indexed circular buffer at every iteration, + and the updates stored (copied) back to it immediately after. Untraced variables + are carried by reference, with each update becoming the next iteration's input. + + Scan is sometimes allowed to destroy/alias the outer traced and untraced variables, + but never sequences and non-sequences. Specifically, outer untraced variables can be + destroyed (destroy_map, opt in) or aliased (view_map, default). Traced variables can be + destroyed (destroy_map, opt in), but otherwise not alias (never in view_map). + Note that destroy permission implies alias permission, but not the other way around. + + Scan is not allowed to return outputs that alias each other, unless they were already + aliased from the outside, and it was itself allowed to alias/destroy them. This means + PyTensor already gauged it was safe to destroy/alias them. + + Scan has some freedom in how this outer contract is respected. If needed, it can + deepcopy the outer inputs once at the start, or make sure any aliased output + the inner function returns is properly copied before the final return. + + Internally, Scan also has total control over the boundary memory management of the + inner function: it grants the permissions to destroy or alias the inner loop inputs, + and whether the inner outputs may alias each other. This inner boundary is distinct + from the outer contract above, and it is Scan's responsibility to choose an inner + strategy that produces correct results while still respecting the outer contract. + + Memory-aliasing strategy + ------------------------ + Traced variables + ~~~~~~~~~~~~~~~~ + Traced variables are deep-copied once at the start if they are not in the destroy_map. + + Because every inner trace update is copied back to the buffer immediately, the inner function + is allowed to alias (but not destroy) the sequence reads, non_sequences, as well as the + traced and untraced inputs or updates, when producing the traced updates. + + A special case occurs when the indexed reads will be immediately overwritten by the updates + in the same loop iteration. For single output taps (mit-sot, sit-sot) this can only + happen when the circular buffer is truncated to its minimum legal length. + For mit-mot this can also happen without any buffer loop-around. + In either case, traced updates are not allowed to alias those traced reads, + as they may otherwise be corrupted if the reads are updated before they were copied to their own buffer. + + On the plus side, when this happens, the inner function is granted permission to destroy these + immediately-to-be discarded reads, as long as the returned updates do not themselves alias them. + + The alias-restriction and destroy-permission caused by the loop-around behavior are derived from the + buffer's static length: + * known large enough: no overwrite is possible, neither alias restriction nor destroy permission applies; + * length unknown: the loop-around overwrite can't be ruled out, alias restricted but not granted destroy permission; + * known minimal: the overwrite is certain, alias restricted but granted destroy permission. + + Untraced variables + ~~~~~~~~~~~~~~~~~~ + Untraced variables are deep-copied once at the start if they are not in destroy_map + and the inner function destroys them. + + Untraced updates are allowed to alias their own untraced inputs (which happens when n_steps=0) + or when the inner function update naturally alias the input (eg, o = i; o = i.T; o = i[::-1]). + + Because the last untraced updates are returned as is, the inner function is not allowed to + alias sequences, non_sequences, or other untraced inputs and outputs (violates the outer alias restriction). + Untraced updates are not allowed to alias traced reads (risks corruption by subsequent overwrittes), + but can alias traced updates, since the immediate copy to their buffer that follows, will break the alias. + + Because untraced inputs are immediately discarded (and protected from alias with other updates), + the inner function is always granted permission to destroy them. It can do so from any computation, + not only the one producing the matching untraced update. + + Controlling inner graph alias + ----------------------------- + PyTensor allows initial graphs to contain arbitrary (non-destructive) aliasing. + Alias at the boundary (output aliasing an input or another output) is controlled via + targeted deepcopies at the end (the ``insert_deepcopy`` helper). + + In contrast, destruction is usually NOT allowed to be present in initial graphs. + Destructive alias at the boundary is controlled here, with the following features: + * Supervisor: Checks whether any protected input are destroyed + * NoOutputInplaceOnInput: Checks whether an output is destroying a potentially + overwritten read (protected inputs are already covered by Supervisor) + Inside the boundary: + * DestroyHandler: Checks whether a consistent ordering exists for the destruction/view chains, + i.e., every read runs before its buffers' destruction and the chain has no cycle. + These features can veto (undo) any rewrite that would violate their spec. + They CANNOT fix violations that already existed in the initial graph. + """ + certain_overwritten, possible_overwritten = find_overwritten_reads( + op, node.inputs, inner.inputs + ) + + # Grant the inner function the right to alias (borrow=True) all inputs and to + # destroy (mutable=True) reads that are immediately discarded: the always-discarded + # reads from ``inner_destroyable_inputs`` (truncated sit_sot/mit_sot and every + # untraced input) plus the certainly-overwritten reads (dead once consumed this + # iteration). The certain set is what extends destruction to mit_mot, which the + # shared ``inner_destroyable_inputs`` (also used by the C backend) leaves out. + discarded = ( + op.inner_destroyable_inputs(node.inputs, inner.inputs) | certain_overwritten + ) + input_specs = [In(x, borrow=True, mutable=x in discarded) for x in inner.inputs] + add_supervisor_to_fgraph(fgraph=inner, input_specs=input_specs, accept_inplace=True) + + # No output may alias a (certainly or possibly) overwritten read: a same-iteration + # overwrite could corrupt the alias before it is stored. Ban that inplace so the + # inner graph allocates a fresh buffer instead; the deepcopy below still breaks any + # remaining non-destructive alias of an overwritten read. + overwritten_reads = certain_overwritten | possible_overwritten + if overwritten_reads: + in_pos = {v: i for i, v in enumerate(inner.inputs)} + inner.attach_feature( + NoOutputInplaceOnInput([in_pos[v] for v in overwritten_reads]) + ) + + scan_inner_optimizer(op, mode).rewrite(inner) + + # Post-patch the alias contract with targeted deepcopies. A traced update may + # borrow unless it aliases an overwritten read. An untraced update may borrow if + # it is the FIRST viewer of a freshly produced buffer (or its own untraced + # input) and not an alias of another untraced update; any other alias (sequence + # reads, traced reads, non-sequences, other untraced inputs/updates) is copied. + own_untraced_input = dict( + zip( + op.inner_untraced_sit_sot_outs(inner.outputs), + op.inner_untraced_sit_sot(inner.inputs), + strict=True, + ) + ) + untraced_outs = set(own_untraced_input) + seen_untraced_roots = set() + output_specs = [] + for update in inner.outputs: + root = alias_root(update) + if update in untraced_outs: + borrow = ( + root.owner is not None or root is own_untraced_input[update] + ) and root not in seen_untraced_roots + if borrow: + seen_untraced_roots.add(root) + else: + borrow = root not in overwritten_reads + output_specs.append(Out(update, borrow=borrow)) + insert_deepcopy(inner, wrapped_inputs=input_specs, wrapped_outputs=output_specs) diff --git a/pytensor/scan/rewriting/inplace.py b/pytensor/scan/rewriting/inplace.py index 4839d4348f..b91a2a886a 100644 --- a/pytensor/scan/rewriting/inplace.py +++ b/pytensor/scan/rewriting/inplace.py @@ -6,6 +6,7 @@ have if it understood Scan's input categories. """ +from copy import copy from itertools import chain from pytensor.compile.ops import deep_copy_op @@ -85,7 +86,13 @@ def attempt_scan_inplace( inputs = ls_begin + ls + ls_end - new_op = op.clone() + # Shallow-copy the op so we can give it its own ``destroy_map`` without + # mutating the canonical op; the frozen inner graph is immutable and + # safely shared. (``op.clone()`` returns ``self`` for immutable ops.) + new_op = copy(op) + # The compiled inner function depends on ``destroy_map`` (it drives the + # untraced-copy and inplace setup), so drop any cached copy. + new_op._fn = None destroy_map = op.destroy_map.copy() for out_idx in output_indices: diff --git a/pytensor/scan/rewriting/merge.py b/pytensor/scan/rewriting/merge.py index e81a55dfd1..b7ab221f2d 100644 --- a/pytensor/scan/rewriting/merge.py +++ b/pytensor/scan/rewriting/merge.py @@ -10,7 +10,6 @@ from pytensor.graph.basic import NominalVariable, equal_computations from pytensor.graph.features import ReplaceValidate -from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import GraphRewriter, node_rewriter from pytensor.graph.rewriting.reachability import ( ancestor_bitsets, @@ -19,7 +18,7 @@ ) from pytensor.graph.traversal import ancestors from pytensor.scan.op import Scan, ScanInfo -from pytensor.scan.utils import ScanArgs, reconstruct_graph +from pytensor.scan.utils import ScanArgs from pytensor.tensor.basic import get_scalar_constant_value from pytensor.tensor.exceptions import NotScalarConstantError @@ -126,15 +125,12 @@ def merge(self, nodes): # add the condition, which was the one of nodes[0] inner_outs[0].append([condition]) - # Clone the inner graph of each node independently + # Thaw the frozen inner graph of each node independently; the category + # slices (all existing inner variables) map through the memo. for idx, nd in enumerate(nodes): - # concatenate all inner_ins and inner_outs of nd - flat_inner_ins = list(chain.from_iterable(inner_ins[idx])) - flat_inner_outs = list(chain.from_iterable(inner_outs[idx])) - # clone - flat_inner_ins, flat_inner_outs = reconstruct_graph( - flat_inner_ins, flat_inner_outs - ) + _, memo = nd.op.fgraph.unfreeze(return_memo=True) + flat_inner_ins = [memo[v] for v in chain.from_iterable(inner_ins[idx])] + flat_inner_outs = [memo[v] for v in chain.from_iterable(inner_outs[idx])] # split the new inner variables again in seq, mitmot, etc. new_inner_ins = [] count = 0 @@ -369,11 +365,12 @@ def scan_merge_inouts(fgraph, node): # Do a first pass to merge identical external inputs. # Equivalent inputs will be stored in inp_equiv, then a new # scan node created without duplicates. + unfrozen_fgraph = node.op.fgraph.unfreeze() a = ScanArgs( node.inputs, node.outputs, - node.op.inner_inputs, - node.op.inner_outputs, + unfrozen_fgraph.inputs, + unfrozen_fgraph.outputs, node.op.info, ) @@ -412,8 +409,10 @@ def scan_merge_inouts(fgraph, node): inner_inputs = a.inner_inputs outer_inputs = a.outer_inputs info = a.info - a_inner_outs = a.inner_outputs - inner_outputs = clone_replace(a_inner_outs, replace=inp_equiv) + unfrozen_fgraph.replace_all(list(inp_equiv.items())) + # Read the outputs back from the graph: an output that was itself a + # replaced input is rewired there, not in ``a``'s stored slices. + inner_outputs = list(unfrozen_fgraph.outputs) new_op = Scan( inner_inputs, @@ -431,6 +430,7 @@ def scan_merge_inouts(fgraph, node): if not isinstance(outputs, list | tuple): outputs = [outputs] + # Read-only analysis below; slice the frozen inner graph directly. na = ScanArgs( outer_inputs, outputs, diff --git a/pytensor/scan/rewriting/push_out.py b/pytensor/scan/rewriting/push_out.py index 949dfd11a2..6f456e1abb 100644 --- a/pytensor/scan/rewriting/push_out.py +++ b/pytensor/scan/rewriting/push_out.py @@ -21,11 +21,10 @@ from pytensor.compile.ops import DeepCopyOp, ViewOp from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.fg import FunctionGraph, Output -from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import node_rewriter from pytensor.graph.type import HasShape from pytensor.scan.op import Scan -from pytensor.scan.utils import ScanArgs, reconstruct_graph, safe_new +from pytensor.scan.utils import ScanArgs, safe_new from pytensor.tensor.basic import get_scalar_constant_value from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import Dot, dot @@ -176,8 +175,15 @@ def add_to_replace(y): nw_outer.append(repl_out) givens[to_repl] = repl_in - op_outs = clone_replace(node_outputs, replace=givens) - op_ins = node_inputs + nw_inner + # Thaw the frozen inner graph and rewire the pushed-out variables to + # their placeholders (imported as fresh inputs) on the mutable copy. + unfrozen_fgraph, memo = op.fgraph.unfreeze(return_memo=True) + unfrozen_fgraph.replace_all( + [(memo[to_repl], repl_in) for to_repl, repl_in in givens.items()], + import_missing=True, + ) + op_outs = list(unfrozen_fgraph.outputs) + op_ins = [memo[inp] for inp in node_inputs] + nw_inner new_info = dataclasses.replace( op.info, n_non_seqs=op.info.n_non_seqs + len(nw_outer) @@ -399,8 +405,15 @@ def add_to_replace(y): givens[to_repl] = repl_in - op_outs = clone_replace(node_outputs, replace=givens) - op_ins = nw_inner + node_inputs + # Thaw the frozen inner graph and rewire the pushed-out variables to + # their placeholders (imported as fresh inputs) on the mutable copy. + unfrozen_fgraph, memo = op.fgraph.unfreeze(return_memo=True) + unfrozen_fgraph.replace_all( + [(memo[to_repl], repl_in) for to_repl, repl_in in givens.items()], + import_missing=True, + ) + op_outs = list(unfrozen_fgraph.outputs) + op_ins = nw_inner + [memo[inp] for inp in node_inputs] # Reconstruct node nw_info = dataclasses.replace(op.info, n_seqs=op.info.n_seqs + len(nw_inner)) @@ -560,6 +573,7 @@ def push_out_inner_vars( assert isinstance(new_scan_node.op, Scan) + # Only the outer bookkeeping is read; slice the frozen inner graph directly. new_scan_args = ScanArgs( new_scan_node.inputs, new_scan_node.outputs, @@ -604,10 +618,12 @@ def add_nitsot_outputs( assert isinstance(old_scan_node.op, Scan) - # Create the `Scan` `Op` from the `ScanArgs` + # Thaw the frozen inner graph to build the new `Scan` `Op`; the appended + # nitsot outputs are existing inner variables, covered by the memo. + _, memo = old_scan_node.op.fgraph.unfreeze(return_memo=True) new_scan_op = Scan( - new_scan_args.inner_inputs, - new_scan_args.inner_outputs, + [memo[v] for v in new_scan_args.inner_inputs], + [memo[v] for v in new_scan_args.inner_outputs], new_scan_args.info, mode=old_scan_node.op.mode, profile=old_scan_node.op.profile, @@ -700,7 +716,6 @@ def scan_push_out_add(fgraph, node): op.inner_inputs, op.inner_outputs, op.info, - clone=False, ) for nd in add_of_dot_nodes: @@ -878,9 +893,11 @@ def scan_push_out_dot1(fgraph, node): + inner_nitsot_outs + inner_untraced_sitsot_outs ) - new_inner_inps, new_inner_outs = reconstruct_graph( - _new_inner_inps, _new_inner_outs - ) + # Thaw the frozen inner graph; the category slices (all + # existing inner variables) map through the memo. + _, memo = op.fgraph.unfreeze(return_memo=True) + new_inner_inps = [memo[v] for v in _new_inner_inps] + new_inner_outs = [memo[v] for v in _new_inner_outs] new_op = Scan( new_inner_inps, new_inner_outs, diff --git a/pytensor/scan/rewriting/trace.py b/pytensor/scan/rewriting/trace.py index 2c681340a8..0767d10599 100644 --- a/pytensor/scan/rewriting/trace.py +++ b/pytensor/scan/rewriting/trace.py @@ -922,9 +922,11 @@ def scan_sit_sot_to_untraced(fgraph, node): convertible_set = set(convertible) - # Gather current inner inputs/outputs by category - inner_inputs = list(op.inner_inputs) - inner_outputs = list(op.inner_outputs) + # Thaw the frozen inner graph; the reassembled category slices below are + # handed to a new ``Scan`` and must be mutable. + unfrozen_fgraph = op.fgraph.unfreeze() + inner_inputs = list(unfrozen_fgraph.inputs) + inner_outputs = list(unfrozen_fgraph.outputs) inner_sitsot_ins = op.inner_sitsot(inner_inputs) inner_sitsot_outs = op.inner_sitsot_outs(inner_outputs) diff --git a/pytensor/scan/rewriting/utils.py b/pytensor/scan/rewriting/utils.py index 288f55e607..9395a56d2e 100644 --- a/pytensor/scan/rewriting/utils.py +++ b/pytensor/scan/rewriting/utils.py @@ -3,7 +3,6 @@ from typing import cast from pytensor.graph.basic import Apply, Variable -from pytensor.graph.replace import clone_replace from pytensor.scan.op import Scan @@ -24,9 +23,9 @@ def _rebuild_scan_with_new_signature( Each ``drop_*`` argument is a set of indices into its category; the rebuilt op retains only the entries whose index is not listed. - ``inner_substitutions``, when provided, is applied via ``clone_replace`` - on the inner outputs before the rebuild -- use it to inline constants - or rewire duplicate inner inputs. + ``inner_substitutions``, when provided, is applied on the thawed inner + graph before the rebuild -- use it to inline constants or rewire + duplicate inner inputs. Returns a ``replacements`` dict: kept outer outputs map to their counterparts on the new op, dropped outputs carry no mapping (they @@ -57,12 +56,16 @@ def _rebuild_scan_with_new_signature( n_non_seqs=len(keep_non_seqs), ) - inner_seqs = op.inner_seqs(op.inner_inputs) - inner_mm_groups = op.inner_mitmot_grouped(op.inner_inputs) - inner_ms_groups = op.inner_mitsot_grouped(op.inner_inputs) - inner_ss = op.inner_sitsot(op.inner_inputs) - inner_us = op.inner_untraced_sit_sot(op.inner_inputs) - inner_non_seqs = op.inner_non_seqs(op.inner_inputs) + # Thaw the frozen inner graph; slices and substitutions map through the memo. + unfrozen_fgraph, memo = op.fgraph.unfreeze(return_memo=True) + inner_inputs = list(unfrozen_fgraph.inputs) + + inner_seqs = op.inner_seqs(inner_inputs) + inner_mm_groups = op.inner_mitmot_grouped(inner_inputs) + inner_ms_groups = op.inner_mitsot_grouped(inner_inputs) + inner_ss = op.inner_sitsot(inner_inputs) + inner_us = op.inner_untraced_sit_sot(inner_inputs) + inner_non_seqs = op.inner_non_seqs(inner_inputs) new_inner_inputs = ( [inner_seqs[k] for k in keep_seqs] @@ -73,9 +76,12 @@ def _rebuild_scan_with_new_signature( + [inner_non_seqs[k] for k in keep_non_seqs] ) - inner_outputs = op.inner_outputs if inner_substitutions: - inner_outputs = clone_replace(inner_outputs, replace=inner_substitutions) + unfrozen_fgraph.replace_all( + [(memo.get(k, k), memo.get(v, v)) for k, v in inner_substitutions.items()], + import_missing=True, + ) + inner_outputs = list(unfrozen_fgraph.outputs) inner_mm_out_groups = op.inner_mitmot_outs_grouped(inner_outputs) inner_ms_outs = op.inner_mitsot_outs(inner_outputs) inner_ss_outs = op.inner_sitsot_outs(inner_outputs) diff --git a/pytensor/scan/utils.py b/pytensor/scan/utils.py index a999979f12..c62bc8fcf9 100644 --- a/pytensor/scan/utils.py +++ b/pytensor/scan/utils.py @@ -2,6 +2,7 @@ import copy import logging +import warnings from collections import namedtuple from collections.abc import Callable, Sequence from itertools import chain @@ -369,23 +370,27 @@ def __init__( _inner_inputs: Sequence[Variable], _inner_outputs: Sequence[Variable], info: "ScanInfo", - clone: bool | None = True, + clone: bool | None = None, ): + if clone is not None: + warnings.warn( + "The `clone` argument of ScanArgs is deprecated and ignored: the " + "inner graph is sliced as-is. For a detached mutable copy of a " + "Scan's inner graph use `ScanArgs.from_node(node, clone=True)` or " + "thaw it with `node.op.fgraph.unfreeze()`.", + FutureWarning, + ) + self.n_steps = outer_inputs[0] self.as_while = info.as_while - if clone: - rval = reconstruct_graph(_inner_inputs, _inner_outputs, "") - else: - rval = (_inner_inputs, _inner_outputs) - if self.as_while: - self.cond = [rval[1][-1]] - inner_outputs = rval[1][:-1] + self.cond = [_inner_outputs[-1]] + inner_outputs = _inner_outputs[:-1] else: self.cond = [] - inner_outputs = rval[1] - inner_inputs = rval[0] + inner_outputs = _inner_outputs + inner_inputs = _inner_inputs p = 1 q = 0 @@ -491,13 +496,20 @@ def from_node(node, clone=False) -> "ScanArgs": if not isinstance(node.op, Scan): raise TypeError(f"{node} is not a Scan node") + if clone: + # Thaw the frozen inner graph into a fresh mutable copy. + unfrozen_fgraph = node.op.fgraph.unfreeze() + inner_inputs = list(unfrozen_fgraph.inputs) + inner_outputs = list(unfrozen_fgraph.outputs) + else: + inner_inputs = node.op.inner_inputs + inner_outputs = node.op.inner_outputs return ScanArgs( node.inputs, node.outputs, - node.op.inner_inputs, - node.op.inner_outputs, + inner_inputs, + inner_outputs, node.op.info, - clone=clone, ) @property diff --git a/pytensor/tensor/optimize.py b/pytensor/tensor/optimize.py index 130cbf9ba2..657c4d28ef 100644 --- a/pytensor/tensor/optimize.py +++ b/pytensor/tensor/optimize.py @@ -1,14 +1,16 @@ import logging from collections.abc import Sequence from copy import copy +from functools import singledispatch import numpy as np import pytensor.scalar as ps from pytensor.compile.maker import function +from pytensor.compile.mode import get_mode from pytensor.gradient import DisconnectedType, grad, jacobian from pytensor.graph.basic import Apply, Constant -from pytensor.graph.fg import FunctionGraph +from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph from pytensor.graph.null_type import NullType from pytensor.graph.op import ( ComputeMapType, @@ -163,15 +165,42 @@ def _depends_only_on_constants(var: Variable) -> bool: class ScipyWrapperOp(Op, HasInnerGraph): - """Shared logic for scipy optimization ops""" + """Shared logic for scipy optimization ops. + + The inner graph is held frozen (immutable) as ``self.fgraph``, so the + canonical op can never be mutated in place. Graph manipulation (``grad`` / + ``graph_replace`` / ``function``) runs on a fresh mutable copy obtained via + ``self.fgraph.unfreeze()`` -- ``graph_replace`` on frozen nodes would + otherwise leak them into the outer graph. + """ + + # Attribute names (besides the frozen inner graph) that distinguish two ops + # of the same type for eq/hash. Subclasses override. + _scipy_props: tuple[str, ...] = () def build_fn(self): """ This is overloaded because scipy converts scalar inputs to lists, changing the return type. The wrapper function logic is there to handle this. """ - outputs = self.inner_outputs - self._fn = fn = function(self.inner_inputs, outputs, trust_input=True) + fgraph = self.fgraph.unfreeze() + # ``optimize_inner_graph`` already baked this graph for the active backend + # (inplace included), exactly like ``OpFromGraph``. So we only link it -- + # ``minimum_compile`` excluding ``compile_inner_graph`` (the rewrite that + # already ran); ``prepare_fgraph`` still inserts the boundary deepcopies. + # ``accept_inplace`` admits the baked inplace ops; backend optimization can + # leave a declared input unused (e.g. a folded core-shape vector), hence + # ``on_unused_input="ignore"``. + self._fn = fn = function( + fgraph.inputs, + fgraph.outputs, + mode=get_mode(None) + .clone(optimizer="minimum_compile") + .excluding("compile_inner_graph"), + accept_inplace=True, + trust_input=True, + on_unused_input="ignore", + ) # Do this reassignment to see the compiled graph in the dprint # self.fgraph = fn.maker.fgraph @@ -192,22 +221,56 @@ def fn_wrapped(self): @property def inner_inputs(self): - return self.fgraph.inputs + # A list (not the frozen tuple) so callers that concatenate inner + # inputs/outputs keep list semantics. + return list(self.fgraph.inputs) @property def inner_outputs(self): - return self.fgraph.outputs + return list(self.fgraph.outputs) + + def _prop_values(self): + return tuple(getattr(self, name) for name in self._scipy_props) + + def __eq__(self, other): + if self is other: + return True + if type(self) is not type(other): + return False + return ( + self.fgraph == other.fgraph and self._prop_values() == other._prop_values() + ) + + def __hash__(self): + # ``optimizer_kwargs`` may hold nested dicts, so the props stay out of + # the hash; equal ops still hash equal on the type and inner graph. + return hash((type(self), self.fgraph)) + + def clone_with_inner_graph(self, inner_fgraph): + """Return a copy of this op whose inner graph is ``inner_fgraph``. - def clone_with_new_fgraph(self, fgraph): + Used by the ``optimize_inner_graph`` rewrite to bake an + already-optimized inner graph into a NEW immutable op without touching + ``self``. ``inner_fgraph`` may be a mutable ``FunctionGraph`` (it is + frozen here) or an already-frozen graph. + """ clone_op = copy(self) clone_op._fn = None clone_op._fn_wrapped = None - clone_op.fgraph = fgraph + clone_op.fgraph = ( + inner_fgraph + if isinstance(inner_fgraph, FrozenFunctionGraph) + else inner_fgraph.freeze() + ) return clone_op + # Name used by the canonicalization rewrite that rebuilds the inner graph. + clone_with_new_fgraph = clone_with_inner_graph + def clone(self): - clone_fgraph = self.fgraph.clone(clone_inner_graphs=True) - return self.clone_with_new_fgraph(clone_fgraph) + # The inner graph is immutable (a frozen ``FunctionGraph``), so there is + # nothing to deep-clone -- mirror ``Composite``/``OpFromGraph``. + return self def prepare_node( self, @@ -229,17 +292,41 @@ def make_node(self, *inputs): return Apply(self, inputs, [self.inner_inputs[0].type(), ps.bool("success")]) +@singledispatch +def rewrite_optimize_inner_graph(linker, op, node, inner, *, mode): + """Rewrite an ``optimize`` op (Minimize/Root) inner graph for ``linker``. + + Each linker registers its own; see ``pytensor.tensor.rewriting.optimize``. + """ + raise NotImplementedError( + f"Linker {type(linker).__name__} has not registered an optimize-op " + "inner-graph rewrite" + ) + + class ScipyScalarWrapperOp(ScipyWrapperOp): def build_fn(self): # We need to adjust the graph to work with what scipy will be passing into the inner function -- # always scalar array of float64 type - x, *args = self.inner_inputs + fgraph = self.fgraph.unfreeze() + x, *args = fgraph.inputs new_root_x = ps.float64(name="x_scalar") new_x = tensor_from_scalar(new_root_x.astype(x.type.dtype)) - new_outputs = graph_replace(self.inner_outputs, {x: new_x}) + new_outputs = graph_replace(fgraph.outputs, {x: new_x}) - self._fn = fn = function([new_root_x, *args], new_outputs, trust_input=True) + # See ``ScipyWrapperOp.build_fn``: the graph is already baked, so link + # it (the scipy boundary wrapping above links alongside the baked body). + self._fn = fn = function( + [new_root_x, *args], + new_outputs, + mode=get_mode(None) + .clone(optimizer="minimum_compile") + .excluding("compile_inner_graph"), + accept_inplace=True, + trust_input=True, + on_unused_input="ignore", + ) # Do this reassignment to see the compiled graph in the dprint # self.fgraph = fn.maker.fgraph @@ -270,9 +357,9 @@ def compute_implicit_gradients( Whether the optimization problem is a minimization problem. If False, it is assumed to be a root-finding problem. """ - fgraph = self.fgraph - inner_x, *inner_args = self.inner_inputs - inner_fx = self.inner_outputs[0] + fgraph = self.fgraph.unfreeze() + inner_x, *inner_args = fgraph.inputs + inner_fx = fgraph.outputs[0] if is_minimization: # The implicit function in minimization is grad(x, theta) == 0 @@ -323,14 +410,26 @@ class ScipyVectorWrapperOp(ScipyWrapperOp): def build_fn(self): # We need to adjust the graph to work with what scipy will be passing into the inner function -- # always a vector array with size of at least 1 - x, *args = self.inner_inputs - if x.type.shape != (): + if self.inner_inputs[0].type.shape != (): return super().build_fn() + fgraph = self.fgraph.unfreeze() + x, *args = fgraph.inputs new_root_x = x[None].type() new_x = new_root_x.squeeze() - new_outputs = graph_replace(self.inner_outputs, {x: new_x}) - self._fn = fn = function([new_root_x, *args], new_outputs, trust_input=True) + new_outputs = graph_replace(fgraph.outputs, {x: new_x}) + # See ``ScipyWrapperOp.build_fn``: the graph is already baked, so link + # it (the scipy boundary wrapping above links alongside the baked body). + self._fn = fn = function( + [new_root_x, *args], + new_outputs, + mode=get_mode(None) + .clone(optimizer="minimum_compile") + .excluding("compile_inner_graph"), + accept_inplace=True, + trust_input=True, + on_unused_input="ignore", + ) # Do this reassignment to see the compiled graph in the dprint # self.fgraph = fn.maker.fgraph @@ -383,9 +482,9 @@ def compute_implicit_gradients( problem, where `f` is the objective function. In this case, we instead take `f` to be the gradient of the objective function, which *is* indeed zero at the minimum. """ - fgraph = self.fgraph - inner_x, *inner_args = self.inner_inputs - implicit_f = self.inner_outputs[0] + fgraph = self.fgraph.unfreeze() + inner_x, *inner_args = fgraph.inputs + implicit_f = fgraph.outputs[0] if is_minimization: # The implicit function in minimization is grad(x, theta) == 0 implicit_f = grad(implicit_f, inner_x) @@ -484,6 +583,7 @@ def _optimizer_connection_pattern(fgraph, is_minimization): An input may be connected to the objective but disconnected from its gradient (e.g. an additive constant), so the connection pattern must reflect the actual implicit function. """ + fgraph = fgraph.unfreeze() inner_x = fgraph.inputs[0] fx = fgraph.outputs[0] if is_minimization: @@ -494,6 +594,8 @@ def _optimizer_connection_pattern(fgraph, is_minimization): class MinimizeScalarOp(ScipyScalarWrapperOp): + _scipy_props = ("method", "optimizer_kwargs") + def __init__( self, x: TensorVariable, @@ -514,7 +616,7 @@ def __init__( raise ValueError( "The variable `x` must be an input to the computational graph of the objective function." ) - self.fgraph = FunctionGraph([x, *args], [objective]) + self.fgraph = FrozenFunctionGraph.from_io([x, *args], [objective]) self.method = method self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {} @@ -611,6 +713,15 @@ def minimize_scalar( class MinimizeOp(ScipyVectorWrapperOp): + _scipy_props = ( + "method", + "jac", + "hess", + "hessp", + "use_vectorized_jac", + "optimizer_kwargs", + ) + def __init__( self, x: TensorVariable, @@ -651,6 +762,8 @@ def __init__( ) self.fgraph.add_output(hess_wrt_x) + self.fgraph = self.fgraph.freeze() + self.jac = jac self.hess = hess self.hessp = hessp @@ -813,6 +926,8 @@ def minimize( class RootScalarOp(ScipyScalarWrapperOp): + _scipy_props = ("method", "jac", "hess", "optimizer_kwargs") + def __init__( self, variables: TensorVariable, @@ -851,6 +966,8 @@ def __init__( f_double_prime = grad(self.fgraph.outputs[-1], self.fgraph.inputs[0]) self.fgraph.add_output(f_double_prime) + self.fgraph = self.fgraph.freeze() + self.method = method self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {} self.jac = jac @@ -965,9 +1082,10 @@ def root_scalar( class RootOp(ScipyVectorWrapperOp): - # These __props__ were wrong: they ignore the inner graph, - # making RootOps of different equations compare equal (and get merged) - # __props__ = ("method", "jac") + # eq/hash key on the (frozen) inner graph plus these props -- keying on the + # inner graph is what keeps RootOps of different equations distinct (an + # earlier ``__props__ = ("method", "jac")`` ignored it and merged them). + _scipy_props = ("method", "jac", "use_vectorized_jac", "optimizer_kwargs") def __init__( self, @@ -1003,6 +1121,8 @@ def __init__( ) self.fgraph.add_output(atleast_2d(jac_wrt_x)) + self.fgraph = self.fgraph.freeze() + self.jac = jac self.method = method @@ -1020,8 +1140,9 @@ def __str__(self): return f"{self.__class__.__name__}({str_args})" def build_fn(self): - outputs = self.inner_outputs - variables, *args = self.inner_inputs + fgraph = self.fgraph.unfreeze() + variables, *args = fgraph.inputs + outputs = fgraph.outputs if variables.ndim > 0: new_root_variables = variables @@ -1036,8 +1157,17 @@ def build_fn(self): new_outputs = graph_replace(outputs, {variables: new_variables}) + # See ``ScipyWrapperOp.build_fn``: the graph is already baked, so link + # it (the scipy boundary wrapping above links alongside the baked body). self._fn = fn = function( - [new_root_variables, *args], new_outputs, trust_input=True + [new_root_variables, *args], + new_outputs, + mode=get_mode(None) + .clone(optimizer="minimum_compile") + .excluding("compile_inner_graph"), + accept_inplace=True, + trust_input=True, + on_unused_input="ignore", ) # Do this reassignment to see the compiled graph in the dprint diff --git a/pytensor/tensor/rewriting/einsum.py b/pytensor/tensor/rewriting/einsum.py index 5ee8acc92a..0ded6ef9ff 100644 --- a/pytensor/tensor/rewriting/einsum.py +++ b/pytensor/tensor/rewriting/einsum.py @@ -1,10 +1,10 @@ from typing import cast +from pytensor.compile.rewriting import inline_ofg_node from pytensor.graph import Apply, FunctionGraph, node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace from pytensor.tensor.einsum import Einsum, einsum from pytensor.tensor.rewriting.basic import register_specialize -from pytensor.tensor.rewriting.ofg import inline_ofg_node from pytensor.tensor.variable import TensorVariable diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/indexed_elemwise.py index b3ff3828b3..38d816d5cf 100644 --- a/pytensor/tensor/rewriting/indexed_elemwise.py +++ b/pytensor/tensor/rewriting/indexed_elemwise.py @@ -261,7 +261,7 @@ def __init__(self, *args, indexed_inputs=(), indexed_outputs=(), **kwargs): # safe because reads don't destroy. Write targets always get their own # fresh inner input (see FuseIndexedElemwise) so a destroyed buffer is # never deduped onto a read source. - super().__init__(*args, on_unused_input="ignore", accept_inplace=True, **kwargs) + super().__init__(*args, on_unused_input="ignore", **kwargs) def __str__(self): for node in self.fgraph.apply_nodes: diff --git a/pytensor/tensor/rewriting/linalg/solvers.py b/pytensor/tensor/rewriting/linalg/solvers.py index 33a907b1da..f3439bd3c6 100644 --- a/pytensor/tensor/rewriting/linalg/solvers.py +++ b/pytensor/tensor/rewriting/linalg/solvers.py @@ -1,11 +1,11 @@ from collections.abc import Container -from copy import copy from pytensor import tensor as pt from pytensor.assumptions import DIAGONAL, ORTHOGONAL, check_assumption from pytensor.assumptions.positive_definite import POSITIVE_DEFINITE from pytensor.compile import optdb from pytensor.graph import Constant, graph_inputs +from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import ( copy_stack_trace, dfs_rewriter, @@ -564,12 +564,14 @@ def _scan_split_non_sequence_decomposition_and_solve( The LU decomposition step can then be pushed out of the inner loop by the `scan_pushout_non_sequences` rewrite. """ scan_op: Scan = node.op - non_sequences = set(scan_op.inner_non_seqs(scan_op.inner_inputs)) - new_scan_fgraph = scan_op.fgraph + frozen_fgraph = scan_op.fgraph + non_sequences = set(scan_op.inner_non_seqs(frozen_fgraph.inputs)) + new_scan_fgraph: FunctionGraph | None = None - changed = False while True: - for inner_node in new_scan_fgraph.toposort(): + for inner_node in ( + frozen_fgraph if new_scan_fgraph is None else new_scan_fgraph + ).toposort(): match (inner_node.op, *inner_node.inputs): case (Blockwise(Solve(assume_a=assume_a_var)), A, _b) if ( assume_a_var in allowed_assume_a @@ -578,12 +580,13 @@ def _scan_split_non_sequence_decomposition_and_solve( (isinstance(root_inp, Constant) or (root_inp in non_sequences)) for root_inp in graph_inputs([A]) ): - if new_scan_fgraph is scan_op.fgraph: - # Clone the first time to avoid mutating the original fgraph - new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv() # type: ignore[attr-defined] - non_sequences = { - equiv[non_seq] for non_seq in non_sequences - } + if new_scan_fgraph is None: + # Thaw the frozen graph into a mutable copy on the + # first match, carrying the tracked state over. + new_scan_fgraph, equiv = frozen_fgraph.unfreeze( + return_memo=True + ) + non_sequences = {equiv[v] for v in non_sequences} inner_node = equiv[inner_node] replace_dict = _split_decomp_and_solve_steps( @@ -595,18 +598,15 @@ def _scan_split_non_sequence_decomposition_and_solve( assert ( isinstance(replace_dict, dict) and len(replace_dict) > 0 ), "Rewrite failed" - new_scan_fgraph.replace_all(replace_dict.items()) # type: ignore[attr-defined] - changed = True + new_scan_fgraph.replace_all(replace_dict.items()) break # Break to start over with a fresh toposort else: # no_break break # Nothing else changed - if not changed: + if new_scan_fgraph is None: return - # Return a new scan to indicate that a rewrite was done - new_scan_op = copy(scan_op) - new_scan_op.fgraph = new_scan_fgraph + new_scan_op = scan_op.clone_with_inner_graph(new_scan_fgraph) new_outs = new_scan_op.make_node(*node.inputs).outputs copy_stack_trace(node.outputs, new_outs) return new_outs diff --git a/pytensor/tensor/rewriting/ofg.py b/pytensor/tensor/rewriting/ofg.py index 1cd914d343..5f8b865fee 100644 --- a/pytensor/tensor/rewriting/ofg.py +++ b/pytensor/tensor/rewriting/ofg.py @@ -1,46 +1,10 @@ -from pytensor.compile import optdb -from pytensor.compile.builders import OpFromGraph -from pytensor.graph import Apply, Variable, node_rewriter -from pytensor.graph.fg import FrozenFunctionGraph -from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter +from pytensor.compile.rewriting import inline_ofg_node +from pytensor.graph import node_rewriter from pytensor.tensor.basic import AllocDiag from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.special import XLog1PY, XLogY -def inline_ofg_node(node: Apply) -> list[Variable]: - frozen_fg: FrozenFunctionGraph = node.op._frozen_fgraph - replacements = dict(zip(frozen_fg.inputs, node.inputs)) - inlined_outs = frozen_fg.bind(replacements) - copy_stack_trace(frozen_fg.outputs, inlined_outs) - return inlined_outs - - -@node_rewriter([OpFromGraph]) -def inline_ofg_expansion(fgraph, node): - """ - This optimization expands internal graph of OpFromGraph. - Only performed if node.op.is_inline == True - Doing so can improve optimization at the cost of compilation speed. - """ - op = node.op - if not op.is_inline: - return False - - return inline_ofg_node(node) - - -# We want to run this before the first merge optimizer -# and before the first scan optimizer. -optdb.register( - "inline_ofg_expansion", - dfs_rewriter(inline_ofg_expansion), - "fast_compile", - "fast_run", - position=-0.01, -) - - @register_specialize("inline_ofg") @node_rewriter([AllocDiag, XLogY, XLog1PY]) def late_inline_OpFromGraph(fgraph, node): diff --git a/pytensor/tensor/rewriting/optimize.py b/pytensor/tensor/rewriting/optimize.py index 41de8666f8..d4b0510fe0 100644 --- a/pytensor/tensor/rewriting/optimize.py +++ b/pytensor/tensor/rewriting/optimize.py @@ -1,11 +1,64 @@ +from pytensor.compile import optdb +from pytensor.compile.aliasing import add_supervisor_to_fgraph +from pytensor.compile.io import In +from pytensor.compile.rewriting import rewrite_inner_graph from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph -from pytensor.graph.replace import clone_replace -from pytensor.graph.rewriting.basic import node_rewriter -from pytensor.tensor.optimize import ScipyWrapperOp +from pytensor.graph.rewriting.basic import graph_rewriter, node_rewriter +from pytensor.link.basic import PerformLinker +from pytensor.link.c.basic import CLinker, OpWiseCLinker +from pytensor.link.jax.linker import JAXLinker +from pytensor.link.mlx.linker import MLXLinker +from pytensor.link.numba.linker import NumbaLinker +from pytensor.link.pytorch.linker import PytorchLinker +from pytensor.link.vm import VMLinker +from pytensor.tensor.optimize import ScipyWrapperOp, rewrite_optimize_inner_graph from pytensor.tensor.rewriting.basic import register_canonicalize +@rewrite_optimize_inner_graph.register(VMLinker) +@rewrite_optimize_inner_graph.register(PerformLinker) +@rewrite_optimize_inner_graph.register(CLinker) +@rewrite_optimize_inner_graph.register(OpWiseCLinker) +def c_rewrite_optimize_inner_graph(linker, op, node, inner, *, mode): + # Same contract as ``OpFromGraph``: inputs (the optimization variable + args) + # must not be mutated, so they are protected; inplace may still be baked + # between purely internal buffers. ``build_fn`` then only links this graph. + # The Supervisor is needed even with no mutable inputs: it is the feature that + # vetoes input-destroying inplace rewrites while ``mode.optimizer`` runs them. + input_specs = [In(x, borrow=True, mutable=False) for x in inner.inputs] + add_supervisor_to_fgraph(fgraph=inner, input_specs=input_specs, accept_inplace=True) + mode.optimizer.rewrite(inner) + + +@rewrite_optimize_inner_graph.register(NumbaLinker) +@rewrite_optimize_inner_graph.register(JAXLinker) +@rewrite_optimize_inner_graph.register(PytorchLinker) +@rewrite_optimize_inner_graph.register(MLXLinker) +def jit_rewrite_optimize_inner_graph(linker, op, node, inner, *, mode): + # JIT backends manage memory themselves, so leave the inner graph functional. + # (Unlike ``OpFromGraph`` under numba, no deepcopies are baked in either: a + # scipy op is never funcified -- it always perform-links via ``build_fn``, + # whose ``FunctionMaker`` pass inserts the boundary deepcopies.) + mode.excluding("inplace").optimizer.rewrite(inner) + + +@graph_rewriter +def optimize_inner_graph(fgraph): + rewrite_inner_graph( + fgraph, lambda op: isinstance(op, ScipyWrapperOp), rewrite_optimize_inner_graph + ) + + +optdb.register( + "optimize_inner_graph", + optimize_inner_graph, + "minimum_compile", + "compile_inner_graph", + position=49.6, +) + + @register_canonicalize @node_rewriter([ScipyWrapperOp]) def remove_constants_and_duplicate_inputs_scipy(fgraph, node): @@ -19,7 +72,9 @@ def remove_constants_and_duplicate_inputs_scipy(fgraph, node): optimization variable x. """ op: ScipyWrapperOp = node.op - inner_x, *inner_args = op.inner_inputs + # Thaw the frozen inner graph; substitutions run on the mutable copy. + unfrozen_fgraph = op.fgraph.unfreeze() + inner_x, *inner_args = unfrozen_fgraph.inputs outer_x, *outer_args = list(node.inputs) givens = {} @@ -40,9 +95,9 @@ def remove_constants_and_duplicate_inputs_scipy(fgraph, node): if not givens: return None - new_inner_outputs = clone_replace(op.inner_outputs, replace=givens) + unfrozen_fgraph.replace_all(list(givens.items()), import_missing=True) new_inner_inputs = (inner_x, *new_inner_args) - new_fgraph = FunctionGraph(new_inner_inputs, new_inner_outputs, clone=False) + new_fgraph = FunctionGraph(new_inner_inputs, unfrozen_fgraph.outputs, clone=False) new_op = op.clone_with_new_fgraph(new_fgraph) new_outer_inputs = (outer_x, *new_outer_args) return new_op.make_node(*new_outer_inputs).outputs diff --git a/pytensor/xtensor/rewriting/utils.py b/pytensor/xtensor/rewriting/utils.py index b3d6433658..c02c6a8fa7 100644 --- a/pytensor/xtensor/rewriting/utils.py +++ b/pytensor/xtensor/rewriting/utils.py @@ -2,10 +2,10 @@ from collections.abc import Sequence from pytensor.compile import optdb +from pytensor.compile.rewriting import inline_ofg_expansion from pytensor.graph.rewriting.basic import NodeRewriter, dfs_rewriter from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase from pytensor.tensor.basic import infer_shape_db -from pytensor.tensor.rewriting.ofg import inline_ofg_expansion from pytensor.tensor.variable import TensorVariable from pytensor.xtensor.type import XTensorVariable diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index 87b8f87b3c..af5521672f 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -26,7 +26,6 @@ from pytensor.tensor.math import dot, exp, sigmoid from pytensor.tensor.math import round as pt_round from pytensor.tensor.math import sum as pt_sum -from pytensor.tensor.random.utils import RandomStream from pytensor.tensor.rewriting.shape import ShapeOptimizer from pytensor.tensor.shape import specify_shape from pytensor.tensor.type import ( @@ -45,7 +44,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): def test_valid_input(self): x, _y, _z = matrices("xyz") - with pytest.raises(ValueError, match=r"Expected at least.*"): + with pytest.raises(ValueError, match=r"Expected 1 input\(s\)"): OpFromGraph([x], [x])() with pytest.raises(ValueError, match=r"Expected 1 input\(s\)"): @@ -65,11 +64,9 @@ def test_clone(self): ofg = OpFromGraph([x], [2 * x]) - ofg_clone = ofg.clone() - - assert ofg_clone.fgraph is not ofg.fgraph - assert ofg_clone.fgraph.outputs != ofg.fgraph.outputs - assert equal_computations(ofg_clone.fgraph.outputs, ofg.fgraph.outputs) + # OpFromGraph is immutable (single frozen inner graph), so cloning + # returns self -- mirroring Composite. + assert ofg.clone() is ofg @pytest.mark.parametrize( "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] @@ -146,56 +143,6 @@ def test_grad_grad(self, cls_ofg): zv = np.ones((2, 2), dtype=config.floatX) * 5 np.testing.assert_array_almost_equal(6.0, fn(xv, yv, zv), 4) - @pytest.mark.parametrize( - "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] - ) - def test_shared(self, cls_ofg): - x, y, z = matrices("xyz") - s = shared(np.random.random((2, 2)).astype(config.floatX)) - e = x + y * z + s - with pytest.warns( - DeprecationWarning, - match="Implicit capture of shared variables is deprecated", - ): - op = cls_ofg([x, y, z], [e]) - # (1+3*5=array of 16) - (3+1*5=array of 8) - f = op(x, y, z) - op(y, z, x) - - fn = function([x, y, z], f) - xv = np.ones((2, 2), dtype=config.floatX) - yv = np.ones((2, 2), dtype=config.floatX) * 3 - zv = np.ones((2, 2), dtype=config.floatX) * 5 - # print function, function.__module__ - # print fn.maker.fgraph.toposort() - np.testing.assert_array_almost_equal(8.0, fn(xv, yv, zv), 4) - np.testing.assert_array_almost_equal(8.0, fn(xv, yv, zv), 4) - - @pytest.mark.parametrize( - "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] - ) - def test_shared_grad(self, cls_ofg): - x, y, z = matrices("xyz") - s = shared(np.random.random((2, 2)).astype(config.floatX)) - e = x + y * z + s - with pytest.warns( - DeprecationWarning, - match="Implicit capture of shared variables is deprecated", - ): - op = cls_ofg([x, y, z], [e]) - f = op(x, y, z) - f = f - grad(pt_sum(f), y) - fn = function([x, y, z], f) - xv = np.ones((2, 2), dtype=config.floatX) - yv = np.ones((2, 2), dtype=config.floatX) * 3 - zv = np.ones((2, 2), dtype=config.floatX) * 5 - np.testing.assert_array_almost_equal(11.0 + s.get_value(), fn(xv, yv, zv), 4) - - # grad again the shared variable - f = op(x, y, z) - f = f - grad(pt_sum(f), s) - fn = function([x, y, z], f) - np.testing.assert_array_almost_equal(15.0 + s.get_value(), fn(xv, yv, zv), 4) - @pytest.mark.parametrize( "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] ) @@ -417,27 +364,6 @@ def test_connection_pattern(self, cls_ofg): expect_result = [[True, False], [True, True], [False, True], [True, True]] assert results == expect_result - # Inner graph where some computation doesn't rely on explicit inputs - srng = RandomStream(seed=234) - rv_u = srng.uniform((2, 2)) - x, y = matrices("xy") - out1 = x + rv_u - out2 = y + 3 - out3 = 3 + rv_u - with pytest.warns( - DeprecationWarning, - match="Implicit capture of shared variables is deprecated", - ): - op3 = cls_ofg([x, y], [out1, out2, out3]) - - results = op3.connection_pattern(None) - expect_result = [ - [True, False, False], - [False, True, False], - [True, False, True], - ] - assert results == expect_result - def test_infer_shape(self): # test infer shape does not need to against inline case # since the Op is remove during optimization phase @@ -475,104 +401,33 @@ def test_infer_shape(self): assert opt_res.shape_feature.shape_tuple(x) is None assert opt_res.shape_feature.shape_tuple(z)[0].data == 2 - def test_make_node_shared(self): - """Make sure we can provide `OpFromGraph.make_node` new shared inputs and get a valid `OpFromGraph`.""" - - x = pt.scalar("x") - y = shared(1.0, name="y") - - with pytest.warns( - DeprecationWarning, - match="Implicit capture of shared variables is deprecated", - ): - test_ofg = OpFromGraph([x], [x + y], on_unused_input="ignore") - assert test_ofg.shared_inputs == [y] - - out = test_ofg(x) - - y_clone = y.clone() - assert y_clone != y - y_clone.name = "y_clone" - - with pytest.warns( - DeprecationWarning, - match="Implicit capture of shared variables is deprecated", - ): - out_new = test_ofg.make_node(*([*out.owner.inputs[:1], y_clone])).outputs[0] - - assert "on_unused_input" in out_new.owner.op.kwargs - assert out_new.owner.op.shared_inputs == [y_clone] - - out_fn = function([x], out_new) - assert np.array_equal(out_fn(1.0), 2.0) - - y_clone.set_value(2.0) - assert np.array_equal(out_fn(1.0), 3.0) - - # This should also work, because the containers are the same: - # y.set_value(1.0) - # assert np.array_equal(out_fn(1.0), 2.0) - - def test_shared_with_constant_input(self): - """Make sure that a constant input can be given to an `OpFromGraph` instance.""" - x = pt.scalar("x") - y = shared(1.0, name="y") - - with pytest.warns( - DeprecationWarning, - match="Implicit capture of shared variables is deprecated", - ): - test_ofg = OpFromGraph([x], [x + y]) - assert test_ofg.shared_inputs == [y] - - out = test_ofg(pt.as_tensor(1.0, dtype=config.floatX)) - - out_fn = function([], out) - assert np.array_equal(out_fn(), 2.0) - def test_missing_input(self): x = pt.lscalar("x") with pytest.raises(MissingInputError): OpFromGraph([], [x]) - def test_shared_to_nonshared_input(self): - """Make sure that shared variables can be replaced with non-shared variables.""" - x = pt.scalar("x") - y = shared(1.0, name="y") - - with pytest.warns( - DeprecationWarning, - match="Implicit capture of shared variables is deprecated", - ): - test_ofg = OpFromGraph([], [y]) - assert test_ofg.shared_inputs == [y] - - out_1_fn = function([], test_ofg()) - res_1 = out_1_fn() - - assert np.array_equal(res_1, 1.0) - - test_ofg_new = test_ofg.make_node(x) - assert test_ofg_new.op.shared_inputs == [] - - out_2_fn = function([x], test_ofg_new.outputs[0]) - res_2 = out_2_fn(np.array(1.0, dtype=config.floatX)) - - assert np.array_equal(res_2, 1.0) - def test_outputs_consistency(self): - """Make sure that `OpFromGraph.fn` doesn't change the value of `OpFromGraph.inner_outputs`.""" + """Compiling the inner function must not mutate `OpFromGraph.inner_outputs`.""" x = scalar("x") - op = OpFromGraph([x], [x**2 / x], mode="FAST_RUN") + op = OpFromGraph([x], [x**2 / x]) # Confirm that the inner-graph is as expected assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x]) - # These outputs of the compiled `op.fgraph` should differ from the - # original, uncompiled `op.fgraph` outputs - fn = op.fn + # Optimizing a copy of the inner graph (here FAST_RUN, which rewrites + # ``x**2 / x`` to ``x``) must not leak back into the canonical, frozen + # inner graph. The canonical graph is immutable and is never handed to + # ``function`` directly; compile an ``unfreeze()``d mutable copy instead. + unfrozen = op.fgraph.unfreeze() + fn = function( + unfrozen.inputs, + unfrozen.outputs, + mode="FAST_RUN", + on_unused_input="ignore", + accept_inplace=True, + ) new_inputs = fn.maker.fgraph.inputs new_outputs = fn.maker.fgraph.outputs assert not equal_computations(new_outputs, [x**2 / x], new_inputs, [x]) @@ -580,6 +435,10 @@ def test_outputs_consistency(self): # The original `op.fgraph` outputs should stay the same, though assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x]) + # `op.fn` (compiled under the active mode) must likewise leave it intact. + op.fn + assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x]) + def test_explicit_input_from_constant(self): x = pt.dscalar("x") y = constant(1.0, dtype=x.type.dtype, name="y") @@ -595,13 +454,14 @@ def test_explicit_input_from_shared(self): x = pt.dscalar("x") y = shared(1.0, name="y") + # A shared variable may only be used if passed explicitly as an input. with pytest.raises( - ValueError, - match=r"The inner-graph implicitly depends on the following shared variables \[y\]", + MissingInputError, + match=r"implicitly depends on shared variable y", ): - OpFromGraph([x], [x + y], strict=True) + OpFromGraph([x], [x + y]) - test_ofg = OpFromGraph([x, y], [x + y], strict=True) + test_ofg = OpFromGraph([x, y], [x + y]) out = test_ofg(x, y) assert out.eval({x: 5}) == 6 @@ -611,16 +471,6 @@ def test_explicit_input_from_shared(self): out = test_ofg(y, y) assert out.eval() == 4 - def test_implicit_shared_inputs_deprecated(self): - x = pt.dscalar("x") - y = shared(1.0, name="y") - - with pytest.warns( - DeprecationWarning, - match="Implicit capture of shared variables is deprecated", - ): - OpFromGraph([x], [x + y]) - @pytest.mark.parametrize("use_custom_pullback", [False, True]) def test_pullback_disconnected_output_grad(self, use_custom_pullback): x, y = dscalars("x", "y") @@ -765,31 +615,6 @@ def test_equality_and_hashing(self): # OFG is hashable, and different OFGs have different hashes assert hash(op1) != hash(op_inline) - def test_equality_shared_variables(self): - x = scalar("x") - s = shared(np.array(1.0, dtype=config.floatX)) - - with pytest.warns( - DeprecationWarning, - match="Implicit capture of shared variables is deprecated", - ): - op1 = OpFromGraph([x], [x + s]) - with pytest.warns( - DeprecationWarning, - match="Implicit capture of shared variables is deprecated", - ): - op2 = OpFromGraph([x], [x + s]) - assert op1 == op2 - - # Same value, different shared object -> not equal - s2 = shared(np.array(1.0, dtype=config.floatX)) - with pytest.warns( - DeprecationWarning, - match="Implicit capture of shared variables is deprecated", - ): - op3 = OpFromGraph([x], [x + s2]) - assert op1 != op3 - def test_equality_callable_overrides(self): x, y = dscalars("x", "y") e = x + y @@ -921,6 +746,28 @@ def test_merge_identical_ofgs(self): np.testing.assert_allclose(r2, 4.0 + 5.0 * 4.0) +@pytest.mark.parametrize("linker", ["cvm", "py"]) +def test_view_output_copied_at_boundary(linker): + # Regression test: an OpFromGraph output that aliases an input must be copied + # at the op boundary (OpFromGraph declares no view_map, so the outer graph + # cannot see the alias). Without the copy, a downstream op destroying the + # input in place corrupts the already-computed output. + x = pt.dvector("x") + op = OpFromGraph([x], [x[::-1]]) + + xin = pt.dvector("xin") + x2 = xin * 2 + view_out = op(x2) + destroyed = pt.inc_subtensor(x2[0], 100.0) + + fn = function( + [xin], [view_out, destroyed], mode=Mode(linker=linker, optimizer="fast_run") + ) + res_view, res_destroyed = fn(np.arange(1.0, 4.0)) + np.testing.assert_allclose(res_view, [6.0, 4.0, 2.0]) + np.testing.assert_allclose(res_destroyed, [102.0, 4.0, 6.0]) + + @config.change_flags(floatX="float64") def test_debugprint(): x, y, z = matrices("xyz") diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index ce2a5d5429..0da32429e2 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -8,6 +8,7 @@ from pytensor import tensor as pt from pytensor.compile.maker import UnusedInputError from pytensor.graph.basic import ( + AbstractApply, Apply, NominalVariable, Variable, @@ -186,13 +187,12 @@ def test_clone_inner_graph(self): o2.name = "o1" o2_node = o2.owner - o2_node_clone = o2_node.clone(clone_inner_graph=True) + o2_node_clone = o2_node.clone() + # Inner-graph Ops are immutable, so cloning a node shares the Op (and its + # inner graph) rather than deep-cloning it. assert o2_node_clone is not o2_node - assert o2_node_clone.op.fgraph is not o2_node.op.fgraph - assert equal_computations( - o2_node_clone.op.fgraph.outputs, o2_node.op.fgraph.outputs - ) + assert o2_node_clone.op is o2_node.op class TestEval: @@ -493,7 +493,9 @@ def test_interning_and_immutability(self): assert fa1 is not fa_diff_op assert fa1 is not fa_diff_order - assert isinstance(fa1, Apply) + # FrozenApply shares the read-only node API but is not a mutable Apply + assert isinstance(fa1, AbstractApply) + assert not isinstance(fa1, Apply) assert fa1.outputs[0].owner is fa1 assert fa1.outputs[0].index == 0 @@ -524,6 +526,6 @@ def test_constant_deduplication_via_frozen_fgraph(self): out1 = add(x, ScalarConstant(float64, 3.14)) out2 = add(x, ScalarConstant(float64, 3.14)) - ffg1 = FrozenFunctionGraph([x], [out1]) - ffg2 = FrozenFunctionGraph([x], [out2]) + ffg1 = FrozenFunctionGraph.from_io([x], [out1]) + ffg2 = FrozenFunctionGraph.from_io([x], [out2]) assert ffg1 == ffg2 diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index 4dad31763e..81dfff8bb1 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -875,13 +875,13 @@ def test_orphan_non_constant_raises(self): orphan = MyVariable("orphan") out = op1(var1, orphan) with pytest.raises(ValueError, match=r"Orphan.*orphan"): - FrozenFunctionGraph([var1], [out]) + FrozenFunctionGraph.from_io([var1], [out]) def test_unmapped_output_raises(self): var1 = MyVariable("x") disconnected = MyVariable("disconnected") with pytest.raises(ValueError, match="could not be mapped"): - FrozenFunctionGraph([var1], [disconnected]) + FrozenFunctionGraph.from_io([var1], [disconnected]) def test_interned_constant_in_variables(self): """Regression test: all node inputs must appear in variables. @@ -898,7 +898,7 @@ def test_interned_constant_in_variables(self): # Populate FrozenApply cache: op_shared(NomVar_0, c1) x1 = MyVariable("x") c1 = MyConstant("c", data=42) - FrozenFunctionGraph([x1], [op_shared(x1, c1)]) + FrozenFunctionGraph.from_io([x1], [op_shared(x1, c1)]) # New graph with a fresh constant c2 (same value, different object). # op_unique: cache miss → FrozenApply stores c2 @@ -906,7 +906,7 @@ def test_interned_constant_in_variables(self): # Both c1 and c2 must be in variables. x2 = MyVariable("x") c2 = MyConstant("c", data=42) - fg = FrozenFunctionGraph([x2], [op_shared(x2, c2), op_unique(x2, c2)]) + fg = FrozenFunctionGraph.from_io([x2], [op_shared(x2, c2), op_unique(x2, c2)]) for node in fg.toposort(): for inp in node.inputs: @@ -918,8 +918,8 @@ def test_constant_output_equality(self): c2 = ScalarConstant(float64, 3.14) assert c1 is not c2 - ffg1 = FrozenFunctionGraph([], [c1]) - ffg2 = FrozenFunctionGraph([], [c2]) + ffg1 = FrozenFunctionGraph.from_io([], [c1]) + ffg2 = FrozenFunctionGraph.from_io([], [c2]) assert ffg1 == ffg2 assert hash(ffg1) == hash(ffg2) assert ffg1.outputs == ffg2.outputs @@ -973,8 +973,8 @@ def test_value_dependent_output_type_collision(self): # s1/s2 are MakeVector variables used as inputs, so freezing nominalizes them and # discards the [2,3]/[3,2] values that informed each reshape's output shape. - ffg1 = FrozenFunctionGraph([x, s1], [rs1]) - ffg2 = FrozenFunctionGraph([x, s2], [rs2]) + ffg1 = FrozenFunctionGraph.from_io([x, s1], [rs1]) + ffg2 = FrozenFunctionGraph.from_io([x, s2], [rs2]) # The original output types are part of the FrozenApply key, so the fgraphs aren't identical. # Alternative design: the two ffg get merged with a general output type (None, None) @@ -993,12 +993,19 @@ def test_bind_constant_output(self): """bind must handle constants that appear directly as outputs.""" x = float64("x") c = ScalarConstant(float64, 42.0) - ffg = FunctionGraph([x], [add(x, c), c]).freeze() + # c2 appears *only* as an output, never as a node input + c2 = ScalarConstant(float64, 7.0) + ffg = FunctionGraph([x], [add(x, c), c, c2]).freeze() y = float64("y") - bound = ffg.bind({ffg.inputs[0]: y}) - assert len(bound) == 2 + bound, memo = ffg.bind({ffg.inputs[0]: y}, return_memo=True) + assert len(bound) == 3 assert bound[1] is c + assert bound[2].data == 7.0 + # Callers index the memo with inner outputs directly (e.g. ScanMerge), + # so it must cover constant-only outputs too. + assert memo[ffg.outputs[1]] is bound[1] + assert memo[ffg.outputs[2]] is bound[2] def test_from_structural_inputs_only_root_inputs(self): """All inputs are roots: behaves like the plain constructor.""" diff --git a/tests/link/jax/test_tensor_basic.py b/tests/link/jax/test_tensor_basic.py index 66b8791cfb..555e27b5cc 100644 --- a/tests/link/jax/test_tensor_basic.py +++ b/tests/link/jax/test_tensor_basic.py @@ -11,7 +11,9 @@ import pytensor import pytensor.tensor.basic as ptb +from pytensor.compile.mode import Mode from pytensor.configdefaults import config +from pytensor.link.jax.linker import JAXLinker from pytensor.tensor.type import matrix, scalar, vector from tests.link.jax.test_basic import compare_jax_and_py from tests.tensor.test_basic import check_alloc_runtime_broadcast @@ -68,6 +70,14 @@ def test_arange_of_shape(): compare_jax_and_py([x], [out], [np.zeros((5,))], jax_mode="JAX") +def test_arange_of_shape_minimum_compile(): + # The shape read stays `Subtensor(Shape(x))`, instead of `Shape_i(x)` + x = vector("x") + out = ptb.arange(1, x.shape[-1], 2) + minimum_jax_mode = Mode(linker=JAXLinker(), optimizer=None) + compare_jax_and_py([x], [out], [np.zeros((5,))], jax_mode=minimum_jax_mode) + + def test_arange_nonconcrete(): """JAX cannot JIT-compile `jax.numpy.arange` when arguments are not concrete values.""" diff --git a/tests/link/numba/test_compile_ops.py b/tests/link/numba/test_compile_ops.py index b8b8f1b56b..e2a1ac08fa 100644 --- a/tests/link/numba/test_compile_ops.py +++ b/tests/link/numba/test_compile_ops.py @@ -13,6 +13,7 @@ from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.linalg.decomposition.cholesky import Cholesky, cholesky +from tests import unittest_tools as utt from tests.link.numba.test_basic import compare_numba_and_py @@ -235,13 +236,14 @@ def test_check_and_raise(): def test_ofg_with_inner_scan_rewrite(): - # Regression test where inner scan would be mutated when compiling outer OFG - ys = pt.tensor("ys", shape=(5, 3, 3)) + # Regression test where inner scan would be mutated when compiling outer OFG. + # The inner cholesky is *batched* (over the size-4 axis) so the Blockwise + # survives optimization and is wrapped as BlockwiseWithCoreShape for numba. + ys = pt.tensor("ys", shape=(5, 4, 3, 3)) xs = scan( lambda y: cholesky(y), sequences=[ys], return_updates=False, - mode=Mode(optimizer=None), ) xs_ofg = OpFromGraph([ys], [xs])(ys) fn = function([ys], xs_ofg, mode="NUMBA") @@ -265,6 +267,72 @@ def test_ofg_with_inner_scan_rewrite(): assert isinstance(cholesky_op.core_op, Cholesky) +def test_compiling_does_not_mutate_canonical_inner_graph(): + # Regression test: compiling an op with an inner graph must NOT mutate the + # canonical (shared) inner FunctionGraph. Backend specialization (e.g. the + # numba ``BlockwiseWithCoreShape`` wrapping) must happen on per-compilation + # copies, never on the op the user holds -- otherwise a second use of the + # same op (here: deriving something from it after an ``.eval()``) sees a + # corrupted inner graph. This is what tripped scan-based pymc CustomDists. + ys = pt.tensor("ys", shape=(5, 4, 3, 3)) + xs = scan(lambda y: cholesky(y), sequences=[ys], return_updates=False) + ofg_out = OpFromGraph([ys], [xs])(ys) + + scan_op = ofg_out.owner.op.fgraph.outputs[0].owner.op + assert isinstance(scan_op, Scan) + + # Compile (and run) under numba: this used to mutate the inner graph above. + fn = function([ys], ofg_out, mode="NUMBA") + fn(np.eye(3)[None, None].repeat(5, 0).repeat(4, 1)) + + # The compiled function carries a different scan op whose inner cholesky was + # wrapped for the backend... + fn_scan_op = fn.maker.fgraph.outputs[0].owner.op.fgraph.outputs[0].owner.op + assert isinstance(fn_scan_op, Scan) + assert fn_scan_op is not scan_op + assert isinstance(fn_scan_op.fgraph.outputs[0].owner.op, BlockwiseWithCoreShape) + # ...while the canonical op still computes a bare batched cholesky. + y_t = pt.tensor("y_t", shape=(4, 3, 3)) + utt.assert_equal_computations( + list(scan_op.fgraph.outputs), + [cholesky(y_t)], + in_xs=list(scan_op.fgraph.inputs), + in_ys=[y_t], + ) + + +def test_blockwise_inner_graph_optimized_for_backend(): + # Regression test for https://github.com/pymc-devs/pytensor/issues/2028. + # An OpFromGraph wrapped in a Blockwise must still have its inner graph lowered + # for the backend -- otherwise numba falls back to object mode (here the inner + # cholesky stays a degenerate Blockwise instead of collapsing to a bare + # Cholesky). And, the original #2028 concern, that lowering must happen on a + # per-compilation copy and never mutate the canonical (shared) inner graph. + yy = pt.matrix("yy") + core = OpFromGraph([yy], [cholesky(yy) + 1.0]) + xs = pt.tensor("xs", shape=(4, 3, 3)) + out = Blockwise(core, signature="(m,m)->(m,m)")(xs) + canonical = out.owner.op.core_op + + val = (np.eye(3)[None] * 2.0).repeat(4, 0).astype(config.floatX) + fn = function([xs], out, mode="NUMBA") + np.testing.assert_allclose(fn(val), np.linalg.cholesky(val) + 1.0) + + # The compiled op carries its own, backend-lowered inner graph: the degenerate + # Blockwise(Cholesky) collapsed to a bare Cholesky (so numba does not object-mode). + compiled = fn.maker.fgraph.outputs[0].owner.op.core_op + assert compiled is not canonical + assert not any(isinstance(n.op, Blockwise) for n in compiled.fgraph.apply_nodes) + # The canonical inner graph must be untouched by compilation. + yy2 = pt.matrix("yy2") + utt.assert_equal_computations( + list(canonical.fgraph.outputs), + [cholesky(yy2) + 1.0], + in_xs=list(canonical.fgraph.inputs), + in_ys=[yy2], + ) + + @pytest.mark.parametrize("as_view", [True, False]) def test_ifelse_single_output(as_view, single_out=True): x = pt.vector("x") diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index eaf10b9848..b577d65a78 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -439,13 +439,15 @@ def step(seq_t, x2, x1): if isinstance(node.op, Scan) and node.op.info.n_mit_mot ] assert mitmot_scan_ops, "expected the gradient to produce a mit-mot scan" - destroyed_mitmot_reads = [ - inp - for scan_op in mitmot_scan_ops - if hasattr(scan_op.fgraph, "destroyers") - for inp in scan_op.inner_mitmot(scan_op.fgraph.inputs) - if scan_op.fgraph.destroyers(inp) - ] + destroyed_mitmot_reads = [] + for scan_op in mitmot_scan_ops: + mitmot_reads = set(scan_op.inner_mitmot(scan_op.fgraph.inputs)) + for node in scan_op.fgraph.toposort(): + destroyed_mitmot_reads.extend( + node.inputs[idx] + for idx in chain.from_iterable(node.op.destroy_map.values()) + if node.inputs[idx] in mitmot_reads + ) assert destroyed_mitmot_reads, "expected a mit-mot read destroyed in place" diff --git a/tests/scalar/test_loop.py b/tests/scalar/test_loop.py index 0b72090a06..6cca6cbaf9 100644 --- a/tests/scalar/test_loop.py +++ b/tests/scalar/test_loop.py @@ -332,18 +332,7 @@ def test_identical_loops_share_inner_graph(): assert hash(op1) == hash(op2) assert op1.fgraph == op2.fgraph - # Two loops with the same structure but different outer inputs. - # MergeOptimizer can't collapse the Apply nodes (different inputs), - # but both should reference the same inner Op after merging. - n = int64("n") - a, b, c_val, d = float64("a"), float64("b"), float64("c_val"), float64("d") - y1 = op1(n, a, b) - y2 = op2(n, c_val, d) - - fn = function( - [n, a, b, c_val, d], [y1, y2], mode=Mode(optimizer="merge", linker="py") - ) - nodes = fn.maker.fgraph.toposort() - loop_nodes = [nd for nd in nodes if isinstance(nd.op, ScalarLoop)] - assert len(loop_nodes) == 2 - assert loop_nodes[0].op is loop_nodes[1].op + # Structurally identical inner graphs are globally interned via FrozenApply, + # so the two distinct op wrappers share the very same inner-graph nodes (the + # heavy state) at construction -- no compilation or canonicalization needed. + assert op1.fgraph.outputs[0].owner is op2.fgraph.outputs[0].owner diff --git a/tests/scan/rewriting/test_inplace.py b/tests/scan/rewriting/test_inplace.py index 33e2f89173..190f6fea75 100644 --- a/tests/scan/rewriting/test_inplace.py +++ b/tests/scan/rewriting/test_inplace.py @@ -5,7 +5,7 @@ import pytensor.tensor as pt from pytensor import function, scan, shared from pytensor.compile.io import In -from pytensor.compile.mode import get_default_mode +from pytensor.compile.mode import Mode, get_default_mode from pytensor.configdefaults import config from pytensor.graph.basic import equal_computations from pytensor.graph.fg import FunctionGraph @@ -13,6 +13,7 @@ from pytensor.scan.op import Scan from pytensor.scan.rewriting import ScanInplaceOptimizer from pytensor.tensor.random.op import RandomVariableWithCoreShape +from pytensor.tensor.random.type import random_generator_type from pytensor.tensor.type import scalar, vector from tests import unittest_tools as utt from tests.scan.test_basic import asarrayX @@ -237,6 +238,36 @@ def test_inplace_untraced_sit_sot(self): # Evaluate and check non equality assert f() != f() + def test_untraced_sit_sot_unowned_not_inplaced(self): + # Under the C/VM backend, an untraced sit_sot whose outer buffer the Scan does + # not own must not be destroyed in place by the step fn -- otherwise the + # destruction would reach back to the caller's input. (An rng state is the only + # kind that stays untraced under C/VM; an array gets a length-2 buffer and stays + # a plain sit_sot.) Numba covers the same case by always destroying and copying + # the first iteration; here we pin the C/VM behavior, so force a cvm mode. + cvm = Mode(linker="cvm", optimizer="fast_run") + rng = random_generator_type("rng") # a function input -> not owned by the Scan + x0 = pt.scalar("x0") + _, xs = scan( + fn=lambda r, x: pt.random.normal(x, rng=r).owner.outputs, + outputs_info=[rng, x0], + n_steps=5, + return_updates=False, + ) + f = function([rng, x0], xs[-1], mode=cvm) + [op] = [n.op for n in f.maker.fgraph.toposort() if isinstance(n.op, Scan)] + + assert op.info.n_untraced_sit_sot == 1 + untraced_start = op.n_tap_outs + op.info.n_nit_sot + # Not owned by the Scan, so the C/VM inner step does not destroy it in place. + assert untraced_start not in op.destroy_map + + # Correct numerics vs a no-inplace reference (fresh, equal rngs). + ref = function([rng, x0], xs[-1], mode=cvm.excluding("inplace")) + utt.assert_allclose( + f(np.random.default_rng(123), 0.5), ref(np.random.default_rng(123), 0.5) + ) + def test_inplace3(self): rng = np.random.default_rng(utt.fetch_seed()) diff --git a/tests/scan/rewriting/test_merge.py b/tests/scan/rewriting/test_merge.py index 384a58c415..5264d1881e 100644 --- a/tests/scan/rewriting/test_merge.py +++ b/tests/scan/rewriting/test_merge.py @@ -10,6 +10,7 @@ from pytensor.scan.op import Scan from pytensor.scan.rewriting import ScanMerge from pytensor.scan.utils import until +from pytensor.tensor import constant as pt_constant from pytensor.tensor import stack from pytensor.tensor.type import scalar, vector from tests import unittest_tools as utt @@ -73,6 +74,29 @@ def sum(s): f = function([x], [sx, sy], mode=self.mode) assert self.count_scans(f) == 2 + def test_constant_inner_output(self): + """Scans whose inner graph returns a bare constant output must still merge.""" + x = vector() + y = vector() + + def step(s): + return s + 1, pt_constant(np.asarray(3.0, dtype=config.floatX)) + + (sx, cx), _upx = scan(step, sequences=[x], n_steps=4) + (sy, cy), _upy = scan(step, sequences=[y], n_steps=4) + + with config.change_flags(on_opt_error="raise"): + f = function([x, y], [sx, cx, sy, cy], mode=self.mode) + assert self.count_scans(f) == 1 + + x_val = np.arange(4, dtype=config.floatX) + y_val = np.arange(10, 14, dtype=config.floatX) + res_sx, res_cx, res_sy, res_cy = f(x_val, y_val) + np.testing.assert_allclose(res_sx, x_val + 1) + np.testing.assert_allclose(res_sy, y_val + 1) + np.testing.assert_allclose(res_cx, np.full(4, 3.0)) + np.testing.assert_allclose(res_cy, np.full(4, 3.0)) + def test_three_scans(self): r""" This test checks a case where we have three `Scan`\s, two of them diff --git a/tests/scan/rewriting/test_push_out.py b/tests/scan/rewriting/test_push_out.py index 508abaa150..8aa55949c9 100644 --- a/tests/scan/rewriting/test_push_out.py +++ b/tests/scan/rewriting/test_push_out.py @@ -464,10 +464,10 @@ def test_OpFromGraph_shared(self): y = shared(1.0, name="y") - test_ofg = OpFromGraph([], [1 + y]) + test_ofg = OpFromGraph([y], [1 + y]) def inner_func(): - return test_ofg() + return test_ofg(y) out, out_updates = pytensor.scan(inner_func, n_steps=10) @@ -484,17 +484,17 @@ def inner_func(): def test_nested_OpFromGraph_shared(self): y = pytensor.shared(1.0, name="y") - test_ofg = OpFromGraph([], [y]) + test_ofg = OpFromGraph([y], [y]) def inner_func(x): - out = pytensor.scan(lambda: test_ofg(), n_steps=x, return_updates=False) + out = pytensor.scan(lambda: test_ofg(y), n_steps=x, return_updates=False) return out out = pytensor.scan( inner_func, sequences=[pt.arange(1, 2)], return_updates=False ) - _ = pytensor.function([], test_ofg()) + _ = pytensor.function([], test_ofg(y)) out_fn = pytensor.function([], out) diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 1bbd183953..1f7a5fdc73 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -27,12 +27,12 @@ from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.gradient import NullTypeGradError, disconnected_grad, grad, pushforward -from pytensor.graph.basic import Apply, Variable, equal_computations +from pytensor.graph.basic import Apply, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.graph.replace import vectorize_graph from pytensor.graph.rewriting.basic import MergeOptimizer -from pytensor.graph.traversal import ancestors +from pytensor.graph.traversal import ancestors, apply_ancestors from pytensor.graph.utils import MissingInputError from pytensor.link.vm import VMLinker from pytensor.raise_op import assert_op @@ -299,11 +299,9 @@ def test_clone(self): scan_op = output.owner.op assert isinstance(scan_op, Scan) - scan_op_clone = scan_op.clone() - assert scan_op_clone is not scan_op - assert scan_op_clone.fgraph is not scan_op.fgraph - assert scan_op_clone.fgraph.outputs != scan_op.fgraph.outputs - assert equal_computations(scan_op_clone.fgraph.outputs, scan_op.fgraph.outputs) + # Scan ops are immutable (single frozen inner graph), so cloning returns + # self -- mirroring Composite. + assert scan_op.clone() is scan_op @pytest.mark.skipif( isinstance(get_default_mode(), DebugMode), @@ -814,6 +812,30 @@ def test_hash(self): assert scan1.owner.op == scan2.owner.op assert hash(scan1.owner.op) == hash(scan2.owner.op) + def test_hash_equality_after_inner_optimization(self): + """Regression test for #1601: the frozen inner graph keeps a `Scan` `Op`'s + equality in sync with its hash even after compilation optimizes it.""" + x0 = scalar("x0") + xs = scan(lambda x: x + 0, outputs_info=[x0], n_steps=5, return_updates=False) + ys = scan(lambda x: x * 1, outputs_info=[x0], n_steps=5, return_updates=False) + + # Before compilation the inner graphs differ (``x + 0`` vs ``x * 1``), so the + # ops -- and their hashes -- differ. + op1, op2 = ( + node.op for node in apply_ancestors([xs, ys]) if isinstance(node.op, Scan) + ) + assert op1 != op2 + assert hash(op1) != hash(op2) + + # Compilation optimizes both inner graphs to the identity; the ops must then + # be equal *and* hash equal. + fn = function([x0], [xs, ys], mode=get_default_mode().excluding("scan")) + op1, op2 = ( + node.op for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan) + ) + assert op1 == op2 + assert hash(op1) == hash(op2) + def test_can_merge(self): """Make sure that equivalent `Scan` nodes can be merged.""" diff --git a/tests/tensor/rewriting/test_ofg.py b/tests/tensor/rewriting/test_ofg.py index 7427ecbdc8..047f4bb5e0 100644 --- a/tests/tensor/rewriting/test_ofg.py +++ b/tests/tensor/rewriting/test_ofg.py @@ -4,9 +4,9 @@ import pytensor.tensor as pt from pytensor import config from pytensor.compile.builders import OpFromGraph +from pytensor.compile.rewriting import inline_ofg_expansion from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import dfs_rewriter -from pytensor.tensor.rewriting.ofg import inline_ofg_expansion @pytest.mark.skipif( diff --git a/tests/tensor/test_einsum.py b/tests/tensor/test_einsum.py index 1943d7a394..093cb06efe 100644 --- a/tests/tensor/test_einsum.py +++ b/tests/tensor/test_einsum.py @@ -37,7 +37,7 @@ def assert_no_blockwise_in_graph(fgraph: FunctionGraph, core_op=None) -> None: if isinstance(node.op, HasInnerGraph): # InnerGraph Ops can be rewritten without modifying the original fgraph - if hasattr(node.op, "_fn"): + if getattr(node.op, "_fn", None) is not None: inner_fgraph = node.op._fn.maker.fgraph else: inner_fgraph = node.op.fgraph diff --git a/tests/test_printing.py b/tests/test_printing.py index 973c7c5888..55476ecb34 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -567,9 +567,10 @@ def add_one_composite(): for exp_line, res_line in zip(exp_res.split("\n"), lines, strict=True): assert exp_line.strip() == res_line.strip() - # An Op that only appears nested inside other inner graphs: its nodes are - # only discovered while the parent bodies are printed, and the shared - # header must still list every node id + # An Op that only appears nested inside other inner graphs is still + # discovered and printed in the "Inner graphs" section. Here both A and B + # apply the same Relu to their (nominal) input, so global FrozenApply + # interning collapses it to a single shared node, printed once. i1 = dvector("i") a_op = OpFromGraph([i1], [relu_ofg()(i1) + 1], inline=False, name="A") i2 = dvector("i") @@ -594,16 +595,16 @@ def add_one_composite(): B{inline=False} [id D] ← Mul [id J] - ├─ Relu{inline=False} [id K] - │ └─ i0 [id G] - └─ ExpandDims{axis=0} [id L] - └─ 2 [id M] + ├─ Relu{inline=False} [id F] + │ └─ ··· + └─ ExpandDims{axis=0} [id K] + └─ 2 [id L] -Relu{inline=False} [id F, K] - ← Maximum [id N] +Relu{inline=False} [id F] + ← Maximum [id M] ├─ i0 [id G] - └─ ExpandDims{axis=0} [id O] - └─ 0 [id P] + └─ ExpandDims{axis=0} [id N] + └─ 0 [id O] """ for exp_line, res_line in zip(exp_res.split("\n"), lines, strict=True):