Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ pub mod simd_neon;
#[allow(clippy::all, missing_docs, dead_code, unused_variables, unused_imports)]
pub mod simd_wasm;

/// Slice-level integer SIMD ops (i8/i16) — `add_i8`, `dot_i8`, `min_i8`, …
#[cfg(feature = "std")]
#[allow(missing_docs)]
pub mod simd_int_ops;

/// Pluggable linear algebra backends (native SIMD, MKL, OpenBLAS).
#[cfg(feature = "std")]
pub mod backend;
Expand Down
72 changes: 69 additions & 3 deletions src/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,14 @@ pub const PREFERRED_I16_LANES: usize = 16;

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
pub use crate::simd_avx512::{
// 256-bit (AVX2 baseline, __m256/__m256d)
F32x8, F64x4, f32x8, f64x4,
// 256-bit (AVX2 baseline, __m256/__m256d/__m256i)
F32x8, F64x4, I8x32, I16x16, f32x8, f64x4, i8x32, i16x16,
// 512-bit (native AVX-512, __m512/__m512d/__m512i)
F32x16, F64x8, U8x64, I32x16, I64x8, U16x32, U32x16, U64x8,
I8x64, I16x32,
F32Mask16, F64Mask8,
f32x16, f64x8, u8x64, i32x16, i64x8, u32x16, u64x8,
i8x64, i16x32,
};

// BF16 types + batch conversion (always available — scalar fallback built in)
Expand Down Expand Up @@ -223,13 +225,15 @@ pub use crate::simd_avx512::{
pub use crate::simd_avx512::{BF16x16, BF16x8};

#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
pub use crate::simd_avx512::{F32x8, F64x4, f32x8, f64x4};
pub use crate::simd_avx512::{F32x8, F64x4, I8x32, I16x16, f32x8, f64x4, i8x32, i16x16};

#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
pub use crate::simd_avx2::{
F32x16, F64x8, U8x64, I32x16, I64x8, U16x32, U32x16, U64x8,
I8x64, I16x32,
F32Mask16, F64Mask8,
f32x16, f64x8, u8x64, i32x16, i64x8, u32x16, u64x8,
i8x64, i16x32,
};

// ============================================================================
Expand Down Expand Up @@ -630,6 +634,62 @@ pub(crate) mod scalar {
impl_int_type!(U32x16, u32, 16, 0u32);
impl_int_type!(U64x8, u64, 8, 0u64);

// I8/I16 SIMD types (scalar fallback)
impl_int_type!(I8x64, i8, 64, 0i8);
impl_int_type!(I8x32, i8, 32, 0i8);
impl_int_type!(I16x32, i16, 32, 0i16);
impl_int_type!(I16x16, i16, 16, 0i16);

// I8x64 / I8x32 / I16x32 / I16x16 — AVX-512BW-style methods (scalar shape)
impl I8x64 {
#[inline(always)] pub fn zero() -> Self { Self([0i8; 64]) }
#[inline(always)] pub fn add(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) }
#[inline(always)] pub fn sub(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) }
#[inline(always)] pub fn min(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].min(other.0[i]); } Self(o) }
#[inline(always)] pub fn max(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].max(other.0[i]); } Self(o) }
#[inline(always)] pub fn cmp_gt(self, other: Self) -> u64 {
let mut m: u64 = 0;
for i in 0..64 { if self.0[i] > other.0[i] { m |= 1u64 << i; } }
m
}
}
impl I8x32 {
#[inline(always)] pub fn zero() -> Self { Self([0i8; 32]) }
#[inline(always)] pub fn add(self, other: Self) -> Self { let mut o = [0i8; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) }
#[inline(always)] pub fn sub(self, other: Self) -> Self { let mut o = [0i8; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) }
#[inline(always)] pub fn min(self, other: Self) -> Self { let mut o = [0i8; 32]; for i in 0..32 { o[i] = self.0[i].min(other.0[i]); } Self(o) }
#[inline(always)] pub fn max(self, other: Self) -> Self { let mut o = [0i8; 32]; for i in 0..32 { o[i] = self.0[i].max(other.0[i]); } Self(o) }
#[inline(always)] pub fn cmp_gt(self, other: Self) -> u32 {
let mut m: u32 = 0;
for i in 0..32 { if self.0[i] > other.0[i] { m |= 1u32 << i; } }
m
}
}
impl I16x32 {
#[inline(always)] pub fn zero() -> Self { Self([0i16; 32]) }
#[inline(always)] pub fn add(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) }
#[inline(always)] pub fn sub(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) }
#[inline(always)] pub fn min(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].min(other.0[i]); } Self(o) }
#[inline(always)] pub fn max(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].max(other.0[i]); } Self(o) }
#[inline(always)] pub fn cmp_gt(self, other: Self) -> u32 {
let mut m: u32 = 0;
for i in 0..32 { if self.0[i] > other.0[i] { m |= 1u32 << i; } }
m
}
}
impl I16x16 {
#[inline(always)] pub fn zero() -> Self { Self([0i16; 16]) }
#[inline(always)] pub fn add(self, other: Self) -> Self { let mut o = [0i16; 16]; for i in 0..16 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) }
#[inline(always)] pub fn sub(self, other: Self) -> Self { let mut o = [0i16; 16]; for i in 0..16 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) }
#[inline(always)] pub fn min(self, other: Self) -> Self { let mut o = [0i16; 16]; for i in 0..16 { o[i] = self.0[i].min(other.0[i]); } Self(o) }
#[inline(always)] pub fn max(self, other: Self) -> Self { let mut o = [0i16; 16]; for i in 0..16 { o[i] = self.0[i].max(other.0[i]); } Self(o) }
#[inline(always)] pub fn cmp_gt(self, other: Self) -> u16 {
let mut m: u16 = 0;
for i in 0..16 { if self.0[i] > other.0[i] { m |= 1u16 << i; } }
m
}
}

// Extra methods for U16x32 (widen/narrow, shift, multiply)
impl U16x32 {
#[inline(always)]
Expand Down Expand Up @@ -1012,6 +1072,10 @@ pub(crate) mod scalar {
#[allow(non_camel_case_types)] pub type u64x8 = U64x8;
#[allow(non_camel_case_types)] pub type f32x8 = F32x8;
#[allow(non_camel_case_types)] pub type f64x4 = F64x4;
#[allow(non_camel_case_types)] pub type i8x64 = I8x64;
#[allow(non_camel_case_types)] pub type i8x32 = I8x32;
#[allow(non_camel_case_types)] pub type i16x32 = I16x32;
#[allow(non_camel_case_types)] pub type i16x16 = I16x16;
}

// aarch64: F32x16/F64x8 come from the real NEON paired-load implementation
Expand All @@ -1036,9 +1100,11 @@ pub use scalar::{
pub use scalar::{
F32x16, F64x8, U8x64, I32x16, I64x8, U16x32, U32x16, U64x8,
F32x8, F64x4,
I8x64, I8x32, I16x32, I16x16,
F32Mask16, F64Mask8,
f32x16, f64x8, u8x64, i32x16, i64x8, u32x16, u64x8,
f32x8, f64x4,
i8x64, i8x32, i16x32, i16x16,
};

// Scalar BF16 conversion — always available on all platforms
Expand Down
49 changes: 49 additions & 0 deletions src/simd_avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

use crate::simd_avx512::{f32x8, f64x4};

// AVX2-native I8x32 / I16x16 live in simd_avx512.rs (256-bit __m256i types).
// Re-export so consumers see a unified `crate::simd_avx2::I8x32` symbol.
pub use crate::simd_avx512::{I8x32, I16x16, i8x32, i16x16};

// ============================================================================
// AVX2 lane counts (half of AVX-512)
// ============================================================================
Expand Down Expand Up @@ -772,6 +776,47 @@ macro_rules! avx2_int_type {
}

avx2_int_type!(U8x64, u8, 64, 0u8);
avx2_int_type!(I8x64, i8, 64, 0i8);
avx2_int_type!(I16x32, i16, 32, 0i16);

// I8x64 / I16x32: AVX2 scalar polyfill — methods matching the AVX-512BW API
impl I8x64 {
#[inline(always)]
pub fn zero() -> Self { Self([0i8; 64]) }
#[inline(always)]
pub fn add(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) }
#[inline(always)]
pub fn sub(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) }
#[inline(always)]
pub fn min(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].min(other.0[i]); } Self(o) }
#[inline(always)]
pub fn max(self, other: Self) -> Self { let mut o = [0i8; 64]; for i in 0..64 { o[i] = self.0[i].max(other.0[i]); } Self(o) }
#[inline(always)]
pub fn cmp_gt(self, other: Self) -> u64 {
let mut m: u64 = 0;
for i in 0..64 { if self.0[i] > other.0[i] { m |= 1u64 << i; } }
m
}
}

impl I16x32 {
#[inline(always)]
pub fn zero() -> Self { Self([0i16; 32]) }
#[inline(always)]
pub fn add(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_add(other.0[i]); } Self(o) }
#[inline(always)]
pub fn sub(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].wrapping_sub(other.0[i]); } Self(o) }
#[inline(always)]
pub fn min(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].min(other.0[i]); } Self(o) }
#[inline(always)]
pub fn max(self, other: Self) -> Self { let mut o = [0i16; 32]; for i in 0..32 { o[i] = self.0[i].max(other.0[i]); } Self(o) }
#[inline(always)]
pub fn cmp_gt(self, other: Self) -> u32 {
let mut m: u32 = 0;
for i in 0..32 { if self.0[i] > other.0[i] { m |= 1u32 << i; } }
m
}
}

// ── U8x64 byte-level operations (scalar fallback for AVX2 tier) ──────────
// These match the AVX-512 U8x64 methods in simd_avx512.rs.
Expand Down Expand Up @@ -1007,6 +1052,10 @@ pub type i64x8 = I64x8;
pub type u32x16 = U32x16;
#[allow(non_camel_case_types)]
pub type u64x8 = U64x8;
#[allow(non_camel_case_types)]
pub type i8x64 = I8x64;
#[allow(non_camel_case_types)]
pub type i16x32 = I16x32;

#[cfg(test)]
mod tests {
Expand Down
Loading
Loading