Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions .claude/AMX_GOTCHAS.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,41 @@ For CPUID leaf 7 (AMX detection): use `__cpuid_count()`, not inline asm.

---

## Gotcha 4: OS must enable AMX via XSETBV
## Gotcha 4: OS must enable AMX via XSETBV + process must request permission

AMX tiles are large (8 KB of state). The OS must opt in via XCR0 bits 17+18.
Linux 5.19+ enables AMX by default. Older kernels: SIGILL on tile instructions.
AMX tiles are large (8 KB of state). Two levels of OS enablement required:

1. **Kernel enables tile state in XCR0** (bits 17+18). Linux 5.19+ does this.
2. **Process requests XCOMP_PERM** via `prctl(ARCH_REQ_XCOMP_PERM, 18)`.
Without this, LDTILECFG will SIGILL even if XCR0 bits are set.

**Detection (stable)**:
```rust
let xcr0 = core::arch::x86_64::__cpuid_count(0xD, 0);
let tilecfg = (xcr0.eax >> 17) & 1; // bit 17 = XTILECFG
let tiledata = (xcr0.eax >> 18) & 1; // bit 18 = XTILEDATA
// Both must be 1
```
// Step 1: CPUID — does CPU support AMX?
let cpuid = core::arch::x86_64::__cpuid_count(7, 0);
let amx_tile = (cpuid.edx >> 24) & 1;
let amx_int8 = (cpuid.edx >> 25) & 1;

// Step 2: OSXSAVE — does OS support XSAVE?
let cpuid_01 = core::arch::x86_64::__cpuid(1);
let osxsave = (cpuid_01.ecx >> 27) & 1;

// Step 3: _xgetbv(0) — did OS ACTUALLY enable tile state?
// ⚠ Do NOT use __cpuid_count(0xD, 0) — that reports what CPU SUPPORTS,
// not what the OS ENABLED. _xgetbv(0) reads the actual XCR0 register.
let xcr0: u64 = unsafe { core::arch::x86_64::_xgetbv(0) };
let tilecfg = (xcr0 >> 17) & 1; // bit 17 = XTILECFG
let tiledata = (xcr0 >> 18) & 1; // bit 18 = XTILEDATA

// Step 4: prctl — request tile permission for this process
// SYS_prctl = 157, ARCH_REQ_XCOMP_PERM = 0x1023, XFEATURE_XTILEDATA = 18
// Returns 0 on success, -errno on failure. Idempotent.
```

**Previous bug**: `__cpuid_count(0xD, 0)` reports XSAVE state component bitmap
(what the CPU *supports*), NOT the actual XCR0 value (what the OS *enabled*).
On hypervisors that advertise AMX in CPUID but don't enable tile state,
the old check returned `true` → SIGILL on LDTILECFG.

---

Expand Down
87 changes: 76 additions & 11 deletions src/simd_amx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,74 @@
// ═══════════════════════════════════════════════════════════════════════════

/// Check if AMX hardware is present AND OS-enabled.
///
/// Two checks required:
/// 1. CPUID.07H.0H:EDX bits 24 (AMX-TILE) + 25 (AMX-INT8) = CPU supports it
/// 2. XCR0 bits 17 (TILECFG) + 18 (TILEDATA) = OS has enabled tile state
///
/// The XCR0 check is critical: even if CPUID reports AMX, the hypervisor
/// may not have enabled the XSTATE for tiles. Without OS enablement,
/// LDTILECFG will SIGILL.
///
/// Previous bug: used CPUID leaf 0xD (reports what CPU supports for XSAVE)
/// instead of _xgetbv(0) (reports what OS actually enabled). The old check
/// could return true on a hypervisor that advertises AMX in CPUID but
/// hasn't set XCR0 bits 17+18.
#[cfg(target_arch = "x86_64")]
pub fn amx_available() -> bool {
// Step 1: CPU supports AMX-TILE + AMX-INT8?
let cpuid = core::arch::x86_64::__cpuid_count(7, 0);
let amx_tile = (cpuid.edx >> 24) & 1;
let amx_int8 = (cpuid.edx >> 25) & 1;
if amx_tile == 0 || amx_int8 == 0 { return false; }
// Check OS enabled via XCR0 bits 17+18
let xcr0 = core::arch::x86_64::__cpuid_count(0xD, 0);
let tilecfg = (xcr0.eax >> 17) & 1;
let tiledata = (xcr0.eax >> 18) & 1;
tilecfg == 1 && tiledata == 1

// Step 2: OS enabled XSAVE? (CPUID.01H:ECX bit 27 = OSXSAVE)
let cpuid_01 = core::arch::x86_64::__cpuid(1);
let osxsave = (cpuid_01.ecx >> 27) & 1;
if osxsave == 0 { return false; }

// Step 3: OS actually enabled tile state in XCR0?
// _xgetbv(0) reads the ACTUAL XCR0 register (what the OS set),
// not the CPUID-reported capability.
// Bit 17 = TILECFG, Bit 18 = TILEDATA. Both must be set.
let xcr0: u64 = unsafe { core::arch::x86_64::_xgetbv(0) };
let tilecfg = (xcr0 >> 17) & 1;
let tiledata = (xcr0 >> 18) & 1;
if tilecfg == 0 || tiledata == 0 { return false; }

// Step 4: Request XCOMP_PERM for TILEDATA.
// Linux kernel 5.19+: processes must call prctl(ARCH_REQ_XCOMP_PERM, 18)
// to request permission for TILEDATA (XFEATURE 18) before using AMX.
// Without this, LDTILECFG will SIGILL even if XCR0 bits are set.
// The prctl either succeeds (0) or fails (-1) — idempotent, safe to call
// multiple times.
#[cfg(target_os = "linux")]
{
const SYS_PRCTL: i64 = 157; // x86_64 syscall number for prctl
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Call arch_prctl when requesting XCOMP permission

On Linux x86_64, ARCH_REQ_XCOMP_PERM is an arch_prctl operation, but this code issues syscall 157 (prctl). That call returns EINVAL for option 0x1023, so amx_available() will return false whenever this branch runs, disabling AMX even on hosts where AMX is actually usable. The request should be made via arch_prctl (syscall 158) with the same option/feature arguments.

Useful? React with 👍 / 👎.

const ARCH_REQ_XCOMP_PERM: i64 = 0x1023;
const XFEATURE_XTILEDATA: i64 = 18;
// SAFETY: syscall(prctl, ARCH_REQ_XCOMP_PERM, 18) is a simple permission
// request. It either grants tile permission (returns 0) or fails (returns
// -errno). No side effects on failure. Idempotent.
let ret: i64;
unsafe {
core::arch::asm!(
"syscall",
inlateout("rax") SYS_PRCTL => ret,
in("rdi") ARCH_REQ_XCOMP_PERM,
in("rsi") XFEATURE_XTILEDATA,
in("rdx") 0i64,
in("r10") 0i64,
in("r8") 0i64,
lateout("rcx") _,
lateout("r11") _,
options(nostack),
);
}
if ret != 0 { return false; }
}

true
}

#[cfg(not(target_arch = "x86_64"))]
Expand Down Expand Up @@ -203,17 +260,25 @@ pub fn vnni_matvec_scalar(

/// Runtime-dispatched VNNI MatVec: avx512vnni → avxvnniint8 → scalar i32.
///
/// Three tiers, mutually exclusive by hardware generation:
/// Three tiers, checked in order (first match wins):
/// avx512vnni — 64 MACs/instr (zmm, Cascade Lake+, Zen 4+)
/// avxvnniint8 — 32 MACs/instr (ymm, Arrow Lake, NUC 14 i9-185H)
/// scalar i32 — only for non-x86 or testing (caller should prefer F32x16 FMA)
/// scalar i32 — only for non-x86 or testing
///
/// IMPORTANT: avxvnniint8 (VNNI2, 256-bit) is NEVER reached when
/// avx512vnni (VNNI512) is present. This is correct:
/// - CPUs with avx512vnni always have 512-bit VPDPBUSD (faster)
/// - avxvnniint8 exists ONLY for CPUs that dropped AVX-512
/// but added 256-bit VNNI (Arrow Lake, Meteor Lake U-series)
/// - The two instructions have DIFFERENT encodings:
/// avx512vnni: EVEX-encoded VPDPBUSD zmm (512-bit)
/// avxvnniint8: VEX-encoded VPDPBUSD ymm (256-bit)
/// - Running EVEX VPDPBUSD on a VEX-only CPU = SIGILL
/// - Running VEX VPDPBUSD on an EVEX CPU = works but wastes half the width
///
/// NOTE: The scalar path here does i32 multiply-accumulate, NOT f32.
/// For the thinking engine, F32x16 FMA (16 MACs/instr) is the true floor.
/// This scalar path exists only for correctness on non-x86 targets.
/// The thinking engine's cycle_auto() dispatches:
/// VNNI detected → cycle_vnni() → this function
/// No VNNI → cycle() → F32x16 (never reaches here)
/// No VNNI → cycle() → F32x16 FMA (never reaches here)
pub fn matvec_dispatch(
table: &[u8],
energy_i8: &[i8],
Expand Down
Loading