Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions accept/accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,32 @@ func (ls *logSampler) Cleanup() {

// ── Rate limiter ──────────────────────────────────────────────────────────────

// globalRateBucket caps total requests per second process-wide.
type globalRateBucket struct {
tokens float64
rate float64
maxFill float64
lastFill time.Time
}

func newGlobalRateBucket(rate float64) *globalRateBucket {
return &globalRateBucket{tokens: rate, rate: rate, maxFill: rate, lastFill: time.Now()}
}

func (gb *globalRateBucket) allow(now time.Time) bool {
elapsed := now.Sub(gb.lastFill).Seconds()
gb.tokens += elapsed * gb.rate
if gb.tokens > gb.maxFill {
gb.tokens = gb.maxFill
}
gb.lastFill = now
if gb.tokens < 1 {
return false
}
gb.tokens--
return true
}

// RateLimiter tracks per-IP registration attempts using a token bucket.
//
// Whitelist (optional): a list of CIDR ranges, each paired with an
Expand Down Expand Up @@ -502,6 +528,7 @@ type Acceptor struct {
connCount atomic.Int64
maxConnections int64
rateLimiter *RateLimiter
globalBucket *globalRateBucket
logSampler *logSampler
listener net.Listener
dispatcher Dispatcher
Expand All @@ -514,6 +541,7 @@ func NewAcceptor(maxConns int64, d Dispatcher) *Acceptor {
return &Acceptor{
maxConnections: maxConns,
rateLimiter: NewRateLimiter(100, time.Second, 50_000),
globalBucket: newGlobalRateBucket(1000), // 1000 req/s process-wide
logSampler: newLogSampler(1000),
dispatcher: d,
}
Expand Down Expand Up @@ -807,6 +835,14 @@ func (a *Acceptor) handleJSONConn(conn net.Conn, reader io.Reader) {
// Per-connection rate check with 5 s grace period.
connReqCount++
if elapsed := time.Since(connStart).Seconds(); elapsed >= 5 {
// Process-level global rate cap: reject when total request
// rate across all connections exceeds the global ceiling.
if !a.globalBucket.allow(time.Now()) {
slog.Warn("global rate limit exceeded, closing connection",
"remote", conn.RemoteAddr())
return
}

rate := float64(connReqCount) / elapsed
if rate > 500 {
slog.Warn("closing abusive connection", "remote", conn.RemoteAddr(), "rate", rate)
Expand Down Expand Up @@ -888,6 +924,14 @@ func (a *Acceptor) handleBinaryConn(conn net.Conn) {

connReqCount++
if elapsed := time.Since(connStart).Seconds(); elapsed >= 5 {
// Process-level global rate cap: reject when total request
// rate across all connections exceeds the global ceiling.
if !a.globalBucket.allow(time.Now()) {
slog.Warn("global rate limit exceeded, closing binary connection",
"remote", conn.RemoteAddr())
return
}

rate := float64(connReqCount) / elapsed
if rate > 500 {
slog.Warn("closing abusive binary connection", "remote", conn.RemoteAddr(), "rate", rate)
Expand Down
Loading