Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,34 @@ ROCm TE provides the compile-time env NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAU
* 3 - standard asm, default;
* 4 - rta_asm.

AITER Native Split-K Forward (gfx942 only)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
On gfx942, the CK fused attention path can optionally dispatch the *forward* pass to AITER's
hand-written native split-K (Flash-Decoding) kernel, which splits the work along the key/value
sequence dimension to keep the GPU busy when a single attention problem does not. This is
controlled by a runtime environment variable:

* NVTE_FUSED_ATTN_SPLITKV - by default 0 (disabled). When set to 1, eligible CK FusedAttention
forward calls are routed to AITER's native split-K kernel, which picks the number of splits with
its built-in occupancy heuristic.

When to use it:

* The benefit comes from *under-subscribed, long-KV* shapes - small ``batch x num_heads`` with a
large ``seqlen_kv`` (e.g. long-context prefill or decode) - where the standard kernel leaves
compute units idle. Splitting the KV dimension across more workgroups fills the machine.
* For already-saturated shapes (large ``batch x num_heads``) there is little to gain; AITER's
heuristic typically declines to split, so leaving the flag on is low-risk but offers no benefit
there.
* It is forward-only and affects only the ``FusedAttention`` / ``DotProductAttention`` module path.
Training backward and the unfused path are unchanged.

The divert engages only when a call is eligible for the native kernel: gfx942, dense ``bshd`` /
``sbhd`` layout (no ``thd``/varlen), bf16, head dim 64, no bias/ALiBi/sliding-window/dropout/
attention-sink/FP8/context-parallel/KV-cache, and, for causal masking, ``seqlen_kv >= seqlen_q``.
Non-eligible calls fall back to the standard CK kernel unchanged. This requires an AITER build that
includes the native split-K kernel; if it is unavailable the flag is ignored with a warning.

Experimental Triton Kernels on ROCm
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Most CUDA kernels in Transformer Engine are hipified to run on ROCm. While the hipifiled CUDA kernels are functional, they are not necessarily optimal on ROCm.
Expand Down
59 changes: 58 additions & 1 deletion tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def reset_global_fp8_state():
def reset_attn_backend():
env = EnvVarCleaner(["NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN",
"NVTE_FUSED_ATTN_CK", "NVTE_FUSED_ATTN_AOTRITON",
"NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3", "NVTE_FP8_DPA_BWD"])
"NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3", "NVTE_FP8_DPA_BWD",
"NVTE_FUSED_ATTN_SPLITKV"])
yield

# Define F16 data types to test
Expand Down Expand Up @@ -480,6 +481,62 @@ def test_dpa_num_splits(dtype, model_configs, model):
)


# ROCm: configs matching the exact gate that AITER's native split-K (Flash-Decoding)
# forward was wired in for -- dense bf16, head_dim 64, no_bias, no sliding window, no
# dropout, no_mask/causal. See _aiter_splitkv_eligible in backends.py.
model_configs_splitkv = {
# test: ModelConfig(b, sq, hq, dqk)
"splitkv_1_0": ModelConfig(2, 2048, 16, 64, max_seqlen_kv=4096), # no_mask, bshd
"splitkv_1_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"), # causal, sbhd
}


@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="AITER split-K is a ROCm-only path.")
@pytest.mark.parametrize("dtype", param_types_lean) # split-K forward is bf16 only
@pytest.mark.parametrize("model_configs", [model_configs_splitkv])
@pytest.mark.parametrize("model", model_configs_splitkv.keys())
def test_dpa_splitkv(dtype, model_configs, model):
"""Test DotProductAttention routed through AITER's native split-K forward.

Enables NVTE_FUSED_ATTN_SPLITKV and exercises the dense bf16 head-dim-64
no_mask/causal config the kernel was added for; the split-K result (fwd+bwd)
is validated against the unfused and CK backends. Interception only fires on
gfx942 -- elsewhere FusedAttention falls back to CK and the config still runs.
"""
import inspect
from transformer_engine.pytorch.attention.dot_product_attention import backends

try:
from aiter.ops.mha import flash_attn_func as splitkv_func
except ImportError:
pytest.skip("aiter.ops.mha split-K forward is unavailable.")
if "num_splits" not in inspect.signature(splitkv_func).parameters:
pytest.skip("aiter split-K forward predates num_splits (AITER PR #3581).")

# NVTE_FUSED_ATTN_SPLITKV is consumed at import time, so enable the interception
# directly for the duration of this test and restore the module state afterwards.
saved_func = backends._aiter_splitkv_flash_attn_func
saved_use = backends._use_aiter_splitkv
os.environ["NVTE_FUSED_ATTN_SPLITKV"] = "1"
backends._aiter_splitkv_flash_attn_func = splitkv_func
backends._use_aiter_splitkv = True
try:
test_dot_product_attention(
dtype,
model_configs,
model,
False,
False,
None,
False,
False,
)
finally:
backends._aiter_splitkv_flash_attn_func = saved_func
backends._use_aiter_splitkv = saved_use
os.environ.pop("NVTE_FUSED_ATTN_SPLITKV", None)


model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from contextlib import nullcontext
from importlib.metadata import version as get_pkg_version
from importlib.metadata import PackageNotFoundError
import inspect
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
Expand Down Expand Up @@ -116,6 +117,87 @@
fa_utils.version = PkgVersion("2.7.1") #masqurade as FA 2.7.1
fa_utils.set_flash_attention_version()
attn_log.fa_logger.info("Using AITER Triton for FlashAttn.")

# ROCm: AITER native split-K (Flash-Decoding) forward, opt-in via NVTE_FUSED_ATTN_SPLITKV.
# When enabled, eligible dense bf16 head-dim-64 calls that would run on the CK
# FusedAttention backend are transparently routed to aiter.ops.mha.flash_attn_func
# (the native split-K forward) instead -- see FusedAttention.forward.
_aiter_splitkv_flash_attn_func = None
_use_aiter_splitkv = False
if IS_HIP_EXTENSION and os.getenv("NVTE_FUSED_ATTN_SPLITKV", "0") == "1":

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be NVTE_FLASH_ATTN_AITER_SPLITKV NVTE_FLASH_ATTN_SPLITKV to match NVTE_FLASH_ATTN_AITER?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this was an intercept on the FA dispatch, I figured we could keep it opaque to the user hence keep the FA style env variable -- if you think it should be transparent instead I'd be happy to change it to NVTE_FLASH_ATTN_AITER_SPLITKV

try:
from aiter.ops.mha import flash_attn_func as _aiter_splitkv_flash_attn_func
except ImportError:
attn_log.fa_logger.warning(
"NVTE_FUSED_ATTN_SPLITKV is set but aiter.ops.mha is unavailable;"
" split-K interception disabled."
)
else:
# The native split-K forward (AITER PR #3581) exposes a `num_splits` arg on
# flash_attn_func. Older AITER builds lack it, so verify before enabling to
# avoid a TypeError at call time.
if "num_splits" in inspect.signature(_aiter_splitkv_flash_attn_func).parameters:
_use_aiter_splitkv = True
attn_log.fa_logger.info("AITER native split-K forward enabled for FusedAttention.")
else:
_aiter_splitkv_flash_attn_func = None
attn_log.fa_logger.warning(
"NVTE_FUSED_ATTN_SPLITKV is set but aiter.ops.mha.flash_attn_func has no"
" num_splits arg (AITER predates PR #3581); split-K interception disabled."
)


def _aiter_splitkv_eligible(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which backward backends does it work with w/o issues?

q,
v,
qkv_format,
attn_mask_type,
core_attention_bias_type,
window_size,
dropout_p,
fp8,
context_parallel,
inference_params,
softmax_offset,
max_seqlen_q,
max_seqlen_kv,
):
"""Whether a FusedAttention call can be served by AITER's native split-K forward.

Conservatively mirrors aiter.ops.mha's can_impl_fmha_native gate (gfx942, dense
bf16, head_dim 64, no bias/alibi/sliding-window/dropout/sink/fp8/varlen/context-
parallel/kvcache). aiter additionally self-gates and falls back to CK internally,
so this is a pre-filter to avoid diverting calls the native kernel cannot serve.
"""
if not _use_aiter_splitkv:
return False
if get_device_compute_capability() != (9, 4): # gfx942 only
return False
if q.dtype != torch.bfloat16 or v.dtype != torch.bfloat16:
return False
if q.shape[-1] != 64 or v.shape[-1] != 64:
return False
if qkv_format not in ("bshd", "sbhd"): # dense only; thd/varlen unsupported
return False
if "padding" in attn_mask_type:
return False
if attn_mask_type not in ("no_mask", "causal"):
return False
if core_attention_bias_type != "no_bias": # excludes both bias and alibi
return False
if window_size is not None and tuple(window_size) != (-1, -1): # no sliding window
return False
if dropout_p != 0.0:
return False
if fp8 or context_parallel or inference_params is not None:
return False
if softmax_offset is not None: # no learnable sink
return False
if "causal" in attn_mask_type and max_seqlen_kv < max_seqlen_q:
return False
return True
Comment on lines +172 to +198

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eligibility doesn't gate on num_heads_q == num_heads_kv (GQA / MQA). The CK FusedAttention path supports unequal head counts; if aiter.ops.mha.flash_attn_func's native split-K path doesn't, GQA shapes will hit the divert and then either error or fall back inside aiter (losing the perf benefit). Worth either adding a q.shape[-2] != k.shape[-2] -> False check (and taking k into the signature for that), or documenting in the README + the docstring that GQA is intentionally let through and relies on aiter's internal self-gate.

Same comment applies to dtype/head-dim: the signature takes q, v but assumes k matches — passing and checking k would close the assumption.



try:
if fa_utils.use_aiter_triton:
raise PackageNotFoundError # skip version check for aiter triton
Expand Down Expand Up @@ -1941,6 +2023,52 @@ def forward(
if (kv_format == "thd" or "padding" in attn_mask_type) and cu_seqlens_kv_padded is None:
cu_seqlens_kv_padded = cu_seqlens_kv

# ROCm: opt-in interception (NVTE_FUSED_ATTN_SPLITKV=1). For eligible dense bf16
# head-dim-64 calls, route the forward to AITER's native split-K kernel instead
# of the CK FusedAttention path. aiter.ops.mha.flash_attn_func is its own
# autograd Function (handles its own backward) and self-gates / falls back to CK
# internally, so non-eligible shapes are unaffected. num_splits=0 => AITER picks
# the split count heuristically.
if _aiter_splitkv_eligible(
query_layer,
value_layer,
qkv_format,
attn_mask_type,
core_attention_bias_type,
window_size,
self.attention_dropout if self.training else 0.0,
fp8,
context_parallel,
inference_params,
softmax_offset,
max_seqlen_q,
max_seqlen_kv,
):
Comment on lines +2032 to +2046

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: the interception sits after lines 1949–2024 which build cu_seqlens_q/kv, cu_seqlens_*_padded, page_table, etc. — none of which the aiter path uses. Moving the eligibility check + divert above that block (right after qkv_format is computed at line 1958) would avoid the wasted work on every eligible call. Not critical, but on the hot path it's free to fix.

q, k, v = query_layer, key_layer, value_layer
if qkv_format == "sbhd":
q, k, v = (x.transpose(0, 1).contiguous() for x in (q, k, v))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sbhd, this materializes a full contiguous [b, s, h, d] copy of all three tensors on every eligible call. That copy can offset (or exceed) the perceived split-K speedup on the very shapes the README highlights as the target (under-subscribed long-KV decode-style). Worth either:

  • excluding sbhd from eligibility for now (bshd is the layout where split-K most cleanly helps), or
  • calling out the extra copy in the README's "When to use it" so users don't measure regressions on sbhd and conclude the flag is broken.

# AITER's autograd Function asserts return_lse=True whenever it will build a
# backward graph -- it saves the LSE for the backward pass -- and in that case
# flash_attn_func returns an (out, softmax_lse) tuple. Mirror its own is_grad
# gate so the LSE is requested (and unpacked) exactly when grad is needed.
return_lse = torch.is_grad_enabled() and any(t.requires_grad for t in (q, k, v))
out = _aiter_splitkv_flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
return_lse=return_lse,
num_splits=0,
Comment on lines +2046 to +2063

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intercept path drops bottom_right_diagonal — the regular path threads it into FusedAttnFunc.apply (line 2163) and it affects causal semantics whenever seqlen_q != seqlen_kv. The eligibility check allows max_seqlen_kv > max_seqlen_q for attn_mask_type == "causal", but then calls aiter with just causal=True, which is typically top-left causal. If a caller relied on bottom-right semantics (e.g. cross-attention with attn_mask_type="causal" + bottom_right_diagonal=True), the intercepted forward will produce different outputs than the unintercepted path.

Two safe options: (a) gate eligibility on max_seqlen_q == max_seqlen_kv for causal so TL/BR coincide, or (b) also check bottom_right_diagonal in (None, False) and document the assumption. Either keeps the intercept semantically equivalent to the CK path.

)
if return_lse:
out = out[0]
if qkv_format == "sbhd":
out = out.transpose(0, 1)
# ...hd -> ...(hd), matching the FusedAttnFunc return convention below.
return out.reshape(*out.shape[:-2], -1)

use_FAv2_bwd = (
self.use_FAv2_bwd
and (core_attention_bias_type == "no_bias")
Expand Down
Loading