Skip to content

fix(dag): infer output spec for reshape/matmul/concat (#673)#674

Merged
michalharakal merged 2 commits into
developfrom
fix/dag-shape-output-spec-673
Jun 6, 2026
Merged

fix(dag): infer output spec for reshape/matmul/concat (#673)#674
michalharakal merged 2 commits into
developfrom
fix/dag-shape-output-spec-673

Conversation

@michalharakal

Copy link
Copy Markdown
Contributor

Addresses the result-type class of #673.

Problem

DagBuilder.ensureOutputSpecs (skainet-lang-dag) fell back to echoing operand-0's shape for any op without special-case shape inference. So shape-changing ops exported a declared result/return type that contradicts the value they produce, and iree-compile rejected them:

op declared true iree-compile
reshape (1,4)->(2,2) tensor<?> / dropped 2x2 empty module ("Missing shape parameter")
matmul (1,4)x(4,3) 1x4 1x3 inferred shape '[1,3]' is incompatible with return type
concat [(1,4),(1,4)] dim 1 return 1x4 1x8 type mismatch

The converter's own #666/#667 fixes (PR #670) couldn't engage because the node never carried a usable output spec — the bug is upstream, in graph construction.

The synthetic-node ReshapeConcatShapeFixTest (PR #670) passed because it hand-populated the spec the real dag{} DSL never sets. This is why the conformance harness — which uses the real DSL — still failed on 0.28.0.

Fix

inferDagOutputSpecs now computes the real output spec for:

  • reshape/view — from parameters["newShape"] (a Shape, which the converter's as? List<Int> had missed), falling back to shape/outputShape.
  • matmul/dot/mm/bmm/batch_matmullhs[.., M, K] @ rhs[.., K, N] -> [.., M, N].
  • concat/concatenate/cat — sum the extents along the (normalized) axis.

The corrected spec flows to both the converter result type and the func.func return type, so each module is self-consistent for iree-compile.

Verification

  • New DagShapeExportConformanceTest (exercises the real dag{}.toComputeGraph() path, same as the conformance harness) — was 3/3 RED, now 3/3 green.
  • Full skainet-compile-hlo, skainet-compile-dag, skainet-lang-dag jvmTest suites: green, no regressions.

Still open in #673 (separate root causes, not this PR)

  • conv1d / gather / maxpool2d / avgpool2d output-shape inference (same inferDagOutputSpecs gap, different shape math).
  • stablehlo.reduce_window emitted in a form IREE rejects (has no custom assembly form) — a converter emission-syntax fix.

…/concat)

Exercises the REAL sk.ainet.lang.dag DSL path (dag{}.toComputeGraph() -> extended
converter), exactly as the skainet-iree-conformance harness builds its op modules.
3 tests, all RED on develop/0.28.0:

- reshape (1,4)->(2,2): target shape lost -> empty module (no stablehlo.reshape)
- matmul (1,4)x(4,3): declares -> 1x4 (echoes operand-0); true result is 1x3
- concat [(1,4),(1,4)] dim1: op types 1x8 but function return stays 1x4

Root cause: shape-changing ops declare result/return types from a stale output
spec instead of the op's inferred output. iree-compile rejects all three
('inferred shape incompatible with return type'). Unlike ReshapeConcatShapeFixTest
(synthetic GraphNodes, green), this hits the real DSL path and stays red until the
dag->graph output-spec inference is fixed.

Lock for the follow-up fix.
DagBuilder.ensureOutputSpecs fell back to echoing operand-0's shape whenever an
op had no special-case shape inference, so shape-changing ops exported a declared
result/return type that contradicted the value they produce -> iree-compile
rejected them ('inferred shape incompatible with return type'):

  - reshape (1,4)->(2,2): target shape (a Shape in parameters["newShape"]) dropped
  - matmul  (1,4)x(4,3): declared 1x4 (operand-0), true result 1x3
  - concat  [(1,4),(1,4)] dim 1: summed axis lost; return stayed 1x4

inferDagOutputSpecs now computes the real output spec for reshape/view (from the
newShape/shape param), matmul/dot/mm/bmm (lhs[..,:-1]+rhs[-1]) and concat/cat
(sum the axis). The corrected spec flows to both the converter result type and the
func.func return type, so the modules are self-consistent for iree-compile.

DagShapeExportConformanceTest (real dag{} DSL path) goes 3/3 green; full
skainet-compile-hlo / skainet-compile-dag / skainet-lang-dag suites still green.
@github-actions

github-actions Bot commented Jun 6, 2026

Copy link
Copy Markdown

📖 Documentation Preview

The documentation has been built successfully for this PR.

Generated Files:

  • Operator documentation: docs/modules/operators/_generated_/
  • JSON schema output: operators.json

Artifacts:

  • Download the documentation-preview-674 artifact to view the complete documentation locally.

This comment will be updated automatically when the PR is updated.

@michalharakal michalharakal merged commit 3b8aff3 into develop Jun 6, 2026
7 checks passed
michalharakal added a commit that referenced this pull request Jun 6, 2026
…reduce_window (#675)

Follow-up to #674. Completes DAG-DSL StableHLO export so every conformance model
and op compiles with iree-compile.

inferDagOutputSpecs (skainet-lang-dag) gains shape rules for:
- conv1d  : (N,Cin,L) * (Cout,_,K) -> (N, Cout, floor((L+2p - d(K-1) - 1)/s) + 1)
- gather  : table[:axis] + indices.shape + table[axis+1:]
- maxpool/avgpool : windowed (N, C, Hout, Wout)
- flatten : collapse dims [startDim..endDim], preserving the leading batch dim
            (it was collapsing everything to rank-1, breaking the dense matmul in mnist-cnn)
conv2d already inferred via Conv2dOperation; conv1d was a GenericOperation that fell
back to echoing operand-0.

reduce_window emission (NeuralNetOperationsConverter): emit the IREE-parseable generic
region form  %r = "stablehlo.reduce_window"(%in, %init) ({ ^bb0(%a, %b): ... })  instead
of the pretty  "applies <op> over window dimensions = ..."  form IREE rejects
("has no custom assembly form"). Full NCHW-rank window attributes; region-local SSA
names made unique per op so two pools in one function don't collide in the validator;
avg-pool divisor splatted to the output type (was a scalar-vs-tensor mismatch).

MlirValidator: register region block-argument SSA defs (^bb0(%a, %b)) and every
"%x =" result on a line, so single-line region ops validate.

Verified end-to-end via an unsigned 0.28.1-SNAPSHOT on mavenLocal + skainet-iree-conformance:
27/27 ops and 7/7 models (incl. mnist-cnn) iree-compile to a vmfb.
DagConvGatherPoolExportTest 5/5 green; hlo / dag / lang-dag suites green.
@michalharakal michalharakal deleted the fix/dag-shape-output-spec-673 branch June 6, 2026 21:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant