fix(dag): infer output spec for reshape/matmul/concat (#673)#674
Merged
Conversation
…/concat)
Exercises the REAL sk.ainet.lang.dag DSL path (dag{}.toComputeGraph() -> extended
converter), exactly as the skainet-iree-conformance harness builds its op modules.
3 tests, all RED on develop/0.28.0:
- reshape (1,4)->(2,2): target shape lost -> empty module (no stablehlo.reshape)
- matmul (1,4)x(4,3): declares -> 1x4 (echoes operand-0); true result is 1x3
- concat [(1,4),(1,4)] dim1: op types 1x8 but function return stays 1x4
Root cause: shape-changing ops declare result/return types from a stale output
spec instead of the op's inferred output. iree-compile rejects all three
('inferred shape incompatible with return type'). Unlike ReshapeConcatShapeFixTest
(synthetic GraphNodes, green), this hits the real DSL path and stays red until the
dag->graph output-spec inference is fixed.
Lock for the follow-up fix.
DagBuilder.ensureOutputSpecs fell back to echoing operand-0's shape whenever an
op had no special-case shape inference, so shape-changing ops exported a declared
result/return type that contradicted the value they produce -> iree-compile
rejected them ('inferred shape incompatible with return type'):
- reshape (1,4)->(2,2): target shape (a Shape in parameters["newShape"]) dropped
- matmul (1,4)x(4,3): declared 1x4 (operand-0), true result 1x3
- concat [(1,4),(1,4)] dim 1: summed axis lost; return stayed 1x4
inferDagOutputSpecs now computes the real output spec for reshape/view (from the
newShape/shape param), matmul/dot/mm/bmm (lhs[..,:-1]+rhs[-1]) and concat/cat
(sum the axis). The corrected spec flows to both the converter result type and the
func.func return type, so the modules are self-consistent for iree-compile.
DagShapeExportConformanceTest (real dag{} DSL path) goes 3/3 green; full
skainet-compile-hlo / skainet-compile-dag / skainet-lang-dag suites still green.
|
📖 Documentation Preview The documentation has been built successfully for this PR. Generated Files:
Artifacts:
This comment will be updated automatically when the PR is updated. |
This was referenced Jun 6, 2026
Closed
michalharakal
added a commit
that referenced
this pull request
Jun 6, 2026
…reduce_window (#675) Follow-up to #674. Completes DAG-DSL StableHLO export so every conformance model and op compiles with iree-compile. inferDagOutputSpecs (skainet-lang-dag) gains shape rules for: - conv1d : (N,Cin,L) * (Cout,_,K) -> (N, Cout, floor((L+2p - d(K-1) - 1)/s) + 1) - gather : table[:axis] + indices.shape + table[axis+1:] - maxpool/avgpool : windowed (N, C, Hout, Wout) - flatten : collapse dims [startDim..endDim], preserving the leading batch dim (it was collapsing everything to rank-1, breaking the dense matmul in mnist-cnn) conv2d already inferred via Conv2dOperation; conv1d was a GenericOperation that fell back to echoing operand-0. reduce_window emission (NeuralNetOperationsConverter): emit the IREE-parseable generic region form %r = "stablehlo.reduce_window"(%in, %init) ({ ^bb0(%a, %b): ... }) instead of the pretty "applies <op> over window dimensions = ..." form IREE rejects ("has no custom assembly form"). Full NCHW-rank window attributes; region-local SSA names made unique per op so two pools in one function don't collide in the validator; avg-pool divisor splatted to the output type (was a scalar-vs-tensor mismatch). MlirValidator: register region block-argument SSA defs (^bb0(%a, %b)) and every "%x =" result on a line, so single-line region ops validate. Verified end-to-end via an unsigned 0.28.1-SNAPSHOT on mavenLocal + skainet-iree-conformance: 27/27 ops and 7/7 models (incl. mnist-cnn) iree-compile to a vmfb. DagConvGatherPoolExportTest 5/5 green; hlo / dag / lang-dag suites green.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Addresses the result-type class of #673.
Problem
DagBuilder.ensureOutputSpecs(skainet-lang-dag) fell back to echoing operand-0's shape for any op without special-case shape inference. So shape-changing ops exported a declared result/return type that contradicts the value they produce, andiree-compilerejected them:reshape (1,4)->(2,2)tensor<?>/ dropped2x2matmul (1,4)x(4,3)1x41x3inferred shape '[1,3]' is incompatible with return typeconcat [(1,4),(1,4)] dim 11x41x8The converter's own #666/#667 fixes (PR #670) couldn't engage because the node never carried a usable output spec — the bug is upstream, in graph construction.
Fix
inferDagOutputSpecsnow computes the real output spec for:parameters["newShape"](aShape, which the converter'sas? List<Int>had missed), falling back toshape/outputShape.lhs[.., M, K] @ rhs[.., K, N] -> [.., M, N].The corrected spec flows to both the converter result type and the
func.funcreturn type, so each module is self-consistent foriree-compile.Verification
DagShapeExportConformanceTest(exercises the realdag{}.toComputeGraph()path, same as the conformance harness) — was 3/3 RED, now 3/3 green.skainet-compile-hlo,skainet-compile-dag,skainet-lang-dagjvmTest suites: green, no regressions.Still open in #673 (separate root causes, not this PR)
conv1d/gather/maxpool2d/avgpool2doutput-shape inference (sameinferDagOutputSpecsgap, different shape math).stablehlo.reduce_windowemitted in a form IREE rejects (has no custom assembly form) — a converter emission-syntax fix.