diff --git a/skainet-backends/skainet-backend-cpu/api/jvm/skainet-backend-cpu.api b/skainet-backends/skainet-backend-cpu/api/jvm/skainet-backend-cpu.api index a953d311..39da5c0a 100644 --- a/skainet-backends/skainet-backend-cpu/api/jvm/skainet-backend-cpu.api +++ b/skainet-backends/skainet-backend-cpu/api/jvm/skainet-backend-cpu.api @@ -81,6 +81,11 @@ public final class sk/ainet/exec/kernel/PanamaVectorQ4KMatmulKernel : sk/ainet/b public fun matmul ([FI[BIII[FI)V } +public final class sk/ainet/exec/kernel/PanamaVectorQ4_0MatmulKernel : sk/ainet/backend/api/kernel/Q4_0MatmulKernel { + public static final field INSTANCE Lsk/ainet/exec/kernel/PanamaVectorQ4_0MatmulKernel; + public fun matmul ([FI[BIII[FI)V +} + public final class sk/ainet/exec/kernel/PanamaVectorQ8_0MatmulKernel : sk/ainet/backend/api/kernel/Q8_0MatmulKernel { public static final field INSTANCE Lsk/ainet/exec/kernel/PanamaVectorQ8_0MatmulKernel; public fun matmul ([FI[BIII[FI)V diff --git a/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/PanamaVectorKernelProvider.kt b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/PanamaVectorKernelProvider.kt index ba978052..ecc68cf5 100644 --- a/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/PanamaVectorKernelProvider.kt +++ b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/PanamaVectorKernelProvider.kt @@ -4,6 +4,7 @@ import sk.ainet.backend.api.kernel.Bf16MatmulKernel import sk.ainet.backend.api.kernel.Fp32MatmulKernel import sk.ainet.backend.api.kernel.KernelProvider import sk.ainet.backend.api.kernel.Q4KMatmulKernel +import sk.ainet.backend.api.kernel.Q4_0MatmulKernel import sk.ainet.backend.api.kernel.Q8_0MatmulKernel import sk.ainet.exec.tensor.ops.JvmCpuBackendConfig @@ -49,6 +50,9 @@ public object PanamaVectorKernelProvider : KernelProvider { override fun matmulQ8_0(): Q8_0MatmulKernel? = if (isAvailable()) PanamaVectorQ8_0MatmulKernel else null + override fun matmulQ4_0(): Q4_0MatmulKernel? = + if (isAvailable()) PanamaVectorQ4_0MatmulKernel else null + private fun isVectorApiClassLoaded(): Boolean = runCatching { Class.forName("jdk.incubator.vector.FloatVector") Class.forName("jdk.incubator.vector.VectorSpecies") diff --git a/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/PanamaVectorQ4_0MatmulKernel.kt b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/PanamaVectorQ4_0MatmulKernel.kt new file mode 100644 index 00000000..d3ca54b9 --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/PanamaVectorQ4_0MatmulKernel.kt @@ -0,0 +1,114 @@ +package sk.ainet.exec.kernel + +import jdk.incubator.vector.FloatVector +import jdk.incubator.vector.VectorOperators +import jdk.incubator.vector.VectorSpecies +import sk.ainet.backend.api.kernel.Q4_0MatmulKernel + +/** + * SIMD-vectorized FP32 × Q4_0 matmul on the JDK Vector API. + * + * Pipeline per 32-element block: + * 1. Decode the 2-byte FP16 scale `d` once. + * 2. Unpack the 16 code bytes into 32 sign-corrected floats (`nibble - 8`) + * in a reusable scratch buffer, using the canonical ggml **split** + * layout (low nibbles → elements 0..15, high nibbles → 16..31). The + * nibble-pair-per-byte packing makes a fully-fused `ByteVector` + * pipeline awkward, so this kernel keeps the scratch-then-FMA shape + * (same approach as the legacy `JvmQuantizedVectorKernels` Q4_0 path). + * 3. SIMD-FMA the scratch against the matching input window into a + * lane-wise block accumulator, reduce across lanes, and fold `* d` + * exactly once per block. + * + * Numerical equivalence with [ScalarQ4_0MatmulKernel] is within FMA + + * reordered-reduction tolerance — the same bar the Q8_0 / Q4_K Panama + * kernels use. + */ +public object PanamaVectorQ4_0MatmulKernel : Q4_0MatmulKernel { + + private const val BLOCK_SIZE = 32 + private const val BYTES_PER_BLOCK = 18 + + private val floatSpecies: VectorSpecies = FloatVector.SPECIES_PREFERRED + + override fun matmul( + input: FloatArray, inputOffset: Int, + weight: ByteArray, weightByteOffset: Int, + inputDim: Int, outputDim: Int, + output: FloatArray, outputOffset: Int, + ) { + require(inputDim % BLOCK_SIZE == 0) { + "PanamaVectorQ4_0MatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim" + } + if (outputDim == 0) return + if (inputDim == 0) { + for (o in 0 until outputDim) output[outputOffset + o] = 0f + return + } + val blocksPerInputDim = inputDim / BLOCK_SIZE + val step = floatSpecies.length() + val loopBound = floatSpecies.loopBound(BLOCK_SIZE) + val codeBuf = FloatArray(BLOCK_SIZE) + + for (o in 0 until outputDim) { + var acc = 0f + for (blockIdx in 0 until blocksPerInputDim) { + val blockBase = weightByteOffset + (blockIdx * outputDim + o) * BYTES_PER_BLOCK + // FP16 scale — two LE bytes. + val dBits = (weight[blockBase].toInt() and 0xFF) or + ((weight[blockBase + 1].toInt() and 0xFF) shl 8) + val d = halfToFloat(dBits) + + // Split-layout unpack: low nibbles → 0..15, high → 16..31. + val codesBase = blockBase + 2 + for (j in 0 until 16) { + val b = weight[codesBase + j].toInt() and 0xFF + codeBuf[j] = ((b and 0x0F) - 8).toFloat() + codeBuf[16 + j] = ((b ushr 4) - 8).toFloat() + } + + val inputBase = inputOffset + blockIdx * BLOCK_SIZE + var blockAccVec = FloatVector.zero(floatSpecies) + var k = 0 + while (k < loopBound) { + val inV = FloatVector.fromArray(floatSpecies, input, inputBase + k) + val cV = FloatVector.fromArray(floatSpecies, codeBuf, k) + blockAccVec = inV.fma(cV, blockAccVec) + k += step + } + var blockAcc = blockAccVec.reduceLanes(VectorOperators.ADD) + // Scalar tail (only if floatSpecies.length() doesn't divide 32 — rare). + while (k < BLOCK_SIZE) { + blockAcc += input[inputBase + k] * codeBuf[k] + k++ + } + acc += blockAcc * d + } + output[outputOffset + o] = acc + } + } + + /** Same FP16 → FP32 conversion as [ScalarQ4_0MatmulKernel]. */ + private fun halfToFloat(hbits: Int): Float { + val sign = (hbits and 0x8000) shl 16 + val exp = (hbits and 0x7C00) shr 10 + val mant = hbits and 0x03FF + return when (exp) { + 0 -> { + if (mant == 0) Float.fromBits(sign) + else { + var m = mant + var e = -14 + while ((m and 0x400) == 0) { + m = m shl 1 + e-- + } + m = m and 0x3FF + Float.fromBits(sign or ((e + 127) shl 23) or (m shl 13)) + } + } + 31 -> Float.fromBits(sign or (0xFF shl 23) or (mant shl 13)) + else -> Float.fromBits(sign or ((exp - 15 + 127) shl 23) or (mant shl 13)) + } + } +} diff --git a/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/JvmQuantizedVectorKernels.kt b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/JvmQuantizedVectorKernels.kt index 94cb5202..8f726ef6 100644 --- a/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/JvmQuantizedVectorKernels.kt +++ b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/JvmQuantizedVectorKernels.kt @@ -549,13 +549,14 @@ internal object JvmQuantizedVectorKernels { // Read f16 scale val scale = halfToFloat(read2BytesLE(weightSeg, blockByteOffset)) - // Unpack 16 packed bytes → 32 sign-corrected nibbles. Two - // nibbles per byte load means half the byte traffic of the - // straight scalar dot product. + // Unpack 16 packed bytes → 32 sign-corrected nibbles in the + // canonical ggml *split* layout: low nibbles decode elements + // 0..15, high nibbles decode elements 16..31. (Matches + // DequantOps.dequantQ4_0FromBytes and Q4_0BlockTensorData.) for (k in 0 until 16) { val b = weightSeg.get(JAVA_BYTE_LE, codesOffset + k.toLong()).toInt() and 0xFF - codeBuf[2 * k] = (b and 0x0F).toFloat() - 8f - codeBuf[2 * k + 1] = (b ushr 4).toFloat() - 8f + codeBuf[k] = (b and 0x0F).toFloat() - 8f + codeBuf[16 + k] = (b ushr 4).toFloat() - 8f } // SIMD FMA dot product. diff --git a/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/PanamaVectorQ4_0MatmulKernelParityTest.kt b/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/PanamaVectorQ4_0MatmulKernelParityTest.kt new file mode 100644 index 00000000..d45e6c99 --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/PanamaVectorQ4_0MatmulKernelParityTest.kt @@ -0,0 +1,113 @@ +package sk.ainet.exec.kernel + +import kotlin.math.abs +import kotlin.random.Random +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +/** + * Numerical parity tests for [PanamaVectorQ4_0MatmulKernel] against + * [ScalarQ4_0MatmulKernel]. Both kernels apply the same FP16-scale + * decode + `(nibble - 8)` dequant in the canonical ggml split layout; + * differences come from FMA + reordered-reduction order only. + * + * Tolerance scales with the number of Q4_0 blocks processed: `1e-2 * + * blocksPerInputDim`, clamped to a `1e-2` floor — mirrors the Q8_0 + * parity test convention. + */ +class PanamaVectorQ4_0MatmulKernelParityTest { + + private val blockSize = 32 + private val bytesPerBlock = 18 + + /** Random Q4_0 packed bytes; scales clamped to a small positive FP16. */ + private fun randomQ4_0Bytes(blocksPerInputDim: Int, outputDim: Int, seed: Int): ByteArray { + val rng = Random(seed) + val numBlocks = blocksPerInputDim * outputDim + val bytes = ByteArray(numBlocks * bytesPerBlock) + rng.nextBytes(bytes) + for (block in 0 until numBlocks) { + val base = block * bytesPerBlock + bytes[base + 0] = 0x00.toByte() + bytes[base + 1] = 0x22.toByte() // FP16 0x2200 ≈ 7.6e-3 + } + return bytes + } + + private fun assertParity( + inputDim: Int, + outputDim: Int, + seed: Int, + tolPerBlock: Float = 1e-2f, + ) { + val blocksPerInputDim = inputDim / blockSize + val rng = Random(seed) + val input = FloatArray(inputDim) { rng.nextFloat() - 0.5f } + val weight = randomQ4_0Bytes(blocksPerInputDim, outputDim, seed) + val outScalar = FloatArray(outputDim) + val outPanama = FloatArray(outputDim) + + ScalarQ4_0MatmulKernel.matmul(input, 0, weight, 0, inputDim, outputDim, outScalar, 0) + PanamaVectorQ4_0MatmulKernel.matmul(input, 0, weight, 0, inputDim, outputDim, outPanama, 0) + + val tol = (tolPerBlock * blocksPerInputDim.coerceAtLeast(1)).coerceAtLeast(tolPerBlock) + for (i in outScalar.indices) { + val diff = abs(outScalar[i] - outPanama[i]) + assertTrue( + diff <= tol, + "mismatch at $i: scalar=${outScalar[i]} panama=${outPanama[i]} diff=$diff tol=$tol", + ) + } + } + + @Test fun single_block_single_output_matches_scalar() = + assertParity(inputDim = 32, outputDim = 1, seed = 1) + + @Test fun single_block_multiple_outputs_matches_scalar() = + assertParity(inputDim = 32, outputDim = 7, seed = 2) + + @Test fun multiple_blocks_single_output_matches_scalar() = + assertParity(inputDim = 256, outputDim = 1, seed = 3) + + @Test fun llm_typical_attention_proj_matches_scalar() = + assertParity(inputDim = 512, outputDim = 512, seed = 4) + + @Test fun llm_typical_ffn_proj_matches_scalar() = + assertParity(inputDim = 256, outputDim = 1024, seed = 5) + + @Test fun rejects_non_block_aligned_input_dim() { + assertFailsWith { + PanamaVectorQ4_0MatmulKernel.matmul( + FloatArray(31), 0, + ByteArray(bytesPerBlock), 0, + 31, 1, + FloatArray(1), 0, + ) + } + } + + @Test fun zero_input_dim_zeros_output() { + val out = FloatArray(5) { 9f } + PanamaVectorQ4_0MatmulKernel.matmul( + FloatArray(0), 0, + ByteArray(0), 0, + 0, 5, + out, 0, + ) + for (v in out) assertEquals(0f, v, "output should be zeroed for inputDim=0") + } + + @Test fun provider_returns_panama_q4_0_when_available() { + val kernel = PanamaVectorKernelProvider.matmulQ4_0() + if (PanamaVectorKernelProvider.isAvailable()) { + assertTrue( + kernel === PanamaVectorQ4_0MatmulKernel, + "Provider must hand out the Panama Q4_0 kernel when available", + ) + } else { + assertEquals(null, kernel, "Provider must return null when Vector API unavailable") + } + } +} diff --git a/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/QuantizedMemSegMatmulTest.kt b/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/QuantizedMemSegMatmulTest.kt index 38f5593e..30c3fd07 100644 --- a/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/QuantizedMemSegMatmulTest.kt +++ b/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/QuantizedMemSegMatmulTest.kt @@ -47,6 +47,8 @@ class QuantizedMemSegMatmulTest { /** * Encode a single Q4_0 block: 32 float values -> 18 bytes (2 scale + 16 packed nibbles). + * Uses the canonical ggml *split* layout: code[j] is the low nibble of + * byte j, code[j+16] is the high nibble of byte j. */ private fun encodeQ4_0Block(values: FloatArray): ByteArray { require(values.size == 32) @@ -62,8 +64,8 @@ class QuantizedMemSegMatmulTest { val out = ByteArray(18) out[0] = (scaleHalf and 0xFF).toByte() out[1] = ((scaleHalf shr 8) and 0xFF).toByte() - for (i in 0 until 16) { - out[2 + i] = ((codes[2 * i + 1] shl 4) or codes[2 * i]).toByte() + for (j in 0 until 16) { + out[2 + j] = ((codes[j + 16] shl 4) or codes[j]).toByte() } return out } diff --git a/skainet-backends/skainet-backend-native-cpu/native/CMakeLists.txt b/skainet-backends/skainet-backend-native-cpu/native/CMakeLists.txt index a6655e67..ade06d41 100644 --- a/skainet-backends/skainet-backend-native-cpu/native/CMakeLists.txt +++ b/skainet-backends/skainet-backend-native-cpu/native/CMakeLists.txt @@ -15,6 +15,7 @@ add_library(skainet_kernels SHARED src/fp32_matmul.c src/bf16_matmul.c src/q8_0_matmul.c + src/q4_0_matmul.c ) target_include_directories(skainet_kernels PUBLIC diff --git a/skainet-backends/skainet-backend-native-cpu/native/include/skainet_kernels.h b/skainet-backends/skainet-backend-native-cpu/native/include/skainet_kernels.h index caadf814..a0fa3ff7 100644 --- a/skainet-backends/skainet-backend-native-cpu/native/include/skainet_kernels.h +++ b/skainet-backends/skainet-backend-native-cpu/native/include/skainet_kernels.h @@ -119,6 +119,32 @@ SKAINET_API void skainet_q8_0_matmul( int32_t output_offset ); +/* + * Q4_0 matrix-vector multiply. + * + * output[output_offset + o] = sum_j input[input_offset + j] * + * dequant(weight[block, o, j]) + * + * Block layout: canonical ggml Q4_0, 32 elements per block, 18 bytes + * per block (2 B FP16 scale + 16 B packed 4-bit codes in split layout — + * low nibbles → elements 0..15, high nibbles → 16..31), with packed + * weights laid out as + * weight + weight_byte_offset + (block_idx * output_dim + o) * 18 + * + * Dequant per element: `(code - 8) * d`. input_dim must be a multiple + * of 32. + */ +SKAINET_API void skainet_q4_0_matmul( + const float* input, + int32_t input_offset, + const uint8_t* weight, + int32_t weight_byte_offset, + int32_t input_dim, + int32_t output_dim, + float* output, + int32_t output_offset +); + #ifdef __cplusplus } #endif diff --git a/skainet-backends/skainet-backend-native-cpu/native/src/q4_0_matmul.c b/skainet-backends/skainet-backend-native-cpu/native/src/q4_0_matmul.c new file mode 100644 index 00000000..97111ccf --- /dev/null +++ b/skainet-backends/skainet-backend-native-cpu/native/src/q4_0_matmul.c @@ -0,0 +1,94 @@ +#include "skainet_kernels.h" + +#include +#include +#include + +/* + * Native FP32 × Q4_0 matrix-vector matmul matching the + * sk.ainet.backend.api.kernel.Q4_0MatmulKernel SPI. + * + * Block layout (canonical ggml Q4_0, 32 elements, 18 bytes): + * - bytes 0..1 : FP16 little-endian scale `d` + * - bytes 2..17 : 16 bytes packing 32 4-bit codes in the *split* + * layout — low nibbles decode elements 0..15, high nibbles decode + * elements 16..31. + * + * Per-block packed weight layout: + * weight + weight_byte_offset + (block_idx * output_dim + o) * 18 + * + * Dequant per element: `(code - 8) * d`. The `- 8` bias centres the + * unsigned 4-bit code. Scale `d` is folded once after the block + * accumulator (cheaper than broadcasting it across every inner FMA). + */ + +/* Portable FP16 → FP32 conversion. Matches the Kotlin + * `Q4_0BlockTensorData.halfToFloat` algorithm bit-for-bit. */ +static inline float skainet_q4_0_fp16_to_fp32(uint16_t h) { + uint32_t sign = ((uint32_t)(h & 0x8000u)) << 16; + uint32_t exp = (h >> 10) & 0x1Fu; + uint32_t mant = h & 0x3FFu; + uint32_t bits; + if (exp == 0) { + if (mant == 0) { + bits = sign; + } else { + int e = -14; + while ((mant & 0x400u) == 0) { + mant <<= 1; + --e; + } + mant &= 0x3FFu; + bits = sign | ((uint32_t)(e + 127) << 23) | (mant << 13); + } + } else if (exp == 0x1Fu) { + bits = sign | 0x7F800000u | (mant << 13); + } else { + bits = sign | ((uint32_t)(exp - 15 + 127) << 23) | (mant << 13); + } + float r; + memcpy(&r, &bits, sizeof(r)); + return r; +} + +SKAINET_API void skainet_q4_0_matmul( + const float* SKAINET_RESTRICT input, int32_t input_offset, + const uint8_t* SKAINET_RESTRICT weight, int32_t weight_byte_offset, + int32_t input_dim, int32_t output_dim, + float* SKAINET_RESTRICT output, int32_t output_offset +) { + if (output_dim <= 0) return; + if (input_dim <= 0) { + for (int32_t o = 0; o < output_dim; ++o) { + output[output_offset + o] = 0.0f; + } + return; + } + + const int32_t BLOCK_SIZE = 32; + const int32_t BYTES_PER_BLOCK = 18; + const int32_t blocks_per_input_dim = input_dim / BLOCK_SIZE; + + for (int32_t o = 0; o < output_dim; ++o) { + float acc = 0.0f; + for (int32_t block_idx = 0; block_idx < blocks_per_input_dim; ++block_idx) { + const uint8_t* SKAINET_RESTRICT block = + weight + weight_byte_offset + + (size_t)(block_idx * output_dim + o) * BYTES_PER_BLOCK; + uint16_t d_bits = (uint16_t) block[0] | ((uint16_t) block[1] << 8); + float d = skainet_q4_0_fp16_to_fp32(d_bits); + const uint8_t* SKAINET_RESTRICT codes = block + 2; + const float* SKAINET_RESTRICT input_block = + input + input_offset + (size_t) block_idx * BLOCK_SIZE; + float block_sum = 0.0f; + for (int32_t k = 0; k < 16; ++k) { + int32_t lo = (int32_t)(codes[k] & 0x0F) - 8; + int32_t hi = (int32_t)(codes[k] >> 4) - 8; + block_sum += input_block[k] * (float) lo; + block_sum += input_block[k + 16] * (float) hi; + } + acc += block_sum * d; + } + output[output_offset + o] = acc; + } +} diff --git a/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeKernelProvider.kt b/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeKernelProvider.kt index becb0393..60dd45e2 100644 --- a/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeKernelProvider.kt +++ b/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeKernelProvider.kt @@ -6,6 +6,7 @@ import sk.ainet.backend.api.kernel.KernelProvider import sk.ainet.backend.api.kernel.MemSegKernelProvider import sk.ainet.backend.api.kernel.Q4KMatmulKernel import sk.ainet.backend.api.kernel.Q4KMemSegMatmulKernel +import sk.ainet.backend.api.kernel.Q4_0MatmulKernel import sk.ainet.backend.api.kernel.Q8_0MatmulKernel /** @@ -93,4 +94,7 @@ public object NativeKernelProvider : KernelProvider, MemSegKernelProvider { override fun matmulQ8_0(): Q8_0MatmulKernel? = if (NativeQ8_0MatmulKernel.isAvailable()) NativeQ8_0MatmulKernel else null + + override fun matmulQ4_0(): Q4_0MatmulKernel? = + if (NativeQ4_0MatmulKernel.isAvailable()) NativeQ4_0MatmulKernel else null } diff --git a/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeQ4_0MatmulKernel.kt b/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeQ4_0MatmulKernel.kt new file mode 100644 index 00000000..718b4917 --- /dev/null +++ b/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeQ4_0MatmulKernel.kt @@ -0,0 +1,103 @@ +package sk.ainet.exec.kernel + +import java.lang.foreign.Arena +import java.lang.foreign.FunctionDescriptor +import java.lang.foreign.Linker +import java.lang.foreign.MemorySegment +import java.lang.foreign.ValueLayout +import java.lang.invoke.MethodHandle +import sk.ainet.backend.api.kernel.Q4_0MatmulKernel + +/** + * Native (FFM) implementation of [Q4_0MatmulKernel]. + * + * Wraps the bundled C symbol + * + * void skainet_q4_0_matmul( + * const float* input, int32_t input_offset, + * const uint8_t* weight, int32_t weight_byte_offset, + * int32_t input_dim, int32_t output_dim, + * float* output, int32_t output_offset); + * + * The C kernel decodes the ggml-canonical Q4_0 block (FP16 scale + 16 + * packed bytes, split nibble layout) with `(code - 8) * d` dequant and a + * tight inner FMA the compiler auto-vectorizes under -O3 -ffast-math. + * + * Numerical parity vs [ScalarQ4_0MatmulKernel] is asserted by + * `NativeQ4_0MatmulKernelParityTest` within the same `1e-2 * + * blocksPerInputDim` band the Panama parity uses. + */ +internal object NativeQ4_0MatmulKernel : Q4_0MatmulKernel { + + fun isAvailable(): Boolean = handle != null + + override fun matmul( + input: FloatArray, inputOffset: Int, + weight: ByteArray, weightByteOffset: Int, + inputDim: Int, outputDim: Int, + output: FloatArray, outputOffset: Int, + ) { + require(inputDim % BLOCK_SIZE == 0) { + "NativeQ4_0MatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim" + } + if (outputDim == 0) return + + val mh = handle + ?: error("NativeQ4_0MatmulKernel.matmul invoked while native library unavailable") + + val blocksPerInputDim = inputDim / BLOCK_SIZE + val inputReachFloats = if (inputDim == 0) 0 else inputOffset + inputDim + val weightReachBytes = if (inputDim == 0 || outputDim == 0) 0 + else weightByteOffset + blocksPerInputDim * outputDim * BYTES_PER_BLOCK + val outputReachFloats = outputOffset + outputDim + + Arena.ofConfined().use { arena -> + val fAlign = ValueLayout.JAVA_FLOAT.byteAlignment() + val bAlign = ValueLayout.JAVA_BYTE.byteAlignment() + + val inputSeg: MemorySegment = if (inputReachFloats > 0) + arena.allocate(inputReachFloats.toLong() * java.lang.Float.BYTES, fAlign) + else MemorySegment.NULL + val weightSeg: MemorySegment = if (weightReachBytes > 0) + arena.allocate(weightReachBytes.toLong(), bAlign) + else MemorySegment.NULL + val outputSeg: MemorySegment = + arena.allocate(outputReachFloats.toLong() * java.lang.Float.BYTES, fAlign) + + if (inputReachFloats > 0) { + MemorySegment.copy(input, 0, inputSeg, ValueLayout.JAVA_FLOAT, 0L, inputReachFloats) + } + if (weightReachBytes > 0) { + MemorySegment.copy(weight, 0, weightSeg, ValueLayout.JAVA_BYTE, 0L, weightReachBytes) + } + + mh.invoke( + inputSeg, inputOffset, + weightSeg, weightByteOffset, + inputDim, outputDim, + outputSeg, outputOffset, + ) + + MemorySegment.copy(outputSeg, ValueLayout.JAVA_FLOAT, 0L, output, 0, outputReachFloats) + } + } + + private const val BLOCK_SIZE = 32 + private const val BYTES_PER_BLOCK = 18 + + private val handle: MethodHandle? by lazy { + val lookup = NativeLibraryLoader.lookup() ?: return@lazy null + val symbol = lookup.find("skainet_q4_0_matmul").orElse(null) ?: return@lazy null + val descriptor = FunctionDescriptor.ofVoid( + ValueLayout.ADDRESS, // input + ValueLayout.JAVA_INT, // input_offset + ValueLayout.ADDRESS, // weight + ValueLayout.JAVA_INT, // weight_byte_offset + ValueLayout.JAVA_INT, // input_dim + ValueLayout.JAVA_INT, // output_dim + ValueLayout.ADDRESS, // output + ValueLayout.JAVA_INT, // output_offset + ) + runCatching { Linker.nativeLinker().downcallHandle(symbol, descriptor) }.getOrNull() + } +} diff --git a/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ4_0MatmulKernelParityTest.kt b/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ4_0MatmulKernelParityTest.kt new file mode 100644 index 00000000..e6e1a198 --- /dev/null +++ b/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ4_0MatmulKernelParityTest.kt @@ -0,0 +1,117 @@ +package sk.ainet.exec.kernel + +import kotlin.math.abs +import kotlin.random.Random +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +/** + * Numerical parity tests for [NativeQ4_0MatmulKernel] against + * [PanamaVectorQ4_0MatmulKernel]. Same FP16 scale decode + split-layout + * `(nibble - 8)` dequant in both kernels; differences come from FMA + + * reordered-reduction only. + * + * Tolerance: `1e-2 * blocksPerInputDim` (matches the Panama / Q8_0 + * parity convention). + */ +class NativeQ4_0MatmulKernelParityTest { + + private val blockSize = 32 + private val bytesPerBlock = 18 + + @BeforeTest + fun checkAvailable() { + assertTrue( + NativeQ4_0MatmulKernel.isAvailable(), + "Native Q4_0 kernel must be available — bundled libskainet_kernels missing or " + + "skainet_q4_0_matmul symbol unresolved", + ) + } + + private fun randomQ4_0Bytes(blocksPerInputDim: Int, outputDim: Int, seed: Int): ByteArray { + val rng = Random(seed) + val numBlocks = blocksPerInputDim * outputDim + val bytes = ByteArray(numBlocks * bytesPerBlock) + rng.nextBytes(bytes) + for (block in 0 until numBlocks) { + val base = block * bytesPerBlock + bytes[base + 0] = 0x00.toByte() + bytes[base + 1] = 0x22.toByte() // FP16 ~ 7.6e-3, comfortably finite + non-zero + } + return bytes + } + + private fun assertParity( + inputDim: Int, + outputDim: Int, + seed: Int, + tolPerBlock: Float = 1e-2f, + ) { + val blocksPerInputDim = inputDim / blockSize + val rng = Random(seed) + val input = FloatArray(inputDim) { rng.nextFloat() - 0.5f } + val weight = randomQ4_0Bytes(blocksPerInputDim, outputDim, seed) + val outPanama = FloatArray(outputDim) + val outNative = FloatArray(outputDim) + + PanamaVectorQ4_0MatmulKernel.matmul(input, 0, weight, 0, inputDim, outputDim, outPanama, 0) + NativeQ4_0MatmulKernel.matmul(input, 0, weight, 0, inputDim, outputDim, outNative, 0) + + val tol = (tolPerBlock * blocksPerInputDim.coerceAtLeast(1)).coerceAtLeast(tolPerBlock) + for (i in outPanama.indices) { + val diff = abs(outPanama[i] - outNative[i]) + assertTrue( + diff <= tol, + "mismatch at $i: panama=${outPanama[i]} native=${outNative[i]} diff=$diff tol=$tol", + ) + } + } + + @Test fun single_block_single_output_matches_panama() = + assertParity(inputDim = 32, outputDim = 1, seed = 1) + + @Test fun single_block_multiple_outputs_matches_panama() = + assertParity(inputDim = 32, outputDim = 7, seed = 2) + + @Test fun multiple_blocks_single_output_matches_panama() = + assertParity(inputDim = 256, outputDim = 1, seed = 3) + + @Test fun llm_typical_attention_proj_matches_panama() = + assertParity(inputDim = 512, outputDim = 512, seed = 4) + + @Test fun llm_typical_ffn_proj_matches_panama() = + assertParity(inputDim = 256, outputDim = 1024, seed = 5) + + @Test fun rejects_non_block_aligned_input_dim() { + assertFailsWith { + NativeQ4_0MatmulKernel.matmul( + FloatArray(31), 0, + ByteArray(bytesPerBlock), 0, + 31, 1, + FloatArray(1), 0, + ) + } + } + + @Test fun zero_input_dim_zeros_output() { + val out = FloatArray(5) { 9f } + NativeQ4_0MatmulKernel.matmul( + FloatArray(0), 0, + ByteArray(0), 0, + 0, 5, + out, 0, + ) + for (v in out) assertEquals(0f, v, "output should be zeroed for inputDim=0") + } + + @Test fun provider_returns_native_q4_0_when_available() { + val kernel = NativeKernelProvider.matmulQ4_0() + assertTrue( + kernel === NativeQ4_0MatmulKernel, + "Provider must hand out the native Q4_0 kernel when bundled lib is loaded", + ) + } +} diff --git a/skainet-lang/skainet-lang-core/src/jvmMain/kotlin/sk/ainet/lang/tensor/data/Q4MemorySegmentTensorData.kt b/skainet-lang/skainet-lang-core/src/jvmMain/kotlin/sk/ainet/lang/tensor/data/Q4MemorySegmentTensorData.kt index c8617307..e2125903 100644 --- a/skainet-lang/skainet-lang-core/src/jvmMain/kotlin/sk/ainet/lang/tensor/data/Q4MemorySegmentTensorData.kt +++ b/skainet-lang/skainet-lang-core/src/jvmMain/kotlin/sk/ainet/lang/tensor/data/Q4MemorySegmentTensorData.kt @@ -27,9 +27,11 @@ public interface Q4MemorySegmentMarker : MemorySegmentBackedData { * * Q4_0 block layout (18 bytes per 32 elements): * - 2 bytes: f16 scale (little-endian) - * - 16 bytes: packed 4-bit codes (32 values, 2 per byte) + * - 16 bytes: packed 4-bit codes (32 values) in the canonical ggml + * *split* layout — low nibbles decode elements 0..15, high nibbles + * decode elements 16..31. * - * Dequantization: output[i] = (nibble[i] - 8) * scale + * Dequantization: output[j] = (nibble[j] - 8) * scale * * The segment is arena-managed and 64-byte aligned for SIMD access. */ @@ -52,9 +54,12 @@ public class Q4MemorySegmentTensorData( val flatIndex = calcFlatIndex(indices) val blockIdx = flatIndex / blockSize val elemIdx = flatIndex % blockSize - val codesByteOffset = segmentByteOffset + blockIdx.toLong() * bytesPerBlock + 2 + (elemIdx / 2).toLong() + // Split layout: elements 0..15 are low nibbles of bytes 0..15, + // elements 16..31 are the high nibbles of the same bytes. + val byteInBlock = if (elemIdx < 16) elemIdx else elemIdx - 16 + val codesByteOffset = segmentByteOffset + blockIdx.toLong() * bytesPerBlock + 2 + byteInBlock.toLong() val packedByte = segment.get(JAVA_BYTE, codesByteOffset).toInt() and 0xFF - val code = if (elemIdx % 2 == 0) packedByte and 0x0F else packedByte ushr 4 + val code = if (elemIdx < 16) packedByte and 0x0F else packedByte ushr 4 return code.toByte() } @@ -62,10 +67,11 @@ public class Q4MemorySegmentTensorData( val flatIndex = calcFlatIndex(indices) val blockIdx = flatIndex / blockSize val elemIdx = flatIndex % blockSize - val codesByteOffset = segmentByteOffset + blockIdx.toLong() * bytesPerBlock + 2 + (elemIdx / 2).toLong() + val byteInBlock = if (elemIdx < 16) elemIdx else elemIdx - 16 + val codesByteOffset = segmentByteOffset + blockIdx.toLong() * bytesPerBlock + 2 + byteInBlock.toLong() val currentByte = segment.get(JAVA_BYTE, codesByteOffset).toInt() and 0xFF val newNibble = value.toInt() and 0x0F - val updated = if (elemIdx % 2 == 0) { + val updated = if (elemIdx < 16) { (currentByte and 0xF0) or newNibble } else { (currentByte and 0x0F) or (newNibble shl 4) @@ -83,11 +89,14 @@ public class Q4MemorySegmentTensorData( val scale = halfToFloat((b1 shl 8) or b0) val elemsInBlock = minOf(blockSize, shape.volume - outIdx) for (i in 0 until elemsInBlock) { - val codeOff = blockOff + 2 + (i / 2).toLong() + // Split layout: i<16 → low nibble of byte i; i>=16 → high nibble of byte i-16. + val byteInBlock = if (i < 16) i else i - 16 + val codeOff = blockOff + 2 + byteInBlock.toLong() val packedByte = segment.get(JAVA_BYTE, codeOff).toInt() and 0xFF - val code = if (i % 2 == 0) packedByte and 0x0F else packedByte ushr 4 - result[outIdx++] = (code - 8).toFloat() * scale + val code = if (i < 16) packedByte and 0x0F else packedByte ushr 4 + result[outIdx + i] = (code - 8).toFloat() * scale } + outIdx += elemsInBlock } return result }