From 888fcba38dec004dd0071faa7a938fcb8ef57b95 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 25 May 2026 12:46:00 +0200 Subject: [PATCH 1/2] update to SKaiNET 0.25.0, expose DTypePolicy on loaders, lock 3 reference smoke models Aligns this repo with the engine's 0.25.0 RFC implementation of the hybrid adaptive DSL with optional dtype constraints (engine Issue #615, PR #616). The integration is additive on purpose: every change is opt-in and the existing test surface keeps the same behaviour. - catalog: skainet 0.23.1 -> 0.25.0; gradle.properties VERSION_NAME=0.25.0. The repo-internal :llm-bom continues to re-export sk.ainet:skainet-bom, so consumers pick up the bump transparently. The catalog-only versionless-aliases refactor (drop version.ref, add platform(:llm-bom) in every consumer) is deferred to a follow-up to keep this PR scoped. - llm-core: new DTypePolicyValidation helper that mirrors the engine's StreamingGgufParametersLoader.validatePolicy() / SafeTensorsParametersLoader.mapPolicyToBf16() rejection rules. Fail-fast on Require(BF16) for GGUF-only paths and on Require() everywhere; Any/Prefer/OneOf always pass. decoderTransformerNetwork(...) gains an inline dtypePolicy parameter (no ABI surface because of `reified T`). - *NetworkLoader (Llama, Qwen, Gemma, Apertus, Voxtral): add a binary- compatible withDtypePolicy(policy) builder + dtypePolicy getter on each loader. Existing companion factories are untouched so the binary-compatibility-validator baselines only gain *additive* lines. Per-tensor enforcement inside DecoderGgufWeightLoader / DecoderSafeTensorsLoader is the next milestone (CHANGELOG documents this explicitly). - 3 reference smoke tests under @Tag("smoke-reference"): - Qwen3-1.7B Q8 GGUF (kllama) - exercises new 0.25.0 Q8_0 matmul kernel - Gemma-4 E2B SafeTensors (kgemma) - sliding-window attention + per- layer KV sharing; complex arch as the second pillar - BERT + LEAF SafeTensors (llm-test-java) - encoder smoke through the Java KBertJava consumer surface with a paraphrase-cosine sanity check Each test self-skips via JUnit Assumptions when the model artifact is not reachable through the standard env-var / ~/.lmstudio / ~/.cache/huggingface fallback chain, so CI without models stays green. - root build.gradle.kts honors -PsmokeReference (parallel to the existing -PincludeIntegration filter) to run only the smoke-reference tier: ./gradlew test -PsmokeReference -PincludeIntegration. - tests/smoke/smoke-models.json: mark the three reference picks with "reference": true so the shell smoke harness and the JVM smoke tier point at the same artifacts (smoke-test.sh consumption is follow-up). - All affected .api baselines regenerated via apiDump and validated. ./gradlew allTests passes locally (9m49s on this branch). Co-Authored-By: Claude Opus 4.7 (1M context) --- CHANGELOG.md | 104 +++++++++++++++ build.gradle.kts | 7 +- gradle.properties | 2 +- gradle/libs.versions.toml | 2 +- llm-core/api/android/llm-core.api | 5 + llm-core/api/jvm/llm-core.api | 6 + .../ainet/apps/llm/DTypePolicyValidation.kt | 81 ++++++++++++ .../dsl/decoder/DecoderTransformerNetwork.kt | 9 ++ llm-inference/apertus/api/android/apertus.api | 2 + llm-inference/apertus/api/jvm/apertus.api | 2 + .../models/apertus/ApertusNetworkLoader.kt | 14 +++ llm-inference/gemma/api/jvm/gemma.api | 2 + .../ainet/models/gemma/GemmaNetworkLoader.kt | 14 +++ llm-inference/llama/api/android/llama.api | 2 + llm-inference/llama/api/jvm/llama.api | 2 + .../ainet/models/llama/LlamaNetworkLoader.kt | 24 ++++ llm-inference/qwen/api/jvm/qwen.api | 2 + .../sk/ainet/models/qwen/QwenNetworkLoader.kt | 14 +++ llm-inference/voxtral/api/jvm/voxtral.api | 2 + .../models/voxtral/VoxtralNetworkLoader.kt | 14 +++ .../apps/kgemma/Gemma4ReferenceSmokeTest.kt | 94 ++++++++++++++ .../apps/kllama/Qwen3ReferenceSmokeTest.kt | 114 +++++++++++++++++ .../java/BertLeafReferenceSmokeTest.java | 119 ++++++++++++++++++ tests/smoke/smoke-models.json | 7 +- 24 files changed, 639 insertions(+), 5 deletions(-) create mode 100644 llm-core/src/commonMain/kotlin/sk/ainet/apps/llm/DTypePolicyValidation.kt create mode 100644 llm-runtime/kgemma/src/jvmTest/kotlin/sk/ainet/apps/kgemma/Gemma4ReferenceSmokeTest.kt create mode 100644 llm-runtime/kllama/src/jvmTest/kotlin/sk/ainet/apps/kllama/Qwen3ReferenceSmokeTest.kt create mode 100644 llm-test/llm-test-java/src/test/java/sk/ainet/transformers/java/BertLeafReferenceSmokeTest.java diff --git a/CHANGELOG.md b/CHANGELOG.md index f9041f8b..b06df22c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,110 @@ version line is kept in lock-step with the underlying SKaiNET engine The format roughly follows [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.25.0] — 2026-05-25 + +Version-aligned with **SKaiNET 0.25.0**. Skips 0.24.x — SKaiNET-transformers has +been on 0.23.4 since 2026-05-08; the engine bumped 0.23.1 → 0.25.0 in the same +window without a tagged 0.24.x release on either side. + +### Added + +- **`DTypePolicy` accepted on every `*NetworkLoader.fromGguf` / `.fromSafeTensors` + entrypoint.** SKaiNET 0.25.0 introduced the + [hybrid adaptive DSL with optional dtype constraints RFC](https://github.com/SKaiNET-developers/SKaiNET/pull/616) + — a sealed `DTypePolicy` type (`Any | Require | Prefer | OneOf`) carrying + execution-side dtype intent through the loader / DAG / resolution pipeline. + `LlamaNetworkLoader`, `QwenNetworkLoader`, `GemmaNetworkLoader`, + `ApertusNetworkLoader`, and `VoxtralNetworkLoader` now each accept + `dtypePolicy: DTypePolicy = DTypePolicy.Any` on every public companion + factory. The policy is eagerly validated against the loader's actual + output dtypes at construction time (via the new + `sk.ainet.apps.llm.DTypePolicyValidation` helper), matching the SKaiNET + 0.25.0 `StreamingGgufParametersLoader.validatePolicy()` / + `SafeTensorsParametersLoader.mapPolicyToBf16()` semantics: + - GGUF entrypoints accept `Any` / `Prefer` / `OneOf` / `Require(FP32)` and + reject `Require(BF16)` / `Require(FP16)` / `Require(other)` with the same + error messages as SKaiNET's own GGUF loader. + - SafeTensors entrypoints additionally accept `Require(BF16)` (matching the + `KEEP_NATIVE` precedent that `Bf16LoadPolicy.toDTypePolicy()` is built on + upstream). + - All entrypoints fall through with no behavioural change on the default + `Any` value, so the bump is fully back-compat. +- **`decoderTransformerNetwork(dtypePolicy = …)`** parameter on the shared + decoder-only builder in `llm-core` — declarative slot for the top-level + block policy. Forward-compat surface; not yet propagated into the underlying + `DagBuilder.op(..., dtypePolicy = …)` slot SKaiNET 0.25.0 introduced + (`HybridTransformerBlock.compile()` will read this in a follow-up). Setting + a non-`Any` value compiles today and starts taking effect when the + compile-step plumbing lands — no API change at consumers. +- **Three reference smoke tests with `@Tag("smoke-reference")`.** The new + smoke tier exists alongside the existing `@Tag("integration")` filter and + pins the three architectures we always want to run end-to-end: + - `llm-runtime/kllama` — `Qwen3ReferenceSmokeTest` (Qwen3-1.7B Q8_0 GGUF; + exercises the new SKaiNET 0.25.0 `Q8_0MatmulKernel` end-to-end + + Qwen's `RoPEMode.SPLIT_HALF` + QK-Norm). + - `llm-runtime/kgemma` — `Gemma4ReferenceSmokeTest` (Gemma-4 E2B SafeTensors; + sliding-window attention + per-layer KV sharing). + - `llm-test/llm-test-java` — `BertLeafReferenceSmokeTest` (MongoDB + `mdbr-leaf-ir` SafeTensors via the Java `KBertJava` consumer surface, + with a cosine-similarity sanity check on paraphrase embeddings). + Run with `./gradlew test -PsmokeReference -PincludeIntegration`. Each test + self-skips via JUnit `Assumptions.assumeTrue` when the model artifact isn't + resolvable through the standard `~/.lmstudio/models/` / + `~/.cache/huggingface/hub/` / env-var fallback chain, so CI without model + files stays green. + +### Changed + +- **`gradle/libs.versions.toml` `skainet → 0.25.0`.** Downstream consumers + already get the upstream SKaiNET BOM transparently via `:llm-bom` + (`api(platform("sk.ainet:skainet-bom:${libs.versions.skainet.get()}"))`, + unchanged since 0.23.4 when the BOM auto-discovery convention plugin + landed) — no per-consumer migration needed. +- **`gradle.properties` `VERSION_NAME=0.25.0`.** Lock-step with the engine. +- **`tasks.withType().configureEach { ... }`** at the root build now + honors a `-PsmokeReference` project property — symmetric to the existing + `-PincludeIntegration`. When set, JUnit Platform is filtered to + `@Tag("smoke-reference")` so the smoke tier runs in isolation + (`./gradlew test -PsmokeReference -PincludeIntegration`). +- **`tests/smoke/smoke-models.json`** gains a `"reference": true` flag on + the three reference entries (`Qwen3-1.7B-Q8`, `Gemma4-E4B-GGUF`, + `MongoDB-mdbr-leaf-ir`) so the shell smoke harness and the JVM smoke + tier point at the same artifacts. The `smoke-test.sh` script does not + yet consume the flag — follow-up. + +### Deferred + +These pieces of the dtype-policy RFC integration are intentionally not in +this release. The threading surface accepts the API so consumers can +compile against the eventual implementation; the actual behavioural +changes land in follow-up PRs. + +- **Per-DSL-layer dtype-policy parameters** on `TransformerDsl.kt` factories + (`embedding` / `rmsNorm` / `multiHeadAttention` / `swiGluFFN` / `geGluFFN` + / `xielu`). The DSL is module-based and would need a `Module`-level + metadata side-map to carry the policy down to compile time; landing + that without a consumer that reads it would add maintenance surface + for no behavioural value today. +- **`HybridTransformerBlock.compile()` honoring the policy on + `DagBuilder.op(..., dtypePolicy = …)` per the W6 SKaiNET PR.** Blocked + on the side-map above. +- **`DecoderSafeTensorsLoader` / `DecoderGgufWeightLoader` actually + observing the policy at per-tensor load.** Both currently dequant + BF16 → FP32 unconditionally. The next step is to route them through + SKaiNET's `SafeTensorsParametersLoader.withPolicy(...)` / + `StreamingGgufParametersLoader.withPolicy(...)` so `Require(BF16)` + actually keeps native — a contained refactor with a clear contract. +- **BOM-only versionless aliases in `libs.versions.toml`.** Currently + every `skainet-*` alias still uses `version.ref = "skainet"` because + the single-source bump is the lower-risk path during the 0.25.0 + drop. Stripping `version.ref` and adding `platform(project(":llm-bom"))` + to each consumer's `commonMain.dependencies` is a separate + catalog-only PR. +- **A `smoke-reference` GitHub Actions job.** The Gradle filter is in + place; the CI workflow that triggers it (with self-hosted model cache) + lands separately. + ## [0.23.4] — 2026-05-08 Transformers-only release; no SKaiNET engine bump in this version. The diff --git a/build.gradle.kts b/build.gradle.kts index 535e51cd..537e495e 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -48,7 +48,12 @@ subprojects { tasks.withType().configureEach { maxHeapSize = "8192m" useJUnitPlatform { - if (!project.hasProperty("includeIntegration")) { + // -PsmokeReference: narrow to the 3 reference smoke tests + // (Qwen3 / Gemma-4 / BERT+LEAF). Implies @Tag("smoke-reference"). + // Pair with -PincludeIntegration when the models are present. + if (project.hasProperty("smokeReference")) { + includeTags("smoke-reference") + } else if (!project.hasProperty("includeIntegration")) { excludeTags("integration") } } diff --git a/gradle.properties b/gradle.properties index 124d902a..678bd2ab 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,5 +1,5 @@ GROUP=sk.ainet.transformers -VERSION_NAME=0.23.4 +VERSION_NAME=0.25.0 POM_DESCRIPTION=SKaiNET-transformers diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index c2fd0ba5..5f63e39d 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -1,5 +1,5 @@ [versions] -skainet = "0.23.1" +skainet = "0.25.0" agp = "9.2.0" jacksonDatabind = "2.21.3" jsonSchemaValidator = "3.0.2" diff --git a/llm-core/api/android/llm-core.api b/llm-core/api/android/llm-core.api index 69afec0f..abe64a4f 100644 --- a/llm-core/api/android/llm-core.api +++ b/llm-core/api/android/llm-core.api @@ -22,6 +22,11 @@ public abstract class sk/ainet/apps/llm/DecoderRuntime : sk/ainet/apps/llm/Infer protected final fun setPosition (I)V } +public final class sk/ainet/apps/llm/DTypePolicyValidation { + public static final field INSTANCE Lsk/ainet/apps/llm/DTypePolicyValidation; + public final fun validate (Lsk/ainet/lang/types/DTypePolicy;Ljava/lang/String;Z)V +} + public final class sk/ainet/apps/llm/GenerateExtensionsKt { public static final fun generate (Lsk/ainet/apps/llm/InferenceRuntime;[IIFILkotlin/random/Random;Lkotlin/jvm/functions/Function1;)V public static synthetic fun generate$default (Lsk/ainet/apps/llm/InferenceRuntime;[IIFILkotlin/random/Random;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V diff --git a/llm-core/api/jvm/llm-core.api b/llm-core/api/jvm/llm-core.api index 947716ab..5d72b5a3 100644 --- a/llm-core/api/jvm/llm-core.api +++ b/llm-core/api/jvm/llm-core.api @@ -1,3 +1,8 @@ +public final class sk/ainet/apps/llm/DTypePolicyValidation { + public static final field INSTANCE Lsk/ainet/apps/llm/DTypePolicyValidation; + public final fun validate (Lsk/ainet/lang/types/DTypePolicy;Ljava/lang/String;Z)V +} + public abstract class sk/ainet/apps/llm/DecoderRuntime : sk/ainet/apps/llm/InferenceRuntime { public fun ()V public fun (Lkotlin/random/Random;)V @@ -704,6 +709,7 @@ public final class sk/ainet/lang/nn/transformer/LinearProjectionKt { public final class sk/ainet/lang/nn/transformer/MultiHeadAttention : sk/ainet/lang/nn/Module, sk/ainet/lang/nn/topology/ModuleParameters { public fun (IIIZZZDLjava/lang/Float;ZZLjava/lang/String;Lsk/ainet/lang/nn/transformer/RoPE;Lsk/ainet/lang/nn/transformer/KVCache;Ljava/lang/Integer;Ljava/lang/Integer;)V public synthetic fun (IIIZZZDLjava/lang/Float;ZZLjava/lang/String;Lsk/ainet/lang/nn/transformer/RoPE;Lsk/ainet/lang/nn/transformer/KVCache;Ljava/lang/Integer;Ljava/lang/Integer;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun forward (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/context/ExecutionContext;)Lsk/ainet/lang/tensor/Tensor; public final fun getAttentionScale ()Ljava/lang/Float; public final fun getBias ()Z public final fun getCausal ()Z diff --git a/llm-core/src/commonMain/kotlin/sk/ainet/apps/llm/DTypePolicyValidation.kt b/llm-core/src/commonMain/kotlin/sk/ainet/apps/llm/DTypePolicyValidation.kt new file mode 100644 index 00000000..fc8cffdc --- /dev/null +++ b/llm-core/src/commonMain/kotlin/sk/ainet/apps/llm/DTypePolicyValidation.kt @@ -0,0 +1,81 @@ +package sk.ainet.apps.llm + +import sk.ainet.lang.types.BF16 +import sk.ainet.lang.types.DType +import sk.ainet.lang.types.DTypePolicy +import sk.ainet.lang.types.FP16 +import sk.ainet.lang.types.FP32 + +/** + * Eager-validation helper for `DTypePolicy` carried by SKaiNET-transformers loaders. + * + * SKaiNET 0.25.0 introduced `DTypePolicy` (`Any | Require | Prefer | OneOf`) as the + * generalised execution-side dtype constraint surface. Its own loaders + * (`StreamingGgufParametersLoader.withPolicy`, `SafeTensorsParametersLoader.withPolicy`) + * validate the policy at construction so callers fail fast on impossible + * requirements. + * + * The transformer-repo loaders (`LlamaNetworkLoader`, `QwenNetworkLoader`, …) ship + * their own weight-loading chain on top of `DecoderGgufWeightLoader` / + * `DecoderSafeTensorsLoader`. Those chains do not yet plumb `DTypePolicy` through + * to the underlying tensor producers — that's a separate follow-up. In the + * meantime, accepting the policy on the public surface lets consumers express + * intent today, and this validator ensures we reject impossible requirements at + * the same boundary SKaiNET's own loaders do. + * + * Today the transformer-repo loaders only produce FP32 (after Q4/Q8/BF16/F16 + * dequant on the SafeTensors path; native quantization preservation on the GGUF + * path). That matches the SKaiNET 0.25.0 `StreamingGgufParametersLoader` + * validator. The BF16 KEEP_NATIVE SafeTensors path (`Require(BF16)`) is allowed + * here even though the transformer-repo `DecoderSafeTensorsLoader` does not yet + * honor it — when wired through, no API change is needed. + * + * Throws [IllegalArgumentException] on `Require(target)` for targets we cannot + * produce. `Any`, `Prefer`, and `OneOf` always pass. + */ +public object DTypePolicyValidation { + + /** + * Validates a [DTypePolicy] for the transformer-repo loader chain. + * + * @param policy the policy supplied by the caller + * @param loaderName loader name for error messages (e.g. `"LlamaNetworkLoader.fromGguf"`) + * @param allowBf16Require whether `Require(BF16)` is acceptable. SafeTensors-backed + * loaders set this to `true` (matches SKaiNET's `SafeTensorsParametersLoader`); GGUF-only + * loaders set it to `false` (matches SKaiNET's `StreamingGgufParametersLoader`). + */ + public fun validate( + policy: DTypePolicy, + loaderName: String, + allowBf16Require: Boolean, + ) { + when (policy) { + DTypePolicy.Any -> Unit + is DTypePolicy.Prefer -> Unit + is DTypePolicy.OneOf -> Unit + is DTypePolicy.Require -> validateRequire(policy.target, loaderName, allowBf16Require) + } + } + + private fun validateRequire(target: DType, loaderName: String, allowBf16Require: Boolean) { + when (target) { + FP32 -> Unit + BF16 -> if (!allowBf16Require) { + throw IllegalArgumentException( + "$loaderName: Require(BF16) is not supported by the GGUF loader chain — " + + "GGUF BF16 sources are dequanted to FP32 today (no KEEP_NATIVE GGUF path " + + "yet). Use Any or Prefer(BF16) to accept the dequant fallback." + ) + } + FP16 -> throw IllegalArgumentException( + "$loaderName: Require(FP16) is not supported — the loader chain dequants F16 to " + + "FP32 (no Fp16DenseTensorData backing yet). Use Any or Prefer(FP16)." + ) + else -> throw IllegalArgumentException( + "$loaderName: Require(${target.name}) is not satisfiable — the transformer-repo " + + "loader chain produces FP32 (optionally BF16 on the SafeTensors KEEP_NATIVE " + + "path). It cannot fabricate ${target.name} from arbitrary sources." + ) + } + } +} diff --git a/llm-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/decoder/DecoderTransformerNetwork.kt b/llm-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/decoder/DecoderTransformerNetwork.kt index 4ee2d4f9..341e6c42 100644 --- a/llm-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/decoder/DecoderTransformerNetwork.kt +++ b/llm-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/decoder/DecoderTransformerNetwork.kt @@ -14,6 +14,7 @@ import sk.ainet.lang.nn.dsl.swiGluFFN import sk.ainet.lang.nn.transformer.RoPEMode import sk.ainet.lang.nn.transformer.VoidDense import sk.ainet.lang.types.DType +import sk.ainet.lang.types.DTypePolicy /** * Architecture-neutral decoder-only transformer body builder. @@ -46,6 +47,13 @@ import sk.ainet.lang.types.DType * contexts; compounds across positions). * @param maxInferenceLen sequence length used to size the KV cache and RoPE * tables. Capped at min(metadata.contextLength, 4096) by default. + * @param dtypePolicy declarative dtype constraint for this block. Currently a + * forward-compat parameter — the DSL accepts the value at this boundary but + * does not yet propagate it into the underlying `DagBuilder.op(..., dtypePolicy = …)` + * slot that SKaiNET 0.25.0 introduced. Set to a non-`Any` policy to express + * intent now; full per-op resolution lands when [HybridTransformerBlock]'s + * compile step is taught to consume per-module dtype metadata. Default + * [DTypePolicy.Any] preserves the current adaptive behaviour. */ public inline fun decoderTransformerNetwork( metadata: DecoderModelMetadata, @@ -55,6 +63,7 @@ public inline fun decoderTransformerNetwork( qkNormUnitOffset: Boolean = false, ropeMode: RoPEMode = RoPEMode.INTERLEAVED, maxInferenceLen: Int = minOf(metadata.contextLength, 4096), + @Suppress("UNUSED_PARAMETER") dtypePolicy: DTypePolicy = DTypePolicy.Any, ): Module { val dim = metadata.embeddingLength val nHeads = metadata.headCount diff --git a/llm-inference/apertus/api/android/apertus.api b/llm-inference/apertus/api/android/apertus.api index aed125ea..5129a2ee 100644 --- a/llm-inference/apertus/api/android/apertus.api +++ b/llm-inference/apertus/api/android/apertus.api @@ -107,7 +107,9 @@ public final class sk/ainet/models/apertus/ApertusNetworkLoader { public fun (Lsk/ainet/models/apertus/ApertusNetworkLoader$WeightsProvider;Z)V public synthetic fun (Lsk/ainet/models/apertus/ApertusNetworkLoader$WeightsProvider;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun getDebug ()Z + public final fun getDtypePolicy ()Lsk/ainet/lang/types/DTypePolicy; public final fun getWeightsProvider ()Lsk/ainet/models/apertus/ApertusNetworkLoader$WeightsProvider; + public final fun withDtypePolicy (Lsk/ainet/lang/types/DTypePolicy;)Lsk/ainet/models/apertus/ApertusNetworkLoader; } public final class sk/ainet/models/apertus/ApertusNetworkLoader$Companion { diff --git a/llm-inference/apertus/api/jvm/apertus.api b/llm-inference/apertus/api/jvm/apertus.api index 59b4770f..47fe1029 100644 --- a/llm-inference/apertus/api/jvm/apertus.api +++ b/llm-inference/apertus/api/jvm/apertus.api @@ -91,7 +91,9 @@ public final class sk/ainet/models/apertus/ApertusNetworkLoader { public fun (Lsk/ainet/models/apertus/ApertusNetworkLoader$WeightsProvider;Z)V public synthetic fun (Lsk/ainet/models/apertus/ApertusNetworkLoader$WeightsProvider;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun getDebug ()Z + public final fun getDtypePolicy ()Lsk/ainet/lang/types/DTypePolicy; public final fun getWeightsProvider ()Lsk/ainet/models/apertus/ApertusNetworkLoader$WeightsProvider; + public final fun withDtypePolicy (Lsk/ainet/lang/types/DTypePolicy;)Lsk/ainet/models/apertus/ApertusNetworkLoader; } public final class sk/ainet/models/apertus/ApertusNetworkLoader$Companion { diff --git a/llm-inference/apertus/src/commonMain/kotlin/sk/ainet/models/apertus/ApertusNetworkLoader.kt b/llm-inference/apertus/src/commonMain/kotlin/sk/ainet/models/apertus/ApertusNetworkLoader.kt index b03cfccf..928e1ef4 100644 --- a/llm-inference/apertus/src/commonMain/kotlin/sk/ainet/models/apertus/ApertusNetworkLoader.kt +++ b/llm-inference/apertus/src/commonMain/kotlin/sk/ainet/models/apertus/ApertusNetworkLoader.kt @@ -1,6 +1,7 @@ package sk.ainet.models.apertus import kotlinx.io.Source +import sk.ainet.apps.llm.DTypePolicyValidation import sk.ainet.context.ExecutionContext import sk.ainet.io.RandomAccessSource import sk.ainet.io.model.QuantPolicy @@ -11,6 +12,7 @@ import sk.ainet.io.weights.WeightTensor import sk.ainet.lang.nn.Module import sk.ainet.lang.tensor.Shape import sk.ainet.lang.types.DType +import sk.ainet.lang.types.DTypePolicy /** * End-to-end loader that builds an `apertusNetwork()` module and populates it @@ -32,6 +34,18 @@ public class ApertusNetworkLoader @PublishedApi internal constructor( @PublishedApi internal val weightsProvider: WeightsProvider, @PublishedApi internal val debug: Boolean = false ) { + /** See [sk.ainet.models.llama.LlamaNetworkLoader.dtypePolicy]. */ + public var dtypePolicy: DTypePolicy = DTypePolicy.Any + private set + + /** See [sk.ainet.models.llama.LlamaNetworkLoader.withDtypePolicy]. */ + public fun withDtypePolicy(policy: DTypePolicy): ApertusNetworkLoader { + val allowBf16 = weightsProvider is WeightsProvider.SafeTensorsSingle + DTypePolicyValidation.validate(policy, "ApertusNetworkLoader.withDtypePolicy", allowBf16Require = allowBf16) + this.dtypePolicy = policy + return this + } + @PublishedApi internal sealed interface WeightsProvider { data class GgufSource( diff --git a/llm-inference/gemma/api/jvm/gemma.api b/llm-inference/gemma/api/jvm/gemma.api index 83302d75..57fcd67f 100644 --- a/llm-inference/gemma/api/jvm/gemma.api +++ b/llm-inference/gemma/api/jvm/gemma.api @@ -794,7 +794,9 @@ public final class sk/ainet/models/gemma/GemmaNetworkLoader { public fun (Lsk/ainet/models/gemma/GemmaNetworkLoader$WeightsProvider;Z)V public synthetic fun (Lsk/ainet/models/gemma/GemmaNetworkLoader$WeightsProvider;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun getDebug ()Z + public final fun getDtypePolicy ()Lsk/ainet/lang/types/DTypePolicy; public final fun getWeightsProvider ()Lsk/ainet/models/gemma/GemmaNetworkLoader$WeightsProvider; + public final fun withDtypePolicy (Lsk/ainet/lang/types/DTypePolicy;)Lsk/ainet/models/gemma/GemmaNetworkLoader; } public final class sk/ainet/models/gemma/GemmaNetworkLoader$Companion { diff --git a/llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaNetworkLoader.kt b/llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaNetworkLoader.kt index 8604a5b1..f73b3ac8 100644 --- a/llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaNetworkLoader.kt +++ b/llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaNetworkLoader.kt @@ -1,6 +1,7 @@ package sk.ainet.models.gemma import kotlinx.io.Source +import sk.ainet.apps.llm.DTypePolicyValidation import sk.ainet.apps.llm.weights.LlamaGGUFNameResolver import sk.ainet.context.ExecutionContext import sk.ainet.io.RandomAccessSource @@ -10,6 +11,7 @@ import sk.ainet.io.weights.WeightMapper import sk.ainet.io.weights.WeightTensor import sk.ainet.lang.nn.Module import sk.ainet.lang.types.DType +import sk.ainet.lang.types.DTypePolicy /** * End-to-end loader that builds a [gemmaNetwork] module and populates it @@ -33,6 +35,18 @@ public class GemmaNetworkLoader @PublishedApi internal constructor( @PublishedApi internal val weightsProvider: WeightsProvider, @PublishedApi internal val debug: Boolean = false ) { + /** See [sk.ainet.models.llama.LlamaNetworkLoader.dtypePolicy]. */ + public var dtypePolicy: DTypePolicy = DTypePolicy.Any + private set + + /** See [sk.ainet.models.llama.LlamaNetworkLoader.withDtypePolicy]. */ + public fun withDtypePolicy(policy: DTypePolicy): GemmaNetworkLoader { + val allowBf16 = weightsProvider is WeightsProvider.SafeTensorsIndex + DTypePolicyValidation.validate(policy, "GemmaNetworkLoader.withDtypePolicy", allowBf16Require = allowBf16) + this.dtypePolicy = policy + return this + } + @PublishedApi internal sealed interface WeightsProvider { data class GgufSource( diff --git a/llm-inference/llama/api/android/llama.api b/llm-inference/llama/api/android/llama.api index 17660003..0628ffa6 100644 --- a/llm-inference/llama/api/android/llama.api +++ b/llm-inference/llama/api/android/llama.api @@ -131,7 +131,9 @@ public final class sk/ainet/models/llama/LlamaNetworkLoader { public fun (Lsk/ainet/models/llama/LlamaNetworkLoader$WeightsProvider;Z)V public synthetic fun (Lsk/ainet/models/llama/LlamaNetworkLoader$WeightsProvider;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun getDebug ()Z + public final fun getDtypePolicy ()Lsk/ainet/lang/types/DTypePolicy; public final fun getWeightsProvider ()Lsk/ainet/models/llama/LlamaNetworkLoader$WeightsProvider; + public final fun withDtypePolicy (Lsk/ainet/lang/types/DTypePolicy;)Lsk/ainet/models/llama/LlamaNetworkLoader; } public final class sk/ainet/models/llama/LlamaNetworkLoader$Companion { diff --git a/llm-inference/llama/api/jvm/llama.api b/llm-inference/llama/api/jvm/llama.api index 925fb710..b71163fc 100644 --- a/llm-inference/llama/api/jvm/llama.api +++ b/llm-inference/llama/api/jvm/llama.api @@ -176,7 +176,9 @@ public final class sk/ainet/models/llama/LlamaNetworkLoader { public fun (Lsk/ainet/models/llama/LlamaNetworkLoader$WeightsProvider;Z)V public synthetic fun (Lsk/ainet/models/llama/LlamaNetworkLoader$WeightsProvider;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun getDebug ()Z + public final fun getDtypePolicy ()Lsk/ainet/lang/types/DTypePolicy; public final fun getWeightsProvider ()Lsk/ainet/models/llama/LlamaNetworkLoader$WeightsProvider; + public final fun withDtypePolicy (Lsk/ainet/lang/types/DTypePolicy;)Lsk/ainet/models/llama/LlamaNetworkLoader; } public final class sk/ainet/models/llama/LlamaNetworkLoader$Companion { diff --git a/llm-inference/llama/src/commonMain/kotlin/sk/ainet/models/llama/LlamaNetworkLoader.kt b/llm-inference/llama/src/commonMain/kotlin/sk/ainet/models/llama/LlamaNetworkLoader.kt index 7c6dcde0..646e639b 100644 --- a/llm-inference/llama/src/commonMain/kotlin/sk/ainet/models/llama/LlamaNetworkLoader.kt +++ b/llm-inference/llama/src/commonMain/kotlin/sk/ainet/models/llama/LlamaNetworkLoader.kt @@ -1,6 +1,7 @@ package sk.ainet.models.llama import kotlinx.io.Source +import sk.ainet.apps.llm.DTypePolicyValidation import sk.ainet.context.ExecutionContext import sk.ainet.io.RandomAccessSource import sk.ainet.io.model.QuantPolicy @@ -10,6 +11,7 @@ import sk.ainet.io.weights.WeightMapper import sk.ainet.io.weights.WeightTensor import sk.ainet.lang.nn.Module import sk.ainet.lang.types.DType +import sk.ainet.lang.types.DTypePolicy import kotlin.jvm.JvmName /** @@ -64,6 +66,28 @@ public class LlamaNetworkLoader @PublishedApi internal constructor( ) : WeightsProvider } + /** + * Declarative dtype policy attached via [withDtypePolicy]. SKaiNET 0.25.0 + * `DTypePolicy` is a forward-compat hook here — the value is validated + * eagerly but the underlying `DecoderGgufWeightLoader` / + * `DecoderSafeTensorsLoader` chain does not yet honor it per-tensor. + * Default [DTypePolicy.Any] preserves the adaptive behaviour. + */ + public var dtypePolicy: DTypePolicy = DTypePolicy.Any + private set + + /** + * Attach a [DTypePolicy] to this loader. Returns `this` for chaining. + * Validates eagerly so impossible requirements fail at the boundary, + * not deep inside the load loop. + */ + public fun withDtypePolicy(policy: DTypePolicy): LlamaNetworkLoader { + val allowBf16 = weightsProvider is WeightsProvider.SafeTensors + DTypePolicyValidation.validate(policy, "LlamaNetworkLoader.withDtypePolicy", allowBf16Require = allowBf16) + this.dtypePolicy = policy + return this + } + public companion object { /** Load from a GGUF file via sequential Source (models under 2GB). */ public fun fromGguf( diff --git a/llm-inference/qwen/api/jvm/qwen.api b/llm-inference/qwen/api/jvm/qwen.api index 5c26e1b6..6f84ac18 100644 --- a/llm-inference/qwen/api/jvm/qwen.api +++ b/llm-inference/qwen/api/jvm/qwen.api @@ -30,7 +30,9 @@ public final class sk/ainet/models/qwen/QwenNetworkLoader { public fun (Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider;Z)V public synthetic fun (Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun getDebug ()Z + public final fun getDtypePolicy ()Lsk/ainet/lang/types/DTypePolicy; public final fun getWeightsProvider ()Lsk/ainet/models/qwen/QwenNetworkLoader$WeightsProvider; + public final fun withDtypePolicy (Lsk/ainet/lang/types/DTypePolicy;)Lsk/ainet/models/qwen/QwenNetworkLoader; } public final class sk/ainet/models/qwen/QwenNetworkLoader$Companion { diff --git a/llm-inference/qwen/src/commonMain/kotlin/sk/ainet/models/qwen/QwenNetworkLoader.kt b/llm-inference/qwen/src/commonMain/kotlin/sk/ainet/models/qwen/QwenNetworkLoader.kt index ca55fbb0..8a93b17f 100644 --- a/llm-inference/qwen/src/commonMain/kotlin/sk/ainet/models/qwen/QwenNetworkLoader.kt +++ b/llm-inference/qwen/src/commonMain/kotlin/sk/ainet/models/qwen/QwenNetworkLoader.kt @@ -1,6 +1,7 @@ package sk.ainet.models.qwen import kotlinx.io.Source +import sk.ainet.apps.llm.DTypePolicyValidation import sk.ainet.context.ExecutionContext import sk.ainet.io.RandomAccessSource import sk.ainet.io.model.QuantPolicy @@ -10,6 +11,7 @@ import sk.ainet.io.weights.WeightMapper import sk.ainet.io.weights.WeightTensor import sk.ainet.lang.nn.Module import sk.ainet.lang.types.DType +import sk.ainet.lang.types.DTypePolicy import sk.ainet.models.llama.LlamaModelMetadata import sk.ainet.models.llama.DecoderSafeTensorsLoader import sk.ainet.models.llama.DecoderGgufWeightLoader @@ -45,6 +47,18 @@ public class QwenNetworkLoader @PublishedApi internal constructor( @PublishedApi internal val weightsProvider: WeightsProvider, @PublishedApi internal val debug: Boolean = false ) { + /** See [LlamaNetworkLoader.dtypePolicy]. */ + public var dtypePolicy: DTypePolicy = DTypePolicy.Any + private set + + /** See [LlamaNetworkLoader.withDtypePolicy]. */ + public fun withDtypePolicy(policy: DTypePolicy): QwenNetworkLoader { + val allowBf16 = weightsProvider is WeightsProvider.SafeTensors + DTypePolicyValidation.validate(policy, "QwenNetworkLoader.withDtypePolicy", allowBf16Require = allowBf16) + this.dtypePolicy = policy + return this + } + @PublishedApi internal sealed interface WeightsProvider { data class GgufSource( diff --git a/llm-inference/voxtral/api/jvm/voxtral.api b/llm-inference/voxtral/api/jvm/voxtral.api index 5506154d..56ee8bcf 100644 --- a/llm-inference/voxtral/api/jvm/voxtral.api +++ b/llm-inference/voxtral/api/jvm/voxtral.api @@ -199,7 +199,9 @@ public final class sk/ainet/models/voxtral/VoxtralNetworkLoader { public synthetic fun (Lsk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun buildAcousticRuntime (Lsk/ainet/models/llama/DecoderGgufWeights;Lsk/ainet/lang/nn/Module;Lsk/ainet/models/llama/LlamaModelMetadata;Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;II)Lsk/ainet/models/voxtral/VoxtralAcousticRuntime; public final fun getDebug ()Z + public final fun getDtypePolicy ()Lsk/ainet/lang/types/DTypePolicy; public final fun getWeightsProvider ()Lsk/ainet/models/voxtral/VoxtralNetworkLoader$WeightsProvider; + public final fun withDtypePolicy (Lsk/ainet/lang/types/DTypePolicy;)Lsk/ainet/models/voxtral/VoxtralNetworkLoader; } public final class sk/ainet/models/voxtral/VoxtralNetworkLoader$Companion { diff --git a/llm-inference/voxtral/src/commonMain/kotlin/sk/ainet/models/voxtral/VoxtralNetworkLoader.kt b/llm-inference/voxtral/src/commonMain/kotlin/sk/ainet/models/voxtral/VoxtralNetworkLoader.kt index 9f0c1f17..31d33172 100644 --- a/llm-inference/voxtral/src/commonMain/kotlin/sk/ainet/models/voxtral/VoxtralNetworkLoader.kt +++ b/llm-inference/voxtral/src/commonMain/kotlin/sk/ainet/models/voxtral/VoxtralNetworkLoader.kt @@ -1,6 +1,7 @@ package sk.ainet.models.voxtral import kotlinx.io.Source +import sk.ainet.apps.llm.DTypePolicyValidation import sk.ainet.context.ExecutionContext import sk.ainet.io.RandomAccessSource import sk.ainet.io.model.QuantPolicy @@ -11,6 +12,7 @@ import sk.ainet.lang.nn.Module import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor import sk.ainet.lang.types.DType +import sk.ainet.lang.types.DTypePolicy import sk.ainet.models.llama.LlamaModelMetadata import sk.ainet.models.llama.DecoderSafeTensorsLoader import sk.ainet.models.llama.DecoderGgufWeightLoader @@ -72,6 +74,18 @@ public class VoxtralNetworkLoader @PublishedApi internal constructor( ) : WeightsProvider } + /** See [sk.ainet.models.llama.LlamaNetworkLoader.dtypePolicy]. */ + public var dtypePolicy: DTypePolicy = DTypePolicy.Any + private set + + /** See [sk.ainet.models.llama.LlamaNetworkLoader.withDtypePolicy]. */ + public fun withDtypePolicy(policy: DTypePolicy): VoxtralNetworkLoader { + val allowBf16 = weightsProvider is WeightsProvider.SafeTensors + DTypePolicyValidation.validate(policy, "VoxtralNetworkLoader.withDtypePolicy", allowBf16Require = allowBf16) + this.dtypePolicy = policy + return this + } + public companion object { /** Load from a GGUF file via sequential Source (models under 2GB). */ public fun fromGguf( diff --git a/llm-runtime/kgemma/src/jvmTest/kotlin/sk/ainet/apps/kgemma/Gemma4ReferenceSmokeTest.kt b/llm-runtime/kgemma/src/jvmTest/kotlin/sk/ainet/apps/kgemma/Gemma4ReferenceSmokeTest.kt new file mode 100644 index 00000000..27d21aa4 --- /dev/null +++ b/llm-runtime/kgemma/src/jvmTest/kotlin/sk/ainet/apps/kgemma/Gemma4ReferenceSmokeTest.kt @@ -0,0 +1,94 @@ +package sk.ainet.apps.kgemma + +import java.nio.file.Path +import kotlin.io.path.exists +import kotlin.io.path.isDirectory +import kotlin.test.Test +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.Assumptions.assumeTrue +import sk.ainet.llm.api.ChatOptions +import sk.ainet.llm.api.ChatRequest +import sk.ainet.llm.api.FinishReason + +/** + * Reference smoke test #2 — Gemma-4 E2B SafeTensors (kgemma runner). + * + * Locked in as part of the SKaiNET 0.25.0 bump. Exercises: + * - SafeTensors load path. With SKaiNET 0.25.0 the BF16 SafeTensors path + * becomes policy-aware via `SafeTensorsParametersLoader.withPolicy` — + * this test stays on the default FP32-dequant adaptive default but pins + * the loader integration end-to-end. + * - Complex architecture: sliding-window attention + per-layer KV sharing + * (Gemma-4 specifics that the new policy-resolution pass must not break). + * - `Gemma4ChatModel.fromSafeTensors` → `InferenceRuntime` → `Tokenizer` → + * `Gemma4ChatTemplate` → `SkaiNetChatModel` wiring. + * + * Tagged `@Tag("smoke-reference")` so it runs only under + * `./gradlew test -PsmokeReference -PincludeIntegration`. Self-skips when + * `GEMMA4_E2B_SAFETENSORS_PATH` is not set. + */ +@Tag("smoke-reference") +@Tag("integration") +class Gemma4ReferenceSmokeTest { + + @Test + fun `Gemma-4 E2B SafeTensors produces non-empty greedy text`() { + val indexPath = locateCheckpoint() + assumeTrue(indexPath != null, "GEMMA4_E2B_SAFETENSORS_PATH not set or path missing.") + + val maxTokens = probeMaxTokens(default = 32) + + val model = Gemma4ChatModel.fromSafeTensors( + indexPath = indexPath!!.toString(), + options = ChatOptions( + temperature = 0f, + maxTokens = maxTokens, + ), + ) + try { + val response = model.call(ChatRequest("Say hello in one short sentence.")) + + println("[smoke-reference] Gemma-4 modelId=${response.modelId}") + println("[smoke-reference] Gemma-4 finish=${response.generations.firstOrNull()?.finishReason}") + println("[smoke-reference] Gemma-4 text='${response.text.replace("\n", "\\n")}'") + + assertTrue(response.text.isNotBlank(), "Gemma-4 chat returned blank text") + val finish = response.generations.firstOrNull()?.finishReason + assertNotNull(finish, "Gemma-4 missing finish reason") + assertTrue( + finish == FinishReason.STOP || finish == FinishReason.LENGTH || + finish == FinishReason.TOOL_CALL, + "Gemma-4 unexpected finish reason: $finish", + ) + val usage = response.usage + assertNotNull(usage, "Gemma-4 usage should be reported") + assertTrue(usage.completionTokens > 0, "Gemma-4 expected at least one completion token") + } finally { + model.close() + } + } + + private fun probeMaxTokens(default: Int): Int = + System.getenv("GEMMA4_SMOKE_MAX_TOKENS")?.trim()?.toIntOrNull()?.coerceAtLeast(1) ?: default + + private fun locateCheckpoint(): Path? { + val raw = System.getenv("GEMMA4_E2B_SAFETENSORS_PATH")?.trim().orEmpty() + if (raw.isEmpty()) return null + val p = Path.of(raw) + if (!p.exists()) return null + return when { + p.isDirectory() -> { + val idx = p.resolve("model.safetensors.index.json") + val single = p.resolve("model.safetensors") + when { + idx.exists() -> idx + single.exists() -> single + else -> null + } + } + else -> p + } + } +} diff --git a/llm-runtime/kllama/src/jvmTest/kotlin/sk/ainet/apps/kllama/Qwen3ReferenceSmokeTest.kt b/llm-runtime/kllama/src/jvmTest/kotlin/sk/ainet/apps/kllama/Qwen3ReferenceSmokeTest.kt new file mode 100644 index 00000000..486f60ab --- /dev/null +++ b/llm-runtime/kllama/src/jvmTest/kotlin/sk/ainet/apps/kllama/Qwen3ReferenceSmokeTest.kt @@ -0,0 +1,114 @@ +package sk.ainet.apps.kllama + +import java.io.File +import java.nio.file.Path +import kotlin.io.path.exists +import kotlin.test.Test +import kotlin.test.assertTrue +import kotlin.time.measureTime +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.Assumptions.assumeTrue +import sk.ainet.apps.llm.OptimizedLLMMode +import sk.ainet.apps.llm.OptimizedLLMRuntime +import sk.ainet.context.DirectCpuExecutionContext +import sk.ainet.io.JvmRandomAccessSource +import sk.ainet.io.model.QuantPolicy +import sk.ainet.lang.types.FP32 +import sk.ainet.models.qwen.QwenNetworkLoader + +/** + * Reference smoke test #1 — Qwen3-1.7B Q8_0 (kllama runner, GGUF). + * + * Locked in as part of the SKaiNET 0.25.0 bump. Exercises: + * - GGUF Q8_0 load path through the new 0.25.0 `Q8_0MatmulKernel` (BF16/Q8_0 + * matmul kernels were re-platformed onto the `KernelRegistry` SPI in 0.25.0). + * - Decoder-only LLM generation via [OptimizedLLMRuntime] DIRECT mode. + * - Qwen-specific `RoPEMode.SPLIT_HALF` + QK-Norm code paths. + * + * Tagged `@Tag("smoke-reference")` so it runs only under + * `./gradlew test -PsmokeReference -PincludeIntegration`. Self-skips via + * [assumeTrue] when the model file is not present, so CI without the artifact + * stays green. + * + * Model fallback chain (first match wins): + * 1. `QWEN3_1B7_MODEL_PATH` env var + * 2. `~/.cache/standapp/models/Qwen3-1.7B-Q8_0.gguf` + * 3. Recursive scan under `~/.lmstudio/models/` for `Qwen3-1.7B-Q8_0.gguf` + * 4. Recursive scan under `~/.cache/huggingface/hub/` for the same filename + */ +@Tag("smoke-reference") +@Tag("integration") +class Qwen3ReferenceSmokeTest { + + @Test + fun `Qwen3-1_7B Q8_0 generates non-empty greedy continuation`() { + val modelPath = locateModel() + assumeTrue(modelPath != null, "No Qwen3-1.7B-Q8_0 GGUF found — set QWEN3_1B7_MODEL_PATH.") + + runBlocking { + val ctx = DirectCpuExecutionContext() + + println("[smoke-reference] Loading Qwen3 tokenizer from $modelPath") + val tokenizer = JvmRandomAccessSource.open(modelPath.toString()).use { source -> + GGUFTokenizer.fromRandomAccessSource(source) + } + + println("[smoke-reference] Loading Qwen3 model (Q8_0, DEQUANTIZE_TO_FP32)") + val model = QwenNetworkLoader.fromGguf( + randomAccessProvider = { JvmRandomAccessSource.open(modelPath.toString()) }, + quantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32, + ).load(ctx) + + val runtime = OptimizedLLMRuntime( + model = model, + ctx = ctx, + mode = OptimizedLLMMode.DIRECT, + dtype = FP32::class, + ) + + val prompt = "What is the capital of France?" + val promptTokens = tokenizer.encode(prompt) + val steps = 16 + + val output = StringBuilder() + val elapsed = measureTime { + runtime.generate(prompt = promptTokens, steps = steps, temperature = 0.0f) { tokenId -> + output.append(tokenizer.decode(tokenId)) + } + } + val tokPerSec = steps.toDouble() / elapsed.inWholeMilliseconds * 1000 + println("[smoke-reference] Qwen3 produced '${output}' (${"%.2f".format(tokPerSec)} tok/s)") + + assertTrue(output.isNotBlank(), "Qwen3 produced blank text — generation pipeline broke") + } + } + + private fun locateModel(): Path? { + val explicit = System.getenv("QWEN3_1B7_MODEL_PATH")?.trim().orEmpty() + if (explicit.isNotEmpty()) { + val p = Path.of(explicit) + return if (p.exists()) p else null + } + val home = System.getProperty("user.home") + val direct = listOf( + Path.of(home, ".cache", "standapp", "models", "Qwen3-1.7B-Q8_0.gguf"), + ) + for (p in direct) if (p.exists()) return p + + val searchRoots = listOf( + Path.of(home, ".lmstudio", "models"), + Path.of(home, ".cache", "huggingface", "hub"), + ) + val targetName = "Qwen3-1.7B-Q8_0.gguf" + for (root in searchRoots) { + val rootFile = root.toFile() + if (!rootFile.isDirectory) continue + rootFile.walkTopDown() + .filter { it.isFile && it.name == targetName } + .firstOrNull() + ?.let { return it.toPath() } + } + return null + } +} diff --git a/llm-test/llm-test-java/src/test/java/sk/ainet/transformers/java/BertLeafReferenceSmokeTest.java b/llm-test/llm-test-java/src/test/java/sk/ainet/transformers/java/BertLeafReferenceSmokeTest.java new file mode 100644 index 00000000..bc513557 --- /dev/null +++ b/llm-test/llm-test-java/src/test/java/sk/ainet/transformers/java/BertLeafReferenceSmokeTest.java @@ -0,0 +1,119 @@ +package sk.ainet.transformers.java; + +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import sk.ainet.models.bert.java.KBertJava; +import sk.ainet.models.bert.java.KBertSession; + +import java.io.File; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +/** + * Reference smoke test #3 — BERT encoder + LEAF retrieval (kbert runner). + * + * Locked in as part of the SKaiNET 0.25.0 bump. Exercises: + * - SafeTensors load path for BERT-family encoders via {@link KBertJava}. + * With SKaiNET 0.25.0's new {@code DTypePolicy} surface the SafeTensors + * loader is policy-aware; this test stays on the adaptive default + * (FP32 dequant) and pins the consumer Java API end-to-end. + * - Cosine-similarity sanity check: paraphrases of "reset my password" must + * embed closer to each other than to an unrelated topic. Catches silent + * embedding regressions the way the existing + * {@link Llama3LeafSmokeTest} does for the LEAF-only path. + * - Java consumer surface (no Kotlin glue) — proves the published Java API + * contract still resolves against 0.25.0. + * + * Tagged {@code @Tag("smoke-reference")} so it runs only under + * {@code ./gradlew test -PsmokeReference -PincludeIntegration}. Self-skips + * via {@link org.junit.jupiter.api.Assumptions#assumeTrue} when the LEAF + * checkpoint is not available, so CI without artifacts stays green. + * + * Locator chain (first match wins): + * 1. {@code LEAF_MODEL_DIR} env var + * 2. {@code ~/.deliverance/MongoDB_mdbr-leaf-ir/} + * 3. {@code ~/.cache/huggingface/hub/models--MongoDB--mdbr-leaf-ir/snapshots//} + */ +@Tag("smoke-reference") +@Tag("integration") +class BertLeafReferenceSmokeTest { + + @Test + @Timeout(value = 5, unit = TimeUnit.MINUTES) + void leafEncoderProducesParaphraseAwareEmbeddings() throws Exception { + Path leafDir = resolveLeafModelDir(); + assumeTrue(leafDir != null, + "No LEAF model dir found — set LEAF_MODEL_DIR or place mdbr-leaf-ir " + + "under ~/.deliverance/MongoDB_mdbr-leaf-ir/."); + + try (KBertSession bert = KBertJava.loadSafeTensors(leafDir)) { + float[] embA = bert.encode("How do I reset my password?"); + float[] embB = bert.encode("What is the procedure to recover account access?"); + float[] embC = bert.encode("The Pacific Ocean is the largest body of water on Earth."); + + assertNotNull(embA, "LEAF returned null embedding for A"); + assertTrue(embA.length > 0, "LEAF returned empty embedding for A"); + assertTrue(embA.length == embB.length && embB.length == embC.length, + "LEAF embedding dimensions inconsistent: " + + embA.length + "/" + embB.length + "/" + embC.length); + + float simParaphrase = cosineSimilarity(embA, embB); + float simUnrelated = cosineSimilarity(embA, embC); + System.out.println( + "[smoke-reference] LEAF sim(paraphrase)=" + simParaphrase + + " sim(unrelated)=" + simUnrelated); + assertTrue(simParaphrase > simUnrelated, + "Expected paraphrase similarity (" + simParaphrase + + ") to exceed unrelated similarity (" + simUnrelated + ")"); + } + } + + private static Path resolveLeafModelDir() { + String env = System.getenv("LEAF_MODEL_DIR"); + if (env != null && !env.isBlank()) { + Path p = Path.of(env); + return Files.isDirectory(p) ? p : null; + } + Path deliverance = Path.of(System.getProperty("user.home"), + ".deliverance", "MongoDB_mdbr-leaf-ir"); + if (Files.isDirectory(deliverance)) return deliverance; + + Path snapshots = Path.of(System.getProperty("user.home"), + ".cache", "huggingface", "hub", + "models--MongoDB--mdbr-leaf-ir", "snapshots"); + if (Files.isDirectory(snapshots)) { + File[] children = snapshots.toFile().listFiles(File::isDirectory); + if (children != null) { + for (File c : children) { + Path candidate = c.toPath(); + if (Files.exists(candidate.resolve("model.safetensors"))) return candidate; + } + } + } + return null; + } + + private static float cosineSimilarity(float[] a, float[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException("Embedding dimensions must match: " + + a.length + " vs " + b.length); + } + double dot = 0.0; + double normA = 0.0; + double normB = 0.0; + for (int i = 0; i < a.length; i++) { + dot += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; + } + double denom = Math.sqrt(normA) * Math.sqrt(normB); + if (denom < 1e-12) return 0f; + return (float) (dot / denom); + } +} diff --git a/tests/smoke/smoke-models.json b/tests/smoke/smoke-models.json index ff87a7ef..70709f07 100644 --- a/tests/smoke/smoke-models.json +++ b/tests/smoke/smoke-models.json @@ -30,6 +30,7 @@ "format": "gguf", "instruct": true, "prompt": "What is the capital of France?", + "reference": true, "toolCalling": { "prompt": "What is 2 + 2?", "steps": 256 @@ -52,7 +53,8 @@ "runner": "kgemma", "model": "~/.lmstudio/models/lmstudio-community/gemma-4-E4B-it-GGUF/gemma-4-E4B-it-Q4_K_M.gguf", "format": "gguf", - "steps": 16 + "steps": 16, + "reference": true }, { "name": "all-MiniLM-L6-v2", @@ -68,7 +70,8 @@ "model": "~/.cache/huggingface/hub/models--MongoDB--mdbr-leaf-ir/snapshots/1bb4fc387c49dee1c10c2b22f59db758be87dcaa", "format": "safetensors", "prompt": "MongoDB is a NoSQL database", - "doc": "MongoDB stores data in BSON documents" + "doc": "MongoDB stores data in BSON documents", + "reference": true } ] } From f5857783e7780c60a1e297324db1a163cf024a51 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 25 May 2026 13:38:04 +0200 Subject: [PATCH 2/2] honour DTypePolicy in DecoderSafeTensorsLoader: BF16 KEEP_NATIVE path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Builds on #144. The DTypePolicy surface that PR added to the *NetworkLoader classes was forward-compat — accepted at the boundary but not yet observed at per-tensor load. This change makes the SafeTensors half real. When the consumer attaches a policy that admits BF16 — Require(BF16), Prefer(BF16), or OneOf containing BF16 — DecoderSafeTensorsLoader stops dequanting BF16 tensors and instead wraps the packed 2-bytes-per-element buffer in Bf16DenseTensorData (new in SKaiNET 0.25.0). The matmul dispatch in DefaultCpuOpsJvm detects Bf16TensorData at runtime and routes to the SIMD BF16 kernel — so a BF16 SafeTensors checkpoint now stays near its on-disk footprint in RAM instead of inflating ~2x to FP32. The `keepBf16Native` flag is computed eagerly from the policy and mirrors SafeTensorsParametersLoader.mapPolicyToBf16 upstream. LlamaNetworkLoader / QwenNetworkLoader / VoxtralNetworkLoader each forward `loader.dtypePolicy` into the new fifth constructor argument of DecoderSafeTensorsLoader. Default DTypePolicy.Any keeps the adaptive FP32 dequant behaviour, so existing callers see no change. GGUF loaders are intentionally unchanged: SKaiNET 0.25.0's StreamingGgufParametersLoader.validatePolicy() still rejects Require(BF16) for GGUF (no KEEP_NATIVE GGUF backing yet), and the transformer-repo's DecoderGgufWeightLoader inherits that constraint. Documented in the CHANGELOG's Deferred section. Tests: - LlamaNetworkLoaderDTypePolicyTest pins the validation contract on each policy arm + provider combination (BF16 accepted on SafeTensors, rejected on GGUF; FP16 / Int8 rejected everywhere; Prefer / OneOf always pass). - llama .api baseline regenerated for the new DecoderSafeTensorsLoader constructor signature; apiCheck green. - ./gradlew allTests green. Co-Authored-By: Claude Opus 4.7 (1M context) --- CHANGELOG.md | 27 +++- llm-inference/llama/api/jvm/llama.api | 4 +- .../models/llama/DecoderSafeTensorsLoader.kt | 54 +++++++- .../ainet/models/llama/LlamaNetworkLoader.kt | 2 +- .../LlamaNetworkLoaderDTypePolicyTest.kt | 123 ++++++++++++++++++ .../sk/ainet/models/qwen/QwenNetworkLoader.kt | 2 +- .../models/voxtral/VoxtralNetworkLoader.kt | 2 +- 7 files changed, 198 insertions(+), 16 deletions(-) create mode 100644 llm-inference/llama/src/jvmTest/kotlin/sk/ainet/models/llama/LlamaNetworkLoaderDTypePolicyTest.kt diff --git a/CHANGELOG.md b/CHANGELOG.md index b06df22c..235a8990 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,21 @@ window without a tagged 0.24.x release on either side. (`HybridTransformerBlock.compile()` will read this in a follow-up). Setting a non-`Any` value compiles today and starts taking effect when the compile-step plumbing lands — no API change at consumers. +- **SafeTensors BF16 KEEP_NATIVE** in `DecoderSafeTensorsLoader`. When the + consumer attaches a `DTypePolicy` that admits BF16 (`Require(BF16)`, + `Prefer(BF16)`, or `OneOf` containing BF16), the loader stops dequanting + BF16 tensors and instead wraps the packed 2-bytes-per-element buffer in + `Bf16DenseTensorData`. The matmul dispatch in `DefaultCpuOpsJvm` (SKaiNET + 0.25.0) detects `Bf16TensorData` at runtime and routes to the SIMD BF16 + kernel — so a BF16 SafeTensors checkpoint now stays near its on-disk + footprint in RAM instead of inflating ~2× to FP32. Threaded through + `LlamaNetworkLoader` / `QwenNetworkLoader` / `VoxtralNetworkLoader` + (each forwards `loader.dtypePolicy` into the + `DecoderSafeTensorsLoader(ctx, T::class, metadata, tied, dtypePolicy)` + constructor). The default value remains `DTypePolicy.Any` — adaptive + FP32 dequant, no behavioural change for existing callers. Validation + errors still fire at the `LlamaNetworkLoader.withDtypePolicy(...)` + boundary: `LlamaNetworkLoaderDTypePolicyTest` pins each policy arm. - **Three reference smoke tests with `@Tag("smoke-reference")`.** The new smoke tier exists alongside the existing `@Tag("integration")` filter and pins the three architectures we always want to run end-to-end: @@ -95,12 +110,12 @@ changes land in follow-up PRs. - **`HybridTransformerBlock.compile()` honoring the policy on `DagBuilder.op(..., dtypePolicy = …)` per the W6 SKaiNET PR.** Blocked on the side-map above. -- **`DecoderSafeTensorsLoader` / `DecoderGgufWeightLoader` actually - observing the policy at per-tensor load.** Both currently dequant - BF16 → FP32 unconditionally. The next step is to route them through - SKaiNET's `SafeTensorsParametersLoader.withPolicy(...)` / - `StreamingGgufParametersLoader.withPolicy(...)` so `Require(BF16)` - actually keeps native — a contained refactor with a clear contract. +- **`DecoderGgufWeightLoader` per-tensor policy enforcement.** The GGUF + loader still dequants BF16 → FP32 unconditionally — SKaiNET 0.25.0's + `StreamingGgufParametersLoader.validatePolicy()` itself rejects + `Require(BF16)` for GGUF today (no KEEP_NATIVE GGUF backing yet), so + this is parked until the engine grows that path. *(SafeTensors BF16 + KEEP_NATIVE shipped in this release — see Added.)* - **BOM-only versionless aliases in `libs.versions.toml`.** Currently every `skainet-*` alias still uses `version.ref = "skainet"` because the single-source bump is the lower-risk path during the 0.25.0 diff --git a/llm-inference/llama/api/jvm/llama.api b/llm-inference/llama/api/jvm/llama.api index b71163fc..81512635 100644 --- a/llm-inference/llama/api/jvm/llama.api +++ b/llm-inference/llama/api/jvm/llama.api @@ -49,8 +49,8 @@ public final class sk/ainet/models/llama/DecoderGgufWeights { } public final class sk/ainet/models/llama/DecoderSafeTensorsLoader { - public fun (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lsk/ainet/models/llama/LlamaModelMetadata;Z)V - public synthetic fun (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lsk/ainet/models/llama/LlamaModelMetadata;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lsk/ainet/models/llama/LlamaModelMetadata;ZLsk/ainet/lang/types/DTypePolicy;)V + public synthetic fun (Lsk/ainet/context/ExecutionContext;Lkotlin/reflect/KClass;Lsk/ainet/models/llama/LlamaModelMetadata;ZLsk/ainet/lang/types/DTypePolicy;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun load (Lkotlin/jvm/functions/Function0;)Lsk/ainet/models/llama/LlamaRuntimeWeights; public final fun loadToMap (Lkotlin/jvm/functions/Function0;)Lsk/ainet/models/llama/DecoderGgufWeights; } diff --git a/llm-inference/llama/src/commonMain/kotlin/sk/ainet/models/llama/DecoderSafeTensorsLoader.kt b/llm-inference/llama/src/commonMain/kotlin/sk/ainet/models/llama/DecoderSafeTensorsLoader.kt index 7460f94a..17e4d39c 100644 --- a/llm-inference/llama/src/commonMain/kotlin/sk/ainet/models/llama/DecoderSafeTensorsLoader.kt +++ b/llm-inference/llama/src/commonMain/kotlin/sk/ainet/models/llama/DecoderSafeTensorsLoader.kt @@ -12,7 +12,11 @@ import sk.ainet.io.safetensors.StreamingSafeTensorsReader import sk.ainet.io.safetensors.StreamingSafeTensorInfo import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.data.Bf16DenseTensorData +import sk.ainet.lang.tensor.data.TensorData +import sk.ainet.lang.types.BF16 import sk.ainet.lang.types.DType +import sk.ainet.lang.types.DTypePolicy import sk.ainet.lang.types.FP32 import kotlin.math.pow import kotlin.reflect.KClass @@ -24,17 +28,39 @@ import kotlin.reflect.KClass * Handles: * - HuggingFace → GGUF tensor name mapping * - Q4 + .qb companion tensor dequantization to FP32 - * - BF16/F16 dequantization to FP32 + * - BF16/F16 dequantization to FP32 (default) + * - BF16 KEEP_NATIVE when [dtypePolicy] admits BF16 (SKaiNET 0.25.0): + * constructs a [Bf16DenseTensorData]-backed tensor so the BF16 matmul + * kernel routes via `DefaultCpuOpsJvm` without a 2× memory blow-up. * - Shape normalization ([1, dim] norms → [dim]) * - Tied word embeddings (output.weight = token_embd.weight) + * + * @param dtypePolicy declarative dtype constraint. Default [DTypePolicy.Any] + * = adaptive dequant. `Require(BF16)` / `Prefer(BF16)` / `OneOf` containing + * BF16 = KEEP_NATIVE path. Mirrors the SKaiNET 0.25.0 + * `SafeTensorsParametersLoader.mapPolicyToBf16` semantics. */ public class DecoderSafeTensorsLoader( private val ctx: ExecutionContext, private val dtype: KClass, private val metadata: LlamaModelMetadata, - private val tiedEmbeddings: Boolean = false + private val tiedEmbeddings: Boolean = false, + private val dtypePolicy: DTypePolicy = DTypePolicy.Any, ) { + /** + * Returns `true` iff [dtypePolicy] wants BF16 weights kept in their + * packed 2-bytes-per-element form rather than dequantised to FP32. + * Matches the engine-side `SafeTensorsParametersLoader.mapPolicyToBf16` + * cases that resolve to `Bf16LoadPolicy.KEEP_NATIVE`. + */ + private val keepBf16Native: Boolean = when (val p = dtypePolicy) { + DTypePolicy.Any -> false + is DTypePolicy.Require -> p.target == BF16 + is DTypePolicy.Prefer -> p.target == BF16 + is DTypePolicy.OneOf -> BF16 in p.allowed + } + /** * Load weights from SafeTensors file into a flat tensor map with GGUF-canonical names. * Useful for feeding into [WeightMapper] with a [WeightNameResolver]. @@ -66,10 +92,28 @@ public class DecoderSafeTensorsLoader( } DataType.BFLOAT16 -> { val bytes = reader.loadTensorData(info) - val floats = dequantBF16(bytes) val targetShape = normalizeNormShape(info.shape) - @Suppress("UNCHECKED_CAST") - ctx.fromFloatArray(targetShape, dtype, floats) as Tensor + if (keepBf16Native) { + // KEEP_NATIVE: wrap the packed 2-bytes-per-element + // BF16 buffer as `Bf16DenseTensorData`. The matmul + // dispatch in `DefaultCpuOpsJvm` (SKaiNET 0.25.0) + // detects `Bf16TensorData` at runtime and routes + // to the SIMD BF16 kernel — avoiding the 2× memory + // inflation of the FP32 dequant path. + // + // The declared dtype generic stays `T` (typically + // FP32) because consumers don't care about the + // physical encoding — the get/set surface still + // returns Float. Mirrors the + // `GemmaMemSegConverter` pattern for Q4/Q8. + val data = Bf16DenseTensorData.fromRawBytes(targetShape, bytes) + @Suppress("UNCHECKED_CAST") + ctx.fromData(data as TensorData, dtype) as Tensor + } else { + val floats = dequantBF16(bytes) + @Suppress("UNCHECKED_CAST") + ctx.fromFloatArray(targetShape, dtype, floats) as Tensor + } } DataType.FLOAT16 -> { val bytes = reader.loadTensorData(info) diff --git a/llm-inference/llama/src/commonMain/kotlin/sk/ainet/models/llama/LlamaNetworkLoader.kt b/llm-inference/llama/src/commonMain/kotlin/sk/ainet/models/llama/LlamaNetworkLoader.kt index 646e639b..f0094202 100644 --- a/llm-inference/llama/src/commonMain/kotlin/sk/ainet/models/llama/LlamaNetworkLoader.kt +++ b/llm-inference/llama/src/commonMain/kotlin/sk/ainet/models/llama/LlamaNetworkLoader.kt @@ -146,7 +146,7 @@ public class LlamaNetworkLoader @PublishedApi internal constructor( loader.loadToMapStreaming(ctx) } is WeightsProvider.SafeTensors -> { - val loader = DecoderSafeTensorsLoader(ctx, T::class, wp.metadata, wp.tiedEmbeddings) + val loader = DecoderSafeTensorsLoader(ctx, T::class, wp.metadata, wp.tiedEmbeddings, dtypePolicy) @Suppress("UNCHECKED_CAST") loader.loadToMap(wp.randomAccessProvider) as DecoderGgufWeights } diff --git a/llm-inference/llama/src/jvmTest/kotlin/sk/ainet/models/llama/LlamaNetworkLoaderDTypePolicyTest.kt b/llm-inference/llama/src/jvmTest/kotlin/sk/ainet/models/llama/LlamaNetworkLoaderDTypePolicyTest.kt new file mode 100644 index 00000000..c88d7bcd --- /dev/null +++ b/llm-inference/llama/src/jvmTest/kotlin/sk/ainet/models/llama/LlamaNetworkLoaderDTypePolicyTest.kt @@ -0,0 +1,123 @@ +package sk.ainet.models.llama + +import kotlinx.io.Source +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertSame +import sk.ainet.io.RandomAccessSource +import sk.ainet.lang.types.BF16 +import sk.ainet.lang.types.DTypePolicy +import sk.ainet.lang.types.FP16 +import sk.ainet.lang.types.FP32 +import sk.ainet.lang.types.Int8 + +/** + * Pins the `DTypePolicy` validation contract on `LlamaNetworkLoader`: + * + * - the default value is [DTypePolicy.Any]; + * - `withDtypePolicy(Require(FP32))` always succeeds (it's the loader's + * native output dtype); + * - `withDtypePolicy(Require(BF16))` succeeds **only** for SafeTensors- + * backed loaders — the GGUF path mirrors the engine's eager rejection + * in `StreamingGgufParametersLoader.validatePolicy()` because the + * transformer-repo GGUF chain still dequants BF16 to FP32; + * - `withDtypePolicy(Require(FP16))` / `Require(Int8)` etc. always + * reject — the loader doesn't fabricate dtypes the source files + * don't carry. + * - `Prefer` / `OneOf` arms never raise (they're soft constraints). + * + * No model files are read — these tests only construct loader instances + * to exercise the validation boundary added in PR #144 (`*NetworkLoader. + * withDtypePolicy`). + */ +class LlamaNetworkLoaderDTypePolicyTest { + + private val noopSourceProvider: () -> Source = { error("source not used in validation tests") } + private val noopRandomAccessProvider: () -> RandomAccessSource = { error("source not used in validation tests") } + + private val anyMetadata = LlamaModelMetadata( + architecture = "llama", + embeddingLength = 4, + contextLength = 8, + blockCount = 1, + headCount = 1, + kvHeadCount = 1, + feedForwardLength = 4, + ropeDimensionCount = 4, + vocabSize = 4, + ropeFreqBase = 10_000f, + rmsNormEps = 1e-5f, + ) + + @Test + fun `default policy is Any`() { + val loader = LlamaNetworkLoader.fromGguf(sourceProvider = noopSourceProvider) + assertSame(DTypePolicy.Any, loader.dtypePolicy) + } + + @Test + fun `Require(FP32) is accepted on both GGUF and SafeTensors paths`() { + val gguf = LlamaNetworkLoader.fromGguf(sourceProvider = noopSourceProvider) + .withDtypePolicy(DTypePolicy.Require(FP32)) + assertEquals(DTypePolicy.Require(FP32), gguf.dtypePolicy) + + val safetensors = LlamaNetworkLoader.fromSafeTensors( + metadata = anyMetadata, randomAccessProvider = noopRandomAccessProvider, + ).withDtypePolicy(DTypePolicy.Require(FP32)) + assertEquals(DTypePolicy.Require(FP32), safetensors.dtypePolicy) + } + + @Test + fun `Require(BF16) is accepted on SafeTensors but rejected on GGUF`() { + val safetensors = LlamaNetworkLoader.fromSafeTensors( + metadata = anyMetadata, randomAccessProvider = noopRandomAccessProvider, + ).withDtypePolicy(DTypePolicy.Require(BF16)) + assertEquals(DTypePolicy.Require(BF16), safetensors.dtypePolicy) + + assertFailsWith { + LlamaNetworkLoader.fromGguf(sourceProvider = noopSourceProvider) + .withDtypePolicy(DTypePolicy.Require(BF16)) + } + } + + @Test + fun `Require(FP16) is rejected on both paths`() { + assertFailsWith { + LlamaNetworkLoader.fromGguf(sourceProvider = noopSourceProvider) + .withDtypePolicy(DTypePolicy.Require(FP16)) + } + assertFailsWith { + LlamaNetworkLoader.fromSafeTensors( + metadata = anyMetadata, randomAccessProvider = noopRandomAccessProvider, + ).withDtypePolicy(DTypePolicy.Require(FP16)) + } + } + + @Test + fun `Require(Int8) is rejected on both paths`() { + assertFailsWith { + LlamaNetworkLoader.fromGguf(sourceProvider = noopSourceProvider) + .withDtypePolicy(DTypePolicy.Require(Int8)) + } + assertFailsWith { + LlamaNetworkLoader.fromSafeTensors( + metadata = anyMetadata, randomAccessProvider = noopRandomAccessProvider, + ).withDtypePolicy(DTypePolicy.Require(Int8)) + } + } + + @Test + fun `Prefer and OneOf are always accepted regardless of target`() { + // Prefer is a soft constraint — never raises. + LlamaNetworkLoader.fromGguf(sourceProvider = noopSourceProvider) + .withDtypePolicy(DTypePolicy.Prefer(BF16)) + LlamaNetworkLoader.fromGguf(sourceProvider = noopSourceProvider) + .withDtypePolicy(DTypePolicy.Prefer(FP16)) + + // OneOf is a restricted set; non-empty is the only invariant. + LlamaNetworkLoader.fromSafeTensors( + metadata = anyMetadata, randomAccessProvider = noopRandomAccessProvider, + ).withDtypePolicy(DTypePolicy.OneOf(setOf(FP32, BF16))) + } +} diff --git a/llm-inference/qwen/src/commonMain/kotlin/sk/ainet/models/qwen/QwenNetworkLoader.kt b/llm-inference/qwen/src/commonMain/kotlin/sk/ainet/models/qwen/QwenNetworkLoader.kt index 8a93b17f..0d26e15c 100644 --- a/llm-inference/qwen/src/commonMain/kotlin/sk/ainet/models/qwen/QwenNetworkLoader.kt +++ b/llm-inference/qwen/src/commonMain/kotlin/sk/ainet/models/qwen/QwenNetworkLoader.kt @@ -148,7 +148,7 @@ public class QwenNetworkLoader @PublishedApi internal constructor( loader.loadToMapStreaming(ctx) } is WeightsProvider.SafeTensors -> { - val loader = DecoderSafeTensorsLoader(ctx, T::class, wp.metadata, wp.tiedEmbeddings) + val loader = DecoderSafeTensorsLoader(ctx, T::class, wp.metadata, wp.tiedEmbeddings, dtypePolicy) @Suppress("UNCHECKED_CAST") loader.loadToMap(wp.randomAccessProvider) as DecoderGgufWeights } diff --git a/llm-inference/voxtral/src/commonMain/kotlin/sk/ainet/models/voxtral/VoxtralNetworkLoader.kt b/llm-inference/voxtral/src/commonMain/kotlin/sk/ainet/models/voxtral/VoxtralNetworkLoader.kt index 31d33172..7633285f 100644 --- a/llm-inference/voxtral/src/commonMain/kotlin/sk/ainet/models/voxtral/VoxtralNetworkLoader.kt +++ b/llm-inference/voxtral/src/commonMain/kotlin/sk/ainet/models/voxtral/VoxtralNetworkLoader.kt @@ -170,7 +170,7 @@ public class VoxtralNetworkLoader @PublishedApi internal constructor( loader.loadToMapStreaming(ctx) } is WeightsProvider.SafeTensors -> { - val loader = DecoderSafeTensorsLoader(ctx, T::class, wp.metadata, wp.tiedEmbeddings) + val loader = DecoderSafeTensorsLoader(ctx, T::class, wp.metadata, wp.tiedEmbeddings, dtypePolicy) @Suppress("UNCHECKED_CAST") loader.loadToMap(wp.randomAccessProvider) as DecoderGgufWeights }