From c4de7fb589c5ede35b69cd40e1872204e853b88d Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Sun, 17 May 2026 12:36:25 +0200 Subject: [PATCH] Wire BF16 matmul dispatch in DefaultCpuOpsJvm.chooseQuantizedMatmul (#613) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Final phase of the three-phase BF16 dispatch chain. Follow-ups to #610 (Bf16TensorData) and #612 (loader KEEP_NATIVE policy) — both merged. After this PR, a consumer that flips `bf16Policy = KEEP_NATIVE` on SafeTensorsParametersLoader (or constructs a `Bf16DenseTensorData` directly) gets the SIMD-vectorised BF16 matmul path with zero other code changes. Native FFM kernel (priority 100) wins when the bundled libskainet_kernels.so is loaded; falls through to Panama Vector (50) then to the scalar SPI reference (0). Implementation: - New `bf16MatmulKernel: Bf16MatmulKernel` lazy in DefaultCpuOpsJvm. Non-null with `ScalarBf16MatmulKernel` floor — mirrors `fp32MatmulKernel`'s pattern rather than the nullable `q4kMatmulKernel` / `q8_0MatmulKernel` pattern (which exist because Q4_K / Q8_0 have legacy non-SPI fallbacks via `JvmQuantizedVectorKernels`; BF16 has no such legacy). - New `is Bf16TensorData ->` branch in `chooseQuantizedMatmul`'s `when (bData)` block. The BF16 SPI kernel is a full SGEMM `(m, n, k)` with byte-strides on the B operand — no per-batch matvec loop like Q4_K/Q8_0/Q6_K need. 3 integration tests in `Bf16MatmulDispatchTest`: - single-batch matmul (`[1, k] × [k, n]` BF16) matches scalar reference within `1e-2 * k`. - multi-batch matmul (`m=3, k=256, n=32`) — exercises a 2D output. - LLM-typical 512² attention projection. Refs #613. Full `:skainet-backends:skainet-backend-cpu:jvmTest` and `:skainet-backends:skainet-backend-native-cpu:jvmTest` suites pass on linux-x86_64 / JDK 21 with `--add-modules jdk.incubator.vector`. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt | 38 +++++++ .../exec/tensor/ops/Bf16MatmulDispatchTest.kt | 102 ++++++++++++++++++ 2 files changed, 140 insertions(+) create mode 100644 skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/Bf16MatmulDispatchTest.kt diff --git a/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt index a035d6a5..e9be28bd 100644 --- a/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt +++ b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt @@ -3,11 +3,13 @@ package sk.ainet.exec.tensor.ops import jdk.incubator.vector.FloatVector import jdk.incubator.vector.VectorSpecies import jdk.incubator.vector.VectorOperators +import sk.ainet.backend.api.kernel.Bf16MatmulKernel import sk.ainet.backend.api.kernel.Fp32MatmulKernel import sk.ainet.backend.api.kernel.KernelRegistry import sk.ainet.backend.api.kernel.KernelServiceLoader import sk.ainet.backend.api.kernel.Q4KMatmulKernel import sk.ainet.backend.api.kernel.Q8_0MatmulKernel +import sk.ainet.exec.kernel.ScalarBf16MatmulKernel import sk.ainet.exec.kernel.ScalarMatmulKernel import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor @@ -17,6 +19,7 @@ import sk.ainet.lang.tensor.data.MemorySegmentBackedData import sk.ainet.lang.tensor.data.MemorySegmentTensorData import sk.ainet.lang.tensor.data.Q4MemorySegmentMarker import sk.ainet.lang.tensor.data.Q4MemorySegmentTensorData +import sk.ainet.lang.tensor.data.Bf16TensorData import sk.ainet.lang.tensor.data.Q8_0TensorData import sk.ainet.lang.tensor.data.Q8MemorySegmentMarker import sk.ainet.lang.tensor.data.Q8MemorySegmentTensorData @@ -89,6 +92,26 @@ internal class DefaultCpuOpsJvm( ?.matmulQ8_0() } + /** + * BF16 matmul kernel resolved via [KernelRegistry]. Unlike the Q4_K + * and Q8_0 lookups (nullable, with legacy `JvmQuantizedVectorKernels` + * fallbacks), BF16 has no pre-SPI implementation in this codebase — + * the scalar SPI kernel is the floor. We mirror [fp32MatmulKernel]'s + * pattern: non-null, picks the highest-priority provider that carries + * a BF16 kernel (native FFM at 100, Panama Vector at 50), falls back + * to [ScalarBf16MatmulKernel] when no SIMD provider reports + * availability (e.g. tests that explicitly clear the registry). + */ + private val bf16MatmulKernel: Bf16MatmulKernel by lazy { + if (KernelRegistry.providers().isEmpty()) { + KernelServiceLoader.installAll() + } + KernelRegistry.providers() + .firstOrNull { it.isAvailable() && it.matmulBf16() != null } + ?.matmulBf16() + ?: ScalarBf16MatmulKernel + } + override fun add(a: Tensor, b: Tensor): Tensor { vectorFloatBinary(a, b, { x, y -> x.add(y) }) { x, y -> x + y }?.let { return it } return super.add(a, b) @@ -511,6 +534,21 @@ internal class DefaultCpuOpsJvm( @Suppress("UNCHECKED_CAST") CpuTensor(outData as TensorData, this, a.dtype) } + is Bf16TensorData -> { + // BF16 is dense (not block-quantized) and the kernel SPI is a + // full SGEMM with `(m, n, k)` strides — no per-batch loop needed, + // unlike the matvec-shaped Q4_K / Q8_0 / Q6_K branches. + val outBuffer = FloatArray(batchSize * outputDim) + bf16MatmulKernel.matmul( + inputBuffer, 0, inputDim, + bData.packedData, 0, outputDim * Bf16TensorData.BYTES_PER_ELEMENT, + outBuffer, 0, outputDim, + batchSize, outputDim, inputDim, + ) + val outData = DenseFloatArrayTensorData(Shape(batchSize, outputDim), outBuffer) + @Suppress("UNCHECKED_CAST") + CpuTensor(outData as TensorData, this, a.dtype) + } is Q6_KTensorData -> { val outBuffer = FloatArray(batchSize * outputDim) for (batch in 0 until batchSize) { diff --git a/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/Bf16MatmulDispatchTest.kt b/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/Bf16MatmulDispatchTest.kt new file mode 100644 index 00000000..2d7796a7 --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/Bf16MatmulDispatchTest.kt @@ -0,0 +1,102 @@ +package sk.ainet.exec.tensor.ops + +import kotlin.math.abs +import kotlin.random.Random +import kotlin.test.Test +import kotlin.test.assertTrue +import sk.ainet.context.DirectCpuExecutionContext +import sk.ainet.exec.kernel.ScalarBf16MatmulKernel +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.data.Bf16DenseTensorData +import sk.ainet.lang.tensor.data.TensorData +import sk.ainet.lang.types.FP32 + +/** + * Integration tests for the FP32 × BF16 dispatch path in + * [DefaultCpuOpsJvm.matmul]. Confirms `ops.matmul` against a + * `Bf16DenseTensorData` weight produces the same output as the + * `ScalarBf16MatmulKernel` reference — proving the new `is + * Bf16TensorData ->` branch in `chooseQuantizedMatmul` is reached + * and routes through the BF16 SPI (or its scalar fallback when no + * SIMD provider resolves). Mirrors `Q8_0MatmulDispatchTest` in + * shape and coverage. + */ +class Bf16MatmulDispatchTest { + + private val ctx = DirectCpuExecutionContext() + + /** BF16 has 7 mantissa bits — accumulated error scales with `k`. */ + private val bf16TolPerK = 1e-2f + + /** Truncate FP32 → BF16 (high 16 bits, zero rounding), pack LE bytes. */ + private fun fp32ToBf16Bytes(values: FloatArray): ByteArray { + val out = ByteArray(values.size * 2) + for (i in values.indices) { + val bf16 = (values[i].toRawBits() ushr 16) and 0xFFFF + out[i * 2] = (bf16 and 0xFF).toByte() + out[i * 2 + 1] = ((bf16 ushr 8) and 0xFF).toByte() + } + return out + } + + @Suppress("UNCHECKED_CAST") + private fun bf16Weight(inputDim: Int, outputDim: Int, seed: Int): Pair, ByteArray> { + val rng = Random(seed) + val values = FloatArray(inputDim * outputDim) { rng.nextFloat() - 0.5f } + val bytes = fp32ToBf16Bytes(values) + val data = Bf16DenseTensorData(Shape(inputDim, outputDim), bytes) as TensorData + return ctx.fromData(data, FP32::class) to bytes + } + + private fun scalarReference( + input: FloatArray, weightBytes: ByteArray, + m: Int, n: Int, k: Int, + ): FloatArray { + val out = FloatArray(m * n) + ScalarBf16MatmulKernel.matmul( + input, 0, k, + weightBytes, 0, n * 2, + out, 0, n, + m, n, k, + ) + return out + } + + private fun assertDispatchMatchesScalar( + m: Int, k: Int, n: Int, seed: Int, + ) { + val rng = Random(seed) + val inputFloats = FloatArray(m * k) { rng.nextFloat() - 0.5f } + val (weight, weightBytes) = bf16Weight(k, n, seed) + val input = ctx.fromFloatArray(Shape(m, k), FP32::class, inputFloats) + + val out = ctx.ops.matmul(input, weight) + val outArr = out.data.copyToFloatArray() + val expected = scalarReference(inputFloats, weightBytes, m, n, k) + + val tol = (bf16TolPerK * k.coerceAtLeast(1)).coerceAtLeast(bf16TolPerK) + for (i in expected.indices) { + val diff = abs(expected[i] - outArr[i]) + assertTrue( + diff <= tol, + "BF16 dispatch mismatch at $i: expected=${expected[i]} got=${outArr[i]} diff=$diff tol=$tol", + ) + } + } + + @Test + fun single_batch_matmul_against_bf16_weight_routes_correctly() { + assertDispatchMatchesScalar(m = 1, k = 128, n = 64, seed = 1) + } + + @Test + fun multi_batch_matmul_against_bf16_weight_routes_correctly() { + assertDispatchMatchesScalar(m = 3, k = 256, n = 32, seed = 2) + } + + @Test + fun llm_typical_attention_proj_matmul_routes_correctly() { + assertDispatchMatchesScalar(m = 1, k = 512, n = 512, seed = 3) + } +}