Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ lines-after-imports = 2
"tests/compile/debug/test_monitormode.py" = ["T201"]
"scripts/run_mypy.py" = ["T201"]
"scripts/bump_numba_upper_bound.py" = ["T201"]
"scripts/check_numba_veclib.py" = ["T201"]
# Test modules of optional backends that use `pytest.importorskip`, skip "E402"
"tests/link/jax/**/test_*.py" = ["E402"]
"tests/link/numba/**/test_*.py" = ["E402"]
Expand Down
14 changes: 14 additions & 0 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,20 @@ def add_numba_configvars():
BoolParam(True),
in_c_key=False,
)
config.add(
"numba__veclib",
(
"Name of the vectorizing math library wired into LLVM, or '' (the default) "
"for none. Any non-empty name (e.g. 'libmvec', 'svml', 'amdlibm', or a "
"custom build) enables Numba lowerings (e.g. for expm1) that rely on a "
"vectorizable exp/log. The value is used verbatim to key the compile cache, "
"so artifacts built against different libraries are never reused; pick a "
"stable name per library. You must wire the library into LLVM yourself "
"before importing numba; verify it with scripts/check_numba_veclib.py."
),
StrParam(""),
in_c_key=False,
)


def _filter_base_compiledir(path: str | Path) -> Path:
Expand Down
1 change: 1 addition & 0 deletions pytensor/configparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class PyTensorConfigParser:
# add_numba_configvars
numba__fastmath: bool
numba__cache: bool
numba__veclib: str
# add_caching_dir_configvars
compiledir_format: str
base_compiledir: Path
Expand Down
9 changes: 8 additions & 1 deletion pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,11 +470,18 @@ def numba_funcify_ensure_cache(op, *args, **kwargs) -> tuple[Callable, str | Non
return jitable_func, None
else:
op_name = jitable_func.__name__
# A vector math library is wired into LLVM globally (not encoded in the
# generated source), so a function compiled against one emits calls to that
# library's symbols; reusing it under a different (or no) library segfaults on
# the unresolved symbol. Fold the library into the key so each variant caches
# separately. Only when set, so the default leaves existing caches untouched.
veclib = config.numba__veclib
veclib_key = f"_veclib{veclib}" if veclib else ""
cached_func = compile_numba_function_src(
src=f"def {op_name}(*args): return jitable_func(*args)",
function_name=op_name,
global_env=globals() | {"jitable_func": jitable_func},
cache_key=f"{cache_key}_fastmath{int(config.numba__fastmath)}",
cache_key=f"{cache_key}_fastmath{int(config.numba__fastmath)}{veclib_key}",
)
return numba_njit(cached_func, cache=True), cache_key

Expand Down
186 changes: 173 additions & 13 deletions pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from numba.core import types
from numba.core.extending import get_cython_function_address

from pytensor import config
from pytensor.graph.basic import Variable
from pytensor.link.numba.cache import _call_cached_ptr, compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic
Expand All @@ -24,7 +25,9 @@
Cast,
Clip,
Composite,
Expm1,
Identity,
Log1p,
Mul,
Pow,
Reciprocal,
Expand All @@ -41,6 +44,21 @@ def scalar_op_cache_key(op, **extra_fields):
return sha256(str((type(op), tuple(extra_fields.items()))).encode()).hexdigest()


@numba_basic.numba_njit(fastmath=False, inline="always")
def _log1p_via_log(x):
# log1p(x) = log(1 + x) * x / ((1 + x) - 1): the factor recovers the bits lost to
# cancellation in (1 + x) near 0 while lowering to the vectorizable `log` instead of the
# scalar-only `log1p`. Shared by Log1p and Softplus. `inline="always"` keeps the caller's
# loop call-free for the vectorizer, so the caller must also be `fastmath=False`: `reassoc`
# would simplify `(1+x)-1` back to `x`, collapsing log1p to 0 in the underflow tail.
# `type(x)(1)` keeps the literal in x's dtype (a bare `1` is int64; `int64 + float32` ->
# float64, doubling the vectorized width on float32).
one = type(x)(1)
u = one + x
um1 = u - one
return x if um1 == 0 else np.log(u) * x / um1


@register_funcify_and_cache_key(ScalarOp)
def numba_funcify_ScalarOp(op, node, **kwargs):
if not hasattr(op, "nfunc_spec"):
Expand Down Expand Up @@ -221,6 +239,85 @@ def pow(x, y):
)


@register_funcify_and_cache_key(Log1p)
def numba_funcify_Log1p(op, node, **kwargs):
out_dtype = node.outputs[0].dtype
# `_log1p_via_log` (corrected `log`) vectorizes under a vector library and, on float32, beats
# scalar `log1p` even without one, at ~1 ulp (glibc `log` near 1 is less accurate than
# `log1p`'s small-arg path). On float64 with no library it is also ~0.6x SLOWER in the
# small-arg regime log1p exists for, so there it only pays off vectorized. Use it under a
# vector library (both dtypes) or float32 + `numba__fastmath` (the opt-in to trade the ulp
# for a scalar win); else keep scalar `log1p`. The cache key's `corrected` flag keeps the two
# from aliasing.
corrected = bool(config.numba__veclib) or (
out_dtype == "float32" and config.numba__fastmath
)
if not corrected:

@numba_basic.numba_njit(fastmath=False)
def log1p(x):
return np.log1p(x)

return log1p, scalar_op_cache_key(op, corrected=False, cache_version=5)

@numba_basic.numba_njit(fastmath=False)
def log1p(x):
return _log1p_via_log(x)

return log1p, scalar_op_cache_key(op, corrected=True, cache_version=5)


def _expm1_numba_src(output_dtype):
# expm1(x) = x + x**2 * p(x), p(x) = sum_j x**j / (j + 2)! for |x| < ln2, else exp(x) - 1
# (no cancellation past ln2). The polynomial removes the near-0 cancellation without the
# `log` the exact correction needs, leaving a single vectorizable `exp`; the caller must keep
# `fastmath` off so it is evaluated as written. Term count is sized to eps: the tail at
# |x| = ln2 rounds away, so 8 terms hold float32 to 1 ulp (7 give 3), float64 needs all 15
# (14 give 2).
n_terms = 8 if output_dtype == "float32" else 15
cast = "np.float32" if output_dtype == "float32" else ""

def lit(c):
# Every literal must carry the output dtype or a float32 input promotes to float64 (a
# bare `1` is int64; `int64 + float32` -> float64). Numba unifies both branch return
# types, so the `np.exp(x) - 1` branch below must go through `lit` too.
return f"{cast}({c!r})" if cast else repr(c)

coeffs = [1.0 / math.factorial(j + 2) for j in range(n_terms)]
poly = lit(coeffs[-1])
for c in reversed(coeffs[:-1]):
poly = f"({poly} * x + {lit(c)})"
return (
f"def expm1(x):\n"
f" if abs(x) < {math.log(2.0)!r}:\n"
f" return x + x * x * ({poly})\n"
f" return np.exp(x) - {lit(1.0)}\n"
)


@register_funcify_and_cache_key(Expm1)
def numba_funcify_Expm1(op, node, **kwargs):
out_dtype = node.outputs[0].dtype
# Polynomial-near-0 + single `exp`: removes expm1's cancellation while lowering to a
# vectorizable `exp`. On float32 the 8-term FMA poly plus one `exp` beat scalar `expm1` even
# with no library; on float64 it is slower in the near-0 regime unless `exp` becomes a SIMD
# call. Use the poly under a vector library (both dtypes) or float32 + `numba__fastmath` (the
# opt-in to trade ~0.5 ulp for the scalar win); else fall back to scalar `expm1`. The cache
# key's `poly` flag keeps the two from aliasing.
poly = bool(config.numba__veclib) or (
out_dtype == "float32" and config.numba__fastmath
)
if not poly:
fn, _ = numba_funcify_ScalarOp(op, node, **kwargs)
return fn, scalar_op_cache_key(op, poly=False, cache_version=4)

src = _expm1_numba_src(out_dtype)
expm1 = compile_numba_function_src(src, "expm1", {"np": np})
return numba_basic.numba_njit(expm1, fastmath=False), scalar_op_cache_key(
op, poly=True, cache_version=4
)


@register_funcify_and_cache_key(Add)
def numba_funcify_Add(op, node, **kwargs):
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")
Expand Down Expand Up @@ -302,30 +399,63 @@ def reciprocal(x):
return reciprocal, scalar_op_cache_key(op, cache_version=1)


@numba_basic.numba_njit(fastmath=False, inline="always")
def _sigmoid_via_exp(x):
# sigmoid(x) = 1 / (1 + exp(-x)). The `d == 0` guard is DEAD (`1 + exp(-x)` >= 1) but the
# division-guard select lets the loop vectorizer pull the division through and replace `exp`
# with a vector call; a plain `1 / (1 + exp(-x))` stays scalar (bare division blocks it, like
# `_log1p_via_log`'s `um1 == 0` select). `fastmath` off so the select is not proved dead. The
# `1`/`0` carry x's dtype via `type(x)(...)` (a bare `1` is int64; `int64 + float32` ->
# float64, halving the SIMD width on float32).
one = type(x)(1)
d = one + np.exp(-x)
return type(x)(0) if d == type(x)(0) else one / d


@register_funcify_and_cache_key(Sigmoid)
def numba_funcify_Sigmoid(op, node, **kwargs):
inp_dtype = node.inputs[0].type.dtype
if inp_dtype.startswith("uint"):
upcast_uint_dtype = {
"uint8": np.float32, # numpy uses float16, but not Numba
"uint16": np.float32,
"uint32": np.float64,
"uint64": np.float64,
}[inp_dtype]
upcast_uint_dtype = {
"uint8": np.float32, # numpy uses float16, but not Numba
"uint16": np.float32,
"uint32": np.float64,
"uint64": np.float64,
}.get(inp_dtype)
# `_sigmoid_via_exp` vectorizes under a vector library (2.7x float64 / 5.8x float32), but its
# division-guard select is a no-op (the `d == 0` arm is dead), so it buys nothing scalar.
# Unlike Log1p/Expm1/Softplus it is not a precision-for-speed trade, so `numba__fastmath` is
# not the right gate -- only a wired library makes it pay. Gate solely on `numba__veclib`;
# else emit plain `1 / (1 + exp(-x))`.
vectorizable = bool(config.numba__veclib)

if upcast_uint_dtype is not None:
if vectorizable:

@numba_basic.numba_njit(fastmath=False)
def sigmoid(x):
return _sigmoid_via_exp(numba_basic.direct_cast(x, upcast_uint_dtype))

@numba_basic.numba_njit
else:

@numba_basic.numba_njit
def sigmoid(x):
# Can't negate uint
float_x = numba_basic.direct_cast(x, upcast_uint_dtype)
return 1 / (1 + np.exp(-float_x))

elif vectorizable:

@numba_basic.numba_njit(fastmath=False)
def sigmoid(x):
# Can't negate uint
float_x = numba_basic.direct_cast(x, upcast_uint_dtype)
return 1 / (1 + np.exp(-float_x))
return _sigmoid_via_exp(x)

else:

@numba_basic.numba_njit
def sigmoid(x):
return 1 / (1 + np.exp(-x))

return sigmoid, scalar_op_cache_key(op, cache_version=1)
return sigmoid, scalar_op_cache_key(op, veclib=vectorizable, cache_version=3)


@register_funcify_and_cache_key(GammaLn)
Expand All @@ -339,6 +469,13 @@ def gammaln(x):

@register_funcify_and_cache_key(Log1mexp)
def numba_funcify_Log1mexp(op, node, **kwargs):
# Mächler (2012) two-branch form with scalar `log1p`. `_log1p_via_log` (trades the `log1p`
# libcall for `log` + a division) was reverted here: it is fast only when its argument is
# away from 0 -- glibc `log(1+a)` is slow for `1+a` near 1, while `log1p` has a fast
# small-argument path. This branch always feeds it small `a = -exp(x)`, the slow regime: a
# ~30% float64 regression for no accuracy gain. (Verified to be the argument range, not the
# `exp` or the branch: the corrected form alone flips from 1.8x faster on x in (-0.5, 5) to
# 0.44x slower as x -> 0.)
@numba_basic.numba_njit
def logp1mexp(x):
if x < np.log(0.5):
Expand Down Expand Up @@ -385,6 +522,29 @@ def numba_funcify_Softplus(op, node, **kwargs):
upcast_uint_dtype = None
out_dtype = np.dtype(node.outputs[0].type.dtype)

# Branchless `max(x, 0) + log1p(exp(-|x|))` is ~1.1-1.2x faster than the cascade for the
# common mixed-sign case and vectorizes under a vector library, but routes log1p through the
# corrected `log`, costing ~1 ulp. Use it under a vector library or `numba__fastmath` (the
# opt-in to trade that ulp for speed); else use the accurate Mächler cascade. The cache key's
# `branchless` flag keeps the two from aliasing.
if bool(config.numba__veclib) or config.numba__fastmath:
# Branch-free softplus: max(x, 0) + log1p(exp(-|x|)). Equivalent to the Mächler cascade
# but feeds exp/log1p an argument in (0, 1] (their cheap regime) without a per-element
# branch: ~1.5x faster on the common case, never overflows. log1p goes through
# `_log1p_via_log` for the vectorizable `log`/`exp` (plain `np.log1p` has no vector form).
# `type(x)(0)` keeps the literal in x's dtype (`max(float32, float64)` would return
# float64, doubling the width). fastmath off so the inlined `(1+x)-1` is not reassociated
# to `x` (collapsing log1p to 0 in the underflow tail); vectorization does not need it.
@numba_basic.numba_njit(fastmath=False)
def softplus(x):
if upcast_uint_dtype is not None:
# Can't negate uint; upcast once so the formula below is uniform.
x = numba_basic.direct_cast(x, upcast_uint_dtype)
value = max(x, type(x)(0)) + _log1p_via_log(np.exp(-abs(x)))
return numba_basic.direct_cast(value, out_dtype)

return softplus, scalar_op_cache_key(op, branchless=True, cache_version=5)

@numba_basic.numba_njit
def softplus(x):
if x < -37.0:
Expand All @@ -400,7 +560,7 @@ def softplus(x):
value = x
return numba_basic.direct_cast(value, out_dtype)

return softplus, scalar_op_cache_key(op, cache_version=1)
return softplus, scalar_op_cache_key(op, branchless=False, cache_version=5)


@register_funcify_and_cache_key(ScalarLoop)
Expand Down
Loading
Loading