From 04e6a906d37d9d72c857fea38c1a3f62c9fe0d57 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 8 Jun 2026 10:33:10 +0200 Subject: [PATCH] feat(backend): commonMain scalar Q5_1/Q5_0/Q4_K/Q6_K kernels + SPI (Native parity) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Part of #708. Brings quantized matmul to Kotlin/Native (and JS/WASM), which previously only had FP32/BF16/Q8_0/Q4_0 scalar kernels — Q4_K/Q6_K/Q5_x were JVM-only (Panama/FFM), so on non-JVM targets packed-quant matmul had no kernel. SPI (skainet-backend-api, commonMain): - New Q5_1MatmulKernel / Q5_0MatmulKernel / Q6KMatmulKernel interfaces (block-major `(blockIdx*outputDim+o)*BYTES_PER_BLOCK`, exact dequant in kdoc). - KernelProvider: matmulQ5_1()/matmulQ5_0()/matmulQ6K() accessors (default null) + supports() keys for "Q5_1"/"Q5_0"/"Q6_K". Scalar kernels (skainet-backend-cpu, commonMain — available on every target): - ScalarQ5_1/Q5_0/Q4_K/Q6_KMatmulKernel, math ported from JvmQuantizedVectorKernels / DequantOps (Q4_K get_scale_min_k4 + sub-block codeSum*scale - inputSum*offset; Q6_K ql/qh 6-bit reassembly). Shared decodeHalf() FP16 helper. - ScalarKernelProvider now overrides matmulQ4K/Q6K/Q5_1/Q5_0 → the scalar floor carries every packed format. Test: ScalarPackedKernelParityTest (commonTest) validates each kernel's matmul against an independent inline dequant; passes on jvmTest AND linuxX64Test, proving Native packed-matmul correctness (relative tol for the FP reassociation of the per-sub-block accumulation). Note: dispatch wiring (so ops.matmul routes packed tensors to these kernels on non-JVM) + non-JVM provider registration land in follow-up commits; this commit is the kernels + SPI surface. Co-Authored-By: Claude Opus 4.8 --- .../backend/api/kernel/KernelProvider.kt | 21 +++ .../backend/api/kernel/Q5_0MatmulKernel.kt | 40 +++++ .../backend/api/kernel/Q5_1MatmulKernel.kt | 53 +++++++ .../backend/api/kernel/Q6KMatmulKernel.kt | 37 +++++ .../api/jvm/skainet-backend-cpu.api | 32 ++++ .../kotlin/sk/ainet/exec/kernel/ScalarHalf.kt | 31 ++++ .../ainet/exec/kernel/ScalarKernelProvider.kt | 8 + .../exec/kernel/ScalarQ4_KMatmulKernel.kt | 82 ++++++++++ .../exec/kernel/ScalarQ5_0MatmulKernel.kt | 60 +++++++ .../exec/kernel/ScalarQ5_1MatmulKernel.kt | 62 ++++++++ .../exec/kernel/ScalarQ6_KMatmulKernel.kt | 80 ++++++++++ .../kernel/ScalarPackedKernelParityTest.kt | 148 ++++++++++++++++++ 12 files changed, 654 insertions(+) create mode 100644 skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/Q5_0MatmulKernel.kt create mode 100644 skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/Q5_1MatmulKernel.kt create mode 100644 skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/Q6KMatmulKernel.kt create mode 100644 skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarHalf.kt create mode 100644 skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarQ4_KMatmulKernel.kt create mode 100644 skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarQ5_0MatmulKernel.kt create mode 100644 skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarQ5_1MatmulKernel.kt create mode 100644 skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarQ6_KMatmulKernel.kt create mode 100644 skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/exec/kernel/ScalarPackedKernelParityTest.kt diff --git a/skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/KernelProvider.kt b/skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/KernelProvider.kt index fd22f37f..09d99a9c 100644 --- a/skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/KernelProvider.kt +++ b/skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/KernelProvider.kt @@ -73,6 +73,24 @@ public interface KernelProvider { */ public fun matmulQ4_0(): Q4_0MatmulKernel? = null + /** + * F32 × Q6_K matmul kernel exposed by this provider, or `null` if + * this provider does not specialize Q6_K. Same fall-through pattern. + */ + public fun matmulQ6K(): Q6KMatmulKernel? = null + + /** + * F32 × Q5_1 matmul kernel exposed by this provider, or `null` if + * this provider does not specialize Q5_1. Same fall-through pattern. + */ + public fun matmulQ5_1(): Q5_1MatmulKernel? = null + + /** + * F32 × Q5_0 matmul kernel exposed by this provider, or `null` if + * this provider does not specialize Q5_0. Same fall-through pattern. + */ + public fun matmulQ5_0(): Q5_0MatmulKernel? = null + /** * Capability query: does this provider carry a kernel for * [opName] with the given [dtypeKeys]? @@ -107,6 +125,9 @@ public interface KernelProvider { "Q4_K" -> matmulQ4K() != null "Q8_0" -> matmulQ8_0() != null "Q4_0" -> matmulQ4_0() != null + "Q6_K" -> matmulQ6K() != null + "Q5_1" -> matmulQ5_1() != null + "Q5_0" -> matmulQ5_0() != null else -> false } } diff --git a/skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/Q5_0MatmulKernel.kt b/skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/Q5_0MatmulKernel.kt new file mode 100644 index 00000000..3d51aa43 --- /dev/null +++ b/skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/Q5_0MatmulKernel.kt @@ -0,0 +1,40 @@ +package sk.ainet.backend.api.kernel + +/** + * F32 input × Q5_0-packed weights matrix-vector multiply, in canonical + * ggml block layout. + * + * output[outputOffset + o] = Σ_j input[inputOffset + j] · dequant(weight[o, j]) + * for j ∈ [0, inputDim), o ∈ [0, outputDim) + * + * Block layout (32-element block, 22 bytes/block; see + * [sk.ainet.lang.tensor.data.Q5_0BlockTensorData] kdoc): + * - bytes 0..1 : `d` (block scale, FP16 LE) + * - bytes 2..5 : `qh[0..3]` (the 5th/high bit of each of the 32 codes) + * - bytes 6..21 : `qs[0..15]` (low 4 bits, two nibbles per byte) + * + * Per element, with `lo = qs[j] & 0x0F`, `hi = qs[j] >>> 4`, and the high + * bits `bitLo = (qh[j/8] >>> (j%8)) & 1`, `bitHi = (qh[(j+16)/8] >>> ((j+16)%8)) & 1`: + * + * element[j] = d * (lo + (bitLo shl 4) - 16) for j ∈ [0, 16) + * element[j + 16] = d * (hi + (bitHi shl 4) - 16) + * + * The `- 16` bias centres the unsigned 5-bit code around zero (no per-block + * min). Matches `sk.ainet.io.gguf.dequant.DequantOps.dequantQ5_0FromBytes`. + * + * Implementations MUST NOT mutate `input` or `weight`. They MUST fully + * write the `outputDim` floats starting at `output[outputOffset]`. + * + * Packed-weight **block-major** row contract: `weight` holds blocks laid + * out `(blockIdx * outputDim + o) * 22`. Matches `Q5_0BlockTensorData.packedData`. + * + * `inputDim` MUST be a multiple of 32 (the Q5_0 block size). + */ +public interface Q5_0MatmulKernel { + public fun matmul( + input: FloatArray, inputOffset: Int, + weight: ByteArray, weightByteOffset: Int, + inputDim: Int, outputDim: Int, + output: FloatArray, outputOffset: Int, + ) +} diff --git a/skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/Q5_1MatmulKernel.kt b/skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/Q5_1MatmulKernel.kt new file mode 100644 index 00000000..1f6d0a99 --- /dev/null +++ b/skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/Q5_1MatmulKernel.kt @@ -0,0 +1,53 @@ +package sk.ainet.backend.api.kernel + +/** + * F32 input × Q5_1-packed weights matrix-vector multiply, in canonical + * ggml block layout. + * + * output[outputOffset + o] = Σ_j input[inputOffset + j] · dequant(weight[o, j]) + * for j ∈ [0, inputDim), o ∈ [0, outputDim) + * + * Block layout (32-element block, 24 bytes/block; see + * [sk.ainet.lang.tensor.data.Q5_1BlockTensorData] kdoc): + * - bytes 0..1 : `d` (block scale, FP16 LE) + * - bytes 2..3 : `m` (block minimum, FP16 LE) + * - bytes 4..7 : `qh[0..3]` (the 5th/high bit of each of the 32 codes) + * - bytes 8..23 : `qs[0..15]` (low 4 bits, two nibbles per byte) + * + * Per element, with `lo = qs[j] & 0x0F`, `hi = qs[j] >>> 4`, and the high + * bits `bitLo = (qh[j/8] >>> (j%8)) & 1`, `bitHi = (qh[(j+16)/8] >>> ((j+16)%8)) & 1`: + * + * element[j] = d * (lo + (bitLo shl 4)) + m for j ∈ [0, 16) + * element[j + 16] = d * (hi + (bitHi shl 4)) + m + * + * Matches `sk.ainet.io.gguf.dequant.DequantOps.dequantQ5_1FromBytes`. + * + * Implementations MUST NOT mutate `input` or `weight`. They MAY assume + * the arrays do not alias each other or `output`. They MUST fully write + * the `outputDim` floats starting at `output[outputOffset]`. + * + * Packed-weight **block-major** row contract: `weight` holds blocks laid + * out `(blockIdx * outputDim + o) * 24` for output row `o` and input + * block index `blockIdx`. This matches `Q5_1BlockTensorData.packedData` + * after the GGUF row-major → input-block-major re-layout. + * + * `inputDim` MUST be a multiple of 32 (the Q5_1 block size). + */ +public interface Q5_1MatmulKernel { + /** + * @param input FP32 input vector (single row). + * @param inputOffset element offset into [input] where the row starts. + * @param weight packed Q5_1 bytes for the full `outputDim × inputDim` weight tensor. + * @param weightByteOffset byte offset into [weight] where block (0, 0) starts. + * @param inputDim contraction dimension (must be a multiple of 32). + * @param outputDim number of output cells. + * @param output FP32 output vector. + * @param outputOffset element offset into [output] where the row starts. + */ + public fun matmul( + input: FloatArray, inputOffset: Int, + weight: ByteArray, weightByteOffset: Int, + inputDim: Int, outputDim: Int, + output: FloatArray, outputOffset: Int, + ) +} diff --git a/skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/Q6KMatmulKernel.kt b/skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/Q6KMatmulKernel.kt new file mode 100644 index 00000000..8d05d3fd --- /dev/null +++ b/skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/Q6KMatmulKernel.kt @@ -0,0 +1,37 @@ +package sk.ainet.backend.api.kernel + +/** + * F32 input × Q6_K-packed weights matrix-vector multiply, in canonical + * ggml block layout. + * + * output[outputOffset + o] = Σ_j input[inputOffset + j] · dequant(weight[o, j]) + * for j ∈ [0, inputDim), o ∈ [0, outputDim) + * + * Q6_K super-block layout (256 elements, 210 bytes/block; see + * [sk.ainet.lang.tensor.data.Q6_KBlockTensorData]): + * - bytes 0..127 : `ql[0..127]` (lower 4 bits of each code) + * - bytes 128..191 : `qh[0..63]` (upper 2 bits of each code) + * - bytes 192..207 : `scales[0..15]`(int8 per-16-element sub-block scales) + * - bytes 208..209 : `d` (super-block scale, FP16 LE) + * + * The 6-bit signed code is reassembled from `ql`/`qh` (see ggml + * `dequantize_row_q6_K`); per element `dequant = d * scales[sub] * (code - 32)`. + * Matches `sk.ainet.io.gguf.dequant.DequantOps.dequantQ6KFromBytes` — that is + * the authoritative reference; implementations MUST agree with it. + * + * Implementations MUST NOT mutate `input` or `weight`. They MUST fully + * write the `outputDim` floats starting at `output[outputOffset]`. + * + * Packed-weight **block-major** row contract: blocks laid out + * `(blockIdx * outputDim + o) * 210`. Matches `Q6_KBlockTensorData.packedData`. + * + * `inputDim` MUST be a multiple of 256 (the Q6_K super-block size). + */ +public interface Q6KMatmulKernel { + public fun matmul( + input: FloatArray, inputOffset: Int, + weight: ByteArray, weightByteOffset: Int, + inputDim: Int, outputDim: Int, + output: FloatArray, outputOffset: Int, + ) +} diff --git a/skainet-backends/skainet-backend-cpu/api/jvm/skainet-backend-cpu.api b/skainet-backends/skainet-backend-cpu/api/jvm/skainet-backend-cpu.api index 39da5c0a..c9a8b256 100644 --- a/skainet-backends/skainet-backend-cpu/api/jvm/skainet-backend-cpu.api +++ b/skainet-backends/skainet-backend-cpu/api/jvm/skainet-backend-cpu.api @@ -54,6 +54,9 @@ public final class sk/ainet/exec/kernel/PanamaVectorKernelProvider : sk/ainet/ba public fun matmulFp32 ()Lsk/ainet/backend/api/kernel/Fp32MatmulKernel; public fun matmulQ4K ()Lsk/ainet/backend/api/kernel/Q4KMatmulKernel; public fun matmulQ4_0 ()Lsk/ainet/backend/api/kernel/Q4_0MatmulKernel; + public fun matmulQ5_0 ()Lsk/ainet/backend/api/kernel/Q5_0MatmulKernel; + public fun matmulQ5_1 ()Lsk/ainet/backend/api/kernel/Q5_1MatmulKernel; + public fun matmulQ6K ()Lsk/ainet/backend/api/kernel/Q6KMatmulKernel; public fun matmulQ8_0 ()Lsk/ainet/backend/api/kernel/Q8_0MatmulKernel; public fun supports (Ljava/lang/String;Ljava/util/List;)Z } @@ -67,6 +70,9 @@ public final class sk/ainet/exec/kernel/PanamaVectorKernelProviderFactory : sk/a public fun matmulFp32 ()Lsk/ainet/backend/api/kernel/Fp32MatmulKernel; public fun matmulQ4K ()Lsk/ainet/backend/api/kernel/Q4KMatmulKernel; public fun matmulQ4_0 ()Lsk/ainet/backend/api/kernel/Q4_0MatmulKernel; + public fun matmulQ5_0 ()Lsk/ainet/backend/api/kernel/Q5_0MatmulKernel; + public fun matmulQ5_1 ()Lsk/ainet/backend/api/kernel/Q5_1MatmulKernel; + public fun matmulQ6K ()Lsk/ainet/backend/api/kernel/Q6KMatmulKernel; public fun matmulQ8_0 ()Lsk/ainet/backend/api/kernel/Q8_0MatmulKernel; public fun supports (Ljava/lang/String;Ljava/util/List;)Z } @@ -105,6 +111,9 @@ public final class sk/ainet/exec/kernel/ScalarKernelProvider : sk/ainet/backend/ public fun matmulFp32 ()Lsk/ainet/backend/api/kernel/Fp32MatmulKernel; public fun matmulQ4K ()Lsk/ainet/backend/api/kernel/Q4KMatmulKernel; public fun matmulQ4_0 ()Lsk/ainet/backend/api/kernel/Q4_0MatmulKernel; + public fun matmulQ5_0 ()Lsk/ainet/backend/api/kernel/Q5_0MatmulKernel; + public fun matmulQ5_1 ()Lsk/ainet/backend/api/kernel/Q5_1MatmulKernel; + public fun matmulQ6K ()Lsk/ainet/backend/api/kernel/Q6KMatmulKernel; public fun matmulQ8_0 ()Lsk/ainet/backend/api/kernel/Q8_0MatmulKernel; public fun supports (Ljava/lang/String;Ljava/util/List;)Z } @@ -118,6 +127,9 @@ public final class sk/ainet/exec/kernel/ScalarKernelProviderFactory : sk/ainet/b public fun matmulFp32 ()Lsk/ainet/backend/api/kernel/Fp32MatmulKernel; public fun matmulQ4K ()Lsk/ainet/backend/api/kernel/Q4KMatmulKernel; public fun matmulQ4_0 ()Lsk/ainet/backend/api/kernel/Q4_0MatmulKernel; + public fun matmulQ5_0 ()Lsk/ainet/backend/api/kernel/Q5_0MatmulKernel; + public fun matmulQ5_1 ()Lsk/ainet/backend/api/kernel/Q5_1MatmulKernel; + public fun matmulQ6K ()Lsk/ainet/backend/api/kernel/Q6KMatmulKernel; public fun matmulQ8_0 ()Lsk/ainet/backend/api/kernel/Q8_0MatmulKernel; public fun supports (Ljava/lang/String;Ljava/util/List;)Z } @@ -132,6 +144,26 @@ public final class sk/ainet/exec/kernel/ScalarQ4_0MatmulKernel : sk/ainet/backen public fun matmul ([FI[BIII[FI)V } +public final class sk/ainet/exec/kernel/ScalarQ4_KMatmulKernel : sk/ainet/backend/api/kernel/Q4KMatmulKernel { + public static final field INSTANCE Lsk/ainet/exec/kernel/ScalarQ4_KMatmulKernel; + public fun matmul ([FI[BIII[FI)V +} + +public final class sk/ainet/exec/kernel/ScalarQ5_0MatmulKernel : sk/ainet/backend/api/kernel/Q5_0MatmulKernel { + public static final field INSTANCE Lsk/ainet/exec/kernel/ScalarQ5_0MatmulKernel; + public fun matmul ([FI[BIII[FI)V +} + +public final class sk/ainet/exec/kernel/ScalarQ5_1MatmulKernel : sk/ainet/backend/api/kernel/Q5_1MatmulKernel { + public static final field INSTANCE Lsk/ainet/exec/kernel/ScalarQ5_1MatmulKernel; + public fun matmul ([FI[BIII[FI)V +} + +public final class sk/ainet/exec/kernel/ScalarQ6_KMatmulKernel : sk/ainet/backend/api/kernel/Q6KMatmulKernel { + public static final field INSTANCE Lsk/ainet/exec/kernel/ScalarQ6_KMatmulKernel; + public fun matmul ([FI[BIII[FI)V +} + public final class sk/ainet/exec/kernel/ScalarQ8_0MatmulKernel : sk/ainet/backend/api/kernel/Q8_0MatmulKernel { public static final field INSTANCE Lsk/ainet/exec/kernel/ScalarQ8_0MatmulKernel; public fun matmul ([FI[BIII[FI)V diff --git a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarHalf.kt b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarHalf.kt new file mode 100644 index 00000000..dd903a0b --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarHalf.kt @@ -0,0 +1,31 @@ +package sk.ainet.exec.kernel + +/** + * Convert a 16-bit IEEE-754 half-precision value (low 16 bits of [hbits]) + * to FP32. Shared by the scalar packed-quant kernels in this package + * (Q5_1/Q5_0/Q4_K/Q6_K). Mirrors the inlined helpers in + * [ScalarQ4_0MatmulKernel] / [ScalarQ8_0MatmulKernel]. + */ +internal fun decodeHalf(hbits: Int): Float { + val sign = (hbits and 0x8000) shl 16 + val exp = (hbits and 0x7C00) shr 10 + val mant = hbits and 0x03FF + return when (exp) { + 0 -> { + if (mant == 0) { + Float.fromBits(sign) + } else { + var m = mant + var e = -14 + while ((m and 0x400) == 0) { + m = m shl 1 + e-- + } + m = m and 0x3FF + Float.fromBits(sign or ((e + 127) shl 23) or (m shl 13)) + } + } + 31 -> Float.fromBits(sign or (0xFF shl 23) or (mant shl 13)) + else -> Float.fromBits(sign or ((exp - 15 + 127) shl 23) or (mant shl 13)) + } +} diff --git a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarKernelProvider.kt b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarKernelProvider.kt index a7c13ccd..0611ce76 100644 --- a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarKernelProvider.kt +++ b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarKernelProvider.kt @@ -3,7 +3,11 @@ package sk.ainet.exec.kernel import sk.ainet.backend.api.kernel.Bf16MatmulKernel import sk.ainet.backend.api.kernel.Fp32MatmulKernel import sk.ainet.backend.api.kernel.KernelProvider +import sk.ainet.backend.api.kernel.Q4KMatmulKernel import sk.ainet.backend.api.kernel.Q4_0MatmulKernel +import sk.ainet.backend.api.kernel.Q5_0MatmulKernel +import sk.ainet.backend.api.kernel.Q5_1MatmulKernel +import sk.ainet.backend.api.kernel.Q6KMatmulKernel import sk.ainet.backend.api.kernel.Q8_0MatmulKernel /** @@ -27,4 +31,8 @@ public object ScalarKernelProvider : KernelProvider { override fun matmulBf16(): Bf16MatmulKernel = ScalarBf16MatmulKernel override fun matmulQ8_0(): Q8_0MatmulKernel = ScalarQ8_0MatmulKernel override fun matmulQ4_0(): Q4_0MatmulKernel = ScalarQ4_0MatmulKernel + override fun matmulQ4K(): Q4KMatmulKernel = ScalarQ4_KMatmulKernel + override fun matmulQ6K(): Q6KMatmulKernel = ScalarQ6_KMatmulKernel + override fun matmulQ5_1(): Q5_1MatmulKernel = ScalarQ5_1MatmulKernel + override fun matmulQ5_0(): Q5_0MatmulKernel = ScalarQ5_0MatmulKernel } diff --git a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarQ4_KMatmulKernel.kt b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarQ4_KMatmulKernel.kt new file mode 100644 index 00000000..df003437 --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarQ4_KMatmulKernel.kt @@ -0,0 +1,82 @@ +package sk.ainet.exec.kernel + +import sk.ainet.backend.api.kernel.Q4KMatmulKernel + +/** + * Scalar reference [Q4KMatmulKernel] — commonMain, so Q4_K packed matmul works + * on Kotlin/Native / JS / WASM, not only the JVM SIMD path. + * + * Q4_K super-block: 256 elements / 144 bytes, block-major `(blockIdx*outputDim+o)*144`: + * `d`(f16) `dMin`(f16) 12 scale bytes (ggml `get_scale_min_k4` packing) 128 code bytes. + * Each of the 8 sub-blocks (32 elts) contributes `codeSum*scale - inputSum*offset`, + * with `scale = d*scaleIdx`, `offset = dMin*minIdx`. Math mirrors + * `JvmQuantizedVectorKernels.matmulQ4_KVec` / `DequantOps.dequantQ4KFromBytes`. + */ +public object ScalarQ4_KMatmulKernel : Q4KMatmulKernel { + + private const val BLOCK_SIZE = 256 + private const val SUB_BLOCK = 32 + private const val BYTES_PER_BLOCK = 144 + + override fun matmul( + input: FloatArray, inputOffset: Int, + weight: ByteArray, weightByteOffset: Int, + inputDim: Int, outputDim: Int, + output: FloatArray, outputOffset: Int, + ) { + require(inputDim % BLOCK_SIZE == 0) { + "ScalarQ4_KMatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim" + } + if (outputDim == 0) return + if (inputDim == 0) { for (o in 0 until outputDim) output[outputOffset + o] = 0f; return } + val blocksPerInputDim = inputDim / BLOCK_SIZE + val scaleIdx = IntArray(8) + val minIdx = IntArray(8) + + for (o in 0 until outputDim) { + var acc = 0f + for (blockIdx in 0 until blocksPerInputDim) { + val blockBase = weightByteOffset + (blockIdx * outputDim + o) * BYTES_PER_BLOCK + val d = decodeHalf(((weight[blockBase + 1].toInt() and 0xFF) shl 8) or (weight[blockBase].toInt() and 0xFF)) + val dMin = decodeHalf(((weight[blockBase + 3].toInt() and 0xFF) shl 8) or (weight[blockBase + 2].toInt() and 0xFF)) + + // ggml get_scale_min_k4 over the 12 scale bytes. + val sc = blockBase + 4 + for (sb in 0 until 4) { + scaleIdx[sb] = weight[sc + sb].toInt() and 0x3F + minIdx[sb] = weight[sc + sb + 4].toInt() and 0x3F + } + for (sb in 4 until 8) { + val low4S = weight[sc + sb + 4].toInt() and 0x0F + val high2S = (weight[sc + sb - 4].toInt() and 0xFF) ushr 6 + scaleIdx[sb] = low4S or (high2S shl 4) + val low4M = (weight[sc + sb + 4].toInt() and 0xFF) ushr 4 + val high2M = (weight[sc + sb].toInt() and 0xFF) ushr 6 + minIdx[sb] = low4M or (high2M shl 4) + } + + val codesOffset = blockBase + 16 + val inBlockBase = inputOffset + blockIdx * BLOCK_SIZE + for (groupJ in 0 until 4) { + val qsRegion = codesOffset + groupJ * 32 + // sub-block lo (low nibbles) then hi (high nibbles) of the same 32 bytes. + for (half in 0 until 2) { + val sb = 2 * groupJ + half + val inStart = inBlockBase + sb * SUB_BLOCK + var codeSum = 0f + var inputSum = 0f + for (i in 0 until 32) { + val b = weight[qsRegion + i].toInt() and 0xFF + val code = if (half == 0) (b and 0x0F) else (b ushr 4) + val v = input[inStart + i] + codeSum += v * code + inputSum += v + } + acc += codeSum * (d * scaleIdx[sb]) - inputSum * (dMin * minIdx[sb]) + } + } + } + output[outputOffset + o] = acc + } + } +} diff --git a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarQ5_0MatmulKernel.kt b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarQ5_0MatmulKernel.kt new file mode 100644 index 00000000..56361286 --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarQ5_0MatmulKernel.kt @@ -0,0 +1,60 @@ +package sk.ainet.exec.kernel + +import sk.ainet.backend.api.kernel.Q5_0MatmulKernel + +/** + * Scalar reference [Q5_0MatmulKernel] — per-block dequant + per-element FMA, + * no SIMD. commonMain, so available on every KMP target. + * + * Block layout (32-elt, 22 B; block-major `(blockIdx*outputDim+o)*22`): + * `d`(f16) `qh[4]` `qs[16]`. Dequant matches `DequantOps.dequantQ5_0FromBytes`: + * `w = d*(code + (highBit shl 4) - 16)` (symmetric, no per-block min). + */ +public object ScalarQ5_0MatmulKernel : Q5_0MatmulKernel { + + private const val BLOCK_SIZE = 32 + private const val BYTES_PER_BLOCK = 22 + + override fun matmul( + input: FloatArray, inputOffset: Int, + weight: ByteArray, weightByteOffset: Int, + inputDim: Int, outputDim: Int, + output: FloatArray, outputOffset: Int, + ) { + require(inputDim % BLOCK_SIZE == 0) { + "ScalarQ5_0MatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim" + } + if (outputDim == 0) return + if (inputDim == 0) { for (o in 0 until outputDim) output[outputOffset + o] = 0f; return } + val blocksPerInputDim = inputDim / BLOCK_SIZE + + for (o in 0 until outputDim) { + var acc = 0f + for (blockIdx in 0 until blocksPerInputDim) { + val base = weightByteOffset + (blockIdx * outputDim + o) * BYTES_PER_BLOCK + val d = decodeHalf(((weight[base + 1].toInt() and 0xFF) shl 8) or (weight[base].toInt() and 0xFF)) + val qh0 = weight[base + 2].toInt() and 0xFF + val qh1 = weight[base + 3].toInt() and 0xFF + val qh2 = weight[base + 4].toInt() and 0xFF + val qh3 = weight[base + 5].toInt() and 0xFF + val qsBase = base + 6 + val inputBase = inputOffset + blockIdx * BLOCK_SIZE + for (j in 0 until 16) { + val q = weight[qsBase + j].toInt() and 0xFF + val lo = q and 0x0F + val hi = q ushr 4 + val bitLo = (qh(qh0, qh1, qh2, qh3, j) ushr (j % 8)) and 0x01 + val jh = j + 16 + val bitHi = (qh(qh0, qh1, qh2, qh3, jh) ushr (jh % 8)) and 0x01 + val wLo = d * (lo + (bitLo shl 4) - 16) + val wHi = d * (hi + (bitHi shl 4) - 16) + acc += input[inputBase + j] * wLo + input[inputBase + 16 + j] * wHi + } + } + output[outputOffset + o] = acc + } + } + + private inline fun qh(q0: Int, q1: Int, q2: Int, q3: Int, idx: Int): Int = + when (idx / 8) { 0 -> q0; 1 -> q1; 2 -> q2; else -> q3 } +} diff --git a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarQ5_1MatmulKernel.kt b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarQ5_1MatmulKernel.kt new file mode 100644 index 00000000..80d08bfa --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarQ5_1MatmulKernel.kt @@ -0,0 +1,62 @@ +package sk.ainet.exec.kernel + +import sk.ainet.backend.api.kernel.Q5_1MatmulKernel + +/** + * Scalar reference [Q5_1MatmulKernel] — per-block dequant + per-element FMA, + * no SIMD. Always available on every KMP target (commonMain), so Q5_1 packed + * matmul works on Kotlin/Native, JS and WASM, not only the JVM. + * + * Block layout (32-elt, 24 B; block-major `(blockIdx*outputDim+o)*24`): + * `d`(f16) `m`(f16) `qh[4]` `qs[16]`. Dequant matches + * `DequantOps.dequantQ5_1FromBytes`: `w = d*(code + (highBit shl 4)) + m`. + */ +public object ScalarQ5_1MatmulKernel : Q5_1MatmulKernel { + + private const val BLOCK_SIZE = 32 + private const val BYTES_PER_BLOCK = 24 + + override fun matmul( + input: FloatArray, inputOffset: Int, + weight: ByteArray, weightByteOffset: Int, + inputDim: Int, outputDim: Int, + output: FloatArray, outputOffset: Int, + ) { + require(inputDim % BLOCK_SIZE == 0) { + "ScalarQ5_1MatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim" + } + if (outputDim == 0) return + if (inputDim == 0) { for (o in 0 until outputDim) output[outputOffset + o] = 0f; return } + val blocksPerInputDim = inputDim / BLOCK_SIZE + + for (o in 0 until outputDim) { + var acc = 0f + for (blockIdx in 0 until blocksPerInputDim) { + val base = weightByteOffset + (blockIdx * outputDim + o) * BYTES_PER_BLOCK + val d = decodeHalf(((weight[base + 1].toInt() and 0xFF) shl 8) or (weight[base].toInt() and 0xFF)) + val m = decodeHalf(((weight[base + 3].toInt() and 0xFF) shl 8) or (weight[base + 2].toInt() and 0xFF)) + val qh0 = weight[base + 4].toInt() and 0xFF + val qh1 = weight[base + 5].toInt() and 0xFF + val qh2 = weight[base + 6].toInt() and 0xFF + val qh3 = weight[base + 7].toInt() and 0xFF + val qsBase = base + 8 + val inputBase = inputOffset + blockIdx * BLOCK_SIZE + for (j in 0 until 16) { + val q = weight[qsBase + j].toInt() and 0xFF + val lo = q and 0x0F + val hi = q ushr 4 + val bitLo = (qh(qh0, qh1, qh2, qh3, j) ushr (j % 8)) and 0x01 + val jh = j + 16 + val bitHi = (qh(qh0, qh1, qh2, qh3, jh) ushr (jh % 8)) and 0x01 + val wLo = d * (lo + (bitLo shl 4)) + m + val wHi = d * (hi + (bitHi shl 4)) + m + acc += input[inputBase + j] * wLo + input[inputBase + 16 + j] * wHi + } + } + output[outputOffset + o] = acc + } + } + + private inline fun qh(q0: Int, q1: Int, q2: Int, q3: Int, idx: Int): Int = + when (idx / 8) { 0 -> q0; 1 -> q1; 2 -> q2; else -> q3 } +} diff --git a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarQ6_KMatmulKernel.kt b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarQ6_KMatmulKernel.kt new file mode 100644 index 00000000..e736a4a5 --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarQ6_KMatmulKernel.kt @@ -0,0 +1,80 @@ +package sk.ainet.exec.kernel + +import sk.ainet.backend.api.kernel.Q6KMatmulKernel + +/** + * Scalar reference [Q6KMatmulKernel] — commonMain, so Q6_K packed matmul works + * on Kotlin/Native / JS / WASM, not only the JVM SIMD path. + * + * Q6_K super-block: 256 elements / 210 bytes, block-major `(blockIdx*outputDim+o)*210`: + * `ql[128]` (low 4 bits) `qh[64]` (high 2 bits) `scales[16]` (int8) `d`(f16). + * Each block is dequantized to 256 floats (matching the scalar path of + * `JvmQuantizedVectorKernels.dequantQ6_KBlock` / `DequantOps.dequantQ6KFromBytes`) + * and dotted with the matching input window. + */ +public object ScalarQ6_KMatmulKernel : Q6KMatmulKernel { + + private const val BLOCK_SIZE = 256 + private const val BYTES_PER_BLOCK = 210 + + override fun matmul( + input: FloatArray, inputOffset: Int, + weight: ByteArray, weightByteOffset: Int, + inputDim: Int, outputDim: Int, + output: FloatArray, outputOffset: Int, + ) { + require(inputDim % BLOCK_SIZE == 0) { + "ScalarQ6_KMatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim" + } + if (outputDim == 0) return + if (inputDim == 0) { for (o in 0 until outputDim) output[outputOffset + o] = 0f; return } + val blocksPerInputDim = inputDim / BLOCK_SIZE + val scratch = FloatArray(BLOCK_SIZE) + + for (o in 0 until outputDim) { + var acc = 0f + for (blockIdx in 0 until blocksPerInputDim) { + val blockBase = weightByteOffset + (blockIdx * outputDim + o) * BYTES_PER_BLOCK + dequantBlock(weight, blockBase, scratch) + val inBase = inputOffset + blockIdx * BLOCK_SIZE + for (i in 0 until BLOCK_SIZE) acc += input[inBase + i] * scratch[i] + } + output[outputOffset + o] = acc + } + } + + /** Dequant one 256-element Q6_K block into [scratch]. Mirrors the scalar path of ggml `dequantize_row_q6_K`. */ + private fun dequantBlock(w: ByteArray, blockBase: Int, scratch: FloatArray) { + val qlBase0 = blockBase + val qhBase0 = blockBase + 128 + val scBase0 = blockBase + 192 + val d = decodeHalf(((w[blockBase + 209].toInt() and 0xFF) shl 8) or (w[blockBase + 208].toInt() and 0xFF)) + + for (half in 0..1) { + val qlBase = qlBase0 + half * 64 + val qhBase = qhBase0 + half * 32 + val scBase = scBase0 + half * 8 + val outBase = half * 128 + for (isIdx in 0..1) { + val sc1 = d * w[scBase + isIdx + 0].toInt() + val sc2 = d * w[scBase + isIdx + 2].toInt() + val sc3 = d * w[scBase + isIdx + 4].toInt() + val sc4 = d * w[scBase + isIdx + 6].toInt() + val lStart = isIdx * 16 + for (l in lStart until lStart + 16) { + val ql0 = w[qlBase + l].toInt() and 0xFF + val ql32 = w[qlBase + l + 32].toInt() and 0xFF + val qhL = w[qhBase + l].toInt() and 0xFF + val q1 = ((ql0 and 0x0F) or ((qhL and 0x03) shl 4)) - 32 + val q2 = ((ql32 and 0x0F) or (((qhL ushr 2) and 0x03) shl 4)) - 32 + val q3 = ((ql0 ushr 4) or (((qhL ushr 4) and 0x03) shl 4)) - 32 + val q4 = ((ql32 ushr 4) or (((qhL ushr 6) and 0x03) shl 4)) - 32 + scratch[outBase + l + 0] = sc1 * q1 + scratch[outBase + l + 32] = sc2 * q2 + scratch[outBase + l + 64] = sc3 * q3 + scratch[outBase + l + 96] = sc4 * q4 + } + } + } + } +} diff --git a/skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/exec/kernel/ScalarPackedKernelParityTest.kt b/skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/exec/kernel/ScalarPackedKernelParityTest.kt new file mode 100644 index 00000000..4625e7e0 --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/exec/kernel/ScalarPackedKernelParityTest.kt @@ -0,0 +1,148 @@ +package sk.ainet.exec.kernel + +import kotlin.math.abs +import kotlin.random.Random +import kotlin.test.Test +import kotlin.test.assertTrue + +/** + * Verifies the commonMain scalar packed-quant matmul kernels (Q5_1/Q5_0/Q4_K/Q6_K) + * against an independent inline dequant: for block-major packed bytes, the kernel's + * `Σ input·dequant` must match a reference that dequantizes the same bytes and does the + * dot in FP32. Runs on every target the module is compiled for (JVM + linuxX64), proving + * Native packed-matmul correctness, not only JVM. + */ +class ScalarPackedKernelParityTest { + + private fun half(v: Float): Int { + val bits = v.toRawBits() + val sign = (bits ushr 16) and 0x8000 + val expo = ((bits ushr 23) and 0xFF) - 127 + 15 + val mant = bits and 0x7FFFFF + if (expo <= 0) return sign + if (expo >= 31) return sign or 0x7C00 + return sign or (expo shl 10) or (mant ushr 13) + } + + private fun le16(b: ByteArray, off: Int, h: Int) { + b[off] = (h and 0xFF).toByte(); b[off + 1] = ((h ushr 8) and 0xFF).toByte() + } + + /** Build random block-major bytes + the row-major [out,in] FP32 weight they dequantize to. */ + private fun build( + fmt: String, inputDim: Int, outputDim: Int, rng: Random, + ): Pair { + val bs = if (fmt == "Q4_K" || fmt == "Q6_K") 256 else 32 + val bpb = when (fmt) { "Q5_1" -> 24; "Q5_0" -> 22; "Q4_K" -> 144; else -> 210 } + val blocks = inputDim / bs + val bytes = ByteArray(outputDim * blocks * bpb) + val wf = FloatArray(outputDim * inputDim) + for (o in 0 until outputDim) for (bI in 0 until blocks) { + val off = (bI * outputDim + o) * bpb + val dst = o * inputDim + bI * bs + when (fmt) { + "Q5_1" -> blkQ5_1(bytes, off, wf, dst, rng) + "Q5_0" -> blkQ5_0(bytes, off, wf, dst, rng) + "Q4_K" -> blkQ4_K(bytes, off, wf, dst, rng) + "Q6_K" -> blkQ6_K(bytes, off, wf, dst, rng) + } + } + return bytes to wf + } + + private fun blkQ5_1(b: ByteArray, off: Int, wf: FloatArray, dst: Int, rng: Random) { + val d = (rng.nextFloat() * 0.05f + 0.01f); val m = (rng.nextFloat() - 0.5f) + le16(b, off, half(d)); le16(b, off + 2, half(m)) + val qh = IntArray(4) { rng.nextInt(256) }; for (k in 0 until 4) b[off + 4 + k] = qh[k].toByte() + for (k in 0 until 16) b[off + 8 + k] = rng.nextInt(256).toByte() + for (j in 0 until 16) { + val q = b[off + 8 + j].toInt() and 0xFF; val lo = q and 0xF; val hi = q ushr 4 + val bl = (qh[j / 8] ushr (j % 8)) and 1; val bh = (qh[(j + 16) / 8] ushr ((j + 16) % 8)) and 1 + wf[dst + j] = d * (lo + (bl shl 4)) + m; wf[dst + 16 + j] = d * (hi + (bh shl 4)) + m + } + } + + private fun blkQ5_0(b: ByteArray, off: Int, wf: FloatArray, dst: Int, rng: Random) { + val d = (rng.nextFloat() * 0.05f + 0.01f); le16(b, off, half(d)) + val qh = IntArray(4) { rng.nextInt(256) }; for (k in 0 until 4) b[off + 2 + k] = qh[k].toByte() + for (k in 0 until 16) b[off + 6 + k] = rng.nextInt(256).toByte() + for (j in 0 until 16) { + val q = b[off + 6 + j].toInt() and 0xFF; val lo = q and 0xF; val hi = q ushr 4 + val bl = (qh[j / 8] ushr (j % 8)) and 1; val bh = (qh[(j + 16) / 8] ushr ((j + 16) % 8)) and 1 + wf[dst + j] = d * (lo + (bl shl 4) - 16); wf[dst + 16 + j] = d * (hi + (bh shl 4) - 16) + } + } + + private fun blkQ4_K(b: ByteArray, off: Int, wf: FloatArray, dst: Int, rng: Random) { + val d = rng.nextFloat() * 0.02f + 0.005f; val dMin = rng.nextFloat() * 0.02f + 0.005f + le16(b, off, half(d)); le16(b, off + 2, half(dMin)) + for (k in 0 until 140) b[off + 4 + k] = rng.nextInt(256).toByte() // 12 scales + 128 codes + val sc = off + 4; val scaleIdx = IntArray(8); val minIdx = IntArray(8) + for (s in 0 until 4) { scaleIdx[s] = b[sc + s].toInt() and 0x3F; minIdx[s] = b[sc + s + 4].toInt() and 0x3F } + for (s in 4 until 8) { + scaleIdx[s] = (b[sc + s + 4].toInt() and 0x0F) or (((b[sc + s - 4].toInt() and 0xFF) ushr 6) shl 4) + minIdx[s] = ((b[sc + s + 4].toInt() and 0xFF) ushr 4) or (((b[sc + s].toInt() and 0xFF) ushr 6) shl 4) + } + val codes = off + 16 + for (g in 0 until 4) for (half in 0 until 2) { + val s = 2 * g + half; val scale = d * scaleIdx[s]; val offs = dMin * minIdx[s] + for (i in 0 until 32) { + val by = b[codes + g * 32 + i].toInt() and 0xFF + val code = if (half == 0) (by and 0x0F) else (by ushr 4) + wf[dst + s * 32 + i] = code * scale - offs + } + } + } + + private fun blkQ6_K(b: ByteArray, off: Int, wf: FloatArray, dst: Int, rng: Random) { + for (k in 0 until 208) b[off + k] = rng.nextInt(256).toByte() // ql+qh+scales + val d = rng.nextFloat() * 0.01f + 0.002f; le16(b, off + 208, half(d)) + for (h in 0..1) { + val qlB = off + h * 64; val qhB = off + 128 + h * 32; val scB = off + 192 + h * 8; val ob = h * 128 + for (isIdx in 0..1) { + val sc1 = d * b[scB + isIdx].toInt(); val sc2 = d * b[scB + isIdx + 2].toInt() + val sc3 = d * b[scB + isIdx + 4].toInt(); val sc4 = d * b[scB + isIdx + 6].toInt() + for (l in isIdx * 16 until isIdx * 16 + 16) { + val ql0 = b[qlB + l].toInt() and 0xFF; val ql32 = b[qlB + l + 32].toInt() and 0xFF + val qhL = b[qhB + l].toInt() and 0xFF + wf[dst + ob + l + 0] = sc1 * (((ql0 and 0xF) or ((qhL and 3) shl 4)) - 32) + wf[dst + ob + l + 32] = sc2 * (((ql32 and 0xF) or (((qhL ushr 2) and 3) shl 4)) - 32) + wf[dst + ob + l + 64] = sc3 * (((ql0 ushr 4) or (((qhL ushr 4) and 3) shl 4)) - 32) + wf[dst + ob + l + 96] = sc4 * (((ql32 ushr 4) or (((qhL ushr 6) and 3) shl 4)) - 32) + } + } + } + } + + private fun check(fmt: String, inputDim: Int, outputDim: Int, seed: Int) { + val rng = Random(seed) + val (bytes, wf) = build(fmt, inputDim, outputDim, rng) + val input = FloatArray(inputDim) { rng.nextFloat() - 0.5f } + val expected = FloatArray(outputDim) { o -> var s = 0f; for (i in 0 until inputDim) s += input[i] * wf[o * inputDim + i]; s } + val actual = FloatArray(outputDim) + when (fmt) { + "Q5_1" -> ScalarQ5_1MatmulKernel.matmul(input, 0, bytes, 0, inputDim, outputDim, actual, 0) + "Q5_0" -> ScalarQ5_0MatmulKernel.matmul(input, 0, bytes, 0, inputDim, outputDim, actual, 0) + "Q4_K" -> ScalarQ4_KMatmulKernel.matmul(input, 0, bytes, 0, inputDim, outputDim, actual, 0) + "Q6_K" -> ScalarQ6_KMatmulKernel.matmul(input, 0, bytes, 0, inputDim, outputDim, actual, 0) + } + var maxErr = 0f + var maxAbs = 1f + for (o in 0 until outputDim) { + maxErr = maxOf(maxErr, abs(expected[o] - actual[o])) + maxAbs = maxOf(maxAbs, abs(expected[o])) + } + // Relative tolerance: the kernel accumulates per sub-block (with a + // codeSum*scale - inputSum*offset cancellation for Q4_K), so it differs + // from the flat reference sum only by FP reassociation (~1e-3 rel). + assertTrue( + maxErr < 5e-3f * maxAbs, + "$fmt scalar kernel deviates from inline dequant: maxErr=$maxErr (maxAbs=$maxAbs)", + ) + } + + @Test fun q5_1() = check("Q5_1", inputDim = 128, outputDim = 16, seed = 1) + @Test fun q5_0() = check("Q5_0", inputDim = 96, outputDim = 24, seed = 2) + @Test fun q4_k() = check("Q4_K", inputDim = 256, outputDim = 12, seed = 3) + @Test fun q6_k() = check("Q6_K", inputDim = 512, outputDim = 8, seed = 4) +}