From 24e9a117823dc16caf83dc45057c6cf2d416bac5 Mon Sep 17 00:00:00 2001 From: matthew-pilot Date: Sat, 30 May 2026 07:13:09 +0000 Subject: [PATCH] fix(dashboard): add per-IP rate limiting to /api/badge/* endpoints (PILOT-338) The two public badge endpoints (/api/badge/nodes, /api/badge/requests) had no per-IP rate limiting, allowing unlimited scrape by any client. While these are non-confidential public stats, unbounded request volume wastes CPU on SVG generation. Add a sliding-window per-IP rate limiter (30 req/min per IP) to both badge endpoints, following the same client-IP extraction pattern used by localhostOnly (respects X-Real-IP from trusted reverse proxies). Closes PILOT-338 --- dashboard/dashboard.go | 78 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 73 insertions(+), 5 deletions(-) diff --git a/dashboard/dashboard.go b/dashboard/dashboard.go index 3c63687..56d5504 100644 --- a/dashboard/dashboard.go +++ b/dashboard/dashboard.go @@ -115,6 +115,9 @@ type Handler struct { maintenanceBanner string bannerPath string + // Per-IP rate limiter for public badge endpoints. + badgeLimiter *ipRateLimiter + // Probe state — per-probe health history ring. probeMu sync.Mutex probeStates map[string]*ProbeState @@ -129,7 +132,7 @@ type Handler struct { // NewHandler creates a ready-to-use dashboard Handler backed by cb. func NewHandler(cb Callbacks) *Handler { - return &Handler{cb: cb} + return &Handler{cb: cb, badgeLimiter: newIPRateLimiter()} } // -------------------------------------------------------------------------- @@ -421,6 +424,71 @@ func readSmallBody(r *http.Request, maxBytes int64) (string, error) { return strings.TrimRight(string(data), "\r\n\t "), nil } +// -------------------------------------------------------------------------- +// Per-IP rate limiter for public endpoints (e.g. badge SVGs) +// -------------------------------------------------------------------------- + +// ipRateLimiter is a per-IP sliding-window rate limiter. Safe for concurrent use. +type ipRateLimiter struct { + mu sync.Mutex + ips map[string]*ipBucket +} + +type ipBucket struct { + count int + resetAt time.Time +} + +func newIPRateLimiter() *ipRateLimiter { + return &ipRateLimiter{ips: make(map[string]*ipBucket)} +} + +// extractClientIP returns the client IP, respecting X-Real-IP when the +// direct connection is from localhost. +func extractClientIP(r *http.Request) string { + remoteIP, _, _ := net.SplitHostPort(r.RemoteAddr) + if remoteIP == "127.0.0.1" || remoteIP == "::1" || remoteIP == "localhost" { + if realIP := r.Header.Get("X-Real-IP"); realIP != "" { + return realIP + } + } + return remoteIP +} + +// middleware returns an http.HandlerFunc that rate-limits by client IP. +// maxReqs is the burst size; window is reset after it elapses. +func (l *ipRateLimiter) middleware(maxReqs int, window time.Duration, next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ip := extractClientIP(r) + + l.mu.Lock() + b, ok := l.ips[ip] + now := time.Now() + if !ok || now.After(b.resetAt) { + l.ips[ip] = &ipBucket{count: 1, resetAt: now.Add(window)} + l.mu.Unlock() + next(w, r) + return + } + if b.count >= maxReqs { + l.mu.Unlock() + http.Error(w, "Too Many Requests", http.StatusTooManyRequests) + return + } + b.count++ + // Periodic sweep: every 1000th increment, purge expired entries. + if b.count%1000 == 0 { + for k, v := range l.ips { + if now.After(v.resetAt) { + delete(l.ips, k) + } + } + } + l.mu.Unlock() + next(w, r) + } +} + // localhostOnly rejects requests not originating from loopback. // Trusts X-Real-IP only when the direct connection is from localhost. func localhostOnly(next http.HandlerFunc) http.HandlerFunc { @@ -625,7 +693,7 @@ func (h *Handler) Serve(addr string) error { } } - mux.HandleFunc("/api/badge/nodes", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/api/badge/nodes", h.badgeLimiter.middleware(30, time.Minute, func(w http.ResponseWriter, r *http.Request) { payload := h.cb.BuildStatsPayload(false) activeNodes, _ := payload["active_nodes"].(int) c := "#4c1" @@ -633,13 +701,13 @@ func (h *Handler) Serve(addr string) error { c = "#9f9f9f" } serveBadge(w, "online nodes", fmtCount(activeNodes), c) - }) + })) - mux.HandleFunc("/api/badge/requests", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/api/badge/requests", h.badgeLimiter.middleware(30, time.Minute, func(w http.ResponseWriter, r *http.Request) { payload := h.cb.BuildStatsPayload(false) totalReqs, _ := payload["total_requests"].(int64) serveBadge(w, "requests", fmtCount(int(totalReqs)), "#a855f7") - }) + })) // Snapshot trigger endpoint (POST only, localhost only). mux.HandleFunc("/api/snapshot", func(w http.ResponseWriter, r *http.Request) {