WIP: ep_dispatch_combine: full DSV4 decode single-layer chain#772
WIP: ep_dispatch_combine: full DSV4 decode single-layer chain#772zhangqi-chen wants to merge 5 commits into
Conversation
…8, T=16) Bump dispatch/combine/local_expert kernel constants and main.py to mirror the moe_expert decode config so the example can later swap in the real expert. x_norm uses (d % 16) to stay within BF16's exact-integer range. Unify BF16 casts on round-to-nearest-even (TCVT CAST_RINT in the kernels matches torch's .to(bfloat16) in the golden) so routed_y stays bit-exact.
Swap the elementwise local_expert kernel for the production DeepSeek-V4 decode MoE block: 17 PyPTO-generated incore kernels (4 AIC matmuls + 13 AIV) copied from the moe_expert_test JIT build, wired by the moe_expert orchestration transplanted into ep_dispatch_combine_orch.cpp (task ids +1, recv_expert_count read from the host-known tensor(4) instead of a kernel output). dispatch stays func_id 0, moe_expert is 1..17, combine is 18. main.py: x_norm now small random BF16 (keeps recv_x bit-exact through dispatch while keeping MoE matmul outputs in a sane magnitude range); generates the six INT8 weight banks like moe_expert.py::build_tensor_specs (shared across ranks); golden = dispatch replay -> golden_moe_expert (ported from models/deepseek/v4/moe_expert.py) -> combine reduce; emits recv_y + sh; sets block_dim=24 / aicpu_thread_num=4; compiles the 19 kernels with a ThreadPool so ccec runs in parallel. ST is restricted to a2a3 hardware — the D=MOE_INTER=4096 INT8 matmuls are too slow under simulation.
Add an 18-kernel moe_router stage (hc_pre + RMSNorm + learned-score gate +
top-k + weight normalize, PyPTO-generated incore kernels) before dispatch. The
chip-produced x_norm and indices now flow through dispatch → moe_expert →
combine instead of being host-precomputed.
Hardware fixes uncovered along the way:
* dispatch.cpp: EP routing policy swapped to dst = (my_rank + k) % N — the
router was JIT'd with EP_WORLD_SIZE=1 so its `indices` are local IDs in
[0, L); the rank component is now layered on at dispatch time.
* write_route_outputs.cpp: PyPTO emits raw scalar GM stores with no tail
sync, so indices/weights stay in the writer core's L1 D-cache when the
next task starts. dispatch then scalar-reads stale (mostly-zero) data
~60% of runs on hardware (and 0% with profiling on, because the extra
task-boundary overhead lets the cache drain). Add
`pipe_barrier(PIPE_ALL) + dcci(..., CACHELINE_OUT) + dsb` at the end of
the generated body so writes land in L2/HBM before dispatch picks them
up.
* ep_dispatch_combine_orch.cpp: write_route_outputs declares its outputs
against ext_indices/ext_weights instead of the flat reshape views; only
matters cosmetically now that the cache flush is in place, but keeps the
L3 runtime's tensormap dep edge correctly tied to the original tensor.
BF16-cast unification on round-to-nearest-even (matches torch
`.to(bfloat16)` in the golden): mix_x.cpp / exp_recv_y_write.cpp /
sh_write.cpp flip their fp32→bf16 TCVT from CAST_ROUND to CAST_RINT, so the
pipeline stays bit-comparable to the host golden through the same
rounding mode the dispatch/combine pair already uses.
main.py: ports the moe_router golden (hc_pre + RMSNorm + learned-score
gate + top-k) from models/deepseek/v4/{hc_pre,moe_router}.py; wires the 9
router host inputs + chip-produced x_norm/indices/weights/post_ffn/comb_ffn
through the merged orch; exposes EP_SWIMLANE=1 to dump per-chip
l2_perf_records.json + func_names.json (kernel names map) under
outputs/swimlane*/; loosens recv_y/sh/routed_y tolerances to absorb BF16
input noise compounded through the INT8 matmul chain.
Adds the final post-MoE residual add from the single-layer spec (model.py:644-645 in deepseek_v4_decode_single_layer.md) so the chip output is now the BF16 ``ffn_out`` ready for ``hc_post``. The kernel reads routed_y (FP32) + sh (BF16) per token, adds in FP32, casts back to BF16 with CAST_RINT to match the host golden's ``.to(bfloat16)``. UB layout note: the AIV UB is 192 KiB so the usual 64 KB slot pitch (0x0 / 0x10000 / 0x20000 / 0x30000) would put the fourth tile at the very end, which hangs the kernel at TSTORE; the 4th tile is moved to 0x28000. main.py: adds the ``ffn_add`` child callable (func_id 37), wires ``ffn_out_outs`` through the orch, and verifies it via a small ``_verify_ffn_out`` against ``(routed_y_golden + sh_golden).to(bfloat16)``.
Closes the single-layer pipeline by chaining hc_post after ffn_add. The chip-produced ``y [B, S, HC_MULT, D] BF16`` is what the next Transformer block would consume as ``x_hc``: hc_post(x=ffn_out, residual=x_hc, post=post_ffn, comb=comb_ffn) -> y * kernels/aiv/hc_post.cpp — copied verbatim from the moe_router_test build output's hc_post jit (func_id 0 there, 38 in the merged orch). * ep_dispatch_combine_orch.cpp — transplanted the generated hc_post body with task id 0->38, bound to the existing chip tensors (x = ffn_out; residual = router's host input x_hc; post = router-produced post_ffn; comb = router-produced comb_ffn). New OUTPUT_EXISTING slot tensor(37) = y; scratch shifts to tensor(38); expected_arg_count 40 -> 41. * main.py — adds the ``hc_post`` child callable; allocates per-chip ``y_outs``; ports ``golden_hc_post`` from ``models/deepseek/v4/hc_post.py`` and wires ``_verify_hc_post`` that feeds the host-side ``ffn_out_golden`` through it. Pipeline now: router -> dispatch -> moe_expert -> combine -> ffn_add -> hc_post. Observed diffs: y (hc_post) max|diff| ~ 4.7e-2, within atol=1e-1.
There was a problem hiding this comment.
Code Review
This pull request significantly expands the Expert Parallel (EP) dispatch and combine example by integrating full moe_router and moe_expert stages, replacing the previous placeholder implementation. It introduces a comprehensive set of new AIC and AIV kernels generated by the PyPTO compiler, updates the orchestration logic to chain these stages, and scales dimensions to a production-like decode configuration. The review feedback identifies a recurring issue across multiple kernel entry points where int64_t variables are passed to functions expecting int32_t arguments, posing a risk of data loss and narrowing conversion warnings.
| int64_t n0_inline72__idx_v0 = n0_inline72__idx_v0_conv.val; | ||
|
|
||
| // Forward to ptoas-generated function | ||
| exp_gate_up_matmul(recv_x_tile_i8_inline75__rv_v2, expert_w1__ssa_v0, expert_w3__ssa_v0, ret0__out, ret1__out, local_i_inline67__idx_v0, n0_inline72__idx_v0); |
There was a problem hiding this comment.
There is a potential data loss due to type mismatch. The variables local_i_inline67__idx_v0 and n0_inline72__idx_v0 are of type int64_t, but the function exp_gate_up_matmul expects int32_t for its 6th and 7th arguments. This could lead to truncation and unexpected behavior if the values exceed the int32_t range. It's safer to explicitly cast them to int32_t to ensure consistency and avoid narrowing conversion warnings.
exp_gate_up_matmul(recv_x_tile_i8_inline75__rv_v2, expert_w1__ssa_v0, expert_w3__ssa_v0, ret0__out, ret1__out, static_cast<int32_t>(local_i_inline67__idx_v0), static_cast<int32_t>(n0_inline72__idx_v0));References
- Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.
| int64_t d0_inline49__idx_v0 = d0_inline49__idx_v0_conv.val; | ||
|
|
||
| // Forward to ptoas-generated function | ||
| exp_w2_matmul(h_tile_i8_inline92__rv_v2, expert_w2__ssa_v0, ret0__out, local_i_inline67__idx_v0, d0_inline49__idx_v0); |
There was a problem hiding this comment.
There is a potential data loss due to type mismatch. The variables local_i_inline67__idx_v0 and d0_inline49__idx_v0 are of type int64_t, but the function exp_w2_matmul expects int32_t for its 4th and 5th arguments. This could lead to truncation and unexpected behavior if the values exceed the int32_t range. It's safer to explicitly cast them to int32_t to ensure consistency and avoid narrowing conversion warnings.
exp_w2_matmul(h_tile_i8_inline92__rv_v2, expert_w2__ssa_v0, ret0__out, static_cast<int32_t>(local_i_inline67__idx_v0), static_cast<int32_t>(d0_inline49__idx_v0));References
- Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.
| int64_t n0_inline113__idx_v0 = n0_inline113__idx_v0_conv.val; | ||
|
|
||
| // Forward to ptoas-generated function | ||
| sh_gate_up_matmul(x_local_i8_inline43__rv_v2, shared_w1__ssa_v0, shared_w3__ssa_v0, ret0__out, ret1__out, n0_inline113__idx_v0); |
There was a problem hiding this comment.
There is a potential data loss due to type mismatch. The variable n0_inline113__idx_v0 is of type int64_t, but the function sh_gate_up_matmul expects int32_t for its 6th argument. This could lead to truncation and unexpected behavior if the value exceeds the int32_t range. It's safer to explicitly cast it to int32_t to ensure consistency and avoid narrowing conversion warnings.
sh_gate_up_matmul(x_local_i8_inline43__rv_v2, shared_w1__ssa_v0, shared_w3__ssa_v0, ret0__out, ret1__out, static_cast<int32_t>(n0_inline113__idx_v0));References
- Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.
| int64_t d0_inline64__idx_v0 = d0_inline64__idx_v0_conv.val; | ||
|
|
||
| // Forward to ptoas-generated function | ||
| sh_w2_matmul(sh_tile_i8_inline126__rv_v2, shared_w2__ssa_v0, ret0__out, d0_inline64__idx_v0); |
There was a problem hiding this comment.
There is a potential data loss due to type mismatch. The variable d0_inline64__idx_v0 is of type int64_t, but the function sh_w2_matmul expects int32_t for its 4th argument. This could lead to truncation and unexpected behavior if the value exceeds the int32_t range. It's safer to explicitly cast it to int32_t to ensure consistency and avoid narrowing conversion warnings.
sh_w2_matmul(sh_tile_i8_inline126__rv_v2, shared_w2__ssa_v0, ret0__out, static_cast<int32_t>(d0_inline64__idx_v0));References
- Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.
| int64_t k0_inline59__ssa_v0 = k0_inline59__ssa_v0_conv.val; | ||
|
|
||
| // Forward to ptoas-generated function | ||
| cast_x(x_flat_inline36__ssa_v0, x_flat_fp32_inline49__iter_v1, k0_inline59__ssa_v0); |
There was a problem hiding this comment.
There is a potential data loss due to type mismatch. The variable k0_inline59__ssa_v0 is of type int64_t, but the function cast_x expects int32_t for its 3rd argument. This could lead to truncation and unexpected behavior if the value exceeds the int32_t range. It's safer to explicitly cast it to int32_t to ensure consistency and avoid narrowing conversion warnings.
cast_x(x_flat_inline36__ssa_v0, x_flat_fp32_inline49__iter_v1, static_cast<int32_t>(k0_inline59__ssa_v0));References
- Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.
| int64_t t0_inline47__ssa_v0 = t0_inline47__ssa_v0_conv.val; | ||
|
|
||
| // Forward to ptoas-generated function | ||
| recv_x_q(recv_x__ssa_v0, recv_x_tile_i8_inline75__ssa_v0, ret0__out, local_i_inline67__idx_v0, t0_inline47__ssa_v0); |
There was a problem hiding this comment.
There is a potential data loss due to type mismatch. The variables local_i_inline67__idx_v0 and t0_inline47__ssa_v0 are of type int64_t, but the function recv_x_q expects int32_t for its 4th and 5th arguments. This could lead to truncation and unexpected behavior if the values exceed the int32_t range. It's safer to explicitly cast them to int32_t to ensure consistency and avoid narrowing conversion warnings.
recv_x_q(recv_x__ssa_v0, recv_x_tile_i8_inline75__ssa_v0, ret0__out, static_cast<int32_t>(local_i_inline67__idx_v0), static_cast<int32_t>(t0_inline47__ssa_v0));References
- Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.
| int64_t n0_inline113__idx_v0 = n0_inline113__idx_v0_conv.val; | ||
|
|
||
| // Forward to ptoas-generated function | ||
| sh_gate_up_dequant(shared_w1_scale__ssa_v0, shared_w3_scale__ssa_v0, sh_gate_acc_inline95__rv_v2, sh_up_acc_inline118__rv_v2, x_local_scale_dq_inline32__ssa_v0, ret0__out, ret1__out, n0_inline113__idx_v0); |
There was a problem hiding this comment.
There is a potential data loss due to type mismatch. The variable n0_inline113__idx_v0 is of type int64_t, but the function sh_gate_up_dequant expects int32_t for its 8th argument. This could lead to truncation and unexpected behavior if the value exceeds the int32_t range. It's safer to explicitly cast it to int32_t to ensure consistency and avoid narrowing conversion warnings.
sh_gate_up_dequant(shared_w1_scale__ssa_v0, shared_w3_scale__ssa_v0, sh_gate_acc_inline95__rv_v2, sh_up_acc_inline118__rv_v2, x_local_scale_dq_inline32__ssa_v0, ret0__out, ret1__out, static_cast<int32_t>(n0_inline113__idx_v0));References
- Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.
| int64_t d0_inline64__idx_v0 = d0_inline64__idx_v0_conv.val; | ||
|
|
||
| // Forward to ptoas-generated function | ||
| sh_w2_dequant(shared_w2_scale__ssa_v0, sh_y_acc_inline4__rv_v2, sh_tile_scale_dq_inline99__ssa_v0, ret0__out, d0_inline64__idx_v0); |
There was a problem hiding this comment.
There is a potential data loss due to type mismatch. The variable d0_inline64__idx_v0 is of type int64_t, but the function sh_w2_dequant expects int32_t for its 5th argument. This could lead to truncation and unexpected behavior if the value exceeds the int32_t range. It's safer to explicitly cast it to int32_t to ensure consistency and avoid narrowing conversion warnings.
sh_w2_dequant(shared_w2_scale__ssa_v0, sh_y_acc_inline4__rv_v2, sh_tile_scale_dq_inline99__ssa_v0, ret0__out, static_cast<int32_t>(d0_inline64__idx_v0));References
- Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.
| int64_t d0_inline64__idx_v0 = d0_inline64__idx_v0_conv.val; | ||
|
|
||
| // Forward to ptoas-generated function | ||
| sh_write(sh_y_v1_inline114__ssa_v0, sh__iter_v1, d0_inline64__idx_v0); |
There was a problem hiding this comment.
There is a potential data loss due to type mismatch. The variable d0_inline64__idx_v0 is of type int64_t, but the function sh_write expects int32_t for its 3rd argument. This could lead to truncation and unexpected behavior if the value exceeds the int32_t range. It's safer to explicitly cast it to int32_t to ensure consistency and avoid narrowing conversion warnings.
sh_write(sh_y_v1_inline114__ssa_v0, sh__iter_v1, static_cast<int32_t>(d0_inline64__idx_v0));References
- Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.
| int64_t t_inline7__co_idx_v0 = t_inline7__co_idx_v0_conv.val; | ||
|
|
||
| // Forward to ptoas-generated function | ||
| write_post(post_pad_flat_inline73__ssa_v0, post_flat_inline51__ssa_v0, t_inline7__co_idx_v0); |
There was a problem hiding this comment.
There is a potential data loss due to type mismatch. The variable t_inline7__co_idx_v0 is of type int64_t, but the function write_post expects int32_t for its 3rd argument. This could lead to truncation and unexpected behavior if the value exceeds the int32_t range. It's safer to explicitly cast it to int32_t to ensure consistency and avoid narrowing conversion warnings.
write_post(post_pad_flat_inline73__ssa_v0, post_flat_inline51__ssa_v0, static_cast<int32_t>(t_inline7__co_idx_v0));References
- Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.
Summary
Extends
examples/workers/l3/ep_dispatch_combinefrom the small 3-kerneldemo into the full DeepSeek-V4 decode single-layer pipeline:
5 commits, each adds one stage / fix:
unify BF16 casts on round-to-nearest-even (
CAST_RINT).local_expertplaceholder withmoe_expert— 17PyPTO-generated incore kernels (4 AIC matmul + 13 AIV) wired by the
transplanted
moe_expertorchestration. Adds host fixtures for the 6INT8 expert weight banks + golden ported from
models/deepseek/v4/moe_expert.py.moe_router— 18 router kernels (hc_pre + RMSNorm +learned-score gate + top-k + weight normalize). Includes a hardware fix
for a GM cache-coherency race:
write_route_outputsuses scalar GMstores with no tail sync, so the downstream dispatch saw stale
(mostly-zero)
indiceson ~60% of hardware runs (profiling-mode maskedit). Added
dcci(..., CACHELINE_OUT) + dsbwriter-side flush.ffn_add—ffn_out = routed_y + sh(model.py:644-645).hc_post— produces the next-layerx_hc.Fixed-seed fixture; golden chains the host implementations of each stage;
tolerances reflect compounded BF16 + INT8-amax noise through the chain.
Testing
task-submit --device auto --device-num 2 …on a2a3 hardware — allverification stages pass (router / recv / expert / combine / ffn_add /
hc_post).
EP_SWIMLANE=1; per-chipl2_perf_records.json+ kernel-name map dumped underoutputs/swimlane-*/.hardware.
Marked draft while iterating on the example before mainline review.