diff --git a/src/pyrecest/_backend/autograd/random.py b/src/pyrecest/_backend/autograd/random.py index f1ada7402..261164cec 100644 --- a/src/pyrecest/_backend/autograd/random.py +++ b/src/pyrecest/_backend/autograd/random.py @@ -1,6 +1,22 @@ """Autograd based random backend.""" import autograd.numpy as _np -from autograd.numpy.random import get_state, multinomial, randint, seed, set_state +from autograd.numpy.random import get_state, randint, seed, set_state +from autograd.numpy.random import multinomial as _autograd_multinomial from .._shared_numpy.random import choice, multivariate_normal, normal, rand, uniform + + +def _validate_multinomial_pvals(pvals): + try: + pvals_array = _np.asarray(pvals) + except (TypeError, ValueError) as exc: + raise TypeError("pvals must be real numeric") from exc + if pvals_array.dtype.kind not in "iuf": + raise TypeError("pvals must be real numeric") + return pvals_array + + +def multinomial(n, pvals, size=None): + pvals_array = _validate_multinomial_pvals(pvals) + return _autograd_multinomial(n, pvals_array, size=size) diff --git a/tests/backend/test_autograd_random_backend.py b/tests/backend/test_autograd_random_backend.py new file mode 100644 index 000000000..79a399d2e --- /dev/null +++ b/tests/backend/test_autograd_random_backend.py @@ -0,0 +1,22 @@ +import numpy as np +import pytest + +pytest.importorskip("autograd") + +from pyrecest._backend.autograd import random # noqa: E402 + + +def test_multinomial_accepts_real_probability_values(): + draw = getattr(random, "multinomial") + sample = draw(4, [0.25, 0.75]) + + assert sample.shape == (2,) + assert int(np.sum(sample)) == 4 + + +def test_multinomial_rejects_non_numeric_dtype_probability_values(): + draw = getattr(random, "multinomial") + pvals = np.array([1, 0], dtype="?") + + with pytest.raises(TypeError, match="real numeric"): + draw(4, pvals)