diff --git a/.cargo/config.toml b/.cargo/config.toml index 359df1db..92467f26 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,4 +1,4 @@ [build] -# x86-64-v4 = AVX-512 baseline. All BF16 SIMD paths use native __m512d. -# Rust 1.94 stable. No nightly. -rustflags = ["-C", "target-cpu=x86-64-v4"] +# No global target-cpu. Each kernel uses #[target_feature(enable = "avx512f")] +# per-function, with LazyLock runtime detection. One binary, all ISAs. +# Railway (AVX-512) and GitHub CI (AVX2) use the same binary. diff --git a/crates/p64/src/lib.rs b/crates/p64/src/lib.rs index 9004e016..270e4db7 100644 --- a/crates/p64/src/lib.rs +++ b/crates/p64/src/lib.rs @@ -168,6 +168,267 @@ fn spread_32_to_64(val: u32) -> u64 { out } +// ============================================================================ +// Multi-versioned attend kernel: AVX-512 → AVX2 → scalar. +// ============================================================================ + +/// Return type for attend kernel: (best_idx, distance, scores, fires). +type AttendFn = unsafe fn(&[u64; 64], u64, u8) -> (u8, u8, [u8; 64], u64); + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +unsafe fn attend_avx512(rows: &[u64; 64], query: u64, gamma: u8) -> (u8, u8, [u8; 64], u64) { + use std::arch::x86_64::*; + let mut best_idx = 0u8; + let mut best_score = 0u8; + let mut scores = [0u8; 64]; + let mut fires = 0u64; + + let q = _mm512_set1_epi64(query as i64); + // Process 8 rows per chunk, 8 chunks = 64 rows + for chunk in 0..8 { + let base = chunk * 8; + // SAFETY: rows is [u64; 64], base..base+8 is in bounds, Palette64 is 64-byte aligned. + let r = _mm512_loadu_si512(rows[base..].as_ptr() as *const __m512i); + let anded = _mm512_and_si512(r, q); + // Extract 8 u64s and scalar popcount (no VPOPCNTDQ dependency) + let vals: [u64; 8] = std::mem::transmute(anded); + for j in 0..8 { + let score = vals[j].count_ones() as u8; + let idx = base + j; + scores[idx] = score; + if score > best_score { + best_score = score; + best_idx = idx as u8; + } + if score >= gamma { + fires |= 1u64 << idx; + } + } + } + (best_idx, 64 - best_score, scores, fires) +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn attend_avx2(rows: &[u64; 64], query: u64, gamma: u8) -> (u8, u8, [u8; 64], u64) { + use std::arch::x86_64::*; + let mut best_idx = 0u8; + let mut best_score = 0u8; + let mut scores = [0u8; 64]; + let mut fires = 0u64; + + let q = _mm256_set1_epi64x(query as i64); + // Process 4 rows per chunk, 16 chunks = 64 rows + for chunk in 0..16 { + let base = chunk * 4; + // SAFETY: rows is [u64; 64], base..base+4 is in bounds. + let r = _mm256_loadu_si256(rows[base..].as_ptr() as *const __m256i); + let anded = _mm256_and_si256(r, q); + let vals: [u64; 4] = std::mem::transmute(anded); + for j in 0..4 { + let score = vals[j].count_ones() as u8; + let idx = base + j; + scores[idx] = score; + if score > best_score { + best_score = score; + best_idx = idx as u8; + } + if score >= gamma { + fires |= 1u64 << idx; + } + } + } + (best_idx, 64 - best_score, scores, fires) +} + +fn attend_scalar(rows: &[u64; 64], query: u64, gamma: u8) -> (u8, u8, [u8; 64], u64) { + let mut best_idx = 0u8; + let mut best_score = 0u8; + let mut scores = [0u8; 64]; + let mut fires = 0u64; + for i in 0..64 { + let score = (query & rows[i]).count_ones() as u8; + scores[i] = score; + if score > best_score { + best_score = score; + best_idx = i as u8; + } + if score >= gamma { + fires |= 1u64 << i; + } + } + (best_idx, 64 - best_score, scores, fires) +} + +static ATTEND_KERNEL: std::sync::LazyLock = std::sync::LazyLock::new(|| { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx512f") { + return attend_avx512 as AttendFn; + } + if is_x86_feature_detected!("avx2") { + return attend_avx2 as AttendFn; + } + } + attend_scalar as AttendFn +}); + +// ============================================================================ +// Multi-versioned nearest_k kernel: AVX-512 → AVX2 → scalar. +// ============================================================================ + +/// Compute all 64 Hamming distances in one pass. +type NearestKFn = unsafe fn(&[u64; 64], u64) -> [u8; 64]; + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +unsafe fn nearest_k_avx512(rows: &[u64; 64], query: u64) -> [u8; 64] { + use std::arch::x86_64::*; + let mut dists = [0u8; 64]; + let q = _mm512_set1_epi64(query as i64); + for chunk in 0..8 { + let base = chunk * 8; + // SAFETY: rows is [u64; 64], base..base+8 is in bounds. + let r = _mm512_loadu_si512(rows[base..].as_ptr() as *const __m512i); + let xored = _mm512_xor_si512(r, q); + let vals: [u64; 8] = std::mem::transmute(xored); + for j in 0..8 { + dists[base + j] = vals[j].count_ones() as u8; + } + } + dists +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn nearest_k_avx2(rows: &[u64; 64], query: u64) -> [u8; 64] { + use std::arch::x86_64::*; + let mut dists = [0u8; 64]; + let q = _mm256_set1_epi64x(query as i64); + for chunk in 0..16 { + let base = chunk * 4; + // SAFETY: rows is [u64; 64], base..base+4 is in bounds. + let r = _mm256_loadu_si256(rows[base..].as_ptr() as *const __m256i); + let xored = _mm256_xor_si256(r, q); + let vals: [u64; 4] = std::mem::transmute(xored); + for j in 0..4 { + dists[base + j] = vals[j].count_ones() as u8; + } + } + dists +} + +fn nearest_k_scalar(rows: &[u64; 64], query: u64) -> [u8; 64] { + let mut dists = [0u8; 64]; + for i in 0..64 { + dists[i] = (query ^ rows[i]).count_ones() as u8; + } + dists +} + +static NEAREST_K_KERNEL: std::sync::LazyLock = std::sync::LazyLock::new(|| { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx512f") { + return nearest_k_avx512 as NearestKFn; + } + if is_x86_feature_detected!("avx2") { + return nearest_k_avx2 as NearestKFn; + } + } + nearest_k_scalar as NearestKFn +}); + +// ============================================================================ +// Multi-versioned moe_gate kernel: AVX-512 → AVX2 → scalar. +// ============================================================================ + +/// Return type: (active_mask, strength[8], combined). +type MoeGateFn = unsafe fn(&[u64; 8], u64, u8) -> (u8, [u8; 8], u64); + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +unsafe fn moe_gate_avx512(planes: &[u64; 8], query: u64, threshold: u8) -> (u8, [u8; 8], u64) { + use std::arch::x86_64::*; + // Load all 8 planes into one zmm register, AND with broadcast query + // SAFETY: planes is [u64; 8] = 64 bytes, fits in one zmm. + let p = _mm512_loadu_si512(planes.as_ptr() as *const __m512i); + let q = _mm512_set1_epi64(query as i64); + let anded = _mm512_and_si512(p, q); + let vals: [u64; 8] = std::mem::transmute(anded); + + let mut active = 0u8; + let mut strength = [0u8; 8]; + let mut combined = 0u64; + for i in 0..8 { + let score = vals[i].count_ones() as u8; + strength[i] = score; + if score >= threshold { + active |= 1 << i; + combined |= planes[i]; + } + } + (active, strength, combined) +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn moe_gate_avx2(planes: &[u64; 8], query: u64, threshold: u8) -> (u8, [u8; 8], u64) { + use std::arch::x86_64::*; + let q = _mm256_set1_epi64x(query as i64); + let mut active = 0u8; + let mut strength = [0u8; 8]; + let mut combined = 0u64; + + // Process 4 planes at a time, 2 chunks = 8 planes + for chunk in 0..2 { + let base = chunk * 4; + // SAFETY: planes is [u64; 8], base..base+4 is in bounds. + let p = _mm256_loadu_si256(planes[base..].as_ptr() as *const __m256i); + let anded = _mm256_and_si256(p, q); + let vals: [u64; 4] = std::mem::transmute(anded); + for j in 0..4 { + let score = vals[j].count_ones() as u8; + let idx = base + j; + strength[idx] = score; + if score >= threshold { + active |= 1 << idx; + combined |= planes[idx]; + } + } + } + (active, strength, combined) +} + +fn moe_gate_scalar(planes: &[u64; 8], query: u64, threshold: u8) -> (u8, [u8; 8], u64) { + let mut active = 0u8; + let mut strength = [0u8; 8]; + let mut combined = 0u64; + for i in 0..8 { + let score = (query & planes[i]).count_ones() as u8; + strength[i] = score; + if score >= threshold { + active |= 1 << i; + combined |= planes[i]; + } + } + (active, strength, combined) +} + +static MOE_GATE_KERNEL: std::sync::LazyLock = std::sync::LazyLock::new(|| { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx512f") { + return moe_gate_avx512 as MoeGateFn; + } + if is_x86_feature_detected!("avx2") { + return moe_gate_avx2 as MoeGateFn; + } + } + moe_gate_scalar as MoeGateFn +}); + // ============================================================================ // BNN Attention // ============================================================================ @@ -183,30 +444,16 @@ impl Palette64 { /// Score = popcount(query AND row[i]). /// Higher score = more bits in common = better match. /// Gamma threshold: rows below this score don't "fire." + /// + /// Runtime dispatch via LazyLock: AVX-512 → AVX2 → scalar. #[inline] pub fn attend(&self, query: u64, gamma: u8) -> AttentionResult { - let mut scores = [0u8; 64]; - let mut best_idx = 0u8; - let mut best_score = 0u8; - let mut fires = 0u64; - - for i in 0..64 { - let score = (query & self.rows[i]).count_ones() as u8; - scores[i] = score; - - if score > best_score { - best_score = score; - best_idx = i as u8; - } - - if score >= gamma { - fires |= 1u64 << i; - } - } - + // SAFETY: LazyLock guarantees the selected kernel matches CPU features. + let (best_idx, distance, scores, fires) = + unsafe { ATTEND_KERNEL(&self.rows, query, gamma) }; AttentionResult { best_idx, - distance: 64 - best_score, + distance, scores, fires, } @@ -228,16 +475,15 @@ impl Palette64 { /// Palette lookup: find the K nearest rows by Hamming distance. /// /// Returns (row_index, hamming_distance) sorted ascending. + /// + /// Runtime dispatch via LazyLock: AVX-512 → AVX2 → scalar. pub fn nearest_k(&self, query: u64, k: usize) -> Vec<(u8, u8)> { - let mut dists: Vec<(u8, u8)> = (0..64) - .map(|i| { - let dist = (query ^ self.rows[i]).count_ones() as u8; - (i as u8, dist) - }) - .collect(); - dists.sort_by_key(|&(_, d)| d); - dists.truncate(k); - dists + // SAFETY: LazyLock guarantees the selected kernel matches CPU features. + let dists = unsafe { NEAREST_K_KERNEL(&self.rows, query) }; + let mut pairs: Vec<(u8, u8)> = (0..64u8).map(|i| (i, dists[i as usize])).collect(); + pairs.sort_by_key(|&(_, d)| d); + pairs.truncate(k); + pairs } /// Row density: popcount of each row. Sparse rows = abstract; dense = concrete. @@ -281,22 +527,13 @@ impl HeelPlanes { /// /// Each HEEL plane is an expert. The query's match against each expert /// determines which experts activate and with what strength. + /// + /// Runtime dispatch via LazyLock: AVX-512 → AVX2 → scalar. #[inline] pub fn moe_gate(&self, query: u64, threshold: u8) -> MoeGate { - let mut active = 0u8; - let mut strength = [0u8; 8]; - let mut combined = 0u64; - - for i in 0..8 { - let score = (query & self.planes[i]).count_ones() as u8; - strength[i] = score; - - if score >= threshold { - active |= 1 << i; - combined |= self.planes[i]; - } - } - + // SAFETY: LazyLock guarantees the selected kernel matches CPU features. + let (active, strength, combined) = + unsafe { MOE_GATE_KERNEL(&self.planes, query, threshold) }; MoeGate { active, strength, @@ -1217,10 +1454,15 @@ pub mod sparse256 { /// - Skip entire 32-element HEEL blocks if the 8×8 super-block is empty /// - Skip 4-element TWIG blocks if the palette bit is 0 /// - Only compute exact distance for active palette entries - pub fn hhtl_cascade_search( + /// + /// `score_fn`: callback that computes the actual score for a (row, col) pair. + /// This is where the LEAF level lives — LanceDB vector search, DistanceMatrix + /// lookup, or BF16 dot product. The cascade doesn't know or care which. + pub fn hhtl_cascade_search f32>( palette: &Palette64, query_row: u8, scores: &mut [f32; 256], + score_fn: F, ) -> usize { let heel_row = query_row / 32; let hip_row = (query_row / 4) % 8; @@ -1241,14 +1483,11 @@ pub mod sparse256 { let block_col = bits.trailing_zeros() as usize; bits &= bits - 1; - // This block is active — compute 4 scores let base_col = block_col * 4; for k in 0..4 { let col = base_col + k; if col < 256 { - // Placeholder: actual score computation goes here - // In production: ZeckF8 distance or BF16 dot product - scores[col] = 1.0; + scores[col] = score_fn(query_row, col as u8); computed += 1; } } @@ -1627,7 +1866,8 @@ mod tests { } let mut scores = [0.0f32; 256]; - let computed = hhtl_cascade_search(&palette, 0, &mut scores); + let score_fn = |row: u8, col: u8| -> f32 { 1.0 - (row as f32 - col as f32).abs() / 256.0 }; + let computed = hhtl_cascade_search(&palette, 0, &mut scores, &score_fn); eprintln!("HHTL cascade: computed {} of 256 scores", computed); @@ -1635,7 +1875,7 @@ mod tests { assert_eq!(computed, 4, "Should only compute active block entries"); // Row 128 → block_row = 128/32*8 + (128/4)%8 = 4*8 + 0 = 32 - let computed2 = hhtl_cascade_search(&palette, 128, &mut scores); + let computed2 = hhtl_cascade_search(&palette, 128, &mut scores, &score_fn); assert_eq!(computed2, 4); } diff --git a/src/hpc/bgz17_bridge.rs b/src/hpc/bgz17_bridge.rs index 628a95dc..245ed3d6 100644 --- a/src/hpc/bgz17_bridge.rs +++ b/src/hpc/bgz17_bridge.rs @@ -34,6 +34,427 @@ pub struct Base17 { pub dims: [i16; BASE_DIM], } +// ============================================================================ +// Multi-versioned L1 kernel: AVX-512 → AVX2 → scalar. One binary, all ISAs. +// ============================================================================ + +type L1Fn = unsafe fn(&[i16; 17], &[i16; 17]) -> u32; + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +unsafe fn l1_avx512(a: &[i16; 17], b: &[i16; 17]) -> u32 { + use std::arch::x86_64::*; + // Load 16 i16 → 16 i32 via sign-extension + let va = _mm512_cvtepi16_epi32(_mm256_loadu_si256(a.as_ptr() as *const __m256i)); + let vb = _mm512_cvtepi16_epi32(_mm256_loadu_si256(b.as_ptr() as *const __m256i)); + let diff = _mm512_sub_epi32(va, vb); + let abs_diff = _mm512_abs_epi32(diff); + let sum16 = _mm512_reduce_add_epi32(abs_diff) as u32; + // 17th dim scalar + let d16 = (a[16] as i32 - b[16] as i32).unsigned_abs(); + sum16 + d16 +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn l1_avx2(a: &[i16; 17], b: &[i16; 17]) -> u32 { + use std::arch::x86_64::*; + // Process 8 dims at a time (2 passes of 8 = 16, + 1 scalar) + let va0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(a.as_ptr() as *const __m128i)); + let vb0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(b.as_ptr() as *const __m128i)); + let diff0 = _mm256_sub_epi32(va0, vb0); + let abs0 = _mm256_abs_epi32(diff0); + + let va1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(a[8..].as_ptr() as *const __m128i)); + let vb1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(b[8..].as_ptr() as *const __m128i)); + let diff1 = _mm256_sub_epi32(va1, vb1); + let abs1 = _mm256_abs_epi32(diff1); + + let sum = _mm256_add_epi32(abs0, abs1); + // Horizontal sum of 8 i32 + let hi128 = _mm256_extracti128_si256(sum, 1); + let lo128 = _mm256_castsi256_si128(sum); + let sum128 = _mm_add_epi32(lo128, hi128); + let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8)); + let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4)); + let sum16 = _mm_extract_epi32(sum32, 0) as u32; + // 17th dim scalar + let d16 = (a[16] as i32 - b[16] as i32).unsigned_abs(); + sum16 + d16 +} + +fn l1_scalar(a: &[i16; 17], b: &[i16; 17]) -> u32 { + let mut d = 0u32; + for i in 0..17 { + d += (a[i] as i32 - b[i] as i32).unsigned_abs(); + } + d +} + +static L1_KERNEL: std::sync::LazyLock = std::sync::LazyLock::new(|| { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx512f") { + return l1_avx512 as L1Fn; + } + if is_x86_feature_detected!("avx2") { + return l1_avx2 as L1Fn; + } + } + l1_scalar as L1Fn +}); + +// ============================================================================ +// Multi-versioned L1-weighted kernel: AVX-512 → AVX2 → scalar. +// ============================================================================ + +type L1WeightedFn = unsafe fn(&[i16; 17], &[i16; 17]) -> u32; + +const WEIGHT_VEC: [i32; 16] = [20, 3, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +unsafe fn l1_weighted_avx512(a: &[i16; 17], b: &[i16; 17]) -> u32 { + use std::arch::x86_64::*; + let va = _mm512_cvtepi16_epi32(_mm256_loadu_si256(a.as_ptr() as *const __m256i)); + let vb = _mm512_cvtepi16_epi32(_mm256_loadu_si256(b.as_ptr() as *const __m256i)); + let diff = _mm512_sub_epi32(va, vb); + let abs_diff = _mm512_abs_epi32(diff); + let vw = _mm512_loadu_si512(WEIGHT_VEC.as_ptr() as *const __m512i); + let weighted = _mm512_mullo_epi32(abs_diff, vw); + let sum16 = _mm512_reduce_add_epi32(weighted) as u32; + // 17th dim: weight = 1 + let d16 = (a[16] as i32 - b[16] as i32).unsigned_abs(); + sum16 + d16 +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn l1_weighted_avx2(a: &[i16; 17], b: &[i16; 17]) -> u32 { + use std::arch::x86_64::*; + // First 8 dims + let va0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(a.as_ptr() as *const __m128i)); + let vb0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(b.as_ptr() as *const __m128i)); + let diff0 = _mm256_sub_epi32(va0, vb0); + let abs0 = _mm256_abs_epi32(diff0); + let vw0 = _mm256_loadu_si256(WEIGHT_VEC.as_ptr() as *const __m256i); + let w0 = _mm256_mullo_epi32(abs0, vw0); + + // Dims 8..16 + let va1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(a[8..].as_ptr() as *const __m128i)); + let vb1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(b[8..].as_ptr() as *const __m128i)); + let diff1 = _mm256_sub_epi32(va1, vb1); + let abs1 = _mm256_abs_epi32(diff1); + let vw1 = _mm256_loadu_si256(WEIGHT_VEC[8..].as_ptr() as *const __m256i); + let w1 = _mm256_mullo_epi32(abs1, vw1); + + let sum = _mm256_add_epi32(w0, w1); + // Horizontal sum + let hi128 = _mm256_extracti128_si256(sum, 1); + let lo128 = _mm256_castsi256_si128(sum); + let sum128 = _mm_add_epi32(lo128, hi128); + let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8)); + let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4)); + let s = _mm_extract_epi32(sum32, 0) as u32; + // 17th dim: weight = 1 + let d16 = (a[16] as i32 - b[16] as i32).unsigned_abs(); + s + d16 +} + +fn l1_weighted_scalar(a: &[i16; 17], b: &[i16; 17]) -> u32 { + let mut d = 0u32; + for i in 0..17 { + let diff = (a[i] as i32 - b[i] as i32).unsigned_abs(); + let weight = if i == 0 { 20 } else if i < 7 { 3 } else { 1 }; + d += diff * weight; + } + d +} + +static L1_WEIGHTED_KERNEL: std::sync::LazyLock = std::sync::LazyLock::new(|| { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx512f") { + return l1_weighted_avx512 as L1WeightedFn; + } + if is_x86_feature_detected!("avx2") { + return l1_weighted_avx2 as L1WeightedFn; + } + } + l1_weighted_scalar as L1WeightedFn +}); + +// ============================================================================ +// Multi-versioned sign_agreement kernel: AVX-512 → AVX2 → scalar. +// ============================================================================ + +type SignAgreementFn = unsafe fn(&[i16; 17], &[i16; 17]) -> u32; + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +unsafe fn sign_agreement_avx512(a: &[i16; 17], b: &[i16; 17]) -> u32 { + use std::arch::x86_64::*; + let va = _mm512_cvtepi16_epi32(_mm256_loadu_si256(a.as_ptr() as *const __m256i)); + let vb = _mm512_cvtepi16_epi32(_mm256_loadu_si256(b.as_ptr() as *const __m256i)); + // XOR: same sign → non-negative, different sign → negative + let xor = _mm512_xor_si512(va, vb); + // Compare >= 0: mask bit set where same sign + let zero = _mm512_setzero_si512(); + let mask = _mm512_cmpge_epi32_mask(xor, zero); + let count16 = mask.count_ones(); + // 17th dim + let same17 = if (a[16] >= 0) == (b[16] >= 0) { 1u32 } else { 0u32 }; + count16 + same17 +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn sign_agreement_avx2(a: &[i16; 17], b: &[i16; 17]) -> u32 { + use std::arch::x86_64::*; + // First 8 dims + let va0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(a.as_ptr() as *const __m128i)); + let vb0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(b.as_ptr() as *const __m128i)); + let xor0 = _mm256_xor_si256(va0, vb0); + let zero = _mm256_setzero_si256(); + let neg0 = _mm256_cmpgt_epi32(zero, xor0); // -1 where xor < 0 + // movemask_ps on the reinterpreted float gives 8 bits, one per 32-bit lane + let mask0 = _mm256_movemask_ps(_mm256_castsi256_ps(neg0)) as u32; + let same0 = 8 - mask0.count_ones(); + + // Dims 8..16 + let va1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(a[8..].as_ptr() as *const __m128i)); + let vb1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(b[8..].as_ptr() as *const __m128i)); + let xor1 = _mm256_xor_si256(va1, vb1); + let neg1 = _mm256_cmpgt_epi32(zero, xor1); + let mask1 = _mm256_movemask_ps(_mm256_castsi256_ps(neg1)) as u32; + let same1 = 8 - mask1.count_ones(); + + // 17th dim + let same17 = if (a[16] >= 0) == (b[16] >= 0) { 1u32 } else { 0u32 }; + same0 + same1 + same17 +} + +fn sign_agreement_scalar(a: &[i16; 17], b: &[i16; 17]) -> u32 { + let mut count = 0u32; + for i in 0..17 { + if (a[i] >= 0) == (b[i] >= 0) { + count += 1; + } + } + count +} + +static SIGN_AGREEMENT_KERNEL: std::sync::LazyLock = + std::sync::LazyLock::new(|| { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx512f") { + return sign_agreement_avx512 as SignAgreementFn; + } + if is_x86_feature_detected!("avx2") { + return sign_agreement_avx2 as SignAgreementFn; + } + } + sign_agreement_scalar as SignAgreementFn + }); + +// ============================================================================ +// Multi-versioned xor_bind kernel: AVX-512 → AVX2 → scalar. +// ============================================================================ + +type XorBindFn = unsafe fn(&[i16; 17], &[i16; 17]) -> [i16; 17]; + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +unsafe fn xor_bind_avx512(a: &[i16; 17], b: &[i16; 17]) -> [i16; 17] { + use std::arch::x86_64::*; + // Load 16 i16 as i32, XOR, store back as i16 + let va = _mm512_cvtepi16_epi32(_mm256_loadu_si256(a.as_ptr() as *const __m256i)); + let vb = _mm512_cvtepi16_epi32(_mm256_loadu_si256(b.as_ptr() as *const __m256i)); + let xored = _mm512_xor_si512(va, vb); + // Convert back to i16: truncate i32 -> i16 via pmovdw + let packed = _mm512_cvtepi32_epi16(xored); + let mut dims = [0i16; 17]; + _mm256_storeu_si256(dims.as_mut_ptr() as *mut __m256i, packed); + dims[16] = (a[16] as u16 ^ b[16] as u16) as i16; + dims +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn xor_bind_avx2(a: &[i16; 17], b: &[i16; 17]) -> [i16; 17] { + use std::arch::x86_64::*; + // First 8 dims: load as i32, XOR, narrow back + let va0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(a.as_ptr() as *const __m128i)); + let vb0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(b.as_ptr() as *const __m128i)); + let xor0 = _mm256_xor_si256(va0, vb0); + + // Dims 8..16 + let va1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(a[8..].as_ptr() as *const __m128i)); + let vb1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(b[8..].as_ptr() as *const __m128i)); + let xor1 = _mm256_xor_si256(va1, vb1); + + // Extract results back to i16 + let mut dims = [0i16; 17]; + // Pack i32 -> i16 via shuffle + truncation + // We need the low 16 bits of each i32 lane. + // Use _mm256_packs_epi32 which saturates — but XOR of two i16 fits in i16, + // so we use manual extraction instead to avoid saturation issues. + let arr0: [i32; 8] = core::mem::transmute(xor0); + let arr1: [i32; 8] = core::mem::transmute(xor1); + for i in 0..8 { + dims[i] = arr0[i] as i16; + } + for i in 0..8 { + dims[8 + i] = arr1[i] as i16; + } + dims[16] = (a[16] as u16 ^ b[16] as u16) as i16; + dims +} + +fn xor_bind_scalar(a: &[i16; 17], b: &[i16; 17]) -> [i16; 17] { + let mut dims = [0i16; 17]; + for i in 0..17 { + dims[i] = (a[i] as u16 ^ b[i] as u16) as i16; + } + dims +} + +static XOR_BIND_KERNEL: std::sync::LazyLock = std::sync::LazyLock::new(|| { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx512f") { + return xor_bind_avx512 as XorBindFn; + } + if is_x86_feature_detected!("avx2") { + return xor_bind_avx2 as XorBindFn; + } + } + xor_bind_scalar as XorBindFn +}); + +// ============================================================================ +// Multi-versioned inject_noise kernel: AVX-512 → AVX2 → scalar. +// ============================================================================ + +type InjectNoiseFn = unsafe fn(&[i16; 17], i16, u64) -> [i16; 17]; + +/// Deterministic PRNG step (PCG-like LCG). +#[inline(always)] +fn prng_step(state: &mut u64) { + *state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); +} + +/// Compute noise value from PRNG state. +#[inline(always)] +fn noise_from_state(state: u64, scale: i16) -> i16 { + ((state >> 33) as i16).wrapping_mul(scale) >> 15 +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +unsafe fn inject_noise_avx512(dims: &[i16; 17], scale: i16, seed: u64) -> [i16; 17] { + use std::arch::x86_64::*; + // Generate 16 noise values via PRNG + let mut state = seed; + let mut noise_vals = [0i32; 16]; + for i in 0..16 { + prng_step(&mut state); + noise_vals[i] = noise_from_state(state, scale) as i32; + } + // Load dims as i32 + let vd = _mm512_cvtepi16_epi32(_mm256_loadu_si256(dims.as_ptr() as *const __m256i)); + let vn = _mm512_loadu_si512(noise_vals.as_ptr() as *const __m512i); + // Saturating add: add then clamp to i16 range + let sum = _mm512_add_epi32(vd, vn); + let lo = _mm512_set1_epi32(-32768); + let hi = _mm512_set1_epi32(32767); + let clamped = _mm512_max_epi32(_mm512_min_epi32(sum, hi), lo); + let packed = _mm512_cvtepi32_epi16(clamped); + let mut result = [0i16; 17]; + _mm256_storeu_si256(result.as_mut_ptr() as *mut __m256i, packed); + // 17th dim + prng_step(&mut state); + let n16 = noise_from_state(state, scale); + result[16] = dims[16].saturating_add(n16); + result +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn inject_noise_avx2(dims: &[i16; 17], scale: i16, seed: u64) -> [i16; 17] { + use std::arch::x86_64::*; + let mut state = seed; + // First 8 dims + let mut noise0 = [0i32; 8]; + for i in 0..8 { + prng_step(&mut state); + noise0[i] = noise_from_state(state, scale) as i32; + } + let vd0 = _mm256_cvtepi16_epi32(_mm_loadu_si128(dims.as_ptr() as *const __m128i)); + let vn0 = _mm256_loadu_si256(noise0.as_ptr() as *const __m256i); + let sum0 = _mm256_add_epi32(vd0, vn0); + + // Dims 8..16 + let mut noise1 = [0i32; 8]; + for i in 0..8 { + prng_step(&mut state); + noise1[i] = noise_from_state(state, scale) as i32; + } + let vd1 = _mm256_cvtepi16_epi32(_mm_loadu_si128(dims[8..].as_ptr() as *const __m128i)); + let vn1 = _mm256_loadu_si256(noise1.as_ptr() as *const __m256i); + let sum1 = _mm256_add_epi32(vd1, vn1); + + // Clamp and extract + let lo = _mm256_set1_epi32(-32768); + let hi = _mm256_set1_epi32(32767); + let c0 = _mm256_max_epi32(_mm256_min_epi32(sum0, hi), lo); + let c1 = _mm256_max_epi32(_mm256_min_epi32(sum1, hi), lo); + + let arr0: [i32; 8] = core::mem::transmute(c0); + let arr1: [i32; 8] = core::mem::transmute(c1); + let mut result = [0i16; 17]; + for i in 0..8 { + result[i] = arr0[i] as i16; + } + for i in 0..8 { + result[8 + i] = arr1[i] as i16; + } + // 17th dim + prng_step(&mut state); + let n16 = noise_from_state(state, scale); + result[16] = dims[16].saturating_add(n16); + result +} + +fn inject_noise_scalar(dims: &[i16; 17], scale: i16, seed: u64) -> [i16; 17] { + let mut result = [0i16; 17]; + result.copy_from_slice(dims); + let mut state = seed; + for d in 0..17 { + prng_step(&mut state); + let noise = noise_from_state(state, scale); + result[d] = result[d].saturating_add(noise); + } + result +} + +static INJECT_NOISE_KERNEL: std::sync::LazyLock = + std::sync::LazyLock::new(|| { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx512f") { + return inject_noise_avx512 as InjectNoiseFn; + } + if is_x86_feature_detected!("avx2") { + return inject_noise_avx2 as InjectNoiseFn; + } + } + inject_noise_scalar as InjectNoiseFn + }); + /// SPO triple of Base17 patterns. 102 bytes. #[derive(Clone, Debug, PartialEq, Eq)] pub struct SpoBase17 { @@ -89,14 +510,14 @@ impl Base17 { Base17 { dims: [0i16; BASE_DIM] } } - /// L1 (Manhattan) distance. + /// L1 (Manhattan) distance — multi-versioned kernel. + /// + /// Runtime dispatch via LazyLock: AVX-512 → AVX2 → scalar. + /// One binary serves all ISAs. #[inline] pub fn l1(&self, other: &Base17) -> u32 { - let mut d = 0u32; - for i in 0..BASE_DIM { - d += (self.dims[i] as i32 - other.dims[i] as i32).unsigned_abs(); - } - d + // SAFETY: LazyLock guarantees the selected kernel matches CPU features. + unsafe { L1_KERNEL(&self.dims, &other.dims) } } /// PCDVQ-informed L1: weight sign dimension 20x over mantissa. @@ -105,40 +526,32 @@ impl Base17 { /// quantization than magnitude. BF16 decomposition maps to polar: /// dim 0 = sign (direction), dims 1-6 = exponent (magnitude scale), /// dims 7-16 = mantissa (fine detail). + /// PCDVQ-weighted L1 via SIMD: sign=20x, magnitude=3x, detail=1x. + /// + /// Runtime dispatch via LazyLock: AVX-512 -> AVX2 -> scalar. #[inline] pub fn l1_weighted(&self, other: &Base17) -> u32 { - let mut d = 0u32; - for i in 0..BASE_DIM { - let diff = (self.dims[i] as i32 - other.dims[i] as i32).unsigned_abs(); - let weight = if i == 0 { 20 } else if i < 7 { 3 } else { 1 }; - d += diff * weight; - } - d + // SAFETY: LazyLock guarantees the selected kernel matches CPU features. + unsafe { L1_WEIGHTED_KERNEL(&self.dims, &other.dims) } } - /// Sign-bit agreement (out of 17). + /// Sign-bit agreement (out of 17) — multi-versioned kernel. + /// + /// Runtime dispatch via LazyLock: AVX-512 -> AVX2 -> scalar. #[inline] pub fn sign_agreement(&self, other: &Base17) -> u32 { - let mut a = 0u32; - for i in 0..BASE_DIM { - if (self.dims[i] >= 0) == (other.dims[i] >= 0) { - a += 1; - } - } - a + // SAFETY: LazyLock guarantees the selected kernel matches CPU features. + unsafe { SIGN_AGREEMENT_KERNEL(&self.dims, &other.dims) } } /// XOR bind: path composition in hyperdimensional space. - /// - /// Bitwise XOR on each i16 dimension (reinterpreted as u16). /// Self-inverse: `a.xor_bind(&b).xor_bind(&b) == a`. - /// Identity: `a.xor_bind(&Base17::zero()) == a`. + /// + /// Runtime dispatch via LazyLock: AVX-512 -> AVX2 -> scalar. #[inline] pub fn xor_bind(&self, other: &Base17) -> Base17 { - let mut dims = [0i16; BASE_DIM]; - for i in 0..BASE_DIM { - dims[i] = (self.dims[i] as u16 ^ other.dims[i] as u16) as i16; - } + // SAFETY: LazyLock guarantees the selected kernel matches CPU features. + let dims = unsafe { XOR_BIND_KERNEL(&self.dims, &other.dims) }; Base17 { dims } } @@ -179,17 +592,13 @@ impl Base17 { /// #6 Thought Randomization — calibrated noise injection on Base17. /// Flip dims with magnitude proportional to coefficient of variation. /// Science: Kirkpatrick et al. (1983), Rahimi & Recht (2007). + /// + /// Runtime dispatch via LazyLock: AVX-512 -> AVX2 -> scalar. pub fn inject_noise(&self, cv: f32, seed: u64) -> Base17 { - let mut noisy = self.clone(); - // Simple deterministic PRNG from seed - let mut state = seed; let scale = (cv * 32767.0).min(32767.0) as i16; - for d in 0..17 { - state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); - let noise = ((state >> 33) as i16).wrapping_mul(scale) >> 15; - noisy.dims[d] = noisy.dims[d].saturating_add(noise); - } - noisy + // SAFETY: LazyLock guarantees the selected kernel matches CPU features. + let dims = unsafe { INJECT_NOISE_KERNEL(&self.dims, scale, seed) }; + Base17 { dims } } /// Serialize to 34 bytes (little-endian). diff --git a/src/hpc/palette_distance.rs b/src/hpc/palette_distance.rs index 34dea154..6dfe4b9a 100644 --- a/src/hpc/palette_distance.rs +++ b/src/hpc/palette_distance.rs @@ -12,6 +12,105 @@ use super::bgz17_bridge::{Base17, PaletteEdge, SpoBase17}; const MAX_PALETTE_SIZE: usize = 256; const BASE_DIM: usize = 17; +// ============================================================================ +// Multi-versioned nearest kernel: AVX-512 → AVX2 → scalar. +// The inner l1() is already SIMD-dispatched; this unrolls the outer loop. +// ============================================================================ + +type NearestFn = unsafe fn(&[Base17], &Base17) -> u8; + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +unsafe fn nearest_avx512(entries: &[Base17], query: &Base17) -> u8 { + let mut best_idx = 0u8; + let mut best_dist = u32::MAX; + // 4-way unroll for better branch prediction and ILP + let chunks = entries.len() / 4; + for c in 0..chunks { + let base = c * 4; + let d0 = query.l1(&entries[base]); + let d1 = query.l1(&entries[base + 1]); + let d2 = query.l1(&entries[base + 2]); + let d3 = query.l1(&entries[base + 3]); + // Find min of 4 + let (mut min_d, mut min_i) = (d0, 0usize); + if d1 < min_d { min_d = d1; min_i = 1; } + if d2 < min_d { min_d = d2; min_i = 2; } + if d3 < min_d { min_d = d3; min_i = 3; } + if min_d < best_dist { + best_dist = min_d; + best_idx = (base + min_i) as u8; + } + } + // Remainder + for i in (chunks * 4)..entries.len() { + let d = query.l1(&entries[i]); + if d < best_dist { + best_dist = d; + best_idx = i as u8; + } + } + best_idx +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn nearest_avx2(entries: &[Base17], query: &Base17) -> u8 { + let mut best_idx = 0u8; + let mut best_dist = u32::MAX; + // 4-way unroll (l1 already dispatches to AVX2 internally) + let chunks = entries.len() / 4; + for c in 0..chunks { + let base = c * 4; + let d0 = query.l1(&entries[base]); + let d1 = query.l1(&entries[base + 1]); + let d2 = query.l1(&entries[base + 2]); + let d3 = query.l1(&entries[base + 3]); + let (mut min_d, mut min_i) = (d0, 0usize); + if d1 < min_d { min_d = d1; min_i = 1; } + if d2 < min_d { min_d = d2; min_i = 2; } + if d3 < min_d { min_d = d3; min_i = 3; } + if min_d < best_dist { + best_dist = min_d; + best_idx = (base + min_i) as u8; + } + } + for i in (chunks * 4)..entries.len() { + let d = query.l1(&entries[i]); + if d < best_dist { + best_dist = d; + best_idx = i as u8; + } + } + best_idx +} + +fn nearest_scalar(entries: &[Base17], query: &Base17) -> u8 { + let mut best_idx = 0u8; + let mut best_dist = u32::MAX; + for (i, entry) in entries.iter().enumerate() { + let d = query.l1(entry); + if d < best_dist { + best_dist = d; + best_idx = i as u8; + } + } + best_idx +} + +static NEAREST_KERNEL: std::sync::LazyLock = std::sync::LazyLock::new(|| { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx512f") { + return nearest_avx512 as NearestFn; + } + if is_x86_feature_detected!("avx2") { + return nearest_avx2 as NearestFn; + } + } + nearest_scalar as NearestFn +}); + /// A palette codebook: up to 256 archetypal Base17 patterns. #[derive(Clone, Debug)] pub struct Palette { @@ -57,17 +156,12 @@ impl Palette { } /// Find the nearest palette entry to a given base pattern. Returns index. + /// + /// Runtime dispatch via LazyLock: AVX-512 → AVX2 → scalar. + /// Inner l1() is already SIMD; outer loop is 4-way unrolled. pub fn nearest(&self, query: &Base17) -> u8 { - let mut best_idx = 0u8; - let mut best_dist = u32::MAX; - for (i, entry) in self.entries.iter().enumerate() { - let d = query.l1(entry); - if d < best_dist { - best_dist = d; - best_idx = i as u8; - } - } - best_idx + // SAFETY: LazyLock guarantees the selected kernel matches CPU features. + unsafe { NEAREST_KERNEL(&self.entries, query) } } /// Encode an SpoBase17 edge to palette indices.