Skip to content

WIP: ep_dispatch_combine: full DSV4 decode single-layer chain#772

Draft
zhangqi-chen wants to merge 5 commits into
hw-native-sys:mainfrom
zhangqi-chen:dsv4_moe_demo
Draft

WIP: ep_dispatch_combine: full DSV4 decode single-layer chain#772
zhangqi-chen wants to merge 5 commits into
hw-native-sys:mainfrom
zhangqi-chen:dsv4_moe_demo

Conversation

@zhangqi-chen
Copy link
Copy Markdown
Contributor

Summary

Extends examples/workers/l3/ep_dispatch_combine from the small 3-kernel
demo into the full DeepSeek-V4 decode single-layer pipeline:

moe_router → dispatch → moe_expert → combine → ffn_add → hc_post

5 commits, each adds one stage / fix:

  1. Scale dims to the production decode config (D=4096, L=8, T=16, R=32);
    unify BF16 casts on round-to-nearest-even (CAST_RINT).
  2. Replace local_expert placeholder with moe_expert — 17
    PyPTO-generated incore kernels (4 AIC matmul + 13 AIV) wired by the
    transplanted moe_expert orchestration. Adds host fixtures for the 6
    INT8 expert weight banks + golden ported from
    models/deepseek/v4/moe_expert.py.
  3. Prepend moe_router — 18 router kernels (hc_pre + RMSNorm +
    learned-score gate + top-k + weight normalize). Includes a hardware fix
    for a GM cache-coherency race: write_route_outputs uses scalar GM
    stores with no tail sync, so the downstream dispatch saw stale
    (mostly-zero) indices on ~60% of hardware runs (profiling-mode masked
    it). Added dcci(..., CACHELINE_OUT) + dsb writer-side flush.
  4. Append ffn_addffn_out = routed_y + sh (model.py:644-645).
  5. Append hc_post — produces the next-layer x_hc.

Fixed-seed fixture; golden chains the host implementations of each stage;
tolerances reflect compounded BF16 + INT8-amax noise through the chain.

Testing

  • task-submit --device auto --device-num 2 … on a2a3 hardware — all
    verification stages pass (router / recv / expert / combine / ffn_add /
    hc_post).
  • L2 swimlane profiling validated via EP_SWIMLANE=1; per-chip
    l2_perf_records.json + kernel-name map dumped under
    outputs/swimlane-*/.
  • Simulation (a2a3sim) — not viable at these dims; ST restricted to
    hardware.

Marked draft while iterating on the example before mainline review.

…8, T=16)

Bump dispatch/combine/local_expert kernel constants and main.py to mirror the
moe_expert decode config so the example can later swap in the real expert.
x_norm uses (d % 16) to stay within BF16's exact-integer range. Unify BF16
casts on round-to-nearest-even (TCVT CAST_RINT in the kernels matches torch's
.to(bfloat16) in the golden) so routed_y stays bit-exact.
Swap the elementwise local_expert kernel for the production DeepSeek-V4 decode
MoE block: 17 PyPTO-generated incore kernels (4 AIC matmuls + 13 AIV) copied
from the moe_expert_test JIT build, wired by the moe_expert orchestration
transplanted into ep_dispatch_combine_orch.cpp (task ids +1, recv_expert_count
read from the host-known tensor(4) instead of a kernel output). dispatch stays
func_id 0, moe_expert is 1..17, combine is 18.

main.py: x_norm now small random BF16 (keeps recv_x bit-exact through dispatch
while keeping MoE matmul outputs in a sane magnitude range); generates the six
INT8 weight banks like moe_expert.py::build_tensor_specs (shared across ranks);
golden = dispatch replay -> golden_moe_expert (ported from
models/deepseek/v4/moe_expert.py) -> combine reduce; emits recv_y + sh; sets
block_dim=24 / aicpu_thread_num=4; compiles the 19 kernels with a ThreadPool so
ccec runs in parallel.

ST is restricted to a2a3 hardware — the D=MOE_INTER=4096 INT8 matmuls are too
slow under simulation.
Add an 18-kernel moe_router stage (hc_pre + RMSNorm + learned-score gate +
top-k + weight normalize, PyPTO-generated incore kernels) before dispatch. The
chip-produced x_norm and indices now flow through dispatch → moe_expert →
combine instead of being host-precomputed.

Hardware fixes uncovered along the way:

* dispatch.cpp: EP routing policy swapped to dst = (my_rank + k) % N — the
  router was JIT'd with EP_WORLD_SIZE=1 so its `indices` are local IDs in
  [0, L); the rank component is now layered on at dispatch time.

* write_route_outputs.cpp: PyPTO emits raw scalar GM stores with no tail
  sync, so indices/weights stay in the writer core's L1 D-cache when the
  next task starts. dispatch then scalar-reads stale (mostly-zero) data
  ~60% of runs on hardware (and 0% with profiling on, because the extra
  task-boundary overhead lets the cache drain). Add
  `pipe_barrier(PIPE_ALL) + dcci(..., CACHELINE_OUT) + dsb` at the end of
  the generated body so writes land in L2/HBM before dispatch picks them
  up.

* ep_dispatch_combine_orch.cpp: write_route_outputs declares its outputs
  against ext_indices/ext_weights instead of the flat reshape views; only
  matters cosmetically now that the cache flush is in place, but keeps the
  L3 runtime's tensormap dep edge correctly tied to the original tensor.

BF16-cast unification on round-to-nearest-even (matches torch
`.to(bfloat16)` in the golden): mix_x.cpp / exp_recv_y_write.cpp /
sh_write.cpp flip their fp32→bf16 TCVT from CAST_ROUND to CAST_RINT, so the
pipeline stays bit-comparable to the host golden through the same
rounding mode the dispatch/combine pair already uses.

main.py: ports the moe_router golden (hc_pre + RMSNorm + learned-score
gate + top-k) from models/deepseek/v4/{hc_pre,moe_router}.py; wires the 9
router host inputs + chip-produced x_norm/indices/weights/post_ffn/comb_ffn
through the merged orch; exposes EP_SWIMLANE=1 to dump per-chip
l2_perf_records.json + func_names.json (kernel names map) under
outputs/swimlane*/; loosens recv_y/sh/routed_y tolerances to absorb BF16
input noise compounded through the INT8 matmul chain.
Adds the final post-MoE residual add from the single-layer spec (model.py:644-645
in deepseek_v4_decode_single_layer.md) so the chip output is now the BF16
``ffn_out`` ready for ``hc_post``. The kernel reads routed_y (FP32) + sh (BF16)
per token, adds in FP32, casts back to BF16 with CAST_RINT to match the host
golden's ``.to(bfloat16)``.

UB layout note: the AIV UB is 192 KiB so the usual 64 KB slot pitch (0x0 /
0x10000 / 0x20000 / 0x30000) would put the fourth tile at the very end, which
hangs the kernel at TSTORE; the 4th tile is moved to 0x28000.

main.py: adds the ``ffn_add`` child callable (func_id 37), wires
``ffn_out_outs`` through the orch, and verifies it via a small
``_verify_ffn_out`` against ``(routed_y_golden + sh_golden).to(bfloat16)``.
Closes the single-layer pipeline by chaining hc_post after ffn_add. The
chip-produced ``y [B, S, HC_MULT, D] BF16`` is what the next Transformer
block would consume as ``x_hc``:

  hc_post(x=ffn_out, residual=x_hc, post=post_ffn, comb=comb_ffn) -> y

* kernels/aiv/hc_post.cpp — copied verbatim from the moe_router_test build
  output's hc_post jit (func_id 0 there, 38 in the merged orch).
* ep_dispatch_combine_orch.cpp — transplanted the generated hc_post body
  with task id 0->38, bound to the existing chip tensors (x = ffn_out;
  residual = router's host input x_hc; post = router-produced post_ffn;
  comb = router-produced comb_ffn). New OUTPUT_EXISTING slot tensor(37) = y;
  scratch shifts to tensor(38); expected_arg_count 40 -> 41.
* main.py — adds the ``hc_post`` child callable; allocates per-chip
  ``y_outs``; ports ``golden_hc_post`` from
  ``models/deepseek/v4/hc_post.py`` and wires ``_verify_hc_post`` that
  feeds the host-side ``ffn_out_golden`` through it.

Pipeline now: router -> dispatch -> moe_expert -> combine -> ffn_add -> hc_post.
Observed diffs: y (hc_post) max|diff| ~ 4.7e-2, within atol=1e-1.
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly expands the Expert Parallel (EP) dispatch and combine example by integrating full moe_router and moe_expert stages, replacing the previous placeholder implementation. It introduces a comprehensive set of new AIC and AIV kernels generated by the PyPTO compiler, updates the orchestration logic to chain these stages, and scales dimensions to a production-like decode configuration. The review feedback identifies a recurring issue across multiple kernel entry points where int64_t variables are passed to functions expecting int32_t arguments, posing a risk of data loss and narrowing conversion warnings.

int64_t n0_inline72__idx_v0 = n0_inline72__idx_v0_conv.val;

// Forward to ptoas-generated function
exp_gate_up_matmul(recv_x_tile_i8_inline75__rv_v2, expert_w1__ssa_v0, expert_w3__ssa_v0, ret0__out, ret1__out, local_i_inline67__idx_v0, n0_inline72__idx_v0);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential data loss due to type mismatch. The variables local_i_inline67__idx_v0 and n0_inline72__idx_v0 are of type int64_t, but the function exp_gate_up_matmul expects int32_t for its 6th and 7th arguments. This could lead to truncation and unexpected behavior if the values exceed the int32_t range. It's safer to explicitly cast them to int32_t to ensure consistency and avoid narrowing conversion warnings.

    exp_gate_up_matmul(recv_x_tile_i8_inline75__rv_v2, expert_w1__ssa_v0, expert_w3__ssa_v0, ret0__out, ret1__out, static_cast<int32_t>(local_i_inline67__idx_v0), static_cast<int32_t>(n0_inline72__idx_v0));
References
  1. Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.

int64_t d0_inline49__idx_v0 = d0_inline49__idx_v0_conv.val;

// Forward to ptoas-generated function
exp_w2_matmul(h_tile_i8_inline92__rv_v2, expert_w2__ssa_v0, ret0__out, local_i_inline67__idx_v0, d0_inline49__idx_v0);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential data loss due to type mismatch. The variables local_i_inline67__idx_v0 and d0_inline49__idx_v0 are of type int64_t, but the function exp_w2_matmul expects int32_t for its 4th and 5th arguments. This could lead to truncation and unexpected behavior if the values exceed the int32_t range. It's safer to explicitly cast them to int32_t to ensure consistency and avoid narrowing conversion warnings.

    exp_w2_matmul(h_tile_i8_inline92__rv_v2, expert_w2__ssa_v0, ret0__out, static_cast<int32_t>(local_i_inline67__idx_v0), static_cast<int32_t>(d0_inline49__idx_v0));
References
  1. Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.

int64_t n0_inline113__idx_v0 = n0_inline113__idx_v0_conv.val;

// Forward to ptoas-generated function
sh_gate_up_matmul(x_local_i8_inline43__rv_v2, shared_w1__ssa_v0, shared_w3__ssa_v0, ret0__out, ret1__out, n0_inline113__idx_v0);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential data loss due to type mismatch. The variable n0_inline113__idx_v0 is of type int64_t, but the function sh_gate_up_matmul expects int32_t for its 6th argument. This could lead to truncation and unexpected behavior if the value exceeds the int32_t range. It's safer to explicitly cast it to int32_t to ensure consistency and avoid narrowing conversion warnings.

    sh_gate_up_matmul(x_local_i8_inline43__rv_v2, shared_w1__ssa_v0, shared_w3__ssa_v0, ret0__out, ret1__out, static_cast<int32_t>(n0_inline113__idx_v0));
References
  1. Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.

int64_t d0_inline64__idx_v0 = d0_inline64__idx_v0_conv.val;

// Forward to ptoas-generated function
sh_w2_matmul(sh_tile_i8_inline126__rv_v2, shared_w2__ssa_v0, ret0__out, d0_inline64__idx_v0);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential data loss due to type mismatch. The variable d0_inline64__idx_v0 is of type int64_t, but the function sh_w2_matmul expects int32_t for its 4th argument. This could lead to truncation and unexpected behavior if the value exceeds the int32_t range. It's safer to explicitly cast it to int32_t to ensure consistency and avoid narrowing conversion warnings.

    sh_w2_matmul(sh_tile_i8_inline126__rv_v2, shared_w2__ssa_v0, ret0__out, static_cast<int32_t>(d0_inline64__idx_v0));
References
  1. Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.

int64_t k0_inline59__ssa_v0 = k0_inline59__ssa_v0_conv.val;

// Forward to ptoas-generated function
cast_x(x_flat_inline36__ssa_v0, x_flat_fp32_inline49__iter_v1, k0_inline59__ssa_v0);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential data loss due to type mismatch. The variable k0_inline59__ssa_v0 is of type int64_t, but the function cast_x expects int32_t for its 3rd argument. This could lead to truncation and unexpected behavior if the value exceeds the int32_t range. It's safer to explicitly cast it to int32_t to ensure consistency and avoid narrowing conversion warnings.

    cast_x(x_flat_inline36__ssa_v0, x_flat_fp32_inline49__iter_v1, static_cast<int32_t>(k0_inline59__ssa_v0));
References
  1. Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.

int64_t t0_inline47__ssa_v0 = t0_inline47__ssa_v0_conv.val;

// Forward to ptoas-generated function
recv_x_q(recv_x__ssa_v0, recv_x_tile_i8_inline75__ssa_v0, ret0__out, local_i_inline67__idx_v0, t0_inline47__ssa_v0);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential data loss due to type mismatch. The variables local_i_inline67__idx_v0 and t0_inline47__ssa_v0 are of type int64_t, but the function recv_x_q expects int32_t for its 4th and 5th arguments. This could lead to truncation and unexpected behavior if the values exceed the int32_t range. It's safer to explicitly cast them to int32_t to ensure consistency and avoid narrowing conversion warnings.

    recv_x_q(recv_x__ssa_v0, recv_x_tile_i8_inline75__ssa_v0, ret0__out, static_cast<int32_t>(local_i_inline67__idx_v0), static_cast<int32_t>(t0_inline47__ssa_v0));
References
  1. Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.

int64_t n0_inline113__idx_v0 = n0_inline113__idx_v0_conv.val;

// Forward to ptoas-generated function
sh_gate_up_dequant(shared_w1_scale__ssa_v0, shared_w3_scale__ssa_v0, sh_gate_acc_inline95__rv_v2, sh_up_acc_inline118__rv_v2, x_local_scale_dq_inline32__ssa_v0, ret0__out, ret1__out, n0_inline113__idx_v0);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential data loss due to type mismatch. The variable n0_inline113__idx_v0 is of type int64_t, but the function sh_gate_up_dequant expects int32_t for its 8th argument. This could lead to truncation and unexpected behavior if the value exceeds the int32_t range. It's safer to explicitly cast it to int32_t to ensure consistency and avoid narrowing conversion warnings.

    sh_gate_up_dequant(shared_w1_scale__ssa_v0, shared_w3_scale__ssa_v0, sh_gate_acc_inline95__rv_v2, sh_up_acc_inline118__rv_v2, x_local_scale_dq_inline32__ssa_v0, ret0__out, ret1__out, static_cast<int32_t>(n0_inline113__idx_v0));
References
  1. Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.

int64_t d0_inline64__idx_v0 = d0_inline64__idx_v0_conv.val;

// Forward to ptoas-generated function
sh_w2_dequant(shared_w2_scale__ssa_v0, sh_y_acc_inline4__rv_v2, sh_tile_scale_dq_inline99__ssa_v0, ret0__out, d0_inline64__idx_v0);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential data loss due to type mismatch. The variable d0_inline64__idx_v0 is of type int64_t, but the function sh_w2_dequant expects int32_t for its 5th argument. This could lead to truncation and unexpected behavior if the value exceeds the int32_t range. It's safer to explicitly cast it to int32_t to ensure consistency and avoid narrowing conversion warnings.

    sh_w2_dequant(shared_w2_scale__ssa_v0, sh_y_acc_inline4__rv_v2, sh_tile_scale_dq_inline99__ssa_v0, ret0__out, static_cast<int32_t>(d0_inline64__idx_v0));
References
  1. Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.

int64_t d0_inline64__idx_v0 = d0_inline64__idx_v0_conv.val;

// Forward to ptoas-generated function
sh_write(sh_y_v1_inline114__ssa_v0, sh__iter_v1, d0_inline64__idx_v0);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential data loss due to type mismatch. The variable d0_inline64__idx_v0 is of type int64_t, but the function sh_write expects int32_t for its 3rd argument. This could lead to truncation and unexpected behavior if the value exceeds the int32_t range. It's safer to explicitly cast it to int32_t to ensure consistency and avoid narrowing conversion warnings.

    sh_write(sh_y_v1_inline114__ssa_v0, sh__iter_v1, static_cast<int32_t>(d0_inline64__idx_v0));
References
  1. Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.

int64_t t_inline7__co_idx_v0 = t_inline7__co_idx_v0_conv.val;

// Forward to ptoas-generated function
write_post(post_pad_flat_inline73__ssa_v0, post_flat_inline51__ssa_v0, t_inline7__co_idx_v0);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential data loss due to type mismatch. The variable t_inline7__co_idx_v0 is of type int64_t, but the function write_post expects int32_t for its 3rd argument. This could lead to truncation and unexpected behavior if the value exceeds the int32_t range. It's safer to explicitly cast it to int32_t to ensure consistency and avoid narrowing conversion warnings.

    write_post(post_pad_flat_inline73__ssa_v0, post_flat_inline51__ssa_v0, static_cast<int32_t>(t_inline7__co_idx_v0));
References
  1. Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant