From ac5c254eb454fe0923a4dbdcb2ac4b92edb96a29 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sat, 27 Jun 2026 10:59:21 +0200 Subject: [PATCH 1/4] Validate JAX multinomial probability values --- src/pyrecest/_backend/jax/random.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/pyrecest/_backend/jax/random.py b/src/pyrecest/_backend/jax/random.py index 155f87f2b..66eee7dd3 100644 --- a/src/pyrecest/_backend/jax/random.py +++ b/src/pyrecest/_backend/jax/random.py @@ -454,6 +454,18 @@ def multivariate_normal(mean, cov, size=None, *args, **kwargs): return set_state_return(has_state, state, res) +def _validate_multinomial_pvals(pvals): + if _contains_boolean_value(pvals): + raise TypeError("pvals must be real numeric, not boolean") + try: + pvals = _jnp.asarray(pvals) + except (TypeError, ValueError, RuntimeError) as exc: + raise TypeError("pvals must be real numeric") from exc + if pvals.dtype.kind not in "iuf": + raise TypeError("pvals must be real numeric") + return pvals.astype(_jnp.float32) + + def _multinomial(state, n, pvals, size=None): if not _looks_like_integer_dimension(n): raise TypeError("n must be a non-negative integer") @@ -463,7 +475,7 @@ def _multinomial(state, n, pvals, size=None): state, key = jax.random.split(state) sample_shape = _shape_from_size(size) - pvals = _jnp.asarray(pvals, dtype=_jnp.float32) + pvals = _validate_multinomial_pvals(pvals) if pvals.ndim != 1: raise ValueError("pvals must be 1-dimensional") if pvals.shape[0] == 0: From 02378c9b0184e12376e150754a7cb2849774fc41 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sat, 27 Jun 2026 11:00:07 +0200 Subject: [PATCH 2/4] Validate PyTorch multinomial probability values --- src/pyrecest/_backend/pytorch/random.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/pyrecest/_backend/pytorch/random.py b/src/pyrecest/_backend/pytorch/random.py index 487627d65..ca7d1c4cf 100644 --- a/src/pyrecest/_backend/pytorch/random.py +++ b/src/pyrecest/_backend/pytorch/random.py @@ -363,6 +363,18 @@ def _multinomial_sample_count(sample_shape): return _prod(sample_shape) if sample_shape else 1 +def _validate_multinomial_pvals(pvals, device): + if _contains_boolean_value(pvals): + raise TypeError("pvals must be real numeric, not boolean") + try: + pvals = _torch.as_tensor(pvals, device=device) + except (TypeError, ValueError, RuntimeError) as exc: + raise TypeError("pvals must be real numeric") from exc + if not _is_real_numeric_dtype(pvals.dtype): + raise TypeError("pvals must be real numeric") + return pvals.to(dtype=_torch.float32) + + def multinomial(n, pvals, size=None): if not _looks_like_integer_dimension(n): raise TypeError("n must be a non-negative integer") @@ -372,7 +384,7 @@ def multinomial(n, pvals, size=None): sample_shape = _shape_from_size(size) device = pvals.device if _torch.is_tensor(pvals) else None - pvals = _torch.as_tensor(pvals, dtype=_torch.float32, device=device) + pvals = _validate_multinomial_pvals(pvals, device) if pvals.ndim != 1: raise ValueError("pvals must be 1-dimensional") if pvals.numel() == 0: From 1234961e0858d9966f0c8b2593c66d14a6b03884 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sat, 27 Jun 2026 11:00:30 +0200 Subject: [PATCH 3/4] Add JAX multinomial pvals regression tests --- .../test_jax_multinomial_real_pvals.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/backend/test_jax_multinomial_real_pvals.py diff --git a/tests/backend/test_jax_multinomial_real_pvals.py b/tests/backend/test_jax_multinomial_real_pvals.py new file mode 100644 index 000000000..a70c5b625 --- /dev/null +++ b/tests/backend/test_jax_multinomial_real_pvals.py @@ -0,0 +1,42 @@ +import numpy as np +import pytest + +jax = pytest.importorskip("jax") +import jax.numpy as jnp # noqa: E402 + +from pyrecest._backend.jax import random # noqa: E402 + + +@pytest.mark.parametrize( + "pvals", + [ + [True, False], + np.array([True, False]), + jnp.array([True, False]), + ], +) +def test_multinomial_rejects_boolean_pvals(pvals): + with pytest.raises(TypeError, match="real numeric"): + random.multinomial(1, pvals) + + +@pytest.mark.parametrize( + "pvals", + [ + [1.0 + 0.0j], + np.array([1.0 + 0.0j]), + jnp.array([1.0 + 0.0j]), + ], +) +def test_multinomial_rejects_complex_pvals(pvals): + with pytest.raises(TypeError, match="real numeric"): + random.multinomial(1, pvals) + + +def test_multinomial_accepts_real_pvals(): + random.seed(0) + + result = random.multinomial(2, [1.0, 0.0]) + + assert result.shape == (2,) + assert int(result.sum()) == 2 From a0bf7313246e4094a9a4011fcd26830f690357e6 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sat, 27 Jun 2026 11:00:45 +0200 Subject: [PATCH 4/4] Add PyTorch multinomial pvals regression tests --- .../test_pytorch_multinomial_real_pvals.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 tests/backend/test_pytorch_multinomial_real_pvals.py diff --git a/tests/backend/test_pytorch_multinomial_real_pvals.py b/tests/backend/test_pytorch_multinomial_real_pvals.py new file mode 100644 index 000000000..1b6a76725 --- /dev/null +++ b/tests/backend/test_pytorch_multinomial_real_pvals.py @@ -0,0 +1,41 @@ +import numpy as np +import pytest + +torch = pytest.importorskip("torch") + +from pyrecest._backend.pytorch import random # noqa: E402 + + +@pytest.mark.parametrize( + "pvals", + [ + [True, False], + np.array([True, False]), + torch.tensor([True, False]), + ], +) +def test_multinomial_rejects_boolean_pvals(pvals): + with pytest.raises(TypeError, match="real numeric"): + random.multinomial(1, pvals) + + +@pytest.mark.parametrize( + "pvals", + [ + [1.0 + 0.0j], + np.array([1.0 + 0.0j]), + torch.tensor([1.0 + 0.0j]), + ], +) +def test_multinomial_rejects_complex_pvals(pvals): + with pytest.raises(TypeError, match="real numeric"): + random.multinomial(1, pvals) + + +def test_multinomial_accepts_real_pvals(): + random.seed(0) + + result = random.multinomial(2, [1.0, 0.0]) + + assert result.shape == (2,) + assert int(result.sum()) == 2