diff --git a/gateway/internal/inject/inject.go b/gateway/internal/inject/inject.go index 2003dc0..85e8572 100644 --- a/gateway/internal/inject/inject.go +++ b/gateway/internal/inject/inject.go @@ -16,6 +16,17 @@ import ( "sync" ) +const ( + // compressMinBytes is the smallest response body we'll bother gzipping. + // Below this, gzip overhead can exceed savings. + compressMinBytes = 1024 + + // cacheMaxEntries bounds the in-memory cache so a runaway URL space + // can't blow up memory. Static exports rarely exceed a few hundred + // HTML routes, so this is generous. + cacheMaxEntries = 1000 +) + // Middleware wraps an http.Handler and injects the REP script tag into HTML responses. type Middleware struct { // next is the upstream handler (reverse proxy or file server). @@ -28,9 +39,27 @@ type Middleware struct { mu sync.RWMutex logger *slog.Logger + + // cache stores fully-processed (injected, optionally gzipped) responses + // keyed by request path. nil means caching is disabled — the default. + // Per REP-RFC-0001 §4.3 the gateway MUST NOT cache when SENSITIVE vars + // are present (the encrypted blob may rotate), so the server only opts + // in via EnableCache when it's safe. + cache map[string]*cacheEntry + cacheMu sync.RWMutex } -// New creates a new injection middleware. +// cacheEntry is a fully-processed response stored under a path key. +// Both encodings are pre-computed so cache hits never re-compress. +type cacheEntry struct { + statusCode int + headers http.Header + identity []byte // pre-injected identity-encoded bytes + gzipped []byte // pre-compressed bytes (nil if compression wasn't worthwhile) +} + +// New creates a new injection middleware. Caching is off by default; +// callers opt in via EnableCache when no SENSITIVE variables are present. func New(next http.Handler, scriptTag string, logger *slog.Logger) *Middleware { return &Middleware{ next: next, @@ -39,11 +68,32 @@ func New(next http.Handler, scriptTag string, logger *slog.Logger) *Middleware { } } -// UpdateScriptTag replaces the script tag (used during hot reload). +// EnableCache turns on response caching for processed HTML. +// +// Per REP-RFC-0001 §4.3, the gateway MUST NOT cache injected HTML when +// SENSITIVE variables are present (the encrypted blob may rotate). Callers +// must only enable this when no SENSITIVE vars are configured, and should +// also leave it off when hot-reload is active. +func (m *Middleware) EnableCache() { + m.cacheMu.Lock() + if m.cache == nil { + m.cache = make(map[string]*cacheEntry) + } + m.cacheMu.Unlock() +} + +// UpdateScriptTag replaces the script tag (used during hot reload) and +// invalidates any cached responses (they contain the previous tag). func (m *Middleware) UpdateScriptTag(scriptTag string) { m.mu.Lock() m.scriptTag = []byte(scriptTag) m.mu.Unlock() + + m.cacheMu.Lock() + if m.cache != nil { + m.cache = make(map[string]*cacheEntry) + } + m.cacheMu.Unlock() } // ServeHTTP intercepts HTML responses and injects the REP payload. @@ -56,11 +106,26 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Strip Accept-Encoding from the request so the upstream always responds - // with identity encoding. This ensures we can reliably search for - // in the response body for injection. + // Capture the client's preferred response encoding before we strip it. + // We strip Accept-Encoding so the upstream always returns identity (we + // need to byte-search for ); on the way back out we honour the + // client's original preference and re-compress if appropriate. + clientAccepts := r.Header.Get("Accept-Encoding") r.Header.Del("Accept-Encoding") + // Cache lookup — only for GET, only when caching is enabled, only for + // requests that don't carry per-user identity (Cookie/Authorization). + // The cache is keyed by request URI (path + query) so URLs that vary + // by query don't collide. + cacheKey := r.URL.RequestURI() + if r.Method == http.MethodGet && requestIsCacheable(r) { + if entry := m.cacheGet(cacheKey); entry != nil { + m.writeCached(w, entry, clientAccepts) + m.logger.Debug("rep.inject.cache_hit", "path", cacheKey) + return + } + } + // Wrap the response writer to capture the response. rec := &responseRecorder{ ResponseWriter: w, @@ -72,6 +137,19 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Serve the request to the upstream handler. m.next.ServeHTTP(rec, r) + // Statuses that MUST NOT carry a body (RFC 9110 §15) — pass the + // upstream response through unmodified. Injecting into a 304 / 204 / + // 1xx would generate a non-empty body and a Content-Length, violating + // HTTP semantics and breaking downstream conditional-request flows. + if isBodylessStatus(rec.statusCode) { + copyHeaders(w.Header(), rec.header) + w.WriteHeader(rec.statusCode) + if _, err := w.Write(rec.body.Bytes()); err != nil { + m.logger.Debug("rep.inject.write_error", "path", r.URL.Path, "error", err) + } + return + } + // Check if the response is HTML. contentType := rec.header.Get("Content-Type") if !isHTML(contentType) { @@ -114,16 +192,76 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Inject the REP script tag into the HTML. injected := injectIntoHTML(body, tag) - copyHeaders(w.Header(), rec.header) + // Compute the gzipped variant only when we'll actually use it: + // - this client accepts gzip (we'll send it now), OR + // - caching is enabled (we may send it to a future client that does). + // Otherwise the work would be thrown away. + var gzipped []byte + if len(injected) >= compressMinBytes { + needsGzip := acceptsGzip(clientAccepts) || m.cacheActive() + if needsGzip { + var err error + gzipped, err = gzipCompress(injected) + if err != nil { + m.logger.Debug("rep.inject.gzip_error", "path", r.URL.Path, "error", err) + gzipped = nil + } + } + } - // Update Content-Length to reflect the injected content. - w.Header().Set("Content-Length", strconv.Itoa(len(injected))) + // Build the response headers we'll send to the client. Strip + // Content-Encoding/Length (we own them now). Strip ETag and + // Last-Modified because the upstream computed them for the + // pre-injection body — keeping them would mislead conditional-request + // flows. Announce that the body varies on Accept-Encoding so caches + // don't serve the wrong form. + respHeader := make(http.Header) + copyHeaders(respHeader, rec.header) + respHeader.Del("Content-Encoding") + respHeader.Del("Content-Length") + respHeader.Del("ETag") + respHeader.Del("Last-Modified") + addVary(respHeader, "Accept-Encoding") + + // Cache eligibility — many guards because the cache is keyed by URI + // only and content can be per-user in proxy mode: + // + // - GET only + // - 200 OK only + // - request has no Cookie/Authorization (would otherwise be per-user) + // - response has no Set-Cookie (per-user state being established) + // - response is not marked Cache-Control: private/no-store/no-cache + // - response doesn't Vary by Cookie/Authorization + if r.Method == http.MethodGet && + rec.statusCode == http.StatusOK && + requestIsCacheable(r) && + responseIsCacheable(rec.header) { + m.cachePut(cacheKey, &cacheEntry{ + statusCode: rec.statusCode, + headers: respHeader.Clone(), + identity: injected, + gzipped: gzipped, + }) + } - // Remove Content-Encoding since we've modified the body. - w.Header().Del("Content-Encoding") + // Pick the encoding the client wants and ship it. + outBody, outEncoding := pickVariant(injected, gzipped, clientAccepts) + if outEncoding != "" { + respHeader.Set("Content-Encoding", outEncoding) + } + respHeader.Set("Content-Length", strconv.Itoa(len(outBody))) + dst := w.Header() + for k := range dst { + dst.Del(k) + } + for k, values := range respHeader { + for _, value := range values { + dst.Add(k, value) + } + } w.WriteHeader(rec.statusCode) - if _, err := w.Write(injected); err != nil { + if _, err := w.Write(outBody); err != nil { m.logger.Debug("rep.inject.write_error", "path", r.URL.Path, "error", err) } @@ -131,7 +269,153 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { "path", r.URL.Path, "original_size", len(body), "injected_size", len(injected), + "sent_size", len(outBody), + "encoding", outEncoding, + ) +} + +// writeCached emits a cached entry, picking identity or gzip per the +// client's Accept-Encoding. +func (m *Middleware) writeCached(w http.ResponseWriter, entry *cacheEntry, clientAccepts string) { + body, encoding := pickVariant(entry.identity, entry.gzipped, clientAccepts) + + dst := w.Header() + for k := range dst { + dst.Del(k) + } + for k, values := range entry.headers { + for _, value := range values { + dst.Add(k, value) + } + } + if encoding != "" { + dst.Set("Content-Encoding", encoding) + } + dst.Set("Content-Length", strconv.Itoa(len(body))) + w.WriteHeader(entry.statusCode) + _, _ = w.Write(body) +} + +// cacheActive reports whether caching is enabled. +func (m *Middleware) cacheActive() bool { + m.cacheMu.RLock() + defer m.cacheMu.RUnlock() + return m.cache != nil +} + +func (m *Middleware) cacheGet(path string) *cacheEntry { + m.cacheMu.RLock() + defer m.cacheMu.RUnlock() + if m.cache == nil { + return nil + } + return m.cache[path] +} + +func (m *Middleware) cachePut(path string, entry *cacheEntry) { + m.cacheMu.Lock() + defer m.cacheMu.Unlock() + if m.cache == nil { + return + } + if len(m.cache) >= cacheMaxEntries { + // Bounded; drop new additions until the next UpdateScriptTag clears. + // In practice this never trips for static exports. + return + } + m.cache[path] = entry +} + +// pickVariant chooses identity or gzipped based on the client's Accept-Encoding. +// Falls back to identity if we didn't pre-compute gzip for this response. +func pickVariant(identity, gzipped []byte, accept string) (body []byte, encoding string) { + if len(gzipped) > 0 && acceptsGzip(accept) { + return gzipped, "gzip" + } + return identity, "" +} + +// acceptsGzip parses Accept-Encoding and returns true if gzip is acceptable. +// +// Per RFC 9110 §12.5.3, an explicit coding parameter takes precedence over +// the `*` wildcard. So `gzip;q=0, *;q=0.5` rejects gzip even though `*` +// would otherwise allow it. +func acceptsGzip(accept string) bool { + if accept == "" { + return false + } + + var ( + explicitGzipSeen bool + explicitGzipQ float64 + wildcardSeen bool + wildcardQ float64 ) + + for _, part := range strings.Split(accept, ",") { + token := strings.TrimSpace(part) + if token == "" { + continue + } + name, params, _ := strings.Cut(token, ";") + name = strings.ToLower(strings.TrimSpace(name)) + if name != "gzip" && name != "*" { + continue + } + q := 1.0 + for _, p := range strings.Split(params, ";") { + p = strings.TrimSpace(p) + if k, v, ok := strings.Cut(p, "="); ok && strings.EqualFold(strings.TrimSpace(k), "q") { + if parsed, err := strconv.ParseFloat(strings.TrimSpace(v), 64); err == nil { + q = parsed + } + } + } + if name == "gzip" { + explicitGzipSeen = true + explicitGzipQ = q + } else { // "*" + wildcardSeen = true + wildcardQ = q + } + } + + switch { + case explicitGzipSeen: + return explicitGzipQ > 0 + case wildcardSeen: + return wildcardQ > 0 + default: + return false + } +} + +// addVary appends a token to the Vary header if it isn't already present. +// Handles both repeated `Vary:` headers and single comma-separated values +// (`Vary: Origin, Accept-Encoding`), so we never duplicate a token. +func addVary(h http.Header, value string) { + for _, v := range h.Values("Vary") { + for _, existing := range strings.Split(v, ",") { + if strings.EqualFold(strings.TrimSpace(existing), value) { + return + } + } + } + h.Add("Vary", value) +} + +// gzipCompress returns the gzip-encoded form of body. +func gzipCompress(body []byte) ([]byte, error) { + var buf bytes.Buffer + w := gzip.NewWriter(&buf) + if _, err := w.Write(body); err != nil { + _ = w.Close() + return nil, err + } + if err := w.Close(); err != nil { + return nil, err + } + return buf.Bytes(), nil } // decompressBody decompresses a response body based on Content-Encoding. @@ -250,6 +534,61 @@ func isInsideComment(html []byte, pos int, open, close []byte) bool { return false } +// isBodylessStatus reports whether an HTTP status code MUST NOT carry a +// response body, per RFC 9110 §15. The middleware bypasses injection, +// compression, and caching for these so we don't fabricate a body that +// breaks downstream conditional-request flows. +func isBodylessStatus(status int) bool { + if status >= 100 && status < 200 { + return true + } + switch status { + case http.StatusNoContent, http.StatusNotModified: + return true + } + return false +} + +// requestIsCacheable reports whether a request can safely use the path- +// keyed in-memory cache. Skipped if the request carries identity headers +// that would normally personalise the response. +func requestIsCacheable(r *http.Request) bool { + if r.Header.Get("Cookie") != "" { + return false + } + if r.Header.Get("Authorization") != "" { + return false + } + return true +} + +// responseIsCacheable reports whether the upstream response can be +// stored in the in-memory cache. Honours upstream Cache-Control +// directives and rejects responses that vary by per-user headers. +func responseIsCacheable(h http.Header) bool { + for _, v := range h.Values("Set-Cookie") { + _ = v + return false // any Set-Cookie disqualifies + } + for _, v := range h.Values("Cache-Control") { + for _, directive := range strings.Split(v, ",") { + d := strings.ToLower(strings.TrimSpace(directive)) + if d == "private" || d == "no-store" || d == "no-cache" { + return false + } + } + } + for _, v := range h.Values("Vary") { + for _, token := range strings.Split(v, ",") { + t := strings.ToLower(strings.TrimSpace(token)) + if t == "cookie" || t == "authorization" || t == "*" { + return false + } + } + } + return true +} + // isWebSocketUpgrade reports whether the request is a WebSocket upgrade. func isWebSocketUpgrade(r *http.Request) bool { return strings.EqualFold(r.Header.Get("Connection"), "upgrade") && diff --git a/gateway/internal/inject/inject_perf_test.go b/gateway/internal/inject/inject_perf_test.go new file mode 100644 index 0000000..90e04bc --- /dev/null +++ b/gateway/internal/inject/inject_perf_test.go @@ -0,0 +1,581 @@ +package inject + +import ( + "bytes" + "compress/gzip" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" +) + +// largeHTMLBody returns an HTML body just over compressMinBytes so the +// middleware will compress it for clients that accept gzip. +func largeHTMLBody() string { + var b strings.Builder + b.WriteString("t") + b.WriteString(strings.Repeat("Lorem ipsum dolor sit amet. ", 100)) + b.WriteString("") + return b.String() +} + +func TestMiddleware_GzipEncodingWhenAccepted(t *testing.T) { + html := largeHTMLBody() + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // We expect the middleware to strip Accept-Encoding before the + // upstream sees it; if a value sneaks through we want to know. + if got := r.Header.Get("Accept-Encoding"); got != "" { + t.Errorf("upstream should not see Accept-Encoding, got %q", got) + } + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(html)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Accept-Encoding", "gzip, deflate, br") + rec := httptest.NewRecorder() + + m.ServeHTTP(rec, req) + + if got := rec.Header().Get("Content-Encoding"); got != "gzip" { + t.Fatalf("expected Content-Encoding=gzip, got %q", got) + } + + if got := rec.Header().Get("Vary"); !containsToken(got, "Accept-Encoding") { + t.Errorf("expected Vary to include Accept-Encoding, got %q", got) + } + + // The body should be valid gzip and decompress to the original-with-tag. + gz, err := gzip.NewReader(bytes.NewReader(rec.Body.Bytes())) + if err != nil { + t.Fatalf("response body is not valid gzip: %v", err) + } + defer func() { _ = gz.Close() }() + + decompressed, err := io.ReadAll(gz) + if err != nil { + t.Fatalf("decompressing body: %v", err) + } + if !strings.Contains(string(decompressed), `id="__rep__"`) { + t.Error("decompressed body should contain the injected script tag") + } + + // Compressed wire size should be smaller than the original — that is + // the whole point. + if len(rec.Body.Bytes()) >= len(html) { + t.Errorf("compressed body (%d) should be smaller than original (%d)", + len(rec.Body.Bytes()), len(html)) + } +} + +func TestMiddleware_NoGzipWhenClientDoesNotAccept(t *testing.T) { + html := largeHTMLBody() + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(html)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + // No Accept-Encoding header. + rec := httptest.NewRecorder() + + m.ServeHTTP(rec, req) + + if got := rec.Header().Get("Content-Encoding"); got != "" { + t.Fatalf("expected no Content-Encoding, got %q", got) + } + if !strings.Contains(rec.Body.String(), `id="__rep__"`) { + t.Error("identity body should still contain the injected tag") + } +} + +func TestMiddleware_NoGzipForSmallBodies(t *testing.T) { + // Shorter than compressMinBytes — gzip overhead would exceed savings. + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(`Hi`)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Accept-Encoding", "gzip") + rec := httptest.NewRecorder() + + m.ServeHTTP(rec, req) + + if got := rec.Header().Get("Content-Encoding"); got != "" { + t.Errorf("small response should not be gzipped, got encoding %q", got) + } +} + +func TestMiddleware_GzipQZeroRejection(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(largeHTMLBody())) + }) + + m := New(upstream, testScriptTag, slog.Default()) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + // Client says "I'd accept anything except gzip." + req.Header.Set("Accept-Encoding", "gzip;q=0, identity") + rec := httptest.NewRecorder() + + m.ServeHTTP(rec, req) + + if got := rec.Header().Get("Content-Encoding"); got != "" { + t.Errorf("gzip;q=0 should not be gzipped, got encoding %q", got) + } +} + +func TestMiddleware_CacheDisabledByDefault(t *testing.T) { + var calls int32 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(`x`)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + + for i := 0; i < 3; i++ { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + m.ServeHTTP(rec, req) + } + + if got := atomic.LoadInt32(&calls); got != 3 { + t.Errorf("cache disabled by default — expected 3 upstream calls, got %d", got) + } +} + +func TestMiddleware_CacheHitSkipsUpstream(t *testing.T) { + var calls int32 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(`cached`)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + m.EnableCache() + + // First request populates the cache. + req1 := httptest.NewRequest(http.MethodGet, "/", nil) + rec1 := httptest.NewRecorder() + m.ServeHTTP(rec1, req1) + body1 := rec1.Body.String() + + // Second request should hit the cache. + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + rec2 := httptest.NewRecorder() + m.ServeHTTP(rec2, req2) + body2 := rec2.Body.String() + + if got := atomic.LoadInt32(&calls); got != 1 { + t.Errorf("expected 1 upstream call (cache hit on second request), got %d", got) + } + if body1 != body2 { + t.Errorf("cached response should be byte-identical: %q vs %q", body1, body2) + } + if !strings.Contains(body2, `id="__rep__"`) { + t.Error("cached response should still contain the injected tag") + } +} + +func TestMiddleware_CacheHitRespectsAcceptEncoding(t *testing.T) { + html := largeHTMLBody() + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(html)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + m.EnableCache() + + // Populate cache with a no-gzip request. + req1 := httptest.NewRequest(http.MethodGet, "/", nil) + rec1 := httptest.NewRecorder() + m.ServeHTTP(rec1, req1) + if got := rec1.Header().Get("Content-Encoding"); got != "" { + t.Errorf("first (no-AE) request should be identity, got %q", got) + } + + // Same path, this time with gzip — should hit cache and serve gzipped variant. + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + req2.Header.Set("Accept-Encoding", "gzip") + rec2 := httptest.NewRecorder() + m.ServeHTTP(rec2, req2) + if got := rec2.Header().Get("Content-Encoding"); got != "gzip" { + t.Errorf("cache hit should serve gzip when client asks for it, got %q", got) + } +} + +func TestMiddleware_UpdateScriptTagInvalidatesCache(t *testing.T) { + var calls int32 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(``)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + m.EnableCache() + + // Populate. + rec := httptest.NewRecorder() + m.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) + + // Update tag — cache should be invalidated. + newTag := `` + m.UpdateScriptTag(newTag) + + rec = httptest.NewRecorder() + m.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) + + if got := atomic.LoadInt32(&calls); got != 2 { + t.Errorf("UpdateScriptTag should invalidate the cache (expected 2 upstream calls, got %d)", got) + } + if !strings.Contains(rec.Body.String(), `"X":"1"`) { + t.Error("response after UpdateScriptTag should reflect the new tag") + } +} + +func TestMiddleware_CacheSkipsSetCookieResponses(t *testing.T) { + var calls int32 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + w.Header().Set("Content-Type", "text/html") + w.Header().Add("Set-Cookie", "session=abc; Path=/") + _, _ = w.Write([]byte(``)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + m.EnableCache() + + for i := 0; i < 3; i++ { + rec := httptest.NewRecorder() + m.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) + } + + if got := atomic.LoadInt32(&calls); got != 3 { + t.Errorf("Set-Cookie responses must not be cached — expected 3 upstream calls, got %d", got) + } +} + +func TestMiddleware_CacheSkipsNon200(t *testing.T) { + var calls int32 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`fail`)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + m.EnableCache() + + for i := 0; i < 2; i++ { + rec := httptest.NewRecorder() + m.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) + } + + if got := atomic.LoadInt32(&calls); got != 2 { + t.Errorf("non-200 responses must not be cached — expected 2 upstream calls, got %d", got) + } +} + +func TestMiddleware_CacheKeyIncludesQueryString(t *testing.T) { + var calls int32 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(`` + r.URL.RawQuery + ``)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + m.EnableCache() + + for _, q := range []string{"a=1", "a=2", "a=1"} { + rec := httptest.NewRecorder() + m.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/page?"+q, nil)) + } + + // Two distinct queries → two upstream calls. The third repeats `a=1` + // so it should hit the cache. + if got := atomic.LoadInt32(&calls); got != 2 { + t.Errorf("cache key should distinguish query strings: expected 2 upstream calls, got %d", got) + } +} + +func TestMiddleware_CacheSkipsCookieRequests(t *testing.T) { + var calls int32 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(``)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + m.EnableCache() + + for i := 0; i < 3; i++ { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Cookie", "session=abc") + rec := httptest.NewRecorder() + m.ServeHTTP(rec, req) + } + + if got := atomic.LoadInt32(&calls); got != 3 { + t.Errorf("requests with Cookie must not be cached: expected 3 upstream calls, got %d", got) + } +} + +func TestMiddleware_CacheSkipsAuthorizationRequests(t *testing.T) { + var calls int32 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(``)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + m.EnableCache() + + for i := 0; i < 3; i++ { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer xyz") + rec := httptest.NewRecorder() + m.ServeHTTP(rec, req) + } + + if got := atomic.LoadInt32(&calls); got != 3 { + t.Errorf("requests with Authorization must not be cached: expected 3 upstream calls, got %d", got) + } +} + +func TestMiddleware_CacheSkipsCacheControlPrivate(t *testing.T) { + cases := []string{"private", "no-store", "no-cache", "private, max-age=60"} + for _, cc := range cases { + t.Run(cc, func(t *testing.T) { + var calls int32 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + w.Header().Set("Content-Type", "text/html") + w.Header().Set("Cache-Control", cc) + _, _ = w.Write([]byte(``)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + m.EnableCache() + + for i := 0; i < 2; i++ { + rec := httptest.NewRecorder() + m.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) + } + + if got := atomic.LoadInt32(&calls); got != 2 { + t.Errorf("Cache-Control %q should disable caching, got %d upstream calls", cc, got) + } + }) + } +} + +func TestMiddleware_CacheSkipsVaryByCookie(t *testing.T) { + var calls int32 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + w.Header().Set("Content-Type", "text/html") + w.Header().Set("Vary", "Origin, Cookie") + _, _ = w.Write([]byte(``)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + m.EnableCache() + + for i := 0; i < 2; i++ { + rec := httptest.NewRecorder() + m.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) + } + + if got := atomic.LoadInt32(&calls); got != 2 { + t.Errorf("Vary by Cookie should disable caching: expected 2 upstream calls, got %d", got) + } +} + +func TestMiddleware_StripsETagAndLastModified(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Header().Set("ETag", `"upstream-etag-abc"`) + w.Header().Set("Last-Modified", "Wed, 21 Oct 2026 07:28:00 GMT") + _, _ = w.Write([]byte(``)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + rec := httptest.NewRecorder() + m.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) + + if got := rec.Header().Get("ETag"); got != "" { + t.Errorf("ETag should be stripped (the body has been modified), got %q", got) + } + if got := rec.Header().Get("Last-Modified"); got != "" { + t.Errorf("Last-Modified should be stripped (the body has been modified), got %q", got) + } +} + +func TestMiddleware_BodylessStatusPassThrough(t *testing.T) { + cases := []int{ + http.StatusContinue, // 100 + http.StatusNoContent, // 204 + http.StatusNotModified, // 304 + http.StatusSwitchingProtocols, // 101 (1xx range) + } + for _, status := range cases { + t.Run(http.StatusText(status), func(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") // tempting to inject! + w.Header().Set("ETag", `"keep-me"`) + w.WriteHeader(status) + // No body — bodyless statuses must not write one. + }) + + m := New(upstream, testScriptTag, slog.Default()) + rec := httptest.NewRecorder() + m.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) + + if rec.Code != status { + t.Errorf("status passed through wrong: got %d, want %d", rec.Code, status) + } + if rec.Body.Len() != 0 { + t.Errorf("bodyless status %d gained a body of %d bytes", status, rec.Body.Len()) + } + // Validators on bodyless responses should pass through untouched + // (they describe the upstream's representation, which we didn't + // modify because we didn't write a body). + if got := rec.Header().Get("ETag"); got != `"keep-me"` { + t.Errorf("bodyless response should preserve upstream ETag, got %q", got) + } + }) + } +} + +func TestMiddleware_VaryHeaderPresent(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte(``)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + m.ServeHTTP(rec, req) + + if got := rec.Header().Get("Vary"); !containsToken(got, "Accept-Encoding") { + t.Errorf("Vary header should include Accept-Encoding, got %q", got) + } +} + +func TestAcceptsGzip(t *testing.T) { + cases := map[string]bool{ + "": false, + "identity": false, + "gzip": true, + "gzip, deflate": true, + "deflate, gzip;q=0.8": true, + "gzip;q=0": false, + "identity, gzip ; q = 0": false, + "*": true, + "*;q=0": false, + "deflate, *;q=0.5": true, + // RFC 9110 §12.5.3: an explicit coding parameter takes precedence + // over `*`. These two cases caught a regression early on. + "gzip;q=0, *;q=0.5": false, // explicit gzip rejection wins + "*;q=0, gzip": true, // explicit gzip allowance wins + } + for header, want := range cases { + got := acceptsGzip(header) + if got != want { + t.Errorf("acceptsGzip(%q) = %v, want %v", header, got, want) + } + } +} + +func TestAddVary_DoesNotDuplicate(t *testing.T) { + cases := []struct { + name string + initial []string // existing Vary header values (multiple = repeated header) + add string + want []string + }{ + { + name: "empty", + initial: nil, + add: "Accept-Encoding", + want: []string{"Accept-Encoding"}, + }, + { + name: "single matching value", + initial: []string{"Accept-Encoding"}, + add: "Accept-Encoding", + want: []string{"Accept-Encoding"}, + }, + { + name: "comma-separated existing", + initial: []string{"Origin, Accept-Encoding"}, + add: "Accept-Encoding", + want: []string{"Origin, Accept-Encoding"}, + }, + { + name: "case-insensitive match", + initial: []string{"origin, accept-encoding"}, + add: "Accept-Encoding", + want: []string{"origin, accept-encoding"}, + }, + { + name: "different value adds", + initial: []string{"Origin"}, + add: "Accept-Encoding", + want: []string{"Origin", "Accept-Encoding"}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + h := http.Header{} + for _, v := range tc.initial { + h.Add("Vary", v) + } + addVary(h, tc.add) + + got := h.Values("Vary") + if len(got) != len(tc.want) { + t.Fatalf("Vary header count = %d, want %d (got %v)", len(got), len(tc.want), got) + } + for i, want := range tc.want { + if got[i] != want { + t.Errorf("Vary[%d] = %q, want %q", i, got[i], want) + } + } + }) + } +} + +// containsToken reports whether a comma-separated header value contains the +// given token (case-insensitive). Used to assert Vary contents. +func containsToken(header, token string) bool { + for _, part := range strings.Split(header, ",") { + if strings.EqualFold(strings.TrimSpace(part), token) { + return true + } + } + return false +} diff --git a/gateway/internal/server/server.go b/gateway/internal/server/server.go index 9f7d4b9..ac9b7a2 100644 --- a/gateway/internal/server/server.go +++ b/gateway/internal/server/server.go @@ -123,6 +123,16 @@ func New(cfg *config.Config, logger *slog.Logger, version string) (*Server, erro // Create the injection middleware wrapping the upstream. s.injector = inject.New(upstream, scriptTag, logger) + // Enable response caching when it's safe: + // - hot-reload off (file content can change at runtime when on) + // - no SENSITIVE vars present (per REP-RFC-0001 §4.3, the gateway + // MUST NOT cache injected HTML if the encrypted blob may rotate) + if !cfg.HotReload && len(vars.Sensitive) == 0 { + s.injector.EnableCache() + logger.Info("rep.inject.cache_enabled", + "reason", "no SENSITIVE vars and hot-reload disabled") + } + // Step 9: Create hot reload hub if enabled. if cfg.HotReload { s.hotReloadHub = hotreload.NewHub(logger)