Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions forge-cli/runtime/auth_audit_seq_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package runtime

import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/initializ/forge/forge-core/auth"
coreruntime "github.com/initializ/forge/forge-core/runtime"
)

// TestAuthAudit_SeqStampedWhenCounterInstalled is the #174 regression
// pin: when the request's ctx carries a SequenceCounter (as it does
// after installSequenceCounterMiddleware wraps the auth chain),
// makeAuthAuditCallback's emit picks the counter up via
// EmitFromContext and stamps seq=1 on auth_verify.
//
// Pre-fix the callback used plain Emit and lost seq entirely.
func TestAuthAudit_SeqStampedWhenCounterInstalled(t *testing.T) {
var buf bytes.Buffer
cb := makeAuthAuditCallback(coreruntime.NewAuditLogger(&buf))

req := httptest.NewRequest(http.MethodPost, "/tasks", nil)
// Simulate the wrapper: install a fresh counter on req.Context().
ctx := coreruntime.WithSequenceCounter(req.Context(), new(coreruntime.SequenceCounter))
req = req.WithContext(ctx)

id := &auth.Identity{UserID: "alice", Source: "oidc"}
cb(req, id, nil, "jwt")

var ev coreruntime.AuditEvent
if err := json.Unmarshal(bytes.TrimSpace(buf.Bytes()), &ev); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if ev.Event != coreruntime.EventAuthVerify {
t.Fatalf("Event = %q, want auth_verify", ev.Event)
}
if ev.Sequence != 1 {
t.Errorf("auth_verify seq = %d, want 1 (counter installed pre-auth)", ev.Sequence)
}
}

// TestAuthAudit_NoSeqWhenCounterAbsent confirms the no-counter path
// stays valid: when nothing installed a counter on ctx, the emit
// produces an event with seq=0 (and the omitempty JSON tag drops the
// field). This pins backward-compat for legacy embedders that wire
// their own server.Server without the wrapper.
func TestAuthAudit_NoSeqWhenCounterAbsent(t *testing.T) {
var buf bytes.Buffer
cb := makeAuthAuditCallback(coreruntime.NewAuditLogger(&buf))

req := httptest.NewRequest(http.MethodPost, "/tasks", nil) // no counter on ctx
cb(req, &auth.Identity{UserID: "alice", Source: "oidc"}, nil, "jwt")

body := strings.TrimSpace(buf.String())
if strings.Contains(body, `"seq"`) {
t.Errorf("seq field must be omitted when no counter is on ctx; got: %s", body)
}
}

// TestSequenceCounterMiddleware_InstallsCounterBeforeNext verifies the
// wrapper installs the counter on r.Context() before delegating to
// the wrapped middleware (and through to the next handler). The next
// handler reads the counter off the context to confirm.
func TestSequenceCounterMiddleware_InstallsCounterBeforeNext(t *testing.T) {
// A passthrough auth middleware — just calls next.
passthroughAuth := func(next http.Handler) http.Handler { return next }

var observed *coreruntime.SequenceCounter
terminal := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
observed = coreruntime.SequenceCounterFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})

wrapped := installSequenceCounterMiddleware(passthroughAuth)(terminal)

req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
wrapped.ServeHTTP(w, req)

if observed == nil {
t.Fatal("terminal handler saw no SequenceCounter on ctx")
}
// Counter starts at 0 and increments to 1 on first NextSequence call.
if got := coreruntime.NextSequence(coreruntime.WithSequenceCounter(context.Background(), observed)); got != 1 {
t.Errorf("first NextSequence on wrapper-installed counter = %d, want 1", got)
}
}

// TestEnsureSequenceCounter_ReusesExisting pins the runner-side
// invariant: the per-A2A-request setup must NOT clobber a counter
// already installed by the auth wrapper. EnsureSequenceCounter
// returns ctx unchanged when the counter is already present.
func TestEnsureSequenceCounter_ReusesExisting(t *testing.T) {
original := new(coreruntime.SequenceCounter)
ctx := coreruntime.WithSequenceCounter(context.Background(), original)
// Advance the counter once so we can detect a reset.
_ = coreruntime.NextSequence(ctx)

ctx2 := coreruntime.EnsureSequenceCounter(ctx)

got := coreruntime.SequenceCounterFromContext(ctx2)
if got != original {
t.Errorf("EnsureSequenceCounter replaced the existing counter; want pointer-equality")
}
// The counter must continue from where it left off (seq=2 next).
if next := coreruntime.NextSequence(ctx2); next != 2 {
t.Errorf("counter reset by EnsureSequenceCounter; got next=%d, want 2", next)
}
}

// TestEnsureSequenceCounter_InstallsFresh covers the --no-auth path
// where the wrapper never ran: EnsureSequenceCounter installs a
// fresh counter so per-A2A-request audit emit still gets seq stamped.
func TestEnsureSequenceCounter_InstallsFresh(t *testing.T) {
ctx := coreruntime.EnsureSequenceCounter(context.Background())
if coreruntime.SequenceCounterFromContext(ctx) == nil {
t.Fatal("EnsureSequenceCounter on empty ctx should install a counter")
}
if next := coreruntime.NextSequence(ctx); next != 1 {
t.Errorf("fresh counter's first NextSequence = %d, want 1", next)
}
}
37 changes: 27 additions & 10 deletions forge-cli/runtime/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ func (r *Runner) Run(ctx context.Context) error {
Host: r.cfg.Host,
ShutdownTimeout: r.cfg.ShutdownTimeout,
AgentCard: card,
AuthMiddleware: auth.Middleware(authCfg),
AuthMiddleware: installSequenceCounterMiddleware(auth.Middleware(authCfg)),
AllowedOrigins: corsOrigins,
RateLimit: rateLimit,
})
Expand Down Expand Up @@ -1189,8 +1189,12 @@ func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.Agent
// FWS-8: per-invocation sequence counter so every audit event
// emitted on behalf of this request carries a monotonically
// increasing `seq` field — consumers detect gaps + ordering
// at the export side.
ctx = coreruntime.WithSequenceCounter(ctx, new(coreruntime.SequenceCounter))
// at the export side. Reuse the counter
// installSequenceCounterMiddleware put on ctx before auth ran
// (so auth_verify=seq=1 and session_start=seq=2) — see #174.
// EnsureSequenceCounter installs a fresh one if missing
// (--no-auth path / direct test invocations).
ctx = coreruntime.EnsureSequenceCounter(ctx)
sseAcc := coreruntime.NewLLMUsageAccumulator()
ctx = coreruntime.WithLLMUsageAccumulator(ctx, sseAcc)
defer func() {
Expand Down Expand Up @@ -1403,7 +1407,11 @@ func (r *Runner) executeTask(
ctx = coreruntime.WithCorrelationID(ctx, correlationID)
ctx = coreruntime.WithTaskID(ctx, params.ID)
// FWS-8: per-invocation sequence counter (see issue #91 / FWS-8).
ctx = coreruntime.WithSequenceCounter(ctx, new(coreruntime.SequenceCounter))
// EnsureSequenceCounter reuses the counter the auth middleware
// wrapper installed pre-auth so auth_verify lands seq=1 and
// session_start lands seq=2 (#174); installs a fresh one when
// missing (--no-auth path / direct test invocations).
ctx = coreruntime.EnsureSequenceCounter(ctx)
// Per-invocation usage accumulator so AfterLLMCall hooks can fold
// each call's tokens/duration into running totals the response
// handler reads back for X-Forge-* headers + the
Expand Down Expand Up @@ -1686,8 +1694,10 @@ func (r *Runner) registerRESTHandlers(srv *server.Server, executor coreruntime.A
// FWS-8: per-invocation sequence counter so every audit event
// emitted on behalf of this request carries a monotonically
// increasing `seq` field — consumers detect gaps + ordering
// at the export side.
ctx = coreruntime.WithSequenceCounter(ctx, new(coreruntime.SequenceCounter))
// at the export side. Reuse the counter
// installSequenceCounterMiddleware put on ctx before auth ran
// (#174); install fresh on the --no-auth path.
ctx = coreruntime.EnsureSequenceCounter(ctx)
// Pull workflow correlation headers (issue #86 / FWS-2) before
// the accumulator setup so invocation_complete inherits workflow
// tagging via EmitFromContext.
Expand Down Expand Up @@ -2526,10 +2536,17 @@ func makeAuthAuditCallback(auditLogger *coreruntime.AuditLogger) func(*http.Requ
wc := coreruntime.WorkflowContextFromHTTPHeaders(req.Header)
// Same for the per-request tenancy override (#157). When
// absent, the AuditLogger's static deployment-time stamp still
// kicks in via plain Emit so auth events match the rest of
// the stream's org_id / workspace_id columns.
// kicks in so auth events match the rest of the stream's
// org_id / workspace_id columns.
tc := coreruntime.TenancyContextFromHTTPHeaders(req.Header)

// EmitFromContext stamps `seq` from the SequenceCounter the
// installSequenceCounterMiddleware wrapper installed on
// req.Context() before the auth chain ran (#174). The
// runner's per-A2A-request setup downstream calls
// EnsureSequenceCounter and reuses this counter, so
// session_start lands at seq=2 and the per-correlation_id
// sequence is gap-free for FWS-8 consumers.
if err == nil && id != nil {
// Success → auth_verify.
fields := map[string]any{
Expand All @@ -2542,7 +2559,7 @@ func makeAuthAuditCallback(auditLogger *coreruntime.AuditLogger) func(*http.Requ
"path": req.URL.Path,
"remote_addr": req.RemoteAddr,
}
auditLogger.Emit(coreruntime.AuditEvent{
auditLogger.EmitFromContext(req.Context(), coreruntime.AuditEvent{
Event: coreruntime.EventAuthVerify,
CorrelationID: correlationID,
WorkflowID: wc.WorkflowID,
Expand All @@ -2557,7 +2574,7 @@ func makeAuthAuditCallback(auditLogger *coreruntime.AuditLogger) func(*http.Requ
}

// Failure → auth_fail with reason code.
auditLogger.Emit(coreruntime.AuditEvent{
auditLogger.EmitFromContext(req.Context(), coreruntime.AuditEvent{
Event: coreruntime.EventAuthFail,
CorrelationID: correlationID,
WorkflowID: wc.WorkflowID,
Expand Down
43 changes: 43 additions & 0 deletions forge-cli/runtime/sequence_counter_middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package runtime

import (
"net/http"

coreruntime "github.com/initializ/forge/forge-core/runtime"
)

// installSequenceCounterMiddleware wraps the auth middleware so the
// per-invocation SequenceCounter is installed on the request context
// BEFORE the auth chain runs. The auth chain's OnAuth callback (which
// emits auth_verify / auth_fail) then sees a counter on its
// req.Context() and stamps seq=1 on the first event. The runner's
// per-A2A-request setup further downstream calls
// coreruntime.EnsureSequenceCounter, which detects the existing
// counter and reuses it — so session_start lands at seq=2, llm_call
// at seq=3, and the per-correlation_id sequence is gap-free for
// FWS-8 consumers.
//
// Before this wrapper, the runner's setup installed the counter at
// the JSON-RPC / REST handler entry, which is downstream of auth.
// The auth callback's audit emits had to use plain Emit() and lost
// seq + trace_id + workflow-correlation tags. See issue #174.
//
// Cost: ~24 bytes per request for the SequenceCounter allocation.
// The wrapper runs even on auth-skipped paths
// (/.well-known/agent-card.json, /healthz). Those paths don't emit
// per-request audit events, so the counter is unused — but allocating
// unconditionally is simpler than threading skip-path knowledge into
// the wrapper, and the allocation is in the same ballpark as the
// request struct itself.
func installSequenceCounterMiddleware(authMW func(http.Handler) http.Handler) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
// Compose once: the auth middleware wraps next; we wrap THAT
// composition so the seq counter is installed before auth sees
// the request.
composed := authMW(next)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := coreruntime.EnsureSequenceCounter(r.Context())
composed.ServeHTTP(w, r.WithContext(ctx))
})
}
}
14 changes: 14 additions & 0 deletions forge-core/runtime/audit_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,17 @@ func NextSequence(ctx context.Context) int64 {
}
return c.Add(1)
}

// EnsureSequenceCounter returns ctx unchanged when it already carries a
// SequenceCounter; otherwise it returns a new ctx with a fresh counter
// installed. Use at any invocation-entry point that may run downstream
// of an upstream middleware which already installed a counter — e.g.,
// the runner's per-A2A-request setup runs after the auth middleware
// (which installs a counter so auth_verify lands seq=1) and must not
// clobber it. See issue #174.
func EnsureSequenceCounter(ctx context.Context) context.Context {
if SequenceCounterFromContext(ctx) != nil {
return ctx
}
return WithSequenceCounter(ctx, new(SequenceCounter))
}
Loading