From be60c40a300fe34ccfc6c787eb129caf5e02c2c4 Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Mon, 18 May 2026 21:36:37 -0300 Subject: [PATCH 01/13] Return groups added to Source --- Ix/Aiur/Compiler/Check.lean | 1 + Ix/Aiur/Interpret.lean | 1 + Ix/Aiur/Meta.lean | 6 ++++++ Ix/Aiur/Semantics/SourceEval.lean | 1 + Ix/Aiur/Stages/Source.lean | 1 + Tests/Aiur/Aiur.lean | 30 ++++++++++++++++++++++++++++++ 6 files changed, 40 insertions(+) diff --git a/Ix/Aiur/Compiler/Check.lean b/Ix/Aiur/Compiler/Check.lean index 6970abbb..c0dd5c8f 100644 --- a/Ix/Aiur/Compiler/Check.lean +++ b/Ix/Aiur/Compiler/Check.lean @@ -790,6 +790,7 @@ def inferTerm (t : Term) : CheckM Typed.Term := match t with | some sub => do pure (some (← inferTerm sub)) let ret' ← inferTerm ret pure (Typed.Term.debug ret'.typ ret'.escapes label term' ret') + | .retGroup _ inner => inferTerm inner termination_by (sizeOf t, 0) decreasing_by all_goals first diff --git a/Ix/Aiur/Interpret.lean b/Ix/Aiur/Interpret.lean index 1926af1b..d7751b15 100644 --- a/Ix/Aiur/Interpret.lean +++ b/Ix/Aiur/Interpret.lean @@ -419,6 +419,7 @@ partial def interp (decls : Decls) (bindings : Bindings) : Term → InterpM Valu let dataGs ← expectFieldArray (← interp decls bindings data) modifyIOBuffer fun io => { io with data := io.data ++ dataGs } interp decls bindings ret + | .retGroup _ inner => interp decls bindings inner end diff --git a/Ix/Aiur/Meta.lean b/Ix/Aiur/Meta.lean index c85a5b94..5f3930e2 100644 --- a/Ix/Aiur/Meta.lean +++ b/Ix/Aiur/Meta.lean @@ -180,6 +180,7 @@ syntax "u8_or" "(" aiur_trm ", " aiur_trm ")" : ai syntax "u8_less_than" "(" aiur_trm ", " aiur_trm ")" : aiur_trm syntax "u32_less_than" "(" aiur_trm ", " aiur_trm ")" : aiur_trm syntax "dbg!" "(" str (", " aiur_trm)? ")" ";" (aiur_trm)? : aiur_trm +syntax "#[" "return_group" "(" ident ")" "]" aiur_trm : aiur_trm syntax aiur_trm "[" "@" noWs ident "]" : aiur_trm syntax "set" "(" aiur_trm ", " "@" noWs ident ", " aiur_trm ")" : aiur_trm @@ -310,6 +311,8 @@ partial def elabTrm : ElabStxCat `aiur_trm | none => mkAppOptM ``Option.none #[some (mkConst ``Source.Term)] | some t => mkAppM ``Option.some #[← elabTrm t] mkAppM ``Source.Term.debug #[mkStrLit label.getString, t, ← elabRet ret] + | `(aiur_trm| #[return_group($name:ident)] $t:aiur_trm) => do + mkAppM ``Source.Term.retGroup #[mkStrLit name.getId.toString, ← elabTrm t] -- Template function calls: explicit type args are dropped (inferred) | `(aiur_trm| $f:ident‹$_:aiur_typ $[, $_:aiur_typ]*›()) => do let g ← mkAppM ``Global.mk #[toExpr f.getId] @@ -511,6 +514,9 @@ where let t' ← t.mapM $ replaceToken old new let ret' ← ret.mapM $ replaceToken old new `(aiur_trm| dbg!($label $[, $t']?); $[$ret']?) + | `(aiur_trm| #[return_group($name:ident)] $t:aiur_trm) => do + let t ← replaceToken old new t + `(aiur_trm| #[return_group($name)] $t) | `(aiur_trm| fold($i .. $j, $init, |$acc, @$v| $body)) => do let init ← replaceToken old new init -- Don't conflict with shadowing tokens. diff --git a/Ix/Aiur/Semantics/SourceEval.lean b/Ix/Aiur/Semantics/SourceEval.lean index c0cffea7..fb8c8f8d 100644 --- a/Ix/Aiur/Semantics/SourceEval.lean +++ b/Ix/Aiur/Semantics/SourceEval.lean @@ -474,6 +474,7 @@ def interp (decls : Decls) (fuel : Nat) (bindings : Bindings) { st'.ioBuffer with data := st'.ioBuffer.data ++ dataGs } } interp decls fuel bindings ret st'' | _ => .error (.typeMismatch "ioWrite") + | .retGroup _ inner => interp decls fuel bindings inner st termination_by (fuel, 2, sizeOf t) decreasing_by all_goals first diff --git a/Ix/Aiur/Stages/Source.lean b/Ix/Aiur/Stages/Source.lean index f1209021..e989e470 100644 --- a/Ix/Aiur/Stages/Source.lean +++ b/Ix/Aiur/Stages/Source.lean @@ -389,6 +389,7 @@ inductive Term | u8LessThan : Term → Term → Term | u32LessThan : Term → Term → Term | debug : String → Option Term → Term → Term + | retGroup : String → Term → Term deriving Repr, BEq, Hashable, Inhabited end Source diff --git a/Tests/Aiur/Aiur.lean b/Tests/Aiur/Aiur.lean index 6a0913b3..dd61de5a 100644 --- a/Tests/Aiur/Aiur.lean +++ b/Tests/Aiur/Aiur.lean @@ -705,6 +705,26 @@ def toplevel := ⟦ let (x, y) = ntm_tuple(a); x + y } + -- Return-group annotation: ignored by compiler/typechecker, must pass through + -- to inner term. Match with distinct group labels per arm. + pub fn match_return_groups(x: G) -> G { + match x { + 0 => + #[return_group(zero)] + 100, + 1 => + #[return_group(one)] + x + 200, + 2 => + #[return_group(two_squared)] + x * x * x, + _ => + #[return_group(default)] + #[return_group(nested)] + x + 1, + } + } + pub fn non_tail_match() -> G { -- Basic, early return, sequential, nested, const mul let r1 = ntm_basic(0) + ntm_basic(5); @@ -884,6 +904,16 @@ def aiurTestCases : List AiurTestCase := [ -- Non-tail match: all patterns in one proof .noIO `non_tail_match #[] #[2281], + + -- Return-group annotation: passthrough across match arms + { AiurTestCase.noIO `match_return_groups #[0] #[100] + with label := "match_return_groups(0)" }, + { AiurTestCase.noIO `match_return_groups #[1] #[201] + with label := "match_return_groups(1)" }, + { AiurTestCase.noIO `match_return_groups #[2] #[8] + with label := "match_return_groups(2)" }, + { AiurTestCase.noIO `match_return_groups #[7] #[8] + with label := "match_return_groups(7)" }, ] end From 70dcbfb2db197a2cdf0ddf7263ad04280f7949cf Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Mon, 18 May 2026 21:45:23 -0300 Subject: [PATCH 02/13] Return groups added to other states --- Ix/Aiur/Compiler/Check.lean | 6 +++++- Ix/Aiur/Compiler/Concretize.lean | 8 ++++++++ Ix/Aiur/Compiler/Lower.lean | 3 +++ Ix/Aiur/Compiler/Match.lean | 1 + Ix/Aiur/Compiler/Simple.lean | 3 +++ Ix/Aiur/Stages/Concrete.lean | 5 +++-- Ix/Aiur/Stages/Simple.lean | 5 +++-- Ix/Aiur/Stages/Typed.lean | 5 +++-- 8 files changed, 29 insertions(+), 7 deletions(-) diff --git a/Ix/Aiur/Compiler/Check.lean b/Ix/Aiur/Compiler/Check.lean index c0dd5c8f..eeaacc03 100644 --- a/Ix/Aiur/Compiler/Check.lean +++ b/Ix/Aiur/Compiler/Check.lean @@ -790,7 +790,9 @@ def inferTerm (t : Term) : CheckM Typed.Term := match t with | some sub => do pure (some (← inferTerm sub)) let ret' ← inferTerm ret pure (Typed.Term.debug ret'.typ ret'.escapes label term' ret') - | .retGroup _ inner => inferTerm inner + | .retGroup name inner => do + let inner' ← inferTerm inner + pure (Typed.Term.retGroup inner'.typ inner'.escapes name inner') termination_by (sizeOf t, 0) decreasing_by all_goals first @@ -918,6 +920,8 @@ def zonkTypedTerm (t : Typed.Term) : CheckM Typed.Term := match t with | none => pure none | some sub => do pure (some (← zonkTypedTerm sub)) pure (.debug (← zonkTyp τ) e label t' (← zonkTypedTerm r)) + | .retGroup τ e name inner => do + pure (.retGroup (← zonkTyp τ) e name (← zonkTypedTerm inner)) termination_by sizeOf t decreasing_by all_goals first diff --git a/Ix/Aiur/Compiler/Concretize.lean b/Ix/Aiur/Compiler/Concretize.lean index b9924aee..f6931ae9 100644 --- a/Ix/Aiur/Compiler/Concretize.lean +++ b/Ix/Aiur/Compiler/Concretize.lean @@ -349,6 +349,8 @@ def termToConcrete | none => pure none | some sub => do pure (some (← termToConcrete mono sub)) pure (.debug (← typToConcrete mono τ) e l t' (← termToConcrete mono r)) + | .retGroup τ e name inner => do + pure (.retGroup (← typToConcrete mono τ) e name (← termToConcrete mono inner)) termination_by t => sizeOf t decreasing_by all_goals first @@ -541,6 +543,8 @@ def rewriteTypedTerm (decls : Typed.Decls) | none => none | some sub => some (rewriteTypedTerm decls subst mono sub) .debug (rewriteTyp subst mono τ) e l t' (rewriteTypedTerm decls subst mono r) + | .retGroup τ e name inner => + .retGroup (rewriteTyp subst mono τ) e name (rewriteTypedTerm decls subst mono inner) termination_by t => sizeOf t decreasing_by all_goals first @@ -625,6 +629,7 @@ def collectInTypedTerm (seen : Std.HashSet (Global × Array Typ)) : let seen := collectInTyp seen τ let seen := match t with | some t => collectInTypedTerm seen t | none => seen collectInTypedTerm seen r + | .retGroup τ _ _ inner => collectInTypedTerm (collectInTyp seen τ) inner termination_by t => sizeOf t decreasing_by all_goals first @@ -683,6 +688,7 @@ def collectCalls (decls : Typed.Decls) | .debug _ _ _ t r => let seen := match t with | some t => collectCalls decls seen t | none => seen collectCalls decls seen r + | .retGroup _ _ _ inner => collectCalls decls seen inner termination_by t => sizeOf t decreasing_by all_goals first @@ -771,6 +777,8 @@ def substInTypedTerm (subst : Global → Option Typ) : Typed.Term → Typed.Term | none => none | some sub => some (substInTypedTerm subst sub) .debug (Typ.instantiate subst τ) e l t' (substInTypedTerm subst r) + | .retGroup τ e name inner => + .retGroup (Typ.instantiate subst τ) e name (substInTypedTerm subst inner) termination_by t => sizeOf t decreasing_by all_goals first diff --git a/Ix/Aiur/Compiler/Lower.lean b/Ix/Aiur/Compiler/Lower.lean index a50f4ef2..dd9e63d3 100644 --- a/Ix/Aiur/Compiler/Lower.lean +++ b/Ix/Aiur/Compiler/Lower.lean @@ -299,6 +299,7 @@ def toIndex | some sub => do pure (some (← toIndex layoutMap bindings sub)) modify fun stt => { stt with ops := stt.ops.push (.debug label term) } toIndex layoutMap bindings ret + | .retGroup _ _ _ inner => toIndex layoutMap bindings inner termination_by (sizeOf term, 0) decreasing_by all_goals first @@ -446,6 +447,8 @@ def Concrete.Term.compile let data ← toIndex layoutMap bindings data modify fun stt => { stt with ops := stt.ops.push (.ioWrite data) } ret.compile returnTyp layoutMap bindings yieldCtrl + | .retGroup _ _ _ inner => + inner.compile returnTyp layoutMap bindings yieldCtrl | .match _ _ scrut cases defaultOpt => do let idxs := bindings[scrut]?.getD #[0] let ops ← extractOps diff --git a/Ix/Aiur/Compiler/Match.lean b/Ix/Aiur/Compiler/Match.lean index 3a92d00e..15266385 100644 --- a/Ix/Aiur/Compiler/Match.lean +++ b/Ix/Aiur/Compiler/Match.lean @@ -389,6 +389,7 @@ def typedToSimple : Term → Simple.Term | .debug τ e l t r => let t' := match t with | none => none | some sub => some (typedToSimple sub) .debug τ e l t' (typedToSimple r) + | .retGroup τ e name inner => .retGroup τ e name (typedToSimple inner) termination_by t => sizeOf t decreasing_by all_goals first | decreasing_tactic | grind diff --git a/Ix/Aiur/Compiler/Simple.lean b/Ix/Aiur/Compiler/Simple.lean index ec33f504..64f03477 100644 --- a/Ix/Aiur/Compiler/Simple.lean +++ b/Ix/Aiur/Compiler/Simple.lean @@ -108,6 +108,9 @@ def simplifyTypedTerm (decls : Source.Decls) : Term → Except CheckError Term let a' ← simplifyTypedTerm decls a let b' ← simplifyTypedTerm decls b pure (.u32LessThan τ e a' b') + | .retGroup τ e name inner => do + let inner' ← simplifyTypedTerm decls inner + pure (.retGroup τ e name inner') | t => pure t termination_by t => sizeOf t decreasing_by diff --git a/Ix/Aiur/Stages/Concrete.lean b/Ix/Aiur/Stages/Concrete.lean index 3e9eaa50..dc6f69f0 100644 --- a/Ix/Aiur/Stages/Concrete.lean +++ b/Ix/Aiur/Stages/Concrete.lean @@ -85,6 +85,7 @@ inductive Term : Type where | u8LessThan (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u32LessThan (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | debug (typ : Typ) (escapes : Bool) (label : String) (t : Option Term) (r : Term) : Term + | retGroup (typ : Typ) (escapes : Bool) (name : String) (inner : Term) : Term deriving Repr, Inhabited /-- Get the type annotation of a Concrete.Term, regardless of constructor. -/ @@ -101,7 +102,7 @@ def Term.typ : Term → Typ | .u8BitDecomposition t _ _ | .u8ShiftLeft t _ _ | .u8ShiftRight t _ _ | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Mul t _ _ _ | .u8Sub t _ _ _ | .u8And t _ _ _ | .u8Or t _ _ _ | .u8LessThan t _ _ _ | .u32LessThan t _ _ _ - | .debug t _ _ _ _ => t + | .debug t _ _ _ _ | .retGroup t _ _ _ => t /-- Get the escapes flag of a Concrete.Term. -/ def Term.escapes : Term → Bool @@ -117,7 +118,7 @@ def Term.escapes : Term → Bool | .u8BitDecomposition _ e _ | .u8ShiftLeft _ e _ | .u8ShiftRight _ e _ | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Mul _ e _ _ | .u8Sub _ e _ _ | .u8And _ e _ _ | .u8Or _ e _ _ | .u8LessThan _ e _ _ | .u32LessThan _ e _ _ - | .debug _ e _ _ _ => e + | .debug _ e _ _ _ | .retGroup _ e _ _ => e structure Constructor where nameHead : String diff --git a/Ix/Aiur/Stages/Simple.lean b/Ix/Aiur/Stages/Simple.lean index 996b5d21..28eb7d86 100644 --- a/Ix/Aiur/Stages/Simple.lean +++ b/Ix/Aiur/Stages/Simple.lean @@ -81,6 +81,7 @@ inductive Term : Type where | u8LessThan (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u32LessThan (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | debug (typ : Typ) (escapes : Bool) (label : String) (t : Option Term) (r : Term) : Term + | retGroup (typ : Typ) (escapes : Bool) (name : String) (inner : Term) : Term deriving Repr, Inhabited def Term.typ : Term → Typ @@ -96,7 +97,7 @@ def Term.typ : Term → Typ | .u8BitDecomposition t _ _ | .u8ShiftLeft t _ _ | .u8ShiftRight t _ _ | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Mul t _ _ _ | .u8Sub t _ _ _ | .u8And t _ _ _ | .u8Or t _ _ _ | .u8LessThan t _ _ _ | .u32LessThan t _ _ _ - | .debug t _ _ _ _ => t + | .debug t _ _ _ _ | .retGroup t _ _ _ => t def Term.escapes : Term → Bool | .unit _ e | .var _ e _ | .ref _ e _ _ | .field _ e _ @@ -111,7 +112,7 @@ def Term.escapes : Term → Bool | .u8BitDecomposition _ e _ | .u8ShiftLeft _ e _ | .u8ShiftRight _ e _ | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Mul _ e _ _ | .u8Sub _ e _ _ | .u8And _ e _ _ | .u8Or _ e _ _ | .u8LessThan _ e _ _ | .u32LessThan _ e _ _ - | .debug _ e _ _ _ => e + | .debug _ e _ _ _ | .retGroup _ e _ _ => e structure Function where name : Global diff --git a/Ix/Aiur/Stages/Typed.lean b/Ix/Aiur/Stages/Typed.lean index 2a3c6f33..7f84218e 100644 --- a/Ix/Aiur/Stages/Typed.lean +++ b/Ix/Aiur/Stages/Typed.lean @@ -57,6 +57,7 @@ inductive Term : Type where | u8LessThan (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u32LessThan (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | debug (typ : Typ) (escapes : Bool) (label : String) (t : Option Term) (r : Term) : Term + | retGroup (typ : Typ) (escapes : Bool) (name : String) (inner : Term) : Term deriving Repr, Inhabited /-- Get the type annotation, regardless of constructor. -/ @@ -72,7 +73,7 @@ def Term.typ : Term → Typ | .u8BitDecomposition t _ _ | .u8ShiftLeft t _ _ | .u8ShiftRight t _ _ | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Mul t _ _ _ | .u8Sub t _ _ _ | .u8And t _ _ _ | .u8Or t _ _ _ | .u8LessThan t _ _ _ | .u32LessThan t _ _ _ - | .debug t _ _ _ _ => t + | .debug t _ _ _ _ | .retGroup t _ _ _ => t /-- Get the escapes flag. -/ def Term.escapes : Term → Bool @@ -87,7 +88,7 @@ def Term.escapes : Term → Bool | .u8BitDecomposition _ e _ | .u8ShiftLeft _ e _ | .u8ShiftRight _ e _ | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Mul _ e _ _ | .u8Sub _ e _ _ | .u8And _ e _ _ | .u8Or _ e _ _ | .u8LessThan _ e _ _ | .u32LessThan _ e _ _ - | .debug _ e _ _ _ => e + | .debug _ e _ _ _ | .retGroup _ e _ _ => e structure Function where name : Global From 843befa679964ed701353686068ff595f3980b61 Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Mon, 18 May 2026 21:53:21 -0300 Subject: [PATCH 03/13] Added groups to Lean `Bytecode.return` --- Ix/Aiur/Compiler.lean | 2 +- Ix/Aiur/Compiler/Dedup.lean | 4 ++-- Ix/Aiur/Compiler/Lower.lean | 17 ++++++++++++----- Ix/Aiur/Semantics/BytecodeEval.lean | 2 +- Ix/Aiur/Stages/Bytecode.lean | 2 +- src/ffi/aiur/toplevel.rs | 2 +- 6 files changed, 18 insertions(+), 11 deletions(-) diff --git a/Ix/Aiur/Compiler.lean b/Ix/Aiur/Compiler.lean index ae250399..4ae98ad4 100644 --- a/Ix/Aiur/Compiler.lean +++ b/Ix/Aiur/Compiler.lean @@ -68,7 +68,7 @@ def Bytecode.Ctrl.collectConstrainedCallees (c : Bytecode.Ctrl) : | some block => branchCallees ++ block.collectConstrainedCallees | none => branchCallees withDefault ++ continuation.collectConstrainedCallees - | .return _ _ | .yield _ _ => #[] + | .return _ _ _ | .yield _ _ => #[] termination_by (sizeOf c, 0) decreasing_by all_goals first diff --git a/Ix/Aiur/Compiler/Dedup.lean b/Ix/Aiur/Compiler/Dedup.lean index 8a9d70bf..6b364d9a 100644 --- a/Ix/Aiur/Compiler/Dedup.lean +++ b/Ix/Aiur/Compiler/Dedup.lean @@ -31,7 +31,7 @@ mutual | .match v branches def_ => .match v (branches.attach.map fun ⟨(g, b), _⟩ => (g, skeletonBlock b)) (match def_ with | none => none | some b => some (skeletonBlock b)) - | .return s vs => .return s vs + | .return s g vs => .return s g vs | .yield s vs => .yield s vs | .matchContinue v branches def_ outputSize sharedAux sharedLookups cont => .matchContinue v (branches.attach.map fun ⟨(g, b), _⟩ => (g, skeletonBlock b)) @@ -107,7 +107,7 @@ mutual | .match v branches def_ => .match v (branches.attach.map fun ⟨(g, b), _⟩ => (g, rewriteBlock f b)) (match def_ with | none => none | some b => some (rewriteBlock f b)) - | .return s vs => .return s vs + | .return s g vs => .return s g vs | .yield s vs => .yield s vs | .matchContinue v branches def_ outputSize sharedAux sharedLookups cont => .matchContinue v (branches.attach.map fun ⟨(g, b), _⟩ => (g, rewriteBlock f b)) diff --git a/Ix/Aiur/Compiler/Lower.lean b/Ix/Aiur/Compiler/Lower.lean index dd9e63d3..bad47fbe 100644 --- a/Ix/Aiur/Compiler/Lower.lean +++ b/Ix/Aiur/Compiler/Lower.lean @@ -94,6 +94,7 @@ structure CompilerState where ops : Array Bytecode.Op selIdx : Bytecode.SelIdx degrees : Array Nat + currentReturnGroup : String := "" deriving Inhabited abbrev CompileM := EStateM String CompilerState @@ -447,8 +448,12 @@ def Concrete.Term.compile let data ← toIndex layoutMap bindings data modify fun stt => { stt with ops := stt.ops.push (.ioWrite data) } ret.compile returnTyp layoutMap bindings yieldCtrl - | .retGroup _ _ _ inner => - inner.compile returnTyp layoutMap bindings yieldCtrl + | .retGroup _ _ name inner => do + let oldGroup := (← get).currentReturnGroup + modify fun s => { s with currentReturnGroup := name } + let blk ← inner.compile returnTyp layoutMap bindings yieldCtrl + modify fun s => { s with currentReturnGroup := oldGroup } + pure blk | .match _ _ scrut cases defaultOpt => do let idxs := bindings[scrut]?.getD #[0] let ops ← extractOps @@ -468,7 +473,7 @@ def Concrete.Term.compile set state let ops := state.ops let id := state.selIdx - pure ({ ops, ctrl := .return (id - 1) idxs } : Bytecode.Block) + pure ({ ops, ctrl := .return (id - 1) state.currentReturnGroup idxs } : Bytecode.Block) | _ => do let idxs ← toIndex layoutMap bindings term let state ← get @@ -477,7 +482,8 @@ def Concrete.Term.compile let ops := state.ops let id := state.selIdx let ctrl : Bytecode.Ctrl := - if yieldCtrl && !term.escapes then .yield (id - 1) idxs else .return (id - 1) idxs + if yieldCtrl && !term.escapes then .yield (id - 1) idxs + else .return (id - 1) state.currentReturnGroup idxs pure ({ ops, ctrl } : Bytecode.Block) termination_by (sizeOf term, 0) decreasing_by @@ -545,7 +551,8 @@ def Concrete.Function.compile (layoutMap : LayoutMap) (f : Concrete.Function) : | .ok len => pure len let indices := Array.range' valIdx len pure (valIdx + len, bindings.insert arg indices) - let state := { valIdx, selIdx := 0, ops := #[], degrees := Array.replicate valIdx 1 } + let state := { valIdx, selIdx := 0, ops := #[], degrees := Array.replicate valIdx 1, + currentReturnGroup := "" } match f.body.compile f.output layoutMap bindings |>.run state with | .error e _ => throw e | .ok body _ => diff --git a/Ix/Aiur/Semantics/BytecodeEval.lean b/Ix/Aiur/Semantics/BytecodeEval.lean index a45c377e..2db2cae5 100644 --- a/Ix/Aiur/Semantics/BytecodeEval.lean +++ b/Ix/Aiur/Semantics/BytecodeEval.lean @@ -301,7 +301,7 @@ decreasing_by def evalCtrl (t : Bytecode.Toplevel) (fuel : Nat) (ctrl : Ctrl) (st : EvalState) : Except BytecodeError (Array G × EvalState) := match ctrl with - | .return _ outs => + | .return _ _ outs => match readIdxs st outs with | .error e => .error e | .ok gs => .ok (gs, st) diff --git a/Ix/Aiur/Stages/Bytecode.lean b/Ix/Aiur/Stages/Bytecode.lean index 527e5431..be8aeeaf 100644 --- a/Ix/Aiur/Stages/Bytecode.lean +++ b/Ix/Aiur/Stages/Bytecode.lean @@ -49,7 +49,7 @@ inductive Op mutual inductive Ctrl where | match : ValIdx → Array (G × Block) → Option Block → Ctrl - | return : SelIdx → Array ValIdx → Ctrl + | return : SelIdx → (group : String) → Array ValIdx → Ctrl | yield : SelIdx → Array ValIdx → Ctrl | matchContinue : ValIdx → Array (G × Block) → Option Block → (outputSize : Nat) → (sharedAuxiliaries : Nat) → (sharedLookups : Nat) diff --git a/src/ffi/aiur/toplevel.rs b/src/ffi/aiur/toplevel.rs index e1f0df73..b98dfb40 100644 --- a/src/ffi/aiur/toplevel.rs +++ b/src/ffi/aiur/toplevel.rs @@ -169,7 +169,7 @@ fn decode_ctrl(ctor: LeanCtor>) -> Ctrl { Ctrl::Match(val_idx, cases, default) }, 1 => { - let [sel_idx_obj, val_idxs_obj] = ctor.objs::<2>(); + let [sel_idx_obj, _group_obj, val_idxs_obj] = ctor.objs::<3>(); let sel_idx = lean_unbox_nat_as_usize(&sel_idx_obj); let val_idxs = decode_vec_val_idx(val_idxs_obj); Ctrl::Return(sel_idx, val_idxs) From a2a7b7fd81548f69720369f62bfb58b85d1157ad Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Tue, 19 May 2026 09:24:16 -0300 Subject: [PATCH 04/13] Added return groups to `bytecode.rs` --- src/aiur/bytecode.rs | 2 +- src/aiur/constraints.rs | 4 ++-- src/aiur/execute.rs | 2 +- src/aiur/trace.rs | 2 +- src/ffi/aiur/toplevel.rs | 5 +++-- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/aiur/bytecode.rs b/src/aiur/bytecode.rs index 10e27cf7..b75b5372 100644 --- a/src/aiur/bytecode.rs +++ b/src/aiur/bytecode.rs @@ -63,7 +63,7 @@ pub enum Op { pub enum Ctrl { Match(ValIdx, FxIndexMap, Option>), - Return(SelIdx, Vec), + Return(SelIdx, String, Vec), Yield(SelIdx, Vec), MatchContinue( ValIdx, diff --git a/src/aiur/constraints.rs b/src/aiur/constraints.rs index 4f0d6c68..aec608fb 100644 --- a/src/aiur/constraints.rs +++ b/src/aiur/constraints.rs @@ -172,7 +172,7 @@ impl Block { /// be double-counted. fn get_block_selector(&self, state: &ConstraintState) -> Expr { match &self.ctrl { - Ctrl::Return(sel, _) | Ctrl::Yield(sel, _) => { + Ctrl::Return(sel, _, _) | Ctrl::Yield(sel, _) => { var(state.selector_index(*sel)) }, Ctrl::Match(_, cases, def) | Ctrl::MatchContinue(_, cases, def, ..) => { @@ -234,7 +234,7 @@ impl Ctrl { #[allow(clippy::needless_pass_by_value)] fn collect_constraints(&self, sel: Expr, state: &mut ConstraintState) { match self { - Ctrl::Return(_, values) => { + Ctrl::Return(_, _, values) => { // channel and function index let mut args = vec![ sel.clone() * function_channel(), diff --git a/src/aiur/execute.rs b/src/aiur/execute.rs index 9e607ebc..6527be2b 100644 --- a/src/aiur/execute.rs +++ b/src/aiur/execute.rs @@ -480,7 +480,7 @@ impl Function { map.extend(yielded); push_block_exec_entries!(cont.block); }, - ExecEntry::Ctrl(Ctrl::Return(_, output)) => { + ExecEntry::Ctrl(Ctrl::Return(_, _, output)) => { // Register the query. let input_size = toplevel.functions[fun_idx].layout.input_size; let args = map[..input_size].to_vec(); diff --git a/src/aiur/trace.rs b/src/aiur/trace.rs index 02d5a928..6b556f1f 100644 --- a/src/aiur/trace.rs +++ b/src/aiur/trace.rs @@ -203,7 +203,7 @@ impl Ctrl { io_buffer: &IOBuffer, ) -> PopulateResult { match self { - Ctrl::Return(sel, _) => { + Ctrl::Return(sel, _, _) => { slice.selectors[*sel] = G::ONE; let lookup = function_lookup( -context.multiplicity, diff --git a/src/ffi/aiur/toplevel.rs b/src/ffi/aiur/toplevel.rs index b98dfb40..f01098c3 100644 --- a/src/ffi/aiur/toplevel.rs +++ b/src/ffi/aiur/toplevel.rs @@ -169,10 +169,11 @@ fn decode_ctrl(ctor: LeanCtor>) -> Ctrl { Ctrl::Match(val_idx, cases, default) }, 1 => { - let [sel_idx_obj, _group_obj, val_idxs_obj] = ctor.objs::<3>(); + let [sel_idx_obj, group_obj, val_idxs_obj] = ctor.objs::<3>(); let sel_idx = lean_unbox_nat_as_usize(&sel_idx_obj); + let group = group_obj.as_string().to_string(); let val_idxs = decode_vec_val_idx(val_idxs_obj); - Ctrl::Return(sel_idx, val_idxs) + Ctrl::Return(sel_idx, group, val_idxs) }, 2 => { let [sel_idx_obj, val_idxs_obj] = ctor.objs::<2>(); From 5f7559abc634a440604c16399b3075ae06fac8d1 Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Tue, 19 May 2026 09:34:18 -0300 Subject: [PATCH 05/13] Implemented `Function.split` which collects all filtered sub-functions for each particular return group --- src/aiur.rs | 1 + src/aiur/bytecode.rs | 3 + src/aiur/split.rs | 138 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 142 insertions(+) create mode 100644 src/aiur/split.rs diff --git a/src/aiur.rs b/src/aiur.rs index 2e904331..3ec67f78 100644 --- a/src/aiur.rs +++ b/src/aiur.rs @@ -3,6 +3,7 @@ pub mod constraints; pub mod execute; pub mod gadgets; pub mod memory; +pub mod split; pub mod synthesis; pub mod trace; diff --git a/src/aiur/bytecode.rs b/src/aiur/bytecode.rs index b75b5372..fbcc3c0c 100644 --- a/src/aiur/bytecode.rs +++ b/src/aiur/bytecode.rs @@ -28,11 +28,13 @@ impl FunctionLayout { } } +#[derive(Clone)] pub struct Block { pub(crate) ops: Vec, pub(crate) ctrl: Ctrl, } +#[derive(Clone)] pub enum Op { Const(G), Add(ValIdx, ValIdx), @@ -61,6 +63,7 @@ pub enum Op { Debug(String, Option>), } +#[derive(Clone)] pub enum Ctrl { Match(ValIdx, FxIndexMap, Option>), Return(SelIdx, String, Vec), diff --git a/src/aiur/split.rs b/src/aiur/split.rs new file mode 100644 index 00000000..574f737a --- /dev/null +++ b/src/aiur/split.rs @@ -0,0 +1,138 @@ +use rustc_hash::{FxHashMap, FxHashSet}; + +use crate::FxIndexMap; + +use super::bytecode::{Block, Ctrl, Function}; + +impl Function { + pub fn return_groups(&self) -> FxHashSet { + let mut groups = FxHashSet::default(); + collect_block(&self.body, &mut groups); + groups + } + + /// Build a `Function` containing only the control-flow paths that can reach a + /// `Ctrl::Return` whose group matches `target`. Returns `None` when no such + /// path exists. `Yield` branches are preserved unconditionally — their + /// reachability is decided by the surrounding `MatchContinue`'s continuation. + pub fn filter_group(&self, target: &str) -> Option { + let body = filter_block(&self.body, target)?; + Some(Function { + body, + layout: self.layout, + entry: self.entry, + constrained: self.constrained, + }) + } + + pub fn split(&self) -> FxHashMap { + self + .return_groups() + .into_iter() + .map(|group| { + let filtered = self.filter_group(&group).unwrap_or_else(|| { + panic!("function contains an unreachable group: {group}") + }); + (group, filtered) + }) + .collect() + } +} + +fn filter_block(block: &Block, target: &str) -> Option { + let ctrl = filter_ctrl(&block.ctrl, target)?; + Some(Block { ops: block.ops.clone(), ctrl }) +} + +fn filter_ctrl(ctrl: &Ctrl, target: &str) -> Option { + match ctrl { + Ctrl::Return(sel, group, vs) => { + if group == target { + Some(Ctrl::Return(*sel, group.clone(), vs.clone())) + } else { + None + } + }, + Ctrl::Yield(sel, vs) => Some(Ctrl::Yield(*sel, vs.clone())), + Ctrl::Match(scrut, cases, default) => { + let new_cases = filter_cases(cases, target); + let new_default = + default.as_ref().and_then(|b| filter_block(b, target).map(Box::new)); + if new_cases.is_empty() && new_default.is_none() { + None + } else { + Some(Ctrl::Match(*scrut, new_cases, new_default)) + } + }, + Ctrl::MatchContinue( + scrut, + cases, + default, + output_size, + shared_aux, + shared_lookups, + cont, + ) => { + let new_cont = filter_block(cont, target)?; + let new_cases = filter_cases(cases, target); + let new_default = + default.as_ref().and_then(|b| filter_block(b, target).map(Box::new)); + if new_cases.is_empty() && new_default.is_none() { + None + } else { + Some(Ctrl::MatchContinue( + *scrut, + new_cases, + new_default, + *output_size, + *shared_aux, + *shared_lookups, + Box::new(new_cont), + )) + } + }, + } +} + +fn filter_cases( + cases: &FxIndexMap, + target: &str, +) -> FxIndexMap { + let mut new_cases = FxIndexMap::default(); + for (&k, blk) in cases { + if let Some(b) = filter_block(blk, target) { + new_cases.insert(k, b); + } + } + new_cases +} + +fn collect_block(block: &Block, groups: &mut FxHashSet) { + collect_ctrl(&block.ctrl, groups); +} + +fn collect_ctrl(ctrl: &Ctrl, groups: &mut FxHashSet) { + match ctrl { + Ctrl::Return(_, group, _) => { + groups.insert(group.clone()); + }, + Ctrl::Yield(..) => {}, + Ctrl::Match(_, cases, default) => { + for branch in cases.values() { + collect_block(branch, groups); + } + if let Some(branch) = default { + collect_block(branch, groups); + } + }, + Ctrl::MatchContinue(_, cases, default, _, _, _, cont) => { + for branch in cases.values() { + collect_block(branch, groups); + } + if let Some(branch) = default { + collect_block(branch, groups); + } + collect_block(cont, groups); + }, + } +} From e2528b6a6ee07fbb3c85ae63b641736666388552 Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Tue, 19 May 2026 11:07:11 -0300 Subject: [PATCH 06/13] Layout fix on split functions --- src/aiur.rs | 1 + src/aiur/bytecode.rs | 4 + src/aiur/layout.rs | 247 +++++++++++++++++++++++++++++++++++++++ src/aiur/split.rs | 43 ++++++- src/ffi/aiur/toplevel.rs | 3 +- 5 files changed, 296 insertions(+), 2 deletions(-) create mode 100644 src/aiur/layout.rs diff --git a/src/aiur.rs b/src/aiur.rs index 3ec67f78..bf6005d4 100644 --- a/src/aiur.rs +++ b/src/aiur.rs @@ -2,6 +2,7 @@ pub mod bytecode; pub mod constraints; pub mod execute; pub mod gadgets; +pub mod layout; pub mod memory; pub mod split; pub mod synthesis; diff --git a/src/aiur/bytecode.rs b/src/aiur/bytecode.rs index fbcc3c0c..9306edff 100644 --- a/src/aiur/bytecode.rs +++ b/src/aiur/bytecode.rs @@ -1,9 +1,13 @@ +use rustc_hash::FxHashMap; + use crate::FxIndexMap; use super::G; pub struct Toplevel { pub(crate) functions: Vec, + #[allow(dead_code)] + pub(crate) filtered_functions: Vec>, pub(crate) memory_sizes: Vec, } diff --git a/src/aiur/layout.rs b/src/aiur/layout.rs new file mode 100644 index 00000000..25b73298 --- /dev/null +++ b/src/aiur/layout.rs @@ -0,0 +1,247 @@ +use std::collections::BTreeSet; + +use super::bytecode::{Block, Ctrl, Function, FunctionLayout, Op, ValIdx}; + +/// Compute the `FunctionLayout` (input_size / selectors / auxiliaries / +/// lookups) of a `Function` by walking its bytecode. Port of +/// `Ix/Aiur/Compiler/Layout.lean::blockLayout`. +pub fn compute_layout(function: &Function) -> FunctionLayout { + let input_size = function.layout.input_size; + let mut state = LayoutState::new(input_size); + state.block_layout(&function.body); + // Reserve one lookup slot for the function's return lookup (mirrors the + // `+ 1` added after `blockLayout` in `Concrete.Function.compile`). + state.function_layout.lookups += 1; + state.function_layout +} + +#[derive(Clone, Copy)] +struct SharedData { + auxiliaries: usize, + lookups: usize, +} + +impl SharedData { + fn maximals(self, other: Self) -> Self { + SharedData { + auxiliaries: self.auxiliaries.max(other.auxiliaries), + lookups: self.lookups.max(other.lookups), + } + } +} + +struct LayoutState { + function_layout: FunctionLayout, + #[allow(dead_code)] + mem_sizes: BTreeSet, + degrees: Vec, +} + +impl LayoutState { + fn new(input_size: usize) -> Self { + LayoutState { + function_layout: FunctionLayout { + input_size, + selectors: 0, + auxiliaries: 1, + lookups: 0, + }, + mem_sizes: BTreeSet::new(), + degrees: vec![1; input_size], + } + } + + fn bump_selectors(&mut self, n: usize) { + self.function_layout.selectors += n; + } + + fn bump_lookups(&mut self, n: usize) { + self.function_layout.lookups += n; + } + + fn bump_auxiliaries(&mut self, n: usize) { + self.function_layout.auxiliaries += n; + } + + fn add_mem_size(&mut self, size: usize) { + self.mem_sizes.insert(size); + } + + fn push_degree(&mut self, d: usize) { + self.degrees.push(d); + } + + fn push_degrees(&mut self, ds: &[usize]) { + self.degrees.extend_from_slice(ds); + } + + fn get_degree(&self, v: ValIdx) -> usize { + self.degrees.get(v).copied().unwrap_or(0) + } + + fn get_shared(&self) -> SharedData { + SharedData { + auxiliaries: self.function_layout.auxiliaries, + lookups: self.function_layout.lookups, + } + } + + fn set_shared(&mut self, s: SharedData) { + self.function_layout.auxiliaries = s.auxiliaries; + self.function_layout.lookups = s.lookups; + } + + fn op_layout(&mut self, op: &Op) { + match op { + Op::Const(_) => self.push_degree(0), + Op::Add(a, b) | Op::Sub(a, b) => { + let d = self.get_degree(*a).max(self.get_degree(*b)); + self.push_degree(d); + }, + Op::Mul(a, b) => { + let d = self.get_degree(*a) + self.get_degree(*b); + if d < 2 { + self.push_degree(d); + } else { + self.push_degree(1); + self.bump_auxiliaries(1); + } + }, + Op::EqZero(a) => { + let d = self.get_degree(*a); + if d == 0 { + self.push_degree(0); + } else { + self.push_degree(1); + self.bump_auxiliaries(2); + } + }, + Op::Call(_, _, output_size, unconstrained) => { + let ones = vec![1usize; *output_size]; + self.push_degrees(&ones); + self.bump_auxiliaries(*output_size); + if !*unconstrained { + self.bump_lookups(1); + } + }, + Op::Store(values) => { + self.push_degree(1); + self.bump_auxiliaries(1); + self.bump_lookups(1); + self.add_mem_size(values.len()); + }, + Op::Load(size, _) => { + let ones = vec![1usize; *size]; + self.push_degrees(&ones); + self.bump_auxiliaries(*size); + self.bump_lookups(1); + self.add_mem_size(*size); + }, + Op::AssertEq(..) + | Op::IOSetInfo(..) + | Op::IOWrite(..) + | Op::Debug(..) => {}, + Op::IOGetInfo(_) => { + self.push_degrees(&[1, 1]); + self.bump_auxiliaries(2); + }, + Op::IORead(_, len) => { + let ones = vec![1usize; *len]; + self.push_degrees(&ones); + self.bump_auxiliaries(*len); + }, + Op::U8BitDecomposition(_) => { + self.push_degrees(&[1; 8]); + self.bump_auxiliaries(8); + self.bump_lookups(1); + }, + Op::U8ShiftLeft(_) + | Op::U8ShiftRight(_) + | Op::U8Xor(..) + | Op::U8And(..) + | Op::U8Or(..) + | Op::U8LessThan(..) => { + self.push_degree(1); + self.bump_auxiliaries(1); + self.bump_lookups(1); + }, + Op::U8Add(..) | Op::U8Mul(..) | Op::U8Sub(..) => { + self.push_degrees(&[1, 1]); + self.bump_auxiliaries(2); + self.bump_lookups(1); + }, + Op::U32LessThan(..) => { + self.push_degree(1); + self.bump_auxiliaries(12); + self.bump_lookups(6); + }, + } + } + + fn block_layout(&mut self, block: &Block) { + for op in &block.ops { + self.op_layout(op); + } + self.ctrl_layout(&block.ctrl); + } + + fn ctrl_layout(&mut self, ctrl: &Ctrl) { + match ctrl { + Ctrl::Return(..) | Ctrl::Yield(..) => self.bump_selectors(1), + Ctrl::Match(_, branches, default) => { + let init_shared = self.get_shared(); + let degrees_save = self.degrees.clone(); + let mut max_shared = init_shared; + for branch in branches.values() { + self.set_shared(init_shared); + self.block_layout(branch); + let branch_shared = self.get_shared(); + self.degrees = degrees_save.clone(); + max_shared = max_shared.maximals(branch_shared); + } + if let Some(default_block) = default { + self.set_shared(init_shared); + self.bump_auxiliaries(branches.len()); + self.block_layout(default_block); + let default_shared = self.get_shared(); + self.degrees = degrees_save.clone(); + max_shared = max_shared.maximals(default_shared); + } + self.set_shared(max_shared); + }, + Ctrl::MatchContinue( + _, + branches, + default, + output_size, + _shared_aux, + _shared_lookups, + continuation, + ) => { + let init_shared = self.get_shared(); + let degrees_save = self.degrees.clone(); + let mut max_shared = init_shared; + for branch in branches.values() { + self.set_shared(init_shared); + self.block_layout(branch); + let branch_shared = self.get_shared(); + self.degrees = degrees_save.clone(); + max_shared = max_shared.maximals(branch_shared); + } + if let Some(default_block) = default { + self.set_shared(init_shared); + self.bump_auxiliaries(branches.len()); + self.block_layout(default_block); + let default_shared = self.get_shared(); + self.degrees = degrees_save.clone(); + max_shared = max_shared.maximals(default_shared); + } + self.set_shared(max_shared); + self.bump_auxiliaries(*output_size); + let ones = vec![1usize; *output_size]; + self.push_degrees(&ones); + self.block_layout(continuation); + }, + } + } +} diff --git a/src/aiur/split.rs b/src/aiur/split.rs index 574f737a..f58e79dc 100644 --- a/src/aiur/split.rs +++ b/src/aiur/split.rs @@ -3,6 +3,7 @@ use rustc_hash::{FxHashMap, FxHashSet}; use crate::FxIndexMap; use super::bytecode::{Block, Ctrl, Function}; +use super::layout::compute_layout; impl Function { pub fn return_groups(&self) -> FxHashSet { @@ -25,6 +26,16 @@ impl Function { }) } + /// Renumber selectors in traversal order and recompute the circuit layout. + /// Intended for functions produced by `filter_group`, whose selector indices + /// and layout are inherited from the original (pre-filter) function. + fn fix(mut self) -> Self { + let mut counter: usize = 0; + fix_block_sel(&mut self.body, &mut counter); + self.layout = compute_layout(&self); + self + } + pub fn split(&self) -> FxHashMap { self .return_groups() @@ -33,7 +44,7 @@ impl Function { let filtered = self.filter_group(&group).unwrap_or_else(|| { panic!("function contains an unreachable group: {group}") }); - (group, filtered) + (group, filtered.fix()) }) .collect() } @@ -94,6 +105,36 @@ fn filter_ctrl(ctrl: &Ctrl, target: &str) -> Option { } } +fn fix_block_sel(block: &mut Block, counter: &mut usize) { + fix_ctrl_sel(&mut block.ctrl, counter); +} + +fn fix_ctrl_sel(ctrl: &mut Ctrl, counter: &mut usize) { + match ctrl { + Ctrl::Return(sel, _, _) | Ctrl::Yield(sel, _) => { + *sel = *counter; + *counter += 1; + }, + Ctrl::Match(_, cases, default) => { + for branch in cases.values_mut() { + fix_block_sel(branch, counter); + } + if let Some(branch) = default { + fix_block_sel(branch, counter); + } + }, + Ctrl::MatchContinue(_, cases, default, _, _, _, cont) => { + for branch in cases.values_mut() { + fix_block_sel(branch, counter); + } + if let Some(branch) = default { + fix_block_sel(branch, counter); + } + fix_block_sel(cont, counter); + }, + } +} + fn filter_cases( cases: &FxIndexMap, target: &str, diff --git a/src/ffi/aiur/toplevel.rs b/src/ffi/aiur/toplevel.rs index f01098c3..945537a2 100644 --- a/src/ffi/aiur/toplevel.rs +++ b/src/ffi/aiur/toplevel.rs @@ -255,5 +255,6 @@ pub(crate) fn decode_toplevel( functions_obj.as_array().map(|o| decode_function(o.as_ctor())); let memory_sizes = memory_sizes_obj.as_array().map(|x| lean_unbox_nat_as_usize(&x)); - Toplevel { functions, memory_sizes } + let filtered_functions = functions.iter().map(|f| f.split()).collect(); + Toplevel { functions, memory_sizes, filtered_functions } } From 3fbc8219801c2cf8a30ca8d956491c1ade2cb302 Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Tue, 19 May 2026 17:38:47 -0300 Subject: [PATCH 07/13] Per group circuits --- Ix/Aiur/Semantics/BytecodeFfi.lean | 25 +++++-- Ix/Aiur/Statistics.lean | 27 +++++--- Kernel.lean | 2 +- src/aiur/constraints.rs | 5 +- src/aiur/execute.rs | 5 +- src/aiur/synthesis.rs | 53 ++++++++++----- src/aiur/trace.rs | 9 ++- src/ffi/aiur/protocol.rs | 104 ++++++++++++++++++++++------- 8 files changed, 166 insertions(+), 64 deletions(-) diff --git a/Ix/Aiur/Semantics/BytecodeFfi.lean b/Ix/Aiur/Semantics/BytecodeFfi.lean index 08a99894..b0a6bed3 100644 --- a/Ix/Aiur/Semantics/BytecodeFfi.lean +++ b/Ix/Aiur/Semantics/BytecodeFfi.lean @@ -61,27 +61,38 @@ structure QueryCount where namespace Bytecode.Toplevel +/-- Per-function execution stats. One entry per split (return group), sorted +by group name. Each quadruple is `(group, totalWidth, uniqueRows, totalHits)`. -/ +abbrev FunctionStats := Array (Array (String × Nat × Nat × Nat)) + +/-- Per-memory-size `(uniqueRows, totalHits)` pairs. -/ +abbrev MemoryCounts := Array (Nat × Nat) + +/-- Query counts shipped back from the Rust executor: per-function split stats +plus per-memory pairs. -/ +abbrev QueryCounts := FunctionStats × MemoryCounts + @[extern "rs_aiur_toplevel_execute"] private opaque execute' : @& Bytecode.Toplevel → @& Bytecode.FunIdx → @& Array G → (ioData : @& Array G) → (ioMap : @& Array (Array G × IOKeyInfo)) → - Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) × Array (Nat × Nat)) + Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) × QueryCounts) /-- Executes the bytecode function `funIdx` with the given `args` and `ioBuffer`, -returning the raw output of the function, the updated `IOBuffer`, and an array -of per-circuit `QueryCount`s. Returns `Except.error msg` when execution -fails (e.g. `assert_eq!` mismatch from a typechecker rejecting a constant), so -callers can recover instead of crashing. -/ +returning the raw output of the function, the updated `IOBuffer`, and a +`QueryCounts` (per-function split stats + per-memory `(uniqueRows, totalHits)` +pairs). Returns `Except.error msg` when execution fails (e.g. `assert_eq!` +mismatch from a typechecker rejecting a constant), so callers can recover +instead of crashing. -/ def execute (toplevel : @& Bytecode.Toplevel) (funIdx : @& Bytecode.FunIdx) (args : @& Array G) (ioBuffer : IOBuffer) : - Except String (Array G × IOBuffer × Array QueryCount) := + Except String (Array G × IOBuffer × QueryCounts) := let ioData := ioBuffer.data let ioMap := ioBuffer.map match execute' toplevel funIdx args ioData ioMap.toArray with | .error e => .error e | .ok (output, (ioData, ioMap), queryCounts) => let ioMap := ioMap.foldl (fun acc (k, v) => acc.insert k v) ∅ - let queryCounts := queryCounts.map fun (uniqueRows, totalHits) => { uniqueRows, totalHits } .ok (output, ⟨ioData, ioMap⟩, queryCounts) end Bytecode.Toplevel diff --git a/Ix/Aiur/Statistics.lean b/Ix/Aiur/Statistics.lean index ddcf082b..4b85307f 100644 --- a/Ix/Aiur/Statistics.lean +++ b/Ix/Aiur/Statistics.lean @@ -39,8 +39,9 @@ def fftCost (w h : Nat) : Float := let hf := h.toFloat wf * hf * (max hf 2.0).log2 -def computeStats (compiled : CompiledToplevel) (queryCounts : Array QueryCount) : - ExecutionStats := +def computeStats (compiled : CompiledToplevel) + (functionStats : Array (Array (String × Nat × Nat × Nat))) + (memoryCounts : Array (Nat × Nat)) : ExecutionStats := let t := compiled.bytecode -- Invert nameMap to get FunIdx → String let reverseMap := compiled.nameMap.fold (init := (∅ : Std.HashMap Bytecode.FunIdx String)) @@ -50,18 +51,22 @@ def computeStats (compiled : CompiledToplevel) (queryCounts : Array QueryCount) let mut acc := #[] for i in [:nAllFuns] do if t.functions[i]!.constrained then - let w := t.functions[i]!.layout.totalWidth - let qc := queryCounts[i]! - let h := qc.uniqueRows - let hits := qc.totalHits - qc.uniqueRows - let name := reverseMap[i]?.getD s!"" - acc := acc.push { name, width := w, height := h, cacheHits := hits, fftCost := fftCost w h : CircuitStats } + let baseName := reverseMap[i]?.getD s!"" + for quad in functionStats[i]! do + let group := quad.1 + let w := quad.2.1 + let h := quad.2.2.1 + let totalHits := quad.2.2.2 + let hits := totalHits - h + let name := if group.isEmpty then baseName else s!"{baseName} [{group}]" + acc := acc.push { name, width := w, height := h, cacheHits := hits, fftCost := fftCost w h : CircuitStats } acc let memoryCircuits := t.memorySizes.mapIdx fun i size => let w := size + 11 - let qc := queryCounts[nAllFuns + i]! - let h := qc.uniqueRows - let hits := qc.totalHits - qc.uniqueRows + let pair := memoryCounts[i]! + let h := pair.1 + let totalHits := pair.2 + let hits := totalHits - h { name := s!"memory[{size}]", width := w, height := h, cacheHits := hits, fftCost := fftCost w h : CircuitStats } let circuits := (functionCircuits ++ memoryCircuits).qsort (·.fftCost > ·.fftCost) diff --git a/Kernel.lean b/Kernel.lean index 54fe8019..3b7634c3 100644 --- a/Kernel.lean +++ b/Kernel.lean @@ -59,7 +59,7 @@ where if ioBuffer != testCase.expectedIOBuffer then IO.eprintln s!"{name}: IOBuffer mismatch" return 1 - let stats := Aiur.computeStats compiled queryCounts + let stats := Aiur.computeStats compiled queryCounts.1 queryCounts.2 Aiur.printStats stats pure 0 interpCheck decls name env : IO UInt32 := do diff --git a/src/aiur/constraints.rs b/src/aiur/constraints.rs index aec608fb..6e18acd7 100644 --- a/src/aiur/constraints.rs +++ b/src/aiur/constraints.rs @@ -110,8 +110,11 @@ impl Toplevel { pub fn build_constraints( &self, function_index: usize, + group: &str, ) -> (Constraints, Vec>) { - let function = &self.functions[function_index]; + let function = self.filtered_functions[function_index] + .get(group) + .expect("Missing filtered function for group"); let constraints = Constraints { zeros: vec![], selectors: 0..0, diff --git a/src/aiur/execute.rs b/src/aiur/execute.rs index 6527be2b..df9447a7 100644 --- a/src/aiur/execute.rs +++ b/src/aiur/execute.rs @@ -18,6 +18,7 @@ use crate::{ pub struct QueryResult { pub(crate) output: Vec, pub(crate) multiplicity: G, + pub(crate) return_group: String, } pub type QueryMap = FxIndexMap, QueryResult>; @@ -256,6 +257,7 @@ impl Function { let result = QueryResult { output: vec![ptr], multiplicity: G::from_bool(!unconstrained), + return_group: String::new(), }; memory_queries.insert(values, result); map.push(ptr); @@ -480,7 +482,7 @@ impl Function { map.extend(yielded); push_block_exec_entries!(cont.block); }, - ExecEntry::Ctrl(Ctrl::Return(_, _, output)) => { + ExecEntry::Ctrl(Ctrl::Return(_, group, output)) => { // Register the query. let input_size = toplevel.functions[fun_idx].layout.input_size; let args = map[..input_size].to_vec(); @@ -488,6 +490,7 @@ impl Function { let result = QueryResult { output: output.clone(), multiplicity: G::from_bool(!unconstrained), + return_group: group.clone(), }; record.function_queries[fun_idx].insert(args, result); if let Some(CallerState { diff --git a/src/aiur/synthesis.rs b/src/aiur/synthesis.rs index 1dd7f971..fe13a4d4 100644 --- a/src/aiur/synthesis.rs +++ b/src/aiur/synthesis.rs @@ -71,7 +71,7 @@ where } enum CircuitType { - Function { idx: usize }, + Function { idx: usize, group: String }, Memory { width: usize }, Bytes1, Bytes2, @@ -82,14 +82,23 @@ impl AiurSystem { toplevel: Toplevel, commitment_parameters: CommitmentParameters, ) -> Self { - let function_circuits = (0..toplevel.functions.len()).filter_map(|i| { - if !toplevel.functions[i].constrained { - None - } else { - let (constraints, lookups) = toplevel.build_constraints(i); - Some(LookupAir::new(AiurCircuit::Function(constraints), lookups)) - } - }); + let toplevel_ref = &toplevel; + let function_circuits = + (0..toplevel_ref.functions.len()).flat_map(move |i| { + let groups: Vec = if toplevel_ref.functions[i].constrained { + let mut gs: Vec = + toplevel_ref.filtered_functions[i].keys().cloned().collect(); + gs.sort(); + gs + } else { + vec![] + }; + groups.into_iter().map(move |group| { + let (constraints, lookups) = + toplevel_ref.build_constraints(i, &group); + LookupAir::new(AiurCircuit::Function(constraints), lookups) + }) + }); let memory_circuits = toplevel.memory_sizes.iter().map(|&width| { let (memory, lookups) = Memory::build(width); LookupAir::new(AiurCircuit::Memory(memory), lookups) @@ -131,14 +140,22 @@ impl AiurSystem { // Build the `SystemWitness` let _g = tracing::info_span!("aiur/witness").entered(); - let functions = - (0..self.toplevel.functions.len()).into_par_iter().filter_map(|idx| { - if self.toplevel.functions[idx].constrained { - Some(CircuitType::Function { idx }) + let functions: Vec = (0..self.toplevel.functions.len()) + .flat_map(|idx| { + let groups: Vec = if self.toplevel.functions[idx].constrained { + let mut gs: Vec = + self.toplevel.filtered_functions[idx].keys().cloned().collect(); + gs.sort(); + gs } else { - None - } - }); + vec![] + }; + groups + .into_iter() + .map(move |group| CircuitType::Function { idx, group }) + }) + .collect(); + let functions = functions.into_par_iter(); let memories = self .toplevel .memory_sizes @@ -149,8 +166,8 @@ impl AiurSystem { .chain(memories) .chain(gadgets) .map(|circuit_type| match circuit_type { - CircuitType::Function { idx } => { - self.toplevel.witness_data(idx, &query_record, io_buffer) + CircuitType::Function { idx, group } => { + self.toplevel.witness_data(idx, &group, &query_record, io_buffer) }, CircuitType::Memory { width } => { Memory::witness_data(width, &query_record) diff --git a/src/aiur/trace.rs b/src/aiur/trace.rs index 6b556f1f..c9689b49 100644 --- a/src/aiur/trace.rs +++ b/src/aiur/trace.rs @@ -77,15 +77,20 @@ impl Toplevel { pub fn witness_data( &self, function_index: usize, + group: &str, query_record: &QueryRecord, io_buffer: &IOBuffer, ) -> (RowMajorMatrix, Vec>>) { - let func = &self.functions[function_index]; + let func = self.filtered_functions[function_index] + .get(group) + .expect("Missing filtered function for group"); let width = func.width(); let unfiltered_queries = &query_record.function_queries[function_index]; let queries = unfiltered_queries .iter() - .filter(|(_, res)| !res.multiplicity.is_zero()) + .filter(|(_, res)| { + !res.multiplicity.is_zero() && res.return_group == group + }) .collect::>(); let height_no_padding = queries.len(); let height = height_no_padding.next_power_of_two(); diff --git a/src/ffi/aiur/protocol.rs b/src/ffi/aiur/protocol.rs index 772b18ef..35bb950c 100644 --- a/src/ffi/aiur/protocol.rs +++ b/src/ffi/aiur/protocol.rs @@ -8,7 +8,7 @@ use std::sync::LazyLock; use lean_ffi::object::{ ExternalClass, LeanArray, LeanBorrowed, LeanByteArray, LeanExcept, - LeanExternal, LeanNat, LeanOwned, LeanProd, LeanRef, + LeanExternal, LeanNat, LeanOwned, LeanProd, LeanRef, LeanString, }; use crate::{ @@ -87,9 +87,11 @@ extern "C" fn rs_aiur_system_verify( } /// `Bytecode.Toplevel.execute`: runs execution only (no proof) and returns -/// `Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) × Array (Nat × Nat))`. -/// The trailing `Array (Nat × Nat)` is one `(uniqueRows, totalHits)` pair per -/// function circuit followed by one per memory size. +/// `Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) +/// × (Array (Array (String × Nat × Nat × Nat)) × Array (Nat × Nat)))`. +/// The function side is per-function array of per-split +/// `(group, width, uniqueRows, totalHits)` quadruples. The memory side is +/// per-memory-size `(uniqueRows, totalHits)` pairs. /// On execution failure (e.g. assertion mismatch from a typechecker /// rejecting a constant), returns `Except.error msg` instead of panicking /// — letting Lean test runners (`KernelArena.lean`) classify failures. @@ -114,14 +116,11 @@ extern "C" fn rs_aiur_toplevel_execute( Err(err) => return LeanExcept::error_string(&err.to_string()), }; - // Build per-circuit (unique_rows, total_hits) pairs: - // one per function, then one per memory size. `unique_rows` is the trace - // height (number of distinct queries); `total_hits` is the sum of - // multiplicities (how often those rows were hit). - let mut query_counts: Vec<(usize, usize)> = Vec::with_capacity( - query_record.function_queries.len() + toplevel.memory_sizes.len(), - ); - let summarize = |q: &crate::aiur::execute::QueryMap| -> (usize, usize) { + // Summarize a query map into `(unique_rows, total_hits)`. `unique_rows` is + // the trace height (number of distinct queries with nonzero multiplicity); + // `total_hits` is the sum of multiplicities (how often those rows were + // hit). + let summarize_pair = |q: &crate::aiur::execute::QueryMap| -> (usize, usize) { let mut rows = 0usize; let mut hits = 0usize; for (_, res) in q.iter() { @@ -134,25 +133,84 @@ extern "C" fn rs_aiur_toplevel_execute( } (rows, hits) }; - for queries in &query_record.function_queries { - query_counts.push(summarize(queries)); - } - for size in &toplevel.memory_sizes { - let pair = query_record.memory_queries.get(size).map_or((0, 0), summarize); - query_counts.push(pair); - } - let lean_query_counts = { - let arr = LeanArray::alloc(query_counts.len()); - for (i, &(rows, hits)) in query_counts.iter().enumerate() { + + // Per-function array of `(group, width, unique_rows, total_hits)` quadruples, + // one per split. Queries within a function are partitioned by return group. + let function_stats: Vec> = (0..toplevel + .functions + .len()) + .map(|i| { + let queries = &query_record.function_queries[i]; + let mut stats: Vec<(String, usize, usize, usize)> = toplevel + .filtered_functions[i] + .iter() + .map(|(group, func)| { + let (rows, hits) = queries + .iter() + .filter(|(_, res)| res.return_group == *group) + .fold((0usize, 0usize), |(r, h), (_, res)| { + let m = usize::try_from(res.multiplicity.as_canonical_u64()) + .expect("multiplicity exceeds usize"); + if m != 0 { (r + 1, h + m) } else { (r, h) } + }); + let l = func.layout; + let width = l.input_size + + l.selectors + + l.auxiliaries + + 4 * (1 + l.lookups); + (group.clone(), width, rows, hits) + }) + .collect(); + stats.sort_by(|a, b| a.0.cmp(&b.0)); + stats + }) + .collect(); + + let memory_counts: Vec<(usize, usize)> = toplevel + .memory_sizes + .iter() + .map(|size| { + query_record.memory_queries.get(size).map_or((0, 0), summarize_pair) + }) + .collect(); + + let lean_function_stats = { + let outer = LeanArray::alloc(function_stats.len()); + for (i, per_fn) in function_stats.iter().enumerate() { + let inner = LeanArray::alloc(per_fn.len()); + for (j, (group, width, rows, hits)) in per_fn.iter().enumerate() { + // (String × Nat × Nat × Nat) — right-nested pair encoding + let quad = LeanProd::new( + LeanString::new(group), + LeanProd::new( + LeanOwned::box_usize(*width), + LeanProd::new( + LeanOwned::box_usize(*rows), + LeanOwned::box_usize(*hits), + ), + ), + ); + inner.set(j, quad); + } + outer.set(i, inner); + } + outer + }; + let lean_memory_counts = { + let arr = LeanArray::alloc(memory_counts.len()); + for (i, &(rows, hits)) in memory_counts.iter().enumerate() { let pair = LeanProd::new(LeanOwned::box_usize(rows), LeanOwned::box_usize(hits)); arr.set(i, pair); } arr }; + let lean_query_counts = + LeanProd::new(lean_function_stats, lean_memory_counts); let lean_io = build_lean_io_buffer(&io_buffer); - // (Array G, (Array G × Array (Array G × IOKeyInfo), Array (Nat × Nat))) + // (Array G, (Array G × Array (Array G × IOKeyInfo), + // Array (Array (String × Nat × Nat × Nat)) × Array (Nat × Nat))) let io_counts = LeanProd::new(lean_io, lean_query_counts); let result = LeanProd::new(build_g_array(&output), io_counts); LeanExcept::ok(result) From fefc24013ed1af2c2a297976dd9a6c440b29abb5 Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Tue, 19 May 2026 17:49:37 -0300 Subject: [PATCH 08/13] `String` -> `Arc` --- src/aiur/bytecode.rs | 6 ++++-- src/aiur/execute.rs | 5 +++-- src/aiur/split.rs | 12 +++++++----- src/aiur/synthesis.rs | 13 ++++++++----- src/aiur/trace.rs | 2 +- src/ffi/aiur/protocol.rs | 17 +++++++---------- src/ffi/aiur/toplevel.rs | 4 +++- 7 files changed, 33 insertions(+), 26 deletions(-) diff --git a/src/aiur/bytecode.rs b/src/aiur/bytecode.rs index 9306edff..762a3c2a 100644 --- a/src/aiur/bytecode.rs +++ b/src/aiur/bytecode.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use rustc_hash::FxHashMap; use crate::FxIndexMap; @@ -7,7 +9,7 @@ use super::G; pub struct Toplevel { pub(crate) functions: Vec, #[allow(dead_code)] - pub(crate) filtered_functions: Vec>, + pub(crate) filtered_functions: Vec, Function>>, pub(crate) memory_sizes: Vec, } @@ -70,7 +72,7 @@ pub enum Op { #[derive(Clone)] pub enum Ctrl { Match(ValIdx, FxIndexMap, Option>), - Return(SelIdx, String, Vec), + Return(SelIdx, Arc, Vec), Yield(SelIdx, Vec), MatchContinue( ValIdx, diff --git a/src/aiur/execute.rs b/src/aiur/execute.rs index df9447a7..b2500d2e 100644 --- a/src/aiur/execute.rs +++ b/src/aiur/execute.rs @@ -1,6 +1,7 @@ use multi_stark::p3_field::{PrimeCharacteristicRing, PrimeField64}; use rustc_hash::FxHashMap; use std::collections::hash_map::Entry; +use std::sync::Arc; use crate::{ FxIndexMap, @@ -18,7 +19,7 @@ use crate::{ pub struct QueryResult { pub(crate) output: Vec, pub(crate) multiplicity: G, - pub(crate) return_group: String, + pub(crate) return_group: Arc, } pub type QueryMap = FxIndexMap, QueryResult>; @@ -257,7 +258,7 @@ impl Function { let result = QueryResult { output: vec![ptr], multiplicity: G::from_bool(!unconstrained), - return_group: String::new(), + return_group: Arc::from(""), }; memory_queries.insert(values, result); map.push(ptr); diff --git a/src/aiur/split.rs b/src/aiur/split.rs index f58e79dc..5cb73ca3 100644 --- a/src/aiur/split.rs +++ b/src/aiur/split.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use rustc_hash::{FxHashMap, FxHashSet}; use crate::FxIndexMap; @@ -6,7 +8,7 @@ use super::bytecode::{Block, Ctrl, Function}; use super::layout::compute_layout; impl Function { - pub fn return_groups(&self) -> FxHashSet { + pub fn return_groups(&self) -> FxHashSet> { let mut groups = FxHashSet::default(); collect_block(&self.body, &mut groups); groups @@ -36,7 +38,7 @@ impl Function { self } - pub fn split(&self) -> FxHashMap { + pub fn split(&self) -> FxHashMap, Function> { self .return_groups() .into_iter() @@ -58,7 +60,7 @@ fn filter_block(block: &Block, target: &str) -> Option { fn filter_ctrl(ctrl: &Ctrl, target: &str) -> Option { match ctrl { Ctrl::Return(sel, group, vs) => { - if group == target { + if group.as_ref() == target { Some(Ctrl::Return(*sel, group.clone(), vs.clone())) } else { None @@ -148,11 +150,11 @@ fn filter_cases( new_cases } -fn collect_block(block: &Block, groups: &mut FxHashSet) { +fn collect_block(block: &Block, groups: &mut FxHashSet>) { collect_ctrl(&block.ctrl, groups); } -fn collect_ctrl(ctrl: &Ctrl, groups: &mut FxHashSet) { +fn collect_ctrl(ctrl: &Ctrl, groups: &mut FxHashSet>) { match ctrl { Ctrl::Return(_, group, _) => { groups.insert(group.clone()); diff --git a/src/aiur/synthesis.rs b/src/aiur/synthesis.rs index fe13a4d4..f1efa8ff 100644 --- a/src/aiur/synthesis.rs +++ b/src/aiur/synthesis.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use multi_stark::{ lookup::LookupAir, p3_air::{Air, AirBuilder, BaseAir}, @@ -71,7 +73,7 @@ where } enum CircuitType { - Function { idx: usize, group: String }, + Function { idx: usize, group: Arc }, Memory { width: usize }, Bytes1, Bytes2, @@ -85,8 +87,8 @@ impl AiurSystem { let toplevel_ref = &toplevel; let function_circuits = (0..toplevel_ref.functions.len()).flat_map(move |i| { - let groups: Vec = if toplevel_ref.functions[i].constrained { - let mut gs: Vec = + let groups: Vec> = if toplevel_ref.functions[i].constrained { + let mut gs: Vec> = toplevel_ref.filtered_functions[i].keys().cloned().collect(); gs.sort(); gs @@ -142,8 +144,9 @@ impl AiurSystem { let _g = tracing::info_span!("aiur/witness").entered(); let functions: Vec = (0..self.toplevel.functions.len()) .flat_map(|idx| { - let groups: Vec = if self.toplevel.functions[idx].constrained { - let mut gs: Vec = + let groups: Vec> = if self.toplevel.functions[idx].constrained + { + let mut gs: Vec> = self.toplevel.filtered_functions[idx].keys().cloned().collect(); gs.sort(); gs diff --git a/src/aiur/trace.rs b/src/aiur/trace.rs index c9689b49..a1ebba36 100644 --- a/src/aiur/trace.rs +++ b/src/aiur/trace.rs @@ -89,7 +89,7 @@ impl Toplevel { let queries = unfiltered_queries .iter() .filter(|(_, res)| { - !res.multiplicity.is_zero() && res.return_group == group + !res.multiplicity.is_zero() && res.return_group.as_ref() == group }) .collect::>(); let height_no_padding = queries.len(); diff --git a/src/ffi/aiur/protocol.rs b/src/ffi/aiur/protocol.rs index 35bb950c..245b76b5 100644 --- a/src/ffi/aiur/protocol.rs +++ b/src/ffi/aiur/protocol.rs @@ -136,28 +136,25 @@ extern "C" fn rs_aiur_toplevel_execute( // Per-function array of `(group, width, unique_rows, total_hits)` quadruples, // one per split. Queries within a function are partitioned by return group. - let function_stats: Vec> = (0..toplevel - .functions - .len()) + let function_stats: Vec, usize, usize, usize)>> = (0 + ..toplevel.functions.len()) .map(|i| { let queries = &query_record.function_queries[i]; - let mut stats: Vec<(String, usize, usize, usize)> = toplevel + let mut stats: Vec<(std::sync::Arc, usize, usize, usize)> = toplevel .filtered_functions[i] .iter() .map(|(group, func)| { let (rows, hits) = queries .iter() - .filter(|(_, res)| res.return_group == *group) + .filter(|(_, res)| res.return_group.as_ref() == group.as_ref()) .fold((0usize, 0usize), |(r, h), (_, res)| { let m = usize::try_from(res.multiplicity.as_canonical_u64()) .expect("multiplicity exceeds usize"); if m != 0 { (r + 1, h + m) } else { (r, h) } }); let l = func.layout; - let width = l.input_size - + l.selectors - + l.auxiliaries - + 4 * (1 + l.lookups); + let width = + l.input_size + l.selectors + l.auxiliaries + 4 * (1 + l.lookups); (group.clone(), width, rows, hits) }) .collect(); @@ -181,7 +178,7 @@ extern "C" fn rs_aiur_toplevel_execute( for (j, (group, width, rows, hits)) in per_fn.iter().enumerate() { // (String × Nat × Nat × Nat) — right-nested pair encoding let quad = LeanProd::new( - LeanString::new(group), + LeanString::new(group.as_ref()), LeanProd::new( LeanOwned::box_usize(*width), LeanProd::new( diff --git a/src/ffi/aiur/toplevel.rs b/src/ffi/aiur/toplevel.rs index 945537a2..e2f6d261 100644 --- a/src/ffi/aiur/toplevel.rs +++ b/src/ffi/aiur/toplevel.rs @@ -171,7 +171,9 @@ fn decode_ctrl(ctor: LeanCtor>) -> Ctrl { 1 => { let [sel_idx_obj, group_obj, val_idxs_obj] = ctor.objs::<3>(); let sel_idx = lean_unbox_nat_as_usize(&sel_idx_obj); - let group = group_obj.as_string().to_string(); + let group_lean = group_obj.as_string(); + let group: std::sync::Arc = + std::sync::Arc::from(group_lean.as_str()); let val_idxs = decode_vec_val_idx(val_idxs_obj); Ctrl::Return(sel_idx, group, val_idxs) }, From 8de0ae2cf5e430923bcfabe863ee561abdee2489 Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Wed, 20 May 2026 13:40:54 -0300 Subject: [PATCH 09/13] `address_eq` fast and slow paths --- Ix/IxVM/Ingress.lean | 45 +++++++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/Ix/IxVM/Ingress.lean b/Ix/IxVM/Ingress.lean index 9466eaa7..0bcc8868 100644 --- a/Ix/IxVM/Ingress.lean +++ b/Ix/IxVM/Ingress.lean @@ -33,25 +33,32 @@ def ingress := ⟦ -- Compare two 32-byte addresses for equality fn address_eq(a: Addr, b: Addr) -> G { - let [a0, a1, a2, a3, a4, a5, a6, a7, - a8, a9, a10, a11, a12, a13, a14, a15, - a16, a17, a18, a19, a20, a21, a22, a23, - a24, a25, a26, a27, a28, a29, a30, a31] = load(a); - let [b0, b1, b2, b3, b4, b5, b6, b7, - b8, b9, b10, b11, b12, b13, b14, b15, - b16, b17, b18, b19, b20, b21, b22, b23, - b24, b25, b26, b27, b28, b29, b30, b31] = load(b); - match [a0 - b0, a1 - b1, a2 - b2, a3 - b3, - a4 - b4, a5 - b5, a6 - b6, a7 - b7, - a8 - b8, a9 - b9, a10 - b10, a11 - b11, - a12 - b12, a13 - b13, a14 - b14, a15 - b15, - a16 - b16, a17 - b17, a18 - b18, a19 - b19, - a20 - b20, a21 - b21, a22 - b22, a23 - b23, - a24 - b24, a25 - b25, a26 - b26, a27 - b27, - a28 - b28, a29 - b29, a30 - b30, a31 - b31] { - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] => 1, - _ => 0, + match ptr_val(a) - ptr_val(b) { + 0 => + #[return_group(fast_path)] + 1, + _ => + #[return_group(slow_path)] + let [a0, a1, a2, a3, a4, a5, a6, a7, + a8, a9, a10, a11, a12, a13, a14, a15, + a16, a17, a18, a19, a20, a21, a22, a23, + a24, a25, a26, a27, a28, a29, a30, a31] = load(a); + let [b0, b1, b2, b3, b4, b5, b6, b7, + b8, b9, b10, b11, b12, b13, b14, b15, + b16, b17, b18, b19, b20, b21, b22, b23, + b24, b25, b26, b27, b28, b29, b30, b31] = load(b); + match [a0 - b0, a1 - b1, a2 - b2, a3 - b3, + a4 - b4, a5 - b5, a6 - b6, a7 - b7, + a8 - b8, a9 - b9, a10 - b10, a11 - b11, + a12 - b12, a13 - b13, a14 - b14, a15 - b15, + a16 - b16, a17 - b17, a18 - b18, a19 - b19, + a20 - b20, a21 - b21, a22 - b22, a23 - b23, + a24 - b24, a25 - b25, a26 - b26, a27 - b27, + a28 - b28, a29 - b29, a30 - b30, a31 - b31] { + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] => 1, + _ => 0, + }, } } From 876e06c4693a8b9a364703e81d0527e0a541630f Mon Sep 17 00:00:00 2001 From: Arthur Paulino Date: Wed, 20 May 2026 06:29:34 -0700 Subject: [PATCH 10/13] Move split + layout from Rust to Lean MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Lean now owns return-group splitting and layout recompute. Adds `Bytecode.Function.split` (with `returnGroups`/`filterGroup`/`fix`) and `Toplevel.computeFiltered`, populated in `Source.Toplevel.compile` after dedup + needsCircuit. Layout recompute reuses the existing `Compiler.Layout.blockLayout`. `Bytecode.Toplevel` gains `filteredFunctions : Array (Array (String × Function))`. The Rust FFI decoder reads it straight from Lean and interns group names into `Arc` shared across every `Ctrl::Return` and every `filtered_functions` key. Deletes the Rust duplicates (`src/aiur/split.rs`, `src/aiur/layout.rs`) so there is one source of truth for bytecode shape. Tests: adds `shared_return_groups` exercising two arms sharing `even`, two sharing `odd`, and one unannotated (`""`) arm, aggregated into a single entry `shared_return_groups_all` so the STARK suite proves every group circuit in one run. --- Ix/Aiur.lean | 1 + Ix/Aiur/Compiler.lean | 4 +- Ix/Aiur/Compiler/Lower.lean | 2 +- Ix/Aiur/Compiler/Split.lean | 216 ++++++++++++++++++++++++++++++ Ix/Aiur/Stages/Bytecode.lean | 4 + Tests/Aiur/Aiur.lean | 23 ++++ Tests/Aiur/Cross.lean | 23 +++- src/aiur.rs | 2 - src/aiur/bytecode.rs | 4 - src/aiur/layout.rs | 247 ----------------------------------- src/aiur/split.rs | 181 ------------------------- src/ffi/aiur/toplevel.rs | 116 ++++++++++++---- src/lean.rs | 2 +- 13 files changed, 365 insertions(+), 460 deletions(-) create mode 100644 Ix/Aiur/Compiler/Split.lean delete mode 100644 src/aiur/layout.rs delete mode 100644 src/aiur/split.rs diff --git a/Ix/Aiur.lean b/Ix/Aiur.lean index 67949f1e..6330b32f 100644 --- a/Ix/Aiur.lean +++ b/Ix/Aiur.lean @@ -20,5 +20,6 @@ public import Ix.Aiur.Compiler.Concretize public import Ix.Aiur.Compiler.Layout public import Ix.Aiur.Compiler.Lower public import Ix.Aiur.Compiler.Dedup +public import Ix.Aiur.Compiler.Split public import Ix.Aiur.Compiler public import Ix.Aiur.Statistics diff --git a/Ix/Aiur/Compiler.lean b/Ix/Aiur/Compiler.lean index 4ae98ad4..3a008b73 100644 --- a/Ix/Aiur/Compiler.lean +++ b/Ix/Aiur/Compiler.lean @@ -3,6 +3,7 @@ public import Ix.Aiur.Compiler.Lower public import Ix.Aiur.Compiler.Dedup public import Ix.Aiur.Compiler.Concretize public import Ix.Aiur.Compiler.Simple +public import Ix.Aiur.Compiler.Split /-! Aiur compiler pipeline: type-check, simplify, concretize, lower, deduplicate. @@ -117,9 +118,10 @@ def Source.Toplevel.compile (t : Source.Toplevel) : Except String CompiledToplev let (bytecodeRaw, preNameMap) ← concDecls.toBytecode let (bytecodeDedup, remap) := bytecodeRaw.deduplicate let needs := bytecodeDedup.needsCircuit - let bytecode := { bytecodeDedup with + let bytecodeConstrained : Bytecode.Toplevel := { bytecodeDedup with functions := bytecodeDedup.functions.mapIdx fun i f => { f with constrained := needs[i]! } } + let bytecode := bytecodeConstrained.computeFiltered let nameMap := preNameMap.fold (init := (∅ : Std.HashMap Global Bytecode.FunIdx)) fun acc name idx => acc.insert name (remap idx) pure (CompiledToplevel.mk t bytecode nameMap) diff --git a/Ix/Aiur/Compiler/Lower.lean b/Ix/Aiur/Compiler/Lower.lean index bad47fbe..b83bd70d 100644 --- a/Ix/Aiur/Compiler/Lower.lean +++ b/Ix/Aiur/Compiler/Lower.lean @@ -575,7 +575,7 @@ def Concrete.Decls.toBytecode (decls : Concrete.Decls) : let memSizes := layoutMState.memSizes.fold (·.insert ·) memSizes pure (functions.push function, memSizes, nameMap) | _ => pure acc - pure (⟨functions, memSizes.toArray⟩, nameMap) + pure ({ functions, memorySizes := memSizes.toArray : Bytecode.Toplevel }, nameMap) end Aiur diff --git a/Ix/Aiur/Compiler/Split.lean b/Ix/Aiur/Compiler/Split.lean new file mode 100644 index 00000000..ecdaffd4 --- /dev/null +++ b/Ix/Aiur/Compiler/Split.lean @@ -0,0 +1,216 @@ +module +public import Ix.Aiur.Stages.Bytecode +public import Ix.Aiur.Compiler.Layout + +/-! +Return-group splitting for Aiur bytecode. + +Each `Bytecode.Function` may carry multiple `Ctrl.return` sites tagged with +distinct group names. `Function.split` carves the function into one filtered +sub-function per group — control-flow paths that cannot reach a `return` of +the target group are pruned. The result is a sorted array of +`(groupName, filteredFunction)` pairs with selectors renumbered in +traversal order and `FunctionLayout` recomputed. +-/ + +public section +@[expose] section + +namespace Aiur + +namespace Bytecode + +/-- Termination helper for the `Block`/`Ctrl` traversal below. -/ +private theorem Block.sizeOf_ctrl_lt_split (b : Block) : + sizeOf b.ctrl < sizeOf b := by + rcases b with ⟨ops, ctrl⟩ + show sizeOf ctrl < 1 + sizeOf ops + sizeOf ctrl + omega + +/-! ## Collect return-group names -/ + +mutual +def Ctrl.collectGroups (c : Ctrl) : Array String := match c with + | .return _ g _ => #[g] + | .yield .. => #[] + | .match _ cases default? => + let branchGroups := cases.attach.foldl (init := #[]) fun acc ⟨(_, blk), _⟩ => + acc ++ Block.collectGroups blk + match default? with + | some blk => branchGroups ++ Block.collectGroups blk + | none => branchGroups + | .matchContinue _ cases default? _ _ _ continuation => + let branchGroups := cases.attach.foldl (init := #[]) fun acc ⟨(_, blk), _⟩ => + acc ++ Block.collectGroups blk + let withDefault := match default? with + | some blk => branchGroups ++ Block.collectGroups blk + | none => branchGroups + withDefault ++ Block.collectGroups continuation +termination_by (sizeOf c, 0) +decreasing_by + all_goals first + | decreasing_tactic + | (have := Array.sizeOf_lt_of_mem ‹_ ∈ _›; grind) + | grind + +def Block.collectGroups (b : Block) : Array String := Ctrl.collectGroups b.ctrl +termination_by (sizeOf b, 1) +decreasing_by + all_goals first + | decreasing_tactic + | (apply Prod.Lex.left; exact Block.sizeOf_ctrl_lt_split _) +end + +def Function.returnGroups (f : Function) : Std.HashSet String := + (Block.collectGroups f.body).foldl (init := ({} : Std.HashSet String)) + fun acc g => acc.insert g + +/-! ## Filter control-flow tree by target group + +Single (non-mutual) recursion on `Ctrl`. The conditional `Option`-typed +constructor wrap is factored into `mkMatch`/`mkMatchContinue` helpers so +that each match arm of `filterGroup` is a single function-call expression +— Lean can then derive the unfold equation by `rfl`. + +`Yield` branches always survive — they're anchored by the surrounding +`MatchContinue`'s continuation, not by terminal returns. -/ + +def Ctrl.mkMatch (scrut : ValIdx) (cases : Array (G × Block)) + (default? : Option Block) : Option Ctrl := + if cases.isEmpty && default?.isNone then none + else some (.match scrut cases default?) + +def Ctrl.mkMatchContinue (scrut : ValIdx) (cases : Array (G × Block)) + (default? : Option Block) (outputSize sharedAux sharedLookups : Nat) + (cont : Block) : Option Ctrl := + if cases.isEmpty && default?.isNone then none + else some (.matchContinue scrut cases default? outputSize sharedAux sharedLookups cont) + +def Ctrl.filterGroup (target : String) : Ctrl → Option Ctrl + | .return sel g vs => if g = target then some (.return sel g vs) else none + | .yield sel vs => some (.yield sel vs) + | .match scrut cases default? => + Ctrl.mkMatch scrut + (cases.attach.foldl (init := (#[] : Array (G × Block))) fun acc ⟨(k, blk), _⟩ => + match Ctrl.filterGroup target blk.ctrl with + | none => acc + | some c => acc.push (k, { blk with ctrl := c })) + (match default? with + | none => none + | some blk => (Ctrl.filterGroup target blk.ctrl).map ({ blk with ctrl := · })) + | .matchContinue scrut cases default? outputSize sharedAux sharedLookups cont => + (Ctrl.filterGroup target cont.ctrl).bind fun c => + Ctrl.mkMatchContinue scrut + (cases.attach.foldl (init := (#[] : Array (G × Block))) fun acc ⟨(k, blk), _⟩ => + match Ctrl.filterGroup target blk.ctrl with + | none => acc + | some c => acc.push (k, { blk with ctrl := c })) + (match default? with + | none => none + | some blk => (Ctrl.filterGroup target blk.ctrl).map ({ blk with ctrl := · })) + outputSize sharedAux sharedLookups + { cont with ctrl := c } +termination_by c => sizeOf c +decreasing_by + all_goals first + | decreasing_tactic + | (have hb := Array.sizeOf_lt_of_mem ‹_ ∈ _› + have hc := Block.sizeOf_ctrl_lt_split ‹Block› + grind) + | (have := Block.sizeOf_ctrl_lt_split ‹Block›; grind) + | grind + +def Block.filterGroup (target : String) (b : Block) : Option Block := + (Ctrl.filterGroup target b.ctrl).map ({ b with ctrl := · }) + +/-! ## Renumber selectors in traversal order -/ + +/-- Single recursion on `Ctrl` threading a `Nat` counter for selectors. The +returned pair is `(rewrittenCtrl, nextCounter)`. -/ +def Ctrl.fixSelectors : Ctrl → Nat → Ctrl × Nat + | .return _ g vs, n => (.return n g vs, n + 1) + | .yield _ vs, n => (.yield n vs, n + 1) + | .match scrut cases default?, n => + let (newCases, n) := cases.attach.foldl + (init := ((#[] : Array (G × Block)), n)) + fun (acc, n) ⟨(k, blk), _⟩ => + let (c', n') := Ctrl.fixSelectors blk.ctrl n + (acc.push (k, { blk with ctrl := c' }), n') + let (newDefault, n) : Option Block × Nat := match default? with + | none => (none, n) + | some blk => + let (c', n') := Ctrl.fixSelectors blk.ctrl n + (some { blk with ctrl := c' }, n') + (.match scrut newCases newDefault, n) + | .matchContinue scrut cases default? outputSize sharedAux sharedLookups cont, n => + let (newCases, n) := cases.attach.foldl + (init := ((#[] : Array (G × Block)), n)) + fun (acc, n) ⟨(k, blk), _⟩ => + let (c', n') := Ctrl.fixSelectors blk.ctrl n + (acc.push (k, { blk with ctrl := c' }), n') + let (newDefault, n) : Option Block × Nat := match default? with + | none => (none, n) + | some blk => + let (c', n') := Ctrl.fixSelectors blk.ctrl n + (some { blk with ctrl := c' }, n') + let (c', n') := Ctrl.fixSelectors cont.ctrl n + let newCont : Block := { cont with ctrl := c' } + (.matchContinue scrut newCases newDefault outputSize sharedAux sharedLookups newCont, n') +termination_by c _ => sizeOf c +decreasing_by + all_goals first + | decreasing_tactic + | (have hb := Array.sizeOf_lt_of_mem ‹_ ∈ _› + have hc := Block.sizeOf_ctrl_lt_split ‹Block› + grind) + | (have := Block.sizeOf_ctrl_lt_split ‹Block›; grind) + | grind + +def Block.fixSelectors (b : Block) (n : Nat) : Block × Nat := + let (c', n') := Ctrl.fixSelectors b.ctrl n + ({ b with ctrl := c' }, n') + +/-! ## Layout recompute -/ + +/-- Recompute the `FunctionLayout` of `f` by replaying `blockLayout` over the +current body. Mirrors `Concrete.Function.compile`: `+1` lookup for the +function's own return lookup. -/ +def Function.recomputeLayout (f : Function) : FunctionLayout := + let (_, st) := Concrete.Bytecode.blockLayout f.body |>.run + (.new f.layout.inputSize) + { st.functionLayout with lookups := st.functionLayout.lookups + 1 } + +/-- Renumber selectors via DFS counter, then recompute layout. Intended for +functions produced by `filterGroup`, whose selector indices and layout are +inherited from the pre-filter function. -/ +def Function.fix (f : Function) : Function := + let (body', _) := Block.fixSelectors f.body 0 + let f := { f with body := body' } + { f with layout := f.recomputeLayout } + +/-! ## Top-level split -/ + +/-- Carve `f` into one filtered sub-function per return group. Result is sorted +by group name. Panics if a discovered group has no reachable path (cannot +happen for well-formed bytecode). -/ +def Function.split (f : Function) : Array (String × Function) := + let groups := f.returnGroups.toArray.qsort fun a b => decide (a < b) + groups.map fun g => + match Block.filterGroup g f.body with + | none => panic! s!"function contains an unreachable group: {g}" + | some body => + let filtered : Function := { f with body } + (g, filtered.fix) + +/-- Populate `t.filteredFunctions` by splitting every function. Idempotent — +overwrites any existing entries. Should run after `deduplicate` and +`needsCircuit` so the split reflects the final function set. -/ +def Toplevel.computeFiltered (t : Toplevel) : Toplevel := + { t with filteredFunctions := t.functions.map (·.split) } + +end Bytecode + +end Aiur + +end -- @[expose] section +end diff --git a/Ix/Aiur/Stages/Bytecode.lean b/Ix/Aiur/Stages/Bytecode.lean index be8aeeaf..7cc4e184 100644 --- a/Ix/Aiur/Stages/Bytecode.lean +++ b/Ix/Aiur/Stages/Bytecode.lean @@ -90,6 +90,10 @@ structure Function where structure Toplevel where functions : Array Function + /-- Per-function split by return group: one entry per `functions[i]`, each a + sorted array of `(groupName, filteredFunction)` pairs. Populated by + `Toplevel.computeFiltered` after `deduplicate` + `needsCircuit`. -/ + filteredFunctions : Array (Array (String × Function)) := #[] memorySizes : Array Nat deriving Repr diff --git a/Tests/Aiur/Aiur.lean b/Tests/Aiur/Aiur.lean index dd61de5a..d06e749c 100644 --- a/Tests/Aiur/Aiur.lean +++ b/Tests/Aiur/Aiur.lean @@ -705,6 +705,25 @@ def toplevel := ⟦ let (x, y) = ntm_tuple(a); x + y } + -- Return-group annotation with two arms sharing a group + one unannotated + -- arm (default `""` group). Exercises per-group multi-row witnesses: the + -- `even` and `odd` circuits each get 2 return sites, the `""` circuit 1. + fn shared_return_groups(x: G) -> G { + match x { + 0 => #[return_group(even)] x + 100, + 1 => #[return_group(odd)] x + 200, + 2 => #[return_group(even)] x * 10, + 3 => #[return_group(odd)] x * 100, + _ => x + 1, + } + } + -- Aggregate one call per arm so the STARK proves every group circuit with + -- the right multiplicities in a single test case. + pub fn shared_return_groups_all() -> G { + shared_return_groups(0) + shared_return_groups(1) + + shared_return_groups(2) + shared_return_groups(3) + + shared_return_groups(7) + } -- Return-group annotation: ignored by compiler/typechecker, must pass through -- to inner term. Match with distinct group labels per arm. pub fn match_return_groups(x: G) -> G { @@ -914,6 +933,10 @@ def aiurTestCases : List AiurTestCase := [ with label := "match_return_groups(2)" }, { AiurTestCase.noIO `match_return_groups #[7] #[8] with label := "match_return_groups(7)" }, + + -- Shared-group + unannotated arm packed into a single STARK run: + -- 100 + 201 + 20 + 300 + 8 = 629. + .noIO `shared_return_groups_all #[] #[629], ] end diff --git a/Tests/Aiur/Cross.lean b/Tests/Aiur/Cross.lean index 25d1e988..76014790 100644 --- a/Tests/Aiur/Cross.lean +++ b/Tests/Aiur/Cross.lean @@ -1020,6 +1020,25 @@ def toplevel : Source.Toplevel := ⟦ r1 + r2 + r3 + r4 + r5 + r6 + r7 + r8 + r9 + r10 + r11 + r12 + r13 + r14 + r15 + r16 + r17 + r18 + r19 + r20 } + + -- Return-group annotation: two arms share `even`, two share `odd`, one + -- arm is unannotated (default `""` group). Source.Eval treats the + -- annotation as identity; Bytecode.Eval must agree on every input. + fn shared_return_groups(x: G) -> G { + match x { + 0 => #[return_group(even)] x + 100, + 1 => #[return_group(odd)] x + 200, + 2 => #[return_group(even)] x * 10, + 3 => #[return_group(odd)] x * 100, + _ => x + 1, + } + } + -- One entry point hitting every arm in a single run. + pub fn shared_return_groups_all() -> G { + shared_return_groups(0) + shared_return_groups(1) + + shared_return_groups(2) + shared_return_groups(3) + + shared_return_groups(7) + } ⟧ /-- Generic helper: run both evaluators on `entryName` with `inputs` as @@ -1263,7 +1282,9 @@ def tests : TestSeq := runAgreement "fibonacci(6)" "fibonacci" [6] ++ -- End-to-end non-tail match entry aggregating many ntm_* helpers runAgreement "ntm_recursive_test" "ntm_recursive_test" [] ++ - runAgreement "non_tail_match" "non_tail_match" [] + runAgreement "non_tail_match" "non_tail_match" [] ++ + -- Shared-group + unannotated arm: aggregate over every arm in one call. + runAgreement "shared_return_groups_all" "shared_return_groups_all" [] end AiurTests.Cross diff --git a/src/aiur.rs b/src/aiur.rs index bf6005d4..2e904331 100644 --- a/src/aiur.rs +++ b/src/aiur.rs @@ -2,9 +2,7 @@ pub mod bytecode; pub mod constraints; pub mod execute; pub mod gadgets; -pub mod layout; pub mod memory; -pub mod split; pub mod synthesis; pub mod trace; diff --git a/src/aiur/bytecode.rs b/src/aiur/bytecode.rs index 762a3c2a..5d428239 100644 --- a/src/aiur/bytecode.rs +++ b/src/aiur/bytecode.rs @@ -8,7 +8,6 @@ use super::G; pub struct Toplevel { pub(crate) functions: Vec, - #[allow(dead_code)] pub(crate) filtered_functions: Vec, Function>>, pub(crate) memory_sizes: Vec, } @@ -34,13 +33,11 @@ impl FunctionLayout { } } -#[derive(Clone)] pub struct Block { pub(crate) ops: Vec, pub(crate) ctrl: Ctrl, } -#[derive(Clone)] pub enum Op { Const(G), Add(ValIdx, ValIdx), @@ -69,7 +66,6 @@ pub enum Op { Debug(String, Option>), } -#[derive(Clone)] pub enum Ctrl { Match(ValIdx, FxIndexMap, Option>), Return(SelIdx, Arc, Vec), diff --git a/src/aiur/layout.rs b/src/aiur/layout.rs deleted file mode 100644 index 25b73298..00000000 --- a/src/aiur/layout.rs +++ /dev/null @@ -1,247 +0,0 @@ -use std::collections::BTreeSet; - -use super::bytecode::{Block, Ctrl, Function, FunctionLayout, Op, ValIdx}; - -/// Compute the `FunctionLayout` (input_size / selectors / auxiliaries / -/// lookups) of a `Function` by walking its bytecode. Port of -/// `Ix/Aiur/Compiler/Layout.lean::blockLayout`. -pub fn compute_layout(function: &Function) -> FunctionLayout { - let input_size = function.layout.input_size; - let mut state = LayoutState::new(input_size); - state.block_layout(&function.body); - // Reserve one lookup slot for the function's return lookup (mirrors the - // `+ 1` added after `blockLayout` in `Concrete.Function.compile`). - state.function_layout.lookups += 1; - state.function_layout -} - -#[derive(Clone, Copy)] -struct SharedData { - auxiliaries: usize, - lookups: usize, -} - -impl SharedData { - fn maximals(self, other: Self) -> Self { - SharedData { - auxiliaries: self.auxiliaries.max(other.auxiliaries), - lookups: self.lookups.max(other.lookups), - } - } -} - -struct LayoutState { - function_layout: FunctionLayout, - #[allow(dead_code)] - mem_sizes: BTreeSet, - degrees: Vec, -} - -impl LayoutState { - fn new(input_size: usize) -> Self { - LayoutState { - function_layout: FunctionLayout { - input_size, - selectors: 0, - auxiliaries: 1, - lookups: 0, - }, - mem_sizes: BTreeSet::new(), - degrees: vec![1; input_size], - } - } - - fn bump_selectors(&mut self, n: usize) { - self.function_layout.selectors += n; - } - - fn bump_lookups(&mut self, n: usize) { - self.function_layout.lookups += n; - } - - fn bump_auxiliaries(&mut self, n: usize) { - self.function_layout.auxiliaries += n; - } - - fn add_mem_size(&mut self, size: usize) { - self.mem_sizes.insert(size); - } - - fn push_degree(&mut self, d: usize) { - self.degrees.push(d); - } - - fn push_degrees(&mut self, ds: &[usize]) { - self.degrees.extend_from_slice(ds); - } - - fn get_degree(&self, v: ValIdx) -> usize { - self.degrees.get(v).copied().unwrap_or(0) - } - - fn get_shared(&self) -> SharedData { - SharedData { - auxiliaries: self.function_layout.auxiliaries, - lookups: self.function_layout.lookups, - } - } - - fn set_shared(&mut self, s: SharedData) { - self.function_layout.auxiliaries = s.auxiliaries; - self.function_layout.lookups = s.lookups; - } - - fn op_layout(&mut self, op: &Op) { - match op { - Op::Const(_) => self.push_degree(0), - Op::Add(a, b) | Op::Sub(a, b) => { - let d = self.get_degree(*a).max(self.get_degree(*b)); - self.push_degree(d); - }, - Op::Mul(a, b) => { - let d = self.get_degree(*a) + self.get_degree(*b); - if d < 2 { - self.push_degree(d); - } else { - self.push_degree(1); - self.bump_auxiliaries(1); - } - }, - Op::EqZero(a) => { - let d = self.get_degree(*a); - if d == 0 { - self.push_degree(0); - } else { - self.push_degree(1); - self.bump_auxiliaries(2); - } - }, - Op::Call(_, _, output_size, unconstrained) => { - let ones = vec![1usize; *output_size]; - self.push_degrees(&ones); - self.bump_auxiliaries(*output_size); - if !*unconstrained { - self.bump_lookups(1); - } - }, - Op::Store(values) => { - self.push_degree(1); - self.bump_auxiliaries(1); - self.bump_lookups(1); - self.add_mem_size(values.len()); - }, - Op::Load(size, _) => { - let ones = vec![1usize; *size]; - self.push_degrees(&ones); - self.bump_auxiliaries(*size); - self.bump_lookups(1); - self.add_mem_size(*size); - }, - Op::AssertEq(..) - | Op::IOSetInfo(..) - | Op::IOWrite(..) - | Op::Debug(..) => {}, - Op::IOGetInfo(_) => { - self.push_degrees(&[1, 1]); - self.bump_auxiliaries(2); - }, - Op::IORead(_, len) => { - let ones = vec![1usize; *len]; - self.push_degrees(&ones); - self.bump_auxiliaries(*len); - }, - Op::U8BitDecomposition(_) => { - self.push_degrees(&[1; 8]); - self.bump_auxiliaries(8); - self.bump_lookups(1); - }, - Op::U8ShiftLeft(_) - | Op::U8ShiftRight(_) - | Op::U8Xor(..) - | Op::U8And(..) - | Op::U8Or(..) - | Op::U8LessThan(..) => { - self.push_degree(1); - self.bump_auxiliaries(1); - self.bump_lookups(1); - }, - Op::U8Add(..) | Op::U8Mul(..) | Op::U8Sub(..) => { - self.push_degrees(&[1, 1]); - self.bump_auxiliaries(2); - self.bump_lookups(1); - }, - Op::U32LessThan(..) => { - self.push_degree(1); - self.bump_auxiliaries(12); - self.bump_lookups(6); - }, - } - } - - fn block_layout(&mut self, block: &Block) { - for op in &block.ops { - self.op_layout(op); - } - self.ctrl_layout(&block.ctrl); - } - - fn ctrl_layout(&mut self, ctrl: &Ctrl) { - match ctrl { - Ctrl::Return(..) | Ctrl::Yield(..) => self.bump_selectors(1), - Ctrl::Match(_, branches, default) => { - let init_shared = self.get_shared(); - let degrees_save = self.degrees.clone(); - let mut max_shared = init_shared; - for branch in branches.values() { - self.set_shared(init_shared); - self.block_layout(branch); - let branch_shared = self.get_shared(); - self.degrees = degrees_save.clone(); - max_shared = max_shared.maximals(branch_shared); - } - if let Some(default_block) = default { - self.set_shared(init_shared); - self.bump_auxiliaries(branches.len()); - self.block_layout(default_block); - let default_shared = self.get_shared(); - self.degrees = degrees_save.clone(); - max_shared = max_shared.maximals(default_shared); - } - self.set_shared(max_shared); - }, - Ctrl::MatchContinue( - _, - branches, - default, - output_size, - _shared_aux, - _shared_lookups, - continuation, - ) => { - let init_shared = self.get_shared(); - let degrees_save = self.degrees.clone(); - let mut max_shared = init_shared; - for branch in branches.values() { - self.set_shared(init_shared); - self.block_layout(branch); - let branch_shared = self.get_shared(); - self.degrees = degrees_save.clone(); - max_shared = max_shared.maximals(branch_shared); - } - if let Some(default_block) = default { - self.set_shared(init_shared); - self.bump_auxiliaries(branches.len()); - self.block_layout(default_block); - let default_shared = self.get_shared(); - self.degrees = degrees_save.clone(); - max_shared = max_shared.maximals(default_shared); - } - self.set_shared(max_shared); - self.bump_auxiliaries(*output_size); - let ones = vec![1usize; *output_size]; - self.push_degrees(&ones); - self.block_layout(continuation); - }, - } - } -} diff --git a/src/aiur/split.rs b/src/aiur/split.rs deleted file mode 100644 index 5cb73ca3..00000000 --- a/src/aiur/split.rs +++ /dev/null @@ -1,181 +0,0 @@ -use std::sync::Arc; - -use rustc_hash::{FxHashMap, FxHashSet}; - -use crate::FxIndexMap; - -use super::bytecode::{Block, Ctrl, Function}; -use super::layout::compute_layout; - -impl Function { - pub fn return_groups(&self) -> FxHashSet> { - let mut groups = FxHashSet::default(); - collect_block(&self.body, &mut groups); - groups - } - - /// Build a `Function` containing only the control-flow paths that can reach a - /// `Ctrl::Return` whose group matches `target`. Returns `None` when no such - /// path exists. `Yield` branches are preserved unconditionally — their - /// reachability is decided by the surrounding `MatchContinue`'s continuation. - pub fn filter_group(&self, target: &str) -> Option { - let body = filter_block(&self.body, target)?; - Some(Function { - body, - layout: self.layout, - entry: self.entry, - constrained: self.constrained, - }) - } - - /// Renumber selectors in traversal order and recompute the circuit layout. - /// Intended for functions produced by `filter_group`, whose selector indices - /// and layout are inherited from the original (pre-filter) function. - fn fix(mut self) -> Self { - let mut counter: usize = 0; - fix_block_sel(&mut self.body, &mut counter); - self.layout = compute_layout(&self); - self - } - - pub fn split(&self) -> FxHashMap, Function> { - self - .return_groups() - .into_iter() - .map(|group| { - let filtered = self.filter_group(&group).unwrap_or_else(|| { - panic!("function contains an unreachable group: {group}") - }); - (group, filtered.fix()) - }) - .collect() - } -} - -fn filter_block(block: &Block, target: &str) -> Option { - let ctrl = filter_ctrl(&block.ctrl, target)?; - Some(Block { ops: block.ops.clone(), ctrl }) -} - -fn filter_ctrl(ctrl: &Ctrl, target: &str) -> Option { - match ctrl { - Ctrl::Return(sel, group, vs) => { - if group.as_ref() == target { - Some(Ctrl::Return(*sel, group.clone(), vs.clone())) - } else { - None - } - }, - Ctrl::Yield(sel, vs) => Some(Ctrl::Yield(*sel, vs.clone())), - Ctrl::Match(scrut, cases, default) => { - let new_cases = filter_cases(cases, target); - let new_default = - default.as_ref().and_then(|b| filter_block(b, target).map(Box::new)); - if new_cases.is_empty() && new_default.is_none() { - None - } else { - Some(Ctrl::Match(*scrut, new_cases, new_default)) - } - }, - Ctrl::MatchContinue( - scrut, - cases, - default, - output_size, - shared_aux, - shared_lookups, - cont, - ) => { - let new_cont = filter_block(cont, target)?; - let new_cases = filter_cases(cases, target); - let new_default = - default.as_ref().and_then(|b| filter_block(b, target).map(Box::new)); - if new_cases.is_empty() && new_default.is_none() { - None - } else { - Some(Ctrl::MatchContinue( - *scrut, - new_cases, - new_default, - *output_size, - *shared_aux, - *shared_lookups, - Box::new(new_cont), - )) - } - }, - } -} - -fn fix_block_sel(block: &mut Block, counter: &mut usize) { - fix_ctrl_sel(&mut block.ctrl, counter); -} - -fn fix_ctrl_sel(ctrl: &mut Ctrl, counter: &mut usize) { - match ctrl { - Ctrl::Return(sel, _, _) | Ctrl::Yield(sel, _) => { - *sel = *counter; - *counter += 1; - }, - Ctrl::Match(_, cases, default) => { - for branch in cases.values_mut() { - fix_block_sel(branch, counter); - } - if let Some(branch) = default { - fix_block_sel(branch, counter); - } - }, - Ctrl::MatchContinue(_, cases, default, _, _, _, cont) => { - for branch in cases.values_mut() { - fix_block_sel(branch, counter); - } - if let Some(branch) = default { - fix_block_sel(branch, counter); - } - fix_block_sel(cont, counter); - }, - } -} - -fn filter_cases( - cases: &FxIndexMap, - target: &str, -) -> FxIndexMap { - let mut new_cases = FxIndexMap::default(); - for (&k, blk) in cases { - if let Some(b) = filter_block(blk, target) { - new_cases.insert(k, b); - } - } - new_cases -} - -fn collect_block(block: &Block, groups: &mut FxHashSet>) { - collect_ctrl(&block.ctrl, groups); -} - -fn collect_ctrl(ctrl: &Ctrl, groups: &mut FxHashSet>) { - match ctrl { - Ctrl::Return(_, group, _) => { - groups.insert(group.clone()); - }, - Ctrl::Yield(..) => {}, - Ctrl::Match(_, cases, default) => { - for branch in cases.values() { - collect_block(branch, groups); - } - if let Some(branch) = default { - collect_block(branch, groups); - } - }, - Ctrl::MatchContinue(_, cases, default, _, _, _, cont) => { - for branch in cases.values() { - collect_block(branch, groups); - } - if let Some(branch) = default { - collect_block(branch, groups); - } - collect_block(cont, groups); - }, - } -} diff --git a/src/ffi/aiur/toplevel.rs b/src/ffi/aiur/toplevel.rs index e2f6d261..00f1ce13 100644 --- a/src/ffi/aiur/toplevel.rs +++ b/src/ffi/aiur/toplevel.rs @@ -1,4 +1,7 @@ +use std::sync::Arc; + use multi_stark::p3_field::PrimeCharacteristicRing; +use rustc_hash::FxHashMap; use lean_ffi::object::{LeanBorrowed, LeanCtor, LeanRef}; @@ -15,6 +18,26 @@ use crate::{ use crate::ffi::aiur::{lean_unbox_g, lean_unbox_nat_as_usize}; +/// Per-decode interner for return-group names. Shared across every +/// `Ctrl::Return` decoded inside the `functions` tree and every key in +/// `filtered_functions`, so the same group name resolves to a single +/// `Arc` allocation. +#[derive(Default)] +struct GroupInterner { + map: FxHashMap>, +} + +impl GroupInterner { + fn intern(&mut self, s: &str) -> Arc { + if let Some(existing) = self.map.get(s) { + return existing.clone(); + } + let arc: Arc = Arc::from(s); + self.map.insert(s.to_string(), arc.clone()); + arc + } +} + fn decode_vec_val_idx(obj: LeanBorrowed<'_>) -> Vec { obj.as_array().map(|x| lean_unbox_nat_as_usize(&x)) } @@ -144,26 +167,34 @@ fn decode_op(ctor: LeanCtor>) -> Op { } } -fn decode_g_block_pair(ctor: LeanCtor>) -> (G, Block) { +fn decode_g_block_pair( + ctor: LeanCtor>, + interner: &mut GroupInterner, +) -> (G, Block) { let [g_obj, block_obj] = ctor.objs::<2>(); let g = lean_unbox_g(&g_obj); - let block = decode_block(block_obj.as_ctor()); + let block = decode_block(block_obj.as_ctor(), interner); (g, block) } -fn decode_ctrl(ctor: LeanCtor>) -> Ctrl { +fn decode_ctrl( + ctor: LeanCtor>, + interner: &mut GroupInterner, +) -> Ctrl { match ctor.tag() { 0 => { let [val_idx_obj, cases_obj, default_obj] = ctor.objs::<3>(); let val_idx = lean_unbox_nat_as_usize(&val_idx_obj); - let vec_cases = - cases_obj.as_array().map(|o| decode_g_block_pair(o.as_ctor())); - let cases = FxIndexMap::from_iter(vec_cases); + let cases: FxIndexMap = cases_obj + .as_array() + .iter() + .map(|o| decode_g_block_pair(o.as_ctor(), interner)) + .collect(); let default = if default_obj.is_scalar() { None } else { let inner_ctor = default_obj.as_ctor(); - let block = decode_block(inner_ctor.get(0).as_ctor()); + let block = decode_block(inner_ctor.get(0).as_ctor(), interner); Some(Box::new(block)) }; Ctrl::Match(val_idx, cases, default) @@ -172,8 +203,7 @@ fn decode_ctrl(ctor: LeanCtor>) -> Ctrl { let [sel_idx_obj, group_obj, val_idxs_obj] = ctor.objs::<3>(); let sel_idx = lean_unbox_nat_as_usize(&sel_idx_obj); let group_lean = group_obj.as_string(); - let group: std::sync::Arc = - std::sync::Arc::from(group_lean.as_str()); + let group = interner.intern(group_lean.as_str()); let val_idxs = decode_vec_val_idx(val_idxs_obj); Ctrl::Return(sel_idx, group, val_idxs) }, @@ -194,20 +224,23 @@ fn decode_ctrl(ctor: LeanCtor>) -> Ctrl { cont_obj, ] = ctor.objs::<7>(); let val_idx = lean_unbox_nat_as_usize(&val_idx_obj); - let vec_cases = - cases_obj.as_array().map(|o| decode_g_block_pair(o.as_ctor())); - let cases = FxIndexMap::from_iter(vec_cases); + let cases: FxIndexMap = cases_obj + .as_array() + .iter() + .map(|o| decode_g_block_pair(o.as_ctor(), interner)) + .collect(); let default = if default_obj.is_scalar() { None } else { let inner_ctor = default_obj.as_ctor(); - let block = decode_block(inner_ctor.get(0).as_ctor()); + let block = decode_block(inner_ctor.get(0).as_ctor(), interner); Some(Box::new(block)) }; let output_size = lean_unbox_nat_as_usize(&output_size_obj); let shared_aux = lean_unbox_nat_as_usize(&shared_aux_obj); let shared_lookups = lean_unbox_nat_as_usize(&shared_lookups_obj); - let continuation = Box::new(decode_block(cont_obj.as_ctor())); + let continuation = + Box::new(decode_block(cont_obj.as_ctor(), interner)); Ctrl::MatchContinue( val_idx, cases, @@ -222,10 +255,13 @@ fn decode_ctrl(ctor: LeanCtor>) -> Ctrl { } } -fn decode_block(ctor: LeanCtor>) -> Block { +fn decode_block( + ctor: LeanCtor>, + interner: &mut GroupInterner, +) -> Block { let [ops_obj, ctrl_obj] = ctor.objs::<2>(); let ops = ops_obj.as_array().map(|o| decode_op(o.as_ctor())); - let ctrl = decode_ctrl(ctrl_obj.as_ctor()); + let ctrl = decode_ctrl(ctrl_obj.as_ctor(), interner); Block { ops, ctrl } } @@ -239,24 +275,60 @@ fn decode_function_layout(ctor: LeanCtor>) -> FunctionLayout { } } -fn decode_function(ctor: LeanCtor>) -> Function { +fn decode_function( + ctor: LeanCtor>, + interner: &mut GroupInterner, +) -> Function { let ctor = LeanAiurFunction::from_ctor(ctor); - let body = decode_block(ctor.get_obj(0).as_ctor()); + let body = decode_block(ctor.get_obj(0).as_ctor(), interner); let layout = decode_function_layout(ctor.get_obj(1).as_ctor()); let entry = ctor.get_num_8(0) != 0; let constrained = ctor.get_num_8(1) != 0; Function { body, layout, entry, constrained } } +/// Decode a single `(String × Function)` Lean product into its Rust form, +/// sharing the group `Arc` with every other site that mentions the +/// same name (via the interner). +fn decode_group_function_pair( + ctor: LeanCtor>, + interner: &mut GroupInterner, +) -> (Arc, Function) { + let [group_obj, fn_obj] = ctor.objs::<2>(); + let group = interner.intern(group_obj.as_string().as_str()); + let function = decode_function(fn_obj.as_ctor(), interner); + (group, function) +} + pub(crate) fn decode_toplevel( obj: &LeanAiurToplevel, ) -> Toplevel { let ctor = obj.as_ctor(); - let [functions_obj, memory_sizes_obj] = ctor.objs::<2>(); - let functions = - functions_obj.as_array().map(|o| decode_function(o.as_ctor())); + let [functions_obj, filtered_functions_obj, memory_sizes_obj] = + ctor.objs::<3>(); + let mut interner = GroupInterner::default(); + // `LeanArray::map` requires `Fn`, so we cannot borrow `interner` mutably + // inside it. Use `iter()` for the interner-touching layers. + let functions: Vec = functions_obj + .as_array() + .iter() + .map(|o| decode_function(o.as_ctor(), &mut interner)) + .collect(); + let filtered_functions: Vec, Function>> = + filtered_functions_obj + .as_array() + .iter() + .map(|inner_obj| { + inner_obj + .as_array() + .iter() + .map(|pair_obj| { + decode_group_function_pair(pair_obj.as_ctor(), &mut interner) + }) + .collect::>() + }) + .collect(); let memory_sizes = memory_sizes_obj.as_array().map(|x| lean_unbox_nat_as_usize(&x)); - let filtered_functions = functions.iter().map(|f| f.split()).collect(); Toplevel { functions, memory_sizes, filtered_functions } } diff --git a/src/lean.rs b/src/lean.rs index ef41347b..24190ee6 100644 --- a/src/lean.rs +++ b/src/lean.rs @@ -234,7 +234,7 @@ lean_ffi::lean_inductive! { // --- Aiur types --- - LeanAiurToplevel [ { num_obj: 2 } ]; + LeanAiurToplevel [ { num_obj: 3 } ]; LeanAiurFunction [ { num_obj: 2, num_8: 2 } ]; // --- Block / comparison types --- From b529b2a5e29a5be9fba77fbb8111d1fa99b466f8 Mon Sep 17 00:00:00 2001 From: Arthur Paulino Date: Wed, 20 May 2026 07:10:32 -0700 Subject: [PATCH 11/13] Index return groups by USize at bytecode level MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Source / Typed / Concrete / Simple keep `#[return_group(name)]` as a `String` (user-written labels). At lowering time `Concrete.Function.compile` allocates a `USize` per distinct group name in encounter order, stores the table in `Bytecode.Function.groupNames`, and stamps every `Ctrl.return` with the index. `Bytecode.Toplevel.filteredFunctions` becomes `Array (Array Function)` keyed positionally by group index. Rust mirrors: `Ctrl::Return(SelIdx, usize, Vec)`, `Toplevel.filtered_functions: Vec>`, `QueryResult.return_group: usize`. The `Arc` interner is gone — group identity is now a `Copy` integer. `LeanAiurCtrl` is declared with `num_usize: 1` on the `return` ctor so `get_usize(0)` is bounds-checked. `LeanAiurFunction` bumps `num_obj: 2 → 3` to carry `groupNames`; the Rust decoder reads it Lean-side only (display lookup in `Statistics`). Execute ships `(groupIdx, width, rows, hits)` quads instead of name strings; `Statistics.computeStats` resolves the display name via `t.functions[i].groupNames[groupIdx]`. --- Ix/Aiur/Compiler/Lower.lean | 57 ++++++++--- Ix/Aiur/Compiler/Split.lean | 72 ++++---------- Ix/Aiur/Semantics/BytecodeFfi.lean | 8 +- Ix/Aiur/Stages/Bytecode.lean | 13 ++- Ix/Aiur/Statistics.lean | 6 +- src/aiur/bytecode.rs | 10 +- src/aiur/constraints.rs | 6 +- src/aiur/execute.rs | 7 +- src/aiur/synthesis.rs | 34 +++---- src/aiur/trace.rs | 8 +- src/ffi/aiur/protocol.rs | 70 +++++++------- src/ffi/aiur/toplevel.rs | 147 +++++++++-------------------- src/lean.rs | 12 ++- 13 files changed, 194 insertions(+), 256 deletions(-) diff --git a/Ix/Aiur/Compiler/Lower.lean b/Ix/Aiur/Compiler/Lower.lean index b83bd70d..27f70ed0 100644 --- a/Ix/Aiur/Compiler/Lower.lean +++ b/Ix/Aiur/Compiler/Lower.lean @@ -94,7 +94,15 @@ structure CompilerState where ops : Array Bytecode.Op selIdx : Bytecode.SelIdx degrees : Array Nat - currentReturnGroup : String := "" + /-- Top of the `#[return_group(…)]` annotation stack — the display name + whose index will tag every `Ctrl.return` emitted inside its scope. The + index is looked up (or allocated) lazily at emit time. -/ + currentReturnGroupName : String := "" + /-- Group display names allocated so far; position `i` is the name for + group index `i`. -/ + groupNames : Array String := #[] + /-- Inverse of `groupNames`: maps name → allocated index. -/ + groupNameMap : Std.HashMap String USize := {} deriving Inhabited abbrev CompileM := EStateM String CompilerState @@ -123,6 +131,20 @@ def pushOp (op : Bytecode.Op) (size : Nat := 1) : CompileM (Array Bytecode.ValId def extractOps : CompileM (Array Bytecode.Op) := modifyGet fun s => (s.ops, {s with ops := #[]}) +/-- Look up the `USize` index for the current return-group name, allocating +fresh storage in `groupNames`/`groupNameMap` on first encounter. -/ +def allocCurrentGroup : CompileM USize := do + let st ← get + let name := st.currentReturnGroupName + match st.groupNameMap[name]? with + | some idx => pure idx + | none => + let idx : USize := USize.ofNat st.groupNames.size + modify fun s => { s with + groupNameMap := s.groupNameMap.insert name idx + groupNames := s.groupNames.push name } + pure idx + open Concrete in mutual @@ -449,10 +471,10 @@ def Concrete.Term.compile modify fun stt => { stt with ops := stt.ops.push (.ioWrite data) } ret.compile returnTyp layoutMap bindings yieldCtrl | .retGroup _ _ name inner => do - let oldGroup := (← get).currentReturnGroup - modify fun s => { s with currentReturnGroup := name } + let oldGroup := (← get).currentReturnGroupName + modify fun s => { s with currentReturnGroupName := name } let blk ← inner.compile returnTyp layoutMap bindings yieldCtrl - modify fun s => { s with currentReturnGroup := oldGroup } + modify fun s => { s with currentReturnGroupName := oldGroup } pure blk | .match _ _ scrut cases defaultOpt => do let idxs := bindings[scrut]?.getD #[0] @@ -468,14 +490,16 @@ def Concrete.Term.compile pure ({ ops, ctrl } : Bytecode.Block) | .ret _ _ term => do let idxs ← toIndex layoutMap bindings term + let groupIdx ← allocCurrentGroup let state ← get let state := { state with selIdx := state.selIdx + 1 } set state let ops := state.ops let id := state.selIdx - pure ({ ops, ctrl := .return (id - 1) state.currentReturnGroup idxs } : Bytecode.Block) + pure ({ ops, ctrl := .return (id - 1) groupIdx idxs } : Bytecode.Block) | _ => do let idxs ← toIndex layoutMap bindings term + let groupIdx ← allocCurrentGroup let state ← get let state := { state with selIdx := state.selIdx + 1 } set state @@ -483,7 +507,7 @@ def Concrete.Term.compile let id := state.selIdx let ctrl : Bytecode.Ctrl := if yieldCtrl && !term.escapes then .yield (id - 1) idxs - else .return (id - 1) state.currentReturnGroup idxs + else .return (id - 1) groupIdx idxs pure ({ ops, ctrl } : Bytecode.Block) termination_by (sizeOf term, 0) decreasing_by @@ -538,9 +562,11 @@ decreasing_by all_goals first | decreasing_tactic | grind end -/-- Lower a full concrete function to bytecode. -/ +/-- Lower a full concrete function to bytecode. Returns the body, layout +state, and the per-function `groupNames` table (position `i` is the display +name for group index `i`). -/ def Concrete.Function.compile (layoutMap : LayoutMap) (f : Concrete.Function) : - Except String (Bytecode.Block × Bytecode.LayoutMState) := do + Except String (Bytecode.Block × Bytecode.LayoutMState × Array String) := do let (_inputSize, _outputSize) ← match layoutMap[f.name]? with | some (.function layout) => pure (layout.inputSize, layout.outputSize) | _ => throw s!"`{f.name}` should be a function" @@ -551,16 +577,16 @@ def Concrete.Function.compile (layoutMap : LayoutMap) (f : Concrete.Function) : | .ok len => pure len let indices := Array.range' valIdx len pure (valIdx + len, bindings.insert arg indices) - let state := { valIdx, selIdx := 0, ops := #[], degrees := Array.replicate valIdx 1, - currentReturnGroup := "" } + let state : CompilerState := { valIdx, selIdx := 0, ops := #[], + degrees := Array.replicate valIdx 1 } match f.body.compile f.output layoutMap bindings |>.run state with | .error e _ => throw e - | .ok body _ => + | .ok body finalState => let (_, layoutMState) := Bytecode.blockLayout body |>.run (.new valIdx) let layoutMState := { layoutMState with functionLayout := { layoutMState.functionLayout with lookups := layoutMState.functionLayout.lookups + 1 } } - pure (body, layoutMState) + pure (body, layoutMState, finalState.groupNames) def Concrete.Decls.toBytecode (decls : Concrete.Decls) : Except String (Bytecode.Toplevel × Std.HashMap Global Bytecode.FunIdx) := do @@ -569,9 +595,12 @@ def Concrete.Decls.toBytecode (decls : Concrete.Decls) : let (functions, memSizes, nameMap) ← decls.foldlM (init := (#[], initMemSizes, {})) fun acc@(functions, memSizes, nameMap) (_, decl) => match decl with | .function function => do - let (body, layoutMState) ← function.compile layout + let (body, layoutMState, groupNames) ← function.compile layout let nameMap := nameMap.insert function.name functions.size - let function := ⟨body, layoutMState.functionLayout, function.entry, false⟩ + let groupNames := if groupNames.isEmpty then #[""] else groupNames + let function : Bytecode.Function := + { body, layout := layoutMState.functionLayout, + groupNames, entry := function.entry, constrained := false } let memSizes := layoutMState.memSizes.fold (·.insert ·) memSizes pure (functions.push function, memSizes, nameMap) | _ => pure acc diff --git a/Ix/Aiur/Compiler/Split.lean b/Ix/Aiur/Compiler/Split.lean index ecdaffd4..b0f6d768 100644 --- a/Ix/Aiur/Compiler/Split.lean +++ b/Ix/Aiur/Compiler/Split.lean @@ -6,11 +6,11 @@ public import Ix.Aiur.Compiler.Layout Return-group splitting for Aiur bytecode. Each `Bytecode.Function` may carry multiple `Ctrl.return` sites tagged with -distinct group names. `Function.split` carves the function into one filtered -sub-function per group — control-flow paths that cannot reach a `return` of -the target group are pruned. The result is a sorted array of -`(groupName, filteredFunction)` pairs with selectors renumbered in -traversal order and `FunctionLayout` recomputed. +distinct `USize` group indices. `Function.split` carves the function into one +filtered sub-function per index — control-flow paths that cannot reach a +`return` of the target group are pruned. The result is positional: index +`i` in the output equals the group index `i` in the function's +`groupNames` table. -/ public section @@ -27,44 +27,6 @@ private theorem Block.sizeOf_ctrl_lt_split (b : Block) : show sizeOf ctrl < 1 + sizeOf ops + sizeOf ctrl omega -/-! ## Collect return-group names -/ - -mutual -def Ctrl.collectGroups (c : Ctrl) : Array String := match c with - | .return _ g _ => #[g] - | .yield .. => #[] - | .match _ cases default? => - let branchGroups := cases.attach.foldl (init := #[]) fun acc ⟨(_, blk), _⟩ => - acc ++ Block.collectGroups blk - match default? with - | some blk => branchGroups ++ Block.collectGroups blk - | none => branchGroups - | .matchContinue _ cases default? _ _ _ continuation => - let branchGroups := cases.attach.foldl (init := #[]) fun acc ⟨(_, blk), _⟩ => - acc ++ Block.collectGroups blk - let withDefault := match default? with - | some blk => branchGroups ++ Block.collectGroups blk - | none => branchGroups - withDefault ++ Block.collectGroups continuation -termination_by (sizeOf c, 0) -decreasing_by - all_goals first - | decreasing_tactic - | (have := Array.sizeOf_lt_of_mem ‹_ ∈ _›; grind) - | grind - -def Block.collectGroups (b : Block) : Array String := Ctrl.collectGroups b.ctrl -termination_by (sizeOf b, 1) -decreasing_by - all_goals first - | decreasing_tactic - | (apply Prod.Lex.left; exact Block.sizeOf_ctrl_lt_split _) -end - -def Function.returnGroups (f : Function) : Std.HashSet String := - (Block.collectGroups f.body).foldl (init := ({} : Std.HashSet String)) - fun acc g => acc.insert g - /-! ## Filter control-flow tree by target group Single (non-mutual) recursion on `Ctrl`. The conditional `Option`-typed @@ -86,7 +48,7 @@ def Ctrl.mkMatchContinue (scrut : ValIdx) (cases : Array (G × Block)) if cases.isEmpty && default?.isNone then none else some (.matchContinue scrut cases default? outputSize sharedAux sharedLookups cont) -def Ctrl.filterGroup (target : String) : Ctrl → Option Ctrl +def Ctrl.filterGroup (target : USize) : Ctrl → Option Ctrl | .return sel g vs => if g = target then some (.return sel g vs) else none | .yield sel vs => some (.yield sel vs) | .match scrut cases default? => @@ -120,7 +82,7 @@ decreasing_by | (have := Block.sizeOf_ctrl_lt_split ‹Block›; grind) | grind -def Block.filterGroup (target : String) (b : Block) : Option Block := +def Block.filterGroup (target : USize) (b : Block) : Option Block := (Ctrl.filterGroup target b.ctrl).map ({ b with ctrl := · }) /-! ## Renumber selectors in traversal order -/ @@ -190,17 +152,19 @@ def Function.fix (f : Function) : Function := /-! ## Top-level split -/ -/-- Carve `f` into one filtered sub-function per return group. Result is sorted -by group name. Panics if a discovered group has no reachable path (cannot -happen for well-formed bytecode). -/ -def Function.split (f : Function) : Array (String × Function) := - let groups := f.returnGroups.toArray.qsort fun a b => decide (a < b) - groups.map fun g => - match Block.filterGroup g f.body with - | none => panic! s!"function contains an unreachable group: {g}" +/-- Carve `f` into one filtered sub-function per group index in +`f.groupNames`. Position `i` in the result equals group index `i`. Panics +if a stored group index has no reachable `Return` site (should not happen +for well-formed bytecode produced by `Concrete.Function.compile`). -/ +def Function.split (f : Function) : Array Function := + (Array.range f.groupNames.size).map fun i => + let target : USize := USize.ofNat i + match Block.filterGroup target f.body with + | none => + panic! s!"function contains an unreachable group: {f.groupNames[i]!}" | some body => let filtered : Function := { f with body } - (g, filtered.fix) + filtered.fix /-- Populate `t.filteredFunctions` by splitting every function. Idempotent — overwrites any existing entries. Should run after `deduplicate` and diff --git a/Ix/Aiur/Semantics/BytecodeFfi.lean b/Ix/Aiur/Semantics/BytecodeFfi.lean index b0a6bed3..2a67a96e 100644 --- a/Ix/Aiur/Semantics/BytecodeFfi.lean +++ b/Ix/Aiur/Semantics/BytecodeFfi.lean @@ -61,9 +61,11 @@ structure QueryCount where namespace Bytecode.Toplevel -/-- Per-function execution stats. One entry per split (return group), sorted -by group name. Each quadruple is `(group, totalWidth, uniqueRows, totalHits)`. -/ -abbrev FunctionStats := Array (Array (String × Nat × Nat × Nat)) +/-- Per-function execution stats. One entry per split (return group), keyed +by group index. Each quadruple is +`(groupIdx, totalWidth, uniqueRows, totalHits)`. The display name is looked +up via `Function.groupNames[groupIdx]`. -/ +abbrev FunctionStats := Array (Array (Nat × Nat × Nat × Nat)) /-- Per-memory-size `(uniqueRows, totalHits)` pairs. -/ abbrev MemoryCounts := Array (Nat × Nat) diff --git a/Ix/Aiur/Stages/Bytecode.lean b/Ix/Aiur/Stages/Bytecode.lean index 7cc4e184..af8ef8bf 100644 --- a/Ix/Aiur/Stages/Bytecode.lean +++ b/Ix/Aiur/Stages/Bytecode.lean @@ -49,7 +49,7 @@ inductive Op mutual inductive Ctrl where | match : ValIdx → Array (G × Block) → Option Block → Ctrl - | return : SelIdx → (group : String) → Array ValIdx → Ctrl + | return : SelIdx → (group : USize) → Array ValIdx → Ctrl | yield : SelIdx → Array ValIdx → Ctrl | matchContinue : ValIdx → Array (G × Block) → Option Block → (outputSize : Nat) → (sharedAuxiliaries : Nat) → (sharedLookups : Nat) @@ -84,16 +84,21 @@ def FunctionLayout.totalWidth (l : FunctionLayout) : Nat := structure Function where body : Block layout: FunctionLayout + /-- Display names for the return groups used in `body`. Position `i` in the + array maps to the `USize` group index `i` carried by `Ctrl.return`. Defaults + to the singleton `#[""]` so functions with no `#[return_group(…)]` + annotations have a single unnamed group at index `0`. -/ + groupNames : Array String := #[""] entry : Bool constrained : Bool deriving Inhabited, Repr structure Toplevel where functions : Array Function - /-- Per-function split by return group: one entry per `functions[i]`, each a - sorted array of `(groupName, filteredFunction)` pairs. Populated by + /-- Per-function split by return group: position in the inner array equals + the `USize` group index used by `Ctrl.return`. Populated by `Toplevel.computeFiltered` after `deduplicate` + `needsCircuit`. -/ - filteredFunctions : Array (Array (String × Function)) := #[] + filteredFunctions : Array (Array Function) := #[] memorySizes : Array Nat deriving Repr diff --git a/Ix/Aiur/Statistics.lean b/Ix/Aiur/Statistics.lean index 4b85307f..e90ffe8c 100644 --- a/Ix/Aiur/Statistics.lean +++ b/Ix/Aiur/Statistics.lean @@ -40,7 +40,7 @@ def fftCost (w h : Nat) : Float := wf * hf * (max hf 2.0).log2 def computeStats (compiled : CompiledToplevel) - (functionStats : Array (Array (String × Nat × Nat × Nat))) + (functionStats : Array (Array (Nat × Nat × Nat × Nat))) (memoryCounts : Array (Nat × Nat)) : ExecutionStats := let t := compiled.bytecode -- Invert nameMap to get FunIdx → String @@ -52,12 +52,14 @@ def computeStats (compiled : CompiledToplevel) for i in [:nAllFuns] do if t.functions[i]!.constrained then let baseName := reverseMap[i]?.getD s!"" + let groupNames := t.functions[i]!.groupNames for quad in functionStats[i]! do - let group := quad.1 + let groupIdx := quad.1 let w := quad.2.1 let h := quad.2.2.1 let totalHits := quad.2.2.2 let hits := totalHits - h + let group := groupNames[groupIdx]?.getD "" let name := if group.isEmpty then baseName else s!"{baseName} [{group}]" acc := acc.push { name, width := w, height := h, cacheHits := hits, fftCost := fftCost w h : CircuitStats } acc diff --git a/src/aiur/bytecode.rs b/src/aiur/bytecode.rs index 5d428239..86c05e70 100644 --- a/src/aiur/bytecode.rs +++ b/src/aiur/bytecode.rs @@ -1,14 +1,12 @@ -use std::sync::Arc; - -use rustc_hash::FxHashMap; - use crate::FxIndexMap; use super::G; pub struct Toplevel { pub(crate) functions: Vec, - pub(crate) filtered_functions: Vec, Function>>, + /// Per-function split by return-group index: outer index = `FunIdx`, inner + /// index = the `usize` group index carried by `Ctrl::Return`. + pub(crate) filtered_functions: Vec>, pub(crate) memory_sizes: Vec, } @@ -68,7 +66,7 @@ pub enum Op { pub enum Ctrl { Match(ValIdx, FxIndexMap, Option>), - Return(SelIdx, Arc, Vec), + Return(SelIdx, usize, Vec), Yield(SelIdx, Vec), MatchContinue( ValIdx, diff --git a/src/aiur/constraints.rs b/src/aiur/constraints.rs index 6e18acd7..2b1ec926 100644 --- a/src/aiur/constraints.rs +++ b/src/aiur/constraints.rs @@ -110,11 +110,9 @@ impl Toplevel { pub fn build_constraints( &self, function_index: usize, - group: &str, + group: usize, ) -> (Constraints, Vec>) { - let function = self.filtered_functions[function_index] - .get(group) - .expect("Missing filtered function for group"); + let function = &self.filtered_functions[function_index][group]; let constraints = Constraints { zeros: vec![], selectors: 0..0, diff --git a/src/aiur/execute.rs b/src/aiur/execute.rs index b2500d2e..8a0a03f7 100644 --- a/src/aiur/execute.rs +++ b/src/aiur/execute.rs @@ -1,7 +1,6 @@ use multi_stark::p3_field::{PrimeCharacteristicRing, PrimeField64}; use rustc_hash::FxHashMap; use std::collections::hash_map::Entry; -use std::sync::Arc; use crate::{ FxIndexMap, @@ -19,7 +18,7 @@ use crate::{ pub struct QueryResult { pub(crate) output: Vec, pub(crate) multiplicity: G, - pub(crate) return_group: Arc, + pub(crate) return_group: usize, } pub type QueryMap = FxIndexMap, QueryResult>; @@ -258,7 +257,7 @@ impl Function { let result = QueryResult { output: vec![ptr], multiplicity: G::from_bool(!unconstrained), - return_group: Arc::from(""), + return_group: 0, }; memory_queries.insert(values, result); map.push(ptr); @@ -491,7 +490,7 @@ impl Function { let result = QueryResult { output: output.clone(), multiplicity: G::from_bool(!unconstrained), - return_group: group.clone(), + return_group: *group, }; record.function_queries[fun_idx].insert(args, result); if let Some(CallerState { diff --git a/src/aiur/synthesis.rs b/src/aiur/synthesis.rs index f1efa8ff..4122d788 100644 --- a/src/aiur/synthesis.rs +++ b/src/aiur/synthesis.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use multi_stark::{ lookup::LookupAir, p3_air::{Air, AirBuilder, BaseAir}, @@ -73,7 +71,7 @@ where } enum CircuitType { - Function { idx: usize, group: Arc }, + Function { idx: usize, group: usize }, Memory { width: usize }, Bytes1, Bytes2, @@ -87,17 +85,13 @@ impl AiurSystem { let toplevel_ref = &toplevel; let function_circuits = (0..toplevel_ref.functions.len()).flat_map(move |i| { - let groups: Vec> = if toplevel_ref.functions[i].constrained { - let mut gs: Vec> = - toplevel_ref.filtered_functions[i].keys().cloned().collect(); - gs.sort(); - gs + let group_count = if toplevel_ref.functions[i].constrained { + toplevel_ref.filtered_functions[i].len() } else { - vec![] + 0 }; - groups.into_iter().map(move |group| { - let (constraints, lookups) = - toplevel_ref.build_constraints(i, &group); + (0..group_count).map(move |group| { + let (constraints, lookups) = toplevel_ref.build_constraints(i, group); LookupAir::new(AiurCircuit::Function(constraints), lookups) }) }); @@ -144,18 +138,12 @@ impl AiurSystem { let _g = tracing::info_span!("aiur/witness").entered(); let functions: Vec = (0..self.toplevel.functions.len()) .flat_map(|idx| { - let groups: Vec> = if self.toplevel.functions[idx].constrained - { - let mut gs: Vec> = - self.toplevel.filtered_functions[idx].keys().cloned().collect(); - gs.sort(); - gs + let group_count = if self.toplevel.functions[idx].constrained { + self.toplevel.filtered_functions[idx].len() } else { - vec![] + 0 }; - groups - .into_iter() - .map(move |group| CircuitType::Function { idx, group }) + (0..group_count).map(move |group| CircuitType::Function { idx, group }) }) .collect(); let functions = functions.into_par_iter(); @@ -170,7 +158,7 @@ impl AiurSystem { .chain(gadgets) .map(|circuit_type| match circuit_type { CircuitType::Function { idx, group } => { - self.toplevel.witness_data(idx, &group, &query_record, io_buffer) + self.toplevel.witness_data(idx, group, &query_record, io_buffer) }, CircuitType::Memory { width } => { Memory::witness_data(width, &query_record) diff --git a/src/aiur/trace.rs b/src/aiur/trace.rs index a1ebba36..740dda5e 100644 --- a/src/aiur/trace.rs +++ b/src/aiur/trace.rs @@ -77,19 +77,17 @@ impl Toplevel { pub fn witness_data( &self, function_index: usize, - group: &str, + group: usize, query_record: &QueryRecord, io_buffer: &IOBuffer, ) -> (RowMajorMatrix, Vec>>) { - let func = self.filtered_functions[function_index] - .get(group) - .expect("Missing filtered function for group"); + let func = &self.filtered_functions[function_index][group]; let width = func.width(); let unfiltered_queries = &query_record.function_queries[function_index]; let queries = unfiltered_queries .iter() .filter(|(_, res)| { - !res.multiplicity.is_zero() && res.return_group.as_ref() == group + !res.multiplicity.is_zero() && res.return_group == group }) .collect::>(); let height_no_padding = queries.len(); diff --git a/src/ffi/aiur/protocol.rs b/src/ffi/aiur/protocol.rs index 245b76b5..43b9e0a0 100644 --- a/src/ffi/aiur/protocol.rs +++ b/src/ffi/aiur/protocol.rs @@ -8,7 +8,7 @@ use std::sync::LazyLock; use lean_ffi::object::{ ExternalClass, LeanArray, LeanBorrowed, LeanByteArray, LeanExcept, - LeanExternal, LeanNat, LeanOwned, LeanProd, LeanRef, LeanString, + LeanExternal, LeanNat, LeanOwned, LeanProd, LeanRef, }; use crate::{ @@ -88,9 +88,10 @@ extern "C" fn rs_aiur_system_verify( /// `Bytecode.Toplevel.execute`: runs execution only (no proof) and returns /// `Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) -/// × (Array (Array (String × Nat × Nat × Nat)) × Array (Nat × Nat)))`. +/// × (Array (Array (Nat × Nat × Nat × Nat)) × Array (Nat × Nat)))`. /// The function side is per-function array of per-split -/// `(group, width, uniqueRows, totalHits)` quadruples. The memory side is +/// `(groupIdx, width, uniqueRows, totalHits)` quadruples (display names are +/// looked up Lean-side via `Function.groupNames`). The memory side is /// per-memory-size `(uniqueRows, totalHits)` pairs. /// On execution failure (e.g. assertion mismatch from a typechecker /// rejecting a constant), returns `Except.error msg` instead of panicking @@ -134,34 +135,33 @@ extern "C" fn rs_aiur_toplevel_execute( (rows, hits) }; - // Per-function array of `(group, width, unique_rows, total_hits)` quadruples, - // one per split. Queries within a function are partitioned by return group. - let function_stats: Vec, usize, usize, usize)>> = (0 - ..toplevel.functions.len()) - .map(|i| { - let queries = &query_record.function_queries[i]; - let mut stats: Vec<(std::sync::Arc, usize, usize, usize)> = toplevel - .filtered_functions[i] - .iter() - .map(|(group, func)| { - let (rows, hits) = queries - .iter() - .filter(|(_, res)| res.return_group.as_ref() == group.as_ref()) - .fold((0usize, 0usize), |(r, h), (_, res)| { - let m = usize::try_from(res.multiplicity.as_canonical_u64()) - .expect("multiplicity exceeds usize"); - if m != 0 { (r + 1, h + m) } else { (r, h) } - }); - let l = func.layout; - let width = - l.input_size + l.selectors + l.auxiliaries + 4 * (1 + l.lookups); - (group.clone(), width, rows, hits) - }) - .collect(); - stats.sort_by(|a, b| a.0.cmp(&b.0)); - stats - }) - .collect(); + // Per-function array of `(group_idx, width, unique_rows, total_hits)` + // quadruples — one entry per split, in group-index order. Queries within + // a function are partitioned by `QueryResult.return_group`. + let function_stats: Vec> = + (0..toplevel.functions.len()) + .map(|i| { + let queries = &query_record.function_queries[i]; + toplevel.filtered_functions[i] + .iter() + .enumerate() + .map(|(group_idx, func)| { + let (rows, hits) = queries + .iter() + .filter(|(_, res)| res.return_group == group_idx) + .fold((0usize, 0usize), |(r, h), (_, res)| { + let m = usize::try_from(res.multiplicity.as_canonical_u64()) + .expect("multiplicity exceeds usize"); + if m != 0 { (r + 1, h + m) } else { (r, h) } + }); + let l = func.layout; + let width = + l.input_size + l.selectors + l.auxiliaries + 4 * (1 + l.lookups); + (group_idx, width, rows, hits) + }) + .collect() + }) + .collect(); let memory_counts: Vec<(usize, usize)> = toplevel .memory_sizes @@ -175,10 +175,10 @@ extern "C" fn rs_aiur_toplevel_execute( let outer = LeanArray::alloc(function_stats.len()); for (i, per_fn) in function_stats.iter().enumerate() { let inner = LeanArray::alloc(per_fn.len()); - for (j, (group, width, rows, hits)) in per_fn.iter().enumerate() { - // (String × Nat × Nat × Nat) — right-nested pair encoding + for (j, (group_idx, width, rows, hits)) in per_fn.iter().enumerate() { + // (Nat × Nat × Nat × Nat) — right-nested pair encoding let quad = LeanProd::new( - LeanString::new(group.as_ref()), + LeanOwned::box_usize(*group_idx), LeanProd::new( LeanOwned::box_usize(*width), LeanProd::new( @@ -207,7 +207,7 @@ extern "C" fn rs_aiur_toplevel_execute( let lean_io = build_lean_io_buffer(&io_buffer); // (Array G, (Array G × Array (Array G × IOKeyInfo), - // Array (Array (String × Nat × Nat × Nat)) × Array (Nat × Nat))) + // Array (Array (Nat × Nat × Nat × Nat)) × Array (Nat × Nat))) let io_counts = LeanProd::new(lean_io, lean_query_counts); let result = LeanProd::new(build_g_array(&output), io_counts); LeanExcept::ok(result) diff --git a/src/ffi/aiur/toplevel.rs b/src/ffi/aiur/toplevel.rs index 00f1ce13..ed281ff2 100644 --- a/src/ffi/aiur/toplevel.rs +++ b/src/ffi/aiur/toplevel.rs @@ -1,11 +1,8 @@ -use std::sync::Arc; - use multi_stark::p3_field::PrimeCharacteristicRing; -use rustc_hash::FxHashMap; use lean_ffi::object::{LeanBorrowed, LeanCtor, LeanRef}; -use crate::lean::LeanAiurFunction; +use crate::lean::{LeanAiurCtrl, LeanAiurFunction}; use crate::{ FxIndexMap, @@ -18,26 +15,6 @@ use crate::{ use crate::ffi::aiur::{lean_unbox_g, lean_unbox_nat_as_usize}; -/// Per-decode interner for return-group names. Shared across every -/// `Ctrl::Return` decoded inside the `functions` tree and every key in -/// `filtered_functions`, so the same group name resolves to a single -/// `Arc` allocation. -#[derive(Default)] -struct GroupInterner { - map: FxHashMap>, -} - -impl GroupInterner { - fn intern(&mut self, s: &str) -> Arc { - if let Some(existing) = self.map.get(s) { - return existing.clone(); - } - let arc: Arc = Arc::from(s); - self.map.insert(s.to_string(), arc.clone()); - arc - } -} - fn decode_vec_val_idx(obj: LeanBorrowed<'_>) -> Vec { obj.as_array().map(|x| lean_unbox_nat_as_usize(&x)) } @@ -167,80 +144,78 @@ fn decode_op(ctor: LeanCtor>) -> Op { } } -fn decode_g_block_pair( - ctor: LeanCtor>, - interner: &mut GroupInterner, -) -> (G, Block) { +fn decode_g_block_pair(ctor: LeanCtor>) -> (G, Block) { let [g_obj, block_obj] = ctor.objs::<2>(); let g = lean_unbox_g(&g_obj); - let block = decode_block(block_obj.as_ctor(), interner); + let block = decode_block(block_obj.as_ctor()); (g, block) } -fn decode_ctrl( - ctor: LeanCtor>, - interner: &mut GroupInterner, -) -> Ctrl { - match ctor.tag() { +fn decode_ctrl(ctor: LeanCtor>) -> Ctrl { + let typed = LeanAiurCtrl::from_ctor(ctor); + match typed.as_ctor().tag() { 0 => { - let [val_idx_obj, cases_obj, default_obj] = ctor.objs::<3>(); + let val_idx_obj = typed.get_obj(0); + let cases_obj = typed.get_obj(1); + let default_obj = typed.get_obj(2); let val_idx = lean_unbox_nat_as_usize(&val_idx_obj); let cases: FxIndexMap = cases_obj .as_array() - .iter() - .map(|o| decode_g_block_pair(o.as_ctor(), interner)) + .map(|o| decode_g_block_pair(o.as_ctor())) + .into_iter() .collect(); let default = if default_obj.is_scalar() { None } else { let inner_ctor = default_obj.as_ctor(); - let block = decode_block(inner_ctor.get(0).as_ctor(), interner); + let block = decode_block(inner_ctor.get(0).as_ctor()); Some(Box::new(block)) }; Ctrl::Match(val_idx, cases, default) }, 1 => { - let [sel_idx_obj, group_obj, val_idxs_obj] = ctor.objs::<3>(); + // `Ctrl.return : SelIdx → USize → Array ValIdx → Ctrl` — `sel_idx` and + // `val_idxs` are boxed; `group` is a `USize` scalar field declared + // via `num_usize: 1` in `LeanAiurCtrl`. + let sel_idx_obj = typed.get_obj(0); + let val_idxs_obj = typed.get_obj(1); let sel_idx = lean_unbox_nat_as_usize(&sel_idx_obj); - let group_lean = group_obj.as_string(); - let group = interner.intern(group_lean.as_str()); + let group = typed.get_usize(0); let val_idxs = decode_vec_val_idx(val_idxs_obj); Ctrl::Return(sel_idx, group, val_idxs) }, 2 => { - let [sel_idx_obj, val_idxs_obj] = ctor.objs::<2>(); + let sel_idx_obj = typed.get_obj(0); + let val_idxs_obj = typed.get_obj(1); let sel_idx = lean_unbox_nat_as_usize(&sel_idx_obj); let val_idxs = decode_vec_val_idx(val_idxs_obj); Ctrl::Yield(sel_idx, val_idxs) }, 3 => { - let [ - val_idx_obj, - cases_obj, - default_obj, - output_size_obj, - shared_aux_obj, - shared_lookups_obj, - cont_obj, - ] = ctor.objs::<7>(); + let val_idx_obj = typed.get_obj(0); + let cases_obj = typed.get_obj(1); + let default_obj = typed.get_obj(2); + let output_size_obj = typed.get_obj(3); + let shared_aux_obj = typed.get_obj(4); + let shared_lookups_obj = typed.get_obj(5); + let cont_obj = typed.get_obj(6); let val_idx = lean_unbox_nat_as_usize(&val_idx_obj); let cases: FxIndexMap = cases_obj .as_array() - .iter() - .map(|o| decode_g_block_pair(o.as_ctor(), interner)) + .map(|o| decode_g_block_pair(o.as_ctor())) + .into_iter() .collect(); let default = if default_obj.is_scalar() { None } else { let inner_ctor = default_obj.as_ctor(); - let block = decode_block(inner_ctor.get(0).as_ctor(), interner); + let block = decode_block(inner_ctor.get(0).as_ctor()); Some(Box::new(block)) }; let output_size = lean_unbox_nat_as_usize(&output_size_obj); let shared_aux = lean_unbox_nat_as_usize(&shared_aux_obj); let shared_lookups = lean_unbox_nat_as_usize(&shared_lookups_obj); - let continuation = - Box::new(decode_block(cont_obj.as_ctor(), interner)); + let continuation = Box::new(decode_block(cont_obj.as_ctor())); Ctrl::MatchContinue( val_idx, cases, @@ -255,13 +230,10 @@ fn decode_ctrl( } } -fn decode_block( - ctor: LeanCtor>, - interner: &mut GroupInterner, -) -> Block { +fn decode_block(ctor: LeanCtor>) -> Block { let [ops_obj, ctrl_obj] = ctor.objs::<2>(); let ops = ops_obj.as_array().map(|o| decode_op(o.as_ctor())); - let ctrl = decode_ctrl(ctrl_obj.as_ctor(), interner); + let ctrl = decode_ctrl(ctrl_obj.as_ctor()); Block { ops, ctrl } } @@ -275,59 +247,32 @@ fn decode_function_layout(ctor: LeanCtor>) -> FunctionLayout { } } -fn decode_function( - ctor: LeanCtor>, - interner: &mut GroupInterner, -) -> Function { +fn decode_function(ctor: LeanCtor>) -> Function { + // Lean `Bytecode.Function` has 3 boxed fields (body, layout, groupNames) + // and 2 scalar booleans (entry, constrained). The `groupNames` slot is the + // display table for return-group indices and is only consumed Lean-side + // (via `Statistics.computeStats`), so the Rust decoder skips it. let ctor = LeanAiurFunction::from_ctor(ctor); - let body = decode_block(ctor.get_obj(0).as_ctor(), interner); + let body = decode_block(ctor.get_obj(0).as_ctor()); let layout = decode_function_layout(ctor.get_obj(1).as_ctor()); let entry = ctor.get_num_8(0) != 0; let constrained = ctor.get_num_8(1) != 0; Function { body, layout, entry, constrained } } -/// Decode a single `(String × Function)` Lean product into its Rust form, -/// sharing the group `Arc` with every other site that mentions the -/// same name (via the interner). -fn decode_group_function_pair( - ctor: LeanCtor>, - interner: &mut GroupInterner, -) -> (Arc, Function) { - let [group_obj, fn_obj] = ctor.objs::<2>(); - let group = interner.intern(group_obj.as_string().as_str()); - let function = decode_function(fn_obj.as_ctor(), interner); - (group, function) -} - pub(crate) fn decode_toplevel( obj: &LeanAiurToplevel, ) -> Toplevel { let ctor = obj.as_ctor(); let [functions_obj, filtered_functions_obj, memory_sizes_obj] = ctor.objs::<3>(); - let mut interner = GroupInterner::default(); - // `LeanArray::map` requires `Fn`, so we cannot borrow `interner` mutably - // inside it. Use `iter()` for the interner-touching layers. - let functions: Vec = functions_obj - .as_array() - .iter() - .map(|o| decode_function(o.as_ctor(), &mut interner)) - .collect(); - let filtered_functions: Vec, Function>> = - filtered_functions_obj - .as_array() - .iter() - .map(|inner_obj| { - inner_obj - .as_array() - .iter() - .map(|pair_obj| { - decode_group_function_pair(pair_obj.as_ctor(), &mut interner) - }) - .collect::>() - }) - .collect(); + let functions: Vec = + functions_obj.as_array().map(|o| decode_function(o.as_ctor())); + // `filteredFunctions : Array (Array Function)` — positional by group index. + let filtered_functions: Vec> = + filtered_functions_obj.as_array().map(|inner_obj| { + inner_obj.as_array().map(|fn_obj| decode_function(fn_obj.as_ctor())) + }); let memory_sizes = memory_sizes_obj.as_array().map(|x| lean_unbox_nat_as_usize(&x)); Toplevel { functions, memory_sizes, filtered_functions } diff --git a/src/lean.rs b/src/lean.rs index 24190ee6..b02320aa 100644 --- a/src/lean.rs +++ b/src/lean.rs @@ -235,7 +235,17 @@ lean_ffi::lean_inductive! { // --- Aiur types --- LeanAiurToplevel [ { num_obj: 3 } ]; - LeanAiurFunction [ { num_obj: 2, num_8: 2 } ]; + LeanAiurFunction [ { num_obj: 3, num_8: 2 } ]; + + // `Bytecode.Ctrl` — `return` carries a `USize` group index as a scalar + // ctor field. Declared so `get_usize(0)` is bounds-checked against + // `num_usize: 1` rather than reaching into raw memory. + LeanAiurCtrl [ + { num_obj: 3 }, // tag 0: match + { num_obj: 2, num_usize: 1 }, // tag 1: return + { num_obj: 2 }, // tag 2: yield + { num_obj: 7 }, // tag 3: matchContinue + ]; // --- Block / comparison types --- From 42c4cebc7d4ec131e520a92097c37bdab19f9688 Mon Sep 17 00:00:00 2001 From: Arthur Paulino Date: Wed, 20 May 2026 07:17:50 -0700 Subject: [PATCH 12/13] Use named structures for QueryCounts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the anonymous `Nat × Nat × Nat × Nat` quads / `Nat × Nat` pairs shipped from the executor with `GroupStats { groupIdx, totalWidth, uniqueRows, totalHits }` and `MemoryCount { uniqueRows, totalHits }`, bundled into a `QueryCounts { functionStats, memoryCounts }` struct. The FFI extern keeps the tuple shape (so the Rust side can build it without matching a Lean structure ctor); `Bytecode.Toplevel.execute` converts to the structured form at the public API boundary. Callers (`Kernel.lean`, `Statistics.computeStats`) lose the `.1`/`.2.2.1` indexing and use named field access. --- Ix/Aiur/Semantics/BytecodeFfi.lean | 65 +++++++++++++++++++----------- Ix/Aiur/Statistics.lean | 29 ++++++------- Kernel.lean | 2 +- 3 files changed, 54 insertions(+), 42 deletions(-) diff --git a/Ix/Aiur/Semantics/BytecodeFfi.lean b/Ix/Aiur/Semantics/BytecodeFfi.lean index 2a67a96e..1d3fd1dd 100644 --- a/Ix/Aiur/Semantics/BytecodeFfi.lean +++ b/Ix/Aiur/Semantics/BytecodeFfi.lean @@ -50,42 +50,53 @@ instance : BEq IOBuffer where -- via `Std.HashMap.beq_iff_equiv` + `Std.HashMap.Equiv.{refl,symm,trans}`, -- bypassing the need for `LawfulBEq` on the outer `IOBuffer`. -/-- Per-circuit query counts for one circuit (one per function circuit, then -one per memory size). `uniqueRows` is the trace height; `totalHits` is the sum -of query multiplicities. The difference `totalHits - uniqueRows` is the number -of cache hits. -/ -structure QueryCount where +namespace Bytecode.Toplevel + +/-- Per-split execution stats for one return-group of one function. +`groupIdx` keys the corresponding entry in `Function.groupNames`. -/ +structure GroupStats where + groupIdx : Nat + totalWidth : Nat uniqueRows : Nat totalHits : Nat - deriving Inhabited + deriving Inhabited, Repr -namespace Bytecode.Toplevel +/-- Per-memory-size counts. `uniqueRows` is the trace height; `totalHits` +is the sum of multiplicities. `totalHits - uniqueRows` is the cache-hit +count. -/ +structure MemoryCount where + uniqueRows : Nat + totalHits : Nat + deriving Inhabited, Repr -/-- Per-function execution stats. One entry per split (return group), keyed -by group index. Each quadruple is -`(groupIdx, totalWidth, uniqueRows, totalHits)`. The display name is looked -up via `Function.groupNames[groupIdx]`. -/ -abbrev FunctionStats := Array (Array (Nat × Nat × Nat × Nat)) +/-- Per-function execution stats. Outer index is `FunIdx`; inner index is +the `USize` group index used by `Ctrl.return`. -/ +abbrev FunctionStats := Array (Array GroupStats) -/-- Per-memory-size `(uniqueRows, totalHits)` pairs. -/ -abbrev MemoryCounts := Array (Nat × Nat) +/-- Per-memory-size counts, parallel to `Toplevel.memorySizes`. -/ +abbrev MemoryCounts := Array MemoryCount -/-- Query counts shipped back from the Rust executor: per-function split stats -plus per-memory pairs. -/ -abbrev QueryCounts := FunctionStats × MemoryCounts +/-- Query counts shipped back from the Rust executor. -/ +structure QueryCounts where + functionStats : FunctionStats + memoryCounts : MemoryCounts + deriving Inhabited +/-- Raw FFI tuple shape — kept tuple-flat so the Rust side can build it +without declaring matching Lean structure ctors. `execute` wraps the +result in the structured `QueryCounts` immediately. -/ @[extern "rs_aiur_toplevel_execute"] private opaque execute' : @& Bytecode.Toplevel → @& Bytecode.FunIdx → @& Array G → (ioData : @& Array G) → (ioMap : @& Array (Array G × IOKeyInfo)) → - Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) × QueryCounts) + Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) + × (Array (Array (Nat × Nat × Nat × Nat)) × Array (Nat × Nat))) /-- Executes the bytecode function `funIdx` with the given `args` and `ioBuffer`, returning the raw output of the function, the updated `IOBuffer`, and a -`QueryCounts` (per-function split stats + per-memory `(uniqueRows, totalHits)` -pairs). Returns `Except.error msg` when execution fails (e.g. `assert_eq!` -mismatch from a typechecker rejecting a constant), so callers can recover -instead of crashing. -/ +`QueryCounts`. Returns `Except.error msg` when execution fails (e.g. +`assert_eq!` mismatch from a typechecker rejecting a constant), so callers +can recover instead of crashing. -/ def execute (toplevel : @& Bytecode.Toplevel) (funIdx : @& Bytecode.FunIdx) (args : @& Array G) (ioBuffer : IOBuffer) : Except String (Array G × IOBuffer × QueryCounts) := @@ -93,9 +104,15 @@ def execute (toplevel : @& Bytecode.Toplevel) let ioMap := ioBuffer.map match execute' toplevel funIdx args ioData ioMap.toArray with | .error e => .error e - | .ok (output, (ioData, ioMap), queryCounts) => + | .ok (output, (ioData, ioMap), rawFn, rawMem) => let ioMap := ioMap.foldl (fun acc (k, v) => acc.insert k v) ∅ - .ok (output, ⟨ioData, ioMap⟩, queryCounts) + let functionStats : FunctionStats := rawFn.map fun perFn => + perFn.map fun quad => + { groupIdx := quad.1, totalWidth := quad.2.1, + uniqueRows := quad.2.2.1, totalHits := quad.2.2.2 } + let memoryCounts : MemoryCounts := rawMem.map fun pair => + { uniqueRows := pair.1, totalHits := pair.2 } + .ok (output, ⟨ioData, ioMap⟩, { functionStats, memoryCounts }) end Bytecode.Toplevel diff --git a/Ix/Aiur/Statistics.lean b/Ix/Aiur/Statistics.lean index e90ffe8c..5de795f6 100644 --- a/Ix/Aiur/Statistics.lean +++ b/Ix/Aiur/Statistics.lean @@ -39,9 +39,8 @@ def fftCost (w h : Nat) : Float := let hf := h.toFloat wf * hf * (max hf 2.0).log2 -def computeStats (compiled : CompiledToplevel) - (functionStats : Array (Array (Nat × Nat × Nat × Nat))) - (memoryCounts : Array (Nat × Nat)) : ExecutionStats := +def computeStats (compiled : CompiledToplevel) (qc : Bytecode.Toplevel.QueryCounts) : + ExecutionStats := let t := compiled.bytecode -- Invert nameMap to get FunIdx → String let reverseMap := compiled.nameMap.fold (init := (∅ : Std.HashMap Bytecode.FunIdx String)) @@ -53,24 +52,20 @@ def computeStats (compiled : CompiledToplevel) if t.functions[i]!.constrained then let baseName := reverseMap[i]?.getD s!"" let groupNames := t.functions[i]!.groupNames - for quad in functionStats[i]! do - let groupIdx := quad.1 - let w := quad.2.1 - let h := quad.2.2.1 - let totalHits := quad.2.2.2 - let hits := totalHits - h - let group := groupNames[groupIdx]?.getD "" + for gs in qc.functionStats[i]! do + let group := groupNames[gs.groupIdx]?.getD "" let name := if group.isEmpty then baseName else s!"{baseName} [{group}]" - acc := acc.push { name, width := w, height := h, cacheHits := hits, fftCost := fftCost w h : CircuitStats } + let hits := gs.totalHits - gs.uniqueRows + acc := acc.push + { name, width := gs.totalWidth, height := gs.uniqueRows, + cacheHits := hits, fftCost := fftCost gs.totalWidth gs.uniqueRows : CircuitStats } acc let memoryCircuits := t.memorySizes.mapIdx fun i size => let w := size + 11 - let pair := memoryCounts[i]! - let h := pair.1 - let totalHits := pair.2 - let hits := totalHits - h - { name := s!"memory[{size}]", - width := w, height := h, cacheHits := hits, fftCost := fftCost w h : CircuitStats } + let mc := qc.memoryCounts[i]! + let hits := mc.totalHits - mc.uniqueRows + { name := s!"memory[{size}]", width := w, height := mc.uniqueRows, + cacheHits := hits, fftCost := fftCost w mc.uniqueRows : CircuitStats } let circuits := (functionCircuits ++ memoryCircuits).qsort (·.fftCost > ·.fftCost) let totalFftCost := circuits.foldl (· + ·.fftCost) 0.0 let totalUncachedFftCost := circuits.foldl (fun acc cs => acc + fftCost cs.width (cs.height + cs.cacheHits)) 0.0 diff --git a/Kernel.lean b/Kernel.lean index 3b7634c3..54fe8019 100644 --- a/Kernel.lean +++ b/Kernel.lean @@ -59,7 +59,7 @@ where if ioBuffer != testCase.expectedIOBuffer then IO.eprintln s!"{name}: IOBuffer mismatch" return 1 - let stats := Aiur.computeStats compiled queryCounts.1 queryCounts.2 + let stats := Aiur.computeStats compiled queryCounts Aiur.printStats stats pure 0 interpCheck decls name env : IO UInt32 := do From 60c38d472fc975cffbab7b2fefa99c579c4dd5f0 Mon Sep 17 00:00:00 2001 From: Arthur Paulino Date: Wed, 20 May 2026 09:50:25 -0700 Subject: [PATCH 13/13] Preserve groupNames through dedup + match arms MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two sites dropped the per-function `groupNames` table: 1. `Dedup.deduplicate_newFunctions` rebuilt the `Function` struct as `{ body, layout, entry, constrained }`, letting the field default `:= #[""]` collapse the table to a singleton. Every dedup'd function then split into one group at idx 0, so multi-group stats showed a single row instead of one per path. 2. `Lower.addCase` snapshotted the compiler state per arm and restored `initState` (patching only `selIdx`). Allocations made by `allocCurrentGroup` inside the arm — pushes to `groupNames` / `groupNameMap` — were discarded, so every arm reused index 0 and the final table held only the surviving allocation. Fix: thread `groupNames`/`groupNameMap` through both rollbacks. Verified via `lake exe kernel Nat.add_comm` with three return-group annotations on `blake3_compress_chunks` — stats now show one row per split. --- Ix/Aiur/Compiler/Dedup.lean | 3 ++- Ix/Aiur/Compiler/Lower.lean | 15 ++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/Ix/Aiur/Compiler/Dedup.lean b/Ix/Aiur/Compiler/Dedup.lean index 6b364d9a..ecb15231 100644 --- a/Ix/Aiur/Compiler/Dedup.lean +++ b/Ix/Aiur/Compiler/Dedup.lean @@ -196,7 +196,8 @@ def deduplicate_newFunctions (functions : Array Function) (classes : Array Nat) if can then let entry := deduplicate_class_entry functions classes cls let body := rewriteBlock remapFn f.body - acc.push { body, layout := f.layout, entry, constrained := false } + acc.push { body, layout := f.layout, groupNames := f.groupNames, + entry, constrained := false } else acc) #[] diff --git a/Ix/Aiur/Compiler/Lower.lean b/Ix/Aiur/Compiler/Lower.lean index 27f70ed0..6666eb0d 100644 --- a/Ix/Aiur/Compiler/Lower.lean +++ b/Ix/Aiur/Compiler/Lower.lean @@ -530,7 +530,10 @@ def Concrete.addCase | .field g => do let initState ← get let term ← term.compile returnTyp layoutMap bindings yieldCtrl - set { initState with selIdx := (← get).selIdx } + let cur ← get + set { initState with selIdx := cur.selIdx, + groupNames := cur.groupNames, + groupNameMap := cur.groupNameMap } pure (cases.push (g, term), defaultBlock) | .ref global pats => do let (index, offsets) ← match layoutMap[global]? with @@ -549,12 +552,18 @@ def Concrete.addCase acc.insert patLocal slice let initState ← get let term ← term.compile returnTyp layoutMap ptrBindings yieldCtrl - set { initState with selIdx := (← get).selIdx } + let cur ← get + set { initState with selIdx := cur.selIdx, + groupNames := cur.groupNames, + groupNameMap := cur.groupNameMap } pure (cases.push (.ofNat index, term), defaultBlock) | .wildcard => do let initState ← get let term ← term.compile returnTyp layoutMap bindings yieldCtrl - set { initState with selIdx := (← get).selIdx } + let cur ← get + set { initState with selIdx := cur.selIdx, + groupNames := cur.groupNames, + groupNameMap := cur.groupNameMap } pure (cases, .some term) | _ => throw "addCase: unsupported pattern in concrete lower" termination_by _ pair => (sizeOf pair.snd, 1)