From 5a8ec958c08be11436eaaec7684a158355b68610 Mon Sep 17 00:00:00 2001 From: racky-scitix Date: Wed, 15 Apr 2026 21:37:47 +0800 Subject: [PATCH 1/4] feat(attention): add TreeFlashAttention backend for tree-training MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a TE-native TreeFlashAttention class + _FA3TreeAttnFunc autograd wrapper that routes packed THD trie sequences through the FA3 tree attention kernel (flash_attn_interface). The backend is self-contained — it does not yet hook into get_attention_backend; consumers that already know they want the tree path (e.g. slime-trainer's Megatron-side TETreeDotProductAttention) can import and use it directly. Follow-ups: * extend AttentionParams with tree_attention + tree_metadata * add use_tree_attention branch in get_attention_backend * instantiate self.tree_attention in DotProductAttention.__init__ and dispatch from DotProductAttention.forward For the initial tree-training native migration in scitix/slime-trainer, the Megatron-side TETreeDotProductAttention calls _FA3TreeAttnFunc directly. Once the dispatch plumbing above is in place, that path can be slimmed to a thin TreeFlashAttention delegation. cp_size == 1 enforced at the call-site layer (Megatron TETreeDPA). NVTE_TREE_ATTN env switch reserved for kill-switch / debugging in future TE-side dispatch work. Refs scitix/slime-trainer#220. --- .../dot_product_attention/backends.py | 110 ++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index c1ff46c75a..7d17623ad9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1953,3 +1953,113 @@ def forward( return output[0].view(*output[0].shape[:-2], -1), output[1] # ...hd -> ...(hd) return output.view(*output.shape[:-2], -1) + + +class _FA3TreeAttnFunc(torch.autograd.Function): + """Autograd wrapper for the FA3 tree attention kernel. + + Forward calls ``flash_attn_interface.flash_attn_tree_func``; backward + calls ``flash_attn_tree_bwd_func``. Tree topology + (``cu_node_lens`` / ``node_parent``) and the output of + ``precompute_tree_metadata`` are saved into the autograd context so + both the forward and backward legs see a consistent view of the trie. + """ + + @staticmethod + def forward(ctx, q, k, v, cu_node_lens, node_parent, softmax_scale, precomputed): + from flash_attn_interface import flash_attn_tree_func + + out, lse = flash_attn_tree_func( + q, k, v, cu_node_lens, node_parent, + softmax_scale=softmax_scale, + tree_metadata=precomputed, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_node_lens, node_parent) + ctx.softmax_scale = softmax_scale + ctx.precomputed = precomputed + return out + + @staticmethod + def backward(ctx, dout): + from flash_attn_interface import flash_attn_tree_bwd_func + + q, k, v, out, lse, cu_node_lens, node_parent = ctx.saved_tensors + dq, dk, dv = flash_attn_tree_bwd_func( + dout, q, k, v, out, lse, + cu_node_lens, node_parent, + softmax_scale=ctx.softmax_scale, + tree_metadata=ctx.precomputed, + ) + return dq, dk, dv, None, None, None, None + + +class TreeFlashAttention(torch.nn.Module): + """Tree-attention backend for packed THD sequences whose tokens form a trie. + + Routes through the FA3 tree attention kernel from + ``flash_attn_interface``, which shares the common prefix of N sibling + sequences across attention. Requires a Hopper-class GPU and the FA3 + wheel. + + Inputs: + * ``query / key / value``: ``[B, S, H, D]`` with ``B == 1`` for the + tree path. + * ``cu_node_lens``: ``int32`` ``[num_nodes + 1]`` cumulative per-node + token count. + * ``node_parent``: ``int32`` ``[num_nodes]`` parent index per node + (``-1`` for roots). + * ``precomputed``: opaque ``dict`` from + ``flash_attn_interface.precompute_tree_metadata`` — reused across + all transformer layers in a forward / backward pass. + + ``cp_size`` must be 1; CP-aware tree attention is tracked as a + follow-up. Intentionally does not implement ``get_attention_backend`` + dispatch — ``TreeFlashAttention`` is wired directly by consumers + (e.g. Megatron's ``TETreeDotProductAttention``) that already know they + want the tree path. Future work can extend ``get_attention_backend`` + with a ``use_tree_attention`` branch driven by a new + ``AttentionParams`` field. + + Env switch: ``NVTE_TREE_ATTN=0`` disables this backend for callers that + respect it (kill-switch for debugging). + """ + + def __init__( + self, + softmax_scale: float, + attention_type: str = "self", + layer_number: int | None = None, + ) -> None: + super().__init__() + if attention_type != "self": + raise ValueError("TreeFlashAttention only supports self-attention.") + self.softmax_scale = softmax_scale + self.attention_type = attention_type + self.layer_number = 1 if layer_number is None else layer_number + + def forward( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + cu_node_lens: torch.Tensor, + node_parent: torch.Tensor, + precomputed, + ) -> torch.Tensor: + """Forward. + + Inputs ``[S, B, H, D]`` (Megatron convention). Output + ``[S, B, H * D]`` to match the rest of the TE attention stack. + """ + # [S, B, H, D] -> [B, S, H, D]; B == 1 for tree training so this + # permute is a no-copy reshape. + q = query_layer.permute(1, 0, 2, 3).contiguous() + k = key_layer.permute(1, 0, 2, 3).contiguous() + v = value_layer.permute(1, 0, 2, 3).contiguous() + + out = _FA3TreeAttnFunc.apply( + q, k, v, cu_node_lens, node_parent, self.softmax_scale, precomputed, + ) + + s, b = query_layer.shape[0], query_layer.shape[1] + return out.permute(1, 0, 2, 3).contiguous().view(s, b, -1) From deb33ef78adf1426bc7e036de00b386252b6af3c Mon Sep 17 00:00:00 2001 From: racky-scitix Date: Wed, 15 Apr 2026 22:35:33 +0800 Subject: [PATCH 2/4] feat(attention): integrate TreeFlashAttention into DotProductAttention dispatch Make TreeFlashAttention a first-class backend alongside Flash/Fused/ Unfused so consumers can request it through the normal DotProductAttention forward() surface instead of reaching into the backend class directly. Changes: - AttentionParams gains tree_attention: bool = False. When set, the caller is advertising that it wants the tree path and is passing the tree topology + precompute along with q/k/v. - get_attention_backend short-circuits early when tree_attention=True: validates qkv_layout starts with 'thd', rejects context_parallel (Stage 2+ of the migration plan), honors NVTE_TREE_ATTN=0 as a kill switch, and returns a 7th tuple element use_tree_attention signalling the tree branch. The non-tree path now returns False in that slot so existing callers do not change their truthy checks. - DotProductAttention.__init__ always instantiates self.tree_attention. The module itself only holds the softmax scale; the FA3 wheel is imported lazily inside its forward, so non-tree deployments pay no import cost. - DotProductAttention.forward grows tree_cu_node_lens / tree_node_parent / tree_precomputed kwargs. Presence of all three sets tree_attention on the AttentionParams that feeds get_attention_backend, so the dispatch cache sees a distinct key and recomputes on the first tree call. The dispatch chain routes to self.tree_attention(...) ahead of the Flash/Fused/Unfused branches when use_tree_attention is True. - _attention_backends cache gains a use_tree_attention slot. This lets upstream wrappers (e.g. Megatron's TETreeDotProductAttention) delegate to TE via super().forward(...) instead of carrying their own FA3 autograd wrapper, which is also the prerequisite for adding CP to the tree backend without touching callers. Refs scitix/slime-trainer#220. --- .../dot_product_attention.py | 44 ++++++++++++++++++- .../attention/dot_product_attention/utils.py | 44 +++++++++++++++++++ 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 1303160965..533bb0e164 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -59,6 +59,7 @@ UnfusedDotProductAttention, FusedAttention, FlashAttention, + TreeFlashAttention, ) @@ -73,6 +74,7 @@ "use_fused_attention": None, "fused_attention_backend": None, "use_unfused_attention": None, + "use_tree_attention": None, "backend_selection_requires_update": False, } @@ -470,6 +472,15 @@ def __init__( return_max_logit=self.return_max_logit, ) + # Tree-attention backend. Lightweight module that holds the softmax + # scale; the FA3 wheel is imported lazily inside its forward, so + # non-tree deployments pay zero import cost. + self.tree_attention = TreeFlashAttention( + softmax_scale, + attention_type=attention_type, + layer_number=layer_number, + ) + def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument """ Temporarily remove core_attention._extra_state as a missing key @@ -820,6 +831,9 @@ def forward( pad_between_seqs: Optional[bool] = None, fp8_output: Optional[bool] = False, num_splits: Optional[int] = 1, + tree_cu_node_lens: Optional[torch.Tensor] = None, + tree_node_parent: Optional[torch.Tensor] = None, + tree_precomputed: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: r""" Dot Product Attention Layer. @@ -1341,6 +1355,11 @@ def forward( return_max_logit=self.return_max_logit, cuda_graph=is_graph_capturing(), num_splits=num_splits, + tree_attention=( + tree_cu_node_lens is not None + and tree_node_parent is not None + and tree_precomputed is not None + ), ) global _attention_backends if is_in_onnx_export_mode(): @@ -1349,6 +1368,7 @@ def forward( use_flash_attention = False use_fused_attention = False use_unfused_attention = True + use_tree_attention = False else: if ( _attention_backends["attention_params"] is None @@ -1364,6 +1384,7 @@ def forward( fused_attention_backend, use_unfused_attention, _, + use_tree_attention, ) = dpa_utils.get_attention_backend(attention_params) # Set global _attention_backends var using return value # from get_attention_backend() @@ -1372,8 +1393,11 @@ def forward( _attention_backends["use_fused_attention"] = use_fused_attention _attention_backends["fused_attention_backend"] = fused_attention_backend _attention_backends["use_unfused_attention"] = use_unfused_attention + _attention_backends["use_tree_attention"] = use_tree_attention _attention_backends["backend_selection_requires_update"] = False - if use_flash_attention: + if use_tree_attention: + self.logger.info("Running with TreeFlashAttention backend") + elif use_flash_attention: self.logger.info( "Running with FlashAttention backend (version %s)", flash_attention_backend, @@ -1391,9 +1415,15 @@ def forward( use_fused_attention = _attention_backends["use_fused_attention"] fused_attention_backend = _attention_backends["fused_attention_backend"] use_unfused_attention = _attention_backends["use_unfused_attention"] + use_tree_attention = _attention_backends["use_tree_attention"] # raise exception if no backend is available - if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0: + if sum([ + use_flash_attention, + use_fused_attention, + use_unfused_attention, + bool(use_tree_attention), + ]) == 0: raise ValueError( "No dot product attention backend is available for the provided inputs. Please" " run with NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 to find out the reasons for" @@ -1407,6 +1437,16 @@ def forward( else None ) + if use_tree_attention: + return self.tree_attention( + query_layer, + key_layer, + value_layer, + cu_node_lens=tree_cu_node_lens, + node_parent=tree_node_parent, + precomputed=tree_precomputed, + ) + if use_flash_attention: if core_attention_bias_type == "alibi": alibi_slopes, _ = dpa_utils.get_alibi( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 8c6b6afc90..9a3cf2173f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -266,6 +266,11 @@ class AttentionParams: return_max_logit: bool = False cuda_graph: bool = False num_splits: int = 1 + # Tree-attention dispatch flag. When True, get_attention_backend short- + # circuits to the TreeFlashAttention backend; the caller must also pass + # tree_cu_node_lens / tree_node_parent / tree_precomputed to + # DotProductAttention.forward at call time. + tree_attention: bool = False def __eq__(self, other): """ @@ -342,12 +347,50 @@ def get_attention_backend( return_max_logit = attention_params.return_max_logit cuda_graph = attention_params.cuda_graph num_splits = attention_params.num_splits + tree_attention = attention_params.tree_attention # Run config logger = logging.getLogger("DotProductAttention") logger.setLevel(AttentionLogging._log_level) if not logger.hasHandlers(): logger.addHandler(AttentionLogging._stream_handler) + + # Tree-attention early dispatch: when the caller advertises tree mode we + # bypass the normal Flash/Fused/Unfused selection and route directly to + # TreeFlashAttention. Tree requires `thd` layout, no CP (Stage 2+), and + # FA3 to be installed; the TreeFlashAttention backend itself rechecks + # runtime-specific preconditions (FA3 wheel importability, etc). + if tree_attention: + tree_kill_switch = int(os.getenv("NVTE_TREE_ATTN", "1")) == 0 + if tree_kill_switch: + logger.debug("NVTE_TREE_ATTN=0; tree_attention disabled") + return ( + False, None, + False, None, + False, + [False, False, False], + False, + ) + if not qkv_layout.startswith("thd"): + raise ValueError( + f"tree_attention requires qkv_layout starting with 'thd' " + f"(got '{qkv_layout}')" + ) + if context_parallel: + raise ValueError( + "tree_attention does not yet support context parallelism. " + "CP-aware tree attention is tracked as Stage 2+ of the " + "tree-training migration plan." + ) + logger.debug("Selected backend = TreeFlashAttention") + return ( + False, None, + False, None, + False, + [False, False, False], + True, + ) + device_compute_capability = get_device_compute_capability() cudnn_version = get_cudnn_version() run_config = { @@ -1166,6 +1209,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt fused_attention_backend, use_unfused_attention, available_backends, + False, # use_tree_attention — tree path uses the early return above ) From ebfa49a6521e139a6d87dbad6e3f92df85f3c14e Mon Sep 17 00:00:00 2001 From: racky-scitix Date: Thu, 16 Apr 2026 11:10:01 +0800 Subject: [PATCH 3/4] refactor(attention): accept tree_metadata directly instead of three kwargs DotProductAttention.forward now takes a single tree_metadata=None kwarg (the Megatron TreeMetadata object) instead of separate tree_cu_node_lens / tree_node_parent / tree_precomputed kwargs. Unpacking happens inside TE at the dispatch site. This lets Megatron's TEDotProductAttention pass tree_metadata through the existing packed_seq_kwargs dict without any tree-specific code on the Megatron side. Refs scitix/slime-trainer#220. --- .../dot_product_attention.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 533bb0e164..468d0d06d9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -831,9 +831,7 @@ def forward( pad_between_seqs: Optional[bool] = None, fp8_output: Optional[bool] = False, num_splits: Optional[int] = 1, - tree_cu_node_lens: Optional[torch.Tensor] = None, - tree_node_parent: Optional[torch.Tensor] = None, - tree_precomputed: Optional[Dict[str, Any]] = None, + tree_metadata: Optional[Any] = None, ) -> torch.Tensor: r""" Dot Product Attention Layer. @@ -1355,11 +1353,7 @@ def forward( return_max_logit=self.return_max_logit, cuda_graph=is_graph_capturing(), num_splits=num_splits, - tree_attention=( - tree_cu_node_lens is not None - and tree_node_parent is not None - and tree_precomputed is not None - ), + tree_attention=tree_metadata is not None, ) global _attention_backends if is_in_onnx_export_mode(): @@ -1442,9 +1436,9 @@ def forward( query_layer, key_layer, value_layer, - cu_node_lens=tree_cu_node_lens, - node_parent=tree_node_parent, - precomputed=tree_precomputed, + cu_node_lens=tree_metadata.cu_node_lens, + node_parent=tree_metadata.node_parent, + precomputed=tree_metadata.precomputed, ) if use_flash_attention: From b293fc4de62c4d2abc65c2b35f7f39b161b2d083 Mon Sep 17 00:00:00 2001 From: racky-scitix <269434192+racky-scitix@users.noreply.github.com> Date: Wed, 22 Apr 2026 10:56:21 +0800 Subject: [PATCH 4/4] feat(attention): support THD input shape in TreeFlashAttention Megatron's packed-sequence path hands attention [T, H, D] tensors (batch dim squeezed) rather than [S, B, H, D]. Detect the 3D case, route it through the FA3 kernel via an unsqueezed [1, T, H, D], and return [T, H*D] so the caller's output shape matches its input convention. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../dot_product_attention/backends.py | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 7d17623ad9..f9fdff10ff 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -2048,18 +2048,33 @@ def forward( ) -> torch.Tensor: """Forward. - Inputs ``[S, B, H, D]`` (Megatron convention). Output - ``[S, B, H * D]`` to match the rest of the TE attention stack. + Accepts either SBHD ``[S, B, H, D]`` (Megatron convention) or + THD ``[T, H, D]`` (Megatron packed-sequence convention where + the batch dim has been squeezed). Output shape matches input + convention: ``[S, B, H*D]`` for SBHD, ``[T, H*D]`` for THD. """ - # [S, B, H, D] -> [B, S, H, D]; B == 1 for tree training so this - # permute is a no-copy reshape. - q = query_layer.permute(1, 0, 2, 3).contiguous() - k = key_layer.permute(1, 0, 2, 3).contiguous() - v = value_layer.permute(1, 0, 2, 3).contiguous() + # Normalize to [B, T, H, D] for FA3 + if query_layer.dim() == 3: + # THD [T, H, D] -> [1, T, H, D] + _thd_input = True + q = query_layer.unsqueeze(0).contiguous() + k = key_layer.unsqueeze(0).contiguous() + v = value_layer.unsqueeze(0).contiguous() + else: + # SBHD [S, B, H, D] -> [B, S, H, D] + _thd_input = False + q = query_layer.permute(1, 0, 2, 3).contiguous() + k = key_layer.permute(1, 0, 2, 3).contiguous() + v = value_layer.permute(1, 0, 2, 3).contiguous() out = _FA3TreeAttnFunc.apply( q, k, v, cu_node_lens, node_parent, self.softmax_scale, precomputed, ) - s, b = query_layer.shape[0], query_layer.shape[1] - return out.permute(1, 0, 2, 3).contiguous().view(s, b, -1) + if _thd_input: + # [1, T, H, D] -> [T, H*D] + return out.squeeze(0).reshape(query_layer.shape[0], -1) + else: + # [B, S, H, D] -> [S, B, H*D] + s, b = query_layer.shape[0], query_layer.shape[1] + return out.permute(1, 0, 2, 3).contiguous().view(s, b, -1)