From eab9b29d40b45a313a69fe8437fb16bc49d54249 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Sat, 6 Jun 2026 18:03:14 +0200 Subject: [PATCH] fix(dag+hlo): conv1d/gather/pooling/flatten output-spec + IREE-valid reduce_window (#675) Follow-up to #674. Completes DAG-DSL StableHLO export so every conformance model and op compiles with iree-compile. inferDagOutputSpecs (skainet-lang-dag) gains shape rules for: - conv1d : (N,Cin,L) * (Cout,_,K) -> (N, Cout, floor((L+2p - d(K-1) - 1)/s) + 1) - gather : table[:axis] + indices.shape + table[axis+1:] - maxpool/avgpool : windowed (N, C, Hout, Wout) - flatten : collapse dims [startDim..endDim], preserving the leading batch dim (it was collapsing everything to rank-1, breaking the dense matmul in mnist-cnn) conv2d already inferred via Conv2dOperation; conv1d was a GenericOperation that fell back to echoing operand-0. reduce_window emission (NeuralNetOperationsConverter): emit the IREE-parseable generic region form %r = "stablehlo.reduce_window"(%in, %init) ({ ^bb0(%a, %b): ... }) instead of the pretty "applies over window dimensions = ..." form IREE rejects ("has no custom assembly form"). Full NCHW-rank window attributes; region-local SSA names made unique per op so two pools in one function don't collide in the validator; avg-pool divisor splatted to the output type (was a scalar-vs-tensor mismatch). MlirValidator: register region block-argument SSA defs (^bb0(%a, %b)) and every "%x =" result on a line, so single-line region ops validate. Verified end-to-end via an unsigned 0.28.1-SNAPSHOT on mavenLocal + skainet-iree-conformance: 27/27 ops and 7/7 models (incl. mnist-cnn) iree-compile to a vmfb. DagConvGatherPoolExportTest 5/5 green; hlo / dag / lang-dag suites green. --- .../sk/ainet/compile/hlo/MlirValidator.kt | 29 +++--- .../NeuralNetOperationsConverter.kt | 99 +++++++++++++------ .../hlo/NeuralNetOperationsConverterTest.kt | 6 +- .../hlo/DagConvGatherPoolExportTest.kt | 91 +++++++++++++++++ .../kotlin/sk/ainet/lang/dag/GraphDsl.kt | 61 ++++++++++++ 5 files changed, 243 insertions(+), 43 deletions(-) create mode 100644 skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/DagConvGatherPoolExportTest.kt diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/MlirValidator.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/MlirValidator.kt index afe13de9..8f295928 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/MlirValidator.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/MlirValidator.kt @@ -112,20 +112,25 @@ public class MlirValidator { // not). if (trimmed.isEmpty() || trimmed.startsWith("//") || trimmed.startsWith("module")) continue - // Extract defined SSA values - if (trimmed.contains(" = ")) { - val parts = trimmed.split(" = ", limit = 2) - if (parts.size == 2) { - val valueName = parts[0].trim() - if (valueName.startsWith("%")) { - if (definedValues.contains(valueName)) { - errors.add("Line ${lineNum + 1}: SSA value $valueName redefined") - } - definedValues.add(valueName) - } + // Extract defined SSA values. A line may carry more than one result + // definition when an op with a region is emitted on a single line + // (e.g. `reduce_window … ({ ^bb0(%a, %b): %r = … })`), so register every + // `%name =` result, not just the leading assignment. + Regex("(%[a-zA-Z0-9_]+)\\s*=").findAll(trimmed).forEach { m -> + val valueName = m.groupValues[1] + if (definedValues.contains(valueName)) { + errors.add("Line ${lineNum + 1}: SSA value $valueName redefined") } + definedValues.add(valueName) } - + + // Register block-argument definitions from region entry blocks + // (`^bb0(%lhs: tensor, %rhs: tensor): …`). These bind SSA + // values without a ` = `, so they must be collected separately. + Regex("\\^[a-zA-Z0-9_]*\\(([^)]*)\\)").findAll(trimmed).forEach { block -> + Regex("%[a-zA-Z0-9_]+").findAll(block.groupValues[1]).forEach { definedValues.add(it.value) } + } + // Extract used SSA values val usedInLine = extractUsedValues(trimmed) usedValues.addAll(usedInLine) diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt index 2de8907b..85cd7374 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt @@ -184,11 +184,13 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter { val padding = extractPadding(params) val resultValue = context.nextTempValue() - + val inputType = node.inputs.firstOrNull()?.let { context.getTypeMapper().mapTensorType(it) } ?: outputType + // Build StableHLO reduce_window operation for max pooling val operations = buildMaxPoolOperations( resultValue = resultValue, input = operands[0], + inputType = inputType, outputType = outputType, kernelSize = kernelSize, stride = stride, @@ -225,11 +227,13 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter { val padding = extractPadding(params) val resultValue = context.nextTempValue() - + val inputType = node.inputs.firstOrNull()?.let { context.getTypeMapper().mapTensorType(it) } ?: outputType + // Build StableHLO reduce_window operation for average pooling val operations = buildAvgPoolOperations( resultValue = resultValue, input = operands[0], + inputType = inputType, outputType = outputType, kernelSize = kernelSize, stride = stride, @@ -691,64 +695,103 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter { } } + /** The MLIR element type ("f32"/"f16"/…) parsed from a `tensor<…xT>` string. */ + private fun elementTypeOf(tensorType: String): String = + tensorType.substringAfterLast('x').substringBefore('>').ifBlank { "f32" } + + /** + * Emit a `reduce_window` in IREE's parseable **generic region** form. The pretty + * `… applies over window dimensions = …` form is rejected by IREE's StableHLO + * parser ("has no custom assembly form"), and its 2-element window only covered H/W; + * the generic form carries full NCHW-rank (`[1, 1, kH, kW]`) window attributes. (#675) + */ + private fun reduceWindowGeneric( + resultValue: String, + input: String, + inputType: String, + initValue: String, + reduceOp: String, + elem: String, + kernelSize: Pair, + stride: Pair, + padding: Pair, + outputType: String, + ): String { + val (kH, kW) = kernelSize + val (sH, sW) = stride + val (pH, pW) = padding + // Single line: MLIR treats newlines as whitespace, and the line-based MLIR + // validator only handles one op per line. The region body ops are separated + // by spaces, which the MLIR parser accepts. + // Region-local SSA names are derived from the (unique) result value so two + // pooling ops in one function don't collide in the flat validator (they are + // region-scoped in MLIR, but the validator tracks names globally). + val t = resultValue.removePrefix("%") + return "$resultValue = \"stablehlo.reduce_window\"($input, $initValue) ({ " + + "^bb0(%lhs_$t: tensor<$elem>, %rhs_$t: tensor<$elem>): " + + "%out_$t = $reduceOp %lhs_$t, %rhs_$t : tensor<$elem> " + + "stablehlo.return %out_$t : tensor<$elem> " + + "}) {window_dimensions = array, " + + "window_strides = array, " + + "base_dilations = array, " + + "window_dilations = array, " + + "padding = dense<[[0, 0], [0, 0], [$pH, $pH], [$pW, $pW]]> : tensor<4x2xi64>} : " + + "($inputType, tensor<$elem>) -> $outputType" + } + private fun buildMaxPoolOperations( resultValue: String, input: String, + inputType: String, outputType: String, kernelSize: Pair, stride: Pair, padding: Pair, context: ConversionContext ): List { - // For max pooling, we need to create a negative infinity constant as the initial value + val elem = elementTypeOf(outputType) val initValue = context.nextTempValue() - val initConstant = "$initValue = stablehlo.constant dense<-3.4028235e+38> : tensor" - - val poolOp = "$resultValue = stablehlo.reduce_window($input, $initValue) " + - "applies stablehlo.maximum " + - "over window dimensions = [${kernelSize.first}, ${kernelSize.second}] " + - "stride = [${stride.first}, ${stride.second}] " + - "pad = [[${padding.first}, ${padding.first}], [${padding.second}, ${padding.second}]] : $outputType" - - // Emit operations through context + val initConstant = "$initValue = stablehlo.constant dense<-3.4028235e+38> : tensor<$elem>" + val poolOp = reduceWindowGeneric( + resultValue, input, inputType, initValue, "stablehlo.maximum", + elem, kernelSize, stride, padding, outputType, + ) context.emitOperation(initConstant) context.emitOperation(poolOp) - return listOf(initConstant, poolOp) } - + private fun buildAvgPoolOperations( resultValue: String, input: String, + inputType: String, outputType: String, kernelSize: Pair, stride: Pair, padding: Pair, context: ConversionContext ): List { - // Average pooling requires sum + division by kernel size + // Average pooling requires sum + division by kernel size. + val elem = elementTypeOf(outputType) val kernelArea = kernelSize.first * kernelSize.second val initZero = context.nextTempValue() val kernelAreaConst = context.nextTempValue() val sumResult = context.nextTempValue() - - val initConstant = "$initZero = stablehlo.constant dense<0.0> : tensor" - val areaConstant = "$kernelAreaConst = stablehlo.constant dense<$kernelArea.0> : tensor" - - val sumOp = "$sumResult = stablehlo.reduce_window($input, $initZero) " + - "applies stablehlo.add " + - "over window dimensions = [${kernelSize.first}, ${kernelSize.second}] " + - "stride = [${stride.first}, ${stride.second}] " + - "pad = [[${padding.first}, ${padding.first}], [${padding.second}, ${padding.second}]] : $outputType" - + + val initConstant = "$initZero = stablehlo.constant dense<0.0> : tensor<$elem>" + // Splat over the output type so the divide is element-type consistent (a scalar + // tensor divisor was a latent type mismatch). + val areaConstant = "$kernelAreaConst = stablehlo.constant dense<$kernelArea.0> : $outputType" + val sumOp = reduceWindowGeneric( + sumResult, input, inputType, initZero, "stablehlo.add", + elem, kernelSize, stride, padding, outputType, + ) val divideOp = "$resultValue = stablehlo.divide $sumResult, $kernelAreaConst : $outputType" - - // Emit operations through context + context.emitOperation(initConstant) context.emitOperation(areaConstant) context.emitOperation(sumOp) context.emitOperation(divideOp) - return listOf(initConstant, areaConstant, sumOp, divideOp) } diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NeuralNetOperationsConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NeuralNetOperationsConverterTest.kt index ee39ef73..d65c1ee8 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NeuralNetOperationsConverterTest.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NeuralNetOperationsConverterTest.kt @@ -73,9 +73,9 @@ class NeuralNetOperationsConverterTest { assertNotNull(module) assertContains(module.content, "stablehlo.reduce_window") - // Should contain kernel size and stride information - assertContains(module.content, "window dimensions") - assertContains(module.content, "stride") + // Generic region form (IREE-parseable): window_dimensions / window_strides attrs. + assertContains(module.content, "window_dimensions") + assertContains(module.content, "window_strides") } private fun createConv2dGraph(): DefaultComputeGraph { diff --git a/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/DagConvGatherPoolExportTest.kt b/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/DagConvGatherPoolExportTest.kt new file mode 100644 index 00000000..18713861 --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/DagConvGatherPoolExportTest.kt @@ -0,0 +1,91 @@ +package sk.ainet.compile.hlo + +import sk.ainet.lang.dag.avgPool2d +import sk.ainet.lang.dag.conv1d +import sk.ainet.lang.dag.dag +import sk.ainet.lang.dag.flatten +import sk.ainet.lang.dag.gather +import sk.ainet.lang.dag.maxPool2d +import sk.ainet.lang.graph.dsl.toComputeGraph +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.types.FP32 +import sk.ainet.lang.types.Int32 +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +/** + * Remaining post-#674 DAG-DSL export bugs (tracked under the #674 follow-up issue). + * + * #674 fixed reshape/matmul/concat output-spec inference. These ops still declare a + * result/return type that contradicts the value they produce (conv/gather: `inferDagOutputSpecs` + * has no shape rule for them; pooling: also emits `reduce_window` in a form IREE rejects). + * All RED on develop after #674; lock for the follow-up fix. + */ +class DagConvGatherPoolExportTest { + + private fun lower(name: String, build: sk.ainet.lang.dag.DagBuilder.() -> Unit): String = + StableHloConverterFactory.createExtended().convert(dag(build).toComputeGraph(), name).content + + @Test + fun conv1d_declares_inferred_output_channels_and_length() { + // input (1,3,8), weight (4,3,3) stride 1 pad 0 -> (1,4,6). Currently declared 1x3x8. + val mlir = lower("op_conv1d") { + val x = input("x", TensorSpec("x", listOf(1, 3, 8), "FP32")) + val w = input("w", TensorSpec("w", listOf(4, 3, 3), "FP32")) + val b = input("b", TensorSpec("b", listOf(4), "FP32")) + output(conv1d(x, w, b, 1, 0, 1, 1)) + } + assertTrue(mlir.contains("-> tensor<1x4x6xf32>"), "conv1d result must be inferred 1x4x6:\n$mlir") + assertFalse(Regex("""return %\w+ : tensor<1x3x8xf32>""").containsMatchIn(mlir), "return must not echo the input shape:\n$mlir") + } + + @Test + fun gather_declares_inferred_rows() { + // table (8,4), 3 indices -> (3,4). Currently declared 8x4. + val mlir = lower("op_gather") { + val t = input("t", TensorSpec("t", listOf(8, 4), "FP32")) + val idx = input("idx", TensorSpec("idx", listOf(3), "INT32")) + output(gather(t, idx, 0)) + } + assertTrue(mlir.contains("-> tensor<3x4xf32>"), "gather result must be inferred 3x4:\n$mlir") + assertFalse(Regex("""return %\w+ : tensor<8x4xf32>""").containsMatchIn(mlir), "return must not echo the table shape:\n$mlir") + } + + @Test + fun maxpool2d_declares_pooled_shape_and_iree_valid_reduce_window() { + // input (1,3,8,8), 2x2 stride 2 -> (1,3,4,4). Currently declared 1x3x8x8. + val mlir = lower("op_maxpool2d") { + val x = input("x", TensorSpec("x", listOf(1, 3, 8, 8), "FP32")) + output(maxPool2d(x, 2 to 2, 2 to 2, 0 to 0)) + } + assertTrue(mlir.contains("tensor<1x3x4x4xf32>"), "maxpool output must be the pooled 1x3x4x4:\n$mlir") + // IREE's parser rejects the pretty `applies … over window` form; it needs the generic + // region-based reduce_window. Assert we are not emitting the rejected pretty form. + assertFalse( + Regex("""reduce_window\([^)]*\)\s+applies""").containsMatchIn(mlir), + "reduce_window must use the IREE-parseable generic form, not 'applies … over window':\n$mlir", + ) + } + + @Test + fun flatten_preserves_leading_batch_dim() { + // (1,16,7,7) flatten dims 1..3 -> (1, 784); must NOT collapse to rank-1 (784), + // which breaks a downstream dense matmul (mnist-cnn). + val mlir = lower("op_flatten") { + val x = input("x", TensorSpec("x", listOf(1, 16, 7, 7), "FP32")) + output(flatten(x, 1, 3)) + } + assertTrue(mlir.contains("tensor<1x784xf32>"), "flatten must keep batch: (1,16,7,7)->(1,784):\n$mlir") + assertFalse(Regex("""-> tensor<784xf32>""").containsMatchIn(mlir), "flatten must not collapse the batch dim:\n$mlir") + } + + @Test + fun avgpool2d_declares_pooled_shape() { + val mlir = lower("op_avgpool2d") { + val x = input("x", TensorSpec("x", listOf(1, 3, 8, 8), "FP32")) + output(avgPool2d(x, 2 to 2, 2 to 2, 0 to 0, false)) + } + assertTrue(mlir.contains("tensor<1x3x4x4xf32>"), "avgpool output must be the pooled 1x3x4x4:\n$mlir") + } +} diff --git a/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/GraphDsl.kt b/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/GraphDsl.kt index 07fff86f..34eb9cc4 100644 --- a/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/GraphDsl.kt +++ b/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/GraphDsl.kt @@ -196,6 +196,20 @@ public class DagBuilder { val target = reshapeTargetShape(operation) ?: return null return spec(target) } + "flatten" -> { + // Collapse dims [startDim..endDim] into one, preserving the others + // (notably the leading batch dim). Without this, flatten echoes operand-0 + // or collapses everything, so a downstream dense matmul mis-types. (#675) + val inS = input?.shape ?: return null + val rank = inS.size + val rawStart = operation.parameters["startDim"] as? Int ?: 1 + val rawEnd = operation.parameters["endDim"] as? Int ?: -1 + val start = if (rawStart < 0) rank + rawStart else rawStart + val end = if (rawEnd < 0) rank + rawEnd else rawEnd + if (start !in 0 until rank || end !in 0 until rank || start > end) return null + val collapsed = inS.subList(start, end + 1).fold(1) { a, b -> a * b } + return spec(inS.subList(0, start) + collapsed + inS.subList(end + 1, rank)) + } "matmul", "dot", "mm", "bmm", "batch_matmul" -> { val lhs = inputs.getOrNull(0)?.spec?.shape val rhs = inputs.getOrNull(1)?.spec?.shape @@ -215,10 +229,57 @@ public class DagBuilder { out[axis] = shapes.sumOf { it[axis] } return spec(out) } + "conv1d" -> { + // (N, Cin, L) * (Cout, Cin/groups, K) -> (N, Cout, Lout). conv2d already + // infers via Conv2dOperation; conv1d is a GenericOperation with no inference. (#675) + val inS = inputs.getOrNull(0)?.spec?.shape + val wS = inputs.getOrNull(1)?.spec?.shape + if (inS == null || wS == null || inS.size != 3 || wS.size != 3) return null + val stride = operation.parameters["stride"] as? Int ?: 1 + val pad = operation.parameters["padding"] as? Int ?: 0 + val dil = operation.parameters["dilation"] as? Int ?: 1 + return spec(listOf(inS[0], wS[0], windowedOutput(inS[2], wS[2], stride, pad, dil))) + } + "gather" -> { + // table[..axis..] gathered by `indices` -> table[:axis] ⊕ indices.shape ⊕ table[axis+1:]. (#675) + val table = inputs.getOrNull(0)?.spec?.shape + val idx = inputs.getOrNull(1)?.spec?.shape + if (table == null || idx == null || table.isEmpty()) return null + val rawAxis = operation.parameters["dim"] as? Int ?: operation.parameters["axis"] as? Int ?: -1 + val axis = if (rawAxis < 0) table.size + rawAxis else rawAxis + if (axis !in table.indices) return null + return spec(table.subList(0, axis) + idx + table.subList(axis + 1, table.size)) + } + "maxpool2d", "avgpool2d" -> { + // (N, C, H, W) windowed by kernel/stride/padding -> (N, C, Hout, Wout). (#675) + val inS = inputs.getOrNull(0)?.spec?.shape + if (inS == null || inS.size != 4) return null + val k = pairParam(operation, "kernel") ?: pairParam(operation, "kernelSize") ?: return null + val s = pairParam(operation, "stride") ?: (1 to 1) + val p = pairParam(operation, "padding") ?: (0 to 0) + return spec( + listOf( + inS[0], inS[1], + windowedOutput(inS[2], k.first, s.first, p.first, 1), + windowedOutput(inS[3], k.second, s.second, p.second, 1), + ), + ) + } } return null } + /** Windowed (conv/pool) output extent: floor((in + 2·pad − dilation·(k−1) − 1) / stride) + 1. */ + private fun windowedOutput(inDim: Int, k: Int, stride: Int, pad: Int, dilation: Int): Int = + (inDim + 2 * pad - dilation * (k - 1) - 1) / stride + 1 + + private fun pairParam(operation: Operation, key: String): Pair? = + (operation.parameters[key] as? Pair<*, *>)?.let { + val a = it.first as? Int + val b = it.second as? Int + if (a != null && b != null) a to b else null + } + /** Recover a reshape/view target shape from the op's `newShape`/`shape` parameter. */ private fun reshapeTargetShape(operation: Operation): List? { val raw = operation.parameters["newShape"]