diff --git a/pyproject.toml b/pyproject.toml index ea9d0728f8..ea1fa59cd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,6 +158,7 @@ lines-after-imports = 2 "tests/compile/debug/test_monitormode.py" = ["T201"] "scripts/run_mypy.py" = ["T201"] "scripts/bump_numba_upper_bound.py" = ["T201"] +"scripts/check_numba_veclib.py" = ["T201"] # Test modules of optional backends that use `pytest.importorskip`, skip "E402" "tests/link/jax/**/test_*.py" = ["E402"] "tests/link/numba/**/test_*.py" = ["E402"] diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index e5a761fd5f..cfacd2c9c1 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -1027,6 +1027,20 @@ def add_numba_configvars(): BoolParam(True), in_c_key=False, ) + config.add( + "numba__veclib", + ( + "Name of the vectorizing math library wired into LLVM, or '' (the default) " + "for none. Any non-empty name (e.g. 'libmvec', 'svml', 'amdlibm', or a " + "custom build) enables Numba lowerings (e.g. for expm1) that rely on a " + "vectorizable exp/log. The value is used verbatim to key the compile cache, " + "so artifacts built against different libraries are never reused; pick a " + "stable name per library. You must wire the library into LLVM yourself " + "before importing numba; verify it with scripts/check_numba_veclib.py." + ), + StrParam(""), + in_c_key=False, + ) def _filter_base_compiledir(path: str | Path) -> Path: diff --git a/pytensor/configparser.py b/pytensor/configparser.py index 4d2f4b98b1..184f99859b 100644 --- a/pytensor/configparser.py +++ b/pytensor/configparser.py @@ -151,6 +151,7 @@ class PyTensorConfigParser: # add_numba_configvars numba__fastmath: bool numba__cache: bool + numba__veclib: str # add_caching_dir_configvars compiledir_format: str base_compiledir: Path diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 59dcc4c6c8..36c53e3277 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -470,11 +470,18 @@ def numba_funcify_ensure_cache(op, *args, **kwargs) -> tuple[Callable, str | Non return jitable_func, None else: op_name = jitable_func.__name__ + # A vector math library is wired into LLVM globally (not encoded in the + # generated source), so a function compiled against one emits calls to that + # library's symbols; reusing it under a different (or no) library segfaults on + # the unresolved symbol. Fold the library into the key so each variant caches + # separately. Only when set, so the default leaves existing caches untouched. + veclib = config.numba__veclib + veclib_key = f"_veclib{veclib}" if veclib else "" cached_func = compile_numba_function_src( src=f"def {op_name}(*args): return jitable_func(*args)", function_name=op_name, global_env=globals() | {"jitable_func": jitable_func}, - cache_key=f"{cache_key}_fastmath{int(config.numba__fastmath)}", + cache_key=f"{cache_key}_fastmath{int(config.numba__fastmath)}{veclib_key}", ) return numba_njit(cached_func, cache=True), cache_key diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index a0662a9c62..5d1f9dce58 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -6,6 +6,7 @@ from numba.core import types from numba.core.extending import get_cython_function_address +from pytensor import config from pytensor.graph.basic import Variable from pytensor.link.numba.cache import _call_cached_ptr, compile_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic @@ -24,7 +25,9 @@ Cast, Clip, Composite, + Expm1, Identity, + Log1p, Mul, Pow, Reciprocal, @@ -41,6 +44,21 @@ def scalar_op_cache_key(op, **extra_fields): return sha256(str((type(op), tuple(extra_fields.items()))).encode()).hexdigest() +@numba_basic.numba_njit(fastmath=False, inline="always") +def _log1p_via_log(x): + # log1p(x) = log(1 + x) * x / ((1 + x) - 1): the factor recovers the bits lost to + # cancellation in (1 + x) near 0 while lowering to the vectorizable `log` instead of the + # scalar-only `log1p`. Shared by Log1p and Softplus. `inline="always"` keeps the caller's + # loop call-free for the vectorizer, so the caller must also be `fastmath=False`: `reassoc` + # would simplify `(1+x)-1` back to `x`, collapsing log1p to 0 in the underflow tail. + # `type(x)(1)` keeps the literal in x's dtype (a bare `1` is int64; `int64 + float32` -> + # float64, doubling the vectorized width on float32). + one = type(x)(1) + u = one + x + um1 = u - one + return x if um1 == 0 else np.log(u) * x / um1 + + @register_funcify_and_cache_key(ScalarOp) def numba_funcify_ScalarOp(op, node, **kwargs): if not hasattr(op, "nfunc_spec"): @@ -221,6 +239,85 @@ def pow(x, y): ) +@register_funcify_and_cache_key(Log1p) +def numba_funcify_Log1p(op, node, **kwargs): + out_dtype = node.outputs[0].dtype + # `_log1p_via_log` (corrected `log`) vectorizes under a vector library and, on float32, beats + # scalar `log1p` even without one, at ~1 ulp (glibc `log` near 1 is less accurate than + # `log1p`'s small-arg path). On float64 with no library it is also ~0.6x SLOWER in the + # small-arg regime log1p exists for, so there it only pays off vectorized. Use it under a + # vector library (both dtypes) or float32 + `numba__fastmath` (the opt-in to trade the ulp + # for a scalar win); else keep scalar `log1p`. The cache key's `corrected` flag keeps the two + # from aliasing. + corrected = bool(config.numba__veclib) or ( + out_dtype == "float32" and config.numba__fastmath + ) + if not corrected: + + @numba_basic.numba_njit(fastmath=False) + def log1p(x): + return np.log1p(x) + + return log1p, scalar_op_cache_key(op, corrected=False, cache_version=5) + + @numba_basic.numba_njit(fastmath=False) + def log1p(x): + return _log1p_via_log(x) + + return log1p, scalar_op_cache_key(op, corrected=True, cache_version=5) + + +def _expm1_numba_src(output_dtype): + # expm1(x) = x + x**2 * p(x), p(x) = sum_j x**j / (j + 2)! for |x| < ln2, else exp(x) - 1 + # (no cancellation past ln2). The polynomial removes the near-0 cancellation without the + # `log` the exact correction needs, leaving a single vectorizable `exp`; the caller must keep + # `fastmath` off so it is evaluated as written. Term count is sized to eps: the tail at + # |x| = ln2 rounds away, so 8 terms hold float32 to 1 ulp (7 give 3), float64 needs all 15 + # (14 give 2). + n_terms = 8 if output_dtype == "float32" else 15 + cast = "np.float32" if output_dtype == "float32" else "" + + def lit(c): + # Every literal must carry the output dtype or a float32 input promotes to float64 (a + # bare `1` is int64; `int64 + float32` -> float64). Numba unifies both branch return + # types, so the `np.exp(x) - 1` branch below must go through `lit` too. + return f"{cast}({c!r})" if cast else repr(c) + + coeffs = [1.0 / math.factorial(j + 2) for j in range(n_terms)] + poly = lit(coeffs[-1]) + for c in reversed(coeffs[:-1]): + poly = f"({poly} * x + {lit(c)})" + return ( + f"def expm1(x):\n" + f" if abs(x) < {math.log(2.0)!r}:\n" + f" return x + x * x * ({poly})\n" + f" return np.exp(x) - {lit(1.0)}\n" + ) + + +@register_funcify_and_cache_key(Expm1) +def numba_funcify_Expm1(op, node, **kwargs): + out_dtype = node.outputs[0].dtype + # Polynomial-near-0 + single `exp`: removes expm1's cancellation while lowering to a + # vectorizable `exp`. On float32 the 8-term FMA poly plus one `exp` beat scalar `expm1` even + # with no library; on float64 it is slower in the near-0 regime unless `exp` becomes a SIMD + # call. Use the poly under a vector library (both dtypes) or float32 + `numba__fastmath` (the + # opt-in to trade ~0.5 ulp for the scalar win); else fall back to scalar `expm1`. The cache + # key's `poly` flag keeps the two from aliasing. + poly = bool(config.numba__veclib) or ( + out_dtype == "float32" and config.numba__fastmath + ) + if not poly: + fn, _ = numba_funcify_ScalarOp(op, node, **kwargs) + return fn, scalar_op_cache_key(op, poly=False, cache_version=4) + + src = _expm1_numba_src(out_dtype) + expm1 = compile_numba_function_src(src, "expm1", {"np": np}) + return numba_basic.numba_njit(expm1, fastmath=False), scalar_op_cache_key( + op, poly=True, cache_version=4 + ) + + @register_funcify_and_cache_key(Add) def numba_funcify_Add(op, node, **kwargs): nary_add_fn = binary_to_nary_func(node.inputs, "add", "+") @@ -302,22 +399,55 @@ def reciprocal(x): return reciprocal, scalar_op_cache_key(op, cache_version=1) +@numba_basic.numba_njit(fastmath=False, inline="always") +def _sigmoid_via_exp(x): + # sigmoid(x) = 1 / (1 + exp(-x)). The `d == 0` guard is DEAD (`1 + exp(-x)` >= 1) but the + # division-guard select lets the loop vectorizer pull the division through and replace `exp` + # with a vector call; a plain `1 / (1 + exp(-x))` stays scalar (bare division blocks it, like + # `_log1p_via_log`'s `um1 == 0` select). `fastmath` off so the select is not proved dead. The + # `1`/`0` carry x's dtype via `type(x)(...)` (a bare `1` is int64; `int64 + float32` -> + # float64, halving the SIMD width on float32). + one = type(x)(1) + d = one + np.exp(-x) + return type(x)(0) if d == type(x)(0) else one / d + + @register_funcify_and_cache_key(Sigmoid) def numba_funcify_Sigmoid(op, node, **kwargs): inp_dtype = node.inputs[0].type.dtype - if inp_dtype.startswith("uint"): - upcast_uint_dtype = { - "uint8": np.float32, # numpy uses float16, but not Numba - "uint16": np.float32, - "uint32": np.float64, - "uint64": np.float64, - }[inp_dtype] + upcast_uint_dtype = { + "uint8": np.float32, # numpy uses float16, but not Numba + "uint16": np.float32, + "uint32": np.float64, + "uint64": np.float64, + }.get(inp_dtype) + # `_sigmoid_via_exp` vectorizes under a vector library (2.7x float64 / 5.8x float32), but its + # division-guard select is a no-op (the `d == 0` arm is dead), so it buys nothing scalar. + # Unlike Log1p/Expm1/Softplus it is not a precision-for-speed trade, so `numba__fastmath` is + # not the right gate -- only a wired library makes it pay. Gate solely on `numba__veclib`; + # else emit plain `1 / (1 + exp(-x))`. + vectorizable = bool(config.numba__veclib) + + if upcast_uint_dtype is not None: + if vectorizable: + + @numba_basic.numba_njit(fastmath=False) + def sigmoid(x): + return _sigmoid_via_exp(numba_basic.direct_cast(x, upcast_uint_dtype)) - @numba_basic.numba_njit + else: + + @numba_basic.numba_njit + def sigmoid(x): + # Can't negate uint + float_x = numba_basic.direct_cast(x, upcast_uint_dtype) + return 1 / (1 + np.exp(-float_x)) + + elif vectorizable: + + @numba_basic.numba_njit(fastmath=False) def sigmoid(x): - # Can't negate uint - float_x = numba_basic.direct_cast(x, upcast_uint_dtype) - return 1 / (1 + np.exp(-float_x)) + return _sigmoid_via_exp(x) else: @@ -325,7 +455,7 @@ def sigmoid(x): def sigmoid(x): return 1 / (1 + np.exp(-x)) - return sigmoid, scalar_op_cache_key(op, cache_version=1) + return sigmoid, scalar_op_cache_key(op, veclib=vectorizable, cache_version=3) @register_funcify_and_cache_key(GammaLn) @@ -339,6 +469,13 @@ def gammaln(x): @register_funcify_and_cache_key(Log1mexp) def numba_funcify_Log1mexp(op, node, **kwargs): + # Mächler (2012) two-branch form with scalar `log1p`. `_log1p_via_log` (trades the `log1p` + # libcall for `log` + a division) was reverted here: it is fast only when its argument is + # away from 0 -- glibc `log(1+a)` is slow for `1+a` near 1, while `log1p` has a fast + # small-argument path. This branch always feeds it small `a = -exp(x)`, the slow regime: a + # ~30% float64 regression for no accuracy gain. (Verified to be the argument range, not the + # `exp` or the branch: the corrected form alone flips from 1.8x faster on x in (-0.5, 5) to + # 0.44x slower as x -> 0.) @numba_basic.numba_njit def logp1mexp(x): if x < np.log(0.5): @@ -385,6 +522,29 @@ def numba_funcify_Softplus(op, node, **kwargs): upcast_uint_dtype = None out_dtype = np.dtype(node.outputs[0].type.dtype) + # Branchless `max(x, 0) + log1p(exp(-|x|))` is ~1.1-1.2x faster than the cascade for the + # common mixed-sign case and vectorizes under a vector library, but routes log1p through the + # corrected `log`, costing ~1 ulp. Use it under a vector library or `numba__fastmath` (the + # opt-in to trade that ulp for speed); else use the accurate Mächler cascade. The cache key's + # `branchless` flag keeps the two from aliasing. + if bool(config.numba__veclib) or config.numba__fastmath: + # Branch-free softplus: max(x, 0) + log1p(exp(-|x|)). Equivalent to the Mächler cascade + # but feeds exp/log1p an argument in (0, 1] (their cheap regime) without a per-element + # branch: ~1.5x faster on the common case, never overflows. log1p goes through + # `_log1p_via_log` for the vectorizable `log`/`exp` (plain `np.log1p` has no vector form). + # `type(x)(0)` keeps the literal in x's dtype (`max(float32, float64)` would return + # float64, doubling the width). fastmath off so the inlined `(1+x)-1` is not reassociated + # to `x` (collapsing log1p to 0 in the underflow tail); vectorization does not need it. + @numba_basic.numba_njit(fastmath=False) + def softplus(x): + if upcast_uint_dtype is not None: + # Can't negate uint; upcast once so the formula below is uniform. + x = numba_basic.direct_cast(x, upcast_uint_dtype) + value = max(x, type(x)(0)) + _log1p_via_log(np.exp(-abs(x))) + return numba_basic.direct_cast(value, out_dtype) + + return softplus, scalar_op_cache_key(op, branchless=True, cache_version=5) + @numba_basic.numba_njit def softplus(x): if x < -37.0: @@ -400,7 +560,7 @@ def softplus(x): value = x return numba_basic.direct_cast(value, out_dtype) - return softplus, scalar_op_cache_key(op, cache_version=1) + return softplus, scalar_op_cache_key(op, branchless=False, cache_version=5) @register_funcify_and_cache_key(ScalarLoop) diff --git a/scripts/check_numba_veclib.py b/scripts/check_numba_veclib.py new file mode 100644 index 0000000000..692c85cc86 --- /dev/null +++ b/scripts/check_numba_veclib.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python +"""Check whether Numba/LLVM vectorizes transcendental math calls in this environment. + +PyTensor's Numba backend lowers ``exp``/``log`` (and, when ``config.numba__veclib`` is +set, ``log1p``/``expm1``) to scalar libm calls that LLVM's loop vectorizer can replace +with SIMD calls into a *vector math library* -- glibc ``libmvec``, Intel SVML, or AMD +AMDLIBM. That substitution only happens when such a library is wired into LLVM. + +Run this to confirm whether your environment picks one up:: + + python scripts/check_numba_veclib.py + +If it reports ``VECTORIZED``, enable the SIMD ``log1p``/``expm1`` lowerings by setting +``pytensor.config.numba__veclib`` to the name of the library it found -- ``"libmvec"``, +``"svml"``, or ``"amdlibm"`` (or the ``PYTENSOR_FLAGS`` / ``.pytensorrc`` equivalent). +Otherwise keep the default (``""``): without a vector library those lowerings only add +work over the scalar libm calls. + +Wiring up glibc ``libmvec`` (Linux/glibc) looks like this, *before* importing numba:: + + import llvmlite.binding as llvm + + llvm.set_option("", "-vector-library=LIBMVEC-X86") # "LIBMVEC" on LLVM >= 21 + llvm.load_library_permanently("libmvec.so.1") +""" + +import re + +import numba +import numpy as np + +import pytensor +import pytensor.tensor as pt + + +# Maps an assembly symbol prefix to (numba__veclib config value, human description) +# for the library that exports it. +VECLIB_SYMBOLS = { + "_ZGV": ("libmvec", "glibc libmvec / GNU vector ABI"), + "__svml_": ("svml", "Intel SVML"), + "amd_vr": ("amdlibm", "AMD AMDLIBM"), +} + + +def detect_vectorized_math() -> dict[str, list[str]]: + """Compile an ``exp`` loop and report any vector-math symbols in its assembly. + + This mirrors how PyTensor's Numba ``Elemwise`` lowers a transcendental: a scalar + libm call inside a contiguous loop, which the loop vectorizer rewrites to a packed + call only when a vector library is available. + """ + + @numba.njit + def exp_loop(x): + out = np.empty_like(x) + for i in range(x.shape[0]): + out[i] = np.exp(x[i]) + return out + + exp_loop.compile((numba.float64[::1],)) + asm = exp_loop.inspect_asm(exp_loop.signatures[0]) + return { + cfg: (desc, sorted(set(re.findall(rf"{re.escape(prefix)}\w*", asm)))) + for prefix, (cfg, desc) in VECLIB_SYMBOLS.items() + if prefix in asm + } + + +def main() -> int: + # Sanity check that PyTensor's Numba backend itself works end to end. + x = pt.vector("x") + fn = pytensor.function([x], pt.exp(x), mode="NUMBA") + np.testing.assert_allclose(fn(np.linspace(-1, 1, 8)), np.exp(np.linspace(-1, 1, 8))) + + found = detect_vectorized_math() + print(f"current pytensor.config.numba__veclib = {pytensor.config.numba__veclib}\n") + + if found: + print("VECTORIZED: exp lowered to SIMD vector-math calls:") + for desc, syms in found.values(): + print(f" - {desc}: {', '.join(syms[:4])}") + cfg = next(iter(found)) + print("\nA vector math library is wired into LLVM. Enable the SIMD") + print( + f'log1p/expm1 lowerings with:\n pytensor.config.numba__veclib = "{cfg}"' + ) + return 0 + + print("SCALAR: exp stayed a scalar libm call -- no vector math library is wired") + print("into LLVM, so the log1p/expm1 SIMD lowerings would only add overhead.") + print('Keep pytensor.config.numba__veclib = "" (the default).\n') + print("To wire up glibc libmvec (Linux/glibc), before importing numba:") + print(" import llvmlite.binding as llvm") + print( + ' llvm.set_option("", "-vector-library=LIBMVEC-X86") # "LIBMVEC" on LLVM >= 21' + ) + print(' llvm.load_library_permanently("libmvec.so.1")') + return 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index fdd505fcb3..f96132210c 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -379,3 +379,162 @@ def test_loop_with_cython_wrapped_op(self): res = fn(x_test) expected_res = ps.psi(x).eval({x: x_test}) np.testing.assert_allclose(res, expected_res) + + +def _max_ulp_err(got, ref64, dtype): + """Max error in ulps of `dtype`, over finite points, vs a float64 reference. + + Uses the true spacing at each value, not a relative-error proxy. `np.spacing` is + signed (negative for negative args), so we take abs -- otherwise a `> 0` mask would + silently drop every point on the negative branch, which is the whole reason log1p / + expm1 exist. A float64 reference resolves float32 accuracy exactly; for float64 it + measures agreement with numpy's own (well-tested) log1p / expm1. + """ + got = np.asarray(got).astype("float64") + target = np.asarray(ref64).astype( + dtype + ) # reference correctly rounded to target dtype + sp = np.abs(np.spacing(target)).astype("float64") + target = target.astype("float64") + finite = np.isfinite(got) & np.isfinite(target) & (sp > 0) + return (np.abs(got[finite] - target[finite]) / sp[finite]).max() + + +@pytest.mark.parametrize("dtype", ["float32", "float64"]) +def test_vectorizable_log1p(dtype): + """log1p lowered via a corrected log(1 + x) stays accurate over a wide range. + + The naive log(1 + x) collapses near 0 (the regime log1p exists for), so we sweep + densely and deep into the near-zero cancellation region in BOTH signs -- a narrow or + positive-only grid never stresses the correction and a broken one slips through. We + also check the domain edge: log1p(-1) = -inf, log1p(x < -1) = nan. The corrected form + is emitted under a vector library (both dtypes) or on float32 under numba__fastmath, so + we wire numba__veclib to exercise the corrected lowering for both dtypes; we do not wire + an actual library, since this checks the corrected form's accuracy, not its SIMD lowering. + """ + with config.change_flags(numba__veclib="libmvec"): + x = pt.vector("x", dtype=dtype) + fn = function([x], pt.log1p(x), mode=numba_mode) + + neg = -np.logspace(-20, np.log10(0.999), 4000) # (-0.999, 0), deep near-zero + pos = np.logspace(-20, 2.5, 4000) # (0, ~316) + edge = np.array([-1.0, -1.5, -10.0]) # -inf, nan, nan + x_test = np.concatenate([neg, pos, edge]).astype(dtype) + + got = fn(x_test) + with np.errstate(invalid="ignore", divide="ignore"): # x <= -1 is intentional + ref = np.log1p(x_test.astype("float64")) + + assert got.dtype == np.dtype( + dtype + ) # output dtype; does NOT reveal an internal upcast + np.testing.assert_array_equal(np.isnan(got), np.isnan(ref)) # nan for x < -1 + np.testing.assert_array_equal(np.isinf(got), np.isinf(ref)) # -inf at x == -1 + assert _max_ulp_err(got, ref, dtype) < 4 # ~2 ulp worst case observed + + +@pytest.mark.parametrize("dtype", ["float32", "float64"]) +def test_vectorizable_expm1(dtype): + """expm1 lowered via a polynomial near 0 / exp(x) - 1 elsewhere stays accurate. + + The naive exp(x) - 1 collapses near 0, so we sweep densely deep near-zero in both + signs. The polynomial branch (|x| < ln2) carries no exp, so its accuracy is + independent of whatever vector exp is wired in -- we bound it tightly there, the + branch the rewrite exists for. Across the full domain (incl. the exp(x) - 1 branch, + which only inherits exp's own accuracy) we keep a looser bound, plus the overflow edge + where it must go to +inf exactly where numpy does. The polynomial is emitted under a + vector library (both dtypes) or on float32 under numba__fastmath, so we wire numba__veclib + to exercise the polynomial for both dtypes; we do not actually wire a library in here, since + this checks the polynomial's accuracy, not its SIMD lowering. + """ + with config.change_flags(numba__veclib="libmvec"): + x = pt.vector("x", dtype=dtype) + fn = function([x], pt.expm1(x), mode=numba_mode) + + ln2 = np.log(2.0) + small = np.logspace(-20, np.log10(ln2), 4000) # |x| in (0, ln2): polynomial branch + mid = np.logspace(np.log10(ln2), 1, 2000) # |x| in (ln2, 10): exp(x) - 1 branch + overflow = np.array([1e3, 1e5]) # exp overflows -> +inf in both float32 and float64 + x_test = np.concatenate([-small, small, -mid, mid, overflow]).astype(dtype) + + got = fn(x_test) + with np.errstate(over="ignore"): # the overflow points are intentional + ref = np.expm1(x_test.astype("float64")) + ref_same_dtype = np.expm1(x_test) + + assert got.dtype == np.dtype( + dtype + ) # output dtype; does NOT reveal an internal upcast + np.testing.assert_array_equal(np.isinf(got), np.isinf(ref_same_dtype)) # +inf + poly = np.abs(x_test) < ln2 + assert _max_ulp_err(got[poly], ref[poly], dtype) < 4 # cancellation region: ~1 ulp + assert ( + _max_ulp_err(got, ref, dtype) < 16 + ) # full domain incl. vector exp's own error + + +@pytest.mark.parametrize("dtype", ["float32", "float64"]) +def test_vectorizable_sigmoid(dtype): + """sigmoid lowered via the division-guard select (1/(1+exp(-x)) with a dead `d == 0` + guard that lets the loop vectorize) stays accurate and matches the naive form. + + The guard never fires (1 + exp(-x) >= 1), so the result is identical to 1/(1+exp(-x)); we + sweep a wide range including the saturating tails (-> 0 and -> 1). The form is gated solely + on numba__veclib (it is ulp-neutral, so it only pays off vectorized), so we enable it to + exercise the vectorizable lowering for both dtypes; no library is wired, so this checks + accuracy, not the SIMD lowering. + """ + with config.change_flags(numba__veclib="libmvec"): + x = pt.vector("x", dtype=dtype) + fn = function([x], pt.sigmoid(x), mode=numba_mode) + + # |x| < 88 keeps exp(-x) finite in float32; the 1e3 points saturate to exactly 0 and 1 + x_test = np.concatenate([np.linspace(-80, 80, 8000), np.array([-1e3, 1e3])]).astype( + dtype + ) + got = fn(x_test) + with np.errstate( + over="ignore" + ): # exp(-x) overflows to inf at the -1e3 point (-> 0) + ref = 1.0 / (1.0 + np.exp(-x_test.astype("float64"))) + + assert got.dtype == np.dtype( + dtype + ) # output dtype; does NOT reveal an internal upcast + np.testing.assert_array_equal( + got[-2:], np.array([0.0, 1.0], dtype=dtype) + ) # saturation + assert ( + _max_ulp_err(got, ref, dtype) < 4 + ) # ~2 ulp; guard is exact, scalar exp is accurate + + +@pytest.mark.parametrize("dtype", ["float64", "float32"]) +@pytest.mark.parametrize( + "op, lo, hi", + [ + (pt.log1p, -0.5, 5.0), + (pt.expm1, -5.0, 5.0), + (pt.log1mexp, -5.0, -0.01), + (pt.softplus, -30.0, 30.0), + (pt.sigmoid, -10.0, 10.0), + ], + ids=["log1p", "expm1", "log1mexp", "softplus", "sigmoid"], +) +def test_vectorizable_op_benchmark(op, lo, hi, dtype, benchmark): + """Throughput of the SIMD/cache-friendly log1p / expm1 / log1mexp / softplus lowerings. + + Runs under the default config (numba__fastmath on, no vector library), so it is + reproducible anywhere; comparing this branch against main (pytest-benchmark tracks results + across runs) shows the PR's speedup. It is largest for float32. The precision-for-speed + rewrites active here are the float32 log1p/expm1 polynomials and softplus (both dtypes), + gated on fastmath; float64 log1p/expm1 and sigmoid (both dtypes) keep main's form unless a + vector library is wired, so they show ~parity here. A wired library widens every win + further -- not exercised here. + """ + x = pt.vector("x", dtype=dtype) + fn = function([x], op(x), mode=numba_mode) + fn.trust_input = True + x_test = rng.uniform(lo, hi, 100_000).astype(dtype) + fn(x_test) # compile before timing + benchmark(fn, x_test)