Skip to content

[FlyDSL] Fix MoE 2-stage bf16 weight buffer overflow for weights >4GiB#3812

Open
yueliu14 wants to merge 1 commit into
ROCm:mainfrom
yueliu14:flydsl-moe-4gib-fix
Open

[FlyDSL] Fix MoE 2-stage bf16 weight buffer overflow for weights >4GiB#3812
yueliu14 wants to merge 1 commit into
ROCm:mainfrom
yueliu14:flydsl-moe-4gib-fix

Conversation

@yueliu14

@yueliu14 yueliu14 commented Jun 19, 2026

Copy link
Copy Markdown

Summary

The FlyDSL 2-stage MoE GEMM (aiter/ops/flydsl/kernels/moe_gemm_2stage.py) silently
produces wrong results when the expert-weight tensor exceeds 4 GiB. The kernel
addresses the whole [E, 2*inter, model_dim] weight through a single AMD buffer
resource whose buffer_load offset is 32-bit (bytes); once the tensor crosses
4 GiB the per-expert offset wraps and the kernel reads the wrong weights. No error is
raised — the output is just corrupt.

This is hit by MiniMax-M3 bf16 MoE (E=128, inter=3072, model_dim=6144
w13 = 9 GiB), where the MoE output has cos_dist ≈ 0.5 versus an fp32 reference.

The fix rebases the weight buffer to a per-expert 64-bit base address for the
f16/bf16 path, keeping the in-kernel 32-bit offset within a single expert. Other
dtypes are left byte-for-byte unchanged.

Requires flydsl >= 0.1.8 — the fix uses the buffer_ops.create_buffer_resource_from_addr
/ extract_base_index primitives, which are only available from flydsl 0.1.8 onward.

Root cause

In both compile_moe_gemm1 and compile_moe_gemm2 the weight resource is created
once over the entire tensor:

w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False)
...
expert_off_idx = expert_idx * (2 * inter_dim)   # stage1 (rows); stage2 uses model_dim
row_gate = expert_off_idx + col_g               # N-row index into the whole tensor

row_gate is turned into a per-tile element/byte offset and passed to buffer_load,
which casts the offset to i32 bytes. The reachable byte offset is
~E * 2*inter * model_dim * elem_bytes = the full weight size. When that is > 2^32
(4 GiB), the high-expert offsets overflow and wrap.

Exact boundary (evidence)

With E=128, model_dim=6144, bf16, total w13 = E*2*inter*model_dim*2 bytes:

inter w13 size result (before fix)
1280 3.75 GiB correct (cos_dist ~6e-6)
1408 4.12 GiB broken (cos_dist ~3e-2)
1536 4.50 GiB broken (cos_dist ~0.11)
3072 (M3) 9.00 GiB broken (cos_dist ~0.5)

The break is exactly at the 4 GiB / 2^32-byte line. It was never caught because the
existing FlyDSL MoE tests use inter=384 (~1.1 GiB), well under the limit.

Fix

For the f16/bf16 path, rebase the weight buffer resource to this expert's 64-bit
base address
, so the in-kernel 32-bit offset only ever spans one expert
(2*inter*model_dim*elem_bytes, e.g. 75 MB at M3 ≪ 4 GiB):

if const_expr(is_f16_or_bf16):
    _w_base_addr   = buffer_ops.extract_base_index(arg_w)
    _w_expert_byte = expert_off_idx * k_in * arith.index(int(w_elem_bytes))
    w_rsrc = buffer_ops.create_buffer_resource_from_addr(
        arith.index_cast(T.i64, _w_base_addr + _w_expert_byte))
    w_row_off = arith.index(0)
else:
    w_row_off = expert_off_idx
...
row_gate = w_row_off + col_g     # stage1   (row_w = w_row_off + col_g in stage2)

Why it is exact: expert_off_idx is a multiple of 16 (the preshuffle row group),
so base += expert_off_idx * K * elem_bytes combined with row = col_g produces the
same element addresses as the original row = expert_off_idx + col_g against a
whole-tensor base — it is purely a 64-bit re-association of the same offset, with no
change to the tile/preshuffle math. k_in is the contraction dim of each stage
(model_dim for stage1, inter for stage2), so the formula
expert_off_idx * k_in * w_elem_bytes is the per-expert byte stride in both.

create_buffer_resource_from_addr / extract_base_index are existing buffer_ops
helpers (already used elsewhere, e.g. for arg_out).

Scope & safety

  • The change is behind a const_expr(is_f16_or_bf16) compile-time branch, so:
    • fp8 / int8 / int4 / int4_bf16 (which also route through compile_moe_gemm{1,2})
      take the original else path and generate byte-for-byte identical kernels.
    • fp4 uses compile_mixed_moe_gemm{1,2} and is not touched at all.
  • Verified that bf16, fp16, fp8, int8, int4, int4_bf16 all still compile through the
    edited functions.
  • No API/signature changes; no new dependencies.

Tests

Adds op_tests/test_flydsl_moe_large_inter.py: a self-contained (aiter + torch only)
kernel-level test that sweeps inter across the 4 GiB boundary for both stage1 modes
(fused k_batch=1, split-K k_batch>=2), at decode and prefill token counts, and
asserts cos_dist < 0.01 vs an fp32 reference.

Before vs after the fix (E=128, model_dim=6144, bf16, MI300X; cos_dist vs fp32 ref):

case inter w13 cos_dist before cos_dist after
decode fused 384 1.12 GiB 6.1e-06 ✅ 6.1e-06 ✅
decode fused 768 2.25 GiB 6.3e-06 ✅ 6.3e-06 ✅
decode fused 1280 3.75 GiB 6.2e-06 ✅ 6.1e-06 ✅
decode fused 1408 4.12 GiB 1.3e-01 ❌ 6.2e-06 ✅
decode fused 1536 4.50 GiB 4.5e-02 ❌ 6.3e-06 ✅
decode fused 2048 6.00 GiB 3.1e-01 ❌ 6.4e-06 ✅
decode fused 3072 9.00 GiB 3.1e-01 ❌ 6.4e-06 ✅
decode split-K 384 1.12 GiB 2.1e-05 ✅ 2.2e-05 ✅
decode split-K 3072 9.00 GiB 3.1e-01 ❌ 2.0e-05 ✅
prefill fused 1408 4.12 GiB 3.0e-02 ❌ 6.2e-06 ✅
prefill fused 3072 9.00 GiB 5.5e-01 ❌ 6.3e-06 ✅
prefill split-K 3072 9.00 GiB 5.5e-01 ❌ 1.3e-05 ✅

Takeaways:

  • Boundary is exactly 4 GiB. inter=1280 (3.75 GiB) passes before and after;
    inter=1408 (4.12 GiB) is the first to break before the fix — the transition lands
    precisely on the 4 GiB / 2^32-byte line.
  • No regression. Every < 4 GiB shape (inter <= 1280) is identical before/after.
  • All > 4 GiB shapes fixed. Every inter >= 1408 case goes from broken
    (cos_dist ~0.03 .. 0.55) to correct (cos_dist ~1e-5), for both fused and split-K
    stage1 modes and at decode and prefill token counts.
  • Latency unchanged. Same math, just 64-bit-addressed — e.g. decode inter=3072
    ~1.5 ms, prefill inter=3072 ~5.4 ms per call, before and after.
pytest op_tests/test_flydsl_moe_large_inter.py
# or: python op_tests/test_flydsl_moe_large_inter.py

Notes / follow-ups

moe_gemm_2stage addresses the entire [E, 2*inter, model_dim] weight through a
single buffer resource, and AMD buffer_load uses a 32-bit byte offset. Once the
weight tensor exceeds 4 GiB the per-expert offset wraps and the kernel reads the
wrong weights, silently corrupting the result. Hit by MiniMax-M3 bf16 MoE
(E=128, inter=3072, model_dim=6144 -> w13 = 9 GiB): cos_dist vs reference ~0.5.
Missed previously because existing tests use inter=384 (well under 4 GiB).

Fix: for the f16/bf16 path, rebase the weight buffer resource to a per-expert
64-bit base address (create_buffer_resource_from_addr + extract_base_index) in
both stage1 and stage2, so the in-kernel 32-bit offset only spans one expert
(<4 GiB). Exact because expert_off_idx is a multiple of 16 (the preshuffle row
group). fp8/fp4/int4 (packed dtypes) keep the original whole-tensor path
byte-for-byte unchanged -- they need their own byte-offset validation.

Adds op_tests/test_flydsl_moe_large_inter.py: sweeps inter across the 4 GiB
boundary (384..3072) for fused (k_batch=1) and split-K (k_batch>=2) paths.
All pass post-fix (cos ~6e-6..2e-5); inter>=1408 fails pre-fix (~0.5).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: yueliu14 <yue.liu4@amd.com>
@github-actions

Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3812 --add-label <label>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant