Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[versions]
skainet = "0.28.1"
skainet = "0.30.0"
agp = "9.2.1"
jacksonDatabind = "2.22.0"
jsonSchemaValidator = "3.0.3"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public class GemmaNetworkLoader @PublishedApi internal constructor(
public suspend inline fun <reified T : DType, V> load(
ctx: ExecutionContext
): Module<T, V> {
val weights: Gemma4Weights<T, V> = when (val wp = weightsProvider) {
val rawWeights: Gemma4Weights<T, V> = when (val wp = weightsProvider) {
is WeightsProvider.GgufSource -> {
val loader = Gemma4WeightLoader(wp.sourceProvider, quantPolicy = wp.quantPolicy)
loader.loadToMap<T, V>(ctx)
Expand All @@ -142,6 +142,24 @@ public class GemmaNetworkLoader @PublishedApi internal constructor(
}
}

// NATIVE_OPTIMIZED yields raw-byte quant tensors the network mapper can't
// consume directly. Pack them (heap Q4/5/6_K + FP32 fallback) here — this
// is commonMain so it works on Kotlin/Native (the board) as well as the
// JVM, and replaces the JVM-only `convertGemmaWeightsToMemSeg` for the
// `load()` entry point.
val ggufPolicy = when (val wp = weightsProvider) {
is WeightsProvider.GgufSource -> wp.quantPolicy
is WeightsProvider.GgufRandomAccess -> wp.quantPolicy
else -> null
}
val weights: Gemma4Weights<T, V> =
if (ggufPolicy == QuantPolicy.NATIVE_OPTIMIZED) {
@Suppress("UNCHECKED_CAST")
convertGemmaWeightsPacked(rawWeights, ctx) as Gemma4Weights<T, V>
} else {
rawWeights
}

return applyWeightsToNetwork(ctx, weights)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package sk.ainet.models.gemma

import sk.ainet.context.ExecutionContext
import sk.ainet.io.gguf.GGMLQuantizationType
import sk.ainet.io.gguf.dequant.DequantOps
import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.tensor.Tensor
import sk.ainet.lang.tensor.data.IntArrayTensorData
import sk.ainet.lang.tensor.data.TensorData
import sk.ainet.lang.types.DType
import sk.ainet.lang.types.FP32

/**
* commonMain (Kotlin/Native-capable) analogue of the jvmMain
* `convertGemmaWeightsToMemSeg`. Converts the raw-byte quantized tensors a
* `NATIVE_OPTIMIZED` load produces into the forms the DSL matmul path consumes:
*
* - **Q4_K / Q5_K / Q6_K matmul weights** → heap-packed `Q{4,5,6}_KBlockTensorData`
* (via [packGemmaKQuant], with the row-major→block-major relayout). These keep
* the GGUF footprint and run the in-kernel dequant matmul (NEON on the board).
* - **token_embd / output** → FP32 dequant in canonical `[vocab, embed]` order
* (the embedding is gathered, not matmul'd, so no transpose).
* - **everything else quantized** → FP32 dequant transposed to `[out, in]`
* row-major so `linearProject` (`x @ W.t()`) is correct.
*
* Unlike the MemSeg converter this uses no `java.lang.foreign` — it runs on the
* SL2610 board binary (Kotlin/Native) as well as the JVM. The JVM still prefers
* the MemSeg path (lazy transpose + Q4/Q8 MemSeg); this is the board path.
*/
public fun convertGemmaWeightsPacked(
weights: Gemma4Weights<*, *>,
ctx: ExecutionContext,
): Gemma4Weights<*, *> {
@Suppress("UNCHECKED_CAST")
val typed = weights as Gemma4Weights<DType, Any>
val quantTypes = typed.quantTypes
if (quantTypes.isEmpty()) return weights

val logicalShapes = typed.logicalShapes
val newTensors = linkedMapOf<String, Tensor<DType, Any>>()
for ((name, tensor) in typed.tensors) {
val qt = quantTypes[name]
newTensors[name] = when {
qt == null -> tensor // not quantized
else -> {
val shape = logicalShapes[name] ?: logicalShapeFor(name, typed.metadata)
if (shape == null) {
tensor // unknown 2-D layout — leave as-is
} else {
val bytes = extractRawBytes(tensor.data)
val isEmbed = name == Gemma4TensorNames.TOKEN_EMBEDDINGS ||
name == Gemma4TensorNames.OUTPUT_WEIGHT
val packed = if (!isEmbed) packGemmaKQuant<FP32>(bytes, qt, shape) else null
when {
packed != null -> {
@Suppress("UNCHECKED_CAST")
ctx.fromData(packed as TensorData<FP32, Float>, FP32::class) as Tensor<DType, Any>
}
isEmbed -> dequantNoTranspose(bytes, qt, shape, ctx)
else -> dequantTransposed(bytes, qt, shape, ctx)
}
}
}
}
}
@Suppress("UNCHECKED_CAST")
return Gemma4Weights(typed.metadata, newTensors, typed.quantTypes, typed.logicalShapes) as Gemma4Weights<*, *>
}

/** Dequant to FP32 in natural `[rows, cols]` order (embeddings — gathered, not matmul'd). */
@Suppress("UNCHECKED_CAST")
private fun dequantNoTranspose(
bytes: ByteArray,
qt: GGMLQuantizationType,
shape: Shape,
ctx: ExecutionContext,
): Tensor<DType, Any> {
val floats = DequantOps.dequantFromBytes(bytes, qt, shape.volume)
return ctx.fromFloatArray<FP32, Float>(shape, FP32::class, floats) as Tensor<DType, Any>
}

/**
* Dequant to a canonical FP32 `[out, in]` row-major weight. GGUF stores K/legacy
* blocks column-major within a row, so the dequantized floats are transposed
* column-major → row-major to match what `linearProject` (`x @ W.t()`) expects.
*/
@Suppress("UNCHECKED_CAST")
private fun dequantTransposed(
bytes: ByteArray,
qt: GGMLQuantizationType,
shape: Shape,
ctx: ExecutionContext,
): Tensor<DType, Any> {
val floats = DequantOps.dequantFromBytes(bytes, qt, shape.volume)
val out = shape[0]
val inDim = shape[1]
val rowMajor = DequantOps.transposeColumnMajorToRowMajor(floats, inDim, out)
return ctx.fromFloatArray<FP32, Float>(shape, FP32::class, rowMajor) as Tensor<DType, Any>
}

/**
* Read the raw packed bytes back from a `NATIVE_OPTIMIZED` quant tensor. The
* backing differs by platform/factory — JVM stores `IntArrayTensorData` (byte
* values widened to Int); Kotlin/Native stores a Byte-typed tensor — so handle
* both element types.
*/
internal fun extractRawBytes(data: TensorData<*, *>): ByteArray {
if (data is IntArrayTensorData<*>) {
val buf = data.buffer
return ByteArray(buf.size) { buf[it].toByte() }
}
val n = data.shape.volume
@Suppress("UNCHECKED_CAST")
val d = data as TensorData<*, Any?>
return ByteArray(n) {
when (val v = d[it]) {
is Byte -> v
is Int -> v.toByte()
else -> error(
"convertGemmaWeightsPacked: cannot read bytes from ${data::class.simpleName} " +
"(element ${v?.let { e -> e::class.simpleName }})",
)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package sk.ainet.models.gemma

import sk.ainet.io.gguf.GGMLQuantizationType
import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.tensor.data.Q4_KBlockTensorData
import sk.ainet.lang.tensor.data.Q5_KBlockTensorData
import sk.ainet.lang.tensor.data.Q6_KBlockTensorData
import sk.ainet.lang.tensor.data.TensorData
import sk.ainet.lang.types.DType

/**
* Platform-neutral (commonMain) layout helpers for Gemma 4 quantized weights.
*
* These were previously JVM-only (inside `GemmaMemSegConverter`), but the
* Kotlin/Native board path needs the same logic: on K/N there is no
* `java.lang.foreign` MemSeg conversion, so the eager runtime keeps K-quant
* weights as heap-packed `Q{4,5,6}_KBlockTensorData` produced here. The JVM
* MemSeg converter reuses the same relayout + shape recovery.
*/

/**
* Recover the logical 2-D shape of a Gemma 4 weight tensor from its GGUF name
* and model metadata. `Gemma4WeightLoader` with `NATIVE_OPTIMIZED` stores
* quantized tensors as 1-D byte arrays, so converters need the original
* `[rows, cols]` shape to re-layout blocks. Returns `null` for tensors without
* a 2-D matmul layout (norms, embeddings the converter dequantizes anyway).
*/
internal fun logicalShapeFor(name: String, metadata: Gemma4ModelMetadata): Shape? {
val embed = metadata.embeddingLength
val vocab = metadata.vocabSize
return when {
name == Gemma4TensorNames.TOKEN_EMBEDDINGS -> Shape(vocab, embed)
name == Gemma4TensorNames.OUTPUT_WEIGHT -> Shape(vocab, embed)
name.startsWith("blk.") -> {
val rest = name.substringAfter("blk.")
val layer = rest.substringBefore('.').toIntOrNull() ?: return null
val headDim = metadata.getHeadDim(layer)
val qDim = metadata.headCount * headDim
val kvDim = metadata.kvHeadCount * headDim
val ffn = metadata.intermediateSize
when {
name.endsWith(".attn_q.weight") -> Shape(qDim, embed)
name.endsWith(".attn_k.weight") -> Shape(kvDim, embed)
name.endsWith(".attn_v.weight") -> Shape(kvDim, embed)
name.endsWith(".attn_output.weight") -> Shape(embed, qDim)
name.endsWith(".ffn_gate.weight") -> Shape(ffn, embed)
name.endsWith(".ffn_up.weight") -> Shape(ffn, embed)
name.endsWith(".ffn_down.weight") -> Shape(embed, ffn)
else -> null
}
}
else -> null
}
}

/**
* Re-layout GGUF K-series bytes from row-major block order
* (`(r * blocksPerRow + b) * bytesPerBlock`) to the input-block-major order the
* `matmulQ{K}` kernels expect (`(b * outDim + r) * bytesPerBlock`). For a
* `[outDim, inDim]` weight with `inDim % 256 == 0`, this is a block-level 2-D
* transpose; bytes inside a block are untouched.
*
* @param bytesPerBlock 144 (Q4_K), 176 (Q5_K), 210 (Q6_K).
*/
internal fun relayoutKSeriesRowMajorToBlockMajor(
bytes: ByteArray,
shape: Shape,
bytesPerBlock: Int,
): ByteArray {
val blockSize = 256
require(shape.rank == 2) { "K-series weight must be 2D, got rank ${shape.rank}" }
val outDim = shape[0]
val inDim = shape[1]
require(inDim % blockSize == 0) { "K-series weight inDim ($inDim) must be a multiple of $blockSize" }
val blocksPerRow = inDim / blockSize
val expected = outDim.toLong() * blocksPerRow.toLong() * bytesPerBlock.toLong()
require(bytes.size.toLong() >= expected) {
"K-series byte buffer ${bytes.size} < expected $expected for [$outDim, $inDim] @ ${bytesPerBlock}B/block"
}
val out = ByteArray(bytes.size)
for (r in 0 until outDim) {
for (b in 0 until blocksPerRow) {
val srcOff = (r * blocksPerRow + b) * bytesPerBlock
val dstOff = (b * outDim + r) * bytesPerBlock
bytes.copyInto(out, dstOff, srcOff, srcOff + bytesPerBlock)
}
}
return out
}

/** Bytes per ggml block for the K-quant types this packer handles. */
private fun kQuantBytesPerBlock(qt: GGMLQuantizationType): Int? = when (qt) {
GGMLQuantizationType.Q4_K -> 144
GGMLQuantizationType.Q5_K -> 176
GGMLQuantizationType.Q6_K -> 210
else -> null
}

/**
* Pack raw GGUF K-quant `bytes` of logical `[out, in]` shape into the
* heap-packed block tensor data the matmul kernels read directly (Q4_K / Q5_K /
* Q6_K). Performs the row-major → block-major relayout. Returns `null` for
* non-K-quant types (caller dequantizes those to FP32).
*
* commonMain → works on JVM and Kotlin/Native alike (no MemSeg / Arena).
*/
internal fun <T : DType> packGemmaKQuant(
bytes: ByteArray,
qt: GGMLQuantizationType,
shape: Shape,
): TensorData<T, *>? {
val bpb = kQuantBytesPerBlock(qt) ?: return null
val relaid = relayoutKSeriesRowMajorToBlockMajor(bytes, shape, bpb)
@Suppress("UNCHECKED_CAST")
return when (qt) {
GGMLQuantizationType.Q4_K -> Q4_KBlockTensorData(shape, relaid) as TensorData<T, *>
GGMLQuantizationType.Q5_K -> Q5_KBlockTensorData(shape, relaid) as TensorData<T, *>
GGMLQuantizationType.Q6_K -> Q6_KBlockTensorData(shape, relaid) as TensorData<T, *>
else -> null
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package sk.ainet.models.gemma

import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertNull
import kotlin.test.assertTrue
import sk.ainet.context.DirectCpuExecutionContext
import sk.ainet.io.gguf.GGMLQuantizationType
import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.tensor.data.Q5_KBlockTensorData
import sk.ainet.lang.types.FP32
import sk.ainet.lang.types.Int8

/**
* Unit tests for the commonMain (board-shareable) Gemma quant layout helpers.
* These run on every target (JVM + Kotlin/Native), proving the K/N board path's
* relayout + packing logic without needing the full model.
*/
class GemmaQuantLayoutTest {

@Test
fun relayout_is_block_level_transpose() {
// [outDim=2, inDim=512] -> blocksPerRow=2, 4 Q5_K blocks of 176 B.
val bpb = 176
val outDim = 2
val inDim = 512
val blocksPerRow = inDim / 256
val bytes = ByteArray(outDim * blocksPerRow * bpb)
// Tag each source block with its row-major index in its first byte.
for (i in 0 until outDim * blocksPerRow) bytes[i * bpb] = i.toByte()

val relaid = relayoutKSeriesRowMajorToBlockMajor(bytes, Shape(outDim, inDim), bpb)

// dst block (b*outDim + r) must hold src block (r*blocksPerRow + b).
for (r in 0 until outDim) {
for (b in 0 until blocksPerRow) {
val srcIdx = r * blocksPerRow + b
val dstIdx = b * outDim + r
assertEquals(srcIdx.toByte(), relaid[dstIdx * bpb], "block ($r,$b) misplaced")
}
}
}

@Test
fun pack_q5k_produces_block_tensor_with_relaid_bytes() {
val shape = Shape(2, 512)
val bytes = ByteArray(2 * 2 * 176)
for (i in 0 until 4) bytes[i * 176] = (i + 1).toByte()

val td = packGemmaKQuant<FP32>(bytes, GGMLQuantizationType.Q5_K, shape)
assertTrue(td is Q5_KBlockTensorData, "Q5_K should pack to Q5_KBlockTensorData")
// packedData is the block-major relayout of the input.
val expected = relayoutKSeriesRowMajorToBlockMajor(bytes, shape, 176)
assertTrue(expected.contentEquals(td.packedData))
}

@Test
fun pack_non_kquant_returns_null() {
assertNull(packGemmaKQuant<FP32>(ByteArray(34), GGMLQuantizationType.Q8_0, Shape(1, 32)))
}

@Test
fun extract_raw_bytes_roundtrips_on_every_platform() {
// The NATIVE_OPTIMIZED loader wraps quant bytes via ctx.fromByteArray<Int8,Byte>;
// extractRawBytes must read them back regardless of the platform backing
// (JVM IntArrayTensorData vs native Byte-typed). Runs on jvm + linuxX64.
val ctx = DirectCpuExecutionContext.create()
val bytes = ByteArray(176 * 3) { ((it * 31 + 7) and 0xFF).toByte() }
val t = ctx.fromByteArray<Int8, Byte>(Shape(bytes.size), Int8::class, bytes)
val got = extractRawBytes(t.data)
assertTrue(bytes.contentEquals(got), "extractRawBytes round-trip mismatch")
}
}
Loading
Loading