Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -112,20 +112,25 @@ public class MlirValidator {
// not).
if (trimmed.isEmpty() || trimmed.startsWith("//") || trimmed.startsWith("module")) continue

// Extract defined SSA values
if (trimmed.contains(" = ")) {
val parts = trimmed.split(" = ", limit = 2)
if (parts.size == 2) {
val valueName = parts[0].trim()
if (valueName.startsWith("%")) {
if (definedValues.contains(valueName)) {
errors.add("Line ${lineNum + 1}: SSA value $valueName redefined")
}
definedValues.add(valueName)
}
// Extract defined SSA values. A line may carry more than one result
// definition when an op with a region is emitted on a single line
// (e.g. `reduce_window … ({ ^bb0(%a, %b): %r = … })`), so register every
// `%name =` result, not just the leading assignment.
Regex("(%[a-zA-Z0-9_]+)\\s*=").findAll(trimmed).forEach { m ->
val valueName = m.groupValues[1]
if (definedValues.contains(valueName)) {
errors.add("Line ${lineNum + 1}: SSA value $valueName redefined")
}
definedValues.add(valueName)
}


// Register block-argument definitions from region entry blocks
// (`^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): …`). These bind SSA
// values without a ` = `, so they must be collected separately.
Regex("\\^[a-zA-Z0-9_]*\\(([^)]*)\\)").findAll(trimmed).forEach { block ->
Regex("%[a-zA-Z0-9_]+").findAll(block.groupValues[1]).forEach { definedValues.add(it.value) }
}

// Extract used SSA values
val usedInLine = extractUsedValues(trimmed)
usedValues.addAll(usedInLine)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,13 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
val padding = extractPadding(params)

val resultValue = context.nextTempValue()

val inputType = node.inputs.firstOrNull()?.let { context.getTypeMapper().mapTensorType(it) } ?: outputType

// Build StableHLO reduce_window operation for max pooling
val operations = buildMaxPoolOperations(
resultValue = resultValue,
input = operands[0],
inputType = inputType,
outputType = outputType,
kernelSize = kernelSize,
stride = stride,
Expand Down Expand Up @@ -225,11 +227,13 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
val padding = extractPadding(params)

val resultValue = context.nextTempValue()

val inputType = node.inputs.firstOrNull()?.let { context.getTypeMapper().mapTensorType(it) } ?: outputType

// Build StableHLO reduce_window operation for average pooling
val operations = buildAvgPoolOperations(
resultValue = resultValue,
input = operands[0],
inputType = inputType,
outputType = outputType,
kernelSize = kernelSize,
stride = stride,
Expand Down Expand Up @@ -691,64 +695,103 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
}
}

/** The MLIR element type ("f32"/"f16"/…) parsed from a `tensor<…xT>` string. */
private fun elementTypeOf(tensorType: String): String =
tensorType.substringAfterLast('x').substringBefore('>').ifBlank { "f32" }

/**
* Emit a `reduce_window` in IREE's parseable **generic region** form. The pretty
* `… applies <op> over window dimensions = …` form is rejected by IREE's StableHLO
* parser ("has no custom assembly form"), and its 2-element window only covered H/W;
* the generic form carries full NCHW-rank (`[1, 1, kH, kW]`) window attributes. (#675)
*/
private fun reduceWindowGeneric(
resultValue: String,
input: String,
inputType: String,
initValue: String,
reduceOp: String,
elem: String,
kernelSize: Pair<Int, Int>,
stride: Pair<Int, Int>,
padding: Pair<Int, Int>,
outputType: String,
): String {
val (kH, kW) = kernelSize
val (sH, sW) = stride
val (pH, pW) = padding
// Single line: MLIR treats newlines as whitespace, and the line-based MLIR
// validator only handles one op per line. The region body ops are separated
// by spaces, which the MLIR parser accepts.
// Region-local SSA names are derived from the (unique) result value so two
// pooling ops in one function don't collide in the flat validator (they are
// region-scoped in MLIR, but the validator tracks names globally).
val t = resultValue.removePrefix("%")
return "$resultValue = \"stablehlo.reduce_window\"($input, $initValue) ({ " +
"^bb0(%lhs_$t: tensor<$elem>, %rhs_$t: tensor<$elem>): " +
"%out_$t = $reduceOp %lhs_$t, %rhs_$t : tensor<$elem> " +
"stablehlo.return %out_$t : tensor<$elem> " +
"}) {window_dimensions = array<i64: 1, 1, $kH, $kW>, " +
"window_strides = array<i64: 1, 1, $sH, $sW>, " +
"base_dilations = array<i64: 1, 1, 1, 1>, " +
"window_dilations = array<i64: 1, 1, 1, 1>, " +
"padding = dense<[[0, 0], [0, 0], [$pH, $pH], [$pW, $pW]]> : tensor<4x2xi64>} : " +
"($inputType, tensor<$elem>) -> $outputType"
}

private fun buildMaxPoolOperations(
resultValue: String,
input: String,
inputType: String,
outputType: String,
kernelSize: Pair<Int, Int>,
stride: Pair<Int, Int>,
padding: Pair<Int, Int>,
context: ConversionContext
): List<String> {
// For max pooling, we need to create a negative infinity constant as the initial value
val elem = elementTypeOf(outputType)
val initValue = context.nextTempValue()
val initConstant = "$initValue = stablehlo.constant dense<-3.4028235e+38> : tensor<f32>"

val poolOp = "$resultValue = stablehlo.reduce_window($input, $initValue) " +
"applies stablehlo.maximum " +
"over window dimensions = [${kernelSize.first}, ${kernelSize.second}] " +
"stride = [${stride.first}, ${stride.second}] " +
"pad = [[${padding.first}, ${padding.first}], [${padding.second}, ${padding.second}]] : $outputType"

// Emit operations through context
val initConstant = "$initValue = stablehlo.constant dense<-3.4028235e+38> : tensor<$elem>"
val poolOp = reduceWindowGeneric(
resultValue, input, inputType, initValue, "stablehlo.maximum",
elem, kernelSize, stride, padding, outputType,
)
context.emitOperation(initConstant)
context.emitOperation(poolOp)

return listOf(initConstant, poolOp)
}

private fun buildAvgPoolOperations(
resultValue: String,
input: String,
inputType: String,
outputType: String,
kernelSize: Pair<Int, Int>,
stride: Pair<Int, Int>,
padding: Pair<Int, Int>,
context: ConversionContext
): List<String> {
// Average pooling requires sum + division by kernel size
// Average pooling requires sum + division by kernel size.
val elem = elementTypeOf(outputType)
val kernelArea = kernelSize.first * kernelSize.second
val initZero = context.nextTempValue()
val kernelAreaConst = context.nextTempValue()
val sumResult = context.nextTempValue()

val initConstant = "$initZero = stablehlo.constant dense<0.0> : tensor<f32>"
val areaConstant = "$kernelAreaConst = stablehlo.constant dense<$kernelArea.0> : tensor<f32>"

val sumOp = "$sumResult = stablehlo.reduce_window($input, $initZero) " +
"applies stablehlo.add " +
"over window dimensions = [${kernelSize.first}, ${kernelSize.second}] " +
"stride = [${stride.first}, ${stride.second}] " +
"pad = [[${padding.first}, ${padding.first}], [${padding.second}, ${padding.second}]] : $outputType"


val initConstant = "$initZero = stablehlo.constant dense<0.0> : tensor<$elem>"
// Splat over the output type so the divide is element-type consistent (a scalar
// tensor<f32> divisor was a latent type mismatch).
val areaConstant = "$kernelAreaConst = stablehlo.constant dense<$kernelArea.0> : $outputType"
val sumOp = reduceWindowGeneric(
sumResult, input, inputType, initZero, "stablehlo.add",
elem, kernelSize, stride, padding, outputType,
)
val divideOp = "$resultValue = stablehlo.divide $sumResult, $kernelAreaConst : $outputType"

// Emit operations through context

context.emitOperation(initConstant)
context.emitOperation(areaConstant)
context.emitOperation(sumOp)
context.emitOperation(divideOp)

return listOf(initConstant, areaConstant, sumOp, divideOp)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ class NeuralNetOperationsConverterTest {

assertNotNull(module)
assertContains(module.content, "stablehlo.reduce_window")
// Should contain kernel size and stride information
assertContains(module.content, "window dimensions")
assertContains(module.content, "stride")
// Generic region form (IREE-parseable): window_dimensions / window_strides attrs.
assertContains(module.content, "window_dimensions")
assertContains(module.content, "window_strides")
}

private fun createConv2dGraph(): DefaultComputeGraph {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package sk.ainet.compile.hlo

import sk.ainet.lang.dag.avgPool2d
import sk.ainet.lang.dag.conv1d
import sk.ainet.lang.dag.dag
import sk.ainet.lang.dag.flatten
import sk.ainet.lang.dag.gather
import sk.ainet.lang.dag.maxPool2d
import sk.ainet.lang.graph.dsl.toComputeGraph
import sk.ainet.lang.tensor.ops.TensorSpec
import sk.ainet.lang.types.FP32
import sk.ainet.lang.types.Int32
import kotlin.test.Test
import kotlin.test.assertFalse
import kotlin.test.assertTrue

/**
* Remaining post-#674 DAG-DSL export bugs (tracked under the #674 follow-up issue).
*
* #674 fixed reshape/matmul/concat output-spec inference. These ops still declare a
* result/return type that contradicts the value they produce (conv/gather: `inferDagOutputSpecs`
* has no shape rule for them; pooling: also emits `reduce_window` in a form IREE rejects).
* All RED on develop after #674; lock for the follow-up fix.
*/
class DagConvGatherPoolExportTest {

private fun lower(name: String, build: sk.ainet.lang.dag.DagBuilder.() -> Unit): String =
StableHloConverterFactory.createExtended().convert(dag(build).toComputeGraph(), name).content

@Test
fun conv1d_declares_inferred_output_channels_and_length() {
// input (1,3,8), weight (4,3,3) stride 1 pad 0 -> (1,4,6). Currently declared 1x3x8.
val mlir = lower("op_conv1d") {
val x = input<FP32>("x", TensorSpec("x", listOf(1, 3, 8), "FP32"))
val w = input<FP32>("w", TensorSpec("w", listOf(4, 3, 3), "FP32"))
val b = input<FP32>("b", TensorSpec("b", listOf(4), "FP32"))
output(conv1d(x, w, b, 1, 0, 1, 1))
}
assertTrue(mlir.contains("-> tensor<1x4x6xf32>"), "conv1d result must be inferred 1x4x6:\n$mlir")
assertFalse(Regex("""return %\w+ : tensor<1x3x8xf32>""").containsMatchIn(mlir), "return must not echo the input shape:\n$mlir")
}

@Test
fun gather_declares_inferred_rows() {
// table (8,4), 3 indices -> (3,4). Currently declared 8x4.
val mlir = lower("op_gather") {
val t = input<FP32>("t", TensorSpec("t", listOf(8, 4), "FP32"))
val idx = input<Int32>("idx", TensorSpec("idx", listOf(3), "INT32"))
output(gather(t, idx, 0))
}
assertTrue(mlir.contains("-> tensor<3x4xf32>"), "gather result must be inferred 3x4:\n$mlir")
assertFalse(Regex("""return %\w+ : tensor<8x4xf32>""").containsMatchIn(mlir), "return must not echo the table shape:\n$mlir")
}

@Test
fun maxpool2d_declares_pooled_shape_and_iree_valid_reduce_window() {
// input (1,3,8,8), 2x2 stride 2 -> (1,3,4,4). Currently declared 1x3x8x8.
val mlir = lower("op_maxpool2d") {
val x = input<FP32>("x", TensorSpec("x", listOf(1, 3, 8, 8), "FP32"))
output(maxPool2d(x, 2 to 2, 2 to 2, 0 to 0))
}
assertTrue(mlir.contains("tensor<1x3x4x4xf32>"), "maxpool output must be the pooled 1x3x4x4:\n$mlir")
// IREE's parser rejects the pretty `applies … over window` form; it needs the generic
// region-based reduce_window. Assert we are not emitting the rejected pretty form.
assertFalse(
Regex("""reduce_window\([^)]*\)\s+applies""").containsMatchIn(mlir),
"reduce_window must use the IREE-parseable generic form, not 'applies … over window':\n$mlir",
)
}

@Test
fun flatten_preserves_leading_batch_dim() {
// (1,16,7,7) flatten dims 1..3 -> (1, 784); must NOT collapse to rank-1 (784),
// which breaks a downstream dense matmul (mnist-cnn).
val mlir = lower("op_flatten") {
val x = input<FP32>("x", TensorSpec("x", listOf(1, 16, 7, 7), "FP32"))
output(flatten(x, 1, 3))
}
assertTrue(mlir.contains("tensor<1x784xf32>"), "flatten must keep batch: (1,16,7,7)->(1,784):\n$mlir")
assertFalse(Regex("""-> tensor<784xf32>""").containsMatchIn(mlir), "flatten must not collapse the batch dim:\n$mlir")
}

@Test
fun avgpool2d_declares_pooled_shape() {
val mlir = lower("op_avgpool2d") {
val x = input<FP32>("x", TensorSpec("x", listOf(1, 3, 8, 8), "FP32"))
output(avgPool2d(x, 2 to 2, 2 to 2, 0 to 0, false))
}
assertTrue(mlir.contains("tensor<1x3x4x4xf32>"), "avgpool output must be the pooled 1x3x4x4:\n$mlir")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,20 @@ public class DagBuilder {
val target = reshapeTargetShape(operation) ?: return null
return spec(target)
}
"flatten" -> {
// Collapse dims [startDim..endDim] into one, preserving the others
// (notably the leading batch dim). Without this, flatten echoes operand-0
// or collapses everything, so a downstream dense matmul mis-types. (#675)
val inS = input?.shape ?: return null
val rank = inS.size
val rawStart = operation.parameters["startDim"] as? Int ?: 1
val rawEnd = operation.parameters["endDim"] as? Int ?: -1
val start = if (rawStart < 0) rank + rawStart else rawStart
val end = if (rawEnd < 0) rank + rawEnd else rawEnd
if (start !in 0 until rank || end !in 0 until rank || start > end) return null
val collapsed = inS.subList(start, end + 1).fold(1) { a, b -> a * b }
return spec(inS.subList(0, start) + collapsed + inS.subList(end + 1, rank))
}
"matmul", "dot", "mm", "bmm", "batch_matmul" -> {
val lhs = inputs.getOrNull(0)?.spec?.shape
val rhs = inputs.getOrNull(1)?.spec?.shape
Expand All @@ -215,10 +229,57 @@ public class DagBuilder {
out[axis] = shapes.sumOf { it[axis] }
return spec(out)
}
"conv1d" -> {
// (N, Cin, L) * (Cout, Cin/groups, K) -> (N, Cout, Lout). conv2d already
// infers via Conv2dOperation; conv1d is a GenericOperation with no inference. (#675)
val inS = inputs.getOrNull(0)?.spec?.shape
val wS = inputs.getOrNull(1)?.spec?.shape
if (inS == null || wS == null || inS.size != 3 || wS.size != 3) return null
val stride = operation.parameters["stride"] as? Int ?: 1
val pad = operation.parameters["padding"] as? Int ?: 0
val dil = operation.parameters["dilation"] as? Int ?: 1
return spec(listOf(inS[0], wS[0], windowedOutput(inS[2], wS[2], stride, pad, dil)))
}
"gather" -> {
// table[..axis..] gathered by `indices` -> table[:axis] ⊕ indices.shape ⊕ table[axis+1:]. (#675)
val table = inputs.getOrNull(0)?.spec?.shape
val idx = inputs.getOrNull(1)?.spec?.shape
if (table == null || idx == null || table.isEmpty()) return null
val rawAxis = operation.parameters["dim"] as? Int ?: operation.parameters["axis"] as? Int ?: -1
val axis = if (rawAxis < 0) table.size + rawAxis else rawAxis
if (axis !in table.indices) return null
return spec(table.subList(0, axis) + idx + table.subList(axis + 1, table.size))
}
"maxpool2d", "avgpool2d" -> {
// (N, C, H, W) windowed by kernel/stride/padding -> (N, C, Hout, Wout). (#675)
val inS = inputs.getOrNull(0)?.spec?.shape
if (inS == null || inS.size != 4) return null
val k = pairParam(operation, "kernel") ?: pairParam(operation, "kernelSize") ?: return null
val s = pairParam(operation, "stride") ?: (1 to 1)
val p = pairParam(operation, "padding") ?: (0 to 0)
return spec(
listOf(
inS[0], inS[1],
windowedOutput(inS[2], k.first, s.first, p.first, 1),
windowedOutput(inS[3], k.second, s.second, p.second, 1),
),
)
}
}
return null
}

/** Windowed (conv/pool) output extent: floor((in + 2·pad − dilation·(k−1) − 1) / stride) + 1. */
private fun windowedOutput(inDim: Int, k: Int, stride: Int, pad: Int, dilation: Int): Int =
(inDim + 2 * pad - dilation * (k - 1) - 1) / stride + 1

private fun pairParam(operation: Operation, key: String): Pair<Int, Int>? =
(operation.parameters[key] as? Pair<*, *>)?.let {
val a = it.first as? Int
val b = it.second as? Int
if (a != null && b != null) a to b else null
}

/** Recover a reshape/view target shape from the op's `newShape`/`shape` parameter. */
private fun reshapeTargetShape(operation: Operation): List<Int>? {
val raw = operation.parameters["newShape"]
Expand Down
Loading