Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 57 additions & 10 deletions src/pyrecest/filters/measurement_reliability.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from dataclasses import dataclass
from typing import Any

import numpy as np

from pyrecest.backend import all, array, isfinite, ones, reshape, stack, zeros


Expand All @@ -26,6 +28,51 @@ class MeasurementReliabilitySelection:
active_measurement_indices: list[int]


def _normalize_integer_count(value: Any, name: str, *, minimum: int, message: str) -> int:
try:
value_array = np.asarray(value)
except (TypeError, ValueError) as exc:
raise ValueError(message) from exc
if value_array.shape != () or value_array.dtype == np.bool_:
raise ValueError(message)

scalar = value_array.item()
if isinstance(scalar, (bool, np.bool_)):
raise ValueError(message)
if isinstance(scalar, (int, np.integer)):
parsed = int(scalar)
elif (
isinstance(scalar, (float, np.floating))
and np.isfinite(scalar)
and float(scalar).is_integer()
):
parsed = int(scalar)
else:
raise ValueError(message)

if parsed < minimum:
raise ValueError(message)
return parsed


def _normalize_nonnegative_integer(value: Any, name: str) -> int:
return _normalize_integer_count(
value,
name,
minimum=0,
message=f"{name} must be a non-negative integer",
)


def _normalize_positive_integer(value: Any, name: str) -> int:
return _normalize_integer_count(
value,
name,
minimum=1,
message=f"{name} must be a positive integer",
)


def _has_boolean_dtype(value) -> bool:
dtype = getattr(value, "dtype", None)
return dtype is not None and "bool" in str(dtype).lower()
Expand All @@ -48,7 +95,10 @@ def _is_real_numeric_object_value(value) -> bool:
if kind is not None:
return kind in {"i", "u", "f"}
dtype_name = str(dtype).lower()
if any(token in dtype_name for token in ("bool", "complex", "str", "string", "object")):
if any(
token in dtype_name
for token in ("bool", "complex", "str", "string", "object")
):
return False
if "float" in dtype_name or "int" in dtype_name:
return True
Expand Down Expand Up @@ -106,8 +156,7 @@ def normalize_measurement_weights(measurement_weights, n_measurements: int):
that the corresponding measurement should be skipped.
"""

if n_measurements < 0:
raise ValueError("n_measurements must be non-negative")
n_measurements = _normalize_nonnegative_integer(n_measurements, "n_measurements")
if measurement_weights is None:
return ones(n_measurements)

Expand Down Expand Up @@ -147,8 +196,7 @@ def normalize_active_measurement_mask(
) -> list[bool]:
"""Return one boolean active/inactive flag per measurement."""

if n_measurements < 0:
raise ValueError("n_measurements must be non-negative")
n_measurements = _normalize_nonnegative_integer(n_measurements, "n_measurements")
if active_measurement_mask is None:
return [True] * n_measurements

Expand All @@ -172,6 +220,7 @@ def normalize_measurement_reliability(
) -> MeasurementReliabilitySelection:
"""Normalize measurement weights and active-mask inputs together."""

n_measurements = _normalize_nonnegative_integer(n_measurements, "n_measurements")
weights = normalize_measurement_weights(measurement_weights, n_measurements)
mask = normalize_active_measurement_mask(active_measurement_mask, n_measurements)
active_indices = [
Expand Down Expand Up @@ -202,10 +251,8 @@ def normalize_measurement_noise_covariances(
tracker classes reuse their own covariance validation conventions.
"""

if n_measurements < 0:
raise ValueError("n_measurements must be non-negative")
if measurement_dim <= 0:
raise ValueError("measurement_dim must be positive")
n_measurements = _normalize_nonnegative_integer(n_measurements, "n_measurements")
measurement_dim = _normalize_positive_integer(measurement_dim, "measurement_dim")

if n_measurements == 0:
return zeros((0, measurement_dim, measurement_dim))
Expand Down Expand Up @@ -236,4 +283,4 @@ def normalize_measurement_noise_covariances(
shared_noise = as_covariance_matrix(noise, measurement_dim, name)
if n_measurements == 0:
return zeros(empty_shape)
return stack([shared_noise for _ in range(n_measurements)])
return stack([shared_noise for _ in range(n_measurements)])
62 changes: 62 additions & 0 deletions tests/filters/test_measurement_reliability_count_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import unittest

import pyrecest.backend
from pyrecest.backend import array, eye
from pyrecest.filters import (
normalize_active_measurement_mask,
normalize_measurement_noise_covariances,
normalize_measurement_weights,
)


def _as_covariance_matrix(value, dim, name):
matrix = array(value)
if matrix.ndim == 0:
matrix = matrix * eye(dim)
if matrix.ndim == 1:
if matrix.shape[0] != dim:
raise ValueError(f"{name} vector must have length {dim}")
matrix = matrix * eye(dim)
if matrix.shape != (dim, dim):
raise ValueError(f"{name} must have shape ({dim}, {dim})")
return matrix


@unittest.skipIf(
pyrecest.backend.__backend_name__ != "numpy",
reason="count validation test uses NumPy backend array shape checks",
)
class TestMeasurementReliabilityCountValidation(unittest.TestCase):
def test_measurement_count_must_be_nonnegative_integer(self):
invalid_counts = (True, False, 1.5, "2", array([2]), -1)

for n_measurements in invalid_counts:
with self.subTest(n_measurements=n_measurements):
with self.assertRaisesRegex(ValueError, "n_measurements"):
normalize_measurement_weights(None, n_measurements)
with self.assertRaisesRegex(ValueError, "n_measurements"):
normalize_active_measurement_mask(None, n_measurements)
with self.assertRaisesRegex(ValueError, "n_measurements"):
normalize_measurement_noise_covariances(
0.5,
n_measurements,
2,
as_covariance_matrix=_as_covariance_matrix,
)

def test_measurement_dim_must_be_positive_integer(self):
invalid_dims = (True, False, 0, -1, 1.5, "2", array([2]))

for measurement_dim in invalid_dims:
with self.subTest(measurement_dim=measurement_dim):
with self.assertRaisesRegex(ValueError, "measurement_dim"):
normalize_measurement_noise_covariances(
0.5,
1,
measurement_dim,
as_covariance_matrix=_as_covariance_matrix,
)


if __name__ == "__main__":
unittest.main()
Loading