Skip to content

feat(attention): add TreeFlashAttention backend with dispatch integration#2

Open
racky-scitix wants to merge 4 commits intoscitix-mainfrom
tree-training-native
Open

feat(attention): add TreeFlashAttention backend with dispatch integration#2
racky-scitix wants to merge 4 commits intoscitix-mainfrom
tree-training-native

Conversation

@racky-scitix
Copy link
Copy Markdown
Collaborator

Summary

Add a TreeFlashAttention backend to TransformerEngine for tree-structured RL rollouts (slime-trainer tree-native training). Branch is based on scitix-main (= NVIDIA upstream 769ed778), so the diff contains only the 4 tree-attention commits without upstream-sync noise.

Commits

  • 5a8ec958 feat(attention): add TreeFlashAttention backend for tree-training
  • deb33ef7 feat(attention): integrate TreeFlashAttention into DotProductAttention dispatch
  • ebfa49a6 refactor(attention): accept tree_metadata directly instead of three kwargs
  • b293fc4d feat(attention): support THD input shape in TreeFlashAttention

How it is used

  • Consumed by vendors/megatron-lm branch sirl-dev-tree-native via TETreeDotProductAttention, which delegates to TE's tree dispatch.
  • slime-trainer side lives under sirl/backends/megatron_utils/tree/ and is gated by --enable-tree-training.

Verification

  • Forward bit-exact vs varlen: rel_max = 1.96e-08, rel_mean = 2.32e-10.
  • Backward rel_err in the 1–3e-4 band (2.35e-04).
  • actor_train speedup vs varlen: 1.07× (pg1) / 3.93× (pg4-intra) / 7.74× (pg8-intra).
  • Full tables: see docs/analysis/analysis_tree_attention_benchmark.md in slime-trainer.

Notes

MaoChouHJM and others added 4 commits April 15, 2026 21:37
Adds a TE-native TreeFlashAttention class + _FA3TreeAttnFunc autograd
wrapper that routes packed THD trie sequences through the FA3 tree
attention kernel (flash_attn_interface). The backend is self-contained
— it does not yet hook into get_attention_backend; consumers that
already know they want the tree path (e.g. slime-trainer's
Megatron-side TETreeDotProductAttention) can import and use it
directly.

Follow-ups:
  * extend AttentionParams with tree_attention + tree_metadata
  * add use_tree_attention branch in get_attention_backend
  * instantiate self.tree_attention in DotProductAttention.__init__
    and dispatch from DotProductAttention.forward

For the initial tree-training native migration in scitix/slime-trainer,
the Megatron-side TETreeDotProductAttention calls _FA3TreeAttnFunc
directly. Once the dispatch plumbing above is in place, that path can
be slimmed to a thin TreeFlashAttention delegation.

cp_size == 1 enforced at the call-site layer (Megatron TETreeDPA).
NVTE_TREE_ATTN env switch reserved for kill-switch / debugging in
future TE-side dispatch work.

Refs scitix/slime-trainer#220.
…n dispatch

Make TreeFlashAttention a first-class backend alongside Flash/Fused/
Unfused so consumers can request it through the normal DotProductAttention
forward() surface instead of reaching into the backend class directly.

Changes:

- AttentionParams gains tree_attention: bool = False. When set, the
  caller is advertising that it wants the tree path and is passing the
  tree topology + precompute along with q/k/v.
- get_attention_backend short-circuits early when tree_attention=True:
  validates qkv_layout starts with 'thd', rejects context_parallel
  (Stage 2+ of the migration plan), honors NVTE_TREE_ATTN=0 as a kill
  switch, and returns a 7th tuple element use_tree_attention signalling
  the tree branch. The non-tree path now returns False in that slot so
  existing callers do not change their truthy checks.
- DotProductAttention.__init__ always instantiates self.tree_attention.
  The module itself only holds the softmax scale; the FA3 wheel is
  imported lazily inside its forward, so non-tree deployments pay no
  import cost.
- DotProductAttention.forward grows tree_cu_node_lens / tree_node_parent
  / tree_precomputed kwargs. Presence of all three sets tree_attention
  on the AttentionParams that feeds get_attention_backend, so the
  dispatch cache sees a distinct key and recomputes on the first tree
  call. The dispatch chain routes to self.tree_attention(...) ahead of
  the Flash/Fused/Unfused branches when use_tree_attention is True.
- _attention_backends cache gains a use_tree_attention slot.

This lets upstream wrappers (e.g. Megatron's TETreeDotProductAttention)
delegate to TE via super().forward(...) instead of carrying their own
FA3 autograd wrapper, which is also the prerequisite for adding CP to
the tree backend without touching callers.

Refs scitix/slime-trainer#220.
…wargs

DotProductAttention.forward now takes a single tree_metadata=None kwarg
(the Megatron TreeMetadata object) instead of separate tree_cu_node_lens
/ tree_node_parent / tree_precomputed kwargs. Unpacking happens inside
TE at the dispatch site. This lets Megatron's TEDotProductAttention pass
tree_metadata through the existing packed_seq_kwargs dict without any
tree-specific code on the Megatron side.

Refs scitix/slime-trainer#220.
Megatron's packed-sequence path hands attention [T, H, D] tensors (batch
dim squeezed) rather than [S, B, H, D]. Detect the 3D case, route it
through the FA3 kernel via an unsqueezed [1, T, H, D], and return
[T, H*D] so the caller's output shape matches its input convention.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants