feat(hlo): DSL→StableHLO path for transformer/gemma3 export#661
Merged
Conversation
Re-applied on develop (0.25.0). Norm layers created void placeholder weights with `Any::class as KClass<T>`, erasing the element type to Object — which breaks weight-free graph tracing (VoidTensorOps alloc ops -> zeros(Object) throw "Unsupported dtype: Object"), blocking DAG->StableHLO before weights load. Fix the root cause: the DSL prescribes the logical element type. Layer/Group/ BatchNormalization get a non-breaking `dtype: KClass<T>? = null` (used as `(dtype ?: Any::class)`); the NetworkBuilder DSL builders pass `dtype = kClass`. Real (possibly dequantized) weights still override at load time. (RMSNormalization moved out of core in 0.25.0 -> handled transformer-side.) Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Closes the last gaps for lowering a gemma3 attention block to StableHLO (the other ops — matmul/reshape/transpose/squeeze/unsqueeze/concat/softmax/gather/ sqrt/addScalar/scaledDotProductAttention — are already covered on develop). - permute: arbitrary-axis transpose. convertTranspose already reads the `axes` parameter, so register `permute` as an alias routed to it. - narrow(dim,start,length): single-axis stablehlo.slice. Reads dim/start/length (the keys the graph tape records), builds start/limit/stride per dim. - NarrowPermuteConverterTest verifies both: permute -> transpose dims=[0,2,1], narrow -> slice start_indices=[0,2] limit_indices=[2,6]. Full hlo suite green. Remaining for full gemma: `split` (multi-output, used by RoPE). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The void/tracing gather collapsed the gathered axis to indices.shape[0], so a
[vocab,emb] table with [batch,seq] indices traced to [batch,emb] instead of
[batch,seq,emb] — breaking the embedding's downstream reshape ("volume mismatch
64 != 256") during weight-free tracing. Replace the axis with the FULL indices
shape, matching DefaultCpuOps.gather. Unblocks tracing full transformer
(gemma3) graphs to a ComputeGraph.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Lowers the last gemma3 op gap (split, used by RoPE). - ConversionContext: per-(nodeId, outputPort) SSA value names (port 0 keeps the bare-nodeId key, so all single-output callers are unchanged) + resolveOperands(node) that walks incoming edges in destinationInputIndex order and resolves each by the edge's sourceOutputIndex. - StableHloConverter: resolve operands via resolveOperands so a consumer of a multi-output op gets the right output (e.g. split chunk N, not chunk 0). - ShapeOperationsConverter: split/chunk -> N stablehlo.slice, each registered on its own output port. - Test: split -> 2 chunk slices + a relu consuming chunk 1 resolves to chunk 1. Full hlo suite green (operand-resolution change is equivalent for single-output). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Adds AttentionOperationsConverter: lowers the atomic scaledDotProductAttention op to the standard StableHLO subgraph — scores = Q·Kᵀ (dot_general, contract head_dim), * scale (arg or 1/sqrt(head_dim)), softmax over key length (stable max/sub/exp/sum/div), out = attn·V (dot_general). Batched [..,S,D]; batching dims = all leading dims. Registered in StableHloConverterFactory. v1: attention mask/causal not yet emitted (structurally correct, unmasked) — TODO. With this, a full gemma3 network lowers to StableHLO with ZERO gaps (verified by GemmaTraceTest over the composite build: 140 nodes -> 255 lines, 0 unsupported, 0 arity). SDPA is a core TensorOps op so its converter lives in core. Unit test asserts 2 dot_generals + softmax + scores shape. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…vmfb
Aligned converter emission to what iree-compile's stablehlo parser accepts
(verified by compiling the full gemma3 graph end-to-end):
- gather: use the GENERIC MLIR form `"stablehlo.gather"(%a,%b) <{...}>`
(stablehlo.gather has no custom assembly form).
- slice/narrow/split: canonical bracket form `stablehlo.slice %x [s:l:st, ...]
: (in) -> out` (attribute-dict form is rejected) — shared sliceLine() helper.
- concatenate: full functional type `(t0, t1, ...) -> out` (bare `: out` rejected).
- batch matmul: batch dims = leading dims shared by BOTH operands
(min(lhsRank,rhsRank)-2); fixes 3D-activation @ 2D-weight Linear projections
that previously emitted mismatched batching_dims=[0]x[0].
- updated converter tests to the valid forms.
Result: SKaiNET gemma3 DSL -> StableHLO -> iree-compile (llvm-cpu; +neon
aarch64) -> vmfb, both host x64 and aarch64 targets.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Dumps a small scaledDotProductAttention StableHLO ([1,1,2,4], scale 0.5); iree-compile + iree-run-module output matches a NumPy reference exactly to 5 decimals (QKᵀ·scale·softmax·V). Confirms the attention converter is numerically correct, not just structurally valid. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…idated) When the SDPA node's `causal` attr is set, emit an additive -inf mask before softmax: iota(query axis) / iota(key axis) -> compare GE -> select(keep, 0, -inf) -> add to scaled scores. Each query attends only to keys at or before it. Validated EXACT vs a NumPy causal reference (S=2): query0 -> v[0] only, query1 -> softmax over both keys. iota/compare/select lowering accepted by iree-compile and numerically correct. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…king finalize() now stores resolved weights as the primitive FloatArray instead of .toList() — boxing a real LLM weight (262153x640 embedding -> ~2.7GB List<Float>) OOMed the trace. ConstantOperationsConverter externalizes FloatArray directly (new floatArrayToLittleEndianBytes + tryMaterializeExternalFloats), and inlines via asList() for small/InlineAlways. IrpaWriter writes byte ranges in one shot (byte-at-a-time was pathological for ~670MB tensors). With this, the real FunctionGemma-270M bakes: 1 func arg (tokens) + 360 weights externalized to util.global #flow.parameter.named. (IrpaWriter's archive header is still IREE-incompatible (40B vs IREE's 88B v0 header + different segment layout) — tracked separately; bake currently routes weights via safetensors + iree-convert-parameters.) Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
AttentionOperationsConverter ignored operands[3] (the additive mask) and only masked when causal=true. Gemma sliding-window layers pass an explicit causal+window mask with causal=false, so those layers exported UNMASKED -> attended to future tokens -> A/B vs llama.cpp correct only at position 0. Now broadcast (trailing-aligned) the mask to the scores shape and add it before softmax; the built-in iota causal path remains for the no-mask case. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
📖 Documentation Preview The documentation has been built successfully for this PR. Generated Files:
Artifacts:
This comment will be updated automatically when the PR is updated. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
feat(hlo): DSL→StableHLO path for transformer/gemma3 export
Adds the core machinery to lower a SKaiNET DSL model (incl. gemma3) to
iree-compilable StableHLO, validated end-to-end to a vmfb that runs on
aarch64+NEON. Foundation for on-device LLM inference (DSL → StableHLO → IREE).
What's in here
instead of erasing to
Object(norm layers +NetworkBuilder); the loaderlater overrides with real weight storage type. Unblocks weight-free tracing.
permute,narrow,split(+ multi-outputconverter support),
scaledDotProductAttention(Q·Kᵀ·scale·softmax··V), and anumerically-validated causal mask for SDPA.
VoidTensorOps.gatheroutput shape for multi-dim indices.iree's StableHLO parser so the full gemma3 graph compiles to a vmfb.
.irpabaking — boxing-freeFloatArrayweight externalization.Validation
iree-compiles to bothgemma_cpu.vmfbandgemma_aarch64.vmfb(+neon).Notes
develop. No behavior change for existing paths(dtype params are non-breaking, default-null).
feature/stablehlo-dtype-mirror(the gemmamodule-side mirror + full-model A/B). Merge this first — the transformers
branch was verified against this via composite build.