diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 29d0848381..f54a043fd2 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -11,7 +11,6 @@ from jax.ad_checkpoint import checkpoint_name import jax import jax.numpy as jnp -from flax.linen import make_attention_mask from transformer_engine_jax import NVTE_Bias_Type from transformer_engine_jax import NVTE_Mask_Type @@ -541,6 +540,149 @@ def run_length_fill(segment_ids) -> jnp.ndarray: return run_length_segment_id_shape.reshape(orig_shape) +def _get_seqlens_offsets_thd( + segment_ids_q, + segment_ids_kv, + segment_pos_q, + segment_pos_kv, + attn_mask_type, + max_segments_per_seq, +): + """O(T * max_segments_per_seq) replacement for the older O(T^2) mask-based slow path. + Returns (q_seqlen, kv_seqlen, q_offset, kv_offset) values to match the reference older mask-based path: + segment_mask = make_attention_mask(q_ids, kv_ids, equal) + segment_mask_with_id = make_attention_mask(q_ids, kv_ids, equal * q_id) + attn_mask = segment_mask AND (causal_or_brcm_or_none) + attn_mask_with_id = where(attn_mask, segment_mask_with_id, 0) + row_ids = reduce_max(attn_mask_with_id, axis=kv) # [B, T_q] + col_ids = reduce_max(attn_mask_with_id, axis=q) # [B, T_kv] + seqlens/offsets = bincount(...) / find_offsets(...) + The two reductions are expressed equivalently as per-segment aggregates: + - causal: row_ids[q] = q_seg_id iff seg_pos_q[q] >= min(seg_pos_kv over same-seg KV) + - brcm: row_ids[q] = q_seg_id iff (run_len_q - seg_pos_q) >= + min(run_len_kv - seg_pos_kv over same-seg KV) + - padding: row_ids[q] = q_seg_id iff q_seg_id appears in KV + (and symmetrically for col_ids with max/<=). + """ + + # Example: For striping P2P causal attention (but this logic also applies for non-CP fused attn) + # pre-striping and sharding: segment_ids = [[1 1 1 1 2 2 2 2]], segment_pos = [[0 1 2 3 0 1 2 3]] + # post-striping and sharding (striped CP=2, Q from rank 0 × KV from rank 1, max_segments_per_seq=2): + # segment_ids_q = [1 1 2 2] segment_pos_q = [0 2 0 2] → q_key = [0 2 0 2] + # segment_ids_kv = [1 1 2 2] segment_pos_kv = [1 3 1 3] → kv_key = [1 3 1 3] + # Q-side — kv_agg[s] = min(kv_key over same-seg KV), fill = max_fill_val = 5 (assumed to be large enough): + # scatter (rows = kv tokens, cols = segs): + # [5 1 5 / 5 3 5 / 5 5 1 / 5 5 3] → reduce min → kv_agg = [5 1 1] + # q_ok = q_key >= kv_agg[seg_ids_q] = [0 2 0 2] >= [1 1 1 1] = [F T F T] + # KV-side — q_agg[s] = max(q_key over same-seg Q), fill = neg_fill_val = -1 (assumed to be small enough): + # scatter: [-1 0 -1 / -1 2 -1 / -1 -1 0 / -1 -1 2] → reduce max → q_agg = [-1 2 2] + # kv_ok = kv_key <= q_agg[seg_ids_kv] = [1 3 1 3] <= [2 2 2 2] = [T F T F] + # Outer combiner: + # row_ids = [0 1 0 2] col_ids = [1 0 2 0] + # q_seqlen = [1 1] kv_seqlen = [1 1] + # q_offset = [1 3 -1] kv_offset = [0 2 -1] + def _row_and_col_ids(): + if attn_mask_type.is_bottom_right(): + # BRCM: mask[q][kv] = (same seg) AND (q_key <= kv_key). + rl_q = run_length_fill(segment_ids_q) + rl_kv = run_length_fill(segment_ids_kv) + q_key = (rl_q - segment_pos_q).astype(jnp.int32) + kv_key = (rl_kv - segment_pos_kv).astype(jnp.int32) + + # Use large positive and negative values as fill values for the KV keys and Q keys respectively + max_fill_val = jnp.asarray(jnp.iinfo(jnp.int32).max, dtype=jnp.int32) + neg_fill_val = jnp.asarray(-1, dtype=jnp.int32) + # Creates a one-hot encoding mask of the KV segment ids (size [B, T_kv, max_segments_per_seq+1]) + # i.e. each row has only one True value, which is the segment id of the row. + kv_oh = jax.nn.one_hot(segment_ids_kv, max_segments_per_seq + 1, dtype=jnp.bool_) + # Mask the KV keys with the valid segment ids (size [B, T_kv, 1]) + kv_key_masked = jnp.where(segment_ids_kv != 0, kv_key, neg_fill_val)[..., None] + # Scatter each KV key (i.e. seg pos) into it's own segment column + kv_agg = jnp.where(kv_oh, kv_key_masked, neg_fill_val) + kv_agg = jnp.max(kv_agg, axis=-2) + # Define causal relationship: Q is attended iff q_key <= max(kv_key over same-seg KV) + q_has_match = q_key <= jnp.take_along_axis( + kv_agg, segment_ids_q.astype(jnp.int32), axis=-1 + ) + + # Symmetric to the Q case, but with KV and Q swapped + q_oh = jax.nn.one_hot(segment_ids_q, max_segments_per_seq + 1, dtype=jnp.bool_) + q_key_masked = jnp.where(segment_ids_q != 0, q_key, max_fill_val)[..., None] + q_agg = jnp.where(q_oh, q_key_masked, max_fill_val) + q_agg = jnp.min(q_agg, axis=-2) + # Define causal relationship: KV is attended iff kv_key >= min(q_key over same-seg Q) + kv_has_match = kv_key >= jnp.take_along_axis( + q_agg, segment_ids_kv.astype(jnp.int32), axis=-1 + ) + elif attn_mask_type.is_causal(): + # CM: mask[q][kv] = (same_seg) AND (q_pos >= kv_pos). + q_key = segment_pos_q.astype(jnp.int32) + kv_key = segment_pos_kv.astype(jnp.int32) + + # Use large positive and negative values as a fill value for the KV keys and Q keys respectively + max_fill_val = jnp.asarray(jnp.iinfo(jnp.int32).max, dtype=jnp.int32) + neg_fill_val = jnp.asarray(-1, dtype=jnp.int32) + + # Creates a one-hot encoding mask of the KV segment ids (size [B, T_kv, max_segments_per_seq+1]) + # i.e. each row has only one True value, which is the segment id of the row. + kv_oh = jax.nn.one_hot(segment_ids_kv, max_segments_per_seq + 1, dtype=jnp.bool_) + # Mask the KV keys with the valid segment ids (size [B, T_kv, 1]) + kv_key_masked = jnp.where(segment_ids_kv != 0, kv_key, max_fill_val)[..., None] + # Scatter each KV key (i.e. seg pos) into it's own segment column + kv_agg = jnp.where(kv_oh, kv_key_masked, max_fill_val) + kv_agg = jnp.min(kv_agg, axis=-2) + # Define causal relationship: Q is attended iff q_key >= min(kv_key over same-seg KV) + q_has_match = q_key >= jnp.take_along_axis( + kv_agg, segment_ids_q.astype(jnp.int32), axis=-1 + ) + + # Symmetric to the Q case, but with KV and Q swapped + q_oh = jax.nn.one_hot(segment_ids_q, max_segments_per_seq + 1, dtype=jnp.bool_) + q_key_masked = jnp.where(segment_ids_q != 0, q_key, neg_fill_val)[..., None] + q_agg = jnp.where(q_oh, q_key_masked, neg_fill_val) + q_agg = jnp.max(q_agg, axis=-2) + # Define causal relationship: KV is attended iff kv_key <= max(q_key over same-seg Q) + kv_has_match = kv_key <= jnp.take_along_axis( + q_agg, segment_ids_kv.astype(jnp.int32), axis=-1 + ) + else: + # Padding-only: row_ids[q] = q_seg_id iff q_seg_id is present in KV (and q not pad). + kv_seg_ids_present = jax.nn.one_hot( + segment_ids_kv, max_segments_per_seq + 1, dtype=jnp.bool_ + ).any(axis=-2) + q_seg_ids_present = jax.nn.one_hot( + segment_ids_q, max_segments_per_seq + 1, dtype=jnp.bool_ + ).any(axis=-2) + q_has_match = jnp.take_along_axis( + kv_seg_ids_present, segment_ids_q.astype(jnp.int32), axis=-1 + ) & (segment_ids_q != 0) + kv_has_match = jnp.take_along_axis( + q_seg_ids_present, segment_ids_kv.astype(jnp.int32), axis=-1 + ) & (segment_ids_kv != 0) + + row_ids = jnp.where(q_has_match, segment_ids_q, 0).astype(jnp.int32) + col_ids = jnp.where(kv_has_match, segment_ids_kv, 0).astype(jnp.int32) + return row_ids, col_ids + + row_ids, col_ids = _row_and_col_ids() + + bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_segments_per_seq + 1)) + q_seqlen = bincount_vmap(row_ids)[..., 1:] + kv_seqlen = bincount_vmap(col_ids)[..., 1:] + + def _find_offsets(x): + same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0) + first_column = x[..., :1] != 0 + boundaries = jnp.concatenate([first_column, same_as_previous], axis=-1) + return jax.vmap(partial(jnp.argwhere, size=(max_segments_per_seq + 1), fill_value=-1))( + boundaries + ).squeeze(-1) + + q_offset = _find_offsets(row_ids) + kv_offset = _find_offsets(col_ids) + return q_seqlen, kv_seqlen, q_offset, kv_offset + + def _segment_ids_pos_to_seqlens_offsets( segment_ids_q, segment_ids_kv, @@ -550,9 +692,52 @@ def _segment_ids_pos_to_seqlens_offsets( window_size, max_segments_per_seq, ): + """Compute per-segment seqlens and start offsets(currently only used for THD) + Given segment-id and segment-position tensors for Q and KV, + returns the four metadata tensors cuDNN needed for variable-length attention: + q_seqlen : [..., max_segments_per_seq] # valid Q tokens per segment + kv_seqlen : [..., max_segments_per_seq] # valid KV tokens per segment + q_offset : [..., max_segments_per_seq + 1] # start index of each Q segment + kv_offset : [..., max_segments_per_seq + 1] # start index of each KV segment + + Args: + segment_ids_q: int32 [..., T_q] per-token segment id; 0 == padding + segment_ids_kv: int32 [..., T_kv] same convention as segment_ids_q + segment_pos_q: int32 [..., T_q] per-token position inside its segment + segment_pos_kv: int32 [..., T_kv] same convention as segment_pos_q + attn_mask_type: AttnMaskType. Selects the mask predicate used to decide + which positions are valid (top-left causal vs + bottom-right causal vs. padding-only) + window_size: Optional sliding-window tuple ``(left, right)`` or None + Used here only as a fast-path eligibility hint + max_segments_per_seq: maximum number of segments expected per row + Used to size the bincount / argwhere outputs + + Routing (only invoked for THD qkv_layout): + 1. Fast path -- ``_segment_ids_pos_to_seqlens_offsets_fast_causal_path``. + O(T) per row. Counts all segment tokens via bincount on + segment_ids and trims at most one token per segment at the + boundary. Used for: + - top-left CAUSAL / PADDING_CAUSAL with ``window_size is None`` + - SWA with ``window_size == (-1, -1)`` and not bottom-right + Bottom-right causal cross-attention is excluded: the boundary + trim leaves kv_seqlen short by one per active segment, which + shifts the BRCM bottom-right alignment by one KV per Q row. + + 2. Slow path -- ``_get_seqlens_offsets_thd``. + O(T * max_segments_per_seq) per row. Per-segment min/max + aggregation that is equivalent to the older O(T^2) + mask-based reference for top-left causal, bottom-right causal, + and padding-only masks. Required under ring attention where + ``segment_ids_q != segment_ids_kv`` in rotated steps. + + Returns: + Tuple ``(q_seqlen, kv_seqlen, q_offset, kv_offset)`` with shapes as + above. Inactive segment slots are filled with 0 in seqlens and -1 + in offsets. + """ # TODO(mgoldfarb-nvidia): Consider an opt-in for arbitrary masking if needed here. # Computing the full mask is expensive due to quadratic expansion of Q * KV masking. - # Assumptions for cudnn causal mask correctness. # 1. Segments are monotonic [4 4 4 0 0 5 5 5 6 6 0 0] # 2. No intra-segment padding, only inter-segment paddding allowed @@ -561,82 +746,30 @@ def _segment_ids_pos_to_seqlens_offsets( # 0 x x # 4 x x x x x # 8 x x x x x x x x - # # This fast path avoids expanding the mask to Q * KV matrix and instead allows us to # examine only O(Q+KV) elements. - - # For seqlens and seqoffsets calculations, the intermediate(temp) attn_mask creation - # using the segment ids and pos along with mask type (causal or brcm) is sufficient. - # It does not need to involve SW for this mask's creation - - # Currently, this function is only exercised for THD qkv_layout. - - # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well - if (attn_mask_type.is_causal() and window_size is None) or ( - window_size == (-1, -1) and not attn_mask_type.is_bottom_right() - ): + # The fast causal path encodes TOP-LEFT causal semantics via + # valid[q][kv] = (segment_pos_q >= segment_pos_kv) + # which is only equivalent to BRCM when s_q == s_kv (self-attention). For + # cross-attention (s_q != s_kv), BRCM diverges from top-left causal, so we + # must route bottom-right masks to the slow path. + + # Fast path: O(T) per row. + if ( + attn_mask_type.is_causal() and not attn_mask_type.is_bottom_right() and window_size is None + ) or (window_size == (-1, -1) and not attn_mask_type.is_bottom_right()): return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq ) - - # (1 = attend, 0 = masked) - segment_mask = make_attention_mask( - segment_ids_q, - segment_ids_kv, - jnp.equal, - ) - segment_mask_with_id = make_attention_mask( + # Slow path: O(T * max_segments_per_seq) per row. + return _get_seqlens_offsets_thd( segment_ids_q, segment_ids_kv, - lambda x, y: jnp.equal(x, y) * x, - ) - # TE JAX Attn expects the THD segments to have q_token <= kv_tokens so that a correct cross-attn type BRCM can be applied - attn_mask = segment_mask - if attn_mask_type.is_bottom_right(): - run_length_out_q = run_length_fill(segment_ids_q) - run_length_out_kv = run_length_fill(segment_ids_kv) - # Example for brcm: - # run_length_out_q: [3 3 3 0 4 4 4 4] - # segment_pos_q: [0 1 2 3 0 1 2 3] - # segment_ids_q: [1 1 1 0 2 2 2 2] - # run_length_out_kv: [4 4 4 4 0 0 10 10 10 10 10 10 10 10 10 10] - # segment_pos_kv: [0 1 2 3 4 5 0 1 2 3 4 5 6 7 8 9] - # segment_ids_kv: [1 1 1 1 0 0 2 2 2 2 2 2 2 2 2 2] - # brcm: [[[1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0] - # [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0] - # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1] - # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1] - # [1 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0] - # [1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0] - # [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0] - # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]]] - # attn_mask(noswa):[[[1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0] - # [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0] - # [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] - # [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0] - # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]]] - bottom_right_causal_mask = make_attention_mask( - run_length_out_q - segment_pos_q, - run_length_out_kv - segment_pos_kv, - jnp.less_equal, - ) - attn_mask = jnp.logical_and(segment_mask, bottom_right_causal_mask) - elif attn_mask_type.is_causal(): - causal_mask = make_attention_mask( - segment_pos_q, - segment_pos_kv, - jnp.greater_equal, - ) - attn_mask = jnp.logical_and(segment_mask, causal_mask) - - attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0) - q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset( - attn_mask_with_id, max_segments_per_seq + segment_pos_q, + segment_pos_kv, + attn_mask_type, + max_segments_per_seq, ) - return q_seqlen, kv_seqlen, q_offset, kv_offset def _segment_ids_to_seqlens(segment_ids_q, segment_ids_kv, attn_mask_type):