feat(attention): add TreeFlashAttention backend with dispatch integration#2
Open
racky-scitix wants to merge 4 commits intoscitix-mainfrom
Open
feat(attention): add TreeFlashAttention backend with dispatch integration#2racky-scitix wants to merge 4 commits intoscitix-mainfrom
racky-scitix wants to merge 4 commits intoscitix-mainfrom
Conversation
Adds a TE-native TreeFlashAttention class + _FA3TreeAttnFunc autograd
wrapper that routes packed THD trie sequences through the FA3 tree
attention kernel (flash_attn_interface). The backend is self-contained
— it does not yet hook into get_attention_backend; consumers that
already know they want the tree path (e.g. slime-trainer's
Megatron-side TETreeDotProductAttention) can import and use it
directly.
Follow-ups:
* extend AttentionParams with tree_attention + tree_metadata
* add use_tree_attention branch in get_attention_backend
* instantiate self.tree_attention in DotProductAttention.__init__
and dispatch from DotProductAttention.forward
For the initial tree-training native migration in scitix/slime-trainer,
the Megatron-side TETreeDotProductAttention calls _FA3TreeAttnFunc
directly. Once the dispatch plumbing above is in place, that path can
be slimmed to a thin TreeFlashAttention delegation.
cp_size == 1 enforced at the call-site layer (Megatron TETreeDPA).
NVTE_TREE_ATTN env switch reserved for kill-switch / debugging in
future TE-side dispatch work.
Refs scitix/slime-trainer#220.
…n dispatch Make TreeFlashAttention a first-class backend alongside Flash/Fused/ Unfused so consumers can request it through the normal DotProductAttention forward() surface instead of reaching into the backend class directly. Changes: - AttentionParams gains tree_attention: bool = False. When set, the caller is advertising that it wants the tree path and is passing the tree topology + precompute along with q/k/v. - get_attention_backend short-circuits early when tree_attention=True: validates qkv_layout starts with 'thd', rejects context_parallel (Stage 2+ of the migration plan), honors NVTE_TREE_ATTN=0 as a kill switch, and returns a 7th tuple element use_tree_attention signalling the tree branch. The non-tree path now returns False in that slot so existing callers do not change their truthy checks. - DotProductAttention.__init__ always instantiates self.tree_attention. The module itself only holds the softmax scale; the FA3 wheel is imported lazily inside its forward, so non-tree deployments pay no import cost. - DotProductAttention.forward grows tree_cu_node_lens / tree_node_parent / tree_precomputed kwargs. Presence of all three sets tree_attention on the AttentionParams that feeds get_attention_backend, so the dispatch cache sees a distinct key and recomputes on the first tree call. The dispatch chain routes to self.tree_attention(...) ahead of the Flash/Fused/Unfused branches when use_tree_attention is True. - _attention_backends cache gains a use_tree_attention slot. This lets upstream wrappers (e.g. Megatron's TETreeDotProductAttention) delegate to TE via super().forward(...) instead of carrying their own FA3 autograd wrapper, which is also the prerequisite for adding CP to the tree backend without touching callers. Refs scitix/slime-trainer#220.
…wargs DotProductAttention.forward now takes a single tree_metadata=None kwarg (the Megatron TreeMetadata object) instead of separate tree_cu_node_lens / tree_node_parent / tree_precomputed kwargs. Unpacking happens inside TE at the dispatch site. This lets Megatron's TEDotProductAttention pass tree_metadata through the existing packed_seq_kwargs dict without any tree-specific code on the Megatron side. Refs scitix/slime-trainer#220.
Megatron's packed-sequence path hands attention [T, H, D] tensors (batch dim squeezed) rather than [S, B, H, D]. Detect the 3D case, route it through the FA3 kernel via an unsqueezed [1, T, H, D], and return [T, H*D] so the caller's output shape matches its input convention. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Add a TreeFlashAttention backend to TransformerEngine for tree-structured RL rollouts (slime-trainer tree-native training). Branch is based on
scitix-main(= NVIDIA upstream769ed778), so the diff contains only the 4 tree-attention commits without upstream-sync noise.Commits
5a8ec958feat(attention): add TreeFlashAttention backend for tree-trainingdeb33ef7feat(attention): integrate TreeFlashAttention into DotProductAttention dispatchebfa49a6refactor(attention): accept tree_metadata directly instead of three kwargsb293fc4dfeat(attention): support THD input shape in TreeFlashAttentionHow it is used
vendors/megatron-lmbranchsirl-dev-tree-nativeviaTETreeDotProductAttention, which delegates to TE's tree dispatch.sirl/backends/megatron_utils/tree/and is gated by--enable-tree-training.Verification
rel_max = 1.96e-08,rel_mean = 2.32e-10.rel_errin the 1–3e-4 band (2.35e-04).actor_trainspeedup vs varlen: 1.07× (pg1) / 3.93× (pg4-intra) / 7.74× (pg8-intra).docs/analysis/analysis_tree_attention_benchmark.mdin slime-trainer.Notes
b293fc4d, but rebased base frommaintoscitix-mainso only the 4 tree commits show up in the diff).ebfa49a6; rebuilding fromb293fc4dis needed to pick up the THD-input-shape support.