Skip to content

Add MXFP8 grouped GEMM (E4M3) for routed-expert MoE training#31

Draft
ysa2215 wants to merge 12 commits into
mainfrom
yue/mxfp8-grouped-gemm
Draft

Add MXFP8 grouped GEMM (E4M3) for routed-expert MoE training#31
ysa2215 wants to merge 12 commits into
mainfrom
yue/mxfp8-grouped-gemm

Conversation

@ysa2215

@ysa2215 ysa2215 commented Jun 18, 2026

Copy link
Copy Markdown
Collaborator

Description:

Why

GPT-OSS-style MoE training needs a grouped GEMM to run on MXFP8 — without it the MXFP8 dispatch path errored out on routed-expert layers, restricting MXFP8 to Linear targets. This is the minimum viable kernel to get MoE training running.

V1 uses E4M3 for all operands across fwd/dgrad/wgrad. A single format keeps the kernel free of dtype dispatch and the autograd free of a separate grad_output quantization path — simplest to implement and validate. The known cost is gradient
numerical robustness (E4M3's narrow dynamic range can underflow the long tail / overflow spikes of grad_output); the toy-MoE training test validates whether V1 is adequate in range.

The mixed-format recipe (grad_output in E5M2 — the TE / Meta / TorchAO industry consensus) is reserved for V2. To make that upgrade kernel-free, the API already exposes independent fwd_format / bwd_grad_format params (both defaulting to
E4M3) and the kernel carries a pre-written E5M2 branch; V2 just flips bwd_grad_format to E5M2. See MXFP8_GROUPED_GEMM_PLAN.md` §0 for the rationale.

What changed

Adds an MXFP8 contiguous grouped GEMM kernel (forward + dgrad + wgrad) and wires it into the MXFP8 dispatch path so routed-expert MoE layers can run on MXFP8.

  • New kernel package alto/kernels/mxfp8/mxfp8_grouped_gemm/
    • cg_forward.py — Triton persistent forward grouped GEMM with super-grouping schedule and contiguous index routing, adapted from the MXFP4 scaffold with all K-packing removed (MXFP8 is one element per byte).
    • cg_backward.py — full backward (dgrad + wgrad) wrapped in an MXFP8GroupedGEMM(autograd.Function).
    • functional.py — user-facing mxfp8_grouped_gemm(...) plus the dispatch-layer entry _quantize_then_mxfp8_scaled_grouped_mm(...), which mirrors the MXFP4
      _quantize_then_mxfp_scaled_grouped_mm contract (1-D cumulative offs,
      padded activation buffers, weights kept in [E, K, N] dispatch convention to
      avoid a transpose copy).
    • autotune.py, __init__.py.
  • Dispatch wiring (alto/kernels/dispatch/tensor.py): the MXFP8 weight wrapper's _grouped_mm path previously raised NotImplementedError; it now routes 2d-activation × 3d-weight + offsets calls to the new kernel, with guards limiting
    V1 to mxfp8_e4m3 and rejecting Hadamard/DGE/bias.
  • Tests (tests/unittest/mxfp8/): grouped GEMM correctness with SNR gates, use_2dblock/use_dot_scaled coverage, cross-format comparison, and a toy-MoE end-to-end training sanity check (+ loss-curve plotting).
  • Docs: MXFP8_GROUPED_GEMM_PLAN.md design/plan doc; m355 test results; README entry for MXFP8 (linear + grouped GEMM).

Test plan

  • tests/unittest/mxfp8/test_mxfp8_grouped_gemm.py (SNR gates pass)
  • tests/unittest/mxfp8/test_e2e_moe.py toy-MoE training sanity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant