Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ public class GemmaNetworkLoader @PublishedApi internal constructor(
* Load weights and build a fully initialized DSL network.
*/
public suspend inline fun <reified T : DType, V> load(
ctx: ExecutionContext
ctx: ExecutionContext,
maxInferenceLen: Int? = null,
): Module<T, V> {
val rawWeights: Gemma4Weights<T, V> = when (val wp = weightsProvider) {
is WeightsProvider.GgufSource -> {
Expand Down Expand Up @@ -160,14 +161,15 @@ public class GemmaNetworkLoader @PublishedApi internal constructor(
rawWeights
}

return applyWeightsToNetwork(ctx, weights)
return applyWeightsToNetwork(ctx, weights, maxInferenceLen)
}

@PublishedApi
internal inline fun <reified T : DType, V> applyWeightsToNetwork(
ctx: ExecutionContext,
weights: Gemma4Weights<T, V>
): Module<T, V> = applyWeightsToNetworkNonReified(ctx, weights, T::class, debug)
weights: Gemma4Weights<T, V>,
maxInferenceLen: Int? = null,
): Module<T, V> = applyWeightsToNetworkNonReified(ctx, weights, T::class, debug, maxInferenceLen)
}

/** Shared non-reified impl used by both the inline-reified companion helpers
Expand All @@ -177,7 +179,8 @@ internal fun <T : DType, V> applyWeightsToNetworkNonReified(
ctx: ExecutionContext,
weights: Gemma4Weights<T, V>,
dtype: kotlin.reflect.KClass<T>,
debug: Boolean
debug: Boolean,
maxInferenceLen: Int? = null,
): Module<T, V> {
// Enable optional Gemma 4 features iff the checkpoint actually carries
// their weights. Real Gemma 4 GGUFs do; synthetic toy-model tests do not,
Expand All @@ -197,6 +200,7 @@ internal fun <T : DType, V> applyWeightsToNetworkNonReified(
val model = gemmaNetwork<T, V>(
weights.metadata,
dtype,
maxInferenceLen = maxInferenceLen ?: minOf(weights.metadata.contextLength, 4096),
qkNorm = hasQKNorm,
sandwichNorms = hasSandwichNorms,
layerOutputScale = hasLayerOutputScale,
Expand Down
Loading