Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion scripts/check_benchmark_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def _flatten_numbers(value: Any, *, path: str) -> list[float]:
raise TypeError(f"{path} must be a number or nested sequence of numbers")


def _finite_nonnegative_number(value: Any, *, path: str) -> tuple[float | None, str | None]:
def _finite_nonnegative_number(
value: Any, *, path: str
) -> tuple[float | None, str | None]:
if isinstance(value, bool):
return None, f"{path} must be numeric, not boolean"
if not isinstance(value, int | float):
Expand Down
7 changes: 6 additions & 1 deletion scripts/check_public_api_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def _load_backend_capabilities() -> dict[str, dict[str, str]]:


def _markdown_table_cell(value: object) -> str:
return str(value).replace("\r", " ").replace("\n", "<br>").replace(chr(124), chr(0xFF5C))
return (
str(value)
.replace("\r", " ")
.replace("\n", "<br>")
.replace(chr(124), chr(0xFF5C))
)


def validate_registry() -> list[str]:
Expand Down
1 change: 0 additions & 1 deletion src/pyrecest/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,5 @@

from pyrecest.cli import main


if __name__ == "__main__": # pragma: no cover
sys.exit(main())
4 changes: 4 additions & 0 deletions src/pyrecest/_backend/numpy/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
from pyrecest._backend._dtype_utils import (
_dyn_update_dtype,
_modify_func_default_dtype,
)
from pyrecest._backend._dtype_utils import (
get_default_cdtype as _shared_get_default_cdtype,
)
from pyrecest._backend._dtype_utils import (
get_default_dtype as _shared_get_default_dtype,
)

Expand Down
4 changes: 3 additions & 1 deletion src/pyrecest/_backend/pytorch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
inv,
)
from torch.linalg import matrix_exp as expm
from torch.linalg import matrix_power
from torch.linalg import (
matrix_power,
)

from .._backend_config import np_atol as atol
from ..numpy import linalg as _gsnplinalg
Expand Down
10 changes: 4 additions & 6 deletions src/pyrecest/_backend/pytorch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,7 @@ def _validate_choice_probabilities(p, population_size, device):
raise ValueError("p must be 1-dimensional with one entry per population item")

p_sum = p.sum()
if (
bool(_torch.any(p < 0))
or not bool(_torch.isfinite(p_sum))
or bool(p_sum <= 0)
):
if bool(_torch.any(p < 0)) or not bool(_torch.isfinite(p_sum)) or bool(p_sum <= 0):
raise ValueError("probabilities do not sum to a positive value")
return p / p_sum

Expand Down Expand Up @@ -209,7 +205,9 @@ def _normal_size(size):

def _broadcasted_parameter_shape(*parameters, message):
try:
return tuple(_torch.broadcast_shapes(*(parameter.shape for parameter in parameters)))
return tuple(
_torch.broadcast_shapes(*(parameter.shape for parameter in parameters))
)
except RuntimeError as exc:
raise ValueError(message) from exc

Expand Down
14 changes: 7 additions & 7 deletions src/pyrecest/calibration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
)
from .time_offset import (
TimeOffsetFitResult,
aggregate_time_offset_sweeps as _aggregate_time_offset_sweeps,
_aggregate_summary_metric,
_as_nonnegative_summary_count,
_as_summary_scalar,
_validate_error_metric,
)
from .time_offset import aggregate_time_offset_sweeps as _aggregate_time_offset_sweeps
from .time_offset import (
apply_time_offset,
fit_time_offset,
interpolate_reference_values,
Expand All @@ -24,12 +30,6 @@
time_offset_error_summary,
time_offset_sweep,
)
from .time_offset import (
_aggregate_summary_metric,
_as_nonnegative_summary_count,
_as_summary_scalar,
_validate_error_metric,
)


def aggregate_time_offset_sweeps(
Expand Down
1 change: 0 additions & 1 deletion src/pyrecest/calibration/time_offset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import numpy as np


_ERROR_METRIC_NAMES = frozenset({"max", "mean", "p95", "rmse", "std"})
_ERROR_METRIC_MESSAGE = "metric must be one of 'max', 'mean', 'p95', 'rmse', or 'std'"

Expand Down
3 changes: 1 addition & 2 deletions src/pyrecest/distributions/abstract_grid_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def __init__(
)
if dim is not None and actual_dim != dim:
raise ValueError(
f"Grid coordinates must have dimension {dim}, got "
f"{actual_dim}."
f"Grid coordinates must have dimension {dim}, got " f"{actual_dim}."
)
if grid is None or (grid.ndim > 1 and grid.shape[0] < grid.shape[1]):
warnings.warn(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
# pylint: disable=no-name-in-module,no-member,redefined-builtin
from pyrecest.backend import empty, int32, int64, log, random, squeeze

_SCALAR_VALUE_ERROR = "Metropolis-Hastings scalar evaluations must return scalar values."
_SCALAR_VALUE_ERROR = (
"Metropolis-Hastings scalar evaluations must return scalar values."
)


def _shape_size(value) -> int | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def _pdf_via_sinc(self, xs, sinc_repetitions):
lower = int(floor(sinc_repetitions / 2) * grid_size)
upper = int(ceil(sinc_repetitions / 2) * grid_size)
repetitions = arange(-lower, upper)
sinc_vals = self._matlab_sinc((xs_eval / step_size)[:, None] - repetitions[None, :])
sinc_vals = self._matlab_sinc(
(xs_eval / step_size)[:, None] - repetitions[None, :]
)
grid_values = (
sqrt(self.grid_values) if self.enforce_pdf_nonnegative else self.grid_values
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,7 @@ def entropy(self):
# distribution on [0, 2*pi). The direct expression evaluates
# log1p(-exp(-log_beta)) and divides by 1 - exp(-log_beta), which
# suffers catastrophic cancellation for tiny log_beta.
return (
log(2.0 * pi)
- log_beta**2 / 24.0
+ log_beta**4 / 960.0
)
return log(2.0 * pi) - log_beta**2 / 24.0 + log_beta**4 / 960.0

# Use exp(-2*pi*lambda) to avoid overflowing exp(2*pi*lambda) for
# concentrated wrapped exponentials.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
AbstractBoundedNonPeriodicDistribution,
)


_ERROR_SCALAR_PDF_VALUE = (
"pdf must return one finite scalar value per integration point"
)
Expand Down
8 changes: 4 additions & 4 deletions src/pyrecest/evaluation/generate_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def generate_measurements(groundtruth, simulation_config):
y = _as_shapely_scalar(curr_groundtruth[..., 1], "groundtruth y")
curr_shape = translate(shape, xoff=x, yoff=y)
elif curr_groundtruth.shape[-1] == 3:
angle = _as_shapely_scalar(curr_groundtruth[..., 0], "groundtruth angle")
angle = _as_shapely_scalar(
curr_groundtruth[..., 0], "groundtruth angle"
)
x = _as_shapely_scalar(curr_groundtruth[..., 1], "groundtruth x")
y = _as_shapely_scalar(curr_groundtruth[..., 2], "groundtruth y")
curr_shape = rotate(
Expand All @@ -160,9 +162,7 @@ def generate_measurements(groundtruth, simulation_config):
"Currently only R^2 and SE(2) scenarios are supported."
)
if not isinstance(curr_shape, PolygonWithSampling):
curr_shape.__class__ = (
PolygonWithSampling # Preserve existing subclass swap to add sampling methods
)
curr_shape.__class__ = PolygonWithSampling # Preserve existing subclass swap to add sampling methods

if "n_meas_at_individual_time_step" in simulation_config:
if "intensity_lambda" in simulation_config:
Expand Down
4 changes: 3 additions & 1 deletion src/pyrecest/evaluation/pareto.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ def _coerce_numeric(value: Any) -> float:
if value_array.shape != () or value_array.dtype.kind in "bSUcMm":
return float("nan")
scalar = value_array.item()
if isinstance(scalar, (bool, np.bool_, complex, np.complexfloating)) or _is_text_scalar(scalar):
if isinstance(
scalar, (bool, np.bool_, complex, np.complexfloating)
) or _is_text_scalar(scalar):
return float("nan")
try:
return float(scalar)
Expand Down
5 changes: 4 additions & 1 deletion src/pyrecest/evidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ class EvidenceComputationMode:
metadata: dict[str, Any] | None = field(default_factory=dict)

def __post_init__(self) -> None:
if not isinstance(self.mode, str) or self.mode not in {"full_smoothing", "evidence_only"}:
if not isinstance(self.mode, str) or self.mode not in {
"full_smoothing",
"evidence_only",
}:
raise ValueError(f"unknown evidence computation mode {self.mode!r}")
return_smoothed = _coerce_bool_flag(self.return_smoothed, "return_smoothed")
terminal_posterior = _coerce_bool_flag(
Expand Down
32 changes: 24 additions & 8 deletions src/pyrecest/experimental/dvs/normal_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def _has_non_real_numeric_values(value: object, *, allow_bool: bool = False) ->
return array.dtype.kind in {"M", "m"}


def _as_finite_real_array(name: str, value: object, *, allow_bool: bool = False) -> np.ndarray:
def _as_finite_real_array(
name: str, value: object, *, allow_bool: bool = False
) -> np.ndarray:
message = f"{name} must contain finite real numeric values."
if _has_non_real_numeric_values(value, allow_bool=allow_bool):
raise ValueError(message)
Expand All @@ -70,7 +72,9 @@ def _as_finite_real_array(name: str, value: object, *, allow_bool: bool = False)
return array


def _as_finite_real_scalar(name: str, value: object, *, allow_bool: bool = False) -> float:
def _as_finite_real_scalar(
name: str, value: object, *, allow_bool: bool = False
) -> float:
message = f"{name} must be a finite real scalar."
try:
array = _as_finite_real_array(name, value, allow_bool=allow_bool)
Expand Down Expand Up @@ -172,8 +176,12 @@ def infer_polarity_contrast_sign(
return normalized

tolerance = _as_finite_nonnegative_scalar("zero_tolerance", zero_tolerance)
flows = _as_finite_real_array("signed_normal_flows", signed_normal_flows).reshape((-1,))
polarities = _as_finite_real_array("event_polarities", event_polarities, allow_bool=True).reshape((-1,))
flows = _as_finite_real_array("signed_normal_flows", signed_normal_flows).reshape(
(-1,)
)
polarities = _as_finite_real_array(
"event_polarities", event_polarities, allow_bool=True
).reshape((-1,))
if flows.shape != polarities.shape:
raise ValueError("event_polarities must have one value per signed normal flow")

Expand Down Expand Up @@ -234,7 +242,9 @@ def polarity_weight_for_signed_flow(
zero_tolerance=1e-12,
) -> float:
"""Return a multiplicative reliability weight from polarity consistency."""
mismatch_weight = _as_unit_interval_scalar("polarity_mismatch_weight", polarity_mismatch_weight)
mismatch_weight = _as_unit_interval_scalar(
"polarity_mismatch_weight", polarity_mismatch_weight
)
consistency = polarity_consistency_for_signed_flow(
signed_normal_flow_value,
event_polarity,
Expand All @@ -257,9 +267,15 @@ def polarity_weights_for_signed_flows(
zero_tolerance=1e-12,
) -> np.ndarray:
"""Return polarity reliability weights for a batch of signed-flow samples."""
mismatch_weight = _as_unit_interval_scalar("polarity_mismatch_weight", polarity_mismatch_weight)
flows = _as_finite_real_array("signed_normal_flows", signed_normal_flows).reshape((-1,))
polarities = _as_finite_real_array("event_polarities", event_polarities, allow_bool=True).reshape((-1,))
mismatch_weight = _as_unit_interval_scalar(
"polarity_mismatch_weight", polarity_mismatch_weight
)
flows = _as_finite_real_array("signed_normal_flows", signed_normal_flows).reshape(
(-1,)
)
polarities = _as_finite_real_array(
"event_polarities", event_polarities, allow_bool=True
).reshape((-1,))
resolved_sign = infer_polarity_contrast_sign(
flows,
polarities,
Expand Down
11 changes: 8 additions & 3 deletions src/pyrecest/filters/measurement_reliability.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ def _is_real_numeric_object_value(value) -> bool:
if kind is not None:
return kind in {"i", "u", "f"}
dtype_name = str(dtype).lower()
if any(token in dtype_name for token in ("bool", "complex", "str", "string", "object")):
if any(
token in dtype_name
for token in ("bool", "complex", "str", "string", "object")
):
return False
if "float" in dtype_name or "int" in dtype_name:
return True
Expand Down Expand Up @@ -81,7 +84,9 @@ def _has_real_numeric_dtype(value) -> bool:
if kind is not None:
return kind in {"i", "u", "f"}
dtype_name = str(dtype).lower()
if any(token in dtype_name for token in ("bool", "complex", "str", "string", "object")):
if any(
token in dtype_name for token in ("bool", "complex", "str", "string", "object")
):
return False
return "float" in dtype_name or "int" in dtype_name

Expand Down Expand Up @@ -236,4 +241,4 @@ def normalize_measurement_noise_covariances(
shared_noise = as_covariance_matrix(noise, measurement_dim, name)
if n_measurements == 0:
return zeros(empty_shape)
return stack([shared_noise for _ in range(n_measurements)])
return stack([shared_noise for _ in range(n_measurements)])
1 change: 0 additions & 1 deletion src/pyrecest/filters/online_time_offset_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import numpy as np


_UNSUPPORTED_NUMERIC_KINDS = {"b", "S", "U", "c", "M", "m"}
_UNSUPPORTED_SCALAR_TYPES = (
type(None),
Expand Down
1 change: 0 additions & 1 deletion src/pyrecest/filters/wrapped_normal_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from .abstract_filter import AbstractFilter
from .manifold_mixins import CircularFilterMixin


_PROGRESSIVE_TAU_MESSAGE = "tau must be a positive finite scalar"


Expand Down
2 changes: 1 addition & 1 deletion src/pyrecest/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
and capability-oriented so filters can opt into the pieces they need.
"""

from ._sampleable_transition_validation import install_sampleable_transition_validation
from ._validated_motion_models import nearly_coordinated_turn_model
from .adapters import (
LinearMeasurementArguments,
Expand Down Expand Up @@ -36,7 +37,6 @@
SupportsTransitionDensity,
SupportsTransitionSampling,
)
from ._sampleable_transition_validation import install_sampleable_transition_validation

install_sampleable_transition_validation()

Expand Down
5 changes: 4 additions & 1 deletion src/pyrecest/models/motion_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def _is_text_bool_or_complex(value: Any) -> bool:

def _as_scalar_float(value: Any, name: str) -> float:
value_array = np.asarray(value)
if value_array.shape != () or value_array.dtype.kind in _REJECTED_NUMERIC_ARRAY_KINDS:
if (
value_array.shape != ()
or value_array.dtype.kind in _REJECTED_NUMERIC_ARRAY_KINDS
):
raise ValueError(f"{name} must be a scalar number")
scalar_value = value_array.item()
if _is_text_bool_or_complex(scalar_value):
Expand Down
5 changes: 4 additions & 1 deletion src/pyrecest/models/sensor_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def _state_vector(state):

def _as_scalar_float(value: Any, name: str) -> float:
value_array = np.asarray(value)
if value_array.shape != () or value_array.dtype.kind in _REJECTED_SCALAR_ARRAY_KINDS:
if (
value_array.shape != ()
or value_array.dtype.kind in _REJECTED_SCALAR_ARRAY_KINDS
):
raise ValueError(f"{name} must be a scalar number")
scalar_value = value_array.item()
if isinstance(scalar_value, _TEXT_OR_BOOL_SCALAR_TYPES) or isinstance(
Expand Down
4 changes: 1 addition & 3 deletions src/pyrecest/models/weak_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,7 @@ def _contains_non_real_numeric_values(value: Any) -> bool:
return True
if array.dtype.kind != "O":
return False
return any(
isinstance(item, _NON_REAL_NUMERIC_SCALAR_TYPES) for item in array.flat
)
return any(isinstance(item, _NON_REAL_NUMERIC_SCALAR_TYPES) for item in array.flat)


def _positive_int(value: int, name: str) -> int:
Expand Down
4 changes: 3 additions & 1 deletion src/pyrecest/tracking/innovation_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,9 @@ def _as_finite_real_array(value: Any, name: str) -> np.ndarray:

if raw_values.dtype == np.bool_ or raw_values.dtype.kind in "USbcMm":
raise ValueError(message)
if _contains_values_of_type(value, _INVALID_REAL_NUMERIC_TYPES) or _contains_values_of_type(raw_values, _INVALID_REAL_NUMERIC_TYPES):
if _contains_values_of_type(
value, _INVALID_REAL_NUMERIC_TYPES
) or _contains_values_of_type(raw_values, _INVALID_REAL_NUMERIC_TYPES):
raise ValueError(message)

try:
Expand Down
Loading