From ef88e16b02a2a835d0f742c2ddd54b34a922b09e Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 15 Jun 2026 20:46:57 +0000 Subject: [PATCH 1/5] Wired split-kv through FlashAttention interface --- .../dot_product_attention/backends.py | 28 +++++++++++++++++++ .../attention/dot_product_attention/utils.py | 9 +++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 57b619a68..94012f88c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -91,6 +91,7 @@ _flash_attn_bwd = None _flash_attn_varlen_fwd = None _flash_attn_varlen_bwd = None +aiter_flash_attn_func_splitkv = None # ROCm: AITER native split-K forward (aiter.ops.mha) if IS_HIP_EXTENSION and os.getenv("NVTE_FLASH_ATTN_AITER", "0") == "1": try: @@ -116,6 +117,19 @@ fa_utils.version = PkgVersion("2.7.1") #masqurade as FA 2.7.1 fa_utils.set_flash_attention_version() attn_log.fa_logger.info("Using AITER Triton for FlashAttn.") + # Optionally pull in AITER's native split-K forward (aiter.ops.mha). Kept + # separate from the Triton imports above so its absence never disables the + # Triton path; engaged only when num_splits != 1 (see FlashAttention.forward). + try: + from aiter.ops.mha import flash_attn_func as aiter_flash_attn_func_splitkv + except ImportError: + attn_log.fa_logger.warning( + "AITER native split-K forward (aiter.ops.mha) is unavailable;" + " num_splits will be ignored for the AITER path." + ) + else: + fa_utils.use_aiter_splitkv = True + attn_log.fa_logger.info("AITER native split-K forward is available.") try: if fa_utils.use_aiter_triton: raise PackageNotFoundError # skip version check for aiter triton @@ -1040,6 +1054,20 @@ def forward( 1 )[:batch_size] ) + if ( + fa_utils.use_aiter_splitkv + and num_splits != 1 + and inference_params is None + and not fp8 + and func is flash_attn_func + ): + # ROCm: route the dense forward to AITER's native split-K kernel. + # aiter.ops.mha.flash_attn_func self-gates (gfx942/D64/bf16) and + # otherwise falls back to the standard CK/ASM dispatch, so this is + # safe for any dense bf16 shape. Forward-only; backward unchanged. + # num_splits: 0 = AITER heuristic, >=2 = forced split count. + func = aiter_flash_attn_func_splitkv + fa_optional_forward_kwargs["num_splits"] = num_splits output = func( query_layer, key_layer, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 87992d294..175882d7d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -144,6 +144,7 @@ class FlashAttentionUtils: (5) cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py""" v3_warning_printed = False use_aiter_triton = False #ROCm + use_aiter_splitkv = False #ROCm: AITER native split-K forward (aiter.ops.mha) available @staticmethod def set_flash_attention_version(): @@ -530,7 +531,13 @@ def get_attention_backend( # Filter: num_splits if num_splits != 1: - if use_flash_attention_2 and FlashAttentionUtils.is_installed: + # ROCm: the AITER backend masquerades as FlashAttention 2 and routes num_splits + # through to its native split-K forward (aiter.ops.mha), so keep it enabled here. + if ( + use_flash_attention_2 + and FlashAttentionUtils.is_installed + and not FlashAttentionUtils.use_aiter_splitkv + ): logger.debug("Disabling FlashAttention 2 for num_splits") use_flash_attention_2 = False if use_fused_attention: From cb028454a2d3444799e4f36d6e9cfded2771f706 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 16 Jun 2026 16:00:32 +0000 Subject: [PATCH 2/5] Updated splitkv intercept to live in CK FA path --- .../dot_product_attention/backends.py | 144 ++++++++++++++---- .../attention/dot_product_attention/utils.py | 9 +- 2 files changed, 119 insertions(+), 34 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 94012f88c..bdc44d9a8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -8,6 +8,7 @@ from contextlib import nullcontext from importlib.metadata import version as get_pkg_version from importlib.metadata import PackageNotFoundError +import inspect import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings @@ -91,7 +92,6 @@ _flash_attn_bwd = None _flash_attn_varlen_fwd = None _flash_attn_varlen_bwd = None -aiter_flash_attn_func_splitkv = None # ROCm: AITER native split-K forward (aiter.ops.mha) if IS_HIP_EXTENSION and os.getenv("NVTE_FLASH_ATTN_AITER", "0") == "1": try: @@ -117,19 +117,87 @@ fa_utils.version = PkgVersion("2.7.1") #masqurade as FA 2.7.1 fa_utils.set_flash_attention_version() attn_log.fa_logger.info("Using AITER Triton for FlashAttn.") - # Optionally pull in AITER's native split-K forward (aiter.ops.mha). Kept - # separate from the Triton imports above so its absence never disables the - # Triton path; engaged only when num_splits != 1 (see FlashAttention.forward). - try: - from aiter.ops.mha import flash_attn_func as aiter_flash_attn_func_splitkv - except ImportError: + +# ROCm: AITER native split-K (Flash-Decoding) forward, opt-in via NVTE_FUSED_ATTN_SPLITKV. +# When enabled, eligible dense bf16 head-dim-64 calls that would run on the CK +# FusedAttention backend are transparently routed to aiter.ops.mha.flash_attn_func +# (the native split-K forward) instead -- see FusedAttention.forward. +_aiter_splitkv_flash_attn_func = None +_use_aiter_splitkv = False +if IS_HIP_EXTENSION and os.getenv("NVTE_FUSED_ATTN_SPLITKV", "0") == "1": + try: + from aiter.ops.mha import flash_attn_func as _aiter_splitkv_flash_attn_func + except ImportError: + attn_log.fa_logger.warning( + "NVTE_FUSED_ATTN_SPLITKV is set but aiter.ops.mha is unavailable;" + " split-K interception disabled." + ) + else: + # The native split-K forward (AITER PR #3581) exposes a `num_splits` arg on + # flash_attn_func. Older AITER builds lack it, so verify before enabling to + # avoid a TypeError at call time. + if "num_splits" in inspect.signature(_aiter_splitkv_flash_attn_func).parameters: + _use_aiter_splitkv = True + attn_log.fa_logger.info("AITER native split-K forward enabled for FusedAttention.") + else: + _aiter_splitkv_flash_attn_func = None attn_log.fa_logger.warning( - "AITER native split-K forward (aiter.ops.mha) is unavailable;" - " num_splits will be ignored for the AITER path." + "NVTE_FUSED_ATTN_SPLITKV is set but aiter.ops.mha.flash_attn_func has no" + " num_splits arg (AITER predates PR #3581); split-K interception disabled." ) - else: - fa_utils.use_aiter_splitkv = True - attn_log.fa_logger.info("AITER native split-K forward is available.") + + +def _aiter_splitkv_eligible( + q, + v, + qkv_format, + attn_mask_type, + core_attention_bias_type, + window_size, + dropout_p, + fp8, + context_parallel, + inference_params, + softmax_offset, + max_seqlen_q, + max_seqlen_kv, +): + """Whether a FusedAttention call can be served by AITER's native split-K forward. + + Conservatively mirrors aiter.ops.mha's can_impl_fmha_native gate (gfx942, dense + bf16, head_dim 64, no bias/alibi/sliding-window/dropout/sink/fp8/varlen/context- + parallel/kvcache). aiter additionally self-gates and falls back to CK internally, + so this is a pre-filter to avoid diverting calls the native kernel cannot serve. + """ + if not _use_aiter_splitkv: + return False + if get_device_compute_capability() != (9, 4): # gfx942 only + return False + if q.dtype != torch.bfloat16 or v.dtype != torch.bfloat16: + return False + if q.shape[-1] != 64 or v.shape[-1] != 64: + return False + if qkv_format not in ("bshd", "sbhd"): # dense only; thd/varlen unsupported + return False + if "padding" in attn_mask_type: + return False + if attn_mask_type not in ("no_mask", "causal"): + return False + if core_attention_bias_type != "no_bias": # excludes both bias and alibi + return False + if window_size is not None and tuple(window_size) != (-1, -1): # no sliding window + return False + if dropout_p != 0.0: + return False + if fp8 or context_parallel or inference_params is not None: + return False + if softmax_offset is not None: # no learnable sink + return False + if "causal" in attn_mask_type and max_seqlen_kv < max_seqlen_q: + return False + return True + + try: if fa_utils.use_aiter_triton: raise PackageNotFoundError # skip version check for aiter triton @@ -1054,20 +1122,6 @@ def forward( 1 )[:batch_size] ) - if ( - fa_utils.use_aiter_splitkv - and num_splits != 1 - and inference_params is None - and not fp8 - and func is flash_attn_func - ): - # ROCm: route the dense forward to AITER's native split-K kernel. - # aiter.ops.mha.flash_attn_func self-gates (gfx942/D64/bf16) and - # otherwise falls back to the standard CK/ASM dispatch, so this is - # safe for any dense bf16 shape. Forward-only; backward unchanged. - # num_splits: 0 = AITER heuristic, >=2 = forced split count. - func = aiter_flash_attn_func_splitkv - fa_optional_forward_kwargs["num_splits"] = num_splits output = func( query_layer, key_layer, @@ -1969,6 +2023,44 @@ def forward( if (kv_format == "thd" or "padding" in attn_mask_type) and cu_seqlens_kv_padded is None: cu_seqlens_kv_padded = cu_seqlens_kv + # ROCm: opt-in interception (NVTE_FUSED_ATTN_SPLITKV=1). For eligible dense bf16 + # head-dim-64 calls, route the forward to AITER's native split-K kernel instead + # of the CK FusedAttention path. aiter.ops.mha.flash_attn_func is its own + # autograd Function (handles its own backward) and self-gates / falls back to CK + # internally, so non-eligible shapes are unaffected. num_splits=0 => AITER picks + # the split count heuristically. + if _aiter_splitkv_eligible( + query_layer, + value_layer, + qkv_format, + attn_mask_type, + core_attention_bias_type, + window_size, + self.attention_dropout if self.training else 0.0, + fp8, + context_parallel, + inference_params, + softmax_offset, + max_seqlen_q, + max_seqlen_kv, + ): + q, k, v = query_layer, key_layer, value_layer + if qkv_format == "sbhd": + q, k, v = (x.transpose(0, 1).contiguous() for x in (q, k, v)) + out = _aiter_splitkv_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + num_splits=0, + ) + if qkv_format == "sbhd": + out = out.transpose(0, 1) + # ...hd -> ...(hd), matching the FusedAttnFunc return convention below. + return out.reshape(*out.shape[:-2], -1) + use_FAv2_bwd = ( self.use_FAv2_bwd and (core_attention_bias_type == "no_bias") diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 175882d7d..87992d294 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -144,7 +144,6 @@ class FlashAttentionUtils: (5) cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py""" v3_warning_printed = False use_aiter_triton = False #ROCm - use_aiter_splitkv = False #ROCm: AITER native split-K forward (aiter.ops.mha) available @staticmethod def set_flash_attention_version(): @@ -531,13 +530,7 @@ def get_attention_backend( # Filter: num_splits if num_splits != 1: - # ROCm: the AITER backend masquerades as FlashAttention 2 and routes num_splits - # through to its native split-K forward (aiter.ops.mha), so keep it enabled here. - if ( - use_flash_attention_2 - and FlashAttentionUtils.is_installed - and not FlashAttentionUtils.use_aiter_splitkv - ): + if use_flash_attention_2 and FlashAttentionUtils.is_installed: logger.debug("Disabling FlashAttention 2 for num_splits") use_flash_attention_2 = False if use_fused_attention: From b4f8fcf6e103e669519749637d98b53f30dd71b3 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 16 Jun 2026 19:45:39 +0000 Subject: [PATCH 3/5] Added readme segment --- README.rst | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/README.rst b/README.rst index 62e7d0738..fe74e0e30 100644 --- a/README.rst +++ b/README.rst @@ -259,6 +259,34 @@ ROCm TE provides the compile-time env NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAU * 3 - standard asm, default; * 4 - rta_asm. +AITER Native Split-K Forward (gfx942 only) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +On gfx942, the CK fused attention path can optionally dispatch the *forward* pass to AITER's +hand-written native split-K (Flash-Decoding) kernel, which splits the work along the key/value +sequence dimension to keep the GPU busy when a single attention problem does not. This is +controlled by a runtime environment variable: + +* NVTE_FUSED_ATTN_SPLITKV - by default 0 (disabled). When set to 1, eligible CK FusedAttention + forward calls are routed to AITER's native split-K kernel, which picks the number of splits with + its built-in occupancy heuristic. + +When to use it: + +* The benefit comes from *under-subscribed, long-KV* shapes - small ``batch x num_heads`` with a + large ``seqlen_kv`` (e.g. long-context prefill or decode) - where the standard kernel leaves + compute units idle. Splitting the KV dimension across more workgroups fills the machine. +* For already-saturated shapes (large ``batch x num_heads``) there is little to gain; AITER's + heuristic typically declines to split, so leaving the flag on is low-risk but offers no benefit + there. +* It is forward-only and affects only the ``FusedAttention`` / ``DotProductAttention`` module path. + Training backward and the unfused path are unchanged. + +The divert engages only when a call is eligible for the native kernel: gfx942, dense ``bshd`` / +``sbhd`` layout (no ``thd``/varlen), bf16, head dim 64, no bias/ALiBi/sliding-window/dropout/ +attention-sink/FP8/context-parallel/KV-cache, and, for causal masking, ``seqlen_kv >= seqlen_q``. +Non-eligible calls fall back to the standard CK kernel unchanged. This requires an AITER build that +includes the native split-K kernel; if it is unavailable the flag is ignored with a warning. + Experimental Triton Kernels on ROCm ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Most CUDA kernels in Transformer Engine are hipified to run on ROCm. While the hipifiled CUDA kernels are functional, they are not necessarily optimal on ROCm. From 060aaa96a85fa1a2aa0687aae4c2d91e80d47b6d Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 18 Jun 2026 18:36:23 +0000 Subject: [PATCH 4/5] Added exclusive test --- tests/pytorch/attention/test_attention.py | 59 ++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index d1bf0d6c4..e4063472c 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -111,7 +111,8 @@ def reset_global_fp8_state(): def reset_attn_backend(): env = EnvVarCleaner(["NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN", "NVTE_FUSED_ATTN_CK", "NVTE_FUSED_ATTN_AOTRITON", - "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3", "NVTE_FP8_DPA_BWD"]) + "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3", "NVTE_FP8_DPA_BWD", + "NVTE_FUSED_ATTN_SPLITKV"]) yield # Define F16 data types to test @@ -480,6 +481,62 @@ def test_dpa_num_splits(dtype, model_configs, model): ) +# ROCm: configs matching the exact gate that AITER's native split-K (Flash-Decoding) +# forward was wired in for -- dense bf16, head_dim 64, no_bias, no sliding window, no +# dropout, no_mask/causal. See _aiter_splitkv_eligible in backends.py. +model_configs_splitkv = { + # test: ModelConfig(b, sq, hq, dqk) + "splitkv_1_0": ModelConfig(2, 2048, 16, 64, max_seqlen_kv=4096), # no_mask, bshd + "splitkv_1_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"), # causal, sbhd +} + + +@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="AITER split-K is a ROCm-only path.") +@pytest.mark.parametrize("dtype", param_types_lean) # split-K forward is bf16 only +@pytest.mark.parametrize("model_configs", [model_configs_splitkv]) +@pytest.mark.parametrize("model", model_configs_splitkv.keys()) +def test_dpa_splitkv(dtype, model_configs, model): + """Test DotProductAttention routed through AITER's native split-K forward. + + Enables NVTE_FUSED_ATTN_SPLITKV and exercises the dense bf16 head-dim-64 + no_mask/causal config the kernel was added for; the split-K result (fwd+bwd) + is validated against the unfused and CK backends. Interception only fires on + gfx942 -- elsewhere FusedAttention falls back to CK and the config still runs. + """ + import inspect + from transformer_engine.pytorch.attention.dot_product_attention import backends + + try: + from aiter.ops.mha import flash_attn_func as splitkv_func + except ImportError: + pytest.skip("aiter.ops.mha split-K forward is unavailable.") + if "num_splits" not in inspect.signature(splitkv_func).parameters: + pytest.skip("aiter split-K forward predates num_splits (AITER PR #3581).") + + # NVTE_FUSED_ATTN_SPLITKV is consumed at import time, so enable the interception + # directly for the duration of this test and restore the module state afterwards. + saved_func = backends._aiter_splitkv_flash_attn_func + saved_use = backends._use_aiter_splitkv + os.environ["NVTE_FUSED_ATTN_SPLITKV"] = "1" + backends._aiter_splitkv_flash_attn_func = splitkv_func + backends._use_aiter_splitkv = True + try: + test_dot_product_attention( + dtype, + model_configs, + model, + False, + False, + None, + False, + False, + ) + finally: + backends._aiter_splitkv_flash_attn_func = saved_func + backends._use_aiter_splitkv = saved_use + os.environ.pop("NVTE_FUSED_ATTN_SPLITKV", None) + + model_configs_softmax = { # test: ModelConfig(b, sq, hq, dqk) "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8), From 0ce5c33a771365202cb5745b3fa6af170a5b2d12 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 18 Jun 2026 18:49:13 +0000 Subject: [PATCH 5/5] Added return_lse fix --- .../pytorch/attention/dot_product_attention/backends.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index bdc44d9a8..acf08bfc1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -2047,6 +2047,11 @@ def forward( q, k, v = query_layer, key_layer, value_layer if qkv_format == "sbhd": q, k, v = (x.transpose(0, 1).contiguous() for x in (q, k, v)) + # AITER's autograd Function asserts return_lse=True whenever it will build a + # backward graph -- it saves the LSE for the backward pass -- and in that case + # flash_attn_func returns an (out, softmax_lse) tuple. Mirror its own is_grad + # gate so the LSE is requested (and unpacked) exactly when grad is needed. + return_lse = torch.is_grad_enabled() and any(t.requires_grad for t in (q, k, v)) out = _aiter_splitkv_flash_attn_func( q, k, @@ -2054,8 +2059,11 @@ def forward( dropout_p=0.0, softmax_scale=self.softmax_scale, causal="causal" in attn_mask_type, + return_lse=return_lse, num_splits=0, ) + if return_lse: + out = out[0] if qkv_format == "sbhd": out = out.transpose(0, 1) # ...hd -> ...(hd), matching the FusedAttnFunc return convention below.