Skip to content

[ROCm] Fused 4-bit SIMT GEMM#1979

Open
sstamenk wants to merge 4 commits into
bitsandbytes-foundation:mainfrom
sstamenk:rdna_simt
Open

[ROCm] Fused 4-bit SIMT GEMM#1979
sstamenk wants to merge 4 commits into
bitsandbytes-foundation:mainfrom
sstamenk:rdna_simt

Conversation

@sstamenk

@sstamenk sstamenk commented Jun 19, 2026

Copy link
Copy Markdown
Contributor

Summary

This PR adds a fused 4-bit SIMT dequant+GEMM path for ROCm/RDNA and wires it into gemm_4bit dispatch for small inference batch sizes. The kernel avoids the dequant+BLAS fallback overhead by decoding NF4/FP4 weights inside the GEMM, reusing decoded values across batch rows, and using RDNA-friendly primitives such as DPP reductions and LDS centroid lookup.

The implementation is validated across gfx1100, gfx1201, and gfx1151 using a 40-shape LLM sweep over bf16, fp16, and fp32, comparing SIMT against dequant+BLAS, the baseline gemv implementation, and the optimized gemv from #1920. The results show SIMT consistently wins in the small-M inference regime, with architecture-specific crossover behavior. This PR does not include dispatch heuristics that take into account the architecture, shape, data type and other data but instead opts for a single value per architecture that is determined from the 40 shape median crossover values. Future work will add more intelligent heuristics that will cover the WMMA/MFMA kernels as well.

Across the 40-shape RDNA sweep, the fused SIMT kernel delivers up to 44.8x speedup over dequant+BLAS (gfx1100, fp32, small GQA k/v, M=1), up to 9.9x over the baseline gemv (gfx1201, bf16, qwen3-235b-moe expert down, M=64), and up to 4.5x over the optimized #1920 gemv (gfx1100, bf16, small GQA k/v, M=64). The win window also extends well beyond single-token GEMV: median SIMT/dequant crossovers reach M≈622 on gfx1151 fp32, with some fp32 shapes remaining SIMT-faster through M=2048.

Technical details

  • HIP/RDNA enablement for cgemm_4bit_*: the existing fused SIMT GEMM is now compiled and callable on ROCm. The shared code uses bnb_bfloat16 / bnb_bfloat162 and bnb_stream_t aliases so the same C API works for both CUDA and HIP.
  • Native bf16 dot-product path: on HIP bf16 uses AMD’s v_dot2_f32_bf16 instruction to compute bf16 pair dot-products into fp32 accumulators. The dequant scale is applied once per chunk after the dot product, avoiding slower bf16 pair multiply emulation on RDNA.
  • LDS centroid lookup on HIP: the NF4/FP4 centroid table is staged in LDS for the HIP decode paths. bf16 uses a uint16_t LUT for the VDOT2 path; fp16 uses the LDS LUT on RDNA; fp32 uses the LDS LUT on gfx11 and gfx12.
  • DPP warp reduction: HIP uses AMD DPP row-shift operations for the warp reduction instead of the CUDA-style shuffle-down tree. This is cheaper for the RDNA 32-lane wave path.
  • Register prefetch of B and absmax: the HIP path prefetches the next packed 4-bit weight chunk and corresponding scale into registers while accumulating the current K group, improving latency hiding in the bandwidth-bound small-M regime.
  • ROCm dispatch integration: the Python dispatch now routes RDNA gfx11/gfx12 small-M GEMMs through the fused custom kernel and falls back to dequant+F.linear outside the calibrated range. The hook is structured so future ROCm WMMA/MFMA kernels can be added without reworking the high-level dispatch.
  • Raw-pointer safety fix: the custom kernel wrapper now calls A.contiguous() before passing A.data_ptr() into C. The fused kernel assumes row-major contiguous A; without this, transposed/sliced/view inputs could silently produce incorrect results.

Testing plan

  • 40-shape performance sweep across RDNA architectures
    Run the full cold-cache sweep over the 40 canonical LLM shapes for bf16, fp16, and fp32, comparing:
    • gemv (baseline) — current gemv implementation slightly modified into a tiled kernel that supports M > 1
    • gemv ([ROCm] Optimize kgemm_4bit_inference_naive for ROCm, use it for batch sizes other than 1 #1920) — optimized multi-row gemv
    • SIMT — fused 4-bit dequant + GEMM kernel
    • dequant+BLAS — current fallback baseline
      The sweep is run across the tested RDNA targets (gfx1100, gfx1201, gfx1151, gfx942) and summarized with per-shape grids plus median-over-shapes aggregate plots.
  • NVIDIA regression testing
    Run the same gemm_4bit correctness and smoke-performance checks on NVIDIA hardware to ensure the shared C++/Python dispatch changes do not regress the existing CUDA path or alter the CUDA MMA/SIMT selection behavior.
  • Unit test / correctness validation
    Run the relevant unit tests for gemm_4bit, gemv_4bit, quantization layouts, and fallback correctness to verify outputs match dequantized reference results across supported dtypes, quantization modes, nested/non-nested absmax, bias handling, and dispatch fallback cases.

Testing results

Median results over the 40 comparing different kernel performance

median_gfx1201 median_gfx1201 median_gfx1151
  • gfx942
    TODO

Nvidia regression testing

  • No performance regressions observed

Correctness validation

  • All tests successfully passed

Known limitations

  • RDNA memory-camping effects can dominate specific N values. The SIMT kernel is bandwidth-bound at small M, so some output dimensions alias poorly onto the memory-channel/partition layout and create large fixed stalls that are mostly independent of work size. This is visible as isolated outlier shapes in the per-shape charts: for example, gfx1201 shows the strongest camp behavior on power-of-two N values such as the small square / N=2048 case, while gfx1100 is more sensitive around multiples of 3072. Strix Halo (gfx1151) shows both families in the resonance sweep.

  • Example RDNA memory-camping outlier: on gfx1201, the N=2048 small square shape aliases poorly and forces the SIMT kernel into a fixed ~27 µs latency floor extending up to M = 8.

small_square
  • Latency shape sweep showing spikes at specific sizes
resonance_nsweep_M1

@sstamenk sstamenk marked this pull request as ready for review June 19, 2026 13:51
@matthewdouglas matthewdouglas added this to the v0.50.0 milestone Jun 19, 2026
@github-actions

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants