Describe the issue:
Summary
pt.log(pt.sigmoid(x)) produces a correct forward value but a NaN gradient
(reverse-mode / VJP) when x is large and positive (>~ 800 in float64). The
forward-mode directional derivative (JVP) is correct. Pure jax.nn.log_sigmoid
gives the correct gradient at the same value. The bug is in how PyTensor's JAX
backend compiles the Sigmoid + Log ops.
Root cause
jax.make_jaxpr shows that PyTensor compiles log(sigmoid(h)) as a
four-branch piecewise via nested select_n (i.e. jnp.where) calls.
One branch — intended for the range 18 ≤ -h < 33.3 — computes:
q = exp(-j) where j = neg(h)
When h = 820, j = -820 and -j = 820, so q = exp(820) = +inf.
This branch is never selected (because j = -820 satisfies j < -37,
selecting the first branch instead), so the forward pass is correct.
However, JAX evaluates the gradient of every select_n branch regardless
of the condition. The VJP of exp(820) is g * exp(820) = 0 * inf = NaN
(IEEE 754: 0 × ∞ is not-a-number). That NaN propagates back to the input.
The full piecewise JAXpr (both versions 2.31.7 and 3.0.7 produce the same):
j = neg(820 + x) # j = -820
k = exp(j) # branch 1: exp(-820) = 0 [selected]
n = log1p(exp(j)) # branch 2: log1p(0) = 0
q = exp(neg(j)) # branch 3: exp(820) = inf ← overflows!
r = j + q # branch 3: -820 + inf = inf
s = select(j<33.3, r, j) # selects r (j<33.3 is True)
t = select(j<18, n, s) # selects n (j<18 is True)
u = select(j<-37, k, t) # selects k (j<-37 is True)
out = neg(u) # -0.0 ✓
# Backward — all branch gradients computed:
g_q = 0 (q not in selected path)
g_p = g_q * exp(q) = 0 * exp(820) = 0 * inf = NaN ← bug
Why forward-mode (JVP) is unaffected
JVP pushes tangents forward alongside primals; each op's tangent rule
involves only the primal values that were already computed (e.g.
jvp of exp(z) is t_z * exp(z)). exp(820) = inf, but the tangent for
the unselected branch is zeroed out by the select_n tangent rule before
it multiplies by inf, so the inf does not propagate. The VJP rule does
it in the opposite order: it multiplies before zeroing, hitting 0 * inf.
Why this is a PyTensor bug (not a fundamental JAX limitation)
jax.nn.log_sigmoid handles the same input correctly because JAX
registers a custom, numerically-stable VJP for it. PyTensor could fix
this by dispatching:
Sigmoid → jax.nn.sigmoid (registered JAX primitive, stable VJP)
Softplus → jax.nn.softplus (registered JAX primitive, stable VJP)
in its JAX Elemwise dispatch table, rather than inlining them as raw
arithmetic ops. The same NaN-in-unselected-branch problem will affect
any PyTensor piecewise formula whose unselected branches contain exp of
a large positive argument.
Workaround (user-side)
Clip the argument before passing it to log(sigmoid(.)):
# exp(700) ~ 1e304, finite in float64; avoids overflow in all VJP branches
arg_safe = pt.minimum(arg, 700.0)
pt.log(pt.sigmoid(arg_safe))
This is semantically correct: when arg >> 0 the function value is already
~0 and the gradient is already ~0, so capping at 700 changes nothing
observable while preventing the overflow.
Expected behavior
jax.grad of any PyTensor function that is smooth and mathematically
well-defined at the evaluation point should return a finite gradient,
matching the result of jax.jvp (forward-mode) and finite differences.
Note: this report was largely generated by Claude
Reproducable code example:
import jax, jax.numpy as jnp, pytensor, pytensor.tensor as pt
jax.config.update("jax_enable_x64", True)
x = pt.dscalar("x")
fn = pytensor.function([x], pt.log(pt.sigmoid(820.0 + x)), mode="JAX")
jax_fn = fn.vm.jit_fn # returns a tuple
def f(v):
return jax_fn(v)[0] # unwrap to scalar
v0 = jnp.array(0.0)
print("forward :", float(f(v0))) # -0.0 (correct)
print("grad :", float(jax.grad(f)(v0))) # nan (should be 0.0)
print("jvp :", jax.jvp(f, (v0,), (v0,))[1]) # 0.0 (correct)
# JAX's own implementation is correct:
print("jax ref :", float(jax.grad(jax.nn.log_sigmoid)(jnp.array(820.0)))) # 0.0 (correct)
Error message:
PyTensor version information:
Tested versions: 2.31.7, 3.0.7 (latest as of 2026-06-27)
Context for the issue:
Switching from nuts to numpyro should be a one line change, but it immediately crashed in my application. The user side work around is easy (clip inputs), so it's not urgent from my end, but it was not trivial to track down and is likely to trip up other users.
Describe the issue:
Summary
pt.log(pt.sigmoid(x))produces a correct forward value but a NaN gradient(reverse-mode / VJP) when x is large and positive (>~ 800 in float64). The
forward-mode directional derivative (JVP) is correct. Pure
jax.nn.log_sigmoidgives the correct gradient at the same value. The bug is in how PyTensor's JAX
backend compiles the Sigmoid + Log ops.
Root cause
jax.make_jaxprshows that PyTensor compileslog(sigmoid(h))as afour-branch piecewise via nested
select_n(i.e.jnp.where) calls.One branch — intended for the range 18 ≤ -h < 33.3 — computes:
When h = 820, j = -820 and -j = 820, so
q = exp(820) = +inf.This branch is never selected (because j = -820 satisfies j < -37,
selecting the first branch instead), so the forward pass is correct.
However, JAX evaluates the gradient of every
select_nbranch regardlessof the condition. The VJP of
exp(820)isg * exp(820) = 0 * inf = NaN(IEEE 754: 0 × ∞ is not-a-number). That NaN propagates back to the input.
The full piecewise JAXpr (both versions 2.31.7 and 3.0.7 produce the same):
Why forward-mode (JVP) is unaffected
JVP pushes tangents forward alongside primals; each op's tangent rule
involves only the primal values that were already computed (e.g.
jvp of exp(z) is t_z * exp(z)). exp(820) = inf, but the tangent for
the unselected branch is zeroed out by the select_n tangent rule before
it multiplies by inf, so the inf does not propagate. The VJP rule does
it in the opposite order: it multiplies before zeroing, hitting 0 * inf.
Why this is a PyTensor bug (not a fundamental JAX limitation)
jax.nn.log_sigmoidhandles the same input correctly because JAXregisters a custom, numerically-stable VJP for it. PyTensor could fix
this by dispatching:
in its JAX Elemwise dispatch table, rather than inlining them as raw
arithmetic ops. The same NaN-in-unselected-branch problem will affect
any PyTensor piecewise formula whose unselected branches contain exp of
a large positive argument.
Workaround (user-side)
Clip the argument before passing it to log(sigmoid(.)):
This is semantically correct: when arg >> 0 the function value is already
~0 and the gradient is already ~0, so capping at 700 changes nothing
observable while preventing the overflow.
Expected behavior
jax.gradof any PyTensor function that is smooth and mathematicallywell-defined at the evaluation point should return a finite gradient,
matching the result of
jax.jvp(forward-mode) and finite differences.Note: this report was largely generated by Claude
Reproducable code example:
Error message:
PyTensor version information:
Tested versions: 2.31.7, 3.0.7 (latest as of 2026-06-27)
Context for the issue:
Switching from nuts to numpyro should be a one line change, but it immediately crashed in my application. The user side work around is easy (clip inputs), so it's not urgent from my end, but it was not trivial to track down and is likely to trip up other users.