Skip to content

Add checkpoint resharding script for faster loading#3801

Open
shuningjin wants to merge 1 commit intomainfrom
shuningjin-reshard
Open

Add checkpoint resharding script for faster loading#3801
shuningjin wants to merge 1 commit intomainfrom
shuningjin-reshard

Conversation

@shuningjin
Copy link
Copy Markdown
Collaborator

@shuningjin shuningjin commented May 3, 2026

Description

This script re-shards a MaxText checkpoint on CPU. The goal is to pre-shard checkpoints (source) to accelerate loading on TPUs (target) by reducing re-sharding overhead.

FIXES: b/504714612

Introduction

Problem: In checkpoint conversion, we typically shard along the 0th dimension (usually the expert dimension for MoE). Consequently, loading is fast when the target sharding is EP (e.g., a few minutes), but noticeably slow for FSDP (e.g., an hour). This is a major bottleneck because FSDP is our most common use case.

Effectiveness: Our experiments show that pre-sharding a checkpoint to fsdp=16 reduces the loading time of DeepSeek-V3 from 60 minutes to 6 minutes on a v5p-128 cluster targeting fsdp=64. Furthermore, the solution scales efficiently to v7x 1k chips, maintaining a brief 10-minute load time.

Generalizability: While this was built to solve the FSDP loading bottleneck, the solution generalizes to pre-shard checkpoints into other target sharding layout.

Method

The Orbax checkpoint is streamed from storage directly into the target sharded layout on a simulated CPU mesh, and then saved to a new checkpoint.

Key operation trace: maxengine.load_params -> maxtext_utils.setup_decode_state -> checkpointing.load_params_from_path -> orbax.checkpoint.Checkpointer.restore

User Guide

Full details are in docstring.

Key Parameters:

  • --simulated_cpu_devices_count (defaults to 16). Examples:
    • Suitable for most cases: --simulated_cpu_devices_count=16 ici_fsdp_parallelism=16
    • More customization: --simulated_cpu_devices_count=32 ici_fsdp_parallelism=16 ici_expert_parallelism=2
  • weight_dtype: The dtype used to load and save the checkpoint. Highly recommend using weight_dtype=bfloat16.

Memory Requirements:

  • For X billion parameters, needs slightly over 2X GB RAM (each param takes 2 bytes with weight_dtype=bfloat16).
  • Note: We only hold one model copy in memory, as the re-sharding happens dynamically during the read operation. Additional buffer memory is needed mainly for the I/O streaming overhead, usually small compared to model weight.
  • Example: deepseek3 with MTP layers has 685B parameters, uses 1.37 TB for weights, and hits a peak RAM of ~1.45 TB (overhead is trivial relative to weight).
  • Example: deepseek2-16b, uses 32GB for weights, and hits a peak RAM of ~63 GB (overhead seems non-trivial, as the model size is small).

Tests

deepseek3-671b with mtp

Full test details in b/504714612 (comment3, comment8)

  • pre-sharded with fsdp=16
    • conversion on CPU: time 134min, peak RAM 1486 GB.
    • The loading time is reduced to 6min (from 1hr), target sharding is fsdp=64 on v5p-128.
  • loading time on v7x with 1k chips is 10min

deepseek2-16b

Reshard:

# reshard CKPT1 to CKPT2 on CPU
python3 -m maxtext.checkpoint_conversion.reshard_checkpoint \
model_name=deepseek2-16b attention=dot_product mla_naive_kvcache=false \
scan_layers=True load_parameters_path=$CKPT1 \
base_output_directory=$CKPT2_DIR \
weight_dtype=bfloat16 \
checkpoint_storage_concurrent_gb=1024 checkpoint_storage_use_ocdbt=True checkpoint_storage_use_zarr3=True \
skip_jax_distributed_system=True ici_fsdp_parallelism=16 \
--simulated_cpu_devices_count=16

Inspect structure:

# CKPT1 (old)
ArrayMetadata :  name=params.params.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_0,  directory=gs://ml-auto-solutions/output/unowned/maxtext_nightly_deepseek2-16b-v5p-8-2026-04-20-06-52-15/scanned/0/items,  shape=(64, 26, 2048, 1408),  sharding=NamedShardingMetadata(shape=[16], axis_names=['checkpoint_sharding_axis'], axis_types=(Auto,), partition_spec=('checkpoint_sharding_axis',)) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0), DeviceMetadata(id=1), DeviceMetadata(id=2), DeviceMetadata(id=3), DeviceMetadata(id=4), DeviceMetadata(id=5), DeviceMetadata(id=6), DeviceMetadata(id=7), DeviceMetadata(id=8), DeviceMetadata(id=9), DeviceMetadata(id=10), DeviceMetadata(id=11), DeviceMetadata(id=12), DeviceMetadata(id=13), DeviceMetadata(id=14), DeviceMetadata(id=15)]),  dtype=float16,  storage=StorageMetadata(chunk_shape=(4, 26, 2048, 1408), write_shape=(4, 26, 2048, 1408)),

# CKPT2 (new)
ArrayMetadata :  name=params.params.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_0,  directory=gs://shuningjin-multipod-dev/conversion/ds2-fsdp-2026-05-03-10-58/0/items,  shape=(64, 26, 2048, 1408),  sharding=NamedShardingMetadata(shape=[ 1  1  1 16  1  1  1  1  1  1  1  1], axis_names=['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto), partition_spec=('expert', None, ['fsdp', 'tensor_transpose', 'context'], ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive'])) device_mesh=DeviceMetadataMesh(mesh=[[[[[[[[[[[[DeviceMetadata(id=0)]]]]]]]], [[[[[[[[DeviceMetadata(id=1)]]]]]]]], [[[[[[[[DeviceMetadata(id=2)]]]]]]]], [[[[[[[[DeviceMetadata(id=3)]]]]]]]], [[[[[[[[DeviceMetadata(id=4)]]]]]]]], [[[[[[[[DeviceMetadata(id=5)]]]]]]]], [[[[[[[[DeviceMetadata(id=6)]]]]]]]], [[[[[[[[DeviceMetadata(id=7)]]]]]]]], [[[[[[[[DeviceMetadata(id=8)]]]]]]]], [[[[[[[[DeviceMetadata(id=9)]]]]]]]], [[[[[[[[DeviceMetadata(id=10)]]]]]]]], [[[[[[[[DeviceMetadata(id=11)]]]]]]]], [[[[[[[[DeviceMetadata(id=12)]]]]]]]], [[[[[[[[DeviceMetadata(id=13)]]]]]]]], [[[[[[[[DeviceMetadata(id=14)]]]]]]]], [[[[[[[[DeviceMetadata(id=15)]]]]]]]]]]]]),  dtype=bfloat16,  storage=StorageMetadata(chunk_shape=(64, 26, 128, 1408), write_shape=(64, 26, 128, 1408)),

forward_pass_logit_checker, load with target sharding fsdp=16:

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@shuningjin shuningjin force-pushed the shuningjin-reshard branch from d549c06 to fa776ba Compare May 3, 2026 22:26
@shuningjin shuningjin changed the title reshard checkpoint Add checkpoint resharding script for faster loading May 3, 2026
@shuningjin shuningjin marked this pull request as ready for review May 3, 2026 23:22
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 3, 2026

🤖 Hi @shuningjin, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces a new script reshard_checkpoint.py designed to re-shard MaxText checkpoints on CPU. This utility is highly effective for reducing checkpoint loading times on TPUs, as demonstrated by the significant performance gains reported for DeepSeek-V3. The PR also includes minor robustness improvements and bug fixes in llama_or_mistral_ckpt.py.

🔍 General Feedback

  • Performance: The reported 10x reduction in loading time (from 60 min to 6 min) for DeepSeek-V3 is a major improvement for large-scale model training and inference.
  • Initialization Timing: A key concern is the timing of JAX initialization in the new script. Setting environment variables like XLA_FLAGS after importing JAX-dependent modules may lead to them being ignored if the XLA backend has already been initialized.
  • Flexibility: Adding a way to specify or preserve the step_number would enhance the utility of the resharding script.

from maxtext.utils import max_utils, max_logging
from maxtext.common import checkpointing
from maxtext.checkpoint_conversion.utils.utils import print_peak_memory

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The `step_number` is hardcoded to `0`. While this might be acceptable for a newly resharded checkpoint, it would be more flexible to allow users to specify the step number via a command-line argument, or attempt to preserve it from the source checkpoint if that information is available.
Suggested change
# Dummy configs for the checkpoint_manager
step_number = config.step if hasattr(config, 'step') else 0

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

Supplementing the previous review with the missed comment on JAX initialization timing. Overall, the PR is very valuable for optimizing large model checkpoints.

🔍 General Feedback

  • Initialization Timing: Setting XLA_FLAGS before JAX imports ensures the simulated CPU mesh is correctly established.

# Set JAX environment
jax.config.update("jax_platforms", "cpu")
# Simulate CPU devices as virtual mesh
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 Setting `os.environ["XLA_FLAGS"]` after importing `jax` and other `maxtext` modules is risky. JAX or its dependencies might initialize the XLA backend upon import or during `jax.config.update`, making later changes to `XLA_FLAGS` ineffective. It is recommended to set these environment variables before any JAX-related imports.
Suggested change
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}"
if __name__ == "__main__":
# Define local parser
parser = argparse.ArgumentParser()
parser.add_argument(
"--simulated_cpu_devices_count",
type=int,
required=False,
default=16,
help="Number of simulated CPU devices for sharding the checkpoint",
)
# Parse known args returns the namespace AND the list of remaining arguments
local_args, remaining_args = parser.parse_known_args()
# Set JAX environment BEFORE any jax imports if possible,
# or at least before any jax calls that might trigger initialization.
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" # Suppress TensorFlow logging
import jax
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
jax.config.update("jax_platforms", "cpu")
# Reconstruct model_args (script name + the args MaxText needs)
model_args = [sys.argv[0]] + remaining_args
app.run(main, argv=model_args)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants