Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions src/simd_neon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,182 @@ impl PartialEq for I16x8 {
fn eq(&self, other: &Self) -> bool { self.to_array() == other.to_array() }
}

// ═══════════════════════════════════════════════════════════════════════════
// W3-B: NEON integer wrapper types (item 8 of burn parity list)
// ─ U8x16, U16x8, U32x4, U64x2, I32x4, I64x2 ─
// ═══════════════════════════════════════════════════════════════════════════

#[cfg(target_arch = "aarch64")]
#[derive(Copy, Clone)]
#[repr(transparent)]
pub struct U8x16(pub uint8x16_t);

#[cfg(target_arch = "aarch64")]
impl U8x16 {
pub const LANES: usize = 16;
#[inline(always)] pub fn splat(v: u8) -> Self { Self(unsafe { vdupq_n_u8(v) }) }
#[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u8(0) }) }
#[inline(always)] pub fn from_slice(s: &[u8]) -> Self {
assert!(s.len() >= 16); Self(unsafe { vld1q_u8(s.as_ptr()) })
}
#[inline(always)] pub fn from_array(arr: [u8; 16]) -> Self { Self(unsafe { vld1q_u8(arr.as_ptr()) }) }
#[inline(always)] pub fn to_array(self) -> [u8; 16] {
let mut arr = [0u8; 16]; unsafe { vst1q_u8(arr.as_mut_ptr(), self.0) }; arr
}
#[inline(always)] pub fn copy_to_slice(self, s: &mut [u8]) {
assert!(s.len() >= 16); unsafe { vst1q_u8(s.as_mut_ptr(), self.0) };
}
#[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u8(self.0, other.0) }) }
#[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u8(self.0, other.0) }) }
#[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_u8(self.0, other.0) }) }
#[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_u8(self.0, other.0) }) }
}

#[cfg(target_arch = "aarch64")]
#[derive(Copy, Clone)]
#[repr(transparent)]
pub struct U16x8(pub uint16x8_t);

#[cfg(target_arch = "aarch64")]
impl U16x8 {
pub const LANES: usize = 8;
#[inline(always)] pub fn splat(v: u16) -> Self { Self(unsafe { vdupq_n_u16(v) }) }
#[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u16(0) }) }
#[inline(always)] pub fn from_slice(s: &[u16]) -> Self {
assert!(s.len() >= 8); Self(unsafe { vld1q_u16(s.as_ptr()) })
}
#[inline(always)] pub fn from_array(arr: [u16; 8]) -> Self { Self(unsafe { vld1q_u16(arr.as_ptr()) }) }
#[inline(always)] pub fn to_array(self) -> [u16; 8] {
let mut arr = [0u16; 8]; unsafe { vst1q_u16(arr.as_mut_ptr(), self.0) }; arr
}
#[inline(always)] pub fn copy_to_slice(self, s: &mut [u16]) {
assert!(s.len() >= 8); unsafe { vst1q_u16(s.as_mut_ptr(), self.0) };
}
#[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u16(self.0, other.0) }) }
#[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u16(self.0, other.0) }) }
#[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_u16(self.0, other.0) }) }
#[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_u16(self.0, other.0) }) }
}

#[cfg(target_arch = "aarch64")]
#[derive(Copy, Clone)]
#[repr(transparent)]
pub struct U32x4(pub uint32x4_t);

#[cfg(target_arch = "aarch64")]
impl U32x4 {
pub const LANES: usize = 4;
#[inline(always)] pub fn splat(v: u32) -> Self { Self(unsafe { vdupq_n_u32(v) }) }
#[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u32(0) }) }
#[inline(always)] pub fn from_slice(s: &[u32]) -> Self {
assert!(s.len() >= 4); Self(unsafe { vld1q_u32(s.as_ptr()) })
}
#[inline(always)] pub fn from_array(arr: [u32; 4]) -> Self { Self(unsafe { vld1q_u32(arr.as_ptr()) }) }
#[inline(always)] pub fn to_array(self) -> [u32; 4] {
let mut arr = [0u32; 4]; unsafe { vst1q_u32(arr.as_mut_ptr(), self.0) }; arr
}
#[inline(always)] pub fn copy_to_slice(self, s: &mut [u32]) {
assert!(s.len() >= 4); unsafe { vst1q_u32(s.as_mut_ptr(), self.0) };
}
#[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u32(self.0, other.0) }) }
#[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u32(self.0, other.0) }) }
#[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_u32(self.0, other.0) }) }
#[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_u32(self.0, other.0) }) }
}

#[cfg(target_arch = "aarch64")]
#[derive(Copy, Clone)]
#[repr(transparent)]
pub struct U64x2(pub uint64x2_t);

#[cfg(target_arch = "aarch64")]
impl U64x2 {
pub const LANES: usize = 2;
#[inline(always)] pub fn splat(v: u64) -> Self { Self(unsafe { vdupq_n_u64(v) }) }
#[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_u64(0) }) }
#[inline(always)] pub fn from_slice(s: &[u64]) -> Self {
assert!(s.len() >= 2); Self(unsafe { vld1q_u64(s.as_ptr()) })
}
#[inline(always)] pub fn from_array(arr: [u64; 2]) -> Self { Self(unsafe { vld1q_u64(arr.as_ptr()) }) }
#[inline(always)] pub fn to_array(self) -> [u64; 2] {
let mut arr = [0u64; 2]; unsafe { vst1q_u64(arr.as_mut_ptr(), self.0) }; arr
}
#[inline(always)] pub fn copy_to_slice(self, s: &mut [u64]) {
assert!(s.len() >= 2); unsafe { vst1q_u64(s.as_mut_ptr(), self.0) };
}
#[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_u64(self.0, other.0) }) }
#[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_u64(self.0, other.0) }) }
// NEON has no vminq_u64 / vmaxq_u64 — scalar fallback
#[inline(always)] pub fn min(self, other: Self) -> Self {
let a = self.to_array(); let b = other.to_array();
Self::from_array([a[0].min(b[0]), a[1].min(b[1])])
}
#[inline(always)] pub fn max(self, other: Self) -> Self {
let a = self.to_array(); let b = other.to_array();
Self::from_array([a[0].max(b[0]), a[1].max(b[1])])
}
}

#[cfg(target_arch = "aarch64")]
#[derive(Copy, Clone)]
#[repr(transparent)]
pub struct I32x4(pub int32x4_t);

#[cfg(target_arch = "aarch64")]
impl I32x4 {
pub const LANES: usize = 4;
#[inline(always)] pub fn splat(v: i32) -> Self { Self(unsafe { vdupq_n_s32(v) }) }
#[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_s32(0) }) }
#[inline(always)] pub fn from_slice(s: &[i32]) -> Self {
assert!(s.len() >= 4); Self(unsafe { vld1q_s32(s.as_ptr()) })
}
#[inline(always)] pub fn from_array(arr: [i32; 4]) -> Self { Self(unsafe { vld1q_s32(arr.as_ptr()) }) }
#[inline(always)] pub fn to_array(self) -> [i32; 4] {
let mut arr = [0i32; 4]; unsafe { vst1q_s32(arr.as_mut_ptr(), self.0) }; arr
}
#[inline(always)] pub fn copy_to_slice(self, s: &mut [i32]) {
assert!(s.len() >= 4); unsafe { vst1q_s32(s.as_mut_ptr(), self.0) };
}
#[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_s32(self.0, other.0) }) }
#[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_s32(self.0, other.0) }) }
#[inline(always)] pub fn min(self, other: Self) -> Self { Self(unsafe { vminq_s32(self.0, other.0) }) }
#[inline(always)] pub fn max(self, other: Self) -> Self { Self(unsafe { vmaxq_s32(self.0, other.0) }) }
}

#[cfg(target_arch = "aarch64")]
#[derive(Copy, Clone)]
#[repr(transparent)]
pub struct I64x2(pub int64x2_t);

#[cfg(target_arch = "aarch64")]
impl I64x2 {
pub const LANES: usize = 2;
#[inline(always)] pub fn splat(v: i64) -> Self { Self(unsafe { vdupq_n_s64(v) }) }
#[inline(always)] pub fn zero() -> Self { Self(unsafe { vdupq_n_s64(0) }) }
#[inline(always)] pub fn from_slice(s: &[i64]) -> Self {
assert!(s.len() >= 2); Self(unsafe { vld1q_s64(s.as_ptr()) })
}
#[inline(always)] pub fn from_array(arr: [i64; 2]) -> Self { Self(unsafe { vld1q_s64(arr.as_ptr()) }) }
#[inline(always)] pub fn to_array(self) -> [i64; 2] {
let mut arr = [0i64; 2]; unsafe { vst1q_s64(arr.as_mut_ptr(), self.0) }; arr
}
#[inline(always)] pub fn copy_to_slice(self, s: &mut [i64]) {
assert!(s.len() >= 2); unsafe { vst1q_s64(s.as_mut_ptr(), self.0) };
}
#[inline(always)] pub fn add(self, other: Self) -> Self { Self(unsafe { vaddq_s64(self.0, other.0) }) }
#[inline(always)] pub fn sub(self, other: Self) -> Self { Self(unsafe { vsubq_s64(self.0, other.0) }) }
// NEON has no vminq_s64 / vmaxq_s64 — scalar fallback
#[inline(always)] pub fn min(self, other: Self) -> Self {
let a = self.to_array(); let b = other.to_array();
Self::from_array([a[0].min(b[0]), a[1].min(b[1])])
}
#[inline(always)] pub fn max(self, other: Self) -> Self {
let a = self.to_array(); let b = other.to_array();
Self::from_array([a[0].max(b[0]), a[1].max(b[1])])
}
}


// ── Polyfills for wider lanes (scalar arrays) ─────────────────────────────

macro_rules! neon_int_polyfill {
Expand Down
Loading