diff --git a/src/simd.rs b/src/simd.rs index 45207f0d..48e9ca1f 100644 --- a/src/simd.rs +++ b/src/simd.rs @@ -237,7 +237,7 @@ pub use crate::simd_avx2::{ // ============================================================================ #[cfg(not(target_arch = "x86_64"))] -mod scalar { +pub(crate) mod scalar { use core::fmt; use core::ops::{ Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, @@ -1014,7 +1014,25 @@ mod scalar { #[allow(non_camel_case_types)] pub type f64x4 = F64x4; } -#[cfg(not(target_arch = "x86_64"))] +// aarch64: F32x16/F64x8 come from the real NEON paired-load implementation +// in simd_neon::aarch64_simd (verified 2026-04-30, agent A7 — burn parity item 9). +// Integer + 256-bit float types still come from the scalar fallback; they're +// not on the critical path for f32 BLAS-1 / VML kernels. +#[cfg(target_arch = "aarch64")] +pub use crate::simd_neon::aarch64_simd::{ + F32x16, F64x8, F32Mask16, F64Mask8, + f32x16, f64x8, +}; +#[cfg(target_arch = "aarch64")] +pub use scalar::{ + U8x64, I32x16, I64x8, U16x32, U32x16, U64x8, + F32x8, F64x4, + u8x64, i32x16, i64x8, u32x16, u64x8, + f32x8, f64x4, +}; + +// Other non-x86 targets (wasm, riscv, etc.): full scalar fallback. +#[cfg(all(not(target_arch = "x86_64"), not(target_arch = "aarch64")))] pub use scalar::{ F32x16, F64x8, U8x64, I32x16, I64x8, U16x32, U32x16, U64x8, F32x8, F64x4, diff --git a/src/simd_neon.rs b/src/simd_neon.rs index 45a5d665..555ac850 100644 --- a/src/simd_neon.rs +++ b/src/simd_neon.rs @@ -446,6 +446,647 @@ pub fn f32_to_f16_batch(input: &[f32], output: &mut [u16]) { } } +// ═══════════════════════════════════════════════════════════════════════════ +// NEON-backed F32x16 / F64x8 — paired loads, NOT scalar fallback +// ═══════════════════════════════════════════════════════════════════════════ +// +// Burn parity item 9 (verified 2026-04-30, agent A7): on aarch64, `F32x16` +// previously dispatched to `simd::scalar` mod (element-wise [f32;16] loop). +// This module provides a real NEON implementation backed by 4× `float32x4_t` +// for `F32x16` and 4× `float64x2_t` for `F64x8`. Hot-path ops (add, sub, mul, +// div, mul_add via `vfmaq_f32`/`vfmaq_f64`, splat, vld1q_*, vst1q_*) compile +// to a single NEON instruction per pair. `simd.rs` re-exports these for +// `target_arch = "aarch64"` ahead of the scalar fallback module. +// +// API matches `simd_avx2::F32x16` (the "dual-tuple" pattern). Methods that +// don't have a direct NEON counterpart (comparisons, reduce_min/max, +// to_bits/from_bits, cast_i32) round-trip through `to_array` — same shape +// as the AVX2 polyfill, so consumer code on aarch64 gets the same +// correctness with vectorized arithmetic kernels. + +#[cfg(target_arch = "aarch64")] +pub mod aarch64_simd { + use super::*; + use core::fmt; + use core::ops::{ + Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign, + }; + + // Integer types come from the scalar fallback in simd.rs — they aren't on + // the perf-critical f32 BLAS-1 / VML path that this module accelerates. + pub use crate::simd::scalar::{ + I32x16, U32x16, U64x8, + }; + + /// 16×f32 backed by 4× NEON `float32x4_t` registers (paired loads). + #[derive(Copy, Clone)] + #[repr(align(64))] + pub struct F32x16(pub [float32x4_t; 4]); + + impl F32x16 { + pub const LANES: usize = 16; + + #[inline(always)] + pub fn splat(v: f32) -> Self { + unsafe { + let s = vdupq_n_f32(v); + Self([s, s, s, s]) + } + } + + #[inline(always)] + pub fn from_slice(s: &[f32]) -> Self { + assert!(s.len() >= 16); + unsafe { + let p = s.as_ptr(); + Self([ + vld1q_f32(p), + vld1q_f32(p.add(4)), + vld1q_f32(p.add(8)), + vld1q_f32(p.add(12)), + ]) + } + } + + #[inline(always)] + pub fn from_array(a: [f32; 16]) -> Self { Self::from_slice(&a) } + + #[inline(always)] + pub fn to_array(self) -> [f32; 16] { + let mut out = [0.0f32; 16]; + self.copy_to_slice(&mut out); + out + } + + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [f32]) { + assert!(s.len() >= 16); + unsafe { + let p = s.as_mut_ptr(); + vst1q_f32(p, self.0[0]); + vst1q_f32(p.add(4), self.0[1]); + vst1q_f32(p.add(8), self.0[2]); + vst1q_f32(p.add(12), self.0[3]); + } + } + + #[inline(always)] + pub fn reduce_sum(self) -> f32 { + unsafe { + let s01 = vaddq_f32(self.0[0], self.0[1]); + let s23 = vaddq_f32(self.0[2], self.0[3]); + vaddvq_f32(vaddq_f32(s01, s23)) + } + } + + #[inline(always)] + pub fn reduce_min(self) -> f32 { + self.to_array().iter().copied().fold(f32::INFINITY, f32::min) + } + + #[inline(always)] + pub fn reduce_max(self) -> f32 { + self.to_array().iter().copied().fold(f32::NEG_INFINITY, f32::max) + } + + #[inline(always)] + pub fn abs(self) -> Self { + unsafe { + Self([ + vabsq_f32(self.0[0]), vabsq_f32(self.0[1]), + vabsq_f32(self.0[2]), vabsq_f32(self.0[3]), + ]) + } + } + + #[inline(always)] + pub fn sqrt(self) -> Self { + unsafe { + Self([ + vsqrtq_f32(self.0[0]), vsqrtq_f32(self.0[1]), + vsqrtq_f32(self.0[2]), vsqrtq_f32(self.0[3]), + ]) + } + } + + #[inline(always)] + pub fn round(self) -> Self { + unsafe { + Self([ + vrndnq_f32(self.0[0]), vrndnq_f32(self.0[1]), + vrndnq_f32(self.0[2]), vrndnq_f32(self.0[3]), + ]) + } + } + + #[inline(always)] + pub fn floor(self) -> Self { + unsafe { + Self([ + vrndmq_f32(self.0[0]), vrndmq_f32(self.0[1]), + vrndmq_f32(self.0[2]), vrndmq_f32(self.0[3]), + ]) + } + } + + #[inline(always)] + pub fn mul_add(self, b: Self, c: Self) -> Self { + unsafe { + Self([ + vfmaq_f32(c.0[0], self.0[0], b.0[0]), + vfmaq_f32(c.0[1], self.0[1], b.0[1]), + vfmaq_f32(c.0[2], self.0[2], b.0[2]), + vfmaq_f32(c.0[3], self.0[3], b.0[3]), + ]) + } + } + + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + unsafe { + Self([ + vminq_f32(self.0[0], other.0[0]), + vminq_f32(self.0[1], other.0[1]), + vminq_f32(self.0[2], other.0[2]), + vminq_f32(self.0[3], other.0[3]), + ]) + } + } + + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + unsafe { + Self([ + vmaxq_f32(self.0[0], other.0[0]), + vmaxq_f32(self.0[1], other.0[1]), + vmaxq_f32(self.0[2], other.0[2]), + vmaxq_f32(self.0[3], other.0[3]), + ]) + } + } + + #[inline(always)] + pub fn simd_clamp(self, lo: Self, hi: Self) -> Self { self.simd_max(lo).simd_min(hi) } + + #[inline(always)] + pub fn simd_lt(self, other: Self) -> F32Mask16 { + let a = self.to_array(); let b = other.to_array(); + let mut bits: u16 = 0; + for i in 0..16 { if a[i] < b[i] { bits |= 1 << i; } } + F32Mask16(bits) + } + #[inline(always)] + pub fn simd_le(self, other: Self) -> F32Mask16 { + let a = self.to_array(); let b = other.to_array(); + let mut bits: u16 = 0; + for i in 0..16 { if a[i] <= b[i] { bits |= 1 << i; } } + F32Mask16(bits) + } + #[inline(always)] pub fn simd_gt(self, other: Self) -> F32Mask16 { other.simd_lt(self) } + #[inline(always)] pub fn simd_ge(self, other: Self) -> F32Mask16 { other.simd_le(self) } + #[inline(always)] + pub fn simd_eq(self, other: Self) -> F32Mask16 { + let a = self.to_array(); let b = other.to_array(); + let mut bits: u16 = 0; + for i in 0..16 { if a[i] == b[i] { bits |= 1 << i; } } + F32Mask16(bits) + } + #[inline(always)] + pub fn simd_ne(self, other: Self) -> F32Mask16 { + let a = self.to_array(); let b = other.to_array(); + let mut bits: u16 = 0; + for i in 0..16 { if a[i] != b[i] { bits |= 1 << i; } } + F32Mask16(bits) + } + + #[inline(always)] + pub fn to_bits(self) -> U32x16 { + let a = self.to_array(); + let mut o = [0u32; 16]; for i in 0..16 { o[i] = a[i].to_bits(); } U32x16(o) + } + #[inline(always)] + pub fn from_bits(bits: U32x16) -> Self { + let mut o = [0.0f32; 16]; for i in 0..16 { o[i] = f32::from_bits(bits.0[i]); } + Self::from_array(o) + } + #[inline(always)] + pub fn cast_i32(self) -> I32x16 { + let a = self.to_array(); + let mut o = [0i32; 16]; for i in 0..16 { o[i] = a[i] as i32; } I32x16(o) + } + } + + impl Add for F32x16 { + type Output = Self; + #[inline(always)] + fn add(self, rhs: Self) -> Self { + unsafe { + Self([ + vaddq_f32(self.0[0], rhs.0[0]), vaddq_f32(self.0[1], rhs.0[1]), + vaddq_f32(self.0[2], rhs.0[2]), vaddq_f32(self.0[3], rhs.0[3]), + ]) + } + } + } + impl Sub for F32x16 { + type Output = Self; + #[inline(always)] + fn sub(self, rhs: Self) -> Self { + unsafe { + Self([ + vsubq_f32(self.0[0], rhs.0[0]), vsubq_f32(self.0[1], rhs.0[1]), + vsubq_f32(self.0[2], rhs.0[2]), vsubq_f32(self.0[3], rhs.0[3]), + ]) + } + } + } + impl Mul for F32x16 { + type Output = Self; + #[inline(always)] + fn mul(self, rhs: Self) -> Self { + unsafe { + Self([ + vmulq_f32(self.0[0], rhs.0[0]), vmulq_f32(self.0[1], rhs.0[1]), + vmulq_f32(self.0[2], rhs.0[2]), vmulq_f32(self.0[3], rhs.0[3]), + ]) + } + } + } + impl Div for F32x16 { + type Output = Self; + #[inline(always)] + fn div(self, rhs: Self) -> Self { + unsafe { + Self([ + vdivq_f32(self.0[0], rhs.0[0]), vdivq_f32(self.0[1], rhs.0[1]), + vdivq_f32(self.0[2], rhs.0[2]), vdivq_f32(self.0[3], rhs.0[3]), + ]) + } + } + } + impl AddAssign for F32x16 { #[inline(always)] fn add_assign(&mut self, rhs: Self) { *self = *self + rhs; } } + impl SubAssign for F32x16 { #[inline(always)] fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; } } + impl MulAssign for F32x16 { #[inline(always)] fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } } + impl DivAssign for F32x16 { #[inline(always)] fn div_assign(&mut self, rhs: Self) { *self = *self / rhs; } } + impl Neg for F32x16 { + type Output = Self; + #[inline(always)] + fn neg(self) -> Self { + unsafe { + Self([ + vnegq_f32(self.0[0]), vnegq_f32(self.0[1]), + vnegq_f32(self.0[2]), vnegq_f32(self.0[3]), + ]) + } + } + } + impl fmt::Debug for F32x16 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "F32x16({:?})", self.to_array()) + } + } + impl PartialEq for F32x16 { + fn eq(&self, other: &Self) -> bool { self.to_array() == other.to_array() } + } + impl Default for F32x16 { fn default() -> Self { Self::splat(0.0) } } + + #[derive(Copy, Clone, Debug)] + pub struct F32Mask16(pub u16); + impl F32Mask16 { + #[inline(always)] + pub fn select(self, true_val: F32x16, false_val: F32x16) -> F32x16 { + let t = true_val.to_array(); let f = false_val.to_array(); + let mut o = [0.0f32; 16]; + for i in 0..16 { o[i] = if (self.0 >> i) & 1 == 1 { t[i] } else { f[i] }; } + F32x16::from_array(o) + } + } + + /// 8×f64 backed by 4× NEON `float64x2_t` registers (paired loads). + #[derive(Copy, Clone)] + #[repr(align(64))] + pub struct F64x8(pub [float64x2_t; 4]); + + impl F64x8 { + pub const LANES: usize = 8; + + #[inline(always)] + pub fn splat(v: f64) -> Self { + unsafe { + let s = vdupq_n_f64(v); + Self([s, s, s, s]) + } + } + + #[inline(always)] + pub fn from_slice(s: &[f64]) -> Self { + assert!(s.len() >= 8); + unsafe { + let p = s.as_ptr(); + Self([ + vld1q_f64(p), + vld1q_f64(p.add(2)), + vld1q_f64(p.add(4)), + vld1q_f64(p.add(6)), + ]) + } + } + + #[inline(always)] + pub fn from_array(a: [f64; 8]) -> Self { Self::from_slice(&a) } + + #[inline(always)] + pub fn to_array(self) -> [f64; 8] { + let mut out = [0.0f64; 8]; + self.copy_to_slice(&mut out); + out + } + + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [f64]) { + assert!(s.len() >= 8); + unsafe { + let p = s.as_mut_ptr(); + vst1q_f64(p, self.0[0]); + vst1q_f64(p.add(2), self.0[1]); + vst1q_f64(p.add(4), self.0[2]); + vst1q_f64(p.add(6), self.0[3]); + } + } + + #[inline(always)] + pub fn reduce_sum(self) -> f64 { + unsafe { + let s01 = vaddq_f64(self.0[0], self.0[1]); + let s23 = vaddq_f64(self.0[2], self.0[3]); + vaddvq_f64(vaddq_f64(s01, s23)) + } + } + + #[inline(always)] + pub fn reduce_min(self) -> f64 { + self.to_array().iter().copied().fold(f64::INFINITY, f64::min) + } + + #[inline(always)] + pub fn reduce_max(self) -> f64 { + self.to_array().iter().copied().fold(f64::NEG_INFINITY, f64::max) + } + + #[inline(always)] + pub fn abs(self) -> Self { + unsafe { + Self([ + vabsq_f64(self.0[0]), vabsq_f64(self.0[1]), + vabsq_f64(self.0[2]), vabsq_f64(self.0[3]), + ]) + } + } + + #[inline(always)] + pub fn sqrt(self) -> Self { + unsafe { + Self([ + vsqrtq_f64(self.0[0]), vsqrtq_f64(self.0[1]), + vsqrtq_f64(self.0[2]), vsqrtq_f64(self.0[3]), + ]) + } + } + + #[inline(always)] + pub fn round(self) -> Self { + unsafe { + Self([ + vrndnq_f64(self.0[0]), vrndnq_f64(self.0[1]), + vrndnq_f64(self.0[2]), vrndnq_f64(self.0[3]), + ]) + } + } + + #[inline(always)] + pub fn floor(self) -> Self { + unsafe { + Self([ + vrndmq_f64(self.0[0]), vrndmq_f64(self.0[1]), + vrndmq_f64(self.0[2]), vrndmq_f64(self.0[3]), + ]) + } + } + + #[inline(always)] + pub fn mul_add(self, b: Self, c: Self) -> Self { + unsafe { + Self([ + vfmaq_f64(c.0[0], self.0[0], b.0[0]), + vfmaq_f64(c.0[1], self.0[1], b.0[1]), + vfmaq_f64(c.0[2], self.0[2], b.0[2]), + vfmaq_f64(c.0[3], self.0[3], b.0[3]), + ]) + } + } + + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + unsafe { + Self([ + vminq_f64(self.0[0], other.0[0]), + vminq_f64(self.0[1], other.0[1]), + vminq_f64(self.0[2], other.0[2]), + vminq_f64(self.0[3], other.0[3]), + ]) + } + } + + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + unsafe { + Self([ + vmaxq_f64(self.0[0], other.0[0]), + vmaxq_f64(self.0[1], other.0[1]), + vmaxq_f64(self.0[2], other.0[2]), + vmaxq_f64(self.0[3], other.0[3]), + ]) + } + } + + #[inline(always)] + pub fn simd_clamp(self, lo: Self, hi: Self) -> Self { self.simd_max(lo).simd_min(hi) } + + #[inline(always)] + pub fn simd_ge(self, other: Self) -> F64Mask8 { + let a = self.to_array(); let b = other.to_array(); + let mut bits: u8 = 0; for i in 0..8 { if a[i] >= b[i] { bits |= 1 << i; } } + F64Mask8(bits) + } + #[inline(always)] + pub fn simd_le(self, other: Self) -> F64Mask8 { + let a = self.to_array(); let b = other.to_array(); + let mut bits: u8 = 0; for i in 0..8 { if a[i] <= b[i] { bits |= 1 << i; } } + F64Mask8(bits) + } + + #[inline(always)] + pub fn to_bits(self) -> U64x8 { + let a = self.to_array(); + let mut o = [0u64; 8]; for i in 0..8 { o[i] = a[i].to_bits(); } U64x8(o) + } + #[inline(always)] + pub fn from_bits(bits: U64x8) -> Self { + let mut o = [0.0f64; 8]; for i in 0..8 { o[i] = f64::from_bits(bits.0[i]); } + Self::from_array(o) + } + } + + impl Add for F64x8 { + type Output = Self; + #[inline(always)] + fn add(self, rhs: Self) -> Self { + unsafe { + Self([ + vaddq_f64(self.0[0], rhs.0[0]), vaddq_f64(self.0[1], rhs.0[1]), + vaddq_f64(self.0[2], rhs.0[2]), vaddq_f64(self.0[3], rhs.0[3]), + ]) + } + } + } + impl Sub for F64x8 { + type Output = Self; + #[inline(always)] + fn sub(self, rhs: Self) -> Self { + unsafe { + Self([ + vsubq_f64(self.0[0], rhs.0[0]), vsubq_f64(self.0[1], rhs.0[1]), + vsubq_f64(self.0[2], rhs.0[2]), vsubq_f64(self.0[3], rhs.0[3]), + ]) + } + } + } + impl Mul for F64x8 { + type Output = Self; + #[inline(always)] + fn mul(self, rhs: Self) -> Self { + unsafe { + Self([ + vmulq_f64(self.0[0], rhs.0[0]), vmulq_f64(self.0[1], rhs.0[1]), + vmulq_f64(self.0[2], rhs.0[2]), vmulq_f64(self.0[3], rhs.0[3]), + ]) + } + } + } + impl Div for F64x8 { + type Output = Self; + #[inline(always)] + fn div(self, rhs: Self) -> Self { + unsafe { + Self([ + vdivq_f64(self.0[0], rhs.0[0]), vdivq_f64(self.0[1], rhs.0[1]), + vdivq_f64(self.0[2], rhs.0[2]), vdivq_f64(self.0[3], rhs.0[3]), + ]) + } + } + } + impl AddAssign for F64x8 { #[inline(always)] fn add_assign(&mut self, rhs: Self) { *self = *self + rhs; } } + impl SubAssign for F64x8 { #[inline(always)] fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; } } + impl MulAssign for F64x8 { #[inline(always)] fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } } + impl DivAssign for F64x8 { #[inline(always)] fn div_assign(&mut self, rhs: Self) { *self = *self / rhs; } } + impl Neg for F64x8 { + type Output = Self; + #[inline(always)] + fn neg(self) -> Self { + unsafe { + Self([ + vnegq_f64(self.0[0]), vnegq_f64(self.0[1]), + vnegq_f64(self.0[2]), vnegq_f64(self.0[3]), + ]) + } + } + } + impl fmt::Debug for F64x8 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "F64x8({:?})", self.to_array()) + } + } + impl PartialEq for F64x8 { + fn eq(&self, other: &Self) -> bool { self.to_array() == other.to_array() } + } + impl Default for F64x8 { fn default() -> Self { Self::splat(0.0) } } + + #[derive(Copy, Clone, Debug)] + pub struct F64Mask8(pub u8); + impl F64Mask8 { + #[inline(always)] + pub fn select(self, true_val: F64x8, false_val: F64x8) -> F64x8 { + let t = true_val.to_array(); let f = false_val.to_array(); + let mut o = [0.0f64; 8]; + for i in 0..8 { o[i] = if (self.0 >> i) & 1 == 1 { t[i] } else { f[i] }; } + F64x8::from_array(o) + } + } + + // Lowercase aliases (consumer-API parity) + #[allow(non_camel_case_types)] pub type f32x16 = F32x16; + #[allow(non_camel_case_types)] pub type f64x8 = F64x8; +} + +#[cfg(all(target_arch = "aarch64", test))] +mod neon_pair_tests { + use super::aarch64_simd::*; + + #[test] + fn f32x16_neon_load_add_store() { + let a: [f32; 16] = core::array::from_fn(|i| i as f32); + let b: [f32; 16] = core::array::from_fn(|i| (i * 10) as f32); + let va = F32x16::from_slice(&a); + let vb = F32x16::from_slice(&b); + let vc = va + vb; + let mut out = [0.0f32; 16]; + vc.copy_to_slice(&mut out); + for i in 0..16 { + assert_eq!(out[i], (i + i * 10) as f32); + } + } + + #[test] + fn f32x16_neon_mul_add() { + let a = F32x16::splat(2.0); + let b = F32x16::splat(3.0); + let c = F32x16::splat(1.0); + let r = a.mul_add(b, c).to_array(); + for &v in &r { assert_eq!(v, 7.0); } + } + + #[test] + fn f32x16_neon_reduce_sum() { + let v = F32x16::from_array(core::array::from_fn(|i| (i + 1) as f32)); + // sum 1..=16 = 136 + assert_eq!(v.reduce_sum(), 136.0); + } + + #[test] + fn f64x8_neon_load_add_store() { + let a: [f64; 8] = core::array::from_fn(|i| i as f64); + let b: [f64; 8] = core::array::from_fn(|i| (i * 10) as f64); + let va = F64x8::from_slice(&a); + let vb = F64x8::from_slice(&b); + let vc = va + vb; + let mut out = [0.0f64; 8]; + vc.copy_to_slice(&mut out); + for i in 0..8 { + assert_eq!(out[i], (i + i * 10) as f64); + } + } + + #[test] + fn f64x8_neon_mul_add_reduce() { + let a = F64x8::splat(2.0); + let b = F64x8::splat(3.0); + let c = F64x8::splat(1.0); + let r = a.mul_add(b, c); + // 8 lanes × 7.0 = 56.0 + assert_eq!(r.reduce_sum(), 56.0); + } +} + // ═══════════════════════════════════════════════════════════════════════════ // Tests (run on x86 as compile-check, actual NEON tests need aarch64) // ═══════════════════════════════════════════════════════════════════════════