Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import sk.ainet.compile.hlo.StableHloModule
import sk.ainet.lang.graph.DefaultGraphExecutionContext
import sk.ainet.lang.model.Model
import sk.ainet.lang.tensor.Tensor
import sk.ainet.lang.tensor.operators.bind
import sk.ainet.lang.tensor.ops.VoidTensorOps
import sk.ainet.lang.tape.toComputeGraph
import sk.ainet.lang.types.DType
Expand All @@ -32,16 +33,17 @@ public object HloGenerator {
functionName: String = "main"
): StableHloModule {
val ctx = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps())
val traceInput = sampleInput.bind(ctx)

// Capture the sample input's tensor ref ID so we can mark it as a function argument
// Capture the rebound sample input's tensor ref ID so we can mark it as a function argument.
@Suppress("UNCHECKED_CAST")
val inputRefId = ctx.session.refOf(sampleInput as sk.ainet.lang.tensor.Tensor<*, *>).id
val inputRefId = ctx.session.refOf(traceInput as sk.ainet.lang.tensor.Tensor<*, *>).id

val (tape, _) = ctx.record {
@Suppress("UNCHECKED_CAST")
traceForwardPass(
model as Model<DType, Any?, Tensor<DType, Any?>, Tensor<DType, Any?>>,
sampleInput as Tensor<DType, Any?>,
traceInput as Tensor<DType, Any?>,
ctx
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
package sk.ainet.compile.hlo.generate

import kotlinx.coroutines.test.runTest
import sk.ainet.compile.hlo.StableHloConverterFactory
import sk.ainet.lang.dag.dag
import sk.ainet.lang.dag.multiply
import sk.ainet.lang.dag.sum
import sk.ainet.lang.dag.unsqueeze
import sk.ainet.lang.graph.DefaultGraphExecutionContext
import sk.ainet.lang.graph.dsl.toComputeGraph
import sk.ainet.lang.model.compute.Rgb2GrayScale
import sk.ainet.lang.model.compute.Rgb2GrayScaleMatMul
import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.tensor.ops.TensorSpec
import sk.ainet.lang.tensor.ops.VoidTensorOps
import sk.ainet.lang.types.FP16
import sk.ainet.lang.types.FP32
Expand Down Expand Up @@ -48,9 +55,38 @@ class HloGeneratorTest {
assertTrue(module.content.contains("module {"), "Expected 'module {' in MLIR output")
assertTrue(module.content.contains("func.func"), "Expected 'func.func' in MLIR output")
assertTrue(module.content.contains("@rgb2grayscale_matmul"), "Expected function name in MLIR output")
assertTrue(module.content.contains("stablehlo."), "Expected tensor-bound forward pass to emit StableHLO ops")
assertTrue(module.content.length > 50, "Expected non-trivial MLIR output")
}

@Test
fun testDagConstantsAreInlinedAndReductionDropsDimension() {
val h = 8
val w = 8
val program = dag {
val x = input<FP16>("input", TensorSpec("input", listOf(1, 3, h, w), "FP16"))
val luma = constant<FP16, Float>("luma") {
fromArray(floatArrayOf(0.2989f, 0.5870f, 0.1140f), shape = listOf(1, 3, 1, 1))
}
val weighted = multiply(x, luma)
val grayHW = sum(weighted, 1)
output(unsqueeze(grayHW, 1))
}

val mlir = StableHloConverterFactory.createExtended()
.convert(program.toComputeGraph(), "grayscale")
.content

assertTrue(
mlir.contains("stablehlo.constant") && !Regex("""@grayscale\([^)]*,[^)]*\)""").containsMatchIn(mlir),
"luma constant should be baked into the StableHLO module:\n$mlir"
)
assertTrue(
mlir.contains("tensor<1x1x${h}x${w}xf16>"),
"sum over the channel dimension should drop that dimension before unsqueeze:\n$mlir"
)
}

@Test
fun testGenerateDefaultFunctionName() = runTest {
val model = Rgb2GrayScale()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package sk.ainet.lang.dag

import sk.ainet.lang.tensor.ops.GenericOperation
import sk.ainet.lang.tensor.ops.InputOperation
import sk.ainet.lang.tensor.ops.Operation
import sk.ainet.lang.tensor.ops.TensorSpec
Expand Down Expand Up @@ -80,6 +81,8 @@ public class DagBuilder {
inputs: List<GraphValue<*>>,
nodeId: String
): List<TensorSpec> {
inferDagOutputSpecs(operation, inputs, nodeId)?.let { return it }

val inputSpecs = inputs.map { it.spec }
val inferred = runCatching { operation.inferOutputs(inputSpecs) }
.getOrElse {
Expand Down Expand Up @@ -150,11 +153,11 @@ public class DagBuilder {
}

/**
* Declare a constant placeholder (treated like an input node).
* Declare a constant tensor with any available initializer data embedded in the graph.
*/
@DagDsl
public fun <T : DType> constant(name: String, spec: TensorSpec): GraphValue<T> {
val op = InputOperation<T, Any>(parameters = mapOf("kind" to "const"))
val op = GenericOperation("weight", constantParameters(spec), type = "constant")
val recorded = recordNode("const", op, emptyList(), id = "const_$name").first()
@Suppress("UNCHECKED_CAST")
val typed = (recorded as GraphValue<T>)
Expand All @@ -163,8 +166,61 @@ public class DagBuilder {
return updated
}

private fun inferDagOutputSpecs(
operation: Operation,
inputs: List<GraphValue<*>>,
nodeId: String
): List<TensorSpec>? {
if (operation.name.lowercase() !in setOf("sum", "mean", "variance")) return null
val input = inputs.firstOrNull()?.spec ?: return null
val outputShape = reductionOutputShape(input.shape, operation.parameters["dim"] as? Int ?: operation.parameters["axis"] as? Int)
return listOf(
TensorSpec(
name = "${nodeId}_out0",
shape = outputShape,
dtype = input.dtype,
requiresGrad = input.requiresGrad
)
)
}

private fun reductionOutputShape(shape: List<Int>?, dim: Int?): List<Int>? {
if (shape == null) return null
if (dim == null) return listOf(1)

val actualDim = if (dim < 0) shape.size + dim else dim
require(actualDim in shape.indices) {
"Reduction dimension $dim is out of bounds for tensor rank ${shape.size}"
}

val reduced = shape.filterIndexed { index, _ -> index != actualDim }
return reduced.ifEmpty { listOf(1) }
}

private fun constantParameters(spec: TensorSpec): Map<String, Any> {
val params = mutableMapOf<String, Any>("trainable" to false)
when (spec.metadata["init"] as? String) {
"fromArray" -> {
val values = spec.metadata["values"] as? FloatArray
if (values != null) params["initial_value"] = values
}
"fromIntArray" -> {
val values = spec.metadata["values"] as? IntArray
if (values != null) params["initial_value"] = values.toList()
}
"ones" -> params["initial_value"] = 1.0f
"zeros" -> params["initial_value"] = 0.0f
else -> {
if ((spec.metadata["init"] as? String)?.startsWith("full(") == true) {
params["initial_value"] = spec.metadata["value"] as? Number ?: 0.0f
}
}
}
return params
}

/**
* Parameter helper that reuses a symbolic, allocation-free data DSL to declare shape/dtype.
* Parameter helper that reuses the symbolic data DSL to declare shape/dtype.
*
* Example:
* ```
Expand All @@ -182,7 +238,7 @@ public class DagBuilder {
}

/**
* Constant helper that reuses a symbolic, allocation-free data DSL to declare shape/dtype.
* Constant helper that reuses the symbolic data DSL to declare shape/dtype and initializer data.
*/
@DagDsl
public inline fun <reified T : DType, V> constant(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import kotlin.reflect.KClass
internal fun <T : DType> dtypeName(kClass: KClass<T>): String = kClass.simpleName ?: kClass.toString()

/**
* Lightweight, allocation-free builder that mimics the shape/initializer style of the data DSL
* but produces only [TensorSpec] metadata for the DAG DSL.
* Lightweight builder that mimics the shape/initializer style of the data DSL
* and produces [TensorSpec] metadata for the DAG DSL.
*/
@DagDsl
public class SymbolicTensorBuilder<T : DType>(
Expand Down Expand Up @@ -42,7 +42,7 @@ public class SymbolicTensorBuilder<T : DType>(
}

/**
* Infer shape from a flat float array. Stores initializer metadata only; no allocation is performed.
* Infer shape from a flat float array and retain the values for constant materialization.
*/
@DagDsl
public fun fromArray(values: FloatArray, shape: List<Int>? = null): TensorSpec {
Expand All @@ -51,12 +51,12 @@ public class SymbolicTensorBuilder<T : DType>(
name = defaultName,
shape = inferredShape,
dtype = dtypeName,
metadata = mapOf("init" to "fromArray", "size" to values.size)
metadata = mapOf("init" to "fromArray", "size" to values.size, "values" to values.copyOf())
)
}

/**
* Infer shape from a flat int array. Stores initializer metadata only; no allocation is performed.
* Infer shape from a flat int array and retain the values for constant materialization.
*/
@DagDsl
public fun fromArray(values: IntArray, shape: List<Int>? = null): TensorSpec {
Expand All @@ -65,7 +65,7 @@ public class SymbolicTensorBuilder<T : DType>(
name = defaultName,
shape = inferredShape,
dtype = dtypeName,
metadata = mapOf("init" to "fromIntArray", "size" to values.size)
metadata = mapOf("init" to "fromIntArray", "size" to values.size, "values" to values.copyOf())
)
}
}
Expand All @@ -76,11 +76,19 @@ public class SymbolicTensorBuilder<T : DType>(
@DagDsl
public class SymbolicInit {
private var kind: String = "unspecified"
private var value: Number? = null

@DagDsl public fun ones() { kind = "ones" }
@DagDsl public fun zeros() { kind = "zeros" }
@DagDsl public fun full(value: Number) { kind = "full($value)" }
@DagDsl public fun full(value: Number) {
kind = "full($value)"
this.value = value
}

internal fun metadata(): Map<String, Any> =
if (kind == "unspecified") emptyMap() else mapOf("init" to kind)
internal fun metadata(): Map<String, Any> {
if (kind == "unspecified") return emptyMap()
val metadata = mutableMapOf<String, Any>("init" to kind)
value?.let { metadata["value"] = it }
return metadata
}
}
Loading