diff --git a/internal/interpreter/recording_store.go b/internal/interpreter/recording_store.go new file mode 100644 index 00000000..e12c735b --- /dev/null +++ b/internal/interpreter/recording_store.go @@ -0,0 +1,76 @@ +package interpreter + +import ( + "context" + "math/big" +) + +// recordingStore wraps a Store and records all balance and metadata reads, +// preserving the order in which the underlying store returned them. +// +// It is used by ResolveDependencies to discover which data a script depends on. +type recordingStore struct { + inner Store + balanceReads Balances + metadataReads AccountsMetadata +} + +func newRecordingStore(inner Store) *recordingStore { + return &recordingStore{ + inner: inner, + balanceReads: Balances{}, + metadataReads: AccountsMetadata{}, + } +} + +func (r *recordingStore) GetBalances(ctx context.Context, query BalanceQuery) (Balances, error) { + result, err := r.inner.GetBalances(ctx, query) + if err != nil { + return nil, err + } + + for _, row := range result { + if r.balanceReads.hasRow(row.Account, row.Asset, row.Color) { + continue + } + amount := new(big.Int) + if row.Amount != nil { + amount.Set(row.Amount) + } + r.balanceReads = append(r.balanceReads, BalanceRow{ + Account: row.Account, + Asset: row.Asset, + Color: row.Color, + Amount: amount, + }) + } + + return result, nil +} + +func (r *recordingStore) GetAccountsMetadata(ctx context.Context, query MetadataQuery) (AccountsMetadata, error) { + result, err := r.inner.GetAccountsMetadata(ctx, query) + if err != nil { + return nil, err + } + + for account, meta := range result { + if _, ok := r.metadataReads[account]; !ok { + r.metadataReads[account] = AccountMetadata{} + } + for key, value := range meta { + r.metadataReads[account][key] = value + } + } + + return result, nil +} + +func (rows Balances) hasRow(account, asset, color string) bool { + for i := range rows { + if rows[i].Account == account && rows[i].Asset == asset && rows[i].Color == color { + return true + } + } + return false +} diff --git a/internal/interpreter/resolve_dependencies.go b/internal/interpreter/resolve_dependencies.go new file mode 100644 index 00000000..0d868ec7 --- /dev/null +++ b/internal/interpreter/resolve_dependencies.go @@ -0,0 +1,273 @@ +package interpreter + +import ( + "context" + "maps" + "slices" + + "github.com/formancehq/numscript/internal/flags" + "github.com/formancehq/numscript/internal/parser" + "github.com/formancehq/numscript/internal/utils" +) + +// ResolvedDependencies summarizes what a script reads from and writes to the +// store. The caller can use it to preload data and to detect input drift +// between successive runs. +type ResolvedDependencies struct { + // Reads contains the data the script read from the store while resolving. + Reads ResolvedReads + + // Writes contains the (account, asset, color) tuples whose balance can be + // impacted by a posting emitted by the script. + Writes ResolvedWrites +} + +// ResolvedReads holds the data read from the store while resolving the +// script's dependencies. +type ResolvedReads struct { + // Volumes contains every (account, asset, color) → balance row read from + // the store, in the order it was returned. + Volumes Balances + + // Metadata contains all (account, key) → value pairs read from the store. + Metadata AccountsMetadata +} + +// ResolvedWrites holds the data the script may write to the store. +type ResolvedWrites struct { + // Volumes lists every (account, asset, color) tuple that may be impacted + // by a posting emitted by the script. + Volumes BalanceQuery +} + +// ResolveDependenciesOptions configures ResolveDependencies behavior. +type ResolveDependenciesOptions struct { + // FeatureFlags enables additional experimental features + // (same semantics as RunWithFeatureFlags). + FeatureFlags map[string]struct{} +} + +// ResolveDependencies discovers which data a script reads from the store and +// which (account, asset, color) tuples it may write to, without executing any +// posting. +// +// It performs variable resolution and source preloading — the two phases that +// RunProgram runs before executing statements — then walks the send statements +// to collect the touched accounts. No transfers are simulated, so the call is +// cheap and does not depend on the script's runtime semantics (allotments, +// overdraft, etc.). +// +// Store calls (GetBalances/GetAccountsMetadata) are issued in a deterministic +// order across runs with identical inputs, so the caller can hash them to +// detect input drift. +func ResolveDependencies( + ctx context.Context, + program parser.Program, + vars map[string]string, + store Store, + opts ResolveDependenciesOptions, +) (*ResolvedDependencies, InterpreterError) { + recorder := newRecordingStore(store) + + featureFlags := maps.Clone(opts.FeatureFlags) + if featureFlags == nil { + featureFlags = make(map[string]struct{}, len(program.Flags)) + } + for _, flag := range program.Flags { + if slices.Index(flags.AllFlags, flag.String) == -1 { + return nil, InvalidFeature{Feature: flag.String} + } + featureFlags[flag.String] = struct{}{} + } + + st := programState{ + ParsedVars: make(map[string]Value), + TxMeta: make(map[string]Value), + CachedAccountsMeta: AccountsMetadata{}, + CachedBalances: InternalBalances{}, + SetAccountsMeta: AccountsMetadata{}, + Store: recorder, + Postings: make([]Posting, 0), + fundsQueue: newFundsQueue(nil), + CurrentBalanceQuery: BalanceQuery{}, + ctx: ctx, + FeatureFlags: featureFlags, + } + + st.varOriginPosition = true + if program.Vars != nil { + if err := st.parseVars(program.Vars.Declarations, vars); err != nil { + return nil, err + } + } + st.varOriginPosition = false + + for _, statement := range program.Statements { + if err := st.findBalancesQueriesInStatement(statement); err != nil { + return nil, err + } + } + if err := st.runBalancesQuery(); err != nil { + return nil, QueryBalanceError{WrappedError: err} + } + + writes := BalanceQuery{} + for _, statement := range program.Statements { + send, ok := statement.(*parser.SendStatement) + if !ok { + continue + } + if err := st.collectSendWrites(*send, &writes); err != nil { + return nil, err + } + } + + return &ResolvedDependencies{ + Reads: ResolvedReads{ + Volumes: recorder.balanceReads, + Metadata: recorder.metadataReads, + }, + Writes: ResolvedWrites{Volumes: writes}, + }, nil +} + +func (st *programState) collectSendWrites( + send parser.SendStatement, + writes *BalanceQuery, +) InterpreterError { + asset, _, err := st.evaluateSentAmt(send.SentValue) + if err != nil { + return err + } + st.CurrentAsset = asset + + if err := st.collectSourceWrites(send.Source, writes); err != nil { + return err + } + return st.collectDestinationWrites(send.Destination, writes) +} + +func (st *programState) collectSourceWrites( + source parser.Source, + writes *BalanceQuery, +) InterpreterError { + switch source := source.(type) { + case *parser.SourceAccount: + return st.touchAccount(source.ValueExpr, source.Color, writes) + + case *parser.SourceOverdraft: + return st.touchAccount(source.Address, source.Color, writes) + + case *parser.SourceWithScaling: + return st.touchAccount(source.Address, nil, writes) + + case *parser.SourceInorder: + for _, sub := range source.Sources { + if err := st.collectSourceWrites(sub, writes); err != nil { + return err + } + } + return nil + + case *parser.SourceOneof: + for _, sub := range source.Sources { + if err := st.collectSourceWrites(sub, writes); err != nil { + return err + } + } + return nil + + case *parser.SourceCapped: + return st.collectSourceWrites(source.From, writes) + + case *parser.SourceAllotment: + for _, item := range source.Items { + if err := st.collectSourceWrites(item.From, writes); err != nil { + return err + } + } + return nil + + default: + utils.NonExhaustiveMatchPanic[any](source) + return nil + } +} + +func (st *programState) collectDestinationWrites( + dest parser.Destination, + writes *BalanceQuery, +) InterpreterError { + switch dest := dest.(type) { + case *parser.DestinationAccount: + return st.touchAccount(dest.ValueExpr, nil, writes) + + case *parser.DestinationInorder: + for _, clause := range dest.Clauses { + if err := st.collectKeptOrDestWrites(clause.To, writes); err != nil { + return err + } + } + return st.collectKeptOrDestWrites(dest.Remaining, writes) + + case *parser.DestinationOneof: + for _, clause := range dest.Clauses { + if err := st.collectKeptOrDestWrites(clause.To, writes); err != nil { + return err + } + } + return st.collectKeptOrDestWrites(dest.Remaining, writes) + + case *parser.DestinationAllotment: + for _, item := range dest.Items { + if err := st.collectKeptOrDestWrites(item.To, writes); err != nil { + return err + } + } + return nil + + default: + utils.NonExhaustiveMatchPanic[any](dest) + return nil + } +} + +func (st *programState) collectKeptOrDestWrites( + k parser.KeptOrDestination, + writes *BalanceQuery, +) InterpreterError { + switch k := k.(type) { + case *parser.DestinationKept: + return nil + case *parser.DestinationTo: + return st.collectDestinationWrites(k.Destination, writes) + default: + utils.NonExhaustiveMatchPanic[any](k) + return nil + } +} + +func (st *programState) touchAccount( + accountExpr parser.ValueExpr, + colorExpr parser.ValueExpr, + writes *BalanceQuery, +) InterpreterError { + account, err := evaluateExprAs(st, accountExpr, expectAccount) + if err != nil { + return err + } + color, err := evaluateOptExprAs(st, colorExpr, expectString) + if err != nil { + return err + } + + item := BalanceQueryItem{ + Account: string(account), + Asset: string(st.CurrentAsset), + Color: string(color), + } + if !slices.Contains(*writes, item) { + *writes = append(*writes, item) + } + return nil +} diff --git a/internal/interpreter/resolve_dependencies_test.go b/internal/interpreter/resolve_dependencies_test.go new file mode 100644 index 00000000..004c4f95 --- /dev/null +++ b/internal/interpreter/resolve_dependencies_test.go @@ -0,0 +1,342 @@ +package interpreter + +import ( + "context" + "math/big" + "testing" + + "github.com/formancehq/numscript/internal/parser" + "github.com/stretchr/testify/require" +) + +func resolveTest(t *testing.T, script string, vars map[string]string, store Store) *ResolvedDependencies { + t.Helper() + parsed := parser.Parse(script) + require.Empty(t, parsed.Errors, "script should parse without errors") + + deps, err := ResolveDependencies(context.Background(), parsed.Value, vars, store, ResolveDependenciesOptions{}) + require.NoError(t, err) + require.NotNil(t, deps) + + return deps +} + +func readVolume(b Balances, account, asset string) *big.Int { + for _, row := range b { + if row.Account == account && row.Asset == asset { + return row.Amount + } + } + return nil +} + +func hasWrite(q BalanceQuery, account, asset string) bool { + for _, item := range q { + if item.Account == account && item.Asset == asset { + return true + } + } + return false +} + +func TestResolveDependencies_SimpleTransfer(t *testing.T) { + t.Parallel() + + deps := resolveTest(t, ` + send [USD/2 100] ( + source = @alice + destination = @bob + ) + `, nil, StaticStore{ + Balances: Balances{ + {Account: "alice", Asset: "USD/2", Amount: big.NewInt(500)}, + }, + }) + + require.Equal(t, big.NewInt(500), readVolume(deps.Reads.Volumes, "alice", "USD/2")) + + require.True(t, hasWrite(deps.Writes.Volumes, "alice", "USD/2")) + require.True(t, hasWrite(deps.Writes.Volumes, "bob", "USD/2")) +} + +func TestResolveDependencies_WorldSource(t *testing.T) { + t.Parallel() + + deps := resolveTest(t, ` + send [USD/2 100] ( + source = @world + destination = @bob + ) + `, nil, StaticStore{}) + + require.Empty(t, deps.Reads.Volumes) + require.True(t, hasWrite(deps.Writes.Volumes, "world", "USD/2")) + require.True(t, hasWrite(deps.Writes.Volumes, "bob", "USD/2")) +} + +func TestResolveDependencies_MetaCall(t *testing.T) { + t.Parallel() + + deps := resolveTest(t, ` + vars { + account $dest = meta(@config, "default_dest") + } + send [USD/2 100] ( + source = @world + destination = $dest + ) + `, nil, StaticStore{ + Meta: AccountsMetadata{ + "config": AccountMetadata{"default_dest": "treasury"}, + }, + }) + + require.Equal(t, "treasury", deps.Reads.Metadata["config"]["default_dest"]) + require.True(t, hasWrite(deps.Writes.Volumes, "treasury", "USD/2")) +} + +func TestResolveDependencies_MultipleSources(t *testing.T) { + t.Parallel() + + deps := resolveTest(t, ` + send [USD/2 200] ( + source = { + @checking + @savings + } + destination = @merchant + ) + `, nil, StaticStore{ + Balances: Balances{ + {Account: "checking", Asset: "USD/2", Amount: big.NewInt(50)}, + {Account: "savings", Asset: "USD/2", Amount: big.NewInt(300)}, + }, + }) + + require.NotNil(t, readVolume(deps.Reads.Volumes, "checking", "USD/2")) + require.NotNil(t, readVolume(deps.Reads.Volumes, "savings", "USD/2")) + require.True(t, hasWrite(deps.Writes.Volumes, "checking", "USD/2")) + require.True(t, hasWrite(deps.Writes.Volumes, "savings", "USD/2")) + require.True(t, hasWrite(deps.Writes.Volumes, "merchant", "USD/2")) +} + +func TestResolveDependencies_Variables(t *testing.T) { + t.Parallel() + + deps := resolveTest(t, ` + vars { + account $src + monetary $amount + } + send $amount ( + source = $src + destination = @dest + ) + `, map[string]string{ + "src": "users:alice", + "amount": "EUR/2 1000", + }, StaticStore{ + Balances: Balances{ + {Account: "users:alice", Asset: "EUR/2", Amount: big.NewInt(5000)}, + }, + }) + + require.NotNil(t, readVolume(deps.Reads.Volumes, "users:alice", "EUR/2")) + require.True(t, hasWrite(deps.Writes.Volumes, "users:alice", "EUR/2")) + require.True(t, hasWrite(deps.Writes.Volumes, "dest", "EUR/2")) +} + +func TestResolveDependencies_BalanceFunction(t *testing.T) { + t.Parallel() + + deps := resolveTest(t, ` + vars { + monetary $bal = balance(@src, USD/2) + } + send $bal ( + source = @src + destination = @dest + ) + `, nil, StaticStore{ + Balances: Balances{ + {Account: "src", Asset: "USD/2", Amount: big.NewInt(750)}, + }, + }) + + require.Equal(t, big.NewInt(750), readVolume(deps.Reads.Volumes, "src", "USD/2")) + require.True(t, hasWrite(deps.Writes.Volumes, "src", "USD/2")) + require.True(t, hasWrite(deps.Writes.Volumes, "dest", "USD/2")) +} + +func TestResolveDependencies_MultipleSends(t *testing.T) { + t.Parallel() + + deps := resolveTest(t, ` + send [USD/2 50] ( + source = @world + destination = @a + ) + send [EUR/2 100] ( + source = @b + destination = @c + ) + `, nil, StaticStore{ + Balances: Balances{ + {Account: "b", Asset: "EUR/2", Amount: big.NewInt(200)}, + }, + }) + + require.Nil(t, readVolume(deps.Reads.Volumes, "world", "USD/2")) + require.NotNil(t, readVolume(deps.Reads.Volumes, "b", "EUR/2")) + require.True(t, hasWrite(deps.Writes.Volumes, "a", "USD/2")) + require.True(t, hasWrite(deps.Writes.Volumes, "c", "EUR/2")) +} + +func TestResolveDependencies_SetAccountMeta(t *testing.T) { + t.Parallel() + + deps := resolveTest(t, ` + set_account_meta(@alice, "status", "active") + send [USD/2 100] ( + source = @world + destination = @alice + ) + `, nil, StaticStore{}) + + require.Empty(t, deps.Reads.Metadata, "set_account_meta should not produce metadata reads") + require.True(t, hasWrite(deps.Writes.Volumes, "alice", "USD/2")) +} + +func TestResolveDependencies_MetaChain(t *testing.T) { + t.Parallel() + + deps := resolveTest(t, ` + vars { + string $key = meta(@config, "key_name") + account $dest = meta(@routing, $key) + } + send [USD/2 100] ( + source = @world + destination = $dest + ) + `, nil, StaticStore{ + Meta: AccountsMetadata{ + "config": AccountMetadata{"key_name": "destination"}, + "routing": AccountMetadata{"destination": "treasury"}, + }, + }) + + require.Equal(t, "destination", deps.Reads.Metadata["config"]["key_name"]) + require.Equal(t, "treasury", deps.Reads.Metadata["routing"]["destination"]) + require.True(t, hasWrite(deps.Writes.Volumes, "treasury", "USD/2")) +} + +func TestResolveDependencies_SendAll(t *testing.T) { + t.Parallel() + + deps := resolveTest(t, ` + send [USD/2 *] ( + source = @src + destination = @dest + ) + `, nil, StaticStore{ + Balances: Balances{ + {Account: "src", Asset: "USD/2", Amount: big.NewInt(999)}, + }, + }) + + require.Equal(t, big.NewInt(999), readVolume(deps.Reads.Volumes, "src", "USD/2")) + require.True(t, hasWrite(deps.Writes.Volumes, "src", "USD/2")) + require.True(t, hasWrite(deps.Writes.Volumes, "dest", "USD/2")) +} + +func TestResolveDependencies_EmptyReads(t *testing.T) { + t.Parallel() + + deps := resolveTest(t, ` + send [USD/2 100] ( + source = @world + destination = @dest + ) + `, nil, StaticStore{}) + + require.Empty(t, deps.Reads.Volumes) + require.Empty(t, deps.Reads.Metadata) + require.True(t, hasWrite(deps.Writes.Volumes, "world", "USD/2")) + require.True(t, hasWrite(deps.Writes.Volumes, "dest", "USD/2")) +} + +func TestResolveDependencies_Nested(t *testing.T) { + t.Parallel() + + script := `vars { + account $s1 + account $s2 = meta(@account_that_needs_meta, "k") + number $b = balance(@account_that_needs_balance, USD/2) +} + +send [COIN 100] ( + source = { + $s1 + $s2 + @source3 + @world + } + destination = @dest +) +` + + deps := resolveTest(t, + script, + map[string]string{"s1": "source1"}, + StaticStore{ + Balances: Balances{ + {Account: "source1", Asset: "COIN", Amount: big.NewInt(123)}, + {Account: "source2", Asset: "COIN", Amount: big.NewInt(456)}, + {Account: "source3", Asset: "COIN", Amount: big.NewInt(55)}, + {Account: "account_that_needs_balance", Asset: "USD/2", Amount: big.NewInt(42)}, + }, + Meta: AccountsMetadata{ + "account_that_needs_meta": {"k": "source2"}, + }, + }) + + require.Equal(t, big.NewInt(123), readVolume(deps.Reads.Volumes, "source1", "COIN")) + require.Equal(t, big.NewInt(456), readVolume(deps.Reads.Volumes, "source2", "COIN")) + require.Equal(t, big.NewInt(55), readVolume(deps.Reads.Volumes, "source3", "COIN")) + require.Equal(t, big.NewInt(42), readVolume(deps.Reads.Volumes, "account_that_needs_balance", "USD/2")) + + require.Equal(t, AccountsMetadata{ + "account_that_needs_meta": {"k": "source2"}, + }, deps.Reads.Metadata) + + // Writes is a conservative over-approximation: every account that appears + // as a source or destination is listed, even if the actual run would not + // touch all of them. + for _, acc := range []string{"source1", "source2", "source3", "world", "dest"} { + require.True(t, hasWrite(deps.Writes.Volumes, acc, "COIN"), "expected %s in writes", acc) + } +} + +func TestResolveDependencies_MidScriptBalance(t *testing.T) { + t.Parallel() + + deps := resolveTest(t, ` +#![feature("experimental-mid-script-function-call")] +send [USD/2 100] ( + source = @world + destination = @acc +) +send balance(@acc, USD/2) ( + source = @acc + destination = @dest +) +`, nil, StaticStore{}) + + // The balance call hits the store during preload, recording acc/USD/2 = 0. + require.Equal(t, big.NewInt(0), readVolume(deps.Reads.Volumes, "acc", "USD/2")) + + require.True(t, hasWrite(deps.Writes.Volumes, "world", "USD/2")) + require.True(t, hasWrite(deps.Writes.Volumes, "acc", "USD/2")) + require.True(t, hasWrite(deps.Writes.Volumes, "dest", "USD/2")) +} diff --git a/numscript.go b/numscript.go index 6d11d1c6..49b52bdc 100644 --- a/numscript.go +++ b/numscript.go @@ -110,3 +110,26 @@ func (p ParseResult) GetSource() string { func (p ParseResult) GetInvolvedAccounts(vars VariablesMap) ([]accounts.InvolvedAccount, []accounts.InvolvedMeta, InterpreterError) { return interpreter.GetInvolvedAccounts(vars, p.parseResult.Value) } + +type ( + ResolvedDependencies = interpreter.ResolvedDependencies + ResolvedReads = interpreter.ResolvedReads + ResolvedWrites = interpreter.ResolvedWrites + ResolveDependenciesOptions = interpreter.ResolveDependenciesOptions +) + +// ResolveDependencies executes the script in dry-run mode and returns the +// (account, asset) → balance and (account, key) → value pairs that were read +// from the store, together with the (account, asset) pairs touched by the +// resulting postings. +// +// Store calls are issued in a deterministic order across runs with identical +// inputs, so the caller can hash them to detect input drift. +func (p ParseResult) ResolveDependencies( + ctx context.Context, + vars VariablesMap, + store Store, + opts ResolveDependenciesOptions, +) (*ResolvedDependencies, InterpreterError) { + return interpreter.ResolveDependencies(ctx, p.parseResult.Value, vars, store, opts) +}