From 94fef20bb62a2e014f2d6d9d480b11a07767a9ec Mon Sep 17 00:00:00 2001 From: wuyangfan Date: Mon, 25 May 2026 13:47:45 +0800 Subject: [PATCH] fix(middleware): set rate limit headers on memory store Add optional rateLimiterStoreContext interface so stores can set X-RateLimit-Limit, X-RateLimit-Remaining, and Retry-After headers. RateLimiterMemoryStore implements it using golang.org/x/time/rate. Fixes #2961 Co-authored-by: Cursor --- middleware/rate_limiter.go | 59 ++++++++++++++++++++++++++++++--- middleware/rate_limiter_test.go | 37 +++++++++++++++++++++ 2 files changed, 92 insertions(+), 4 deletions(-) diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index 2746a3de1..6ac617bf8 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -6,6 +6,7 @@ package middleware import ( "math" "net/http" + "strconv" "sync" "time" @@ -13,6 +14,17 @@ import ( "golang.org/x/time/rate" ) +const ( + HeaderXRateLimitLimit = "X-RateLimit-Limit" + HeaderXRateLimitRemaining = "X-RateLimit-Remaining" +) + +// rateLimiterStoreContext is an optional interface for RateLimiterStore implementations +// that can set rate limit response headers on the given echo.Context. +type rateLimiterStoreContext interface { + AllowContext(c echo.Context, identifier string) (bool, error) +} + // RateLimiterStore is the interface to be implemented by custom stores. type RateLimiterStore interface { // Stores for the rate limiter have to implement the Allow method @@ -140,7 +152,13 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { return nil } - if allow, err := config.Store.Allow(identifier); !allow { + var allow bool + if store, ok := config.Store.(rateLimiterStoreContext); ok { + allow, err = store.AllowContext(c, identifier) + } else { + allow, err = config.Store.Allow(identifier) + } + if !allow { c.Error(config.DenyHandler(c, identifier, err)) return nil } @@ -238,7 +256,21 @@ var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{ // Allow implements RateLimiterStore.Allow func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { + _, allowed := store.allow(identifier) + return allowed, nil +} + +// AllowContext implements rateLimiterStoreContext for RateLimiterMemoryStore. +func (store *RateLimiterMemoryStore) AllowContext(c echo.Context, identifier string) (bool, error) { + limiter, allowed := store.allow(identifier) + store.setRateLimitHeaders(c, limiter, allowed) + return allowed, nil +} + +func (store *RateLimiterMemoryStore) allow(identifier string) (*rate.Limiter, bool) { store.mutex.Lock() + defer store.mutex.Unlock() + limiter, exists := store.visitors[identifier] if !exists { limiter = new(Visitor) @@ -250,9 +282,28 @@ func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { if now.Sub(store.lastCleanup) > store.expiresIn { store.cleanupStaleVisitors() } - allowed := limiter.AllowN(now, 1) - store.mutex.Unlock() - return allowed, nil + return limiter.Limiter, limiter.AllowN(now, 1) +} + +func (store *RateLimiterMemoryStore) setRateLimitHeaders(c echo.Context, limiter *rate.Limiter, allowed bool) { + res := c.Response() + res.Header().Set(HeaderXRateLimitLimit, strconv.Itoa(store.burst)) + + remaining := int(math.Floor(limiter.Tokens())) + if remaining < 0 { + remaining = 0 + } + res.Header().Set(HeaderXRateLimitRemaining, strconv.Itoa(remaining)) + + if !allowed { + now := store.timeNow() + reservation := limiter.ReserveN(now, 1) + delay := reservation.Delay() + if delay > 0 { + res.Header().Set(echo.HeaderRetryAfter, strconv.Itoa(int(math.Ceil(delay.Seconds())))) + } + reservation.Cancel() + } } /* diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go index 655d4731d..a5af74650 100644 --- a/middleware/rate_limiter_test.go +++ b/middleware/rate_limiter_test.go @@ -8,6 +8,7 @@ import ( "math/rand" "net/http" "net/http/httptest" + "strconv" "sync" "sync/atomic" "testing" @@ -624,3 +625,39 @@ func TestRateLimiterMemoryStore_TimeOrdering(t *testing.T) { allowed4, _ := store.Allow("user1") assert.True(t, allowed4, "Request 4 should be allowed (1 token available)") } + +func TestRateLimiterMemoryStore_AllowContext_SetsHeaders(t *testing.T) { + store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + e := echo.New() + + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + mw := RateLimiter(store) + + for i := 0; i < 3; i++ { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderXRealIP, "127.0.0.1") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := mw(handler)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "3", rec.Header().Get(HeaderXRateLimitLimit)) + assert.Equal(t, strconv.Itoa(2-i), rec.Header().Get(HeaderXRateLimitRemaining)) + assert.Empty(t, rec.Header().Get(echo.HeaderRetryAfter)) + } + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderXRealIP, "127.0.0.1") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := mw(handler)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusTooManyRequests, rec.Code) + assert.Equal(t, "3", rec.Header().Get(HeaderXRateLimitLimit)) + assert.Equal(t, "0", rec.Header().Get(HeaderXRateLimitRemaining)) + assert.NotEmpty(t, rec.Header().Get(echo.HeaderRetryAfter)) +}