Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import sk.ainet.lang.tensor.data.Q5_1TensorData
import sk.ainet.lang.tensor.data.Q5_1BlockTensorData
import sk.ainet.lang.tensor.data.Q5_0TensorData
import sk.ainet.lang.tensor.data.Q5_0BlockTensorData
import sk.ainet.lang.tensor.data.Q4_0BlockTensorData
import sk.ainet.lang.tensor.data.Q8_0BlockTensorData
import sk.ainet.lang.tensor.data.TensorData
import sk.ainet.lang.tensor.data.TensorDataFactory
Expand Down Expand Up @@ -607,12 +608,15 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory
is Q6_KTensorData -> return newTensor(Q6_KBlockTensorData(Shape(cols, rows), d.packedData) as TensorData<T, V>, tensor.dtype, tensor)
is Q5_1TensorData -> return newTensor(Q5_1BlockTensorData(Shape(cols, rows), d.packedData) as TensorData<T, V>, tensor.dtype, tensor)
is Q5_0TensorData -> return newTensor(Q5_0BlockTensorData(Shape(cols, rows), d.packedData) as TensorData<T, V>, tensor.dtype, tensor)
// Q8_0 lazy transpose: rewrap the same input-block-major bytes with
// flipped shape (bytes are layout-agnostic to the [out,in] kernel
// convention) so a packed Q8_0 weight (e.g. gemma's tied lm_head)
// Q8_0 / Q4_0 lazy transpose: rewrap the same input-block-major bytes
// with flipped shape (bytes are layout-agnostic to the [out,in] kernel
// convention) so a packed weight (e.g. gemma's tied Q8_0 lm_head)
// survives linearProject's transpose instead of hitting the generic
// FP32 path (Byte→Float ClassCastException). See transformers #178.
// FP32 path (Byte→Float ClassCastException). This `when` now covers
// every quant type chooseQuantizedMatmulHeap dispatches — i.e. every
// packed type that can be a matmul weight. See transformers #178.
is Q8_0TensorData -> return newTensor(Q8_0BlockTensorData(Shape(cols, rows), d.packedData) as TensorData<T, V>, tensor.dtype, tensor)
is Q4_0TensorData -> return newTensor(Q4_0BlockTensorData(Shape(cols, rows), d.packedData) as TensorData<T, V>, tensor.dtype, tensor)
else -> {}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@ package sk.ainet.exec.tensor.ops
import kotlin.math.abs
import kotlin.random.Random
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import sk.ainet.context.DirectCpuExecutionContext
import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.tensor.data.Q4_0BlockTensorData
import sk.ainet.lang.tensor.data.Q4_KBlockTensorData
import sk.ainet.lang.tensor.data.Q5_0BlockTensorData
import sk.ainet.lang.tensor.data.Q5_1BlockTensorData
import sk.ainet.lang.tensor.data.Q5_KBlockTensorData
import sk.ainet.lang.tensor.data.Q6_KBlockTensorData
import sk.ainet.lang.tensor.data.Q8_0BlockTensorData
import sk.ainet.lang.tensor.data.TensorData
import sk.ainet.lang.types.FP32

Expand Down Expand Up @@ -129,4 +134,40 @@ class PackedMatmulDispatchTest {
@Test fun q5_1_through_ops_matmul_transpose() = run("Q5_1", inDim = 128, outDim = 16, seed = 7)
@Test fun q4_k_through_ops_matmul_transpose() = run("Q4_K", inDim = 256, outDim = 12, seed = 8)
@Test fun q6_k_through_ops_matmul_transpose() = run("Q6_K", inDim = 512, outDim = 8, seed = 9)

/**
* `ops.transpose` must lazily rewrap EVERY packed quant type that can be a
* matmul weight (the full `chooseQuantizedMatmulHeap` set) — flipping the
* shape while keeping the same packed bytes — instead of falling into the
* generic FP32 path, which casts the Byte-backed buffer to Float and throws
* `ClassCastException`. Regression guard for transformers #178 (Q8_0/Q4_0
* were the gaps). Content-agnostic: zero bytes, sized per block geometry.
*/
@Test
fun transpose_preserves_every_packed_quant_type() {
val outDim = 8
// name -> (blockElems, bytesPerBlock, builder)
val cases: List<Triple<String, Pair<Int, Int>, (Shape, ByteArray) -> TensorData<FP32, Float>>> = listOf(
Triple("Q4_K", 256 to 144) { s, b -> Q4_KBlockTensorData(s, b) as TensorData<FP32, Float> },
Triple("Q5_K", 256 to 176) { s, b -> Q5_KBlockTensorData(s, b) as TensorData<FP32, Float> },
Triple("Q6_K", 256 to 210) { s, b -> Q6_KBlockTensorData(s, b) as TensorData<FP32, Float> },
Triple("Q8_0", 32 to 34) { s, b -> Q8_0BlockTensorData(s, b) as TensorData<FP32, Float> },
Triple("Q4_0", 32 to 18) { s, b -> Q4_0BlockTensorData(s, b) as TensorData<FP32, Float> },
Triple("Q5_0", 32 to 22) { s, b -> Q5_0BlockTensorData(s, b) as TensorData<FP32, Float> },
Triple("Q5_1", 32 to 24) { s, b -> Q5_1BlockTensorData(s, b) as TensorData<FP32, Float> },
)
for ((name, geom, build) in cases) {
val (blockElems, bpb) = geom
val inDim = blockElems // one block per row
val bytes = ByteArray(outDim * (inDim / blockElems) * bpb)
val w = ctx.fromData(build(Shape(outDim, inDim), bytes), FP32::class)
// The bug threw here for unhandled packed types.
val t = ctx.ops.transpose(w)
assertEquals(Shape(inDim, outDim), t.shape, "$name: transpose did not flip shape")
assertTrue(
t.data::class.simpleName?.contains("Block") == true,
"$name: transpose dropped the packed encoding (got ${t.data::class.simpleName})",
)
}
}
}
Loading