diff --git a/bridge.go b/bridge.go index d0361954..9983a485 100644 --- a/bridge.go +++ b/bridge.go @@ -10,16 +10,17 @@ import ( "time" "cdr.dev/slog/v3" + "github.com/coder/aibridge/circuitbreaker" aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" + "github.com/hashicorp/go-multierror" + "github.com/sony/gobreaker/v2" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/trace" - - "github.com/hashicorp/go-multierror" ) // RequestBridge is an [http.Handler] which is capable of masquerading as AI providers' APIs; @@ -59,22 +60,53 @@ const recordingTimeout = time.Second * 5 // A [intercept.Recorder] is also required to record prompt, tool, and token use. // // mcpProxy will be closed when the [RequestBridge] is closed. +// +// Circuit breaker configuration is obtained from each provider's CircuitBreakerConfig() method. +// Providers returning nil will not have circuit breaker protection. func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec recorder.Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) (*RequestBridge, error) { mux := http.NewServeMux() - for _, provider := range providers { + for _, prov := range providers { + // Create per-provider circuit breaker if configured + cfg := prov.CircuitBreakerConfig() + providerName := prov.Name() + onChange := func(endpoint string, from, to gobreaker.State) { + logger.Info(context.Background(), "circuit breaker state change", + slog.F("provider", providerName), + slog.F("endpoint", endpoint), + slog.F("from", from.String()), + slog.F("to", to.String()), + ) + if m != nil { + m.CircuitBreakerState.WithLabelValues(providerName, endpoint).Set(circuitbreaker.StateToGaugeValue(to)) + if to == gobreaker.StateOpen { + m.CircuitBreakerTrips.WithLabelValues(providerName, endpoint).Inc() + } + } + } + cbs := circuitbreaker.NewProviderCircuitBreakers(providerName, cfg, onChange) + // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). - for _, path := range provider.BridgedRoutes() { - mux.HandleFunc(path, newInterceptionProcessor(provider, rec, mcpProxy, logger, m, tracer)) + for _, path := range prov.BridgedRoutes() { + // Initialize circuit breaker state metric to closed (0) for known routes + if m != nil && cbs != nil { + endpoint := strings.TrimPrefix(path, "/"+providerName) + m.CircuitBreakerState.WithLabelValues(providerName, endpoint).Set(0) + } + + handler := newInterceptionProcessor(prov, rec, mcpProxy, logger, m, tracer) + // Wrap with circuit breaker middleware (nil cbs passes through) + wrapped := circuitbreaker.Middleware(cbs, m, logger)(handler) + mux.Handle(path, wrapped) } // Any requests which passthrough to this will be reverse-proxied to the upstream. // // We have to whitelist the known-safe routes because an API key with elevated privileges (i.e. admin) might be // configured, so we should just reverse-proxy known-safe routes. - ftr := newPassthroughRouter(provider, logger.Named(fmt.Sprintf("passthrough.%s", provider.Name())), m, tracer) - for _, path := range provider.PassthroughRoutes() { - prefix := fmt.Sprintf("/%s", provider.Name()) + ftr := newPassthroughRouter(prov, logger.Named(fmt.Sprintf("passthrough.%s", prov.Name())), m, tracer) + for _, path := range prov.PassthroughRoutes() { + prefix := fmt.Sprintf("/%s", prov.Name()) route := fmt.Sprintf("%s%s", prefix, path) mux.HandleFunc(route, http.StripPrefix(prefix, ftr).ServeHTTP) } @@ -100,7 +132,7 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re // newInterceptionProcessor returns an [http.HandlerFunc] which is capable of creating a new interceptor and processing a given request // using [Provider] p, recording all usage events using [Recorder] rec. -func newInterceptionProcessor(p Provider, rec recorder.Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc { +func newInterceptionProcessor(p provider.Provider, rec recorder.Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, span := tracer.Start(r.Context(), "Intercept") defer span.End() diff --git a/circuit_breaker_integration_test.go b/circuit_breaker_integration_test.go new file mode 100644 index 00000000..1cf1fba6 --- /dev/null +++ b/circuit_breaker_integration_test.go @@ -0,0 +1,544 @@ +package aibridge_test + +import ( + "context" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/aibridge" + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/internal/testutil" + "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/metrics" + "github.com/coder/aibridge/provider" + "github.com/prometheus/client_golang/prometheus" + promtest "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" +) + +// TestCircuitBreaker_FullRecoveryCycle tests the complete circuit breaker lifecycle: +// closed → open (after consecutive failures) → half-open (after timeout) → closed (after successful request) +func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { + t.Parallel() + + type testCase struct { + name string + providerName string + endpoint string + errorBody string + successBody string + requestBody string + setupHeaders func(req *http.Request) + createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider + } + + tests := []testCase{ + { + name: "Anthropic", + providerName: config.ProviderAnthropic, + endpoint: "/v1/messages", + errorBody: `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`, + successBody: `{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-sonnet-4-20250514","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, + requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, + setupHeaders: func(req *http.Request) { + req.Header.Set("x-api-key", "test") + req.Header.Set("anthropic-version", "2023-06-01") + }, + createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { + return provider.NewAnthropic(config.Anthropic{ + BaseURL: baseURL, + Key: "test-key", + CircuitBreaker: cbConfig, + }, nil) + }, + }, + { + name: "OpenAI", + providerName: config.ProviderOpenAI, + endpoint: "/v1/chat/completions", + errorBody: `{"error":{"type":"rate_limit_error","message":"rate limited","code":"rate_limit_exceeded"}}`, + successBody: `{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, + requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, + setupHeaders: func(req *http.Request) { + req.Header.Set("Authorization", "Bearer test-key") + }, + createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { + return provider.NewOpenAI(config.OpenAI{ + BaseURL: baseURL, + Key: "test-key", + CircuitBreaker: cbConfig, + }) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var upstreamCalls atomic.Int32 + var shouldFail atomic.Bool + shouldFail.Store(true) + + // Mock upstream that returns 429 or 200 based on shouldFail flag. + // x-should-retry: false is required to disable SDK automatic retries (default MaxRetries=2). + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalls.Add(1) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("x-should-retry", "false") + if shouldFail.Load() { + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(tc.errorBody)) + } else { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(tc.successBody)) + } + })) + defer mockUpstream.Close() + + metrics := metrics.NewMetrics(prometheus.NewRegistry()) + + // Create provider with circuit breaker config + cbConfig := &config.CircuitBreaker{ + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 1, + } + prov := tc.createProvider(mockUpstream.URL, cbConfig) + + ctx := t.Context() + tracer := otel.Tracer("forTesting") + logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + bridge, err := aibridge.NewRequestBridge(ctx, + []provider.Provider{prov}, + &testutil.MockRecorder{}, + mcp.NewServerProxyManager(nil, tracer), + logger, + metrics, + tracer, + ) + require.NoError(t, err) + + mockSrv := httptest.NewUnstartedServer(bridge) + t.Cleanup(mockSrv.Close) + mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, "test-user-id", nil) + } + mockSrv.Start() + + makeRequest := func() *http.Response { + req, err := http.NewRequest("POST", mockSrv.URL+"/"+tc.providerName+tc.endpoint, strings.NewReader(tc.requestBody)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + tc.setupHeaders(req) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + resp.Body.Close() + return resp + } + + // Phase 1: Trip the circuit breaker + // First FailureThreshold requests hit upstream, get 429 + for i := uint32(0); i < cbConfig.FailureThreshold; i++ { + resp := makeRequest() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + } + assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load()) + + // Phase 2: Verify circuit is open + // Request should be blocked by circuit breaker (no upstream call) + resp := makeRequest() + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load(), "No new upstream call when circuit is open") + + // Verify metrics show circuit is open + trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.providerName, tc.endpoint)) + assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1") + + state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.providerName, tc.endpoint)) + assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open)") + + rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.providerName, tc.endpoint)) + assert.Equal(t, 1.0, rejects, "CircuitBreakerRejects should be 1") + + // Phase 3: Wait for timeout to transition to half-open + time.Sleep(cbConfig.Timeout + 10*time.Millisecond) + + // Switch upstream to return success + shouldFail.Store(false) + + // Phase 4: Recovery - request in half-open state should succeed and close circuit + upstreamCallsBefore := upstreamCalls.Load() + resp = makeRequest() + assert.Equal(t, http.StatusOK, resp.StatusCode, "Request should succeed in half-open state") + assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state") + + // Verify circuit is now closed + state = promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.providerName, tc.endpoint)) + assert.Equal(t, 0.0, state, "CircuitBreakerState should be 0 (closed) after recovery") + + // Phase 5: Verify circuit is fully functional again + // Multiple requests should all succeed and reach upstream + for i := 0; i < 3; i++ { + resp = makeRequest() + assert.Equal(t, http.StatusOK, resp.StatusCode, "Request should succeed after circuit closes") + } + + // All requests should have reached upstream + assert.Equal(t, upstreamCallsBefore+4, upstreamCalls.Load(), "All requests should reach upstream after circuit closes") + + // Rejects count should not have increased + rejects = promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.providerName, tc.endpoint)) + assert.Equal(t, 1.0, rejects, "CircuitBreakerRejects should still be 1 (no new rejects)") + }) + } +} + +// TestCircuitBreaker_HalfOpenFailure tests that a failed request in half-open state +// returns the circuit to open: closed → open → half-open → open +func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { + t.Parallel() + + type testCase struct { + name string + providerName string + endpoint string + errorBody string + requestBody string + setupHeaders func(req *http.Request) + createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider + } + + tests := []testCase{ + { + name: "Anthropic", + providerName: config.ProviderAnthropic, + endpoint: "/v1/messages", + errorBody: `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`, + requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, + setupHeaders: func(req *http.Request) { + req.Header.Set("x-api-key", "test") + req.Header.Set("anthropic-version", "2023-06-01") + }, + createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { + return provider.NewAnthropic(config.Anthropic{ + BaseURL: baseURL, + Key: "test-key", + CircuitBreaker: cbConfig, + }, nil) + }, + }, + { + name: "OpenAI", + providerName: config.ProviderOpenAI, + endpoint: "/v1/chat/completions", + errorBody: `{"error":{"type":"rate_limit_error","message":"rate limited","code":"rate_limit_exceeded"}}`, + requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, + setupHeaders: func(req *http.Request) { + req.Header.Set("Authorization", "Bearer test-key") + }, + createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { + return provider.NewOpenAI(config.OpenAI{ + BaseURL: baseURL, + Key: "test-key", + CircuitBreaker: cbConfig, + }) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var upstreamCalls atomic.Int32 + + // Mock upstream that always returns 429. + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalls.Add(1) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("x-should-retry", "false") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(tc.errorBody)) + })) + defer mockUpstream.Close() + + metrics := metrics.NewMetrics(prometheus.NewRegistry()) + + cbConfig := &config.CircuitBreaker{ + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 1, + } + prov := tc.createProvider(mockUpstream.URL, cbConfig) + + ctx := t.Context() + tracer := otel.Tracer("forTesting") + logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + bridge, err := aibridge.NewRequestBridge(ctx, + []provider.Provider{prov}, + &testutil.MockRecorder{}, + mcp.NewServerProxyManager(nil, tracer), + logger, + metrics, + tracer, + ) + require.NoError(t, err) + + mockSrv := httptest.NewUnstartedServer(bridge) + t.Cleanup(mockSrv.Close) + mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, "test-user-id", nil) + } + mockSrv.Start() + + makeRequest := func() *http.Response { + req, err := http.NewRequest("POST", mockSrv.URL+"/"+tc.providerName+tc.endpoint, strings.NewReader(tc.requestBody)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + tc.setupHeaders(req) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + resp.Body.Close() + return resp + } + + // Phase 1: Trip the circuit + for i := uint32(0); i < cbConfig.FailureThreshold; i++ { + resp := makeRequest() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + } + + // Verify circuit is open + resp := makeRequest() + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + + trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.providerName, tc.endpoint)) + assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1") + + // Phase 2: Wait for half-open state + time.Sleep(cbConfig.Timeout + 10*time.Millisecond) + + // Phase 3: Request in half-open state fails, circuit should re-open + upstreamCallsBefore := upstreamCalls.Load() + resp = makeRequest() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode, "Request should fail in half-open state") + assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state") + + // Circuit should be open again - next request should be rejected immediately + resp = makeRequest() + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Circuit should be open again after half-open failure") + assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should NOT reach upstream when circuit re-opens") + + // Verify metrics: trips should be 2 now (tripped twice) + trips = promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(tc.providerName, tc.endpoint)) + assert.Equal(t, 2.0, trips, "CircuitBreakerTrips should be 2 after half-open failure") + + state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(tc.providerName, tc.endpoint)) + assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open) after half-open failure") + }) + } +} + +// TestCircuitBreaker_HalfOpenMaxRequests tests that MaxRequests limits concurrent +// requests in half-open state. Requests beyond the limit should be rejected. +func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { + t.Parallel() + + type testCase struct { + name string + providerName string + endpoint string + errorBody string + successBody string + requestBody string + setupHeaders func(req *http.Request) + createProvider func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider + } + + tests := []testCase{ + { + name: "Anthropic", + providerName: config.ProviderAnthropic, + endpoint: "/v1/messages", + errorBody: `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`, + successBody: `{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-sonnet-4-20250514","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, + requestBody: `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, + setupHeaders: func(req *http.Request) { + req.Header.Set("x-api-key", "test") + req.Header.Set("anthropic-version", "2023-06-01") + }, + createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { + return provider.NewAnthropic(config.Anthropic{ + BaseURL: baseURL, + Key: "test-key", + CircuitBreaker: cbConfig, + }, nil) + }, + }, + { + name: "OpenAI", + providerName: config.ProviderOpenAI, + endpoint: "/v1/chat/completions", + errorBody: `{"error":{"type":"rate_limit_error","message":"rate limited","code":"rate_limit_exceeded"}}`, + successBody: `{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, + requestBody: `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`, + setupHeaders: func(req *http.Request) { + req.Header.Set("Authorization", "Bearer test-key") + }, + createProvider: func(baseURL string, cbConfig *config.CircuitBreaker) provider.Provider { + return provider.NewOpenAI(config.OpenAI{ + BaseURL: baseURL, + Key: "test-key", + CircuitBreaker: cbConfig, + }) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var upstreamCalls atomic.Int32 + var shouldFail atomic.Bool + shouldFail.Store(true) + + // Upstream is slow to ensure concurrent requests overlap in half-open state. + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalls.Add(1) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("x-should-retry", "false") + if shouldFail.Load() { + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(tc.errorBody)) + } else { + // Slow response to ensure requests overlap + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(tc.successBody)) + } + })) + defer mockUpstream.Close() + + metrics := metrics.NewMetrics(prometheus.NewRegistry()) + + const maxRequests = 2 + cbConfig := &config.CircuitBreaker{ + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: maxRequests, // Allow only 2 concurrent requests in half-open + } + prov := tc.createProvider(mockUpstream.URL, cbConfig) + + ctx := t.Context() + tracer := otel.Tracer("forTesting") + logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + bridge, err := aibridge.NewRequestBridge(ctx, + []provider.Provider{prov}, + &testutil.MockRecorder{}, + mcp.NewServerProxyManager(nil, tracer), + logger, + metrics, + tracer, + ) + require.NoError(t, err) + + mockSrv := httptest.NewUnstartedServer(bridge) + t.Cleanup(mockSrv.Close) + mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, "test-user-id", nil) + } + mockSrv.Start() + + makeRequest := func() *http.Response { + req, err := http.NewRequest("POST", mockSrv.URL+"/"+tc.providerName+tc.endpoint, strings.NewReader(tc.requestBody)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + tc.setupHeaders(req) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + resp.Body.Close() + return resp + } + + // Phase 1: Trip the circuit + for i := uint32(0); i < cbConfig.FailureThreshold; i++ { + resp := makeRequest() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + } + + // Verify circuit is open + resp := makeRequest() + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + + // Phase 2: Wait for half-open state and switch upstream to success + time.Sleep(cbConfig.Timeout + 10*time.Millisecond) + shouldFail.Store(false) + upstreamCalls.Store(0) + + // Phase 3: Send concurrent requests (more than MaxRequests) + const totalRequests = 5 + var wg sync.WaitGroup + responses := make(chan int, totalRequests) + + for i := 0; i < totalRequests; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resp := makeRequest() + responses <- resp.StatusCode + }() + } + + wg.Wait() + close(responses) + + // Count results + var successCount, rejectedCount int + for status := range responses { + switch status { + case http.StatusOK: + successCount++ + case http.StatusServiceUnavailable: + rejectedCount++ + } + } + + // Verify only MaxRequests reached upstream + assert.Equal(t, int32(maxRequests), upstreamCalls.Load(), + "Only MaxRequests (%d) should reach upstream in half-open state", maxRequests) + + // Verify request counts + assert.Equal(t, maxRequests, successCount, + "Only %d requests should succeed (MaxRequests)", maxRequests) + assert.Equal(t, totalRequests-maxRequests, rejectedCount, + "%d requests should be rejected (ErrTooManyRequests)", totalRequests-maxRequests) + + // Verify rejects metric increased + rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(tc.providerName, tc.endpoint)) + assert.Equal(t, float64(1+totalRequests-maxRequests), rejects, + "CircuitBreakerRejects should include half-open rejections") + }) + } +} diff --git a/circuitbreaker/circuitbreaker.go b/circuitbreaker/circuitbreaker.go new file mode 100644 index 00000000..97d52c31 --- /dev/null +++ b/circuitbreaker/circuitbreaker.go @@ -0,0 +1,182 @@ +package circuitbreaker + +import ( + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "sync" + + "cdr.dev/slog/v3" + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/metrics" + "github.com/sony/gobreaker/v2" +) + +// DefaultIsFailure returns true for standard HTTP status codes that typically +// indicate upstream overload. +func DefaultIsFailure(statusCode int) bool { + switch statusCode { + case http.StatusTooManyRequests, // 429 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout: // 504 + return true + default: + return false + } +} + +// ProviderCircuitBreakers manages per-endpoint circuit breakers for a single provider. +type ProviderCircuitBreakers struct { + provider string + config config.CircuitBreaker + breakers sync.Map // endpoint -> *gobreaker.CircuitBreaker[struct{}] + onChange func(endpoint string, from, to gobreaker.State) +} + +// NewProviderCircuitBreakers creates circuit breakers for a single provider. +// Returns nil if cfg is nil (no circuit breaker protection). +func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onChange func(endpoint string, from, to gobreaker.State)) *ProviderCircuitBreakers { + if cfg == nil { + return nil + } + return &ProviderCircuitBreakers{ + provider: provider, + config: *cfg, + onChange: onChange, + } +} + +// isFailure checks if the status code should count as a failure. +// Falls back to DefaultIsFailure if no custom function is configured. +func (p *ProviderCircuitBreakers) isFailure(statusCode int) bool { + if p.config.IsFailure != nil { + return p.config.IsFailure(statusCode) + } + return DefaultIsFailure(statusCode) +} + +// openErrorResponse returns the error response body when the circuit is open. +func (p *ProviderCircuitBreakers) openErrorResponse() []byte { + if p.config.OpenErrorResponse != nil { + return p.config.OpenErrorResponse() + } + return []byte(`{"error":"circuit breaker is open"}`) +} + +// Get returns the circuit breaker for an endpoint, creating it if needed. +func (p *ProviderCircuitBreakers) Get(endpoint string) *gobreaker.CircuitBreaker[struct{}] { + if v, ok := p.breakers.Load(endpoint); ok { + return v.(*gobreaker.CircuitBreaker[struct{}]) + } + + settings := gobreaker.Settings{ + Name: p.provider + ":" + endpoint, + MaxRequests: p.config.MaxRequests, + Interval: p.config.Interval, + Timeout: p.config.Timeout, + ReadyToTrip: func(counts gobreaker.Counts) bool { + return counts.ConsecutiveFailures >= p.config.FailureThreshold + }, + OnStateChange: func(_ string, from, to gobreaker.State) { + if p.onChange != nil { + p.onChange(endpoint, from, to) + } + }, + } + + cb := gobreaker.NewCircuitBreaker[struct{}](settings) + actual, _ := p.breakers.LoadOrStore(endpoint, cb) + return actual.(*gobreaker.CircuitBreaker[struct{}]) +} + +// statusCapturingWriter wraps http.ResponseWriter to capture the status code. +// It also implements http.Flusher to support streaming responses. +type statusCapturingWriter struct { + http.ResponseWriter + statusCode int + headerWritten bool +} + +func (w *statusCapturingWriter) WriteHeader(code int) { + if !w.headerWritten { + w.statusCode = code + w.headerWritten = true + } + w.ResponseWriter.WriteHeader(code) +} + +func (w *statusCapturingWriter) Write(b []byte) (int, error) { + if !w.headerWritten { + w.statusCode = http.StatusOK + w.headerWritten = true + } + return w.ResponseWriter.Write(b) +} + +func (w *statusCapturingWriter) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// Unwrap returns the underlying ResponseWriter for interface checks. +func (w *statusCapturingWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + +// Middleware returns middleware that wraps handlers with circuit breaker protection. +// It captures the response status code to determine success/failure without provider-specific logic. +// If cbs is nil, requests pass through without circuit breaker protection. +func Middleware(cbs *ProviderCircuitBreakers, m *metrics.Metrics, logger slog.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + // No circuit breaker configured - pass through + if cbs == nil { + return next + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + endpoint := strings.TrimPrefix(r.URL.Path, "/"+cbs.provider) + cb := cbs.Get(endpoint) + + // Wrap response writer to capture status code + sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK} + + _, err := cb.Execute(func() (struct{}, error) { + next.ServeHTTP(sw, r) + if cbs.isFailure(sw.statusCode) { + return struct{}{}, fmt.Errorf("upstream error: %d", sw.statusCode) + } + return struct{}{}, nil + }) + + if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) { + if m != nil { + m.CircuitBreakerRejects.WithLabelValues(cbs.provider, endpoint).Inc() + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Retry-After", strconv.FormatInt(int64(cbs.config.Timeout.Seconds()), 10)) + w.WriteHeader(http.StatusServiceUnavailable) + w.Write(cbs.openErrorResponse()) + } else if err != nil { + logger.Warn(r.Context(), "unexpected circuit breaker error", slog.Error(err)) + } + }) + } +} + +// StateToGaugeValue converts gobreaker.State to a gauge value. +// closed=0, half-open=0.5, open=1 +func StateToGaugeValue(s gobreaker.State) float64 { + switch s { + case gobreaker.StateClosed: + return 0 + case gobreaker.StateHalfOpen: + return 0.5 + case gobreaker.StateOpen: + return 1 + default: + return 0 + } +} diff --git a/circuitbreaker/circuitbreaker_test.go b/circuitbreaker/circuitbreaker_test.go new file mode 100644 index 00000000..0e48f6c1 --- /dev/null +++ b/circuitbreaker/circuitbreaker_test.go @@ -0,0 +1,155 @@ +package circuitbreaker + +import ( + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/aibridge/config" + "github.com/sony/gobreaker/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMiddleware_PerEndpointIsolation(t *testing.T) { + t.Parallel() + + chatCalls := atomic.Int32{} + responsesCalls := atomic.Int32{} + + // Mock upstream - /chat returns 429, /responses returns 200 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/test/v1/chat/completions" { + chatCalls.Add(1) + w.WriteHeader(http.StatusTooManyRequests) + } else { + responsesCalls.Add(1) + w.WriteHeader(http.StatusOK) + } + }) + + cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, + }, func(endpoint string, from, to gobreaker.State) {}) + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + handler := Middleware(cbs, nil, logger)(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // Trip circuit on /chat/completions + resp, err := http.Get(server.URL + "/test/v1/chat/completions") + require.NoError(t, err) + resp.Body.Close() + + // /chat/completions should now be blocked + resp, err = http.Get(server.URL + "/test/v1/chat/completions") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + assert.Equal(t, "60", resp.Header.Get("Retry-After")) // Timeout is 1 minute + assert.Equal(t, int32(1), chatCalls.Load()) // Only 1 call, second was blocked + + // /responses should still work + resp, err = http.Get(server.URL + "/test/v1/responses") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, int32(1), responsesCalls.Load()) +} + +func TestMiddleware_NotConfigured(t *testing.T) { + t.Parallel() + + var upstreamCalls atomic.Int32 + + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalls.Add(1) + w.WriteHeader(http.StatusTooManyRequests) + }) + + // No circuit breaker configured (nil) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + handler := Middleware(nil, nil, logger)(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // All requests should pass through even with 429s + for i := 0; i < 10; i++ { + resp, err := http.Get(server.URL + "/test/v1/messages") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + } + assert.Equal(t, int32(10), upstreamCalls.Load()) +} + +func TestMiddleware_CustomIsFailure(t *testing.T) { + t.Parallel() + + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) // 502 + }) + + // Custom IsFailure that treats 502 as failure + cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, + IsFailure: func(statusCode int) bool { + return statusCode == http.StatusBadGateway + }, + }, func(endpoint string, from, to gobreaker.State) {}) + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + handler := Middleware(cbs, nil, logger)(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // First request returns 502, trips circuit + resp, _ := http.Get(server.URL + "/test/v1/messages") + resp.Body.Close() + + // Second request should be blocked + resp, _ = http.Get(server.URL + "/test/v1/messages") + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + resp.Body.Close() +} + +func TestDefaultIsFailure(t *testing.T) { + t.Parallel() + + tests := []struct { + statusCode int + isFailure bool + }{ + {http.StatusOK, false}, + {http.StatusBadRequest, false}, + {http.StatusUnauthorized, false}, + {http.StatusTooManyRequests, true}, // 429 + {http.StatusInternalServerError, false}, + {http.StatusBadGateway, false}, + {http.StatusServiceUnavailable, true}, // 503 + {http.StatusGatewayTimeout, true}, // 504 + } + + for _, tt := range tests { + assert.Equal(t, tt.isFailure, DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode) + } +} + +func TestStateToGaugeValue(t *testing.T) { + t.Parallel() + + assert.Equal(t, float64(0), StateToGaugeValue(gobreaker.StateClosed)) + assert.Equal(t, float64(0.5), StateToGaugeValue(gobreaker.StateHalfOpen)) + assert.Equal(t, float64(1), StateToGaugeValue(gobreaker.StateOpen)) +} diff --git a/config/config.go b/config/config.go index aa091911..5b953989 100644 --- a/config/config.go +++ b/config/config.go @@ -1,13 +1,44 @@ package config +import "time" + const ( ProviderAnthropic = "anthropic" ProviderOpenAI = "openai" ) +// CircuitBreaker holds configuration for circuit breakers. +type CircuitBreaker struct { + // MaxRequests is the maximum number of requests allowed in half-open state. + MaxRequests uint32 + // Interval is the cyclic period of the closed state for clearing internal counts. + Interval time.Duration + // Timeout is how long the circuit stays open before transitioning to half-open. + Timeout time.Duration + // FailureThreshold is the number of consecutive failures that triggers the circuit to open. + FailureThreshold uint32 + // IsFailure determines if a status code should count as a failure. + // If nil, defaults to DefaultIsFailure. + IsFailure func(statusCode int) bool + // OpenErrorResponse returns the response body when the circuit is open. + // This should match the provider's error format. + OpenErrorResponse func() []byte +} + +// DefaultCircuitBreaker returns sensible defaults for circuit breaker configuration. +func DefaultCircuitBreaker() CircuitBreaker { + return CircuitBreaker{ + FailureThreshold: 5, + Interval: 10 * time.Second, + Timeout: 30 * time.Second, + MaxRequests: 3, + } +} + type Anthropic struct { - BaseURL string - Key string + BaseURL string + Key string + CircuitBreaker *CircuitBreaker } type AWSBedrock struct { @@ -20,6 +51,7 @@ type AWSBedrock struct { } type OpenAI struct { - BaseURL string - Key string + BaseURL string + Key string + CircuitBreaker *CircuitBreaker } diff --git a/go.mod b/go.mod index 45aceca9..829d8374 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/mark3labs/mcp-go v0.38.0 github.com/prometheus/client_golang v1.23.2 + github.com/sony/gobreaker/v2 v2.3.0 github.com/stretchr/testify v1.11.1 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 diff --git a/go.sum b/go.sum index 07932903..64d8e169 100644 --- a/go.sum +++ b/go.sum @@ -110,6 +110,8 @@ github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/sony/gobreaker/v2 v2.3.0 h1:7VYxZ69QXRQ2Q4eEawHn6eU4FiuwovzJwsUMA03Lu4I= +github.com/sony/gobreaker/v2 v2.3.0/go.mod h1:pTyFJgcZ3h2tdQVLZZruK2C0eoFL1fb/G83wK1ZQl+s= github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= diff --git a/metrics/metrics.go b/metrics/metrics.go index bd654d78..5ae6ac2d 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -28,6 +28,11 @@ type Metrics struct { // Tool-related metrics. InjectedToolUseCount *prometheus.CounterVec NonInjectedToolUseCount *prometheus.CounterVec + + // Circuit breaker metrics. + CircuitBreakerState *prometheus.GaugeVec // Current state (0=closed, 0.5=half-open, 1=open) + CircuitBreakerTrips *prometheus.CounterVec // Total times circuit opened + CircuitBreakerRejects *prometheus.CounterVec // Requests rejected due to open circuit } // NewMetrics creates AND registers metrics. It will panic if a collector has already been registered. @@ -102,5 +107,26 @@ func NewMetrics(reg prometheus.Registerer) *Metrics { Name: "total", Help: "The number of times an AI model selected a tool to be invoked by the client.", }, append(baseLabels, "name")), + + // Circuit breaker metrics. + + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + CircuitBreakerState: promauto.With(reg).NewGaugeVec(prometheus.GaugeOpts{ + Subsystem: "circuit_breaker", + Name: "state", + Help: "Current state of the circuit breaker (0=closed, 0.5=half-open, 1=open).", + }, []string{"provider", "endpoint"}), + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + CircuitBreakerTrips: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "circuit_breaker", + Name: "trips_total", + Help: "Total number of times the circuit breaker transitioned to open state.", + }, []string{"provider", "endpoint"}), + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + CircuitBreakerRejects: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "circuit_breaker", + Name: "rejects_total", + Help: "Total number of requests rejected due to open circuit breaker.", + }, []string{"provider", "endpoint"}), } } diff --git a/provider/anthropic.go b/provider/anthropic.go index 36a5f075..b90fa5f5 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -7,6 +7,7 @@ import ( "net/http" "os" + "github.com/coder/aibridge/circuitbreaker" "github.com/coder/aibridge/config" "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/messages" @@ -26,6 +27,18 @@ type Anthropic struct { const routeMessages = "/anthropic/v1/messages" // https://docs.anthropic.com/en/api/messages +var anthropicOpenErrorResponse = func() []byte { + return []byte(`{"type":"error","error":{"type":"overloaded_error","message":"circuit breaker is open"}}`) +} + +var anthropicIsFailure = func(statusCode int) bool { + // https://platform.claude.com/docs/en/api/errors + if statusCode == 529 { + return true + } + return circuitbreaker.DefaultIsFailure(statusCode) +} + func NewAnthropic(cfg config.Anthropic, bedrockCfg *config.AWSBedrock) *Anthropic { if cfg.BaseURL == "" { cfg.BaseURL = "https://api.anthropic.com/" @@ -33,6 +46,10 @@ func NewAnthropic(cfg config.Anthropic, bedrockCfg *config.AWSBedrock) *Anthropi if cfg.Key == "" { cfg.Key = os.Getenv("ANTHROPIC_API_KEY") } + if cfg.CircuitBreaker != nil { + cfg.CircuitBreaker.IsFailure = anthropicIsFailure + cfg.CircuitBreaker.OpenErrorResponse = anthropicOpenErrorResponse + } return &Anthropic{ cfg: cfg, @@ -102,3 +119,7 @@ func (p *Anthropic) InjectAuthHeader(headers *http.Header) { headers.Set(p.AuthHeader(), p.cfg.Key) } + +func (p *Anthropic) CircuitBreakerConfig() *config.CircuitBreaker { + return p.cfg.CircuitBreaker +} diff --git a/provider/anthropic_test.go b/provider/anthropic_test.go new file mode 100644 index 00000000..49b2c6c3 --- /dev/null +++ b/provider/anthropic_test.go @@ -0,0 +1,31 @@ +package provider + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_anthropicIsFailure(t *testing.T) { + t.Parallel() + + tests := []struct { + statusCode int + isFailure bool + }{ + {http.StatusOK, false}, + {http.StatusBadRequest, false}, + {http.StatusUnauthorized, false}, + {http.StatusTooManyRequests, true}, // 429 + {http.StatusInternalServerError, false}, + {http.StatusBadGateway, false}, + {http.StatusServiceUnavailable, true}, // 503 + {http.StatusGatewayTimeout, true}, // 504 + {529, true}, // Anthropic Overloaded + } + + for _, tt := range tests { + assert.Equal(t, tt.isFailure, anthropicIsFailure(tt.statusCode), "status code %d", tt.statusCode) + } +} diff --git a/provider/openai.go b/provider/openai.go index 8e91b109..24ce3d81 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -22,10 +22,15 @@ const ( routeResponses = "/openai/v1/responses" // https://platform.openai.com/docs/api-reference/responses ) +var openAIOpenErrorResponse = func() []byte { + return []byte(`{"error":{"message":"circuit breaker is open","type":"server_error","code":"service_unavailable"}}`) +} + // OpenAI allows for interactions with the OpenAI API. type OpenAI struct { - baseURL string - key string + baseURL string + key string + circuitBreaker *config.CircuitBreaker } var _ Provider = &OpenAI{} @@ -39,9 +44,14 @@ func NewOpenAI(cfg config.OpenAI) *OpenAI { cfg.Key = os.Getenv("OPENAI_API_KEY") } + if cfg.CircuitBreaker != nil { + cfg.CircuitBreaker.OpenErrorResponse = openAIOpenErrorResponse + } + return &OpenAI{ - baseURL: cfg.BaseURL, - key: cfg.Key, + baseURL: cfg.BaseURL, + key: cfg.Key, + circuitBreaker: cfg.CircuitBreaker, } } @@ -132,3 +142,7 @@ func (p *OpenAI) InjectAuthHeader(headers *http.Header) { headers.Set(p.AuthHeader(), "Bearer "+p.key) } + +func (p *OpenAI) CircuitBreakerConfig() *config.CircuitBreaker { + return p.circuitBreaker +} diff --git a/provider/provider.go b/provider/provider.go index 562b55bb..2013a0cb 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -4,6 +4,7 @@ import ( "errors" "net/http" + "github.com/coder/aibridge/config" "github.com/coder/aibridge/intercept" "go.opentelemetry.io/otel/trace" ) @@ -37,4 +38,7 @@ type Provider interface { AuthHeader() string // InjectAuthHeader allows [Provider]s to set its authentication header. InjectAuthHeader(*http.Header) + + // CircuitBreakerConfig returns the circuit breaker configuration for the provider. + CircuitBreakerConfig() *config.CircuitBreaker }