diff --git a/src/hpc/quantized.rs b/src/hpc/quantized.rs index cbf0f272..124efc28 100644 --- a/src/hpc/quantized.rs +++ b/src/hpc/quantized.rs @@ -1,7 +1,7 @@ -//! Quantized GEMM: BF16 and Int8 matrix multiplication. +//! Quantized GEMM: BF16, Int8, Int4 (Q4_0 GGUF) matrix multiplication. //! //! Provides BF16 (bfloat16) type with conversions, BF16 GEMM with f32 accumulation, -//! and int8 quantized GEMM with various dequantization modes. +//! int8 quantized GEMM, and GGUF Q4_0 block-quantized helpers. // Types used only for ndarray integration (Array re-exports) @@ -102,6 +102,342 @@ pub fn bf16_vec_to_f32(src: &[BF16]) -> Vec { src.iter().map(|v| v.to_f32()).collect() } +// ── F16 (IEEE 754 binary16) ──────────────────────────────────────── + +/// IEEE 754 binary16 (half-precision): 1 sign + 5 exponent + 10 mantissa bits. +/// +/// Range ±65504, ~3 decimal digits of precision. Smaller dynamic range than +/// [`BF16`] but more mantissa precision; typical for graphics / Apple Neural +/// Engine / NVIDIA tensor cores. +/// +/// # Example +/// +/// ``` +/// use ndarray::hpc::quantized::F16; +/// +/// let val = F16::from_f32(1.0); +/// assert_eq!(val, F16::ONE); +/// assert_eq!(val.to_f32(), 1.0); +/// ``` +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[repr(transparent)] +pub struct F16(pub u16); + +impl F16 { + /// Zero in F16. + pub const ZERO: F16 = F16(0x0000); + /// One in F16. + pub const ONE: F16 = F16(0x3C00); + /// Negative one. + pub const NEG_ONE: F16 = F16(0xBC00); + /// Positive infinity. + pub const INFINITY: F16 = F16(0x7C00); + /// Negative infinity. + pub const NEG_INFINITY: F16 = F16(0xFC00); + /// Quiet NaN (canonical bit pattern). + pub const NAN: F16 = F16(0x7E00); + + /// Convert f32 to F16 with round-to-nearest-even (the default rule for + /// IEEE 754 binary16). + /// + /// Tested against: `0.0 → 0x0000`, `1.0 → 0x3C00`, `-2.0 → 0xC000`, + /// `+inf → 0x7C00`, NaN → quiet NaN. + #[inline] + pub fn from_f32(v: f32) -> Self { + Self::from_f32_rounded(v) + } + + /// Truncating f32 → F16 conversion (drop low mantissa bits, no rounding). + /// + /// Faster than [`Self::from_f32_rounded`] but ~0.5 ULP biased. + #[inline] + pub fn from_f32_truncate(v: f32) -> Self { + let bits = v.to_bits(); + let sign = ((bits >> 16) & 0x8000) as u16; + let exp = ((bits >> 23) & 0xFF) as i32; + let mant = bits & 0x007F_FFFF; + + // Inf / NaN + if exp == 0xFF { + if mant == 0 { + return F16(sign | 0x7C00); + } + // Preserve NaN-ness; truncate mantissa to top 10 bits, force at + // least one mantissa bit so it stays NaN (not Inf). + let mant16 = (mant >> 13) as u16; + let mant16 = if mant16 == 0 { 1 } else { mant16 }; + return F16(sign | 0x7C00 | mant16); + } + + let new_exp = exp - 127 + 15; + if new_exp >= 0x1F { + // Overflow → ±Inf + return F16(sign | 0x7C00); + } + if new_exp <= 0 { + // Subnormal or underflow to zero. + if new_exp < -10 { + return F16(sign); + } + // Subnormal: insert implicit leading 1, then shift right. + let mant_full = mant | 0x0080_0000; + let shift = 14 - new_exp; // 13 + (1 - new_exp) + let mant16 = (mant_full >> shift) as u16; + return F16(sign | mant16); + } + let mant16 = (mant >> 13) as u16; + F16(sign | ((new_exp as u16) << 10) | mant16) + } + + /// f32 → F16 with round-to-nearest-even. + pub fn from_f32_rounded(v: f32) -> Self { + let bits = v.to_bits(); + let sign = ((bits >> 16) & 0x8000) as u16; + let exp = ((bits >> 23) & 0xFF) as i32; + let mant = bits & 0x007F_FFFF; + + if exp == 0xFF { + // Inf/NaN + if mant == 0 { + return F16(sign | 0x7C00); + } + let mant16 = (mant >> 13) as u16; + let mant16 = if mant16 == 0 { 0x200 } else { mant16 | 0x200 }; + return F16(sign | 0x7C00 | mant16); + } + + let new_exp = exp - 127 + 15; + + // Normal range: 1 ≤ new_exp ≤ 30 + if new_exp >= 0x1F { + return F16(sign | 0x7C00); + } + if new_exp >= 1 { + // Round-to-nearest-even using the dropped 13 bits. + let half_bits = mant & 0x0000_1FFF; + let truncated = (mant >> 13) as u32; + let mant16 = truncated; + let half = 0x1000u32; + let round_up = if half_bits > half { + 1 + } else if half_bits < half { + 0 + } else { + // tie → round to even + mant16 & 1 + }; + let mant16 = mant16 + round_up; + // Mantissa overflow bumps exponent. + if mant16 == 0x400 { + let new_exp = new_exp + 1; + if new_exp >= 0x1F { + return F16(sign | 0x7C00); + } + return F16(sign | ((new_exp as u16) << 10)); + } + return F16(sign | ((new_exp as u16) << 10) | (mant16 as u16)); + } + + // Subnormal / underflow. + if new_exp < -10 { + // Underflow to zero (still RNE: half-of-min subnormal could round + // up, but here we treat exponent below -10 as truly tiny). + return F16(sign); + } + // Build the full mantissa with implicit leading 1 then shift. + let mant_full = mant | 0x0080_0000; // 24-bit + let shift = (14 - new_exp) as u32; // 13 + (1 - new_exp) + let truncated = mant_full >> shift; + let dropped_mask = (1u32 << shift) - 1; + let dropped = mant_full & dropped_mask; + let half = 1u32 << (shift - 1); + let round_up = if dropped > half { + 1 + } else if dropped < half { + 0 + } else { + truncated & 1 + }; + let mant16 = (truncated + round_up) as u16; + // mant16 may be 0x400 = 1024 → next exponent (which is 1, normal) + F16(sign | mant16) + } + + /// F16 → f32 (lossless, since binary32 strictly subsumes binary16). + pub fn to_f32(self) -> f32 { + let h = self.0 as u32; + let sign = (h & 0x8000) << 16; + let exp = (h >> 10) & 0x1F; + let mant = h & 0x03FF; + + let bits = if exp == 0 { + if mant == 0 { + // ±0 + sign + } else { + // Subnormal: normalize. + let mut m = mant; + let mut e: i32 = 1; + while (m & 0x0400) == 0 { + m <<= 1; + e -= 1; + } + let m = m & 0x03FF; + let new_exp = (e - 1 + 127 - 14) as u32; + sign | (new_exp << 23) | (m << 13) + } + } else if exp == 0x1F { + // Inf or NaN. + if mant == 0 { + sign | 0x7F80_0000 + } else { + sign | 0x7F80_0000 | (mant << 13) + } + } else { + let new_exp = exp + (127 - 15); + sign | (new_exp << 23) | (mant << 13) + }; + f32::from_bits(bits) + } +} + +/// Convert f32 slice to F16 (round-to-nearest-even). +pub fn f32_to_f16_slice(src: &[f32], dst: &mut [F16]) { + let n = src.len().min(dst.len()); + for i in 0..n { + dst[i] = F16::from_f32(src[i]); + } +} + +/// Convert F16 slice to f32. +pub fn f16_to_f32_slice(src: &[F16], dst: &mut [f32]) { + let n = src.len().min(dst.len()); + for i in 0..n { + dst[i] = src[i].to_f32(); + } +} + +/// Convert f32 vec to F16 vec. +pub fn f32_vec_to_f16(src: &[f32]) -> Vec { + src.iter().map(|&v| F16::from_f32(v)).collect() +} + +/// Convert F16 vec to f32 vec. +pub fn f16_vec_to_f32(src: &[F16]) -> Vec { + src.iter().map(|v| v.to_f32()).collect() +} + +// ── Operator + num_traits impls for BF16 + F16 ───────────────────── + +macro_rules! impl_half_ops { + ($t:ident) => { + impl core::ops::Add for $t { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + $t::from_f32_rounded(self.to_f32() + rhs.to_f32()) + } + } + impl core::ops::Sub for $t { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + $t::from_f32_rounded(self.to_f32() - rhs.to_f32()) + } + } + impl core::ops::Mul for $t { + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + $t::from_f32_rounded(self.to_f32() * rhs.to_f32()) + } + } + impl core::ops::Div for $t { + type Output = Self; + #[inline] + fn div(self, rhs: Self) -> Self { + $t::from_f32_rounded(self.to_f32() / rhs.to_f32()) + } + } + impl core::ops::Rem for $t { + type Output = Self; + #[inline] + fn rem(self, rhs: Self) -> Self { + $t::from_f32_rounded(self.to_f32() % rhs.to_f32()) + } + } + impl core::ops::Neg for $t { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self(self.0 ^ 0x8000) + } + } + impl core::ops::AddAssign for $t { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } + } + impl core::ops::SubAssign for $t { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } + } + impl core::ops::MulAssign for $t { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } + } + impl core::ops::DivAssign for $t { + #[inline] + fn div_assign(&mut self, rhs: Self) { + *self = *self / rhs; + } + } + impl core::ops::RemAssign for $t { + #[inline] + fn rem_assign(&mut self, rhs: Self) { + *self = *self % rhs; + } + } + impl num_traits::Zero for $t { + #[inline] + fn zero() -> Self { + Self::ZERO + } + #[inline] + fn is_zero(&self) -> bool { + // Treat both +0 and -0 as zero. + (self.0 & 0x7FFF) == 0 + } + } + impl num_traits::One for $t { + #[inline] + fn one() -> Self { + Self::ONE + } + } + impl Default for $t { + #[inline] + fn default() -> Self { + Self::ZERO + } + } + impl core::fmt::Display for $t { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}", self.to_f32()) + } + } + impl crate::ScalarOperand for $t {} + }; +} + +impl_half_ops!(BF16); +impl_half_ops!(F16); + /// BF16 GEMM with f32 accumulation: C = alpha * A * B + beta * C /// /// A and B are BF16, C is f32. Accumulation done in f32 for precision. @@ -663,6 +999,125 @@ mod tests { } } + #[test] + fn test_f16_special_constants() { + assert_eq!(F16::ZERO.0, 0x0000); + assert_eq!(F16::ONE.0, 0x3C00); + assert_eq!(F16::INFINITY.0, 0x7C00); + assert_eq!(F16::NEG_INFINITY.0, 0xFC00); + } + + #[test] + fn test_f16_from_f32_known_bits() { + assert_eq!(F16::from_f32(0.0).0, 0x0000); + assert_eq!(F16::from_f32(-0.0).0, 0x8000); + assert_eq!(F16::from_f32(1.0).0, 0x3C00); + assert_eq!(F16::from_f32(-1.0).0, 0xBC00); + assert_eq!(F16::from_f32(2.0).0, 0x4000); + assert_eq!(F16::from_f32(-2.0).0, 0xC000); + assert_eq!(F16::from_f32(0.5).0, 0x3800); + assert_eq!(F16::from_f32(-0.5).0, 0xB800); + } + + #[test] + fn test_f16_inf_nan() { + // ±Inf + assert_eq!(F16::from_f32(f32::INFINITY).0, 0x7C00); + assert_eq!(F16::from_f32(f32::NEG_INFINITY).0, 0xFC00); + assert!(F16::from_f32(f32::INFINITY).to_f32().is_infinite()); + assert!(F16::from_f32(f32::INFINITY).to_f32() > 0.0); + assert!(F16::from_f32(f32::NEG_INFINITY).to_f32() < 0.0); + + // NaN preserved + let nan = F16::from_f32(f32::NAN); + assert!(nan.to_f32().is_nan()); + // Top exponent = all-ones, mantissa nonzero + assert_eq!(nan.0 & 0x7C00, 0x7C00); + assert!(nan.0 & 0x03FF != 0); + } + + #[test] + fn test_f16_roundtrip_exact() { + let exact = [0.0f32, 1.0, -1.0, 2.0, -2.0, 0.5, -0.5, 0.25, -0.25, 65504.0, -65504.0]; + for &v in &exact { + let h = F16::from_f32(v); + assert_eq!(h.to_f32(), v, "F16 lost {}", v); + } + } + + #[test] + fn test_f16_roundtrip_approx() { + let approx = [3.14f32, -3.14, 0.1, 100.0, 1000.0, 1e-3, 1e-4]; + for &v in &approx { + let h = F16::from_f32(v); + let back = h.to_f32(); + assert!( + (back - v).abs() / v.abs().max(1.0) < 0.001, + "F16 roundtrip {} → {}", + v, + back + ); + } + } + + #[test] + fn test_f16_negation() { + let one = F16::ONE; + let neg_one = -one; + assert_eq!(neg_one.to_f32(), -1.0); + let zero = F16::ZERO; + let neg_zero = -zero; + assert_eq!(neg_zero.0, 0x8000); + } + + #[test] + fn test_f16_arithmetic() { + let a = F16::from_f32(3.0); + let b = F16::from_f32(4.0); + assert!(((a + b).to_f32() - 7.0).abs() < 1e-3); + assert!(((b - a).to_f32() - 1.0).abs() < 1e-3); + assert!(((a * b).to_f32() - 12.0).abs() < 1e-3); + assert!(((b / a).to_f32() - 4.0 / 3.0).abs() < 1e-3); + } + + #[test] + fn test_bf16_arithmetic() { + let a = BF16::from_f32_rounded(3.0); + let b = BF16::from_f32_rounded(4.0); + assert!(((a + b).to_f32() - 7.0).abs() < 0.05); + assert!(((b - a).to_f32() - 1.0).abs() < 0.05); + assert!(((a * b).to_f32() - 12.0).abs() < 0.1); + } + + #[test] + fn test_half_zero_one_traits() { + use num_traits::{One, Zero}; + let z: F16 = F16::zero(); + let o: F16 = F16::one(); + assert_eq!(z, F16::ZERO); + assert_eq!(o, F16::ONE); + assert!(z.is_zero()); + let bz: BF16 = BF16::zero(); + let bo: BF16 = BF16::one(); + assert_eq!(bz, BF16::ZERO); + assert_eq!(bo, BF16::ONE); + } + + #[test] + fn test_linalg_scalar_compiles() { + // Smoke test: F16 / BF16 satisfy LinalgScalar bounds. + fn assert_linalg() {} + assert_linalg::(); + assert_linalg::(); + } + + #[test] + fn test_scalar_operand_compiles() { + fn assert_scalar() {} + assert_scalar::(); + assert_scalar::(); + } + #[test] fn test_i2_packing_layout() { // 4 values per byte, LSB first. diff --git a/src/lib.rs b/src/lib.rs index 3badbdcc..701e0faa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -257,6 +257,11 @@ pub mod simd_wasm; #[allow(missing_docs)] pub mod simd_int_ops; +/// Half-precision SIMD vectors (`BF16x16`, `F16x16`) + slice-level ops. +#[cfg(feature = "std")] +#[allow(clippy::all, missing_docs, dead_code, unused_variables, unused_imports)] +// pub mod simd_half; // TODO: BF16x16/F16x16 SIMD vectors (A2 WIP) + /// Pluggable linear algebra backends (native SIMD, MKL, OpenBLAS). #[cfg(feature = "std")] pub mod backend; diff --git a/src/simd.rs b/src/simd.rs index 9161929a..7cfbb63f 100644 --- a/src/simd.rs +++ b/src/simd.rs @@ -1210,6 +1210,15 @@ pub use crate::hpc::quantized::{ QuantParams, }; +// Half-precision SIMD vectors (BF16x16, F16x16) — runtime-dispatched, always +// available. Note: when `target_feature = "avx512bf16"` is active a separate +// hardware-only `BF16x16` is also exported above from `simd_avx512`. The +// hardware-native one ships unsafe `from_u16_slice` / `to_f32x16` and is +// distinct from the portable runtime-dispatched `simd_half::BF16x16`. +// TODO: BF16x16/F16x16 SIMD vector types + slice ops (A2 WIP — simd_half module) +// F16 type itself is available in hpc::quantized::F16. +// SIMD vectors land in Wave 3 after the A2 module is completed. + // K-means + L2 distance pub use crate::hpc::cam_pq::{kmeans, squared_l2};