Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/pyrecest/distributions/so3_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

# pylint: disable=no-name-in-module,no-member
from math import isfinite
from numbers import Integral, Real
from numbers import Integral

import numpy as np
from pyrecest.backend import array, diag, matmul, reshape, sum, transpose

from .conversion import register_conversion
Expand Down Expand Up @@ -36,11 +37,26 @@ def _validate_particle_count(n_particles):


def _validate_covariance_regularization(covariance_regularization):
if isinstance(covariance_regularization, bool) or not isinstance(
covariance_regularization, Real
try:
regularization_array = np.asarray(covariance_regularization)
except (TypeError, ValueError) as exc:
raise ValueError(_COVARIANCE_REGULARIZATION_ERROR) from exc

if regularization_array.shape != () or regularization_array.dtype.kind in "bSU":
raise ValueError(_COVARIANCE_REGULARIZATION_ERROR)

regularization_scalar = regularization_array.item()
if isinstance(
regularization_scalar,
(bool, np.bool_, str, bytes, bytearray, np.str_, np.bytes_),
):
raise ValueError(_COVARIANCE_REGULARIZATION_ERROR)
covariance_regularization = float(covariance_regularization)

try:
covariance_regularization = float(regularization_scalar)
except (TypeError, ValueError, OverflowError) as exc:
raise ValueError(_COVARIANCE_REGULARIZATION_ERROR) from exc

if not isfinite(covariance_regularization) or covariance_regularization < 0.0:
raise ValueError(_COVARIANCE_REGULARIZATION_ERROR)
return covariance_regularization
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np
import pytest

from pyrecest.distributions.so3_conversion import _validate_covariance_regularization


@pytest.mark.parametrize("value", [np.array(0.0), np.array(1e-6), np.float64(0.25)])
def test_covariance_regularization_accepts_numpy_scalar_values(value):
assert _validate_covariance_regularization(value) == pytest.approx(float(np.asarray(value)))


@pytest.mark.parametrize(
"value",
[np.array([0.0]), np.array(True), np.array(-1.0), np.array(np.inf), "0.1"],
)
def test_covariance_regularization_rejects_non_numeric_nonfinite_or_nonscalar_values(value):
with pytest.raises(ValueError, match="covariance_regularization"):
_validate_covariance_regularization(value)
Loading