Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 73 additions & 5 deletions dashboard/dashboard.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ type Handler struct {
maintenanceBanner string
bannerPath string

// Per-IP rate limiter for public badge endpoints.
badgeLimiter *ipRateLimiter

// Probe state — per-probe health history ring.
probeMu sync.Mutex
probeStates map[string]*ProbeState
Expand All @@ -129,7 +132,7 @@ type Handler struct {

// NewHandler creates a ready-to-use dashboard Handler backed by cb.
func NewHandler(cb Callbacks) *Handler {
return &Handler{cb: cb}
return &Handler{cb: cb, badgeLimiter: newIPRateLimiter()}
}

// --------------------------------------------------------------------------
Expand Down Expand Up @@ -421,6 +424,71 @@ func readSmallBody(r *http.Request, maxBytes int64) (string, error) {
return strings.TrimRight(string(data), "\r\n\t "), nil
}

// --------------------------------------------------------------------------
// Per-IP rate limiter for public endpoints (e.g. badge SVGs)
// --------------------------------------------------------------------------

// ipRateLimiter is a per-IP sliding-window rate limiter. Safe for concurrent use.
type ipRateLimiter struct {
mu sync.Mutex
ips map[string]*ipBucket
}

type ipBucket struct {
count int
resetAt time.Time
}

func newIPRateLimiter() *ipRateLimiter {
return &ipRateLimiter{ips: make(map[string]*ipBucket)}
}

// extractClientIP returns the client IP, respecting X-Real-IP when the
// direct connection is from localhost.
func extractClientIP(r *http.Request) string {
remoteIP, _, _ := net.SplitHostPort(r.RemoteAddr)
if remoteIP == "127.0.0.1" || remoteIP == "::1" || remoteIP == "localhost" {
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
return realIP
}
}
return remoteIP
}

// middleware returns an http.HandlerFunc that rate-limits by client IP.
// maxReqs is the burst size; window is reset after it elapses.
func (l *ipRateLimiter) middleware(maxReqs int, window time.Duration, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ip := extractClientIP(r)

l.mu.Lock()
b, ok := l.ips[ip]
now := time.Now()
if !ok || now.After(b.resetAt) {
l.ips[ip] = &ipBucket{count: 1, resetAt: now.Add(window)}
l.mu.Unlock()
next(w, r)
return
}
if b.count >= maxReqs {
l.mu.Unlock()
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
b.count++
// Periodic sweep: every 1000th increment, purge expired entries.
if b.count%1000 == 0 {
for k, v := range l.ips {
if now.After(v.resetAt) {
delete(l.ips, k)
}
}
}
l.mu.Unlock()
next(w, r)
}
}

// localhostOnly rejects requests not originating from loopback.
// Trusts X-Real-IP only when the direct connection is from localhost.
func localhostOnly(next http.HandlerFunc) http.HandlerFunc {
Expand Down Expand Up @@ -625,21 +693,21 @@ func (h *Handler) Serve(addr string) error {
}
}

mux.HandleFunc("/api/badge/nodes", func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc("/api/badge/nodes", h.badgeLimiter.middleware(30, time.Minute, func(w http.ResponseWriter, r *http.Request) {
payload := h.cb.BuildStatsPayload(false)
activeNodes, _ := payload["active_nodes"].(int)
c := "#4c1"
if activeNodes == 0 {
c = "#9f9f9f"
}
serveBadge(w, "online nodes", fmtCount(activeNodes), c)
})
}))

mux.HandleFunc("/api/badge/requests", func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc("/api/badge/requests", h.badgeLimiter.middleware(30, time.Minute, func(w http.ResponseWriter, r *http.Request) {
payload := h.cb.BuildStatsPayload(false)
totalReqs, _ := payload["total_requests"].(int64)
serveBadge(w, "requests", fmtCount(int(totalReqs)), "#a855f7")
})
}))

// Snapshot trigger endpoint (POST only, localhost only).
mux.HandleFunc("/api/snapshot", func(w http.ResponseWriter, r *http.Request) {
Expand Down
Loading