diff --git a/cookiecrypt.go b/cookiecrypt.go index 3722624..0960589 100644 --- a/cookiecrypt.go +++ b/cookiecrypt.go @@ -1,11 +1,14 @@ package cookiecrypt import ( + "bufio" "crypto/aes" "crypto/cipher" "crypto/rand" "encoding/base64" "fmt" + "io" + "net" "net/http" "strings" @@ -133,6 +136,10 @@ func (cc *CookieCrypt) decrypt(ciphertext string) (string, error) { return string(plaintext), nil } +// cookieInterceptResponseWriter wraps the downstream ResponseWriter and +// preserves optional interfaces such as Hijacker, Flusher, and ReadFrom. +// The Unwrap/Hijack support is needed so Caddy's ResponseController can +// traverse wrapper layers and still upgrade WebSocket/CONNECT connections. type cookieInterceptResponseWriter struct { http.ResponseWriter logger *zap.Logger @@ -196,6 +203,56 @@ func (w *cookieInterceptResponseWriter) WriteHeader(statusCode int) { w.ResponseWriter.WriteHeader(statusCode) } +// Hijack delegates through wrapper layers to the first real http.Hijacker. +// This is required because Caddy may wrap the original ResponseWriter in +// layers such as headers.responseWriterWrapper and caddyhttp.responseRecorder. +func (w *cookieInterceptResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return hijackResponseWriter(w.ResponseWriter, w.logger) +} + +// Unwrap exposes the next underlying ResponseWriter so http.NewResponseController +// can traverse the wrapper chain and find optional interfaces like Hijacker. +func (w *cookieInterceptResponseWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + +// hijackResponseWriter walks nested Unwrap() wrappers until it finds a Hijacker. +// If none is found, it returns http.ErrNotSupported to preserve standard Go semantics. +func hijackResponseWriter(rw http.ResponseWriter, logger *zap.Logger) (net.Conn, *bufio.ReadWriter, error) { + chain := []string{} + for { + chain = append(chain, fmt.Sprintf("%T", rw)) + if hj, ok := rw.(http.Hijacker); ok { + logger.Debug("found hijack-capable writer in chain", + zap.String("writer_type", fmt.Sprintf("%T", rw)), + zap.Strings("writer_chain", chain), + ) + return hj.Hijack() + } + uw, ok := rw.(interface{ Unwrap() http.ResponseWriter }) + if !ok { + logger.Error("no hijack support found in response writer chain", + zap.Strings("writer_chain", chain), + ) + return nil, nil, http.ErrNotSupported + } + rw = uw.Unwrap() + } +} + +func (w *cookieInterceptResponseWriter) Flush() { + if fl, ok := w.ResponseWriter.(http.Flusher); ok { + fl.Flush() + } +} + +func (w *cookieInterceptResponseWriter) ReadFrom(r io.Reader) (int64, error) { + if rf, ok := w.ResponseWriter.(io.ReaderFrom); ok { + return rf.ReadFrom(r) + } + return io.Copy(w, r) +} + func (cc CookieCrypt) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { for _, c := range r.Cookies() { if !strings.HasPrefix(c.Name, cc.Prefix) { @@ -232,4 +289,7 @@ var ( _ caddy.Validator = (*CookieCrypt)(nil) _ caddyhttp.MiddlewareHandler = (*CookieCrypt)(nil) _ caddyfile.Unmarshaler = (*CookieCrypt)(nil) + _ http.Hijacker = (*cookieInterceptResponseWriter)(nil) + _ http.Flusher = (*cookieInterceptResponseWriter)(nil) + _ io.ReaderFrom = (*cookieInterceptResponseWriter)(nil) )