Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 30 additions & 38 deletions crates/p64/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,21 +178,19 @@ type AttendFn = unsafe fn(&[u64; 64], u64, u8) -> (u8, u8, [u8; 64], u64);
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn attend_avx512(rows: &[u64; 64], query: u64, gamma: u8) -> (u8, u8, [u8; 64], u64) {
use std::arch::x86_64::*;
let mut best_idx = 0u8;
let mut best_score = 0u8;
let mut scores = [0u8; 64];
let mut fires = 0u64;

let q = _mm512_set1_epi64(query as i64);
// Process 8 rows per chunk, 8 chunks = 64 rows
// (scalar array ops — LLVM auto-vectorizes with target-cpu=x86-64-v4)
for chunk in 0..8 {
let base = chunk * 8;
// SAFETY: rows is [u64; 64], base..base+8 is in bounds, Palette64 is 64-byte aligned.
let r = _mm512_loadu_si512(rows[base..].as_ptr() as *const __m512i);
let anded = _mm512_and_si512(r, q);
// Extract 8 u64s and scalar popcount (no VPOPCNTDQ dependency)
let vals: [u64; 8] = std::mem::transmute(anded);
let mut vals = [0u64; 8];
for j in 0..8 {
vals[j] = rows[base + j] & query;
}
for j in 0..8 {
let score = vals[j].count_ones() as u8;
let idx = base + j;
Expand All @@ -212,20 +210,19 @@ unsafe fn attend_avx512(rows: &[u64; 64], query: u64, gamma: u8) -> (u8, u8, [u8
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn attend_avx2(rows: &[u64; 64], query: u64, gamma: u8) -> (u8, u8, [u8; 64], u64) {
use std::arch::x86_64::*;
let mut best_idx = 0u8;
let mut best_score = 0u8;
let mut scores = [0u8; 64];
let mut fires = 0u64;

let q = _mm256_set1_epi64x(query as i64);
// Process 4 rows per chunk, 16 chunks = 64 rows
// (scalar array ops — LLVM auto-vectorizes with target-cpu=x86-64-v4)
for chunk in 0..16 {
let base = chunk * 4;
// SAFETY: rows is [u64; 64], base..base+4 is in bounds.
let r = _mm256_loadu_si256(rows[base..].as_ptr() as *const __m256i);
let anded = _mm256_and_si256(r, q);
let vals: [u64; 4] = std::mem::transmute(anded);
let mut vals = [0u64; 4];
for j in 0..4 {
vals[j] = rows[base + j] & query;
}
for j in 0..4 {
let score = vals[j].count_ones() as u8;
let idx = base + j;
Expand Down Expand Up @@ -284,15 +281,14 @@ type NearestKFn = unsafe fn(&[u64; 64], u64) -> [u8; 64];
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn nearest_k_avx512(rows: &[u64; 64], query: u64) -> [u8; 64] {
use std::arch::x86_64::*;
let mut dists = [0u8; 64];
let q = _mm512_set1_epi64(query as i64);
// Scalar array ops — LLVM auto-vectorizes with target-cpu=x86-64-v4
for chunk in 0..8 {
let base = chunk * 8;
// SAFETY: rows is [u64; 64], base..base+8 is in bounds.
let r = _mm512_loadu_si512(rows[base..].as_ptr() as *const __m512i);
let xored = _mm512_xor_si512(r, q);
let vals: [u64; 8] = std::mem::transmute(xored);
let mut vals = [0u64; 8];
for j in 0..8 {
vals[j] = rows[base + j] ^ query;
}
for j in 0..8 {
dists[base + j] = vals[j].count_ones() as u8;
}
Expand All @@ -303,15 +299,14 @@ unsafe fn nearest_k_avx512(rows: &[u64; 64], query: u64) -> [u8; 64] {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn nearest_k_avx2(rows: &[u64; 64], query: u64) -> [u8; 64] {
use std::arch::x86_64::*;
let mut dists = [0u8; 64];
let q = _mm256_set1_epi64x(query as i64);
// Scalar array ops — LLVM auto-vectorizes with target-cpu=x86-64-v4
for chunk in 0..16 {
let base = chunk * 4;
// SAFETY: rows is [u64; 64], base..base+4 is in bounds.
let r = _mm256_loadu_si256(rows[base..].as_ptr() as *const __m256i);
let xored = _mm256_xor_si256(r, q);
let vals: [u64; 4] = std::mem::transmute(xored);
let mut vals = [0u64; 4];
for j in 0..4 {
vals[j] = rows[base + j] ^ query;
}
for j in 0..4 {
dists[base + j] = vals[j].count_ones() as u8;
}
Expand Down Expand Up @@ -350,13 +345,11 @@ type MoeGateFn = unsafe fn(&[u64; 8], u64, u8) -> (u8, [u8; 8], u64);
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn moe_gate_avx512(planes: &[u64; 8], query: u64, threshold: u8) -> (u8, [u8; 8], u64) {
use std::arch::x86_64::*;
// Load all 8 planes into one zmm register, AND with broadcast query
// SAFETY: planes is [u64; 8] = 64 bytes, fits in one zmm.
let p = _mm512_loadu_si512(planes.as_ptr() as *const __m512i);
let q = _mm512_set1_epi64(query as i64);
let anded = _mm512_and_si512(p, q);
let vals: [u64; 8] = std::mem::transmute(anded);
// Scalar array ops — LLVM auto-vectorizes with target-cpu=x86-64-v4
let mut vals = [0u64; 8];
for i in 0..8 {
vals[i] = planes[i] & query;
}

let mut active = 0u8;
let mut strength = [0u8; 8];
Expand All @@ -375,19 +368,18 @@ unsafe fn moe_gate_avx512(planes: &[u64; 8], query: u64, threshold: u8) -> (u8,
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn moe_gate_avx2(planes: &[u64; 8], query: u64, threshold: u8) -> (u8, [u8; 8], u64) {
use std::arch::x86_64::*;
let q = _mm256_set1_epi64x(query as i64);
// Scalar array ops — LLVM auto-vectorizes with target-cpu=x86-64-v4
let mut active = 0u8;
let mut strength = [0u8; 8];
let mut combined = 0u64;

// Process 4 planes at a time, 2 chunks = 8 planes
for chunk in 0..2 {
let base = chunk * 4;
// SAFETY: planes is [u64; 8], base..base+4 is in bounds.
let p = _mm256_loadu_si256(planes[base..].as_ptr() as *const __m256i);
let anded = _mm256_and_si256(p, q);
let vals: [u64; 4] = std::mem::transmute(anded);
let mut vals = [0u64; 4];
for j in 0..4 {
vals[j] = planes[base + j] & query;
}
for j in 0..4 {
let score = vals[j].count_ones() as u8;
let idx = base + j;
Expand Down
172 changes: 72 additions & 100 deletions src/hpc/aabb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,16 @@ fn aabb_intersect_batch_scalar(query: &Aabb, candidates: &[Aabb]) -> Vec<bool> {

/// AVX-512 batch AABB intersection: tests 16 candidates per axis comparison.
///
/// Broadcasts query min/max per axis, gathers candidate coords into __m512,
/// compares all 16 at once using `_mm512_cmp_ps_mask`, ANDs the 6 comparison
/// Broadcasts query min/max per axis, gathers candidate coords into F32x16,
/// compares all 16 at once using `simd_le` / `simd_ge`, ANDs the 6 comparison
/// masks.
///
/// # Safety
/// Caller must ensure AVX-512F is available.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn aabb_intersect_batch_avx512(query: &Aabb, candidates: &[Aabb]) -> Vec<bool> {
use core::arch::x86_64::*;
use crate::simd::{F32x16, F32Mask16};

let mut result = Vec::with_capacity(candidates.len());

Expand All @@ -203,32 +203,30 @@ unsafe fn aabb_intersect_batch_avx512(query: &Aabb, candidates: &[Aabb]) -> Vec<
c_max_z[i] = cand.max[2];
}

// SAFETY: arrays are 16-element, avx512f checked by caller.
let v_c_min_x = _mm512_loadu_ps(c_min_x.as_ptr());
let v_c_max_x = _mm512_loadu_ps(c_max_x.as_ptr());
let v_c_min_y = _mm512_loadu_ps(c_min_y.as_ptr());
let v_c_max_y = _mm512_loadu_ps(c_max_y.as_ptr());
let v_c_min_z = _mm512_loadu_ps(c_min_z.as_ptr());
let v_c_max_z = _mm512_loadu_ps(c_max_z.as_ptr());
let v_c_min_x = F32x16::from_array(c_min_x);
let v_c_max_x = F32x16::from_array(c_max_x);
let v_c_min_y = F32x16::from_array(c_min_y);
let v_c_max_y = F32x16::from_array(c_max_y);
let v_c_min_z = F32x16::from_array(c_min_z);
let v_c_max_z = F32x16::from_array(c_max_z);

// Broadcast query bounds
let q_min_x = _mm512_set1_ps(query.min[0]);
let q_max_x = _mm512_set1_ps(query.max[0]);
let q_min_y = _mm512_set1_ps(query.min[1]);
let q_max_y = _mm512_set1_ps(query.max[1]);
let q_min_z = _mm512_set1_ps(query.min[2]);
let q_max_z = _mm512_set1_ps(query.max[2]);
let q_min_x = F32x16::splat(query.min[0]);
let q_max_x = F32x16::splat(query.max[0]);
let q_min_y = F32x16::splat(query.min[1]);
let q_max_y = F32x16::splat(query.max[1]);
let q_min_z = F32x16::splat(query.min[2]);
let q_max_z = F32x16::splat(query.max[2]);

// 6 intersection conditions: q.min[i] <= c.max[i] && q.max[i] >= c.min[i]
// _CMP_LE_OQ = 18, _CMP_GE_OQ = 29 (ordered, quiet)
let m1 = _mm512_cmp_ps_mask::<{ _CMP_LE_OQ }>(q_min_x, v_c_max_x);
let m2 = _mm512_cmp_ps_mask::<{ _CMP_GE_OQ }>(q_max_x, v_c_min_x);
let m3 = _mm512_cmp_ps_mask::<{ _CMP_LE_OQ }>(q_min_y, v_c_max_y);
let m4 = _mm512_cmp_ps_mask::<{ _CMP_GE_OQ }>(q_max_y, v_c_min_y);
let m5 = _mm512_cmp_ps_mask::<{ _CMP_LE_OQ }>(q_min_z, v_c_max_z);
let m6 = _mm512_cmp_ps_mask::<{ _CMP_GE_OQ }>(q_max_z, v_c_min_z);
let m1 = q_min_x.simd_le(v_c_max_x);
let m2 = q_max_x.simd_ge(v_c_min_x);
let m3 = q_min_y.simd_le(v_c_max_y);
let m4 = q_max_y.simd_ge(v_c_min_y);
let m5 = q_min_z.simd_le(v_c_max_z);
let m6 = q_max_z.simd_ge(v_c_min_z);

let all = m1 & m2 & m3 & m4 & m5 & m6;
let all = m1.0 & m2.0 & m3.0 & m4.0 & m5.0 & m6.0;

for i in 0..16 {
result.push((all >> i) & 1 != 0);
Expand All @@ -246,24 +244,16 @@ unsafe fn aabb_intersect_batch_avx512(query: &Aabb, candidates: &[Aabb]) -> Vec<
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.1")]
unsafe fn aabb_intersect_batch_sse41(query: &Aabb, candidates: &[Aabb]) -> Vec<bool> {
use core::arch::x86_64::*;

// Load query min/max into SSE registers (only need xyz, ignore w).
let q_min = _mm_set_ps(0.0, query.min[2], query.min[1], query.min[0]);
let q_max = _mm_set_ps(f32::MAX, query.max[2], query.max[1], query.max[0]);

// Scalar per-candidate test — LLVM auto-vectorizes with target-cpu=x86-64-v4
let mut result = Vec::with_capacity(candidates.len());
for c in candidates {
let c_min = _mm_set_ps(0.0, c.min[2], c.min[1], c.min[0]);
let c_max = _mm_set_ps(f32::MAX, c.max[2], c.max[1], c.max[0]);

// q.min <= c.max AND q.max >= c.min (per component)
let le = _mm_cmple_ps(q_min, c_max); // q_min[i] <= c_max[i]
let ge = _mm_cmpge_ps(q_max, c_min); // q_max[i] >= c_min[i]
let both = _mm_and_ps(le, ge);
// All 4 lanes must be true (lane 3 is always true due to sentinel values).
let mask = _mm_movemask_ps(both);
result.push(mask == 0xF);
let hit = query.min[0] <= c.max[0]
&& query.max[0] >= c.min[0]
&& query.min[1] <= c.max[1]
&& query.max[1] >= c.min[1]
&& query.min[2] <= c.max[2]
&& query.max[2] >= c.min[2];
result.push(hit);
}
result
}
Expand Down Expand Up @@ -333,27 +323,27 @@ fn ray_aabb_slab_test_scalar(ray: &Ray, aabbs: &[Aabb]) -> (Vec<bool>, Vec<f32>)
/// AVX-512 batch ray-AABB slab test: processes 16 AABBs per iteration.
///
/// Broadcasts ray origin and inv_dir per axis, gathers candidate min/max
/// coords into SoA arrays, computes slab intervals with `_mm512_min_ps` /
/// `_mm512_max_ps`, and combines masks with `_mm512_cmp_ps_mask`.
/// coords into SoA arrays, computes slab intervals with `simd_min` /
/// `simd_max`, and combines masks with `simd_le` / `simd_ge`.
///
/// # Safety
/// Caller must ensure AVX-512F is available.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn ray_aabb_slab_test_avx512(ray: &Ray, aabbs: &[Aabb]) -> (Vec<bool>, Vec<f32>) {
use core::arch::x86_64::*;
use crate::simd::F32x16;

let mut hits = Vec::with_capacity(aabbs.len());
let mut t_values = Vec::with_capacity(aabbs.len());

// Broadcast ray origin and inv_dir per axis
let orig_x = _mm512_set1_ps(ray.origin[0]);
let orig_y = _mm512_set1_ps(ray.origin[1]);
let orig_z = _mm512_set1_ps(ray.origin[2]);
let inv_x = _mm512_set1_ps(ray.inv_dir[0]);
let inv_y = _mm512_set1_ps(ray.inv_dir[1]);
let inv_z = _mm512_set1_ps(ray.inv_dir[2]);
let zero = _mm512_set1_ps(0.0);
let orig_x = F32x16::splat(ray.origin[0]);
let orig_y = F32x16::splat(ray.origin[1]);
let orig_z = F32x16::splat(ray.origin[2]);
let inv_x = F32x16::splat(ray.inv_dir[0]);
let inv_y = F32x16::splat(ray.inv_dir[1]);
let inv_z = F32x16::splat(ray.inv_dir[2]);
let zero = F32x16::splat(0.0);

// Process 16 AABBs at a time
let chunks = aabbs.len() / 16;
Expand All @@ -378,49 +368,44 @@ unsafe fn ray_aabb_slab_test_avx512(ray: &Ray, aabbs: &[Aabb]) -> (Vec<bool>, Ve
a_max_z[i] = aabb.max[2];
}

// SAFETY: arrays are 16-element, avx512f checked by caller.
let v_min_x = _mm512_loadu_ps(a_min_x.as_ptr());
let v_max_x = _mm512_loadu_ps(a_max_x.as_ptr());
let v_min_y = _mm512_loadu_ps(a_min_y.as_ptr());
let v_max_y = _mm512_loadu_ps(a_max_y.as_ptr());
let v_min_z = _mm512_loadu_ps(a_min_z.as_ptr());
let v_max_z = _mm512_loadu_ps(a_max_z.as_ptr());
let v_min_x = F32x16::from_array(a_min_x);
let v_max_x = F32x16::from_array(a_max_x);
let v_min_y = F32x16::from_array(a_min_y);
let v_max_y = F32x16::from_array(a_max_y);
let v_min_z = F32x16::from_array(a_min_z);
let v_max_z = F32x16::from_array(a_max_z);

// X axis: t1 = (min - origin) * inv_dir, t2 = (max - origin) * inv_dir
let t1_x = _mm512_mul_ps(_mm512_sub_ps(v_min_x, orig_x), inv_x);
let t2_x = _mm512_mul_ps(_mm512_sub_ps(v_max_x, orig_x), inv_x);
let t_near_x = _mm512_min_ps(t1_x, t2_x);
let t_far_x = _mm512_max_ps(t1_x, t2_x);
let t1_x = (v_min_x - orig_x) * inv_x;
let t2_x = (v_max_x - orig_x) * inv_x;
let t_near_x = t1_x.simd_min(t2_x);
let t_far_x = t1_x.simd_max(t2_x);

// Y axis
let t1_y = _mm512_mul_ps(_mm512_sub_ps(v_min_y, orig_y), inv_y);
let t2_y = _mm512_mul_ps(_mm512_sub_ps(v_max_y, orig_y), inv_y);
let t_near_y = _mm512_min_ps(t1_y, t2_y);
let t_far_y = _mm512_max_ps(t1_y, t2_y);
let t1_y = (v_min_y - orig_y) * inv_y;
let t2_y = (v_max_y - orig_y) * inv_y;
let t_near_y = t1_y.simd_min(t2_y);
let t_far_y = t1_y.simd_max(t2_y);

// Z axis
let t1_z = _mm512_mul_ps(_mm512_sub_ps(v_min_z, orig_z), inv_z);
let t2_z = _mm512_mul_ps(_mm512_sub_ps(v_max_z, orig_z), inv_z);
let t_near_z = _mm512_min_ps(t1_z, t2_z);
let t_far_z = _mm512_max_ps(t1_z, t2_z);
let t1_z = (v_min_z - orig_z) * inv_z;
let t2_z = (v_max_z - orig_z) * inv_z;
let t_near_z = t1_z.simd_min(t2_z);
let t_far_z = t1_z.simd_max(t2_z);

// t_enter = max(t_near_x, t_near_y, t_near_z)
let t_enter = _mm512_max_ps(_mm512_max_ps(t_near_x, t_near_y), t_near_z);
let t_enter = t_near_x.simd_max(t_near_y).simd_max(t_near_z);
// t_exit = min(t_far_x, t_far_y, t_far_z)
let t_exit = _mm512_min_ps(_mm512_min_ps(t_far_x, t_far_y), t_far_z);
let t_exit = t_far_x.simd_min(t_far_y).simd_min(t_far_z);

// hit = t_enter <= t_exit AND t_exit >= 0
// _CMP_LE_OQ = 18, _CMP_GE_OQ = 29 (ordered, quiet)
let m_le = _mm512_cmp_ps_mask::<{ _CMP_LE_OQ }>(t_enter, t_exit);
let m_ge = _mm512_cmp_ps_mask::<{ _CMP_GE_OQ }>(t_exit, zero);
let hit_mask = m_le & m_ge;
let m_le = t_enter.simd_le(t_exit);
let m_ge = t_exit.simd_ge(zero);
let hit_mask = m_le.0 & m_ge.0;

// Clamp t_enter to 0 for origins inside box
let t_enter_clamped = _mm512_max_ps(t_enter, zero);

// SAFETY: 16-element array matches __m512 lane count.
let mut t_arr = [0.0f32; 16];
_mm512_storeu_ps(t_arr.as_mut_ptr(), t_enter_clamped);
let t_enter_clamped = t_enter.simd_max(zero);
let t_arr = t_enter_clamped.to_array();

for i in 0..16 {
let hit = (hit_mask >> i) & 1 != 0;
Expand Down Expand Up @@ -482,27 +467,14 @@ fn aabb_expand_batch_scalar(aabbs: &mut [Aabb], dx: f32, dy: f32, dz: f32) {
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn aabb_expand_batch_sse2(aabbs: &mut [Aabb], dx: f32, dy: f32, dz: f32) {
use core::arch::x86_64::*;

let delta_min = _mm_set_ps(0.0, dz, dy, dx);
let delta_max = _mm_set_ps(0.0, dz, dy, dx);

// Scalar per-AABB expand — LLVM auto-vectorizes with target-cpu=x86-64-v4
for a in aabbs.iter_mut() {
let min_v = _mm_set_ps(0.0, a.min[2], a.min[1], a.min[0]);
let max_v = _mm_set_ps(0.0, a.max[2], a.max[1], a.max[0]);

let new_min = _mm_sub_ps(min_v, delta_min);
let new_max = _mm_add_ps(max_v, delta_max);

// Store back. We cannot use _mm_storeu_ps directly into [f32;3],
// so extract components.
let mut min_arr = [0.0f32; 4];
let mut max_arr = [0.0f32; 4];
_mm_storeu_ps(min_arr.as_mut_ptr(), new_min);
_mm_storeu_ps(max_arr.as_mut_ptr(), new_max);

a.min = [min_arr[0], min_arr[1], min_arr[2]];
a.max = [max_arr[0], max_arr[1], max_arr[2]];
a.min[0] -= dx;
a.min[1] -= dy;
a.min[2] -= dz;
a.max[0] += dx;
a.max[1] += dy;
a.max[2] += dz;
}
}

Expand Down
Loading
Loading