diff --git a/src/simd.rs b/src/simd.rs index ccd35aa0..5f37eb4c 100644 --- a/src/simd.rs +++ b/src/simd.rs @@ -265,6 +265,14 @@ pub use crate::simd_avx2::{ I32x16, I64x8, I8x64, U16x32, U32x16, U64x8, U8x64, }; +// U8x32 — native AVX2 byte width (one __m256i = 32 bytes). Available on +// both AVX-512 and AVX2 builds: it's the natural width for byte-level +// AVX2 ops, and on AVX-512 builds it's the half-register companion to +// U8x64. Lives in simd_avx2.rs (single source of truth) and is re-exported +// from both tier branches. +#[cfg(target_arch = "x86_64")] +pub use crate::simd_avx2::{u8x32, U8x32}; + // ============================================================================ // Non-x86: scalar fallback types with identical API // ============================================================================ diff --git a/src/simd_avx2.rs b/src/simd_avx2.rs index 0f7799fe..5e164eab 100644 --- a/src/simd_avx2.rs +++ b/src/simd_avx2.rs @@ -1728,6 +1728,521 @@ impl I64x8 { } } +// ═══════════════════════════════════════════════════════════════════ +// U8x32 — native AVX2 byte vector (one __m256i = 32 bytes). +// +// The AVX2-tier "byte width" for the polyfill. AVX-512's U8x64 lives in +// simd_avx512.rs and maps to one __m512i; on AVX2 the equivalent 64-byte +// shape (the U8x64 macro above) is implemented as a scalar [u8; 64] +// fallback because AVX2's natural byte width is 32, not 64. Consumers +// that want REAL AVX2 SIMD speedup over scalar should chunk their data +// in 32-byte windows and use U8x32. +// +// Requires AVX2 at compile time (project baseline is x86-64-v3, so this +// holds on every supported build). Calling these methods on a baseline +// x86_64 build (no AVX2) would SIGILL — same constraint as the rest of +// the file's `_mm256_*` users (e.g. the AVX2 popcount at line ~357). +// ═══════════════════════════════════════════════════════════════════ + +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::*; + +/// 32-byte unsigned-integer SIMD vector mapping to one AVX2 `__m256i`. +/// +/// API mirrors `simd_avx512::U8x64` so consumer code can pick the natural +/// byte width for its loop (32 on AVX2, 64 on AVX-512) and rely on the +/// same method set. The polyfill in `simd.rs` re-exports both. +#[cfg(target_arch = "x86_64")] +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct U8x32(pub __m256i); + +#[cfg(target_arch = "x86_64")] +impl U8x32 { + /// Number of u8 lanes (32 = one AVX2 ymm register). + pub const LANES: usize = 32; + + // ── Constructors ──────────────────────────────────────────────── + + /// Broadcast a single byte to all 32 lanes. + #[inline(always)] + pub fn splat(v: u8) -> Self { + // SAFETY: AVX2 is the project baseline (x86-64-v3); calling + // `_mm256_set1_epi8` requires AVX, which AVX2 implies. + Self(unsafe { _mm256_set1_epi8(v as i8) }) + } + + /// Unaligned load 32 bytes from a slice. Panics if `s.len() < 32`. + #[inline(always)] + pub fn from_slice(s: &[u8]) -> Self { + assert!(s.len() >= 32, "U8x32::from_slice needs ≥32 bytes"); + // SAFETY: bounds checked above; loadu allows unaligned src. + Self(unsafe { _mm256_loadu_si256(s.as_ptr() as *const __m256i) }) + } + + /// Load 32 bytes from a fixed-size array. + #[inline(always)] + pub fn from_array(arr: [u8; 32]) -> Self { + // SAFETY: `arr` is exactly 32 bytes contiguous; loadu allows any align. + Self(unsafe { _mm256_loadu_si256(arr.as_ptr() as *const __m256i) }) + } + + /// Store all 32 bytes to a `[u8; 32]` array. + #[inline(always)] + pub fn to_array(self) -> [u8; 32] { + let mut out = [0u8; 32]; + // SAFETY: `out` is exactly 32 bytes contiguous; storeu allows any align. + unsafe { _mm256_storeu_si256(out.as_mut_ptr() as *mut __m256i, self.0) }; + out + } + + /// Copy all 32 bytes into a mutable slice. Panics if `s.len() < 32`. + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [u8]) { + assert!(s.len() >= 32, "U8x32::copy_to_slice needs ≥32 bytes"); + // SAFETY: bounds checked above; storeu allows unaligned dst. + unsafe { _mm256_storeu_si256(s.as_mut_ptr() as *mut __m256i, self.0) }; + } + + // ── Reductions ────────────────────────────────────────────────── + + /// Sum of all 32 bytes, modulo 2^8. Wraps on overflow. + #[inline(always)] + pub fn reduce_sum(self) -> u8 { + let arr = self.to_array(); + arr.iter().fold(0u8, |acc, &b| acc.wrapping_add(b)) + } + + /// Unsigned minimum across all 32 lanes. + #[inline(always)] + pub fn reduce_min(self) -> u8 { + let arr = self.to_array(); + *arr.iter().min().unwrap() + } + + /// Unsigned maximum across all 32 lanes. + #[inline(always)] + pub fn reduce_max(self) -> u8 { + let arr = self.to_array(); + *arr.iter().max().unwrap() + } + + /// Sum-of-absolute-differences against zero ⇒ horizontal byte sum + /// folded into the low 64 bits of each 128-bit lane, then combined. + /// Returns the total as u64 (does NOT wrap at 2^8). Useful for + /// counting set bits in popcount-style masks. + #[inline(always)] + pub fn sum_bytes_u64(self) -> u64 { + // SAFETY: AVX2 baseline. + let sums = unsafe { _mm256_sad_epu8(self.0, _mm256_setzero_si256()) }; + // sad_epu8 places 4 partial sums (one per 64-bit lane) in u16 slots. + // Pull them out and add manually — small N, scalar is fine. + let mut tmp = [0u64; 4]; + unsafe { _mm256_storeu_si256(tmp.as_mut_ptr() as *mut __m256i, sums) }; + tmp[0] + tmp[1] + tmp[2] + tmp[3] + } + + // ── Min / max (lane-wise) ─────────────────────────────────────── + + /// Lane-wise unsigned min. + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + // SAFETY: AVX2 baseline. + Self(unsafe { _mm256_min_epu8(self.0, other.0) }) + } + + /// Lane-wise unsigned max. + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + // SAFETY: AVX2 baseline. + Self(unsafe { _mm256_max_epu8(self.0, other.0) }) + } + + // ── Comparison → bitmask ──────────────────────────────────────── + + /// Per-lane equality. Returns a 32-bit mask: bit `i` set iff + /// `self[i] == other[i]`. (Matches the shape of `U8x64::cmpeq_mask` + /// at the natural AVX2 width.) + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u32 { + // SAFETY: AVX2 baseline. + let eq = unsafe { _mm256_cmpeq_epi8(self.0, other.0) }; + // movemask_epi8 extracts the MSB of each byte. After cmpeq, each + // lane is 0xFF (match) or 0x00 (mismatch); MSB matches what we want. + unsafe { _mm256_movemask_epi8(eq) as u32 } + } + + /// Per-lane unsigned greater-than. Returns a 32-bit mask. + /// AVX2 only has signed `_mm256_cmpgt_epi8`, so we XOR both + /// operands with `0x80` to convert unsigned ↔ signed (preserves + /// ordering for unsigned compare). + #[inline(always)] + pub fn cmpgt_mask(self, other: Self) -> u32 { + // SAFETY: AVX2 baseline. + unsafe { + let bias = _mm256_set1_epi8(i8::MIN); // 0x80 + let a_s = _mm256_xor_si256(self.0, bias); + let b_s = _mm256_xor_si256(other.0, bias); + let gt = _mm256_cmpgt_epi8(a_s, b_s); + _mm256_movemask_epi8(gt) as u32 + } + } + + /// Extract MSB of each lane as a 32-bit mask (matches + /// `U8x64::movemask` at AVX2 width). + #[inline(always)] + pub fn movemask(self) -> u32 { + // SAFETY: AVX2 baseline. + unsafe { _mm256_movemask_epi8(self.0) as u32 } + } + + // ── Saturating arithmetic ──────────────────────────────────────── + + /// Per-lane saturating unsigned add: `min(a + b, 255)`. + #[inline(always)] + pub fn saturating_add(self, other: Self) -> Self { + // SAFETY: AVX2 baseline. + Self(unsafe { _mm256_adds_epu8(self.0, other.0) }) + } + + /// Per-lane saturating unsigned sub: `max(a - b, 0)`. + #[inline(always)] + pub fn saturating_sub(self, other: Self) -> Self { + // SAFETY: AVX2 baseline. + Self(unsafe { _mm256_subs_epu8(self.0, other.0) }) + } + + /// Per-lane unsigned rounded average: `(a + b + 1) >> 1`. + #[inline(always)] + pub fn pairwise_avg(self, other: Self) -> Self { + // SAFETY: AVX2 baseline. + Self(unsafe { _mm256_avg_epu8(self.0, other.0) }) + } + + // ── 16-bit-lane shifts (used by nibble pack/unpack) ───────────── + + /// Right shift each 16-bit lane by `imm` bits. (AVX2 has no native + /// 8-bit shift; 16-bit shift + mask is the standard idiom.) + #[inline(always)] + pub fn shr_epi16(self, imm: u32) -> Self { + // SAFETY: AVX2 baseline. `imm` is an arbitrary count; we use the + // vector-count form to avoid the const-generic constraint. + Self(unsafe { _mm256_srl_epi16(self.0, _mm_cvtsi32_si128(imm as i32)) }) + } + + /// Left shift each 16-bit lane by `imm` bits. + #[inline(always)] + pub fn shl_epi16(self, imm: u32) -> Self { + // SAFETY: AVX2 baseline. Vector-count form (see shr_epi16). + Self(unsafe { _mm256_sll_epi16(self.0, _mm_cvtsi32_si128(imm as i32)) }) + } + + // ── Lane shuffles ─────────────────────────────────────────────── + + /// Within-128-bit-lane byte shuffle. `idx[i]` (0..16) selects the + /// source byte within the SAME 128-bit half; high-bit set in + /// `idx[i]` zeroes the output lane. Matches `_mm256_shuffle_epi8`. + /// (Cross-lane permute is NOT available in pure AVX2 — use + /// `permute_bytes` for that, which falls back to scalar.) + #[inline(always)] + pub fn shuffle_bytes(self, idx: Self) -> Self { + // SAFETY: AVX2 baseline. + Self(unsafe { _mm256_shuffle_epi8(self.0, idx.0) }) + } + + /// Cross-lane byte permute (full 32-byte). AVX2 has no native + /// cross-lane byte permute, so this falls back to scalar — same + /// shape as `simd_avx512::U8x64::permute_bytes` does on + /// AVX-512F-without-VBMI hosts. + /// + /// `idx[i]` selects the source byte at position `idx[i] & 0x1F`. + #[inline(always)] + pub fn permute_bytes(self, idx: Self) -> Self { + let src = self.to_array(); + let idx_arr = idx.to_array(); + let mut out = [0u8; 32]; + for i in 0..32 { + out[i] = src[(idx_arr[i] & 0x1F) as usize]; + } + Self::from_array(out) + } + + /// Interleave low 8 bytes of each 128-bit half (`_mm256_unpacklo_epi8`). + /// Output: `[a0,b0, a1,b1, ..., a7,b7]` within each 128-bit half. + #[inline(always)] + pub fn unpack_lo_epi8(self, other: Self) -> Self { + // SAFETY: AVX2 baseline. + Self(unsafe { _mm256_unpacklo_epi8(self.0, other.0) }) + } + + /// Interleave high 8 bytes of each 128-bit half (`_mm256_unpackhi_epi8`). + #[inline(always)] + pub fn unpack_hi_epi8(self, other: Self) -> Self { + // SAFETY: AVX2 baseline. + Self(unsafe { _mm256_unpackhi_epi8(self.0, other.0) }) + } + + // ── Conditional move via bit mask ─────────────────────────────── + + /// Select `a` where mask bit is set, else `b`. The mask is a + /// `U8x32` whose lane MSB acts as the boolean (matches + /// `_mm256_blendv_epi8` semantics — different from the + /// 64-bit-bitmask shape of `U8x64::mask_blend`). + #[inline(always)] + pub fn mask_blend(mask: Self, a: Self, b: Self) -> Self { + // SAFETY: AVX2 baseline. + Self(unsafe { _mm256_blendv_epi8(b.0, a.0, mask.0) }) + } + + // ── Pre-computed LUT used by nibble popcount ─────────────────── + + /// Returns a `U8x32` populated with the nibble-popcount lookup + /// table replicated across both 128-bit halves. `shuffle_bytes` + /// with this LUT computes `popcount(nibble)` for 32 bytes in + /// parallel. + #[inline(always)] + pub fn nibble_popcount_lut() -> Self { + // Index i ∈ [0,15] → number of set bits in i. + Self::from_array([ + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + ]) + } +} + +// Bitwise + arithmetic operator impls so consumers can use natural +// `a + b`, `a & b`, etc. without method chaining. Match the U8x64 shape. + +#[cfg(target_arch = "x86_64")] +impl core::ops::BitAnd for U8x32 { + type Output = Self; + #[inline(always)] + fn bitand(self, rhs: Self) -> Self { + // SAFETY: AVX2 baseline. + Self(unsafe { _mm256_and_si256(self.0, rhs.0) }) + } +} + +#[cfg(target_arch = "x86_64")] +impl core::ops::BitOr for U8x32 { + type Output = Self; + #[inline(always)] + fn bitor(self, rhs: Self) -> Self { + // SAFETY: AVX2 baseline. + Self(unsafe { _mm256_or_si256(self.0, rhs.0) }) + } +} + +#[cfg(target_arch = "x86_64")] +impl core::ops::BitXor for U8x32 { + type Output = Self; + #[inline(always)] + fn bitxor(self, rhs: Self) -> Self { + // SAFETY: AVX2 baseline. + Self(unsafe { _mm256_xor_si256(self.0, rhs.0) }) + } +} + +#[cfg(target_arch = "x86_64")] +impl core::ops::Add for U8x32 { + type Output = Self; + #[inline(always)] + fn add(self, rhs: Self) -> Self { + // SAFETY: AVX2 baseline. WRAPS — use saturating_add for clamp. + Self(unsafe { _mm256_add_epi8(self.0, rhs.0) }) + } +} + +#[cfg(target_arch = "x86_64")] +impl core::ops::Sub for U8x32 { + type Output = Self; + #[inline(always)] + fn sub(self, rhs: Self) -> Self { + // SAFETY: AVX2 baseline. WRAPS — use saturating_sub for clamp. + Self(unsafe { _mm256_sub_epi8(self.0, rhs.0) }) + } +} + +#[cfg(target_arch = "x86_64")] +impl core::fmt::Debug for U8x32 { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "U8x32({:?})", self.to_array()) + } +} + +#[cfg(target_arch = "x86_64")] +impl Default for U8x32 { + #[inline(always)] + fn default() -> Self { + Self::splat(0) + } +} + +// ═══════════════════════════════════════════════════════════════════ +// U8x32 tests +// ═══════════════════════════════════════════════════════════════════ + +#[cfg(all(test, target_arch = "x86_64"))] +mod u8x32_tests { + use super::U8x32; + + #[test] + fn splat_and_to_array() { + let v = U8x32::splat(42); + assert_eq!(v.to_array(), [42u8; 32]); + } + + #[test] + fn from_array_roundtrip() { + let arr: [u8; 32] = core::array::from_fn(|i| i as u8); + let v = U8x32::from_array(arr); + assert_eq!(v.to_array(), arr); + } + + #[test] + fn from_slice_to_slice() { + let src: Vec = (0..40).map(|i| i as u8).collect(); + let v = U8x32::from_slice(&src); + let mut dst = vec![0u8; 32]; + v.copy_to_slice(&mut dst); + assert_eq!(&dst[..], &src[..32]); + } + + #[test] + fn reduce_sum_wraps() { + // 32 × 8 = 256 → wraps to 0 in u8. + let v = U8x32::splat(8); + assert_eq!(v.reduce_sum(), 0); + } + + #[test] + fn sum_bytes_u64_does_not_wrap() { + // 32 × 100 = 3200 — beyond u8, but sum_bytes_u64 returns full u64. + let v = U8x32::splat(100); + assert_eq!(v.sum_bytes_u64(), 3200); + } + + #[test] + fn reduce_min_max() { + let arr: [u8; 32] = core::array::from_fn(|i| (i * 7 + 3) as u8); + let v = U8x32::from_array(arr); + assert_eq!(v.reduce_min(), *arr.iter().min().unwrap()); + assert_eq!(v.reduce_max(), *arr.iter().max().unwrap()); + } + + #[test] + fn simd_min_max_unsigned() { + let a = U8x32::from_array(core::array::from_fn(|i| i as u8)); + let b = U8x32::from_array(core::array::from_fn(|i| (31 - i) as u8)); + let lo = a.simd_min(b).to_array(); + let hi = a.simd_max(b).to_array(); + for i in 0..32 { + assert_eq!(lo[i], (i as u8).min((31 - i) as u8)); + assert_eq!(hi[i], (i as u8).max((31 - i) as u8)); + } + } + + #[test] + fn cmpeq_mask_matches_scalar() { + let a: [u8; 32] = core::array::from_fn(|i| (i & 7) as u8); + let b: [u8; 32] = core::array::from_fn(|i| (i & 6) as u8); + let va = U8x32::from_array(a); + let vb = U8x32::from_array(b); + let m = va.cmpeq_mask(vb); + for i in 0..32 { + let bit = (m >> i) & 1 == 1; + assert_eq!(bit, a[i] == b[i], "lane {} disagrees", i); + } + } + + #[test] + fn cmpgt_mask_matches_scalar_unsigned() { + // High bytes (>= 128) must compare as unsigned, NOT signed. + let a: [u8; 32] = core::array::from_fn(|i| (i * 9) as u8); + let b: [u8; 32] = core::array::from_fn(|i| (200u8.wrapping_sub(i as u8)) as u8); + let va = U8x32::from_array(a); + let vb = U8x32::from_array(b); + let m = va.cmpgt_mask(vb); + for i in 0..32 { + let bit = (m >> i) & 1 == 1; + assert_eq!(bit, a[i] > b[i], "lane {} disagrees (a={} b={})", i, a[i], b[i]); + } + } + + #[test] + fn saturating_add_clamps() { + let a = U8x32::splat(200); + let b = U8x32::splat(100); + assert_eq!(a.saturating_add(b).to_array(), [255u8; 32]); + } + + #[test] + fn saturating_sub_clamps() { + let a = U8x32::splat(10); + let b = U8x32::splat(50); + assert_eq!(a.saturating_sub(b).to_array(), [0u8; 32]); + } + + #[test] + fn pairwise_avg_rounds_up() { + let a = U8x32::splat(7); + let b = U8x32::splat(8); + // (7 + 8 + 1) >> 1 = 8 + assert_eq!(a.pairwise_avg(b).to_array(), [8u8; 32]); + } + + #[test] + fn shr_epi16_extracts_nibble() { + // Pack 0xAB in low byte of each 16-bit pair, shift right 4 → 0x0A. + let mut arr = [0u8; 32]; + for i in (0..32).step_by(2) { + arr[i] = 0xAB; + } + let shifted = U8x32::from_array(arr).shr_epi16(4).to_array(); + for i in (0..32).step_by(2) { + assert_eq!(shifted[i], 0x0A); + assert_eq!(shifted[i + 1], 0x00); + } + } + + #[test] + fn permute_bytes_reverse() { + let src: [u8; 32] = core::array::from_fn(|i| i as u8); + let idx: [u8; 32] = core::array::from_fn(|i| (31 - i) as u8); + let out = U8x32::from_array(src) + .permute_bytes(U8x32::from_array(idx)) + .to_array(); + for i in 0..32 { + assert_eq!(out[i], src[31 - i]); + } + } + + #[test] + fn mask_blend_selects_per_msb() { + let a = U8x32::splat(0xAA); + let b = U8x32::splat(0x55); + // Mask with MSB set on every other lane → selects a/b alternating. + let mask_arr: [u8; 32] = core::array::from_fn(|i| if i % 2 == 0 { 0x80 } else { 0x00 }); + let mask = U8x32::from_array(mask_arr); + let out = U8x32::mask_blend(mask, a, b).to_array(); + for i in 0..32 { + assert_eq!(out[i], if i % 2 == 0 { 0xAA } else { 0x55 }); + } + } + + #[test] + fn nibble_popcount_lut_via_shuffle() { + let lut = U8x32::nibble_popcount_lut(); + // For each possible nibble value 0..15, shuffle should produce + // its popcount. + let idx: [u8; 32] = core::array::from_fn(|i| (i & 0x0F) as u8); + let counts = lut.shuffle_bytes(U8x32::from_array(idx)).to_array(); + for i in 0..32 { + let n = (i & 0x0F) as u32; + assert_eq!(counts[i] as u32, n.count_ones(), "lane {}", i); + } + } +} + /// Lowercase aliases (std::simd convention). #[allow(non_camel_case_types)] pub type f32x16 = F32x16; @@ -1735,6 +2250,9 @@ pub type f32x16 = F32x16; pub type f64x8 = F64x8; #[allow(non_camel_case_types)] pub type u8x64 = U8x64; +#[cfg(target_arch = "x86_64")] +#[allow(non_camel_case_types)] +pub type u8x32 = U8x32; #[allow(non_camel_case_types)] pub type i32x16 = I32x16; #[allow(non_camel_case_types)]