From 476dff0c1f452d778fdc7d39bfc16eabaae73219 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 30 Apr 2026 12:58:41 +0000 Subject: [PATCH] feat(simd-neon): 6 NEON integer wrapper types for aarch64 (sprint W3-B) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes parity item 8 — adds U8x16, U16x8, U32x4, U64x2, I32x4, I64x2 NEON wrapper types so aarch64 burn-ndarray builds get real NEON acceleration on integer hot paths instead of scalar. Each type has: splat/zero/from_slice/from_array/to_array/copy_to_slice/ add/sub/min/max. NEON intrinsics: - U8x16: vaddq_u8, vsubq_u8, vminq_u8, vmaxq_u8 - U16x8: vaddq_u16, vsubq_u16, vminq_u16, vmaxq_u16 - U32x4: vaddq_u32, vsubq_u32, vminq_u32, vmaxq_u32 - U64x2: vaddq_u64, vsubq_u64 (min/max scalar — NEON has no vminq_u64) - I32x4: vaddq_s32, vsubq_s32, vminq_s32, vmaxq_s32 - I64x2: vaddq_s64, vsubq_s64 (min/max scalar — NEON has no vminq_s64) Item 7 (AVX2 paired-256 fallbacks for U32x16/U64x8/etc.) deferred: all 6 types already exist as scalar fallback via avx2_int_type! macro in src/simd_avx2.rs — they're correct and complete; paired-256 SIMD acceleration is a perf upgrade, not a functionality blocker. Tests: builds clean on x86_64 + cross-compiles clean for aarch64-unknown-linux-gnu. https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj --- src/simd_neon.rs | 176 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) diff --git a/src/simd_neon.rs b/src/simd_neon.rs index 63717849..cb1df556 100644 --- a/src/simd_neon.rs +++ b/src/simd_neon.rs @@ -1236,6 +1236,182 @@ impl PartialEq for I16x8 { fn eq(&self, other: &Self) -> bool { self.to_array() == other.to_array() } } +// ═══════════════════════════════════════════════════════════════════════════ +// W3-B: NEON integer wrapper types (item 8 of burn parity list) +// ─ U8x16, U16x8, U32x4, U64x2, I32x4, I64x2 ─ +// ═══════════════════════════════════════════════════════════════════════════ + +#[cfg(target_arch = "aarch64")] +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct U8x16(pub uint8x16_t); + +#[cfg(target_arch = "aarch64")] +impl U8x16 { + pub const LANES: usize = 16; + #[inline(always)] pub fn splat(v: u8) -> Self { Self(unsafe { vdupq_n_u8(v) }) } + #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u8(0) }) } + #[inline(always)] pub fn from_slice(s: &[u8]) -> Self { + assert!(s.len() >= 16); Self(unsafe { vld1q_u8(s.as_ptr()) }) + } + #[inline(always)] pub fn from_array(arr: [u8; 16]) -> Self { Self(unsafe { vld1q_u8(arr.as_ptr()) }) } + #[inline(always)] pub fn to_array(self) -> [u8; 16] { + let mut arr = [0u8; 16]; unsafe { vst1q_u8(arr.as_mut_ptr(), self.0) }; arr + } + #[inline(always)] pub fn copy_to_slice(self, s: &mut [u8]) { + assert!(s.len() >= 16); unsafe { vst1q_u8(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u8(self.0, other.0) }) } + #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u8(self.0, other.0) }) } + #[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_u8(self.0, other.0) }) } + #[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_u8(self.0, other.0) }) } +} + +#[cfg(target_arch = "aarch64")] +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct U16x8(pub uint16x8_t); + +#[cfg(target_arch = "aarch64")] +impl U16x8 { + pub const LANES: usize = 8; + #[inline(always)] pub fn splat(v: u16) -> Self { Self(unsafe { vdupq_n_u16(v) }) } + #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u16(0) }) } + #[inline(always)] pub fn from_slice(s: &[u16]) -> Self { + assert!(s.len() >= 8); Self(unsafe { vld1q_u16(s.as_ptr()) }) + } + #[inline(always)] pub fn from_array(arr: [u16; 8]) -> Self { Self(unsafe { vld1q_u16(arr.as_ptr()) }) } + #[inline(always)] pub fn to_array(self) -> [u16; 8] { + let mut arr = [0u16; 8]; unsafe { vst1q_u16(arr.as_mut_ptr(), self.0) }; arr + } + #[inline(always)] pub fn copy_to_slice(self, s: &mut [u16]) { + assert!(s.len() >= 8); unsafe { vst1q_u16(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u16(self.0, other.0) }) } + #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u16(self.0, other.0) }) } + #[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_u16(self.0, other.0) }) } + #[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_u16(self.0, other.0) }) } +} + +#[cfg(target_arch = "aarch64")] +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct U32x4(pub uint32x4_t); + +#[cfg(target_arch = "aarch64")] +impl U32x4 { + pub const LANES: usize = 4; + #[inline(always)] pub fn splat(v: u32) -> Self { Self(unsafe { vdupq_n_u32(v) }) } + #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u32(0) }) } + #[inline(always)] pub fn from_slice(s: &[u32]) -> Self { + assert!(s.len() >= 4); Self(unsafe { vld1q_u32(s.as_ptr()) }) + } + #[inline(always)] pub fn from_array(arr: [u32; 4]) -> Self { Self(unsafe { vld1q_u32(arr.as_ptr()) }) } + #[inline(always)] pub fn to_array(self) -> [u32; 4] { + let mut arr = [0u32; 4]; unsafe { vst1q_u32(arr.as_mut_ptr(), self.0) }; arr + } + #[inline(always)] pub fn copy_to_slice(self, s: &mut [u32]) { + assert!(s.len() >= 4); unsafe { vst1q_u32(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u32(self.0, other.0) }) } + #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u32(self.0, other.0) }) } + #[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_u32(self.0, other.0) }) } + #[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_u32(self.0, other.0) }) } +} + +#[cfg(target_arch = "aarch64")] +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct U64x2(pub uint64x2_t); + +#[cfg(target_arch = "aarch64")] +impl U64x2 { + pub const LANES: usize = 2; + #[inline(always)] pub fn splat(v: u64) -> Self { Self(unsafe { vdupq_n_u64(v) }) } + #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u64(0) }) } + #[inline(always)] pub fn from_slice(s: &[u64]) -> Self { + assert!(s.len() >= 2); Self(unsafe { vld1q_u64(s.as_ptr()) }) + } + #[inline(always)] pub fn from_array(arr: [u64; 2]) -> Self { Self(unsafe { vld1q_u64(arr.as_ptr()) }) } + #[inline(always)] pub fn to_array(self) -> [u64; 2] { + let mut arr = [0u64; 2]; unsafe { vst1q_u64(arr.as_mut_ptr(), self.0) }; arr + } + #[inline(always)] pub fn copy_to_slice(self, s: &mut [u64]) { + assert!(s.len() >= 2); unsafe { vst1q_u64(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u64(self.0, other.0) }) } + #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u64(self.0, other.0) }) } + // NEON has no vminq_u64 / vmaxq_u64 — scalar fallback + #[inline(always)] pub fn min(self, other: Self) -> Self { + let a = self.to_array(); let b = other.to_array(); + Self::from_array([a[0].min(b[0]), a[1].min(b[1])]) + } + #[inline(always)] pub fn max(self, other: Self) -> Self { + let a = self.to_array(); let b = other.to_array(); + Self::from_array([a[0].max(b[0]), a[1].max(b[1])]) + } +} + +#[cfg(target_arch = "aarch64")] +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct I32x4(pub int32x4_t); + +#[cfg(target_arch = "aarch64")] +impl I32x4 { + pub const LANES: usize = 4; + #[inline(always)] pub fn splat(v: i32) -> Self { Self(unsafe { vdupq_n_s32(v) }) } + #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_s32(0) }) } + #[inline(always)] pub fn from_slice(s: &[i32]) -> Self { + assert!(s.len() >= 4); Self(unsafe { vld1q_s32(s.as_ptr()) }) + } + #[inline(always)] pub fn from_array(arr: [i32; 4]) -> Self { Self(unsafe { vld1q_s32(arr.as_ptr()) }) } + #[inline(always)] pub fn to_array(self) -> [i32; 4] { + let mut arr = [0i32; 4]; unsafe { vst1q_s32(arr.as_mut_ptr(), self.0) }; arr + } + #[inline(always)] pub fn copy_to_slice(self, s: &mut [i32]) { + assert!(s.len() >= 4); unsafe { vst1q_s32(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_s32(self.0, other.0) }) } + #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_s32(self.0, other.0) }) } + #[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_s32(self.0, other.0) }) } + #[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_s32(self.0, other.0) }) } +} + +#[cfg(target_arch = "aarch64")] +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct I64x2(pub int64x2_t); + +#[cfg(target_arch = "aarch64")] +impl I64x2 { + pub const LANES: usize = 2; + #[inline(always)] pub fn splat(v: i64) -> Self { Self(unsafe { vdupq_n_s64(v) }) } + #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_s64(0) }) } + #[inline(always)] pub fn from_slice(s: &[i64]) -> Self { + assert!(s.len() >= 2); Self(unsafe { vld1q_s64(s.as_ptr()) }) + } + #[inline(always)] pub fn from_array(arr: [i64; 2]) -> Self { Self(unsafe { vld1q_s64(arr.as_ptr()) }) } + #[inline(always)] pub fn to_array(self) -> [i64; 2] { + let mut arr = [0i64; 2]; unsafe { vst1q_s64(arr.as_mut_ptr(), self.0) }; arr + } + #[inline(always)] pub fn copy_to_slice(self, s: &mut [i64]) { + assert!(s.len() >= 2); unsafe { vst1q_s64(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_s64(self.0, other.0) }) } + #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_s64(self.0, other.0) }) } + // NEON has no vminq_s64 / vmaxq_s64 — scalar fallback + #[inline(always)] pub fn min(self, other: Self) -> Self { + let a = self.to_array(); let b = other.to_array(); + Self::from_array([a[0].min(b[0]), a[1].min(b[1])]) + } + #[inline(always)] pub fn max(self, other: Self) -> Self { + let a = self.to_array(); let b = other.to_array(); + Self::from_array([a[0].max(b[0]), a[1].max(b[1])]) + } +} + + // ── Polyfills for wider lanes (scalar arrays) ───────────────────────────── macro_rules! neon_int_polyfill {