diff --git a/README.md b/README.md index 6679825..b9f1e62 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,6 @@ HTTP API providing user/client message handling for an fmsg host. Exposes CRUD o | `FMSG_TLS_KEY` | *(optional)* | Path to the TLS private key file (e.g. `/etc/letsencrypt/live/example.com/privkey.pem`). Must be set together with `FMSG_TLS_CERT`. | | `FMSG_API_PORT` | `443` (TLS) / `8000` (plain) | TCP port to listen on. | | `FMSG_ID_URL` | `http://127.0.0.1:8080` | Base URL of the fmsgid identity service | -| `FMSG_API_RATE_LIMIT`| `10` | Max sustained requests per second per IP | -| `FMSG_API_RATE_BURST`| `20` | Max burst size for the per-IP rate limiter | | `FMSG_API_MAX_DATA_SIZE`| `10` | Maximum message data size in megabytes | | `FMSG_API_MAX_ATTACH_SIZE`| `10` | Maximum attachment file size in megabytes | | `FMSG_API_MAX_MSG_SIZE`| `20` | Maximum total message size (data + attachments) in megabytes | @@ -133,12 +131,8 @@ maximum long-poll duration (60 s) so connections are not dropped prematurely. All routes are prefixed with `/fmsg` and require a valid `Authorization: Bearer ` header. -All routes are subject to per-IP rate limiting. When the limit is exceeded, the -server responds with `429 Too Many Requests`: - -```json -{"error": "rate limit exceeded"} -``` +Rate limiting is enforced at the host level (e.g. `nftables`) rather than in +the application. | Method | Path | Description | | -------- | ------------------------------------------- | ------------------------ | diff --git a/src/go.mod b/src/go.mod index 9534b24..41be0f0 100644 --- a/src/go.mod +++ b/src/go.mod @@ -8,7 +8,6 @@ require ( github.com/golang-jwt/jwt/v5 v5.3.1 github.com/jackc/pgx/v5 v5.8.0 github.com/joho/godotenv v1.5.1 - golang.org/x/time v0.15.0 ) require ( @@ -42,8 +41,9 @@ require ( golang.org/x/arch v0.22.0 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/net v0.51.0 // indirect - golang.org/x/sync v0.19.0 // indirect + golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.34.0 // indirect + golang.org/x/time v0.15.0 // indirect google.golang.org/protobuf v1.36.10 // indirect ) diff --git a/src/go.sum b/src/go.sum index 7daea34..7a75f93 100644 --- a/src/go.sum +++ b/src/go.sum @@ -95,6 +95,8 @@ golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= diff --git a/src/main.go b/src/main.go index 2b88d9c..f45812b 100644 --- a/src/main.go +++ b/src/main.go @@ -45,8 +45,6 @@ func main() { // Optional configuration with defaults. idURL := envOrDefault("FMSG_ID_URL", "http://127.0.0.1:8080") - rateLimit := envOrDefaultInt("FMSG_API_RATE_LIMIT", 10) - rateBurst := envOrDefaultInt("FMSG_API_RATE_BURST", 20) maxDataSize := int64(envOrDefaultInt("FMSG_API_MAX_DATA_SIZE", 10)) * 1024 * 1024 maxAttachSize := int64(envOrDefaultInt("FMSG_API_MAX_ATTACH_SIZE", 10)) * 1024 * 1024 maxMsgSize := int64(envOrDefaultInt("FMSG_API_MAX_MSG_SIZE", 20)) * 1024 * 1024 @@ -88,8 +86,7 @@ func main() { log.Printf("CORS enabled for origins: %s", strings.Join(corsOrigins, ", ")) } - // Global rate limiter. - router.Use(middleware.NewRateLimiter(ctx, float64(rateLimit), rateBurst)) + // Global rate limiting is handled by nftables at the host level. // Instantiate handlers. msgHandler := handlers.NewMessageHandler(database, dataDir, maxDataSize, maxMsgSize, shortTextSize) diff --git a/src/middleware/jwt.go b/src/middleware/jwt.go index 8a3c353..1657292 100644 --- a/src/middleware/jwt.go +++ b/src/middleware/jwt.go @@ -7,10 +7,12 @@ import ( "log" "net/http" "strings" + "sync" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" + "golang.org/x/sync/singleflight" ) // IdentityKey is the Gin context key under which the authenticated user @@ -221,28 +223,95 @@ func IsValidAddr(addr string) bool { return strings.Contains(rest, "@") } +// fmsgIDClient is a dedicated HTTP client with a bounded timeout so that a +// slow or hung fmsgid never blocks an API request goroutine indefinitely +// (which would otherwise hold the inbound HTTP connection open and exhaust +// the browser's per-host connection limit). +var fmsgIDClient = &http.Client{Timeout: 5 * time.Second} + +// fmsgIDCacheTTL is how long a positive fmsgid lookup is cached. Tokens are +// re-validated every time, but the relatively expensive network round-trip to +// fmsgid is short-circuited for this window. Negative results are not cached. +const fmsgIDCacheTTL = 30 * time.Second + +type fmsgIDEntry struct { + expires time.Time + code int + acceptingNew bool +} + +var fmsgIDCache sync.Map // map[string]fmsgIDEntry, key = addr + +// fmsgIDGroup coalesces concurrent lookups for the same address so that a +// burst of cache misses (e.g. several browser requests arriving before the +// first response is cached) results in a single upstream fmsgid call. +var fmsgIDGroup singleflight.Group + +type fmsgIDResult struct { + code int + acceptingNew bool +} + // checkFmsgID queries the fmsgid service for a user address. -// Returns (statusCode, acceptingNew, error). +// Returns (statusCode, acceptingNew, error). Successful 200 responses are +// cached for fmsgIDCacheTTL to avoid hammering fmsgid when a browser fires +// many concurrent requests with the same JWT. Concurrent cache misses for +// the same address are deduplicated via singleflight. func checkFmsgID(idURL, addr string) (int, bool, error) { - url := strings.TrimRight(idURL, "/") + "/fmsgid/" + addr - resp, err := http.Get(url) //nolint:gosec // URL constructed from trusted config + validated addr + if v, ok := fmsgIDCache.Load(addr); ok { + entry := v.(fmsgIDEntry) + if time.Now().Before(entry.expires) { + return entry.code, entry.acceptingNew, nil + } + fmsgIDCache.Delete(addr) + } + + v, err, _ := fmsgIDGroup.Do(addr, func() (interface{}, error) { + // Re-check inside the singleflight in case another goroutine just + // populated the cache while we were waiting to enter. + if v, ok := fmsgIDCache.Load(addr); ok { + entry := v.(fmsgIDEntry) + if time.Now().Before(entry.expires) { + return fmsgIDResult{code: entry.code, acceptingNew: entry.acceptingNew}, nil + } + } + return fetchFmsgID(idURL, addr) + }) if err != nil { return 0, false, err } + res := v.(fmsgIDResult) + return res.code, res.acceptingNew, nil +} + +// fetchFmsgID performs the actual HTTP call to fmsgid and stores positive +// results in the cache. +func fetchFmsgID(idURL, addr string) (fmsgIDResult, error) { + url := strings.TrimRight(idURL, "/") + "/fmsgid/" + addr + resp, err := fmsgIDClient.Get(url) //nolint:gosec // URL constructed from trusted config + validated addr + if err != nil { + return fmsgIDResult{}, err + } defer resp.Body.Close() if resp.StatusCode == http.StatusNotFound { - return http.StatusNotFound, false, nil + return fmsgIDResult{code: http.StatusNotFound}, nil } if resp.StatusCode != http.StatusOK { - return resp.StatusCode, false, nil + return fmsgIDResult{code: resp.StatusCode}, nil } var result struct { AcceptingNew bool `json:"acceptingNew"` } if err := decodeJSON(resp.Body, &result); err != nil { - return http.StatusOK, true, nil // assume accepting if parse fails + return fmsgIDResult{code: http.StatusOK, acceptingNew: true}, nil // assume accepting if parse fails } - return http.StatusOK, result.AcceptingNew, nil + + fmsgIDCache.Store(addr, fmsgIDEntry{ + expires: time.Now().Add(fmsgIDCacheTTL), + code: http.StatusOK, + acceptingNew: result.AcceptingNew, + }) + return fmsgIDResult{code: http.StatusOK, acceptingNew: result.AcceptingNew}, nil } diff --git a/src/middleware/jwt_test.go b/src/middleware/jwt_test.go index b98f653..a88f507 100644 --- a/src/middleware/jwt_test.go +++ b/src/middleware/jwt_test.go @@ -308,6 +308,7 @@ func TestEdDSAMode_Reuse(t *testing.T) { } func TestEdDSAMode_FmsgIDUnavailable(t *testing.T) { + fmsgIDCache.Delete("@alice@example.com") srv := fmsgIDServer(t, http.StatusInternalServerError, false) defer srv.Close() priv, jwks := newEdDSAFixture(t) diff --git a/src/middleware/ratelimit.go b/src/middleware/ratelimit.go deleted file mode 100644 index 905649f..0000000 --- a/src/middleware/ratelimit.go +++ /dev/null @@ -1,84 +0,0 @@ -package middleware - -import ( - "context" - "log" - "net/http" - "sync" - "sync/atomic" - "time" - - "github.com/gin-gonic/gin" - "golang.org/x/time/rate" -) - -type visitor struct { - limiter *rate.Limiter - lastSeen atomic.Int64 // UnixNano -} - -type rateLimiter struct { - visitors sync.Map - rps rate.Limit - burst int -} - -// NewRateLimiter returns Gin middleware that enforces a per-IP token-bucket -// rate limit. rps is the sustained requests-per-second rate and burst is the -// maximum burst size allowed. The cleanup goroutine runs until ctx is cancelled. -func NewRateLimiter(ctx context.Context, rps float64, burst int) gin.HandlerFunc { - rl := &rateLimiter{ - rps: rate.Limit(rps), - burst: burst, - } - go rl.cleanup(ctx) - return rl.handler -} - -func (rl *rateLimiter) getVisitor(ip string) *rate.Limiter { - now := time.Now().UnixNano() - if val, ok := rl.visitors.Load(ip); ok { - v := val.(*visitor) - v.lastSeen.Store(now) - return v.limiter - } - v := &visitor{limiter: rate.NewLimiter(rl.rps, rl.burst)} - v.lastSeen.Store(now) - if actual, loaded := rl.visitors.LoadOrStore(ip, v); loaded { - v = actual.(*visitor) - v.lastSeen.Store(now) - } - return v.limiter -} - -func (rl *rateLimiter) handler(c *gin.Context) { - ip := c.ClientIP() - limiter := rl.getVisitor(ip) - if !limiter.Allow() { - log.Printf("rate limit exceeded: ip=%s", ip) - c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "rate limit exceeded"}) - return - } - c.Next() -} - -// cleanup removes visitors that have not been seen for 5 minutes. -func (rl *rateLimiter) cleanup(ctx context.Context) { - ticker := time.NewTicker(1 * time.Minute) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - now := time.Now().UnixNano() - rl.visitors.Range(func(key, value any) bool { - v := value.(*visitor) - if now-v.lastSeen.Load() > int64(5*time.Minute) { - rl.visitors.Delete(key) - } - return true - }) - } - } -} diff --git a/src/middleware/ratelimit_test.go b/src/middleware/ratelimit_test.go deleted file mode 100644 index f793001..0000000 --- a/src/middleware/ratelimit_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package middleware - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/gin-gonic/gin" -) - -func init() { - gin.SetMode(gin.TestMode) -} - -func setupRateLimitRouter(rps float64, burst int) *gin.Engine { - r := gin.New() - r.Use(NewRateLimiter(context.Background(), rps, burst)) - r.GET("/test", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"ok": true}) - }) - return r -} - -func TestRateLimiterAllowsUnderLimit(t *testing.T) { - router := setupRateLimitRouter(10, 5) - - for i := 0; i < 5; i++ { - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/test", nil) - req.RemoteAddr = "1.2.3.4:1234" - router.ServeHTTP(w, req) - if w.Code != http.StatusOK { - t.Fatalf("request %d: expected 200, got %d", i, w.Code) - } - } -} - -func TestRateLimiterBlocksExcessBurst(t *testing.T) { - router := setupRateLimitRouter(1, 3) // 1 rps, burst of 3 - - // First 3 requests should succeed (burst). - for i := 0; i < 3; i++ { - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/test", nil) - req.RemoteAddr = "1.2.3.4:1234" - router.ServeHTTP(w, req) - if w.Code != http.StatusOK { - t.Fatalf("request %d: expected 200, got %d", i, w.Code) - } - } - - // Next request should be rate-limited. - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/test", nil) - req.RemoteAddr = "1.2.3.4:1234" - router.ServeHTTP(w, req) - if w.Code != http.StatusTooManyRequests { - t.Fatalf("expected 429, got %d", w.Code) - } - - var body map[string]string - if err := json.NewDecoder(w.Body).Decode(&body); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - if body["error"] != "rate limit exceeded" { - t.Fatalf("unexpected error message: %s", body["error"]) - } -} - -func TestRateLimiterTracksIPsIndependently(t *testing.T) { - router := setupRateLimitRouter(1, 1) // 1 rps, burst of 1 - - // Exhaust IP A. - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/test", nil) - req.RemoteAddr = "10.0.0.1:1000" - router.ServeHTTP(w, req) - if w.Code != http.StatusOK { - t.Fatalf("IP A first request: expected 200, got %d", w.Code) - } - - w = httptest.NewRecorder() - req, _ = http.NewRequest("GET", "/test", nil) - req.RemoteAddr = "10.0.0.1:1000" - router.ServeHTTP(w, req) - if w.Code != http.StatusTooManyRequests { - t.Fatalf("IP A second request: expected 429, got %d", w.Code) - } - - // IP B should still be allowed. - w = httptest.NewRecorder() - req, _ = http.NewRequest("GET", "/test", nil) - req.RemoteAddr = "10.0.0.2:2000" - router.ServeHTTP(w, req) - if w.Code != http.StatusOK { - t.Fatalf("IP B first request: expected 200, got %d", w.Code) - } -}