diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index c1ff46c75a..f9fdff10ff 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1953,3 +1953,128 @@ def forward( return output[0].view(*output[0].shape[:-2], -1), output[1] # ...hd -> ...(hd) return output.view(*output.shape[:-2], -1) + + +class _FA3TreeAttnFunc(torch.autograd.Function): + """Autograd wrapper for the FA3 tree attention kernel. + + Forward calls ``flash_attn_interface.flash_attn_tree_func``; backward + calls ``flash_attn_tree_bwd_func``. Tree topology + (``cu_node_lens`` / ``node_parent``) and the output of + ``precompute_tree_metadata`` are saved into the autograd context so + both the forward and backward legs see a consistent view of the trie. + """ + + @staticmethod + def forward(ctx, q, k, v, cu_node_lens, node_parent, softmax_scale, precomputed): + from flash_attn_interface import flash_attn_tree_func + + out, lse = flash_attn_tree_func( + q, k, v, cu_node_lens, node_parent, + softmax_scale=softmax_scale, + tree_metadata=precomputed, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_node_lens, node_parent) + ctx.softmax_scale = softmax_scale + ctx.precomputed = precomputed + return out + + @staticmethod + def backward(ctx, dout): + from flash_attn_interface import flash_attn_tree_bwd_func + + q, k, v, out, lse, cu_node_lens, node_parent = ctx.saved_tensors + dq, dk, dv = flash_attn_tree_bwd_func( + dout, q, k, v, out, lse, + cu_node_lens, node_parent, + softmax_scale=ctx.softmax_scale, + tree_metadata=ctx.precomputed, + ) + return dq, dk, dv, None, None, None, None + + +class TreeFlashAttention(torch.nn.Module): + """Tree-attention backend for packed THD sequences whose tokens form a trie. + + Routes through the FA3 tree attention kernel from + ``flash_attn_interface``, which shares the common prefix of N sibling + sequences across attention. Requires a Hopper-class GPU and the FA3 + wheel. + + Inputs: + * ``query / key / value``: ``[B, S, H, D]`` with ``B == 1`` for the + tree path. + * ``cu_node_lens``: ``int32`` ``[num_nodes + 1]`` cumulative per-node + token count. + * ``node_parent``: ``int32`` ``[num_nodes]`` parent index per node + (``-1`` for roots). + * ``precomputed``: opaque ``dict`` from + ``flash_attn_interface.precompute_tree_metadata`` — reused across + all transformer layers in a forward / backward pass. + + ``cp_size`` must be 1; CP-aware tree attention is tracked as a + follow-up. Intentionally does not implement ``get_attention_backend`` + dispatch — ``TreeFlashAttention`` is wired directly by consumers + (e.g. Megatron's ``TETreeDotProductAttention``) that already know they + want the tree path. Future work can extend ``get_attention_backend`` + with a ``use_tree_attention`` branch driven by a new + ``AttentionParams`` field. + + Env switch: ``NVTE_TREE_ATTN=0`` disables this backend for callers that + respect it (kill-switch for debugging). + """ + + def __init__( + self, + softmax_scale: float, + attention_type: str = "self", + layer_number: int | None = None, + ) -> None: + super().__init__() + if attention_type != "self": + raise ValueError("TreeFlashAttention only supports self-attention.") + self.softmax_scale = softmax_scale + self.attention_type = attention_type + self.layer_number = 1 if layer_number is None else layer_number + + def forward( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + cu_node_lens: torch.Tensor, + node_parent: torch.Tensor, + precomputed, + ) -> torch.Tensor: + """Forward. + + Accepts either SBHD ``[S, B, H, D]`` (Megatron convention) or + THD ``[T, H, D]`` (Megatron packed-sequence convention where + the batch dim has been squeezed). Output shape matches input + convention: ``[S, B, H*D]`` for SBHD, ``[T, H*D]`` for THD. + """ + # Normalize to [B, T, H, D] for FA3 + if query_layer.dim() == 3: + # THD [T, H, D] -> [1, T, H, D] + _thd_input = True + q = query_layer.unsqueeze(0).contiguous() + k = key_layer.unsqueeze(0).contiguous() + v = value_layer.unsqueeze(0).contiguous() + else: + # SBHD [S, B, H, D] -> [B, S, H, D] + _thd_input = False + q = query_layer.permute(1, 0, 2, 3).contiguous() + k = key_layer.permute(1, 0, 2, 3).contiguous() + v = value_layer.permute(1, 0, 2, 3).contiguous() + + out = _FA3TreeAttnFunc.apply( + q, k, v, cu_node_lens, node_parent, self.softmax_scale, precomputed, + ) + + if _thd_input: + # [1, T, H, D] -> [T, H*D] + return out.squeeze(0).reshape(query_layer.shape[0], -1) + else: + # [B, S, H, D] -> [S, B, H*D] + s, b = query_layer.shape[0], query_layer.shape[1] + return out.permute(1, 0, 2, 3).contiguous().view(s, b, -1) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 1303160965..468d0d06d9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -59,6 +59,7 @@ UnfusedDotProductAttention, FusedAttention, FlashAttention, + TreeFlashAttention, ) @@ -73,6 +74,7 @@ "use_fused_attention": None, "fused_attention_backend": None, "use_unfused_attention": None, + "use_tree_attention": None, "backend_selection_requires_update": False, } @@ -470,6 +472,15 @@ def __init__( return_max_logit=self.return_max_logit, ) + # Tree-attention backend. Lightweight module that holds the softmax + # scale; the FA3 wheel is imported lazily inside its forward, so + # non-tree deployments pay zero import cost. + self.tree_attention = TreeFlashAttention( + softmax_scale, + attention_type=attention_type, + layer_number=layer_number, + ) + def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument """ Temporarily remove core_attention._extra_state as a missing key @@ -820,6 +831,7 @@ def forward( pad_between_seqs: Optional[bool] = None, fp8_output: Optional[bool] = False, num_splits: Optional[int] = 1, + tree_metadata: Optional[Any] = None, ) -> torch.Tensor: r""" Dot Product Attention Layer. @@ -1341,6 +1353,7 @@ def forward( return_max_logit=self.return_max_logit, cuda_graph=is_graph_capturing(), num_splits=num_splits, + tree_attention=tree_metadata is not None, ) global _attention_backends if is_in_onnx_export_mode(): @@ -1349,6 +1362,7 @@ def forward( use_flash_attention = False use_fused_attention = False use_unfused_attention = True + use_tree_attention = False else: if ( _attention_backends["attention_params"] is None @@ -1364,6 +1378,7 @@ def forward( fused_attention_backend, use_unfused_attention, _, + use_tree_attention, ) = dpa_utils.get_attention_backend(attention_params) # Set global _attention_backends var using return value # from get_attention_backend() @@ -1372,8 +1387,11 @@ def forward( _attention_backends["use_fused_attention"] = use_fused_attention _attention_backends["fused_attention_backend"] = fused_attention_backend _attention_backends["use_unfused_attention"] = use_unfused_attention + _attention_backends["use_tree_attention"] = use_tree_attention _attention_backends["backend_selection_requires_update"] = False - if use_flash_attention: + if use_tree_attention: + self.logger.info("Running with TreeFlashAttention backend") + elif use_flash_attention: self.logger.info( "Running with FlashAttention backend (version %s)", flash_attention_backend, @@ -1391,9 +1409,15 @@ def forward( use_fused_attention = _attention_backends["use_fused_attention"] fused_attention_backend = _attention_backends["fused_attention_backend"] use_unfused_attention = _attention_backends["use_unfused_attention"] + use_tree_attention = _attention_backends["use_tree_attention"] # raise exception if no backend is available - if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0: + if sum([ + use_flash_attention, + use_fused_attention, + use_unfused_attention, + bool(use_tree_attention), + ]) == 0: raise ValueError( "No dot product attention backend is available for the provided inputs. Please" " run with NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 to find out the reasons for" @@ -1407,6 +1431,16 @@ def forward( else None ) + if use_tree_attention: + return self.tree_attention( + query_layer, + key_layer, + value_layer, + cu_node_lens=tree_metadata.cu_node_lens, + node_parent=tree_metadata.node_parent, + precomputed=tree_metadata.precomputed, + ) + if use_flash_attention: if core_attention_bias_type == "alibi": alibi_slopes, _ = dpa_utils.get_alibi( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 8c6b6afc90..9a3cf2173f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -266,6 +266,11 @@ class AttentionParams: return_max_logit: bool = False cuda_graph: bool = False num_splits: int = 1 + # Tree-attention dispatch flag. When True, get_attention_backend short- + # circuits to the TreeFlashAttention backend; the caller must also pass + # tree_cu_node_lens / tree_node_parent / tree_precomputed to + # DotProductAttention.forward at call time. + tree_attention: bool = False def __eq__(self, other): """ @@ -342,12 +347,50 @@ def get_attention_backend( return_max_logit = attention_params.return_max_logit cuda_graph = attention_params.cuda_graph num_splits = attention_params.num_splits + tree_attention = attention_params.tree_attention # Run config logger = logging.getLogger("DotProductAttention") logger.setLevel(AttentionLogging._log_level) if not logger.hasHandlers(): logger.addHandler(AttentionLogging._stream_handler) + + # Tree-attention early dispatch: when the caller advertises tree mode we + # bypass the normal Flash/Fused/Unfused selection and route directly to + # TreeFlashAttention. Tree requires `thd` layout, no CP (Stage 2+), and + # FA3 to be installed; the TreeFlashAttention backend itself rechecks + # runtime-specific preconditions (FA3 wheel importability, etc). + if tree_attention: + tree_kill_switch = int(os.getenv("NVTE_TREE_ATTN", "1")) == 0 + if tree_kill_switch: + logger.debug("NVTE_TREE_ATTN=0; tree_attention disabled") + return ( + False, None, + False, None, + False, + [False, False, False], + False, + ) + if not qkv_layout.startswith("thd"): + raise ValueError( + f"tree_attention requires qkv_layout starting with 'thd' " + f"(got '{qkv_layout}')" + ) + if context_parallel: + raise ValueError( + "tree_attention does not yet support context parallelism. " + "CP-aware tree attention is tracked as Stage 2+ of the " + "tree-training migration plan." + ) + logger.debug("Selected backend = TreeFlashAttention") + return ( + False, None, + False, None, + False, + [False, False, False], + True, + ) + device_compute_capability = get_device_compute_capability() cudnn_version = get_cudnn_version() run_config = { @@ -1166,6 +1209,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt fused_attention_backend, use_unfused_attention, available_backends, + False, # use_tree_attention — tree path uses the early return above )