Skip to content

[Small_M_GEMM_GroupGEMM_MXFP8] Decode small-M MX-FP8 GEMM and GroupGEMM kernels for gfx950#3783

Open
JohnQinAMD wants to merge 3 commits into
mainfrom
pr/m3-mxfp8-decode-kernels
Open

[Small_M_GEMM_GroupGEMM_MXFP8] Decode small-M MX-FP8 GEMM and GroupGEMM kernels for gfx950#3783
JohnQinAMD wants to merge 3 commits into
mainfrom
pr/m3-mxfp8-decode-kernels

Conversation

@JohnQinAMD

Copy link
Copy Markdown
Contributor

HIP kernels for the small-M (decode) regime of OCP MX-FP8 (fp8 e4m3 data + e8m0 1x32 K-block scales) on AMD CDNA4 (gfx950 / MI355X), filling a gap: aiter's CK/ASM microscaling is MXFP4-only and its fp8 CK/ASM is 128-block, not 1x32 MX — there is no MXFP8-1x32 GEMM for the small-M decode regime.

Kernels (each a standalone @compile_ops JIT module):

  • smallm_mxfp8_gemv — dense GEMV, M in {1,2,4,8,16}; packed-VFMA, register-resident e8m0 scales via ldexpf
  • smallm_mxfp8_mfma — dense MFMA crossover, M in {8,16,32,64}; mfma_scale_f32_16x16x128_f8f6f4, split-K + M-tiling
  • smallm_mxfp8_moe_grouped_gemm — MoE grouped GEMM (sorted_token_ids layout)

gfx950-only: device path under #if defined(__gfx950__) + host TORCH_CHECK(get_gpu_arch()=="gfx950"). 64-bit byte/offset math; the raw-buffer voffset is a signed int32, so a host guard rejects tensors >2 GB (the 2³¹ offset limit) before launch.

Selection is envelope-based, not allowlist-based, so it generalizes across TP degrees and models without per-shape hand-tuning: mxfp8_gemv engages HIP for any shape the kernels support (GEMV M∈{1..16}; MFMA M∈{8,16,32,64}; K a multiple of 32/128; bf16 out; gfx950). The autotuned per-(K,N,M) CSV is consulted first for the best config + an explicit use_hip flag (HIP wins by >3%, else Triton; the o_proj M=64 cell is pinned to HIP at parity). Absent a tuned entry, an in-envelope shape still engages with the hand-tuned _MFMA_CFG or a default config — except untuned M∈{32,64}, which require a tuned entry to engage (default config can lose to Triton there; measured). Tuned per-GPU shapes shipped for TP1, TP2, TP4, and TP8 (the CSV records use_hip=0 for the few M=64 cells where Triton wins, so they route correctly).

Performance (MI355X / gfx950)

op_tests/bench_smallm_mxfp8.py — HIP GEMV/MFMA vs the stock Triton dot_scaled path it replaces (BLOCK 64/128/128, w8 — the unmodified vLLM/sglang fallback). us/call, 200 iters after 30 warmup. HIP wins every dense cell 1.4–3.6×, with one near-parity cell (o_proj M=64, see note):

proj K N M=1 M=2 M=4 M=8 M=16 M=32 M=64
qkv 6144 2304 3.64× 3.48× 2.88× 2.49× 2.52× 2.20× 1.72×
o_proj 2048 6144 2.10× 2.10× 2.09× 2.09× 1.99× 1.63× ~1.0×*
gate_up 6144 1536 3.34× 3.42× 3.43× 2.36× 2.45× 2.74× 2.41×
mlp_down 1536 6144 2.06× 2.08× 1.91× 2.13× 2.12× 2.12× 1.42×

Absolute (HIP vs Triton us), representative cells: qkv M=1 7.20 vs 26.21; gate_up M=1 7.82 vs 26.13; o_proj M=1 7.65 vs 16.05; qkv M=64 21.97 vs 37.80.

* o_proj M=64 is a measured tie (run-to-run 0.97×–1.05×), pinned to HIP for a uniform path; not claimed as a speedup.

MoE grouped GEMM

Selection is keyed on (N, K, a_div, has_weight) — E-agnostic for engagement — so the kernel fires under any expert-parallel degree (M3 has 256 experts → E=128/64/32 per GPU at EP2/EP4/EP8). An earlier E=128-keyed allowlist silently failed to match under EP, disengaging the kernel in the actual deployment. The M_routed bound is E-aware because the win-envelope widens with E (more experts → more routing padding → Triton wastes more): gemm1 (deep K=6144) wins to M_routed≤16 at any E and to ≤64 once E≥128 (EP2); gemm2 (shallow K=768) to ≤8 at any E — ~2–3× vs Triton at the decode operating point. BLOCK_N is tuned per-(N,K) ({8,16}): gemm2 uses 16 for ~8–13% on top; gemm1 stays 8. (Bigger BLOCK_N does not extend the envelope — the ceiling is routing padding, not tiling.) The raw-buffer voffset is signed int32, so weights >2 GB are rejected pre-launch: no-EP gemm1 (E=256, 2.4 GB) → Triton; no-EP gemm2 (1.2 GB) → HIP.

Test methodology

op_tests/test_smallm_gemm_mxfp8.py118 passed, 3 skipped on MI355X.

  • Correctness: each op is compared against a PyTorch reference that dequantizes the same e4m3 data + e8m0 1x32 scales the kernel reads, so the only difference is fp8 matrix-core accumulation vs a bf16 matmul. Pass bar: relative error < 5e-2 (≈ cosine 0.999). Coverage: dense mxfp8_gemv over M ∈ {1,2,4,8,16,32,64} × 15 per-GPU shapes (TP1/2/4/8) + an untuned in-envelope shape (envelope-dispatch generalization); MoE grouped_gemm_mxfp8 over the routed configs incl. EP2/EP4/EP8 (E=128/64/32) cases and a >2 GB → Triton fallback case. The 3 skips are the M=64 use_hip=0 cells (Triton wins → routed there).
  • Guard / negative tests: pytest.raises cases asserting the host/wrapper validation fires — K % 32 != 0 (dense), M > 16 (GEMV wrapper), a non-float32 mul_weight_by (MoE), out-of-envelope M (returns None), and untuned-large-M fallback.
  • Off-gfx95x: the whole module skips cleanly via a requires_gfx950 marker; the kernels are gfx950-only by host guard.

Built locally with hipcc (all 3 JIT modules compile clean) and run on MI355X; black + ruff clean.

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

@JohnQinAMD JohnQinAMD requested review from a team and Copilot June 17, 2026 17:52
@github-actions

Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3783 --add-label <label>

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds gfx950-targeted HIP decode-regime kernels and dispatch for OCP MX-FP8 (e4m3 payload + e8m0 1x32 K-block scales) covering dense small-M GEMV/MFMA and MoE grouped GEMM, with tuned-shape routing via a shipped CSV and accompanying test/bench tooling.

Changes:

  • Introduces new HIP kernels for dense GEMV and MFMA crossover, plus MoE grouped GEMM (gfx950 guarded).
  • Adds Python wrappers with envelope/tuned-CSV-based dispatch returning None to fall back to Triton when unsupported/unpreferred.
  • Adds tuning/benchmark scripts, a tuned CSV, and pytest correctness + guard/negative tests; wires new JIT modules into optCompilerConfig.json.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
op_tests/tune_smallm_mxfp8.py Dev autotune sweep to generate the per-(K,N,M) tuned CSV; CI skip gating.
op_tests/test_smallm_gemm_mxfp8.py Pytest correctness + negative/guard tests for dense and MoE paths.
op_tests/bench_smallm_mxfp8.py Dev microbenchmark comparing HIP vs Triton baseline; CI skip gating.
csrc/smallm_gemm_mxfp8/smallm_mxfp8_moe.cu gfx950 MoE grouped GEMM HIP kernel + host launcher/validation.
csrc/smallm_gemm_mxfp8/smallm_mxfp8_dense.cu gfx950 dense GEMV HIP kernel + host launcher/validation.
csrc/smallm_gemm_mxfp8/smallm_mxfp8_dense_mfma.cu gfx950 MFMA crossover HIP kernel (split-K + M-tiling) + host launcher.
aiter/ops/smallm_gemm_mxfp8.py Python JIT bindings and envelope/tuned-config dispatch wrappers for dense + MoE.
aiter/jit/optCompilerConfig.json Registers new JIT modules and HIP compile flags.
aiter/configs/smallm_mxfp8_tuned.csv Shipped tuned per-(K,N,M) kernel/config selections including use_hip routing.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread csrc/smallm_gemm_mxfp8/smallm_mxfp8_moe.cu Outdated
Comment thread csrc/smallm_gemm_mxfp8/smallm_mxfp8_dense.cu Outdated
Comment thread aiter/ops/smallm_gemm_mxfp8.py
Comment thread aiter/ops/smallm_gemm_mxfp8.py
Comment thread op_tests/test_smallm_gemm_mxfp8.py
Comment thread op_tests/test_smallm_gemm_mxfp8.py
Comment thread op_tests/tune_smallm_mxfp8.py Outdated
Comment thread op_tests/bench_smallm_mxfp8.py Outdated
Hand-written HIP kernels for the small-M (decode) regime of OCP MX-FP8
(fp8 e4m3 data + e8m0 1x32 K-block scales) on AMD CDNA4 (gfx950 / MI355X),
filling a real gap: aiter's CK/ASM microscaling is MXFP4-only and its fp8
CK/ASM is 128-block, not 1x32 MX -- there is no MXFP8-1x32 GEMM for the
small-M decode regime.

Kernels (each a standalone @compile_ops JIT module):
  * smallm_mxfp8_gemv      dense GEMV, M in {1,2,4,8,16}; packed-VFMA,
                           register-resident e8m0 scales via ldexpf
  * smallm_mxfp8_mfma      dense MFMA crossover, M in {8,16,32,64};
                           mfma_scale_f32_16x16x128_f8f6f4, split-K + M-tiling
  * smallm_mxfp8_moe_grouped_gemm   MoE grouped GEMM (sorted_token_ids layout),
                           per-(N,K) tunable BLOCK_N in {8,16}

gfx950-only: device path under #if defined(__gfx950__) (inert -> builds on
every arch) + host TORCH_CHECK(get_gpu_arch()=="gfx950"). Byte/offset math is
64-bit; the raw-buffer voffset is a signed int32, so a host guard rejects
tensors whose byte size exceeds the 2GB (2^31) offset limit before launch (a
GPU fault is uncatchable). Host validation: K%32 (dense), sorted_token_ids%
block_m mul_weight_by dtype/device/contiguity/size and out contiguity/size (MoE); wrappers force
contiguous payload+scale tensors.

Dense dispatch is envelope-based, not allowlist-based: mxfp8_gemv engages HIP
for any in-envelope shape (GEMV M<=16, MFMA M in {8,16,32,64}, K%32/128, bf16,
gfx950), so it generalizes across TP degrees and models without per-shape
hand-tuning. The autotuned per-(K,N,M) CSV (op_tests/tune_smallm_mxfp8.py
sweeps n_sub/k_splits, HIP vs vanilla Triton dot_scaled, >3% margin) supplies
the best config + an explicit use_hip flag; absent a tuned entry an in-envelope
shape still engages (hand table or default), except untuned M in {32,64} which
fall back to Triton (default config can lose there -- measured). Tuned per-GPU
shapes ship for TP1/TP2/TP4/TP8 (qkv / gate_up / o_proj / down).

MoE selection is keyed on (N,K,a_div,has_weight) -- E-agnostic so it engages
under any expert-parallel degree (M3 has 256 experts: E=128/64/32 per GPU at
EP2/EP4/EP8) -- with an E-aware M_routed bound (the win-envelope widens with E:
gemm1 deep-K wins to 16 at any E and to 64 once E>=128; gemm2 shallow-K to 8).
no-EP gemm1 (E=256, 2.4GB) exceeds the 2GB offset limit and falls back to
Triton; no-EP gemm2 (1.2GB) runs on HIP.

Measured on MI355X (gfx950): HIP beats vanilla Triton on dense decode cells
1.4-4x (a few M=64 cells tie/lose -> use_hip=0, routed to Triton); MoE wins
~2-3x at small M_routed (BLOCK_N=16 adds ~8-13% on gemm2). op_tests correctness
vs a PyTorch reference over the same e4m3/e8m0 bits: 118 passed, 3 skipped (the
use_hip=0 cells), plus negative tests for the host-validation guards and the
>2GB fallback.

Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: John Qin <yanyuan.qin@amd.com>
The MoE allowlist was TP4-shaped (gemm1 N=1536=2*intermediate/4); at TP8 the
per-GPU gate_up width halves to N=768 (=2*intermediate/8), so the kernel fell
back to Triton. Add the autotuned (768,6144,4,False) entry (HIP wins 1.15-4.1x
across the decode M_routed range on MI355X). gemm2 @TP8 is N=6144,K=384 -- K=384
fails the K%1024-or-768 preflight and it is fp32-out, so it stays on Triton.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: John Qin <yanyuan.qin@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants