From d9f760fa673271bb36676bafb32e42be9118e4f0 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 8 Jun 2026 00:34:25 +0200 Subject: [PATCH] feat(backend-cpu): packed Q5_1 / Q5_0 matmul kernels + lazy transpose MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #708. Adds a packed SIMD-path (currently scalar) matmul for the GGML Q5_1 and Q5_0 quantized formats, mirroring the existing Q4_0/Q8_0/Q4_K/Q6_K chain, so these weights can be consumed packed instead of dequantized to FP32 (avoids the FP32 memory blow-up; e.g. functiongemma-270m's attention/FFN weights are Q5_1). Changes: - TensorEncoding: add Q5_0 (22 B/block) and Q5_1 (24 B/block) data objects. - New Q5_1TensorData / Q5_0TensorData interfaces + Q5_{1,0}BlockTensorData, with dequantizeBlock matching DequantOps.dequantQ5_{1,0}FromBytes exactly (w = d*(code + (highBit<<4)) + m for Q5_1; d*(code + (highBit<<4) - 16) for Q5_0). - JvmQuantizedVectorKernels.matmulQ5_1Vec / matmulQ5_0Vec: row-major [out, in] packing (output row o's `in` weights are in/32 contiguous blocks), so the kernel consumes raw GGUF bytes directly — no block-major re-layout. - DefaultCpuOpsJvm: matmul dispatch branches for Q5_1TensorData / Q5_0TensorData, and lazy transpose branches (pure shape swap, keep packed bytes) so `ops.matmul(x, ops.transpose(W))` runs without a dequant round-trip. Layout note: Q4_K/Q6_K kernels are block-major and need a converter re-layout; Q5_1/Q5_0 are intentionally row-major so the downstream converter case (SKaiNET-transformers#170) just wraps the raw bytes. Tests (Q5MatmulDispatchTest): packed Q5_1/Q5_0 matmul through ops.matmul(x, transpose(W)) matches the FP32-dequant matmul to <1e-3, across single/multi-batch, with the FP32 reference dequantized inline (independent of the data-type code under test). Existing Q8_0/Q4_0/MemSeg/transpose tests stay green. Scalar inner loop keeps the weights packed (the memory win); SIMD vectorization of the dequant+dot loop is a follow-up. Co-Authored-By: Claude Opus 4.8 --- .../ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt | 45 ++++++ .../tensor/ops/JvmQuantizedVectorKernels.kt | 94 ++++++++++++ .../exec/tensor/ops/Q5MatmulDispatchTest.kt | 140 ++++++++++++++++++ .../ainet/lang/tensor/data/Q5_0TensorData.kt | 132 +++++++++++++++++ .../ainet/lang/tensor/data/Q5_1TensorData.kt | 140 ++++++++++++++++++ .../lang/tensor/storage/TensorEncoding.kt | 24 +++ 6 files changed, 575 insertions(+) create mode 100644 skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/Q5MatmulDispatchTest.kt create mode 100644 skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/Q5_0TensorData.kt create mode 100644 skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/Q5_1TensorData.kt diff --git a/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt index b70abfd9..145fdbba 100644 --- a/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt +++ b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt @@ -31,6 +31,10 @@ import sk.ainet.lang.tensor.data.Q4_KBlockTensorData import sk.ainet.lang.tensor.data.Q4_KTensorData import sk.ainet.lang.tensor.data.Q6_KBlockTensorData import sk.ainet.lang.tensor.data.Q6_KTensorData +import sk.ainet.lang.tensor.data.Q5_1BlockTensorData +import sk.ainet.lang.tensor.data.Q5_1TensorData +import sk.ainet.lang.tensor.data.Q5_0BlockTensorData +import sk.ainet.lang.tensor.data.Q5_0TensorData import sk.ainet.lang.tensor.data.TensorData import sk.ainet.lang.types.DType import sk.ainet.lang.types.FP16 @@ -224,6 +228,21 @@ internal class DefaultCpuOpsJvm( @Suppress("UNCHECKED_CAST") return newTensor(transposed as TensorData, tensor.dtype, tensor) } + // Q5_1 / Q5_0 packed bytes use a row-major `[out, in]` layout that the + // `matmulQ5_1Vec` / `matmulQ5_0Vec` kernels index by output row, so the + // transpose is a pure shape swap — the same bytes give the right values + // under the swapped shape (lets `ops.matmul(x, ops.transpose(W))` run + // without a dequant round-trip). + if (data is Q5_1TensorData) { + val transposed = Q5_1BlockTensorData(Shape(cols, rows), data.packedData) + @Suppress("UNCHECKED_CAST") + return newTensor(transposed as TensorData, tensor.dtype, tensor) + } + if (data is Q5_0TensorData) { + val transposed = Q5_0BlockTensorData(Shape(cols, rows), data.packedData) + @Suppress("UNCHECKED_CAST") + return newTensor(transposed as TensorData, tensor.dtype, tensor) + } // MemorySegment FP32 fast path: physical transpose via SIMD. // Uses Arena.ofAuto() so the result segment is reclaimed by GC // when the wrapping Tensor is no longer reachable. Earlier @@ -558,6 +577,32 @@ internal class DefaultCpuOpsJvm( @Suppress("UNCHECKED_CAST") CpuTensor(outData as TensorData, this, a.dtype) } + is Q5_1TensorData -> { + val outBuffer = FloatArray(batchSize * outputDim) + for (batch in 0 until batchSize) { + val batchInput = if (batchSize == 1) inputBuffer + else inputBuffer.copyOfRange(batch * inputDim, (batch + 1) * inputDim) + JvmQuantizedVectorKernels.matmulQ5_1Vec( + batchInput, bData.packedData, inputDim, outputDim, outBuffer, batch * outputDim, + ) + } + val outData = DenseFloatArrayTensorData(Shape(batchSize, outputDim), outBuffer) + @Suppress("UNCHECKED_CAST") + CpuTensor(outData as TensorData, this, a.dtype) + } + is Q5_0TensorData -> { + val outBuffer = FloatArray(batchSize * outputDim) + for (batch in 0 until batchSize) { + val batchInput = if (batchSize == 1) inputBuffer + else inputBuffer.copyOfRange(batch * inputDim, (batch + 1) * inputDim) + JvmQuantizedVectorKernels.matmulQ5_0Vec( + batchInput, bData.packedData, inputDim, outputDim, outBuffer, batch * outputDim, + ) + } + val outData = DenseFloatArrayTensorData(Shape(batchSize, outputDim), outBuffer) + @Suppress("UNCHECKED_CAST") + CpuTensor(outData as TensorData, this, a.dtype) + } is Q4_KTensorData -> { val outBuffer = FloatArray(batchSize * outputDim) val spiKernel = q4kMatmulKernel diff --git a/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/JvmQuantizedVectorKernels.kt b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/JvmQuantizedVectorKernels.kt index 8f726ef6..009c188c 100644 --- a/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/JvmQuantizedVectorKernels.kt +++ b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/JvmQuantizedVectorKernels.kt @@ -909,4 +909,98 @@ internal object JvmQuantizedVectorKernels { output[outputOffset + o] = accVec.reduceLanes(VectorOperators.ADD) + accScalar } } + + /** + * Q5_1 matrix-vector multiply: `output = input · Wᵀ` for a packed Q5_1 weight. + * + * Packed weights are in the natural GGUF **row-major** `[outputDim, inputDim]` + * layout: output row `o`'s `inputDim` weights are `inputDim / 32` contiguous + * 24-byte blocks. Dequant matches `DequantOps.dequantQ5_1FromBytes` exactly: + * `w = d * (code + (highBit shl 4)) + m`. Scalar (keeps weights packed — the + * memory win; SIMD vectorization of the inner loop is a follow-up). + */ + fun matmulQ5_1Vec( + input: FloatArray, + packedWeights: ByteArray, + inputDim: Int, + outputDim: Int, + output: FloatArray, + outputOffset: Int = 0, + ) { + val bytesPerBlock = 24 + val blocksPerInputDim = (inputDim + 31) / 32 + for (o in 0 until outputDim) { + var acc = 0f + val rowBase = o * blocksPerInputDim * bytesPerBlock + for (blk in 0 until blocksPerInputDim) { + val base = rowBase + blk * bytesPerBlock + val d = halfToFloat(((packedWeights[base + 1].toInt() and 0xFF) shl 8) or (packedWeights[base].toInt() and 0xFF)) + val m = halfToFloat(((packedWeights[base + 3].toInt() and 0xFF) shl 8) or (packedWeights[base + 2].toInt() and 0xFF)) + val qh = intArrayOf( + packedWeights[base + 4].toInt() and 0xFF, + packedWeights[base + 5].toInt() and 0xFF, + packedWeights[base + 6].toInt() and 0xFF, + packedWeights[base + 7].toInt() and 0xFF, + ) + val qsBase = base + 8 + val inBase = blk * 32 + for (j in 0 until 16) { + val q = packedWeights[qsBase + j].toInt() and 0xFF + val lo = q and 0x0F + val hi = q ushr 4 + val bitLo = (qh[j / 8] ushr (j % 8)) and 0x01 + val bitHi = (qh[(j + 16) / 8] ushr ((j + 16) % 8)) and 0x01 + val wLo = d * (lo + (bitLo shl 4)) + m + val wHi = d * (hi + (bitHi shl 4)) + m + acc += input[inBase + j] * wLo + input[inBase + 16 + j] * wHi + } + } + output[outputOffset + o] = acc + } + } + + /** + * Q5_0 matrix-vector multiply: `output = input · Wᵀ` for a packed Q5_0 weight. + * + * Row-major `[outputDim, inputDim]` packing of 22-byte blocks. Dequant matches + * `DequantOps.dequantQ5_0FromBytes`: `w = d * (code + (highBit shl 4) - 16)`. + */ + fun matmulQ5_0Vec( + input: FloatArray, + packedWeights: ByteArray, + inputDim: Int, + outputDim: Int, + output: FloatArray, + outputOffset: Int = 0, + ) { + val bytesPerBlock = 22 + val blocksPerInputDim = (inputDim + 31) / 32 + for (o in 0 until outputDim) { + var acc = 0f + val rowBase = o * blocksPerInputDim * bytesPerBlock + for (blk in 0 until blocksPerInputDim) { + val base = rowBase + blk * bytesPerBlock + val d = halfToFloat(((packedWeights[base + 1].toInt() and 0xFF) shl 8) or (packedWeights[base].toInt() and 0xFF)) + val qh = intArrayOf( + packedWeights[base + 2].toInt() and 0xFF, + packedWeights[base + 3].toInt() and 0xFF, + packedWeights[base + 4].toInt() and 0xFF, + packedWeights[base + 5].toInt() and 0xFF, + ) + val qsBase = base + 6 + val inBase = blk * 32 + for (j in 0 until 16) { + val q = packedWeights[qsBase + j].toInt() and 0xFF + val lo = q and 0x0F + val hi = q ushr 4 + val bitLo = (qh[j / 8] ushr (j % 8)) and 0x01 + val bitHi = (qh[(j + 16) / 8] ushr ((j + 16) % 8)) and 0x01 + val wLo = d * (lo + (bitLo shl 4) - 16) + val wHi = d * (hi + (bitHi shl 4) - 16) + acc += input[inBase + j] * wLo + input[inBase + 16 + j] * wHi + } + } + output[outputOffset + o] = acc + } + } } diff --git a/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/Q5MatmulDispatchTest.kt b/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/Q5MatmulDispatchTest.kt new file mode 100644 index 00000000..65d1cc27 --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/Q5MatmulDispatchTest.kt @@ -0,0 +1,140 @@ +package sk.ainet.exec.tensor.ops + +import kotlin.random.Random +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import sk.ainet.context.DirectCpuExecutionContext +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.data.Q5_0BlockTensorData +import sk.ainet.lang.tensor.data.Q5_1BlockTensorData +import sk.ainet.lang.tensor.data.TensorData +import sk.ainet.lang.types.FP32 + +/** + * Validates the packed Q5_1 / Q5_0 matmul kernels + lazy transpose: feeding a packed + * weight through `ops.matmul(x, ops.transpose(W))` must match feeding the FP32-dequantized + * weight through the same path. The FP32 reference is dequantized inline (independent of the + * `Q5_*BlockTensorData.dequantizeBlock` code under test), matching ggml / `DequantOps`. + */ +class Q5MatmulDispatchTest { + + private val ctx = DirectCpuExecutionContext() + + private fun f16(v: Float): Int { + // float -> IEEE half bits (round-to-nearest-even, good enough for test weights) + val bits = v.toRawBits() + val sign = (bits ushr 16) and 0x8000 + var expo = ((bits ushr 23) and 0xFF) - 127 + 15 + val mant = bits and 0x7FFFFF + if (expo <= 0) return sign // flush tiny to signed zero + if (expo >= 31) return sign or 0x7C00 // inf + return sign or (expo shl 10) or (mant ushr 13) + } + + private fun halfToFloat(h: Int): Float { + val sign = (h and 0x8000) shl 16 + val exp = (h and 0x7C00) shr 10 + val mant = h and 0x03FF + return when (exp) { + 0 -> Float.fromBits(sign) // (subnormals flushed by f16() above) + 31 -> Float.fromBits(sign or (0xFF shl 23) or (mant shl 13)) + else -> Float.fromBits(sign or ((exp - 15 + 127) shl 23) or (mant shl 13)) + } + } + + // --- Q5_1: 24 bytes/block (d, m, qh[4], qs[16]) --------------------------------------- + + private fun randomQ5_1Block(rng: Random, out: ByteArray, off: Int) { + val d = f16(0.02f + rng.nextFloat() * 0.05f) + val m = f16(-0.3f + rng.nextFloat() * 0.6f) + out[off] = (d and 0xFF).toByte(); out[off + 1] = ((d ushr 8) and 0xFF).toByte() + out[off + 2] = (m and 0xFF).toByte(); out[off + 3] = ((m ushr 8) and 0xFF).toByte() + for (k in 0 until 4) out[off + 4 + k] = rng.nextInt(256).toByte() // qh + for (k in 0 until 16) out[off + 8 + k] = rng.nextInt(256).toByte() // qs + } + + private fun dequantQ5_1Block(b: ByteArray, off: Int, dst: FloatArray, dstOff: Int) { + val d = halfToFloat(((b[off + 1].toInt() and 0xFF) shl 8) or (b[off].toInt() and 0xFF)) + val m = halfToFloat(((b[off + 3].toInt() and 0xFF) shl 8) or (b[off + 2].toInt() and 0xFF)) + val qh = intArrayOf(b[off + 4].toInt() and 0xFF, b[off + 5].toInt() and 0xFF, b[off + 6].toInt() and 0xFF, b[off + 7].toInt() and 0xFF) + for (j in 0 until 16) { + val q = b[off + 8 + j].toInt() and 0xFF + val lo = q and 0x0F; val hi = q ushr 4 + val bitLo = (qh[j / 8] ushr (j % 8)) and 0x01 + val bitHi = (qh[(j + 16) / 8] ushr ((j + 16) % 8)) and 0x01 + dst[dstOff + j] = d * (lo + (bitLo shl 4)) + m + dst[dstOff + 16 + j] = d * (hi + (bitHi shl 4)) + m + } + } + + // --- Q5_0: 22 bytes/block (d, qh[4], qs[16]), symmetric -16 -------------------------- + + private fun randomQ5_0Block(rng: Random, out: ByteArray, off: Int) { + val d = f16(0.02f + rng.nextFloat() * 0.05f) + out[off] = (d and 0xFF).toByte(); out[off + 1] = ((d ushr 8) and 0xFF).toByte() + for (k in 0 until 4) out[off + 2 + k] = rng.nextInt(256).toByte() + for (k in 0 until 16) out[off + 6 + k] = rng.nextInt(256).toByte() + } + + private fun dequantQ5_0Block(b: ByteArray, off: Int, dst: FloatArray, dstOff: Int) { + val d = halfToFloat(((b[off + 1].toInt() and 0xFF) shl 8) or (b[off].toInt() and 0xFF)) + val qh = intArrayOf(b[off + 2].toInt() and 0xFF, b[off + 3].toInt() and 0xFF, b[off + 4].toInt() and 0xFF, b[off + 5].toInt() and 0xFF) + for (j in 0 until 16) { + val q = b[off + 6 + j].toInt() and 0xFF + val lo = q and 0x0F; val hi = q ushr 4 + val bitLo = (qh[j / 8] ushr (j % 8)) and 0x01 + val bitHi = (qh[(j + 16) / 8] ushr ((j + 16) % 8)) and 0x01 + dst[dstOff + j] = d * (lo + (bitLo shl 4) - 16) + dst[dstOff + 16 + j] = d * (hi + (bitHi shl 4) - 16) + } + } + + private fun assertPackedMatchesFp32( + encoding: String, inputDim: Int, outputDim: Int, batchSize: Int, seed: Int, + ) { + val rng = Random(seed) + val blocksPerRow = inputDim / 32 + val bytesPerBlock = if (encoding == "Q5_1") 24 else 22 + val bytes = ByteArray(outputDim * blocksPerRow * bytesPerBlock) + val wf = FloatArray(outputDim * inputDim) // row-major [out, in] + for (o in 0 until outputDim) { + for (blk in 0 until blocksPerRow) { + val off = (o * blocksPerRow + blk) * bytesPerBlock + val dstOff = o * inputDim + blk * 32 + if (encoding == "Q5_1") { randomQ5_1Block(rng, bytes, off); dequantQ5_1Block(bytes, off, wf, dstOff) } + else { randomQ5_0Block(rng, bytes, off); dequantQ5_0Block(bytes, off, wf, dstOff) } + } + } + + val packed: Tensor = if (encoding == "Q5_1") + ctx.fromData(Q5_1BlockTensorData(Shape(outputDim, inputDim), bytes) as TensorData, FP32::class) + else + ctx.fromData(Q5_0BlockTensorData(Shape(outputDim, inputDim), bytes) as TensorData, FP32::class) + val fp32 = ctx.fromFloatArray(Shape(outputDim, inputDim), FP32::class, wf) + + val input = ctx.fromFloatArray( + Shape(batchSize, inputDim), FP32::class, FloatArray(batchSize * inputDim) { (rng.nextFloat() - 0.5f) }, + ) + val outPacked = ctx.ops.matmul(input, ctx.ops.transpose(packed)).data.copyToFloatArray() + val outFp32 = ctx.ops.matmul(input, ctx.ops.transpose(fp32)).data.copyToFloatArray() + + assertEquals(outFp32.size, outPacked.size, "$encoding output size") + var maxErr = 0f + for (i in outFp32.indices) maxErr = maxOf(maxErr, kotlin.math.abs(outFp32[i] - outPacked[i])) + assertTrue(maxErr < 1e-3f, "$encoding packed matmul deviates from FP32 dequant: maxErr=$maxErr") + } + + @Test fun q5_1_matmul_matches_fp32_dequant_single_batch() = + assertPackedMatchesFp32("Q5_1", inputDim = 128, outputDim = 64, batchSize = 1, seed = 1) + + @Test fun q5_1_matmul_matches_fp32_dequant_multi_batch() = + assertPackedMatchesFp32("Q5_1", inputDim = 256, outputDim = 96, batchSize = 3, seed = 2) + + @Test fun q5_0_matmul_matches_fp32_dequant_single_batch() = + assertPackedMatchesFp32("Q5_0", inputDim = 128, outputDim = 64, batchSize = 1, seed = 3) + + @Test fun q5_0_matmul_matches_fp32_dequant_multi_batch() = + assertPackedMatchesFp32("Q5_0", inputDim = 192, outputDim = 48, batchSize = 2, seed = 4) +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/Q5_0TensorData.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/Q5_0TensorData.kt new file mode 100644 index 00000000..d2795e0f --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/Q5_0TensorData.kt @@ -0,0 +1,132 @@ +package sk.ainet.lang.tensor.data + +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.storage.PackedBlockStorage +import sk.ainet.lang.tensor.storage.TensorEncoding +import sk.ainet.lang.types.DType + +/** + * Tensor data interface for the GGML **Q5_0** quantized format (5-bit, symmetric). + * + * Q5_0 block format (32 elements per block, 22 bytes per block): + * - 2 bytes: f16 scale (`d`) + * - 4 bytes: `qh[0..3]` — the 5th (high) bit of each of the 32 codes + * - 16 bytes: `qs[0..15]` — the low 4 bits, two nibbles per byte + * + * Dequantization (matching `sk.ainet.io.gguf.dequant.DequantOps.dequantQ5_0FromBytes`): + * for `j ∈ [0, 16)`, with `q = qs[j]`, `lo = q & 0x0F`, `hi = q >>> 4`, and the + * high bits `bitLo = (qh[j/8] >>> (j%8)) & 1`, `bitHi = (qh[(j+16)/8] >>> ((j+16)%8)) & 1`: + * + * element[j] = d * (lo + (bitLo shl 4) - 16) + * element[j + 16] = d * (hi + (bitHi shl 4) - 16) + * + * The `- 16` bias makes the 5-bit code symmetric around zero. + */ +public interface Q5_0TensorData : TensorData { + /** Number of Q5_0 blocks in the tensor. */ + public val blockCount: Int + + /** Raw packed data containing all blocks. */ + public val packedData: ByteArray + + public companion object { + /** Elements per Q5_0 block. */ + public const val BLOCK_SIZE: Int = 32 + + /** Bytes per Q5_0 block (2 `d` + 4 `qh` + 16 `qs`). */ + public const val BYTES_PER_BLOCK: Int = 22 + } +} + +/** + * Implementation of [Q5_0TensorData] backed by a packed byte array, in the + * natural GGUF **row-major** `[out, in]` layout. `matmulQ5_0Vec` indexes the + * packed bytes row-major, so no block-major re-layout is needed. + */ +public class Q5_0BlockTensorData( + initialShape: Shape, + private val data: ByteArray +) : Q5_0TensorData, PackedBlockStorage { + + override val shape: Shape = Shape(initialShape.dimensions.copyOf()) + private val strides: IntArray = shape.computeStrides() + override val packedData: ByteArray get() = data + + override val blockCount: Int = (shape.volume + Q5_0TensorData.BLOCK_SIZE - 1) / Q5_0TensorData.BLOCK_SIZE + + override val encoding: TensorEncoding get() = TensorEncoding.Q5_0 + override val blockSize: Int get() = Q5_0TensorData.BLOCK_SIZE + + init { + val requiredBytes = blockCount * Q5_0TensorData.BYTES_PER_BLOCK + require(data.size >= requiredBytes) { + "Data size ${data.size} is less than required $requiredBytes bytes for $blockCount blocks" + } + } + + override fun dequantizeBlock(blockIdx: Int, output: FloatArray, outputOffset: Int) { + require(blockIdx in 0 until blockCount) { "Block index $blockIdx out of bounds (0..$blockCount)" } + val base = blockIdx * Q5_0TensorData.BYTES_PER_BLOCK + val d = Q4_0BlockTensorData.halfToFloat(((data[base + 1].toInt() and 0xFF) shl 8) or (data[base].toInt() and 0xFF)) + val qh0 = data[base + 2].toInt() and 0xFF + val qh1 = data[base + 3].toInt() and 0xFF + val qh2 = data[base + 4].toInt() and 0xFF + val qh3 = data[base + 5].toInt() and 0xFF + val qh = intArrayOf(qh0, qh1, qh2, qh3) + val qsBase = base + 6 + val elemsInBlock = minOf(Q5_0TensorData.BLOCK_SIZE, shape.volume - blockIdx * Q5_0TensorData.BLOCK_SIZE) + for (j in 0 until 16) { + val q = data[qsBase + j].toInt() and 0xFF + val lo = q and 0x0F + val hi = q ushr 4 + val bitLo = (qh[j / 8] ushr (j % 8)) and 0x01 + val bitHi = (qh[(j + 16) / 8] ushr ((j + 16) % 8)) and 0x01 + val o0 = outputOffset + j + if (j < elemsInBlock && o0 < output.size) output[o0] = d * (lo + (bitLo shl 4) - 16) + val o1 = outputOffset + 16 + j + if (16 + j < elemsInBlock && o1 < output.size) output[o1] = d * (hi + (bitHi shl 4) - 16) + } + } + + override fun get(vararg indices: Int): Byte { + val flatIndex = calcFlatIndex(indices) + val tmp = FloatArray(Q5_0TensorData.BLOCK_SIZE) + dequantizeBlock(flatIndex / Q5_0TensorData.BLOCK_SIZE, tmp, 0) + return tmp[flatIndex % Q5_0TensorData.BLOCK_SIZE].toInt().toByte() + } + + override fun set(vararg indices: Int, value: Byte) { + throw UnsupportedOperationException("Q5_0BlockTensorData is read-only (packed quantized weights)") + } + + private fun calcFlatIndex(indices: IntArray): Int { + require(indices.size == shape.dimensions.size) { + "Number of indices (${indices.size}) must match tensor dimensions (${shape.dimensions.size})" + } + var flatIndex = 0 + for (i in indices.indices) { + val idx = indices[i] + require(idx >= 0 && idx < shape.dimensions[i]) { + "Index $idx out of bounds for dimension $i with size ${shape.dimensions[i]}" + } + flatIndex += idx * strides[i] + } + return flatIndex + } + + public companion object { + /** Create [Q5_0BlockTensorData] from raw packed Q5_0 bytes (GGUF row-major). */ + public fun fromRawBytes(shape: Shape, bytes: ByteArray): Q5_0BlockTensorData = + Q5_0BlockTensorData(shape, bytes) + } +} + +/** Dequantize Q5_0 tensor data to a FloatArray (row-major, matching the packed layout). */ +public fun Q5_0TensorData.toFloatArray(): FloatArray { + val result = FloatArray(shape.volume) + val block = this as Q5_0BlockTensorData + for (blockIdx in 0 until blockCount) { + block.dequantizeBlock(blockIdx, result, blockIdx * Q5_0TensorData.BLOCK_SIZE) + } + return result +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/Q5_1TensorData.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/Q5_1TensorData.kt new file mode 100644 index 00000000..1aab1b54 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/Q5_1TensorData.kt @@ -0,0 +1,140 @@ +package sk.ainet.lang.tensor.data + +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.storage.PackedBlockStorage +import sk.ainet.lang.tensor.storage.TensorEncoding +import sk.ainet.lang.types.DType + +/** + * Tensor data interface for the GGML **Q5_1** quantized format (5-bit, with a + * per-block minimum). + * + * Q5_1 block format (32 elements per block, 24 bytes per block): + * - 2 bytes: f16 scale (`d`) + * - 2 bytes: f16 minimum (`m`) + * - 4 bytes: `qh[0..3]` — the 5th (high) bit of each of the 32 codes + * - 16 bytes: `qs[0..15]` — the low 4 bits, two nibbles per byte + * + * Dequantization (matching `sk.ainet.io.gguf.dequant.DequantOps.dequantQ5_1FromBytes`): + * for `j ∈ [0, 16)`, with `q = qs[j]`, `lo = q & 0x0F`, `hi = q >>> 4`, and the + * high bits `bitLo = (qh[j/8] >>> (j%8)) & 1`, `bitHi = (qh[(j+16)/8] >>> ((j+16)%8)) & 1`: + * + * element[j] = d * (lo + (bitLo shl 4)) + m + * element[j + 16] = d * (hi + (bitHi shl 4)) + m + * + * Enables direct quantized matmul without full dequantization, mirroring + * [Q4_0TensorData] / [Q8_0TensorData]. + */ +public interface Q5_1TensorData : TensorData { + /** Number of Q5_1 blocks in the tensor. */ + public val blockCount: Int + + /** Raw packed data containing all blocks. */ + public val packedData: ByteArray + + public companion object { + /** Elements per Q5_1 block. */ + public const val BLOCK_SIZE: Int = 32 + + /** Bytes per Q5_1 block (2 `d` + 2 `m` + 4 `qh` + 16 `qs`). */ + public const val BYTES_PER_BLOCK: Int = 24 + } +} + +/** + * Implementation of [Q5_1TensorData] backed by a packed byte array, in the + * natural GGUF **row-major** `[out, in]` layout (each logical row's elements are + * packed sequentially as `in / 32` blocks). `matmulQ5_1Vec` indexes the packed + * bytes row-major, so no block-major re-layout is needed. + */ +public class Q5_1BlockTensorData( + initialShape: Shape, + private val data: ByteArray +) : Q5_1TensorData, PackedBlockStorage { + + override val shape: Shape = Shape(initialShape.dimensions.copyOf()) + private val strides: IntArray = shape.computeStrides() + override val packedData: ByteArray get() = data + + override val blockCount: Int = (shape.volume + Q5_1TensorData.BLOCK_SIZE - 1) / Q5_1TensorData.BLOCK_SIZE + + override val encoding: TensorEncoding get() = TensorEncoding.Q5_1 + override val blockSize: Int get() = Q5_1TensorData.BLOCK_SIZE + + init { + val requiredBytes = blockCount * Q5_1TensorData.BYTES_PER_BLOCK + require(data.size >= requiredBytes) { + "Data size ${data.size} is less than required $requiredBytes bytes for $blockCount blocks" + } + } + + override fun dequantizeBlock(blockIdx: Int, output: FloatArray, outputOffset: Int) { + require(blockIdx in 0 until blockCount) { "Block index $blockIdx out of bounds (0..$blockCount)" } + val base = blockIdx * Q5_1TensorData.BYTES_PER_BLOCK + val d = Q4_0BlockTensorData.halfToFloat(((data[base + 1].toInt() and 0xFF) shl 8) or (data[base].toInt() and 0xFF)) + val m = Q4_0BlockTensorData.halfToFloat(((data[base + 3].toInt() and 0xFF) shl 8) or (data[base + 2].toInt() and 0xFF)) + val qh0 = data[base + 4].toInt() and 0xFF + val qh1 = data[base + 5].toInt() and 0xFF + val qh2 = data[base + 6].toInt() and 0xFF + val qh3 = data[base + 7].toInt() and 0xFF + val qh = intArrayOf(qh0, qh1, qh2, qh3) + val qsBase = base + 8 + val elemsInBlock = minOf(Q5_1TensorData.BLOCK_SIZE, shape.volume - blockIdx * Q5_1TensorData.BLOCK_SIZE) + for (j in 0 until 16) { + val q = data[qsBase + j].toInt() and 0xFF + val lo = q and 0x0F + val hi = q ushr 4 + val bitLo = (qh[j / 8] ushr (j % 8)) and 0x01 + val bitHi = (qh[(j + 16) / 8] ushr ((j + 16) % 8)) and 0x01 + val o0 = outputOffset + j + if (j < elemsInBlock && o0 < output.size) output[o0] = d * (lo + (bitLo shl 4)) + m + val o1 = outputOffset + 16 + j + if (16 + j < elemsInBlock && o1 < output.size) output[o1] = d * (hi + (bitHi shl 4)) + m + } + } + + override fun get(vararg indices: Int): Byte { + val flatIndex = calcFlatIndex(indices) + val tmp = FloatArray(Q5_1TensorData.BLOCK_SIZE) + val blockIdx = flatIndex / Q5_1TensorData.BLOCK_SIZE + dequantizeBlock(blockIdx, tmp, 0) + // Q5_1 stores real-valued reconstructions; expose the rounded code is not + // meaningful, so this accessor is best-effort for debugging only. + return tmp[flatIndex % Q5_1TensorData.BLOCK_SIZE].toInt().toByte() + } + + override fun set(vararg indices: Int, value: Byte) { + throw UnsupportedOperationException("Q5_1BlockTensorData is read-only (packed quantized weights)") + } + + private fun calcFlatIndex(indices: IntArray): Int { + require(indices.size == shape.dimensions.size) { + "Number of indices (${indices.size}) must match tensor dimensions (${shape.dimensions.size})" + } + var flatIndex = 0 + for (i in indices.indices) { + val idx = indices[i] + require(idx >= 0 && idx < shape.dimensions[i]) { + "Index $idx out of bounds for dimension $i with size ${shape.dimensions[i]}" + } + flatIndex += idx * strides[i] + } + return flatIndex + } + + public companion object { + /** Create [Q5_1BlockTensorData] from raw packed Q5_1 bytes (GGUF row-major). */ + public fun fromRawBytes(shape: Shape, bytes: ByteArray): Q5_1BlockTensorData = + Q5_1BlockTensorData(shape, bytes) + } +} + +/** Dequantize Q5_1 tensor data to a FloatArray (row-major, matching the packed layout). */ +public fun Q5_1TensorData.toFloatArray(): FloatArray { + val result = FloatArray(shape.volume) + val block = this as Q5_1BlockTensorData + for (blockIdx in 0 until blockCount) { + block.dequantizeBlock(blockIdx, result, blockIdx * Q5_1TensorData.BLOCK_SIZE) + } + return result +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/TensorEncoding.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/TensorEncoding.kt index bd781a4f..509b6704 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/TensorEncoding.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/storage/TensorEncoding.kt @@ -76,6 +76,30 @@ public sealed interface TensorEncoding { } } + /** GGML Q5_0 block quantization: 32 elements per 22-byte block. */ + public data object Q5_0 : TensorEncoding { + public const val BLOCK_SIZE: Int = 32 + public const val BYTES_PER_BLOCK: Int = 22 + + override val name: String get() = "Q5_0" + override fun physicalBytes(elementCount: Long): Long { + val blocks = (elementCount + BLOCK_SIZE - 1) / BLOCK_SIZE + return blocks * BYTES_PER_BLOCK + } + } + + /** GGML Q5_1 block quantization: 32 elements per 24-byte block. */ + public data object Q5_1 : TensorEncoding { + public const val BLOCK_SIZE: Int = 32 + public const val BYTES_PER_BLOCK: Int = 24 + + override val name: String get() = "Q5_1" + override fun physicalBytes(elementCount: Long): Long { + val blocks = (elementCount + BLOCK_SIZE - 1) / BLOCK_SIZE + return blocks * BYTES_PER_BLOCK + } + } + /** Ternary encoding: 2 bits per element, packed 4 elements per byte. */ public data object TernaryPacked : TensorEncoding { override val name: String get() = "Ternary"