Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llm-core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ kotlin {
// versions. Bumping the engine is then a one-line change at the
// top of `gradle/libs.versions.toml`.
implementation(project.dependencies.platform(project(":llm-bom")))
api(project(":transformer-core"))
implementation(libs.skainet.lang.core)
implementation(libs.skainet.compile.dag)
implementation(libs.skainet.compile.opt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,11 @@ public class HybridTransformerBlock<T : DType, V>(
// the same name — so MHA can't gate its own dump on the block id.
// Toggle the static flag from here, where we know which block we're in.
val isMhaCall = dumpMha && module is MultiHeadAttention<*, *>
if (isMhaCall) sk.ainet.lang.nn.transformer.MultiHeadAttentionDiag.shouldDumpThisCall = true
if (isMhaCall) {
// wire transformer-core's MHA diagnostic sink to llm-core's platform dumpStats (idempotent)
sk.ainet.lang.nn.transformer.mhaStatSink = { l, t -> sk.ainet.apps.llm.diag.dumpStats(l, t) }
sk.ainet.lang.nn.transformer.MultiHeadAttentionDiag.shouldDumpThisCall = true
}
tmp = module.forward(tmp, ctx)
if (isMhaCall) sk.ainet.lang.nn.transformer.MultiHeadAttentionDiag.shouldDumpThisCall = false
outputs[i + 1] = tmp
Expand Down
1 change: 1 addition & 0 deletions settings.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ if (providers.gradleProperty("useLocalSkainet").orNull == "true") {
rootProject.name = "SKaiNET-transformers"

include("llm-api")
include("transformer-core")
include("llm-core")
include("llm-agent")
include("llm-providers")
Expand Down
41 changes: 41 additions & 0 deletions transformer-core/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# transformer-core

Framework NN primitives — attention, the KV-cache family, embedding, norms, RoPE, SwiGLU/GeGLU FFN,
residual, linear projection — extracted from `llm-core` so they build on the **full Kotlin target matrix
including `androidNativeArm32/Arm64`** (the on-device ARM path). Depends only on `skainet-lang-core`
(which has androidNative); no io/compile/backend deps.

`llm-core` `api`-depends on this module and **re-exports** it, so existing consumers are unaffected.
ARM-native consumers (e.g. `skainet-whisper-kmp`) depend on `transformer-core` directly and reuse
KV-cache/attention instead of reimplementing.

## Why
`llm-core`'s primitives only need `lang-core`, but were trapped there: `llm-core`'s *other* deps
(`io-gguf`, `io-core`, `compile-*`, `backend-cpu`) lack androidNative, so ARM-native consumers couldn't
depend on it. The primitives are **dtype-agnostic** (just call `ops.*`), so this target generalization is
orthogonal to the quant/dtype generalization (issue #178) — they meet cleanly at these primitives.

## What moved (15 files, lang-core-only)
`transformer/*` (KVCache, RoPE, ResidualAdd, MultiHeadAttention, GeGLUFFN, SwiGLUFFN, XIELUActivation,
LayerScalarMul, LinearProjection, VoidDense), `layers/*` (Embedding*), `normalization/RMSNormalization`,
`dsl/TransformerDsl`. **Kept in `llm-core`:** `dsl/decoder/*` (DecoderTransformerNetwork needs
`apps.llm.HybridTransformerBlock`, which is compile-opt-coupled).

One back-reference decoupled: `MultiHeadAttention`'s diagnostic `dumpStats` → a settable `mhaStatSink`
(default no-op) that `HybridTransformerBlock` wires to llm-core's platform `dumpStats` (no behaviour lost).

## Verified
`:transformer-core:` compiles for jvm + androidNativeArm32 + arm64; `:llm-core:jvmTest` green (5/5) via
the re-export.

## Landing (for the maintainer)
Branch `feature/transformer-core` was cut from `release/0.31.0`. To land on `develop` (which has #178's
merged #179/#180):
1. `git fetch origin && git rebase origin/develop` — **no conflicts expected on the moved files**: #178's
merged work is in the model layer (`GemmaPackedWeights`) + engine (`ops.transpose` Q8_0/Q4_0), not these
primitives. (Verified against local refs; re-check against fresh `develop`.)
2. Build the full target matrix + `:llm-core:` tests; PR; CI-publish; bump the `skainet`/transformers pins.
3. **Note for future quant work:** the pre-transpose-marker (#178 "Solution C") will land in
`LinearProjection.kt`, which now lives **here**, not `llm-core`. And `RowDequantSource` + packed-weight
packing (today in `sk.ainet.models.gemma`) are the next candidates to hoist into a shared `quant` layer
or this module — that's what makes quant reusable across models *and* whisper.
43 changes: 43 additions & 0 deletions transformer-core/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import org.jetbrains.kotlin.gradle.ExperimentalWasmDsl
import org.jetbrains.kotlin.gradle.dsl.JvmTarget

plugins {
alias(libs.plugins.kotlinMultiplatform)
alias(libs.plugins.androidMultiplatformLibrary)
alias(libs.plugins.vanniktech.mavenPublish)
}

// Framework NN primitives (attention, KV-cache family, embedding, norms, RoPE, FFNs) extracted from
// llm-core so they build on the FULL target matrix — including androidNative (the 32-bit box + phones).
// Depends ONLY on skainet-lang-core (which has androidNative); no io/compile/backend deps. llm-core
// re-exports this module, so existing consumers are unaffected; ARM-native consumers depend on it directly.
kotlin {
android {
namespace = "sk.ainet.lang.nn"
compileSdk = libs.versions.android.compileSdk.get().toInt()
minSdk = libs.versions.android.minSdk.get().toInt()
compilerOptions { jvmTarget.set(JvmTarget.JVM_11) }
}

jvm()
androidNativeArm32()
androidNativeArm64()
iosArm64()
iosSimulatorArm64()
linuxX64()
linuxArm64()
macosArm64()
js { browser() }
@OptIn(ExperimentalWasmDsl::class) wasmJs { browser() }
@OptIn(ExperimentalWasmDsl::class) wasmWasi { nodejs() }

sourceSets {
commonMain.dependencies {
implementation(project.dependencies.platform(project(":llm-bom")))
api(libs.skainet.lang.core) // public API is lang-core-typed (Tensor/Module/ExecutionContext)
}
commonTest.dependencies {
implementation(libs.kotlin.test)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import kotlin.reflect.KClass
*
* Computes: down_proj(gelu(gate_proj(x)) * up_proj(x))
*
* Identical parameter layout to [SwiGLUFFN] so [sk.ainet.apps.llm.weights.LlamaGGUFNameResolver]
* Identical parameter layout to [SwiGLUFFN] so `LlamaGGUFNameResolver` (llm-core)
* maps the same GGUF tensor names. The only difference is the activation
* (GELU instead of SiLU).
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,10 +474,16 @@ public class MultiHeadAttention<T : DType, V>(
* references. Stats are over the *whole* tensor, not just last
* position — different MHA substeps have different shapes. */
private fun mhaDumpStat(label: String, t: Tensor<T, V>) {
sk.ainet.apps.llm.diag.dumpStats(label, t)
mhaStatSink?.invoke(label, t)
}
}

/**
* Optional diagnostic sink for MHA substep stats — decouples `transformer-core` from llm-core's
* platform `dumpStats`. Defaults to no-op (diagnostics off); llm-core wires its `dumpStats` into it.
*/
public var mhaStatSink: ((String, Tensor<*, *>) -> Unit)? = null

/**
* Per-call MHA-substep dump gate. The MHA module itself is named just `"attn"`
* (every block's MHA shares that name), so we can't gate the dump from inside
Expand Down