From ae7691b29e279b27b421a92baf556f167e2ed8d6 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 15 Jun 2026 14:00:11 +0200 Subject: [PATCH] fix(cpu-ops): lazy transpose for Q4_0 too; cover all packed matmul dtypes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Follow-up to #736 (Q8_0). The transpose lazy-rewrap `when` was still missing Q4_0 — a packed type chooseQuantizedMatmulHeap dispatches — so a packed Q4_0 matmul weight through linearProject (matmul(x, transpose(W))) hit the generic FP32 path and threw `Byte cannot be cast to Float`. Add the Q4_0 case so the `when` now covers EVERY packed type that can be a matmul weight (Q4_K/Q5_K/Q6_K/Q5_0/Q5_1/Q8_0/Q4_0). Adds `transpose_preserves_every_packed_quant_type` to PackedMatmulDispatchTest: transposes a 2-D tensor of each of the 7 packed types and asserts the shape flips and the packed encoding is preserved (no FP32 fallback / no crash). Content-agnostic, runs on every platform (jvm + linuxX64). See SKaiNET-transformers#178. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../sk/ainet/exec/tensor/ops/DefaultCpuOps.kt | 12 ++++-- .../tensor/ops/PackedMatmulDispatchTest.kt | 41 +++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt index 0e2c889a..10926261 100644 --- a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt +++ b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt @@ -24,6 +24,7 @@ import sk.ainet.lang.tensor.data.Q5_1TensorData import sk.ainet.lang.tensor.data.Q5_1BlockTensorData import sk.ainet.lang.tensor.data.Q5_0TensorData import sk.ainet.lang.tensor.data.Q5_0BlockTensorData +import sk.ainet.lang.tensor.data.Q4_0BlockTensorData import sk.ainet.lang.tensor.data.Q8_0BlockTensorData import sk.ainet.lang.tensor.data.TensorData import sk.ainet.lang.tensor.data.TensorDataFactory @@ -607,12 +608,15 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory is Q6_KTensorData -> return newTensor(Q6_KBlockTensorData(Shape(cols, rows), d.packedData) as TensorData, tensor.dtype, tensor) is Q5_1TensorData -> return newTensor(Q5_1BlockTensorData(Shape(cols, rows), d.packedData) as TensorData, tensor.dtype, tensor) is Q5_0TensorData -> return newTensor(Q5_0BlockTensorData(Shape(cols, rows), d.packedData) as TensorData, tensor.dtype, tensor) - // Q8_0 lazy transpose: rewrap the same input-block-major bytes with - // flipped shape (bytes are layout-agnostic to the [out,in] kernel - // convention) so a packed Q8_0 weight (e.g. gemma's tied lm_head) + // Q8_0 / Q4_0 lazy transpose: rewrap the same input-block-major bytes + // with flipped shape (bytes are layout-agnostic to the [out,in] kernel + // convention) so a packed weight (e.g. gemma's tied Q8_0 lm_head) // survives linearProject's transpose instead of hitting the generic - // FP32 path (Byte→Float ClassCastException). See transformers #178. + // FP32 path (Byte→Float ClassCastException). This `when` now covers + // every quant type chooseQuantizedMatmulHeap dispatches — i.e. every + // packed type that can be a matmul weight. See transformers #178. is Q8_0TensorData -> return newTensor(Q8_0BlockTensorData(Shape(cols, rows), d.packedData) as TensorData, tensor.dtype, tensor) + is Q4_0TensorData -> return newTensor(Q4_0BlockTensorData(Shape(cols, rows), d.packedData) as TensorData, tensor.dtype, tensor) else -> {} } } diff --git a/skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/exec/tensor/ops/PackedMatmulDispatchTest.kt b/skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/exec/tensor/ops/PackedMatmulDispatchTest.kt index 0c460b04..6593645b 100644 --- a/skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/exec/tensor/ops/PackedMatmulDispatchTest.kt +++ b/skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/exec/tensor/ops/PackedMatmulDispatchTest.kt @@ -3,12 +3,17 @@ package sk.ainet.exec.tensor.ops import kotlin.math.abs import kotlin.random.Random import kotlin.test.Test +import kotlin.test.assertEquals import kotlin.test.assertTrue import sk.ainet.context.DirectCpuExecutionContext import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.data.Q4_0BlockTensorData import sk.ainet.lang.tensor.data.Q4_KBlockTensorData +import sk.ainet.lang.tensor.data.Q5_0BlockTensorData import sk.ainet.lang.tensor.data.Q5_1BlockTensorData +import sk.ainet.lang.tensor.data.Q5_KBlockTensorData import sk.ainet.lang.tensor.data.Q6_KBlockTensorData +import sk.ainet.lang.tensor.data.Q8_0BlockTensorData import sk.ainet.lang.tensor.data.TensorData import sk.ainet.lang.types.FP32 @@ -129,4 +134,40 @@ class PackedMatmulDispatchTest { @Test fun q5_1_through_ops_matmul_transpose() = run("Q5_1", inDim = 128, outDim = 16, seed = 7) @Test fun q4_k_through_ops_matmul_transpose() = run("Q4_K", inDim = 256, outDim = 12, seed = 8) @Test fun q6_k_through_ops_matmul_transpose() = run("Q6_K", inDim = 512, outDim = 8, seed = 9) + + /** + * `ops.transpose` must lazily rewrap EVERY packed quant type that can be a + * matmul weight (the full `chooseQuantizedMatmulHeap` set) — flipping the + * shape while keeping the same packed bytes — instead of falling into the + * generic FP32 path, which casts the Byte-backed buffer to Float and throws + * `ClassCastException`. Regression guard for transformers #178 (Q8_0/Q4_0 + * were the gaps). Content-agnostic: zero bytes, sized per block geometry. + */ + @Test + fun transpose_preserves_every_packed_quant_type() { + val outDim = 8 + // name -> (blockElems, bytesPerBlock, builder) + val cases: List, (Shape, ByteArray) -> TensorData>> = listOf( + Triple("Q4_K", 256 to 144) { s, b -> Q4_KBlockTensorData(s, b) as TensorData }, + Triple("Q5_K", 256 to 176) { s, b -> Q5_KBlockTensorData(s, b) as TensorData }, + Triple("Q6_K", 256 to 210) { s, b -> Q6_KBlockTensorData(s, b) as TensorData }, + Triple("Q8_0", 32 to 34) { s, b -> Q8_0BlockTensorData(s, b) as TensorData }, + Triple("Q4_0", 32 to 18) { s, b -> Q4_0BlockTensorData(s, b) as TensorData }, + Triple("Q5_0", 32 to 22) { s, b -> Q5_0BlockTensorData(s, b) as TensorData }, + Triple("Q5_1", 32 to 24) { s, b -> Q5_1BlockTensorData(s, b) as TensorData }, + ) + for ((name, geom, build) in cases) { + val (blockElems, bpb) = geom + val inDim = blockElems // one block per row + val bytes = ByteArray(outDim * (inDim / blockElems) * bpb) + val w = ctx.fromData(build(Shape(outDim, inDim), bytes), FP32::class) + // The bug threw here for unhandled packed types. + val t = ctx.ops.transpose(w) + assertEquals(Shape(inDim, outDim), t.shape, "$name: transpose did not flip shape") + assertTrue( + t.data::class.simpleName?.contains("Block") == true, + "$name: transpose dropped the packed encoding (got ${t.data::class.simpleName})", + ) + } + } }