Eager DAG-at-optimum intermediates: drain at simulate-time, serve to_dataframe from host#369
Open
hmgaudecker wants to merge 5 commits into
Open
Conversation
…iates Simulation now evaluates every additional target — each user function and constraint except `H` and the stochastic-weight helpers — at the realised optimal actions, once per regime-period, and stores the per-subject results as host `np.ndarray` on `PeriodRegimeSimulationData.intermediates`. The shared `_target_names_for_regime` helper drives both this evaluation and `available_targets`, so the names a result advertises are exactly the names its `intermediates` carry. `to_dataframe(additional_targets=...)` selects the requested columns straight from `intermediates` instead of recomputing the DAG at dataframe-build time. The recompute machinery (`additional_targets.py`) is gone; the surviving introspection and resolution helpers move to `targets.py`. `to_dataframe` therefore needs neither the compiled `Regime` objects nor `flat_params`, so `save()` drops `self._regimes` unconditionally before building the dataframe and the earlier per-call workarounds (a top-of-`save` drain, a conditional regime drop) are removed. `intermediates` round-trips through the orbax checkpoint, so a reloaded result serves `additional_targets` identically to a fresh one. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…fixes Adversarial review of the eager-DAG-intermediates refactor surfaced: - The save-time array-size census (`_walk_tree`) counted only `jax.Array` leaves, so the host-numpy `intermediates` orbax now persists were invisible — the reported byte total understated the serialised tree. Count `np.ndarray` leaves too; docstrings and the log label updated to match. - `test_raw_results_intermediates_exposes_dag_outputs_as_host_arrays` asserted only `isinstance(..., np.ndarray)`. Recompute utility from the realised actions and assert the value (atol 1e-6) and the (n_subjects,) shape. - `docs/explanations/architecture.md` still named the deleted `additional_targets.py` module and the removed `_compute_targets` helper; point both references at `targets.py` + eager eval in `simulate.py`. - `save()`'s docstring did not mention that `arrays/` persists every regime's full `intermediates` regardless of `df_additional_targets`; document the unconditional cost. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
A target that consumes a next-period state — e.g. a reporting function
`f(next_aime)` — could not be evaluated at the optimum: the eager DAG
pool held only user functions + constraints, so `next_aime` (a
state-transition output) was an unresolved input and the vmapped call
raised `ValueError: ... missing: {'next_aime'}`. The old code never hit
this because it computed only the targets the caller requested; the
eager path computes the full set every period, so any next-state
dependency now fires.
Add the regime's deterministic state-transition functions to the pool so
dags resolves the `next_<state>` chain (as solve already does).
Recomputing a deterministic transition from the realised optimal action
reproduces the value the simulation stores, so the intermediate stays
consistent with the forward path. Stochastic transitions are excluded —
their realised draw lives in the simulation loop's random key, not in
states + actions.
Regression test in test_chained_state_transitions.py: a reporting
function `2 * next_aime` whose column must equal `2 * (aime + work)`.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…-dag-intermediates
The eager target evaluation fuses every DAG target into one vmap over the full subject population. That fused program carries the subject axis through every intermediate, so its peak device workspace scales with the subject count — on a memory-tight GPU the single-pass evaluation can exceed device memory at production subject counts. Add `subject_batch_size`, threaded `Model.simulate` -> `simulate` -> `_simulate_regime_in_period` -> `_evaluate_dag_at_optimum`: evaluate the targets over subject chunks of that size, pulling each chunk to host before the next runs, so the peak scales with the chunk rather than the population. `None` (default) keeps the single-pass behaviour. Values are identical across batch sizes — verified by a parametrized test over even (4 -> 2+2) and uneven (4 -> 3+1) splits. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked on #368.
Why
On a 16 GiB V100,
SimulationResult.save()/to_dataframe()OOM'd atnp.asarraytime with 10–11 GiB allocations even though every leaf is a(n_subjects,)single-device array. Root cause:PeriodRegimeSimulationDatawas not registered as a JAX pytree, so the simulate-end
jax.block_until_ready(...)drain was a silent no-op — the lazy compute graphsfor
additional_targetsonly fired later, atnp.asarray, where they needed amulti-GiB XLA workspace.
What changed
PeriodRegimeSimulationDataas a pytree so the simulate-enddrain actually completes the device compute (907df02).
(regime, period),_evaluate_dag_at_optimumcomposes the full target DAG(user functions + constraints, minus
Handweight_*helpers), vmaps oversubjects, and pulls each output straight to host. Results live in
PeriodRegimeSimulationData.intermediatesas plainnp.ndarray, pinning nodevice memory across the loop.
to_dataframeservesadditional_targetsfromintermediates— no DAGrecomputation, no
Regimeobjects, noflat_paramsneeded at dataframe time.save()simplified: drops the in-memory regimes unconditionally beforeto_dataframe(intermediates are self-contained);arrays/persists theintermediates so a loaded result serves any
additional_targetsidenticallyto a fresh one.
additional_targets.pytotargets.py(_resolve_targets,_collect_all_available_targets,_target_names_for_regime).The public contract is unchanged:
to_dataframe(additional_targets=[...] | "all")behaves exactly as before; the default stays lean (states + actions).
Review follow-up (f9650f6)
Adversarial review surfaced: a save-time census that ignored host-numpy leaves
(now counted), a type-only intermediates test (now a concrete-value + shape
assertion), and stale
architecture.mdreferences to the deleted module (fixed).Verification
ty: clean ·prek run --all-files: clean🤖 Generated with Claude Code