Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Ix/Aiur.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ public import Ix.Aiur.Compiler.Concretize
public import Ix.Aiur.Compiler.Layout
public import Ix.Aiur.Compiler.Lower
public import Ix.Aiur.Compiler.Dedup
public import Ix.Aiur.Compiler.Split
public import Ix.Aiur.Compiler
public import Ix.Aiur.Statistics
6 changes: 4 additions & 2 deletions Ix/Aiur/Compiler.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ public import Ix.Aiur.Compiler.Lower
public import Ix.Aiur.Compiler.Dedup
public import Ix.Aiur.Compiler.Concretize
public import Ix.Aiur.Compiler.Simple
public import Ix.Aiur.Compiler.Split

/-!
Aiur compiler pipeline: type-check, simplify, concretize, lower, deduplicate.
Expand Down Expand Up @@ -68,7 +69,7 @@ def Bytecode.Ctrl.collectConstrainedCallees (c : Bytecode.Ctrl) :
| some block => branchCallees ++ block.collectConstrainedCallees
| none => branchCallees
withDefault ++ continuation.collectConstrainedCallees
| .return _ _ | .yield _ _ => #[]
| .return _ _ _ | .yield _ _ => #[]
termination_by (sizeOf c, 0)
decreasing_by
all_goals first
Expand Down Expand Up @@ -117,9 +118,10 @@ def Source.Toplevel.compile (t : Source.Toplevel) : Except String CompiledToplev
let (bytecodeRaw, preNameMap) ← concDecls.toBytecode
let (bytecodeDedup, remap) := bytecodeRaw.deduplicate
let needs := bytecodeDedup.needsCircuit
let bytecode := { bytecodeDedup with
let bytecodeConstrained : Bytecode.Toplevel := { bytecodeDedup with
functions := bytecodeDedup.functions.mapIdx fun i f =>
{ f with constrained := needs[i]! } }
let bytecode := bytecodeConstrained.computeFiltered
let nameMap := preNameMap.fold (init := (∅ : Std.HashMap Global Bytecode.FunIdx))
fun acc name idx => acc.insert name (remap idx)
pure (CompiledToplevel.mk t bytecode nameMap)
Expand Down
5 changes: 5 additions & 0 deletions Ix/Aiur/Compiler/Check.lean
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,9 @@ def inferTerm (t : Term) : CheckM Typed.Term := match t with
| some sub => do pure (some (← inferTerm sub))
let ret' ← inferTerm ret
pure (Typed.Term.debug ret'.typ ret'.escapes label term' ret')
| .retGroup name inner => do
let inner' ← inferTerm inner
pure (Typed.Term.retGroup inner'.typ inner'.escapes name inner')
termination_by (sizeOf t, 0)
decreasing_by
all_goals first
Expand Down Expand Up @@ -917,6 +920,8 @@ def zonkTypedTerm (t : Typed.Term) : CheckM Typed.Term := match t with
| none => pure none
| some sub => do pure (some (← zonkTypedTerm sub))
pure (.debug (← zonkTyp τ) e label t' (← zonkTypedTerm r))
| .retGroup τ e name inner => do
pure (.retGroup (← zonkTyp τ) e name (← zonkTypedTerm inner))
termination_by sizeOf t
decreasing_by
all_goals first
Expand Down
8 changes: 8 additions & 0 deletions Ix/Aiur/Compiler/Concretize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ def termToConcrete
| none => pure none
| some sub => do pure (some (← termToConcrete mono sub))
pure (.debug (← typToConcrete mono τ) e l t' (← termToConcrete mono r))
| .retGroup τ e name inner => do
pure (.retGroup (← typToConcrete mono τ) e name (← termToConcrete mono inner))
termination_by t => sizeOf t
decreasing_by
all_goals first
Expand Down Expand Up @@ -541,6 +543,8 @@ def rewriteTypedTerm (decls : Typed.Decls)
| none => none
| some sub => some (rewriteTypedTerm decls subst mono sub)
.debug (rewriteTyp subst mono τ) e l t' (rewriteTypedTerm decls subst mono r)
| .retGroup τ e name inner =>
.retGroup (rewriteTyp subst mono τ) e name (rewriteTypedTerm decls subst mono inner)
termination_by t => sizeOf t
decreasing_by
all_goals first
Expand Down Expand Up @@ -625,6 +629,7 @@ def collectInTypedTerm (seen : Std.HashSet (Global × Array Typ)) :
let seen := collectInTyp seen τ
let seen := match t with | some t => collectInTypedTerm seen t | none => seen
collectInTypedTerm seen r
| .retGroup τ _ _ inner => collectInTypedTerm (collectInTyp seen τ) inner
termination_by t => sizeOf t
decreasing_by
all_goals first
Expand Down Expand Up @@ -683,6 +688,7 @@ def collectCalls (decls : Typed.Decls)
| .debug _ _ _ t r =>
let seen := match t with | some t => collectCalls decls seen t | none => seen
collectCalls decls seen r
| .retGroup _ _ _ inner => collectCalls decls seen inner
termination_by t => sizeOf t
decreasing_by
all_goals first
Expand Down Expand Up @@ -771,6 +777,8 @@ def substInTypedTerm (subst : Global → Option Typ) : Typed.Term → Typed.Term
| none => none
| some sub => some (substInTypedTerm subst sub)
.debug (Typ.instantiate subst τ) e l t' (substInTypedTerm subst r)
| .retGroup τ e name inner =>
.retGroup (Typ.instantiate subst τ) e name (substInTypedTerm subst inner)
termination_by t => sizeOf t
decreasing_by
all_goals first
Expand Down
7 changes: 4 additions & 3 deletions Ix/Aiur/Compiler/Dedup.lean
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ mutual
| .match v branches def_ =>
.match v (branches.attach.map fun ⟨(g, b), _⟩ => (g, skeletonBlock b))
(match def_ with | none => none | some b => some (skeletonBlock b))
| .return s vs => .return s vs
| .return s g vs => .return s g vs
| .yield s vs => .yield s vs
| .matchContinue v branches def_ outputSize sharedAux sharedLookups cont =>
.matchContinue v (branches.attach.map fun ⟨(g, b), _⟩ => (g, skeletonBlock b))
Expand Down Expand Up @@ -107,7 +107,7 @@ mutual
| .match v branches def_ =>
.match v (branches.attach.map fun ⟨(g, b), _⟩ => (g, rewriteBlock f b))
(match def_ with | none => none | some b => some (rewriteBlock f b))
| .return s vs => .return s vs
| .return s g vs => .return s g vs
| .yield s vs => .yield s vs
| .matchContinue v branches def_ outputSize sharedAux sharedLookups cont =>
.matchContinue v (branches.attach.map fun ⟨(g, b), _⟩ => (g, rewriteBlock f b))
Expand Down Expand Up @@ -196,7 +196,8 @@ def deduplicate_newFunctions (functions : Array Function) (classes : Array Nat)
if can then
let entry := deduplicate_class_entry functions classes cls
let body := rewriteBlock remapFn f.body
acc.push { body, layout := f.layout, entry, constrained := false }
acc.push { body, layout := f.layout, groupNames := f.groupNames,
entry, constrained := false }
else acc)
#[]

Expand Down
74 changes: 61 additions & 13 deletions Ix/Aiur/Compiler/Lower.lean
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ structure CompilerState where
ops : Array Bytecode.Op
selIdx : Bytecode.SelIdx
degrees : Array Nat
/-- Top of the `#[return_group(…)]` annotation stack — the display name
whose index will tag every `Ctrl.return` emitted inside its scope. The
index is looked up (or allocated) lazily at emit time. -/
currentReturnGroupName : String := ""
/-- Group display names allocated so far; position `i` is the name for
group index `i`. -/
groupNames : Array String := #[]
/-- Inverse of `groupNames`: maps name → allocated index. -/
groupNameMap : Std.HashMap String USize := {}
deriving Inhabited

abbrev CompileM := EStateM String CompilerState
Expand Down Expand Up @@ -122,6 +131,20 @@ def pushOp (op : Bytecode.Op) (size : Nat := 1) : CompileM (Array Bytecode.ValId
def extractOps : CompileM (Array Bytecode.Op) :=
modifyGet fun s => (s.ops, {s with ops := #[]})

/-- Look up the `USize` index for the current return-group name, allocating
fresh storage in `groupNames`/`groupNameMap` on first encounter. -/
def allocCurrentGroup : CompileM USize := do
let st ← get
let name := st.currentReturnGroupName
match st.groupNameMap[name]? with
| some idx => pure idx
| none =>
let idx : USize := USize.ofNat st.groupNames.size
modify fun s => { s with
groupNameMap := s.groupNameMap.insert name idx
groupNames := s.groupNames.push name }
pure idx

open Concrete in
mutual

Expand Down Expand Up @@ -299,6 +322,7 @@ def toIndex
| some sub => do pure (some (← toIndex layoutMap bindings sub))
modify fun stt => { stt with ops := stt.ops.push (.debug label term) }
toIndex layoutMap bindings ret
| .retGroup _ _ _ inner => toIndex layoutMap bindings inner
termination_by (sizeOf term, 0)
decreasing_by
all_goals first
Expand Down Expand Up @@ -446,6 +470,12 @@ def Concrete.Term.compile
let data ← toIndex layoutMap bindings data
modify fun stt => { stt with ops := stt.ops.push (.ioWrite data) }
ret.compile returnTyp layoutMap bindings yieldCtrl
| .retGroup _ _ name inner => do
let oldGroup := (← get).currentReturnGroupName
modify fun s => { s with currentReturnGroupName := name }
let blk ← inner.compile returnTyp layoutMap bindings yieldCtrl
modify fun s => { s with currentReturnGroupName := oldGroup }
pure blk
| .match _ _ scrut cases defaultOpt => do
let idxs := bindings[scrut]?.getD #[0]
let ops ← extractOps
Expand All @@ -460,21 +490,24 @@ def Concrete.Term.compile
pure ({ ops, ctrl } : Bytecode.Block)
| .ret _ _ term => do
let idxs ← toIndex layoutMap bindings term
let groupIdx ← allocCurrentGroup
let state ← get
let state := { state with selIdx := state.selIdx + 1 }
set state
let ops := state.ops
let id := state.selIdx
pure ({ ops, ctrl := .return (id - 1) idxs } : Bytecode.Block)
pure ({ ops, ctrl := .return (id - 1) groupIdx idxs } : Bytecode.Block)
| _ => do
let idxs ← toIndex layoutMap bindings term
let groupIdx ← allocCurrentGroup
let state ← get
let state := { state with selIdx := state.selIdx + 1 }
set state
let ops := state.ops
let id := state.selIdx
let ctrl : Bytecode.Ctrl :=
if yieldCtrl && !term.escapes then .yield (id - 1) idxs else .return (id - 1) idxs
if yieldCtrl && !term.escapes then .yield (id - 1) idxs
else .return (id - 1) groupIdx idxs
pure ({ ops, ctrl } : Bytecode.Block)
termination_by (sizeOf term, 0)
decreasing_by
Expand All @@ -497,7 +530,10 @@ def Concrete.addCase
| .field g => do
let initState ← get
let term ← term.compile returnTyp layoutMap bindings yieldCtrl
set { initState with selIdx := (← get).selIdx }
let cur ← get
set { initState with selIdx := cur.selIdx,
groupNames := cur.groupNames,
groupNameMap := cur.groupNameMap }
pure (cases.push (g, term), defaultBlock)
| .ref global pats => do
let (index, offsets) ← match layoutMap[global]? with
Expand All @@ -516,22 +552,30 @@ def Concrete.addCase
acc.insert patLocal slice
let initState ← get
let term ← term.compile returnTyp layoutMap ptrBindings yieldCtrl
set { initState with selIdx := (← get).selIdx }
let cur ← get
set { initState with selIdx := cur.selIdx,
groupNames := cur.groupNames,
groupNameMap := cur.groupNameMap }
pure (cases.push (.ofNat index, term), defaultBlock)
| .wildcard => do
let initState ← get
let term ← term.compile returnTyp layoutMap bindings yieldCtrl
set { initState with selIdx := (← get).selIdx }
let cur ← get
set { initState with selIdx := cur.selIdx,
groupNames := cur.groupNames,
groupNameMap := cur.groupNameMap }
pure (cases, .some term)
| _ => throw "addCase: unsupported pattern in concrete lower"
termination_by _ pair => (sizeOf pair.snd, 1)
decreasing_by all_goals first | decreasing_tactic | grind

end

/-- Lower a full concrete function to bytecode. -/
/-- Lower a full concrete function to bytecode. Returns the body, layout
state, and the per-function `groupNames` table (position `i` is the display
name for group index `i`). -/
def Concrete.Function.compile (layoutMap : LayoutMap) (f : Concrete.Function) :
Except String (Bytecode.Block × Bytecode.LayoutMState) := do
Except String (Bytecode.Block × Bytecode.LayoutMState × Array String) := do
let (_inputSize, _outputSize) ← match layoutMap[f.name]? with
| some (.function layout) => pure (layout.inputSize, layout.outputSize)
| _ => throw s!"`{f.name}` should be a function"
Expand All @@ -542,15 +586,16 @@ def Concrete.Function.compile (layoutMap : LayoutMap) (f : Concrete.Function) :
| .ok len => pure len
let indices := Array.range' valIdx len
pure (valIdx + len, bindings.insert arg indices)
let state := { valIdx, selIdx := 0, ops := #[], degrees := Array.replicate valIdx 1 }
let state : CompilerState := { valIdx, selIdx := 0, ops := #[],
degrees := Array.replicate valIdx 1 }
match f.body.compile f.output layoutMap bindings |>.run state with
| .error e _ => throw e
| .ok body _ =>
| .ok body finalState =>
let (_, layoutMState) := Bytecode.blockLayout body |>.run (.new valIdx)
let layoutMState := { layoutMState with functionLayout :=
{ layoutMState.functionLayout with
lookups := layoutMState.functionLayout.lookups + 1 } }
pure (body, layoutMState)
pure (body, layoutMState, finalState.groupNames)

def Concrete.Decls.toBytecode (decls : Concrete.Decls) :
Except String (Bytecode.Toplevel × Std.HashMap Global Bytecode.FunIdx) := do
Expand All @@ -559,13 +604,16 @@ def Concrete.Decls.toBytecode (decls : Concrete.Decls) :
let (functions, memSizes, nameMap) ← decls.foldlM (init := (#[], initMemSizes, {}))
fun acc@(functions, memSizes, nameMap) (_, decl) => match decl with
| .function function => do
let (body, layoutMState) ← function.compile layout
let (body, layoutMState, groupNames) ← function.compile layout
let nameMap := nameMap.insert function.name functions.size
let function := ⟨body, layoutMState.functionLayout, function.entry, false⟩
let groupNames := if groupNames.isEmpty then #[""] else groupNames
let function : Bytecode.Function :=
{ body, layout := layoutMState.functionLayout,
groupNames, entry := function.entry, constrained := false }
let memSizes := layoutMState.memSizes.fold (·.insert ·) memSizes
pure (functions.push function, memSizes, nameMap)
| _ => pure acc
pure (functions, memSizes.toArray, nameMap)
pure ({ functions, memorySizes := memSizes.toArray : Bytecode.Toplevel }, nameMap)

end Aiur

Expand Down
1 change: 1 addition & 0 deletions Ix/Aiur/Compiler/Match.lean
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def typedToSimple : Term → Simple.Term
| .debug τ e l t r =>
let t' := match t with | none => none | some sub => some (typedToSimple sub)
.debug τ e l t' (typedToSimple r)
| .retGroup τ e name inner => .retGroup τ e name (typedToSimple inner)
termination_by t => sizeOf t
decreasing_by all_goals first | decreasing_tactic | grind

Expand Down
3 changes: 3 additions & 0 deletions Ix/Aiur/Compiler/Simple.lean
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def simplifyTypedTerm (decls : Source.Decls) : Term → Except CheckError Term
let a' ← simplifyTypedTerm decls a
let b' ← simplifyTypedTerm decls b
pure (.u32LessThan τ e a' b')
| .retGroup τ e name inner => do
let inner' ← simplifyTypedTerm decls inner
pure (.retGroup τ e name inner')
| t => pure t
termination_by t => sizeOf t
decreasing_by
Expand Down
Loading