Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 175 additions & 1 deletion crates/burn/src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,176 @@ pub fn clear_attention_cache() {
}

// ============================================================================
// VNNI u8 MatVec fast path — 64 MACs per instruction
// Compiled Linear Cache — O(k) replacing O(n_rows) for any weight matrix
// ============================================================================
//
// For any linear layer y = W @ x, where W is [n_rows, n_cols]:
// 1. Each row of W is assigned to one of 256 palette centroids (u8 index)
// 2. At inference: compute k=256 centroid dot products with input x
// 3. For each output row i: y[i] = centroid_outputs[assignment[i]]
//
// Cost: 256 × n_cols MACs + n_rows lookups (vs n_rows × n_cols MACs)
// For gate_proj [3072, 1024]: 256K MACs vs 3.1M MACs = 12× fewer.
//
// Keyed by (n_rows, n_cols) — the weight matrix shape.

/// A compiled linear layer: 256 centroids replace the full weight matrix.
#[cfg(feature = "std")]
#[derive(Clone)]
pub struct CompiledLinear {
/// Centroid weight vectors: [k × n_cols] f32, row-major.
/// k=256 centroids, each of dimension n_cols.
pub centroids: Vec<f32>,
/// Number of centroids (palette size, typically 256).
pub k: usize,
/// Input dimension (n_cols of the original weight matrix).
pub n_cols: usize,
/// Output dimension (n_rows of the original weight matrix).
pub n_rows: usize,
/// Row assignment: for each of the n_rows output rows, which centroid it maps to.
pub assignments: Vec<u8>,
}

/// Global cache of compiled linear layers.
/// Keyed by (n_rows, n_cols) — the original weight matrix shape.
/// Multiple layers can share the same shape, so we use a Vec and match by registration order.
#[cfg(feature = "std")]
static LINEAR_CACHE: LazyLock<RwLock<Vec<CompiledLinear>>> =
LazyLock::new(|| RwLock::new(Vec::new()));

/// Register a compiled linear layer.
#[cfg(feature = "std")]
pub fn register_compiled_linear(compiled: CompiledLinear) {
let mut cache = LINEAR_CACHE.write().unwrap();
cache.push(compiled);
}

/// Pop the next compiled linear for the given shape.
/// Returns None if no matching table exists.
/// This is FIFO — layers are consumed in registration order.
#[cfg(feature = "std")]
fn pop_compiled_linear(n_rows: usize, n_cols: usize) -> Option<CompiledLinear> {
let cache = LINEAR_CACHE.read().unwrap();
// Find first matching entry (don't pop — layers may be reused across batches)
cache.iter().find(|c| c.n_rows == n_rows && c.n_cols == n_cols).cloned()
}

/// Try to compute y = W @ x using compiled centroid matmul with VNNI acceleration.
///
/// Instead of n_rows × n_cols MACs:
/// 1. Quantize centroids to u8, input column to i8
/// 2. VNNI dot: 256 centroid × input dots at 64 MACs/instruction
/// 3. Dequantize i32 results back to f32 via scale factors
/// 4. Broadcast via palette assignment: out[i] = centroid_out[assignment[i]]
///
/// Returns true if compiled path was used.
#[cfg(feature = "std")]
fn try_compiled_linear<E: NdArrayElement>(
_lhs: &ndarray::ArrayView2<'_, E>,
_rhs: &ndarray::ArrayView2<'_, E>,
out: &mut ndarray::ArrayViewMut2<'_, E>,
m: usize,
k_dim: usize,
n: usize,
) -> bool {
let compiled = match pop_compiled_linear(m, k_dim) {
Some(c) => c,
None => return false,
};

if compiled.assignments.len() < m || compiled.k == 0 {
return false;
}

let k = compiled.k;
let dim = compiled.n_cols.min(k_dim);

// Pre-quantize centroids: f32 → u8 [0, 255] (done once, amortized across columns)
// Find global min/max across all centroid values for uniform quantization
let mut c_min = f32::MAX;
let mut c_max = f32::MIN;
for v in &compiled.centroids[..k * dim] {
if *v < c_min { c_min = *v; }
if *v > c_max { c_max = *v; }
}
let c_range = (c_max - c_min).max(1e-10);
let c_scale = c_range / 255.0;

let centroids_u8: Vec<u8> = compiled.centroids[..k * dim].iter()
.map(|&v| (((v - c_min) / c_range) * 255.0).round().clamp(0.0, 255.0) as u8)
.collect();

// Select VNNI dot function (same tiered dispatch as build_distance_table_vnni)
let dot_fn: fn(&[u8], &[i8]) -> i32 = {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512vnni") {
|a, b| {
// SAFETY: avx512vnni confirmed
unsafe { ndarray::simd_amx::vnni_dot_u8_i8(a, b) }
}
} else {
ndarray::simd_amx::vnni_dot_u8_i8_scalar
}
}
#[cfg(not(target_arch = "x86_64"))]
{ ndarray::simd_amx::vnni_dot_u8_i8_scalar }
};

for j in 0..n {
// Extract input column j and quantize to i8 [-128, 127]
let mut col_f32 = vec![0.0f32; dim];
for d in 0..dim {
col_f32[d] = _rhs[[d, j]].elem::<f64>() as f32;
}
let mut x_min = f32::MAX;
let mut x_max = f32::MIN;
for &v in &col_f32 {
if v < x_min { x_min = v; }
if v > x_max { x_max = v; }
}
let x_range = (x_max - x_min).max(1e-10);
let x_scale = x_range / 255.0;

let col_i8: Vec<i8> = col_f32.iter()
.map(|&v| (((v - x_min) / x_range) * 255.0).round().clamp(0.0, 255.0) as u8 as i8)
.collect();

// VNNI dot: 256 centroid dots at 64 MACs/instruction
let mut centroid_out = vec![0.0f64; k];
for c in 0..k {
let c_row = &centroids_u8[c * dim..(c + 1) * dim];
let raw_dot = dot_fn(c_row, &col_i8);

// Dequantize: raw_dot was computed on quantized values.
// Approximate: result ≈ c_scale × x_scale × raw_dot + bias_correction
// The bias from zero-point offsets: sum(c_u8) × x_zero + sum(x_u8) × c_zero + ...
// For speed: use the linear approximation (sufficient for inference)
centroid_out[c] = raw_dot as f64 * c_scale as f64 * x_scale as f64;
}

// Broadcast via palette assignment
for i in 0..m {
let c_idx = compiled.assignments[i] as usize;
out[[i, j]] = centroid_out[c_idx.min(k - 1)].elem();
}
}

true
}

/// Count of registered compiled linear layers.
#[cfg(feature = "std")]
pub fn compiled_linear_count() -> usize {
LINEAR_CACHE.read().unwrap().len()
}

/// Clear all compiled linear layers.
#[cfg(feature = "std")]
pub fn clear_compiled_linear_cache() {
LINEAR_CACHE.write().unwrap().clear();
}
//
// For quantized u8×i8 matmul (codebook distance table build):
// Input A: [m, k] u8 (codebook rows, quantized)
// Input B: [k, n] i8 (codebook cols, quantized)
Expand Down Expand Up @@ -355,6 +522,13 @@ pub(crate) fn matmul<E: NdArrayElement>(
.get()
.slice_mut(s!(out_batch, .., ..));

// Try compiled linear (centroid matmul, O(256) per column).
// Falls through to BLAS if no compiled layer matches.
#[cfg(feature = "std")]
if try_compiled_linear(&lhs_slice, &rhs_slice, &mut out_slice, m, k, n) {
return;
}

// Try compiled attention table (O(1) per element).
// Falls through to BLAS if no table is registered for d_head=k.
#[cfg(feature = "std")]
Expand Down
2 changes: 1 addition & 1 deletion crates/burn/upstream
Loading