[NNX] NNX migration prep (7/N): NNX-native MaxEngine inference#3821
Draft
ecnal-cienet wants to merge 5 commits intomainfrom
Draft
[NNX] NNX migration prep (7/N): NNX-native MaxEngine inference#3821ecnal-cienet wants to merge 5 commits intomainfrom
ecnal-cienet wants to merge 5 commits intomainfrom
Conversation
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests - Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils - Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py - Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
…raining fixes Part 1 — sharding diagnostics and Linen<->NNX checkpoint utilities: - modify print_shardings_params to support NNX (maxtext_utils.py) - add --pure_nnx flag to run_sharding_dump.py - add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py) - add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py) Part 2 — post-training bug fixes: - models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the whole object as multimodal_input= kwarg; NNXDecoder only accepts individual fields) - optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams (callable() check before invoking learning_rate_fn) - train_distill.py: fix nested NNX transform issue (nnx.value_and_grad inside nnx.jit raises conflicting outer_index error); refactored to jax.value_and_grad + explicit nnx.split/merge pattern; teacher inference moved outside value_and_grad
Bug fixes (run as no-op while pure_nnx=False stays default):
- nnx_wrappers.py: add _refresh_variable_trace_state + is_linen_initializing;
call from ToLinen after nnx.update to fix "Cannot extract graph node from
different trace level" when grad tracers leak into Variable._trace_state.
- gpt_oss.py / olmo3.py: replace inline nn.Dropout(...) with self.dropout =
linears.Dropout(...) in __init__ to fix CallCompactUnboundModuleError.
- normalizations.py: Qwen3NextRMSNorm signature: eps -> epsilon, accept
shard_mode/kernel_axes/parameter_memory_host_offload for callsite parity.
- attentions.py / qwen3.py: callsites eps= -> epsilon=.
- moe.py: per_expert_scale block moved into the unfused-kernel else branch
(was scaling wo even when fused_kernel was active).
- models.py: build MTP block as MultiTokenPredictionBlock(...) directly
(drop the ToNNX(linen) + lazy_init wrap); pass multimodal_input whole
to NNXDecoder instead of unpacking 5 fields.
- gradient_accumulation.py: ZeRO-1+GA all-reduce annotation deferred until
after lax.scan (reduced/unreduced PartitionSpec is rejected inside scan
carry); use nnx.merge(..., copy=True) to avoid Variable reuse.
- diloco.py: NNX-aware state handling — state.params -> state.model.filter
(nnx.Param), step counter at state.optimizer.step, replace_nnx_model_params
helper for jax.lax.cond pytree-structure parity.
- train_compile.py: new _collect_nnx_activation_shardings helper (forward
pass populates _ACTIVATION_SHARDINGS_DUMP — get_abstract_state_nnx only
traces __init__); NNX path now passes 2-arg shaped_train_args (no rng);
diloco path patched to handle the 2-vs-3 length difference.
- muon_utils.py: get_model_mdn default pure_nnx=True; wrap NNX result as
{"params": nnx.to_pure_dict(...)} for parity with Linen tree shape.
- nnx_decoders.py: FP8+NNX scan fix — Linen FP8 ops (fp8_nanoo, fp8_gpu)
retain tracers in Linen scope across re-traces. Skip jax.checkpoint and
use a Python for-loop instead of jax.lax.scan when quantization is FP8.
Makes FP8 quantization usable on the NNX path.
- train.py (pre-train train_step): return nnx.state(new_state, nnx.Not
(nnx.Intermediate)) so sowed forward-pass artifacts (e.g. max_logits for
QK-Clip) don't break leaf-count parity with state_mesh_shardings.
- llama2.py: pass parameter_memory_host_offload to pre_self_attention_layer
_norm RMSNorm (was missing on this norm only).
- base.yml: add 4 pipeline-related logical_axis_rules — layers_outside
_pipeline, layers_per_stage, num_activations, circular_repeats. Additive,
no-op without use_nnx_pipeline=True.
NNX feature enablements (clear all 17 "Pure NNX support has not been
implemented yet" NotImplementedError sites by routing Linen-coupled
utilities to the Linen path; their on-disk format is Linen):
- layerwise_quantization.py (2 sites): operates on Linen-format checkpoints
via DeepSeek*ToLinen layers.
- lora_utils.py (1 site): downstream get_lora_abstract_state expects Linen
tree shape; LoRA adapters on disk are Linen.
- standalone_checkpointer.py (2 sites): add_entropy_to_checkpoint accesses
state.opt_state[0]._replace(mu=..., nu=...) — Linen-only.
- generate_param_only_checkpoint.py (3 sites): _possibly_unroll_params and
_save_decode_checkpoint use state.params["params"]["decoder"] — Linen.
- convert_gpt3_ckpt_from_paxml.py (2 sites): keystr_map targets Linen tree
paths (.params['params'], .opt_state.mu['params']).
- maxengine.py (3 sites): inference engine uses state.params and serves
Linen-format inference checkpoints.
- grpo_trainer.py (4 sites): RL trainer is end-to-end Linen-shaped; route
to Linen with a clear log warning since NNX-format checkpoints will fail
at restore time.
Vocab tiling on NNX (real implementation, not just routing):
- models.py: add Transformer.logits_from_hidden_states on the NNX
Transformer class — wraps NNXDecoder.apply_output_head with the
token_embedder; mirrors TransformerLinenPure.logits_from_hidden_states.
- vocabulary_tiling.py: add vocab_tiling_nnx_loss — chunks the vocab axis
via jax.lax.scan and calls model.logits_from_hidden_states(chunk) per
chunk. The NNX model carries its parameters internally so no explicit
FSDP gather is needed (unlike the Linen gathered_params pattern). MVP
uses default autograd; custom_vjp memory-savings optimization is a
follow-up if backward memory becomes a concern.
- train.py (NNX loss_fn): replace the NotImplementedError with the call
to vocab_tiling_nnx_loss using hidden_states from intermediates.
- pyconfig_deprecated.py / configs/types.py: drop the num_vocab_tiling > 1
and enable_nnx validation guards (no longer needed).
DPO + NNX retained as NotImplementedError but with a much more informative
message (points users at pure_nnx=False workaround). Full implementation
is deferred — needs a new TrainState shape carrying both policy and
reference NNX models plus an NNX dpo_loss_fn.
Stats: 26 source files modified, +406 / -171 lines. Linen invariant
verified: pure_nnx / enable_nnx / pure_nnx_decoder still default to False;
Linen-path UTs unaffected (3 pre-existing failures on the parent branch
remain unchanged — sharding_compare_test::deepseek2-16b,
optimizers_test::test_model_integration_kimi-k2-1t, diloco_test::two
_slices x2). All "Pure NNX support has not been implemented yet"
NotImplementedError sites cleared (was 17, now 0).
Implements NNX-native DPO so that the pure_nnx=True training path no longer raises NotImplementedError on use_dpo runs. The Linen DPO overlay pattern (model.apply(params=..., reference_params=...)) does not translate to NNX modules, which carry their parameters internally. Instead the policy and reference models are held as separate nnx.Module instances on TrainStateNNX, and a new dpo_loss_fn_nnx runs both forwards with stop_gradient on the reference logits. TrainStateNNX: - Add optional `reference_model: nnx.Module` field. apply_gradients continues to update only `self.model`, leaving `self.reference_model` bit-identical across steps. dpo_utils.py: - Add dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True). Signature mirrors the Linen dpo_loss_fn so it slots into gradient_accumulation_loss_and_grad's dispatcher (dropout_rng / params slots are unused for NNX; carried for parity, and reference_model is passed as the single extra_dpo_args entry). With nnx.value_and_grad(..., argnums=0) over the policy, no gradient flows to the reference model's nnx.Param leaves; the explicit jax.lax.stop_gradient on ref_logits is a belt-and-braces guard. - Both dpo_loss_fn (Linen) and dpo_loss_fn_nnx (NNX) now include indexer_loss=0.0 and mtp_loss=0.0 in aux so the gradient_accumulation aux pytree shape matches the non-DPO loss_fn. train.py: - Drop the NotImplementedError in train_step's NNX branch. When use_dpo, dispatch to dpo_loss_fn_nnx with state.reference_model as extra_dpo_args; otherwise use the regular loss_fn. eval_step gains the same dispatch. - diff_wrapper picks _loss_fn / extra_dpo_args from the per-path init block, so both the GA and non-GA NNX paths route DPO identically. - Checkpoint-save _split_dpo_state stripping is now Linen-only; TrainStateNNX saves whole (reference_model included) — the step-0 reload later overwrites reference_model from the step-0 checkpoint. train_utils.py: - NNX init_state_fn materializes a frozen reference_model alongside the policy when config.use_dpo. Both are constructed by _create_model_partial() with config.init_weights_seed, so they start identical (standard DPO practice) until the step-0 reload. - Step-0 checkpoint reload: copy step0_state["model"] into state["reference_model"]. Linen path unchanged. Tests: - New tests/unit/dpo_nnx_test.py (7 tests): TrainStateNNX reference_model init/hasattr semantics; apply_gradients leaves reference bit-identical; aux key set; identical policy/reference yields loss=log(2) and reward_accuracy=0.0 (strict > on equal logratios); dropout_rng/params slots are signature-compat only; nnx.value_and_grad(argnums=0) over the policy yields finite grads on policy params only. - train_nnx_test.py: drop the two stale negative tests (vocab_tiling_raises_not_implemented, train_step_dpo_raises_for_nnx) — both features are now real. Stats: 4 source files + 2 test files, +199/-22 source lines. Linen DPO path behaviorally unchanged (only adds two harmless aux-dict keys); NNX non-DPO path unchanged (all changes gated on config.use_dpo).
…e.py)
PR5 audited maxengine.py and routed the inference path to the Linen
implementation regardless of pure_nnx, with a comment block explaining
that "the flag affects training, not inference serving." That kept the
Linen serving path unchanged but meant pure_nnx=True users silently got
the Linen engine. This change replaces the route with a real NNX flow:
when config.pure_nnx=True, the engine builds an NNX Transformer, splits
out (params, cache, rest) with nnx.split, and at every JIT body merges
the model concretely with nnx.merge to run the forward pass. Linen is
preserved byte-for-byte; every NNX edit is gated `if config.pure_nnx:`
and pure_nnx=False is still the default.
maxengine.py (__init__):
- Build two abstract NNX Transformers on the NNX path: self.model with
model_mode=PREFILL (batch=1, single padded prompt) and self.model_ar
with model_mode=AUTOREGRESSIVE (batch=micro_batch_size_to_train_on,
decode_state shape). Both are needed because NNX cache vars inherit
CACHE_BATCH_PREFILL vs CACHE_BATCH from the construction model_mode,
and bulk_insert searches for the substring "cache_batch" in the
AR-mode logical-axes tuple. nnx.eval_shape is called directly inside
nn_partitioning.axis_rules rather than through create_nnx_abstract_model
to avoid the jax.set_mesh wrap that trips Flax 0.12.6 on logical-only
axes like "norm" (same reason get_abstract_state_nnx avoids set_mesh).
- Cache the graphdef from a 3-way nnx.split(Param, Cache, ...) so JIT
bodies can pass (params, cache, rest) separately to nnx.merge. The
rest slot (RNG vars etc.) is materialized concretely in load_params.
maxengine.py (cache adapter + _nnx_run_model):
- bulk_insert / _insert_jit / _maybe_*_prefill_result_cache walk the
cache via tree_map_with_path and switch on path[-1].key (the cache
variable name like "cached_prefill_key"). Linen mutable cache is a
plain nested dict. NNX Cache state would expose a ".value" accessor
at that position. Bridge via nnx.State.to_pure_dict() (after the
model run) and nnx.replace_by_pure_dict (before nnx.merge), so the
cache plumbing helpers see the same shape on both paths.
- Add _nnx_run_model: nnx.merge(graphdef, params, cache, rest, copy=True)
-> model(...) -> nnx.state(model, nnx.Cache).to_pure_dict(). copy=True
avoids reusing Variable objects across traces (TraceContextError),
mirroring train.py's diff_wrapper workaround.
- Add _nnx_cache_state_template / _nnx_init_cache_dict helpers
parametrised by mode so prefill (batch 1) and decode_state (batch N)
pull from the right abstract model.
maxengine.py (load_params):
- New _load_params_nnx: accepts user-provided NNX-shape params or loads
via from_pretrained. For user-provided params, materializes a concrete
model once via _create_model_fn() to capture a real rest state for
nnx.merge (wasteful but simple; the from_pretrained branch avoids
this). Refreshes self.graphdef from the concrete model so subsequent
merges line up exactly.
- Builds self.abstract_params, populates self.prefill_kv_cache_annotations
and self.kv_cache_annotations (using model_ar for the latter so
bulk_insert's substring lookup hits), wraps both into NamedSharding.
- pure_nnx + quantization, pure_nnx + LoRA, pure_nnx +
stack_prefill_result_cache=True, pure_nnx + prefill_multisampling,
and pure_nnx + prefill_concat raise NotImplementedError for now;
the Linen path is the workaround. AOT compilation
(aot_compile / _compile_generate_and_get_layouts) is not gated and
may work as-is; not exercised by tests yet.
maxengine.py (init_decode_state, _prefill_jit, _generate_jit):
- _init_decode_state_nnx zero-initializes a pure-dict cache from
model_ar (so the leading batch dim matches generate's input shape)
and builds kv_cache_annotations_named per leaf by reading
nnx.Cache.metadata. Tries "out_sharding", "sharding", and
"sharding_names" because Flax 0.12.6 renamed these.
- _prefill_jit / _generate_jit add an `if config.pure_nnx:` branch
that calls _nnx_run_model in place of self.model.apply with
mutable=["cache"]. existing_prefix.cache is threaded as a pure-dict
cache directly (no params|{"cache":...} dict-merge — params is an
nnx.State, not a dict).
maxtext_utils.py:
- New get_prefill_kv_cache_annotations_nnx / get_kv_cache_annotations_nnx
that mirror the Linen helpers' return shape (per-leaf PartitionSpec
tree). Both delegate to _nnx_cache_partition_specs which extracts
nnx.Cache state via nnx.split, calls
get_nnx_named_sharding_with_scan_axis inside
nn_partitioning.axis_rules so logical axes ("layers", "cache_batch",
"norm", ...) resolve to physical mesh axes, and converts the result
to a pure-dict tree.
tests/unit/maxengine_test.py:
- New tests: test_init_nnx, test_basic_prefill_nnx (with NaN/inf and
per-layer cache shape checks), test_basic_decode_nnx (4-step generate
with next_pos advancement check), test_quantize_raises_for_nnx,
test_lora_raises_for_nnx.
- New test_linen_nnx_parity_prefill: bridges Linen-init params into
the NNX engine via linen_nnx_converter (convert_linen_to_nnx ->
_strip_value_wrappers -> nnx.replace_by_pure_dict) and asserts the
NNX engine's prefill matches Linen on the same weights — logits
within bf16 tolerance (rtol=0.05, atol=0.1; the test config uses
bf16 compute) and exact greedy first-token argmax.
- Existing Linen tests untouched.
Test summary: 9 passed, 1 skipped (test_chunked_prefill is a
pre-existing CPU-only skip). bash lint.sh: codespell + pylint + pyink
all green.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
NNX Migration Route Map
pure_nnxflag,init_state_fn,TrainStateNNX, NNX utils. Linen workflow unchanged. (PR #3427)get_abstract_state_nnx,get_named_sharding_nnx,set_named_sharding_nnx,get_partition_spec_nnx,get_mesh_from_config. (PR #3470)TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR #3500)pure_nnx=Falsestays default. (PR #3766)NotImplementedErroron the NNX path;pure_nnx=True+use_dpo=Trueis now supported.maxengine.py;pure_nnx=Truenow drives a real NNX inference flow end-to-end (prefill, generate, KV cache).standalone_checkpointer,generate_param_only_checkpoint,layerwise_quantization,convert_gpt3_ckpt_from_paxml).custom_vjpfor NNX (perf optimization, not correctness).True; regenerate sharding goldens; flip back integration-testpure_nnx=Falseannotations.Description
PR5 audited
maxengine.pyand routed three call sites to the Linen path (state.paramsshape,state_mesh_shardings.params, and Linen-format inference checkpoint serving), becausetransformer_as_linenreturns a Linen module regardless ofpure_nnx. That preserved Linen serving but meantpure_nnx=Trueusers silently got the Linen engine — the cited rationale was "the flag affects training, not inference serving."This PR replaces those routes with a real NNX inference flow. When
config.pure_nnx=True, the engine now builds an NNXTransformerviafrom_config(rngs=...), splits params/cache/rest withnnx.split(model, nnx.Param, nnx.Cache, ...), and at every JIT body merges the model concretely withnnx.merge(graphdef, params, cache, rest, copy=True)to run the forward pass. Linen is byte-for-byte preserved: every NNX edit is gatedif config.pure_nnx:.The diff is +651 / −34 across 3 files, with the only physical deletion being the "use Linen path regardless of
pure_nnx" comment block at the top ofMaxEngine.__init__.Part 1: Engine construction — two abstract NNX models
inference/maxengine/maxengine.py(__init__)Transformerinstances on the NNX path:self.modelwithmodel_mode=PREFILL(batch=1, single padded prompt) andself.model_arwithmodel_mode=AUTOREGRESSIVE(batch=micro_batch_size_to_train_on, decode-state shape). One graphdef (self.graphdef) is cached from the PREFILL model; both abstract models contain initializednnx.Cachevars (KVCache__init__allocates both prefill and AR caches whenmodel_mode in {PREFILL, AUTOREGRESSIVE}).inference/maxengine/maxengine.py(__init__)nnx.eval_shapedirectly insidenn_partitioning.axis_rules(config.logical_axis_rules)rather thancreate_nnx_abstract_model— the latter wraps the trace injax.set_mesh(mesh)which makes Flax 0.12.6's_to_variableresolve logical axis names against the global mesh and raise on logical-only names like"norm". Same reasonget_abstract_state_nnxavoidsset_mesh.The two-model split is required because NNX cache vars take their logical axis names (
CACHE_BATCHvsCACHE_BATCH_PREFILL) from the model's constructionmodel_mode. The decode_state cache must use AR-mode annotations so thatbulk_insert'scache_batchsubstring lookup hits.Part 2: NNX param loading
inference/maxengine/maxengine.py(_load_params_nnx)from_pretrained. For user-provided params, materializes a concrete model once via_create_model_fn()to capture a real_nnx_rest_state(RNG vars) fornnx.merge. For checkpoint loads, extracts(graphdef, params, cache, rest)from the loadednnx_model. Refreshesself.graphdeffrom the concrete model so subsequent merges line up exactly.inference/maxengine/maxengine.py(_load_params_nnx)self.abstract_paramsasShapeDtypeStructs of the param state, populatesself.prefill_kv_cache_annotationsandself.kv_cache_annotations(using AR-mode for the latter socache_batchappears in the per-leaf logical-axes tuple), and wraps both inNamedSharding.Part 3: NNX init_decode_state
inference/maxengine/maxengine.py(_init_decode_state_nnx)generate's input shape(per_device_batch_size * mesh.size, 1)). Buildskv_cache_annotations_namedper leaf by readingnnx.Cache.metadata["out_sharding" / "sharding" / "sharding_names"](Flax 0.12.6 renames vary). Returns the samedecode_statedict shape as the Linen path.Part 4: Cache adapter — NNX state ↔ pure dict
NNX cache state is an
nnx.Statewithnnx.Cache-wrapped leaves; tree paths include a.valueaccessor at the end. Linen's mutable cache is a plain nested dict with the same leaf names but no.value. The engine'sbulk_insert/_insert_jit/insert_partial/_maybe_*_prefill_result_cachehelpers walk the cache viatree_map_with_pathand switch onpath[-1].key(the cache var name like"cached_prefill_key").inference/maxengine/maxengine.py(_nnx_cache_state_template/_nnx_init_cache_dict/_nnx_run_model)nnx.State.to_pure_dict()(after the model run) andnnx.replace_by_pure_dict(template, pure_dict)(beforennx.merge). The cache plumbing helpers see the same shape on Linen and NNX, requiring zero modifications.Part 5: NNX prefill / generate JIT branches
inference/maxengine/maxengine.py(_prefill_jit,_generate_jit)self.model.apply(params | {"cache": cache}, ..., mutable=["cache"])with a call to_nnx_run_model, which doesnnx.merge(graphdef, params, cache, rest, copy=True) → model(...) → nnx.state(model, nnx.Cache).to_pure_dict().copy=Truemirrors the train.pydiff_wrapperworkaround forTraceContextErrorfrom reused Variable objects. Linen branch (theelse:) is the prior code, unchanged.inference/maxengine/maxengine.py(_prefill_jit)existing_prefix is not None, the NNX branch threadsexisting_prefix.cache(already a pure dict) directly into_nnx_run_model'scache_dict=arg instead of the Linenparams | {"cache": ...}dict-merge —paramsis annnx.State, not a dict.Part 6: KV-cache annotation helpers
utils/maxtext_utils.pyget_prefill_kv_cache_annotations_nnx(abstract_model, config, mesh)andget_kv_cache_annotations_nnx(abstract_model, config, mesh)mirror the Linen helpers' return shape (a tree ofPartitionSpec). Both delegate to_nnx_cache_partition_specs, which extractsnnx.Cachestate viannx.split(model, nnx.Cache, ...), callsget_nnx_named_sharding_with_scan_axisinsidenn_partitioning.axis_rules(config.logical_axis_rules)so logical axes (layers,cache_batch,norm, ...) resolve to physical mesh axes, then converts the resultingnnx.Stateto a pure-dict tree viato_pure_dict().Part 7: Carve-outs (raise
NotImplementedErroron the NNX path with PR pointers)quantize_params(NNX + quantization)pure_nnx=Falseload_paramswhencheckpoint_is_quantized=Truepure_nnx=Falseload_single_adapter/apply_adapter/unapply_adapter(NNX + LoRA)pure_nnx=False_prefill_multisampling_jit(NNX)pure_nnx=Falseprefill_concat(NNX)pure_nnx=Falseload_paramswhenstack_prefill_result_cache=True(NNX)pure_nnx=Falsescan_layers=Truethe NNX cache leaves are already stacked on axis 0; the engine's manual-stack helper assumes an unstacked Linen tree shapeThese are deliberate scope cuts, not silent fallbacks. Each raises with a message pointing at the responsible follow-up PR. The Linen path remains the workaround for users who hit any of these sites.
Deferred (intentionally out of scope)
aot_compile,_compile_generate_and_get_layouts): functions are not gated onpure_nnxand are structure-agnostic, so they likely work as-is — but they are not exercised by PR7's tests. Recommended manual validation when a JetStream-style serving setup first hits this.gemma2-2bdecode parity on a real checkpoint: the in-PR parity test (below) demonstrates equivalence on a 2-layer test config. Apython3 -m MaxText.decode pure_nnx=True ...smoke run againstgemma2-2b(with an NNX-format ckpt or a converted Linen one) is recommended manual validation before flipping defaults in PR11.Tests
test_init_nnx: engine constructs and exposesgraphdef+ abstractTransformer.test_basic_prefill_nnx: prefill produces the same dict shape as the Linen path; logits + every cache leaf are finite (catches silent NaN/inf from a badnnx.mergeor cache round-trip); per-layer cache leading axis equalsnum_decoder_layers(catches scan-axis misalignment).test_basic_decode_nnx: prefill → insert → 4-step generate. Verifiesnext_posadvances by exactly 1 each step (catches off-by-one cache pointer bugs) and logits stay finite at every step.test_linen_nnx_parity_prefill: strongest evidence — same Linen-init weights are bridged into the NNX engine via linen_nnx_converter.py (convert_linen_to_nnx→_strip_value_wrappers→nnx.replace_by_pure_dict), then prefill is run on both engines with the same prompt + greedy. Asserts logits match within bf16 tolerance (rtol=0.05, atol=0.1; the test config uses bf16 compute) and the greedy first-token argmax is exactly the same. A failure here means the NNX forward pass diverged from Linen on identical weights.test_quantize_raises_for_nnx,test_lora_raises_for_nnx: assert the carve-outs fire withNotImplementedError, so users hitting them get a clear "tracked in PR8/PR9" message instead of an opaque crash.test_basic_prefill,test_basic_decode,test_stack_and_unstack_prefill_cache): every NNX edit is gated onif config.pure_nnx:andpure_nnx=Falseis the default, so the Linen branch is byte-for-byte the prior code.bash lint.sh— codespell + pylint + pyink all green.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.