Skip to content

Make inner graph Ops immutable#2243

Merged
ricardoV94 merged 7 commits into
pymc-devs:mainfrom
ricardoV94:real_frozen_graphs
Jul 3, 2026
Merged

Make inner graph Ops immutable#2243
ricardoV94 merged 7 commits into
pymc-devs:mainfrom
ricardoV94:real_frozen_graphs

Conversation

@ricardoV94

@ricardoV94 ricardoV94 commented Jun 18, 2026

Copy link
Copy Markdown
Member

Closes #2232
Closes #2028
Closes #2033
Closes #2194 by adding a method (should it be the default) where theres's no node deduplication, as the toposort index becomes part of the hash
Closes #1601
Closes #2035

No more inner graph "cloning" of inner graphs

@ricardoV94 ricardoV94 force-pushed the real_frozen_graphs branch 5 times, most recently from fe15202 to 726ee00 Compare June 20, 2026 15:59
@ricardoV94 ricardoV94 force-pushed the real_frozen_graphs branch 5 times, most recently from 67d4867 to f77b580 Compare June 28, 2026 22:41
Comment thread pytensor/compile/builders.py
Comment thread pytensor/compile/builders.py Outdated
@ricardoV94 ricardoV94 force-pushed the real_frozen_graphs branch 8 times, most recently from 7dbeea9 to c266735 Compare July 2, 2026 07:51
A graph compiled with only the minimum_compile rewrites keeps a shape
element read as Subtensor(Shape(x)) instead of canonicalizing it to
Shape_i(x). Both are concrete under JAX tracing, so accept either form
as an arange bound. Inner graphs will soon reach the backend this way.
Passing shared variables implicitly through an OpFromGraph inner graph was
soft-deprecated in pymc-devs#2047; make it a hard error. construct_nominal_fgraph now
raises MissingInputError for any non-input, non-constant dependency (including
shared variables), and the now-dead shared-input machinery (make_node/__call__/
__eq__ handling, the shared_inputs attribute) is removed from OpFromGraph.
@ricardoV94 ricardoV94 force-pushed the real_frozen_graphs branch 3 times, most recently from e5d09f3 to ca73d56 Compare July 2, 2026 13:08
@ricardoV94 ricardoV94 force-pushed the real_frozen_graphs branch 2 times, most recently from 8fe033c to d5907bc Compare July 2, 2026 14:03
Comment thread pytensor/graph/basic.py Outdated
Comment thread pytensor/compile/aliasing.py Outdated
Comment thread pytensor/compile/aliasing.py Outdated
Comment thread pytensor/tensor/rewriting/optimize.py
Comment thread pytensor/tensor/rewriting/optimize.py
Comment thread pytensor/tensor/rewriting/optimize.py

@jessegrabowski jessegrabowski left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we had a call to discuss this in depth. looks good

Comment thread pytensor/graph/basic.py
there are multiple outputs and self.op.default_output does not exist.

"""
do = getattr(self.op, "default_output", None)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:call it default_output

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just so you know this was previous code, just moved around

@ricardoV94 ricardoV94 force-pushed the real_frozen_graphs branch from d5907bc to 9775d3a Compare July 2, 2026 20:45
Introduce FrozenFunctionGraph and FrozenApply: an immutable, hash-consable
representation of an inner graph, plus an AbstractApply base shared by the mutable
Apply and the immutable FrozenApply. FunctionGraph.freeze(dedup_nodes=...) builds one,
baking a destroy-aware toposort by forwarding the source graph's orderings (a
DestroyHandler, if present, makes the order safe for a backend to funcify as-is; a
regular graph stays the plain toposort at no cost). Apply.clone_with_new_inputs rebuilds
immutable nodes instead of mutating in place. Composite and ScalarLoop adopt the frozen
representation for their inner graphs. The clone_inner_graph(s) machinery is retained
here and removed once every inner-graph op is immutable.
Scan and OpFromGraph carry a single frozen inner graph (op.fgraph) and are optimized
by a graph rewrite instead of at link time. Each inner-graph op has its own rewrite
that dispatches the inner-graph inplace contract on the linker via
functools.singledispatch with a raising base, so every backend registers an
implementation explicitly (scan/rewriting/{c,numba,jit}.py, tensor/rewriting/{numba,
ofg}.py). numba owns its Scan memory model (+ potentially_overwritten_reads) and the
boundary deepcopies in scan/rewriting/numba.py; numba funcify just funcifies op.fgraph
and reads codegen metadata off the op (no graph mutation). The shared orchestration
lives in compile/rewriting.py::rewrite_inner_graph, which groups nodes by inner op plus
input types, destroy_map and view_map (ops sharing a frozen graph but with different
inplace permissions must bake separately).

The OpFromGraph rewrite registers at position 49.6; the Scan rewrite runs at 100,
after the inplace passes, so its inner graph is baked with the outer Scan's final
destroy/view permissions -- a tap may only be destroyed in place when the outer Scan
owns its buffer. numba always destroys untraced states and copies the first iteration;
the C/VM rewrite instead destroys an untraced state only when the Scan owns it.

The rewrite owns unfreeze -> optimize -> freeze (the optimized graph still carries a
DestroyHandler, so its freeze is destroy-aware) and the ops store the frozen graph
verbatim: Scan.clone_with_inner_graph is copy(self) + inner.freeze() (no rebuild);
OpFromGraph rebuilds via construct_nominal_fgraph and re-attaches a DestroyHandler
gated on inplace.
ScipyWrapperOp (Minimize / Root) carries a frozen inner graph and registers a
per-linker optimize_inner_graph rewrite (noinplace, via the shared
compile/rewriting.py::rewrite_inner_graph helper), like Scan and OpFromGraph.
Cloning an outer graph used to optionally deep-clone the inner graphs of
HasInnerGraph ops (clone_inner_graph(s)=True). Now that every inner-graph op
(Scan/OpFromGraph/Composite/ScipyWrapperOp) is immutable -- clone() returns self --
that deep-clone is always a no-op, so drop the branch throughout the clone machinery
(Apply.clone, Apply.clone_with_new_inputs, clone, clone_get_equiv,
FunctionGraph.clone, rebuild_collect_shared). The public clone_inner_graph(s) kwarg is
kept as a deprecated, ignored no-op that emits a FutureWarning; only the internal
clone_node_and_cache (and its op-clone caching) loses it outright.
@ricardoV94 ricardoV94 force-pushed the real_frozen_graphs branch from 9775d3a to ed6f73d Compare July 2, 2026 22:30
@ricardoV94 ricardoV94 merged commit 6b70d2c into pymc-devs:main Jul 3, 2026
66 checks passed
@ricardoV94 ricardoV94 deleted the real_frozen_graphs branch July 3, 2026 05:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment