From 75b45e932112efa2d56eb9377c94c1f43e766377 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 3 Apr 2026 16:41:41 +0000 Subject: [PATCH 1/3] =?UTF-8?q?feat:=20BF16=20SIMD=20polyfill=20=E2=80=94?= =?UTF-8?q?=20as=5Fchunks::<16>()=20+=20runtime=20avx512bf16=20detection?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BF16↔f32 batch conversion via stable Rust 1.94: 1. Runtime detect avx512bf16 + avx512vl 2. as_chunks::<16>() → _mm512_cvtpbh_ps (16 BF16 → 16 f32) 3. as_chunks::<8>() remainder → _mm256_cvtpbh_ps (8 BF16 → 8 f32) 4. Scalar tail → f32::from_bits((bits as u32) << 16) No LazyLock — slice chunking handles batch widths. No nightly — as_chunks is stable since 1.94. Reference: https://doc.rust-lang.org/beta/src/core/stdarch/crates/core_arch/src/x86/avx512bf16.rs.html Types: BF16x16 (__m256bh), BF16x8 (__m128bh) — available when target_feature avx512bf16 is enabled at compile time. Functions (always available, scalar fallback built in): bf16_to_f32_batch(input: &[u16], output: &mut [f32]) f32_to_bf16_batch(input: &[f32], output: &mut [u16]) bf16_to_f32_scalar(bits: u16) → f32 f32_to_bf16_scalar(v: f32) → u16 3 tests passing. https://claude.ai/code/session_01ChLvBfpJS8dQhHxRD4pYNp --- src/simd.rs | 29 +++++ src/simd_avx512.rs | 260 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 289 insertions(+) diff --git a/src/simd.rs b/src/simd.rs index c16621b7..89ee2130 100644 --- a/src/simd.rs +++ b/src/simd.rs @@ -22,6 +22,11 @@ static TIER: LazyLock = LazyLock::new(|| { #[inline(always)] fn tier() -> Tier { *TIER } +// BF16 tier detection happens inline in bf16_to_f32_batch() via +// is_x86_feature_detected!("avx512bf16") — no LazyLock needed. +// The check is cheap (reads a cached cpuid result) and the batch +// function uses as_chunks::<16>() + as_chunks::<8>() for SIMD widths. + // ============================================================================ // x86_64: re-export based on tier // ============================================================================ @@ -41,6 +46,16 @@ pub use crate::simd_avx512::{ f32x16, f64x8, u8x64, i32x16, i64x8, u32x16, u64x8, }; +// BF16 types + batch conversion (always available — scalar fallback built in) +#[cfg(target_arch = "x86_64")] +pub use crate::simd_avx512::{ + bf16_to_f32_scalar, f32_to_bf16_scalar, + bf16_to_f32_batch, f32_to_bf16_batch, +}; +// BF16 SIMD types only available when avx512bf16 is enabled at compile time +#[cfg(all(target_arch = "x86_64", target_feature = "avx512bf16"))] +pub use crate::simd_avx512::{BF16x16, BF16x8}; + #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))] pub use crate::simd_avx512::{F32x8, F64x4, f32x8, f64x4}; @@ -697,6 +712,20 @@ pub use scalar::{ f32x8, f64x4, }; +// Scalar BF16 conversion — always available on all platforms +#[cfg(not(target_arch = "x86_64"))] +pub fn bf16_to_f32_scalar(bits: u16) -> f32 { f32::from_bits((bits as u32) << 16) } +#[cfg(not(target_arch = "x86_64"))] +pub fn f32_to_bf16_scalar(v: f32) -> u16 { (v.to_bits() >> 16) as u16 } +#[cfg(not(target_arch = "x86_64"))] +pub fn bf16_to_f32_batch(input: &[u16], output: &mut [f32]) { + for (i, &b) in input.iter().enumerate() { if i < output.len() { output[i] = bf16_to_f32_scalar(b); } } +} +#[cfg(not(target_arch = "x86_64"))] +pub fn f32_to_bf16_batch(input: &[f32], output: &mut [u16]) { + for (i, &v) in input.iter().enumerate() { if i < output.len() { output[i] = f32_to_bf16_scalar(v); } } +} + // ============================================================================ // SIMD math functions — ndarray additions (not in std::simd) // ============================================================================ diff --git a/src/simd_avx512.rs b/src/simd_avx512.rs index e3cb8f44..0ba48f4c 100644 --- a/src/simd_avx512.rs +++ b/src/simd_avx512.rs @@ -1457,3 +1457,263 @@ pub type f32x8 = F32x8; #[allow(non_camel_case_types)] pub type f64x4 = F64x4; +// ============================================================================ +// BF16 conversion wrappers — AVX-512 BF16 hardware instructions +// ============================================================================ +// +// Reference: https://doc.rust-lang.org/beta/src/core/stdarch/crates/core_arch/src/x86/avx512bf16.rs.html +// +// Hardware instructions (requires avx512bf16 + avx512vl): +// _mm512_cvtpbh_ps: 16 BF16 → 16 f32 (__m256bh → __m512) +// _mm256_cvtpbh_ps: 8 BF16 → 8 f32 (__m128bh → __m256) +// _mm_cvtpbh_ps: 4 BF16 → 4 f32 (__m128bh → __m128) +// _mm_cvtsbh_ss: 1 BF16 → 1 f32 (scalar) +// +// _mm512_cvtneps_pbh: 16 f32 → 16 BF16 (__m512 → __m256bh) +// _mm256_cvtneps_pbh: 8 f32 → 8 BF16 (__m256 → __m128bh) +// _mm_cvtness_sbh: 1 f32 → 1 BF16 (scalar) +// +// These are NOT available on all AVX-512 CPUs — requires the BF16 extension. +// The scalar fallback (shift left 16) works everywhere. + +/// BF16x16: 16 BF16 values packed in __m256bh. Converts to/from F32x16. +/// +/// Primary use: bulk BF16→f32 hydration from GGUF source files. +/// One `vcvtneebf162ps` instruction converts 16 BF16 → 16 f32. +#[cfg(target_arch = "x86_64")] +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct BF16x16(pub __m256bh); + +#[cfg(target_arch = "x86_64")] +impl BF16x16 { + pub const LANES: usize = 16; + + /// Load 16 BF16 values from a u16 slice. + /// + /// SAFETY: Requires avx512bf16 at call site. + /// Caller must ensure slice has >= 16 elements. + #[inline] + #[target_feature(enable = "avx512bf16")] + pub unsafe fn from_u16_slice(s: &[u16]) -> Self { + assert!(s.len() >= 16); + // __m256bh is 256 bits = 16 × u16. Load as __m256i then transmute. + let raw = _mm256_loadu_si256(s.as_ptr() as *const __m256i); + Self(core::mem::transmute(raw)) + } + + /// Convert 16 BF16 → 16 f32 via hardware instruction. + /// + /// SAFETY: Requires avx512bf16 + avx512f at call site. + /// Uses `vcvtneebf162ps` — one instruction, one cycle. + #[inline] + #[target_feature(enable = "avx512bf16,avx512f")] + pub unsafe fn to_f32x16(self) -> F32x16 { + F32x16(_mm512_cvtpbh_ps(self.0)) + } +} + +/// BF16x8: 8 BF16 values packed in __m128bh. Converts to/from F32x8. +#[cfg(target_arch = "x86_64")] +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct BF16x8(pub __m128bh); + +#[cfg(target_arch = "x86_64")] +impl BF16x8 { + pub const LANES: usize = 8; + + /// Load 8 BF16 values from a u16 slice. + #[inline] + #[target_feature(enable = "avx512bf16")] + pub unsafe fn from_u16_slice(s: &[u16]) -> Self { + assert!(s.len() >= 8); + let raw = _mm_loadu_si128(s.as_ptr() as *const __m128i); + Self(core::mem::transmute(raw)) + } + + /// Convert 8 BF16 → 8 f32 via hardware instruction. + #[inline] + #[target_feature(enable = "avx512bf16,avx512vl")] + pub unsafe fn to_f32x8(self) -> F32x8 { + F32x8(_mm256_cvtpbh_ps(self.0)) + } +} + +/// F32x16 → BF16x16 conversion (16 f32 → 16 BF16). +#[cfg(target_arch = "x86_64")] +impl F32x16 { + /// Convert 16 f32 → 16 BF16 via hardware instruction. + #[inline] + #[target_feature(enable = "avx512bf16,avx512f")] + pub unsafe fn to_bf16x16(self) -> BF16x16 { + BF16x16(_mm512_cvtneps_pbh(self.0)) + } +} + +/// F32x8 → BF16x8 conversion (8 f32 → 8 BF16). +#[cfg(target_arch = "x86_64")] +impl F32x8 { + /// Convert 8 f32 → 8 BF16 via hardware instruction. + #[inline] + #[target_feature(enable = "avx512bf16,avx512vl")] + pub unsafe fn to_bf16x8(self) -> BF16x8 { + BF16x8(_mm256_cvtneps_pbh(self.0)) + } +} + +// ── Scalar BF16 conversion (always available, no target_feature needed) ── + +/// Scalar BF16 → f32: bit shift, one instruction, lossless. +/// Works on ALL platforms — this is the fallback when avx512bf16 is not available. +#[inline] +pub fn bf16_to_f32_scalar(bits: u16) -> f32 { + f32::from_bits((bits as u32) << 16) +} + +/// Scalar f32 → BF16: truncate mantissa (lossy, 1 ULP). +#[inline] +pub fn f32_to_bf16_scalar(v: f32) -> u16 { + (v.to_bits() >> 16) as u16 +} + +/// Batch BF16 → f32 conversion: runtime feature detection + `as_chunks::()`. +/// +/// Uses stable Rust 1.94 `slice::as_chunks` for SIMD batch widths: +/// 1. Runtime detect avx512bf16 + avx512vl +/// 2. Process 16-wide chunks via `_mm512_cvtpbh_ps` +/// 3. Process 8-wide remainder via `_mm256_cvtpbh_ps` +/// 4. Finish scalar tail via bit shift +/// +/// No LazyLock, no nightly. Just `as_chunks::<16>()` + `as_chunks::<8>()`. +pub fn bf16_to_f32_batch(input: &[u16], output: &mut [f32]) { + assert!(output.len() >= input.len(), "output must be >= input length"); + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("avx512bf16") + && is_x86_feature_detected!("avx512vl") + { + // SAFETY: feature detection confirmed avx512bf16 + avx512vl + unsafe { convert_bf16_to_f32_avx512bf16(input, output); } + return; + } + } + + // Scalar fallback (all platforms, all CPUs) + for (src, dst) in input.iter().copied().zip(output.iter_mut()) { + *dst = bf16_to_f32_scalar(src); + } +} + +/// Batch f32 → BF16 conversion: same pattern. +pub fn f32_to_bf16_batch(input: &[f32], output: &mut [u16]) { + assert!(output.len() >= input.len(), "output must be >= input length"); + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("avx512bf16") + && is_x86_feature_detected!("avx512vl") + { + unsafe { convert_f32_to_bf16_avx512bf16(input, output); } + return; + } + } + + for (src, dst) in input.iter().copied().zip(output.iter_mut()) { + *dst = f32_to_bf16_scalar(src); + } +} + +/// AVX-512 BF16 path: as_chunks::<16>() → as_chunks::<8>() → scalar tail. +/// +/// Reference: https://doc.rust-lang.org/beta/src/core/stdarch/crates/core_arch/src/x86/avx512bf16.rs.html +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512bf16,avx512vl")] +unsafe fn convert_bf16_to_f32_avx512bf16(input: &[u16], output: &mut [f32]) { + // 16-wide chunks + let (chunks16, rem16) = input.as_chunks::<16>(); + let (out16, out_rem16) = output[..input.len()].as_chunks_mut::<16>(); + + for (src, dst) in chunks16.iter().zip(out16.iter_mut()) { + // SAFETY: [u16; 16] = 256 bits = __m256bh + let v_bf16: __m256bh = core::mem::transmute(*src); + let v_f32: __m512 = _mm512_cvtpbh_ps(v_bf16); + *dst = core::mem::transmute(v_f32); + } + + // 8-wide remainder chunks + let (chunks8, rem8) = rem16.as_chunks::<8>(); + let (out8, out_rem8) = out_rem16.as_chunks_mut::<8>(); + + for (src, dst) in chunks8.iter().zip(out8.iter_mut()) { + let v_bf16: __m128bh = core::mem::transmute(*src); + let v_f32: __m256 = _mm256_cvtpbh_ps(v_bf16); + *dst = core::mem::transmute(v_f32); + } + + // Scalar tail (0-7 remaining values) + for (src, dst) in rem8.iter().copied().zip(out_rem8.iter_mut()) { + *dst = f32::from_bits((src as u32) << 16); + } +} + +/// AVX-512 BF16 path for f32 → BF16. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512bf16,avx512vl")] +unsafe fn convert_f32_to_bf16_avx512bf16(input: &[f32], output: &mut [u16]) { + let (chunks16, rem16) = input.as_chunks::<16>(); + let (out16, out_rem16) = output[..input.len()].as_chunks_mut::<16>(); + + for (src, dst) in chunks16.iter().zip(out16.iter_mut()) { + let v_f32: __m512 = core::mem::transmute(*src); + let v_bf16: __m256bh = _mm512_cvtneps_pbh(v_f32); + *dst = core::mem::transmute(v_bf16); + } + + // Scalar remainder (f32→BF16 has no 8-wide instruction worth using) + for (src, dst) in rem16.iter().copied().zip(out_rem16.iter_mut()) { + *dst = (src.to_bits() >> 16) as u16; + } +} + +#[cfg(test)] +mod bf16_tests { + use super::*; + + #[test] + fn scalar_roundtrip() { + for &v in &[0.0f32, 1.0, -1.0, 0.5, -0.5, 100.0, 0.001, -0.001] { + let bf16 = f32_to_bf16_scalar(v); + let back = bf16_to_f32_scalar(bf16); + let err = (v - back).abs() / v.abs().max(1e-6); + assert!(err < 0.02, "roundtrip error for {}: {} → {} → {}, err={:.4}", v, v, bf16, back, err); + } + } + + #[test] + fn batch_conversion_matches_scalar() { + let input: Vec = (0..100).map(|i| f32_to_bf16_scalar(i as f32 * 0.1 - 5.0)).collect(); + let mut batch_output = vec![0.0f32; 100]; + bf16_to_f32_batch(&input, &mut batch_output); + + for (i, &bf16) in input.iter().enumerate() { + let scalar = bf16_to_f32_scalar(bf16); + assert_eq!(batch_output[i], scalar, "mismatch at index {}", i); + } + } + + #[test] + fn batch_f32_to_bf16() { + let input: Vec = (0..50).map(|i| i as f32 * 0.3 - 7.5).collect(); + let mut output = vec![0u16; 50]; + f32_to_bf16_batch(&input, &mut output); + + for (i, &v) in input.iter().enumerate() { + let expected = f32_to_bf16_scalar(v); + // Allow ±1 ULP: hardware uses round-to-nearest-even, scalar uses truncation + let diff = (output[i] as i32 - expected as i32).unsigned_abs(); + assert!(diff <= 1, "mismatch at index {}: {} → {} vs {}, diff={}", i, v, output[i], expected, diff); + } + } +} From 1b06969670851bedbb969a62aa9951fb25c078a2 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 3 Apr 2026 17:15:10 +0000 Subject: [PATCH 2/3] feat: PREFERRED_F64/F32/U64/I16_LANES compile-time constants for array_windows dispatch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Compile-time const (not LazyLock) — resolved by #[cfg(target_feature)]: AVX-512: F64=8, F32=16, U64=8, I16=32 AVX2: F64=4, F32=8, U64=4, I16=16 Scalar: same as AVX2 Enables consumers to use array_windows::<{PREFERRED_F64_LANES}>() for native-width SIMD processing without runtime branching. https://claude.ai/code/session_01ChLvBfpJS8dQhHxRD4pYNp --- src/simd.rs | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/simd.rs b/src/simd.rs index 89ee2130..10a81335 100644 --- a/src/simd.rs +++ b/src/simd.rs @@ -27,6 +27,59 @@ fn tier() -> Tier { *TIER } // The check is cheap (reads a cached cpuid result) and the batch // function uses as_chunks::<16>() + as_chunks::<8>() for SIMD widths. +// ============================================================================ +// Preferred SIMD lane widths — compile-time constants for array_windows +// ============================================================================ +// +// Consumer code uses these to select array_windows size at compile time: +// +// for window in data.array_windows::<{crate::simd::PREFERRED_F64_LANES}>() { +// let v = F64x8::from_array(*window); // AVX-512: native 8-wide +// // or +// let v = F64x4::from_array(*window); // AVX2: native 4-wide +// } +// +// generic_const_exprs is nightly, so consumers must #[cfg] branch on window size. +// These constants document the preferred width per tier. + +/// Preferred f64 SIMD width (elements per register). +/// AVX-512: 8 lanes (__m512d). AVX2/scalar: 4 lanes (__m256d). +#[cfg(target_feature = "avx512f")] +pub const PREFERRED_F64_LANES: usize = 8; +#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))] +pub const PREFERRED_F64_LANES: usize = 4; +#[cfg(not(target_arch = "x86_64"))] +pub const PREFERRED_F64_LANES: usize = 4; // scalar fallback: same as AVX2 shape + +/// Preferred f32 SIMD width. +/// AVX-512: 16 lanes (__m512). AVX2/scalar: 8 lanes (__m256). +#[cfg(target_feature = "avx512f")] +pub const PREFERRED_F32_LANES: usize = 16; +#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))] +pub const PREFERRED_F32_LANES: usize = 8; +#[cfg(not(target_arch = "x86_64"))] +pub const PREFERRED_F32_LANES: usize = 8; + +/// Preferred u64 SIMD width. +/// AVX-512: 8 lanes. AVX2/scalar: 4 lanes. +#[cfg(target_feature = "avx512f")] +pub const PREFERRED_U64_LANES: usize = 8; +#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))] +pub const PREFERRED_U64_LANES: usize = 4; +#[cfg(not(target_arch = "x86_64"))] +pub const PREFERRED_U64_LANES: usize = 4; + +/// Preferred i16 SIMD width (for Base17 L1 on i16[17]). +/// AVX-512: 32 lanes (__m512i via epi16). AVX2: 16 lanes (__m256i). +/// Base17 has 17 dims — AVX-512 covers 32 (load 17 + 15 padding), +/// AVX2 covers 16 + 1 scalar. +#[cfg(target_feature = "avx512f")] +pub const PREFERRED_I16_LANES: usize = 32; +#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))] +pub const PREFERRED_I16_LANES: usize = 16; +#[cfg(not(target_arch = "x86_64"))] +pub const PREFERRED_I16_LANES: usize = 16; + // ============================================================================ // x86_64: re-export based on tier // ============================================================================ From bad2a55160c954001b400f8e3bf5fa211866e2c4 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 3 Apr 2026 17:20:00 +0000 Subject: [PATCH 3/3] feat: U8x64 byte-level ops for palette codec, nibble, byte scan (Pumpkin/SD) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added to all three tiers (AVX-512 / AVX2 / scalar): cmpeq_mask(other) → u64 — byte-wise equality, returns bitmask shr_epi16(imm) → Self — shift right 16-bit lanes (nibble extract) saturating_sub(other) — max(a-b, 0) per byte (delta subtraction) unpack_lo_epi8(other) — interleave low bytes (nibble interleave) unpack_hi_epi8(other) — interleave high bytes These operations are used by: palette_codec.rs — Minecraft-style variable-width bit packing nibble.rs — 4-bit light level packing (Pumpkin) byte_scan.rs — NBT format byte scanning (future) stable_diffusion/ — VAE latent palette encoding via GGUF All three are currently using raw _mm256_/_mm512_ intrinsics. Next step: rewire them to use crate::simd::U8x64 instead. https://claude.ai/code/session_01ChLvBfpJS8dQhHxRD4pYNp --- src/simd.rs | 37 +++++++++++++++++++++--- src/simd_avx2.rs | 70 ++++++++++++++++++++++++++++++++++++++++++++++ src/simd_avx512.rs | 47 +++++++++++++++++++++++++++++++ 3 files changed, 150 insertions(+), 4 deletions(-) diff --git a/src/simd.rs b/src/simd.rs index 10a81335..a6654b74 100644 --- a/src/simd.rs +++ b/src/simd.rs @@ -713,7 +713,7 @@ mod scalar { fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } } - // U8x64 extra methods + // U8x64 extra methods — byte-level operations for palette codec, nibble, byte scan impl U8x64 { #[inline(always)] pub fn reduce_min(self) -> u8 { *self.0.iter().min().unwrap_or(&0) } @@ -721,14 +721,43 @@ mod scalar { pub fn reduce_max(self) -> u8 { *self.0.iter().max().unwrap_or(&0) } #[inline(always)] pub fn simd_min(self, other: Self) -> Self { + let mut out = [0u8; 64]; for i in 0..64 { out[i] = self.0[i].min(other.0[i]); } Self(out) + } + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + let mut out = [0u8; 64]; for i in 0..64 { out[i] = self.0[i].max(other.0[i]); } Self(out) + } + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u64 { + let mut mask = 0u64; + for i in 0..64 { if self.0[i] == other.0[i] { mask |= 1u64 << i; } } + mask + } + #[inline(always)] + pub fn shr_epi16(self, imm: u32) -> Self { let mut out = [0u8; 64]; - for i in 0..64 { out[i] = self.0[i].min(other.0[i]); } + for i in (0..64).step_by(2) { + let val = u16::from_le_bytes([self.0[i], self.0[i + 1]]); + let shifted = val >> imm; + let bytes = shifted.to_le_bytes(); + out[i] = bytes[0]; out[i + 1] = bytes[1]; + } Self(out) } #[inline(always)] - pub fn simd_max(self, other: Self) -> Self { + pub fn saturating_sub(self, other: Self) -> Self { + let mut out = [0u8; 64]; for i in 0..64 { out[i] = self.0[i].saturating_sub(other.0[i]); } Self(out) + } + #[inline(always)] + pub fn unpack_lo_epi8(self, other: Self) -> Self { + let mut out = [0u8; 64]; + for lane in 0..4 { let b = lane * 16; for i in 0..8 { out[b+i*2] = self.0[b+i]; out[b+i*2+1] = other.0[b+i]; } } + Self(out) + } + #[inline(always)] + pub fn unpack_hi_epi8(self, other: Self) -> Self { let mut out = [0u8; 64]; - for i in 0..64 { out[i] = self.0[i].max(other.0[i]); } + for lane in 0..4 { let b = lane * 16; for i in 0..8 { out[b+i*2] = self.0[b+8+i]; out[b+i*2+1] = other.0[b+8+i]; } } Self(out) } } diff --git a/src/simd_avx2.rs b/src/simd_avx2.rs index fba91afc..b8f9ad84 100644 --- a/src/simd_avx2.rs +++ b/src/simd_avx2.rs @@ -761,6 +761,76 @@ macro_rules! avx2_int_type { } avx2_int_type!(U8x64, u8, 64, 0u8); + +// ── U8x64 byte-level operations (scalar fallback for AVX2 tier) ────────── +// These match the AVX-512 U8x64 methods in simd_avx512.rs. +impl U8x64 { + /// Byte-wise equality mask: bit i set if self[i] == other[i]. + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u64 { + let mut mask = 0u64; + for i in 0..64 { if self.0[i] == other.0[i] { mask |= 1u64 << i; } } + mask + } + + /// Shift right each 16-bit lane by imm bits (operates on pairs of u8 as u16). + #[inline(always)] + pub fn shr_epi16(self, imm: u32) -> Self { + let mut out = [0u8; 64]; + for i in (0..64).step_by(2) { + let val = u16::from_le_bytes([self.0[i], self.0[i + 1]]); + let shifted = val >> imm; + let bytes = shifted.to_le_bytes(); + out[i] = bytes[0]; + out[i + 1] = bytes[1]; + } + Self(out) + } + + /// Saturating unsigned subtraction: max(a - b, 0) per byte. + #[inline(always)] + pub fn saturating_sub(self, other: Self) -> Self { + let mut out = [0u8; 64]; + for i in 0..64 { out[i] = self.0[i].saturating_sub(other.0[i]); } + Self(out) + } + + /// Interleave low bytes within each 128-bit lane. + #[inline(always)] + pub fn unpack_lo_epi8(self, other: Self) -> Self { + let mut out = [0u8; 64]; + // Operates per 16-byte lane (4 lanes in 512-bit) + for lane in 0..4 { + let base = lane * 16; + for i in 0..8 { + out[base + i * 2] = self.0[base + i]; + out[base + i * 2 + 1] = other.0[base + i]; + } + } + Self(out) + } + + /// Interleave high bytes within each 128-bit lane. + #[inline(always)] + pub fn unpack_hi_epi8(self, other: Self) -> Self { + let mut out = [0u8; 64]; + for lane in 0..4 { + let base = lane * 16; + for i in 0..8 { + out[base + i * 2] = self.0[base + 8 + i]; + out[base + i * 2 + 1] = other.0[base + 8 + i]; + } + } + Self(out) + } + + /// Reduce min/max (not in macro). + #[inline(always)] pub fn reduce_min(self) -> u8 { *self.0.iter().min().unwrap() } + #[inline(always)] pub fn reduce_max(self) -> u8 { *self.0.iter().max().unwrap() } + #[inline(always)] pub fn simd_min(self, other: Self) -> Self { let mut o = [0u8; 64]; for i in 0..64 { o[i] = self.0[i].min(other.0[i]); } Self(o) } + #[inline(always)] pub fn simd_max(self, other: Self) -> Self { let mut o = [0u8; 64]; for i in 0..64 { o[i] = self.0[i].max(other.0[i]); } Self(o) } +} + avx2_int_type!(I32x16, i32, 16, 0i32); avx2_int_type!(I64x8, i64, 8, 0i64); avx2_int_type!(U32x16, u32, 16, 0u32); diff --git a/src/simd_avx512.rs b/src/simd_avx512.rs index 0ba48f4c..ad249d3d 100644 --- a/src/simd_avx512.rs +++ b/src/simd_avx512.rs @@ -576,6 +576,53 @@ impl U8x64 { pub fn simd_max(self, other: Self) -> Self { Self(unsafe { _mm512_max_epu8(self.0, other.0) }) } + + // ── Byte-level operations for palette codec, nibble, byte scan ────── + // Reference: Pumpkin/Minecraft-derived modules (palette_codec.rs, + // nibble.rs, byte_scan.rs) use these for 4-bit packing and scanning. + + /// Byte-wise equality comparison. Returns 64-bit mask: bit i set if a[i] == b[i]. + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u64 { + unsafe { _mm512_cmpeq_epi8_mask(self.0, other.0) } + } + + /// Shift right each 16-bit lane by immediate bits (for nibble extraction). + /// Note: operates on 16-bit lanes, not 8-bit — matches _mm512_srli_epi16. + #[inline(always)] + pub fn shr_epi16(self, imm: u32) -> Self { + // _mm512_srli_epi16 shifts each 16-bit lane right + // Use match for const immediate (intrinsic requires const) + Self(unsafe { match imm { + 1 => _mm512_srli_epi16(self.0, 1), + 2 => _mm512_srli_epi16(self.0, 2), + 3 => _mm512_srli_epi16(self.0, 3), + 4 => _mm512_srli_epi16(self.0, 4), + 5 => _mm512_srli_epi16(self.0, 5), + 6 => _mm512_srli_epi16(self.0, 6), + 7 => _mm512_srli_epi16(self.0, 7), + 8 => _mm512_srli_epi16(self.0, 8), + _ => _mm512_setzero_si512(), + }}) + } + + /// Saturating unsigned subtraction: max(a - b, 0) per byte. + #[inline(always)] + pub fn saturating_sub(self, other: Self) -> Self { + Self(unsafe { _mm512_subs_epu8(self.0, other.0) }) + } + + /// Interleave low bytes: [a0,b0,a1,b1,...] from lower halves. + #[inline(always)] + pub fn unpack_lo_epi8(self, other: Self) -> Self { + Self(unsafe { _mm512_unpacklo_epi8(self.0, other.0) }) + } + + /// Interleave high bytes: [a8,b8,a9,b9,...] from upper halves. + #[inline(always)] + pub fn unpack_hi_epi8(self, other: Self) -> Self { + Self(unsafe { _mm512_unpackhi_epi8(self.0, other.0) }) + } } // u8 add/sub use AVX-512BW instructions