From ec33c21f4b826b06a6f501d8f90d433c2196d5de Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 29 Mar 2026 09:06:12 +0000 Subject: [PATCH] =?UTF-8?q?refactor(burn):=20symlink=20overlay=20=E2=80=94?= =?UTF-8?q?=2015K=20upstream=20lines=20=E2=86=92=2035=20symlinks=20+=202?= =?UTF-8?q?=20files?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace 11,775 lines of copied upstream burn-ndarray source with: - Git submodule: crates/burn/upstream/ (pinned at ed72d2b) - 35 symlinks: crates/burn/src/*.rs → upstream/crates/burn-ndarray/src/ - 2 real files: ops/tensor.rs (our SIMD wires), ops/activation.rs (fused sigmoid) Our total owned code: 846 lines (801 tensor + 45 activation). Upstream code: 0 lines tracked (lives in submodule). The repo shrinks by ~11K lines while keeping the same functionality. Build: cargo check --manifest-path crates/burn/Cargo.toml Test: cargo test --manifest-path crates/burn/Cargo.toml (30 pass) https://claude.ai/code/session_01Y69Vnw751w75iVSBRws7o7 --- .gitmodules | 3 + crates/burn/src/backend.rs | 223 +-- crates/burn/src/element.rs | 208 +-- crates/burn/src/lib.rs | 30 +- crates/burn/src/ops/adaptive_avgpool.rs | 104 +- crates/burn/src/ops/avgpool.rs | 173 +-- crates/burn/src/ops/base.rs | 1449 +------------------ crates/burn/src/ops/bool_tensor.rs | 242 +--- crates/burn/src/ops/conv.rs | 575 +------- crates/burn/src/ops/deform_conv.rs | 663 +-------- crates/burn/src/ops/grid_sample.rs | 215 +-- crates/burn/src/ops/int_tensor.rs | 510 +------ crates/burn/src/ops/interpolate.rs | 398 +---- crates/burn/src/ops/macros.rs | 108 +- crates/burn/src/ops/matmul.rs | 363 +---- crates/burn/src/ops/maxpool.rs | 248 +--- crates/burn/src/ops/mod.rs | 25 +- crates/burn/src/ops/module.rs | 382 +---- crates/burn/src/ops/padding.rs | 73 +- crates/burn/src/ops/qtensor.rs | 354 +---- crates/burn/src/ops/quantization.rs | 219 +-- crates/burn/src/ops/simd/avgpool.rs | 444 +----- crates/burn/src/ops/simd/base.rs | 116 +- crates/burn/src/ops/simd/binary.rs | 300 +--- crates/burn/src/ops/simd/binary_elemwise.rs | 420 +----- crates/burn/src/ops/simd/cmp.rs | 375 +---- crates/burn/src/ops/simd/conv.rs | 495 +------ crates/burn/src/ops/simd/maxpool.rs | 395 +---- crates/burn/src/ops/simd/mod.rs | 11 +- crates/burn/src/ops/simd/unary.rs | 235 +-- crates/burn/src/ops/transaction.rs | 14 +- crates/burn/src/parallel.rs | 77 +- crates/burn/src/rand.rs | 37 +- crates/burn/src/sharing.rs | 20 +- crates/burn/src/storage.rs | 507 +------ crates/burn/src/tensor.rs | 956 +----------- crates/burn/upstream | 1 + 37 files changed, 39 insertions(+), 10929 deletions(-) create mode 100644 .gitmodules mode change 100644 => 120000 crates/burn/src/backend.rs mode change 100644 => 120000 crates/burn/src/element.rs mode change 100644 => 120000 crates/burn/src/lib.rs mode change 100644 => 120000 crates/burn/src/ops/adaptive_avgpool.rs mode change 100644 => 120000 crates/burn/src/ops/avgpool.rs mode change 100644 => 120000 crates/burn/src/ops/base.rs mode change 100644 => 120000 crates/burn/src/ops/bool_tensor.rs mode change 100644 => 120000 crates/burn/src/ops/conv.rs mode change 100644 => 120000 crates/burn/src/ops/deform_conv.rs mode change 100644 => 120000 crates/burn/src/ops/grid_sample.rs mode change 100644 => 120000 crates/burn/src/ops/int_tensor.rs mode change 100644 => 120000 crates/burn/src/ops/interpolate.rs mode change 100644 => 120000 crates/burn/src/ops/macros.rs mode change 100644 => 120000 crates/burn/src/ops/matmul.rs mode change 100644 => 120000 crates/burn/src/ops/maxpool.rs mode change 100644 => 120000 crates/burn/src/ops/mod.rs mode change 100644 => 120000 crates/burn/src/ops/module.rs mode change 100644 => 120000 crates/burn/src/ops/padding.rs mode change 100644 => 120000 crates/burn/src/ops/qtensor.rs mode change 100644 => 120000 crates/burn/src/ops/quantization.rs mode change 100644 => 120000 crates/burn/src/ops/simd/avgpool.rs mode change 100644 => 120000 crates/burn/src/ops/simd/base.rs mode change 100644 => 120000 crates/burn/src/ops/simd/binary.rs mode change 100644 => 120000 crates/burn/src/ops/simd/binary_elemwise.rs mode change 100644 => 120000 crates/burn/src/ops/simd/cmp.rs mode change 100644 => 120000 crates/burn/src/ops/simd/conv.rs mode change 100644 => 120000 crates/burn/src/ops/simd/maxpool.rs mode change 100644 => 120000 crates/burn/src/ops/simd/mod.rs mode change 100644 => 120000 crates/burn/src/ops/simd/unary.rs mode change 100644 => 120000 crates/burn/src/ops/transaction.rs mode change 100644 => 120000 crates/burn/src/parallel.rs mode change 100644 => 120000 crates/burn/src/rand.rs mode change 100644 => 120000 crates/burn/src/sharing.rs mode change 100644 => 120000 crates/burn/src/storage.rs mode change 100644 => 120000 crates/burn/src/tensor.rs create mode 160000 crates/burn/upstream diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..5ead8fc9 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "crates/burn/upstream"] + path = crates/burn/upstream + url = https://github.com/tracel-ai/burn.git diff --git a/crates/burn/src/backend.rs b/crates/burn/src/backend.rs deleted file mode 100644 index 6a27a9fd..00000000 --- a/crates/burn/src/backend.rs +++ /dev/null @@ -1,222 +0,0 @@ -use crate::rand::NdArrayRng; -use crate::{NdArrayQTensor, NdArrayTensor}; -use crate::{ - SharedArray, - element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}, -}; -use alloc::string::String; -use burn_backend::quantization::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue}; -use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}; -use burn_backend::{Backend, DType, DeviceId, DeviceOps}; -use burn_ir::{BackendIr, HandleKind, TensorHandle}; -use burn_std::BoolStore; -use burn_std::stub::Mutex; -use core::marker::PhantomData; -use rand::SeedableRng; - -pub(crate) static SEED: Mutex> = Mutex::new(None); - -/// The device type for the ndarray backend. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] -pub enum NdArrayDevice { - /// The CPU device. - #[default] - Cpu, -} - -impl DeviceOps for NdArrayDevice {} - -impl burn_backend::Device for NdArrayDevice { - fn from_id(_device_id: DeviceId) -> Self { - Self::Cpu - } - - fn to_id(&self) -> DeviceId { - DeviceId { - type_id: 0, - index_id: 0, - } - } -} - -/// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations. -/// -/// This backend is compatible with CPUs and can be compiled for almost any platform, including -/// `wasm`, `arm`, and `x86`. -#[derive(Clone, Copy, Default, Debug)] -pub struct NdArray -where - NdArrayTensor: From>, - NdArrayTensor: From>, -{ - _e: PhantomData, - _i: PhantomData, - _q: PhantomData, -} - -impl Backend for NdArray -where - NdArrayTensor: From>, - NdArrayTensor: From>, -{ - type Device = NdArrayDevice; - - type FloatTensorPrimitive = NdArrayTensor; - type FloatElem = E; - - type IntTensorPrimitive = NdArrayTensor; - type IntElem = I; - - type BoolTensorPrimitive = NdArrayTensor; - type BoolElem = bool; - - type QuantizedTensorPrimitive = NdArrayQTensor; - - fn ad_enabled(_device: &Self::Device) -> bool { - false - } - - fn name(_device: &Self::Device) -> String { - String::from("ndarray") - } - - fn seed(_device: &Self::Device, seed: u64) { - let rng = NdArrayRng::seed_from_u64(seed); - let mut seed = SEED.lock().unwrap(); - *seed = Some(rng); - } - - fn dtype_usage(_device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet { - match dtype { - DType::F64 - | DType::F32 - | DType::Flex32 - | DType::I64 - | DType::I32 - | DType::I16 - | DType::I8 - | DType::U64 - | DType::U32 - | DType::U16 - | DType::U8 - | DType::Bool(BoolStore::Native) => burn_backend::DTypeUsage::general(), - DType::F16 | DType::BF16 | DType::Bool(_) => burn_backend::DTypeUsageSet::empty(), - DType::QFloat(scheme) => { - match scheme { - QuantScheme { - level: QuantLevel::Tensor | QuantLevel::Block(_), - mode: QuantMode::Symmetric, - #[cfg(not(feature = "export_tests"))] - value: QuantValue::Q8F | QuantValue::Q8S, - // For tests, "native" sub-byte quant serves as a reference for value equality. - // Values are stored as i8 regardless. - #[cfg(feature = "export_tests")] - value: - QuantValue::Q8F - | QuantValue::Q8S - | QuantValue::Q4F - | QuantValue::Q4S - | QuantValue::Q2F - | QuantValue::Q2S, - store: QuantStore::Native, - .. - } => burn_backend::DTypeUsage::general(), - _scheme => burn_backend::DTypeUsageSet::empty(), - } - } - } - } - - fn device_count(_: u16) -> usize { - 1 - } -} - -impl BackendIr for NdArray -where - NdArrayTensor: From>, - NdArrayTensor: From>, -{ - type Handle = HandleKind; - - fn float_tensor(handle: TensorHandle) -> FloatTensor { - match handle.handle { - HandleKind::Float(handle) => handle, - _ => panic!("Expected float handle, got {}", handle.handle.name()), - } - } - - fn int_tensor(handle: TensorHandle) -> IntTensor { - match handle.handle { - HandleKind::Int(handle) => handle, - _ => panic!("Expected int handle, got {}", handle.handle.name()), - } - } - - fn bool_tensor(handle: TensorHandle) -> BoolTensor { - match handle.handle { - HandleKind::Bool(handle) => handle, - _ => panic!("Expected bool handle, got {}", handle.handle.name()), - } - } - - fn quantized_tensor(handle: TensorHandle) -> QuantizedTensor { - match handle.handle { - HandleKind::Quantized(handle) => handle, - _ => panic!("Expected quantized handle, got {}", handle.handle.name()), - } - } - - fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle { - HandleKind::Float(tensor) - } - - fn int_tensor_handle(tensor: IntTensor) -> Self::Handle { - HandleKind::Int(tensor) - } - - fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle { - HandleKind::Bool(tensor) - } - - fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle { - HandleKind::Quantized(tensor) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use burn_backend::QTensorPrimitive; - - #[test] - fn should_support_dtypes() { - type B = NdArray; - let device = Default::default(); - - assert!(B::supports_dtype(&device, DType::F64)); - assert!(B::supports_dtype(&device, DType::F32)); - assert!(B::supports_dtype(&device, DType::Flex32)); - assert!(B::supports_dtype(&device, DType::I64)); - assert!(B::supports_dtype(&device, DType::I32)); - assert!(B::supports_dtype(&device, DType::I16)); - assert!(B::supports_dtype(&device, DType::I8)); - assert!(B::supports_dtype(&device, DType::U64)); - assert!(B::supports_dtype(&device, DType::U32)); - assert!(B::supports_dtype(&device, DType::U16)); - assert!(B::supports_dtype(&device, DType::U8)); - assert!(B::supports_dtype(&device, DType::Bool(BoolStore::Native))); - assert!(B::supports_dtype( - &device, - DType::QFloat(NdArrayQTensor::default_scheme()) - )); - - assert!(!B::supports_dtype(&device, DType::F16)); - assert!(!B::supports_dtype(&device, DType::BF16)); - // QuantStore::U32 not supported - assert!(!B::supports_dtype( - &device, - DType::QFloat(QuantScheme::default()) - )); - } -} diff --git a/crates/burn/src/backend.rs b/crates/burn/src/backend.rs new file mode 120000 index 00000000..dc799977 --- /dev/null +++ b/crates/burn/src/backend.rs @@ -0,0 +1 @@ +../upstream/crates/burn-ndarray/src/backend.rs \ No newline at end of file diff --git a/crates/burn/src/element.rs b/crates/burn/src/element.rs deleted file mode 100644 index 8485352e..00000000 --- a/crates/burn/src/element.rs +++ /dev/null @@ -1,207 +0,0 @@ -use burn_backend::Element; -use num_traits::Signed; - -#[cfg(not(feature = "std"))] -#[allow(unused_imports)] -use num_traits::Float; - -use num_traits::Pow; - -use libm::{log1p, log1pf}; - -/// A float element for ndarray backend. -pub trait FloatNdArrayElement: NdArrayElement + Signed + core::cmp::PartialOrd -where - Self: Sized, -{ -} - -/// An int element for ndarray backend. -pub trait IntNdArrayElement: NdArrayElement + core::cmp::PartialOrd {} - -/// A general element for ndarray backend. -pub trait NdArrayElement: - Element - + ndarray::LinalgScalar - + ndarray::ScalarOperand - + ExpElement - + AddAssignElement - + num_traits::FromPrimitive - + core::ops::AddAssign - + core::cmp::PartialEq - + core::ops::Rem -{ -} - -/// A element for ndarray backend that supports exp ops. -pub trait ExpElement { - /// Exponent - fn exp_elem(self) -> Self; - /// Log - fn log_elem(self) -> Self; - /// Log1p - fn log1p_elem(self) -> Self; - /// Powf - fn powf_elem(self, value: f32) -> Self; - /// Powi - fn powi_elem(self, value: i32) -> Self; - /// Sqrt - fn sqrt_elem(self) -> Self; - /// Abs - fn abs_elem(self) -> Self; -} - -/// The addition assignment operator implemented for ndarray elements. -pub trait AddAssignElement { - /// Performs the addition assignment operation. - /// - /// For `bool`, this corresponds to logical OR assignment. - fn add_assign(&mut self, rhs: Rhs); -} - -impl AddAssignElement for E { - fn add_assign(&mut self, rhs: Self) { - *self += rhs; - } -} - -impl AddAssignElement for bool { - fn add_assign(&mut self, rhs: Self) { - *self = *self || rhs; // logical OR for bool - } -} - -/// A quantized element for the ndarray backend. -pub trait QuantElement: NdArrayElement {} - -impl QuantElement for i8 {} - -impl FloatNdArrayElement for f64 {} -impl FloatNdArrayElement for f32 {} - -impl IntNdArrayElement for i64 {} -impl IntNdArrayElement for i32 {} -impl IntNdArrayElement for i16 {} -impl IntNdArrayElement for i8 {} - -impl IntNdArrayElement for u64 {} -impl IntNdArrayElement for u32 {} -impl IntNdArrayElement for u16 {} -impl IntNdArrayElement for u8 {} - -macro_rules! make_float { - ( - $ty:ty, - $log1p:expr - ) => { - impl NdArrayElement for $ty {} - - #[allow(clippy::cast_abs_to_unsigned)] - impl ExpElement for $ty { - #[inline(always)] - fn exp_elem(self) -> Self { - self.exp() - } - - #[inline(always)] - fn log_elem(self) -> Self { - self.ln() - } - - #[inline(always)] - fn log1p_elem(self) -> Self { - $log1p(self) - } - - #[inline(always)] - fn powf_elem(self, value: f32) -> Self { - self.pow(value) - } - - #[inline(always)] - fn powi_elem(self, value: i32) -> Self { - #[cfg(feature = "std")] - let val = self.powi(value); - - #[cfg(not(feature = "std"))] - let val = Self::powf_elem(self, value as f32); - - val - } - - #[inline(always)] - fn sqrt_elem(self) -> Self { - self.sqrt() - } - - #[inline(always)] - fn abs_elem(self) -> Self { - self.abs() - } - } - }; -} -macro_rules! make_int { - ( - $ty:ty, - $abs:expr - ) => { - impl NdArrayElement for $ty {} - - #[allow(clippy::cast_abs_to_unsigned)] - impl ExpElement for $ty { - #[inline(always)] - fn exp_elem(self) -> Self { - (self as f32).exp() as $ty - } - - #[inline(always)] - fn log_elem(self) -> Self { - (self as f32).ln() as $ty - } - - #[inline(always)] - fn log1p_elem(self) -> Self { - log1pf(self as f32) as $ty - } - - #[inline(always)] - fn powf_elem(self, value: f32) -> Self { - (self as f32).pow(value) as $ty - } - - #[inline(always)] - fn powi_elem(self, value: i32) -> Self { - #[cfg(feature = "std")] - let val = f32::powi(self as f32, value) as $ty; - - #[cfg(not(feature = "std"))] - let val = Self::powf_elem(self, value as f32); - - val - } - - #[inline(always)] - fn sqrt_elem(self) -> Self { - (self as f32).sqrt() as $ty - } - - #[inline(always)] - fn abs_elem(self) -> Self { - $abs(self) - } - } - }; -} - -make_float!(f64, log1p); -make_float!(f32, log1pf); - -make_int!(i64, i64::wrapping_abs); -make_int!(i32, i32::wrapping_abs); -make_int!(i16, i16::wrapping_abs); -make_int!(i8, i8::wrapping_abs); -make_int!(u64, |x| x); -make_int!(u32, |x| x); -make_int!(u16, |x| x); -make_int!(u8, |x| x); diff --git a/crates/burn/src/element.rs b/crates/burn/src/element.rs new file mode 120000 index 00000000..c010ea9e --- /dev/null +++ b/crates/burn/src/element.rs @@ -0,0 +1 @@ +../upstream/crates/burn-ndarray/src/element.rs \ No newline at end of file diff --git a/crates/burn/src/lib.rs b/crates/burn/src/lib.rs deleted file mode 100644 index 34a46255..00000000 --- a/crates/burn/src/lib.rs +++ /dev/null @@ -1,29 +0,0 @@ -#![cfg_attr(not(feature = "std"), no_std)] -#![warn(missing_docs)] -#![cfg_attr(docsrs, feature(doc_cfg))] - -//! Burn ndarray backend. - -#[cfg(any( - feature = "blas-netlib", - feature = "blas-openblas", - feature = "blas-openblas-system", -))] -extern crate blas_src; - -mod backend; -mod element; -mod ops; -mod parallel; -mod rand; -mod sharing; -mod storage; -mod tensor; - -pub use backend::*; -pub use element::*; -pub(crate) use sharing::*; -pub(crate) use storage::*; -pub use tensor::*; - -extern crate alloc; diff --git a/crates/burn/src/lib.rs b/crates/burn/src/lib.rs new file mode 120000 index 00000000..1af3555e --- /dev/null +++ b/crates/burn/src/lib.rs @@ -0,0 +1 @@ +../upstream/crates/burn-ndarray/src/lib.rs \ No newline at end of file diff --git a/crates/burn/src/ops/adaptive_avgpool.rs b/crates/burn/src/ops/adaptive_avgpool.rs deleted file mode 100644 index baaee09f..00000000 --- a/crates/burn/src/ops/adaptive_avgpool.rs +++ /dev/null @@ -1,103 +0,0 @@ -use crate::{ - SharedArray, element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef, -}; -use burn_backend::ElementConversion; -use ndarray::Array4; - -#[cfg(not(feature = "std"))] -#[allow(unused_imports)] -use num_traits::Float; - -pub(crate) fn adaptive_avg_pool2d( - x: SharedArray, - output_size: [usize; 2], -) -> SharedArray { - let [batch_size, channels, input_height, input_width] = x.shape().try_into().unwrap(); - - let mut output = Array4::from_elem( - (batch_size, channels, output_size[0], output_size[1]), - 0.elem(), - ); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; - - let output = unsafe_shared_out.get(); - for h in 0..output_size[0] { - for w in 0..output_size[1] { - let ih_start = start_index(h, output_size[0], input_height); - let ih_end = end_index(h, output_size[0], input_height); - let iw_start = start_index(w, output_size[1], input_width); - let iw_end = end_index(w, output_size[1], input_width); - - let mut sum_val: E = 0.elem(); - - for ih in ih_start..ih_end { - for iw in iw_start..iw_end { - sum_val += x[[b, c, ih, iw]]; - } - } - - let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem(); - output[[b, c, h, w]] = sum_val / count.elem(); - } - } - }) - }); - - output.into_dyn().into_shared() -} - -pub(crate) fn adaptive_avg_pool2d_backward( - x: SharedArray, - grad: SharedArray, -) -> SharedArray { - let [_, _, input_height, input_width] = x.shape().try_into().unwrap(); - let [batch_size, channels, output_height, output_width] = grad.shape().try_into().unwrap(); - - let mut output_grad = - Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem()); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad); - - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; - - let output_grad = unsafe_shared_out.get(); - for oh in 0..output_height { - for ow in 0..output_width { - let ih_start = start_index(oh, output_height, input_height); - let ih_end = end_index(oh, output_height, input_height); - - let iw_start = start_index(ow, output_width, input_width); - let iw_end = end_index(ow, output_width, input_width); - - let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem(); - - for ih in ih_start..ih_end { - for iw in iw_start..iw_end { - output_grad[[b, c, ih, iw]] += grad[[b, c, oh, ow]] / count.elem(); - } - } - } - } - }) - }); - - output_grad.into_dyn().into_shared() -} - -fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { - ((output_size_index as f32 * input_size as f32) / output_size as f32).floor() as usize -} - -fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { - let index = - (((output_size_index + 1) as f32 * input_size as f32) / output_size as f32).ceil() as usize; - - usize::min(index, input_size) -} diff --git a/crates/burn/src/ops/adaptive_avgpool.rs b/crates/burn/src/ops/adaptive_avgpool.rs new file mode 120000 index 00000000..a984e87c --- /dev/null +++ b/crates/burn/src/ops/adaptive_avgpool.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/adaptive_avgpool.rs \ No newline at end of file diff --git a/crates/burn/src/ops/avgpool.rs b/crates/burn/src/ops/avgpool.rs deleted file mode 100644 index 4d015dd9..00000000 --- a/crates/burn/src/ops/avgpool.rs +++ /dev/null @@ -1,172 +0,0 @@ -use crate::{ - SharedArray, element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef, -}; - -use burn_backend::ElementConversion; -use burn_backend::ops::conv::calculate_pool_output_size; -use ndarray::Array4; - -pub(crate) fn avg_pool2d( - x: SharedArray, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ceil_mode: bool, -) -> SharedArray { - let [kernel_height, kernel_width] = kernel_size; - let [padding_height, padding_width] = padding; - let [stride_height, stride_width] = stride; - let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap(); - - let out_height = calculate_pool_output_size( - kernel_height, - stride_height, - padding_height, - 1, - x_height, - ceil_mode, - ); - let out_width = calculate_pool_output_size( - kernel_width, - stride_width, - padding_width, - 1, - x_width, - ceil_mode, - ); - - // Padded input bounds (for count_include_pad calculation) - let padded_height = x_height + 2 * padding_height; - let padded_width = x_width + 2 * padding_width; - - let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), 0.elem()); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; - - let output = unsafe_shared_out.get(); - - for oh in 0..out_height { - for ow in 0..out_width { - let mut sum_val: E = 0.elem(); - let mut valid_count = 0usize; - let mut padded_count = 0usize; - - for kh in 0..kernel_height { - let ih = oh * stride_height + kh; - - for kw in 0..kernel_width { - let iw = ow * stride_width + kw; - - // Check if within padded bounds (excludes ceil_mode extensions) - if ih < padded_height && iw < padded_width { - padded_count += 1; - - // Check if within valid (non-padding) input bounds - if ih >= padding_height - && ih < x_height + padding_height - && iw >= padding_width - && iw < x_width + padding_width - { - let ih_valid = ih - padding_height; - let iw_valid = iw - padding_width; - sum_val += x[[b, c, ih_valid, iw_valid]]; - valid_count += 1; - } - } - } - } - - // count_include_pad: count positions within padded bounds (not ceil_mode extensions) - // !count_include_pad: count only valid (non-padding) positions - let count: E = if count_include_pad { - (padded_count as i32).elem() - } else { - (valid_count as i32).elem() - }; - - output[[b, c, oh, ow]] = sum_val / count; - } - } - }) - }); - - output.into_dyn().into_shared() -} - -pub(crate) fn avg_pool2d_backward( - x: SharedArray, - grad: SharedArray, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - _ceil_mode: bool, -) -> SharedArray { - let [kernel_height, kernel_width] = kernel_size; - let [stride_height, stride_width] = stride; - let [padding_height, padding_width] = padding; - let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap(); - let [_batch_size, _channels, out_height, out_width] = grad.shape().try_into().unwrap(); - - // Padded input bounds (for count_include_pad calculation) - let padded_height = x_height + 2 * padding_height; - let padded_width = x_width + 2 * padding_width; - - let mut output_grad = Array4::from_elem((batch_size, channels, x_height, x_width), 0.elem()); - let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad); - - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; - - let output_grad = unsafe_shared_grad.get(); - - for oh in 0..out_height { - for ow in 0..out_width { - let ih_start_kernel = oh * stride_height; - let iw_start_kernel = ow * stride_width; - - let ih_end_kernel = ih_start_kernel + kernel_height; - let iw_end_kernel = iw_start_kernel + kernel_width; - - // Clip to valid input bounds (for gradient distribution) - let ih_start = usize::max(ih_start_kernel, padding_height); - let iw_start = usize::max(iw_start_kernel, padding_width); - let ih_end = usize::min(ih_end_kernel, x_height + padding_height); - let iw_end = usize::min(iw_end_kernel, x_width + padding_width); - - // Calculate count based on count_include_pad - let count = if count_include_pad { - // Count positions within padded bounds (not ceil_mode extensions) - let ih_start_padded = ih_start_kernel; - let iw_start_padded = iw_start_kernel; - let ih_end_padded = usize::min(ih_end_kernel, padded_height); - let iw_end_padded = usize::min(iw_end_kernel, padded_width); - (ih_end_padded - ih_start_padded) * (iw_end_padded - iw_start_padded) - } else { - // Count only valid (non-padding) positions - (ih_end - ih_start) * (iw_end - iw_start) - }; - - for ih in ih_start..ih_end { - for iw in iw_start..iw_end { - let ih = ih - padding_height; - let iw = iw - padding_width; - - output_grad[[b, c, ih, iw]] += - grad[[b, c, oh, ow]] / (count as i32).elem(); - } - } - } - } - }) - }); - - output_grad.into_dyn().into_shared() -} diff --git a/crates/burn/src/ops/avgpool.rs b/crates/burn/src/ops/avgpool.rs new file mode 120000 index 00000000..90d28c7d --- /dev/null +++ b/crates/burn/src/ops/avgpool.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/avgpool.rs \ No newline at end of file diff --git a/crates/burn/src/ops/base.rs b/crates/burn/src/ops/base.rs deleted file mode 100644 index 5d2ce429..00000000 --- a/crates/burn/src/ops/base.rs +++ /dev/null @@ -1,1448 +0,0 @@ -use alloc::{vec, vec::Vec}; -use burn_backend::element::{Element, ElementConversion}; -#[cfg(feature = "simd")] -use burn_backend::{DType, quantization::QuantValue}; -use core::fmt::Debug; -use core::marker::PhantomData; -use ndarray::IntoDimension; -use ndarray::SliceInfo; -use ndarray::Zip; -use ndarray::s; -use ndarray::{Array2, ArrayD}; -use num_traits::Signed; -#[cfg(feature = "simd")] -use paste::paste; - -#[cfg(not(feature = "std"))] -#[allow(unused_imports)] -use num_traits::Float; - -#[cfg(feature = "simd")] -use crate::ops::simd::{ - binary::try_binary_simd, - binary_elemwise::{ - VecAdd, VecBitAnd, VecBitOr, VecBitXor, VecClamp, VecDiv, VecMax, VecMin, VecMul, VecSub, - try_binary_scalar_simd, - }, - cmp::{ - VecEquals, VecGreater, VecGreaterEq, VecLower, VecLowerEq, try_cmp_scalar_simd, - try_cmp_simd, - }, - unary::{RecipVec, VecAbs, VecBitNot, try_unary_simd}, -}; -use crate::reshape; -use crate::{ - IntNdArrayElement, ShapeOps, - ops::macros::{ - cummax_dim, cummin_dim, cumprod_dim, cumsum_dim, keepdim, mean_dim, prod_dim, sum_dim, - }, -}; -use crate::{SharedArray, element::NdArrayElement}; -use burn_backend::ops::unfold::calculate_unfold_shape; -use burn_backend::{Shape, Slice}; -use ndarray::ArrayView; -use ndarray::Axis; -use ndarray::Dim; -use ndarray::IxDyn; -use ndarray::SliceInfoElem; - -pub struct NdArrayOps { - e: PhantomData, -} - -pub(crate) struct NdArrayMathOps { - e: PhantomData, -} - -impl NdArrayOps -where - E: Copy + Debug + Element + crate::AddAssignElement, -{ - pub fn slice(tensor: ArrayView, slices: &[Slice]) -> SharedArray { - let slices = Self::to_slice_args_with_steps(slices, tensor.shape().num_dims()); - tensor.slice_move(slices.as_slice()).to_shared() - } - - pub fn slice_assign( - tensor: SharedArray, - slices: &[Slice], - value: SharedArray, - ) -> SharedArray { - let slices = Self::to_slice_args_with_steps(slices, tensor.shape().num_dims()); - let mut array = tensor.into_owned(); - array.slice_mut(slices.as_slice()).assign(&value); - array.into_shared() - } - - pub fn mask_where( - tensor: SharedArray, - mask: SharedArray, - source: SharedArray, - ) -> SharedArray { - let tensor = tensor.broadcast(mask.dim()).unwrap(); - let source = source.broadcast(mask.dim()).unwrap(); - Zip::from(&tensor) - .and(&mask) - .and(&source) - .map_collect(|&x, &mask_val, &y| if mask_val { y } else { x }) - .into_shared() - } - - pub fn mask_fill(tensor: SharedArray, mask: SharedArray, value: E) -> SharedArray { - // Use into_owned() instead of clone() - only copies if shared, avoids copy if unique - let mut output = tensor.into_owned(); - let broadcast_mask = mask.broadcast(output.dim()).unwrap(); - Zip::from(&mut output) - .and(&broadcast_mask) - .for_each(|out, &mask_val| { - if mask_val { - *out = value; - } - }); - output.into_shared() - } - - pub fn gather( - dim: usize, - mut tensor: SharedArray, - mut indices: SharedArray, - ) -> SharedArray { - let ndims = tensor.shape().num_dims(); - if dim != ndims - 1 { - tensor.swap_axes(ndims - 1, dim); - indices.swap_axes(ndims - 1, dim); - } - let (shape_tensor, shape_indices) = (tensor.shape(), indices.shape().into_shape()); - let (size_tensor, size_index) = (shape_tensor[ndims - 1], shape_indices[ndims - 1]); - let batch_size = Self::gather_batch_size(shape_tensor, &shape_indices); - - let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])); - let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])); - let mut output = Array2::from_elem((batch_size, size_index), 0.elem::()); - - for b in 0..batch_size { - let indices = indices.slice(s!(b, ..)); - for (i, index) in indices.iter().enumerate() { - output[[b, i]] = tensor[[b, index.elem::() as usize]]; - } - } - - let mut output = NdArrayOps::reshape(output.into_shared().into_dyn(), shape_indices); - - if dim != ndims - 1 { - output.swap_axes(ndims - 1, dim); - } - - output - } - - pub fn scatter( - dim: usize, - mut tensor: SharedArray, - mut indices: SharedArray, - mut value: SharedArray, - ) -> SharedArray { - let ndims = tensor.shape().num_dims(); - if dim != ndims - 1 { - tensor.swap_axes(ndims - 1, dim); - indices.swap_axes(ndims - 1, dim); - value.swap_axes(ndims - 1, dim); - } - - let (shape_tensor, shape_indices, shape_value) = - (tensor.shape().into_shape(), indices.shape(), value.shape()); - let (size_tensor, size_index, size_value) = ( - shape_tensor[ndims - 1], - shape_indices[ndims - 1], - shape_value[ndims - 1], - ); - let batch_size = Self::gather_batch_size(&shape_tensor, shape_indices); - - if shape_value != shape_indices { - panic!( - "Invalid dimension: the shape of the index tensor should be the same as the value \ - tensor: Index {:?} value {:?}", - shape_indices, shape_value - ); - } - - let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])); - let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value])); - let mut tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])); - - for b in 0..batch_size { - let indices = indices.slice(s!(b, ..)); - - for (i, index) in indices.iter().enumerate() { - let index = index.elem::() as usize; - tensor[[b, index]].add_assign(value[[b, i]]); - } - } - - let mut output = NdArrayOps::reshape(tensor.into_shared().into_dyn(), shape_tensor); - if dim != ndims - 1 { - output.swap_axes(ndims - 1, dim); - } - output - } - - fn gather_batch_size(shape_tensor: &[usize], shape_indices: &[usize]) -> usize { - let ndims = shape_tensor.num_dims(); - let mut batch_size = 1; - - for i in 0..ndims - 1 { - if shape_tensor[i] != shape_indices[i] { - panic!( - "Unsupported dimension, only the last dimension can differ: Tensor {:?} Index \ - {:?}", - shape_tensor, shape_indices - ); - } - batch_size *= shape_indices[i]; - } - - batch_size - } - - pub fn reshape(tensor: SharedArray, shape: Shape) -> SharedArray { - reshape!( - ty E, - shape shape, - array tensor, - d shape.num_dims() - ) - } - - pub(crate) fn concatenate( - arrays: &[ndarray::ArrayView], - dim: usize, - ) -> SharedArray { - let array = ndarray::concatenate(Axis(dim), arrays) - .unwrap() - .into_shared(); - - // Transform column-major layout into row-major (standard) layout. (fix #1053) - // Get shape first (via reference), then pass ownership to avoid clone - let shape = array.shape().into_shape(); - Self::reshape(array, shape) - } - - pub fn cat(tensors: Vec>, dim: usize) -> SharedArray { - let arrays: Vec<_> = tensors.iter().map(|t| t.view()).collect(); - Self::concatenate(&arrays, dim) - } - - #[allow(clippy::wrong_self_convention)] - fn to_slice_args_with_steps( - burn_slices: &[burn_backend::Slice], - ndims: usize, - ) -> Vec { - let mut slices = vec![SliceInfoElem::NewAxis; ndims]; - - for i in 0..ndims { - slices[i] = if i < burn_slices.len() { - let slice = &burn_slices[i]; - - // Check for empty range (would result in no elements) - if let Some(end) = slice.end - && slice.start == end - { - SliceInfoElem::Slice { - start: 0, - end: Some(0), - step: 1, - } - } else { - // Pass slice parameters directly to ndarray - // ndarray handles both positive and negative steps correctly: - // - Positive step: iterates forward from start - // - Negative step: iterates backward from the last element in range - SliceInfoElem::Slice { - start: slice.start, - end: slice.end, - step: slice.step, - } - } - } else { - // Dimension not specified in slices - use full range - SliceInfoElem::Slice { - start: 0, - end: None, - step: 1, - } - } - } - - slices - } - - pub fn swap_dims(mut tensor: SharedArray, dim1: usize, dim2: usize) -> SharedArray { - tensor.swap_axes(dim1, dim2); - - tensor - } - - pub fn permute(tensor: SharedArray, axes: &[usize]) -> SharedArray { - tensor.permuted_axes(axes.into_dimension()) - } - - /// Broadcasts the tensor to the given shape - pub(crate) fn expand(tensor: SharedArray, shape: Shape) -> SharedArray { - tensor - .broadcast(shape.into_dimension()) - .expect("The shapes should be broadcastable") - // need to convert view to owned array because NdArrayTensor expects owned array - // and try_into_owned_nocopy() panics for broadcasted arrays (zero strides) - .into_owned() - .into_shared() - } - - pub fn flip(tensor: SharedArray, axes: &[usize]) -> SharedArray { - let slice_items: Vec<_> = (0..tensor.shape().num_dims()) - .map(|i| { - if axes.contains(&i) { - SliceInfoElem::Slice { - start: 0, - end: None, - step: -1, - } - } else { - SliceInfoElem::Slice { - start: 0, - end: None, - step: 1, - } - } - }) - .collect(); - let slice_info = - SliceInfo::, IxDyn, IxDyn>::try_from(slice_items).unwrap(); - tensor.slice(slice_info).into_owned().into_shared() - } - - /// Unfold windows along a dimension. - /// - /// # Warning - /// - /// This is a copy impl; `ndarray` doesn't expose the layout machinery - /// necessary to build the stride view. - /// - /// Returns a copy of the tensor with all complete windows of size `size` in dimension `dim`; - /// where windows are advanced by `step` at each index. - /// - /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`. - /// - /// # Arguments - /// - /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]`` - /// * `dim` - the dimension to unfold. - /// * `size` - the size of each unfolded window. - /// * `step` - the step between each window. - /// - /// # Returns - /// - /// A tensor view with shape ``[pre=..., windows, post=..., size]``. - #[allow(unused)] - pub(crate) fn unfold( - tensor: SharedArray, - dim: usize, - size: usize, - step: usize, - ) -> SharedArray { - let result_shape = calculate_unfold_shape(tensor.shape(), dim, size, step); - let windows = result_shape[dim]; - - let mut slices = vec![Slice::new(0, None, 1); tensor.shape().len()]; - let new_axis = slices.len(); - - let mut stack = Vec::with_capacity(windows); - for widx in 0..windows { - let start = widx * step; - let end = start + size; - slices[dim] = Slice::new(start as isize, Some(end as isize), 1); - - let mut window_slice = - tensor.slice(Self::to_slice_args_with_steps(&slices, slices.len()).as_slice()); - window_slice.insert_axis_inplace(Axis(new_axis)); - window_slice.swap_axes(dim, new_axis); - - stack.push(window_slice); - } - Self::concatenate(&stack, dim) - } -} - -#[cfg(feature = "simd")] -macro_rules! dispatch_binary_simd { - (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ - paste! { - let simd = match $elem::dtype() { - $(DType::[<$ty:upper>] => try_binary_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* - _ => Err(($lhs, $rhs)), - }; - match simd { - Ok(out) => return out, - Err(args) => args, - } - } - }}; - ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ - paste! { - let simd = match $elem::dtype() { - $(DType::[<$ty:upper>] => try_binary_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* - DType::QFloat(strategy) => match strategy.value { - QuantValue::Q8F | QuantValue::Q8S => try_binary_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs), - _ => Err(($lhs, $rhs)), - }, - _ => Err(($lhs, $rhs)), - }; - match simd { - Ok(out) => return out, - Err(args) => args, - } - } - }}; -} - -#[cfg(not(feature = "simd"))] -macro_rules! dispatch_binary_simd { - (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }}; - ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }}; -} - -#[cfg(feature = "simd")] -macro_rules! dispatch_binary_scalar_simd { - (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ - paste! { - let simd = match $elem::dtype() { - $(DType::[<$ty:upper>] => try_binary_scalar_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* - _ => Err($lhs), - }; - match simd { - Ok(out) => return out, - Err(args) => args, - } - } - }}; - ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ - paste! { - let simd = match $elem::dtype() { - $(DType::[<$ty:upper>] => try_binary_scalar_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* - DType::QFloat(strategy) => match strategy.value { - QuantValue::Q8F | QuantValue::Q8S => try_binary_scalar_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs), - QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err($lhs) - }, - _ => Err($lhs), - }; - match simd { - Ok(out) => return out, - Err(args) => args, - } - } - }}; -} - -#[cfg(not(feature = "simd"))] -macro_rules! dispatch_binary_scalar_simd { - (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }}; - ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }}; -} - -#[cfg(feature = "simd")] -macro_rules! dispatch_cmp_simd { - ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ - paste! { - let simd = match $elem::dtype() { - $(DType::[<$ty:upper>] => try_cmp_simd::<$elem, $ty, $op>($lhs, $rhs),)* - DType::QFloat(strategy) => match strategy.value { - QuantValue::Q8F | QuantValue::Q8S => try_cmp_simd::<$elem, i8, $op>($lhs, $rhs), - QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err(($lhs, $rhs)) - }, - _ => Err(($lhs, $rhs)), - }; - match simd { - Ok(out) => return out, - Err(args) => args, - } - } - }}; -} - -#[cfg(not(feature = "simd"))] -macro_rules! dispatch_cmp_simd { - ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }}; -} - -#[cfg(feature = "simd")] -macro_rules! dispatch_cmp_scalar_simd { - ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ - paste! { - let simd = match $elem::dtype() { - $(DType::[<$ty:upper>] => try_cmp_scalar_simd::<$elem, $ty, $op>($lhs, $rhs),)* - DType::QFloat(strategy) => match strategy.value { - QuantValue::Q8F | QuantValue::Q8S => try_cmp_scalar_simd::<$elem, i8, $op>($lhs, $rhs), - QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err($lhs) - }, - _ => Err($lhs), - }; - match simd { - Ok(out) => return out, - Err(args) => args, - } - } - }}; -} - -#[cfg(not(feature = "simd"))] -macro_rules! dispatch_cmp_scalar_simd { - ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }}; -} - -#[cfg(feature = "simd")] -macro_rules! dispatch_unary_simd { - ($elem: ty, $op: ty, $lhs: expr, $($ty: ty),*) => {{ - paste! { - let simd = match $elem::dtype() { - $(DType::[<$ty:upper>] => try_unary_simd::<$elem, $elem, $ty, $ty, $op>($lhs),)* - _ => Err($lhs), - }; - match simd { - Ok(out) => return out, - Err(args) => args, - } - } - }}; -} - -#[cfg(not(feature = "simd"))] -macro_rules! dispatch_unary_simd { - ($elem: ty, $op: ty, $lhs: expr, $($ty: ty),*) => {{ $lhs }}; -} - -// Helper function to broadcast two tensors to a common shape for comparison operations -// Returns broadcasted views that can be safely zipped -fn broadcast_for_comparison<'a, E: Copy, S1, S2>( - lhs: &'a ndarray::ArrayBase, - rhs: &'a ndarray::ArrayBase, -) -> ( - ndarray::ArrayView<'a, E, ndarray::IxDyn>, - ndarray::ArrayView<'a, E, ndarray::IxDyn>, -) -where - S1: ndarray::Data, - S2: ndarray::Data, -{ - // Get shapes - let lhs_shape = lhs.shape(); - let rhs_shape = rhs.shape(); - - // Compute broadcast shape using ndarray's broadcast compatibility rules - let ndims = lhs_shape.len().max(rhs_shape.len()); - let mut broadcast_shape = vec![1; ndims]; - - for i in 0..ndims { - let lhs_dim = if i < lhs_shape.len() { - lhs_shape[lhs_shape.len() - 1 - i] - } else { - 1 - }; - let rhs_dim = if i < rhs_shape.len() { - rhs_shape[rhs_shape.len() - 1 - i] - } else { - 1 - }; - - if lhs_dim == rhs_dim { - broadcast_shape[ndims - 1 - i] = lhs_dim; - } else if lhs_dim == 1 { - broadcast_shape[ndims - 1 - i] = rhs_dim; - } else if rhs_dim == 1 { - broadcast_shape[ndims - 1 - i] = lhs_dim; - } else { - panic!( - "Incompatible shapes for broadcasting: {:?} and {:?}", - lhs_shape, rhs_shape - ); - } - } - - // Create IxDyn from broadcast shape - let broadcast_dim = ndarray::IxDyn(&broadcast_shape); - - // Broadcast both arrays - let lhs_broadcast = lhs - .broadcast(broadcast_dim.clone()) - .expect("Failed to broadcast lhs"); - let rhs_broadcast = rhs - .broadcast(broadcast_dim) - .expect("Failed to broadcast rhs"); - - (lhs_broadcast, rhs_broadcast) -} - -impl NdArrayMathOps -where - E: Copy + NdArrayElement, -{ - pub fn add(lhs: SharedArray, rhs: SharedArray) -> SharedArray { - let (lhs, rhs) = dispatch_binary_simd!( - E, VecAdd, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 - ); - - let array = &lhs + &rhs; - array.into_shared() - } - - pub fn add_scalar(lhs: SharedArray, rhs: E) -> SharedArray { - let lhs = dispatch_binary_scalar_simd!( - E, - VecAdd, - lhs, - rhs.elem(), - u8, - i8, - u16, - i16, - u32, - i32, - f32, - u64, - i64, - f64 - ); - - let array = lhs + rhs; - array.into_shared() - } - - pub fn sub(lhs: SharedArray, rhs: SharedArray) -> SharedArray { - let (lhs, rhs) = dispatch_binary_simd!( - E, VecSub, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 - ); - - let array = lhs - rhs; - array.into_shared() - } - - pub fn sub_scalar(lhs: SharedArray, rhs: E) -> SharedArray { - let lhs = dispatch_binary_scalar_simd!( - E, - VecSub, - lhs, - rhs.elem(), - u8, - i8, - u16, - i16, - u32, - i32, - f32, - u64, - i64, - f64 - ); - - let array = lhs - rhs; - array.into_shared() - } - - pub fn mul(lhs: SharedArray, rhs: SharedArray) -> SharedArray { - let (lhs, rhs) = - dispatch_binary_simd!(noq, E, VecMul, lhs, rhs, u16, i16, u32, i32, f32, f64); - - let array = lhs * rhs; - array.into_shared() - } - - pub fn mul_scalar(lhs: SharedArray, rhs: E) -> SharedArray { - let lhs = dispatch_binary_scalar_simd!( - noq, - E, - VecMul, - lhs, - rhs.elem(), - u16, - i16, - u32, - i32, - f32, - f64 - ); - - let array = lhs * rhs; - array.into_shared() - } - - pub fn div(lhs: SharedArray, rhs: SharedArray) -> SharedArray { - let (lhs, rhs) = dispatch_binary_simd!(noq, E, VecDiv, lhs, rhs, f32, f64); - - let array = lhs / rhs; - array.into_shared() - } - - pub fn div_scalar(lhs: SharedArray, rhs: E) -> SharedArray { - let lhs = dispatch_binary_scalar_simd!(noq, E, VecDiv, lhs, rhs.elem(), f32, f64); - - let array = lhs / rhs; - array.into_shared() - } - - pub fn remainder(lhs: SharedArray, rhs: SharedArray) -> SharedArray { - // Use into_owned() instead of clone() - only copies if shared, avoids copy if unique - let mut out = lhs.into_owned(); - Zip::from(&mut out).and(&rhs).for_each(|out_elem, &b| { - // out_elem holds lhs value; read it before overwriting with remainder - let a_f = (*out_elem).to_f64(); - let b_f = b.to_f64(); - let r = a_f - b_f * (a_f / b_f).floor(); - *out_elem = r.elem(); - }); - out.into_shared() - } - - pub fn remainder_scalar(lhs: SharedArray, rhs: E) -> SharedArray - where - E: core::ops::Rem, - { - let array = lhs.mapv(|x| ((x % rhs) + rhs) % rhs); - array.into_shared() - } - - pub fn recip(tensor: SharedArray) -> SharedArray { - let tensor = dispatch_unary_simd!(E, RecipVec, tensor, f32); - - let array = tensor.map(|x| 1.elem::() / *x); - array.into_shared() - } - - /// Sum all elements - zero-copy for borrowed storage. - pub fn sum_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { - let sum = view.sum(); - ArrayD::from_elem(IxDyn(&[1]), sum).into_shared() - } - - /// Mean of all elements - zero-copy for borrowed storage. - pub fn mean_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { - let mean = view.mean().unwrap(); - ArrayD::from_elem(IxDyn(&[1]), mean).into_shared() - } - - /// Product of all elements - zero-copy for borrowed storage. - pub fn prod_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { - let prod = view.iter().fold(E::one(), |acc, &x| acc * x); - ArrayD::from_elem(IxDyn(&[1]), prod).into_shared() - } - - pub fn mean_dim(tensor: SharedArray, dim: usize) -> SharedArray { - let ndims = tensor.shape().num_dims(); - match ndims { - d if (1..=6).contains(&d) => keepdim!(dim, tensor, mean), - _ => panic!("Dim not supported {ndims}"), - } - } - - pub fn sum_dim(tensor: SharedArray, dim: usize) -> SharedArray { - let ndims = tensor.shape().num_dims(); - match ndims { - d if (1..=6).contains(&d) => keepdim!(dim, tensor, sum), - _ => panic!("Dim not supported {ndims}"), - } - } - - pub fn prod_dim(tensor: SharedArray, dim: usize) -> SharedArray { - let ndims = tensor.shape().num_dims(); - match ndims { - d if (1..=6).contains(&d) => keepdim!(dim, tensor, prod), - _ => panic!("Dim not supported {ndims}"), - } - } - - pub fn cumsum(tensor: SharedArray, dim: usize) -> SharedArray { - cumsum_dim(tensor, dim) - } - - pub fn cumprod(tensor: SharedArray, dim: usize) -> SharedArray { - cumprod_dim(tensor, dim) - } - - pub fn select( - tensor: SharedArray, - dim: usize, - indices: SharedArray, - ) -> SharedArray { - let array = tensor.select( - Axis(dim), - &indices - .into_iter() - .map(|i| i.elem::() as usize) - .collect::>(), - ); - - array.into_shared() - } - - pub fn select_assign( - tensor: SharedArray, - dim: usize, - indices: SharedArray, - value: SharedArray, - ) -> SharedArray { - let mut output_array = tensor.into_owned(); - - for (index_value, index) in indices.into_iter().enumerate() { - let mut view = output_array.index_axis_mut(Axis(dim), index.elem::() as usize); - let value = value.index_axis(Axis(dim), index_value); - - view.zip_mut_with(&value, |a, b| *a += *b); - } - - output_array.into_shared() - } - - pub(crate) fn elementwise_op( - lhs: SharedArray, - rhs: SharedArray, - var_name: impl FnMut(&E, &OtherE) -> E, - ) -> SharedArray { - let lhs = lhs.broadcast(rhs.dim()).unwrap_or(lhs.view()); - let rhs = rhs.broadcast(lhs.dim()).unwrap_or(rhs.view()); - - Zip::from(lhs).and(rhs).map_collect(var_name).into_shared() - } - - pub(crate) fn elementwise_op_scalar( - lhs: SharedArray, - var_name: impl FnMut(E) -> E, - ) -> SharedArray { - lhs.mapv(var_name).into_shared() - } - - pub(crate) fn abs(tensor: SharedArray) -> SharedArray { - let tensor = dispatch_unary_simd!(E, VecAbs, tensor, i8, i16, i32, f32, f64); - - tensor.mapv_into(|a| a.abs_elem()).into_shared() - } - - pub(crate) fn equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { - let (lhs, rhs) = dispatch_cmp_simd!( - E, VecEquals, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 - ); - - // Use the helper to broadcast both arrays to a common shape - let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); - // Now we can safely zip and compare - Zip::from(&lhs_broadcast) - .and(&rhs_broadcast) - .map_collect(|&lhs, &rhs| lhs == rhs) - .into_shared() - } - - pub(crate) fn equal_elem(lhs: SharedArray, rhs: E) -> SharedArray { - let lhs = dispatch_cmp_scalar_simd!( - E, - VecEquals, - lhs, - rhs.elem(), - u8, - i8, - u16, - i16, - u32, - f32, - i32, - u64, - i64, - f64 - ); - - lhs.mapv(|a| a == rhs).into_shared() - } - - pub(crate) fn sign_op(tensor: SharedArray) -> SharedArray - where - E: Signed, - { - let zero = 0.elem(); - let one = 1.elem::(); - - tensor - .mapv(|x| { - if x == zero { - zero - } else { - match x.is_positive() { - true => one, - false => -one, - } - } - }) - .into_shared() - } -} - -impl NdArrayMathOps -where - E: Copy + NdArrayElement + PartialOrd, -{ - /// Max of all elements - zero-copy for borrowed storage. - pub fn max_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { - let max = view - .iter() - .copied() - .reduce(|a, b| if a > b { a } else { b }) - .expect("Cannot compute max of empty tensor"); - ArrayD::from_elem(IxDyn(&[1]), max).into_shared() - } - - /// Min of all elements - zero-copy for borrowed storage. - pub fn min_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { - let min = view - .iter() - .copied() - .reduce(|a, b| if a < b { a } else { b }) - .expect("Cannot compute min of empty tensor"); - ArrayD::from_elem(IxDyn(&[1]), min).into_shared() - } - - /// Argmax along dimension - zero-copy for borrowed storage. - pub fn argmax_view( - view: ArrayView<'_, E, IxDyn>, - dim: usize, - ) -> SharedArray { - arg_view(view, dim, CmpType::Max) - } - - /// Argmin along dimension - zero-copy for borrowed storage. - pub fn argmin_view( - view: ArrayView<'_, E, IxDyn>, - dim: usize, - ) -> SharedArray { - arg_view(view, dim, CmpType::Min) - } - - pub fn cummin(tensor: SharedArray, dim: usize) -> SharedArray { - cummin_dim(tensor, dim) - } - - pub fn cummax(tensor: SharedArray, dim: usize) -> SharedArray { - cummax_dim(tensor, dim) - } - - pub fn argmax( - tensor: SharedArray, - dim: usize, - ) -> SharedArray { - arg(tensor, dim, CmpType::Max) - } - - pub fn argmin( - tensor: SharedArray, - dim: usize, - ) -> SharedArray { - arg(tensor, dim, CmpType::Min) - } - - pub fn clamp_min(tensor: SharedArray, min: E) -> SharedArray { - let mut tensor = dispatch_binary_scalar_simd!( - E, - VecMax, - tensor, - min.elem(), - u8, - i8, - u16, - i16, - u32, - i32, - f32, - u64, - i64, - f64 - ); - - tensor.mapv_inplace(|x| match x < min { - true => min, - false => x, - }); - - tensor - } - - pub fn clamp_max(tensor: SharedArray, max: E) -> SharedArray { - let mut tensor = dispatch_binary_scalar_simd!( - E, - VecMin, - tensor, - max.elem(), - u8, - i8, - u16, - i16, - u32, - i32, - f32, - u64, - i64, - f64 - ); - - tensor.mapv_inplace(|x| match x > max { - true => max, - false => x, - }); - - tensor - } - - pub fn clamp(tensor: SharedArray, min: E, max: E) -> SharedArray { - let mut tensor = dispatch_binary_scalar_simd!( - E, - VecClamp, - tensor, - (min.elem(), max.elem()), - u8, - i8, - u16, - i16, - u32, - i32, - f32, - u64, - i64, - f64 - ); - - tensor.mapv_inplace(|x| match x < min { - true => min, - false => match x > max { - true => max, - false => x, - }, - }); - - tensor - } - - pub(crate) fn greater(lhs: SharedArray, rhs: SharedArray) -> SharedArray { - let (lhs, rhs) = dispatch_cmp_simd!( - E, VecGreater, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 - ); - - // Use the helper to broadcast both arrays to a common shape - let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); - // Now we can safely zip and compare - Zip::from(&lhs_broadcast) - .and(&rhs_broadcast) - .map_collect(|&lhs, &rhs| lhs > rhs) - .into_shared() - } - - pub(crate) fn greater_elem(lhs: SharedArray, rhs: E) -> SharedArray { - let lhs = dispatch_cmp_scalar_simd!( - E, - VecGreater, - lhs, - rhs.elem(), - u8, - i8, - u16, - i16, - u32, - f32, - i32, - u64, - i64, - f64 - ); - - lhs.mapv(|a| a > rhs).into_shared() - } - - pub(crate) fn greater_equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { - let (lhs, rhs) = dispatch_cmp_simd!( - E, - VecGreaterEq, - lhs, - rhs, - u8, - i8, - u16, - i16, - u32, - f32, - i32, - u64, - i64, - f64 - ); - - // Use the helper to broadcast both arrays to a common shape - let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); - // Now we can safely zip and compare - Zip::from(&lhs_broadcast) - .and(&rhs_broadcast) - .map_collect(|&lhs, &rhs| lhs >= rhs) - .into_shared() - } - - pub(crate) fn greater_equal_elem(lhs: SharedArray, rhs: E) -> SharedArray { - let lhs = dispatch_cmp_scalar_simd!( - E, - VecGreaterEq, - lhs, - rhs.elem(), - u8, - i8, - u16, - i16, - u32, - f32, - i32, - u64, - i64, - f64 - ); - - lhs.mapv(|a| a >= rhs).into_shared() - } - - pub(crate) fn lower_equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { - let (lhs, rhs) = dispatch_cmp_simd!( - E, VecLowerEq, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 - ); - - // Use the helper to broadcast both arrays to a common shape - let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); - // Now we can safely zip and compare - Zip::from(&lhs_broadcast) - .and(&rhs_broadcast) - .map_collect(|&lhs, &rhs| lhs <= rhs) - .into_shared() - } - - pub(crate) fn lower_equal_elem(lhs: SharedArray, rhs: E) -> SharedArray { - let lhs = dispatch_cmp_scalar_simd!( - E, - VecLowerEq, - lhs, - rhs.elem(), - u8, - i8, - u16, - i16, - u32, - f32, - i32, - u64, - i64, - f64 - ); - - lhs.mapv(|a| a <= rhs).into_shared() - } - - pub(crate) fn lower(lhs: SharedArray, rhs: SharedArray) -> SharedArray { - let (lhs, rhs) = dispatch_cmp_simd!( - E, VecLower, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 - ); - - // Use the helper to broadcast both arrays to a common shape - let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); - - // Now we can safely zip and compare - Zip::from(&lhs_broadcast) - .and(&rhs_broadcast) - .map_collect(|&lhs, &rhs| lhs < rhs) - .into_shared() - } - - pub(crate) fn lower_elem(lhs: SharedArray, rhs: E) -> SharedArray { - let lhs = dispatch_cmp_scalar_simd!( - E, - VecLower, - lhs, - rhs.elem(), - u8, - i8, - u16, - i16, - u32, - f32, - i32, - u64, - i64, - f64 - ); - - lhs.mapv(|a| a < rhs).into_shared() - } -} - -pub struct NdArrayBitOps(PhantomData); - -impl NdArrayBitOps { - pub(crate) fn bitand(lhs: SharedArray, rhs: SharedArray) -> SharedArray { - let (lhs, rhs) = - dispatch_binary_simd!(I, VecBitAnd, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64); - - NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { - (a.elem::() & (b.elem::())).elem() - }) - } - - pub(crate) fn bitand_scalar(lhs: SharedArray, rhs: I) -> SharedArray { - let lhs = dispatch_binary_scalar_simd!( - I, - VecBitAnd, - lhs, - rhs.elem(), - i8, - u8, - i16, - u16, - i32, - u32, - i64, - u64 - ); - - NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { - (a.elem::() & rhs.elem::()).elem() - }) - } - - pub(crate) fn bitor(lhs: SharedArray, rhs: SharedArray) -> SharedArray { - let (lhs, rhs) = - dispatch_binary_simd!(I, VecBitOr, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64); - - NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { - (a.elem::() | (b.elem::())).elem() - }) - } - - pub(crate) fn bitor_scalar(lhs: SharedArray, rhs: I) -> SharedArray { - let lhs = dispatch_binary_scalar_simd!( - I, - VecBitOr, - lhs, - rhs.elem(), - i8, - u8, - i16, - u16, - i32, - u32, - i64, - u64 - ); - - NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { - (a.elem::() | rhs.elem::()).elem() - }) - } - - pub(crate) fn bitxor(lhs: SharedArray, rhs: SharedArray) -> SharedArray { - let (lhs, rhs) = - dispatch_binary_simd!(I, VecBitXor, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64); - - NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { - (a.elem::() ^ (b.elem::())).elem() - }) - } - - pub(crate) fn bitxor_scalar(lhs: SharedArray, rhs: I) -> SharedArray { - let lhs = dispatch_binary_scalar_simd!( - I, - VecBitXor, - lhs, - rhs.elem(), - i8, - u8, - i16, - u16, - i32, - u32, - i64, - u64 - ); - - NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { - (a.elem::() ^ rhs.elem::()).elem() - }) - } - - pub(crate) fn bitnot(tensor: SharedArray) -> SharedArray { - let tensor = - dispatch_unary_simd!(I, VecBitNot, tensor, i8, u8, i16, u16, i32, u32, i64, u64); - - NdArrayMathOps::elementwise_op_scalar(tensor, |a: I| (!a.elem::()).elem()) - } -} - -pub struct NdArrayBoolOps; - -// Rust booleans are either `00000000` or `00000001`, so bitwise and/or is fine, but bitwise not would -// produce invalid values. -impl NdArrayBoolOps { - pub(crate) fn equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { - #[cfg(feature = "simd")] - let (lhs, rhs) = match try_cmp_simd::(lhs, rhs) { - Ok(out) => return out, - Err(args) => args, - }; - - // Use the helper to broadcast both arrays to a common shape - let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); - // Now we can safely zip and compare - Zip::from(&lhs_broadcast) - .and(&rhs_broadcast) - .map_collect(|&lhs, &rhs| lhs == rhs) - .into_shared() - } - - pub(crate) fn equal_elem(lhs: SharedArray, rhs: bool) -> SharedArray { - #[cfg(feature = "simd")] - let lhs = match try_cmp_scalar_simd::(lhs, rhs.elem()) { - Ok(out) => return out, - Err(args) => args, - }; - - lhs.mapv(|a| a == rhs).into_shared() - } - - pub(crate) fn and(lhs: SharedArray, rhs: SharedArray) -> SharedArray { - #[cfg(feature = "simd")] - let (lhs, rhs) = match try_binary_simd::(lhs, rhs) { - Ok(out) => return out, - Err(args) => args, - }; - - // Use the helper to broadcast both arrays to a common shape - let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); - // Now we can safely zip and compare - Zip::from(&lhs_broadcast) - .and(&rhs_broadcast) - .map_collect(|&lhs, &rhs| lhs && rhs) - .into_shared() - } - - pub(crate) fn or(lhs: SharedArray, rhs: SharedArray) -> SharedArray { - #[cfg(feature = "simd")] - let (lhs, rhs) = match try_binary_simd::(lhs, rhs) { - Ok(out) => return out, - Err(args) => args, - }; - - // Use the helper to broadcast both arrays to a common shape - let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); - // Now we can safely zip and compare - Zip::from(&lhs_broadcast) - .and(&rhs_broadcast) - .map_collect(|&lhs, &rhs| lhs || rhs) - .into_shared() - } - - /// Any element is true - zero-copy for borrowed storage. - pub fn any_view(view: ArrayView<'_, bool, IxDyn>) -> bool { - view.iter().any(|&x| x) - } - - /// All elements are true - zero-copy for borrowed storage. - pub fn all_view(view: ArrayView<'_, bool, IxDyn>) -> bool { - view.iter().all(|&x| x) - } -} - -enum CmpType { - Min, - Max, -} - -fn arg( - tensor: SharedArray, - dim: usize, - cmp: CmpType, -) -> SharedArray { - arg_view(tensor.view(), dim, cmp) -} - -/// View-based argmax/argmin - zero-copy for borrowed storage. -fn arg_view( - view: ArrayView<'_, E, IxDyn>, - dim: usize, - cmp: CmpType, -) -> SharedArray { - let mut reshape = view.shape().to_vec(); - reshape[dim] = 1; - - let output = view.map_axis(Axis(dim), |arr| { - // Find the min/max value in the array, and return its index. - let (_e, idx) = arr.indexed_iter().fold((arr[0], 0usize), |acc, (idx, e)| { - let cmp = match cmp { - CmpType::Min => e < &acc.0, - CmpType::Max => e > &acc.0, - }; - - if cmp { (*e, idx) } else { acc } - }); - - (idx as i64).elem() - }); - - let output = output.to_shape(Dim(reshape.as_slice())).unwrap(); - - output.into_shared() -} - -#[cfg(test)] -mod tests { - use burn_backend::TensorData; - - use crate::NdArrayTensor; - - use super::*; - - #[test] - fn should_generate_row_major_layout_for_cat() { - let expected_shape: &[usize] = &[4, 6, 2]; - let expected_strides: &[isize] = &[12, 2, 1]; - let NdArrayTensor::I32(expected_storage) = NdArrayTensor::from_data(TensorData::from([ - [[1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0]], - [[7, 0], [8, 0], [9, 0], [10, 0], [11, 0], [12, 0]], - [[13, 0], [14, 0], [15, 0], [16, 0], [17, 0], [18, 0]], - [[19, 0], [20, 0], [21, 0], [22, 0], [23, 0], [24, 0]], - ])) else { - panic!() - }; - let expected_array = expected_storage.into_shared(); - - let NdArrayTensor::I32(tensor_storage) = NdArrayTensor::from_data(TensorData::from([ - [1, 2, 3, 4, 5, 6], - [7, 8, 9, 10, 11, 12], - [13, 14, 15, 16, 17, 18], - [19, 20, 21, 22, 23, 24], - ])) else { - panic!() - }; - let tensor = tensor_storage.into_shared(); - - // unsqueeze dim on the outermost axis - let array = NdArrayOps::reshape(tensor, Shape::from([4, 6, 1])); - let NdArrayTensor::I32(zeros_storage) = - NdArrayTensor::from_data(TensorData::zeros::([4, 6, 1])) - else { - panic!() - }; - let zeros = zeros_storage.into_shared(); - // make `ndarray` concatenates array on the outermost axis - let array = NdArrayOps::cat([array, zeros].to_vec(), 2); - - assert!(array.is_standard_layout()); - assert_eq!(array.shape(), expected_shape); - assert_eq!(array.strides(), expected_strides); - assert_eq!( - array.into_iter().collect::>(), - expected_array.into_iter().collect::>(), - ); - } -} diff --git a/crates/burn/src/ops/base.rs b/crates/burn/src/ops/base.rs new file mode 120000 index 00000000..7c06d3bf --- /dev/null +++ b/crates/burn/src/ops/base.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/base.rs \ No newline at end of file diff --git a/crates/burn/src/ops/bool_tensor.rs b/crates/burn/src/ops/bool_tensor.rs deleted file mode 100644 index 1d1f26d3..00000000 --- a/crates/burn/src/ops/bool_tensor.rs +++ /dev/null @@ -1,241 +0,0 @@ -// Language -use alloc::vec; -use alloc::vec::Vec; -use burn_backend::Scalar; -use burn_backend::{ElementConversion, TensorMetadata, tensor::FloatTensor}; -use burn_backend::{ - backend::ExecutionError, - ops::BoolTensorOps, - tensor::{BoolTensor, IntTensor}, -}; -use burn_std::{BoolDType, FloatDType, IntDType}; -use ndarray::IntoDimension; - -// Current crate -use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}; -use crate::{NdArray, execute_with_int_dtype, tensor::NdArrayTensor}; -use crate::{ - NdArrayDevice, SharedArray, execute_with_float_out_dtype, execute_with_int_out_dtype, slice, -}; - -// Workspace crates -use burn_backend::{Shape, TensorData, backend::Backend}; - -use super::{NdArrayBoolOps, NdArrayOps}; - -impl BoolTensorOps - for NdArray -where - NdArrayTensor: From>, - NdArrayTensor: From>, -{ - fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor { - if !data.dtype.is_bool() { - unimplemented!("Unsupported dtype for `bool_from_data`") - } - NdArrayTensor::from_data(data) - } - - async fn bool_into_data(tensor: NdArrayTensor) -> Result { - Ok(tensor.into_data()) - } - - fn bool_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor { - tensor - } - - fn bool_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { - NdArrayOps::reshape(tensor.bool(), shape).into() - } - - fn bool_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor { - slice!(tensor, slices) - } - - fn bool_into_int(tensor: NdArrayTensor, out_dtype: IntDType) -> NdArrayTensor { - // Use mapv directly instead of collecting to Vec and going through TensorData - execute_with_int_out_dtype!( - out_dtype, - I, - tensor.bool().mapv(|b| b.elem::()).into_shared().into() - ) - } - - fn bool_device(_tensor: &NdArrayTensor) -> as Backend>::Device { - NdArrayDevice::Cpu - } - - fn bool_empty( - shape: Shape, - _device: & as Backend>::Device, - dtype: BoolDType, - ) -> NdArrayTensor { - Self::bool_zeros(shape, _device, dtype) - } - - fn bool_zeros( - shape: Shape, - _device: & as Backend>::Device, - _dtype: BoolDType, - ) -> NdArrayTensor { - let values = vec![false; shape.num_elements()]; - NdArrayTensor::from_data(TensorData::new(values, shape)) - } - - fn bool_ones( - shape: Shape, - _device: & as Backend>::Device, - _dtype: BoolDType, - ) -> NdArrayTensor { - let values = vec![true; shape.num_elements()]; - NdArrayTensor::from_data(TensorData::new(values, shape)) - } - - fn bool_slice_assign( - tensor: NdArrayTensor, - slices: &[burn_backend::Slice], - value: NdArrayTensor, - ) -> NdArrayTensor { - NdArrayOps::slice_assign(tensor.bool(), slices, value.bool()).into() - } - - fn bool_cat(tensors: Vec, dim: usize) -> NdArrayTensor { - NdArrayOps::cat(tensors.into_iter().map(|it| it.bool()).collect(), dim).into() - } - - fn bool_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { - NdArrayBoolOps::equal(lhs.bool(), rhs.bool()).into() - } - - fn bool_not(tensor: NdArrayTensor) -> NdArrayTensor { - tensor.bool().mapv(|a| !a).into_shared().into() - } - - fn bool_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { - NdArrayBoolOps::and(lhs.bool(), rhs.bool()).into() - } - - fn bool_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { - NdArrayBoolOps::or(lhs.bool(), rhs.bool()).into() - } - - fn bool_into_float(tensor: NdArrayTensor, out_dtype: FloatDType) -> FloatTensor { - execute_with_float_out_dtype!( - out_dtype, - E, - tensor.bool().mapv(|b| b.elem::()).into_shared().into() - ) - } - - fn bool_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor { - NdArrayOps::swap_dims(tensor.bool(), dim1, dim2).into() - } - - fn bool_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor { - tensor.bool().permuted_axes(axes.into_dimension()).into() - } - - fn bool_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { - NdArrayOps::expand(tensor.bool(), shape).into() - } - - fn bool_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor { - execute_with_int_dtype!(indices, I, |indices: SharedArray| -> NdArrayTensor { - let tensor_bool = tensor.bool(); - let indices_vec: Vec = indices - .into_iter() - .map(|i| i.elem::() as usize) - .collect(); - - let selected = tensor_bool.select(ndarray::Axis(dim), &indices_vec); - selected.into_shared().into() - }) - } - - fn bool_select_or( - tensor: NdArrayTensor, - dim: usize, - indices: NdArrayTensor, - value: NdArrayTensor, - ) -> NdArrayTensor { - execute_with_int_dtype!(indices, I, |indices: SharedArray| -> NdArrayTensor { - let mut output_array = tensor.bool().into_owned(); - let value_bool = value.bool(); - - for (index_value, index) in indices.into_iter().enumerate() { - let index_usize = index.elem::() as usize; - let mut view = output_array.index_axis_mut(ndarray::Axis(dim), index_usize); - let value_slice = value_bool.index_axis(ndarray::Axis(dim), index_value); - // For boolean tensors, select_assign should use logical OR operation - view.zip_mut_with(&value_slice, |a, b| *a = *a || *b); - } - output_array.into_shared().into() - }) - } - - fn bool_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor { - NdArrayOps::flip(tensor.bool(), axes).into() - } - - fn bool_unfold(tensor: NdArrayTensor, dim: usize, size: usize, step: usize) -> NdArrayTensor { - NdArrayOps::unfold(tensor.bool(), dim, size, step).into() - } - - fn bool_mask_where( - tensor: BoolTensor, - mask: BoolTensor, - value: BoolTensor, - ) -> BoolTensor { - NdArrayOps::mask_where(tensor.bool(), mask.bool(), value.bool()).into() - } - - fn bool_mask_fill( - tensor: BoolTensor, - mask: BoolTensor, - value: Scalar, - ) -> BoolTensor { - NdArrayOps::mask_fill(tensor.bool(), mask.bool(), value.elem()).into() - } - - fn bool_gather( - dim: usize, - tensor: BoolTensor, - indices: IntTensor, - ) -> BoolTensor { - execute_with_int_dtype!(indices, |indices| NdArrayOps::gather( - dim, - tensor.bool(), - indices - )) - } - - fn bool_scatter_or( - dim: usize, - tensor: BoolTensor, - indices: IntTensor, - value: BoolTensor, - ) -> BoolTensor { - execute_with_int_dtype!(indices, |indices| NdArrayOps::scatter( - dim, - tensor.bool(), - indices, - value.bool() - )) - } - - fn bool_equal_elem(lhs: BoolTensor, rhs: Scalar) -> BoolTensor { - NdArrayBoolOps::equal_elem(lhs.bool(), rhs.elem()).into() - } - - fn bool_any(tensor: BoolTensor) -> BoolTensor { - // Use view() for zero-copy on borrowed storage with short-circuit evaluation - let result = NdArrayBoolOps::any_view(tensor.bool().view()); - NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1]))) - } - - fn bool_all(tensor: BoolTensor) -> BoolTensor { - // Use view() for zero-copy on borrowed storage with short-circuit evaluation - let result = NdArrayBoolOps::all_view(tensor.bool().view()); - NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1]))) - } -} diff --git a/crates/burn/src/ops/bool_tensor.rs b/crates/burn/src/ops/bool_tensor.rs new file mode 120000 index 00000000..8c09bc90 --- /dev/null +++ b/crates/burn/src/ops/bool_tensor.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/bool_tensor.rs \ No newline at end of file diff --git a/crates/burn/src/ops/conv.rs b/crates/burn/src/ops/conv.rs deleted file mode 100644 index 5fb2cad5..00000000 --- a/crates/burn/src/ops/conv.rs +++ /dev/null @@ -1,574 +0,0 @@ -use burn_backend::{ - ElementConversion, - ops::{ - ConvOptions, ConvTransposeOptions, - conv::{calculate_conv_output_size, calculate_conv_transpose_output_size}, - }, -}; -use ndarray::{ - Array3, Array4, Array5, ArrayView2, ArrayView3, ArrayViewMut2, ArrayViewMut3, Axis, Dim, s, -}; - -use crate::{ - NdArrayElement, SharedArray, iter_par, iter_range_par, - ops::padding::{apply_padding_4d, apply_padding_5d}, - run_par, - sharing::UnsafeSharedRef, - tensor::NdArrayTensor, -}; - -#[inline(always)] -fn conv2d_mad_inner( - mut output: ArrayViewMut2, - x: ArrayView2, - k: E, - k_xy: (usize, usize), - out_xy: (usize, usize), - stride: (usize, usize), - dilation: (usize, usize), -) { - let (kh, kw) = k_xy; - let (out_width, out_height) = out_xy; - let (stride_width, stride_height) = stride; - let (dilation_width, dilation_height) = dilation; - - for oh in 0..out_height { - // Construct a sub-slice view of the input row. - // This is done upfront so that rustc does not have to emit bounds checks - // in the hot loop below. - let ir = x - .row(oh * stride_height + kh * dilation_height) - .to_slice() - .unwrap(); - - // Ditto. Construct a sub-slice view of the output row, and explicitly specify - // the bounds upfront as 0..out_width so that rustc can make the assumption - // that all accesses are in-bounds in the below loop. - let mut or = output.row_mut(oh); - let or = &mut or.as_slice_mut().unwrap()[0..out_width]; - - #[allow(clippy::needless_range_loop)] - for ow in 0..out_width { - let iw = ow * stride_width + kw * dilation_width; - or[ow] += ir[iw] * k; - } - } -} - -#[inline(always)] -fn conv3d_mad_inner( - mut output: ArrayViewMut3, - x: ArrayView3, - k: E, - k_xyz: (usize, usize, usize), - out_xyz: (usize, usize, usize), - stride: (usize, usize, usize), - dilation: (usize, usize, usize), -) { - let (kd, kh, kw) = k_xyz; - let (out_width, out_height, out_depth) = out_xyz; - let (stride_width, stride_height, stride_depth) = stride; - let (dilation_width, dilation_height, dilation_depth) = dilation; - - for od in 0..out_depth { - let id = od * stride_depth + kd * dilation_depth; - - for oh in 0..out_height { - let ih = oh * stride_height + kh * dilation_height; - - // Construct a sub-slice view of the input row. - // This is done upfront so that rustc does not have to emit bounds checks - // in the hot loop below. - let ir = x.slice(s![id, ih, ..]).to_slice().unwrap(); - - // Ditto. Construct a sub-slice view of the output row, and explicitly specify - // the bounds upfront as 0..out_width so that rustc can make the assumption - // that all accesses are in-bounds in the below loop. - let or = &mut output - .slice_mut(s![od, oh, 0..out_width]) - .into_slice() - .unwrap()[0..out_width]; - - #[allow(clippy::needless_range_loop)] - for ow in 0..out_width { - let iw = ow * stride_width + kw * dilation_width; - or[ow] += ir[iw] * k; - } - } - } -} - -pub(crate) fn conv2d( - x: SharedArray, - weight: SharedArray, - bias: Option>, - options: ConvOptions<2>, -) -> SharedArray -where - NdArrayTensor: From>, -{ - let [dilation_height, dilation_width] = options.dilation; - let [padding_height, padding_width] = options.padding; - let [stride_height, stride_width] = options.stride; - let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap(); - let [out_channels, in_channels, kernel_height, kernel_width] = - weight.shape().try_into().unwrap(); - let channels_per_group = out_channels / options.groups; - - let out_height = calculate_conv_output_size( - kernel_height, - stride_height, - padding_height, - dilation_height, - in_height, - ); - let out_width = calculate_conv_output_size( - kernel_width, - stride_width, - padding_width, - dilation_width, - in_width, - ); - - let x = apply_padding_4d::(x, options.padding, 0i32.elem()); - - // Convert inputs from dynamic indexes to static to improve perf. - let x = x.into_dimensionality::().unwrap(); - let weights = weight.into_dimensionality::().unwrap(); - - let mut output = Array3::zeros(Dim([batch_size * out_channels, out_height, out_width])); - - run_par!(|| { - iter_par!(output.axis_iter_mut(Axis(0))) - .enumerate() - .for_each( - #[inline(never)] - |(k, mut output)| { - let b = k / out_channels; - let oc = k % out_channels; - let g = oc / channels_per_group; - - for ic in (in_channels * g)..(in_channels * (g + 1)) { - let weight_ic = ic - (g * in_channels); - - let x = x.slice(s![b, ic, .., ..]); - let k = weights.slice(s![oc, weight_ic, .., ..]); - - for kh in 0..kernel_height { - for kw in 0..kernel_width { - let k = k[[kh, kw]]; - - // NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization - // in the case that the stride/dilation is 1. - #[allow(clippy::if_same_then_else)] - if (1, 1, 1, 1) - == ( - stride_width, - stride_height, - dilation_width, - dilation_height, - ) - { - conv2d_mad_inner( - output.view_mut(), - x.view(), - k, - (kh, kw), - (out_width, out_height), - (stride_width, stride_height), - (dilation_width, dilation_height), - ); - } else { - conv2d_mad_inner( - output.view_mut(), - x.view(), - k, - (kh, kw), - (out_width, out_height), - (stride_width, stride_height), - (dilation_width, dilation_height), - ); - } - } - } - } - - if let Some(bias) = &bias { - let bias = bias[oc]; - - for oh in 0..out_height { - // Get a mutable slice reference to the row we're looping over. - // We explicitly define the bounds to 0..out_width so that rustc can make - // the assumption that all accesses are in-bounds. - let mut or = output.row_mut(oh); - let or = &mut or.as_slice_mut().unwrap()[0..out_width]; - - #[allow(clippy::needless_range_loop)] - for ow in 0..out_width { - or[ow] += bias; - } - } - } - }, - ); - }); - - output - .to_shape([batch_size, out_channels, out_height, out_width]) - .unwrap() - .into_dyn() - .into_shared() -} - -pub(crate) fn conv_transpose2d( - x: SharedArray, - weight: SharedArray, - bias: Option>, - options: ConvTransposeOptions<2>, -) -> SharedArray { - let [dilation_height, dilation_width] = options.dilation; - let [padding_height, padding_width] = options.padding; - let [stride_height, stride_width] = options.stride; - let [out_padding_height, out_padding_width] = options.padding_out; - let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap(); - let [in_channels, out_channels, kernel_height, kernel_width] = - weight.shape().try_into().unwrap(); - - let out_height = calculate_conv_transpose_output_size( - kernel_height, - stride_height, - padding_height, - out_padding_height, - dilation_height, - in_height, - ); - let out_width = calculate_conv_transpose_output_size( - kernel_width, - stride_width, - padding_width, - out_padding_width, - dilation_width, - in_width, - ); - - let x = x; - let mut output = Array4::zeros(Dim([ - batch_size, - out_channels * options.groups, - out_height, - out_width, - ])); - - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - run_par!(|| { - iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe { - let b = k / (out_channels * options.groups); - let oc = k % out_channels; - let g = (k / out_channels) % options.groups; - - let output = unsafe_shared_out.get(); - - let oc_out = oc + (out_channels * g); - let ic_start = g * (in_channels / options.groups); - let ic_end = ic_start + in_channels / options.groups; - - for ic in ic_start..ic_end { - for ih in 0..in_height { - for iw in 0..in_width { - for kh in 0..kernel_height { - for kw in 0..kernel_width { - let oh = ih * stride_height + kh * dilation_height; - let ow = iw * stride_width + kw * dilation_width; - - if oh >= out_height + padding_height - || ow >= out_width + padding_width - || oh < padding_height - || ow < padding_width - { - continue; - } - - let oh = oh - padding_height; - let ow = ow - padding_width; - - output[[b, oc_out, oh, ow]] += - x[[b, ic, ih, iw]] * weight[[ic, oc, kh, kw]]; - } - } - } - } - } - - if let Some(bias) = &bias { - for oh in 0..out_height { - for ow in 0..out_width { - output[[b, oc_out, oh, ow]] += bias[oc_out]; - } - } - } - }); - }); - - output.into_dyn().into_shared() -} - -pub(crate) fn conv3d( - x: SharedArray, - weight: SharedArray, - bias: Option>, - options: ConvOptions<3>, -) -> SharedArray -where - NdArrayTensor: From>, -{ - let [dilation_depth, dilation_height, dilation_width] = options.dilation; - let [padding_depth, padding_height, padding_width] = options.padding; - let [stride_depth, stride_height, stride_width] = options.stride; - let [batch_size, _in_channels, in_depth, in_height, in_width] = x.shape().try_into().unwrap(); - let [ - out_channels, - in_channels, - kernel_depth, - kernel_height, - kernel_width, - ] = weight.shape().try_into().unwrap(); - let out_c_per_group = out_channels / options.groups; - - let out_depth = calculate_conv_output_size( - kernel_depth, - stride_depth, - padding_depth, - dilation_depth, - in_depth, - ); - let out_height = calculate_conv_output_size( - kernel_height, - stride_height, - padding_height, - dilation_height, - in_height, - ); - let out_width = calculate_conv_output_size( - kernel_width, - stride_width, - padding_width, - dilation_width, - in_width, - ); - - let x = apply_padding_5d::(x, options.padding, 0i32.elem()); - - // Convert inputs from dynamic indexes to static to improve perf. - let x = x.into_dimensionality::().unwrap(); - let weights = weight.into_dimensionality::().unwrap(); - - let mut output = Array4::zeros(Dim([ - batch_size * out_channels, - out_depth, - out_height, - out_width, - ])); - - run_par!(|| { - iter_par!(output.axis_iter_mut(Axis(0))) - .enumerate() - .for_each( - #[inline(never)] - |(k, mut output)| { - let b = k / out_channels; - let oc = k % out_channels; - let g = oc / out_c_per_group; - - for ic in (in_channels * g)..(in_channels * (g + 1)) { - let weight_ic = ic - (g * in_channels); - - let x = x.slice(s![b, ic, .., .., ..]); - let k = weights.slice(s![oc, weight_ic, .., .., ..]); - - for kd in 0..kernel_depth { - for kh in 0..kernel_height { - for kw in 0..kernel_width { - let k = k[[kd, kh, kw]]; - - // NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization - // in the case that the stride/dilation is 1. - #[allow(clippy::if_same_then_else)] - if (1, 1, 1, 1, 1, 1) - == ( - stride_width, - stride_height, - stride_depth, - dilation_width, - dilation_height, - dilation_depth, - ) - { - conv3d_mad_inner( - output.view_mut(), - x.view(), - k, - (kd, kh, kw), - (out_width, out_height, out_depth), - (stride_width, stride_height, stride_depth), - (dilation_width, dilation_height, dilation_depth), - ); - } else { - conv3d_mad_inner( - output.view_mut(), - x.view(), - k, - (kd, kh, kw), - (out_width, out_height, out_depth), - (stride_width, stride_height, stride_depth), - (dilation_width, dilation_height, dilation_depth), - ); - } - } - } - } - } - - if let Some(bias) = &bias { - let bias = bias[oc]; - - // Get a mutable iterator to the row we're looping over. - let orows = output.rows_mut(); - for mut or in orows { - // We explicitly define the bounds to 0..out_width so that rustc can make - // the assumption that all accesses are in-bounds. - let or = &mut or.as_slice_mut().unwrap()[0..out_width]; - - #[allow(clippy::needless_range_loop)] - for ow in 0..out_width { - or[ow] += bias; - } - } - } - }, - ); - }); - - output - .to_shape([batch_size, out_channels, out_depth, out_height, out_width]) - .unwrap() - .into_dyn() - .into_shared() -} - -pub(crate) fn conv_transpose3d( - x: SharedArray, - weight: SharedArray, - bias: Option>, - options: ConvTransposeOptions<3>, -) -> SharedArray { - let [dilation_depth, dilation_height, dilation_width] = options.dilation; - let [padding_depth, padding_height, padding_width] = options.padding; - let [stride_depth, stride_height, stride_width] = options.stride; - let [out_padding_depth, out_padding_height, out_padding_width] = options.padding_out; - let [batch_size, _in_channels, in_depth, in_height, in_width] = x.shape().try_into().unwrap(); - let [ - in_channels, - out_channels, - kernel_depth, - kernel_height, - kernel_width, - ] = weight.shape().try_into().unwrap(); - - let out_depth = calculate_conv_transpose_output_size( - kernel_depth, - stride_depth, - padding_depth, - out_padding_depth, - dilation_depth, - in_depth, - ); - let out_height = calculate_conv_transpose_output_size( - kernel_height, - stride_height, - padding_height, - out_padding_height, - dilation_height, - in_height, - ); - let out_width = calculate_conv_transpose_output_size( - kernel_width, - stride_width, - padding_width, - out_padding_width, - dilation_width, - in_width, - ); - - let x = x; - let mut output = Array5::zeros(Dim([ - batch_size, - out_channels * options.groups, - out_depth, - out_height, - out_width, - ])); - - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - run_par!(|| { - iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe { - let b = k / (out_channels * options.groups); - let oc = k % out_channels; - let g = (k / out_channels) % options.groups; - - let output = unsafe_shared_out.get(); - - let oc_out = oc + (out_channels * g); - let ic_start = g * (in_channels / options.groups); - let ic_end = ic_start + in_channels / options.groups; - - for ic in ic_start..ic_end { - for id in 0..in_depth { - for ih in 0..in_height { - for iw in 0..in_width { - for kd in 0..kernel_depth { - for kh in 0..kernel_height { - for kw in 0..kernel_width { - let od = id * stride_depth + kd * dilation_depth; - let oh = ih * stride_height + kh * dilation_height; - let ow = iw * stride_width + kw * dilation_width; - - if od >= out_depth + padding_depth - || oh >= out_height + padding_height - || ow >= out_width + padding_width - || od < padding_depth - || oh < padding_height - || ow < padding_width - { - continue; - } - - let od = od - padding_depth; - let oh = oh - padding_height; - let ow = ow - padding_width; - - output[[b, oc_out, od, oh, ow]] += - x[[b, ic, id, ih, iw]] * weight[[ic, oc, kd, kh, kw]]; - } - } - } - } - } - } - } - - if let Some(bias) = &bias { - for od in 0..out_depth { - for oh in 0..out_height { - for ow in 0..out_width { - output[[b, oc_out, od, oh, ow]] += bias[oc_out]; - } - } - } - } - }); - }); - - output.into_dyn().into_shared() -} diff --git a/crates/burn/src/ops/conv.rs b/crates/burn/src/ops/conv.rs new file mode 120000 index 00000000..5e87cdb5 --- /dev/null +++ b/crates/burn/src/ops/conv.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/conv.rs \ No newline at end of file diff --git a/crates/burn/src/ops/deform_conv.rs b/crates/burn/src/ops/deform_conv.rs deleted file mode 100644 index 390010b9..00000000 --- a/crates/burn/src/ops/deform_conv.rs +++ /dev/null @@ -1,662 +0,0 @@ -use burn_backend::ops::{DeformConvOptions, conv::calculate_conv_output_size}; -use core::ops::AddAssign; -use ndarray::{ - Array2, Array4, ArrayView2, ArrayView3, ArrayView4, ArrayView6, ArrayViewMut2, Axis, Dim, Ix4, - Zip, s, -}; - -#[cfg(not(feature = "std"))] -#[allow(unused_imports)] -use num_traits::Float; - -use crate::{FloatNdArrayElement, NdArrayTensor, ShapeOps, SharedArray, iter_par, run_par}; - -use super::matmul::matmul; - -#[inline(always)] -#[allow(clippy::too_many_arguments)] -fn deform_im2col_kernel( - out_y: usize, - out_x: usize, - input: ArrayView2, - offset: ArrayView3, - mask: Option>, - mut columns: ArrayViewMut2, - args: DeformConvOptions<2>, - (kernel_h, kernel_w): (usize, usize), -) { - // position shape: [in_channels, batch_size, out_h, out_w] - // columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]] - - let (height, width) = input.dim(); - - for kernel_y in 0..kernel_h { - for kernel_x in 0..kernel_w { - let mask_value = mask - .map(|it| it[[kernel_y, kernel_x]]) - .unwrap_or_else(|| F::from_elem(1.0)); - - let offset = offset.slice(s![kernel_y, kernel_x, ..]); - let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0]) - - F::from_elem(args.padding[0]) - + offset[0]; - let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1]) - - F::from_elem(args.padding[1]) - + offset[1]; - - let interpolated = bilinear_interpolate(input, height, width, y, x); - - columns[[kernel_y, kernel_x]] = mask_value * interpolated; - } - } -} - -fn bilinear_interpolate( - input: ArrayView2, - height: usize, - width: usize, - y: F, - x: F, -) -> F { - // To simplify code - let y = y.to_f32(); - let x = x.to_f32(); - - let mut result = F::from_elem(0.0); - if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x { - let y_low = f32::floor(y); - let x_low = f32::floor(x); - let y_high = (y_low + 1.) as usize; - let x_high = (x_low + 1.) as usize; - - let zero = F::from_elem(0.0); - let v1: F = if y_low >= 0. && x_low >= 0. { - input[[y_low as usize, x_low as usize]] - } else { - zero - }; - let v2: F = if y_low >= 0. && x_high < width { - input[[y_low as usize, x_high]] - } else { - zero - }; - let v3: F = if y_high < height && x_low >= 0. { - input[[y_high, x_low as usize]] - } else { - zero - }; - let v4: F = if y_high < height && x_high < width { - input[[y_high, x_high]] - } else { - zero - }; - - let l_y = y - y_low; - let l_x = x - x_low; - let h_y = 1.0 - l_y; - let h_x = 1.0 - l_x; - - let w1 = F::from_elem(h_y * h_x); - let w2 = F::from_elem(h_y * l_x); - let w3 = F::from_elem(l_y * h_x); - let w4 = F::from_elem(l_y * l_x); - - result = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; - } - result -} - -pub(crate) fn deform_conv2d( - input: SharedArray, - offset: SharedArray, - weight: SharedArray, - mask: Option>, - bias: Option>, - args: DeformConvOptions<2>, -) -> SharedArray -where - NdArrayTensor: From>, -{ - let [batch_size, _, in_height, in_width] = input.shape().dims(); - let [out_channels, _, kernel_h, kernel_w] = weight.shape().dims(); - let groups = args.weight_groups; - - let weight = weight.as_standard_layout(); - - let out_h = calculate_conv_output_size( - kernel_h, - args.stride[0], - args.padding[0], - args.dilation[0], - in_height, - ); - let out_w = calculate_conv_output_size( - kernel_w, - args.stride[1], - args.padding[1], - args.dilation[1], - in_width, - ); - let out_dims = (out_h, out_w); - - let input = input.into_dimensionality::().unwrap(); - let offset = offset.into_dimensionality::().unwrap(); - let mask = mask.as_ref().map(|it| { - it.to_shape(( - batch_size, - args.offset_groups, - kernel_h, - kernel_w, - out_h, - out_w, - )) - .unwrap() - }); - - let columns = deform_im2col( - input.view(), - offset.view(), - mask.as_ref().map(|it| it.view()), - args, - out_dims, - (kernel_h, kernel_w), - ); - - let (col_size_0, col_size_1) = columns.dim(); - let col_size_0 = col_size_0 / groups; - let out_c_per_group = out_channels / groups; - - let weight = weight - .to_shape((groups, out_c_per_group, col_size_0)) - .unwrap(); - let columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap(); - let out = matmul( - weight.to_owned().into_dyn().into_shared(), - columns.to_owned().into_dyn().into_shared(), - ); - - let mut out = out - .into_shape_with_order((out_channels, batch_size, out_h, out_w)) - .unwrap(); - out.swap_axes(0, 1); - - if let Some(bias) = bias { - let bias = bias.to_shape((1, out_channels, 1, 1)).unwrap(); - out.add_assign(&bias); - } - - out.into_dyn().into_shared() -} - -pub(crate) fn deform_im2col( - input: ArrayView4, - offset: ArrayView4, - mask: Option>, - args: DeformConvOptions<2>, - out_dims: (usize, usize), - kernel_dims: (usize, usize), -) -> Array2 { - let (batch_size, in_channels, _, _) = input.dim(); - let (kernel_h, kernel_w) = kernel_dims; - let (out_h, out_w) = out_dims; - let channels_per_offset_group = in_channels / args.offset_groups; - - let mut columns = Array4::zeros(Dim([ - in_channels, - kernel_h, - kernel_w, - batch_size * out_h * out_w, - ])); - - let groups = args.offset_groups; - - run_par!(|| { - iter_par!(columns.axis_iter_mut(Axis(3))) - .enumerate() - .for_each(|(index, mut columns)| { - let out_x = index % out_w; - let out_y = (index / out_w) % out_h; - let batch = (index / (out_w * out_h)) % batch_size; - let offset = offset.slice(s![batch, .., out_y, out_x]); - let offset = offset.to_shape((groups, kernel_h, kernel_w, 2)).unwrap(); - let mask = mask - .as_ref() - .map(|it| it.slice(s![batch, .., .., .., out_y, out_x])); - columns - .axis_iter_mut(Axis(0)) - .enumerate() - .for_each(|(in_channel, mut columns)| { - let group_index = in_channel / channels_per_offset_group; - deform_im2col_kernel( - out_y, - out_x, - input.slice(s![batch, in_channel, .., ..]), - offset.slice(s![group_index, .., .., ..]), - mask.as_ref().map(|it| it.slice(s![group_index, .., ..])), - columns.view_mut(), - args.clone(), - kernel_dims, - ); - }); - }); - }); - - columns - // Columns is created here, so we know it's contiguous - .into_shape_with_order(( - in_channels * kernel_h * kernel_w, - batch_size * out_h * out_w, - )) - .unwrap() -} - -pub mod backward { - #[cfg(target_has_atomic = "32")] - use core::sync::atomic::Ordering; - - use atomic_float::AtomicF32; - use ndarray::{Array1, Array5, ArrayView4, ArrayView6, Ix4}; - - use super::*; - - pub(crate) type DeformConv2dBackward = ( - SharedArray, - SharedArray, - SharedArray, - Option>, - Option>, - ); - - /// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions. - pub(crate) fn deform_conv2d_backward( - input: SharedArray, - offset: SharedArray, - weight: SharedArray, - mask: Option>, - bias: Option>, - out_grad: SharedArray, - args: DeformConvOptions<2>, - ) -> DeformConv2dBackward { - let [batch_size, out_channels, out_h, out_w] = out_grad.shape().dims(); - let [_, _, kernel_h, kernel_w] = weight.shape().dims(); - let groups = args.weight_groups; - let out_c_per_group = out_channels / groups; - let col_shape_1 = batch_size * out_h * out_w; - let mut out_grad = out_grad.into_dimensionality::().unwrap(); - - let gradient_bias = bias.map(|_| { - let out_grad = out_grad - .clone() - .sum_axis(Axis(0)) - .sum_axis(Axis(1)) - .sum_axis(Axis(1)); - - out_grad.into_dyn().into_shared() - }); - - out_grad.swap_axes(0, 1); - let out_grad = out_grad - .to_shape((groups, out_c_per_group, col_shape_1)) - .unwrap(); - - let input = input.into_dimensionality::().unwrap(); - let offset = offset.into_dimensionality::().unwrap(); - let mask = mask.map(|it| { - it.into_shape_with_order(( - batch_size, - args.offset_groups, - kernel_h, - kernel_w, - out_h, - out_w, - )) - .unwrap() - }); - - let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs( - input.view(), - weight, - offset.view(), - mask.as_ref().map(|it| it.view()), - out_grad.view(), - &args, - (kernel_h, kernel_w), - ); - - let weight_grad = compute_weight_grad( - input.view(), - offset.view(), - mask.as_ref().map(|it| it.view()), - out_grad.view(), - args, - (kernel_h, kernel_w), - (out_h, out_w), - ); - - ( - input_gradient, - offset_gradient, - weight_grad, - mask_gradient, - gradient_bias, - ) - } - - fn compute_weight_grad( - input: ArrayView4, - offset: ArrayView4, - mask: Option>, - out_grad: ArrayView3, - options: DeformConvOptions<2>, - kernel_dims: (usize, usize), - out_dims: (usize, usize), - ) -> SharedArray { - let in_channels = input.dim().1; - let (groups, out_c_per_group, _) = out_grad.dim(); - let (kernel_h, kernel_w) = kernel_dims; - - let in_c_per_group = in_channels / groups; - - let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims); - let (col_size_0, col_size_1) = columns.dim(); - let col_size_0 = col_size_0 / groups; - - let mut columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap(); - columns.swap_axes(1, 2); - - let grad_weight = matmul( - out_grad.to_owned().into_dyn().into_shared(), - columns.to_owned().into_dyn().into_shared(), - ); - - let grad_weight = grad_weight - .into_shape_with_order((out_c_per_group * groups, in_c_per_group, kernel_h, kernel_w)) - .unwrap(); - grad_weight.into_dyn().into_shared() - } - - type InputGradients = (SharedArray, SharedArray, Option>); - - fn backward_gradient_inputs( - image: ArrayView4, - weight: SharedArray, - offset: ArrayView4, - mask: Option>, - out_grad: ArrayView3, - args: &DeformConvOptions<2>, - kernel_dims: (usize, usize), - ) -> InputGradients { - let input_shape = image.dim(); - let in_channels = input_shape.1; - let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.shape().dims(); - let (batch_size, _, out_h, out_w) = offset.dim(); - - let groups = args.weight_groups; - let out_c_per_group = out_channels / groups; - - let col_shape_0 = in_c_per_group * kernel_h * kernel_w; - - let mut weight = weight - .to_shape((groups, out_c_per_group, col_shape_0)) - .unwrap(); - weight.swap_axes(1, 2); - let columns = matmul( - weight.to_owned().into_dyn().into_shared(), - out_grad.to_owned().into_dyn().into_shared(), - ); - - let columns = columns - .to_shape((in_channels, kernel_h, kernel_w, batch_size, out_h, out_w)) - .unwrap(); - - let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient( - columns.view(), - image.view(), - offset, - mask, - args, - kernel_dims, - ); - - let input_gradient = - compute_input_grad(columns.view(), offset, mask, args, kernel_dims, input_shape); - - (input_gradient, offset_gradient, mask_gradient) - } - - fn compute_offset_and_mask_gradient( - columns: ArrayView6, - image: ArrayView4, - offset: ArrayView4, - mask: Option>, - args: &DeformConvOptions<2>, - kernel_dims: (usize, usize), - ) -> (SharedArray, Option>) { - let (kernel_h, kernel_w) = kernel_dims; - let (_, in_channels, height, width) = image.dim(); - let (batch_size, offset_channels, out_h, out_w) = offset.dim(); - let offs_groups = args.offset_groups; - let channels_per_offset_group = in_channels / args.offset_groups; - - let mut grad_offset = Array5::zeros(( - offs_groups, - kernel_h, - kernel_w, - 2, - batch_size * out_h * out_w, - )); - let mut grad_mask = - Array4::zeros((offs_groups, kernel_h, kernel_w, batch_size * out_h * out_w)); - - grad_mask - .axis_iter_mut(Axis(3)) - .zip(grad_offset.axis_iter_mut(Axis(4))) - .enumerate() - .for_each(|(index, (mut grad_mask, mut grad_offset))| { - let out_x = index % out_w; - let out_y = (index / out_w) % out_h; - let batch = index / (out_w * out_h); - let offset = offset.slice(s![batch, .., out_y, out_x]); - let offset = offset - .to_shape((offs_groups, kernel_h, kernel_w, 2)) - .unwrap(); - let mask: Option> = mask - .as_ref() - .map(|mask| mask.slice(s![batch, .., .., .., out_y, out_x])); - let columns = columns.slice(s![.., .., .., batch, out_y, out_x]); - let image = image.slice(s![batch, .., .., ..]); - - for ((group, kernel_y, kernel_x), grad_mask) in grad_mask.indexed_iter_mut() { - let grad_mask: &mut F = grad_mask; - let mut grad_offset = grad_offset.slice_mut(s![group, kernel_y, kernel_x, ..]); - let offset = offset.slice(s![group, kernel_y, kernel_x, ..]); - let mask = mask.map(|it| it[[group, kernel_y, kernel_x]]); - let columns = columns.slice(s![.., kernel_y, kernel_x]); - let group_offset = group * channels_per_offset_group; - let image = image.slice(s![group_offset.., .., ..]); - let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0]) - - F::from_elem(args.padding[0]) - + offset[0]; - let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1]) - - F::from_elem(args.padding[1]) - + offset[1]; - for (i, grad_offset) in grad_offset.iter_mut().enumerate() { - let is_y_direction = i % 2 == 0; - let use_mask = mask.is_some(); - - for channel in 0..channels_per_offset_group { - let mask = mask.unwrap_or_else(|| F::one()); - let image = image.index_axis(Axis(0), channel); - let weight = - get_coordinate_weight(image, height, width, y, x, is_y_direction); - *grad_offset += mask * weight * columns[channel]; - if use_mask && is_y_direction { - *grad_mask += columns[channel] - * bilinear_interpolate(image, height, width, y, x); - } - } - } - } - }); - - let mask_gradient = mask.map(|_| { - let mut grad_mask = grad_mask - .into_shape_with_order((offset_channels / 2, batch_size, out_h, out_w)) - .unwrap(); - grad_mask.swap_axes(0, 1); - grad_mask.into_dyn().into_shared() - }); - let mut grad_offset = grad_offset - .into_shape_with_order((offset_channels, batch_size, out_h, out_w)) - .unwrap(); - grad_offset.swap_axes(0, 1); - let offset_gradient = grad_offset.into_dyn().into_shared(); - (offset_gradient, mask_gradient) - } - - fn get_coordinate_weight( - input: ArrayView2, - height: usize, - width: usize, - y: F, - x: F, - is_y_direction: bool, - ) -> F { - let y = y.to_f32(); - let x = x.to_f32(); - - let y_low = f32::floor(y); - let x_low = f32::floor(x); - let y_high = y_low + 1.; - let x_high = x_low + 1.; - - let valid_y_low = y_low >= 0. && y_low < height as f32; - let valid_y_high = y_high >= 0. && y_high < height as f32; - let valid_x_low = x_low >= 0. && x_low < width as f32; - let valid_x_high = x_high >= 0. && x_high < width as f32; - - let bottom_left = if valid_y_low && valid_x_low { - input[[y_low as usize, x_low as usize]] - } else { - F::zero() - }; - let bottom_right = if valid_y_low && valid_x_high { - input[[y_low as usize, x_high as usize]] - } else { - F::zero() - }; - let top_left = if valid_y_high && valid_x_low { - input[[y_high as usize, x_low as usize]] - } else { - F::zero() - }; - let top_right = if valid_y_high && valid_x_high { - input[[y_high as usize, x_high as usize]] - } else { - F::zero() - }; - - if is_y_direction { - let delta_x = F::from_elem(x - x_low); - delta_x * (top_right - bottom_right) + (F::one() - delta_x) * (top_left - bottom_left) - } else { - let delta_y = F::from_elem(y - y_low); - delta_y * (top_right - top_left) + (F::one() - delta_y) * (bottom_right - bottom_left) - } - } - - fn compute_input_grad( - columns: ArrayView6, - offset: ArrayView4, - mask: Option>, - args: &DeformConvOptions<2>, - kernel_dims: (usize, usize), - input_shape: (usize, usize, usize, usize), - ) -> SharedArray { - let (batch_size, in_channels, height, width) = input_shape; - let (kernel_h, kernel_w) = kernel_dims; - let offs_groups = args.offset_groups; - let channels_per_offset_group = in_channels / offs_groups; - - let grad_in = - Array4::from_shape_simple_fn((batch_size, in_channels, height, width), || { - AtomicF32::new(0.0) - }); - - let compute_for_each = |(in_channel, kernel_y, kernel_x, batch, out_y, out_x), col: &F| { - let group = in_channel / channels_per_offset_group; - let offset = offset.slice(s![batch, .., out_y, out_x]); - let offset = offset - .to_shape((offs_groups, kernel_h, kernel_w, 2)) - .unwrap(); - let offset = offset.slice(s![group, kernel_y, kernel_x, ..]); - let offset = [offset[0], offset[1]]; - let mask = mask - .as_ref() - .map(|it| it[[batch, group, kernel_y, kernel_x, out_y, out_x]].to_f32()); - let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0]) - - F::from_elem(args.padding[0]) - + offset[0]; - let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1]) - - F::from_elem(args.padding[1]) - + offset[1]; - let grad_in = grad_in.slice(s![batch, in_channel, .., ..]); - deform_col2img_kernel(y.to_f32(), x.to_f32(), mask, col.to_f32(), grad_in); - }; - - // `for_each` expects a 2-tuple argument with `.into_par_iter()`, but 2 separate arguments otherwise - #[cfg(feature = "multi-threads")] - run_par!(|| { - iter_par!(Zip::indexed(columns)) - .for_each(|(args0, args1)| compute_for_each(args0, args1)) - }); - - #[cfg(not(feature = "multi-threads"))] - run_par!(|| { iter_par!(Zip::indexed(columns)).for_each(&compute_for_each) }); - - let grad_in: Array1 = grad_in - .into_iter() - .map(|it| F::from_elem(it.into_inner())) - .collect(); - let grad_in = grad_in - .into_shape_with_order((batch_size, in_channels, height, width)) - .unwrap(); - grad_in.into_dyn().into_shared() - } - - fn deform_col2img_kernel( - y: f32, - x: f32, - mask: Option, - col: f32, - grad_input: ArrayView2, - ) { - let (height, width) = grad_input.dim(); - let mask_value = mask.unwrap_or(1.0); - - for dy in -1..=1 { - for dx in -1..=1 { - let yp = f32::floor(y) + dy as f32; - let xp = f32::floor(x) + dx as f32; - - if yp >= 0.0 - && yp < height as f32 - && xp >= 0.0 - && xp < width as f32 - && f32::abs(y - yp) < 1.0 - && f32::abs(x - xp) < 1.0 - { - let weight = (1.0 - f32::abs(y - yp)) * (1.0 - f32::abs(x - xp)); - - #[cfg_attr(not(target_has_atomic = "32"), allow(unused))] - let value = mask_value * weight * col; - - #[cfg(target_has_atomic = "32")] - grad_input[[yp as usize, xp as usize]].fetch_add(value, Ordering::AcqRel); - #[cfg(not(target_has_atomic = "32"))] - panic!("Can't use deformable convolution backwards pass without atomics"); - } - } - } - } -} diff --git a/crates/burn/src/ops/deform_conv.rs b/crates/burn/src/ops/deform_conv.rs new file mode 120000 index 00000000..fdc04daf --- /dev/null +++ b/crates/burn/src/ops/deform_conv.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/deform_conv.rs \ No newline at end of file diff --git a/crates/burn/src/ops/grid_sample.rs b/crates/burn/src/ops/grid_sample.rs deleted file mode 100644 index 256c2fd8..00000000 --- a/crates/burn/src/ops/grid_sample.rs +++ /dev/null @@ -1,214 +0,0 @@ -use burn_backend::ElementConversion; -use burn_backend::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode}; -#[cfg(not(feature = "std"))] -#[allow(unused_imports)] -use num_traits::Float; - -use ndarray::Array4; - -use crate::SharedArray; -use crate::{FloatNdArrayElement, UnsafeSharedRef, iter_range_par, run_par}; - -/// Sample a tensor using grid-based sampling. -/// -/// # Arguments -/// -/// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in) -/// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1]. -/// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right -/// * `options` - Grid sampling options (mode, padding_mode, align_corners) -/// -/// # Returns -/// -/// A tensor with shape (N, C, H_out, W_out) -pub(crate) fn grid_sample_2d( - tensor: SharedArray, - grid: SharedArray, - options: GridSampleOptions, -) -> SharedArray { - match options.mode { - InterpolateMode::Bilinear => (), - _ => todo!( - "grid_sample_2d with {:?} mode is not implemented", - options.mode - ), - } - - let tensor = tensor.into_dimensionality::().unwrap(); - let grid = grid.into_dimensionality::().unwrap(); - - let (batch_size, channels, height_in, width_in) = tensor.dim(); - let (b, height_out, width_out, d) = grid.dim(); - assert!(batch_size == b); - assert!(2 == d); - - let mut output = Array4::zeros((batch_size, channels, height_out, width_out)); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - let sample_count = batch_size * channels * height_out * width_out; - let strides = ( - channels * height_out * width_out, - height_out * width_out, - width_out, - ); - - let align = options.align_corners; - let pad_mode = options.padding_mode; - - run_par!(|| { - iter_range_par!(0, sample_count).for_each(|id| { - let (b, c, y, x) = ( - id / strides.0, - id % strides.0 / strides.1, - id % strides.1 / strides.2, - id % strides.2, - ); - - let sample_x = grid[(b, y, x, 0)].elem::(); - let sample_y = grid[(b, y, x, 1)].elem::(); - - // Convert normalized grid coordinates [-1, 1] to pixel coordinates - let (px, py) = if align { - // align_corners=true: x_pixel = (x_norm + 1) * (width - 1) / 2 - // Maps -1 to 0 and 1 to width - 1 - let px = (sample_x + 1.0) * ((width_in - 1) as f64) / 2.0; - let py = (sample_y + 1.0) * ((height_in - 1) as f64) / 2.0; - (px, py) - } else { - // align_corners=false: x_pixel = (x_norm + 1) * width / 2 - 0.5 - // Maps -1 to -0.5 and 1 to width - 0.5 - let px = (sample_x + 1.0) * (width_in as f64) / 2.0 - 0.5; - let py = (sample_y + 1.0) * (height_in as f64) / 2.0 - 0.5; - (px, py) - }; - - // Bilinear interpolation with the specified padding mode - let val = - bilinear_interpolate(&tensor, b, c, px, py, width_in, height_in, pad_mode, align); - - unsafe { - let output = unsafe_shared_out.get(); - output[(b, c, y, x)] = val.elem(); - } - }); - }); - - output.into_dyn().into_shared() -} - -/// Bilinear interpolation at a point with configurable padding mode. -#[allow(clippy::too_many_arguments)] -fn bilinear_interpolate( - source: &ndarray::ArrayBase>, - b: usize, - c: usize, - x: f64, - y: f64, - width: usize, - height: usize, - padding_mode: GridSamplePaddingMode, - align_corners: bool, -) -> f64 -where - E: FloatNdArrayElement, - S: ndarray::Data, -{ - // Handle inf/nan coordinates - if !x.is_finite() || !y.is_finite() { - return match padding_mode { - GridSamplePaddingMode::Zeros => 0.0, - GridSamplePaddingMode::Border => { - // Clamp to center of image for inf/nan - let cx = ((width - 1) as f64 / 2.0).clamp(0.0, (width - 1) as f64); - let cy = ((height - 1) as f64 / 2.0).clamp(0.0, (height - 1) as f64); - source[(b, c, cy as usize, cx as usize)].elem::() - } - GridSamplePaddingMode::Reflection => 0.0, // Simplified: treat as zeros for inf/nan - }; - } - - // Apply padding mode to get actual sampling coordinates - let (x, y) = match padding_mode { - GridSamplePaddingMode::Border => { - // Clamp coordinates to valid range [0, size-1] - let x = x.clamp(0.0, (width - 1) as f64); - let y = y.clamp(0.0, (height - 1) as f64); - (x, y) - } - GridSamplePaddingMode::Reflection => { - // Reflect coordinates at boundaries - let x = reflect_coordinate(x, width, align_corners); - let y = reflect_coordinate(y, height, align_corners); - (x, y) - } - GridSamplePaddingMode::Zeros => (x, y), // Keep as-is, handle out-of-bounds in read - }; - - // Get the four corner indices - let x0 = x.floor() as i64; - let y0 = y.floor() as i64; - let x1 = x0.saturating_add(1); - let y1 = y0.saturating_add(1); - - // Compute interpolation weights (fractional part) - let x_frac = x - x.floor(); - let y_frac = y - y.floor(); - - // Helper to read a value based on padding mode - let read_value = |xi: i64, yi: i64| -> f64 { - match padding_mode { - GridSamplePaddingMode::Zeros => { - // Return 0 for out-of-bounds - if xi >= 0 && xi < width as i64 && yi >= 0 && yi < height as i64 { - source[(b, c, yi as usize, xi as usize)].elem::() - } else { - 0.0 - } - } - GridSamplePaddingMode::Border | GridSamplePaddingMode::Reflection => { - // Coordinates should already be in valid range after clamping/reflection - let xi = xi.clamp(0, (width - 1) as i64) as usize; - let yi = yi.clamp(0, (height - 1) as i64) as usize; - source[(b, c, yi, xi)].elem::() - } - } - }; - - // Read the four corners - let v00 = read_value(x0, y0); - let v01 = read_value(x0, y1); - let v10 = read_value(x1, y0); - let v11 = read_value(x1, y1); - - // Bilinear interpolation weights - let w00 = (1.0 - x_frac) * (1.0 - y_frac); - let w01 = (1.0 - x_frac) * y_frac; - let w10 = x_frac * (1.0 - y_frac); - let w11 = x_frac * y_frac; - - v00 * w00 + v01 * w01 + v10 * w10 + v11 * w11 -} - -/// Reflect a coordinate at the boundaries using a triangle wave pattern. -/// -/// For align_corners=true: reflects within [0, size-1] -/// For align_corners=false: reflects within [-0.5, size-0.5] -fn reflect_coordinate(coord: f64, size: usize, align_corners: bool) -> f64 { - let size_f = size as f64; - let (min_val, max_val) = if align_corners { - (0.0, size_f - 1.0) - } else { - (-0.5, size_f - 0.5) - }; - - let span = max_val - min_val; - if span <= 0.0 { - return min_val; - } - - // Triangle wave formula: span - |((x mod 2*span) - span)| - let period = 2.0 * span; - let x = (coord - min_val).abs(); - let x_mod = x - (x / period).floor() * period; - span - (x_mod - span).abs() + min_val -} diff --git a/crates/burn/src/ops/grid_sample.rs b/crates/burn/src/ops/grid_sample.rs new file mode 120000 index 00000000..45b8d55c --- /dev/null +++ b/crates/burn/src/ops/grid_sample.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/grid_sample.rs \ No newline at end of file diff --git a/crates/burn/src/ops/int_tensor.rs b/crates/burn/src/ops/int_tensor.rs deleted file mode 100644 index 02710cdc..00000000 --- a/crates/burn/src/ops/int_tensor.rs +++ /dev/null @@ -1,509 +0,0 @@ -// Language -use crate::rand::get_seeded_rng; -use alloc::vec::Vec; -use burn_backend::backend::ExecutionError; -use burn_backend::ops::IntTensorOps; -use burn_backend::tensor::{FloatTensor, IntTensor}; -use burn_backend::{Distribution, IntDType, Scalar, TensorMetadata}; - -use burn_backend::ElementConversion; -use burn_std::{BoolDType, FloatDType}; - -// Current crate -use crate::{ExpElement, NdArrayDevice, SEED, execute_with_int_out_dtype, slice}; -use crate::{NdArray, cast_to_dtype, execute_with_dtype, tensor::NdArrayTensor}; -use crate::{SharedArray, element::QuantElement}; -use crate::{cat_with_dtype, execute_with_float_out_dtype}; -use crate::{element::FloatNdArrayElement, ops::matmul::matmul}; -use crate::{element::IntNdArrayElement, execute_with_int_dtype}; - -// Workspace crates -use super::{NdArrayBitOps, NdArrayMathOps, NdArrayOps}; -use burn_backend::{DType, Shape, TensorData, backend::Backend}; - -impl IntTensorOps - for NdArray -where - NdArrayTensor: From>, - NdArrayTensor: From>, -{ - fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor { - if data.dtype.is_int() || data.dtype.is_uint() { - NdArrayTensor::from_data(data) - } else { - unimplemented!("Unsupported dtype for `int_from_data`: {:?}", data.dtype) - } - } - - async fn int_into_data(tensor: NdArrayTensor) -> Result { - Ok(tensor.into_data()) - } - - fn int_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor { - tensor - } - - fn int_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { - execute_with_int_dtype!(tensor, |array| NdArrayOps::reshape(array, shape)) - } - - fn int_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor { - slice!(tensor, slices) - } - - fn int_device(_tensor: &NdArrayTensor) -> as Backend>::Device { - NdArrayDevice::Cpu - } - - fn int_empty( - shape: Shape, - device: & as Backend>::Device, - dtype: IntDType, - ) -> NdArrayTensor { - Self::int_zeros(shape, device, dtype) - } - - fn int_matmul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - execute_with_int_dtype!((lhs, rhs), matmul) - } - - fn int_mask_where( - tensor: NdArrayTensor, - mask: NdArrayTensor, - source: NdArrayTensor, - ) -> NdArrayTensor { - execute_with_int_dtype!((tensor, source), |tensor, source| { - NdArrayOps::mask_where(tensor, mask.bool(), source) - }) - } - - fn int_mask_fill(tensor: NdArrayTensor, mask: NdArrayTensor, value: Scalar) -> NdArrayTensor { - execute_with_int_dtype!(tensor, |array| NdArrayOps::mask_fill( - array, - mask.bool(), - value.elem() - )) - } - - fn int_slice_assign( - tensor: NdArrayTensor, - slices: &[burn_backend::Slice], - value: NdArrayTensor, - ) -> NdArrayTensor { - execute_with_int_dtype!((tensor, value), |tensor, value| NdArrayOps::slice_assign( - tensor, slices, value - )) - } - - fn int_cat(tensors: Vec, dim: usize) -> NdArrayTensor { - cat_with_dtype!(tensors, dim, [I64, I32, I16, I8, U64, U32, U16, U8]) - } - - fn int_equal(lhs: NdArrayTensor, rhs: NdArrayTensor, _out_dtype: BoolDType) -> NdArrayTensor { - execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::equal) - } - - fn int_equal_elem(lhs: NdArrayTensor, rhs: Scalar, _out_dtype: BoolDType) -> NdArrayTensor { - execute_with_int_dtype!(lhs, |array| NdArrayMathOps::equal_elem(array, rhs.elem())) - } - - fn int_greater(lhs: NdArrayTensor, rhs: NdArrayTensor, _out_dtype: BoolDType) -> NdArrayTensor { - execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater) - } - - fn int_greater_elem(lhs: NdArrayTensor, rhs: Scalar, _out_dtype: BoolDType) -> NdArrayTensor { - execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_elem(array, rhs.elem())) - } - - fn int_greater_equal( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - _out_dtype: BoolDType, - ) -> NdArrayTensor { - execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater_equal) - } - - fn int_greater_equal_elem( - lhs: NdArrayTensor, - rhs: Scalar, - _out_dtype: BoolDType, - ) -> NdArrayTensor { - execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_equal_elem( - array, - rhs.elem() - )) - } - - fn int_lower(lhs: NdArrayTensor, rhs: NdArrayTensor, _out_dtype: BoolDType) -> NdArrayTensor { - execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower) - } - - fn int_lower_elem(lhs: NdArrayTensor, rhs: Scalar, _out_dtype: BoolDType) -> NdArrayTensor { - execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_elem(array, rhs.elem())) - } - - fn int_lower_equal( - lhs: NdArrayTensor, - rhs: NdArrayTensor, - _out_dtype: BoolDType, - ) -> NdArrayTensor { - execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower_equal) - } - - fn int_lower_equal_elem( - lhs: NdArrayTensor, - rhs: Scalar, - _out_dtype: BoolDType, - ) -> NdArrayTensor { - execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_equal_elem( - array, - rhs.elem() - )) - } - - fn int_add(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { - execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::add) - } - - fn int_add_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { - execute_with_int_dtype!(lhs, |array| NdArrayMathOps::add_scalar(array, rhs.elem())) - } - - fn int_sub(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { - execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::sub) - } - - fn int_sub_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { - execute_with_int_dtype!(lhs, |array| NdArrayMathOps::sub_scalar(array, rhs.elem())) - } - - fn int_mul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { - execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::mul) - } - - fn int_mul_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { - execute_with_int_dtype!(lhs, |array| NdArrayMathOps::mul_scalar(array, rhs.elem())) - } - - fn int_div(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { - execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::div) - } - - fn int_div_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { - execute_with_int_dtype!(lhs, |array| NdArrayMathOps::div_scalar(array, rhs.elem())) - } - - fn int_remainder(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { - execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::remainder) - } - - fn int_remainder_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { - execute_with_int_dtype!(lhs, |array| NdArrayMathOps::remainder_scalar( - array, - rhs.elem() - )) - } - - fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor { - // Use view() for zero-copy on borrowed storage - execute_with_int_dtype!(tensor, E, |array: SharedArray| NdArrayMathOps::sum_view( - array.view() - )) - } - - fn int_sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - execute_with_int_dtype!(tensor, |array| NdArrayMathOps::sum_dim(array, dim)) - } - - fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor { - // Use view() for zero-copy on borrowed storage - execute_with_int_dtype!( - tensor, - E, - |array: SharedArray| NdArrayMathOps::prod_view(array.view()) - ) - } - - fn int_prod_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - execute_with_int_dtype!(tensor, |array| NdArrayMathOps::prod_dim(array, dim)) - } - - fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor { - // Use view() for zero-copy on borrowed storage - execute_with_int_dtype!( - tensor, - E, - |array: SharedArray| NdArrayMathOps::mean_view(array.view()) - ) - } - - fn int_mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - execute_with_int_dtype!(tensor, |array| NdArrayMathOps::mean_dim(array, dim)) - } - - fn int_max(tensor: NdArrayTensor) -> NdArrayTensor { - // Use view() for zero-copy on borrowed storage - execute_with_int_dtype!(tensor, E, |array: SharedArray| NdArrayMathOps::max_view( - array.view() - )) - } - - fn int_min(tensor: NdArrayTensor) -> NdArrayTensor { - // Use view() for zero-copy on borrowed storage - execute_with_int_dtype!(tensor, E, |array: SharedArray| NdArrayMathOps::min_view( - array.view() - )) - } - - fn int_cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumsum(array, dim)) - } - - fn int_cumprod(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumprod(array, dim)) - } - - fn int_cummin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummin(array, dim)) - } - - fn int_cummax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummax(array, dim)) - } - - fn int_gather(dim: usize, tensor: NdArrayTensor, indices: NdArrayTensor) -> NdArrayTensor { - execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor { - execute_with_int_dtype!(indices, |idx_array| NdArrayOps::gather( - dim, array, idx_array - )) - }) - } - - fn int_scatter_add( - dim: usize, - tensor: NdArrayTensor, - indices: NdArrayTensor, - value: NdArrayTensor, - ) -> NdArrayTensor { - execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor { - execute_with_int_dtype!(indices, |idx_array| NdArrayOps::::scatter( - dim, tensor, idx_array, value - )) - }) - } - - fn int_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor { - execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor { - execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::select( - array, dim, idx_array - )) - }) - } - - fn int_select_add( - tensor: NdArrayTensor, - dim: usize, - indices: NdArrayTensor, - value: NdArrayTensor, - ) -> NdArrayTensor { - execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor { - execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::::select_assign( - tensor, dim, idx_array, value - )) - }) - } - fn int_argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - // Use view() for zero-copy on borrowed storage - execute_with_int_dtype!(tensor, E, |array: SharedArray| { - NdArrayMathOps::argmax_view::(array.view(), dim) - }) - } - - fn int_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { - // Use view() for zero-copy on borrowed storage - execute_with_int_dtype!(tensor, E, |array: SharedArray| { - NdArrayMathOps::argmin_view::(array.view(), dim) - }) - } - - fn int_clamp_min(tensor: NdArrayTensor, min: Scalar) -> NdArrayTensor { - execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, min.elem())) - } - - fn int_clamp_max(tensor: NdArrayTensor, max: Scalar) -> NdArrayTensor { - execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_max(array, max.elem())) - } - - fn int_clamp(tensor: NdArrayTensor, min: Scalar, max: Scalar) -> NdArrayTensor { - execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp( - array, - min.elem(), - max.elem() - )) - } - - fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor { - match tensor.dtype() { - DType::I64 | DType::I32 | DType::I16 | DType::I8 => { - execute_with_dtype!(tensor, I, NdArrayMathOps::abs, [ - I64 => i64, I32 => i32, I16 => i16, I8 => i8 - ]) - } - // Already unsigned - DType::U64 | DType::U32 | DType::U16 | DType::U8 => tensor, - other => panic!("Unsupported dtype: {other:?}"), - } - } - - fn int_into_float(tensor: NdArrayTensor, out_dtype: FloatDType) -> FloatTensor { - execute_with_float_out_dtype!(out_dtype, F, { - execute_with_int_dtype!(tensor, IntElem, |array: SharedArray| { - array.mapv(|a: IntElem| a.elem::()).into_shared() - }) - }) - } - - fn int_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor { - execute_with_int_dtype!(tensor, |array| NdArrayOps::swap_dims(array, dim1, dim2)) - } - - fn int_random( - shape: Shape, - distribution: Distribution, - device: &NdArrayDevice, - dtype: IntDType, - ) -> NdArrayTensor { - let mut seed = SEED.lock().unwrap(); - let mut rng = seed.take().unwrap_or_else(get_seeded_rng); - - let effective_distribution = if distribution == Distribution::Default { - Distribution::Uniform(0.0, 255.0) // Assuming UniformInt is the integer variant - } else { - distribution - }; - - let tensor = execute_with_int_out_dtype!( - dtype, - I, - Self::int_from_data( - TensorData::random::(shape, effective_distribution, &mut rng), - device, - ) - ); - *seed = Some(rng); - tensor - } - - fn int_powi(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { - execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| NdArrayMathOps::elementwise_op( - lhs, - rhs, - |a: &I, b: &I| { (a.elem::().pow(b.elem::())).elem() } - )) - } - - fn int_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor { - execute_with_int_dtype!(tensor, |array| NdArrayOps::permute(array, axes)) - } - - fn int_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor { - execute_with_int_dtype!(tensor, |array| NdArrayOps::flip(array, axes)) - } - - fn int_sign(tensor: NdArrayTensor) -> NdArrayTensor { - match tensor.dtype() { - DType::I64 | DType::I32 | DType::I16 | DType::I8 => { - execute_with_dtype!(tensor, I, NdArrayMathOps::sign_op, [ - I64 => i64, I32 => i32, I16 => i16, I8 => i8 - ]) - } - DType::U64 | DType::U32 | DType::U16 | DType::U8 => { - Self::int_greater_elem(tensor, 0.into(), BoolDType::Native) - } - other => panic!("Unsupported dtype: {other:?}"), - } - } - - fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { - execute_with_int_dtype!(tensor, |array| NdArrayOps::expand(array, shape)) - } - - fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { - execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitand) - } - - fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { - execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitand_scalar(array, rhs.elem())) - } - - fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { - execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitor) - } - - fn bitwise_or_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { - execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitor_scalar(array, rhs.elem())) - } - - fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { - execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitxor) - } - - fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { - execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitxor_scalar(array, rhs.elem())) - } - - fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor { - execute_with_int_dtype!(tensor, NdArrayBitOps::bitnot) - } - - fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { - execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| { - NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { - (a.elem::() << (b.elem::())).elem() - }) - }) - } - - fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { - execute_with_int_dtype!(lhs, I, |array| { - NdArrayMathOps::elementwise_op_scalar(array, |a: I| { - (a.elem::() << rhs.elem::()).elem() - }) - }) - } - - fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { - execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| { - NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { - (a.elem::() >> (b.elem::())).elem() - }) - }) - } - - fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { - execute_with_int_dtype!(lhs, I, |array| { - NdArrayMathOps::elementwise_op_scalar(array, |a: I| { - (a.elem::() >> rhs.elem::()).elem() - }) - }) - } - - fn int_cast(tensor: IntTensor, dtype: IntDType) -> IntTensor { - execute_with_int_dtype!(tensor, |array| cast_to_dtype(array, dtype.into())) - } - - fn int_unfold( - tensor: IntTensor, - dim: usize, - size: usize, - step: usize, - ) -> IntTensor { - execute_with_int_dtype!(tensor, |array| NdArrayOps::unfold(array, dim, size, step)) - } - - fn int_powi_scalar_impl(lhs: IntTensor, rhs: Scalar) -> IntTensor { - execute_with_int_dtype!(lhs, I, |array| { - NdArrayMathOps::elementwise_op_scalar(array, |a: I| a.powi_elem(rhs.elem())) - }) - } -} diff --git a/crates/burn/src/ops/int_tensor.rs b/crates/burn/src/ops/int_tensor.rs new file mode 120000 index 00000000..8ebe8b96 --- /dev/null +++ b/crates/burn/src/ops/int_tensor.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/int_tensor.rs \ No newline at end of file diff --git a/crates/burn/src/ops/interpolate.rs b/crates/burn/src/ops/interpolate.rs deleted file mode 100644 index af9d50d1..00000000 --- a/crates/burn/src/ops/interpolate.rs +++ /dev/null @@ -1,397 +0,0 @@ -use burn_backend::ElementConversion; -use ndarray::{Array4, ArrayBase, DataOwned}; -#[cfg(not(feature = "std"))] -#[allow(unused_imports)] -use num_traits::Float; - -use crate::{FloatNdArrayElement, ShapeOps, SharedArray, UnsafeSharedRef, iter_range_par, run_par}; - -pub(crate) fn nearest_interpolate( - x: SharedArray, - output_size: [usize; 2], -) -> SharedArray { - let x = x.into_dimensionality::().unwrap(); - - let (batch_size, channels, in_height, in_width) = x.dim(); - let [out_height, out_width] = output_size; - - let y_ratio = (in_height as f64) / (out_height as f64); - let x_ratio = (in_width as f64) / (out_width as f64); - - let out_element_num = batch_size * channels * out_height * out_width; - let strides = ( - channels * out_height * out_width, - out_height * out_width, - out_width, - ); - - let mut output = Array4::zeros((batch_size, channels, out_height, out_width)); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - run_par!(|| { - iter_range_par!(0, out_element_num).for_each(|id| { - let (b, c, h, w) = ( - id / strides.0, - id % strides.0 / strides.1, - id % strides.1 / strides.2, - id % strides.2, - ); - - let y_in = (y_ratio * h as f64).floor() as usize; - let x_in = (x_ratio * w as f64).floor() as usize; - - unsafe { - let output = unsafe_shared_out.get(); - output[(b, c, h, w)] = x[(b, c, y_in, x_in)]; - } - }); - }); - - output.into_dyn().into_shared() -} - -pub(crate) fn nearest_interpolate_backward( - x: SharedArray, - grad: SharedArray, - output_size: [usize; 2], -) -> SharedArray { - let [batch_size, channels, input_height, input_width] = x.shape().dims(); - let [output_height, output_width] = output_size; - - let mut output_grad = - Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem()); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad); - - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; - - let output_grad = unsafe_shared_out.get(); - - for oh in 0..output_height { - for ow in 0..output_width { - let ih = start_index(oh, output_height, input_height); - let iw = start_index(ow, output_width, input_width); - - output_grad[[b, c, ih, iw]] += grad[[b, c, oh, ow]] - } - } - }) - }); - - output_grad.into_dyn().into_shared() -} - -fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { - ((output_size_index as f32 * input_size as f32) / output_size as f32).floor() as usize -} - -// clamp ceil(frac) to stay within bounds in case of floating-point imprecision -pub(crate) fn ceil_clamp(frac: f64, max: usize) -> f64 { - frac.ceil().min(max as f64) -} - -pub(crate) fn bilinear_interpolate( - x: SharedArray, - output_size: [usize; 2], - align_corners: bool, -) -> SharedArray { - let x = x.into_dimensionality::().unwrap(); - - let (batch_size, channels, in_height, in_width) = x.dim(); - let [out_height, out_width] = output_size; - - let out_element_num = batch_size * channels * out_height * out_width; - let strides = ( - channels * out_height * out_width, - out_height * out_width, - out_width, - ); - - let mut output = Array4::zeros((batch_size, channels, out_height, out_width)); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - run_par!(|| { - iter_range_par!(0, out_element_num).for_each(|id| { - let (b, c, h, w) = ( - id / strides.0, - id % strides.0 / strides.1, - id % strides.1 / strides.2, - id % strides.2, - ); - - let (y_frac, x_frac) = if align_corners { - let y_ratio = ((in_height - 1) as f64) / (core::cmp::max(out_height - 1, 1) as f64); - let x_ratio = ((in_width - 1) as f64) / (core::cmp::max(out_width - 1, 1) as f64); - (y_ratio * h as f64, x_ratio * w as f64) - } else { - let y_frac = (h as f64 + 0.5) * (in_height as f64 / out_height as f64) - 0.5; - let x_frac = (w as f64 + 0.5) * (in_width as f64 / out_width as f64) - 0.5; - ( - y_frac.clamp(0.0, (in_height - 1) as f64), - x_frac.clamp(0.0, (in_width - 1) as f64), - ) - }; - let val = - bilinear_interpolate_single(&x, b, c, x_frac, y_frac, in_width - 1, in_height - 1); - - unsafe { - let output = unsafe_shared_out.get(); - output[(b, c, h, w)] = val.elem(); - } - }); - }); - - output.into_dyn().into_shared() -} - -pub(crate) fn bicubic_interpolate( - x: SharedArray, - output_size: [usize; 2], - align_corners: bool, -) -> SharedArray { - fn cubic_interp1d(x0: f64, x1: f64, x2: f64, x3: f64, t: f64) -> f64 { - fn cubic_convolution1(x: f64, a: f64) -> f64 { - ((a + 2.0) * x - (a + 3.0)) * x * x + 1.0 - } - - fn cubic_convolution2(x: f64, a: f64) -> f64 { - ((a * x - 5.0 * a) * x + 8.0 * a) * x - 4.0 * a - } - - let coeffs = [ - cubic_convolution2(t + 1.0, -0.75), - cubic_convolution1(t, -0.75), - cubic_convolution1(1.0 - t, -0.75), - cubic_convolution2(2.0 - t, -0.75), - ]; - - x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3] - } - - let x = x.into_dimensionality::().unwrap(); - - let (batch_size, channels, in_height, in_width) = x.dim(); - let [out_height, out_width] = output_size; - - let out_element_num = batch_size * channels * out_height * out_width; - let strides = ( - channels * out_height * out_width, - out_height * out_width, - out_width, - ); - - let mut output = Array4::zeros((batch_size, channels, out_height, out_width)); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - run_par!(|| { - iter_range_par!(0, out_element_num).for_each(|id| { - let (b, c, h, w) = ( - id / strides.0, - id % strides.0 / strides.1, - id % strides.1 / strides.2, - id % strides.2, - ); - - let (y_frac, x_frac) = if align_corners { - let y_ratio = ((in_height - 1) as f64) / (core::cmp::max(out_height - 1, 1) as f64); - let x_ratio = ((in_width - 1) as f64) / (core::cmp::max(out_width - 1, 1) as f64); - (y_ratio * h as f64, x_ratio * w as f64) - } else { - let y_frac = (h as f64 + 0.5) * (in_height as f64 / out_height as f64) - 0.5; - let x_frac = (w as f64 + 0.5) * (in_width as f64 / out_width as f64) - 0.5; - (y_frac, x_frac) - }; - let y0 = y_frac.floor(); - let yw = y_frac - y0; - let y_in = y0 as isize; - - let x0 = x_frac.floor(); - let xw = x_frac - x0; - let x_in = x0 as isize; - - let max_h = (in_height - 1) as isize; - let max_w = (in_width - 1) as isize; - - let ys_in = [ - (y_in - 1).clamp(0, max_h) as usize, - y_in.clamp(0, max_h) as usize, - (y_in + 1).clamp(0, max_h) as usize, - (y_in + 2).clamp(0, max_h) as usize, - ]; - - let xs_in = [ - (x_in - 1).clamp(0, max_w) as usize, - x_in.clamp(0, max_w) as usize, - (x_in + 1).clamp(0, max_w) as usize, - (x_in + 2).clamp(0, max_w) as usize, - ]; - - let coefficients = ys_in.map(|y| { - cubic_interp1d( - x[(b, c, y, xs_in[0])].elem(), - x[(b, c, y, xs_in[1])].elem(), - x[(b, c, y, xs_in[2])].elem(), - x[(b, c, y, xs_in[3])].elem(), - xw, - ) - }); - - let result = cubic_interp1d( - coefficients[0], - coefficients[1], - coefficients[2], - coefficients[3], - yw, - ) - .elem(); - - unsafe { - let output = unsafe_shared_out.get(); - output[(b, c, h, w)] = result; - } - }); - }); - - output.into_dyn().into_shared() -} - -pub(crate) fn lanczos3_interpolate( - x: SharedArray, - output_size: [usize; 2], - align_corners: bool, -) -> SharedArray { - fn lanczos3_weight(x: f64) -> f64 { - if x == 0.0 { - return 1.0; - } - let abs_x = x.abs(); - if abs_x >= 3.0 { - return 0.0; - } - let pi = core::f64::consts::PI; - let pi_x = pi * x; - let pi_x_over_3 = pi_x / 3.0; - (pi_x.sin() * pi_x_over_3.sin()) / (pi_x * pi_x_over_3) - } - - let x = x.into_dimensionality::().unwrap(); - - let (batch_size, channels, in_height, in_width) = x.dim(); - let [out_height, out_width] = output_size; - - let out_element_num = batch_size * channels * out_height * out_width; - let strides = ( - channels * out_height * out_width, - out_height * out_width, - out_width, - ); - - let mut output = Array4::zeros((batch_size, channels, out_height, out_width)); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - run_par!(|| { - iter_range_par!(0, out_element_num).for_each(|id| { - let (b, c, h, w) = ( - id / strides.0, - id % strides.0 / strides.1, - id % strides.1 / strides.2, - id % strides.2, - ); - - let (y_frac, x_frac) = if align_corners { - let y_ratio = ((in_height - 1) as f64) / (core::cmp::max(out_height - 1, 1) as f64); - let x_ratio = ((in_width - 1) as f64) / (core::cmp::max(out_width - 1, 1) as f64); - (y_ratio * h as f64, x_ratio * w as f64) - } else { - let y_frac = (h as f64 + 0.5) * (in_height as f64 / out_height as f64) - 0.5; - let x_frac = (w as f64 + 0.5) * (in_width as f64 / out_width as f64) - 0.5; - (y_frac, x_frac) - }; - - let y0 = y_frac.floor(); - let x0 = x_frac.floor(); - let max_h = (in_height - 1) as isize; - let max_w = (in_width - 1) as isize; - - // 6x6 separable Lanczos3 filter (skip out-of-bounds positions) - let mut result = 0.0; - let mut weight_sum = 0.0; - for ky in -2..=3 { - let yi = y0 as isize + ky; - if yi < 0 || yi > max_h { - continue; - } - let y_idx = yi as usize; - let wy = lanczos3_weight(y_frac - (y0 + ky as f64)); - for kx in -2..=3 { - let xi = x0 as isize + kx; - if xi < 0 || xi > max_w { - continue; - } - let x_idx = xi as usize; - let wx = lanczos3_weight(x_frac - (x0 + kx as f64)); - let w = wy * wx; - let pixel: f64 = x[(b, c, y_idx, x_idx)].elem(); - result += pixel * w; - weight_sum += w; - } - } - if weight_sum != 0.0 { - result /= weight_sum; - } - - unsafe { - let output = unsafe_shared_out.get(); - output[(b, c, h, w)] = result.elem(); - } - }); - }); - - output.into_dyn().into_shared() -} - -/// Sample an element of the source array with bilinear interpolation -/// -/// * `source` - The tensor to read from. Has shape (batch_size, channels, height, width) -/// * `b` - The batch to read from -/// * `c` - The channel to read from -/// * `x` - The x position to read in the array -/// * `y` - The y position to read in the array -/// * `x_max` - The max x position (inclusive) -/// * `y_max` - The max y position (inclusive) -/// -/// # Returns -/// -/// The interpolated value read from the array -pub(crate) fn bilinear_interpolate_single( - source: &ArrayBase>, - b: usize, - c: usize, - x: f64, - y: f64, - x_max: usize, - y_max: usize, -) -> f64 -where - E: FloatNdArrayElement, - S: DataOwned, -{ - let y0 = y.floor(); - let y1 = ceil_clamp(y, y_max); - let yw = y - y0; - - let x0 = x.floor(); - let x1 = ceil_clamp(x, x_max); - let xw = x - x0; - - let (x0, x1, y0, y1) = (x0 as usize, x1 as usize, y0 as usize, y1 as usize); - - let p_a = source[(b, c, y0, x0)].elem::() * (1.0 - xw) * (1.0 - yw); - let p_b = source[(b, c, y0, x1)].elem::() * xw * (1.0 - yw); - let p_c = source[(b, c, y1, x0)].elem::() * (1.0 - xw) * yw; - let p_d = source[(b, c, y1, x1)].elem::() * xw * yw; - - p_a + p_b + p_c + p_d -} diff --git a/crates/burn/src/ops/interpolate.rs b/crates/burn/src/ops/interpolate.rs new file mode 120000 index 00000000..25bb1da5 --- /dev/null +++ b/crates/burn/src/ops/interpolate.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/interpolate.rs \ No newline at end of file diff --git a/crates/burn/src/ops/macros.rs b/crates/burn/src/ops/macros.rs deleted file mode 100644 index b3ac4f94..00000000 --- a/crates/burn/src/ops/macros.rs +++ /dev/null @@ -1,107 +0,0 @@ -macro_rules! keepdim { - ( - $dim:expr, - $self:expr, - mean - ) => {{ - // Get shape first (via reference), then pass ownership to avoid clone - let mut shape = $self.shape().into_shape(); - shape[$dim] = 1; - let tensor: SharedArray = mean_dim($self, $dim); - NdArrayOps::reshape(tensor, shape) - }}; - ( - $dim:expr, - $self:expr, - sum - ) => {{ - // Get shape first (via reference), then pass ownership to avoid clone - let mut shape = $self.shape().into_shape(); - shape[$dim] = 1; - let tensor: SharedArray = sum_dim($self, $dim); - NdArrayOps::reshape(tensor, shape) - }}; - ( - $dim:expr, - $self:expr, - prod - ) => {{ - // Get shape first (via reference), then pass ownership to avoid clone - let mut shape = $self.shape().into_shape(); - shape[$dim] = 1; - let tensor: SharedArray = prod_dim($self, $dim); - NdArrayOps::reshape(tensor, shape) - }}; -} - -use burn_backend::ElementConversion; -pub(crate) use keepdim; -use ndarray::{Axis, Zip}; - -use crate::{SharedArray, element::NdArrayElement}; - -pub(crate) fn mean_dim(tensor: SharedArray, dim: usize) -> SharedArray { - tensor.mean_axis(Axis(dim)).unwrap().into_shared() -} - -pub(crate) fn sum_dim(tensor: SharedArray, dim: usize) -> SharedArray { - tensor.sum_axis(Axis(dim)).into_shared() -} - -pub(crate) fn prod_dim(tensor: SharedArray, dim: usize) -> SharedArray { - tensor - .fold_axis(Axis(dim), 1.elem::(), |acc, &x| acc.mul(x.elem())) - .into_shared() -} - -/// Generic cumulative operation function with closure-based operation. -pub(crate) fn cumulative_with_op(tensor: SharedArray, dim: usize, op: F) -> SharedArray -where - E: NdArrayElement, - F: Fn(&mut E, &E), -{ - let axis = Axis(dim); - let shape = tensor.shape().to_vec(); - // Use into_owned() instead of to_owned() - only copies if shared, avoids copy if unique - let mut result = tensor.into_owned(); - let dim_size = shape[dim]; - - for i in 1..dim_size { - let prev = result.index_axis(axis, i - 1).to_owned(); - let mut current = result.index_axis_mut(axis, i); - Zip::from(&mut current).and(&prev).for_each(&op); - } - - result.into_shared() -} - -// Define all cumulative operation functions using the generic function -pub(crate) fn cumsum_dim(tensor: SharedArray, dim: usize) -> SharedArray { - cumulative_with_op(tensor, dim, |c, &p| *c = c.add(p.elem())) -} - -pub(crate) fn cumprod_dim(tensor: SharedArray, dim: usize) -> SharedArray { - cumulative_with_op(tensor, dim, |c, &p| *c = c.mul(p.elem())) -} - -pub(crate) fn cummin_dim>( - tensor: SharedArray, - dim: usize, -) -> SharedArray { - cumulative_with_op(tensor, dim, |c, &p| { - if p < *c { - *c = p; - } - }) -} - -pub(crate) fn cummax_dim>( - tensor: SharedArray, - dim: usize, -) -> SharedArray { - cumulative_with_op(tensor, dim, |c, &p| { - if p > *c { - *c = p; - } - }) -} diff --git a/crates/burn/src/ops/macros.rs b/crates/burn/src/ops/macros.rs new file mode 120000 index 00000000..5ffa1c12 --- /dev/null +++ b/crates/burn/src/ops/macros.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/macros.rs \ No newline at end of file diff --git a/crates/burn/src/ops/matmul.rs b/crates/burn/src/ops/matmul.rs deleted file mode 100644 index 3fb7b467..00000000 --- a/crates/burn/src/ops/matmul.rs +++ /dev/null @@ -1,362 +0,0 @@ -use crate::UnsafeSharedRef; -use crate::{NdArrayElement, ShapeOps, SharedArray, iter_range_par, ops::NdArrayOps, run_par}; - -use alloc::{vec, vec::Vec}; -use burn_backend::ElementConversion; -use burn_backend::Shape; -use ndarray::{IxDyn, s}; - -pub(crate) fn matmul( - lhs: SharedArray, - rhs: SharedArray, -) -> SharedArray { - let shape_lhs = lhs.shape(); - let shape_rhs = rhs.shape(); - let ndims = shape_lhs.num_dims(); - let m = shape_lhs[ndims - 2]; // # of left rows - let k = shape_rhs[ndims - 2]; // # of left cols and right rows - let n = shape_rhs[ndims - 1]; // # of right cols - - let (out_shape, strides_lhs, strides_rhs, strides_out) = output_shape(shape_lhs, shape_rhs); - let l_mat_size = m * k; // size of matrix component of left array - let r_mat_size = k * n; // size of matrix component of right array - let out_mat_size = m * n; // size of matrix component of output array - - let num_l_batches = shape_lhs.num_elements() / l_mat_size; - let num_r_batches = shape_rhs.num_elements() / r_mat_size; - let num_out_batches = out_shape.num_elements() / out_mat_size; - - let lhs_array = NdArrayOps::reshape(lhs, Shape::new([num_l_batches, m, k])); - let rhs_array = NdArrayOps::reshape(rhs, Shape::new([num_r_batches, k, n])); - - let alpha: E = 1.0.elem(); - let beta: E = 0.0.elem(); - - let out = run_par!(|| { - let mut out_array = ndarray::Array3::::zeros((num_out_batches, m, n)); - let unsafe_shared_out_array = UnsafeSharedRef::new(&mut out_array); - - iter_range_par!(0, num_out_batches).for_each(|out_batch| { - // Here, we: - // 1. Un-flatten the output batch into a component-based batch index. - // 2. Use the strides for left and right batch indices to convert it to a flattened - // batch for left and right. - let out_index = strides_out.unflatten(out_batch); - let l_batch = strides_lhs.flatten(&out_index); - let r_batch = strides_rhs.flatten(&out_index); - - let lhs_slice = lhs_array.slice(s!(l_batch, .., ..)); - let rhs_slice = rhs_array.slice(s!(r_batch, .., ..)); - - unsafe { - let mut out_slice = unsafe_shared_out_array - .get() - .slice_mut(s!(out_batch, .., ..)); - - ndarray::linalg::general_mat_mul( - alpha, - &lhs_slice, - &rhs_slice, - beta, - &mut out_slice, - ) - } - }); - - out_array.into_shared().into_dyn() - }); - - NdArrayOps::reshape(out, out_shape) -} - -#[derive(Debug, PartialEq)] -struct Strides { - strides: Vec, -} -impl Strides { - fn new(strides: Vec) -> Self { - Strides { strides } - } - - fn unflatten(&self, linear_index: usize) -> Vec { - let mut coord = Vec::with_capacity(self.strides.len()); - let mut rem = linear_index; - for stride in self.strides.iter() { - coord.push(rem / stride); - rem %= stride; - } - coord - } - - fn flatten(&self, index: &Vec) -> usize { - assert_eq!(self.strides.len(), index.len()); - self.strides - .iter() - .zip(index) - .map(|(stride, index)| stride * index) - .sum() - } -} - -/// Compute the (broadcasted) output shape of matrix multiplication, along with strides for -/// the non-matrix dimensions of all arrays. -/// -/// # Arguments -/// * `lsh`: Shape of the first (left-hand) matrix multiplication argument. -/// * `rsh`: Shape of the second (right-hand) matrix multiplication argument. -/// -/// # Panics -/// * If `D` is not at least 2. -/// * If the matrix multiplication dimensions (last 2) are incompatible. -/// * If any other dimension is not the same for both tensors, or equal to 1. (Any dimension where -/// one dim is equal to 1 is broadcast.) -fn output_shape(lsh: &[usize], rsh: &[usize]) -> (Shape, Strides, Strides, Strides) { - let ndims = lsh.num_dims(); - if ndims < 2 { - panic!("Matrix multiplication requires an array with at least 2 dimensions."); - } - - // Fetch matrix dimensions and check compatibility. - let l_rows = lsh[ndims - 2]; - let l_cols = lsh[ndims - 1]; - let r_rows = rsh[ndims - 2]; - let r_cols = rsh[ndims - 1]; - if l_cols != r_rows { - panic!("Dimensions are incompatible for matrix multiplication."); - } - // Set matrix dimensions of the output shape. - let mut osh = vec![0; ndims]; - osh[ndims - 2] = l_rows; - osh[ndims - 1] = r_cols; - - // Set other array dimensions, broadcasting as necessary. - // Compute the strides inline. - let mut cur_l_stride: usize = 1; - let mut cur_r_stride: usize = 1; - let mut cur_o_stride: usize = 1; - let mut l_strides = Vec::with_capacity(ndims - 2); - let mut r_strides = Vec::with_capacity(ndims - 2); - let mut o_strides = Vec::with_capacity(ndims - 2); - for i in (0..ndims - 2).rev() { - let l_dim = lsh[i]; - let r_dim = rsh[i]; - - // Compatible dimensions are: - // 1. Both dimensions are equal. - // 2. One of the dimensions is equal to 1. - let o_dim: usize; - if l_dim == r_dim { - o_dim = l_dim; // both dimensions are equal - l_strides.push(cur_l_stride); - r_strides.push(cur_r_stride); - } else if l_dim == 1 { - o_dim = r_dim; // broadcast the left - l_strides.push(0); - r_strides.push(cur_r_stride); - } else if r_dim == 1 { - o_dim = l_dim; // broadcast the right - l_strides.push(cur_l_stride); - r_strides.push(0); - } else { - panic!("Dimensions differ and cannot be broadcasted."); - } - osh[i] = o_dim; - o_strides.push(cur_o_stride); - cur_o_stride *= o_dim; - - cur_l_stride *= l_dim; - cur_r_stride *= r_dim; - } - l_strides.reverse(); - r_strides.reverse(); - o_strides.reverse(); - - ( - Shape::from(osh), - Strides::new(l_strides), - Strides::new(r_strides), - Strides::new(o_strides), - ) -} - -pub(crate) fn cross( - lhs: SharedArray, - rhs: SharedArray, - dim: usize, -) -> SharedArray { - let shape_lhs = lhs.shape(); - let shape_rhs = rhs.shape(); - let ndims = shape_lhs.num_dims(); - - // Broadcast the shapes except along dim - let mut broadcast_shape = vec![0; ndims]; - for i in 0..ndims { - if i == dim { - broadcast_shape[i] = shape_lhs[i]; // already checked to be 3 - } else { - let l = shape_lhs[i]; - let r = shape_rhs[i]; - if l == r { - broadcast_shape[i] = l; - } else if l == 1 { - broadcast_shape[i] = r; - } else if r == 1 { - broadcast_shape[i] = l; - } else { - panic!("Tensors are not broadcastable along dimension {}", i); - } - } - } - - // Broadcast lhs and rhs - let lhs_broadcast = if shape_lhs == broadcast_shape.as_slice() { - lhs - } else { - NdArrayOps::expand(lhs, Shape::from(broadcast_shape.clone())) - }; - let rhs_broadcast = if shape_rhs == broadcast_shape.as_slice() { - rhs - } else { - NdArrayOps::expand(rhs, Shape::from(broadcast_shape.clone())) - }; - - // Now, move dim to the last dimension - let mut perm = (0..ndims).collect::>(); - perm.remove(dim); - perm.push(dim); - - let lhs_permuted = NdArrayOps::permute(lhs_broadcast, &perm); - let rhs_permuted = NdArrayOps::permute(rhs_broadcast, &perm); - - // Reshape to (*, 3) - let total_elements = lhs_permuted.shape().num_elements(); - let batch_size = total_elements / 3; - let lhs_reshaped = NdArrayOps::reshape(lhs_permuted, Shape::new([batch_size, 3])); - let rhs_reshaped = NdArrayOps::reshape(rhs_permuted, Shape::new([batch_size, 3])); - - // Compute cross product - let mut result = ndarray::ArrayD::::zeros(IxDyn(&[batch_size, 3])); - for i in 0..batch_size { - let a1 = lhs_reshaped[IxDyn(&[i, 0])]; - let a2 = lhs_reshaped[IxDyn(&[i, 1])]; - let a3 = lhs_reshaped[IxDyn(&[i, 2])]; - let b1 = rhs_reshaped[IxDyn(&[i, 0])]; - let b2 = rhs_reshaped[IxDyn(&[i, 1])]; - let b3 = rhs_reshaped[IxDyn(&[i, 2])]; - result[IxDyn(&[i, 0])] = a2.mul(b3).sub(a3.mul(b2)); - result[IxDyn(&[i, 1])] = a3.mul(b1).sub(a1.mul(b3)); - result[IxDyn(&[i, 2])] = a1.mul(b2).sub(a2.mul(b1)); - } - - let result_shared = result.into_shared(); - - // Reshape back to the broadcast shape with dim at the end - let mut result_shape = broadcast_shape; - result_shape.remove(dim); - result_shape.push(3); - let result_reshaped = NdArrayOps::reshape(result_shared, Shape::from(result_shape)); - - // Permute back - let mut inv_perm = vec![0; ndims]; - for (i, &p) in perm.iter().enumerate() { - inv_perm[p] = i; - } - NdArrayOps::permute(result_reshaped, &inv_perm) -} - -#[cfg(test)] -mod tests { - use super::*; - - impl Strides { - fn empty() -> Self { - Strides { - strides: Vec::with_capacity(0), - } - } - } - - #[test] - fn test_output_shape() { - // plain matrix multiply - assert_eq!( - output_shape(&[5, 3], &[3, 7]), - ( - Shape::from([5, 7]), - Strides::empty(), - Strides::empty(), - Strides::empty() - ) - ); - // matrix multiply with one extra stack dimension - assert_eq!( - output_shape(&[4, 5, 3], &[4, 3, 7]), - ( - Shape::from([4, 5, 7]), - Strides::new(vec![1]), - Strides::new(vec![1]), - Strides::new(vec![1]) - ) - ); - // rank 3, broadcast left - assert_eq!( - output_shape(&[1, 5, 3], &[4, 3, 7]), - ( - Shape::from([4, 5, 7]), - Strides::new(vec![0]), - Strides::new(vec![1]), - Strides::new(vec![1]) - ) - ); - // rank 3, broadcast right - assert_eq!( - output_shape(&[4, 5, 3], &[1, 3, 7]), - ( - Shape::from([4, 5, 7]), - Strides::new(vec![1]), - Strides::new(vec![0]), - Strides::new(vec![1]) - ) - ); - // rank 4, multi broadcast - assert_eq!( - output_shape(&[1, 4, 5, 3], &[8, 1, 3, 7]), - ( - Shape::from([8, 4, 5, 7]), - Strides::new(vec![0, 1]), - Strides::new(vec![1, 0]), - Strides::new(vec![4, 1]) - ) - ); - // rank 5, multi-broadcast - assert_eq!( - output_shape(&[1, 3, 4, 5, 3], &[8, 3, 1, 3, 7]), - ( - Shape::from([8, 3, 4, 5, 7]), - Strides::new(vec![0, 4, 1]), - Strides::new(vec![3, 1, 0]), - Strides::new(vec![12, 4, 1]) - ) - ) - } - - #[test] - #[should_panic( - expected = "Matrix multiplication requires an array with at least 2 dimensions." - )] - fn test_output_shape_too_small() { - output_shape(&[4], &[4]); - } - - #[test] - #[should_panic(expected = "Dimensions are incompatible for matrix multiplication.")] - fn test_output_shape_bad_matrix_dims() { - output_shape(&[5, 3], &[4, 7]); - } - - #[test] - #[should_panic(expected = "Dimensions differ and cannot be broadcasted.")] - fn test_output_shape_non_broadcast() { - output_shape(&[4, 5, 3], &[2, 3, 7]); - } -} diff --git a/crates/burn/src/ops/matmul.rs b/crates/burn/src/ops/matmul.rs new file mode 120000 index 00000000..44ce5ad9 --- /dev/null +++ b/crates/burn/src/ops/matmul.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/matmul.rs \ No newline at end of file diff --git a/crates/burn/src/ops/maxpool.rs b/crates/burn/src/ops/maxpool.rs deleted file mode 100644 index 2a162cf9..00000000 --- a/crates/burn/src/ops/maxpool.rs +++ /dev/null @@ -1,247 +0,0 @@ -use crate::{ - ShapeOps, SharedArray, - element::{FloatNdArrayElement, IntNdArrayElement}, - iter_range_par, - ops::padding::apply_padding_4d, - run_par, - sharing::UnsafeSharedRef, -}; - -use burn_backend::ElementConversion; -use burn_backend::ops::conv::calculate_pool_output_size; -use ndarray::Array4; - -pub(crate) fn max_pool2d( - x: SharedArray, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ceil_mode: bool, -) -> SharedArray { - let [kernel_height, kernel_width] = kernel_size; - let [padding_height, padding_width] = padding; - let [stride_height, stride_width] = stride; - let [dilation_height, dilation_width] = dilation; - let [batch_size, channels, x_height, x_width] = x.shape().dims(); - let inf = (-f32::INFINITY).elem::(); - - let out_height = calculate_pool_output_size( - kernel_height, - stride_height, - padding_height, - dilation_height, - x_height, - ceil_mode, - ); - let out_width = calculate_pool_output_size( - kernel_width, - stride_width, - padding_width, - dilation_width, - x_width, - ceil_mode, - ); - - // Calculate extra padding needed for ceil_mode - // The maximum input position accessed is: (out_size - 1) * stride + (kernel_size - 1) * dilation - // This must be < input_size + 2 * total_padding - let max_ih = - (out_height.saturating_sub(1)) * stride_height + (kernel_height - 1) * dilation_height; - let max_iw = (out_width.saturating_sub(1)) * stride_width + (kernel_width - 1) * dilation_width; - let padded_height = x_height + 2 * padding_height; - let padded_width = x_width + 2 * padding_width; - let extra_pad_h = max_ih.saturating_sub(padded_height.saturating_sub(1)); - let extra_pad_w = max_iw.saturating_sub(padded_width.saturating_sub(1)); - let total_padding = [padding_height + extra_pad_h, padding_width + extra_pad_w]; - - let x = apply_padding_4d::(x, total_padding, inf); - - // Offset to account for extra padding (extra_pad is added on both sides by apply_padding_4d) - let offset_h = extra_pad_h; - let offset_w = extra_pad_w; - - let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; - - let output = unsafe_shared_out.get(); - - for oh in 0..out_height { - for ow in 0..out_width { - let mut max_val = inf; - - for kh in 0..kernel_height { - let ih = offset_h + oh * stride_height + kh * dilation_height; - - for kw in 0..kernel_width { - let iw = offset_w + ow * stride_width + kw * dilation_width; - - let val = x[[b, c, ih, iw]]; - - if val > max_val { - max_val = val; - } - } - } - - output[[b, c, oh, ow]] = max_val; - } - } - }) - }); - - output.into_dyn().into_shared() -} - -pub(crate) fn max_pool2d_with_indices( - x: SharedArray, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ceil_mode: bool, -) -> (SharedArray, SharedArray) { - let [kernel_height, kernel_width] = kernel_size; - let [padding_height, padding_width] = padding; - let [stride_height, stride_width] = stride; - let [dilation_height, dilation_width] = dilation; - let [batch_size, channels, x_height, x_width] = x.shape().dims(); - let inf = (-f32::INFINITY).elem::(); - - let out_height = calculate_pool_output_size( - kernel_height, - stride_height, - padding_height, - dilation_height, - x_height, - ceil_mode, - ); - let out_width = calculate_pool_output_size( - kernel_width, - stride_width, - padding_width, - dilation_width, - x_width, - ceil_mode, - ); - - // Calculate extra padding needed for ceil_mode - let max_ih = - (out_height.saturating_sub(1)) * stride_height + (kernel_height - 1) * dilation_height; - let max_iw = (out_width.saturating_sub(1)) * stride_width + (kernel_width - 1) * dilation_width; - let padded_height = x_height + 2 * padding_height; - let padded_width = x_width + 2 * padding_width; - let extra_pad_h = max_ih.saturating_sub(padded_height.saturating_sub(1)); - let extra_pad_w = max_iw.saturating_sub(padded_width.saturating_sub(1)); - let total_padding = [padding_height + extra_pad_h, padding_width + extra_pad_w]; - - let x = apply_padding_4d::(x, total_padding, inf); - - // Offset to account for extra padding - let offset_h = extra_pad_h; - let offset_w = extra_pad_w; - - let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); - let mut indices = Array4::::zeros((batch_size, channels, out_height, out_width)); - - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices); - - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; - - let output = unsafe_shared_out.get(); - let indices = unsafe_shared_indices.get(); - - for oh in 0..out_height { - for ow in 0..out_width { - let mut max_val = inf; - let mut index = 0; - - for kh in 0..kernel_height { - let ih = offset_h + oh * stride_height + kh * dilation_height; - - for kw in 0..kernel_width { - let iw = offset_w + ow * stride_width + kw * dilation_width; - let val = x[[b, c, ih, iw]]; - - if val > max_val { - max_val = val; - - // Calculate index in original (unpadded) input - let ih_orig = ih as i64 - (total_padding[0]) as i64; - let iw_orig = iw as i64 - (total_padding[1]) as i64; - - // Clamp to valid range for index calculation - let ih_clamped = ih_orig.max(0).min(x_height as i64 - 1); - let iw_clamped = iw_orig.max(0).min(x_width as i64 - 1); - - index = ih_clamped * x_width as i64 + iw_clamped; - } - } - } - - output[[b, c, oh, ow]] = max_val; - indices[[b, c, oh, ow]] = index.elem(); - } - } - }) - }); - - let output = output.into_dyn().into_shared(); - let indices = indices.into_dyn().into_shared(); - - (output, indices) -} - -#[allow(clippy::too_many_arguments)] -pub(crate) fn max_pool2d_backward( - x: SharedArray, - _kernel_size: [usize; 2], - _stride: [usize; 2], - _padding: [usize; 2], - _dilation: [usize; 2], - _ceil_mode: bool, - output_grad: SharedArray, - indices: SharedArray, -) -> SharedArray { - let [_batch_size, _channels, height, width] = output_grad.shape().dims(); - let [batch_size, channels, height_x, width_x] = x.shape().dims(); - - let output_grad = output_grad; - let indices = indices; - - let mut output = Array4::zeros((batch_size, channels, height_x, width_x)); - - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - run_par!(|| { - iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { - let b = k / channels; - let c = k % channels; - - let output = unsafe_shared_out.get(); - - for h in 0..height { - for w in 0..width { - let index = indices[[b, c, h, w]].elem::(); - let grad = output_grad[[b, c, h, w]]; - - let index_h = index as usize / width_x; - let index_w = index as usize % width_x; - - output[[b, c, index_h, index_w]] += grad; - } - } - }); - }); - - output.into_dyn().into_shared() -} diff --git a/crates/burn/src/ops/maxpool.rs b/crates/burn/src/ops/maxpool.rs new file mode 120000 index 00000000..9764cf27 --- /dev/null +++ b/crates/burn/src/ops/maxpool.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/maxpool.rs \ No newline at end of file diff --git a/crates/burn/src/ops/mod.rs b/crates/burn/src/ops/mod.rs deleted file mode 100644 index f4f215ec..00000000 --- a/crates/burn/src/ops/mod.rs +++ /dev/null @@ -1,24 +0,0 @@ -mod activation; -mod base; -mod bool_tensor; -mod int_tensor; -mod module; -mod qtensor; -#[cfg(feature = "simd")] -mod simd; -mod tensor; -mod transaction; - -pub(crate) mod adaptive_avgpool; -pub(crate) mod avgpool; -pub(crate) mod conv; -pub(crate) mod deform_conv; -pub(crate) mod grid_sample; -pub(crate) mod interpolate; -pub(crate) mod macros; -pub(crate) mod matmul; -pub(crate) mod maxpool; -pub(crate) mod padding; -pub(crate) mod quantization; - -pub(crate) use base::*; diff --git a/crates/burn/src/ops/mod.rs b/crates/burn/src/ops/mod.rs new file mode 120000 index 00000000..e839b5bd --- /dev/null +++ b/crates/burn/src/ops/mod.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/mod.rs \ No newline at end of file diff --git a/crates/burn/src/ops/module.rs b/crates/burn/src/ops/module.rs deleted file mode 100644 index a7d7e27a..00000000 --- a/crates/burn/src/ops/module.rs +++ /dev/null @@ -1,381 +0,0 @@ -use super::{ - adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward}, - avgpool::{avg_pool2d, avg_pool2d_backward}, - conv::{conv_transpose2d, conv_transpose3d, conv2d, conv3d}, - deform_conv::{backward::deform_conv2d_backward, deform_conv2d}, - interpolate::{ - bicubic_interpolate, bilinear_interpolate, lanczos3_interpolate, nearest_interpolate, - }, - maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices}, -}; -#[cfg(feature = "simd")] -use crate::ops::simd::{ - avgpool::try_avg_pool2d_simd, conv::try_conv2d_simd, maxpool::try_max_pool2d_simd, -}; -use crate::{ - NdArray, SharedArray, element::FloatNdArrayElement, execute_with_int_dtype, - tensor::NdArrayTensor, -}; -use crate::{ - element::{IntNdArrayElement, QuantElement}, - ops::interpolate::nearest_interpolate_backward, -}; -use burn_backend::{ - ElementConversion, TensorMetadata, - ops::{attention::attention_fallback, *}, - tensor::FloatTensor, -}; - -macro_rules! module_op { - // Module op with inputs (inp), optional (opt) and arguments (args). - // Converts NdArrayStorage to SharedArray for compatibility with existing operations. - (inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{ - #[allow(unused_parens, unreachable_patterns)] - match ($($x),+) { - ($(NdArrayTensor::F32($x)),+) => { - type $element = f32; - $op( - $($x.into_shared()),+ - $(, $opt.map(|o| match o { NdArrayTensor::F32(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))* - ) - } - ($(NdArrayTensor::F64($x)),+) => { - type $element = f64; - $op( - $($x.into_shared()),+ - $(, $opt.map(|o| match o { NdArrayTensor::F64(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))* - ) - } - _ => panic!("Data type mismatch"), - } - }}; -} - -impl ModuleOps - for NdArray -where - NdArrayTensor: From>, - NdArrayTensor: From>, -{ - fn conv2d( - x: NdArrayTensor, - weight: NdArrayTensor, - bias: Option, - options: ConvOptions<2>, - ) -> NdArrayTensor { - module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| { - #[cfg(feature = "simd")] - let (x, weight, bias) = match try_conv2d_simd(x, weight, bias, options.clone()) { - Ok(out) => return out.into(), - Err(args) => args, - }; - conv2d::(x, weight, bias, options).into() - }) - } - - fn deform_conv2d( - x: FloatTensor, - offset: FloatTensor, - weight: FloatTensor, - mask: Option>, - bias: Option>, - options: DeformConvOptions<2>, - ) -> FloatTensor { - module_op!( - inp(x, offset, weight), - opt(mask, bias), - E, - |x, offset, weight, mask, bias| deform_conv2d::( - x, offset, weight, mask, bias, options - ) - .into() - ) - } - - fn deform_conv2d_backward( - x: FloatTensor, - offset: FloatTensor, - weight: FloatTensor, - mask: Option>, - bias: Option>, - output_grad: FloatTensor, - options: DeformConvOptions<2>, - ) -> DeformConv2dBackward { - module_op!( - inp(x, offset, weight, output_grad), - opt(mask, bias), - E, - |x, offset, weight, output_grad, mask, bias| { - let (x, offset, weight, mask, bias) = deform_conv2d_backward::( - x, - offset, - weight, - mask, - bias, - output_grad, - options, - ); - DeformConv2dBackward::new( - x.into(), - offset.into(), - weight.into(), - mask.map(|m| m.into()), - bias.map(|b| b.into()), - ) - } - ) - } - - fn conv_transpose2d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<2>, - ) -> FloatTensor { - module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| { - conv_transpose2d::(x, weight, bias, options).into() - }) - } - - fn avg_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ceil_mode: bool, - ) -> FloatTensor { - module_op!(inp(x), opt(), E, |x| { - #[cfg(feature = "simd")] - let x = match if ceil_mode { - // SIMD path doesn't support ceil_mode yet, skip it - Err(x) - } else { - try_avg_pool2d_simd(x, kernel_size, stride, padding, count_include_pad) - } { - Ok(out) => return out.into(), - Err(x) => x, - }; - avg_pool2d::( - x, - kernel_size, - stride, - padding, - count_include_pad, - ceil_mode, - ) - .into() - }) - } - - fn avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - count_include_pad: bool, - ceil_mode: bool, - ) -> FloatTensor { - module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::( - x, - grad, - kernel_size, - stride, - padding, - count_include_pad, - ceil_mode - ) - .into()) - } - - fn max_pool2d( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ceil_mode: bool, - ) -> FloatTensor { - module_op!(inp(x), opt(), E, |x| { - #[cfg(feature = "simd")] - let x = match if ceil_mode { - // SIMD path doesn't support ceil_mode yet, skip it - Err(x) - } else { - try_max_pool2d_simd(x, kernel_size, stride, padding, dilation) - } { - Ok(out) => return out.into(), - Err(x) => x, - }; - max_pool2d::(x, kernel_size, stride, padding, dilation, ceil_mode).into() - }) - } - - fn max_pool2d_with_indices( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ceil_mode: bool, - ) -> MaxPool2dWithIndices> { - module_op!(inp(x), opt(), E, |x| { - let (output, indices) = max_pool2d_with_indices::( - x, - kernel_size, - stride, - padding, - dilation, - ceil_mode, - ); - MaxPool2dWithIndices::new(output.into(), indices.into()) - }) - } - - fn max_pool2d_with_indices_backward( - x: FloatTensor, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ceil_mode: bool, - output_grad: FloatTensor, - indices: NdArrayTensor, - ) -> MaxPool2dBackward> { - execute_with_int_dtype!(indices, IntElem, |idx_s: SharedArray| { - // Convert indices from runtime dtype to the expected I type - // (pool indices are bounded by tensor dimensions, so conversion is safe) - let indices: SharedArray = idx_s.mapv(|x| x.elem()).into_shared(); - module_op!(inp(x, output_grad), opt(), E, |x, output_grad| { - let output = max_pool2d_backward::( - x, - kernel_size, - stride, - padding, - dilation, - ceil_mode, - output_grad, - indices, - ); - MaxPool2dBackward::new(output.into()) - }) - }) - } - - fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor { - module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::( - x, - output_size - ) - .into()) - } - - fn adaptive_avg_pool2d_backward( - x: FloatTensor, - grad: FloatTensor, - ) -> FloatTensor { - module_op!(inp(x, grad), opt(), E, |x, grad| { - adaptive_avg_pool2d_backward::(x, grad).into() - }) - } - - fn interpolate( - x: FloatTensor, - output_size: [usize; 2], - options: InterpolateOptions, - ) -> FloatTensor { - match options.mode { - InterpolateMode::Nearest => { - module_op!(inp(x), opt(), E, |x| nearest_interpolate::( - x, - output_size - ) - .into()) - } - InterpolateMode::Bilinear => { - let align_corners = options.align_corners; - module_op!(inp(x), opt(), E, |x| bilinear_interpolate::( - x, - output_size, - align_corners - ) - .into()) - } - InterpolateMode::Bicubic => { - let align_corners = options.align_corners; - module_op!(inp(x), opt(), E, |x| bicubic_interpolate::( - x, - output_size, - align_corners - ) - .into()) - } - InterpolateMode::Lanczos3 => { - let align_corners = options.align_corners; - module_op!(inp(x), opt(), E, |x| lanczos3_interpolate::( - x, - output_size, - align_corners - ) - .into()) - } - } - } - - fn interpolate_backward( - x: FloatTensor, - grad: FloatTensor, - output_size: [usize; 2], - options: InterpolateOptions, - ) -> FloatTensor { - match options.mode { - InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| { - nearest_interpolate_backward::(x, grad, output_size).into() - }), - InterpolateMode::Bilinear => { - panic!("bilinear interpolation backward is not supported for ndarray backend") - } - InterpolateMode::Bicubic => { - panic!("bicubic interpolation backward is not supported for ndarray backend") - } - InterpolateMode::Lanczos3 => { - panic!("lanczos3 interpolation backward is not supported for ndarray backend") - } - } - } - - fn conv3d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvOptions<3>, - ) -> FloatTensor { - module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::( - x, weight, bias, options - ) - .into()) - } - - fn conv_transpose3d( - x: FloatTensor, - weight: FloatTensor, - bias: Option>, - options: ConvTransposeOptions<3>, - ) -> FloatTensor { - module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| { - conv_transpose3d::(x, weight, bias, options).into() - }) - } - - fn attention( - query: FloatTensor, - key: FloatTensor, - value: FloatTensor, - mask: Option>, - attn_bias: Option>, - options: AttentionModuleOptions, - ) -> FloatTensor { - attention_fallback::(query, key, value, mask, attn_bias, options) - } -} diff --git a/crates/burn/src/ops/module.rs b/crates/burn/src/ops/module.rs new file mode 120000 index 00000000..fd8616f8 --- /dev/null +++ b/crates/burn/src/ops/module.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/module.rs \ No newline at end of file diff --git a/crates/burn/src/ops/padding.rs b/crates/burn/src/ops/padding.rs deleted file mode 100644 index d9c6fd3a..00000000 --- a/crates/burn/src/ops/padding.rs +++ /dev/null @@ -1,72 +0,0 @@ -use crate::{NdArrayElement, SharedArray}; -use ndarray::{Array4, Array5}; - -use super::NdArrayOps; - -pub(crate) fn apply_padding_4d( - x: SharedArray, - padding: [usize; 2], - elem: E, -) -> SharedArray { - let [batch_size, input_channels, height, width] = x.shape().try_into().unwrap(); - let [padding_height, padding_width] = padding; - let padded_height = height + 2 * padding_height; - let padded_width = width + 2 * padding_width; - - let x_new = Array4::from_elem( - (batch_size, input_channels, padded_height, padded_width), - elem, - ); - let mut x_new = x_new.into_shared().into_dyn(); - - x_new = NdArrayOps::slice_assign( - x_new, - &[ - burn_backend::Slice::from(0..batch_size), - burn_backend::Slice::from(0..input_channels), - burn_backend::Slice::from(padding_height..height + padding_height), - burn_backend::Slice::from(padding_width..width + padding_width), - ], - x, - ); - - x_new -} - -pub(crate) fn apply_padding_5d( - x: SharedArray, - padding: [usize; 3], - elem: E, -) -> SharedArray { - let [batch_size, input_channels, depth, height, width] = x.shape().try_into().unwrap(); - let [padding_depth, padding_height, padding_width] = padding; - let padded_depth = depth + 2 * padding_depth; - let padded_height = height + 2 * padding_height; - let padded_width = width + 2 * padding_width; - - let x_new = Array5::from_elem( - ( - batch_size, - input_channels, - padded_depth, - padded_height, - padded_width, - ), - elem, - ); - let mut x_new = x_new.into_shared().into_dyn(); - - x_new = NdArrayOps::slice_assign( - x_new, - &[ - burn_backend::Slice::from(0..batch_size), - burn_backend::Slice::from(0..input_channels), - burn_backend::Slice::from(padding_depth..depth + padding_depth), - burn_backend::Slice::from(padding_height..height + padding_height), - burn_backend::Slice::from(padding_width..width + padding_width), - ], - x, - ); - - x_new -} diff --git a/crates/burn/src/ops/padding.rs b/crates/burn/src/ops/padding.rs new file mode 120000 index 00000000..7a6c0a1d --- /dev/null +++ b/crates/burn/src/ops/padding.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/padding.rs \ No newline at end of file diff --git a/crates/burn/src/ops/qtensor.rs b/crates/burn/src/ops/qtensor.rs deleted file mode 100644 index a7210fc8..00000000 --- a/crates/burn/src/ops/qtensor.rs +++ /dev/null @@ -1,353 +0,0 @@ -use alloc::{vec, vec::Vec}; - -use burn_backend::{ - DType, ExecutionError, Shape, TensorData, TensorMetadata, - ops::QTensorOps, - quantization::{ - QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue, - QuantizationParametersPrimitive, QuantizedBytes, - }, - tensor::{FloatTensor, IntTensor, QuantizedTensor}, -}; -use burn_std::{FloatDType, IntDType}; - -use crate::{ - FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray, - element::{IntNdArrayElement, QuantElement}, - execute_with_dtype, execute_with_int_dtype, execute_with_int_out_dtype, - execute_with_numeric_dtype, slice, -}; - -use super::quantization::{QuantizationStrategy, SymmetricQuantization}; -use super::{NdArrayMathOps, NdArrayOps}; - -impl QTensorOps - for NdArray -where - NdArrayTensor: From>, - NdArrayTensor: From>, -{ - fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor { - match data.dtype { - DType::QFloat(scheme) => { - let shape = data.shape.clone(); - let num_elements = data.num_elements(); - let q_bytes = QuantizedBytes { - bytes: data.into_bytes(), - scheme, - num_elements, - }; - - match scheme { - QuantScheme { - level: QuantLevel::Tensor | QuantLevel::Block(_), - mode: QuantMode::Symmetric, - value: QuantValue::Q8F | QuantValue::Q8S, - .. - } => { - // We can load QuantStore::U32 w/ QuantizedBytes impl - let (values, qparams) = q_bytes.into_vec_i8(); - let data = TensorData::new(values, shape); - // Overwrite storage - let scheme = scheme.with_store(QuantStore::Native); - - let qparams = qparams - .scales - .into_iter() - .map(|scales| QParams { scales }) - .collect(); - - NdArrayQTensor { - qtensor: NdArrayTensor::from_data(data), - scheme, - qparams, - } - } - QuantScheme { - value: - QuantValue::Q4F - | QuantValue::Q4S - | QuantValue::Q2F - | QuantValue::Q2S - | QuantValue::E2M1 - | QuantValue::E4M3 - | QuantValue::E5M2, - .. - } => unimplemented!("from_data not supported for scheme {scheme:?}"), - } - } - _ => panic!( - "Invalid dtype (expected DType::QFloat, got {:?})", - data.dtype - ), - } - } - - fn quantize( - tensor: FloatTensor, - scheme: &QuantScheme, - qparams: QuantizationParametersPrimitive, - ) -> QuantizedTensor { - let shape = tensor.shape(); - let data_f = tensor.into_data(); - let scales = qparams.scales.into_data().convert::(); - - // Implement with ndarray instead of QuantizationStrategy? - let (data, qparams) = match scheme { - QuantScheme { - level: QuantLevel::Tensor, - mode: QuantMode::Symmetric, - #[cfg(not(feature = "export_tests"))] - value: QuantValue::Q8F | QuantValue::Q8S, - // For tests, "native" sub-byte quant serves as a reference for value equality. - // Values are stored as i8 regardless. - #[cfg(feature = "export_tests")] - value: - QuantValue::Q8F - | QuantValue::Q8S - | QuantValue::Q4F - | QuantValue::Q4S - | QuantValue::Q2F - | QuantValue::Q2S, - store: QuantStore::Native, - .. - } => { - let scales = scales.iter().next().unwrap(); - let strategy = QuantizationStrategy::PerTensorSymmetric( - SymmetricQuantization::init(scales, scheme.value), - ); - let values = strategy.quantize(data_f.as_slice().unwrap()); - ( - TensorData::quantized(values, shape.clone(), *scheme, &[scales]), - vec![QParams { scales }], - ) - } - QuantScheme { - level: QuantLevel::Block(block_size), - mode: QuantMode::Symmetric, - #[cfg(not(feature = "export_tests"))] - value: QuantValue::Q8F | QuantValue::Q8S, - #[cfg(feature = "export_tests")] - value: - QuantValue::Q8F - | QuantValue::Q8S - | QuantValue::Q4F - | QuantValue::Q4S - | QuantValue::Q2F - | QuantValue::Q2S, - store: QuantStore::Native, - .. - } => { - let scales = scales.as_slice().unwrap(); - let (strategy, qparams) = scales - .iter() - .map(|&s| { - ( - SymmetricQuantization::init(s, scheme.value), - QParams { scales: s }, - ) - }) - .unzip(); - let strategy = QuantizationStrategy::PerBlockSymmetric(strategy, *block_size); - let values = strategy.quantize(data_f.as_slice().unwrap()); - ( - TensorData::quantized(values, shape.clone(), *scheme, scales), - qparams, - ) - } - scheme => unimplemented!("Quantization not supported for scheme {scheme:?}"), - }; - - let num_elements = data.num_elements(); - let q_bytes = QuantizedBytes { - bytes: data.into_bytes(), - scheme: *scheme, - num_elements, - }; - let (values, _) = q_bytes.into_vec_i8(); - let data = TensorData::new(values, shape).convert::(); - - NdArrayQTensor { - qtensor: NdArrayTensor::from_data(data), - scheme: *scheme, - qparams, - } - } - - fn dequantize(tensor: QuantizedTensor, dtype: FloatDType) -> FloatTensor { - let strategy = tensor.strategy(); - let scheme = tensor.scheme; - let shape = tensor.shape(); - let data = match tensor.qtensor { - NdArrayTensor::I8(storage) => { - let data = storage.into_shared().into_iter().collect(); - dequantize(data, shape, scheme, &strategy, dtype.into()) - } - _ => unreachable!(), - }; - NdArrayTensor::from_data(data) - } - - fn q_device(_tensor: &QuantizedTensor) -> NdArrayDevice { - NdArrayDevice::Cpu - } - - fn q_to_device( - tensor: QuantizedTensor, - _device: &NdArrayDevice, - ) -> QuantizedTensor { - tensor - } - - fn q_reshape(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor { - NdArrayQTensor { - qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray| { - NdArrayOps::reshape(array, shape) - }), - scheme: tensor.scheme, - qparams: tensor.qparams, - } - } - - async fn q_into_data(tensor: QuantizedTensor) -> Result { - let shape = tensor.qtensor.shape(); - let scales = tensor.qparams.iter().map(|q| q.scales).collect::>(); - Ok(execute_with_numeric_dtype!( - tensor.qtensor, - E, - |array: SharedArray| { - let values = array.into_iter().collect(); - TensorData::quantized(values, shape, tensor.scheme, &scales) - } - )) - } - - fn q_swap_dims( - tensor: QuantizedTensor, - dim1: usize, - dim2: usize, - ) -> QuantizedTensor { - NdArrayQTensor { - qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray| { - NdArrayOps::swap_dims(array, dim1, dim2) - }), - scheme: tensor.scheme, - qparams: tensor.qparams, - } - } - - fn q_permute(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor { - NdArrayQTensor { - qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray| { - NdArrayOps::permute(array, axes) - }), - scheme: tensor.scheme, - qparams: tensor.qparams, - } - } - - fn q_flip(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor { - NdArrayQTensor { - qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray| { - NdArrayOps::flip(array, axes) - }), - scheme: tensor.scheme, - qparams: tensor.qparams, - } - } - - fn q_gather( - dim: usize, - tensor: QuantizedTensor, - indices: IntTensor, - ) -> QuantizedTensor { - let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray< - IntElem, - >| - -> NdArrayTensor { - execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray| { - NdArrayOps::gather(dim, array, idx_array) - }) - }); - NdArrayQTensor { - qtensor, - scheme: tensor.scheme, - qparams: tensor.qparams, - } - } - - fn q_select( - tensor: QuantizedTensor, - dim: usize, - indices: IntTensor, - ) -> QuantizedTensor { - let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray< - IntElem, - >| - -> NdArrayTensor { - execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray| { - NdArrayMathOps::select(array, dim, idx_array) - }) - }); - NdArrayQTensor { - qtensor, - scheme: tensor.scheme, - qparams: tensor.qparams, - } - } - - fn q_slice( - tensor: QuantizedTensor, - slices: &[burn_backend::Slice], - ) -> QuantizedTensor { - NdArrayQTensor { - qtensor: slice!(tensor.qtensor, slices), - scheme: tensor.scheme, - qparams: tensor.qparams, - } - } - - fn q_argmax(tensor: QuantizedTensor, dim: usize, out_dtype: IntDType) -> IntTensor { - execute_with_int_out_dtype!(out_dtype, I, { - execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray| { - NdArrayMathOps::argmax::(array, dim) - }) - }) - } - - fn q_argmin(tensor: QuantizedTensor, dim: usize, out_dtype: IntDType) -> IntTensor { - execute_with_int_out_dtype!(out_dtype, I, { - execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray| { - NdArrayMathOps::argmin::(array, dim) - }) - }) - } - - fn q_expand(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor { - NdArrayQTensor { - qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray| { - NdArrayOps::expand(array, shape) - }), - scheme: tensor.scheme, - qparams: tensor.qparams, - } - } -} - -fn dequantize( - data: Vec, - shape: Shape, - scheme: QuantScheme, - strategy: &QuantizationStrategy, - dtype: DType, -) -> TensorData { - let qparams = match strategy { - QuantizationStrategy::PerTensorSymmetric(quant) => vec![quant.scale], - QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => { - quant.iter().map(|q| q.scale).collect() - } - }; - let q_bytes = QuantizedBytes::new(data, scheme, &qparams); - let (values, _qparams) = q_bytes.into_vec_i8(); - TensorData::new(strategy.dequantize(&values), shape).convert_dtype(dtype) -} diff --git a/crates/burn/src/ops/qtensor.rs b/crates/burn/src/ops/qtensor.rs new file mode 120000 index 00000000..d3d98cc7 --- /dev/null +++ b/crates/burn/src/ops/qtensor.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/qtensor.rs \ No newline at end of file diff --git a/crates/burn/src/ops/quantization.rs b/crates/burn/src/ops/quantization.rs deleted file mode 100644 index adaf1b16..00000000 --- a/crates/burn/src/ops/quantization.rs +++ /dev/null @@ -1,218 +0,0 @@ -use alloc::vec::Vec; -use num_traits::{Float, PrimInt}; - -use burn_backend::quantization::{BlockSize, QuantValue}; - -// NOTE: this mainly serves as a simple reference implementation. -// The de/quantization ops should be refactored to use ndarray. - -/// Quantization strategy. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum QuantizationStrategy { - /// Per-tensor symmetric quantization. - PerTensorSymmetric(SymmetricQuantization), - /// Per-block symmetric quantization. - PerBlockSymmetric(Vec>, BlockSize), -} - -impl QuantizationStrategy { - /// Quantize the values to a lower precision data type. - pub fn quantize(&self, values: &[f32]) -> Vec { - match self { - QuantizationStrategy::PerTensorSymmetric(strategy) => strategy.quantize(values), - QuantizationStrategy::PerBlockSymmetric(strategy, block_size) => { - let block_elems = block_size.num_elements(); - let num_blocks = strategy.len(); - let numel = values.len(); - assert_eq!( - numel / block_elems, - num_blocks, - "Invalid per-block quantization with num blocks {num_blocks} and {numel} values" - ); - values - .chunks(block_elems) - .enumerate() - .flat_map(|(block_id, block)| strategy[block_id].quantize(block)) - .collect() - } - } - } - - /// Dequantize the values to a higher precision data type. - pub fn dequantize(&self, values: &[i8]) -> Vec { - match self { - QuantizationStrategy::PerTensorSymmetric(strategy) => strategy.dequantize(values), - QuantizationStrategy::PerBlockSymmetric(strategy, block_size) => { - let block_elems = block_size.num_elements(); - let num_blocks = strategy.len(); - let numel = values.len(); - assert_eq!( - numel / block_elems, - num_blocks, - "Invalid per-block quantization with block size {block_elems}, num blocks {num_blocks} and {numel} values" - ); - values - .chunks(block_elems) - .enumerate() - .flat_map(|(block_id, block)| strategy[block_id].dequantize(block)) - .collect() - } - } - } -} - -/// Quantization scheme to convert elements of a higher precision data type `E` to a lower precision -/// data type `Q` and vice-versa. -pub trait Quantization { - /// Returns the quantization range `[a, b]`. - fn range(&self) -> (E, E); - /// Convert the values to a lower precision data type. - fn quantize(&self, values: &[E]) -> Vec; - /// Convert a single value to a lower precision data type. - fn quantize_one(&self, value: E) -> Q; - /// Convert the values back to a higher precision data type. - fn dequantize(&self, values: &[Q]) -> Vec; - /// Convert a single value back to a higher precision data type. - fn dequantize_one(&self, value: Q) -> E; -} - -fn valid_scale(mut scale: E) -> E { - // If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the - // scale to 0.1 to avoid division by zero. - if scale.eq(&E::zero()) { - scale = E::from(0.1).unwrap(); - } - scale -} - -/// Symmetric quantization scheme. -#[derive(Debug, Clone, Copy)] -pub struct SymmetricQuantization { - /// The scaling factor. - pub scale: E, - // The quantization value data type. - value: QuantValue, -} - -impl SymmetricQuantization { - /// Initialize a symmetric quantization scheme with the given parameters. - pub fn init(scale: E, value: QuantValue) -> Self { - Self { - scale: valid_scale(scale), - value, - } - } - - #[allow(dead_code)] - /// Create a new quantization scheme for an input range `[alpha, beta]`. - fn new(alpha: E, beta: E, value: QuantValue) -> Self { - let (a, b) = value.range(); - let a = E::from(a).unwrap(); - let b = E::from(b).unwrap(); - - // Compute scale to convert a floating point value in range `[-alpha, alpha]` to the quantized range - let alpha = alpha.abs().max(beta.abs()); - let scale = valid_scale((alpha + alpha) / (b - a)); - Self { scale, value } - } -} - -impl Quantization for SymmetricQuantization { - fn quantize(&self, values: &[E]) -> Vec { - values.iter().map(|x| self.quantize_one(*x)).collect() - } - - fn dequantize(&self, values: &[Q]) -> Vec { - values.iter().map(|x_q| self.dequantize_one(*x_q)).collect() - } - - fn quantize_one(&self, value: E) -> Q { - let (a, b) = self.range(); - - // x_q = clamp(round(x / scale), a, b) - Q::from(value.div(self.scale).round().clamp(a, b)).unwrap() - } - - fn dequantize_one(&self, value: Q) -> E { - // x = scale * x_q - self.scale * E::from(value).unwrap() - } - - fn range(&self) -> (E, E) { - let (a, b) = self.value.range(); - let a = E::from(a).unwrap(); - let b = E::from(b).unwrap(); - (a, b) - } -} - -impl PartialEq for SymmetricQuantization { - fn eq(&self, other: &Self) -> bool { - self.scale == other.scale - } -} - -impl Eq for SymmetricQuantization {} - -#[cfg(test)] -mod tests { - use burn_backend::TensorData; - - use super::*; - use alloc::vec; - - #[test] - fn test_int8_symmetric_quantization() { - let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5]; - let expected_q = vec![-127, -71, 0, 35]; - let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063]; - - let symmetric = SymmetricQuantization::::new(-1.8, 0.5, QuantValue::Q8S); - - let q: Vec = symmetric.quantize(&x); - assert_eq!(q, expected_q); - - let d = symmetric.dequantize(&expected_q); - - assert_eq!(d, expected_d); - } - - #[test] - fn test_int8_symmetric_quantization_per_block() { - let x: [f32; 8] = [-1.8, -1.0, 0.0, 0.5, -1.8, -1.0, 0.0, 0.5]; - let expected_q = vec![-127, -71, 0, 35, -127, -71, 0, 35]; - let expected_d = vec![ - -1.8, -1.0062993, 0.0, 0.496063, -1.8, -1.0062993, 0.0, 0.496063, - ]; - - let symmetric = SymmetricQuantization::::new(-1.8, 0.5, QuantValue::Q8S); - let strategy = QuantizationStrategy::PerBlockSymmetric( - vec![symmetric, symmetric], - BlockSize::new([4]), - ); - - let q: Vec = strategy.quantize(&x); - assert_eq!(q, expected_q); - - let d = symmetric.dequantize(&expected_q); - - assert_eq!(d, expected_d); - } - - #[test] - fn should_support_dequantize() { - let strategy = QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization { - scale: 0.1, - value: QuantValue::Q8S, - }); - - let output = strategy.dequantize(&[-127i8, -77, -26, 25, 76, 127]); - - let output = TensorData::new(output, [2, 3]); - - output.assert_approx_eq::( - &TensorData::from([[-12.7, -7.7, -2.6], [2.5, 7.6, 12.7]]), - Default::default(), - ); - } -} diff --git a/crates/burn/src/ops/quantization.rs b/crates/burn/src/ops/quantization.rs new file mode 120000 index 00000000..4eb1748f --- /dev/null +++ b/crates/burn/src/ops/quantization.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/quantization.rs \ No newline at end of file diff --git a/crates/burn/src/ops/simd/avgpool.rs b/crates/burn/src/ops/simd/avgpool.rs deleted file mode 100644 index 41d5ba61..00000000 --- a/crates/burn/src/ops/simd/avgpool.rs +++ /dev/null @@ -1,443 +0,0 @@ -use core::{marker::PhantomData, mem::transmute}; - -use crate::{SharedArray, iter_range_par, run_par, sharing::UnsafeSharedRef}; - -use burn_backend::DType; -use burn_backend::{Element, ElementConversion}; -use bytemuck::Zeroable; -use macerator::{Simd, VAdd, VDiv}; -use ndarray::{Array4, s}; -use nhwc::avg_pool_nhwc; - -use super::should_use_simd; - -#[macerator::with_simd] -fn is_accelerated(_x: PhantomData) -> bool { - ::is_accelerated::() && ::is_accelerated::() -} - -pub(crate) fn try_avg_pool2d_simd( - x: SharedArray, - ksize: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - with_pad: bool, -) -> Result, SharedArray> { - // Strides must be unit, dilation isn't supported, rows must be contiguous - if x.strides()[1] != 1 || !should_use_simd(x.shape()[1]) { - return Err(x); - } - - match E::dtype() { - DType::F64 if is_accelerated::(PhantomData) => Ok(cast(avg_pool_nhwc::( - cast(x), - ksize, - stride, - padding, - with_pad, - ))), - DType::F32 if is_accelerated::(PhantomData) => Ok(cast(avg_pool_nhwc::( - cast(x), - ksize, - stride, - padding, - with_pad, - ))), - _ => Err(x), - } -} - -fn cast(tensor: SharedArray) -> SharedArray { - unsafe { transmute::, SharedArray>(tensor) } -} - -mod nhwc { - use itertools::Itertools; - use macerator::{Simd, Vector, vload_unaligned, vstore_unaligned}; - use ndarray::{ArrayView3, ArrayViewMut3}; - use seq_macro::seq; - - use crate::ops::simd::lanes; - - use super::*; - - // Until you can use associated constants as array size, we need to hardcode this. - // The most common config (x86-v3) has 16 registers, so use half of them for accumulators. - const BLOCK_REGISTERS: usize = 8; - - pub(crate) fn avg_pool_nhwc( - x: SharedArray, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - with_pad: bool, - ) -> SharedArray { - let [kernel_height, kernel_width] = kernel_size; - let [pad_h, pad_w] = padding; - let [stride_height, stride_width] = stride; - let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap(); - let lanes = lanes::(); - - let ch_block = lanes * BLOCK_REGISTERS; - - let out_height = ((x_height + 2 * pad_h - (kernel_height - 1) - 1) / stride_height) + 1; - let out_width = ((x_width + 2 * pad_w - (kernel_width - 1) - 1) / stride_width) + 1; - - let mut output = unsafe { - Array4::::uninit((batch_size, out_height, out_width, channels)).assume_init() - }; - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - let x = x.view(); - let x = x.permuted_axes(vec![0, 2, 3, 1]); - - // Floor division ensures `blocks * lanes * blocking factor` is always `<= out_channels`. - // An exclusive loop will always have `lanes * blocking factor` elements in bounds. - let blocks = channels / ch_block; - let blocks_end = blocks * ch_block; - // Floor division means simd_end is always divisible by `lanes` and `<= out_channels`. An - // exclusive loop will always have `lanes` elements in bounds. - let simd_end = channels / lanes * lanes; - let num_simd_unblocked = (simd_end - blocks_end) / lanes; - let remainder = channels - simd_end; - - run_par!(|| { - // SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe. - iter_range_par!(0, batch_size * blocks).for_each(|k| unsafe { - let block = k % blocks; - let b = k / blocks; - - let output = unsafe_shared_out.get(); - - let x = x.slice(s![b, .., .., ..]); - let out = output.slice_mut(s![b, .., .., ..]); - - loop_blocked(x, out, kernel_size, stride, padding, with_pad, block); - }); - // SAFETY: See `loop_unblocked` - iter_range_par!(0, batch_size * num_simd_unblocked).for_each(|k| unsafe { - let ch = (k % num_simd_unblocked) * lanes + blocks_end; - let b = k / num_simd_unblocked; - - let output = unsafe_shared_out.get(); - - let x = x.slice(s![b, .., .., ..]); - let out = output.slice_mut(s![b, .., .., ..]); - - loop_unblocked(x, out, kernel_size, stride, padding, with_pad, ch); - }); - // SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe. - iter_range_par!(0, batch_size * remainder).for_each(|k| unsafe { - let ch = (k % remainder) + simd_end; - let b = k / remainder; - - let output = unsafe_shared_out.get(); - - let x = x.slice(s![b, .., .., ..]); - let out = output.slice_mut(s![b, .., .., ..]); - - loop_scalar(x, out, kernel_size, stride, padding, with_pad, ch); - }); - }); - - output = output.permuted_axes([0, 3, 1, 2]); - - output.into_dyn().into_shared() - } - - /// Execute the blocked (unrolled) portion of the pool. - #[allow( - clippy::too_many_arguments, - clippy::erasing_op, - clippy::identity_op, - unused_mut - )] - #[macerator::with_simd] - fn loop_blocked<'a, S: Simd, E: Element + VAdd + VDiv>( - x: ArrayView3<'a, E>, - mut out: ArrayViewMut3<'a, E>, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - with_pad: bool, - block: usize, - ) where - 'a: 'a, - { - let [kernel_height, kernel_width] = kernel_size; - let [pad_h, pad_w] = padding; - let [stride_height, stride_width] = stride; - - let (x_height, x_width, _) = x.dim(); - let (out_height, out_width, _) = out.dim(); - let lanes = E::lanes::(); - - let ch_block = lanes * BLOCK_REGISTERS; - - // If pixels are more than `padding` from the edges, the in pixel cannot be out of bounds - for oh in pad_h..out_height.saturating_sub(pad_h) { - for ow in pad_w..out_width.saturating_sub(pad_w) { - seq!(N in 0..8 { - let mut sum~N: Vector = Zeroable::zeroed(); - }); - let ch = block * ch_block; - let ch_end = ch + ch_block; - let mut out = out.slice_mut(s![oh, ow, ch..ch_end]); - - for kh in 0..kernel_height { - let ih = oh * stride_height + kh - pad_h; - - for kw in 0..kernel_width { - let iw = ow * stride_width + kw - pad_w; - let x = x.slice(s![ih, iw, ch..ch_end]); - - seq!(N in 0..8 { - // SAFETY: - // Load a full vector from x[N * lanes]. This is bounds checked by the - // slice above. - sum~N += unsafe { vload_unaligned(&x[N * lanes]) }; - }); - } - } - - let count = kernel_height * kernel_width; - let count = (count as u64).elem::(); - let count_v = count.splat(); - seq!(N in 0..8 { - let s~N = sum~N / count_v; - // SAFETY: - // Store a full vector to out[N * lanes]. This is bounds checked by the - // slice above. - unsafe { vstore_unaligned(&mut out[N * lanes], s~N) }; - }); - } - } - - // Border pixels need bounds checks - if (pad_h, pad_w) != (0, 0) { - let v_borders = (0..pad_h) - .chain(out_height.saturating_sub(pad_h)..out_height) - .cartesian_product(0..out_width); - let h_borders = (0..out_height) - .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width)); - - for (oh, ow) in v_borders.chain(h_borders) { - seq!(N in 0..8 { - let mut sum~N: Vector = Zeroable::zeroed(); - }); - let mut count: usize = 0; - let ch = block * ch_block; - let ch_end = ch + ch_block; - let mut out = out.slice_mut(s![oh, ow, ch..ch_end]); - - for kh in 0..kernel_height { - let ih = oh * stride_height + kh; - if ih < pad_h || ih >= x_height + pad_h { - continue; - } - let ih = ih - pad_h; - - for kw in 0..kernel_width { - let iw = ow * stride_width + kw; - if iw < pad_w || iw >= x_width + pad_w { - continue; - } - let iw = iw - pad_w; - count += 1; - - let x = x.slice(s![ih, iw, ch..ch_end]); - - seq!(N in 0..8 { - // SAFETY: - // Load a full vector from x[N * lanes]. This is bounds checked by the - // slice above. - sum~N += unsafe { vload_unaligned(&x[N * lanes]) }; - }); - } - } - - if with_pad { - count = kernel_height * kernel_width; - } - - let count = (count as u64).elem::(); - let count_v = count.splat(); - seq!(N in 0..8 { - let s~N = sum~N / count_v; - // SAFETY: - // Store a full vector to out[N * lanes]. This is bounds checked by the - // slice above. - unsafe { vstore_unaligned(&mut out[N * lanes], s~N) }; - }); - } - } - } - - /// Execute the unblocked (not unrolled) portion of the pool. - /// - /// SAFETY: Safe as long as `ch + simd_lanes <= out_channels`. - #[allow(clippy::too_many_arguments, unused_mut)] - #[macerator::with_simd] - unsafe fn loop_unblocked<'a, S: Simd, E: Element + VAdd + VDiv>( - x: ArrayView3<'a, E>, - mut out: ArrayViewMut3<'a, E>, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - with_pad: bool, - ch: usize, - ) where - 'a: 'a, - { - let [kernel_height, kernel_width] = kernel_size; - let [pad_h, pad_w] = padding; - let [stride_height, stride_width] = stride; - - let (x_height, x_width, _) = x.dim(); - let (out_height, out_width, _) = out.dim(); - - // If pixels are not within padding range, bounds checks are always true - for oh in pad_h..out_height - pad_h { - for ow in pad_w..out_width - pad_w { - let mut sum: Vector = Zeroable::zeroed(); - - for kh in 0..kernel_height { - let ih = oh * stride_height + kh - pad_h; - - for kw in 0..kernel_width { - let iw = ow * stride_width + kw - pad_w; - // Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes` - let s0 = unsafe { vload_unaligned(&x[[ih, iw, ch]]) }; - sum += s0; - } - } - - let count = kernel_height * kernel_width; - let count: E = (count as u64).elem(); - let count_v = count.splat(); - let s0 = sum / count_v; - // Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`. - unsafe { vstore_unaligned(&mut out[[oh, ow, ch]], s0) }; - } - } - - // Border pixels need bounds checks - if (pad_h, pad_w) != (0, 0) { - let v_borders = (0..pad_h) - .chain(out_height.saturating_sub(pad_h)..out_height) - .cartesian_product(0..out_width); - let h_borders = (0..out_height) - .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width)); - - for (oh, ow) in v_borders.chain(h_borders) { - let mut sum: Vector = Zeroable::zeroed(); - let mut count: usize = 0; - - for kh in 0..kernel_height { - let ih = oh * stride_height + kh; - if ih < pad_h || ih >= x_height + pad_h { - continue; - } - let ih = ih - pad_h; - - for kw in 0..kernel_width { - let iw = ow * stride_width + kw; - if iw < pad_w || iw >= x_width + pad_w { - continue; - } - let iw = iw - pad_w; - count += 1; - - // Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes` - sum += unsafe { vload_unaligned(&x[[ih, iw, ch]]) }; - } - } - - if with_pad { - count = kernel_height * kernel_width; - } - - let count = (count as u64).elem::(); - let count_v = count.splat(); - let s0 = sum / count_v; - // Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`. - unsafe { vstore_unaligned(&mut out[[oh, ow, ch]], s0) }; - } - } - } - - /// Execute scalar portion of the pooling - #[allow(clippy::too_many_arguments)] - fn loop_scalar( - x: ArrayView3<'_, E>, - mut out: ArrayViewMut3<'_, E>, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - with_pad: bool, - ch: usize, - ) { - let [kernel_height, kernel_width] = kernel_size; - let [pad_h, pad_w] = padding; - let [stride_height, stride_width] = stride; - - let (x_height, x_width, _) = x.dim(); - let (out_height, out_width, _) = out.dim(); - - // If pixels are not within padding range, bounds checks are always true - for oh in pad_h..out_height.saturating_sub(pad_h) { - for ow in pad_w..out_width.saturating_sub(pad_w) { - let mut sum: E = Zeroable::zeroed(); - - for kh in 0..kernel_height { - let ih = oh * stride_height + kh - pad_h; - - for kw in 0..kernel_width { - let iw = ow * stride_width + kw - pad_w; - sum = sum + x[[ih, iw, ch]]; - } - } - - let count = (kernel_height * kernel_width) as u64; - out[[oh, ow, ch]] = sum / count.elem(); - } - } - - // Border pixels need bounds checks - if (pad_h, pad_w) != (0, 0) { - let v_borders = (0..pad_h) - .chain(out_height.saturating_sub(pad_h)..out_height) - .cartesian_product(0..out_width); - let h_borders = (0..out_height) - .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width)); - - for (oh, ow) in v_borders.chain(h_borders) { - let mut sum: E = Zeroable::zeroed(); - let mut count: usize = 0; - - for kh in 0..kernel_height { - let ih = oh * stride_height + kh; - if ih < pad_h || ih >= x_height + pad_h { - continue; - } - let ih = ih - pad_h; - - for kw in 0..kernel_width { - let iw = ow * stride_width + kw; - if iw < pad_w || iw >= x_width + pad_w { - continue; - } - let iw = iw - pad_w; - count += 1; - sum = sum + x[[ih, iw, ch]]; - } - } - - if with_pad { - count = kernel_height * kernel_width; - } - - out[[oh, ow, ch]] = sum / (count as u64).elem(); - } - } - } -} diff --git a/crates/burn/src/ops/simd/avgpool.rs b/crates/burn/src/ops/simd/avgpool.rs new file mode 120000 index 00000000..715f8e3a --- /dev/null +++ b/crates/burn/src/ops/simd/avgpool.rs @@ -0,0 +1 @@ +../../../upstream/crates/burn-ndarray/src/ops/simd/avgpool.rs \ No newline at end of file diff --git a/crates/burn/src/ops/simd/base.rs b/crates/burn/src/ops/simd/base.rs deleted file mode 100644 index 005316f7..00000000 --- a/crates/burn/src/ops/simd/base.rs +++ /dev/null @@ -1,115 +0,0 @@ -use core::{marker::PhantomData, mem::MaybeUninit}; - -use macerator::{Arch, Scalar, Simd}; -use ndarray::{ArcArray, ArrayD, IxDyn, ShapeBuilder}; - -/// Whether SIMD instructions are worth using -#[cfg(all( - any( - target_arch = "x86", - target_arch = "x86_64", - target_arch = "aarch64", - target_arch = "wasm32", - target_arch = "loongarch64" - ), - not(test) -))] -pub fn should_use_simd(len: usize) -> bool { - len >= 32 -} - -/// Whether SIMD instructions are worth using -#[cfg(all( - not(any( - target_arch = "x86", - target_arch = "x86_64", - target_arch = "aarch64", - target_arch = "wasm32", - target_arch = "loongarch64" - )), - not(test) -))] -pub fn should_use_simd(_len: usize) -> bool { - false -} - -#[cfg(test)] -pub fn should_use_simd(_len: usize) -> bool { - true -} - -pub(crate) fn lanes() -> usize { - #[allow(non_camel_case_types)] - struct lanes<__T0>(__T0); - - impl ::macerator::WithSimd for lanes> { - type Output = usize; - #[inline(always)] - fn with_simd<__S: ::macerator::Simd>(self) -> ::Output { - let Self(__ty) = self; - #[allow(unused_unsafe)] - unsafe { - lanes_simd::<__S, E>(__ty) - } - } - } - (Arch::new()).dispatch(lanes(PhantomData::)) -} - -fn lanes_simd(_ty: PhantomData) -> usize { - E::lanes::() -} - -pub(crate) fn uninit_array_like(reference: &ArcArray) -> ArrayD { - let shape = reference.raw_dim(); - let strides = reference.strides(); - let strides = strides.iter().map(|it| *it as usize).collect::>(); - let shape_strides = shape.strides(IxDyn(&strides)); - let size = reference.len(); - let mut out_data: Vec> = Vec::with_capacity(size); - unsafe { out_data.set_len(size) }; - unsafe { ArrayD::from_shape_vec_unchecked(shape_strides, out_data).assume_init() } -} - -pub trait MinMax { - fn min(self, other: Self) -> Self; - fn max(self, other: Self) -> Self; -} - -macro_rules! impl_minmax { - ($ty: ty) => { - impl MinMax for $ty { - fn min(self, other: Self) -> Self { - Ord::min(self, other) - } - fn max(self, other: Self) -> Self { - Ord::max(self, other) - } - } - }; - ($($ty: ty),*) => { - $(impl_minmax!($ty);)* - } -} - -impl_minmax!(u8, i8, u16, i16, u32, i32, u64, i64); - -impl MinMax for f32 { - fn min(self, other: Self) -> Self { - self.min(other) - } - - fn max(self, other: Self) -> Self { - self.max(other) - } -} - -impl MinMax for f64 { - fn min(self, other: Self) -> Self { - self.min(other) - } - - fn max(self, other: Self) -> Self { - self.max(other) - } -} diff --git a/crates/burn/src/ops/simd/base.rs b/crates/burn/src/ops/simd/base.rs new file mode 120000 index 00000000..80337bda --- /dev/null +++ b/crates/burn/src/ops/simd/base.rs @@ -0,0 +1 @@ +../../../upstream/crates/burn-ndarray/src/ops/simd/base.rs \ No newline at end of file diff --git a/crates/burn/src/ops/simd/binary.rs b/crates/burn/src/ops/simd/binary.rs deleted file mode 100644 index dae3ed57..00000000 --- a/crates/burn/src/ops/simd/binary.rs +++ /dev/null @@ -1,299 +0,0 @@ -use core::{marker::PhantomData, slice}; - -use burn_backend::Element; -use macerator::{ - Scalar, Simd, VAdd, VBitAnd, VBitOr, VBitXor, VDiv, VMul, VOrd, VSub, Vector, vload_unaligned, - vstore_unaligned, -}; -use ndarray::ArrayD; -use seq_macro::seq; - -use crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like}; - -use super::{ - MinMax, - binary_elemwise::{ - VecAdd, VecBitAnd, VecBitOr, VecBitXor, VecDiv, VecMax, VecMin, VecMul, VecSub, - }, - should_use_simd, -}; - -pub trait SimdBinop { - fn apply_vec(lhs: Vector, rhs: Vector) -> Vector; - fn apply(lhs: T, rhs: T) -> Out; - fn is_accelerated() -> bool; -} - -impl SimdBinop for VecAdd { - fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { - lhs + rhs - } - - fn apply(lhs: T, rhs: T) -> T { - lhs + rhs - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -impl SimdBinop for VecDiv { - fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { - lhs / rhs - } - - fn apply(lhs: T, rhs: T) -> T { - lhs / rhs - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -impl SimdBinop for VecMul { - fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { - lhs * rhs - } - - fn apply(lhs: T, rhs: T) -> T { - lhs * rhs - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -impl SimdBinop for VecSub { - fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { - lhs - rhs - } - - fn apply(lhs: T, rhs: T) -> T { - lhs - rhs - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -impl SimdBinop for VecMin { - fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { - lhs.min(rhs) - } - - fn apply(lhs: T, rhs: T) -> T { - MinMax::min(lhs, rhs) - } - - fn is_accelerated() -> bool { - ::is_min_max_accelerated::() - } -} - -impl SimdBinop for VecMax { - fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { - lhs.max(rhs) - } - - fn apply(lhs: T, rhs: T) -> T { - MinMax::max(lhs, rhs) - } - - fn is_accelerated() -> bool { - ::is_min_max_accelerated::() - } -} - -impl SimdBinop for VecBitAnd { - fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { - lhs & rhs - } - - fn apply(lhs: T, rhs: T) -> T { - lhs.bitand(rhs) - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -impl SimdBinop for VecBitOr { - fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { - lhs | rhs - } - - fn apply(lhs: T, rhs: T) -> T { - lhs.bitor(rhs) - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -impl SimdBinop for VecBitXor { - fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { - lhs ^ rhs - } - - fn apply(lhs: T, rhs: T) -> T { - lhs.bitxor(rhs) - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -#[macerator::with_simd] -fn is_accelerated>( - _x: PhantomData<(T, Out, Op)>, -) -> bool { - Op::is_accelerated::() -} - -#[allow(clippy::result_large_err)] -pub fn try_binary_simd< - E: Element, - EOut: Element, - T: NdArrayElement + Scalar, - Out: NdArrayElement + Scalar, - Op: SimdBinop, ->( - lhs: SharedArray, - rhs: SharedArray, -) -> Result, (SharedArray, SharedArray)> { - let lhs_len = lhs.len(); - let rhs_len = rhs.len(); - if !should_use_simd(lhs_len.max(rhs_len)) - || !lhs.is_standard_layout() - || !rhs.is_standard_layout() - || lhs.shape() != rhs.shape() - || !is_accelerated::(PhantomData) - { - return Err((lhs, rhs)); - } - // Used to assert traits based on the dynamic `DType`. - let lhs = unsafe { core::mem::transmute::, SharedArray>(lhs) }; - let rhs = unsafe { core::mem::transmute::, SharedArray>(rhs) }; - let out = binary_simd_same::(lhs, rhs); - - // Used to assert traits based on the dynamic `DType`. - let out = unsafe { core::mem::transmute::, SharedArray>(out) }; - Ok(out) -} - -fn binary_simd_same< - T: NdArrayElement + Scalar, - Out: NdArrayElement + Scalar, - Op: SimdBinop, ->( - lhs: SharedArray, - rhs: SharedArray, -) -> SharedArray { - let out = if lhs.is_unique() { - let mut buf = lhs.into_owned(); - let lhs = buf.as_slice_mut().unwrap(); - let rhs = rhs.as_slice().unwrap(); - let out = - unsafe { core::mem::transmute::<&mut [T], &mut [Out]>(unsafe_alias_slice_mut(lhs)) }; - binary(lhs, rhs, out, PhantomData::); - unsafe { core::mem::transmute::, ArrayD>(buf) } - } else if rhs.is_unique() { - let mut buf = rhs.into_owned(); - let lhs = lhs.as_slice().unwrap(); - let rhs = buf.as_slice_mut().unwrap(); - let out = - unsafe { core::mem::transmute::<&mut [T], &mut [Out]>(unsafe_alias_slice_mut(rhs)) }; - binary(lhs, rhs, out, PhantomData::); - unsafe { core::mem::transmute::, ArrayD>(buf) } - } else { - let mut out = uninit_array_like(&lhs); - let lhs = lhs.as_slice().unwrap(); - let rhs = rhs.as_slice().unwrap(); - let out_slice = out.as_slice_mut().unwrap(); - binary(lhs, rhs, out_slice, PhantomData::); - out - }; - out.into_shared() -} - -#[allow(clippy::erasing_op, clippy::identity_op)] -#[macerator::with_simd] -fn binary< - 'a, - S: Simd, - T: NdArrayElement + Scalar, - Out: NdArrayElement + Scalar, - Op: SimdBinop, ->( - lhs: &'a [T], - rhs: &'a [T], - out: &'a mut [Out], - _op: PhantomData, -) where - 'a: 'a, -{ - let lanes = T::lanes::(); - let mut chunks_lhs = lhs.chunks_exact(8 * lanes); - let mut chunks_rhs = rhs.chunks_exact(8 * lanes); - let mut chunks_out = out.chunks_exact_mut(8 * lanes); - while let Some(((lhs, rhs), out)) = chunks_lhs - .next() - .zip(chunks_rhs.next()) - .zip(chunks_out.next()) - { - seq!(N in 0..8 { - // Load one full vector from `lhs`. - // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` - let lhs~N = unsafe { vload_unaligned::(&lhs[N * lanes]) }; - // Load one full vector from `rhs`. - // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` - let rhs~N = unsafe { vload_unaligned(&rhs[N * lanes]) }; - let s~N = Op::apply_vec(lhs~N, rhs~N); - // Store one full vector to `out`. - // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` - unsafe { vstore_unaligned(&mut out[N * lanes], s~N) }; - }); - } - let mut chunks_lhs = chunks_lhs.remainder().chunks_exact(lanes); - let mut chunks_rhs = chunks_rhs.remainder().chunks_exact(lanes); - let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes); - while let Some(((lhs, rhs), out)) = chunks_lhs - .next() - .zip(chunks_rhs.next()) - .zip(chunks_out.next()) - { - // Load one full vector from `lhs`. - // SAFETY: Guaranteed to be in bounds because `len == lanes` - let lhs0 = unsafe { vload_unaligned::(lhs.as_ptr()) }; - // Load one full vector from `rhs`. - // SAFETY: Guaranteed to be in bounds because `len == lanes` - let rhs0 = unsafe { vload_unaligned(rhs.as_ptr()) }; - let s0 = Op::apply_vec(lhs0, rhs0); - // Store one full vector to `out`. - // SAFETY: Guaranteed to be in bounds because `len == lanes` - unsafe { vstore_unaligned(out.as_mut_ptr(), s0) }; - } - - for ((lhs, rhs), out) in chunks_lhs - .remainder() - .iter() - .zip(chunks_rhs.remainder()) - .zip(chunks_out.into_remainder()) - { - *out = Op::apply(*lhs, *rhs) - } -} - -/// Unsafely alias a slice to use as an inline argument -fn unsafe_alias_slice_mut<'a, T>(slice: &mut [T]) -> &'a mut [T] { - let ptr = slice.as_mut_ptr(); - let len = slice.len(); - unsafe { slice::from_raw_parts_mut(ptr, len) } -} diff --git a/crates/burn/src/ops/simd/binary.rs b/crates/burn/src/ops/simd/binary.rs new file mode 120000 index 00000000..a4a8a79a --- /dev/null +++ b/crates/burn/src/ops/simd/binary.rs @@ -0,0 +1 @@ +../../../upstream/crates/burn-ndarray/src/ops/simd/binary.rs \ No newline at end of file diff --git a/crates/burn/src/ops/simd/binary_elemwise.rs b/crates/burn/src/ops/simd/binary_elemwise.rs deleted file mode 100644 index 7534da53..00000000 --- a/crates/burn/src/ops/simd/binary_elemwise.rs +++ /dev/null @@ -1,419 +0,0 @@ -use core::marker::PhantomData; - -use bytemuck::cast; -use macerator::{ - Scalar, Simd, VAdd, VBitAnd, VBitOr, VBitXor, VDiv, VMul, VOrd, VSub, Vector, vload, - vload_unaligned, vstore, vstore_unaligned, -}; -use ndarray::ArrayD; -use seq_macro::seq; - -use crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like}; - -use super::{MinMax, should_use_simd}; - -pub trait ScalarSimdBinop { - type Rhs: Copy; - type RhsVec: Copy; - fn splat(rhs: Self::Rhs) -> Self::RhsVec; - fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector; - fn apply(lhs: T, rhs: Self::Rhs) -> Out; - fn is_accelerated() -> bool; -} - -pub struct VecAdd; -pub struct VecDiv; -pub struct VecMul; -pub struct VecSub; -pub struct VecMin; -pub struct VecMax; -pub struct VecClamp; -pub struct VecBitAnd; -pub struct VecBitOr; -pub struct VecBitXor; - -impl ScalarSimdBinop for VecAdd { - type Rhs = T; - type RhsVec = Vector; - - fn splat(rhs: Self::Rhs) -> Self::RhsVec { - rhs.splat() - } - - fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { - lhs + rhs - } - - fn apply(lhs: T, rhs: T) -> T { - lhs + rhs - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -impl ScalarSimdBinop for VecDiv { - type Rhs = T; - type RhsVec = Vector; - - fn splat(rhs: Self::Rhs) -> Self::RhsVec { - rhs.splat() - } - - fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { - lhs / rhs - } - - fn apply(lhs: T, rhs: T) -> T { - lhs / rhs - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -impl ScalarSimdBinop for VecMul { - type Rhs = T; - type RhsVec = Vector; - - fn splat(rhs: Self::Rhs) -> Self::RhsVec { - rhs.splat() - } - - fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { - lhs * rhs - } - - fn apply(lhs: T, rhs: T) -> T { - lhs * rhs - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -impl ScalarSimdBinop for VecSub { - type Rhs = T; - type RhsVec = Vector; - - fn splat(rhs: Self::Rhs) -> Self::RhsVec { - rhs.splat() - } - - fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { - lhs - rhs - } - - fn apply(lhs: T, rhs: T) -> T { - lhs - rhs - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -impl ScalarSimdBinop for VecMin { - type Rhs = T; - type RhsVec = Vector; - - fn splat(rhs: Self::Rhs) -> Self::RhsVec { - rhs.splat() - } - - fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { - lhs.min(rhs) - } - - fn apply(lhs: T, rhs: T) -> T { - lhs.min(rhs) - } - - fn is_accelerated() -> bool { - ::is_min_max_accelerated::() - } -} - -impl ScalarSimdBinop for VecMax { - type Rhs = T; - type RhsVec = Vector; - - fn splat(rhs: Self::Rhs) -> Self::RhsVec { - rhs.splat() - } - - fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { - lhs.max(rhs) - } - - fn apply(lhs: T, rhs: T) -> T { - lhs.max(rhs) - } - - fn is_accelerated() -> bool { - ::is_min_max_accelerated::() - } -} - -impl ScalarSimdBinop for VecClamp { - type Rhs = (T, T); - type RhsVec = (Vector, Vector); - - fn splat((min, max): Self::Rhs) -> Self::RhsVec { - (min.splat(), max.splat()) - } - - fn apply_vec(lhs: Vector, (min, max): Self::RhsVec) -> Vector { - lhs.min(max).max(min) - } - - fn apply(lhs: T, (min, max): Self::Rhs) -> T { - lhs.min(max).max(min) - } - - fn is_accelerated() -> bool { - ::is_min_max_accelerated::() - } -} - -impl ScalarSimdBinop for VecBitAnd { - type Rhs = T; - type RhsVec = Vector; - - fn splat(rhs: Self::Rhs) -> Self::RhsVec { - rhs.splat() - } - - fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { - lhs & rhs - } - - fn apply(lhs: T, rhs: Self::Rhs) -> T { - lhs & rhs - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -impl ScalarSimdBinop for VecBitOr { - type Rhs = T; - type RhsVec = Vector; - - fn splat(rhs: Self::Rhs) -> Self::RhsVec { - rhs.splat() - } - - fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { - lhs | rhs - } - - fn apply(lhs: T, rhs: Self::Rhs) -> T { - lhs | rhs - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -impl ScalarSimdBinop for VecBitXor { - type Rhs = T; - type RhsVec = Vector; - - fn splat(rhs: Self::Rhs) -> Self::RhsVec { - rhs.splat() - } - - fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { - lhs ^ rhs - } - - fn apply(lhs: T, rhs: Self::Rhs) -> T { - lhs ^ rhs - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -#[macerator::with_simd] -fn is_accelerated>( - _x: PhantomData<(T, Out, Op)>, -) -> bool { - Op::is_accelerated::() -} - -pub fn try_binary_scalar_simd< - E: NdArrayElement, - EOut: NdArrayElement, - T: NdArrayElement + Scalar, - Out: NdArrayElement + Scalar, - Op: ScalarSimdBinop, ->( - input: SharedArray, - elem: Op::Rhs, -) -> Result, SharedArray> { - if !should_use_simd(input.len()) - || input.as_slice_memory_order().is_none() - || !is_accelerated::(PhantomData) - { - return Err(input); - } - // Used to assert traits based on the dynamic `DType`. - let input = unsafe { core::mem::transmute::, SharedArray>(input) }; - let out = if size_of::() == size_of::() - && align_of::() >= align_of::() - && input.is_unique() - { - unsafe { binary_scalar_simd_inplace::(input, elem) } - } else { - binary_scalar_simd_owned::(input, elem) - }; - // Used to assert traits based on the dynamic `DType`. - let out = unsafe { core::mem::transmute::, SharedArray>(out) }; - Ok(out) -} - -/// Execute operation in place on an owned tensor -/// SAFETY: -/// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. -unsafe fn binary_scalar_simd_inplace< - T: NdArrayElement + Scalar, - Out: NdArrayElement + Scalar, - Op: ScalarSimdBinop, ->( - input: SharedArray, - elem: Op::Rhs, -) -> SharedArray { - let mut buffer = input.into_owned(); - let slice = buffer.as_slice_memory_order_mut().unwrap(); - unsafe { binary_scalar_slice_inplace::(slice, elem, PhantomData) }; - // Buffer has the same elem size and is filled with the operation output, so this is safe - let out = unsafe { core::mem::transmute::, ArrayD>(buffer) }; - out.into_shared() -} - -/// Create a new copy of the tensor as the output -fn binary_scalar_simd_owned< - T: NdArrayElement + Scalar, - Out: NdArrayElement + Scalar, - Op: ScalarSimdBinop, ->( - input: SharedArray, - elem: Op::Rhs, -) -> SharedArray { - let mut out = uninit_array_like(&input); - let input = input.as_slice_memory_order().unwrap(); - let out_slice = out.as_slice_memory_order_mut().unwrap(); - binary_scalar_slice::(input, out_slice, elem, PhantomData); - out.into_shared() -} - -#[inline(always)] -#[allow(clippy::erasing_op, clippy::identity_op)] -#[macerator::with_simd] -fn binary_scalar_slice< - 'a, - S: Simd, - T: NdArrayElement + Scalar, - Out: NdArrayElement + Scalar, - Op: ScalarSimdBinop, ->( - input: &'a [T], - out: &'a mut [Out], - rhs: Op::Rhs, - _op: PhantomData, -) where - 'a: 'a, -{ - let lanes = T::lanes::(); - let mut chunks_input = input.chunks_exact(8 * lanes); - let mut chunks_out = out.chunks_exact_mut(8 * lanes); - let rhs_vec = Op::splat::(rhs); - while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { - seq!(N in 0..8 { - // Load one full vector from `input`. - // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` - let s~N = unsafe { vload_unaligned(&input[N * lanes]) }; - let s~N = Op::apply_vec(s~N, rhs_vec); - // Store one full vector to `out`. - // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` - unsafe { vstore_unaligned(&mut out[N * lanes], s~N) }; - }); - } - let mut chunks_input = chunks_input.remainder().chunks_exact(lanes); - let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes); - while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { - // Load one full vector from `input`. - // SAFETY: Guaranteed to be in bounds because `len == lanes` - let s0 = unsafe { vload_unaligned(input.as_ptr()) }; - let s0 = Op::apply_vec(s0, rhs_vec); - // Store one full vector to `out`. - // SAFETY: Guaranteed to be in bounds because `len == lanes` - unsafe { vstore_unaligned(out.as_mut_ptr(), s0) }; - } - - for (input, out) in chunks_input - .remainder() - .iter() - .zip(chunks_out.into_remainder()) - { - *out = Op::apply(*input, rhs) - } -} - -/// Execute operation in line. -/// SAFETY: -/// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. -#[inline(always)] -#[macerator::with_simd] -unsafe fn binary_scalar_slice_inplace< - 'a, - S: Simd, - T: NdArrayElement + Scalar, - Out: NdArrayElement + Scalar, - Op: ScalarSimdBinop, ->( - buf: &'a mut [T], - rhs: Op::Rhs, - _op: PhantomData<(Out, Op)>, -) where - 'a: 'a, -{ - let (head, main, tail) = unsafe { buf.align_to_mut::>() }; - for elem in head.iter_mut().chain(tail) { - *elem = cast(Op::apply(*elem, rhs)); - } - let mut chunks = main.chunks_exact_mut(8); - let rhs = Op::splat::(rhs); - for elem in chunks.by_ref() { - seq!(N in 0..8 { - // Load a full vector from the aligned portion of the buffer. - // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is - // always a full vector in bounds. - let s~N = unsafe { vload(&elem[N] as *const _ as *const T) }; - let s~N = Op::apply_vec(s~N, rhs); - // Store a full vector at the same position as the input. Cast is safe because `Out` is - // size and align compatible - unsafe { vstore_unaligned(&mut elem[N] as *mut _ as *mut Out, s~N) }; - }); - } - - for elem in chunks.into_remainder() { - // Load a full vector from the aligned portion of the buffer. - // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is - // always a full vector in bounds. - let s0 = unsafe { vload(elem as *const _ as *const T) }; - - let s0 = Op::apply_vec(s0, rhs); - // Store a full vector at the same position as the input. Cast is safe because `Out` is - // size and align compatible - unsafe { vstore(elem as *mut _ as *mut Out, s0) }; - } -} diff --git a/crates/burn/src/ops/simd/binary_elemwise.rs b/crates/burn/src/ops/simd/binary_elemwise.rs new file mode 120000 index 00000000..f4807064 --- /dev/null +++ b/crates/burn/src/ops/simd/binary_elemwise.rs @@ -0,0 +1 @@ +../../../upstream/crates/burn-ndarray/src/ops/simd/binary_elemwise.rs \ No newline at end of file diff --git a/crates/burn/src/ops/simd/cmp.rs b/crates/burn/src/ops/simd/cmp.rs deleted file mode 100644 index c9f8c0ea..00000000 --- a/crates/burn/src/ops/simd/cmp.rs +++ /dev/null @@ -1,374 +0,0 @@ -use core::{marker::PhantomData, slice}; - -use burn_backend::Element; -use macerator::{Mask, Scalar, Simd, VEq, VOrd, Vector, vload_unaligned}; -use ndarray::ArrayD; -use seq_macro::seq; - -use crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like}; - -use super::should_use_simd; - -pub trait SimdCmpOp { - fn apply_vec(lhs: Vector, rhs: Vector) -> Mask; - fn apply(lhs: T, rhs: T) -> bool; - fn is_accelerated() -> bool; -} - -pub struct VecEquals; - -impl SimdCmpOp for VecEquals { - fn apply_vec(lhs: Vector, rhs: Vector) -> Mask { - lhs.eq(rhs) - } - - fn apply(lhs: T, rhs: T) -> bool { - lhs == rhs - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -pub struct VecGreater; - -impl SimdCmpOp for VecGreater { - fn apply_vec(lhs: Vector, rhs: Vector) -> Mask { - lhs.gt(rhs) - } - - fn apply(lhs: T, rhs: T) -> bool { - lhs > rhs - } - - fn is_accelerated() -> bool { - ::is_cmp_accelerated::() - } -} - -pub struct VecGreaterEq; - -impl SimdCmpOp for VecGreaterEq { - fn apply_vec(lhs: Vector, rhs: Vector) -> Mask { - lhs.ge(rhs) - } - - fn apply(lhs: T, rhs: T) -> bool { - lhs >= rhs - } - - fn is_accelerated() -> bool { - ::is_cmp_accelerated::() - } -} - -pub struct VecLowerEq; - -impl SimdCmpOp for VecLowerEq { - fn apply_vec(lhs: Vector, rhs: Vector) -> Mask { - lhs.le(rhs) - } - - fn apply(lhs: T, rhs: T) -> bool { - lhs <= rhs - } - - fn is_accelerated() -> bool { - ::is_cmp_accelerated::() - } -} - -pub struct VecLower; - -impl SimdCmpOp for VecLower { - fn apply_vec(lhs: Vector, rhs: Vector) -> Mask { - lhs.lt(rhs) - } - - fn apply(lhs: T, rhs: T) -> bool { - lhs < rhs - } - - fn is_accelerated() -> bool { - ::is_cmp_accelerated::() - } -} - -#[macerator::with_simd] -fn is_accelerated>(_x: PhantomData<(T, Op)>) -> bool { - Op::is_accelerated::() -} - -#[allow(clippy::result_large_err)] -pub fn try_cmp_simd>( - lhs: SharedArray, - rhs: SharedArray, -) -> Result, (SharedArray, SharedArray)> { - let lhs_len = lhs.len(); - let rhs_len = rhs.len(); - if !should_use_simd(lhs_len.max(rhs_len)) - || !lhs.is_standard_layout() - || !rhs.is_standard_layout() - || lhs.shape() != rhs.shape() - || !is_accelerated::(PhantomData) - { - return Err((lhs, rhs)); - } - // Used to assert traits based on the dynamic `DType`. - let lhs = unsafe { core::mem::transmute::, SharedArray>(lhs) }; - let rhs = unsafe { core::mem::transmute::, SharedArray>(rhs) }; - let out = cmp_simd_same::(lhs, rhs); - - Ok(out) -} - -fn cmp_simd_same>( - lhs: SharedArray, - rhs: SharedArray, -) -> SharedArray { - let out = if lhs.is_unique() && size_of::() == size_of::() { - let mut buf = lhs.into_owned(); - let lhs = buf.as_slice_mut().unwrap(); - let rhs = rhs.as_slice().unwrap(); - let out = - unsafe { core::mem::transmute::<&mut [T], &mut [bool]>(unsafe_alias_slice_mut(lhs)) }; - cmp(lhs, rhs, out, PhantomData::); - unsafe { core::mem::transmute::, ArrayD>(buf) } - } else if rhs.is_unique() && size_of::() == size_of::() { - let mut buf = rhs.into_owned(); - let lhs = lhs.as_slice().unwrap(); - let rhs = buf.as_slice_mut().unwrap(); - let out = - unsafe { core::mem::transmute::<&mut [T], &mut [bool]>(unsafe_alias_slice_mut(rhs)) }; - cmp(lhs, rhs, out, PhantomData::); - unsafe { core::mem::transmute::, ArrayD>(buf) } - } else { - let mut out = uninit_array_like(&lhs); - let lhs = lhs.as_slice().unwrap(); - let rhs = rhs.as_slice().unwrap(); - let out_slice = out.as_slice_mut().unwrap(); - cmp(lhs, rhs, out_slice, PhantomData::); - out - }; - out.into_shared() -} - -#[allow(clippy::erasing_op, clippy::identity_op)] -#[macerator::with_simd] -fn cmp<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp>( - lhs: &'a [T], - rhs: &'a [T], - out: &'a mut [bool], - _op: PhantomData, -) where - 'a: 'a, -{ - let lanes = T::lanes::(); - let mut chunks_lhs = lhs.chunks_exact(8 * lanes); - let mut chunks_rhs = rhs.chunks_exact(8 * lanes); - let mut chunks_out = out.chunks_exact_mut(8 * lanes); - while let Some(((lhs, rhs), out)) = chunks_lhs - .next() - .zip(chunks_rhs.next()) - .zip(chunks_out.next()) - { - seq!(N in 0..8 { - // Load one full vector from `lhs`. - // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` - let lhs~N = unsafe { vload_unaligned::(&lhs[N * lanes]) }; - // Load one full vector from `rhs`. - // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` - let rhs~N = unsafe { vload_unaligned(&rhs[N * lanes]) }; - let s~N = Op::apply_vec(lhs~N, rhs~N); - // Store one full vector to `out`. - // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` - unsafe { T::mask_store_as_bool(&mut out[N * lanes], s~N) }; - }); - } - let mut chunks_lhs = chunks_lhs.remainder().chunks_exact(lanes); - let mut chunks_rhs = chunks_rhs.remainder().chunks_exact(lanes); - let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes); - while let Some(((lhs, rhs), out)) = chunks_lhs - .next() - .zip(chunks_rhs.next()) - .zip(chunks_out.next()) - { - // Load one full vector from `lhs`. - // SAFETY: Guaranteed to be in bounds because `len == lanes` - let lhs0 = unsafe { vload_unaligned::(lhs.as_ptr()) }; - // Load one full vector from `rhs`. - // SAFETY: Guaranteed to be in bounds because `len == lanes` - let rhs0 = unsafe { vload_unaligned(rhs.as_ptr()) }; - let s0 = Op::apply_vec(lhs0, rhs0); - // Store one full vector to `out`. - // SAFETY: Guaranteed to be in bounds because `len == lanes` - unsafe { T::mask_store_as_bool(out.as_mut_ptr(), s0) }; - } - - for ((lhs, rhs), out) in chunks_lhs - .remainder() - .iter() - .zip(chunks_rhs.remainder()) - .zip(chunks_out.into_remainder()) - { - *out = Op::apply(*lhs, *rhs) - } -} - -/// Unsafely alias a slice to use as an inline argument -fn unsafe_alias_slice_mut<'a, T>(slice: &mut [T]) -> &'a mut [T] { - let ptr = slice.as_mut_ptr(); - let len = slice.len(); - unsafe { slice::from_raw_parts_mut(ptr, len) } -} - -pub use elemwise::try_cmp_scalar_simd; - -mod elemwise { - use bytemuck::cast; - use macerator::vload; - - use super::*; - - pub fn try_cmp_scalar_simd>( - input: SharedArray, - elem: T, - ) -> Result, SharedArray> { - if !should_use_simd(input.len()) - || input.as_slice_memory_order().is_none() - || !is_accelerated::(PhantomData) - { - return Err(input); - } - // Used to assert traits based on the dynamic `DType`. - let input = unsafe { core::mem::transmute::, SharedArray>(input) }; - let out = if size_of::() == size_of::() - && align_of::() >= align_of::() - && input.is_unique() - { - unsafe { cmp_scalar_simd_inplace::(input, elem) } - } else { - cmp_scalar_simd_owned::(input, elem) - }; - Ok(out) - } - - /// Execute operation in place on an owned tensor - /// SAFETY: - /// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. - unsafe fn cmp_scalar_simd_inplace>( - input: SharedArray, - elem: T, - ) -> SharedArray { - let mut buffer = input.into_owned(); - let slice = buffer.as_slice_memory_order_mut().unwrap(); - unsafe { cmp_scalar_slice_inplace::(slice, elem, PhantomData) }; - // Buffer has the same elem size and is filled with the operation output, so this is safe - let out = unsafe { core::mem::transmute::, ArrayD>(buffer) }; - out.into_shared() - } - - /// Create a new copy of the tensor as the output - fn cmp_scalar_simd_owned>( - input: SharedArray, - elem: T, - ) -> SharedArray { - let mut out = uninit_array_like(&input); - let input = input.as_slice_memory_order().unwrap(); - let out_slice = out.as_slice_memory_order_mut().unwrap(); - cmp_scalar_slice::(input, out_slice, elem, PhantomData); - out.into_shared() - } - - #[inline(always)] - #[allow(clippy::erasing_op, clippy::identity_op)] - #[macerator::with_simd] - fn cmp_scalar_slice<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp>( - input: &'a [T], - out: &'a mut [bool], - rhs: T, - _op: PhantomData, - ) where - 'a: 'a, - { - let lanes = T::lanes::(); - let mut chunks_input = input.chunks_exact(8 * lanes); - let mut chunks_out = out.chunks_exact_mut(8 * lanes); - let rhs_vec = rhs.splat::(); - while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { - seq!(N in 0..8 { - // Load one full vector from `input`. - // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` - let s~N = unsafe { vload_unaligned(&input[N * lanes]) }; - let s~N = Op::apply_vec(s~N, rhs_vec); - // Store one full vector to `out`. - // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` - unsafe { T::mask_store_as_bool(&mut out[N * lanes], s~N) }; - }); - } - let mut chunks_input = chunks_input.remainder().chunks_exact(lanes); - let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes); - while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { - // Load one full vector from `input`. - // SAFETY: Guaranteed to be in bounds because `len == lanes` - let s0 = unsafe { vload_unaligned(input.as_ptr()) }; - let s0 = Op::apply_vec(s0, rhs_vec); - // Store one full vector to `out`. - // SAFETY: Guaranteed to be in bounds because `len == lanes` - unsafe { T::mask_store_as_bool(out.as_mut_ptr(), s0) }; - } - - for (input, out) in chunks_input - .remainder() - .iter() - .zip(chunks_out.into_remainder()) - { - *out = Op::apply(*input, rhs) - } - } - - /// Execute operation in line. - /// SAFETY: - /// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. - #[inline(always)] - #[macerator::with_simd] - unsafe fn cmp_scalar_slice_inplace<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp>( - buf: &'a mut [T], - rhs: T, - _op: PhantomData, - ) where - 'a: 'a, - { - let (head, main, tail) = unsafe { buf.align_to_mut::>() }; - for elem in head.iter_mut().chain(tail) { - *elem = cast(Op::apply(*elem, rhs)); - } - let mut chunks = main.chunks_exact_mut(8); - let rhs = rhs.splat::(); - for elem in chunks.by_ref() { - seq!(N in 0..8 { - // Load a full vector from the aligned portion of the buffer. - // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is - // always a full vector in bounds. - let s~N = unsafe { vload(&elem[N] as *const _ as *const T) }; - let s~N = Op::apply_vec(s~N, rhs); - // Store a full vector at the same position as the input. Cast is safe because `Out` is - // size and align compatible - unsafe { T::mask_store_as_bool(&mut elem[N] as *mut _ as *mut bool, s~N) }; - }); - } - - for elem in chunks.into_remainder() { - // Load a full vector from the aligned portion of the buffer. - // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is - // always a full vector in bounds. - let s0 = unsafe { vload(elem as *const _ as *const T) }; - - let s0 = Op::apply_vec(s0, rhs); - // Store a full vector at the same position as the input. Cast is safe because `Out` is - // size and align compatible - unsafe { T::mask_store_as_bool(elem as *mut _ as *mut bool, s0) }; - } - } -} diff --git a/crates/burn/src/ops/simd/cmp.rs b/crates/burn/src/ops/simd/cmp.rs new file mode 120000 index 00000000..0b4d7850 --- /dev/null +++ b/crates/burn/src/ops/simd/cmp.rs @@ -0,0 +1 @@ +../../../upstream/crates/burn-ndarray/src/ops/simd/cmp.rs \ No newline at end of file diff --git a/crates/burn/src/ops/simd/conv.rs b/crates/burn/src/ops/simd/conv.rs deleted file mode 100644 index 5bbd4633..00000000 --- a/crates/burn/src/ops/simd/conv.rs +++ /dev/null @@ -1,494 +0,0 @@ -use core::{marker::PhantomData, mem::transmute}; - -use burn_backend::{ - DType, Element, - ops::{ConvOptions, conv::calculate_conv_output_size}, -}; -use bytemuck::Zeroable; -use macerator::{Simd, VMulAdd, Vector, vload_unaligned, vstore_unaligned}; -use ndarray::{ - ArcArray1, Array4, ArrayView3, ArrayView4, ArrayViewMut2, ArrayViewMut3, Dim, Ix1, Ix4, s, -}; -use seq_macro::seq; - -use crate::{FloatNdArrayElement, SharedArray, UnsafeSharedRef, iter_range_par, run_par}; - -type Args = (SharedArray, SharedArray, Option>); - -#[allow(clippy::result_large_err)] -pub fn try_conv2d_simd( - x: SharedArray, - weight: SharedArray, - bias: Option>, - options: ConvOptions<2>, -) -> Result, Args> { - match E::dtype() { - DType::F64 => conv2d::(x, weight, bias, options, PhantomData), - DType::F32 => conv2d::(x, weight, bias, options, PhantomData), - DType::I64 => conv2d::(x, weight, bias, options, PhantomData), - DType::I32 => conv2d::(x, weight, bias, options, PhantomData), - DType::I16 => conv2d::(x, weight, bias, options, PhantomData), - DType::U64 => conv2d::(x, weight, bias, options, PhantomData), - DType::U32 => conv2d::(x, weight, bias, options, PhantomData), - DType::U16 => conv2d::(x, weight, bias, options, PhantomData), - _ => Err((x, weight, bias)), - } -} - -fn cast(tensor: SharedArray) -> SharedArray { - unsafe { transmute::, SharedArray>(tensor) } -} - -/// Out-channel last SIMD accelerated direct convolution. Loop order and register blocking based on -/// E. Georganas, S. Avancha, K. Banerjee, D. Kalamkar, G. Henry, H. Pabst, A. Heinecke (2018). -/// Anatomy Of High-Performance Deep Learning Convolutions On SIMD Architectures. -/// SC '18, Article 6, pp. 1-12. arXiv:1808.05567. . -#[allow(clippy::result_large_err)] -fn conv2d( - x: SharedArray, - weight: SharedArray, - bias: Option>, - options: ConvOptions<2>, - _ty: PhantomData, -) -> Result, Args> { - let [out_channels, _, k_height, k_width] = weight.shape().try_into().unwrap(); - let channels_per_group = out_channels / options.groups; - - #[macerator::with_simd] - fn precheck(_ty: PhantomData) -> (usize, bool) { - (E::lanes::(), E::is_accelerated::()) - } - - let (lanes, accelerated) = precheck::(PhantomData); - - if !accelerated || !channels_per_group.is_multiple_of(lanes) { - return Err((x, weight, bias)); - } - - let x = cast::<_, E>(x); - let weight = cast::<_, E>(weight); - let bias = bias.map(|bias| cast::<_, E>(bias)); - - let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap(); - let [dilate_h, dilate_w] = options.dilation; - let [stride_h, stride_w] = options.stride; - let [pad_h, pad_w] = options.padding; - let padded = options.padding != [0, 0]; - let strided = options.stride != [1, 1] || options.dilation != [1, 1]; - let grouped = options.groups != 1; - - let out_height = calculate_conv_output_size(k_height, stride_h, pad_h, dilate_h, in_height); - let out_width = calculate_conv_output_size(k_width, stride_w, pad_w, dilate_w, in_width); - - let x = x.into_dimensionality::().unwrap(); - let weights = weight.into_dimensionality::().unwrap(); - let weights = weights.permuted_axes([1, 2, 3, 0]); - let weights = weights.as_standard_layout(); - let bias = bias.map(|bias| bias.into_dimensionality::().unwrap()); - // floor division means `(oc_blocks - 1) * lanes` can never be greater than `out_channels - lanes`. - let oc_blocks = out_channels / lanes; - - let mut out = unsafe { - Array4::::uninit(Dim([batch_size, out_height, out_width, out_channels])).assume_init() - }; - let unsafe_shared_out = UnsafeSharedRef::new(&mut out); - - run_par!(|| { - // SAFETY: Slices are guaranteed to be non-overlapping, so having an unsafe shared reference - // is safe. `oc_blocks * lanes` must be `<= out_channels` to satisfy safety of inner function. - iter_range_par!(0, batch_size * oc_blocks).for_each(|k| unsafe { - let b = k / oc_blocks; - let ob = k % oc_blocks; - let x = x.slice(s![b, .., .., ..]); - let out = unsafe_shared_out.get(); - let mut out = out.slice_mut(s![b, .., .., ..]); - let w = weights.view(); - - match (padded, strided, grouped) { - (true, true, true) => { - conv2d_launch::(x, w, &bias, &mut out, &options, ob) - } - (true, false, true) => { - conv2d_launch::(x, w, &bias, &mut out, &options, ob) - } - (false, true, true) => { - conv2d_launch::(x, w, &bias, &mut out, &options, ob) - } - (false, false, true) => { - conv2d_launch::(x, w, &bias, &mut out, &options, ob) - } - (true, true, false) => { - conv2d_launch::(x, w, &bias, &mut out, &options, ob) - } - (true, false, false) => { - conv2d_launch::(x, w, &bias, &mut out, &options, ob) - } - (false, true, false) => { - conv2d_launch::(x, w, &bias, &mut out, &options, ob) - } - (false, false, false) => { - conv2d_launch::(x, w, &bias, &mut out, &options, ob) - } - } - }); - }); - - let output = out.permuted_axes([0, 3, 1, 2]); - Ok(cast(output.into_dyn().into_shared())) -} - -/// Size of register blocks, we need to hardcode this because Rust and the `seq` macro don't support -/// using associated constants as constant parameters. 8 works for all semi-modern CPUs but might -/// not be perfectly optimized for AVX-512 capable CPUs (which probably should use 16). -/// This should always be conservative, since oversizing it will cause register spills and that's -/// **much** worse than the performance lost with lower values. -const REGISTER_BLOCK: usize = 8; -inner_with_register_blocking_size!(8); - -/// Run a loop of conv2d. -/// # SAFETY -/// See `conv2d_inner_nopad`, `conv2d_inner_nopad_nostride`, `conv2d_remainder`. -/// Required preconditions: `ob * simd_lanes` must be `<= out_channels - simd_lanes`, `weights` and -/// `out` must have unit stride for the out channels. -#[inline(always)] -#[macerator::with_simd] -unsafe fn conv2d_launch< - 'a, - S: Simd, - E: VMulAdd, - const PAD: bool, - const STRIDE: bool, - const GROUPS: bool, ->( - x: ArrayView3<'a, E>, - weights: ArrayView4<'a, E>, - bias: &'a Option>, - out: &'a mut ArrayViewMut3<'a, E>, - options: &'a ConvOptions<2>, - ob: usize, -) where - 'a: 'a, -{ - let (in_channels, k_height, k_width, out_channels) = weights.dim(); - let (out_height, out_width, _) = out.dim(); - let channels_per_group = out_channels / options.groups; - let lanes = E::lanes::(); - - let [mut pad_h, mut pad_w] = options.padding; - let [stride_h, stride_w] = options.stride; - let [dilate_h, dilate_w] = options.dilation; - - // Trick compiler into inlining 0 to padding - if !PAD { - pad_h = 0; - pad_w = 0; - } - - let oc_b = channels_per_group.min(lanes); - let ow_b = REGISTER_BLOCK; - - let ow_start = pad_w; - let ow_width = out_width.saturating_sub(2 * pad_w); - let oh_start = pad_h; - let oh_end = out_height.saturating_sub(pad_h); - - let ow_blocks = ow_width / ow_b; - let oc = ob * oc_b; - let group = oc / channels_per_group; - let mut ic_off = group * in_channels; - if !GROUPS { - ic_off = 0; - } - - unsafe { - let bias = if let Some(bias) = &bias { - vload_unaligned::(&bias[oc]) - } else { - Zeroable::zeroed() - }; - - for oh in oh_start..oh_end { - let mut out = out.slice_mut(s![oh, .., ..]); - for ow_block in 0..ow_blocks { - let ow = ow_block * ow_b + ow_start; - - #[allow(clippy::if_same_then_else)] - if STRIDE { - conv2d_inner_nopad( - &x, &weights, &mut out, bias, oh, ow, oc, ic_off, stride_h, stride_w, - dilate_h, dilate_w, k_height, k_width, pad_h, pad_w, - ); - } else { - conv2d_inner_nopad_nostride( - &x, &weights, &mut out, bias, oh, ow, oc, ic_off, k_height, k_width, pad_h, - pad_w, - ); - } - } - } - conv2d_remainder( - x, - weights, - out, - bias, - oc, - ic_off, - ow_blocks * ow_b, - stride_h, - stride_w, - dilate_h, - dilate_w, - pad_h, - pad_w, - k_height, - k_width, - ); - } -} - -/// Execute the non-unrolled and/or padded portion of the convolution. This has more checks and is -/// much slower, so we want to minimize the amount of pixels that need to be processed by this -/// -/// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector -/// is in bounds. Weights and `out` must be channels last (with `stride == 1`). -#[allow(clippy::too_many_arguments)] -#[inline(always)] -unsafe fn conv2d_remainder( - x: ArrayView3, - weights: ArrayView4, - out: &mut ArrayViewMut3, - bias: Vector, - oc: usize, - ic_off: usize, - owb_end: usize, - stride_h: usize, - stride_w: usize, - dilate_h: usize, - dilate_w: usize, - pad_h: usize, - pad_w: usize, - k_height: usize, - k_width: usize, -) { - let in_channels = weights.shape()[0]; - let (_, in_height, in_width) = x.dim(); - let (out_height, out_width, _) = out.dim(); - let oh_start = pad_h; - let oh_end = out_height.saturating_sub(pad_h); - let ow_start = pad_w; - - let height1 = in_height + pad_h; - let width1 = in_width + pad_w; - - for oh in (0..oh_start).chain(oh_end..out_height) { - for ow in 0..out_width { - let mut acc = bias; - - for ic in 0..in_channels { - for kh in 0..k_height { - let ih = oh * stride_h + kh * dilate_h; - if (ih < pad_h) | (ih >= height1) { - continue; - } - let ih = ih - pad_h; - - for kw in 0..k_width { - let iw = ow * stride_w + kw * dilate_w; - if (iw < pad_w) | (iw >= width1) { - continue; - } - let iw = iw - pad_w; - - // Load a full vector from the weights. This is guaranteed to be in bounds - // as long as `oc <= out_channels - simd_lanes` and out channels are last. - // We need to ensure the weights are reshaped appropriately. - let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) }; - - // The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the - // compiler can't prove this. We can't use `as_slice` with fixed bounds - // because we want to support arbitrary input layouts. So an unchecked load - // is used. - let i0 = unsafe { x.uget([ic, ih, iw]) }.splat::(); - acc = i0.mul_add(f0, acc); - } - } - } - - // Store a full vector from the output. This is guaranteed to be in bounds - // as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with - // channels last, so this always holds. - unsafe { vstore_unaligned(&mut out[[oh, ow, oc]], acc) }; - } - } - for ow in (0..ow_start).chain(owb_end..out_width) { - for oh in 0..out_height { - let mut acc = bias; - - for ic in 0..in_channels { - for kh in 0..k_height { - let ih = oh * stride_h + kh * dilate_h; - if (ih < pad_h) | (ih >= height1) { - continue; - } - let ih = ih - pad_h; - - for kw in 0..k_width { - let iw = ow * stride_w + kw * dilate_w; - if (iw < pad_w) | (iw >= width1) { - continue; - } - let iw = iw - pad_w; - - // Load a full vector from the weights. This is guaranteed to be in bounds - // as long as `oc <= out_channels - simd_lanes` and out channels are last. - // We need to ensure the weights are reshaped appropriately. - let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) }; - - // The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the - // compiler can't prove this. We can't use `as_slice` with fixed bounds - // because we want to support arbitrary input layouts. So an unchecked load - // is used. - let i0 = unsafe { x.uget([ic_off + ic, ih, iw]) }.splat::(); - acc = i0.mul_add(f0, acc); - } - } - } - - // Store a full vector from the output. This is guaranteed to be in bounds - // as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with - // channels last, so this always holds. - unsafe { vstore_unaligned(&mut out[[oh, ow, oc]], acc) }; - } - } -} - -macro_rules! inner_with_register_blocking_size { - ($rb: literal) => { - /// Execute the unrolled and unpadded portion of the convolution. Any pixel that is more than - /// `pad_h` away from the horizontal border, and `pad_w` away from the vertical border is - /// guaranteed to always be in bounds (because of the way out size is calculated). - /// - /// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector - /// is in bounds. Weights and `out` must be channels last (with `stride == 1`). - #[allow(clippy::erasing_op, clippy::identity_op, clippy::too_many_arguments)] - #[inline(always)] - unsafe fn conv2d_inner_nopad( - x: &ArrayView3, - weights: &ArrayView4, - out: &mut ArrayViewMut2, - bias: Vector, - oh: usize, - ow: usize, - oc: usize, - ic_off: usize, - stride_h: usize, - stride_w: usize, - dilate_h: usize, - dilate_w: usize, - k_height: usize, - k_width: usize, - pad_h: usize, - pad_w: usize, - ) { - let in_channels = weights.shape()[0]; - - seq!(N in 0..$rb { - let mut acc~N = bias; - }); - - for ic in 0..in_channels { - for kh in 0..k_height { - let ih = oh * stride_h + kh * dilate_h - pad_h; - - for kw in 0..k_width { - // Load a full vector from the weights. This is guaranteed to be in bounds - // as long as `oc <= out_channels - simd_lanes` and out channels are last. - // We need to ensure the weights are reshaped appropriately. - let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) }; - let iw = ow * stride_w + kw * dilate_w - pad_w; - - seq!(N in 0..$rb { - // The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the - // compiler can't prove this. We can't use `as_slice` with fixed bounds - // because we want to support arbitrary input layouts. So an unchecked load - // is used. - let i~N = unsafe { x.uget([ic + ic_off, ih, iw + N * stride_w]) }.splat::(); - }); - seq!(N in 0..$rb { - acc~N = i~N.mul_add(f0, acc~N); - }); - } - } - } - - seq!(N in 0..$rb { - // Store a full vector from the output. This is guaranteed to be in bounds - // as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with - // channels last, so this always holds. - unsafe { vstore_unaligned(&mut out[[ow + N, oc]], acc~N) }; - }); - } - - /// Execute the unrolled and unpadded portion of the convolution. Any pixel that is more than - /// `pad_h` away from the horizontal border, and `pad_w` away from the vertical border is - /// guaranteed to always be in bounds (because of the way out size is calculated). - /// - /// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector - /// is in bounds. Weights and `out` must be channels last (with `stride == 1`). - #[allow(clippy::erasing_op, clippy::identity_op, clippy::too_many_arguments)] - #[inline(always)] - unsafe fn conv2d_inner_nopad_nostride( - x: &ArrayView3, - weights: &ArrayView4, - out: &mut ArrayViewMut2, - bias: Vector, - oh: usize, - ow: usize, - oc: usize, - ic_off: usize, - k_height: usize, - k_width: usize, - pad_h: usize, - pad_w: usize, - ) { - let in_channels = weights.shape()[0]; - - seq!(N in 0..$rb { - let mut acc~N = bias; - }); - - for ic in 0..in_channels { - for kh in 0..k_height { - let ih = oh + kh - pad_h; - - for kw in 0..k_width { - // Load a full vector from the weights. This is guaranteed to be in bounds - // as long as `oc <= out_channels - simd_lanes` and out channels are last. - // We need to ensure the weights are reshaped appropriately. - let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) }; - let iw = ow + kw - pad_w; - - seq!(N in 0..$rb { - // The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the - // compiler can't prove this. We can't use `as_slice` with fixed bounds - // because we want to support arbitrary input layouts. So an unchecked load - // is used. - let i~N = unsafe { x.uget([ic + ic_off, ih, iw + N]) }.splat::(); - }); - seq!(N in 0..$rb { - acc~N = i~N.mul_add(f0, acc~N); - }); - } - } - } - - seq!(N in 0..$rb { - // Store a full vector from the output. This is guaranteed to be in bounds - // as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with - // channels last, so this always holds. - unsafe { vstore_unaligned(&mut out[[ow + N, oc]], acc~N) }; - }); - } - }; -} -pub(crate) use inner_with_register_blocking_size; diff --git a/crates/burn/src/ops/simd/conv.rs b/crates/burn/src/ops/simd/conv.rs new file mode 120000 index 00000000..022a7edd --- /dev/null +++ b/crates/burn/src/ops/simd/conv.rs @@ -0,0 +1 @@ +../../../upstream/crates/burn-ndarray/src/ops/simd/conv.rs \ No newline at end of file diff --git a/crates/burn/src/ops/simd/maxpool.rs b/crates/burn/src/ops/simd/maxpool.rs deleted file mode 100644 index 279af69b..00000000 --- a/crates/burn/src/ops/simd/maxpool.rs +++ /dev/null @@ -1,394 +0,0 @@ -use core::{marker::PhantomData, mem::transmute}; - -use crate::{SharedArray, iter_range_par, run_par, sharing::UnsafeSharedRef}; - -use burn_backend::{BoolStore, DType, Element, quantization::QuantValue}; -use macerator::{Simd, VOrd}; -use ndarray::{Array4, s}; -use nhwc::max_pool2d_nhwc; - -use super::{MinMax, should_use_simd}; - -#[macerator::with_simd] -fn is_accelerated_impl(_x: PhantomData) -> bool { - ::is_min_max_accelerated::() -} - -fn is_accelerated() -> bool { - is_accelerated_impl::(PhantomData) -} - -macro_rules! launch_kernel { - ($ty: ty, $func: ident, $x: expr, $($arg: expr),*) => { - match <$ty as Element>::dtype() { - DType::F64 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), - DType::F32 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), - DType::I64 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), - DType::I32 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), - DType::I16 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), - DType::I8 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), - DType::U64 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), - DType::U32 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), - DType::U16 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), - DType::U8 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), - DType::Bool(BoolStore::Native) if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), - DType::QFloat(scheme) => match scheme.value { - QuantValue::Q8F | QuantValue::Q8S if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), - _ => Err($x) - }, - _ => Err($x), - } - }; -} - -pub(crate) fn try_max_pool2d_simd( - x: SharedArray, - ksize: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], -) -> Result, SharedArray> { - let [_, c, _, _] = x.shape().try_into().unwrap(); - if !should_use_simd(c) || x.strides()[1] != 1 { - return Err(x); - } - - launch_kernel!(E, max_pool2d_nhwc, x, ksize, stride, padding, dilation) -} - -fn cast(tensor: SharedArray) -> SharedArray { - unsafe { transmute::, SharedArray>(tensor) } -} - -mod nhwc { - use itertools::Itertools; - use macerator::{Simd, vload_unaligned, vstore_unaligned}; - use ndarray::{ArrayView3, ArrayViewMut3, Ix4}; - use seq_macro::seq; - - use crate::ops::simd::lanes; - - use super::*; - - // Until you can use associated constants as array size, we need to hardcode this. - // The most common config (x86-v3) has 16 registers, so use half of them for accumulators. - const BLOCK_REGISTERS: usize = 8; - - pub(crate) fn max_pool2d_nhwc( - x: SharedArray, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ) -> SharedArray { - let [kernel_height, kernel_width] = kernel_size; - let [pad_h, pad_w] = padding; - let [stride_height, stride_width] = stride; - let [dilation_height, dilation_width] = dilation; - let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap(); - let lanes = lanes::(); - - let ch_block = lanes * BLOCK_REGISTERS; - - let out_height = ((x_height + 2 * pad_h - dilation_height * (kernel_height - 1) - 1) - / stride_height) - + 1; - let out_width = - ((x_width + 2 * pad_w - dilation_width * (kernel_width - 1) - 1) / stride_width) + 1; - - let mut output = unsafe { - Array4::::uninit((batch_size, out_height, out_width, channels)).assume_init() - }; - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); - - let x = x.into_dimensionality::().unwrap(); - let x = x.view(); - let x = x.permuted_axes([0, 2, 3, 1]); - - // Floor division ensures `blocks * lanes * blocking factor` is always `<= out_channels`. - // An exclusive loop will always have `lanes * blocking factor` elements in bounds. - let blocks = channels / ch_block; - let blocks_end = blocks * ch_block; - // Floor division means simd_end is always divisible by `lanes` and `<= out_channels`. An - // exclusive loop will always have `lanes` elements in bounds. - let simd_end = channels / lanes * lanes; - let simd_unblocked = (simd_end - blocks_end) / lanes; - let remainder = channels - simd_end; - - run_par!(|| { - // SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe. - iter_range_par!(0, batch_size * blocks).for_each(|k| unsafe { - let block = k % blocks; - let b = k / blocks; - - let output = unsafe_shared_out.get(); - let x = x.slice(s![b, .., .., ..]); - let out = output.slice_mut(s![b, .., .., ..]); - loop_blocked(x, out, kernel_size, stride, padding, dilation, block); - }); - // SAFETY: See `loop_unblocked` - iter_range_par!(0, batch_size * simd_unblocked).for_each(|k| unsafe { - let ch = (k % simd_unblocked) * lanes + blocks_end; - let b = k / simd_unblocked; - - let output = unsafe_shared_out.get(); - let x = x.slice(s![b, .., .., ..]); - let out = output.slice_mut(s![b, .., .., ..]); - loop_unblocked(x, out, kernel_size, stride, padding, dilation, ch); - }); - // SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe. - iter_range_par!(0, batch_size * remainder).for_each(|k| unsafe { - let ch = (k % remainder) + simd_end; - let b = k / remainder; - - let output = unsafe_shared_out.get(); - let x = x.slice(s![b, .., .., ..]); - let out = output.slice_mut(s![b, .., .., ..]); - loop_scalar(x, out, kernel_size, stride, padding, dilation, ch); - }); - }); - - output = output.permuted_axes([0, 3, 1, 2]); - - output.into_dyn().into_shared() - } - - /// Execute the blocked (unrolled) portion of the pool. - #[allow( - clippy::too_many_arguments, - clippy::erasing_op, - clippy::identity_op, - unused_mut - )] - #[inline(always)] - #[macerator::with_simd] - fn loop_blocked<'a, S: Simd, E: Element + VOrd + MinMax>( - x: ArrayView3<'a, E>, - mut out: ArrayViewMut3<'a, E>, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - block: usize, - ) where - 'a: 'a, - { - let [kernel_height, kernel_width] = kernel_size; - let [pad_h, pad_w] = padding; - let [stride_height, stride_width] = stride; - let [dilation_height, dilation_width] = dilation; - - let (x_height, x_width, _) = x.dim(); - let (out_height, out_width, _) = out.dim(); - let lanes = E::lanes::(); - let ch_block = lanes * BLOCK_REGISTERS; - - let min = E::MIN.splat::(); - // If outside padding area, kernels are guaranteed to be in bounds - for oh in pad_h..out_height.saturating_sub(pad_h) { - for ow in pad_w..out_width.saturating_sub(pad_w) { - seq!(N in 0..8 { - let mut acc~N = min; - }); - let ch = block * ch_block; - let ch_end = ch + ch_block; - let mut out = out.slice_mut(s![oh, ow, ch..ch_end]); - - for kh in 0..kernel_height { - let ih = oh * stride_height + kh * dilation_height - pad_h; - - for kw in 0..kernel_width { - let iw = ow * stride_width + kw * dilation_width - pad_w; - let x = x.slice(s![ih, iw, ch..ch_end]); - - seq!(N in 0..8 { - // SAFETY: - // Load a full vector from x[N * lanes]. This is bounds checked by the - // slice above. - acc~N = acc~N.max(unsafe { vload_unaligned(&x[N * lanes]) }); - }); - } - } - - seq!(N in 0..8 { - // SAFETY: - // Store a full vector to out[N * lanes]. This is bounds checked by the - // slice above. - unsafe { vstore_unaligned(&mut out[N * lanes], acc~N) }; - }); - } - } - - // Border pixels need bounds checks - if (pad_h, pad_w) != (0, 0) { - let v_borders = (0..pad_h) - .chain(out_height.saturating_sub(pad_h)..out_height) - .cartesian_product(0..out_width); - let h_borders = (0..out_height) - .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width)); - - for (oh, ow) in v_borders.chain(h_borders) { - seq!(N in 0..8 { - let mut acc~N = min; - }); - let ch = block * ch_block; - let ch_end = ch + ch_block; - let mut out = out.slice_mut(s![oh, ow, ch..ch_end]); - - for kh in 0..kernel_height { - let ih = oh * stride_height + kh * dilation_height; - if ih < pad_h || ih >= x_height + pad_h { - continue; - } - let ih = ih - pad_h; - - for kw in 0..kernel_width { - let iw = ow * stride_width + kw * dilation_width; - if iw < pad_w || iw >= x_width + pad_w { - continue; - } - let iw = iw - pad_w; - - let x = x.slice(s![ih, iw, ch..ch_end]); - - seq!(N in 0..8 { - // SAFETY: - // Load a full vector from x[N * lanes]. This is bounds checked by the - // slice above. - acc~N = acc~N.max(unsafe { vload_unaligned(&x[N * lanes]) }); - }); - } - } - - seq!(N in 0..8 { - // SAFETY: - // Store a full vector to out[N * lanes]. This is bounds checked by the - // slice above. - unsafe { vstore_unaligned(&mut out[N * lanes], acc~N) }; - }); - } - } - } - - /// Execute the unblocked (not unrolled) portion of the pool. - /// - /// SAFETY: Safe as long as `ch + simd_lanes <= out_channels`. - #[allow(clippy::too_many_arguments, unused_mut)] - #[inline(always)] - #[macerator::with_simd] - unsafe fn loop_unblocked<'a, S: Simd, E: Element + VOrd + MinMax>( - x: ArrayView3<'a, E>, - mut out: ArrayViewMut3<'a, E>, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ch: usize, - ) where - 'a: 'a, - { - let [kernel_height, kernel_width] = kernel_size; - let [pad_h, pad_w] = padding; - let [stride_height, stride_width] = stride; - let [dilation_height, dilation_width] = dilation; - - let (x_height, x_width, _) = x.dim(); - let (out_height, out_width, _) = out.dim(); - - for oh in pad_h..out_height.saturating_sub(pad_h) { - for ow in pad_w..out_width.saturating_sub(pad_w) { - let mut acc = E::MIN.splat::(); - let out = &mut out[[oh, ow, ch]]; - - for kh in 0..kernel_height { - let ih = oh * stride_height + kh * dilation_height - pad_h; - - for kw in 0..kernel_width { - let iw = ow * stride_width + kw * dilation_width - pad_w; - // Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes` - acc = acc.max(unsafe { vload_unaligned(&x[[ih, iw, ch]]) }); - } - } - // Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`. - unsafe { vstore_unaligned(out, acc) }; - } - } - - // Border pixels need bounds checks - if (pad_h, pad_w) != (0, 0) { - let v_borders = (0..pad_h) - .chain(out_height.saturating_sub(pad_h)..out_height) - .cartesian_product(0..out_width); - let h_borders = (0..out_height) - .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width)); - - for (oh, ow) in v_borders.chain(h_borders) { - let mut acc = E::MIN.splat::(); - let out = &mut out[[oh, ow, ch]]; - - for kh in 0..kernel_height { - let ih = oh * stride_height + kh * dilation_height; - if ih < pad_h || ih >= x_height + pad_h { - continue; - } - let ih = ih - pad_h; - - for kw in 0..kernel_width { - let iw = ow * stride_width + kw * dilation_width; - if iw < pad_w || iw >= x_width + pad_w { - continue; - } - let iw = iw - pad_w; - // Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes` - acc = acc.max(unsafe { vload_unaligned(&x[[ih, iw, ch]]) }); - } - } - // Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`. - unsafe { vstore_unaligned(out, acc) }; - } - } - } - - fn loop_scalar( - x: ArrayView3<'_, E>, - mut out: ArrayViewMut3<'_, E>, - kernel_size: [usize; 2], - stride: [usize; 2], - padding: [usize; 2], - dilation: [usize; 2], - ch: usize, - ) { - let [kernel_height, kernel_width] = kernel_size; - let [pad_h, pad_w] = padding; - let [stride_height, stride_width] = stride; - let [dilation_height, dilation_width] = dilation; - - let (x_height, x_width, _) = x.dim(); - let (out_height, out_width, _) = out.dim(); - - for oh in 0..out_height { - for ow in 0..out_width { - let mut acc = E::MIN; - - for kh in 0..kernel_height { - let ih = oh * stride_height + kh * dilation_height; - if ih < pad_h || ih >= x_height + pad_h { - continue; - } - let ih = ih - pad_h; - - for kw in 0..kernel_width { - let iw = ow * stride_width + kw * dilation_width; - if iw < pad_w || iw >= x_width + pad_w { - continue; - } - let iw = iw - pad_w; - acc = acc.max(x[[ih, iw, ch]]); - } - } - - out[[oh, ow, ch]] = acc; - } - } - } -} diff --git a/crates/burn/src/ops/simd/maxpool.rs b/crates/burn/src/ops/simd/maxpool.rs new file mode 120000 index 00000000..ede08b96 --- /dev/null +++ b/crates/burn/src/ops/simd/maxpool.rs @@ -0,0 +1 @@ +../../../upstream/crates/burn-ndarray/src/ops/simd/maxpool.rs \ No newline at end of file diff --git a/crates/burn/src/ops/simd/mod.rs b/crates/burn/src/ops/simd/mod.rs deleted file mode 100644 index 2032f30c..00000000 --- a/crates/burn/src/ops/simd/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -pub(crate) mod avgpool; -mod base; -pub(crate) mod binary; -pub(crate) mod binary_elemwise; -pub(crate) mod cmp; -pub(crate) mod conv; -pub(crate) mod maxpool; -pub(crate) mod unary; - -pub use base::*; diff --git a/crates/burn/src/ops/simd/mod.rs b/crates/burn/src/ops/simd/mod.rs new file mode 120000 index 00000000..b2a2eda6 --- /dev/null +++ b/crates/burn/src/ops/simd/mod.rs @@ -0,0 +1 @@ +../../../upstream/crates/burn-ndarray/src/ops/simd/mod.rs \ No newline at end of file diff --git a/crates/burn/src/ops/simd/unary.rs b/crates/burn/src/ops/simd/unary.rs deleted file mode 100644 index 68d26267..00000000 --- a/crates/burn/src/ops/simd/unary.rs +++ /dev/null @@ -1,234 +0,0 @@ -use core::marker::PhantomData; - -use bytemuck::cast; -use macerator::{ - Scalar, Simd, VAbs, VBitNot, VRecip, Vector, vload, vload_unaligned, vstore, vstore_unaligned, -}; -use ndarray::ArrayD; -use num_traits::Signed; -use seq_macro::seq; - -use crate::{NdArrayElement, SharedArray}; - -use super::should_use_simd; - -pub trait SimdUnop { - fn apply_vec(input: Vector) -> Vector; - fn apply(input: T) -> Out; - fn is_accelerated() -> bool; -} - -pub struct RecipVec; - -impl SimdUnop for RecipVec { - fn apply_vec(input: Vector) -> Vector { - input.recip() - } - - fn apply(input: f32) -> f32 { - input.recip() - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -pub struct VecAbs; - -impl SimdUnop for VecAbs { - fn apply_vec(input: Vector) -> Vector { - input.abs() - } - - fn apply(input: T) -> T { - input.abs() - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -pub struct VecBitNot; - -impl SimdUnop for VecBitNot { - fn apply_vec(input: Vector) -> Vector { - !input - } - - fn apply(input: T) -> T { - input.not() - } - - fn is_accelerated() -> bool { - ::is_accelerated::() - } -} - -#[macerator::with_simd] -fn is_accelerated>( - _x: PhantomData<(T, Out, Op)>, -) -> bool { - Op::is_accelerated::() -} - -pub fn try_unary_simd< - E: NdArrayElement, - EOut: NdArrayElement, - T: NdArrayElement + Scalar, - Out: NdArrayElement + Scalar, - Op: SimdUnop, ->( - input: SharedArray, -) -> Result, SharedArray> { - if !should_use_simd(input.len()) - || input.as_slice_memory_order().is_none() - || !is_accelerated::(PhantomData) - { - return Err(input); - } - // Used to assert traits based on the dynamic `DType`. - let input = unsafe { core::mem::transmute::, SharedArray>(input) }; - let out = if size_of::() == size_of::() - && align_of::() >= align_of::() - && input.is_unique() - { - unsafe { unary_scalar_simd_inplace::(input) } - } else { - unary_scalar_simd_owned::(input) - }; - // Used to assert traits based on the dynamic `DType`. - let out = unsafe { core::mem::transmute::, SharedArray>(out) }; - Ok(out) -} - -/// Execute operation in line. -/// SAFETY: -/// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. -unsafe fn unary_scalar_simd_inplace< - T: NdArrayElement + Scalar, - Out: NdArrayElement + Scalar, - Op: SimdUnop, ->( - input: SharedArray, -) -> SharedArray { - let mut buffer = input.into_owned(); - let slice = buffer.as_slice_memory_order_mut().unwrap(); - // This is only called when in and out have the same size, so it's safe - unsafe { unary_slice_inplace::(slice, PhantomData) }; - // Buffer has the same elem size and is filled with the operation output, so this is safe - let out = unsafe { core::mem::transmute::, ArrayD>(buffer) }; - out.into_shared() -} - -fn unary_scalar_simd_owned< - T: NdArrayElement + Scalar, - Out: NdArrayElement + Scalar, - Op: SimdUnop, ->( - input: SharedArray, -) -> SharedArray { - let mut out = unsafe { ArrayD::uninit(input.shape()).assume_init() }; - let input = input.as_slice_memory_order().unwrap(); - let out_slice = out.as_slice_memory_order_mut().unwrap(); - unary_slice::(input, out_slice, PhantomData); - out.into_shared() -} - -#[allow(clippy::erasing_op, clippy::identity_op)] -#[macerator::with_simd] -fn unary_slice< - 'a, - S: Simd, - T: NdArrayElement + Scalar, - Out: NdArrayElement + Scalar, - Op: SimdUnop, ->( - input: &'a [T], - out: &'a mut [Out], - _op: PhantomData, -) where - 'a: 'a, -{ - let lanes = T::lanes::(); - let mut chunks_input = input.chunks_exact(8 * lanes); - let mut chunks_out = out.chunks_exact_mut(8 * lanes); - while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { - seq!(N in 0..8 { - // Load one full vector from `input`. - // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` - let s~N = unsafe { vload_unaligned(&input[N * lanes]) }; - let s~N = Op::apply_vec::(s~N); - // Store one full vector to `out`. - // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` - unsafe { vstore_unaligned(&mut out[N * lanes], s~N) }; - }); - } - let mut chunks_input = chunks_input.remainder().chunks_exact(lanes); - let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes); - while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { - // Load one full vector from `input`. - // SAFETY: Guaranteed to be in bounds because `len == lanes` - let s0 = unsafe { vload_unaligned(input.as_ptr()) }; - let s0 = Op::apply_vec::(s0); - // Store one full vector to `out`. - // SAFETY: Guaranteed to be in bounds because `len == lanes` - unsafe { vstore_unaligned(out.as_mut_ptr(), s0) }; - } - - for (input, out) in chunks_input - .remainder() - .iter() - .zip(chunks_out.into_remainder()) - { - *out = Op::apply(*input) - } -} - -/// Execute operation in line. -/// SAFETY: -/// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. -#[macerator::with_simd] -unsafe fn unary_slice_inplace< - 'a, - S: Simd, - T: NdArrayElement + Scalar, - Out: NdArrayElement + Scalar, - Op: SimdUnop, ->( - buf: &'a mut [T], - _op: PhantomData<(Out, Op)>, -) where - 'a: 'a, -{ - let (head, main, tail) = unsafe { buf.align_to_mut::>() }; - for elem in head.iter_mut().chain(tail) { - *elem = cast(Op::apply(*elem)); - } - let mut chunks = main.chunks_exact_mut(8); - for elem in chunks.by_ref() { - seq!(N in 0..8 { - // Load a full vector from the aligned portion of the buffer. - // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is - // always a full vector in bounds. - let s~N = unsafe { vload(&elem[N] as *const _ as *const T) }; - let s~N = Op::apply_vec::(s~N); - // Store a full vector at the same position as the input. Cast is safe because `Out` is - // size and align compatible - unsafe { vstore(&mut elem[N] as *mut _ as *mut Out, s~N) }; - }); - } - - for elem in chunks.into_remainder() { - // Load a full vector from the aligned portion of the buffer. - // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is - // always a full vector in bounds. - let s0 = unsafe { vload(elem as *const _ as *const T) }; - - let s0 = Op::apply_vec::(s0); - // Store a full vector at the same position as the input. Cast is safe because `Out` is - // size and align compatible - unsafe { vstore(elem as *mut _ as *mut Out, s0) }; - } -} diff --git a/crates/burn/src/ops/simd/unary.rs b/crates/burn/src/ops/simd/unary.rs new file mode 120000 index 00000000..9050fcfc --- /dev/null +++ b/crates/burn/src/ops/simd/unary.rs @@ -0,0 +1 @@ +../../../upstream/crates/burn-ndarray/src/ops/simd/unary.rs \ No newline at end of file diff --git a/crates/burn/src/ops/transaction.rs b/crates/burn/src/ops/transaction.rs deleted file mode 100644 index b308c0f0..00000000 --- a/crates/burn/src/ops/transaction.rs +++ /dev/null @@ -1,13 +0,0 @@ -use crate::{ - FloatNdArrayElement, NdArray, NdArrayTensor, SharedArray, - element::{IntNdArrayElement, QuantElement}, -}; -use burn_backend::ops::TransactionOps; - -impl TransactionOps - for NdArray -where - NdArrayTensor: From>, - NdArrayTensor: From>, -{ -} diff --git a/crates/burn/src/ops/transaction.rs b/crates/burn/src/ops/transaction.rs new file mode 120000 index 00000000..56826d01 --- /dev/null +++ b/crates/burn/src/ops/transaction.rs @@ -0,0 +1 @@ +../../upstream/crates/burn-ndarray/src/ops/transaction.rs \ No newline at end of file diff --git a/crates/burn/src/parallel.rs b/crates/burn/src/parallel.rs deleted file mode 100644 index a6657619..00000000 --- a/crates/burn/src/parallel.rs +++ /dev/null @@ -1,76 +0,0 @@ -/// Macro for running a function in parallel. -#[cfg(feature = "multi-threads")] -#[macro_export(local_inner_macros)] -macro_rules! run_par { - ( - $func:expr - ) => {{ - use rayon::prelude::*; - - #[allow(clippy::redundant_closure_call)] - rayon::scope(|_| $func()) - }}; -} - -/// Macro for running a function in parallel. -#[cfg(not(feature = "multi-threads"))] -#[macro_export(local_inner_macros)] -macro_rules! run_par { - ( - $func:expr - ) => {{ $func() }}; -} - -/// Macro for iterating in parallel. -#[cfg(not(feature = "multi-threads"))] -#[macro_export(local_inner_macros)] -macro_rules! iter_par { - ( - $iter:expr - ) => {{ $iter }}; -} - -/// Macro for iterating in parallel. -#[cfg(feature = "multi-threads")] -#[macro_export(local_inner_macros)] -macro_rules! iter_par { - ( - $iter:expr - ) => {{ $iter.into_par_iter() }}; -} - -/// Macro for iterating in parallel. -#[cfg(feature = "multi-threads")] -#[macro_export(local_inner_macros)] -macro_rules! iter_slice_par { - ( - $slice:expr - ) => {{ $slice.into_par_iter() }}; -} - -/// Macro for iterating in parallel. -#[cfg(not(feature = "multi-threads"))] -#[macro_export(local_inner_macros)] -macro_rules! iter_slice_par { - ( - $slice:expr - ) => {{ $slice.iter() }}; -} - -/// Macro for iterating over a range in parallel. -#[cfg(feature = "multi-threads")] -#[macro_export(local_inner_macros)] -macro_rules! iter_range_par { - ( - $start:expr, $end:expr - ) => {{ ($start..$end).into_par_iter() }}; -} - -/// Macro for iterating over a range in parallel. -#[cfg(not(feature = "multi-threads"))] -#[macro_export(local_inner_macros)] -macro_rules! iter_range_par { - ( - $start:expr, $end:expr - ) => {{ ($start..$end) }}; -} diff --git a/crates/burn/src/parallel.rs b/crates/burn/src/parallel.rs new file mode 120000 index 00000000..0f21dc75 --- /dev/null +++ b/crates/burn/src/parallel.rs @@ -0,0 +1 @@ +../upstream/crates/burn-ndarray/src/parallel.rs \ No newline at end of file diff --git a/crates/burn/src/rand.rs b/crates/burn/src/rand.rs deleted file mode 100644 index 94b9bcda..00000000 --- a/crates/burn/src/rand.rs +++ /dev/null @@ -1,36 +0,0 @@ -//! Random number generation utilities for burn-ndarray - -#[cfg(not(feature = "std"))] -use rand::rngs::SmallRng; -#[cfg(feature = "std")] -use rand::rngs::StdRng; - -/// Type alias for the RNG used by burn-ndarray -#[cfg(feature = "std")] -pub type NdArrayRng = StdRng; -#[cfg(not(feature = "std"))] -pub type NdArrayRng = SmallRng; - -#[cfg(not(feature = "std"))] -use rand::SeedableRng; - -/// Get a seeded random number generator -/// -/// For std builds, uses OS entropy. -/// For no_std builds, uses a compile-time random seed. -#[cfg(feature = "std")] -pub fn get_seeded_rng() -> NdArrayRng { - // Use the standard implementation from burn-std - burn_std::rand::get_seeded_rng() -} - -/// Get a seeded random number generator -/// -/// For std builds, uses OS entropy. -/// For no_std builds, uses a compile-time random seed. -#[cfg(not(feature = "std"))] -pub fn get_seeded_rng() -> NdArrayRng { - // Use compile-time random seed for no_std - const SEED: u64 = const_random::const_random!(u64); - SmallRng::seed_from_u64(SEED) -} diff --git a/crates/burn/src/rand.rs b/crates/burn/src/rand.rs new file mode 120000 index 00000000..16c2e6e9 --- /dev/null +++ b/crates/burn/src/rand.rs @@ -0,0 +1 @@ +../upstream/crates/burn-ndarray/src/rand.rs \ No newline at end of file diff --git a/crates/burn/src/sharing.rs b/crates/burn/src/sharing.rs deleted file mode 100644 index 75d51421..00000000 --- a/crates/burn/src/sharing.rs +++ /dev/null @@ -1,19 +0,0 @@ -use core::cell::UnsafeCell; - -/// Similar to `SyncUnsafeCell` see [Rust issues](https://github.com/rust-lang/rust/issues/95439). -pub(crate) struct UnsafeSharedRef<'a, T> { - cell: UnsafeCell<&'a mut T>, -} - -unsafe impl Sync for UnsafeSharedRef<'_, T> {} - -impl<'a, T> UnsafeSharedRef<'a, T> { - pub fn new(data: &'a mut T) -> Self { - Self { - cell: UnsafeCell::new(data), - } - } - pub unsafe fn get(&self) -> &'a mut T { - unsafe { core::ptr::read(self.cell.get()) } - } -} diff --git a/crates/burn/src/sharing.rs b/crates/burn/src/sharing.rs new file mode 120000 index 00000000..a128359e --- /dev/null +++ b/crates/burn/src/sharing.rs @@ -0,0 +1 @@ +../upstream/crates/burn-ndarray/src/sharing.rs \ No newline at end of file diff --git a/crates/burn/src/storage.rs b/crates/burn/src/storage.rs deleted file mode 100644 index 7eeca47f..00000000 --- a/crates/burn/src/storage.rs +++ /dev/null @@ -1,506 +0,0 @@ -//! Copy-on-write storage for zero-copy tensor loading. -//! -//! This module provides `NdArrayStorage`, which enables true zero-copy loading -//! from burnpack files. When data is borrowed from external memory (like mmap'd files -//! or static data), it remains zero-copy until a mutating operation is performed, -//! at which point it's copied (copy-on-write semantics). -//! -//! This integrates with ndarray's existing COW patterns - operations that check -//! `is_unique()` will see borrowed data as non-unique, triggering the allocation path. - -use burn_backend::Element; -use burn_std::{Bytes, Shape}; -use core::mem; -use ndarray::{ArcArray, ArrayView, IxDyn}; - -/// Storage that supports both owned data and borrowed (zero-copy) data. -/// -/// # Copy-on-Write Semantics -/// -/// - **Borrowed**: Data from external source (burnpack, mmap, static). -/// Reports `is_unique() == false` to trigger copy on mutation. -/// - **Owned**: Standard `ArcArray` with built-in COW via Arc refcount. -/// -/// # Example -/// -/// ```ignore -/// // Zero-copy load -/// let storage = NdArrayStorage::from_borrowed(bytes, shape); -/// storage.is_unique(); // false - will copy on mutation -/// -/// // Read operations use view() - zero-copy -/// let view = storage.view(); -/// -/// // Mutation converts to owned -/// let owned = storage.into_owned(); // Copies here -/// ``` -#[derive(Debug)] -pub enum NdArrayStorage { - /// Borrowed from external source (e.g., burnpack zero-copy load). - /// Keeps `Bytes` alive to ensure the referenced memory is valid. - Borrowed { - /// Source bytes - keeps external memory alive via reference counting - bytes: Bytes, - /// Shape of the tensor - shape: Shape, - }, - - /// Standard owned storage with ArcArray COW semantics. - Owned(ArcArray), -} - -impl Clone for NdArrayStorage { - fn clone(&self) -> Self { - match self { - // For borrowed data, clone the Bytes (cheap Arc clone) and shape - Self::Borrowed { bytes, shape } => Self::Borrowed { - bytes: bytes.clone(), - shape: shape.clone(), - }, - // For owned data, clone the ArcArray (cheap Arc clone) - Self::Owned(arr) => Self::Owned(arr.clone()), - } - } -} - -impl NdArrayStorage { - /// Create borrowed storage from external bytes. - /// - /// Returns the bytes and shape back on failure (misaligned or too small), - /// enabling zero-copy even for native allocations by avoiding defensive cloning. - /// - /// # Requirements - /// - /// The caller must ensure that: - /// - The `Bytes` contain valid data for the element type `E` - /// - The data is contiguous in row-major (C) order matching the provided shape - /// - /// These requirements are upheld when loading from `TensorData` (burnpack, etc.) - /// which always stores data contiguously in row-major order. - pub fn from_borrowed(bytes: Bytes, shape: impl Into) -> Result { - let shape = shape.into(); - // Validate alignment - let ptr = bytes.as_ptr(); - if !(ptr as usize).is_multiple_of(mem::align_of::()) { - return Err((bytes, shape)); - } - - // Validate size (using checked arithmetic to prevent overflow) - let num_elements = match shape - .iter() - .try_fold(1usize, |acc, &dim| acc.checked_mul(dim)) - { - Some(n) => n, - None => return Err((bytes, shape)), - }; - let expected_size = match num_elements.checked_mul(mem::size_of::()) { - Some(s) => s, - None => return Err((bytes, shape)), - }; - if bytes.len() < expected_size { - return Err((bytes, shape)); - } - - Ok(Self::Borrowed { bytes, shape }) - } - - /// Create owned storage from an ArcArray. - #[inline] - pub fn from_owned(array: ArcArray) -> Self { - Self::Owned(array) - } - - /// Returns whether this storage is uniquely owned and can be mutated in-place. - /// - /// - **Borrowed**: Always returns `false` to trigger copy-on-write. - /// - **Owned**: Delegates to `ArcArray::is_unique()`. - /// - /// This integrates with existing SIMD code patterns like: - /// ```ignore - /// if tensor.is_unique() { - /// // mutate in place - /// } else { - /// // allocate new - /// } - /// ``` - #[inline] - pub fn is_unique(&self) -> bool { - match self { - Self::Borrowed { .. } => false, // Force copy path - Self::Owned(arr) => arr.is_unique(), - } - } - - /// Get a read-only view of the data. - /// - /// This is zero-copy for both borrowed and owned variants. - #[inline] - pub fn view(&self) -> ArrayView<'_, E, IxDyn> { - match self { - Self::Borrowed { bytes, shape } => { - let ptr = bytes.as_ptr() as *const E; - let dim = IxDyn(shape); - // SAFETY: - // - `bytes` is kept alive for the lifetime of `self` - // - Alignment was validated in `from_borrowed` - // - Size was validated in `from_borrowed` - unsafe { ArrayView::from_shape_ptr(dim, ptr) } - } - Self::Owned(arr) => arr.view(), - } - } - - /// Convert to owned ArcArray. - /// - /// - **Borrowed**: Copies the data into a new ArcArray. - /// - **Owned + unique**: Returns the array without copying. - /// - **Owned + shared**: Clones the data. - pub fn into_owned(self) -> ArcArray { - match self { - Self::Borrowed { bytes, shape } => { - let ptr = bytes.as_ptr() as *const E; - let dim = IxDyn(&shape); - // SAFETY: Same as view() - bytes is valid for this scope - let view = unsafe { ArrayView::from_shape_ptr(dim, ptr) }; - view.to_owned().into_shared() - } - Self::Owned(arr) => arr, - } - } - - /// Convert to shared ArcArray, suitable for returning from operations. - /// - /// This is equivalent to `into_owned()` but named for clarity. - #[inline] - pub fn into_shared(self) -> ArcArray { - self.into_owned() - } - - /// Get the shape of the tensor. - pub fn shape(&self) -> &[usize] { - match self { - Self::Borrowed { shape, .. } => shape, - Self::Owned(arr) => arr.shape(), - } - } - - /// Get the number of dimensions. - #[inline] - pub fn ndim(&self) -> usize { - self.shape().len() - } - - /// Get the total number of elements. - #[inline] - pub fn len(&self) -> usize { - self.shape().iter().product() - } - - /// Check if the tensor is empty. - #[inline] - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns `true` if this is borrowed (zero-copy) storage. - #[inline] - pub fn is_borrowed(&self) -> bool { - matches!(self, Self::Borrowed { .. }) - } - - /// Returns `true` if this is owned storage. - #[inline] - pub fn is_owned(&self) -> bool { - matches!(self, Self::Owned(_)) - } - - /// Ensure owned and return mutable reference to the ArcArray. - /// - /// Converts borrowed to owned if necessary. - pub fn ensure_owned(&mut self) -> &mut ArcArray { - if let Self::Borrowed { bytes, shape } = self { - let ptr = bytes.as_ptr() as *const E; - let dim = IxDyn(shape); - // SAFETY: Same as view() - let view = unsafe { ArrayView::from_shape_ptr(dim, ptr) }; - *self = Self::Owned(view.to_owned().into_shared()); - } - match self { - Self::Owned(arr) => arr, - Self::Borrowed { .. } => unreachable!(), - } - } -} - -/// Convert from ArcArray to NdArrayStorage. -impl From> for NdArrayStorage { - fn from(array: ArcArray) -> Self { - Self::Owned(array) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use alloc::{vec, vec::Vec}; - use burn_std::Bytes; - - #[test] - fn test_borrowed_is_not_unique() { - let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let bytes = Bytes::from_elems(data); - let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); - - assert!(!storage.is_unique()); - assert!(storage.is_borrowed()); - } - - #[test] - fn test_owned_unique_when_single_ref() { - let array = ndarray::ArrayD::from_elem(IxDyn(&[2, 2]), 1.0f32).into_shared(); - let storage = NdArrayStorage::from_owned(array); - - assert!(storage.is_unique()); - assert!(storage.is_owned()); - } - - #[test] - fn test_owned_not_unique_when_cloned() { - let array = ndarray::ArrayD::from_elem(IxDyn(&[2, 2]), 1.0f32).into_shared(); - let storage = NdArrayStorage::from_owned(array); - let _clone = storage.clone(); - - assert!(!storage.is_unique()); - } - - #[test] - fn test_view_zero_copy() { - let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let bytes = Bytes::from_elems(data); - let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); - - let view = storage.view(); - assert_eq!(view[[0, 0]], 1.0); - assert_eq!(view[[1, 1]], 4.0); - } - - #[test] - fn test_into_owned_copies_borrowed() { - let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let bytes = Bytes::from_elems(data); - let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); - - let owned = storage.into_owned(); - assert_eq!(owned[[0, 0]], 1.0); - assert_eq!(owned[[1, 1]], 4.0); - } - - #[test] - fn test_from_borrowed_validates_alignment() { - use burn_std::AllocationProperty; - - // Test 1: Properly aligned data should succeed - let aligned_data: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let aligned_bytes = Bytes::from_elems(aligned_data); - - // Verify test setup - should be 4-byte aligned for f32 - assert_eq!( - (aligned_bytes.as_ptr() as usize) % core::mem::align_of::(), - 0, - "Test setup: f32 data should be properly aligned" - ); - - let result = NdArrayStorage::::from_borrowed(aligned_bytes, [2, 2]); - assert!( - result.is_ok(), - "from_borrowed should succeed for properly aligned data" - ); - - // Test 2: Misaligned data should fail - // Create a buffer large enough to find a misaligned offset - // (static data placement varies by platform, so we find an offset dynamically) - let buffer: &[u8] = &[0u8; 32]; - let shared = bytes::Bytes::from_static(buffer); - let base = shared.as_ptr() as usize; - let align = core::mem::align_of::(); - - // Find an offset in 1..align that produces misalignment (at least one must exist) - let misalign_offset = (1..align) - .find(|&off| !(base + off).is_multiple_of(align)) - .expect("Should find a misaligned offset"); - - let sliced = shared.slice(misalign_offset..(misalign_offset + 16)); - let misaligned_bytes = Bytes::from_shared(sliced, AllocationProperty::Other); - - // Verify test setup - should NOT be 4-byte aligned - assert_ne!( - (misaligned_bytes.as_ptr() as usize) % align, - 0, - "Test setup: sliced data should be misaligned for f32" - ); - - let result = NdArrayStorage::::from_borrowed(misaligned_bytes, [4]); - assert!( - result.is_err(), - "from_borrowed should return Err for misaligned data" - ); - } - - #[test] - fn test_insufficient_size_returns_err() { - // Create bytes that are too small for the requested shape - let data: Vec = vec![1.0, 2.0]; // 8 bytes - let bytes = Bytes::from_elems(data); - - // Try to create storage for 4 elements (needs 16 bytes) - let result = NdArrayStorage::::from_borrowed(bytes, [4]); - assert!( - result.is_err(), - "from_borrowed should return Err when bytes are too small" - ); - } - - // ========================================================================== - // Zero-copy hardening tests - // These tests verify the zero-copy guarantee is maintained. If any of these - // fail, it indicates a regression in zero-copy functionality. - // ========================================================================== - - #[test] - fn test_zero_copy_native_allocation() { - // CRITICAL: Verify that native allocations (Bytes::from_elems) are zero-copy - // on initial load. The view() must return a pointer to the SAME memory. - // - // Note: Native allocations copy on clone (this is expected), but the initial - // load is still zero-copy, avoiding an extra copy in the common case where - // the tensor is used without cloning. - let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let bytes = Bytes::from_elems(data); - let original_ptr = bytes.as_ptr(); - - let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); - - // Initial load must be zero-copy - let view = storage.view(); - let view_ptr = view.as_ptr() as *const u8; - - assert_eq!( - original_ptr, view_ptr, - "ZERO-COPY REGRESSION: native allocation view() must return pointer to original bytes" - ); - - // Verify data integrity - assert_eq!(view[[0, 0]], 1.0); - assert_eq!(view[[0, 1]], 2.0); - assert_eq!(view[[1, 0]], 3.0); - assert_eq!(view[[1, 1]], 4.0); - } - - #[test] - fn test_zero_copy_shared_bytes_pointer_identity() { - // CRITICAL: Test with SharedBytesAllocationController for true zero-copy. - // This simulates the actual burnpack/mmap loading path. - use burn_std::AllocationProperty; - - // Create static-like data using bytes::Bytes - let data: &[u8] = &[ - 0, 0, 128, 63, // 1.0f32 in little-endian - 0, 0, 0, 64, // 2.0f32 - 0, 0, 64, 64, // 3.0f32 - 0, 0, 128, 64, // 4.0f32 - ]; - let shared = bytes::Bytes::from_static(data); - let original_ptr = shared.as_ptr(); - - // Create Bytes with SharedBytesAllocationController - let bytes = Bytes::from_shared(shared, AllocationProperty::Other); - - let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); - - // Verify pointer identity - let view_ptr = storage.view().as_ptr() as *const u8; - assert_eq!( - original_ptr, view_ptr, - "ZERO-COPY REGRESSION: SharedBytes view must point to original static data" - ); - - // Clone should also share the same memory - let cloned = storage.clone(); - let cloned_ptr = cloned.view().as_ptr() as *const u8; - assert_eq!( - original_ptr, cloned_ptr, - "ZERO-COPY REGRESSION: SharedBytes clone must share memory" - ); - } - - #[test] - fn test_clone_borrowed_stays_borrowed() { - // Verify that cloning borrowed storage produces another borrowed storage. - // Note: The underlying Bytes may or may not share memory depending on - // the allocation controller (native allocations copy, file-backed may share). - let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let bytes = Bytes::from_elems(data); - - let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); - let cloned = storage.clone(); - - // Both should still be borrowed (the storage type is preserved) - assert!( - storage.is_borrowed(), - "ZERO-COPY REGRESSION: original should remain borrowed after clone" - ); - assert!( - cloned.is_borrowed(), - "ZERO-COPY REGRESSION: clone should be borrowed type" - ); - - // Both should report not unique (important for COW behavior) - assert!( - !storage.is_unique(), - "ZERO-COPY REGRESSION: original should not be unique after clone" - ); - assert!( - !cloned.is_unique(), - "ZERO-COPY REGRESSION: clone should not be unique" - ); - - // Data should be identical - assert_eq!(storage.view(), cloned.view(), "Clone should have same data"); - } - - #[test] - fn test_zero_copy_triggers_copy_on_mutation() { - // Verify that into_owned() on borrowed data creates a NEW allocation - // (this is the "copy" in copy-on-write) - let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let bytes = Bytes::from_elems(data); - let original_ptr = bytes.as_ptr(); - - let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); - - assert!(storage.is_borrowed(), "should start as borrowed"); - - let owned = storage.into_owned(); - let owned_ptr = owned.as_ptr() as *const u8; - - assert_ne!( - original_ptr, owned_ptr, - "into_owned() on borrowed data MUST allocate new memory (copy-on-write)" - ); - } - - #[test] - fn test_borrowed_reports_not_unique() { - // CRITICAL: Borrowed storage must report is_unique() == false - // This is what triggers copy-on-write in mutation operations - let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let bytes = Bytes::from_elems(data); - let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); - - assert!( - !storage.is_unique(), - "ZERO-COPY REGRESSION: borrowed storage MUST report is_unique() == false \ - to trigger copy-on-write. If this is true, mutations will corrupt shared data!" - ); - } -} diff --git a/crates/burn/src/storage.rs b/crates/burn/src/storage.rs new file mode 120000 index 00000000..8001bfba --- /dev/null +++ b/crates/burn/src/storage.rs @@ -0,0 +1 @@ +../upstream/crates/burn-ndarray/src/storage.rs \ No newline at end of file diff --git a/crates/burn/src/tensor.rs b/crates/burn/src/tensor.rs deleted file mode 100644 index 97699a1f..00000000 --- a/crates/burn/src/tensor.rs +++ /dev/null @@ -1,955 +0,0 @@ -use burn_backend::{ - AllocationProperty, DType, Element, QTensorPrimitive, Shape, TensorData, TensorMetadata, - quantization::{QParams, QuantLevel, QuantMode, QuantScheme, QuantValue}, -}; -use burn_std::BoolStore; - -use crate::NdArrayStorage; -use crate::ops::quantization::{QuantizationStrategy, SymmetricQuantization}; -use alloc::vec::Vec; -use ndarray::{ArcArray, ArrayD, IxDyn}; - -/// Concrete storage type for ndarray (owned with COW semantics via Arc) -pub type SharedArray = ArcArray; - -/// Tensor primitive used by the [ndarray backend](crate::NdArray). -/// -/// Supports both owned and borrowed (zero-copy) data via `NdArrayStorage`. -/// When data is borrowed from external sources (like burnpack files), -/// it remains zero-copy until a mutating operation is performed. -#[derive(Debug, Clone)] -#[allow(missing_docs)] -pub enum NdArrayTensor { - F64(NdArrayStorage), - F32(NdArrayStorage), - I64(NdArrayStorage), - I32(NdArrayStorage), - I16(NdArrayStorage), - I8(NdArrayStorage), - U64(NdArrayStorage), - U32(NdArrayStorage), - U16(NdArrayStorage), - U8(NdArrayStorage), - Bool(NdArrayStorage), -} - -impl NdArrayTensor { - /// Extract bool array, converting to owned if necessary. - pub(crate) fn bool(self) -> SharedArray { - match self { - NdArrayTensor::Bool(storage) => storage.into_shared(), - _ => unimplemented!("Expected bool tensor, got {:?}", self.dtype()), - } - } - - /// Returns true if this tensor uses borrowed (zero-copy) storage. - #[inline] - pub fn is_borrowed(&self) -> bool { - macro_rules! check { - ($($variant:ident),*) => { - match self { - $(NdArrayTensor::$variant(s) => s.is_borrowed(),)* - } - }; - } - check!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool) - } -} - -pub(crate) fn cast_to_dtype(array: SharedArray, dtype: DType) -> NdArrayTensor -where - NdArrayTensor: From>, -{ - fn cast(array: SharedArray) -> SharedArray { - array.mapv(|a| a.elem()).into_shared() - } - - if E1::dtype() == dtype { - return array.into(); - } - - match dtype { - DType::F64 => cast::(array).into(), - DType::F32 => cast::(array).into(), - DType::Flex32 => cast::(array).into(), - DType::I64 => cast::(array).into(), - DType::I32 => cast::(array).into(), - DType::I16 => cast::(array).into(), - DType::I8 => cast::(array).into(), - DType::U64 => cast::(array).into(), - DType::U32 => cast::(array).into(), - DType::U16 => cast::(array).into(), - DType::U8 => cast::(array).into(), - DType::Bool(BoolStore::Native) => cast::(array).into(), - dtype => panic!("Unsupported dtype: {dtype:?}"), - } -} - -macro_rules! impl_from { - ($($ty: ty => $dtype: ident),*) => { - // From SharedArray (owned) -> NdArrayTensor - $(impl From> for NdArrayTensor { - fn from(value: SharedArray<$ty>) -> NdArrayTensor { - NdArrayTensor::$dtype(NdArrayStorage::from_owned(value)) - } - })* - - // From NdArrayStorage -> NdArrayTensor - $(impl From> for NdArrayTensor { - fn from(value: NdArrayStorage<$ty>) -> NdArrayTensor { - NdArrayTensor::$dtype(value) - } - })* - }; -} - -impl_from!( - f64 => F64, f32 => F32, - i64 => I64, i32 => I32, i16 => I16, i8 => I8, - u64 => U64, u32 => U32, u16 => U16, u8 => U8, - bool => Bool -); - -/// Macro to execute an operation on a given element type. -/// -/// Extracts the storage from NdArrayTensor, converts to SharedArray, and passes to operation. -/// -/// # Panics -/// Since there is no automatic type cast at this time, binary operations for different -/// floating point precision data types will panic with a data type mismatch. -#[macro_export] -macro_rules! execute_with_dtype { - (($lhs:expr, $rhs:expr),$element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{ - let lhs_dtype = burn_backend::TensorMetadata::dtype(&$lhs); - let rhs_dtype = burn_backend::TensorMetadata::dtype(&$rhs); - match ($lhs, $rhs) { - $( - ($crate::NdArrayTensor::$dtype(lhs), $crate::NdArrayTensor::$dtype(rhs)) => { - #[allow(unused)] - type $element = $ty; - // Convert storage to SharedArray for compatibility with existing operations - $op(lhs.into_shared(), rhs.into_shared()).into() - } - )* - _ => panic!( - "Data type mismatch (lhs: {:?}, rhs: {:?})", - lhs_dtype, rhs_dtype - ), - } - }}; - // Binary op: type automatically inferred by the compiler - (($lhs:expr, $rhs:expr), $op:expr) => {{ - $crate::execute_with_dtype!(($lhs, $rhs), E, $op) - }}; - - // Binary op: generic type cannot be inferred for an operation - (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{ - $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [ - F64 => f64, F32 => f32, - I64 => i64, I32 => i32, I16 => i16, I8 => i8, - U64 => u64, U32 => u32, U16 => u16, U8 => u8, - Bool => bool - ]) - }}; - - ($tensor:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{ - match $tensor { - $( - $crate::NdArrayTensor::$dtype(storage) => { - #[allow(unused)] - type $element = $ty; - // Convert to SharedArray for compatibility with most operations - $op(storage.into_shared()).into() - } - )* - #[allow(unreachable_patterns)] - other => unimplemented!("unsupported dtype: {:?}", other.dtype()) - } - }}; - // Unary op: type automatically inferred by the compiler - ($tensor:expr, $op:expr) => {{ - $crate::execute_with_dtype!($tensor, E, $op) - }}; - - // Unary op: generic type cannot be inferred for an operation - ($tensor:expr, $element:ident, $op:expr) => {{ - $crate::execute_with_dtype!($tensor, $element, $op, [ - F64 => f64, F32 => f32, - I64 => i64, I32 => i32, I16 => i16, I8 => i8, - U64 => u64, U32 => u32, U16 => u16, U8 => u8, - Bool => bool - ]) - }}; -} - -/// Macro to execute an operation a given element type. -/// Only handles float types. -/// -/// # Panics -/// Since there is no automatic type cast at this time, binary operations for different -/// floating point precision data types will panic with a data type mismatch. -#[macro_export] -macro_rules! execute_with_float_dtype { - // Binary op: type automatically inferred by the compiler - (($lhs:expr, $rhs:expr), $op:expr) => {{ - $crate::execute_with_float_dtype!(($lhs, $rhs), E, $op) - }}; - - // Binary op: generic type cannot be inferred for an operation - (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{ - $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [ - F64 => f64, F32 => f32 - ]) - }}; - - // Unary op: type automatically inferred by the compiler - ($tensor:expr, $op:expr) => {{ - $crate::execute_with_float_dtype!($tensor, E, $op) - }}; - - // Unary op: generic type cannot be inferred for an operation - ($tensor:expr, $element:ident, $op:expr) => {{ - $crate::execute_with_dtype!($tensor, $element, $op, [ - F64 => f64, F32 => f32 - ]) - }}; -} - -/// Macro to execute an operation a given element type. -/// Only handles int types. -/// -/// # Panics -/// Since there is no automatic type cast at this time, binary operations for different -/// floating point precision data types will panic with a data type mismatch. -#[macro_export] -macro_rules! execute_with_int_dtype { - // Binary op: type automatically inferred by the compiler - (($lhs:expr, $rhs:expr), $op:expr) => {{ - $crate::execute_with_int_dtype!(($lhs, $rhs), E, $op) - }}; - - // Binary op: generic type cannot be inferred for an operation - (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{ - $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [ - I64 => i64, I32 => i32, I16 => i16, I8 => i8, - U64 => u64, U32 => u32, U16 => u16, U8 => u8 - ]) - }}; - - // Unary op: type automatically inferred by the compiler - ($tensor:expr, $op:expr) => {{ - $crate::execute_with_int_dtype!($tensor, E, $op) - }}; - - // Unary op: generic type cannot be inferred for an operation - ($tensor:expr, $element:ident, $op:expr) => {{ - $crate::execute_with_dtype!($tensor, $element, $op, [ - I64 => i64, I32 => i32, I16 => i16, I8 => i8, - U64 => u64, U32 => u32, U16 => u16, U8 => u8 - ]) - }}; -} - -/// Macro to execute an operation a given element type. -/// Only handles numeric types -/// -/// # Panics -/// Since there is no automatic type cast at this time, binary operations for different -/// floating point precision data types will panic with a data type mismatch. -#[macro_export] -macro_rules! execute_with_numeric_dtype { - // Binary op: type automatically inferred by the compiler - (($lhs:expr, $rhs:expr), $op:expr) => {{ - $crate::execute_with_numeric_dtype!(($lhs, $rhs), E, $op) - }}; - - // Binary op: generic type cannot be inferred for an operation - (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{ - $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [ - F64 => f64, F32 => f32, - I64 => i64, I32 => i32, I16 => i16, I8 => i8, - U64 => u64, U32 => u32, U16 => u16, U8 => u8 - ]) - }}; - - // Unary op: type automatically inferred by the compiler - ($tensor:expr, $op:expr) => {{ - $crate::execute_with_numeric_dtype!($tensor, E, $op) - }}; - - // Unary op: generic type cannot be inferred for an operation - ($tensor:expr, $element:ident, $op:expr) => {{ - $crate::execute_with_dtype!($tensor, $element, $op, [ - F64 => f64, F32 => f32, - I64 => i64, I32 => i32, I16 => i16, I8 => i8, - U64 => u64, U32 => u32, U16 => u16, U8 => u8 - ]) - }}; -} - -/// Macro to execute a cat operation on a given set of element types. -/// -/// Uses zero-copy views from storage for concatenation. -/// -/// # Panics -/// Since there is no automatic type cast at this time, binary operations for different -/// floating point precision data types will panic with a data type mismatch. -#[macro_export] -macro_rules! cat_with_dtype { - ($tensors: expr, $dim: expr, [$($dtype: ident),*]) => { - match &$tensors[0] { - $(NdArrayTensor::$dtype(_) => { - let tensors = $tensors - .iter() - .map(|t| { - if let NdArrayTensor::$dtype(storage) = t { - // Use storage.view() for zero-copy access - storage.view() - } else { - panic!("Concatenate data type mismatch (expected {:?}, got {:?})", $tensors[0].dtype(), t.dtype()) - } - }) - .collect::>(); - NdArrayOps::concatenate(&tensors, $dim).into() - })* - _ => panic!("Unsupported dtype: {:?}", $tensors[0].dtype()) - } - }; -} - -/// Macro to execute an operation that returns a given element type. -#[macro_export] -macro_rules! execute_with_float_out_dtype { - ($out_dtype:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{ - match $out_dtype { - $( - burn_std::FloatDType::$dtype => { - #[allow(unused)] - type $element = $ty; - $op - } - )* - #[allow(unreachable_patterns)] - other => unimplemented!("unsupported dtype: {other:?}") - } - }}; - // Unary op: type automatically inferred by the compiler - ($out_dtype:expr, $op:expr) => {{ - $crate::execute_with_float_out_dtype!($out_dtype, E, $op) - }}; - - // Unary op: generic type cannot be inferred for an operation - ($out_dtype:expr, $element:ident, $op:expr) => {{ - $crate::execute_with_float_out_dtype!($out_dtype, $element, $op, [ - F64 => f64, F32 => f32 - ]) - }}; -} - -/// Macro to execute an operation that returns a given element type. -#[macro_export] -macro_rules! execute_with_int_out_dtype { - ($out_dtype:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{ - match $out_dtype { - $( - burn_std::IntDType::$dtype => { - #[allow(unused)] - type $element = $ty; - $op - } - )* - #[allow(unreachable_patterns)] - other => unimplemented!("unsupported dtype: {other:?}") - } - }}; - // Unary op: type automatically inferred by the compiler - ($out_dtype:expr, $op:expr) => {{ - $crate::execute_with_int_out_dtype!($out_dtype, E, $op) - }}; - - // Unary op: generic type cannot be inferred for an operation - ($out_dtype:expr, $element:ident, $op:expr) => {{ - $crate::execute_with_int_out_dtype!($out_dtype, $element, $op, [ - I64 => i64, I32 => i32, I16 => i16, I8 => i8, - U64 => u64, U32 => u32, U16 => u16, U8 => u8 - ]) - }}; -} - -impl TensorMetadata for NdArrayTensor { - fn dtype(&self) -> DType { - match self { - NdArrayTensor::F64(_) => DType::F64, - NdArrayTensor::F32(_) => DType::F32, - NdArrayTensor::I64(_) => DType::I64, - NdArrayTensor::I32(_) => DType::I32, - NdArrayTensor::I16(_) => DType::I16, - NdArrayTensor::I8(_) => DType::I8, - NdArrayTensor::U64(_) => DType::U64, - NdArrayTensor::U32(_) => DType::U32, - NdArrayTensor::U16(_) => DType::U16, - NdArrayTensor::U8(_) => DType::U8, - NdArrayTensor::Bool(_) => DType::Bool(BoolStore::Native), - } - } - - fn shape(&self) -> Shape { - // Use storage's shape method (works for both borrowed and owned) - macro_rules! get_shape { - ($($variant:ident),*) => { - match self { - $(NdArrayTensor::$variant(storage) => Shape::from(storage.shape().to_vec()),)* - } - }; - } - get_shape!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool) - } - - fn rank(&self) -> usize { - self.shape().num_dims() - } -} - -pub(crate) trait ShapeOps { - fn num_dims(self) -> usize; - fn num_elements(self) -> usize; - fn dims(self) -> [usize; N]; - fn into_shape(self) -> Shape; -} - -impl ShapeOps for &[usize] { - fn num_dims(self) -> usize { - self.len() - } - - fn num_elements(self) -> usize { - self.iter().product() - } - - fn dims(self) -> [usize; N] { - self.try_into().unwrap() - } - - fn into_shape(self) -> Shape { - Shape::from(self) - } -} - -mod utils { - use burn_std::tensor::is_contiguous; - - use super::*; - - impl NdArrayTensor { - pub(crate) fn into_data(self) -> TensorData { - let shape = self.shape(); - let contiguous = self.is_contiguous(); - - fn inner( - shape: Shape, - is_contiguous: bool, - array: ArcArray, - ) -> TensorData { - let vec = if is_contiguous { - match array.try_into_owned_nocopy() { - Ok(owned) => { - let (mut vec, offset) = owned.into_raw_vec_and_offset(); - if let Some(offset) = offset { - vec.drain(..offset); - } - if vec.len() > shape.num_elements() { - vec.drain(shape.num_elements()..vec.len()); - } - vec - } - Err(array) => array.into_iter().collect(), - } - } else { - array.into_iter().collect() - }; - - TensorData::new(vec, shape) - } - - // Convert storage to owned array before extracting data - execute_with_dtype!(self, |arr| inner(shape, contiguous, arr)) - } - - pub(crate) fn is_contiguous(&self) -> bool { - // For borrowed data, we assume it's contiguous (it came from TensorData which is contiguous) - // For owned data, we check the strides - macro_rules! check_contiguous { - ($($variant:ident),*) => { - match self { - $(NdArrayTensor::$variant(storage) => { - match storage { - NdArrayStorage::Borrowed { .. } => { - // Borrowed storage requires contiguous row-major data - // (see NdArrayStorage::from_borrowed documentation) - true - } - NdArrayStorage::Owned(array) => { - let shape = array.shape(); - let mut strides = Vec::with_capacity(array.strides().len()); - for &stride in array.strides() { - if stride <= 0 { - return false; - } - strides.push(stride as usize); - } - is_contiguous(shape, &strides) - } - } - })* - } - }; - } - check_contiguous!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool) - } - } -} - -/// Converts a slice of usize to a typed dimension. -#[macro_export(local_inner_macros)] -macro_rules! to_typed_dims { - ( - $n:expr, - $dims:expr, - justdim - ) => {{ - let mut dims = [0; $n]; - for i in 0..$n { - dims[i] = $dims[i]; - } - let dim: Dim<[usize; $n]> = Dim(dims); - dim - }}; -} - -/// Reshapes an array into a tensor. -#[macro_export(local_inner_macros)] -macro_rules! reshape { - ( - ty $ty:ty, - n $n:expr, - shape $shape:expr, - array $array:expr - ) => {{ - let dim = $crate::to_typed_dims!($n, $shape, justdim); - let array = match $array.is_standard_layout() { - true => { - match $array.to_shape(dim) { - Ok(val) => val.into_shared(), - Err(err) => { - core::panic!("Shape should be compatible shape={dim:?}: {err:?}"); - } - } - }, - false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(), - }; - array.into_dyn() - }}; - ( - ty $ty:ty, - shape $shape:expr, - array $array:expr, - d $D:expr - ) => {{ - match $D { - 1 => reshape!(ty $ty, n 1, shape $shape, array $array), - 2 => reshape!(ty $ty, n 2, shape $shape, array $array), - 3 => reshape!(ty $ty, n 3, shape $shape, array $array), - 4 => reshape!(ty $ty, n 4, shape $shape, array $array), - 5 => reshape!(ty $ty, n 5, shape $shape, array $array), - 6 => reshape!(ty $ty, n 6, shape $shape, array $array), - _ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D), - } - }}; -} - -/// Slice a tensor -#[macro_export] -macro_rules! slice { - ($tensor:expr, $slices:expr) => { - slice!($tensor, $slices, F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool) - }; - ($tensor:expr, $slices:expr, $($variant:ident),*) => { - match $tensor { - $(NdArrayTensor::$variant(s) => { NdArrayOps::slice(s.view(), $slices).into() })* - } - }; -} - -impl NdArrayTensor { - /// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData). - /// - /// This method attempts zero-copy loading when possible. If the data has properly - /// aligned bytes that can be borrowed, it creates a borrowed tensor. Otherwise, - /// it falls back to copying the data. - /// - /// Zero-copy loading works when: - /// - The data's bytes are properly aligned for the element type - /// - The bytes can be borrowed (e.g., from mmap'd file or static data) - pub fn from_data(data: TensorData) -> NdArrayTensor { - // Only use Borrowed storage for non-native allocations (e.g., burnpack mmap/file). - // For native Rust heap allocations (the common case), go directly to owned storage: - // `from_data_owned` reclaims the Vec zero-copy via `into_vec`, while - // Borrowed storage would trigger a full memcopy on every single operation. - if data.bytes.property() != AllocationProperty::Native { - match Self::try_from_data_borrowed(data) { - Ok(tensor) => return tensor, - Err(data) => return Self::from_data_owned(data), - } - } - Self::from_data_owned(data) - } - - /// Try to create a tensor with borrowed storage (zero-copy). - /// - /// Takes ownership of TensorData and returns it back on failure. - /// No cloning occurs - bytes are moved into storage or returned on failure. - /// - /// Returns `Err(data)` if borrowing is not possible (e.g., misaligned data). - fn try_from_data_borrowed(data: TensorData) -> Result { - let TensorData { - bytes, - shape, - dtype, - } = data; - - macro_rules! try_borrow { - ($ty:ty, $variant:ident, $bytes:expr, $shape:expr) => { - match NdArrayStorage::<$ty>::from_borrowed($bytes, $shape) { - Ok(storage) => return Ok(NdArrayTensor::$variant(storage)), - Err((bytes, shape)) => (bytes, shape), - } - }; - } - - // Try to create borrowed storage; get bytes back on failure - let (bytes, shape) = match dtype { - DType::F64 => try_borrow!(f64, F64, bytes, shape), - DType::F32 => try_borrow!(f32, F32, bytes, shape), - DType::I64 => try_borrow!(i64, I64, bytes, shape), - DType::I32 => try_borrow!(i32, I32, bytes, shape), - DType::I16 => try_borrow!(i16, I16, bytes, shape), - DType::I8 => try_borrow!(i8, I8, bytes, shape), - DType::U64 => try_borrow!(u64, U64, bytes, shape), - DType::U32 => try_borrow!(u32, U32, bytes, shape), - DType::U16 => try_borrow!(u16, U16, bytes, shape), - DType::U8 => try_borrow!(u8, U8, bytes, shape), - DType::Bool(BoolStore::Native) => try_borrow!(bool, Bool, bytes, shape), - _ => (bytes, shape), // QFloat not supported for zero-copy - }; - - Err(TensorData { - bytes, - shape, - dtype, - }) - } - - /// Create a tensor with owned storage. - /// - /// This may or may not copy data depending on whether the underlying bytes - /// can be reclaimed (via `try_into_vec`). If bytes are uniquely owned, - /// no copy occurs; otherwise data is copied to a new allocation. - fn from_data_owned(data: TensorData) -> NdArrayTensor { - let shape = data.shape.to_vec(); // TODO: into_vec - - macro_rules! execute { - ($data: expr, [$($dtype: pat => $ty: ty),*]) => { - match $data.dtype { - $( $dtype => { - match data.into_vec::<$ty>() { - Ok(vec) => unsafe { ArrayD::from_shape_vec_unchecked(shape, vec) }.into_shared(), - Err(err) => panic!("Data should have the same element type as the tensor {err:?}"), - }.into() - }, )* - other => unimplemented!("Unsupported dtype {other:?}"), - } - }; - } - - execute!(data, [ - DType::F64 => f64, DType::F32 => f32, - DType::I64 => i64, DType::I32 => i32, DType::I16 => i16, DType::I8 => i8, - DType::U64 => u64, DType::U32 => u32, DType::U16 => u16, DType::U8 => u8, - DType::Bool(BoolStore::Native) => bool - ]) - } -} - -/// A quantized tensor for the ndarray backend. -#[derive(Clone, Debug)] -pub struct NdArrayQTensor { - /// The quantized tensor. - pub qtensor: NdArrayTensor, - /// The quantization scheme. - pub scheme: QuantScheme, - /// The quantization parameters. - pub qparams: Vec>, -} - -impl NdArrayQTensor { - /// Returns the quantization strategy, including quantization parameters, for the given tensor. - pub fn strategy(&self) -> QuantizationStrategy { - match self.scheme { - QuantScheme { - level: QuantLevel::Tensor, - mode: QuantMode::Symmetric, - value: - QuantValue::Q8F - | QuantValue::Q8S - | QuantValue::E4M3 - | QuantValue::E5M2 - | QuantValue::Q4F - | QuantValue::Q4S - | QuantValue::E2M1 - | QuantValue::Q2F - | QuantValue::Q2S, - .. - } => QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init( - self.qparams[0].scales, - self.scheme.value, - )), - QuantScheme { - level: QuantLevel::Block(block_size), - mode: QuantMode::Symmetric, - value: - QuantValue::Q8F - | QuantValue::Q8S - | QuantValue::E4M3 - | QuantValue::E5M2 - | QuantValue::Q4F - | QuantValue::Q4S - | QuantValue::E2M1 - | QuantValue::Q2F - | QuantValue::Q2S, - .. - } => QuantizationStrategy::PerBlockSymmetric( - self.qparams - .iter() - .map(|q| SymmetricQuantization::init(q.scales, self.scheme.value)) - .collect(), - block_size, - ), - } - } -} - -impl QTensorPrimitive for NdArrayQTensor { - fn scheme(&self) -> &QuantScheme { - &self.scheme - } - - fn default_scheme() -> QuantScheme { - QuantScheme::default().with_store(burn_backend::quantization::QuantStore::Native) - } -} - -impl TensorMetadata for NdArrayQTensor { - fn dtype(&self) -> DType { - DType::QFloat(self.scheme) - } - - fn shape(&self) -> Shape { - self.qtensor.shape() - } - - fn rank(&self) -> usize { - self.shape().num_dims() - } -} - -#[cfg(test)] -mod tests { - use crate::NdArray; - use alloc::vec; - - use super::*; - use burn_backend::{ - Distribution, - ops::{FloatTensorOps, QTensorOps}, - quantization::{QuantStore, QuantizationParametersPrimitive}, - }; - use burn_std::rand::get_seeded_rng; - - #[test] - fn should_support_into_and_from_data_1d() { - let data_expected = TensorData::random::( - Shape::new([3]), - Distribution::Default, - &mut get_seeded_rng(), - ); - let tensor = NdArrayTensor::from_data(data_expected.clone()); - - let data_actual = tensor.into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_into_and_from_data_2d() { - let data_expected = TensorData::random::( - Shape::new([2, 3]), - Distribution::Default, - &mut get_seeded_rng(), - ); - let tensor = NdArrayTensor::from_data(data_expected.clone()); - - let data_actual = tensor.into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_into_and_from_data_3d() { - let data_expected = TensorData::random::( - Shape::new([2, 3, 4]), - Distribution::Default, - &mut get_seeded_rng(), - ); - let tensor = NdArrayTensor::from_data(data_expected.clone()); - - let data_actual = tensor.into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_into_and_from_data_4d() { - let data_expected = TensorData::random::( - Shape::new([2, 3, 4, 2]), - Distribution::Default, - &mut get_seeded_rng(), - ); - let tensor = NdArrayTensor::from_data(data_expected.clone()); - - let data_actual = tensor.into_data(); - - assert_eq!(data_expected, data_actual); - } - - #[test] - fn should_support_qtensor_strategy() { - type B = NdArray; - let scale: f32 = 0.009_019_608; - let device = Default::default(); - - let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device); - let scheme = QuantScheme::default() - .with_value(QuantValue::Q8S) - .with_store(QuantStore::Native); - let qparams = QuantizationParametersPrimitive { - scales: B::float_from_data(TensorData::from([scale]), &device), - }; - let qtensor: NdArrayQTensor = B::quantize(tensor, &scheme, qparams); - - assert_eq!(qtensor.scheme(), &scheme); - assert_eq!( - qtensor.strategy(), - QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init( - scale, - QuantValue::Q8S - )) - ); - } - - // ========================================================================== - // Zero-copy integration tests - // These tests verify end-to-end zero-copy behavior through NdArrayTensor. - // ========================================================================== - - #[test] - fn zero_copy_creates_borrowed_storage_for_non_native() { - // Verify that from_data creates borrowed storage for non-native allocations - // (e.g. burnpack mmap/file data tagged with AllocationProperty::Other or File). - // Native heap allocations intentionally use Owned storage for performance. - use burn_backend::AllocationProperty; - use burn_std::Bytes; - - let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let bytes = Bytes::from_elems(data); - // Tag as Other to simulate burnpack / mmap data (non-native backing storage) - let non_native_bytes = Bytes::from_shared( - bytes::Bytes::copy_from_slice(&bytes), - AllocationProperty::Other, - ); - let tensor_data = TensorData::from_bytes(non_native_bytes, Shape::new([2, 2]), DType::F32); - - let tensor = NdArrayTensor::from_data(tensor_data); - - match &tensor { - NdArrayTensor::F32(storage) => { - assert!( - storage.is_borrowed(), - "ZERO-COPY REGRESSION: from_data should create borrowed storage \ - for non-native (e.g. burnpack) TensorData" - ); - assert!( - !storage.is_unique(), - "ZERO-COPY REGRESSION: borrowed storage must report is_unique() == false" - ); - } - _ => panic!("Expected F32 tensor"), - } - } - - #[test] - fn native_alloc_creates_owned_storage() { - // Native heap allocations must use Owned storage to avoid the memcpy. - use burn_std::Bytes; - - let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let bytes = Bytes::from_elems(data); // AllocationProperty::Native - let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32); - - let tensor = NdArrayTensor::from_data(tensor_data); - - match &tensor { - NdArrayTensor::F32(storage) => { - assert!( - !storage.is_borrowed(), - "PERF REGRESSION: from_data must NOT create borrowed storage \ - for native TensorData" - ); - } - _ => panic!("Expected F32 tensor"), - } - } - - #[test] - fn zero_copy_data_integrity() { - // Verify data is correctly accessible through borrowed storage - use burn_std::Bytes; - - let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let bytes = Bytes::from_elems(data); - let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32); - - let tensor = NdArrayTensor::from_data(tensor_data); - - match &tensor { - NdArrayTensor::F32(storage) => { - let view = storage.view(); - assert_eq!(view[[0, 0]], 1.0); - assert_eq!(view[[0, 1]], 2.0); - assert_eq!(view[[1, 0]], 3.0); - assert_eq!(view[[1, 1]], 4.0); - } - _ => panic!("Expected F32 tensor"), - } - } - - #[test] - fn zero_copy_fallback_when_bytes_owned() { - // When TensorData owns bytes exclusively, it may use the copy path - // This is expected behavior - verify it still works correctly - let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]); - let tensor = NdArrayTensor::from_data(data.clone()); - let result = tensor.into_data(); - - assert_eq!(data, result, "Data should round-trip correctly"); - } -} diff --git a/crates/burn/src/tensor.rs b/crates/burn/src/tensor.rs new file mode 120000 index 00000000..8b223a80 --- /dev/null +++ b/crates/burn/src/tensor.rs @@ -0,0 +1 @@ +../upstream/crates/burn-ndarray/src/tensor.rs \ No newline at end of file diff --git a/crates/burn/upstream b/crates/burn/upstream new file mode 160000 index 00000000..ed72d2b1 --- /dev/null +++ b/crates/burn/upstream @@ -0,0 +1 @@ +Subproject commit ed72d2b125a364aff18aed2a53396c128e01cb42