Fix/topk decode dispatch seqlen#3773
Open
chuanbowang2026 wants to merge 5 commits into
Open
Conversation
In DSV4 CSA_INDEX decode, logits tensor is allocated with fixed stride0=max_model_len_idx (262144) for CUDAGraph compatibility, but actual data is only ~2048 elements. The dispatch used stride0 to choose mb vs ob path, causing unnecessary multi-block overhead. Add max_seqlen parameter to top_k_per_row_decode. When provided, dispatch uses it instead of stride0. Default -1 preserves backward compatibility. Benchmark (MI308X, N=8, stride0=262144, effective_len=2048, K=1024): - Before: 33us (mb path) - After: 21us (ob path), 36% faster - Correctness: verified, same index set output Caller (ATOM) passes CPU-side max(n_committed_per_seq) as hint, zero GPU sync overhead.
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Contributor
Author
|
Note: this PR adds the To activate the fix, ATOM needs a one-line change in max_committed = int(attn_metadata.n_committed_csa_per_seq_cpu[:bs].max())
top_k_per_row_decode(..., k=topk, max_seqlen=max_committed)Will submit the ATOM PR after this one merges. |
CUDAGraph replays freeze all CPU-side arguments, so the previous max_seqlen hint approach cannot work at runtime. Switch decode to always use the one-block kernel unconditionally — this avoids the mb persistent kernel's cross-block barrier deadlock risk when other streams occupy CUs, and gives better latency for typical decode batch sizes. Prefill dispatch (mb/ob heuristic) is unchanged.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
top_k_per_row_decodeusedstride0to choose between the multi-block (mb) and one-block (ob) paths. In DSV4 CSA_INDEX decode,stride0is the preallocated CUDAGraph width (262144), while the effective row length is much smaller (~8192). This made dispatch choose mb unnecessarily.This PR adds an optional
max_seqlenhint totop_k_per_row_decode. When provided, dispatch usesmax_seqlenfor path selection while the actual top-k range remains controlled byseqLens. The default-1preserves existing behavior.Prefill is unchanged.
Benchmark
Case:
N=8, stride0=262144, effective_len=8192, K=1024Both paths produce identical top-k index sets.
Caller Usage