Skip to content

Eager DAG-at-optimum intermediates: drain at simulate-time, serve to_dataframe from host#369

Open
hmgaudecker wants to merge 5 commits into
feat/discrete-only-sharding-orbax-savefrom
feat/eager-dag-intermediates
Open

Eager DAG-at-optimum intermediates: drain at simulate-time, serve to_dataframe from host#369
hmgaudecker wants to merge 5 commits into
feat/discrete-only-sharding-orbax-savefrom
feat/eager-dag-intermediates

Conversation

@hmgaudecker
Copy link
Copy Markdown
Member

Stacked on #368.

Why

On a 16 GiB V100, SimulationResult.save() / to_dataframe() OOM'd at
np.asarray time with 10–11 GiB allocations even though every leaf is a
(n_subjects,) single-device array. Root cause: PeriodRegimeSimulationData
was not registered as a JAX pytree, so the simulate-end
jax.block_until_ready(...) drain was a silent no-op — the lazy compute graphs
for additional_targets only fired later, at np.asarray, where they needed a
multi-GiB XLA workspace.

What changed

  • Register PeriodRegimeSimulationData as a pytree so the simulate-end
    drain actually completes the device compute (907df02).
  • Evaluate every additional target eagerly at the optimum — once per
    (regime, period), _evaluate_dag_at_optimum composes the full target DAG
    (user functions + constraints, minus H and weight_* helpers), vmaps over
    subjects, and pulls each output straight to host. Results live in
    PeriodRegimeSimulationData.intermediates as plain np.ndarray, pinning no
    device memory across the loop.
  • to_dataframe serves additional_targets from intermediates — no DAG
    recomputation, no Regime objects, no flat_params needed at dataframe time.
  • save() simplified: drops the in-memory regimes unconditionally before
    to_dataframe (intermediates are self-contained); arrays/ persists the
    intermediates so a loaded result serves any additional_targets identically
    to a fresh one.
  • Name resolution moved from the deleted additional_targets.py to
    targets.py (_resolve_targets, _collect_all_available_targets,
    _target_names_for_regime).

The public contract is unchanged: to_dataframe(additional_targets=[...] | "all")
behaves exactly as before; the default stays lean (states + actions).

Review follow-up (f9650f6)

Adversarial review surfaced: a save-time census that ignored host-numpy leaves
(now counted), a type-only intermediates test (now a concrete-value + shape
assertion), and stale architecture.md references to the deleted module (fixed).

Verification

  • pylcm suite: 1034 passed, 42 skipped
  • ty: clean · prek run --all-files: clean

🤖 Generated with Claude Code

hmgaudecker and others added 2 commits May 29, 2026 09:42
…iates

Simulation now evaluates every additional target — each user function and
constraint except `H` and the stochastic-weight helpers — at the realised
optimal actions, once per regime-period, and stores the per-subject results
as host `np.ndarray` on `PeriodRegimeSimulationData.intermediates`. The shared
`_target_names_for_regime` helper drives both this evaluation and
`available_targets`, so the names a result advertises are exactly the names
its `intermediates` carry.

`to_dataframe(additional_targets=...)` selects the requested columns straight
from `intermediates` instead of recomputing the DAG at dataframe-build time.
The recompute machinery (`additional_targets.py`) is gone; the surviving
introspection and resolution helpers move to `targets.py`. `to_dataframe`
therefore needs neither the compiled `Regime` objects nor `flat_params`, so
`save()` drops `self._regimes` unconditionally before building the dataframe
and the earlier per-call workarounds (a top-of-`save` drain, a conditional
regime drop) are removed.

`intermediates` round-trips through the orbax checkpoint, so a reloaded result
serves `additional_targets` identically to a fresh one.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…fixes

Adversarial review of the eager-DAG-intermediates refactor surfaced:

- The save-time array-size census (`_walk_tree`) counted only
  `jax.Array` leaves, so the host-numpy `intermediates` orbax now
  persists were invisible — the reported byte total understated the
  serialised tree. Count `np.ndarray` leaves too; docstrings and the
  log label updated to match.
- `test_raw_results_intermediates_exposes_dag_outputs_as_host_arrays`
  asserted only `isinstance(..., np.ndarray)`. Recompute utility from
  the realised actions and assert the value (atol 1e-6) and the
  (n_subjects,) shape.
- `docs/explanations/architecture.md` still named the deleted
  `additional_targets.py` module and the removed `_compute_targets`
  helper; point both references at `targets.py` + eager eval in
  `simulate.py`.
- `save()`'s docstring did not mention that `arrays/` persists every
  regime's full `intermediates` regardless of `df_additional_targets`;
  document the unconditional cost.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community Bot commented May 29, 2026

hmgaudecker and others added 3 commits May 29, 2026 10:45
A target that consumes a next-period state — e.g. a reporting function
`f(next_aime)` — could not be evaluated at the optimum: the eager DAG
pool held only user functions + constraints, so `next_aime` (a
state-transition output) was an unresolved input and the vmapped call
raised `ValueError: ... missing: {'next_aime'}`. The old code never hit
this because it computed only the targets the caller requested; the
eager path computes the full set every period, so any next-state
dependency now fires.

Add the regime's deterministic state-transition functions to the pool so
dags resolves the `next_<state>` chain (as solve already does).
Recomputing a deterministic transition from the realised optimal action
reproduces the value the simulation stores, so the intermediate stays
consistent with the forward path. Stochastic transitions are excluded —
their realised draw lives in the simulation loop's random key, not in
states + actions.

Regression test in test_chained_state_transitions.py: a reporting
function `2 * next_aime` whose column must equal `2 * (aime + work)`.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The eager target evaluation fuses every DAG target into one vmap over the
full subject population. That fused program carries the subject axis through
every intermediate, so its peak device workspace scales with the subject
count — on a memory-tight GPU the single-pass evaluation can exceed device
memory at production subject counts.

Add `subject_batch_size`, threaded `Model.simulate` -> `simulate` ->
`_simulate_regime_in_period` -> `_evaluate_dag_at_optimum`: evaluate the
targets over subject chunks of that size, pulling each chunk to host before
the next runs, so the peak scales with the chunk rather than the population.
`None` (default) keeps the single-pass behaviour. Values are identical across
batch sizes — verified by a parametrized test over even (4 -> 2+2) and uneven
(4 -> 3+1) splits.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant