Skip to content

Commit 116a9f8

Browse files
authored
Merge pull request #34 from AdaWorldAPI/claude/continue-session-0mAVa
refactor(hpc): add simd_caps LazyLock singleton — detect once, dispat…
2 parents 056a1c9 + 935db00 commit 116a9f8

11 files changed

Lines changed: 180 additions & 33 deletions

File tree

src/hpc/aabb.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,15 @@ fn sq_dist_point_aabb(point: [f32; 3], aabb: &Aabb) -> f32 {
144144
pub fn aabb_intersect_batch(query: &Aabb, candidates: &[Aabb]) -> Vec<bool> {
145145
#[cfg(target_arch = "x86_64")]
146146
{
147-
if is_x86_feature_detected!("avx512f") && candidates.len() >= 16 {
148-
// SAFETY: avx512f detected, enough candidates for batch processing.
147+
let caps = super::simd_caps::simd_caps();
148+
if caps.avx512f && candidates.len() >= 16 {
149+
// SAFETY: avx512f detected via simd_caps singleton.
149150
unsafe {
150151
return aabb_intersect_batch_avx512(query, candidates);
151152
}
152153
}
153-
if is_x86_feature_detected!("sse4.1") {
154-
// SAFETY: sse4.1 detected, slice access within bounds.
154+
if caps.sse41 {
155+
// SAFETY: sse4.1 detected via simd_caps singleton.
155156
unsafe {
156157
return aabb_intersect_batch_sse41(query, candidates);
157158
}
@@ -294,8 +295,8 @@ unsafe fn aabb_intersect_batch_sse41(query: &Aabb, candidates: &[Aabb]) -> Vec<b
294295
pub fn ray_aabb_slab_test_batch(ray: &Ray, aabbs: &[Aabb]) -> (Vec<bool>, Vec<f32>) {
295296
#[cfg(target_arch = "x86_64")]
296297
{
297-
if is_x86_feature_detected!("avx512f") && aabbs.len() >= 16 {
298-
// SAFETY: avx512f detected, enough AABBs for batch processing.
298+
if super::simd_caps::simd_caps().avx512f && aabbs.len() >= 16 {
299+
// SAFETY: avx512f detected via simd_caps singleton.
299300
unsafe {
300301
return ray_aabb_slab_test_avx512(ray, aabbs);
301302
}
@@ -455,8 +456,8 @@ unsafe fn ray_aabb_slab_test_avx512(ray: &Ray, aabbs: &[Aabb]) -> (Vec<bool>, Ve
455456
pub fn aabb_expand_batch(aabbs: &mut [Aabb], dx: f32, dy: f32, dz: f32) {
456457
#[cfg(target_arch = "x86_64")]
457458
{
458-
if is_x86_feature_detected!("sse2") {
459-
// SAFETY: sse2 detected, operating on mutable slice in-bounds.
459+
if super::simd_caps::simd_caps().sse2 {
460+
// SAFETY: sse2 detected via simd_caps singleton.
460461
unsafe {
461462
aabb_expand_batch_sse2(aabbs, dx, dy, dz);
462463
return;

src/hpc/bitwise.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -258,15 +258,16 @@ pub fn hamming_top_k_raw(
258258
fn dispatch_hamming(a: &[u8], b: &[u8]) -> u64 {
259259
#[cfg(target_arch = "x86_64")]
260260
{
261-
if is_x86_feature_detected!("avx512vpopcntdq") && is_x86_feature_detected!("avx512bw") {
261+
let caps = super::simd_caps::simd_caps();
262+
if caps.has_avx512_bw_popcnt() {
262263
// SAFETY: checked VPOPCNTDQ + BW
263264
return unsafe { crate::backend::kernels_avx512::hamming_distance(a, b) };
264265
}
265-
if is_x86_feature_detected!("avx512bw") {
266+
if caps.avx512bw {
266267
// SAFETY: checked AVX-512 BW — uses 512-bit vpshufb (64B/iter)
267268
return unsafe { hamming_avx512bw(a, b) };
268269
}
269-
if is_x86_feature_detected!("avx2") {
270+
if caps.avx2 {
270271
// SAFETY: checked AVX2 — uses 256-bit vpshufb (32B/iter)
271272
return unsafe { hamming_avx2(a, b) };
272273
}
@@ -277,11 +278,12 @@ fn dispatch_hamming(a: &[u8], b: &[u8]) -> u64 {
277278
fn dispatch_popcount(a: &[u8]) -> u64 {
278279
#[cfg(target_arch = "x86_64")]
279280
{
280-
if is_x86_feature_detected!("avx512vpopcntdq") {
281+
let caps = super::simd_caps::simd_caps();
282+
if caps.avx512vpopcntdq {
281283
// SAFETY: checked VPOPCNTDQ
282284
return unsafe { crate::backend::kernels_avx512::popcount(a) };
283285
}
284-
if is_x86_feature_detected!("avx512bw") {
286+
if caps.avx512bw {
285287
// SAFETY: checked AVX-512 BW — uses 512-bit vpshufb
286288
return unsafe { popcount_avx512bw(a) };
287289
}
@@ -292,7 +294,8 @@ fn dispatch_popcount(a: &[u8]) -> u64 {
292294
fn dispatch_hamming_batch(query: &[u8], database: &[u8], num_rows: usize, row_bytes: usize) -> Vec<u64> {
293295
#[cfg(target_arch = "x86_64")]
294296
{
295-
if is_x86_feature_detected!("avx512vpopcntdq") && is_x86_feature_detected!("avx512bw") {
297+
let caps = super::simd_caps::simd_caps();
298+
if caps.has_avx512_bw_popcnt() {
296299
// SAFETY: checked VPOPCNTDQ + BW
297300
return unsafe { crate::backend::kernels_avx512::hamming_batch(query, database, num_rows, row_bytes) };
298301
}

src/hpc/byte_scan.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,12 @@ mod simd_impl {
142142
pub fn byte_find_all(haystack: &[u8], needle: u8) -> Vec<usize> {
143143
#[cfg(target_arch = "x86_64")]
144144
{
145-
if is_x86_feature_detected!("avx512bw") {
145+
let caps = super::simd_caps::simd_caps();
146+
if caps.avx512bw {
146147
// SAFETY: feature detected above.
147148
return unsafe { simd_impl::byte_find_all_avx512(haystack, needle) };
148149
}
149-
if is_x86_feature_detected!("avx2") {
150+
if caps.avx2 {
150151
// SAFETY: feature detected above.
151152
return unsafe { simd_impl::byte_find_all_avx2(haystack, needle) };
152153
}
@@ -180,11 +181,12 @@ pub fn u16_find_all(haystack: &[u8], pattern: u16) -> Vec<usize> {
180181
pub fn byte_count(haystack: &[u8], needle: u8) -> usize {
181182
#[cfg(target_arch = "x86_64")]
182183
{
183-
if is_x86_feature_detected!("avx512bw") {
184+
let caps = super::simd_caps::simd_caps();
185+
if caps.avx512bw {
184186
// SAFETY: feature detected above.
185187
return unsafe { simd_impl::byte_count_avx512(haystack, needle) };
186188
}
187-
if is_x86_feature_detected!("avx2") {
189+
if caps.avx2 {
188190
// SAFETY: feature detected above.
189191
return unsafe { simd_impl::byte_count_avx2(haystack, needle) };
190192
}

src/hpc/cam_pq.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ impl DistanceTables {
199199
pub fn distance_batch(&self, cams: &[CamFingerprint]) -> Vec<f32> {
200200
#[cfg(target_arch = "x86_64")]
201201
{
202-
if is_x86_feature_detected!("avx512f") {
202+
if super::simd_caps::simd_caps().avx512f {
203203
return unsafe { self.distance_batch_avx512(cams) };
204204
}
205205
}

src/hpc/distance.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ mod simd_impl {
109109
pub fn squared_distances_f32(query: [f32; 3], points: &[[f32; 3]]) -> Vec<f32> {
110110
#[cfg(target_arch = "x86_64")]
111111
{
112-
if is_x86_feature_detected!("avx2") {
112+
if super::simd_caps::simd_caps().avx2 {
113113
let mut out = Vec::new();
114114
// SAFETY: feature detected above.
115115
unsafe { simd_impl::squared_distances_avx2(query, points, &mut out) };

src/hpc/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
//! - FFT (forward, inverse, real-to-complex)
2020
//! - VML (vectorized math library)
2121
22+
// SIMD capability singleton — detect once, all modules share
23+
pub mod simd_caps;
24+
2225
pub mod blas_level1;
2326
pub mod blas_level2;
2427
pub mod blas_level3;

src/hpc/nibble.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub fn nibble_unpack(packed: &[u8], count: usize) -> Vec<u8> {
2727

2828
#[cfg(target_arch = "x86_64")]
2929
{
30-
if count >= 32 && is_x86_feature_detected!("avx2") {
30+
if count >= 32 && super::simd_caps::simd_caps().avx2 {
3131
// SAFETY: avx2 detected, packed buffer large enough.
3232
unsafe {
3333
nibble_unpack_avx2(packed, count, &mut out);
@@ -136,14 +136,15 @@ pub fn nibble_sub_clamp(packed: &mut [u8], delta: u8) {
136136

137137
#[cfg(target_arch = "x86_64")]
138138
{
139-
if is_x86_feature_detected!("avx512bw") {
139+
let caps = super::simd_caps::simd_caps();
140+
if caps.avx512bw {
140141
// SAFETY: avx512bw detected, slice is mutable and valid.
141142
unsafe {
142143
nibble_sub_clamp_avx512(packed, delta);
143144
return;
144145
}
145146
}
146-
if is_x86_feature_detected!("avx2") {
147+
if caps.avx2 {
147148
// SAFETY: avx2 detected, slice is mutable and valid.
148149
unsafe {
149150
nibble_sub_clamp_avx2(packed, delta);
@@ -242,7 +243,7 @@ unsafe fn nibble_sub_clamp_avx512(packed: &mut [u8], delta: u8) {
242243
pub fn nibble_above_threshold(packed: &[u8], threshold: u8) -> Vec<usize> {
243244
#[cfg(target_arch = "x86_64")]
244245
{
245-
if packed.len() >= 16 && is_x86_feature_detected!("avx2") {
246+
if packed.len() >= 16 && super::simd_caps::simd_caps().avx2 {
246247
// SAFETY: avx2 detected, packed buffer large enough.
247248
return unsafe { nibble_above_threshold_avx2(packed, threshold) };
248249
}

src/hpc/palette_codec.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,12 @@ impl PackedPaletteArray {
265265
pub fn unpack_indices_simd(packed: &[u64], bits_per_index: usize, count: usize) -> Vec<u8> {
266266
#[cfg(target_arch = "x86_64")]
267267
{
268-
if is_x86_feature_detected!("avx512f") && count >= 16 {
268+
let caps = super::simd_caps::simd_caps();
269+
if caps.avx512f && count >= 16 {
269270
// SAFETY: avx512f detected, count >= 16 ensures enough data.
270271
return unsafe { unpack_generic_avx512(packed, bits_per_index, count) };
271272
}
272-
if bits_per_index == 4 && count >= 16 && is_x86_feature_detected!("avx2") {
273+
if bits_per_index == 4 && count >= 16 && caps.avx2 {
273274
return unsafe { unpack_4bit_avx2(packed, count) };
274275
}
275276
}
@@ -281,7 +282,8 @@ pub fn unpack_indices_simd(packed: &[u64], bits_per_index: usize, count: usize)
281282
pub fn pack_indices_simd(indices: &[u8], bits_per_index: usize) -> Vec<u64> {
282283
#[cfg(target_arch = "x86_64")]
283284
{
284-
if is_x86_feature_detected!("avx512f") && indices.len() >= 16 {
285+
let caps = super::simd_caps::simd_caps();
286+
if caps.avx512f && indices.len() >= 16 {
285287
// SAFETY: avx512f detected, enough indices for SIMD processing.
286288
return unsafe { pack_generic_avx512(indices, bits_per_index) };
287289
}
@@ -415,7 +417,8 @@ pub fn bedrock_reorder_xzy(states: &[u16]) -> Vec<u16> {
415417

416418
#[cfg(target_arch = "x86_64")]
417419
{
418-
if is_x86_feature_detected!("avx512f") {
420+
let caps = super::simd_caps::simd_caps();
421+
if caps.avx512f {
419422
// SAFETY: avx512f detected, states.len() == 4096 guaranteed by assert.
420423
return unsafe { bedrock_reorder_xzy_avx512(states) };
421424
}

src/hpc/property_mask.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,15 @@ impl PropertyMask {
9696

9797
#[cfg(target_arch = "x86_64")]
9898
{
99-
if is_x86_feature_detected!("avx512f") {
99+
let caps = super::simd_caps::simd_caps();
100+
if caps.avx512f {
100101
// SAFETY: avx512f detected, pointers are within slice bounds.
101102
unsafe {
102103
self.test_section_avx512(states, &mut result);
103104
return result;
104105
}
105106
}
106-
if is_x86_feature_detected!("avx2") {
107+
if caps.avx2 {
107108
// SAFETY: we checked avx2 at runtime, pointers are within slice bounds.
108109
unsafe {
109110
self.test_section_avx2(states, &mut result);
@@ -120,7 +121,8 @@ impl PropertyMask {
120121
pub fn count_section(&self, states: &[u64]) -> u32 {
121122
#[cfg(target_arch = "x86_64")]
122123
{
123-
if is_x86_feature_detected!("avx512vpopcntdq") && is_x86_feature_detected!("avx512f") {
124+
let caps = super::simd_caps::simd_caps();
125+
if caps.avx512vpopcntdq && caps.avx512f {
124126
// SAFETY: feature detected above.
125127
return unsafe { self.count_section_avx512(states) };
126128
}
@@ -329,7 +331,8 @@ pub fn count_section_multi(masks: &[PropertyMask], states: &[u64]) -> MultiMaskR
329331

330332
#[cfg(target_arch = "x86_64")]
331333
{
332-
if is_x86_feature_detected!("avx512f") && states.len() >= 8 {
334+
let caps = super::simd_caps::simd_caps();
335+
if caps.avx512f && states.len() >= 8 {
333336
// SAFETY: avx512f detected above, states.len() >= 8 guaranteed.
334337
unsafe {
335338
return count_section_multi_avx512(masks, states);

src/hpc/simd_caps.rs

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
//! SIMD capability singleton — detect once, dispatch forever.
2+
//!
3+
//! Replaces per-call `is_x86_feature_detected!` (hidden `AtomicU8` load each time)
4+
//! with a single `LazyLock<SimdCaps>` detected at first access. Every HPC module
5+
//! calls `simd_caps()` which is one pointer deref to a frozen `Copy` struct.
6+
//!
7+
//! ```text
8+
//! is_x86_feature_detected!("avx512f") → ~3ns (atomic load + branch)
9+
//! simd_caps().avx512f → ~1ns (LazyLock deref + bool read)
10+
//! ```
11+
12+
use std::sync::LazyLock;
13+
14+
/// Detected SIMD capabilities, frozen at first access.
15+
///
16+
/// This is a `Copy` type: 8 bools packed into 8 bytes. Passed by value,
17+
/// lives in registers after the first `LazyLock` deref.
18+
#[derive(Debug, Clone, Copy)]
19+
pub struct SimdCaps {
20+
/// AVX2 (256-bit integer/FP SIMD).
21+
pub avx2: bool,
22+
/// AVX-512 Foundation (512-bit).
23+
pub avx512f: bool,
24+
/// AVX-512 Byte/Word operations.
25+
pub avx512bw: bool,
26+
/// AVX-512 Vector Length extensions.
27+
pub avx512vl: bool,
28+
/// AVX-512 VPOPCNTDQ (hardware popcount on 512-bit).
29+
pub avx512vpopcntdq: bool,
30+
/// SSE 4.1.
31+
pub sse41: bool,
32+
/// SSE2 (baseline on x86_64, but explicit for clarity).
33+
pub sse2: bool,
34+
/// FMA (fused multiply-add).
35+
pub fma: bool,
36+
}
37+
38+
/// Global singleton — detected once at first access via `LazyLock`.
39+
static CAPS: LazyLock<SimdCaps> = LazyLock::new(SimdCaps::detect);
40+
41+
/// Get the detected SIMD capabilities. First call detects; all subsequent
42+
/// calls are a single pointer deref with no atomic operations.
43+
#[inline(always)]
44+
pub fn simd_caps() -> SimdCaps {
45+
*CAPS
46+
}
47+
48+
impl SimdCaps {
49+
/// Detect CPU capabilities at runtime.
50+
#[cfg(target_arch = "x86_64")]
51+
fn detect() -> Self {
52+
Self {
53+
avx2: is_x86_feature_detected!("avx2"),
54+
avx512f: is_x86_feature_detected!("avx512f"),
55+
avx512bw: is_x86_feature_detected!("avx512bw"),
56+
avx512vl: is_x86_feature_detected!("avx512vl"),
57+
avx512vpopcntdq: is_x86_feature_detected!("avx512vpopcntdq"),
58+
sse41: is_x86_feature_detected!("sse4.1"),
59+
sse2: is_x86_feature_detected!("sse2"),
60+
fma: is_x86_feature_detected!("fma"),
61+
}
62+
}
63+
64+
/// Non-x86: all false.
65+
#[cfg(not(target_arch = "x86_64"))]
66+
fn detect() -> Self {
67+
Self {
68+
avx2: false,
69+
avx512f: false,
70+
avx512bw: false,
71+
avx512vl: false,
72+
avx512vpopcntdq: false,
73+
sse41: false,
74+
sse2: false,
75+
fma: false,
76+
}
77+
}
78+
79+
/// True if AVX-512 Foundation + VPOPCNTDQ are both available.
80+
#[inline(always)]
81+
pub fn has_avx512_popcnt(self) -> bool {
82+
self.avx512f && self.avx512vpopcntdq
83+
}
84+
85+
/// True if AVX-512 BW + VPOPCNTDQ are both available.
86+
#[inline(always)]
87+
pub fn has_avx512_bw_popcnt(self) -> bool {
88+
self.avx512bw && self.avx512vpopcntdq
89+
}
90+
}
91+
92+
#[cfg(test)]
93+
mod tests {
94+
use super::*;
95+
96+
#[test]
97+
fn detect_does_not_panic() {
98+
let caps = simd_caps();
99+
// On any platform, simd_caps() should succeed.
100+
let _ = caps.avx2;
101+
let _ = caps.avx512f;
102+
}
103+
104+
#[test]
105+
fn simd_caps_is_copy() {
106+
let a = simd_caps();
107+
let b = a; // Copy
108+
let c = a; // Still valid
109+
assert_eq!(a.avx2, b.avx2);
110+
assert_eq!(b.avx512f, c.avx512f);
111+
}
112+
113+
#[test]
114+
fn simd_caps_deterministic() {
115+
let a = simd_caps();
116+
let b = simd_caps();
117+
assert_eq!(a.avx2, b.avx2);
118+
assert_eq!(a.avx512f, b.avx512f);
119+
assert_eq!(a.avx512bw, b.avx512bw);
120+
assert_eq!(a.avx512vpopcntdq, b.avx512vpopcntdq);
121+
assert_eq!(a.sse41, b.sse41);
122+
}
123+
124+
#[test]
125+
fn convenience_methods() {
126+
let caps = simd_caps();
127+
// Just verify these don't panic and return consistent values.
128+
let _ = caps.has_avx512_popcnt();
129+
let _ = caps.has_avx512_bw_popcnt();
130+
}
131+
}

0 commit comments

Comments
 (0)