Skip to content

BUG: NaN gradient (VJP) from pt.log(pt.sigmoid(x)) for large x in JAX backend #2264

Description

@jdeast

Describe the issue:

Summary

pt.log(pt.sigmoid(x)) produces a correct forward value but a NaN gradient
(reverse-mode / VJP) when x is large and positive (>~ 800 in float64). The
forward-mode directional derivative (JVP) is correct. Pure jax.nn.log_sigmoid
gives the correct gradient at the same value. The bug is in how PyTensor's JAX
backend compiles the Sigmoid + Log ops.

Root cause

jax.make_jaxpr shows that PyTensor compiles log(sigmoid(h)) as a
four-branch piecewise via nested select_n (i.e. jnp.where) calls.
One branch — intended for the range 18 ≤ -h < 33.3 — computes:

q = exp(-j)   where j = neg(h)

When h = 820, j = -820 and -j = 820, so q = exp(820) = +inf.
This branch is never selected (because j = -820 satisfies j < -37,
selecting the first branch instead), so the forward pass is correct.

However, JAX evaluates the gradient of every select_n branch regardless
of the condition. The VJP of exp(820) is g * exp(820) = 0 * inf = NaN
(IEEE 754: 0 × ∞ is not-a-number). That NaN propagates back to the input.

The full piecewise JAXpr (both versions 2.31.7 and 3.0.7 produce the same):

j  = neg(820 + x)          # j = -820
k  = exp(j)                # branch 1: exp(-820) = 0          [selected]
n  = log1p(exp(j))         # branch 2: log1p(0) = 0
q  = exp(neg(j))           # branch 3: exp(820) = inf  ← overflows!
r  = j + q                 # branch 3: -820 + inf = inf
s  = select(j<33.3, r, j)  # selects r  (j<33.3 is True)
t  = select(j<18,   n, s)  # selects n  (j<18   is True)
u  = select(j<-37,  k, t)  # selects k  (j<-37  is True)
out = neg(u)               # -0.0  ✓

# Backward — all branch gradients computed:
g_q = 0            (q not in selected path)
g_p = g_q * exp(q) = 0 * exp(820) = 0 * inf = NaN  ← bug

Why forward-mode (JVP) is unaffected

JVP pushes tangents forward alongside primals; each op's tangent rule
involves only the primal values that were already computed (e.g.
jvp of exp(z) is t_z * exp(z)). exp(820) = inf, but the tangent for
the unselected branch is zeroed out by the select_n tangent rule before
it multiplies by inf, so the inf does not propagate. The VJP rule does
it in the opposite order: it multiplies before zeroing, hitting 0 * inf.

Why this is a PyTensor bug (not a fundamental JAX limitation)

jax.nn.log_sigmoid handles the same input correctly because JAX
registers a custom, numerically-stable VJP for it. PyTensor could fix
this by dispatching:

Sigmoid  →  jax.nn.sigmoid     (registered JAX primitive, stable VJP)
Softplus →  jax.nn.softplus    (registered JAX primitive, stable VJP)

in its JAX Elemwise dispatch table, rather than inlining them as raw
arithmetic ops. The same NaN-in-unselected-branch problem will affect
any PyTensor piecewise formula whose unselected branches contain exp of
a large positive argument.

Workaround (user-side)

Clip the argument before passing it to log(sigmoid(.)):

# exp(700) ~ 1e304, finite in float64; avoids overflow in all VJP branches
arg_safe = pt.minimum(arg, 700.0)
pt.log(pt.sigmoid(arg_safe))

This is semantically correct: when arg >> 0 the function value is already
~0 and the gradient is already ~0, so capping at 700 changes nothing
observable while preventing the overflow.

Expected behavior

jax.grad of any PyTensor function that is smooth and mathematically
well-defined at the evaluation point should return a finite gradient,
matching the result of jax.jvp (forward-mode) and finite differences.

Note: this report was largely generated by Claude

Reproducable code example:

import jax, jax.numpy as jnp, pytensor, pytensor.tensor as pt
jax.config.update("jax_enable_x64", True)

x  = pt.dscalar("x")
fn = pytensor.function([x], pt.log(pt.sigmoid(820.0 + x)), mode="JAX")
jax_fn = fn.vm.jit_fn        # returns a tuple

def f(v):
    return jax_fn(v)[0]      # unwrap to scalar

v0 = jnp.array(0.0)
print("forward :", float(f(v0)))                 # -0.0 (correct)
print("grad    :", float(jax.grad(f)(v0)))       # nan (should be 0.0)
print("jvp     :", jax.jvp(f, (v0,), (v0,))[1])  # 0.0 (correct)

# JAX's own implementation is correct:
print("jax ref :", float(jax.grad(jax.nn.log_sigmoid)(jnp.array(820.0))))  # 0.0 (correct)

Error message:

PyTensor version information:

Tested versions: 2.31.7, 3.0.7 (latest as of 2026-06-27)

Context for the issue:

Switching from nuts to numpyro should be a one line change, but it immediately crashed in my application. The user side work around is easy (clip inputs), so it's not urgent from my end, but it was not trivial to track down and is likely to trip up other users.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions