Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/pyrecest/_backend/jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,18 @@ def multivariate_normal(mean, cov, size=None, *args, **kwargs):
return set_state_return(has_state, state, res)


def _validate_multinomial_pvals(pvals):
if _contains_boolean_value(pvals):
raise TypeError("pvals must be real numeric, not boolean")
try:
pvals = _jnp.asarray(pvals)
except (TypeError, ValueError, RuntimeError) as exc:
raise TypeError("pvals must be real numeric") from exc
if pvals.dtype.kind not in "iuf":
raise TypeError("pvals must be real numeric")
return pvals.astype(_jnp.float32)


def _multinomial(state, n, pvals, size=None):
if not _looks_like_integer_dimension(n):
raise TypeError("n must be a non-negative integer")
Expand All @@ -463,7 +475,7 @@ def _multinomial(state, n, pvals, size=None):

state, key = jax.random.split(state)
sample_shape = _shape_from_size(size)
pvals = _jnp.asarray(pvals, dtype=_jnp.float32)
pvals = _validate_multinomial_pvals(pvals)
if pvals.ndim != 1:
raise ValueError("pvals must be 1-dimensional")
if pvals.shape[0] == 0:
Expand Down
14 changes: 13 additions & 1 deletion src/pyrecest/_backend/pytorch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,18 @@ def _multinomial_sample_count(sample_shape):
return _prod(sample_shape) if sample_shape else 1


def _validate_multinomial_pvals(pvals, device):
if _contains_boolean_value(pvals):
raise TypeError("pvals must be real numeric, not boolean")
try:
pvals = _torch.as_tensor(pvals, device=device)
except (TypeError, ValueError, RuntimeError) as exc:
raise TypeError("pvals must be real numeric") from exc
if not _is_real_numeric_dtype(pvals.dtype):
raise TypeError("pvals must be real numeric")
return pvals.to(dtype=_torch.float32)


def multinomial(n, pvals, size=None):
if not _looks_like_integer_dimension(n):
raise TypeError("n must be a non-negative integer")
Expand All @@ -372,7 +384,7 @@ def multinomial(n, pvals, size=None):

sample_shape = _shape_from_size(size)
device = pvals.device if _torch.is_tensor(pvals) else None
pvals = _torch.as_tensor(pvals, dtype=_torch.float32, device=device)
pvals = _validate_multinomial_pvals(pvals, device)
if pvals.ndim != 1:
raise ValueError("pvals must be 1-dimensional")
if pvals.numel() == 0:
Expand Down
42 changes: 42 additions & 0 deletions tests/backend/test_jax_multinomial_real_pvals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
import pytest

jax = pytest.importorskip("jax")
import jax.numpy as jnp # noqa: E402

from pyrecest._backend.jax import random # noqa: E402


@pytest.mark.parametrize(
"pvals",
[
[True, False],
np.array([True, False]),
jnp.array([True, False]),
],
)
def test_multinomial_rejects_boolean_pvals(pvals):
with pytest.raises(TypeError, match="real numeric"):
random.multinomial(1, pvals)


@pytest.mark.parametrize(
"pvals",
[
[1.0 + 0.0j],
np.array([1.0 + 0.0j]),
jnp.array([1.0 + 0.0j]),
],
)
def test_multinomial_rejects_complex_pvals(pvals):
with pytest.raises(TypeError, match="real numeric"):
random.multinomial(1, pvals)


def test_multinomial_accepts_real_pvals():
random.seed(0)

result = random.multinomial(2, [1.0, 0.0])

assert result.shape == (2,)
assert int(result.sum()) == 2
41 changes: 41 additions & 0 deletions tests/backend/test_pytorch_multinomial_real_pvals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
import pytest

torch = pytest.importorskip("torch")

from pyrecest._backend.pytorch import random # noqa: E402


@pytest.mark.parametrize(
"pvals",
[
[True, False],
np.array([True, False]),
torch.tensor([True, False]),
],
)
def test_multinomial_rejects_boolean_pvals(pvals):
with pytest.raises(TypeError, match="real numeric"):
random.multinomial(1, pvals)


@pytest.mark.parametrize(
"pvals",
[
[1.0 + 0.0j],
np.array([1.0 + 0.0j]),
torch.tensor([1.0 + 0.0j]),
],
)
def test_multinomial_rejects_complex_pvals(pvals):
with pytest.raises(TypeError, match="real numeric"):
random.multinomial(1, pvals)


def test_multinomial_accepts_real_pvals():
random.seed(0)

result = random.multinomial(2, [1.0, 0.0])

assert result.shape == (2,)
assert int(result.sum()) == 2
Loading