[FlyDSL] Fix MoE 2-stage bf16 weight buffer overflow for weights >4GiB#3812
Open
yueliu14 wants to merge 1 commit into
Open
[FlyDSL] Fix MoE 2-stage bf16 weight buffer overflow for weights >4GiB#3812yueliu14 wants to merge 1 commit into
yueliu14 wants to merge 1 commit into
Conversation
moe_gemm_2stage addresses the entire [E, 2*inter, model_dim] weight through a single buffer resource, and AMD buffer_load uses a 32-bit byte offset. Once the weight tensor exceeds 4 GiB the per-expert offset wraps and the kernel reads the wrong weights, silently corrupting the result. Hit by MiniMax-M3 bf16 MoE (E=128, inter=3072, model_dim=6144 -> w13 = 9 GiB): cos_dist vs reference ~0.5. Missed previously because existing tests use inter=384 (well under 4 GiB). Fix: for the f16/bf16 path, rebase the weight buffer resource to a per-expert 64-bit base address (create_buffer_resource_from_addr + extract_base_index) in both stage1 and stage2, so the in-kernel 32-bit offset only spans one expert (<4 GiB). Exact because expert_off_idx is a multiple of 16 (the preshuffle row group). fp8/fp4/int4 (packed dtypes) keep the original whole-tensor path byte-for-byte unchanged -- they need their own byte-offset validation. Adds op_tests/test_flydsl_moe_large_inter.py: sweeps inter across the 4 GiB boundary (384..3072) for fused (k_batch=1) and split-K (k_batch>=2) paths. All pass post-fix (cos ~6e-6..2e-5); inter>=1408 fails pre-fix (~0.5). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: yueliu14 <yue.liu4@amd.com>
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The FlyDSL 2-stage MoE GEMM (
aiter/ops/flydsl/kernels/moe_gemm_2stage.py) silentlyproduces wrong results when the expert-weight tensor exceeds 4 GiB. The kernel
addresses the whole
[E, 2*inter, model_dim]weight through a single AMD bufferresource whose
buffer_loadoffset is 32-bit (bytes); once the tensor crosses4 GiB the per-expert offset wraps and the kernel reads the wrong weights. No error is
raised — the output is just corrupt.
This is hit by MiniMax-M3 bf16 MoE (
E=128, inter=3072, model_dim=6144→w13 = 9 GiB), where the MoE output hascos_dist ≈ 0.5versus an fp32 reference.The fix rebases the weight buffer to a per-expert 64-bit base address for the
f16/bf16path, keeping the in-kernel 32-bit offset within a single expert. Otherdtypes are left byte-for-byte unchanged.
Root cause
In both
compile_moe_gemm1andcompile_moe_gemm2the weight resource is createdonce over the entire tensor:
row_gateis turned into a per-tile element/byte offset and passed tobuffer_load,which casts the offset to i32 bytes. The reachable byte offset is
~E * 2*inter * model_dim * elem_bytes= the full weight size. When that is> 2^32(4 GiB), the high-expert offsets overflow and wrap.
Exact boundary (evidence)
With
E=128, model_dim=6144, bf16, totalw13 = E*2*inter*model_dim*2bytes:cos_dist ~6e-6)cos_dist ~3e-2)cos_dist ~0.11)cos_dist ~0.5)The break is exactly at the 4 GiB / 2^32-byte line. It was never caught because the
existing FlyDSL MoE tests use
inter=384(~1.1 GiB), well under the limit.Fix
For the
f16/bf16path, rebase the weight buffer resource to this expert's 64-bitbase address, so the in-kernel 32-bit offset only ever spans one expert
(
2*inter*model_dim*elem_bytes, e.g. 75 MB at M3 ≪ 4 GiB):Why it is exact:
expert_off_idxis a multiple of 16 (the preshuffle row group),so
base += expert_off_idx * K * elem_bytescombined withrow = col_gproduces thesame element addresses as the original
row = expert_off_idx + col_gagainst awhole-tensor base — it is purely a 64-bit re-association of the same offset, with no
change to the tile/preshuffle math.
k_inis the contraction dim of each stage(
model_dimfor stage1,interfor stage2), so the formulaexpert_off_idx * k_in * w_elem_bytesis the per-expert byte stride in both.create_buffer_resource_from_addr/extract_base_indexare existingbuffer_opshelpers (already used elsewhere, e.g. for
arg_out).Scope & safety
const_expr(is_f16_or_bf16)compile-time branch, so:fp8 / int8 / int4 / int4_bf16(which also route throughcompile_moe_gemm{1,2})take the original
elsepath and generate byte-for-byte identical kernels.fp4usescompile_mixed_moe_gemm{1,2}and is not touched at all.bf16, fp16, fp8, int8, int4, int4_bf16all still compile through theedited functions.
Tests
Adds
op_tests/test_flydsl_moe_large_inter.py: a self-contained (aiter + torch only)kernel-level test that sweeps
interacross the 4 GiB boundary for both stage1 modes(fused
k_batch=1, split-Kk_batch>=2), at decode and prefill token counts, andasserts
cos_dist < 0.01vs an fp32 reference.Before vs after the fix (
E=128, model_dim=6144, bf16, MI300X;cos_distvs fp32 ref):Takeaways:
inter=1280(3.75 GiB) passes before and after;inter=1408(4.12 GiB) is the first to break before the fix — the transition landsprecisely on the 4 GiB / 2^32-byte line.
< 4 GiBshape (inter <= 1280) is identical before/after.> 4 GiBshapes fixed. Everyinter >= 1408case goes from broken(
cos_dist ~0.03 .. 0.55) to correct (cos_dist ~1e-5), for both fused and split-Kstage1 modes and at decode and prefill token counts.
inter=3072~1.5 ms, prefill
inter=3072~5.4 ms per call, before and after.Notes / follow-ups
fp8/fp4weights are packed, so the same overflow can affect them at large shapes;fixing those needs their own byte-offset validation and is intentionally out of
scope here (left on the original path).
(vLLM PR [ROCm][Perf] Optional FlyDSL BF16 MoE for the MXFP8-emulation path on MiniMax-M3 vllm-project/vllm#46123), which depends on this kernel fix.