Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -471,15 +471,25 @@ public final class sk/ainet/compile/hlo/generate/HloGenerator {
public static final field INSTANCE Lsk/ainet/compile/hlo/generate/HloGenerator;
public final fun generate (Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static synthetic fun generate$default (Lsk/ainet/compile/hlo/generate/HloGenerator;Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;Ljava/lang/String;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
public static final fun generateBlocking (Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/compile/hlo/StableHloModule;
public static final fun generateBlocking (Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;Ljava/lang/String;)Lsk/ainet/compile/hlo/StableHloModule;
public static synthetic fun generateBlocking$default (Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;Ljava/lang/String;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloModule;
}

public final class sk/ainet/compile/hlo/generate/HloGeneratorJvmKt {
public static final fun generateBlocking (Lsk/ainet/compile/hlo/generate/HloGenerator;Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/compile/hlo/StableHloModule;
public static final fun generateBlocking (Lsk/ainet/compile/hlo/generate/HloGenerator;Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;Ljava/lang/String;)Lsk/ainet/compile/hlo/StableHloModule;
public static synthetic fun generateBlocking$default (Lsk/ainet/compile/hlo/generate/HloGenerator;Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;Ljava/lang/String;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloModule;
}

public final class sk/ainet/compile/hlo/generate/HloGeneratorMainKt {
public static final fun main ([Ljava/lang/String;)V
}

public final class sk/ainet/compile/hlo/generate/JvmHloGenerator {
public static final field INSTANCE Lsk/ainet/compile/hlo/generate/JvmHloGenerator;
public static final fun generateBlocking (Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/compile/hlo/StableHloModule;
public static final fun generateBlocking (Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;Ljava/lang/String;)Lsk/ainet/compile/hlo/StableHloModule;
public static synthetic fun generateBlocking$default (Lsk/ainet/lang/model/Model;Lsk/ainet/lang/tensor/Tensor;Ljava/lang/String;ILjava/lang/Object;)Lsk/ainet/compile/hlo/StableHloModule;
}

public final class sk/ainet/compile/hlo/validation/BenchmarkResults {
public fun <init> (Ljava/lang/String;Ljava/util/List;Lsk/ainet/compile/hlo/validation/ConversionMetrics;Lsk/ainet/compile/hlo/validation/ConversionMetrics;Lsk/ainet/compile/hlo/validation/ConversionMetrics;)V
public final fun component1 ()Ljava/lang/String;
Expand Down
6 changes: 2 additions & 4 deletions skainet-compile/skainet-compile-hlo/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,19 @@ kotlin {
sourceSets {
commonMain.dependencies {
api(project(":skainet-lang:skainet-lang-core"))
api(project(":skainet-lang:skainet-lang-models"))
api(project(":skainet-compile:skainet-compile-core"))
api(project(":skainet-compile:skainet-compile-dag"))
}

commonTest.dependencies {
implementation(libs.kotlin.test)
implementation(libs.kotlinx.coroutines.test)
implementation(project(":skainet-backends:skainet-backend-cpu"))

}

jvmMain.dependencies {
// HloGenerator records traces with VoidTensorOps from
// skainet-lang-core — the JVM production path never needs a
// real backend implementation. No CPU-specific imports here.
implementation(project(":skainet-lang:skainet-lang-models"))
implementation(libs.kotlinx.coroutines)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package sk.ainet.compile.hlo.generate

import kotlinx.coroutines.runBlocking
import sk.ainet.compile.hlo.StableHloConverterFactory
import sk.ainet.compile.hlo.StableHloModule
import sk.ainet.lang.graph.DefaultGraphExecutionContext
Expand All @@ -23,8 +22,8 @@ public object HloGenerator {
/**
* Generate StableHLO from any [Model] and a sample input tensor.
*
* @param model The model whose forward pass will be traced.
* @param sampleInput A tensor with the desired input shape/dtype (values don't matter).
* @param model The model whose forward pass will be traced.
* @param sampleInput A tensor with the desired input shape/dtype (values do not matter).
* @param functionName The MLIR function name in the emitted module.
*/
public suspend fun <D : DType, V> generate(
Expand All @@ -37,7 +36,7 @@ public object HloGenerator {

// Capture the rebound sample input's tensor ref ID so we can mark it as a function argument.
@Suppress("UNCHECKED_CAST")
val inputRefId = ctx.session.refOf(traceInput as sk.ainet.lang.tensor.Tensor<*, *>).id
val inputRefId = ctx.session.refOf(traceInput as Tensor<*, *>).id

val (tape, _) = ctx.record {
@Suppress("UNCHECKED_CAST")
Expand All @@ -57,45 +56,6 @@ public object HloGenerator {
return converter.convert(computeGraph, functionName)
}

internal suspend fun generate(descriptor: ModelDescriptor, height: Int, width: Int, batch: Int): StableHloModule {
val ctx = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps())

var sampleInputRefId: String? = null
val (tape, _) = ctx.record {
val (model, sampleInput) = descriptor.createModelAndInput(ctx, height, width, batch)
@Suppress("UNCHECKED_CAST")
sampleInputRefId = ctx.session.refOf(sampleInput as sk.ainet.lang.tensor.Tensor<*, *>).id
@Suppress("UNCHECKED_CAST")
traceForwardPass(
model as Model<DType, Any?, Tensor<DType, Any?>, Tensor<DType, Any?>>,
sampleInput as Tensor<DType, Any?>,
ctx
)
}

val inputIds = sampleInputRefId?.let { setOf(it) } ?: emptySet()
val computeGraph = tape?.toComputeGraph(
synthesizeExternalInputs = true,
inputTensorIds = inputIds
) ?: error("Failed to create compute graph: no execution tape was recorded")

val converter = StableHloConverterFactory.createExtended()
return converter.convert(computeGraph, descriptor.functionName)
}

/**
* Blocking variant of [generate] for Java callers who cannot use `suspend`.
*/
@JvmStatic
@JvmOverloads
public fun <D : DType, V> generateBlocking(
model: Model<D, V, Tensor<D, V>, Tensor<D, V>>,
sampleInput: Tensor<D, V>,
functionName: String = "main"
): StableHloModule = runBlocking {
generate(model, sampleInput, functionName)
}

private suspend fun traceForwardPass(
model: Model<DType, Any?, Tensor<DType, Any?>, Tensor<DType, Any?>>,
sampleInput: Tensor<DType, Any?>,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package sk.ainet.compile.hlo.generate

import kotlinx.coroutines.test.runTest
import sk.ainet.context.ExecutionContext
import sk.ainet.lang.graph.DefaultGraphExecutionContext
import sk.ainet.lang.model.Model
import sk.ainet.lang.model.ModelCard
import sk.ainet.lang.nn.Module
import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.tensor.Tensor
import sk.ainet.lang.tensor.ops.VoidTensorOps
import sk.ainet.lang.tensor.times
import sk.ainet.lang.types.FP32
import kotlin.test.Test
import kotlin.test.assertTrue

class HloGeneratorCommonTest {

@Test
fun commonGeneratorEmitsStableHloOps() = runTest {
val ctx = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps())
val sampleInput = ctx.fromFloatArray<FP32, Float>(
shape = Shape(2, 2),
dtype = FP32::class,
data = floatArrayOf(1f, 2f, 3f, 4f)
)

val module = HloGenerator.generate(SquareModel(), sampleInput, "square")

assertTrue(module.content.contains("func.func @square"), "Expected generated function name")
assertTrue(module.content.contains("stablehlo.multiply"), "Expected traced multiply op")
assertTrue(module.content.contains("tensor<2x2xf32>"), "Expected input/output tensor type")
}

private class SquareModel : Model<FP32, Float, Tensor<FP32, Float>, Tensor<FP32, Float>> {
override fun create(executionContext: ExecutionContext): Module<FP32, Float> = object : Module<FP32, Float>() {
override val name: String = "square"
override val modules: List<Module<FP32, Float>> = emptyList()
}

override suspend fun calculate(
module: Module<FP32, Float>,
inputValue: Tensor<FP32, Float>,
executionContext: ExecutionContext,
reportProgress: suspend (current: Int, total: Int, message: String?) -> Unit
): Tensor<FP32, Float> = inputValue * inputValue

override fun modelCard(): ModelCard = error("Not needed for HLO generator tests")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package sk.ainet.compile.hlo.generate

import kotlinx.coroutines.runBlocking
import sk.ainet.compile.hlo.StableHloModule
import sk.ainet.context.ExecutionContext
import sk.ainet.lang.graph.DefaultGraphExecutionContext
import sk.ainet.lang.model.Model
import sk.ainet.lang.tensor.Tensor
import sk.ainet.lang.tensor.ops.VoidTensorOps
import sk.ainet.lang.types.DType

/**
* Blocking JVM convenience wrapper for Java callers and CLI-style integrations.
*/
@JvmName("generateBlocking")
@JvmOverloads
public fun <D : DType, V> HloGenerator.generateBlocking(
model: Model<D, V, Tensor<D, V>, Tensor<D, V>>,
sampleInput: Tensor<D, V>,
functionName: String = "main"
): StableHloModule = runBlocking {
generate(model, sampleInput, functionName)
}

internal suspend fun HloGenerator.generate(
descriptor: ModelDescriptor,
height: Int,
width: Int,
batch: Int
): StableHloModule {
val ctx: ExecutionContext = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps())
val (model, sampleInput) = descriptor.createModelAndInput(ctx, height, width, batch)

@Suppress("UNCHECKED_CAST")
return generate(
model as Model<DType, Any?, Tensor<DType, Any?>, Tensor<DType, Any?>>,
sampleInput as Tensor<DType, Any?>,
descriptor.functionName
)
}

/**
* JVM-named facade for callers that prefer static Java interop over Kotlin extension syntax.
*/
public object JvmHloGenerator {
@JvmStatic
@JvmOverloads
public fun <D : DType, V> generateBlocking(
model: Model<D, V, Tensor<D, V>, Tensor<D, V>>,
sampleInput: Tensor<D, V>,
functionName: String = "main"
): StableHloModule = HloGenerator.generateBlocking(model, sampleInput, functionName)
}
Loading