Draft
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Short Description
Introduces the token Compressor designed for DeepSeek-V4 Heavily Compressed Attention (HCA) and Compressed Sparse Attention (CSA) pre-training inside MaxText. It successfully implements block-wise sequence compression and overlapping causal padding.
Problem Solved & Context
DeepSeek-V4 introduces low-rank token compression inside its attention block to reduce the memory footprint of the Key-Value cache. This PR implements the stateless pre-training math for both disjoint Heavily Compressed Attention (HCA, compression ratio$m'=128$ ) and causal overlapping Compressed Sparse Attention (CSA, compression ratio $m=4$ ).
Technical Implementation Details
Isomorphic JAX Math:
jnp.concatenatewith a -1e4 causal gating pad (to nullify prior tokens on window 0). It concatenates prior Ca and current Cb along the window axis to form a combined window of size 8, gates with APE bias, softmaxes, and reduces.Pre-Training Stability Upcast:
self.apebias and executing thejax.nn.softmaxexponentiation, and casts the weights back to bfloat16 afterward. This prevents precision truncation and underflow inside the softmax denominator under bfloat16 training.Causal Gating Padding (-1e4):
Tests
This change was verified using a newly added, 100% self-contained test suite at
tests/unit/deepseek_v4_vs_reference_test.pywhich performs four categories of systems and mathematical validation:nnx.split): Verifies that the module's static structures (compress_ratio,overlap) cleanly separate from dynamic parameter arrays undernnx.splitandnnx.mergewithout leaking JAX tracers, ensuring absolute compatibility with checkpoints and parallel training loops.modeling_deepseek_v4.py): Collects PyTorch intermediate tensors on the fly and compares them step-by-step against our JAX layers, asserting that projections, softmax weights, and pre-norm weighted sums match numerically to the 5th decimal place under randomized inputs.Command to Reproduce:
To run the entire test suite on CPU, execute the following command:
JAX_PLATFORMS=cpu pytest -v tests/unit/deepseek_v4_vs_reference_test.py Verification Output: tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_csa_compression_shape PASSED tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_csa_pytorch_equivalence PASSED tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_extreme_batch_sizes PASSED tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_hca_compression_shape PASSED tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_hca_pytorch_equivalence PASSED tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_nnx_state_splitting PASSED tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_non_divisible_sequence_length_fails PASSED tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_rng_reproducibility PASSED tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_varying_hyperparameters PASSED ======================== 9 passed, 40 warnings in 7.05s ========================Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.