From 25af0430cdc461e2f92d12d82bb38f208ab7f91c Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Wed, 29 Apr 2026 23:10:21 +0200 Subject: [PATCH] feat(native-cpu): zero-copy Q4_K MemSeg kernel + SPI sibling (PR 3 of 5) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR 3 of the staged native-FFM rollout per docs/.../perf/native-ffm-plan.adoc. Closes the M4↔M5 zero-copy story for mmap'd Q4_K weights: callers that already hold off-heap weight bytes (mmap'd .gguf files, shared arenas) skip the staging ByteArray → MemorySegment copy that NativeQ4KMatmulKernel.matmul performs on every call. SPI surface (skainet-backend-api/src/jvmMain): - Q4KMemSegMatmulKernel — JVM-only sibling of Q4KMatmulKernel; same block layout / lazy-dmin contract, but `weight` is a java.lang.foreign.MemorySegment with a Long byte offset. KMP-safe positioning: lives in jvmMain (not commonMain) because java.lang.foreign isn't available on Native / JS / Wasm targets. - MemSegKernelProvider — JVM-only sibling of KernelProvider that exposes a `matmulQ4KMemSeg(): Q4KMemSegMatmulKernel?` accessor with a `null`-defaulting body. Lookup pattern at the call site: val kernel = (KernelRegistry.bestAvailable() as? MemSegKernelProvider) ?.matmulQ4KMemSeg() ?: heapFallback() Doesn't fork the registry — providers opt into MemSeg surfaces by implementing both interfaces; smart-cast does the rest. Adding `matmulQ4KMemSeg` directly to KernelProvider would have broken commonMain (MemorySegment is JVM-only). Native side (skainet-backend-native-cpu): - NativeQ4KMemSegMatmulKernel reuses PR 2's skainet_q4k_matmul C symbol — the kernel just sees `const uint8_t*` and is oblivious to whether the bytes were staged through an arena or read directly from a caller-owned segment. The weight pointer is forwarded straight through; only input/output go through small confined-arena copies (those are usually a few KB and produced/consumed on the heap by the surrounding forward pass). - Validates the segment is large enough for `(inputDim/256) * outputDim * 144` bytes from the given offset and rejects undersized segments with IllegalArgumentException — without it, an undersized segment would crash the JVM with SIGSEGV from the C side. - weightByteOffset is Long on the Kotlin side narrowing to int32_t at the FFM boundary; we require <= Int.MAX_VALUE for now and document the eventual int64_t-offset overload as a follow-up. No current LLM single-tensor exceeds 2 GB. - NativeKernelProvider now implements both KernelProvider and MemSegKernelProvider; NativeKernelProviderFactory delegates both via `by NativeKernelProvider`. Without the second `by`, the factory instance the registry hands out would fail the smart-cast even though the underlying singleton implements both interfaces. Tests (skainet-backend-native-cpu/src/jvmTest): - NativeQ4KMemSegMatmulKernelParityTest — 7 tests asserting bit-identical output (compared via Float.toRawBits, no tolerance) to NativeQ4KMatmulKernel across single-block / multi-block / LLM-typical shapes. The bit-identical contract is the right bar: same C symbol, same inputs ⇒ same outputs; any drift means the wrapper added arithmetic. - Honors-non-zero-weight-byte-offset and rejects-undersized-segment cases for the new validation logic. - Provider/factory smart-cast tests confirm the SPI plumbing works end-to-end (NativeKernelProvider as MemSegKernelProvider succeeds; factory ditto). - Q4KMatmulMicrobenchTest extended: heap-copy vs zero-copy at LLM shapes. Weight segment pre-allocated in an Arena.ofShared outside the timed region — that's the realistic load profile (mmap once, reuse across forward passes). Microbench numbers (Linux x86_64, JDK 21.0.10, gcc 13.3 -O3 -ffast- math; warmup=20, samples=21, median µs): shape heap memseg zero-copy speedup memseg vs panama 1024² 360 369 0.98× 5.05× 2048² 1317 1284 1.03× 4.66× 4096² 6206 5184 1.20× 4.48× Honest read: zero-copy is noise at small shapes (the staged copy is sub-1MB; arena allocator + memcpy throughput hide it) and a real +20% saving at 4096² (9 MB weight copy starts to dominate cache pressure). Production loads on actual LLMs will be larger still and will benefit more — plus they'll save on resident memory because the heap path materializes a copy of every weight in JVM heap on top of the off-heap segment. Verification (linux-x86_64, JDK 21.0.10): - :skainet-backends:skainet-backend-native-cpu:jvmTest — 15/15 (3 pipeline + 5 heap-parity + 7 memseg-parity, microbench skipped) - :skainet-backends:skainet-backend-cpu:jvmTest — 218/218 (no regression) - :skainet-backends:skainet-backend-api:jvmTest — 0/0 (no tests yet) Out of scope (deferred per asciidoc staging): - PR 4: NEON / AVX2 intrinsics + cross-arch CI matrix - PR 5: native FP32 / Q6_K / Q8_0 kernels - int64_t weight offset overload (current int32_t limit hit at 2 GB per single segment slice) - Panama priority-50 implementation of MemSegKernelProvider — Panama already has Q4_K MemSeg internals; exposing through the new SPI is a small follow-up and lets the smart-cast cascade work even when the native provider is unavailable Co-Authored-By: Claude Opus 4.7 (1M context) --- .../api/kernel/MemSegKernelProvider.kt | 37 ++++ .../api/kernel/Q4KMemSegMatmulKernel.kt | 52 ++++++ .../ainet/exec/kernel/NativeKernelProvider.kt | 30 ++- .../kernel/NativeKernelProviderFactory.kt | 21 ++- .../kernel/NativeQ4KMemSegMatmulKernel.kt | 107 +++++++++++ .../NativeQ4KMemSegMatmulKernelParityTest.kt | 174 ++++++++++++++++++ .../exec/kernel/Q4KMatmulMicrobenchTest.kt | 72 +++++--- 7 files changed, 461 insertions(+), 32 deletions(-) create mode 100644 skainet-backends/skainet-backend-api/src/jvmMain/kotlin/sk/ainet/backend/api/kernel/MemSegKernelProvider.kt create mode 100644 skainet-backends/skainet-backend-api/src/jvmMain/kotlin/sk/ainet/backend/api/kernel/Q4KMemSegMatmulKernel.kt create mode 100644 skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeQ4KMemSegMatmulKernel.kt create mode 100644 skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ4KMemSegMatmulKernelParityTest.kt diff --git a/skainet-backends/skainet-backend-api/src/jvmMain/kotlin/sk/ainet/backend/api/kernel/MemSegKernelProvider.kt b/skainet-backends/skainet-backend-api/src/jvmMain/kotlin/sk/ainet/backend/api/kernel/MemSegKernelProvider.kt new file mode 100644 index 00000000..527401a5 --- /dev/null +++ b/skainet-backends/skainet-backend-api/src/jvmMain/kotlin/sk/ainet/backend/api/kernel/MemSegKernelProvider.kt @@ -0,0 +1,37 @@ +package sk.ainet.backend.api.kernel + +/** + * JVM-only sibling of [KernelProvider] for kernels whose interface + * surface depends on `java.lang.foreign.MemorySegment`. Kept separate + * because [KernelProvider] lives in `commonMain` — adding + * `MemorySegment` accessors there would break Kotlin/Native, JS, and + * Wasm targets. + * + * Providers that ship MemSeg-input kernels declare both interfaces: + * + * ```kotlin + * public object MyProvider : KernelProvider, MemSegKernelProvider { ... } + * ``` + * + * Lookup pattern at the call site: + * + * ```kotlin + * val kernel = (KernelRegistry.bestAvailable() as? MemSegKernelProvider) + * ?.matmulQ4KMemSeg() + * ?: fallbackHeapPath() + * ``` + * + * No automatic registry lookup helper for now — the smart-cast is + * sufficient and avoids a second registry. If a third MemSeg surface + * lands (FP32 matmul-MemSeg, Q6_K matmul-MemSeg, ...) it joins this + * interface as another `null`-defaulting accessor. + */ +public interface MemSegKernelProvider { + /** + * F32 × Q4_K matmul-MemSeg kernel exposed by this provider, or + * `null` if this provider does not specialize the MemSeg path. + * Default returns `null` so providers that pre-date the MemSeg SPI + * keep compiling. + */ + public fun matmulQ4KMemSeg(): Q4KMemSegMatmulKernel? = null +} diff --git a/skainet-backends/skainet-backend-api/src/jvmMain/kotlin/sk/ainet/backend/api/kernel/Q4KMemSegMatmulKernel.kt b/skainet-backends/skainet-backend-api/src/jvmMain/kotlin/sk/ainet/backend/api/kernel/Q4KMemSegMatmulKernel.kt new file mode 100644 index 00000000..c59fd0b9 --- /dev/null +++ b/skainet-backends/skainet-backend-api/src/jvmMain/kotlin/sk/ainet/backend/api/kernel/Q4KMemSegMatmulKernel.kt @@ -0,0 +1,52 @@ +package sk.ainet.backend.api.kernel + +import java.lang.foreign.MemorySegment + +/** + * F32 input × Q4_K-packed weights matrix-vector multiply where the + * **weight tensor is supplied as a `java.lang.foreign.MemorySegment`** + * rather than a heap [ByteArray]. JVM-only sibling of [Q4KMatmulKernel]. + * + * Use this kernel when the Q4_K weight bytes already live in an + * off-heap segment — typically because they were `mmap`'d from a + * `.gguf` / `.safetensors` file, or because they were materialized + * into an `Arena.ofShared` segment at load time. Letting a backend + * read those bytes directly avoids the staging copy that + * [Q4KMatmulKernel.matmul] does on every call (heap `ByteArray` → + * temporary off-heap segment → native). + * + * The block layout, scale-pair packing, and lazy-`dmin` math are + * identical to [Q4KMatmulKernel] (canonical ggml super-block, 256 + * elements, 144 bytes/block; see that kernel's kdoc for the byte + * map). Implementations MUST NOT mutate `input` or `weight`, MUST + * fully write `outputDim` floats starting at `output[outputOffset]`, + * and MAY assume no aliasing between the inputs and the output. + * + * Lifetime contract: the caller owns the [weight] segment's [Arena]. + * The kernel must not retain pointers past the [matmul] call return — + * no asynchronous reads, no caching of dereferenced addresses across + * calls. Callers in turn must keep the segment's arena alive for the + * duration of the call. + */ +public interface Q4KMemSegMatmulKernel { + /** + * @param input FP32 input vector (single row), heap array. + * @param inputOffset element offset into [input] where the row starts. + * @param weight off-heap `MemorySegment` holding the packed Q4_K + * weights for the full `outputDim × inputDim` tensor in canonical + * block-major layout `(blockIdx * outputDim + o) * 144` bytes. + * @param weightByteOffset byte offset into [weight] where block + * `(0, 0)` starts. + * @param inputDim contraction dimension; must be a multiple of 256. + * @param outputDim number of output cells. + * @param output FP32 output vector, heap array. + * @param outputOffset element offset into [output] where the row + * starts. + */ + public fun matmul( + input: FloatArray, inputOffset: Int, + weight: MemorySegment, weightByteOffset: Long, + inputDim: Int, outputDim: Int, + output: FloatArray, outputOffset: Int, + ) +} diff --git a/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeKernelProvider.kt b/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeKernelProvider.kt index 745ec7ca..adb9cabf 100644 --- a/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeKernelProvider.kt +++ b/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeKernelProvider.kt @@ -2,23 +2,34 @@ package sk.ainet.exec.kernel import sk.ainet.backend.api.kernel.Fp32MatmulKernel import sk.ainet.backend.api.kernel.KernelProvider +import sk.ainet.backend.api.kernel.MemSegKernelProvider import sk.ainet.backend.api.kernel.Q4KMatmulKernel +import sk.ainet.backend.api.kernel.Q4KMemSegMatmulKernel /** - * Native (FFM) [KernelProvider]. Sits at priority `100`, above - * [PanamaVectorKernelProvider] (`50`) and the scalar reference (`0`). + * Native (FFM) [KernelProvider] / [MemSegKernelProvider]. Sits at + * priority `100`, above [PanamaVectorKernelProvider] (`50`) and the + * scalar reference (`0`). * * Availability is gated on [NativeQ4KMatmulKernel.isAvailable] — the - * bundled `libskainet_kernels` shared library has to load AND the - * `skainet_q4k_matmul` symbol has to resolve via FFM. When either - * fails (missing arch, sandbox, JDK without FFM, kill-switch), + * bundled `libskainet_kernels` shared library has to load AND + * `skainet_q4k_matmul` has to resolve via FFM. When either fails + * (missing arch, sandbox, JDK without FFM, kill-switch), * `KernelRegistry.bestAvailable()` cleanly cascades to * [PanamaVectorKernelProvider] at priority 50. * - * PR 2 of the staged rollout: real Q4_K matmul wired into the SPI. - * `matmulFp32` follows in a later PR alongside a native FP32 kernel. + * The MemSeg surface ([matmulQ4KMemSeg]) is the JVM-only zero-copy + * path for mmap'd Q4_K weights — sized for inference loops that + * project against pre-loaded `MemorySegment`-backed tensors. Heap + * callers stick with [matmulQ4K]; both wrap the same C symbol so + * outputs are bit-for-bit identical. + * + * Staged rollout cursor (see `native-ffm-plan` asciidoc): + * - PR 2: real Q4_K matmul wired into the heap SPI. + * - PR 3 (this commit): MemSeg-input zero-copy sibling. + * - Later: native `matmulFp32`, `matmulQ6K`, `matmulQ8_0`. */ -public object NativeKernelProvider : KernelProvider { +public object NativeKernelProvider : KernelProvider, MemSegKernelProvider { override val name: String = "native-ffm" override val priority: Int = 100 @@ -28,4 +39,7 @@ public object NativeKernelProvider : KernelProvider { override fun matmulQ4K(): Q4KMatmulKernel? = if (NativeQ4KMatmulKernel.isAvailable()) NativeQ4KMatmulKernel else null + + override fun matmulQ4KMemSeg(): Q4KMemSegMatmulKernel? = + if (NativeQ4KMemSegMatmulKernel.isAvailable()) NativeQ4KMemSegMatmulKernel else null } diff --git a/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeKernelProviderFactory.kt b/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeKernelProviderFactory.kt index b9a5480a..4d26b2d4 100644 --- a/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeKernelProviderFactory.kt +++ b/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeKernelProviderFactory.kt @@ -1,16 +1,33 @@ package sk.ainet.exec.kernel import sk.ainet.backend.api.kernel.KernelProvider +import sk.ainet.backend.api.kernel.MemSegKernelProvider /** * `ServiceLoader`-friendly wrapper around [NativeKernelProvider]. The * platform `ServiceLoader` machinery requires a public no-arg * constructor, which a Kotlin `object` does not expose; this factory - * delegates every [KernelProvider] member back to the singleton. + * delegates every [KernelProvider] / [MemSegKernelProvider] member + * back to the singleton. + * + * Implementing both interfaces here matters for the MemSeg lookup + * pattern at the call site: + * + * ```kotlin + * val provider = KernelRegistry.bestAvailable() // KernelProvider + * val memSeg = (provider as? MemSegKernelProvider) // smart-cast + * ?.matmulQ4KMemSeg() + * ``` + * + * Without the second `by`, the factory instance the registry hands out + * wouldn't satisfy the smart-cast even though the underlying singleton + * implements both interfaces. * * Listed in * `META-INF/services/sk.ainet.backend.api.kernel.KernelProvider` so * `KernelServiceLoader.installAll()` discovers the provider on JVM * startup. */ -public class NativeKernelProviderFactory : KernelProvider by NativeKernelProvider +public class NativeKernelProviderFactory : + KernelProvider by NativeKernelProvider, + MemSegKernelProvider by NativeKernelProvider diff --git a/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeQ4KMemSegMatmulKernel.kt b/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeQ4KMemSegMatmulKernel.kt new file mode 100644 index 00000000..a3b10b13 --- /dev/null +++ b/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeQ4KMemSegMatmulKernel.kt @@ -0,0 +1,107 @@ +package sk.ainet.exec.kernel + +import java.lang.foreign.Arena +import java.lang.foreign.FunctionDescriptor +import java.lang.foreign.Linker +import java.lang.foreign.MemorySegment +import java.lang.foreign.ValueLayout +import java.lang.invoke.MethodHandle +import sk.ainet.backend.api.kernel.Q4KMemSegMatmulKernel + +/** + * Zero-copy native [Q4KMemSegMatmulKernel] implementation. + * + * Reuses the same `skainet_q4k_matmul` C symbol as + * [NativeQ4KMatmulKernel] — the C side just sees `const uint8_t*` and + * doesn't care whether the Kotlin caller backed those bytes by a + * staged copy of a `ByteArray` or by an mmap'd off-heap segment. The + * win on this path is that the weight bytes (which dominate the + * payload — typical LLM Q4_K tensor: tens to hundreds of MB per layer) + * never round-trip through the heap. + * + * Per-call cost vs [NativeQ4KMatmulKernel]: + * - skips `MemorySegment.copy(weight, ...)` of `inputDim/256 * outputDim + * * 144` bytes (e.g. 9 MB at 4096² shape). + * - still copies `inputDim * 4` bytes for the input vector and + * `outputDim * 4` bytes for the output — the input/output are + * typically heap arrays produced/consumed by the surrounding + * forward pass. + * + * PR 3 of the staged native-FFM rollout — see the `native-ffm-plan` + * asciidoc. + */ +internal object NativeQ4KMemSegMatmulKernel : Q4KMemSegMatmulKernel { + + private const val BLOCK_SIZE = 256 + + fun isAvailable(): Boolean = handle != null + + override fun matmul( + input: FloatArray, inputOffset: Int, + weight: MemorySegment, weightByteOffset: Long, + inputDim: Int, outputDim: Int, + output: FloatArray, outputOffset: Int, + ) { + require(inputDim % BLOCK_SIZE == 0) { + "NativeQ4KMemSegMatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim" + } + require(weightByteOffset >= 0) { + "NativeQ4KMemSegMatmulKernel: weightByteOffset must be non-negative; got $weightByteOffset" + } + require(weightByteOffset <= Int.MAX_VALUE) { + "NativeQ4KMemSegMatmulKernel: weightByteOffset $weightByteOffset exceeds Int range — " + + "the C kernel takes int32_t today; slice the segment first or wait for the int64_t overload" + } + if (outputDim == 0 || inputDim == 0) return + val mh = handle + ?: error("NativeQ4KMemSegMatmulKernel.matmul invoked while native library unavailable") + + // The C kernel reads weight from offset 0..weightBytesUsed, so + // require that the caller's segment is large enough. This catches + // scope/aliasing bugs early; without it, an undersized segment + // would crash the JVM with SIGSEGV from native code. + val weightBytesUsed = ((inputDim / BLOCK_SIZE).toLong() * outputDim) * 144L + require(weightByteOffset + weightBytesUsed <= weight.byteSize()) { + "NativeQ4KMemSegMatmulKernel: weight segment too small — needs " + + "$weightBytesUsed bytes from offset $weightByteOffset, " + + "segment is ${weight.byteSize()} bytes" + } + + Arena.ofConfined().use { arena -> + val inSeg = arena.allocate( + inputDim.toLong() * java.lang.Float.BYTES, + ValueLayout.JAVA_FLOAT.byteAlignment(), + ) + val outSeg = arena.allocate( + outputDim.toLong() * java.lang.Float.BYTES, + ValueLayout.JAVA_FLOAT.byteAlignment(), + ) + MemorySegment.copy(input, inputOffset, inSeg, ValueLayout.JAVA_FLOAT, 0L, inputDim) + + mh.invoke( + inSeg, 0, + weight, weightByteOffset.toInt(), + inputDim, outputDim, + outSeg, 0, + ) + + MemorySegment.copy(outSeg, ValueLayout.JAVA_FLOAT, 0L, output, outputOffset, outputDim) + } + } + + private val handle: MethodHandle? by lazy { + val lookup = NativeLibraryLoader.lookup() ?: return@lazy null + val symbol = lookup.find("skainet_q4k_matmul").orElse(null) ?: return@lazy null + val descriptor = FunctionDescriptor.ofVoid( + ValueLayout.ADDRESS, // input + ValueLayout.JAVA_INT, // input_offset + ValueLayout.ADDRESS, // weight (passed straight through from caller) + ValueLayout.JAVA_INT, // weight_byte_offset + ValueLayout.JAVA_INT, // input_dim + ValueLayout.JAVA_INT, // output_dim + ValueLayout.ADDRESS, // output + ValueLayout.JAVA_INT, // output_offset + ) + runCatching { Linker.nativeLinker().downcallHandle(symbol, descriptor) }.getOrNull() + } +} diff --git a/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ4KMemSegMatmulKernelParityTest.kt b/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ4KMemSegMatmulKernelParityTest.kt new file mode 100644 index 00000000..9da60a66 --- /dev/null +++ b/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ4KMemSegMatmulKernelParityTest.kt @@ -0,0 +1,174 @@ +package sk.ainet.exec.kernel + +import java.lang.foreign.Arena +import java.lang.foreign.MemorySegment +import java.lang.foreign.ValueLayout +import kotlin.random.Random +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import sk.ainet.backend.api.kernel.MemSegKernelProvider + +/** + * Parity tests for [NativeQ4KMemSegMatmulKernel] vs the heap variant + * [NativeQ4KMatmulKernel]. Both paths invoke the same + * `skainet_q4k_matmul` C symbol, so output must agree **exactly** — + * the only difference is whether the weight bytes were staged through + * an arena copy or read directly from a caller-owned segment. + * + * Bit-identical assertion (no tolerance) is the contract: any drift + * here means the wrapper added arithmetic, which is a bug. + * + * Also asserts the SPI plumbing works end-to-end: + * - [NativeKernelProvider] reports itself as [MemSegKernelProvider] + * so the smart-cast at the call site succeeds. + * - The factory class hands out the provider via the same path + * `KernelServiceLoader` would use. + */ +class NativeQ4KMemSegMatmulKernelParityTest { + + private val blockSize = 256 + private val bytesPerBlock = 144 + + @BeforeTest + fun checkAvailable() { + assertTrue(NativeQ4KMatmulKernel.isAvailable(), "Heap kernel must be available") + assertTrue(NativeQ4KMemSegMatmulKernel.isAvailable(), "MemSeg kernel must be available") + } + + private fun randomQ4KBytes(numBlocks: Int, seed: Int): ByteArray { + val rng = Random(seed) + val bytes = ByteArray(numBlocks * bytesPerBlock) + rng.nextBytes(bytes) + for (block in 0 until numBlocks) { + val base = block * bytesPerBlock + bytes[base + 0] = 0x00.toByte() + bytes[base + 1] = 0x3C.toByte() + bytes[base + 2] = 0x00.toByte() + bytes[base + 3] = 0x3C.toByte() + } + return bytes + } + + private fun assertBitIdentical(inputDim: Int, outputDim: Int, seed: Int) { + val numBlocks = (inputDim / blockSize) * outputDim + val packed = randomQ4KBytes(numBlocks, seed) + val input = FloatArray(inputDim) { Random(seed + it).nextFloat() - 0.5f } + + val heapOut = FloatArray(outputDim) + NativeQ4KMatmulKernel.matmul(input, 0, packed, 0, inputDim, outputDim, heapOut, 0) + + val memSegOut = FloatArray(outputDim) + Arena.ofConfined().use { arena -> + val weightSeg = arena.allocate(packed.size.toLong(), 1L) + MemorySegment.copy(packed, 0, weightSeg, ValueLayout.JAVA_BYTE, 0L, packed.size) + NativeQ4KMemSegMatmulKernel.matmul( + input, 0, + weightSeg, 0L, + inputDim, outputDim, + memSegOut, 0, + ) + } + + for (o in 0 until outputDim) { + assertEquals( + heapOut[o].toRawBits(), + memSegOut[o].toRawBits(), + "row $o diverged: heap=${heapOut[o]} memSeg=${memSegOut[o]}", + ) + } + } + + @Test + fun bit_identical_single_block_single_row() { + assertBitIdentical(inputDim = 256, outputDim = 1, seed = 42) + } + + @Test + fun bit_identical_single_block_multi_row() { + assertBitIdentical(inputDim = 256, outputDim = 16, seed = 7) + } + + @Test + fun bit_identical_multi_block_multi_row() { + assertBitIdentical(inputDim = 1024, outputDim = 64, seed = 123) + } + + @Test + fun bit_identical_llm_typical_shape() { + assertBitIdentical(inputDim = 4096, outputDim = 64, seed = 999) + } + + @Test + fun honors_non_zero_weight_byte_offset() { + // Same weights laid out at byte offset 257 inside a larger + // segment — kernel must skip the leading bytes correctly. + val inputDim = 256 + val outputDim = 4 + val seed = 17 + val numBlocks = (inputDim / blockSize) * outputDim + val packed = randomQ4KBytes(numBlocks, seed) + val input = FloatArray(inputDim) { Random(seed + it).nextFloat() - 0.5f } + + val heapOut = FloatArray(outputDim) + NativeQ4KMatmulKernel.matmul(input, 0, packed, 0, inputDim, outputDim, heapOut, 0) + + val memSegOut = FloatArray(outputDim) + val leadingPadBytes = 257L + Arena.ofConfined().use { arena -> + val weightSeg = arena.allocate(packed.size + leadingPadBytes, 1L) + MemorySegment.copy(packed, 0, weightSeg, ValueLayout.JAVA_BYTE, leadingPadBytes, packed.size) + NativeQ4KMemSegMatmulKernel.matmul( + input, 0, + weightSeg, leadingPadBytes, + inputDim, outputDim, + memSegOut, 0, + ) + } + + for (o in 0 until outputDim) { + assertEquals(heapOut[o].toRawBits(), memSegOut[o].toRawBits(), "row $o offset path diverged") + } + } + + @Test + fun rejects_undersized_weight_segment() { + val inputDim = 256 + val outputDim = 4 + val numBlocks = (inputDim / blockSize) * outputDim + val needed = numBlocks.toLong() * bytesPerBlock // 4 * 144 = 576 + val input = FloatArray(inputDim) + val output = FloatArray(outputDim) + + Arena.ofConfined().use { arena -> + val tooSmall = arena.allocate(needed - 1L, 1L) + try { + NativeQ4KMemSegMatmulKernel.matmul( + input, 0, + tooSmall, 0L, + inputDim, outputDim, + output, 0, + ) + kotlin.test.fail("expected IllegalArgumentException for undersized segment") + } catch (_: IllegalArgumentException) { + // expected + } + } + } + + @Test + fun provider_smart_casts_to_MemSegKernelProvider() { + val provider: Any = NativeKernelProvider + assertTrue(provider is MemSegKernelProvider, "NativeKernelProvider must implement MemSegKernelProvider") + assertNotNull(provider.matmulQ4KMemSeg()) + } + + @Test + fun factory_smart_casts_to_MemSegKernelProvider() { + val factory: Any = NativeKernelProviderFactory() + assertTrue(factory is MemSegKernelProvider, "Factory must implement MemSegKernelProvider for ServiceLoader path") + assertNotNull(factory.matmulQ4KMemSeg()) + } +} diff --git a/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/Q4KMatmulMicrobenchTest.kt b/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/Q4KMatmulMicrobenchTest.kt index e92177a0..a83b5c13 100644 --- a/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/Q4KMatmulMicrobenchTest.kt +++ b/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/Q4KMatmulMicrobenchTest.kt @@ -1,5 +1,8 @@ package sk.ainet.exec.kernel +import java.lang.foreign.Arena +import java.lang.foreign.MemorySegment +import java.lang.foreign.ValueLayout import kotlin.math.abs import kotlin.random.Random import kotlin.test.Test @@ -86,30 +89,55 @@ class Q4KMatmulMicrobenchTest { println("Host: ${System.getProperty("os.name")} ${System.getProperty("os.arch")} | JDK ${System.getProperty("java.version")}") println() - for ((inputDim, outputDim, seed) in shapes) { - val numBlocks = (inputDim / blockSize) * outputDim - val packed = randomQ4KBytes(numBlocks, seed) - val input = FloatArray(inputDim) { Random(seed + it).nextFloat() - 0.5f } - val outNative = FloatArray(outputDim) - val outPanama = FloatArray(outputDim) + // Pre-allocate weight segments outside the timed region — the + // MemSeg path's whole point is that weights are loaded ONCE + // (mmap, Arena.ofShared) and reused across forward passes. + Arena.ofShared().use { sharedArena -> + for ((inputDim, outputDim, seed) in shapes) { + val numBlocks = (inputDim / blockSize) * outputDim + val packed = randomQ4KBytes(numBlocks, seed) + val input = FloatArray(inputDim) { Random(seed + it).nextFloat() - 0.5f } + val outNative = FloatArray(outputDim) + val outNativeMemSeg = FloatArray(outputDim) + val outPanama = FloatArray(outputDim) - println("[inputDim=$inputDim, outputDim=$outputDim]") - val nativeNs = benchOne("native", warmup = 20, samples = 21) { - NativeQ4KMatmulKernel.matmul(input, 0, packed, 0, inputDim, outputDim, outNative, 0) - } - val panamaNs = benchOne("panama", warmup = 20, samples = 21) { - PanamaVectorQ4KMatmulKernel.matmul(input, 0, packed, 0, inputDim, outputDim, outPanama, 0) + val weightSeg: MemorySegment = sharedArena.allocate(packed.size.toLong(), 1L) + MemorySegment.copy(packed, 0, weightSeg, ValueLayout.JAVA_BYTE, 0L, packed.size) + + println("[inputDim=$inputDim, outputDim=$outputDim]") + val nativeNs = benchOne("native (heap) ", warmup = 20, samples = 21) { + NativeQ4KMatmulKernel.matmul(input, 0, packed, 0, inputDim, outputDim, outNative, 0) + } + val nativeMemSegNs = benchOne("native (memseg)", warmup = 20, samples = 21) { + NativeQ4KMemSegMatmulKernel.matmul( + input, 0, + weightSeg, 0L, + inputDim, outputDim, + outNativeMemSeg, 0, + ) + } + val panamaNs = benchOne("panama ", warmup = 20, samples = 21) { + PanamaVectorQ4KMatmulKernel.matmul(input, 0, packed, 0, inputDim, outputDim, outPanama, 0) + } + + val heapVsMemSeg = nativeNs.toDouble() / nativeMemSegNs.toDouble() + val memSegVsPanama = panamaNs.toDouble() / nativeMemSegNs.toDouble() + println( + " zero-copy speedup: %.2fx over native heap-copy (heap=%dµs vs memseg=%dµs)".format( + heapVsMemSeg, + nativeNs / 1_000, + nativeMemSegNs / 1_000, + ), + ) + println( + " ratio: native (memseg) is %.2fx panama (%.1f%% %s)".format( + memSegVsPanama, + abs((memSegVsPanama - 1.0) * 100.0), + if (memSegVsPanama >= 1.0) "faster" else "slower", + ), + ) + println() } - val ratio = panamaNs.toDouble() / nativeNs.toDouble() - val pct = (ratio - 1.0) * 100.0 - println( - " ratio: native is %.2fx panama (%.1f%% %s)".format( - ratio, - abs(pct), - if (ratio >= 1.0) "faster" else "slower", - ), - ) - println() } } }