Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 55 additions & 4 deletions middleware/rate_limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,25 @@ package middleware
import (
"math"
"net/http"
"strconv"
"sync"
"time"

"github.com/labstack/echo/v4"
"golang.org/x/time/rate"
)

const (
HeaderXRateLimitLimit = "X-RateLimit-Limit"
HeaderXRateLimitRemaining = "X-RateLimit-Remaining"
)

// rateLimiterStoreContext is an optional interface for RateLimiterStore implementations
// that can set rate limit response headers on the given echo.Context.
type rateLimiterStoreContext interface {
AllowContext(c echo.Context, identifier string) (bool, error)
}

// RateLimiterStore is the interface to be implemented by custom stores.
type RateLimiterStore interface {
// Stores for the rate limiter have to implement the Allow method
Expand Down Expand Up @@ -140,7 +152,13 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
return nil
}

if allow, err := config.Store.Allow(identifier); !allow {
var allow bool
if store, ok := config.Store.(rateLimiterStoreContext); ok {
allow, err = store.AllowContext(c, identifier)
} else {
allow, err = config.Store.Allow(identifier)
}
if !allow {
c.Error(config.DenyHandler(c, identifier, err))
return nil
}
Expand Down Expand Up @@ -238,7 +256,21 @@ var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{

// Allow implements RateLimiterStore.Allow
func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) {
_, allowed := store.allow(identifier)
return allowed, nil
}

// AllowContext implements rateLimiterStoreContext for RateLimiterMemoryStore.
func (store *RateLimiterMemoryStore) AllowContext(c echo.Context, identifier string) (bool, error) {
limiter, allowed := store.allow(identifier)
store.setRateLimitHeaders(c, limiter, allowed)
return allowed, nil
}

func (store *RateLimiterMemoryStore) allow(identifier string) (*rate.Limiter, bool) {
store.mutex.Lock()
defer store.mutex.Unlock()

limiter, exists := store.visitors[identifier]
if !exists {
limiter = new(Visitor)
Expand All @@ -250,9 +282,28 @@ func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) {
if now.Sub(store.lastCleanup) > store.expiresIn {
store.cleanupStaleVisitors()
}
allowed := limiter.AllowN(now, 1)
store.mutex.Unlock()
return allowed, nil
return limiter.Limiter, limiter.AllowN(now, 1)
}

func (store *RateLimiterMemoryStore) setRateLimitHeaders(c echo.Context, limiter *rate.Limiter, allowed bool) {
res := c.Response()
res.Header().Set(HeaderXRateLimitLimit, strconv.Itoa(store.burst))

remaining := int(math.Floor(limiter.Tokens()))
if remaining < 0 {
remaining = 0
}
res.Header().Set(HeaderXRateLimitRemaining, strconv.Itoa(remaining))

if !allowed {
now := store.timeNow()
reservation := limiter.ReserveN(now, 1)
delay := reservation.Delay()
if delay > 0 {
res.Header().Set(echo.HeaderRetryAfter, strconv.Itoa(int(math.Ceil(delay.Seconds()))))
}
reservation.Cancel()
}
}

/*
Expand Down
37 changes: 37 additions & 0 deletions middleware/rate_limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"math/rand"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -624,3 +625,39 @@ func TestRateLimiterMemoryStore_TimeOrdering(t *testing.T) {
allowed4, _ := store.Allow("user1")
assert.True(t, allowed4, "Request 4 should be allowed (1 token available)")
}

func TestRateLimiterMemoryStore_AllowContext_SetsHeaders(t *testing.T) {
store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
e := echo.New()

handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
mw := RateLimiter(store)

for i := 0; i < 3; i++ {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderXRealIP, "127.0.0.1")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

err := mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "3", rec.Header().Get(HeaderXRateLimitLimit))
assert.Equal(t, strconv.Itoa(2-i), rec.Header().Get(HeaderXRateLimitRemaining))
assert.Empty(t, rec.Header().Get(echo.HeaderRetryAfter))
}

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderXRealIP, "127.0.0.1")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

err := mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
assert.Equal(t, "3", rec.Header().Get(HeaderXRateLimitLimit))
assert.Equal(t, "0", rec.Header().Get(HeaderXRateLimitRemaining))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderRetryAfter))
}