From 33cffa2005303f2b8e2975e121743186a6ed55e7 Mon Sep 17 00:00:00 2001 From: Iridescent-gcrace <781461778@qq.com> Date: Sat, 9 May 2026 18:32:19 +0800 Subject: [PATCH] Fix 2bit bit decoding residual path --- bit_decode/bit_decode_interface.py | 62 ++++++++++--------- csrc/bit_decode/decode_api.cpp | 4 +- csrc/bit_decode/src/flash_fwd_kernel.h | 4 +- .../flash_fwd_split_hdim128_fp16_sm80_2bit.cu | 4 +- .../flash_qpack_hdim128_fp16_sm80_2bit.cu | 16 ++--- evaluation/example.py | 5 +- 6 files changed, 51 insertions(+), 44 deletions(-) diff --git a/bit_decode/bit_decode_interface.py b/bit_decode/bit_decode_interface.py index b23bc6d..7b1248a 100644 --- a/bit_decode/bit_decode_interface.py +++ b/bit_decode/bit_decode_interface.py @@ -32,15 +32,17 @@ def kvcache_pack_int(k_cache: torch.Tensor, k_pack: torch.Tensor, k_params: torc quant_mode, group_size ) - # else: - # bit_decode_cuda.kvcache_pack_int2(K_unpad, k_pack, k_params, - # V_unpad, v_pack, v_params, - # opt_block_table, - # cu_seqlens_k, - # seqlen_k, - # quant_mode, - # group_size - # ) + elif num_bits == 2: + bit_decode_cuda.kvcache_pack_int2(K_unpad, k_pack, k_params, + V_unpad, v_pack, v_params, + opt_block_table, + cu_seqlens_k, + seqlen_k, + quant_mode, + group_size + ) + else: + raise ValueError(f"Unsupported num_bits={num_bits}; expected 2 or 4") def fwd_kvcache_int(q: torch.Tensor, k_pack: torch.Tensor, k_params: torch.Tensor, @@ -78,26 +80,28 @@ def fwd_kvcache_int(q: torch.Tensor, True, # Added 0 # Added ) - # else: - # out_bit, k_pack_new, k_params_new, v_pack_new, v_params_new = bit_decode_cuda.fwd_kvcache_int2( - # q, - # k_pack, k_params, - # v_pack, v_params, - # opt_k_new, opt_v_new, opt_seqlens_k, - # k_pack_new, k_params_new, v_pack_new, v_params_new, - # opt_block_table, - # softmax_scale, - # quant_mode, - # group_size, - # residual_block_size, - # new_lens, - # False, # Added - # -1, # Added - # -1, # Added - # 0.0, # Added - # True, # Added - # 0 # Added - # ) + elif num_bits == 2: + out_bit, k_pack_new, k_params_new, v_pack_new, v_params_new = bit_decode_cuda.fwd_kvcache_int2( + q, + k_pack, k_params, + v_pack, v_params, + opt_k_new, opt_v_new, opt_seqlens_k, + k_pack_new, k_params_new, v_pack_new, v_params_new, + opt_block_table, + softmax_scale, + quant_mode, + group_size, + residual_block_size, + new_lens, + False, # Added + -1, # Added + -1, # Added + 0.0, # Added + True, # Added + 0 # Added + ) + else: + raise ValueError(f"Unsupported num_bits={num_bits}; expected 2 or 4") return out_bit, k_pack_new, k_params_new, v_pack_new, v_params_new diff --git a/csrc/bit_decode/decode_api.cpp b/csrc/bit_decode/decode_api.cpp index 0a9c388..302d854 100644 --- a/csrc/bit_decode/decode_api.cpp +++ b/csrc/bit_decode/decode_api.cpp @@ -687,8 +687,8 @@ void kvcache_qpack(const at::Tensor &k, at::Tensor &k_pack, at::Tensor &k_params PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "BitDecoding"; - // m.def("kvcache_pack_i2", &kvcache_qpack<2>, "Forward pass, kvcache quantization and packing (2-bit)"); + m.def("kvcache_pack_int2", &kvcache_qpack<2>, "Forward pass, kvcache quantization and packing (2-bit)"); m.def("kvcache_pack_int4", &kvcache_qpack<4>, "Forward pass, kvcache quantization and packing (4-bit)"); - // m.def("fwd_kvcache_i2", &mha_fwd_kvcache<2>, "Forward pass, with 2-bit KV-cache"); + m.def("fwd_kvcache_int2", &mha_fwd_kvcache<2>, "Forward pass, with 2-bit KV-cache"); m.def("fwd_kvcache_int4", &mha_fwd_kvcache<4>, "Forward pass, with 4-bit KV-cache"); } diff --git a/csrc/bit_decode/src/flash_fwd_kernel.h b/csrc/bit_decode/src/flash_fwd_kernel.h index 226cfc6..9074920 100644 --- a/csrc/bit_decode/src/flash_fwd_kernel.h +++ b/csrc/bit_decode/src/flash_fwd_kernel.h @@ -70,7 +70,7 @@ inline __device__ void compute_attn_1rowblock_residualkv(const Params ¶ms, c using ElementKVPack = typename Kernel_traits::ElementKVPack; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; - using SharedStorage = typename Kernel_traits::SharedStorage; + using SharedStorage = typename Kernel_traits::SharedStorage_residual; // Shared memory. extern __shared__ char smem_[]; @@ -1881,4 +1881,4 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { } } -} // namespace flash \ No newline at end of file +} // namespace flash diff --git a/csrc/bit_decode/src/genfile/flash_fwd_split_hdim128_fp16_sm80_2bit.cu b/csrc/bit_decode/src/genfile/flash_fwd_split_hdim128_fp16_sm80_2bit.cu index 53705ca..75c6047 100644 --- a/csrc/bit_decode/src/genfile/flash_fwd_split_hdim128_fp16_sm80_2bit.cu +++ b/csrc/bit_decode/src/genfile/flash_fwd_split_hdim128_fp16_sm80_2bit.cu @@ -4,6 +4,6 @@ #include "../flash_fwd_launch_template.h" -// template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); // template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -// template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/bit_decode/src/genfile/flash_qpack_hdim128_fp16_sm80_2bit.cu b/csrc/bit_decode/src/genfile/flash_qpack_hdim128_fp16_sm80_2bit.cu index 40ca15c..b85f7c8 100644 --- a/csrc/bit_decode/src/genfile/flash_qpack_hdim128_fp16_sm80_2bit.cu +++ b/csrc/bit_decode/src/genfile/flash_qpack_hdim128_fp16_sm80_2bit.cu @@ -4,15 +4,15 @@ #include "../flash_fwd_launch_template.h" -// template<> -// void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// run_kvcache_qpack_hdim128(params, stream); -// } +template<> +void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_kvcache_qpack_hdim128(params, stream); +} // template<> // void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream) { // run_kvcache_qpack_hdim128(params, stream); // } -// template<> -// void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// run_kvcache_qpack_hdim128(params, stream); -// } \ No newline at end of file +template<> +void run_kvcache_qpack_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_kvcache_qpack_hdim128(params, stream); +} diff --git a/evaluation/example.py b/evaluation/example.py index 8a3672d..59152bb 100644 --- a/evaluation/example.py +++ b/evaluation/example.py @@ -23,7 +23,7 @@ def main(): parser.add_argument('--max_length', type=int, default=131072, help='Maximum length of the input sequence') parser.add_argument('--num_bits', type=int, default=4, help='Number of bits for quantization') parser.add_argument('--quant_mode', type=str, default='k-channel', help='Quantization mode') - parser.add_argument('--group_size', type=int, default=128, help='Group size for quantization') + parser.add_argument('--group_size', type=int, default=None, help='Group size for quantization') parser.add_argument('--attn_backend', type=str, default='flash_attention_2', help='Attention implementation') args = parser.parse_args() @@ -31,6 +31,9 @@ def main(): random.seed(0) torch.manual_seed(0) + if args.group_size is None: + args.group_size = 32 if args.num_bits == 2 else 128 + if "Llama" in args.model_path: config = LlamaConfig.from_pretrained(args.model_path) elif "Qwen" in args.model_path: