diff --git a/README.rst b/README.rst index 62e7d0738..fe74e0e30 100644 --- a/README.rst +++ b/README.rst @@ -259,6 +259,34 @@ ROCm TE provides the compile-time env NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAU * 3 - standard asm, default; * 4 - rta_asm. +AITER Native Split-K Forward (gfx942 only) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +On gfx942, the CK fused attention path can optionally dispatch the *forward* pass to AITER's +hand-written native split-K (Flash-Decoding) kernel, which splits the work along the key/value +sequence dimension to keep the GPU busy when a single attention problem does not. This is +controlled by a runtime environment variable: + +* NVTE_FUSED_ATTN_SPLITKV - by default 0 (disabled). When set to 1, eligible CK FusedAttention + forward calls are routed to AITER's native split-K kernel, which picks the number of splits with + its built-in occupancy heuristic. + +When to use it: + +* The benefit comes from *under-subscribed, long-KV* shapes - small ``batch x num_heads`` with a + large ``seqlen_kv`` (e.g. long-context prefill or decode) - where the standard kernel leaves + compute units idle. Splitting the KV dimension across more workgroups fills the machine. +* For already-saturated shapes (large ``batch x num_heads``) there is little to gain; AITER's + heuristic typically declines to split, so leaving the flag on is low-risk but offers no benefit + there. +* It is forward-only and affects only the ``FusedAttention`` / ``DotProductAttention`` module path. + Training backward and the unfused path are unchanged. + +The divert engages only when a call is eligible for the native kernel: gfx942, dense ``bshd`` / +``sbhd`` layout (no ``thd``/varlen), bf16, head dim 64, no bias/ALiBi/sliding-window/dropout/ +attention-sink/FP8/context-parallel/KV-cache, and, for causal masking, ``seqlen_kv >= seqlen_q``. +Non-eligible calls fall back to the standard CK kernel unchanged. This requires an AITER build that +includes the native split-K kernel; if it is unavailable the flag is ignored with a warning. + Experimental Triton Kernels on ROCm ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Most CUDA kernels in Transformer Engine are hipified to run on ROCm. While the hipifiled CUDA kernels are functional, they are not necessarily optimal on ROCm. diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index d1bf0d6c4..e4063472c 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -111,7 +111,8 @@ def reset_global_fp8_state(): def reset_attn_backend(): env = EnvVarCleaner(["NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN", "NVTE_FUSED_ATTN_CK", "NVTE_FUSED_ATTN_AOTRITON", - "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3", "NVTE_FP8_DPA_BWD"]) + "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3", "NVTE_FP8_DPA_BWD", + "NVTE_FUSED_ATTN_SPLITKV"]) yield # Define F16 data types to test @@ -480,6 +481,62 @@ def test_dpa_num_splits(dtype, model_configs, model): ) +# ROCm: configs matching the exact gate that AITER's native split-K (Flash-Decoding) +# forward was wired in for -- dense bf16, head_dim 64, no_bias, no sliding window, no +# dropout, no_mask/causal. See _aiter_splitkv_eligible in backends.py. +model_configs_splitkv = { + # test: ModelConfig(b, sq, hq, dqk) + "splitkv_1_0": ModelConfig(2, 2048, 16, 64, max_seqlen_kv=4096), # no_mask, bshd + "splitkv_1_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"), # causal, sbhd +} + + +@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="AITER split-K is a ROCm-only path.") +@pytest.mark.parametrize("dtype", param_types_lean) # split-K forward is bf16 only +@pytest.mark.parametrize("model_configs", [model_configs_splitkv]) +@pytest.mark.parametrize("model", model_configs_splitkv.keys()) +def test_dpa_splitkv(dtype, model_configs, model): + """Test DotProductAttention routed through AITER's native split-K forward. + + Enables NVTE_FUSED_ATTN_SPLITKV and exercises the dense bf16 head-dim-64 + no_mask/causal config the kernel was added for; the split-K result (fwd+bwd) + is validated against the unfused and CK backends. Interception only fires on + gfx942 -- elsewhere FusedAttention falls back to CK and the config still runs. + """ + import inspect + from transformer_engine.pytorch.attention.dot_product_attention import backends + + try: + from aiter.ops.mha import flash_attn_func as splitkv_func + except ImportError: + pytest.skip("aiter.ops.mha split-K forward is unavailable.") + if "num_splits" not in inspect.signature(splitkv_func).parameters: + pytest.skip("aiter split-K forward predates num_splits (AITER PR #3581).") + + # NVTE_FUSED_ATTN_SPLITKV is consumed at import time, so enable the interception + # directly for the duration of this test and restore the module state afterwards. + saved_func = backends._aiter_splitkv_flash_attn_func + saved_use = backends._use_aiter_splitkv + os.environ["NVTE_FUSED_ATTN_SPLITKV"] = "1" + backends._aiter_splitkv_flash_attn_func = splitkv_func + backends._use_aiter_splitkv = True + try: + test_dot_product_attention( + dtype, + model_configs, + model, + False, + False, + None, + False, + False, + ) + finally: + backends._aiter_splitkv_flash_attn_func = saved_func + backends._use_aiter_splitkv = saved_use + os.environ.pop("NVTE_FUSED_ATTN_SPLITKV", None) + + model_configs_softmax = { # test: ModelConfig(b, sq, hq, dqk) "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8), diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 57b619a68..acf08bfc1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -8,6 +8,7 @@ from contextlib import nullcontext from importlib.metadata import version as get_pkg_version from importlib.metadata import PackageNotFoundError +import inspect import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings @@ -116,6 +117,87 @@ fa_utils.version = PkgVersion("2.7.1") #masqurade as FA 2.7.1 fa_utils.set_flash_attention_version() attn_log.fa_logger.info("Using AITER Triton for FlashAttn.") + +# ROCm: AITER native split-K (Flash-Decoding) forward, opt-in via NVTE_FUSED_ATTN_SPLITKV. +# When enabled, eligible dense bf16 head-dim-64 calls that would run on the CK +# FusedAttention backend are transparently routed to aiter.ops.mha.flash_attn_func +# (the native split-K forward) instead -- see FusedAttention.forward. +_aiter_splitkv_flash_attn_func = None +_use_aiter_splitkv = False +if IS_HIP_EXTENSION and os.getenv("NVTE_FUSED_ATTN_SPLITKV", "0") == "1": + try: + from aiter.ops.mha import flash_attn_func as _aiter_splitkv_flash_attn_func + except ImportError: + attn_log.fa_logger.warning( + "NVTE_FUSED_ATTN_SPLITKV is set but aiter.ops.mha is unavailable;" + " split-K interception disabled." + ) + else: + # The native split-K forward (AITER PR #3581) exposes a `num_splits` arg on + # flash_attn_func. Older AITER builds lack it, so verify before enabling to + # avoid a TypeError at call time. + if "num_splits" in inspect.signature(_aiter_splitkv_flash_attn_func).parameters: + _use_aiter_splitkv = True + attn_log.fa_logger.info("AITER native split-K forward enabled for FusedAttention.") + else: + _aiter_splitkv_flash_attn_func = None + attn_log.fa_logger.warning( + "NVTE_FUSED_ATTN_SPLITKV is set but aiter.ops.mha.flash_attn_func has no" + " num_splits arg (AITER predates PR #3581); split-K interception disabled." + ) + + +def _aiter_splitkv_eligible( + q, + v, + qkv_format, + attn_mask_type, + core_attention_bias_type, + window_size, + dropout_p, + fp8, + context_parallel, + inference_params, + softmax_offset, + max_seqlen_q, + max_seqlen_kv, +): + """Whether a FusedAttention call can be served by AITER's native split-K forward. + + Conservatively mirrors aiter.ops.mha's can_impl_fmha_native gate (gfx942, dense + bf16, head_dim 64, no bias/alibi/sliding-window/dropout/sink/fp8/varlen/context- + parallel/kvcache). aiter additionally self-gates and falls back to CK internally, + so this is a pre-filter to avoid diverting calls the native kernel cannot serve. + """ + if not _use_aiter_splitkv: + return False + if get_device_compute_capability() != (9, 4): # gfx942 only + return False + if q.dtype != torch.bfloat16 or v.dtype != torch.bfloat16: + return False + if q.shape[-1] != 64 or v.shape[-1] != 64: + return False + if qkv_format not in ("bshd", "sbhd"): # dense only; thd/varlen unsupported + return False + if "padding" in attn_mask_type: + return False + if attn_mask_type not in ("no_mask", "causal"): + return False + if core_attention_bias_type != "no_bias": # excludes both bias and alibi + return False + if window_size is not None and tuple(window_size) != (-1, -1): # no sliding window + return False + if dropout_p != 0.0: + return False + if fp8 or context_parallel or inference_params is not None: + return False + if softmax_offset is not None: # no learnable sink + return False + if "causal" in attn_mask_type and max_seqlen_kv < max_seqlen_q: + return False + return True + + try: if fa_utils.use_aiter_triton: raise PackageNotFoundError # skip version check for aiter triton @@ -1941,6 +2023,52 @@ def forward( if (kv_format == "thd" or "padding" in attn_mask_type) and cu_seqlens_kv_padded is None: cu_seqlens_kv_padded = cu_seqlens_kv + # ROCm: opt-in interception (NVTE_FUSED_ATTN_SPLITKV=1). For eligible dense bf16 + # head-dim-64 calls, route the forward to AITER's native split-K kernel instead + # of the CK FusedAttention path. aiter.ops.mha.flash_attn_func is its own + # autograd Function (handles its own backward) and self-gates / falls back to CK + # internally, so non-eligible shapes are unaffected. num_splits=0 => AITER picks + # the split count heuristically. + if _aiter_splitkv_eligible( + query_layer, + value_layer, + qkv_format, + attn_mask_type, + core_attention_bias_type, + window_size, + self.attention_dropout if self.training else 0.0, + fp8, + context_parallel, + inference_params, + softmax_offset, + max_seqlen_q, + max_seqlen_kv, + ): + q, k, v = query_layer, key_layer, value_layer + if qkv_format == "sbhd": + q, k, v = (x.transpose(0, 1).contiguous() for x in (q, k, v)) + # AITER's autograd Function asserts return_lse=True whenever it will build a + # backward graph -- it saves the LSE for the backward pass -- and in that case + # flash_attn_func returns an (out, softmax_lse) tuple. Mirror its own is_grad + # gate so the LSE is requested (and unpacked) exactly when grad is needed. + return_lse = torch.is_grad_enabled() and any(t.requires_grad for t in (q, k, v)) + out = _aiter_splitkv_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + return_lse=return_lse, + num_splits=0, + ) + if return_lse: + out = out[0] + if qkv_format == "sbhd": + out = out.transpose(0, 1) + # ...hd -> ...(hd), matching the FusedAttnFunc return convention below. + return out.reshape(*out.shape[:-2], -1) + use_FAv2_bwd = ( self.use_FAv2_bwd and (core_attention_bias_type == "no_bias")