diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 8291663c27..bc38f4391e 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -21,7 +21,6 @@ from pytensor.tensor import TensorLike from pytensor.tensor.basic import ( alloc, - arange, as_tensor_variable, cast, concatenate, @@ -400,30 +399,36 @@ def _c_all(self, node, name, input_names, output_names, sub): return setup, alloc, loop, cast -class Max(NonZeroDimsCAReduce): - nfunc_spec = ("max", 1, 1) +class MaxAndMinCAReduce(NonZeroDimsCAReduce): + """Base class for the :class:`Max` and :class:`Min` reduction ``Op``\\s. - def __init__(self, axis): - super().__init__(ps.maximum, axis) + A maximum and a minimum reduction differ only in *which* element along the + reduced axes is selected; *how* derivatives propagate through that + selection is identical. In both cases the (weak) derivative routes through + the position(s) of the input that attain the reduced output. Keeping the + differentiation rules in a single place therefore avoids duplication and + guarantees that :class:`Max` and :class:`Min` stay consistent. + + Subclasses only need to bind the appropriate scalar ``Op`` (``maximum`` or + ``minimum``) in their ``__init__`` and set ``nfunc_spec``. + """ def clone(self, **kwargs): axis = kwargs.get("axis", self.axis) return type(self)(axis=axis) def pullback(self, inputs, outputs, output_grads): - # The strict sense mathematical gradient of the maximum function is - # not calculated here for it is not defined at every point where some - # coordinates are identical. However, since the latter set has null - # Lebesgue measure, the result may be interpreted as weak gradient. - - # @note: This function should work correctly for L{vector}s. - # (x, y), (gz, gw) - # gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy - # gMax * dMax/dx + gArgMax * dArgMax/dx, - # gMax * dMax/daxis + gArgMax * dArgMax/daxis - # g_max has one less dimension than x, so you need to complete - # g_max to x's shape when axis=0 the broadcasting mechanism - # does it automatically + # The strict-sense mathematical gradient of a maximum/minimum reduction + # is not defined at points where the extremum is attained by more than + # one coordinate. However, since that set has null Lebesgue measure, the + # result below may be interpreted as a weak gradient: the cotangent is + # routed to *every* position that attains the extremum (i.e. where the + # input equals the reduced output). This rule is identical for `Max` and + # `Min`, which is why it lives on the shared base class. + # + # `out`/`g_out` have one fewer dimension than `x` along each reduced + # axis, so we re-insert those axes (`expand_dims`) to let broadcasting + # spread the cotangent back over `x`'s shape. [x] = inputs [out] = outputs [g_out] = output_grads @@ -436,54 +441,42 @@ def pullback(self, inputs, outputs, output_grads): g_x = eq(out_pad, x) * g_out_pad return (g_x,) - def pushforward(self, inputs, outputs, eval_points): + def pushforward(self, inputs, outputs, tangents): + # Forward-mode is the exact transpose of `pullback`: the output tangent + # is gathered from the input tangent at the extremum position(s) and + # summed over the reduced axes. For a unique extremum (the almost- + # everywhere case) this simply selects the tangent of the winning + # element; on the zero-measure tie set it sums their tangents, which is + # precisely what makes it the transpose of `pullback` (and keeps + # forward- and reverse-mode consistent). Working for any number of + # reduced axes and input dimensions, this also generalises the previous + # matrix-only implementation. [x] = inputs - if isinstance(eval_points[0].type, DisconnectedType): + [out] = outputs + [x_dot] = tangents + + if isinstance(x_dot.type, DisconnectedType): return [disconnected_type()] - axis = tuple(range(x.ndim) if self.axis is None else self.axis) - if isinstance(axis, int): - axis = [axis] - if len(axis) != 1: - raise NotImplementedError("R_op supported for max only for one axis!") - if axis[0] > 1: - raise NotImplementedError("R_op supported for max only when axis is 0 or 1") - if inputs[0].ndim != 2: - raise NotImplementedError( - "R_op supported for max only when input is a matrix" - ) - max_pos = Argmax(self.axis)(*inputs) - if self.axis[0] == 0: - return [eval_points[0][max_pos, arange(eval_points[0].shape[1])]] - else: - return [eval_points[0][arange(eval_points[0].shape[0]), max_pos]] + axis = tuple(range(x.ndim)) if self.axis is None else self.axis + out_pad = expand_dims(out, axis) -class Min(NonZeroDimsCAReduce): - nfunc_spec = ("min", 1, 1) + out_dot = (eq(out_pad, x) * x_dot).sum(axis=axis) + return [out_dot] - def __init__(self, axis): - super().__init__(ps.minimum, axis) - def clone(self, **kwargs): - axis = kwargs.get("axis", self.axis) - return type(self)(axis=axis) +class Max(MaxAndMinCAReduce): + nfunc_spec = ("max", 1, 1) - def pullback(self, inputs, outputs, output_grads): - # The strict sense mathematical gradient of the minimum function is - # not calculated here for it is not defined at every point where some - # coordinates are identical. However, since the latter set has null - # Lebesgue measure, the result may be interpreted as weak gradient. - [x] = inputs - [out] = outputs - [g_out] = output_grads + def __init__(self, axis): + super().__init__(ps.maximum, axis) - axis = tuple(range(x.ndim)) if self.axis is None else self.axis - out_pad = expand_dims(out, axis) - g_out_pad = expand_dims(g_out, axis) - # Set the grad to the correct position. - g_x = eq(out_pad, x) * g_out_pad - return (g_x,) +class Min(MaxAndMinCAReduce): + nfunc_spec = ("min", 1, 1) + + def __init__(self, axis): + super().__init__(ps.minimum, axis) def max(x, axis=None, keepdims=False): diff --git a/tests/test_rop.py b/tests/test_rop.py index a819daae8f..17602a75e4 100644 --- a/tests/test_rop.py +++ b/tests/test_rop.py @@ -25,7 +25,7 @@ from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.graph.replace import clone_replace -from pytensor.tensor.math import argmax, dot +from pytensor.tensor.math import Max, Min, argmax, dot from pytensor.tensor.math import max as pt_max from pytensor.tensor.type import matrix, vector from tests import unittest_tools as utt @@ -217,6 +217,59 @@ def test_max(self): pt_max(self.mx, axis=1), (self.mat_in_shape[0],) ) + def test_min(self): + # `Min` is the bare reduction Op. Unlike `pt.min`, which lowers to + # ``-max(-x)`` and never instantiates `Min`, this exercises the Op's + # own pushforward/pullback. It must behave exactly like `Max`. + self.check_mat_pushforward_pullback( + Min(axis=[0])(self.mx), (self.mat_in_shape[1],) + ) + self.check_mat_pushforward_pullback( + Min(axis=[1])(self.mx), (self.mat_in_shape[0],) + ) + + def test_max_min_pushforward_on_ties(self): + # With tied extrema, reverse-mode (`pullback`) routes the cotangent to + # *every* position attaining the extremum. For forward-mode to remain + # the exact transpose (adjoint) of reverse-mode, `pushforward` must + # *sum* the tied input tangents. The previous matrix-only, `Argmax`- + # based implementation instead picked a single winner and was therefore + # NOT the adjoint of its own pullback on ties. Random data is tie-free + # almost surely (so `test_max`/`test_min` would not catch a regression + # to single-winner semantics); this checks the tied case explicitly for + # both `Max` and `Min`. + mx = matrix("mx") + mv = matrix("mv") + # Columns 0 and 2 are constant -> 3-way ties along axis 0; column 1 has + # a unique max (row 1) and a unique min (row 0). + x_val = np.array( + [[1.0, 2.0, 1.0], [1.0, 5.0, 1.0], [1.0, 4.0, 1.0]], + dtype=config.floatX, + ) + v_val = np.array( + [[1.0, 10.0, 100.0], [2.0, 20.0, 200.0], [4.0, 30.0, 300.0]], + dtype=config.floatX, + ) + + for op_cls in (Max, Min): + for axis in ([0], [1]): + y = op_cls(axis=axis)(mx) + # The Op's own `pushforward` must equal the pushforward obtained + # by transposing the `pullback` (`use_op_pushforward=False`). + yv_op = pushforward(y, mx, mv, use_op_pushforward=True) + yv_ref = pushforward(y, mx, mv, use_op_pushforward=False) + f_op = function([mx, mv], yv_op, on_unused_input="ignore") + f_ref = function([mx, mv], yv_ref, on_unused_input="ignore") + np.testing.assert_allclose(f_op(x_val, v_val), f_ref(x_val, v_val)) + + # Pin the tie convention explicitly: along axis 0 the tied columns sum + # their tangents (1+2+4=7 and 100+200+300=600); the unique-extremum + # column selects the winning row (max -> row 1 = 20, min -> row 0 = 10). + f_max0 = function([mx, mv], pushforward(Max(axis=[0])(mx), mx, mv)) + f_min0 = function([mx, mv], pushforward(Min(axis=[0])(mx), mx, mv)) + np.testing.assert_allclose(f_max0(x_val, v_val), [7.0, 20.0, 600.0]) + np.testing.assert_allclose(f_min0(x_val, v_val), [7.0, 10.0, 600.0]) + def test_argmax(self): self.check_nondiff_pushforward(argmax(self.mx, axis=1), self.mx, self.mv)