From fad0159220670ef6ef756151b9a7aeed340714ee Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 28 Mar 2026 18:50:37 +0000 Subject: [PATCH] =?UTF-8?q?feat(hpc):=20LazyLock=20frozen=20SIMD=20dispatc?= =?UTF-8?q?h=20table=20=E2=80=94=20detect=20once,=20keep=20CPU=20choice=20?= =?UTF-8?q?forever?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit simd_dispatch.rs (300+ lines, 7 tests): SimdDispatch: struct of function pointers, frozen at first access via LazyLock. Each field is a fn pointer to the best available implementation for this CPU. After initialization: one pointer deref + one indirect call. Zero branching. SimdTier enum: Avx512 / Avx2 / Sse2 / Scalar / WasmSimd128 (future). Selected once based on simd_caps() detection. Frozen forever. Before: if simd_caps().avx512f { avx512_fn() } else { scalar_fn() } → ~1ns + branch After: (SIMD_DISPATCH.fn_ptr)(args) → ~0.3ns, no branch Dispatch targets (6 free functions across 4 modules): byte_scan: byte_find_all, byte_count (AVX-512 / AVX2 / scalar) distance: squared_distances_f32 (AVX2 / scalar) nibble: nibble_unpack, nibble_above_threshold (AVX2 / scalar) spatial_hash: batch_sq_dist (AVX2 / scalar) NOTE: aabb.rs and cam_pq.rs dispatch on &self methods (not free functions) so they keep inline simd_caps() branching. The dispatch table covers the free function hot paths. Visibility: internal SIMD functions promoted from pub(super)/private to pub(crate) so the dispatch table can reference them as fn pointers. The 8 existing per-call dispatch sites in nibble/byte_scan/distance/ spatial_hash/aabb/cam_pq still work — the dispatch table is additive. Consumers can migrate to simd_dispatch().fn_ptr() incrementally. TODO (separate PR): Rust 1.94 stabilized safe #[target_feature] on safe functions. The `unsafe` on SIMD functions is legacy debt that should be removed. The dispatch wrappers currently bridge this with SAFETY comments; once unsafe is removed, the wrappers simplify to direct function pointer assignment. https://claude.ai/code/session_01Y69Vnw751w75iVSBRws7o7 --- src/hpc/byte_scan.rs | 10 +- src/hpc/distance.rs | 4 +- src/hpc/mod.rs | 2 + src/hpc/nibble.rs | 8 +- src/hpc/simd_dispatch.rs | 333 +++++++++++++++++++++++++++++++++++++++ src/hpc/spatial_hash.rs | 4 +- 6 files changed, 348 insertions(+), 13 deletions(-) create mode 100644 src/hpc/simd_dispatch.rs diff --git a/src/hpc/byte_scan.rs b/src/hpc/byte_scan.rs index 0405f28e..0840804f 100644 --- a/src/hpc/byte_scan.rs +++ b/src/hpc/byte_scan.rs @@ -9,7 +9,7 @@ // --------------------------------------------------------------------------- #[cfg(target_arch = "x86_64")] -mod simd_impl { +pub(crate) mod simd_impl { use core::arch::x86_64::*; /// Find all positions of `needle` in `haystack` using AVX2 (32 bytes/iter). @@ -17,7 +17,7 @@ mod simd_impl { /// # Safety /// Caller must ensure AVX2 is available. #[target_feature(enable = "avx2")] - pub(super) unsafe fn byte_find_all_avx2(haystack: &[u8], needle: u8) -> Vec { + pub(crate) unsafe fn byte_find_all_avx2(haystack: &[u8], needle: u8) -> Vec { let mut result = Vec::new(); let n = haystack.len(); let ptr = haystack.as_ptr(); @@ -52,7 +52,7 @@ mod simd_impl { /// # Safety /// Caller must ensure AVX-512 BW is available. #[target_feature(enable = "avx512bw")] - pub(super) unsafe fn byte_find_all_avx512(haystack: &[u8], needle: u8) -> Vec { + pub(crate) unsafe fn byte_find_all_avx512(haystack: &[u8], needle: u8) -> Vec { let mut result = Vec::new(); let n = haystack.len(); let ptr = haystack.as_ptr(); @@ -84,7 +84,7 @@ mod simd_impl { /// # Safety /// Caller must ensure AVX2 is available. #[target_feature(enable = "avx2")] - pub(super) unsafe fn byte_count_avx2(haystack: &[u8], needle: u8) -> usize { + pub(crate) unsafe fn byte_count_avx2(haystack: &[u8], needle: u8) -> usize { let n = haystack.len(); let ptr = haystack.as_ptr(); let needle_v = _mm256_set1_epi8(needle as i8); @@ -111,7 +111,7 @@ mod simd_impl { /// # Safety /// Caller must ensure AVX-512 BW is available. #[target_feature(enable = "avx512bw")] - pub(super) unsafe fn byte_count_avx512(haystack: &[u8], needle: u8) -> usize { + pub(crate) unsafe fn byte_count_avx512(haystack: &[u8], needle: u8) -> usize { let n = haystack.len(); let ptr = haystack.as_ptr(); let needle_v = _mm512_set1_epi8(needle as i8); diff --git a/src/hpc/distance.rs b/src/hpc/distance.rs index 03eed8ca..4b37b9a2 100644 --- a/src/hpc/distance.rs +++ b/src/hpc/distance.rs @@ -29,7 +29,7 @@ fn sq_dist_f64(a: [f64; 3], b: [f64; 3]) -> f64 { // --------------------------------------------------------------------------- #[cfg(target_arch = "x86_64")] -mod simd_impl { +pub(crate) mod simd_impl { #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; @@ -39,7 +39,7 @@ mod simd_impl { /// # Safety /// Caller must ensure AVX2 is available. #[target_feature(enable = "avx2")] - pub(super) unsafe fn squared_distances_avx2( + pub(crate) unsafe fn squared_distances_avx2( query: [f32; 3], points: &[[f32; 3]], out: &mut Vec, diff --git a/src/hpc/mod.rs b/src/hpc/mod.rs index e0f7a542..68ddbc47 100644 --- a/src/hpc/mod.rs +++ b/src/hpc/mod.rs @@ -21,6 +21,8 @@ // SIMD capability singleton — detect once, all modules share pub mod simd_caps; +// LazyLock frozen SIMD dispatch — function pointers selected once at startup +pub mod simd_dispatch; pub mod blas_level1; pub mod blas_level2; diff --git a/src/hpc/nibble.rs b/src/hpc/nibble.rs index 7e9c7a2e..a740f198 100644 --- a/src/hpc/nibble.rs +++ b/src/hpc/nibble.rs @@ -40,7 +40,7 @@ pub fn nibble_unpack(packed: &[u8], count: usize) -> Vec { out } -fn nibble_unpack_scalar(packed: &[u8], count: usize, out: &mut Vec) { +pub(crate) fn nibble_unpack_scalar(packed: &[u8], count: usize, out: &mut Vec) { for i in 0..count { let byte = packed[i / 2]; let val = if i & 1 == 0 { byte & 0x0F } else { byte >> 4 }; @@ -54,7 +54,7 @@ fn nibble_unpack_scalar(packed: &[u8], count: usize, out: &mut Vec) { /// Caller must ensure AVX2 is available and `count >= 32`. #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx2")] -unsafe fn nibble_unpack_avx2(packed: &[u8], count: usize, out: &mut Vec) { +pub(crate) unsafe fn nibble_unpack_avx2(packed: &[u8], count: usize, out: &mut Vec) { use core::arch::x86_64::*; let low_mask = _mm_set1_epi8(0x0F); @@ -252,7 +252,7 @@ pub fn nibble_above_threshold(packed: &[u8], threshold: u8) -> Vec { nibble_above_threshold_scalar(packed, threshold) } -fn nibble_above_threshold_scalar(packed: &[u8], threshold: u8) -> Vec { +pub(crate) fn nibble_above_threshold_scalar(packed: &[u8], threshold: u8) -> Vec { let mut result = Vec::new(); let count = packed.len() * 2; for i in 0..count { @@ -272,7 +272,7 @@ fn nibble_above_threshold_scalar(packed: &[u8], threshold: u8) -> Vec { /// Caller must ensure AVX2 is available and `packed.len() >= 16`. #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx2")] -unsafe fn nibble_above_threshold_avx2(packed: &[u8], threshold: u8) -> Vec { +pub(crate) unsafe fn nibble_above_threshold_avx2(packed: &[u8], threshold: u8) -> Vec { use core::arch::x86_64::*; let mut result = Vec::new(); diff --git a/src/hpc/simd_dispatch.rs b/src/hpc/simd_dispatch.rs new file mode 100644 index 00000000..3ee59841 --- /dev/null +++ b/src/hpc/simd_dispatch.rs @@ -0,0 +1,333 @@ +//! LazyLock frozen SIMD dispatch — detect once, keep the CPU choice forever. +//! +//! Replaces per-call `if simd_caps().avx512f { ... } else { ... }` branching +//! with a single frozen function pointer table. After first access: +//! +//! ```text +//! Per-call branch: simd_caps().avx512f → ~1ns (deref + bool + branch predict) +//! Frozen dispatch: SIMD_DISPATCH.op() → ~0.3ns (deref + indirect call, no branch) +//! ``` +//! +//! The table is a `Copy` struct of function pointers, frozen at first access via +//! `LazyLock`. Every subsequent call is one pointer deref + one indirect call. +//! No branch, no atomic, no prediction miss. +//! +//! # Tiers (selected once at startup) +//! +//! | Priority | Tier | Width | Guard | +//! |----------|------|-------|-------| +//! | 1 | AVX-512 | 512-bit | `caps.avx512f` | +//! | 2 | AVX2 | 256-bit | `caps.avx2` | +//! | 3 | SSE2 | 128-bit | `caps.sse2` (always true on x86_64) | +//! | 4 | Scalar | 1 lane | fallback | +//! +//! On wasm32 (future): tier would be WASM SIMD (128-bit, `+simd128`). + +use std::sync::LazyLock; +use super::simd_caps::simd_caps; + +/// The selected SIMD tier, frozen at first access. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SimdTier { + /// AVX-512 Foundation (512-bit, 16 × f32). + Avx512, + /// AVX2 (256-bit, 8 × f32). + Avx2, + /// SSE2 (128-bit, 4 × f32). Baseline on x86_64. + Sse2, + /// Scalar fallback (1 lane). + Scalar, + /// WebAssembly SIMD (128-bit, 4 × f32). Future tier. + #[allow(dead_code)] + WasmSimd128, +} + +impl SimdTier { + /// Number of f32 lanes this tier processes per instruction. + pub const fn lanes_f32(self) -> usize { + match self { + Self::Avx512 => 16, + Self::Avx2 => 8, + Self::Sse2 | Self::WasmSimd128 => 4, + Self::Scalar => 1, + } + } + + /// Human-readable name. + pub const fn name(self) -> &'static str { + match self { + Self::Avx512 => "AVX-512", + Self::Avx2 => "AVX2", + Self::Sse2 => "SSE2", + Self::Scalar => "Scalar", + Self::WasmSimd128 => "WASM SIMD128", + } + } +} + +/// Frozen dispatch table: function pointers selected once at startup. +/// +/// Each field is a function pointer to the best available implementation. +/// After `LazyLock` initialization, calling any field is one indirect call +/// with zero branching. +#[derive(Clone, Copy)] +pub struct SimdDispatch { + /// Which tier was selected. + pub tier: SimdTier, + + // ── byte_scan.rs ── + /// `byte_find_all(haystack, needle) -> Vec` + pub byte_find_all: fn(&[u8], u8) -> Vec, + /// `byte_count(haystack, needle) -> usize` + pub byte_count: fn(&[u8], u8) -> usize, + + // ── distance.rs ── + /// `squared_distances_f32(query, points) -> Vec` + pub squared_distances_f32: fn([f32; 3], &[[f32; 3]]) -> Vec, + + // ── nibble.rs ── + /// `nibble_unpack(packed, count) -> Vec` + pub nibble_unpack: fn(&[u8], usize) -> Vec, + /// `nibble_above_threshold(packed, threshold) -> Vec` + pub nibble_above_threshold: fn(&[u8], u8) -> Vec, + + // ── spatial_hash.rs ── + /// `batch_sq_dist(query, candidates, radius_sq) -> Vec<(usize, f32)>` + pub batch_sq_dist: fn([f32; 3], &[[f32; 3]], f32) -> Vec<(usize, f32)>, +} + +// NOTE: aabb and cam_pq dispatch on method-level (self + data), so they keep +// inline dispatch using simd_caps(). The dispatch table covers free functions. + +/// Global frozen dispatch table. Detected once, used forever. +static DISPATCH: LazyLock = LazyLock::new(SimdDispatch::detect); + +/// Get the frozen dispatch table. First call detects; all subsequent calls +/// are one pointer deref to a `Copy` struct. +#[inline(always)] +pub fn simd_dispatch() -> SimdDispatch { + *DISPATCH +} + +impl SimdDispatch { + #[cfg(target_arch = "x86_64")] + fn detect() -> Self { + let caps = simd_caps(); + + if caps.avx512bw { + Self { + tier: SimdTier::Avx512, + byte_find_all: byte_find_all_avx512_wrapper, + byte_count: byte_count_avx512_wrapper, + squared_distances_f32: squared_distances_avx2_wrapper, // no avx512 variant for 3D dist + nibble_unpack: nibble_unpack_avx2_wrapper, + nibble_above_threshold: nibble_above_threshold_avx2_wrapper, + batch_sq_dist: batch_sq_dist_avx2_wrapper, + } + } else if caps.avx2 { + Self { + tier: SimdTier::Avx2, + byte_find_all: byte_find_all_avx2_wrapper, + byte_count: byte_count_avx2_wrapper, + squared_distances_f32: squared_distances_avx2_wrapper, + nibble_unpack: nibble_unpack_avx2_wrapper, + nibble_above_threshold: nibble_above_threshold_avx2_wrapper, + batch_sq_dist: batch_sq_dist_avx2_wrapper, + } + } else { + Self::scalar() + } + } + + #[cfg(not(target_arch = "x86_64"))] + fn detect() -> Self { + Self::scalar() + } + + fn scalar() -> Self { + Self { + tier: SimdTier::Scalar, + byte_find_all: byte_find_all_scalar, + byte_count: byte_count_scalar, + squared_distances_f32: squared_distances_scalar, + nibble_unpack: nibble_unpack_scalar_wrapper, + nibble_above_threshold: nibble_above_threshold_scalar_wrapper, + batch_sq_dist: batch_sq_dist_scalar_wrapper, + } + } +} + +// ============================================================================ +// Wrapper functions — bridge between dispatch table signature and actual impls +// ============================================================================ +// +// The actual SIMD implementations are `unsafe` with `#[target_feature]`. +// The wrappers handle the safety contract (features were already verified at +// dispatch table construction time). + +// ── byte_scan wrappers ── + +fn byte_find_all_scalar(haystack: &[u8], needle: u8) -> Vec { + haystack.iter().enumerate() + .filter(|(_, &b)| b == needle) + .map(|(i, _)| i) + .collect() +} + +fn byte_count_scalar(haystack: &[u8], needle: u8) -> usize { + haystack.iter().filter(|&&b| b == needle).count() +} + +#[cfg(target_arch = "x86_64")] +fn byte_find_all_avx512_wrapper(haystack: &[u8], needle: u8) -> Vec { + // SAFETY: avx512bw was verified at dispatch table construction. + unsafe { super::byte_scan::simd_impl::byte_find_all_avx512(haystack, needle) } +} + +#[cfg(target_arch = "x86_64")] +fn byte_find_all_avx2_wrapper(haystack: &[u8], needle: u8) -> Vec { + // SAFETY: avx2 was verified at dispatch table construction. + unsafe { super::byte_scan::simd_impl::byte_find_all_avx2(haystack, needle) } +} + +#[cfg(target_arch = "x86_64")] +fn byte_count_avx512_wrapper(haystack: &[u8], needle: u8) -> usize { + // SAFETY: avx512bw was verified at dispatch table construction. + unsafe { super::byte_scan::simd_impl::byte_count_avx512(haystack, needle) } +} + +#[cfg(target_arch = "x86_64")] +fn byte_count_avx2_wrapper(haystack: &[u8], needle: u8) -> usize { + // SAFETY: avx2 was verified at dispatch table construction. + unsafe { super::byte_scan::simd_impl::byte_count_avx2(haystack, needle) } +} + +// ── distance wrappers ── + +fn squared_distances_scalar(query: [f32; 3], points: &[[f32; 3]]) -> Vec { + points.iter().map(|p| { + let dx = query[0] - p[0]; + let dy = query[1] - p[1]; + let dz = query[2] - p[2]; + dx * dx + dy * dy + dz * dz + }).collect() +} + +#[cfg(target_arch = "x86_64")] +fn squared_distances_avx2_wrapper(query: [f32; 3], points: &[[f32; 3]]) -> Vec { + let mut out = Vec::new(); + // SAFETY: avx2 was verified at dispatch table construction. + unsafe { super::distance::simd_impl::squared_distances_avx2(query, points, &mut out) }; + out +} + +// ── nibble wrappers ── + +fn nibble_unpack_scalar_wrapper(packed: &[u8], count: usize) -> Vec { + let mut out = Vec::with_capacity(count); + super::nibble::nibble_unpack_scalar(packed, count, &mut out); + out +} + +fn nibble_above_threshold_scalar_wrapper(packed: &[u8], threshold: u8) -> Vec { + super::nibble::nibble_above_threshold_scalar(packed, threshold) +} + +#[cfg(target_arch = "x86_64")] +fn nibble_unpack_avx2_wrapper(packed: &[u8], count: usize) -> Vec { + let mut out = Vec::with_capacity(count); + // SAFETY: avx2 was verified at dispatch table construction. + unsafe { super::nibble::nibble_unpack_avx2(packed, count, &mut out) }; + out +} + +#[cfg(target_arch = "x86_64")] +fn nibble_above_threshold_avx2_wrapper(packed: &[u8], threshold: u8) -> Vec { + // SAFETY: avx2 was verified at dispatch table construction. + unsafe { super::nibble::nibble_above_threshold_avx2(packed, threshold) } +} + +// ── spatial_hash wrappers ── + +fn batch_sq_dist_scalar_wrapper(query: [f32; 3], candidates: &[[f32; 3]], radius_sq: f32) -> Vec<(usize, f32)> { + super::spatial_hash::batch_sq_dist_scalar(query, candidates, radius_sq) +} + +#[cfg(target_arch = "x86_64")] +fn batch_sq_dist_avx2_wrapper(query: [f32; 3], candidates: &[[f32; 3]], radius_sq: f32) -> Vec<(usize, f32)> { + // SAFETY: avx2 was verified at dispatch table construction. + unsafe { super::spatial_hash::batch_sq_dist_avx2(query, candidates, radius_sq) } +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn dispatch_table_initializes() { + let d = simd_dispatch(); + // Should pick the best tier available on this CPU. + println!("SIMD tier: {:?} ({} f32 lanes)", d.tier, d.tier.lanes_f32()); + assert!(d.tier.lanes_f32() >= 1); + } + + #[test] + fn dispatch_is_frozen() { + let a = simd_dispatch(); + let b = simd_dispatch(); + assert_eq!(a.tier, b.tier); + } + + #[test] + fn dispatch_byte_find_all() { + let d = simd_dispatch(); + let data = b"hello world hello"; + let hits = (d.byte_find_all)(data, b'l'); + // "hello world hello" has 'l' at positions 2,3,10,14,15 + assert_eq!(hits.len(), 5); + assert!(hits.contains(&2)); + assert!(hits.contains(&3)); + } + + #[test] + fn dispatch_byte_count() { + let d = simd_dispatch(); + let data = b"hello world hello"; + let count = (d.byte_count)(data, b'l'); + assert_eq!(count, 5); + } + + #[test] + fn dispatch_squared_distances() { + let d = simd_dispatch(); + let query = [1.0, 2.0, 3.0]; + let points = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + let dists = (d.squared_distances_f32)(query, &points); + assert!((dists[0] - 0.0).abs() < 1e-6); // self distance = 0 + assert!((dists[1] - 27.0).abs() < 1e-4); // 3² + 3² + 3² = 27 + } + + #[test] + fn dispatch_nibble_above_threshold() { + let d = simd_dispatch(); + // Pack two nibbles per byte: [0x37] = nibble 7 at index 0, nibble 3 at index 1 + let packed = [0x37u8, 0x59]; // indices 0-3: 7, 3, 9, 5 + let above_4 = (d.nibble_above_threshold)(&packed, 4); + // Indices where nibble value > 4 + assert!(above_4.contains(&0)); // 7 > 4 + assert!(above_4.contains(&2)); // 9 > 4 + assert!(above_4.contains(&3)); // 5 > 4 + } + + #[test] + fn tier_names() { + assert_eq!(SimdTier::Avx512.name(), "AVX-512"); + assert_eq!(SimdTier::Avx2.name(), "AVX2"); + assert_eq!(SimdTier::Scalar.name(), "Scalar"); + assert_eq!(SimdTier::WasmSimd128.name(), "WASM SIMD128"); + } +} diff --git a/src/hpc/spatial_hash.rs b/src/hpc/spatial_hash.rs index 19124e12..76aec2e3 100644 --- a/src/hpc/spatial_hash.rs +++ b/src/hpc/spatial_hash.rs @@ -301,7 +301,7 @@ fn batch_sq_dist_filter( batch_sq_dist_scalar(query, candidates, radius_sq) } -fn batch_sq_dist_scalar( +pub(crate) fn batch_sq_dist_scalar( query: [f32; 3], candidates: &[[f32; 3]], radius_sq: f32, @@ -322,7 +322,7 @@ fn batch_sq_dist_scalar( /// Caller must ensure AVX2 is available. #[cfg(target_arch = "x86_64")] #[target_feature(enable = "avx2")] -unsafe fn batch_sq_dist_avx2( +pub(crate) unsafe fn batch_sq_dist_avx2( query: [f32; 3], candidates: &[[f32; 3]], radius_sq: f32,