Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ add_library(skainet_kernels SHARED
src/fp32_matmul.c
src/bf16_matmul.c
src/q8_0_matmul.c
src/q4_0_matmul.c
)

target_include_directories(skainet_kernels PUBLIC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,32 @@ SKAINET_API void skainet_q8_0_matmul(
int32_t output_offset
);

/*
* Q4_0 matrix-vector multiply.
*
* output[output_offset + o] = sum_j input[input_offset + j] *
* dequant(weight[block, o, j])
*
* Block layout: canonical ggml Q4_0, 32 elements per block, 18 bytes
* per block (2 B FP16 scale + 16 B packed 4-bit codes in split layout —
* low nibbles → elements 0..15, high nibbles → 16..31), with packed
* weights laid out as
* weight + weight_byte_offset + (block_idx * output_dim + o) * 18
*
* Dequant per element: `(code - 8) * d`. input_dim must be a multiple
* of 32.
*/
SKAINET_API void skainet_q4_0_matmul(
const float* input,
int32_t input_offset,
const uint8_t* weight,
int32_t weight_byte_offset,
int32_t input_dim,
int32_t output_dim,
float* output,
int32_t output_offset
);

#ifdef __cplusplus
}
#endif
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#include "skainet_kernels.h"

#include <stddef.h>
#include <stdint.h>
#include <string.h>

/*
* Native FP32 × Q4_0 matrix-vector matmul matching the
* sk.ainet.backend.api.kernel.Q4_0MatmulKernel SPI.
*
* Block layout (canonical ggml Q4_0, 32 elements, 18 bytes):
* - bytes 0..1 : FP16 little-endian scale `d`
* - bytes 2..17 : 16 bytes packing 32 4-bit codes in the *split*
* layout — low nibbles decode elements 0..15, high nibbles decode
* elements 16..31.
*
* Per-block packed weight layout:
* weight + weight_byte_offset + (block_idx * output_dim + o) * 18
*
* Dequant per element: `(code - 8) * d`. The `- 8` bias centres the
* unsigned 4-bit code. Scale `d` is folded once after the block
* accumulator (cheaper than broadcasting it across every inner FMA).
*/

/* Portable FP16 → FP32 conversion. Matches the Kotlin
* `Q4_0BlockTensorData.halfToFloat` algorithm bit-for-bit. */
static inline float skainet_q4_0_fp16_to_fp32(uint16_t h) {
uint32_t sign = ((uint32_t)(h & 0x8000u)) << 16;
uint32_t exp = (h >> 10) & 0x1Fu;
uint32_t mant = h & 0x3FFu;
uint32_t bits;
if (exp == 0) {
if (mant == 0) {
bits = sign;
} else {
int e = -14;
while ((mant & 0x400u) == 0) {
mant <<= 1;
--e;
}
mant &= 0x3FFu;
bits = sign | ((uint32_t)(e + 127) << 23) | (mant << 13);
}
} else if (exp == 0x1Fu) {
bits = sign | 0x7F800000u | (mant << 13);
} else {
bits = sign | ((uint32_t)(exp - 15 + 127) << 23) | (mant << 13);
}
float r;
memcpy(&r, &bits, sizeof(r));
return r;
}

SKAINET_API void skainet_q4_0_matmul(
const float* SKAINET_RESTRICT input, int32_t input_offset,
const uint8_t* SKAINET_RESTRICT weight, int32_t weight_byte_offset,
int32_t input_dim, int32_t output_dim,
float* SKAINET_RESTRICT output, int32_t output_offset
) {
if (output_dim <= 0) return;
if (input_dim <= 0) {
for (int32_t o = 0; o < output_dim; ++o) {
output[output_offset + o] = 0.0f;
}
return;
}

const int32_t BLOCK_SIZE = 32;
const int32_t BYTES_PER_BLOCK = 18;
const int32_t blocks_per_input_dim = input_dim / BLOCK_SIZE;

for (int32_t o = 0; o < output_dim; ++o) {
float acc = 0.0f;
for (int32_t block_idx = 0; block_idx < blocks_per_input_dim; ++block_idx) {
const uint8_t* SKAINET_RESTRICT block =
weight + weight_byte_offset +
(size_t)(block_idx * output_dim + o) * BYTES_PER_BLOCK;
uint16_t d_bits = (uint16_t) block[0] | ((uint16_t) block[1] << 8);
float d = skainet_q4_0_fp16_to_fp32(d_bits);
const uint8_t* SKAINET_RESTRICT codes = block + 2;
const float* SKAINET_RESTRICT input_block =
input + input_offset + (size_t) block_idx * BLOCK_SIZE;
float block_sum = 0.0f;
for (int32_t k = 0; k < 16; ++k) {
int32_t lo = (int32_t)(codes[k] & 0x0F) - 8;
int32_t hi = (int32_t)(codes[k] >> 4) - 8;
block_sum += input_block[k] * (float) lo;
block_sum += input_block[k + 16] * (float) hi;
}
acc += block_sum * d;
}
output[output_offset + o] = acc;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import sk.ainet.backend.api.kernel.KernelProvider
import sk.ainet.backend.api.kernel.MemSegKernelProvider
import sk.ainet.backend.api.kernel.Q4KMatmulKernel
import sk.ainet.backend.api.kernel.Q4KMemSegMatmulKernel
import sk.ainet.backend.api.kernel.Q4_0MatmulKernel
import sk.ainet.backend.api.kernel.Q8_0MatmulKernel

/**
Expand Down Expand Up @@ -93,4 +94,7 @@ public object NativeKernelProvider : KernelProvider, MemSegKernelProvider {

override fun matmulQ8_0(): Q8_0MatmulKernel? =
if (NativeQ8_0MatmulKernel.isAvailable()) NativeQ8_0MatmulKernel else null

override fun matmulQ4_0(): Q4_0MatmulKernel? =
if (NativeQ4_0MatmulKernel.isAvailable()) NativeQ4_0MatmulKernel else null
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package sk.ainet.exec.kernel

import java.lang.foreign.Arena
import java.lang.foreign.FunctionDescriptor
import java.lang.foreign.Linker
import java.lang.foreign.MemorySegment
import java.lang.foreign.ValueLayout
import java.lang.invoke.MethodHandle
import sk.ainet.backend.api.kernel.Q4_0MatmulKernel

/**
* Native (FFM) implementation of [Q4_0MatmulKernel].
*
* Wraps the bundled C symbol
*
* void skainet_q4_0_matmul(
* const float* input, int32_t input_offset,
* const uint8_t* weight, int32_t weight_byte_offset,
* int32_t input_dim, int32_t output_dim,
* float* output, int32_t output_offset);
*
* The C kernel decodes the ggml-canonical Q4_0 block (FP16 scale + 16
* packed bytes, split nibble layout) with `(code - 8) * d` dequant and a
* tight inner FMA the compiler auto-vectorizes under -O3 -ffast-math.
*
* Numerical parity vs [ScalarQ4_0MatmulKernel] is asserted by
* `NativeQ4_0MatmulKernelParityTest` within the same `1e-2 *
* blocksPerInputDim` band the Panama parity uses.
*/
internal object NativeQ4_0MatmulKernel : Q4_0MatmulKernel {

fun isAvailable(): Boolean = handle != null

override fun matmul(
input: FloatArray, inputOffset: Int,
weight: ByteArray, weightByteOffset: Int,
inputDim: Int, outputDim: Int,
output: FloatArray, outputOffset: Int,
) {
require(inputDim % BLOCK_SIZE == 0) {
"NativeQ4_0MatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim"
}
if (outputDim == 0) return

val mh = handle
?: error("NativeQ4_0MatmulKernel.matmul invoked while native library unavailable")

val blocksPerInputDim = inputDim / BLOCK_SIZE
val inputReachFloats = if (inputDim == 0) 0 else inputOffset + inputDim
val weightReachBytes = if (inputDim == 0 || outputDim == 0) 0
else weightByteOffset + blocksPerInputDim * outputDim * BYTES_PER_BLOCK
val outputReachFloats = outputOffset + outputDim

Arena.ofConfined().use { arena ->
val fAlign = ValueLayout.JAVA_FLOAT.byteAlignment()
val bAlign = ValueLayout.JAVA_BYTE.byteAlignment()

val inputSeg: MemorySegment = if (inputReachFloats > 0)
arena.allocate(inputReachFloats.toLong() * java.lang.Float.BYTES, fAlign)
else MemorySegment.NULL
val weightSeg: MemorySegment = if (weightReachBytes > 0)
arena.allocate(weightReachBytes.toLong(), bAlign)
else MemorySegment.NULL
val outputSeg: MemorySegment =
arena.allocate(outputReachFloats.toLong() * java.lang.Float.BYTES, fAlign)

if (inputReachFloats > 0) {
MemorySegment.copy(input, 0, inputSeg, ValueLayout.JAVA_FLOAT, 0L, inputReachFloats)
}
if (weightReachBytes > 0) {
MemorySegment.copy(weight, 0, weightSeg, ValueLayout.JAVA_BYTE, 0L, weightReachBytes)
}

mh.invoke(
inputSeg, inputOffset,
weightSeg, weightByteOffset,
inputDim, outputDim,
outputSeg, outputOffset,
)

MemorySegment.copy(outputSeg, ValueLayout.JAVA_FLOAT, 0L, output, 0, outputReachFloats)
}
}

private const val BLOCK_SIZE = 32
private const val BYTES_PER_BLOCK = 18

private val handle: MethodHandle? by lazy {
val lookup = NativeLibraryLoader.lookup() ?: return@lazy null
val symbol = lookup.find("skainet_q4_0_matmul").orElse(null) ?: return@lazy null
val descriptor = FunctionDescriptor.ofVoid(
ValueLayout.ADDRESS, // input
ValueLayout.JAVA_INT, // input_offset
ValueLayout.ADDRESS, // weight
ValueLayout.JAVA_INT, // weight_byte_offset
ValueLayout.JAVA_INT, // input_dim
ValueLayout.JAVA_INT, // output_dim
ValueLayout.ADDRESS, // output
ValueLayout.JAVA_INT, // output_offset
)
runCatching { Linker.nativeLinker().downcallHandle(symbol, descriptor) }.getOrNull()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package sk.ainet.exec.kernel

import kotlin.math.abs
import kotlin.random.Random
import kotlin.test.BeforeTest
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue

/**
* Numerical parity tests for [NativeQ4_0MatmulKernel] against
* [PanamaVectorQ4_0MatmulKernel]. Same FP16 scale decode + split-layout
* `(nibble - 8)` dequant in both kernels; differences come from FMA +
* reordered-reduction only.
*
* Tolerance: `1e-2 * blocksPerInputDim` (matches the Panama / Q8_0
* parity convention).
*/
class NativeQ4_0MatmulKernelParityTest {

private val blockSize = 32
private val bytesPerBlock = 18

@BeforeTest
fun checkAvailable() {
assertTrue(
NativeQ4_0MatmulKernel.isAvailable(),
"Native Q4_0 kernel must be available — bundled libskainet_kernels missing or " +
"skainet_q4_0_matmul symbol unresolved",
)
}

private fun randomQ4_0Bytes(blocksPerInputDim: Int, outputDim: Int, seed: Int): ByteArray {
val rng = Random(seed)
val numBlocks = blocksPerInputDim * outputDim
val bytes = ByteArray(numBlocks * bytesPerBlock)
rng.nextBytes(bytes)
for (block in 0 until numBlocks) {
val base = block * bytesPerBlock
bytes[base + 0] = 0x00.toByte()
bytes[base + 1] = 0x22.toByte() // FP16 ~ 7.6e-3, comfortably finite + non-zero
}
return bytes
}

private fun assertParity(
inputDim: Int,
outputDim: Int,
seed: Int,
tolPerBlock: Float = 1e-2f,
) {
val blocksPerInputDim = inputDim / blockSize
val rng = Random(seed)
val input = FloatArray(inputDim) { rng.nextFloat() - 0.5f }
val weight = randomQ4_0Bytes(blocksPerInputDim, outputDim, seed)
val outPanama = FloatArray(outputDim)
val outNative = FloatArray(outputDim)

PanamaVectorQ4_0MatmulKernel.matmul(input, 0, weight, 0, inputDim, outputDim, outPanama, 0)
NativeQ4_0MatmulKernel.matmul(input, 0, weight, 0, inputDim, outputDim, outNative, 0)

val tol = (tolPerBlock * blocksPerInputDim.coerceAtLeast(1)).coerceAtLeast(tolPerBlock)
for (i in outPanama.indices) {
val diff = abs(outPanama[i] - outNative[i])
assertTrue(
diff <= tol,
"mismatch at $i: panama=${outPanama[i]} native=${outNative[i]} diff=$diff tol=$tol",
)
}
}

@Test fun single_block_single_output_matches_panama() =
assertParity(inputDim = 32, outputDim = 1, seed = 1)

@Test fun single_block_multiple_outputs_matches_panama() =
assertParity(inputDim = 32, outputDim = 7, seed = 2)

@Test fun multiple_blocks_single_output_matches_panama() =
assertParity(inputDim = 256, outputDim = 1, seed = 3)

@Test fun llm_typical_attention_proj_matches_panama() =
assertParity(inputDim = 512, outputDim = 512, seed = 4)

@Test fun llm_typical_ffn_proj_matches_panama() =
assertParity(inputDim = 256, outputDim = 1024, seed = 5)

@Test fun rejects_non_block_aligned_input_dim() {
assertFailsWith<IllegalArgumentException> {
NativeQ4_0MatmulKernel.matmul(
FloatArray(31), 0,
ByteArray(bytesPerBlock), 0,
31, 1,
FloatArray(1), 0,
)
}
}

@Test fun zero_input_dim_zeros_output() {
val out = FloatArray(5) { 9f }
NativeQ4_0MatmulKernel.matmul(
FloatArray(0), 0,
ByteArray(0), 0,
0, 5,
out, 0,
)
for (v in out) assertEquals(0f, v, "output should be zeroed for inputDim=0")
}

@Test fun provider_returns_native_q4_0_when_available() {
val kernel = NativeKernelProvider.matmulQ4_0()
assertTrue(
kernel === NativeQ4_0MatmulKernel,
"Provider must hand out the native Q4_0 kernel when bundled lib is loaded",
)
}
}
Loading