diff --git a/src/pyrecest/_backend/numpy/random.py b/src/pyrecest/_backend/numpy/random.py index b35084e6d..ea24531f6 100644 --- a/src/pyrecest/_backend/numpy/random.py +++ b/src/pyrecest/_backend/numpy/random.py @@ -73,7 +73,14 @@ def _validate_multinomial_pvals(pvals): return pvals_array +def _validate_multinomial_size(size): + if size is not None and _contains_boolean_value(size): + raise TypeError("size must be None, an integer, or a sequence of integers") + return size + + def multinomial(n, pvals, size=None): n = _validate_multinomial_sample_count(n) pvals_array = _validate_multinomial_pvals(pvals) + size = _validate_multinomial_size(size) return _np.random.multinomial(n, pvals_array, size=size) diff --git a/tests/backend/test_numpy_random_backend.py b/tests/backend/test_numpy_random_backend.py index 50da29082..c834684c3 100644 --- a/tests/backend/test_numpy_random_backend.py +++ b/tests/backend/test_numpy_random_backend.py @@ -101,6 +101,32 @@ def test_multinomial_accepts_integer_like_scalar_sample_counts(): assert np.all(samples.sum(axis=1) == 2) +@pytest.mark.parametrize( + "bad_size", + [ + True, + False, + np.bool_(True), + (True,), + [np.bool_(False), 2], + np.array(True), + np.array([True, 2], dtype=object), + ], +) +def test_multinomial_rejects_boolean_size_arguments(bad_size): + with pytest.raises(TypeError, match="size"): + random.multinomial(2, [0.25, 0.75], size=bad_size) + + +def test_multinomial_accepts_integer_like_size_argument(): + random.seed(0) + + samples = random.multinomial(2, [0.25, 0.75], size=np.array(3, dtype=np.int64)) + + assert samples.shape == (3, 2) + assert np.all(samples.sum(axis=1) == 2) + + @pytest.mark.parametrize( ("low", "high"), [