Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/hpc/aabb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,15 @@ fn sq_dist_point_aabb(point: [f32; 3], aabb: &Aabb) -> f32 {
pub fn aabb_intersect_batch(query: &Aabb, candidates: &[Aabb]) -> Vec<bool> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") && candidates.len() >= 16 {
// SAFETY: avx512f detected, enough candidates for batch processing.
let caps = super::simd_caps::simd_caps();
if caps.avx512f && candidates.len() >= 16 {
// SAFETY: avx512f detected via simd_caps singleton.
unsafe {
return aabb_intersect_batch_avx512(query, candidates);
}
}
if is_x86_feature_detected!("sse4.1") {
// SAFETY: sse4.1 detected, slice access within bounds.
if caps.sse41 {
// SAFETY: sse4.1 detected via simd_caps singleton.
unsafe {
return aabb_intersect_batch_sse41(query, candidates);
}
Expand Down Expand Up @@ -294,8 +295,8 @@ unsafe fn aabb_intersect_batch_sse41(query: &Aabb, candidates: &[Aabb]) -> Vec<b
pub fn ray_aabb_slab_test_batch(ray: &Ray, aabbs: &[Aabb]) -> (Vec<bool>, Vec<f32>) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") && aabbs.len() >= 16 {
// SAFETY: avx512f detected, enough AABBs for batch processing.
if super::simd_caps::simd_caps().avx512f && aabbs.len() >= 16 {
// SAFETY: avx512f detected via simd_caps singleton.
unsafe {
return ray_aabb_slab_test_avx512(ray, aabbs);
}
Expand Down Expand Up @@ -455,8 +456,8 @@ unsafe fn ray_aabb_slab_test_avx512(ray: &Ray, aabbs: &[Aabb]) -> (Vec<bool>, Ve
pub fn aabb_expand_batch(aabbs: &mut [Aabb], dx: f32, dy: f32, dz: f32) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("sse2") {
// SAFETY: sse2 detected, operating on mutable slice in-bounds.
if super::simd_caps::simd_caps().sse2 {
// SAFETY: sse2 detected via simd_caps singleton.
unsafe {
aabb_expand_batch_sse2(aabbs, dx, dy, dz);
return;
Expand Down
15 changes: 9 additions & 6 deletions src/hpc/bitwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,16 @@ pub fn hamming_top_k_raw(
fn dispatch_hamming(a: &[u8], b: &[u8]) -> u64 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512vpopcntdq") && is_x86_feature_detected!("avx512bw") {
let caps = super::simd_caps::simd_caps();
if caps.has_avx512_bw_popcnt() {
// SAFETY: checked VPOPCNTDQ + BW
return unsafe { crate::backend::kernels_avx512::hamming_distance(a, b) };
}
if is_x86_feature_detected!("avx512bw") {
if caps.avx512bw {
// SAFETY: checked AVX-512 BW — uses 512-bit vpshufb (64B/iter)
return unsafe { hamming_avx512bw(a, b) };
}
if is_x86_feature_detected!("avx2") {
if caps.avx2 {
// SAFETY: checked AVX2 — uses 256-bit vpshufb (32B/iter)
return unsafe { hamming_avx2(a, b) };
}
Expand All @@ -277,11 +278,12 @@ fn dispatch_hamming(a: &[u8], b: &[u8]) -> u64 {
fn dispatch_popcount(a: &[u8]) -> u64 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512vpopcntdq") {
let caps = super::simd_caps::simd_caps();
if caps.avx512vpopcntdq {
// SAFETY: checked VPOPCNTDQ
return unsafe { crate::backend::kernels_avx512::popcount(a) };
}
if is_x86_feature_detected!("avx512bw") {
if caps.avx512bw {
// SAFETY: checked AVX-512 BW — uses 512-bit vpshufb
return unsafe { popcount_avx512bw(a) };
}
Expand All @@ -292,7 +294,8 @@ fn dispatch_popcount(a: &[u8]) -> u64 {
fn dispatch_hamming_batch(query: &[u8], database: &[u8], num_rows: usize, row_bytes: usize) -> Vec<u64> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512vpopcntdq") && is_x86_feature_detected!("avx512bw") {
let caps = super::simd_caps::simd_caps();
if caps.has_avx512_bw_popcnt() {
// SAFETY: checked VPOPCNTDQ + BW
return unsafe { crate::backend::kernels_avx512::hamming_batch(query, database, num_rows, row_bytes) };
}
Expand Down
10 changes: 6 additions & 4 deletions src/hpc/byte_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,12 @@ mod simd_impl {
pub fn byte_find_all(haystack: &[u8], needle: u8) -> Vec<usize> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512bw") {
let caps = super::simd_caps::simd_caps();
if caps.avx512bw {
// SAFETY: feature detected above.
return unsafe { simd_impl::byte_find_all_avx512(haystack, needle) };
}
if is_x86_feature_detected!("avx2") {
if caps.avx2 {
// SAFETY: feature detected above.
return unsafe { simd_impl::byte_find_all_avx2(haystack, needle) };
}
Expand Down Expand Up @@ -180,11 +181,12 @@ pub fn u16_find_all(haystack: &[u8], pattern: u16) -> Vec<usize> {
pub fn byte_count(haystack: &[u8], needle: u8) -> usize {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512bw") {
let caps = super::simd_caps::simd_caps();
if caps.avx512bw {
// SAFETY: feature detected above.
return unsafe { simd_impl::byte_count_avx512(haystack, needle) };
}
if is_x86_feature_detected!("avx2") {
if caps.avx2 {
// SAFETY: feature detected above.
return unsafe { simd_impl::byte_count_avx2(haystack, needle) };
}
Expand Down
2 changes: 1 addition & 1 deletion src/hpc/cam_pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ impl DistanceTables {
pub fn distance_batch(&self, cams: &[CamFingerprint]) -> Vec<f32> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
if super::simd_caps::simd_caps().avx512f {
return unsafe { self.distance_batch_avx512(cams) };
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/hpc/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ mod simd_impl {
pub fn squared_distances_f32(query: [f32; 3], points: &[[f32; 3]]) -> Vec<f32> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
if super::simd_caps::simd_caps().avx2 {
let mut out = Vec::new();
// SAFETY: feature detected above.
unsafe { simd_impl::squared_distances_avx2(query, points, &mut out) };
Expand Down
3 changes: 3 additions & 0 deletions src/hpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
//! - FFT (forward, inverse, real-to-complex)
//! - VML (vectorized math library)

// SIMD capability singleton — detect once, all modules share
pub mod simd_caps;

pub mod blas_level1;
pub mod blas_level2;
pub mod blas_level3;
Expand Down
9 changes: 5 additions & 4 deletions src/hpc/nibble.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub fn nibble_unpack(packed: &[u8], count: usize) -> Vec<u8> {

#[cfg(target_arch = "x86_64")]
{
if count >= 32 && is_x86_feature_detected!("avx2") {
if count >= 32 && super::simd_caps::simd_caps().avx2 {
// SAFETY: avx2 detected, packed buffer large enough.
unsafe {
nibble_unpack_avx2(packed, count, &mut out);
Expand Down Expand Up @@ -136,14 +136,15 @@ pub fn nibble_sub_clamp(packed: &mut [u8], delta: u8) {

#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512bw") {
let caps = super::simd_caps::simd_caps();
if caps.avx512bw {
// SAFETY: avx512bw detected, slice is mutable and valid.
unsafe {
nibble_sub_clamp_avx512(packed, delta);
return;
}
}
if is_x86_feature_detected!("avx2") {
if caps.avx2 {
// SAFETY: avx2 detected, slice is mutable and valid.
unsafe {
nibble_sub_clamp_avx2(packed, delta);
Expand Down Expand Up @@ -242,7 +243,7 @@ unsafe fn nibble_sub_clamp_avx512(packed: &mut [u8], delta: u8) {
pub fn nibble_above_threshold(packed: &[u8], threshold: u8) -> Vec<usize> {
#[cfg(target_arch = "x86_64")]
{
if packed.len() >= 16 && is_x86_feature_detected!("avx2") {
if packed.len() >= 16 && super::simd_caps::simd_caps().avx2 {
// SAFETY: avx2 detected, packed buffer large enough.
return unsafe { nibble_above_threshold_avx2(packed, threshold) };
}
Expand Down
11 changes: 7 additions & 4 deletions src/hpc/palette_codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,12 @@ impl PackedPaletteArray {
pub fn unpack_indices_simd(packed: &[u64], bits_per_index: usize, count: usize) -> Vec<u8> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") && count >= 16 {
let caps = super::simd_caps::simd_caps();
if caps.avx512f && count >= 16 {
// SAFETY: avx512f detected, count >= 16 ensures enough data.
return unsafe { unpack_generic_avx512(packed, bits_per_index, count) };
}
if bits_per_index == 4 && count >= 16 && is_x86_feature_detected!("avx2") {
if bits_per_index == 4 && count >= 16 && caps.avx2 {
return unsafe { unpack_4bit_avx2(packed, count) };
}
}
Expand All @@ -281,7 +282,8 @@ pub fn unpack_indices_simd(packed: &[u64], bits_per_index: usize, count: usize)
pub fn pack_indices_simd(indices: &[u8], bits_per_index: usize) -> Vec<u64> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") && indices.len() >= 16 {
let caps = super::simd_caps::simd_caps();
if caps.avx512f && indices.len() >= 16 {
// SAFETY: avx512f detected, enough indices for SIMD processing.
return unsafe { pack_generic_avx512(indices, bits_per_index) };
}
Expand Down Expand Up @@ -415,7 +417,8 @@ pub fn bedrock_reorder_xzy(states: &[u16]) -> Vec<u16> {

#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
let caps = super::simd_caps::simd_caps();
if caps.avx512f {
// SAFETY: avx512f detected, states.len() == 4096 guaranteed by assert.
return unsafe { bedrock_reorder_xzy_avx512(states) };
}
Expand Down
11 changes: 7 additions & 4 deletions src/hpc/property_mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,15 @@ impl PropertyMask {

#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
let caps = super::simd_caps::simd_caps();
if caps.avx512f {
// SAFETY: avx512f detected, pointers are within slice bounds.
unsafe {
self.test_section_avx512(states, &mut result);
return result;
}
}
if is_x86_feature_detected!("avx2") {
if caps.avx2 {
// SAFETY: we checked avx2 at runtime, pointers are within slice bounds.
unsafe {
self.test_section_avx2(states, &mut result);
Expand All @@ -120,7 +121,8 @@ impl PropertyMask {
pub fn count_section(&self, states: &[u64]) -> u32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512vpopcntdq") && is_x86_feature_detected!("avx512f") {
let caps = super::simd_caps::simd_caps();
if caps.avx512vpopcntdq && caps.avx512f {
// SAFETY: feature detected above.
return unsafe { self.count_section_avx512(states) };
}
Expand Down Expand Up @@ -329,7 +331,8 @@ pub fn count_section_multi(masks: &[PropertyMask], states: &[u64]) -> MultiMaskR

#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") && states.len() >= 8 {
let caps = super::simd_caps::simd_caps();
if caps.avx512f && states.len() >= 8 {
// SAFETY: avx512f detected above, states.len() >= 8 guaranteed.
unsafe {
return count_section_multi_avx512(masks, states);
Expand Down
131 changes: 131 additions & 0 deletions src/hpc/simd_caps.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
//! SIMD capability singleton — detect once, dispatch forever.
//!
//! Replaces per-call `is_x86_feature_detected!` (hidden `AtomicU8` load each time)
//! with a single `LazyLock<SimdCaps>` detected at first access. Every HPC module
//! calls `simd_caps()` which is one pointer deref to a frozen `Copy` struct.
//!
//! ```text
//! is_x86_feature_detected!("avx512f") → ~3ns (atomic load + branch)
//! simd_caps().avx512f → ~1ns (LazyLock deref + bool read)
//! ```

use std::sync::LazyLock;

/// Detected SIMD capabilities, frozen at first access.
///
/// This is a `Copy` type: 8 bools packed into 8 bytes. Passed by value,
/// lives in registers after the first `LazyLock` deref.
#[derive(Debug, Clone, Copy)]
pub struct SimdCaps {
/// AVX2 (256-bit integer/FP SIMD).
pub avx2: bool,
/// AVX-512 Foundation (512-bit).
pub avx512f: bool,
/// AVX-512 Byte/Word operations.
pub avx512bw: bool,
/// AVX-512 Vector Length extensions.
pub avx512vl: bool,
/// AVX-512 VPOPCNTDQ (hardware popcount on 512-bit).
pub avx512vpopcntdq: bool,
/// SSE 4.1.
pub sse41: bool,
/// SSE2 (baseline on x86_64, but explicit for clarity).
pub sse2: bool,
/// FMA (fused multiply-add).
pub fma: bool,
}

/// Global singleton — detected once at first access via `LazyLock`.
static CAPS: LazyLock<SimdCaps> = LazyLock::new(SimdCaps::detect);

/// Get the detected SIMD capabilities. First call detects; all subsequent
/// calls are a single pointer deref with no atomic operations.
#[inline(always)]
pub fn simd_caps() -> SimdCaps {
*CAPS
}

impl SimdCaps {
/// Detect CPU capabilities at runtime.
#[cfg(target_arch = "x86_64")]
fn detect() -> Self {
Self {
avx2: is_x86_feature_detected!("avx2"),
avx512f: is_x86_feature_detected!("avx512f"),
avx512bw: is_x86_feature_detected!("avx512bw"),
avx512vl: is_x86_feature_detected!("avx512vl"),
avx512vpopcntdq: is_x86_feature_detected!("avx512vpopcntdq"),
sse41: is_x86_feature_detected!("sse4.1"),
sse2: is_x86_feature_detected!("sse2"),
fma: is_x86_feature_detected!("fma"),
}
}

/// Non-x86: all false.
#[cfg(not(target_arch = "x86_64"))]
fn detect() -> Self {
Self {
avx2: false,
avx512f: false,
avx512bw: false,
avx512vl: false,
avx512vpopcntdq: false,
sse41: false,
sse2: false,
fma: false,
}
}

/// True if AVX-512 Foundation + VPOPCNTDQ are both available.
#[inline(always)]
pub fn has_avx512_popcnt(self) -> bool {
self.avx512f && self.avx512vpopcntdq
}

/// True if AVX-512 BW + VPOPCNTDQ are both available.
#[inline(always)]
pub fn has_avx512_bw_popcnt(self) -> bool {
self.avx512bw && self.avx512vpopcntdq
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn detect_does_not_panic() {
let caps = simd_caps();
// On any platform, simd_caps() should succeed.
let _ = caps.avx2;
let _ = caps.avx512f;
}

#[test]
fn simd_caps_is_copy() {
let a = simd_caps();
let b = a; // Copy
let c = a; // Still valid
assert_eq!(a.avx2, b.avx2);
assert_eq!(b.avx512f, c.avx512f);
}

#[test]
fn simd_caps_deterministic() {
let a = simd_caps();
let b = simd_caps();
assert_eq!(a.avx2, b.avx2);
assert_eq!(a.avx512f, b.avx512f);
assert_eq!(a.avx512bw, b.avx512bw);
assert_eq!(a.avx512vpopcntdq, b.avx512vpopcntdq);
assert_eq!(a.sse41, b.sse41);
}

#[test]
fn convenience_methods() {
let caps = simd_caps();
// Just verify these don't panic and return consistent values.
let _ = caps.has_avx512_popcnt();
let _ = caps.has_avx512_bw_popcnt();
}
}
2 changes: 1 addition & 1 deletion src/hpc/spatial_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ fn batch_sq_dist_filter(
) -> Vec<(usize, f32)> {
#[cfg(target_arch = "x86_64")]
{
if candidates.len() >= 8 && is_x86_feature_detected!("avx2") {
if candidates.len() >= 8 && super::simd_caps::simd_caps().avx2 {
// SAFETY: avx2 detected, enough candidates for SIMD.
return unsafe { batch_sq_dist_avx2(query, candidates, radius_sq) };
}
Expand Down
Loading