From 520c70cebb6602d564ebb93dcc1311f56f7c9478 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sat, 27 Jun 2026 22:24:06 +0200 Subject: [PATCH 1/2] Validate measurement reliability counts --- .../filters/measurement_reliability.py | 67 ++++++++++++++++--- 1 file changed, 57 insertions(+), 10 deletions(-) diff --git a/src/pyrecest/filters/measurement_reliability.py b/src/pyrecest/filters/measurement_reliability.py index 67bd9b335..663b0ef52 100644 --- a/src/pyrecest/filters/measurement_reliability.py +++ b/src/pyrecest/filters/measurement_reliability.py @@ -7,6 +7,8 @@ from dataclasses import dataclass from typing import Any +import numpy as np + from pyrecest.backend import all, array, isfinite, ones, reshape, stack, zeros @@ -26,6 +28,51 @@ class MeasurementReliabilitySelection: active_measurement_indices: list[int] +def _normalize_integer_count(value: Any, name: str, *, minimum: int, message: str) -> int: + try: + value_array = np.asarray(value) + except (TypeError, ValueError) as exc: + raise ValueError(message) from exc + if value_array.shape != () or value_array.dtype == np.bool_: + raise ValueError(message) + + scalar = value_array.item() + if isinstance(scalar, (bool, np.bool_)): + raise ValueError(message) + if isinstance(scalar, (int, np.integer)): + parsed = int(scalar) + elif ( + isinstance(scalar, (float, np.floating)) + and np.isfinite(scalar) + and float(scalar).is_integer() + ): + parsed = int(scalar) + else: + raise ValueError(message) + + if parsed < minimum: + raise ValueError(message) + return parsed + + +def _normalize_nonnegative_integer(value: Any, name: str) -> int: + return _normalize_integer_count( + value, + name, + minimum=0, + message=f"{name} must be a non-negative integer", + ) + + +def _normalize_positive_integer(value: Any, name: str) -> int: + return _normalize_integer_count( + value, + name, + minimum=1, + message=f"{name} must be a positive integer", + ) + + def _has_boolean_dtype(value) -> bool: dtype = getattr(value, "dtype", None) return dtype is not None and "bool" in str(dtype).lower() @@ -48,7 +95,10 @@ def _is_real_numeric_object_value(value) -> bool: if kind is not None: return kind in {"i", "u", "f"} dtype_name = str(dtype).lower() - if any(token in dtype_name for token in ("bool", "complex", "str", "string", "object")): + if any( + token in dtype_name + for token in ("bool", "complex", "str", "string", "object") + ): return False if "float" in dtype_name or "int" in dtype_name: return True @@ -106,8 +156,7 @@ def normalize_measurement_weights(measurement_weights, n_measurements: int): that the corresponding measurement should be skipped. """ - if n_measurements < 0: - raise ValueError("n_measurements must be non-negative") + n_measurements = _normalize_nonnegative_integer(n_measurements, "n_measurements") if measurement_weights is None: return ones(n_measurements) @@ -147,8 +196,7 @@ def normalize_active_measurement_mask( ) -> list[bool]: """Return one boolean active/inactive flag per measurement.""" - if n_measurements < 0: - raise ValueError("n_measurements must be non-negative") + n_measurements = _normalize_nonnegative_integer(n_measurements, "n_measurements") if active_measurement_mask is None: return [True] * n_measurements @@ -172,6 +220,7 @@ def normalize_measurement_reliability( ) -> MeasurementReliabilitySelection: """Normalize measurement weights and active-mask inputs together.""" + n_measurements = _normalize_nonnegative_integer(n_measurements, "n_measurements") weights = normalize_measurement_weights(measurement_weights, n_measurements) mask = normalize_active_measurement_mask(active_measurement_mask, n_measurements) active_indices = [ @@ -202,10 +251,8 @@ def normalize_measurement_noise_covariances( tracker classes reuse their own covariance validation conventions. """ - if n_measurements < 0: - raise ValueError("n_measurements must be non-negative") - if measurement_dim <= 0: - raise ValueError("measurement_dim must be positive") + n_measurements = _normalize_nonnegative_integer(n_measurements, "n_measurements") + measurement_dim = _normalize_positive_integer(measurement_dim, "measurement_dim") if n_measurements == 0: return zeros((0, measurement_dim, measurement_dim)) @@ -236,4 +283,4 @@ def normalize_measurement_noise_covariances( shared_noise = as_covariance_matrix(noise, measurement_dim, name) if n_measurements == 0: return zeros(empty_shape) - return stack([shared_noise for _ in range(n_measurements)]) \ No newline at end of file + return stack([shared_noise for _ in range(n_measurements)]) From fdf73d627fdd42311824f80231fb2cd5885d7c7f Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sat, 27 Jun 2026 22:24:49 +0200 Subject: [PATCH 2/2] Test measurement reliability count validation --- ...easurement_reliability_count_validation.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 tests/filters/test_measurement_reliability_count_validation.py diff --git a/tests/filters/test_measurement_reliability_count_validation.py b/tests/filters/test_measurement_reliability_count_validation.py new file mode 100644 index 000000000..8d1f88e6d --- /dev/null +++ b/tests/filters/test_measurement_reliability_count_validation.py @@ -0,0 +1,62 @@ +import unittest + +import pyrecest.backend +from pyrecest.backend import array, eye +from pyrecest.filters import ( + normalize_active_measurement_mask, + normalize_measurement_noise_covariances, + normalize_measurement_weights, +) + + +def _as_covariance_matrix(value, dim, name): + matrix = array(value) + if matrix.ndim == 0: + matrix = matrix * eye(dim) + if matrix.ndim == 1: + if matrix.shape[0] != dim: + raise ValueError(f"{name} vector must have length {dim}") + matrix = matrix * eye(dim) + if matrix.shape != (dim, dim): + raise ValueError(f"{name} must have shape ({dim}, {dim})") + return matrix + + +@unittest.skipIf( + pyrecest.backend.__backend_name__ != "numpy", + reason="count validation test uses NumPy backend array shape checks", +) +class TestMeasurementReliabilityCountValidation(unittest.TestCase): + def test_measurement_count_must_be_nonnegative_integer(self): + invalid_counts = (True, False, 1.5, "2", array([2]), -1) + + for n_measurements in invalid_counts: + with self.subTest(n_measurements=n_measurements): + with self.assertRaisesRegex(ValueError, "n_measurements"): + normalize_measurement_weights(None, n_measurements) + with self.assertRaisesRegex(ValueError, "n_measurements"): + normalize_active_measurement_mask(None, n_measurements) + with self.assertRaisesRegex(ValueError, "n_measurements"): + normalize_measurement_noise_covariances( + 0.5, + n_measurements, + 2, + as_covariance_matrix=_as_covariance_matrix, + ) + + def test_measurement_dim_must_be_positive_integer(self): + invalid_dims = (True, False, 0, -1, 1.5, "2", array([2])) + + for measurement_dim in invalid_dims: + with self.subTest(measurement_dim=measurement_dim): + with self.assertRaisesRegex(ValueError, "measurement_dim"): + normalize_measurement_noise_covariances( + 0.5, + 1, + measurement_dim, + as_covariance_matrix=_as_covariance_matrix, + ) + + +if __name__ == "__main__": + unittest.main()