From 2e92e0d22de31f50cabe06969a38f125f3b98f10 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sat, 27 Jun 2026 12:03:53 +0200 Subject: [PATCH 1/2] Validate NumPy multinomial sample count --- src/pyrecest/_backend/numpy/random.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/pyrecest/_backend/numpy/random.py b/src/pyrecest/_backend/numpy/random.py index 9bae51172..b35084e6d 100644 --- a/src/pyrecest/_backend/numpy/random.py +++ b/src/pyrecest/_backend/numpy/random.py @@ -46,6 +46,21 @@ def randint(low, high=None, size=None, dtype=int): return _np.random.randint(low, high=high, size=size, dtype=dtype) +def _validate_multinomial_sample_count(n): + if _contains_boolean_value(n): + raise TypeError("n must be a non-negative integer") + try: + n_array = _np.asarray(n) + except (TypeError, ValueError) as exc: + raise TypeError("n must be a non-negative integer") from exc + if n_array.shape != () or n_array.dtype.kind not in "iu": + raise TypeError("n must be a non-negative integer") + count = int(n_array.item()) + if count < 0: + raise ValueError("n must be non-negative") + return count + + def _validate_multinomial_pvals(pvals): if _contains_boolean_value(pvals): raise TypeError("pvals must be real numeric, not boolean") @@ -59,5 +74,6 @@ def _validate_multinomial_pvals(pvals): def multinomial(n, pvals, size=None): + n = _validate_multinomial_sample_count(n) pvals_array = _validate_multinomial_pvals(pvals) return _np.random.multinomial(n, pvals_array, size=size) From 9d5721e409cadfab726df095e29565529ad4f758 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sat, 27 Jun 2026 12:04:14 +0200 Subject: [PATCH 2/2] Add NumPy multinomial count validation tests --- tests/backend/test_numpy_random_backend.py | 31 ++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/backend/test_numpy_random_backend.py b/tests/backend/test_numpy_random_backend.py index a00c1672c..50da29082 100644 --- a/tests/backend/test_numpy_random_backend.py +++ b/tests/backend/test_numpy_random_backend.py @@ -70,6 +70,37 @@ def test_randint_accepts_integer_array_bounds(): assert np.all(samples < high) +@pytest.mark.parametrize( + "n", + [ + True, + False, + np.bool_(True), + np.array(True), + np.array([1]), + 1.5, + "1", + ], +) +def test_multinomial_rejects_non_integer_or_boolean_sample_counts(n): + with pytest.raises(TypeError, match="non-negative integer"): + random.multinomial(n, [1.0]) + + +def test_multinomial_rejects_negative_sample_counts(): + with pytest.raises(ValueError, match="non-negative"): + random.multinomial(-1, [1.0]) + + +def test_multinomial_accepts_integer_like_scalar_sample_counts(): + random.seed(0) + + samples = random.multinomial(np.array(2, dtype=np.int64), [0.25, 0.75], size=3) + + assert samples.shape == (3, 2) + assert np.all(samples.sum(axis=1) == 2) + + @pytest.mark.parametrize( ("low", "high"), [