diff --git a/bridge.go b/bridge.go index 9983a485..f2a2dd51 100644 --- a/bridge.go +++ b/bridge.go @@ -70,34 +70,27 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re // Create per-provider circuit breaker if configured cfg := prov.CircuitBreakerConfig() providerName := prov.Name() - onChange := func(endpoint string, from, to gobreaker.State) { + onChange := func(endpoint, model string, from, to gobreaker.State) { logger.Info(context.Background(), "circuit breaker state change", slog.F("provider", providerName), slog.F("endpoint", endpoint), + slog.F("model", model), slog.F("from", from.String()), slog.F("to", to.String()), ) if m != nil { - m.CircuitBreakerState.WithLabelValues(providerName, endpoint).Set(circuitbreaker.StateToGaugeValue(to)) + m.CircuitBreakerState.WithLabelValues(providerName, endpoint, model).Set(circuitbreaker.StateToGaugeValue(to)) if to == gobreaker.StateOpen { - m.CircuitBreakerTrips.WithLabelValues(providerName, endpoint).Inc() + m.CircuitBreakerTrips.WithLabelValues(providerName, endpoint, model).Inc() } } } - cbs := circuitbreaker.NewProviderCircuitBreakers(providerName, cfg, onChange) + cbs := circuitbreaker.NewProviderCircuitBreakers(providerName, cfg, onChange, m) // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range prov.BridgedRoutes() { - // Initialize circuit breaker state metric to closed (0) for known routes - if m != nil && cbs != nil { - endpoint := strings.TrimPrefix(path, "/"+providerName) - m.CircuitBreakerState.WithLabelValues(providerName, endpoint).Set(0) - } - - handler := newInterceptionProcessor(prov, rec, mcpProxy, logger, m, tracer) - // Wrap with circuit breaker middleware (nil cbs passes through) - wrapped := circuitbreaker.Middleware(cbs, m, logger)(handler) - mux.Handle(path, wrapped) + handler := newInterceptionProcessor(prov, cbs, rec, mcpProxy, logger, m, tracer) + mux.Handle(path, handler) } // Any requests which passthrough to this will be reverse-proxied to the upstream. @@ -132,7 +125,8 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re // newInterceptionProcessor returns an [http.HandlerFunc] which is capable of creating a new interceptor and processing a given request // using [Provider] p, recording all usage events using [Recorder] rec. -func newInterceptionProcessor(p provider.Provider, rec recorder.Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc { +// If cbs is non-nil, circuit breaker protection is applied per endpoint/model tuple. +func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderCircuitBreakers, rec recorder.Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, span := tracer.Start(r.Context(), "Intercept") defer span.End() @@ -202,7 +196,10 @@ func newInterceptionProcessor(p provider.Provider, rec recorder.Recorder, mcpPro }() } - if err := interceptor.ProcessRequest(w, r); err != nil { + // Process request with circuit breaker protection if configured + if err := cbs.Execute(route, interceptor.Model(), w, func(rw http.ResponseWriter) error { + return interceptor.ProcessRequest(rw, r) + }); err != nil { if m != nil { m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusFailed, route, r.Method, actor.ID).Add(1) } @@ -214,6 +211,7 @@ func newInterceptionProcessor(p provider.Provider, rec recorder.Recorder, mcpPro } log.Debug(ctx, "interception ended") } + asyncRecorder.RecordInterceptionEnded(ctx, &recorder.InterceptionRecordEnded{ID: interceptor.ID().String()}) // Ensure all recording have completed before completing request. diff --git a/circuit_breaker_integration_test.go b/circuit_breaker_integration_test.go index 1cf1fba6..d821dd35 100644 --- a/circuit_breaker_integration_test.go +++ b/circuit_breaker_integration_test.go @@ -2,6 +2,7 @@ package aibridge_test import ( "context" + "fmt" "io" "net" "net/http" @@ -27,6 +28,20 @@ import ( "go.opentelemetry.io/otel" ) +// Common response bodies for circuit breaker tests. +const ( + anthropicRateLimitError = `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}` + openAIRateLimitError = `{"error":{"type":"rate_limit_error","message":"rate limited","code":"rate_limit_exceeded"}}` +) + +func anthropicSuccessResponse(model string) string { + return fmt.Sprintf(`{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"%s","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, model) +} + +func openAISuccessResponse(model string) string { + return fmt.Sprintf(`{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, model) +} + // TestCircuitBreaker_FullRecoveryCycle tests the complete circuit breaker lifecycle: // closed → open (after consecutive failures) → half-open (after timeout) → closed (after successful request) func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { @@ -36,10 +51,12 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { name string providerName string endpoint string + model string errorBody string successBody string requestBody string setupHeaders func(req *http.Request) + createRequest func(t *testing.T, baseURL string, input []byte) *http.Request createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider } @@ -48,13 +65,15 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { name: "Anthropic", providerName: config.ProviderAnthropic, endpoint: "/v1/messages", - errorBody: `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`, - successBody: `{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-sonnet-4-20250514","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, + model: "claude-sonnet-4-20250514", + errorBody: anthropicRateLimitError, + successBody: anthropicSuccessResponse("claude-sonnet-4-20250514"), requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, setupHeaders: func(req *http.Request) { req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") }, + createRequest: createAnthropicMessagesReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -67,12 +86,14 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { name: "OpenAI", providerName: config.ProviderOpenAI, endpoint: "/v1/chat/completions", - errorBody: `{"error":{"type":"rate_limit_error","message":"rate limited","code":"rate_limit_exceeded"}}`, - successBody: `{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, + model: "gpt-4o", + errorBody: openAIRateLimitError, + successBody: openAISuccessResponse("gpt-4o"), requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, setupHeaders: func(req *http.Request) { req.Header.Set("Authorization", "Bearer test-key") }, + createRequest: createOpenAIChatCompletionsReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -139,9 +160,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { mockSrv.Start() makeRequest := func() *http.Response { - req, err := http.NewRequest("POST", mockSrv.URL+"/"+tc.providerName+tc.endpoint, strings.NewReader(tc.requestBody)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") + req := tc.createRequest(t, mockSrv.URL, []byte(tc.requestBody)) tc.setupHeaders(req) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) @@ -166,13 +185,13 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load(), "No new upstream call when circuit is open") // Verify metrics show circuit is open - trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.providerName, tc.endpoint)) + trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1") - state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.providerName, tc.endpoint)) + state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open)") - rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.providerName, tc.endpoint)) + rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, 1.0, rejects, "CircuitBreakerRejects should be 1") // Phase 3: Wait for timeout to transition to half-open @@ -188,7 +207,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state") // Verify circuit is now closed - state = promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.providerName, tc.endpoint)) + state = promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, 0.0, state, "CircuitBreakerState should be 0 (closed) after recovery") // Phase 5: Verify circuit is fully functional again @@ -202,7 +221,7 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { assert.Equal(t, upstreamCallsBefore+4, upstreamCalls.Load(), "All requests should reach upstream after circuit closes") // Rejects count should not have increased - rejects = promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.providerName, tc.endpoint)) + rejects = promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, 1.0, rejects, "CircuitBreakerRejects should still be 1 (no new rejects)") }) } @@ -217,9 +236,11 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { name string providerName string endpoint string + model string errorBody string requestBody string setupHeaders func(req *http.Request) + createRequest func(t *testing.T, baseURL string, input []byte) *http.Request createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider } @@ -228,12 +249,14 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { name: "Anthropic", providerName: config.ProviderAnthropic, endpoint: "/v1/messages", - errorBody: `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`, + model: "claude-sonnet-4-20250514", + errorBody: anthropicRateLimitError, requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, setupHeaders: func(req *http.Request) { req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") }, + createRequest: createAnthropicMessagesReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -246,11 +269,13 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { name: "OpenAI", providerName: config.ProviderOpenAI, endpoint: "/v1/chat/completions", - errorBody: `{"error":{"type":"rate_limit_error","message":"rate limited","code":"rate_limit_exceeded"}}`, + model: "gpt-4o", + errorBody: openAIRateLimitError, requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, setupHeaders: func(req *http.Request) { req.Header.Set("Authorization", "Bearer test-key") }, + createRequest: createOpenAIChatCompletionsReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -308,9 +333,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { mockSrv.Start() makeRequest := func() *http.Response { - req, err := http.NewRequest("POST", mockSrv.URL+"/"+tc.providerName+tc.endpoint, strings.NewReader(tc.requestBody)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") + req := tc.createRequest(t, mockSrv.URL, []byte(tc.requestBody)) tc.setupHeaders(req) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) @@ -330,7 +353,7 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { resp := makeRequest() assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) - trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.providerName, tc.endpoint)) + trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1") // Phase 2: Wait for half-open state @@ -348,10 +371,10 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should NOT reach upstream when circuit re-opens") // Verify metrics: trips should be 2 now (tripped twice) - trips = promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.providerName, tc.endpoint)) + trips = promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, 2.0, trips, "CircuitBreakerTrips should be 2 after half-open failure") - state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.providerName, tc.endpoint)) + state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open) after half-open failure") }) } @@ -366,10 +389,12 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { name string providerName string endpoint string + model string errorBody string successBody string requestBody string setupHeaders func(req *http.Request) + createRequest func(t *testing.T, baseURL string, input []byte) *http.Request createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider } @@ -378,13 +403,15 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { name: "Anthropic", providerName: config.ProviderAnthropic, endpoint: "/v1/messages", - errorBody: `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`, - successBody: `{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-sonnet-4-20250514","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, + model: "claude-sonnet-4-20250514", + errorBody: anthropicRateLimitError, + successBody: anthropicSuccessResponse("claude-sonnet-4-20250514"), requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, setupHeaders: func(req *http.Request) { req.Header.Set("x-api-key", "test") req.Header.Set("anthropic-version", "2023-06-01") }, + createRequest: createAnthropicMessagesReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewAnthropic(config.Anthropic{ BaseURL: baseURL, @@ -397,12 +424,14 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { name: "OpenAI", providerName: config.ProviderOpenAI, endpoint: "/v1/chat/completions", - errorBody: `{"error":{"type":"rate_limit_error","message":"rate limited","code":"rate_limit_exceeded"}}`, - successBody: `{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, + model: "gpt-4o", + errorBody: openAIRateLimitError, + successBody: openAISuccessResponse("gpt-4o"), requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, setupHeaders: func(req *http.Request) { req.Header.Set("Authorization", "Bearer test-key") }, + createRequest: createOpenAIChatCompletionsReq, createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { return provider.NewOpenAI(config.OpenAI{ BaseURL: baseURL, @@ -470,9 +499,7 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { mockSrv.Start() makeRequest := func() *http.Response { - req, err := http.NewRequest("POST", mockSrv.URL+"/"+tc.providerName+tc.endpoint, strings.NewReader(tc.requestBody)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") + req := tc.createRequest(t, mockSrv.URL, []byte(tc.requestBody)) tc.setupHeaders(req) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) @@ -536,9 +563,138 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { "%d requests should be rejected (ErrTooManyRequests)", totalRequests-maxRequests) // Verify rejects metric increased - rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.providerName, tc.endpoint)) + rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.providerName, tc.endpoint, tc.model)) assert.Equal(t, float64(1+totalRequests-maxRequests), rejects, "CircuitBreakerRejects should include half-open rejections") }) } } + +// TestCircuitBreaker_PerModelIsolation tests that circuit breakers are independent per model. +// Rate limits on one model should not affect other models on the same endpoint. +func TestCircuitBreaker_PerModelIsolation(t *testing.T) { + t.Parallel() + + var sonnetCalls, haikuCalls atomic.Int32 + var sonnetShouldFail atomic.Bool + sonnetShouldFail.Store(true) + + // Mock upstream that returns different responses based on model in request + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("x-should-retry", "false") + + if strings.Contains(string(body), "claude-sonnet-4-20250514") { + sonnetCalls.Add(1) + if sonnetShouldFail.Load() { + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(anthropicRateLimitError)) + } else { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(anthropicSuccessResponse("claude-sonnet-4-20250514"))) + } + } else if strings.Contains(string(body), "claude-3-5-haiku-20241022") { + haikuCalls.Add(1) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(anthropicSuccessResponse("claude-3-5-haiku-20241022"))) + } + })) + defer mockUpstream.Close() + + m := metrics.NewMetrics(prometheus.NewRegistry()) + + cbConfig := &config.CircuitBreaker{ + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 1, + } + prov := provider.NewAnthropic(config.Anthropic{ + BaseURL: mockUpstream.URL, + Key: "test-key", + CircuitBreaker: cbConfig, + }, nil) + + ctx := t.Context() + tracer := otel.Tracer("forTesting") + logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + bridge, err := aibridge.NewRequestBridge(ctx, + []provider.Provider{prov}, + &testutil.MockRecorder{}, + mcp.NewServerProxyManager(nil, tracer), + logger, + m, + tracer, + ) + require.NoError(t, err) + + mockSrv := httptest.NewUnstartedServer(bridge) + t.Cleanup(mockSrv.Close) + mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, "test-user-id", nil) + } + mockSrv.Start() + + makeRequest := func(model string) *http.Response { + body := fmt.Sprintf(`{"model":"%s","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, model) + req := createAnthropicMessagesReq(t, mockSrv.URL, []byte(body)) + req.Header.Set("x-api-key", "test") + req.Header.Set("anthropic-version", "2023-06-01") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + resp.Body.Close() + return resp + } + + // Phase 1: Trip the circuit for sonnet model + for i := uint32(0); i < cbConfig.FailureThreshold; i++ { + resp := makeRequest("claude-sonnet-4-20250514") + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + } + assert.Equal(t, int32(cbConfig.FailureThreshold), sonnetCalls.Load()) + + // Verify sonnet circuit is open + resp := makeRequest("claude-sonnet-4-20250514") + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Sonnet circuit should be open") + assert.Equal(t, int32(cbConfig.FailureThreshold), sonnetCalls.Load(), "No new sonnet calls when circuit is open") + + // Verify sonnet metrics show circuit is open + sonnetTrips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-sonnet-4-20250514")) + assert.Equal(t, 1.0, sonnetTrips, "Sonnet CircuitBreakerTrips should be 1") + + sonnetState := promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-sonnet-4-20250514")) + assert.Equal(t, 1.0, sonnetState, "Sonnet CircuitBreakerState should be 1 (open)") + + // Phase 2: Haiku model should still work (independent circuit) + resp = makeRequest("claude-3-5-haiku-20241022") + assert.Equal(t, http.StatusOK, resp.StatusCode, "Haiku should succeed while sonnet circuit is open") + assert.Equal(t, int32(1), haikuCalls.Load(), "Haiku call should reach upstream") + + // Make multiple haiku requests - all should succeed + for i := 0; i < 3; i++ { + resp = makeRequest("claude-3-5-haiku-20241022") + assert.Equal(t, http.StatusOK, resp.StatusCode, "Haiku should continue to succeed") + } + assert.Equal(t, int32(4), haikuCalls.Load(), "All haiku calls should reach upstream") + + // Verify haiku circuit is still closed (no trips) + haikuTrips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-3-5-haiku-20241022")) + assert.Equal(t, 0.0, haikuTrips, "Haiku CircuitBreakerTrips should be 0") + + haikuState := promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-3-5-haiku-20241022")) + assert.Equal(t, 0.0, haikuState, "Haiku CircuitBreakerState should be 0 (closed)") + + // Phase 3: Sonnet recovers after timeout + time.Sleep(cbConfig.Timeout + 10*time.Millisecond) + sonnetShouldFail.Store(false) + + resp = makeRequest("claude-sonnet-4-20250514") + assert.Equal(t, http.StatusOK, resp.StatusCode, "Sonnet should recover after timeout") + + // Verify sonnet circuit is now closed + sonnetState = promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-sonnet-4-20250514")) + assert.Equal(t, 0.0, sonnetState, "Sonnet CircuitBreakerState should be 0 (closed) after recovery") +} diff --git a/circuitbreaker/circuitbreaker.go b/circuitbreaker/circuitbreaker.go index 97d52c31..4be1d2b8 100644 --- a/circuitbreaker/circuitbreaker.go +++ b/circuitbreaker/circuitbreaker.go @@ -4,16 +4,18 @@ import ( "errors" "fmt" "net/http" - "strconv" - "strings" "sync" + "time" - "cdr.dev/slog/v3" "github.com/coder/aibridge/config" "github.com/coder/aibridge/metrics" "github.com/sony/gobreaker/v2" ) +// ErrCircuitOpen is returned by Execute when the circuit breaker is open +// and the request was rejected without calling the handler. +var ErrCircuitOpen = errors.New("circuit breaker is open") + // DefaultIsFailure returns true for standard HTTP status codes that typically // indicate upstream overload. func DefaultIsFailure(statusCode int) bool { @@ -27,17 +29,20 @@ func DefaultIsFailure(statusCode int) bool { } } -// ProviderCircuitBreakers manages per-endpoint circuit breakers for a single provider. +// ProviderCircuitBreakers manages per-endpoint/model circuit breakers for a single provider. type ProviderCircuitBreakers struct { provider string config config.CircuitBreaker - breakers sync.Map // endpoint -> *gobreaker.CircuitBreaker[struct{}] - onChange func(endpoint string, from, to gobreaker.State) + breakers sync.Map // "endpoint:model" -> *gobreaker.CircuitBreaker[struct{}] + onChange func(endpoint, model string, from, to gobreaker.State) + metrics *metrics.Metrics } // NewProviderCircuitBreakers creates circuit breakers for a single provider. // Returns nil if cfg is nil (no circuit breaker protection). -func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onChange func(endpoint string, from, to gobreaker.State)) *ProviderCircuitBreakers { +// onChange is called when circuit state changes. +// metrics is used to record circuit breaker reject counts (can be nil). +func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onChange func(endpoint, model string, from, to gobreaker.State), m *metrics.Metrics) *ProviderCircuitBreakers { if cfg == nil { return nil } @@ -45,6 +50,7 @@ func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onC provider: provider, config: *cfg, onChange: onChange, + metrics: m, } } @@ -65,14 +71,15 @@ func (p *ProviderCircuitBreakers) openErrorResponse() []byte { return []byte(`{"error":"circuit breaker is open"}`) } -// Get returns the circuit breaker for an endpoint, creating it if needed. -func (p *ProviderCircuitBreakers) Get(endpoint string) *gobreaker.CircuitBreaker[struct{}] { - if v, ok := p.breakers.Load(endpoint); ok { +// Get returns the circuit breaker for an endpoint/model tuple, creating it if needed. +func (p *ProviderCircuitBreakers) Get(endpoint, model string) *gobreaker.CircuitBreaker[struct{}] { + key := endpoint + ":" + model + if v, ok := p.breakers.Load(key); ok { return v.(*gobreaker.CircuitBreaker[struct{}]) } settings := gobreaker.Settings{ - Name: p.provider + ":" + endpoint, + Name: p.provider + ":" + key, MaxRequests: p.config.MaxRequests, Interval: p.config.Interval, Timeout: p.config.Timeout, @@ -81,13 +88,13 @@ func (p *ProviderCircuitBreakers) Get(endpoint string) *gobreaker.CircuitBreaker }, OnStateChange: func(_ string, from, to gobreaker.State) { if p.onChange != nil { - p.onChange(endpoint, from, to) + p.onChange(endpoint, model, from, to) } }, } cb := gobreaker.NewCircuitBreaker[struct{}](settings) - actual, _ := p.breakers.LoadOrStore(endpoint, cb) + actual, _ := p.breakers.LoadOrStore(key, cb) return actual.(*gobreaker.CircuitBreaker[struct{}]) } @@ -126,44 +133,59 @@ func (w *statusCapturingWriter) Unwrap() http.ResponseWriter { return w.ResponseWriter } -// Middleware returns middleware that wraps handlers with circuit breaker protection. -// It captures the response status code to determine success/failure without provider-specific logic. -// If cbs is nil, requests pass through without circuit breaker protection. -func Middleware(cbs *ProviderCircuitBreakers, m *metrics.Metrics, logger slog.Logger) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - // No circuit breaker configured - pass through - if cbs == nil { - return next +// Execute runs the given handler function within circuit breaker protection. +// If the circuit is open, the request is rejected with a 503 response, metrics are recorded, +// and ErrCircuitOpen is returned. +// Otherwise, it returns the handler's error (or nil on success). +// The handler receives a wrapped ResponseWriter that captures the status code. +// If the receiver is nil (no circuit breaker configured), the handler is called directly. +func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.ResponseWriter, handler func(http.ResponseWriter) error) error { + if p == nil { + return handler(w) + } + + cb := p.Get(endpoint, model) + + // Wrap response writer to capture status code + sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK} + + var handlerErr error + _, err := cb.Execute(func() (struct{}, error) { + handlerErr = handler(sw) + if p.isFailure(sw.statusCode) { + return struct{}{}, fmt.Errorf("upstream error: %d", sw.statusCode) } + return struct{}{}, nil + }) - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - endpoint := strings.TrimPrefix(r.URL.Path, "/"+cbs.provider) - cb := cbs.Get(endpoint) - - // Wrap response writer to capture status code - sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK} - - _, err := cb.Execute(func() (struct{}, error) { - next.ServeHTTP(sw, r) - if cbs.isFailure(sw.statusCode) { - return struct{}{}, fmt.Errorf("upstream error: %d", sw.statusCode) - } - return struct{}{}, nil - }) - - if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) { - if m != nil { - m.CircuitBreakerRejects.WithLabelValues(cbs.provider, endpoint).Inc() - } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Retry-After", strconv.FormatInt(int64(cbs.config.Timeout.Seconds()), 10)) - w.WriteHeader(http.StatusServiceUnavailable) - w.Write(cbs.openErrorResponse()) - } else if err != nil { - logger.Warn(r.Context(), "unexpected circuit breaker error", slog.Error(err)) - } - }) + if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) { + if p.metrics != nil { + p.metrics.CircuitBreakerRejects.WithLabelValues(p.provider, endpoint, model).Inc() + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Retry-After", fmt.Sprintf("%d", int64(p.config.Timeout.Seconds()))) + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write(p.openErrorResponse()) + return ErrCircuitOpen } + + return handlerErr +} + +// Timeout returns the configured timeout duration for this circuit breaker. +func (p *ProviderCircuitBreakers) Timeout() time.Duration { + return p.config.Timeout +} + +// Provider returns the provider name for this circuit breaker. +func (p *ProviderCircuitBreakers) Provider() string { + return p.provider +} + +// OpenErrorResponse returns the error response body when the circuit is open. +// This is exposed for handlers to use when responding to rejected requests. +func (p *ProviderCircuitBreakers) OpenErrorResponse() []byte { + return p.openErrorResponse() } // StateToGaugeValue converts gobreaker.State to a gauge value. diff --git a/circuitbreaker/circuitbreaker_test.go b/circuitbreaker/circuitbreaker_test.go index 0e48f6c1..18913718 100644 --- a/circuitbreaker/circuitbreaker_test.go +++ b/circuitbreaker/circuitbreaker_test.go @@ -1,102 +1,118 @@ package circuitbreaker import ( + "errors" "net/http" "net/http/httptest" "sync/atomic" "testing" "time" - "cdr.dev/slog/v3" - "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/aibridge/config" "github.com/sony/gobreaker/v2" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func TestMiddleware_PerEndpointIsolation(t *testing.T) { +func TestExecute_PerModelIsolation(t *testing.T) { t.Parallel() - chatCalls := atomic.Int32{} - responsesCalls := atomic.Int32{} - - // Mock upstream - /chat returns 429, /responses returns 200 - upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/test/v1/chat/completions" { - chatCalls.Add(1) - w.WriteHeader(http.StatusTooManyRequests) - } else { - responsesCalls.Add(1) - w.WriteHeader(http.StatusOK) - } - }) + sonnetCalls := atomic.Int32{} + haikuCalls := atomic.Int32{} cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ FailureThreshold: 1, Interval: time.Minute, Timeout: time.Minute, MaxRequests: 1, - }, func(endpoint string, from, to gobreaker.State) {}) - - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - handler := Middleware(cbs, nil, logger)(upstream) - server := httptest.NewServer(handler) - defer server.Close() - - // Trip circuit on /chat/completions - resp, err := http.Get(server.URL + "/test/v1/chat/completions") - require.NoError(t, err) - resp.Body.Close() - - // /chat/completions should now be blocked - resp, err = http.Get(server.URL + "/test/v1/chat/completions") - require.NoError(t, err) - resp.Body.Close() - assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) - assert.Equal(t, "60", resp.Header.Get("Retry-After")) // Timeout is 1 minute - assert.Equal(t, int32(1), chatCalls.Load()) // Only 1 call, second was blocked - - // /responses should still work - resp, err = http.Get(server.URL + "/test/v1/responses") - require.NoError(t, err) - resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.Equal(t, int32(1), responsesCalls.Load()) + }, func(endpoint, model string, from, to gobreaker.State) {}, nil) + + endpoint := "/v1/messages" + sonnetModel := "claude-sonnet-4-20250514" + haikuModel := "claude-3-5-haiku-20241022" + + // Trip circuit on sonnet model (returns 429) + w := httptest.NewRecorder() + err := cbs.Execute(endpoint, sonnetModel, w, func(rw http.ResponseWriter) error { + sonnetCalls.Add(1) + rw.WriteHeader(http.StatusTooManyRequests) + return nil + }) + assert.NoError(t, err) + assert.Equal(t, int32(1), sonnetCalls.Load()) + + // Second sonnet request should be blocked by circuit breaker + w = httptest.NewRecorder() + err = cbs.Execute(endpoint, sonnetModel, w, func(rw http.ResponseWriter) error { + sonnetCalls.Add(1) + rw.WriteHeader(http.StatusOK) + return nil + }) + assert.True(t, errors.Is(err, ErrCircuitOpen)) + assert.Equal(t, int32(1), sonnetCalls.Load()) // No new call + assert.Equal(t, http.StatusServiceUnavailable, w.Code) + + // Haiku model on same endpoint should still work (independent circuit) + w = httptest.NewRecorder() + err = cbs.Execute(endpoint, haikuModel, w, func(rw http.ResponseWriter) error { + haikuCalls.Add(1) + rw.WriteHeader(http.StatusOK) + return nil + }) + assert.NoError(t, err) + assert.Equal(t, int32(1), haikuCalls.Load()) } -func TestMiddleware_NotConfigured(t *testing.T) { +func TestExecute_PerEndpointIsolation(t *testing.T) { t.Parallel() - var upstreamCalls atomic.Int32 + messagesCalls := atomic.Int32{} + completionsCalls := atomic.Int32{} - upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - upstreamCalls.Add(1) - w.WriteHeader(http.StatusTooManyRequests) - }) + cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, + }, func(endpoint, model string, from, to gobreaker.State) {}, nil) - // No circuit breaker configured (nil) - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - handler := Middleware(nil, nil, logger)(upstream) - server := httptest.NewServer(handler) - defer server.Close() - - // All requests should pass through even with 429s - for i := 0; i < 10; i++ { - resp, err := http.Get(server.URL + "/test/v1/messages") - require.NoError(t, err) - resp.Body.Close() - assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) - } - assert.Equal(t, int32(10), upstreamCalls.Load()) + model := "test-model" + + // Trip circuit on /v1/messages endpoint (returns 429) + w := httptest.NewRecorder() + err := cbs.Execute("/v1/messages", model, w, func(rw http.ResponseWriter) error { + messagesCalls.Add(1) + rw.WriteHeader(http.StatusTooManyRequests) + return nil + }) + assert.NoError(t, err) + assert.Equal(t, int32(1), messagesCalls.Load()) + + // Second /v1/messages request should be blocked + w = httptest.NewRecorder() + err = cbs.Execute("/v1/messages", model, w, func(rw http.ResponseWriter) error { + messagesCalls.Add(1) + rw.WriteHeader(http.StatusOK) + return nil + }) + assert.True(t, errors.Is(err, ErrCircuitOpen)) + assert.Equal(t, int32(1), messagesCalls.Load()) // No new call + assert.Equal(t, http.StatusServiceUnavailable, w.Code) + + // /v1/chat/completions on same model should still work (different endpoint) + w = httptest.NewRecorder() + err = cbs.Execute("/v1/chat/completions", model, w, func(rw http.ResponseWriter) error { + completionsCalls.Add(1) + rw.WriteHeader(http.StatusOK) + return nil + }) + assert.NoError(t, err) + assert.Equal(t, int32(1), completionsCalls.Load()) } -func TestMiddleware_CustomIsFailure(t *testing.T) { +func TestExecute_CustomIsFailure(t *testing.T) { t.Parallel() - upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadGateway) // 502 - }) + var calls atomic.Int32 // Custom IsFailure that treats 502 as failure cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ @@ -107,21 +123,70 @@ func TestMiddleware_CustomIsFailure(t *testing.T) { IsFailure: func(statusCode int) bool { return statusCode == http.StatusBadGateway }, - }, func(endpoint string, from, to gobreaker.State) {}) - - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - handler := Middleware(cbs, nil, logger)(upstream) - server := httptest.NewServer(handler) - defer server.Close() + }, func(endpoint, model string, from, to gobreaker.State) {}, nil) // First request returns 502, trips circuit - resp, _ := http.Get(server.URL + "/test/v1/messages") - resp.Body.Close() + w := httptest.NewRecorder() + err := cbs.Execute("/v1/messages", "test-model", w, func(rw http.ResponseWriter) error { + calls.Add(1) + rw.WriteHeader(http.StatusBadGateway) + return nil + }) + assert.NoError(t, err) + assert.Equal(t, int32(1), calls.Load()) // Second request should be blocked - resp, _ = http.Get(server.URL + "/test/v1/messages") - assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) - resp.Body.Close() + w = httptest.NewRecorder() + err = cbs.Execute("/v1/messages", "test-model", w, func(rw http.ResponseWriter) error { + calls.Add(1) + rw.WriteHeader(http.StatusOK) + return nil + }) + assert.True(t, errors.Is(err, ErrCircuitOpen)) + assert.Equal(t, int32(1), calls.Load()) // No new call + assert.Equal(t, http.StatusServiceUnavailable, w.Code) +} + +func TestExecute_OnStateChange(t *testing.T) { + t.Parallel() + + var stateChanges []struct { + endpoint string + model string + from gobreaker.State + to gobreaker.State + } + + cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, + }, func(endpoint, model string, from, to gobreaker.State) { + stateChanges = append(stateChanges, struct { + endpoint string + model string + from gobreaker.State + to gobreaker.State + }{endpoint, model, from, to}) + }, nil) + + endpoint := "/v1/messages" + model := "claude-sonnet-4-20250514" + + // Trip circuit + w := httptest.NewRecorder() + cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) error { + rw.WriteHeader(http.StatusTooManyRequests) + return nil + }) + + // Verify state change callback was called with correct parameters + assert.Len(t, stateChanges, 1) + assert.Equal(t, endpoint, stateChanges[0].endpoint) + assert.Equal(t, model, stateChanges[0].model) + assert.Equal(t, gobreaker.StateClosed, stateChanges[0].from) + assert.Equal(t, gobreaker.StateOpen, stateChanges[0].to) } func TestDefaultIsFailure(t *testing.T) { diff --git a/metrics/metrics.go b/metrics/metrics.go index 5ae6ac2d..0c5f5f8b 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -110,23 +110,23 @@ func NewMetrics(reg prometheus.Registerer) *Metrics { // Circuit breaker metrics. - // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + // Pessimistic cardinality: 2 providers, 2 endpoints, 5 models = up to 20. CircuitBreakerState: promauto.With(reg).NewGaugeVec(prometheus.GaugeOpts{ Subsystem: "circuit_breaker", Name: "state", Help: "Current state of the circuit breaker (0=closed, 0.5=half-open, 1=open).", - }, []string{"provider", "endpoint"}), - // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + }, []string{"provider", "endpoint", "model"}), + // Pessimistic cardinality: 2 providers, 2 endpoints, 5 models = up to 20. CircuitBreakerTrips: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ Subsystem: "circuit_breaker", Name: "trips_total", Help: "Total number of times the circuit breaker transitioned to open state.", - }, []string{"provider", "endpoint"}), - // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + }, []string{"provider", "endpoint", "model"}), + // Pessimistic cardinality: 2 providers, 2 endpoints, 5 models = up to 20. CircuitBreakerRejects: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ Subsystem: "circuit_breaker", Name: "rejects_total", Help: "Total number of requests rejected due to open circuit breaker.", - }, []string{"provider", "endpoint"}), + }, []string{"provider", "endpoint", "model"}), } }