[gfx1151] [flash_attn_triton_amd]: Tune for gfx1151 / ViT#3419
Conversation
For the packed `[seq, heads, dim]` layout used by varlen prefill, the
head-axis stride equals `head_dim`. When `head_dim` is a multiple of 8
but not 16 (e.g. 72), Triton's integer-arg auto-specialization does not
attach `tt.divisibility = 8` to `stride_*h` (its threshold is 16), so
AxisInfo treats the K/V global load as 2-byte aligned and Coalesce
emits scalar `buffer_load_u16` instead of vectorized `buffer_load_b128`.
Add a `HEAD_STRIDE_ALIGNED_8` constexpr to `attn_fwd` and apply
`tl.multiple_of(off_h_{q,k} * stride_{q,k,v}h, 8)` to the head-axis
integer offset when the caller sets it. AddPtr propagates this through
to the load pointer, so AxisInfo computes a 16-byte alignment and the
load coalesces.
The wrapper checks `stride_*h % 8 == 0` against the actual runtime
strides (not against `head_dim`), so the hint stays sound for
non-contiguous Q/K/V views where `stride_*h != head_dim`. The constexpr
defaults to `False`, so external callers of `attn_fwd` are unaffected
unless they opt in.
On gfx1151 with `head_dim=72, fp16`, the Qwen3-Omni ViT prefill shape
goes from 128 `buffer_load_u16` to 16 `buffer_load_b128` and median
kernel time drops from 3.04 ms to 2.71 ms (-10.9%). Correctness matches
torch SDPA to 1e-4 max abs error.
The default RDNA config for `attn_fwd` uses BLOCK_N=32 + waves_per_eu=6 to target 6-wave occupancy. On gfx1151 the kernel allocates 32 KB shared per workgroup (Q tile staged through LDS), which already caps occupancy at 4 WGs/WGP (= 2 waves/EU) regardless of the register budget. The `waves_per_eu=6` hint then makes the register allocator spill / shorten live ranges to fit waves that will never schedule, which degrades codegen: fewer s_delay_alu hints, shorter clauses, more scalar f32 multiplies in the softmax path. Drop the hint to `waves_per_eu=2` (matching CK's high-VGPR low-occupancy choice for the same kernel on the same arch) and pair it with BLOCK_N=64 so each workgroup amortises loop-dispatch overhead over 2x the WMMA work per K-tile. The inner-loop mnemonic profile then collapses onto CK's: s_clause matches at 0.03/WMMA, s_waitcnt 1.19 vs 1.31, total instr/WMMA drops from 12.2 to 11.8. On AITER's flash_attn_2.varlen_fwd at the Qwen3-Omni ViT prefill shape (B=1, S=3200, H=16, head_dim=72, fp16) on gfx1151, median kernel time goes from 2.47 ms to 2.35 ms (-4.9%), now 3% faster than the CK ck_tile FmhaFwdKernel reference at 2.42 ms. Min times are 2.28 vs CK 2.34 ms (-2.6%). Stacked on top of the previous head-stride divisibility-by-8 hint commit, total speedup vs the upstream baseline is 3.04 ms -> 2.35 ms (-22.7%). Scoped to gfx1151 only; other RDNA targets keep the existing default pending validation.
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
This PR tunes the AMD Triton FlashAttention forward-prefill kernel for RDNA3.5 (gfx1151) and adds a compiler-alignment hint to improve vectorized global loads when head-axis strides are suitably aligned.
Changes:
- Add a
gfx1151-specific RDNA prefill config (BLOCK_N=64,waves_per_eu=2) tuned for the Qwen3-Omni ViT prefill shape. - Introduce
HEAD_STRIDE_ALIGNED_8constexpr plumbing and a runtime stride check to safely applytl.multiple_of(..., 8)on Q/K/V head-axis offsets.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
The in-kernel comment claimed the head-axis byte offset is 16-byte aligned whenever stride_*h % 8 == 0. That only holds for 16-bit element types. The hint guarantees an 8-element multiple, whose byte alignment is element-size dependent (16 B for fp16/bf16, 8 B for fp8, 32 B for fp32); only the 16-bit case yields the buffer_load_b128 widening this PR targets. The tl.multiple_of(..., 8) hint stays sound for all dtypes. Comment-only change, no behavioral impact.
brunomazzottiamd
left a comment
There was a problem hiding this comment.
Code looks good to me! Head-stride div-by-8 hint seems to be fine too, given the benchmarking performed in the scope of #3424.
However, we have a CI failure in Flash Attention - Triton / RDNA3 (1 GPU) job:
[aiter] import [module_aiter_core] under /aiter/aiter/jit/module_aiter_core.so
Traceback (most recent call last):
File "/flash-attention/benchmarks/benchmark_flash_attention.py", line 11, in <module>
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
File "/opt/venv/lib/python3.12/site-packages/flash_attn/__init__.py", line 8, in <module>
from flash_attn.flash_attn_interface import (
File "/opt/venv/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py", line 21, in <module>
from aiter.ops.triton._triton_kernels.flash_attn_triton_amd import flash_attn_2 as flash_attn_gpu
File "/aiter/aiter/__init__.py", line 92, in <module>
from .ops.gemm_op_a8w8 import * # noqa: F403,E402
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/aiter/aiter/ops/gemm_op_a8w8.py", line 22, in <module>
from ..ops.flydsl.utils import is_flydsl_available
File "/aiter/aiter/ops/flydsl/__init__.py", line 40, in <module>
from .gemm_kernels import flydsl_hgemm, flydsl_preshuffle_gemm_a8
File "/aiter/aiter/ops/flydsl/gemm_kernels.py", line 705, in <module>
_register_all_configs()
File "/aiter/aiter/ops/flydsl/gemm_kernels.py", line 701, in _register_all_configs
get_flydsl_splitk_hgemm_kernels(dtype, out_dtype)
File "/aiter/aiter/ops/flydsl/gemm_kernels.py", line 626, in get_flydsl_splitk_hgemm_kernels
config = _normalize_registry_config(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/aiter/aiter/ops/flydsl/gemm_kernels.py", line 525, in _normalize_registry_config
_validate_hgemm_tiling(
File "/aiter/aiter/ops/flydsl/gemm_kernels.py", line 383, in _validate_hgemm_tiling
if not selection_filter(m, n, k, config):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/aiter/aiter/ops/flydsl/gemm_kernels.py", line 337, in selection_filter
smem_cap = SMEM_CAPACITY_MAP[GPU_ARCH]
~~~~~~~~~~~~~~~~~^^^^^^^^^^
KeyError: 'gfx1100'
It seems some import is failing due to lack of RDNA3 support and it's FlyDSL related... I'm not sure how to proceed with this, as far as I known FlyDSL isn't supported on RDNA3:
Verified Platforms:
AMD MI300X/MI308X (gfx942), AMD MI350/MI355X (gfx950), AMD MI450 (gfx1250), Radeon AI PRO R9700 (gfx1201)
It's a burden we have to deal with if we want to enhance RDNA support.
I had added RDNA3.5 support to FlyDLS in ROCm/FlyDSL#567; but I don't see how this PR causes that CI failure. The PR doesn't touch any files involved in the FlyDSL test. |
@mgehre-amd, Flash Attention Integration job doesn't trigger for every PR. It was triggered for this one because you have changed a file in Unfortunately we have to fix this issue before merging. |
micmelesse
left a comment
There was a problem hiding this comment.
LGTM. Wait for #3683 to merge and rebase
brunomazzottiamd
left a comment
There was a problem hiding this comment.
LGTM! Let's make CI pass and then merge.
Summary
Two stacked Triton FMHA forward-prefill tuning commits, scoped to gfx1151
(Strix Halo / RDNA3.5), reducing kernel time by ~12% on the Qwen3-Omni
ViT prefill shape (B=1, S=3200, H=16, head_dim=72, fp16).
[triton-fa] hint head-stride div-by-8 for vectorized global loadAdds a
HEAD_STRIDE_ALIGNED_8constexpr +tl.multiple_of(..., 8)soTriton AxisInfo computes 16-byte alignment for the K/V global load
when
stride_*h % 8 == 0. On thehead_dim=72shape this switchesthe loop body from 128
buffer_load_u16to 16buffer_load_b128.The wrapper checks the actual runtime stride (not
head_dim), sonon-contiguous Q/K/V views stay sound. The constexpr defaults to
False, so external callers ofattn_fwdare unaffected.[triton-fa] gfx1151: BLOCK_N=64, waves_per_eu=2 for LDS-bound FMHATunes the RDNA config for gfx1151:
BLOCK_N=64,waves_per_eu=2(matching CK's
FmhaFwdKernelchoice for the same shape). Scoped togfx1151only — other RDNA targets (gfx1100,gfx1102, ...) keepthe existing default pending validation.
Benchmark — gfx1151, AITER
flash_attn_2.varlen_fwdQwen3-Omni ViT prefill shape: B=1, S=3200, H=16, head_dim=72, fp16.
triton-lang/tritonmain,d92727c2Triton 3.6.0+rocm7.14.0a20260529 (matched to torch 2.11)
Both configs improve, with triton main yielding the larger absolute and
relative speedup — likely because the head-stride alignment hint
(commit 1) lands more cleanly through the newer Coalesce pass.
Methodology:
triton.testing.do_bench, warmup=50, rep=300, 3 reps perbuild. Triton cache cleared between (baseline, patched) pairs. GPU
clock not locked.
Reproducer
python op_tests/op_benchmarks/triton/bench_mha.py \ -fn fwd_varlen -equal_seqlens \ -b 1 -hq 16 -hk 16 -sq 3200 -sk 3200 -d 72 \ --dtype fp16 -causal False \ -impl dao_ai -metric timeCorrectness
torch.nn.functional.scaled_dot_product_attention= 1.22e-4,mean = 4.5e-6 (within fp16 tolerance for S=3200 attention).
op_tests/triton_tests/attention/test_mha_dao_ai.py— 5/5 PASSED(the
dao_aiimpl inaiter.ops.triton.attention.mharoutes throughthe kernel modified here; the suite covers fwd causal/non-causal,
GQA, varlen, and backward).