Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 111 additions & 9 deletions modello.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#!/usr/bin/env python
"""Module for symbolic modeling of systems."""
import typing
from dataclasses import dataclass, field
from enum import StrEnum

from sympy import Basic, Dummy, Eq, solve
from sympy import Basic, Dummy, Eq, FiniteSet, linsolve, nonlinsolve, solve, solveset
# more verbose path as mypy sees sympy.simplify as a module
from sympy.simplify.simplify import simplify

Expand All @@ -28,6 +30,40 @@ class BoundInstanceDummy(InstanceDummy):
"""Dummy associated with a Modello instance."""


class SolverStrategy(StrEnum):
"""Constraint solver strategy options."""

LEGACY = "legacy"
SET = "set"


class SolutionMode(StrEnum):
"""How non-unique solutions are handled."""

STRICT = "strict"
PERMISSIVE = "permissive"


class SolutionKind(StrEnum):
"""Normalized solver result kind."""

UNIQUE = "unique"
MULTIPLE = "multiple"
PARAMETRIC = "parametric"


@dataclass
class SolveResult:
"""Normalized result for solver outputs."""

kind: SolutionKind
selected: typing.Dict[BoundInstanceDummy, Basic] = field(default_factory=dict)
solutions: typing.List[typing.Dict[BoundInstanceDummy, Basic]] = field(
default_factory=list
)
sets: typing.Optional[typing.Dict[Basic, Basic]] = None


class ModelloMetaNamespace(dict):
"""This is so that Modello class definitions implicitly define symbols."""

Expand Down Expand Up @@ -127,7 +163,17 @@ class Modello(ModelloSentinelClass, metaclass=ModelloMeta):
)
_modello_class_constraints: typing.Dict[InstanceDummy, Basic] = {}

def __init__(self, name: str, **value_map: Basic) -> None:
def __init__(
self,
name: str,
*,
solver_strategy: SolverStrategy = SolverStrategy.LEGACY,
solution_mode: SolutionMode = SolutionMode.STRICT,
solution_selector: typing.Optional[
typing.Callable[[typing.List[typing.Dict[BoundInstanceDummy, Basic]]], typing.Dict[BoundInstanceDummy, Basic]]
] = None,
**value_map: Basic,
) -> None:
"""Initialise a model instance and solve for each attribute."""
instance_dummies = {
class_dummy: class_dummy.bound(name)
Expand Down Expand Up @@ -157,13 +203,13 @@ def __init__(self, name: str, **value_map: Basic) -> None:
# handy for debugging
self._modello_constraints: typing.List[Eq] = constraints

if constraints:
solutions = solve(constraints, particular=True, dict=True)
if len(solutions) != 1:
raise ValueError("%s solutions" % len(solutions))
solution = solutions[0]
else:
solution = {}
solve_result = self._solve_constraints(
constraints,
solver_strategy=solver_strategy,
solution_mode=solution_mode,
solution_selector=solution_selector,
)
solution = solve_result.selected

for attr, class_dummy in self._modello_namespace.dummies.items():
instance_dummy = instance_dummies[class_dummy]
Expand All @@ -178,3 +224,59 @@ def __init__(self, name: str, **value_map: Basic) -> None:
else:
value = instance_dummy
setattr(self, attr, value)

def _solve_constraints(
self,
constraints: typing.List[Eq],
*,
solver_strategy: SolverStrategy,
solution_mode: SolutionMode,
solution_selector: typing.Optional[
typing.Callable[[typing.List[typing.Dict[BoundInstanceDummy, Basic]]], typing.Dict[BoundInstanceDummy, Basic]]
],
) -> SolveResult:
if not constraints:
return SolveResult(kind=SolutionKind.UNIQUE, solutions=[{}])

solver_strategy = SolverStrategy(solver_strategy)
solution_mode = SolutionMode(solution_mode)

if solver_strategy is SolverStrategy.LEGACY:
solutions = solve(constraints, particular=True, dict=True)
elif solver_strategy is SolverStrategy.SET:
symbols = sorted(
set().union(*(eq.free_symbols for eq in constraints)), key=lambda s: s.name
)
exprs = [eq.lhs - eq.rhs for eq in constraints]
if all(expr.is_polynomial(*symbols) and expr.total_degree() <= 1 for expr in exprs):
set_solution = linsolve(exprs, symbols)
else:
set_solution = nonlinsolve(exprs, symbols)

if isinstance(set_solution, FiniteSet):
solutions = [dict(zip(symbols, sol)) for sol in set_solution]
else:
# fallback to per-symbol solvesets for symbolic set outputs
solution_sets = {
symbol: solveset(exprs[0], symbol) for symbol in symbols
}
return SolveResult(kind=SolutionKind.PARAMETRIC, sets=solution_sets)
else:
raise ValueError("Unknown solver_strategy: %s" % solver_strategy)

if len(solutions) == 1:
return SolveResult(
kind=SolutionKind.UNIQUE, selected=solutions[0], solutions=solutions
)
if len(solutions) > 1:
if solution_selector is not None:
selected = solution_selector(solutions)
return SolveResult(
kind=SolutionKind.MULTIPLE,
selected=selected,
solutions=solutions,
)
if solution_mode is SolutionMode.STRICT:
raise ValueError("%s solutions" % len(solutions))
return SolveResult(kind=SolutionKind.MULTIPLE, solutions=solutions)
return SolveResult(kind=SolutionKind.PARAMETRIC, solutions=solutions)
38 changes: 38 additions & 0 deletions test_modello.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,41 @@ class ExampleC(ExampleA, ExampleB):

assert ExampleC.conflicted == ExampleB.conflicted
assert ExampleC.conflicted != ExampleA.conflicted


def test_solver_unique_solution():
class Linear(Modello):
x = InstanceDummy("x")
y = InstanceDummy("y")
total = x + y

instance = Linear("L", x=2, total=5)
assert instance.y == 3


def test_solver_multiple_solution_permissive_and_selector():
class Branches(Modello):
x = InstanceDummy("x")
y = x**2

permissive = Branches("B1", y=4, solution_mode="permissive")
assert isinstance(permissive.x, BoundInstanceDummy)

selected = Branches(
"B2",
y=4,
solution_selector=lambda solutions: max(
solutions, key=lambda sol: list(sol.values())[0]
),
)
assert selected.x == 2


def test_solver_underdetermined_set_strategy():
class Under(Modello):
x = InstanceDummy("x")
y = x

instance = Under("U", solver_strategy="set", solution_mode="permissive")
assert isinstance(instance.x, BoundInstanceDummy)
assert instance.y == instance.x
Loading