fix(dag+hlo): conv1d/gather/pooling/flatten + IREE-valid reduce_window — 7/7 models compile (#675)#676
Merged
Conversation
…reduce_window (#675) Follow-up to #674. Completes DAG-DSL StableHLO export so every conformance model and op compiles with iree-compile. inferDagOutputSpecs (skainet-lang-dag) gains shape rules for: - conv1d : (N,Cin,L) * (Cout,_,K) -> (N, Cout, floor((L+2p - d(K-1) - 1)/s) + 1) - gather : table[:axis] + indices.shape + table[axis+1:] - maxpool/avgpool : windowed (N, C, Hout, Wout) - flatten : collapse dims [startDim..endDim], preserving the leading batch dim (it was collapsing everything to rank-1, breaking the dense matmul in mnist-cnn) conv2d already inferred via Conv2dOperation; conv1d was a GenericOperation that fell back to echoing operand-0. reduce_window emission (NeuralNetOperationsConverter): emit the IREE-parseable generic region form %r = "stablehlo.reduce_window"(%in, %init) ({ ^bb0(%a, %b): ... }) instead of the pretty "applies <op> over window dimensions = ..." form IREE rejects ("has no custom assembly form"). Full NCHW-rank window attributes; region-local SSA names made unique per op so two pools in one function don't collide in the validator; avg-pool divisor splatted to the output type (was a scalar-vs-tensor mismatch). MlirValidator: register region block-argument SSA defs (^bb0(%a, %b)) and every "%x =" result on a line, so single-line region ops validate. Verified end-to-end via an unsigned 0.28.1-SNAPSHOT on mavenLocal + skainet-iree-conformance: 27/27 ops and 7/7 models (incl. mnist-cnn) iree-compile to a vmfb. DagConvGatherPoolExportTest 5/5 green; hlo / dag / lang-dag suites green.
|
📖 Documentation Preview The documentation has been built successfully for this PR. Generated Files:
Artifacts:
This comment will be updated automatically when the PR is updated. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #675. Follow-up to #674 — completes the DAG-DSL StableHLO export so every conformance model and op compiles with
iree-compile.Fixes
inferDagOutputSpecs(skainet-lang-dag) — add output-shape rules for the ops that still echoed operand-0:conv1d—(N,Cin,L) * (Cout,_,K) → (N, Cout, floor((L+2p − d(K−1) − 1)/s) + 1). conv2d already inferred viaConv2dOperation; conv1d was aGenericOperation.gather—table[:axis] ⊕ indices.shape ⊕ table[axis+1:].maxpool2d/avgpool2d— windowed(N, C, Hout, Wout).flatten— collapse[startDim..endDim], preserving the leading batch dim (it was collapsing everything to rank-1, which broke the dense matmul in mnist-cnn).reduce_windowemission (NeuralNetOperationsConverter) — emit IREE's parseable generic region form instead of the prettyapplies … over windowform IREE rejects (has no custom assembly form). Full NCHW-rank window attributes; region-local SSA names made unique per op (two pools in one function); avg-pool divisor splatted to the output type (was a scalar-vs-tensor mismatch).MlirValidator— register region block-argument SSA defs (^bb0(%a, %b)) and every%x =result on a line, so single-line region ops validate.Verification (mavenLocal, no release risk)
Published an unsigned
0.28.1-SNAPSHOTto~/.m2and ranskainet-iree-conformanceagainst it:DagConvGatherPoolExportTest5/5 green;skainet-compile-hlo/skainet-compile-dag/skainet-lang-dagjvmTest suites green.After this merges
The export →
iree-compilepath is fully green for the conformance suite. Recommend cutting 0.28.1 (verify once more via mavenLocal), then the SKaiNET-transformers release decision reopens.