From aef40939af710bb39ecba64c7919ef2d5ea95b08 Mon Sep 17 00:00:00 2001 From: Oliver Bristow Date: Fri, 8 May 2026 23:33:17 +0100 Subject: [PATCH] Use StrEnum and dataclass for solver result normalization --- modello.py | 120 ++++++++++++++++++++++++++++++++++++++++++++---- test_modello.py | 38 +++++++++++++++ 2 files changed, 149 insertions(+), 9 deletions(-) diff --git a/modello.py b/modello.py index 8e77728..f210688 100644 --- a/modello.py +++ b/modello.py @@ -1,8 +1,10 @@ #!/usr/bin/env python """Module for symbolic modeling of systems.""" import typing +from dataclasses import dataclass, field +from enum import StrEnum -from sympy import Basic, Dummy, Eq, solve +from sympy import Basic, Dummy, Eq, FiniteSet, linsolve, nonlinsolve, solve, solveset # more verbose path as mypy sees sympy.simplify as a module from sympy.simplify.simplify import simplify @@ -28,6 +30,40 @@ class BoundInstanceDummy(InstanceDummy): """Dummy associated with a Modello instance.""" +class SolverStrategy(StrEnum): + """Constraint solver strategy options.""" + + LEGACY = "legacy" + SET = "set" + + +class SolutionMode(StrEnum): + """How non-unique solutions are handled.""" + + STRICT = "strict" + PERMISSIVE = "permissive" + + +class SolutionKind(StrEnum): + """Normalized solver result kind.""" + + UNIQUE = "unique" + MULTIPLE = "multiple" + PARAMETRIC = "parametric" + + +@dataclass +class SolveResult: + """Normalized result for solver outputs.""" + + kind: SolutionKind + selected: typing.Dict[BoundInstanceDummy, Basic] = field(default_factory=dict) + solutions: typing.List[typing.Dict[BoundInstanceDummy, Basic]] = field( + default_factory=list + ) + sets: typing.Optional[typing.Dict[Basic, Basic]] = None + + class ModelloMetaNamespace(dict): """This is so that Modello class definitions implicitly define symbols.""" @@ -127,7 +163,17 @@ class Modello(ModelloSentinelClass, metaclass=ModelloMeta): ) _modello_class_constraints: typing.Dict[InstanceDummy, Basic] = {} - def __init__(self, name: str, **value_map: Basic) -> None: + def __init__( + self, + name: str, + *, + solver_strategy: SolverStrategy = SolverStrategy.LEGACY, + solution_mode: SolutionMode = SolutionMode.STRICT, + solution_selector: typing.Optional[ + typing.Callable[[typing.List[typing.Dict[BoundInstanceDummy, Basic]]], typing.Dict[BoundInstanceDummy, Basic]] + ] = None, + **value_map: Basic, + ) -> None: """Initialise a model instance and solve for each attribute.""" instance_dummies = { class_dummy: class_dummy.bound(name) @@ -157,13 +203,13 @@ def __init__(self, name: str, **value_map: Basic) -> None: # handy for debugging self._modello_constraints: typing.List[Eq] = constraints - if constraints: - solutions = solve(constraints, particular=True, dict=True) - if len(solutions) != 1: - raise ValueError("%s solutions" % len(solutions)) - solution = solutions[0] - else: - solution = {} + solve_result = self._solve_constraints( + constraints, + solver_strategy=solver_strategy, + solution_mode=solution_mode, + solution_selector=solution_selector, + ) + solution = solve_result.selected for attr, class_dummy in self._modello_namespace.dummies.items(): instance_dummy = instance_dummies[class_dummy] @@ -178,3 +224,59 @@ def __init__(self, name: str, **value_map: Basic) -> None: else: value = instance_dummy setattr(self, attr, value) + + def _solve_constraints( + self, + constraints: typing.List[Eq], + *, + solver_strategy: SolverStrategy, + solution_mode: SolutionMode, + solution_selector: typing.Optional[ + typing.Callable[[typing.List[typing.Dict[BoundInstanceDummy, Basic]]], typing.Dict[BoundInstanceDummy, Basic]] + ], + ) -> SolveResult: + if not constraints: + return SolveResult(kind=SolutionKind.UNIQUE, solutions=[{}]) + + solver_strategy = SolverStrategy(solver_strategy) + solution_mode = SolutionMode(solution_mode) + + if solver_strategy is SolverStrategy.LEGACY: + solutions = solve(constraints, particular=True, dict=True) + elif solver_strategy is SolverStrategy.SET: + symbols = sorted( + set().union(*(eq.free_symbols for eq in constraints)), key=lambda s: s.name + ) + exprs = [eq.lhs - eq.rhs for eq in constraints] + if all(expr.is_polynomial(*symbols) and expr.total_degree() <= 1 for expr in exprs): + set_solution = linsolve(exprs, symbols) + else: + set_solution = nonlinsolve(exprs, symbols) + + if isinstance(set_solution, FiniteSet): + solutions = [dict(zip(symbols, sol)) for sol in set_solution] + else: + # fallback to per-symbol solvesets for symbolic set outputs + solution_sets = { + symbol: solveset(exprs[0], symbol) for symbol in symbols + } + return SolveResult(kind=SolutionKind.PARAMETRIC, sets=solution_sets) + else: + raise ValueError("Unknown solver_strategy: %s" % solver_strategy) + + if len(solutions) == 1: + return SolveResult( + kind=SolutionKind.UNIQUE, selected=solutions[0], solutions=solutions + ) + if len(solutions) > 1: + if solution_selector is not None: + selected = solution_selector(solutions) + return SolveResult( + kind=SolutionKind.MULTIPLE, + selected=selected, + solutions=solutions, + ) + if solution_mode is SolutionMode.STRICT: + raise ValueError("%s solutions" % len(solutions)) + return SolveResult(kind=SolutionKind.MULTIPLE, solutions=solutions) + return SolveResult(kind=SolutionKind.PARAMETRIC, solutions=solutions) diff --git a/test_modello.py b/test_modello.py index 17fbf9f..52bed82 100644 --- a/test_modello.py +++ b/test_modello.py @@ -38,3 +38,41 @@ class ExampleC(ExampleA, ExampleB): assert ExampleC.conflicted == ExampleB.conflicted assert ExampleC.conflicted != ExampleA.conflicted + + +def test_solver_unique_solution(): + class Linear(Modello): + x = InstanceDummy("x") + y = InstanceDummy("y") + total = x + y + + instance = Linear("L", x=2, total=5) + assert instance.y == 3 + + +def test_solver_multiple_solution_permissive_and_selector(): + class Branches(Modello): + x = InstanceDummy("x") + y = x**2 + + permissive = Branches("B1", y=4, solution_mode="permissive") + assert isinstance(permissive.x, BoundInstanceDummy) + + selected = Branches( + "B2", + y=4, + solution_selector=lambda solutions: max( + solutions, key=lambda sol: list(sol.values())[0] + ), + ) + assert selected.x == 2 + + +def test_solver_underdetermined_set_strategy(): + class Under(Modello): + x = InstanceDummy("x") + y = x + + instance = Under("U", solver_strategy="set", solution_mode="permissive") + assert isinstance(instance.x, BoundInstanceDummy) + assert instance.y == instance.x