Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions .claude/AMX_GOTCHAS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# AMX Gotchas — Resolved on Stable Rust 1.94

> Updated: 2026-04-03
> CPU: Sapphire Rapids (AMX-TILE + AMX-INT8 + AMX-BF16 confirmed)
> Kernel: 6.18.5 (XCR0 bits 17+18 enabled)

---

## Status

AMX works on **stable Rust 1.94** via `asm!()`. No nightly needed.

```
LDTILECFG: ✓ (load tile configuration)
TILEZERO: ✓ (zero a tile register)
TILERELEASE: ✓ (release tiles)
TDPBUSD: ✓ (u8×i8 tile dot product, 256 MACs/instruction)
```

---

## Gotcha 1: Rust intrinsics are NIGHTLY ONLY

```rust
// This DOES NOT compile on stable:
use std::arch::x86_64::_tile_loadconfig; // error: unstable feature x86_amx_intrinsics
```

**Fix**: Use `asm!()` (stable since Rust 1.59):
```rust
asm!("ldtilecfg [{}]", in(reg) config.data.as_ptr(), options(nostack));
```

Tracking issue: https://github.com/rust-lang/rust/issues/126622

---

## Gotcha 2: Tile config MUST be 64-byte aligned

```rust
// This SEGFAULTS:
let config = [0u8; 64]; // stack-allocated, no alignment guarantee

// This WORKS:
#[repr(C, align(64))]
struct TileConfig { data: [u8; 64] }
let config = TileConfig { data: [0u8; 64] };
```

LDTILECFG reads 64 bytes from the pointer. If not 64-byte aligned,
the CPU raises #GP (general protection fault) → SIGSEGV.

---

## Gotcha 3: rbx is LLVM-reserved

```rust
// This DOES NOT compile:
asm!("cpuid", out("ebx") ebx, ...); // error: rbx is used internally by LLVM

// This WORKS:
let result = core::arch::x86_64::__cpuid_count(7, 0); // stable, handles rbx internally
```

For CPUID leaf 7 (AMX detection): use `__cpuid_count()`, not inline asm.

---

## Gotcha 4: OS must enable AMX via XSETBV

AMX tiles are large (8 KB of state). The OS must opt in via XCR0 bits 17+18.
Linux 5.19+ enables AMX by default. Older kernels: SIGILL on tile instructions.

**Detection (stable)**:
```rust
let xcr0 = core::arch::x86_64::__cpuid_count(0xD, 0);
let tilecfg = (xcr0.eax >> 17) & 1; // bit 17 = XTILECFG
let tiledata = (xcr0.eax >> 18) & 1; // bit 18 = XTILEDATA
// Both must be 1
```

---

## Gotcha 5: TILEZERO/TILERELEASE need manual byte encoding

The Rust assembler on some toolchains doesn't know AMX mnemonics.
Use raw instruction bytes:

```rust
// TILEZERO tmm0
asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xc0", options(nostack, nomem));

// TILEZERO tmm1
asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xc8", options(nostack, nomem));

// TILEZERO tmm2
asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xd0", options(nostack, nomem));

// TILEZERO tmm3
asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xd8", options(nostack, nomem));

// TILERELEASE
asm!(".byte 0xc4, 0xe2, 0x78, 0x49, 0xc0", options(nostack, nomem));

// TDPBUSD tmm0, tmm1, tmm2 (C += A × B)
asm!(".byte 0xc4, 0xe2, 0x73, 0x5e, 0xc1", options(nostack, nomem));
```

Note: LDTILECFG works as a mnemonic:
```rust
asm!("ldtilecfg [{}]", in(reg) ptr, options(nostack));
```

---

## Gotcha 6: Tile config field layout is not obvious

The 64-byte tile config structure:
```
Byte 0: palette (must be 1)
Bytes 1-15: reserved (zero)
Bytes 16-23: rows per tile (tile 0 at byte 16, tile 1 at byte 17, ...)
Bytes 24-47: reserved (zero)
Bytes 48-63: colbytes per tile (tile 0 at [48..49] as u16 LE, tile 1 at [50..51], ...)
```

For TDPBUSD (u8×i8 → i32):
- Tile 0 (C result): rows=16, colbytes=64 (16 × i32 = 64 bytes per row)
- Tile 1 (A input): rows=16, colbytes=64 (16 × 64 u8)
- Tile 2 (B input): rows=16, colbytes=64 (transposed for column access)

**IMPORTANT**: colbytes is a u16 at byte offset 48+2*tile_id (little-endian).
For values ≤ 64, only the low byte matters.

---

## Gotcha 7: TILEZERO with wrong config = SEGFAULT

If you configure tile 0 as 16 rows × 64 colbytes but then TILEZERO tmm0,
it works. But if the config doesn't match what the hardware expects (e.g.,
palette=0 or all zeros), TILEZERO will SEGFAULT.

**Fix**: Always start with the minimal working config:
```rust
cfg.data[0] = 1; // palette 1 (MUST be 1, not 0)
cfg.data[16] = 1; // at least 1 row
cfg.data[48] = 4; // at least 4 colbytes (1 × i32)
```

Then expand to full 16×64 after verifying the minimal config works.

---

## Gotcha 8: is_x86_feature_detected!("amx-tile") is NIGHTLY ONLY

```rust
// DOES NOT compile on stable:
is_x86_feature_detected!("amx-tile") // error: unstable x86_amx_intrinsics

// WORKS on stable:
fn amx_available() -> bool {
let cpuid = core::arch::x86_64::__cpuid_count(7, 0);
let amx_tile = (cpuid.edx >> 24) & 1;
let amx_int8 = (cpuid.edx >> 25) & 1;
amx_tile == 1 && amx_int8 == 1
}
```

Use `__cpuid_count` (stable) for detection, not `is_x86_feature_detected!`.

---

## Hardware Tiers (this session)

```
Tier Feature MACs/instr Detection (stable) CPU
──── ─────── ────────── ────────────────── ───
3 AMX 256 __cpuid_count(7,0).edx bit 24 Sapphire Rapids+
2 avx512vnni 64 is_x86_feature_detected! Cascade Lake+, Zen 4+
1 avxvnniint8 32 is_x86_feature_detected! Arrow Lake (NUC 14)
0 scalar 1 always any
```

Also detectable but not yet kernelized:
- `avxvnniint16`: i16×i16 dot product (VPDPWSSD)
- `amx-bf16`: TDPBF16PS (BF16 tile matmul, for calibration)

---

## Files

```
ndarray/src/simd_amx.rs — AMX detection + VNNI/VNNI2 kernels + quantize
ndarray/src/hpc/amx_matmul.rs — AMX tile ops via inline asm (TDPBUSD)
ndarray/crates/burn/src/ops/matmul.rs — 4-tier dispatch in distance table builder
```

---

## What AMX Enables

```
Distance table build (4096² = 16M dot products):
AMX: ~20 min (all models combined)
avx512vnni: ~1:20h
avxvnniint8: ~2:40h (NUC 14)
scalar: ~24-48h

ThinkingEngine MatVec (per cycle):
AMX: ~44 μs (L1 table fits in 4 tile registers)
avx512vnni: ~175 μs
avxvnniint8: ~350 μs
scalar: ~5 ms
```
65 changes: 59 additions & 6 deletions crates/burn/src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,69 @@ pub fn try_vnni_matmul_u8(
false
}

/// Build a k×k distance table from k centroids using VNNI if available.
/// Build a k×k COSINE SIMILARITY table from f32 centroids.
///
/// Takes raw f32 centroids, normalizes to unit vectors, quantizes,
/// runs tiered VNNI/AMX dot product, maps to u8 [0, 255].
///
/// This IS the ThinkingEngine's brain. cosine[-1,1] → u8[0,255].
/// 128 = orthogonal. 255 = identical. 0 = opposite.
///
/// centroids_f32: [k × dim] raw f32 centroids (row-major)
/// Returns: [k × k] u8 cosine similarity table
#[cfg(feature = "std")]
pub fn build_cosine_table(centroids_f32: &[f32], k: usize, dim: usize) -> Vec<u8> {
assert_eq!(centroids_f32.len(), k * dim);

// Step 1: Normalize each centroid to unit vector
let mut normed = vec![0.0f32; k * dim];
for i in 0..k {
let row = &centroids_f32[i * dim..(i + 1) * dim];
let norm: f32 = row.iter().map(|v| v * v).sum::<f32>().sqrt();
let inv_norm = if norm > 1e-10 { 1.0 / norm } else { 0.0 };
for d in 0..dim {
normed[i * dim + d] = row[d] * inv_norm;
}
}

// Step 2: Quantize normalized [-1, 1] → u8 [0, 255]
// After normalization, values are in [-1, 1].
// Map: u8 = round((value + 1.0) * 127.5)
let centroids_u8: Vec<u8> = normed.iter()
.map(|&v| ((v + 1.0) * 127.5).round().clamp(0.0, 255.0) as u8)
.collect();

// Step 3: Compute dot products using tiered VNNI dispatch
let raw_dots = build_distance_table_vnni(&centroids_u8, k, dim);

// Step 4: Map i32 dot products → u8 cosine similarity [0, 255]
// The dot product of two unit vectors quantized to u8 [0,255]:
// max dot (identical) = sum of (u8_i)² over dim
// min dot (opposite) = much lower
// Find actual min/max to scale properly
let min_dot = raw_dots.iter().copied().min().unwrap_or(0) as f64;
let max_dot = raw_dots.iter().copied().max().unwrap_or(1) as f64;
let range = (max_dot - min_dot).max(1.0);

let mut table = vec![128u8; k * k]; // 128 = default orthogonal
for i in 0..k {
for j in 0..k {
let raw = raw_dots[i * k + j] as f64;
let normalized = (raw - min_dot) / range; // [0, 1]
table[i * k + j] = (normalized * 255.0).round().clamp(0.0, 255.0) as u8;
}
}

table
}

/// Build a k×k RAW DOT PRODUCT table from u8 centroids using VNNI if available.
///
/// centroids_u8: [k × dim] quantized codebook centroids (u8, row-major)
/// Returns: [k × k] i32 dot product matrix (symmetric)
///
/// Uses VNNI dot product (64 MACs/instruction) for each centroid pair.
/// Symmetric: only computes upper triangle, mirrors to lower.
///
/// This IS the ThinkingEngine's brain construction step.
/// 4096² = 16M dot products. With VNNI: ~1:20h for large dim.
/// For cosine: use build_cosine_table() which normalizes first.
/// This function is for raw dot products when centroids are already u8.
#[cfg(feature = "std")]
pub fn build_distance_table_vnni(centroids_u8: &[u8], k: usize, dim: usize) -> Vec<i32> {
assert_eq!(centroids_u8.len(), k * dim);
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ mod dimension;
/// Portable SIMD types — `crate::simd::f32x16` today, `std::simd::f32x16` tomorrow.
#[cfg(feature = "std")]
#[allow(missing_docs)]
pub(crate) mod simd;
pub mod simd;
#[cfg(all(feature = "std", target_arch = "x86_64"))]
#[allow(missing_docs, dead_code)]
pub(crate) mod simd_avx512;
Expand Down
Loading