Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions src/pyrecest/_backend/pytorch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,40 @@ def _size_type_error():
return TypeError("size must be None, an integer, or a sequence of integers")


def _scalar_integer_dimension(value):
if isinstance(value, bool):
return None
if isinstance(value, _Integral):
return int(value)
if isinstance(value, _np.ndarray) and value.ndim == 0:
if _np.issubdtype(value.dtype, _np.integer) and not _np.issubdtype(
value.dtype, _np.bool_
):
return int(value.item())
return None
if (
_torch.is_tensor(value)
and value.ndim == 0
and value.dtype in _INTEGER_DTYPES
):
return int(value.item())
return None


def _looks_like_integer_dimension(value):
return isinstance(value, _Integral) and not isinstance(value, bool)
return _scalar_integer_dimension(value) is not None


def _is_zero_dimensional_array_like(value):
return (isinstance(value, _np.ndarray) and value.ndim == 0) or (
_torch.is_tensor(value) and value.ndim == 0
)


def _integer_dimension(value):
if not _looks_like_integer_dimension(value):
value = _scalar_integer_dimension(value)
if value is None:
raise _size_type_error()
value = int(value)
if value < 0:
raise ValueError("size dimensions must be non-negative")
return value
Expand All @@ -49,7 +75,11 @@ def _shape_from_size(size):
return ()
if _looks_like_integer_dimension(size):
return (_integer_dimension(size),)
if isinstance(size, (str, bytes)) or not hasattr(size, "__iter__"):
if (
isinstance(size, (str, bytes))
or _is_zero_dimensional_array_like(size)
or not hasattr(size, "__iter__")
):
raise _size_type_error()
return tuple(_integer_dimension(dim) for dim in size)

Expand Down Expand Up @@ -82,9 +112,12 @@ def _validate_choice_probabilities(p, population_size, device):
if _contains_boolean_value(p):
raise TypeError("p must be real numeric, not boolean")
try:
p = _torch.as_tensor(p, dtype=_torch.float32, device=device)
p = _torch.as_tensor(p, device=device)
except (TypeError, ValueError, RuntimeError) as exc:
raise TypeError("p must be real numeric") from exc
if not _is_real_numeric_dtype(p.dtype):
raise TypeError("p must be real numeric")
p = p.to(dtype=_torch.float32)
if p.ndim != 1 or p.shape[0] != population_size:
raise ValueError("p must be 1-dimensional with one entry per population item")

Expand Down
58 changes: 58 additions & 0 deletions tests/backend/test_pytorch_random_scalar_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
import pytest

torch = pytest.importorskip("torch")

from pyrecest._backend.pytorch import random # noqa: E402


def _size_aware_samplers():
return (
lambda size: random.rand(size=size),
lambda size: random.uniform(size=size),
lambda size: random.normal(size=size),
lambda size: random.randint(0, 3, size=size),
lambda size: random.choice(3, size=size),
lambda size: random.multivariate_normal([0.0], [[1.0]], size=size),
lambda size: random.multinomial(3, [0.25, 0.75], size=size),
)


@pytest.mark.parametrize(
"scalar_size",
[np.array(3, dtype=np.int64), torch.tensor(3, dtype=torch.int64)],
)
def test_size_arguments_accept_zero_dimensional_integer_arrays_and_tensors(scalar_size):
random.seed(0)

for sampler in _size_aware_samplers():
sample = sampler(scalar_size)

assert sample.shape[0] == 3


@pytest.mark.parametrize(
"bad_size",
[
np.array(True),
torch.tensor(True),
np.array(3.0),
torch.tensor(3.0),
],
)
def test_size_arguments_reject_zero_dimensional_non_integer_arrays_and_tensors(bad_size):
for sampler in _size_aware_samplers():
with pytest.raises(TypeError, match="size must"):
sampler(bad_size)


@pytest.mark.parametrize(
"probabilities",
[
np.array([0.5 + 0.0j, 0.5 + 0.0j]),
torch.tensor([0.5 + 0.0j, 0.5 + 0.0j]),
],
)
def test_choice_rejects_complex_probabilities(probabilities):
with pytest.raises(TypeError, match="real numeric"):
random.choice(2, size=1, p=probabilities)
Loading