Skip to content

Add prepare metadata kernel#6

Open
simveit wants to merge 3 commits into
MoonshotAI:masterfrom
simveit:feature/metadata
Open

Add prepare metadata kernel#6
simveit wants to merge 3 commits into
MoonshotAI:masterfrom
simveit:feature/metadata

Conversation

@simveit
Copy link
Copy Markdown

@simveit simveit commented May 26, 2026

Summary

This PRs is based upon work by team CUDA_ERROR_UNKNOWN
(@gau-nernst, @yue-zhang-2025, @mayankagarwals, @zcnrex) from the FlashInfer
competition.

This PR adds a metadata prepare kernel for variable-length batches. When
enabled, it precomputes chunk offsets and chunk-to-sequence indices, avoiding
repeated scans over cu_seqlens in _flash_kda_fwd_prepare and tile-base scans
in the recurrence kernel. By default, metadata is enabled when
N >= kVarlenMetadataAutoMinSequences, currently 32.

Benchmark

See below for performance numbers on B200.

  • Device: Blackwell / B200
  • Benchmark settings: T=8192, H=96, D=128, warmup=30, iters=200, repeats=5
  • fla_chunk_kda configuration: use_gate_in_kernel=True, use_qk_l2norm_in_kernel=True, use_beta_sigmoid_in_kernel=True, lower_bound=-5, transpose_state_layout=True
  • fla_chunk_gated_delta_rule configuration: scalar per-head gate g of shape (1, T, H), use_qk_l2norm_in_kernel=True, transpose_state_layout=True

Command:

uv run --python .venv/bin/python python benchmarks/bench_fwd.py --mode all --use-varlen-metadata default
uv run --python .venv/bin/python python benchmarks/bench_fwd.py --mode all --use-varlen-metadata off
uv run --python .venv/bin/python python benchmarks/bench_fwd.py --mode all --use-varlen-metadata on

T=8192, H=96, D=128, use_varlen_metadata=default

Case flash_kda mean (ms) fla_chunk_kda mean (ms) Speedup vs chunk_kda fla_chunk_gdn mean (ms) Speedup vs gdn
Fixed 1.0626 2.4014 2.26x 1.3465 1.27x
Varlen, seq_lens=[1300, 547, 2048, 963, 271, 3063] 0.9225 2.4587 2.67x 1.3939 1.51x
Varlen, seq_lens=1024 x 8 0.7767 2.4289 3.13x 1.3541 1.74x
Varlen, seq_lens=512 x 16 0.8090 2.4476 3.03x 1.3637 1.69x
Varlen, seq_lens=256 x 32 0.8450 2.5152 2.98x 1.4259 1.69x
Varlen, seq_lens=64 x 128 1.2422 2.9343 2.36x 1.8281 1.47x
Varlen, seq_lens=32 x 256 1.7671 4.1664 2.36x 3.0362 1.72x
Varlen, seq_lens=16 x 512 2.8234 6.7707 2.40x 5.5277 1.96x
Varlen, seq_lens=4096 + 8 x 512 4.0787 7.2827 1.79x 5.9525 1.46x

T=8192, H=96, D=128, use_varlen_metadata=off

Case flash_kda mean (ms) fla_chunk_kda mean (ms) Speedup vs chunk_kda fla_chunk_gdn mean (ms) Speedup vs gdn
Fixed 1.0626 2.4015 2.26x 1.3467 1.27x
Varlen, seq_lens=[1300, 547, 2048, 963, 271, 3063] 0.9226 2.4585 2.66x 1.3938 1.51x
Varlen, seq_lens=1024 x 8 0.7779 2.4288 3.12x 1.3540 1.74x
Varlen, seq_lens=512 x 16 0.8071 2.4475 3.03x 1.3638 1.69x
Varlen, seq_lens=256 x 32 0.8791 2.5153 2.86x 1.4257 1.62x
Varlen, seq_lens=64 x 128 1.8059 2.9345 1.62x 1.8282 1.01x
Varlen, seq_lens=32 x 256 3.5842 4.1658 1.16x 3.0362 0.85x
Varlen, seq_lens=16 x 512 8.7285 6.7704 0.78x 5.5276 0.63x
Varlen, seq_lens=4096 + 8 x 512 8.3310 7.2918 0.88x 5.9640 0.72x

T=8192, H=96, D=128, use_varlen_metadata=on

Case flash_kda mean (ms) fla_chunk_kda mean (ms) Speedup vs chunk_kda fla_chunk_gdn mean (ms) Speedup vs gdn
Fixed 1.0656 2.4013 2.25x 1.3467 1.26x
Varlen, seq_lens=[1300, 547, 2048, 963, 271, 3063] 0.9786 2.4584 2.51x 1.3940 1.42x
Varlen, seq_lens=1024 x 8 0.8114 2.4283 2.99x 1.3540 1.67x
Varlen, seq_lens=512 x 16 0.8236 2.4475 2.97x 1.3639 1.66x
Varlen, seq_lens=256 x 32 0.8462 2.5149 2.97x 1.4257 1.68x
Varlen, seq_lens=64 x 128 1.2428 2.9342 2.36x 1.8280 1.47x
Varlen, seq_lens=32 x 256 1.7678 4.1657 2.36x 3.0364 1.72x
Varlen, seq_lens=16 x 512 2.8234 6.7702 2.40x 5.5276 1.96x
Varlen, seq_lens=4096 + 8 x 512 4.0868 7.2916 1.78x 5.9635 1.46x

@simveit simveit changed the title Feature/metadata Add prepare metadata kernel May 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant