diff --git a/scripts/check_benchmark_regression.py b/scripts/check_benchmark_regression.py
index a9b2cbbbf..4fce50269 100644
--- a/scripts/check_benchmark_regression.py
+++ b/scripts/check_benchmark_regression.py
@@ -55,7 +55,9 @@ def _flatten_numbers(value: Any, *, path: str) -> list[float]:
raise TypeError(f"{path} must be a number or nested sequence of numbers")
-def _finite_nonnegative_number(value: Any, *, path: str) -> tuple[float | None, str | None]:
+def _finite_nonnegative_number(
+ value: Any, *, path: str
+) -> tuple[float | None, str | None]:
if isinstance(value, bool):
return None, f"{path} must be numeric, not boolean"
if not isinstance(value, int | float):
diff --git a/scripts/check_public_api_registry.py b/scripts/check_public_api_registry.py
index 83428adeb..fdec151f9 100644
--- a/scripts/check_public_api_registry.py
+++ b/scripts/check_public_api_registry.py
@@ -41,7 +41,12 @@ def _load_backend_capabilities() -> dict[str, dict[str, str]]:
def _markdown_table_cell(value: object) -> str:
- return str(value).replace("\r", " ").replace("\n", "
").replace(chr(124), chr(0xFF5C))
+ return (
+ str(value)
+ .replace("\r", " ")
+ .replace("\n", "
")
+ .replace(chr(124), chr(0xFF5C))
+ )
def validate_registry() -> list[str]:
diff --git a/src/pyrecest/__main__.py b/src/pyrecest/__main__.py
index e72674850..17bfa0a9a 100644
--- a/src/pyrecest/__main__.py
+++ b/src/pyrecest/__main__.py
@@ -6,6 +6,5 @@
from pyrecest.cli import main
-
if __name__ == "__main__": # pragma: no cover
sys.exit(main())
diff --git a/src/pyrecest/_backend/numpy/_common.py b/src/pyrecest/_backend/numpy/_common.py
index 9cad54096..6ad2a6cc7 100644
--- a/src/pyrecest/_backend/numpy/_common.py
+++ b/src/pyrecest/_backend/numpy/_common.py
@@ -2,7 +2,11 @@
from pyrecest._backend._dtype_utils import (
_dyn_update_dtype,
_modify_func_default_dtype,
+)
+from pyrecest._backend._dtype_utils import (
get_default_cdtype as _shared_get_default_cdtype,
+)
+from pyrecest._backend._dtype_utils import (
get_default_dtype as _shared_get_default_dtype,
)
diff --git a/src/pyrecest/_backend/pytorch/linalg.py b/src/pyrecest/_backend/pytorch/linalg.py
index f7204b86f..9f5d5b386 100644
--- a/src/pyrecest/_backend/pytorch/linalg.py
+++ b/src/pyrecest/_backend/pytorch/linalg.py
@@ -14,7 +14,9 @@
inv,
)
from torch.linalg import matrix_exp as expm
-from torch.linalg import matrix_power
+from torch.linalg import (
+ matrix_power,
+)
from .._backend_config import np_atol as atol
from ..numpy import linalg as _gsnplinalg
diff --git a/src/pyrecest/_backend/pytorch/random.py b/src/pyrecest/_backend/pytorch/random.py
index 487627d65..eee404977 100644
--- a/src/pyrecest/_backend/pytorch/random.py
+++ b/src/pyrecest/_backend/pytorch/random.py
@@ -89,11 +89,7 @@ def _validate_choice_probabilities(p, population_size, device):
raise ValueError("p must be 1-dimensional with one entry per population item")
p_sum = p.sum()
- if (
- bool(_torch.any(p < 0))
- or not bool(_torch.isfinite(p_sum))
- or bool(p_sum <= 0)
- ):
+ if bool(_torch.any(p < 0)) or not bool(_torch.isfinite(p_sum)) or bool(p_sum <= 0):
raise ValueError("probabilities do not sum to a positive value")
return p / p_sum
@@ -209,7 +205,9 @@ def _normal_size(size):
def _broadcasted_parameter_shape(*parameters, message):
try:
- return tuple(_torch.broadcast_shapes(*(parameter.shape for parameter in parameters)))
+ return tuple(
+ _torch.broadcast_shapes(*(parameter.shape for parameter in parameters))
+ )
except RuntimeError as exc:
raise ValueError(message) from exc
diff --git a/src/pyrecest/calibration/__init__.py b/src/pyrecest/calibration/__init__.py
index 11723e6d6..d756cc40c 100644
--- a/src/pyrecest/calibration/__init__.py
+++ b/src/pyrecest/calibration/__init__.py
@@ -15,7 +15,13 @@
)
from .time_offset import (
TimeOffsetFitResult,
- aggregate_time_offset_sweeps as _aggregate_time_offset_sweeps,
+ _aggregate_summary_metric,
+ _as_nonnegative_summary_count,
+ _as_summary_scalar,
+ _validate_error_metric,
+)
+from .time_offset import aggregate_time_offset_sweeps as _aggregate_time_offset_sweeps
+from .time_offset import (
apply_time_offset,
fit_time_offset,
interpolate_reference_values,
@@ -24,12 +30,6 @@
time_offset_error_summary,
time_offset_sweep,
)
-from .time_offset import (
- _aggregate_summary_metric,
- _as_nonnegative_summary_count,
- _as_summary_scalar,
- _validate_error_metric,
-)
def aggregate_time_offset_sweeps(
diff --git a/src/pyrecest/calibration/time_offset.py b/src/pyrecest/calibration/time_offset.py
index dc8896374..5c98273b7 100644
--- a/src/pyrecest/calibration/time_offset.py
+++ b/src/pyrecest/calibration/time_offset.py
@@ -8,7 +8,6 @@
import numpy as np
-
_ERROR_METRIC_NAMES = frozenset({"max", "mean", "p95", "rmse", "std"})
_ERROR_METRIC_MESSAGE = "metric must be one of 'max', 'mean', 'p95', 'rmse', or 'std'"
diff --git a/src/pyrecest/distributions/abstract_grid_distribution.py b/src/pyrecest/distributions/abstract_grid_distribution.py
index 4ffb1c7dc..4568c9f00 100644
--- a/src/pyrecest/distributions/abstract_grid_distribution.py
+++ b/src/pyrecest/distributions/abstract_grid_distribution.py
@@ -43,8 +43,7 @@ def __init__(
)
if dim is not None and actual_dim != dim:
raise ValueError(
- f"Grid coordinates must have dimension {dim}, got "
- f"{actual_dim}."
+ f"Grid coordinates must have dimension {dim}, got " f"{actual_dim}."
)
if grid is None or (grid.ndim > 1 and grid.shape[0] < grid.shape[1]):
warnings.warn(
diff --git a/src/pyrecest/distributions/abstract_manifold_specific_distribution.py b/src/pyrecest/distributions/abstract_manifold_specific_distribution.py
index 08e02cbc3..f01f827ae 100644
--- a/src/pyrecest/distributions/abstract_manifold_specific_distribution.py
+++ b/src/pyrecest/distributions/abstract_manifold_specific_distribution.py
@@ -10,7 +10,9 @@
# pylint: disable=no-name-in-module,no-member,redefined-builtin
from pyrecest.backend import empty, int32, int64, log, random, squeeze
-_SCALAR_VALUE_ERROR = "Metropolis-Hastings scalar evaluations must return scalar values."
+_SCALAR_VALUE_ERROR = (
+ "Metropolis-Hastings scalar evaluations must return scalar values."
+)
def _shape_size(value) -> int | None:
diff --git a/src/pyrecest/distributions/circle/circular_grid_distribution.py b/src/pyrecest/distributions/circle/circular_grid_distribution.py
index 8ea39d3e0..3da169e61 100644
--- a/src/pyrecest/distributions/circle/circular_grid_distribution.py
+++ b/src/pyrecest/distributions/circle/circular_grid_distribution.py
@@ -80,7 +80,9 @@ def _pdf_via_sinc(self, xs, sinc_repetitions):
lower = int(floor(sinc_repetitions / 2) * grid_size)
upper = int(ceil(sinc_repetitions / 2) * grid_size)
repetitions = arange(-lower, upper)
- sinc_vals = self._matlab_sinc((xs_eval / step_size)[:, None] - repetitions[None, :])
+ sinc_vals = self._matlab_sinc(
+ (xs_eval / step_size)[:, None] - repetitions[None, :]
+ )
grid_values = (
sqrt(self.grid_values) if self.enforce_pdf_nonnegative else self.grid_values
)
diff --git a/src/pyrecest/distributions/circle/wrapped_exponential_distribution.py b/src/pyrecest/distributions/circle/wrapped_exponential_distribution.py
index 607881327..58e636267 100644
--- a/src/pyrecest/distributions/circle/wrapped_exponential_distribution.py
+++ b/src/pyrecest/distributions/circle/wrapped_exponential_distribution.py
@@ -83,11 +83,7 @@ def entropy(self):
# distribution on [0, 2*pi). The direct expression evaluates
# log1p(-exp(-log_beta)) and divides by 1 - exp(-log_beta), which
# suffers catastrophic cancellation for tiny log_beta.
- return (
- log(2.0 * pi)
- - log_beta**2 / 24.0
- + log_beta**4 / 960.0
- )
+ return log(2.0 * pi) - log_beta**2 / 24.0 + log_beta**4 / 960.0
# Use exp(-2*pi*lambda) to avoid overflowing exp(2*pi*lambda) for
# concentrated wrapped exponentials.
diff --git a/src/pyrecest/distributions/nonperiodic/abstract_hyperrectangular_distribution.py b/src/pyrecest/distributions/nonperiodic/abstract_hyperrectangular_distribution.py
index cac6e01c6..612f539cd 100644
--- a/src/pyrecest/distributions/nonperiodic/abstract_hyperrectangular_distribution.py
+++ b/src/pyrecest/distributions/nonperiodic/abstract_hyperrectangular_distribution.py
@@ -9,7 +9,6 @@
AbstractBoundedNonPeriodicDistribution,
)
-
_ERROR_SCALAR_PDF_VALUE = (
"pdf must return one finite scalar value per integration point"
)
diff --git a/src/pyrecest/evaluation/generate_measurements.py b/src/pyrecest/evaluation/generate_measurements.py
index 5ca2a13bd..09cca80de 100644
--- a/src/pyrecest/evaluation/generate_measurements.py
+++ b/src/pyrecest/evaluation/generate_measurements.py
@@ -146,7 +146,9 @@ def generate_measurements(groundtruth, simulation_config):
y = _as_shapely_scalar(curr_groundtruth[..., 1], "groundtruth y")
curr_shape = translate(shape, xoff=x, yoff=y)
elif curr_groundtruth.shape[-1] == 3:
- angle = _as_shapely_scalar(curr_groundtruth[..., 0], "groundtruth angle")
+ angle = _as_shapely_scalar(
+ curr_groundtruth[..., 0], "groundtruth angle"
+ )
x = _as_shapely_scalar(curr_groundtruth[..., 1], "groundtruth x")
y = _as_shapely_scalar(curr_groundtruth[..., 2], "groundtruth y")
curr_shape = rotate(
@@ -160,9 +162,7 @@ def generate_measurements(groundtruth, simulation_config):
"Currently only R^2 and SE(2) scenarios are supported."
)
if not isinstance(curr_shape, PolygonWithSampling):
- curr_shape.__class__ = (
- PolygonWithSampling # Preserve existing subclass swap to add sampling methods
- )
+ curr_shape.__class__ = PolygonWithSampling # Preserve existing subclass swap to add sampling methods
if "n_meas_at_individual_time_step" in simulation_config:
if "intensity_lambda" in simulation_config:
diff --git a/src/pyrecest/evaluation/pareto.py b/src/pyrecest/evaluation/pareto.py
index 680de95e3..6d793c7a2 100644
--- a/src/pyrecest/evaluation/pareto.py
+++ b/src/pyrecest/evaluation/pareto.py
@@ -361,7 +361,9 @@ def _coerce_numeric(value: Any) -> float:
if value_array.shape != () or value_array.dtype.kind in "bSUcMm":
return float("nan")
scalar = value_array.item()
- if isinstance(scalar, (bool, np.bool_, complex, np.complexfloating)) or _is_text_scalar(scalar):
+ if isinstance(
+ scalar, (bool, np.bool_, complex, np.complexfloating)
+ ) or _is_text_scalar(scalar):
return float("nan")
try:
return float(scalar)
diff --git a/src/pyrecest/evidence.py b/src/pyrecest/evidence.py
index 46b00a575..2435b5eb5 100644
--- a/src/pyrecest/evidence.py
+++ b/src/pyrecest/evidence.py
@@ -61,7 +61,10 @@ class EvidenceComputationMode:
metadata: dict[str, Any] | None = field(default_factory=dict)
def __post_init__(self) -> None:
- if not isinstance(self.mode, str) or self.mode not in {"full_smoothing", "evidence_only"}:
+ if not isinstance(self.mode, str) or self.mode not in {
+ "full_smoothing",
+ "evidence_only",
+ }:
raise ValueError(f"unknown evidence computation mode {self.mode!r}")
return_smoothed = _coerce_bool_flag(self.return_smoothed, "return_smoothed")
terminal_posterior = _coerce_bool_flag(
diff --git a/src/pyrecest/experimental/dvs/normal_flow.py b/src/pyrecest/experimental/dvs/normal_flow.py
index 0d58fb979..0441bf0a5 100644
--- a/src/pyrecest/experimental/dvs/normal_flow.py
+++ b/src/pyrecest/experimental/dvs/normal_flow.py
@@ -57,7 +57,9 @@ def _has_non_real_numeric_values(value: object, *, allow_bool: bool = False) ->
return array.dtype.kind in {"M", "m"}
-def _as_finite_real_array(name: str, value: object, *, allow_bool: bool = False) -> np.ndarray:
+def _as_finite_real_array(
+ name: str, value: object, *, allow_bool: bool = False
+) -> np.ndarray:
message = f"{name} must contain finite real numeric values."
if _has_non_real_numeric_values(value, allow_bool=allow_bool):
raise ValueError(message)
@@ -70,7 +72,9 @@ def _as_finite_real_array(name: str, value: object, *, allow_bool: bool = False)
return array
-def _as_finite_real_scalar(name: str, value: object, *, allow_bool: bool = False) -> float:
+def _as_finite_real_scalar(
+ name: str, value: object, *, allow_bool: bool = False
+) -> float:
message = f"{name} must be a finite real scalar."
try:
array = _as_finite_real_array(name, value, allow_bool=allow_bool)
@@ -172,8 +176,12 @@ def infer_polarity_contrast_sign(
return normalized
tolerance = _as_finite_nonnegative_scalar("zero_tolerance", zero_tolerance)
- flows = _as_finite_real_array("signed_normal_flows", signed_normal_flows).reshape((-1,))
- polarities = _as_finite_real_array("event_polarities", event_polarities, allow_bool=True).reshape((-1,))
+ flows = _as_finite_real_array("signed_normal_flows", signed_normal_flows).reshape(
+ (-1,)
+ )
+ polarities = _as_finite_real_array(
+ "event_polarities", event_polarities, allow_bool=True
+ ).reshape((-1,))
if flows.shape != polarities.shape:
raise ValueError("event_polarities must have one value per signed normal flow")
@@ -234,7 +242,9 @@ def polarity_weight_for_signed_flow(
zero_tolerance=1e-12,
) -> float:
"""Return a multiplicative reliability weight from polarity consistency."""
- mismatch_weight = _as_unit_interval_scalar("polarity_mismatch_weight", polarity_mismatch_weight)
+ mismatch_weight = _as_unit_interval_scalar(
+ "polarity_mismatch_weight", polarity_mismatch_weight
+ )
consistency = polarity_consistency_for_signed_flow(
signed_normal_flow_value,
event_polarity,
@@ -257,9 +267,15 @@ def polarity_weights_for_signed_flows(
zero_tolerance=1e-12,
) -> np.ndarray:
"""Return polarity reliability weights for a batch of signed-flow samples."""
- mismatch_weight = _as_unit_interval_scalar("polarity_mismatch_weight", polarity_mismatch_weight)
- flows = _as_finite_real_array("signed_normal_flows", signed_normal_flows).reshape((-1,))
- polarities = _as_finite_real_array("event_polarities", event_polarities, allow_bool=True).reshape((-1,))
+ mismatch_weight = _as_unit_interval_scalar(
+ "polarity_mismatch_weight", polarity_mismatch_weight
+ )
+ flows = _as_finite_real_array("signed_normal_flows", signed_normal_flows).reshape(
+ (-1,)
+ )
+ polarities = _as_finite_real_array(
+ "event_polarities", event_polarities, allow_bool=True
+ ).reshape((-1,))
resolved_sign = infer_polarity_contrast_sign(
flows,
polarities,
diff --git a/src/pyrecest/filters/measurement_reliability.py b/src/pyrecest/filters/measurement_reliability.py
index 67bd9b335..f356836f4 100644
--- a/src/pyrecest/filters/measurement_reliability.py
+++ b/src/pyrecest/filters/measurement_reliability.py
@@ -48,7 +48,10 @@ def _is_real_numeric_object_value(value) -> bool:
if kind is not None:
return kind in {"i", "u", "f"}
dtype_name = str(dtype).lower()
- if any(token in dtype_name for token in ("bool", "complex", "str", "string", "object")):
+ if any(
+ token in dtype_name
+ for token in ("bool", "complex", "str", "string", "object")
+ ):
return False
if "float" in dtype_name or "int" in dtype_name:
return True
@@ -81,7 +84,9 @@ def _has_real_numeric_dtype(value) -> bool:
if kind is not None:
return kind in {"i", "u", "f"}
dtype_name = str(dtype).lower()
- if any(token in dtype_name for token in ("bool", "complex", "str", "string", "object")):
+ if any(
+ token in dtype_name for token in ("bool", "complex", "str", "string", "object")
+ ):
return False
return "float" in dtype_name or "int" in dtype_name
@@ -236,4 +241,4 @@ def normalize_measurement_noise_covariances(
shared_noise = as_covariance_matrix(noise, measurement_dim, name)
if n_measurements == 0:
return zeros(empty_shape)
- return stack([shared_noise for _ in range(n_measurements)])
\ No newline at end of file
+ return stack([shared_noise for _ in range(n_measurements)])
diff --git a/src/pyrecest/filters/online_time_offset_estimator.py b/src/pyrecest/filters/online_time_offset_estimator.py
index ccc3916b9..c93838082 100644
--- a/src/pyrecest/filters/online_time_offset_estimator.py
+++ b/src/pyrecest/filters/online_time_offset_estimator.py
@@ -7,7 +7,6 @@
import numpy as np
-
_UNSUPPORTED_NUMERIC_KINDS = {"b", "S", "U", "c", "M", "m"}
_UNSUPPORTED_SCALAR_TYPES = (
type(None),
diff --git a/src/pyrecest/filters/wrapped_normal_filter.py b/src/pyrecest/filters/wrapped_normal_filter.py
index 927ea9088..be919b71b 100644
--- a/src/pyrecest/filters/wrapped_normal_filter.py
+++ b/src/pyrecest/filters/wrapped_normal_filter.py
@@ -8,7 +8,6 @@
from .abstract_filter import AbstractFilter
from .manifold_mixins import CircularFilterMixin
-
_PROGRESSIVE_TAU_MESSAGE = "tau must be a positive finite scalar"
diff --git a/src/pyrecest/models/__init__.py b/src/pyrecest/models/__init__.py
index 315b7e90a..2ff3499c0 100644
--- a/src/pyrecest/models/__init__.py
+++ b/src/pyrecest/models/__init__.py
@@ -4,6 +4,7 @@
and capability-oriented so filters can opt into the pieces they need.
"""
+from ._sampleable_transition_validation import install_sampleable_transition_validation
from ._validated_motion_models import nearly_coordinated_turn_model
from .adapters import (
LinearMeasurementArguments,
@@ -36,7 +37,6 @@
SupportsTransitionDensity,
SupportsTransitionSampling,
)
-from ._sampleable_transition_validation import install_sampleable_transition_validation
install_sampleable_transition_validation()
diff --git a/src/pyrecest/models/motion_models.py b/src/pyrecest/models/motion_models.py
index 740315e39..929959010 100644
--- a/src/pyrecest/models/motion_models.py
+++ b/src/pyrecest/models/motion_models.py
@@ -67,7 +67,10 @@ def _is_text_bool_or_complex(value: Any) -> bool:
def _as_scalar_float(value: Any, name: str) -> float:
value_array = np.asarray(value)
- if value_array.shape != () or value_array.dtype.kind in _REJECTED_NUMERIC_ARRAY_KINDS:
+ if (
+ value_array.shape != ()
+ or value_array.dtype.kind in _REJECTED_NUMERIC_ARRAY_KINDS
+ ):
raise ValueError(f"{name} must be a scalar number")
scalar_value = value_array.item()
if _is_text_bool_or_complex(scalar_value):
diff --git a/src/pyrecest/models/sensor_models.py b/src/pyrecest/models/sensor_models.py
index 9162f7469..59d5bf65e 100644
--- a/src/pyrecest/models/sensor_models.py
+++ b/src/pyrecest/models/sensor_models.py
@@ -47,7 +47,10 @@ def _state_vector(state):
def _as_scalar_float(value: Any, name: str) -> float:
value_array = np.asarray(value)
- if value_array.shape != () or value_array.dtype.kind in _REJECTED_SCALAR_ARRAY_KINDS:
+ if (
+ value_array.shape != ()
+ or value_array.dtype.kind in _REJECTED_SCALAR_ARRAY_KINDS
+ ):
raise ValueError(f"{name} must be a scalar number")
scalar_value = value_array.item()
if isinstance(scalar_value, _TEXT_OR_BOOL_SCALAR_TYPES) or isinstance(
diff --git a/src/pyrecest/models/weak_measurement.py b/src/pyrecest/models/weak_measurement.py
index 145a8b31f..f94771d9f 100644
--- a/src/pyrecest/models/weak_measurement.py
+++ b/src/pyrecest/models/weak_measurement.py
@@ -225,9 +225,7 @@ def _contains_non_real_numeric_values(value: Any) -> bool:
return True
if array.dtype.kind != "O":
return False
- return any(
- isinstance(item, _NON_REAL_NUMERIC_SCALAR_TYPES) for item in array.flat
- )
+ return any(isinstance(item, _NON_REAL_NUMERIC_SCALAR_TYPES) for item in array.flat)
def _positive_int(value: int, name: str) -> int:
diff --git a/src/pyrecest/tracking/innovation_diagnostics.py b/src/pyrecest/tracking/innovation_diagnostics.py
index defd54fb1..d4fb0c36b 100644
--- a/src/pyrecest/tracking/innovation_diagnostics.py
+++ b/src/pyrecest/tracking/innovation_diagnostics.py
@@ -369,7 +369,9 @@ def _as_finite_real_array(value: Any, name: str) -> np.ndarray:
if raw_values.dtype == np.bool_ or raw_values.dtype.kind in "USbcMm":
raise ValueError(message)
- if _contains_values_of_type(value, _INVALID_REAL_NUMERIC_TYPES) or _contains_values_of_type(raw_values, _INVALID_REAL_NUMERIC_TYPES):
+ if _contains_values_of_type(
+ value, _INVALID_REAL_NUMERIC_TYPES
+ ) or _contains_values_of_type(raw_values, _INVALID_REAL_NUMERIC_TYPES):
raise ValueError(message)
try:
diff --git a/src/pyrecest/utils/multisession_assignment_observation_costs.py b/src/pyrecest/utils/multisession_assignment_observation_costs.py
index 1c2c0dda3..4a9772b80 100644
--- a/src/pyrecest/utils/multisession_assignment_observation_costs.py
+++ b/src/pyrecest/utils/multisession_assignment_observation_costs.py
@@ -72,8 +72,7 @@ def _contains_non_real_cost(values: np.ndarray) -> bool:
return True
if values.dtype == object:
return any(
- isinstance(item, _INVALID_COST_SCALAR_TYPES)
- for item in values.reshape(-1)
+ isinstance(item, _INVALID_COST_SCALAR_TYPES) for item in values.reshape(-1)
)
return False
@@ -357,4 +356,4 @@ def _transform_pairwise_costs(
return transformed
-__all__ = ["solve_multisession_assignment_with_observation_costs"]
\ No newline at end of file
+__all__ = ["solve_multisession_assignment_with_observation_costs"]
diff --git a/src/pyrecest/utils/multisession_assignment_score.py b/src/pyrecest/utils/multisession_assignment_score.py
index 4c17e4af2..ab87602b9 100644
--- a/src/pyrecest/utils/multisession_assignment_score.py
+++ b/src/pyrecest/utils/multisession_assignment_score.py
@@ -34,7 +34,6 @@
solve_multisession_assignment,
)
-
_INVALID_SCORE_SCALAR_TYPES = (
type(None),
bool,
diff --git a/tests/backend/test_numpy_multinomial_real_pvals.py b/tests/backend/test_numpy_multinomial_real_pvals.py
index 16a1033db..45897d704 100644
--- a/tests/backend/test_numpy_multinomial_real_pvals.py
+++ b/tests/backend/test_numpy_multinomial_real_pvals.py
@@ -1,6 +1,5 @@
import numpy as np
import pytest
-
from pyrecest._backend.numpy import random
diff --git a/tests/backend_support/test_pytorch_rotation_stub_contract.py b/tests/backend_support/test_pytorch_rotation_stub_contract.py
index aed8e4ad7..af52dc650 100644
--- a/tests/backend_support/test_pytorch_rotation_stub_contract.py
+++ b/tests/backend_support/test_pytorch_rotation_stub_contract.py
@@ -15,9 +15,7 @@ def test_pytorch_rotation_stub_exposes_matrix_methods_with_backend_error():
from pyrecest._backend.pytorch.spatial import Rotation
with pytest.raises(RuntimeError, match="Rotation.from_matrix is not supported"):
- Rotation.from_matrix(
- [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
- )
+ Rotation.from_matrix([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
rotation = object.__new__(Rotation)
with pytest.raises(RuntimeError, match="Rotation.as_matrix is not supported"):
diff --git a/tests/calibration/test_bias_complex_input_validation.py b/tests/calibration/test_bias_complex_input_validation.py
index 476ea9579..77ea9456f 100644
--- a/tests/calibration/test_bias_complex_input_validation.py
+++ b/tests/calibration/test_bias_complex_input_validation.py
@@ -1,7 +1,6 @@
import unittest
import numpy as np
-
from pyrecest.calibration.bias import (
SensorBiasCorrectionModel,
make_bias_training_examples,
diff --git a/tests/calibration/test_time_offset_aggregate_validation.py b/tests/calibration/test_time_offset_aggregate_validation.py
index f3c8b16cf..139daad35 100644
--- a/tests/calibration/test_time_offset_aggregate_validation.py
+++ b/tests/calibration/test_time_offset_aggregate_validation.py
@@ -2,7 +2,6 @@
import numpy as np
import numpy.testing as npt
-
from pyrecest.calibration import aggregate_time_offset_sweeps
from pyrecest.calibration.time_offset import (
aggregate_time_offset_sweeps as aggregate_time_offset_sweeps_from_module,
diff --git a/tests/calibration/test_time_offset_metric_validation.py b/tests/calibration/test_time_offset_metric_validation.py
index e62ffd625..8da70fb87 100644
--- a/tests/calibration/test_time_offset_metric_validation.py
+++ b/tests/calibration/test_time_offset_metric_validation.py
@@ -2,7 +2,6 @@
import numpy as np
import numpy.testing as npt
-
from pyrecest.calibration.time_offset import fit_time_offset
diff --git a/tests/distributions/test_abstract_grid_distribution.py b/tests/distributions/test_abstract_grid_distribution.py
index 9044ca230..1180f4005 100644
--- a/tests/distributions/test_abstract_grid_distribution.py
+++ b/tests/distributions/test_abstract_grid_distribution.py
@@ -49,15 +49,11 @@ def test_constructor_rejects_grid_dimension_mismatch(self):
def test_constructor_rejects_one_dimensional_grid_for_multidimensional_space(self):
with self.assertRaisesRegex(ValueError, "dimension 2"):
- DummyGridDistribution(
- array([1.0, 1.0]), grid=array([0.0, 1.0]), dim=2
- )
+ DummyGridDistribution(array([1.0, 1.0]), grid=array([0.0, 1.0]), dim=2)
def test_constructor_rejects_higher_rank_grid_coordinates(self):
with self.assertRaisesRegex(ValueError, "one- or two-dimensional"):
- DummyGridDistribution(
- array([1.0, 1.0]), grid=array([[[0.0]], [[1.0]]])
- )
+ DummyGridDistribution(array([1.0, 1.0]), grid=array([[[0.0]], [[1.0]]]))
def test_integrate_rejects_custom_boundaries(self):
dist = DummyGridDistribution(array([1.0, 1.0]))
diff --git a/tests/distributions/test_circular_grid_scalar_sinc_pdf.py b/tests/distributions/test_circular_grid_scalar_sinc_pdf.py
index c4b16acfb..bac1b0e89 100644
--- a/tests/distributions/test_circular_grid_scalar_sinc_pdf.py
+++ b/tests/distributions/test_circular_grid_scalar_sinc_pdf.py
@@ -1,7 +1,9 @@
import unittest
from pyrecest.backend import array, pi
-from pyrecest.distributions.circle.circular_grid_distribution import CircularGridDistribution
+from pyrecest.distributions.circle.circular_grid_distribution import (
+ CircularGridDistribution,
+)
class CircularGridScalarSincPdfTest(unittest.TestCase):
diff --git a/tests/distributions/test_circular_mixture.py b/tests/distributions/test_circular_mixture.py
index b9329f299..c7071086e 100644
--- a/tests/distributions/test_circular_mixture.py
+++ b/tests/distributions/test_circular_mixture.py
@@ -2,7 +2,6 @@
import numpy as np
import numpy.testing as npt
-
import pyrecest.backend
# pylint: disable=no-name-in-module,no-member
diff --git a/tests/distributions/test_hyperrectangular_integration_validation.py b/tests/distributions/test_hyperrectangular_integration_validation.py
index 9e0236de3..3a96c2075 100644
--- a/tests/distributions/test_hyperrectangular_integration_validation.py
+++ b/tests/distributions/test_hyperrectangular_integration_validation.py
@@ -2,7 +2,6 @@
import numpy as np
import pytest
-
from pyrecest.distributions.nonperiodic.hyperrectangular_uniform_distribution import (
HyperrectangularUniformDistribution,
)
diff --git a/tests/distributions/test_linear_mixture.py b/tests/distributions/test_linear_mixture.py
index 22d8a9d0b..8963e7728 100644
--- a/tests/distributions/test_linear_mixture.py
+++ b/tests/distributions/test_linear_mixture.py
@@ -3,7 +3,6 @@
import numpy as np
import numpy.testing as npt
-
import pyrecest.backend
# pylint: disable=no-name-in-module,no-member
diff --git a/tests/evaluation/test_distance_registry_normalization.py b/tests/evaluation/test_distance_registry_normalization.py
index 08f5b689e..f96cf638a 100644
--- a/tests/evaluation/test_distance_registry_normalization.py
+++ b/tests/evaluation/test_distance_registry_normalization.py
@@ -14,5 +14,13 @@ def test_custom_distance_registry_strips_names_for_registration_and_lookup():
assert "unit-test-distance-trimmed" in available_distance_functions()
assert " unit-test-distance-trimmed " not in available_distance_functions()
- assert get_distance_function("unit-test-distance-trimmed")(array([0.0]), array([1.0])) == 13.0
- assert get_distance_function(" unit-test-distance-trimmed ")(array([0.0]), array([1.0])) == 13.0
+ assert (
+ get_distance_function("unit-test-distance-trimmed")(array([0.0]), array([1.0]))
+ == 13.0
+ )
+ assert (
+ get_distance_function(" unit-test-distance-trimmed ")(
+ array([0.0]), array([1.0])
+ )
+ == 13.0
+ )
diff --git a/tests/evaluation/test_pareto_objective_scalar_validation.py b/tests/evaluation/test_pareto_objective_scalar_validation.py
index 251ce28e1..a79261f86 100644
--- a/tests/evaluation/test_pareto_objective_scalar_validation.py
+++ b/tests/evaluation/test_pareto_objective_scalar_validation.py
@@ -1,7 +1,6 @@
from __future__ import annotations
import pandas as pd
-
from pyrecest.evaluation import is_pareto_front, pareto_front_indices, record_dominates
diff --git a/tests/experimental/test_dvs_event_likelihood_counts.py b/tests/experimental/test_dvs_event_likelihood_counts.py
index 731af47dd..067223481 100644
--- a/tests/experimental/test_dvs_event_likelihood_counts.py
+++ b/tests/experimental/test_dvs_event_likelihood_counts.py
@@ -1,7 +1,6 @@
import unittest
import numpy as np
-
from pyrecest.experimental.dvs.event_likelihood import PointProcessUpdateConfig
diff --git a/tests/experimental/test_dvs_normal_flow.py b/tests/experimental/test_dvs_normal_flow.py
index dc965927a..613445b06 100644
--- a/tests/experimental/test_dvs_normal_flow.py
+++ b/tests/experimental/test_dvs_normal_flow.py
@@ -24,7 +24,9 @@ def test_signed_scalar_sign_rejects_invalid_values(bad_value):
signed_scalar_sign(bad_value)
-@pytest.mark.parametrize("bad_tolerance", [-1e-12, np.nan, np.inf, "1e-12", 1e-12 + 0.0j])
+@pytest.mark.parametrize(
+ "bad_tolerance", [-1e-12, np.nan, np.inf, "1e-12", 1e-12 + 0.0j]
+)
def test_signed_scalar_sign_rejects_invalid_zero_tolerance(bad_tolerance):
with pytest.raises(ValueError, match="zero_tolerance"):
signed_scalar_sign(1.0, zero_tolerance=bad_tolerance)
@@ -108,7 +110,9 @@ def test_polarity_consistency_and_weight_for_signed_flow():
@pytest.mark.parametrize("bad_weight", [np.nan, np.inf, "0.5", 0.5 + 0.0j, -0.1, 1.1])
def test_polarity_weight_rejects_invalid_mismatch_weight(bad_weight):
with pytest.raises(ValueError, match="polarity_mismatch_weight"):
- polarity_weight_for_signed_flow(1.0, 0.0, 1.0, polarity_mismatch_weight=bad_weight)
+ polarity_weight_for_signed_flow(
+ 1.0, 0.0, 1.0, polarity_mismatch_weight=bad_weight
+ )
def test_polarity_weights_for_signed_flows_resolves_batch_sign():
@@ -122,4 +126,6 @@ def test_polarity_weights_for_signed_flows_resolves_batch_sign():
def test_polarity_weights_for_signed_flows_rejects_invalid_mismatch_weight_when_disabled():
with pytest.raises(ValueError, match="polarity_mismatch_weight"):
- polarity_weights_for_signed_flows([1.0], [1.0], None, polarity_mismatch_weight=np.nan)
+ polarity_weights_for_signed_flows(
+ [1.0], [1.0], None, polarity_mismatch_weight=np.nan
+ )
diff --git a/tests/filters/test_abstract_extended_object_tracker.py b/tests/filters/test_abstract_extended_object_tracker.py
index 227854ce7..ece376f35 100644
--- a/tests/filters/test_abstract_extended_object_tracker.py
+++ b/tests/filters/test_abstract_extended_object_tracker.py
@@ -5,7 +5,9 @@
# pylint: disable=no-name-in-module,no-member
from pyrecest.backend import array, eye
-from pyrecest.filters.abstract_extended_object_tracker import AbstractExtendedObjectTracker
+from pyrecest.filters.abstract_extended_object_tracker import (
+ AbstractExtendedObjectTracker,
+)
class _MinimalExtendedObjectTracker(AbstractExtendedObjectTracker):
diff --git a/tests/filters/test_gaussian_mixture_phd_filter_dimensions.py b/tests/filters/test_gaussian_mixture_phd_filter_dimensions.py
index ca3eda81a..77060b3e2 100644
--- a/tests/filters/test_gaussian_mixture_phd_filter_dimensions.py
+++ b/tests/filters/test_gaussian_mixture_phd_filter_dimensions.py
@@ -12,7 +12,9 @@
)
class TestGaussianMixturePHDFilterDimensions(unittest.TestCase):
def test_constructor_rejects_mismatched_birth_dimension(self):
- with self.assertRaisesRegex(ValueError, "Birth components must have dimension 2"):
+ with self.assertRaisesRegex(
+ ValueError, "Birth components must have dimension 2"
+ ):
GaussianMixturePHDFilter(
initial_components=[GaussianDistribution(array([0.0, 0.0]), eye(2))],
initial_weights=array([0.8]),
@@ -32,7 +34,9 @@ def test_set_birth_model_rejects_mismatched_dimension(self):
log_posterior_estimates=False,
)
- with self.assertRaisesRegex(ValueError, "Birth components must have dimension 2"):
+ with self.assertRaisesRegex(
+ ValueError, "Birth components must have dimension 2"
+ ):
tracker.set_birth_model(
[GaussianDistribution(array([5.0, 5.0, 0.0]), eye(3))],
array([0.2]),
@@ -46,7 +50,9 @@ def test_predict_linear_rejects_mismatched_temporary_birth_dimension(self):
log_posterior_estimates=False,
)
- with self.assertRaisesRegex(ValueError, "Birth components must have dimension 2"):
+ with self.assertRaisesRegex(
+ ValueError, "Birth components must have dimension 2"
+ ):
tracker.predict_linear(
eye(2),
0.1 * eye(2),
diff --git a/tests/filters/test_gnn_pairwise_cost_validation.py b/tests/filters/test_gnn_pairwise_cost_validation.py
index 2c58add7a..2d0ee463f 100644
--- a/tests/filters/test_gnn_pairwise_cost_validation.py
+++ b/tests/filters/test_gnn_pairwise_cost_validation.py
@@ -1,7 +1,6 @@
import unittest
import numpy as np
-
from pyrecest.filters.global_nearest_neighbor import GlobalNearestNeighbor
diff --git a/tests/filters/test_measurement_reliability_object_weights.py b/tests/filters/test_measurement_reliability_object_weights.py
index 366559fd4..df98d9b32 100644
--- a/tests/filters/test_measurement_reliability_object_weights.py
+++ b/tests/filters/test_measurement_reliability_object_weights.py
@@ -34,4 +34,4 @@ def test_object_weight_inputs_still_reject_non_real_values(self):
if __name__ == "__main__":
- unittest.main()
\ No newline at end of file
+ unittest.main()
diff --git a/tests/filters/test_measurement_scoring.py b/tests/filters/test_measurement_scoring.py
index 083dbcf58..bfd63d276 100644
--- a/tests/filters/test_measurement_scoring.py
+++ b/tests/filters/test_measurement_scoring.py
@@ -1,5 +1,4 @@
import numpy as np
-
from pyrecest.filters.measurement_scoring import MeasurementScore
diff --git a/tests/filters/test_update_diagnostics_validation.py b/tests/filters/test_update_diagnostics_validation.py
index cb02e33bd..4bacc3234 100644
--- a/tests/filters/test_update_diagnostics_validation.py
+++ b/tests/filters/test_update_diagnostics_validation.py
@@ -30,7 +30,9 @@ def test_active_measurement_indices_must_fit_measurement_count():
ValueError,
match="active_measurement_indices.*measurement_count",
):
- MeasurementUpdateDiagnostics(active_measurement_indices=(0, 2), measurement_count=2)
+ MeasurementUpdateDiagnostics(
+ active_measurement_indices=(0, 2), measurement_count=2
+ )
def test_measurement_count_rejects_non_integer_values():
diff --git a/tests/models/test_linear_gaussian_models.py b/tests/models/test_linear_gaussian_models.py
index 37df6ae28..dc19c8369 100644
--- a/tests/models/test_linear_gaussian_models.py
+++ b/tests/models/test_linear_gaussian_models.py
@@ -70,7 +70,17 @@ def test_identity_models(self):
)
def test_identity_models_reject_invalid_dimensions(self):
- invalid_dims = (True, False, 0, -1, 1.5, float("inf"), "2", b"2", bytearray(b"2"))
+ invalid_dims = (
+ True,
+ False,
+ 0,
+ -1,
+ 1.5,
+ float("inf"),
+ "2",
+ b"2",
+ bytearray(b"2"),
+ )
for dim in invalid_dims:
with self.subTest(model="transition", dim=dim):
diff --git a/tests/models/test_sampleable_transition_vectorization_validation.py b/tests/models/test_sampleable_transition_vectorization_validation.py
index fe59cae73..1a57a1f8f 100644
--- a/tests/models/test_sampleable_transition_vectorization_validation.py
+++ b/tests/models/test_sampleable_transition_vectorization_validation.py
@@ -1,5 +1,4 @@
import pytest
-
from pyrecest.models.likelihood import SampleableTransitionModel
diff --git a/tests/models/test_weak_measurement_nonreal_inputs.py b/tests/models/test_weak_measurement_nonreal_inputs.py
index 0fa01717c..55e470b7b 100644
--- a/tests/models/test_weak_measurement_nonreal_inputs.py
+++ b/tests/models/test_weak_measurement_nonreal_inputs.py
@@ -4,7 +4,6 @@
import numpy as np
import pytest
-
from pyrecest.models import (
MaskedLinearMeasurementModel,
WeakDimensionMeasurementModel,
@@ -12,7 +11,6 @@
diagonal_measurement_covariance,
)
-
_NON_REAL_STD_CASES = (
[1.0 + 0.0j],
[np.complex128(1.0 + 0.0j)],
diff --git a/tests/sampling/test_support_points.py b/tests/sampling/test_support_points.py
index 6ccb4a79f..dbaf7da67 100644
--- a/tests/sampling/test_support_points.py
+++ b/tests/sampling/test_support_points.py
@@ -52,7 +52,9 @@ def test_support_points_from_axis_offsets_supports_batches() -> None:
def test_support_points_from_axis_offsets_can_omit_center() -> None:
- support = support_points_from_axis_offsets([0.0, 0.0], np.eye(2), include_center=False)
+ support = support_points_from_axis_offsets(
+ [0.0, 0.0], np.eye(2), include_center=False
+ )
assert support.shape == (4, 2)
assert np.allclose(support, [[1.0, 0.0], [-1.0, 0.0], [0.0, 1.0], [0.0, -1.0]])
diff --git a/tests/test_backend_default_dtype_like_values.py b/tests/test_backend_default_dtype_like_values.py
index 0f31a41f1..d5727bb2c 100644
--- a/tests/test_backend_default_dtype_like_values.py
+++ b/tests/test_backend_default_dtype_like_values.py
@@ -1,6 +1,5 @@
import numpy as np
import pytest
-
from pyrecest._backend import numpy as numpy_backend
diff --git a/tests/test_backend_sum_contract.py b/tests/test_backend_sum_contract.py
index 1732317e2..f76c7e7fe 100644
--- a/tests/test_backend_sum_contract.py
+++ b/tests/test_backend_sum_contract.py
@@ -30,7 +30,9 @@ def test_convert_to_wider_dtype_matches_numpy_result_type_for_mixed_dtypes():
expected_dtype = np.result_type(np.dtype("int64"), np.dtype("float32"))
assert backend.to_numpy(first).dtype == expected_dtype
assert backend.to_numpy(second).dtype == expected_dtype
- np.testing.assert_allclose(backend.to_numpy(first), np.array([1], dtype=expected_dtype))
+ np.testing.assert_allclose(
+ backend.to_numpy(first), np.array([1], dtype=expected_dtype)
+ )
np.testing.assert_allclose(
backend.to_numpy(second), np.array([1.5], dtype=expected_dtype)
)
diff --git a/tests/test_backend_support_public.py b/tests/test_backend_support_public.py
index fb678ff48..cd3aeaddd 100644
--- a/tests/test_backend_support_public.py
+++ b/tests/test_backend_support_public.py
@@ -7,7 +7,6 @@
)
from pyrecest._backend.capabilities import BACKEND_SUPPORT_LEVELS
-
backend_support_module = import_module("pyrecest.backend_support")
diff --git a/tests/test_benchmark_regression_script.py b/tests/test_benchmark_regression_script.py
index 41c06189c..889514d4c 100644
--- a/tests/test_benchmark_regression_script.py
+++ b/tests/test_benchmark_regression_script.py
@@ -52,11 +52,16 @@ def test_benchmark_regression_rejects_boolean_elapsed(tmp_path: Path) -> None:
)
assert completed.returncode == 1
- assert "::error::linear_kalman.elapsed_seconds must be numeric, not boolean" in completed.stdout
+ assert (
+ "::error::linear_kalman.elapsed_seconds must be numeric, not boolean"
+ in completed.stdout
+ )
assert "Traceback" not in completed.stderr
-def test_benchmark_regression_rejects_text_elapsed_without_traceback(tmp_path: Path) -> None:
+def test_benchmark_regression_rejects_text_elapsed_without_traceback(
+ tmp_path: Path,
+) -> None:
completed = _run_checker(
tmp_path,
actual_entry={
@@ -74,7 +79,10 @@ def test_benchmark_regression_rejects_text_elapsed_without_traceback(tmp_path: P
)
assert completed.returncode == 1
- assert "::error::linear_kalman.elapsed_seconds must be numeric, got '0.01'" in completed.stdout
+ assert (
+ "::error::linear_kalman.elapsed_seconds must be numeric, got '0.01'"
+ in completed.stdout
+ )
assert "Traceback" not in completed.stderr
@@ -96,11 +104,16 @@ def test_benchmark_regression_rejects_text_iterations(tmp_path: Path) -> None:
)
assert completed.returncode == 1
- assert "::error::linear_kalman.iterations must be an integer, got '200'" in completed.stdout
+ assert (
+ "::error::linear_kalman.iterations must be an integer, got '200'"
+ in completed.stdout
+ )
assert "Traceback" not in completed.stderr
-def test_benchmark_regression_reports_invalid_final_estimate_without_traceback(tmp_path: Path) -> None:
+def test_benchmark_regression_reports_invalid_final_estimate_without_traceback(
+ tmp_path: Path,
+) -> None:
completed = _run_checker(
tmp_path,
actual_entry={
@@ -118,5 +131,8 @@ def test_benchmark_regression_reports_invalid_final_estimate_without_traceback(t
)
assert completed.returncode == 1
- assert "::error::linear_kalman.final_estimate[1] must be numeric, not boolean" in completed.stdout
+ assert (
+ "::error::linear_kalman.final_estimate[1] must be numeric, not boolean"
+ in completed.stdout
+ )
assert "Traceback" not in completed.stderr
diff --git a/tests/test_candidate_pruning_temporal_validation.py b/tests/test_candidate_pruning_temporal_validation.py
index bacdde405..9bd341044 100644
--- a/tests/test_candidate_pruning_temporal_validation.py
+++ b/tests/test_candidate_pruning_temporal_validation.py
@@ -1,7 +1,6 @@
import unittest
import numpy as np
-
from pyrecest.utils import (
CandidatePruningConfig,
candidate_mask_from_costs,
@@ -21,7 +20,9 @@ def test_cost_matrix_rejects_datetime_and_timedelta_entries(self):
for function in (candidate_mask_from_costs, prune_pairwise_cost_matrix):
for matrix in invalid_matrices:
with self.subTest(function=function.__name__, dtype=str(matrix.dtype)):
- with self.assertRaisesRegex(ValueError, "cost_matrix must be numeric"):
+ with self.assertRaisesRegex(
+ ValueError, "cost_matrix must be numeric"
+ ):
function(matrix)
def test_probability_matrix_rejects_datetime_and_timedelta_entries(self):
@@ -35,7 +36,9 @@ def test_probability_matrix_rejects_datetime_and_timedelta_entries(self):
for probabilities in invalid_probability_matrices:
with self.subTest(dtype=str(probabilities.dtype)):
- with self.assertRaisesRegex(ValueError, "probability_matrix must be numeric"):
+ with self.assertRaisesRegex(
+ ValueError, "probability_matrix must be numeric"
+ ):
candidate_mask_from_costs(
np.array([[1.0]]),
probability_matrix=probabilities,
diff --git a/tests/test_cli_expected_mapping.py b/tests/test_cli_expected_mapping.py
index 988b56d0b..60994621e 100644
--- a/tests/test_cli_expected_mapping.py
+++ b/tests/test_cli_expected_mapping.py
@@ -11,7 +11,9 @@ def test_expected_mapping_reports_nonnumeric_actual_instead_of_raising():
tolerance=1e-8,
)
- assert errors == ["metrics.rmse mismatch: expected finite numeric 1.0, got 'not-a-number'"]
+ assert errors == [
+ "metrics.rmse mismatch: expected finite numeric 1.0, got 'not-a-number'"
+ ]
def test_expected_mapping_compares_boolean_expected_values_exactly():
diff --git a/tests/test_dynamic_models.py b/tests/test_dynamic_models.py
index ff9aaa955..6566660c0 100644
--- a/tests/test_dynamic_models.py
+++ b/tests/test_dynamic_models.py
@@ -110,7 +110,10 @@ def test_process_noise_rejects_invalid_parameters(self):
{"spectral_density": True, "message": "spectral_density"},
{"spectral_density": "1.0", "message": "spectral_density"},
{"spectral_density": ["1.0", "2.0"], "message": "spectral_density"},
- {"spectral_density": np.array([True, False], dtype=object), "message": "spectral_density"},
+ {
+ "spectral_density": np.array([True, False], dtype=object),
+ "message": "spectral_density",
+ },
{
"spectral_density": np.array([1.0, np.nan]),
"message": "spectral_density",
diff --git a/tests/test_multisession_assignment_observation_cost_type_validation.py b/tests/test_multisession_assignment_observation_cost_type_validation.py
index d5c1ddeed..b056eee72 100644
--- a/tests/test_multisession_assignment_observation_cost_type_validation.py
+++ b/tests/test_multisession_assignment_observation_cost_type_validation.py
@@ -3,7 +3,6 @@
import unittest
import numpy as np
-
from pyrecest.backend import ( # pylint: disable=no-name-in-module
__backend_name__,
array,
diff --git a/tests/test_numpy_backend_dynamic_dtype.py b/tests/test_numpy_backend_dynamic_dtype.py
index 68a03ca9a..debe92071 100644
--- a/tests/test_numpy_backend_dynamic_dtype.py
+++ b/tests/test_numpy_backend_dynamic_dtype.py
@@ -1,6 +1,6 @@
import numpy as np
-
-from pyrecest._backend import _backend_config, numpy as numpy_backend
+from pyrecest._backend import _backend_config
+from pyrecest._backend import numpy as numpy_backend
def test_dynamic_dtype_preserves_explicit_positional_dtype_for_zeros():
@@ -28,6 +28,9 @@ def test_dynamic_dtype_uses_default_dtype_for_omitted_or_none_dtype():
_backend_config.DEFAULT_DTYPE = expected_dtype
assert numpy_backend.zeros((2,)).dtype == expected_dtype
assert numpy_backend.empty((2,), None).dtype == expected_dtype
- assert numpy_backend.linspace(0.0, 1.0, 3, True, False, None).dtype == expected_dtype
+ assert (
+ numpy_backend.linspace(0.0, 1.0, 3, True, False, None).dtype
+ == expected_dtype
+ )
finally:
_backend_config.DEFAULT_DTYPE = previous_dtype
diff --git a/tests/test_numpy_backend_zeros_dtype.py b/tests/test_numpy_backend_zeros_dtype.py
index 25019cc9d..2e756a793 100644
--- a/tests/test_numpy_backend_zeros_dtype.py
+++ b/tests/test_numpy_backend_zeros_dtype.py
@@ -2,7 +2,6 @@
import numpy as np
import pytest
-
from pyrecest._backend import numpy as numpy_backend
from tests.support.backend_runner import run_backend_code
diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py
index 2787a57ec..d9b989639 100644
--- a/tests/test_reproducibility.py
+++ b/tests/test_reproducibility.py
@@ -6,7 +6,13 @@
class ReproducibilityValidationTest(unittest.TestCase):
def test_normalize_seed_rejects_text_values(self):
- for value in ("1", np.array("1"), bytes([49]), bytearray([49]), np.bytes_(bytes([49]))):
+ for value in (
+ "1",
+ np.array("1"),
+ bytes([49]),
+ bytearray([49]),
+ np.bytes_(bytes([49])),
+ ):
with self.subTest(value=repr(value)):
with self.assertRaisesRegex(
ValueError,
diff --git a/tests/test_scenario_type_validation.py b/tests/test_scenario_type_validation.py
index 719a3f055..e79c167de 100644
--- a/tests/test_scenario_type_validation.py
+++ b/tests/test_scenario_type_validation.py
@@ -1,5 +1,4 @@
import pytest
-
from pyrecest import scenarios
@@ -27,7 +26,9 @@ def runner(_path):
try:
assert scenarios.register_scenario_runner(f" {name}\t", runner) is runner
assert name in scenarios.available_scenario_types()
- assert f" {name}\t" not in scenarios._SCENARIO_RUNNERS # pylint: disable=protected-access
+ assert (
+ f" {name}\t" not in scenarios._SCENARIO_RUNNERS
+ ) # pylint: disable=protected-access
with pytest.raises(ValueError, match="already registered"):
scenarios.register_scenario_runner(name, runner)
@@ -38,8 +39,7 @@ def runner(_path):
def test_run_scenario_rejects_unhashable_scenario_type(tmp_path):
scenario_path = tmp_path / "bad_scenario.toml"
scenario_path.write_text(
- "[scenario]\n"
- 'type = ["linear_gaussian"]\n',
+ "[scenario]\n" 'type = ["linear_gaussian"]\n',
encoding="utf-8",
)
diff --git a/tests/test_scenarios.py b/tests/test_scenarios.py
index 090117cc3..cd926a58a 100644
--- a/tests/test_scenarios.py
+++ b/tests/test_scenarios.py
@@ -2,7 +2,6 @@
from pathlib import Path
import pytest
-
from pyrecest.scenarios import load_scenario_config, run_scenario
SCENARIO = Path("scenarios/linear_gaussian_cv_1d/config.toml")
diff --git a/tests/test_track_edit_whatif_empty_split.py b/tests/test_track_edit_whatif_empty_split.py
index 1f4819fd6..9c6e9283d 100644
--- a/tests/test_track_edit_whatif_empty_split.py
+++ b/tests/test_track_edit_whatif_empty_split.py
@@ -1,6 +1,10 @@
"""Regression tests for empty split-track edits."""
-from pyrecest.utils.track_edit_whatif import TrackEdit, apply_track_edit, score_track_edit_delta
+from pyrecest.utils.track_edit_whatif import (
+ TrackEdit,
+ apply_track_edit,
+ score_track_edit_delta,
+)
def test_split_track_rejects_empty_track() -> None:
diff --git a/tests/test_track_edit_whatif_index_validation.py b/tests/test_track_edit_whatif_index_validation.py
index 169b4024a..037038ce4 100644
--- a/tests/test_track_edit_whatif_index_validation.py
+++ b/tests/test_track_edit_whatif_index_validation.py
@@ -1,12 +1,17 @@
"""Validation tests for track-edit what-if selector indices."""
import pytest
-
-from pyrecest.utils.track_edit_whatif import TrackEdit, apply_track_edit, score_track_edit_delta
+from pyrecest.utils.track_edit_whatif import (
+ TrackEdit,
+ apply_track_edit,
+ score_track_edit_delta,
+)
def test_add_link_rejects_invalid_link_fields() -> None:
- with pytest.raises(ValueError, match="target_observation must be a non-negative integer"):
+ with pytest.raises(
+ ValueError, match="target_observation must be a non-negative integer"
+ ):
apply_track_edit(
[[1, None]],
TrackEdit(
@@ -40,8 +45,14 @@ def test_score_rejects_invalid_session_selectors() -> None:
target_observation=2,
)
- with pytest.raises(ValueError, match="complete_session_indices must be a non-negative integer"):
- score_track_edit_delta([[1, 2]], [[1, 2]], edit, complete_session_indices=[0, 1.5])
+ with pytest.raises(
+ ValueError, match="complete_session_indices must be a non-negative integer"
+ ):
+ score_track_edit_delta(
+ [[1, 2]], [[1, 2]], edit, complete_session_indices=[0, 1.5]
+ )
- with pytest.raises(ValueError, match="session_pairs must be a non-negative integer"):
+ with pytest.raises(
+ ValueError, match="session_pairs must be a non-negative integer"
+ ):
score_track_edit_delta([[1, 2]], [[1, 2]], edit, session_pairs=[(bool(0), 1)])
diff --git a/tests/test_track_edit_whatif_metadata_validation.py b/tests/test_track_edit_whatif_metadata_validation.py
index face93ea7..0a4faabba 100644
--- a/tests/test_track_edit_whatif_metadata_validation.py
+++ b/tests/test_track_edit_whatif_metadata_validation.py
@@ -1,7 +1,6 @@
"""Regression tests for track-edit metadata selector validation."""
import pytest
-
from pyrecest.utils.track_edit_whatif import TrackEdit, apply_track_edit
@@ -15,7 +14,9 @@ def test_remove_link_rejects_fractional_occurrence_index() -> None:
metadata={"occurrence_index": 0.5},
)
- with pytest.raises(ValueError, match="metadata\['occurrence_index'\] must be an integer"):
+ with pytest.raises(
+ ValueError, match="metadata\['occurrence_index'\] must be an integer"
+ ):
apply_track_edit([[1, 2], [1, 2]], edit)
@@ -32,7 +33,10 @@ def test_swap_link_rejects_boolean_removed_observation() -> None:
},
)
- with pytest.raises(ValueError, match="metadata\['remove_source_observation'\] must be a non-negative integer"):
+ with pytest.raises(
+ ValueError,
+ match="metadata\['remove_source_observation'\] must be a non-negative integer",
+ ):
apply_track_edit([[1, 2]], edit)
@@ -43,5 +47,7 @@ def test_merge_tracks_rejects_text_other_track_index() -> None:
metadata={"other_track_index": "1"},
)
- with pytest.raises(ValueError, match="metadata\['other_track_index'\] must be an integer"):
+ with pytest.raises(
+ ValueError, match="metadata\['other_track_index'\] must be an integer"
+ ):
apply_track_edit([[1, None], [None, 2]], edit)
diff --git a/tests/test_track_evaluation_session_validation.py b/tests/test_track_evaluation_session_validation.py
index afa57f1f0..70b5793bd 100644
--- a/tests/test_track_evaluation_session_validation.py
+++ b/tests/test_track_evaluation_session_validation.py
@@ -3,7 +3,6 @@
import unittest
import numpy as np
-
from pyrecest.utils.track_evaluation import (
complete_track_set,
score_complete_tracks,
diff --git a/tests/tracking/test_innovation_diagnostics.py b/tests/tracking/test_innovation_diagnostics.py
index 1c2dae594..bf2011dab 100644
--- a/tests/tracking/test_innovation_diagnostics.py
+++ b/tests/tracking/test_innovation_diagnostics.py
@@ -89,7 +89,9 @@ def test_innovation_diagnostic_rejects_non_real_residual_values(bad_residual) ->
[[None]],
),
)
-def test_innovation_diagnostic_rejects_non_real_covariance_values(bad_covariance) -> None:
+def test_innovation_diagnostic_rejects_non_real_covariance_values(
+ bad_covariance,
+) -> None:
with pytest.raises(
ValueError,
match="innovation_covariance must contain finite real numeric values",
diff --git a/tests/utils/test_association_features_probability_validation.py b/tests/utils/test_association_features_probability_validation.py
index 63b21a4bf..2aa20d745 100644
--- a/tests/utils/test_association_features_probability_validation.py
+++ b/tests/utils/test_association_features_probability_validation.py
@@ -2,7 +2,6 @@
import numpy as np
import pytest
-
from pyrecest.utils.association_features import CalibratedPairwiseAssociationModel
@@ -46,7 +45,9 @@ def test_predict_match_probability_rejects_non_real_probability_values(probabili
feature_names=["distance"],
)
- with pytest.raises(ValueError, match="predicted probabilities must be real numeric"):
+ with pytest.raises(
+ ValueError, match="predicted probabilities must be real numeric"
+ ):
model.predict_match_probability(np.array([[0.0], [1.0]]))
@@ -56,7 +57,9 @@ def test_predict_proba_rejects_non_real_probabilities_before_class_selection():
feature_names=["distance"],
)
- with pytest.raises(ValueError, match="predicted probabilities must be real numeric"):
+ with pytest.raises(
+ ValueError, match="predicted probabilities must be real numeric"
+ ):
model.predict_match_probability(np.array([[0.0]]))
@@ -66,5 +69,7 @@ def test_pairwise_cost_model_rejects_temporal_cost_values():
feature_names=["distance"],
)
- with pytest.raises(ValueError, match="predicted pairwise costs must be real numeric"):
+ with pytest.raises(
+ ValueError, match="predicted pairwise costs must be real numeric"
+ ):
model.predict_match_probability(np.array([[0.0]]))