diff --git a/gradle.properties b/gradle.properties index 1a99b287..3a8751f8 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,5 +1,5 @@ GROUP=sk.ainet.core -VERSION_NAME=0.27.0 +VERSION_NAME=0.28.0 POM_DESCRIPTION=SKaiNET POM_URL=https://github.com/SKaiNET-developers/skainet/ diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt index 5f829ebf..a86f639f 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ShapeOperationsConverter.kt @@ -4,6 +4,7 @@ import sk.ainet.compile.hlo.ConversionContext import sk.ainet.compile.hlo.ConversionResult import sk.ainet.compile.hlo.StableHloOperationConverter import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.ops.TensorSpec /** * Converter for shape manipulation operations. @@ -97,7 +98,24 @@ public class ShapeOperationsConverter : StableHloOperationConverter { ?: node.inputs.getOrNull(i)?.let { context.getTypeMapper().mapTensorType(it) } ?: "tensor" } - val operation = "$resultValue = stablehlo.concatenate $operandList, dim = $axis : ($operandTypes) -> $outputType" + // Compute the result type ourselves: the extent on `axis` is the SUM of the + // operands' extents there (the other axes match operand 0). Trusting the node's + // declared output spec mis-infers the concatenated axis for >2 operands (#667). + val inShapes = node.inputs.mapNotNull { it.shape } + val resultType = + if (inShapes.size == node.inputs.size && inShapes.size == operands.size && + inShapes.isNotEmpty() && inShapes.all { it.size == inShapes[0].size } && + axis in inShapes[0].indices + ) { + val outShape = inShapes[0].toMutableList() + outShape[axis] = inShapes.sumOf { it[axis] } + context.getTypeMapper().mapTensorType( + TensorSpec("${node.id}_out", outShape, outputSpec?.dtype ?: node.inputs[0].dtype), + ) + } else { + outputType + } + val operation = "$resultValue = stablehlo.concatenate $operandList, dim = $axis : ($operandTypes) -> $resultType" context.emitOperation(operation) return ConversionResult.Success( @@ -297,23 +315,14 @@ public class ShapeOperationsConverter : StableHloOperationConverter { } val outputSpec = node.outputs.firstOrNull() - val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } - ?: "tensor" - - // Get the new shape from parameters or output spec - val newShape = when { - outputSpec?.shape != null -> outputSpec.shape - node.operation.parameters.containsKey("shape") -> { - @Suppress("UNCHECKED_CAST") - node.operation.parameters["shape"] as? List - } - node.operation.parameters.containsKey("newShape") -> { - @Suppress("UNCHECKED_CAST") - node.operation.parameters["newShape"] as? List - } - else -> null - } - + + // Get the new shape from the output spec or any of the parameter aliases the + // tape records it under (`outputShape` is the key the reshape trace uses, #666). + @Suppress("UNCHECKED_CAST") + fun param(key: String): List? = node.operation.parameters[key] as? List + val newShape: List? = outputSpec?.shape + ?: param("shape") ?: param("newShape") ?: param("outputShape") + if (newShape == null || newShape.isEmpty()) { return ConversionResult.Failure( "Reshape operation requires a target shape specification", @@ -321,9 +330,16 @@ public class ShapeOperationsConverter : StableHloOperationConverter { ) } + // Prefer the declared output type; otherwise build it from the resolved shape so + // a reshape whose target lives only in a parameter still emits a typed result. + val resultType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } + ?: context.getTypeMapper().mapTensorType( + TensorSpec("${node.id}_out", newShape, node.inputs.firstOrNull()?.dtype ?: "FP32"), + ) + val inputType = resolveOperandType(operands[0], node, context) val resultValue = context.nextTempValue() - val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($inputType) -> $outputType" + val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($inputType) -> $resultType" context.emitOperation(operation) return ConversionResult.Success( diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ReshapeConcatShapeFixTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ReshapeConcatShapeFixTest.kt new file mode 100644 index 00000000..a6447bcd --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ReshapeConcatShapeFixTest.kt @@ -0,0 +1,79 @@ +package sk.ainet.compile.hlo + +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.GraphEdge +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.ops.Operation +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.tensor.ops.ValidationResult +import sk.ainet.lang.types.DType +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +/** + * Regression tests for the ShapeOperationsConverter shape bugs: + * - #667: multi-input `concatenate` must SUM the operands' extents on the axis. + * - #666: `reshape` whose target shape lives only in a parameter must still lower. + */ +class ReshapeConcatShapeFixTest { + + private fun op(name: String, type: String, params: Map = emptyMap()): Operation = + object : Operation { + override val name: String = name + override val type: String = type + override val parameters: Map = params + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("marker") + override fun validateInputs(inputs: List): ValidationResult = ValidationResult.Valid + override fun inferOutputs(inputs: List): List = inputs.take(1) + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = mapOf("name" to name, "type" to type) + } + + @Test + fun multiInputConcat_sums_the_concatenated_axis() { + val g = DefaultComputeGraph() + val a = GraphNode("a", op("input", "input"), emptyList(), listOf(TensorSpec("a", listOf(1, 1, 8, 8), "FP32"))) + val b = GraphNode("b", op("input", "input"), emptyList(), listOf(TensorSpec("b", listOf(1, 4, 8, 8), "FP32"))) + val c = GraphNode("c", op("input", "input"), emptyList(), listOf(TensorSpec("c", listOf(1, 1, 8, 8), "FP32"))) + g.addNode(a); g.addNode(b); g.addNode(c) + + val cat = GraphNode( + id = "cat", + operation = op("concatenate", "shape", mapOf("axis" to 1)), + inputs = listOf(TensorSpec("a", listOf(1, 1, 8, 8), "FP32"), TensorSpec("b", listOf(1, 4, 8, 8), "FP32"), TensorSpec("c", listOf(1, 1, 8, 8), "FP32")), + // Deliberately WRONG declared output shape (operand-0 extent on the axis): + outputs = listOf(TensorSpec("y", listOf(1, 1, 8, 8), "FP32")), + ) + g.addNode(cat) + g.addEdge(GraphEdge("e0", a, cat, 0, 0, a.outputs[0])) + g.addEdge(GraphEdge("e1", b, cat, 0, 1, b.outputs[0])) + g.addEdge(GraphEdge("e2", c, cat, 0, 2, c.outputs[0])) + + val mlir = StableHloConverterFactory.createExtended().convert(g, "concat3").content + assertTrue(mlir.contains("stablehlo.concatenate"), "expected a concatenate op:\n$mlir") + assertTrue(mlir.contains("-> tensor<1x6x8x8xf32>"), "concat result must sum the axis to 6:\n$mlir") + assertFalse(mlir.contains("-> tensor<1x1x8x8xf32>"), "must not echo operand-0's axis extent:\n$mlir") + } + + @Test + fun reshape_with_shape_only_in_parameter_still_lowers() { + val g = DefaultComputeGraph() + val x = GraphNode("x", op("input", "input"), emptyList(), listOf(TensorSpec("x", listOf(1, 12), "FP32"))) + g.addNode(x) + val r = GraphNode( + id = "r", + operation = op("reshape", "shape", mapOf("outputShape" to listOf(1, 3, 4))), + inputs = listOf(TensorSpec("x", listOf(1, 12), "FP32")), + outputs = emptyList(), // no declared output spec — target lives only in the param + ) + g.addNode(r) + g.addEdge(GraphEdge("e0", x, r, 0, 0, x.outputs[0])) + + val mlir = StableHloConverterFactory.createExtended().convert(g, "reshape1").content + assertTrue(mlir.contains("stablehlo.reshape"), "reshape must lower (not an empty module):\n$mlir") + assertTrue(mlir.contains("tensor<1x3x4xf32>"), "reshape must carry the target shape:\n$mlir") + } +}