diff --git a/src/hpc/aabb.rs b/src/hpc/aabb.rs index 0dda3e1b..a5d742f5 100644 --- a/src/hpc/aabb.rs +++ b/src/hpc/aabb.rs @@ -292,6 +292,15 @@ unsafe fn aabb_intersect_batch_sse41(query: &Aabb, candidates: &[Aabb]) -> Vec (Vec, Vec) { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx512f") && aabbs.len() >= 16 { + // SAFETY: avx512f detected, enough AABBs for batch processing. + unsafe { + return ray_aabb_slab_test_avx512(ray, aabbs); + } + } + } ray_aabb_slab_test_scalar(ray, aabbs) } @@ -320,6 +329,128 @@ fn ray_aabb_slab_test_scalar(ray: &Ray, aabbs: &[Aabb]) -> (Vec, Vec) (hits, t_values) } +/// AVX-512 batch ray-AABB slab test: processes 16 AABBs per iteration. +/// +/// Broadcasts ray origin and inv_dir per axis, gathers candidate min/max +/// coords into SoA arrays, computes slab intervals with `_mm512_min_ps` / +/// `_mm512_max_ps`, and combines masks with `_mm512_cmp_ps_mask`. +/// +/// # Safety +/// Caller must ensure AVX-512F is available. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +unsafe fn ray_aabb_slab_test_avx512(ray: &Ray, aabbs: &[Aabb]) -> (Vec, Vec) { + use core::arch::x86_64::*; + + let mut hits = Vec::with_capacity(aabbs.len()); + let mut t_values = Vec::with_capacity(aabbs.len()); + + // Broadcast ray origin and inv_dir per axis + let orig_x = _mm512_set1_ps(ray.origin[0]); + let orig_y = _mm512_set1_ps(ray.origin[1]); + let orig_z = _mm512_set1_ps(ray.origin[2]); + let inv_x = _mm512_set1_ps(ray.inv_dir[0]); + let inv_y = _mm512_set1_ps(ray.inv_dir[1]); + let inv_z = _mm512_set1_ps(ray.inv_dir[2]); + let zero = _mm512_set1_ps(0.0); + + // Process 16 AABBs at a time + let chunks = aabbs.len() / 16; + for c in 0..chunks { + let base = c * 16; + + // Gather min/max coords for 16 AABBs into SoA arrays + let mut a_min_x = [0.0f32; 16]; + let mut a_max_x = [0.0f32; 16]; + let mut a_min_y = [0.0f32; 16]; + let mut a_max_y = [0.0f32; 16]; + let mut a_min_z = [0.0f32; 16]; + let mut a_max_z = [0.0f32; 16]; + + for i in 0..16 { + let aabb = &aabbs[base + i]; + a_min_x[i] = aabb.min[0]; + a_max_x[i] = aabb.max[0]; + a_min_y[i] = aabb.min[1]; + a_max_y[i] = aabb.max[1]; + a_min_z[i] = aabb.min[2]; + a_max_z[i] = aabb.max[2]; + } + + // SAFETY: arrays are 16-element, avx512f checked by caller. + let v_min_x = _mm512_loadu_ps(a_min_x.as_ptr()); + let v_max_x = _mm512_loadu_ps(a_max_x.as_ptr()); + let v_min_y = _mm512_loadu_ps(a_min_y.as_ptr()); + let v_max_y = _mm512_loadu_ps(a_max_y.as_ptr()); + let v_min_z = _mm512_loadu_ps(a_min_z.as_ptr()); + let v_max_z = _mm512_loadu_ps(a_max_z.as_ptr()); + + // X axis: t1 = (min - origin) * inv_dir, t2 = (max - origin) * inv_dir + let t1_x = _mm512_mul_ps(_mm512_sub_ps(v_min_x, orig_x), inv_x); + let t2_x = _mm512_mul_ps(_mm512_sub_ps(v_max_x, orig_x), inv_x); + let t_near_x = _mm512_min_ps(t1_x, t2_x); + let t_far_x = _mm512_max_ps(t1_x, t2_x); + + // Y axis + let t1_y = _mm512_mul_ps(_mm512_sub_ps(v_min_y, orig_y), inv_y); + let t2_y = _mm512_mul_ps(_mm512_sub_ps(v_max_y, orig_y), inv_y); + let t_near_y = _mm512_min_ps(t1_y, t2_y); + let t_far_y = _mm512_max_ps(t1_y, t2_y); + + // Z axis + let t1_z = _mm512_mul_ps(_mm512_sub_ps(v_min_z, orig_z), inv_z); + let t2_z = _mm512_mul_ps(_mm512_sub_ps(v_max_z, orig_z), inv_z); + let t_near_z = _mm512_min_ps(t1_z, t2_z); + let t_far_z = _mm512_max_ps(t1_z, t2_z); + + // t_enter = max(t_near_x, t_near_y, t_near_z) + let t_enter = _mm512_max_ps(_mm512_max_ps(t_near_x, t_near_y), t_near_z); + // t_exit = min(t_far_x, t_far_y, t_far_z) + let t_exit = _mm512_min_ps(_mm512_min_ps(t_far_x, t_far_y), t_far_z); + + // hit = t_enter <= t_exit AND t_exit >= 0 + // _CMP_LE_OQ = 18, _CMP_GE_OQ = 29 (ordered, quiet) + let m_le = _mm512_cmp_ps_mask::<{ _CMP_LE_OQ }>(t_enter, t_exit); + let m_ge = _mm512_cmp_ps_mask::<{ _CMP_GE_OQ }>(t_exit, zero); + let hit_mask = m_le & m_ge; + + // Clamp t_enter to 0 for origins inside box + let t_enter_clamped = _mm512_max_ps(t_enter, zero); + + // SAFETY: 16-element array matches __m512 lane count. + let mut t_arr = [0.0f32; 16]; + _mm512_storeu_ps(t_arr.as_mut_ptr(), t_enter_clamped); + + for i in 0..16 { + let hit = (hit_mask >> i) & 1 != 0; + hits.push(hit); + t_values.push(if hit { t_arr[i] } else { f32::MAX }); + } + } + + // Scalar tail for remainder + for i in (chunks * 16)..aabbs.len() { + let aabb = &aabbs[i]; + let mut t_enter = f32::NEG_INFINITY; + let mut t_exit = f32::INFINITY; + + for axis in 0..3 { + let t1 = (aabb.min[axis] - ray.origin[axis]) * ray.inv_dir[axis]; + let t2 = (aabb.max[axis] - ray.origin[axis]) * ray.inv_dir[axis]; + let t_near = t1.min(t2); + let t_far = t1.max(t2); + t_enter = t_enter.max(t_near); + t_exit = t_exit.min(t_far); + } + + let hit = t_enter <= t_exit && t_exit >= 0.0; + hits.push(hit); + t_values.push(if hit { t_enter.max(0.0) } else { f32::MAX }); + } + + (hits, t_values) +} + /// Expand all AABBs in-place by `(dx, dy, dz)` in both directions per axis. pub fn aabb_expand_batch(aabbs: &mut [Aabb], dx: f32, dy: f32, dz: f32) { #[cfg(target_arch = "x86_64")] @@ -722,4 +853,28 @@ mod tests { assert!(approx_eq(ray.inv_dir[0], 0.5)); assert!(ray.inv_dir[1].is_infinite()); } + + // ---------- AVX-512 ray-AABB parity ---------- + + #[test] + fn test_ray_aabb_avx512_parity() { + // 100 AABBs to exercise AVX-512 + tail + let ray = Ray::new([0.0, 0.5, 0.5], [1.0, 0.0, 0.0]); + let aabbs: Vec = (0..100) + .map(|i| { + let f = i as f32; + Aabb::new([f, 0.0, 0.0], [f + 1.0, 1.0, 1.0]) + }) + .collect(); + let (hits_batch, ts_batch) = ray_aabb_slab_test_batch(&ray, &aabbs); + let (hits_scalar, ts_scalar) = ray_aabb_slab_test_scalar(&ray, &aabbs); + assert_eq!(hits_batch, hits_scalar, "ray AVX-512 hit parity"); + for i in 0..100 { + assert!( + approx_eq(ts_batch[i], ts_scalar[i]), + "ray AVX-512 t parity at {i}: {} vs {}", + ts_batch[i], ts_scalar[i] + ); + } + } } diff --git a/src/hpc/byte_scan.rs b/src/hpc/byte_scan.rs index 38753310..1c38e7a6 100644 --- a/src/hpc/byte_scan.rs +++ b/src/hpc/byte_scan.rs @@ -200,6 +200,163 @@ pub fn byte_find_first(haystack: &[u8], needle: u8) -> Option { haystack.iter().position(|&b| b == needle) } +// --------------------------------------------------------------------------- +// NBT schema-aware scanning +// --------------------------------------------------------------------------- + +/// NBT tag type identifiers (matching Minecraft NBT format). +/// +/// Used by the schema scanner to identify tag boundaries in raw NBT data. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum NbtTagId { + /// TAG_End (0) — marks the end of a compound tag. + End = 0, + /// TAG_Byte (1) — a single signed byte. + Byte = 1, + /// TAG_Short (2) — a signed 16-bit integer. + Short = 2, + /// TAG_Int (3) — a signed 32-bit integer. + Int = 3, + /// TAG_Long (4) — a signed 64-bit integer. + Long = 4, + /// TAG_Float (5) — an IEEE 754 single-precision float. + Float = 5, + /// TAG_Double (6) — an IEEE 754 double-precision float. + Double = 6, + /// TAG_Byte_Array (7) — a length-prefixed array of bytes. + ByteArray = 7, + /// TAG_String (8) — a length-prefixed UTF-8 string. + String = 8, + /// TAG_List (9) — a typed list of tags. + List = 9, + /// TAG_Compound (10) — a set of named tags. + Compound = 10, + /// TAG_Int_Array (11) — a length-prefixed array of 32-bit integers. + IntArray = 11, + /// TAG_Long_Array (12) — a length-prefixed array of 64-bit integers. + LongArray = 12, +} + +/// A schema entry describing a named NBT tag to locate. +/// +/// The scanner searches for the tag name bytes preceded by the tag type byte +/// and a 2-byte big-endian name length. +#[derive(Debug, Clone)] +pub struct NbtSchemaEntry { + /// Expected tag type. + pub tag_id: NbtTagId, + /// Tag name bytes (UTF-8). + pub name: Vec, +} + +impl NbtSchemaEntry { + /// Create a schema entry for a named compound tag. + pub fn compound(name: &str) -> Self { + Self { tag_id: NbtTagId::Compound, name: name.as_bytes().to_vec() } + } + + /// Create a schema entry for a named list tag. + pub fn list(name: &str) -> Self { + Self { tag_id: NbtTagId::List, name: name.as_bytes().to_vec() } + } + + /// Create a schema entry for any tag type with given name. + pub fn new(tag_id: NbtTagId, name: &str) -> Self { + Self { tag_id, name: name.as_bytes().to_vec() } + } +} + +/// A match from schema scanning: the byte offset where this tag's payload begins. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct NbtSchemaMatch { + /// Index of the schema entry that matched. + pub schema_index: usize, + /// Byte offset of the tag type byte in the buffer. + pub tag_offset: usize, + /// Byte offset where the tag's payload begins (after type + name_len + name). + pub payload_offset: usize, +} + +/// Scan a raw NBT byte buffer for multiple named tags simultaneously. +/// +/// For each schema entry, searches for the pattern: +/// `[tag_id_byte] [name_len_hi] [name_len_lo] [name_bytes...]` +/// +/// Returns all matches found, sorted by offset. +/// +/// # Strategy +/// +/// 1. Use SIMD `byte_find_all` to locate all occurrences of each unique tag_id byte +/// 2. At each candidate position, verify the name length and name bytes match +/// 3. Record payload offset (position + 1 + 2 + name_len) +/// +/// This avoids linear scanning of the entire buffer for each tag. +pub fn nbt_schema_scan(data: &[u8], schema: &[NbtSchemaEntry]) -> Vec { + let mut matches = Vec::new(); + + // Group schema entries by tag_id to avoid redundant SIMD scans. + // Collect unique tag_id bytes and the schema indices that use each. + let mut tag_groups: Vec<(u8, Vec)> = Vec::new(); + for (si, entry) in schema.iter().enumerate() { + let tid = entry.tag_id as u8; + if let Some(group) = tag_groups.iter_mut().find(|(t, _)| *t == tid) { + group.1.push(si); + } else { + tag_groups.push((tid, vec![si])); + } + } + + for (tid_byte, schema_indices) in &tag_groups { + // SIMD-accelerated scan for this tag type byte. + let candidates = byte_find_all(data, *tid_byte); + + for &pos in &candidates { + // Need at least 3 bytes (tag_id + 2-byte name_len) after pos. + if pos + 3 > data.len() { + continue; + } + + // Read big-endian u16 name length. + let name_len = u16::from_be_bytes([data[pos + 1], data[pos + 2]]) as usize; + + // Check bounds for the full name. + if pos + 3 + name_len > data.len() { + continue; + } + + let name_slice = &data[pos + 3..pos + 3 + name_len]; + + // Check against every schema entry for this tag_id. + for &si in schema_indices { + let entry = &schema[si]; + if entry.name.len() == name_len && name_slice == entry.name.as_slice() { + matches.push(NbtSchemaMatch { + schema_index: si, + tag_offset: pos, + payload_offset: pos + 3 + name_len, + }); + } + } + } + } + + // Sort by tag_offset for deterministic output order. + matches.sort_by_key(|m| m.tag_offset); + matches +} + +/// Scan multiple NBT buffers against the same schema. +/// +/// Returns per-buffer match vectors. Useful for batch region loading +/// where 1024 chunk NBT blobs are processed together. +pub fn nbt_schema_scan_batch( + buffers: &[&[u8]], + schema: &[NbtSchemaEntry], +) -> Vec> { + buffers.iter().map(|buf| nbt_schema_scan(buf, schema)).collect() +} + // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- @@ -325,4 +482,81 @@ mod tests { let buf = vec![0xABu8; 64]; assert_eq!(byte_count(&buf, 0xAB), 64); } + + #[test] + fn test_nbt_schema_scan_basic() { + // Manually craft an NBT-like buffer with a Compound tag named "Entities" + // Format: tag_id(1) + name_len(2 BE) + name(N) + payload... + let mut data = Vec::new(); + // Tag: Compound "Entities" + data.push(10); // Compound tag id + data.extend_from_slice(&(8u16).to_be_bytes()); // name length + data.extend_from_slice(b"Entities"); // name + data.extend_from_slice(&[0; 10]); // some payload + + let schema = vec![NbtSchemaEntry::compound("Entities")]; + let matches = nbt_schema_scan(&data, &schema); + assert_eq!(matches.len(), 1); + assert_eq!(matches[0].schema_index, 0); + assert_eq!(matches[0].tag_offset, 0); + assert_eq!(matches[0].payload_offset, 11); // 1 + 2 + 8 + } + + #[test] + fn test_nbt_schema_scan_multiple_tags() { + let mut data = Vec::new(); + // Compound "Entities" + data.push(10); + data.extend_from_slice(&(8u16).to_be_bytes()); + data.extend_from_slice(b"Entities"); + data.extend_from_slice(&[0; 5]); + // List "BlockEntities" + let offset2 = data.len(); + data.push(9); // List + data.extend_from_slice(&(13u16).to_be_bytes()); + data.extend_from_slice(b"BlockEntities"); + data.extend_from_slice(&[0; 5]); + + let schema = vec![ + NbtSchemaEntry::compound("Entities"), + NbtSchemaEntry::list("BlockEntities"), + ]; + let matches = nbt_schema_scan(&data, &schema); + assert_eq!(matches.len(), 2); + assert_eq!(matches[0].tag_offset, 0); + assert_eq!(matches[1].tag_offset, offset2); + } + + #[test] + fn test_nbt_schema_scan_no_match() { + let data = vec![0u8; 100]; + let schema = vec![NbtSchemaEntry::compound("Entities")]; + let matches = nbt_schema_scan(&data, &schema); + assert!(matches.is_empty()); + } + + #[test] + fn test_nbt_schema_scan_batch() { + let buf1 = { + let mut d = Vec::new(); + d.push(10); + d.extend_from_slice(&(4u16).to_be_bytes()); + d.extend_from_slice(b"Test"); + d + }; + let buf2 = vec![0u8; 20]; // no match + + let schema = vec![NbtSchemaEntry::compound("Test")]; + let results = nbt_schema_scan_batch(&[&buf1, &buf2], &schema); + assert_eq!(results.len(), 2); + assert_eq!(results[0].len(), 1); + assert!(results[1].is_empty()); + } + + #[test] + fn test_nbt_tag_id_values() { + assert_eq!(NbtTagId::End as u8, 0); + assert_eq!(NbtTagId::Compound as u8, 10); + assert_eq!(NbtTagId::LongArray as u8, 12); + } } diff --git a/src/hpc/jitson_cranelift/engine.rs b/src/hpc/jitson_cranelift/engine.rs index 8633ac73..2782f864 100644 --- a/src/hpc/jitson_cranelift/engine.rs +++ b/src/hpc/jitson_cranelift/engine.rs @@ -32,6 +32,7 @@ use cranelift_module::{FuncId, Linkage, Module}; use super::detect::CpuCaps; use super::ir::{JitError, ScanParams}; +use super::noise_jit::CachedNoiseKernel; use super::scan_jit::ScanKernel; /// Builder for creating a JIT engine with registered external functions. @@ -103,7 +104,12 @@ impl JitEngineBuilder { let cache = LazyLock::new(empty_cache as fn() -> KernelCache); LazyLock::force(&cache); - Ok(JitEngine { module, caps, cache }) + Ok(JitEngine { + module, + caps, + cache, + noise_cache: HashMap::new(), + }) } } @@ -147,13 +153,16 @@ unsafe impl Sync for KernelCache {} pub struct JitEngine { /// Cranelift JIT module — owns the compiled code pages. /// Only accessed during BUILD phase (&mut self). - module: JITModule, + pub(crate) module: JITModule, /// CPU capabilities detected at engine creation. pub caps: CpuCaps, /// Kernel cache. Mutable during BUILD (via get_mut), frozen during RUN. cache: LazyLock, + + /// Noise kernel cache. Mutable during BUILD, read-only during RUN. + pub(crate) noise_cache: HashMap, } // SAFETY: JITModule's compiled code pages are immutable after finalization. diff --git a/src/hpc/jitson_cranelift/mod.rs b/src/hpc/jitson_cranelift/mod.rs index b327c126..033c207f 100644 --- a/src/hpc/jitson_cranelift/mod.rs +++ b/src/hpc/jitson_cranelift/mod.rs @@ -17,8 +17,10 @@ pub mod ir; pub mod detect; pub mod engine; pub mod scan_jit; +pub mod noise_jit; pub use ir::*; pub use detect::CpuCaps; pub use engine::{JitEngine, JitEngineBuilder}; pub use scan_jit::ScanKernel; +pub use noise_jit::{NoiseKernel, NoiseKernelParams}; diff --git a/src/hpc/jitson_cranelift/noise_jit.rs b/src/hpc/jitson_cranelift/noise_jit.rs new file mode 100644 index 00000000..6254b98e --- /dev/null +++ b/src/hpc/jitson_cranelift/noise_jit.rs @@ -0,0 +1,533 @@ +//! Noise function JIT specialization. +//! +//! Generates a native multi-octave noise evaluation function where: +//! - `frequencies[i]` -> F64 immediate operand in FMUL (no memory fetch) +//! - `amplitudes[i]` -> F64 immediate operand in FMA (no memory fetch) +//! - `normalization` -> F64 immediate final scale (no memory fetch) +//! - `num_octaves` -> unrolled loop (straight-line code, no branches) +//! +//! The base noise function is external, called via `FuncRef` +//! (registered via `JitEngineBuilder::register_fn()`). + +use cranelift_codegen::ir::types; +use cranelift_codegen::ir::{AbiParam, Function, InstBuilder, Signature, UserFuncName}; +use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext, Variable}; +use cranelift_jit::JITModule; +use cranelift_module::{Linkage, Module}; + +use super::ir::JitError; + +/// IR parameters for noise kernel compilation. +/// +/// Each field maps to a specific instruction encoding: +/// - `num_octaves` -> unrolled loop bound (no branch, no counter) +/// - `frequencies[i]` -> F64 immediate operand in FMUL per octave +/// - `amplitudes[i]` -> F64 immediate operand in FMA per octave +/// - `normalization` -> F64 immediate in final FMUL +/// +/// Cannot be `Copy` because `frequencies` and `amplitudes` are `Vec`. +#[derive(Debug, Clone)] +pub struct NoiseKernelParams { + /// Number of octaves — baked as loop unroll count (no branch). + pub num_octaves: u32, + + /// Per-octave frequency scale (baked as F64 immediates). + /// Length must equal `num_octaves`. + pub frequencies: Vec, + + /// Per-octave amplitude scale (baked as F64 immediates). + /// Length must equal `num_octaves`. + pub amplitudes: Vec, + + /// Final normalization scale factor (baked as F64 immediate). + pub normalization: f64, +} + +/// A compiled noise kernel — holds the native function pointer +/// and the params that generated it (for introspection). +pub struct NoiseKernel { + /// The compiled noise function. + /// Signature: `fn(x: f64, y: f64, z: f64) -> f64` + fn_ptr: *const u8, + + /// Parameters that were baked into this kernel. + pub params: NoiseKernelParams, +} + +// SAFETY: The compiled code is immutable and thread-safe. +// Function pointers point to finalized Cranelift code pages +// which are never modified after compilation. +unsafe impl Send for NoiseKernel {} +// SAFETY: The compiled code is immutable and thread-safe. +// No mutable state is accessed through shared references. +unsafe impl Sync for NoiseKernel {} + +impl NoiseKernel { + /// Wrap a raw function pointer as a `NoiseKernel`. + pub(crate) fn from_raw(ptr: *const u8, params: NoiseKernelParams) -> Self { + Self { + fn_ptr: ptr, + params, + } + } + + /// Evaluate the compiled noise function at the given coordinates. + /// + /// # Safety + /// + /// - `self.fn_ptr` must point to a valid Cranelift-compiled function + /// with the signature `fn(f64, f64, f64) -> f64`. + /// - The base noise function registered during compilation must still + /// be valid (not unloaded or freed). + pub unsafe fn evaluate(&self, x: f64, y: f64, z: f64) -> f64 { + // SAFETY: caller guarantees fn_ptr validity; fn_ptr was compiled + // by Cranelift with the matching signature (f64, f64, f64) -> f64. + let func: unsafe extern "C" fn(f64, f64, f64) -> f64 = + std::mem::transmute(self.fn_ptr); + func(x, y, z) + } + + /// Get the raw function pointer (for benchmarking/introspection). + pub fn as_fn_ptr(&self) -> *const u8 { + self.fn_ptr + } +} + +/// Build a `NoiseKernelParams` from a `CompiledNoiseConfig`. +/// +/// Maps the config's precomputed per-octave arrays directly into +/// the IR parameter struct for Cranelift code generation. +/// +/// # Examples +/// +/// ```ignore +/// use ndarray::hpc::jitson::noise::{NoiseParams, CompiledNoiseConfig}; +/// use ndarray::hpc::jitson_cranelift::noise_jit::from_compiled_config; +/// +/// let params = NoiseParams::perlin(4, 2.0, 0.5); +/// let config = CompiledNoiseConfig::from_params(¶ms, 42); +/// let kernel_params = from_compiled_config(&config); +/// assert_eq!(kernel_params.num_octaves, 4); +/// ``` +pub fn from_compiled_config( + config: &super::super::jitson::noise::CompiledNoiseConfig, +) -> NoiseKernelParams { + NoiseKernelParams { + num_octaves: config.frequencies.len() as u32, + frequencies: config.frequencies.clone(), + amplitudes: config.amplitudes.clone(), + normalization: config.normalization, + } +} + +/// Build the Cranelift IR for a multi-octave noise function with baked-in parameters. +/// +/// Generates a function with signature `fn(x: f64, y: f64, z: f64) -> f64` +/// that evaluates multi-octave noise by calling an external base noise function. +/// +/// The octave loop is fully unrolled — each octave becomes straight-line code +/// with F64 immediates for frequency and amplitude. No branches in the hot path. +/// +/// Generated pseudo-code: +/// ```text +/// fn noise(x: f64, y: f64, z: f64) -> f64: +/// value = 0.0 +/// // Octave 0 (unrolled): +/// value += AMP_0 * base_noise(x * FREQ_0, y * FREQ_0, z * FREQ_0) +/// // Octave 1 (unrolled): +/// value += AMP_1 * base_noise(x * FREQ_1, y * FREQ_1, z * FREQ_1) +/// // ... (one block per octave, no loop) +/// value *= NORMALIZATION +/// return value +/// ``` +pub fn build_noise_ir( + func: &mut Function, + params: &NoiseKernelParams, + base_noise_ref: cranelift_codegen::ir::FuncRef, +) -> Result<(), JitError> { + // Validate params + let n = params.num_octaves as usize; + if params.frequencies.len() != n { + return Err(JitError::InvalidParams(format!( + "frequencies length {} != num_octaves {}", + params.frequencies.len(), + n + ))); + } + if params.amplitudes.len() != n { + return Err(JitError::InvalidParams(format!( + "amplitudes length {} != num_octaves {}", + params.amplitudes.len(), + n + ))); + } + + let mut fbc = FunctionBuilderContext::new(); + let mut builder = FunctionBuilder::new(func, &mut fbc); + + // Variable for accumulating noise value across octaves. + let v_value = Variable::from_u32(0); + builder.declare_var(v_value, types::F64); + + // Entry block — function signature: fn(x: f64, y: f64, z: f64) -> f64 + let entry = builder.create_block(); + builder.append_block_params_for_function_params(entry); + builder.switch_to_block(entry); + builder.seal_block(entry); + + // Get function parameters + let x = builder.block_params(entry)[0]; + let y = builder.block_params(entry)[1]; + let z = builder.block_params(entry)[2]; + + // Initialize accumulator: value = 0.0 + let zero_f64 = builder.ins().f64const(0.0); + builder.def_var(v_value, zero_f64); + + // Unrolled octave loop — each iteration is straight-line code with + // frequency/amplitude baked as F64 immediates. + for i in 0..n { + let freq_imm = builder.ins().f64const(params.frequencies[i]); + let amp_imm = builder.ins().f64const(params.amplitudes[i]); + + // Scaled coordinates: sx = x * freq, sy = y * freq, sz = z * freq + let sx = builder.ins().fmul(x, freq_imm); + let sy = builder.ins().fmul(y, freq_imm); + let sz = builder.ins().fmul(z, freq_imm); + + // CALL base_noise(sx, sy, sz) + let call = builder.ins().call(base_noise_ref, &[sx, sy, sz]); + let noise_val = builder.inst_results(call)[0]; + + // value += amp * noise_val + // FMA: fma(amp, noise_val, accum) = amp * noise_val + accum + let accum = builder.use_var(v_value); + let new_accum = builder.ins().fma(amp_imm, noise_val, accum); + builder.def_var(v_value, new_accum); + } + + // Final normalization: value *= normalization + let normalization_imm = builder.ins().f64const(params.normalization); + let accum = builder.use_var(v_value); + let result = builder.ins().fmul(accum, normalization_imm); + + builder.ins().return_(&[result]); + + builder.finalize(); + Ok(()) +} + +/// Noise function signature: `fn(x: f64, y: f64, z: f64) -> f64` +fn noise_signature(module: &JITModule) -> Signature { + let mut sig = module.make_signature(); + sig.params.push(AbiParam::new(types::F64)); // x + sig.params.push(AbiParam::new(types::F64)); // y + sig.params.push(AbiParam::new(types::F64)); // z + sig.returns.push(AbiParam::new(types::F64)); // result + sig +} + +/// Base noise function signature (external): `fn(f64, f64, f64) -> f64` +fn base_noise_signature(module: &JITModule) -> Signature { + let mut sig = module.make_signature(); + sig.params.push(AbiParam::new(types::F64)); // x + sig.params.push(AbiParam::new(types::F64)); // y + sig.params.push(AbiParam::new(types::F64)); // z + sig.returns.push(AbiParam::new(types::F64)); // result + sig +} + +/// Hash noise kernel params + base noise name for cache lookup. +fn noise_params_hash(params: &NoiseKernelParams, base_noise_name: &str) -> u64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + params.num_octaves.hash(&mut hasher); + for f in ¶ms.frequencies { + f.to_bits().hash(&mut hasher); + } + for a in ¶ms.amplitudes { + a.to_bits().hash(&mut hasher); + } + params.normalization.to_bits().hash(&mut hasher); + base_noise_name.hash(&mut hasher); + hasher.finish() +} + +/// Cached noise kernel entry (internal to engine). +pub(crate) struct CachedNoiseKernel { + /// Compiled function pointer. + fn_ptr: *const u8, + /// Parameters baked into this kernel. + params: NoiseKernelParams, +} + +impl super::engine::JitEngine { + /// Compile a noise kernel and add it to the cache. + /// + /// The `base_noise_name` must be a symbol registered via + /// `JitEngineBuilder::register_fn()` with signature `fn(f64, f64, f64) -> f64`. + /// + /// Only works during BUILD phase (before sharing via `Arc`). + /// + /// Returns a cache hash that can be used with `get_noise()`. + pub fn compile_noise( + &mut self, + params: NoiseKernelParams, + base_noise_name: &str, + ) -> Result { + let cache_key = noise_params_hash(¶ms, base_noise_name); + + // Already compiled? Return existing hash. + if self.noise_cache.contains_key(&cache_key) { + return Ok(cache_key); + } + + // Declare the noise function + let func_name = format!("noise_{cache_key:x}"); + let sig = noise_signature(&self.module); + + let func_id = self + .module + .declare_function(&func_name, Linkage::Local, &sig) + .map_err(|e| JitError::Module(e.to_string()))?; + + // Declare the base noise function as an import + let base_noise_sig = base_noise_signature(&self.module); + let base_noise_id = self + .module + .declare_function(base_noise_name, Linkage::Import, &base_noise_sig) + .map_err(|e| JitError::Module(e.to_string()))?; + + let mut ctx = self.module.make_context(); + ctx.func.signature = sig; + ctx.func.name = UserFuncName::user(0, func_id.as_u32()); + + // Get a FuncRef for the base noise function + let base_noise_ref = self + .module + .declare_func_in_func(base_noise_id, &mut ctx.func); + + // Generate the noise IR + build_noise_ir(&mut ctx.func, ¶ms, base_noise_ref)?; + + // Compile + self.module + .define_function(func_id, &mut ctx) + .map_err(|e| JitError::Codegen(e.to_string()))?; + + self.module.clear_context(&mut ctx); + self.module + .finalize_definitions() + .map_err(|e| JitError::Codegen(format!("{e:?}")))?; + + let code_ptr = self.module.get_finalized_function(func_id); + + self.noise_cache.insert( + cache_key, + CachedNoiseKernel { + fn_ptr: code_ptr, + params: params.clone(), + }, + ); + + Ok(cache_key) + } + + /// Look up a compiled noise kernel by hash. Zero-cost after freeze. + /// Returns `None` if the kernel wasn't compiled during BUILD. + pub fn get_noise(&self, hash: u64) -> Option { + self.noise_cache + .get(&hash) + .map(|k| NoiseKernel::from_raw(k.fn_ptr, k.params.clone())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::hpc::jitson::noise::{CompiledNoiseConfig, NoiseParams}; + + #[test] + fn test_noise_kernel_params_from_config() { + let noise_params = NoiseParams::perlin(4, 2.0, 0.5); + let config = CompiledNoiseConfig::from_params(&noise_params, 42); + + let kernel_params = from_compiled_config(&config); + + assert_eq!(kernel_params.num_octaves, 4); + assert_eq!(kernel_params.frequencies.len(), 4); + assert_eq!(kernel_params.amplitudes.len(), 4); + + // Verify frequencies roundtrip + for i in 0..4 { + assert!( + (kernel_params.frequencies[i] - config.frequencies[i]).abs() < 1e-10, + "frequency mismatch at octave {i}" + ); + } + + // Verify amplitudes roundtrip + for i in 0..4 { + assert!( + (kernel_params.amplitudes[i] - config.amplitudes[i]).abs() < 1e-10, + "amplitude mismatch at octave {i}" + ); + } + + // Verify normalization roundtrip + assert!( + (kernel_params.normalization - config.normalization).abs() < 1e-10, + "normalization mismatch" + ); + } + + #[test] + fn test_build_noise_ir_compiles() { + use cranelift_codegen::ir::{AbiParam, UserFuncName}; + use cranelift_codegen::isa::CallConv; + use cranelift_codegen::settings; + + let params = NoiseKernelParams { + num_octaves: 3, + frequencies: vec![1.0, 2.0, 4.0], + amplitudes: vec![1.0, 0.5, 0.25], + normalization: 1.0 / 1.75, + }; + + // Build a minimal Function with the correct signature + let call_conv = CallConv::SystemV; + + let mut sig = cranelift_codegen::ir::Signature::new(call_conv); + sig.params.push(AbiParam::new(types::F64)); // x + sig.params.push(AbiParam::new(types::F64)); // y + sig.params.push(AbiParam::new(types::F64)); // z + sig.returns.push(AbiParam::new(types::F64)); // result + + let mut func = Function::with_name_signature(UserFuncName::user(0, 0), sig); + + // Declare the external base noise function signature + let mut base_sig = cranelift_codegen::ir::Signature::new(call_conv); + base_sig.params.push(AbiParam::new(types::F64)); + base_sig.params.push(AbiParam::new(types::F64)); + base_sig.params.push(AbiParam::new(types::F64)); + base_sig.returns.push(AbiParam::new(types::F64)); + + let base_noise_ref = func.import_function(cranelift_codegen::ir::ExtFuncData { + name: cranelift_codegen::ir::ExternalName::user(0, 1), + signature: func.import_signature(base_sig), + colocated: false, + }); + + // Build the IR — should not error + let result = build_noise_ir(&mut func, ¶ms, base_noise_ref); + assert!(result.is_ok(), "build_noise_ir failed: {result:?}"); + } + + #[test] + fn test_noise_kernel_params_clone() { + let params = NoiseKernelParams { + num_octaves: 4, + frequencies: vec![1.0, 2.0, 4.0, 8.0], + amplitudes: vec![1.0, 0.5, 0.25, 0.125], + normalization: 1.0 / 1.875, + }; + + let cloned = params.clone(); + assert_eq!(cloned.num_octaves, params.num_octaves); + assert_eq!(cloned.frequencies.len(), params.frequencies.len()); + assert_eq!(cloned.amplitudes.len(), params.amplitudes.len()); + assert!((cloned.normalization - params.normalization).abs() < 1e-10); + + for i in 0..4 { + assert!((cloned.frequencies[i] - params.frequencies[i]).abs() < 1e-10); + assert!((cloned.amplitudes[i] - params.amplitudes[i]).abs() < 1e-10); + } + } + + #[test] + fn test_noise_kernel_send_sync() { + /// Compile-time assertion that `T` implements Send + Sync. + fn assert_send_sync() {} + assert_send_sync::(); + } + + #[test] + fn test_build_noise_ir_rejects_mismatched_frequencies() { + use cranelift_codegen::ir::{AbiParam, UserFuncName}; + use cranelift_codegen::isa::CallConv; + + let params = NoiseKernelParams { + num_octaves: 3, + frequencies: vec![1.0, 2.0], // only 2, but num_octaves says 3 + amplitudes: vec![1.0, 0.5, 0.25], + normalization: 1.0, + }; + + let call_conv = CallConv::SystemV; + let mut sig = cranelift_codegen::ir::Signature::new(call_conv); + sig.params.push(AbiParam::new(types::F64)); + sig.params.push(AbiParam::new(types::F64)); + sig.params.push(AbiParam::new(types::F64)); + sig.returns.push(AbiParam::new(types::F64)); + + let mut func = Function::with_name_signature(UserFuncName::user(0, 0), sig); + + let mut base_sig = cranelift_codegen::ir::Signature::new(call_conv); + base_sig.params.push(AbiParam::new(types::F64)); + base_sig.params.push(AbiParam::new(types::F64)); + base_sig.params.push(AbiParam::new(types::F64)); + base_sig.returns.push(AbiParam::new(types::F64)); + + let base_noise_ref = func.import_function(cranelift_codegen::ir::ExtFuncData { + name: cranelift_codegen::ir::ExternalName::user(0, 1), + signature: func.import_signature(base_sig), + colocated: false, + }); + + let result = build_noise_ir(&mut func, ¶ms, base_noise_ref); + assert!( + result.is_err(), + "should reject mismatched num_octaves vs frequencies" + ); + } + + #[test] + fn test_build_noise_ir_rejects_mismatched_amplitudes() { + use cranelift_codegen::ir::{AbiParam, UserFuncName}; + use cranelift_codegen::isa::CallConv; + + let params = NoiseKernelParams { + num_octaves: 2, + frequencies: vec![1.0, 2.0], + amplitudes: vec![1.0], // only 1, but num_octaves says 2 + normalization: 1.0, + }; + + let call_conv = CallConv::SystemV; + let mut sig = cranelift_codegen::ir::Signature::new(call_conv); + sig.params.push(AbiParam::new(types::F64)); + sig.params.push(AbiParam::new(types::F64)); + sig.params.push(AbiParam::new(types::F64)); + sig.returns.push(AbiParam::new(types::F64)); + + let mut func = Function::with_name_signature(UserFuncName::user(0, 0), sig); + + let mut base_sig = cranelift_codegen::ir::Signature::new(call_conv); + base_sig.params.push(AbiParam::new(types::F64)); + base_sig.params.push(AbiParam::new(types::F64)); + base_sig.params.push(AbiParam::new(types::F64)); + base_sig.returns.push(AbiParam::new(types::F64)); + + let base_noise_ref = func.import_function(cranelift_codegen::ir::ExtFuncData { + name: cranelift_codegen::ir::ExternalName::user(0, 1), + signature: func.import_signature(base_sig), + colocated: false, + }); + + let result = build_noise_ir(&mut func, ¶ms, base_noise_ref); + assert!( + result.is_err(), + "should reject mismatched num_octaves vs amplitudes" + ); + } +} diff --git a/src/hpc/palette_codec.rs b/src/hpc/palette_codec.rs index 6eb3fd4c..e20cad3f 100644 --- a/src/hpc/palette_codec.rs +++ b/src/hpc/palette_codec.rs @@ -402,6 +402,117 @@ fn bytemuck_cast_u64_to_u8(words: &[u64]) -> &[u8] { } } +/// Reorder 4096 block states from Java Y-major ordering (y*256+z*16+x) +/// to Bedrock XZY ordering (x*256+z*16+y). +/// +/// Bedrock uses a different coordinate convention than Java edition. +/// This function handles the permutation without intermediate allocation. +/// +/// # Panics +/// Panics if `states.len() != 4096`. +pub fn bedrock_reorder_xzy(states: &[u16]) -> Vec { + assert!(states.len() == 4096, "expected 4096 block states, got {}", states.len()); + + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx512f") { + // SAFETY: avx512f detected, states.len() == 4096 guaranteed by assert. + return unsafe { bedrock_reorder_xzy_avx512(states) }; + } + } + + let mut out = vec![0u16; 4096]; + for y in 0..16 { + for z in 0..16 { + for x in 0..16 { + out[x * 256 + z * 16 + y] = states[y * 256 + z * 16 + x]; + } + } + } + out +} + +/// Reorder 4096 block states from Bedrock XZY ordering (x*256+z*16+y) +/// back to Java Y-major ordering (y*256+z*16+x). +/// +/// # Panics +/// Panics if `states.len() != 4096`. +pub fn bedrock_reorder_xzy_inverse(states: &[u16]) -> Vec { + assert!(states.len() == 4096, "expected 4096 block states, got {}", states.len()); + + let mut out = vec![0u16; 4096]; + for x in 0..16 { + for z in 0..16 { + for y in 0..16 { + out[y * 256 + z * 16 + x] = states[x * 256 + z * 16 + y]; + } + } + } + out +} + +/// Reorder Java Y-major block states to Bedrock XZY and pack into bit-packed format. +/// +/// Combines `bedrock_reorder_xzy` with `pack_indices` for efficient serialization. +/// The palette maps u16 block state IDs to u8 palette indices. +/// +/// Returns `None` if any block state ID is not in the palette. +/// +/// # Panics +/// Panics if `states.len() != 4096` or `bits_per_index` is 0 or > 8. +/// +/// # Example +/// +/// ``` +/// use ndarray::hpc::palette_codec::bedrock_pack_section; +/// use std::collections::HashMap; +/// +/// let states = vec![0u16; 4096]; +/// let mut palette = HashMap::new(); +/// palette.insert(0u16, 0u8); +/// let packed = bedrock_pack_section(&states, &palette, 1); +/// assert!(packed.is_some()); +/// ``` +pub fn bedrock_pack_section( + states: &[u16], + palette: &std::collections::HashMap, + bits_per_index: usize, +) -> Option> { + let reordered = bedrock_reorder_xzy(states); + let mut indices = Vec::with_capacity(4096); + for &state in &reordered { + let idx = palette.get(&state)?; + indices.push(*idx); + } + Some(pack_indices(&indices, bits_per_index)) +} + +/// AVX-512 accelerated reorder from Java Y-major to Bedrock XZY ordering. +/// +/// Uses the same permutation logic as the scalar path but is marked with +/// `target_feature(enable = "avx512f")` for future SIMD gather/scatter +/// optimization. +/// +/// # Safety +/// Caller must ensure AVX-512F is available and `states.len() == 4096`. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +unsafe fn bedrock_reorder_xzy_avx512(states: &[u16]) -> Vec { + // Scalar implementation with correct permutation logic. + // AVX-512 gather/scatter for u16 requires widening to u32 which adds + // complexity; the scalar loop over 4096 elements is already fast due to + // the target_feature enabling wider instruction scheduling. + let mut out = vec![0u16; 4096]; + for y in 0..16 { + for z in 0..16 { + for x in 0..16 { + out[x * 256 + z * 16 + y] = *states.get_unchecked(y * 256 + z * 16 + x); + } + } + } + out +} + #[cfg(test)] mod tests { use super::*; @@ -647,4 +758,85 @@ mod tests { } } } + + #[test] + fn test_bedrock_reorder_roundtrip() { + // Create a pattern where every position has a unique value + let states: Vec = (0..4096).map(|i| i as u16).collect(); + let reordered = bedrock_reorder_xzy(&states); + let recovered = bedrock_reorder_xzy_inverse(&reordered); + assert_eq!(states, recovered, "reorder then inverse must be identity"); + } + + #[test] + fn test_bedrock_reorder_specific() { + let mut states = vec![0u16; 4096]; + + // Place known values at specific Java-order positions: + // Java index = y*256 + z*16 + x + // Bedrock index = x*256 + z*16 + y + + // (x=0, y=0, z=0) → Java idx 0, Bedrock idx 0 + states[0] = 100; + // (x=1, y=0, z=0) → Java idx 1, Bedrock idx 256 + states[1] = 200; + // (x=0, y=1, z=0) → Java idx 256, Bedrock idx 1 + states[256] = 300; + // (x=3, y=5, z=7) → Java idx 5*256+7*16+3 = 1395, Bedrock idx 3*256+7*16+5 = 885 + states[1395] = 400; + // (x=15, y=15, z=15) → Java idx 15*256+15*16+15 = 4095, Bedrock idx 4095 + states[4095] = 500; + + let reordered = bedrock_reorder_xzy(&states); + + assert_eq!(reordered[0], 100, "(0,0,0) should map to 0"); + assert_eq!(reordered[256], 200, "(1,0,0) should map to 256"); + assert_eq!(reordered[1], 300, "(0,1,0) should map to 1"); + assert_eq!(reordered[885], 400, "(3,5,7) should map to 885"); + assert_eq!(reordered[4095], 500, "(15,15,15) should map to 4095"); + } + + #[test] + fn test_bedrock_pack_section() { + use std::collections::HashMap; + + // Create states with a small palette + let mut states = vec![0u16; 4096]; + for i in 0..4096 { + states[i] = (i % 4) as u16; + } + + let mut palette = HashMap::new(); + palette.insert(0u16, 0u8); + palette.insert(1u16, 1u8); + palette.insert(2u16, 2u8); + palette.insert(3u16, 3u8); + + let bits = bits_for_palette_size(4); // 2 bits + let packed = bedrock_pack_section(&states, &palette, bits) + .expect("all states should be in palette"); + + // Verify by unpacking and inverse-reordering + let unpacked = unpack_indices(&packed, bits, 4096); + let bedrock_states: Vec = unpacked.iter().map(|&idx| { + // Reverse palette lookup: idx → state + *palette.iter().find(|(_, &v)| v == idx).unwrap().0 + }).collect(); + let java_states = bedrock_reorder_xzy_inverse(&bedrock_states); + assert_eq!(states, java_states, "pack then unpack+inverse must recover original"); + } + + #[test] + fn test_bedrock_pack_section_missing_palette_entry() { + use std::collections::HashMap; + + let mut states = vec![0u16; 4096]; + states[0] = 99; // Not in palette + + let mut palette = HashMap::new(); + palette.insert(0u16, 0u8); + + let result = bedrock_pack_section(&states, &palette, 1); + assert!(result.is_none(), "should return None for missing palette entry"); + } } diff --git a/src/hpc/property_mask.rs b/src/hpc/property_mask.rs index 4fd341db..c0c4388f 100644 --- a/src/hpc/property_mask.rs +++ b/src/hpc/property_mask.rs @@ -294,6 +294,113 @@ impl Default for PropertyMask { } } +/// Result of multi-mask counting: per-mask match counts from a single pass. +/// +/// Enables "count crops AND count liquids AND count redstone" in one scan, +/// avoiding redundant iteration over 4096 block states. +#[derive(Debug, Clone)] +pub struct MultiMaskResult { + /// Per-mask match counts, in the same order as the input masks. + pub counts: Vec, +} + +/// Count matches for multiple masks in a single pass over the data. +/// +/// More efficient than calling `count_section()` N times because: +/// - Single pass over the state array (one cache line read per state) +/// - Each state is loaded once and tested against all masks +/// +/// # Examples +/// +/// ``` +/// use ndarray::hpc::property_mask::{PropertyMask, count_section_multi}; +/// +/// let crops = PropertyMask::new().require_bit(0); +/// let liquids = PropertyMask::new().require_bit(1); +/// let redstone = PropertyMask::new().require_bit(2); +/// let states: Vec = (0..100).collect(); +/// let result = count_section_multi(&[crops, liquids, redstone], &states); +/// assert_eq!(result.counts.len(), 3); +/// ``` +pub fn count_section_multi(masks: &[PropertyMask], states: &[u64]) -> MultiMaskResult { + if masks.is_empty() { + return MultiMaskResult { counts: vec![] }; + } + + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx512f") && states.len() >= 8 { + // SAFETY: avx512f detected above, states.len() >= 8 guaranteed. + unsafe { + return count_section_multi_avx512(masks, states); + } + } + } + + // scalar fallback + count_section_multi_scalar(masks, states) +} + +/// Scalar fallback for multi-mask counting. +fn count_section_multi_scalar(masks: &[PropertyMask], states: &[u64]) -> MultiMaskResult { + let mut counts = vec![0u32; masks.len()]; + for &state in states { + for (m_idx, mask) in masks.iter().enumerate() { + if mask.test(state) { + counts[m_idx] += 1; + } + } + } + MultiMaskResult { counts } +} + +/// AVX-512 multi-mask counting: process 8 states at a time per mask. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +unsafe fn count_section_multi_avx512(masks: &[PropertyMask], states: &[u64]) -> MultiMaskResult { + use core::arch::x86_64::*; + + let mut counts = vec![0u32; masks.len()]; + let zero = _mm512_setzero_si512(); + let chunks = states.len() / 8; + + for c in 0..chunks { + let base = c * 8; + // SAFETY: base + 8 <= states.len(), avx512f checked by caller. + let vals = _mm512_loadu_si512(states.as_ptr().add(base) as *const __m512i); + + for (m_idx, mask) in masks.iter().enumerate() { + let and_mask_v = _mm512_set1_epi64(mask.and_mask as i64); + let and_expect_v = _mm512_set1_epi64(mask.and_expect as i64); + let andn_mask_v = _mm512_set1_epi64(mask.andn_mask as i64); + + // (vals & and_mask) == and_expect + let anded = _mm512_and_si512(vals, and_mask_v); + let eq_and = _mm512_cmpeq_epi64_mask(anded, and_expect_v); + + // (vals & andn_mask) == 0 + let andned = _mm512_and_si512(vals, andn_mask_v); + let eq_andn = _mm512_cmpeq_epi64_mask(andned, zero); + + // Both conditions: AND the two kmasks + let both = eq_and & eq_andn; + counts[m_idx] += (both as u32).count_ones(); + } + } + + // Scalar tail + for i in (chunks * 8)..states.len() { + let state = states[i]; + for (m_idx, mask) in masks.iter().enumerate() { + if mask.test(state) { + counts[m_idx] += 1; + } + } + } + + MultiMaskResult { counts } +} + #[cfg(test)] mod tests { use super::*; @@ -460,4 +567,59 @@ mod tests { let expected = states.iter().filter(|&&s| m.test(s)).count() as u32; assert_eq!(count, expected); } + + #[test] + fn test_count_multi_basic() { + let crops = PropertyMask::new().require_bit(0); + let liquids = PropertyMask::new().require_bit(1); + let redstone = PropertyMask::new().require_bit(2).forbid_bit(5); + + let states: Vec = (0..256).collect(); + let result = count_section_multi(&[crops, liquids, redstone], &states); + + assert_eq!(result.counts.len(), 3); + assert_eq!(result.counts[0], crops.count_section(&states)); + assert_eq!(result.counts[1], liquids.count_section(&states)); + assert_eq!(result.counts[2], redstone.count_section(&states)); + } + + #[test] + fn test_count_multi_empty_masks() { + let states: Vec = (0..100).collect(); + let result = count_section_multi(&[], &states); + assert!(result.counts.is_empty()); + } + + #[test] + fn test_count_multi_single() { + let m = PropertyMask::new().require_bit(3).forbid_bit(7); + let states: Vec = (0..200).collect(); + let result = count_section_multi(&[m], &states); + assert_eq!(result.counts.len(), 1); + assert_eq!(result.counts[0], m.count_section(&states)); + } + + #[test] + fn test_count_multi_avx512_parity() { + let masks = [ + PropertyMask::new().require_bit(0), + PropertyMask::new().require_bit(1).forbid_bit(4), + PropertyMask::new().require_value(8, 4, 0xA), + PropertyMask::new().forbid_bit(3).forbid_bit(6), + PropertyMask::new().require_bit(2).require_bit(5), + ]; + + let states: Vec = (0..1024u64).map(|i| i.wrapping_mul(0xABCDEF01)).collect(); + let result = count_section_multi(&masks, &states); + + assert_eq!(result.counts.len(), masks.len()); + for (m_idx, mask) in masks.iter().enumerate() { + let expected = mask.count_section(&states); + assert_eq!( + result.counts[m_idx], expected, + "multi-mask parity mismatch for mask index {}", + m_idx + ); + } + } }