Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ public final class sk/ainet/exec/kernel/PanamaVectorQ4KMatmulKernel : sk/ainet/b
public fun matmul ([FI[BIII[FI)V
}

public final class sk/ainet/exec/kernel/PanamaVectorQ4_0MatmulKernel : sk/ainet/backend/api/kernel/Q4_0MatmulKernel {
public static final field INSTANCE Lsk/ainet/exec/kernel/PanamaVectorQ4_0MatmulKernel;
public fun matmul ([FI[BIII[FI)V
}

public final class sk/ainet/exec/kernel/PanamaVectorQ8_0MatmulKernel : sk/ainet/backend/api/kernel/Q8_0MatmulKernel {
public static final field INSTANCE Lsk/ainet/exec/kernel/PanamaVectorQ8_0MatmulKernel;
public fun matmul ([FI[BIII[FI)V
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import sk.ainet.backend.api.kernel.Bf16MatmulKernel
import sk.ainet.backend.api.kernel.Fp32MatmulKernel
import sk.ainet.backend.api.kernel.KernelProvider
import sk.ainet.backend.api.kernel.Q4KMatmulKernel
import sk.ainet.backend.api.kernel.Q4_0MatmulKernel
import sk.ainet.backend.api.kernel.Q8_0MatmulKernel
import sk.ainet.exec.tensor.ops.JvmCpuBackendConfig

Expand Down Expand Up @@ -49,6 +50,9 @@ public object PanamaVectorKernelProvider : KernelProvider {
override fun matmulQ8_0(): Q8_0MatmulKernel? =
if (isAvailable()) PanamaVectorQ8_0MatmulKernel else null

override fun matmulQ4_0(): Q4_0MatmulKernel? =
if (isAvailable()) PanamaVectorQ4_0MatmulKernel else null

private fun isVectorApiClassLoaded(): Boolean = runCatching {
Class.forName("jdk.incubator.vector.FloatVector")
Class.forName("jdk.incubator.vector.VectorSpecies")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package sk.ainet.exec.kernel

import jdk.incubator.vector.FloatVector
import jdk.incubator.vector.VectorOperators
import jdk.incubator.vector.VectorSpecies
import sk.ainet.backend.api.kernel.Q4_0MatmulKernel

/**
* SIMD-vectorized FP32 × Q4_0 matmul on the JDK Vector API.
*
* Pipeline per 32-element block:
* 1. Decode the 2-byte FP16 scale `d` once.
* 2. Unpack the 16 code bytes into 32 sign-corrected floats (`nibble - 8`)
* in a reusable scratch buffer, using the canonical ggml **split**
* layout (low nibbles → elements 0..15, high nibbles → 16..31). The
* nibble-pair-per-byte packing makes a fully-fused `ByteVector`
* pipeline awkward, so this kernel keeps the scratch-then-FMA shape
* (same approach as the legacy `JvmQuantizedVectorKernels` Q4_0 path).
* 3. SIMD-FMA the scratch against the matching input window into a
* lane-wise block accumulator, reduce across lanes, and fold `* d`
* exactly once per block.
*
* Numerical equivalence with [ScalarQ4_0MatmulKernel] is within FMA +
* reordered-reduction tolerance — the same bar the Q8_0 / Q4_K Panama
* kernels use.
*/
public object PanamaVectorQ4_0MatmulKernel : Q4_0MatmulKernel {

private const val BLOCK_SIZE = 32
private const val BYTES_PER_BLOCK = 18

private val floatSpecies: VectorSpecies<Float> = FloatVector.SPECIES_PREFERRED

override fun matmul(
input: FloatArray, inputOffset: Int,
weight: ByteArray, weightByteOffset: Int,
inputDim: Int, outputDim: Int,
output: FloatArray, outputOffset: Int,
) {
require(inputDim % BLOCK_SIZE == 0) {
"PanamaVectorQ4_0MatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim"
}
if (outputDim == 0) return
if (inputDim == 0) {
for (o in 0 until outputDim) output[outputOffset + o] = 0f
return
}
val blocksPerInputDim = inputDim / BLOCK_SIZE
val step = floatSpecies.length()
val loopBound = floatSpecies.loopBound(BLOCK_SIZE)
val codeBuf = FloatArray(BLOCK_SIZE)

for (o in 0 until outputDim) {
var acc = 0f
for (blockIdx in 0 until blocksPerInputDim) {
val blockBase = weightByteOffset + (blockIdx * outputDim + o) * BYTES_PER_BLOCK
// FP16 scale — two LE bytes.
val dBits = (weight[blockBase].toInt() and 0xFF) or
((weight[blockBase + 1].toInt() and 0xFF) shl 8)
val d = halfToFloat(dBits)

// Split-layout unpack: low nibbles → 0..15, high → 16..31.
val codesBase = blockBase + 2
for (j in 0 until 16) {
val b = weight[codesBase + j].toInt() and 0xFF
codeBuf[j] = ((b and 0x0F) - 8).toFloat()
codeBuf[16 + j] = ((b ushr 4) - 8).toFloat()
}

val inputBase = inputOffset + blockIdx * BLOCK_SIZE
var blockAccVec = FloatVector.zero(floatSpecies)
var k = 0
while (k < loopBound) {
val inV = FloatVector.fromArray(floatSpecies, input, inputBase + k)
val cV = FloatVector.fromArray(floatSpecies, codeBuf, k)
blockAccVec = inV.fma(cV, blockAccVec)
k += step
}
var blockAcc = blockAccVec.reduceLanes(VectorOperators.ADD)
// Scalar tail (only if floatSpecies.length() doesn't divide 32 — rare).
while (k < BLOCK_SIZE) {
blockAcc += input[inputBase + k] * codeBuf[k]
k++
}
acc += blockAcc * d
}
output[outputOffset + o] = acc
}
}

/** Same FP16 → FP32 conversion as [ScalarQ4_0MatmulKernel]. */
private fun halfToFloat(hbits: Int): Float {
val sign = (hbits and 0x8000) shl 16
val exp = (hbits and 0x7C00) shr 10
val mant = hbits and 0x03FF
return when (exp) {
0 -> {
if (mant == 0) Float.fromBits(sign)
else {
var m = mant
var e = -14
while ((m and 0x400) == 0) {
m = m shl 1
e--
}
m = m and 0x3FF
Float.fromBits(sign or ((e + 127) shl 23) or (m shl 13))
}
}
31 -> Float.fromBits(sign or (0xFF shl 23) or (mant shl 13))
else -> Float.fromBits(sign or ((exp - 15 + 127) shl 23) or (mant shl 13))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -549,13 +549,14 @@ internal object JvmQuantizedVectorKernels {
// Read f16 scale
val scale = halfToFloat(read2BytesLE(weightSeg, blockByteOffset))

// Unpack 16 packed bytes → 32 sign-corrected nibbles. Two
// nibbles per byte load means half the byte traffic of the
// straight scalar dot product.
// Unpack 16 packed bytes → 32 sign-corrected nibbles in the
// canonical ggml *split* layout: low nibbles decode elements
// 0..15, high nibbles decode elements 16..31. (Matches
// DequantOps.dequantQ4_0FromBytes and Q4_0BlockTensorData.)
for (k in 0 until 16) {
val b = weightSeg.get(JAVA_BYTE_LE, codesOffset + k.toLong()).toInt() and 0xFF
codeBuf[2 * k] = (b and 0x0F).toFloat() - 8f
codeBuf[2 * k + 1] = (b ushr 4).toFloat() - 8f
codeBuf[k] = (b and 0x0F).toFloat() - 8f
codeBuf[16 + k] = (b ushr 4).toFloat() - 8f
}

// SIMD FMA dot product.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package sk.ainet.exec.kernel

import kotlin.math.abs
import kotlin.random.Random
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue

/**
* Numerical parity tests for [PanamaVectorQ4_0MatmulKernel] against
* [ScalarQ4_0MatmulKernel]. Both kernels apply the same FP16-scale
* decode + `(nibble - 8)` dequant in the canonical ggml split layout;
* differences come from FMA + reordered-reduction order only.
*
* Tolerance scales with the number of Q4_0 blocks processed: `1e-2 *
* blocksPerInputDim`, clamped to a `1e-2` floor — mirrors the Q8_0
* parity test convention.
*/
class PanamaVectorQ4_0MatmulKernelParityTest {

private val blockSize = 32
private val bytesPerBlock = 18

/** Random Q4_0 packed bytes; scales clamped to a small positive FP16. */
private fun randomQ4_0Bytes(blocksPerInputDim: Int, outputDim: Int, seed: Int): ByteArray {
val rng = Random(seed)
val numBlocks = blocksPerInputDim * outputDim
val bytes = ByteArray(numBlocks * bytesPerBlock)
rng.nextBytes(bytes)
for (block in 0 until numBlocks) {
val base = block * bytesPerBlock
bytes[base + 0] = 0x00.toByte()
bytes[base + 1] = 0x22.toByte() // FP16 0x2200 ≈ 7.6e-3
}
return bytes
}

private fun assertParity(
inputDim: Int,
outputDim: Int,
seed: Int,
tolPerBlock: Float = 1e-2f,
) {
val blocksPerInputDim = inputDim / blockSize
val rng = Random(seed)
val input = FloatArray(inputDim) { rng.nextFloat() - 0.5f }
val weight = randomQ4_0Bytes(blocksPerInputDim, outputDim, seed)
val outScalar = FloatArray(outputDim)
val outPanama = FloatArray(outputDim)

ScalarQ4_0MatmulKernel.matmul(input, 0, weight, 0, inputDim, outputDim, outScalar, 0)
PanamaVectorQ4_0MatmulKernel.matmul(input, 0, weight, 0, inputDim, outputDim, outPanama, 0)

val tol = (tolPerBlock * blocksPerInputDim.coerceAtLeast(1)).coerceAtLeast(tolPerBlock)
for (i in outScalar.indices) {
val diff = abs(outScalar[i] - outPanama[i])
assertTrue(
diff <= tol,
"mismatch at $i: scalar=${outScalar[i]} panama=${outPanama[i]} diff=$diff tol=$tol",
)
}
}

@Test fun single_block_single_output_matches_scalar() =
assertParity(inputDim = 32, outputDim = 1, seed = 1)

@Test fun single_block_multiple_outputs_matches_scalar() =
assertParity(inputDim = 32, outputDim = 7, seed = 2)

@Test fun multiple_blocks_single_output_matches_scalar() =
assertParity(inputDim = 256, outputDim = 1, seed = 3)

@Test fun llm_typical_attention_proj_matches_scalar() =
assertParity(inputDim = 512, outputDim = 512, seed = 4)

@Test fun llm_typical_ffn_proj_matches_scalar() =
assertParity(inputDim = 256, outputDim = 1024, seed = 5)

@Test fun rejects_non_block_aligned_input_dim() {
assertFailsWith<IllegalArgumentException> {
PanamaVectorQ4_0MatmulKernel.matmul(
FloatArray(31), 0,
ByteArray(bytesPerBlock), 0,
31, 1,
FloatArray(1), 0,
)
}
}

@Test fun zero_input_dim_zeros_output() {
val out = FloatArray(5) { 9f }
PanamaVectorQ4_0MatmulKernel.matmul(
FloatArray(0), 0,
ByteArray(0), 0,
0, 5,
out, 0,
)
for (v in out) assertEquals(0f, v, "output should be zeroed for inputDim=0")
}

@Test fun provider_returns_panama_q4_0_when_available() {
val kernel = PanamaVectorKernelProvider.matmulQ4_0()
if (PanamaVectorKernelProvider.isAvailable()) {
assertTrue(
kernel === PanamaVectorQ4_0MatmulKernel,
"Provider must hand out the Panama Q4_0 kernel when available",
)
} else {
assertEquals(null, kernel, "Provider must return null when Vector API unavailable")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class QuantizedMemSegMatmulTest {

/**
* Encode a single Q4_0 block: 32 float values -> 18 bytes (2 scale + 16 packed nibbles).
* Uses the canonical ggml *split* layout: code[j] is the low nibble of
* byte j, code[j+16] is the high nibble of byte j.
*/
private fun encodeQ4_0Block(values: FloatArray): ByteArray {
require(values.size == 32)
Expand All @@ -62,8 +64,8 @@ class QuantizedMemSegMatmulTest {
val out = ByteArray(18)
out[0] = (scaleHalf and 0xFF).toByte()
out[1] = ((scaleHalf shr 8) and 0xFF).toByte()
for (i in 0 until 16) {
out[2 + i] = ((codes[2 * i + 1] shl 4) or codes[2 * i]).toByte()
for (j in 0 until 16) {
out[2 + j] = ((codes[j + 16] shl 4) or codes[j]).toByte()
}
return out
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ add_library(skainet_kernels SHARED
src/fp32_matmul.c
src/bf16_matmul.c
src/q8_0_matmul.c
src/q4_0_matmul.c
)

target_include_directories(skainet_kernels PUBLIC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,32 @@ SKAINET_API void skainet_q8_0_matmul(
int32_t output_offset
);

/*
* Q4_0 matrix-vector multiply.
*
* output[output_offset + o] = sum_j input[input_offset + j] *
* dequant(weight[block, o, j])
*
* Block layout: canonical ggml Q4_0, 32 elements per block, 18 bytes
* per block (2 B FP16 scale + 16 B packed 4-bit codes in split layout —
* low nibbles → elements 0..15, high nibbles → 16..31), with packed
* weights laid out as
* weight + weight_byte_offset + (block_idx * output_dim + o) * 18
*
* Dequant per element: `(code - 8) * d`. input_dim must be a multiple
* of 32.
*/
SKAINET_API void skainet_q4_0_matmul(
const float* input,
int32_t input_offset,
const uint8_t* weight,
int32_t weight_byte_offset,
int32_t input_dim,
int32_t output_dim,
float* output,
int32_t output_offset
);

#ifdef __cplusplus
}
#endif
Expand Down
Loading
Loading