diff --git a/src/pyrecest/_backend/pytorch/random.py b/src/pyrecest/_backend/pytorch/random.py index ca7d1c4cf..b04a5ab58 100644 --- a/src/pyrecest/_backend/pytorch/random.py +++ b/src/pyrecest/_backend/pytorch/random.py @@ -31,14 +31,40 @@ def _size_type_error(): return TypeError("size must be None, an integer, or a sequence of integers") +def _scalar_integer_dimension(value): + if isinstance(value, bool): + return None + if isinstance(value, _Integral): + return int(value) + if isinstance(value, _np.ndarray) and value.ndim == 0: + if _np.issubdtype(value.dtype, _np.integer) and not _np.issubdtype( + value.dtype, _np.bool_ + ): + return int(value.item()) + return None + if ( + _torch.is_tensor(value) + and value.ndim == 0 + and value.dtype in _INTEGER_DTYPES + ): + return int(value.item()) + return None + + def _looks_like_integer_dimension(value): - return isinstance(value, _Integral) and not isinstance(value, bool) + return _scalar_integer_dimension(value) is not None + + +def _is_zero_dimensional_array_like(value): + return (isinstance(value, _np.ndarray) and value.ndim == 0) or ( + _torch.is_tensor(value) and value.ndim == 0 + ) def _integer_dimension(value): - if not _looks_like_integer_dimension(value): + value = _scalar_integer_dimension(value) + if value is None: raise _size_type_error() - value = int(value) if value < 0: raise ValueError("size dimensions must be non-negative") return value @@ -49,7 +75,11 @@ def _shape_from_size(size): return () if _looks_like_integer_dimension(size): return (_integer_dimension(size),) - if isinstance(size, (str, bytes)) or not hasattr(size, "__iter__"): + if ( + isinstance(size, (str, bytes)) + or _is_zero_dimensional_array_like(size) + or not hasattr(size, "__iter__") + ): raise _size_type_error() return tuple(_integer_dimension(dim) for dim in size) @@ -82,9 +112,12 @@ def _validate_choice_probabilities(p, population_size, device): if _contains_boolean_value(p): raise TypeError("p must be real numeric, not boolean") try: - p = _torch.as_tensor(p, dtype=_torch.float32, device=device) + p = _torch.as_tensor(p, device=device) except (TypeError, ValueError, RuntimeError) as exc: raise TypeError("p must be real numeric") from exc + if not _is_real_numeric_dtype(p.dtype): + raise TypeError("p must be real numeric") + p = p.to(dtype=_torch.float32) if p.ndim != 1 or p.shape[0] != population_size: raise ValueError("p must be 1-dimensional with one entry per population item") diff --git a/tests/backend/test_pytorch_random_scalar_inputs.py b/tests/backend/test_pytorch_random_scalar_inputs.py new file mode 100644 index 000000000..391d16c4b --- /dev/null +++ b/tests/backend/test_pytorch_random_scalar_inputs.py @@ -0,0 +1,58 @@ +import numpy as np +import pytest + +torch = pytest.importorskip("torch") + +from pyrecest._backend.pytorch import random # noqa: E402 + + +def _size_aware_samplers(): + return ( + lambda size: random.rand(size=size), + lambda size: random.uniform(size=size), + lambda size: random.normal(size=size), + lambda size: random.randint(0, 3, size=size), + lambda size: random.choice(3, size=size), + lambda size: random.multivariate_normal([0.0], [[1.0]], size=size), + lambda size: random.multinomial(3, [0.25, 0.75], size=size), + ) + + +@pytest.mark.parametrize( + "scalar_size", + [np.array(3, dtype=np.int64), torch.tensor(3, dtype=torch.int64)], +) +def test_size_arguments_accept_zero_dimensional_integer_arrays_and_tensors(scalar_size): + random.seed(0) + + for sampler in _size_aware_samplers(): + sample = sampler(scalar_size) + + assert sample.shape[0] == 3 + + +@pytest.mark.parametrize( + "bad_size", + [ + np.array(True), + torch.tensor(True), + np.array(3.0), + torch.tensor(3.0), + ], +) +def test_size_arguments_reject_zero_dimensional_non_integer_arrays_and_tensors(bad_size): + for sampler in _size_aware_samplers(): + with pytest.raises(TypeError, match="size must"): + sampler(bad_size) + + +@pytest.mark.parametrize( + "probabilities", + [ + np.array([0.5 + 0.0j, 0.5 + 0.0j]), + torch.tensor([0.5 + 0.0j, 0.5 + 0.0j]), + ], +) +def test_choice_rejects_complex_probabilities(probabilities): + with pytest.raises(TypeError, match="real numeric"): + random.choice(2, size=1, p=probabilities)