Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ public class sk/ainet/exec/tensor/ops/DefaultCpuOpsBase : sk/ainet/lang/tensor/o
public fun addScalar (Lsk/ainet/lang/tensor/Tensor;Ljava/lang/Number;)Lsk/ainet/lang/tensor/Tensor;
public fun avgPool2d (Lsk/ainet/lang/tensor/Tensor;Lkotlin/Pair;Lkotlin/Pair;Lkotlin/Pair;Z)Lsk/ainet/lang/tensor/Tensor;
protected final fun broadcastShapes (Lsk/ainet/lang/tensor/Shape;Lsk/ainet/lang/tensor/Shape;)Lsk/ainet/lang/tensor/Shape;
protected final fun chooseQuantizedMatmulHeap (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
public fun clamp (Lsk/ainet/lang/tensor/Tensor;FF)Lsk/ainet/lang/tensor/Tensor;
public fun concat (Ljava/util/List;I)Lsk/ainet/lang/tensor/Tensor;
public fun conv1d (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;IIII)Lsk/ainet/lang/tensor/Tensor;
Expand All @@ -192,6 +193,7 @@ public class sk/ainet/exec/tensor/ops/DefaultCpuOpsBase : sk/ainet/lang/tensor/o
public fun divide (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
protected final fun elementwise (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lkotlin/jvm/functions/Function3;)Lsk/ainet/lang/tensor/Tensor;
public fun elu (Lsk/ainet/lang/tensor/Tensor;F)Lsk/ainet/lang/tensor/Tensor;
protected fun ensureKernelProviders ()V
public fun exp (Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
public fun expm1 (Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
public fun flatten (Lsk/ainet/lang/tensor/Tensor;II)Lsk/ainet/lang/tensor/Tensor;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
package sk.ainet.exec.tensor.ops

import sk.ainet.backend.api.kernel.KernelRegistry
import sk.ainet.exec.kernel.ScalarKernelProvider
import sk.ainet.lang.tensor.data.TensorDataFactory
import sk.ainet.lang.tensor.ops.TensorOps

internal actual fun platformDefaultCpuOpsFactory(): (TensorDataFactory) -> TensorOps =
{ factory -> DefaultCpuOps(factory) }
internal actual fun platformDefaultCpuOpsFactory(): (TensorDataFactory) -> TensorOps {
// Non-JVM has no ServiceLoader; register the scalar packed-quant kernels
// (Q4_K/Q6_K/Q5_1/Q5_0/Q8_0/Q4_0) so DefaultCpuOpsBase can dispatch them.
KernelRegistry.register(ScalarKernelProvider)
return { factory -> DefaultCpuOps(factory) }
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package sk.ainet.exec.tensor.ops

import sk.ainet.backend.api.kernel.KernelRegistry
import sk.ainet.exec.kernel.ScalarKernelProvider
import sk.ainet.lang.tensor.data.TensorDataFactory
import sk.ainet.lang.tensor.ops.TensorOps

internal actual fun platformDefaultCpuOpsFactory(): (TensorDataFactory) -> TensorOps {
println("[SKaiNET] Using Accelerate-backed CPU operations (ARM NEON + AMX)")
// Accelerate overrides dense FP32 matmul; packed-quant weights still flow through
// DefaultCpuOpsBase, so register the scalar packed kernels (no ServiceLoader on Native).
KernelRegistry.register(ScalarKernelProvider)
return { factory -> AccelerateCpuOps(factory) }
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,20 @@ import sk.ainet.lang.types.DType
import sk.ainet.lang.ops.Backend
import sk.ainet.lang.ops.TensorOp
import sk.ainet.lang.ops.InProgress
import sk.ainet.backend.api.kernel.KernelProvider
import sk.ainet.backend.api.kernel.KernelRegistry
import sk.ainet.lang.tensor.data.FloatArrayTensorData
import sk.ainet.lang.tensor.data.IntArrayTensorData
import sk.ainet.lang.tensor.data.Q4_0TensorData
import sk.ainet.lang.tensor.data.Q8_0TensorData
import sk.ainet.lang.tensor.data.Q4_KTensorData
import sk.ainet.lang.tensor.data.Q4_KBlockTensorData
import sk.ainet.lang.tensor.data.Q6_KTensorData
import sk.ainet.lang.tensor.data.Q6_KBlockTensorData
import sk.ainet.lang.tensor.data.Q5_1TensorData
import sk.ainet.lang.tensor.data.Q5_1BlockTensorData
import sk.ainet.lang.tensor.data.Q5_0TensorData
import sk.ainet.lang.tensor.data.Q5_0BlockTensorData
import sk.ainet.lang.tensor.data.TensorData
import sk.ainet.lang.tensor.data.TensorDataFactory
import sk.ainet.lang.tensor.ops.UpsampleMode
Expand Down Expand Up @@ -304,13 +316,74 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory
}

@TensorOp()
/**
* Hook to populate [KernelRegistry] before the platform-neutral packed-quant
* dispatch resolves kernels. No-op in the base (callers register providers
* directly, e.g. the non-JVM platform factories register [ScalarKernelProvider]);
* the JVM ops override this to auto-install ServiceLoader-discovered providers.
*/
protected open fun ensureKernelProviders() {}

private fun resolveProvider(test: (KernelProvider) -> Boolean): KernelProvider? {
ensureKernelProviders()
return KernelRegistry.providers().firstOrNull { it.isAvailable() && test(it) }
}

private val q8_0Kernel by lazy { resolveProvider { it.matmulQ8_0() != null }?.matmulQ8_0() }
private val q4_0Kernel by lazy { resolveProvider { it.matmulQ4_0() != null }?.matmulQ4_0() }
private val q4kKernel by lazy { resolveProvider { it.matmulQ4K() != null }?.matmulQ4K() }
private val q6kKernel by lazy { resolveProvider { it.matmulQ6K() != null }?.matmulQ6K() }
private val q5_1Kernel by lazy { resolveProvider { it.matmulQ5_1() != null }?.matmulQ5_1() }
private val q5_0Kernel by lazy { resolveProvider { it.matmulQ5_0() != null }?.matmulQ5_0() }

/**
* Platform-neutral packed-quant matmul: `FP32 input × packed-quant weight`,
* resolving the kernel via [KernelRegistry] (scalar on Native/JS/WASM, Panama/
* native-FFM on JVM). Returns `null` when the weight isn't a heap-packed quant
* type or no provider carries a kernel, so callers fall through. The JVM ops
* intercept Q4_K/Q6_K/Q8_0/Q4_0 (+ MemSeg) before this runs; Q5_1/Q5_0 (and the
* whole set on non-JVM) resolve here.
*/
protected fun <T : DType, V> chooseQuantizedMatmulHeap(a: Tensor<T, V>, b: Tensor<T, V>): Tensor<T, V>? {
if (a.dtype != FP32::class || a.shape.rank != 2 || b.shape.rank != 2) return null
if (a.shape[1] != b.shape[0]) return null
val inputBuffer = (a.data as? FloatArrayTensorData<*>)?.buffer ?: return null
val batchSize = a.shape[0]
val inputDim = a.shape[1]
val outputDim = b.shape[1]

fun run(packed: ByteArray, kernel: (FloatArray, Int, ByteArray, Int, Int, Int, FloatArray, Int) -> Unit): Tensor<T, V> {
val out = FloatArray(batchSize * outputDim)
for (batch in 0 until batchSize) {
val bi = if (batchSize == 1) inputBuffer else inputBuffer.copyOfRange(batch * inputDim, (batch + 1) * inputDim)
kernel(bi, 0, packed, 0, inputDim, outputDim, out, batch * outputDim)
}
@Suppress("UNCHECKED_CAST")
val outData = dataFactory.fromFloatArray<T, Float>(Shape(batchSize, outputDim), a.dtype, out) as TensorData<T, V>
return newTensor(outData, a.dtype, a, b)
}

return when (val bd = b.data) {
is Q5_1TensorData -> q5_1Kernel?.let { k -> run(bd.packedData, k::matmul) }
is Q5_0TensorData -> q5_0Kernel?.let { k -> run(bd.packedData, k::matmul) }
is Q4_KTensorData -> q4kKernel?.let { k -> run(bd.packedData, k::matmul) }
is Q6_KTensorData -> q6kKernel?.let { k -> run(bd.packedData, k::matmul) }
is Q8_0TensorData -> q8_0Kernel?.let { k -> run(bd.packedData, k::matmul) }
is Q4_0TensorData -> q4_0Kernel?.let { k -> run(bd.packedData, k::matmul) }
else -> null
}
}

override fun <T : DType, V> matmul(
a: Tensor<T, V>,
b: Tensor<T, V>
): Tensor<T, V> {
require(a.rank >= 1 && b.rank >= 1) { "Matrix multiplication requires tensors with at least 1 dimension per operand" }
require(a.dtype == b.dtype) { "DType mismatch: ${a.dtype} vs ${b.dtype}" }

// Packed-quant fast path (FP32 input × packed weight), resolved via KernelRegistry.
chooseQuantizedMatmulHeap(a, b)?.let { return it }

// Fast path: 2D × 2D with FloatArray backing — direct buffer access, no per-element allocation
if (a.rank == 2 && b.rank == 2
&& (a.dtype == FP32::class)
Expand Down Expand Up @@ -516,6 +589,22 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory
val rows = tensor.shape[rank - 2]
val cols = tensor.shape[rank - 1]

// Lazy transpose for heap-packed quant weights (Q4_K/Q6_K/Q5_1/Q5_0): the
// matmul kernels index the packed bytes input-block-major from the post-swap
// (inputDim, outputDim), so transpose is a pure shape swap — same bytes, no copy.
// Lets `ops.matmul(x, ops.transpose(W))` run on every platform without a dequant
// round-trip. (The JVM ops intercept Q4_K/Q6_K + MemSeg before reaching here.)
if (rank == 2) {
@Suppress("UNCHECKED_CAST")
when (val d = tensor.data) {
is Q4_KTensorData -> return newTensor(Q4_KBlockTensorData(Shape(cols, rows), d.packedData) as TensorData<T, V>, tensor.dtype, tensor)
is Q6_KTensorData -> return newTensor(Q6_KBlockTensorData(Shape(cols, rows), d.packedData) as TensorData<T, V>, tensor.dtype, tensor)
is Q5_1TensorData -> return newTensor(Q5_1BlockTensorData(Shape(cols, rows), d.packedData) as TensorData<T, V>, tensor.dtype, tensor)
is Q5_0TensorData -> return newTensor(Q5_0BlockTensorData(Shape(cols, rows), d.packedData) as TensorData<T, V>, tensor.dtype, tensor)
else -> {}
}
}

// Fast path: 2D float tensor — direct buffer swap
if (rank == 2 && tensor.data is FloatArrayTensorData<*>) {
val buf = (tensor.data as FloatArrayTensorData<*>).buffer
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package sk.ainet.exec.tensor.ops

import kotlin.math.abs
import kotlin.random.Random
import kotlin.test.Test
import kotlin.test.assertTrue
import sk.ainet.context.DirectCpuExecutionContext
import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.tensor.data.Q4_KBlockTensorData
import sk.ainet.lang.tensor.data.Q5_1BlockTensorData
import sk.ainet.lang.tensor.data.TensorData
import sk.ainet.lang.types.FP32

/**
* End-to-end proof that packed-quant weights flow through `ctx.ops.matmul(x, ops.transpose(W))`
* on EVERY platform — exercising the lazy-transpose shape-swap + `chooseQuantizedMatmulHeap` in
* DefaultCpuOpsBase, resolving the registered kernel (scalar on Native/JS/WASM, Panama/FFM on JVM).
* Runs on jvmTest AND linuxX64Test; a green linuxX64 run is the headline "Native packed matmul works".
*/
class PackedMatmulDispatchTest {

private val ctx = DirectCpuExecutionContext()

private fun half(v: Float): Int {
val b = v.toRawBits(); val s = (b ushr 16) and 0x8000
val e = ((b ushr 23) and 0xFF) - 127 + 15; val m = b and 0x7FFFFF
if (e <= 0) return s; if (e >= 31) return s or 0x7C00
return s or (e shl 10) or (m ushr 13)
}
private fun le16(b: ByteArray, o: Int, h: Int) { b[o] = (h and 0xFF).toByte(); b[o + 1] = ((h ushr 8) and 0xFF).toByte() }

/** Random block-major Q5_1 bytes for [out,in] + the FP32 weight they dequantize to (row-major). */
private fun q5_1(inDim: Int, outDim: Int, rng: Random): Pair<ByteArray, FloatArray> {
val blocks = inDim / 32; val bytes = ByteArray(outDim * blocks * 24); val wf = FloatArray(outDim * inDim)
for (o in 0 until outDim) for (bI in 0 until blocks) {
val off = (bI * outDim + o) * 24; val dst = o * inDim + bI * 32
val d = rng.nextFloat() * 0.05f + 0.01f; val m = rng.nextFloat() - 0.5f
le16(bytes, off, half(d)); le16(bytes, off + 2, half(m))
val qh = IntArray(4) { rng.nextInt(256) }; for (k in 0 until 4) bytes[off + 4 + k] = qh[k].toByte()
for (k in 0 until 16) bytes[off + 8 + k] = rng.nextInt(256).toByte()
for (j in 0 until 16) {
val q = bytes[off + 8 + j].toInt() and 0xFF
val bl = (qh[j / 8] ushr (j % 8)) and 1; val bh = (qh[(j + 16) / 8] ushr ((j + 16) % 8)) and 1
wf[dst + j] = d * ((q and 0xF) + (bl shl 4)) + m; wf[dst + 16 + j] = d * ((q ushr 4) + (bh shl 4)) + m
}
}
return bytes to wf
}

/** Random block-major Q4_K bytes for [out,in] + the FP32 weight. */
private fun q4_k(inDim: Int, outDim: Int, rng: Random): Pair<ByteArray, FloatArray> {
val blocks = inDim / 256; val bytes = ByteArray(outDim * blocks * 144); val wf = FloatArray(outDim * inDim)
for (o in 0 until outDim) for (bI in 0 until blocks) {
val off = (bI * outDim + o) * 144; val dst = o * inDim + bI * 256
val d = rng.nextFloat() * 0.02f + 0.005f; val dMin = rng.nextFloat() * 0.02f + 0.005f
le16(bytes, off, half(d)); le16(bytes, off + 2, half(dMin))
for (k in 0 until 140) bytes[off + 4 + k] = rng.nextInt(256).toByte()
val sc = off + 4; val si = IntArray(8); val mi = IntArray(8)
for (s in 0 until 4) { si[s] = bytes[sc + s].toInt() and 0x3F; mi[s] = bytes[sc + s + 4].toInt() and 0x3F }
for (s in 4 until 8) {
si[s] = (bytes[sc + s + 4].toInt() and 0x0F) or (((bytes[sc + s - 4].toInt() and 0xFF) ushr 6) shl 4)
mi[s] = ((bytes[sc + s + 4].toInt() and 0xFF) ushr 4) or (((bytes[sc + s].toInt() and 0xFF) ushr 6) shl 4)
}
val codes = off + 16
for (g in 0 until 4) for (h in 0 until 2) {
val s = 2 * g + h
for (i in 0 until 32) {
val by = bytes[codes + g * 32 + i].toInt() and 0xFF
val code = if (h == 0) (by and 0x0F) else (by ushr 4)
wf[dst + s * 32 + i] = code * (d * si[s]) - dMin * mi[s]
}
}
}
return bytes to wf
}

private fun run(fmt: String, inDim: Int, outDim: Int, seed: Int) {
val rng = Random(seed)
val (bytes, wf) = if (fmt == "Q5_1") q5_1(inDim, outDim, rng) else q4_k(inDim, outDim, rng)
@Suppress("UNCHECKED_CAST")
val w = ctx.fromData(
(if (fmt == "Q5_1") Q5_1BlockTensorData(Shape(outDim, inDim), bytes)
else Q4_KBlockTensorData(Shape(outDim, inDim), bytes)) as TensorData<FP32, Float>,
FP32::class,
)
val xf = FloatArray(inDim) { rng.nextFloat() - 0.5f }
val x = ctx.fromFloatArray<FP32, Float>(Shape(1, inDim), FP32::class, xf)
val out = ctx.ops.matmul(x, ctx.ops.transpose(w)).data.copyToFloatArray()
val expected = FloatArray(outDim) { o -> var s = 0f; for (i in 0 until inDim) s += xf[i] * wf[o * inDim + i]; s }
var maxErr = 0f; var maxAbs = 1f
for (o in 0 until outDim) { maxErr = maxOf(maxErr, abs(expected[o] - out[o])); maxAbs = maxOf(maxAbs, abs(expected[o])) }
assertTrue(maxErr < 5e-3f * maxAbs, "$fmt e2e matmul deviates: maxErr=$maxErr (maxAbs=$maxAbs)")
}

@Test fun q5_1_through_ops_matmul_transpose() = run("Q5_1", inDim = 128, outDim = 16, seed = 7)
@Test fun q4_k_through_ops_matmul_transpose() = run("Q4_K", inDim = 256, outDim = 12, seed = 8)
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
package sk.ainet.exec.tensor.ops

internal actual fun platformDefaultCpuOpsFactory(): (sk.ainet.lang.tensor.data.TensorDataFactory) -> sk.ainet.lang.tensor.ops.TensorOps =
{ factory -> DefaultCpuOps(factory) }
import sk.ainet.backend.api.kernel.KernelRegistry
import sk.ainet.exec.kernel.ScalarKernelProvider

internal actual fun platformDefaultCpuOpsFactory(): (sk.ainet.lang.tensor.data.TensorDataFactory) -> sk.ainet.lang.tensor.ops.TensorOps {
KernelRegistry.register(ScalarKernelProvider)
return { factory -> DefaultCpuOps(factory) }
}
Loading
Loading