diff --git a/src/backend/mkl.rs b/src/backend/mkl.rs index 6805db73..f8424850 100644 --- a/src/backend/mkl.rs +++ b/src/backend/mkl.rs @@ -7,11 +7,18 @@ #![allow(non_snake_case)] +use crate::{ArrayView2, ArrayViewMut2}; use std::os::raw::{c_double, c_float, c_int, c_long, c_void}; const CBLAS_ROW_MAJOR: c_int = 101; const CBLAS_NO_TRANS: c_int = 111; +// `cblas_gemm_s8u8s32` / `cblas_gemm_s8s8s32` use CBLAS_OFFSET enums for the +// final argument (offset mode). `RowOffset = 171`, `ColOffset = 172`, +// `FixOffset = 173` — we always use `FixOffset` with a zero offset, which +// matches Burn / rustyblas behaviour. +const CBLAS_OFFSET_FIX: c_int = 173; + // ═══════════════════════════════════════════════════════════════ // CBLAS (shared API surface with OpenBLAS) // ═══════════════════════════════════════════════════════════════ @@ -56,6 +63,32 @@ extern "C" { x: *const c_double, incx: c_int, beta: c_double, y: *mut c_double, incy: c_int, ); + + // Mixed-precision GEMM: BF16 inputs, F32 accumulator. + // MKL takes `*const u16` for BF16 operands (no native bf16 type in C ABI). + // Reference: oneAPI MKL Developer Reference, "cblas_gemm_bf16bf16f32". + fn cblas_gemm_bf16bf16f32( + layout: c_int, transa: c_int, transb: c_int, + m: c_int, n: c_int, k: c_int, + alpha: c_float, a: *const u16, lda: c_int, + b: *const u16, ldb: c_int, + beta: c_float, c: *mut c_float, ldc: c_int, + ); + + // Integer GEMM: i8 × i8 → i32. + // The trailing offset arguments take CBLAS_OFFSET (= FixOffset) plus a + // pointer to the offset value. Passing zero offsets matches a plain + // matmul without zero-point correction. + // Reference: oneAPI MKL Developer Reference, "cblas_gemm_s8s8s32". + fn cblas_gemm_s8s8s32( + layout: c_int, transa: c_int, transb: c_int, offsetc: c_int, + m: c_int, n: c_int, k: c_int, + alpha: c_float, + a: *const i8, lda: c_int, oa: i8, + b: *const i8, ldb: c_int, ob: i8, + beta: c_float, c: *mut i32, ldc: c_int, + co: *const i32, + ); } // ═══════════════════════════════════════════════════════════════ @@ -235,3 +268,237 @@ pub const fn sgemm_nr() -> usize { 16 } pub const fn sgemm_mr() -> usize { 6 } pub const fn dgemm_nr() -> usize { 8 } pub const fn dgemm_mr() -> usize { 6 } + +// ═══════════════════════════════════════════════════════════════ +// Public ndarray-shaped GEMM API (Burn integration surface) +// ═══════════════════════════════════════════════════════════════ +// +// These wrappers accept `ArrayView2` / `ArrayViewMut2` and forward to the +// CBLAS FFI declared above. They handle row-major / column-major layout +// detection from ndarray strides and return a structured error if the input +// is non-contiguous along its leading dimension (which CBLAS cannot express). + +/// Errors returned from the MKL ndarray-shaped GEMM wrappers. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum MklError { + /// Inner dimensions of A and B don't match (`A.cols != B.rows`). + ShapeMismatch { + a_shape: (usize, usize), + b_shape: (usize, usize), + }, + /// Output `C` dimensions don't match `(A.rows, B.cols)`. + OutputShapeMismatch { + expected: (usize, usize), + got: (usize, usize), + }, + /// One of the arrays is not stride-compatible with CBLAS. + /// + /// CBLAS requires that one of the two strides is `1` (the contiguous + /// dimension) and the other is `>= the contiguous extent`. Arbitrary + /// striding (e.g. from a non-contiguous slice) is not supported — copy + /// to a contiguous buffer first. + NonContiguous { which: &'static str }, + /// The bf16 / int8 routines are not available in this MKL build, or are + /// stubbed out (e.g. older MKL versions predate `cblas_gemm_*`). + Unsupported(&'static str), +} + +impl core::fmt::Display for MklError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + MklError::ShapeMismatch { a_shape, b_shape } => write!( + f, + "MKL GEMM shape mismatch: A is {:?}, B is {:?}", + a_shape, b_shape + ), + MklError::OutputShapeMismatch { expected, got } => write!( + f, + "MKL GEMM output shape mismatch: expected {:?}, got {:?}", + expected, got + ), + MklError::NonContiguous { which } => { + write!(f, "MKL GEMM operand `{}` is not stride-compatible with CBLAS", which) + } + MklError::Unsupported(msg) => write!(f, "MKL GEMM unsupported: {}", msg), + } + } +} + +impl std::error::Error for MklError {} + +/// CBLAS layout descriptor extracted from an ndarray view. +struct BlasLayout { + layout: c_int, + trans: c_int, + ld: c_int, +} + +/// Inspect strides and produce a CBLAS layout/transpose/leading-dimension +/// triple for a 2D ndarray view. Returns `None` if neither dimension has +/// stride 1 (i.e. the matrix is non-contiguous in both directions). +fn blas_layout(view: &crate::ArrayBase) -> Option { + let (rows, cols) = view.dim(); + let strides = view.strides(); + let rs = strides[0]; + let cs = strides[1]; + // Row-major: stride between rows is the leading dim, columns are stride 1. + if cs == 1 && (rs >= cols as isize || rows <= 1) { + return Some(BlasLayout { layout: CBLAS_ROW_MAJOR, trans: CBLAS_NO_TRANS, ld: rs.max(1) as c_int }); + } + // Column-major: stride between cols is the leading dim, rows are stride 1. + // We expose this to CBLAS as a *row-major transposed* matrix so we keep a + // single `layout` argument across all three operands. + if rs == 1 && (cs >= rows as isize || cols <= 1) { + return Some(BlasLayout { layout: CBLAS_ROW_MAJOR, trans: 112 /* CblasTrans */, ld: cs.max(1) as c_int }); + } + None +} + +/// `C := alpha * A * B + beta * C` for `f32` matrices via MKL `cblas_sgemm`. +pub fn sgemm( + a: ArrayView2, + b: ArrayView2, + mut c: ArrayViewMut2, + alpha: f32, + beta: f32, +) -> Result<(), MklError> { + let (m, k) = a.dim(); + let (kb, n) = b.dim(); + if k != kb { + return Err(MklError::ShapeMismatch { a_shape: a.dim(), b_shape: b.dim() }); + } + if c.dim() != (m, n) { + return Err(MklError::OutputShapeMismatch { expected: (m, n), got: c.dim() }); + } + let la = blas_layout(&a).ok_or(MklError::NonContiguous { which: "a" })?; + let lb = blas_layout(&b).ok_or(MklError::NonContiguous { which: "b" })?; + let lc = blas_layout(&c).ok_or(MklError::NonContiguous { which: "c" })?; + if lc.trans != CBLAS_NO_TRANS { + return Err(MklError::NonContiguous { which: "c" }); + } + unsafe { + cblas_sgemm( + lc.layout, la.trans, lb.trans, + m as c_int, n as c_int, k as c_int, + alpha, a.as_ptr(), la.ld, + b.as_ptr(), lb.ld, + beta, c.as_mut_ptr(), lc.ld, + ); + } + Ok(()) +} + +/// `C := alpha * A * B + beta * C` for `f64` matrices via MKL `cblas_dgemm`. +pub fn dgemm( + a: ArrayView2, + b: ArrayView2, + mut c: ArrayViewMut2, + alpha: f64, + beta: f64, +) -> Result<(), MklError> { + let (m, k) = a.dim(); + let (kb, n) = b.dim(); + if k != kb { + return Err(MklError::ShapeMismatch { a_shape: a.dim(), b_shape: b.dim() }); + } + if c.dim() != (m, n) { + return Err(MklError::OutputShapeMismatch { expected: (m, n), got: c.dim() }); + } + let la = blas_layout(&a).ok_or(MklError::NonContiguous { which: "a" })?; + let lb = blas_layout(&b).ok_or(MklError::NonContiguous { which: "b" })?; + let lc = blas_layout(&c).ok_or(MklError::NonContiguous { which: "c" })?; + if lc.trans != CBLAS_NO_TRANS { + return Err(MklError::NonContiguous { which: "c" }); + } + unsafe { + cblas_dgemm( + lc.layout, la.trans, lb.trans, + m as c_int, n as c_int, k as c_int, + alpha, a.as_ptr(), la.ld, + b.as_ptr(), lb.ld, + beta, c.as_mut_ptr(), lc.ld, + ); + } + Ok(()) +} + +/// `C := alpha * A * B + beta * C` with BF16 inputs and `f32` accumulator, +/// via MKL `cblas_gemm_bf16bf16f32`. +/// +/// This requires Intel MKL >= 2020 (for the bf16 GEMM kernel). On older MKL +/// builds the symbol is missing and linking will fail at runtime — there is +/// no compile-time fallback. +pub fn sgemm_bf16( + a: ArrayView2, + b: ArrayView2, + mut c: ArrayViewMut2, + alpha: f32, + beta: f32, +) -> Result<(), MklError> { + let (m, k) = a.dim(); + let (kb, n) = b.dim(); + if k != kb { + return Err(MklError::ShapeMismatch { a_shape: a.dim(), b_shape: b.dim() }); + } + if c.dim() != (m, n) { + return Err(MklError::OutputShapeMismatch { expected: (m, n), got: c.dim() }); + } + let la = blas_layout(&a).ok_or(MklError::NonContiguous { which: "a" })?; + let lb = blas_layout(&b).ok_or(MklError::NonContiguous { which: "b" })?; + let lc = blas_layout(&c).ok_or(MklError::NonContiguous { which: "c" })?; + if lc.trans != CBLAS_NO_TRANS { + return Err(MklError::NonContiguous { which: "c" }); + } + // BF16 is `#[repr(transparent)] (pub u16)`, so the pointer cast is sound. + unsafe { + cblas_gemm_bf16bf16f32( + lc.layout, la.trans, lb.trans, + m as c_int, n as c_int, k as c_int, + alpha, + a.as_ptr() as *const u16, la.ld, + b.as_ptr() as *const u16, lb.ld, + beta, c.as_mut_ptr(), lc.ld, + ); + } + Ok(()) +} + +/// `C := A * B` with `i8` inputs and `i32` accumulator, via MKL +/// `cblas_gemm_s8s8s32` with zero offsets (no zero-point correction). +/// +/// Note: alpha/beta are fixed at `1.0` / `0.0` for the simple `Burn`-style +/// signature. If you need scaling, call the FFI directly. This requires +/// Intel MKL >= 2018 (when integer GEMM was introduced). +pub fn sgemm_int8( + a: ArrayView2, + b: ArrayView2, + mut c: ArrayViewMut2, +) -> Result<(), MklError> { + let (m, k) = a.dim(); + let (kb, n) = b.dim(); + if k != kb { + return Err(MklError::ShapeMismatch { a_shape: a.dim(), b_shape: b.dim() }); + } + if c.dim() != (m, n) { + return Err(MklError::OutputShapeMismatch { expected: (m, n), got: c.dim() }); + } + let la = blas_layout(&a).ok_or(MklError::NonContiguous { which: "a" })?; + let lb = blas_layout(&b).ok_or(MklError::NonContiguous { which: "b" })?; + let lc = blas_layout(&c).ok_or(MklError::NonContiguous { which: "c" })?; + if lc.trans != CBLAS_NO_TRANS { + return Err(MklError::NonContiguous { which: "c" }); + } + let co: i32 = 0; + unsafe { + cblas_gemm_s8s8s32( + lc.layout, la.trans, lb.trans, CBLAS_OFFSET_FIX, + m as c_int, n as c_int, k as c_int, + 1.0_f32, + a.as_ptr(), la.ld, 0_i8, + b.as_ptr(), lb.ld, 0_i8, + 0.0_f32, c.as_mut_ptr(), lc.ld, + &co as *const i32, + ); + } + Ok(()) +} diff --git a/src/backend/mod.rs b/src/backend/mod.rs index d1d78c0e..52e4cf96 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -16,7 +16,7 @@ pub(crate) mod kernels_avx512; #[cfg(feature = "intel-mkl")] -mod mkl; +pub mod mkl; #[cfg(feature = "openblas")] mod openblas;