Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ public interface KernelProvider {
*/
public fun matmulQ8_0(): Q8_0MatmulKernel? = null

/**
* F32 × Q4_0 matmul kernel exposed by this provider, or `null` if
* this provider does not specialize Q4_0. Same fall-through pattern.
*/
public fun matmulQ4_0(): Q4_0MatmulKernel? = null

/**
* Capability query: does this provider carry a kernel for
* [opName] with the given [dtypeKeys]?
Expand Down Expand Up @@ -100,6 +106,7 @@ public interface KernelProvider {
"BFloat16" -> matmulBf16() != null
"Q4_K" -> matmulQ4K() != null
"Q8_0" -> matmulQ8_0() != null
"Q4_0" -> matmulQ4_0() != null
else -> false
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package sk.ainet.backend.api.kernel

/**
* F32 input × Q4_0-packed weights matrix-vector multiply, in canonical
* ggml block layout.
*
* output[outputOffset + o] = Σ_j input[inputOffset + j] · dequant(weight[o, j])
* for j ∈ [0, inputDim), o ∈ [0, outputDim)
*
* Block layout (32-element block, 18 bytes/block; see
* [sk.ainet.lang.tensor.data.Q4_0BlockTensorData] kdoc):
* - bytes 0..1 : `d` (block scale, FP16 LE)
* - bytes 2..17 : 16 bytes packing 32 4-bit codes (split layout — low
* nibbles decode elements 0..15, high nibbles decode elements 16..31)
*
* Per element: `dequant = (code - 8) * d` (the `- 8` bias centres the
* unsigned 4-bit code around zero). Q4_0 has no per-block min / offset.
*
* Implementations MUST NOT mutate `input` or `weight`. They MAY assume
* the arrays do not alias each other or `output`. They MUST fully
* write the `outputDim` floats starting at `output[outputOffset]`.
*
* Packed-weight row-major contract: `weight` holds blocks laid out
* `(blockIdx * outputDim + o) * 18` for output row `o` and input block
* index `blockIdx`. This matches `Q4_0BlockTensorData.packedData`.
*
* `inputDim` MUST be a multiple of 32 (the Q4_0 block size).
*/
public interface Q4_0MatmulKernel {
/**
* @param input FP32 input vector (single row).
* @param inputOffset element offset into [input] where the row starts.
* @param weight packed Q4_0 bytes for the full `outputDim × inputDim` weight tensor.
* @param weightByteOffset byte offset into [weight] where block (0, 0) starts.
* @param inputDim contraction dimension (must be a multiple of 32).
* @param outputDim number of output cells.
* @param output FP32 output vector.
* @param outputOffset element offset into [output] where the row starts.
*/
public fun matmul(
input: FloatArray, inputOffset: Int,
weight: ByteArray, weightByteOffset: Int,
inputDim: Int, outputDim: Int,
output: FloatArray, outputOffset: Int,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public final class sk/ainet/exec/kernel/PanamaVectorKernelProvider : sk/ainet/ba
public fun matmulBf16 ()Lsk/ainet/backend/api/kernel/Bf16MatmulKernel;
public fun matmulFp32 ()Lsk/ainet/backend/api/kernel/Fp32MatmulKernel;
public fun matmulQ4K ()Lsk/ainet/backend/api/kernel/Q4KMatmulKernel;
public fun matmulQ4_0 ()Lsk/ainet/backend/api/kernel/Q4_0MatmulKernel;
public fun matmulQ8_0 ()Lsk/ainet/backend/api/kernel/Q8_0MatmulKernel;
public fun supports (Ljava/lang/String;Ljava/util/List;)Z
}
Expand All @@ -65,6 +66,7 @@ public final class sk/ainet/exec/kernel/PanamaVectorKernelProviderFactory : sk/a
public fun matmulBf16 ()Lsk/ainet/backend/api/kernel/Bf16MatmulKernel;
public fun matmulFp32 ()Lsk/ainet/backend/api/kernel/Fp32MatmulKernel;
public fun matmulQ4K ()Lsk/ainet/backend/api/kernel/Q4KMatmulKernel;
public fun matmulQ4_0 ()Lsk/ainet/backend/api/kernel/Q4_0MatmulKernel;
public fun matmulQ8_0 ()Lsk/ainet/backend/api/kernel/Q8_0MatmulKernel;
public fun supports (Ljava/lang/String;Ljava/util/List;)Z
}
Expand Down Expand Up @@ -97,6 +99,7 @@ public final class sk/ainet/exec/kernel/ScalarKernelProvider : sk/ainet/backend/
public fun matmulBf16 ()Lsk/ainet/backend/api/kernel/Bf16MatmulKernel;
public fun matmulFp32 ()Lsk/ainet/backend/api/kernel/Fp32MatmulKernel;
public fun matmulQ4K ()Lsk/ainet/backend/api/kernel/Q4KMatmulKernel;
public fun matmulQ4_0 ()Lsk/ainet/backend/api/kernel/Q4_0MatmulKernel;
public fun matmulQ8_0 ()Lsk/ainet/backend/api/kernel/Q8_0MatmulKernel;
public fun supports (Ljava/lang/String;Ljava/util/List;)Z
}
Expand All @@ -109,6 +112,7 @@ public final class sk/ainet/exec/kernel/ScalarKernelProviderFactory : sk/ainet/b
public fun matmulBf16 ()Lsk/ainet/backend/api/kernel/Bf16MatmulKernel;
public fun matmulFp32 ()Lsk/ainet/backend/api/kernel/Fp32MatmulKernel;
public fun matmulQ4K ()Lsk/ainet/backend/api/kernel/Q4KMatmulKernel;
public fun matmulQ4_0 ()Lsk/ainet/backend/api/kernel/Q4_0MatmulKernel;
public fun matmulQ8_0 ()Lsk/ainet/backend/api/kernel/Q8_0MatmulKernel;
public fun supports (Ljava/lang/String;Ljava/util/List;)Z
}
Expand All @@ -118,6 +122,11 @@ public final class sk/ainet/exec/kernel/ScalarMatmulKernel : sk/ainet/backend/ap
public fun matmul ([FII[FII[FIIIII)V
}

public final class sk/ainet/exec/kernel/ScalarQ4_0MatmulKernel : sk/ainet/backend/api/kernel/Q4_0MatmulKernel {
public static final field INSTANCE Lsk/ainet/exec/kernel/ScalarQ4_0MatmulKernel;
public fun matmul ([FI[BIII[FI)V
}

public final class sk/ainet/exec/kernel/ScalarQ8_0MatmulKernel : sk/ainet/backend/api/kernel/Q8_0MatmulKernel {
public static final field INSTANCE Lsk/ainet/exec/kernel/ScalarQ8_0MatmulKernel;
public fun matmul ([FI[BIII[FI)V
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sk.ainet.exec.kernel
import sk.ainet.backend.api.kernel.Bf16MatmulKernel
import sk.ainet.backend.api.kernel.Fp32MatmulKernel
import sk.ainet.backend.api.kernel.KernelProvider
import sk.ainet.backend.api.kernel.Q4_0MatmulKernel
import sk.ainet.backend.api.kernel.Q8_0MatmulKernel

/**
Expand All @@ -25,4 +26,5 @@ public object ScalarKernelProvider : KernelProvider {
override fun matmulFp32(): Fp32MatmulKernel = ScalarMatmulKernel
override fun matmulBf16(): Bf16MatmulKernel = ScalarBf16MatmulKernel
override fun matmulQ8_0(): Q8_0MatmulKernel = ScalarQ8_0MatmulKernel
override fun matmulQ4_0(): Q4_0MatmulKernel = ScalarQ4_0MatmulKernel
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package sk.ainet.exec.kernel

import sk.ainet.backend.api.kernel.Q4_0MatmulKernel

/**
* Scalar reference implementation of [Q4_0MatmulKernel] — straight
* per-block dequant + per-element FMA, no SIMD. Always available on
* every KMP target. Used as:
*
* - The correctness reference that accelerated kernels (Panama Vector,
* native FFM) must match within FP order tolerance.
* - A guaranteed fallback when no accelerated provider is registered.
*
* Block layout (32-element block, 18 bytes):
* - bytes 0..1 : FP16 little-endian scale (`d`)
* - bytes 2..17: 16 bytes packing 32 4-bit codes (split layout)
*
* Dequant per element: `(code - 8) * d`. No min / offset.
*
* Performance is intentionally modest; production paths should pick the
* Panama Vector or native variant via the kernel registry.
*/
public object ScalarQ4_0MatmulKernel : Q4_0MatmulKernel {

private const val BLOCK_SIZE = 32
private const val BYTES_PER_BLOCK = 18

override fun matmul(
input: FloatArray, inputOffset: Int,
weight: ByteArray, weightByteOffset: Int,
inputDim: Int, outputDim: Int,
output: FloatArray, outputOffset: Int,
) {
require(inputDim % BLOCK_SIZE == 0) {
"ScalarQ4_0MatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim"
}
if (outputDim == 0 || inputDim == 0) {
if (outputDim > 0) {
for (o in 0 until outputDim) output[outputOffset + o] = 0f
}
return
}
val blocksPerInputDim = inputDim / BLOCK_SIZE

for (o in 0 until outputDim) {
var acc = 0f
for (blockIdx in 0 until blocksPerInputDim) {
val blockBase = weightByteOffset + (blockIdx * outputDim + o) * BYTES_PER_BLOCK
// FP16 scale: two LE bytes.
val dBits = (weight[blockBase].toInt() and 0xFF) or
((weight[blockBase + 1].toInt() and 0xFF) shl 8)
val d = halfToFloat(dBits)
// 32 codes, blockIdx-th window of the input vector. Split
// layout: low nibbles → elements 0..15, high → 16..31.
val inputBase = inputOffset + blockIdx * BLOCK_SIZE
val codesBase = blockBase + 2
for (j in 0 until 16) {
val b = weight[codesBase + j].toInt() and 0xFF
val lo = (b and 0x0F) - 8
val hi = (b ushr 4) - 8
acc += input[inputBase + j] * lo * d
acc += input[inputBase + 16 + j] * hi * d
}
}
output[outputOffset + o] = acc
}
}

/**
* Convert a 16-bit IEEE-754 half-precision value (low 16 bits of
* [hbits]) to FP32. Mirrors [ScalarQ8_0MatmulKernel]'s inlined helper
* — the skainet-lang-core dequant helper is internal to that module.
*/
private fun halfToFloat(hbits: Int): Float {
val sign = (hbits and 0x8000) shl 16
val exp = (hbits and 0x7C00) shr 10
val mant = hbits and 0x03FF
return when (exp) {
0 -> {
if (mant == 0) Float.fromBits(sign)
else {
var m = mant
var e = -14
while ((m and 0x400) == 0) {
m = m shl 1
e--
}
m = m and 0x3FF
Float.fromBits(sign or ((e + 127) shl 23) or (m shl 13))
}
}
31 -> Float.fromBits(sign or (0xFF shl 23) or (mant shl 13))
else -> Float.fromBits(sign or ((exp - 15 + 127) shl 23) or (mant shl 13))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ import sk.ainet.backend.api.kernel.KernelRegistry
import sk.ainet.backend.api.kernel.KernelServiceLoader
import sk.ainet.backend.api.kernel.KernelStrictness
import sk.ainet.backend.api.kernel.Q4KMatmulKernel
import sk.ainet.backend.api.kernel.Q4_0MatmulKernel
import sk.ainet.backend.api.kernel.Q8_0MatmulKernel
import sk.ainet.exec.kernel.ScalarBf16MatmulKernel
import sk.ainet.exec.kernel.ScalarMatmulKernel
import sk.ainet.exec.kernel.ScalarQ4_0MatmulKernel
import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.tensor.Tensor
import sk.ainet.lang.tensor.data.DenseFloatArrayTensorData
Expand All @@ -21,6 +23,7 @@ import sk.ainet.lang.tensor.data.MemorySegmentTensorData
import sk.ainet.lang.tensor.data.Q4MemorySegmentMarker
import sk.ainet.lang.tensor.data.Q4MemorySegmentTensorData
import sk.ainet.lang.tensor.data.Bf16TensorData
import sk.ainet.lang.tensor.data.Q4_0TensorData
import sk.ainet.lang.tensor.data.Q8_0TensorData
import sk.ainet.lang.tensor.data.Q8MemorySegmentMarker
import sk.ainet.lang.tensor.data.Q8MemorySegmentTensorData
Expand Down Expand Up @@ -113,6 +116,24 @@ internal class DefaultCpuOpsJvm(
?: ScalarBf16MatmulKernel
}

/**
* Q4_0 matmul kernel resolved via [KernelRegistry]. Mirrors
* [bf16MatmulKernel]: non-null, picks the highest-priority provider
* that carries a Q4_0 kernel (native FFM at 100, Panama Vector at
* 50), falling back to [ScalarQ4_0MatmulKernel] — the scalar SPI
* kernel is the floor (every `KernelProvider` carries one), so Q4_0
* has no pre-SPI legacy fallback to thread through.
*/
private val q4_0MatmulKernel: Q4_0MatmulKernel by lazy {
if (KernelRegistry.providers().isEmpty()) {
KernelServiceLoader.installAll()
}
KernelRegistry.providers()
.firstOrNull { it.isAvailable() && it.matmulQ4_0() != null }
?.matmulQ4_0()
?: ScalarQ4_0MatmulKernel
}

override fun <T : DType, V> add(a: Tensor<T, V>, b: Tensor<T, V>): Tensor<T, V> {
vectorFloatBinary(a, b, { x, y -> x.add(y) }) { x, y -> x + y }?.let { return it }
return super.add(a, b)
Expand Down Expand Up @@ -521,6 +542,22 @@ internal class DefaultCpuOpsJvm(
@Suppress("UNCHECKED_CAST")
CpuTensor(outData as TensorData<T, V>, this, a.dtype)
}
is Q4_0TensorData -> {
val outBuffer = FloatArray(batchSize * outputDim)
for (batch in 0 until batchSize) {
val batchInput = if (batchSize == 1) inputBuffer
else inputBuffer.copyOfRange(batch * inputDim, (batch + 1) * inputDim)
q4_0MatmulKernel.matmul(
batchInput, 0,
bData.packedData, 0,
inputDim, outputDim,
outBuffer, batch * outputDim,
)
}
val outData = DenseFloatArrayTensorData<T>(Shape(batchSize, outputDim), outBuffer)
@Suppress("UNCHECKED_CAST")
CpuTensor(outData as TensorData<T, V>, this, a.dtype)
}
is Q4_KTensorData -> {
val outBuffer = FloatArray(batchSize * outputDim)
val spiKernel = q4kMatmulKernel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ class KernelProviderSupportsTest {
p.supports("matmul", listOf("Float32", "Q8_0")),
"Q8_0 matmul support must mirror matmulQ8_0() != null",
)
assertEquals(
p.matmulQ4_0() != null,
p.supports("matmul", listOf("Float32", "Q4_0")),
"Q4_0 matmul support must mirror matmulQ4_0() != null",
)
}

@Test
Expand All @@ -62,6 +67,9 @@ class KernelProviderSupportsTest {
p.matmulQ4K() != null,
p.supports("matmul", listOf("Float32", "Q4_K")),
)
// Scalar carries the Q4_0 floor kernel, so the capability query
// must report it as supported.
assertTrue(p.supports("matmul", listOf("Float32", "Q4_0")))
}

@Test
Expand Down
Loading
Loading