Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package sk.ainet.compile.hlo

import sk.ainet.lang.dag.concat
import sk.ainet.lang.dag.dag
import sk.ainet.lang.dag.matmul
import sk.ainet.lang.dag.reshape
import sk.ainet.lang.graph.dsl.toComputeGraph
import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.tensor.ops.TensorSpec
import sk.ainet.lang.types.FP32
import kotlin.test.Test
import kotlin.test.assertFalse
import kotlin.test.assertTrue

/**
* End-to-end regression tests that exercise the REAL `sk.ainet.lang.dag` DSL path
* (`dag { … }.toComputeGraph()` → extended converter), exactly as the conformance
* harness builds its op modules.
*
* Distinct from [ReshapeConcatShapeFixTest], which constructs synthetic `GraphNode`s
* directly: that test passes while these fail, because the bug lives in how the DAG
* DSL records a shape-changing op's output spec — not in the converter's shape math.
*
* Each asserts the emitted module is IREE-compilable shape-wise: the op lowers AND the
* declared result/return type matches the value it actually produces.
*/
class DagShapeExportConformanceTest {

private fun lower(name: String, build: sk.ainet.lang.dag.DagBuilder.() -> Unit): String =
StableHloConverterFactory.createExtended().convert(dag(build).toComputeGraph(), name).content

@Test
fun reshape_via_dag_dsl_lowers_with_target_shape() {
// reshape (1,4) -> (2,2). Harness: OpsModel.reshapeMlir().
val mlir = lower("op_reshape") {
val a = input<FP32>("a", TensorSpec("a", listOf(1, 4), "FP32"))
output(reshape(a, Shape(intArrayOf(2, 2))))
}
assertFalse(
mlir.contains("Missing shape parameter") || mlir.contains("requires a target shape"),
"reshape dropped its target shape — empty/invalid module:\n$mlir",
)
assertTrue(mlir.contains("stablehlo.reshape"), "reshape must lower:\n$mlir")
assertTrue(mlir.contains("tensor<2x2xf32>"), "reshape must carry target shape 2x2:\n$mlir")
}

@Test
fun concat_via_dag_dsl_propagates_summed_axis_to_return() {
// concat([(1,4),(1,4)], dim=1) -> (1,8). Harness: OpsModel.concatMlir().
val mlir = lower("op_concat") {
val a = input<FP32>("a", TensorSpec("a", listOf(1, 4), "FP32"))
val b = input<FP32>("b", TensorSpec("b", listOf(1, 4), "FP32"))
output(concat(listOf(a, b), 1))
}
assertTrue(mlir.contains("-> tensor<1x8xf32>"), "concat op must type the axis-sum 1x8:\n$mlir")
// The function return must agree with the value it returns (else iree-compile rejects it).
assertFalse(
Regex("""return %\w+ : tensor<1x4xf32>""").containsMatchIn(mlir),
"function return type still 1x4 but the concat value is 1x8 — type mismatch:\n$mlir",
)
}

@Test
fun matmul_via_dag_dsl_declares_inferred_result_shape() {
// (1,4)·(4,3) -> (1,3). Harness: OpsModel.matmulMlir().
// dot_general contracts dim 1 x 0, so the result is 1x3 — but the export
// declares the result/return as 1x4 (echoes operand-0), which iree-compile
// rejects: "inferred shape '[1,3]' is incompatible with return type tensor<1x4xf32>".
val mlir = lower("op_matmul") {
val a = input<FP32>("a", TensorSpec("a", listOf(1, 4), "FP32"))
val w = input<FP32>("w", TensorSpec("w", listOf(4, 3), "FP32"))
output(matmul(a, w))
}
assertTrue(mlir.contains("stablehlo.dot_general") || mlir.contains("stablehlo.dot"), "matmul must lower:\n$mlir")
assertTrue(mlir.contains("-> tensor<1x3xf32>"), "matmul result must be the inferred 1x3:\n$mlir")
assertFalse(
Regex("""return %\w+ : tensor<1x4xf32>""").containsMatchIn(mlir),
"function return type still 1x4 but the matmul value is 1x3 — type mismatch:\n$mlir",
)
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package sk.ainet.lang.dag

import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.tensor.ops.GenericOperation
import sk.ainet.lang.tensor.ops.InputOperation
import sk.ainet.lang.tensor.ops.Operation
Expand Down Expand Up @@ -171,17 +172,65 @@ public class DagBuilder {
inputs: List<GraphValue<*>>,
nodeId: String
): List<TensorSpec>? {
if (operation.name.lowercase() !in setOf("sum", "mean", "variance")) return null
val input = inputs.firstOrNull()?.spec ?: return null
val outputShape = reductionOutputShape(input.shape, operation.parameters["dim"] as? Int ?: operation.parameters["axis"] as? Int)
return listOf(
val input = inputs.firstOrNull()?.spec
fun spec(shape: List<Int>?): List<TensorSpec> = listOf(
TensorSpec(
name = "${nodeId}_out0",
shape = outputShape,
dtype = input.dtype,
requiresGrad = input.requiresGrad
)
shape = shape,
dtype = input?.dtype ?: "unknown",
requiresGrad = input?.requiresGrad ?: false,
),
)

// Shape-changing ops whose output extent differs from operand-0. Without these,
// ensureOutputSpecs falls back to echoing operand-0's shape, producing modules
// whose declared result/return type contradicts the op's real output (the value
// iree-compile actually sees) — e.g. matmul (1,4)x(4,3) declared as 1x4 not 1x3,
// concat summed axis lost, reshape target dropped. (SKaiNET#673)
when (operation.name.lowercase()) {
"sum", "mean", "variance" -> {
input ?: return null
return spec(reductionOutputShape(input.shape, operation.parameters["dim"] as? Int ?: operation.parameters["axis"] as? Int))
}
"reshape", "view" -> {
val target = reshapeTargetShape(operation) ?: return null
return spec(target)
}
"matmul", "dot", "mm", "bmm", "batch_matmul" -> {
val lhs = inputs.getOrNull(0)?.spec?.shape
val rhs = inputs.getOrNull(1)?.spec?.shape
if (lhs == null || rhs == null || lhs.size < 2 || rhs.size < 2) return null
// (..., M, K) @ (..., K, N) -> (..., M, N)
return spec(lhs.dropLast(1) + rhs.last())
}
"concat", "concatenate", "cat" -> {
val shapes = inputs.mapNotNull { it.spec.shape }
if (shapes.size != inputs.size || shapes.isEmpty()) return null
if (shapes.any { it.size != shapes[0].size }) return null
val rank = shapes[0].size
val rawAxis = operation.parameters["dim"] as? Int ?: operation.parameters["axis"] as? Int ?: return null
val axis = if (rawAxis < 0) rank + rawAxis else rawAxis
if (axis !in 0 until rank) return null
val out = shapes[0].toMutableList()
out[axis] = shapes.sumOf { it[axis] }
return spec(out)
}
}
return null
}

/** Recover a reshape/view target shape from the op's `newShape`/`shape` parameter. */
private fun reshapeTargetShape(operation: Operation): List<Int>? {
val raw = operation.parameters["newShape"]
?: operation.parameters["shape"]
?: operation.parameters["outputShape"]
?: return null
return when (raw) {
is Shape -> raw.dimensions.toList()
is IntArray -> raw.toList()
is List<*> -> raw.filterIsInstance<Int>().takeIf { it.size == raw.size }
else -> null
}
}

private fun reductionOutputShape(shape: List<Int>?, dim: Int?): List<Int>? {
Expand Down
Loading