[ROCm] Fused 4-bit SIMT GEMM#1979
Open
sstamenk wants to merge 4 commits into
Open
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds a fused 4-bit SIMT dequant+GEMM path for ROCm/RDNA and wires it into gemm_4bit dispatch for small inference batch sizes. The kernel avoids the dequant+BLAS fallback overhead by decoding NF4/FP4 weights inside the GEMM, reusing decoded values across batch rows, and using RDNA-friendly primitives such as DPP reductions and LDS centroid lookup.
The implementation is validated across gfx1100, gfx1201, and gfx1151 using a 40-shape LLM sweep over bf16, fp16, and fp32, comparing SIMT against dequant+BLAS, the baseline gemv implementation, and the optimized gemv from #1920. The results show SIMT consistently wins in the small-M inference regime, with architecture-specific crossover behavior. This PR does not include dispatch heuristics that take into account the architecture, shape, data type and other data but instead opts for a single value per architecture that is determined from the 40 shape median crossover values. Future work will add more intelligent heuristics that will cover the WMMA/MFMA kernels as well.
Across the 40-shape RDNA sweep, the fused SIMT kernel delivers up to 44.8x speedup over dequant+BLAS (gfx1100, fp32, small GQA k/v, M=1), up to 9.9x over the baseline gemv (gfx1201, bf16, qwen3-235b-moe expert down, M=64), and up to 4.5x over the optimized #1920 gemv (gfx1100, bf16, small GQA k/v, M=64). The win window also extends well beyond single-token GEMV: median SIMT/dequant crossovers reach M≈622 on gfx1151 fp32, with some fp32 shapes remaining SIMT-faster through M=2048.
Technical details
Testing plan
Run the full cold-cache sweep over the 40 canonical LLM shapes for bf16, fp16, and fp32, comparing:
The sweep is run across the tested RDNA targets (gfx1100, gfx1201, gfx1151, gfx942) and summarized with per-shape grids plus median-over-shapes aggregate plots.
Run the same gemm_4bit correctness and smoke-performance checks on NVIDIA hardware to ensure the shared C++/Python dispatch changes do not regress the existing CUDA path or alter the CUDA MMA/SIMT selection behavior.
Run the relevant unit tests for gemm_4bit, gemv_4bit, quantization layouts, and fallback correctness to verify outputs match dequantized reference results across supported dtypes, quantization modes, nested/non-nested absmax, bias handling, and dispatch fallback cases.
Testing results
Median results over the 40 comparing different kernel performance
TODO
Nvidia regression testing
Correctness validation
Known limitations
RDNA memory-camping effects can dominate specific N values. The SIMT kernel is bandwidth-bound at small M, so some output dimensions alias poorly onto the memory-channel/partition layout and create large fixed stalls that are mostly independent of work size. This is visible as isolated outlier shapes in the per-shape charts: for example, gfx1201 shows the strongest camp behavior on power-of-two N values such as the small square / N=2048 case, while gfx1100 is more sensitive around multiples of 3072. Strix Halo (gfx1151) shows both families in the resonance sweep.
Example RDNA memory-camping outlier: on gfx1201, the N=2048 small square shape aliases poorly and forces the SIMT kernel into a fixed ~27 µs latency floor extending up to M = 8.