From 1e08b4208665ec31113c4335367e46626261b51e Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 1 Jun 2026 08:25:32 +0200 Subject: [PATCH 01/10] feat(hlo): DSL prescribes element dtype for placeholder weights MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Re-applied on develop (0.25.0). Norm layers created void placeholder weights with `Any::class as KClass`, erasing the element type to Object — which breaks weight-free graph tracing (VoidTensorOps alloc ops -> zeros(Object) throw "Unsupported dtype: Object"), blocking DAG->StableHLO before weights load. Fix the root cause: the DSL prescribes the logical element type. Layer/Group/ BatchNormalization get a non-breaking `dtype: KClass? = null` (used as `(dtype ?: Any::class)`); the NetworkBuilder DSL builders pass `dtype = kClass`. Real (possibly dequantized) weights still override at load time. (RMSNormalization moved out of core in 0.25.0 -> handled transformer-side.) Co-Authored-By: Claude Opus 4.8 (1M context) --- .../sk/ainet/lang/nn/dsl/NetworkBuilder.kt | 18 ++++++++++++------ .../nn/normalization/BatchNormalization.kt | 8 +++++--- .../nn/normalization/GroupNormalization.kt | 8 +++++--- .../nn/normalization/LayerNormalization.kt | 8 +++++--- 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/NetworkBuilder.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/NetworkBuilder.kt index 40dd6ac4..4605d2af 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/NetworkBuilder.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/NetworkBuilder.kt @@ -1299,7 +1299,8 @@ public class StageImpl( eps = eps, momentum = momentum, affine = affine, - name = getDefaultName(id, "BatchNorm", modules.size) + name = getDefaultName(id, "BatchNorm", modules.size), + dtype = kClass ) ) } @@ -1317,7 +1318,8 @@ public class StageImpl( numChannels = numChannels, eps = eps, affine = affine, - name = getDefaultName(id, "GroupNorm", modules.size) + name = getDefaultName(id, "GroupNorm", modules.size), + dtype = kClass ) ) } @@ -1333,7 +1335,8 @@ public class StageImpl( normalizedShape = normalizedShape, eps = eps, elementwiseAffine = elementwiseAffine, - name = getDefaultName(id, "LayerNorm", modules.size) + name = getDefaultName(id, "LayerNorm", modules.size), + dtype = kClass ) ) } @@ -1675,7 +1678,8 @@ public class NeuralNetworkDslImpl( eps = eps, momentum = momentum, affine = affine, - name = getDefaultName(id, "BatchNorm", modules.size) + name = getDefaultName(id, "BatchNorm", modules.size), + dtype = kClass ) ) } @@ -1693,7 +1697,8 @@ public class NeuralNetworkDslImpl( numChannels = numChannels, eps = eps, affine = affine, - name = getDefaultName(id, "GroupNorm", modules.size) + name = getDefaultName(id, "GroupNorm", modules.size), + dtype = kClass ) ) } @@ -1709,7 +1714,8 @@ public class NeuralNetworkDslImpl( normalizedShape = normalizedShape, eps = eps, elementwiseAffine = elementwiseAffine, - name = getDefaultName(id, "LayerNorm", modules.size) + name = getDefaultName(id, "LayerNorm", modules.size), + dtype = kClass ) ) } diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/normalization/BatchNormalization.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/normalization/BatchNormalization.kt index 8540da78..943c617a 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/normalization/BatchNormalization.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/normalization/BatchNormalization.kt @@ -31,7 +31,9 @@ public class BatchNormalization( private val affine: Boolean = true, override val name: String = "BatchNormalization", initGamma: Tensor? = null, - initBeta: Tensor? = null + initBeta: Tensor? = null, + // Logical element type prescribed by the DSL; keeps placeholder weights typed. + private val dtype: KClass? = null, ) : Module(), ModuleParameters { // Running statistics for inference mode @@ -61,7 +63,7 @@ public class BatchNormalization( override fun get(vararg indices: Int): V = 1.0f as V override fun set(vararg indices: Int, value: V) {} }, - Any::class as KClass + (dtype ?: Any::class) as KClass ) } @@ -75,7 +77,7 @@ public class BatchNormalization( override fun get(vararg indices: Int): V = 0.0f as V override fun set(vararg indices: Int, value: V) {} }, - Any::class as KClass + (dtype ?: Any::class) as KClass ) } diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/normalization/GroupNormalization.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/normalization/GroupNormalization.kt index fbf247c2..4140e515 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/normalization/GroupNormalization.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/normalization/GroupNormalization.kt @@ -29,7 +29,9 @@ public class GroupNormalization( private val affine: Boolean = true, override val name: String = "GroupNormalization", initGamma: Tensor? = null, - initBeta: Tensor? = null + initBeta: Tensor? = null, + // Logical element type prescribed by the DSL; keeps placeholder weights typed. + private val dtype: KClass? = null, ) : Module(), ModuleParameters { init { @@ -62,7 +64,7 @@ public class GroupNormalization( override fun get(vararg indices: Int): V = 1.0f as V override fun set(vararg indices: Int, value: V) {} }, - Any::class as KClass + (dtype ?: Any::class) as KClass ) } @@ -76,7 +78,7 @@ public class GroupNormalization( override fun get(vararg indices: Int): V = 0.0f as V override fun set(vararg indices: Int, value: V) {} }, - Any::class as KClass + (dtype ?: Any::class) as KClass ) } diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/normalization/LayerNormalization.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/normalization/LayerNormalization.kt index 1522e4a5..18d27a34 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/normalization/LayerNormalization.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/normalization/LayerNormalization.kt @@ -27,7 +27,9 @@ public class LayerNormalization( private val elementwiseAffine: Boolean = true, override val name: String = "LayerNormalization", initGamma: Tensor? = null, - initBeta: Tensor? = null + initBeta: Tensor? = null, + // Logical element type prescribed by the DSL; keeps placeholder weights typed. + private val dtype: KClass? = null, ) : Module(), ModuleParameters { override val params: List> = if (elementwiseAffine) { @@ -52,7 +54,7 @@ public class LayerNormalization( override fun get(vararg indices: Int): V = 1.0f as V override fun set(vararg indices: Int, value: V) {} }, - Any::class as KClass + (dtype ?: Any::class) as KClass ) } @@ -66,7 +68,7 @@ public class LayerNormalization( override fun get(vararg indices: Int): V = 0.0f as V override fun set(vararg indices: Int, value: V) {} }, - Any::class as KClass + (dtype ?: Any::class) as KClass ) } From ced96c08131616e9df75e3050a6354a3660eba61 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 1 Jun 2026 08:44:14 +0200 Subject: [PATCH 02/10] feat(hlo): add permute + narrow StableHLO converters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes the last gaps for lowering a gemma3 attention block to StableHLO (the other ops — matmul/reshape/transpose/squeeze/unsqueeze/concat/softmax/gather/ sqrt/addScalar/scaledDotProductAttention — are already covered on develop). - permute: arbitrary-axis transpose. convertTranspose already reads the `axes` parameter, so register `permute` as an alias routed to it. - narrow(dim,start,length): single-axis stablehlo.slice. Reads dim/start/length (the keys the graph tape records), builds start/limit/stride per dim. - NarrowPermuteConverterTest verifies both: permute -> transpose dims=[0,2,1], narrow -> slice start_indices=[0,2] limit_indices=[2,6]. Full hlo suite green. Remaining for full gemma: `split` (multi-output, used by RoPE). Co-Authored-By: Claude Opus 4.8 (1M context) --- .../converters/LinalgOperationsConverter.kt | 7 +- .../converters/ShapeOperationsConverter.kt | 64 ++++++++++++++- .../compile/hlo/NarrowPermuteConverterTest.kt | 77 +++++++++++++++++++ 3 files changed, 145 insertions(+), 3 deletions(-) create mode 100644 skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NarrowPermuteConverterTest.kt diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/LinalgOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/LinalgOperationsConverter.kt index 1f5744de..ecc78757 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/LinalgOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/LinalgOperationsConverter.kt @@ -26,7 +26,10 @@ public class LinalgOperationsConverter : StableHloOperationConverter { override val supportedOperations: Set = setOf( "matmul", "transpose", // Common aliases - "dot", "mm", "bmm", "batch_matmul" + "dot", "mm", "bmm", "batch_matmul", + // permute is an arbitrary-axis transpose; convertTranspose already + // reads the `axes` parameter, so route it through the same lowering. + "permute" ) override fun convert( @@ -41,7 +44,7 @@ public class LinalgOperationsConverter : StableHloOperationConverter { // where `[1] x [0]` is wrong — contract last/second-to-last. "matmul", "dot", "mm", "bmm", "batch_matmul" -> convertBatchMatmul(node, operands, context) - "transpose" -> convertTranspose(node, operands, context) + "transpose", "permute" -> convertTranspose(node, operands, context) else -> ConversionResult.Unsupported( node.operation.name, "Operation not supported by LinalgOperationsConverter" diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt index 86608118..a21303f4 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt @@ -27,7 +27,10 @@ public class ShapeOperationsConverter : StableHloOperationConverter { // flatten / squeeze. concat glues tensors along an axis, // slice extracts a static window of a tensor. "concat", "concatenate", "cat", "stack", - "slice" + "slice", + // narrow(dim, start, length) is a single-axis slice — RoPE / attention + // head splitting use it heavily. + "narrow" ) override fun convert( @@ -42,6 +45,7 @@ public class ShapeOperationsConverter : StableHloOperationConverter { "unsqueeze" -> convertUnsqueeze(node, operands, context) "concat", "concatenate", "cat", "stack" -> convertConcat(node, operands, context) "slice" -> convertSlice(node, operands, context) + "narrow" -> convertNarrow(node, operands, context) else -> ConversionResult.Unsupported( node.operation.name, "Operation not supported by ShapeOperationsConverter" @@ -150,6 +154,64 @@ public class ShapeOperationsConverter : StableHloOperationConverter { ) } + /** + * Convert narrow(dim, start, length) to a single-axis stablehlo.slice. + * + * narrow keeps `[start, start+length)` along `dim` and the full extent of + * every other axis. Reads `dim`/`start`/`length` from parameters (the keys + * the graph tape records); falls back to the output shape for `length`. + * + * %out = stablehlo.slice %x {start_indices=[..], limit_indices=[..], strides=[..]} : + */ + private fun convertNarrow( + node: GraphNode, + operands: List, + context: ConversionContext + ): ConversionResult { + if (operands.size != 1) { + return ConversionResult.Failure( + "Narrow operation requires exactly 1 operand, got ${operands.size}", + "Unsupported narrow arity for node ${node.id}" + ) + } + + val outputSpec = node.outputs.firstOrNull() + val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } + ?: "tensor" + + val inputShape = node.inputs.firstOrNull()?.shape ?: emptyList() + val rank = inputShape.size + if (rank == 0) { + return ConversionResult.Failure( + "Narrow requires a known input rank", + "Missing input shape for narrow node ${node.id}" + ) + } + + val rawDim = node.operation.parameters["dim"] as? Int ?: 0 + val dim = if (rawDim < 0) rank + rawDim else rawDim + val start = node.operation.parameters["start"] as? Int ?: 0 + val length = node.operation.parameters["length"] as? Int + ?: outputSpec?.shape?.getOrNull(dim) + ?: (inputShape[dim] - start) + + val starts = List(rank) { if (it == dim) start else 0 } + val limits = List(rank) { if (it == dim) start + length else inputShape[it] } + val strides = List(rank) { 1 } + + val resultValue = context.nextTempValue() + val operation = "$resultValue = stablehlo.slice ${operands[0]} " + + "{start_indices = [${starts.joinToString(", ")}], " + + "limit_indices = [${limits.joinToString(", ")}], " + + "strides = [${strides.joinToString(", ")}]} : $outputType" + context.emitOperation(operation) + + return ConversionResult.Success( + outputValueName = resultValue, + emittedOperations = listOf(operation) + ) + } + /** * Convert reshape operation using stablehlo.reshape. * Handles both static and dynamic shape specifications. diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NarrowPermuteConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NarrowPermuteConverterTest.kt new file mode 100644 index 00000000..19bc033a --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NarrowPermuteConverterTest.kt @@ -0,0 +1,77 @@ +package sk.ainet.compile.hlo + +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.GraphEdge +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.ops.InputOperation +import sk.ainet.lang.tensor.ops.Operation +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.types.DType +import kotlin.test.Test +import kotlin.test.assertTrue + +/** + * Covers the converters added for the transformer (gemma3) export path: + * - permute -> stablehlo.transpose (arbitrary axes) + * - narrow(dim,start,length) -> single-axis stablehlo.slice + */ +class NarrowPermuteConverterTest { + + private fun opNode( + id: String, + opName: String, + opType: String, + params: Map, + input: TensorSpec, + output: TensorSpec, + ): GraphNode = GraphNode( + id = id, + operation = object : Operation { + override val name = opName + override val type = opType + override val parameters = params + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("test op") + override fun validateInputs(inputs: List): sk.ainet.lang.tensor.ops.ValidationResult = + sk.ainet.lang.tensor.ops.ValidationResult.Valid + override fun inferOutputs(inputs: List): List = listOf(output) + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = params + }, + inputs = listOf(input), + outputs = listOf(output), + ) + + @Test + fun permuteLowersToTranspose() { + val g = DefaultComputeGraph() + val a = GraphNode("a", InputOperation(), emptyList(), listOf(TensorSpec("a", listOf(2, 3, 4), "FP32"))) + val p = opNode( + "p", "permute", "linalg", mapOf("axes" to listOf(0, 2, 1)), + TensorSpec("a", listOf(2, 3, 4), "FP32"), TensorSpec("b", listOf(2, 4, 3), "FP32"), + ) + g.addNode(a); g.addNode(p) + g.addEdge(GraphEdge("e1", a, p, 0, 0, a.outputs[0])) + + val mlir = StableHloConverterFactory.createBasic().convert(g, "permute_test").content + assertTrue(mlir.contains("stablehlo.transpose"), "expected transpose in:\n$mlir") + assertTrue(mlir.contains("dims = [0, 2, 1]"), "expected axes [0,2,1] in:\n$mlir") + } + + @Test + fun narrowLowersToSlice() { + val g = DefaultComputeGraph() + val a = GraphNode("a", InputOperation(), emptyList(), listOf(TensorSpec("a", listOf(2, 8), "FP32"))) + val n = opNode( + "n", "narrow", "shape", mapOf("dim" to 1, "start" to 2, "length" to 4), + TensorSpec("a", listOf(2, 8), "FP32"), TensorSpec("b", listOf(2, 4), "FP32"), + ) + g.addNode(a); g.addNode(n) + g.addEdge(GraphEdge("e1", a, n, 0, 0, a.outputs[0])) + + val mlir = StableHloConverterFactory.createBasic().convert(g, "narrow_test").content + assertTrue(mlir.contains("stablehlo.slice"), "expected slice in:\n$mlir") + assertTrue(mlir.contains("start_indices = [0, 2]"), "expected start [0,2] in:\n$mlir") + assertTrue(mlir.contains("limit_indices = [2, 6]"), "expected limit [2,6] in:\n$mlir") + } +} From 63dd06cb10d22f4557b53eb20a6f78dfd74b5215 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 1 Jun 2026 16:04:53 +0200 Subject: [PATCH 03/10] fix(trace): VoidTensorOps.gather output shape for multi-dim indices MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The void/tracing gather collapsed the gathered axis to indices.shape[0], so a [vocab,emb] table with [batch,seq] indices traced to [batch,emb] instead of [batch,seq,emb] — breaking the embedding's downstream reshape ("volume mismatch 64 != 256") during weight-free tracing. Replace the axis with the FULL indices shape, matching DefaultCpuOps.gather. Unblocks tracing full transformer (gemma3) graphs to a ComputeGraph. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt index 225af4b8..92dad4da 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt @@ -466,9 +466,15 @@ public class VoidTensorOps : TensorOps { } override fun gather(input: Tensor, indices: Tensor, dim: Int): Tensor { - // Gather selects rows along dim — output shape replaces dim with indices length - val resultDims = input.shape.dimensions.copyOf() - resultDims[dim] = indices.shape.dimensions[0] + // Gather selects rows along `dim`, replacing that axis with the FULL + // indices shape (not just its first dim). Matches DefaultCpuOps.gather: + // for a [vocab, emb] table and [batch, seq] indices the result is + // [batch, seq, emb]. The previous `resultDims[dim] = indices.shape[0]` + // collapsed multi-dim indices to a single row, breaking the downstream + // reshape (e.g. embedding lookup) during weight-free tracing. + val inDims = input.shape.dimensions + val idxDims = indices.shape.dimensions + val resultDims = inDims.copyOfRange(0, dim) + idxDims + inDims.copyOfRange(dim + 1, inDims.size) val resultData = dataFactory.zeros(Shape(resultDims), input.dtype) return VoidOpsTensor(resultData, input.dtype) } From e42c85a981ff51e0a2958d110a31826caaef6ffb Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 1 Jun 2026 16:19:01 +0200 Subject: [PATCH 04/10] feat(hlo): multi-output converter support + split converter Lowers the last gemma3 op gap (split, used by RoPE). - ConversionContext: per-(nodeId, outputPort) SSA value names (port 0 keeps the bare-nodeId key, so all single-output callers are unchanged) + resolveOperands(node) that walks incoming edges in destinationInputIndex order and resolves each by the edge's sourceOutputIndex. - StableHloConverter: resolve operands via resolveOperands so a consumer of a multi-output op gets the right output (e.g. split chunk N, not chunk 0). - ShapeOperationsConverter: split/chunk -> N stablehlo.slice, each registered on its own output port. - Test: split -> 2 chunk slices + a relu consuming chunk 1 resolves to chunk 1. Full hlo suite green (operand-resolution change is equivalent for single-output). Co-Authored-By: Claude Opus 4.8 (1M context) --- .../sk/ainet/compile/hlo/ConversionContext.kt | 28 +++++++ .../ainet/compile/hlo/StableHloConverter.kt | 8 +- .../converters/ShapeOperationsConverter.kt | 73 ++++++++++++++++++- .../compile/hlo/NarrowPermuteConverterTest.kt | 41 +++++++++++ 4 files changed, 146 insertions(+), 4 deletions(-) diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConversionContext.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConversionContext.kt index d288fc25..3b335be0 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConversionContext.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConversionContext.kt @@ -44,6 +44,34 @@ public class ConversionContext @kotlin.jvm.JvmOverloads constructor( valueNames[nodeId] = valueName } + // --- Multi-output support ------------------------------------------------ + // A node may produce several results (e.g. `split` -> N chunks). Each output + // port gets its own SSA name; port 0 stays keyed by the bare nodeId so all + // existing single-output callers are unchanged. + private fun key(nodeId: String, port: Int): String = if (port == 0) nodeId else "$nodeId#$port" + + /** Set the SSA value name for a specific output port of a node. */ + public fun setValueName(nodeId: String, port: Int, valueName: String) { + valueNames[key(nodeId, port)] = valueName + } + + /** Get the SSA value name for a specific output port of a node. */ + public fun getValueName(nodeId: String, port: Int): String? = valueNames[key(nodeId, port)] + + /** + * Resolve a node's input operands in input-port order, honoring the source + * output port of each incoming edge (so a consumer of `split`'s chunk N gets + * chunk N, not chunk 0). Equivalent to the old node-based resolution for + * single-output producers (all source ports are 0). + */ + public fun resolveOperands(node: GraphNode): List { + val g = graph ?: return emptyList() + return g.edges + .filter { it.destination.id == node.id } + .sortedBy { it.destinationInputIndex } + .mapNotNull { getValueName(it.source.id, it.sourceOutputIndex) } + } + /** * Record the MLIR tensor type associated with an SSA value name. * diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt index 7bc6940e..9f2e428c 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverter.kt @@ -194,9 +194,11 @@ public class StableHloConverter @kotlin.jvm.JvmOverloads constructor( val converter = registry.getConverter(node.operation.name) ?: throw UnsupportedOperationException("No converter found for operation: ${node.operation.name}") - // Get input operands from context - val inputNodes = context.getInputNodes(node) - val operands = inputNodes.mapNotNull { context.getValueName(it.id) } + // Get input operands in input-port order, honoring each incoming edge's + // source output port so consumers of a multi-output op (e.g. split) get + // the right chunk. Equivalent to the prior node-based resolution for + // single-output producers. + val operands = context.resolveOperands(node) // Surface any physical storage encoding declared on this node's // result specs as an MLIR comment before the operation is diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt index a21303f4..3ea07bc4 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt @@ -30,7 +30,10 @@ public class ShapeOperationsConverter : StableHloOperationConverter { "slice", // narrow(dim, start, length) is a single-axis slice — RoPE / attention // head splitting use it heavily. - "narrow" + "narrow", + // split(splitSize, dim) -> N equal chunks along dim. Multi-output: each + // chunk is a stablehlo.slice registered on its own output port. + "split", "chunk" ) override fun convert( @@ -46,6 +49,7 @@ public class ShapeOperationsConverter : StableHloOperationConverter { "concat", "concatenate", "cat", "stack" -> convertConcat(node, operands, context) "slice" -> convertSlice(node, operands, context) "narrow" -> convertNarrow(node, operands, context) + "split", "chunk" -> convertSplit(node, operands, context) else -> ConversionResult.Unsupported( node.operation.name, "Operation not supported by ShapeOperationsConverter" @@ -212,6 +216,73 @@ public class ShapeOperationsConverter : StableHloOperationConverter { ) } + /** + * Convert split(splitSize, dim) / chunk to N stablehlo.slice ops — one per + * output chunk. Multi-output: each chunk's SSA name is registered on its own + * output port (context.setValueName(node.id, port, name)) so downstream + * consumers, resolved by their incoming edge's source port, pick the right + * chunk. Returns chunk 0 as the nominal result. + */ + private fun convertSplit( + node: GraphNode, + operands: List, + context: ConversionContext + ): ConversionResult { + if (operands.size != 1) { + return ConversionResult.Failure( + "Split operation requires exactly 1 operand, got ${operands.size}", + "Unsupported split arity for node ${node.id}" + ) + } + val inputShape = node.inputs.firstOrNull()?.shape ?: emptyList() + val rank = inputShape.size + if (rank == 0) { + return ConversionResult.Failure( + "Split requires a known input rank", + "Missing input shape for split node ${node.id}" + ) + } + val rawDim = node.operation.parameters["dim"] as? Int ?: 0 + val dim = if (rawDim < 0) rank + rawDim else rawDim + val splitSize = (node.operation.parameters["splitSize"] as? Int) + ?: (node.operation.parameters["split_size"] as? Int) + ?: return ConversionResult.Failure( + "Split requires a 'splitSize' parameter", + "Missing splitSize for split node ${node.id}" + ) + val axisLen = inputShape[dim] + val nChunks = node.outputs.size.takeIf { it > 0 } + ?: ((axisLen + splitSize - 1) / splitSize) + + val emitted = mutableListOf() + var firstName: String? = null + for (i in 0 until nChunks) { + val start = i * splitSize + if (start >= axisLen) break + val end = minOf(start + splitSize, axisLen) + val starts = List(rank) { if (it == dim) start else 0 } + val limits = List(rank) { if (it == dim) end else inputShape[it] } + val strides = List(rank) { 1 } + val outType = node.outputs.getOrNull(i) + ?.let { context.getTypeMapper().mapTensorType(it) } ?: "tensor" + val v = context.nextTempValue() + val op = "$v = stablehlo.slice ${operands[0]} " + + "{start_indices = [${starts.joinToString(", ")}], " + + "limit_indices = [${limits.joinToString(", ")}], " + + "strides = [${strides.joinToString(", ")}]} : $outType" + context.emitOperation(op) + context.setValueName(node.id, i, v) + context.setValueType(v, outType) + emitted += op + if (i == 0) firstName = v + } + return if (firstName != null) { + ConversionResult.Success(outputValueName = firstName, emittedOperations = emitted) + } else { + ConversionResult.Failure("Split produced no chunks", "Empty split for node ${node.id}") + } + } + /** * Convert reshape operation using stablehlo.reshape. * Handles both static and dynamic shape specifications. diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NarrowPermuteConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NarrowPermuteConverterTest.kt index 19bc033a..7deb7ff1 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NarrowPermuteConverterTest.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NarrowPermuteConverterTest.kt @@ -74,4 +74,45 @@ class NarrowPermuteConverterTest { assertTrue(mlir.contains("start_indices = [0, 2]"), "expected start [0,2] in:\n$mlir") assertTrue(mlir.contains("limit_indices = [2, 6]"), "expected limit [2,6] in:\n$mlir") } + + /** Multi-output: split -> N slices, and a consumer of chunk 1 resolves to chunk 1. */ + @Test + fun splitMultiOutputAndPortResolution() { + val g = DefaultComputeGraph() + val a = GraphNode("a", InputOperation(), emptyList(), listOf(TensorSpec("a", listOf(2, 8), "FP32"))) + val c0 = TensorSpec("c0", listOf(2, 4), "FP32") + val c1 = TensorSpec("c1", listOf(2, 4), "FP32") + val split = GraphNode( + id = "s", + operation = object : Operation { + override val name = "split" + override val type = "shape" + override val parameters = mapOf("splitSize" to 4, "dim" to 1) + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("test op") + override fun validateInputs(inputs: List) = sk.ainet.lang.tensor.ops.ValidationResult.Valid + override fun inferOutputs(inputs: List): List = listOf(c0, c1) + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = parameters + }, + inputs = listOf(TensorSpec("a", listOf(2, 8), "FP32")), + outputs = listOf(c0, c1), + ) + // relu consumes split output PORT 1 (chunk 1) + val relu = opNode("r", "relu", "activation", emptyMap(), c1, TensorSpec("o", listOf(2, 4), "FP32")) + g.addNode(a); g.addNode(split); g.addNode(relu) + g.addEdge(GraphEdge("e1", a, split, 0, 0, a.outputs[0])) + g.addEdge(GraphEdge("e2", split, relu, 1, 0, c1)) + + val mlir = StableHloConverterFactory.createBasic().convert(g, "split_test").content + // Two chunk slices: chunk0 = [.., 0:4], chunk1 = [.., 4:8]. + assertTrue(mlir.contains("limit_indices = [2, 4]"), "expected chunk0 slice in:\n$mlir") + assertTrue(mlir.contains("start_indices = [0, 4]") && mlir.contains("limit_indices = [2, 8]"), + "expected chunk1 slice [0,4]..[2,8] in:\n$mlir") + // The chunk-1 slice's SSA value must be the operand the relu consumes. + val chunk1Val = mlir.lines().first { it.contains("stablehlo.slice") && it.contains("limit_indices = [2, 8]") } + .trim().substringBefore(" =") + val reluLine = mlir.lines().first { it.contains("stablehlo.maximum") } + assertTrue(reluLine.contains(chunk1Val), "relu must consume chunk1 ($chunk1Val): $reluLine") + } } From bf76bb076168d6ed7dde1dd18a51ec4cfcb2b87b Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 1 Jun 2026 16:42:29 +0200 Subject: [PATCH 05/10] =?UTF-8?q?feat(hlo):=20scaledDotProductAttention=20?= =?UTF-8?q?converter=20=E2=80=94=20full=20gemma3=20lowers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds AttentionOperationsConverter: lowers the atomic scaledDotProductAttention op to the standard StableHLO subgraph — scores = Q·Kᵀ (dot_general, contract head_dim), * scale (arg or 1/sqrt(head_dim)), softmax over key length (stable max/sub/exp/sum/div), out = attn·V (dot_general). Batched [..,S,D]; batching dims = all leading dims. Registered in StableHloConverterFactory. v1: attention mask/causal not yet emitted (structurally correct, unmasked) — TODO. With this, a full gemma3 network lowers to StableHLO with ZERO gaps (verified by GemmaTraceTest over the composite build: 140 nodes -> 255 lines, 0 unsupported, 0 arity). SDPA is a core TensorOps op so its converter lives in core. Unit test asserts 2 dot_generals + softmax + scores shape. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../compile/hlo/StableHloConverterFactory.kt | 7 ++ .../AttentionOperationsConverter.kt | 112 ++++++++++++++++++ .../compile/hlo/NarrowPermuteConverterTest.kt | 37 ++++++ 3 files changed, 156 insertions(+) create mode 100644 skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/AttentionOperationsConverter.kt diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverterFactory.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverterFactory.kt index 3aa92591..dd8d8c8e 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverterFactory.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/StableHloConverterFactory.kt @@ -1,6 +1,7 @@ package sk.ainet.compile.hlo import sk.ainet.compile.hlo.converters.ActivationOperationsConverter +import sk.ainet.compile.hlo.converters.AttentionOperationsConverter import sk.ainet.compile.hlo.converters.ConstantOperationsConverter import sk.ainet.compile.hlo.converters.GatherOperationsConverter import sk.ainet.compile.hlo.converters.LegacyOperationsConverter @@ -67,6 +68,9 @@ public object StableHloConverterFactory { // Register constant operations converter registry.register(ConstantOperationsConverter()) + // Register attention (scaledDotProductAttention) converter + registry.register(AttentionOperationsConverter()) + // Register gather / embedding / index_select converter — the // LLM front-door op for token-id \u2192 embedding lookups. registry.register(GatherOperationsConverter()) @@ -121,6 +125,9 @@ public object StableHloConverterFactory { // Register constant operations converter registry.register(ConstantOperationsConverter()) + // Register attention (scaledDotProductAttention) converter + registry.register(AttentionOperationsConverter()) + // Register gather / embedding / index_select converter — the // LLM front-door op for token-id \u2192 embedding lookups. registry.register(GatherOperationsConverter()) diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/AttentionOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/AttentionOperationsConverter.kt new file mode 100644 index 00000000..f02569d9 --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/AttentionOperationsConverter.kt @@ -0,0 +1,112 @@ +package sk.ainet.compile.hlo.converters + +import sk.ainet.compile.hlo.ConversionContext +import sk.ainet.compile.hlo.ConversionResult +import sk.ainet.compile.hlo.StableHloOperationConverter +import sk.ainet.lang.graph.GraphNode +import kotlin.math.sqrt + +/** + * Converts `scaledDotProductAttention` to the standard StableHLO attention + * subgraph: + * + * scores = Q · Kᵀ (dot_general, contract head_dim) + * scaled = scores * scale (scale = arg, or 1/sqrt(head_dim)) + * attn = softmax(scaled, axis = -1) (max/sub/exp/sum/div, numerically stable) + * out = attn · V (dot_general, contract key length) + * + * Q/K/V are batched `[.., S, D]`; batching dims are every leading dim except the + * last two. The softmax decomposition mirrors ActivationOperationsConverter. + * + * SDPA is a core `TensorOps` op (KSP-generated), so its converter lives here in + * core alongside dot_general/softmax — the transformer modules just decompose to it. + * + * v1 limitation: the optional attention `mask` / `causal` flag is not yet + * emitted (structurally correct, numerically unmasked). TODO: emit a causal mask + * (iota + compare + select, additive -inf) before the softmax. + */ +public class AttentionOperationsConverter : StableHloOperationConverter { + + override val supportedOperations: Set = setOf( + "scaledDotProductAttention", "scaleddotproductattention", "sdpa" + ) + + override fun convert( + node: GraphNode, + operands: List, + context: ConversionContext + ): ConversionResult { + if (operands.size < 3) { + return ConversionResult.Failure( + "scaledDotProductAttention requires q, k, v (>=3 operands), got ${operands.size}", + "Unsupported SDPA arity for node ${node.id}" + ) + } + + val qShape = node.inputs.getOrNull(0)?.shape + ?: return ConversionResult.Failure("SDPA requires a known query shape", "Missing query shape for ${node.id}") + val kShape = node.inputs.getOrNull(1)?.shape ?: qShape + val vShape = node.inputs.getOrNull(2)?.shape ?: qShape + val rank = qShape.size + if (rank < 2) { + return ConversionResult.Failure("SDPA query must have rank >= 2", "Bad query rank for ${node.id}") + } + + val outSpec = node.outputs.firstOrNull() + val mapper = context.getTypeMapper() + val elem = outSpec?.let { mapper.mapDType(it.dtype) } ?: "f32" + fun typeOf(shape: List): String = "tensor<${shape.joinToString("x")}x$elem>" + + val qType = context.getValueType(operands[0]) ?: typeOf(qShape) + val kType = context.getValueType(operands[1]) ?: typeOf(kShape) + val vType = context.getValueType(operands[2]) ?: typeOf(vShape) + + val headDim = qShape[rank - 1] + val keyLen = kShape[rank - 2] + val scoresShape = qShape.dropLast(1) + keyLen // [.., Sq, Sk] + val scoresType = typeOf(scoresShape) + val outputType = outSpec?.let { mapper.mapTensorType(it) } ?: typeOf(qShape.dropLast(1) + headDim) + + val scaleParam = (node.operation.parameters["scale"] as? Number)?.toFloat() ?: 0f + val scaleVal = if (scaleParam != 0f) scaleParam else (1.0f / sqrt(headDim.toFloat())) + + val hasBatch = rank > 2 + val batchList = (0 until rank - 2).joinToString(", ") + val batchClause = if (hasBatch) "batching_dims = [$batchList] x [$batchList], " else "" + val contractQK = rank - 1 // contract head_dim of Q and K + val sdAxis = scoresShape.size - 1 // softmax over key length + val reducedShape = scoresShape.dropLast(1) + val reducedType = if (reducedShape.isEmpty()) "tensor<$elem>" else "tensor<${reducedShape.joinToString("x")}x$elem>" + val bcastDims = (scoresShape.indices).filter { it != sdAxis }.joinToString(", ") + val contractAttn = scoresShape.size - 1 // attn key-length axis + val contractV = rank - 2 // V key-length axis + + val scores = context.nextTempValue() + val scaleC = context.nextTempValue() + val scaled = context.nextTempValue() + val maxInit = context.nextTempValue(); val maxV = context.nextTempValue(); val maxB = context.nextTempValue() + val shifted = context.nextTempValue(); val expV = context.nextTempValue() + val sumInit = context.nextTempValue(); val sumV = context.nextTempValue(); val sumB = context.nextTempValue() + val attn = context.nextTempValue() + val out = context.nextTempValue() + + val ops = listOf( + "$scores = stablehlo.dot_general ${operands[0]}, ${operands[1]}, ${batchClause}contracting_dims = [$contractQK] x [$contractQK] : ($qType, $kType) -> $scoresType", + "$scaleC = stablehlo.constant dense<$scaleVal> : $scoresType", + "$scaled = stablehlo.multiply $scores, $scaleC : $scoresType", + // softmax(scaled) over the key-length axis + "$maxInit = stablehlo.constant dense<0xFF800000> : tensor<$elem>", + "$maxV = stablehlo.reduce($scaled init: $maxInit) applies stablehlo.maximum across dimensions = [$sdAxis] : ($scoresType, tensor<$elem>) -> $reducedType", + "$maxB = stablehlo.broadcast_in_dim $maxV, dims = [$bcastDims] : ($reducedType) -> $scoresType", + "$shifted = stablehlo.subtract $scaled, $maxB : $scoresType", + "$expV = stablehlo.exponential $shifted : $scoresType", + "$sumInit = stablehlo.constant dense<0.0> : tensor<$elem>", + "$sumV = stablehlo.reduce($expV init: $sumInit) applies stablehlo.add across dimensions = [$sdAxis] : ($scoresType, tensor<$elem>) -> $reducedType", + "$sumB = stablehlo.broadcast_in_dim $sumV, dims = [$bcastDims] : ($reducedType) -> $scoresType", + "$attn = stablehlo.divide $expV, $sumB : $scoresType", + "$out = stablehlo.dot_general $attn, ${operands[2]}, ${batchClause}contracting_dims = [$contractAttn] x [$contractV] : ($scoresType, $vType) -> $outputType", + ) + ops.forEach { context.emitOperation(it) } + return ConversionResult.Success(outputValueName = out, emittedOperations = ops) + } +} diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NarrowPermuteConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NarrowPermuteConverterTest.kt index 7deb7ff1..cfd0de4e 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NarrowPermuteConverterTest.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NarrowPermuteConverterTest.kt @@ -115,4 +115,41 @@ class NarrowPermuteConverterTest { val reluLine = mlir.lines().first { it.contains("stablehlo.maximum") } assertTrue(reluLine.contains(chunk1Val), "relu must consume chunk1 ($chunk1Val): $reluLine") } + + /** scaledDotProductAttention lowers to QKᵀ · scale · softmax · ·V (two dot_generals). */ + @Test + fun sdpaLowersToAttentionSubgraph() { + val g = DefaultComputeGraph() + // batched [B=1, H=2, S=4, D=8] + val qkvShape = listOf(1, 2, 4, 8) + fun inNode(id: String) = GraphNode(id, InputOperation(), emptyList(), listOf(TensorSpec(id, qkvShape, "FP32"))) + val q = inNode("q"); val k = inNode("k"); val v = inNode("v") + val sdpa = GraphNode( + id = "att", + operation = object : Operation { + override val name = "scaledDotProductAttention" + override val type = "trace" + override val parameters = mapOf("scale" to 0.0f, "causal" to true) + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("test op") + override fun validateInputs(inputs: List) = sk.ainet.lang.tensor.ops.ValidationResult.Valid + override fun inferOutputs(inputs: List): List = listOf(TensorSpec("o", qkvShape, "FP32")) + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = parameters + }, + inputs = listOf(TensorSpec("q", qkvShape, "FP32"), TensorSpec("k", qkvShape, "FP32"), TensorSpec("v", qkvShape, "FP32")), + outputs = listOf(TensorSpec("o", qkvShape, "FP32")), + ) + g.addNode(q); g.addNode(k); g.addNode(v); g.addNode(sdpa) + g.addEdge(GraphEdge("e0", q, sdpa, 0, 0, q.outputs[0])) + g.addEdge(GraphEdge("e1", k, sdpa, 0, 1, k.outputs[0])) + g.addEdge(GraphEdge("e2", v, sdpa, 0, 2, v.outputs[0])) + + val mlir = StableHloConverterFactory.createBasic().convert(g, "sdpa_test").content + // Two dot_generals (scores = QKᵀ, out = attn·V). + assertTrue(Regex("stablehlo\\.dot_general").findAll(mlir).count() == 2, "expected 2 dot_general in:\n$mlir") + // Softmax decomposition present (exponential), and scores shape is [1,2,4,4]. + assertTrue(mlir.contains("stablehlo.exponential"), "expected softmax exp in:\n$mlir") + assertTrue(mlir.contains("1x2x4x4xf32"), "expected scores type 1x2x4x4 in:\n$mlir") + } } From 0c7b961d321ec69e125ee9cb75de53b2dc924345 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 1 Jun 2026 17:04:07 +0200 Subject: [PATCH 06/10] =?UTF-8?q?fix(hlo):=20emit=20iree-valid=20StableHLO?= =?UTF-8?q?=20syntax=20=E2=80=94=20full=20gemma3=20compiles=20to=20vmfb?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Aligned converter emission to what iree-compile's stablehlo parser accepts (verified by compiling the full gemma3 graph end-to-end): - gather: use the GENERIC MLIR form `"stablehlo.gather"(%a,%b) <{...}>` (stablehlo.gather has no custom assembly form). - slice/narrow/split: canonical bracket form `stablehlo.slice %x [s:l:st, ...] : (in) -> out` (attribute-dict form is rejected) — shared sliceLine() helper. - concatenate: full functional type `(t0, t1, ...) -> out` (bare `: out` rejected). - batch matmul: batch dims = leading dims shared by BOTH operands (min(lhsRank,rhsRank)-2); fixes 3D-activation @ 2D-weight Linear projections that previously emitted mismatched batching_dims=[0]x[0]. - updated converter tests to the valid forms. Result: SKaiNET gemma3 DSL -> StableHLO -> iree-compile (llvm-cpu; +neon aarch64) -> vmfb, both host x64 and aarch64 targets. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../converters/GatherOperationsConverter.kt | 8 +-- .../converters/LinalgOperationsConverter.kt | 18 ++++--- .../converters/ShapeOperationsConverter.kt | 49 ++++++++++++------- .../hlo/ConcatSliceCastConverterTest.kt | 14 +++--- .../ainet/compile/hlo/GatherConverterTest.kt | 9 ++-- .../compile/hlo/NarrowPermuteConverterTest.kt | 12 ++--- 6 files changed, 65 insertions(+), 45 deletions(-) diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/GatherOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/GatherOperationsConverter.kt index 98b8eff3..cbc05ab6 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/GatherOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/GatherOperationsConverter.kt @@ -133,14 +133,16 @@ public class GatherOperationsConverter : StableHloOperationConverter { val weightOperand = operands[0] val indicesOperand = operands[1] val resultValue = context.nextTempValue() - val gatherOp = "$resultValue = stablehlo.gather($weightOperand, $indicesOperand) " + - "{ dimension_numbers = #stablehlo.gather<" + + // stablehlo.gather has no custom (pretty) assembly form — emit the + // generic MLIR op form: "stablehlo.gather"(%operand, %indices) <{attrs}>. + val gatherOp = "$resultValue = \"stablehlo.gather\"($weightOperand, $indicesOperand) " + + "<{dimension_numbers = #stablehlo.gather<" + "offset_dims = [$offsetDims], " + "collapsed_slice_dims = [$collapsedSliceDims], " + "start_index_map = [$startIndexMap], " + "index_vector_dim = $indexVectorDim>, " + "slice_sizes = array, " + - "indices_are_sorted = false } " + + "indices_are_sorted = false}> " + ": ($weightType, $indicesType) -> $outputType" context.emitOperation(gatherOp) diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/LinalgOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/LinalgOperationsConverter.kt index ecc78757..54c4af47 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/LinalgOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/LinalgOperationsConverter.kt @@ -83,19 +83,23 @@ public class LinalgOperationsConverter : StableHloOperationConverter { val rhsType = rhsSpec?.let { context.getTypeMapper().mapTensorType(it) } ?: "tensor" - // Infer batching rank: for A[..., M, K] x B[..., K, N], batching dims - // are all leading dims except the last two. Falls back to rank 3 if - // shape is unknown (matches prior hard-coded behavior). - val rank = lhsSpec?.shape?.size ?: rhsSpec?.shape?.size ?: 3 - val batchCount = (rank - 2).coerceAtLeast(0) + // Batching dims are the leading dims shared by BOTH operands. When the + // ranks differ — e.g. batched activations A[..,M,K] times a 2-D weight + // B[K,N] (the common Linear-projection case) — there are no batch dims: + // A's leading dims are free output dims, not batch dims. Using lhsRank + // for both (the old behavior) produced batching_dims=[0]x[0] with + // mismatched sizes (1 vs 64). + val lhsRank = lhsSpec?.shape?.size ?: 3 + val rhsRank = rhsSpec?.shape?.size ?: lhsRank + val batchCount = (minOf(lhsRank, rhsRank) - 2).coerceAtLeast(0) val explicitBatch = node.operation.parameters["batch_dims"] as? List<*> val batchDimsList = if (explicitBatch != null && explicitBatch.isNotEmpty()) { explicitBatch.map { it.toString() } } else { (0 until batchCount).map { it.toString() } } - val contractingLhs = rank - 1 - val contractingRhs = (rank - 2).coerceAtLeast(0) + val contractingLhs = lhsRank - 1 + val contractingRhs = (rhsRank - 2).coerceAtLeast(0) val resultValue = context.nextTempValue() diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt index 3ea07bc4..5f829ebf 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt @@ -90,7 +90,14 @@ public class ShapeOperationsConverter : StableHloOperationConverter { val resultValue = context.nextTempValue() val operandList = operands.joinToString(", ") - val operation = "$resultValue = stablehlo.concatenate $operandList, dim = $axis : $outputType" + // concatenate's custom form needs the full functional type: + // (t0, t1, ...) -> outType (a bare `: outType` is rejected). + val operandTypes = operands.indices.joinToString(", ") { i -> + context.getValueType(operands[i]) + ?: node.inputs.getOrNull(i)?.let { context.getTypeMapper().mapTensorType(it) } + ?: "tensor" + } + val operation = "$resultValue = stablehlo.concatenate $operandList, dim = $axis : ($operandTypes) -> $outputType" context.emitOperation(operation) return ConversionResult.Success( @@ -141,15 +148,9 @@ public class ShapeOperationsConverter : StableHloOperationConverter { val strides = (node.operation.parameters["strides"] as? List) ?: List(rank) { 1 } - val startsAttr = starts.joinToString(", ") - val limitsAttr = limits.joinToString(", ") - val stridesAttr = strides.joinToString(", ") - val resultValue = context.nextTempValue() - val operation = "$resultValue = stablehlo.slice ${operands[0]} " + - "{start_indices = [$startsAttr], " + - "limit_indices = [$limitsAttr], " + - "strides = [$stridesAttr]} : $outputType" + val operation = sliceLine(resultValue, operands[0], starts, limits, strides, + resolveOperandType(operands[0], node, context), outputType) context.emitOperation(operation) return ConversionResult.Success( @@ -204,10 +205,8 @@ public class ShapeOperationsConverter : StableHloOperationConverter { val strides = List(rank) { 1 } val resultValue = context.nextTempValue() - val operation = "$resultValue = stablehlo.slice ${operands[0]} " + - "{start_indices = [${starts.joinToString(", ")}], " + - "limit_indices = [${limits.joinToString(", ")}], " + - "strides = [${strides.joinToString(", ")}]} : $outputType" + val operation = sliceLine(resultValue, operands[0], starts, limits, strides, + resolveOperandType(operands[0], node, context), outputType) context.emitOperation(operation) return ConversionResult.Success( @@ -266,10 +265,8 @@ public class ShapeOperationsConverter : StableHloOperationConverter { val outType = node.outputs.getOrNull(i) ?.let { context.getTypeMapper().mapTensorType(it) } ?: "tensor" val v = context.nextTempValue() - val op = "$v = stablehlo.slice ${operands[0]} " + - "{start_indices = [${starts.joinToString(", ")}], " + - "limit_indices = [${limits.joinToString(", ")}], " + - "strides = [${strides.joinToString(", ")}]} : $outType" + val op = sliceLine(v, operands[0], starts, limits, strides, + resolveOperandType(operands[0], node, context), outType) context.emitOperation(op) context.setValueName(node.id, i, v) context.setValueType(v, outType) @@ -454,6 +451,24 @@ public class ShapeOperationsConverter : StableHloOperationConverter { ) } + /** + * Emit a `stablehlo.slice` in the canonical bracket assembly form + * `%out = stablehlo.slice %x [s0:l0:st0, s1:l1:st1, ...] : (inType) -> outType`. + * (stablehlo.slice has no attribute-dict custom form — iree-compile rejects it.) + */ + private fun sliceLine( + result: String, + operand: String, + starts: List, + limits: List, + strides: List, + inType: String, + outType: String, + ): String { + val ranges = starts.indices.joinToString(", ") { "${starts[it]}:${limits[it]}:${strides[it]}" } + return "$result = stablehlo.slice $operand [$ranges] : ($inType) -> $outType" + } + /** * Look up the MLIR type of an SSA operand. * diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConcatSliceCastConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConcatSliceCastConverterTest.kt index 9233a66c..d4efe579 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConcatSliceCastConverterTest.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ConcatSliceCastConverterTest.kt @@ -104,17 +104,15 @@ class ConcatSliceCastConverterTest { fun slice_carries_start_limit_stride_attributes() { val module = buildSliceModule() println("[DEBUG_LOG] slice export:\n${module.content}") + // Canonical bracket form `stablehlo.slice %x [s:l:st, ...] : (in) -> out` + // (stablehlo.slice has no attribute-dict custom form). assertTrue( - module.content.contains("start_indices"), - "slice must emit start_indices" + module.content.contains("[0:4:1, 0:8:1]"), + "slice must emit per-dim start:limit:stride brackets, got:\n${module.content}" ) assertTrue( - module.content.contains("limit_indices"), - "slice must emit limit_indices" - ) - assertTrue( - module.content.contains("strides"), - "slice must emit strides" + module.content.contains("(tensor<8x16xf32>) -> tensor<4x8xf32>"), + "slice must carry (operand) -> result types" ) } diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/GatherConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/GatherConverterTest.kt index c875ec63..7aff69e5 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/GatherConverterTest.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/GatherConverterTest.kt @@ -88,12 +88,15 @@ class GatherConverterTest { // (Earlier draft accidentally emitted // `stablehlo.gather([%arg0, %arg1][0], [%arg0, %arg1][1])` // because of a `$operands[0]` Kotlin string-template pitfall.) + // Generic MLIR form ("stablehlo.gather" has no custom assembly form): + // "stablehlo.gather"(%operand, %indices) <{...}> + // Operands must be bare SSA values, not a `[%arg0, %arg1][0]` expression. assertTrue( - module.content.contains("stablehlo.gather(%arg0, %arg1)"), - "gather must reference operands as bare SSA values, not `[%arg0, %arg1][0]`" + module.content.contains("\"stablehlo.gather\"(%arg0, %arg1)"), + "gather must reference operands as bare SSA values in the generic form" ) assertFalse( - module.content.contains("stablehlo.gather([%"), + module.content.contains("gather\"([%"), "gather must not emit operand lists as Kotlin-string `[..., ...][0]` junk" ) } diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NarrowPermuteConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NarrowPermuteConverterTest.kt index cfd0de4e..425fa647 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NarrowPermuteConverterTest.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NarrowPermuteConverterTest.kt @@ -71,8 +71,7 @@ class NarrowPermuteConverterTest { val mlir = StableHloConverterFactory.createBasic().convert(g, "narrow_test").content assertTrue(mlir.contains("stablehlo.slice"), "expected slice in:\n$mlir") - assertTrue(mlir.contains("start_indices = [0, 2]"), "expected start [0,2] in:\n$mlir") - assertTrue(mlir.contains("limit_indices = [2, 6]"), "expected limit [2,6] in:\n$mlir") + assertTrue(mlir.contains("[0:2:1, 2:6:1]"), "expected bracket slice [0:2:1, 2:6:1] in:\n$mlir") } /** Multi-output: split -> N slices, and a consumer of chunk 1 resolves to chunk 1. */ @@ -105,12 +104,11 @@ class NarrowPermuteConverterTest { g.addEdge(GraphEdge("e2", split, relu, 1, 0, c1)) val mlir = StableHloConverterFactory.createBasic().convert(g, "split_test").content - // Two chunk slices: chunk0 = [.., 0:4], chunk1 = [.., 4:8]. - assertTrue(mlir.contains("limit_indices = [2, 4]"), "expected chunk0 slice in:\n$mlir") - assertTrue(mlir.contains("start_indices = [0, 4]") && mlir.contains("limit_indices = [2, 8]"), - "expected chunk1 slice [0,4]..[2,8] in:\n$mlir") + // Two chunk slices: chunk0 = [.., 0:4], chunk1 = [.., 4:8] (bracket form). + assertTrue(mlir.contains("[0:2:1, 0:4:1]"), "expected chunk0 slice [0:2:1, 0:4:1] in:\n$mlir") + assertTrue(mlir.contains("[0:2:1, 4:8:1]"), "expected chunk1 slice [0:2:1, 4:8:1] in:\n$mlir") // The chunk-1 slice's SSA value must be the operand the relu consumes. - val chunk1Val = mlir.lines().first { it.contains("stablehlo.slice") && it.contains("limit_indices = [2, 8]") } + val chunk1Val = mlir.lines().first { it.contains("stablehlo.slice") && it.contains("4:8:1") } .trim().substringBefore(" =") val reluLine = mlir.lines().first { it.contains("stablehlo.maximum") } assertTrue(reluLine.contains(chunk1Val), "relu must consume chunk1 ($chunk1Val): $reluLine") From 6b6e089342c2b0df121692f311267de2697453c1 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 1 Jun 2026 18:30:04 +0200 Subject: [PATCH 07/10] test: numerical validation harness for SDPA lowering MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Dumps a small scaledDotProductAttention StableHLO ([1,1,2,4], scale 0.5); iree-compile + iree-run-module output matches a NumPy reference exactly to 5 decimals (QKᵀ·scale·softmax·V). Confirms the attention converter is numerically correct, not just structurally valid. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../ainet/compile/hlo/SdpaNumericDumpTest.kt | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/SdpaNumericDumpTest.kt diff --git a/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/SdpaNumericDumpTest.kt b/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/SdpaNumericDumpTest.kt new file mode 100644 index 00000000..e1cf4cef --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/SdpaNumericDumpTest.kt @@ -0,0 +1,52 @@ +package sk.ainet.compile.hlo + +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.GraphEdge +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.ops.InputOperation +import sk.ainet.lang.tensor.ops.Operation +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.types.DType +import java.io.File +import kotlin.test.Test + +/** + * Dumps a small scaledDotProductAttention graph to StableHLO for numerical + * validation against a NumPy reference (see docker iree-run-module + numpy). + * Shapes [B=1, H=1, S=2, D=4]; scale = 1/sqrt(4) = 0.5. + */ +class SdpaNumericDumpTest { + @Test + fun dumpSdpaMlir() { + val shape = listOf(1, 1, 2, 4) + fun inNode(id: String) = GraphNode(id, InputOperation(), emptyList(), listOf(TensorSpec(id, shape, "FP32"))) + val q = inNode("q"); val k = inNode("k"); val v = inNode("v") + val sdpa = GraphNode( + id = "att", + operation = object : Operation { + override val name = "scaledDotProductAttention" + override val type = "trace" + override val parameters = mapOf("scale" to 0.0f, "causal" to false) + override fun execute(inputs: List>) = + throw UnsupportedOperationException("test op") + override fun validateInputs(inputs: List) = sk.ainet.lang.tensor.ops.ValidationResult.Valid + override fun inferOutputs(inputs: List) = listOf(TensorSpec("o", shape, "FP32")) + override fun clone(newParameters: Map): Operation = this + override fun serialize() = parameters + }, + inputs = listOf(TensorSpec("q", shape, "FP32"), TensorSpec("k", shape, "FP32"), TensorSpec("v", shape, "FP32")), + outputs = listOf(TensorSpec("o", shape, "FP32")), + ) + val g = DefaultComputeGraph() + g.addNode(q); g.addNode(k); g.addNode(v); g.addNode(sdpa) + g.addEdge(GraphEdge("e0", q, sdpa, 0, 0, q.outputs[0])) + g.addEdge(GraphEdge("e1", k, sdpa, 0, 1, k.outputs[0])) + g.addEdge(GraphEdge("e2", v, sdpa, 0, 2, v.outputs[0])) + + val mlir = StableHloConverterFactory.createBasic().convert(g, "sdpa").content + val out = File(System.getProperty("sdpaMlirOut") ?: "/home/miso/projects/coral/build-mlir/sdpa.mlir") + out.parentFile?.mkdirs() + out.writeText(mlir) + println("WROTE_SDPA ${out.absolutePath}") + } +} From fa20edeb7ff5735b9230611b37a2196bfd1333a5 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 1 Jun 2026 18:37:57 +0200 Subject: [PATCH 08/10] feat(hlo): causal mask for scaledDotProductAttention (numerically validated) When the SDPA node's `causal` attr is set, emit an additive -inf mask before softmax: iota(query axis) / iota(key axis) -> compare GE -> select(keep, 0, -inf) -> add to scaled scores. Each query attends only to keys at or before it. Validated EXACT vs a NumPy causal reference (S=2): query0 -> v[0] only, query1 -> softmax over both keys. iota/compare/select lowering accepted by iree-compile and numerically correct. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../AttentionOperationsConverter.kt | 55 ++++++++++++++----- .../ainet/compile/hlo/SdpaNumericDumpTest.kt | 2 +- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/AttentionOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/AttentionOperationsConverter.kt index f02569d9..183e2175 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/AttentionOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/AttentionOperationsConverter.kt @@ -21,9 +21,10 @@ import kotlin.math.sqrt * SDPA is a core `TensorOps` op (KSP-generated), so its converter lives here in * core alongside dot_general/softmax — the transformer modules just decompose to it. * - * v1 limitation: the optional attention `mask` / `causal` flag is not yet - * emitted (structurally correct, numerically unmasked). TODO: emit a causal mask - * (iota + compare + select, additive -inf) before the softmax. + * Causal masking: when the `causal` attribute is set, an additive -inf mask + * (built from iota row/col indices + compare + select) is added to the scaled + * scores before softmax so each query only attends to keys at or before it. + * An explicit `mask` operand is not yet consumed (TODO: add operands[3]). */ public class AttentionOperationsConverter : StableHloOperationConverter { @@ -81,6 +82,11 @@ public class AttentionOperationsConverter : StableHloOperationConverter { val contractAttn = scoresShape.size - 1 // attn key-length axis val contractV = rank - 2 // V key-length axis + val causal = (node.operation.parameters["causal"] as? Boolean) ?: false + val qAxis = rank - 2 // query position in scores [.., Sq, Sk] + val scoresI32Type = "tensor<${scoresShape.joinToString("x")}xi32>" + val scoresI1Type = "tensor<${scoresShape.joinToString("x")}xi1>" + val scores = context.nextTempValue() val scaleC = context.nextTempValue() val scaled = context.nextTempValue() @@ -90,22 +96,41 @@ public class AttentionOperationsConverter : StableHloOperationConverter { val attn = context.nextTempValue() val out = context.nextTempValue() - val ops = listOf( + val ops = mutableListOf( "$scores = stablehlo.dot_general ${operands[0]}, ${operands[1]}, ${batchClause}contracting_dims = [$contractQK] x [$contractQK] : ($qType, $kType) -> $scoresType", "$scaleC = stablehlo.constant dense<$scaleVal> : $scoresType", "$scaled = stablehlo.multiply $scores, $scaleC : $scoresType", - // softmax(scaled) over the key-length axis - "$maxInit = stablehlo.constant dense<0xFF800000> : tensor<$elem>", - "$maxV = stablehlo.reduce($scaled init: $maxInit) applies stablehlo.maximum across dimensions = [$sdAxis] : ($scoresType, tensor<$elem>) -> $reducedType", - "$maxB = stablehlo.broadcast_in_dim $maxV, dims = [$bcastDims] : ($reducedType) -> $scoresType", - "$shifted = stablehlo.subtract $scaled, $maxB : $scoresType", - "$expV = stablehlo.exponential $shifted : $scoresType", - "$sumInit = stablehlo.constant dense<0.0> : tensor<$elem>", - "$sumV = stablehlo.reduce($expV init: $sumInit) applies stablehlo.add across dimensions = [$sdAxis] : ($scoresType, tensor<$elem>) -> $reducedType", - "$sumB = stablehlo.broadcast_in_dim $sumV, dims = [$bcastDims] : ($reducedType) -> $scoresType", - "$attn = stablehlo.divide $expV, $sumB : $scoresType", - "$out = stablehlo.dot_general $attn, ${operands[2]}, ${batchClause}contracting_dims = [$contractAttn] x [$contractV] : ($scoresType, $vType) -> $outputType", ) + + // Causal mask: keep key_index <= query_index, set the rest to -inf + // before softmax (additive mask, built from iota row/col indices). + var softmaxIn = scaled + if (causal) { + val iotaQ = context.nextTempValue(); val iotaK = context.nextTempValue() + val keep = context.nextTempValue(); val zeros = context.nextTempValue() + val ninf = context.nextTempValue(); val maskAdd = context.nextTempValue() + val masked = context.nextTempValue() + ops += "$iotaQ = stablehlo.iota dim = $qAxis : $scoresI32Type" + ops += "$iotaK = stablehlo.iota dim = $sdAxis : $scoresI32Type" + ops += "$keep = stablehlo.compare GE, $iotaQ, $iotaK : ($scoresI32Type, $scoresI32Type) -> $scoresI1Type" + ops += "$zeros = stablehlo.constant dense<0.0> : $scoresType" + ops += "$ninf = stablehlo.constant dense<0xFF800000> : $scoresType" + ops += "$maskAdd = stablehlo.select $keep, $zeros, $ninf : $scoresI1Type, $scoresType" + ops += "$masked = stablehlo.add $scaled, $maskAdd : $scoresType" + softmaxIn = masked + } + + // softmax(softmaxIn) over the key-length axis + ops += "$maxInit = stablehlo.constant dense<0xFF800000> : tensor<$elem>" + ops += "$maxV = stablehlo.reduce($softmaxIn init: $maxInit) applies stablehlo.maximum across dimensions = [$sdAxis] : ($scoresType, tensor<$elem>) -> $reducedType" + ops += "$maxB = stablehlo.broadcast_in_dim $maxV, dims = [$bcastDims] : ($reducedType) -> $scoresType" + ops += "$shifted = stablehlo.subtract $softmaxIn, $maxB : $scoresType" + ops += "$expV = stablehlo.exponential $shifted : $scoresType" + ops += "$sumInit = stablehlo.constant dense<0.0> : tensor<$elem>" + ops += "$sumV = stablehlo.reduce($expV init: $sumInit) applies stablehlo.add across dimensions = [$sdAxis] : ($scoresType, tensor<$elem>) -> $reducedType" + ops += "$sumB = stablehlo.broadcast_in_dim $sumV, dims = [$bcastDims] : ($reducedType) -> $scoresType" + ops += "$attn = stablehlo.divide $expV, $sumB : $scoresType" + ops += "$out = stablehlo.dot_general $attn, ${operands[2]}, ${batchClause}contracting_dims = [$contractAttn] x [$contractV] : ($scoresType, $vType) -> $outputType" ops.forEach { context.emitOperation(it) } return ConversionResult.Success(outputValueName = out, emittedOperations = ops) } diff --git a/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/SdpaNumericDumpTest.kt b/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/SdpaNumericDumpTest.kt index e1cf4cef..3aac9553 100644 --- a/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/SdpaNumericDumpTest.kt +++ b/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/SdpaNumericDumpTest.kt @@ -26,7 +26,7 @@ class SdpaNumericDumpTest { operation = object : Operation { override val name = "scaledDotProductAttention" override val type = "trace" - override val parameters = mapOf("scale" to 0.0f, "causal" to false) + override val parameters = mapOf("scale" to 0.0f, "causal" to true) override fun execute(inputs: List>) = throw UnsupportedOperationException("test op") override fun validateInputs(inputs: List) = sk.ainet.lang.tensor.ops.ValidationResult.Valid From c69984f9e09b09cf60f000a02b62bd29b825633f Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 1 Jun 2026 19:38:54 +0200 Subject: [PATCH 09/10] feat(hlo): boxing-free FloatArray weight externalization for .irpa baking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit finalize() now stores resolved weights as the primitive FloatArray instead of .toList() — boxing a real LLM weight (262153x640 embedding -> ~2.7GB List) OOMed the trace. ConstantOperationsConverter externalizes FloatArray directly (new floatArrayToLittleEndianBytes + tryMaterializeExternalFloats), and inlines via asList() for small/InlineAlways. IrpaWriter writes byte ranges in one shot (byte-at-a-time was pathological for ~670MB tensors). With this, the real FunctionGemma-270M bakes: 1 func arg (tokens) + 360 weights externalized to util.global #flow.parameter.named. (IrpaWriter's archive header is still IREE-incompatible (40B vs IREE's 88B v0 header + different segment layout) — tracked separately; bake currently routes weights via safetensors + iree-convert-parameters.) Co-Authored-By: Claude Opus 4.8 (1M context) --- .../ainet/lang/trace/TraceToGraphBuilder.kt | 7 +- .../compile/hlo/ConstantByteSerializer.kt | 44 ++++++++++ .../converters/ConstantOperationsConverter.kt | 84 ++++++++++++++++++- .../kotlin/sk/ainet/io/irpa/IrpaWriter.kt | 14 ++-- 4 files changed, 135 insertions(+), 14 deletions(-) diff --git a/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/trace/TraceToGraphBuilder.kt b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/trace/TraceToGraphBuilder.kt index 8b80f8d7..1bca097b 100644 --- a/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/trace/TraceToGraphBuilder.kt +++ b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/trace/TraceToGraphBuilder.kt @@ -277,7 +277,12 @@ public class TraceToGraphBuilder( name = "weight", type = "constant", parameters = mapOf( - "initial_value" to constantValues.toList(), + // Store the primitive FloatArray, NOT .toList(): boxing a + // real LLM weight (e.g. 262153x640 embedding) into a + // List is ~2.7GB and OOMs the trace. The HLO + // converter handles FloatArray for both inline and + // external (.irpa) materialization. + "initial_value" to constantValues, "trainable" to false ) ) diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConstantByteSerializer.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConstantByteSerializer.kt index d21eb646..9021ba9b 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConstantByteSerializer.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/ConstantByteSerializer.kt @@ -77,6 +77,50 @@ internal fun numberListToLittleEndianBytes( } } +/** + * Boxing-free variant for tensors whose values arrive as a primitive + * [FloatArray] — the form produced by `TraceToGraphBuilder.finalize` + * for resolved (dequantized) weights. Avoids the `FloatArray.toList()` + * boxing that turns a 262153x640 embedding into a ~2.7GB `List` + * and OOMs the trace. FP32 / I32 only (the dtypes a float-backed + * weight resolves to); other dtypes throw so the caller falls back. + */ +internal fun floatArrayToLittleEndianBytes( + values: FloatArray, + dtype: String, + expectedElements: Int +): ByteArray { + val count = expectedElements.coerceAtLeast(values.size) + val n = minOf(count, values.size) + return when (dtype.uppercase()) { + "FP32", "F32", "FLOAT32" -> { + val bytes = ByteArray(count * 4) + for (i in 0 until n) { + val bits = values[i].toRawBits() + bytes[i * 4] = (bits and 0xff).toByte() + bytes[i * 4 + 1] = (bits ushr 8 and 0xff).toByte() + bytes[i * 4 + 2] = (bits ushr 16 and 0xff).toByte() + bytes[i * 4 + 3] = (bits ushr 24 and 0xff).toByte() + } + bytes + } + "I32", "INT32" -> { + val bytes = ByteArray(count * 4) + for (i in 0 until n) { + val v = values[i].toInt() + bytes[i * 4] = (v and 0xff).toByte() + bytes[i * 4 + 1] = (v ushr 8 and 0xff).toByte() + bytes[i * 4 + 2] = (v ushr 16 and 0xff).toByte() + bytes[i * 4 + 3] = (v ushr 24 and 0xff).toByte() + } + bytes + } + else -> throw IllegalArgumentException( + "Boxing-free external materialization supports FP32 / I32 FloatArray; got dtype=$dtype." + ) + } +} + /** * Expected element count for a (possibly empty) shape. Empty shape * (scalar) means one element; `null` / absent dims degrade to 0 so the diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ConstantOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ConstantOperationsConverter.kt index 44de7868..5c44df61 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ConstantOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ConstantOperationsConverter.kt @@ -6,6 +6,7 @@ import sk.ainet.compile.hlo.ConversionResult import sk.ainet.compile.hlo.ExternalParameterRef import sk.ainet.compile.hlo.StableHloOperationConverter import sk.ainet.compile.hlo.elementCountFromShape +import sk.ainet.compile.hlo.floatArrayToLittleEndianBytes import sk.ainet.compile.hlo.numberListToLittleEndianBytes import sk.ainet.lang.graph.GraphNode import sk.ainet.lang.tensor.ops.TensorSpec @@ -227,10 +228,13 @@ public class ConstantOperationsConverter : StableHloOperationConverter { val paramType = if (isTrainable) "trainable parameter" else "frozen parameter" context.emitComment("${node.operation.name} ${node.id}: $paramType") - // Policy seam — same shape as convertTensorConstant. Only a - // List-valued initial_value can be externalized with bytes - // available today; Number (splat) and missing cases fall - // through to the inline path intentionally. + // Policy seam — same shape as convertTensorConstant. List- and + // FloatArray-valued initial_value can be externalized with bytes; + // Number (splat) and missing cases fall through to inline. + // FloatArray is the boxing-free form from finalize for real weights. + if (initialValue is FloatArray) { + tryMaterializeExternalFloats(node, outputSpec, outputType, initialValue, context)?.let { return it } + } if (initialValue is List<*>) { tryMaterializeExternal(node, outputSpec, outputType, initialValue, context)?.let { return it } } @@ -241,6 +245,13 @@ public class ConstantOperationsConverter : StableHloOperationConverter { initialValue is Number -> { "$resultValue = stablehlo.constant dense<${formatConstantValue(initialValue)}> : $outputType" } + initialValue is FloatArray -> { + // Inline path (small tensors / InlineAlways): asList() boxes + // lazily during formatting; acceptable since big tensors take + // the external branch above. + val formattedValues = formatTensorValues(initialValue.asList(), outputSpec) + "$resultValue = stablehlo.constant dense<$formattedValues> : $outputType" + } initialValue is List<*> -> { val formattedValues = formatTensorValues(initialValue, outputSpec) "$resultValue = stablehlo.constant dense<$formattedValues> : $outputType" @@ -454,6 +465,71 @@ public class ConstantOperationsConverter : StableHloOperationConverter { ) } + /** + * Boxing-free twin of [tryMaterializeExternal] for [FloatArray] + * weights (the form `finalize` now produces). Serializes the + * primitive array straight to little-endian bytes — never building + * a `List` — so a 262153x640 embedding externalizes without + * the ~2.7GB boxing that OOMs the trace. Same emission/registration + * contract as the List variant. + */ + private fun tryMaterializeExternalFloats( + node: GraphNode, + outputSpec: TensorSpec?, + outputType: String, + values: FloatArray, + context: ConversionContext + ): ConversionResult? { + val policy = context.materializationPolicy + if (policy is ConstantMaterializationPolicy.InlineAlways) return null + if (outputSpec == null) return null + + val encoding = outputSpec.tensorEncoding + ?: TensorEncoding.Dense(bytesPerElement = bytesPerElement(outputSpec.dtype)) + val elementCount = elementCountFromShape(outputSpec.shape) + if (elementCount <= 0) return null + + val logicalBytes = encoding.physicalBytes(elementCount.toLong()) ?: return null + val scope = when (policy) { + is ConstantMaterializationPolicy.InlineAlways -> return null + is ConstantMaterializationPolicy.ExternalAlways -> policy.scope + is ConstantMaterializationPolicy.SizeThreshold -> { + if (logicalBytes < policy.bytes) return null + policy.scope + } + } + + val bytes = try { + floatArrayToLittleEndianBytes(values, outputSpec.dtype, elementCount) + } catch (e: IllegalArgumentException) { + context.emitComment( + "external materialization fell back to inline for ${node.id}: ${e.message}" + ) + return null + } + + val key = outputSpec.name.ifEmpty { node.id } + context.registerExternalParameter( + ExternalParameterRef( + scope = scope, + key = key, + encoding = encoding, + source = BufferHandle.Owned(bytes) + ) + ) + context.emitModuleDeclaration( + "util.global private @${key} = " + + "#flow.parameter.named<\"${scope}\"::\"${key}\"> : $outputType" + ) + val resultValue = context.nextTempValue() + val operation = "$resultValue = util.global.load @${key} : $outputType" + context.emitOperation(operation) + return ConversionResult.Success( + outputValueName = resultValue, + emittedOperations = listOf(operation) + ) + } + /** * Rough bytes-per-element for the default [TensorEncoding.Dense] * fallback when a spec does not carry an explicit encoding. diff --git a/skainet-io/skainet-io-iree-params/src/commonMain/kotlin/sk/ainet/io/irpa/IrpaWriter.kt b/skainet-io/skainet-io-iree-params/src/commonMain/kotlin/sk/ainet/io/irpa/IrpaWriter.kt index 796c9518..e3f6caed 100644 --- a/skainet-io/skainet-io-iree-params/src/commonMain/kotlin/sk/ainet/io/irpa/IrpaWriter.kt +++ b/skainet-io/skainet-io-iree-params/src/commonMain/kotlin/sk/ainet/io/irpa/IrpaWriter.kt @@ -236,15 +236,11 @@ public class IrpaWriter { } private fun writeByteArray(sink: Sink, data: ByteArray, offset: Int, length: Int) { - // Byte-at-a-time for the same reason noted below — and because - // under the sizes we see in practice for single-op values - // (tens to a few thousand bytes) the overhead is lost in the - // wider write cost. FileBacked paths use a chunked copy on - // their platform-specific side, which is where the byte - // volume is meaningful. - for (i in offset until offset + length) { - sink.writeByte(data[i]) - } + // Bulk range write. Owned/Borrowed buffers can be large for real + // LLM weights (a 262153x640 FP32 embedding = ~670MB); a + // byte-at-a-time loop there is pathological (hundreds of millions + // of calls). kotlinx-io's range write copies in one shot. + sink.write(data, offset, offset + length) } private fun writePadding(sink: Sink, bytes: Int) { From e1daab124d6c71687d9f6ddff216af2e343ca4a4 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 1 Jun 2026 20:03:05 +0200 Subject: [PATCH 10/10] fix(hlo attention): consume explicit SDPA mask operand AttentionOperationsConverter ignored operands[3] (the additive mask) and only masked when causal=true. Gemma sliding-window layers pass an explicit causal+window mask with causal=false, so those layers exported UNMASKED -> attended to future tokens -> A/B vs llama.cpp correct only at position 0. Now broadcast (trailing-aligned) the mask to the scores shape and add it before softmax; the built-in iota causal path remains for the no-mask case. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../AttentionOperationsConverter.kt | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/AttentionOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/AttentionOperationsConverter.kt index 183e2175..9f35ad02 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/AttentionOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/AttentionOperationsConverter.kt @@ -102,10 +102,30 @@ public class AttentionOperationsConverter : StableHloOperationConverter { "$scaled = stablehlo.multiply $scores, $scaleC : $scoresType", ) - // Causal mask: keep key_index <= query_index, set the rest to -inf - // before softmax (additive mask, built from iota row/col indices). + // Explicit additive mask (operands[3]) — e.g. a sliding-window+causal + // mask the caller built and passed with causal=false. It already + // encodes causality/window, so it takes priority over the built-in + // iota causal path. Broadcast (trailing-aligned) to the scores shape + // and add. Without this the masked layers run UNMASKED (attend to + // future tokens) — correct only at position 0. var softmaxIn = scaled - if (causal) { + val maskOperand = operands.getOrNull(3) + if (maskOperand != null) { + val maskShape = node.inputs.getOrNull(3)?.shape ?: scoresShape + val maskType = context.getValueType(maskOperand) ?: typeOf(maskShape) + val maskBc = if (maskShape == scoresShape) { + maskOperand + } else { + val mb = context.nextTempValue() + val offset = scoresShape.size - maskShape.size + val dims = maskShape.indices.joinToString(", ") { (it + offset).toString() } + ops += "$mb = stablehlo.broadcast_in_dim $maskOperand, dims = [$dims] : ($maskType) -> $scoresType" + mb + } + val masked = context.nextTempValue() + ops += "$masked = stablehlo.add $scaled, $maskBc : $scoresType" + softmaxIn = masked + } else if (causal) { val iotaQ = context.nextTempValue(); val iotaK = context.nextTempValue() val keep = context.nextTempValue(); val zeros = context.nextTempValue() val ninf = context.nextTempValue(); val maskAdd = context.nextTempValue()