From 1145e33d49e7b54d815078f4bd0f1ebb87888524 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Fri, 5 Jun 2026 14:35:10 +0200 Subject: [PATCH] Fix StableHLO DSL export for constants and reductions Rebind HLO generator sample inputs to the tracing context so tensor-bound forward passes emit real StableHLO ops. Preserve symbolic DAG constant values and lower them as embedded constants instead of function arguments. Also infer reduction output shapes so reduced dimensions are dropped before downstream shape ops. Closes #663 --- .../compile/hlo/generate/HloGenerator.kt | 8 ++- .../compile/hlo/generate/HloGeneratorTest.kt | 36 +++++++++++ .../kotlin/sk/ainet/lang/dag/GraphDsl.kt | 64 +++++++++++++++++-- .../sk/ainet/lang/dag/SymbolicDataDsl.kt | 26 +++++--- 4 files changed, 118 insertions(+), 16 deletions(-) diff --git a/skainet-compile/skainet-compile-hlo/src/jvmMain/kotlin/sk/ainet/compile/hlo/generate/HloGenerator.kt b/skainet-compile/skainet-compile-hlo/src/jvmMain/kotlin/sk/ainet/compile/hlo/generate/HloGenerator.kt index 4285bbd3..e3ec8a20 100644 --- a/skainet-compile/skainet-compile-hlo/src/jvmMain/kotlin/sk/ainet/compile/hlo/generate/HloGenerator.kt +++ b/skainet-compile/skainet-compile-hlo/src/jvmMain/kotlin/sk/ainet/compile/hlo/generate/HloGenerator.kt @@ -6,6 +6,7 @@ import sk.ainet.compile.hlo.StableHloModule import sk.ainet.lang.graph.DefaultGraphExecutionContext import sk.ainet.lang.model.Model import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.operators.bind import sk.ainet.lang.tensor.ops.VoidTensorOps import sk.ainet.lang.tape.toComputeGraph import sk.ainet.lang.types.DType @@ -32,16 +33,17 @@ public object HloGenerator { functionName: String = "main" ): StableHloModule { val ctx = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps()) + val traceInput = sampleInput.bind(ctx) - // Capture the sample input's tensor ref ID so we can mark it as a function argument + // Capture the rebound sample input's tensor ref ID so we can mark it as a function argument. @Suppress("UNCHECKED_CAST") - val inputRefId = ctx.session.refOf(sampleInput as sk.ainet.lang.tensor.Tensor<*, *>).id + val inputRefId = ctx.session.refOf(traceInput as sk.ainet.lang.tensor.Tensor<*, *>).id val (tape, _) = ctx.record { @Suppress("UNCHECKED_CAST") traceForwardPass( model as Model, Tensor>, - sampleInput as Tensor, + traceInput as Tensor, ctx ) } diff --git a/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/generate/HloGeneratorTest.kt b/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/generate/HloGeneratorTest.kt index 2ed7eb60..7ab3c8b4 100644 --- a/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/generate/HloGeneratorTest.kt +++ b/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/generate/HloGeneratorTest.kt @@ -1,10 +1,17 @@ package sk.ainet.compile.hlo.generate import kotlinx.coroutines.test.runTest +import sk.ainet.compile.hlo.StableHloConverterFactory +import sk.ainet.lang.dag.dag +import sk.ainet.lang.dag.multiply +import sk.ainet.lang.dag.sum +import sk.ainet.lang.dag.unsqueeze import sk.ainet.lang.graph.DefaultGraphExecutionContext +import sk.ainet.lang.graph.dsl.toComputeGraph import sk.ainet.lang.model.compute.Rgb2GrayScale import sk.ainet.lang.model.compute.Rgb2GrayScaleMatMul import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.ops.TensorSpec import sk.ainet.lang.tensor.ops.VoidTensorOps import sk.ainet.lang.types.FP16 import sk.ainet.lang.types.FP32 @@ -48,9 +55,38 @@ class HloGeneratorTest { assertTrue(module.content.contains("module {"), "Expected 'module {' in MLIR output") assertTrue(module.content.contains("func.func"), "Expected 'func.func' in MLIR output") assertTrue(module.content.contains("@rgb2grayscale_matmul"), "Expected function name in MLIR output") + assertTrue(module.content.contains("stablehlo."), "Expected tensor-bound forward pass to emit StableHLO ops") assertTrue(module.content.length > 50, "Expected non-trivial MLIR output") } + @Test + fun testDagConstantsAreInlinedAndReductionDropsDimension() { + val h = 8 + val w = 8 + val program = dag { + val x = input("input", TensorSpec("input", listOf(1, 3, h, w), "FP16")) + val luma = constant("luma") { + fromArray(floatArrayOf(0.2989f, 0.5870f, 0.1140f), shape = listOf(1, 3, 1, 1)) + } + val weighted = multiply(x, luma) + val grayHW = sum(weighted, 1) + output(unsqueeze(grayHW, 1)) + } + + val mlir = StableHloConverterFactory.createExtended() + .convert(program.toComputeGraph(), "grayscale") + .content + + assertTrue( + mlir.contains("stablehlo.constant") && !Regex("""@grayscale\([^)]*,[^)]*\)""").containsMatchIn(mlir), + "luma constant should be baked into the StableHLO module:\n$mlir" + ) + assertTrue( + mlir.contains("tensor<1x1x${h}x${w}xf16>"), + "sum over the channel dimension should drop that dimension before unsqueeze:\n$mlir" + ) + } + @Test fun testGenerateDefaultFunctionName() = runTest { val model = Rgb2GrayScale() diff --git a/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/GraphDsl.kt b/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/GraphDsl.kt index cdf13d79..8888f51f 100644 --- a/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/GraphDsl.kt +++ b/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/GraphDsl.kt @@ -1,5 +1,6 @@ package sk.ainet.lang.dag +import sk.ainet.lang.tensor.ops.GenericOperation import sk.ainet.lang.tensor.ops.InputOperation import sk.ainet.lang.tensor.ops.Operation import sk.ainet.lang.tensor.ops.TensorSpec @@ -80,6 +81,8 @@ public class DagBuilder { inputs: List>, nodeId: String ): List { + inferDagOutputSpecs(operation, inputs, nodeId)?.let { return it } + val inputSpecs = inputs.map { it.spec } val inferred = runCatching { operation.inferOutputs(inputSpecs) } .getOrElse { @@ -150,11 +153,11 @@ public class DagBuilder { } /** - * Declare a constant placeholder (treated like an input node). + * Declare a constant tensor with any available initializer data embedded in the graph. */ @DagDsl public fun constant(name: String, spec: TensorSpec): GraphValue { - val op = InputOperation(parameters = mapOf("kind" to "const")) + val op = GenericOperation("weight", constantParameters(spec), type = "constant") val recorded = recordNode("const", op, emptyList(), id = "const_$name").first() @Suppress("UNCHECKED_CAST") val typed = (recorded as GraphValue) @@ -163,8 +166,61 @@ public class DagBuilder { return updated } + private fun inferDagOutputSpecs( + operation: Operation, + inputs: List>, + nodeId: String + ): List? { + if (operation.name.lowercase() !in setOf("sum", "mean", "variance")) return null + val input = inputs.firstOrNull()?.spec ?: return null + val outputShape = reductionOutputShape(input.shape, operation.parameters["dim"] as? Int ?: operation.parameters["axis"] as? Int) + return listOf( + TensorSpec( + name = "${nodeId}_out0", + shape = outputShape, + dtype = input.dtype, + requiresGrad = input.requiresGrad + ) + ) + } + + private fun reductionOutputShape(shape: List?, dim: Int?): List? { + if (shape == null) return null + if (dim == null) return listOf(1) + + val actualDim = if (dim < 0) shape.size + dim else dim + require(actualDim in shape.indices) { + "Reduction dimension $dim is out of bounds for tensor rank ${shape.size}" + } + + val reduced = shape.filterIndexed { index, _ -> index != actualDim } + return reduced.ifEmpty { listOf(1) } + } + + private fun constantParameters(spec: TensorSpec): Map { + val params = mutableMapOf("trainable" to false) + when (spec.metadata["init"] as? String) { + "fromArray" -> { + val values = spec.metadata["values"] as? FloatArray + if (values != null) params["initial_value"] = values + } + "fromIntArray" -> { + val values = spec.metadata["values"] as? IntArray + if (values != null) params["initial_value"] = values.toList() + } + "ones" -> params["initial_value"] = 1.0f + "zeros" -> params["initial_value"] = 0.0f + else -> { + if ((spec.metadata["init"] as? String)?.startsWith("full(") == true) { + params["initial_value"] = spec.metadata["value"] as? Number ?: 0.0f + } + } + } + return params + } + /** - * Parameter helper that reuses a symbolic, allocation-free data DSL to declare shape/dtype. + * Parameter helper that reuses the symbolic data DSL to declare shape/dtype. * * Example: * ``` @@ -182,7 +238,7 @@ public class DagBuilder { } /** - * Constant helper that reuses a symbolic, allocation-free data DSL to declare shape/dtype. + * Constant helper that reuses the symbolic data DSL to declare shape/dtype and initializer data. */ @DagDsl public inline fun constant( diff --git a/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/SymbolicDataDsl.kt b/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/SymbolicDataDsl.kt index 37293848..8426dd5d 100644 --- a/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/SymbolicDataDsl.kt +++ b/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/SymbolicDataDsl.kt @@ -8,8 +8,8 @@ import kotlin.reflect.KClass internal fun dtypeName(kClass: KClass): String = kClass.simpleName ?: kClass.toString() /** - * Lightweight, allocation-free builder that mimics the shape/initializer style of the data DSL - * but produces only [TensorSpec] metadata for the DAG DSL. + * Lightweight builder that mimics the shape/initializer style of the data DSL + * and produces [TensorSpec] metadata for the DAG DSL. */ @DagDsl public class SymbolicTensorBuilder( @@ -42,7 +42,7 @@ public class SymbolicTensorBuilder( } /** - * Infer shape from a flat float array. Stores initializer metadata only; no allocation is performed. + * Infer shape from a flat float array and retain the values for constant materialization. */ @DagDsl public fun fromArray(values: FloatArray, shape: List? = null): TensorSpec { @@ -51,12 +51,12 @@ public class SymbolicTensorBuilder( name = defaultName, shape = inferredShape, dtype = dtypeName, - metadata = mapOf("init" to "fromArray", "size" to values.size) + metadata = mapOf("init" to "fromArray", "size" to values.size, "values" to values.copyOf()) ) } /** - * Infer shape from a flat int array. Stores initializer metadata only; no allocation is performed. + * Infer shape from a flat int array and retain the values for constant materialization. */ @DagDsl public fun fromArray(values: IntArray, shape: List? = null): TensorSpec { @@ -65,7 +65,7 @@ public class SymbolicTensorBuilder( name = defaultName, shape = inferredShape, dtype = dtypeName, - metadata = mapOf("init" to "fromIntArray", "size" to values.size) + metadata = mapOf("init" to "fromIntArray", "size" to values.size, "values" to values.copyOf()) ) } } @@ -76,11 +76,19 @@ public class SymbolicTensorBuilder( @DagDsl public class SymbolicInit { private var kind: String = "unspecified" + private var value: Number? = null @DagDsl public fun ones() { kind = "ones" } @DagDsl public fun zeros() { kind = "zeros" } - @DagDsl public fun full(value: Number) { kind = "full($value)" } + @DagDsl public fun full(value: Number) { + kind = "full($value)" + this.value = value + } - internal fun metadata(): Map = - if (kind == "unspecified") emptyMap() else mapOf("init" to kind) + internal fun metadata(): Map { + if (kind == "unspecified") return emptyMap() + val metadata = mutableMapOf("init" to kind) + value?.let { metadata["value"] = it } + return metadata + } }