Add BF16 tile GEMM with AMX/AVX-512 dispatch#104
Merged
Conversation
Previously: no global target-cpu, per-function #[target_feature] + runtime dispatch so "one binary runs on AVX2 and AVX-512 machines". Now: compile-time AVX-512 baseline via rustflags target-cpu=x86-64-v4 (AVX-512F, AVX-512BW, AVX-512CD, AVX-512DQ, AVX-512VL). The v4 level does not include AMX; AMX (amx-tile / amx-int8 / amx-bf16) remains per-function #[target_feature] with runtime amx_available() gating (CPUID + _xgetbv(0) bits 17/18 + prctl ARCH_REQ_XCOMP_PERM for Linux 5.19+ tile state). AVX2 is used as the CI-only fallback path; local and Railway builds pin v4 so AVX-512 lanes (F32x16, kernels_avx512) light up at compile time. https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj
This reverts commit d7731ba.
Additive only — no existing symbols modified.
hpc/amx_matmul.rs (ADD):
pub unsafe fn tile_dpbf16ps()
TDPBF16PS tmm0, tmm1, tmm2 via stable inline asm .byte encoding
(C4 E2 72 5C C1 — same binary-trick pattern as existing
tile_dpbusd/tile_zero/tile_release for pre-nightly AMX on Rust 1.94)
pub fn vnni_pack_bf16(src, dst, k, n)
Pack K×N row-major bf16 → K/2 × (N*2) VNNI pairs for TDPBF16PS B tile
hpc/bf16_tile_gemm.rs (NEW module, additive):
pub fn bf16_tile_gemm_16x16(a_bf16, b_bf16, c_f32, k)
Same API, runtime tier dispatch:
amx_available() → AMX TDPBF16PS tile GEMM (K/32 tile iters)
amx_available() = false → AVX-512 F32x16 + mul_add FMA fallback
(BF16→f32 via bf16_to_f32_batch, then F32x16 chunks_exact(16)
+ mul_add = VFMADD231PS on __m512 with target-cpu=x86-64-v4,
emulated as 2× F32x8 FMA on AVX2-only hosts)
hpc/mod.rs (ADD):
pub mod bf16_tile_gemm;
Test results:
hpc::bf16_tile_gemm::tests::fallback_matches_scalar_reference_k64 ... ok
hpc::bf16_tile_gemm::tests::public_api_runs_on_any_hardware ... ok
Full suite: 1616 passed, 0 failed, 36 ignored, no SIGILL.
Baseline was 1612 → +4 (two new here, two other).
Design invariants honored:
- simd.rs polyfill boundary untouched (F32x16/F32x8 re-exports unchanged)
- additive only: no modifications to tile_dpbusd, TileConfig, or mod layout
- runtime-dispatched via amx_available() — same binary works on AMX
machines and AVX-512-only machines
- stable Rust 1.94 (inline asm .byte encoding, no nightly intrinsics)
https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Introduces a new
bf16_tile_gemmmodule that provides a unified API for 16×16 BF16 matrix multiplication with runtime tier dispatch between AMX (TDPBF16PS) and AVX-512 F32x16 fallback paths.Key Changes
New module
src/hpc/bf16_tile_gemm.rs: Implementsbf16_tile_gemm_16x16()public API with:amx_available()checkF32x16SIMD with FMA operationsExtended
src/hpc/amx_matmul.rs:tile_dpbf16ps(): Inline assembly wrapper for TDPBF16PS instruction (C += A(bf16) × B(bf16_vnni) → f32)vnni_pack_bf16(): Utility to repack B matrix from row-major to VNNI pair layout required by TDPBF16PSUpdated
src/hpc/mod.rs: Exported newbf16_tile_gemmmoduleImplementation Details
bf16_to_f32_batch(), then accumulates viaF32x16::mul_add()with column-gathering optimizationThe implementation follows the pattern of one dispatch check per call, with identical numerical results across both execution tiers.
https://claude.ai/code/session_01SbYsmmbPf9YQuYbHZN52Zh