From 1e855613d079951de6d77d7d9eb5be6e48d04bb2 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Tue, 16 Dec 2025 16:33:59 +0000 Subject: [PATCH 1/9] Get seqlens and offsets in O(N) space instead of O(N*N) space Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/attention.py | 112 ++++++++++++---------------- 1 file changed, 48 insertions(+), 64 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 29d0848381..fdd2d7c358 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -540,6 +540,44 @@ def run_length_fill(segment_ids) -> jnp.ndarray: run_length_segment_id_shape = jax.vmap(run_length_fill_flattened, in_axes=0)(segment_ids_flat) return run_length_segment_id_shape.reshape(orig_shape) +def _get_seqlens_thd(segment_ids, max_segments_per_seq): + # Create mask for non-zero seg ids and get the non-zero indices associated with the same + non_zero_mask = segment_ids != 0 + max_size = segment_ids.shape[-1] + non_zero_indices = jax.vmap( + lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0] + )(non_zero_mask) + + # Pick non-zero seg ids and seg pos using take_along_axis to index within the seg ids and pos + # Clip -1 to 0 for safe indexing + clipped_indices = jnp.clip(non_zero_indices, 0, None) + valid_segment_ids = jnp.where( + non_zero_indices >= 0, jnp.take_along_axis(segment_ids, clipped_indices, axis=-1), 0 + ) + seqlens_all = jax.vmap( + lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:] + )(valid_segment_ids) + seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all) + return seqlens_all_pad_neg + +def _get_seqoffsets_thd(segment_ids, segment_pos, max_segments_per_seq): + segment_changes = jnp.concatenate( + [ + jnp.full( + (segment_pos.shape[0], 1), True, dtype=bool + ), # First valid element starts a segment + (segment_pos[..., 1:] != segment_pos[..., :-1] + 1), # Segment pos changed + ], + axis=-1, + ) + # Remove any padded region segment changes + segment_changes_masked = jnp.where(segment_ids != 0, segment_changes, False) + # Get the indices for segment changes (these are the offsets) + seq_offsets = jax.vmap( + lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq + 1, fill_value=-1)[0] + )(segment_changes_masked) + return seq_offsets + def _segment_ids_pos_to_seqlens_offsets( segment_ids_q, @@ -572,70 +610,16 @@ def _segment_ids_pos_to_seqlens_offsets( # Currently, this function is only exercised for THD qkv_layout. # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well - if (attn_mask_type.is_causal() and window_size is None) or ( - window_size == (-1, -1) and not attn_mask_type.is_bottom_right() - ): - return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( - segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq - ) - - # (1 = attend, 0 = masked) - segment_mask = make_attention_mask( - segment_ids_q, - segment_ids_kv, - jnp.equal, - ) - segment_mask_with_id = make_attention_mask( - segment_ids_q, - segment_ids_kv, - lambda x, y: jnp.equal(x, y) * x, - ) - # TE JAX Attn expects the THD segments to have q_token <= kv_tokens so that a correct cross-attn type BRCM can be applied - attn_mask = segment_mask - if attn_mask_type.is_bottom_right(): - run_length_out_q = run_length_fill(segment_ids_q) - run_length_out_kv = run_length_fill(segment_ids_kv) - # Example for brcm: - # run_length_out_q: [3 3 3 0 4 4 4 4] - # segment_pos_q: [0 1 2 3 0 1 2 3] - # segment_ids_q: [1 1 1 0 2 2 2 2] - # run_length_out_kv: [4 4 4 4 0 0 10 10 10 10 10 10 10 10 10 10] - # segment_pos_kv: [0 1 2 3 4 5 0 1 2 3 4 5 6 7 8 9] - # segment_ids_kv: [1 1 1 1 0 0 2 2 2 2 2 2 2 2 2 2] - # brcm: [[[1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0] - # [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0] - # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1] - # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1] - # [1 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0] - # [1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0] - # [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0] - # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]]] - # attn_mask(noswa):[[[1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0] - # [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0] - # [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] - # [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]]] - bottom_right_causal_mask = make_attention_mask( - run_length_out_q - segment_pos_q, - run_length_out_kv - segment_pos_kv, - jnp.less_equal, - ) - attn_mask = jnp.logical_and(segment_mask, bottom_right_causal_mask) - elif attn_mask_type.is_causal(): - causal_mask = make_attention_mask( - segment_pos_q, - segment_pos_kv, - jnp.greater_equal, - ) - attn_mask = jnp.logical_and(segment_mask, causal_mask) - - attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0) - q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset( - attn_mask_with_id, max_segments_per_seq - ) + # if (attn_mask_type.is_causal() and window_size is None) or ( + # window_size == (-1, -1) and not attn_mask_type.is_bottom_right() + # ): + # return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( + # segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq + # ) + q_seqlen = _get_seqlens_thd(segment_ids=segment_ids_q, max_segments_per_seq=max_segments_per_seq) + kv_seqlen = _get_seqlens_thd(segment_ids=segment_ids_kv, max_segments_per_seq=max_segments_per_seq) + q_offset = _get_seqoffsets_thd(segment_ids=segment_ids_q, segment_pos=segment_pos_q, max_segments_per_seq=max_segments_per_seq) + kv_offset = _get_seqoffsets_thd(segment_ids=segment_ids_kv, segment_pos=segment_pos_kv, max_segments_per_seq=max_segments_per_seq) return q_seqlen, kv_seqlen, q_offset, kv_offset From 11d2f82e1ebdc0b4c1205d604683a3b843feba23 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Dec 2025 16:43:20 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/attention.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index fdd2d7c358..6f21a7d2d2 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -540,6 +540,7 @@ def run_length_fill(segment_ids) -> jnp.ndarray: run_length_segment_id_shape = jax.vmap(run_length_fill_flattened, in_axes=0)(segment_ids_flat) return run_length_segment_id_shape.reshape(orig_shape) + def _get_seqlens_thd(segment_ids, max_segments_per_seq): # Create mask for non-zero seg ids and get the non-zero indices associated with the same non_zero_mask = segment_ids != 0 @@ -556,10 +557,11 @@ def _get_seqlens_thd(segment_ids, max_segments_per_seq): ) seqlens_all = jax.vmap( lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:] - )(valid_segment_ids) + )(valid_segment_ids) seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all) return seqlens_all_pad_neg + def _get_seqoffsets_thd(segment_ids, segment_pos, max_segments_per_seq): segment_changes = jnp.concatenate( [ @@ -577,7 +579,7 @@ def _get_seqoffsets_thd(segment_ids, segment_pos, max_segments_per_seq): lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq + 1, fill_value=-1)[0] )(segment_changes_masked) return seq_offsets - + def _segment_ids_pos_to_seqlens_offsets( segment_ids_q, @@ -616,10 +618,22 @@ def _segment_ids_pos_to_seqlens_offsets( # return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( # segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq # ) - q_seqlen = _get_seqlens_thd(segment_ids=segment_ids_q, max_segments_per_seq=max_segments_per_seq) - kv_seqlen = _get_seqlens_thd(segment_ids=segment_ids_kv, max_segments_per_seq=max_segments_per_seq) - q_offset = _get_seqoffsets_thd(segment_ids=segment_ids_q, segment_pos=segment_pos_q, max_segments_per_seq=max_segments_per_seq) - kv_offset = _get_seqoffsets_thd(segment_ids=segment_ids_kv, segment_pos=segment_pos_kv, max_segments_per_seq=max_segments_per_seq) + q_seqlen = _get_seqlens_thd( + segment_ids=segment_ids_q, max_segments_per_seq=max_segments_per_seq + ) + kv_seqlen = _get_seqlens_thd( + segment_ids=segment_ids_kv, max_segments_per_seq=max_segments_per_seq + ) + q_offset = _get_seqoffsets_thd( + segment_ids=segment_ids_q, + segment_pos=segment_pos_q, + max_segments_per_seq=max_segments_per_seq, + ) + kv_offset = _get_seqoffsets_thd( + segment_ids=segment_ids_kv, + segment_pos=segment_pos_kv, + max_segments_per_seq=max_segments_per_seq, + ) return q_seqlen, kv_seqlen, q_offset, kv_offset From 642a0d6ccf54ccc01e8074b458e143cf02ac149e Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 18 Mar 2026 15:24:29 -0700 Subject: [PATCH 3/9] Re enable fast causal path Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/attention.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 6f21a7d2d2..f298803cf7 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -612,12 +612,12 @@ def _segment_ids_pos_to_seqlens_offsets( # Currently, this function is only exercised for THD qkv_layout. # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well - # if (attn_mask_type.is_causal() and window_size is None) or ( - # window_size == (-1, -1) and not attn_mask_type.is_bottom_right() - # ): - # return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( - # segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq - # ) + if (attn_mask_type.is_causal() and window_size is None) or ( + window_size == (-1, -1) and not attn_mask_type.is_bottom_right() + ): + return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( + segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq + ) q_seqlen = _get_seqlens_thd( segment_ids=segment_ids_q, max_segments_per_seq=max_segments_per_seq ) From d83fa2afaeb8c3d547a71e8c998427a0b05d2860 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Mon, 27 Apr 2026 11:21:37 -0700 Subject: [PATCH 4/9] Fix: seqoffsets calculation for THD Signed-off-by: Kshitij Janardan Lakhani --- transformer_engine/jax/attention.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index f298803cf7..a5aaa06567 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -563,12 +563,20 @@ def _get_seqlens_thd(segment_ids, max_segments_per_seq): def _get_seqoffsets_thd(segment_ids, segment_pos, max_segments_per_seq): + # NOTE: we detect segment boundaries from segment_ids changes, not segment_pos gaps. + # Under Striped CP reorder (used by P2P ring attention for THD+BALANCED) segment_pos + # values become non-sequential within a single logical segment (e.g. [0,2,4,6] on + # rank 0, [1,3,5,7] on rank 1), which would make a pos-gap detector flag every step + # as a new segment. segment_ids, however, stay contiguous per rank under striping so + # id-change detection is both correct and reorder-invariant, matching the pre-O(N) + # mask-based path. + del segment_pos segment_changes = jnp.concatenate( [ jnp.full( - (segment_pos.shape[0], 1), True, dtype=bool + (segment_ids.shape[0], 1), True, dtype=bool ), # First valid element starts a segment - (segment_pos[..., 1:] != segment_pos[..., :-1] + 1), # Segment pos changed + (segment_ids[..., 1:] != segment_ids[..., :-1]), # Segment id changed ], axis=-1, ) From 03604ada57f13e1630f63539ca3cb27c6ba66858 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Mon, 27 Apr 2026 23:36:49 +0000 Subject: [PATCH 5/9] Clean up code. Add new comments. Fix unecessary pasing of seg pos to the seqoffsets calculation API Signed-off-by: Kshitij Lakhani --- transformer_engine/jax/attention.py | 39 +++++++++++++++++------------ 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index a5aaa06567..33072db51b 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -542,19 +542,29 @@ def run_length_fill(segment_ids) -> jnp.ndarray: def _get_seqlens_thd(segment_ids, max_segments_per_seq): - # Create mask for non-zero seg ids and get the non-zero indices associated with the same + """O(T) per-row segment-length computation for packed-THD layouts. + + Returns a [B, max_segments_per_seq] array whose k-th entry is + the length of the k-th segment in that row (or -1 if no k-th segment + exists). Valid segment ids are >= 1; id 0 is padding and is excluded from the counts. + """ + # Gather the indices of non-padding tokens per row into a dense prefix; + # slots past the last valid token are filled with -1. non_zero_mask = segment_ids != 0 max_size = segment_ids.shape[-1] non_zero_indices = jax.vmap( lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0] )(non_zero_mask) - # Pick non-zero seg ids and seg pos using take_along_axis to index within the seg ids and pos - # Clip -1 to 0 for safe indexing + # Materialise a padding-free view of segment_ids by gathering at + # non_zero_indices. Slots whose index was -1 are explicitly set + # to 0 so they end up in the id=0 bucket (that we drop below). clipped_indices = jnp.clip(non_zero_indices, 0, None) valid_segment_ids = jnp.where( non_zero_indices >= 0, jnp.take_along_axis(segment_ids, clipped_indices, axis=-1), 0 ) + # Per-row bincount of ids -> segment length, discarding the + # id=0 bucket (padding) and capping at max_segments_per_seq. seqlens_all = jax.vmap( lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:] )(valid_segment_ids) @@ -562,15 +572,14 @@ def _get_seqlens_thd(segment_ids, max_segments_per_seq): return seqlens_all_pad_neg -def _get_seqoffsets_thd(segment_ids, segment_pos, max_segments_per_seq): - # NOTE: we detect segment boundaries from segment_ids changes, not segment_pos gaps. - # Under Striped CP reorder (used by P2P ring attention for THD+BALANCED) segment_pos - # values become non-sequential within a single logical segment (e.g. [0,2,4,6] on - # rank 0, [1,3,5,7] on rank 1), which would make a pos-gap detector flag every step - # as a new segment. segment_ids, however, stay contiguous per rank under striping so - # id-change detection is both correct and reorder-invariant, matching the pre-O(N) - # mask-based path. - del segment_pos +def _get_seqoffsets_thd(segment_ids, max_segments_per_seq): + """O(T) per-row segment start-offset computation for packed-THD layouts. + + Returns a [B, max_segments_per_seq + 1] array whose k-th entry + is the starting index of the k-th segment in that row (or -1 if no k-th + segment exists). Boundaries are detected from segment_ids transitions + """ + # Detect segment boundaries from segment_ids changes segment_changes = jnp.concatenate( [ jnp.full( @@ -580,9 +589,9 @@ def _get_seqoffsets_thd(segment_ids, segment_pos, max_segments_per_seq): ], axis=-1, ) - # Remove any padded region segment changes + # Remove any padded region segment changes (this also handles intra-segment padding correctly) segment_changes_masked = jnp.where(segment_ids != 0, segment_changes, False) - # Get the indices for segment changes (these are the offsets) + # Get the indices for segment changes (these are the start offsets) seq_offsets = jax.vmap( lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq + 1, fill_value=-1)[0] )(segment_changes_masked) @@ -634,12 +643,10 @@ def _segment_ids_pos_to_seqlens_offsets( ) q_offset = _get_seqoffsets_thd( segment_ids=segment_ids_q, - segment_pos=segment_pos_q, max_segments_per_seq=max_segments_per_seq, ) kv_offset = _get_seqoffsets_thd( segment_ids=segment_ids_kv, - segment_pos=segment_pos_kv, max_segments_per_seq=max_segments_per_seq, ) return q_seqlen, kv_seqlen, q_offset, kv_offset From 66ddc786324dddb35b88309aa2a2ab6a7c22372a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Apr 2026 00:08:29 +0000 Subject: [PATCH 6/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 33072db51b..ee90bbbfc9 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -548,8 +548,8 @@ def _get_seqlens_thd(segment_ids, max_segments_per_seq): the length of the k-th segment in that row (or -1 if no k-th segment exists). Valid segment ids are >= 1; id 0 is padding and is excluded from the counts. """ - # Gather the indices of non-padding tokens per row into a dense prefix; - # slots past the last valid token are filled with -1. + # Gather the indices of non-padding tokens per row into a dense prefix; + # slots past the last valid token are filled with -1. non_zero_mask = segment_ids != 0 max_size = segment_ids.shape[-1] non_zero_indices = jax.vmap( From 423569a3bea28b3cba586cd1e0fc804d60d6cf5c Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Thu, 30 Apr 2026 17:48:55 -0700 Subject: [PATCH 7/9] Optimize and fix the slow O(T*T) path for seqlens and seqoffsets calculation for THD non-cp and Cp p2p ring - Newer path is O(T*max_segments) per seq - Newer path works well with CP p2p ring Fix BRCM cross attn by routing to new slow path rather than fast causal path Signed-off-by: Kshitij Janardan Lakhani --- transformer_engine/jax/attention.py | 278 ++++++++++++++++++++-------- 1 file changed, 200 insertions(+), 78 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index ee90bbbfc9..0f60fc3f28 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -541,61 +541,146 @@ def run_length_fill(segment_ids) -> jnp.ndarray: return run_length_segment_id_shape.reshape(orig_shape) -def _get_seqlens_thd(segment_ids, max_segments_per_seq): - """O(T) per-row segment-length computation for packed-THD layouts. - - Returns a [B, max_segments_per_seq] array whose k-th entry is - the length of the k-th segment in that row (or -1 if no k-th segment - exists). Valid segment ids are >= 1; id 0 is padding and is excluded from the counts. +def _get_seqlens_offsets_thd( + segment_ids_q, + segment_ids_kv, + segment_pos_q, + segment_pos_kv, + attn_mask_type, + max_segments_per_seq, +): + """O(T * max_segments_per_seq) replacement for the older O(T^2) mask-based slow path. + Returns (q_seqlen, kv_seqlen, q_offset, kv_offset) values to match the reference older mask-based path: + segment_mask = make_attention_mask(q_ids, kv_ids, equal) + segment_mask_with_id = make_attention_mask(q_ids, kv_ids, equal * q_id) + attn_mask = segment_mask AND (causal_or_brcm_or_none) + attn_mask_with_id = where(attn_mask, segment_mask_with_id, 0) + row_ids = reduce_max(attn_mask_with_id, axis=kv) # [B, T_q] + col_ids = reduce_max(attn_mask_with_id, axis=q) # [B, T_kv] + seqlens/offsets = bincount(...) / find_offsets(...) + The two reductions are expressed equivalently as per-segment aggregates: + - causal: row_ids[q] = q_seg_id iff seg_pos_q[q] >= min(seg_pos_kv over same-seg KV) + - brcm: row_ids[q] = q_seg_id iff (run_len_q - seg_pos_q) >= + min(run_len_kv - seg_pos_kv over same-seg KV) + - padding: row_ids[q] = q_seg_id iff q_seg_id appears in KV + (and symmetrically for col_ids with max/<=). """ - # Gather the indices of non-padding tokens per row into a dense prefix; - # slots past the last valid token are filled with -1. - non_zero_mask = segment_ids != 0 - max_size = segment_ids.shape[-1] - non_zero_indices = jax.vmap( - lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0] - )(non_zero_mask) - - # Materialise a padding-free view of segment_ids by gathering at - # non_zero_indices. Slots whose index was -1 are explicitly set - # to 0 so they end up in the id=0 bucket (that we drop below). - clipped_indices = jnp.clip(non_zero_indices, 0, None) - valid_segment_ids = jnp.where( - non_zero_indices >= 0, jnp.take_along_axis(segment_ids, clipped_indices, axis=-1), 0 - ) - # Per-row bincount of ids -> segment length, discarding the - # id=0 bucket (padding) and capping at max_segments_per_seq. - seqlens_all = jax.vmap( - lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:] - )(valid_segment_ids) - seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all) - return seqlens_all_pad_neg + # Example: For striping P2P causal attention (but this logic also applies for non-CP fused attn) + # pre-striping and sharding: segment_ids = [[1 1 1 1 2 2 2 2]], segment_pos = [[0 1 2 3 0 1 2 3]] + # post-striping and sharding (striped CP=2, Q from rank 0 × KV from rank 1, max_segments_per_seq=2): + # segment_ids_q = [1 1 2 2] segment_pos_q = [0 2 0 2] → q_key = [0 2 0 2] + # segment_ids_kv = [1 1 2 2] segment_pos_kv = [1 3 1 3] → kv_key = [1 3 1 3] + # Q-side — kv_agg[s] = min(kv_key over same-seg KV), fill = max_fill_val = 5 (assumed to be large enough): + # scatter (rows = kv tokens, cols = segs): + # [5 1 5 / 5 3 5 / 5 5 1 / 5 5 3] → reduce min → kv_agg = [5 1 1] + # q_ok = q_key >= kv_agg[seg_ids_q] = [0 2 0 2] >= [1 1 1 1] = [F T F T] + # KV-side — q_agg[s] = max(q_key over same-seg Q), fill = neg_fill_val = -1 (assumed to be small enough): + # scatter: [-1 0 -1 / -1 2 -1 / -1 -1 0 / -1 -1 2] → reduce max → q_agg = [-1 2 2] + # kv_ok = kv_key <= q_agg[seg_ids_kv] = [1 3 1 3] <= [2 2 2 2] = [T F T F] + # Outer combiner: + # row_ids = [0 1 0 2] col_ids = [1 0 2 0] + # q_seqlen = [1 1] kv_seqlen = [1 1] + # q_offset = [1 3 -1] kv_offset = [0 2 -1] + def _row_and_col_ids(): + if attn_mask_type.is_bottom_right(): + # BRCM: mask[q][kv] = (same seg) AND (q_key <= kv_key). + rl_q = run_length_fill(segment_ids_q) + rl_kv = run_length_fill(segment_ids_kv) + q_key = (rl_q - segment_pos_q).astype(jnp.int32) + kv_key = (rl_kv - segment_pos_kv).astype(jnp.int32) + + # Use large positive and negative values as fill values for the KV keys and Q keys respectively + max_fill_val = jnp.asarray(jnp.iinfo(jnp.int32).max, dtype=jnp.int32) + neg_fill_val = jnp.asarray(-1, dtype=jnp.int32) + # Creates a one-hot encoding mask of the KV segment ids (size [B, T_kv, max_segments_per_seq+1]) + # i.e. each row has only one True value, which is the segment id of the row. + kv_oh = jax.nn.one_hot(segment_ids_kv, max_segments_per_seq + 1, dtype=jnp.bool_) + # Mask the KV keys with the valid segment ids (size [B, T_kv, 1]) + kv_key_masked = jnp.where(segment_ids_kv != 0, kv_key, neg_fill_val)[..., None] + # Scatter each KV key (i.e. seg pos) into it's own segment column + kv_agg = jnp.where(kv_oh, kv_key_masked, neg_fill_val) + kv_agg = jnp.max(kv_agg, axis=-2) + # Define causal relationship: Q is attended iff q_key <= max(kv_key over same-seg KV) + q_has_match = q_key <= jnp.take_along_axis( + kv_agg, segment_ids_q.astype(jnp.int32), axis=-1 + ) + # Symmetric to the Q case, but with KV and Q swapped + q_oh = jax.nn.one_hot(segment_ids_q, max_segments_per_seq + 1, dtype=jnp.bool_) + q_key_masked = jnp.where(segment_ids_q != 0, q_key, max_fill_val)[..., None] + q_agg = jnp.where(q_oh, q_key_masked, max_fill_val) + q_agg = jnp.min(q_agg, axis=-2) + # Define causal relationship: KV is attended iff kv_key >= min(q_key over same-seg Q) + kv_has_match = kv_key >= jnp.take_along_axis( + q_agg, segment_ids_kv.astype(jnp.int32), axis=-1 + ) + elif attn_mask_type.is_causal(): + # CM: mask[q][kv] = (same_seg) AND (q_pos >= kv_pos). + q_key = segment_pos_q.astype(jnp.int32) + kv_key = segment_pos_kv.astype(jnp.int32) + + # Use large positive and negative values as a fill value for the KV keys and Q keys respectively + max_fill_val = jnp.asarray(jnp.iinfo(jnp.int32).max, dtype=jnp.int32) + neg_fill_val = jnp.asarray(-1, dtype=jnp.int32) + + # Creates a one-hot encoding mask of the KV segment ids (size [B, T_kv, max_segments_per_seq+1]) + # i.e. each row has only one True value, which is the segment id of the row. + kv_oh = jax.nn.one_hot(segment_ids_kv, max_segments_per_seq + 1, dtype=jnp.bool_) + # Mask the KV keys with the valid segment ids (size [B, T_kv, 1]) + kv_key_masked = jnp.where(segment_ids_kv != 0, kv_key, max_fill_val)[..., None] + # Scatter each KV key (i.e. seg pos) into it's own segment column + kv_agg = jnp.where(kv_oh, kv_key_masked, max_fill_val) + kv_agg = jnp.min(kv_agg, axis=-2) + # Define causal relationship: Q is attended iff q_key >= min(kv_key over same-seg KV) + q_has_match = q_key >= jnp.take_along_axis( + kv_agg, segment_ids_q.astype(jnp.int32), axis=-1 + ) -def _get_seqoffsets_thd(segment_ids, max_segments_per_seq): - """O(T) per-row segment start-offset computation for packed-THD layouts. + # Symmetric to the Q case, but with KV and Q swapped + q_oh = jax.nn.one_hot(segment_ids_q, max_segments_per_seq + 1, dtype=jnp.bool_) + q_key_masked = jnp.where(segment_ids_q != 0, q_key, neg_fill_val)[..., None] + q_agg = jnp.where(q_oh, q_key_masked, neg_fill_val) + q_agg = jnp.max(q_agg, axis=-2) + # Define causal relationship: KV is attended iff kv_key <= max(q_key over same-seg Q) + kv_has_match = kv_key <= jnp.take_along_axis( + q_agg, segment_ids_kv.astype(jnp.int32), axis=-1 + ) + else: + # Padding-only: row_ids[q] = q_seg_id iff q_seg_id is present in KV (and q not pad). + kv_seg_ids_present = jax.nn.one_hot( + segment_ids_kv, max_segments_per_seq + 1, dtype=jnp.bool_ + ).any(axis=-2) + q_seg_ids_present = jax.nn.one_hot( + segment_ids_q, max_segments_per_seq + 1, dtype=jnp.bool_ + ).any(axis=-2) + q_has_match = jnp.take_along_axis( + kv_seg_ids_present, segment_ids_q.astype(jnp.int32), axis=-1 + ) & (segment_ids_q != 0) + kv_has_match = jnp.take_along_axis( + q_seg_ids_present, segment_ids_kv.astype(jnp.int32), axis=-1 + ) & (segment_ids_kv != 0) + + row_ids = jnp.where(q_has_match, segment_ids_q, 0).astype(jnp.int32) + col_ids = jnp.where(kv_has_match, segment_ids_kv, 0).astype(jnp.int32) + return row_ids, col_ids + + row_ids, col_ids = _row_and_col_ids() - Returns a [B, max_segments_per_seq + 1] array whose k-th entry - is the starting index of the k-th segment in that row (or -1 if no k-th - segment exists). Boundaries are detected from segment_ids transitions - """ - # Detect segment boundaries from segment_ids changes - segment_changes = jnp.concatenate( - [ - jnp.full( - (segment_ids.shape[0], 1), True, dtype=bool - ), # First valid element starts a segment - (segment_ids[..., 1:] != segment_ids[..., :-1]), # Segment id changed - ], - axis=-1, - ) - # Remove any padded region segment changes (this also handles intra-segment padding correctly) - segment_changes_masked = jnp.where(segment_ids != 0, segment_changes, False) - # Get the indices for segment changes (these are the start offsets) - seq_offsets = jax.vmap( - lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq + 1, fill_value=-1)[0] - )(segment_changes_masked) - return seq_offsets + bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_segments_per_seq + 1)) + q_seqlen = bincount_vmap(row_ids)[..., 1:] + kv_seqlen = bincount_vmap(col_ids)[..., 1:] + + def _find_offsets(x): + same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0) + first_column = x[..., :1] != 0 + boundaries = jnp.concatenate([first_column, same_as_previous], axis=-1) + return jax.vmap( + partial(jnp.argwhere, size=(max_segments_per_seq + 1), fill_value=-1) + )(boundaries).squeeze(-1) + + q_offset = _find_offsets(row_ids) + kv_offset = _find_offsets(col_ids) + return q_seqlen, kv_seqlen, q_offset, kv_offset def _segment_ids_pos_to_seqlens_offsets( @@ -607,9 +692,52 @@ def _segment_ids_pos_to_seqlens_offsets( window_size, max_segments_per_seq, ): + """Compute per-segment seqlens and start offsets(currently only used for THD) + Given segment-id and segment-position tensors for Q and KV, + returns the four metadata tensors cuDNN needed for variable-length attention: + q_seqlen : [..., max_segments_per_seq] # valid Q tokens per segment + kv_seqlen : [..., max_segments_per_seq] # valid KV tokens per segment + q_offset : [..., max_segments_per_seq + 1] # start index of each Q segment + kv_offset : [..., max_segments_per_seq + 1] # start index of each KV segment + + Args: + segment_ids_q: int32 [..., T_q] per-token segment id; 0 == padding + segment_ids_kv: int32 [..., T_kv] same convention as segment_ids_q + segment_pos_q: int32 [..., T_q] per-token position inside its segment + segment_pos_kv: int32 [..., T_kv] same convention as segment_pos_q + attn_mask_type: AttnMaskType. Selects the mask predicate used to decide + which positions are valid (top-left causal vs + bottom-right causal vs. padding-only) + window_size: Optional sliding-window tuple ``(left, right)`` or None + Used here only as a fast-path eligibility hint + max_segments_per_seq: maximum number of segments expected per row + Used to size the bincount / argwhere outputs + + Routing (only invoked for THD qkv_layout): + 1. Fast path -- ``_segment_ids_pos_to_seqlens_offsets_fast_causal_path``. + O(T) per row. Counts all segment tokens via bincount on + segment_ids and trims at most one token per segment at the + boundary. Used for: + - top-left CAUSAL / PADDING_CAUSAL with ``window_size is None`` + - SWA with ``window_size == (-1, -1)`` and not bottom-right + Bottom-right causal cross-attention is excluded: the boundary + trim leaves kv_seqlen short by one per active segment, which + shifts the BRCM bottom-right alignment by one KV per Q row. + + 2. Slow path -- ``_get_seqlens_offsets_thd``. + O(T * max_segments_per_seq) per row. Per-segment min/max + aggregation that is equivalent to the older O(T^2) + mask-based reference for top-left causal, bottom-right causal, + and padding-only masks. Required under ring attention where + ``segment_ids_q != segment_ids_kv`` in rotated steps. + + Returns: + Tuple ``(q_seqlen, kv_seqlen, q_offset, kv_offset)`` with shapes as + above. Inactive segment slots are filled with 0 in seqlens and -1 + in offsets. + """ # TODO(mgoldfarb-nvidia): Consider an opt-in for arbitrary masking if needed here. # Computing the full mask is expensive due to quadratic expansion of Q * KV masking. - # Assumptions for cudnn causal mask correctness. # 1. Segments are monotonic [4 4 4 0 0 5 5 5 6 6 0 0] # 2. No intra-segment padding, only inter-segment paddding allowed @@ -618,38 +746,32 @@ def _segment_ids_pos_to_seqlens_offsets( # 0 x x # 4 x x x x x # 8 x x x x x x x x - # # This fast path avoids expanding the mask to Q * KV matrix and instead allows us to # examine only O(Q+KV) elements. - - # For seqlens and seqoffsets calculations, the intermediate(temp) attn_mask creation - # using the segment ids and pos along with mask type (causal or brcm) is sufficient. - # It does not need to involve SW for this mask's creation - - # Currently, this function is only exercised for THD qkv_layout. - - # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well - if (attn_mask_type.is_causal() and window_size is None) or ( - window_size == (-1, -1) and not attn_mask_type.is_bottom_right() - ): + # The fast causal path encodes TOP-LEFT causal semantics via + # valid[q][kv] = (segment_pos_q >= segment_pos_kv) + # which is only equivalent to BRCM when s_q == s_kv (self-attention). For + # cross-attention (s_q != s_kv), BRCM diverges from top-left causal, so we + # must route bottom-right masks to the slow path. + + # Fast path: O(T) per row. + if ( + attn_mask_type.is_causal() + and not attn_mask_type.is_bottom_right() + and window_size is None + ) or (window_size == (-1, -1) and not attn_mask_type.is_bottom_right()): return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq ) - q_seqlen = _get_seqlens_thd( - segment_ids=segment_ids_q, max_segments_per_seq=max_segments_per_seq - ) - kv_seqlen = _get_seqlens_thd( - segment_ids=segment_ids_kv, max_segments_per_seq=max_segments_per_seq - ) - q_offset = _get_seqoffsets_thd( - segment_ids=segment_ids_q, - max_segments_per_seq=max_segments_per_seq, - ) - kv_offset = _get_seqoffsets_thd( - segment_ids=segment_ids_kv, - max_segments_per_seq=max_segments_per_seq, + # Slow path: O(T * max_segments_per_seq) per row. + return _get_seqlens_offsets_thd( + segment_ids_q, + segment_ids_kv, + segment_pos_q, + segment_pos_kv, + attn_mask_type, + max_segments_per_seq, ) - return q_seqlen, kv_seqlen, q_offset, kv_offset def _segment_ids_to_seqlens(segment_ids_q, segment_ids_kv, attn_mask_type): From 4380b13aea13637531eb7bb71eba188e9f1964dc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 May 2026 00:51:35 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/attention.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 0f60fc3f28..13157a29a1 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -565,6 +565,7 @@ def _get_seqlens_offsets_thd( - padding: row_ids[q] = q_seg_id iff q_seg_id appears in KV (and symmetrically for col_ids with max/<=). """ + # Example: For striping P2P causal attention (but this logic also applies for non-CP fused attn) # pre-striping and sharding: segment_ids = [[1 1 1 1 2 2 2 2]], segment_pos = [[0 1 2 3 0 1 2 3]] # post-striping and sharding (striped CP=2, Q from rank 0 × KV from rank 1, max_segments_per_seq=2): @@ -674,9 +675,9 @@ def _find_offsets(x): same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0) first_column = x[..., :1] != 0 boundaries = jnp.concatenate([first_column, same_as_previous], axis=-1) - return jax.vmap( - partial(jnp.argwhere, size=(max_segments_per_seq + 1), fill_value=-1) - )(boundaries).squeeze(-1) + return jax.vmap(partial(jnp.argwhere, size=(max_segments_per_seq + 1), fill_value=-1))( + boundaries + ).squeeze(-1) q_offset = _find_offsets(row_ids) kv_offset = _find_offsets(col_ids) @@ -756,9 +757,7 @@ def _segment_ids_pos_to_seqlens_offsets( # Fast path: O(T) per row. if ( - attn_mask_type.is_causal() - and not attn_mask_type.is_bottom_right() - and window_size is None + attn_mask_type.is_causal() and not attn_mask_type.is_bottom_right() and window_size is None ) or (window_size == (-1, -1) and not attn_mask_type.is_bottom_right()): return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq From 1e0380c3389754611440e5f88db853b8c31aebda Mon Sep 17 00:00:00 2001 From: JAX Toolbox Date: Fri, 1 May 2026 11:46:59 -0700 Subject: [PATCH 9/9] Fix lint failure Signed-off-by: Kshitij Janardan Lakhani --- transformer_engine/jax/attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 13157a29a1..f54a043fd2 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -11,7 +11,6 @@ from jax.ad_checkpoint import checkpoint_name import jax import jax.numpy as jnp -from flax.linen import make_attention_mask from transformer_engine_jax import NVTE_Bias_Type from transformer_engine_jax import NVTE_Mask_Type