Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions transformer_engine/pytorch/attention/dot_product_attention/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -1953,3 +1953,128 @@ def forward(
return output[0].view(*output[0].shape[:-2], -1), output[1]
# ...hd -> ...(hd)
return output.view(*output.shape[:-2], -1)


class _FA3TreeAttnFunc(torch.autograd.Function):
"""Autograd wrapper for the FA3 tree attention kernel.

Forward calls ``flash_attn_interface.flash_attn_tree_func``; backward
calls ``flash_attn_tree_bwd_func``. Tree topology
(``cu_node_lens`` / ``node_parent``) and the output of
``precompute_tree_metadata`` are saved into the autograd context so
both the forward and backward legs see a consistent view of the trie.
"""

@staticmethod
def forward(ctx, q, k, v, cu_node_lens, node_parent, softmax_scale, precomputed):
from flash_attn_interface import flash_attn_tree_func

out, lse = flash_attn_tree_func(
q, k, v, cu_node_lens, node_parent,
softmax_scale=softmax_scale,
tree_metadata=precomputed,
)
ctx.save_for_backward(q, k, v, out, lse, cu_node_lens, node_parent)
ctx.softmax_scale = softmax_scale
ctx.precomputed = precomputed
return out

@staticmethod
def backward(ctx, dout):
from flash_attn_interface import flash_attn_tree_bwd_func

q, k, v, out, lse, cu_node_lens, node_parent = ctx.saved_tensors
dq, dk, dv = flash_attn_tree_bwd_func(
dout, q, k, v, out, lse,
cu_node_lens, node_parent,
softmax_scale=ctx.softmax_scale,
tree_metadata=ctx.precomputed,
)
return dq, dk, dv, None, None, None, None


class TreeFlashAttention(torch.nn.Module):
"""Tree-attention backend for packed THD sequences whose tokens form a trie.

Routes through the FA3 tree attention kernel from
``flash_attn_interface``, which shares the common prefix of N sibling
sequences across attention. Requires a Hopper-class GPU and the FA3
wheel.

Inputs:
* ``query / key / value``: ``[B, S, H, D]`` with ``B == 1`` for the
tree path.
* ``cu_node_lens``: ``int32`` ``[num_nodes + 1]`` cumulative per-node
token count.
* ``node_parent``: ``int32`` ``[num_nodes]`` parent index per node
(``-1`` for roots).
* ``precomputed``: opaque ``dict`` from
``flash_attn_interface.precompute_tree_metadata`` — reused across
all transformer layers in a forward / backward pass.

``cp_size`` must be 1; CP-aware tree attention is tracked as a
follow-up. Intentionally does not implement ``get_attention_backend``
dispatch — ``TreeFlashAttention`` is wired directly by consumers
(e.g. Megatron's ``TETreeDotProductAttention``) that already know they
want the tree path. Future work can extend ``get_attention_backend``
with a ``use_tree_attention`` branch driven by a new
``AttentionParams`` field.

Env switch: ``NVTE_TREE_ATTN=0`` disables this backend for callers that
respect it (kill-switch for debugging).
"""

def __init__(
self,
softmax_scale: float,
attention_type: str = "self",
layer_number: int | None = None,
) -> None:
super().__init__()
if attention_type != "self":
raise ValueError("TreeFlashAttention only supports self-attention.")
self.softmax_scale = softmax_scale
self.attention_type = attention_type
self.layer_number = 1 if layer_number is None else layer_number

def forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
cu_node_lens: torch.Tensor,
node_parent: torch.Tensor,
precomputed,
) -> torch.Tensor:
"""Forward.

Accepts either SBHD ``[S, B, H, D]`` (Megatron convention) or
THD ``[T, H, D]`` (Megatron packed-sequence convention where
the batch dim has been squeezed). Output shape matches input
convention: ``[S, B, H*D]`` for SBHD, ``[T, H*D]`` for THD.
"""
# Normalize to [B, T, H, D] for FA3
if query_layer.dim() == 3:
# THD [T, H, D] -> [1, T, H, D]
_thd_input = True
q = query_layer.unsqueeze(0).contiguous()
k = key_layer.unsqueeze(0).contiguous()
v = value_layer.unsqueeze(0).contiguous()
else:
# SBHD [S, B, H, D] -> [B, S, H, D]
_thd_input = False
q = query_layer.permute(1, 0, 2, 3).contiguous()
k = key_layer.permute(1, 0, 2, 3).contiguous()
v = value_layer.permute(1, 0, 2, 3).contiguous()

out = _FA3TreeAttnFunc.apply(
q, k, v, cu_node_lens, node_parent, self.softmax_scale, precomputed,
)

if _thd_input:
# [1, T, H, D] -> [T, H*D]
return out.squeeze(0).reshape(query_layer.shape[0], -1)
else:
# [B, S, H, D] -> [S, B, H*D]
s, b = query_layer.shape[0], query_layer.shape[1]
return out.permute(1, 0, 2, 3).contiguous().view(s, b, -1)
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
UnfusedDotProductAttention,
FusedAttention,
FlashAttention,
TreeFlashAttention,
)


Expand All @@ -73,6 +74,7 @@
"use_fused_attention": None,
"fused_attention_backend": None,
"use_unfused_attention": None,
"use_tree_attention": None,
"backend_selection_requires_update": False,
}

Expand Down Expand Up @@ -470,6 +472,15 @@ def __init__(
return_max_logit=self.return_max_logit,
)

# Tree-attention backend. Lightweight module that holds the softmax
# scale; the FA3 wheel is imported lazily inside its forward, so
# non-tree deployments pay zero import cost.
self.tree_attention = TreeFlashAttention(
softmax_scale,
attention_type=attention_type,
layer_number=layer_number,
)

def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
Temporarily remove core_attention._extra_state as a missing key
Expand Down Expand Up @@ -820,6 +831,7 @@ def forward(
pad_between_seqs: Optional[bool] = None,
fp8_output: Optional[bool] = False,
num_splits: Optional[int] = 1,
tree_metadata: Optional[Any] = None,
) -> torch.Tensor:
r"""
Dot Product Attention Layer.
Expand Down Expand Up @@ -1341,6 +1353,7 @@ def forward(
return_max_logit=self.return_max_logit,
cuda_graph=is_graph_capturing(),
num_splits=num_splits,
tree_attention=tree_metadata is not None,
)
global _attention_backends
if is_in_onnx_export_mode():
Expand All @@ -1349,6 +1362,7 @@ def forward(
use_flash_attention = False
use_fused_attention = False
use_unfused_attention = True
use_tree_attention = False
else:
if (
_attention_backends["attention_params"] is None
Expand All @@ -1364,6 +1378,7 @@ def forward(
fused_attention_backend,
use_unfused_attention,
_,
use_tree_attention,
) = dpa_utils.get_attention_backend(attention_params)
# Set global _attention_backends var using return value
# from get_attention_backend()
Expand All @@ -1372,8 +1387,11 @@ def forward(
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["use_tree_attention"] = use_tree_attention
_attention_backends["backend_selection_requires_update"] = False
if use_flash_attention:
if use_tree_attention:
self.logger.info("Running with TreeFlashAttention backend")
elif use_flash_attention:
self.logger.info(
"Running with FlashAttention backend (version %s)",
flash_attention_backend,
Expand All @@ -1391,9 +1409,15 @@ def forward(
use_fused_attention = _attention_backends["use_fused_attention"]
fused_attention_backend = _attention_backends["fused_attention_backend"]
use_unfused_attention = _attention_backends["use_unfused_attention"]
use_tree_attention = _attention_backends["use_tree_attention"]

# raise exception if no backend is available
if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0:
if sum([
use_flash_attention,
use_fused_attention,
use_unfused_attention,
bool(use_tree_attention),
]) == 0:
raise ValueError(
"No dot product attention backend is available for the provided inputs. Please"
" run with NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 to find out the reasons for"
Expand All @@ -1407,6 +1431,16 @@ def forward(
else None
)

if use_tree_attention:
return self.tree_attention(
query_layer,
key_layer,
value_layer,
cu_node_lens=tree_metadata.cu_node_lens,
node_parent=tree_metadata.node_parent,
precomputed=tree_metadata.precomputed,
)

if use_flash_attention:
if core_attention_bias_type == "alibi":
alibi_slopes, _ = dpa_utils.get_alibi(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ class AttentionParams:
return_max_logit: bool = False
cuda_graph: bool = False
num_splits: int = 1
# Tree-attention dispatch flag. When True, get_attention_backend short-
# circuits to the TreeFlashAttention backend; the caller must also pass
# tree_cu_node_lens / tree_node_parent / tree_precomputed to
# DotProductAttention.forward at call time.
tree_attention: bool = False

def __eq__(self, other):
"""
Expand Down Expand Up @@ -342,12 +347,50 @@ def get_attention_backend(
return_max_logit = attention_params.return_max_logit
cuda_graph = attention_params.cuda_graph
num_splits = attention_params.num_splits
tree_attention = attention_params.tree_attention

# Run config
logger = logging.getLogger("DotProductAttention")
logger.setLevel(AttentionLogging._log_level)
if not logger.hasHandlers():
logger.addHandler(AttentionLogging._stream_handler)

# Tree-attention early dispatch: when the caller advertises tree mode we
# bypass the normal Flash/Fused/Unfused selection and route directly to
# TreeFlashAttention. Tree requires `thd` layout, no CP (Stage 2+), and
# FA3 to be installed; the TreeFlashAttention backend itself rechecks
# runtime-specific preconditions (FA3 wheel importability, etc).
if tree_attention:
tree_kill_switch = int(os.getenv("NVTE_TREE_ATTN", "1")) == 0
if tree_kill_switch:
logger.debug("NVTE_TREE_ATTN=0; tree_attention disabled")
return (
False, None,
False, None,
False,
[False, False, False],
False,
)
if not qkv_layout.startswith("thd"):
raise ValueError(
f"tree_attention requires qkv_layout starting with 'thd' "
f"(got '{qkv_layout}')"
)
if context_parallel:
raise ValueError(
"tree_attention does not yet support context parallelism. "
"CP-aware tree attention is tracked as Stage 2+ of the "
"tree-training migration plan."
)
logger.debug("Selected backend = TreeFlashAttention")
return (
False, None,
False, None,
False,
[False, False, False],
True,
)

device_compute_capability = get_device_compute_capability()
cudnn_version = get_cudnn_version()
run_config = {
Expand Down Expand Up @@ -1166,6 +1209,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
fused_attention_backend,
use_unfused_attention,
available_backends,
False, # use_tree_attention — tree path uses the early return above
)


Expand Down