diff --git a/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/DagShapeExportConformanceTest.kt b/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/DagShapeExportConformanceTest.kt new file mode 100644 index 00000000..f7dc02bd --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/DagShapeExportConformanceTest.kt @@ -0,0 +1,81 @@ +package sk.ainet.compile.hlo + +import sk.ainet.lang.dag.concat +import sk.ainet.lang.dag.dag +import sk.ainet.lang.dag.matmul +import sk.ainet.lang.dag.reshape +import sk.ainet.lang.graph.dsl.toComputeGraph +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.types.FP32 +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +/** + * End-to-end regression tests that exercise the REAL `sk.ainet.lang.dag` DSL path + * (`dag { … }.toComputeGraph()` → extended converter), exactly as the conformance + * harness builds its op modules. + * + * Distinct from [ReshapeConcatShapeFixTest], which constructs synthetic `GraphNode`s + * directly: that test passes while these fail, because the bug lives in how the DAG + * DSL records a shape-changing op's output spec — not in the converter's shape math. + * + * Each asserts the emitted module is IREE-compilable shape-wise: the op lowers AND the + * declared result/return type matches the value it actually produces. + */ +class DagShapeExportConformanceTest { + + private fun lower(name: String, build: sk.ainet.lang.dag.DagBuilder.() -> Unit): String = + StableHloConverterFactory.createExtended().convert(dag(build).toComputeGraph(), name).content + + @Test + fun reshape_via_dag_dsl_lowers_with_target_shape() { + // reshape (1,4) -> (2,2). Harness: OpsModel.reshapeMlir(). + val mlir = lower("op_reshape") { + val a = input("a", TensorSpec("a", listOf(1, 4), "FP32")) + output(reshape(a, Shape(intArrayOf(2, 2)))) + } + assertFalse( + mlir.contains("Missing shape parameter") || mlir.contains("requires a target shape"), + "reshape dropped its target shape — empty/invalid module:\n$mlir", + ) + assertTrue(mlir.contains("stablehlo.reshape"), "reshape must lower:\n$mlir") + assertTrue(mlir.contains("tensor<2x2xf32>"), "reshape must carry target shape 2x2:\n$mlir") + } + + @Test + fun concat_via_dag_dsl_propagates_summed_axis_to_return() { + // concat([(1,4),(1,4)], dim=1) -> (1,8). Harness: OpsModel.concatMlir(). + val mlir = lower("op_concat") { + val a = input("a", TensorSpec("a", listOf(1, 4), "FP32")) + val b = input("b", TensorSpec("b", listOf(1, 4), "FP32")) + output(concat(listOf(a, b), 1)) + } + assertTrue(mlir.contains("-> tensor<1x8xf32>"), "concat op must type the axis-sum 1x8:\n$mlir") + // The function return must agree with the value it returns (else iree-compile rejects it). + assertFalse( + Regex("""return %\w+ : tensor<1x4xf32>""").containsMatchIn(mlir), + "function return type still 1x4 but the concat value is 1x8 — type mismatch:\n$mlir", + ) + } + + @Test + fun matmul_via_dag_dsl_declares_inferred_result_shape() { + // (1,4)·(4,3) -> (1,3). Harness: OpsModel.matmulMlir(). + // dot_general contracts dim 1 x 0, so the result is 1x3 — but the export + // declares the result/return as 1x4 (echoes operand-0), which iree-compile + // rejects: "inferred shape '[1,3]' is incompatible with return type tensor<1x4xf32>". + val mlir = lower("op_matmul") { + val a = input("a", TensorSpec("a", listOf(1, 4), "FP32")) + val w = input("w", TensorSpec("w", listOf(4, 3), "FP32")) + output(matmul(a, w)) + } + assertTrue(mlir.contains("stablehlo.dot_general") || mlir.contains("stablehlo.dot"), "matmul must lower:\n$mlir") + assertTrue(mlir.contains("-> tensor<1x3xf32>"), "matmul result must be the inferred 1x3:\n$mlir") + assertFalse( + Regex("""return %\w+ : tensor<1x4xf32>""").containsMatchIn(mlir), + "function return type still 1x4 but the matmul value is 1x3 — type mismatch:\n$mlir", + ) + } +} diff --git a/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/GraphDsl.kt b/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/GraphDsl.kt index 8888f51f..07fff86f 100644 --- a/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/GraphDsl.kt +++ b/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/GraphDsl.kt @@ -1,5 +1,6 @@ package sk.ainet.lang.dag +import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.ops.GenericOperation import sk.ainet.lang.tensor.ops.InputOperation import sk.ainet.lang.tensor.ops.Operation @@ -171,17 +172,65 @@ public class DagBuilder { inputs: List>, nodeId: String ): List? { - if (operation.name.lowercase() !in setOf("sum", "mean", "variance")) return null - val input = inputs.firstOrNull()?.spec ?: return null - val outputShape = reductionOutputShape(input.shape, operation.parameters["dim"] as? Int ?: operation.parameters["axis"] as? Int) - return listOf( + val input = inputs.firstOrNull()?.spec + fun spec(shape: List?): List = listOf( TensorSpec( name = "${nodeId}_out0", - shape = outputShape, - dtype = input.dtype, - requiresGrad = input.requiresGrad - ) + shape = shape, + dtype = input?.dtype ?: "unknown", + requiresGrad = input?.requiresGrad ?: false, + ), ) + + // Shape-changing ops whose output extent differs from operand-0. Without these, + // ensureOutputSpecs falls back to echoing operand-0's shape, producing modules + // whose declared result/return type contradicts the op's real output (the value + // iree-compile actually sees) — e.g. matmul (1,4)x(4,3) declared as 1x4 not 1x3, + // concat summed axis lost, reshape target dropped. (SKaiNET#673) + when (operation.name.lowercase()) { + "sum", "mean", "variance" -> { + input ?: return null + return spec(reductionOutputShape(input.shape, operation.parameters["dim"] as? Int ?: operation.parameters["axis"] as? Int)) + } + "reshape", "view" -> { + val target = reshapeTargetShape(operation) ?: return null + return spec(target) + } + "matmul", "dot", "mm", "bmm", "batch_matmul" -> { + val lhs = inputs.getOrNull(0)?.spec?.shape + val rhs = inputs.getOrNull(1)?.spec?.shape + if (lhs == null || rhs == null || lhs.size < 2 || rhs.size < 2) return null + // (..., M, K) @ (..., K, N) -> (..., M, N) + return spec(lhs.dropLast(1) + rhs.last()) + } + "concat", "concatenate", "cat" -> { + val shapes = inputs.mapNotNull { it.spec.shape } + if (shapes.size != inputs.size || shapes.isEmpty()) return null + if (shapes.any { it.size != shapes[0].size }) return null + val rank = shapes[0].size + val rawAxis = operation.parameters["dim"] as? Int ?: operation.parameters["axis"] as? Int ?: return null + val axis = if (rawAxis < 0) rank + rawAxis else rawAxis + if (axis !in 0 until rank) return null + val out = shapes[0].toMutableList() + out[axis] = shapes.sumOf { it[axis] } + return spec(out) + } + } + return null + } + + /** Recover a reshape/view target shape from the op's `newShape`/`shape` parameter. */ + private fun reshapeTargetShape(operation: Operation): List? { + val raw = operation.parameters["newShape"] + ?: operation.parameters["shape"] + ?: operation.parameters["outputShape"] + ?: return null + return when (raw) { + is Shape -> raw.dimensions.toList() + is IntArray -> raw.toList() + is List<*> -> raw.filterIsInstance().takeIf { it.size == raw.size } + else -> null + } } private fun reductionOutputShape(shape: List?, dim: Int?): List? {