Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions src/hpc/aabb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,15 @@ unsafe fn aabb_intersect_batch_sse41(query: &Aabb, candidates: &[Aabb]) -> Vec<b
/// assert!(!hits[1]);
/// ```
pub fn ray_aabb_slab_test_batch(ray: &Ray, aabbs: &[Aabb]) -> (Vec<bool>, Vec<f32>) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") && aabbs.len() >= 16 {
// SAFETY: avx512f detected, enough AABBs for batch processing.
unsafe {
return ray_aabb_slab_test_avx512(ray, aabbs);
}
}
}
ray_aabb_slab_test_scalar(ray, aabbs)
}

Expand Down Expand Up @@ -320,6 +329,128 @@ fn ray_aabb_slab_test_scalar(ray: &Ray, aabbs: &[Aabb]) -> (Vec<bool>, Vec<f32>)
(hits, t_values)
}

/// AVX-512 batch ray-AABB slab test: processes 16 AABBs per iteration.
///
/// Broadcasts ray origin and inv_dir per axis, gathers candidate min/max
/// coords into SoA arrays, computes slab intervals with `_mm512_min_ps` /
/// `_mm512_max_ps`, and combines masks with `_mm512_cmp_ps_mask`.
///
/// # Safety
/// Caller must ensure AVX-512F is available.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn ray_aabb_slab_test_avx512(ray: &Ray, aabbs: &[Aabb]) -> (Vec<bool>, Vec<f32>) {
use core::arch::x86_64::*;

let mut hits = Vec::with_capacity(aabbs.len());
let mut t_values = Vec::with_capacity(aabbs.len());

// Broadcast ray origin and inv_dir per axis
let orig_x = _mm512_set1_ps(ray.origin[0]);
let orig_y = _mm512_set1_ps(ray.origin[1]);
let orig_z = _mm512_set1_ps(ray.origin[2]);
let inv_x = _mm512_set1_ps(ray.inv_dir[0]);
let inv_y = _mm512_set1_ps(ray.inv_dir[1]);
let inv_z = _mm512_set1_ps(ray.inv_dir[2]);
let zero = _mm512_set1_ps(0.0);

// Process 16 AABBs at a time
let chunks = aabbs.len() / 16;
for c in 0..chunks {
let base = c * 16;

// Gather min/max coords for 16 AABBs into SoA arrays
let mut a_min_x = [0.0f32; 16];
let mut a_max_x = [0.0f32; 16];
let mut a_min_y = [0.0f32; 16];
let mut a_max_y = [0.0f32; 16];
let mut a_min_z = [0.0f32; 16];
let mut a_max_z = [0.0f32; 16];

for i in 0..16 {
let aabb = &aabbs[base + i];
a_min_x[i] = aabb.min[0];
a_max_x[i] = aabb.max[0];
a_min_y[i] = aabb.min[1];
a_max_y[i] = aabb.max[1];
a_min_z[i] = aabb.min[2];
a_max_z[i] = aabb.max[2];
}

// SAFETY: arrays are 16-element, avx512f checked by caller.
let v_min_x = _mm512_loadu_ps(a_min_x.as_ptr());
let v_max_x = _mm512_loadu_ps(a_max_x.as_ptr());
let v_min_y = _mm512_loadu_ps(a_min_y.as_ptr());
let v_max_y = _mm512_loadu_ps(a_max_y.as_ptr());
let v_min_z = _mm512_loadu_ps(a_min_z.as_ptr());
let v_max_z = _mm512_loadu_ps(a_max_z.as_ptr());

// X axis: t1 = (min - origin) * inv_dir, t2 = (max - origin) * inv_dir
let t1_x = _mm512_mul_ps(_mm512_sub_ps(v_min_x, orig_x), inv_x);
let t2_x = _mm512_mul_ps(_mm512_sub_ps(v_max_x, orig_x), inv_x);
let t_near_x = _mm512_min_ps(t1_x, t2_x);
let t_far_x = _mm512_max_ps(t1_x, t2_x);

// Y axis
let t1_y = _mm512_mul_ps(_mm512_sub_ps(v_min_y, orig_y), inv_y);
let t2_y = _mm512_mul_ps(_mm512_sub_ps(v_max_y, orig_y), inv_y);
let t_near_y = _mm512_min_ps(t1_y, t2_y);
let t_far_y = _mm512_max_ps(t1_y, t2_y);

// Z axis
let t1_z = _mm512_mul_ps(_mm512_sub_ps(v_min_z, orig_z), inv_z);
let t2_z = _mm512_mul_ps(_mm512_sub_ps(v_max_z, orig_z), inv_z);
let t_near_z = _mm512_min_ps(t1_z, t2_z);
let t_far_z = _mm512_max_ps(t1_z, t2_z);

// t_enter = max(t_near_x, t_near_y, t_near_z)
let t_enter = _mm512_max_ps(_mm512_max_ps(t_near_x, t_near_y), t_near_z);
// t_exit = min(t_far_x, t_far_y, t_far_z)
let t_exit = _mm512_min_ps(_mm512_min_ps(t_far_x, t_far_y), t_far_z);

// hit = t_enter <= t_exit AND t_exit >= 0
// _CMP_LE_OQ = 18, _CMP_GE_OQ = 29 (ordered, quiet)
let m_le = _mm512_cmp_ps_mask::<{ _CMP_LE_OQ }>(t_enter, t_exit);
let m_ge = _mm512_cmp_ps_mask::<{ _CMP_GE_OQ }>(t_exit, zero);
let hit_mask = m_le & m_ge;

// Clamp t_enter to 0 for origins inside box
let t_enter_clamped = _mm512_max_ps(t_enter, zero);

// SAFETY: 16-element array matches __m512 lane count.
let mut t_arr = [0.0f32; 16];
_mm512_storeu_ps(t_arr.as_mut_ptr(), t_enter_clamped);

for i in 0..16 {
let hit = (hit_mask >> i) & 1 != 0;
hits.push(hit);
t_values.push(if hit { t_arr[i] } else { f32::MAX });
}
}

// Scalar tail for remainder
for i in (chunks * 16)..aabbs.len() {
let aabb = &aabbs[i];
let mut t_enter = f32::NEG_INFINITY;
let mut t_exit = f32::INFINITY;

for axis in 0..3 {
let t1 = (aabb.min[axis] - ray.origin[axis]) * ray.inv_dir[axis];
let t2 = (aabb.max[axis] - ray.origin[axis]) * ray.inv_dir[axis];
let t_near = t1.min(t2);
let t_far = t1.max(t2);
t_enter = t_enter.max(t_near);
t_exit = t_exit.min(t_far);
}

let hit = t_enter <= t_exit && t_exit >= 0.0;
hits.push(hit);
t_values.push(if hit { t_enter.max(0.0) } else { f32::MAX });
}

(hits, t_values)
}

/// Expand all AABBs in-place by `(dx, dy, dz)` in both directions per axis.
pub fn aabb_expand_batch(aabbs: &mut [Aabb], dx: f32, dy: f32, dz: f32) {
#[cfg(target_arch = "x86_64")]
Expand Down Expand Up @@ -722,4 +853,28 @@ mod tests {
assert!(approx_eq(ray.inv_dir[0], 0.5));
assert!(ray.inv_dir[1].is_infinite());
}

// ---------- AVX-512 ray-AABB parity ----------

#[test]
fn test_ray_aabb_avx512_parity() {
// 100 AABBs to exercise AVX-512 + tail
let ray = Ray::new([0.0, 0.5, 0.5], [1.0, 0.0, 0.0]);
let aabbs: Vec<Aabb> = (0..100)
.map(|i| {
let f = i as f32;
Aabb::new([f, 0.0, 0.0], [f + 1.0, 1.0, 1.0])
})
.collect();
let (hits_batch, ts_batch) = ray_aabb_slab_test_batch(&ray, &aabbs);
let (hits_scalar, ts_scalar) = ray_aabb_slab_test_scalar(&ray, &aabbs);
assert_eq!(hits_batch, hits_scalar, "ray AVX-512 hit parity");
for i in 0..100 {
assert!(
approx_eq(ts_batch[i], ts_scalar[i]),
"ray AVX-512 t parity at {i}: {} vs {}",
ts_batch[i], ts_scalar[i]
);
}
}
}
Loading
Loading