From f94ce6c5e97f7286710a6658a68cdfebd3c297b2 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 15 Jun 2026 13:07:07 +0200 Subject: [PATCH] fix(gemma): keep tied Q8_0 lm_head packed in eager NATIVE_OPTIMIZED path (#178) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit FunctionGemma's token_embd is Q8_0 and tied, so convertGemmaWeightsPacked was dequanting BOTH token_embd AND output to FP32 (2×~0.67 GB) — OOM on the 1.9 GB SL2610. `output`/lm_head is a real matmul weight, not an embedding: - packGemmaKQuant: add Q8_0 (32-elem/34B blocks → Q8_0BlockTensorData); generalize the row-major→block-major relayout with a blockSize param. - convertGemmaWeightsPacked: drop OUTPUT_WEIGHT from the isEmbed FP32 branch so it packs like the other matmul weights and runs on the (NEON) Q8_0 kernel. token_embd stays FP32 (it's gathered) but is now wrapped no-copy via DenseFloatArrayTensorData instead of ctx.fromFloatArray (which allocates a second ~0.67 GB buffer). Footprint for the tied embed/lm_head drops ~1.34 GB → ~0.67 GB (embed FP32) + ~0.09 GB (packed Q8_0 lm_head). Requires the engine Q8_0 case in ops.transpose (SKaiNET fix/q8_0-lazy-transpose) so linearProject can transpose the packed weight. Verified: GemmaQ5KPackedParityTest (composite -PuseLocalSkainet) — eager load(NATIVE_OPTIMIZED) decodes byte-identically to the FP32 baseline; lm_head packed as Q8_0. (token_embd row-dequant gather to drop the last ~0.67 GB is the remaining follow-up in #178.) Co-Authored-By: Claude Opus 4.8 (1M context) --- .../ainet/models/gemma/GemmaPackedWeights.kt | 15 ++++++-- .../sk/ainet/models/gemma/GemmaQuantLayout.kt | 38 +++++++++++++------ 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaPackedWeights.kt b/llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaPackedWeights.kt index ec52eb4..24596ed 100644 --- a/llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaPackedWeights.kt +++ b/llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaPackedWeights.kt @@ -5,6 +5,7 @@ import sk.ainet.io.gguf.GGMLQuantizationType import sk.ainet.io.gguf.dequant.DequantOps import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.data.DenseFloatArrayTensorData import sk.ainet.lang.tensor.data.IntArrayTensorData import sk.ainet.lang.tensor.data.TensorData import sk.ainet.lang.types.DType @@ -48,8 +49,12 @@ public fun convertGemmaWeightsPacked( tensor // unknown 2-D layout — leave as-is } else { val bytes = extractRawBytes(tensor.data) - val isEmbed = name == Gemma4TensorNames.TOKEN_EMBEDDINGS || - name == Gemma4TensorNames.OUTPUT_WEIGHT + // Only the token-embedding table is gathered (row lookup) and so + // must be FP32 here. `output`/lm_head is a real matmul weight — + // it stays packed (FunctionGemma's tied output is Q8_0 → NEON + // Q8_0 kernel, transposed lazily by ops.transpose) instead of a + // second ~0.67 GB FP32 copy that would OOM the 1.9 GB board. + val isEmbed = name == Gemma4TensorNames.TOKEN_EMBEDDINGS val packed = if (!isEmbed) packGemmaKQuant(bytes, qt, shape) else null when { packed != null -> { @@ -76,7 +81,11 @@ private fun dequantNoTranspose( ctx: ExecutionContext, ): Tensor { val floats = DequantOps.dequantFromBytes(bytes, qt, shape.volume) - return ctx.fromFloatArray(shape, FP32::class, floats) as Tensor + // Wrap the dequant array directly (no-copy) rather than ctx.fromFloatArray, + // which routes through BufferHandleFactory.owned and allocates a second + // full-size buffer — for the 262k×640 FP32 token_embd (~0.67 GB) that + // transient double is itself enough to OOM the 1.9 GB board. + return ctx.fromData(DenseFloatArrayTensorData(shape, floats), FP32::class) as Tensor } /** diff --git a/llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaQuantLayout.kt b/llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaQuantLayout.kt index 7f4e7b9..608f2ab 100644 --- a/llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaQuantLayout.kt +++ b/llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaQuantLayout.kt @@ -5,6 +5,7 @@ import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.data.Q4_KBlockTensorData import sk.ainet.lang.tensor.data.Q5_KBlockTensorData import sk.ainet.lang.tensor.data.Q6_KBlockTensorData +import sk.ainet.lang.tensor.data.Q8_0BlockTensorData import sk.ainet.lang.tensor.data.TensorData import sk.ainet.lang.types.DType @@ -66,8 +67,8 @@ internal fun relayoutKSeriesRowMajorToBlockMajor( bytes: ByteArray, shape: Shape, bytesPerBlock: Int, + blockSize: Int = 256, ): ByteArray { - val blockSize = 256 require(shape.rank == 2) { "K-series weight must be 2D, got rank ${shape.rank}" } val outDim = shape[0] val inDim = shape[1] @@ -88,19 +89,31 @@ internal fun relayoutKSeriesRowMajorToBlockMajor( return out } -/** Bytes per ggml block for the K-quant types this packer handles. */ -private fun kQuantBytesPerBlock(qt: GGMLQuantizationType): Int? = when (qt) { - GGMLQuantizationType.Q4_K -> 144 - GGMLQuantizationType.Q5_K -> 176 - GGMLQuantizationType.Q6_K -> 210 +/** + * Block geometry `(blockElems, bytesPerBlock)` for the quant types this packer + * handles. The K-series are 256-element super-blocks; Q8_0 is a 32-element block + * (f16 scale + 32 int8). All four have a first-class CPU matmul kernel + a lazy + * transpose in `ops.transpose`, so all four can stay packed instead of FP32. + */ +private fun quantBlockLayout(qt: GGMLQuantizationType): Pair? = when (qt) { + GGMLQuantizationType.Q4_K -> 256 to 144 + GGMLQuantizationType.Q5_K -> 256 to 176 + GGMLQuantizationType.Q6_K -> 256 to 210 + GGMLQuantizationType.Q8_0 -> 32 to 34 else -> null } /** - * Pack raw GGUF K-quant `bytes` of logical `[out, in]` shape into the - * heap-packed block tensor data the matmul kernels read directly (Q4_K / Q5_K / - * Q6_K). Performs the row-major → block-major relayout. Returns `null` for - * non-K-quant types (caller dequantizes those to FP32). + * Pack raw GGUF `bytes` of logical `[out, in]` shape into the heap-packed block + * tensor data the matmul kernels read directly (Q4_K / Q5_K / Q6_K / Q8_0). + * Performs the row-major → block-major relayout. Returns `null` for types + * without a packed kernel (caller dequantizes those to FP32). + * + * Q8_0 matters for gemma's tied `output`/lm_head: FunctionGemma's token_embd is + * Q8_0, so keeping the lm_head packed (vs ~0.67 GB FP32) is what lets the eager + * decode fit the 1.9 GB board, and it runs on the NEON Q8_0 kernel. (Requires + * the Q8_0 case in `ops.transpose` — engine — so `linearProject` can transpose + * the packed weight; see transformers #178.) * * commonMain → works on JVM and Kotlin/Native alike (no MemSeg / Arena). */ @@ -109,13 +122,14 @@ internal fun packGemmaKQuant( qt: GGMLQuantizationType, shape: Shape, ): TensorData? { - val bpb = kQuantBytesPerBlock(qt) ?: return null - val relaid = relayoutKSeriesRowMajorToBlockMajor(bytes, shape, bpb) + val (blockElems, bpb) = quantBlockLayout(qt) ?: return null + val relaid = relayoutKSeriesRowMajorToBlockMajor(bytes, shape, bpb, blockElems) @Suppress("UNCHECKED_CAST") return when (qt) { GGMLQuantizationType.Q4_K -> Q4_KBlockTensorData(shape, relaid) as TensorData GGMLQuantizationType.Q5_K -> Q5_KBlockTensorData(shape, relaid) as TensorData GGMLQuantizationType.Q6_K -> Q6_KBlockTensorData(shape, relaid) as TensorData + GGMLQuantizationType.Q8_0 -> Q8_0BlockTensorData(shape, relaid) as TensorData else -> null } }