diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 7eb7150fef..e225698514 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -47,6 +47,25 @@ def _match_tangent_dtype(var: Variable, tangent: Variable) -> Variable: grad_time: float = 0.0 +# Hooks that rewrite the graph being differentiated before the gradient traversal. +# A subsystem whose Ops are not differentiable node-by-node (e.g. xtensor, which is +# only meaningful once lowered to tensor Ops) registers one to collapse such a region +# into a single differentiable node, so `grad` handles it as a unit. Each hook takes +# the differentiated outputs and the boundary variables (`wrt` and `consider_constant`, +# which must stay reachable) and returns rewritten outputs. +_grad_graph_rewriters: list[ + Callable[[list[Variable], list[Variable]], list[Variable]] +] = [] + + +def register_grad_graph_rewriter( + fn: Callable[[list[Variable], list[Variable]], list[Variable]], +) -> Callable[[list[Variable], list[Variable]], list[Variable]]: + """Register a hook applied to the graph being differentiated by `grad`.""" + _grad_graph_rewriters.append(fn) + return fn + + # TODO: Add `overload` variants def as_list_or_tuple[V: Variable | None]( use_list: bool, use_tuple: bool, outputs: V | Sequence[V] @@ -671,6 +690,20 @@ def grad( else: _wrt = list(wrt) + if _grad_graph_rewriters and (cost is not None or known_grads): + # Let a subsystem (e.g. xtensor) collapse a region it only differentiates as a + # whole into a single node before traversal. `wrt` and `consider_constant` are + # kept as graph inputs so grad can still reach (and stop at) them after collapse. + boundaries = [*_wrt, *(consider_constant or ())] + roots = ([cost] if cost is not None else []) + list(known_grads or ()) + for rewrite in _grad_graph_rewriters: + roots = rewrite(roots, boundaries) + roots_iter = iter(roots) + if cost is not None: + cost = next(roots_iter) + if known_grads: + known_grads = dict(zip(roots_iter, known_grads.values(), strict=True)) + outputs = [] if cost is not None: outputs.append(cost) diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py index 22de7642f1..dfaf74959f 100644 --- a/pytensor/xtensor/__init__.py +++ b/pytensor/xtensor/__init__.py @@ -1,3 +1,4 @@ +import pytensor.xtensor.gradient # registers the grad hook for xtensor regions import pytensor.xtensor.rewriting from pytensor.xtensor import linalg, math, random, signal from pytensor.xtensor.math import dot, where diff --git a/pytensor/xtensor/gradient.py b/pytensor/xtensor/gradient.py new file mode 100644 index 0000000000..3cd2cee40c --- /dev/null +++ b/pytensor/xtensor/gradient.py @@ -0,0 +1,208 @@ +"""Make xtensor graphs differentiable through ``pytensor.grad``. + +xtensor Ops carry no gradient of their own: they are meaningful only once lowered to +tensor Ops. Rather than differentiate them node-by-node, this grabs each xtensor region +-- the subgraph between the tensor<->xtensor conversion boundaries -- as a single +``OpFromGraph`` unit whose inner graph is the region lowered to tensor Ops. ``grad`` +then differentiates the unit as a whole through the ordinary tensor rules. The lowering +happens here, in a graph pass registered with ``grad``, never inside an Op's pullback. +""" + +from typing import cast + +from pytensor.compile.builders import OpFromGraph +from pytensor.gradient import register_grad_graph_rewriter +from pytensor.graph.basic import Constant, Variable +from pytensor.graph.replace import graph_replace +from pytensor.graph.rewriting.utils import rewrite_graph +from pytensor.graph.traversal import ancestors, toposort +from pytensor.tensor.type_other import SliceType +from pytensor.xtensor.basic import ( + TensorFromXTensor, + tensor_from_xtensor, + xtensor_from_tensor, +) +from pytensor.xtensor.random.type import ( + XRandomGeneratorType, + XRNGToRNG, + rng_to_xrng, + xrng_to_rng, +) +from pytensor.xtensor.type import XTensorType + + +_XTENSOR_TYPES = (XTensorType, XRandomGeneratorType) +# Ops that carry an xtensor value back into the tensor world (the region's outputs). +_EXIT_OPS = (TensorFromXTensor, XRNGToRNG) + + +def _is_xtensor(var: Variable) -> bool: + return isinstance(var.type, _XTENSOR_TYPES) + + +def _to_tensor_world(var: Variable) -> Variable: + """Convert an xtensor-world variable to its tensor-world equivalent.""" + if isinstance(var.type, XTensorType): + return cast(Variable, tensor_from_xtensor(var)) + if isinstance(var.type, XRandomGeneratorType): + return cast(Variable, xrng_to_rng(var)) + return var + + +def _from_tensor_world(var: Variable, dummy: Variable) -> Variable: + """Rebuild the xtensor-world equivalent of ``var`` from a tensor-world dummy.""" + if isinstance(var.type, XTensorType): + return cast(Variable, xtensor_from_tensor(dummy, dims=var.type.dims)) + if isinstance(var.type, XRandomGeneratorType): + return cast(Variable, rng_to_xrng(dummy)) + return dummy + + +def _collapse_region( + exit_var: Variable, boundaries: set[Variable], memo: dict[Variable, Variable] +) -> Variable: + """Build the lowered OpFromGraph unit replacing the xtensor region at ``exit_var``. + + The unit's inner graph is built from the original region; its outer inputs are + resolved through ``memo`` so inner regions already collapsed in this pass are + consumed through their fresh replacements. + """ + # Walk the xtensor cone of exit_var up to its boundaries. Non-xtensor inputs (tensors + # entering via XTensorFromTensor, indices, rngs) and any boundary/leaf become unit + # inputs; constants stay inside the unit. `boundaries` are the wrt and consider_constant + # variables, kept as inputs so grad still sees (and can stop at) them after collapse. + inputs: list[Variable] = [] + seen: set[Variable] = set() + + def visit(var: Variable) -> None: + # Keep var inside the unit and recurse through its node's inputs. + if var in seen: + return + seen.add(var) + if var in boundaries or var.owner is None: + if not isinstance(var, Constant): + inputs.append(var) + return + for inp in var.owner.inputs: + visit_input(inp) + + def visit_input(inp: Variable) -> None: + if inp not in boundaries and ( + _is_xtensor(inp) + # A computed slice (e.g. MakeSlice) must also stay inside the unit: passed + # in as an opaque SliceType input the lowering could not pattern-match it. + # Its symbolic components become unit inputs instead; constants stay inside. + or (isinstance(inp.type, SliceType) and inp.owner is not None) + ): + visit(inp) + elif not isinstance(inp, Constant): + inputs.append(inp) + + for inp in exit_var.owner.inputs: + visit_input(inp) + inputs = list(dict.fromkeys(inputs)) + + # Keep the unit's boundaries tensor-typed: convert any xtensor-world input outside + # the unit and rebuild it inside. This keeps the unit a plain tensor->tensor op that + # lowers fully, so a repeated grad (e.g. second order wrt an xtensor) sees no residual + # conversion. Tensor inputs (the common case) are passed through unchanged. Only the + # outer inputs go through `memo`: boundary/leaf xtensor inputs are never rebuilt, so + # the inner graph (built from the original identities) is unaffected. + outer_inputs = [_to_tensor_world(memo.get(v, v)) for v in inputs] + dummies = [inp.type() for inp in outer_inputs] + inner = [_from_tensor_world(v, d) for v, d in zip(inputs, dummies)] + if inputs: + [region] = graph_replace([exit_var], dict(zip(inputs, inner)), strict=False) + [lowered] = cast( + list[Variable], + rewrite_graph([region], include=("lower_xtensor",), clone=False), + ) + else: + # Fully constant region: lower it in place, nothing to differentiate through. + [lowered] = cast( + list[Variable], + rewrite_graph([exit_var], include=("lower_xtensor",), clone=True), + ) + + if any( + var.owner is not None + and isinstance(var.owner.op, _EXIT_OPS) + and var.owner.inputs[0].owner is not None + for var in ancestors([lowered]) + ): + # lower_xtensor could not fully lower the region; wrapping it would leave an + # un-lowerable xtensor node inside the unit and recurse. Fail loudly instead. + raise NotImplementedError( + f"Cannot differentiate through xtensor region ending at {exit_var}: " + "lower_xtensor left an un-lowered conversion in it." + ) + + if not inputs: + return lowered + unit = OpFromGraph(dummies, [lowered], inline=True) + [new_exit] = unit(*outer_inputs, return_list=True) + return cast(Variable, new_exit) + + +@register_grad_graph_rewriter +def collapse_xtensor_grad_regions( + outputs: list[Variable], boundaries: list[Variable] +) -> list[Variable]: + """Collapse every xtensor region in ``outputs`` into a lowered OpFromGraph unit.""" + for out in outputs: + if _is_xtensor(out): + raise TypeError( + "Cannot differentiate an xtensor-typed variable directly: " + f"{out} (of type {out.type}) was passed as the cost or as a " + "known_grads key. Convert it to a tensor first, e.g. with " + "`cost.values`, and differentiate that instead." + ) + outputs = list(outputs) + boundary_set = set(boundaries) + # wrt/consider_constant act by variable identity: an exit that a boundary IS, or + # that a boundary DEPENDS ON, must survive the collapse unchanged, or the rebuilt + # graph would no longer contain the boundary and grad could not reach (or stop at) + # it. Protect the boundaries' whole ancestor cone (`ancestors` is inclusive). The + # protected set is closed under ancestry, so nothing protected can ever depend on + # a collapsed exit -- protected variables are never rebuilt below. + protected = set(ancestors(boundaries)) + + def is_exit(var: Variable) -> bool: + # Only regions with a computed, non-protected xtensor value need collapsing. + # Converting a bare leaf/constant (e.g. TensorFromXTensor of an XTensorConstant) + # has nothing to lower away, and collapsing it would just re-wrap the same node + # and recurse forever. Their trivial boundary pullback handles them instead. + return ( + var.owner is not None + and isinstance(var.owner.op, _EXIT_OPS) + and var.owner.inputs[0].owner is not None + and var not in protected + and var.owner.inputs[0] not in protected + ) + + while True: + # One bottom-up pass over the graph: collapse every exit into its unit and + # re-clone each downstream node (once) onto the rebuilt variables. Toposort + # visits dependencies first, so an inner region is collapsed before any region + # or node consuming it -- `memo` carries the fresh identities forward. This + # keeps a chain of N regions at one pass instead of one graph rebuild each. + memo: dict[Variable, Variable] = {} + for node in toposort(outputs): + if isinstance(node.op, _EXIT_OPS) and is_exit(node.outputs[0]): + exit_var = node.outputs[0] + memo[exit_var] = _collapse_region(exit_var, boundary_set, memo) + elif any(memo.get(inp, inp) is not inp for inp in node.inputs): + new_node = node.clone_with_new_inputs( + [memo.get(inp, inp) for inp in node.inputs] + ) + memo.update(zip(node.outputs, new_node.outputs)) + if not memo: + # `memo` only gains entries through an exit collapse (rebuilds require an + # already-changed input), so an empty memo means no exits were found. + return outputs + outputs = [memo.get(out, out) for out in outputs] + # Loop as insurance: collapsed units cannot contain further exits (the residual + # guard raises otherwise), so the next pass normally finds nothing and returns. + # (Differentiating a unit later re-enters `grad` and hence this hook on its + # pure-tensor inner graph — those separate, exit-free invocations return here + # immediately and are not extra passes of this loop.) diff --git a/tests/xtensor/test_grad.py b/tests/xtensor/test_grad.py new file mode 100644 index 0000000000..a8491571b1 --- /dev/null +++ b/tests/xtensor/test_grad.py @@ -0,0 +1,275 @@ +import pytest + + +pytest.importorskip("xarray") +pytestmark = pytest.mark.filterwarnings("error") + +import numpy as np + +import pytensor +import pytensor.tensor as pt +import pytensor.xtensor as px +from pytensor.gradient import pushforward +from pytensor.graph import rewrite_graph +from pytensor.xtensor.type import as_xtensor +from tests.unittest_tools import verify_grad + + +def grad_through_lowering(cost, wrt): + """Reference: lower the xtensor graph to tensor ops, then take the gradient.""" + cost = rewrite_graph(cost, include=("lower_xtensor",), clone=True) + return pt.grad(cost, wrt) + + +def _x(): + xt = pt.tensor("x", shape=(3, 4)) + return xt, as_xtensor(xt, dims=("a", "b")) + + +def _y(): + yt = pt.tensor("y", shape=(4, 2)) + return yt, as_xtensor(yt, dims=("b", "c")) + + +def build_cases(): + xt, x = _x() + yt, y = _y() + return [ + ("reduce_sum", (px.math.exp(x).sum("a") * 1.5).sum(), [xt]), + ("reduce_mean_std", (x.mean("a") + x.std("a")).sum(), [xt]), + ("reduce_max", (x.max("a") * 1.5).sum(), [xt]), + ("reduce_min", (x.min("a") * 1.5).sum(), [xt]), + ("cumsum", px.math.exp(x).cumsum("a").sum(), [xt]), + ("elemwise", (px.math.tanh(x) * px.math.sin(x)).sum(), [xt]), + ("transpose", (x.transpose("b", "a") ** 2).sum(), [xt]), + ("concat", px.concat([x, x + 1.0], dim="a").sum(), [xt]), + ("stack", px.math.exp(x).stack({"z": ("a", "b")}).sum(), [xt]), + ("rename", (x.rename({"a": "a2"}) ** 2).sum(), [xt]), + # Swapping names exercises Rename as a positional relabel (not a permutation). + ("rename_swap", (x.rename({"a": "b", "b": "a"}).sum("a") ** 2).sum(), [xt]), + ("dot", (px.dot(x, y, dim="b") ** 2).sum(), [xt, yt]), + ] + + +@pytest.mark.parametrize( + "loss, wrt", + [pytest.param(loss, wrt, id=name) for name, loss, wrt in build_cases()], +) +def test_grad_matches_lowering(loss, wrt): + # pt.grad must work directly on the un-lowered xtensor graph and agree with the + # supported "lower first, then grad" path. + rng = np.random.default_rng(7) + test_vals = [rng.normal(size=w.type.shape).astype(w.type.dtype) for w in wrt] + g_direct = pt.grad(loss.values, wrt) + g_ref = grad_through_lowering(loss.values, wrt) + fn = pytensor.function(wrt, [*g_direct, *g_ref]) + out = fn(*test_vals) + n = len(wrt) + for direct, ref in zip(out[:n], out[n:]): + np.testing.assert_allclose(direct, ref) + + +def test_grad_repeated_input(): + # A repeated input must accumulate per-slot cotangents (no factor-of-N error). + xt = pt.vector("x", shape=(3,)) + x = as_xtensor(xt, dims=("a",)) + x_test = np.array([1.0, 2.0, 3.0]) + for power, loss in [(2, (x * x).sum()), (3, (x * x * x).sum())]: + g = pytensor.function([xt], pt.grad(loss.values, xt))(x_test) + np.testing.assert_allclose(g, power * x_test ** (power - 1)) + + +def test_grad_second_order(): + W = pytensor.shared(np.ones((3, 2)), name="W") + xt = pt.vector("x", shape=(3,)) + x = as_xtensor(xt, dims=("a",)) + y = px.dot(x, as_xtensor(W, dims=("a", "b")), dim="a") + loss = (y * y).sum() + g2 = pt.grad(pt.grad(loss.values, W).sum(), W) + g2_ref = pt.grad(grad_through_lowering(loss.values, W).sum(), W) + direct, ref = pytensor.function([xt], [g2, g2_ref])(np.arange(3.0)) + np.testing.assert_allclose(direct, ref) + + +def test_grad_through_indexing(): + # The index itself is non-differentiable (an integer xtensor) so it gets an + # undefined gradient, but the array input's gradient is still correct: a scatter + # of the cotangent into the indexed positions. + xt = pt.tensor("x", shape=(3, 4)) + x = as_xtensor(xt, dims=("a", "b")) + loss = (x.isel(a=1) ** 2).sum() + grad = pt.grad(loss.values, xt) + x_test = np.arange(12.0).reshape(3, 4) + expected = np.zeros((3, 4)) + expected[1] = 2 * x_test[1] + np.testing.assert_allclose(pytensor.function([xt], grad)(x_test), expected) + + +def test_verify_grad(): + rng = np.random.default_rng(seed=420) + + def dot_loss(x, w): + xx = as_xtensor(x, dims=("a",)) + ww = as_xtensor(w, dims=("a", "b")) + return (px.dot(xx, ww, dim="a") ** 2).sum().values + + verify_grad(dot_loss, [rng.normal(size=(3,)), rng.normal(size=(3, 2))], rng=rng) + + +def test_grad_consider_constant(): + # consider_constant on a variable internal to the xtensor region must still stop the + # gradient there once the region is collapsed for differentiation. + xt = pt.tensor("x", shape=(3, 4)) + x = as_xtensor(xt, dims=("a", "b")) + inter = px.math.exp(x) + loss = ((inter * inter).sum() + (x * x).sum()).values + g = pytensor.function([xt], pt.grad(loss, xt, consider_constant=[inter])) + # exp(x) is held constant, so only the (x*x) path contributes: d/dx = 2x. + x_test = np.random.default_rng(0).normal(size=(3, 4)) + np.testing.assert_allclose(g(x_test), 2 * x_test) + + +def test_grad_second_order_xtensor_wrt(): + # Differentiating twice w.r.t. an xtensor variable: the collapsed unit is a plain + # tensor op, so the repeated grad lowers and differentiates it without a residual. + x = px.xtensor("x", dims=("a",), shape=(3,)) + loss = (x * x * x).sum() # d2/dx2 = 6x + g2 = pt.grad(pt.grad(loss.values, x).sum().values, x) + x_test = np.array([1.0, 2.0, 3.0]) + np.testing.assert_allclose(pytensor.function([x], g2.values)(x_test), 6 * x_test) + + +def test_grad_chained_regions(): + # Leaving and re-entering the xtensor world collapses into separate units and still + # matches the lower-then-grad reference (order of collapse must not form a cycle). + xt = pt.tensor("x", shape=(3, 4)) + x = as_xtensor(xt, dims=("a", "b")) + reentered = as_xtensor(x.sum("a").values + 1.0, dims=("b",)) + loss = (reentered**2).sum().values + [direct] = pt.grad(loss, [xt]) + [ref] = grad_through_lowering(loss, [xt]) + d, r = pytensor.function([xt], [direct, ref])( + np.random.default_rng(0).normal(size=(3, 4)) + ) + np.testing.assert_allclose(d, r) + + +def test_grad_wrt_exit_variable(): + # wrt a region's tensor exit (`expr.values`): the boundary acts by identity, so the + # collapse must leave the exit (and everything it depends on) in place for grad to + # reach it; no xtensor op needs differentiating on the cost -> wrt path. + xt = pt.tensor("x", shape=(3, 4)) + x = as_xtensor(xt, dims=("a", "b")) + w = px.math.exp(x).values + g = pt.grad((w**2).sum(), w) + x_test = np.random.default_rng(1).normal(size=(3, 4)) + np.testing.assert_allclose(pytensor.function([xt], g)(x_test), 2 * np.exp(x_test)) + + +def test_grad_consider_constant_exit_variable(): + # consider_constant on a region's tensor exit must still stop the gradient there; + # rewriting the exit out of the graph silently dropped the stop. + xt = pt.tensor("x", shape=(3,)) + x = as_xtensor(xt, dims=("a",)) + w = px.math.exp(x).values + cost = (w * w).sum() + (xt**2).sum() + g = pt.grad(cost, xt, consider_constant=[w]) + x_test = np.array([0.5, 1.0, 1.5]) + np.testing.assert_allclose(pytensor.function([xt], g)(x_test), 2 * x_test) + + +def test_pushforward(): + # Forward mode double-pullbacks with a seed that references the exit; the exit is + # then a boundary ancestor and must survive the collapse (collapsing it silently + # disconnected the seed, returning 0). + xt = pt.tensor("x", shape=(3,)) + x = as_xtensor(xt, dims=("a",)) + cost = (px.math.exp(x) * 2.0).sum().values + v = pt.tensor("v", shape=(3,)) + jvp = pushforward(cost, [xt], [v]) + out = pytensor.function([xt, v], jvp)(np.zeros(3), np.ones(3)) + np.testing.assert_allclose(out, 6.0) + + +def test_grad_unrelated_to_random_region(): + # A random draw elsewhere in the cost graph must not break grad wrt an unrelated + # variable: RNG-typed region inputs take rng conversions, not tensor conversions. + theta = pt.scalar("theta") + rng = px.random.shared_rng(seed=0) + _, draw = px.random.normal( + 0.0, 1.0, extra_dims={"a": 3}, rng=rng, return_next_rng=True + ) + cost = draw.sum().values * 0.0 + theta**2 + g = pt.grad(cost, theta) + np.testing.assert_allclose(g.eval({theta: 3.0}), 6.0) + + +def test_grad_through_random_region(): + # Reparameterized draw: grad flows through the deterministic use of the draw and + # matches the lower-then-grad reference, with a free xrng as function input. + rng = px.random.rng("rng") + at = pt.tensor("a", shape=(3,)) + a = as_xtensor(at, dims=("d",)) + _, eps = px.random.normal( + 0.0, 1.0, extra_dims={"d": 3}, rng=rng, return_next_rng=True + ) + loss = ((a * eps) ** 2).sum().values + g_direct = pt.grad(loss, at) + g_ref = grad_through_lowering(loss, at) + a_test = np.arange(1.0, 4.0) + d, r = pytensor.function([at, rng], [g_direct, g_ref])( + a_test, np.random.default_rng(3) + ) + np.testing.assert_allclose(d, r) + + +def test_grad_through_slice_indexing(): + # Slice components are structural (MakeSlice) and must stay inside the collapsed + # unit so the indexing lowering can pattern-match them, symbolic bounds included. + at = pt.tensor("a", shape=(6, 5)) + xa = as_xtensor(at, dims=("i", "j")) + k = pt.iscalar("k") + loss = (xa.isel(i=slice(1, 4), j=slice(0, 2)) ** 2).sum().values + sym_loss = (xa.isel(i=slice(k, 4)) ** 2).sum().values + outs = [pt.grad(loss, at), *grad_through_lowering(loss, [at])] + outs += [pt.grad(sym_loss, at), *grad_through_lowering(sym_loss, [at])] + a_test = np.random.default_rng(2).normal(size=(6, 5)) + g, g_ref, gs, gs_ref = pytensor.function([at, k], outs)(a_test, 1) + np.testing.assert_allclose(g, g_ref) + np.testing.assert_allclose(gs, gs_ref) + + +def test_grad_xtensor_cost_raises(): + # A raw xtensor cost or known_grads key gets a clear error, not a crash deep in + # the grad internals. + xt = pt.tensor("x", shape=(3,)) + x = as_xtensor(xt, dims=("a",)) + with pytest.raises(TypeError, match="Convert it to a tensor first"): + pt.grad((x**2).sum(), xt) + with pytest.raises(TypeError, match="Convert it to a tensor first"): + pt.grad(None, xt, known_grads={px.math.exp(x): px.math.exp(x)}) + + +def test_grad_exit_diamond(): + # One collapsed exit feeding several consumers -- the cost directly, a re-entered + # region, and a node using it twice -- must be rebuilt once and accumulate all paths. + xt = pt.tensor("x", shape=(3,)) + x = as_xtensor(xt, dims=("a",)) + e = px.math.exp(x).sum("a").values + reentered = (as_xtensor(e * 2.0, dims=()) ** 2).sum().values + loss = e + reentered + e * e + [direct] = pt.grad(loss, [xt]) + [ref] = grad_through_lowering(loss, [xt]) + d, r = pytensor.function([xt], [direct, ref])(np.arange(3.0)) + np.testing.assert_allclose(d, r) + + +def test_grad_known_grads_exit_key(): + # A tensor exit used as a known_grads key is rewritten consistently with the cost + # graph, so the supplied cotangent lands on the collapsed unit's output. + xt = pt.tensor("x", shape=(3,)) + x = as_xtensor(xt, dims=("a",)) + w = px.math.exp(x).sum("a").values + g = pt.grad(None, xt, known_grads={w: pt.constant(2.0)}) + x_test = np.arange(3.0) + np.testing.assert_allclose(pytensor.function([xt], g)(x_test), 2.0 * np.exp(x_test))