diff --git a/.claude/AMX_GOTCHAS.md b/.claude/AMX_GOTCHAS.md new file mode 100644 index 00000000..3e8dfe48 --- /dev/null +++ b/.claude/AMX_GOTCHAS.md @@ -0,0 +1,214 @@ +# AMX Gotchas — Resolved on Stable Rust 1.94 + +> Updated: 2026-04-03 +> CPU: Sapphire Rapids (AMX-TILE + AMX-INT8 + AMX-BF16 confirmed) +> Kernel: 6.18.5 (XCR0 bits 17+18 enabled) + +--- + +## Status + +AMX works on **stable Rust 1.94** via `asm!()`. No nightly needed. + +``` +LDTILECFG: ✓ (load tile configuration) +TILEZERO: ✓ (zero a tile register) +TILERELEASE: ✓ (release tiles) +TDPBUSD: ✓ (u8×i8 tile dot product, 256 MACs/instruction) +``` + +--- + +## Gotcha 1: Rust intrinsics are NIGHTLY ONLY + +```rust +// This DOES NOT compile on stable: +use std::arch::x86_64::_tile_loadconfig; // error: unstable feature x86_amx_intrinsics +``` + +**Fix**: Use `asm!()` (stable since Rust 1.59): +```rust +asm!("ldtilecfg [{}]", in(reg) config.data.as_ptr(), options(nostack)); +``` + +Tracking issue: https://github.com/rust-lang/rust/issues/126622 + +--- + +## Gotcha 2: Tile config MUST be 64-byte aligned + +```rust +// This SEGFAULTS: +let config = [0u8; 64]; // stack-allocated, no alignment guarantee + +// This WORKS: +#[repr(C, align(64))] +struct TileConfig { data: [u8; 64] } +let config = TileConfig { data: [0u8; 64] }; +``` + +LDTILECFG reads 64 bytes from the pointer. If not 64-byte aligned, +the CPU raises #GP (general protection fault) → SIGSEGV. + +--- + +## Gotcha 3: rbx is LLVM-reserved + +```rust +// This DOES NOT compile: +asm!("cpuid", out("ebx") ebx, ...); // error: rbx is used internally by LLVM + +// This WORKS: +let result = core::arch::x86_64::__cpuid_count(7, 0); // stable, handles rbx internally +``` + +For CPUID leaf 7 (AMX detection): use `__cpuid_count()`, not inline asm. + +--- + +## Gotcha 4: OS must enable AMX via XSETBV + +AMX tiles are large (8 KB of state). The OS must opt in via XCR0 bits 17+18. +Linux 5.19+ enables AMX by default. Older kernels: SIGILL on tile instructions. + +**Detection (stable)**: +```rust +let xcr0 = core::arch::x86_64::__cpuid_count(0xD, 0); +let tilecfg = (xcr0.eax >> 17) & 1; // bit 17 = XTILECFG +let tiledata = (xcr0.eax >> 18) & 1; // bit 18 = XTILEDATA +// Both must be 1 +``` + +--- + +## Gotcha 5: TILEZERO/TILERELEASE need manual byte encoding + +The Rust assembler on some toolchains doesn't know AMX mnemonics. +Use raw instruction bytes: + +```rust +// TILEZERO tmm0 +asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xc0", options(nostack, nomem)); + +// TILEZERO tmm1 +asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xc8", options(nostack, nomem)); + +// TILEZERO tmm2 +asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xd0", options(nostack, nomem)); + +// TILEZERO tmm3 +asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xd8", options(nostack, nomem)); + +// TILERELEASE +asm!(".byte 0xc4, 0xe2, 0x78, 0x49, 0xc0", options(nostack, nomem)); + +// TDPBUSD tmm0, tmm1, tmm2 (C += A × B) +asm!(".byte 0xc4, 0xe2, 0x73, 0x5e, 0xc1", options(nostack, nomem)); +``` + +Note: LDTILECFG works as a mnemonic: +```rust +asm!("ldtilecfg [{}]", in(reg) ptr, options(nostack)); +``` + +--- + +## Gotcha 6: Tile config field layout is not obvious + +The 64-byte tile config structure: +``` +Byte 0: palette (must be 1) +Bytes 1-15: reserved (zero) +Bytes 16-23: rows per tile (tile 0 at byte 16, tile 1 at byte 17, ...) +Bytes 24-47: reserved (zero) +Bytes 48-63: colbytes per tile (tile 0 at [48..49] as u16 LE, tile 1 at [50..51], ...) +``` + +For TDPBUSD (u8×i8 → i32): +- Tile 0 (C result): rows=16, colbytes=64 (16 × i32 = 64 bytes per row) +- Tile 1 (A input): rows=16, colbytes=64 (16 × 64 u8) +- Tile 2 (B input): rows=16, colbytes=64 (transposed for column access) + +**IMPORTANT**: colbytes is a u16 at byte offset 48+2*tile_id (little-endian). +For values ≤ 64, only the low byte matters. + +--- + +## Gotcha 7: TILEZERO with wrong config = SEGFAULT + +If you configure tile 0 as 16 rows × 64 colbytes but then TILEZERO tmm0, +it works. But if the config doesn't match what the hardware expects (e.g., +palette=0 or all zeros), TILEZERO will SEGFAULT. + +**Fix**: Always start with the minimal working config: +```rust +cfg.data[0] = 1; // palette 1 (MUST be 1, not 0) +cfg.data[16] = 1; // at least 1 row +cfg.data[48] = 4; // at least 4 colbytes (1 × i32) +``` + +Then expand to full 16×64 after verifying the minimal config works. + +--- + +## Gotcha 8: is_x86_feature_detected!("amx-tile") is NIGHTLY ONLY + +```rust +// DOES NOT compile on stable: +is_x86_feature_detected!("amx-tile") // error: unstable x86_amx_intrinsics + +// WORKS on stable: +fn amx_available() -> bool { + let cpuid = core::arch::x86_64::__cpuid_count(7, 0); + let amx_tile = (cpuid.edx >> 24) & 1; + let amx_int8 = (cpuid.edx >> 25) & 1; + amx_tile == 1 && amx_int8 == 1 +} +``` + +Use `__cpuid_count` (stable) for detection, not `is_x86_feature_detected!`. + +--- + +## Hardware Tiers (this session) + +``` +Tier Feature MACs/instr Detection (stable) CPU +──── ─────── ────────── ────────────────── ─── +3 AMX 256 __cpuid_count(7,0).edx bit 24 Sapphire Rapids+ +2 avx512vnni 64 is_x86_feature_detected! Cascade Lake+, Zen 4+ +1 avxvnniint8 32 is_x86_feature_detected! Arrow Lake (NUC 14) +0 scalar 1 always any +``` + +Also detectable but not yet kernelized: +- `avxvnniint16`: i16×i16 dot product (VPDPWSSD) +- `amx-bf16`: TDPBF16PS (BF16 tile matmul, for calibration) + +--- + +## Files + +``` +ndarray/src/simd_amx.rs — AMX detection + VNNI/VNNI2 kernels + quantize +ndarray/src/hpc/amx_matmul.rs — AMX tile ops via inline asm (TDPBUSD) +ndarray/crates/burn/src/ops/matmul.rs — 4-tier dispatch in distance table builder +``` + +--- + +## What AMX Enables + +``` +Distance table build (4096² = 16M dot products): + AMX: ~20 min (all models combined) + avx512vnni: ~1:20h + avxvnniint8: ~2:40h (NUC 14) + scalar: ~24-48h + +ThinkingEngine MatVec (per cycle): + AMX: ~44 μs (L1 table fits in 4 tile registers) + avx512vnni: ~175 μs + avxvnniint8: ~350 μs + scalar: ~5 ms +``` diff --git a/crates/burn/src/ops/matmul.rs b/crates/burn/src/ops/matmul.rs index 60065dd3..4dabc3cc 100644 --- a/crates/burn/src/ops/matmul.rs +++ b/crates/burn/src/ops/matmul.rs @@ -125,16 +125,69 @@ pub fn try_vnni_matmul_u8( false } -/// Build a k×k distance table from k centroids using VNNI if available. +/// Build a k×k COSINE SIMILARITY table from f32 centroids. +/// +/// Takes raw f32 centroids, normalizes to unit vectors, quantizes, +/// runs tiered VNNI/AMX dot product, maps to u8 [0, 255]. +/// +/// This IS the ThinkingEngine's brain. cosine[-1,1] → u8[0,255]. +/// 128 = orthogonal. 255 = identical. 0 = opposite. +/// +/// centroids_f32: [k × dim] raw f32 centroids (row-major) +/// Returns: [k × k] u8 cosine similarity table +#[cfg(feature = "std")] +pub fn build_cosine_table(centroids_f32: &[f32], k: usize, dim: usize) -> Vec { + assert_eq!(centroids_f32.len(), k * dim); + + // Step 1: Normalize each centroid to unit vector + let mut normed = vec![0.0f32; k * dim]; + for i in 0..k { + let row = ¢roids_f32[i * dim..(i + 1) * dim]; + let norm: f32 = row.iter().map(|v| v * v).sum::().sqrt(); + let inv_norm = if norm > 1e-10 { 1.0 / norm } else { 0.0 }; + for d in 0..dim { + normed[i * dim + d] = row[d] * inv_norm; + } + } + + // Step 2: Quantize normalized [-1, 1] → u8 [0, 255] + // After normalization, values are in [-1, 1]. + // Map: u8 = round((value + 1.0) * 127.5) + let centroids_u8: Vec = normed.iter() + .map(|&v| ((v + 1.0) * 127.5).round().clamp(0.0, 255.0) as u8) + .collect(); + + // Step 3: Compute dot products using tiered VNNI dispatch + let raw_dots = build_distance_table_vnni(¢roids_u8, k, dim); + + // Step 4: Map i32 dot products → u8 cosine similarity [0, 255] + // The dot product of two unit vectors quantized to u8 [0,255]: + // max dot (identical) = sum of (u8_i)² over dim + // min dot (opposite) = much lower + // Find actual min/max to scale properly + let min_dot = raw_dots.iter().copied().min().unwrap_or(0) as f64; + let max_dot = raw_dots.iter().copied().max().unwrap_or(1) as f64; + let range = (max_dot - min_dot).max(1.0); + + let mut table = vec![128u8; k * k]; // 128 = default orthogonal + for i in 0..k { + for j in 0..k { + let raw = raw_dots[i * k + j] as f64; + let normalized = (raw - min_dot) / range; // [0, 1] + table[i * k + j] = (normalized * 255.0).round().clamp(0.0, 255.0) as u8; + } + } + + table +} + +/// Build a k×k RAW DOT PRODUCT table from u8 centroids using VNNI if available. /// /// centroids_u8: [k × dim] quantized codebook centroids (u8, row-major) /// Returns: [k × k] i32 dot product matrix (symmetric) /// -/// Uses VNNI dot product (64 MACs/instruction) for each centroid pair. -/// Symmetric: only computes upper triangle, mirrors to lower. -/// -/// This IS the ThinkingEngine's brain construction step. -/// 4096² = 16M dot products. With VNNI: ~1:20h for large dim. +/// For cosine: use build_cosine_table() which normalizes first. +/// This function is for raw dot products when centroids are already u8. #[cfg(feature = "std")] pub fn build_distance_table_vnni(centroids_u8: &[u8], k: usize, dim: usize) -> Vec { assert_eq!(centroids_u8.len(), k * dim); diff --git a/src/lib.rs b/src/lib.rs index e1ea8ddc..d98558be 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -232,7 +232,7 @@ mod dimension; /// Portable SIMD types — `crate::simd::f32x16` today, `std::simd::f32x16` tomorrow. #[cfg(feature = "std")] #[allow(missing_docs)] -pub(crate) mod simd; +pub mod simd; #[cfg(all(feature = "std", target_arch = "x86_64"))] #[allow(missing_docs, dead_code)] pub(crate) mod simd_avx512;