[FMHA] add flash_attn_interface L2 wrapper and cross-length Q/KV (seqlen_q != seqlen_kv)#704
Open
yanguahe wants to merge 5 commits into
Open
[FMHA] add flash_attn_interface L2 wrapper and cross-length Q/KV (seqlen_q != seqlen_kv)#704yanguahe wants to merge 5 commits into
yanguahe wants to merge 5 commits into
Conversation
- Add flash_attn_dualwave_swp_gfx950_kernel with lazy-rescale, s_setprio stagger, split-K combine path, and buffer_store_dwordx4 O-store - Support packed QKV varlen via cu_seqlens; arbitrary seq_len >= 1 on both dualwave and generic fallback paths with padding masks - Update flash_attn_generic dispatch, seq_len guard, and varlen routing - Extend test_flash_attn_fwd with split-K, varlen configs, OPUS/aiter compare Ported from opus_align FMHA optimization work onto rocm/main base. Co-authored-by: Cursor <cursoragent@cursor.com>
The generic flash_attn O-store used permlane32_swap and cvt_pk_bf16_f32 (both gfx950/CDNA4-only) unconditionally. On gfx942 (CDNA3) the gfx950 dualwave fast path is disabled and flash_attn falls back to the generic kernel, so the backend hit "Cannot select intrinsic llvm.amdgcn.permlane32.swap" and aborted (CI: test linux-flydsl-mi325-1). Gate the 128-bit permlane-fused store behind gfx950; gfx942 falls back to a per-lane dwordx2 store packed via .to(elem_dtype) (arch-correct bf16/f16 conversion, same column layout, still num_records-bounded for OOB rows). Add FLYDSL_DISABLE_DUALWAVE_SWP / FLYDSL_GENERIC_OSTORE_SCALAR env hooks to exercise the generic kernel and its gfx942 store path on gfx950 hardware. Verified on gfx950 (MI355): the permlane and scalar O-store paths both give MaxErr 3.91e-3 vs SDPA across H8/16/64, GQA, and partial-seqlen configs; the default gfx950 dualwave path is unchanged (PASS, MaxErr 3.91e-3). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
- Shorten verbose comments in flash_attn_generic and flash_attn_gfx950 - Drop unused FLYDSL_GENERIC_OSTORE_SCALAR knob; gfx942 O-store fallback unchanged - Extend run_benchmark DEFAULT_FLASH_ATTN_FUNC_SHAPES with causal/non-causal seq_len 1-65 configs for arbitrary-length coverage - Keep run_benchmark Bandwidth parsing on the base-op first match Co-authored-by: Cursor <cursoragent@cursor.com>
- Add flydsl_flash_attn_func L2 wrapper with cached builds, split-K workspace, and explicit varlen max_seqlen/cross_seqlen controls - Support seqlen_q != seqlen_kv with bottom-right causal masking in dualwave kernel; gate extra v_s_1 mask behind cross_seqlen build flag - Pass cross_seqlen through flash_attn_generic; reject diff-KV on fallback - Unify test_flash_attn_fwd harness via run_attn_config (dense/varlen/split-K/ cross-length) with varlen CSV export - Remove unused arch variable in rmsnorm_kernel build helper Co-authored-by: Cursor <cursoragent@cursor.com>
- Add kernels/flash_attn_interface.py flydsl_flash_attn_func L2 wrapper - gfx950 DUALWAVE_SWP cross_seqlen flag and bottom-right causal masking - Route seq_len_kv != seq_len through dualwave path; reject on generic fallback - Refactor test_flash_attn_fwd.py with unified run_attn_config harness Merged commit: 8461fe1 Co-authored-by: Cursor <cursoragent@cursor.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
This PR adds two things on top of what is already in
main(#683, #685):kernels/flash_attn_interface.py— a new process-level L2 API wrapper(
flydsl_flash_attn_func) that provides@lru_cachebuild caching, unifiedshape / dtype / arch validation, automatic
cross_seqlendetection, andsplit-K workspace management.
Cross-length attention (
seqlen_q != seqlen_kv) — the gfx950 DUALWAVE_SWPkernel now supports arbitrary Q and KV sequence lengths with a
bottom-right-aligned causal mask, for both dense and packed-varlen
(cu_seqlens) inputs.
The accompanying test-harness update merges the four
run_config*/run_splitk_configfunctions into a singlerun_attn_config, shares thereference computation across FlyDSL / aiter_ck / aiter_asm in
--comparemode(eliminating 2 out of 3 expensive reference calls), and exposes all kernel
options as CLI flags.
Technical Details
Changed files (4):
kernels/flash_attn_interface.pykernels/flash_attn_gfx950.pyseq_len_kvparam + bottom-right causal +cross_seqlenflagkernels/flash_attn_generic.pycross_seqlenthrough dispatch;seq_len_kvin fallback guardtests/kernels/test_flash_attn_fwd.pyrun_attn_config; shared ref; CLI flagsTest Plan
on MI355X (gfx950),
HIP_VISIBLE_DEVICES=1.Test Result
Measured on AMD Instinct MI355X (gfx950), bf16+fp16,
--comparevsmainacross all DEFAULT_CONFIGS (27 shapes x 2 dtype x causal + non-causal = 108
matched configurations),
HIP_VISIBLE_DEVICES=1:mainacross all 108 matched configurations.Breakdown: bf16/causal +0.7%, bf16/nocausal +1.6%, fp16/causal +1.3%,
fp16/nocausal +1.0%.
run-to-run noise for a roofline-pinned config.
MaxErrwithin bf16/fp16 tolerance.black+ruffclean.Submission Checklist
black+ruff checkclean (line-length 120).