diff --git a/src/hpc/quantized.rs b/src/hpc/quantized.rs index b215588a..cbf0f272 100644 --- a/src/hpc/quantized.rs +++ b/src/hpc/quantized.rs @@ -463,6 +463,112 @@ pub fn dequantize_i2_to_f32(packed: &[u8], params: &QuantParams, n: usize) -> Ve out } +// ── Q4_0 (GGUF-compatible block quantization) ───────────────────── + +/// Q4_0 block size (number of f32 elements per block). +pub const Q4_0_BLOCK_SIZE: usize = 32; + +/// Number of bytes used to pack one Q4_0 block (32 nibbles = 16 bytes). +pub const Q4_0_BYTES_PER_BLOCK: usize = Q4_0_BLOCK_SIZE / 2; + +/// Quantize f32 to Q4_0 (GGUF-compatible) per-block 4-bit quantization. +/// +/// Each block of [`Q4_0_BLOCK_SIZE`] (32) f32 elements is encoded as +/// [`Q4_0_BYTES_PER_BLOCK`] (16) packed bytes plus one f32 scale `d`. +/// +/// Encoding (matches `llama.cpp` / GGUF reference): +/// - `d = max(|x_i|) / -8.0` (negated max-abs, signed) +/// - `q_i = clamp(round(x_i / d) + 8, 0, 15)` (unsigned nibble in `0..=15`) +/// - Within a block, element `j` (`0 <= j < 16`) and element `j + 16` +/// share byte `j`: low nibble holds element `j`, high nibble holds +/// element `j + 16`. (This is the GGUF interleaved layout, NOT the +/// simple "two consecutive elements per byte" layout used by +/// [`quantize_f32_to_i4`].) +/// +/// Returns `(packed_bytes, scales)` where `scales.len() == data.len() / 32` +/// and `packed_bytes.len() == scales.len() * 16`. +/// +/// # Panics +/// +/// Panics if `data.len()` is not a multiple of [`Q4_0_BLOCK_SIZE`]. +pub fn quantize_f32_to_q4_0(data: &[f32]) -> (Vec, Vec) { + assert!( + data.len() % Q4_0_BLOCK_SIZE == 0, + "Q4_0 requires data.len() to be a multiple of {}", + Q4_0_BLOCK_SIZE + ); + + let n_blocks = data.len() / Q4_0_BLOCK_SIZE; + let mut packed = vec![0u8; n_blocks * Q4_0_BYTES_PER_BLOCK]; + let mut scales = Vec::with_capacity(n_blocks); + + for b in 0..n_blocks { + let block = &data[b * Q4_0_BLOCK_SIZE..(b + 1) * Q4_0_BLOCK_SIZE]; + + // Find signed max-abs (preserving sign of the extreme element). + let mut amax = 0.0f32; + let mut max_signed = 0.0f32; + for &x in block { + let ax = x.abs(); + if ax > amax { + amax = ax; + max_signed = x; + } + } + + // d = max_signed / -8 ; if all-zero block, d = 0 and all q = 8. + let d = if amax > 0.0 { max_signed / -8.0 } else { 0.0 }; + let id = if d != 0.0 { 1.0 / d } else { 0.0 }; + scales.push(d); + + let byte_off = b * Q4_0_BYTES_PER_BLOCK; + for j in 0..Q4_0_BYTES_PER_BLOCK { + let lo = ((block[j] * id).round() + 8.5).floor().clamp(0.0, 15.0) as u8; + let hi = ((block[j + Q4_0_BYTES_PER_BLOCK] * id).round() + 8.5) + .floor() + .clamp(0.0, 15.0) as u8; + packed[byte_off + j] = (lo & 0x0F) | ((hi & 0x0F) << 4); + } + } + + (packed, scales) +} + +/// Dequantize Q4_0 (GGUF-compatible) packed bytes back to f32. +/// +/// Inverse of [`quantize_f32_to_q4_0`]. `packed.len()` must equal +/// `scales.len() * 16` and the result has length `scales.len() * 32`. +/// +/// # Panics +/// +/// Panics if `packed.len() != scales.len() * Q4_0_BYTES_PER_BLOCK`. +pub fn dequantize_q4_0_to_f32(packed: &[u8], scales: &[f32]) -> Vec { + assert_eq!( + packed.len(), + scales.len() * Q4_0_BYTES_PER_BLOCK, + "Q4_0 packed length must equal scales.len() * {}", + Q4_0_BYTES_PER_BLOCK + ); + + let n_blocks = scales.len(); + let mut out = vec![0.0f32; n_blocks * Q4_0_BLOCK_SIZE]; + + for b in 0..n_blocks { + let d = scales[b]; + let byte_off = b * Q4_0_BYTES_PER_BLOCK; + let elem_off = b * Q4_0_BLOCK_SIZE; + for j in 0..Q4_0_BYTES_PER_BLOCK { + let byte = packed[byte_off + j]; + let lo = (byte & 0x0F) as i32 - 8; + let hi = ((byte >> 4) & 0x0F) as i32 - 8; + out[elem_off + j] = lo as f32 * d; + out[elem_off + j + Q4_0_BYTES_PER_BLOCK] = hi as f32 * d; + } + } + + out +} + #[cfg(test)] mod tests { use super::*; @@ -567,4 +673,109 @@ mod tests { // = 0b01_00_11_01 = 0x4D assert_eq!(packed[0], 0b01_00_11_01); } + + #[test] + fn test_i4_boundary_values() { + // With abs_max=7 -> scale=1.0; the i4 grid maps directly: + // input -7 -> q=-7 -> dequant -7.0 + // input 0 -> q= 0 -> dequant 0.0 + // input 7 -> q= 7 -> dequant 7.0 + let data = vec![-7.0f32, -3.0, 0.0, 3.0, 7.0]; + let (packed, params) = quantize_f32_to_i4(&data); + assert!((params.scale - 1.0).abs() < 1e-6); + let recovered = dequantize_i4_to_f32(&packed, ¶ms, data.len()); + assert_eq!(recovered, vec![-7.0, -3.0, 0.0, 3.0, 7.0]); + + // Negative-end clamp: with abs_max=8 -> scale=8/7, the value -8 + // maps to q=-7 (since -8 / (8/7) = -7), dequantizing to -8.0 + // exactly. The 8 grid cell hits q=7 -> dequant 8.0 exactly. + let data2 = vec![-8.0f32, 0.0, 8.0]; + let (packed2, params2) = quantize_f32_to_i4(&data2); + let s = params2.scale; + assert!((s - 8.0 / 7.0).abs() < 1e-6); + let rec2 = dequantize_i4_to_f32(&packed2, ¶ms2, data2.len()); + assert!((rec2[0] - -8.0).abs() < 1e-4); + assert_eq!(rec2[1], 0.0); + assert!((rec2[2] - 8.0).abs() < 1e-4); + } + + #[test] + fn test_q4_0_roundtrip_single_block() { + let mut data = Vec::with_capacity(Q4_0_BLOCK_SIZE); + for i in 0..Q4_0_BLOCK_SIZE { + data.push((i as f32) - 16.0); // values in [-16, 15] + } + let (packed, scales) = quantize_f32_to_q4_0(&data); + assert_eq!(packed.len(), Q4_0_BYTES_PER_BLOCK); + assert_eq!(scales.len(), 1); + let recovered = dequantize_q4_0_to_f32(&packed, &scales); + assert_eq!(recovered.len(), data.len()); + // Max abs in block is 16. With 4-bit signed grid (16 levels), + // expected error <= |d| ≈ 16/8 = 2.0. + let max_abs = 16.0f32; + let tol = max_abs / 8.0 + 1e-4; + for (i, (orig, rec)) in data.iter().zip(recovered.iter()).enumerate() { + assert!((orig - rec).abs() <= tol, "q4_0 roundtrip[{i}]: {orig} vs {rec} (tol {tol})"); + } + } + + #[test] + fn test_q4_0_roundtrip_multi_block() { + // 3 blocks (96 elements), monotonically varying values. + let n = 3 * Q4_0_BLOCK_SIZE; + let data: Vec = (0..n).map(|i| ((i as f32) - 48.0) * 0.25).collect(); + let (packed, scales) = quantize_f32_to_q4_0(&data); + assert_eq!(scales.len(), 3); + assert_eq!(packed.len(), 3 * Q4_0_BYTES_PER_BLOCK); + let recovered = dequantize_q4_0_to_f32(&packed, &scales); + assert_eq!(recovered.len(), n); + for (b, &d) in scales.iter().enumerate() { + let tol = d.abs() + 1e-4; + for j in 0..Q4_0_BLOCK_SIZE { + let i = b * Q4_0_BLOCK_SIZE + j; + assert!( + (data[i] - recovered[i]).abs() <= tol, + "q4_0 multi[{i}] block={b}: {} vs {} tol={tol}", + data[i], + recovered[i] + ); + } + } + } + + #[test] + fn test_q4_0_zero_block() { + let data = vec![0.0f32; Q4_0_BLOCK_SIZE]; + let (packed, scales) = quantize_f32_to_q4_0(&data); + assert_eq!(scales[0], 0.0); + let recovered = dequantize_q4_0_to_f32(&packed, &scales); + for v in recovered { + assert_eq!(v, 0.0); + } + } + + #[test] + fn test_q4_0_packing_layout_interleaved() { + // Verify GGUF interleaved layout: byte j carries element j (low) + // and element j + 16 (high), within one 32-element block. + let mut data = vec![0.0f32; Q4_0_BLOCK_SIZE]; + // Set element 0 to extreme negative so q=15 (since d<0, x/d>0), + // and leave element 16 at 0 so its q=8. + data[0] = -1.0; + data[16] = 0.0; + let (packed, scales) = quantize_f32_to_q4_0(&data); + // d = (-1.0) / -8 = 0.125 ; q[0] = round(-1/0.125)+8 = -8+8 = 0 + // q[16] = 0 + 8 = 8 + assert!(scales[0] > 0.0); + // byte 0: low nibble = q[0] = 0, high nibble = q[16] = 8 + assert_eq!(packed[0] & 0x0F, 0); + assert_eq!((packed[0] >> 4) & 0x0F, 8); + } + + #[test] + #[should_panic] + fn test_q4_0_requires_block_aligned() { + let data = vec![1.0f32; Q4_0_BLOCK_SIZE - 1]; + let _ = quantize_f32_to_q4_0(&data); + } }