Skip to content

[gfx1151] [flash_attn_triton_amd]: Tune for gfx1151 / ViT#3419

Merged
mgehre-amd merged 7 commits into
ROCm:mainfrom
mgehre-amd:matthias.triton-fa-gfx1151-blockn64
Jun 15, 2026
Merged

[gfx1151] [flash_attn_triton_amd]: Tune for gfx1151 / ViT#3419
mgehre-amd merged 7 commits into
ROCm:mainfrom
mgehre-amd:matthias.triton-fa-gfx1151-blockn64

Conversation

@mgehre-amd

Copy link
Copy Markdown
Contributor

Summary

Two stacked Triton FMHA forward-prefill tuning commits, scoped to gfx1151
(Strix Halo / RDNA3.5), reducing kernel time by ~12% on the Qwen3-Omni
ViT prefill shape (B=1, S=3200, H=16, head_dim=72, fp16).

  1. [triton-fa] hint head-stride div-by-8 for vectorized global load
    Adds a HEAD_STRIDE_ALIGNED_8 constexpr + tl.multiple_of(..., 8) so
    Triton AxisInfo computes 16-byte alignment for the K/V global load
    when stride_*h % 8 == 0. On the head_dim=72 shape this switches
    the loop body from 128 buffer_load_u16 to 16 buffer_load_b128.
    The wrapper checks the actual runtime stride (not head_dim), so
    non-contiguous Q/K/V views stay sound. The constexpr defaults to
    False, so external callers of attn_fwd are unaffected.

  2. [triton-fa] gfx1151: BLOCK_N=64, waves_per_eu=2 for LDS-bound FMHA
    Tunes the RDNA config for gfx1151: BLOCK_N=64, waves_per_eu=2
    (matching CK's FmhaFwdKernel choice for the same shape). Scoped to
    gfx1151 only — other RDNA targets (gfx1100, gfx1102, ...) keep
    the existing default pending validation.

Benchmark — gfx1151, AITER flash_attn_2.varlen_fwd

Qwen3-Omni ViT prefill shape: B=1, S=3200, H=16, head_dim=72, fp16.

triton-lang/triton main, d92727c2

build p20 p50 (median) p80
baseline (origin/main) 2.96 ms 3.04 ms 3.13 ms
this PR (both commits) 2.35 ms 2.39 ms 2.45 ms
Δ −0.61 ms −0.65 ms (−21.4 %) −0.68 ms

Triton 3.6.0+rocm7.14.0a20260529 (matched to torch 2.11)

build p20 p50 (median) p80
baseline (origin/main) 3.21 ms 3.30 ms 3.43 ms
this PR (both commits) 2.88 ms 2.91 ms 2.99 ms
Δ −0.33 ms −0.39 ms (−11.8 %) −0.44 ms

Both configs improve, with triton main yielding the larger absolute and
relative speedup — likely because the head-stride alignment hint
(commit 1) lands more cleanly through the newer Coalesce pass.

Methodology: triton.testing.do_bench, warmup=50, rep=300, 3 reps per
build. Triton cache cleared between (baseline, patched) pairs. GPU
clock not locked.

Reproducer

python op_tests/op_benchmarks/triton/bench_mha.py \
    -fn fwd_varlen -equal_seqlens \
    -b 1 -hq 16 -hk 16 -sq 3200 -sk 3200 -d 72 \
    --dtype fp16 -causal False \
    -impl dao_ai -metric time

Correctness

  • Single-shape sanity at the Qwen3-Omni shape above: max abs diff vs
    torch.nn.functional.scaled_dot_product_attention = 1.22e-4,
    mean = 4.5e-6 (within fp16 tolerance for S=3200 attention).
  • op_tests/triton_tests/attention/test_mha_dao_ai.py5/5 PASSED
    (the dao_ai impl in aiter.ops.triton.attention.mha routes through
    the kernel modified here; the suite covers fwd causal/non-causal,
    GQA, varlen, and backward).

For the packed `[seq, heads, dim]` layout used by varlen prefill, the
head-axis stride equals `head_dim`. When `head_dim` is a multiple of 8
but not 16 (e.g. 72), Triton's integer-arg auto-specialization does not
attach `tt.divisibility = 8` to `stride_*h` (its threshold is 16), so
AxisInfo treats the K/V global load as 2-byte aligned and Coalesce
emits scalar `buffer_load_u16` instead of vectorized `buffer_load_b128`.

Add a `HEAD_STRIDE_ALIGNED_8` constexpr to `attn_fwd` and apply
`tl.multiple_of(off_h_{q,k} * stride_{q,k,v}h, 8)` to the head-axis
integer offset when the caller sets it. AddPtr propagates this through
to the load pointer, so AxisInfo computes a 16-byte alignment and the
load coalesces.

The wrapper checks `stride_*h % 8 == 0` against the actual runtime
strides (not against `head_dim`), so the hint stays sound for
non-contiguous Q/K/V views where `stride_*h != head_dim`. The constexpr
defaults to `False`, so external callers of `attn_fwd` are unaffected
unless they opt in.

On gfx1151 with `head_dim=72, fp16`, the Qwen3-Omni ViT prefill shape
goes from 128 `buffer_load_u16` to 16 `buffer_load_b128` and median
kernel time drops from 3.04 ms to 2.71 ms (-10.9%). Correctness matches
torch SDPA to 1e-4 max abs error.
The default RDNA config for `attn_fwd` uses BLOCK_N=32 + waves_per_eu=6
to target 6-wave occupancy. On gfx1151 the kernel allocates 32 KB
shared per workgroup (Q tile staged through LDS), which already caps
occupancy at 4 WGs/WGP (= 2 waves/EU) regardless of the register
budget. The `waves_per_eu=6` hint then makes the register allocator
spill / shorten live ranges to fit waves that will never schedule,
which degrades codegen: fewer s_delay_alu hints, shorter clauses,
more scalar f32 multiplies in the softmax path.

Drop the hint to `waves_per_eu=2` (matching CK's high-VGPR
low-occupancy choice for the same kernel on the same arch) and pair
it with BLOCK_N=64 so each workgroup amortises loop-dispatch overhead
over 2x the WMMA work per K-tile. The inner-loop mnemonic profile
then collapses onto CK's: s_clause matches at 0.03/WMMA, s_waitcnt
1.19 vs 1.31, total instr/WMMA drops from 12.2 to 11.8.

On AITER's flash_attn_2.varlen_fwd at the Qwen3-Omni ViT prefill shape
(B=1, S=3200, H=16, head_dim=72, fp16) on gfx1151, median kernel time
goes from 2.47 ms to 2.35 ms (-4.9%), now 3% faster than the CK
ck_tile FmhaFwdKernel reference at 2.42 ms. Min times are 2.28 vs CK
2.34 ms (-2.6%). Stacked on top of the previous head-stride
divisibility-by-8 hint commit, total speedup vs the upstream baseline
is 3.04 ms -> 2.35 ms (-22.7%).

Scoped to gfx1151 only; other RDNA targets keep the existing default
pending validation.
@github-actions

Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3419 --add-label <label>

@mgehre-amd mgehre-amd marked this pull request as ready for review June 2, 2026 09:58
@mgehre-amd mgehre-amd requested review from a team and Copilot June 2, 2026 09:58

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR tunes the AMD Triton FlashAttention forward-prefill kernel for RDNA3.5 (gfx1151) and adds a compiler-alignment hint to improve vectorized global loads when head-axis strides are suitably aligned.

Changes:

  • Add a gfx1151-specific RDNA prefill config (BLOCK_N=64, waves_per_eu=2) tuned for the Qwen3-Omni ViT prefill shape.
  • Introduce HEAD_STRIDE_ALIGNED_8 constexpr plumbing and a runtime stride check to safely apply tl.multiple_of(..., 8) on Q/K/V head-axis offsets.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/fwd_prefill.py Outdated
The in-kernel comment claimed the head-axis byte offset is 16-byte
aligned whenever stride_*h % 8 == 0. That only holds for 16-bit element
types. The hint guarantees an 8-element multiple, whose byte alignment
is element-size dependent (16 B for fp16/bf16, 8 B for fp8, 32 B for
fp32); only the 16-bit case yields the buffer_load_b128 widening this
PR targets. The tl.multiple_of(..., 8) hint stays sound for all dtypes.

Comment-only change, no behavioral impact.

@brunomazzottiamd brunomazzottiamd left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code looks good to me! Head-stride div-by-8 hint seems to be fine too, given the benchmarking performed in the scope of #3424.

However, we have a CI failure in Flash Attention - Triton / RDNA3 (1 GPU) job:

[aiter] import [module_aiter_core] under /aiter/aiter/jit/module_aiter_core.so
Traceback (most recent call last):
  File "/flash-attention/benchmarks/benchmark_flash_attention.py", line 11, in <module>
    from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
  File "/opt/venv/lib/python3.12/site-packages/flash_attn/__init__.py", line 8, in <module>
    from flash_attn.flash_attn_interface import (
  File "/opt/venv/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py", line 21, in <module>
    from aiter.ops.triton._triton_kernels.flash_attn_triton_amd import flash_attn_2 as flash_attn_gpu
  File "/aiter/aiter/__init__.py", line 92, in <module>
    from .ops.gemm_op_a8w8 import *  # noqa: F403,E402
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/aiter/aiter/ops/gemm_op_a8w8.py", line 22, in <module>
    from ..ops.flydsl.utils import is_flydsl_available
  File "/aiter/aiter/ops/flydsl/__init__.py", line 40, in <module>
    from .gemm_kernels import flydsl_hgemm, flydsl_preshuffle_gemm_a8
  File "/aiter/aiter/ops/flydsl/gemm_kernels.py", line 705, in <module>
    _register_all_configs()
  File "/aiter/aiter/ops/flydsl/gemm_kernels.py", line 701, in _register_all_configs
    get_flydsl_splitk_hgemm_kernels(dtype, out_dtype)
  File "/aiter/aiter/ops/flydsl/gemm_kernels.py", line 626, in get_flydsl_splitk_hgemm_kernels
    config = _normalize_registry_config(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/aiter/aiter/ops/flydsl/gemm_kernels.py", line 525, in _normalize_registry_config
    _validate_hgemm_tiling(
  File "/aiter/aiter/ops/flydsl/gemm_kernels.py", line 383, in _validate_hgemm_tiling
    if not selection_filter(m, n, k, config):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/aiter/aiter/ops/flydsl/gemm_kernels.py", line 337, in selection_filter
    smem_cap = SMEM_CAPACITY_MAP[GPU_ARCH]
               ~~~~~~~~~~~~~~~~~^^^^^^^^^^
KeyError: 'gfx1100'

It seems some import is failing due to lack of RDNA3 support and it's FlyDSL related... I'm not sure how to proceed with this, as far as I known FlyDSL isn't supported on RDNA3:

Verified Platforms:

AMD MI300X/MI308X (gfx942), AMD MI350/MI355X (gfx950), AMD MI450 (gfx1250), Radeon AI PRO R9700 (gfx1201)

It's a burden we have to deal with if we want to enhance RDNA support.

@mgehre-amd

Copy link
Copy Markdown
Contributor Author

It seems some import is failing due to lack of RDNA3 support and it's FlyDSL related... I'm not sure how to proceed with this, as far as I known FlyDSL isn't supported on RDNA3:

I had added RDNA3.5 support to FlyDLS in ROCm/FlyDSL#567; but I don't see how this PR causes that CI failure. The PR doesn't touch any files involved in the FlyDSL test.

@brunomazzottiamd

Copy link
Copy Markdown
Contributor

I had added RDNA3.5 support to FlyDLS in ROCm/FlyDSL#567; but I don't see how this PR causes that CI failure. The PR doesn't touch any files involved in the FlyDSL test.

@mgehre-amd, Flash Attention Integration job doesn't trigger for every PR. It was triggered for this one because you have changed a file in aiter/ops/triton/_triton_kernels/flash_attn_triton_amd directory (https://github.com/ROCm/aiter/blob/main/.github/workflows/flash_attention_integration.yaml#L7). My guess is that between the last successful run of Flash Attention Integration job and this PR something that broke the job got merged.

Unfortunately we have to fix this issue before merging.

@micmelesse micmelesse left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Wait for #3683 to merge and rebase

@brunomazzottiamd brunomazzottiamd left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Let's make CI pass and then merge.

@mgehre-amd mgehre-amd merged commit c9333f7 into ROCm:main Jun 15, 2026
88 of 89 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants