diff --git a/scripts/check_benchmark_regression.py b/scripts/check_benchmark_regression.py index a9b2cbbbf..4fce50269 100644 --- a/scripts/check_benchmark_regression.py +++ b/scripts/check_benchmark_regression.py @@ -55,7 +55,9 @@ def _flatten_numbers(value: Any, *, path: str) -> list[float]: raise TypeError(f"{path} must be a number or nested sequence of numbers") -def _finite_nonnegative_number(value: Any, *, path: str) -> tuple[float | None, str | None]: +def _finite_nonnegative_number( + value: Any, *, path: str +) -> tuple[float | None, str | None]: if isinstance(value, bool): return None, f"{path} must be numeric, not boolean" if not isinstance(value, int | float): diff --git a/scripts/check_public_api_registry.py b/scripts/check_public_api_registry.py index 83428adeb..fdec151f9 100644 --- a/scripts/check_public_api_registry.py +++ b/scripts/check_public_api_registry.py @@ -41,7 +41,12 @@ def _load_backend_capabilities() -> dict[str, dict[str, str]]: def _markdown_table_cell(value: object) -> str: - return str(value).replace("\r", " ").replace("\n", "
").replace(chr(124), chr(0xFF5C)) + return ( + str(value) + .replace("\r", " ") + .replace("\n", "
") + .replace(chr(124), chr(0xFF5C)) + ) def validate_registry() -> list[str]: diff --git a/src/pyrecest/__main__.py b/src/pyrecest/__main__.py index e72674850..17bfa0a9a 100644 --- a/src/pyrecest/__main__.py +++ b/src/pyrecest/__main__.py @@ -6,6 +6,5 @@ from pyrecest.cli import main - if __name__ == "__main__": # pragma: no cover sys.exit(main()) diff --git a/src/pyrecest/_backend/numpy/_common.py b/src/pyrecest/_backend/numpy/_common.py index 9cad54096..6ad2a6cc7 100644 --- a/src/pyrecest/_backend/numpy/_common.py +++ b/src/pyrecest/_backend/numpy/_common.py @@ -2,7 +2,11 @@ from pyrecest._backend._dtype_utils import ( _dyn_update_dtype, _modify_func_default_dtype, +) +from pyrecest._backend._dtype_utils import ( get_default_cdtype as _shared_get_default_cdtype, +) +from pyrecest._backend._dtype_utils import ( get_default_dtype as _shared_get_default_dtype, ) diff --git a/src/pyrecest/_backend/pytorch/linalg.py b/src/pyrecest/_backend/pytorch/linalg.py index f7204b86f..9f5d5b386 100644 --- a/src/pyrecest/_backend/pytorch/linalg.py +++ b/src/pyrecest/_backend/pytorch/linalg.py @@ -14,7 +14,9 @@ inv, ) from torch.linalg import matrix_exp as expm -from torch.linalg import matrix_power +from torch.linalg import ( + matrix_power, +) from .._backend_config import np_atol as atol from ..numpy import linalg as _gsnplinalg diff --git a/src/pyrecest/_backend/pytorch/random.py b/src/pyrecest/_backend/pytorch/random.py index 487627d65..eee404977 100644 --- a/src/pyrecest/_backend/pytorch/random.py +++ b/src/pyrecest/_backend/pytorch/random.py @@ -89,11 +89,7 @@ def _validate_choice_probabilities(p, population_size, device): raise ValueError("p must be 1-dimensional with one entry per population item") p_sum = p.sum() - if ( - bool(_torch.any(p < 0)) - or not bool(_torch.isfinite(p_sum)) - or bool(p_sum <= 0) - ): + if bool(_torch.any(p < 0)) or not bool(_torch.isfinite(p_sum)) or bool(p_sum <= 0): raise ValueError("probabilities do not sum to a positive value") return p / p_sum @@ -209,7 +205,9 @@ def _normal_size(size): def _broadcasted_parameter_shape(*parameters, message): try: - return tuple(_torch.broadcast_shapes(*(parameter.shape for parameter in parameters))) + return tuple( + _torch.broadcast_shapes(*(parameter.shape for parameter in parameters)) + ) except RuntimeError as exc: raise ValueError(message) from exc diff --git a/src/pyrecest/calibration/__init__.py b/src/pyrecest/calibration/__init__.py index 11723e6d6..d756cc40c 100644 --- a/src/pyrecest/calibration/__init__.py +++ b/src/pyrecest/calibration/__init__.py @@ -15,7 +15,13 @@ ) from .time_offset import ( TimeOffsetFitResult, - aggregate_time_offset_sweeps as _aggregate_time_offset_sweeps, + _aggregate_summary_metric, + _as_nonnegative_summary_count, + _as_summary_scalar, + _validate_error_metric, +) +from .time_offset import aggregate_time_offset_sweeps as _aggregate_time_offset_sweeps +from .time_offset import ( apply_time_offset, fit_time_offset, interpolate_reference_values, @@ -24,12 +30,6 @@ time_offset_error_summary, time_offset_sweep, ) -from .time_offset import ( - _aggregate_summary_metric, - _as_nonnegative_summary_count, - _as_summary_scalar, - _validate_error_metric, -) def aggregate_time_offset_sweeps( diff --git a/src/pyrecest/calibration/time_offset.py b/src/pyrecest/calibration/time_offset.py index dc8896374..5c98273b7 100644 --- a/src/pyrecest/calibration/time_offset.py +++ b/src/pyrecest/calibration/time_offset.py @@ -8,7 +8,6 @@ import numpy as np - _ERROR_METRIC_NAMES = frozenset({"max", "mean", "p95", "rmse", "std"}) _ERROR_METRIC_MESSAGE = "metric must be one of 'max', 'mean', 'p95', 'rmse', or 'std'" diff --git a/src/pyrecest/distributions/abstract_grid_distribution.py b/src/pyrecest/distributions/abstract_grid_distribution.py index 4ffb1c7dc..4568c9f00 100644 --- a/src/pyrecest/distributions/abstract_grid_distribution.py +++ b/src/pyrecest/distributions/abstract_grid_distribution.py @@ -43,8 +43,7 @@ def __init__( ) if dim is not None and actual_dim != dim: raise ValueError( - f"Grid coordinates must have dimension {dim}, got " - f"{actual_dim}." + f"Grid coordinates must have dimension {dim}, got " f"{actual_dim}." ) if grid is None or (grid.ndim > 1 and grid.shape[0] < grid.shape[1]): warnings.warn( diff --git a/src/pyrecest/distributions/abstract_manifold_specific_distribution.py b/src/pyrecest/distributions/abstract_manifold_specific_distribution.py index 08e02cbc3..f01f827ae 100644 --- a/src/pyrecest/distributions/abstract_manifold_specific_distribution.py +++ b/src/pyrecest/distributions/abstract_manifold_specific_distribution.py @@ -10,7 +10,9 @@ # pylint: disable=no-name-in-module,no-member,redefined-builtin from pyrecest.backend import empty, int32, int64, log, random, squeeze -_SCALAR_VALUE_ERROR = "Metropolis-Hastings scalar evaluations must return scalar values." +_SCALAR_VALUE_ERROR = ( + "Metropolis-Hastings scalar evaluations must return scalar values." +) def _shape_size(value) -> int | None: diff --git a/src/pyrecest/distributions/circle/circular_grid_distribution.py b/src/pyrecest/distributions/circle/circular_grid_distribution.py index 8ea39d3e0..3da169e61 100644 --- a/src/pyrecest/distributions/circle/circular_grid_distribution.py +++ b/src/pyrecest/distributions/circle/circular_grid_distribution.py @@ -80,7 +80,9 @@ def _pdf_via_sinc(self, xs, sinc_repetitions): lower = int(floor(sinc_repetitions / 2) * grid_size) upper = int(ceil(sinc_repetitions / 2) * grid_size) repetitions = arange(-lower, upper) - sinc_vals = self._matlab_sinc((xs_eval / step_size)[:, None] - repetitions[None, :]) + sinc_vals = self._matlab_sinc( + (xs_eval / step_size)[:, None] - repetitions[None, :] + ) grid_values = ( sqrt(self.grid_values) if self.enforce_pdf_nonnegative else self.grid_values ) diff --git a/src/pyrecest/distributions/circle/wrapped_exponential_distribution.py b/src/pyrecest/distributions/circle/wrapped_exponential_distribution.py index 607881327..58e636267 100644 --- a/src/pyrecest/distributions/circle/wrapped_exponential_distribution.py +++ b/src/pyrecest/distributions/circle/wrapped_exponential_distribution.py @@ -83,11 +83,7 @@ def entropy(self): # distribution on [0, 2*pi). The direct expression evaluates # log1p(-exp(-log_beta)) and divides by 1 - exp(-log_beta), which # suffers catastrophic cancellation for tiny log_beta. - return ( - log(2.0 * pi) - - log_beta**2 / 24.0 - + log_beta**4 / 960.0 - ) + return log(2.0 * pi) - log_beta**2 / 24.0 + log_beta**4 / 960.0 # Use exp(-2*pi*lambda) to avoid overflowing exp(2*pi*lambda) for # concentrated wrapped exponentials. diff --git a/src/pyrecest/distributions/nonperiodic/abstract_hyperrectangular_distribution.py b/src/pyrecest/distributions/nonperiodic/abstract_hyperrectangular_distribution.py index cac6e01c6..612f539cd 100644 --- a/src/pyrecest/distributions/nonperiodic/abstract_hyperrectangular_distribution.py +++ b/src/pyrecest/distributions/nonperiodic/abstract_hyperrectangular_distribution.py @@ -9,7 +9,6 @@ AbstractBoundedNonPeriodicDistribution, ) - _ERROR_SCALAR_PDF_VALUE = ( "pdf must return one finite scalar value per integration point" ) diff --git a/src/pyrecest/evaluation/generate_measurements.py b/src/pyrecest/evaluation/generate_measurements.py index 5ca2a13bd..09cca80de 100644 --- a/src/pyrecest/evaluation/generate_measurements.py +++ b/src/pyrecest/evaluation/generate_measurements.py @@ -146,7 +146,9 @@ def generate_measurements(groundtruth, simulation_config): y = _as_shapely_scalar(curr_groundtruth[..., 1], "groundtruth y") curr_shape = translate(shape, xoff=x, yoff=y) elif curr_groundtruth.shape[-1] == 3: - angle = _as_shapely_scalar(curr_groundtruth[..., 0], "groundtruth angle") + angle = _as_shapely_scalar( + curr_groundtruth[..., 0], "groundtruth angle" + ) x = _as_shapely_scalar(curr_groundtruth[..., 1], "groundtruth x") y = _as_shapely_scalar(curr_groundtruth[..., 2], "groundtruth y") curr_shape = rotate( @@ -160,9 +162,7 @@ def generate_measurements(groundtruth, simulation_config): "Currently only R^2 and SE(2) scenarios are supported." ) if not isinstance(curr_shape, PolygonWithSampling): - curr_shape.__class__ = ( - PolygonWithSampling # Preserve existing subclass swap to add sampling methods - ) + curr_shape.__class__ = PolygonWithSampling # Preserve existing subclass swap to add sampling methods if "n_meas_at_individual_time_step" in simulation_config: if "intensity_lambda" in simulation_config: diff --git a/src/pyrecest/evaluation/pareto.py b/src/pyrecest/evaluation/pareto.py index 680de95e3..6d793c7a2 100644 --- a/src/pyrecest/evaluation/pareto.py +++ b/src/pyrecest/evaluation/pareto.py @@ -361,7 +361,9 @@ def _coerce_numeric(value: Any) -> float: if value_array.shape != () or value_array.dtype.kind in "bSUcMm": return float("nan") scalar = value_array.item() - if isinstance(scalar, (bool, np.bool_, complex, np.complexfloating)) or _is_text_scalar(scalar): + if isinstance( + scalar, (bool, np.bool_, complex, np.complexfloating) + ) or _is_text_scalar(scalar): return float("nan") try: return float(scalar) diff --git a/src/pyrecest/evidence.py b/src/pyrecest/evidence.py index 46b00a575..2435b5eb5 100644 --- a/src/pyrecest/evidence.py +++ b/src/pyrecest/evidence.py @@ -61,7 +61,10 @@ class EvidenceComputationMode: metadata: dict[str, Any] | None = field(default_factory=dict) def __post_init__(self) -> None: - if not isinstance(self.mode, str) or self.mode not in {"full_smoothing", "evidence_only"}: + if not isinstance(self.mode, str) or self.mode not in { + "full_smoothing", + "evidence_only", + }: raise ValueError(f"unknown evidence computation mode {self.mode!r}") return_smoothed = _coerce_bool_flag(self.return_smoothed, "return_smoothed") terminal_posterior = _coerce_bool_flag( diff --git a/src/pyrecest/experimental/dvs/normal_flow.py b/src/pyrecest/experimental/dvs/normal_flow.py index 0d58fb979..0441bf0a5 100644 --- a/src/pyrecest/experimental/dvs/normal_flow.py +++ b/src/pyrecest/experimental/dvs/normal_flow.py @@ -57,7 +57,9 @@ def _has_non_real_numeric_values(value: object, *, allow_bool: bool = False) -> return array.dtype.kind in {"M", "m"} -def _as_finite_real_array(name: str, value: object, *, allow_bool: bool = False) -> np.ndarray: +def _as_finite_real_array( + name: str, value: object, *, allow_bool: bool = False +) -> np.ndarray: message = f"{name} must contain finite real numeric values." if _has_non_real_numeric_values(value, allow_bool=allow_bool): raise ValueError(message) @@ -70,7 +72,9 @@ def _as_finite_real_array(name: str, value: object, *, allow_bool: bool = False) return array -def _as_finite_real_scalar(name: str, value: object, *, allow_bool: bool = False) -> float: +def _as_finite_real_scalar( + name: str, value: object, *, allow_bool: bool = False +) -> float: message = f"{name} must be a finite real scalar." try: array = _as_finite_real_array(name, value, allow_bool=allow_bool) @@ -172,8 +176,12 @@ def infer_polarity_contrast_sign( return normalized tolerance = _as_finite_nonnegative_scalar("zero_tolerance", zero_tolerance) - flows = _as_finite_real_array("signed_normal_flows", signed_normal_flows).reshape((-1,)) - polarities = _as_finite_real_array("event_polarities", event_polarities, allow_bool=True).reshape((-1,)) + flows = _as_finite_real_array("signed_normal_flows", signed_normal_flows).reshape( + (-1,) + ) + polarities = _as_finite_real_array( + "event_polarities", event_polarities, allow_bool=True + ).reshape((-1,)) if flows.shape != polarities.shape: raise ValueError("event_polarities must have one value per signed normal flow") @@ -234,7 +242,9 @@ def polarity_weight_for_signed_flow( zero_tolerance=1e-12, ) -> float: """Return a multiplicative reliability weight from polarity consistency.""" - mismatch_weight = _as_unit_interval_scalar("polarity_mismatch_weight", polarity_mismatch_weight) + mismatch_weight = _as_unit_interval_scalar( + "polarity_mismatch_weight", polarity_mismatch_weight + ) consistency = polarity_consistency_for_signed_flow( signed_normal_flow_value, event_polarity, @@ -257,9 +267,15 @@ def polarity_weights_for_signed_flows( zero_tolerance=1e-12, ) -> np.ndarray: """Return polarity reliability weights for a batch of signed-flow samples.""" - mismatch_weight = _as_unit_interval_scalar("polarity_mismatch_weight", polarity_mismatch_weight) - flows = _as_finite_real_array("signed_normal_flows", signed_normal_flows).reshape((-1,)) - polarities = _as_finite_real_array("event_polarities", event_polarities, allow_bool=True).reshape((-1,)) + mismatch_weight = _as_unit_interval_scalar( + "polarity_mismatch_weight", polarity_mismatch_weight + ) + flows = _as_finite_real_array("signed_normal_flows", signed_normal_flows).reshape( + (-1,) + ) + polarities = _as_finite_real_array( + "event_polarities", event_polarities, allow_bool=True + ).reshape((-1,)) resolved_sign = infer_polarity_contrast_sign( flows, polarities, diff --git a/src/pyrecest/filters/measurement_reliability.py b/src/pyrecest/filters/measurement_reliability.py index 67bd9b335..f356836f4 100644 --- a/src/pyrecest/filters/measurement_reliability.py +++ b/src/pyrecest/filters/measurement_reliability.py @@ -48,7 +48,10 @@ def _is_real_numeric_object_value(value) -> bool: if kind is not None: return kind in {"i", "u", "f"} dtype_name = str(dtype).lower() - if any(token in dtype_name for token in ("bool", "complex", "str", "string", "object")): + if any( + token in dtype_name + for token in ("bool", "complex", "str", "string", "object") + ): return False if "float" in dtype_name or "int" in dtype_name: return True @@ -81,7 +84,9 @@ def _has_real_numeric_dtype(value) -> bool: if kind is not None: return kind in {"i", "u", "f"} dtype_name = str(dtype).lower() - if any(token in dtype_name for token in ("bool", "complex", "str", "string", "object")): + if any( + token in dtype_name for token in ("bool", "complex", "str", "string", "object") + ): return False return "float" in dtype_name or "int" in dtype_name @@ -236,4 +241,4 @@ def normalize_measurement_noise_covariances( shared_noise = as_covariance_matrix(noise, measurement_dim, name) if n_measurements == 0: return zeros(empty_shape) - return stack([shared_noise for _ in range(n_measurements)]) \ No newline at end of file + return stack([shared_noise for _ in range(n_measurements)]) diff --git a/src/pyrecest/filters/online_time_offset_estimator.py b/src/pyrecest/filters/online_time_offset_estimator.py index ccc3916b9..c93838082 100644 --- a/src/pyrecest/filters/online_time_offset_estimator.py +++ b/src/pyrecest/filters/online_time_offset_estimator.py @@ -7,7 +7,6 @@ import numpy as np - _UNSUPPORTED_NUMERIC_KINDS = {"b", "S", "U", "c", "M", "m"} _UNSUPPORTED_SCALAR_TYPES = ( type(None), diff --git a/src/pyrecest/filters/wrapped_normal_filter.py b/src/pyrecest/filters/wrapped_normal_filter.py index 927ea9088..be919b71b 100644 --- a/src/pyrecest/filters/wrapped_normal_filter.py +++ b/src/pyrecest/filters/wrapped_normal_filter.py @@ -8,7 +8,6 @@ from .abstract_filter import AbstractFilter from .manifold_mixins import CircularFilterMixin - _PROGRESSIVE_TAU_MESSAGE = "tau must be a positive finite scalar" diff --git a/src/pyrecest/models/__init__.py b/src/pyrecest/models/__init__.py index 315b7e90a..2ff3499c0 100644 --- a/src/pyrecest/models/__init__.py +++ b/src/pyrecest/models/__init__.py @@ -4,6 +4,7 @@ and capability-oriented so filters can opt into the pieces they need. """ +from ._sampleable_transition_validation import install_sampleable_transition_validation from ._validated_motion_models import nearly_coordinated_turn_model from .adapters import ( LinearMeasurementArguments, @@ -36,7 +37,6 @@ SupportsTransitionDensity, SupportsTransitionSampling, ) -from ._sampleable_transition_validation import install_sampleable_transition_validation install_sampleable_transition_validation() diff --git a/src/pyrecest/models/motion_models.py b/src/pyrecest/models/motion_models.py index 740315e39..929959010 100644 --- a/src/pyrecest/models/motion_models.py +++ b/src/pyrecest/models/motion_models.py @@ -67,7 +67,10 @@ def _is_text_bool_or_complex(value: Any) -> bool: def _as_scalar_float(value: Any, name: str) -> float: value_array = np.asarray(value) - if value_array.shape != () or value_array.dtype.kind in _REJECTED_NUMERIC_ARRAY_KINDS: + if ( + value_array.shape != () + or value_array.dtype.kind in _REJECTED_NUMERIC_ARRAY_KINDS + ): raise ValueError(f"{name} must be a scalar number") scalar_value = value_array.item() if _is_text_bool_or_complex(scalar_value): diff --git a/src/pyrecest/models/sensor_models.py b/src/pyrecest/models/sensor_models.py index 9162f7469..59d5bf65e 100644 --- a/src/pyrecest/models/sensor_models.py +++ b/src/pyrecest/models/sensor_models.py @@ -47,7 +47,10 @@ def _state_vector(state): def _as_scalar_float(value: Any, name: str) -> float: value_array = np.asarray(value) - if value_array.shape != () or value_array.dtype.kind in _REJECTED_SCALAR_ARRAY_KINDS: + if ( + value_array.shape != () + or value_array.dtype.kind in _REJECTED_SCALAR_ARRAY_KINDS + ): raise ValueError(f"{name} must be a scalar number") scalar_value = value_array.item() if isinstance(scalar_value, _TEXT_OR_BOOL_SCALAR_TYPES) or isinstance( diff --git a/src/pyrecest/models/weak_measurement.py b/src/pyrecest/models/weak_measurement.py index 145a8b31f..f94771d9f 100644 --- a/src/pyrecest/models/weak_measurement.py +++ b/src/pyrecest/models/weak_measurement.py @@ -225,9 +225,7 @@ def _contains_non_real_numeric_values(value: Any) -> bool: return True if array.dtype.kind != "O": return False - return any( - isinstance(item, _NON_REAL_NUMERIC_SCALAR_TYPES) for item in array.flat - ) + return any(isinstance(item, _NON_REAL_NUMERIC_SCALAR_TYPES) for item in array.flat) def _positive_int(value: int, name: str) -> int: diff --git a/src/pyrecest/tracking/innovation_diagnostics.py b/src/pyrecest/tracking/innovation_diagnostics.py index defd54fb1..d4fb0c36b 100644 --- a/src/pyrecest/tracking/innovation_diagnostics.py +++ b/src/pyrecest/tracking/innovation_diagnostics.py @@ -369,7 +369,9 @@ def _as_finite_real_array(value: Any, name: str) -> np.ndarray: if raw_values.dtype == np.bool_ or raw_values.dtype.kind in "USbcMm": raise ValueError(message) - if _contains_values_of_type(value, _INVALID_REAL_NUMERIC_TYPES) or _contains_values_of_type(raw_values, _INVALID_REAL_NUMERIC_TYPES): + if _contains_values_of_type( + value, _INVALID_REAL_NUMERIC_TYPES + ) or _contains_values_of_type(raw_values, _INVALID_REAL_NUMERIC_TYPES): raise ValueError(message) try: diff --git a/src/pyrecest/utils/multisession_assignment_observation_costs.py b/src/pyrecest/utils/multisession_assignment_observation_costs.py index 1c2c0dda3..4a9772b80 100644 --- a/src/pyrecest/utils/multisession_assignment_observation_costs.py +++ b/src/pyrecest/utils/multisession_assignment_observation_costs.py @@ -72,8 +72,7 @@ def _contains_non_real_cost(values: np.ndarray) -> bool: return True if values.dtype == object: return any( - isinstance(item, _INVALID_COST_SCALAR_TYPES) - for item in values.reshape(-1) + isinstance(item, _INVALID_COST_SCALAR_TYPES) for item in values.reshape(-1) ) return False @@ -357,4 +356,4 @@ def _transform_pairwise_costs( return transformed -__all__ = ["solve_multisession_assignment_with_observation_costs"] \ No newline at end of file +__all__ = ["solve_multisession_assignment_with_observation_costs"] diff --git a/src/pyrecest/utils/multisession_assignment_score.py b/src/pyrecest/utils/multisession_assignment_score.py index 4c17e4af2..ab87602b9 100644 --- a/src/pyrecest/utils/multisession_assignment_score.py +++ b/src/pyrecest/utils/multisession_assignment_score.py @@ -34,7 +34,6 @@ solve_multisession_assignment, ) - _INVALID_SCORE_SCALAR_TYPES = ( type(None), bool, diff --git a/tests/backend/test_numpy_multinomial_real_pvals.py b/tests/backend/test_numpy_multinomial_real_pvals.py index 16a1033db..45897d704 100644 --- a/tests/backend/test_numpy_multinomial_real_pvals.py +++ b/tests/backend/test_numpy_multinomial_real_pvals.py @@ -1,6 +1,5 @@ import numpy as np import pytest - from pyrecest._backend.numpy import random diff --git a/tests/backend_support/test_pytorch_rotation_stub_contract.py b/tests/backend_support/test_pytorch_rotation_stub_contract.py index aed8e4ad7..af52dc650 100644 --- a/tests/backend_support/test_pytorch_rotation_stub_contract.py +++ b/tests/backend_support/test_pytorch_rotation_stub_contract.py @@ -15,9 +15,7 @@ def test_pytorch_rotation_stub_exposes_matrix_methods_with_backend_error(): from pyrecest._backend.pytorch.spatial import Rotation with pytest.raises(RuntimeError, match="Rotation.from_matrix is not supported"): - Rotation.from_matrix( - [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] - ) + Rotation.from_matrix([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) rotation = object.__new__(Rotation) with pytest.raises(RuntimeError, match="Rotation.as_matrix is not supported"): diff --git a/tests/calibration/test_bias_complex_input_validation.py b/tests/calibration/test_bias_complex_input_validation.py index 476ea9579..77ea9456f 100644 --- a/tests/calibration/test_bias_complex_input_validation.py +++ b/tests/calibration/test_bias_complex_input_validation.py @@ -1,7 +1,6 @@ import unittest import numpy as np - from pyrecest.calibration.bias import ( SensorBiasCorrectionModel, make_bias_training_examples, diff --git a/tests/calibration/test_time_offset_aggregate_validation.py b/tests/calibration/test_time_offset_aggregate_validation.py index f3c8b16cf..139daad35 100644 --- a/tests/calibration/test_time_offset_aggregate_validation.py +++ b/tests/calibration/test_time_offset_aggregate_validation.py @@ -2,7 +2,6 @@ import numpy as np import numpy.testing as npt - from pyrecest.calibration import aggregate_time_offset_sweeps from pyrecest.calibration.time_offset import ( aggregate_time_offset_sweeps as aggregate_time_offset_sweeps_from_module, diff --git a/tests/calibration/test_time_offset_metric_validation.py b/tests/calibration/test_time_offset_metric_validation.py index e62ffd625..8da70fb87 100644 --- a/tests/calibration/test_time_offset_metric_validation.py +++ b/tests/calibration/test_time_offset_metric_validation.py @@ -2,7 +2,6 @@ import numpy as np import numpy.testing as npt - from pyrecest.calibration.time_offset import fit_time_offset diff --git a/tests/distributions/test_abstract_grid_distribution.py b/tests/distributions/test_abstract_grid_distribution.py index 9044ca230..1180f4005 100644 --- a/tests/distributions/test_abstract_grid_distribution.py +++ b/tests/distributions/test_abstract_grid_distribution.py @@ -49,15 +49,11 @@ def test_constructor_rejects_grid_dimension_mismatch(self): def test_constructor_rejects_one_dimensional_grid_for_multidimensional_space(self): with self.assertRaisesRegex(ValueError, "dimension 2"): - DummyGridDistribution( - array([1.0, 1.0]), grid=array([0.0, 1.0]), dim=2 - ) + DummyGridDistribution(array([1.0, 1.0]), grid=array([0.0, 1.0]), dim=2) def test_constructor_rejects_higher_rank_grid_coordinates(self): with self.assertRaisesRegex(ValueError, "one- or two-dimensional"): - DummyGridDistribution( - array([1.0, 1.0]), grid=array([[[0.0]], [[1.0]]]) - ) + DummyGridDistribution(array([1.0, 1.0]), grid=array([[[0.0]], [[1.0]]])) def test_integrate_rejects_custom_boundaries(self): dist = DummyGridDistribution(array([1.0, 1.0])) diff --git a/tests/distributions/test_circular_grid_scalar_sinc_pdf.py b/tests/distributions/test_circular_grid_scalar_sinc_pdf.py index c4b16acfb..bac1b0e89 100644 --- a/tests/distributions/test_circular_grid_scalar_sinc_pdf.py +++ b/tests/distributions/test_circular_grid_scalar_sinc_pdf.py @@ -1,7 +1,9 @@ import unittest from pyrecest.backend import array, pi -from pyrecest.distributions.circle.circular_grid_distribution import CircularGridDistribution +from pyrecest.distributions.circle.circular_grid_distribution import ( + CircularGridDistribution, +) class CircularGridScalarSincPdfTest(unittest.TestCase): diff --git a/tests/distributions/test_circular_mixture.py b/tests/distributions/test_circular_mixture.py index b9329f299..c7071086e 100644 --- a/tests/distributions/test_circular_mixture.py +++ b/tests/distributions/test_circular_mixture.py @@ -2,7 +2,6 @@ import numpy as np import numpy.testing as npt - import pyrecest.backend # pylint: disable=no-name-in-module,no-member diff --git a/tests/distributions/test_hyperrectangular_integration_validation.py b/tests/distributions/test_hyperrectangular_integration_validation.py index 9e0236de3..3a96c2075 100644 --- a/tests/distributions/test_hyperrectangular_integration_validation.py +++ b/tests/distributions/test_hyperrectangular_integration_validation.py @@ -2,7 +2,6 @@ import numpy as np import pytest - from pyrecest.distributions.nonperiodic.hyperrectangular_uniform_distribution import ( HyperrectangularUniformDistribution, ) diff --git a/tests/distributions/test_linear_mixture.py b/tests/distributions/test_linear_mixture.py index 22d8a9d0b..8963e7728 100644 --- a/tests/distributions/test_linear_mixture.py +++ b/tests/distributions/test_linear_mixture.py @@ -3,7 +3,6 @@ import numpy as np import numpy.testing as npt - import pyrecest.backend # pylint: disable=no-name-in-module,no-member diff --git a/tests/evaluation/test_distance_registry_normalization.py b/tests/evaluation/test_distance_registry_normalization.py index 08f5b689e..f96cf638a 100644 --- a/tests/evaluation/test_distance_registry_normalization.py +++ b/tests/evaluation/test_distance_registry_normalization.py @@ -14,5 +14,13 @@ def test_custom_distance_registry_strips_names_for_registration_and_lookup(): assert "unit-test-distance-trimmed" in available_distance_functions() assert " unit-test-distance-trimmed " not in available_distance_functions() - assert get_distance_function("unit-test-distance-trimmed")(array([0.0]), array([1.0])) == 13.0 - assert get_distance_function(" unit-test-distance-trimmed ")(array([0.0]), array([1.0])) == 13.0 + assert ( + get_distance_function("unit-test-distance-trimmed")(array([0.0]), array([1.0])) + == 13.0 + ) + assert ( + get_distance_function(" unit-test-distance-trimmed ")( + array([0.0]), array([1.0]) + ) + == 13.0 + ) diff --git a/tests/evaluation/test_pareto_objective_scalar_validation.py b/tests/evaluation/test_pareto_objective_scalar_validation.py index 251ce28e1..a79261f86 100644 --- a/tests/evaluation/test_pareto_objective_scalar_validation.py +++ b/tests/evaluation/test_pareto_objective_scalar_validation.py @@ -1,7 +1,6 @@ from __future__ import annotations import pandas as pd - from pyrecest.evaluation import is_pareto_front, pareto_front_indices, record_dominates diff --git a/tests/experimental/test_dvs_event_likelihood_counts.py b/tests/experimental/test_dvs_event_likelihood_counts.py index 731af47dd..067223481 100644 --- a/tests/experimental/test_dvs_event_likelihood_counts.py +++ b/tests/experimental/test_dvs_event_likelihood_counts.py @@ -1,7 +1,6 @@ import unittest import numpy as np - from pyrecest.experimental.dvs.event_likelihood import PointProcessUpdateConfig diff --git a/tests/experimental/test_dvs_normal_flow.py b/tests/experimental/test_dvs_normal_flow.py index dc965927a..613445b06 100644 --- a/tests/experimental/test_dvs_normal_flow.py +++ b/tests/experimental/test_dvs_normal_flow.py @@ -24,7 +24,9 @@ def test_signed_scalar_sign_rejects_invalid_values(bad_value): signed_scalar_sign(bad_value) -@pytest.mark.parametrize("bad_tolerance", [-1e-12, np.nan, np.inf, "1e-12", 1e-12 + 0.0j]) +@pytest.mark.parametrize( + "bad_tolerance", [-1e-12, np.nan, np.inf, "1e-12", 1e-12 + 0.0j] +) def test_signed_scalar_sign_rejects_invalid_zero_tolerance(bad_tolerance): with pytest.raises(ValueError, match="zero_tolerance"): signed_scalar_sign(1.0, zero_tolerance=bad_tolerance) @@ -108,7 +110,9 @@ def test_polarity_consistency_and_weight_for_signed_flow(): @pytest.mark.parametrize("bad_weight", [np.nan, np.inf, "0.5", 0.5 + 0.0j, -0.1, 1.1]) def test_polarity_weight_rejects_invalid_mismatch_weight(bad_weight): with pytest.raises(ValueError, match="polarity_mismatch_weight"): - polarity_weight_for_signed_flow(1.0, 0.0, 1.0, polarity_mismatch_weight=bad_weight) + polarity_weight_for_signed_flow( + 1.0, 0.0, 1.0, polarity_mismatch_weight=bad_weight + ) def test_polarity_weights_for_signed_flows_resolves_batch_sign(): @@ -122,4 +126,6 @@ def test_polarity_weights_for_signed_flows_resolves_batch_sign(): def test_polarity_weights_for_signed_flows_rejects_invalid_mismatch_weight_when_disabled(): with pytest.raises(ValueError, match="polarity_mismatch_weight"): - polarity_weights_for_signed_flows([1.0], [1.0], None, polarity_mismatch_weight=np.nan) + polarity_weights_for_signed_flows( + [1.0], [1.0], None, polarity_mismatch_weight=np.nan + ) diff --git a/tests/filters/test_abstract_extended_object_tracker.py b/tests/filters/test_abstract_extended_object_tracker.py index 227854ce7..ece376f35 100644 --- a/tests/filters/test_abstract_extended_object_tracker.py +++ b/tests/filters/test_abstract_extended_object_tracker.py @@ -5,7 +5,9 @@ # pylint: disable=no-name-in-module,no-member from pyrecest.backend import array, eye -from pyrecest.filters.abstract_extended_object_tracker import AbstractExtendedObjectTracker +from pyrecest.filters.abstract_extended_object_tracker import ( + AbstractExtendedObjectTracker, +) class _MinimalExtendedObjectTracker(AbstractExtendedObjectTracker): diff --git a/tests/filters/test_gaussian_mixture_phd_filter_dimensions.py b/tests/filters/test_gaussian_mixture_phd_filter_dimensions.py index ca3eda81a..77060b3e2 100644 --- a/tests/filters/test_gaussian_mixture_phd_filter_dimensions.py +++ b/tests/filters/test_gaussian_mixture_phd_filter_dimensions.py @@ -12,7 +12,9 @@ ) class TestGaussianMixturePHDFilterDimensions(unittest.TestCase): def test_constructor_rejects_mismatched_birth_dimension(self): - with self.assertRaisesRegex(ValueError, "Birth components must have dimension 2"): + with self.assertRaisesRegex( + ValueError, "Birth components must have dimension 2" + ): GaussianMixturePHDFilter( initial_components=[GaussianDistribution(array([0.0, 0.0]), eye(2))], initial_weights=array([0.8]), @@ -32,7 +34,9 @@ def test_set_birth_model_rejects_mismatched_dimension(self): log_posterior_estimates=False, ) - with self.assertRaisesRegex(ValueError, "Birth components must have dimension 2"): + with self.assertRaisesRegex( + ValueError, "Birth components must have dimension 2" + ): tracker.set_birth_model( [GaussianDistribution(array([5.0, 5.0, 0.0]), eye(3))], array([0.2]), @@ -46,7 +50,9 @@ def test_predict_linear_rejects_mismatched_temporary_birth_dimension(self): log_posterior_estimates=False, ) - with self.assertRaisesRegex(ValueError, "Birth components must have dimension 2"): + with self.assertRaisesRegex( + ValueError, "Birth components must have dimension 2" + ): tracker.predict_linear( eye(2), 0.1 * eye(2), diff --git a/tests/filters/test_gnn_pairwise_cost_validation.py b/tests/filters/test_gnn_pairwise_cost_validation.py index 2c58add7a..2d0ee463f 100644 --- a/tests/filters/test_gnn_pairwise_cost_validation.py +++ b/tests/filters/test_gnn_pairwise_cost_validation.py @@ -1,7 +1,6 @@ import unittest import numpy as np - from pyrecest.filters.global_nearest_neighbor import GlobalNearestNeighbor diff --git a/tests/filters/test_measurement_reliability_object_weights.py b/tests/filters/test_measurement_reliability_object_weights.py index 366559fd4..df98d9b32 100644 --- a/tests/filters/test_measurement_reliability_object_weights.py +++ b/tests/filters/test_measurement_reliability_object_weights.py @@ -34,4 +34,4 @@ def test_object_weight_inputs_still_reject_non_real_values(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/filters/test_measurement_scoring.py b/tests/filters/test_measurement_scoring.py index 083dbcf58..bfd63d276 100644 --- a/tests/filters/test_measurement_scoring.py +++ b/tests/filters/test_measurement_scoring.py @@ -1,5 +1,4 @@ import numpy as np - from pyrecest.filters.measurement_scoring import MeasurementScore diff --git a/tests/filters/test_update_diagnostics_validation.py b/tests/filters/test_update_diagnostics_validation.py index cb02e33bd..4bacc3234 100644 --- a/tests/filters/test_update_diagnostics_validation.py +++ b/tests/filters/test_update_diagnostics_validation.py @@ -30,7 +30,9 @@ def test_active_measurement_indices_must_fit_measurement_count(): ValueError, match="active_measurement_indices.*measurement_count", ): - MeasurementUpdateDiagnostics(active_measurement_indices=(0, 2), measurement_count=2) + MeasurementUpdateDiagnostics( + active_measurement_indices=(0, 2), measurement_count=2 + ) def test_measurement_count_rejects_non_integer_values(): diff --git a/tests/models/test_linear_gaussian_models.py b/tests/models/test_linear_gaussian_models.py index 37df6ae28..dc19c8369 100644 --- a/tests/models/test_linear_gaussian_models.py +++ b/tests/models/test_linear_gaussian_models.py @@ -70,7 +70,17 @@ def test_identity_models(self): ) def test_identity_models_reject_invalid_dimensions(self): - invalid_dims = (True, False, 0, -1, 1.5, float("inf"), "2", b"2", bytearray(b"2")) + invalid_dims = ( + True, + False, + 0, + -1, + 1.5, + float("inf"), + "2", + b"2", + bytearray(b"2"), + ) for dim in invalid_dims: with self.subTest(model="transition", dim=dim): diff --git a/tests/models/test_sampleable_transition_vectorization_validation.py b/tests/models/test_sampleable_transition_vectorization_validation.py index fe59cae73..1a57a1f8f 100644 --- a/tests/models/test_sampleable_transition_vectorization_validation.py +++ b/tests/models/test_sampleable_transition_vectorization_validation.py @@ -1,5 +1,4 @@ import pytest - from pyrecest.models.likelihood import SampleableTransitionModel diff --git a/tests/models/test_weak_measurement_nonreal_inputs.py b/tests/models/test_weak_measurement_nonreal_inputs.py index 0fa01717c..55e470b7b 100644 --- a/tests/models/test_weak_measurement_nonreal_inputs.py +++ b/tests/models/test_weak_measurement_nonreal_inputs.py @@ -4,7 +4,6 @@ import numpy as np import pytest - from pyrecest.models import ( MaskedLinearMeasurementModel, WeakDimensionMeasurementModel, @@ -12,7 +11,6 @@ diagonal_measurement_covariance, ) - _NON_REAL_STD_CASES = ( [1.0 + 0.0j], [np.complex128(1.0 + 0.0j)], diff --git a/tests/sampling/test_support_points.py b/tests/sampling/test_support_points.py index 6ccb4a79f..dbaf7da67 100644 --- a/tests/sampling/test_support_points.py +++ b/tests/sampling/test_support_points.py @@ -52,7 +52,9 @@ def test_support_points_from_axis_offsets_supports_batches() -> None: def test_support_points_from_axis_offsets_can_omit_center() -> None: - support = support_points_from_axis_offsets([0.0, 0.0], np.eye(2), include_center=False) + support = support_points_from_axis_offsets( + [0.0, 0.0], np.eye(2), include_center=False + ) assert support.shape == (4, 2) assert np.allclose(support, [[1.0, 0.0], [-1.0, 0.0], [0.0, 1.0], [0.0, -1.0]]) diff --git a/tests/test_backend_default_dtype_like_values.py b/tests/test_backend_default_dtype_like_values.py index 0f31a41f1..d5727bb2c 100644 --- a/tests/test_backend_default_dtype_like_values.py +++ b/tests/test_backend_default_dtype_like_values.py @@ -1,6 +1,5 @@ import numpy as np import pytest - from pyrecest._backend import numpy as numpy_backend diff --git a/tests/test_backend_sum_contract.py b/tests/test_backend_sum_contract.py index 1732317e2..f76c7e7fe 100644 --- a/tests/test_backend_sum_contract.py +++ b/tests/test_backend_sum_contract.py @@ -30,7 +30,9 @@ def test_convert_to_wider_dtype_matches_numpy_result_type_for_mixed_dtypes(): expected_dtype = np.result_type(np.dtype("int64"), np.dtype("float32")) assert backend.to_numpy(first).dtype == expected_dtype assert backend.to_numpy(second).dtype == expected_dtype - np.testing.assert_allclose(backend.to_numpy(first), np.array([1], dtype=expected_dtype)) + np.testing.assert_allclose( + backend.to_numpy(first), np.array([1], dtype=expected_dtype) + ) np.testing.assert_allclose( backend.to_numpy(second), np.array([1.5], dtype=expected_dtype) ) diff --git a/tests/test_backend_support_public.py b/tests/test_backend_support_public.py index fb678ff48..cd3aeaddd 100644 --- a/tests/test_backend_support_public.py +++ b/tests/test_backend_support_public.py @@ -7,7 +7,6 @@ ) from pyrecest._backend.capabilities import BACKEND_SUPPORT_LEVELS - backend_support_module = import_module("pyrecest.backend_support") diff --git a/tests/test_benchmark_regression_script.py b/tests/test_benchmark_regression_script.py index 41c06189c..889514d4c 100644 --- a/tests/test_benchmark_regression_script.py +++ b/tests/test_benchmark_regression_script.py @@ -52,11 +52,16 @@ def test_benchmark_regression_rejects_boolean_elapsed(tmp_path: Path) -> None: ) assert completed.returncode == 1 - assert "::error::linear_kalman.elapsed_seconds must be numeric, not boolean" in completed.stdout + assert ( + "::error::linear_kalman.elapsed_seconds must be numeric, not boolean" + in completed.stdout + ) assert "Traceback" not in completed.stderr -def test_benchmark_regression_rejects_text_elapsed_without_traceback(tmp_path: Path) -> None: +def test_benchmark_regression_rejects_text_elapsed_without_traceback( + tmp_path: Path, +) -> None: completed = _run_checker( tmp_path, actual_entry={ @@ -74,7 +79,10 @@ def test_benchmark_regression_rejects_text_elapsed_without_traceback(tmp_path: P ) assert completed.returncode == 1 - assert "::error::linear_kalman.elapsed_seconds must be numeric, got '0.01'" in completed.stdout + assert ( + "::error::linear_kalman.elapsed_seconds must be numeric, got '0.01'" + in completed.stdout + ) assert "Traceback" not in completed.stderr @@ -96,11 +104,16 @@ def test_benchmark_regression_rejects_text_iterations(tmp_path: Path) -> None: ) assert completed.returncode == 1 - assert "::error::linear_kalman.iterations must be an integer, got '200'" in completed.stdout + assert ( + "::error::linear_kalman.iterations must be an integer, got '200'" + in completed.stdout + ) assert "Traceback" not in completed.stderr -def test_benchmark_regression_reports_invalid_final_estimate_without_traceback(tmp_path: Path) -> None: +def test_benchmark_regression_reports_invalid_final_estimate_without_traceback( + tmp_path: Path, +) -> None: completed = _run_checker( tmp_path, actual_entry={ @@ -118,5 +131,8 @@ def test_benchmark_regression_reports_invalid_final_estimate_without_traceback(t ) assert completed.returncode == 1 - assert "::error::linear_kalman.final_estimate[1] must be numeric, not boolean" in completed.stdout + assert ( + "::error::linear_kalman.final_estimate[1] must be numeric, not boolean" + in completed.stdout + ) assert "Traceback" not in completed.stderr diff --git a/tests/test_candidate_pruning_temporal_validation.py b/tests/test_candidate_pruning_temporal_validation.py index bacdde405..9bd341044 100644 --- a/tests/test_candidate_pruning_temporal_validation.py +++ b/tests/test_candidate_pruning_temporal_validation.py @@ -1,7 +1,6 @@ import unittest import numpy as np - from pyrecest.utils import ( CandidatePruningConfig, candidate_mask_from_costs, @@ -21,7 +20,9 @@ def test_cost_matrix_rejects_datetime_and_timedelta_entries(self): for function in (candidate_mask_from_costs, prune_pairwise_cost_matrix): for matrix in invalid_matrices: with self.subTest(function=function.__name__, dtype=str(matrix.dtype)): - with self.assertRaisesRegex(ValueError, "cost_matrix must be numeric"): + with self.assertRaisesRegex( + ValueError, "cost_matrix must be numeric" + ): function(matrix) def test_probability_matrix_rejects_datetime_and_timedelta_entries(self): @@ -35,7 +36,9 @@ def test_probability_matrix_rejects_datetime_and_timedelta_entries(self): for probabilities in invalid_probability_matrices: with self.subTest(dtype=str(probabilities.dtype)): - with self.assertRaisesRegex(ValueError, "probability_matrix must be numeric"): + with self.assertRaisesRegex( + ValueError, "probability_matrix must be numeric" + ): candidate_mask_from_costs( np.array([[1.0]]), probability_matrix=probabilities, diff --git a/tests/test_cli_expected_mapping.py b/tests/test_cli_expected_mapping.py index 988b56d0b..60994621e 100644 --- a/tests/test_cli_expected_mapping.py +++ b/tests/test_cli_expected_mapping.py @@ -11,7 +11,9 @@ def test_expected_mapping_reports_nonnumeric_actual_instead_of_raising(): tolerance=1e-8, ) - assert errors == ["metrics.rmse mismatch: expected finite numeric 1.0, got 'not-a-number'"] + assert errors == [ + "metrics.rmse mismatch: expected finite numeric 1.0, got 'not-a-number'" + ] def test_expected_mapping_compares_boolean_expected_values_exactly(): diff --git a/tests/test_dynamic_models.py b/tests/test_dynamic_models.py index ff9aaa955..6566660c0 100644 --- a/tests/test_dynamic_models.py +++ b/tests/test_dynamic_models.py @@ -110,7 +110,10 @@ def test_process_noise_rejects_invalid_parameters(self): {"spectral_density": True, "message": "spectral_density"}, {"spectral_density": "1.0", "message": "spectral_density"}, {"spectral_density": ["1.0", "2.0"], "message": "spectral_density"}, - {"spectral_density": np.array([True, False], dtype=object), "message": "spectral_density"}, + { + "spectral_density": np.array([True, False], dtype=object), + "message": "spectral_density", + }, { "spectral_density": np.array([1.0, np.nan]), "message": "spectral_density", diff --git a/tests/test_multisession_assignment_observation_cost_type_validation.py b/tests/test_multisession_assignment_observation_cost_type_validation.py index d5c1ddeed..b056eee72 100644 --- a/tests/test_multisession_assignment_observation_cost_type_validation.py +++ b/tests/test_multisession_assignment_observation_cost_type_validation.py @@ -3,7 +3,6 @@ import unittest import numpy as np - from pyrecest.backend import ( # pylint: disable=no-name-in-module __backend_name__, array, diff --git a/tests/test_numpy_backend_dynamic_dtype.py b/tests/test_numpy_backend_dynamic_dtype.py index 68a03ca9a..debe92071 100644 --- a/tests/test_numpy_backend_dynamic_dtype.py +++ b/tests/test_numpy_backend_dynamic_dtype.py @@ -1,6 +1,6 @@ import numpy as np - -from pyrecest._backend import _backend_config, numpy as numpy_backend +from pyrecest._backend import _backend_config +from pyrecest._backend import numpy as numpy_backend def test_dynamic_dtype_preserves_explicit_positional_dtype_for_zeros(): @@ -28,6 +28,9 @@ def test_dynamic_dtype_uses_default_dtype_for_omitted_or_none_dtype(): _backend_config.DEFAULT_DTYPE = expected_dtype assert numpy_backend.zeros((2,)).dtype == expected_dtype assert numpy_backend.empty((2,), None).dtype == expected_dtype - assert numpy_backend.linspace(0.0, 1.0, 3, True, False, None).dtype == expected_dtype + assert ( + numpy_backend.linspace(0.0, 1.0, 3, True, False, None).dtype + == expected_dtype + ) finally: _backend_config.DEFAULT_DTYPE = previous_dtype diff --git a/tests/test_numpy_backend_zeros_dtype.py b/tests/test_numpy_backend_zeros_dtype.py index 25019cc9d..2e756a793 100644 --- a/tests/test_numpy_backend_zeros_dtype.py +++ b/tests/test_numpy_backend_zeros_dtype.py @@ -2,7 +2,6 @@ import numpy as np import pytest - from pyrecest._backend import numpy as numpy_backend from tests.support.backend_runner import run_backend_code diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py index 2787a57ec..d9b989639 100644 --- a/tests/test_reproducibility.py +++ b/tests/test_reproducibility.py @@ -6,7 +6,13 @@ class ReproducibilityValidationTest(unittest.TestCase): def test_normalize_seed_rejects_text_values(self): - for value in ("1", np.array("1"), bytes([49]), bytearray([49]), np.bytes_(bytes([49]))): + for value in ( + "1", + np.array("1"), + bytes([49]), + bytearray([49]), + np.bytes_(bytes([49])), + ): with self.subTest(value=repr(value)): with self.assertRaisesRegex( ValueError, diff --git a/tests/test_scenario_type_validation.py b/tests/test_scenario_type_validation.py index 719a3f055..e79c167de 100644 --- a/tests/test_scenario_type_validation.py +++ b/tests/test_scenario_type_validation.py @@ -1,5 +1,4 @@ import pytest - from pyrecest import scenarios @@ -27,7 +26,9 @@ def runner(_path): try: assert scenarios.register_scenario_runner(f" {name}\t", runner) is runner assert name in scenarios.available_scenario_types() - assert f" {name}\t" not in scenarios._SCENARIO_RUNNERS # pylint: disable=protected-access + assert ( + f" {name}\t" not in scenarios._SCENARIO_RUNNERS + ) # pylint: disable=protected-access with pytest.raises(ValueError, match="already registered"): scenarios.register_scenario_runner(name, runner) @@ -38,8 +39,7 @@ def runner(_path): def test_run_scenario_rejects_unhashable_scenario_type(tmp_path): scenario_path = tmp_path / "bad_scenario.toml" scenario_path.write_text( - "[scenario]\n" - 'type = ["linear_gaussian"]\n', + "[scenario]\n" 'type = ["linear_gaussian"]\n', encoding="utf-8", ) diff --git a/tests/test_scenarios.py b/tests/test_scenarios.py index 090117cc3..cd926a58a 100644 --- a/tests/test_scenarios.py +++ b/tests/test_scenarios.py @@ -2,7 +2,6 @@ from pathlib import Path import pytest - from pyrecest.scenarios import load_scenario_config, run_scenario SCENARIO = Path("scenarios/linear_gaussian_cv_1d/config.toml") diff --git a/tests/test_track_edit_whatif_empty_split.py b/tests/test_track_edit_whatif_empty_split.py index 1f4819fd6..9c6e9283d 100644 --- a/tests/test_track_edit_whatif_empty_split.py +++ b/tests/test_track_edit_whatif_empty_split.py @@ -1,6 +1,10 @@ """Regression tests for empty split-track edits.""" -from pyrecest.utils.track_edit_whatif import TrackEdit, apply_track_edit, score_track_edit_delta +from pyrecest.utils.track_edit_whatif import ( + TrackEdit, + apply_track_edit, + score_track_edit_delta, +) def test_split_track_rejects_empty_track() -> None: diff --git a/tests/test_track_edit_whatif_index_validation.py b/tests/test_track_edit_whatif_index_validation.py index 169b4024a..037038ce4 100644 --- a/tests/test_track_edit_whatif_index_validation.py +++ b/tests/test_track_edit_whatif_index_validation.py @@ -1,12 +1,17 @@ """Validation tests for track-edit what-if selector indices.""" import pytest - -from pyrecest.utils.track_edit_whatif import TrackEdit, apply_track_edit, score_track_edit_delta +from pyrecest.utils.track_edit_whatif import ( + TrackEdit, + apply_track_edit, + score_track_edit_delta, +) def test_add_link_rejects_invalid_link_fields() -> None: - with pytest.raises(ValueError, match="target_observation must be a non-negative integer"): + with pytest.raises( + ValueError, match="target_observation must be a non-negative integer" + ): apply_track_edit( [[1, None]], TrackEdit( @@ -40,8 +45,14 @@ def test_score_rejects_invalid_session_selectors() -> None: target_observation=2, ) - with pytest.raises(ValueError, match="complete_session_indices must be a non-negative integer"): - score_track_edit_delta([[1, 2]], [[1, 2]], edit, complete_session_indices=[0, 1.5]) + with pytest.raises( + ValueError, match="complete_session_indices must be a non-negative integer" + ): + score_track_edit_delta( + [[1, 2]], [[1, 2]], edit, complete_session_indices=[0, 1.5] + ) - with pytest.raises(ValueError, match="session_pairs must be a non-negative integer"): + with pytest.raises( + ValueError, match="session_pairs must be a non-negative integer" + ): score_track_edit_delta([[1, 2]], [[1, 2]], edit, session_pairs=[(bool(0), 1)]) diff --git a/tests/test_track_edit_whatif_metadata_validation.py b/tests/test_track_edit_whatif_metadata_validation.py index face93ea7..0a4faabba 100644 --- a/tests/test_track_edit_whatif_metadata_validation.py +++ b/tests/test_track_edit_whatif_metadata_validation.py @@ -1,7 +1,6 @@ """Regression tests for track-edit metadata selector validation.""" import pytest - from pyrecest.utils.track_edit_whatif import TrackEdit, apply_track_edit @@ -15,7 +14,9 @@ def test_remove_link_rejects_fractional_occurrence_index() -> None: metadata={"occurrence_index": 0.5}, ) - with pytest.raises(ValueError, match="metadata\['occurrence_index'\] must be an integer"): + with pytest.raises( + ValueError, match="metadata\['occurrence_index'\] must be an integer" + ): apply_track_edit([[1, 2], [1, 2]], edit) @@ -32,7 +33,10 @@ def test_swap_link_rejects_boolean_removed_observation() -> None: }, ) - with pytest.raises(ValueError, match="metadata\['remove_source_observation'\] must be a non-negative integer"): + with pytest.raises( + ValueError, + match="metadata\['remove_source_observation'\] must be a non-negative integer", + ): apply_track_edit([[1, 2]], edit) @@ -43,5 +47,7 @@ def test_merge_tracks_rejects_text_other_track_index() -> None: metadata={"other_track_index": "1"}, ) - with pytest.raises(ValueError, match="metadata\['other_track_index'\] must be an integer"): + with pytest.raises( + ValueError, match="metadata\['other_track_index'\] must be an integer" + ): apply_track_edit([[1, None], [None, 2]], edit) diff --git a/tests/test_track_evaluation_session_validation.py b/tests/test_track_evaluation_session_validation.py index afa57f1f0..70b5793bd 100644 --- a/tests/test_track_evaluation_session_validation.py +++ b/tests/test_track_evaluation_session_validation.py @@ -3,7 +3,6 @@ import unittest import numpy as np - from pyrecest.utils.track_evaluation import ( complete_track_set, score_complete_tracks, diff --git a/tests/tracking/test_innovation_diagnostics.py b/tests/tracking/test_innovation_diagnostics.py index 1c2dae594..bf2011dab 100644 --- a/tests/tracking/test_innovation_diagnostics.py +++ b/tests/tracking/test_innovation_diagnostics.py @@ -89,7 +89,9 @@ def test_innovation_diagnostic_rejects_non_real_residual_values(bad_residual) -> [[None]], ), ) -def test_innovation_diagnostic_rejects_non_real_covariance_values(bad_covariance) -> None: +def test_innovation_diagnostic_rejects_non_real_covariance_values( + bad_covariance, +) -> None: with pytest.raises( ValueError, match="innovation_covariance must contain finite real numeric values", diff --git a/tests/utils/test_association_features_probability_validation.py b/tests/utils/test_association_features_probability_validation.py index 63b21a4bf..2aa20d745 100644 --- a/tests/utils/test_association_features_probability_validation.py +++ b/tests/utils/test_association_features_probability_validation.py @@ -2,7 +2,6 @@ import numpy as np import pytest - from pyrecest.utils.association_features import CalibratedPairwiseAssociationModel @@ -46,7 +45,9 @@ def test_predict_match_probability_rejects_non_real_probability_values(probabili feature_names=["distance"], ) - with pytest.raises(ValueError, match="predicted probabilities must be real numeric"): + with pytest.raises( + ValueError, match="predicted probabilities must be real numeric" + ): model.predict_match_probability(np.array([[0.0], [1.0]])) @@ -56,7 +57,9 @@ def test_predict_proba_rejects_non_real_probabilities_before_class_selection(): feature_names=["distance"], ) - with pytest.raises(ValueError, match="predicted probabilities must be real numeric"): + with pytest.raises( + ValueError, match="predicted probabilities must be real numeric" + ): model.predict_match_probability(np.array([[0.0]])) @@ -66,5 +69,7 @@ def test_pairwise_cost_model_rejects_temporal_cost_values(): feature_names=["distance"], ) - with pytest.raises(ValueError, match="predicted pairwise costs must be real numeric"): + with pytest.raises( + ValueError, match="predicted pairwise costs must be real numeric" + ): model.predict_match_probability(np.array([[0.0]]))