Skip to content
Merged
273 changes: 203 additions & 70 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from jax.ad_checkpoint import checkpoint_name
import jax
import jax.numpy as jnp
from flax.linen import make_attention_mask

from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
Expand Down Expand Up @@ -541,6 +540,149 @@ def run_length_fill(segment_ids) -> jnp.ndarray:
return run_length_segment_id_shape.reshape(orig_shape)


def _get_seqlens_offsets_thd(
segment_ids_q,
segment_ids_kv,
segment_pos_q,
segment_pos_kv,
attn_mask_type,
max_segments_per_seq,
):
"""O(T * max_segments_per_seq) replacement for the older O(T^2) mask-based slow path.
Returns (q_seqlen, kv_seqlen, q_offset, kv_offset) values to match the reference older mask-based path:
segment_mask = make_attention_mask(q_ids, kv_ids, equal)
segment_mask_with_id = make_attention_mask(q_ids, kv_ids, equal * q_id)
attn_mask = segment_mask AND (causal_or_brcm_or_none)
attn_mask_with_id = where(attn_mask, segment_mask_with_id, 0)
row_ids = reduce_max(attn_mask_with_id, axis=kv) # [B, T_q]
col_ids = reduce_max(attn_mask_with_id, axis=q) # [B, T_kv]
seqlens/offsets = bincount(...) / find_offsets(...)
The two reductions are expressed equivalently as per-segment aggregates:
- causal: row_ids[q] = q_seg_id iff seg_pos_q[q] >= min(seg_pos_kv over same-seg KV)
- brcm: row_ids[q] = q_seg_id iff (run_len_q - seg_pos_q) >=
min(run_len_kv - seg_pos_kv over same-seg KV)
- padding: row_ids[q] = q_seg_id iff q_seg_id appears in KV
(and symmetrically for col_ids with max/<=).
"""

# Example: For striping P2P causal attention (but this logic also applies for non-CP fused attn)
# pre-striping and sharding: segment_ids = [[1 1 1 1 2 2 2 2]], segment_pos = [[0 1 2 3 0 1 2 3]]
# post-striping and sharding (striped CP=2, Q from rank 0 × KV from rank 1, max_segments_per_seq=2):
# segment_ids_q = [1 1 2 2] segment_pos_q = [0 2 0 2] → q_key = [0 2 0 2]
# segment_ids_kv = [1 1 2 2] segment_pos_kv = [1 3 1 3] → kv_key = [1 3 1 3]
# Q-side — kv_agg[s] = min(kv_key over same-seg KV), fill = max_fill_val = 5 (assumed to be large enough):
# scatter (rows = kv tokens, cols = segs):
# [5 1 5 / 5 3 5 / 5 5 1 / 5 5 3] → reduce min → kv_agg = [5 1 1]
# q_ok = q_key >= kv_agg[seg_ids_q] = [0 2 0 2] >= [1 1 1 1] = [F T F T]
# KV-side — q_agg[s] = max(q_key over same-seg Q), fill = neg_fill_val = -1 (assumed to be small enough):
# scatter: [-1 0 -1 / -1 2 -1 / -1 -1 0 / -1 -1 2] → reduce max → q_agg = [-1 2 2]
# kv_ok = kv_key <= q_agg[seg_ids_kv] = [1 3 1 3] <= [2 2 2 2] = [T F T F]
# Outer combiner:
# row_ids = [0 1 0 2] col_ids = [1 0 2 0]
# q_seqlen = [1 1] kv_seqlen = [1 1]
# q_offset = [1 3 -1] kv_offset = [0 2 -1]
def _row_and_col_ids():
if attn_mask_type.is_bottom_right():
# BRCM: mask[q][kv] = (same seg) AND (q_key <= kv_key).
rl_q = run_length_fill(segment_ids_q)
rl_kv = run_length_fill(segment_ids_kv)
q_key = (rl_q - segment_pos_q).astype(jnp.int32)
kv_key = (rl_kv - segment_pos_kv).astype(jnp.int32)

# Use large positive and negative values as fill values for the KV keys and Q keys respectively
max_fill_val = jnp.asarray(jnp.iinfo(jnp.int32).max, dtype=jnp.int32)
neg_fill_val = jnp.asarray(-1, dtype=jnp.int32)
# Creates a one-hot encoding mask of the KV segment ids (size [B, T_kv, max_segments_per_seq+1])
# i.e. each row has only one True value, which is the segment id of the row.
kv_oh = jax.nn.one_hot(segment_ids_kv, max_segments_per_seq + 1, dtype=jnp.bool_)
# Mask the KV keys with the valid segment ids (size [B, T_kv, 1])
kv_key_masked = jnp.where(segment_ids_kv != 0, kv_key, neg_fill_val)[..., None]
# Scatter each KV key (i.e. seg pos) into it's own segment column
kv_agg = jnp.where(kv_oh, kv_key_masked, neg_fill_val)
kv_agg = jnp.max(kv_agg, axis=-2)
# Define causal relationship: Q is attended iff q_key <= max(kv_key over same-seg KV)
q_has_match = q_key <= jnp.take_along_axis(
kv_agg, segment_ids_q.astype(jnp.int32), axis=-1
)

# Symmetric to the Q case, but with KV and Q swapped
q_oh = jax.nn.one_hot(segment_ids_q, max_segments_per_seq + 1, dtype=jnp.bool_)
q_key_masked = jnp.where(segment_ids_q != 0, q_key, max_fill_val)[..., None]
q_agg = jnp.where(q_oh, q_key_masked, max_fill_val)
q_agg = jnp.min(q_agg, axis=-2)
# Define causal relationship: KV is attended iff kv_key >= min(q_key over same-seg Q)
kv_has_match = kv_key >= jnp.take_along_axis(
q_agg, segment_ids_kv.astype(jnp.int32), axis=-1
)
elif attn_mask_type.is_causal():
# CM: mask[q][kv] = (same_seg) AND (q_pos >= kv_pos).
q_key = segment_pos_q.astype(jnp.int32)
kv_key = segment_pos_kv.astype(jnp.int32)

# Use large positive and negative values as a fill value for the KV keys and Q keys respectively
max_fill_val = jnp.asarray(jnp.iinfo(jnp.int32).max, dtype=jnp.int32)
neg_fill_val = jnp.asarray(-1, dtype=jnp.int32)

# Creates a one-hot encoding mask of the KV segment ids (size [B, T_kv, max_segments_per_seq+1])
# i.e. each row has only one True value, which is the segment id of the row.
kv_oh = jax.nn.one_hot(segment_ids_kv, max_segments_per_seq + 1, dtype=jnp.bool_)
# Mask the KV keys with the valid segment ids (size [B, T_kv, 1])
kv_key_masked = jnp.where(segment_ids_kv != 0, kv_key, max_fill_val)[..., None]
# Scatter each KV key (i.e. seg pos) into it's own segment column
kv_agg = jnp.where(kv_oh, kv_key_masked, max_fill_val)
kv_agg = jnp.min(kv_agg, axis=-2)
# Define causal relationship: Q is attended iff q_key >= min(kv_key over same-seg KV)
q_has_match = q_key >= jnp.take_along_axis(
kv_agg, segment_ids_q.astype(jnp.int32), axis=-1
)

# Symmetric to the Q case, but with KV and Q swapped
q_oh = jax.nn.one_hot(segment_ids_q, max_segments_per_seq + 1, dtype=jnp.bool_)
q_key_masked = jnp.where(segment_ids_q != 0, q_key, neg_fill_val)[..., None]
q_agg = jnp.where(q_oh, q_key_masked, neg_fill_val)
q_agg = jnp.max(q_agg, axis=-2)
# Define causal relationship: KV is attended iff kv_key <= max(q_key over same-seg Q)
kv_has_match = kv_key <= jnp.take_along_axis(
q_agg, segment_ids_kv.astype(jnp.int32), axis=-1
)
else:
# Padding-only: row_ids[q] = q_seg_id iff q_seg_id is present in KV (and q not pad).
kv_seg_ids_present = jax.nn.one_hot(
segment_ids_kv, max_segments_per_seq + 1, dtype=jnp.bool_
).any(axis=-2)
q_seg_ids_present = jax.nn.one_hot(
segment_ids_q, max_segments_per_seq + 1, dtype=jnp.bool_
).any(axis=-2)
q_has_match = jnp.take_along_axis(
kv_seg_ids_present, segment_ids_q.astype(jnp.int32), axis=-1
) & (segment_ids_q != 0)
kv_has_match = jnp.take_along_axis(
q_seg_ids_present, segment_ids_kv.astype(jnp.int32), axis=-1
) & (segment_ids_kv != 0)

row_ids = jnp.where(q_has_match, segment_ids_q, 0).astype(jnp.int32)
col_ids = jnp.where(kv_has_match, segment_ids_kv, 0).astype(jnp.int32)
return row_ids, col_ids

row_ids, col_ids = _row_and_col_ids()

bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_segments_per_seq + 1))
q_seqlen = bincount_vmap(row_ids)[..., 1:]
kv_seqlen = bincount_vmap(col_ids)[..., 1:]

def _find_offsets(x):
same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
first_column = x[..., :1] != 0
boundaries = jnp.concatenate([first_column, same_as_previous], axis=-1)
return jax.vmap(partial(jnp.argwhere, size=(max_segments_per_seq + 1), fill_value=-1))(
boundaries
).squeeze(-1)

q_offset = _find_offsets(row_ids)
kv_offset = _find_offsets(col_ids)
return q_seqlen, kv_seqlen, q_offset, kv_offset


def _segment_ids_pos_to_seqlens_offsets(
segment_ids_q,
segment_ids_kv,
Expand All @@ -550,9 +692,52 @@ def _segment_ids_pos_to_seqlens_offsets(
window_size,
max_segments_per_seq,
):
"""Compute per-segment seqlens and start offsets(currently only used for THD)
Given segment-id and segment-position tensors for Q and KV,
returns the four metadata tensors cuDNN needed for variable-length attention:
q_seqlen : [..., max_segments_per_seq] # valid Q tokens per segment
kv_seqlen : [..., max_segments_per_seq] # valid KV tokens per segment
q_offset : [..., max_segments_per_seq + 1] # start index of each Q segment
kv_offset : [..., max_segments_per_seq + 1] # start index of each KV segment

Args:
segment_ids_q: int32 [..., T_q] per-token segment id; 0 == padding
segment_ids_kv: int32 [..., T_kv] same convention as segment_ids_q
segment_pos_q: int32 [..., T_q] per-token position inside its segment
segment_pos_kv: int32 [..., T_kv] same convention as segment_pos_q
attn_mask_type: AttnMaskType. Selects the mask predicate used to decide
which positions are valid (top-left causal vs
bottom-right causal vs. padding-only)
window_size: Optional sliding-window tuple ``(left, right)`` or None
Used here only as a fast-path eligibility hint
max_segments_per_seq: maximum number of segments expected per row
Used to size the bincount / argwhere outputs

Routing (only invoked for THD qkv_layout):
1. Fast path -- ``_segment_ids_pos_to_seqlens_offsets_fast_causal_path``.
O(T) per row. Counts all segment tokens via bincount on
segment_ids and trims at most one token per segment at the
boundary. Used for:
- top-left CAUSAL / PADDING_CAUSAL with ``window_size is None``
- SWA with ``window_size == (-1, -1)`` and not bottom-right
Bottom-right causal cross-attention is excluded: the boundary
trim leaves kv_seqlen short by one per active segment, which
shifts the BRCM bottom-right alignment by one KV per Q row.

2. Slow path -- ``_get_seqlens_offsets_thd``.
O(T * max_segments_per_seq) per row. Per-segment min/max
aggregation that is equivalent to the older O(T^2)
mask-based reference for top-left causal, bottom-right causal,
and padding-only masks. Required under ring attention where
``segment_ids_q != segment_ids_kv`` in rotated steps.

Returns:
Tuple ``(q_seqlen, kv_seqlen, q_offset, kv_offset)`` with shapes as
above. Inactive segment slots are filled with 0 in seqlens and -1
in offsets.
"""
# TODO(mgoldfarb-nvidia): Consider an opt-in for arbitrary masking if needed here.
# Computing the full mask is expensive due to quadratic expansion of Q * KV masking.

# Assumptions for cudnn causal mask correctness.
# 1. Segments are monotonic [4 4 4 0 0 5 5 5 6 6 0 0]
# 2. No intra-segment padding, only inter-segment paddding allowed
Expand All @@ -561,82 +746,30 @@ def _segment_ids_pos_to_seqlens_offsets(
# 0 x x
# 4 x x x x x
# 8 x x x x x x x x
#
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements.

# For seqlens and seqoffsets calculations, the intermediate(temp) attn_mask creation
# using the segment ids and pos along with mask type (causal or brcm) is sufficient.
# It does not need to involve SW for this mask's creation

# Currently, this function is only exercised for THD qkv_layout.

# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
if (attn_mask_type.is_causal() and window_size is None) or (
window_size == (-1, -1) and not attn_mask_type.is_bottom_right()
):
# The fast causal path encodes TOP-LEFT causal semantics via
# valid[q][kv] = (segment_pos_q >= segment_pos_kv)
# which is only equivalent to BRCM when s_q == s_kv (self-attention). For
# cross-attention (s_q != s_kv), BRCM diverges from top-left causal, so we
# must route bottom-right masks to the slow path.

# Fast path: O(T) per row.
if (
attn_mask_type.is_causal() and not attn_mask_type.is_bottom_right() and window_size is None
) or (window_size == (-1, -1) and not attn_mask_type.is_bottom_right()):
return _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
)

# (1 = attend, 0 = masked)
segment_mask = make_attention_mask(
segment_ids_q,
segment_ids_kv,
jnp.equal,
)
segment_mask_with_id = make_attention_mask(
# Slow path: O(T * max_segments_per_seq) per row.
return _get_seqlens_offsets_thd(
segment_ids_q,
segment_ids_kv,
lambda x, y: jnp.equal(x, y) * x,
)
# TE JAX Attn expects the THD segments to have q_token <= kv_tokens so that a correct cross-attn type BRCM can be applied
attn_mask = segment_mask
if attn_mask_type.is_bottom_right():
run_length_out_q = run_length_fill(segment_ids_q)
run_length_out_kv = run_length_fill(segment_ids_kv)
# Example for brcm:
# run_length_out_q: [3 3 3 0 4 4 4 4]
# segment_pos_q: [0 1 2 3 0 1 2 3]
# segment_ids_q: [1 1 1 0 2 2 2 2]
# run_length_out_kv: [4 4 4 4 0 0 10 10 10 10 10 10 10 10 10 10]
# segment_pos_kv: [0 1 2 3 4 5 0 1 2 3 4 5 6 7 8 9]
# segment_ids_kv: [1 1 1 1 0 0 2 2 2 2 2 2 2 2 2 2]
# brcm: [[[1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
# [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0]
# [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]
# [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]
# [1 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0]
# [1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
# [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0]
# [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]]]
# attn_mask(noswa):[[[1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
# [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
# [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0]
# [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
# [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]]]
bottom_right_causal_mask = make_attention_mask(
run_length_out_q - segment_pos_q,
run_length_out_kv - segment_pos_kv,
jnp.less_equal,
)
attn_mask = jnp.logical_and(segment_mask, bottom_right_causal_mask)
elif attn_mask_type.is_causal():
causal_mask = make_attention_mask(
segment_pos_q,
segment_pos_kv,
jnp.greater_equal,
)
attn_mask = jnp.logical_and(segment_mask, causal_mask)

attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0)
q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset(
attn_mask_with_id, max_segments_per_seq
segment_pos_q,
segment_pos_kv,
attn_mask_type,
max_segments_per_seq,
)
return q_seqlen, kv_seqlen, q_offset, kv_offset


def _segment_ids_to_seqlens(segment_ids_q, segment_ids_kv, attn_mask_type):
Expand Down
Loading