diff --git a/src/hpc/aabb.rs b/src/hpc/aabb.rs new file mode 100644 index 00000000..100e9f84 --- /dev/null +++ b/src/hpc/aabb.rs @@ -0,0 +1,461 @@ +//! Axis-aligned bounding box batch operations. +//! +//! Provides SIMD-accelerated batch intersection, expansion, and distance +//! queries for entity collision detection. + +/// Axis-aligned bounding box stored as 6 `f32` values. +/// +/// # Examples +/// +/// ``` +/// use ndarray::hpc::aabb::Aabb; +/// +/// let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]); +/// let b = Aabb::new([0.5, 0.5, 0.5], [1.5, 1.5, 1.5]); +/// assert!(a.intersects(&b)); +/// ``` +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(C)] +pub struct Aabb { + pub min: [f32; 3], + pub max: [f32; 3], +} + +impl Aabb { + /// Create a new AABB from min and max corners. + #[inline] + pub fn new(min: [f32; 3], max: [f32; 3]) -> Self { + Self { min, max } + } + + /// Test if this AABB intersects another (inclusive on boundaries). + #[inline] + pub fn intersects(&self, other: &Aabb) -> bool { + self.min[0] <= other.max[0] + && self.max[0] >= other.min[0] + && self.min[1] <= other.max[1] + && self.max[1] >= other.min[1] + && self.min[2] <= other.max[2] + && self.max[2] >= other.min[2] + } + + /// Expand the AABB by `(dx, dy, dz)` in both directions per axis. + #[inline] + pub fn expand(&self, dx: f32, dy: f32, dz: f32) -> Self { + Self { + min: [self.min[0] - dx, self.min[1] - dy, self.min[2] - dz], + max: [self.max[0] + dx, self.max[1] + dy, self.max[2] + dz], + } + } + + /// Test if a point is inside (or on the boundary of) this AABB. + #[inline] + pub fn contains_point(&self, point: [f32; 3]) -> bool { + point[0] >= self.min[0] + && point[0] <= self.max[0] + && point[1] >= self.min[1] + && point[1] <= self.max[1] + && point[2] >= self.min[2] + && point[2] <= self.max[2] + } + + /// Volume of the AABB. Returns 0 if any dimension is degenerate. + #[inline] + pub fn volume(&self) -> f32 { + let dx = (self.max[0] - self.min[0]).max(0.0); + let dy = (self.max[1] - self.min[1]).max(0.0); + let dz = (self.max[2] - self.min[2]).max(0.0); + dx * dy * dz + } + + /// Center point of the AABB. + #[inline] + pub fn center(&self) -> [f32; 3] { + [ + (self.min[0] + self.max[0]) * 0.5, + (self.min[1] + self.max[1]) * 0.5, + (self.min[2] + self.max[2]) * 0.5, + ] + } +} + +/// Squared distance from a point to the nearest point on an AABB. +#[inline] +fn sq_dist_point_aabb(point: [f32; 3], aabb: &Aabb) -> f32 { + let mut dist_sq = 0.0f32; + for axis in 0..3 { + let v = point[axis]; + if v < aabb.min[axis] { + let d = aabb.min[axis] - v; + dist_sq += d * d; + } else if v > aabb.max[axis] { + let d = v - aabb.max[axis]; + dist_sq += d * d; + } + } + dist_sq +} + +/// Test one AABB against N candidates. Returns a `Vec` indicating +/// which candidates intersect the query. +pub fn aabb_intersect_batch(query: &Aabb, candidates: &[Aabb]) -> Vec { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("sse4.1") { + // SAFETY: sse4.1 detected, slice access within bounds. + unsafe { + return aabb_intersect_batch_sse41(query, candidates); + } + } + } + + aabb_intersect_batch_scalar(query, candidates) +} + +fn aabb_intersect_batch_scalar(query: &Aabb, candidates: &[Aabb]) -> Vec { + candidates.iter().map(|c| query.intersects(c)).collect() +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "sse4.1")] +unsafe fn aabb_intersect_batch_sse41(query: &Aabb, candidates: &[Aabb]) -> Vec { + use core::arch::x86_64::*; + + // Load query min/max into SSE registers (only need xyz, ignore w). + let q_min = _mm_set_ps(0.0, query.min[2], query.min[1], query.min[0]); + let q_max = _mm_set_ps(f32::MAX, query.max[2], query.max[1], query.max[0]); + + let mut result = Vec::with_capacity(candidates.len()); + for c in candidates { + let c_min = _mm_set_ps(0.0, c.min[2], c.min[1], c.min[0]); + let c_max = _mm_set_ps(f32::MAX, c.max[2], c.max[1], c.max[0]); + + // q.min <= c.max AND q.max >= c.min (per component) + let le = _mm_cmple_ps(q_min, c_max); // q_min[i] <= c_max[i] + let ge = _mm_cmpge_ps(q_max, c_min); // q_max[i] >= c_min[i] + let both = _mm_and_ps(le, ge); + // All 4 lanes must be true (lane 3 is always true due to sentinel values). + let mask = _mm_movemask_ps(both); + result.push(mask == 0xF); + } + result +} + +/// Expand all AABBs in-place by `(dx, dy, dz)` in both directions per axis. +pub fn aabb_expand_batch(aabbs: &mut [Aabb], dx: f32, dy: f32, dz: f32) { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("sse2") { + // SAFETY: sse2 detected, operating on mutable slice in-bounds. + unsafe { + aabb_expand_batch_sse2(aabbs, dx, dy, dz); + return; + } + } + } + + aabb_expand_batch_scalar(aabbs, dx, dy, dz); +} + +fn aabb_expand_batch_scalar(aabbs: &mut [Aabb], dx: f32, dy: f32, dz: f32) { + for a in aabbs.iter_mut() { + a.min[0] -= dx; + a.min[1] -= dy; + a.min[2] -= dz; + a.max[0] += dx; + a.max[1] += dy; + a.max[2] += dz; + } +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "sse2")] +unsafe fn aabb_expand_batch_sse2(aabbs: &mut [Aabb], dx: f32, dy: f32, dz: f32) { + use core::arch::x86_64::*; + + let delta_min = _mm_set_ps(0.0, dz, dy, dx); + let delta_max = _mm_set_ps(0.0, dz, dy, dx); + + for a in aabbs.iter_mut() { + let min_v = _mm_set_ps(0.0, a.min[2], a.min[1], a.min[0]); + let max_v = _mm_set_ps(0.0, a.max[2], a.max[1], a.max[0]); + + let new_min = _mm_sub_ps(min_v, delta_min); + let new_max = _mm_add_ps(max_v, delta_max); + + // Store back. We cannot use _mm_storeu_ps directly into [f32;3], + // so extract components. + let mut min_arr = [0.0f32; 4]; + let mut max_arr = [0.0f32; 4]; + _mm_storeu_ps(min_arr.as_mut_ptr(), new_min); + _mm_storeu_ps(max_arr.as_mut_ptr(), new_max); + + a.min = [min_arr[0], min_arr[1], min_arr[2]]; + a.max = [max_arr[0], max_arr[1], max_arr[2]]; + } +} + +/// Squared distance from a point to the nearest point on each AABB. +pub fn aabb_squared_distance_batch(point: [f32; 3], aabbs: &[Aabb]) -> Vec { + aabbs.iter().map(|a| sq_dist_point_aabb(point, a)).collect() +} + +/// Filter AABBs by maximum squared distance from a point. Returns indices +/// of AABBs whose nearest point is within `max_sq_dist` of `point`. +pub fn aabb_filter_by_distance( + point: [f32; 3], + aabbs: &[Aabb], + max_sq_dist: f32, +) -> Vec { + let distances = aabb_squared_distance_batch(point, aabbs); + distances + .iter() + .enumerate() + .filter(|(_, &d)| d <= max_sq_dist) + .map(|(i, _)| i) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + fn approx_eq(a: f32, b: f32) -> bool { + (a - b).abs() < 1e-5 + } + + // ---------- Aabb unit tests ---------- + + #[test] + fn test_intersects_overlap() { + let a = Aabb::new([0.0, 0.0, 0.0], [2.0, 2.0, 2.0]); + let b = Aabb::new([1.0, 1.0, 1.0], [3.0, 3.0, 3.0]); + assert!(a.intersects(&b)); + assert!(b.intersects(&a)); + } + + #[test] + fn test_intersects_touching() { + let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]); + let b = Aabb::new([1.0, 0.0, 0.0], [2.0, 1.0, 1.0]); + assert!(a.intersects(&b)); // boundary inclusive + } + + #[test] + fn test_no_intersect() { + let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]); + let b = Aabb::new([2.0, 2.0, 2.0], [3.0, 3.0, 3.0]); + assert!(!a.intersects(&b)); + } + + #[test] + fn test_no_intersect_single_axis() { + let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]); + let b = Aabb::new([0.0, 0.0, 1.5], [1.0, 1.0, 2.5]); // only z separates + assert!(!a.intersects(&b)); + } + + #[test] + fn test_contains_point() { + let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]); + assert!(a.contains_point([0.5, 0.5, 0.5])); + assert!(a.contains_point([0.0, 0.0, 0.0])); // boundary + assert!(a.contains_point([1.0, 1.0, 1.0])); // boundary + assert!(!a.contains_point([1.5, 0.5, 0.5])); + } + + #[test] + fn test_expand() { + let a = Aabb::new([1.0, 2.0, 3.0], [4.0, 5.0, 6.0]); + let expanded = a.expand(0.5, 1.0, 1.5); + assert!(approx_eq(expanded.min[0], 0.5)); + assert!(approx_eq(expanded.min[1], 1.0)); + assert!(approx_eq(expanded.min[2], 1.5)); + assert!(approx_eq(expanded.max[0], 4.5)); + assert!(approx_eq(expanded.max[1], 6.0)); + assert!(approx_eq(expanded.max[2], 7.5)); + } + + #[test] + fn test_volume() { + let a = Aabb::new([0.0, 0.0, 0.0], [2.0, 3.0, 4.0]); + assert!(approx_eq(a.volume(), 24.0)); + } + + #[test] + fn test_volume_degenerate() { + let a = Aabb::new([0.0, 0.0, 0.0], [0.0, 3.0, 4.0]); + assert!(approx_eq(a.volume(), 0.0)); + } + + #[test] + fn test_center() { + let a = Aabb::new([1.0, 2.0, 3.0], [5.0, 6.0, 7.0]); + let c = a.center(); + assert!(approx_eq(c[0], 3.0)); + assert!(approx_eq(c[1], 4.0)); + assert!(approx_eq(c[2], 5.0)); + } + + // ---------- Batch tests ---------- + + #[test] + fn test_intersect_batch() { + let query = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]); + let candidates = vec![ + Aabb::new([0.5, 0.5, 0.5], [1.5, 1.5, 1.5]), // yes + Aabb::new([2.0, 2.0, 2.0], [3.0, 3.0, 3.0]), // no + Aabb::new([-1.0, -1.0, -1.0], [0.5, 0.5, 0.5]), // yes + Aabb::new([1.0, 1.0, 1.0], [2.0, 2.0, 2.0]), // yes (touching) + ]; + let results = aabb_intersect_batch(&query, &candidates); + assert_eq!(results, vec![true, false, true, true]); + } + + #[test] + fn test_intersect_batch_empty() { + let query = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]); + let results = aabb_intersect_batch(&query, &[]); + assert!(results.is_empty()); + } + + #[test] + fn test_intersect_batch_scalar_parity() { + let query = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]); + let candidates: Vec = (0..100) + .map(|i| { + let f = i as f32 * 0.1; + Aabb::new([f - 0.5, f - 0.5, f - 0.5], [f + 0.5, f + 0.5, f + 0.5]) + }) + .collect(); + + let batch = aabb_intersect_batch(&query, &candidates); + let scalar: Vec = candidates.iter().map(|c| query.intersects(c)).collect(); + assert_eq!(batch, scalar); + } + + #[test] + fn test_expand_batch() { + let mut aabbs = vec![ + Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), + Aabb::new([5.0, 5.0, 5.0], [6.0, 6.0, 6.0]), + ]; + aabb_expand_batch(&mut aabbs, 0.5, 1.0, 1.5); + + assert!(approx_eq(aabbs[0].min[0], -0.5)); + assert!(approx_eq(aabbs[0].max[2], 2.5)); + assert!(approx_eq(aabbs[1].min[1], 4.0)); + assert!(approx_eq(aabbs[1].max[0], 6.5)); + } + + #[test] + fn test_expand_batch_scalar_parity() { + let base: Vec = (0..50) + .map(|i| { + let f = i as f32; + Aabb::new([f, f, f], [f + 1.0, f + 2.0, f + 3.0]) + }) + .collect(); + + let mut batch = base.clone(); + aabb_expand_batch(&mut batch, 0.25, 0.5, 0.75); + + for (i, orig) in base.iter().enumerate() { + let expected = orig.expand(0.25, 0.5, 0.75); + for axis in 0..3 { + assert!( + approx_eq(batch[i].min[axis], expected.min[axis]), + "min mismatch at [{},{}]", + i, + axis + ); + assert!( + approx_eq(batch[i].max[axis], expected.max[axis]), + "max mismatch at [{},{}]", + i, + axis + ); + } + } + } + + // ---------- Distance tests ---------- + + #[test] + fn test_squared_distance_inside() { + let a = Aabb::new([0.0, 0.0, 0.0], [2.0, 2.0, 2.0]); + let d = sq_dist_point_aabb([1.0, 1.0, 1.0], &a); + assert!(approx_eq(d, 0.0)); + } + + #[test] + fn test_squared_distance_outside() { + let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]); + // Point is 1 unit away on x-axis + let d = sq_dist_point_aabb([2.0, 0.5, 0.5], &a); + assert!(approx_eq(d, 1.0)); + } + + #[test] + fn test_squared_distance_corner() { + let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]); + // Point at (2,2,2): distance to corner (1,1,1) = sqrt(3), sq=3 + let d = sq_dist_point_aabb([2.0, 2.0, 2.0], &a); + assert!(approx_eq(d, 3.0)); + } + + #[test] + fn test_squared_distance_batch() { + let aabbs = vec![ + Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), + Aabb::new([10.0, 10.0, 10.0], [11.0, 11.0, 11.0]), + ]; + let dists = aabb_squared_distance_batch([0.5, 0.5, 0.5], &aabbs); + assert!(approx_eq(dists[0], 0.0)); // inside + assert!(dists[1] > 200.0); // far away + } + + #[test] + fn test_filter_by_distance() { + let aabbs = vec![ + Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), // 0: dist=0 + Aabb::new([2.0, 0.0, 0.0], [3.0, 1.0, 1.0]), // 1: nearest pt (2,0.5,0.5), dist=1.5, sq=2.25 + Aabb::new([10.0, 10.0, 10.0], [11.0, 11.0, 11.0]),// 2: far + ]; + let indices = aabb_filter_by_distance([0.5, 0.5, 0.5], &aabbs, 5.0); + assert_eq!(indices, vec![0, 1]); + } + + #[test] + fn test_filter_by_distance_none() { + let aabbs = vec![ + Aabb::new([100.0, 100.0, 100.0], [101.0, 101.0, 101.0]), + ]; + let indices = aabb_filter_by_distance([0.0, 0.0, 0.0], &aabbs, 1.0); + assert!(indices.is_empty()); + } + + #[test] + fn test_filter_by_distance_all() { + let aabbs = vec![ + Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), + Aabb::new([0.5, 0.5, 0.5], [1.5, 1.5, 1.5]), + ]; + let indices = aabb_filter_by_distance([0.7, 0.7, 0.7], &aabbs, 100.0); + assert_eq!(indices, vec![0, 1]); + } + + #[test] + fn test_self_intersection() { + let a = Aabb::new([1.0, 2.0, 3.0], [4.0, 5.0, 6.0]); + assert!(a.intersects(&a)); + } + + #[test] + fn test_zero_volume_aabb_intersects() { + let a = Aabb::new([1.0, 1.0, 1.0], [1.0, 1.0, 1.0]); // point + let b = Aabb::new([0.0, 0.0, 0.0], [2.0, 2.0, 2.0]); + assert!(a.intersects(&b)); + assert!(b.intersects(&a)); + } +} diff --git a/src/hpc/arrow_bridge.rs b/src/hpc/arrow_bridge.rs index d3922537..363c4e55 100644 --- a/src/hpc/arrow_bridge.rs +++ b/src/hpc/arrow_bridge.rs @@ -598,6 +598,209 @@ impl BindNodeV2 { } } +// ============================================================================ +// Per-Row Types: ThreePlaneRowBuffer, SoakingRowBuffer, BindNodeV2Row +// ============================================================================ + +/// Three-plane fingerprint row buffer: holds S/P/O binary fingerprints for +/// a single row, suitable for zero-copy Arrow interop. +/// +/// Total: 3 x 2048 = 6144 bytes per row. +#[derive(Debug, Clone)] +pub struct ThreePlaneRowBuffer { + /// Subject binary fingerprint (2048 bytes). + pub s_binary: Vec, + /// Predicate binary fingerprint (2048 bytes). + pub p_binary: Vec, + /// Object binary fingerprint (2048 bytes). + pub o_binary: Vec, +} + +impl ThreePlaneRowBuffer { + /// Create a zeroed three-plane row buffer. + pub fn new() -> Self { + Self { + s_binary: vec![0u8; PLANE_BINARY_BYTES], + p_binary: vec![0u8; PLANE_BINARY_BYTES], + o_binary: vec![0u8; PLANE_BINARY_BYTES], + } + } + + /// Create from three Plane references (copies their cached bit patterns). + pub fn from_planes(s: &mut Plane, p: &mut Plane, o: &mut Plane) -> Self { + s.ensure_cache(); + p.ensure_cache(); + o.ensure_cache(); + Self { + s_binary: s.bits_bytes_ref().to_vec(), + p_binary: p.bits_bytes_ref().to_vec(), + o_binary: o.bits_bytes_ref().to_vec(), + } + } + + /// Compute S XOR P XOR O composite fingerprint. + pub fn xor_spo(&self) -> Vec { + let mut result = vec![0u8; PLANE_BINARY_BYTES]; + for i in 0..PLANE_BINARY_BYTES { + result[i] = self.s_binary[i] ^ self.p_binary[i] ^ self.o_binary[i]; + } + result + } + + /// Per-plane Hamming distance to another row buffer. + /// + /// Returns `(subject_dist, predicate_dist, object_dist)`. + pub fn hamming_distance(&self, other: &ThreePlaneRowBuffer) -> (u64, u64, u64) { + let ds = hamming_distance_raw(&self.s_binary, &other.s_binary); + let dp = hamming_distance_raw(&self.p_binary, &other.p_binary); + let do_ = hamming_distance_raw(&self.o_binary, &other.o_binary); + (ds, dp, do_) + } + + /// Total byte size of this row buffer (always 3 * PLANE_BINARY_BYTES). + pub fn total_bytes(&self) -> usize { + 3 * PLANE_BINARY_BYTES + } +} + +impl Default for ThreePlaneRowBuffer { + fn default() -> Self { + Self::new() + } +} + +/// Soaking row buffer: nullable i8 accumulator for a single plane of a single row. +/// +/// When `data` is `Some`, the buffer is active (Form state). +/// When `data` is `None`, the buffer has been crystallized or nulled. +#[derive(Debug, Clone)] +pub struct SoakingRowBuffer { + /// Soaking data (None = nulled/crystallized). + pub data: Option>, + /// Dimension count. + pub dims: usize, +} + +impl SoakingRowBuffer { + /// Create a new active soaking row buffer, zeroed. + pub fn new(dims: usize) -> Self { + Self { + data: Some(vec![0i8; dims]), + dims, + } + } + + /// Crystallize: convert soaking (int8) to binary fingerprint via sign(). + /// + /// Consumes the soaking data and returns a binary vector. + /// After crystallization, the buffer is nulled. + pub fn crystallize(&mut self) -> Vec { + let soaking = match self.data.take() { + Some(d) => d, + None => return vec![0u8; (self.dims + 7) / 8], + }; + let n_bytes = (soaking.len() + 7) / 8; + let mut bits = vec![0u8; n_bytes]; + for (i, &val) in soaking.iter().enumerate() { + if val > 0 { + bits[i / 8] |= 1 << (i % 8); + } + } + bits + } + + /// Returns `true` when soaking is still active (not nulled/crystallized). + pub fn is_active(&self) -> bool { + self.data.is_some() + } + + /// Null out the soaking data (transition to inactive). + pub fn null_out(&mut self) { + self.data = None; + } +} + +/// Complete bind_nodes_v2 row type combining fingerprints, soaking, gate, +/// and NARS truth values. +/// +/// This is a streamlined per-row type that pairs `ThreePlaneRowBuffer` with +/// per-plane `SoakingRowBuffer`s for the full Lance schema. +#[derive(Debug, Clone)] +pub struct BindNodeV2Row { + /// Three-plane binary fingerprints. + pub fingerprints: ThreePlaneRowBuffer, + /// Subject soaking accumulator. + pub s_soaking: SoakingRowBuffer, + /// Predicate soaking accumulator. + pub p_soaking: SoakingRowBuffer, + /// Object soaking accumulator. + pub o_soaking: SoakingRowBuffer, + /// Gate lifecycle state. + pub gate: GateState, + /// NARS frequency (u16 fixed-point). + pub nars_frequency: u16, + /// NARS confidence (u16 fixed-point). + pub nars_confidence: u16, +} + +impl BindNodeV2Row { + /// Create a new row in Form state with active soaking. + pub fn new(dims: usize) -> Self { + Self { + fingerprints: ThreePlaneRowBuffer::new(), + s_soaking: SoakingRowBuffer::new(dims), + p_soaking: SoakingRowBuffer::new(dims), + o_soaking: SoakingRowBuffer::new(dims), + gate: GateState::Form, + nars_frequency: 32768, + nars_confidence: 0, + } + } + + /// Crystallize all three soaking buffers, folding sign bits into fingerprints. + /// Transitions from Form to Flow. + pub fn crystallize(&mut self) { + if self.gate != GateState::Form { + return; + } + // Fold each soaking into its binary plane + if let Some(ref soaking) = self.s_soaking.data { + fold_sign_into_binary(&mut self.fingerprints.s_binary, soaking); + } + if let Some(ref soaking) = self.p_soaking.data { + fold_sign_into_binary(&mut self.fingerprints.p_binary, soaking); + } + if let Some(ref soaking) = self.o_soaking.data { + fold_sign_into_binary(&mut self.fingerprints.o_binary, soaking); + } + self.s_soaking.null_out(); + self.p_soaking.null_out(); + self.o_soaking.null_out(); + self.gate = GateState::Flow; + } + + /// Freeze: transition from Flow to Freeze. + pub fn freeze(&mut self) { + if self.gate == GateState::Flow { + self.gate = GateState::Freeze; + } + } +} + +/// Fold soaking sign bits into a binary fingerprint (shared helper). +fn fold_sign_into_binary(binary: &mut [u8], soaking: &[i8]) { + let bit_count = (binary.len() * 8).min(soaking.len()); + for i in 0..bit_count { + let byte_idx = i / 8; + let bit_idx = i % 8; + if soaking[i] > 0 { + binary[byte_idx] |= 1 << bit_idx; + } else { + binary[byte_idx] &= !(1 << bit_idx); + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -982,4 +1185,164 @@ mod tests { assert_eq!(slice.len(), 3 * BINARY_BYTES); assert_eq!(slice[BINARY_BYTES], 0xAB); } + + // ================================================================ + // ThreePlaneRowBuffer tests + // ================================================================ + + #[test] + fn three_plane_row_buffer_new() { + let buf = ThreePlaneRowBuffer::new(); + assert_eq!(buf.s_binary.len(), PLANE_BINARY_BYTES); + assert_eq!(buf.p_binary.len(), PLANE_BINARY_BYTES); + assert_eq!(buf.o_binary.len(), PLANE_BINARY_BYTES); + assert_eq!(buf.total_bytes(), 3 * PLANE_BINARY_BYTES); + } + + #[test] + fn three_plane_row_buffer_default() { + let buf = ThreePlaneRowBuffer::default(); + assert_eq!(buf.total_bytes(), 6144); + } + + #[test] + fn three_plane_row_buffer_from_planes() { + let (mut s, mut p, mut o) = make_test_planes(); + let buf = ThreePlaneRowBuffer::from_planes(&mut s, &mut p, &mut o); + assert_eq!(buf.s_binary.len(), PLANE_BINARY_BYTES); + // Non-trivial planes should produce non-zero binaries + assert!(buf.s_binary.iter().any(|&b| b != 0)); + } + + #[test] + fn three_plane_row_buffer_xor_spo() { + let mut buf = ThreePlaneRowBuffer::new(); + buf.s_binary[0] = 0xFF; + buf.p_binary[0] = 0x0F; + buf.o_binary[0] = 0xAA; + let xor = buf.xor_spo(); + assert_eq!(xor[0], 0xFF ^ 0x0F ^ 0xAA); + assert_eq!(xor[1], 0); // rest is zero + } + + #[test] + fn three_plane_row_buffer_hamming_self_zero() { + let (mut s, mut p, mut o) = make_test_planes(); + let buf = ThreePlaneRowBuffer::from_planes(&mut s, &mut p, &mut o); + let (ds, dp, do_) = buf.hamming_distance(&buf); + assert_eq!(ds, 0); + assert_eq!(dp, 0); + assert_eq!(do_, 0); + } + + #[test] + fn three_plane_row_buffer_hamming_different() { + let mut buf1 = ThreePlaneRowBuffer::new(); + let mut buf2 = ThreePlaneRowBuffer::new(); + buf1.s_binary.fill(0xFF); + buf2.s_binary.fill(0x00); + let (ds, dp, _) = buf1.hamming_distance(&buf2); + assert_eq!(ds, PLANE_BINARY_BYTES as u64 * 8); + assert_eq!(dp, 0); // both zero + } + + // ================================================================ + // SoakingRowBuffer tests + // ================================================================ + + #[test] + fn soaking_row_buffer_new() { + let buf = SoakingRowBuffer::new(100); + assert!(buf.is_active()); + assert_eq!(buf.dims, 100); + assert_eq!(buf.data.as_ref().unwrap().len(), 100); + } + + #[test] + fn soaking_row_buffer_crystallize() { + let mut buf = SoakingRowBuffer::new(8); + buf.data.as_mut().unwrap().copy_from_slice(&[1, -1, 1, -1, 1, -1, 1, -1]); + let bits = buf.crystallize(); + assert_eq!(bits[0], 0b01010101); + assert!(!buf.is_active()); // should be nulled after crystallize + } + + #[test] + fn soaking_row_buffer_crystallize_inactive() { + let mut buf = SoakingRowBuffer::new(16); + buf.null_out(); + assert!(!buf.is_active()); + let bits = buf.crystallize(); + // Should return zeroed bits when inactive + assert!(bits.iter().all(|&b| b == 0)); + } + + #[test] + fn soaking_row_buffer_null_out() { + let mut buf = SoakingRowBuffer::new(10); + assert!(buf.is_active()); + buf.null_out(); + assert!(!buf.is_active()); + } + + // ================================================================ + // BindNodeV2Row tests + // ================================================================ + + #[test] + fn bind_node_v2_row_new() { + let row = BindNodeV2Row::new(100); + assert_eq!(row.gate, GateState::Form); + assert!(row.s_soaking.is_active()); + assert!(row.p_soaking.is_active()); + assert!(row.o_soaking.is_active()); + assert_eq!(row.nars_frequency, 32768); + assert_eq!(row.nars_confidence, 0); + assert_eq!(row.fingerprints.total_bytes(), 6144); + } + + #[test] + fn bind_node_v2_row_crystallize() { + let mut row = BindNodeV2Row::new(16); + // Put some data in soaking + row.s_soaking.data.as_mut().unwrap().fill(1); + row.crystallize(); + assert_eq!(row.gate, GateState::Flow); + assert!(!row.s_soaking.is_active()); + assert!(!row.p_soaking.is_active()); + assert!(!row.o_soaking.is_active()); + } + + #[test] + fn bind_node_v2_row_lifecycle() { + let mut row = BindNodeV2Row::new(16); + assert_eq!(row.gate, GateState::Form); + + // Cannot freeze from Form + row.freeze(); + assert_eq!(row.gate, GateState::Form); + + // Crystallize: Form -> Flow + row.crystallize(); + assert_eq!(row.gate, GateState::Flow); + + // Double crystallize is no-op + row.crystallize(); + assert_eq!(row.gate, GateState::Flow); + + // Freeze: Flow -> Freeze + row.freeze(); + assert_eq!(row.gate, GateState::Freeze); + } + + #[test] + fn bind_node_v2_row_crystallize_folds_sign() { + let mut row = BindNodeV2Row::new(16); + // Set subject soaking to all positive + row.s_soaking.data.as_mut().unwrap().fill(5); + row.crystallize(); + // First 2 bytes of s_binary should have bits set (16 bits = 2 bytes) + assert_eq!(row.fingerprints.s_binary[0], 0xFF); + assert_eq!(row.fingerprints.s_binary[1], 0xFF); + } } diff --git a/src/hpc/bitwise.rs b/src/hpc/bitwise.rs index 21eeae12..63faa386 100644 --- a/src/hpc/bitwise.rs +++ b/src/hpc/bitwise.rs @@ -306,6 +306,29 @@ fn dispatch_hamming_batch(query: &[u8], database: &[u8], num_rows: usize, row_by .collect() } +/// Count set bits across an array of u64 words. +/// More efficient than reinterpreting as bytes — works on native u64s directly. +pub fn popcount_batch_u64(words: &[u64]) -> u64 { + // Use POPCNT instruction if available, else scalar + words.iter().map(|w| w.count_ones() as u64).sum() +} + +/// Per-word popcount: returns count of set bits in each u64. +pub fn popcount_per_word(words: &[u64]) -> Vec { + words.iter().map(|w| w.count_ones()).collect() +} + +/// Batch AND + popcount: for each word, compute (word & mask).count_ones(). +/// Used for "count blocks matching a property mask in each palette group." +pub fn masked_popcount_batch(words: &[u64], mask: u64) -> Vec { + words.iter().map(|w| (w & mask).count_ones()).collect() +} + +/// Total masked popcount across all words. +pub fn masked_popcount_total(words: &[u64], mask: u64) -> u64 { + words.iter().map(|w| (w & mask).count_ones() as u64).sum() +} + impl BitwiseOps for ArrayBase where S: Data { @@ -638,6 +661,27 @@ mod tests { } } + #[test] + fn test_popcount_batch_u64() { + let words = [0xFFFFFFFFFFFFFFFFu64, 0, 0x0F0F0F0F0F0F0F0F]; + assert_eq!(super::popcount_batch_u64(&words), 64 + 0 + 32); + } + + #[test] + fn test_popcount_per_word() { + let words = [0xFFu64, 0xFFFF, 0]; + let counts = super::popcount_per_word(&words); + assert_eq!(counts, vec![8, 16, 0]); + } + + #[test] + fn test_masked_popcount() { + let words = [0xFFu64, 0xFF00, 0xFFFF]; + let mask = 0xFF; + assert_eq!(super::masked_popcount_batch(&words, mask), vec![8, 0, 8]); + assert_eq!(super::masked_popcount_total(&words, mask), 16); + } + /// Edge: identical vectors → distance 0 at all tiers. #[cfg(target_arch = "x86_64")] #[test] diff --git a/src/hpc/byte_scan.rs b/src/hpc/byte_scan.rs new file mode 100644 index 00000000..f8d27b07 --- /dev/null +++ b/src/hpc/byte_scan.rs @@ -0,0 +1,224 @@ +//! Byte pattern scanning for NBT tag detection. +//! +//! SIMD-accelerated search for byte values and short patterns in contiguous +//! buffers. All functions operate on borrowed `&[u8]` slices with zero copies. +//! Scalar fallback is provided for non-x86 targets. + +// --------------------------------------------------------------------------- +// SIMD (x86_64 SSE2 / AVX2) internals +// --------------------------------------------------------------------------- + +#[cfg(target_arch = "x86_64")] +mod simd_impl { + use core::arch::x86_64::*; + + /// Find all positions of `needle` in `haystack` using AVX2 (32 bytes/iter). + /// + /// # Safety + /// Caller must ensure AVX2 is available. + #[target_feature(enable = "avx2")] + pub(super) unsafe fn byte_find_all_avx2(haystack: &[u8], needle: u8) -> Vec { + let mut result = Vec::new(); + let n = haystack.len(); + let ptr = haystack.as_ptr(); + let needle_v = _mm256_set1_epi8(needle as i8); + + let mut i = 0usize; + while i + 32 <= n { + let data = _mm256_loadu_si256(ptr.add(i) as *const __m256i); + let cmp = _mm256_cmpeq_epi8(data, needle_v); + let mut mask = _mm256_movemask_epi8(cmp) as u32; + while mask != 0 { + let bit = mask.trailing_zeros() as usize; + result.push(i + bit); + mask &= mask - 1; // clear lowest set bit + } + i += 32; + } + // Scalar tail + for j in i..n { + if *ptr.add(j) == needle { + result.push(j); + } + } + result + } + + /// Count occurrences of `needle` using AVX2. + /// + /// # Safety + /// Caller must ensure AVX2 is available. + #[target_feature(enable = "avx2")] + pub(super) unsafe fn byte_count_avx2(haystack: &[u8], needle: u8) -> usize { + let n = haystack.len(); + let ptr = haystack.as_ptr(); + let needle_v = _mm256_set1_epi8(needle as i8); + let mut total = 0usize; + + let mut i = 0usize; + while i + 32 <= n { + let data = _mm256_loadu_si256(ptr.add(i) as *const __m256i); + let cmp = _mm256_cmpeq_epi8(data, needle_v); + let mask = _mm256_movemask_epi8(cmp) as u32; + total += mask.count_ones() as usize; + i += 32; + } + for j in i..n { + if *ptr.add(j) == needle { + total += 1; + } + } + total + } +} + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +/// Find all occurrences of a byte value. Returns indices. +pub fn byte_find_all(haystack: &[u8], needle: u8) -> Vec { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx2") { + // SAFETY: feature detected above. + return unsafe { simd_impl::byte_find_all_avx2(haystack, needle) }; + } + } + // Scalar fallback + haystack + .iter() + .enumerate() + .filter_map(|(i, &b)| if b == needle { Some(i) } else { None }) + .collect() +} + +/// Find all occurrences of a 2-byte pattern (big-endian u16). Returns indices +/// of the first byte of each match. +pub fn u16_find_all(haystack: &[u8], pattern: u16) -> Vec { + let hi = (pattern >> 8) as u8; + let lo = (pattern & 0xFF) as u8; + if haystack.len() < 2 { + return Vec::new(); + } + let mut result = Vec::new(); + for i in 0..haystack.len() - 1 { + if haystack[i] == hi && haystack[i + 1] == lo { + result.push(i); + } + } + result +} + +/// Count occurrences of a byte value. +pub fn byte_count(haystack: &[u8], needle: u8) -> usize { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx2") { + // SAFETY: feature detected above. + return unsafe { simd_impl::byte_count_avx2(haystack, needle) }; + } + } + // Scalar fallback + haystack.iter().filter(|&&b| b == needle).count() +} + +/// Find first occurrence of a byte value. Returns index or `None`. +pub fn byte_find_first(haystack: &[u8], needle: u8) -> Option { + // memchr-style: the compiler will auto-vectorise this well, + // but we also have a fast-path via the find-all SIMD path. + haystack.iter().position(|&b| b == needle) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn naive_byte_find_all(haystack: &[u8], needle: u8) -> Vec { + haystack + .iter() + .enumerate() + .filter_map(|(i, &b)| if b == needle { Some(i) } else { None }) + .collect() + } + + fn naive_byte_count(haystack: &[u8], needle: u8) -> usize { + haystack.iter().filter(|&&b| b == needle).count() + } + + #[test] + fn test_byte_find_all_matches_naive() { + // Use a buffer that exercises both SIMD and scalar tail. + let buf: Vec = (0..200).map(|i| (i % 7) as u8).collect(); + for needle in 0..7u8 { + assert_eq!( + byte_find_all(&buf, needle), + naive_byte_find_all(&buf, needle), + "mismatch for needle {needle}" + ); + } + } + + #[test] + fn test_byte_count_matches_naive() { + let buf: Vec = (0..200).map(|i| (i % 7) as u8).collect(); + for needle in 0..7u8 { + assert_eq!( + byte_count(&buf, needle), + naive_byte_count(&buf, needle), + "mismatch for needle {needle}" + ); + } + } + + #[test] + fn test_u16_find_all() { + let buf = [0x00, 0x0A, 0x0B, 0x0A, 0x0B, 0xFF]; + let result = u16_find_all(&buf, 0x0A0B); + assert_eq!(result, vec![1, 3]); + } + + #[test] + fn test_u16_find_all_at_boundary() { + let buf = [0xAB, 0xCD]; + assert_eq!(u16_find_all(&buf, 0xABCD), vec![0]); + } + + #[test] + fn test_byte_find_first_found() { + let buf = [1, 2, 3, 4, 5]; + assert_eq!(byte_find_first(&buf, 3), Some(2)); + } + + #[test] + fn test_byte_find_first_not_found() { + let buf = [1, 2, 3, 4, 5]; + assert_eq!(byte_find_first(&buf, 99), None); + } + + #[test] + fn test_empty_haystack() { + let empty: &[u8] = &[]; + assert!(byte_find_all(empty, 0).is_empty()); + assert_eq!(byte_count(empty, 0), 0); + assert_eq!(byte_find_first(empty, 0), None); + assert!(u16_find_all(empty, 0x0000).is_empty()); + } + + #[test] + fn test_single_byte_haystack() { + assert_eq!(byte_find_all(&[42], 42), vec![0]); + assert_eq!(byte_find_all(&[42], 0), Vec::::new()); + assert!(u16_find_all(&[42], 0x2A00).is_empty()); + } + + #[test] + fn test_u16_not_found() { + let buf = [0x00, 0x01, 0x02, 0x03]; + assert!(u16_find_all(&buf, 0xFFFF).is_empty()); + } +} diff --git a/src/hpc/clam.rs b/src/hpc/clam.rs index 7cb20e00..b5adb2b0 100644 --- a/src/hpc/clam.rs +++ b/src/hpc/clam.rs @@ -1248,6 +1248,100 @@ pub fn compress(data: &[u8], vec_len: usize, tree: &ClamTree) -> CompressedTree use super::cascade::{Band, Cascade, RankedHit}; +impl ClamTree { + /// rho-nearest-neighbour via CLAM tree -> feed into Cascade verification. + /// + /// Uses triangle inequality pruning to find candidates within distance rho, + /// then returns candidate indices and their Hamming distances from query, + /// sorted by distance ascending. + /// + /// Unlike the standalone `rho_nn` function, this method directly returns + /// a simple `Vec<(usize, u64)>` suitable for piping into cascade verification + /// via `clam_cascade_search`. + pub fn rho_nn_candidates( + &self, + data: &[u8], + vec_len: usize, + query: &[u8], + rho: u64, + ) -> Vec<(usize, u64)> { + if self.nodes.is_empty() { + return Vec::new(); + } + + let mut candidates = Vec::new(); + let mut stack = vec![0usize]; // start at root + + while let Some(node_idx) = stack.pop() { + let node = &self.nodes[node_idx]; + let center = &data[self.reordered[node.center_idx] * vec_len..][..vec_len]; + let dist_to_center = (self.distance_fn)(query, center); + + // Triangle inequality: closest possible point in cluster + if node.delta_minus(dist_to_center) > rho { + continue; // entire cluster is too far + } + + if node.is_leaf() { + // Check all points in this leaf + for i in node.offset..node.offset + node.cardinality { + let idx = self.reordered[i]; + let point = &data[idx * vec_len..][..vec_len]; + let d = (self.distance_fn)(query, point); + if d <= rho { + candidates.push((idx, d)); + } + } + } else { + if let Some(left) = node.left { + stack.push(left); + } + if let Some(right) = node.right { + stack.push(right); + } + } + } + + candidates.sort_unstable_by_key(|&(_, d)| d); + candidates + } +} + +/// Bridge: CLAM tree pruning -> Cascade verification. +/// +/// Phase 1 of the CLAM+QualiaCAM pipeline: +/// 1. CLAM finds candidates within rho using triangle inequality +/// 2. Cascade verifies and classifies candidates into quality bands +/// +/// Returns only non-Reject hits, limited to `top_k`, sorted by Hamming distance. +pub fn clam_cascade_search( + tree: &ClamTree, + cascade: &Cascade, + data: &[u8], + vec_len: usize, + query: &[u8], + rho: u64, + top_k: usize, +) -> Vec { + let candidates = tree.rho_nn_candidates(data, vec_len, query, rho); + + // Convert to RankedHit with band classification + let mut hits: Vec = candidates + .into_iter() + .map(|(index, hamming)| RankedHit { + index, + hamming, + precise: 0.0, + band: cascade.expose(hamming as u32), + }) + .collect(); + + // Keep only non-Reject hits, limit to top_k + hits.retain(|h| h.band != Band::Reject); + hits.truncate(top_k); + hits +} + /// Result of CLAM→Cascade bridged search, combining CLAM's tight candidates /// with cascade verification and banding. #[derive(Debug, Clone)] @@ -2590,4 +2684,85 @@ mod tests { assert!(s.score >= 0.0 && s.score <= 1.0); } } + + // ── rho_nn_candidates tests ────────────────────────────────── + + #[test] + fn test_rho_nn_candidates() { + let vec_len = 32; + let n = 50; + let mut data = vec![0u8; n * vec_len]; + // Make distinct vectors + for i in 0..n { + data[i * vec_len + (i % vec_len)] = 0xFF; + } + let tree = ClamTree::build(&data, vec_len, 3); + let query = &data[0..vec_len]; + let candidates = tree.rho_nn_candidates(&data, vec_len, query, 16); + // Should find at least the exact match + assert!(candidates.iter().any(|&(_, d)| d == 0)); + // Results should be sorted by distance ascending + for w in candidates.windows(2) { + assert!(w[0].1 <= w[1].1); + } + } + + #[test] + fn test_rho_nn_candidates_empty_tree() { + let tree = ClamTree::build(&[], 32, 3); + let query = vec![0u8; 32]; + let candidates = tree.rho_nn_candidates(&[], 32, &query, 100); + assert!(candidates.is_empty()); + } + + #[test] + fn test_rho_nn_candidates_tight_rho() { + let vec_len = 32; + let n = 50; + let data = make_test_data(n, vec_len); + let tree = ClamTree::build(&data, vec_len, 3); + let query = &data[0..vec_len]; + // rho=0 should only find exact matches + let candidates = tree.rho_nn_candidates(&data, vec_len, query, 0); + for &(_, d) in &candidates { + assert_eq!(d, 0); + } + } + + // ── clam_cascade_search tests ──────────────────────────────── + + #[test] + fn test_clam_cascade_search() { + let vec_len = 32; + let n = 50; + let mut data = vec![0u8; n * vec_len]; + for i in 0..n { + data[i * vec_len + (i % vec_len)] = 0xFF; + } + let tree = ClamTree::build(&data, vec_len, 3); + let cascade = Cascade::from_threshold(vec_len as u64 * 4, vec_len); + let query = &data[0..vec_len]; + let hits = clam_cascade_search( + &tree, &cascade, &data, vec_len, query, vec_len as u64 * 8, 10, + ); + assert!(!hits.is_empty()); + // No Reject band hits should survive + for h in &hits { + assert_ne!(h.band, Band::Reject); + } + } + + #[test] + fn test_clam_cascade_search_respects_top_k() { + let vec_len = 32; + let n = 100; + let data = make_test_data(n, vec_len); + let tree = ClamTree::build(&data, vec_len, 3); + let cascade = Cascade::from_threshold(vec_len as u64 * 8, vec_len); + let query = &data[0..vec_len]; + let hits = clam_cascade_search( + &tree, &cascade, &data, vec_len, query, u64::MAX, 5, + ); + assert!(hits.len() <= 5); + } } diff --git a/src/hpc/crystal_encoder.rs b/src/hpc/crystal_encoder.rs index d0f711fb..f6979370 100644 --- a/src/hpc/crystal_encoder.rs +++ b/src/hpc/crystal_encoder.rs @@ -193,6 +193,128 @@ impl CrystalEncoder { fn flip_weight(&mut self, index: usize) { self.projection[index] = -self.projection[index]; } + + /// Encode an embedding, absorb it into a node under the given role, and + /// return both the fingerprint and the mutated node. + /// + /// Convenience wrapper that chains `encode_embedding` and `absorb_into_node`. + /// + /// # Example + /// + /// ``` + /// use ndarray::hpc::crystal_encoder::{CrystalEncoder, Role}; + /// use ndarray::hpc::node::Node; + /// + /// let enc = CrystalEncoder::new(4, 42); + /// let mut node = Node::new(); + /// let fp = enc.encode_and_absorb(&[1.0, -0.5, 0.3, 0.8], &mut node, Role::Subject); + /// assert!(node.s.encounters() > 0); + /// assert!(!fp.is_zero()); + /// ``` + pub fn encode_and_absorb( + &self, + embedding: &[f32], + node: &mut Node, + role: Role, + ) -> Fingerprint { + let fp = self.encode_embedding(embedding); + Self::absorb_into_node(&fp, node, role); + fp + } + + /// Search a database of nodes for the most similar to a query node. + /// + /// Computes SPO Hamming distance between the query node and each database + /// node, returning `(index, distance)` pairs sorted by distance ascending, + /// limited to `top_k` results. + /// + /// # Example + /// + /// ``` + /// use ndarray::hpc::crystal_encoder::{CrystalEncoder, Role}; + /// use ndarray::hpc::node::Node; + /// + /// let enc = CrystalEncoder::new(4, 42); + /// let mut query = Node::new(); + /// enc.encode_and_absorb(&[1.0, 0.0, 0.0, 0.0], &mut query, Role::Subject); + /// enc.encode_and_absorb(&[0.0, 1.0, 0.0, 0.0], &mut query, Role::Predicate); + /// enc.encode_and_absorb(&[0.0, 0.0, 1.0, 0.0], &mut query, Role::Object); + /// + /// let mut db = vec![Node::random(1), Node::random(2), query.clone()]; + /// let results = CrystalEncoder::search_similar(&mut query, &mut db, 2); + /// assert!(!results.is_empty()); + /// assert_eq!(results[0].1, 0); // exact match should have distance 0 + /// ``` + pub fn search_similar( + query: &mut Node, + database: &mut [Node], + top_k: usize, + ) -> Vec<(usize, u32)> { + let mut results: Vec<(usize, u32)> = database + .iter_mut() + .enumerate() + .map(|(i, db_node)| { + let d = spo_hamming(query, db_node); + (i, d) + }) + .collect(); + + results.sort_unstable_by_key(|&(_, d)| d); + results.truncate(top_k); + results + } +} + +/// Full pipeline: encode three embeddings (S/P/O) into a node, then search +/// a database for the closest matches. +/// +/// This wires the complete flow: projection -> node absorption -> search. +/// +/// # Arguments +/// * `encoder` — the CrystalEncoder with a loaded projection matrix. +/// * `subject_emb` — dense float embedding for the Subject plane. +/// * `predicate_emb` — dense float embedding for the Predicate plane. +/// * `object_emb` — dense float embedding for the Object plane. +/// * `database` — mutable slice of database nodes to search against. +/// * `top_k` — maximum number of results to return. +/// +/// # Returns +/// A tuple of `(query_node, results)` where results are `(index, distance)` pairs. +/// +/// # Example +/// +/// ``` +/// use ndarray::hpc::crystal_encoder::{CrystalEncoder, pipeline_encode_search}; +/// use ndarray::hpc::node::Node; +/// +/// let enc = CrystalEncoder::new(4, 42); +/// let mut db = vec![Node::random(1), Node::random(2)]; +/// let (query, results) = pipeline_encode_search( +/// &enc, +/// &[1.0, 0.0, 0.0, 0.0], +/// &[0.0, 1.0, 0.0, 0.0], +/// &[0.0, 0.0, 1.0, 0.0], +/// &mut db, +/// 5, +/// ); +/// assert!(query.s.encounters() > 0); +/// assert!(!results.is_empty()); +/// ``` +pub fn pipeline_encode_search( + encoder: &CrystalEncoder, + subject_emb: &[f32], + predicate_emb: &[f32], + object_emb: &[f32], + database: &mut [Node], + top_k: usize, +) -> (Node, Vec<(usize, u32)>) { + let mut query_node = Node::new(); + encoder.encode_and_absorb(subject_emb, &mut query_node, Role::Subject); + encoder.encode_and_absorb(predicate_emb, &mut query_node, Role::Predicate); + encoder.encode_and_absorb(object_emb, &mut query_node, Role::Object); + + let results = CrystalEncoder::search_similar(&mut query_node, database, top_k); + (query_node, results) } // ============================================================================ @@ -845,6 +967,85 @@ mod tests { // -- Integration test: full pipeline ------------------------------------ + // -- Pipeline wiring tests ----------------------------------------------- + + #[test] + fn encode_and_absorb_returns_fingerprint() { + let enc = CrystalEncoder::new(4, 42); + let mut node = Node::new(); + let fp = enc.encode_and_absorb(&[1.0, -0.5, 0.3, 0.8], &mut node, Role::Subject); + assert!(!fp.is_zero()); + assert_eq!(node.s.encounters(), 1); + assert_eq!(node.p.encounters(), 0); + } + + #[test] + fn search_similar_finds_exact_match() { + let enc = CrystalEncoder::new(4, 42); + let emb = [1.0f32, -0.5, 0.3, 0.8]; + + let mut query = Node::new(); + enc.encode_and_absorb(&emb, &mut query, Role::Subject); + enc.encode_and_absorb(&emb, &mut query, Role::Predicate); + enc.encode_and_absorb(&emb, &mut query, Role::Object); + + // Clone query into database + let mut db = vec![Node::random(1), query.clone(), Node::random(2)]; + let results = CrystalEncoder::search_similar(&mut query, &mut db, 5); + assert!(!results.is_empty()); + // Best match should be the clone at index 1 with distance 0 + assert_eq!(results[0].0, 1); + assert_eq!(results[0].1, 0); + } + + #[test] + fn search_similar_respects_top_k() { + let mut query = Node::random(42); + let mut db: Vec = (0..10).map(|i| Node::random(i + 100)).collect(); + let results = CrystalEncoder::search_similar(&mut query, &mut db, 3); + assert!(results.len() <= 3); + } + + #[test] + fn pipeline_encode_search_basic() { + let enc = CrystalEncoder::new(4, 42); + let mut db = vec![Node::random(1), Node::random(2), Node::random(3)]; + let (query, results) = pipeline_encode_search( + &enc, + &[1.0, 0.0, 0.0, 0.0], + &[0.0, 1.0, 0.0, 0.0], + &[0.0, 0.0, 1.0, 0.0], + &mut db, + 5, + ); + assert_eq!(query.s.encounters(), 1); + assert_eq!(query.p.encounters(), 1); + assert_eq!(query.o.encounters(), 1); + assert!(!results.is_empty()); + // Results should be sorted by distance + for w in results.windows(2) { + assert!(w[0].1 <= w[1].1); + } + } + + #[test] + fn pipeline_encode_search_empty_db() { + let enc = CrystalEncoder::new(4, 42); + let mut db: Vec = vec![]; + let (query, results) = pipeline_encode_search( + &enc, + &[1.0, 0.0, 0.0, 0.0], + &[0.0, 1.0, 0.0, 0.0], + &[0.0, 0.0, 1.0, 0.0], + &mut db, + 5, + ); + assert_eq!(query.s.encounters(), 1); + assert!(results.is_empty()); + } + + // -- Phase 1 full integration test (original) -------------------------- + #[test] fn full_pipeline_encode_absorb_measure() { let enc = CrystalEncoder::new(16, 42); diff --git a/src/hpc/distance.rs b/src/hpc/distance.rs new file mode 100644 index 00000000..4060fa84 --- /dev/null +++ b/src/hpc/distance.rs @@ -0,0 +1,361 @@ +//! Batch distance computations for spatial queries. +//! +//! SIMD-accelerated squared-distance, radius filtering, and K-nearest-neighbor +//! searches over contiguous point slices. All operations work on borrowed slices +//! with no internal copies. Scalar fallback is provided for non-x86 targets. + +// --------------------------------------------------------------------------- +// Scalar helpers +// --------------------------------------------------------------------------- + +#[inline] +fn sq_dist_f32(a: [f32; 3], b: [f32; 3]) -> f32 { + let dx = a[0] - b[0]; + let dy = a[1] - b[1]; + let dz = a[2] - b[2]; + dx * dx + dy * dy + dz * dz +} + +#[inline] +fn sq_dist_f64(a: [f64; 3], b: [f64; 3]) -> f64 { + let dx = a[0] - b[0]; + let dy = a[1] - b[1]; + let dz = a[2] - b[2]; + dx * dx + dy * dy + dz * dz +} + +// --------------------------------------------------------------------------- +// SIMD (x86_64 AVX2) internals +// --------------------------------------------------------------------------- + +#[cfg(target_arch = "x86_64")] +mod simd_impl { + #[cfg(target_arch = "x86_64")] + use core::arch::x86_64::*; + + /// Compute squared distances for 8 points at a time using AVX2. + /// `query` components are broadcast; `points` is read in SOA-style chunks. + /// + /// # Safety + /// Caller must ensure AVX2 is available. + #[target_feature(enable = "avx2")] + pub(super) unsafe fn squared_distances_avx2( + query: [f32; 3], + points: &[[f32; 3]], + out: &mut Vec, + ) { + let n = points.len(); + out.clear(); + out.reserve(n); + + let qx = _mm256_set1_ps(query[0]); + let qy = _mm256_set1_ps(query[1]); + let qz = _mm256_set1_ps(query[2]); + + let ptr = points.as_ptr() as *const f32; + // Each point is 3 floats => stride 3 + let mut i = 0usize; + // Process 8 points at a time + while i + 8 <= n { + // Gather x, y, z for 8 points (scalar gather — AVX2 gather is slow + // on many microarchitectures for non-contiguous strides). + let mut xs = [0f32; 8]; + let mut ys = [0f32; 8]; + let mut zs = [0f32; 8]; + for j in 0..8 { + let base = (i + j) * 3; + xs[j] = *ptr.add(base); + ys[j] = *ptr.add(base + 1); + zs[j] = *ptr.add(base + 2); + } + + let vx = _mm256_loadu_ps(xs.as_ptr()); + let vy = _mm256_loadu_ps(ys.as_ptr()); + let vz = _mm256_loadu_ps(zs.as_ptr()); + + let dx = _mm256_sub_ps(qx, vx); + let dy = _mm256_sub_ps(qy, vy); + let dz = _mm256_sub_ps(qz, vz); + + // dx*dx + dy*dy + dz*dz (FMA where available) + let mut acc = _mm256_mul_ps(dx, dx); + acc = _mm256_add_ps(acc, _mm256_mul_ps(dy, dy)); + acc = _mm256_add_ps(acc, _mm256_mul_ps(dz, dz)); + + let mut tmp = [0f32; 8]; + _mm256_storeu_ps(tmp.as_mut_ptr(), acc); + out.extend_from_slice(&tmp); + + i += 8; + } + + // Scalar tail + for j in i..n { + let dx = query[0] - points[j][0]; + let dy = query[1] - points[j][1]; + let dz = query[2] - points[j][2]; + out.push(dx * dx + dy * dy + dz * dz); + } + } +} + +// --------------------------------------------------------------------------- +// Public API — f32 +// --------------------------------------------------------------------------- + +/// Squared distance from one point to N points (f32). +/// +/// Returns a `Vec` of length `points.len()`. +pub fn squared_distances_f32(query: [f32; 3], points: &[[f32; 3]]) -> Vec { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx2") { + let mut out = Vec::new(); + // SAFETY: feature detected above. + unsafe { simd_impl::squared_distances_avx2(query, points, &mut out) }; + return out; + } + } + // Scalar fallback + points.iter().map(|p| sq_dist_f32(query, *p)).collect() +} + +/// Filter points by max squared distance. Returns indices of survivors. +pub fn filter_by_radius_sq( + query: [f32; 3], + points: &[[f32; 3]], + radius_sq: f32, +) -> Vec { + let dists = squared_distances_f32(query, points); + dists + .iter() + .enumerate() + .filter_map(|(i, &d)| if d <= radius_sq { Some(i) } else { None }) + .collect() +} + +/// Find K nearest points (f32). Returns `(indices, squared_distances)` sorted +/// ascending by distance. +pub fn knn_f32( + query: [f32; 3], + points: &[[f32; 3]], + k: usize, +) -> (Vec, Vec) { + let dists = squared_distances_f32(query, points); + let mut indexed: Vec<(usize, f32)> = dists.into_iter().enumerate().collect(); + indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal)); + let take = k.min(indexed.len()); + let indices: Vec = indexed[..take].iter().map(|&(i, _)| i).collect(); + let sq_dists: Vec = indexed[..take].iter().map(|&(_, d)| d).collect(); + (indices, sq_dists) +} + +// --------------------------------------------------------------------------- +// Public API — f64 +// --------------------------------------------------------------------------- + +/// Squared distance from one point to N points (f64). +/// +/// Uses scalar path (AVX2 f64 lanes are only 4-wide so the gain is marginal +/// for AOS-3 data). +pub fn squared_distances_f64(query: [f64; 3], points: &[[f64; 3]]) -> Vec { + points.iter().map(|p| sq_dist_f64(query, *p)).collect() +} + +/// Filter f64 points by squared-distance radius. Returns survivor indices. +pub fn filter_by_radius_sq_f64( + query: [f64; 3], + points: &[[f64; 3]], + radius_sq: f64, +) -> Vec { + let dists = squared_distances_f64(query, points); + dists + .iter() + .enumerate() + .filter_map(|(i, &d)| if d <= radius_sq { Some(i) } else { None }) + .collect() +} + +/// Find K nearest points (f64). Returns `(indices, squared_distances)` sorted +/// ascending by distance. +pub fn knn_f64( + query: [f64; 3], + points: &[[f64; 3]], + k: usize, +) -> (Vec, Vec) { + let dists = squared_distances_f64(query, points); + let mut indexed: Vec<(usize, f64)> = dists.into_iter().enumerate().collect(); + indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal)); + let take = k.min(indexed.len()); + let indices: Vec = indexed[..take].iter().map(|&(i, _)| i).collect(); + let sq_dists: Vec = indexed[..take].iter().map(|&(_, d)| d).collect(); + (indices, sq_dists) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn approx_eq_f32(a: f32, b: f32) -> bool { + (a - b).abs() < 1e-5 + } + + fn approx_eq_f64(a: f64, b: f64) -> bool { + (a - b).abs() < 1e-10 + } + + // -- scalar parity -- + + #[test] + fn test_squared_distances_f32_matches_scalar() { + let query = [1.0f32, 2.0, 3.0]; + let points: Vec<[f32; 3]> = (0..33) + .map(|i| { + let v = i as f32; + [v, v + 1.0, v + 2.0] + }) + .collect(); + let result = squared_distances_f32(query, &points); + assert_eq!(result.len(), points.len()); + for (i, &d) in result.iter().enumerate() { + let expected = sq_dist_f32(query, points[i]); + assert!( + approx_eq_f32(d, expected), + "mismatch at {i}: {d} vs {expected}" + ); + } + } + + #[test] + fn test_squared_distances_f64_matches_scalar() { + let query = [1.0f64, 2.0, 3.0]; + let points: Vec<[f64; 3]> = (0..33) + .map(|i| { + let v = i as f64; + [v, v + 1.0, v + 2.0] + }) + .collect(); + let result = squared_distances_f64(query, &points); + for (i, &d) in result.iter().enumerate() { + let expected = sq_dist_f64(query, points[i]); + assert!( + approx_eq_f64(d, expected), + "mismatch at {i}: {d} vs {expected}" + ); + } + } + + // -- filter -- + + #[test] + fn test_filter_by_radius_sq() { + let query = [0.0f32, 0.0, 0.0]; + let points = vec![[1.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.5, 0.0, 0.0]]; + let result = filter_by_radius_sq(query, &points, 1.0); + // Point 0: dist=1.0, pass; Point 1: dist=4.0, fail; Point 2: dist=0.25, pass + assert_eq!(result, vec![0, 2]); + } + + #[test] + fn test_filter_by_radius_sq_f64() { + let query = [0.0f64, 0.0, 0.0]; + let points = vec![[1.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.5, 0.0, 0.0]]; + let result = filter_by_radius_sq_f64(query, &points, 1.0); + assert_eq!(result, vec![0, 2]); + } + + #[test] + fn test_filter_vs_brute_force_f32() { + let query = [5.0f32, 5.0, 5.0]; + let points: Vec<[f32; 3]> = (0..100) + .map(|i| { + let v = i as f32 * 0.3; + [v, v, v] + }) + .collect(); + let radius_sq = 10.0f32; + let result = filter_by_radius_sq(query, &points, radius_sq); + let brute: Vec = points + .iter() + .enumerate() + .filter(|(_, p)| sq_dist_f32(query, **p) <= radius_sq) + .map(|(i, _)| i) + .collect(); + assert_eq!(result, brute); + } + + // -- knn -- + + #[test] + fn test_knn_f32() { + let query = [0.0f32, 0.0, 0.0]; + let points = vec![ + [3.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [2.0, 0.0, 0.0], + [0.5, 0.0, 0.0], + ]; + let (idx, dist) = knn_f32(query, &points, 2); + assert_eq!(idx, vec![3, 1]); // 0.25, 1.0 + assert!(approx_eq_f32(dist[0], 0.25)); + assert!(approx_eq_f32(dist[1], 1.0)); + } + + #[test] + fn test_knn_f64() { + let query = [0.0f64, 0.0, 0.0]; + let points = vec![ + [3.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [2.0, 0.0, 0.0], + [0.5, 0.0, 0.0], + ]; + let (idx, dist) = knn_f64(query, &points, 2); + assert_eq!(idx, vec![3, 1]); + assert!(approx_eq_f64(dist[0], 0.25)); + assert!(approx_eq_f64(dist[1], 1.0)); + } + + #[test] + fn test_knn_k_larger_than_n() { + let query = [0.0f32, 0.0, 0.0]; + let points = vec![[1.0, 0.0, 0.0]]; + let (idx, dist) = knn_f32(query, &points, 10); + assert_eq!(idx.len(), 1); + assert_eq!(dist.len(), 1); + } + + // -- edge cases -- + + #[test] + fn test_empty_points() { + let query = [0.0f32, 0.0, 0.0]; + let empty: &[[f32; 3]] = &[]; + assert!(squared_distances_f32(query, empty).is_empty()); + assert!(filter_by_radius_sq(query, empty, 1.0).is_empty()); + let (idx, dist) = knn_f32(query, empty, 5); + assert!(idx.is_empty()); + assert!(dist.is_empty()); + } + + #[test] + fn test_single_point() { + let query = [0.0f32, 0.0, 0.0]; + let points = vec![[1.0, 1.0, 1.0]]; + let result = squared_distances_f32(query, &points); + assert_eq!(result.len(), 1); + assert!(approx_eq_f32(result[0], 3.0)); + } + + #[test] + fn test_zero_distance() { + let query = [5.0f32, 10.0, 15.0]; + let points = vec![query]; + let result = squared_distances_f32(query, &points); + assert!(approx_eq_f32(result[0], 0.0)); + } +} diff --git a/src/hpc/jitson/mod.rs b/src/hpc/jitson/mod.rs index 96d08c61..668e0683 100644 --- a/src/hpc/jitson/mod.rs +++ b/src/hpc/jitson/mod.rs @@ -29,6 +29,7 @@ pub mod template; pub mod precompile; pub mod scan_config; pub mod packed; +pub mod noise; // Re-exports: parser layer pub use parser::{parse_json, JsonValue, ParseError}; @@ -48,3 +49,6 @@ pub use scan_config::{ ScanConfig, ScanResult, SimdKernelRegistry, DefaultKernelRegistry, scan_hamming, jit_symbol_table, }; + +// Re-exports: noise parameters +pub use noise::{NoiseParams, GRAD3, simple_noise_3d}; diff --git a/src/hpc/jitson/noise.rs b/src/hpc/jitson/noise.rs new file mode 100644 index 00000000..61a1c207 --- /dev/null +++ b/src/hpc/jitson/noise.rs @@ -0,0 +1,102 @@ +//! Noise parameter types for JIT compilation of noise functions. +//! +//! Per-octave frequency and amplitude scales are baked into compiled +//! functions as immediate values, avoiding per-sample parameter loads. +//! These types are always available (no Cranelift dependency). + +/// Noise octave parameters — compiled as JIT immediates. +/// +/// Per-octave frequency and amplitude scales are baked into the compiled +/// function as immediate values, avoiding per-sample parameter loads. +#[derive(Debug, Clone)] +pub struct NoiseParams { + /// Per-octave: (frequency_scale, amplitude_scale) + pub octaves: Vec<(f64, f64)>, + /// Lacunarity: frequency multiplier per octave + pub lacunarity: f64, + /// Persistence: amplitude multiplier per octave + pub persistence: f64, +} + +impl NoiseParams { + /// Create standard Perlin noise parameters. + pub fn perlin(num_octaves: usize, lacunarity: f64, persistence: f64) -> Self { + let mut octaves = Vec::with_capacity(num_octaves); + let mut freq = 1.0; + let mut amp = 1.0; + for _ in 0..num_octaves { + octaves.push((freq, amp)); + freq *= lacunarity; + amp *= persistence; + } + Self { octaves, lacunarity, persistence } + } + + /// Number of octaves. + pub fn num_octaves(&self) -> usize { + self.octaves.len() + } + + /// Total amplitude sum (for normalization). + pub fn amplitude_sum(&self) -> f64 { + self.octaves.iter().map(|(_, a)| a.abs()).sum() + } + + /// Evaluate noise at a point using scalar octave accumulation. + /// This is the reference implementation that JIT-compiled code must match. + pub fn evaluate_reference(&self, x: f64, y: f64, z: f64, base_noise: fn(f64, f64, f64) -> f64) -> f64 { + let mut value = 0.0; + for &(freq, amp) in &self.octaves { + value += amp * base_noise(x * freq, y * freq, z * freq); + } + value + } +} + +/// Gradient vectors for 3D Perlin noise (12 edges of a cube). +pub const GRAD3: [[f64; 3]; 12] = [ + [1.0, 1.0, 0.0], [-1.0, 1.0, 0.0], [1.0, -1.0, 0.0], [-1.0, -1.0, 0.0], + [1.0, 0.0, 1.0], [-1.0, 0.0, 1.0], [1.0, 0.0, -1.0], [-1.0, 0.0, -1.0], + [0.0, 1.0, 1.0], [0.0, -1.0, 1.0], [0.0, 1.0, -1.0], [0.0, -1.0, -1.0], +]; + +/// Simple hash-based 3D noise (deterministic, not cryptographic). +pub fn simple_noise_3d(x: f64, y: f64, z: f64) -> f64 { + // Simple value noise for testing + let ix = x.floor() as i64; + let iy = y.floor() as i64; + let iz = z.floor() as i64; + let hash = (ix.wrapping_mul(73856093) ^ iy.wrapping_mul(19349663) ^ iz.wrapping_mul(83492791)) as u64; + // Map to [-1, 1] + (hash % 1000) as f64 / 500.0 - 1.0 +} + +#[cfg(test)] +mod noise_tests { + use super::*; + + #[test] + fn test_noise_params_perlin() { + let params = NoiseParams::perlin(4, 2.0, 0.5); + assert_eq!(params.num_octaves(), 4); + assert!((params.octaves[0].0 - 1.0).abs() < 1e-10); + assert!((params.octaves[1].0 - 2.0).abs() < 1e-10); + assert!((params.octaves[1].1 - 0.5).abs() < 1e-10); + } + + #[test] + fn test_noise_evaluate_deterministic() { + let params = NoiseParams::perlin(4, 2.0, 0.5); + let v1 = params.evaluate_reference(1.0, 2.0, 3.0, simple_noise_3d); + let v2 = params.evaluate_reference(1.0, 2.0, 3.0, simple_noise_3d); + assert_eq!(v1, v2); + } + + #[test] + fn test_amplitude_sum() { + let params = NoiseParams::perlin(4, 2.0, 0.5); + let sum = params.amplitude_sum(); + // 1.0 + 0.5 + 0.25 + 0.125 = 1.875 + assert!((sum - 1.875).abs() < 1e-10); + } +} diff --git a/src/hpc/mod.rs b/src/hpc/mod.rs index c04c0578..cadab298 100644 --- a/src/hpc/mod.rs +++ b/src/hpc/mod.rs @@ -129,10 +129,23 @@ pub mod parallel_search; #[allow(missing_docs)] pub mod zeck; +// SIMD-accelerated spatial / byte-scan / hash utilities +pub mod distance; +pub mod byte_scan; +pub mod spatial_hash; + // Variable-width palette index codec (Minecraft-style bit packing) #[allow(missing_docs)] pub mod palette_codec; +// SIMD-accelerated HPC modules (block properties, nibble light data, AABB collision) +#[allow(missing_docs)] +pub mod property_mask; +#[allow(missing_docs)] +pub mod nibble; +#[allow(missing_docs)] +pub mod aabb; + // Holographic phase-space operations (ported from rustynum-holo) #[allow(missing_docs)] #[allow(clippy::needless_range_loop)] diff --git a/src/hpc/nibble.rs b/src/hpc/nibble.rs new file mode 100644 index 00000000..f0c57f49 --- /dev/null +++ b/src/hpc/nibble.rs @@ -0,0 +1,308 @@ +//! Nibble batch operations for 4-bit packed data (light levels). +//! +//! Light levels in Minecraft are stored as 4-bit nibbles packed two per byte +//! (low nibble = even index, high nibble = odd index). This module provides +//! SIMD-accelerated batch operations on packed nibble arrays. + +/// Unpack 4-bit nibbles from a packed byte array into full `u8` values (0-15). +/// +/// Each byte in `packed` holds two nibbles: the low nibble at even index, +/// the high nibble at the subsequent odd index. Exactly `count` values +/// are returned. +/// +/// # Panics +/// Panics if `packed.len() < (count + 1) / 2`. +/// +/// # Examples +/// +/// ``` +/// use ndarray::hpc::nibble::nibble_unpack; +/// let packed = &[0x3A]; // low=0xA, high=0x3 +/// assert_eq!(nibble_unpack(packed, 2), vec![0xA, 0x3]); +/// ``` +pub fn nibble_unpack(packed: &[u8], count: usize) -> Vec { + assert!(packed.len() >= (count + 1) / 2, "packed buffer too small"); + + let mut out = Vec::with_capacity(count); + + nibble_unpack_scalar(packed, count, &mut out); + out +} + +fn nibble_unpack_scalar(packed: &[u8], count: usize, out: &mut Vec) { + for i in 0..count { + let byte = packed[i / 2]; + let val = if i & 1 == 0 { byte & 0x0F } else { byte >> 4 }; + out.push(val); + } +} + +/// Pack `u8` values (each 0-15) into 4-bit nibble pairs. +/// +/// Values are clamped to 0-15. The resulting byte count is `(values.len() + 1) / 2`. +/// +/// # Examples +/// +/// ``` +/// use ndarray::hpc::nibble::nibble_pack; +/// let packed = nibble_pack(&[0xA, 0x3]); +/// assert_eq!(packed, vec![0x3A]); +/// ``` +pub fn nibble_pack(values: &[u8]) -> Vec { + let out_len = (values.len() + 1) / 2; + let mut out = vec![0u8; out_len]; + + for (i, &v) in values.iter().enumerate() { + let clamped = v & 0x0F; + let byte_idx = i / 2; + if i & 1 == 0 { + out[byte_idx] |= clamped; + } else { + out[byte_idx] |= clamped << 4; + } + } + out +} + +/// Batch subtract with clamp: every nibble in `packed` has `delta` subtracted, +/// clamping to 0. Used for light propagation BFS decay. +/// +/// Operates in-place on the packed representation. +pub fn nibble_sub_clamp(packed: &mut [u8], delta: u8) { + if delta == 0 { + return; + } + if delta >= 15 { + for b in packed.iter_mut() { + *b = 0; + } + return; + } + + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx2") { + // SAFETY: avx2 detected, slice is mutable and valid. + unsafe { + nibble_sub_clamp_avx2(packed, delta); + return; + } + } + } + + nibble_sub_clamp_scalar(packed, delta); +} + +fn nibble_sub_clamp_scalar(packed: &mut [u8], delta: u8) { + for b in packed.iter_mut() { + let lo = (*b & 0x0F).saturating_sub(delta); + let hi = ((*b >> 4) & 0x0F).saturating_sub(delta); + *b = lo | (hi << 4); + } +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn nibble_sub_clamp_avx2(packed: &mut [u8], delta: u8) { + use core::arch::x86_64::*; + + let mask_lo = _mm256_set1_epi8(0x0F); + let mask_hi = _mm256_set1_epi8(0xF0u8 as i8); + let delta_v = _mm256_set1_epi8(delta as i8); + // delta shifted into high nibble position for direct subtraction + let delta_hi = _mm256_set1_epi8((delta << 4) as i8); + let chunks = packed.len() / 32; + + for c in 0..chunks { + let ptr = packed.as_mut_ptr().add(c * 32); + let data = _mm256_loadu_si256(ptr as *const __m256i); + + // Extract low nibbles, subtract with saturation + let lo = _mm256_and_si256(data, mask_lo); + let lo_sub = _mm256_subs_epu8(lo, delta_v); + + // Extract high nibbles (keep in high position), subtract with saturation + let hi = _mm256_and_si256(data, mask_hi); + let hi_sub = _mm256_subs_epu8(hi, delta_hi); + + // Combine: low nibbles are already clean (0-15), high nibbles already in position + let result = _mm256_or_si256( + _mm256_and_si256(lo_sub, mask_lo), + _mm256_and_si256(hi_sub, mask_hi), + ); + + _mm256_storeu_si256(ptr as *mut __m256i, result); + } + + // Scalar tail + nibble_sub_clamp_scalar(&mut packed[chunks * 32..], delta); +} + +/// Find all nibble indices with value strictly above `threshold`. Returns sorted indices. +pub fn nibble_above_threshold(packed: &[u8], threshold: u8) -> Vec { + let mut result = Vec::new(); + let count = packed.len() * 2; + for i in 0..count { + if nibble_get(packed, i) > threshold { + result.push(i); + } + } + result +} + +/// Get a single nibble value at the given index. +/// +/// # Panics +/// Panics if `index / 2 >= packed.len()`. +#[inline] +pub fn nibble_get(packed: &[u8], index: usize) -> u8 { + let byte = packed[index / 2]; + if index & 1 == 0 { + byte & 0x0F + } else { + byte >> 4 + } +} + +/// Set a single nibble value at the given index. Value is clamped to 0-15. +/// +/// # Panics +/// Panics if `index / 2 >= packed.len()`. +#[inline] +pub fn nibble_set(packed: &mut [u8], index: usize, value: u8) { + let clamped = value & 0x0F; + let byte_idx = index / 2; + if index & 1 == 0 { + packed[byte_idx] = (packed[byte_idx] & 0xF0) | clamped; + } else { + packed[byte_idx] = (packed[byte_idx] & 0x0F) | (clamped << 4); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_roundtrip_pack_unpack() { + let original: Vec = (0..16).collect(); + let packed = nibble_pack(&original); + let unpacked = nibble_unpack(&packed, original.len()); + assert_eq!(unpacked, original); + } + + #[test] + fn test_roundtrip_odd_count() { + let original = vec![1, 5, 9]; + let packed = nibble_pack(&original); + let unpacked = nibble_unpack(&packed, 3); + assert_eq!(unpacked, original); + } + + #[test] + fn test_roundtrip_large() { + let original: Vec = (0..4096).map(|i| (i % 16) as u8).collect(); + let packed = nibble_pack(&original); + let unpacked = nibble_unpack(&packed, original.len()); + assert_eq!(unpacked, original); + } + + #[test] + fn test_get_set() { + let mut packed = vec![0u8; 4]; // 8 nibbles + for i in 0..8 { + nibble_set(&mut packed, i, (i as u8) % 16); + } + for i in 0..8 { + assert_eq!(nibble_get(&packed, i), (i as u8) % 16); + } + } + + #[test] + fn test_sub_clamp_basic() { + // Two bytes: nibbles [5, 3, 10, 1] + let mut packed = nibble_pack(&[5, 3, 10, 1]); + nibble_sub_clamp(&mut packed, 3); + let vals = nibble_unpack(&packed, 4); + assert_eq!(vals, vec![2, 0, 7, 0]); + } + + #[test] + fn test_sub_clamp_zero_delta() { + let mut packed = nibble_pack(&[5, 3, 10, 1]); + let original = packed.clone(); + nibble_sub_clamp(&mut packed, 0); + assert_eq!(packed, original); + } + + #[test] + fn test_sub_clamp_large_delta() { + let mut packed = nibble_pack(&[15, 15, 15, 15]); + nibble_sub_clamp(&mut packed, 15); + let vals = nibble_unpack(&packed, 4); + assert_eq!(vals, vec![0, 0, 0, 0]); + } + + #[test] + fn test_sub_clamp_large() { + let original: Vec = (0..256).map(|i| (i % 16) as u8).collect(); + let mut packed = nibble_pack(&original); + nibble_sub_clamp(&mut packed, 4); + let result = nibble_unpack(&packed, original.len()); + for (i, (&orig, &res)) in original.iter().zip(result.iter()).enumerate() { + assert_eq!(res, orig.saturating_sub(4), "mismatch at nibble {}", i); + } + } + + #[test] + fn test_above_threshold() { + let packed = nibble_pack(&[0, 5, 3, 15, 7, 1, 14, 8]); + let above_5 = nibble_above_threshold(&packed, 5); + // Indices with value > 5: index 3 (15), 4 (7), 6 (14), 7 (8) + assert_eq!(above_5, vec![3, 4, 6, 7]); + } + + #[test] + fn test_above_threshold_none() { + let packed = nibble_pack(&[0, 1, 2, 3]); + assert!(nibble_above_threshold(&packed, 15).is_empty()); + } + + #[test] + fn test_above_threshold_all() { + let packed = nibble_pack(&[15, 15, 15, 15]); + let above_0 = nibble_above_threshold(&packed, 0); + assert_eq!(above_0, vec![0, 1, 2, 3]); + } + + #[test] + fn test_clamping_on_pack() { + // Values above 15 should be clamped + let packed = nibble_pack(&[0xFF, 0x1A]); + let unpacked = nibble_unpack(&packed, 2); + assert_eq!(unpacked[0], 0x0F); + assert_eq!(unpacked[1], 0x0A); + } + + #[test] + fn test_empty() { + let packed = nibble_pack(&[]); + assert!(packed.is_empty()); + let unpacked = nibble_unpack(&packed, 0); + assert!(unpacked.is_empty()); + } + + #[test] + fn test_single_nibble() { + let packed = nibble_pack(&[7]); + assert_eq!(packed.len(), 1); + let unpacked = nibble_unpack(&packed, 1); + assert_eq!(unpacked, vec![7]); + } + + #[test] + #[should_panic(expected = "packed buffer too small")] + fn test_unpack_too_small() { + nibble_unpack(&[0x00], 4); // 1 byte can hold 2 nibbles, not 4 + } +} diff --git a/src/hpc/palette_codec.rs b/src/hpc/palette_codec.rs index 313f9a13..5de88c14 100644 --- a/src/hpc/palette_codec.rs +++ b/src/hpc/palette_codec.rs @@ -249,6 +249,92 @@ impl PackedPaletteArray { } } +/// SIMD-accelerated palette unpacking. +/// Falls back to scalar `unpack_indices` on non-AVX2 targets. +/// +/// # Example +/// +/// ``` +/// use ndarray::hpc::palette_codec::{pack_indices, unpack_indices_simd}; +/// +/// let indices: Vec = (0..64).map(|i| i % 16).collect(); +/// let packed = pack_indices(&indices, 4); +/// let recovered = unpack_indices_simd(&packed, 4, 64); +/// assert_eq!(indices, recovered); +/// ``` +pub fn unpack_indices_simd(packed: &[u64], bits_per_index: usize, count: usize) -> Vec { + #[cfg(target_arch = "x86_64")] + { + if bits_per_index == 4 && count >= 16 && is_x86_feature_detected!("avx2") { + return unsafe { unpack_4bit_avx2(packed, count) }; + } + } + unpack_indices(packed, bits_per_index, count) +} + +/// SIMD-accelerated palette packing. +/// Falls back to scalar `pack_indices` on non-AVX2 targets. +pub fn pack_indices_simd(indices: &[u8], bits_per_index: usize) -> Vec { + // Currently delegates to scalar — SIMD packing is less critical than unpacking + // because packing happens at chunk save time (cold path). + pack_indices(indices, bits_per_index) +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn unpack_4bit_avx2(packed: &[u64], count: usize) -> Vec { + use core::arch::x86_64::*; + + let mut result = Vec::with_capacity(count); + let bytes = bytemuck_cast_u64_to_u8(packed); + let low_mask = _mm256_set1_epi8(0x0F); + let mut i = 0; + + // Process 32 bytes at a time → 64 nibbles + while i + 32 <= bytes.len() && result.len() + 64 <= count { + let data = _mm256_loadu_si256(bytes.as_ptr().add(i) as *const __m256i); + + // Extract low nibbles + let lo = _mm256_and_si256(data, low_mask); + // Extract high nibbles + let hi = _mm256_and_si256(_mm256_srli_epi16(data, 4), low_mask); + + // Interleave: we need to output low nibble, then high nibble for each byte + let interleaved_lo = _mm256_unpacklo_epi8(lo, hi); + let interleaved_hi = _mm256_unpackhi_epi8(lo, hi); + + // Store + let mut buf = [0u8; 64]; + _mm256_storeu_si256(buf.as_mut_ptr() as *mut __m256i, interleaved_lo); + _mm256_storeu_si256(buf.as_mut_ptr().add(32) as *mut __m256i, interleaved_hi); + + let remaining = count - result.len(); + let take = remaining.min(64); + result.extend_from_slice(&buf[..take]); + i += 32; + } + + // Handle remainder with scalar + let scalar_start = result.len(); + if scalar_start < count { + let remainder = unpack_indices(packed, 4, count); + result.extend_from_slice(&remainder[scalar_start..]); + } + + result +} + +/// Reinterpret &[u64] as &[u8] (little-endian safe). +fn bytemuck_cast_u64_to_u8(words: &[u64]) -> &[u8] { + // SAFETY: u64 and u8 have compatible layouts on little-endian + unsafe { + core::slice::from_raw_parts( + words.as_ptr() as *const u8, + words.len() * 8, + ) + } +} + #[cfg(test)] mod tests { use super::*; @@ -433,4 +519,30 @@ mod tests { } } } + + #[test] + fn test_unpack_simd_4bit_matches_scalar() { + let indices: Vec = (0..4096).map(|i| (i % 16) as u8).collect(); + let packed = pack_indices(&indices, 4); + let scalar = unpack_indices(&packed, 4, 4096); + let simd = unpack_indices_simd(&packed, 4, 4096); + assert_eq!(scalar, simd, "SIMD 4-bit unpack must match scalar"); + } + + #[test] + fn test_unpack_simd_non_4bit_fallback() { + let indices: Vec = (0..100).map(|i| (i % 8) as u8).collect(); + let packed = pack_indices(&indices, 3); + let scalar = unpack_indices(&packed, 3, 100); + let simd = unpack_indices_simd(&packed, 3, 100); + assert_eq!(scalar, simd, "non-4bit should fall back to scalar"); + } + + #[test] + fn test_pack_simd_roundtrip() { + let indices: Vec = (0..1000).map(|i| (i % 16) as u8).collect(); + let packed = pack_indices_simd(&indices, 4); + let recovered = unpack_indices_simd(&packed, 4, 1000); + assert_eq!(indices, recovered); + } } diff --git a/src/hpc/property_mask.rs b/src/hpc/property_mask.rs new file mode 100644 index 00000000..6cafaa0b --- /dev/null +++ b/src/hpc/property_mask.rs @@ -0,0 +1,337 @@ +//! Block property mask — compiled bitset queries. +//! +//! Compiles boolean property queries (e.g. "waterlogged AND facing_north AND NOT open") +//! into bitmask operations that test a single block state in O(1). +//! +//! With AVX-512 VPTERNLOGD: tests 3 conditions in 1 cycle. + +/// A compiled property query on block state bits. +/// +/// Tests multiple boolean properties in a single operation: +/// `(block_state & and_mask) == and_expect && (block_state & andn_mask) == 0` +/// +/// # Examples +/// +/// ``` +/// use ndarray::hpc::property_mask::PropertyMask; +/// +/// let mask = PropertyMask::new() +/// .require_bit(0) // bit 0 must be set +/// .forbid_bit(3); // bit 3 must NOT be set +/// +/// assert!(mask.test(0b0001)); // bit 0 set, bit 3 clear +/// assert!(!mask.test(0b1001)); // bit 3 is set — forbidden +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct PropertyMask { + /// Bits to test (AND) + and_mask: u64, + /// Expected result after AND + and_expect: u64, + /// Bits that must NOT be set + andn_mask: u64, +} + +impl PropertyMask { + /// Create a new empty mask that matches all block states. + pub fn new() -> Self { + Self { + and_mask: 0, + and_expect: 0, + andn_mask: 0, + } + } + + /// Require that `bit` is set in the block state. + /// + /// # Panics + /// Panics if `bit >= 64`. + pub fn require_bit(mut self, bit: usize) -> Self { + assert!(bit < 64, "bit index out of range"); + let b = 1u64 << bit; + self.and_mask |= b; + self.and_expect |= b; + self + } + + /// Require that a multi-bit field at `offset` with `width` bits equals `value`. + /// + /// # Panics + /// Panics if the field exceeds 64 bits or `value` does not fit in `width` bits. + pub fn require_value(mut self, offset: usize, width: usize, value: u64) -> Self { + assert!(width > 0 && offset + width <= 64, "field out of range"); + let field_mask = ((1u64 << width) - 1) << offset; + assert!(value < (1u64 << width), "value does not fit in width"); + self.and_mask |= field_mask; + self.and_expect = (self.and_expect & !field_mask) | (value << offset); + self + } + + /// Forbid `bit` from being set in the block state. + /// + /// # Panics + /// Panics if `bit >= 64`. + pub fn forbid_bit(mut self, bit: usize) -> Self { + assert!(bit < 64, "bit index out of range"); + self.andn_mask |= 1u64 << bit; + self + } + + /// Test a single block state against the compiled mask. + #[inline(always)] + pub fn test(&self, block_state: u64) -> bool { + (block_state & self.and_mask) == self.and_expect + && (block_state & self.andn_mask) == 0 + } + + /// Batch test up to 4096 block states (one chunk section). + /// Returns a `Vec` where each bit indicates whether the + /// corresponding state matched. + /// + /// The returned vector has `ceil(states.len() / 64)` entries. + pub fn test_section(&self, states: &[u64]) -> Vec { + let n = states.len(); + let result_len = (n + 63) / 64; + let mut result = vec![0u64; result_len]; + + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx2") { + // SAFETY: we checked avx2 at runtime, pointers are within slice bounds. + unsafe { + self.test_section_avx2(states, &mut result); + return result; + } + } + } + + self.test_section_scalar(states, &mut result); + result + } + + /// Count the number of matching block states in the slice. + pub fn count_section(&self, states: &[u64]) -> u32 { + let bits = self.test_section(states); + let full_words = states.len() / 64; + let remainder = states.len() % 64; + let mut count = 0u32; + for &w in &bits[..full_words] { + count += w.count_ones(); + } + if remainder > 0 { + // Mask off bits beyond the actual state count. + let last = bits[full_words] & ((1u64 << remainder) - 1); + count += last.count_ones(); + } + count + } + + // ---------- scalar fallback ---------- + + fn test_section_scalar(&self, states: &[u64], result: &mut [u64]) { + for (i, &state) in states.iter().enumerate() { + if self.test(state) { + result[i / 64] |= 1u64 << (i % 64); + } + } + } + + // ---------- AVX2 path ---------- + + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "avx2")] + unsafe fn test_section_avx2(&self, states: &[u64], result: &mut [u64]) { + // AVX2 processes 4 u64s at a time via 256-bit registers. + use core::arch::x86_64::*; + + let and_mask_v = _mm256_set1_epi64x(self.and_mask as i64); + let and_expect_v = _mm256_set1_epi64x(self.and_expect as i64); + let andn_mask_v = _mm256_set1_epi64x(self.andn_mask as i64); + let zero = _mm256_setzero_si256(); + + let chunks = states.len() / 4; + for c in 0..chunks { + let base = c * 4; + let vals = _mm256_loadu_si256(states.as_ptr().add(base) as *const __m256i); + + // (vals & and_mask) == and_expect + let anded = _mm256_and_si256(vals, and_mask_v); + let eq_and = _mm256_cmpeq_epi64(anded, and_expect_v); + + // (vals & andn_mask) == 0 + let andned = _mm256_and_si256(vals, andn_mask_v); + let eq_andn = _mm256_cmpeq_epi64(andned, zero); + + // both conditions + let both = _mm256_and_si256(eq_and, eq_andn); + + // Extract per-lane results (each lane is all-1s or all-0s). + // _mm256_movemask_epi8 gives 32 bits; lanes are 8 bytes each. + let mask32 = _mm256_movemask_epi8(both) as u32; + // Lane k matched if bytes [k*8..(k+1)*8] are all 0xFF → bits set. + for lane in 0..4usize { + let byte_mask = (mask32 >> (lane * 8)) & 0xFF; + if byte_mask == 0xFF { + let idx = base + lane; + result[idx / 64] |= 1u64 << (idx % 64); + } + } + } + + // Scalar tail + for i in (chunks * 4)..states.len() { + if self.test(states[i]) { + result[i / 64] |= 1u64 << (i % 64); + } + } + } +} + +impl Default for PropertyMask { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_single_require_bit() { + let m = PropertyMask::new().require_bit(0); + assert!(m.test(0b0001)); + assert!(m.test(0b1111)); + assert!(!m.test(0b0000)); + assert!(!m.test(0b1110)); + } + + #[test] + fn test_single_forbid_bit() { + let m = PropertyMask::new().forbid_bit(2); + assert!(m.test(0b0001)); + assert!(!m.test(0b0100)); + assert!(!m.test(0b0111)); + } + + #[test] + fn test_require_and_forbid() { + let m = PropertyMask::new().require_bit(0).forbid_bit(3); + assert!(m.test(0b0001)); + assert!(!m.test(0b1001)); // bit 3 forbidden + assert!(!m.test(0b0000)); // bit 0 required + } + + #[test] + fn test_require_value() { + // bits [2..4] must equal 2 (binary 10) + let m = PropertyMask::new().require_value(2, 2, 2); + assert!(m.test(0b1000)); // field = 10 => 2 + assert!(!m.test(0b0100)); // field = 01 => 1 + assert!(!m.test(0b1100)); // field = 11 => 3 + assert!(m.test(0b11111_1000)); // field still 10 + } + + #[test] + fn test_empty_mask_matches_everything() { + let m = PropertyMask::new(); + assert!(m.test(0)); + assert!(m.test(u64::MAX)); + assert!(m.test(0xDEADBEEF)); + } + + #[test] + fn test_batch_section() { + let m = PropertyMask::new().require_bit(0); + let states: Vec = (0..128).collect(); + let bits = m.test_section(&states); + // Every odd-indexed value in 0..128 has bit 0 set. + for i in 0..128 { + let matched = (bits[i / 64] >> (i % 64)) & 1 == 1; + assert_eq!(matched, i & 1 == 1, "mismatch at index {}", i); + } + } + + #[test] + fn test_batch_section_non_multiple() { + let m = PropertyMask::new().forbid_bit(0); + // 7 states: 0,1,2,3,4,5,6 + let states: Vec = (0..7).collect(); + let bits = m.test_section(&states); + assert_eq!(bits.len(), 1); + for i in 0..7 { + let matched = (bits[0] >> i) & 1 == 1; + assert_eq!(matched, i % 2 == 0, "mismatch at index {}", i); + } + } + + #[test] + fn test_count_section() { + let m = PropertyMask::new().require_bit(0); + let states: Vec = (0..100).collect(); + let count = m.count_section(&states); + // Numbers 1,3,5,...,99 → 50 + assert_eq!(count, 50); + } + + #[test] + fn test_count_empty() { + let m = PropertyMask::new(); + let states: Vec = (0..256).collect(); + assert_eq!(m.count_section(&states), 256); + } + + #[test] + fn test_builder_chain() { + let m = PropertyMask::new() + .require_bit(0) + .require_bit(1) + .forbid_bit(4) + .require_value(8, 4, 0xA); + + let state = 0b0000_1010_0000_0011u64; // bits 0,1 set; field [8..12]=0xA; bit 4 clear + assert!(m.test(state)); + + let bad_bit4 = state | (1 << 4); + assert!(!m.test(bad_bit4)); + } + + #[test] + fn test_scalar_parity_with_batch() { + // Ensure scalar single-test agrees with batch for a complex mask. + let m = PropertyMask::new() + .require_bit(5) + .forbid_bit(10) + .require_value(16, 3, 5); + + let states: Vec = (0..512u64).map(|i| i.wrapping_mul(0x123456789)).collect(); + let batch = m.test_section(&states); + for (i, &s) in states.iter().enumerate() { + let from_batch = (batch[i / 64] >> (i % 64)) & 1 == 1; + assert_eq!(from_batch, m.test(s), "parity mismatch at index {}", i); + } + } + + #[test] + #[should_panic(expected = "bit index out of range")] + fn test_require_bit_oob() { + PropertyMask::new().require_bit(64); + } + + #[test] + #[should_panic(expected = "bit index out of range")] + fn test_forbid_bit_oob() { + PropertyMask::new().forbid_bit(64); + } + + #[test] + #[should_panic(expected = "field out of range")] + fn test_require_value_oob() { + PropertyMask::new().require_value(60, 8, 0); + } + + #[test] + fn test_default_is_new() { + assert_eq!(PropertyMask::default(), PropertyMask::new()); + } +} diff --git a/src/hpc/spatial_hash.rs b/src/hpc/spatial_hash.rs new file mode 100644 index 00000000..973e2309 --- /dev/null +++ b/src/hpc/spatial_hash.rs @@ -0,0 +1,448 @@ +//! 3D spatial hash grid for efficient proximity queries. +//! +//! Entities are hashed into axis-aligned cells by position. Radius and KNN +//! queries only visit cells that overlap the search volume, giving O(1) +//! amortised lookup for uniformly distributed entities. +//! +//! The grid itself stores only `(cell_key -> [entity_id])`. Actual positions +//! are passed in by reference at query time so the caller keeps ownership. + +use std::collections::HashMap; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +#[inline] +fn sq_dist_f32(a: [f32; 3], b: [f32; 3]) -> f32 { + let dx = a[0] - b[0]; + let dy = a[1] - b[1]; + let dz = a[2] - b[2]; + dx * dx + dy * dy + dz * dz +} + +// --------------------------------------------------------------------------- +// SpatialHash +// --------------------------------------------------------------------------- + +/// 3D spatial hash. Entities are hashed into cells by position. +pub struct SpatialHash { + cell_size: f32, + inv_cell_size: f32, + grid: HashMap<(i32, i32, i32), Vec>, +} + +impl SpatialHash { + /// Create a new spatial hash with the given cell size. + /// + /// # Panics + /// Panics if `cell_size` is not positive and finite. + pub fn new(cell_size: f32) -> Self { + assert!(cell_size > 0.0 && cell_size.is_finite(), "cell_size must be positive and finite"); + Self { + cell_size, + inv_cell_size: 1.0 / cell_size, + grid: HashMap::new(), + } + } + + /// Insert an entity at the given position. + pub fn insert(&mut self, id: u32, x: f32, y: f32, z: f32) { + let key = self.cell_key(x, y, z); + self.grid.entry(key).or_default().push(id); + } + + /// Remove an entity from its cell. Returns `true` if found and removed. + pub fn remove(&mut self, id: u32, x: f32, y: f32, z: f32) -> bool { + let key = self.cell_key(x, y, z); + if let Some(cell) = self.grid.get_mut(&key) { + if let Some(pos) = cell.iter().position(|&eid| eid == id) { + cell.swap_remove(pos); + if cell.is_empty() { + self.grid.remove(&key); + } + return true; + } + } + false + } + + /// Move an entity from `old` position to `new` position. + pub fn update(&mut self, id: u32, old: [f32; 3], new: [f32; 3]) { + let old_key = self.cell_key(old[0], old[1], old[2]); + let new_key = self.cell_key(new[0], new[1], new[2]); + if old_key != new_key { + self.remove(id, old[0], old[1], old[2]); + self.insert(id, new[0], new[1], new[2]); + } + } + + /// Remove all entities from the grid. + pub fn clear(&mut self) { + self.grid.clear(); + } + + /// Total number of entity entries across all cells. + pub fn len(&self) -> usize { + self.grid.values().map(|v| v.len()).sum() + } + + /// Whether the grid contains zero entities. + pub fn is_empty(&self) -> bool { + self.grid.is_empty() + } + + /// Find all entities within `radius` of `(x, y, z)`. + /// + /// `positions` maps entity id to its `[x, y, z]`. Only entities present + /// in `positions` are considered. Returns `(entity_id, squared_distance)` + /// sorted ascending by distance. + pub fn query_radius( + &self, + x: f32, + y: f32, + z: f32, + radius: f32, + positions: &HashMap, + ) -> Vec<(u32, f32)> { + let radius_sq = radius * radius; + let query = [x, y, z]; + + // Determine cell range to search. + let min_cx = ((x - radius) * self.inv_cell_size).floor() as i32; + let max_cx = ((x + radius) * self.inv_cell_size).floor() as i32; + let min_cy = ((y - radius) * self.inv_cell_size).floor() as i32; + let max_cy = ((y + radius) * self.inv_cell_size).floor() as i32; + let min_cz = ((z - radius) * self.inv_cell_size).floor() as i32; + let max_cz = ((z + radius) * self.inv_cell_size).floor() as i32; + + let mut results: Vec<(u32, f32)> = Vec::new(); + + for cx in min_cx..=max_cx { + for cy in min_cy..=max_cy { + for cz in min_cz..=max_cz { + if let Some(cell) = self.grid.get(&(cx, cy, cz)) { + for &eid in cell { + if let Some(&pos) = positions.get(&eid) { + let d2 = sq_dist_f32(query, pos); + if d2 <= radius_sq { + results.push((eid, d2)); + } + } + } + } + } + } + } + + results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal)); + results + } + + /// Find K nearest entities to `(x, y, z)`. + /// + /// Uses expanding-ring search: starts at the cell containing the query + /// point and expands outward until at least K candidates are found, then + /// refines. Returns `(entity_id, squared_distance)` sorted ascending. + pub fn query_knn( + &self, + x: f32, + y: f32, + z: f32, + k: usize, + positions: &HashMap, + ) -> Vec<(u32, f32)> { + if k == 0 { + return Vec::new(); + } + + let query = [x, y, z]; + + // Expand ring from 0 until we have enough candidates. + let mut candidates: Vec<(u32, f32)> = Vec::new(); + let (cx, cy, cz) = self.cell_key(x, y, z); + + let mut ring = 0i32; + let max_ring = 64; // safety cap + + loop { + // Collect candidates from all cells in this ring shell. + for dx in -ring..=ring { + for dy in -ring..=ring { + for dz in -ring..=ring { + // Only visit cells on the shell (at least one coord at + // the ring boundary) to avoid re-visiting interior. + if dx.abs() != ring && dy.abs() != ring && dz.abs() != ring { + continue; + } + let key = (cx + dx, cy + dy, cz + dz); + if let Some(cell) = self.grid.get(&key) { + for &eid in cell { + if let Some(&pos) = positions.get(&eid) { + let d2 = sq_dist_f32(query, pos); + candidates.push((eid, d2)); + } + } + } + } + } + } + + if ring >= max_ring { + break; + } + + // If we have at least k candidates, check whether the k-th best + // is closer than the nearest possible point in the next ring. + // If so, no further ring can improve the result. + if candidates.len() >= k { + candidates.sort_by(|a, b| { + a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal) + }); + let worst = candidates[k - 1].1; + // The nearest point in ring+1 is at least (ring * cell_size) away. + let next_ring_min = (ring as f32) * self.cell_size; + if worst <= next_ring_min * next_ring_min { + break; + } + } + + ring += 1; + } + + candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(core::cmp::Ordering::Equal)); + candidates.truncate(k); + candidates + } + + /// Compute the cell key for a world-space coordinate. + fn cell_key(&self, x: f32, y: f32, z: f32) -> (i32, i32, i32) { + ( + (x * self.inv_cell_size).floor() as i32, + (y * self.inv_cell_size).floor() as i32, + (z * self.inv_cell_size).floor() as i32, + ) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_positions(pts: &[(u32, [f32; 3])]) -> HashMap { + pts.iter().copied().collect() + } + + // -- insert / remove -- + + #[test] + fn test_insert_and_len() { + let mut sh = SpatialHash::new(10.0); + assert!(sh.is_empty()); + sh.insert(0, 1.0, 2.0, 3.0); + sh.insert(1, 11.0, 2.0, 3.0); + assert_eq!(sh.len(), 2); + assert!(!sh.is_empty()); + } + + #[test] + fn test_remove() { + let mut sh = SpatialHash::new(10.0); + sh.insert(0, 1.0, 2.0, 3.0); + assert!(sh.remove(0, 1.0, 2.0, 3.0)); + assert!(sh.is_empty()); + } + + #[test] + fn test_remove_not_found() { + let mut sh = SpatialHash::new(10.0); + sh.insert(0, 1.0, 2.0, 3.0); + assert!(!sh.remove(99, 1.0, 2.0, 3.0)); + assert!(!sh.remove(0, 999.0, 999.0, 999.0)); + assert_eq!(sh.len(), 1); + } + + // -- update -- + + #[test] + fn test_update_same_cell() { + let mut sh = SpatialHash::new(10.0); + sh.insert(0, 1.0, 1.0, 1.0); + sh.update(0, [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]); // same cell + assert_eq!(sh.len(), 1); + } + + #[test] + fn test_update_different_cell() { + let mut sh = SpatialHash::new(10.0); + sh.insert(0, 1.0, 1.0, 1.0); + sh.update(0, [1.0, 1.0, 1.0], [100.0, 100.0, 100.0]); // different cell + assert_eq!(sh.len(), 1); + // Old cell should be empty now + assert!(sh.remove(0, 100.0, 100.0, 100.0)); + assert!(sh.is_empty()); + } + + // -- clear -- + + #[test] + fn test_clear() { + let mut sh = SpatialHash::new(10.0); + for i in 0..50 { + sh.insert(i, i as f32, 0.0, 0.0); + } + assert_eq!(sh.len(), 50); + sh.clear(); + assert!(sh.is_empty()); + } + + // -- radius query -- + + #[test] + fn test_query_radius_basic() { + let mut sh = SpatialHash::new(10.0); + let pts = vec![ + (0u32, [0.0f32, 0.0, 0.0]), + (1, [5.0, 0.0, 0.0]), + (2, [20.0, 0.0, 0.0]), + (3, [100.0, 0.0, 0.0]), + ]; + for &(id, pos) in &pts { + sh.insert(id, pos[0], pos[1], pos[2]); + } + let positions = make_positions(&pts); + let result = sh.query_radius(0.0, 0.0, 0.0, 10.0, &positions); + let ids: Vec = result.iter().map(|&(id, _)| id).collect(); + assert!(ids.contains(&0)); + assert!(ids.contains(&1)); + assert!(!ids.contains(&2)); // dist=20 > radius=10 + assert!(!ids.contains(&3)); + } + + #[test] + fn test_query_radius_sorted_by_distance() { + let mut sh = SpatialHash::new(5.0); + let pts = vec![ + (0u32, [10.0f32, 0.0, 0.0]), + (1, [3.0, 0.0, 0.0]), + (2, [1.0, 0.0, 0.0]), + ]; + for &(id, pos) in &pts { + sh.insert(id, pos[0], pos[1], pos[2]); + } + let positions = make_positions(&pts); + let result = sh.query_radius(0.0, 0.0, 0.0, 20.0, &positions); + // Should be sorted: id=2 (d=1), id=1 (d=9), id=0 (d=100) + assert_eq!(result[0].0, 2); + assert_eq!(result[1].0, 1); + assert_eq!(result[2].0, 0); + } + + #[test] + fn test_query_radius_empty() { + let sh = SpatialHash::new(10.0); + let positions: HashMap = HashMap::new(); + let result = sh.query_radius(0.0, 0.0, 0.0, 100.0, &positions); + assert!(result.is_empty()); + } + + // -- knn -- + + #[test] + fn test_knn_basic() { + let mut sh = SpatialHash::new(10.0); + let pts = vec![ + (0u32, [30.0f32, 0.0, 0.0]), + (1, [10.0, 0.0, 0.0]), + (2, [20.0, 0.0, 0.0]), + (3, [5.0, 0.0, 0.0]), + ]; + for &(id, pos) in &pts { + sh.insert(id, pos[0], pos[1], pos[2]); + } + let positions = make_positions(&pts); + let result = sh.query_knn(0.0, 0.0, 0.0, 2, &positions); + assert_eq!(result.len(), 2); + assert_eq!(result[0].0, 3); // dist=25 + assert_eq!(result[1].0, 1); // dist=100 + } + + #[test] + fn test_knn_vs_brute_force() { + let mut sh = SpatialHash::new(5.0); + let pts: Vec<(u32, [f32; 3])> = (0..50) + .map(|i| { + let v = i as f32 * 2.0; + (i as u32, [v, v * 0.5, v * 0.3]) + }) + .collect(); + for &(id, pos) in &pts { + sh.insert(id, pos[0], pos[1], pos[2]); + } + let positions = make_positions(&pts); + let k = 5; + let result = sh.query_knn(10.0, 5.0, 3.0, k, &positions); + + // Brute-force reference + let mut brute: Vec<(u32, f32)> = pts + .iter() + .map(|&(id, pos)| (id, sq_dist_f32([10.0, 5.0, 3.0], pos))) + .collect(); + brute.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + brute.truncate(k); + + assert_eq!(result.len(), brute.len()); + // Compare distances (not IDs — ties may break differently) + for (r, b) in result.iter().zip(brute.iter()) { + assert!( + (r.1 - b.1).abs() < 1e-3, + "knn dist mismatch: spatial_hash=({},{:.2}) brute=({},{:.2})", + r.0, r.1, b.0, b.1 + ); + } + } + + #[test] + fn test_knn_k_zero() { + let sh = SpatialHash::new(10.0); + let positions: HashMap = HashMap::new(); + let result = sh.query_knn(0.0, 0.0, 0.0, 0, &positions); + assert!(result.is_empty()); + } + + #[test] + fn test_knn_k_larger_than_count() { + let mut sh = SpatialHash::new(10.0); + sh.insert(0, 1.0, 2.0, 3.0); + let positions = make_positions(&[(0, [1.0, 2.0, 3.0])]); + let result = sh.query_knn(0.0, 0.0, 0.0, 100, &positions); + assert_eq!(result.len(), 1); + } + + // -- negative coordinates -- + + #[test] + fn test_negative_coordinates() { + let mut sh = SpatialHash::new(10.0); + let pts = vec![ + (0u32, [-5.0f32, -5.0, -5.0]), + (1, [5.0, 5.0, 5.0]), + ]; + for &(id, pos) in &pts { + sh.insert(id, pos[0], pos[1], pos[2]); + } + let positions = make_positions(&pts); + let result = sh.query_radius(0.0, 0.0, 0.0, 20.0, &positions); + assert_eq!(result.len(), 2); + } + + #[test] + #[should_panic] + fn test_invalid_cell_size() { + SpatialHash::new(0.0); + } +}