Skip to content
Open
178 changes: 162 additions & 16 deletions protocol/http/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,44 @@ package http

import (
std_bufio "bufio"
"bytes"
"context"
"io"
"net"
"net/http"
"sort"
"strings"
"time"

"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe"
)

const defaultProxyAuthRetryTimeout = 5 * time.Second

type HTTPServerOptions struct {
ProxyAuthRetryTimeout time.Duration
Logger logger.ContextLogger
}

func normalizeHTTPServerOptions(options HTTPServerOptions) HTTPServerOptions {
if options.ProxyAuthRetryTimeout <= 0 {
options.ProxyAuthRetryTimeout = defaultProxyAuthRetryTimeout
}
if options.Logger == nil {
options.Logger = logger.NOP()
}
return options
}

func HandleConnectionEx(
ctx context.Context,
conn net.Conn,
Expand All @@ -28,37 +49,68 @@ func HandleConnectionEx(
source M.Socksaddr,
onClose N.CloseHandlerFunc,
) error {
return HandleConnectionExWithOptions(ctx, conn, reader, authenticator, handler, source, onClose, HTTPServerOptions{})
}

func HandleConnectionExWithOptions(
ctx context.Context,
conn net.Conn,
reader *std_bufio.Reader,
authenticator *auth.Authenticator,
handler N.TCPConnectionHandlerEx,
source M.Socksaddr,
onClose N.CloseHandlerFunc,
options HTTPServerOptions,
) error {
options = normalizeHTTPServerOptions(options)
missingProxyAuthorizationRetried := false
waitingRetryProxyAuthentication := false
for {
request, err := ReadRequest(reader)
if err != nil {
return E.Cause(err, "read http request")
}
printRequestHeaders(ctx, options.Logger, request)
if waitingRetryProxyAuthentication {
waitingRetryProxyAuthentication = false
err = conn.SetReadDeadline(time.Time{})
if err != nil {
return E.Cause(err, "clear retry-proxy-authentication timeout")
}
}
retryMissingProxyAuthorization := shouldRetryMissingProxyAuthorization(request)
if authenticator != nil {
username, password, authOk := ParseBasicAuth(request.Header.Get("Proxy-Authorization"))
authOk = authOk && authenticator.Verify(username, password)
if authOk {
ctx = auth.ContextWithUser(ctx, username)
} else {
keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive" && request.ContentLength == 0
// Since no one else is using the library, use a fixed realm until rewritten
headers := []string{"Proxy-Authenticate", `Basic realm="sing-box" charset="UTF-8"`}
if !keepAlive {
headers = append(headers, "Connection", "close")
}
err = responseWith(request, http.StatusProxyAuthRequired, headers...).Write(conn)
proxyAuthRequiredResponse := responseWith(
request, http.StatusProxyAuthRequired,
"Proxy-Authenticate", `Basic realm="sing-box", charset="UTF-8"`,
)
printResponseHeaders(ctx, options.Logger, proxyAuthRequiredResponse)
err = writeResponseBuffered(conn, proxyAuthRequiredResponse)
if err != nil {
return err
}
if keepAlive {
continue
return E.Cause(err, "write proxy authentication required response")
}
authorization := request.Header.Get("Proxy-Authorization")
switch {
case username != "":
return E.New("http: authentication failed, username=", username, ", password=", password)
case authorization != "":
return E.New("http: authentication failed, Proxy-Authorization=", authorization)
default:
} else {
if retryMissingProxyAuthorization && !missingProxyAuthorizationRetried {
missingProxyAuthorizationRetried = true
err = conn.SetReadDeadline(time.Now().Add(options.ProxyAuthRetryTimeout))
if err != nil {
return E.Cause(err, "set retry-proxy-authentication timeout")
}
waitingRetryProxyAuthentication = true
continue
}
return E.New("http: authentication failed, no Proxy-Authorization header")
}
}
Expand Down Expand Up @@ -125,7 +177,7 @@ func HandleConnectionEx(
}
return bufio.CopyConn(ctx, conn, serverConn)
} else {
err = handleHTTPConnection(ctx, handler, conn, request, source)
err = handleHTTPConnection(ctx, handler, conn, request, source, options.Logger)
if err != nil {
return err
}
Expand All @@ -137,9 +189,11 @@ func handleHTTPConnection(
ctx context.Context,
handler N.TCPConnectionHandlerEx,
conn net.Conn,
request *http.Request, source M.Socksaddr,
request *http.Request,
source M.Socksaddr,
contextLogger logger.ContextLogger,
) error {
keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
keepAlive := isProxyKeepAlive(request)
request.RequestURI = ""

removeHopByHopHeaders(request.Header)
Expand All @@ -152,7 +206,9 @@ func handleHTTPConnection(
}

if request.URL.Scheme == "" || request.URL.Host == "" {
return responseWith(request, http.StatusBadRequest).Write(conn)
badRequestResponse := responseWith(request, http.StatusBadRequest)
printResponseHeaders(ctx, contextLogger, badRequestResponse)
return badRequestResponse.Write(conn)
}

var innerErr common.TypedValue[error]
Expand All @@ -178,7 +234,9 @@ func handleHTTPConnection(
response, err := httpClient.Do(request.WithContext(requestCtx))
if err != nil {
cancel()
return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn))
badGatewayResponse := responseWith(request, http.StatusBadGateway)
printResponseHeaders(ctx, contextLogger, badGatewayResponse)
return E.Errors(innerErr.Load(), err, badGatewayResponse.Write(conn))
}

removeHopByHopHeaders(response.Header)
Expand All @@ -190,6 +248,7 @@ func handleHTTPConnection(
}

response.Close = !keepAlive
printResponseHeaders(ctx, contextLogger, response)

err = response.Write(conn)
if err != nil {
Expand All @@ -204,6 +263,77 @@ func handleHTTPConnection(
return nil
}

func isProxyKeepAlive(request *http.Request) bool {
connection := request.Header.Get("Connection")
proxyConnection := request.Header.Get("Proxy-Connection")

if request.ProtoMajor > 1 || (request.ProtoMajor == 1 && request.ProtoMinor >= 1) {
// HTTP/1.1+ connections are persistent unless explicitly closed.
return !hasHeaderToken(connection, "close") && !hasHeaderToken(proxyConnection, "close")
}

if request.ProtoMajor == 1 && request.ProtoMinor == 0 {
// HTTP/1.0 defaults to close unless keep-alive is requested.
if hasHeaderToken(connection, "close") || hasHeaderToken(proxyConnection, "close") {
return false
}
return hasHeaderToken(connection, "keep-alive") || hasHeaderToken(proxyConnection, "keep-alive")
}

return false
}

func hasHeaderToken(headerValue string, token string) bool {
for _, h := range strings.Split(headerValue, ",") {
if strings.EqualFold(strings.TrimSpace(h), token) {
return true
}
}
return false
}

func shouldRetryMissingProxyAuthorization(request *http.Request) bool {
return isProxyKeepAlive(request) &&
!request.Close &&
!hasHeaderToken(request.Header.Get("Connection"), "upgrade") &&
request.ContentLength == 0 &&
len(request.TransferEncoding) == 0
}

func printRequestHeaders(ctx context.Context, contextLogger logger.ContextLogger, request *http.Request) {
contextLogger.TraceContext(ctx, "request protocol: ", request.Proto)
printHeaders(ctx, contextLogger, "request", request.Header)
}

func printResponseHeaders(ctx context.Context, contextLogger logger.ContextLogger, response *http.Response) {
contextLogger.TraceContext(ctx, "response: protocol=", response.Proto, " status=", response.StatusCode)
printHeaders(ctx, contextLogger, "response", response.Header)
}

func printHeaders(ctx context.Context, contextLogger logger.ContextLogger, kind string, header http.Header) {
keys := make([]string, 0, len(header))
for key := range header {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
redacted := shouldRedactHeaderValue(key)
for _, value := range header[key] {
if redacted {
value = "[redacted]"
}
contextLogger.TraceContext(ctx, kind, " header: ", key, ": ", value)
}
}
}

func shouldRedactHeaderValue(headerKey string) bool {
return strings.EqualFold(headerKey, "Authorization") ||
strings.EqualFold(headerKey, "Proxy-Authorization") ||
strings.EqualFold(headerKey, "Cookie") ||
strings.EqualFold(headerKey, "Set-Cookie")
}

func removeHopByHopHeaders(header http.Header) {
// Strip hop-by-hop header based on RFC:
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1
Expand Down Expand Up @@ -244,6 +374,22 @@ func removeExtraHTTPHostPort(req *http.Request) {
req.URL.Host = host
}

func writeResponseBuffered(conn net.Conn, response *http.Response) error {
var responseBuffer bytes.Buffer
err := response.Write(&responseBuffer)
if err != nil {
return err
}
n, err := conn.Write(responseBuffer.Bytes())
if err != nil {
return err
}
if n != responseBuffer.Len() {
return io.ErrShortWrite
}
return nil
}

func responseWith(request *http.Request, statusCode int, headers ...string) *http.Response {
var header http.Header
if len(headers) > 0 {
Expand Down
Loading