From a62df5908d3a4c1a345233938d45cfe2353ee900 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 24 Apr 2026 22:05:58 +0000 Subject: [PATCH 1/5] NNX: add TrainState, model creation utilities, and training loop support - Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests - Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils - Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py - Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch --- src/maxtext/common/checkpointing.py | 32 +- src/maxtext/input_pipeline/olmo_data.py | 1 - src/maxtext/layers/nnx_decoders.py | 3 +- src/maxtext/trainers/pre_train/train.py | 503 +++++++++++------- src/maxtext/utils/gradient_accumulation.py | 35 +- src/maxtext/utils/maxtext_utils.py | 236 ++++++-- src/maxtext/utils/model_creation_utils.py | 183 ++++--- src/maxtext/utils/muon_utils.py | 60 ++- src/maxtext/utils/sharding.py | 121 ++++- src/maxtext/utils/train_utils.py | 54 +- .../integration/setup_train_loop_nnx_test.py | 140 +++++ tests/unit/checkpointing_nnx_load_test.py | 106 ++++ tests/unit/gradient_accumulation_nnx_test.py | 159 ++++++ tests/unit/maxtext_utils_test.py | 263 ++++++++- tests/unit/muon_utils_test.py | 224 ++++++++ tests/unit/nnx_decoders_test.py | 73 ++- tests/unit/optimizers_test.py | 116 +++- tests/unit/sharding_nnx_test.py | 161 ++++++ tests/unit/train_nnx_test.py | 239 +++++++++ tests/unit/train_state_nnx_checkpoint_test.py | 412 ++++++++++++++ tests/unit/train_state_nnx_test.py | 90 ++++ tests/unit/train_utils_nnx_test.py | 149 ++++++ 22 files changed, 2985 insertions(+), 375 deletions(-) create mode 100644 tests/integration/setup_train_loop_nnx_test.py create mode 100644 tests/unit/checkpointing_nnx_load_test.py create mode 100644 tests/unit/gradient_accumulation_nnx_test.py create mode 100644 tests/unit/muon_utils_test.py create mode 100644 tests/unit/sharding_nnx_test.py create mode 100644 tests/unit/train_nnx_test.py create mode 100644 tests/unit/train_state_nnx_checkpoint_test.py create mode 100644 tests/unit/train_state_nnx_test.py create mode 100644 tests/unit/train_utils_nnx_test.py diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index dc01262e6c..ad7618868a 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -20,6 +20,7 @@ from absl import flags import datetime from etils import epath +from flax import nnx from flax.training import train_state import jax from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE @@ -571,7 +572,7 @@ def load_state_if_possible( load_parameters_from_path: str, load_full_state_from_path: str, checkpoint_storage_concurrent_gb: int, - abstract_unboxed_pre_state: train_state.TrainState, + abstract_unboxed_pre_state: train_state.TrainState | nnx.State, enable_single_replica_ckpt_restoring: bool | None = False, dataset_type: str | None = "tfds", step: int = -1, # -1 means latest @@ -639,9 +640,14 @@ def map_to_pspec(data): ) ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) - restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state) + # Convert nnx.State to pure dict to match how checkpoints are saved for NNX + restore_target = abstract_unboxed_pre_state + if isinstance(abstract_unboxed_pre_state, nnx.State): + restore_target = abstract_unboxed_pre_state.to_pure_dict() + + restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target) checkpoint_args = ocp.args.PyTreeRestore( - item=abstract_unboxed_pre_state, + item=restore_target, restore_args=restore_args, partial_restore=True, ) @@ -679,9 +685,14 @@ def map_to_pspec(data): return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None) if load_parameters_from_path != "": + if isinstance(abstract_unboxed_pre_state, nnx.State): + _, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...) + else: + params = abstract_unboxed_pre_state.params + restored_params = load_params_from_path( load_parameters_from_path, - abstract_unboxed_pre_state.params, + params, checkpoint_storage_concurrent_gb, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3, @@ -773,7 +784,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step # Determine the effective step for saving a checkpoint. # If 'step' is not provided, this call is for a potential final checkpoint # and use the last completed step from the state. - actual_step = (int(state.step) - 1) if step is None else int(step) + if step is not None: + actual_step = int(step) + else: + if config.pure_nnx: + actual_step = int(state.optimizer.step) - 1 + else: + # Linen TrainState has .step attribute + actual_step = int(state.step) - 1 + + if config.pure_nnx: + # Convert nnx.State to dict. + state = state.to_pure_dict() # Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic. # This occurs if this function was called: diff --git a/src/maxtext/input_pipeline/olmo_data.py b/src/maxtext/input_pipeline/olmo_data.py index 4613c4eb71..82c258b78a 100644 --- a/src/maxtext/input_pipeline/olmo_data.py +++ b/src/maxtext/input_pipeline/olmo_data.py @@ -27,7 +27,6 @@ import bisect import dataclasses import hashlib -import io import json import os import struct diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 6932eed6c1..100d9c6817 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -35,6 +35,7 @@ MODEL_MODE_TRAIN, Config, DecoderBlockType, + MultimodalInput, ShardMode, ) from maxtext.inference import page_manager @@ -1059,10 +1060,10 @@ def __call__( previous_chunk=None, slot: None | int = None, page_state: None | page_manager.PageState = None, - multimodal_input: None | Any = None, kv_caches: list[jax.Array] | None = None, attention_metadata=None, deepstack_visual_embeds: None | list[jnp.ndarray] = None, + multimodal_input: None | MultimodalInput = None, ): cfg = self.config assert decoder_input_tokens.ndim == 2 # [batch, len] diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 1011563a7b..bd475deba4 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -35,8 +35,9 @@ import jax import jax.numpy as jnp +from jax.sharding import NamedSharding -from flax import linen as nn +from flax import linen as nn, nnx from flax.linen import partitioning as nn_partitioning from maxtext.configs import pyconfig @@ -68,6 +69,7 @@ from maxtext.utils import maxtext_utils from maxtext.utils import qk_clip_utils from maxtext.utils import sharding +from maxtext.utils import maxtext_utils_nnx from maxtext.utils import train_utils from maxtext.utils.gradient_accumulation import gradient_accumulation_loss_and_grad from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss @@ -92,11 +94,11 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr """loss_fn for both train and eval. Args: - model: A nn.Module + model: A nn.Module (Linen) or nnx.Module (NNX). config: Config of parameters data: Batch of data to apply to the model - dropout_rng: A key to use to generate rng for dropout - params: Model params + dropout_rng: A key to use to generate rng for dropout (Linen); unused for NNX. + params: Model params (Linen); unused for NNX (params are part of the model). is_train: True for train_step and False for eval_step Returns: @@ -183,7 +185,7 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr xent_sum = jnp.sum(xent) total_z_loss = jnp.sum(z_loss) else: - # Flax NNX model + # Flax NNX model: forward pass, then pop Intermediates sown during it. logits = model( decoder_input_tokens=data["inputs"], decoder_positions=data["inputs_position"], @@ -194,7 +196,11 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr decoder_target_tokens=data["targets"], decoder_target_mask=data["targets_segmentation"], ) - intermediate_outputs = {} + intermediates = nnx.pop(model, nnx.Intermediate) + intermediate_outputs = intermediates.to_pure_dict() + + if config.num_vocab_tiling > 1: + raise NotImplementedError("Vocab tiling for NNX modules has not been implemented.") if (config.use_indexer and not config.indexer_sparse_training) and is_train: # In Dense Warm-up stage, we skip main model loss calculation for efficiency. @@ -286,74 +292,111 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr return loss, aux -def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng): - """ +def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng=None): + """Training step for both Linen and NNX models. Args: - model: A nn.Module - state: A pytree of the current state of the model - data: Batch of data to apply to the model - dropout_rng: A key to use to generate rng for dropout + model: A nn.Module (Linen) or nnx.GraphDef of the TrainStateNNX (NNX). + config: Hyperparameters. + state_mesh_shardings: PyTree of PartitionSpecs for the train state. + params_shardings: PyTree of PartitionSpecs for model parameters, used for gradient accumulation. + state: Linen TrainState or NNX pure State. + data: Training data batch. + dropout_rng: A key to use to generate rng for dropout (Linen); unused for NNX. Returns: - new_state: Same format as state. + new_state: Updated Linen TrainState or NNX pure State. metrics: Dictionary of model metrics such as loss, training rate, etc. - rng2: A new rng key that can be used in future calls. - """ - reference_params, reference_params_sharding, extra_dpo_args, _loss_fn = ( - [], - [], - [], - loss_fn, - ) - if config.use_dpo: - state, reference_params = _split_dpo_state(state) - state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings) - extra_dpo_args = [reference_params] - _loss_fn = dpo_loss_fn + # --- Per-path initialization --- + if isinstance(model, nn.Module): + reference_params, reference_params_sharding, extra_dpo_args, _loss_fn = [], [], [], loss_fn + if config.use_dpo: + state, reference_params = _split_dpo_state(state) + state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings) + extra_dpo_args = [reference_params] + _loss_fn = dpo_loss_fn + params = state.params + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = _loss_fn, model, params, dropout_rng, extra_dpo_args + else: + if config.use_dpo: + raise NotImplementedError("DPO for NNX modules has not been implemented.") + state = nnx.merge(model, state) # reconstruct TrainStateNNX + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] - params = state.params + # --- Gradient computation --- if config.gradient_accumulation_steps > 1: loss, aux, raw_grads = gradient_accumulation_loss_and_grad( - _loss_fn, + ga_fn, config, - model, - params, + ga_model, + ga_params, params_shardings, data, - dropout_rng, - extra_dpo_args, + ga_rng, + ga_dpo, ) else: - if config.optimizer_memory_host_offload: - if config.use_dpo: + if isinstance(model, nn.Module): + if config.optimizer_memory_host_offload and config.use_dpo: reference_params = jax.device_put( reference_params, max_utils.with_memory_kind(reference_params_sharding, "device"), ) extra_dpo_args = [reference_params] - if config.shard_optimizer_over_data: - params = jax.tree.map( - functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), - params, - params_shardings, + if config.shard_optimizer_over_data: + params = jax.tree.map( + functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), + params, + params_shardings, + ) + sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m + pure_params = params["params"] if sparsity_enabled else params + batch_stats = params.get("batch_stats", {}) + + grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) + (loss, aux), raw_grads = grad_func( + model, + config, + data, + dropout_rng, + pure_params, + *extra_dpo_args, + sparsity_state=batch_stats, + is_train=True, ) - sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m - pure_params = params["params"] if sparsity_enabled else params - batch_stats = params.get("batch_stats", {}) + else: + model_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...) + if config.parameter_memory_host_offload: + # Params are kept on host (pinned_host) in in_shardings. Move only Param + # variables to device before the forward/backward pass so that all dot_general + # operands share the same memory space (XLA on GPU requires this). + # Using params_shardings (Param-only) avoids Shardy rank mismatches that + # occur when applying PartitionSpec() (rank-0 in SDY) to rank-1 RNG key tensors. + device_param_shardings = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_device, + params_shardings, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + curr_params = jax.device_put(curr_params, device_param_shardings) + nnx.update(state.model, curr_params) # ensure state.model has device params for optimizer update + if config.shard_optimizer_over_data: + curr_params = jax.tree.map( + functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), + curr_params, + params_shardings, + ) + nnx.update(state.model, curr_params) - grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) - (loss, aux), raw_grads = grad_func( - model, - config, - data, - dropout_rng, - pure_params, - *extra_dpo_args, - sparsity_state=batch_stats, - is_train=True, - ) + def diff_wrapper(param, rest, config, data): + local_model = nnx.merge(model_graphdef, param, rest, copy=True) + loss, aux = loss_fn(local_model, config, data, None, None, is_train=True) + _, _, new_rest = nnx.split(local_model, nnx.Param, ...) + return loss, (aux, new_rest) + + grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True) + (loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data) + nnx.update(state.model, new_rest) raw_grads = jax.tree_util.tree_map( lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, @@ -364,6 +407,8 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat raw_grads, max_utils.with_memory_kind(params_shardings, "device"), ) + + # Extract aux fields into locals intermediate_outputs = aux["intermediate_outputs"] xent_sum = aux["xent_sum"] total_weights = aux["total_weights"] @@ -373,67 +418,90 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat moe_bias_updates = aux.get("moe_bias_updates") mtp_loss = aux.get("mtp_loss", 0.0) - if config.gradient_clipping_threshold > 0: - grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold) - else: - grads = raw_grads - - if config.optimizer_memory_host_offload: - state = state.replace( - opt_state=jax.device_put( - state.opt_state, - jax.tree_util.tree_map( - lambda x: x.with_memory_kind(kind="device"), - state_mesh_shardings.opt_state, - ), - ) - ) - # Move all parameters to device before optimizer update - if config.parameter_memory_host_offload: - max_logging.log("\nMoving all parameters to device before optimizer update") - - def move(path, value): - max_logging.log(f"train.py: Moving f{path} to device") - return value.with_memory_kind(kind="device") - - state = state.replace( - params=jax.device_put( - state.params, - jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params), - ) - ) - # Re-wrap grads to match state.params structure if it's a dict of collections - sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m - if sparsity_enabled: - full_grads = {"params": grads} - if sparsity_enabled and "batch_stats" in state.params: - batch_stats_grads = jax.tree_util.tree_map(jnp.zeros_like, state.params.get("batch_stats", {})) - full_grads["batch_stats"] = batch_stats_grads - full_grads = max_utils.unbox_logicallypartioned(full_grads) - else: - full_grads = grads - - if getattr(config, "skip_step_on_spikes", False): - grad_norm = max_utils.l2norm_pytree(grads) - # TrainState.apply_gradients doesn't pass **kwargs to tx.update, so we unpack it manually. - updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params, loss=loss, grad_norm=grad_norm) - new_params = optax.apply_updates(state.params, updates) - - new_state = state.replace( - step=state.step + 1, - params=new_params, - opt_state=new_opt_state, - ) + if isinstance(model, nn.Module): + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold) + else: + grads = raw_grads + if config.optimizer_memory_host_offload: + state = state.replace( + opt_state=jax.device_put( + state.opt_state, + jax.tree_util.tree_map( + lambda x: x.with_memory_kind(kind="device"), + state_mesh_shardings.opt_state, + ), + ) + ) + # Move all parameters to device before optimizer update + if config.parameter_memory_host_offload: + max_logging.log("\nMoving all parameters to device before optimizer update") + + def move(path, value): + max_logging.log(f"train.py: Moving f{path} to device") + return value.with_memory_kind(kind="device") + + state = state.replace( + params=jax.device_put( + state.params, + jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params), + ) + ) + # Re-wrap grads to match state.params structure if it's a dict of collections + # (when weight_sparsity is enabled, params has both 'params' and 'batch_stats' keys). + sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m + if sparsity_enabled: + full_grads = {"params": grads} + if "batch_stats" in state.params: + batch_stats_grads = jax.tree_util.tree_map(jnp.zeros_like, state.params.get("batch_stats", {})) + full_grads["batch_stats"] = batch_stats_grads + full_grads = max_utils.unbox_logicallypartioned(full_grads) + else: + full_grads = grads + + if getattr(config, "skip_step_on_spikes", False): + grad_norm = max_utils.l2norm_pytree(grads) + # TrainState.apply_gradients doesn't pass **kwargs to tx.update, so we unpack it manually. + updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params, loss=loss, grad_norm=grad_norm) + new_params = optax.apply_updates(state.params, updates) + + new_state = state.replace( + step=state.step + 1, + params=new_params, + opt_state=new_opt_state, + ) + else: + new_state = state.apply_gradients(grads=full_grads) + + # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family + if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: + target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") + # Updates the shape to be aligned with state. + moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() + new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) else: - new_state = state.apply_gradients(grads=full_grads) + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold) + else: + grads = raw_grads + if config.optimizer_memory_host_offload: + # state.optimizer is an NNX Optimizer module; state_mesh_shardings.optimizer + # is an NNX State. Use nnx.state() to get a compatible State for device_put. + device_opt_shardings = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_device, + state_mesh_shardings.optimizer, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + opt_state = nnx.state(state.optimizer) + new_opt_state = jax.device_put(opt_state, device_opt_shardings) + nnx.update(state.optimizer, new_opt_state) + state.apply_gradients(grads) + new_state = state - # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") - # Flax 'sow' returns a tuple, so we take the first element [0]. - # Updates the shape to be aligned with state. - moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() - new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) + # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family + if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: + target_bias = new_state.model.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias + target_bias.value = target_bias.value + jnp.array(moe_bias_updates[0]).transpose() lm_loss = xent_sum / (total_weights + EPS) scalar_metrics = { @@ -447,8 +515,9 @@ def move(path, value): "learning/total_weights": total_weights, } if config.use_qk_clip: - # Apply QK-Clip - new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) + # Apply QK-Clip (Linen path only; NNX uses different state layout — TODO: implement for NNX) + if isinstance(model, nn.Module): + new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) # Report max_logits metric global_max_logit = qk_clip_utils.calculate_max_logit_metric(intermediate_outputs) @@ -458,7 +527,11 @@ def move(path, value): if not config.optimizer_memory_host_offload: scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads) scalar_metrics["learning/raw_grad_norm"] = max_utils.l2norm_pytree(raw_grads) - scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_state.params) + if isinstance(model, nn.Module): + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_state.params) + else: + model_params = nnx.state(new_state.model, nnx.Param) + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(model_params) if config.use_dpo: scalar_metrics["learning/dpo_loss"] = aux["dpo_loss"] scalar_metrics["learning/dpo_reward_accuracy"] = aux["reward_accuracy"] @@ -466,31 +539,34 @@ def move(path, value): "scalar": scalar_metrics, "scalars": {}, } - if config.record_internal_nn_metrics: record_activation_metrics(metrics, intermediate_outputs, config) - if config.use_dpo: - new_state = _merge_dpo_state(new_state, reference_params) - - return new_state, metrics + if isinstance(model, nn.Module): + if config.use_dpo: + new_state = _merge_dpo_state(new_state, reference_params) + return new_state, metrics + return nnx.state(new_state), metrics -def eval_step(model, config, state, data, dropout_rng): +def eval_step(model, config, state, data, dropout_rng=None): """eval_step no backprop and new state compared with train_step.""" + if isinstance(model, nn.Module): + reference_params, extra_dpo_args, _loss_fn = [], [], loss_fn + if config.use_dpo: + state, reference_params = _split_dpo_state(state) + extra_dpo_args = [reference_params] + _loss_fn = dpo_loss_fn - reference_params, extra_dpo_args, _loss_fn = [], [], loss_fn - if config.use_dpo: - state, reference_params = _split_dpo_state(state) - extra_dpo_args = [reference_params] - _loss_fn = dpo_loss_fn - - sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m - pure_params = state.params["params"] if sparsity_enabled else state.params - batch_stats = state.params.get("batch_stats", {}) + sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m + pure_params = state.params["params"] if sparsity_enabled else state.params + batch_stats = state.params.get("batch_stats", {}) - eval_loss_fn = functools.partial(_loss_fn, model, config, data, dropout_rng, is_train=False) - loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, sparsity_state=batch_stats) + eval_loss_fn = functools.partial(_loss_fn, model, config, data, dropout_rng, is_train=False) + loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, sparsity_state=batch_stats) + else: + state = nnx.merge(model, state) # reconstruct TrainStateNNX + loss, aux = loss_fn(state.model, config, data, None, None, is_train=False) mtp_acceptance_rate = 0.0 if config.mtp_eval_target_module > 0: @@ -518,7 +594,7 @@ def eval_step(model, config, state, data, dropout_rng): "evaluation/mtp_acceptance_rate_percent": mtp_acceptance_rate, }, } - if config.use_dpo: + if isinstance(model, nn.Module) and config.use_dpo: metrics["scalar"]["evaluation/dpo_reward_accuracy"] = aux["reward_accuracy"] return metrics @@ -540,32 +616,46 @@ def train_loop(config, recorder, state=None): state, ) = train_utils.setup_train_loop(config, recorder) - if config.use_dpo: - if "reference_params" not in state.params: - reference_params = jax.tree.map(jnp.copy, state.params["params"]) - state = _merge_dpo_state(state, reference_params) - state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + if isinstance(model, nn.Module): + if config.use_dpo: + if "reference_params" not in state.params: + reference_params = jax.tree.map(jnp.copy, state.params["params"]) + state = _merge_dpo_state(state, reference_params) + state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + jit_model = model + else: + if config.use_dpo: + raise NotImplementedError("DPO is not supported for NNX models.") + jit_model, state = nnx.split(state) params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) + p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( + config, + jit_model, + mesh, + state, + state_mesh_shardings, + train_step, + eval_step, + eval_data_iterator, + params_shardings, + ) + with jax.set_mesh(mesh), mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( - config, - model, - mesh, - state, - state_mesh_shardings, - train_step, - eval_step, - eval_data_iterator, - params_shardings, - ) shaped_batch = maxtext_utils.get_shaped_batch(config) - if config.shard_optimizer_over_data: + if config.shard_optimizer_over_data and isinstance(model, nn.Module): state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) - maxtext_utils.maybe_dump_jaxpr(config, p_train_step, (state, shaped_batch, init_rng)) + elif config.shard_optimizer_over_data: + # NNX: reshard state so params match the data-sharded in_shardings (Zero-1 layout) + state = jax.device_put(state, state_mesh_shardings) + if isinstance(model, nn.Module): + lower_args = (state, shaped_batch, init_rng) + else: + lower_args = (state, shaped_batch) + maxtext_utils.maybe_dump_jaxpr(config, p_train_step, lower_args) if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded - compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() + compiled = p_train_step.lower(*lower_args).compile() compiled_stats = compiled.memory_analysis() max_utils.print_compiled_memory_stats(compiled_stats) @@ -574,7 +664,11 @@ def train_loop(config, recorder, state=None): metric_logger_instance = metric_logger.MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) # Write train config params, num model params, and XLA flags to tensorboard - metric_logger_instance.write_setup_info_to_tensorboard(state.params) + if isinstance(model, nn.Module): + setup_params = state.params + else: + _, setup_params, _ = nnx.split(state.model, nnx.Param, ...) + metric_logger_instance.write_setup_info_to_tensorboard(setup_params) _job_completed_gracefully = False try: @@ -584,59 +678,62 @@ def train_loop(config, recorder, state=None): with jax.profiler.StepTraceAnnotation("train", step_num=step): example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager) - # pylint: disable=not-callable - nextrng = jax.jit(jax.random.fold_in)(init_rng, step) + if isinstance(model, nn.Module): + # pylint: disable=not-callable + step_rng_args = (jax.jit(jax.random.fold_in)(init_rng, step),) + else: + step_rng_args = () with maybe_record_goodput(recorder, GoodputEvent.STEP, step): with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): - if config.shard_optimizer_over_data: + if config.shard_optimizer_over_data and isinstance(model, nn.Module): state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) - state, metrics = p_train_step(state, example_batch, nextrng) - - step_time_delta = datetime.datetime.now() - last_step_completion - last_step_completion = datetime.datetime.now() - - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] - checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) - - if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): - jax.block_until_ready(state) # Ensure compilation has finished. - gcs_utils.upload_dump( - config.dump_hlo_local_dir, - config.dump_hlo_gcs_dir, - module_name=config.dump_hlo_module_name, - delete_local_after=config.dump_hlo_delete_local_after, - all_host_upload=config.dump_hlo_upload_all, - ) - - if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: - assert eval_data_iterator - # Explicitly reset the eval iterator and counters before starting the eval loop - eval_data_iterator.reset() - metric_logger_instance.reset_eval_metrics() - - eval_step_count = 0 - # pylint: disable=not-callable - for eval_batch in eval_data_iterator: - # Shard input eval data - eval_batch = jax.device_put(eval_batch, sharding.get_input_data_sharding(config, mesh)) - if config.eval_steps > 0 and eval_step_count >= config.eval_steps: - break - with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): - eval_metrics = p_eval_step(state, eval_batch, nextrng) - metric_logger_instance.record_eval_metrics(step, metrics=eval_metrics) - max_logging.log(f"Completed eval step {eval_step_count}") - eval_step_count += 1 - metric_logger_instance.record_eval_metrics(step, eval_step_count=eval_step_count) - if metric_logger_instance.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: - prof.deactivate() - raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") - - prof.maybe_deactivate_profiler(step, state) - - if step == start_step: - max_utils.print_mem_stats("After params initialized") - - metric_logger_instance.buffer_and_write_train_metrics(metrics, step, step_time_delta) + state, metrics = p_train_step(state, example_batch, *step_rng_args) + + step_time_delta = datetime.datetime.now() - last_step_completion + last_step_completion = datetime.datetime.now() + + state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) + + if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): + jax.block_until_ready(state) # Ensure compilation has finished. + gcs_utils.upload_dump( + config.dump_hlo_local_dir, + config.dump_hlo_gcs_dir, + module_name=config.dump_hlo_module_name, + delete_local_after=config.dump_hlo_delete_local_after, + all_host_upload=config.dump_hlo_upload_all, + ) + + if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: + assert eval_data_iterator + # Explicitly reset the eval iterator and counters before starting the eval loop + eval_data_iterator.reset() + metric_logger_instance.reset_eval_metrics() + + eval_step_count = 0 + # pylint: disable=not-callable + for eval_batch in eval_data_iterator: + # Shard input eval data + eval_batch = jax.device_put(eval_batch, sharding.get_input_data_sharding(config, mesh)) + if config.eval_steps > 0 and eval_step_count >= config.eval_steps: + break + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + eval_metrics = p_eval_step(state, eval_batch, *step_rng_args) + metric_logger_instance.record_eval_metrics(step, metrics=eval_metrics) + max_logging.log(f"Completed eval step {eval_step_count}") + eval_step_count += 1 + metric_logger_instance.record_eval_metrics(step, eval_step_count=eval_step_count) + if metric_logger_instance.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: + prof.deactivate() + raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") + + prof.maybe_deactivate_profiler(step, state) + + if step == start_step: + max_utils.print_mem_stats("After params initialized") + + metric_logger_instance.buffer_and_write_train_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] diff --git a/src/maxtext/utils/gradient_accumulation.py b/src/maxtext/utils/gradient_accumulation.py index 9bad1cfb35..e1699647c6 100644 --- a/src/maxtext/utils/gradient_accumulation.py +++ b/src/maxtext/utils/gradient_accumulation.py @@ -17,6 +17,7 @@ import jax import jax.numpy as jnp from jax.sharding import NamedSharding +from flax import nnx from maxtext.common.common_types import ShardMode from maxtext.utils.sharding import maybe_shard_with_name @@ -49,7 +50,8 @@ def gradient_accumulation_loss_and_grad( config: Model and training configuration object. Must contain `gradient_accumulation_steps` and `shard_optimizer_over_data`. model: The model module. - params: The model parameters (PyTree). + params: The model parameters (PyTree). This is only used for Linen. For NNX, + we can get the params from the model. params_shardings: The sharding constraints for the parameters (PyTree). data: A PyTree of batched data. The leading dimension is assumed to be the total batch size (microbatch_size * num_accumulations). @@ -67,12 +69,18 @@ def _maybe_shard_with_name(inputs, sharding_names): """Wrapper of maybe_shard_with_name with fixed shard_mode""" return maybe_shard_with_name(inputs, sharding_names, config.shard_mode, debug_sharding=config.debug_sharding) + is_nnx = isinstance(model, nnx.Module) + # For more efficient DP/ZeRO-1 + GA if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1: ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings) grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) else: ga_params_shardings = grad_shardings = params_shardings + + if is_nnx: + graphdef, params, rest = nnx.split(model, nnx.Param, ...) + # When using Zero-1 optimizer sharding, cast params to lower precision and apply sharding constraints # so that all-gather is done once in the lower precision before the gradient accumulation loop if config.shard_optimizer_over_data: @@ -87,11 +95,27 @@ def convert_to_bf16(param): ga_params = params ga_params = jax.tree.map(_maybe_shard_with_name, ga_params, ga_params_shardings) - grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) + if is_nnx: + grad_func = nnx.value_and_grad(_loss_fn, argnums=0, has_aux=True) + else: + grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) def accumulate_gradient(acc_grad_and_loss, data): ga_params = acc_grad_and_loss["ga_params"] - (_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, ga_params, *extra_dpo_args, is_train=True) + if is_nnx: + # Reconstruct the model using the fixed parameters (ga_params) + # and the advancing non-parameter state (RNGs) from the carry. + local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"]) + (_, aux), cur_batch_gradient = grad_func(local_model, config, data, None, None, *extra_dpo_args, is_train=True) + _, _, next_rest_state = nnx.split(local_model, nnx.Param, ...) + acc_grad_and_loss["rest_state"] = next_rest_state + else: + rng = ( + jax.random.fold_in(dropout_rng, acc_grad_and_loss["total_weights"].astype(jnp.int32)) + if dropout_rng is not None + else None + ) + (_, aux), cur_batch_gradient = grad_func(model, config, data, rng, ga_params, *extra_dpo_args, is_train=True) acc_grad_and_loss["loss"] += aux["xent_sum"] + aux.get("dpo_loss", 0.0) acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"] acc_grad_and_loss["indexer_loss"] += aux["indexer_loss"] @@ -119,6 +143,8 @@ def reshape_to_microbatch_accumulations(batch_arr): "mtp_loss": 0.0, "ga_params": ga_params, } + if is_nnx: + init_grad_and_loss["rest_state"] = rest grad_and_loss, aux = jax.lax.scan( accumulate_gradient, init_grad_and_loss, data, length=config.gradient_accumulation_steps @@ -134,6 +160,9 @@ def reshape_to_microbatch_accumulations(batch_arr): raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss["total_weights"], raw_grads) aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) # pytype: disable=module-attr + if is_nnx: + nnx.update(model, grad_and_loss["rest_state"]) + return loss, aux, raw_grads diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 76e0852921..37a0710cbb 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -20,21 +20,20 @@ import os from typing import Sequence -from flax import linen as nn +from flax import nnx, linen as nn +from flax.core.spmd import composite_rules, from_sharding_rules, get_logical_axis_rules from flax.linen import partitioning as nn_partitioning -from flax.training import train_state +from flax.training.train_state import TrainState import numpy as np -from jax.experimental import mesh_utils -from jax.experimental.serialize_executable import deserialize_and_load -from jax.sharding import AxisType, Mesh - import jax import jax.numpy as jnp +from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils +from jax.experimental.serialize_executable import deserialize_and_load import optax - import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager @@ -55,6 +54,7 @@ from maxtext.utils import max_utils from maxtext.utils import sharding from maxtext.utils import elastic_utils +from maxtext.utils import maxtext_utils_nnx OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" @@ -102,7 +102,10 @@ def get_functional_train_with_signature( """Get the shardings (both state and data) for `train_step`.""" functional_train = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings) functional_train.__name__ = "train_step" - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + if config.pure_nnx: + in_shardings = (state_mesh_shardings, data_sharding) # State, batch + else: + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng out_shardings = (state_mesh_shardings, None) # State, metrics static_argnums = () # We partial out the static argnums of model and config donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory. @@ -113,7 +116,10 @@ def get_functional_eval_with_signature(eval_step, data_sharding, state_mesh_shar """Get the shardings (both state and data) for `eval_step`.""" functional_eval = functools.partial(eval_step, model, config) functional_eval.__name__ = "eval_step" - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + if config.pure_nnx: + in_shardings = (state_mesh_shardings, data_sharding) # State, batch (NNX: no rng) + else: + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng out_shardings = None # metrics static_argnums = () # We partial out the static argnums of model, config donate_argnums = () # state will be kept instead of being donated in eval_step @@ -1232,15 +1238,15 @@ def _apply_update(path, param): return state.replace(params=new_params) -def init_decode_state(apply_fn, params) -> train_state.TrainState: +def init_decode_state(apply_fn, params) -> TrainState: """Init train state with null opt state for decode.""" - state = train_state.TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore + state = TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore return state def init_training_state(apply_fn, params, tx): """Init train state with null opt state for decode.""" - state = train_state.TrainState.create(apply_fn=apply_fn, params=params, tx=tx) + state = TrainState.create(apply_fn=apply_fn, params=params, tx=tx) return state @@ -1368,7 +1374,7 @@ def setup_initial_state( is_training: True to initialize training state, False for decode state Returns: - state: the initialized train state + train_state: the initialized train state. For NNX, this is a TrainStateNNX instance state_mesh_annotations: the mesh annotations for the train state """ @@ -1407,29 +1413,48 @@ def setup_initial_state( else: # The update of data_iterator state happens in place, no need to assign explicitly state = restored["items"] + + # For NNX, convert the pure dict to nnx.State using the abstract state as template + if config.pure_nnx: + nnx.replace_by_pure_dict(unboxed_abstract_state, state) + state = unboxed_abstract_state else: init_state_partial = init_state_fn init_state_partial.__name__ = "initialize_state" - # pylint: disable=not-callable - state = jax.jit( - init_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings, - )() - sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m - if sparsity_enabled and raw_params: # If we loaded a partial state, we need to merge it. - - def _merge_params(p_raw, p_init): - if isinstance(p_raw, jax.ShapeDtypeStruct): - return p_init - return p_raw - - merged_params = jax.tree_util.tree_map(_merge_params, raw_params, state.params) - state = state.replace(params=merged_params) - elif raw_params: - state = state.replace(params=raw_params) - - state = max_utils.unbox_logicallypartioned(state) + if config.pure_nnx: + state = jax.jit( + lambda: nnx.state(init_state_partial()), # Get state only, mapping to out_sharding structure + in_shardings=None, + out_shardings=state_mesh_shardings, + )() + else: + # pylint: disable=not-callable + state = jax.jit( + init_state_partial, + in_shardings=None, + out_shardings=state_mesh_shardings, + )() + if raw_params: # If we loaded a partial state, we need to merge it. + if config.pure_nnx: + # raw_params should have the same sharding info as in the model + nnx.update(state.model, raw_params) + else: + sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m + if sparsity_enabled: + # Sparsity-init keeps freshly initialized params for any leaf still + # represented as an abstract ShapeDtypeStruct in raw_params (i.e. not + # actually restored), and uses the restored value otherwise. + def _merge_params(p_raw, p_init): + if isinstance(p_raw, jax.ShapeDtypeStruct): + return p_init + return p_raw + + merged_params = jax.tree_util.tree_map(_merge_params, raw_params, state.params) + state = state.replace(params=merged_params) + else: + state = state.replace(params=raw_params) + if not config.pure_nnx: + state = max_utils.unbox_logicallypartioned(state) return state, state_mesh_annotations, state_mesh_shardings, data_iterator @@ -1444,6 +1469,9 @@ def get_logical_annotations(config, mesh, init_state_fn): def get_abstract_state(config, mesh, init_state_fn, is_training=True): """Get a shaped abstraction of the state (including optimizer)""" + if config.pure_nnx: + return get_abstract_state_nnx(config, mesh, init_state_fn, is_training) + init_state_partial = init_state_fn with nn_partitioning.axis_rules(config.logical_axis_rules): @@ -1487,6 +1515,148 @@ def move(path, x): ) +def get_nnx_named_sharding_with_scan_axis(abs_var_state: nnx.State, mesh) -> nnx.State: + """Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis. + + Unlike flax.nnx.spmd.get_var_pspec (used inside nnx.get_abstract_model), this function also + inserts the partition_name axis at the correct scan_axis position for parameters created by + _create_scanned_layers. Without this, scanned parameters get a 2D partition spec applied to a + 3D tensor, placing sharding on the stacked-layers dimension instead of the embedding dimension. + + Args: + abs_var_state: NNX abstract variable state from nnx.split(nnx.eval_shape(...)). + mesh: JAX physical mesh. + + Returns: + Same tree structure as abs_var_state but each Variable's value replaced with NamedSharding. + """ + + def _make_named_sharding(v): + val = v.get_value() + if not hasattr(val, "shape"): + # Non-tensor value (e.g., optax MaskedNode for non-trainable params). Preserve + # as-is so the treedef matches abs_var_state in the downstream jax.tree.map. + return v + metadata = v.get_metadata() + out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding") + if not out_sharding: + pspec = PartitionSpec() + else: + # Insert the scan axis for parameters created by _create_scanned_layers. + # _add_scan_metadata stores the axis name in nnx.PARTITION_NAME and the + # axis index in "param_scan_axis". flax.nnx.spmd.get_var_pspec ignores these. + if nnx.PARTITION_NAME in metadata: + partition_name = metadata[nnx.PARTITION_NAME] + # Always use param_scan_axis from metadata. OptVariable (optimizer state) inherits + # param_scan_axis=1 from the model Param via to_opt_state(), so we must not hardcode + # scan_axis=0 for non-Param types. stacked_rest non-Param variables have + # param_scan_axis=0 set explicitly by _add_scan_metadata, so this is always correct. + scan_axis = metadata.get("param_scan_axis", 0) + out_sharding = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding) + # Guard against double-insertion: Flax 0.12.6 _remap_sharding_metadata renames + # 'sharding' -> 'out_sharding', so _add_scan_metadata may have already inserted + # the scan axis. Only insert if not already present. + if partition_name not in out_sharding: + out_sharding.insert(scan_axis, partition_name) + out_sharding = tuple(out_sharding) + # Convert logical axis names to physical mesh axes using current context rules. + context_rules = get_logical_axis_rules() + local_rules = metadata.get("sharding_rules", ()) + if context_rules or local_rules: + rules = composite_rules(context_rules, local_rules) + pspec = PartitionSpec(*from_sharding_rules(out_sharding, rules)) + else: + pspec = PartitionSpec(*out_sharding) + return v.replace(NamedSharding(mesh, pspec)) + + return jax.tree.map(_make_named_sharding, abs_var_state, is_leaf=lambda x: isinstance(x, nnx.Variable)) + + +def get_abstract_state_nnx(config, mesh, nnx_init_trainstate_fn, is_training=True): + """Calculates the abstract sharded state and memory placement for an NNX TrainState. + + This function performs an abstract trace of the NNX model and optimizer using + `nnx.get_abstract_model`. It resolves logical sharding annotations into physical + JAX shardings and applies memory placement optimizations such as optimizer + sharding and host memory offloading (pinning to CPU RAM). + + Args: + config: Configuration object containing sharding and offloading hyperparameters + (e.g., shard_optimizer_over_data, optimizer_memory_host_offload). + mesh: JAX physical mesh used to resolve logical axis names to physical devices. + nnx_init_trainstate_fn: A zero-argument factory function that produces a + TrainStateNNX instance during the abstract trace. + is_training: Boolean indicating if the state is for training. If True, + optimizer state is processed and memory offloading strategies are applied. + + Returns: + A tuple containing (abstract_sharded_state, None, state_mesh_shardings): + abstract_sharded_state: An nnx.State containing ShapeDtypeStructs with + fully resolved physical sharding and memory_kind metadata. + state_mesh_annotations: An nnx.State tree consisting of the raw PartitionSpec + objects corresponding to each parameter/variable. + state_mesh_shardings: An nnx.State tree consisting of the raw JAX + Sharding objects corresponding to each parameter/variable. + """ + assert nnx_init_trainstate_fn is not None, "get_abstract_state_nnx: init function must be given." + + with nn_partitioning.axis_rules(config.logical_axis_rules): + # Use nnx.eval_shape + nnx.split instead of nnx.get_abstract_model, so we can apply + # get_nnx_named_sharding_with_scan_axis which correctly inserts the stacked-layers + # axis into the partition spec. nnx.get_abstract_model uses get_var_pspec internally + # which ignores nnx.PARTITION_NAME / param_scan_axis metadata set by _create_scanned_layers, + # causing the 2D partition spec to be misapplied to the 3D stacked parameter tensor. + # Do NOT wrap nnx.eval_shape in jax.set_mesh: Flax 0.12.6's _to_variable calls + # var.shape for every variable when a global mesh is active, but masked optimizer + # state variables (e.g. from trainable_parameters_mask) have value=MaskedNode() + # which has no .shape and would raise AttributeError. We handle sharding + # ourselves via get_nnx_named_sharding_with_scan_axis, so auto-assignment is not + # needed here. + abs_model = nnx.eval_shape(nnx_init_trainstate_fn) + _, abs_var_state = nnx.split(abs_model) + named_sharding_state = get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh) + abstract_state = jax.tree.map( + lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), + abs_var_state, + named_sharding_state, + ) + + state_mesh_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + + if is_training and config.shard_optimizer_over_data: + # Add data to sharding for optimizer state + optimizer_sharding = jax.tree_util.tree_map_with_path( + functools.partial(sharding.add_data_to_sharding, mesh), + abstract_state.optimizer, + state_mesh_shardings.optimizer, + ) + state_mesh_shardings.optimizer = optimizer_sharding + if is_training and config.optimizer_memory_host_offload: + optimizer_sharding = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_host, + state_mesh_shardings.optimizer, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + state_mesh_shardings.optimizer = optimizer_sharding + if is_training and config.parameter_memory_host_offload: + assert config.param_scan_axis == 0, "You must set the scan axis 0 to enable parameter offloading." + _, state_params, _ = nnx.split(state_mesh_shardings, nnx.Param, ...) + state_params = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_host, + state_params, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + nnx.update(state_mesh_shardings, state_params) + + abstract_sharded_state = maxtext_utils_nnx.set_named_sharding_nnx(abstract_state, state_mesh_shardings) + state_mesh_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings) + return ( + abstract_sharded_state, + state_mesh_annotations, + state_mesh_shardings, + ) + + def get_prefill_kv_cache_annotations(model, config, rng, mesh, page_state: None | PageState = None): """Get a shaped abstraction of the state (including optimizer)""" diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index ab85894832..92ef7e1251 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -1,3 +1,17 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Copyright 2023–2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,11 +32,11 @@ import dataclasses import collections from collections.abc import Sequence +from typing import Callable, overload from functools import partial import os import subprocess import sys -from typing import overload from etils import epath from flax import nnx import flax.linen as nn @@ -516,34 +530,99 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng return model -def create_nnx_abstract_model(config, mesh, model_mode=MODEL_MODE_TRAIN, rng_key=None): - """Returns (_create_model_partial, abstract_model) for AOT compilation. +def get_nnx_create_model_fn(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None) -> Callable: - This does not shard parameters or load checkpoints. It only builds the - abstract shape/dtype structure needed by get_abstract_state and optimizer - construction (e.g. Muon). + def _create_model(): + rngs = maxtext_utils_nnx.create_nnx_rngs(config, model_mode=model_mode, rng_key=rng_key) + return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode) - Args: - config: the configuration - mesh: the device mesh - model_mode: train or inference - rng_key: optional RNG key + return _create_model + + +def create_nnx_abstract_model( + config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None +) -> tuple[Callable, nnx.Module]: + """Creates an abstract NNX model. Returns: - (_create_model_partial, abstract_model) where _create_model_partial() creates - a concrete model instance and abstract_model is the eval_shape result. + A tuple containing (create_model_fn, abstract_model): + create_model_fn: A zero-argument callable that produces a new model instance. + abstract_model: The stateful NNX model instance in an abstract state. """ - def _create_model(rng_key=None): - rngs = maxtext_utils_nnx.create_nnx_rngs(config, model_mode=model_mode, rng_key=rng_key) - return from_config(config, mesh=mesh, rngs=rngs, model_mode=model_mode) + with nn.logical_axis_rules(config.logical_axis_rules): + _create_model = get_nnx_create_model_fn(config, mesh, devices, model_mode, rng_key) + if mesh is None: + _tmp = nnx.eval_shape(_create_model) + mesh = _tmp.mesh + # Use nnx.eval_shape + our scan-axis-aware sharding helper instead of + # nnx.get_abstract_model, which uses get_var_pspec internally and ignores + # param_scan_axis / nnx.PARTITION_NAME metadata set by _create_scanned_layers, + # causing the stacked layers axis to be missing from the PartitionSpec. + with jax.set_mesh(mesh): + abs_model = nnx.eval_shape(_create_model) + graphdef, abs_var_state = nnx.split(abs_model) + named_sharding_state = maxtext_utils.get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh) + abstract_state = jax.tree.map( + lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), + abs_var_state, + named_sharding_state, + ) + return _create_model, nnx.merge(graphdef, abstract_state) + + +def create_nnx_sharded_model_hybrid(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): + """Creates a sharded model for hybrid NNX modules containing Linen sub-modules. - _create_model_partial = partial(_create_model, rng_key=rng_key) + DEPRECATED: This function is a transitional utility for the Linen-to-NNX + migration. It should be removed once all model components are ported to + pure NNX modules. + + This function specifically handles the complexity of "mixed" state initialization, + where logical sharding annotations must be resolved for both NNX native + Parameters and legacy Linen variables wrapped via the NNX-Linen bridge. + It ensures that both systems correctly respect the provided mesh and + logical axis rules during the abstraction/sharding planning phase. + """ + _create_model_partial = get_nnx_create_model_fn(config, mesh, devices, model_mode, rng_key) with nn.logical_axis_rules(config.logical_axis_rules): abstract_model = nnx.eval_shape(_create_model_partial) + graphdef, abstract_state = nnx.split(abstract_model) + specs = nnx.get_partition_spec(abstract_state) + + if mesh is None: + mesh = abstract_model.mesh + + # JIT a function that creates the model state with proper sharding from the start. + # By providing out_shardings, we instruct JAX to produce sharded output directly, + # avoiding a large intermediate allocation on a single device. + with nn.logical_axis_rules(config.logical_axis_rules): + out_shardings = nn.logical_to_mesh_sharding(specs, mesh) + + @partial(jax.jit, out_shardings=out_shardings) + def create_sharded_state(): + # This will be JIT-compiled. JAX knows the output sharding and can + # initialize the parameters directly on the target devices in a sharded way. + model = _create_model_partial() + return nnx.state(model) - return _create_model_partial, abstract_model + with mesh: + # Create the model with sharded parameters. + with nn.logical_axis_rules(config.logical_axis_rules): + sharded_state = create_sharded_state() + model = nnx.merge(graphdef, sharded_state) + + # print weights sharding info under debug sharding mode + if config.debug_sharding: + max_utils.print_non_trivial_mesh_axis(model.mesh) + maxtext_utils.print_shardings_params( + params=sharded_state, + params_sharding=out_shardings, + mesh=model.mesh, + logical_annotations=specs, + ) + return model def setup_configs_and_devices(argv: list[str] | None = None, kwargs: dict | None = None, **extra_kwargs): @@ -728,60 +807,30 @@ def from_pretrained( ) config = pyconfig.HyperParameters(new_config) - def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None): - rngs = maxtext_utils_nnx.create_nnx_rngs(config, model_mode=model_mode, rng_key=rng_key) - return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode) - - _create_model_partial = partial(_create_model, mesh=mesh, model_mode=model_mode, rng_key=rng_key) - + if config.pure_nnx: + _create_model, abstract_model = create_nnx_abstract_model(config, mesh, devices, model_mode, rng_key) + model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model, mesh=mesh) + # TODO: print debug_sharding info + else: + model = create_nnx_sharded_model_hybrid(config, mesh, devices, model_mode, rng_key) + + # Compute logical-axis specs for downstream checkpoint alignment. + # The model-creation helpers above resolve specs internally for sharding, but + # the checkpoint-loading branch below needs the logical PartitionSpec tree + # (axis names like "kv_heads", "mlp_moe") for repeat/zero-pad dispatch in + # _align_checkpoint_to_model_shapes. nnx.eval_shape is cheap (abstract trace). + _create_model_for_specs = get_nnx_create_model_fn(config, mesh, devices, model_mode, rng_key) with nn.logical_axis_rules(config.logical_axis_rules): - abstract_model = nnx.eval_shape(_create_model_partial) - graphdef, abstract_state = nnx.split(abstract_model) - specs = nnx.get_partition_spec(abstract_state) - - if mesh is None: - mesh = abstract_model.mesh - - # Note for pure_nnx: - # Currently, the NNX model returned has a linen decoder wrapped to NNX. So it is not a pure NNX model and - # we still need to use nn.logical_axis_rules(config.logical_axis_rules) to get the out sharding from the linen - # LogicallyPartitioned structure. - # In the future if the pure NNX model is used, with pure NNX's eager sharding, there will be no LogicallyPartitioned - # structure in the abstract state and we can get the sharded state with the following code: - # graphdef, state = nnx.get_abstract_model(_create_model_partial, mesh) - # abstract_model = nnx.merge(graphdef, state) - # model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model_partial, mesh=mesh) - # sharded_state = nnx.state(model) + _abs_model_for_specs = nnx.eval_shape(_create_model_for_specs) + _, _abs_state_for_specs = nnx.split(_abs_model_for_specs) + specs = nnx.get_partition_spec(_abs_state_for_specs) - # JIT a function that creates the model state with proper sharding from the start. - # By providing out_shardings, we instruct JAX to produce sharded output directly, - # avoiding a large intermediate allocation on a single device. - with nn.logical_axis_rules(config.logical_axis_rules): - out_shardings = nn.logical_to_mesh_sharding(specs, mesh) + sharded_state = nnx.state(model) - @partial(jax.jit, out_shardings=out_shardings) - def create_sharded_state(): - # This will be JIT-compiled. JAX knows the output sharding and can - # initialize the parameters directly on the target devices in a sharded way. - model = _create_model_partial() - return nnx.state(model) + if mesh is None: + mesh = model.mesh with mesh: - # Create the model with sharded parameters. - with nn.logical_axis_rules(config.logical_axis_rules): - sharded_state = create_sharded_state() - model = nnx.merge(graphdef, sharded_state) - - # print weights sharding info under debug sharding mode - if config.debug_sharding: - max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_shardings_params( - params=sharded_state, - params_sharding=out_shardings, - mesh=model.mesh, - logical_annotations=specs, - ) - if config.load_parameters_path: try: ckptr = ocp.Checkpointer( diff --git a/src/maxtext/utils/muon_utils.py b/src/maxtext/utils/muon_utils.py index 3ba60d7371..3bd2b186b1 100644 --- a/src/maxtext/utils/muon_utils.py +++ b/src/maxtext/utils/muon_utils.py @@ -24,25 +24,23 @@ python3 -m maxtext.utils.muon_utils qwen3-4b True """ - import os import sys from typing import Optional, Tuple import flax.linen as nn +from flax import nnx import jax from maxtext.configs import pyconfig from maxtext.utils.globals import MAXTEXT_PKG_DIR from maxtext.layers import quantizations from maxtext.models import models -from maxtext.utils import maxtext_utils +from maxtext.utils import maxtext_utils, model_creation_utils from optax.contrib._muon import MuonDimensionNumbers as mdn -Transformer = models.transformer_as_linen - - def _is_path_contain_any(tuples, path): + """Checks if any element in 'tuples' is present in 'path'.""" return any(x in path for x in tuples) @@ -107,10 +105,25 @@ def get_transform_tree(tree, path=()): def get_muon_weight_dimension_numbers(model, config, verbose=False): """Extract muon dimension number from model structure.""" - # quickly get param structure without materialization - abstract_param = maxtext_utils.get_abstract_param(model, config) - # get muon dimension number from param - muon_weight_dimension_numbers = get_transform_tree(abstract_param) + + if isinstance(model, nnx.Module): + _, abstract_param, _ = nnx.split(model, nnx.Param, ...) + + def apply_transform_nnx(path: Tuple[jax.tree_util.KeyEntry, ...], leaf): + # Convert jax.tree_util.KeyEntry path to Tuple[str, ...] + path_strings = tuple(p.key for p in path if isinstance(p, jax.tree_util.DictKey)) + return transform_logic(path_strings) + + # Use jax.tree_util.tree_map_with_path for NNX's potentially complex PyTree structure. + # This is different with linen where abstract_param is a dict-based tree with nn.LogicallyPartitioned leaves. + muon_weight_dimension_numbers = jax.tree_util.tree_map_with_path(apply_transform_nnx, abstract_param) + + else: # Linen + # quickly get param structure without materialization + abstract_param = maxtext_utils.get_abstract_param(model, config) + # get muon dimension number from param + muon_weight_dimension_numbers = get_transform_tree(abstract_param) + if verbose: _print_structure_debug(abstract_param, muon_weight_dimension_numbers) return muon_weight_dimension_numbers @@ -118,19 +131,30 @@ def get_muon_weight_dimension_numbers(model, config, verbose=False): def _print_structure_debug(abstract_param, muon_weight_dimension_numbers): """Prints the model structure and the resulting Muon config.""" - # Access the shape from the inner ShapeDtypeStruct and names from the wrapper - # Return a new tree with the same structure containing only shapes/names + + def get_leaf_info(leaf): + # For linen: + # Access the shape from the inner ShapeDtypeStruct and names from the wrapper + # Return a new tree with the same structure containing only shapes/names + if isinstance(leaf, nn.LogicallyPartitioned): + return {"shape": leaf.value.shape, "names": leaf.names} + # For nnx: + # Only return the shape because it doesn't have a wrapper. + elif isinstance(leaf, jax.ShapeDtypeStruct): + return {"shape": leaf.shape} + return {"shape": "N/A"} + info_tree = jax.tree_util.tree_map( - lambda leaf: {"shape": leaf.value.shape, "names": leaf.names}, + get_leaf_info, abstract_param, - is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned), + is_leaf=lambda x: isinstance(x, (nn.LogicallyPartitioned, jax.ShapeDtypeStruct)), ) print(f"\n=== Model Structure ===\n{info_tree}") print(f"\n=== Muon Dimension Numbers ===\n{muon_weight_dimension_numbers}") print("\nIs this reasonable?") -def get_model_mdn(model_name, scan_layers=True, verbose=False): +def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=False): """Initializes a model and retrieves its Muon dimension numbers. This function sets up the configuration for a given model, initializes the @@ -154,13 +178,17 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False): f"model_name={model_name}", f"scan_layers={scan_layers}", "attention=dot_product", + f"pure_nnx={pure_nnx}", ] config = pyconfig.initialize(argv) # Setup model devices_array = maxtext_utils.create_device_mesh(config) mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) quant = quantizations.configure_quantization(config) - model = Transformer(config, mesh=mesh, quant=quant) + if pure_nnx: + _, model = model_creation_utils.create_nnx_abstract_model(config, mesh) + else: + model = models.transformer_as_linen(config, mesh=mesh, quant=quant) # Get dimension number muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config, verbose=verbose) return muon_weight_dimension_numbers @@ -172,4 +200,4 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False): sys.exit(1) model_name_arg = sys.argv[1] scan_layers_arg = sys.argv[2].lower() == "true" - get_model_mdn(model_name_arg, scan_layers_arg, verbose=True) + get_model_mdn(model_name_arg, scan_layers_arg, verbose=True, pure_nnx=False) diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index d4bb64f016..4a500e2fe1 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -15,7 +15,7 @@ # pylint: disable=line-too-long, disable=bare-except, consider-using-generator """ Utils that are only interesting to MaxText and sharding related. """ -from flax import linen as nn +from flax import linen as nn, nnx from collections.abc import Iterable @@ -25,6 +25,7 @@ import optax +from maxtext.configs import pyconfig from maxtext.common.common_types import ShardMode from maxtext.utils import max_logging from maxtext.utils import max_utils @@ -483,6 +484,8 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): - updated_state_mesh_shardings: State mesh shardings with updated params field (unchanged if shard_optimizer_over_data is False) """ + if config.pure_nnx: + return maybe_update_params_sharding_with_opt_nnx(config, state_mesh_shardings) prev_params_shardings = state_mesh_shardings.params if config.shard_optimizer_over_data: if isinstance(state_mesh_shardings.opt_state, optax.ScaleByAdamState): @@ -501,6 +504,122 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): return prev_params_shardings, state_mesh_shardings +def maybe_update_params_sharding_with_opt_nnx( + config: pyconfig.HyperParameters, state_mesh_shardings: nnx.State +) -> tuple[nnx.State, nnx.State]: + """ + NNX version of parameter sharding update. Updates parameter sharding configuration + when optimizer state sharding is enabled. + + When shard_optimizer_over_data is enabled (Zero-1 style sharding), this function + extracts the optimizer state shardings from the Adam optimizer's first moment (mu) + and merges them with the parameter shardings. This ensures parameter sharding is + consistent with how the optimizer state is distributed across the compute mesh. + + Args: + config: Configuration with shard_optimizer_over_data flag. + state_mesh_shardings: The sharding state for a TrainStateNNX container. + + Returns: + A tuple of (prev_params_shardings, updated_state_mesh_shardings): + - prev_params_shardings: Original parameter shardings before the update + - updated_state_mesh_shardings: State mesh shardings with updated params field + (unchanged if shard_optimizer_over_data is False)""" + # In TrainStateNNX, parameters are under 'model' + model_shardings = state_mesh_shardings.model + + def _extract_param_only(state): + """Recursively extract nnx.Param variables from an nnx.State into a nested plain dict. + + Constructs nnx.State({'key': nested_dict, ...}) which produces the same pytree + structure as nnx.split(model, nnx.Param, ...)[1], enabling jax.tree.map + to work correctly between ga_params (Param-only) and params_shardings. + """ + result = {} + for k, v in state.items(): + if isinstance(v, nnx.Param): + result[k] = v + elif isinstance(v, nnx.Variable): + pass # skip non-Param variables (RngKey, RngCount, OptVariable, etc.) + elif hasattr(v, "items"): + sub = _extract_param_only(v) + if sub: + result[k] = sub + return result + + # prev_params_shardings must match the pytree structure of ga_params from + # nnx.split(model, nnx.Param, ...) — Param variables only, no rngs. + prev_params_shardings = nnx.State(_extract_param_only(model_shardings)) + + if not config.shard_optimizer_over_data: + return prev_params_shardings, state_mesh_shardings + + sharded_fp32_params = None + # Check if the optimizer has any state at all (stateless optimizers like SGD omit this key) + if "opt_state" in state_mesh_shardings.optimizer: + # Access the optimizer branch to find the optax state + # state_mesh_shardings.optimizer contains the sharding for the nnx.Optimizer + opt_state = state_mesh_shardings.optimizer.opt_state + + def find_adam_mu(obj): + # 1. Direct hit on ScaleByAdamState (Linen path or unflattened NNX) + if isinstance(obj, optax.ScaleByAdamState): + return obj.mu + + # 2. Check for flattened ScaleByAdamState (nnx.State/dict) + # These nodes contain 'mu', 'nu', and 'count' as keys. + if hasattr(obj, "__getitem__") and "mu" in obj and "nu" in obj: + return obj["mu"] + + # 3. Recursive search through containers (nnx.State, dict, list, tuple) + values = None + if hasattr(obj, "values"): # Handles nnx.State and dict + values = obj.values() + elif isinstance(obj, (list, tuple)): + values = obj + + if values: + for v in values: + res = find_adam_mu(v) + if res is not None: + return res + return None + + sharded_fp32_params = find_adam_mu(opt_state) + if sharded_fp32_params is None: + actual_type = type(state_mesh_shardings.optimizer.get("opt_state", "None")) + raise NotImplementedError(f"Could not find Adam optimizer state in: {actual_type}") + + # Update model parameter sharding to match the mu (first moment) sharding. + # This ensures parameter sharding is consistent with the Zero-1 distributed layout. + # Build a path → new_PS lookup from sharded_fp32_params (mu), then update model_shardings + # at those paths while preserving rngs and any other non-Param variables. + mu_leaves_with_paths = list( + jax.tree_util.tree_leaves_with_path(sharded_fp32_params, is_leaf=lambda x: isinstance(x, nnx.Variable)) + ) + mu_lookup = {path: mu_var.get_value() for path, mu_var in mu_leaves_with_paths} + + def _update_model_var(path, var): + if path in mu_lookup: + return var.replace(mu_lookup[path]) + return var + + new_model_shardings = jax.tree_util.tree_map_with_path( + _update_model_var, model_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable) + ) + # Use jax.tree_util.tree_map (identity) to create a new nnx.State via JAX's unflatten + # mechanism (not the nnx.State constructor). This is critical because: + # 1. nnx.State({...}) constructor recursively converts nested plain dicts to nnx.State, + # causing a pytree type mismatch with the actual state from nnx.split (which stores + # nested module states as plain dicts). JAX's unflatten preserves the original types. + # 2. copy.deepcopy fails because NamedSharding contains non-picklable jaxlib.Device objects. + # Direct __setattr__ assignment stores new_model_shardings as-is (no type conversion). + updated_state = jax.tree_util.tree_map(lambda x: x, state_mesh_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable)) + updated_state.model = new_model_shardings + + return prev_params_shardings, updated_state + + def logical_axis_rules_pp_act_as_dp(logical_rules): """Add stage as a physical axes before data for each rule, so stage acts just like data instead of PP. This is used when we want to pipeline only a subset of layers, and leave the rest like DP. diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index 906a597728..ca90550630 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -15,12 +15,14 @@ # pylint: disable=bare-except, consider-using-generator """Utils that are only interesting for training in MaxText.""" +import functools import os from functools import partial import jax -import functools +from flax import nnx from flax.linen import partitioning as nn_partitioning +from maxtext.layers import train_state_nnx from maxtext.common import checkpointing from maxtext.common.data_loader import create_dataloader from maxtext.common.goodput import GoodputEvent, maybe_record_goodput @@ -205,7 +207,7 @@ def setup_train_loop(config, recorder, devices=None): data_iterator: data_loader: rampup_manager: the class managing rampup batch sizes - state: the initialized train state + train_state: the initialized train state. For NNX, this is a TrainStateNNX instance """ # pylint: disable=import-outside-toplevel from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator @@ -213,16 +215,22 @@ def setup_train_loop(config, recorder, devices=None): with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): is_training = True init_rng = jax.random.PRNGKey(config.init_weights_seed) + mesh = maxtext_utils.get_mesh_from_config(config, devices) if config.pure_nnx: # Create abstract NNX model. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + _create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, mesh, devices) else: model = model_creation_utils.from_config(config, devices) - mesh = model.mesh learning_rate_schedule, tx = create_training_optimizer(config, model) + if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + # For NNX, the train state is wrapped in the TrainStateNNX module. + def create_train_state_fn(): + model = _create_model_partial() + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(model, optimizer) + + init_state_fn = create_train_state_fn else: init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, is_training, init_rng) checkpoint_manager = create_checkpoint_manager(config, mesh, init_state_fn) @@ -266,6 +274,15 @@ def setup_train_loop(config, recorder, devices=None): state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( data_iterator, config, mesh, checkpoint_manager, init_state_fn ) + if config.pure_nnx: + with nn_partitioning.axis_rules(config.logical_axis_rules): + # train_state is instance of TrainStateNNX + state_graphdef, _ = nnx.get_abstract_model(init_state_fn, mesh) + _, state_params, _ = nnx.split(state.model, nnx.Param, ...) + _, state_mesh_shardings_params, _ = nnx.split(state_mesh_shardings.model, nnx.Param, ...) + else: + state_params = state.params + state_mesh_shardings_params = state_mesh_shardings.params if config.enable_diloco: with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): @@ -283,17 +300,24 @@ def setup_train_loop(config, recorder, devices=None): # TODO(aireenmei, hengtaoguo): support sharding in vit for multimodal if not config.using_pipeline_parallelism and not config.use_multimodal: # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage - sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance) + sharding.assert_params_sufficiently_sharded(state_params, mesh, config.sharding_tolerance) # print weights sharding info under debug sharding mode if config.debug_sharding: - logical_annotations = maxtext_utils.get_logical_annotations(config, mesh, init_state_fn) + if config.pure_nnx: + # TODO: Study how to get logical annotations of NNX module. Because of eager sharding, we + # probably already lost the logical partition info at this moment. + logical_annotations_params = None + else: + logical_annotations = maxtext_utils.get_logical_annotations(config, mesh, init_state_fn) + logical_annotations_params = logical_annotations.params + max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_shardings_params( - state.params, state_mesh_shardings.params, model.mesh, logical_annotations.params - ) + maxtext_utils.print_shardings_params(state_params, state_mesh_shardings_params, mesh, logical_annotations_params) if config.use_dpo: + if config.pure_nnx: + raise NotImplementedError("DPO is not supported yet by NNX models.") abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training) max_logging.log( "Restoring reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" @@ -318,12 +342,18 @@ def setup_train_loop(config, recorder, devices=None): except FileNotFoundError: step0_restored = None if step0_restored is not None: + # TODO: For pure_nnx, the dpo state manipulation is different. reference_params = step0_restored["items"].params["params"] state = _merge_dpo_state(state, reference_params) else: max_logging.log( "Could not restore reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" ) + if config.pure_nnx: + train_state = nnx.merge(state_graphdef, state) + model = train_state.model + else: + train_state = state return ( init_rng, @@ -336,7 +366,7 @@ def setup_train_loop(config, recorder, devices=None): data_loader, rampup_manager, eval_data_iterator, - state, + train_state, ) diff --git a/tests/integration/setup_train_loop_nnx_test.py b/tests/integration/setup_train_loop_nnx_test.py new file mode 100644 index 0000000000..d11f9658a7 --- /dev/null +++ b/tests/integration/setup_train_loop_nnx_test.py @@ -0,0 +1,140 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for setup_train_loop with pure_nnx=True. + +setup_train_loop wires together create_nnx_abstract_model, the training optimizer, +the checkpoint manager, the data iterator, and finally nnx.split / nnx.merge to +return a fully-formed TrainStateNNX. This test exercises that wiring end-to-end +on a tiny synthetic config — the goal is to cover the integration glue that the +unit tests in tests/unit/train_utils_nnx_test.py cannot reach. +""" + +import os +import sys +import unittest + +import pytest + +import jax +from flax import nnx + +from maxtext.configs import pyconfig +from maxtext.layers import train_state_nnx +from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT +from maxtext.utils.train_utils import setup_train_loop +from tests.utils.test_helpers import get_test_config_path + + +def _tiny_nnx_pyconfig(**overrides): + """Build a tiny pyconfig suitable for a single-host setup_train_loop run.""" + init_kwargs = { + "run_name": "setup_train_loop_nnx_test", + "enable_checkpointing": False, + "dataset_type": "synthetic", + "model_name": "default", + "pure_nnx": True, + "per_device_batch_size": 1.0, + "base_emb_dim": 8, + "base_num_query_heads": 4, + "base_num_kv_heads": 4, + "base_mlp_dim": 32, + "base_num_decoder_layers": 2, + "head_dim": 128, + "max_target_length": 128, + "vocab_size": 256, + "steps": 1, + "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"), + "enable_goodput_recording": False, + "enable_checkpoint_cloud_logger": False, + "monitor_goodput": False, + } + init_kwargs.update(overrides) + return pyconfig.initialize([sys.argv[0], get_test_config_path()], **init_kwargs) + + +@pytest.mark.integration_test +@pytest.mark.tpu_only +class SetupTrainLoopNNXIntegrationTest(unittest.TestCase): + """End-to-end check that setup_train_loop returns a usable TrainStateNNX.""" + + def test_pure_nnx_setup_returns_train_state_nnx(self): + config = _tiny_nnx_pyconfig() + + ( + init_rng, + checkpoint_manager, + state_mesh_shardings, + model, + mesh, + learning_rate_schedule, + data_iterator, + data_loader, + rampup_manager, + eval_data_iterator, + train_state, + ) = setup_train_loop(config, recorder=None) + + # The NNX path returns a fully-merged TrainStateNNX (lines 352-354 in train_utils.py). + self.assertIsInstance(train_state, train_state_nnx.TrainStateNNX) + # Optimizer.step starts at 0 for a fresh init. + self.assertEqual(int(train_state.optimizer.step.get_value()), 0) + # The returned model is train_state.model, an NNX module. + self.assertIsInstance(model, nnx.Module) + self.assertIs(model, train_state.model) + + # Sanity for sibling outputs: + self.assertIsNotNone(init_rng) + self.assertIsNotNone(mesh) + self.assertTrue(callable(learning_rate_schedule)) + # data_loader is mandatory; data_iterator may be wrapped/unwrapped. + self.assertIsNotNone(data_loader) + self.assertIsNotNone(data_iterator) + + # state_mesh_shardings (NNX) is an nnx.State and contains a 'model' branch. + self.assertIsInstance(state_mesh_shardings, nnx.State) + self.assertIn("model", state_mesh_shardings) + + # Cleanup: the rest are not asserted on but referenced so linters don't + # flag them as unused — they're part of the public return contract. + del checkpoint_manager, rampup_manager, eval_data_iterator + + def test_pure_nnx_setup_param_only_split_matches_model(self): + """nnx.split(state.model, nnx.Param, ...) must yield a non-empty Param tree + whose structure matches state_mesh_shardings.model after the same split.""" + config = _tiny_nnx_pyconfig() + *_, state_mesh_shardings, model, _, _, _, _, _, _, train_state = setup_train_loop(config, recorder=None) + + _, params, _ = nnx.split(train_state.model, nnx.Param, ...) + _, params_shardings, _ = nnx.split(state_mesh_shardings.model, nnx.Param, ...) + + # Same key-set after nnx.split — this is what setup_train_loop relies on at + # train_utils.py:281-282 to pair state_params with state_mesh_shardings_params. + self.assertEqual(jax.tree_util.tree_structure(params), jax.tree_util.tree_structure(params_shardings)) + self.assertGreater(len(jax.tree.leaves(params)), 0) + + del model + + def test_pure_nnx_dpo_raises_not_implemented(self): + """The use_dpo branch (train_utils.py:319-320) must raise for NNX.""" + # use_dpo requires a few prerequisites; the simplest is to set the flag and + # let setup_train_loop reach the NotImplementedError check before the more + # involved DPO path runs. + config = _tiny_nnx_pyconfig(use_dpo=True, packing=False) + with self.assertRaises(NotImplementedError): + setup_train_loop(config, recorder=None) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/checkpointing_nnx_load_test.py b/tests/unit/checkpointing_nnx_load_test.py new file mode 100644 index 0000000000..622f19323a --- /dev/null +++ b/tests/unit/checkpointing_nnx_load_test.py @@ -0,0 +1,106 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX branches of load_state_if_possible.""" + +import unittest +from unittest import mock + +import jax +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.common import checkpointing +from maxtext.layers import train_state_nnx + + +class _Model(nnx.Module): + """Tiny single-linear NNX model for restore tests.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + +def _abstract_nnx_state(): + """Build an nnx.State from a TrainStateNNX — same shape that pre_train passes in.""" + model = _Model(rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + return nnx.state(train_state_nnx.TrainStateNNX(model, optimizer)) + + +class TestLoadStateIfPossibleNNX(unittest.TestCase): + """Cover the NNX branches in load_state_if_possible.""" + + def test_load_parameters_from_path_splits_nnx_state_for_param_view(self): + """When abstract_unboxed_pre_state is an nnx.State, the function must call + nnx.split(model, nnx.Param, ...) to get the params and forward them to load_params_from_path.""" + abstract = _abstract_nnx_state() + sentinel_restored = {"linear": {"kernel": jnp.ones((2, 1)), "bias": jnp.zeros((1,))}} + + with mock.patch.object(checkpointing, "load_params_from_path", return_value=sentinel_restored) as m: + full, params = checkpointing.load_state_if_possible( + checkpoint_manager=None, + data_iterator=None, + load_parameters_from_path="gs://does-not-exist/params", + load_full_state_from_path="", + checkpoint_storage_concurrent_gb=8, + abstract_unboxed_pre_state=abstract, + ) + + self.assertIsNone(full) + self.assertIs(params, sentinel_restored) + m.assert_called_once() + forwarded_params = m.call_args[0][1] # second positional arg = abstract_unboxed_params + # The forwarded params come from nnx.split(..., nnx.Param, ...) — same key shape as the model. + leaves = jax.tree.leaves(forwarded_params) + self.assertEqual(len(leaves), 2) # linear.kernel + linear.bias + + def test_load_parameters_from_path_uses_state_params_for_linen(self): + """For Linen TrainState, the function must use state.params (not nnx.split).""" + fake_state = mock.Mock(spec=["params"]) + fake_state.params = {"layer": {"kernel": jnp.ones((2, 2))}} + sentinel = object() + + with mock.patch.object(checkpointing, "load_params_from_path", return_value=sentinel) as m: + full, params = checkpointing.load_state_if_possible( + checkpoint_manager=None, + data_iterator=None, + load_parameters_from_path="gs://does-not-exist/params", + load_full_state_from_path="", + checkpoint_storage_concurrent_gb=8, + abstract_unboxed_pre_state=fake_state, + ) + + self.assertIsNone(full) + self.assertIs(params, sentinel) + forwarded_params = m.call_args[0][1] + self.assertIs(forwarded_params, fake_state.params) + + def test_no_paths_returns_none_none(self): + """Sanity: with no checkpoint manager and no load paths, the function returns (None, None).""" + full, params = checkpointing.load_state_if_possible( + checkpoint_manager=None, + data_iterator=None, + load_parameters_from_path="", + load_full_state_from_path="", + checkpoint_storage_concurrent_gb=8, + abstract_unboxed_pre_state=_abstract_nnx_state(), + ) + self.assertIsNone(full) + self.assertIsNone(params) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/gradient_accumulation_nnx_test.py b/tests/unit/gradient_accumulation_nnx_test.py new file mode 100644 index 0000000000..6353f02397 --- /dev/null +++ b/tests/unit/gradient_accumulation_nnx_test.py @@ -0,0 +1,159 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX branch of gradient_accumulation_loss_and_grad.""" + +import unittest +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +import numpy as np +from flax import nnx +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from maxtext.common.common_types import ShardMode +from maxtext.utils import gradient_accumulation + + +@dataclass +class _Cfg: + gradient_accumulation_steps: int = 2 + shard_optimizer_over_data: bool = False + shard_mode: int = ShardMode.AUTO + ici_data_parallelism: int = 1 + debug_sharding: bool = False + + +class _TinyNNX(nnx.Module): + """Single linear layer NNX model.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + def __call__(self, x): + return self.linear(x) + + +def _fake_loss_fn(model, config, data, dropout_rng, params, is_train=True): + """A loss_fn shaped like the production loss_fn but for a tiny linear model. + + Returns (loss, aux) where aux follows the schema gradient_accumulation_loss_and_grad + reads from: xent_sum / total_weights / moe_lb_loss / indexer_loss / mtp_loss. + """ + del config, dropout_rng, params, is_train + pred = model(data["inputs"]) + per_sample_loss = jnp.mean((pred - data["targets"]) ** 2, axis=-1) + xent_sum = jnp.sum(per_sample_loss) + total_weights = jnp.array(per_sample_loss.shape[0], dtype=jnp.float32) + aux = { + "xent_sum": xent_sum, + "total_weights": total_weights, + "moe_lb_loss": jnp.array(0.0), + "indexer_loss": jnp.array(0.0), + "mtp_loss": jnp.array(0.0), + } + return xent_sum / total_weights, aux + + +class TestGradientAccumulationNNX(unittest.TestCase): + """Cover the NNX path of gradient_accumulation_loss_and_grad.""" + + def setUp(self): + self.model = _TinyNNX(rngs=nnx.Rngs(0)) + self.cfg = _Cfg(gradient_accumulation_steps=2) + # 4 examples → 2 microbatches of 2 each + self.data = { + "inputs": jnp.arange(8.0).reshape(4, 2), + "targets": jnp.zeros((4, 1)), + } + + def _params_shardings(self): + """Build a per-leaf NamedSharding tree shaped like nnx.split(model, nnx.Param, ...)[1]. + + Uses a trivial single-device mesh so jax.lax.with_sharding_constraint accepts the + sharding without contradicting the actual device topology. + """ + _, params, _ = nnx.split(self.model, nnx.Param, ...) + mesh = Mesh( + np.array(jax.local_devices()[:1]).reshape( + 1, + ), + ("x",), + ) + ns = NamedSharding(mesh, PartitionSpec()) + return jax.tree.map(lambda _: ns, params) + + def test_nnx_path_runs_and_returns_grad_for_every_param(self): + """The NNX branch must call nnx.value_and_grad and return one gradient per Param.""" + loss, aux, raw_grads = gradient_accumulation.gradient_accumulation_loss_and_grad( + _fake_loss_fn, + self.cfg, + self.model, + params=None, # NNX branch ignores params + params_shardings=self._params_shardings(), + data=self.data, + dropout_rng=None, + extra_dpo_args=[], + ) + self.assertTrue(jnp.isfinite(loss)) + self.assertIn("xent_sum", aux) + self.assertIn("total_weights", aux) + grad_leaves = jax.tree.leaves(raw_grads) + self.assertEqual(len(grad_leaves), 2) # linear.kernel + linear.bias + for g in grad_leaves: + self.assertTrue(jnp.all(jnp.isfinite(g))) + + def test_nnx_path_updates_model_rest_state_after_scan(self): + """After accumulation, nnx.update is called on the model with the rest_state from the scan. + + For a TinyNNX (no rngs/dropout), the rest tree is empty but the call path must still + succeed end-to-end without raising — covering the `if is_nnx: nnx.update(...)` branch. + """ + pre_kernel = self.model.linear.kernel.value.copy() + gradient_accumulation.gradient_accumulation_loss_and_grad( + _fake_loss_fn, + self.cfg, + self.model, + params=None, + params_shardings=self._params_shardings(), + data=self.data, + dropout_rng=None, + extra_dpo_args=[], + ) + # The kernel itself is a Param — gradient_accumulation_loss_and_grad does not apply + # gradients to params, so the value should be untouched. + self.assertTrue(jnp.allclose(self.model.linear.kernel.value, pre_kernel)) + + def test_nnx_with_shard_optimizer_over_data_casts_to_bf16(self): + """Zero-1 path must convert fp32 params to bf16 before the scan loop.""" + self.cfg.shard_optimizer_over_data = True + # Should not raise; just verify the function runs and returns sensible outputs. + loss, _, raw_grads = gradient_accumulation.gradient_accumulation_loss_and_grad( + _fake_loss_fn, + self.cfg, + self.model, + params=None, + params_shardings=self._params_shardings(), + data=self.data, + dropout_rng=None, + extra_dpo_args=[], + ) + self.assertTrue(jnp.isfinite(loss)) + for g in jax.tree.leaves(raw_grads): + self.assertTrue(jnp.all(jnp.isfinite(g))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index 1d2f52b2fd..9d7e9749e9 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -15,11 +15,13 @@ """Tests for the common MaxText utilities""" import functools -from typing import Any, Sequence from collections.abc import Callable +from typing import Any, Sequence import unittest from unittest.mock import MagicMock, Mock, patch from dataclasses import dataclass, field +import numpy as np +import optax from flax import linen as nn from flax import nnx @@ -29,6 +31,7 @@ from jax import random, vmap import jax.numpy as jnp from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils from maxtext.configs import pyconfig from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_TRAIN, ShardMode from maxtext.inference import inference_utils @@ -39,8 +42,7 @@ from maxtext.utils import sharding from maxtext.utils.sharding import assert_params_sufficiently_sharded, get_formatted_sharding_annotations from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides -import numpy as np -import optax +from maxtext.utils import maxtext_utils_nnx Transformer = models.transformer_as_linen @@ -179,11 +181,7 @@ def setUp(self): "decoder": {"gate": {"bias": jnp.array([0.5, 0.5])}}, } self.state = train_state.TrainState( - step=0, - apply_fn=self.model.apply, - params=self.initial_params, - tx=None, - opt_state={}, + step=0, apply_fn=self.model.apply, params=self.initial_params, tx=None, opt_state={} ) def test_update_mode_add(self): @@ -196,10 +194,10 @@ def test_update_mode_add(self): self.assertTrue(jnp.allclose(actual, expected)) # Other values are untouched - original_layer_0 = self.state.params["layers"]["layer_0"]["bias"] + original_layer_0 = self.state.params["layers"]["layer_0"]["bias"] # pylint: disable=unsubscriptable-object new_layer_0 = new_state.params["layers"]["layer_0"]["bias"] self.assertTrue(jnp.array_equal(original_layer_0, new_layer_0)) - original_layer_1 = self.state.params["layers"]["layer_1"]["bias"] + original_layer_1 = self.state.params["layers"]["layer_1"]["bias"] # pylint: disable=unsubscriptable-object new_layer_1 = new_state.params["layers"]["layer_1"]["bias"] self.assertTrue(jnp.array_equal(original_layer_1, new_layer_1)) @@ -264,7 +262,7 @@ def test_init_training_state(self): @nnx.register_variable_name("special_variables") -class SpecialVariables(nnx.Variable): +class SpecialVariables(nnx.Variable): # pylint: disable=abstract-method pass @@ -281,7 +279,7 @@ def __call__(self, x, y, encoder_images=None, nnx_method=None, model_mode=None): return x -class TrainState(train_state.TrainState): +class TrainState(train_state.TrainState): # pylint: disable=abstract-method other_variables: nnx.State @@ -993,49 +991,63 @@ def train_step(_model, _config, _state_shardings, _params_shardings, state, _bat return train_step + def _make_mock_config(self, pure_nnx=False): + cfg = MagicMock() + cfg.pure_nnx = pure_nnx + return cfg + def test_returns_five_tuple(self): step = self._make_mock_step() result = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config() ) self.assertEqual(len(result), 5) def test_functional_train_has_correct_name(self): step = self._make_mock_step() fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config() ) self.assertEqual(fn.__name__, "train_step") - def test_in_shardings_structure(self): + def test_linen_in_shardings_includes_rng(self): + """pure_nnx=False: in_shardings should be (state, batch, rng).""" step = self._make_mock_step() _, in_shardings, _, _, _ = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config(pure_nnx=False) ) - # (state, batch, rng) self.assertEqual(len(in_shardings), 3) self.assertIsNone(in_shardings[2]) # rng sharding is None + def test_nnx_in_shardings_excludes_rng(self): + """pure_nnx=True: in_shardings should be (state, batch) — no rng slot.""" + step = self._make_mock_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_train_with_signature( + step, "data_sharding", "state_shardings", "model", self._make_mock_config(pure_nnx=True) + ) + self.assertEqual(len(in_shardings), 2) + def test_donate_argnums_is_zero(self): step = self._make_mock_step() _, _, _, _, donate_argnums = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config() ) self.assertEqual(donate_argnums, 0) def test_functional_train_is_partial(self): """functional_train should partially apply model and config.""" received = {} + cfg = self._make_mock_config() def train_step(model, config, _state_shardings, _params_shardings, state, _batch, _rng=None): received["model"] = model received["config"] = config return state, {} - fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature(train_step, "ds", "ss", "my_model", "my_config") + fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature(train_step, "ds", "ss", "my_model", cfg) fn("state", "batch") self.assertEqual(received["model"], "my_model") - self.assertEqual(received["config"], "my_config") + self.assertEqual(received["config"], cfg) class TestGetFunctionalEvalWithSignature(unittest.TestCase): @@ -1047,26 +1059,51 @@ def eval_step(_model, _config, _state, _batch, _rng=None): return eval_step + def _make_mock_config(self, pure_nnx=False): + cfg = MagicMock() + cfg.pure_nnx = pure_nnx + return cfg + def test_returns_five_tuple(self): step = self._make_mock_eval_step() - result = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + result = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", self._make_mock_config()) self.assertEqual(len(result), 5) def test_functional_eval_has_correct_name(self): step = self._make_mock_eval_step() - fn, _, _, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + fn, _, _, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", self._make_mock_config()) self.assertEqual(fn.__name__, "eval_step") def test_out_shardings_is_none(self): step = self._make_mock_eval_step() - _, _, out_shardings, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + _, _, out_shardings, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "ds", "ss", "model", self._make_mock_config() + ) self.assertIsNone(out_shardings) def test_donate_argnums_is_empty(self): step = self._make_mock_eval_step() - _, _, _, _, donate_argnums = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + _, _, _, _, donate_argnums = maxtext_utils.get_functional_eval_with_signature( + step, "ds", "ss", "model", self._make_mock_config() + ) self.assertEqual(donate_argnums, ()) + def test_nnx_in_shardings_excludes_rng(self): + """pure_nnx=True: in_shardings should be (state, batch) — no rng slot.""" + step = self._make_mock_eval_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "batch_sharding", "state_sharding", "model", self._make_mock_config(pure_nnx=True) + ) + self.assertEqual(len(in_shardings), 2) + + def test_linen_in_shardings_includes_rng(self): + """pure_nnx=False: in_shardings should be (state, batch, rng).""" + step = self._make_mock_eval_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "batch_sharding", "state_sharding", "model", self._make_mock_config(pure_nnx=False) + ) + self.assertEqual(len(in_shardings), 3) + class TestGetShapedBatch(unittest.TestCase): """Tests for get_shaped_batch.""" @@ -1414,5 +1451,183 @@ def test_runs_without_logical_annotations(self): maxtext_utils.print_shardings_params(params, param_sharding, mesh=self.mesh, logical_annotations=None) +class TestNNXAbstractState(unittest.TestCase): + """Test the get_abstract_state_nnx func.""" + + @dataclass + class MockConfig: + init_weights_seed: int = 42 + shard_optimizer_over_data: bool = False + optimizer_memory_host_offload: bool = False + parameter_memory_host_offload: bool = False + param_scan_axis: int = 0 + logical_axis_rules: list = field(default_factory=lambda: [["data", ["data"]]]) + + class MockTrainState(nnx.Module): + """Simulates a TrainState with params and optimizer state.""" + + def __init__(self, rngs: nnx.Rngs): + # Model parameters + device_num = len(jax.local_devices()) + self.params = nnx.Linear( + 2, 4, kernel_init=nnx.with_partitioning(nnx.initializers.ones, sharding=("model",)), rngs=rngs + ) + # Simulated optimizer state + self.optimizer = nnx.Variable(jnp.zeros((device_num,)), sharding=("model",)) + + def setUp(self): + # Create a real 1D mesh on local devices + devices = jax.local_devices() + self.mesh = Mesh(mesh_utils.create_device_mesh((len(devices), 1)), axis_names=("model", "data")) + self.config = self.MockConfig() + + def nnx_init_trainstate_wrapper(self): + """Wrapper to initialize the mock NNX model.""" + rngs = maxtext_utils_nnx.create_nnx_rngs(self.config) + return self.MockTrainState(rngs) + + def test_basic_abstraction(self): + """Verifies the basic return structure and partition spec extraction.""" + abstract_state, annotations, shardings = maxtext_utils.get_abstract_state_nnx( + self.config, self.mesh, self.nnx_init_trainstate_wrapper + ) + + # Check return types + self.assertIsInstance(abstract_state, nnx.State) + self.assertIsInstance(annotations, nnx.State) + self.assertIsInstance(shardings, nnx.State) + + # Verify PartitionSpec was extracted correctly from the mock model's annotations + # Path: params -> kernel -> spec + self.assertEqual( + annotations.params.kernel.get_value(), + PartitionSpec( + "model", + ), + ) + + def test_shard_optimizer_over_data(self): + """Verifies that 'data' is added to optimizer sharding using the real utility.""" + self.config.shard_optimizer_over_data = True + + _, annotations, _ = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Original Pspec for optimizer was PartitionSpec(None). + # add_data_to_sharding should find that dim 0 is compatible with mesh 'data' + # and update it to PartitionSpec(('data',)). + opt_spec = annotations.optimizer.get_value() + + # Verify 'data' is now in the spec + self.assertEqual(opt_spec, PartitionSpec(("data", "model"))) + + def test_optimizer_host_offload(self): + """Verifies that optimizer memory is moved to host when configured.""" + self.config.optimizer_memory_host_offload = True + + _, _, shardings = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Optimizer state should be pinned to host + opt_sharding = shardings.optimizer.get_value() + self.assertEqual(opt_sharding.memory_kind, "pinned_host") + + # Params should still be on default memory (usually device) + param_sharding = shardings.params.kernel.get_value() + self.assertNotEqual(param_sharding.memory_kind, "pinned_host") + + def test_parameter_host_offload(self): + """Verifies that parameter memory is moved to host when configured.""" + self.config.parameter_memory_host_offload = True + self.config.param_scan_axis = 0 + + _, _, shardings = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Parameters should be pinned to host + param_sharding = shardings.params.kernel.get_value() + self.assertEqual(param_sharding.memory_kind, "pinned_host") + + def test_invalid_init_fn(self): + """Ensures function raises error if no init function is provided.""" + with self.assertRaises(AssertionError): + maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, None) + + +class TestGetNnxNamedShardingWithScanAxis(unittest.TestCase): + """Unit tests for get_nnx_named_sharding_with_scan_axis covering every branch. + + The helper resolves a NamedSharding for each NNX Variable and — unlike + flax.nnx.spmd.get_var_pspec — also inserts the `nnx.PARTITION_NAME` axis at + `param_scan_axis` when scanned-layers metadata is present. + """ + + def setUp(self): + # Mesh needs to contain every axis name the tests reference in partition specs. + self.mesh = Mesh(np.array(jax.local_devices()[:1]).reshape(1, 1), ("fsdp", "layers")) + + def _build_state(self, **variables): + """Wrap a dict of {key: nnx.Variable} in an nnx.State for tree traversal.""" + return nnx.State(variables) + + def _run(self, state): + return maxtext_utils.get_nnx_named_sharding_with_scan_axis(state, self.mesh) + + def test_scan_axis_inserted_at_param_scan_axis(self): + """When PARTITION_NAME is present, the partition name is inserted at `param_scan_axis`.""" + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((3, 4, 8)), + out_sharding=(None, "fsdp"), + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 1}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertIsInstance(result_sharding, NamedSharding) + # 'layers' must be inserted at position 1 (param_scan_axis=1). + self.assertEqual(result_sharding.spec, PartitionSpec(None, "layers", "fsdp")) + + def test_scan_axis_not_inserted_when_already_present(self): + """Guard against double-insertion when partition_name is already in out_sharding.""" + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((2, 2, 2)), + out_sharding=("layers", None, "fsdp"), + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 0}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # 'layers' must appear exactly once — the same PartitionSpec we started with. + self.assertEqual(result_sharding.spec, PartitionSpec("layers", None, "fsdp")) + + def test_masked_node_preserved_as_is(self): + """Values without a .shape attribute (e.g., optax.MaskedNode) are returned unchanged.""" + masked = nnx.Variable(optax.MaskedNode()) + state = self._build_state(masked=masked) + out = self._run(state) + # The leaf must be the original Variable, not a NamedSharding wrapper. + self.assertIs(out["masked"], masked) + + def test_empty_out_sharding_yields_empty_pspec(self): + """A Variable without any sharding metadata should resolve to PartitionSpec().""" + with jax.set_mesh(self.mesh): + # No out_sharding/sharding_names/sharding metadata → falsy → PartitionSpec() + v = nnx.Param(jnp.zeros((4,))) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertIsInstance(result_sharding, NamedSharding) + self.assertEqual(result_sharding.spec, PartitionSpec()) + + def test_string_out_sharding_is_wrapped_into_tuple(self): + """A single-string out_sharding value should still produce a valid PartitionSpec.""" + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((4,)), + out_sharding="fsdp", + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 0}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # The single string 'fsdp' is turned into a list, and 'layers' is prepended. + self.assertEqual(result_sharding.spec, PartitionSpec("layers", "fsdp")) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/muon_utils_test.py b/tests/unit/muon_utils_test.py new file mode 100644 index 0000000000..9570257eee --- /dev/null +++ b/tests/unit/muon_utils_test.py @@ -0,0 +1,224 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for muon_utils.py.""" + +# pylint: disable=protected-access + +import io +import contextlib +import unittest +from unittest import mock + +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax import nnx +from optax.contrib._muon import MuonDimensionNumbers as mdn + +from maxtext.utils import muon_utils + + +class TestIsPathContainAny(unittest.TestCase): + """Tests for _is_path_contain_any helper.""" + + def test_returns_true_when_any_element_in_path(self): + self.assertTrue(muon_utils._is_path_contain_any(("bias", "scale"), ("decoder", "bias"))) + + def test_returns_false_when_no_element_in_path(self): + self.assertFalse(muon_utils._is_path_contain_any(("bias", "scale"), ("decoder", "kernel"))) + + def test_empty_tuples_returns_false(self): + self.assertFalse(muon_utils._is_path_contain_any((), ("decoder", "kernel"))) + + +class TestTransformLogic(unittest.TestCase): + """Tests for transform_logic: covers every branch of the mapping.""" + + # --- 1. Exclusions --- + def test_scale_is_excluded(self): + self.assertIsNone(muon_utils.transform_logic(("decoder", "norm", "scale"))) + + def test_bias_is_excluded(self): + self.assertIsNone(muon_utils.transform_logic(("decoder", "dense", "bias"))) + + def test_embedding_is_excluded(self): + self.assertIsNone(muon_utils.transform_logic(("token_embedder", "embedding"))) + + def test_logits_dense_is_excluded(self): + self.assertIsNone(muon_utils.transform_logic(("decoder", "logits_dense", "kernel"))) + + # --- 2.1 MoE --- + def test_moe_wi_0_uses_last_two_axes(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "MoeBlock_0", "wi_0")), mdn((-2,), (-1,))) + + def test_moe_wi_1_uses_last_two_axes(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "MoeBlock_0", "wi_1")), mdn((-2,), (-1,))) + + def test_moe_wo_uses_last_two_axes(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "MoeBlock_0", "wo")), mdn((-2,), (-1,))) + + def test_moe_gate_falls_through_to_standard(self): + # 'gate' is inside MoeBlock_0 but not one of (wi_0, wi_1, wo) → standard. + self.assertEqual(muon_utils.transform_logic(("decoder", "MoeBlock_0", "gate", "kernel")), mdn((0,), (-1,))) + + # --- 2.2 Self-attention --- + def test_self_attention_out_projection(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "out")), mdn((0, -2), (-1,))) + + def test_self_attention_query_projection(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "query")), mdn((0,), (-2, -1))) + + def test_self_attention_key_projection(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "key")), mdn((0,), (-2, -1))) + + def test_self_attention_value_projection(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "value")), mdn((0,), (-2, -1))) + + def test_self_attention_wq_b_and_wkv_b(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "wq_b")), mdn((0,), (-2, -1))) + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "wkv_b")), mdn((0,), (-2, -1))) + + def test_self_attention_mla_wq_a_is_excluded_from_special(self): + # wq_a / wkv_a are MLA down-projections; they fall through the self_attention branch + # without matching anything, so the function returns the default standard mdn((0,), (-1,)). + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "wq_a")), mdn((0,), (-1,))) + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "wkv_a")), mdn((0,), (-1,))) + + # --- 3. Standard --- + def test_standard_weight(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "mlp", "kernel")), mdn((0,), (-1,))) + + +class TestGetTransformTree(unittest.TestCase): + """Tests for get_transform_tree: recursive dict walk that applies transform_logic.""" + + def test_nested_dict_is_walked(self): + tree = {"decoder": {"self_attention": {"out": 0}, "mlp": {"kernel": 0}}} + result = muon_utils.get_transform_tree(tree) + self.assertEqual(result["decoder"]["self_attention"]["out"], mdn((0, -2), (-1,))) + self.assertEqual(result["decoder"]["mlp"]["kernel"], mdn((0,), (-1,))) + + def test_excluded_leaves_become_none(self): + tree = {"decoder": {"norm": {"scale": 0}}} + self.assertIsNone(muon_utils.get_transform_tree(tree)["decoder"]["norm"]["scale"]) + + def test_non_dict_leaf_at_root_returns_transform(self): + # If the tree itself is a leaf, path=() and transform_logic returns the standard mdn. + self.assertEqual(muon_utils.get_transform_tree(0), mdn((0,), (-1,))) + + +class _MoeLikeNNXModel(nnx.Module): + """Small NNX model whose param paths exercise the NNX branch of get_muon_weight_dimension_numbers.""" + + def __init__(self, rngs): + # Names are chosen so transform_logic matches each of the three meaningful branches: + # - w_standard: default mdn + # - self_attention_out: attention-out mdn + # - scale: excluded (None) + self.w_standard = nnx.Param(jnp.ones((4, 8))) + self.self_attention_out = nnx.Param(jnp.ones((4, 8))) + self.scale = nnx.Param(jnp.ones((8,))) + + +class TestGetMuonWeightDimensionNumbersNNX(unittest.TestCase): + """Covers the NNX branch of get_muon_weight_dimension_numbers (isinstance(model, nnx.Module)).""" + + def setUp(self): + self.model = _MoeLikeNNXModel(rngs=nnx.Rngs(0)) + + def test_nnx_model_dispatches_to_tree_map_with_path(self): + """NNX branch should produce an nnx.State tree with transform_logic applied per leaf.""" + result = muon_utils.get_muon_weight_dimension_numbers(self.model, config=None) + + # Result is an nnx.State whose top-level keys mirror the model attributes. + self.assertIn("w_standard", result) + self.assertIn("self_attention_out", result) + self.assertIn("scale", result) + + # NNX Variables are walked by jax.tree_util.tree_map_with_path, so the returned + # tree replaces each Variable's value with transform_logic(path_strings). + # 'scale' matches the exclusion branch → value is None. + self.assertIsNone(result["scale"].get_value()) + # 'w_standard' does not trigger any special rule → standard mdn. + self.assertEqual(result["w_standard"].get_value(), mdn((0,), (-1,))) + + def test_nnx_verbose_path_executes_print_debug(self): + """verbose=True should also execute _print_structure_debug without raising.""" + buf = io.StringIO() + with contextlib.redirect_stdout(buf): + muon_utils.get_muon_weight_dimension_numbers(self.model, config=None, verbose=True) + self.assertIn("Model Structure", buf.getvalue()) + self.assertIn("Muon Dimension Numbers", buf.getvalue()) + + +class TestGetMuonWeightDimensionNumbersLinen(unittest.TestCase): + """Covers the Linen branch of get_muon_weight_dimension_numbers.""" + + def test_linen_branch_uses_get_abstract_param(self): + """Linen models dispatch to maxtext_utils.get_abstract_param + get_transform_tree.""" + # Build a Linen nn.Module so isinstance(model, nnx.Module) is False. + + class LinenStub(nn.Module): + + @nn.compact + def __call__(self, x): + return x + + model = LinenStub() + + # Mock the heavy get_abstract_param call with a pre-shaped dict that exercises + # both a standard weight path and an excluded path. + fake_abstract_param = { + "params": { + "self_attention": {"out": object()}, + "norm": {"scale": object()}, + }, + } + + with mock.patch.object(muon_utils.maxtext_utils, "get_abstract_param", return_value=fake_abstract_param): + result = muon_utils.get_muon_weight_dimension_numbers(model, config=mock.MagicMock()) + + self.assertEqual(result["params"]["self_attention"]["out"], mdn((0, -2), (-1,))) + self.assertIsNone(result["params"]["norm"]["scale"]) + + +class TestPrintStructureDebug(unittest.TestCase): + """Covers both branches of get_leaf_info inside _print_structure_debug.""" + + def test_handles_logically_partitioned_leaf(self): + """Linen leaves are nn.LogicallyPartitioned; the helper should return {shape, names}.""" + leaf = nn.LogicallyPartitioned(value=jax.ShapeDtypeStruct((4, 8), jnp.float32), names=("embed", "mlp")) + tree = {"params": {"kernel": leaf}} + + buf = io.StringIO() + with contextlib.redirect_stdout(buf): + muon_utils._print_structure_debug(tree, muon_weight_dimension_numbers={"params": {"kernel": mdn((0,), (-1,))}}) + out = buf.getvalue() + self.assertIn("(4, 8)", out) + self.assertIn("embed", out) + + def test_handles_shape_dtype_struct_leaf(self): + """NNX abstract leaves are ShapeDtypeStruct directly; the helper should return {shape}.""" + tree = {"kernel": jax.ShapeDtypeStruct((16, 32), jnp.float32)} + + buf = io.StringIO() + with contextlib.redirect_stdout(buf): + muon_utils._print_structure_debug(tree, muon_weight_dimension_numbers={"kernel": mdn((0,), (-1,))}) + out = buf.getvalue() + self.assertIn("(16, 32)", out) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/nnx_decoders_test.py b/tests/unit/nnx_decoders_test.py index acff8afe23..2525a181f1 100644 --- a/tests/unit/nnx_decoders_test.py +++ b/tests/unit/nnx_decoders_test.py @@ -31,7 +31,13 @@ from flax import nnx from jax.sharding import Mesh -from maxtext.common.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, DecoderBlockType +from maxtext.common.common_types import ( + DECODING_ACTIVE_SEQUENCE_INDICATOR, + MODEL_MODE_PREFILL, + MODEL_MODE_TRAIN, + DecoderBlockType, + MultimodalInput, +) from maxtext.configs import pyconfig from maxtext.layers import linears from maxtext.layers.attentions import Attention @@ -573,6 +579,71 @@ def test_logits_are_finite(self): ) self.assertTrue(jnp.all(jnp.isfinite(logits))) + def test_multimodal_input_forwarded_to_apply_embedding(self): + """`multimodal_input` must reach `_apply_embedding` as the original struct. + + `NNXDecoder.__call__` takes a `MultimodalInput` struct and hands it to + `_apply_embedding`, which is the layer that actually unpacks the fields + and merges the embeddings. This test stubs `_apply_embedding` to capture + the forwarded struct without running the real embedding path (the test + config has `use_multimodal=False`). + """ + ids, segment_ids, positions = self._make_token_inputs() + + # Distinct sentinels so each field can be traced independently. + sentinel_img_emb = jnp.full((1, 1), 11.0) + sentinel_img_mask = jnp.full((1, 1), 22.0) + sentinel_aud_emb = jnp.full((1, 1), 33.0) + sentinel_aud_mask = jnp.full((1, 1), 44.0) + sentinel_bidir = jnp.full((1, 1), 55.0) + + mm_input = MultimodalInput( + image_embeddings=sentinel_img_emb, + image_masks=sentinel_img_mask, + audio_embeddings=sentinel_aud_emb, + audio_masks=sentinel_aud_mask, + bidirectional_mask=sentinel_bidir, + ) + + captured = {} + + def fake_apply_embedding( + _shared_embedding, + _ids, + _positions, + _deterministic, + _model_mode, + multimodal_input=None, + ): + captured["multimodal_input"] = multimodal_input + batch = self.cfg.global_batch_size_to_train_on + seq_len = self.cfg.max_target_length + emb_dim = self.cfg.emb_dim + return jnp.zeros((batch, seq_len, emb_dim), dtype=self.cfg.dtype) + + self.decoder._apply_embedding = fake_apply_embedding # pylint: disable=protected-access + try: + self.decoder( + self.shared_embedding, + ids, + positions, + decoder_segment_ids=segment_ids, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + multimodal_input=mm_input, + ) + finally: + # NNX modules bind attributes statefully; remove the override to avoid leaking. + del self.decoder._apply_embedding # pylint: disable=protected-access + + forwarded = captured["multimodal_input"] + self.assertIsNotNone(forwarded) + self.assertTrue(jnp.array_equal(forwarded.image_embeddings, sentinel_img_emb)) + self.assertTrue(jnp.array_equal(forwarded.image_masks, sentinel_img_mask)) + self.assertTrue(jnp.array_equal(forwarded.audio_embeddings, sentinel_aud_emb)) + self.assertTrue(jnp.array_equal(forwarded.audio_masks, sentinel_aud_mask)) + self.assertTrue(jnp.array_equal(forwarded.bidirectional_mask, sentinel_bidir)) + def test_different_random_seeds_produce_different_logits(self): """Two randomly-initialised decoders should not produce identical logits.""" cfg = self.cfg diff --git a/tests/unit/optimizers_test.py b/tests/unit/optimizers_test.py index 44623f24f3..5194719ce2 100644 --- a/tests/unit/optimizers_test.py +++ b/tests/unit/optimizers_test.py @@ -15,19 +15,19 @@ """ Unit tests for all optimizers. """ import re import unittest -from unittest.mock import patch +from unittest.mock import patch, MagicMock import jax import optax import jax.numpy as jnp import pytest from absl.testing import parameterized +from flax import nnx from optax.contrib import MuonDimensionNumbers as mdn from maxtext.configs import pyconfig from maxtext.optimizers import optimizers -from maxtext.utils import maxtext_utils -from maxtext.utils.muon_utils import get_model_mdn +from maxtext.utils import maxtext_utils, muon_utils from tests.utils.test_helpers import get_test_config_path from typing import NamedTuple @@ -49,6 +49,7 @@ DEEPSEEK2_DIMENSION_NUMBER = { "params": { "decoder": { + "decoder_norm": {"scale": None}, "dense_layers": { "mlp": { "wi_0": {"kernel": mdn((0,), (-1,))}, @@ -57,6 +58,7 @@ }, **_DEEPSEEK2_ATTENTION, }, + "logits_dense": {"kernel": None}, "moe_layers": { "DeepSeekMoeBlock_0": { "MoeBlock_0": { @@ -73,8 +75,6 @@ }, **_DEEPSEEK2_ATTENTION, }, - "decoder_norm": {"scale": None}, - "logits_dense": {"kernel": None}, }, "token_embedder": {"embedding": None}, } @@ -99,6 +99,7 @@ DEEPSEEK3_DIMENSION_NUMBER = { "params": { "decoder": { + "decoder_norm": {"scale": None}, "dense_layers": { "mlp": { "wi_0": {"kernel": mdn((0,), (-1,))}, @@ -107,6 +108,7 @@ }, **_DEEPSEEK3_ATTENTION, }, + "logits_dense": {"kernel": None}, "moe_layers": { "DeepSeekMoeBlock_0": { "MoeBlock_0": { @@ -123,8 +125,6 @@ }, **_DEEPSEEK3_ATTENTION, }, - "decoder_norm": {"scale": None}, - "logits_dense": {"kernel": None}, }, "token_embedder": {"embedding": None}, } @@ -243,7 +243,7 @@ def test_model_integration(self, model_name, expected_output): Initializes the specified MaxText model and asserts that the generated Muon dimension numbers match the hardcoded reference. """ - actual_output = get_model_mdn(model_name, scan_layers=True) + actual_output = muon_utils.get_model_mdn(model_name, scan_layers=True, pure_nnx=False) self.assertEqual(actual_output, expected_output) @@ -483,5 +483,105 @@ def test_no_skip_without_kwargs(self): self.assertEqual(opt_state["count"], 0) +class TestMuonLogic(unittest.TestCase): + """Tests the granular path transformation functions.""" + + def test_is_path_contain_any(self): + # pylint: disable=protected-access + self.assertTrue(muon_utils._is_path_contain_any(("a", "b"), ("x", "a", "z"))) + self.assertFalse(muon_utils._is_path_contain_any(("a", "b"), ("x", "y", "z"))) + + def test_transform_logic_exclusions(self): + self.assertIsNone(muon_utils.transform_logic(("layer_0", "bias"))) + self.assertIsNone(muon_utils.transform_logic(("layer_0", "scale"))) + self.assertIsNone(muon_utils.transform_logic(("embedding", "kernel"))) + + def test_transform_logic_moe(self): + path = ("layers_0", "MoeBlock_0", "wi_0") + result = muon_utils.transform_logic(path) + self.assertEqual(result.reduction_axis, (-2,)) + self.assertEqual(result.output_axis, (-1,)) + + def test_transform_logic_attention(self): + path_out = ("layers_0", "self_attention", "out", "kernel") + self.assertEqual(muon_utils.transform_logic(path_out), mdn((0, -2), (-1,))) + + path_q = ("layers_0", "self_attention", "query", "kernel") + self.assertEqual(muon_utils.transform_logic(path_q), mdn((0,), (-2, -1))) + + def test_get_transform_tree(self): + fake_tree = {"params": {"layer_0": {"kernel": "leaf", "bias": "leaf"}, "MoeBlock_0": {"wi_0": "leaf"}}} + result = muon_utils.get_transform_tree(fake_tree) + self.assertEqual(result["params"]["layer_0"]["kernel"], mdn((0,), (-1,))) + self.assertIsNone(result["params"]["layer_0"]["bias"]) + + def test_get_muon_weight_dimension_numbers_nnx(self): + """Verifies dimension extraction for stateful NNX modules.""" + + class MockNNXModel(nnx.Module): + """Mock NNX Module.""" + + def __init__(self, rngs: nnx.Rngs): + # 1. Standard layer + self.layer1 = nnx.Linear(2, 4, rngs=rngs) + + # 2. MoE specific naming to trigger transform logic. + # The logic expects "MoeBlock_0" AND "wi_0"/"wi_1"/"wo" in the path. + # We nest the linear layer to create the path: ('MoeBlock_0', 'wi_0', 'kernel') + self.MoeBlock_0 = nnx.Module() + self.MoeBlock_0.wi_0 = nnx.Linear(4, 2, rngs=rngs) + + # 3. Exclusion case (scaler/scale) + self.scale = nnx.Param(jnp.ones((1,))) + + # Use eval_shape to create an abstract version of the model. + model = nnx.eval_shape(lambda: MockNNXModel(rngs=nnx.Rngs(0))) + config = MagicMock() + + # Extract dimension numbers using the NNX path in muon_utils + result = muon_utils.get_muon_weight_dimension_numbers(model, config) + + # Verify standard weight path: ('layer1', 'kernel') -> default (0,) + self.assertEqual(result.layer1.kernel.value, mdn((0,), (-1,))) + + # Verify MoE weight path: ('MoeBlock_0', 'wi_0', 'kernel') -> (-2,) + self.assertEqual(result.MoeBlock_0.wi_0.kernel.value, mdn((-2,), (-1,))) + + # Verify exclusion (scalar/scale) + self.assertIsNone(result.scale.value) + + def test_verbose_output_nnx(self): + """Covers lines 128 and 135-154: _print_structure_debug via verbose=True with NNX model.""" + + class SimpleNNXModel(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 4, rngs=rngs) + + model = nnx.eval_shape(lambda: SimpleNNXModel(rngs=nnx.Rngs(0))) + config = MagicMock() + muon_utils.get_muon_weight_dimension_numbers(model, config, verbose=True) + + def test_nnx_deepseek_attention_logic(self): + """Simulates a DeepSeek-like attention structure in NNX.""" + + class DeepSeekAttention(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.self_attention = nnx.Module() + self.self_attention.query = nnx.Linear(8, 8, rngs=rngs) + self.self_attention.out = nnx.Linear(8, 8, rngs=rngs) + + # Use eval_shape to create an abstract version of the model. + model = nnx.eval_shape(lambda: DeepSeekAttention(nnx.Rngs(0))) + config = MagicMock() + result = muon_utils.get_muon_weight_dimension_numbers(model, config) + + # Check attention query: [0] -> [-2, -1] + self.assertEqual(result.self_attention.query.kernel.value, mdn((0,), (-2, -1))) + # Check attention out: [0, -2] -> [-1] + self.assertEqual(result.self_attention.out.kernel.value, mdn((0, -2), (-1,))) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/sharding_nnx_test.py b/tests/unit/sharding_nnx_test.py new file mode 100644 index 0000000000..3cda286c68 --- /dev/null +++ b/tests/unit/sharding_nnx_test.py @@ -0,0 +1,161 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX-specific helpers in maxtext.utils.sharding.""" + +import unittest +from dataclasses import dataclass + +import jax +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from flax import nnx +import numpy as np +import optax + +from maxtext.layers import train_state_nnx +from maxtext.utils import sharding + + +@dataclass +class _Cfg: + pure_nnx: bool = True + shard_optimizer_over_data: bool = False + + +class _LinearNNX(nnx.Module): + """Tiny NNX model with a single Linear layer for sharding tests.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 4, rngs=rngs) + + +def _build_state_mesh_shardings(model, tx): + """Build an nnx.State of NamedShardings mirroring the TrainStateNNX layout. + + This emulates what get_abstract_state_nnx returns: an nnx.State whose leaves + are nnx.Variable wrappers around NamedSharding objects. + """ + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + state_obj = train_state_nnx.TrainStateNNX(model, optimizer) + state = nnx.state(state_obj) + mesh = Mesh(np.array(jax.local_devices()[:1]).reshape(1, 1), ("data", "model")) + + def _to_sharding(var): + val = var.get_value() + if not hasattr(val, "shape") or val.ndim == 0: + pspec = PartitionSpec() + elif val.ndim == 1: + pspec = PartitionSpec("model") + else: + pspec = PartitionSpec("data", "model") + return var.replace(NamedSharding(mesh, pspec)) + + return jax.tree.map(_to_sharding, state, is_leaf=lambda x: isinstance(x, nnx.Variable)) + + +class TestMaybeUpdateParamsShardingWithOptNNX(unittest.TestCase): + """Cover the NNX branches of maybe_update_params_sharding_with_opt.""" + + def setUp(self): + self.model = _LinearNNX(rngs=nnx.Rngs(0)) + + def test_dispatch_from_main_helper_when_pure_nnx(self): + """maybe_update_params_sharding_with_opt should dispatch to the NNX variant.""" + cfg = _Cfg(pure_nnx=True, shard_optimizer_over_data=False) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + prev, updated = sharding.maybe_update_params_sharding_with_opt(cfg, state_mesh_shardings) + # prev is the param-only view (no rngs / non-Param nodes) + self.assertIsInstance(prev, nnx.State) + self.assertIn("linear", prev) + # updated is unchanged because shard_optimizer_over_data=False + self.assertIs(updated, state_mesh_shardings) + + def test_extract_param_only_skips_non_param_variables(self): + """prev_params_shardings must contain Params only — RngKey/RngCount/OptVariable filtered out.""" + cfg = _Cfg(shard_optimizer_over_data=False) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + prev, _ = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) + leaves = jax.tree.leaves(prev, is_leaf=lambda x: isinstance(x, nnx.Variable)) + # Every surviving leaf is wrapped as an nnx.Param. + self.assertTrue(all(isinstance(leaf, nnx.Param) for leaf in leaves)) + # The model has linear.kernel and linear.bias — exactly two Param leaves. + self.assertEqual(len(leaves), 2) + + def test_returns_unchanged_when_shard_optimizer_over_data_false(self): + """When shard_optimizer_over_data=False, the second return value must be the input object.""" + cfg = _Cfg(shard_optimizer_over_data=False) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + _, updated = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) + self.assertIs(updated, state_mesh_shardings) + + def test_zero1_propagates_mu_sharding_to_model_params(self): + """Zero-1: model param shardings must be replaced with the optimizer mu shardings.""" + cfg = _Cfg(shard_optimizer_over_data=True) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + + # Mutate the optimizer mu leaves in place so the function picks up a distinct PartitionSpec. + mesh = Mesh(np.array(jax.local_devices()[:1]).reshape(1, 1), ("data", "model")) + target_pspec = PartitionSpec(("data", "model")) + new_mu_sharding = NamedSharding(mesh, target_pspec) + + # After _build_state_mesh_shardings, every leaf's .value is a NamedSharding (no .shape), + # so we just override every Variable leaf in mu in place. + # After _build_state_mesh_shardings, every leaf's value is a NamedSharding (no .shape), + # so we just override every Variable leaf in mu in place via set_value (modern API). + mu_state = state_mesh_shardings.optimizer.opt_state[0]["mu"] + for var in jax.tree.leaves(mu_state, is_leaf=lambda x: isinstance(x, nnx.Variable)): + if isinstance(var, nnx.Variable): + var.set_value(new_mu_sharding) + + _, updated = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) + + # All Param leaves under updated.model must now share the new mu sharding. + param_leaves = jax.tree.leaves(updated.model, is_leaf=lambda x: isinstance(x, nnx.Variable)) + param_leaves = [v for v in param_leaves if isinstance(v, nnx.Param)] + self.assertGreater(len(param_leaves), 0) + for leaf in param_leaves: + self.assertEqual(leaf.get_value().spec, target_pspec) + + def test_raises_when_no_adam_state_present(self): + """Stateless optimizers (e.g., SGD) have no mu — function must raise NotImplementedError.""" + cfg = _Cfg(shard_optimizer_over_data=True) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.sgd(1e-3)) + with self.assertRaises(NotImplementedError): + sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) + + def test_chained_optimizer_recursion_finds_adam_mu(self): + """A nested optax.chain(clip, adam) wraps mu under multiple containers — recursion must find it.""" + cfg = _Cfg(shard_optimizer_over_data=True) + chained = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3)) + state_mesh_shardings = _build_state_mesh_shardings(self.model, chained) + + # Should not raise; verify update happens (params replaced with mu shardings). + prev, updated = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) + self.assertIsInstance(prev, nnx.State) + self.assertIsInstance(updated, nnx.State) + # Same number of Param leaves before and after. + n_prev = len(jax.tree.leaves(prev, is_leaf=lambda x: isinstance(x, nnx.Variable))) + n_after = len( + [ + v + for v in jax.tree.leaves(updated.model, is_leaf=lambda x: isinstance(x, nnx.Variable)) + if isinstance(v, nnx.Param) + ] + ) + self.assertEqual(n_prev, n_after) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py new file mode 100644 index 0000000000..3495b4c557 --- /dev/null +++ b/tests/unit/train_nnx_test.py @@ -0,0 +1,239 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX paths of loss_fn / train_step / eval_step in pre_train.train. + +These tests exercise the NNX branches without standing up a real Transformer or +data pipeline. We use a tiny NNX module that mimics the call signature the +production loss_fn uses (decoder_input_tokens, decoder_positions, ...). +""" + +import unittest +from dataclasses import dataclass + +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.layers import train_state_nnx +from maxtext.trainers.pre_train import train as pre_train + + +@dataclass +class _Cfg: + """Subset of HyperParameters used by loss_fn / train_step / eval_step.""" + + micro_batch_size_to_train_on: int = 2 + micro_batch_size_to_eval_on: int = 2 + vocab_size: int = 8 + z_loss_multiplier: float = 0.0 + enable_dropout: bool = False + use_multimodal: bool = False + use_indexer: bool = False + indexer_sparse_training: bool = False + indexer_loss_scaling_factor: float = 0.0 + num_vocab_tiling: int = 1 + num_experts: int = 1 + routed_bias: bool = False + routed_bias_update_rate: float = 0.0 + mtp_num_layers: int = 0 + mtp_eval_target_module: int = 0 + use_dpo: bool = False + use_qk_clip: bool = False + use_tunix_gradient_accumulation: bool = False + gradient_accumulation_steps: int = 1 + shard_optimizer_over_data: bool = False + optimizer_memory_host_offload: bool = False + parameter_memory_host_offload: bool = False + gradient_clipping_threshold: float = 0.0 + grad_dtype: jnp.dtype = jnp.float32 + record_internal_nn_metrics: bool = False + skip_step_on_spikes: bool = False + shard_mode: int = 0 # ShardMode.AUTO + weight_sparsity_n: int = 0 + weight_sparsity_m: int = 0 + + +class _TinyDecoder(nnx.Module): + """Mimics NNXDecoder.__call__ enough for loss_fn to run end-to-end. + + Returns logits of shape [batch, seq_len, vocab_size]. Ignores all multimodal + / dropout / target arguments — they exist only to match the keyword signature. + """ + + def __init__(self, vocab_size: int, hidden: int, rngs: nnx.Rngs): + self.embed = nnx.Embed(vocab_size, hidden, rngs=rngs) + self.proj = nnx.Linear(hidden, vocab_size, rngs=rngs) + + def __call__( + self, + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=None, + encoder_images=None, + encoder_image_masks=None, + enable_dropout=False, + decoder_target_tokens=None, + decoder_target_mask=None, + ): + del decoder_positions, decoder_segment_ids, encoder_images, encoder_image_masks + del enable_dropout, decoder_target_tokens, decoder_target_mask + h = self.embed(decoder_input_tokens) + return self.proj(h) + + +def _make_data(batch=2, seq=4, vocab=8): + return { + "inputs": jnp.zeros((batch, seq), dtype=jnp.int32), + "inputs_position": jnp.broadcast_to(jnp.arange(seq), (batch, seq)), + "inputs_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + "targets": jnp.zeros((batch, seq), dtype=jnp.int32), + "targets_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + } + + +def _build_state(): + cfg = _Cfg() + model = _TinyDecoder(cfg.vocab_size, hidden=4, rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, optax.sgd(0.01), wrt=nnx.Param) + ts = train_state_nnx.TrainStateNNX(model, optimizer) + return cfg, ts + + +class TestLossFnNNX(unittest.TestCase): + """Cover the NNX branch of loss_fn (lines 178-213).""" + + def test_returns_loss_and_full_aux_dict(self): + cfg, ts = _build_state() + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + loss, aux = pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=True) + self.assertTrue(jnp.isfinite(loss)) + # Aux schema relied on by train_step / eval_step / GA. + for key in ( + "intermediate_outputs", + "xent_sum", + "z_loss", + "total_weights", + "moe_lb_loss", + "indexer_loss", + "moe_bias_updates", + "mtp_loss", + ): + self.assertIn(key, aux) + # NNX intermediates are captured into a pure-dict snapshot, then logits attached. + self.assertIsInstance(aux["intermediate_outputs"], dict) + self.assertIn("logits", aux["intermediate_outputs"]) + + def test_eval_mode_truncates_to_eval_micro_batch(self): + cfg, ts = _build_state() + cfg.micro_batch_size_to_eval_on = 1 + data = _make_data(batch=2, vocab=cfg.vocab_size) + loss, aux = pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=False) + self.assertTrue(jnp.isfinite(loss)) + # eval truncated batch to 1 → total_weights = seq_len * 1 + self.assertEqual(int(aux["total_weights"]), data["targets_segmentation"].shape[1]) + + def test_indexer_dense_warmup_skips_xent(self): + cfg, ts = _build_state() + cfg.use_indexer = True + cfg.indexer_sparse_training = False + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + loss, aux = pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=True) + # When dense warm-up is active the loss_fn skips the main loss entirely. + self.assertEqual(float(aux["xent_sum"]), 0.0) + self.assertEqual(float(loss), 0.0) + + def test_vocab_tiling_raises_not_implemented(self): + cfg, ts = _build_state() + cfg.num_vocab_tiling = 4 + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + with self.assertRaises(NotImplementedError): + pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=True) + + +class TestTrainStepNNX(unittest.TestCase): + """Cover the NNX branch of train_step (the diff_wrapper / nnx.update path).""" + + def test_train_step_returns_state_and_metrics(self): + cfg, ts = _build_state() + state_graphdef, state_pure = nnx.split(ts) + + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + new_state, metrics = pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + # NNX path returns nnx.State (via nnx.state(new_state)) and a metrics dict. + self.assertIsInstance(new_state, nnx.State) + self.assertIn("scalar", metrics) + self.assertIn("learning/loss", metrics["scalar"]) + self.assertIn("learning/grad_norm", metrics["scalar"]) + self.assertIn("learning/param_norm", metrics["scalar"]) + self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) + + def test_train_step_dpo_raises_for_nnx(self): + cfg, ts = _build_state() + cfg.use_dpo = True + state_graphdef, state_pure = nnx.split(ts) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + with self.assertRaises(NotImplementedError): + pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + + def test_train_step_increments_optimizer_step(self): + cfg, ts = _build_state() + state_graphdef, state_pure = nnx.split(ts) + pre_step = int(state_pure.optimizer.step.get_value()) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + new_state, _ = pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + self.assertEqual(int(new_state.optimizer.step.get_value()), pre_step + 1) + + def test_train_step_with_gradient_clipping(self): + """The clipping branch (gradient_clipping_threshold > 0) must run without raising.""" + cfg, ts = _build_state() + cfg.gradient_clipping_threshold = 1.0 + state_graphdef, state_pure = nnx.split(ts) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + new_state, metrics = pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + self.assertIsInstance(new_state, nnx.State) + self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) + + +class TestEvalStepNNX(unittest.TestCase): + """Cover the NNX branch of eval_step (lines 568-570).""" + + def test_eval_step_returns_metrics(self): + cfg, ts = _build_state() + state_graphdef, state_pure = nnx.split(ts) + data = _make_data(batch=cfg.micro_batch_size_to_eval_on, vocab=cfg.vocab_size) + metrics = pre_train.eval_step(state_graphdef, cfg, state_pure, data) + self.assertIn("scalar", metrics) + for key in ( + "evaluation/loss", + "evaluation/total_loss", + "evaluation/total_weights", + "evaluation/moe_lb_loss", + ): + self.assertIn(key, metrics["scalar"]) + # NNX path must NOT include DPO eval metric. + self.assertNotIn("evaluation/dpo_reward_accuracy", metrics["scalar"]) + self.assertTrue(jnp.isfinite(metrics["scalar"]["evaluation/loss"])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_state_nnx_checkpoint_test.py b/tests/unit/train_state_nnx_checkpoint_test.py new file mode 100644 index 0000000000..0f7dc22d68 --- /dev/null +++ b/tests/unit/train_state_nnx_checkpoint_test.py @@ -0,0 +1,412 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TrainStateNNX checkpoint tests.""" + +import pathlib +import tempfile +import shutil +from types import SimpleNamespace +from unittest import mock + +import unittest +import jax +import jax.numpy as jnp +from flax import nnx, serialization +from flax import linen as nn +from flax.training import train_state +import optax +import orbax.checkpoint as ocp + +from maxtext.common import checkpointing +from maxtext.layers import train_state_nnx + + +class MockModel(nnx.Module): + """A simple model for checkpoint testing.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + def __call__(self, x): + return self.linear(x) + + +class LinenMockModel(nn.Module): + """The Linen equivalent of the MockModel.""" + + @nn.compact + def __call__(self, x): + # We name the layer 'linear' to match the attribute name in the NNX MockModel + return nn.Dense(features=1, name="linear")(x) + + +def _replicate_for_orbax(pytree): + """Give every array a replicated NamedSharding so Orbax can save in multi-host CI. + + Orbax refuses arrays with the default SingleDeviceSharding when + jax.process_count() > 1. Putting each leaf on a NamedSharding over the local + mesh works in both single- and multi-host environments without changing + values. + """ + mesh = jax.sharding.Mesh(jax.devices(), ("x",)) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + return jax.tree.map(lambda x: jax.device_put(x, sharding) if isinstance(x, jax.Array) else x, pytree) + + +class TestTrainStateNNXCheckpoint(unittest.TestCase): + """Class to test NNX checkpoint.""" + + def setUp(self): + self.rngs = nnx.Rngs(0) + self.model = MockModel(rngs=self.rngs) + + # Setup a chained optimizer: Gradient Clipping -> Adam + # Note: optax.adam is also a chain (scale_by_adam + scale_by_learning_rate). + # This creates a nested state structure: (EmptyState, (ScaleByAdamState, EmptyState)) + self.tx = optax.chain( + optax.clip_by_global_norm(max_norm=1.0), + optax.adam(1e-3), + ) + + def test_checkpoint_structure(self): + """Ensures the state object contains both model and optimizer keys.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # We use .to_pure_dict() to simulate the format stored in a checkpoint. + # This converts nnx.Variable/State objects into raw arrays and dictionaries. + full_state = nnx.state(state).to_pure_dict() + + # 1. Verify Top-level Keys + self.assertIn("model", full_state) + self.assertIn("optimizer", full_state) + + # 2. Verify Optimizer Internal Structure + opt_inner_state = full_state["optimizer"]["opt_state"] + + # Because we used optax.chain(clip, adam), index 0 is clip, index 1 is adam. + # Since adam is also a chain, index 1 is itself a dictionary/tuple representation. + # Adam's momentum (mu/nu) is in the first element of its own sub-chain. + adam_component = opt_inner_state[1][0] + + self.assertIn("mu", adam_component, "Adam 'mu' buffer not found in pure dict state.") + self.assertIn("nu", adam_component, "Adam 'nu' buffer not found in pure dict state.") + + # In a pure dict, these are nested dictionaries containing arrays, not NNX objects. + self.assertIsInstance(adam_component["mu"], dict) + self.assertIsInstance(adam_component["nu"], dict) + + # To verify a specific leaf, we navigate the dictionary hierarchy: + self.assertIsInstance(adam_component["mu"]["linear"]["kernel"], jax.Array) + + def test_checkpoint_and_restore(self): + """Verifies that the full state can be captured and restored into a new instance.""" + # 1. Initialize original state and optimizer + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state_original = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # 2. Perform a training step to modify weights and optimizer buffers + def loss_fn(m): + return jnp.mean(m(jnp.ones((1, 2))) ** 2) + + grads = nnx.grad(loss_fn)(state_original.model) + state_original.apply_gradients(grads) + + # Capture state after one step + original_kernel_val = state_original.model.linear.kernel.value + original_step_val = state_original.optimizer.step.value + self.assertEqual(original_step_val, 1) + + # 3. Capture the "Checkpoint" as a pure dictionary + checkpoint_state = nnx.state(state_original).to_pure_dict() + + # 4. Initialize a fresh, different instance + new_rngs = nnx.Rngs(1) + new_model = MockModel(rngs=new_rngs) + new_optimizer = nnx.Optimizer(new_model, self.tx, wrt=nnx.Param) + state_restored = train_state_nnx.TrainStateNNX(new_model, new_optimizer) + + # Check differences before restoration + self.assertEqual(state_restored.optimizer.step.value, 0) + self.assertFalse(jnp.allclose(state_restored.model.linear.kernel.value, original_kernel_val)) + + # 5. Restore the state into the new instance. + # nnx.update supports updating from a pure dictionary. + nnx.update(state_restored, checkpoint_state) + + # 6. Verify restoration + # Check step counter + self.assertEqual(state_restored.optimizer.step.value, original_step_val) + # Check model weights + self.assertTrue(jnp.allclose(state_restored.model.linear.kernel.value, original_kernel_val)) + + # Check that it can still be trained after restoration + new_grads = nnx.grad(loss_fn)(state_restored.model) + state_restored.apply_gradients(new_grads) + self.assertEqual(state_restored.optimizer.step.value, 2) + + def test_restore_from_linen_state(self): + """Verifies a multi-stage migration: Linen CKPT -> Migrate -> NNX CKPT -> Restore.""" + # 1. Setup Linen TrainState (Simulating original training) + linen_model = LinenMockModel() + dummy_input = jnp.ones((1, 2)) + variables = linen_model.init(jax.random.key(42), dummy_input) + + state_linen = train_state.TrainState.create(apply_fn=linen_model.apply, params=variables["params"], tx=self.tx) + + # Perform a step to populate optimizer buffers + grads = jax.tree.map(jnp.ones_like, state_linen.params) + state_linen = state_linen.apply_gradients(grads=grads) + + temp_dir = pathlib.Path(tempfile.mkdtemp()) + try: + # --- PHASE 1: Save Legacy Linen Checkpoint --- + linen_ckpt_dir = temp_dir / "linen_ckpt" + mngr_linen = ocp.CheckpointManager( + linen_ckpt_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + mngr_linen.save(0, args=ocp.args.StandardSave(_replicate_for_orbax(state_linen))) + mngr_linen.wait_until_finished() + + # --- PHASE 2: Read Linen CKPT and Convert to NNX Structure --- + # Load it back without knowing the blueprint (reading as a pure PyTree) + restored_linen_obj = mngr_linen.restore(0) + + # Convert the restored object to a pure dictionary structure. + restored_linen_dict = serialization.to_state_dict(restored_linen_obj) + + # Helper to recursively convert string keys back to integers + # and filter out None values. + def recursive_clean(obj): + if isinstance(obj, dict): + return {int(k) if k.isdigit() else k: recursive_clean(v) for k, v in obj.items() if v is not None} + return obj + + # Converted dict - simple PyTree mapping, no NNX Module initialization needed here. + # This simulates a situation where the conversion logic is blueprint-agnostic. + linen_as_nnx_dict = { + "model": restored_linen_dict["params"], + "optimizer": { + "step": jnp.array(restored_linen_dict["step"]), + "opt_state": recursive_clean(restored_linen_dict["opt_state"]), + }, + } + + # --- PHASE 3: Save as Native NNX Checkpoint --- + nnx_ckpt_dir = temp_dir / "nnx_ckpt" + mngr_nnx = ocp.CheckpointManager( + nnx_ckpt_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + # We save the raw dictionary directly to disk. + mngr_nnx.save(0, args=ocp.args.StandardSave(_replicate_for_orbax(linen_as_nnx_dict))) + mngr_nnx.wait_until_finished() + + # --- PHASE 4: Restore from NNX Checkpoint to target Model --- + nnx_model = MockModel(rngs=nnx.Rngs(0)) + nnx_optimizer = nnx.Optimizer(nnx_model, self.tx, wrt=nnx.Param) + state_nnx = train_state_nnx.TrainStateNNX(nnx_model, nnx_optimizer) + + # We now restore using the nnx.State as a blueprint. This ensures Orbax + # correctly maps the arrays on disk to the model's structural expectation. + blueprint = nnx.state(state_nnx).to_pure_dict() + restored_nnx_pytree = mngr_nnx.restore(0, args=ocp.args.StandardRestore(item=blueprint)) + nnx.update(state_nnx, restored_nnx_pytree) + + # --- PHASE 5: Verification --- + # 1. Verify Step + self.assertEqual(state_nnx.optimizer.step.value, 1) + + # 2. Verify Weights + self.assertTrue(jnp.allclose(state_nnx.model.linear.kernel.value, state_linen.params["linear"]["kernel"])) + + # 3. Verify Chained Optimizer State (Clip at index 0, Adam at index 1) + self.assertEqual(type(state_nnx.optimizer.opt_state[0]), type(state_linen.opt_state[0])) + + # state_linen.opt_state[1] is the Adam chain state. + # state_linen.opt_state[1][0] is the ScaleByAdamState containing 'mu'. + self.assertTrue( + jnp.allclose( + state_nnx.optimizer.opt_state[1][0].mu["linear"]["kernel"], + state_linen.opt_state[1][0].mu["linear"]["kernel"], + ) + ) + + finally: + # Cleanup temporary directory + shutil.rmtree(temp_dir) + + def test_restore_from_checkpoint_model_params(self): + """Verifies that model parameters can be restored from model params only.""" + # 1. Setup mocked parameters manually (no Linen model needed for setup) + # This structure matches the path model.linear.kernel/bias in the NNX MockModel. + mock_params = {"linear": {"kernel": jnp.ones((2, 1)) * 9.0, "bias": jnp.zeros((1,))}} + + # Simplified checkpoint dictionary using hardcoded mocked params as requested + checkpoint_dict = { + "model": mock_params, + } + + temp_dir = pathlib.Path(tempfile.mkdtemp()) + try: + # --- PHASE 1: Save the partial checkpoint --- + mngr = ocp.CheckpointManager( + temp_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + mngr.save(0, args=ocp.args.StandardSave(_replicate_for_orbax(checkpoint_dict))) + mngr.wait_until_finished() + + # --- PHASE 2: Restore into a full TrainStateNNX --- + nnx_model = MockModel(rngs=nnx.Rngs(0)) + nnx_optimizer = nnx.Optimizer(nnx_model, self.tx, wrt=nnx.Param) + state_nnx = train_state_nnx.TrainStateNNX(nnx_model, nnx_optimizer) + + # We use nnx.state to get a full blueprint as a reference. + full_nnx_pure_dict = nnx.state(state_nnx).to_pure_dict() + blueprint = {"model": full_nnx_pure_dict["model"]} + + # If we don't know if the checkpoint on disk has 'optimizer' or not, we simulate + # schema-agnostic restoration by calling restore without a blueprint. + # This avoids Orbax structural mismatch errors while allowing us to see the data. + restored_pytree = mngr.restore(0, args=ocp.args.StandardRestore(item=blueprint)) + + # Use nnx.update to apply the restored data to the stateful NNX object. + # nnx.update is naturally partial: it will update 'model' from the restored dict + # and leave 'optimizer' untouched at its initialized value. + nnx.update(state_nnx, restored_pytree) + + # --- PHASE 3: Verification --- + # Check that weights were restored to the specific mock values + self.assertTrue(jnp.allclose(state_nnx.model.linear.kernel.value, mock_params["linear"]["kernel"])) + # Step remains at its initialized value (0) because it was not in the checkpoint + self.assertEqual(state_nnx.optimizer.step.value, 0) + + # Verify that the optimizer state still exists in the object (initialized) + # even though it was not provided in the checkpoint. + # Adam's state is at index 1 of the chain, and it's a nested structure (tuple). + # We verify that index 0 (ScaleByAdamState) contains the 'mu' State container. + self.assertIsInstance(state_nnx.optimizer.opt_state[1][0].mu, nnx.State) + + finally: + # Cleanup temporary directory + shutil.rmtree(temp_dir) + + +class TestMaybeSaveCheckpointStepAlignment(unittest.TestCase): + """Verify maybe_save_checkpoint's fallback step matches the last completed step. + + When the training loop's final save calls maybe_save_checkpoint without an + explicit `step`, it derives `actual_step` from the state: + - NNX: int(state.optimizer.step) - 1 + - Linen: int(state.step) - 1 + Both TrainStateNNX.apply_gradients (via nnx.Optimizer.update) and Linen + TrainState.apply_gradients increment the counter by 1 per call, so after N + gradient applications the counter is N and the "last completed step" is N-1. + """ + + N_STEPS = 5 + + def setUp(self): + self.tx = optax.adam(1e-3) + + def _build_nnx_state(self, num_steps): + """Build an nnx.State flattened from TrainStateNNX after num_steps gradient applications.""" + model = MockModel(rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(model, optimizer) + + def loss_fn(m): + return jnp.mean(m(jnp.ones((1, 2))) ** 2) + + for _ in range(num_steps): + grads = nnx.grad(loss_fn)(state.model) + state.apply_gradients(grads) + # maybe_save_checkpoint is called with a flat nnx.State in the NNX path + # (train_step returns nnx.state(new_state)). + return nnx.state(state) + + def _build_linen_state(self, num_steps): + """Build a Linen TrainState after num_steps gradient applications.""" + model = LinenMockModel() + variables = model.init(jax.random.key(0), jnp.ones((1, 2))) + state = train_state.TrainState.create(apply_fn=model.apply, params=variables["params"], tx=self.tx) + grads = jax.tree.map(jnp.ones_like, state.params) + for _ in range(num_steps): + state = state.apply_gradients(grads=grads) + return state + + def _invoke_maybe_save(self, state, pure_nnx): + """Call maybe_save_checkpoint with save_checkpoint patched, return {step, state} captured.""" + # checkpoint_period=1 keeps force_ckpt_save False regardless of actual_step. + config = SimpleNamespace(pure_nnx=pure_nnx, checkpoint_period=1, async_checkpointing=False) + mgr = mock.MagicMock() + mgr.reached_preemption.return_value = False + + captured = {} + + def fake_save_checkpoint(_mgr, step, state_arg, *_args, **_kwargs): + captured["step"] = step + captured["state"] = state_arg + return False # no save happened => print_save_message is skipped + + with mock.patch.object(checkpointing, "save_checkpoint", side_effect=fake_save_checkpoint): + checkpointing.maybe_save_checkpoint(mgr, state, config, data_iterator=None, step=None) + return captured + + def test_nnx_final_save_step_is_n_minus_1(self): + state = self._build_nnx_state(self.N_STEPS) + self.assertEqual(int(state.optimizer.step.value), self.N_STEPS) + captured = self._invoke_maybe_save(state, pure_nnx=True) + self.assertEqual(captured["step"], self.N_STEPS - 1) + + def test_linen_final_save_step_is_n_minus_1(self): + state = self._build_linen_state(self.N_STEPS) + self.assertEqual(int(state.step), self.N_STEPS) + captured = self._invoke_maybe_save(state, pure_nnx=False) + self.assertEqual(captured["step"], self.N_STEPS - 1) + + def test_nnx_and_linen_agree_on_actual_step(self): + """TrainStateNNX and Linen TrainState must yield the same fallback actual_step.""" + nnx_state = self._build_nnx_state(self.N_STEPS) + linen_state = self._build_linen_state(self.N_STEPS) + self.assertEqual( + self._invoke_maybe_save(nnx_state, pure_nnx=True)["step"], + self._invoke_maybe_save(linen_state, pure_nnx=False)["step"], + ) + + def test_nnx_state_is_converted_to_pure_dict_before_save(self): + """For pure_nnx=True, maybe_save_checkpoint must pass a plain dict to save_checkpoint, not an nnx.State.""" + state = self._build_nnx_state(self.N_STEPS) + self.assertIsInstance(state, nnx.State) # precondition: NNX train_step returns an nnx.State + + captured = self._invoke_maybe_save(state, pure_nnx=True) + + # save_checkpoint should have received a plain Python dict (the result of + # nnx.State.to_pure_dict()), not the original nnx.State. + self.assertIsInstance(captured["state"], dict) + self.assertNotIsInstance(captured["state"], nnx.State) + # Sanity: the converted dict still mirrors the TrainStateNNX structure. + self.assertIn("model", captured["state"]) + self.assertIn("optimizer", captured["state"]) + + def test_linen_state_is_passed_through_unchanged(self): + """For pure_nnx=False, maybe_save_checkpoint must pass the original TrainState object through.""" + state = self._build_linen_state(self.N_STEPS) + captured = self._invoke_maybe_save(state, pure_nnx=False) + # Linen path must not invoke to_pure_dict(); state is forwarded as-is. + self.assertIs(captured["state"], state) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_state_nnx_test.py b/tests/unit/train_state_nnx_test.py new file mode 100644 index 0000000000..03db77ff63 --- /dev/null +++ b/tests/unit/train_state_nnx_test.py @@ -0,0 +1,90 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TrainStateNNX tests.""" + +import unittest +import jax.numpy as jnp +from flax import nnx +import optax + +from maxtext.layers import train_state_nnx + + +class MockModel(nnx.Module): + """Mocked NNX model""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + def __call__(self, x): + return self.linear(x) + + +class TestTrainStateNNX(unittest.TestCase): + """TrainStateNNX tests.""" + + def setUp(self): + self.rngs = nnx.Rngs(0) + self.model = MockModel(rngs=self.rngs) + self.tx = optax.adam(1e-3) + + def test_init_with_optimizer(self): + """Test init with iptimizer.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + self.assertEqual(state.model, self.model) + self.assertEqual(state.optimizer, optimizer) + # Access step directly from optimizer + self.assertEqual(state.optimizer.step.value, 0) + + def test_init_without_optimizer(self): + """Test init without optimizer.""" + state = train_state_nnx.TrainStateNNX(self.model, None) + + self.assertEqual(state.model, self.model) + self.assertIsNone(state.optimizer) + + def test_apply_gradients_success(self): + """Test apply gradients can be called successfully.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # Create dummy gradients matching the model state structure + def loss_fn(m): + return jnp.mean(m(jnp.ones((1, 2))) ** 2) + + grads = nnx.grad(loss_fn)(state.model) + + # Apply gradients + state.apply_gradients(grads) + + # Verify step incremented (managed by nnx.Optimizer) + self.assertEqual(state.optimizer.step.value, 1) + + def test_apply_gradients_raises_runtime_error(self): + """Test apply gradients without a optimizer.""" + # Initialize without optimizer (inference mode) + state = train_state_nnx.TrainStateNNX(self.model, None) + + dummy_grads = {} + with self.assertRaises(RuntimeError) as cm: + state.apply_gradients(dummy_grads) + + self.assertIn("inference only", str(cm.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_utils_nnx_test.py b/tests/unit/train_utils_nnx_test.py new file mode 100644 index 0000000000..2ff7276fd9 --- /dev/null +++ b/tests/unit/train_utils_nnx_test.py @@ -0,0 +1,149 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX-specific helpers / patterns in train_utils.setup_train_loop. + +setup_train_loop itself is integration territory (it touches data iterators, +checkpoint managers, and a real mesh), so we cover the NNX-only pieces that +have unit-testable contracts: + + 1. The create_train_state_fn closure pattern: builds nnx.Optimizer + TrainStateNNX + from a zero-arg model factory and a transform. + 2. nnx.split(state.model, nnx.Param, ...) returns Param-only state used to + compute state_params / state_mesh_shardings_params. + 3. nnx.merge(state_graphdef, state) reconstitutes a TrainStateNNX from the + pure-state form returned by setup_training_state. +""" + +import unittest +from functools import partial + +import jax +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.layers import train_state_nnx + + +class _Model(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + +class TestCreateTrainStateFnClosure(unittest.TestCase): + """Exercise the closure pattern in setup_train_loop: + + def create_train_state_fn(): + model = _create_model_partial() + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(model, optimizer) + """ + + def test_returns_train_state_nnx_with_optimizer(self): + tx = optax.sgd(0.01) + + def _create_model(): + return _Model(rngs=nnx.Rngs(0)) + + def create_train_state_fn(): + model = _create_model() + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(model, optimizer) + + state = create_train_state_fn() + self.assertIsInstance(state, train_state_nnx.TrainStateNNX) + self.assertIsInstance(state.optimizer, nnx.Optimizer) + self.assertEqual(int(state.optimizer.step.get_value()), 0) + + def test_two_invocations_produce_independent_states(self): + """The lambda must call the factory each time (otherwise checkpoint init/restore would alias).""" + tx = optax.sgd(0.01) + counter = {"n": 0} + + def _create_model(): + counter["n"] += 1 + return _Model(rngs=nnx.Rngs(counter["n"])) + + def create_train_state_fn(): + model = _create_model() + return train_state_nnx.TrainStateNNX(model, nnx.Optimizer(model, tx, wrt=nnx.Param)) + + s1 = create_train_state_fn() + s2 = create_train_state_fn() + self.assertEqual(counter["n"], 2) + self.assertIsNot(s1.model, s2.model) + + +class TestSetupTrainLoopNNXTreeOps(unittest.TestCase): + """Cover the nnx.split(state.model, nnx.Param, ...) and nnx.merge round-trip + patterns that setup_train_loop uses to derive Param-only views and rebuild + the full TrainStateNNX before returning.""" + + def setUp(self): + self.tx = optax.sgd(0.01) + self.model = _Model(rngs=nnx.Rngs(0)) + self.state = train_state_nnx.TrainStateNNX(self.model, nnx.Optimizer(self.model, self.tx, wrt=nnx.Param)) + + def test_nnx_split_yields_param_only_state(self): + """state_params used for assert_params_sufficiently_sharded must contain only nnx.Param leaves.""" + _, state_params, _ = nnx.split(self.state.model, nnx.Param, ...) + leaves = jax.tree.leaves(state_params, is_leaf=lambda x: isinstance(x, nnx.Variable)) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + self.assertIsInstance(leaf, nnx.Param) + + def test_nnx_merge_reconstructs_train_state_nnx(self): + """setup_train_loop ends with nnx.merge(state_graphdef, state) — verify that round-trips.""" + state_graphdef, state_pure = nnx.split(self.state) + train_state = nnx.merge(state_graphdef, state_pure) + self.assertIsInstance(train_state, train_state_nnx.TrainStateNNX) + # Same numeric values. + self.assertTrue(jnp.allclose(train_state.model.linear.kernel.value, self.state.model.linear.kernel.value)) + + +class TestInitStateFnIsCallable(unittest.TestCase): + """For the Linen path setup_train_loop builds init_state_fn = partial(...). + + The NNX path uses a closure instead — confirm both forms have the + zero-argument call contract create_checkpoint_manager / setup_training_state expect. + """ + + def test_nnx_init_state_fn_callable_with_no_args(self): + tx = optax.sgd(0.01) + + def _create_model(): + return _Model(rngs=nnx.Rngs(0)) + + def init_state_fn(): + model = _create_model() + return train_state_nnx.TrainStateNNX(model, nnx.Optimizer(model, tx, wrt=nnx.Param)) + + state = init_state_fn() # must not raise / require args + self.assertIsInstance(state, train_state_nnx.TrainStateNNX) + + def test_linen_init_state_fn_is_partial_callable_with_no_args(self): + """Sanity: the Linen-side `partial(init_initial_state, model, tx, config, is_training, init_rng)` form.""" + + def init_initial_state(model, tx, config, is_training, init_rng): + del model, tx, config, is_training, init_rng + return "linen-state" + + init_state_fn = partial(init_initial_state, "model", "tx", "config", True, "rng") + self.assertEqual(init_state_fn(), "linen-state") + + +if __name__ == "__main__": + unittest.main() From e641bea7fe9a3b5da72f0b2fb52a62d0c0cf5e99 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 31 Mar 2026 14:32:29 +0000 Subject: [PATCH 2/5] NNX: add sharding tools, Linen<->NNX checkpoint utilities, and post-training fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Part 1 — sharding diagnostics and Linen<->NNX checkpoint utilities: - modify print_shardings_params to support NNX (maxtext_utils.py) - add --pure_nnx flag to run_sharding_dump.py - add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py) - add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py) Part 2 — post-training bug fixes: - models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the whole object as multimodal_input= kwarg; NNXDecoder only accepts individual fields) - optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams (callable() check before invoking learning_rate_fn) - train_distill.py: fix nested NNX transform issue (nnx.value_and_grad inside nnx.jit raises conflicting outer_index error); refactored to jax.value_and_grad + explicit nnx.split/merge pattern; teacher inference moved outside value_and_grad --- .../compare_linen_nnx_checkpoint.py | 609 ++++++++++++ .../linen_nnx_converter.py | 581 ++++++++++++ src/maxtext/models/models.py | 6 +- src/maxtext/optimizers/optimizers.py | 4 +- .../post_train/distillation/train_distill.py | 78 +- .../trainers/post_train/rl/train_rl.py | 43 +- .../trainers/post_train/sft/train_sft.py | 70 +- src/maxtext/utils/maxtext_utils.py | 47 +- .../unit/distillation_scheduling_test.py | 44 +- .../post_training/unit/train_distill_test.py | 84 +- .../unit/compare_linen_nnx_checkpoint_test.py | 501 ++++++++++ tests/unit/linen_nnx_converter_test.py | 869 ++++++++++++++++++ tests/utils/run_sharding_dump.py | 9 +- 13 files changed, 2856 insertions(+), 89 deletions(-) create mode 100644 src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py create mode 100644 src/maxtext/checkpoint_conversion/linen_nnx_converter.py create mode 100644 tests/unit/compare_linen_nnx_checkpoint_test.py create mode 100644 tests/unit/linen_nnx_converter_test.py diff --git a/src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py b/src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py new file mode 100644 index 0000000000..7439ac36a0 --- /dev/null +++ b/src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py @@ -0,0 +1,609 @@ +# Copyright 2023-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Compare checkpoint tree structures, shapes, and values. + +Supports comparing any combination of Linen and NNX checkpoints: +- Linen vs NNX (cross-format comparison) +- Linen vs Linen (same-format comparison) +- NNX vs NNX (same-format comparison) + +The script auto-detects the format of each checkpoint and applies the +appropriate normalization. Cross-format transformations (like layer axis +transposition) are only applied when comparing Linen vs NNX. + +Key differences between Linen and NNX checkpoints: +- Linen: params/params/decoder/layers/0/... (per-layer, double nested) +- NNX: model/decoder/layers/... (stacked layers, single nested, {value: array} wrappers) + +The script handles: +- Double 'params' nesting in Linen checkpoints +- 'model' key in NNX checkpoints (vs 'params' in Linen) +- {value: array} wrappers in NNX checkpoints +- Layer axis transposition (NNX stacks layers along axis 0, only for cross-format) +- RNG filtering (NNX has rngs, Linen doesn't) + +Usage: + # Compare Linen vs NNX (structure and shapes only) + python compare_linen_nnx_checkpoint.py \ + --ckpt_path_1="gs://bucket/linen_checkpoint/0/items" \ + --ckpt_path_2="gs://bucket/nnx_checkpoint/0/items" + + # Compare NNX vs NNX + python compare_linen_nnx_checkpoint.py \ + --ckpt_path_1="gs://bucket/nnx_checkpoint_a/0/items" \ + --ckpt_path_2="gs://bucket/nnx_checkpoint_b/0/items" + + # Compare Linen vs Linen + python compare_linen_nnx_checkpoint.py \ + --ckpt_path_1="gs://bucket/linen_checkpoint_a/0/items" \ + --ckpt_path_2="gs://bucket/linen_checkpoint_b/0/items" + + # Compare with value checking + python compare_linen_nnx_checkpoint.py \ + --ckpt_path_1="gs://bucket/checkpoint_a/0/items" \ + --ckpt_path_2="gs://bucket/checkpoint_b/0/items" \ + --compare_values --atol=1e-5 --rtol=1e-5 +""" + +import os +from typing import Any, Dict, Sequence + +# MUST set before importing JAX to force CPU-only mode +os.environ["JAX_PLATFORMS"] = "cpu" + +import jax +import jax.numpy as jnp +from jax.tree_util import tree_flatten_with_path, keystr, tree_structure, tree_map_with_path +import numpy as np +from etils import epath +import orbax.checkpoint as ocp +from absl import app +from absl import flags + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "ckpt_path_1", + None, + "Path to the first checkpoint items directory. Format is auto-detected.", +) +flags.DEFINE_string( + "ckpt_path_2", + None, + "Path to the second checkpoint items directory. Format is auto-detected.", +) +flags.DEFINE_boolean( + "verbose", + False, + "Print detailed per-parameter information.", +) +flags.DEFINE_boolean( + "transpose_nnx_layers", + False, + "Transpose NNX layer params from (layers, ...) to (...) for comparison. " + "NNX stacks layers along axis 0, while Linen stores per-layer params. " + "Only applied for cross-format (Linen vs NNX) comparisons.", +) +flags.DEFINE_string( + "compare_only", + "params", + "Which parts to compare: 'params' for params only, 'all' for full state.", +) +flags.DEFINE_boolean( + "ignore_rngs", + True, + "Ignore RNG-related paths in comparison (NNX has rngs, Linen doesn't).", +) +flags.DEFINE_boolean( + "compare_values", + False, + "Also compare parameter values (not just structure and shapes).", +) +flags.DEFINE_float( + "atol", + 1e-5, + "Absolute tolerance for value comparison.", +) +flags.DEFINE_float( + "rtol", + 1e-5, + "Relative tolerance for value comparison.", +) + + +def log(message: str) -> None: + """Log a message with prefix.""" + print(f"[compare_ckpt] {message}") + + +def is_rng_path(path: str) -> bool: + """Check if a path is RNG-related.""" + path_lower = path.lower() + return "rngs" in path_lower or "rng" in path_lower + + +def filter_rngs(tree: Dict[str, Any]) -> Dict[str, Any]: + """Filter out RNG-related keys from a tree.""" + if not isinstance(tree, dict): + return tree + + result = {} + for key, value in tree.items(): + # Skip RNG-related keys + if is_rng_path(key): + continue + # Recursively filter nested dicts + if isinstance(value, dict): + filtered = filter_rngs(value) + if filtered: # Only add if not empty after filtering + result[key] = filtered + else: + result[key] = value + return result + + +def detect_format(state: dict) -> str: + """Detects checkpoint format from state structure ('linen' or 'nnx'). + + Linen format: + - Top-level keys: ['params', 'opt_state', 'step'] + - params/params/decoder/... (double nested) + + NNX format: + - Top-level keys: ['model', 'optimizer'] (nnx.State style) + - model/decoder/... with {value: array} wrappers + """ + # Check for NNX nnx.State format (has 'model' key instead of 'params') + if "model" in state: + return "nnx" + + if "params" not in state: + raise ValueError(f"Checkpoint does not contain 'params' or 'model' key. Found keys: {list(state.keys())}") + + params = state["params"] + + # Check for Linen's double 'params' nesting + if isinstance(params, dict) and "params" in params: + inner = params["params"] + if isinstance(inner, dict) and ("decoder" in inner or "encoder" in inner): + return "linen" + + # Check for NNX's flat structure (params/decoder/...) + if isinstance(params, dict) and ("decoder" in params or "encoder" in params): + return "nnx" + + # Try to detect by looking for {value: array} wrappers (NNX style) + if _has_value_wrappers(params): + return "nnx" + + raise ValueError( + f"Could not detect checkpoint format. params keys: {list(params.keys()) if isinstance(params, dict) else type(params)}" + ) + + +def _has_value_wrappers(tree: Any) -> bool: + """Check if tree contains {value: array} wrappers (NNX style).""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, (np.ndarray, jnp.ndarray)): + return True + for v in tree.values(): + if _has_value_wrappers(v): + return True + return False + + +def _strip_value_wrappers(tree: Any) -> Any: + """Recursively strips {'value': array} wrappers from a tree.""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, (np.ndarray, jnp.ndarray)): + return inner + return {k: _strip_value_wrappers(v) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_strip_value_wrappers(item) for item in tree) + else: + return tree + + +def _normalize_linen_params(params: dict) -> dict: + """Normalize Linen params by removing double 'params' nesting.""" + if isinstance(params, dict) and "params" in params: + inner = params["params"] + if isinstance(inner, dict) and ("decoder" in inner or "encoder" in inner): + return inner + return params + + +def _normalize_nnx_params(params: dict) -> dict: + """Normalize NNX params by stripping {value: array} wrappers.""" + return _strip_value_wrappers(params) + + +def load_checkpoint(checkpoint_path: str, metadata_only: bool = False) -> dict: + """Loads checkpoint from local or GCS path. + + If metadata_only=True, returns a pytree of ArrayMetadata (shape/dtype only) + without downloading any tensor data. This is fast and sufficient for + structure/shape comparison. + """ + log(f"Loading checkpoint from: {checkpoint_path}") + if metadata_only: + log(" Mode: metadata only (no tensor data downloaded)") + + checkpoint_dir = epath.Path(checkpoint_path) + + # Create checkpointer and get metadata + ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) + + try: + metadata = ckptr.metadata(checkpoint_dir) + + if metadata_only: + tree = metadata.item_metadata.tree + log(f" Loaded metadata keys: {list(tree.keys())}") + return tree + + # Create a mesh with all available devices for unsharded restoration + devices = np.array(jax.devices()).reshape((-1,)) + single_device_mesh = jax.sharding.Mesh(devices, ("x",)) + unsharded = jax.sharding.NamedSharding(single_device_mesh, jax.sharding.PartitionSpec()) + + # Build restore args that restore arrays without original sharding + restore_args = jax.tree_util.tree_map( + lambda x: ocp.ArrayRestoreArgs(sharding=unsharded) if hasattr(x, "shape") else None, + metadata.item_metadata.tree, + is_leaf=lambda x: hasattr(x, "shape"), + ) + state = ckptr.restore(checkpoint_dir, restore_args=restore_args) + except Exception as e: # pylint: disable=broad-exception-caught + if metadata_only: + log(f" Metadata loading failed: {e}") + raise + # Fallback to simple restore without sharding args + log(f" Falling back to simple restore: {e}") + checkpointer = ocp.PyTreeCheckpointer() + state = checkpointer.restore(checkpoint_path) + + if state is None: + raise ValueError(f"Failed to restore checkpoint from {checkpoint_path}") + + log(f" Loaded keys: {list(state.keys())}") + return state + + +def transform_nnx_params_for_comparison(nnx_params: Dict[str, Any]) -> Dict[str, Any]: + """Transform NNX params to match Linen structure for comparison. + + NNX stacks layer parameters along axis 0 (shape: [num_layers, ...]), + while Linen stores per-layer parameters (shape: [...]). + + This function transposes layer params from (layers, d1, d2, ...) to (d1, layers, d2, ...) + to align with how Linen params would look if stacked. + """ + + def _transform(path, leaf: jax.Array) -> jax.Array: + key_str = keystr(path) + + # Only transform arrays in 'layers' with ndim >= 2 + if "layers" in key_str and hasattr(leaf, "ndim") and leaf.ndim >= 2: + # Transpose from (layers, d1, d2, ...) to (d1, layers, d2, ...) + axes = (1, 0) + tuple(range(2, leaf.ndim)) + result = jnp.transpose(leaf, axes=axes) + if FLAGS.verbose: + log(f" TRANSPOSING: {key_str} shape {leaf.shape} -> {result.shape}") + return result + else: + return leaf + + log("Transforming NNX params (transposing layer dimensions)...") + return tree_map_with_path(_transform, nnx_params) + + +def get_tree_structure_info(tree: Dict[str, Any]) -> Dict[str, tuple]: + """Get structure info as dict of path -> (shape, dtype).""" + flat_with_path, _ = tree_flatten_with_path(tree) + return { + keystr(p): ( + getattr(leaf, "shape", "N/A"), + str(getattr(leaf, "dtype", type(leaf).__name__)), + ) + for p, leaf in flat_with_path + } + + +def print_structure_diff(params1: Dict, params2: Dict, name1: str = "Linen", name2: str = "NNX"): + """Print structural differences between two param trees.""" + info1 = get_tree_structure_info(params1) + info2 = get_tree_structure_info(params2) + keys1, keys2 = set(info1.keys()), set(info2.keys()) + + only_in_1 = sorted(keys1 - keys2) + only_in_2 = sorted(keys2 - keys1) + common = keys1 & keys2 + + if only_in_1: + print(f"\n--- Paths only in {name1} ({len(only_in_1)}) ---") + for k in only_in_1: + shape, dtype = info1[k] + print(f" - {k}: shape={shape}, dtype={dtype}") + + if only_in_2: + print(f"\n--- Paths only in {name2} ({len(only_in_2)}) ---") + for k in only_in_2: + shape, dtype = info2[k] + print(f" + {k}: shape={shape}, dtype={dtype}") + + # Check for shape/dtype mismatches in common paths + shape_mismatches = [] + dtype_mismatches = [] + for k in common: + shape1, dtype1 = info1[k] + shape2, dtype2 = info2[k] + if shape1 != shape2: + shape_mismatches.append((k, shape1, shape2)) + if dtype1 != dtype2: + dtype_mismatches.append((k, dtype1, dtype2)) + + if shape_mismatches: + print(f"\n--- Shape mismatches ({len(shape_mismatches)}) ---") + for k, s1, s2 in shape_mismatches: + print(f" {k}: {name1}={s1}, {name2}={s2}") + + if dtype_mismatches: + print(f"\n--- Dtype mismatches ({len(dtype_mismatches)}) ---") + for k, d1, d2 in dtype_mismatches: + print(f" {k}: {name1}={d1}, {name2}={d2}") + + return only_in_1, only_in_2, shape_mismatches, dtype_mismatches + + +def compare_params( + params1: Dict[str, Any], + params2: Dict[str, Any], + verbose: bool = False, + compare_values: bool = False, + atol: float = 1e-5, + rtol: float = 1e-5, + name1: str = "Ckpt1", + name2: str = "Ckpt2", +) -> bool: + """Compare two parameter trees for structure, shape, and optionally values. + + Returns True if tree structures, shapes, and (optionally) values match. + """ + # First check tree structure + if tree_structure(params1) != tree_structure(params2): + print("\n[✗] Tree structures differ.") + print_structure_diff(params1, params2, name1=name1, name2=name2) + return False + + print("\n[✓] Tree structures are the same.") + + all_match = True + num_params = 0 + shape_mismatches = [] + dtype_mismatches = [] + value_mismatches = [] + value_matches = 0 + + def _compare_leaf(path, x, y): + nonlocal all_match, num_params, shape_mismatches, dtype_mismatches, value_mismatches, value_matches + key_str = keystr(path) + num_params += 1 + + shape1 = getattr(x, "shape", "N/A") + shape2 = getattr(y, "shape", "N/A") + dtype1 = getattr(x, "dtype", type(x).__name__) + dtype2 = getattr(y, "dtype", type(y).__name__) + + # Check shape + shape_match = shape1 == shape2 + if not shape_match: + shape_mismatches.append((key_str, shape1, shape2)) + all_match = False + + # Check dtype + dtype_match = str(dtype1) == str(dtype2) + if not dtype_match: + dtype_mismatches.append((key_str, dtype1, dtype2)) + all_match = False + + # Check values if requested and shapes match + if compare_values and shape_match and hasattr(x, "shape") and hasattr(y, "shape"): + try: + x_arr = np.asarray(x) + y_arr = np.asarray(y) + is_close = bool(np.allclose(x_arr, y_arr, atol=atol, rtol=rtol)) + + if is_close: + value_matches += 1 + if verbose: + print(f" [✓] {key_str} | Shape: {shape1} | Values match") + else: + diff = np.abs(x_arr - y_arr) + mean_diff = float(np.mean(diff)) + max_diff = float(np.max(diff)) + value_mismatches.append((key_str, mean_diff, max_diff)) + all_match = False + if verbose: + print(f" [✗] {key_str} | Shape: {shape1} | Mean diff: {mean_diff:.2e}, Max diff: {max_diff:.2e}") + except Exception as e: # pylint: disable=broad-exception-caught + value_mismatches.append((key_str, f"Error: {e}", "")) + all_match = False + elif verbose and not compare_values: + print(f" {key_str} | Shape: {shape1} | Dtype: {dtype1}") + + tree_map_with_path(_compare_leaf, params1, params2) + + # Print summary + print("\n--- Summary ---") + print(f"Total parameters: {num_params}") + + if shape_mismatches: + print(f"\n[✗] Shape mismatches ({len(shape_mismatches)}):") + for key_str, s1, s2 in shape_mismatches: + print(f" {key_str}: {name1}={s1}, {name2}={s2}") + else: + print("[✓] All shapes match.") + + if dtype_mismatches: + print(f"\n[✗] Dtype mismatches ({len(dtype_mismatches)}):") + for key_str, d1, d2 in dtype_mismatches: + print(f" {key_str}: {name1}={d1}, {name2}={d2}") + else: + print("[✓] All dtypes match.") + + if compare_values: + if value_mismatches: + print(f"\n[✗] Value mismatches ({len(value_mismatches)}):") + for item in value_mismatches[:20]: # Show first 20 + if len(item) == 3: + key_str, mean_diff, max_diff = item + if isinstance(mean_diff, float): + print(f" {key_str}: mean_diff={mean_diff:.2e}, max_diff={max_diff:.2e}") + else: + print(f" {key_str}: {mean_diff}") + if len(value_mismatches) > 20: + print(f" ... and {len(value_mismatches) - 20} more (use --verbose to see all)") + else: + print(f"[✓] All values match (atol={atol}, rtol={rtol}).") + print(f" Values matching: {value_matches}/{num_params}") + + return all_match + + +def _extract_params(state: dict, fmt: str) -> dict: + """Extract params from a checkpoint state based on its detected format.""" + if fmt == "linen": + return state.get("params", {}) + else: + # NNX format: params are in 'model' key + return state.get("model", state.get("params", {})) + + +def _normalize_params(params: dict, fmt: str) -> dict: + """Normalize params based on detected format.""" + if fmt == "linen": + return _normalize_linen_params(params) + else: + return _normalize_nnx_params(params) + + +def main(argv: Sequence[str]): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + ckpt_path_1 = FLAGS.ckpt_path_1 + ckpt_path_2 = FLAGS.ckpt_path_2 + if not ckpt_path_1 or not ckpt_path_2: + raise app.UsageError("--ckpt_path_1 and --ckpt_path_2 are required.") + + print("=" * 80) + print("Checkpoint Comparator") + print("=" * 80) + + print(f"\nCheckpoint 1: {ckpt_path_1}") + print(f"Checkpoint 2: {ckpt_path_2}") + print(f"Transpose NNX layers: {FLAGS.transpose_nnx_layers}") + print(f"Ignore RNGs: {FLAGS.ignore_rngs}") + print(f"Compare values: {FLAGS.compare_values}") + if FLAGS.compare_values: + print(f" Tolerance: atol={FLAGS.atol}, rtol={FLAGS.rtol}") + + # Load checkpoints — use metadata-only when not comparing values to avoid + # downloading tensor data (which can be 100+ GiB and cause XPK timeouts). + metadata_only = not FLAGS.compare_values + print("\n" + "-" * 40) + state_1 = load_checkpoint(ckpt_path_1, metadata_only=metadata_only) + state_2 = load_checkpoint(ckpt_path_2, metadata_only=metadata_only) + + # Detect formats + format_1 = detect_format(state_1) + format_2 = detect_format(state_2) + log(f"Detected checkpoint 1 format: {format_1}") + log(f"Detected checkpoint 2 format: {format_2}") + + is_cross_format = format_1 != format_2 + name_1 = f"Ckpt1({format_1})" + name_2 = f"Ckpt2({format_2})" + + # Extract and normalize params + print("\n" + "-" * 40) + log("Normalizing parameters...") + + if FLAGS.compare_only == "params": + params_1 = _extract_params(state_1, format_1) + params_2 = _extract_params(state_2, format_2) + else: + params_1 = state_1 + params_2 = state_2 + + params_1 = _normalize_params(params_1, format_1) + log(f" Checkpoint 1 ({format_1}): normalized") + params_2 = _normalize_params(params_2, format_2) + log(f" Checkpoint 2 ({format_2}): normalized") + + # Filter out RNG paths if requested + if FLAGS.ignore_rngs: + print("\n" + "-" * 40) + log("Filtering out RNG-related paths...") + params_1 = filter_rngs(params_1) + params_2 = filter_rngs(params_2) + + # Transform NNX params for cross-format comparison (transpose layer dimensions) + # Only apply when comparing Linen vs NNX, not for same-format comparisons + if FLAGS.transpose_nnx_layers and is_cross_format: + print("\n" + "-" * 40) + if format_1 == "nnx": + params_1 = transform_nnx_params_for_comparison(params_1) + if format_2 == "nnx": + params_2 = transform_nnx_params_for_comparison(params_2) + + # Compare + print("\n" + "-" * 40) + log("Comparing parameters...") + + success = compare_params( + params_1, + params_2, + verbose=FLAGS.verbose, + compare_values=FLAGS.compare_values, + atol=FLAGS.atol, + rtol=FLAGS.rtol, + name1=name_1, + name2=name_2, + ) + + # Final verdict + print("\n" + "=" * 80) + if success: + print("CHECKPOINTS MATCH") + if FLAGS.compare_values: + print(" Tree structure, shapes, and values are identical!") + else: + print(" Tree structure and all shapes are identical!") + else: + print("CHECKPOINTS DIFFER") + print(" See details above for mismatches.") + print("=" * 80) + + return 0 if success else 1 + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxtext/checkpoint_conversion/linen_nnx_converter.py b/src/maxtext/checkpoint_conversion/linen_nnx_converter.py new file mode 100644 index 0000000000..015d3b5a56 --- /dev/null +++ b/src/maxtext/checkpoint_conversion/linen_nnx_converter.py @@ -0,0 +1,581 @@ +# Copyright 2023-2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bidirectional conversion between Linen and NNX checkpoint formats. + +Top-level key mapping: + Linen → NNX: + params/params/ → model/ (remove double-nesting, rename, add {value:} wrappers) + opt_state → optimizer/opt_state (remove 'params' level from mu/nu) + step → optimizer/step (move inside optimizer) + + NNX → Linen: + model/ → params/params/ (strip {value:} wrappers, add double-nesting) + optimizer/opt_state → opt_state (add 'params' level to mu/nu) + optimizer/step → step (move to top level) + +Layer structure (--scan_layers): + linen_to_nnx: + scan_layers=True (default): stack layers_N arrays → 'layers' tensor with layer dim at axis 1 + scan_layers=False: rename layers_N → integer-keyed 'layers/{N}' + + nnx_to_linen (auto-detected): + Stacked 'layers' tensor → unstack along axis 1 → layers_N per-layer arrays + Integer-keyed layers/{N} → rename to layers_N + +Usage: + python linen_nnx_converter.py \\ + --source_path="gs://bucket/checkpoint/0/items" \\ + --target_path="gs://bucket/converted/" \\ + --direction=auto +""" + +import argparse +import os +import re +import time +from typing import Any + +# MUST set before importing JAX to force CPU-only mode +os.environ["JAX_PLATFORMS"] = "cpu" + +import jax +import numpy as np +from etils import epath +import orbax.checkpoint as ocp + + +def log(message: str) -> None: + print(f"[linen_nnx_converter] {message}") + + +# ── Format detection ─────────────────────────────────────────────────────────── + + +def detect_format(state: dict) -> str: + """Detects checkpoint format ('linen' or 'nnx') from top-level keys.""" + # NNX: uses 'model' as the top-level params key + if "model" in state: + return "nnx" + + if "params" not in state: + raise ValueError(f"Cannot detect checkpoint format: no 'model' or 'params' key. " f"Found: {list(state.keys())}") + + params = state["params"] + + # Linen: double-nested params/params/decoder + if isinstance(params, dict) and "params" in params: + inner = params["params"] + if isinstance(inner, dict) and ("decoder" in inner or "encoder" in inner): + return "linen" + + # Old NNX format: params/decoder (single-nested with value wrappers) + if isinstance(params, dict) and ("decoder" in params or "encoder" in params): + if _has_value_wrappers(params): + return "nnx" + + if "optimizer" in state: + return "nnx" + if "opt_state" in state: + return "linen" + + raise ValueError( + f"Could not detect checkpoint format. Keys: {list(state.keys())}, " + f"params keys: {list(params.keys()) if isinstance(params, dict) else type(params)}" + ) + + +# ── Value wrapper helpers ────────────────────────────────────────────────────── + + +def _has_value_wrappers(tree: Any) -> bool: + """Returns True if tree contains {value: array} wrappers (NNX style).""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return True + for v in tree.values(): + if _has_value_wrappers(v): + return True + return False + + +def _strip_value_wrappers(tree: Any) -> Any: + """Recursively strips {value: array} wrappers from a tree.""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return inner + return {k: _strip_value_wrappers(v) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_strip_value_wrappers(item) for item in tree) + else: + return tree + + +def _add_value_wrappers(tree: Any) -> Any: + """Recursively wraps leaf arrays in {value: array} (NNX nnx.Param format).""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return tree # Already wrapped + return {k: _add_value_wrappers(v) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_add_value_wrappers(item) for item in tree) + elif hasattr(tree, "shape") or isinstance(tree, np.ndarray): + return {"value": tree} + else: + return tree + + +# ── Layer structure helpers ──────────────────────────────────────────────────── + + +def _stack_layers(decoder: dict) -> tuple[dict, bool]: + """Stacks per-layer parameters (layers_N) into a single 'layers' dict at axis 0. + + Returns (result_dict, was_stacked). + """ + layer_pattern = re.compile(r"^layers_(\d+)$") + layer_indices = {} + other_keys = {} + + for key, value in decoder.items(): + match = layer_pattern.match(key) + if match: + layer_indices[int(match.group(1))] = value + else: + other_keys[key] = value + + if not layer_indices: + return decoder, False + + sorted_indices = sorted(layer_indices.keys()) + num_layers = len(sorted_indices) + log(f" Found {num_layers} individual layers, stacking into 'layers'") + + def stack_arrays(layers_data: list) -> Any: + first = layers_data[0] + if hasattr(first, "shape") or isinstance(first, np.ndarray): + return np.stack([np.asarray(layers_data[i]) for i in range(len(layers_data))], axis=0) + elif isinstance(first, dict): + result = {} + for key in first.keys(): + child_data = [layers_data[i].get(key) for i in range(len(layers_data))] + if all(c is not None for c in child_data): + result[key] = stack_arrays(child_data) + return result + else: + return first + + layers_data = [layer_indices[i] for i in sorted_indices] + stacked = stack_arrays(layers_data) + + result = dict(other_keys) + result["layers"] = stacked + return result, True + + +def _rename_layers_to_integer_keys(decoder: dict) -> dict: + """Converts layers_N keys to integer-keyed dict under 'layers' (no stacking). + + Converts {layers_0: {...}, layers_1: {...}} → {layers: {'0': {...}, '1': {...}}}. + Used for scan_layers=False linen→nnx conversion (Pattern C). + """ + layer_pattern = re.compile(r"^layers_(\d+)$") + layer_indices = {} + other_keys = {} + + for key, value in decoder.items(): + match = layer_pattern.match(key) + if match: + layer_indices[int(match.group(1))] = value + else: + other_keys[key] = value + + if not layer_indices: + return decoder + + sorted_indices = sorted(layer_indices.keys()) + log(f" Found {len(sorted_indices)} individual layers, renaming to integer-keyed 'layers/N'") + result = dict(other_keys) + result["layers"] = {str(i): layer_indices[i] for i in sorted_indices} + return result + + +def _transpose_layers_axes(tree: Any, src_axis: int, dst_axis: int) -> Any: + """Transposes the layers dimension in arrays within a tree (src_axis ↔ dst_axis).""" + if src_axis == dst_axis: + return tree + if isinstance(tree, dict): + return {k: _transpose_layers_axes(v, src_axis, dst_axis) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_transpose_layers_axes(item, src_axis, dst_axis) for item in tree) + elif hasattr(tree, "shape") and len(tree.shape) >= 2: + axes = list(range(len(tree.shape))) + axes[src_axis], axes[dst_axis] = axes[dst_axis], axes[src_axis] + result = np.transpose(np.asarray(tree), axes=axes) + log(f" Transposed: {tree.shape} → {result.shape}") + return result + else: + return tree + + +def _detect_num_layers(tree: Any, scan_axis: int) -> int | None: + """Detects num_layers from the first array with ndim > scan_axis.""" + if hasattr(tree, "shape") or isinstance(tree, np.ndarray): + shape = getattr(tree, "shape", None) or np.asarray(tree).shape + if len(shape) > scan_axis: + return shape[scan_axis] + return None + if isinstance(tree, dict): + for v in tree.values(): + result = _detect_num_layers(v, scan_axis) + if result is not None: + return result + return None + + +def _unstack_single_layer(tree: Any, idx: int, scan_axis: int) -> Any: + """Extracts a single layer by indexing at scan_axis.""" + if hasattr(tree, "shape") or isinstance(tree, np.ndarray): + arr = np.asarray(tree) + if arr.ndim > scan_axis: + return np.take(arr, idx, axis=scan_axis) + return arr + if isinstance(tree, dict): + return {k: _unstack_single_layer(v, idx, scan_axis) for k, v in tree.items()} + if isinstance(tree, (list, tuple)): + return type(tree)(_unstack_single_layer(v, idx, scan_axis) for v in tree) + return tree + + +def _convert_layers_to_linen_format(decoder: dict) -> dict: + """Converts NNX 'layers' back to Linen's layers_N format (auto-detects NNX style). + + Handles: + - Stacked tensor (Pattern B): layers/ + → layers_0, layers_1, ... (unstack along axis 1) + - Integer-keyed (Pattern C): layers/0, layers/1, ... + → layers_0, layers_1, ... (rename) + """ + if "layers" not in decoder: + return decoder + + layers_val = decoder["layers"] + other_keys = {k: v for k, v in decoder.items() if k != "layers"} + + if not isinstance(layers_val, dict): + # Already a non-dict (shouldn't happen normally), keep as-is + return decoder + + # Pattern C: integer-keyed per-layer dict → rename + if all(k.isdigit() for k in layers_val.keys()): + result = dict(other_keys) + for idx_str, layer_data in sorted(layers_val.items(), key=lambda x: int(x[0])): + result[f"layers_{idx_str}"] = layer_data + log(f" Renamed integer-keyed layers/N → layers_N ({len(layers_val)} layers)") + return result + + # Pattern B: stacked tensor (layer dim at axis 1) → unstack + num_layers = _detect_num_layers(layers_val, scan_axis=1) + if num_layers is None: + log(" WARNING: Could not detect num_layers for unstacking, keeping 'layers' as-is") + result = dict(other_keys) + result["layers"] = layers_val + return result + + result = dict(other_keys) + for i in range(num_layers): + result[f"layers_{i}"] = _unstack_single_layer(layers_val, idx=i, scan_axis=1) + log(f" Unstacked scanned 'layers' → layers_N ({num_layers} layers at axis 1)") + return result + + +# ── Optimizer state helpers ──────────────────────────────────────────────────── + + +def _convert_opt_state_linen_to_nnx(opt_state: Any) -> Any: + """Removes 'params' nesting from mu/nu in linen opt_state. + + NNX optimizer state has plain arrays (no {value:} wrappers). + Linen opt_state mirrors the params structure (params/decoder/...), + so we remove the 'params' level to get decoder/... directly. + """ + if isinstance(opt_state, dict): + result = {} + for k, v in opt_state.items(): + if k == "params": + # Remove this level by merging its contents up + converted = _convert_opt_state_linen_to_nnx(v) + if isinstance(converted, dict): + result.update(converted) + else: + result[k] = converted + else: + result[k] = _convert_opt_state_linen_to_nnx(v) + return result + elif isinstance(opt_state, (list, tuple)): + return type(opt_state)(_convert_opt_state_linen_to_nnx(item) for item in opt_state) + else: + return opt_state # Plain array or scalar — no value wrapper for opt_state + + +def _convert_opt_state_nnx_to_linen(opt_state: Any, depth: int = 0) -> Any: + """Adds 'params' nesting to mu/nu, removes any stray {value:} wrappers. + + NNX optimizer mu/nu contains decoder/... directly. + Linen expects mu/params/decoder/... (one 'params' level mirroring the params structure). + """ + if isinstance(opt_state, dict): + # Strip any {value:} wrappers in opt_state (shouldn't be there but handle gracefully) + if set(opt_state.keys()) == {"value"}: + inner = opt_state["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return inner + + result = {} + for k, v in opt_state.items(): + converted = _convert_opt_state_nnx_to_linen(v, depth + 1) + # Add one 'params' level after mu/nu (mirrors linen's params structure) + if k in ("mu", "nu") and isinstance(converted, dict): + result[k] = {"params": converted} + else: + result[k] = converted + return result + elif isinstance(opt_state, (list, tuple)): + return type(opt_state)(_convert_opt_state_nnx_to_linen(item, depth + 1) for item in opt_state) + else: + return opt_state + + +# ── Main conversion functions ────────────────────────────────────────────────── + + +def convert_linen_to_nnx(state: dict, scan_layers: bool = True) -> dict: + """Converts Linen checkpoint to NNX format. + + Args: + state: Linen checkpoint dict with keys ['params', 'opt_state', 'step']. + scan_layers: If True (default), stack per-layer arrays and insert layer + dim at axis 1 (for NNX with scan_layers=True). + If False, rename layers_N → integer-keyed layers/N + (for NNX with scan_layers=False). + """ + result = {} + + if "params" in state: + linen_params = state["params"] + # Remove double 'params' nesting: params/params/decoder → decoder + if isinstance(linen_params, dict) and "params" in linen_params: + nnx_params = linen_params["params"] + log(" params: Removed double 'params' nesting (params/params → model)") + else: + nnx_params = linen_params + log(" params: No double nesting found") + + stripped = _strip_value_wrappers(nnx_params) + + for component in ("decoder", "encoder"): + if component in stripped and isinstance(stripped[component], dict): + if scan_layers: + stripped[component], was_stacked = _stack_layers(stripped[component]) + if was_stacked and "layers" in stripped[component]: + log(f" {component}/layers: Transposing stacked (layers, ...) → (..., layers, ...) at axis 1") + stripped[component]["layers"] = _transpose_layers_axes(stripped[component]["layers"], src_axis=0, dst_axis=1) + else: + stripped[component] = _rename_layers_to_integer_keys(stripped[component]) + + result["model"] = _add_value_wrappers(stripped) + log(" model: Saved with {value:} wrappers under 'model' key") + + # optimizer: move step inside, keep opt_state + optimizer_dict = {} + if "step" in state: + optimizer_dict["step"] = state["step"] + log(f" optimizer/step: Moved from top-level (step={state['step']})") + if "opt_state" in state: + optimizer_dict["opt_state"] = _convert_opt_state_linen_to_nnx(state["opt_state"]) + log(" optimizer/opt_state: Removed 'params' nesting from mu/nu") + if optimizer_dict: + result["optimizer"] = optimizer_dict + + return result + + +def convert_nnx_to_linen(state: dict) -> dict: + """Converts NNX checkpoint to Linen format. + + Reads from 'model'/'optimizer' keys (or falls back to old 'params'/'opt_state' format). + Layer structure is auto-detected (stacked vs integer-keyed). + """ + result = {} + + model_key = "model" if "model" in state else "params" + if model_key in state: + nnx_params = state[model_key] + stripped = _strip_value_wrappers(nnx_params) + log(f" {model_key}: Removed {{value:}} wrappers") + + for component in ("decoder", "encoder"): + if component in stripped and isinstance(stripped[component], dict): + stripped[component] = _convert_layers_to_linen_format(stripped[component]) + + # Add double 'params' nesting: decoder → params/params/decoder + result["params"] = {"params": stripped} + log(" params: Added double 'params' nesting (model → params/params)") + + # optimizer: extract step and opt_state back to top level + if "optimizer" in state: + optimizer = state["optimizer"] + if "step" in optimizer: + result["step"] = optimizer["step"] + log(" step: Extracted from optimizer/step to top level") + if "opt_state" in optimizer: + result["opt_state"] = _convert_opt_state_nnx_to_linen(optimizer["opt_state"]) + log(" opt_state: Added 'params' nesting to mu/nu") + elif "opt_state" in state: + # Backward compat: old format with opt_state at top level + result["opt_state"] = _convert_opt_state_nnx_to_linen(state["opt_state"]) + log(" opt_state: Converted from top-level opt_state (old format)") + + if "step" in state and "step" not in result: + result["step"] = state["step"] + + return result + + +# ── Checkpoint I/O ───────────────────────────────────────────────────────────── + + +def load_checkpoint(checkpoint_path: str) -> dict: + """Loads checkpoint from local or GCS path.""" + log(f"Loading checkpoint from: {checkpoint_path}") + + checkpoint_dir = epath.Path(checkpoint_path) + ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) + metadata = ckptr.metadata(checkpoint_dir) + + devices = np.array(jax.devices()).reshape((-1,)) + single_device_mesh = jax.sharding.Mesh(devices, ("x",)) + unsharded = jax.sharding.NamedSharding(single_device_mesh, jax.sharding.PartitionSpec()) + + restore_args = jax.tree_util.tree_map( + lambda x: ocp.ArrayRestoreArgs(sharding=unsharded) if hasattr(x, "shape") else None, + metadata.item_metadata.tree, + is_leaf=lambda x: hasattr(x, "shape"), + ) + + state = ckptr.restore(checkpoint_dir, restore_args=restore_args) + log(f" Loaded keys: {list(state.keys())}") + return state + + +def save_checkpoint(state: dict, output_path: str) -> None: + """Saves checkpoint to local or GCS path.""" + log(f"Saving checkpoint to: {output_path}") + + output_dir = epath.Path(output_path) + output_dir.mkdir(exist_ok=True, parents=True) + + ckptr = ocp.PyTreeCheckpointer() + ckptr.save(output_dir, state, force=True) + log(" Checkpoint saved successfully") + + +# ── CLI ──────────────────────────────────────────────────────────────────────── + + +def main(): + parser = argparse.ArgumentParser( + description="Convert between Linen and NNX checkpoint formats.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--source_path", + type=str, + required=True, + help="Path to source checkpoint items directory (e.g. gs://bucket/ckpt/0/items).", + ) + parser.add_argument( + "--target_path", + type=str, + required=True, + help="Path to save converted checkpoint.", + ) + parser.add_argument( + "--direction", + type=str, + choices=["auto", "linen_to_nnx", "nnx_to_linen"], + default="auto", + help="Conversion direction. 'auto' detects from source format.", + ) + parser.add_argument( + "--scan_layers", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "For linen_to_nnx only: if True (default), stack per-layer arrays into a " + "scanned 'layers' tensor with layer dim at axis 1 (for NNX with scan_layers=True). " + "If False, rename layers_N to integer-keyed layers/N without stacking " + "(for NNX with scan_layers=False)." + ), + ) + + args = parser.parse_args() + + print("=" * 80) + print("Linen <-> NNX Checkpoint Converter") + print("=" * 80) + + start_time = time.time() + + state = load_checkpoint(args.source_path) + + if args.direction == "auto": + source_format = detect_format(state) + target_format = "nnx" if source_format == "linen" else "linen" + log(f"Auto-detected: {source_format} → {target_format}") + else: + source_format = args.direction.split("_to_")[0] + target_format = args.direction.split("_to_")[1] + log(f"Using specified direction: {source_format} → {target_format}") + + log(f"Converting: {source_format} → {target_format}") + if source_format == "linen": + log(f"scan_layers={args.scan_layers}") + + if source_format == "linen" and target_format == "nnx": + converted_state = convert_linen_to_nnx(state, scan_layers=args.scan_layers) + elif source_format == "nnx" and target_format == "linen": + converted_state = convert_nnx_to_linen(state) + else: + raise ValueError(f"Invalid conversion: {source_format} → {target_format}") + + save_checkpoint(converted_state, args.target_path) + + elapsed = time.time() - start_time + print("\n" + "=" * 80) + print(f"Conversion complete in {elapsed:.2f} seconds") + print(f" Source: {args.source_path}") + print(f" Target: {args.target_path}") + print(f" Direction: {source_format} → {target_format}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index 1b0d4b4cd3..c6ca234a47 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -509,7 +509,11 @@ def __call__( previous_chunk=previous_chunk, slot=slot, page_state=page_state, - multimodal_input=multimodal_input, + image_embeddings=multimodal_input.image_embeddings if multimodal_input is not None else None, + image_masks=multimodal_input.image_masks if multimodal_input is not None else None, + audio_embeddings=multimodal_input.audio_embeddings if multimodal_input is not None else None, + audio_masks=multimodal_input.audio_masks if multimodal_input is not None else None, + bidirectional_mask=multimodal_input.bidirectional_mask if multimodal_input is not None else None, kv_caches=kv_caches, attention_metadata=attention_metadata, deepstack_visual_embeds=deepstack_visual_embeds, diff --git a/src/maxtext/optimizers/optimizers.py b/src/maxtext/optimizers/optimizers.py index 2ae7e5f8e5..9992d7674f 100644 --- a/src/maxtext/optimizers/optimizers.py +++ b/src/maxtext/optimizers/optimizers.py @@ -336,7 +336,9 @@ def _update_momentum(update, mu, nu): else: updates = jax.tree_util.tree_map(lambda x, v: x + weight_decay * v, updates, params) - step_size = -1.0 * learning_rate_fn(count) + # learning_rate_fn may be a callable schedule or a scalar (e.g. when wrapped + # by optax.inject_hyperparams, it is passed as a pre-evaluated scalar). + step_size = -1.0 * (learning_rate_fn(count) if callable(learning_rate_fn) else learning_rate_fn) # Finally, fold in step size. updates = jax.tree_util.tree_map(lambda x: step_size * x, updates) diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 1a66a532fb..40b866d415 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -273,30 +273,45 @@ def wrt_filter(path, x): # Inherits _shard_optimizer from PeftTrainer. def _train_step(self, model, optimizer, inputs): - """Overrides the main JIT block to natively handle ModelBundle module.""" + """Overrides the main JIT block to natively handle ModelBundle module. + Uses jax.value_and_grad with explicit split/merge to avoid nesting + nnx.value_and_grad inside nnx.jit, which causes Flax NNX to assign + conflicting outer_index values and raises: + ValueError: The graph structure of a node added to cached_partial was + mutated inside the transformation. + """ batch = self.gen_model_input_fn(inputs) + student = model.student_model + teacher = model.teacher_model current_step = model.training_step[...] - def loss_wrapper(student, teacher, batch): - if "teacher_output" in batch: - teacher_output = batch["teacher_output"] - else: - teacher_output = self.strategy.teacher_forward_fn( - model=teacher, - input_tokens=batch["input_tokens"], - positions=batch["positions"], - attention_mask=batch.get("attention_mask"), - decoder_segment_ids=batch.get("decoder_segment_ids"), - decoder_target_tokens=batch.get("targets", None), - decoder_target_mask=batch.get("targets_segmentation", None), - cache=None, - ) + # Run teacher inference outside of value_and_grad. + # The teacher is frozen (stop_gradient), so its output is a constant + # from the perspective of the student gradient computation. + if "teacher_output" in batch: + teacher_output = batch["teacher_output"] + else: + teacher_output = self.strategy.teacher_forward_fn( + model=teacher, + input_tokens=batch["input_tokens"], + positions=batch["positions"], + attention_mask=batch.get("attention_mask"), + decoder_segment_ids=batch.get("decoder_segment_ids"), + decoder_target_tokens=batch.get("targets", None), + decoder_target_mask=batch.get("targets_segmentation", None), + cache=None, + ) + teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output) - teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output) + # Split student into differentiable params and non-differentiable rest. + # Capture graphdef outside of jax.value_and_grad for stable graph tracking. + student_graphdef, diff_params, rest = nnx.split(student, self.wrt_filter, ...) + def loss_wrapper_pure(diff_params, rest): + local_student = nnx.merge(student_graphdef, diff_params, rest, copy=True) student_output = self.strategy.student_forward_fn( - model=student, + model=local_student, input_tokens=batch["input_tokens"], positions=batch["positions"], attention_mask=batch.get("attention_mask"), @@ -305,29 +320,26 @@ def loss_wrapper(student, teacher, batch): decoder_target_mask=batch.get("targets_segmentation", None), cache=None, ) - # we should apply a mask for labels to disable segment-separator tokens labels = self.strategy.create_labels(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None)) - return self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step) - - # Because student is the 0th argument, argnums=0 guarantees - # we only compute gradients for the student. - grad_fn = nnx.value_and_grad( - loss_wrapper, - argnums=nnx.DiffState(0, self.wrt_filter), - has_aux=True, - ) + loss, aux = self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step) + # Capture updated non-param state (e.g. RNG counters) from local_student. + _, _, new_rest = nnx.split(local_student, self.wrt_filter, ...) + return loss, (aux, new_rest) - out, grads = grad_fn(model.student_model, model.teacher_model, batch) + grad_fn = jax.value_and_grad(loss_wrapper_pure, argnums=0, has_aux=True) + (loss, (aux, new_rest)), grads = grad_fn(diff_params, rest) - model.training_step.set_value(current_step + 1) + # Propagate updated non-param state back to student. + nnx.update(student, new_rest) - tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True) + optimizer.update(student, grads) - optimizer.update(model.student_model, grads) + model.training_step.set_value(current_step + 1) + tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True) if tunix_expects_grad_norm: - return out[0], out[1], optax.global_norm(grads) - return out[0], out[1] + return loss, aux, optax.global_norm(grads) + return loss, aux def _eval_step(self, model, inputs): """Evaluation only needs the student.""" diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 0af37dc10f..5e15697127 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -55,6 +55,42 @@ import os import pathwaysutils +# JAX 0.9+ changed with_sharding_constraint to assert (not reshard) when all +# mesh axes are Explicit. tpu_inference still expects resharding semantics. +# Patch: try the original (works for Auto axes); on AssertionError (Explicit +# mesh) fall back to jax.sharding.reshard. +_orig_wsc = jax.lax.with_sharding_constraint + + +def _compat_wsc(x, shardings): + try: + return _orig_wsc(x, shardings) + except AssertionError: + return jax.sharding.reshard(x, shardings) + + +jax.lax.with_sharding_constraint = _compat_wsc + +# tpu_inference JaxEinsum defaults param_dtype=float32, so tpu_inference model weights +# initialize as float32. During weight sync, tunix._apply_dtype_cast then upcasts the +# incoming bfloat16 MaxText weights → float32 to match the target. This leaves v_proj +# as float32 while k_proj output appears bfloat16 (due to k_norm dtype promotion), +# causing a dtype mismatch in the ragged paged attention kernel. +# Fix: skip bfloat16→float32 upcasts during weight sync so synced weights stay bfloat16. +import jax.numpy as _jnp +import tunix.generate.utils as _tunix_utils + +_orig_apply_dtype_cast = _tunix_utils._apply_dtype_cast # pylint: disable=protected-access + + +def _no_bf16_to_f32_cast(val, tgt_dtype, src_key): + if hasattr(val, "dtype") and val.dtype == _jnp.bfloat16 and tgt_dtype == _jnp.float32: + return val # keep bfloat16; tpu_inference model dtype is bfloat16 despite float32 init + return _orig_apply_dtype_cast(val, tgt_dtype, src_key) + + +_tunix_utils._apply_dtype_cast = _no_bf16_to_f32_cast # pylint: disable=protected-access + from absl import app from absl import logging as absl_logging from etils import epath @@ -418,6 +454,8 @@ def create_rl_components( "hf_overrides": trainer_config.vllm_hf_overrides, "enable_expert_parallel": sampler_config.enable_expert_parallel, "enable_prefix_caching": True, # Enable prefix caching to speed up generation for long prompts + # Ensures vLLM model initializes with correct dtype (not float32 default) + "dtype": trainer_config.weight_dtype, }, rollout_vllm_sampling_kwargs={ "stop": trainer_config.stop_strings, @@ -563,7 +601,10 @@ def rl_train(argv: Sequence[str], kwargs: dict): max_train_steps = get_max_train_steps(trainer_config) # Create model tokenizer - model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path) + model_tokenizer = AutoTokenizer.from_pretrained( + trainer_config.tokenizer_path, + token=trainer_config.hf_access_token or None, + ) train_dataset, test_dataset = prepare_datasets(trainer_config, model_tokenizer) diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index c7c726cec9..a6c80d27dc 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -35,7 +35,7 @@ eval_interval=-1 steps=10 profiler=xplane weight_dtype=bfloat16 """ -from typing import Sequence +from typing import Any, Sequence from absl import app import os @@ -43,6 +43,7 @@ import optax import pathwaysutils +from flax import nnx from flax.linen import partitioning as nn_partitioning from orbax import checkpoint as ocp @@ -68,6 +69,70 @@ from maxtext.utils import model_creation_utils +class MaxTextPeftTrainer(peft_trainer.PeftTrainer): + """MaxText-specific PeftTrainer that avoids nested NNX transformations. + + Tunix's default PeftTrainer._train_step creates nnx.value_and_grad inside + nnx.jit. This nesting causes Flax NNX to assign conflicting outer_index + values to graph nodes, resulting in: + ValueError: The graph structure of a node added to cached_partial was + mutated inside the transformation. + + This subclass overrides create_train_step_fn to use jax.value_and_grad + with an explicit split/merge pattern (matching MaxText's pre-training NNX + train_step), which avoids the nested NNX transformation issue entirely. + """ + + def create_train_step_fn(self): + """Creates a train step using jax.value_and_grad with explicit NNX split/merge.""" + loss_fn_ref = self.loss_fn + has_aux = self._has_aux + gen_fn = self.gen_model_input_fn + is_lora_enabled = self._lora_enabled + wrt = nnx.LoRAParam if is_lora_enabled else nnx.Param + tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True) + + # Capture the graphdef once outside of JIT so that split/merge inside + # jax.value_and_grad can use a stable (non-traced) structural descriptor. + graphdef, _, _ = nnx.split(self.model, wrt, ...) + + def train_step(model: nnx.Module, optimizer: nnx.Optimizer, inputs: Any): + inputs = gen_fn(inputs) + + # Split model into differentiable params and non-differentiable rest. + # Using jax.value_and_grad (not nnx.value_and_grad) avoids nesting NNX + # transforms inside nnx.jit, which would corrupt outer_index tracking. + _, diff_params, rest = nnx.split(model, wrt, ...) + + def loss_wrapper(diff_params, rest, **inputs_kw): + local_model = nnx.merge(graphdef, diff_params, rest, copy=True) + out = loss_fn_ref(local_model, **inputs_kw) + # Capture updated non-param state (e.g. RNG counters) from local_model. + _, _, new_rest = nnx.split(local_model, wrt, ...) + if has_aux: + loss, aux = out + return loss, (aux, new_rest) + else: + return out, (None, new_rest) + + grad_fn = jax.value_and_grad(loss_wrapper, argnums=0, has_aux=True) + (out_val, (aux, new_rest)), grads = grad_fn(diff_params, rest, **inputs) + + # Propagate updated non-param state (RNG counters, etc.) back to model. + nnx.update(model, new_rest) + + # Apply optimizer update. grads has the same nnx.State(wrt) structure + # as diff_params, which is compatible with optimizer.update. + optimizer.update(model, grads) + + aux_out = aux if has_aux else None + if tunix_expects_grad_norm: + return out_val, aux_out, optax.global_norm(grads) + return out_val, aux_out + + return train_step + + def get_tunix_config(mt_config): """Gets the Tunix training configurations from the MaxText config. @@ -109,6 +174,7 @@ def get_tunix_config(mt_config): checkpointing_options=checkpointing_options, metrics_logging_options=metrics_logging_options, profiler_options=profiler_options, + data_sharding_axis=tuple(mt_config.data_sharding), ) @@ -162,7 +228,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None): data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder) # Provide rules context so 'norm' is translated to mesh axes during maybe_restore with nn_partitioning.axis_rules(mt_config.logical_axis_rules): - trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config) + trainer = MaxTextPeftTrainer(model, optimizer, tunix_config) trainer.with_training_hooks(training_hooks) trainer.with_data_hooks(data_hooks) trainer = use_maxtext_loss_function(trainer, mt_config) diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 37a0710cbb..1638dc8869 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1902,26 +1902,41 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No """ Print state shardings comparing Logical Definition vs Physical Result. """ - if not hasattr(params, "params"): - params = {"params": params} - if not hasattr(params_sharding, "params"): - params_sharding = {"params": params_sharding} - if logical_annotations and not hasattr(logical_annotations, "params"): - logical_annotations = {"params": logical_annotations} + if not isinstance(params, nnx.State): + if not hasattr(params, "params"): + params = {"params": params} + if not hasattr(params_sharding, "params"): + params_sharding = {"params": params_sharding} + if logical_annotations and not hasattr(logical_annotations, "params"): + logical_annotations = {"params": logical_annotations} leaves_params, _ = jax.tree_util.tree_flatten_with_path(params) leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding) - leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations) - for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical): - path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) - shape = jax.typeof(leaf_val) - pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) - pspec_str = str(tuple(pspec)) - logical_str = str(leaf_logical_val) - - message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}" - max_logging.info(message) + if logical_annotations is not None: + leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations) + for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip( + leaves_params, leaves_sharding, leaves_logical + ): + path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) + shape = jax.typeof(leaf_val) + pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) + pspec_str = str(tuple(pspec)) + logical_str = str(leaf_logical_val) + + message = ( + f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}" + ) + max_logging.info(message) + else: + for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding): + path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) + shape = jax.typeof(leaf_val) + pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) + pspec_str = str(tuple(pspec)) + + message = f" {path_str}\n" f" Shape: {shape}\n" f" Physical: {pspec_str}" + max_logging.info(message) print(flush=True) diff --git a/tests/post_training/unit/distillation_scheduling_test.py b/tests/post_training/unit/distillation_scheduling_test.py index 21e22839b4..24b9b6d721 100644 --- a/tests/post_training/unit/distillation_scheduling_test.py +++ b/tests/post_training/unit/distillation_scheduling_test.py @@ -412,9 +412,15 @@ def __call__(self, x): self.assertEqual(int(bundle.training_step[...]), 2) @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") - def test_train_step_increments_and_passes_step(self, mock_value_and_grad, mock_global_norm): + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") + def test_train_step_increments_and_passes_step( + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_global_norm + ): """_train_step passes pre-increment step to compute_loss and increments after.""" + del mock_merge, mock_update # pylint: disable=no-value-for-parameter trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer) trainer.strategy = mock.Mock() @@ -442,37 +448,54 @@ def test_train_step_increments_and_passes_step(self, mock_value_and_grad, mock_g # Simulate resume from step 5 model_bundle.training_step.set_value(jnp.array(5, dtype=jnp.int32)) - mock_grad_fn = mock.Mock(return_value=((mock.Mock(), {}), mock.Mock())) + # nnx.split returns (graphdef, diff_params, rest); loss_wrapper_pure takes (diff_params, rest). + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + # grad_fn returns ((loss, (aux, new_rest)), grads) + mock_grad_fn = mock.Mock(return_value=((mock.Mock(), ({}, mock.Mock())), mock.Mock())) mock_value_and_grad.return_value = mock_grad_fn mock_global_norm.return_value = mock.Mock() + trainer.strategy.compute_loss.return_value = (mock.Mock(), {}) trainer._train_step(model_bundle, optimizer, mock.Mock()) # Step should have incremented to 6 self.assertEqual(int(model_bundle.training_step[...]), 6) - # Trigger loss_wrapper to verify step=5 was passed to compute_loss + # Trigger loss_wrapper_pure to verify step=5 was passed to compute_loss. + # Signature is (diff_params, rest). loss_wrapper = mock_value_and_grad.call_args[0][0] - loss_wrapper(student_model, teacher_model, mock_batch) + loss_wrapper(mock_diff_params, mock_rest) call_kwargs = trainer.strategy.compute_loss.call_args self.assertIn("step", call_kwargs.kwargs) self.assertEqual(int(call_kwargs.kwargs["step"]), 5) @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") - def test_consecutive_train_steps_increment(self, mock_value_and_grad, mock_global_norm): + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") + def test_consecutive_train_steps_increment( + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_global_norm + ): """training_step increments 0→1→2→3 across consecutive _train_step calls.""" + del mock_merge, mock_update # pylint: disable=no-value-for-parameter trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer) trainer.strategy = mock.Mock() trainer.wrt_filter = lambda path, x: True # type: ignore + # Use a real DistillationForwardOutput so jax.tree.map(stop_gradient, ...) works. + fake_teacher_output = distillation_utils.DistillationForwardOutput( + logits=jnp.zeros((1, 2, 4)), out_projection_activations=None + ) mock_batch = { "input_tokens": mock.Mock(), "positions": mock.Mock(), "targets": mock.Mock(), - "teacher_output": mock.Mock(), + "teacher_output": fake_teacher_output, } trainer.gen_model_input_fn = mock.Mock(return_value=mock_batch) @@ -480,7 +503,10 @@ def test_consecutive_train_steps_increment(self, mock_value_and_grad, mock_globa model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) optimizer = mock.Mock() - mock_grad_fn = mock.Mock(return_value=((mock.Mock(), {}), mock.Mock())) + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + mock_grad_fn = mock.Mock(return_value=((mock.Mock(), ({}, mock.Mock())), mock.Mock())) mock_value_and_grad.return_value = mock_grad_fn mock_global_norm.return_value = mock.Mock() diff --git a/tests/post_training/unit/train_distill_test.py b/tests/post_training/unit/train_distill_test.py index 80b7cbfce7..ca57e13f7a 100644 --- a/tests/post_training/unit/train_distill_test.py +++ b/tests/post_training/unit/train_distill_test.py @@ -162,9 +162,12 @@ def test_prepare_inputs_logic(self): @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") def test_train_step_skips_teacher_forward_when_output_present( - self, mock_value_and_grad, mock_tree_map, mock_global_norm + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_tree_map, mock_global_norm ): """Verifies teacher forward is skipped when model_output is already in the batch.""" # 1. Initialize Trainer @@ -189,21 +192,28 @@ def test_train_step_skips_teacher_forward_when_output_present( model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) optimizer, inputs = mock.Mock(), mock.Mock() - # 4. Configure mocked nnx.value_and_grad + # 4. Configure nnx.split/merge/update mocks + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + # 5. Configure mocked jax.value_and_grad + # _train_step uses: (loss, (aux, new_rest)), grads = grad_fn(diff_params, rest) mock_loss, mock_aux, mock_grads = mock.Mock(), {}, mock.Mock() - mock_grad_fn = mock.Mock(return_value=((mock_loss, mock_aux), mock_grads)) + mock_grad_fn = mock.Mock(return_value=((mock_loss, (mock_aux, mock.Mock())), mock_grads)) mock_value_and_grad.return_value = mock_grad_fn mock_global_norm.return_value = mock.Mock() + trainer.strategy.compute_loss.return_value = (mock.Mock(), {}) - # 5. Execute outer function & trigger inner loss_wrapper + # 6. Execute outer function & trigger inner loss_wrapper_pure trainer._train_step(model_bundle, optimizer, inputs) loss_wrapper = mock_value_and_grad.call_args[0][0] - loss_wrapper(student_model, teacher_model, mock_batch) + # loss_wrapper_pure signature is (diff_params, rest), not (student, teacher, batch) + loss_wrapper(mock_diff_params, mock_rest) - # 6. Assertions + # 7. Assertions trainer.strategy.teacher_forward_fn.assert_not_called() trainer.strategy.student_forward_fn.assert_called_once_with( - model=student_model, + model=mock.ANY, # local_student from nnx.merge, not the original student_model input_tokens=mock_batch["input_tokens"], positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], @@ -215,9 +225,12 @@ def test_train_step_skips_teacher_forward_when_output_present( @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") def test_train_step_calls_teacher_forward_when_output_missing( - self, mock_value_and_grad, mock_tree_map, mock_global_norm + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_tree_map, mock_global_norm ): """Verifies teacher forward is called when model_output is missing from the batch.""" # 1. Initialize Trainer @@ -242,19 +255,27 @@ def test_train_step_calls_teacher_forward_when_output_missing( model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) optimizer, inputs = mock.Mock(), mock.Mock() - # 4. Configure mocked nnx.value_and_grad + # 4. Configure nnx.split/merge/update mocks + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + # 5. Configure mocked jax.value_and_grad + # _train_step uses: (loss, (aux, new_rest)), grads = grad_fn(diff_params, rest) mock_loss, mock_aux, mock_grads = mock.Mock(), {}, mock.Mock() - mock_grad_fn = mock.Mock(return_value=((mock_loss, mock_aux), mock_grads)) + mock_grad_fn = mock.Mock(return_value=((mock_loss, (mock_aux, mock.Mock())), mock_grads)) mock_value_and_grad.return_value = mock_grad_fn mock_gn = mock.Mock() mock_global_norm.return_value = mock_gn + trainer.strategy.compute_loss.return_value = (mock.Mock(), {}) - # 5. Execute outer function & trigger inner loss_wrapper + # 6. Execute outer function & trigger inner loss_wrapper_pure train_step_out = trainer._train_step(model_bundle, optimizer, inputs) loss_wrapper = mock_value_and_grad.call_args[0][0] - loss_wrapper(student_model, teacher_model, mock_batch) + # loss_wrapper_pure signature is (diff_params, rest), not (student, teacher, batch) + loss_wrapper(mock_diff_params, mock_rest) - # 6. Assertions + # 7. Assertions + # Teacher forward is called OUTSIDE value_and_grad in _train_step trainer.strategy.teacher_forward_fn.assert_called_once_with( model=teacher_model, input_tokens=mock_batch["input_tokens"], @@ -266,8 +287,9 @@ def test_train_step_calls_teacher_forward_when_output_missing( decoder_target_mask=None, ) + # Student forward is called INSIDE loss_wrapper_pure via nnx.merge'd local_student trainer.strategy.student_forward_fn.assert_called_once_with( - model=student_model, + model=mock.ANY, # local_student from nnx.merge, not the original student_model input_tokens=mock_batch["input_tokens"], positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], @@ -291,8 +313,13 @@ def test_train_step_calls_teacher_forward_when_output_missing( @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") - def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_tree_map, mock_global_norm): + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") + def test_train_step_passes_targets_segmentation( + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_tree_map, mock_global_norm + ): """Verifies strategy callbacks receive decoder_target_tokens and decoder_target_mask.""" # 1. Initialize Trainer # pylint: disable=no-value-for-parameter @@ -317,22 +344,30 @@ def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_ model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) optimizer, inputs = mock.Mock(), mock.Mock() - # 4. Configure mocked nnx.value_and_grad - mock_grad_fn = mock.Mock(return_value=((mock.Mock(), {}), mock.Mock())) + # 4. Configure nnx.split/merge/update mocks + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + # 5. Configure mocked jax.value_and_grad + # _train_step uses: (loss, (aux, new_rest)), grads = grad_fn(diff_params, rest) + mock_grad_fn = mock.Mock(return_value=((mock.Mock(), ({}, mock.Mock())), mock.Mock())) mock_value_and_grad.return_value = mock_grad_fn mock_global_norm.return_value = mock.Mock() + trainer.strategy.compute_loss.return_value = (mock.Mock(), {}) - # 5. Execute outer function & trigger inner loss_wrapper + # 6. Execute outer function & trigger inner loss_wrapper_pure trainer._train_step(model_bundle, optimizer, inputs) loss_wrapper = mock_value_and_grad.call_args[0][0] - loss_wrapper(student_model, teacher_model, mock_batch) + # loss_wrapper_pure signature is (diff_params, rest), not (student, teacher, batch) + loss_wrapper(mock_diff_params, mock_rest) - # 6. Assertions + # 7. Assertions trainer.strategy.create_labels.assert_called_once_with( mock_batch["targets"], targets_segmentation=mock_targets_segmentation ) + # Student forward is called INSIDE loss_wrapper_pure via nnx.merge'd local_student trainer.strategy.student_forward_fn.assert_called_once_with( - model=student_model, + model=mock.ANY, # local_student from nnx.merge, not the original student_model input_tokens=mock_batch["input_tokens"], positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], @@ -341,6 +376,7 @@ def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_ decoder_target_mask=mock_targets_segmentation, cache=None, ) + # Teacher forward is called OUTSIDE value_and_grad in _train_step trainer.strategy.teacher_forward_fn.assert_called_once_with( model=teacher_model, input_tokens=mock_batch["input_tokens"], diff --git a/tests/unit/compare_linen_nnx_checkpoint_test.py b/tests/unit/compare_linen_nnx_checkpoint_test.py new file mode 100644 index 0000000000..d3d49e6a63 --- /dev/null +++ b/tests/unit/compare_linen_nnx_checkpoint_test.py @@ -0,0 +1,501 @@ +# Copyright 2023-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for compare_linen_nnx_checkpoint utilities.""" + +import io +import unittest +from unittest.mock import patch +import numpy as np + +from absl import flags as absl_flags +from maxtext.checkpoint_conversion.compare_linen_nnx_checkpoint import ( + is_rng_path, + filter_rngs, + detect_format, + _has_value_wrappers, + _strip_value_wrappers, + _normalize_linen_params, + _normalize_nnx_params, + _extract_params, + _normalize_params, + get_tree_structure_info, + print_structure_diff, + compare_params, + transform_nnx_params_for_comparison, +) + + +def _arr(*shape): + """Helper: float32 array of given shape, values 0..prod(shape)-1.""" + return np.arange(int(np.prod(shape)), dtype=np.float32).reshape(shape) + + +def setUpModule(): + # Mark FLAGS as parsed so FLAGS.verbose etc. are accessible without a full + # app.run(). Required flags (ckpt_path_1/2) are not needed in unit tests. + absl_flags.FLAGS.mark_as_parsed() + + +# --------------------------------------------------------------------------- +# is_rng_path +# --------------------------------------------------------------------------- + + +class TestIsRngPath(unittest.TestCase): + """Tests for is_rng_path.""" + + def test_returns_true_for_rngs(self): + self.assertTrue(is_rng_path("model/decoder/rngs/dropout")) + + def test_returns_true_for_rng(self): + self.assertTrue(is_rng_path("model/rngs/params/key")) + + def test_returns_true_case_insensitive(self): + self.assertTrue(is_rng_path("model/RNGs/state")) + self.assertTrue(is_rng_path("model/RNG/state")) + + def test_returns_false_for_normal_path(self): + self.assertFalse(is_rng_path("model/decoder/layers/kernel")) + + def test_returns_false_for_empty_string(self): + self.assertFalse(is_rng_path("")) + + +# --------------------------------------------------------------------------- +# filter_rngs +# --------------------------------------------------------------------------- + + +class TestFilterRngs(unittest.TestCase): + """Tests for filter_rngs.""" + + def test_removes_top_level_rngs_key(self): + tree = {"model": {"kernel": _arr(4)}, "rngs": {"dropout": _arr(2)}} + result = filter_rngs(tree) + self.assertNotIn("rngs", result) + self.assertIn("model", result) + + def test_removes_nested_rngs_key(self): + tree = {"model": {"kernel": _arr(4), "rngs": {"key": _arr(2)}}} + result = filter_rngs(tree) + self.assertNotIn("rngs", result["model"]) + self.assertIn("kernel", result["model"]) + + def test_keeps_empty_parent_when_only_child_is_rng(self): + # After filtering, the parent dict becomes empty and is dropped. + tree = {"model": {"rngs": {"key": _arr(2)}}} + result = filter_rngs(tree) + self.assertNotIn("model", result) + + def test_passthrough_for_non_rng_tree(self): + tree = {"params": {"kernel": _arr(4), "bias": _arr(2)}} + result = filter_rngs(tree) + self.assertEqual(set(result.keys()), {"params"}) + + def test_passthrough_for_non_dict_input(self): + arr = _arr(4) + self.assertIs(filter_rngs(arr), arr) + + +# --------------------------------------------------------------------------- +# _has_value_wrappers +# --------------------------------------------------------------------------- + + +class TestHasValueWrappers(unittest.TestCase): + """Tests for _has_value_wrappers.""" + + def test_returns_true_for_direct_value_wrapper(self): + tree = {"value": _arr(3, 4)} + self.assertTrue(_has_value_wrappers(tree)) + + def test_returns_true_for_nested_wrapper(self): + tree = {"decoder": {"kernel": {"value": _arr(2, 2)}}} + self.assertTrue(_has_value_wrappers(tree)) + + def test_returns_false_for_plain_array(self): + self.assertFalse(_has_value_wrappers(_arr(3))) + + def test_returns_false_for_multi_key_dict(self): + tree = {"value": _arr(2), "extra": _arr(2)} + self.assertFalse(_has_value_wrappers(tree)) + + def test_returns_false_for_value_key_with_non_array(self): + tree = {"value": 42} + self.assertFalse(_has_value_wrappers(tree)) + + +# --------------------------------------------------------------------------- +# _strip_value_wrappers +# --------------------------------------------------------------------------- + + +class TestStripValueWrappers(unittest.TestCase): + """Tests for _strip_value_wrappers.""" + + def test_strips_direct_wrapper(self): + arr = _arr(3, 4) + result = _strip_value_wrappers({"value": arr}) + np.testing.assert_array_equal(result, arr) + + def test_strips_nested_wrappers(self): + arr = _arr(2, 2) + tree = {"decoder": {"kernel": {"value": arr}}} + result = _strip_value_wrappers(tree) + np.testing.assert_array_equal(result["decoder"]["kernel"], arr) + + def test_passthrough_plain_array(self): + arr = _arr(4) + self.assertIs(_strip_value_wrappers(arr), arr) + + def test_handles_list(self): + arr = _arr(2) + result = _strip_value_wrappers([{"value": arr}]) + np.testing.assert_array_equal(result[0], arr) + + def test_handles_tuple(self): + arr = _arr(2) + result = _strip_value_wrappers(({"value": arr},)) + np.testing.assert_array_equal(result[0], arr) + + def test_passthrough_non_array_scalar(self): + self.assertEqual(_strip_value_wrappers(42), 42) + + +# --------------------------------------------------------------------------- +# _normalize_linen_params +# --------------------------------------------------------------------------- + + +class TestNormalizeLinenParams(unittest.TestCase): + """Tests for _normalize_linen_params.""" + + def test_removes_double_nesting(self): + inner = {"decoder": {"layers": {}}} + params = {"params": inner} + result = _normalize_linen_params(params) + self.assertIs(result, inner) + + def test_removes_double_nesting_encoder(self): + inner = {"encoder": {"layers": {}}} + params = {"params": inner} + result = _normalize_linen_params(params) + self.assertIs(result, inner) + + def test_passthrough_when_no_double_nesting(self): + params = {"decoder": {"layers": {}}} + result = _normalize_linen_params(params) + self.assertIs(result, params) + + def test_passthrough_when_inner_has_no_decoder_encoder(self): + params = {"params": {"other_key": {}}} + result = _normalize_linen_params(params) + self.assertIs(result, params) + + +# --------------------------------------------------------------------------- +# _normalize_nnx_params +# --------------------------------------------------------------------------- + + +class TestNormalizeNnxParams(unittest.TestCase): + """Tests for _normalize_nnx_params.""" + + def test_strips_value_wrappers(self): + arr = _arr(2, 3) + params = {"decoder": {"kernel": {"value": arr}}} + result = _normalize_nnx_params(params) + np.testing.assert_array_equal(result["decoder"]["kernel"], arr) + + def test_passthrough_plain_tree(self): + arr = _arr(4) + params = {"decoder": {"kernel": arr}} + result = _normalize_nnx_params(params) + np.testing.assert_array_equal(result["decoder"]["kernel"], arr) + + +# --------------------------------------------------------------------------- +# detect_format +# --------------------------------------------------------------------------- + + +class TestDetectFormat(unittest.TestCase): + """Tests for detect_format.""" + + def test_detects_nnx_via_model_key(self): + state = {"model": {"decoder": {}}, "optimizer": {}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_via_double_nested_decoder(self): + state = {"params": {"params": {"decoder": {}}}} + self.assertEqual(detect_format(state), "linen") + + def test_detects_linen_via_double_nested_encoder(self): + state = {"params": {"params": {"encoder": {}}}} + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_via_value_wrappers(self): + arr = _arr(2, 2) + state = {"params": {"decoder": {"kernel": {"value": arr}}}} + self.assertEqual(detect_format(state), "nnx") + + def test_raises_when_no_params_or_model_key(self): + with self.assertRaises(ValueError): + detect_format({"step": 0}) + + def test_raises_on_undetectable_format(self): + with self.assertRaises(ValueError): + detect_format({"params": {"unknown_key": {}}}) + + +# --------------------------------------------------------------------------- +# _extract_params +# --------------------------------------------------------------------------- + + +class TestExtractParams(unittest.TestCase): + """Tests for _extract_params.""" + + def test_extracts_linen_params(self): + params = {"params": {"decoder": {}}} + state = {"params": params, "opt_state": {}} + self.assertIs(_extract_params(state, "linen"), params) + + def test_extracts_nnx_params_from_model_key(self): + model = {"decoder": {}} + state = {"model": model, "optimizer": {}} + self.assertIs(_extract_params(state, "nnx"), model) + + def test_extracts_nnx_params_falls_back_to_params_key(self): + params = {"decoder": {}} + state = {"params": params} + self.assertIs(_extract_params(state, "nnx"), params) + + def test_returns_empty_dict_when_key_missing(self): + state = {"optimizer": {}} + result = _extract_params(state, "linen") + self.assertEqual(result, {}) + + +# --------------------------------------------------------------------------- +# _normalize_params +# --------------------------------------------------------------------------- + + +class TestNormalizeParams(unittest.TestCase): + """Tests for _normalize_params.""" + + def test_dispatches_to_linen(self): + inner = {"decoder": {}} + params = {"params": inner} + result = _normalize_params(params, "linen") + self.assertIs(result, inner) + + def test_dispatches_to_nnx(self): + arr = _arr(2, 2) + params = {"decoder": {"kernel": {"value": arr}}} + result = _normalize_params(params, "nnx") + np.testing.assert_array_equal(result["decoder"]["kernel"], arr) + + +# --------------------------------------------------------------------------- +# get_tree_structure_info +# --------------------------------------------------------------------------- + + +class TestGetTreeStructureInfo(unittest.TestCase): + """Tests for get_tree_structure_info.""" + + def test_returns_shape_and_dtype(self): + tree = {"kernel": _arr(3, 4), "bias": _arr(4)} + info = get_tree_structure_info(tree) + self.assertEqual(info["['kernel']"], ((3, 4), "float32")) + self.assertEqual(info["['bias']"], ((4,), "float32")) + + def test_handles_nested_tree(self): + tree = {"decoder": {"kernel": _arr(2, 2)}} + info = get_tree_structure_info(tree) + self.assertEqual(len(info), 1) + shapes = [v[0] for v in info.values()] + self.assertIn((2, 2), shapes) + + def test_handles_non_array_leaves(self): + tree = {"step": 5} + info = get_tree_structure_info(tree) + self.assertEqual(len(info), 1) + shape, _ = list(info.values())[0] + self.assertEqual(shape, "N/A") + + +# --------------------------------------------------------------------------- +# print_structure_diff +# --------------------------------------------------------------------------- + + +class TestPrintStructureDiff(unittest.TestCase): + """Tests for print_structure_diff.""" + + def _make_params(self, keys_and_shapes): + return {k: _arr(*s) for k, s in keys_and_shapes.items()} + + def test_returns_empty_tuples_when_identical(self): + params = self._make_params({"kernel": (4, 4), "bias": (4,)}) + with patch("sys.stdout", new_callable=io.StringIO): + only1, only2, shape_mm, dtype_mm = print_structure_diff(params, params) + self.assertEqual(only1, []) + self.assertEqual(only2, []) + self.assertEqual(shape_mm, []) + self.assertEqual(dtype_mm, []) + + def test_detects_key_only_in_first(self): + p1 = self._make_params({"kernel": (4, 4), "bias": (4,)}) + p2 = self._make_params({"kernel": (4, 4)}) + with patch("sys.stdout", new_callable=io.StringIO): + only1, only2, _, _ = print_structure_diff(p1, p2) + self.assertEqual(len(only1), 1) + self.assertEqual(only2, []) + + def test_detects_key_only_in_second(self): + p1 = self._make_params({"kernel": (4, 4)}) + p2 = self._make_params({"kernel": (4, 4), "bias": (4,)}) + with patch("sys.stdout", new_callable=io.StringIO): + only1, only2, _, _ = print_structure_diff(p1, p2) + self.assertEqual(only1, []) + self.assertEqual(len(only2), 1) + + def test_detects_shape_mismatch(self): + p1 = {"kernel": _arr(4, 4)} + p2 = {"kernel": _arr(4, 8)} + with patch("sys.stdout", new_callable=io.StringIO): + _, _, shape_mm, _ = print_structure_diff(p1, p2) + self.assertEqual(len(shape_mm), 1) + + def test_detects_dtype_mismatch(self): + p1 = {"kernel": np.zeros((4,), dtype=np.float32)} + p2 = {"kernel": np.zeros((4,), dtype=np.float16)} + with patch("sys.stdout", new_callable=io.StringIO): + _, _, _, dtype_mm = print_structure_diff(p1, p2) + self.assertEqual(len(dtype_mm), 1) + + +# --------------------------------------------------------------------------- +# compare_params +# --------------------------------------------------------------------------- + + +class TestCompareParams(unittest.TestCase): + """Tests for compare_params.""" + + def test_returns_true_for_identical_params(self): + params = {"kernel": _arr(4, 4), "bias": _arr(4)} + with patch("builtins.print"): + result = compare_params(params, params) + self.assertTrue(result) + + def test_returns_false_for_different_structures(self): + p1 = {"kernel": _arr(4, 4)} + p2 = {"kernel": _arr(4, 4), "bias": _arr(4)} + with patch("builtins.print"): + result = compare_params(p1, p2) + self.assertFalse(result) + + def test_returns_false_for_shape_mismatch(self): + p1 = {"kernel": _arr(4, 4)} + p2 = {"kernel": _arr(4, 8)} + with patch("builtins.print"): + result = compare_params(p1, p2) + self.assertFalse(result) + + def test_returns_false_for_dtype_mismatch(self): + p1 = {"kernel": np.zeros((4,), dtype=np.float32)} + p2 = {"kernel": np.zeros((4,), dtype=np.float16)} + with patch("builtins.print"): + result = compare_params(p1, p2) + self.assertFalse(result) + + def test_value_comparison_passes_when_equal(self): + arr = _arr(4) + with patch("builtins.print"): + result = compare_params({"w": arr}, {"w": arr.copy()}, compare_values=True) + self.assertTrue(result) + + def test_value_comparison_fails_when_different(self): + p1 = {"w": np.array([1.0, 2.0], dtype=np.float32)} + p2 = {"w": np.array([1.0, 9.0], dtype=np.float32)} + with patch("builtins.print"): + result = compare_params(p1, p2, compare_values=True, atol=1e-5, rtol=1e-5) + self.assertFalse(result) + + def test_value_comparison_passes_within_tolerance(self): + p1 = {"w": np.array([1.0], dtype=np.float32)} + p2 = {"w": np.array([1.0 + 1e-7], dtype=np.float32)} + with patch("builtins.print"): + result = compare_params(p1, p2, compare_values=True, atol=1e-5, rtol=1e-5) + self.assertTrue(result) + + def test_verbose_mode_does_not_raise(self): + params = {"kernel": _arr(2, 2)} + with patch("builtins.print"): + result = compare_params(params, params, verbose=True, compare_values=True) + self.assertTrue(result) + + def test_nested_params(self): + params = {"decoder": {"kernel": _arr(4, 4), "bias": _arr(4)}} + with patch("builtins.print"): + result = compare_params(params, params) + self.assertTrue(result) + + +# --------------------------------------------------------------------------- +# transform_nnx_params_for_comparison +# --------------------------------------------------------------------------- + + +class TestTransformNnxParamsForComparison(unittest.TestCase): + """Tests for transform_nnx_params_for_comparison.""" + + def test_transposes_layer_array(self): + # Shape (num_layers=3, d=4) -> (d=4, num_layers=3) + arr = _arr(3, 4) + tree = {"layers": {"kernel": arr}} + with patch("builtins.print"): + result = transform_nnx_params_for_comparison(tree) + self.assertEqual(result["layers"]["kernel"].shape, (4, 3)) + + def test_does_not_transpose_non_layer_array(self): + arr = _arr(3, 4) + tree = {"embedding": arr} + with patch("builtins.print"): + result = transform_nnx_params_for_comparison(tree) + self.assertEqual(result["embedding"].shape, (3, 4)) + + def test_does_not_transpose_1d_layer_array(self): + arr = _arr(4) + tree = {"layers": {"bias": arr}} + with patch("builtins.print"): + result = transform_nnx_params_for_comparison(tree) + self.assertEqual(result["layers"]["bias"].shape, (4,)) + + def test_transposes_higher_rank_layer_array(self): + # Shape (num_layers=2, d1=3, d2=5) -> (d1=3, num_layers=2, d2=5) + arr = _arr(2, 3, 5) + tree = {"layers": {"kernel": arr}} + with patch("builtins.print"): + result = transform_nnx_params_for_comparison(tree) + self.assertEqual(result["layers"]["kernel"].shape, (3, 2, 5)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/linen_nnx_converter_test.py b/tests/unit/linen_nnx_converter_test.py new file mode 100644 index 0000000000..808990f8cf --- /dev/null +++ b/tests/unit/linen_nnx_converter_test.py @@ -0,0 +1,869 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for linen_nnx_converter utilities.""" + +import unittest +import numpy as np +from unittest.mock import MagicMock, patch + +from maxtext.checkpoint_conversion.linen_nnx_converter import ( + detect_format, + _has_value_wrappers, + _strip_value_wrappers, + _add_value_wrappers, + _transpose_layers_axes, + _stack_layers, + convert_linen_to_nnx, + convert_nnx_to_linen, + _convert_opt_state_linen_to_nnx, + _convert_opt_state_nnx_to_linen, + load_checkpoint, + save_checkpoint, + main, +) + + +def _make_array(*shape): + """Helper to create a numpy array with given shape.""" + return np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + +class TestDetectFormat(unittest.TestCase): + """Tests for the detect_format function.""" + + def test_raises_when_no_params_key(self): + with self.assertRaises(ValueError): + detect_format({"step": 0}) + + def test_detects_nnx_format_via_model_key(self): + # NNX: top-level "model" key + state = {"model": {"decoder": {"layers": {}}}, "optimizer": {}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_format_double_nested(self): + state = {"params": {"params": {"decoder": {"layers": {}}}}} + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_format_single_nested_with_value_wrappers(self): + # Old NNX format: params/decoder with {value:} wrappers + arr = _make_array(2, 2) + state = {"params": {"decoder": {"kernel": {"value": arr}}}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_via_encoder(self): + state = {"params": {"params": {"encoder": {"layers": {}}}}} + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_via_encoder_with_value_wrappers(self): + arr = _make_array(2, 2) + state = {"params": {"encoder": {"kernel": {"value": arr}}}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_nnx_via_optimizer_key(self): + arr = _make_array(2, 2) + state = {"params": {"something": arr}, "optimizer": {"step": 0}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_via_opt_state(self): + arr = _make_array(2, 2) + state = { + "params": {"something": arr}, + "opt_state": {"params": {"mu": {"decoder": {"kernel": arr}}}}, + } + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_via_optimizer_over_opt_state(self): + # "optimizer" key takes precedence for NNX detection + arr = _make_array(2, 2) + state = { + "params": {"something": arr}, + "optimizer": {"step": 0, "opt_state": {}}, + } + self.assertEqual(detect_format(state), "nnx") + + def test_raises_on_undetectable_format(self): + state = {"params": {"some_unknown_key": 42}} + with self.assertRaises(ValueError): + detect_format(state) + + +class TestHasValueWrappers(unittest.TestCase): + """Tests for the _has_value_wrappers helper.""" + + def test_returns_true_for_value_wrapper(self): + arr = _make_array(2, 2) + self.assertTrue(_has_value_wrappers({"value": arr})) + + def test_returns_true_for_nested_value_wrapper(self): + arr = _make_array(2, 2) + self.assertTrue(_has_value_wrappers({"mu": {"value": arr}})) + + def test_returns_false_for_plain_array(self): + # A plain array is not a {"value": ...} wrapper dict + self.assertFalse(_has_value_wrappers(_make_array(2, 2))) + + def test_returns_false_for_multi_key_dict(self): + arr = _make_array(2, 2) + self.assertFalse(_has_value_wrappers({"value": arr, "extra": arr})) + + def test_returns_false_for_non_array_value(self): + self.assertFalse(_has_value_wrappers({"value": "string"})) + + +class TestStripValueWrappers(unittest.TestCase): + """Tests for the _strip_value_wrappers helper.""" + + def test_strips_single_wrapper(self): + arr = _make_array(3, 4) + result = _strip_value_wrappers({"value": arr}) + np.testing.assert_array_equal(result, arr) + + def test_strips_nested_wrappers(self): + arr = _make_array(2, 2) + wrapped = {"decoder": {"layers": {"kernel": {"value": arr}}}} + stripped = _strip_value_wrappers(wrapped) + np.testing.assert_array_equal(stripped["decoder"]["layers"]["kernel"], arr) + + def test_passes_through_plain_array(self): + arr = _make_array(2, 3) + result = _strip_value_wrappers(arr) + np.testing.assert_array_equal(result, arr) + + def test_handles_list_and_tuple(self): + arr = _make_array(2) + result_list = _strip_value_wrappers([{"value": arr}]) + result_tuple = _strip_value_wrappers(({"value": arr},)) + np.testing.assert_array_equal(result_list[0], arr) + np.testing.assert_array_equal(result_tuple[0], arr) + + def test_passes_through_non_array_value(self): + # A dict with key "value" but scalar content should not be unwrapped + d = {"value": 42} + result = _strip_value_wrappers(d) + self.assertEqual(result, d) + + +class TestAddValueWrappers(unittest.TestCase): + """Tests for the _add_value_wrappers helper.""" + + def test_wraps_array(self): + arr = _make_array(3, 4) + result = _add_value_wrappers(arr) + self.assertIsInstance(result, dict) + self.assertIn("value", result) + np.testing.assert_array_equal(result["value"], arr) + + def test_wraps_nested_arrays(self): + arr = _make_array(2, 2) + nested = {"decoder": {"layers": {"kernel": arr}}} + wrapped = _add_value_wrappers(nested) + self.assertEqual(set(wrapped["decoder"]["layers"]["kernel"].keys()), {"value"}) + np.testing.assert_array_equal(wrapped["decoder"]["layers"]["kernel"]["value"], arr) + + def test_idempotent_on_already_wrapped(self): + arr = _make_array(2) + already_wrapped = {"value": arr} + result = _add_value_wrappers(already_wrapped) + # Should not double-wrap + self.assertEqual(set(result.keys()), {"value"}) + np.testing.assert_array_equal(result["value"], arr) + + def test_handles_list_and_tuple(self): + arr = _make_array(2) + result_list = _add_value_wrappers([arr]) + result_tuple = _add_value_wrappers((arr,)) + self.assertEqual(set(result_list[0].keys()), {"value"}) + self.assertEqual(set(result_tuple[0].keys()), {"value"}) + + def test_passes_through_non_array_scalars(self): + result = _add_value_wrappers(42) + self.assertEqual(result, 42) + result_str = _add_value_wrappers("text") + self.assertEqual(result_str, "text") + + +class TestTransposeLayersAxes(unittest.TestCase): + """Tests for the _transpose_layers_axes helper.""" + + def test_noop_when_same_axis(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes(arr, src_axis=0, dst_axis=0) + np.testing.assert_array_equal(result, arr) + + def test_transposes_axis_0_to_1(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes(arr, src_axis=0, dst_axis=1) + self.assertEqual(result.shape, (2, 4, 3)) + + def test_transposes_axis_1_to_0(self): + arr = _make_array(2, 4, 3) + result = _transpose_layers_axes(arr, src_axis=1, dst_axis=0) + self.assertEqual(result.shape, (4, 2, 3)) + + def test_transposes_nested_dict(self): + arr = _make_array(4, 2, 3) + tree = {"decoder": {"layers": {"kernel": arr}}} + result = _transpose_layers_axes(tree, src_axis=0, dst_axis=1) + self.assertEqual(result["decoder"]["layers"]["kernel"].shape, (2, 4, 3)) + + def test_passes_through_1d_array(self): + arr = _make_array(5) + result = _transpose_layers_axes(arr, src_axis=0, dst_axis=1) + # 1D array has no axis 1, should be returned unchanged + np.testing.assert_array_equal(result, arr) + + def test_handles_list(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes([arr], src_axis=0, dst_axis=1) + self.assertIsInstance(result, list) + self.assertEqual(result[0].shape, (2, 4, 3)) + + def test_handles_tuple(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes((arr,), src_axis=0, dst_axis=1) + self.assertIsInstance(result, tuple) + self.assertEqual(result[0].shape, (2, 4, 3)) + + +class TestStackLayers(unittest.TestCase): + """Tests for the _stack_layers helper.""" + + def test_stacks_individual_layers(self): + arr0 = _make_array(3, 4) + arr1 = _make_array(3, 4) + decoder = { + "layers_0": {"mlp": {"kernel": arr0}}, + "layers_1": {"mlp": {"kernel": arr1}}, + } + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("layers", result) + stacked = result["layers"]["mlp"]["kernel"] + self.assertEqual(stacked.shape, (2, 3, 4)) + np.testing.assert_array_equal(stacked[0], arr0) + np.testing.assert_array_equal(stacked[1], arr1) + + def test_noop_when_no_layer_pattern(self): + arr = _make_array(3, 4) + decoder = {"layers": {"mlp": {"kernel": arr}}} + result, was_stacked = _stack_layers(decoder) + self.assertFalse(was_stacked) + self.assertIs(result, decoder) + + def test_preserves_non_layer_keys(self): + norm_weight = _make_array(4) + arr0 = _make_array(3, 4) + decoder = { + "layers_0": {"mlp": {"kernel": arr0}}, + "final_norm": {"scale": norm_weight}, + } + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("final_norm", result) + np.testing.assert_array_equal(result["final_norm"]["scale"], norm_weight) + + def test_stacks_three_layers(self): + arrays = [_make_array(2, 2) for _ in range(3)] + decoder = {f"layers_{i}": {"w": arrays[i]} for i in range(3)} + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + stacked = result["layers"]["w"] + self.assertEqual(stacked.shape, (3, 2, 2)) + + def test_non_array_non_dict_leaf(self): + # Scalar leaf — stack_arrays returns first element + decoder = {"layers_0": {"count": 1}, "layers_1": {"count": 2}} + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("layers", result) + + def test_with_missing_key_in_some_layers(self): + arr = _make_array(3, 4) + decoder = { + "layers_0": {"mlp": {"kernel": arr, "bias": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, # no "bias" + } + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("kernel", result["layers"]["mlp"]) + + +class TestConvertLinenToNNX(unittest.TestCase): + """Tests for the convert_linen_to_nnx function.""" + + def _make_linen_state(self, add_opt_state=False): + """Creates a minimal Linen checkpoint structure.""" + arr = _make_array(2, 4, 3) + state = { + "step": 10, + "params": { + "params": { + "decoder": { + "layers": {"mlp": {"wi": {"kernel": arr}}}, + "decoder_norm": {"scale": _make_array(4)}, + } + } + }, + } + if add_opt_state: + state["opt_state"] = {"params": {"mu": {"decoder": {"layers": {"kernel": arr}}}}} + return state + + def test_converts_step_under_optimizer(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + self.assertEqual(result["optimizer"]["step"], 10) + + def test_step_not_at_top_level(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + self.assertNotIn("step", result) + + def test_params_stored_under_model_key(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + self.assertIn("model", result) + self.assertNotIn("params", result) + + def test_removes_double_nesting(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + # model should have 'decoder' directly, not 'params.decoder' + self.assertIn("decoder", result["model"]) + self.assertNotIn("params", result["model"]) + + def test_adds_value_wrappers(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + # Arrays should be wrapped in {"value": array} + kernel = result["model"]["decoder"]["layers"]["mlp"]["wi"]["kernel"] + self.assertIsInstance(kernel, dict) + self.assertIn("value", kernel) + + def test_converts_opt_state_under_optimizer(self): + state = self._make_linen_state(add_opt_state=True) + result = convert_linen_to_nnx(state) + self.assertIn("opt_state", result["optimizer"]) + # Linen opt_state had nested 'params' level; it should be removed + self.assertNotIn("params", result["optimizer"]["opt_state"]) + + def test_no_step_produces_no_optimizer_step(self): + arr = _make_array(2, 4, 3) + state = {"params": {"params": {"decoder": {"layers": {"kernel": arr}}}}} + result = convert_linen_to_nnx(state) + self.assertNotIn("step", result) + self.assertIn("model", result) + + def test_no_double_nesting_still_converts(self): + # Linen state without double-nesting (unusual but handled) + arr = _make_array(2, 4) + state = {"params": {"decoder": {"layers": {"kernel": arr}}}} + result = convert_linen_to_nnx(state) + self.assertIn("decoder", result["model"]) + + def test_no_params_key_only_step(self): + state = {"step": 3} + result = convert_linen_to_nnx(state) + self.assertEqual(result["optimizer"]["step"], 3) + self.assertNotIn("model", result) + + def test_with_per_layer_params_stacked_and_transposed(self): + # Linen checkpoint with layers_0, layers_1 → stacked + transposed to axis 1 + arr = _make_array(3, 4) + state = { + "params": { + "params": { + "decoder": { + "layers_0": {"mlp": {"kernel": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, + } + } + } + } + result = convert_linen_to_nnx(state) + stacked = result["model"]["decoder"]["layers"]["mlp"]["kernel"]["value"] + # Original (3, 4) stacked → (2, 3, 4), transposed to (3, 2, 4) + self.assertEqual(stacked.shape, (3, 2, 4)) + + +class TestConvertNNXToLinen(unittest.TestCase): + """Tests for the convert_nnx_to_linen function.""" + + def _make_nnx_state(self, add_opt_state=False): + """Creates an NNX checkpoint with 'model' and 'optimizer' keys. + + Uses 'attention' (not 'layers') as the sub-key so _convert_layers_to_linen_format + does not try to unstack the data. + """ + arr = _make_array(2, 4, 3) + state = { + "model": { + "decoder": { + "attention": {"wi": {"kernel": {"value": arr}}}, + "decoder_norm": {"scale": {"value": _make_array(4)}}, + } + }, + "optimizer": {"step": 5}, + } + if add_opt_state: + state["optimizer"]["opt_state"] = { + "mu": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + "nu": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + } + return state + + def test_converts_step(self): + state = self._make_nnx_state() + result = convert_nnx_to_linen(state) + self.assertEqual(result["step"], 5) + + def test_adds_double_nesting(self): + state = self._make_nnx_state() + result = convert_nnx_to_linen(state) + self.assertIn("params", result["params"]) + self.assertIn("decoder", result["params"]["params"]) + + def test_strips_value_wrappers(self): + state = self._make_nnx_state() + result = convert_nnx_to_linen(state) + kernel = result["params"]["params"]["decoder"]["attention"]["wi"]["kernel"] + self.assertIsInstance(kernel, np.ndarray) + + def test_converts_opt_state(self): + state = self._make_nnx_state(add_opt_state=True) + result = convert_nnx_to_linen(state) + self.assertIn("opt_state", result) + # mu/nu should get a 'params' level added + self.assertIn("params", result["opt_state"]["mu"]) + self.assertIn("params", result["opt_state"]["nu"]) + + def test_backward_compat_params_key(self): + # Old NNX format: "params" instead of "model", top-level "step" + arr = _make_array(2, 4, 3) + state = { + "step": 5, + "params": { + "decoder": { + "layers": {"mlp": {"wi": {"kernel": {"value": arr}}}}, + "decoder_norm": {"scale": {"value": _make_array(4)}}, + } + }, + } + result = convert_nnx_to_linen(state) + self.assertEqual(result["step"], 5) + self.assertIn("decoder", result["params"]["params"]) + + def test_no_step(self): + arr = _make_array(2, 4) + state = {"model": {"decoder": {"layers": {"kernel": {"value": arr}}}}} + result = convert_nnx_to_linen(state) + self.assertNotIn("step", result) + self.assertIn("params", result) + + +class TestRoundTrip(unittest.TestCase): + """Verifies that linen->nnx->linen round-trip preserves data.""" + + def test_linen_to_nnx_to_linen(self): + # Use "attention" (not "layers") so _convert_layers_to_linen_format + # does not try to unstack the dict as a stacked-layers tensor. + arr = _make_array(2, 4, 3) + linen_state = { + "step": 42, + "params": { + "params": { + "decoder": { + "attention": {"mlp": {"wi": {"kernel": arr}}}, + "norm": {"scale": _make_array(4)}, + } + } + }, + } + nnx_state = convert_linen_to_nnx(linen_state) + recovered_state = convert_nnx_to_linen(nnx_state) + + self.assertEqual(recovered_state["step"], 42) + recovered_kernel = recovered_state["params"]["params"]["decoder"]["attention"]["mlp"]["wi"]["kernel"] + np.testing.assert_array_equal(recovered_kernel, arr) + + def test_nnx_to_linen_to_nnx(self): + arr = _make_array(2, 4, 3) + nnx_state = { + "model": { + "decoder": { + "layers": {"mlp": {"wi": {"kernel": {"value": arr}}}}, + } + }, + "optimizer": {"step": 7}, + } + linen_state = convert_nnx_to_linen(nnx_state) + recovered_state = convert_linen_to_nnx(linen_state) + + self.assertEqual(recovered_state["optimizer"]["step"], 7) + recovered_kernel = recovered_state["model"]["decoder"]["layers"]["mlp"]["wi"]["kernel"] + self.assertIn("value", recovered_kernel) + np.testing.assert_array_equal(recovered_kernel["value"], arr) + + +class TestConvertOptState(unittest.TestCase): + """Tests for the _convert_opt_state_linen_to_nnx and _convert_opt_state_nnx_to_linen helpers.""" + + def test_linen_to_nnx_removes_params_level(self): + arr = _make_array(3, 4) + opt_state = {"mu": {"params": {"decoder": {"kernel": arr}}}} + result = _convert_opt_state_linen_to_nnx(opt_state) + # 'params' key removed; decoder promoted + self.assertNotIn("params", result["mu"]) + self.assertIn("decoder", result["mu"]) + # Arrays are plain (no value wrappers in NNX opt_state) + np.testing.assert_array_equal(result["mu"]["decoder"]["kernel"], arr) + + def test_linen_to_nnx_handles_list_input(self): + arr = _make_array(2, 2) + opt_state = [{"decoder": {"kernel": arr}}, {"decoder": {"kernel": arr}}] + result = _convert_opt_state_linen_to_nnx(opt_state) + self.assertIsInstance(result, list) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_linen_to_nnx_handles_tuple_input(self): + arr = _make_array(2, 2) + opt_state = ({"decoder": {"kernel": arr}},) + result = _convert_opt_state_linen_to_nnx(opt_state) + self.assertIsInstance(result, tuple) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_linen_to_nnx_handles_non_array_non_dict(self): + # Scalars should be passed through unchanged + result = _convert_opt_state_linen_to_nnx(42) + self.assertEqual(result, 42) + + def test_linen_to_nnx_params_key_with_non_dict_value(self): + # When k == "params" but converted value is not a dict, store it as-is + opt_state = {"params": 99} + result = _convert_opt_state_linen_to_nnx(opt_state) + self.assertIn("params", result) + self.assertEqual(result["params"], 99) + + def test_nnx_to_linen_adds_params_level_and_strips(self): + arr = _make_array(3, 4) + opt_state = { + "mu": {"decoder": {"kernel": {"value": arr}}}, + "nu": {"decoder": {"kernel": {"value": arr}}}, + } + result = _convert_opt_state_nnx_to_linen(opt_state) + # mu/nu should have 'params' nested inside + self.assertIn("params", result["mu"]) + self.assertIn("params", result["nu"]) + # Arrays unwrapped + kernel = result["mu"]["params"]["decoder"]["kernel"] + np.testing.assert_array_equal(kernel, arr) + + def test_nnx_to_linen_handles_list_input(self): + arr = _make_array(2, 2) + opt_state = [{"decoder": {"kernel": {"value": arr}}}] + result = _convert_opt_state_nnx_to_linen(opt_state) + self.assertIsInstance(result, list) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_nnx_to_linen_handles_tuple_input(self): + arr = _make_array(2, 2) + opt_state = ({"decoder": {"kernel": {"value": arr}}},) + result = _convert_opt_state_nnx_to_linen(opt_state) + self.assertIsInstance(result, tuple) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_nnx_to_linen_passes_through_scalars(self): + result = _convert_opt_state_nnx_to_linen("scalar_string") + self.assertEqual(result, "scalar_string") + + def test_nnx_to_linen_value_wrapper_with_non_array_inner(self): + # {"value": scalar} should NOT be unwrapped (only arrays get unwrapped) + d = {"value": 42} + result = _convert_opt_state_nnx_to_linen(d) + self.assertIn("value", result) + self.assertEqual(result["value"], 42) + + +class TestConvertLinenToNNXEncoder(unittest.TestCase): + """Tests encoder path in convert_linen_to_nnx.""" + + def test_converts_encoder_params(self): + arr = _make_array(2, 4, 3) + state = { + "params": { + "params": { + "encoder": { + "layers": {"mlp": {"wi": {"kernel": arr}}}, + } + } + } + } + result = convert_linen_to_nnx(state) + self.assertIn("encoder", result["model"]) + kernel = result["model"]["encoder"]["layers"]["mlp"]["wi"]["kernel"] + self.assertIsInstance(kernel, dict) + self.assertIn("value", kernel) + + def test_converts_encoder_with_per_layer_stacking(self): + arr = _make_array(3, 4) + state = { + "params": { + "params": { + "encoder": { + "layers_0": {"mlp": {"kernel": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, + } + } + } + } + result = convert_linen_to_nnx(state) + stacked = result["model"]["encoder"]["layers"]["mlp"]["kernel"]["value"] + # Stacked at axis 0 → (2, 3, 4), then transposed to (3, 2, 4) + self.assertEqual(stacked.shape, (3, 2, 4)) + + +class TestAdditionalEdgeCases(unittest.TestCase): + """Covers remaining edge cases.""" + + def test_detect_format_params_has_params_but_no_decoder_encoder(self): + # params["params"] exists but inner has no decoder/encoder -> falls through + # no optimizer/opt_state -> should raise + state = {"params": {"params": {"some_other_key": {}}}} + with self.assertRaises(ValueError): + detect_format(state) + + def test_detect_format_opt_state_returns_linen(self): + # Any state with "opt_state" (but no "model"/"optimizer") detects as linen + arr = _make_array(2) + state = { + "params": {"something": arr}, + "opt_state": {"mu": {"decoder": {"kernel": arr}}}, + } + self.assertEqual(detect_format(state), "linen") + + def test_add_value_wrappers_value_key_with_non_array(self): + # {"value": "text"} is not a wrapper (inner is not an array), recurse normally + d = {"value": "not_an_array"} + result = _add_value_wrappers(d) + self.assertEqual(result, {"value": "not_an_array"}) + + def test_convert_nnx_to_linen_no_step(self): + arr = _make_array(2, 4) + state = {"model": {"decoder": {"layers": {"kernel": {"value": arr}}}}} + result = convert_nnx_to_linen(state) + self.assertNotIn("step", result) + self.assertIn("params", result) + + def test_convert_nnx_to_linen_already_has_params_nesting(self): + arr = _make_array(2, 4) + state = {"params": {"params": {"decoder": {"layers": {"kernel": {"value": arr}}}}}} + result = convert_nnx_to_linen(state) + self.assertIn("params", result) + + def test_convert_nnx_to_linen_no_params_key(self): + state = {"optimizer": {"step": 8}} + result = convert_nnx_to_linen(state) + self.assertEqual(result["step"], 8) + self.assertNotIn("params", result) + + +class TestLoadCheckpoint(unittest.TestCase): + """Tests for load_checkpoint with mocked orbax/epath.""" + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_load_checkpoint_calls_checkpointer_and_returns_state(self, mock_epath, mock_ocp): + arr = _make_array(2, 2) + expected_state = {"params": arr, "step": 0} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + + mock_metadata = MagicMock() + mock_metadata.item_metadata.tree = {"params": arr} + + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = mock_metadata + mock_ckptr.restore.return_value = expected_state + mock_ocp.Checkpointer.return_value = mock_ckptr + mock_ocp.ArrayRestoreArgs.return_value = MagicMock() + + result = load_checkpoint("/tmp/test_ckpt") + + mock_epath.Path.assert_called_once_with("/tmp/test_ckpt") + mock_ocp.Checkpointer.assert_called_once() + mock_ckptr.metadata.assert_called_once_with(mock_path) + mock_ckptr.restore.assert_called_once() + self.assertEqual(result, expected_state) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_load_checkpoint_with_empty_tree_metadata(self, mock_epath, mock_ocp): + expected_state = {"step": 5} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + + mock_metadata = MagicMock() + mock_metadata.item_metadata.tree = {} + + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = mock_metadata + mock_ckptr.restore.return_value = expected_state + mock_ocp.Checkpointer.return_value = mock_ckptr + + result = load_checkpoint("/tmp/empty_ckpt") + + self.assertEqual(result["step"], 5) + + +class TestSaveCheckpoint(unittest.TestCase): + """Tests for save_checkpoint with mocked orbax/epath.""" + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_save_checkpoint_creates_dir_and_saves(self, mock_epath, mock_ocp): + state = {"params": _make_array(2, 2), "step": 1} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + + mock_ckptr = MagicMock() + mock_ocp.PyTreeCheckpointer.return_value = mock_ckptr + + save_checkpoint(state, "/tmp/output") + + mock_epath.Path.assert_called_once_with("/tmp/output") + mock_path.mkdir.assert_called_once_with(exist_ok=True, parents=True) + mock_ocp.PyTreeCheckpointer.assert_called_once() + mock_ckptr.save.assert_called_once_with(mock_path, state, force=True) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_save_checkpoint_passes_state_unchanged(self, mock_epath, mock_ocp): + state = {"step": 99, "params": {"decoder": {}}} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + mock_ckptr = MagicMock() + mock_ocp.PyTreeCheckpointer.return_value = mock_ckptr + + save_checkpoint(state, "/tmp/out2") + + call_args = mock_ckptr.save.call_args + self.assertIs(call_args[0][1], state) + + +class TestMain(unittest.TestCase): + """Tests for the main() CLI entry point.""" + + def _run_main(self, argv): + with patch("sys.argv", ["prog"] + argv): + main() + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_explicit_linen_to_nnx(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "step": 1, + "params": {"params": {"decoder": {"layers": {"kernel": arr}}}}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=linen_to_nnx"]) + mock_load.assert_called_once_with("/src") + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # NNX format: decoder at top level of model + self.assertIn("decoder", saved_state["model"]) + self.assertEqual(mock_save.call_args[0][1], "/dst") + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_explicit_nnx_to_linen(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "model": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + "optimizer": {"step": 2}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=nnx_to_linen"]) + mock_load.assert_called_once_with("/src") + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # Linen format: double nesting + self.assertIn("params", saved_state["params"]) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_auto_detects_linen_converts_to_nnx(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "step": 3, + "params": {"params": {"decoder": {"layers": {"kernel": arr}}}}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=auto"]) + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # Auto-detected linen → NNX format: model key + self.assertIn("decoder", saved_state["model"]) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_auto_detects_nnx_converts_to_linen(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "model": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + "optimizer": {"step": 4}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=auto"]) + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # Auto-detected nnx → Linen format + self.assertIn("params", saved_state["params"]) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_default_direction_is_auto(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "params": {"params": {"decoder": {"layers": {"kernel": arr}}}}, + } + # No --direction arg -> defaults to "auto" + self._run_main(["--source_path=/src", "--target_path=/dst"]) + mock_save.assert_called_once() + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_scan_layers_false(self, mock_load, mock_save): + arr = _make_array(3, 4) + mock_load.return_value = { + "params": { + "params": { + "decoder": { + "layers_0": {"mlp": {"kernel": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, + } + } + } + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=linen_to_nnx", "--no-scan_layers"]) + saved_state = mock_save.call_args[0][0] + # With scan_layers=False: integer-keyed layers/N + layers = saved_state["model"]["decoder"]["layers"] + self.assertIsInstance(layers, dict) + self.assertTrue(all(k.isdigit() for k in layers.keys())) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/run_sharding_dump.py b/tests/utils/run_sharding_dump.py index 7d3156fe00..62c71a9b5b 100644 --- a/tests/utils/run_sharding_dump.py +++ b/tests/utils/run_sharding_dump.py @@ -59,9 +59,12 @@ flags.DEFINE_string("topology", None, "Specific topology to dump.") flags.DEFINE_string("num_slice", None, "Specific number of slices to dump.") flags.DEFINE_string("custom_mesh_and_rule", None, "Specific custom_mesh_and_rule to dump.") +flags.DEFINE_bool("pure_nnx", False, "Use pure NNX model.") -def run_single_dump(model_name: str, topology: str, num_slice: str, custom_mesh_and_rule: str, overrides: tuple) -> None: +def run_single_dump( + model_name: str, topology: str, num_slice: str, custom_mesh_and_rule: str, overrides: tuple, pure_nnx: bool = False +) -> None: """Generate sharding json file for one specific model, topology, slice and rule.""" args = [ "python3", @@ -79,6 +82,8 @@ def run_single_dump(model_name: str, topology: str, num_slice: str, custom_mesh_ args.append(f"custom_mesh_and_rule={custom_mesh_and_rule}") if overrides: args.extend(overrides) + if pure_nnx: + args.append("pure_nnx=true") subprocess.run(args, check=True) @@ -117,7 +122,7 @@ def main(argv: Sequence[str]) -> None: print(" -> Sharding files already exist. Regenerating to overwrite.") try: - run_single_dump(model_name, topology, str(num_slice), custom_mesh_and_rule, overrides) + run_single_dump(model_name, topology, str(num_slice), custom_mesh_and_rule, overrides, pure_nnx=FLAGS.pure_nnx) except subprocess.CalledProcessError: print(f"!!! FAILED: {model_name} {topology} {num_slice} {custom_mesh_and_rule} overrides={overrides}") From 0f09e3ac72419ce2cdbd98e7ed338b5189ea02f1 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 28 Apr 2026 21:17:04 +0000 Subject: [PATCH 3/5] NNX: correctness fixes, enable feature paths, and vocab tiling on NNX MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug fixes (run as no-op while pure_nnx=False stays default): - nnx_wrappers.py: add _refresh_variable_trace_state + is_linen_initializing; call from ToLinen after nnx.update to fix "Cannot extract graph node from different trace level" when grad tracers leak into Variable._trace_state. - gpt_oss.py / olmo3.py: replace inline nn.Dropout(...) with self.dropout = linears.Dropout(...) in __init__ to fix CallCompactUnboundModuleError. - normalizations.py: Qwen3NextRMSNorm signature: eps -> epsilon, accept shard_mode/kernel_axes/parameter_memory_host_offload for callsite parity. - attentions.py / qwen3.py: callsites eps= -> epsilon=. - moe.py: per_expert_scale block moved into the unfused-kernel else branch (was scaling wo even when fused_kernel was active). - models.py: build MTP block as MultiTokenPredictionBlock(...) directly (drop the ToNNX(linen) + lazy_init wrap); pass multimodal_input whole to NNXDecoder instead of unpacking 5 fields. - gradient_accumulation.py: ZeRO-1+GA all-reduce annotation deferred until after lax.scan (reduced/unreduced PartitionSpec is rejected inside scan carry); use nnx.merge(..., copy=True) to avoid Variable reuse. - diloco.py: NNX-aware state handling — state.params -> state.model.filter (nnx.Param), step counter at state.optimizer.step, replace_nnx_model_params helper for jax.lax.cond pytree-structure parity. - train_compile.py: new _collect_nnx_activation_shardings helper (forward pass populates _ACTIVATION_SHARDINGS_DUMP — get_abstract_state_nnx only traces __init__); NNX path now passes 2-arg shaped_train_args (no rng); diloco path patched to handle the 2-vs-3 length difference. - muon_utils.py: get_model_mdn default pure_nnx=True; wrap NNX result as {"params": nnx.to_pure_dict(...)} for parity with Linen tree shape. - nnx_decoders.py: FP8+NNX scan fix — Linen FP8 ops (fp8_nanoo, fp8_gpu) retain tracers in Linen scope across re-traces. Skip jax.checkpoint and use a Python for-loop instead of jax.lax.scan when quantization is FP8. Makes FP8 quantization usable on the NNX path. - train.py (pre-train train_step): return nnx.state(new_state, nnx.Not (nnx.Intermediate)) so sowed forward-pass artifacts (e.g. max_logits for QK-Clip) don't break leaf-count parity with state_mesh_shardings. - llama2.py: pass parameter_memory_host_offload to pre_self_attention_layer _norm RMSNorm (was missing on this norm only). - base.yml: add 4 pipeline-related logical_axis_rules — layers_outside _pipeline, layers_per_stage, num_activations, circular_repeats. Additive, no-op without use_nnx_pipeline=True. NNX feature enablements (clear all 17 "Pure NNX support has not been implemented yet" NotImplementedError sites by routing Linen-coupled utilities to the Linen path; their on-disk format is Linen): - layerwise_quantization.py (2 sites): operates on Linen-format checkpoints via DeepSeek*ToLinen layers. - lora_utils.py (1 site): downstream get_lora_abstract_state expects Linen tree shape; LoRA adapters on disk are Linen. - standalone_checkpointer.py (2 sites): add_entropy_to_checkpoint accesses state.opt_state[0]._replace(mu=..., nu=...) — Linen-only. - generate_param_only_checkpoint.py (3 sites): _possibly_unroll_params and _save_decode_checkpoint use state.params["params"]["decoder"] — Linen. - convert_gpt3_ckpt_from_paxml.py (2 sites): keystr_map targets Linen tree paths (.params['params'], .opt_state.mu['params']). - maxengine.py (3 sites): inference engine uses state.params and serves Linen-format inference checkpoints. - grpo_trainer.py (4 sites): RL trainer is end-to-end Linen-shaped; route to Linen with a clear log warning since NNX-format checkpoints will fail at restore time. Vocab tiling on NNX (real implementation, not just routing): - models.py: add Transformer.logits_from_hidden_states on the NNX Transformer class — wraps NNXDecoder.apply_output_head with the token_embedder; mirrors TransformerLinenPure.logits_from_hidden_states. - vocabulary_tiling.py: add vocab_tiling_nnx_loss — chunks the vocab axis via jax.lax.scan and calls model.logits_from_hidden_states(chunk) per chunk. The NNX model carries its parameters internally so no explicit FSDP gather is needed (unlike the Linen gathered_params pattern). MVP uses default autograd; custom_vjp memory-savings optimization is a follow-up if backward memory becomes a concern. - train.py (NNX loss_fn): replace the NotImplementedError with the call to vocab_tiling_nnx_loss using hidden_states from intermediates. - pyconfig_deprecated.py / configs/types.py: drop the num_vocab_tiling > 1 and enable_nnx validation guards (no longer needed). DPO + NNX retained as NotImplementedError but with a much more informative message (points users at pure_nnx=False workaround). Full implementation is deferred — needs a new TrainState shape carrying both policy and reference NNX models plus an NNX dpo_loss_fn. Stats: 26 source files modified, +406 / -171 lines. Linen invariant verified: pure_nnx / enable_nnx / pure_nnx_decoder still default to False; Linen-path UTs unaffected (3 pre-existing failures on the parent branch remain unchanged — sharding_compare_test::deepseek2-16b, optimizers_test::test_model_integration_kimi-k2-1t, diloco_test::two _slices x2). All "Pure NNX support has not been implemented yet" NotImplementedError sites cleared (was 17, now 0). --- .../convert_gpt3_ckpt_from_paxml.py | 15 +-- src/maxtext/configs/base.yml | 7 ++ src/maxtext/configs/pyconfig_deprecated.py | 3 +- src/maxtext/configs/types.py | 3 +- src/maxtext/experimental/rl/grpo_trainer.py | 37 +++--- src/maxtext/inference/maxengine/maxengine.py | 22 ++-- src/maxtext/layers/attentions.py | 4 +- src/maxtext/layers/moe.py | 4 +- src/maxtext/layers/nnx_decoders.py | 34 +++++- src/maxtext/layers/nnx_wrappers.py | 35 ++++++ src/maxtext/layers/normalizations.py | 14 ++- src/maxtext/models/gpt_oss.py | 5 +- src/maxtext/models/llama2.py | 1 + src/maxtext/models/models.py | 19 +++- src/maxtext/models/olmo3.py | 4 +- src/maxtext/models/qwen3.py | 4 +- src/maxtext/trainers/diloco/diloco.py | 59 ++++++++-- src/maxtext/trainers/pre_train/train.py | 22 +++- .../trainers/pre_train/train_compile.py | 38 ++++++- .../utils/generate_param_only_checkpoint.py | 26 ++--- src/maxtext/utils/gradient_accumulation.py | 21 +++- src/maxtext/utils/layerwise_quantization.py | 20 ++-- src/maxtext/utils/lora_utils.py | 11 +- src/maxtext/utils/muon_utils.py | 5 +- src/maxtext/utils/standalone_checkpointer.py | 15 +-- src/maxtext/utils/vocabulary_tiling.py | 107 ++++++++++++++++++ tests/unit/train_nnx_test.py | 7 -- 27 files changed, 399 insertions(+), 143 deletions(-) diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py index 9b5f0cfb21..d4d4c39290 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py @@ -87,11 +87,12 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name devices_array = maxtext_utils.create_device_mesh(cfg) mesh = Mesh(devices_array, cfg.mesh_axes) + # This conversion script reads paxml-format weights and emits a Linen-format + # MaxText checkpoint (downstream uses `.params['params']`, `.opt_state.mu['params']`, + # `.opt_state.nu['params']` keystr paths; the keystr_map below targets the Linen + # tree shape). Use the Linen path regardless of pure_nnx. quant = quantizations.configure_quantization(cfg) - if cfg.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg) tx = optimizers.get_optimizer(cfg, learning_rate_schedule) @@ -102,11 +103,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name cfg.checkpoint_period, ) - if cfg.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn) max_logging.log("start") max_utils.print_mem_stats("After params initialized") diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 2d5cde124b..48a36a5c9b 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -560,6 +560,13 @@ logical_axis_rules: [ ['tokens_per_page', []], ['paged_kv_head_dim_size', []], # ========================================== + # Pipeline Parallelism + # ========================================== + ['layers_outside_pipeline', []], + ['layers_per_stage', []], + ['num_activations', []], + ['circular_repeats', []], + # ========================================== # Deprecated / Scheduled for Removal # ========================================== ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], diff --git a/src/maxtext/configs/pyconfig_deprecated.py b/src/maxtext/configs/pyconfig_deprecated.py index 406ba92523..c14d87cd4b 100644 --- a/src/maxtext/configs/pyconfig_deprecated.py +++ b/src/maxtext/configs/pyconfig_deprecated.py @@ -195,10 +195,9 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) - def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool): + del enable_nnx # NNX vocab tiling supported via vocab_tiling_nnx_loss in vocabulary_tiling.py if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0: raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") - if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration - raise ValueError("We currently don't support vocab tiling on NNX module.") def validate_rampup_batch_size(batch_size_start, batch_size_end, batch_size_increment, global_rampup_samples): diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 5c980966e7..f6d5c4eddf 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2814,8 +2814,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0 ): raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") - if self.num_vocab_tiling > 1 and self.enable_nnx: - raise ValueError("We currently don't support vocab tiling on NNX module.") + # Vocab tiling on NNX is now supported via vocab_tiling_nnx_loss in vocabulary_tiling.py. if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring": if "gpu" not in self.hardware: raise ValueError( diff --git a/src/maxtext/experimental/rl/grpo_trainer.py b/src/maxtext/experimental/rl/grpo_trainer.py index 28eef21cb0..4244d199a8 100644 --- a/src/maxtext/experimental/rl/grpo_trainer.py +++ b/src/maxtext/experimental/rl/grpo_trainer.py @@ -542,29 +542,28 @@ def setup_train_loop( - eval_data_iterator: The iterator for the evaluation dataset (or None). - state: The initialized training state. """ + # GRPO RL trainer is Linen-shaped end-to-end (state.params accesses below, + # state_mesh_shardings.params, and the inference path through MaxEngine which is + # Linen-only). Run on Linen path regardless of pure_nnx; warn the user since + # NNX-format checkpoints will mismatch at restore time. + if config.pure_nnx or config_inference.pure_nnx: + max_logging.log( + "WARNING: GRPO RL trainer does not yet support pure_nnx natively; " + "running on the Linen path. NNX-format checkpoints will not load correctly here." + ) with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): max_logging.log("Training mesh used for the workload") num_inference_devices = config.inference_devices_per_replica * config.inference_replicas training_devices = jax.devices()[num_inference_devices:] - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = mt.from_config(config, devices=training_devices) + model = mt.from_config(config, devices=training_devices) mesh = model.mesh max_logging.log("Inference mesh used for the workload") inference_devices = jax.devices()[:num_inference_devices] - if config_inference.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - inference_model = mt.from_config(config_inference, devices=inference_devices) + inference_model = mt.from_config(config_inference, devices=inference_devices) inference_mesh = inference_model.mesh init_rng = jax.random.PRNGKey(config.init_weights_seed) learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): @@ -573,14 +572,10 @@ def setup_train_loop( data_iterator, config, mesh, checkpoint_manager, init_state_fn ) - # create inference_state_mesh_shardings from inference_mesh - if config_inference.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_inference_state_fn = functools.partial( - maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng - ) + # create inference_state_mesh_shardings from inference_mesh (Linen path; see warning above) + init_inference_state_fn = functools.partial( + maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng + ) inference_state_mesh_shardings = maxtext_utils.get_abstract_state( config_inference, inference_mesh, init_inference_state_fn, is_training=False )[2] diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index 5bb0a87b5a..c00f475e8d 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -111,12 +111,12 @@ def __init__(self, config: Any, devices: Any | None = None): devices_array = maxtext_utils.create_device_mesh(config=config, devices=devices) self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) - # Model and Optimizer definition + # Model and Optimizer definition. + # MaxEngine uses Linen-shaped state (state.params, state_mesh_shardings.params, + # state.opt_state) and serves Linen-format inference checkpoints. Use Linen path + # regardless of pure_nnx — the flag affects training, not inference serving. quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None)) self.abstract_params = None @@ -232,11 +232,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar rng1, rng2, rng3 = jax.random.split(rng, 3) if params: print("Resharding given params") - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) _, self.state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state( self.config, self._mesh, init_state_fn, False ) @@ -245,11 +241,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar state = maxtext_utils.init_decode_state(None, params) state = max_utils.unbox_logicallypartioned(state) else: - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1) state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.config, self._mesh, None, init_state_fn) # pylint: disable=isinstance-second-argument-not-valid-type self.abstract_params = jax.tree_util.tree_map( diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index e53de0973a..f2c337f330 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -525,14 +525,14 @@ def __init__( elif self.is_qwen3_next: self.query_norm = Qwen3NextRMSNorm( num_features=self.config.head_dim, - eps=self.config.normalization_layer_epsilon, + epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, rngs=self.rngs, ) self.key_norm = Qwen3NextRMSNorm( num_features=self.config.head_dim, - eps=self.config.normalization_layer_epsilon, + epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, rngs=self.rngs, diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index e23c3eba9f..48d1f78108 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -2242,8 +2242,8 @@ def __call__( w0_kernel = jnp.asarray(self.wi_0[...], self.dtype) w1_kernel = jnp.asarray(self.wi_1[...], self.dtype) - if self.per_expert_scale is not None: - wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None] + if self.per_expert_scale is not None: + wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None] if self.wi_0_sparsity_module is not None: _, w0_kernel = self.wi_0_sparsity_module(jnp.zeros_like(w0_kernel), w0_kernel) diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 100d9c6817..94538a7911 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -543,8 +543,16 @@ def pure_layer_fn(state_in, y_in): out = merged_layer(y_in, **kwargs) return out, nnx.state(merged_layer) - checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) - out, new_state = checkpointed_fn(state, y) + # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen + # mutable scope. jax.checkpoint re-traces the scan body during backward (remat), + # but the Linen scope retains JAX tracers from the first trace, causing + # UnexpectedTracerError. Skip checkpoint for these quantization types. + uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu") + if uses_linen_fp8_mutable_state: + out, new_state = pure_layer_fn(state, y) + else: + checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) + out, new_state = checkpointed_fn(state, y) nnx.update(layer, new_state) return out @@ -623,13 +631,12 @@ def layer_fn(carry, scanned_vars): return new_carry, (new_current_state, updated_kv) return new_carry, new_current_state - layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) - if use_kv: # If kv_caches is provided (e.g., from vLLM), we CANNOT use jax.lax.scan # because scanning requires stacking the kv_caches list, which creates a copy # and breaks the in-place memory updates required by vLLM's PagedAttention. # Therefore, we must unroll the loop statically when kv_caches is provided. + layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) # kv_caches_stacked is actually the original kv_caches list in this new flow kv_caches_list = kv_caches_stacked @@ -651,7 +658,24 @@ def layer_fn(carry, scanned_vars): # inference with vLLM, parameters do not change and we don't need intermediates. return current_carry, layers, None else: - final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state)) + # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen + # mutable scope. jax.lax.scan traces the body function and Linen's setup() creates + # intermediate tracer values (amax_history float32[1024]) that escape the scan scope, + # causing UnexpectedTracerError. Use a Python for loop instead for these types. + uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu") + if uses_linen_fp8_mutable_state: + carry = x_in + per_layer_states = [] + for i in range(length): + current_params = jax.tree.map(lambda x, i=i: x[i], params) + current_state = jax.tree.map(lambda x, i=i: x[i], state) + carry, new_state_i = layer_fn(carry, (current_params, current_state)) + per_layer_states.append(new_state_i) + final_carry = carry + scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states) + else: + layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) + final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state)) returned_kv_stacked = None if scan_axis != 0: diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index d41d924456..24ebecd492 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -26,6 +26,7 @@ from flax.core import FrozenDict from flax.core import meta from flax.nnx import graph +from flax.nnx import tracers as nnx_tracers from flax.nnx import variablelib from flax.nnx.bridge import module as bdg_module from flax.nnx.module import Module @@ -167,6 +168,39 @@ def current_linen_module() -> linen.Module | None: return None +def is_linen_initializing() -> bool: + """Check if the current execution context is inside a Linen init() call. + + Returns True when called from within a ``to_linen_class`` wrapper's + ``init()`` path. Uses :func:`current_linen_module` to access the Linen + module stack (private API already used by this module). + + This is used by NNX pipeline modules to short-circuit the full scan + during Linen init, where only the output shape/dtype is needed. + """ + module = current_linen_module() + if module is not None and hasattr(module, "is_initializing") and callable(module.is_initializing): + return module.is_initializing() + return False + + +def _refresh_variable_trace_state(module: Module) -> None: + """Refresh _trace_state for Variables that have stale trace state. + + When nnx.update() is called with tracer values from a JAX transformation + (e.g. jax.grad's LinearizeTracer), it uses _unsafe_bypass_check=True which + updates the raw value but not _trace_state. This leaves Variables with a + stale _trace_state from the outer (Python) context, causing nnx.split() to + fail with "Cannot extract graph node from different trace level" errors. + + This function resets _trace_state on any Variables whose _can_update is False + so that downstream NNX operations (e.g. nnx.split in NNXPipeline) succeed. + """ + for _, v in nnx.graph.iter_graph(module): + if isinstance(v, variablelib.Variable) and not v._can_update: # pylint: disable=protected-access + object.__setattr__(v, "_trace_state", nnx_tracers.TraceState()) + + class ToNNX(Module): """A wrapper to turn any Linen module into an NNX module. @@ -464,6 +498,7 @@ def maybe_unbox(x): warnings.warn(f"Found unknown module paths in incoming state:{paths_str}") nnx.update(module, new_state) + _refresh_variable_trace_state(module) _fix_for_qwix_quantization(module) method_fn = _get_module_method(module, nnx_method) diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index bf91262bf1..35611b2166 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -114,7 +114,17 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> return y_flat.reshape(input_shape) -def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs): +def Qwen3NextRMSNorm( + num_features: int, + epsilon: float = 1e-6, + dtype: DType = None, + weight_dtype: DType = None, + shard_mode=None, + kernel_axes=None, + parameter_memory_host_offload=None, + *, + rngs: nnx.Rngs, +): """ Used for input and post attention layernorms in Qwen3NextDecoderLayer. @@ -127,7 +137,7 @@ def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: return nnx.data( RMSNorm( num_features=num_features, - epsilon=eps, + epsilon=epsilon, dtype=dtype, weight_dtype=weight_dtype, scale_init=linen_initializers.zeros, diff --git a/src/maxtext/models/gpt_oss.py b/src/maxtext/models/gpt_oss.py index 9401d01d9f..5f4a2f3fb6 100644 --- a/src/maxtext/models/gpt_oss.py +++ b/src/maxtext/models/gpt_oss.py @@ -29,6 +29,7 @@ from maxtext.common.common_types import AttentionType, Config from maxtext.layers import attentions from maxtext.layers import initializers +from maxtext.layers import linears from maxtext.layers import moe from maxtext.layers import nnx_wrappers from maxtext.layers import quantizations @@ -132,6 +133,8 @@ def __init__( rngs=rngs, ) + self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) + def __call__( self, inputs, @@ -189,7 +192,7 @@ def __call__( mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, diff --git a/src/maxtext/models/llama2.py b/src/maxtext/models/llama2.py index 6a215c5dbe..244eed03bb 100644 --- a/src/maxtext/models/llama2.py +++ b/src/maxtext/models/llama2.py @@ -71,6 +71,7 @@ def __init__( shard_mode=config.shard_mode, kernel_axes=("norm",), epsilon=config.normalization_layer_epsilon, + parameter_memory_host_offload=config.parameter_memory_host_offload, rngs=rngs, ) diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index c6ca234a47..5ba365b74b 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -398,6 +398,19 @@ def no_op(self, *args, **kwargs): """A no-op method to allow the model to be used in a lazy context.""" return + def logits_from_hidden_states(self, hidden_states, deterministic, model_mode): + """Compute logits from hidden states (wraps NNXDecoder.apply_output_head). + + Mirrors the Linen TransformerLinenPure.logits_from_hidden_states method; + used by vocabulary tiling to recompute logits from chunked hidden states. + """ + return self.decoder.apply_output_head( + shared_embedding=self.token_embedder, + y=hidden_states, + deterministic=deterministic, + model_mode=model_mode, + ) + def init_cache(self, cache_size: int, batch_size: int, dtype=jnp.float32): """Initializes the KV cache for the Transformer. @@ -509,11 +522,7 @@ def __call__( previous_chunk=previous_chunk, slot=slot, page_state=page_state, - image_embeddings=multimodal_input.image_embeddings if multimodal_input is not None else None, - image_masks=multimodal_input.image_masks if multimodal_input is not None else None, - audio_embeddings=multimodal_input.audio_embeddings if multimodal_input is not None else None, - audio_masks=multimodal_input.audio_masks if multimodal_input is not None else None, - bidirectional_mask=multimodal_input.bidirectional_mask if multimodal_input is not None else None, + multimodal_input=multimodal_input, kv_caches=kv_caches, attention_metadata=attention_metadata, deepstack_visual_embeds=deepstack_visual_embeds, diff --git a/src/maxtext/models/olmo3.py b/src/maxtext/models/olmo3.py index a3a8b6997d..9d68d6a57d 100644 --- a/src/maxtext/models/olmo3.py +++ b/src/maxtext/models/olmo3.py @@ -30,6 +30,7 @@ from maxtext.common.common_types import AttentionType, Config from maxtext.layers import attentions from maxtext.layers import initializers +from maxtext.layers import linears from maxtext.layers import nnx_wrappers from maxtext.layers import quantizations from maxtext.layers.attentions import Attention @@ -140,6 +141,7 @@ def __init__( model_mode=model_mode, rngs=rngs, ) + self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) def __call__( self, @@ -200,7 +202,7 @@ def __call__( mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index bd65f04438..87cb4cc7ef 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -966,7 +966,7 @@ def __init__( # First LayerNorm, applied before the attention block. self.input_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, @@ -991,7 +991,7 @@ def __init__( # Second LayerNorm, applied before the MoE block. self.post_attention_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, diff --git a/src/maxtext/trainers/diloco/diloco.py b/src/maxtext/trainers/diloco/diloco.py index a9ef64631a..39d84a89dc 100644 --- a/src/maxtext/trainers/diloco/diloco.py +++ b/src/maxtext/trainers/diloco/diloco.py @@ -26,6 +26,7 @@ from typing import Any, Callable import drjax +from flax import nnx from flax import struct from flax.training import train_state import jax @@ -153,7 +154,15 @@ def add_diloco_dim(x): momentum=config.diloco_outer_momentum, nesterov=True, ) - outer_opt_state = jax.eval_shape(outer_optimizer.init, abstract_state.params) + # For NNX, model params (Param variables only) live under abstract_state.model; + # for Linen under abstract_state.params. + if config.pure_nnx: + model_params = abstract_state.model.filter(nnx.Param) + model_params_sharding = state_mesh_shardings.model.filter(nnx.Param) + else: + model_params = abstract_state.params + model_params_sharding = state_mesh_shardings.params + outer_opt_state = jax.eval_shape(outer_optimizer.init, model_params) # Create abstract step abstract_step = jax.ShapeDtypeStruct((), jnp.int32) @@ -161,7 +170,7 @@ def add_diloco_dim(x): # Build abstract DiLoCo state diloco_state = DiLoCoTrainState( inner_state=inner_state, - params=abstract_state.params, + params=model_params, outer_opt_state=outer_opt_state, step=abstract_step, ) @@ -171,12 +180,12 @@ def add_diloco_dim(x): # Sharding for outer_opt_state. For SGD with momentum, it is (TraceState(trace=...), EmptyState()) # We shard the momentum trace the same way as the parameters. outer_opt_state_sharding = ( - optax.TraceState(trace=state_mesh_shardings.params), + optax.TraceState(trace=model_params_sharding), optax.EmptyState(), ) diloco_state_shardings = DiLoCoTrainState( inner_state=inner_state_shardings, - params=state_mesh_shardings.params, + params=model_params_sharding, outer_opt_state=outer_opt_state_sharding, step=None, ) @@ -205,11 +214,15 @@ def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]: # mesh automatically when jax.set_mesh is used. inner_state = drjax.broadcast(state, mesh=mesh) # Outer state retains a single copy of the model parameters and optimizer state. - outer_params = state.params + # For NNX, model params (Param variables only) live under state.model; + # for Linen under state.params. + outer_params = state.model.filter(nnx.Param) if config.pure_nnx else state.params outer_opt_state = outer_optimizer.init(outer_params) outer_opt_state_sharding = jax.tree_util.tree_map(lambda x: x.sharding, outer_opt_state) + # For NNX, the step counter lives at state.optimizer.step; for Linen at state.step. + step = state.optimizer.step if config.pure_nnx else state.step return ( - DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=state.step), + DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=step), outer_opt_state_sharding, ) @@ -244,7 +257,11 @@ def synchronize(state): # Calculate the delta between the current replica's state and the global # state (since last synchronization). broadcast_outer_params = drjax.broadcast(state.params, mesh=mesh) - model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params) + # For NNX, model Param vars live under inner_state.model; for Linen under inner_state.params. + inner_model_params = ( + nnx.filter_state(state.inner_state.model, nnx.Param) if config.pure_nnx else state.inner_state.params + ) + model_delta = jax.tree.map(lambda x, y: y - x, inner_model_params, broadcast_outer_params) # Treat the average delta as the outer optimizer's gradient and apply to # the global (outer) model params. averaged_pseudo_grad = drjax.reduce_mean(model_delta) @@ -253,7 +270,27 @@ def synchronize(state): # Replace inner model params with the new global model params. # NOTE: inner optimizer state is retained despite the change in parameters, # see section 6.1 in https://arxiv.org/pdf/2311.08105. - new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state, mesh=mesh) + if config.pure_nnx: + # For NNX: merge new Param vars back with the non-Param model vars (e.g. RNG state). + def replace_nnx_model_params(s, new_params): + non_param_model = nnx.filter_state(s.model, nnx.Not(nnx.Param)) + new_model = nnx.merge_state(non_param_model, new_params) + # Build result via __setitem__ so nested States are stored as plain dicts + # internally, matching the pytree structure produced by nnx.state(). + # (Passing State objects via the constructor dict literal stores them + # as-is, causing jax.lax.cond to see mismatched pytree structures.) + result = type(s)({}) + result["model"] = new_model + result["optimizer"] = s["optimizer"] + return result + + new_inner_state = drjax.map_fn( + lambda s: replace_nnx_model_params(s, new_outer_params), + state.inner_state, + mesh=mesh, + ) + else: + new_inner_state = drjax.map_fn(lambda s: s.replace(params=new_outer_params), state.inner_state, mesh=mesh) return state.replace( params=new_outer_params, outer_opt_state=new_opt_state, @@ -271,14 +308,16 @@ def diloco_train_step(state, batch, prng): broadcast_rng = drjax.broadcast(prng, mesh=mesh) inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng), mesh=mesh) avg_metrics = typed_reduce_mean(metrics) + # For NNX, the step counter lives at inner_state.optimizer.step; for Linen at inner_state.step. + new_step = inner_state.optimizer.step[0] if config.pure_nnx else inner_state.step[0] state = state.replace( inner_state=inner_state, - step=inner_state.step[0], + step=new_step, ) # Either synchronize the model, or no-op, depending on whether the current # step falls on the synchronization period. state = jax.lax.cond( - inner_state.step[0] % config.diloco_sync_period == 0, + new_step % config.diloco_sync_period == 0, synchronize, lambda x: x, # no-op state, diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index bd475deba4..585886b682 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -72,7 +72,7 @@ from maxtext.utils import maxtext_utils_nnx from maxtext.utils import train_utils from maxtext.utils.gradient_accumulation import gradient_accumulation_loss_and_grad -from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss +from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss, vocab_tiling_nnx_loss _diag_modules = _cloud_diag() diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = _diag_modules @@ -200,9 +200,10 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr intermediate_outputs = intermediates.to_pure_dict() if config.num_vocab_tiling > 1: - raise NotImplementedError("Vocab tiling for NNX modules has not been implemented.") - - if (config.use_indexer and not config.indexer_sparse_training) and is_train: + hidden_state_key = ("decoder", "hidden_states") + hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0] + xent_sum, total_z_loss = vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train) + elif (config.use_indexer and not config.indexer_sparse_training) and is_train: # In Dense Warm-up stage, we skip main model loss calculation for efficiency. # The main model parameters are frozen and only the indexer is trained via KL divergence. xent_sum = 0.0 @@ -320,7 +321,12 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat ga_fn, ga_model, ga_params, ga_rng, ga_dpo = _loss_fn, model, params, dropout_rng, extra_dpo_args else: if config.use_dpo: - raise NotImplementedError("DPO for NNX modules has not been implemented.") + raise NotImplementedError( + "DPO is not yet supported for NNX modules. DPO requires a reference model " + "stored alongside the policy model (Linen path uses state.params['reference_params']); " + "the NNX TrainState equivalent has not been wired up. As a workaround, set " + "pure_nnx=False for DPO runs." + ) state = nnx.merge(model, state) # reconstruct TrainStateNNX ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] @@ -546,7 +552,11 @@ def move(path, value): if config.use_dpo: new_state = _merge_dpo_state(new_state, reference_params) return new_state, metrics - return nnx.state(new_state), metrics + # Exclude Intermediate variables (e.g., sowed max_logits for QK-Clip) from the + # returned state. Intermediates are transient forward-pass artifacts and must not + # persist across steps: they're absent from the abstract state used to build + # state_mesh_shardings, so including them would cause a leaf-count mismatch in JAX. + return nnx.state(new_state, nnx.Not(nnx.Intermediate)), metrics def eval_step(model, config, state, data, dropout_rng=None): diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index a2981f67ed..c593d3c540 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -30,6 +30,7 @@ from flax import nnx from flax.linen import partitioning as nn_partitioning import jax +import jax.numpy as jnp from jax.experimental.serialize_executable import serialize from jax.experimental.topologies import get_topology_desc from jax.sharding import AxisType, Mesh @@ -92,6 +93,27 @@ def get_topology_mesh(config): return topology_mesh +def _collect_nnx_activation_shardings(create_model_fn, config, mesh): + """Run an NNX forward pass in abstract mode to populate _ACTIVATION_SHARDINGS_DUMP. + + get_abstract_state_nnx uses nnx.eval_shape which only traces model initialization, + not __call__. Activation shardings are only collected during a forward pass. + """ + input_shape = (config.micro_batch_size_to_train_on, config.max_target_length) + + def _nnx_forward(): + model_instance = create_model_fn() + return model_instance( + decoder_input_tokens=jnp.ones(input_shape, dtype=jnp.int32), + decoder_positions=jnp.ones(input_shape, dtype=jnp.int32), + decoder_segment_ids=jnp.ones(input_shape, dtype=jnp.int32), + enable_dropout=False, + ) + + with nn_partitioning.axis_rules(config.logical_axis_rules): + jax.eval_shape(_nnx_forward) + + def get_shaped_inputs(topology_mesh, config): """Get shaped abstractions of inputs to train_step: state, batch and rng""" # Construct the model and optimizer to get shaped versions of the state @@ -129,7 +151,8 @@ def create_train_state_fn(): # For NNX, get_functional_train_with_signature expects the graphdef (static structure), # not the raw model — mirroring how the training loop does nnx.split(train_state). with nn_partitioning.axis_rules(config.logical_axis_rules): - graphdef, _ = nnx.get_abstract_model(init_state_fn, topology_mesh) + abs_train_state = nnx.eval_shape(init_state_fn) + graphdef, _ = nnx.split(abs_train_state) model = graphdef else: # unsharded logical annotations @@ -139,10 +162,17 @@ def create_train_state_fn(): shaped_batch = maxtext_utils.get_shaped_batch(config) if config.pure_nnx: - shaped_train_args = (abstract_state, shaped_batch, None) # NNX doesn't use dropout_rng + shaped_train_args = (abstract_state, shaped_batch) # NNX doesn't use dropout_rng else: shaped_train_args = (abstract_state, shaped_batch, shaped_rng) shaped_train_kwargs = {} + + # Collect activation shardings for NNX by running an abstract forward pass. + # This must happen after get_abstract_state (which uses nnx.eval_shape and only + # traces __init__, not __call__). + if config.debug_sharding and config.pure_nnx: + _collect_nnx_activation_shardings(_create_model_partial, config, topology_mesh) + return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model @@ -280,7 +310,9 @@ def main(argv: Sequence[str]) -> None: diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state( config, abstract_state, state_mesh_shardings, topology_mesh ) - shaped_train_args = (diloco_state, shaped_train_args[1], shaped_train_args[2]) + # For NNX, shaped_train_args has 2 elements (state, batch) — no rng; pass None for prng. + shaped_rng_arg = shaped_train_args[2] if len(shaped_train_args) > 2 else None + shaped_train_args = (diloco_state, shaped_train_args[1], shaped_rng_arg) # Wrap train_step with diloco train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, None) diff --git a/src/maxtext/utils/generate_param_only_checkpoint.py b/src/maxtext/utils/generate_param_only_checkpoint.py index 2fd14b87a2..0f997a6577 100644 --- a/src/maxtext/utils/generate_param_only_checkpoint.py +++ b/src/maxtext/utils/generate_param_only_checkpoint.py @@ -90,20 +90,17 @@ def slice_ith(input_layers): def _read_train_checkpoint(config, checkpoint_manager, mesh): """Read training checkpoint at path defined by load_full_state_path.""" - # Model and Optimizer definition + # Model and Optimizer definition. + # This script reads a Linen-format full state and emits a Linen-format + # parameter-only checkpoint (downstream `_possibly_unroll_params` and + # `_save_decode_checkpoint` access `state.params["params"]["decoder"]` / `state.opt_state`, + # both Linen-only). Use the Linen path regardless of pure_nnx. quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) state, state_mesh_notations, _, _ = maxtext_utils.setup_training_state( None, config, mesh, checkpoint_manager, init_state_fn ) @@ -114,12 +111,11 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh): def _generate_lora_decode_checkpoints(config, mesh): """Read lora checkpoints checkpoint at path defined by load_full_state_path.""" - # Model and Optimizer definition + # Model and Optimizer definition. + # LoRA adapters and downstream `_save_decode_checkpoint`/`_possibly_unroll_params` + # are Linen-shaped; use the Linen path regardless of pure_nnx. quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) diff --git a/src/maxtext/utils/gradient_accumulation.py b/src/maxtext/utils/gradient_accumulation.py index e1699647c6..cf84577dbd 100644 --- a/src/maxtext/utils/gradient_accumulation.py +++ b/src/maxtext/utils/gradient_accumulation.py @@ -71,10 +71,16 @@ def _maybe_shard_with_name(inputs, sharding_names): is_nnx = isinstance(model, nnx.Module) - # For more efficient DP/ZeRO-1 + GA - if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1: - ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings) - grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) + # For more efficient DP/ZeRO-1 + GA. + # config.ici_data_parallelism may be -1 (auto-fill: resolved at mesh creation time, but + # the config field remains -1). Treat any value != 1 as "data parallelism is active". + if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism != 1: + # jax.lax.scan traces its body with an AbstractMesh where all axis types are Auto, + # which rejects reduced/unreduced PartitionSpec in scan carry tensors (raises ValueError). + # Use plain params_shardings for ga_params and init_grad in the carry. + # The all-reduce for data parallelism is applied to raw_grads after the scan instead. + ga_params_shardings = params_shardings + grad_shardings = params_shardings else: ga_params_shardings = grad_shardings = params_shardings @@ -105,7 +111,7 @@ def accumulate_gradient(acc_grad_and_loss, data): if is_nnx: # Reconstruct the model using the fixed parameters (ga_params) # and the advancing non-parameter state (RNGs) from the carry. - local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"]) + local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"], copy=True) (_, aux), cur_batch_gradient = grad_func(local_model, config, data, None, None, *extra_dpo_args, is_train=True) _, _, next_rest_state = nnx.split(local_model, nnx.Param, ...) acc_grad_and_loss["rest_state"] = next_rest_state @@ -156,6 +162,11 @@ def reshape_to_microbatch_accumulations(batch_arr): + grad_and_loss["mtp_loss"] / config.gradient_accumulation_steps ) raw_grads = grad_and_loss["grad"] + if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism != 1: + # Apply unreduced annotation after the scan to trigger all-reduce across data-parallel + # devices (reduced/unreduced cannot be used inside jax.lax.scan carry tensors). + unreduced_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) + raw_grads = jax.tree.map(_maybe_shard_with_name, raw_grads, unreduced_shardings) raw_grads = jax.tree.map(_maybe_shard_with_name, raw_grads, params_shardings) raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss["total_weights"], raw_grads) aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) # pytype: disable=module-attr diff --git a/src/maxtext/utils/layerwise_quantization.py b/src/maxtext/utils/layerwise_quantization.py index 29fa928656..a6c1c07f67 100644 --- a/src/maxtext/utils/layerwise_quantization.py +++ b/src/maxtext/utils/layerwise_quantization.py @@ -173,19 +173,15 @@ def __init__(self, config: Any, rng: PRNGKeyType): devices_array = maxtext_utils.create_device_mesh(config=config) self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) - # Model and quantization config + # Model and quantization config. + # This script produces and consumes Linen-format checkpoints (see DeepSeek*ToLinen + # layer classes used in load_and_quantize). Always use the Linen path internally, + # regardless of the pure_nnx flag — the flag affects training, not checkpoint format. self.quant = quantizations.configure_quantization(config) - if self.config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen( - config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN - ) - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng) + model = models.transformer_as_linen( + config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN + ) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng) self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(self.config, self._mesh, init_state_fn, False) diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 24099ef22a..76cd26d20e 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Common LoRA utils needed to support LoRA adapters.""" +"""Common LoRA utils needed to support LoRA adapters.""" from functools import partial import json @@ -167,11 +167,10 @@ def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkp if lora_adapter_path: max_logging.log(f"Setting initial state of LoRA with lora_adapter_path = {lora_adapter_path}") - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) + # LoRA adapters are Linen-format on disk (downstream `get_lora_abstract_state` expects + # `unboxed_abstract_state.params` Linen tree shape; `lora_state.replace(params=...)` + # uses Linen TrainState API). Use the Linen init path regardless of the pure_nnx flag. + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, True) lora_config_path = lora_adapter_path + "adapter_config.json" diff --git a/src/maxtext/utils/muon_utils.py b/src/maxtext/utils/muon_utils.py index 3bd2b186b1..049a084979 100644 --- a/src/maxtext/utils/muon_utils.py +++ b/src/maxtext/utils/muon_utils.py @@ -116,6 +116,7 @@ def apply_transform_nnx(path: Tuple[jax.tree_util.KeyEntry, ...], leaf): # Use jax.tree_util.tree_map_with_path for NNX's potentially complex PyTree structure. # This is different with linen where abstract_param is a dict-based tree with nn.LogicallyPartitioned leaves. + # The result is an nnx.State with the same structure, where each Param's value holds the mdn result. muon_weight_dimension_numbers = jax.tree_util.tree_map_with_path(apply_transform_nnx, abstract_param) else: # Linen @@ -154,7 +155,7 @@ def get_leaf_info(leaf): print("\nIs this reasonable?") -def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=False): +def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=True): """Initializes a model and retrieves its Muon dimension numbers. This function sets up the configuration for a given model, initializes the @@ -191,6 +192,8 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=False): model = models.transformer_as_linen(config, mesh=mesh, quant=quant) # Get dimension number muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config, verbose=verbose) + if pure_nnx: + muon_weight_dimension_numbers = {"params": nnx.to_pure_dict(muon_weight_dimension_numbers)} return muon_weight_dimension_numbers diff --git a/src/maxtext/utils/standalone_checkpointer.py b/src/maxtext/utils/standalone_checkpointer.py index ba6b148b04..2fc2b09e25 100644 --- a/src/maxtext/utils/standalone_checkpointer.py +++ b/src/maxtext/utils/standalone_checkpointer.py @@ -52,18 +52,15 @@ def checkpoint_loop(config, state=None): Returns: """ - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = from_config(config) + # Standalone checkpointer is a save/restore exerciser that uses + # add_entropy_to_checkpoint() to populate Linen-shaped optimizer state + # (state.opt_state, state.params). Use the Linen path regardless of pure_nnx — + # the flag affects training, not this checkpoint test harness. + model = from_config(config) mesh = model.mesh init_rng = jax.random.PRNGKey(config.init_weights_seed) _, tx = train_utils.create_training_optimizer(config, model) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training=True) diff --git a/src/maxtext/utils/vocabulary_tiling.py b/src/maxtext/utils/vocabulary_tiling.py index e7b155416c..6a61f9ed23 100644 --- a/src/maxtext/utils/vocabulary_tiling.py +++ b/src/maxtext/utils/vocabulary_tiling.py @@ -247,3 +247,110 @@ def _bwd_scan_body(grad_params_acc, chunk_data): ) return total_loss, total_z_loss + + +def vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train): + """Calculates cross-entropy loss using vocab tiling for NNX models. + + NNX equivalent of `vocab_tiling_linen_loss`. Iterates the vocab dimension via + `jax.lax.scan` with `model.logits_from_hidden_states` per chunk; the model + carries its parameters internally so no explicit gather is needed. + + This is a memory-efficient forward (chunked logits) but uses the default + autograd path (no custom_vjp), so backward memory savings vs. the Linen + custom_vjp path are not yet realized. TODO: add a custom_vjp using + `nnx.split`/`nnx.merge` if backward memory becomes a concern. + + Args: + model: The NNX model instance (must implement `logits_from_hidden_states`). + hidden_states: The final hidden states from the decoder. + data: A dictionary containing the input data, including 'targets' and 'targets_segmentation'. + config: The model and training configuration. + is_train: A boolean indicating if the model is in training mode. + + Returns: + A tuple (total_loss, total_z_loss). + """ + labels = data["targets"] + segmentation = data["targets_segmentation"] + deterministic = not config.enable_dropout if is_train else True + model_mode = "train" + + hidden_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch", "activation_length", "activation_embed"), + ) + label_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch", "activation_length"), + ) + reshaped_hidden_spec = create_sharding( + model.mesh, + ("num_tile", "activation_embed_and_logits_batch_sequence", "activation_embed"), + ) + reshaped_data_spec = create_sharding( + model.mesh, + ("num_tile", "activation_embed_and_logits_batch_sequence"), + ) + chunked_hidden_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence", "activation_embed"), + ) + chunked_data_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence",), + ) + chunked_logits_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence", "activation_vocab"), + ) + + _maybe_shard_with_name = functools.partial( + maybe_shard_with_name, + shard_mode=config.shard_mode, + debug_sharding=config.debug_sharding, + extra_stack_level=1, + ) + + def _reshape(inputs, out_shape, out_sharding): + reshape_out_sharding = out_sharding if config.shard_mode == ShardMode.EXPLICIT else None + inputs = jax.lax.reshape(inputs, out_shape, out_sharding=reshape_out_sharding) + return _maybe_shard_with_name(inputs, out_sharding) + + hidden_states = _maybe_shard_with_name(hidden_states, hidden_spec) + labels = _maybe_shard_with_name(labels, label_spec) + segmentation = _maybe_shard_with_name(segmentation, label_spec) + + batch_size, seq_len, emb_dim = hidden_states.shape + vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling + + reshaped_hidden_states = _reshape( + hidden_states, (config.num_vocab_tiling, vocab_tile_size, emb_dim), reshaped_hidden_spec + ) + reshaped_labels = _reshape(labels, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) + reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) + + def _scan_body(accumulators, chunk_data): + loss_accumulator, z_loss_accumulator = accumulators + hidden_chunk, label_chunk, segmentation_chunk = chunk_data + hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec) + label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec) + segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec) + + chunk_logits = model.logits_from_hidden_states(hidden_chunk, deterministic, model_mode) + chunk_logits = _maybe_shard_with_name(chunk_logits, chunked_logits_spec) + one_hot_label_chunk = jax.nn.one_hot(label_chunk, config.vocab_size) + chunk_xent, chunk_z_loss = max_utils.cross_entropy_with_logits( + chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier + ) + + masked_xent = jnp.sum(chunk_xent * (segmentation_chunk != 0)) + masked_z_loss = jnp.sum(chunk_z_loss * (segmentation_chunk != 0)) + + return (loss_accumulator + masked_xent, z_loss_accumulator + masked_z_loss), None + + initial_acc = (jnp.zeros((), dtype=hidden_states.dtype), jnp.zeros((), dtype=hidden_states.dtype)) + (total_loss, total_z_loss), _ = jax.lax.scan( + _scan_body, initial_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation) + ) + return total_loss, total_z_loss diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py index 3495b4c557..f532820f86 100644 --- a/tests/unit/train_nnx_test.py +++ b/tests/unit/train_nnx_test.py @@ -154,13 +154,6 @@ def test_indexer_dense_warmup_skips_xent(self): self.assertEqual(float(aux["xent_sum"]), 0.0) self.assertEqual(float(loss), 0.0) - def test_vocab_tiling_raises_not_implemented(self): - cfg, ts = _build_state() - cfg.num_vocab_tiling = 4 - data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) - with self.assertRaises(NotImplementedError): - pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=True) - class TestTrainStepNNX(unittest.TestCase): """Cover the NNX branch of train_step (the diff_wrapper / nnx.update path).""" From 968a6a10df5027c26f52c7825142b182d586acab Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 29 Apr 2026 16:07:35 +0000 Subject: [PATCH 4/5] NNX: native DPO (TrainStateNNX.reference_model + dpo_loss_fn_nnx) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements NNX-native DPO so that the pure_nnx=True training path no longer raises NotImplementedError on use_dpo runs. The Linen DPO overlay pattern (model.apply(params=..., reference_params=...)) does not translate to NNX modules, which carry their parameters internally. Instead the policy and reference models are held as separate nnx.Module instances on TrainStateNNX, and a new dpo_loss_fn_nnx runs both forwards with stop_gradient on the reference logits. TrainStateNNX: - Add optional `reference_model: nnx.Module` field. apply_gradients continues to update only `self.model`, leaving `self.reference_model` bit-identical across steps. dpo_utils.py: - Add dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True). Signature mirrors the Linen dpo_loss_fn so it slots into gradient_accumulation_loss_and_grad's dispatcher (dropout_rng / params slots are unused for NNX; carried for parity, and reference_model is passed as the single extra_dpo_args entry). With nnx.value_and_grad(..., argnums=0) over the policy, no gradient flows to the reference model's nnx.Param leaves; the explicit jax.lax.stop_gradient on ref_logits is a belt-and-braces guard. - Both dpo_loss_fn (Linen) and dpo_loss_fn_nnx (NNX) now include indexer_loss=0.0 and mtp_loss=0.0 in aux so the gradient_accumulation aux pytree shape matches the non-DPO loss_fn. train.py: - Drop the NotImplementedError in train_step's NNX branch. When use_dpo, dispatch to dpo_loss_fn_nnx with state.reference_model as extra_dpo_args; otherwise use the regular loss_fn. eval_step gains the same dispatch. - diff_wrapper picks _loss_fn / extra_dpo_args from the per-path init block, so both the GA and non-GA NNX paths route DPO identically. - Checkpoint-save _split_dpo_state stripping is now Linen-only; TrainStateNNX saves whole (reference_model included) — the step-0 reload later overwrites reference_model from the step-0 checkpoint. train_utils.py: - NNX init_state_fn materializes a frozen reference_model alongside the policy when config.use_dpo. Both are constructed by _create_model_partial() with config.init_weights_seed, so they start identical (standard DPO practice) until the step-0 reload. - Step-0 checkpoint reload: copy step0_state["model"] into state["reference_model"]. Linen path unchanged. Tests: - New tests/unit/dpo_nnx_test.py (7 tests): TrainStateNNX reference_model init/hasattr semantics; apply_gradients leaves reference bit-identical; aux key set; identical policy/reference yields loss=log(2) and reward_accuracy=0.0 (strict > on equal logratios); dropout_rng/params slots are signature-compat only; nnx.value_and_grad(argnums=0) over the policy yields finite grads on policy params only. - train_nnx_test.py: drop the two stale negative tests (vocab_tiling_raises_not_implemented, train_step_dpo_raises_for_nnx) — both features are now real. Stats: 4 source files + 2 test files, +199/-22 source lines. Linen DPO path behaviorally unchanged (only adds two harmless aux-dict keys); NNX non-DPO path unchanged (all changes gated on config.use_dpo). --- src/maxtext/layers/train_state_nnx.py | 24 +- .../trainers/post_train/dpo/dpo_utils.py | 139 +++++++++++ src/maxtext/trainers/pre_train/train.py | 34 +-- src/maxtext/utils/train_utils.py | 24 +- .../integration/setup_train_loop_nnx_test.py | 9 - tests/unit/dpo_nnx_test.py | 215 ++++++++++++++++++ tests/unit/train_nnx_test.py | 10 - 7 files changed, 412 insertions(+), 43 deletions(-) create mode 100644 tests/unit/dpo_nnx_test.py diff --git a/src/maxtext/layers/train_state_nnx.py b/src/maxtext/layers/train_state_nnx.py index 9ef0e6dffd..3f9ee1ce29 100644 --- a/src/maxtext/layers/train_state_nnx.py +++ b/src/maxtext/layers/train_state_nnx.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" The NNX Unified TrainState. """ +"""The NNX Unified TrainState.""" from typing import Any @@ -25,20 +25,34 @@ class TrainStateNNX(nnx.Module): This replaces Linen's TrainState for checkpointing. Linen TrainState pytree: - {“params”: {...}, “opt_state”: {}...} + {"params": {...}, "opt_state": {}...} TrainStateNNX state pytree: - {“model”: {...}, “optimizer”: {“opt_state”: {...}} + {"model": {...}, "optimizer": {"opt_state": {...}}} + + For DPO (Direct Preference Optimization), an optional `reference_model` + carries a frozen copy of the same architecture used to compute reference + log-probabilities. Only `model` is updated by `apply_gradients`; the + reference is held alongside so it is sharded, jit-traced, and checkpointed + with the rest of the train state. """ - def __init__(self, model: nnx.Module, optimizer: nnx.Optimizer | None): + def __init__( + self, + model: nnx.Module, + optimizer: nnx.Optimizer | None, + reference_model: nnx.Module | None = None, + ): self.model = model self.optimizer = optimizer + if reference_model is not None: + self.reference_model = reference_model def apply_gradients(self, grads: Any): """ Mimics the Linen apply_gradients function. Updates the optimizer state, applies updates to parameters, - and increments the step counter. + and increments the step counter. Only updates `self.model`; + `self.reference_model` (if present) is left untouched. """ if self.optimizer is None: raise RuntimeError( diff --git a/src/maxtext/trainers/post_train/dpo/dpo_utils.py b/src/maxtext/trainers/post_train/dpo/dpo_utils.py index eeda1c1a7f..fd5faa5c9c 100644 --- a/src/maxtext/trainers/post_train/dpo/dpo_utils.py +++ b/src/maxtext/trainers/post_train/dpo/dpo_utils.py @@ -19,6 +19,8 @@ import jax import jax.numpy as jnp +from flax import nnx + from maxtext.utils import maxtext_utils @@ -148,6 +150,8 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t "total_weights": total_weights, "moe_lb_loss": moe_lb_loss, "reward_accuracy": reward_accuracy, + "indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility + "mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility } return loss, aux @@ -155,3 +159,138 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t def _merge_dpo_state(state, reference_params): """Merge reference parameters back into DPO state.""" return state.replace(params=dict(state.params, reference_params=reference_params)) + + +# NNX DPO has no split/merge counterpart: the Linen path overlays +# `reference_params` inside `state.params`, so it must be peeled off and +# reattached around `apply_gradients`. The NNX path holds the reference as a +# sibling field `TrainStateNNX.reference_model`; `apply_gradients` already +# only touches `self.model`, so no split/merge is needed. + + +def dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True): + """NNX DPO loss_fn for both train and eval. + + Signature mirrors the Linen `dpo_loss_fn` so it slots into the same + dispatcher in `gradient_accumulation_loss_and_grad`: + `(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True)` + + Differences from the Linen `dpo_loss_fn`: + * `policy_model` is an `nnx.Module` (carries its own params + RNG state). + * `dropout_rng` and `params` are unused for NNX (kept positional for + signature parity; NNX models manage these internally). + * The 6th arg (the `extra_dpo_args[0]`) is a frozen reference + `nnx.Module`, not a `reference_params` pytree. + * Reference forward is wrapped in `jax.lax.stop_gradient`; combined with + `nnx.value_and_grad(..., argnums=0)` over the policy, no gradient flows + to the reference's `nnx.Param` leaves. + + Args: + policy_model: Policy `nnx.Module` (the model being trained). + config: Config of parameters. + data: Batch of preference data with `chosen` / `rejected` fields. + dropout_rng: Unused for NNX (kept for signature parity with Linen). + params: Unused for NNX (kept for signature parity with Linen). + reference_model: Frozen reference `nnx.Module` for DPO logratio computation. + is_train: True for train_step and False for eval_step. + + Returns: + loss: DPO preference loss + MoE load balance loss (if applicable). + aux: dict with intermediate_outputs, xent_sum (always 0.0), dpo_loss, + total_weights, moe_lb_loss, reward_accuracy. + """ + del dropout_rng, params # unused for NNX + # decimate proportion of data when per_device_batch_size<1 + if is_train: + for k, v in data.items(): + data[k] = v[: config.micro_batch_size_to_train_on, :] + + # for DPO we don't support packed sequences (they shouldn't be present in the first place) + data["chosen_segmentation"] = (data["chosen_segmentation"] == 1).astype(jnp.int32) + data["rejected_segmentation"] = (data["rejected_segmentation"] == 1).astype(jnp.int32) + data["chosen_position"] = data["chosen_position"] * (data["chosen_segmentation"] == 1) + data["rejected_position"] = data["rejected_position"] * (data["rejected_segmentation"] == 1) + + # concatenated policy/reference forward pass + inputs = jnp.concatenate([data["chosen"], data["rejected"]], 0) + inputs_position = jnp.concatenate([data["chosen_position"], data["rejected_position"]], 0) + inputs_segmentation = jnp.concatenate([data["chosen_segmentation"], data["rejected_segmentation"]], 0) + + logits = policy_model( + decoder_input_tokens=inputs, + decoder_positions=inputs_position, + decoder_segment_ids=inputs_segmentation, + enable_dropout=config.enable_dropout if is_train else False, + ) + intermediate_outputs = nnx.state(policy_model, nnx.Intermediate).to_pure_dict() + + ref_logits = reference_model( + decoder_input_tokens=inputs, + decoder_positions=inputs_position, + decoder_segment_ids=inputs_segmentation, + enable_dropout=False, + ) + ref_logits = jax.lax.stop_gradient(ref_logits) + + # extract token ids, segmentation and logits for chosen and rejected sequences + chosen_ids = data["chosen"][..., 1:] + rejected_ids = data["rejected"][..., 1:] + chosen_segmentation = data["chosen_segmentation"][..., 1:] + rejected_segmentation = data["rejected_segmentation"][..., 1:] + n_logits = logits.shape[-3] // 2 # [B, S, E] - [batch, sequence, embedding/vocab] + chosen_logits, rejected_logits = logits[:n_logits, :, :], logits[n_logits:, :, :] + chosen_ref_logits, rejected_ref_logits = ref_logits[:n_logits, :, :], ref_logits[n_logits:, :, :] + + # common subsequence and padding mask + common_prefix_mask = jnp.cumsum(chosen_ids != rejected_ids, axis=-1) == 0 # [B, S] + valid_seq_mask = (chosen_segmentation != 0) & (rejected_segmentation != 0) & ~common_prefix_mask # [B, S] + + # compute logratios from the sequence-reduced observed token log-probability + chosen_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(chosen_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1 + )[..., 0] + chosen_logps = jnp.sum(chosen_logps_seq * valid_seq_mask, axis=-1) # [B] + chosen_ref_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(chosen_ref_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1 + )[..., 0] + chosen_ref_logps = jnp.sum(chosen_ref_logps_seq * valid_seq_mask, axis=-1) # [B] + chosen_logratios = chosen_logps - chosen_ref_logps # [B] + + rejected_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(rejected_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1 + )[..., 0] + rejected_logps = jnp.sum(rejected_logps_seq * valid_seq_mask, axis=-1) # [B] + rejected_ref_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(rejected_ref_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1 + )[..., 0] + rejected_ref_logps = jnp.sum(rejected_ref_logps_seq * valid_seq_mask, axis=-1) # [B] + rejected_logratios = rejected_logps - rejected_ref_logps # [B] + + # DPO loss from chosen and rejected logratios + LABEL_SMOOTHING, BETA = config.dpo_label_smoothing, config.dpo_beta + logratios_delta = BETA * (chosen_logratios - rejected_logratios) # [B] + losses = ( # [B] + -jax.nn.log_sigmoid(BETA * logratios_delta) * (1 - LABEL_SMOOTHING) + - jax.nn.log_sigmoid(-BETA * logratios_delta) * LABEL_SMOOTHING + ) + total_loss, total_weights = jnp.mean(losses), losses.shape[0] + loss = total_loss + + moe_lb_loss = 0.0 + if config.num_experts > 1: + moe_lb_losses = maxtext_utils.collect_intermediates_by_suffix(intermediate_outputs, "moe_lb_loss") + if moe_lb_losses: + moe_lb_loss = jnp.mean(jnp.concatenate(moe_lb_losses)) + loss += moe_lb_loss + reward_accuracy = jnp.mean(chosen_logratios > rejected_logratios) + aux = { + "intermediate_outputs": intermediate_outputs, + "xent_sum": 0.0, # DPO has no per-token cross-entropy sum; set to 0 for train_step compatibility + "dpo_loss": total_loss, # pure preference loss before MoE lb, analogous to lm_loss in pre-training + "total_weights": total_weights, + "moe_lb_loss": moe_lb_loss, + "reward_accuracy": reward_accuracy, + "indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility + "mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility + } + return loss, aux diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 585886b682..b997ecc7dc 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -61,7 +61,7 @@ from maxtext.common.gcloud_stub import vertex_tensorboard_modules from maxtext.common import metric_logger from maxtext.common.metric_logger import record_activation_metrics -from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn +from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn, dpo_loss_fn_nnx from maxtext.utils import exceptions from maxtext.utils import gcs_utils from maxtext.utils import max_logging @@ -320,15 +320,15 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat params = state.params ga_fn, ga_model, ga_params, ga_rng, ga_dpo = _loss_fn, model, params, dropout_rng, extra_dpo_args else: - if config.use_dpo: - raise NotImplementedError( - "DPO is not yet supported for NNX modules. DPO requires a reference model " - "stored alongside the policy model (Linen path uses state.params['reference_params']); " - "the NNX TrainState equivalent has not been wired up. As a workaround, set " - "pure_nnx=False for DPO runs." - ) state = nnx.merge(model, state) # reconstruct TrainStateNNX - ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] + if config.use_dpo: + # NNX DPO: reference_model is a sibling field on TrainStateNNX (set up by + # init_initial_state when config.use_dpo=True). dpo_loss_fn_nnx mirrors + # the Linen dpo_loss_fn signature, so it slots into the same dispatcher + # with reference_model passed as the single extra_dpo_args entry. + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = (dpo_loss_fn_nnx, state.model, None, None, [state.reference_model]) + else: + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] # --- Gradient computation --- if config.gradient_accumulation_steps > 1: @@ -394,9 +394,14 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat ) nnx.update(state.model, curr_params) + # `ga_fn` and `ga_dpo` were set up earlier (loss_fn vs dpo_loss_fn_nnx; + # ga_dpo carries the frozen reference_model when use_dpo, else empty). + _nnx_loss_fn = ga_fn + _nnx_extra_dpo_args = ga_dpo + def diff_wrapper(param, rest, config, data): local_model = nnx.merge(model_graphdef, param, rest, copy=True) - loss, aux = loss_fn(local_model, config, data, None, None, is_train=True) + loss, aux = _nnx_loss_fn(local_model, config, data, None, None, *_nnx_extra_dpo_args, is_train=True) _, _, new_rest = nnx.split(local_model, nnx.Param, ...) return loss, (aux, new_rest) @@ -576,7 +581,10 @@ def eval_step(model, config, state, data, dropout_rng=None): loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, sparsity_state=batch_stats) else: state = nnx.merge(model, state) # reconstruct TrainStateNNX - loss, aux = loss_fn(state.model, config, data, None, None, is_train=False) + if config.use_dpo: + loss, aux = dpo_loss_fn_nnx(state.model, config, data, None, None, state.reference_model, is_train=False) + else: + loss, aux = loss_fn(state.model, config, data, None, None, is_train=False) mtp_acceptance_rate = 0.0 if config.mtp_eval_target_module > 0: @@ -702,7 +710,7 @@ def train_loop(config, recorder, state=None): step_time_delta = datetime.datetime.now() - last_step_completion last_step_completion = datetime.datetime.now() - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + state_to_save = state if not (config.use_dpo and not config.pure_nnx) else _split_dpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): @@ -746,7 +754,7 @@ def train_loop(config, recorder, state=None): metric_logger_instance.buffer_and_write_train_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + state_to_save = state if not (config.use_dpo and not config.pure_nnx) else _split_dpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) if checkpoint_manager is not None: # in case the last checkpoint_period checkpoint is still in progress diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index ca90550630..80229b05be 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -225,10 +225,16 @@ def setup_train_loop(config, recorder, devices=None): if config.pure_nnx: # For NNX, the train state is wrapped in the TrainStateNNX module. + # When DPO is enabled, also materialize a frozen reference model alongside + # the policy. Both are constructed by `_create_model_partial()` (which uses + # `config.init_weights_seed`), so the reference starts identical to the + # policy — standard DPO practice. The reference is later overwritten by + # the step-0 checkpoint in `setup_post_setup_state` below. def create_train_state_fn(): model = _create_model_partial() optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) - return train_state_nnx.TrainStateNNX(model, optimizer) + reference_model = _create_model_partial() if config.use_dpo else None + return train_state_nnx.TrainStateNNX(model, optimizer, reference_model=reference_model) init_state_fn = create_train_state_fn else: @@ -316,8 +322,6 @@ def create_train_state_fn(): maxtext_utils.print_shardings_params(state_params, state_mesh_shardings_params, mesh, logical_annotations_params) if config.use_dpo: - if config.pure_nnx: - raise NotImplementedError("DPO is not supported yet by NNX models.") abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training) max_logging.log( "Restoring reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" @@ -342,9 +346,17 @@ def create_train_state_fn(): except FileNotFoundError: step0_restored = None if step0_restored is not None: - # TODO: For pure_nnx, the dpo state manipulation is different. - reference_params = step0_restored["items"].params["params"] - state = _merge_dpo_state(state, reference_params) + if config.pure_nnx: + # step0_restored["items"] is the flat nnx.State of the step-0 TrainStateNNX + # (typically from a non-DPO pre-training run, so its top-level fields are + # `model` and `optimizer` — no `reference_model`). Copy its `model` substate + # into our current state's `reference_model` slot. + step0_state = step0_restored["items"] + step0_model_substate = step0_state["model"] if "model" in step0_state else step0_state + state["reference_model"] = step0_model_substate + else: + reference_params = step0_restored["items"].params["params"] + state = _merge_dpo_state(state, reference_params) else: max_logging.log( "Could not restore reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" diff --git a/tests/integration/setup_train_loop_nnx_test.py b/tests/integration/setup_train_loop_nnx_test.py index d11f9658a7..05a7fcffec 100644 --- a/tests/integration/setup_train_loop_nnx_test.py +++ b/tests/integration/setup_train_loop_nnx_test.py @@ -126,15 +126,6 @@ def test_pure_nnx_setup_param_only_split_matches_model(self): del model - def test_pure_nnx_dpo_raises_not_implemented(self): - """The use_dpo branch (train_utils.py:319-320) must raise for NNX.""" - # use_dpo requires a few prerequisites; the simplest is to set the flag and - # let setup_train_loop reach the NotImplementedError check before the more - # involved DPO path runs. - config = _tiny_nnx_pyconfig(use_dpo=True, packing=False) - with self.assertRaises(NotImplementedError): - setup_train_loop(config, recorder=None) - if __name__ == "__main__": unittest.main() diff --git a/tests/unit/dpo_nnx_test.py b/tests/unit/dpo_nnx_test.py new file mode 100644 index 0000000000..461c3cb2aa --- /dev/null +++ b/tests/unit/dpo_nnx_test.py @@ -0,0 +1,215 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NNX DPO unit tests. + +Covers the NNX-native DPO surface: + * `TrainStateNNX(model, optimizer, reference_model=...)` — reference model + sits alongside policy and is not touched by `apply_gradients`. + * `dpo_loss_fn_nnx(policy, config, data, None, None, reference, is_train)` — + aux structure, identical-model invariant (loss = log(2), reward_accuracy = 0.5). +""" + +import math +import types +import unittest + +import jax +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.layers import train_state_nnx +from maxtext.trainers.post_train.dpo import dpo_utils + + +class _MockTransformer(nnx.Module): + """Tiny NNX transformer-shaped module for DPO tests. + + Accepts the same keyword args that `dpo_loss_fn_nnx` passes: + `decoder_input_tokens`, `decoder_positions`, `decoder_segment_ids`, + `enable_dropout`. Other args are tolerated via **kwargs. + """ + + def __init__(self, vocab_size: int, embed_dim: int, rngs: nnx.Rngs): + self.embed = nnx.Embed(vocab_size, embed_dim, rngs=rngs) + self.proj = nnx.Linear(embed_dim, vocab_size, rngs=rngs) + + def __call__( + self, + decoder_input_tokens, + decoder_positions=None, + decoder_segment_ids=None, + enable_dropout=False, + **kwargs, + ): + del decoder_positions, decoder_segment_ids, enable_dropout, kwargs + return self.proj(self.embed(decoder_input_tokens)) + + +def _make_dpo_config(**overrides): + """Build the minimal config surface that `dpo_loss_fn_nnx` reads.""" + base = { + "dpo_label_smoothing": 0.0, + "dpo_beta": 0.1, + "enable_dropout": False, + "num_experts": 1, + "micro_batch_size_to_train_on": 2, + } + base.update(overrides) + return types.SimpleNamespace(**base) + + +def _make_dpo_batch(batch_size=2, seq_len=5): + """Build a tiny DPO-shaped batch. + + `chosen` and `rejected` share the first 2 tokens (common prefix is masked + out in the loss), differ at positions 2 and 3, and are padded at position 4. + """ + chosen = jnp.array([[1, 2, 3, 4, 0]] * batch_size, dtype=jnp.int32) + rejected = jnp.array([[1, 2, 5, 6, 0]] * batch_size, dtype=jnp.int32) + positions = jnp.tile(jnp.arange(seq_len, dtype=jnp.int32), (batch_size, 1)) + segmentation = jnp.array([[1, 1, 1, 1, 0]] * batch_size, dtype=jnp.int32) + return { + "chosen": chosen, + "rejected": rejected, + "chosen_position": positions, + "rejected_position": positions, + "chosen_segmentation": segmentation, + "rejected_segmentation": segmentation, + } + + +class TestTrainStateNNXWithReferenceModel(unittest.TestCase): + """`TrainStateNNX(reference_model=...)` semantics.""" + + def setUp(self): + self.policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + self.reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(1)) + self.tx = optax.adam(1e-3) + + def test_init_with_reference(self): + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer, reference_model=self.reference) + self.assertIs(state.model, self.policy) + self.assertIs(state.reference_model, self.reference) + self.assertEqual(state.optimizer.step.value, 0) + + def test_init_without_reference_omits_attribute(self): + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer) + self.assertFalse(hasattr(state, "reference_model")) + + def test_apply_gradients_does_not_touch_reference(self): + """Gradient update on policy must leave reference model bit-identical.""" + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer, reference_model=self.reference) + + ref_kernel_before = jnp.asarray(state.reference_model.proj.kernel.value).copy() + + def policy_loss(m): + return jnp.mean(m(jnp.array([[1, 2]])) ** 2) + + grads = nnx.grad(policy_loss)(state.model) + state.apply_gradients(grads) + + ref_kernel_after = jnp.asarray(state.reference_model.proj.kernel.value) + self.assertTrue(jnp.array_equal(ref_kernel_before, ref_kernel_after)) + + +class TestDPOLossFnNNX(unittest.TestCase): + """`dpo_loss_fn_nnx` numerical and structural sanity checks.""" + + def setUp(self): + self.policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + # Reference initialized with the same seed to make policy and reference + # bit-identical at construction time. + self.reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + self.config = _make_dpo_config() + self.data = _make_dpo_batch() + + def test_aux_has_expected_keys(self): + _, aux = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + expected_keys = { + "intermediate_outputs", + "xent_sum", + "dpo_loss", + "total_weights", + "moe_lb_loss", + "reward_accuracy", + "indexer_loss", + "mtp_loss", + } + self.assertEqual(set(aux.keys()), expected_keys) + self.assertEqual(aux["xent_sum"], 0.0) + self.assertEqual(aux["moe_lb_loss"], 0.0) # num_experts=1 + self.assertEqual(aux["total_weights"], self.data["chosen"].shape[0]) + + def test_identical_policy_and_reference_yields_log2_loss(self): + """When policy == reference, all logratios are 0; with label_smoothing=0 + the per-example loss is `-log(sigmoid(0)) = log(2)`. `reward_accuracy` + uses strict `chosen > rejected`, so equal logratios score 0.0 (no example + is strictly preferred). + """ + loss, aux = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + self.assertAlmostEqual(float(loss), math.log(2.0), places=4) + self.assertAlmostEqual(float(aux["dpo_loss"]), math.log(2.0), places=4) + self.assertAlmostEqual(float(aux["reward_accuracy"]), 0.0, places=4) + + def test_dropout_rng_and_params_args_are_unused(self): + """The 4th and 5th positional args are signature-compat slots for the + Linen dispatcher; passing arbitrary values must not affect the result. + """ + loss_a, _ = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + loss_b, _ = dpo_utils.dpo_loss_fn_nnx( + self.policy, + self.config, + dict(self.data), + jax.random.PRNGKey(123), # dropout_rng — unused + {"params": "garbage"}, # params — unused + self.reference, + is_train=True, + ) + self.assertAlmostEqual(float(loss_a), float(loss_b), places=6) + + def test_value_and_grad_argnums0_only_diffs_policy(self): + """`nnx.value_and_grad(..., argnums=0)` over the policy should produce + finite grads on policy params and not require reference grads. + """ + + def _loss(policy_module): + loss, _ = dpo_utils.dpo_loss_fn_nnx( + policy_module, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + return loss + + grad_fn = nnx.value_and_grad(_loss, argnums=0) + loss, grads = grad_fn(self.policy) + self.assertTrue(jnp.isfinite(loss)) + # Grads is an nnx.State of the policy's nnx.Param leaves; check at least one + # leaf is finite and non-trivially shaped. + leaves = jax.tree_util.tree_leaves(grads) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + self.assertTrue(jnp.all(jnp.isfinite(leaf))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py index f532820f86..4340d4e22a 100644 --- a/tests/unit/train_nnx_test.py +++ b/tests/unit/train_nnx_test.py @@ -174,16 +174,6 @@ def test_train_step_returns_state_and_metrics(self): self.assertIn("learning/param_norm", metrics["scalar"]) self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) - def test_train_step_dpo_raises_for_nnx(self): - cfg, ts = _build_state() - cfg.use_dpo = True - state_graphdef, state_pure = nnx.split(ts) - data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) - with self.assertRaises(NotImplementedError): - pre_train.train_step( - state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data - ) - def test_train_step_increments_optimizer_step(self): cfg, ts = _build_state() state_graphdef, state_pure = nnx.split(ts) From 4dc3ae2b9eee2ccb55c7b3a3640676b669001d27 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 5 May 2026 20:53:02 +0000 Subject: [PATCH 5/5] NNX: native MaxEngine inference (drop route-to-Linen path in maxengine.py) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR5 audited maxengine.py and routed the inference path to the Linen implementation regardless of pure_nnx, with a comment block explaining that "the flag affects training, not inference serving." That kept the Linen serving path unchanged but meant pure_nnx=True users silently got the Linen engine. This change replaces the route with a real NNX flow: when config.pure_nnx=True, the engine builds an NNX Transformer, splits out (params, cache, rest) with nnx.split, and at every JIT body merges the model concretely with nnx.merge to run the forward pass. Linen is preserved byte-for-byte; every NNX edit is gated `if config.pure_nnx:` and pure_nnx=False is still the default. maxengine.py (__init__): - Build two abstract NNX Transformers on the NNX path: self.model with model_mode=PREFILL (batch=1, single padded prompt) and self.model_ar with model_mode=AUTOREGRESSIVE (batch=micro_batch_size_to_train_on, decode_state shape). Both are needed because NNX cache vars inherit CACHE_BATCH_PREFILL vs CACHE_BATCH from the construction model_mode, and bulk_insert searches for the substring "cache_batch" in the AR-mode logical-axes tuple. nnx.eval_shape is called directly inside nn_partitioning.axis_rules rather than through create_nnx_abstract_model to avoid the jax.set_mesh wrap that trips Flax 0.12.6 on logical-only axes like "norm" (same reason get_abstract_state_nnx avoids set_mesh). - Cache the graphdef from a 3-way nnx.split(Param, Cache, ...) so JIT bodies can pass (params, cache, rest) separately to nnx.merge. The rest slot (RNG vars etc.) is materialized concretely in load_params. maxengine.py (cache adapter + _nnx_run_model): - bulk_insert / _insert_jit / _maybe_*_prefill_result_cache walk the cache via tree_map_with_path and switch on path[-1].key (the cache variable name like "cached_prefill_key"). Linen mutable cache is a plain nested dict. NNX Cache state would expose a ".value" accessor at that position. Bridge via nnx.State.to_pure_dict() (after the model run) and nnx.replace_by_pure_dict (before nnx.merge), so the cache plumbing helpers see the same shape on both paths. - Add _nnx_run_model: nnx.merge(graphdef, params, cache, rest, copy=True) -> model(...) -> nnx.state(model, nnx.Cache).to_pure_dict(). copy=True avoids reusing Variable objects across traces (TraceContextError), mirroring train.py's diff_wrapper workaround. - Add _nnx_cache_state_template / _nnx_init_cache_dict helpers parametrised by mode so prefill (batch 1) and decode_state (batch N) pull from the right abstract model. maxengine.py (load_params): - New _load_params_nnx: accepts user-provided NNX-shape params or loads via from_pretrained. For user-provided params, materializes a concrete model once via _create_model_fn() to capture a real rest state for nnx.merge (wasteful but simple; the from_pretrained branch avoids this). Refreshes self.graphdef from the concrete model so subsequent merges line up exactly. - Builds self.abstract_params, populates self.prefill_kv_cache_annotations and self.kv_cache_annotations (using model_ar for the latter so bulk_insert's substring lookup hits), wraps both into NamedSharding. - pure_nnx + quantization, pure_nnx + LoRA, pure_nnx + stack_prefill_result_cache=True, pure_nnx + prefill_multisampling, and pure_nnx + prefill_concat raise NotImplementedError for now; the Linen path is the workaround. AOT compilation (aot_compile / _compile_generate_and_get_layouts) is not gated and may work as-is; not exercised by tests yet. maxengine.py (init_decode_state, _prefill_jit, _generate_jit): - _init_decode_state_nnx zero-initializes a pure-dict cache from model_ar (so the leading batch dim matches generate's input shape) and builds kv_cache_annotations_named per leaf by reading nnx.Cache.metadata. Tries "out_sharding", "sharding", and "sharding_names" because Flax 0.12.6 renamed these. - _prefill_jit / _generate_jit add an `if config.pure_nnx:` branch that calls _nnx_run_model in place of self.model.apply with mutable=["cache"]. existing_prefix.cache is threaded as a pure-dict cache directly (no params|{"cache":...} dict-merge — params is an nnx.State, not a dict). maxtext_utils.py: - New get_prefill_kv_cache_annotations_nnx / get_kv_cache_annotations_nnx that mirror the Linen helpers' return shape (per-leaf PartitionSpec tree). Both delegate to _nnx_cache_partition_specs which extracts nnx.Cache state via nnx.split, calls get_nnx_named_sharding_with_scan_axis inside nn_partitioning.axis_rules so logical axes ("layers", "cache_batch", "norm", ...) resolve to physical mesh axes, and converts the result to a pure-dict tree. tests/unit/maxengine_test.py: - New tests: test_init_nnx, test_basic_prefill_nnx (with NaN/inf and per-layer cache shape checks), test_basic_decode_nnx (4-step generate with next_pos advancement check), test_quantize_raises_for_nnx, test_lora_raises_for_nnx. - New test_linen_nnx_parity_prefill: bridges Linen-init params into the NNX engine via linen_nnx_converter (convert_linen_to_nnx -> _strip_value_wrappers -> nnx.replace_by_pure_dict) and asserts the NNX engine's prefill matches Linen on the same weights — logits within bf16 tolerance (rtol=0.05, atol=0.1; the test config uses bf16 compute) and exact greedy first-token argmax. - Existing Linen tests untouched. Test summary: 9 passed, 1 skipped (test_chunked_prefill is a pre-existing CPU-only skip). bash lint.sh: codespell + pylint + pyink all green. --- src/maxtext/inference/maxengine/maxengine.py | 317 +++++++++++++++++-- src/maxtext/utils/maxtext_utils.py | 24 ++ tests/unit/maxengine_test.py | 170 ++++++++++ 3 files changed, 478 insertions(+), 33 deletions(-) diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index c00f475e8d..5bd220f4e1 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -32,6 +32,7 @@ from jax.experimental.layout import DeviceLocalLayout as DLL # type: ignore from flax import linen as nn +from flax import nnx from flax import struct from flax.linen import partitioning as nn_partitioning import flax @@ -44,8 +45,10 @@ from maxtext.inference.page_manager import PageManager, PageState from maxtext.multimodal import processor as mm_processor from maxtext.utils import lora_utils +from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +from maxtext.utils import model_creation_utils from maxtext.common.gcloud_stub import jetstream, is_decoupled from maxtext.common.common_types import MODEL_MODE_PREFILL, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE @@ -112,11 +115,32 @@ def __init__(self, config: Any, devices: Any | None = None): self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) # Model and Optimizer definition. - # MaxEngine uses Linen-shaped state (state.params, state_mesh_shardings.params, - # state.opt_state) and serves Linen-format inference checkpoints. Use Linen path - # regardless of pure_nnx — the flag affects training, not inference serving. quant = quantizations.configure_quantization(config) - self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + if config.pure_nnx: + # We need both PREFILL and AR abstract models because the cache vars inherit + # CACHE_BATCH_PREFILL vs CACHE_BATCH from the construction model_mode, and + # bulk_insert searches for the substring "cache_batch" in the AR-mode names. + # Calling nnx.eval_shape directly (instead of create_nnx_abstract_model) avoids + # the jax.set_mesh wrap that trips Flax 0.12.6 on logical-only axes like "norm". + _create_model = model_creation_utils.get_nnx_create_model_fn(config, mesh=self._mesh, model_mode=MODEL_MODE_PREFILL) + _create_model_ar = model_creation_utils.get_nnx_create_model_fn( + config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE + ) + with nn_partitioning.axis_rules(config.logical_axis_rules): + abstract_model = nnx.eval_shape(_create_model) + abstract_model_ar = nnx.eval_shape(_create_model_ar) + self.model = abstract_model + self.model_ar = abstract_model_ar + # 3-way split so JIT bodies can pass (params, cache, rest) separately to + # nnx.merge. `rest` (RNG state etc.) is materialized in load_params. + graphdef, _, _, _ = nnx.split(abstract_model, nnx.Param, nnx.Cache, ...) + self.graphdef = graphdef + self._create_model_fn = _create_model + self._nnx_rest_state = None + else: + self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + self.graphdef = None + self._create_model_fn = None self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None)) self.abstract_params = None @@ -142,6 +166,65 @@ def print_stats(self, label: str): max_utils.print_mem_stats(label) max_utils.print_cpu_ram_stats(label) + # NNX cache adapter: bulk_insert / _insert_jit / _maybe_stack_* switch on + # path[-1].key (e.g. "cached_prefill_key"). NNX state would expose ".value" at + # that position, so we convert NNX state <-> plain dict at the JIT boundary + # via to_pure_dict / replace_by_pure_dict. The cache helpers stay unchanged. + + def _nnx_cache_state_template(self, mode: str = MODEL_MODE_PREFILL) -> Any: + """Empty nnx.State template for the model's nnx.Cache vars (PREFILL=batch 1, AR=batch N).""" + src = self.model if mode == MODEL_MODE_PREFILL else self.model_ar + _, cache_state, _ = nnx.split(src, nnx.Cache, ...) + return cache_state + + def _nnx_init_cache_dict(self, mode: str = MODEL_MODE_PREFILL) -> dict: + """Zero-filled pure-dict cache matching the abstract NNX model.""" + src = self.model if mode == MODEL_MODE_PREFILL else self.model_ar + _, cache_state, _ = nnx.split(src, nnx.Cache, ...) + cache_dict = cache_state.to_pure_dict() + return jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), cache_dict) + + def _nnx_run_model( + self, + params, + cache_dict, + decoder_input_tokens, + decoder_positions, + *, + decoder_segment_ids=None, + enable_dropout=False, + model_mode, + previous_chunk=None, + true_length=None, + slot=None, + page_state=None, + encoder_images=None, + encoder_image_masks=None, + encoder_audios=None, + ): + """NNX equivalent of `model.apply(..., mutable=["cache"])`. Returns (logits, new_cache_dict).""" + cache_state = self._nnx_cache_state_template(mode=model_mode) + nnx.replace_by_pure_dict(cache_state, cache_dict) + # copy=True avoids reusing Variable objects across traces (TraceContextError), + # mirroring the workaround in train.py's diff_wrapper. + model = nnx.merge(self.graphdef, params, cache_state, self._nnx_rest_state, copy=True) + logits = model( + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + encoder_images=encoder_images, + encoder_image_masks=encoder_image_masks, + encoder_audios=encoder_audios, + enable_dropout=enable_dropout, + model_mode=model_mode, + previous_chunk=previous_chunk, + true_length=true_length, + slot=slot, + page_state=page_state, + ) + new_cache = nnx.state(model, nnx.Cache).to_pure_dict() + return logits, new_cache + def generate_aot( self, params: Params, decode_state: DecodeState, rng: PRNGKeyType | None = None ): # returns (new_decode_state, result_tokens) @@ -225,6 +308,9 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar if rng is None: rng = jax.random.PRNGKey(0) + if self.config.pure_nnx: + return self._load_params_nnx(params=params, rng=rng) + if self.model.quant and self.config.checkpoint_is_quantized: print("Loading from the quantized checkpoint...") self.model.quant.quant_mode = quantizations.get_quant_mode("serve") @@ -284,11 +370,80 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar return params + def _load_params_nnx(self, params, rng): + """NNX equivalent of load_params: returns an nnx.Param state and populates KV cache shardings.""" + if self.model.quant is not None: + raise NotImplementedError("pure_nnx + quantization not yet supported. Use pure_nnx=False.") + + if params: + print("Resharding given NNX params") + _, params_abs, _ = nnx.split(self.model, nnx.Param, ...) + target_shardings = jax.tree.map( + lambda x: x.sharding if hasattr(x, "sharding") else None, + params_abs, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + params_state = jax.device_put(params, target_shardings) + # Build a concrete model once to capture a real `rest` (RNG vars) for nnx.merge. + # Wasteful but simple — the from_pretrained branch below avoids this. + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + concrete_model = self._create_model_fn() + graphdef, _, _, rest_state = nnx.split(concrete_model, nnx.Param, nnx.Cache, ...) + self.graphdef = graphdef + self._nnx_rest_state = rest_state + del concrete_model + else: + max_logging.log("Loading NNX params via from_pretrained") + with self._mesh: + nnx_model = model_creation_utils.from_pretrained( + self.config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE + ) + # Refresh graphdef from the concrete loaded model so subsequent merges line up. + graphdef, params_state, _, rest_state = nnx.split(nnx_model, nnx.Param, nnx.Cache, ...) + self.graphdef = graphdef + self._nnx_rest_state = rest_state + del nnx_model + + self.abstract_params = jax.tree.map( + lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) + if isinstance(x, jax.Array) + else None, + params_state, + ) + + self.prefill_kv_cache_annotations = maxtext_utils.get_prefill_kv_cache_annotations_nnx( + self.model, self.config, self._mesh + ) + self.prefill_kv_cache_shardings = jax.tree.map( + lambda x: jax.sharding.NamedSharding(self._mesh, x), + self.prefill_kv_cache_annotations, + ) + if self.config.stack_prefill_result_cache: + # With scan_layers=True the NNX cache leaves are already stacked on axis 0, + # so the engine's manual-stack helper (which assumes an unstacked Linen tree) + # doesn't apply. Wiring this up cleanly is a Phase-2 follow-up. + raise NotImplementedError("pure_nnx + stack_prefill_result_cache=True not yet supported.") + # AR-mode abstract model so axis names use CACHE_BATCH (not CACHE_BATCH_PREFILL); + # bulk_insert / _insert_jit search for "cache_batch" in the per-leaf logical axes. + self.kv_cache_annotations = maxtext_utils.get_kv_cache_annotations_nnx(self.model_ar, self.config, self._mesh) + self.kv_cache_shardings = jax.tree.map( + lambda x: jax.sharding.NamedSharding(self._mesh, x), + self.kv_cache_annotations, + ) + # state_mesh_annotations is unused on the NNX path; callers reading it + # (e.g. set_engine_vars_from_base_engine) need to be NNX-aware first. + self.state_mesh_annotations = None + + self.print_stats("After load_params (NNX)") + return params_state + def load_single_adapter(self, adapter_path): """ Load Single adapter from adapter_path. Expect adapter_config.json and LoRA adapter weights at this path within subdirectory `/0/items`. """ + if self.config.pure_nnx: + raise NotImplementedError("pure_nnx + LoRA not yet supported. Use pure_nnx=False.") adapter_config_path = os.path.join(adapter_path, "adapter_config.json") adapter_weights_path = os.path.join(adapter_path, "0", "items") @@ -324,6 +479,8 @@ def quantize_params(self, state, rng: PRNGKeyType | None = None): """Forward pass to quantize decode params.""" if rng is None: rng = jax.random.PRNGKey(0) + if self.config.pure_nnx: + raise NotImplementedError("pure_nnx + quantize_params not yet supported.") self.model.quant.quant_mode = quantizations.get_quant_mode("convert") @@ -478,7 +635,10 @@ def _prefill_jit( if existing_prefix is not None: if not self.use_chunked_prefill: raise ValueError("Using chunked prefill is needed for existing_prefix.") - input_params = params | {"cache": existing_prefix.cache} + # NNX threads existing_prefix.cache via the nnx_cache local below; only + # the Linen path merges cache into input_params (params is a dict there). + if not self.config.pure_nnx: + input_params = params | {"cache": existing_prefix.cache} start_position = existing_prefix.common_prefix_tokens.shape[0] # TODO(yuyanpeng): rename previous_chunk previous_chunk = jnp.expand_dims(existing_prefix.common_prefix_tokens, 0) @@ -510,24 +670,48 @@ def _prefill_jit( sequence_indicator = jnp.expand_dims(one_d_output, 0) rng, new_rng = jax.random.split(rng) - with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - flat_logits, new_vars = self.model.apply( - input_params, - input_tokens, - positions, - encoder_images=images, - encoder_image_masks=image_masks, - encoder_audios=audio_values, - decoder_segment_ids=sequence_indicator, - enable_dropout=False, - model_mode=MODEL_MODE_PREFILL, - rngs={"params": new_rng}, - mutable=["cache"], - previous_chunk=previous_chunk, - true_length=true_length, - slot=slot, - page_state=page_state, + if self.config.pure_nnx: + # Prefill always operates on batch=1 (one padded prompt at a time). + nnx_cache = ( + existing_prefix.cache if existing_prefix is not None else self._nnx_init_cache_dict(mode=MODEL_MODE_PREFILL) ) + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + flat_logits, new_cache_dict = self._nnx_run_model( + params=input_params, + cache_dict=nnx_cache, + decoder_input_tokens=input_tokens, + decoder_positions=positions, + decoder_segment_ids=sequence_indicator, + encoder_images=images, + encoder_image_masks=image_masks, + encoder_audios=audio_values, + enable_dropout=False, + model_mode=MODEL_MODE_PREFILL, + previous_chunk=previous_chunk, + true_length=true_length, + slot=slot, + page_state=page_state, + ) + new_vars = {"cache": new_cache_dict} + else: + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + flat_logits, new_vars = self.model.apply( + input_params, + input_tokens, + positions, + encoder_images=images, + encoder_image_masks=image_masks, + encoder_audios=audio_values, + decoder_segment_ids=sequence_indicator, + enable_dropout=False, + model_mode=MODEL_MODE_PREFILL, + rngs={"params": new_rng}, + mutable=["cache"], + previous_chunk=previous_chunk, + true_length=true_length, + slot=slot, + page_state=page_state, + ) if return_prompt_logp: prompt_logp = inference_utils.prompt_logprobs_from_prefill(flat_logits, input_tokens, true_length) else: @@ -736,6 +920,9 @@ def _prefill_multisampling_jit( prefilling stage. The number of tokens is specified by num_samples. """ + if self.config.pure_nnx: + raise NotImplementedError("pure_nnx + prefill_multisampling not yet supported. Use pure_nnx=False.") + input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] positions = jnp.expand_dims(jnp.arange(0, input_tokens.shape[1]), 0) @@ -861,6 +1048,9 @@ def prefill_concat( if existing_prefix: raise ValueError("We don't know what to do with existing_prefix") + if self.config.pure_nnx: + raise NotImplementedError("pure_nnx + prefill_concat not yet supported. Use pure_nnx=False.") + if rng is None: rng = jax.random.PRNGKey(0) input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] @@ -1030,17 +1220,30 @@ def _generate_jit( previous_token = decode_state["tokens"] rng, new_rng = jax.random.split(rng) # run one step generation - with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - out_logits, new_vars = self.model.apply( - params | {"cache": decode_state["cache"]}, - previous_token, - decode_state["next_pos"], - enable_dropout=False, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - rngs={"params": new_rng}, - mutable=["cache"], - page_state=page_state, - ) + if self.config.pure_nnx: + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + out_logits, new_cache_dict = self._nnx_run_model( + params=params, + cache_dict=decode_state["cache"], + decoder_input_tokens=previous_token, + decoder_positions=decode_state["next_pos"], + enable_dropout=False, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + page_state=page_state, + ) + new_vars = {"cache": new_cache_dict} + else: + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + out_logits, new_vars = self.model.apply( + params | {"cache": decode_state["cache"]}, + previous_token, + decode_state["next_pos"], + enable_dropout=False, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + rngs={"params": new_rng}, + mutable=["cache"], + page_state=page_state, + ) out_logits = jax.lax.with_sharding_constraint(out_logits, self.replicated_sharding) new_cache = jax.lax.with_sharding_constraint(new_vars["cache"], self.kv_cache_shardings) # sampling tokens @@ -1598,6 +1801,9 @@ def init_decode_state( if self.config.attention == "paged" and self.page_manager is not None: page_state = self.page_manager.get_initial_page_state() # pytype: disable=attribute-error + if self.config.pure_nnx: + return self._init_decode_state_nnx(rng=rng, page_state=page_state) + # pylint: disable=unused-argument def init(abstract_params, page_state): x = jnp.ones( @@ -1691,6 +1897,51 @@ def is_lp(k): zeroed = max_utils.unbox_logicallypartioned(init_state) return zeroed + def _init_decode_state_nnx(self, rng, page_state) -> DecodeState: + """NNX equivalent of init_decode_state. Returns a decode_state dict with a pure-dict cache.""" + del rng, page_state # cache shape comes from the abstract model + batch = int(self.config.per_device_batch_size * self.mesh.size) + vocab = self.config.vocab_size + + # AR-mode cache so the batch dim matches generate's input shape. + cache_dict_abs = self._nnx_init_cache_dict(mode=MODEL_MODE_AUTOREGRESSIVE) + + @functools.partial(jax.jit, out_shardings=(self.kv_cache_shardings,)) + def _init_cache(): + return (jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), cache_dict_abs),) + + (cache,) = _init_cache() + + # Per-leaf logical axes for bulk_insert's "cache_batch" lookup. Use model_ar + # so segment_id leaves carry CACHE_BATCH (under PREFILL they'd carry + # CACHE_BATCH_PREFILL, which doesn't contain the "cache_batch" substring). + _, cache_state, _ = nnx.split(self.model_ar, nnx.Cache, ...) + + def _logical_axes_for(var): + # Flax 0.12.6 renamed "sharding" to "out_sharding"; older code may still + # use "sharding_names". Try all three. + meta = var.get_metadata() if hasattr(var, "get_metadata") else {} + out = meta.get("out_sharding") or meta.get("sharding") or meta.get("sharding_names") + if out is None: + return () + return (out,) if isinstance(out, str) else tuple(out) + + annotations_state = jax.tree.map( + _logical_axes_for, + cache_state, + is_leaf=lambda v: isinstance(v, nnx.Variable), + ) + self.kv_cache_annotations_named = annotations_state.to_pure_dict() + + return { + "logits": jnp.zeros((batch, 1, vocab), dtype=jnp.float32), + "cache": cache, + "next_pos": jnp.zeros((batch, 1), dtype=jnp.int32), + "generated_tokens": jnp.zeros((batch, 1), dtype=jnp.int32), + "tokens": jnp.zeros((batch, 1), dtype=jnp.int32), + "token_logp": jnp.zeros((batch, 1), dtype=jnp.float32), + } + @property def max_concurrent_decodes(self) -> int: """Free slots.""" diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 1638dc8869..4b1f48e07e 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1722,6 +1722,30 @@ def init_kv_cache(model, config): return state_mesh_annotations +def _nnx_cache_partition_specs(abstract_model, config, mesh): + """Per-leaf PartitionSpec tree for the abstract model's nnx.Cache vars. + + Returned as a pure dict so the engine can wrap it in NamedSharding the same + way it does for the Linen helpers below. + """ + _, cache_state, _ = nnx.split(abstract_model, nnx.Cache, ...) + # get_nnx_named_sharding_with_scan_axis reads logical axis rules from the + # active flax partitioning context, so wrap. + with nn_partitioning.axis_rules(config.logical_axis_rules): + named_state = get_nnx_named_sharding_with_scan_axis(cache_state, mesh) + return jax.tree.map(lambda s: s.spec, named_state.to_pure_dict()) + + +def get_prefill_kv_cache_annotations_nnx(abstract_model, config, mesh): + """NNX equivalent of get_prefill_kv_cache_annotations.""" + return _nnx_cache_partition_specs(abstract_model, config, mesh) + + +def get_kv_cache_annotations_nnx(abstract_model, config, mesh): + """NNX equivalent of get_kv_cache_annotations.""" + return _nnx_cache_partition_specs(abstract_model, config, mesh) + + def save_quantized_checkpoint_if_configured(config, params): """Save quantized checkpoint if configured""" assert config.quantization, "quantization must be configured" diff --git a/tests/unit/maxengine_test.py b/tests/unit/maxengine_test.py index 944d34bfef..0148bccc3f 100644 --- a/tests/unit/maxengine_test.py +++ b/tests/unit/maxengine_test.py @@ -23,6 +23,8 @@ from jax.sharding import Mesh import numpy as np import pytest +from flax import nnx +from flax.linen import partitioning as nn_partitioning from maxtext.configs import pyconfig from maxtext.common.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_PREFILL from maxtext.layers import quantizations @@ -30,7 +32,10 @@ pytest.importorskip("jetstream", reason="jetstream not installed") from maxtext.inference.maxengine import maxengine from maxtext.models import models +from maxtext.checkpoint_conversion import linen_nnx_converter +from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +from maxtext.utils import model_creation_utils from tests.utils.test_helpers import get_test_config_path pytestmark = [pytest.mark.external_serving] @@ -162,6 +167,171 @@ def test_basic_decode(self): self.assertEqual(result_token.data.ndim, 2) self.assertEqual(result_token.data.shape[1], 3) + def _init_nnx_pyconfig(self, **kwargs): + """init_pyconfig with NNX flags on.""" + return self.init_pyconfig(pure_nnx=True, enable_nnx=True, pure_nnx_decoder=True, **kwargs) + + def _build_nnx_params(self, cfg, mesh): + """Materialize an NNX Transformer and return its nnx.Param state.""" + _create_model = model_creation_utils.get_nnx_create_model_fn(cfg, mesh=mesh, model_mode=MODEL_MODE_PREFILL) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + model = _create_model() + _, params_state, _ = nnx.split(model, nnx.Param, ...) + return params_state + + def test_init_nnx(self): + """NNX engine init exposes graphdef + abstract Transformer.""" + cfg = self._init_nnx_pyconfig() + engine = maxengine.MaxEngine(cfg, jax.devices()) + self.assertIsNotNone(engine.graphdef) + self.assertIsNotNone(engine.model) + self.assertEqual(type(engine.model).__name__, "Transformer") + + def test_basic_prefill_nnx(self): + """NNX prefill returns a Linen-shape result dict with finite values.""" + cfg = self._init_nnx_pyconfig() + devices_array = maxtext_utils.create_device_mesh(cfg) + mesh = Mesh(devices_array, cfg.mesh_axes) + params_state = self._build_nnx_params(cfg, mesh) + + input_tokens = jnp.array([1, 306, 5360, 304, 0, 0, 0, 0]) + true_length = 4 + engine = maxengine.MaxEngine(cfg, jax.devices()) + params = engine.load_params(params=params_state) + prefill_result, first_token = engine.prefill(params=params, padded_tokens=input_tokens, true_length=true_length) + + self.assertEqual(prefill_result["generated_tokens"], jnp.array([0])) + self.assertEqual(prefill_result["tokens"].size, 1) + self.assertTrue(jnp.array_equal(first_token.data.size, 3)) + self.assertEqual(first_token.log_prob.shape, (1, 1)) + self.assertIn("cache", prefill_result) + self.assertIsInstance(prefill_result["cache"], dict) + # Catch silent NaN/inf from a bad nnx.merge or cache round-trip. + self.assertTrue(jnp.all(jnp.isfinite(prefill_result["logits"]))) + cache_leaves, _ = jax.tree.flatten(prefill_result["cache"]) + for leaf in cache_leaves: + self.assertTrue(jnp.all(jnp.isfinite(leaf)), msg=f"non-finite cache leaf, shape={leaf.shape}") + # scan_layers=True (default in test config) ⇒ leading axis is num_decoder_layers. + for leaf in cache_leaves: + self.assertEqual(leaf.shape[0], cfg.num_decoder_layers, msg=f"layer-axis mismatch, got shape={leaf.shape}") + + def test_basic_decode_nnx(self): + """NNX prefill → insert → 4 generate steps. Verifies next_pos advances and logits stay finite.""" + cfg = self._init_nnx_pyconfig() + devices_array = maxtext_utils.create_device_mesh(cfg) + mesh = Mesh(devices_array, cfg.mesh_axes) + params_state = self._build_nnx_params(cfg, mesh) + + input_tokens = jnp.array([1, 306, 5360, 304]) + engine = maxengine.MaxEngine(cfg, jax.devices()) + params = engine.load_params(params=params_state) + decode_state = engine.init_decode_state() + prefill_result, _ = engine.prefill(params=params, padded_tokens=input_tokens, true_length=4) + decode_state = engine.insert(prefill_result, decode_state, slot=0) + + # 4 steps is enough to catch off-by-one cache pointer bugs. + initial_next_pos = int(decode_state["next_pos"][0, 0]) + for step in range(4): + decode_state, result_token = engine.generate(params=params, decode_state=decode_state) + self.assertEqual(result_token.log_prob.ndim, 2) + self.assertEqual(result_token.log_prob.shape[1], 1) + self.assertEqual(result_token.data.ndim, 2) + self.assertEqual(result_token.data.shape[1], 3) + self.assertTrue(jnp.all(jnp.isfinite(decode_state["logits"]))) + self.assertEqual( + int(decode_state["next_pos"][0, 0]), + initial_next_pos + step + 1, + msg=f"next_pos didn't advance at step {step}", + ) + + def test_quantize_raises_for_nnx(self): + """pure_nnx + quantization raises NotImplementedError.""" + cfg = self._init_nnx_pyconfig(quantization="int8") + engine = maxengine.MaxEngine(cfg, jax.devices()) + with self.assertRaises(NotImplementedError): + engine.load_params(rng=self.rng) + + def test_lora_raises_for_nnx(self): + """pure_nnx + LoRA raises NotImplementedError.""" + cfg = self._init_nnx_pyconfig() + engine = maxengine.MaxEngine(cfg, jax.devices()) + with self.assertRaises(NotImplementedError): + engine.load_single_adapter("/nonexistent/adapter/path") + + def _linen_params_to_nnx_state(self, linen_params, abstract_nnx_model): + """Convert Linen params → NNX nnx.Param state via linen_nnx_converter so both engines share weights.""" + nnx_dict_wrapped = linen_nnx_converter.convert_linen_to_nnx({"params": linen_params}, scan_layers=True)["model"] + # pylint: disable=protected-access + nnx_pure = linen_nnx_converter._strip_value_wrappers(nnx_dict_wrapped) + _, params_state, _ = nnx.split(abstract_nnx_model, nnx.Param, ...) + nnx.replace_by_pure_dict(params_state, nnx_pure) + return params_state + + def test_linen_nnx_parity_prefill(self): + """Same weights → same prefill output across Linen and NNX engines. + + A failure here means the NNX forward pass diverges from Linen on identical + weights (cache plumbing, nnx.merge wiring, or Transformer.__call__). + """ + cfg_linen = self.init_pyconfig() + devices_array = maxtext_utils.create_device_mesh(cfg_linen) + mesh = Mesh(devices_array, cfg_linen.mesh_axes) + + # Linen: init params, run prefill. + quant = quantizations.configure_quantization(cfg_linen) + linen_model = models.transformer_as_linen(config=cfg_linen, mesh=mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + ids, decoder_segment_ids, decoder_positions = self.get_data() + linen_vars = linen_model.init( + {"params": self.rng, "aqt": self.rng, "dropout": self.rng}, + ids, + decoder_positions, + decoder_segment_ids, + enable_dropout=False, + ) + # Linen.init wraps leaves in LogicallyPartitioned (which has a `.value` + # attribute); unbox so the converter's {value:} wrapper detector doesn't + # mistake them for already-wrapped NNX leaves. + linen_vars = max_utils.unbox_logicallypartioned(linen_vars) + + input_tokens = jnp.array([1, 306, 5360, 304, 0, 0, 0, 0]) + true_length = 4 + linen_engine = maxengine.MaxEngine(cfg_linen, jax.devices()) + linen_params = linen_engine.load_params(params=linen_vars) + linen_prefill, linen_first_token = linen_engine.prefill( + params=linen_params, padded_tokens=input_tokens, true_length=true_length + ) + + # NNX: bridge Linen weights, run prefill on the same prompt. + cfg_nnx = self._init_nnx_pyconfig() + nnx_engine = maxengine.MaxEngine(cfg_nnx, jax.devices()) + nnx_params_state = self._linen_params_to_nnx_state(linen_vars["params"], nnx_engine.model) + nnx_params = nnx_engine.load_params(params=nnx_params_state) + nnx_prefill, nnx_first_token = nnx_engine.prefill( + params=nnx_params, padded_tokens=input_tokens, true_length=true_length + ) + + # Tolerance is loose because the test config uses bf16 compute, where + # accumulation order between Linen-scan and NNX-scan drifts by ~0.05. + # Greedy match below is the behavioral check that actually matters. + linen_logits = np.asarray(linen_prefill["logits"]) + nnx_logits = np.asarray(nnx_prefill["logits"]) + self.assertEqual(linen_logits.shape, nnx_logits.shape) + np.testing.assert_allclose( + linen_logits, + nnx_logits, + rtol=0.05, + atol=0.1, + err_msg="Linen vs NNX prefill logits diverge beyond bf16 tolerance.", + ) + self.assertEqual( + int(linen_first_token.data[0, 0]), + int(nnx_first_token.data[0, 0]), + msg="Linen and NNX disagreed on greedy first token with identical weights.", + ) + linen_cache_leaves, _ = jax.tree.flatten(linen_prefill["cache"]) + nnx_cache_leaves, _ = jax.tree.flatten(nnx_prefill["cache"]) + self.assertEqual(len(linen_cache_leaves), len(nnx_cache_leaves)) + @pytest.mark.skip(reason="Can only pass on CPU.") def test_chunked_prefill(self): """Test identical result between chunked prefill with single and multiple chunked.