Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions src/pyrecest/experimental/dvs/event_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@

import numpy as np

_TEXT_SCALAR_TYPES = (str, bytes, bytearray, np.str_, np.bytes_)
_BOOL_SCALAR_TYPES = (bool, np.bool_)


def _as_finite_scalar(value: float, message: str) -> float:
value_array = np.asarray(value)
try:
value_array = np.asarray(value)
except (TypeError, ValueError) as exc:
raise ValueError(message) from exc
if value_array.shape != () or value_array.dtype == np.bool_:
raise ValueError(message)
scalar = value_array.item()
if isinstance(scalar, (bool, np.bool_, str, bytes, bytearray)):
if isinstance(scalar, (*_BOOL_SCALAR_TYPES, *_TEXT_SCALAR_TYPES)):
raise ValueError(message)
try:
result = float(scalar)
Expand All @@ -39,6 +45,31 @@ def _validate_nonnegative_finite(value: float, name: str) -> float:
return value


def _validate_integer_greater_than(value: int, name: str, lower_bound: int) -> int:
message = f"{name} must be greater than {lower_bound}"
try:
value_array = np.asarray(value)
except (TypeError, ValueError) as exc:
raise ValueError(message) from exc
if value_array.shape != () or value_array.dtype == np.bool_:
raise ValueError(message)
scalar = value_array.item()
if isinstance(scalar, (*_BOOL_SCALAR_TYPES, *_TEXT_SCALAR_TYPES)):
raise ValueError(message)
if isinstance(scalar, (int, np.integer)):
parsed = int(scalar)
elif isinstance(scalar, (float, np.floating)):
scalar_float = float(scalar)
if not np.isfinite(scalar_float) or not scalar_float.is_integer():
raise ValueError(message)
parsed = int(scalar_float)
else:
raise ValueError(message)
if parsed <= int(lower_bound):
raise ValueError(message)
return parsed


def _validate_finite_array(values: np.ndarray, name: str) -> None:
if np.any(~np.isfinite(values)):
raise ValueError(f"{name} must contain only finite values")
Expand Down Expand Up @@ -116,8 +147,12 @@ class PointProcessUpdateConfig:
max_state_update_norm: float = 5.0

def __post_init__(self) -> None:
if self.contour_samples <= 2:
raise ValueError("contour_samples must be greater than 2")
contour_samples = _validate_integer_greater_than(
self.contour_samples,
"contour_samples",
2,
)
object.__setattr__(self, "contour_samples", contour_samples)
_validate_positive_finite(self.finite_difference_eps, "finite_difference_eps")
_validate_nonnegative_finite(self.map_step_size, "map_step_size")
if self.max_map_iterations < 0:
Expand Down Expand Up @@ -368,9 +403,12 @@ def _resolve_scgp_likelihood_arguments(
else:
likelihood_config = config or EventLikelihoodConfig()
sample_count = 96 if contour_samples is None else contour_samples
if int(sample_count) <= 2:
raise ValueError("contour_samples must be greater than 2")
return likelihood_config, int(sample_count)
sample_count = _validate_integer_greater_than(
sample_count,
"contour_samples",
2,
)
return likelihood_config, sample_count


def _gaussian_contour_kernel(
Expand Down
25 changes: 25 additions & 0 deletions tests/experimental/test_dvs_event_likelihood_counts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import unittest

import numpy as np

from pyrecest.experimental.dvs.event_likelihood import PointProcessUpdateConfig


class TestDVSPointProcessCountValidation(unittest.TestCase):
def test_rejects_fractional_contour_samples(self):
with self.assertRaisesRegex(ValueError, "contour_samples"):
PointProcessUpdateConfig(contour_samples=3.5)

def test_rejects_text_contour_samples(self):
with self.assertRaisesRegex(ValueError, "contour_samples"):
PointProcessUpdateConfig(contour_samples="5")

def test_normalizes_integer_like_contour_samples(self):
config = PointProcessUpdateConfig(contour_samples=np.array(5.0))

self.assertEqual(config.contour_samples, 5)
self.assertIsInstance(config.contour_samples, int)


if __name__ == "__main__":
unittest.main()
Loading