diff --git a/skainet-compile/skainet-compile-hlo/api/jvm/skainet-compile-hlo.api b/skainet-compile/skainet-compile-hlo/api/jvm/skainet-compile-hlo.api index 62483e10..ffbaa568 100644 --- a/skainet-compile/skainet-compile-hlo/api/jvm/skainet-compile-hlo.api +++ b/skainet-compile/skainet-compile-hlo/api/jvm/skainet-compile-hlo.api @@ -471,15 +471,25 @@ public final class sk/ainet/compile/hlo/generate/HloGenerator { public static final field INSTANCE Lsk/ainet/compile/hlo/generate/HloGenerator; public final fun generate (Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun generate$default (Lsk/ainet/compile/hlo/generate/HloGenerator;Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;Ljava/lang/String;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; - public static final fun generateBlocking (Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/compile/hlo/StableHloModule; - public static final fun generateBlocking (Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;Ljava/lang/String;)Lsk/ainet/compile/hlo/StableHloModule; - public static synthetic fun generateBlocking$default (Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;Ljava/lang/String;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloModule; +} + +public final class sk/ainet/compile/hlo/generate/HloGeneratorJvmKt { + public static final fun generateBlocking (Lsk/ainet/compile/hlo/generate/HloGenerator;Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/compile/hlo/StableHloModule; + public static final fun generateBlocking (Lsk/ainet/compile/hlo/generate/HloGenerator;Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;Ljava/lang/String;)Lsk/ainet/compile/hlo/StableHloModule; + public static synthetic fun generateBlocking$default (Lsk/ainet/compile/hlo/generate/HloGenerator;Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;Ljava/lang/String;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloModule; } public final class sk/ainet/compile/hlo/generate/HloGeneratorMainKt { public static final fun main ([Ljava/lang/String;)V } +public final class sk/ainet/compile/hlo/generate/JvmHloGenerator { + public static final field INSTANCE Lsk/ainet/compile/hlo/generate/JvmHloGenerator; + public static final fun generateBlocking (Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/compile/hlo/StableHloModule; + public static final fun generateBlocking (Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;Ljava/lang/String;)Lsk/ainet/compile/hlo/StableHloModule; + public static synthetic fun generateBlocking$default (Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;Ljava/lang/String;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloModule; +} + public final class sk/ainet/compile/hlo/validation/BenchmarkResults { public fun (Ljava/lang/String;Ljava/util/List;Lsk/ainet/compile/hlo/validation/ConversionMetrics;Lsk/ainet/compile/hlo/validation/ConversionMetrics;Lsk/ainet/compile/hlo/validation/ConversionMetrics;)V public final fun component1 ()Ljava/lang/String; diff --git a/skainet-compile/skainet-compile-hlo/build.gradle.kts b/skainet-compile/skainet-compile-hlo/build.gradle.kts index 71526c75..01fd77c5 100644 --- a/skainet-compile/skainet-compile-hlo/build.gradle.kts +++ b/skainet-compile/skainet-compile-hlo/build.gradle.kts @@ -43,21 +43,19 @@ kotlin { sourceSets { commonMain.dependencies { api(project(":skainet-lang:skainet-lang-core")) + api(project(":skainet-lang:skainet-lang-models")) api(project(":skainet-compile:skainet-compile-core")) api(project(":skainet-compile:skainet-compile-dag")) } commonTest.dependencies { implementation(libs.kotlin.test) + implementation(libs.kotlinx.coroutines.test) implementation(project(":skainet-backends:skainet-backend-cpu")) } jvmMain.dependencies { - // HloGenerator records traces with VoidTensorOps from - // skainet-lang-core — the JVM production path never needs a - // real backend implementation. No CPU-specific imports here. - implementation(project(":skainet-lang:skainet-lang-models")) implementation(libs.kotlinx.coroutines) } diff --git a/skainet-compile/skainet-compile-hlo/src/jvmMain/kotlin/sk/ainet/compile/hlo/generate/HloGenerator.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/generate/HloGenerator.kt similarity index 57% rename from skainet-compile/skainet-compile-hlo/src/jvmMain/kotlin/sk/ainet/compile/hlo/generate/HloGenerator.kt rename to skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/generate/HloGenerator.kt index e3ec8a20..661fdc06 100644 --- a/skainet-compile/skainet-compile-hlo/src/jvmMain/kotlin/sk/ainet/compile/hlo/generate/HloGenerator.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/generate/HloGenerator.kt @@ -1,6 +1,5 @@ package sk.ainet.compile.hlo.generate -import kotlinx.coroutines.runBlocking import sk.ainet.compile.hlo.StableHloConverterFactory import sk.ainet.compile.hlo.StableHloModule import sk.ainet.lang.graph.DefaultGraphExecutionContext @@ -23,8 +22,8 @@ public object HloGenerator { /** * Generate StableHLO from any [Model] and a sample input tensor. * - * @param model The model whose forward pass will be traced. - * @param sampleInput A tensor with the desired input shape/dtype (values don't matter). + * @param model The model whose forward pass will be traced. + * @param sampleInput A tensor with the desired input shape/dtype (values do not matter). * @param functionName The MLIR function name in the emitted module. */ public suspend fun generate( @@ -37,7 +36,7 @@ public object HloGenerator { // Capture the rebound sample input's tensor ref ID so we can mark it as a function argument. @Suppress("UNCHECKED_CAST") - val inputRefId = ctx.session.refOf(traceInput as sk.ainet.lang.tensor.Tensor<*, *>).id + val inputRefId = ctx.session.refOf(traceInput as Tensor<*, *>).id val (tape, _) = ctx.record { @Suppress("UNCHECKED_CAST") @@ -57,45 +56,6 @@ public object HloGenerator { return converter.convert(computeGraph, functionName) } - internal suspend fun generate(descriptor: ModelDescriptor, height: Int, width: Int, batch: Int): StableHloModule { - val ctx = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps()) - - var sampleInputRefId: String? = null - val (tape, _) = ctx.record { - val (model, sampleInput) = descriptor.createModelAndInput(ctx, height, width, batch) - @Suppress("UNCHECKED_CAST") - sampleInputRefId = ctx.session.refOf(sampleInput as sk.ainet.lang.tensor.Tensor<*, *>).id - @Suppress("UNCHECKED_CAST") - traceForwardPass( - model as Model, Tensor>, - sampleInput as Tensor, - ctx - ) - } - - val inputIds = sampleInputRefId?.let { setOf(it) } ?: emptySet() - val computeGraph = tape?.toComputeGraph( - synthesizeExternalInputs = true, - inputTensorIds = inputIds - ) ?: error("Failed to create compute graph: no execution tape was recorded") - - val converter = StableHloConverterFactory.createExtended() - return converter.convert(computeGraph, descriptor.functionName) - } - - /** - * Blocking variant of [generate] for Java callers who cannot use `suspend`. - */ - @JvmStatic - @JvmOverloads - public fun generateBlocking( - model: Model, Tensor>, - sampleInput: Tensor, - functionName: String = "main" - ): StableHloModule = runBlocking { - generate(model, sampleInput, functionName) - } - private suspend fun traceForwardPass( model: Model, Tensor>, sampleInput: Tensor, diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/generate/HloGeneratorCommonTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/generate/HloGeneratorCommonTest.kt new file mode 100644 index 00000000..dbf001ef --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/generate/HloGeneratorCommonTest.kt @@ -0,0 +1,50 @@ +package sk.ainet.compile.hlo.generate + +import kotlinx.coroutines.test.runTest +import sk.ainet.context.ExecutionContext +import sk.ainet.lang.graph.DefaultGraphExecutionContext +import sk.ainet.lang.model.Model +import sk.ainet.lang.model.ModelCard +import sk.ainet.lang.nn.Module +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.ops.VoidTensorOps +import sk.ainet.lang.tensor.times +import sk.ainet.lang.types.FP32 +import kotlin.test.Test +import kotlin.test.assertTrue + +class HloGeneratorCommonTest { + + @Test + fun commonGeneratorEmitsStableHloOps() = runTest { + val ctx = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps()) + val sampleInput = ctx.fromFloatArray( + shape = Shape(2, 2), + dtype = FP32::class, + data = floatArrayOf(1f, 2f, 3f, 4f) + ) + + val module = HloGenerator.generate(SquareModel(), sampleInput, "square") + + assertTrue(module.content.contains("func.func @square"), "Expected generated function name") + assertTrue(module.content.contains("stablehlo.multiply"), "Expected traced multiply op") + assertTrue(module.content.contains("tensor<2x2xf32>"), "Expected input/output tensor type") + } + + private class SquareModel : Model, Tensor> { + override fun create(executionContext: ExecutionContext): Module = object : Module() { + override val name: String = "square" + override val modules: List> = emptyList() + } + + override suspend fun calculate( + module: Module, + inputValue: Tensor, + executionContext: ExecutionContext, + reportProgress: suspend (current: Int, total: Int, message: String?) -> Unit + ): Tensor = inputValue * inputValue + + override fun modelCard(): ModelCard = error("Not needed for HLO generator tests") + } +} diff --git a/skainet-compile/skainet-compile-hlo/src/jvmMain/kotlin/sk/ainet/compile/hlo/generate/HloGeneratorJvm.kt b/skainet-compile/skainet-compile-hlo/src/jvmMain/kotlin/sk/ainet/compile/hlo/generate/HloGeneratorJvm.kt new file mode 100644 index 00000000..9606a8de --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/jvmMain/kotlin/sk/ainet/compile/hlo/generate/HloGeneratorJvm.kt @@ -0,0 +1,53 @@ +package sk.ainet.compile.hlo.generate + +import kotlinx.coroutines.runBlocking +import sk.ainet.compile.hlo.StableHloModule +import sk.ainet.context.ExecutionContext +import sk.ainet.lang.graph.DefaultGraphExecutionContext +import sk.ainet.lang.model.Model +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.ops.VoidTensorOps +import sk.ainet.lang.types.DType + +/** + * Blocking JVM convenience wrapper for Java callers and CLI-style integrations. + */ +@JvmName("generateBlocking") +@JvmOverloads +public fun HloGenerator.generateBlocking( + model: Model, Tensor>, + sampleInput: Tensor, + functionName: String = "main" +): StableHloModule = runBlocking { + generate(model, sampleInput, functionName) +} + +internal suspend fun HloGenerator.generate( + descriptor: ModelDescriptor, + height: Int, + width: Int, + batch: Int +): StableHloModule { + val ctx: ExecutionContext = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps()) + val (model, sampleInput) = descriptor.createModelAndInput(ctx, height, width, batch) + + @Suppress("UNCHECKED_CAST") + return generate( + model as Model, Tensor>, + sampleInput as Tensor, + descriptor.functionName + ) +} + +/** + * JVM-named facade for callers that prefer static Java interop over Kotlin extension syntax. + */ +public object JvmHloGenerator { + @JvmStatic + @JvmOverloads + public fun generateBlocking( + model: Model, Tensor>, + sampleInput: Tensor, + functionName: String = "main" + ): StableHloModule = HloGenerator.generateBlocking(model, sampleInput, functionName) +}