diff --git a/src/simd_neon.rs b/src/simd_neon.rs index 63717849..cb1df556 100644 --- a/src/simd_neon.rs +++ b/src/simd_neon.rs @@ -1236,6 +1236,182 @@ impl PartialEq for I16x8 { fn eq(&self, other: &Self) -> bool { self.to_array() == other.to_array() } } +// ═══════════════════════════════════════════════════════════════════════════ +// W3-B: NEON integer wrapper types (item 8 of burn parity list) +// ─ U8x16, U16x8, U32x4, U64x2, I32x4, I64x2 ─ +// ═══════════════════════════════════════════════════════════════════════════ + +#[cfg(target_arch = "aarch64")] +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct U8x16(pub uint8x16_t); + +#[cfg(target_arch = "aarch64")] +impl U8x16 { + pub const LANES: usize = 16; + #[inline(always)] pub fn splat(v: u8) -> Self { Self(unsafe { vdupq_n_u8(v) }) } + #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u8(0) }) } + #[inline(always)] pub fn from_slice(s: &[u8]) -> Self { + assert!(s.len() >= 16); Self(unsafe { vld1q_u8(s.as_ptr()) }) + } + #[inline(always)] pub fn from_array(arr: [u8; 16]) -> Self { Self(unsafe { vld1q_u8(arr.as_ptr()) }) } + #[inline(always)] pub fn to_array(self) -> [u8; 16] { + let mut arr = [0u8; 16]; unsafe { vst1q_u8(arr.as_mut_ptr(), self.0) }; arr + } + #[inline(always)] pub fn copy_to_slice(self, s: &mut [u8]) { + assert!(s.len() >= 16); unsafe { vst1q_u8(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u8(self.0, other.0) }) } + #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u8(self.0, other.0) }) } + #[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_u8(self.0, other.0) }) } + #[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_u8(self.0, other.0) }) } +} + +#[cfg(target_arch = "aarch64")] +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct U16x8(pub uint16x8_t); + +#[cfg(target_arch = "aarch64")] +impl U16x8 { + pub const LANES: usize = 8; + #[inline(always)] pub fn splat(v: u16) -> Self { Self(unsafe { vdupq_n_u16(v) }) } + #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u16(0) }) } + #[inline(always)] pub fn from_slice(s: &[u16]) -> Self { + assert!(s.len() >= 8); Self(unsafe { vld1q_u16(s.as_ptr()) }) + } + #[inline(always)] pub fn from_array(arr: [u16; 8]) -> Self { Self(unsafe { vld1q_u16(arr.as_ptr()) }) } + #[inline(always)] pub fn to_array(self) -> [u16; 8] { + let mut arr = [0u16; 8]; unsafe { vst1q_u16(arr.as_mut_ptr(), self.0) }; arr + } + #[inline(always)] pub fn copy_to_slice(self, s: &mut [u16]) { + assert!(s.len() >= 8); unsafe { vst1q_u16(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u16(self.0, other.0) }) } + #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u16(self.0, other.0) }) } + #[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_u16(self.0, other.0) }) } + #[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_u16(self.0, other.0) }) } +} + +#[cfg(target_arch = "aarch64")] +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct U32x4(pub uint32x4_t); + +#[cfg(target_arch = "aarch64")] +impl U32x4 { + pub const LANES: usize = 4; + #[inline(always)] pub fn splat(v: u32) -> Self { Self(unsafe { vdupq_n_u32(v) }) } + #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u32(0) }) } + #[inline(always)] pub fn from_slice(s: &[u32]) -> Self { + assert!(s.len() >= 4); Self(unsafe { vld1q_u32(s.as_ptr()) }) + } + #[inline(always)] pub fn from_array(arr: [u32; 4]) -> Self { Self(unsafe { vld1q_u32(arr.as_ptr()) }) } + #[inline(always)] pub fn to_array(self) -> [u32; 4] { + let mut arr = [0u32; 4]; unsafe { vst1q_u32(arr.as_mut_ptr(), self.0) }; arr + } + #[inline(always)] pub fn copy_to_slice(self, s: &mut [u32]) { + assert!(s.len() >= 4); unsafe { vst1q_u32(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u32(self.0, other.0) }) } + #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u32(self.0, other.0) }) } + #[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_u32(self.0, other.0) }) } + #[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_u32(self.0, other.0) }) } +} + +#[cfg(target_arch = "aarch64")] +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct U64x2(pub uint64x2_t); + +#[cfg(target_arch = "aarch64")] +impl U64x2 { + pub const LANES: usize = 2; + #[inline(always)] pub fn splat(v: u64) -> Self { Self(unsafe { vdupq_n_u64(v) }) } + #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u64(0) }) } + #[inline(always)] pub fn from_slice(s: &[u64]) -> Self { + assert!(s.len() >= 2); Self(unsafe { vld1q_u64(s.as_ptr()) }) + } + #[inline(always)] pub fn from_array(arr: [u64; 2]) -> Self { Self(unsafe { vld1q_u64(arr.as_ptr()) }) } + #[inline(always)] pub fn to_array(self) -> [u64; 2] { + let mut arr = [0u64; 2]; unsafe { vst1q_u64(arr.as_mut_ptr(), self.0) }; arr + } + #[inline(always)] pub fn copy_to_slice(self, s: &mut [u64]) { + assert!(s.len() >= 2); unsafe { vst1q_u64(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u64(self.0, other.0) }) } + #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u64(self.0, other.0) }) } + // NEON has no vminq_u64 / vmaxq_u64 — scalar fallback + #[inline(always)] pub fn min(self, other: Self) -> Self { + let a = self.to_array(); let b = other.to_array(); + Self::from_array([a[0].min(b[0]), a[1].min(b[1])]) + } + #[inline(always)] pub fn max(self, other: Self) -> Self { + let a = self.to_array(); let b = other.to_array(); + Self::from_array([a[0].max(b[0]), a[1].max(b[1])]) + } +} + +#[cfg(target_arch = "aarch64")] +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct I32x4(pub int32x4_t); + +#[cfg(target_arch = "aarch64")] +impl I32x4 { + pub const LANES: usize = 4; + #[inline(always)] pub fn splat(v: i32) -> Self { Self(unsafe { vdupq_n_s32(v) }) } + #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_s32(0) }) } + #[inline(always)] pub fn from_slice(s: &[i32]) -> Self { + assert!(s.len() >= 4); Self(unsafe { vld1q_s32(s.as_ptr()) }) + } + #[inline(always)] pub fn from_array(arr: [i32; 4]) -> Self { Self(unsafe { vld1q_s32(arr.as_ptr()) }) } + #[inline(always)] pub fn to_array(self) -> [i32; 4] { + let mut arr = [0i32; 4]; unsafe { vst1q_s32(arr.as_mut_ptr(), self.0) }; arr + } + #[inline(always)] pub fn copy_to_slice(self, s: &mut [i32]) { + assert!(s.len() >= 4); unsafe { vst1q_s32(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_s32(self.0, other.0) }) } + #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_s32(self.0, other.0) }) } + #[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_s32(self.0, other.0) }) } + #[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_s32(self.0, other.0) }) } +} + +#[cfg(target_arch = "aarch64")] +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct I64x2(pub int64x2_t); + +#[cfg(target_arch = "aarch64")] +impl I64x2 { + pub const LANES: usize = 2; + #[inline(always)] pub fn splat(v: i64) -> Self { Self(unsafe { vdupq_n_s64(v) }) } + #[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_s64(0) }) } + #[inline(always)] pub fn from_slice(s: &[i64]) -> Self { + assert!(s.len() >= 2); Self(unsafe { vld1q_s64(s.as_ptr()) }) + } + #[inline(always)] pub fn from_array(arr: [i64; 2]) -> Self { Self(unsafe { vld1q_s64(arr.as_ptr()) }) } + #[inline(always)] pub fn to_array(self) -> [i64; 2] { + let mut arr = [0i64; 2]; unsafe { vst1q_s64(arr.as_mut_ptr(), self.0) }; arr + } + #[inline(always)] pub fn copy_to_slice(self, s: &mut [i64]) { + assert!(s.len() >= 2); unsafe { vst1q_s64(s.as_mut_ptr(), self.0) }; + } + #[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_s64(self.0, other.0) }) } + #[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_s64(self.0, other.0) }) } + // NEON has no vminq_s64 / vmaxq_s64 — scalar fallback + #[inline(always)] pub fn min(self, other: Self) -> Self { + let a = self.to_array(); let b = other.to_array(); + Self::from_array([a[0].min(b[0]), a[1].min(b[1])]) + } + #[inline(always)] pub fn max(self, other: Self) -> Self { + let a = self.to_array(); let b = other.to_array(); + Self::from_array([a[0].max(b[0]), a[1].max(b[1])]) + } +} + + // ── Polyfills for wider lanes (scalar arrays) ───────────────────────────── macro_rules! neon_int_polyfill {