-
Notifications
You must be signed in to change notification settings - Fork 32
Added support for AITER JIT native splitkv kernel #631
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
ef88e16
cb02845
b4f8fcf
060aaa9
0ce5c33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
| from contextlib import nullcontext | ||
| from importlib.metadata import version as get_pkg_version | ||
| from importlib.metadata import PackageNotFoundError | ||
| import inspect | ||
| import os | ||
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||
| import warnings | ||
|
|
@@ -116,6 +117,87 @@ | |
| fa_utils.version = PkgVersion("2.7.1") #masqurade as FA 2.7.1 | ||
| fa_utils.set_flash_attention_version() | ||
| attn_log.fa_logger.info("Using AITER Triton for FlashAttn.") | ||
|
|
||
| # ROCm: AITER native split-K (Flash-Decoding) forward, opt-in via NVTE_FUSED_ATTN_SPLITKV. | ||
| # When enabled, eligible dense bf16 head-dim-64 calls that would run on the CK | ||
| # FusedAttention backend are transparently routed to aiter.ops.mha.flash_attn_func | ||
| # (the native split-K forward) instead -- see FusedAttention.forward. | ||
| _aiter_splitkv_flash_attn_func = None | ||
| _use_aiter_splitkv = False | ||
| if IS_HIP_EXTENSION and os.getenv("NVTE_FUSED_ATTN_SPLITKV", "0") == "1": | ||
| try: | ||
| from aiter.ops.mha import flash_attn_func as _aiter_splitkv_flash_attn_func | ||
| except ImportError: | ||
| attn_log.fa_logger.warning( | ||
| "NVTE_FUSED_ATTN_SPLITKV is set but aiter.ops.mha is unavailable;" | ||
| " split-K interception disabled." | ||
| ) | ||
| else: | ||
| # The native split-K forward (AITER PR #3581) exposes a `num_splits` arg on | ||
| # flash_attn_func. Older AITER builds lack it, so verify before enabling to | ||
| # avoid a TypeError at call time. | ||
| if "num_splits" in inspect.signature(_aiter_splitkv_flash_attn_func).parameters: | ||
| _use_aiter_splitkv = True | ||
| attn_log.fa_logger.info("AITER native split-K forward enabled for FusedAttention.") | ||
| else: | ||
| _aiter_splitkv_flash_attn_func = None | ||
| attn_log.fa_logger.warning( | ||
| "NVTE_FUSED_ATTN_SPLITKV is set but aiter.ops.mha.flash_attn_func has no" | ||
| " num_splits arg (AITER predates PR #3581); split-K interception disabled." | ||
| ) | ||
|
|
||
|
|
||
| def _aiter_splitkv_eligible( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which backward backends does it work with w/o issues? |
||
| q, | ||
| v, | ||
| qkv_format, | ||
| attn_mask_type, | ||
| core_attention_bias_type, | ||
| window_size, | ||
| dropout_p, | ||
| fp8, | ||
| context_parallel, | ||
| inference_params, | ||
| softmax_offset, | ||
| max_seqlen_q, | ||
| max_seqlen_kv, | ||
| ): | ||
| """Whether a FusedAttention call can be served by AITER's native split-K forward. | ||
|
|
||
| Conservatively mirrors aiter.ops.mha's can_impl_fmha_native gate (gfx942, dense | ||
| bf16, head_dim 64, no bias/alibi/sliding-window/dropout/sink/fp8/varlen/context- | ||
| parallel/kvcache). aiter additionally self-gates and falls back to CK internally, | ||
| so this is a pre-filter to avoid diverting calls the native kernel cannot serve. | ||
| """ | ||
| if not _use_aiter_splitkv: | ||
| return False | ||
| if get_device_compute_capability() != (9, 4): # gfx942 only | ||
| return False | ||
| if q.dtype != torch.bfloat16 or v.dtype != torch.bfloat16: | ||
| return False | ||
| if q.shape[-1] != 64 or v.shape[-1] != 64: | ||
| return False | ||
| if qkv_format not in ("bshd", "sbhd"): # dense only; thd/varlen unsupported | ||
| return False | ||
| if "padding" in attn_mask_type: | ||
| return False | ||
| if attn_mask_type not in ("no_mask", "causal"): | ||
| return False | ||
| if core_attention_bias_type != "no_bias": # excludes both bias and alibi | ||
| return False | ||
| if window_size is not None and tuple(window_size) != (-1, -1): # no sliding window | ||
| return False | ||
| if dropout_p != 0.0: | ||
| return False | ||
| if fp8 or context_parallel or inference_params is not None: | ||
| return False | ||
| if softmax_offset is not None: # no learnable sink | ||
| return False | ||
| if "causal" in attn_mask_type and max_seqlen_kv < max_seqlen_q: | ||
| return False | ||
| return True | ||
|
Comment on lines
+172
to
+198
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Eligibility doesn't gate on Same comment applies to dtype/head-dim: the signature takes |
||
|
|
||
|
|
||
| try: | ||
| if fa_utils.use_aiter_triton: | ||
| raise PackageNotFoundError # skip version check for aiter triton | ||
|
|
@@ -1941,6 +2023,52 @@ def forward( | |
| if (kv_format == "thd" or "padding" in attn_mask_type) and cu_seqlens_kv_padded is None: | ||
| cu_seqlens_kv_padded = cu_seqlens_kv | ||
|
|
||
| # ROCm: opt-in interception (NVTE_FUSED_ATTN_SPLITKV=1). For eligible dense bf16 | ||
| # head-dim-64 calls, route the forward to AITER's native split-K kernel instead | ||
| # of the CK FusedAttention path. aiter.ops.mha.flash_attn_func is its own | ||
| # autograd Function (handles its own backward) and self-gates / falls back to CK | ||
| # internally, so non-eligible shapes are unaffected. num_splits=0 => AITER picks | ||
| # the split count heuristically. | ||
| if _aiter_splitkv_eligible( | ||
| query_layer, | ||
| value_layer, | ||
| qkv_format, | ||
| attn_mask_type, | ||
| core_attention_bias_type, | ||
| window_size, | ||
| self.attention_dropout if self.training else 0.0, | ||
| fp8, | ||
| context_parallel, | ||
| inference_params, | ||
| softmax_offset, | ||
| max_seqlen_q, | ||
| max_seqlen_kv, | ||
| ): | ||
|
Comment on lines
+2032
to
+2046
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: the interception sits after lines 1949–2024 which build |
||
| q, k, v = query_layer, key_layer, value_layer | ||
| if qkv_format == "sbhd": | ||
| q, k, v = (x.transpose(0, 1).contiguous() for x in (q, k, v)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For
|
||
| # AITER's autograd Function asserts return_lse=True whenever it will build a | ||
| # backward graph -- it saves the LSE for the backward pass -- and in that case | ||
| # flash_attn_func returns an (out, softmax_lse) tuple. Mirror its own is_grad | ||
| # gate so the LSE is requested (and unpacked) exactly when grad is needed. | ||
| return_lse = torch.is_grad_enabled() and any(t.requires_grad for t in (q, k, v)) | ||
| out = _aiter_splitkv_flash_attn_func( | ||
| q, | ||
| k, | ||
| v, | ||
| dropout_p=0.0, | ||
| softmax_scale=self.softmax_scale, | ||
| causal="causal" in attn_mask_type, | ||
| return_lse=return_lse, | ||
| num_splits=0, | ||
|
Comment on lines
+2046
to
+2063
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The intercept path drops Two safe options: (a) gate eligibility on |
||
| ) | ||
| if return_lse: | ||
| out = out[0] | ||
| if qkv_format == "sbhd": | ||
| out = out.transpose(0, 1) | ||
| # ...hd -> ...(hd), matching the FusedAttnFunc return convention below. | ||
| return out.reshape(*out.shape[:-2], -1) | ||
|
|
||
| use_FAv2_bwd = ( | ||
| self.use_FAv2_bwd | ||
| and (core_attention_bias_type == "no_bias") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be NVTE_FLASH_ATTN_AITER_SPLITKV NVTE_FLASH_ATTN_SPLITKV to match NVTE_FLASH_ATTN_AITER?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this was an intercept on the FA dispatch, I figured we could keep it opaque to the user hence keep the FA style env variable -- if you think it should be transparent instead I'd be happy to change it to
NVTE_FLASH_ATTN_AITER_SPLITKV