Skip to content

Fix/topk decode dispatch seqlen#3773

Open
chuanbowang2026 wants to merge 5 commits into
ROCm:mainfrom
chuanbowang2026:fix/topk-decode-dispatch-seqlen
Open

Fix/topk decode dispatch seqlen#3773
chuanbowang2026 wants to merge 5 commits into
ROCm:mainfrom
chuanbowang2026:fix/topk-decode-dispatch-seqlen

Conversation

@chuanbowang2026

Copy link
Copy Markdown
Contributor

Summary

top_k_per_row_decode used stride0 to choose between the multi-block (mb) and one-block (ob) paths. In DSV4 CSA_INDEX decode, stride0 is the preallocated CUDAGraph width (262144), while the effective row length is much smaller (~8192). This made dispatch choose mb unnecessarily.

This PR adds an optional max_seqlen hint to top_k_per_row_decode. When provided, dispatch uses max_seqlen for path selection while the actual top-k range remains controlled by seqLens. The default -1 preserves existing behavior.

Prefill is unchanged.

Benchmark

Case: N=8, stride0=262144, effective_len=8192, K=1024

GPU Before After Speedup
MI355X 30.740 us 10.838 us 2.84x
MI308X 36.74 us 20.24 us 1.82x

Both paths produce identical top-k index sets.

Caller Usage

max_committed = int(n_committed_csa_per_seq_cpu[:bs].max())
top_k_per_row_decode(..., k=topk, max_seqlen=max_committed)

In DSV4 CSA_INDEX decode, logits tensor is allocated with fixed
stride0=max_model_len_idx (262144) for CUDAGraph compatibility,
but actual data is only ~2048 elements. The dispatch used stride0
to choose mb vs ob path, causing unnecessary multi-block overhead.

Add max_seqlen parameter to top_k_per_row_decode. When provided,
dispatch uses it instead of stride0. Default -1 preserves backward
compatibility.

Benchmark (MI308X, N=8, stride0=262144, effective_len=2048, K=1024):
- Before: 33us (mb path)
- After:  21us (ob path), 36% faster
- Correctness: verified, same index set output

Caller (ATOM) passes CPU-side max(n_committed_per_seq) as hint,
zero GPU sync overhead.
@chuanbowang2026 chuanbowang2026 requested a review from a team June 17, 2026 07:30
@github-actions

Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3773 --add-label <label>

@chuanbowang2026

Copy link
Copy Markdown
Contributor Author

Note: this PR adds the max_seqlen parameter but defaults to -1, so it's backward compatible � no behavior change without caller-side update.

To activate the fix, ATOM needs a one-line change in deepseek_v4.py _score_topk_decode:

max_committed = int(attn_metadata.n_committed_csa_per_seq_cpu[:bs].max())
top_k_per_row_decode(..., k=topk, max_seqlen=max_committed)

Will submit the ATOM PR after this one merges.

chuanbowang2026 and others added 2 commits June 17, 2026 23:51
CUDAGraph replays freeze all CPU-side arguments, so the previous
max_seqlen hint approach cannot work at runtime. Switch decode to
always use the one-block kernel unconditionally — this avoids the
mb persistent kernel's cross-block barrier deadlock risk when other
streams occupy CUs, and gives better latency for typical decode
batch sizes.

Prefill dispatch (mb/ob heuristic) is unchanged.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant