diff --git a/Ix/Aiur.lean b/Ix/Aiur.lean index 67949f1e..6330b32f 100644 --- a/Ix/Aiur.lean +++ b/Ix/Aiur.lean @@ -20,5 +20,6 @@ public import Ix.Aiur.Compiler.Concretize public import Ix.Aiur.Compiler.Layout public import Ix.Aiur.Compiler.Lower public import Ix.Aiur.Compiler.Dedup +public import Ix.Aiur.Compiler.Split public import Ix.Aiur.Compiler public import Ix.Aiur.Statistics diff --git a/Ix/Aiur/Compiler.lean b/Ix/Aiur/Compiler.lean index ae250399..3a008b73 100644 --- a/Ix/Aiur/Compiler.lean +++ b/Ix/Aiur/Compiler.lean @@ -3,6 +3,7 @@ public import Ix.Aiur.Compiler.Lower public import Ix.Aiur.Compiler.Dedup public import Ix.Aiur.Compiler.Concretize public import Ix.Aiur.Compiler.Simple +public import Ix.Aiur.Compiler.Split /-! Aiur compiler pipeline: type-check, simplify, concretize, lower, deduplicate. @@ -68,7 +69,7 @@ def Bytecode.Ctrl.collectConstrainedCallees (c : Bytecode.Ctrl) : | some block => branchCallees ++ block.collectConstrainedCallees | none => branchCallees withDefault ++ continuation.collectConstrainedCallees - | .return _ _ | .yield _ _ => #[] + | .return _ _ _ | .yield _ _ => #[] termination_by (sizeOf c, 0) decreasing_by all_goals first @@ -117,9 +118,10 @@ def Source.Toplevel.compile (t : Source.Toplevel) : Except String CompiledToplev let (bytecodeRaw, preNameMap) ← concDecls.toBytecode let (bytecodeDedup, remap) := bytecodeRaw.deduplicate let needs := bytecodeDedup.needsCircuit - let bytecode := { bytecodeDedup with + let bytecodeConstrained : Bytecode.Toplevel := { bytecodeDedup with functions := bytecodeDedup.functions.mapIdx fun i f => { f with constrained := needs[i]! } } + let bytecode := bytecodeConstrained.computeFiltered let nameMap := preNameMap.fold (init := (∅ : Std.HashMap Global Bytecode.FunIdx)) fun acc name idx => acc.insert name (remap idx) pure (CompiledToplevel.mk t bytecode nameMap) diff --git a/Ix/Aiur/Compiler/Check.lean b/Ix/Aiur/Compiler/Check.lean index 6970abbb..eeaacc03 100644 --- a/Ix/Aiur/Compiler/Check.lean +++ b/Ix/Aiur/Compiler/Check.lean @@ -790,6 +790,9 @@ def inferTerm (t : Term) : CheckM Typed.Term := match t with | some sub => do pure (some (← inferTerm sub)) let ret' ← inferTerm ret pure (Typed.Term.debug ret'.typ ret'.escapes label term' ret') + | .retGroup name inner => do + let inner' ← inferTerm inner + pure (Typed.Term.retGroup inner'.typ inner'.escapes name inner') termination_by (sizeOf t, 0) decreasing_by all_goals first @@ -917,6 +920,8 @@ def zonkTypedTerm (t : Typed.Term) : CheckM Typed.Term := match t with | none => pure none | some sub => do pure (some (← zonkTypedTerm sub)) pure (.debug (← zonkTyp τ) e label t' (← zonkTypedTerm r)) + | .retGroup τ e name inner => do + pure (.retGroup (← zonkTyp τ) e name (← zonkTypedTerm inner)) termination_by sizeOf t decreasing_by all_goals first diff --git a/Ix/Aiur/Compiler/Concretize.lean b/Ix/Aiur/Compiler/Concretize.lean index b9924aee..f6931ae9 100644 --- a/Ix/Aiur/Compiler/Concretize.lean +++ b/Ix/Aiur/Compiler/Concretize.lean @@ -349,6 +349,8 @@ def termToConcrete | none => pure none | some sub => do pure (some (← termToConcrete mono sub)) pure (.debug (← typToConcrete mono τ) e l t' (← termToConcrete mono r)) + | .retGroup τ e name inner => do + pure (.retGroup (← typToConcrete mono τ) e name (← termToConcrete mono inner)) termination_by t => sizeOf t decreasing_by all_goals first @@ -541,6 +543,8 @@ def rewriteTypedTerm (decls : Typed.Decls) | none => none | some sub => some (rewriteTypedTerm decls subst mono sub) .debug (rewriteTyp subst mono τ) e l t' (rewriteTypedTerm decls subst mono r) + | .retGroup τ e name inner => + .retGroup (rewriteTyp subst mono τ) e name (rewriteTypedTerm decls subst mono inner) termination_by t => sizeOf t decreasing_by all_goals first @@ -625,6 +629,7 @@ def collectInTypedTerm (seen : Std.HashSet (Global × Array Typ)) : let seen := collectInTyp seen τ let seen := match t with | some t => collectInTypedTerm seen t | none => seen collectInTypedTerm seen r + | .retGroup τ _ _ inner => collectInTypedTerm (collectInTyp seen τ) inner termination_by t => sizeOf t decreasing_by all_goals first @@ -683,6 +688,7 @@ def collectCalls (decls : Typed.Decls) | .debug _ _ _ t r => let seen := match t with | some t => collectCalls decls seen t | none => seen collectCalls decls seen r + | .retGroup _ _ _ inner => collectCalls decls seen inner termination_by t => sizeOf t decreasing_by all_goals first @@ -771,6 +777,8 @@ def substInTypedTerm (subst : Global → Option Typ) : Typed.Term → Typed.Term | none => none | some sub => some (substInTypedTerm subst sub) .debug (Typ.instantiate subst τ) e l t' (substInTypedTerm subst r) + | .retGroup τ e name inner => + .retGroup (Typ.instantiate subst τ) e name (substInTypedTerm subst inner) termination_by t => sizeOf t decreasing_by all_goals first diff --git a/Ix/Aiur/Compiler/Dedup.lean b/Ix/Aiur/Compiler/Dedup.lean index 8a9d70bf..ecb15231 100644 --- a/Ix/Aiur/Compiler/Dedup.lean +++ b/Ix/Aiur/Compiler/Dedup.lean @@ -31,7 +31,7 @@ mutual | .match v branches def_ => .match v (branches.attach.map fun ⟨(g, b), _⟩ => (g, skeletonBlock b)) (match def_ with | none => none | some b => some (skeletonBlock b)) - | .return s vs => .return s vs + | .return s g vs => .return s g vs | .yield s vs => .yield s vs | .matchContinue v branches def_ outputSize sharedAux sharedLookups cont => .matchContinue v (branches.attach.map fun ⟨(g, b), _⟩ => (g, skeletonBlock b)) @@ -107,7 +107,7 @@ mutual | .match v branches def_ => .match v (branches.attach.map fun ⟨(g, b), _⟩ => (g, rewriteBlock f b)) (match def_ with | none => none | some b => some (rewriteBlock f b)) - | .return s vs => .return s vs + | .return s g vs => .return s g vs | .yield s vs => .yield s vs | .matchContinue v branches def_ outputSize sharedAux sharedLookups cont => .matchContinue v (branches.attach.map fun ⟨(g, b), _⟩ => (g, rewriteBlock f b)) @@ -196,7 +196,8 @@ def deduplicate_newFunctions (functions : Array Function) (classes : Array Nat) if can then let entry := deduplicate_class_entry functions classes cls let body := rewriteBlock remapFn f.body - acc.push { body, layout := f.layout, entry, constrained := false } + acc.push { body, layout := f.layout, groupNames := f.groupNames, + entry, constrained := false } else acc) #[] diff --git a/Ix/Aiur/Compiler/Lower.lean b/Ix/Aiur/Compiler/Lower.lean index a50f4ef2..6666eb0d 100644 --- a/Ix/Aiur/Compiler/Lower.lean +++ b/Ix/Aiur/Compiler/Lower.lean @@ -94,6 +94,15 @@ structure CompilerState where ops : Array Bytecode.Op selIdx : Bytecode.SelIdx degrees : Array Nat + /-- Top of the `#[return_group(…)]` annotation stack — the display name + whose index will tag every `Ctrl.return` emitted inside its scope. The + index is looked up (or allocated) lazily at emit time. -/ + currentReturnGroupName : String := "" + /-- Group display names allocated so far; position `i` is the name for + group index `i`. -/ + groupNames : Array String := #[] + /-- Inverse of `groupNames`: maps name → allocated index. -/ + groupNameMap : Std.HashMap String USize := {} deriving Inhabited abbrev CompileM := EStateM String CompilerState @@ -122,6 +131,20 @@ def pushOp (op : Bytecode.Op) (size : Nat := 1) : CompileM (Array Bytecode.ValId def extractOps : CompileM (Array Bytecode.Op) := modifyGet fun s => (s.ops, {s with ops := #[]}) +/-- Look up the `USize` index for the current return-group name, allocating +fresh storage in `groupNames`/`groupNameMap` on first encounter. -/ +def allocCurrentGroup : CompileM USize := do + let st ← get + let name := st.currentReturnGroupName + match st.groupNameMap[name]? with + | some idx => pure idx + | none => + let idx : USize := USize.ofNat st.groupNames.size + modify fun s => { s with + groupNameMap := s.groupNameMap.insert name idx + groupNames := s.groupNames.push name } + pure idx + open Concrete in mutual @@ -299,6 +322,7 @@ def toIndex | some sub => do pure (some (← toIndex layoutMap bindings sub)) modify fun stt => { stt with ops := stt.ops.push (.debug label term) } toIndex layoutMap bindings ret + | .retGroup _ _ _ inner => toIndex layoutMap bindings inner termination_by (sizeOf term, 0) decreasing_by all_goals first @@ -446,6 +470,12 @@ def Concrete.Term.compile let data ← toIndex layoutMap bindings data modify fun stt => { stt with ops := stt.ops.push (.ioWrite data) } ret.compile returnTyp layoutMap bindings yieldCtrl + | .retGroup _ _ name inner => do + let oldGroup := (← get).currentReturnGroupName + modify fun s => { s with currentReturnGroupName := name } + let blk ← inner.compile returnTyp layoutMap bindings yieldCtrl + modify fun s => { s with currentReturnGroupName := oldGroup } + pure blk | .match _ _ scrut cases defaultOpt => do let idxs := bindings[scrut]?.getD #[0] let ops ← extractOps @@ -460,21 +490,24 @@ def Concrete.Term.compile pure ({ ops, ctrl } : Bytecode.Block) | .ret _ _ term => do let idxs ← toIndex layoutMap bindings term + let groupIdx ← allocCurrentGroup let state ← get let state := { state with selIdx := state.selIdx + 1 } set state let ops := state.ops let id := state.selIdx - pure ({ ops, ctrl := .return (id - 1) idxs } : Bytecode.Block) + pure ({ ops, ctrl := .return (id - 1) groupIdx idxs } : Bytecode.Block) | _ => do let idxs ← toIndex layoutMap bindings term + let groupIdx ← allocCurrentGroup let state ← get let state := { state with selIdx := state.selIdx + 1 } set state let ops := state.ops let id := state.selIdx let ctrl : Bytecode.Ctrl := - if yieldCtrl && !term.escapes then .yield (id - 1) idxs else .return (id - 1) idxs + if yieldCtrl && !term.escapes then .yield (id - 1) idxs + else .return (id - 1) groupIdx idxs pure ({ ops, ctrl } : Bytecode.Block) termination_by (sizeOf term, 0) decreasing_by @@ -497,7 +530,10 @@ def Concrete.addCase | .field g => do let initState ← get let term ← term.compile returnTyp layoutMap bindings yieldCtrl - set { initState with selIdx := (← get).selIdx } + let cur ← get + set { initState with selIdx := cur.selIdx, + groupNames := cur.groupNames, + groupNameMap := cur.groupNameMap } pure (cases.push (g, term), defaultBlock) | .ref global pats => do let (index, offsets) ← match layoutMap[global]? with @@ -516,12 +552,18 @@ def Concrete.addCase acc.insert patLocal slice let initState ← get let term ← term.compile returnTyp layoutMap ptrBindings yieldCtrl - set { initState with selIdx := (← get).selIdx } + let cur ← get + set { initState with selIdx := cur.selIdx, + groupNames := cur.groupNames, + groupNameMap := cur.groupNameMap } pure (cases.push (.ofNat index, term), defaultBlock) | .wildcard => do let initState ← get let term ← term.compile returnTyp layoutMap bindings yieldCtrl - set { initState with selIdx := (← get).selIdx } + let cur ← get + set { initState with selIdx := cur.selIdx, + groupNames := cur.groupNames, + groupNameMap := cur.groupNameMap } pure (cases, .some term) | _ => throw "addCase: unsupported pattern in concrete lower" termination_by _ pair => (sizeOf pair.snd, 1) @@ -529,9 +571,11 @@ decreasing_by all_goals first | decreasing_tactic | grind end -/-- Lower a full concrete function to bytecode. -/ +/-- Lower a full concrete function to bytecode. Returns the body, layout +state, and the per-function `groupNames` table (position `i` is the display +name for group index `i`). -/ def Concrete.Function.compile (layoutMap : LayoutMap) (f : Concrete.Function) : - Except String (Bytecode.Block × Bytecode.LayoutMState) := do + Except String (Bytecode.Block × Bytecode.LayoutMState × Array String) := do let (_inputSize, _outputSize) ← match layoutMap[f.name]? with | some (.function layout) => pure (layout.inputSize, layout.outputSize) | _ => throw s!"`{f.name}` should be a function" @@ -542,15 +586,16 @@ def Concrete.Function.compile (layoutMap : LayoutMap) (f : Concrete.Function) : | .ok len => pure len let indices := Array.range' valIdx len pure (valIdx + len, bindings.insert arg indices) - let state := { valIdx, selIdx := 0, ops := #[], degrees := Array.replicate valIdx 1 } + let state : CompilerState := { valIdx, selIdx := 0, ops := #[], + degrees := Array.replicate valIdx 1 } match f.body.compile f.output layoutMap bindings |>.run state with | .error e _ => throw e - | .ok body _ => + | .ok body finalState => let (_, layoutMState) := Bytecode.blockLayout body |>.run (.new valIdx) let layoutMState := { layoutMState with functionLayout := { layoutMState.functionLayout with lookups := layoutMState.functionLayout.lookups + 1 } } - pure (body, layoutMState) + pure (body, layoutMState, finalState.groupNames) def Concrete.Decls.toBytecode (decls : Concrete.Decls) : Except String (Bytecode.Toplevel × Std.HashMap Global Bytecode.FunIdx) := do @@ -559,13 +604,16 @@ def Concrete.Decls.toBytecode (decls : Concrete.Decls) : let (functions, memSizes, nameMap) ← decls.foldlM (init := (#[], initMemSizes, {})) fun acc@(functions, memSizes, nameMap) (_, decl) => match decl with | .function function => do - let (body, layoutMState) ← function.compile layout + let (body, layoutMState, groupNames) ← function.compile layout let nameMap := nameMap.insert function.name functions.size - let function := ⟨body, layoutMState.functionLayout, function.entry, false⟩ + let groupNames := if groupNames.isEmpty then #[""] else groupNames + let function : Bytecode.Function := + { body, layout := layoutMState.functionLayout, + groupNames, entry := function.entry, constrained := false } let memSizes := layoutMState.memSizes.fold (·.insert ·) memSizes pure (functions.push function, memSizes, nameMap) | _ => pure acc - pure (⟨functions, memSizes.toArray⟩, nameMap) + pure ({ functions, memorySizes := memSizes.toArray : Bytecode.Toplevel }, nameMap) end Aiur diff --git a/Ix/Aiur/Compiler/Match.lean b/Ix/Aiur/Compiler/Match.lean index 3a92d00e..15266385 100644 --- a/Ix/Aiur/Compiler/Match.lean +++ b/Ix/Aiur/Compiler/Match.lean @@ -389,6 +389,7 @@ def typedToSimple : Term → Simple.Term | .debug τ e l t r => let t' := match t with | none => none | some sub => some (typedToSimple sub) .debug τ e l t' (typedToSimple r) + | .retGroup τ e name inner => .retGroup τ e name (typedToSimple inner) termination_by t => sizeOf t decreasing_by all_goals first | decreasing_tactic | grind diff --git a/Ix/Aiur/Compiler/Simple.lean b/Ix/Aiur/Compiler/Simple.lean index ec33f504..64f03477 100644 --- a/Ix/Aiur/Compiler/Simple.lean +++ b/Ix/Aiur/Compiler/Simple.lean @@ -108,6 +108,9 @@ def simplifyTypedTerm (decls : Source.Decls) : Term → Except CheckError Term let a' ← simplifyTypedTerm decls a let b' ← simplifyTypedTerm decls b pure (.u32LessThan τ e a' b') + | .retGroup τ e name inner => do + let inner' ← simplifyTypedTerm decls inner + pure (.retGroup τ e name inner') | t => pure t termination_by t => sizeOf t decreasing_by diff --git a/Ix/Aiur/Compiler/Split.lean b/Ix/Aiur/Compiler/Split.lean new file mode 100644 index 00000000..b0f6d768 --- /dev/null +++ b/Ix/Aiur/Compiler/Split.lean @@ -0,0 +1,180 @@ +module +public import Ix.Aiur.Stages.Bytecode +public import Ix.Aiur.Compiler.Layout + +/-! +Return-group splitting for Aiur bytecode. + +Each `Bytecode.Function` may carry multiple `Ctrl.return` sites tagged with +distinct `USize` group indices. `Function.split` carves the function into one +filtered sub-function per index — control-flow paths that cannot reach a +`return` of the target group are pruned. The result is positional: index +`i` in the output equals the group index `i` in the function's +`groupNames` table. +-/ + +public section +@[expose] section + +namespace Aiur + +namespace Bytecode + +/-- Termination helper for the `Block`/`Ctrl` traversal below. -/ +private theorem Block.sizeOf_ctrl_lt_split (b : Block) : + sizeOf b.ctrl < sizeOf b := by + rcases b with ⟨ops, ctrl⟩ + show sizeOf ctrl < 1 + sizeOf ops + sizeOf ctrl + omega + +/-! ## Filter control-flow tree by target group + +Single (non-mutual) recursion on `Ctrl`. The conditional `Option`-typed +constructor wrap is factored into `mkMatch`/`mkMatchContinue` helpers so +that each match arm of `filterGroup` is a single function-call expression +— Lean can then derive the unfold equation by `rfl`. + +`Yield` branches always survive — they're anchored by the surrounding +`MatchContinue`'s continuation, not by terminal returns. -/ + +def Ctrl.mkMatch (scrut : ValIdx) (cases : Array (G × Block)) + (default? : Option Block) : Option Ctrl := + if cases.isEmpty && default?.isNone then none + else some (.match scrut cases default?) + +def Ctrl.mkMatchContinue (scrut : ValIdx) (cases : Array (G × Block)) + (default? : Option Block) (outputSize sharedAux sharedLookups : Nat) + (cont : Block) : Option Ctrl := + if cases.isEmpty && default?.isNone then none + else some (.matchContinue scrut cases default? outputSize sharedAux sharedLookups cont) + +def Ctrl.filterGroup (target : USize) : Ctrl → Option Ctrl + | .return sel g vs => if g = target then some (.return sel g vs) else none + | .yield sel vs => some (.yield sel vs) + | .match scrut cases default? => + Ctrl.mkMatch scrut + (cases.attach.foldl (init := (#[] : Array (G × Block))) fun acc ⟨(k, blk), _⟩ => + match Ctrl.filterGroup target blk.ctrl with + | none => acc + | some c => acc.push (k, { blk with ctrl := c })) + (match default? with + | none => none + | some blk => (Ctrl.filterGroup target blk.ctrl).map ({ blk with ctrl := · })) + | .matchContinue scrut cases default? outputSize sharedAux sharedLookups cont => + (Ctrl.filterGroup target cont.ctrl).bind fun c => + Ctrl.mkMatchContinue scrut + (cases.attach.foldl (init := (#[] : Array (G × Block))) fun acc ⟨(k, blk), _⟩ => + match Ctrl.filterGroup target blk.ctrl with + | none => acc + | some c => acc.push (k, { blk with ctrl := c })) + (match default? with + | none => none + | some blk => (Ctrl.filterGroup target blk.ctrl).map ({ blk with ctrl := · })) + outputSize sharedAux sharedLookups + { cont with ctrl := c } +termination_by c => sizeOf c +decreasing_by + all_goals first + | decreasing_tactic + | (have hb := Array.sizeOf_lt_of_mem ‹_ ∈ _› + have hc := Block.sizeOf_ctrl_lt_split ‹Block› + grind) + | (have := Block.sizeOf_ctrl_lt_split ‹Block›; grind) + | grind + +def Block.filterGroup (target : USize) (b : Block) : Option Block := + (Ctrl.filterGroup target b.ctrl).map ({ b with ctrl := · }) + +/-! ## Renumber selectors in traversal order -/ + +/-- Single recursion on `Ctrl` threading a `Nat` counter for selectors. The +returned pair is `(rewrittenCtrl, nextCounter)`. -/ +def Ctrl.fixSelectors : Ctrl → Nat → Ctrl × Nat + | .return _ g vs, n => (.return n g vs, n + 1) + | .yield _ vs, n => (.yield n vs, n + 1) + | .match scrut cases default?, n => + let (newCases, n) := cases.attach.foldl + (init := ((#[] : Array (G × Block)), n)) + fun (acc, n) ⟨(k, blk), _⟩ => + let (c', n') := Ctrl.fixSelectors blk.ctrl n + (acc.push (k, { blk with ctrl := c' }), n') + let (newDefault, n) : Option Block × Nat := match default? with + | none => (none, n) + | some blk => + let (c', n') := Ctrl.fixSelectors blk.ctrl n + (some { blk with ctrl := c' }, n') + (.match scrut newCases newDefault, n) + | .matchContinue scrut cases default? outputSize sharedAux sharedLookups cont, n => + let (newCases, n) := cases.attach.foldl + (init := ((#[] : Array (G × Block)), n)) + fun (acc, n) ⟨(k, blk), _⟩ => + let (c', n') := Ctrl.fixSelectors blk.ctrl n + (acc.push (k, { blk with ctrl := c' }), n') + let (newDefault, n) : Option Block × Nat := match default? with + | none => (none, n) + | some blk => + let (c', n') := Ctrl.fixSelectors blk.ctrl n + (some { blk with ctrl := c' }, n') + let (c', n') := Ctrl.fixSelectors cont.ctrl n + let newCont : Block := { cont with ctrl := c' } + (.matchContinue scrut newCases newDefault outputSize sharedAux sharedLookups newCont, n') +termination_by c _ => sizeOf c +decreasing_by + all_goals first + | decreasing_tactic + | (have hb := Array.sizeOf_lt_of_mem ‹_ ∈ _› + have hc := Block.sizeOf_ctrl_lt_split ‹Block› + grind) + | (have := Block.sizeOf_ctrl_lt_split ‹Block›; grind) + | grind + +def Block.fixSelectors (b : Block) (n : Nat) : Block × Nat := + let (c', n') := Ctrl.fixSelectors b.ctrl n + ({ b with ctrl := c' }, n') + +/-! ## Layout recompute -/ + +/-- Recompute the `FunctionLayout` of `f` by replaying `blockLayout` over the +current body. Mirrors `Concrete.Function.compile`: `+1` lookup for the +function's own return lookup. -/ +def Function.recomputeLayout (f : Function) : FunctionLayout := + let (_, st) := Concrete.Bytecode.blockLayout f.body |>.run + (.new f.layout.inputSize) + { st.functionLayout with lookups := st.functionLayout.lookups + 1 } + +/-- Renumber selectors via DFS counter, then recompute layout. Intended for +functions produced by `filterGroup`, whose selector indices and layout are +inherited from the pre-filter function. -/ +def Function.fix (f : Function) : Function := + let (body', _) := Block.fixSelectors f.body 0 + let f := { f with body := body' } + { f with layout := f.recomputeLayout } + +/-! ## Top-level split -/ + +/-- Carve `f` into one filtered sub-function per group index in +`f.groupNames`. Position `i` in the result equals group index `i`. Panics +if a stored group index has no reachable `Return` site (should not happen +for well-formed bytecode produced by `Concrete.Function.compile`). -/ +def Function.split (f : Function) : Array Function := + (Array.range f.groupNames.size).map fun i => + let target : USize := USize.ofNat i + match Block.filterGroup target f.body with + | none => + panic! s!"function contains an unreachable group: {f.groupNames[i]!}" + | some body => + let filtered : Function := { f with body } + filtered.fix + +/-- Populate `t.filteredFunctions` by splitting every function. Idempotent — +overwrites any existing entries. Should run after `deduplicate` and +`needsCircuit` so the split reflects the final function set. -/ +def Toplevel.computeFiltered (t : Toplevel) : Toplevel := + { t with filteredFunctions := t.functions.map (·.split) } + +end Bytecode + +end Aiur + +end -- @[expose] section +end diff --git a/Ix/Aiur/Interpret.lean b/Ix/Aiur/Interpret.lean index 1926af1b..d7751b15 100644 --- a/Ix/Aiur/Interpret.lean +++ b/Ix/Aiur/Interpret.lean @@ -419,6 +419,7 @@ partial def interp (decls : Decls) (bindings : Bindings) : Term → InterpM Valu let dataGs ← expectFieldArray (← interp decls bindings data) modifyIOBuffer fun io => { io with data := io.data ++ dataGs } interp decls bindings ret + | .retGroup _ inner => interp decls bindings inner end diff --git a/Ix/Aiur/Meta.lean b/Ix/Aiur/Meta.lean index c85a5b94..5f3930e2 100644 --- a/Ix/Aiur/Meta.lean +++ b/Ix/Aiur/Meta.lean @@ -180,6 +180,7 @@ syntax "u8_or" "(" aiur_trm ", " aiur_trm ")" : ai syntax "u8_less_than" "(" aiur_trm ", " aiur_trm ")" : aiur_trm syntax "u32_less_than" "(" aiur_trm ", " aiur_trm ")" : aiur_trm syntax "dbg!" "(" str (", " aiur_trm)? ")" ";" (aiur_trm)? : aiur_trm +syntax "#[" "return_group" "(" ident ")" "]" aiur_trm : aiur_trm syntax aiur_trm "[" "@" noWs ident "]" : aiur_trm syntax "set" "(" aiur_trm ", " "@" noWs ident ", " aiur_trm ")" : aiur_trm @@ -310,6 +311,8 @@ partial def elabTrm : ElabStxCat `aiur_trm | none => mkAppOptM ``Option.none #[some (mkConst ``Source.Term)] | some t => mkAppM ``Option.some #[← elabTrm t] mkAppM ``Source.Term.debug #[mkStrLit label.getString, t, ← elabRet ret] + | `(aiur_trm| #[return_group($name:ident)] $t:aiur_trm) => do + mkAppM ``Source.Term.retGroup #[mkStrLit name.getId.toString, ← elabTrm t] -- Template function calls: explicit type args are dropped (inferred) | `(aiur_trm| $f:ident‹$_:aiur_typ $[, $_:aiur_typ]*›()) => do let g ← mkAppM ``Global.mk #[toExpr f.getId] @@ -511,6 +514,9 @@ where let t' ← t.mapM $ replaceToken old new let ret' ← ret.mapM $ replaceToken old new `(aiur_trm| dbg!($label $[, $t']?); $[$ret']?) + | `(aiur_trm| #[return_group($name:ident)] $t:aiur_trm) => do + let t ← replaceToken old new t + `(aiur_trm| #[return_group($name)] $t) | `(aiur_trm| fold($i .. $j, $init, |$acc, @$v| $body)) => do let init ← replaceToken old new init -- Don't conflict with shadowing tokens. diff --git a/Ix/Aiur/Semantics/BytecodeEval.lean b/Ix/Aiur/Semantics/BytecodeEval.lean index a45c377e..2db2cae5 100644 --- a/Ix/Aiur/Semantics/BytecodeEval.lean +++ b/Ix/Aiur/Semantics/BytecodeEval.lean @@ -301,7 +301,7 @@ decreasing_by def evalCtrl (t : Bytecode.Toplevel) (fuel : Nat) (ctrl : Ctrl) (st : EvalState) : Except BytecodeError (Array G × EvalState) := match ctrl with - | .return _ outs => + | .return _ _ outs => match readIdxs st outs with | .error e => .error e | .ok gs => .ok (gs, st) diff --git a/Ix/Aiur/Semantics/BytecodeFfi.lean b/Ix/Aiur/Semantics/BytecodeFfi.lean index 08a99894..1d3fd1dd 100644 --- a/Ix/Aiur/Semantics/BytecodeFfi.lean +++ b/Ix/Aiur/Semantics/BytecodeFfi.lean @@ -50,39 +50,69 @@ instance : BEq IOBuffer where -- via `Std.HashMap.beq_iff_equiv` + `Std.HashMap.Equiv.{refl,symm,trans}`, -- bypassing the need for `LawfulBEq` on the outer `IOBuffer`. -/-- Per-circuit query counts for one circuit (one per function circuit, then -one per memory size). `uniqueRows` is the trace height; `totalHits` is the sum -of query multiplicities. The difference `totalHits - uniqueRows` is the number -of cache hits. -/ -structure QueryCount where +namespace Bytecode.Toplevel + +/-- Per-split execution stats for one return-group of one function. +`groupIdx` keys the corresponding entry in `Function.groupNames`. -/ +structure GroupStats where + groupIdx : Nat + totalWidth : Nat uniqueRows : Nat totalHits : Nat - deriving Inhabited + deriving Inhabited, Repr -namespace Bytecode.Toplevel +/-- Per-memory-size counts. `uniqueRows` is the trace height; `totalHits` +is the sum of multiplicities. `totalHits - uniqueRows` is the cache-hit +count. -/ +structure MemoryCount where + uniqueRows : Nat + totalHits : Nat + deriving Inhabited, Repr + +/-- Per-function execution stats. Outer index is `FunIdx`; inner index is +the `USize` group index used by `Ctrl.return`. -/ +abbrev FunctionStats := Array (Array GroupStats) + +/-- Per-memory-size counts, parallel to `Toplevel.memorySizes`. -/ +abbrev MemoryCounts := Array MemoryCount + +/-- Query counts shipped back from the Rust executor. -/ +structure QueryCounts where + functionStats : FunctionStats + memoryCounts : MemoryCounts + deriving Inhabited +/-- Raw FFI tuple shape — kept tuple-flat so the Rust side can build it +without declaring matching Lean structure ctors. `execute` wraps the +result in the structured `QueryCounts` immediately. -/ @[extern "rs_aiur_toplevel_execute"] private opaque execute' : @& Bytecode.Toplevel → @& Bytecode.FunIdx → @& Array G → (ioData : @& Array G) → (ioMap : @& Array (Array G × IOKeyInfo)) → - Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) × Array (Nat × Nat)) + Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) + × (Array (Array (Nat × Nat × Nat × Nat)) × Array (Nat × Nat))) /-- Executes the bytecode function `funIdx` with the given `args` and `ioBuffer`, -returning the raw output of the function, the updated `IOBuffer`, and an array -of per-circuit `QueryCount`s. Returns `Except.error msg` when execution -fails (e.g. `assert_eq!` mismatch from a typechecker rejecting a constant), so -callers can recover instead of crashing. -/ +returning the raw output of the function, the updated `IOBuffer`, and a +`QueryCounts`. Returns `Except.error msg` when execution fails (e.g. +`assert_eq!` mismatch from a typechecker rejecting a constant), so callers +can recover instead of crashing. -/ def execute (toplevel : @& Bytecode.Toplevel) (funIdx : @& Bytecode.FunIdx) (args : @& Array G) (ioBuffer : IOBuffer) : - Except String (Array G × IOBuffer × Array QueryCount) := + Except String (Array G × IOBuffer × QueryCounts) := let ioData := ioBuffer.data let ioMap := ioBuffer.map match execute' toplevel funIdx args ioData ioMap.toArray with | .error e => .error e - | .ok (output, (ioData, ioMap), queryCounts) => + | .ok (output, (ioData, ioMap), rawFn, rawMem) => let ioMap := ioMap.foldl (fun acc (k, v) => acc.insert k v) ∅ - let queryCounts := queryCounts.map fun (uniqueRows, totalHits) => { uniqueRows, totalHits } - .ok (output, ⟨ioData, ioMap⟩, queryCounts) + let functionStats : FunctionStats := rawFn.map fun perFn => + perFn.map fun quad => + { groupIdx := quad.1, totalWidth := quad.2.1, + uniqueRows := quad.2.2.1, totalHits := quad.2.2.2 } + let memoryCounts : MemoryCounts := rawMem.map fun pair => + { uniqueRows := pair.1, totalHits := pair.2 } + .ok (output, ⟨ioData, ioMap⟩, { functionStats, memoryCounts }) end Bytecode.Toplevel diff --git a/Ix/Aiur/Semantics/SourceEval.lean b/Ix/Aiur/Semantics/SourceEval.lean index c0cffea7..fb8c8f8d 100644 --- a/Ix/Aiur/Semantics/SourceEval.lean +++ b/Ix/Aiur/Semantics/SourceEval.lean @@ -474,6 +474,7 @@ def interp (decls : Decls) (fuel : Nat) (bindings : Bindings) { st'.ioBuffer with data := st'.ioBuffer.data ++ dataGs } } interp decls fuel bindings ret st'' | _ => .error (.typeMismatch "ioWrite") + | .retGroup _ inner => interp decls fuel bindings inner st termination_by (fuel, 2, sizeOf t) decreasing_by all_goals first diff --git a/Ix/Aiur/Stages/Bytecode.lean b/Ix/Aiur/Stages/Bytecode.lean index 527e5431..af8ef8bf 100644 --- a/Ix/Aiur/Stages/Bytecode.lean +++ b/Ix/Aiur/Stages/Bytecode.lean @@ -49,7 +49,7 @@ inductive Op mutual inductive Ctrl where | match : ValIdx → Array (G × Block) → Option Block → Ctrl - | return : SelIdx → Array ValIdx → Ctrl + | return : SelIdx → (group : USize) → Array ValIdx → Ctrl | yield : SelIdx → Array ValIdx → Ctrl | matchContinue : ValIdx → Array (G × Block) → Option Block → (outputSize : Nat) → (sharedAuxiliaries : Nat) → (sharedLookups : Nat) @@ -84,12 +84,21 @@ def FunctionLayout.totalWidth (l : FunctionLayout) : Nat := structure Function where body : Block layout: FunctionLayout + /-- Display names for the return groups used in `body`. Position `i` in the + array maps to the `USize` group index `i` carried by `Ctrl.return`. Defaults + to the singleton `#[""]` so functions with no `#[return_group(…)]` + annotations have a single unnamed group at index `0`. -/ + groupNames : Array String := #[""] entry : Bool constrained : Bool deriving Inhabited, Repr structure Toplevel where functions : Array Function + /-- Per-function split by return group: position in the inner array equals + the `USize` group index used by `Ctrl.return`. Populated by + `Toplevel.computeFiltered` after `deduplicate` + `needsCircuit`. -/ + filteredFunctions : Array (Array Function) := #[] memorySizes : Array Nat deriving Repr diff --git a/Ix/Aiur/Stages/Concrete.lean b/Ix/Aiur/Stages/Concrete.lean index 3e9eaa50..dc6f69f0 100644 --- a/Ix/Aiur/Stages/Concrete.lean +++ b/Ix/Aiur/Stages/Concrete.lean @@ -85,6 +85,7 @@ inductive Term : Type where | u8LessThan (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u32LessThan (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | debug (typ : Typ) (escapes : Bool) (label : String) (t : Option Term) (r : Term) : Term + | retGroup (typ : Typ) (escapes : Bool) (name : String) (inner : Term) : Term deriving Repr, Inhabited /-- Get the type annotation of a Concrete.Term, regardless of constructor. -/ @@ -101,7 +102,7 @@ def Term.typ : Term → Typ | .u8BitDecomposition t _ _ | .u8ShiftLeft t _ _ | .u8ShiftRight t _ _ | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Mul t _ _ _ | .u8Sub t _ _ _ | .u8And t _ _ _ | .u8Or t _ _ _ | .u8LessThan t _ _ _ | .u32LessThan t _ _ _ - | .debug t _ _ _ _ => t + | .debug t _ _ _ _ | .retGroup t _ _ _ => t /-- Get the escapes flag of a Concrete.Term. -/ def Term.escapes : Term → Bool @@ -117,7 +118,7 @@ def Term.escapes : Term → Bool | .u8BitDecomposition _ e _ | .u8ShiftLeft _ e _ | .u8ShiftRight _ e _ | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Mul _ e _ _ | .u8Sub _ e _ _ | .u8And _ e _ _ | .u8Or _ e _ _ | .u8LessThan _ e _ _ | .u32LessThan _ e _ _ - | .debug _ e _ _ _ => e + | .debug _ e _ _ _ | .retGroup _ e _ _ => e structure Constructor where nameHead : String diff --git a/Ix/Aiur/Stages/Simple.lean b/Ix/Aiur/Stages/Simple.lean index 996b5d21..28eb7d86 100644 --- a/Ix/Aiur/Stages/Simple.lean +++ b/Ix/Aiur/Stages/Simple.lean @@ -81,6 +81,7 @@ inductive Term : Type where | u8LessThan (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u32LessThan (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | debug (typ : Typ) (escapes : Bool) (label : String) (t : Option Term) (r : Term) : Term + | retGroup (typ : Typ) (escapes : Bool) (name : String) (inner : Term) : Term deriving Repr, Inhabited def Term.typ : Term → Typ @@ -96,7 +97,7 @@ def Term.typ : Term → Typ | .u8BitDecomposition t _ _ | .u8ShiftLeft t _ _ | .u8ShiftRight t _ _ | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Mul t _ _ _ | .u8Sub t _ _ _ | .u8And t _ _ _ | .u8Or t _ _ _ | .u8LessThan t _ _ _ | .u32LessThan t _ _ _ - | .debug t _ _ _ _ => t + | .debug t _ _ _ _ | .retGroup t _ _ _ => t def Term.escapes : Term → Bool | .unit _ e | .var _ e _ | .ref _ e _ _ | .field _ e _ @@ -111,7 +112,7 @@ def Term.escapes : Term → Bool | .u8BitDecomposition _ e _ | .u8ShiftLeft _ e _ | .u8ShiftRight _ e _ | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Mul _ e _ _ | .u8Sub _ e _ _ | .u8And _ e _ _ | .u8Or _ e _ _ | .u8LessThan _ e _ _ | .u32LessThan _ e _ _ - | .debug _ e _ _ _ => e + | .debug _ e _ _ _ | .retGroup _ e _ _ => e structure Function where name : Global diff --git a/Ix/Aiur/Stages/Source.lean b/Ix/Aiur/Stages/Source.lean index f1209021..e989e470 100644 --- a/Ix/Aiur/Stages/Source.lean +++ b/Ix/Aiur/Stages/Source.lean @@ -389,6 +389,7 @@ inductive Term | u8LessThan : Term → Term → Term | u32LessThan : Term → Term → Term | debug : String → Option Term → Term → Term + | retGroup : String → Term → Term deriving Repr, BEq, Hashable, Inhabited end Source diff --git a/Ix/Aiur/Stages/Typed.lean b/Ix/Aiur/Stages/Typed.lean index 2a3c6f33..7f84218e 100644 --- a/Ix/Aiur/Stages/Typed.lean +++ b/Ix/Aiur/Stages/Typed.lean @@ -57,6 +57,7 @@ inductive Term : Type where | u8LessThan (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | u32LessThan (typ : Typ) (escapes : Bool) (a : Term) (b : Term) : Term | debug (typ : Typ) (escapes : Bool) (label : String) (t : Option Term) (r : Term) : Term + | retGroup (typ : Typ) (escapes : Bool) (name : String) (inner : Term) : Term deriving Repr, Inhabited /-- Get the type annotation, regardless of constructor. -/ @@ -72,7 +73,7 @@ def Term.typ : Term → Typ | .u8BitDecomposition t _ _ | .u8ShiftLeft t _ _ | .u8ShiftRight t _ _ | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Mul t _ _ _ | .u8Sub t _ _ _ | .u8And t _ _ _ | .u8Or t _ _ _ | .u8LessThan t _ _ _ | .u32LessThan t _ _ _ - | .debug t _ _ _ _ => t + | .debug t _ _ _ _ | .retGroup t _ _ _ => t /-- Get the escapes flag. -/ def Term.escapes : Term → Bool @@ -87,7 +88,7 @@ def Term.escapes : Term → Bool | .u8BitDecomposition _ e _ | .u8ShiftLeft _ e _ | .u8ShiftRight _ e _ | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Mul _ e _ _ | .u8Sub _ e _ _ | .u8And _ e _ _ | .u8Or _ e _ _ | .u8LessThan _ e _ _ | .u32LessThan _ e _ _ - | .debug _ e _ _ _ => e + | .debug _ e _ _ _ | .retGroup _ e _ _ => e structure Function where name : Global diff --git a/Ix/Aiur/Statistics.lean b/Ix/Aiur/Statistics.lean index ddcf082b..5de795f6 100644 --- a/Ix/Aiur/Statistics.lean +++ b/Ix/Aiur/Statistics.lean @@ -39,7 +39,7 @@ def fftCost (w h : Nat) : Float := let hf := h.toFloat wf * hf * (max hf 2.0).log2 -def computeStats (compiled : CompiledToplevel) (queryCounts : Array QueryCount) : +def computeStats (compiled : CompiledToplevel) (qc : Bytecode.Toplevel.QueryCounts) : ExecutionStats := let t := compiled.bytecode -- Invert nameMap to get FunIdx → String @@ -50,20 +50,22 @@ def computeStats (compiled : CompiledToplevel) (queryCounts : Array QueryCount) let mut acc := #[] for i in [:nAllFuns] do if t.functions[i]!.constrained then - let w := t.functions[i]!.layout.totalWidth - let qc := queryCounts[i]! - let h := qc.uniqueRows - let hits := qc.totalHits - qc.uniqueRows - let name := reverseMap[i]?.getD s!"" - acc := acc.push { name, width := w, height := h, cacheHits := hits, fftCost := fftCost w h : CircuitStats } + let baseName := reverseMap[i]?.getD s!"" + let groupNames := t.functions[i]!.groupNames + for gs in qc.functionStats[i]! do + let group := groupNames[gs.groupIdx]?.getD "" + let name := if group.isEmpty then baseName else s!"{baseName} [{group}]" + let hits := gs.totalHits - gs.uniqueRows + acc := acc.push + { name, width := gs.totalWidth, height := gs.uniqueRows, + cacheHits := hits, fftCost := fftCost gs.totalWidth gs.uniqueRows : CircuitStats } acc let memoryCircuits := t.memorySizes.mapIdx fun i size => let w := size + 11 - let qc := queryCounts[nAllFuns + i]! - let h := qc.uniqueRows - let hits := qc.totalHits - qc.uniqueRows - { name := s!"memory[{size}]", - width := w, height := h, cacheHits := hits, fftCost := fftCost w h : CircuitStats } + let mc := qc.memoryCounts[i]! + let hits := mc.totalHits - mc.uniqueRows + { name := s!"memory[{size}]", width := w, height := mc.uniqueRows, + cacheHits := hits, fftCost := fftCost w mc.uniqueRows : CircuitStats } let circuits := (functionCircuits ++ memoryCircuits).qsort (·.fftCost > ·.fftCost) let totalFftCost := circuits.foldl (· + ·.fftCost) 0.0 let totalUncachedFftCost := circuits.foldl (fun acc cs => acc + fftCost cs.width (cs.height + cs.cacheHits)) 0.0 diff --git a/Ix/IxVM/Ingress.lean b/Ix/IxVM/Ingress.lean index 9466eaa7..0bcc8868 100644 --- a/Ix/IxVM/Ingress.lean +++ b/Ix/IxVM/Ingress.lean @@ -33,25 +33,32 @@ def ingress := ⟦ -- Compare two 32-byte addresses for equality fn address_eq(a: Addr, b: Addr) -> G { - let [a0, a1, a2, a3, a4, a5, a6, a7, - a8, a9, a10, a11, a12, a13, a14, a15, - a16, a17, a18, a19, a20, a21, a22, a23, - a24, a25, a26, a27, a28, a29, a30, a31] = load(a); - let [b0, b1, b2, b3, b4, b5, b6, b7, - b8, b9, b10, b11, b12, b13, b14, b15, - b16, b17, b18, b19, b20, b21, b22, b23, - b24, b25, b26, b27, b28, b29, b30, b31] = load(b); - match [a0 - b0, a1 - b1, a2 - b2, a3 - b3, - a4 - b4, a5 - b5, a6 - b6, a7 - b7, - a8 - b8, a9 - b9, a10 - b10, a11 - b11, - a12 - b12, a13 - b13, a14 - b14, a15 - b15, - a16 - b16, a17 - b17, a18 - b18, a19 - b19, - a20 - b20, a21 - b21, a22 - b22, a23 - b23, - a24 - b24, a25 - b25, a26 - b26, a27 - b27, - a28 - b28, a29 - b29, a30 - b30, a31 - b31] { - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] => 1, - _ => 0, + match ptr_val(a) - ptr_val(b) { + 0 => + #[return_group(fast_path)] + 1, + _ => + #[return_group(slow_path)] + let [a0, a1, a2, a3, a4, a5, a6, a7, + a8, a9, a10, a11, a12, a13, a14, a15, + a16, a17, a18, a19, a20, a21, a22, a23, + a24, a25, a26, a27, a28, a29, a30, a31] = load(a); + let [b0, b1, b2, b3, b4, b5, b6, b7, + b8, b9, b10, b11, b12, b13, b14, b15, + b16, b17, b18, b19, b20, b21, b22, b23, + b24, b25, b26, b27, b28, b29, b30, b31] = load(b); + match [a0 - b0, a1 - b1, a2 - b2, a3 - b3, + a4 - b4, a5 - b5, a6 - b6, a7 - b7, + a8 - b8, a9 - b9, a10 - b10, a11 - b11, + a12 - b12, a13 - b13, a14 - b14, a15 - b15, + a16 - b16, a17 - b17, a18 - b18, a19 - b19, + a20 - b20, a21 - b21, a22 - b22, a23 - b23, + a24 - b24, a25 - b25, a26 - b26, a27 - b27, + a28 - b28, a29 - b29, a30 - b30, a31 - b31] { + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] => 1, + _ => 0, + }, } } diff --git a/Tests/Aiur/Aiur.lean b/Tests/Aiur/Aiur.lean index 6a0913b3..d06e749c 100644 --- a/Tests/Aiur/Aiur.lean +++ b/Tests/Aiur/Aiur.lean @@ -705,6 +705,45 @@ def toplevel := ⟦ let (x, y) = ntm_tuple(a); x + y } + -- Return-group annotation with two arms sharing a group + one unannotated + -- arm (default `""` group). Exercises per-group multi-row witnesses: the + -- `even` and `odd` circuits each get 2 return sites, the `""` circuit 1. + fn shared_return_groups(x: G) -> G { + match x { + 0 => #[return_group(even)] x + 100, + 1 => #[return_group(odd)] x + 200, + 2 => #[return_group(even)] x * 10, + 3 => #[return_group(odd)] x * 100, + _ => x + 1, + } + } + -- Aggregate one call per arm so the STARK proves every group circuit with + -- the right multiplicities in a single test case. + pub fn shared_return_groups_all() -> G { + shared_return_groups(0) + shared_return_groups(1) + + shared_return_groups(2) + shared_return_groups(3) + + shared_return_groups(7) + } + -- Return-group annotation: ignored by compiler/typechecker, must pass through + -- to inner term. Match with distinct group labels per arm. + pub fn match_return_groups(x: G) -> G { + match x { + 0 => + #[return_group(zero)] + 100, + 1 => + #[return_group(one)] + x + 200, + 2 => + #[return_group(two_squared)] + x * x * x, + _ => + #[return_group(default)] + #[return_group(nested)] + x + 1, + } + } + pub fn non_tail_match() -> G { -- Basic, early return, sequential, nested, const mul let r1 = ntm_basic(0) + ntm_basic(5); @@ -884,6 +923,20 @@ def aiurTestCases : List AiurTestCase := [ -- Non-tail match: all patterns in one proof .noIO `non_tail_match #[] #[2281], + + -- Return-group annotation: passthrough across match arms + { AiurTestCase.noIO `match_return_groups #[0] #[100] + with label := "match_return_groups(0)" }, + { AiurTestCase.noIO `match_return_groups #[1] #[201] + with label := "match_return_groups(1)" }, + { AiurTestCase.noIO `match_return_groups #[2] #[8] + with label := "match_return_groups(2)" }, + { AiurTestCase.noIO `match_return_groups #[7] #[8] + with label := "match_return_groups(7)" }, + + -- Shared-group + unannotated arm packed into a single STARK run: + -- 100 + 201 + 20 + 300 + 8 = 629. + .noIO `shared_return_groups_all #[] #[629], ] end diff --git a/Tests/Aiur/Cross.lean b/Tests/Aiur/Cross.lean index 25d1e988..76014790 100644 --- a/Tests/Aiur/Cross.lean +++ b/Tests/Aiur/Cross.lean @@ -1020,6 +1020,25 @@ def toplevel : Source.Toplevel := ⟦ r1 + r2 + r3 + r4 + r5 + r6 + r7 + r8 + r9 + r10 + r11 + r12 + r13 + r14 + r15 + r16 + r17 + r18 + r19 + r20 } + + -- Return-group annotation: two arms share `even`, two share `odd`, one + -- arm is unannotated (default `""` group). Source.Eval treats the + -- annotation as identity; Bytecode.Eval must agree on every input. + fn shared_return_groups(x: G) -> G { + match x { + 0 => #[return_group(even)] x + 100, + 1 => #[return_group(odd)] x + 200, + 2 => #[return_group(even)] x * 10, + 3 => #[return_group(odd)] x * 100, + _ => x + 1, + } + } + -- One entry point hitting every arm in a single run. + pub fn shared_return_groups_all() -> G { + shared_return_groups(0) + shared_return_groups(1) + + shared_return_groups(2) + shared_return_groups(3) + + shared_return_groups(7) + } ⟧ /-- Generic helper: run both evaluators on `entryName` with `inputs` as @@ -1263,7 +1282,9 @@ def tests : TestSeq := runAgreement "fibonacci(6)" "fibonacci" [6] ++ -- End-to-end non-tail match entry aggregating many ntm_* helpers runAgreement "ntm_recursive_test" "ntm_recursive_test" [] ++ - runAgreement "non_tail_match" "non_tail_match" [] + runAgreement "non_tail_match" "non_tail_match" [] ++ + -- Shared-group + unannotated arm: aggregate over every arm in one call. + runAgreement "shared_return_groups_all" "shared_return_groups_all" [] end AiurTests.Cross diff --git a/src/aiur/bytecode.rs b/src/aiur/bytecode.rs index 10e27cf7..86c05e70 100644 --- a/src/aiur/bytecode.rs +++ b/src/aiur/bytecode.rs @@ -4,6 +4,9 @@ use super::G; pub struct Toplevel { pub(crate) functions: Vec, + /// Per-function split by return-group index: outer index = `FunIdx`, inner + /// index = the `usize` group index carried by `Ctrl::Return`. + pub(crate) filtered_functions: Vec>, pub(crate) memory_sizes: Vec, } @@ -63,7 +66,7 @@ pub enum Op { pub enum Ctrl { Match(ValIdx, FxIndexMap, Option>), - Return(SelIdx, Vec), + Return(SelIdx, usize, Vec), Yield(SelIdx, Vec), MatchContinue( ValIdx, diff --git a/src/aiur/constraints.rs b/src/aiur/constraints.rs index 4f0d6c68..2b1ec926 100644 --- a/src/aiur/constraints.rs +++ b/src/aiur/constraints.rs @@ -110,8 +110,9 @@ impl Toplevel { pub fn build_constraints( &self, function_index: usize, + group: usize, ) -> (Constraints, Vec>) { - let function = &self.functions[function_index]; + let function = &self.filtered_functions[function_index][group]; let constraints = Constraints { zeros: vec![], selectors: 0..0, @@ -172,7 +173,7 @@ impl Block { /// be double-counted. fn get_block_selector(&self, state: &ConstraintState) -> Expr { match &self.ctrl { - Ctrl::Return(sel, _) | Ctrl::Yield(sel, _) => { + Ctrl::Return(sel, _, _) | Ctrl::Yield(sel, _) => { var(state.selector_index(*sel)) }, Ctrl::Match(_, cases, def) | Ctrl::MatchContinue(_, cases, def, ..) => { @@ -234,7 +235,7 @@ impl Ctrl { #[allow(clippy::needless_pass_by_value)] fn collect_constraints(&self, sel: Expr, state: &mut ConstraintState) { match self { - Ctrl::Return(_, values) => { + Ctrl::Return(_, _, values) => { // channel and function index let mut args = vec![ sel.clone() * function_channel(), diff --git a/src/aiur/execute.rs b/src/aiur/execute.rs index 9e607ebc..8a0a03f7 100644 --- a/src/aiur/execute.rs +++ b/src/aiur/execute.rs @@ -18,6 +18,7 @@ use crate::{ pub struct QueryResult { pub(crate) output: Vec, pub(crate) multiplicity: G, + pub(crate) return_group: usize, } pub type QueryMap = FxIndexMap, QueryResult>; @@ -256,6 +257,7 @@ impl Function { let result = QueryResult { output: vec![ptr], multiplicity: G::from_bool(!unconstrained), + return_group: 0, }; memory_queries.insert(values, result); map.push(ptr); @@ -480,7 +482,7 @@ impl Function { map.extend(yielded); push_block_exec_entries!(cont.block); }, - ExecEntry::Ctrl(Ctrl::Return(_, output)) => { + ExecEntry::Ctrl(Ctrl::Return(_, group, output)) => { // Register the query. let input_size = toplevel.functions[fun_idx].layout.input_size; let args = map[..input_size].to_vec(); @@ -488,6 +490,7 @@ impl Function { let result = QueryResult { output: output.clone(), multiplicity: G::from_bool(!unconstrained), + return_group: *group, }; record.function_queries[fun_idx].insert(args, result); if let Some(CallerState { diff --git a/src/aiur/synthesis.rs b/src/aiur/synthesis.rs index 1dd7f971..4122d788 100644 --- a/src/aiur/synthesis.rs +++ b/src/aiur/synthesis.rs @@ -71,7 +71,7 @@ where } enum CircuitType { - Function { idx: usize }, + Function { idx: usize, group: usize }, Memory { width: usize }, Bytes1, Bytes2, @@ -82,14 +82,19 @@ impl AiurSystem { toplevel: Toplevel, commitment_parameters: CommitmentParameters, ) -> Self { - let function_circuits = (0..toplevel.functions.len()).filter_map(|i| { - if !toplevel.functions[i].constrained { - None - } else { - let (constraints, lookups) = toplevel.build_constraints(i); - Some(LookupAir::new(AiurCircuit::Function(constraints), lookups)) - } - }); + let toplevel_ref = &toplevel; + let function_circuits = + (0..toplevel_ref.functions.len()).flat_map(move |i| { + let group_count = if toplevel_ref.functions[i].constrained { + toplevel_ref.filtered_functions[i].len() + } else { + 0 + }; + (0..group_count).map(move |group| { + let (constraints, lookups) = toplevel_ref.build_constraints(i, group); + LookupAir::new(AiurCircuit::Function(constraints), lookups) + }) + }); let memory_circuits = toplevel.memory_sizes.iter().map(|&width| { let (memory, lookups) = Memory::build(width); LookupAir::new(AiurCircuit::Memory(memory), lookups) @@ -131,14 +136,17 @@ impl AiurSystem { // Build the `SystemWitness` let _g = tracing::info_span!("aiur/witness").entered(); - let functions = - (0..self.toplevel.functions.len()).into_par_iter().filter_map(|idx| { - if self.toplevel.functions[idx].constrained { - Some(CircuitType::Function { idx }) + let functions: Vec = (0..self.toplevel.functions.len()) + .flat_map(|idx| { + let group_count = if self.toplevel.functions[idx].constrained { + self.toplevel.filtered_functions[idx].len() } else { - None - } - }); + 0 + }; + (0..group_count).map(move |group| CircuitType::Function { idx, group }) + }) + .collect(); + let functions = functions.into_par_iter(); let memories = self .toplevel .memory_sizes @@ -149,8 +157,8 @@ impl AiurSystem { .chain(memories) .chain(gadgets) .map(|circuit_type| match circuit_type { - CircuitType::Function { idx } => { - self.toplevel.witness_data(idx, &query_record, io_buffer) + CircuitType::Function { idx, group } => { + self.toplevel.witness_data(idx, group, &query_record, io_buffer) }, CircuitType::Memory { width } => { Memory::witness_data(width, &query_record) diff --git a/src/aiur/trace.rs b/src/aiur/trace.rs index 02d5a928..740dda5e 100644 --- a/src/aiur/trace.rs +++ b/src/aiur/trace.rs @@ -77,15 +77,18 @@ impl Toplevel { pub fn witness_data( &self, function_index: usize, + group: usize, query_record: &QueryRecord, io_buffer: &IOBuffer, ) -> (RowMajorMatrix, Vec>>) { - let func = &self.functions[function_index]; + let func = &self.filtered_functions[function_index][group]; let width = func.width(); let unfiltered_queries = &query_record.function_queries[function_index]; let queries = unfiltered_queries .iter() - .filter(|(_, res)| !res.multiplicity.is_zero()) + .filter(|(_, res)| { + !res.multiplicity.is_zero() && res.return_group == group + }) .collect::>(); let height_no_padding = queries.len(); let height = height_no_padding.next_power_of_two(); @@ -203,7 +206,7 @@ impl Ctrl { io_buffer: &IOBuffer, ) -> PopulateResult { match self { - Ctrl::Return(sel, _) => { + Ctrl::Return(sel, _, _) => { slice.selectors[*sel] = G::ONE; let lookup = function_lookup( -context.multiplicity, diff --git a/src/ffi/aiur/protocol.rs b/src/ffi/aiur/protocol.rs index 772b18ef..43b9e0a0 100644 --- a/src/ffi/aiur/protocol.rs +++ b/src/ffi/aiur/protocol.rs @@ -87,9 +87,12 @@ extern "C" fn rs_aiur_system_verify( } /// `Bytecode.Toplevel.execute`: runs execution only (no proof) and returns -/// `Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) × Array (Nat × Nat))`. -/// The trailing `Array (Nat × Nat)` is one `(uniqueRows, totalHits)` pair per -/// function circuit followed by one per memory size. +/// `Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) +/// × (Array (Array (Nat × Nat × Nat × Nat)) × Array (Nat × Nat)))`. +/// The function side is per-function array of per-split +/// `(groupIdx, width, uniqueRows, totalHits)` quadruples (display names are +/// looked up Lean-side via `Function.groupNames`). The memory side is +/// per-memory-size `(uniqueRows, totalHits)` pairs. /// On execution failure (e.g. assertion mismatch from a typechecker /// rejecting a constant), returns `Except.error msg` instead of panicking /// — letting Lean test runners (`KernelArena.lean`) classify failures. @@ -114,14 +117,11 @@ extern "C" fn rs_aiur_toplevel_execute( Err(err) => return LeanExcept::error_string(&err.to_string()), }; - // Build per-circuit (unique_rows, total_hits) pairs: - // one per function, then one per memory size. `unique_rows` is the trace - // height (number of distinct queries); `total_hits` is the sum of - // multiplicities (how often those rows were hit). - let mut query_counts: Vec<(usize, usize)> = Vec::with_capacity( - query_record.function_queries.len() + toplevel.memory_sizes.len(), - ); - let summarize = |q: &crate::aiur::execute::QueryMap| -> (usize, usize) { + // Summarize a query map into `(unique_rows, total_hits)`. `unique_rows` is + // the trace height (number of distinct queries with nonzero multiplicity); + // `total_hits` is the sum of multiplicities (how often those rows were + // hit). + let summarize_pair = |q: &crate::aiur::execute::QueryMap| -> (usize, usize) { let mut rows = 0usize; let mut hits = 0usize; for (_, res) in q.iter() { @@ -134,25 +134,80 @@ extern "C" fn rs_aiur_toplevel_execute( } (rows, hits) }; - for queries in &query_record.function_queries { - query_counts.push(summarize(queries)); - } - for size in &toplevel.memory_sizes { - let pair = query_record.memory_queries.get(size).map_or((0, 0), summarize); - query_counts.push(pair); - } - let lean_query_counts = { - let arr = LeanArray::alloc(query_counts.len()); - for (i, &(rows, hits)) in query_counts.iter().enumerate() { + + // Per-function array of `(group_idx, width, unique_rows, total_hits)` + // quadruples — one entry per split, in group-index order. Queries within + // a function are partitioned by `QueryResult.return_group`. + let function_stats: Vec> = + (0..toplevel.functions.len()) + .map(|i| { + let queries = &query_record.function_queries[i]; + toplevel.filtered_functions[i] + .iter() + .enumerate() + .map(|(group_idx, func)| { + let (rows, hits) = queries + .iter() + .filter(|(_, res)| res.return_group == group_idx) + .fold((0usize, 0usize), |(r, h), (_, res)| { + let m = usize::try_from(res.multiplicity.as_canonical_u64()) + .expect("multiplicity exceeds usize"); + if m != 0 { (r + 1, h + m) } else { (r, h) } + }); + let l = func.layout; + let width = + l.input_size + l.selectors + l.auxiliaries + 4 * (1 + l.lookups); + (group_idx, width, rows, hits) + }) + .collect() + }) + .collect(); + + let memory_counts: Vec<(usize, usize)> = toplevel + .memory_sizes + .iter() + .map(|size| { + query_record.memory_queries.get(size).map_or((0, 0), summarize_pair) + }) + .collect(); + + let lean_function_stats = { + let outer = LeanArray::alloc(function_stats.len()); + for (i, per_fn) in function_stats.iter().enumerate() { + let inner = LeanArray::alloc(per_fn.len()); + for (j, (group_idx, width, rows, hits)) in per_fn.iter().enumerate() { + // (Nat × Nat × Nat × Nat) — right-nested pair encoding + let quad = LeanProd::new( + LeanOwned::box_usize(*group_idx), + LeanProd::new( + LeanOwned::box_usize(*width), + LeanProd::new( + LeanOwned::box_usize(*rows), + LeanOwned::box_usize(*hits), + ), + ), + ); + inner.set(j, quad); + } + outer.set(i, inner); + } + outer + }; + let lean_memory_counts = { + let arr = LeanArray::alloc(memory_counts.len()); + for (i, &(rows, hits)) in memory_counts.iter().enumerate() { let pair = LeanProd::new(LeanOwned::box_usize(rows), LeanOwned::box_usize(hits)); arr.set(i, pair); } arr }; + let lean_query_counts = + LeanProd::new(lean_function_stats, lean_memory_counts); let lean_io = build_lean_io_buffer(&io_buffer); - // (Array G, (Array G × Array (Array G × IOKeyInfo), Array (Nat × Nat))) + // (Array G, (Array G × Array (Array G × IOKeyInfo), + // Array (Array (Nat × Nat × Nat × Nat)) × Array (Nat × Nat))) let io_counts = LeanProd::new(lean_io, lean_query_counts); let result = LeanProd::new(build_g_array(&output), io_counts); LeanExcept::ok(result) diff --git a/src/ffi/aiur/toplevel.rs b/src/ffi/aiur/toplevel.rs index e1f0df73..ed281ff2 100644 --- a/src/ffi/aiur/toplevel.rs +++ b/src/ffi/aiur/toplevel.rs @@ -2,7 +2,7 @@ use multi_stark::p3_field::PrimeCharacteristicRing; use lean_ffi::object::{LeanBorrowed, LeanCtor, LeanRef}; -use crate::lean::LeanAiurFunction; +use crate::lean::{LeanAiurCtrl, LeanAiurFunction}; use crate::{ FxIndexMap, @@ -152,13 +152,18 @@ fn decode_g_block_pair(ctor: LeanCtor>) -> (G, Block) { } fn decode_ctrl(ctor: LeanCtor>) -> Ctrl { - match ctor.tag() { + let typed = LeanAiurCtrl::from_ctor(ctor); + match typed.as_ctor().tag() { 0 => { - let [val_idx_obj, cases_obj, default_obj] = ctor.objs::<3>(); + let val_idx_obj = typed.get_obj(0); + let cases_obj = typed.get_obj(1); + let default_obj = typed.get_obj(2); let val_idx = lean_unbox_nat_as_usize(&val_idx_obj); - let vec_cases = - cases_obj.as_array().map(|o| decode_g_block_pair(o.as_ctor())); - let cases = FxIndexMap::from_iter(vec_cases); + let cases: FxIndexMap = cases_obj + .as_array() + .map(|o| decode_g_block_pair(o.as_ctor())) + .into_iter() + .collect(); let default = if default_obj.is_scalar() { None } else { @@ -169,31 +174,37 @@ fn decode_ctrl(ctor: LeanCtor>) -> Ctrl { Ctrl::Match(val_idx, cases, default) }, 1 => { - let [sel_idx_obj, val_idxs_obj] = ctor.objs::<2>(); + // `Ctrl.return : SelIdx → USize → Array ValIdx → Ctrl` — `sel_idx` and + // `val_idxs` are boxed; `group` is a `USize` scalar field declared + // via `num_usize: 1` in `LeanAiurCtrl`. + let sel_idx_obj = typed.get_obj(0); + let val_idxs_obj = typed.get_obj(1); let sel_idx = lean_unbox_nat_as_usize(&sel_idx_obj); + let group = typed.get_usize(0); let val_idxs = decode_vec_val_idx(val_idxs_obj); - Ctrl::Return(sel_idx, val_idxs) + Ctrl::Return(sel_idx, group, val_idxs) }, 2 => { - let [sel_idx_obj, val_idxs_obj] = ctor.objs::<2>(); + let sel_idx_obj = typed.get_obj(0); + let val_idxs_obj = typed.get_obj(1); let sel_idx = lean_unbox_nat_as_usize(&sel_idx_obj); let val_idxs = decode_vec_val_idx(val_idxs_obj); Ctrl::Yield(sel_idx, val_idxs) }, 3 => { - let [ - val_idx_obj, - cases_obj, - default_obj, - output_size_obj, - shared_aux_obj, - shared_lookups_obj, - cont_obj, - ] = ctor.objs::<7>(); + let val_idx_obj = typed.get_obj(0); + let cases_obj = typed.get_obj(1); + let default_obj = typed.get_obj(2); + let output_size_obj = typed.get_obj(3); + let shared_aux_obj = typed.get_obj(4); + let shared_lookups_obj = typed.get_obj(5); + let cont_obj = typed.get_obj(6); let val_idx = lean_unbox_nat_as_usize(&val_idx_obj); - let vec_cases = - cases_obj.as_array().map(|o| decode_g_block_pair(o.as_ctor())); - let cases = FxIndexMap::from_iter(vec_cases); + let cases: FxIndexMap = cases_obj + .as_array() + .map(|o| decode_g_block_pair(o.as_ctor())) + .into_iter() + .collect(); let default = if default_obj.is_scalar() { None } else { @@ -237,6 +248,10 @@ fn decode_function_layout(ctor: LeanCtor>) -> FunctionLayout { } fn decode_function(ctor: LeanCtor>) -> Function { + // Lean `Bytecode.Function` has 3 boxed fields (body, layout, groupNames) + // and 2 scalar booleans (entry, constrained). The `groupNames` slot is the + // display table for return-group indices and is only consumed Lean-side + // (via `Statistics.computeStats`), so the Rust decoder skips it. let ctor = LeanAiurFunction::from_ctor(ctor); let body = decode_block(ctor.get_obj(0).as_ctor()); let layout = decode_function_layout(ctor.get_obj(1).as_ctor()); @@ -249,10 +264,16 @@ pub(crate) fn decode_toplevel( obj: &LeanAiurToplevel, ) -> Toplevel { let ctor = obj.as_ctor(); - let [functions_obj, memory_sizes_obj] = ctor.objs::<2>(); - let functions = + let [functions_obj, filtered_functions_obj, memory_sizes_obj] = + ctor.objs::<3>(); + let functions: Vec = functions_obj.as_array().map(|o| decode_function(o.as_ctor())); + // `filteredFunctions : Array (Array Function)` — positional by group index. + let filtered_functions: Vec> = + filtered_functions_obj.as_array().map(|inner_obj| { + inner_obj.as_array().map(|fn_obj| decode_function(fn_obj.as_ctor())) + }); let memory_sizes = memory_sizes_obj.as_array().map(|x| lean_unbox_nat_as_usize(&x)); - Toplevel { functions, memory_sizes } + Toplevel { functions, memory_sizes, filtered_functions } } diff --git a/src/lean.rs b/src/lean.rs index ef41347b..b02320aa 100644 --- a/src/lean.rs +++ b/src/lean.rs @@ -234,8 +234,18 @@ lean_ffi::lean_inductive! { // --- Aiur types --- - LeanAiurToplevel [ { num_obj: 2 } ]; - LeanAiurFunction [ { num_obj: 2, num_8: 2 } ]; + LeanAiurToplevel [ { num_obj: 3 } ]; + LeanAiurFunction [ { num_obj: 3, num_8: 2 } ]; + + // `Bytecode.Ctrl` — `return` carries a `USize` group index as a scalar + // ctor field. Declared so `get_usize(0)` is bounds-checked against + // `num_usize: 1` rather than reaching into raw memory. + LeanAiurCtrl [ + { num_obj: 3 }, // tag 0: match + { num_obj: 2, num_usize: 1 }, // tag 1: return + { num_obj: 2 }, // tag 2: yield + { num_obj: 7 }, // tag 3: matchContinue + ]; // --- Block / comparison types ---