Skip to content

[JAX] Calculate seqlens and offsets in O(T) space instead of O(T*T) space for THD sequences #2522

Merged
KshitijLakhani merged 10 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/mem-optimize-seqlens-offsets-thd
May 1, 2026
Merged

[JAX] Calculate seqlens and offsets in O(T) space instead of O(T*T) space for THD sequences #2522
KshitijLakhani merged 10 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/mem-optimize-seqlens-offsets-thd

Conversation

@KshitijLakhani
Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani commented Dec 16, 2025

Description

The current mechanism in TE JAX attention for calculating the THD seqlens and offsets materializes the full mask and then uses it to calculate the seqlens and seqoffsets. However, this is O(T²) in space and can result in OOM failures when running larger sequences. This PR moves to a newer O(T*max_segment_per_seq) approach for the same thereby allowing the processing of larger sequences in memory along with some perf advantages.

Benched on a standalone script: the current mask-based approach is O(T²) in
both FLOPS and intermediate memory, while the new approach is O(T · max_segments_per_seq).
At T=64k, the new approach needs ~1.3 MiB of scratch (vs the current ~64 MiB,
~50–800× less) and does ~3,000–5,000× fewer FLOPS per call. Trend holds across all
three mask families (top-left causal, bottom-right causal, padding-only) and
max_segments_per_seq from 16 to 256.

NOTE: These tests were run on a stand alone microbench to give the user some estimation of the scope of the impact, but the e2e impact might be much lesser)

Fixes # (issue)
#2700

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Calculate the seqlens and seqoffsets for THD layout without materializing a full mask.
  • This code path affects THD fused attn as well THD CP P2P fused attn.
    • An evaluation was also performed (static code analysis and running the standalone benchmark script) to use a O(T) approach for THD fused attn and a O(T*max_segmetns_per_seq) for THD CP P2P, but it wasn't deemed to be as much better on the benchmarks to duplicate and complicate the logic, so not chosen here (especially since there's already a separate fast causal path)
  • Route BRCM code to the slow path as that's functionally correct.

Benchmarking:

TL;DR: The new approach matches or beats current latency in our standalone bench while cutting XLA temp memory and per-call FLOPs by ~10³–10⁴× at T=128K. Trends hold across all three mask families (top-left causal, bottom-right causal, padding-only) and across max_segments_per_seq from 16 to 256.

The bench compiles each helper (current and new) as a standalone jax.jit program on randomized [B, T] segment inputs (T up to 128k, ~16 segments/row with some intra-segment padding) and reads three signals off the compiled object:

  • Memory and FLOPS/bytes come from compiled.memory_analysis() (buffer-assignment plan: arg/out/temp/total bytes) and compiled.cost_analysis() (HLO instruction counts) — both are static, post-compile reads that do not execute the kernel
  • Latency actually runs the compiled kernel (5x warmup + 20x timed, wall-clocked via time.perf_counter() with jax.block_until_ready() to sync the GPU)

Memory, work, and latency vs the prior O(T²) mask path

B=1, max_segments_per_seq=16, mixed-padding layout. NEW is the approach in this PR; OLD is what we replace.

Memory and FLOPs are reported for T=128K as those are statically analyzed by XLA without running the actual function, however, for latency we need to run the function and for the arch (GB200) I was using 64k was the largest power of 2 that I could fit with OOM so metrics reported on that.
I calculated these numbers for a sweep from 2k, 8k, 16k, 64k, 128k but posting the larger T values below.

Memory @ T=128K (peak temp / total)
mask          OLD temp     NEW temp     OLD/NEW     OLD total    NEW total    OLD/NEW
padding       16.13 GiB    161.03 KiB    105,003x     16.13 GiB    673.32 KiB    25,113x
causal        32.25 GiB      1.03 MiB     31,939x     32.25 GiB      2.03 MiB    16,234x
bottom_right  32.13 GiB      2.66 MiB     12,379x     32.13 GiB      3.66 MiB     8,994x

Per-call work @ T=128K(cost_analysis; fusion-immune)
mask          OLD flops    NEW flops    OLD/NEW     OLD bytes    NEW bytes    OLD/NEW
padding       120.27 G     12.73 M       9,447x       34.90 GB      2.18 MB      16,018x
causal        171.81 G     27.56 M       6,234x       86.44 GB      9.13 MB       9,470x
bottom_right  171.82 G     35.72 M       4,810x       86.45 GB     18.15 MB       4,764x

Latency @ T=64K (median wall-clock)
mask          OLD μs        NEW μs      OLD/NEW
padding       46,526.29     214.69       217x
causal        57,549.62     314.46       183x
bottom_right  57,692.69     370.24       156x

Scaling with max_segments_per_seq (causal mask)

B=1, mixed-padding layout. NEW's cost is O(T · max_segments_per_seq); OLD is O(T²)
and independent of max_segments_per_seq. Sweep below shows that even at the high
end of typical max_seg values, the OLD/NEW gap stays in the 10³–10⁴× range.

Memory @ T=128K (peak temp / total)
max_seg   OLD temp     NEW temp    OLD/NEW    OLD total    NEW total    OLD/NEW
   16     32.25 GiB     1.03 MiB    31,939x     32.25 GiB     2.03 MiB     16,234x
   64     32.25 GiB     1.13 MiB    29,284x     32.25 GiB     2.13 MiB     15,514x
  256     32.25 GiB     1.50 MiB    21,976x     32.25 GiB     2.51 MiB     13,175x

Per-call work @ T=128K (cost_analysis; fusion-immune)
max_seg   OLD flops    NEW flops    OLD/NEW     OLD bytes    NEW bytes   OLD/NEW
   16     171.81 G      27.56 M       6,234x      86.44 GB      9.13 MB     9,470x
   64     171.81 G      59.02 M       2,911x      86.44 GB      9.33 MB     9,265x
  256     171.81 G     184.86 M         929x      86.44 GB     10.14 MB     8,523x

Latency @ T=64K (median wall-clock)
max_seg   OLD μs        NEW μs       OLD/NEW
   16     57,502.81      310.91       185x
   64     57,616.70      304.45       189x
  256     57,327.03      260.99       220x

Testing:

Pipeline 49970285 passes for all tests except two sets of test below:

Failures seen:

  • L0 Lint failure: fixed in 1e0380c
  • L2 failures are seen due to time outs (but all attention tests pass in them) so I do not think this is a problem
    Likely cause of time out: [JAX] Fix MNIST L2 jax test instability #2933 (TBD). As a sanity check I ran the L2 attn tests on GB200x4 and see no failures so should not be a blocker

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@KshitijLakhani KshitijLakhani self-assigned this Dec 16, 2025
@KshitijLakhani KshitijLakhani force-pushed the klakhani/mem-optimize-seqlens-offsets-thd branch from 8a7da45 to 1e15b00 Compare March 18, 2026 22:31
@KshitijLakhani KshitijLakhani force-pushed the klakhani/mem-optimize-seqlens-offsets-thd branch from 1e15b00 to 7c891bd Compare April 16, 2026 20:42
KshitijLakhani and others added 3 commits April 24, 2026 21:46
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/mem-optimize-seqlens-offsets-thd branch from 7c891bd to 642a0d6 Compare April 24, 2026 21:47
KshitijLakhani and others added 3 commits April 27, 2026 23:14
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
…the seqoffsets calculation API

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

/te-ci jax L0 L1 L2

@KshitijLakhani KshitijLakhani marked this pull request as ready for review April 28, 2026 00:11
@KshitijLakhani KshitijLakhani changed the title [JAX] Calculate seqlens and offsets in O(N) space instead of O(N*N) space for THD sequences [JAX] Calculate seqlens and offsets in O(T) space instead of O(T*T) space for THD sequences Apr 28, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 28, 2026

Greptile Summary

This PR replaces the O(T²) mask-materialization path in JAX THD attention with a new _get_seqlens_offsets_thd function that is O(T·max_segments_per_seq), using per-segment min/max aggregation via one-hot scatter rather than expanding the full Q×KV attention mask. It also correctly routes BRCM masks to this slow path by adding not attn_mask_type.is_bottom_right() to the fast-path condition, since BRCM semantics diverge from top-left causal in cross-attention (Q and KV on different CP ranks).

Confidence Score: 5/5

Safe to merge; logic is correct and well-tested by existing attention tests, with only a minor docstring error.

All findings are P2 (docstring wording). The core algorithmic logic for causal, BRCM, and padding-only cases is correct, the fast-path guard correctly excludes BRCM, and existing L2 attention tests pass.

No files require special attention beyond the minor docstring fix in transformer_engine/jax/attention.py.

Important Files Changed

Filename Overview
transformer_engine/jax/attention.py Replaces O(T²) mask-materialization slow path with a new O(T·max_segments_per_seq) aggregation in _get_seqlens_offsets_thd; adds not attn_mask_type.is_bottom_right() guard on the fast-path condition to correctly route BRCM to the slow path; the BRCM docstring has an inverted formula (>= min vs <= max), and _mask_to_seqlens_offset is now unreachable dead code.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["_segment_ids_pos_to_seqlens_offsets"] --> B{Fast path eligible?}
    B -- Yes --> C["_segment_ids_pos_to_seqlens_offsets_fast_causal_path\nO of T per row"]
    B -- No --> E["_get_seqlens_offsets_thd\nO of T times max_segments_per_seq per row"]
    E --> F{Mask type}
    F -- BRCM --> G["Scatter kv_key into per-segment columns\nmax-agg for Q-side, min-agg for KV-side"]
    F -- Causal --> H["Scatter kv_key into per-segment columns\nmin-agg for Q-side, max-agg for KV-side"]
    F -- Padding only --> I["one_hot presence check across segments"]
    G --> J["bincount for seqlens, argwhere for offsets"]
    H --> J
    I --> J
Loading

Reviews (4): Last reviewed commit: "Merge branch 'main' into klakhani/mem-op..." | Re-trigger Greptile

Comment thread transformer_engine/jax/attention.py Outdated
Comment thread transformer_engine/jax/attention.py Outdated
@KshitijLakhani KshitijLakhani added 2.15.0 performance Performance issues labels Apr 28, 2026
Kshitij Janardan Lakhani and others added 2 commits April 30, 2026 17:48
…culation for THD non-cp and Cp p2p ring

    - Newer path is O(T*max_segments) per seq
    - Newer path works well with CP p2p ring

    Fix BRCM cross attn by routing to new slow path rather than fast causal path

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

/te-ci jax L0 L1 L2

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
Copy link
Copy Markdown
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the logic is a bit convoluted and expensive but if a customer is waiting on a hot fix, I've approved it. Please think about whether we can improve the logic or fuse these small kernels in the future. Thanks!

@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

KshitijLakhani commented May 1, 2026

I feel the logic is a bit convoluted and expensive but if a customer is waiting on a hot fix, I've approved it. Please think about whether we can improve the logic or fuse these small kernels in the future. Thanks!

Thanks for the feedback ! I agree.
The complexity comes from THD CP P2P using the same function as THD non-CP and there not being a quick, clean decoupling approach for the same.

For THD non-CP the logic is pretty straightforward as it needs only the corresponding segment_ids (so for q seqlens it needs only q segment_ids) - this is shown in the snippet below.

However, for CP it needs both segment_ids and segment_pos and for both q and kv for the rotating ring steps which then complicates the logic - as it is in this PR.

(The seqoffsets calculation is the same and trivial in both cases)

I will try to split the logic in the future

    # Step 1: mark which positions carry a real (non-padding) segment id.
    non_zero_mask = segment_ids != 0
    max_size = segment_ids.shape[-1]

    # Step 2: gather the indices of those valid positions into a dense prefix.
    #   - `non_zero_indices` pulls the i-th valid index to slot i; remaining
    #     slots are padded with -1.
    non_zero_indices = jax.vmap(
        lambda r: jnp.where(r, size=max_size, fill_value=-1)[0]
    )(non_zero_mask)

    # Step 3: use those indices (clipped to >= 0 to keep take_along_axis safe)
    # to pull the segment ids in compacted order. Where index was -1 we write
    # back 0 so bincount sees it as padding.
    clipped = jnp.clip(non_zero_indices, 0, None)
    gathered_ids = jnp.take_along_axis(segment_ids, clipped, axis=-1)
    valid_segment_ids = jnp.where(non_zero_indices >= 0, gathered_ids, 0)

    # Step 4: bincount counts how many tokens belong to each segment id.
    # bincount_raw[0] is "# of padding tokens"; [1:] is per-id length.
    bincount_raw = jax.vmap(
        lambda sp: jnp.bincount(sp, length=max_seg + 1)
    )(valid_segment_ids)
    seqlens_all = bincount_raw[..., 1:]

    # Step 5: express "segment slot k unused" as -1 (cuDNN convention).
    padded = jnp.where(seqlens_all == 0, -1, seqlens_all)
    return padded

@KshitijLakhani KshitijLakhani merged commit 360779b into NVIDIA:main May 1, 2026
9 of 12 checks passed
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 1, 2026

Want your agent to iterate on Greptile's feedback? Try greploops.

KshitijLakhani added a commit that referenced this pull request May 1, 2026
…pace for THD sequences (#2522)

* Get seqlens and offsets in O(N) space instead of O(N*N) space

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Re enable fast causal path

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Fix: seqoffsets calculation for THD

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>

* Clean up code. Add new comments. Fix unecessary pasing of seg pos to the seqoffsets calculation API

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Optimize and fix the slow O(T*T) path for seqlens and seqoffsets calculation for THD non-cp and Cp p2p ring
    - Newer path is O(T*max_segments) per seq
    - Newer path works well with CP p2p ring

    Fix BRCM cross attn by routing to new slow path rather than fast causal path

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix lint failure

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>

---------

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kshitij  Janardan Lakhani <klakhani@login-ptyche02.ptyche.clusters.nvidia.com>
Co-authored-by: JAX Toolbox <jax@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants