[JAX] Calculate seqlens and offsets in O(T) space instead of O(T*T) space for THD sequences #2522
Conversation
8a7da45 to
1e15b00
Compare
1e15b00 to
7c891bd
Compare
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
7c891bd to
642a0d6
Compare
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
…the seqoffsets calculation API Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci jax L0 L1 L2 |
Greptile SummaryThis PR replaces the O(T²) mask-materialization path in JAX THD attention with a new Confidence Score: 5/5Safe to merge; logic is correct and well-tested by existing attention tests, with only a minor docstring error. All findings are P2 (docstring wording). The core algorithmic logic for causal, BRCM, and padding-only cases is correct, the fast-path guard correctly excludes BRCM, and existing L2 attention tests pass. No files require special attention beyond the minor docstring fix in transformer_engine/jax/attention.py. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["_segment_ids_pos_to_seqlens_offsets"] --> B{Fast path eligible?}
B -- Yes --> C["_segment_ids_pos_to_seqlens_offsets_fast_causal_path\nO of T per row"]
B -- No --> E["_get_seqlens_offsets_thd\nO of T times max_segments_per_seq per row"]
E --> F{Mask type}
F -- BRCM --> G["Scatter kv_key into per-segment columns\nmax-agg for Q-side, min-agg for KV-side"]
F -- Causal --> H["Scatter kv_key into per-segment columns\nmin-agg for Q-side, max-agg for KV-side"]
F -- Padding only --> I["one_hot presence check across segments"]
G --> J["bincount for seqlens, argwhere for offsets"]
H --> J
I --> J
Reviews (4): Last reviewed commit: "Merge branch 'main' into klakhani/mem-op..." | Re-trigger Greptile |
…culation for THD non-cp and Cp p2p ring
- Newer path is O(T*max_segments) per seq
- Newer path works well with CP p2p ring
Fix BRCM cross attn by routing to new slow path rather than fast causal path
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci jax L0 L1 L2 |
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
cyanguwa
left a comment
There was a problem hiding this comment.
I feel the logic is a bit convoluted and expensive but if a customer is waiting on a hot fix, I've approved it. Please think about whether we can improve the logic or fuse these small kernels in the future. Thanks!
Thanks for the feedback ! I agree. For THD non-CP the logic is pretty straightforward as it needs only the corresponding However, for CP it needs both (The I will try to split the logic in the future |
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
…pace for THD sequences (#2522) * Get seqlens and offsets in O(N) space instead of O(N*N) space Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Re enable fast causal path Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Fix: seqoffsets calculation for THD Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com> * Clean up code. Add new comments. Fix unecessary pasing of seg pos to the seqoffsets calculation API Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Optimize and fix the slow O(T*T) path for seqlens and seqoffsets calculation for THD non-cp and Cp p2p ring - Newer path is O(T*max_segments) per seq - Newer path works well with CP p2p ring Fix BRCM cross attn by routing to new slow path rather than fast causal path Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix lint failure Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com> --------- Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kshitij Janardan Lakhani <klakhani@login-ptyche02.ptyche.clusters.nvidia.com> Co-authored-by: JAX Toolbox <jax@nvidia.com>
Description
The current mechanism in TE JAX attention for calculating the THD seqlens and offsets materializes the full mask and then uses it to calculate the seqlens and seqoffsets. However, this is O(T²) in space and can result in OOM failures when running larger sequences. This PR moves to a newer O(T*max_segment_per_seq) approach for the same thereby allowing the processing of larger sequences in memory along with some perf advantages.
Benched on a standalone script: the current mask-based approach is O(T²) in
both FLOPS and intermediate memory, while the new approach is O(T · max_segments_per_seq).
At T=64k, the new approach needs ~1.3 MiB of scratch (vs the current ~64 MiB,
~50–800× less) and does ~3,000–5,000× fewer FLOPS per call. Trend holds across all
three mask families (top-left causal, bottom-right causal, padding-only) and
max_segments_per_seq from 16 to 256.
NOTE: These tests were run on a stand alone microbench to give the user some estimation of the scope of the impact, but the e2e impact might be much lesser)
Fixes # (issue)
#2700
Type of change
Changes
Benchmarking:
TL;DR: The new approach matches or beats current latency in our standalone bench while cutting XLA temp memory and per-call FLOPs by ~10³–10⁴× at T=128K. Trends hold across all three mask families (top-left causal, bottom-right causal, padding-only) and across max_segments_per_seq from 16 to 256.
The bench compiles each helper (current and new) as a standalone jax.jit program on randomized [B, T] segment inputs (T up to 128k, ~16 segments/row with some intra-segment padding) and reads three signals off the compiled object:
compiled.memory_analysis()(buffer-assignment plan: arg/out/temp/total bytes) andcompiled.cost_analysis()(HLO instruction counts) — both are static, post-compile reads that do not execute the kernelMemory, work, and latency vs the prior O(T²) mask path
B=1, max_segments_per_seq=16, mixed-padding layout. NEW is the approach in this PR; OLD is what we replace.
Memory and FLOPs are reported for T=128K as those are statically analyzed by XLA without running the actual function, however, for latency we need to run the function and for the arch (GB200) I was using 64k was the largest power of 2 that I could fit with OOM so metrics reported on that.
I calculated these numbers for a sweep from 2k, 8k, 16k, 64k, 128k but posting the larger T values below.
Scaling with max_segments_per_seq (causal mask)
B=1, mixed-padding layout. NEW's cost is O(T · max_segments_per_seq); OLD is O(T²)
and independent of max_segments_per_seq. Sweep below shows that even at the high
end of typical max_seg values, the OLD/NEW gap stays in the 10³–10⁴× range.
Testing:
Pipeline 49970285 passes for all tests except two sets of test below:
Failures seen:
Likely cause of time out: [JAX] Fix MNIST L2 jax test instability #2933 (TBD). As a sanity check I ran the L2 attn tests on GB200x4 and see no failures so should not be a blocker
Checklist: