diff --git a/skainet-compile/skainet-compile-hlo/src/jvmMain/kotlin/sk/ainet/compile/hlo/generate/HloGenerator.kt b/skainet-compile/skainet-compile-hlo/src/jvmMain/kotlin/sk/ainet/compile/hlo/generate/HloGenerator.kt index 4285bbd3..e3ec8a20 100644 --- a/skainet-compile/skainet-compile-hlo/src/jvmMain/kotlin/sk/ainet/compile/hlo/generate/HloGenerator.kt +++ b/skainet-compile/skainet-compile-hlo/src/jvmMain/kotlin/sk/ainet/compile/hlo/generate/HloGenerator.kt @@ -6,6 +6,7 @@ import sk.ainet.compile.hlo.StableHloModule import sk.ainet.lang.graph.DefaultGraphExecutionContext import sk.ainet.lang.model.Model import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.operators.bind import sk.ainet.lang.tensor.ops.VoidTensorOps import sk.ainet.lang.tape.toComputeGraph import sk.ainet.lang.types.DType @@ -32,16 +33,17 @@ public object HloGenerator { functionName: String = "main" ): StableHloModule { val ctx = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps()) + val traceInput = sampleInput.bind(ctx) - // Capture the sample input's tensor ref ID so we can mark it as a function argument + // Capture the rebound sample input's tensor ref ID so we can mark it as a function argument. @Suppress("UNCHECKED_CAST") - val inputRefId = ctx.session.refOf(sampleInput as sk.ainet.lang.tensor.Tensor<*, *>).id + val inputRefId = ctx.session.refOf(traceInput as sk.ainet.lang.tensor.Tensor<*, *>).id val (tape, _) = ctx.record { @Suppress("UNCHECKED_CAST") traceForwardPass( model as Model, Tensor>, - sampleInput as Tensor, + traceInput as Tensor, ctx ) } diff --git a/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/generate/HloGeneratorTest.kt b/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/generate/HloGeneratorTest.kt index 2ed7eb60..7ab3c8b4 100644 --- a/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/generate/HloGeneratorTest.kt +++ b/skainet-compile/skainet-compile-hlo/src/jvmTest/kotlin/sk/ainet/compile/hlo/generate/HloGeneratorTest.kt @@ -1,10 +1,17 @@ package sk.ainet.compile.hlo.generate import kotlinx.coroutines.test.runTest +import sk.ainet.compile.hlo.StableHloConverterFactory +import sk.ainet.lang.dag.dag +import sk.ainet.lang.dag.multiply +import sk.ainet.lang.dag.sum +import sk.ainet.lang.dag.unsqueeze import sk.ainet.lang.graph.DefaultGraphExecutionContext +import sk.ainet.lang.graph.dsl.toComputeGraph import sk.ainet.lang.model.compute.Rgb2GrayScale import sk.ainet.lang.model.compute.Rgb2GrayScaleMatMul import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.ops.TensorSpec import sk.ainet.lang.tensor.ops.VoidTensorOps import sk.ainet.lang.types.FP16 import sk.ainet.lang.types.FP32 @@ -48,9 +55,38 @@ class HloGeneratorTest { assertTrue(module.content.contains("module {"), "Expected 'module {' in MLIR output") assertTrue(module.content.contains("func.func"), "Expected 'func.func' in MLIR output") assertTrue(module.content.contains("@rgb2grayscale_matmul"), "Expected function name in MLIR output") + assertTrue(module.content.contains("stablehlo."), "Expected tensor-bound forward pass to emit StableHLO ops") assertTrue(module.content.length > 50, "Expected non-trivial MLIR output") } + @Test + fun testDagConstantsAreInlinedAndReductionDropsDimension() { + val h = 8 + val w = 8 + val program = dag { + val x = input("input", TensorSpec("input", listOf(1, 3, h, w), "FP16")) + val luma = constant("luma") { + fromArray(floatArrayOf(0.2989f, 0.5870f, 0.1140f), shape = listOf(1, 3, 1, 1)) + } + val weighted = multiply(x, luma) + val grayHW = sum(weighted, 1) + output(unsqueeze(grayHW, 1)) + } + + val mlir = StableHloConverterFactory.createExtended() + .convert(program.toComputeGraph(), "grayscale") + .content + + assertTrue( + mlir.contains("stablehlo.constant") && !Regex("""@grayscale\([^)]*,[^)]*\)""").containsMatchIn(mlir), + "luma constant should be baked into the StableHLO module:\n$mlir" + ) + assertTrue( + mlir.contains("tensor<1x1x${h}x${w}xf16>"), + "sum over the channel dimension should drop that dimension before unsqueeze:\n$mlir" + ) + } + @Test fun testGenerateDefaultFunctionName() = runTest { val model = Rgb2GrayScale() diff --git a/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/GraphDsl.kt b/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/GraphDsl.kt index cdf13d79..8888f51f 100644 --- a/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/GraphDsl.kt +++ b/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/GraphDsl.kt @@ -1,5 +1,6 @@ package sk.ainet.lang.dag +import sk.ainet.lang.tensor.ops.GenericOperation import sk.ainet.lang.tensor.ops.InputOperation import sk.ainet.lang.tensor.ops.Operation import sk.ainet.lang.tensor.ops.TensorSpec @@ -80,6 +81,8 @@ public class DagBuilder { inputs: List>, nodeId: String ): List { + inferDagOutputSpecs(operation, inputs, nodeId)?.let { return it } + val inputSpecs = inputs.map { it.spec } val inferred = runCatching { operation.inferOutputs(inputSpecs) } .getOrElse { @@ -150,11 +153,11 @@ public class DagBuilder { } /** - * Declare a constant placeholder (treated like an input node). + * Declare a constant tensor with any available initializer data embedded in the graph. */ @DagDsl public fun constant(name: String, spec: TensorSpec): GraphValue { - val op = InputOperation(parameters = mapOf("kind" to "const")) + val op = GenericOperation("weight", constantParameters(spec), type = "constant") val recorded = recordNode("const", op, emptyList(), id = "const_$name").first() @Suppress("UNCHECKED_CAST") val typed = (recorded as GraphValue) @@ -163,8 +166,61 @@ public class DagBuilder { return updated } + private fun inferDagOutputSpecs( + operation: Operation, + inputs: List>, + nodeId: String + ): List? { + if (operation.name.lowercase() !in setOf("sum", "mean", "variance")) return null + val input = inputs.firstOrNull()?.spec ?: return null + val outputShape = reductionOutputShape(input.shape, operation.parameters["dim"] as? Int ?: operation.parameters["axis"] as? Int) + return listOf( + TensorSpec( + name = "${nodeId}_out0", + shape = outputShape, + dtype = input.dtype, + requiresGrad = input.requiresGrad + ) + ) + } + + private fun reductionOutputShape(shape: List?, dim: Int?): List? { + if (shape == null) return null + if (dim == null) return listOf(1) + + val actualDim = if (dim < 0) shape.size + dim else dim + require(actualDim in shape.indices) { + "Reduction dimension $dim is out of bounds for tensor rank ${shape.size}" + } + + val reduced = shape.filterIndexed { index, _ -> index != actualDim } + return reduced.ifEmpty { listOf(1) } + } + + private fun constantParameters(spec: TensorSpec): Map { + val params = mutableMapOf("trainable" to false) + when (spec.metadata["init"] as? String) { + "fromArray" -> { + val values = spec.metadata["values"] as? FloatArray + if (values != null) params["initial_value"] = values + } + "fromIntArray" -> { + val values = spec.metadata["values"] as? IntArray + if (values != null) params["initial_value"] = values.toList() + } + "ones" -> params["initial_value"] = 1.0f + "zeros" -> params["initial_value"] = 0.0f + else -> { + if ((spec.metadata["init"] as? String)?.startsWith("full(") == true) { + params["initial_value"] = spec.metadata["value"] as? Number ?: 0.0f + } + } + } + return params + } + /** - * Parameter helper that reuses a symbolic, allocation-free data DSL to declare shape/dtype. + * Parameter helper that reuses the symbolic data DSL to declare shape/dtype. * * Example: * ``` @@ -182,7 +238,7 @@ public class DagBuilder { } /** - * Constant helper that reuses a symbolic, allocation-free data DSL to declare shape/dtype. + * Constant helper that reuses the symbolic data DSL to declare shape/dtype and initializer data. */ @DagDsl public inline fun constant( diff --git a/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/SymbolicDataDsl.kt b/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/SymbolicDataDsl.kt index 37293848..8426dd5d 100644 --- a/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/SymbolicDataDsl.kt +++ b/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/SymbolicDataDsl.kt @@ -8,8 +8,8 @@ import kotlin.reflect.KClass internal fun dtypeName(kClass: KClass): String = kClass.simpleName ?: kClass.toString() /** - * Lightweight, allocation-free builder that mimics the shape/initializer style of the data DSL - * but produces only [TensorSpec] metadata for the DAG DSL. + * Lightweight builder that mimics the shape/initializer style of the data DSL + * and produces [TensorSpec] metadata for the DAG DSL. */ @DagDsl public class SymbolicTensorBuilder( @@ -42,7 +42,7 @@ public class SymbolicTensorBuilder( } /** - * Infer shape from a flat float array. Stores initializer metadata only; no allocation is performed. + * Infer shape from a flat float array and retain the values for constant materialization. */ @DagDsl public fun fromArray(values: FloatArray, shape: List? = null): TensorSpec { @@ -51,12 +51,12 @@ public class SymbolicTensorBuilder( name = defaultName, shape = inferredShape, dtype = dtypeName, - metadata = mapOf("init" to "fromArray", "size" to values.size) + metadata = mapOf("init" to "fromArray", "size" to values.size, "values" to values.copyOf()) ) } /** - * Infer shape from a flat int array. Stores initializer metadata only; no allocation is performed. + * Infer shape from a flat int array and retain the values for constant materialization. */ @DagDsl public fun fromArray(values: IntArray, shape: List? = null): TensorSpec { @@ -65,7 +65,7 @@ public class SymbolicTensorBuilder( name = defaultName, shape = inferredShape, dtype = dtypeName, - metadata = mapOf("init" to "fromIntArray", "size" to values.size) + metadata = mapOf("init" to "fromIntArray", "size" to values.size, "values" to values.copyOf()) ) } } @@ -76,11 +76,19 @@ public class SymbolicTensorBuilder( @DagDsl public class SymbolicInit { private var kind: String = "unspecified" + private var value: Number? = null @DagDsl public fun ones() { kind = "ones" } @DagDsl public fun zeros() { kind = "zeros" } - @DagDsl public fun full(value: Number) { kind = "full($value)" } + @DagDsl public fun full(value: Number) { + kind = "full($value)" + this.value = value + } - internal fun metadata(): Map = - if (kind == "unspecified") emptyMap() else mapOf("init" to kind) + internal fun metadata(): Map { + if (kind == "unspecified") return emptyMap() + val metadata = mutableMapOf("init" to kind) + value?.let { metadata["value"] = it } + return metadata + } }