Skip to content

Port/aakbarza/flydsl blockmoe fusion#3810

Open
amirakb89 wants to merge 21 commits into
ROCm:mainfrom
amirakb89:port/aakbarza/flydsl_blockmoe-fusion
Open

Port/aakbarza/flydsl blockmoe fusion#3810
amirakb89 wants to merge 21 commits into
ROCm:mainfrom
amirakb89:port/aakbarza/flydsl_blockmoe-fusion

Conversation

@amirakb89

Copy link
Copy Markdown

Title

[Kernel][Feature] FlyDSL FP8 blockscale MoE (DeepSeek g1u1) with SiLU + GeLU

Summary

Ports the upstream FlyDSL FP8/FP8 + per-1x128 FP32 blockscale MoE GEMM into aiter as a two-stage path wired into the MoE dispatcher, giving a third backend alongside CK and asm for DeepSeek-R1-class workloads on gfx950.

Features

  • FlyDSL fp8 blockscale MoE two-stage path — vendored upstream kernel + aiter adapter matching the dispatcher contract (A/B = fp8 e4m3, C ∈ {bf16, fp16}, per-1x128 fp32 scale, preshuffled B, g1u1).
  • Split-K (k_batch>1) for the fp8/fp8 blockscale path, AOT-precompilable — adds a small-M decode option.
  • SiLU + GeLU activations in stage1 (GeLU is exact-erf, CK-parity formula).
  • bf16 cshuffle epilog — bf16 output now takes the same fast store path as f16.
  • Per-M dispatcher routing with a seeded DSR1 tuner CSV; FlyDSL preferred where competitive, CK/asm fallback elsewhere.
  • Tests + benchmark — unit tests (17 pass, 1 gated perf) and a standalone HIP benchmark with DSR1/Mixtral/Qwen3/GPT-OSS presets.

Performance

On DeepSeek-R1-0528 (gfx950, TP=8) the per-M FlyDSL+fallback config is at parity with CK/asm across the mid band (M≈16–1024) and nets a small, accuracy-neutral E2E serving gain. The tuned asm 1-stage kernel still owns the large-M prefill tail (M≥2048).

Table 1 — End-to-end serving (ATOM, DeepSeek-R1-0528, gfx950, TP=8, conc=128, 1280 prompts) on Mi355 node

FlyDSL column = mean of two back-to-back same-kernel runs (range −1.6% to +8.8% output tok/s; spread is TTFT/queue variance at conc=128, not the kernel). FlyDSL routes only at M=1,2,4,8 (split-K), CK/asm fallback elsewhere.

Metric baseline (CK/asm) FlyDSL split-K (mean) Δ
Output tok/s 4961.69 5139.92 +3.6%
Total tok/s 9923.38 10279.83 +3.6%
Request throughput (req/s) 4.85 5.02 +3.6%
Mean TTFT (ms) 2084.6 1610.9 −22.7%
Mean TPOT (ms) 23.77 23.41 −1.5%
P99 TPOT (ms) 25.61 24.54 −4.2%

Accuracy-neutral: gsm8k 1.00 flexible / 0.98 strict (50-sample); full-set A/B held at 0.955.

Table 2 — Per-M stage1+stage2 latency, best-tuned (DSR1 shape, lower=better) on Mi355 node

model_dim=7168, inter_dim=256, E=257, topk=9, fp8 e4m3 per-1x128, gfx950. All three backends tuned per-M; ratio = FlyDSL / best(CK, asm).

M FlyDSL (μs) CK 2-stage (μs) asm 1tg (μs) best CK/asm FlyDSL ratio
1 12.2 † 35.0 36.7 35.0 (CK) 0.35× †
2 13.6 † 38.0 38.2 38.0 (CK) 0.36× †
4 21.3 † 43.9 39.8 39.8 (asm) 0.54× †
8 34.6 † 68.6 77.1 68.6 (CK) 0.51× †
16 102.9 102.1 113.6 102.1 (CK) 1.01×
32 168.3 182.9 172.1 172.1 (asm) 0.98×
64 219.9 225.1 230.6 225.1 (CK) 0.98×
128 235.0 241.9 247.4 241.9 (CK) 0.97×
256 283.3 282.7 308.3 282.7 (CK) 1.00×
512 306.0 302.0 335.3 302.0 (CK) 1.01×
1024 389.2 398.4 385.2 385.2 (asm) 1.01×
2048 586.6 637.3 467.7 467.7 (asm) 1.25×
4096 957.0 1042.5 683.9 683.9 (asm) 1.40×

† M≤8 are not publishable performance claims — pure-kernel profiler timing mis-counts split-K atomic-accumulate sub-kernels at small M (non-monotonic values); correctness-checked small-M timing/tuning is a follow-up. M=16 uses the cosine-gated e2e value (≈ parity).

Not yet covered (follow-ups)

The adapter raises NotImplementedError for CK features not yet ported: SwiGLU-step (clipped) activation, padded shapes (M/N/K padding), fused bias / multi-D elementwise, compute-optimized pipelines (CK v2–v5), and pipeline scheduler choice. Also open: reliable small-M split-K timing + tuning, and CSV-driven multi-M routing in aiter/fused_moe.py for non-DSR1 shapes.

Testing

  • Unit tests (17 pass, 1 perf test gated behind AITER_RUN_PERF=1)
  • Performance benchmark (op_tests/op_benchmarks/hip/bench_flydsl_moe_blockscale.py)
  • Tested on MI355X (gfx950 / MI355)

@amirakb89 amirakb89 requested review from a team and Copilot June 19, 2026 04:29
@github-actions

Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3810 --add-label <label>

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Ports FlyDSL’s FP8/FP8 + per-1x128 FP32 blockscale 2-stage MoE GEMM into aiter, wiring it into the MoE dispatcher as an additional backend and adding functional/perf tests plus related plumbing (split-K, activations, scale handling).

Changes:

  • Add vendored FlyDSL blockscale MoE stage1/stage2 kernel implementation plus adapter/dispatcher integration (including split-K and SiLU/GeLU).
  • Extend FlyDSL MoE kernel registry/argument plumbing to support FP8 blockscale configs and raw-pointer (fx.Pointer) launch ABI.
  • Add end-to-end FlyDSL blockscale MoE tests and update existing FlyDSL test(s); add an env var to disable FlyDSL HGEMM.

Reviewed changes

Copilot reviewed 7 out of 8 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
op_tests/flydsl_tests/test_silu_and_mul_fq.py Updates test to pass raw-pointer adaptors for kernels that take fx.Pointer args.
op_tests/flydsl_tests/test_flydsl_blockscale_moe.py New end-to-end correctness + split-K scaffolding + gated perf test for FP8 blockscale MoE path.
aiter/tuned_gemm.py Adds AITER_DISABLE_FLYDSL_HGEMM env var gate for FlyDSL HGEMM selection.
aiter/ops/flydsl/moe_kernels.py Adds FP8 blockscale kernel catalog + pointer/memref adaptor split; updates stage1/2 wrappers and args packing.
aiter/ops/flydsl/kernels/silu_and_mul_fq.py Adds explicit FP8 saturation clamp prior to conversion to avoid NaN propagation.
aiter/ops/flydsl/kernels/blockscale_moe_gemm_2stage.py New vendored upstream FlyDSL FP8 blockscale 2-stage MoE kernel with aiter-specific split-K and epilogue additions.
aiter/fused_moe.py Adjusts per-1x128 quant scale transposition logic to include FlyDSL stage1 wrapper routing.
aiter/aot/flydsl/moe.py Updates AOT precompile plumbing for new stage1 out_scale arg slot, but currently has ABI mismatches (see comments).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/aot/flydsl/moe.py
Comment on lines 40 to 43
from aiter.ops.flydsl.moe_kernels import (
_get_compiled_silu_fused,
_ptr_view_safe,
_as_memref,
_run_compiled,
Comment thread aiter/aot/flydsl/moe.py Outdated
Comment on lines +609 to +615
_as_memref(tmp_out.view(-1, inter_dim * 2)),
_as_memref(out.view(-1).view(torch.uint8)),
_as_memref(out_scale_sorted_flat),
_as_memref(sorted_token_ids),
_as_memref(num_valid_ids),
_as_memref(sorted_token_ids.view(-1)),
_as_memref(torch.empty(0, device=dev, dtype=torch.float32)),
Comment thread aiter/aot/flydsl/moe.py
Comment on lines 556 to 560
_k_in,
_grid_y,
stream=0,
out_scale_sorted=_s1_scale_arg,
)
Comment on lines 856 to 859
blocks,
stream=None,
use_ptr=False,
):

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 5cf516b486

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread aiter/ops/flydsl/moe_kernels.py Outdated

_SUFFIX_RE = re.compile(r"(?P<fp4>_fp4)?(?P<fp8>_fp8)?(?:_sbm(?P<sbm>\d+))?$")
_SUFFIX_RE = re.compile(
r"(?P<blk>_blkscale)?(?P<fp4>_fp4)?(?P<fp8>_fp8)?(?:_sbm(?P<sbm>\d+))?$"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve the blockscale suffix when parsing fp8 variants

When a fused-quant blockscale name uses the default variant, e.g. flydsl_moe1_afp8_wfp8_bf16_t..._blkscale_fp8, this regex consumes the whole _blkscale_fp8 tail. get_flydsl_kernel_params() then looks up the base name without _blkscale, but the blockscale registry keys are registered with _blkscale, so the lookup returns None and _flydsl_stage1_wrapper rejects an otherwise valid fp8-output blockscale kernel name. The parser should strip only the quant/sbm suffix while keeping _blkscale as part of the registered base name.

Useful? React with 👍 / 👎.

a_dtype=a_dtype,
b_dtype=b_dtype,
out_dtype=out_dtype,
act=act,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Forward Gelu activations to the blockscale kernel

For Gelu MoE configs that select this new fp8/fp8 FlyDSL blockscale path, the high-level wrapper still maps every non-Swiglu activation to act='silu' before reaching this call, even though the blockscale compiler accepts act='gelu'. That means an ActivationType.Gelu model configured with these FlyDSL kernel names will compile/run the SiLU gate instead and return numerically wrong outputs; preserve Gelu when deriving the act argument before forwarding it here.

Useful? React with 👍 / 👎.

@amirakb89 amirakb89 force-pushed the port/aakbarza/flydsl_blockmoe-fusion branch from 5cf516b to 33afdea Compare June 19, 2026 04:36
Amir Akbarzadeh and others added 20 commits June 19, 2026 15:46
Verbatim snapshot of ROCm/FlyDSL kernels/moe_blockscale_2stage.py
(PRs ROCm#164 / ROCm#252 / ROCm#306). Only change: rewrite two `from kernels.X`
imports to relative `.X` for the aiter package layout. All referenced
helpers already exist in aiter's mfma_preshuffle_pipeline.py and
mfma_epilogues.py.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Replaces the prior 439-line TODO scaffold with a thin adapter that
forwards the wide aiter dispatcher signature (25 kwargs) to the
upstream entry points. Tier-A passthrough (out_dtype, scale_block_k,
waves_per_eu, accumulate); Tier-B no-op accept of DSR1 defaults
(silu, no bias, no pad, no swiglu_limit); Tier-C raises a clear
NotImplementedError for gelu / bias / pads / split-K / persistent /
xcd_swizzle / b_nt.

Dispatcher (moe_kernels.py) keeps the same a_dtype='fp8',b_dtype='fp8'
routing line and registers the canonical FlyDSL kernel-name set for
the FP8 blockscale path so AOT precompile and tuner CSVs can resolve
them.

DSR1 (M=8192) via the dispatcher: FlyDSL 2219us vs CK 2533us (1.14x).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Adds one tuned row (M=8192,N=4096,K=7168, top_k=8, exp=256) so the
AOT precompile harness in aot/flydsl/moe.py can discover the kernel.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
12 adapter tests (signature, validation, Tier-A/B/C error paths),
stage1+stage2 functional smoke at small shapes, and a gated DSR1-scale
perf test that benchmarks the dispatcher path vs CK 2-stage and the
ASM fused fallback.

Local run on MI355: 12/12 adapter tests pass, DSR1 perf gate passes
(FlyDSL 2219us vs CK 2533us, 1.14x).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The prior fixture quantized activations twice (per-token then per-block)
which left the kernel and the FP32 reference with mismatched scales.
Small-shape outputs were Inf/NaN in both paths; isclose(Inf,Inf)
masked the failure as err_ratio=0.0001, and stage2 was downgraded to
informational.

Switch _prepare_data to a single per_group_quant_hip pass: x_bq + the
[nblk_k,tokens] and [tokens,nblk_k] scale layouts the kernel and ref
each need. Replace the informational-only assertions with strict
tolerances (stage1 err_ratio <= 0.02, stage2 <= 0.05, no non-finite).

Local on MI355:
  tiny  (64t/dim=1024/idim=256/E=8/topk=2): s1 err=0.000, s2 err=0.000
  small (256t/dim=2048/idim=512/E=16/topk=4): s1 err=0.000, s2 err=0.000
  DSR1 perf gate unchanged: FlyDSL 2235.9us vs CK 2539.3us (1.14x).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The cshuffle epilog was gated to f16 output only by a hard raise in
compile_moe_blockscale_gemm1. Route the LDS frag type and the per-element
trunc through out_mlir() so bf16 can take the same fast store path that
f16 already uses, and forward frag_elem_type to mfma_epilog.

Required for DSR1 (bf16 out) to actually exercise the cshuffle path; the
prior code would have raised ValueError as soon as the adapter forwarded
out_dtype="bf16" with cshuffle enabled (the default).

Adds test_dispatcher_routes_fp8_fp8_bf16_compiles to lock in the new
path; every other test in the file uses out_dtype='f16' and would not
have caught a regression in the bf16 store.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Plumb act through compile_moe_blockscale_gemm1 and select silu vs gelu
at compile time in the stage1 epilogue. GeLU uses the exact erf formula
(0.5*x*(1+erf(x/sqrt(2)))) for CK parity, implemented inline via the
Abramowitz & Stegun 7.1.26 polynomial (rocdl.exp2 + rocdl.rcp fast
path + math.absf/copysign).

Adapter Tier-C check relaxed to accept act in {silu, gelu}. Default is
unchanged (silu), so existing callers see no behavior change.

Test ref _torch_stage1_ref takes act and uses F.gelu(approximate='none')
to match the kernel; the e2e correctness test now runs both activations
at tiny/small shapes within the same tolerances.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Standalone bench for the FP8/FP8 + per-1x128 FP32 blockscale MoE path,
modeled on op_tests/op_benchmarks/hip/ conventions. Reuses the launch
helpers from the test file (imported via importlib since flydsl_tests
has no __init__.py) so timing matches the gated pytest within ~2%.

Defaults to the DSR1 shape (M sweep over [1, 8, 64, 256, 1024, 4096,
8192]) but exposes presets for dsr1-tp4/tp1, mixtral-8x7b/22b,
qwen3-235b, and gpt-oss-120b. Optional --compare ck,asm columns reuse
ck_moe_stage{1,2}_fwd and aiter.fmoe_fp8_blockscale_g1u1; --compare ""
runs FlyDSL-only. CSV output to logs/flydsl_moe_blockscale_<act>.csv
for offline pivoting.

For non-DSR1 presets, the bench logs (does not assert) when the
adapter raises NotImplementedError so --preset all never aborts
halfway, and notes in the header that only DSR1 has a tuned CSV row
today.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Run black==26.3.0 and ruff==0.15.7 over the new blockscale adapter,
test, and benchmark files to satisfy AITER's pre-commit hooks
(CONTRIBUTE.md).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Parallelize the K-reduction across k_batch workgroups so the M*N CTA grid
no longer under-fills the GPU at low M (DSR1 decode). When k_batch > 1
the kernel expands its z-grid, each CTA processes K/k_batch of the K
dimension, and the epilogue atomic-adds unactivated gate/up partials
into a zeroed tmp_out buffer (interleave gui_layout); the existing
silu_and_mul kernel fuses activation + reduction after the GEMM.

Adapter lifts the Tier-C k_batch != 1 gate and validates the K-slice is
divisible by scale_block_k=128, by tile_k, and that total tiles per
split is even >= 2 (the K-loop unrolls in pairs of tile_k tiles).

Upstream kernel:
- launcher grid (gx, gy, k_batch); kernel reads bz, k_offset_base
- K main-loop and tail bounds shifted by k_offset_base; scale-tensor
  base inherits the shift via existing k_base/scale_block_k math
- out_rsrc byte count doubled (2*inter_dim) for split-K tmp_out
- CShuffle epilogue forced off for split-K
- new _atomic_add_scalar_via_pk helper uses raw_ptr_buffer_atomic_fadd
  on a half2/bf16x2 pack (AMD has no scalar fp16/bf16 atomic) lowering
  to buffer_atomic_pk_add_f16/bf16 on gfx950
- doweight_stage1 + k_batch > 1 raises NotImplementedError until the
  dispatcher forwards sorted_weights to silu_and_mul_fq

Tests:
- test_splitk_valid_config_compiles[1,2,4]: compile-path guards
- test_splitk_invalid_kslice_raises: validator coverage
- test_blockscale_splitk_stage1_e2e: M in {8,64,256}, k_batch in {2,4},
  err_ratio = 0.0000 vs FP32 ref

Measured speedup at DSR1 decode shape (M=8, dim=7168, inter=256):
  k_batch=1: 52.0us  k_batch=2: 37.6us (1.38x)  k_batch=4: 23.3us (2.23x)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds pick_k_batch_for_blockscale_stage1() heuristic and wires it through
flydsl_moe_stage1 via k_batch="auto". The selector estimates grid
occupancy from (M*topk, 2*inter_dim) tiles and picks k_batch to
multiply the CTA count when the base grid is too small to saturate
the GPU. Decode-1 (blocks_m == 1) is special-cased to k_batch=4 to
match measured DSR1 TP=8 optima (1.5-1.65x vs CK at M=1).

Selector is k-slice-aware: validates model_dim % k_batch, scale_block_k
and tile_k divisibility, and falls back to the next-smallest valid
k_batch (down to 1) if the desired choice is invalid for the shape.

Tests pin the measured-best k_batch for 12 DSR1 TP=8 stage1 shapes
(M in {1,8,16,32,64,128,256,1024,8192} x inter_dim in {256,512}) plus
a fall-back case for an invalid-kslice shape.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ow blockscale

Phase B's per-32 e8m0 sorted-tile scale layout matched silu_and_mul_fq but
NOT the fp8/fp8 stage2 GEMM, which expects a2_scale as f32 per-row
[n_blocks_k, tokens*topk] (the format per_group_quant_hip(transpose_scale=True)
produces). End-to-end piping fused_stage1 → flydsl_moe_stage2 produced
inf/nan because stage2 reinterpreted e8m0 bytes as f32 exponents.

Switch the fused fp8 epilog to emit the per-row f32 format directly, making
the fused chain a drop-in replacement for the unfused
(bf16 → per_group_quant_hip → stage2) chain.

Kernel (_blockscale_moe_gemm_2stage_upstream.py):
  * Restricted to tile_n=128 (one per-128 block per CTA in N).
  * Allocate tile_m × 4-wave × f32 LDS scratch for cross-wave amax.
  * Two-pass epilog: (1) per-wave intra-wave shuffle_xor amax + LDS publish
    by lane_mod_16==0; (2) barrier; all threads read 4 slots → per-128 amax;
    quant_scale = 240/amax, clamp ±240, cvt_pk_fp8 → token-major fp8 bytes;
    designated lane (wave_mod_4==0, lane_mod_16==0) writes f32 amax/240 to
    per-row scale buffer at by * tokens*topk + t*topk + s.
  * Drop sorted-tile address math + e8m0 headroom=8 bias (now amax/240, CK
    pattern matching per_group_quant_hip).

Wrapper (blockscale_moe_gemm_2stage.py):
  * Validate tile_n==128 and inter_dim%128==0 when out_dtype='fp8'.

Dispatcher (moe_kernels.py):
  * For fused-fp8 GEMM path, allocate f32 buffer sized n_blocks_k * tokens*topk
    (no padding) and return shape (n_blocks_k, tokens*topk). Legacy uint8
    sorted-tile path (gui_sk / splitk_fp4 post-quant kernels) unchanged.

Tests:
  * ab_fused_fp8_e2e.py — production-equivalent stage1+stage2 A/B
    (fused vs unfused vs torch ref).
  * ab_fused_fp8_correctness.py — stage1-only fused vs unfused fp8 emit
    with byte-match check.
  * ab_splitk_correctness.py — kb=2 vs kb=4 split-K self-consistency vs
    FP32 partials ref.
  * bench_before_after_fusion.py — pre/post fusion + CK comparison.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
On gfx950 the fp8 e4m3fn max is 448, not 240 (the gfx942 e4m3fnuz value).
The fused stage1 epilog was clamping/scaling to ±240, which left half the
fp8 dynamic range unused and caused ~42% e2e error vs the unfused chain
(scales 1.87x = 448/240 too large, fp8 values one exponent step too small).

After this fix, e2e A/B vs the unfused chain agrees within bf16 noise
(fu/un_err% ≈ 2-5% across M ∈ {1,8,16,32,64,256}).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
* test_flydsl_blockscale_moe.py: pass a placeholder uint8 buffer for the
  new arg_out_scale parameter that compile_moe_blockscale_gemm1 now takes
  unconditionally (consumed only on the fp8 epilog path).
* bench_vs_ck_e2e.py: FlyDSL fused fp8 (stage1+stage2) vs CK two-stage,
  reports per-stage and total us across M.
* gemm_kernels.py: flyc.from_c_void_p was removed upstream; switch to
  flyc.from_dlpack for both real and fake tensors.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
cvt_pk_fp8_f32 does not saturate; out-of-range inputs become fp8 NaN
(0x80) and propagate through stage2. Add ±240 v_med3 clamp before the
intrinsic (matches the CK pattern in vec_convert.h:46-59).

Needed for the split-K (k_batch>1) MoE stage1 path, which still uses
the bf16 GEMM + silu_and_mul_fq chain (Phase B-v2 fuses kb=1 only).

Note: 240 is the gfx942 e4m3fnuz max; on gfx950 (e4m3fn, max=448) this
under-utilizes dynamic range. Same gfx950 fix as 5decccc should be
applied here in a follow-up; the clamp itself is correctness-critical
on both arches.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Upstream FlyDSL blockscale stage1/stage2 kernels expect activation scales
in the transposed [nblk_k, tokens] / [nblk_k_w2, tokens*topk] layout — the
same layout asm_stage1 needs. fused_moe_2stages only flipped
transpose_scale=True for asm_stage1, so the FlyDSL path silently received
the default [tokens, nblk_k] layout. No MLIR/runtime error — just garbage
numerics.

Production DSR1 TP=8 gsm8k LIMIT=50 with this branch:
  unfused bf16:  0.02 → 0.98
  fused fp8:     0.62 → 0.96

Extend both gates (a1_scale at ~1635, a2_scale fallthrough at ~1760) to
also match _flydsl_stage1_wrapper. The fused-fp8 path is unaffected by the
a2_scale gate because stage1 writes its own scale buffer in the correct
layout (tuple-return branch at ~1701).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Allow disabling FlyDSL for regular A16W16 hgemm dispatch while keeping it
enabled for blockscale MoE. Useful when isolating FlyDSL's MoE behavior in
end-to-end runs without changing tuning CSVs, and required when the FlyDSL
wheel's splitk_hgemm path is incompatible with the current memref-style
buffer descriptors used elsewhere.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Squash of post-review fixes: pass raw pointers (from_c_void_p) to all
fx.Pointer/ptrtoint kernels — _ptr_view_safe in gemm_kernels (fixes the
gfx950 preshuffle a8w8 crash) and the bf16/int4 + fp4 MoE arg builders;
correct the stage1 out_scale_sorted slot and use a null FakeTensor ptr
for AOT trace; arch-correct FP8 max (240 gfx942 / 448 gfx950) in
silu_and_mul_fq; wfp4 scale layout; merge vendored upstream blockscale
kernel; drop diagnostic A/B scripts and fix CI lint.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
@amirakb89 amirakb89 force-pushed the port/aakbarza/flydsl_blockmoe-fusion branch from 33afdea to 0b840ae Compare June 19, 2026 16:27
@amirakb89 amirakb89 force-pushed the port/aakbarza/flydsl_blockmoe-fusion branch from 0b840ae to 0b8bead Compare June 19, 2026 22:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants