From 4acec7f45223a5578ae17714f7ae78373d2c3a0d Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 15 Jun 2026 15:42:09 +0200 Subject: [PATCH] feat(gemma): optional maxInferenceLen on GemmaNetworkLoader.load() (#178) The eager network sizes its KV cache + RoPE tables for maxInferenceLen (= min(contextLength, 4096) by default). On the 1.9 GB SL2610 that ~0.4 GB KV cache (allocated at the first forward) OOMs the board even after the packed Q8_0 lm_head dropped the weight footprint to ~1.06 GB resident. Thread an optional `maxInferenceLen: Int? = null` through load() -> applyWeightsToNetwork -> applyWeightsToNetworkNonReified -> gemmaNetwork so a constrained-device consumer can cap the context (e.g. 32 for a short tool-call prompt), shrinking the KV cache ~100x. Default null preserves the existing min(contextLength, 4096) behaviour. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../sk/ainet/models/gemma/GemmaNetworkLoader.kt | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaNetworkLoader.kt b/llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaNetworkLoader.kt index abc8c3e..e9d6772 100644 --- a/llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaNetworkLoader.kt +++ b/llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaNetworkLoader.kt @@ -120,7 +120,8 @@ public class GemmaNetworkLoader @PublishedApi internal constructor( * Load weights and build a fully initialized DSL network. */ public suspend inline fun load( - ctx: ExecutionContext + ctx: ExecutionContext, + maxInferenceLen: Int? = null, ): Module { val rawWeights: Gemma4Weights = when (val wp = weightsProvider) { is WeightsProvider.GgufSource -> { @@ -160,14 +161,15 @@ public class GemmaNetworkLoader @PublishedApi internal constructor( rawWeights } - return applyWeightsToNetwork(ctx, weights) + return applyWeightsToNetwork(ctx, weights, maxInferenceLen) } @PublishedApi internal inline fun applyWeightsToNetwork( ctx: ExecutionContext, - weights: Gemma4Weights - ): Module = applyWeightsToNetworkNonReified(ctx, weights, T::class, debug) + weights: Gemma4Weights, + maxInferenceLen: Int? = null, + ): Module = applyWeightsToNetworkNonReified(ctx, weights, T::class, debug, maxInferenceLen) } /** Shared non-reified impl used by both the inline-reified companion helpers @@ -177,7 +179,8 @@ internal fun applyWeightsToNetworkNonReified( ctx: ExecutionContext, weights: Gemma4Weights, dtype: kotlin.reflect.KClass, - debug: Boolean + debug: Boolean, + maxInferenceLen: Int? = null, ): Module { // Enable optional Gemma 4 features iff the checkpoint actually carries // their weights. Real Gemma 4 GGUFs do; synthetic toy-model tests do not, @@ -197,6 +200,7 @@ internal fun applyWeightsToNetworkNonReified( val model = gemmaNetwork( weights.metadata, dtype, + maxInferenceLen = maxInferenceLen ?: minOf(weights.metadata.contextLength, 4096), qkNorm = hasQKNorm, sandwichNorms = hasSandwichNorms, layerOutputScale = hasLayerOutputScale,