From fd112d326cc7f3cb35ce30753578fb1faa518a5f Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Tue, 13 Jan 2026 14:57:12 +0000 Subject: [PATCH 1/6] Extend circuit breaker functionality to support per-model isolation. --- bridge.go | 59 ++++--- circuit_breaker_integration_test.go | 159 +++++++++++++++++-- circuitbreaker/circuitbreaker.go | 104 +++++++------ circuitbreaker/circuitbreaker_test.go | 212 ++++++++++++++++---------- metrics/metrics.go | 12 +- 5 files changed, 381 insertions(+), 165 deletions(-) diff --git a/bridge.go b/bridge.go index 9983a485..eff06975 100644 --- a/bridge.go +++ b/bridge.go @@ -70,17 +70,18 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re // Create per-provider circuit breaker if configured cfg := prov.CircuitBreakerConfig() providerName := prov.Name() - onChange := func(endpoint string, from, to gobreaker.State) { + onChange := func(endpoint, model string, from, to gobreaker.State) { logger.Info(context.Background(), "circuit breaker state change", slog.F("provider", providerName), slog.F("endpoint", endpoint), + slog.F("model", model), slog.F("from", from.String()), slog.F("to", to.String()), ) if m != nil { - m.CircuitBreakerState.WithLabelValues(providerName, endpoint).Set(circuitbreaker.StateToGaugeValue(to)) + m.CircuitBreakerState.WithLabelValues(providerName, endpoint, model).Set(circuitbreaker.StateToGaugeValue(to)) if to == gobreaker.StateOpen { - m.CircuitBreakerTrips.WithLabelValues(providerName, endpoint).Inc() + m.CircuitBreakerTrips.WithLabelValues(providerName, endpoint, model).Inc() } } } @@ -88,16 +89,8 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range prov.BridgedRoutes() { - // Initialize circuit breaker state metric to closed (0) for known routes - if m != nil && cbs != nil { - endpoint := strings.TrimPrefix(path, "/"+providerName) - m.CircuitBreakerState.WithLabelValues(providerName, endpoint).Set(0) - } - - handler := newInterceptionProcessor(prov, rec, mcpProxy, logger, m, tracer) - // Wrap with circuit breaker middleware (nil cbs passes through) - wrapped := circuitbreaker.Middleware(cbs, m, logger)(handler) - mux.Handle(path, wrapped) + handler := newInterceptionProcessor(prov, cbs, rec, mcpProxy, logger, m, tracer) + mux.Handle(path, handler) } // Any requests which passthrough to this will be reverse-proxied to the upstream. @@ -132,7 +125,8 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re // newInterceptionProcessor returns an [http.HandlerFunc] which is capable of creating a new interceptor and processing a given request // using [Provider] p, recording all usage events using [Recorder] rec. -func newInterceptionProcessor(p provider.Provider, rec recorder.Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc { +// If cbs is non-nil, circuit breaker protection is applied per endpoint/model tuple. +func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderCircuitBreakers, rec recorder.Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, span := tracer.Start(r.Context(), "Intercept") defer span.End() @@ -202,18 +196,37 @@ func newInterceptionProcessor(p provider.Provider, rec recorder.Recorder, mcpPro }() } - if err := interceptor.ProcessRequest(w, r); err != nil { - if m != nil { - m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusFailed, route, r.Method, actor.ID).Add(1) + // Process request with circuit breaker protection if configured + processRequest := func(rw http.ResponseWriter) { + if err := interceptor.ProcessRequest(rw, r); err != nil { + if m != nil { + m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusFailed, route, r.Method, actor.ID).Add(1) + } + span.SetStatus(codes.Error, fmt.Sprintf("interception failed: %v", err)) + log.Warn(ctx, "interception failed", slog.Error(err)) + } else { + if m != nil { + m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusCompleted, route, r.Method, actor.ID).Add(1) + } + log.Debug(ctx, "interception ended") } - span.SetStatus(codes.Error, fmt.Sprintf("interception failed: %v", err)) - log.Warn(ctx, "interception failed", slog.Error(err)) - } else { - if m != nil { - m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusCompleted, route, r.Method, actor.ID).Add(1) + } + + if cbs != nil { + result := cbs.Execute(route, interceptor.Model(), w, processRequest) + if result.CircuitOpen { + if m != nil { + m.CircuitBreakerRejects.WithLabelValues(p.Name(), route, interceptor.Model()).Inc() + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Retry-After", fmt.Sprintf("%d", int64(cbs.Timeout().Seconds()))) + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write(cbs.OpenErrorResponse()) } - log.Debug(ctx, "interception ended") + } else { + processRequest(w) } + asyncRecorder.RecordInterceptionEnded(ctx, &recorder.InterceptionRecordEnded{ID: interceptor.ID().String()}) // Ensure all recording have completed before completing request. diff --git a/circuit_breaker_integration_test.go b/circuit_breaker_integration_test.go index 1cf1fba6..a004991c 100644 --- a/circuit_breaker_integration_test.go +++ b/circuit_breaker_integration_test.go @@ -2,6 +2,7 @@ package aibridge_test import ( "context" + "fmt" "io" "net" "net/http" @@ -36,6 +37,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { name string providerName string endpoint string + model string errorBody string successBody string requestBody string @@ -48,6 +50,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { name: "Anthropic", providerName: config.ProviderAnthropic, endpoint: "/v1/messages", + model: "claude-sonnet-4-20250514", errorBody: `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`, successBody: `{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-sonnet-4-20250514","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, @@ -67,6 +70,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { name: "OpenAI", providerName: config.ProviderOpenAI, endpoint: "/v1/chat/completions", + model: "gpt-4o", errorBody: `{"error":{"type":"rate_limit_error","message":"rate limited","code":"rate_limit_exceeded"}}`, successBody: `{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, @@ -166,13 +170,13 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load(), "No new upstream call when circuit is open") // Verify metrics show circuit is open - trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.providerName, tc.endpoint)) + trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1") - state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.providerName, tc.endpoint)) + state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open)") - rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.providerName, tc.endpoint)) + rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, 1.0, rejects, "CircuitBreakerRejects should be 1") // Phase 3: Wait for timeout to transition to half-open @@ -188,7 +192,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state") // Verify circuit is now closed - state = promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.providerName, tc.endpoint)) + state = promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, 0.0, state, "CircuitBreakerState should be 0 (closed) after recovery") // Phase 5: Verify circuit is fully functional again @@ -202,7 +206,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { assert.Equal(t, upstreamCallsBefore+4, upstreamCalls.Load(), "All requests should reach upstream after circuit closes") // Rejects count should not have increased - rejects = promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.providerName, tc.endpoint)) + rejects = promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, 1.0, rejects, "CircuitBreakerRejects should still be 1 (no new rejects)") }) } @@ -217,6 +221,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { name string providerName string endpoint string + model string errorBody string requestBody string setupHeaders func(req *http.Request) @@ -228,6 +233,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { name: "Anthropic", providerName: config.ProviderAnthropic, endpoint: "/v1/messages", + model: "claude-sonnet-4-20250514", errorBody: `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`, requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, setupHeaders: func(req *http.Request) { @@ -246,6 +252,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { name: "OpenAI", providerName: config.ProviderOpenAI, endpoint: "/v1/chat/completions", + model: "gpt-4o", errorBody: `{"error":{"type":"rate_limit_error","message":"rate limited","code":"rate_limit_exceeded"}}`, requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, setupHeaders: func(req *http.Request) { @@ -330,7 +337,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { resp := makeRequest() assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) - trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.providerName, tc.endpoint)) + trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1") // Phase 2: Wait for half-open state @@ -348,10 +355,10 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should NOT reach upstream when circuit re-opens") // Verify metrics: trips should be 2 now (tripped twice) - trips = promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.providerName, tc.endpoint)) + trips = promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, 2.0, trips, "CircuitBreakerTrips should be 2 after half-open failure") - state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.providerName, tc.endpoint)) + state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open) after half-open failure") }) } @@ -366,6 +373,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { name string providerName string endpoint string + model string errorBody string successBody string requestBody string @@ -378,6 +386,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { name: "Anthropic", providerName: config.ProviderAnthropic, endpoint: "/v1/messages", + model: "claude-sonnet-4-20250514", errorBody: `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`, successBody: `{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-sonnet-4-20250514","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, @@ -397,6 +406,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { name: "OpenAI", providerName: config.ProviderOpenAI, endpoint: "/v1/chat/completions", + model: "gpt-4o", errorBody: `{"error":{"type":"rate_limit_error","message":"rate limited","code":"rate_limit_exceeded"}}`, successBody: `{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, @@ -536,9 +546,140 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { "%d requests should be rejected (ErrTooManyRequests)", totalRequests-maxRequests) // Verify rejects metric increased - rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.providerName, tc.endpoint)) + rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, float64(1+totalRequests-maxRequests), rejects, "CircuitBreakerRejects should include half-open rejections") }) } } + +// TestCircuitBreaker_PerModelIsolation tests that circuit breakers are independent per model. +// Rate limits on one model should not affect other models on the same endpoint. +func TestCircuitBreaker_PerModelIsolation(t *testing.T) { + t.Parallel() + + var sonnetCalls, haikuCalls atomic.Int32 + var sonnetShouldFail atomic.Bool + sonnetShouldFail.Store(true) + + // Mock upstream that returns different responses based on model in request + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("x-should-retry", "false") + + if strings.Contains(string(body), "claude-sonnet-4-20250514") { + sonnetCalls.Add(1) + if sonnetShouldFail.Load() { + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`)) + } else { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-sonnet-4-20250514","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`)) + } + } else if strings.Contains(string(body), "claude-3-5-haiku-20241022") { + haikuCalls.Add(1) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"msg_02","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-3-5-haiku-20241022","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`)) + } + })) + defer mockUpstream.Close() + + m := metrics.NewMetrics(prometheus.NewRegistry()) + + cbConfig := &config.CircuitBreaker{ + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 1, + } + prov := provider.NewAnthropic(config.Anthropic{ + BaseURL: mockUpstream.URL, + Key: "test-key", + CircuitBreaker: cbConfig, + }, nil) + + ctx := t.Context() + tracer := otel.Tracer("forTesting") + logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + bridge, err := aibridge.NewRequestBridge(ctx, + []provider.Provider{prov}, + &mockRecorderClient{}, + mcp.NewServerProxyManager(nil, tracer), + logger, + m, + tracer, + ) + require.NoError(t, err) + + mockSrv := httptest.NewUnstartedServer(bridge) + t.Cleanup(mockSrv.Close) + mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, "test-user-id", nil) + } + mockSrv.Start() + + makeRequest := func(model string) *http.Response { + body := fmt.Sprintf(`{"model":"%s","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, model) + req, err := http.NewRequest("POST", mockSrv.URL+"/anthropic/v1/messages", strings.NewReader(body)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-api-key", "test") + req.Header.Set("anthropic-version", "2023-06-01") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + resp.Body.Close() + return resp + } + + // Phase 1: Trip the circuit for sonnet model + for i := uint32(0); i < cbConfig.FailureThreshold; i++ { + resp := makeRequest("claude-sonnet-4-20250514") + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + } + assert.Equal(t, int32(cbConfig.FailureThreshold), sonnetCalls.Load()) + + // Verify sonnet circuit is open + resp := makeRequest("claude-sonnet-4-20250514") + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Sonnet circuit should be open") + assert.Equal(t, int32(cbConfig.FailureThreshold), sonnetCalls.Load(), "No new sonnet calls when circuit is open") + + // Verify sonnet metrics show circuit is open + sonnetTrips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-sonnet-4-20250514")) + assert.Equal(t, 1.0, sonnetTrips, "Sonnet CircuitBreakerTrips should be 1") + + sonnetState := promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-sonnet-4-20250514")) + assert.Equal(t, 1.0, sonnetState, "Sonnet CircuitBreakerState should be 1 (open)") + + // Phase 2: Haiku model should still work (independent circuit) + resp = makeRequest("claude-3-5-haiku-20241022") + assert.Equal(t, http.StatusOK, resp.StatusCode, "Haiku should succeed while sonnet circuit is open") + assert.Equal(t, int32(1), haikuCalls.Load(), "Haiku call should reach upstream") + + // Make multiple haiku requests - all should succeed + for i := 0; i < 3; i++ { + resp = makeRequest("claude-3-5-haiku-20241022") + assert.Equal(t, http.StatusOK, resp.StatusCode, "Haiku should continue to succeed") + } + assert.Equal(t, int32(4), haikuCalls.Load(), "All haiku calls should reach upstream") + + // Verify haiku circuit is still closed (no trips) + haikuTrips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-3-5-haiku-20241022")) + assert.Equal(t, 0.0, haikuTrips, "Haiku CircuitBreakerTrips should be 0") + + haikuState := promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-3-5-haiku-20241022")) + assert.Equal(t, 0.0, haikuState, "Haiku CircuitBreakerState should be 0 (closed)") + + // Phase 3: Sonnet recovers after timeout + time.Sleep(cbConfig.Timeout + 10*time.Millisecond) + sonnetShouldFail.Store(false) + + resp = makeRequest("claude-sonnet-4-20250514") + assert.Equal(t, http.StatusOK, resp.StatusCode, "Sonnet should recover after timeout") + + // Verify sonnet circuit is now closed + sonnetState = promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-sonnet-4-20250514")) + assert.Equal(t, 0.0, sonnetState, "Sonnet CircuitBreakerState should be 0 (closed) after recovery") +} diff --git a/circuitbreaker/circuitbreaker.go b/circuitbreaker/circuitbreaker.go index 97d52c31..f1b9b187 100644 --- a/circuitbreaker/circuitbreaker.go +++ b/circuitbreaker/circuitbreaker.go @@ -4,13 +4,10 @@ import ( "errors" "fmt" "net/http" - "strconv" - "strings" "sync" + "time" - "cdr.dev/slog/v3" "github.com/coder/aibridge/config" - "github.com/coder/aibridge/metrics" "github.com/sony/gobreaker/v2" ) @@ -27,17 +24,17 @@ func DefaultIsFailure(statusCode int) bool { } } -// ProviderCircuitBreakers manages per-endpoint circuit breakers for a single provider. +// ProviderCircuitBreakers manages per-endpoint/model circuit breakers for a single provider. type ProviderCircuitBreakers struct { provider string config config.CircuitBreaker - breakers sync.Map // endpoint -> *gobreaker.CircuitBreaker[struct{}] - onChange func(endpoint string, from, to gobreaker.State) + breakers sync.Map // "endpoint:model" -> *gobreaker.CircuitBreaker[struct{}] + onChange func(endpoint, model string, from, to gobreaker.State) } // NewProviderCircuitBreakers creates circuit breakers for a single provider. // Returns nil if cfg is nil (no circuit breaker protection). -func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onChange func(endpoint string, from, to gobreaker.State)) *ProviderCircuitBreakers { +func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onChange func(endpoint, model string, from, to gobreaker.State)) *ProviderCircuitBreakers { if cfg == nil { return nil } @@ -65,14 +62,15 @@ func (p *ProviderCircuitBreakers) openErrorResponse() []byte { return []byte(`{"error":"circuit breaker is open"}`) } -// Get returns the circuit breaker for an endpoint, creating it if needed. -func (p *ProviderCircuitBreakers) Get(endpoint string) *gobreaker.CircuitBreaker[struct{}] { - if v, ok := p.breakers.Load(endpoint); ok { +// Get returns the circuit breaker for an endpoint/model tuple, creating it if needed. +func (p *ProviderCircuitBreakers) Get(endpoint, model string) *gobreaker.CircuitBreaker[struct{}] { + key := endpoint + ":" + model + if v, ok := p.breakers.Load(key); ok { return v.(*gobreaker.CircuitBreaker[struct{}]) } settings := gobreaker.Settings{ - Name: p.provider + ":" + endpoint, + Name: p.provider + ":" + endpoint + ":" + model, MaxRequests: p.config.MaxRequests, Interval: p.config.Interval, Timeout: p.config.Timeout, @@ -81,16 +79,24 @@ func (p *ProviderCircuitBreakers) Get(endpoint string) *gobreaker.CircuitBreaker }, OnStateChange: func(_ string, from, to gobreaker.State) { if p.onChange != nil { - p.onChange(endpoint, from, to) + p.onChange(endpoint, model, from, to) } }, } cb := gobreaker.NewCircuitBreaker[struct{}](settings) - actual, _ := p.breakers.LoadOrStore(endpoint, cb) + actual, _ := p.breakers.LoadOrStore(key, cb) return actual.(*gobreaker.CircuitBreaker[struct{}]) } +// ExecuteResult contains the result of a circuit breaker execution. +type ExecuteResult struct { + // StatusCode is the HTTP status code returned by the handler. + StatusCode int + // CircuitOpen is true if the request was rejected due to an open circuit. + CircuitOpen bool +} + // statusCapturingWriter wraps http.ResponseWriter to capture the status code. // It also implements http.Flusher to support streaming responses. type statusCapturingWriter struct { @@ -126,44 +132,44 @@ func (w *statusCapturingWriter) Unwrap() http.ResponseWriter { return w.ResponseWriter } -// Middleware returns middleware that wraps handlers with circuit breaker protection. -// It captures the response status code to determine success/failure without provider-specific logic. -// If cbs is nil, requests pass through without circuit breaker protection. -func Middleware(cbs *ProviderCircuitBreakers, m *metrics.Metrics, logger slog.Logger) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - // No circuit breaker configured - pass through - if cbs == nil { - return next +// Execute runs the given handler function within circuit breaker protection. +// It returns ExecuteResult with CircuitOpen=true if the circuit is open. +// The handler receives a wrapped ResponseWriter that captures the status code. +func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.ResponseWriter, handler func(http.ResponseWriter)) ExecuteResult { + cb := p.Get(endpoint, model) + + // Wrap response writer to capture status code + sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK} + + _, err := cb.Execute(func() (struct{}, error) { + handler(sw) + if p.isFailure(sw.statusCode) { + return struct{}{}, fmt.Errorf("upstream error: %d", sw.statusCode) } + return struct{}{}, nil + }) - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - endpoint := strings.TrimPrefix(r.URL.Path, "/"+cbs.provider) - cb := cbs.Get(endpoint) - - // Wrap response writer to capture status code - sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK} - - _, err := cb.Execute(func() (struct{}, error) { - next.ServeHTTP(sw, r) - if cbs.isFailure(sw.statusCode) { - return struct{}{}, fmt.Errorf("upstream error: %d", sw.statusCode) - } - return struct{}{}, nil - }) - - if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) { - if m != nil { - m.CircuitBreakerRejects.WithLabelValues(cbs.provider, endpoint).Inc() - } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Retry-After", strconv.FormatInt(int64(cbs.config.Timeout.Seconds()), 10)) - w.WriteHeader(http.StatusServiceUnavailable) - w.Write(cbs.openErrorResponse()) - } else if err != nil { - logger.Warn(r.Context(), "unexpected circuit breaker error", slog.Error(err)) - } - }) + if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) { + return ExecuteResult{CircuitOpen: true} } + + return ExecuteResult{StatusCode: sw.statusCode} +} + +// Timeout returns the configured timeout duration for this circuit breaker. +func (p *ProviderCircuitBreakers) Timeout() time.Duration { + return p.config.Timeout +} + +// Provider returns the provider name for this circuit breaker. +func (p *ProviderCircuitBreakers) Provider() string { + return p.provider +} + +// OpenErrorResponse returns the error response body when the circuit is open. +// This is exposed for handlers to use when responding to rejected requests. +func (p *ProviderCircuitBreakers) OpenErrorResponse() []byte { + return p.openErrorResponse() } // StateToGaugeValue converts gobreaker.State to a gauge value. diff --git a/circuitbreaker/circuitbreaker_test.go b/circuitbreaker/circuitbreaker_test.go index 0e48f6c1..dcfc96ff 100644 --- a/circuitbreaker/circuitbreaker_test.go +++ b/circuitbreaker/circuitbreaker_test.go @@ -7,96 +7,106 @@ import ( "testing" "time" - "cdr.dev/slog/v3" - "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/aibridge/config" "github.com/sony/gobreaker/v2" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func TestMiddleware_PerEndpointIsolation(t *testing.T) { +func TestExecute_PerModelIsolation(t *testing.T) { t.Parallel() - chatCalls := atomic.Int32{} - responsesCalls := atomic.Int32{} - - // Mock upstream - /chat returns 429, /responses returns 200 - upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/test/v1/chat/completions" { - chatCalls.Add(1) - w.WriteHeader(http.StatusTooManyRequests) - } else { - responsesCalls.Add(1) - w.WriteHeader(http.StatusOK) - } - }) + sonnetCalls := atomic.Int32{} + haikuCalls := atomic.Int32{} cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ FailureThreshold: 1, Interval: time.Minute, Timeout: time.Minute, MaxRequests: 1, - }, func(endpoint string, from, to gobreaker.State) {}) - - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - handler := Middleware(cbs, nil, logger)(upstream) - server := httptest.NewServer(handler) - defer server.Close() - - // Trip circuit on /chat/completions - resp, err := http.Get(server.URL + "/test/v1/chat/completions") - require.NoError(t, err) - resp.Body.Close() - - // /chat/completions should now be blocked - resp, err = http.Get(server.URL + "/test/v1/chat/completions") - require.NoError(t, err) - resp.Body.Close() - assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) - assert.Equal(t, "60", resp.Header.Get("Retry-After")) // Timeout is 1 minute - assert.Equal(t, int32(1), chatCalls.Load()) // Only 1 call, second was blocked - - // /responses should still work - resp, err = http.Get(server.URL + "/test/v1/responses") - require.NoError(t, err) - resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.Equal(t, int32(1), responsesCalls.Load()) + }, func(endpoint, model string, from, to gobreaker.State) {}) + + endpoint := "/v1/messages" + sonnetModel := "claude-sonnet-4-20250514" + haikuModel := "claude-3-5-haiku-20241022" + + // Trip circuit on sonnet model (returns 429) + w := httptest.NewRecorder() + result := cbs.Execute(endpoint, sonnetModel, w, func(rw http.ResponseWriter) { + sonnetCalls.Add(1) + rw.WriteHeader(http.StatusTooManyRequests) + }) + assert.False(t, result.CircuitOpen) + assert.Equal(t, http.StatusTooManyRequests, result.StatusCode) + assert.Equal(t, int32(1), sonnetCalls.Load()) + + // Second sonnet request should be blocked by circuit breaker + w = httptest.NewRecorder() + result = cbs.Execute(endpoint, sonnetModel, w, func(rw http.ResponseWriter) { + sonnetCalls.Add(1) + rw.WriteHeader(http.StatusOK) + }) + assert.True(t, result.CircuitOpen) + assert.Equal(t, int32(1), sonnetCalls.Load()) // No new call + + // Haiku model on same endpoint should still work (independent circuit) + w = httptest.NewRecorder() + result = cbs.Execute(endpoint, haikuModel, w, func(rw http.ResponseWriter) { + haikuCalls.Add(1) + rw.WriteHeader(http.StatusOK) + }) + assert.False(t, result.CircuitOpen) + assert.Equal(t, http.StatusOK, result.StatusCode) + assert.Equal(t, int32(1), haikuCalls.Load()) } -func TestMiddleware_NotConfigured(t *testing.T) { +func TestExecute_PerEndpointIsolation(t *testing.T) { t.Parallel() - var upstreamCalls atomic.Int32 + messagesCalls := atomic.Int32{} + completionsCalls := atomic.Int32{} - upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - upstreamCalls.Add(1) - w.WriteHeader(http.StatusTooManyRequests) - }) + cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, + }, func(endpoint, model string, from, to gobreaker.State) {}) - // No circuit breaker configured (nil) - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - handler := Middleware(nil, nil, logger)(upstream) - server := httptest.NewServer(handler) - defer server.Close() - - // All requests should pass through even with 429s - for i := 0; i < 10; i++ { - resp, err := http.Get(server.URL + "/test/v1/messages") - require.NoError(t, err) - resp.Body.Close() - assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) - } - assert.Equal(t, int32(10), upstreamCalls.Load()) + model := "test-model" + + // Trip circuit on /v1/messages endpoint (returns 429) + w := httptest.NewRecorder() + result := cbs.Execute("/v1/messages", model, w, func(rw http.ResponseWriter) { + messagesCalls.Add(1) + rw.WriteHeader(http.StatusTooManyRequests) + }) + assert.False(t, result.CircuitOpen) + assert.Equal(t, int32(1), messagesCalls.Load()) + + // Second /v1/messages request should be blocked + w = httptest.NewRecorder() + result = cbs.Execute("/v1/messages", model, w, func(rw http.ResponseWriter) { + messagesCalls.Add(1) + rw.WriteHeader(http.StatusOK) + }) + assert.True(t, result.CircuitOpen) + assert.Equal(t, int32(1), messagesCalls.Load()) // No new call + + // /v1/chat/completions on same model should still work (different endpoint) + w = httptest.NewRecorder() + result = cbs.Execute("/v1/chat/completions", model, w, func(rw http.ResponseWriter) { + completionsCalls.Add(1) + rw.WriteHeader(http.StatusOK) + }) + assert.False(t, result.CircuitOpen) + assert.Equal(t, http.StatusOK, result.StatusCode) + assert.Equal(t, int32(1), completionsCalls.Load()) } -func TestMiddleware_CustomIsFailure(t *testing.T) { +func TestExecute_CustomIsFailure(t *testing.T) { t.Parallel() - upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadGateway) // 502 - }) + var calls atomic.Int32 // Custom IsFailure that treats 502 as failure cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ @@ -107,21 +117,67 @@ func TestMiddleware_CustomIsFailure(t *testing.T) { IsFailure: func(statusCode int) bool { return statusCode == http.StatusBadGateway }, - }, func(endpoint string, from, to gobreaker.State) {}) - - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - handler := Middleware(cbs, nil, logger)(upstream) - server := httptest.NewServer(handler) - defer server.Close() + }, func(endpoint, model string, from, to gobreaker.State) {}) // First request returns 502, trips circuit - resp, _ := http.Get(server.URL + "/test/v1/messages") - resp.Body.Close() + w := httptest.NewRecorder() + result := cbs.Execute("/v1/messages", "test-model", w, func(rw http.ResponseWriter) { + calls.Add(1) + rw.WriteHeader(http.StatusBadGateway) + }) + assert.False(t, result.CircuitOpen) + assert.Equal(t, http.StatusBadGateway, result.StatusCode) + assert.Equal(t, int32(1), calls.Load()) // Second request should be blocked - resp, _ = http.Get(server.URL + "/test/v1/messages") - assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) - resp.Body.Close() + w = httptest.NewRecorder() + result = cbs.Execute("/v1/messages", "test-model", w, func(rw http.ResponseWriter) { + calls.Add(1) + rw.WriteHeader(http.StatusOK) + }) + assert.True(t, result.CircuitOpen) + assert.Equal(t, int32(1), calls.Load()) // No new call +} + +func TestExecute_OnStateChange(t *testing.T) { + t.Parallel() + + var stateChanges []struct { + endpoint string + model string + from gobreaker.State + to gobreaker.State + } + + cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, + }, func(endpoint, model string, from, to gobreaker.State) { + stateChanges = append(stateChanges, struct { + endpoint string + model string + from gobreaker.State + to gobreaker.State + }{endpoint, model, from, to}) + }) + + endpoint := "/v1/messages" + model := "claude-sonnet-4-20250514" + + // Trip circuit + w := httptest.NewRecorder() + cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) { + rw.WriteHeader(http.StatusTooManyRequests) + }) + + // Verify state change callback was called with correct parameters + assert.Len(t, stateChanges, 1) + assert.Equal(t, endpoint, stateChanges[0].endpoint) + assert.Equal(t, model, stateChanges[0].model) + assert.Equal(t, gobreaker.StateClosed, stateChanges[0].from) + assert.Equal(t, gobreaker.StateOpen, stateChanges[0].to) } func TestDefaultIsFailure(t *testing.T) { diff --git a/metrics/metrics.go b/metrics/metrics.go index 5ae6ac2d..0c5f5f8b 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -110,23 +110,23 @@ func NewMetrics(reg prometheus.Registerer) *Metrics { // Circuit breaker metrics. - // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + // Pessimistic cardinality: 2 providers, 2 endpoints, 5 models = up to 20. CircuitBreakerState: promauto.With(reg).NewGaugeVec(prometheus.GaugeOpts{ Subsystem: "circuit_breaker", Name: "state", Help: "Current state of the circuit breaker (0=closed, 0.5=half-open, 1=open).", - }, []string{"provider", "endpoint"}), - // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + }, []string{"provider", "endpoint", "model"}), + // Pessimistic cardinality: 2 providers, 2 endpoints, 5 models = up to 20. CircuitBreakerTrips: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ Subsystem: "circuit_breaker", Name: "trips_total", Help: "Total number of times the circuit breaker transitioned to open state.", - }, []string{"provider", "endpoint"}), - // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + }, []string{"provider", "endpoint", "model"}), + // Pessimistic cardinality: 2 providers, 2 endpoints, 5 models = up to 20. CircuitBreakerRejects: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ Subsystem: "circuit_breaker", Name: "rejects_total", Help: "Total number of requests rejected due to open circuit breaker.", - }, []string{"provider", "endpoint"}), + }, []string{"provider", "endpoint", "model"}), } } From 3e8fde5ca649e848d36354e845189ca94a8a8689 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Thu, 15 Jan 2026 14:50:19 +0000 Subject: [PATCH 2/6] Apply review suggestions --- bridge.go | 10 +---- circuit_breaker_integration_test.go | 65 ++++++++++++++++----------- circuitbreaker/circuitbreaker.go | 10 +++-- circuitbreaker/circuitbreaker_test.go | 4 -- 4 files changed, 48 insertions(+), 41 deletions(-) diff --git a/bridge.go b/bridge.go index eff06975..6973c8a2 100644 --- a/bridge.go +++ b/bridge.go @@ -214,14 +214,8 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC if cbs != nil { result := cbs.Execute(route, interceptor.Model(), w, processRequest) - if result.CircuitOpen { - if m != nil { - m.CircuitBreakerRejects.WithLabelValues(p.Name(), route, interceptor.Model()).Inc() - } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Retry-After", fmt.Sprintf("%d", int64(cbs.Timeout().Seconds()))) - w.WriteHeader(http.StatusServiceUnavailable) - _, _ = w.Write(cbs.OpenErrorResponse()) + if result.CircuitOpen && m != nil { + m.CircuitBreakerRejects.WithLabelValues(p.Name(), route, interceptor.Model()).Inc() } } else { processRequest(w) diff --git a/circuit_breaker_integration_test.go b/circuit_breaker_integration_test.go index a004991c..4678a7ae 100644 --- a/circuit_breaker_integration_test.go +++ b/circuit_breaker_integration_test.go @@ -28,6 +28,20 @@ import ( "go.opentelemetry.io/otel" ) +// Common response bodies for circuit breaker tests. +const ( + anthropicRateLimitError = `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}` + openAIRateLimitError = `{"error":{"type":"rate_limit_error","message":"rate limited","code":"rate_limit_exceeded"}}` +) + +func anthropicSuccessResponse(model string) string { + return fmt.Sprintf(`{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"%s","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, model) +} + +func openAISuccessResponse(model string) string { + return fmt.Sprintf(`{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, model) +} + // TestCircuitBreaker_FullRecoveryCycle tests the complete circuit breaker lifecycle: // closed → open (after consecutive failures) → half-open (after timeout) → closed (after successful request) func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { @@ -42,6 +56,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { successBody string requestBody string setupHeaders func(req *http.Request) + createRequest func(t *testing.T, baseURL string, input []byte) *http.Request createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider } @@ -51,13 +66,14 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { providerName: config.ProviderAnthropic, endpoint: "/v1/messages", model: "claude-sonnet-4-20250514", - errorBody: `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`, - successBody: `{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-sonnet-4-20250514","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, + errorBody: anthropicRateLimitError, + successBody: anthropicSuccessResponse("claude-sonnet-4-20250514"), requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, setupHeaders: func(req *http.Request) { req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") }, + createRequest: createAnthropicMessagesReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -71,12 +87,13 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { providerName: config.ProviderOpenAI, endpoint: "/v1/chat/completions", model: "gpt-4o", - errorBody: `{"error":{"type":"rate_limit_error","message":"rate limited","code":"rate_limit_exceeded"}}`, - successBody: `{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, + errorBody: openAIRateLimitError, + successBody: openAISuccessResponse("gpt-4o"), requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, setupHeaders: func(req *http.Request) { req.Header.Set("Authorization", "Bearer test-key") }, + createRequest: createOpenAIChatCompletionsReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -143,9 +160,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { mockSrv.Start() makeRequest := func() *http.Response { - req, err := http.NewRequest("POST", mockSrv.URL+"/"+tc.providerName+tc.endpoint, strings.NewReader(tc.requestBody)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") + req := tc.createRequest(t, mockSrv.URL, []byte(tc.requestBody)) tc.setupHeaders(req) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) @@ -225,6 +240,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { errorBody string requestBody string setupHeaders func(req *http.Request) + createRequest func(t *testing.T, baseURL string, input []byte) *http.Request createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider } @@ -234,12 +250,13 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { providerName: config.ProviderAnthropic, endpoint: "/v1/messages", model: "claude-sonnet-4-20250514", - errorBody: `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`, + errorBody: anthropicRateLimitError, requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, setupHeaders: func(req *http.Request) { req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") }, + createRequest: createAnthropicMessagesReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -253,11 +270,12 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { providerName: config.ProviderOpenAI, endpoint: "/v1/chat/completions", model: "gpt-4o", - errorBody: `{"error":{"type":"rate_limit_error","message":"rate limited","code":"rate_limit_exceeded"}}`, + errorBody: openAIRateLimitError, requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, setupHeaders: func(req *http.Request) { req.Header.Set("Authorization", "Bearer test-key") }, + createRequest: createOpenAIChatCompletionsReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -315,9 +333,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { mockSrv.Start() makeRequest := func() *http.Response { - req, err := http.NewRequest("POST", mockSrv.URL+"/"+tc.providerName+tc.endpoint, strings.NewReader(tc.requestBody)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") + req := tc.createRequest(t, mockSrv.URL, []byte(tc.requestBody)) tc.setupHeaders(req) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) @@ -378,6 +394,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { successBody string requestBody string setupHeaders func(req *http.Request) + createRequest func(t *testing.T, baseURL string, input []byte) *http.Request createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider } @@ -387,13 +404,14 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { providerName: config.ProviderAnthropic, endpoint: "/v1/messages", model: "claude-sonnet-4-20250514", - errorBody: `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`, - successBody: `{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-sonnet-4-20250514","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, + errorBody: anthropicRateLimitError, + successBody: anthropicSuccessResponse("claude-sonnet-4-20250514"), requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, setupHeaders: func(req *http.Request) { req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") }, + createRequest: createAnthropicMessagesReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -407,12 +425,13 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { providerName: config.ProviderOpenAI, endpoint: "/v1/chat/completions", model: "gpt-4o", - errorBody: `{"error":{"type":"rate_limit_error","message":"rate limited","code":"rate_limit_exceeded"}}`, - successBody: `{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, + errorBody: openAIRateLimitError, + successBody: openAISuccessResponse("gpt-4o"), requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, setupHeaders: func(req *http.Request) { req.Header.Set("Authorization", "Bearer test-key") }, + createRequest: createOpenAIChatCompletionsReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -480,9 +499,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { mockSrv.Start() makeRequest := func() *http.Response { - req, err := http.NewRequest("POST", mockSrv.URL+"/"+tc.providerName+tc.endpoint, strings.NewReader(tc.requestBody)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") + req := tc.createRequest(t, mockSrv.URL, []byte(tc.requestBody)) tc.setupHeaders(req) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) @@ -572,15 +589,15 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { sonnetCalls.Add(1) if sonnetShouldFail.Load() { w.WriteHeader(http.StatusTooManyRequests) - _, _ = w.Write([]byte(`{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`)) + _, _ = w.Write([]byte(anthropicRateLimitError)) } else { w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-sonnet-4-20250514","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`)) + _, _ = w.Write([]byte(anthropicSuccessResponse("claude-sonnet-4-20250514"))) } } else if strings.Contains(string(body), "claude-3-5-haiku-20241022") { haikuCalls.Add(1) w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"id":"msg_02","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-3-5-haiku-20241022","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`)) + _, _ = w.Write([]byte(anthropicSuccessResponse("claude-3-5-haiku-20241022"))) } })) defer mockUpstream.Close() @@ -621,9 +638,7 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { makeRequest := func(model string) *http.Response { body := fmt.Sprintf(`{"model":"%s","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, model) - req, err := http.NewRequest("POST", mockSrv.URL+"/anthropic/v1/messages", strings.NewReader(body)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") + req := createAnthropicMessagesReq(t, mockSrv.URL, []byte(body)) req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") resp, err := http.DefaultClient.Do(req) diff --git a/circuitbreaker/circuitbreaker.go b/circuitbreaker/circuitbreaker.go index f1b9b187..1e3e6273 100644 --- a/circuitbreaker/circuitbreaker.go +++ b/circuitbreaker/circuitbreaker.go @@ -70,7 +70,7 @@ func (p *ProviderCircuitBreakers) Get(endpoint, model string) *gobreaker.Circuit } settings := gobreaker.Settings{ - Name: p.provider + ":" + endpoint + ":" + model, + Name: p.provider + ":" + key, MaxRequests: p.config.MaxRequests, Interval: p.config.Interval, Timeout: p.config.Timeout, @@ -91,8 +91,6 @@ func (p *ProviderCircuitBreakers) Get(endpoint, model string) *gobreaker.Circuit // ExecuteResult contains the result of a circuit breaker execution. type ExecuteResult struct { - // StatusCode is the HTTP status code returned by the handler. - StatusCode int // CircuitOpen is true if the request was rejected due to an open circuit. CircuitOpen bool } @@ -150,10 +148,14 @@ func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.Respons }) if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Retry-After", fmt.Sprintf("%d", int64(p.config.Timeout.Seconds()))) + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write(p.openErrorResponse()) return ExecuteResult{CircuitOpen: true} } - return ExecuteResult{StatusCode: sw.statusCode} + return ExecuteResult{} } // Timeout returns the configured timeout duration for this circuit breaker. diff --git a/circuitbreaker/circuitbreaker_test.go b/circuitbreaker/circuitbreaker_test.go index dcfc96ff..4c6865f5 100644 --- a/circuitbreaker/circuitbreaker_test.go +++ b/circuitbreaker/circuitbreaker_test.go @@ -36,7 +36,6 @@ func TestExecute_PerModelIsolation(t *testing.T) { rw.WriteHeader(http.StatusTooManyRequests) }) assert.False(t, result.CircuitOpen) - assert.Equal(t, http.StatusTooManyRequests, result.StatusCode) assert.Equal(t, int32(1), sonnetCalls.Load()) // Second sonnet request should be blocked by circuit breaker @@ -55,7 +54,6 @@ func TestExecute_PerModelIsolation(t *testing.T) { rw.WriteHeader(http.StatusOK) }) assert.False(t, result.CircuitOpen) - assert.Equal(t, http.StatusOK, result.StatusCode) assert.Equal(t, int32(1), haikuCalls.Load()) } @@ -99,7 +97,6 @@ func TestExecute_PerEndpointIsolation(t *testing.T) { rw.WriteHeader(http.StatusOK) }) assert.False(t, result.CircuitOpen) - assert.Equal(t, http.StatusOK, result.StatusCode) assert.Equal(t, int32(1), completionsCalls.Load()) } @@ -126,7 +123,6 @@ func TestExecute_CustomIsFailure(t *testing.T) { rw.WriteHeader(http.StatusBadGateway) }) assert.False(t, result.CircuitOpen) - assert.Equal(t, http.StatusBadGateway, result.StatusCode) assert.Equal(t, int32(1), calls.Load()) // Second request should be blocked From da4f5ef656766d128cefcefe7bacec6cb416da46 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Thu, 15 Jan 2026 14:55:31 +0000 Subject: [PATCH 3/6] Update tests --- circuit_breaker_integration_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/circuit_breaker_integration_test.go b/circuit_breaker_integration_test.go index 4678a7ae..d821dd35 100644 --- a/circuit_breaker_integration_test.go +++ b/circuit_breaker_integration_test.go @@ -621,7 +621,7 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) bridge, err := aibridge.NewRequestBridge(ctx, []provider.Provider{prov}, - &mockRecorderClient{}, + &testutil.MockRecorder{}, mcp.NewServerProxyManager(nil, tracer), logger, m, From 1df56e18da8ffbfcaf8ae45498305a8d2ec42818 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Thu, 15 Jan 2026 15:27:56 +0000 Subject: [PATCH 4/6] Apply review suggestions --- bridge.go | 37 +++++++++++----------- circuitbreaker/circuitbreaker.go | 27 +++++++++------- circuitbreaker/circuitbreaker_test.go | 44 ++++++++++++++++----------- 3 files changed, 60 insertions(+), 48 deletions(-) diff --git a/bridge.go b/bridge.go index 6973c8a2..1671da45 100644 --- a/bridge.go +++ b/bridge.go @@ -2,6 +2,7 @@ package aibridge import ( "context" + "errors" "fmt" "net/http" "strings" @@ -197,28 +198,24 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC } // Process request with circuit breaker protection if configured - processRequest := func(rw http.ResponseWriter) { - if err := interceptor.ProcessRequest(rw, r); err != nil { - if m != nil { - m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusFailed, route, r.Method, actor.ID).Add(1) - } - span.SetStatus(codes.Error, fmt.Sprintf("interception failed: %v", err)) - log.Warn(ctx, "interception failed", slog.Error(err)) - } else { - if m != nil { - m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusCompleted, route, r.Method, actor.ID).Add(1) - } - log.Debug(ctx, "interception ended") - } - } - - if cbs != nil { - result := cbs.Execute(route, interceptor.Model(), w, processRequest) - if result.CircuitOpen && m != nil { + switch err := cbs.Execute(route, interceptor.Model(), w, func(rw http.ResponseWriter) error { + return interceptor.ProcessRequest(rw, r) + }); { + case errors.Is(err, circuitbreaker.ErrCircuitOpen): + if m != nil { m.CircuitBreakerRejects.WithLabelValues(p.Name(), route, interceptor.Model()).Inc() } - } else { - processRequest(w) + case err != nil: + if m != nil { + m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusFailed, route, r.Method, actor.ID).Add(1) + } + span.SetStatus(codes.Error, fmt.Sprintf("interception failed: %v", err)) + log.Warn(ctx, "interception failed", slog.Error(err)) + default: + if m != nil { + m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusCompleted, route, r.Method, actor.ID).Add(1) + } + log.Debug(ctx, "interception ended") } asyncRecorder.RecordInterceptionEnded(ctx, &recorder.InterceptionRecordEnded{ID: interceptor.ID().String()}) diff --git a/circuitbreaker/circuitbreaker.go b/circuitbreaker/circuitbreaker.go index 1e3e6273..066ef47f 100644 --- a/circuitbreaker/circuitbreaker.go +++ b/circuitbreaker/circuitbreaker.go @@ -11,6 +11,10 @@ import ( "github.com/sony/gobreaker/v2" ) +// ErrCircuitOpen is returned by Execute when the circuit breaker is open +// and the request was rejected without calling the handler. +var ErrCircuitOpen = errors.New("circuit breaker is open") + // DefaultIsFailure returns true for standard HTTP status codes that typically // indicate upstream overload. func DefaultIsFailure(statusCode int) bool { @@ -89,12 +93,6 @@ func (p *ProviderCircuitBreakers) Get(endpoint, model string) *gobreaker.Circuit return actual.(*gobreaker.CircuitBreaker[struct{}]) } -// ExecuteResult contains the result of a circuit breaker execution. -type ExecuteResult struct { - // CircuitOpen is true if the request was rejected due to an open circuit. - CircuitOpen bool -} - // statusCapturingWriter wraps http.ResponseWriter to capture the status code. // It also implements http.Flusher to support streaming responses. type statusCapturingWriter struct { @@ -131,16 +129,23 @@ func (w *statusCapturingWriter) Unwrap() http.ResponseWriter { } // Execute runs the given handler function within circuit breaker protection. -// It returns ExecuteResult with CircuitOpen=true if the circuit is open. +// It returns ErrCircuitOpen if the request was rejected due to an open circuit. +// Otherwise, it returns the handler's error (or nil on success). // The handler receives a wrapped ResponseWriter that captures the status code. -func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.ResponseWriter, handler func(http.ResponseWriter)) ExecuteResult { +// If the receiver is nil (no circuit breaker configured), the handler is called directly. +func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.ResponseWriter, handler func(http.ResponseWriter) error) error { + if p == nil { + return handler(w) + } + cb := p.Get(endpoint, model) // Wrap response writer to capture status code sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK} + var handlerErr error _, err := cb.Execute(func() (struct{}, error) { - handler(sw) + handlerErr = handler(sw) if p.isFailure(sw.statusCode) { return struct{}{}, fmt.Errorf("upstream error: %d", sw.statusCode) } @@ -152,10 +157,10 @@ func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.Respons w.Header().Set("Retry-After", fmt.Sprintf("%d", int64(p.config.Timeout.Seconds()))) w.WriteHeader(http.StatusServiceUnavailable) _, _ = w.Write(p.openErrorResponse()) - return ExecuteResult{CircuitOpen: true} + return ErrCircuitOpen } - return ExecuteResult{} + return handlerErr } // Timeout returns the configured timeout duration for this circuit breaker. diff --git a/circuitbreaker/circuitbreaker_test.go b/circuitbreaker/circuitbreaker_test.go index 4c6865f5..ab6a30aa 100644 --- a/circuitbreaker/circuitbreaker_test.go +++ b/circuitbreaker/circuitbreaker_test.go @@ -1,6 +1,7 @@ package circuitbreaker import ( + "errors" "net/http" "net/http/httptest" "sync/atomic" @@ -31,29 +32,32 @@ func TestExecute_PerModelIsolation(t *testing.T) { // Trip circuit on sonnet model (returns 429) w := httptest.NewRecorder() - result := cbs.Execute(endpoint, sonnetModel, w, func(rw http.ResponseWriter) { + err := cbs.Execute(endpoint, sonnetModel, w, func(rw http.ResponseWriter) error { sonnetCalls.Add(1) rw.WriteHeader(http.StatusTooManyRequests) + return nil }) - assert.False(t, result.CircuitOpen) + assert.NoError(t, err) assert.Equal(t, int32(1), sonnetCalls.Load()) // Second sonnet request should be blocked by circuit breaker w = httptest.NewRecorder() - result = cbs.Execute(endpoint, sonnetModel, w, func(rw http.ResponseWriter) { + err = cbs.Execute(endpoint, sonnetModel, w, func(rw http.ResponseWriter) error { sonnetCalls.Add(1) rw.WriteHeader(http.StatusOK) + return nil }) - assert.True(t, result.CircuitOpen) + assert.True(t, errors.Is(err, ErrCircuitOpen)) assert.Equal(t, int32(1), sonnetCalls.Load()) // No new call // Haiku model on same endpoint should still work (independent circuit) w = httptest.NewRecorder() - result = cbs.Execute(endpoint, haikuModel, w, func(rw http.ResponseWriter) { + err = cbs.Execute(endpoint, haikuModel, w, func(rw http.ResponseWriter) error { haikuCalls.Add(1) rw.WriteHeader(http.StatusOK) + return nil }) - assert.False(t, result.CircuitOpen) + assert.NoError(t, err) assert.Equal(t, int32(1), haikuCalls.Load()) } @@ -74,29 +78,32 @@ func TestExecute_PerEndpointIsolation(t *testing.T) { // Trip circuit on /v1/messages endpoint (returns 429) w := httptest.NewRecorder() - result := cbs.Execute("/v1/messages", model, w, func(rw http.ResponseWriter) { + err := cbs.Execute("/v1/messages", model, w, func(rw http.ResponseWriter) error { messagesCalls.Add(1) rw.WriteHeader(http.StatusTooManyRequests) + return nil }) - assert.False(t, result.CircuitOpen) + assert.NoError(t, err) assert.Equal(t, int32(1), messagesCalls.Load()) // Second /v1/messages request should be blocked w = httptest.NewRecorder() - result = cbs.Execute("/v1/messages", model, w, func(rw http.ResponseWriter) { + err = cbs.Execute("/v1/messages", model, w, func(rw http.ResponseWriter) error { messagesCalls.Add(1) rw.WriteHeader(http.StatusOK) + return nil }) - assert.True(t, result.CircuitOpen) + assert.True(t, errors.Is(err, ErrCircuitOpen)) assert.Equal(t, int32(1), messagesCalls.Load()) // No new call // /v1/chat/completions on same model should still work (different endpoint) w = httptest.NewRecorder() - result = cbs.Execute("/v1/chat/completions", model, w, func(rw http.ResponseWriter) { + err = cbs.Execute("/v1/chat/completions", model, w, func(rw http.ResponseWriter) error { completionsCalls.Add(1) rw.WriteHeader(http.StatusOK) + return nil }) - assert.False(t, result.CircuitOpen) + assert.NoError(t, err) assert.Equal(t, int32(1), completionsCalls.Load()) } @@ -118,20 +125,22 @@ func TestExecute_CustomIsFailure(t *testing.T) { // First request returns 502, trips circuit w := httptest.NewRecorder() - result := cbs.Execute("/v1/messages", "test-model", w, func(rw http.ResponseWriter) { + err := cbs.Execute("/v1/messages", "test-model", w, func(rw http.ResponseWriter) error { calls.Add(1) rw.WriteHeader(http.StatusBadGateway) + return nil }) - assert.False(t, result.CircuitOpen) + assert.NoError(t, err) assert.Equal(t, int32(1), calls.Load()) // Second request should be blocked w = httptest.NewRecorder() - result = cbs.Execute("/v1/messages", "test-model", w, func(rw http.ResponseWriter) { + err = cbs.Execute("/v1/messages", "test-model", w, func(rw http.ResponseWriter) error { calls.Add(1) rw.WriteHeader(http.StatusOK) + return nil }) - assert.True(t, result.CircuitOpen) + assert.True(t, errors.Is(err, ErrCircuitOpen)) assert.Equal(t, int32(1), calls.Load()) // No new call } @@ -164,8 +173,9 @@ func TestExecute_OnStateChange(t *testing.T) { // Trip circuit w := httptest.NewRecorder() - cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) { + cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) error { rw.WriteHeader(http.StatusTooManyRequests) + return nil }) // Verify state change callback was called with correct parameters From eb95b350226957c84a8fb8525ae53e52535846cd Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Thu, 15 Jan 2026 16:06:59 +0000 Subject: [PATCH 5/6] Apply review suggestions --- bridge.go | 14 ++++---------- circuitbreaker/circuitbreaker.go | 18 +++++++++++------- circuitbreaker/circuitbreaker_test.go | 18 ++++++++++-------- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/bridge.go b/bridge.go index 1671da45..f2a2dd51 100644 --- a/bridge.go +++ b/bridge.go @@ -2,7 +2,6 @@ package aibridge import ( "context" - "errors" "fmt" "net/http" "strings" @@ -86,7 +85,7 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re } } } - cbs := circuitbreaker.NewProviderCircuitBreakers(providerName, cfg, onChange) + cbs := circuitbreaker.NewProviderCircuitBreakers(providerName, cfg, onChange, m) // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range prov.BridgedRoutes() { @@ -198,20 +197,15 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC } // Process request with circuit breaker protection if configured - switch err := cbs.Execute(route, interceptor.Model(), w, func(rw http.ResponseWriter) error { + if err := cbs.Execute(route, interceptor.Model(), w, func(rw http.ResponseWriter) error { return interceptor.ProcessRequest(rw, r) - }); { - case errors.Is(err, circuitbreaker.ErrCircuitOpen): - if m != nil { - m.CircuitBreakerRejects.WithLabelValues(p.Name(), route, interceptor.Model()).Inc() - } - case err != nil: + }); err != nil { if m != nil { m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusFailed, route, r.Method, actor.ID).Add(1) } span.SetStatus(codes.Error, fmt.Sprintf("interception failed: %v", err)) log.Warn(ctx, "interception failed", slog.Error(err)) - default: + } else { if m != nil { m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusCompleted, route, r.Method, actor.ID).Add(1) } diff --git a/circuitbreaker/circuitbreaker.go b/circuitbreaker/circuitbreaker.go index 066ef47f..a8bc3180 100644 --- a/circuitbreaker/circuitbreaker.go +++ b/circuitbreaker/circuitbreaker.go @@ -8,13 +8,10 @@ import ( "time" "github.com/coder/aibridge/config" + "github.com/coder/aibridge/metrics" "github.com/sony/gobreaker/v2" ) -// ErrCircuitOpen is returned by Execute when the circuit breaker is open -// and the request was rejected without calling the handler. -var ErrCircuitOpen = errors.New("circuit breaker is open") - // DefaultIsFailure returns true for standard HTTP status codes that typically // indicate upstream overload. func DefaultIsFailure(statusCode int) bool { @@ -34,11 +31,14 @@ type ProviderCircuitBreakers struct { config config.CircuitBreaker breakers sync.Map // "endpoint:model" -> *gobreaker.CircuitBreaker[struct{}] onChange func(endpoint, model string, from, to gobreaker.State) + metrics *metrics.Metrics } // NewProviderCircuitBreakers creates circuit breakers for a single provider. // Returns nil if cfg is nil (no circuit breaker protection). -func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onChange func(endpoint, model string, from, to gobreaker.State)) *ProviderCircuitBreakers { +// onChange is called when circuit state changes. +// metrics is used to record circuit breaker reject counts (can be nil). +func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onChange func(endpoint, model string, from, to gobreaker.State), m *metrics.Metrics) *ProviderCircuitBreakers { if cfg == nil { return nil } @@ -46,6 +46,7 @@ func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onC provider: provider, config: *cfg, onChange: onChange, + metrics: m, } } @@ -129,7 +130,7 @@ func (w *statusCapturingWriter) Unwrap() http.ResponseWriter { } // Execute runs the given handler function within circuit breaker protection. -// It returns ErrCircuitOpen if the request was rejected due to an open circuit. +// If the circuit is open, the request is rejected with a 503 response and metrics are recorded. // Otherwise, it returns the handler's error (or nil on success). // The handler receives a wrapped ResponseWriter that captures the status code. // If the receiver is nil (no circuit breaker configured), the handler is called directly. @@ -153,11 +154,14 @@ func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.Respons }) if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) { + if p.metrics != nil { + p.metrics.CircuitBreakerRejects.WithLabelValues(p.provider, endpoint, model).Inc() + } w.Header().Set("Content-Type", "application/json") w.Header().Set("Retry-After", fmt.Sprintf("%d", int64(p.config.Timeout.Seconds()))) w.WriteHeader(http.StatusServiceUnavailable) _, _ = w.Write(p.openErrorResponse()) - return ErrCircuitOpen + return nil } return handlerErr diff --git a/circuitbreaker/circuitbreaker_test.go b/circuitbreaker/circuitbreaker_test.go index ab6a30aa..29ab8bb1 100644 --- a/circuitbreaker/circuitbreaker_test.go +++ b/circuitbreaker/circuitbreaker_test.go @@ -1,7 +1,6 @@ package circuitbreaker import ( - "errors" "net/http" "net/http/httptest" "sync/atomic" @@ -24,7 +23,7 @@ func TestExecute_PerModelIsolation(t *testing.T) { Interval: time.Minute, Timeout: time.Minute, MaxRequests: 1, - }, func(endpoint, model string, from, to gobreaker.State) {}) + }, func(endpoint, model string, from, to gobreaker.State) {}, nil) endpoint := "/v1/messages" sonnetModel := "claude-sonnet-4-20250514" @@ -47,8 +46,9 @@ func TestExecute_PerModelIsolation(t *testing.T) { rw.WriteHeader(http.StatusOK) return nil }) - assert.True(t, errors.Is(err, ErrCircuitOpen)) + assert.NoError(t, err) assert.Equal(t, int32(1), sonnetCalls.Load()) // No new call + assert.Equal(t, http.StatusServiceUnavailable, w.Code) // Haiku model on same endpoint should still work (independent circuit) w = httptest.NewRecorder() @@ -72,7 +72,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) { Interval: time.Minute, Timeout: time.Minute, MaxRequests: 1, - }, func(endpoint, model string, from, to gobreaker.State) {}) + }, func(endpoint, model string, from, to gobreaker.State) {}, nil) model := "test-model" @@ -93,8 +93,9 @@ func TestExecute_PerEndpointIsolation(t *testing.T) { rw.WriteHeader(http.StatusOK) return nil }) - assert.True(t, errors.Is(err, ErrCircuitOpen)) + assert.NoError(t, err) assert.Equal(t, int32(1), messagesCalls.Load()) // No new call + assert.Equal(t, http.StatusServiceUnavailable, w.Code) // /v1/chat/completions on same model should still work (different endpoint) w = httptest.NewRecorder() @@ -121,7 +122,7 @@ func TestExecute_CustomIsFailure(t *testing.T) { IsFailure: func(statusCode int) bool { return statusCode == http.StatusBadGateway }, - }, func(endpoint, model string, from, to gobreaker.State) {}) + }, func(endpoint, model string, from, to gobreaker.State) {}, nil) // First request returns 502, trips circuit w := httptest.NewRecorder() @@ -140,8 +141,9 @@ func TestExecute_CustomIsFailure(t *testing.T) { rw.WriteHeader(http.StatusOK) return nil }) - assert.True(t, errors.Is(err, ErrCircuitOpen)) + assert.NoError(t, err) assert.Equal(t, int32(1), calls.Load()) // No new call + assert.Equal(t, http.StatusServiceUnavailable, w.Code) } func TestExecute_OnStateChange(t *testing.T) { @@ -166,7 +168,7 @@ func TestExecute_OnStateChange(t *testing.T) { from gobreaker.State to gobreaker.State }{endpoint, model, from, to}) - }) + }, nil) endpoint := "/v1/messages" model := "claude-sonnet-4-20250514" From dd13bd548ca522c772adb07e94834807c6e16502 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Fri, 16 Jan 2026 11:58:13 +0000 Subject: [PATCH 6/6] Apply review suggestions --- circuitbreaker/circuitbreaker.go | 9 +++++++-- circuitbreaker/circuitbreaker_test.go | 7 ++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/circuitbreaker/circuitbreaker.go b/circuitbreaker/circuitbreaker.go index a8bc3180..4be1d2b8 100644 --- a/circuitbreaker/circuitbreaker.go +++ b/circuitbreaker/circuitbreaker.go @@ -12,6 +12,10 @@ import ( "github.com/sony/gobreaker/v2" ) +// ErrCircuitOpen is returned by Execute when the circuit breaker is open +// and the request was rejected without calling the handler. +var ErrCircuitOpen = errors.New("circuit breaker is open") + // DefaultIsFailure returns true for standard HTTP status codes that typically // indicate upstream overload. func DefaultIsFailure(statusCode int) bool { @@ -130,7 +134,8 @@ func (w *statusCapturingWriter) Unwrap() http.ResponseWriter { } // Execute runs the given handler function within circuit breaker protection. -// If the circuit is open, the request is rejected with a 503 response and metrics are recorded. +// If the circuit is open, the request is rejected with a 503 response, metrics are recorded, +// and ErrCircuitOpen is returned. // Otherwise, it returns the handler's error (or nil on success). // The handler receives a wrapped ResponseWriter that captures the status code. // If the receiver is nil (no circuit breaker configured), the handler is called directly. @@ -161,7 +166,7 @@ func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.Respons w.Header().Set("Retry-After", fmt.Sprintf("%d", int64(p.config.Timeout.Seconds()))) w.WriteHeader(http.StatusServiceUnavailable) _, _ = w.Write(p.openErrorResponse()) - return nil + return ErrCircuitOpen } return handlerErr diff --git a/circuitbreaker/circuitbreaker_test.go b/circuitbreaker/circuitbreaker_test.go index 29ab8bb1..18913718 100644 --- a/circuitbreaker/circuitbreaker_test.go +++ b/circuitbreaker/circuitbreaker_test.go @@ -1,6 +1,7 @@ package circuitbreaker import ( + "errors" "net/http" "net/http/httptest" "sync/atomic" @@ -46,7 +47,7 @@ func TestExecute_PerModelIsolation(t *testing.T) { rw.WriteHeader(http.StatusOK) return nil }) - assert.NoError(t, err) + assert.True(t, errors.Is(err, ErrCircuitOpen)) assert.Equal(t, int32(1), sonnetCalls.Load()) // No new call assert.Equal(t, http.StatusServiceUnavailable, w.Code) @@ -93,7 +94,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) { rw.WriteHeader(http.StatusOK) return nil }) - assert.NoError(t, err) + assert.True(t, errors.Is(err, ErrCircuitOpen)) assert.Equal(t, int32(1), messagesCalls.Load()) // No new call assert.Equal(t, http.StatusServiceUnavailable, w.Code) @@ -141,7 +142,7 @@ func TestExecute_CustomIsFailure(t *testing.T) { rw.WriteHeader(http.StatusOK) return nil }) - assert.NoError(t, err) + assert.True(t, errors.Is(err, ErrCircuitOpen)) assert.Equal(t, int32(1), calls.Load()) // No new call assert.Equal(t, http.StatusServiceUnavailable, w.Code) }