Skip to content

feat(hlo): DSL→StableHLO path for transformer/gemma3 export#661

Merged
michalharakal merged 10 commits into
developfrom
feature/stablehlo-dtype
Jun 4, 2026
Merged

feat(hlo): DSL→StableHLO path for transformer/gemma3 export#661
michalharakal merged 10 commits into
developfrom
feature/stablehlo-dtype

Conversation

@michalharakal

Copy link
Copy Markdown
Contributor

feat(hlo): DSL→StableHLO path for transformer/gemma3 export

Adds the core machinery to lower a SKaiNET DSL model (incl. gemma3) to
iree-compilable StableHLO, validated end-to-end to a vmfb that runs on
aarch64+NEON. Foundation for on-device LLM inference (DSL → StableHLO → IREE).

What's in here

  • dtype threading — placeholders carry the DSL-prescribed element type
    instead of erasing to Object (norm layers + NetworkBuilder); the loader
    later overrides with real weight storage type. Unblocks weight-free tracing.
  • StableHLO converterspermute, narrow, split (+ multi-output
    converter support), scaledDotProductAttention (Q·Kᵀ·scale·softmax··V), and a
    numerically-validated causal mask for SDPA.
  • trace fixVoidTensorOps.gather output shape for multi-dim indices.
  • iree-valid emission — gather/slice/concat/batch-matmul syntax aligned to
    iree's StableHLO parser so the full gemma3 graph compiles to a vmfb.
  • .irpa baking — boxing-free FloatArray weight externalization.

Validation

  • SDPA lowering matches NumPy to 5 decimals (masked + unmasked); add+relu smoke.
  • Full gemma3 graph lowers with 0 unsupported ops / 0 arity gaps and
    iree-compiles to both gemma_cpu.vmfb and gemma_aarch64.vmfb (+neon).

Notes

  • 10 commits, branched from develop. No behavior change for existing paths
    (dtype params are non-breaking, default-null).
  • Paired with SKaiNET-transformers feature/stablehlo-dtype-mirror (the gemma
    module-side mirror + full-model A/B). Merge this first — the transformers
    branch was verified against this via composite build.

michalharakal and others added 10 commits June 1, 2026 08:25
Re-applied on develop (0.25.0). Norm layers created void placeholder weights
with `Any::class as KClass<T>`, erasing the element type to Object — which
breaks weight-free graph tracing (VoidTensorOps alloc ops -> zeros(Object)
throw "Unsupported dtype: Object"), blocking DAG->StableHLO before weights load.

Fix the root cause: the DSL prescribes the logical element type. Layer/Group/
BatchNormalization get a non-breaking `dtype: KClass<T>? = null` (used as
`(dtype ?: Any::class)`); the NetworkBuilder DSL builders pass `dtype = kClass`.
Real (possibly dequantized) weights still override at load time.

(RMSNormalization moved out of core in 0.25.0 -> handled transformer-side.)

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Closes the last gaps for lowering a gemma3 attention block to StableHLO (the
other ops — matmul/reshape/transpose/squeeze/unsqueeze/concat/softmax/gather/
sqrt/addScalar/scaledDotProductAttention — are already covered on develop).

- permute: arbitrary-axis transpose. convertTranspose already reads the `axes`
  parameter, so register `permute` as an alias routed to it.
- narrow(dim,start,length): single-axis stablehlo.slice. Reads dim/start/length
  (the keys the graph tape records), builds start/limit/stride per dim.
- NarrowPermuteConverterTest verifies both: permute -> transpose dims=[0,2,1],
  narrow -> slice start_indices=[0,2] limit_indices=[2,6]. Full hlo suite green.

Remaining for full gemma: `split` (multi-output, used by RoPE).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The void/tracing gather collapsed the gathered axis to indices.shape[0], so a
[vocab,emb] table with [batch,seq] indices traced to [batch,emb] instead of
[batch,seq,emb] — breaking the embedding's downstream reshape ("volume mismatch
64 != 256") during weight-free tracing. Replace the axis with the FULL indices
shape, matching DefaultCpuOps.gather. Unblocks tracing full transformer
(gemma3) graphs to a ComputeGraph.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Lowers the last gemma3 op gap (split, used by RoPE).

- ConversionContext: per-(nodeId, outputPort) SSA value names (port 0 keeps the
  bare-nodeId key, so all single-output callers are unchanged) + resolveOperands(node)
  that walks incoming edges in destinationInputIndex order and resolves each by
  the edge's sourceOutputIndex.
- StableHloConverter: resolve operands via resolveOperands so a consumer of a
  multi-output op gets the right output (e.g. split chunk N, not chunk 0).
- ShapeOperationsConverter: split/chunk -> N stablehlo.slice, each registered on
  its own output port.
- Test: split -> 2 chunk slices + a relu consuming chunk 1 resolves to chunk 1.
  Full hlo suite green (operand-resolution change is equivalent for single-output).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Adds AttentionOperationsConverter: lowers the atomic scaledDotProductAttention
op to the standard StableHLO subgraph — scores = Q·Kᵀ (dot_general, contract
head_dim), * scale (arg or 1/sqrt(head_dim)), softmax over key length (stable
max/sub/exp/sum/div), out = attn·V (dot_general). Batched [..,S,D]; batching
dims = all leading dims. Registered in StableHloConverterFactory.
v1: attention mask/causal not yet emitted (structurally correct, unmasked) — TODO.

With this, a full gemma3 network lowers to StableHLO with ZERO gaps (verified by
GemmaTraceTest over the composite build: 140 nodes -> 255 lines, 0 unsupported,
0 arity). SDPA is a core TensorOps op so its converter lives in core. Unit test
asserts 2 dot_generals + softmax + scores shape.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…vmfb

Aligned converter emission to what iree-compile's stablehlo parser accepts
(verified by compiling the full gemma3 graph end-to-end):
- gather: use the GENERIC MLIR form `"stablehlo.gather"(%a,%b) <{...}>`
  (stablehlo.gather has no custom assembly form).
- slice/narrow/split: canonical bracket form `stablehlo.slice %x [s:l:st, ...]
  : (in) -> out` (attribute-dict form is rejected) — shared sliceLine() helper.
- concatenate: full functional type `(t0, t1, ...) -> out` (bare `: out` rejected).
- batch matmul: batch dims = leading dims shared by BOTH operands
  (min(lhsRank,rhsRank)-2); fixes 3D-activation @ 2D-weight Linear projections
  that previously emitted mismatched batching_dims=[0]x[0].
- updated converter tests to the valid forms.

Result: SKaiNET gemma3 DSL -> StableHLO -> iree-compile (llvm-cpu; +neon
aarch64) -> vmfb, both host x64 and aarch64 targets.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Dumps a small scaledDotProductAttention StableHLO ([1,1,2,4], scale 0.5);
iree-compile + iree-run-module output matches a NumPy reference exactly to 5
decimals (QKᵀ·scale·softmax·V). Confirms the attention converter is numerically
correct, not just structurally valid.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…idated)

When the SDPA node's `causal` attr is set, emit an additive -inf mask before
softmax: iota(query axis) / iota(key axis) -> compare GE -> select(keep, 0, -inf)
-> add to scaled scores. Each query attends only to keys at or before it.

Validated EXACT vs a NumPy causal reference (S=2): query0 -> v[0] only,
query1 -> softmax over both keys. iota/compare/select lowering accepted by
iree-compile and numerically correct.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…king

finalize() now stores resolved weights as the primitive FloatArray instead of
.toList() — boxing a real LLM weight (262153x640 embedding -> ~2.7GB List<Float>)
OOMed the trace. ConstantOperationsConverter externalizes FloatArray directly
(new floatArrayToLittleEndianBytes + tryMaterializeExternalFloats), and inlines
via asList() for small/InlineAlways. IrpaWriter writes byte ranges in one shot
(byte-at-a-time was pathological for ~670MB tensors).

With this, the real FunctionGemma-270M bakes: 1 func arg (tokens) + 360 weights
externalized to util.global #flow.parameter.named. (IrpaWriter's archive header
is still IREE-incompatible (40B vs IREE's 88B v0 header + different segment
layout) — tracked separately; bake currently routes weights via safetensors +
iree-convert-parameters.)

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
AttentionOperationsConverter ignored operands[3] (the additive mask) and only
masked when causal=true. Gemma sliding-window layers pass an explicit
causal+window mask with causal=false, so those layers exported UNMASKED ->
attended to future tokens -> A/B vs llama.cpp correct only at position 0.
Now broadcast (trailing-aligned) the mask to the scores shape and add it before
softmax; the built-in iota causal path remains for the no-mask case.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@michalharakal michalharakal merged commit 167e1e9 into develop Jun 4, 2026
6 of 7 checks passed
@michalharakal michalharakal deleted the feature/stablehlo-dtype branch June 4, 2026 17:36
@github-actions

github-actions Bot commented Jun 4, 2026

Copy link
Copy Markdown

📖 Documentation Preview

The documentation has been built successfully for this PR.

Generated Files:

  • Operator documentation: docs/modules/operators/_generated_/
  • JSON schema output: operators.json

Artifacts:

  • Download the documentation-preview-661 artifact to view the complete documentation locally.

This comment will be updated automatically when the PR is updated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant