Skip to content

[FMHA] add flash_attn_interface L2 wrapper and cross-length Q/KV (seqlen_q != seqlen_kv)#704

Open
yanguahe wants to merge 5 commits into
mainfrom
yanguahe/fmha-interface-and-crosslen
Open

[FMHA] add flash_attn_interface L2 wrapper and cross-length Q/KV (seqlen_q != seqlen_kv)#704
yanguahe wants to merge 5 commits into
mainfrom
yanguahe/fmha-interface-and-crosslen

Conversation

@yanguahe

@yanguahe yanguahe commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

Motivation

This PR adds two things on top of what is already in main (#683, #685):

  1. kernels/flash_attn_interface.py — a new process-level L2 API wrapper
    (flydsl_flash_attn_func) that provides @lru_cache build caching, unified
    shape / dtype / arch validation, automatic cross_seqlen detection, and
    split-K workspace management.

  2. Cross-length attention (seqlen_q != seqlen_kv) — the gfx950 DUALWAVE_SWP
    kernel now supports arbitrary Q and KV sequence lengths with a
    bottom-right-aligned causal mask, for both dense and packed-varlen
    (cu_seqlens) inputs.

The accompanying test-harness update merges the four run_config* /
run_splitk_config functions into a single run_attn_config, shares the
reference computation across FlyDSL / aiter_ck / aiter_asm in --compare mode
(eliminating 2 out of 3 expensive reference calls), and exposes all kernel
options as CLI flags.

Technical Details

Changed files (4):

File Change
kernels/flash_attn_interface.py new — 371-line L2 wrapper
kernels/flash_attn_gfx950.py seq_len_kv param + bottom-right causal + cross_seqlen flag
kernels/flash_attn_generic.py thread cross_seqlen through dispatch; seq_len_kv in fallback guard
tests/kernels/test_flash_attn_fwd.py unified run_attn_config; shared ref; CLI flags

Test Plan

python tests/kernels/test_flash_attn_fwd.py --iters 100 --compare

on MI355X (gfx950), HIP_VISIBLE_DEVICES=1.

Test Result

Measured on AMD Instinct MI355X (gfx950), bf16+fp16, --compare vs main
across all DEFAULT_CONFIGS (27 shapes x 2 dtype x causal + non-causal = 108
matched configurations), HIP_VISIBLE_DEVICES=1:

  • Average +1.1% FlyDSL TFLOPS vs main across all 108 matched configurations.
    Breakdown: bf16/causal +0.7%, bf16/nocausal +1.6%, fp16/causal +1.3%,
    fp16/nocausal +1.0%.
  • Peak gain +10.9% at B=2,S=1024,H=64,bf16/nocausal.
  • Largest regression -2.9% at B=16,S=8192,H=64,fp16/causal — within
    run-to-run noise for a roofline-pinned config.
  • All MaxErr within bf16/fp16 tolerance. black + ruff clean.

Submission Checklist

  • black + ruff check clean (line-length 120).
  • Accuracy verified on MI355X (gfx950) — all DEFAULT_CONFIGS PASS.
  • Performance measured before/after on MI355X (bf16+fp16).
  • CI green.

yanguahe and others added 5 commits June 15, 2026 04:47
- Add flash_attn_dualwave_swp_gfx950_kernel with lazy-rescale, s_setprio
  stagger, split-K combine path, and buffer_store_dwordx4 O-store
- Support packed QKV varlen via cu_seqlens; arbitrary seq_len >= 1 on both
  dualwave and generic fallback paths with padding masks
- Update flash_attn_generic dispatch, seq_len guard, and varlen routing
- Extend test_flash_attn_fwd with split-K, varlen configs, OPUS/aiter compare

Ported from opus_align FMHA optimization work onto rocm/main base.

Co-authored-by: Cursor <cursoragent@cursor.com>
The generic flash_attn O-store used permlane32_swap and cvt_pk_bf16_f32
(both gfx950/CDNA4-only) unconditionally. On gfx942 (CDNA3) the gfx950
dualwave fast path is disabled and flash_attn falls back to the generic
kernel, so the backend hit "Cannot select intrinsic
llvm.amdgcn.permlane32.swap" and aborted (CI: test linux-flydsl-mi325-1).

Gate the 128-bit permlane-fused store behind gfx950; gfx942 falls back to a
per-lane dwordx2 store packed via .to(elem_dtype) (arch-correct bf16/f16
conversion, same column layout, still num_records-bounded for OOB rows).
Add FLYDSL_DISABLE_DUALWAVE_SWP / FLYDSL_GENERIC_OSTORE_SCALAR env hooks to
exercise the generic kernel and its gfx942 store path on gfx950 hardware.

Verified on gfx950 (MI355): the permlane and scalar O-store paths both give
MaxErr 3.91e-3 vs SDPA across H8/16/64, GQA, and partial-seqlen configs; the
default gfx950 dualwave path is unchanged (PASS, MaxErr 3.91e-3).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
- Shorten verbose comments in flash_attn_generic and flash_attn_gfx950
- Drop unused FLYDSL_GENERIC_OSTORE_SCALAR knob; gfx942 O-store fallback unchanged
- Extend run_benchmark DEFAULT_FLASH_ATTN_FUNC_SHAPES with causal/non-causal
  seq_len 1-65 configs for arbitrary-length coverage
- Keep run_benchmark Bandwidth parsing on the base-op first match

Co-authored-by: Cursor <cursoragent@cursor.com>
- Add flydsl_flash_attn_func L2 wrapper with cached builds, split-K workspace,
  and explicit varlen max_seqlen/cross_seqlen controls
- Support seqlen_q != seqlen_kv with bottom-right causal masking in dualwave
  kernel; gate extra v_s_1 mask behind cross_seqlen build flag
- Pass cross_seqlen through flash_attn_generic; reject diff-KV on fallback
- Unify test_flash_attn_fwd harness via run_attn_config (dense/varlen/split-K/
  cross-length) with varlen CSV export
- Remove unused arch variable in rmsnorm_kernel build helper

Co-authored-by: Cursor <cursoragent@cursor.com>
- Add kernels/flash_attn_interface.py flydsl_flash_attn_func L2 wrapper
- gfx950 DUALWAVE_SWP cross_seqlen flag and bottom-right causal masking
- Route seq_len_kv != seq_len through dualwave path; reject on generic fallback
- Refactor test_flash_attn_fwd.py with unified run_attn_config harness

Merged commit: 8461fe1

Co-authored-by: Cursor <cursoragent@cursor.com>
@yanguahe yanguahe changed the title [FMHA] gfx950 interface wrapper, cross-length Q/KV, split-K, varlen, and batch-aware routing [FMHA] add flash_attn_interface L2 wrapper and cross-length Q/KV (seqlen_q != seqlen_kv) Jun 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant