Skip to content

feat: Add DeepSeek-V4 Token Compressor#3810

Draft
parambole wants to merge 1 commit intomainfrom
parambole/dsv4-compressor
Draft

feat: Add DeepSeek-V4 Token Compressor#3810
parambole wants to merge 1 commit intomainfrom
parambole/dsv4-compressor

Conversation

@parambole
Copy link
Copy Markdown
Collaborator

@parambole parambole commented May 4, 2026

Description

Short Description

Introduces the token Compressor designed for DeepSeek-V4 Heavily Compressed Attention (HCA) and Compressed Sparse Attention (CSA) pre-training inside MaxText. It successfully implements block-wise sequence compression and overlapping causal padding.

Problem Solved & Context

DeepSeek-V4 introduces low-rank token compression inside its attention block to reduce the memory footprint of the Key-Value cache. This PR implements the stateless pre-training math for both disjoint Heavily Compressed Attention (HCA, compression ratio $m'=128$) and causal overlapping Compressed Sparse Attention (CSA, compression ratio $m=4$).

Technical Implementation Details

Isomorphic JAX Math:

  • HCA: Groups the sequence into disjoint chunks of size 128, applies low-rank projections, adds Absolute Positional Embedding (APE) sequence bias, softmaxes along the window axis, and reduces.
  • CSA: Projects to 2 * compressed_dim, splits Ca and Cb along the feature axis (axis -1), and chunks both by 4. It shifts Ca functionally by 1 chunk using a JAX-friendly jnp.concatenate with a -1e4 causal gating pad (to nullify prior tokens on window 0). It concatenates prior Ca and current Cb along the window axis to form a combined window of size 8, gates with APE bias, softmaxes, and reduces.

Pre-Training Stability Upcast:

  • Explicitly upcasts the gate logits to float32 before adding the self.ape bias and executing the jax.nn.softmax exponentiation, and casts the weights back to bfloat16 afterward. This prevents precision truncation and underflow inside the softmax denominator under bfloat16 training.

Causal Gating Padding (-1e4):

  • Pads out-of-bounds historical tokens on window 0 with -1e4 instead of -inf. In FP32 softmax, $e^{-10000}$ underflows safely to exactly 0.0 (nullifying the pad), while completely shielding the JAX pre-training pass from NaN gradient collapses (-inf * 0 under XLA during backpropagation).

Tests

This change was verified using a newly added, 100% self-contained test suite at tests/unit/deepseek_v4_vs_reference_test.py which performs four categories of systems and mathematical validation:

  1. Static Shape & Edge Case Verification: Validates HCA and CSA output shapes under extreme batch sizes (1 and 16), dynamic embedding hyperparameters, and non-divisible sequence length errors.
  2. JAX NNX State Splitting (nnx.split): Verifies that the module's static structures (compress_ratio, overlap) cleanly separate from dynamic parameter arrays under nnx.split and nnx.merge without leaking JAX tracers, ensuring absolute compatibility with checkpoints and parallel training loops.
  3. RNG Determinism: Assures that parameter initialization is perfectly reproducible across parallel hosts when given identical RNG seeds.
  4. Deep Intermediate Parity (Hugging Face modeling_deepseek_v4.py): Collects PyTorch intermediate tensors on the fly and compares them step-by-step against our JAX layers, asserting that projections, softmax weights, and pre-norm weighted sums match numerically to the 5th decimal place under randomized inputs.

Command to Reproduce:
To run the entire test suite on CPU, execute the following command:

JAX_PLATFORMS=cpu pytest -v tests/unit/deepseek_v4_vs_reference_test.py
Verification Output:
tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_csa_compression_shape PASSED
tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_csa_pytorch_equivalence PASSED
tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_extreme_batch_sizes PASSED
tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_hca_compression_shape PASSED
tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_hca_pytorch_equivalence PASSED
tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_nnx_state_splitting PASSED
tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_non_divisible_sequence_length_fails PASSED
tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_rng_reproducibility PASSED
tests/unit/deepseek_v4_vs_reference_test.py::DeepseekV4VsReferenceTest::test_varying_hyperparameters PASSED
======================== 9 passed, 40 warnings in 7.05s ========================

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant