Add MXFP8 grouped GEMM (E4M3) for routed-expert MoE training#31
Draft
ysa2215 wants to merge 12 commits into
Draft
Add MXFP8 grouped GEMM (E4M3) for routed-expert MoE training#31ysa2215 wants to merge 12 commits into
ysa2215 wants to merge 12 commits into
Conversation
… design and feature doc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description:
Why
GPT-OSS-style MoE training needs a grouped GEMM to run on MXFP8 — without it the MXFP8 dispatch path errored out on routed-expert layers, restricting MXFP8 to Linear targets. This is the minimum viable kernel to get MoE training running.
V1 uses E4M3 for all operands across fwd/dgrad/wgrad. A single format keeps the kernel free of dtype dispatch and the autograd free of a separate grad_output quantization path — simplest to implement and validate. The known cost is gradient
numerical robustness (E4M3's narrow dynamic range can underflow the long tail / overflow spikes of grad_output); the toy-MoE training test validates whether V1 is adequate in range.
The mixed-format recipe (grad_output in E5M2 — the TE / Meta / TorchAO industry consensus) is reserved for V2. To make that upgrade kernel-free, the API already exposes independent
fwd_format/bwd_grad_formatparams (both defaulting toE4M3) and the kernel carries a pre-written E5M2 branch; V2 just flips
bwd_grad_formatto E5M2. See MXFP8_GROUPED_GEMM_PLAN.md` §0 for the rationale.What changed
Adds an MXFP8 contiguous grouped GEMM kernel (forward + dgrad + wgrad) and wires it into the MXFP8 dispatch path so routed-expert MoE layers can run on MXFP8.
alto/kernels/mxfp8/mxfp8_grouped_gemm/cg_forward.py— Triton persistent forward grouped GEMM with super-grouping schedule and contiguous index routing, adapted from the MXFP4 scaffold with all K-packing removed (MXFP8 is one element per byte).cg_backward.py— full backward (dgrad + wgrad) wrapped in anMXFP8GroupedGEMM(autograd.Function).functional.py— user-facingmxfp8_grouped_gemm(...)plus the dispatch-layer entry_quantize_then_mxfp8_scaled_grouped_mm(...), which mirrors the MXFP4_quantize_then_mxfp_scaled_grouped_mmcontract (1-D cumulativeoffs,padded activation buffers, weights kept in
[E, K, N]dispatch convention toavoid a transpose copy).
autotune.py,__init__.py.alto/kernels/dispatch/tensor.py): the MXFP8 weight wrapper's_grouped_mmpath previously raisedNotImplementedError; it now routes 2d-activation × 3d-weight + offsets calls to the new kernel, with guards limitingV1 to
mxfp8_e4m3and rejecting Hadamard/DGE/bias.tests/unittest/mxfp8/): grouped GEMM correctness with SNR gates,use_2dblock/use_dot_scaledcoverage, cross-format comparison, and a toy-MoE end-to-end training sanity check (+ loss-curve plotting).MXFP8_GROUPED_GEMM_PLAN.mddesign/plan doc; m355 test results; README entry for MXFP8 (linear + grouped GEMM).Test plan
tests/unittest/mxfp8/test_mxfp8_grouped_gemm.py(SNR gates pass)tests/unittest/mxfp8/test_e2e_moe.pytoy-MoE training sanity