You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Phase 3 — final wiring of the BF16 dispatch chain. Follow-ups: #610 (Bf16TensorData merged) and #612 (loader KEEP_NATIVE policy merged).
What's missing
After #610 + #612, BF16 weights loaded via SafeTensorsParametersLoader with bf16Policy = KEEP_NATIVE arrive as Bf16DenseTensorData-backed tensors. But DefaultCpuOps.matmul doesn't recognise this storage type yet — chooseQuantizedMatmul only matches Q4_K, Q6_K, Q8_0, and MemorySegment-backed quantised data. A BF16-tensor matmul falls through to super.matmul, which scalar-loops through the FP32 path (decoding BF16 per-element via Bf16DenseTensorData.get). The SIMD BF16 kernel from #605 stays unreachable.
This PR adds the is Bf16TensorData -> branch — the final piece of the chain.
Scope
Add private val bf16MatmulKernel: Bf16MatmulKernel by lazy { ... } in DefaultCpuOpsJvm. Mirrors fp32MatmulKernel (non-null with scalar fallback) rather than q8_0MatmulKernel (nullable with legacy fallback) because there's no legacy non-SPI BF16 kernel — the scalar SPI provider is the floor.
Add is Bf16TensorData -> branch in chooseQuantizedMatmul's when (bData) block. Calls the SPI's matmul with dense (m, n, k) parameters — the BF16 kernel is a full 2D SGEMM, not a matvec like Q4_K/Q8_0, so no per-batch loop is needed.
Tests
New Bf16MatmulDispatchTest in commonTest (mirrors Q8_0MatmulDispatchTest):
single_batch_matmul_against_bf16_weight_routes_correctly — [1, k] × [k, n] matmul against a Bf16DenseTensorData weight; compares against ScalarBf16MatmulKernel reference within 1e-2 * k tolerance.
multi_batch_matmul_against_bf16_weight_routes_correctly — [m, k] × [k, n] with m>1.
Phase 3 — final wiring of the BF16 dispatch chain. Follow-ups: #610 (Bf16TensorData merged) and #612 (loader KEEP_NATIVE policy merged).
What's missing
After #610 + #612, BF16 weights loaded via
SafeTensorsParametersLoaderwithbf16Policy = KEEP_NATIVEarrive asBf16DenseTensorData-backed tensors. ButDefaultCpuOps.matmuldoesn't recognise this storage type yet —chooseQuantizedMatmulonly matchesQ4_K,Q6_K,Q8_0, and MemorySegment-backed quantised data. A BF16-tensor matmul falls through tosuper.matmul, which scalar-loops through the FP32 path (decoding BF16 per-element viaBf16DenseTensorData.get). The SIMD BF16 kernel from #605 stays unreachable.This PR adds the
is Bf16TensorData ->branch — the final piece of the chain.Scope
private val bf16MatmulKernel: Bf16MatmulKernel by lazy { ... }inDefaultCpuOpsJvm. Mirrorsfp32MatmulKernel(non-null with scalar fallback) rather thanq8_0MatmulKernel(nullable with legacy fallback) because there's no legacy non-SPI BF16 kernel — the scalar SPI provider is the floor.is Bf16TensorData ->branch inchooseQuantizedMatmul'swhen (bData)block. Calls the SPI'smatmulwith dense(m, n, k)parameters — the BF16 kernel is a full 2D SGEMM, not a matvec like Q4_K/Q8_0, so no per-batch loop is needed.Tests
New
Bf16MatmulDispatchTestincommonTest(mirrorsQ8_0MatmulDispatchTest):single_batch_matmul_against_bf16_weight_routes_correctly—[1, k] × [k, n]matmul against aBf16DenseTensorDataweight; compares againstScalarBf16MatmulKernelreference within1e-2 * ktolerance.multi_batch_matmul_against_bf16_weight_routes_correctly—[m, k] × [k, n]with m>1.llm_typical_attention_proj_matmul_routes_correctly— 512² shape.Out of scope (separate follow-ups)
SKaiNET-transformers/llm-inference/gemma— exercises the full chain on a real model).Branch
feature/bf16-dispatchoffdevelop.