From 04e9768abbed146d904e79e12bede621fd1c24d2 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sat, 27 Jun 2026 23:00:18 +0200 Subject: [PATCH 1/2] Support shuffle control in JAX random choice --- src/pyrecest/_backend/jax/random.py | 37 ++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/src/pyrecest/_backend/jax/random.py b/src/pyrecest/_backend/jax/random.py index 66eee7dd3..a7453bfee 100644 --- a/src/pyrecest/_backend/jax/random.py +++ b/src/pyrecest/_backend/jax/random.py @@ -363,6 +363,16 @@ def _choice_bool(value, name): raise TypeError(f"{name} must be a boolean") +def _maybe_preserve_choice_order(indices, *, replace, p, shuffle): + if replace or p is not None or shuffle: + return indices + + index_array = _jnp.asarray(indices) + if index_array.ndim == 0: + return indices + return _jnp.sort(index_array.reshape(-1)).reshape(index_array.shape) + + def _choice_population_size(a, kwargs): population_size = _integer_population_size(a) if population_size is not None: @@ -398,11 +408,14 @@ def _validate_choice_probabilities(p, population_size): return p / p_sum -def _choice(state, a, size=None, replace=True, p=None, *args, **kwargs): +def _choice(state, a, size=None, replace=True, p=None, shuffle=True, *args, **kwargs): + if args: + raise TypeError("choice() received unexpected positional arguments") state, key = jax.random.split(state) a = _jnp.asarray(a) shape = _shape_from_size(size) replace = _choice_bool(replace, "replace") + shuffle = _choice_bool(shuffle, "shuffle") population_size = _choice_population_size(a, kwargs) if population_size == 0: if _shape_has_no_samples(shape): @@ -412,26 +425,34 @@ def _choice(state, a, size=None, replace=True, p=None, *args, **kwargs): raise ValueError("a must be a positive integer or a non-empty array") if p is not None: p = _validate_choice_probabilities(p, population_size) - res = jax.random.choice( + choice_kwargs = {name: value for name, value in kwargs.items() if name != "axis"} + indices = jax.random.choice( key, - a, - *args, + population_size, shape=shape, replace=replace, p=p, - **kwargs, + **choice_kwargs, ) - return state, res + indices = _maybe_preserve_choice_order( + indices, + replace=replace, + p=p, + shuffle=shuffle, + ) + if a.ndim == 0: + return state, indices + return state, _jnp.take(a, indices, axis=kwargs.get("axis", 0)) -def choice(a, size=None, replace=True, p=None, *args, **kwargs): +def choice(a, size=None, replace=True, p=None, shuffle=True, *args, **kwargs): """Draw samples using a NumPy-like ``choice`` contract.""" if "n" in kwargs: if size is not None: raise TypeError("Specify only one of 'size' or legacy 'n'.") size = kwargs.pop("n") state, has_state, kwargs = _get_state(**kwargs) - state, res = _choice(state, a, size, replace, p, *args, **kwargs) + state, res = _choice(state, a, size, replace, p, shuffle, *args, **kwargs) return set_state_return(has_state, state, res) From f005cacc055ef081e9d4719a62b7ae870469ed69 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sat, 27 Jun 2026 23:01:07 +0200 Subject: [PATCH 2/2] Add JAX choice shuffle regression tests --- .../backend/test_jax_random_choice_shuffle.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 tests/backend/test_jax_random_choice_shuffle.py diff --git a/tests/backend/test_jax_random_choice_shuffle.py b/tests/backend/test_jax_random_choice_shuffle.py new file mode 100644 index 000000000..8944a615b --- /dev/null +++ b/tests/backend/test_jax_random_choice_shuffle.py @@ -0,0 +1,30 @@ +import numpy as np +import pytest + +pytest.importorskip("jax") +import jax.numpy as jnp # noqa: E402 +from pyrecest._backend.jax import random # noqa: E402 + + +def test_choice_without_replacement_shuffle_false_preserves_order(): + values = jnp.array([10, 20, 30, 40, 50]) + matrix = jnp.array([[10, 20, 30], [40, 50, 60]]) + + random.seed(0) + samples = random.choice(values, size=values.shape[0], replace=False, shuffle=False) + column_samples = random.choice( + matrix, + size=matrix.shape[1], + replace=False, + axis=1, + shuffle=False, + ) + + assert jnp.array_equal(samples, values) + assert jnp.array_equal(column_samples, matrix) + + +@pytest.mark.parametrize("bad_shuffle", ["False", "True", 1, 0, None, np.array(True)]) +def test_choice_rejects_non_boolean_shuffle_flag(bad_shuffle): + with pytest.raises(TypeError, match="shuffle must be a boolean"): + random.choice(jnp.array([0, 1, 2]), size=2, shuffle=bad_shuffle)