Skip to content

[NNX] NNX migration prep (7/N): NNX-native MaxEngine inference#3821

Draft
ecnal-cienet wants to merge 5 commits intomainfrom
feat/nnx-native-maxengine
Draft

[NNX] NNX migration prep (7/N): NNX-native MaxEngine inference#3821
ecnal-cienet wants to merge 5 commits intomainfrom
feat/nnx-native-maxengine

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR #3500)
  4. ✅ NNX sharding diagnostics, bidirectional Linen↔NNX checkpoint conversion utilities, and post-training fixes. (PR #3652)
  5. ✅ NNX correctness fixes, feature enablements, and vocab tiling on NNX. No-op while pure_nnx=False stays default. (PR #3766)
  6. ✅ NNX-native DPO. Closes the only remaining hard NotImplementedError on the NNX path; pure_nnx=True + use_dpo=True is now supported.
  7. 🔄 [This PR] NNX-native MaxEngine inference. Drops the route-to-Linen path in maxengine.py; pure_nnx=True now drives a real NNX inference flow end-to-end (prefill, generate, KV cache).
  8. ❌ NNX-native LoRA + GRPO.
  9. ❌ NNX-aware QK-Clip + remaining checkpoint utilities (standalone_checkpointer, generate_param_only_checkpoint, layerwise_quantization, convert_gpt3_ckpt_from_paxml).
  10. ❌ Vocab tiling custom_vjp for NNX (perf optimization, not correctness).
  11. ❌ Set NNX defaults to True; regenerate sharding goldens; flip back integration-test pure_nnx=False annotations.
  12. ❌ Delete Linen-specific code paths and NNX compatibility flags.

Description

PR5 audited maxengine.py and routed three call sites to the Linen path (state.params shape, state_mesh_shardings.params, and Linen-format inference checkpoint serving), because transformer_as_linen returns a Linen module regardless of pure_nnx. That preserved Linen serving but meant pure_nnx=True users silently got the Linen engine — the cited rationale was "the flag affects training, not inference serving."

This PR replaces those routes with a real NNX inference flow. When config.pure_nnx=True, the engine now builds an NNX Transformer via from_config(rngs=...), splits params/cache/rest with nnx.split(model, nnx.Param, nnx.Cache, ...), and at every JIT body merges the model concretely with nnx.merge(graphdef, params, cache, rest, copy=True) to run the forward pass. Linen is byte-for-byte preserved: every NNX edit is gated if config.pure_nnx:.

The diff is +651 / −34 across 3 files, with the only physical deletion being the "use Linen path regardless of pure_nnx" comment block at the top of MaxEngine.__init__.

Part 1: Engine construction — two abstract NNX models

File Change
inference/maxengine/maxengine.py (__init__) Builds two abstract NNX Transformer instances on the NNX path: self.model with model_mode=PREFILL (batch=1, single padded prompt) and self.model_ar with model_mode=AUTOREGRESSIVE (batch=micro_batch_size_to_train_on, decode-state shape). One graphdef (self.graphdef) is cached from the PREFILL model; both abstract models contain initialized nnx.Cache vars (KVCache __init__ allocates both prefill and AR caches when model_mode in {PREFILL, AUTOREGRESSIVE}).
inference/maxengine/maxengine.py (__init__) Uses nnx.eval_shape directly inside nn_partitioning.axis_rules(config.logical_axis_rules) rather than create_nnx_abstract_model — the latter wraps the trace in jax.set_mesh(mesh) which makes Flax 0.12.6's _to_variable resolve logical axis names against the global mesh and raise on logical-only names like "norm". Same reason get_abstract_state_nnx avoids set_mesh.

The two-model split is required because NNX cache vars take their logical axis names (CACHE_BATCH vs CACHE_BATCH_PREFILL) from the model's construction model_mode. The decode_state cache must use AR-mode annotations so that bulk_insert's cache_batch substring lookup hits.

Part 2: NNX param loading

File Change
inference/maxengine/maxengine.py (_load_params_nnx) Accepts user-provided NNX-shape params or loads via from_pretrained. For user-provided params, materializes a concrete model once via _create_model_fn() to capture a real _nnx_rest_state (RNG vars) for nnx.merge. For checkpoint loads, extracts (graphdef, params, cache, rest) from the loaded nnx_model. Refreshes self.graphdef from the concrete model so subsequent merges line up exactly.
inference/maxengine/maxengine.py (_load_params_nnx) Builds self.abstract_params as ShapeDtypeStructs of the param state, populates self.prefill_kv_cache_annotations and self.kv_cache_annotations (using AR-mode for the latter so cache_batch appears in the per-leaf logical-axes tuple), and wraps both in NamedSharding.

Part 3: NNX init_decode_state

File Change
inference/maxengine/maxengine.py (_init_decode_state_nnx) Zero-initializes a pure-dict cache from the AR-mode abstract model (so the leading batch dim matches generate's input shape (per_device_batch_size * mesh.size, 1)). Builds kv_cache_annotations_named per leaf by reading nnx.Cache.metadata["out_sharding" / "sharding" / "sharding_names"] (Flax 0.12.6 renames vary). Returns the same decode_state dict shape as the Linen path.

Part 4: Cache adapter — NNX state ↔ pure dict

NNX cache state is an nnx.State with nnx.Cache-wrapped leaves; tree paths include a .value accessor at the end. Linen's mutable cache is a plain nested dict with the same leaf names but no .value. The engine's bulk_insert / _insert_jit / insert_partial / _maybe_*_prefill_result_cache helpers walk the cache via tree_map_with_path and switch on path[-1].key (the cache var name like "cached_prefill_key").

File Change
inference/maxengine/maxengine.py (_nnx_cache_state_template / _nnx_init_cache_dict / _nnx_run_model) Cache flows through the engine as a pure dict on both paths. NNX state is converted at the JIT boundary via nnx.State.to_pure_dict() (after the model run) and nnx.replace_by_pure_dict(template, pure_dict) (before nnx.merge). The cache plumbing helpers see the same shape on Linen and NNX, requiring zero modifications.

Part 5: NNX prefill / generate JIT branches

File Change
inference/maxengine/maxengine.py (_prefill_jit, _generate_jit) NNX path replaces self.model.apply(params | {"cache": cache}, ..., mutable=["cache"]) with a call to _nnx_run_model, which does nnx.merge(graphdef, params, cache, rest, copy=True) → model(...) → nnx.state(model, nnx.Cache).to_pure_dict(). copy=True mirrors the train.py diff_wrapper workaround for TraceContextError from reused Variable objects. Linen branch (the else:) is the prior code, unchanged.
inference/maxengine/maxengine.py (_prefill_jit) When existing_prefix is not None, the NNX branch threads existing_prefix.cache (already a pure dict) directly into _nnx_run_model's cache_dict= arg instead of the Linen params | {"cache": ...} dict-merge — params is an nnx.State, not a dict.

Part 6: KV-cache annotation helpers

File Change
utils/maxtext_utils.py New get_prefill_kv_cache_annotations_nnx(abstract_model, config, mesh) and get_kv_cache_annotations_nnx(abstract_model, config, mesh) mirror the Linen helpers' return shape (a tree of PartitionSpec). Both delegate to _nnx_cache_partition_specs, which extracts nnx.Cache state via nnx.split(model, nnx.Cache, ...), calls get_nnx_named_sharding_with_scan_axis inside nn_partitioning.axis_rules(config.logical_axis_rules) so logical axes (layers, cache_batch, norm, ...) resolve to physical mesh axes, then converts the resulting nnx.State to a pure-dict tree via to_pure_dict().

Part 7: Carve-outs (raise NotImplementedError on the NNX path with PR pointers)

Site Linen workaround Tracked in
quantize_params (NNX + quantization) pure_nnx=False PR9 (NNX-aware checkpoints + quantization)
load_params when checkpoint_is_quantized=True pure_nnx=False PR9
load_single_adapter / apply_adapter / unapply_adapter (NNX + LoRA) pure_nnx=False PR8 (NNX-native LoRA + GRPO)
_prefill_multisampling_jit (NNX) pure_nnx=False Follow-up
prefill_concat (NNX) pure_nnx=False Follow-up
load_params when stack_prefill_result_cache=True (NNX) pure_nnx=False Follow-up — with scan_layers=True the NNX cache leaves are already stacked on axis 0; the engine's manual-stack helper assumes an unstacked Linen tree shape

These are deliberate scope cuts, not silent fallbacks. Each raises with a message pointing at the responsible follow-up PR. The Linen path remains the workaround for users who hit any of these sites.

Deferred (intentionally out of scope)

  • AOT compilation on the NNX path (aot_compile, _compile_generate_and_get_layouts): functions are not gated on pure_nnx and are structure-agnostic, so they likely work as-is — but they are not exercised by PR7's tests. Recommended manual validation when a JetStream-style serving setup first hits this.
  • End-to-end gemma2-2b decode parity on a real checkpoint: the in-PR parity test (below) demonstrates equivalence on a 2-layer test config. A python3 -m MaxText.decode pure_nnx=True ... smoke run against gemma2-2b (with an NNX-format ckpt or a converted Linen one) is recommended manual validation before flipping defaults in PR11.

Tests

  • New unit tests (tests/unit/maxengine_test.py, 6 tests):
    • test_init_nnx: engine constructs and exposes graphdef + abstract Transformer.
    • test_basic_prefill_nnx: prefill produces the same dict shape as the Linen path; logits + every cache leaf are finite (catches silent NaN/inf from a bad nnx.merge or cache round-trip); per-layer cache leading axis equals num_decoder_layers (catches scan-axis misalignment).
    • test_basic_decode_nnx: prefill → insert → 4-step generate. Verifies next_pos advances by exactly 1 each step (catches off-by-one cache pointer bugs) and logits stay finite at every step.
    • test_linen_nnx_parity_prefill: strongest evidence — same Linen-init weights are bridged into the NNX engine via linen_nnx_converter.py (convert_linen_to_nnx_strip_value_wrappersnnx.replace_by_pure_dict), then prefill is run on both engines with the same prompt + greedy. Asserts logits match within bf16 tolerance (rtol=0.05, atol=0.1; the test config uses bf16 compute) and the greedy first-token argmax is exactly the same. A failure here means the NNX forward pass diverged from Linen on identical weights.
    • test_quantize_raises_for_nnx, test_lora_raises_for_nnx: assert the carve-outs fire with NotImplementedError, so users hitting them get a clear "tracked in PR8/PR9" message instead of an opaque crash.
  • Existing Linen tests untouched and still pass (test_basic_prefill, test_basic_decode, test_stack_and_unstack_prefill_cache): every NNX edit is gated on if config.pure_nnx: and pure_nnx=False is the default, so the Linen branch is byte-for-byte the prior code.
  • Lint: bash lint.sh — codespell + pylint + pyink all green.
  • Comprehensive verification doc: PR7_VERIFICATION.md — TL;DR table, full test-run output, design notes (two abstract models, cache adapter, carve-outs), per-test rationale, what the evidence does not cover (manual validation suggestions), and reproduction commands.
  • Test summary:
$ pytest tests/unit/maxengine_test.py -v
test_basic_decode             PASSED
test_basic_decode_nnx         PASSED
test_basic_prefill            PASSED
test_basic_prefill_nnx        PASSED
test_chunked_prefill          SKIPPED  (CPU-only; pre-existing skip)
test_init_nnx                 PASSED
test_linen_nnx_parity_prefill PASSED
test_lora_raises_for_nnx      PASSED
test_quantize_raises_for_nnx  PASSED
test_stack_and_unstack_prefill_cache PASSED
================= 9 passed, 1 skipped =================

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests
- Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils
- Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py
- Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
…raining fixes

Part 1 — sharding diagnostics and Linen<->NNX checkpoint utilities:
- modify print_shardings_params to support NNX (maxtext_utils.py)
- add --pure_nnx flag to run_sharding_dump.py
- add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py)
- add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py)

Part 2 — post-training bug fixes:
- models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the
  whole object as multimodal_input= kwarg; NNXDecoder only accepts individual fields)
- optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams
  (callable() check before invoking learning_rate_fn)
- train_distill.py: fix nested NNX transform issue (nnx.value_and_grad inside nnx.jit
  raises conflicting outer_index error); refactored to jax.value_and_grad + explicit
  nnx.split/merge pattern; teacher inference moved outside value_and_grad
Bug fixes (run as no-op while pure_nnx=False stays default):
- nnx_wrappers.py: add _refresh_variable_trace_state + is_linen_initializing;
  call from ToLinen after nnx.update to fix "Cannot extract graph node from
  different trace level" when grad tracers leak into Variable._trace_state.
- gpt_oss.py / olmo3.py: replace inline nn.Dropout(...) with self.dropout =
  linears.Dropout(...) in __init__ to fix CallCompactUnboundModuleError.
- normalizations.py: Qwen3NextRMSNorm signature: eps -> epsilon, accept
  shard_mode/kernel_axes/parameter_memory_host_offload for callsite parity.
- attentions.py / qwen3.py: callsites eps= -> epsilon=.
- moe.py: per_expert_scale block moved into the unfused-kernel else branch
  (was scaling wo even when fused_kernel was active).
- models.py: build MTP block as MultiTokenPredictionBlock(...) directly
  (drop the ToNNX(linen) + lazy_init wrap); pass multimodal_input whole
  to NNXDecoder instead of unpacking 5 fields.
- gradient_accumulation.py: ZeRO-1+GA all-reduce annotation deferred until
  after lax.scan (reduced/unreduced PartitionSpec is rejected inside scan
  carry); use nnx.merge(..., copy=True) to avoid Variable reuse.
- diloco.py: NNX-aware state handling — state.params -> state.model.filter
  (nnx.Param), step counter at state.optimizer.step, replace_nnx_model_params
  helper for jax.lax.cond pytree-structure parity.
- train_compile.py: new _collect_nnx_activation_shardings helper (forward
  pass populates _ACTIVATION_SHARDINGS_DUMP — get_abstract_state_nnx only
  traces __init__); NNX path now passes 2-arg shaped_train_args (no rng);
  diloco path patched to handle the 2-vs-3 length difference.
- muon_utils.py: get_model_mdn default pure_nnx=True; wrap NNX result as
  {"params": nnx.to_pure_dict(...)} for parity with Linen tree shape.
- nnx_decoders.py: FP8+NNX scan fix — Linen FP8 ops (fp8_nanoo, fp8_gpu)
  retain tracers in Linen scope across re-traces. Skip jax.checkpoint and
  use a Python for-loop instead of jax.lax.scan when quantization is FP8.
  Makes FP8 quantization usable on the NNX path.
- train.py (pre-train train_step): return nnx.state(new_state, nnx.Not
  (nnx.Intermediate)) so sowed forward-pass artifacts (e.g. max_logits for
  QK-Clip) don't break leaf-count parity with state_mesh_shardings.
- llama2.py: pass parameter_memory_host_offload to pre_self_attention_layer
  _norm RMSNorm (was missing on this norm only).
- base.yml: add 4 pipeline-related logical_axis_rules — layers_outside
  _pipeline, layers_per_stage, num_activations, circular_repeats. Additive,
  no-op without use_nnx_pipeline=True.

NNX feature enablements (clear all 17 "Pure NNX support has not been
implemented yet" NotImplementedError sites by routing Linen-coupled
utilities to the Linen path; their on-disk format is Linen):
- layerwise_quantization.py (2 sites): operates on Linen-format checkpoints
  via DeepSeek*ToLinen layers.
- lora_utils.py (1 site): downstream get_lora_abstract_state expects Linen
  tree shape; LoRA adapters on disk are Linen.
- standalone_checkpointer.py (2 sites): add_entropy_to_checkpoint accesses
  state.opt_state[0]._replace(mu=..., nu=...) — Linen-only.
- generate_param_only_checkpoint.py (3 sites): _possibly_unroll_params and
  _save_decode_checkpoint use state.params["params"]["decoder"] — Linen.
- convert_gpt3_ckpt_from_paxml.py (2 sites): keystr_map targets Linen tree
  paths (.params['params'], .opt_state.mu['params']).
- maxengine.py (3 sites): inference engine uses state.params and serves
  Linen-format inference checkpoints.
- grpo_trainer.py (4 sites): RL trainer is end-to-end Linen-shaped; route
  to Linen with a clear log warning since NNX-format checkpoints will fail
  at restore time.

Vocab tiling on NNX (real implementation, not just routing):
- models.py: add Transformer.logits_from_hidden_states on the NNX
  Transformer class — wraps NNXDecoder.apply_output_head with the
  token_embedder; mirrors TransformerLinenPure.logits_from_hidden_states.
- vocabulary_tiling.py: add vocab_tiling_nnx_loss — chunks the vocab axis
  via jax.lax.scan and calls model.logits_from_hidden_states(chunk) per
  chunk. The NNX model carries its parameters internally so no explicit
  FSDP gather is needed (unlike the Linen gathered_params pattern). MVP
  uses default autograd; custom_vjp memory-savings optimization is a
  follow-up if backward memory becomes a concern.
- train.py (NNX loss_fn): replace the NotImplementedError with the call
  to vocab_tiling_nnx_loss using hidden_states from intermediates.
- pyconfig_deprecated.py / configs/types.py: drop the num_vocab_tiling > 1
  and enable_nnx validation guards (no longer needed).

DPO + NNX retained as NotImplementedError but with a much more informative
message (points users at pure_nnx=False workaround). Full implementation
is deferred — needs a new TrainState shape carrying both policy and
reference NNX models plus an NNX dpo_loss_fn.

Stats: 26 source files modified, +406 / -171 lines. Linen invariant
verified: pure_nnx / enable_nnx / pure_nnx_decoder still default to False;
Linen-path UTs unaffected (3 pre-existing failures on the parent branch
remain unchanged — sharding_compare_test::deepseek2-16b,
optimizers_test::test_model_integration_kimi-k2-1t, diloco_test::two
_slices x2). All "Pure NNX support has not been implemented yet"
NotImplementedError sites cleared (was 17, now 0).
Implements NNX-native DPO so that the pure_nnx=True training path no
longer raises NotImplementedError on use_dpo runs. The Linen DPO
overlay pattern (model.apply(params=..., reference_params=...)) does
not translate to NNX modules, which carry their parameters internally.
Instead the policy and reference models are held as separate
nnx.Module instances on TrainStateNNX, and a new dpo_loss_fn_nnx
runs both forwards with stop_gradient on the reference logits.

TrainStateNNX:
- Add optional `reference_model: nnx.Module` field. apply_gradients
  continues to update only `self.model`, leaving `self.reference_model`
  bit-identical across steps.

dpo_utils.py:
- Add dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params,
  reference_model, is_train=True). Signature mirrors the Linen
  dpo_loss_fn so it slots into gradient_accumulation_loss_and_grad's
  dispatcher (dropout_rng / params slots are unused for NNX; carried
  for parity, and reference_model is passed as the single
  extra_dpo_args entry). With nnx.value_and_grad(..., argnums=0) over
  the policy, no gradient flows to the reference model's nnx.Param
  leaves; the explicit jax.lax.stop_gradient on ref_logits is a
  belt-and-braces guard.
- Both dpo_loss_fn (Linen) and dpo_loss_fn_nnx (NNX) now include
  indexer_loss=0.0 and mtp_loss=0.0 in aux so the
  gradient_accumulation aux pytree shape matches the non-DPO loss_fn.

train.py:
- Drop the NotImplementedError in train_step's NNX branch. When
  use_dpo, dispatch to dpo_loss_fn_nnx with state.reference_model as
  extra_dpo_args; otherwise use the regular loss_fn. eval_step gains
  the same dispatch.
- diff_wrapper picks _loss_fn / extra_dpo_args from the per-path init
  block, so both the GA and non-GA NNX paths route DPO identically.
- Checkpoint-save _split_dpo_state stripping is now Linen-only;
  TrainStateNNX saves whole (reference_model included) — the step-0
  reload later overwrites reference_model from the step-0 checkpoint.

train_utils.py:
- NNX init_state_fn materializes a frozen reference_model alongside
  the policy when config.use_dpo. Both are constructed by
  _create_model_partial() with config.init_weights_seed, so they
  start identical (standard DPO practice) until the step-0 reload.
- Step-0 checkpoint reload: copy step0_state["model"] into
  state["reference_model"]. Linen path unchanged.

Tests:
- New tests/unit/dpo_nnx_test.py (7 tests): TrainStateNNX
  reference_model init/hasattr semantics; apply_gradients leaves
  reference bit-identical; aux key set; identical policy/reference
  yields loss=log(2) and reward_accuracy=0.0 (strict > on equal
  logratios); dropout_rng/params slots are signature-compat only;
  nnx.value_and_grad(argnums=0) over the policy yields finite grads
  on policy params only.
- train_nnx_test.py: drop the two stale negative tests
  (vocab_tiling_raises_not_implemented,
  train_step_dpo_raises_for_nnx) — both features are now real.

Stats: 4 source files + 2 test files, +199/-22 source lines. Linen
DPO path behaviorally unchanged (only adds two harmless aux-dict
keys); NNX non-DPO path unchanged (all changes gated on
config.use_dpo).
…e.py)

PR5 audited maxengine.py and routed the inference path to the Linen
implementation regardless of pure_nnx, with a comment block explaining
that "the flag affects training, not inference serving." That kept the
Linen serving path unchanged but meant pure_nnx=True users silently got
the Linen engine. This change replaces the route with a real NNX flow:
when config.pure_nnx=True, the engine builds an NNX Transformer, splits
out (params, cache, rest) with nnx.split, and at every JIT body merges
the model concretely with nnx.merge to run the forward pass. Linen is
preserved byte-for-byte; every NNX edit is gated `if config.pure_nnx:`
and pure_nnx=False is still the default.

maxengine.py (__init__):
- Build two abstract NNX Transformers on the NNX path: self.model with
  model_mode=PREFILL (batch=1, single padded prompt) and self.model_ar
  with model_mode=AUTOREGRESSIVE (batch=micro_batch_size_to_train_on,
  decode_state shape). Both are needed because NNX cache vars inherit
  CACHE_BATCH_PREFILL vs CACHE_BATCH from the construction model_mode,
  and bulk_insert searches for the substring "cache_batch" in the
  AR-mode logical-axes tuple. nnx.eval_shape is called directly inside
  nn_partitioning.axis_rules rather than through create_nnx_abstract_model
  to avoid the jax.set_mesh wrap that trips Flax 0.12.6 on logical-only
  axes like "norm" (same reason get_abstract_state_nnx avoids set_mesh).
- Cache the graphdef from a 3-way nnx.split(Param, Cache, ...) so JIT
  bodies can pass (params, cache, rest) separately to nnx.merge. The
  rest slot (RNG vars etc.) is materialized concretely in load_params.

maxengine.py (cache adapter + _nnx_run_model):
- bulk_insert / _insert_jit / _maybe_*_prefill_result_cache walk the
  cache via tree_map_with_path and switch on path[-1].key (the cache
  variable name like "cached_prefill_key"). Linen mutable cache is a
  plain nested dict. NNX Cache state would expose a ".value" accessor
  at that position. Bridge via nnx.State.to_pure_dict() (after the
  model run) and nnx.replace_by_pure_dict (before nnx.merge), so the
  cache plumbing helpers see the same shape on both paths.
- Add _nnx_run_model: nnx.merge(graphdef, params, cache, rest, copy=True)
  -> model(...) -> nnx.state(model, nnx.Cache).to_pure_dict(). copy=True
  avoids reusing Variable objects across traces (TraceContextError),
  mirroring train.py's diff_wrapper workaround.
- Add _nnx_cache_state_template / _nnx_init_cache_dict helpers
  parametrised by mode so prefill (batch 1) and decode_state (batch N)
  pull from the right abstract model.

maxengine.py (load_params):
- New _load_params_nnx: accepts user-provided NNX-shape params or loads
  via from_pretrained. For user-provided params, materializes a concrete
  model once via _create_model_fn() to capture a real rest state for
  nnx.merge (wasteful but simple; the from_pretrained branch avoids
  this). Refreshes self.graphdef from the concrete model so subsequent
  merges line up exactly.
- Builds self.abstract_params, populates self.prefill_kv_cache_annotations
  and self.kv_cache_annotations (using model_ar for the latter so
  bulk_insert's substring lookup hits), wraps both into NamedSharding.
- pure_nnx + quantization, pure_nnx + LoRA, pure_nnx +
  stack_prefill_result_cache=True, pure_nnx + prefill_multisampling,
  and pure_nnx + prefill_concat raise NotImplementedError for now;
  the Linen path is the workaround. AOT compilation
  (aot_compile / _compile_generate_and_get_layouts) is not gated and
  may work as-is; not exercised by tests yet.

maxengine.py (init_decode_state, _prefill_jit, _generate_jit):
- _init_decode_state_nnx zero-initializes a pure-dict cache from
  model_ar (so the leading batch dim matches generate's input shape)
  and builds kv_cache_annotations_named per leaf by reading
  nnx.Cache.metadata. Tries "out_sharding", "sharding", and
  "sharding_names" because Flax 0.12.6 renamed these.
- _prefill_jit / _generate_jit add an `if config.pure_nnx:` branch
  that calls _nnx_run_model in place of self.model.apply with
  mutable=["cache"]. existing_prefix.cache is threaded as a pure-dict
  cache directly (no params|{"cache":...} dict-merge — params is an
  nnx.State, not a dict).

maxtext_utils.py:
- New get_prefill_kv_cache_annotations_nnx / get_kv_cache_annotations_nnx
  that mirror the Linen helpers' return shape (per-leaf PartitionSpec
  tree). Both delegate to _nnx_cache_partition_specs which extracts
  nnx.Cache state via nnx.split, calls
  get_nnx_named_sharding_with_scan_axis inside
  nn_partitioning.axis_rules so logical axes ("layers", "cache_batch",
  "norm", ...) resolve to physical mesh axes, and converts the result
  to a pure-dict tree.

tests/unit/maxengine_test.py:
- New tests: test_init_nnx, test_basic_prefill_nnx (with NaN/inf and
  per-layer cache shape checks), test_basic_decode_nnx (4-step generate
  with next_pos advancement check), test_quantize_raises_for_nnx,
  test_lora_raises_for_nnx.
- New test_linen_nnx_parity_prefill: bridges Linen-init params into
  the NNX engine via linen_nnx_converter (convert_linen_to_nnx ->
  _strip_value_wrappers -> nnx.replace_by_pure_dict) and asserts the
  NNX engine's prefill matches Linen on the same weights — logits
  within bf16 tolerance (rtol=0.05, atol=0.1; the test config uses
  bf16 compute) and exact greedy first-token argmax.
- Existing Linen tests untouched.

Test summary: 9 passed, 1 skipped (test_chunked_prefill is a
pre-existing CPU-only skip). bash lint.sh: codespell + pylint + pyink
all green.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant