[Small_M_GEMM_GroupGEMM_MXFP8] Decode small-M MX-FP8 GEMM and GroupGEMM kernels for gfx950#3783
Open
JohnQinAMD wants to merge 3 commits into
Open
[Small_M_GEMM_GroupGEMM_MXFP8] Decode small-M MX-FP8 GEMM and GroupGEMM kernels for gfx950#3783JohnQinAMD wants to merge 3 commits into
JohnQinAMD wants to merge 3 commits into
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Contributor
There was a problem hiding this comment.
Pull request overview
Adds gfx950-targeted HIP decode-regime kernels and dispatch for OCP MX-FP8 (e4m3 payload + e8m0 1x32 K-block scales) covering dense small-M GEMV/MFMA and MoE grouped GEMM, with tuned-shape routing via a shipped CSV and accompanying test/bench tooling.
Changes:
- Introduces new HIP kernels for dense GEMV and MFMA crossover, plus MoE grouped GEMM (gfx950 guarded).
- Adds Python wrappers with envelope/tuned-CSV-based dispatch returning
Noneto fall back to Triton when unsupported/unpreferred. - Adds tuning/benchmark scripts, a tuned CSV, and pytest correctness + guard/negative tests; wires new JIT modules into
optCompilerConfig.json.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
op_tests/tune_smallm_mxfp8.py |
Dev autotune sweep to generate the per-(K,N,M) tuned CSV; CI skip gating. |
op_tests/test_smallm_gemm_mxfp8.py |
Pytest correctness + negative/guard tests for dense and MoE paths. |
op_tests/bench_smallm_mxfp8.py |
Dev microbenchmark comparing HIP vs Triton baseline; CI skip gating. |
csrc/smallm_gemm_mxfp8/smallm_mxfp8_moe.cu |
gfx950 MoE grouped GEMM HIP kernel + host launcher/validation. |
csrc/smallm_gemm_mxfp8/smallm_mxfp8_dense.cu |
gfx950 dense GEMV HIP kernel + host launcher/validation. |
csrc/smallm_gemm_mxfp8/smallm_mxfp8_dense_mfma.cu |
gfx950 MFMA crossover HIP kernel (split-K + M-tiling) + host launcher. |
aiter/ops/smallm_gemm_mxfp8.py |
Python JIT bindings and envelope/tuned-config dispatch wrappers for dense + MoE. |
aiter/jit/optCompilerConfig.json |
Registers new JIT modules and HIP compile flags. |
aiter/configs/smallm_mxfp8_tuned.csv |
Shipped tuned per-(K,N,M) kernel/config selections including use_hip routing. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Hand-written HIP kernels for the small-M (decode) regime of OCP MX-FP8
(fp8 e4m3 data + e8m0 1x32 K-block scales) on AMD CDNA4 (gfx950 / MI355X),
filling a real gap: aiter's CK/ASM microscaling is MXFP4-only and its fp8
CK/ASM is 128-block, not 1x32 MX -- there is no MXFP8-1x32 GEMM for the
small-M decode regime.
Kernels (each a standalone @compile_ops JIT module):
* smallm_mxfp8_gemv dense GEMV, M in {1,2,4,8,16}; packed-VFMA,
register-resident e8m0 scales via ldexpf
* smallm_mxfp8_mfma dense MFMA crossover, M in {8,16,32,64};
mfma_scale_f32_16x16x128_f8f6f4, split-K + M-tiling
* smallm_mxfp8_moe_grouped_gemm MoE grouped GEMM (sorted_token_ids layout),
per-(N,K) tunable BLOCK_N in {8,16}
gfx950-only: device path under #if defined(__gfx950__) (inert -> builds on
every arch) + host TORCH_CHECK(get_gpu_arch()=="gfx950"). Byte/offset math is
64-bit; the raw-buffer voffset is a signed int32, so a host guard rejects
tensors whose byte size exceeds the 2GB (2^31) offset limit before launch (a
GPU fault is uncatchable). Host validation: K%32 (dense), sorted_token_ids%
block_m mul_weight_by dtype/device/contiguity/size and out contiguity/size (MoE); wrappers force
contiguous payload+scale tensors.
Dense dispatch is envelope-based, not allowlist-based: mxfp8_gemv engages HIP
for any in-envelope shape (GEMV M<=16, MFMA M in {8,16,32,64}, K%32/128, bf16,
gfx950), so it generalizes across TP degrees and models without per-shape
hand-tuning. The autotuned per-(K,N,M) CSV (op_tests/tune_smallm_mxfp8.py
sweeps n_sub/k_splits, HIP vs vanilla Triton dot_scaled, >3% margin) supplies
the best config + an explicit use_hip flag; absent a tuned entry an in-envelope
shape still engages (hand table or default), except untuned M in {32,64} which
fall back to Triton (default config can lose there -- measured). Tuned per-GPU
shapes ship for TP1/TP2/TP4/TP8 (qkv / gate_up / o_proj / down).
MoE selection is keyed on (N,K,a_div,has_weight) -- E-agnostic so it engages
under any expert-parallel degree (M3 has 256 experts: E=128/64/32 per GPU at
EP2/EP4/EP8) -- with an E-aware M_routed bound (the win-envelope widens with E:
gemm1 deep-K wins to 16 at any E and to 64 once E>=128; gemm2 shallow-K to 8).
no-EP gemm1 (E=256, 2.4GB) exceeds the 2GB offset limit and falls back to
Triton; no-EP gemm2 (1.2GB) runs on HIP.
Measured on MI355X (gfx950): HIP beats vanilla Triton on dense decode cells
1.4-4x (a few M=64 cells tie/lose -> use_hip=0, routed to Triton); MoE wins
~2-3x at small M_routed (BLOCK_N=16 adds ~8-13% on gemm2). op_tests correctness
vs a PyTorch reference over the same e4m3/e8m0 bits: 118 passed, 3 skipped (the
use_hip=0 cells), plus negative tests for the host-validation guards and the
>2GB fallback.
Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: John Qin <yanyuan.qin@amd.com>
4a7a50e to
f97adba
Compare
The MoE allowlist was TP4-shaped (gemm1 N=1536=2*intermediate/4); at TP8 the per-GPU gate_up width halves to N=768 (=2*intermediate/8), so the kernel fell back to Triton. Add the autotuned (768,6144,4,False) entry (HIP wins 1.15-4.1x across the decode M_routed range on MI355X). gemm2 @TP8 is N=6144,K=384 -- K=384 fails the K%1024-or-768 preflight and it is fp32-out, so it stays on Triton. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: John Qin <yanyuan.qin@amd.com>
1 task
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
HIP kernels for the small-M (decode) regime of OCP MX-FP8 (fp8 e4m3 data + e8m0 1x32 K-block scales) on AMD CDNA4 (gfx950 / MI355X), filling a gap: aiter's CK/ASM microscaling is MXFP4-only and its fp8 CK/ASM is 128-block, not 1x32 MX — there is no MXFP8-1x32 GEMM for the small-M decode regime.
Kernels (each a standalone
@compile_opsJIT module):smallm_mxfp8_gemv— dense GEMV, M in {1,2,4,8,16}; packed-VFMA, register-resident e8m0 scales vialdexpfsmallm_mxfp8_mfma— dense MFMA crossover, M in {8,16,32,64};mfma_scale_f32_16x16x128_f8f6f4, split-K + M-tilingsmallm_mxfp8_moe_grouped_gemm— MoE grouped GEMM (sorted_token_ids layout)gfx950-only: device path under
#if defined(__gfx950__)+ hostTORCH_CHECK(get_gpu_arch()=="gfx950"). 64-bit byte/offset math; the raw-buffer voffset is a signed int32, so a host guard rejects tensors >2 GB (the 2³¹ offset limit) before launch.Selection is envelope-based, not allowlist-based, so it generalizes across TP degrees and models without per-shape hand-tuning:
mxfp8_gemvengages HIP for any shape the kernels support (GEMV M∈{1..16}; MFMA M∈{8,16,32,64}; K a multiple of 32/128; bf16 out; gfx950). The autotuned per-(K,N,M) CSV is consulted first for the best config + an explicituse_hipflag (HIP wins by >3%, else Triton; the o_proj M=64 cell is pinned to HIP at parity). Absent a tuned entry, an in-envelope shape still engages with the hand-tuned_MFMA_CFGor a default config — except untuned M∈{32,64}, which require a tuned entry to engage (default config can lose to Triton there; measured). Tuned per-GPU shapes shipped for TP1, TP2, TP4, and TP8 (the CSV recordsuse_hip=0for the few M=64 cells where Triton wins, so they route correctly).Performance (MI355X / gfx950)
op_tests/bench_smallm_mxfp8.py— HIP GEMV/MFMA vs the stock Tritondot_scaledpath it replaces (BLOCK 64/128/128, w8 — the unmodified vLLM/sglang fallback). us/call, 200 iters after 30 warmup. HIP wins every dense cell 1.4–3.6×, with one near-parity cell (o_proj M=64, see note):Absolute (HIP vs Triton us), representative cells: qkv M=1 7.20 vs 26.21; gate_up M=1 7.82 vs 26.13; o_proj M=1 7.65 vs 16.05; qkv M=64 21.97 vs 37.80.
* o_proj M=64 is a measured tie (run-to-run 0.97×–1.05×), pinned to HIP for a uniform path; not claimed as a speedup.
MoE grouped GEMM
Selection is keyed on
(N, K, a_div, has_weight)— E-agnostic for engagement — so the kernel fires under any expert-parallel degree (M3 has 256 experts → E=128/64/32 per GPU at EP2/EP4/EP8). An earlier E=128-keyed allowlist silently failed to match under EP, disengaging the kernel in the actual deployment. TheM_routedbound is E-aware because the win-envelope widens with E (more experts → more routing padding → Triton wastes more): gemm1 (deep K=6144) wins toM_routed≤16at any E and to ≤64 once E≥128 (EP2); gemm2 (shallow K=768) to≤8at any E — ~2–3× vs Triton at the decode operating point.BLOCK_Nis tuned per-(N,K) ({8,16}): gemm2 uses 16 for ~8–13% on top; gemm1 stays 8. (BiggerBLOCK_Ndoes not extend the envelope — the ceiling is routing padding, not tiling.) The raw-buffer voffset is signed int32, so weights >2 GB are rejected pre-launch: no-EP gemm1 (E=256, 2.4 GB) → Triton; no-EP gemm2 (1.2 GB) → HIP.Test methodology
op_tests/test_smallm_gemm_mxfp8.py— 118 passed, 3 skipped on MI355X.< 5e-2(≈ cosine 0.999). Coverage: densemxfp8_gemvover M ∈ {1,2,4,8,16,32,64} × 15 per-GPU shapes (TP1/2/4/8) + an untuned in-envelope shape (envelope-dispatch generalization); MoEgrouped_gemm_mxfp8over the routed configs incl. EP2/EP4/EP8 (E=128/64/32) cases and a>2 GB → Tritonfallback case. The 3 skips are the M=64use_hip=0cells (Triton wins → routed there).pytest.raisescases asserting the host/wrapper validation fires —K % 32 != 0(dense),M > 16(GEMV wrapper), a non-float32mul_weight_by(MoE), out-of-envelope M (returns None), and untuned-large-M fallback.requires_gfx950marker; the kernels are gfx950-only by host guard.Built locally with hipcc (all 3 JIT modules compile clean) and run on MI355X; black + ruff clean.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist