Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 50 additions & 57 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from pytensor.tensor import TensorLike
from pytensor.tensor.basic import (
alloc,
arange,
as_tensor_variable,
cast,
concatenate,
Expand Down Expand Up @@ -400,30 +399,36 @@ def _c_all(self, node, name, input_names, output_names, sub):
return setup, alloc, loop, cast


class Max(NonZeroDimsCAReduce):
nfunc_spec = ("max", 1, 1)
class MaxAndMinCAReduce(NonZeroDimsCAReduce):
"""Base class for the :class:`Max` and :class:`Min` reduction ``Op``\\s.

def __init__(self, axis):
super().__init__(ps.maximum, axis)
A maximum and a minimum reduction differ only in *which* element along the
reduced axes is selected; *how* derivatives propagate through that
selection is identical. In both cases the (weak) derivative routes through
the position(s) of the input that attain the reduced output. Keeping the
differentiation rules in a single place therefore avoids duplication and
guarantees that :class:`Max` and :class:`Min` stay consistent.

Subclasses only need to bind the appropriate scalar ``Op`` (``maximum`` or
``minimum``) in their ``__init__`` and set ``nfunc_spec``.
"""

def clone(self, **kwargs):
axis = kwargs.get("axis", self.axis)
return type(self)(axis=axis)

def pullback(self, inputs, outputs, output_grads):
# The strict sense mathematical gradient of the maximum function is
# not calculated here for it is not defined at every point where some
# coordinates are identical. However, since the latter set has null
# Lebesgue measure, the result may be interpreted as weak gradient.

# @note: This function should work correctly for L{vector}s.
# (x, y), (gz, gw)
# gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy
# gMax * dMax/dx + gArgMax * dArgMax/dx,
# gMax * dMax/daxis + gArgMax * dArgMax/daxis
# g_max has one less dimension than x, so you need to complete
# g_max to x's shape when axis=0 the broadcasting mechanism
# does it automatically
# The strict-sense mathematical gradient of a maximum/minimum reduction
# is not defined at points where the extremum is attained by more than
# one coordinate. However, since that set has null Lebesgue measure, the
# result below may be interpreted as a weak gradient: the cotangent is
# routed to *every* position that attains the extremum (i.e. where the
# input equals the reduced output). This rule is identical for `Max` and
# `Min`, which is why it lives on the shared base class.
#
# `out`/`g_out` have one fewer dimension than `x` along each reduced
# axis, so we re-insert those axes (`expand_dims`) to let broadcasting
# spread the cotangent back over `x`'s shape.
[x] = inputs
[out] = outputs
[g_out] = output_grads
Expand All @@ -436,54 +441,42 @@ def pullback(self, inputs, outputs, output_grads):
g_x = eq(out_pad, x) * g_out_pad
return (g_x,)

def pushforward(self, inputs, outputs, eval_points):
def pushforward(self, inputs, outputs, tangents):
# Forward-mode is the exact transpose of `pullback`: the output tangent
# is gathered from the input tangent at the extremum position(s) and
# summed over the reduced axes. For a unique extremum (the almost-
# everywhere case) this simply selects the tangent of the winning
# element; on the zero-measure tie set it sums their tangents, which is
# precisely what makes it the transpose of `pullback` (and keeps
# forward- and reverse-mode consistent). Working for any number of
# reduced axes and input dimensions, this also generalises the previous
# matrix-only implementation.
[x] = inputs
if isinstance(eval_points[0].type, DisconnectedType):
[out] = outputs
[x_dot] = tangents

if isinstance(x_dot.type, DisconnectedType):
return [disconnected_type()]
axis = tuple(range(x.ndim) if self.axis is None else self.axis)
if isinstance(axis, int):
axis = [axis]
if len(axis) != 1:
raise NotImplementedError("R_op supported for max only for one axis!")
if axis[0] > 1:
raise NotImplementedError("R_op supported for max only when axis is 0 or 1")
if inputs[0].ndim != 2:
raise NotImplementedError(
"R_op supported for max only when input is a matrix"
)
max_pos = Argmax(self.axis)(*inputs)
if self.axis[0] == 0:
return [eval_points[0][max_pos, arange(eval_points[0].shape[1])]]
else:
return [eval_points[0][arange(eval_points[0].shape[0]), max_pos]]

axis = tuple(range(x.ndim)) if self.axis is None else self.axis
out_pad = expand_dims(out, axis)

class Min(NonZeroDimsCAReduce):
nfunc_spec = ("min", 1, 1)
out_dot = (eq(out_pad, x) * x_dot).sum(axis=axis)
return [out_dot]

def __init__(self, axis):
super().__init__(ps.minimum, axis)

def clone(self, **kwargs):
axis = kwargs.get("axis", self.axis)
return type(self)(axis=axis)
class Max(MaxAndMinCAReduce):
nfunc_spec = ("max", 1, 1)

def pullback(self, inputs, outputs, output_grads):
# The strict sense mathematical gradient of the minimum function is
# not calculated here for it is not defined at every point where some
# coordinates are identical. However, since the latter set has null
# Lebesgue measure, the result may be interpreted as weak gradient.
[x] = inputs
[out] = outputs
[g_out] = output_grads
def __init__(self, axis):
super().__init__(ps.maximum, axis)

axis = tuple(range(x.ndim)) if self.axis is None else self.axis
out_pad = expand_dims(out, axis)
g_out_pad = expand_dims(g_out, axis)

# Set the grad to the correct position.
g_x = eq(out_pad, x) * g_out_pad
return (g_x,)
class Min(MaxAndMinCAReduce):
nfunc_spec = ("min", 1, 1)

def __init__(self, axis):
super().__init__(ps.minimum, axis)


def max(x, axis=None, keepdims=False):
Expand Down
55 changes: 54 additions & 1 deletion tests/test_rop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.graph.replace import clone_replace
from pytensor.tensor.math import argmax, dot
from pytensor.tensor.math import Max, Min, argmax, dot
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.type import matrix, vector
from tests import unittest_tools as utt
Expand Down Expand Up @@ -217,6 +217,59 @@ def test_max(self):
pt_max(self.mx, axis=1), (self.mat_in_shape[0],)
)

def test_min(self):
# `Min` is the bare reduction Op. Unlike `pt.min`, which lowers to
# ``-max(-x)`` and never instantiates `Min`, this exercises the Op's
# own pushforward/pullback. It must behave exactly like `Max`.
self.check_mat_pushforward_pullback(
Min(axis=[0])(self.mx), (self.mat_in_shape[1],)
)
self.check_mat_pushforward_pullback(
Min(axis=[1])(self.mx), (self.mat_in_shape[0],)
)

def test_max_min_pushforward_on_ties(self):
# With tied extrema, reverse-mode (`pullback`) routes the cotangent to
# *every* position attaining the extremum. For forward-mode to remain
# the exact transpose (adjoint) of reverse-mode, `pushforward` must
# *sum* the tied input tangents. The previous matrix-only, `Argmax`-
# based implementation instead picked a single winner and was therefore
# NOT the adjoint of its own pullback on ties. Random data is tie-free
# almost surely (so `test_max`/`test_min` would not catch a regression
# to single-winner semantics); this checks the tied case explicitly for
# both `Max` and `Min`.
mx = matrix("mx")
mv = matrix("mv")
# Columns 0 and 2 are constant -> 3-way ties along axis 0; column 1 has
# a unique max (row 1) and a unique min (row 0).
x_val = np.array(
[[1.0, 2.0, 1.0], [1.0, 5.0, 1.0], [1.0, 4.0, 1.0]],
dtype=config.floatX,
)
v_val = np.array(
[[1.0, 10.0, 100.0], [2.0, 20.0, 200.0], [4.0, 30.0, 300.0]],
dtype=config.floatX,
)

for op_cls in (Max, Min):
for axis in ([0], [1]):
y = op_cls(axis=axis)(mx)
# The Op's own `pushforward` must equal the pushforward obtained
# by transposing the `pullback` (`use_op_pushforward=False`).
yv_op = pushforward(y, mx, mv, use_op_pushforward=True)
yv_ref = pushforward(y, mx, mv, use_op_pushforward=False)
f_op = function([mx, mv], yv_op, on_unused_input="ignore")
f_ref = function([mx, mv], yv_ref, on_unused_input="ignore")
np.testing.assert_allclose(f_op(x_val, v_val), f_ref(x_val, v_val))

# Pin the tie convention explicitly: along axis 0 the tied columns sum
# their tangents (1+2+4=7 and 100+200+300=600); the unique-extremum
# column selects the winning row (max -> row 1 = 20, min -> row 0 = 10).
f_max0 = function([mx, mv], pushforward(Max(axis=[0])(mx), mx, mv))
f_min0 = function([mx, mv], pushforward(Min(axis=[0])(mx), mx, mv))
np.testing.assert_allclose(f_max0(x_val, v_val), [7.0, 20.0, 600.0])
np.testing.assert_allclose(f_min0(x_val, v_val), [7.0, 10.0, 600.0])

def test_argmax(self):
self.check_nondiff_pushforward(argmax(self.mx, axis=1), self.mx, self.mv)

Expand Down
Loading