diff --git a/src/pyrecest/distributions/so3_conversion.py b/src/pyrecest/distributions/so3_conversion.py index b99dbfdf1..e89d9210b 100644 --- a/src/pyrecest/distributions/so3_conversion.py +++ b/src/pyrecest/distributions/so3_conversion.py @@ -2,8 +2,9 @@ # pylint: disable=no-name-in-module,no-member from math import isfinite -from numbers import Integral, Real +from numbers import Integral +import numpy as np from pyrecest.backend import array, diag, matmul, reshape, sum, transpose from .conversion import register_conversion @@ -36,11 +37,26 @@ def _validate_particle_count(n_particles): def _validate_covariance_regularization(covariance_regularization): - if isinstance(covariance_regularization, bool) or not isinstance( - covariance_regularization, Real + try: + regularization_array = np.asarray(covariance_regularization) + except (TypeError, ValueError) as exc: + raise ValueError(_COVARIANCE_REGULARIZATION_ERROR) from exc + + if regularization_array.shape != () or regularization_array.dtype.kind in "bSU": + raise ValueError(_COVARIANCE_REGULARIZATION_ERROR) + + regularization_scalar = regularization_array.item() + if isinstance( + regularization_scalar, + (bool, np.bool_, str, bytes, bytearray, np.str_, np.bytes_), ): raise ValueError(_COVARIANCE_REGULARIZATION_ERROR) - covariance_regularization = float(covariance_regularization) + + try: + covariance_regularization = float(regularization_scalar) + except (TypeError, ValueError, OverflowError) as exc: + raise ValueError(_COVARIANCE_REGULARIZATION_ERROR) from exc + if not isfinite(covariance_regularization) or covariance_regularization < 0.0: raise ValueError(_COVARIANCE_REGULARIZATION_ERROR) return covariance_regularization diff --git a/tests/distributions/test_so3_covariance_regularization_scalar_arrays.py b/tests/distributions/test_so3_covariance_regularization_scalar_arrays.py new file mode 100644 index 000000000..3400e5525 --- /dev/null +++ b/tests/distributions/test_so3_covariance_regularization_scalar_arrays.py @@ -0,0 +1,18 @@ +import numpy as np +import pytest + +from pyrecest.distributions.so3_conversion import _validate_covariance_regularization + + +@pytest.mark.parametrize("value", [np.array(0.0), np.array(1e-6), np.float64(0.25)]) +def test_covariance_regularization_accepts_numpy_scalar_values(value): + assert _validate_covariance_regularization(value) == pytest.approx(float(np.asarray(value))) + + +@pytest.mark.parametrize( + "value", + [np.array([0.0]), np.array(True), np.array(-1.0), np.array(np.inf), "0.1"], +) +def test_covariance_regularization_rejects_non_numeric_nonfinite_or_nonscalar_values(value): + with pytest.raises(ValueError, match="covariance_regularization"): + _validate_covariance_regularization(value)