Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
GROUP=sk.ainet.core
VERSION_NAME=0.27.0
VERSION_NAME=0.28.0
POM_DESCRIPTION=SKaiNET

POM_URL=https://github.com/SKaiNET-developers/skainet/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import sk.ainet.compile.hlo.ConversionContext
import sk.ainet.compile.hlo.ConversionResult
import sk.ainet.compile.hlo.StableHloOperationConverter
import sk.ainet.lang.graph.GraphNode
import sk.ainet.lang.tensor.ops.TensorSpec

/**
* Converter for shape manipulation operations.
Expand Down Expand Up @@ -97,7 +98,24 @@ public class ShapeOperationsConverter : StableHloOperationConverter {
?: node.inputs.getOrNull(i)?.let { context.getTypeMapper().mapTensorType(it) }
?: "tensor<?xf32>"
}
val operation = "$resultValue = stablehlo.concatenate $operandList, dim = $axis : ($operandTypes) -> $outputType"
// Compute the result type ourselves: the extent on `axis` is the SUM of the
// operands' extents there (the other axes match operand 0). Trusting the node's
// declared output spec mis-infers the concatenated axis for >2 operands (#667).
val inShapes = node.inputs.mapNotNull { it.shape }
val resultType =
if (inShapes.size == node.inputs.size && inShapes.size == operands.size &&
inShapes.isNotEmpty() && inShapes.all { it.size == inShapes[0].size } &&
axis in inShapes[0].indices
) {
val outShape = inShapes[0].toMutableList()
outShape[axis] = inShapes.sumOf { it[axis] }
context.getTypeMapper().mapTensorType(
TensorSpec("${node.id}_out", outShape, outputSpec?.dtype ?: node.inputs[0].dtype),
)
} else {
outputType
}
val operation = "$resultValue = stablehlo.concatenate $operandList, dim = $axis : ($operandTypes) -> $resultType"
context.emitOperation(operation)

return ConversionResult.Success(
Expand Down Expand Up @@ -297,33 +315,31 @@ public class ShapeOperationsConverter : StableHloOperationConverter {
}

val outputSpec = node.outputs.firstOrNull()
val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) }
?: "tensor<?xf32>"

// Get the new shape from parameters or output spec
val newShape = when {
outputSpec?.shape != null -> outputSpec.shape
node.operation.parameters.containsKey("shape") -> {
@Suppress("UNCHECKED_CAST")
node.operation.parameters["shape"] as? List<Int>
}
node.operation.parameters.containsKey("newShape") -> {
@Suppress("UNCHECKED_CAST")
node.operation.parameters["newShape"] as? List<Int>
}
else -> null
}


// Get the new shape from the output spec or any of the parameter aliases the
// tape records it under (`outputShape` is the key the reshape trace uses, #666).
@Suppress("UNCHECKED_CAST")
fun param(key: String): List<Int>? = node.operation.parameters[key] as? List<Int>
val newShape: List<Int>? = outputSpec?.shape
?: param("shape") ?: param("newShape") ?: param("outputShape")

if (newShape == null || newShape.isEmpty()) {
return ConversionResult.Failure(
"Reshape operation requires a target shape specification",
"Missing shape parameter for reshape node ${node.id}"
)
}

// Prefer the declared output type; otherwise build it from the resolved shape so
// a reshape whose target lives only in a parameter still emits a typed result.
val resultType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) }
?: context.getTypeMapper().mapTensorType(
TensorSpec("${node.id}_out", newShape, node.inputs.firstOrNull()?.dtype ?: "FP32"),
)

val inputType = resolveOperandType(operands[0], node, context)
val resultValue = context.nextTempValue()
val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($inputType) -> $outputType"
val operation = "$resultValue = stablehlo.reshape ${operands[0]} : ($inputType) -> $resultType"
context.emitOperation(operation)

return ConversionResult.Success(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package sk.ainet.compile.hlo

import sk.ainet.lang.graph.DefaultComputeGraph
import sk.ainet.lang.graph.GraphEdge
import sk.ainet.lang.graph.GraphNode
import sk.ainet.lang.tensor.Tensor
import sk.ainet.lang.tensor.ops.Operation
import sk.ainet.lang.tensor.ops.TensorSpec
import sk.ainet.lang.tensor.ops.ValidationResult
import sk.ainet.lang.types.DType
import kotlin.test.Test
import kotlin.test.assertFalse
import kotlin.test.assertTrue

/**
* Regression tests for the ShapeOperationsConverter shape bugs:
* - #667: multi-input `concatenate` must SUM the operands' extents on the axis.
* - #666: `reshape` whose target shape lives only in a parameter must still lower.
*/
class ReshapeConcatShapeFixTest {

private fun op(name: String, type: String, params: Map<String, Any> = emptyMap()): Operation =
object : Operation {
override val name: String = name
override val type: String = type
override val parameters: Map<String, Any> = params
override fun <T : DType, V> execute(inputs: List<Tensor<T, V>>): List<Tensor<T, V>> =
throw UnsupportedOperationException("marker")
override fun validateInputs(inputs: List<TensorSpec>): ValidationResult = ValidationResult.Valid
override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> = inputs.take(1)
override fun clone(newParameters: Map<String, Any>): Operation = this
override fun serialize(): Map<String, Any> = mapOf("name" to name, "type" to type)
}

@Test
fun multiInputConcat_sums_the_concatenated_axis() {
val g = DefaultComputeGraph()
val a = GraphNode("a", op("input", "input"), emptyList(), listOf(TensorSpec("a", listOf(1, 1, 8, 8), "FP32")))
val b = GraphNode("b", op("input", "input"), emptyList(), listOf(TensorSpec("b", listOf(1, 4, 8, 8), "FP32")))
val c = GraphNode("c", op("input", "input"), emptyList(), listOf(TensorSpec("c", listOf(1, 1, 8, 8), "FP32")))
g.addNode(a); g.addNode(b); g.addNode(c)

val cat = GraphNode(
id = "cat",
operation = op("concatenate", "shape", mapOf("axis" to 1)),
inputs = listOf(TensorSpec("a", listOf(1, 1, 8, 8), "FP32"), TensorSpec("b", listOf(1, 4, 8, 8), "FP32"), TensorSpec("c", listOf(1, 1, 8, 8), "FP32")),
// Deliberately WRONG declared output shape (operand-0 extent on the axis):
outputs = listOf(TensorSpec("y", listOf(1, 1, 8, 8), "FP32")),
)
g.addNode(cat)
g.addEdge(GraphEdge("e0", a, cat, 0, 0, a.outputs[0]))
g.addEdge(GraphEdge("e1", b, cat, 0, 1, b.outputs[0]))
g.addEdge(GraphEdge("e2", c, cat, 0, 2, c.outputs[0]))

val mlir = StableHloConverterFactory.createExtended().convert(g, "concat3").content
assertTrue(mlir.contains("stablehlo.concatenate"), "expected a concatenate op:\n$mlir")
assertTrue(mlir.contains("-> tensor<1x6x8x8xf32>"), "concat result must sum the axis to 6:\n$mlir")
assertFalse(mlir.contains("-> tensor<1x1x8x8xf32>"), "must not echo operand-0's axis extent:\n$mlir")
}

@Test
fun reshape_with_shape_only_in_parameter_still_lowers() {
val g = DefaultComputeGraph()
val x = GraphNode("x", op("input", "input"), emptyList(), listOf(TensorSpec("x", listOf(1, 12), "FP32")))
g.addNode(x)
val r = GraphNode(
id = "r",
operation = op("reshape", "shape", mapOf("outputShape" to listOf(1, 3, 4))),
inputs = listOf(TensorSpec("x", listOf(1, 12), "FP32")),
outputs = emptyList(), // no declared output spec — target lives only in the param
)
g.addNode(r)
g.addEdge(GraphEdge("e0", x, r, 0, 0, x.outputs[0]))

val mlir = StableHloConverterFactory.createExtended().convert(g, "reshape1").content
assertTrue(mlir.contains("stablehlo.reshape"), "reshape must lower (not an empty module):\n$mlir")
assertTrue(mlir.contains("tensor<1x3x4xf32>"), "reshape must carry the target shape:\n$mlir")
}
}
Loading