Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 33 additions & 29 deletions bit_decode/bit_decode_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,17 @@ def kvcache_pack_int(k_cache: torch.Tensor, k_pack: torch.Tensor, k_params: torc
quant_mode,
group_size
)
# else:
# bit_decode_cuda.kvcache_pack_int2(K_unpad, k_pack, k_params,
# V_unpad, v_pack, v_params,
# opt_block_table,
# cu_seqlens_k,
# seqlen_k,
# quant_mode,
# group_size
# )
elif num_bits == 2:
bit_decode_cuda.kvcache_pack_int2(K_unpad, k_pack, k_params,
V_unpad, v_pack, v_params,
opt_block_table,
cu_seqlens_k,
seqlen_k,
quant_mode,
group_size
)
else:
raise ValueError(f"Unsupported num_bits={num_bits}; expected 2 or 4")

def fwd_kvcache_int(q: torch.Tensor,
k_pack: torch.Tensor, k_params: torch.Tensor,
Expand Down Expand Up @@ -78,26 +80,28 @@ def fwd_kvcache_int(q: torch.Tensor,
True, # Added
0 # Added
)
# else:
# out_bit, k_pack_new, k_params_new, v_pack_new, v_params_new = bit_decode_cuda.fwd_kvcache_int2(
# q,
# k_pack, k_params,
# v_pack, v_params,
# opt_k_new, opt_v_new, opt_seqlens_k,
# k_pack_new, k_params_new, v_pack_new, v_params_new,
# opt_block_table,
# softmax_scale,
# quant_mode,
# group_size,
# residual_block_size,
# new_lens,
# False, # Added
# -1, # Added
# -1, # Added
# 0.0, # Added
# True, # Added
# 0 # Added
# )
elif num_bits == 2:
out_bit, k_pack_new, k_params_new, v_pack_new, v_params_new = bit_decode_cuda.fwd_kvcache_int2(
q,
k_pack, k_params,
v_pack, v_params,
opt_k_new, opt_v_new, opt_seqlens_k,
k_pack_new, k_params_new, v_pack_new, v_params_new,
opt_block_table,
softmax_scale,
quant_mode,
group_size,
residual_block_size,
new_lens,
False, # Added
-1, # Added
-1, # Added
0.0, # Added
True, # Added
0 # Added
)
else:
raise ValueError(f"Unsupported num_bits={num_bits}; expected 2 or 4")


return out_bit, k_pack_new, k_params_new, v_pack_new, v_params_new
4 changes: 2 additions & 2 deletions csrc/bit_decode/decode_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -687,8 +687,8 @@ void kvcache_qpack(const at::Tensor &k, at::Tensor &k_pack, at::Tensor &k_params

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "BitDecoding";
// m.def("kvcache_pack_i2", &kvcache_qpack<2>, "Forward pass, kvcache quantization and packing (2-bit)");
m.def("kvcache_pack_int2", &kvcache_qpack<2>, "Forward pass, kvcache quantization and packing (2-bit)");
m.def("kvcache_pack_int4", &kvcache_qpack<4>, "Forward pass, kvcache quantization and packing (4-bit)");
// m.def("fwd_kvcache_i2", &mha_fwd_kvcache<2>, "Forward pass, with 2-bit KV-cache");
m.def("fwd_kvcache_int2", &mha_fwd_kvcache<2>, "Forward pass, with 2-bit KV-cache");
m.def("fwd_kvcache_int4", &mha_fwd_kvcache<4>, "Forward pass, with 4-bit KV-cache");
}
4 changes: 2 additions & 2 deletions csrc/bit_decode/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ inline __device__ void compute_attn_1rowblock_residualkv(const Params &params, c
using ElementKVPack = typename Kernel_traits::ElementKVPack;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
using SharedStorage = typename Kernel_traits::SharedStorage;
using SharedStorage = typename Kernel_traits::SharedStorage_residual;

// Shared memory.
extern __shared__ char smem_[];
Expand Down Expand Up @@ -1881,4 +1881,4 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
}
}

} // namespace flash
} // namespace flash
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@

#include "../flash_fwd_launch_template.h"

// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 128>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 128>(Flash_fwd_params &params, cudaStream_t stream);
// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 64>(Flash_fwd_params &params, cudaStream_t stream);
// template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 32>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false, 1, 2, 32>(Flash_fwd_params &params, cudaStream_t stream);
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

#include "../flash_fwd_launch_template.h"

// template<>
// void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 128>(Flash_fwd_params &params, cudaStream_t stream) {
// run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 128>(params, stream);
// }
template<>
void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 128>(params, stream);
}
// template<>
// void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 64>(Flash_fwd_params &params, cudaStream_t stream) {
// run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 64>(params, stream);
// }
// template<>
// void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 32>(Flash_fwd_params &params, cudaStream_t stream) {
// run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 32>(params, stream);
// }
template<>
void run_kvcache_qpack_<cutlass::half_t, 128, 1, 2, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_kvcache_qpack_hdim128<cutlass::half_t, 1, 2, 32>(params, stream);
}
5 changes: 4 additions & 1 deletion evaluation/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@ def main():
parser.add_argument('--max_length', type=int, default=131072, help='Maximum length of the input sequence')
parser.add_argument('--num_bits', type=int, default=4, help='Number of bits for quantization')
parser.add_argument('--quant_mode', type=str, default='k-channel', help='Quantization mode')
parser.add_argument('--group_size', type=int, default=128, help='Group size for quantization')
parser.add_argument('--group_size', type=int, default=None, help='Group size for quantization')
parser.add_argument('--attn_backend', type=str, default='flash_attention_2', help='Attention implementation')
args = parser.parse_args()

# For reproducibility
random.seed(0)
torch.manual_seed(0)

if args.group_size is None:
args.group_size = 32 if args.num_bits == 2 else 128

if "Llama" in args.model_path:
config = LlamaConfig.from_pretrained(args.model_path)
elif "Qwen" in args.model_path:
Expand Down