From 72c6aafcec106c43fd945e28eec1726ceab1f1d8 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sun, 17 May 2026 08:49:44 -0400 Subject: [PATCH] Reject NaN values in numeric inputs --- changelog.d/reject-nan-set-input.fixed.md | 1 + policyengine_core/holders/helpers.py | 3 +- policyengine_core/holders/holder.py | 34 ++++++++++-- tests/core/test_holders.py | 63 +++++++++++++++++++++++ 4 files changed, 96 insertions(+), 5 deletions(-) create mode 100644 changelog.d/reject-nan-set-input.fixed.md diff --git a/changelog.d/reject-nan-set-input.fixed.md b/changelog.d/reject-nan-set-input.fixed.md new file mode 100644 index 00000000..95d67ae5 --- /dev/null +++ b/changelog.d/reject-nan-set-input.fixed.md @@ -0,0 +1 @@ +Raised a clear error when numeric simulation inputs contain NaN values. diff --git a/policyengine_core/holders/helpers.py b/policyengine_core/holders/helpers.py index 30e324c6..790f136d 100644 --- a/policyengine_core/holders/helpers.py +++ b/policyengine_core/holders/helpers.py @@ -30,7 +30,7 @@ def set_input_dispatch_by_period(holder: Holder, period: Period, array: ArrayLik To read more about ``set_input`` attributes, check the `documentation `_. """ - array = holder._to_array(array) + array = holder._to_array(array, validate_nan=True) period_size = period.size period_unit = period.unit @@ -70,6 +70,7 @@ def set_input_divide_by_period(holder: Holder, period: Period, array: ArrayLike) """ if not isinstance(array, numpy.ndarray): array = numpy.array(array) + array = holder._to_array(array, validate_nan=True) period_size = period.size period_unit = period.unit diff --git a/policyengine_core/holders/holder.py b/policyengine_core/holders/holder.py index fa2aac5e..f6be8d04 100644 --- a/policyengine_core/holders/holder.py +++ b/policyengine_core/holders/holder.py @@ -250,6 +250,7 @@ def set_input( return warnings.warn(warning_message, Warning) if self.variable.value_type in (float, int) and isinstance(array, str): array = tools.eval_expression(array) + self._raise_if_input_contains_nan(numpy.asarray(array)) simulation = getattr(self, "simulation", None) if simulation is not None: if not hasattr(simulation, "_user_input_keys"): @@ -263,12 +264,29 @@ def set_input( and period.unit != self.variable.definition_period ): return self.variable.set_input(self, period, array) - return self._set(period, array, branch_name) + return self._set(period, array, branch_name, validate_nan=True) finally: if simulation is not None: simulation._user_input_contexts.pop() - def _to_array(self, value: Any) -> ArrayLike: + def _raise_if_input_contains_nan(self, value: ArrayLike) -> None: + if self.variable.value_type not in (float, int): + return + value = numpy.asarray(value) + try: + if value.dtype.kind in ("O", "S", "U"): + value = value.astype(float) + contains_nan = numpy.isnan(value).any() + except (TypeError, ValueError): + return + if contains_nan: + raise ValueError( + 'Unable to set value for variable "{}", as the input contains NaN values.'.format( + self.variable.name, + ) + ) + + def _to_array(self, value: Any, validate_nan: bool = False) -> ArrayLike: if not isinstance(value, numpy.ndarray): value = numpy.asarray(value) if value.ndim == 0: @@ -284,6 +302,8 @@ def _to_array(self, value: Any) -> ArrayLike: self.population.entity.plural, ) ) + if validate_nan: + self._raise_if_input_contains_nan(value) if self.variable.value_type == Enum: original_value = value value = self.variable.possible_values.encode(value) @@ -301,16 +321,22 @@ def _to_array(self, value: Any) -> ArrayLike: value.dtype, ) ) + if validate_nan: + self._raise_if_input_contains_nan(value) return value def _set( - self, period: Period, value: ArrayLike, branch_name: str = "default" + self, + period: Period, + value: ArrayLike, + branch_name: str = "default", + validate_nan: bool = False, ) -> None: simulation = getattr(self, "simulation", None) user_input_contexts = getattr(simulation, "_user_input_contexts", None) if user_input_contexts and branch_name == "default": branch_name = user_input_contexts[-1] - value = self._to_array(value) + value = self._to_array(value, validate_nan=validate_nan) if self.variable.definition_period != periods.ETERNITY: if period is None: raise ValueError( diff --git a/tests/core/test_holders.py b/tests/core/test_holders.py index 94aebe09..59b77c47 100644 --- a/tests/core/test_holders.py +++ b/tests/core/test_holders.py @@ -216,3 +216,66 @@ def test_set_input_float_to_int(single): simulation.person.get_holder("age").set_input(period, age) result = simulation.calculate("age", period) assert result == numpy.asarray([50]) + + +def test__given_nan_float_array__then_set_input_raises_value_error(single): + simulation = single + + with pytest.raises(ValueError, match='variable "salary".*NaN'): + simulation.set_input("salary", period, numpy.asarray([numpy.nan])) + + +def test__given_nan_int_array__then_set_input_raises_value_error(single): + simulation = single + + with pytest.raises(ValueError, match='variable "age".*NaN'): + simulation.set_input("age", period, numpy.asarray([numpy.nan])) + + +def test__given_object_array_containing_nan__then_set_input_raises_value_error( + single, +): + simulation = single + age = numpy.asarray([numpy.nan], dtype=object) + + with pytest.raises(ValueError, match='variable "age".*NaN'): + simulation.set_input("age", period, age) + + +def test__given_nan_yearly_input__then_set_input_divide_by_period_raises_value_error( + single, +): + simulation = single + salary_holder = simulation.person.get_holder("salary") + + with pytest.raises(ValueError, match='variable "salary".*NaN'): + holders.set_input_divide_by_period( + salary_holder, + periods.period(2017), + numpy.asarray([numpy.nan]), + ) + + +def test__given_nan_period_dispatch_input__then_helper_raises_value_error( + single, +): + simulation = single + age_holder = simulation.person.get_holder("age") + + with pytest.raises(ValueError, match='variable "age".*NaN'): + holders.set_input_dispatch_by_period( + age_holder, + periods.period(2017), + numpy.asarray([numpy.nan]), + ) + + +def test__given_nan_cache_value__then_put_in_cache_keeps_internal_write_allowed( + single, +): + simulation = single + salary_holder = simulation.person.get_holder("salary") + + salary_holder.put_in_cache(numpy.asarray([numpy.nan]), period) + + assert numpy.isnan(salary_holder.get_array(period)).all()