Skip to content

Kotlin/Native inference produces deterministic nonsense across model families and runtime paths (JVM unaffected) #104

@michalharakal

Description

@michalharakal

Kotlin/Native inference produces deterministic nonsense across model families and runtime paths (JVM unaffected)

Summary

On Kotlin/Native (verified on macosArm64 against SKaiNET-transformers develop @ commit 4cd1da9, SKaiNET 0.23.0), generating tokens via the kllama runtime — and via OptimizedLLMRuntime + QwenNetworkLoader — produces deterministic but meaningless token sequences. The same model files generate sensibly on the JVM target. Issue persists across:

  • Both runtime paths: deprecated LlamaRuntime + LlamaIngestion, and the new OptimizedLLMRuntime + QwenNetworkLoader.fromGguf(...).load(...).
  • Both load paths: sequential (sourceProvider, slurps to ByteArray) and streaming (randomAccessProvider, exercises the new POSIX-pread RandomAccessSource shipped in 0.23.0 / #591).
  • Multiple model families: Llama-3.2-1B-Instruct (Q8), TinyLlama-1.1B-Chat-v1.0 (Q8, Llama-2 arch), Qwen2.5-0.5B-Instruct (Q8), Qwen3-1.7B-Q8_0 (Q8).
  • Multiple compute backends: the SKaiNET CPU backend, and an external macOS-native execution backend wired in via BackendRegistry.register(...). On Qwen2.5-0.5B, the two backends emit bit-for-bit identical output — same garbage, same tokens, same order, every run. That's the cleanest signal that the problem is upstream of the compute layer: two different TensorOps implementations agree perfectly, which only happens if the bug is somewhere they share (weight load, dequantization, the runtime's forward orchestration, or some K/N stdlib / cinterop quirk).

Reproduction

The shortest repro uses SKaiNET-transformers' own native CLI binary (no third-party code involved):

./gradlew :llm-runtime:kllama:linkReleaseExecutableMacosArm64
./llm-runtime/kllama/build/bin/macosArm64/releaseExecutable/kllama.kexe \
    /path/to/Llama-3.2-1B-Instruct-Q8_0.gguf \
    "The capital of France is" 32 0.0

Output, deterministic on every restart:

Generating 32 tokens with temperature=0.0...
---
 Bodies Bodies Bodies Bodies Bodies Bodies … reigning reigning reigning

Same kllama binary on tinyllama-1.1b-chat-v1.0.Q8_0.gguf:

witzwitzwitzwitzwitzwitzwitzwitzwitzwitzwitzwitz…

OptimizedLLMRuntime + QwenNetworkLoader on Qwen2.5-0.5B-Instruct-Q8_0.gguf (in a small consumer CLI we built to verify the streaming path works):

bynDRAM ASAP */čĊčĊčĊ appréci(animation mur Offline bénéfic coppia wrzeÅĽnia…

All deterministic at temperature 0.0; tokens are within vocab but semantically nonsense.

Cleanest single repro for triage

Two consumer CLI invocations, same prompt, same model, two different BackendProviders:

$ ./llama3-native-cli -m Qwen2.5-0.5B-Instruct-Q8_0.gguf --backend cpu              --steps 32 --temperature 0.0 "What is 17 * 23?"
Backend: CPU (SIMD)  (priority=0)
Tokenizer: BPE (model=gpt2)
Assistant: bynDRAM ASAP */čĊčĊčĊ appréci(animation mur Offline bénéfic coppia wrzeÅĽnia{i]';Ċodonå¡ijåįķä½įæĪĬ åŃĹéducationLUaniobihrÃł,GLæĭ¼éٳournemouth_endian)$_.AdapterView.nlmMurlu

$ ./llama3-native-cli -m Qwen2.5-0.5B-Instruct-Q8_0.gguf --backend platform-native  --steps 32 --temperature 0.0 "What is 17 * 23?"
Backend: Platform-native (macOS)  (priority=100)
Tokenizer: BPE (model=gpt2)
Assistant: bynDRAM ASAP */čĊčĊčĊ appréci(animation mur Offline bénéfic coppia wrzeÅĽnia{i]';Ċodonå¡ijåįķä½įæĪĬ åŃĹéducationLUaniobihrÃł,GLæĭ¼éٳournemouth_endian)$_.AdapterView.nlmMurlu

Bit-for-bit identical output across two completely different TensorOps implementations. Same prompt on the JVM target with the same loader stack produces sensible output.

What's been ruled out

  • EOS resolution: tokenizer.eosTokenId == 128009 (<|eot_id|>) for Llama-3.2 — correct. Confirmed via diagnostic print before runtime.generate(...).
  • Backend correctness on small models: Qwen2.5-0.5B → bit-identical output on two different backends (proof above). Rules out compute layer for small / single-block-size cases.
  • Tool calling / agent loop: gibberish reproduces with the raw kllama.kexe baseline that does no tool calling at all — runtime.generate(...) straight after LlamaIngestion.load(...).
  • GGUF reader path: same gibberish via the legacy Source-based reader and via the new POSIX-pread StreamingGGUFReader path. Not a slurp-vs-stream issue.
  • JVM: same loaders + runtimes, exercised via the existing JUnit test suite (gated on TINYLLAMA_MODEL_PATH), produce sensible output on the JVM target. K/N-specific.
  • 0.23.0 fixes: the K/N pread fix (#591) and the lazy zero-init API (#588) did not change the symptom.

Most likely places to look

Ordered by suspicion (the TensorOps-agnostic invariant pins this somewhere shared):

  1. GGUF Q8 dequantization on K/N. QuantPolicy.DEQUANTIZE_TO_FP32 is the only K/N-viable policy (MemSegWeightConverter is jvmMain-only). If the K/N Q8 dequant has a bug, every downstream tensor is silently corrupted; both backends consume the same corrupt tensors and produce the same nonsense — exactly what we see. A quick verification path: snapshot a known weight tensor (e.g. token_embd.weight[0,:]) on JVM and on K/N and compare element-wise; first divergence localises the bug.
  2. OptimizedLLMRuntime / LlamaRuntime forward orchestration on K/N. Both runtimes are independent code paths and both fail identically — but they share helpers (RoPE position table, attention mask construction, KV cache layout). A K/N-specific bug in one of those helpers would surface in both.
  3. K/N stdlib / cinterop edge case — less likely but possible. E.g. FloatArray of unusual size silently zeroing a tail, or a cinterop boundary on Accelerate calls passing wrong sizes.

(1) is by far the most likely; would suggest tackling first because it's the easiest to falsify with a one-page test.

Separate, related bug (filing separately, not this issue)

On larger models (Qwen3-1.7B-Q8_0, ~1.8 GiB) the CPU backend and platform-native backend diverge — produce different nonsense. That divergence is real backend-side disagreement and points at one or more compute ops the larger / GQA-using architecture exercises that the small Qwen2.5-0.5B doesn't. This is a backend issue, separate from the K/N inference bug above; mentioned only because it confirms (1)/(2) as different from (3).

Why this matters

K/N is the only target for consumers shipping single-binary CLIs or non-JVM mobile/desktop apps on top of SKaiNET-transformers. The JVM-only happy path is fine for backend services but rules out a large slice of intended consumers. The K/N native CLI in this repo (kllama.kexe) currently doesn't generate sensible output for any tested model — that's a strong "K/N is broken end-to-end" signal worth attending to even before any external consumer.

Triage helpers

  • Same machine, same JDK toolchain (21), same model files. Only the Kotlin target differs.
  • Tested against release/0.23.0 of SKaiNET (commit e0b228c2 plus the docs cleanup) and develop of SKaiNET-transformers @ 4cd1da9.
  • Happy to produce a JVM-vs-K/N logits diff (forward pass on a fixed input, position 0, first 16 logits) on request — that should localise where the K/N path diverges.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions