Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions src/pyrecest/_backend/jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,16 @@ def _choice_bool(value, name):
raise TypeError(f"{name} must be a boolean")


def _maybe_preserve_choice_order(indices, *, replace, p, shuffle):
if replace or p is not None or shuffle:
return indices

index_array = _jnp.asarray(indices)
if index_array.ndim == 0:
return indices
return _jnp.sort(index_array.reshape(-1)).reshape(index_array.shape)


def _choice_population_size(a, kwargs):
population_size = _integer_population_size(a)
if population_size is not None:
Expand Down Expand Up @@ -398,11 +408,14 @@ def _validate_choice_probabilities(p, population_size):
return p / p_sum


def _choice(state, a, size=None, replace=True, p=None, *args, **kwargs):
def _choice(state, a, size=None, replace=True, p=None, shuffle=True, *args, **kwargs):
if args:
raise TypeError("choice() received unexpected positional arguments")
state, key = jax.random.split(state)
a = _jnp.asarray(a)
shape = _shape_from_size(size)
replace = _choice_bool(replace, "replace")
shuffle = _choice_bool(shuffle, "shuffle")
population_size = _choice_population_size(a, kwargs)
if population_size == 0:
if _shape_has_no_samples(shape):
Expand All @@ -412,26 +425,34 @@ def _choice(state, a, size=None, replace=True, p=None, *args, **kwargs):
raise ValueError("a must be a positive integer or a non-empty array")
if p is not None:
p = _validate_choice_probabilities(p, population_size)
res = jax.random.choice(
choice_kwargs = {name: value for name, value in kwargs.items() if name != "axis"}
indices = jax.random.choice(
key,
a,
*args,
population_size,
shape=shape,
replace=replace,
p=p,
**kwargs,
**choice_kwargs,
)
return state, res
indices = _maybe_preserve_choice_order(
indices,
replace=replace,
p=p,
shuffle=shuffle,
)
if a.ndim == 0:
return state, indices
return state, _jnp.take(a, indices, axis=kwargs.get("axis", 0))


def choice(a, size=None, replace=True, p=None, *args, **kwargs):
def choice(a, size=None, replace=True, p=None, shuffle=True, *args, **kwargs):
"""Draw samples using a NumPy-like ``choice`` contract."""
if "n" in kwargs:
if size is not None:
raise TypeError("Specify only one of 'size' or legacy 'n'.")
size = kwargs.pop("n")
state, has_state, kwargs = _get_state(**kwargs)
state, res = _choice(state, a, size, replace, p, *args, **kwargs)
state, res = _choice(state, a, size, replace, p, shuffle, *args, **kwargs)
return set_state_return(has_state, state, res)


Expand Down
30 changes: 30 additions & 0 deletions tests/backend/test_jax_random_choice_shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
import pytest

pytest.importorskip("jax")
import jax.numpy as jnp # noqa: E402
from pyrecest._backend.jax import random # noqa: E402


def test_choice_without_replacement_shuffle_false_preserves_order():
values = jnp.array([10, 20, 30, 40, 50])
matrix = jnp.array([[10, 20, 30], [40, 50, 60]])

random.seed(0)
samples = random.choice(values, size=values.shape[0], replace=False, shuffle=False)
column_samples = random.choice(
matrix,
size=matrix.shape[1],
replace=False,
axis=1,
shuffle=False,
)

assert jnp.array_equal(samples, values)
assert jnp.array_equal(column_samples, matrix)


@pytest.mark.parametrize("bad_shuffle", ["False", "True", 1, 0, None, np.array(True)])
def test_choice_rejects_non_boolean_shuffle_flag(bad_shuffle):
with pytest.raises(TypeError, match="shuffle must be a boolean"):
random.choice(jnp.array([0, 1, 2]), size=2, shuffle=bad_shuffle)
Loading