diff --git a/docs/EVAL.md b/docs/EVAL.md index a3eff62..8a0301b 100644 --- a/docs/EVAL.md +++ b/docs/EVAL.md @@ -222,6 +222,47 @@ tests: --- +## Trace (Behavioral) Assertions + +Content matching checks *what* an agent answered. Trace assertions check *how* it got +there — which tools it called, how many LLM calls it made, and the path it took. They run +**after** the content match and are configured under `expect.trace` (the canonical schema +lives in `internal/eval/types.go`). + +After a successful invoke, AGK fetches the run's trace from the EvalServer +(`GET /traces/{id}`) and evaluates the assertions. Tool calls also use the `tools_called` +field from the invoke response, so `tool_calls` is checked even if the trace can't be fetched. + +```yaml +tests: + - name: "Answers about Paris using search, efficiently" + input: "What's the weather in Paris?" + expect: + type: contains + values: ["Paris"] + trace: + tool_calls: ["search"] # each listed tool must have been called + llm_calls: 2 # exact LLM-call count + execution_path: ["research", "format"] # must appear, in order, as a subsequence + min_steps: 2 # observed steps >= 2 + max_steps: 8 # observed steps <= 8 +``` + +### Trace Fields (`expect.trace`) + +| Field | Type | Check | +|-------|------|-------| +| `tool_calls` | string[] | Every listed tool must appear among the called tools (subset). | +| `llm_calls` | int | When > 0, the observed LLM-call count must match **exactly**. | +| `execution_path` | string[] | The listed span names must appear **in order** (gaps allowed). | +| `min_steps` | int | Observed step count (total spans) must be **≥** this. | +| `max_steps` | int | Observed step count (total spans) must be **≤** this. | + +A test fails if any assertion fails; the report lists every failed assertion. Omit a field +to skip that check. + +--- + ## Semantic Matching Strategies ### 1. Embedding Strategy diff --git a/internal/eval/http_target.go b/internal/eval/http_target.go index c8a3335..bf753c4 100644 --- a/internal/eval/http_target.go +++ b/internal/eval/http_target.go @@ -91,6 +91,31 @@ func (ht *HTTPTarget) Invoke(input string, timeout int) (*InvokeResponse, error) return &invokeResp, nil } +// FetchTrace retrieves a trace by ID from the EvalServer's GET /traces/{id} endpoint. +// It is used to validate trace (behavioral) expectations after an invoke. +func (ht *HTTPTarget) FetchTrace(traceID string) (*evalTrace, error) { + if traceID == "" { + return nil, fmt.Errorf("empty trace id") + } + + resp, err := ht.client.Get(ht.baseURL + "/traces/" + traceID) + if err != nil { + return nil, fmt.Errorf("trace request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("trace fetch returned HTTP %d: %s", resp.StatusCode, string(body)) + } + + var trace evalTrace + if err := json.NewDecoder(resp.Body).Decode(&trace); err != nil { + return nil, fmt.Errorf("failed to parse trace: %w", err) + } + return &trace, nil +} + // Health checks if the target is healthy func (ht *HTTPTarget) Health() error { resp, err := ht.client.Get(ht.baseURL + "/health") diff --git a/internal/eval/http_target_test.go b/internal/eval/http_target_test.go new file mode 100644 index 0000000..fbfb02a --- /dev/null +++ b/internal/eval/http_target_test.go @@ -0,0 +1,68 @@ +package eval + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestFetchTrace(t *testing.T) { + const traceID = "trace-abc" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/traces/"+traceID { + http.Error(w, "not found", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "id": "trace-abc", + "spans": [ + {"name": "agk.agent.run"}, + {"name": "agk.llm.generate"}, + {"name": "agk.tool.call", "attributes": {"agk.tool.name": "search"}} + ] + }`)) + })) + defer server.Close() + + target := NewHTTPTarget(server.URL, 5*time.Second) + + trace, err := target.FetchTrace(traceID) + if err != nil { + t.Fatalf("FetchTrace error: %v", err) + } + if trace.ID != traceID { + t.Errorf("trace ID = %q, want %q", trace.ID, traceID) + } + if len(trace.Spans) != 3 { + t.Fatalf("got %d spans, want 3", len(trace.Spans)) + } + + // Round-trip through the normalizer to confirm the wire format is consumable. + obs := buildObservedTrace(trace, nil) + if obs.LLMCalls != 1 { + t.Errorf("LLMCalls = %d, want 1", obs.LLMCalls) + } + if len(obs.ToolCalls) != 1 || obs.ToolCalls[0] != "search" { + t.Errorf("ToolCalls = %v, want [search]", obs.ToolCalls) + } +} + +func TestFetchTraceErrors(t *testing.T) { + target := NewHTTPTarget("http://127.0.0.1:0", time.Second) + + if _, err := target.FetchTrace(""); err == nil { + t.Error("expected error for empty trace id") + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "nope", http.StatusNotFound) + })) + defer server.Close() + + target = NewHTTPTarget(server.URL, 5*time.Second) + if _, err := target.FetchTrace("missing"); err == nil { + t.Error("expected error for 404 trace fetch") + } +} diff --git a/internal/eval/runner.go b/internal/eval/runner.go index 64fc127..3f84b3f 100644 --- a/internal/eval/runner.go +++ b/internal/eval/runner.go @@ -3,6 +3,7 @@ package eval import ( "context" "fmt" + "strings" "time" ) @@ -190,8 +191,32 @@ func (r *Runner) runTest(test Test, target *HTTPTarget) TestResult { return result } - // TODO: Validate trace expectations if specified (test.Expect.Trace) + // Validate trace (behavioral) expectations if specified + if test.Expect.Trace != nil { + if failures := r.validateTraceExpectation(test, target, resp); len(failures) > 0 { + result.Passed = false + result.ErrorMessage = "trace assertion failed: " + strings.Join(failures, "; ") + return result + } + } result.Passed = true return result } + +// validateTraceExpectation fetches the run's trace (when available) and checks it +// against the test's trace expectation. Tool calls come from the invoke response; +// LLM-call count, execution path, and step counts come from the fetched trace. +func (r *Runner) validateTraceExpectation(test Test, target *HTTPTarget, resp *InvokeResponse) []string { + var observed *evalTrace + if resp.TraceID != "" { + if t, err := target.FetchTrace(resp.TraceID); err == nil { + observed = t + } else if r.config.Verbose { + fmt.Printf(" [trace] could not fetch trace %s: %v\n", resp.TraceID, err) + } + } + + obs := buildObservedTrace(observed, resp.ToolsCalled) + return ValidateTrace(test.Expect.Trace, obs) +} diff --git a/internal/eval/trace_validator.go b/internal/eval/trace_validator.go new file mode 100644 index 0000000..07aa61d --- /dev/null +++ b/internal/eval/trace_validator.go @@ -0,0 +1,160 @@ +package eval + +import ( + "fmt" + "strings" +) + +// evalTrace mirrors the subset of the EvalServer's GET /traces/{id} response that +// trace assertions need. The server type lives in the framework (v1beta.EvalTrace); +// this is a decode-only copy so the CLI doesn't depend on the framework package. +type evalTrace struct { + ID string `json:"id"` + Spans []*evalSpan `json:"spans"` +} + +type evalSpan struct { + Name string `json:"name"` + Attributes map[string]interface{} `json:"attributes"` +} + +// ObservedTrace is the normalized view of a run's behavior used for assertions. +type ObservedTrace struct { + ToolCalls []string // distinct tool names invoked, in first-seen order + LLMCalls int // number of LLM spans + Path []string // ordered span names (the execution path) + Steps int // total spans (a proxy for execution steps) +} + +// buildObservedTrace normalizes a fetched trace (and the tools_called list from the +// invoke response) into an ObservedTrace. Either source may be empty; tools_called is +// treated as the authoritative tool list and augmented with any tool spans found. +func buildObservedTrace(t *evalTrace, toolsCalled []string) ObservedTrace { + obs := ObservedTrace{} + seen := make(map[string]bool) + + addTool := func(name string) { + name = strings.TrimSpace(name) + if name == "" || seen[name] { + return + } + seen[name] = true + obs.ToolCalls = append(obs.ToolCalls, name) + } + + for _, name := range toolsCalled { + addTool(name) + } + + if t != nil { + for _, sp := range t.Spans { + if sp == nil { + continue + } + obs.Path = append(obs.Path, sp.Name) + + lname := strings.ToLower(sp.Name) + if strings.Contains(lname, "llm") { + obs.LLMCalls++ + } + if isToolSpan(lname) { + addTool(toolNameFromSpan(sp)) + } + } + obs.Steps = len(t.Spans) + } + + return obs +} + +func isToolSpan(lowerName string) bool { + return strings.Contains(lowerName, "tool.call") || strings.Contains(lowerName, "tool_call") +} + +// toolNameFromSpan extracts a tool name from a tool span's attributes, trying the +// AgenticGoKit key first and a couple of common fallbacks. +func toolNameFromSpan(sp *evalSpan) string { + for _, key := range []string{"agk.tool.name", "tool.name", "tool"} { + if v, ok := sp.Attributes[key]; ok { + if s, ok := v.(string); ok && s != "" { + return s + } + } + } + return "" +} + +// ValidateTrace checks an ObservedTrace against a TraceExpectation and returns a list +// of human-readable failure messages (empty means all assertions passed). +// +// Semantics: +// - tool_calls: every listed tool must have been called (subset check) +// - llm_calls: when > 0, the observed LLM-call count must match exactly +// - min_steps: observed step count must be >= min +// - max_steps: observed step count must be <= max +// - execution_path: the listed names must appear, in order, as a subsequence +func ValidateTrace(exp *TraceExpectation, obs ObservedTrace) []string { + if exp == nil { + return nil + } + + var failures []string + + if len(exp.ToolCalls) > 0 { + have := make(map[string]bool, len(obs.ToolCalls)) + for _, t := range obs.ToolCalls { + have[t] = true + } + var missing []string + for _, want := range exp.ToolCalls { + if !have[want] { + missing = append(missing, want) + } + } + if len(missing) > 0 { + failures = append(failures, fmt.Sprintf( + "expected tool call(s) not found: %v (called: %v)", missing, orNone(obs.ToolCalls))) + } + } + + if exp.LLMCalls > 0 && obs.LLMCalls != exp.LLMCalls { + failures = append(failures, fmt.Sprintf( + "expected %d LLM call(s), observed %d", exp.LLMCalls, obs.LLMCalls)) + } + + if exp.MinSteps > 0 && obs.Steps < exp.MinSteps { + failures = append(failures, fmt.Sprintf( + "expected at least %d step(s), observed %d", exp.MinSteps, obs.Steps)) + } + + if exp.MaxSteps > 0 && obs.Steps > exp.MaxSteps { + failures = append(failures, fmt.Sprintf( + "expected at most %d step(s), observed %d", exp.MaxSteps, obs.Steps)) + } + + if len(exp.ExecutionPath) > 0 && !isOrderedSubsequence(exp.ExecutionPath, obs.Path) { + failures = append(failures, fmt.Sprintf( + "expected execution path %v not found (in order) within observed %v", + exp.ExecutionPath, orNone(obs.Path))) + } + + return failures +} + +// isOrderedSubsequence reports whether want appears within have in order (gaps allowed). +func isOrderedSubsequence(want, have []string) bool { + i := 0 + for _, h := range have { + if i < len(want) && h == want[i] { + i++ + } + } + return i == len(want) +} + +func orNone(s []string) []string { + if len(s) == 0 { + return []string{""} + } + return s +} diff --git a/internal/eval/trace_validator_test.go b/internal/eval/trace_validator_test.go new file mode 100644 index 0000000..4a93e29 --- /dev/null +++ b/internal/eval/trace_validator_test.go @@ -0,0 +1,145 @@ +package eval + +import ( + "reflect" + "testing" +) + +func TestValidateTrace(t *testing.T) { + tests := []struct { + name string + exp *TraceExpectation + obs ObservedTrace + wantFailure bool + }{ + { + name: "tool calls present", + exp: &TraceExpectation{ToolCalls: []string{"search"}}, + obs: ObservedTrace{ToolCalls: []string{"search", "calculator"}}, + }, + { + name: "tool call missing", + exp: &TraceExpectation{ToolCalls: []string{"search", "weather"}}, + obs: ObservedTrace{ToolCalls: []string{"search"}}, + wantFailure: true, + }, + { + name: "llm calls exact match", + exp: &TraceExpectation{LLMCalls: 3}, + obs: ObservedTrace{LLMCalls: 3}, + }, + { + name: "llm calls mismatch", + exp: &TraceExpectation{LLMCalls: 3}, + obs: ObservedTrace{LLMCalls: 2}, + wantFailure: true, + }, + { + name: "min steps satisfied", + exp: &TraceExpectation{MinSteps: 2}, + obs: ObservedTrace{Steps: 5}, + }, + { + name: "min steps violated", + exp: &TraceExpectation{MinSteps: 5}, + obs: ObservedTrace{Steps: 2}, + wantFailure: true, + }, + { + name: "max steps violated", + exp: &TraceExpectation{MaxSteps: 3}, + obs: ObservedTrace{Steps: 7}, + wantFailure: true, + }, + { + name: "execution path ordered subsequence", + exp: &TraceExpectation{ExecutionPath: []string{"research", "format"}}, + obs: ObservedTrace{Path: []string{"start", "research", "summarize", "format", "done"}}, + }, + { + name: "execution path out of order", + exp: &TraceExpectation{ExecutionPath: []string{"format", "research"}}, + obs: ObservedTrace{Path: []string{"research", "format"}}, + wantFailure: true, + }, + { + name: "nil expectation passes", + exp: nil, + obs: ObservedTrace{}, + }, + { + name: "multiple failures combine", + exp: &TraceExpectation{ToolCalls: []string{"x"}, LLMCalls: 2, MaxSteps: 1}, + obs: ObservedTrace{ToolCalls: []string{"y"}, LLMCalls: 5, Steps: 9}, + wantFailure: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + failures := ValidateTrace(tc.exp, tc.obs) + if got := len(failures) > 0; got != tc.wantFailure { + t.Fatalf("wantFailure=%v, got failures=%v", tc.wantFailure, failures) + } + }) + } +} + +func TestBuildObservedTrace(t *testing.T) { + tr := &evalTrace{ + ID: "trace-1", + Spans: []*evalSpan{ + {Name: "agk.agent.run"}, + {Name: "agk.llm.generate"}, + {Name: "agk.tool.call", Attributes: map[string]interface{}{"agk.tool.name": "search"}}, + {Name: "agk.llm.generate"}, + }, + } + + obs := buildObservedTrace(tr, []string{"calculator"}) + + if obs.LLMCalls != 2 { + t.Errorf("LLMCalls = %d, want 2", obs.LLMCalls) + } + if obs.Steps != 4 { + t.Errorf("Steps = %d, want 4", obs.Steps) + } + // tools_called ("calculator") plus the tool span ("search"), de-duplicated and ordered. + wantTools := []string{"calculator", "search"} + if !reflect.DeepEqual(obs.ToolCalls, wantTools) { + t.Errorf("ToolCalls = %v, want %v", obs.ToolCalls, wantTools) + } + wantPath := []string{"agk.agent.run", "agk.llm.generate", "agk.tool.call", "agk.llm.generate"} + if !reflect.DeepEqual(obs.Path, wantPath) { + t.Errorf("Path = %v, want %v", obs.Path, wantPath) + } +} + +func TestBuildObservedTraceNilTrace(t *testing.T) { + obs := buildObservedTrace(nil, []string{"search", "search", ""}) + // Only tools_called is available; duplicates and empties removed. + if !reflect.DeepEqual(obs.ToolCalls, []string{"search"}) { + t.Errorf("ToolCalls = %v, want [search]", obs.ToolCalls) + } + if obs.Steps != 0 || obs.LLMCalls != 0 { + t.Errorf("expected zero Steps/LLMCalls with nil trace, got steps=%d llm=%d", obs.Steps, obs.LLMCalls) + } +} + +func TestIsOrderedSubsequence(t *testing.T) { + cases := []struct { + want, have []string + ok bool + }{ + {[]string{"a", "c"}, []string{"a", "b", "c"}, true}, + {[]string{"a", "b", "c"}, []string{"a", "b", "c"}, true}, + {[]string{"c", "a"}, []string{"a", "b", "c"}, false}, + {[]string{"a", "d"}, []string{"a", "b", "c"}, false}, + {[]string{}, []string{"a"}, true}, + } + for _, c := range cases { + if got := isOrderedSubsequence(c.want, c.have); got != c.ok { + t.Errorf("isOrderedSubsequence(%v, %v) = %v, want %v", c.want, c.have, got, c.ok) + } + } +}