From 2d43f46dab65acd233ec2214d36dd63fd0a66ae0 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Thu, 12 Mar 2026 12:32:10 +0000 Subject: [PATCH 1/3] feat: forward client request headers to upstream providers in bridge routes --- intercept/chatcompletions/base.go | 16 ++ intercept/chatcompletions/blocking.go | 22 +- intercept/chatcompletions/streaming.go | 22 +- intercept/chatcompletions/streaming_test.go | 10 +- intercept/client_headers.go | 70 +++++++ intercept/client_headers_test.go | 219 ++++++++++++++++++++ intercept/messages/base.go | 16 ++ intercept/messages/blocking.go | 28 ++- intercept/messages/streaming.go | 27 ++- intercept/responses/base.go | 27 ++- intercept/responses/blocking.go | 28 ++- intercept/responses/streaming.go | 28 ++- provider/anthropic.go | 4 +- provider/anthropic_test.go | 13 +- provider/copilot.go | 8 +- provider/copilot_test.go | 22 +- provider/openai.go | 8 +- provider/openai_test.go | 90 ++++++++ 18 files changed, 583 insertions(+), 75 deletions(-) create mode 100644 intercept/client_headers.go create mode 100644 intercept/client_headers_test.go diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index 7a755e06..086b1c78 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -9,6 +9,7 @@ import ( "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" @@ -29,6 +30,10 @@ type interceptionBase struct { req *ChatCompletionNewParamsWrapper cfg config.OpenAI + // clientHeaders are the original HTTP headers from the client request. + clientHeaders http.Header + authHeaderName string + logger slog.Logger tracer trace.Tracer @@ -41,10 +46,21 @@ func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService // Add extra headers if configured. // Some providers require additional headers that are not added by the SDK. + // TODO(ssncferreira): remove as part of https://github.com/coder/aibridge/issues/192 for key, value := range i.cfg.ExtraHeaders { opts = append(opts, option.WithHeader(key, value)) } + // Forward client headers to upstream. This middleware runs after the SDK + // has built the request, and replaces the outgoing headers with the sanitized + // client headers plus provider auth. + if i.clientHeaders != nil { + opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + req.Header = intercept.BuildUpstreamHeaders(req.Header, i.clientHeaders, i.authHeaderName) + return next(req) + })) + } + // Add API dump middleware if configured if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) diff --git a/intercept/chatcompletions/blocking.go b/intercept/chatcompletions/blocking.go index 9a84d143..2816ed7a 100644 --- a/intercept/chatcompletions/blocking.go +++ b/intercept/chatcompletions/blocking.go @@ -28,12 +28,21 @@ type BlockingInterception struct { interceptionBase } -func NewBlockingInterceptor(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg config.OpenAI, tracer trace.Tracer) *BlockingInterception { +func NewBlockingInterceptor( + id uuid.UUID, + req *ChatCompletionNewParamsWrapper, + cfg config.OpenAI, + clientHeaders http.Header, + authHeaderName string, + tracer trace.Tracer, +) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ - id: id, - req: req, - cfg: cfg, - tracer: tracer, + id: id, + req: req, + cfg: cfg, + clientHeaders: clientHeaders, + authHeaderName: authHeaderName, + tracer: tracer, }} } @@ -78,6 +87,9 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req var opts []option.RequestOption opts = append(opts, option.WithRequestTimeout(time.Second*600)) + + // TODO(ssncferreira): inject actor headers directly in the client-header + // middleware instead of using SDK options. if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...) } diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index ff3b78c6..231601c0 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -33,12 +33,21 @@ type StreamingInterception struct { interceptionBase } -func NewStreamingInterceptor(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg config.OpenAI, tracer trace.Tracer) *StreamingInterception { +func NewStreamingInterceptor( + id uuid.UUID, + req *ChatCompletionNewParamsWrapper, + cfg config.OpenAI, + clientHeaders http.Header, + authHeaderName string, + tracer trace.Tracer, +) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ - id: id, - req: req, - cfg: cfg, - tracer: tracer, + id: id, + req: req, + cfg: cfg, + clientHeaders: clientHeaders, + authHeaderName: authHeaderName, + tracer: tracer, }} } @@ -121,6 +130,9 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re for { // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) var opts []option.RequestOption + + // TODO(ssncferreira): inject actor headers directly in the client-header + // middleware instead of using SDK options. if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...) } diff --git a/intercept/chatcompletions/streaming_test.go b/intercept/chatcompletions/streaming_test.go index 7d8d4d57..233831e6 100644 --- a/intercept/chatcompletions/streaming_test.go +++ b/intercept/chatcompletions/streaming_test.go @@ -81,16 +81,16 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) { Stream: true, } + // Create test request + w := httptest.NewRecorder() + httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil) + tracer := otel.Tracer("test") - interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, tracer) + interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, httpReq.Header, "Authorization", tracer) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) interceptor.Setup(logger, &testutil.MockRecorder{}, nil) - // Create test request - w := httptest.NewRecorder() - httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil) - // Process the request err := interceptor.ProcessRequest(w, httpReq) diff --git a/intercept/client_headers.go b/intercept/client_headers.go new file mode 100644 index 00000000..6eaae541 --- /dev/null +++ b/intercept/client_headers.go @@ -0,0 +1,70 @@ +package intercept + +import "net/http" + +// hopByHopHeaders are connection-level headers specific to the connection +// between client and AI Bridge, not meant for the upstream. +// See https://www.rfc-editor.org/rfc/rfc2616#section-13.5.1 +var hopByHopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", + "Trailer", + "Transfer-Encoding", + "Upgrade", +} + +// nonForwardedHeaders are transport-level headers managed by aibridge or +// Go's HTTP transport that must not be forwarded to the upstream provider. +var nonForwardedHeaders = []string{ + "Host", + "Accept-Encoding", + "Content-Length", +} + +// authHeaders are headers that carry authentication credentials from the +// client. These are stripped because the SDK re-injects the correct +// provider credentials (API key or per-user token). +var authHeaders = []string{ + "Authorization", + "X-Api-Key", +} + +// SanitizeClientHeaders returns a copy of the client headers with hop-by-hop, +// transport, and auth headers removed. +func SanitizeClientHeaders(clientHeaders http.Header) http.Header { + sanitized := clientHeaders.Clone() + for _, h := range hopByHopHeaders { + sanitized.Del(h) + } + for _, h := range nonForwardedHeaders { + sanitized.Del(h) + } + for _, h := range authHeaders { + sanitized.Del(h) + } + return sanitized +} + +// BuildUpstreamHeaders produces the header set for an upstream SDK request. +// It starts from the sanitized client headers, then preserves specific +// headers from the SDK-built request that must not be overwritten. +func BuildUpstreamHeaders(sdkHeader http.Header, clientHeaders http.Header, authHeaderName string) http.Header { + headers := SanitizeClientHeaders(clientHeaders) + + // Preserve the auth header set by the SDK from the provider configuration. + if v := sdkHeader.Get(authHeaderName); v != "" { + headers.Set(authHeaderName, v) + } + + // Preserve actor headers injected by aibridge as per-request SDK options. + for name, values := range sdkHeader { + if IsActorHeader(name) { + headers[name] = values + } + } + + return headers +} diff --git a/intercept/client_headers_test.go b/intercept/client_headers_test.go new file mode 100644 index 00000000..c696e82a --- /dev/null +++ b/intercept/client_headers_test.go @@ -0,0 +1,219 @@ +package intercept + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSanitizeClientHeaders(t *testing.T) { + t.Parallel() + + t.Run("nil input returns empty header", func(t *testing.T) { + t.Parallel() + + result := SanitizeClientHeaders(nil) + require.Empty(t, result) + }) + + t.Run("hop-by-hop headers are removed", func(t *testing.T) { + t.Parallel() + + input := http.Header{ + "Connection": {"keep-alive"}, + "Keep-Alive": {"timeout=5"}, + "Transfer-Encoding": {"chunked"}, + "Upgrade": {"websocket"}, + "X-Custom": {"preserved"}, + } + + result := SanitizeClientHeaders(input) + + assert.Empty(t, result.Get("Connection")) + assert.Empty(t, result.Get("Keep-Alive")) + assert.Empty(t, result.Get("Transfer-Encoding")) + assert.Empty(t, result.Get("Upgrade")) + assert.Equal(t, "preserved", result.Get("X-Custom")) + }) + + t.Run("non-forwarded headers are removed", func(t *testing.T) { + t.Parallel() + + input := http.Header{ + "Host": {"example.com"}, + "Accept-Encoding": {"gzip"}, + "Content-Length": {"42"}, + "X-Custom": {"preserved"}, + } + + result := SanitizeClientHeaders(input) + + assert.Empty(t, result.Get("Host")) + assert.Empty(t, result.Get("Accept-Encoding")) + assert.Empty(t, result.Get("Content-Length")) + assert.Equal(t, "preserved", result.Get("X-Custom")) + }) + + t.Run("auth headers are removed", func(t *testing.T) { + t.Parallel() + + input := http.Header{ + "Authorization": {"Bearer coder-session-token"}, + "X-Api-Key": {"sk-client-key"}, + "X-Custom": {"preserved"}, + } + + result := SanitizeClientHeaders(input) + + assert.Empty(t, result.Get("Authorization")) + assert.Empty(t, result.Get("X-Api-Key")) + assert.Equal(t, "preserved", result.Get("X-Custom")) + }) + + t.Run("multi-value headers are preserved", func(t *testing.T) { + t.Parallel() + + input := http.Header{ + "X-Custom": {"value-1", "value-2"}, + } + + result := SanitizeClientHeaders(input) + + require.Equal(t, []string{"value-1", "value-2"}, result["X-Custom"]) + }) + + t.Run("input is not mutated", func(t *testing.T) { + t.Parallel() + + input := http.Header{ + "Connection": {"keep-alive"}, + "X-Custom": {"preserved"}, + } + originalCopy := input.Clone() + + _ = SanitizeClientHeaders(input) + + require.Equal(t, originalCopy, input) + }) +} + +func TestBuildUpstreamHeaders(t *testing.T) { + t.Parallel() + + t.Run("preserves auth from SDK", func(t *testing.T) { + t.Parallel() + + sdkHeader := http.Header{ + "Authorization": {"Bearer sk-provider-key"}, + } + clientHeaders := http.Header{ + "Authorization": {"Bearer coder-session-token"}, + "User-Agent": {"claude-code/1.0"}, + } + + result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + + assert.Equal(t, "Bearer sk-provider-key", result.Get("Authorization")) + assert.Equal(t, "claude-code/1.0", result.Get("User-Agent")) + }) + + t.Run("preserves X-Api-Key from SDK and strips client Authorization", func(t *testing.T) { + t.Parallel() + + sdkHeader := http.Header{ + "X-Api-Key": {"sk-ant-provider-key"}, + } + clientHeaders := http.Header{ + "X-Api-Key": {"sk-ant-client-key"}, + "Authorization": {"Bearer coder-session-token"}, + "Anthropic-Beta": {"prompt-caching-2024-07-31"}, + } + + result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "X-Api-Key") + + assert.Equal(t, "sk-ant-provider-key", result.Get("X-Api-Key")) + assert.Empty(t, result.Get("Authorization")) + assert.Equal(t, "prompt-caching-2024-07-31", result.Get("Anthropic-Beta")) + }) + + t.Run("preserves actor headers from SDK", func(t *testing.T) { + t.Parallel() + + sdkHeader := http.Header{ + "Authorization": {"Bearer sk-key"}, + "X-Ai-Bridge-Actor-Id": {"user-123"}, + "X-Ai-Bridge-Actor-Metadata-Name": {"alice"}, + } + clientHeaders := http.Header{ + "Authorization": {"Bearer coder-token"}, + "User-Agent": {"claude-code/1.0"}, + } + + result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + + assert.Equal(t, "Bearer sk-key", result.Get("Authorization")) + assert.Equal(t, "user-123", result.Get("X-Ai-Bridge-Actor-Id")) + assert.Equal(t, "alice", result.Get("X-Ai-Bridge-Actor-Metadata-Name")) + assert.Equal(t, "claude-code/1.0", result.Get("User-Agent")) + }) + + t.Run("strips hop-by-hop and transport headers", func(t *testing.T) { + t.Parallel() + + sdkHeader := http.Header{ + "Authorization": {"Bearer sk-key"}, + } + clientHeaders := http.Header{ + "Connection": {"keep-alive"}, + "Host": {"bridge.example.com"}, + "Content-Length": {"99"}, + "Accept-Encoding": {"gzip"}, + "Transfer-Encoding": {"chunked"}, + "User-Agent": {"claude-code/1.0"}, + } + + result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + + assert.Empty(t, result.Get("Connection")) + assert.Empty(t, result.Get("Host")) + assert.Empty(t, result.Get("Content-Length")) + assert.Empty(t, result.Get("Accept-Encoding")) + assert.Empty(t, result.Get("Transfer-Encoding")) + assert.Equal(t, "claude-code/1.0", result.Get("User-Agent")) + }) + + t.Run("empty auth header in SDK is not injected", func(t *testing.T) { + t.Parallel() + + sdkHeader := http.Header{} + clientHeaders := http.Header{ + "User-Agent": {"claude-code/1.0"}, + } + + result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + + assert.Empty(t, result.Get("Authorization")) + assert.Equal(t, "claude-code/1.0", result.Get("User-Agent")) + }) + + t.Run("does not mutate inputs", func(t *testing.T) { + t.Parallel() + + sdkHeader := http.Header{ + "Authorization": {"Bearer sk-key"}, + } + clientHeaders := http.Header{ + "Authorization": {"Bearer coder-token"}, + "Connection": {"keep-alive"}, + } + sdkCopy := sdkHeader.Clone() + clientCopy := clientHeaders.Clone() + + _ = BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + + require.Equal(t, sdkCopy, sdkHeader) + require.Equal(t, clientCopy, clientHeaders) + }) +} diff --git a/intercept/messages/base.go b/intercept/messages/base.go index c61a3a7c..ff21c96c 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -18,6 +18,7 @@ import ( "github.com/aws/aws-sdk-go-v2/credentials" aibconfig "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" @@ -40,6 +41,10 @@ type interceptionBase struct { cfg aibconfig.Anthropic bedrockCfg *aibconfig.AWSBedrock + // clientHeaders are the original HTTP headers from the client request. + clientHeaders http.Header + authHeaderName string + tracer trace.Tracer logger slog.Logger @@ -183,10 +188,21 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio // Add extra headers if configured. // Some providers require additional headers that are not added by the SDK. + // TODO(ssncferreira): remove as part of https://github.com/coder/aibridge/issues/192 for key, value := range i.cfg.ExtraHeaders { opts = append(opts, option.WithHeader(key, value)) } + // Forward client headers to upstream. This middleware runs after the SDK + // has built the request, and replaces the outgoing headers with the sanitized + // client headers plus provider auth. + if i.clientHeaders != nil { + opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + req.Header = intercept.BuildUpstreamHeaders(req.Header, i.clientHeaders, i.authHeaderName) + return next(req) + })) + } + // Add API dump middleware if configured if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, aibconfig.ProviderAnthropic, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index e22b97f8..75c43479 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -28,14 +28,25 @@ type BlockingInterception struct { interceptionBase } -func NewBlockingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *BlockingInterception { +func NewBlockingInterceptor( + id uuid.UUID, + req *MessageNewParamsWrapper, + payload []byte, + cfg config.Anthropic, + bedrockCfg *config.AWSBedrock, + clientHeaders http.Header, + authHeaderName string, + tracer trace.Tracer, +) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ - id: id, - req: req, - payload: payload, - cfg: cfg, - bedrockCfg: bedrockCfg, - tracer: tracer, + id: id, + req: req, + payload: payload, + cfg: cfg, + bedrockCfg: bedrockCfg, + clientHeaders: clientHeaders, + authHeaderName: authHeaderName, + tracer: tracer, }} } @@ -74,6 +85,9 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req } opts := []option.RequestOption{option.WithRequestTimeout(time.Second * 600)} + + // TODO(ssncferreira): inject actor headers directly in the client-header + // middleware instead of using SDK options. if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { opts = append(opts, intercept.ActorHeadersAsAnthropicOpts(actor)...) } diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index 4e87fd85..78eb287e 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -34,14 +34,25 @@ type StreamingInterception struct { interceptionBase } -func NewStreamingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *StreamingInterception { +func NewStreamingInterceptor( + id uuid.UUID, + req *MessageNewParamsWrapper, + payload []byte, + cfg config.Anthropic, + bedrockCfg *config.AWSBedrock, + clientHeaders http.Header, + authHeaderName string, + tracer trace.Tracer, +) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ - id: id, - req: req, - payload: payload, - cfg: cfg, - bedrockCfg: bedrockCfg, - tracer: tracer, + id: id, + req: req, + payload: payload, + cfg: cfg, + bedrockCfg: bedrockCfg, + clientHeaders: clientHeaders, + authHeaderName: authHeaderName, + tracer: tracer, }} } @@ -110,6 +121,8 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re streamCtx, streamCancel := context.WithCancelCause(ctx) defer streamCancel(errors.New("deferred")) + // TODO(ssncferreira): inject actor headers directly in the client-header + // middleware instead of using SDK options. var opts []option.RequestOption if actor := aibcontext.ActorFromContext(ctx); actor != nil && i.cfg.SendActorHeaders { opts = append(opts, intercept.ActorHeadersAsAnthropicOpts(actor)...) diff --git a/intercept/responses/base.go b/intercept/responses/base.go index 8b7c3ded..f4272031 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -17,6 +17,7 @@ import ( "cdr.dev/slog/v3" "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/metrics" @@ -42,11 +43,16 @@ type responsesInterceptionBase struct { reqPayload []byte cfg config.OpenAI model string - recorder recorder.Recorder - mcpProxy mcp.ServerProxier - logger slog.Logger - metrics metrics.Metrics - tracer trace.Tracer + + // clientHeaders are the original HTTP headers from the client request. + clientHeaders http.Header + authHeaderName string + + recorder recorder.Recorder + mcpProxy mcp.ServerProxier + logger slog.Logger + metrics metrics.Metrics + tracer trace.Tracer } func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService { @@ -54,10 +60,21 @@ func (i *responsesInterceptionBase) newResponsesService() responses.ResponseServ // Add extra headers if configured. // Some providers require additional headers that are not added by the SDK. + // TODO(ssncferreira): remove as part of https://github.com/coder/aibridge/issues/192 for key, value := range i.cfg.ExtraHeaders { opts = append(opts, option.WithHeader(key, value)) } + // Forward client headers to upstream. This middleware runs after the SDK + // has built the request, and replaces the outgoing headers with the sanitized + // client headers plus provider auth. + if i.clientHeaders != nil { + opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + req.Header = intercept.BuildUpstreamHeaders(req.Header, i.clientHeaders, i.authHeaderName) + return next(req) + })) + } + // Add API dump middleware if configured if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index 3e94a6cc..687ce5b7 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -25,15 +25,26 @@ type BlockingResponsesInterceptor struct { responsesInterceptionBase } -func NewBlockingInterceptor(id uuid.UUID, req *ResponsesNewParamsWrapper, reqPayload []byte, cfg config.OpenAI, model string, tracer trace.Tracer) *BlockingResponsesInterceptor { +func NewBlockingInterceptor( + id uuid.UUID, + req *ResponsesNewParamsWrapper, + reqPayload []byte, + cfg config.OpenAI, + model string, + clientHeaders http.Header, + authHeaderName string, + tracer trace.Tracer, +) *BlockingResponsesInterceptor { return &BlockingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ - id: id, - req: req, - reqPayload: reqPayload, - cfg: cfg, - model: model, - tracer: tracer, + id: id, + req: req, + reqPayload: reqPayload, + cfg: cfg, + model: model, + clientHeaders: clientHeaders, + authHeaderName: authHeaderName, + tracer: tracer, }, } } @@ -80,6 +91,9 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * opts := i.requestOptions(&respCopy) opts = append(opts, option.WithRequestTimeout(time.Second*600)) + + // TODO(ssncferreira): inject actor headers directly in the client-header + // middleware instead of using SDK options. if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...) } diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index 6925d86f..99c6b33a 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -32,15 +32,26 @@ type StreamingResponsesInterceptor struct { responsesInterceptionBase } -func NewStreamingInterceptor(id uuid.UUID, req *ResponsesNewParamsWrapper, reqPayload []byte, cfg config.OpenAI, model string, tracer trace.Tracer) *StreamingResponsesInterceptor { +func NewStreamingInterceptor( + id uuid.UUID, + req *ResponsesNewParamsWrapper, + reqPayload []byte, + cfg config.OpenAI, + model string, + clientHeaders http.Header, + authHeaderName string, + tracer trace.Tracer, +) *StreamingResponsesInterceptor { return &StreamingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ - id: id, - req: req, - reqPayload: reqPayload, - cfg: cfg, - model: model, - tracer: tracer, + id: id, + req: req, + reqPayload: reqPayload, + cfg: cfg, + model: model, + clientHeaders: clientHeaders, + authHeaderName: authHeaderName, + tracer: tracer, }, } } @@ -98,6 +109,9 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r respCopy = responseCopier{} opts := i.requestOptions(&respCopy) + + // TODO(ssncferreira): inject actor headers directly in the client-header + // middleware instead of using SDK options. if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...) } diff --git a/provider/anthropic.go b/provider/anthropic.go index e682fdb7..4a79cd42 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -112,9 +112,9 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr var interceptor intercept.Interceptor if req.Stream { - interceptor = messages.NewStreamingInterceptor(id, &req, payload, cfg, p.bedrockCfg, tracer) + interceptor = messages.NewStreamingInterceptor(id, &req, payload, cfg, p.bedrockCfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = messages.NewBlockingInterceptor(id, &req, payload, cfg, p.bedrockCfg, tracer) + interceptor = messages.NewBlockingInterceptor(id, &req, payload, cfg, p.bedrockCfg, r.Header, p.AuthHeader(), tracer) } span.SetAttributes(interceptor.TraceAttributes(r)...) return interceptor, nil diff --git a/provider/anthropic_test.go b/provider/anthropic_test.go index 924c0f98..d34fd029 100644 --- a/provider/anthropic_test.go +++ b/provider/anthropic_test.go @@ -61,7 +61,7 @@ func TestAnthropic_CreateInterceptor(t *testing.T) { assert.Contains(t, err.Error(), "unmarshal request body") }) - t.Run("Messages_ForwardsAnthropicBetaHeaderToUpstream", func(t *testing.T) { + t.Run("Messages_ClientHeaders", func(t *testing.T) { t.Parallel() var receivedHeaders http.Header @@ -86,7 +86,9 @@ func TestAnthropic_CreateInterceptor(t *testing.T) { body := `{"model": "claude-opus-4-5", "max_tokens": 1024, "messages": [{"role": "user", "content": "hello"}], "stream": false}` req := httptest.NewRequest(http.MethodPost, routeMessages, bytes.NewBufferString(body)) req.Header.Set("Anthropic-Beta", betaHeader) - req.Header.Set("X-Custom-Header", "should-not-forward") + // Simulate a client sending its own auth credential, which must be replaced + // by aibridge with the configured provider key. + req.Header.Set("Authorization", "Bearer fake-client-bearer") w := httptest.NewRecorder() interceptor, err := provider.CreateInterceptor(w, req, testTracer) @@ -101,10 +103,11 @@ func TestAnthropic_CreateInterceptor(t *testing.T) { require.NoError(t, err) // Verify the full Anthropic-Beta header (all betas) was forwarded unchanged. - assert.Equal(t, betaHeader, receivedHeaders.Get("Anthropic-Beta")) + assert.Equal(t, betaHeader, receivedHeaders.Get("Anthropic-Beta"), "Anthropic-Beta header must be forwarded unchanged to upstream") - // Verify non-Anthropic headers are not forwarded. - assert.Empty(t, receivedHeaders.Get("X-Custom-Header"), "non-Anthropic headers should not be forwarded") + // Verify aibridge's configured key was used and the client's auth credential was not forwarded. + assert.Equal(t, "test-key", receivedHeaders.Get("X-Api-Key"), "upstream must receive configured provider key") + assert.Empty(t, receivedHeaders.Get("Authorization"), "client Authorization header must not reach upstream") }) t.Run("UnknownRoute", func(t *testing.T) { diff --git a/provider/copilot.go b/provider/copilot.go index 9b128cab..4bdf6a29 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -148,9 +148,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac } if req.Stream { - interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, tracer) + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, tracer) + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer) } case routeCopilotResponses: @@ -164,9 +164,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac } if req.Stream { - interceptor = responses.NewStreamingInterceptor(id, &req, payload, cfg, req.Model, tracer) + interceptor = responses.NewStreamingInterceptor(id, &req, payload, cfg, req.Model, r.Header, p.AuthHeader(), tracer) } else { - interceptor = responses.NewBlockingInterceptor(id, &req, payload, cfg, req.Model, tracer) + interceptor = responses.NewBlockingInterceptor(id, &req, payload, cfg, req.Model, r.Header, p.AuthHeader(), tracer) } default: diff --git a/provider/copilot_test.go b/provider/copilot_test.go index 697b6990..a5df7bd4 100644 --- a/provider/copilot_test.go +++ b/provider/copilot_test.go @@ -129,7 +129,7 @@ func TestCopilot_CreateInterceptor(t *testing.T) { assert.Contains(t, err.Error(), "unmarshal chat completions request body") }) - t.Run("ChatCompletions_ForwardsHeadersToUpstream", func(t *testing.T) { + t.Run("ChatCompletions_ClientHeaders", func(t *testing.T) { t.Parallel() var receivedHeaders http.Header @@ -153,7 +153,6 @@ func TestCopilot_CreateInterceptor(t *testing.T) { req.Header.Set("Authorization", "Bearer test-token") req.Header.Set("Editor-Version", "vscode/1.85.0") req.Header.Set("Copilot-Integration-Id", "test-integration") - req.Header.Set("X-Custom-Header", "should-not-forward") w := httptest.NewRecorder() interceptor, err := provider.CreateInterceptor(w, req, testTracer) @@ -168,12 +167,12 @@ func TestCopilot_CreateInterceptor(t *testing.T) { err = interceptor.ProcessRequest(w, processReq) require.NoError(t, err) - // Verify headers were forwarded + // Verify Copilot-specific headers were forwarded. assert.Equal(t, "vscode/1.85.0", receivedHeaders.Get("Editor-Version")) assert.Equal(t, "test-integration", receivedHeaders.Get("Copilot-Integration-Id")) - - // Verify non-Copilot headers are not forwarded - assert.Empty(t, receivedHeaders.Get("X-Custom-Header"), "non-Copilot headers should not be forwarded") + // Copilot uses per-user tokens: the client's Authorization must reach upstream as-is. + assert.Equal(t, "Bearer test-token", receivedHeaders.Get("Authorization"), "client Authorization must be used as provider key") + assert.Empty(t, receivedHeaders.Get("X-Api-Key"), "X-Api-Key must not be set upstream") }) t.Run("Responses_NonStreamingRequest_BlockingInterceptor", func(t *testing.T) { @@ -221,7 +220,7 @@ func TestCopilot_CreateInterceptor(t *testing.T) { assert.Contains(t, err.Error(), "unmarshal responses request body") }) - t.Run("Responses_ForwardsHeadersToUpstream", func(t *testing.T) { + t.Run("Responses_ClientHeaders", func(t *testing.T) { t.Parallel() var receivedHeaders http.Header @@ -245,7 +244,6 @@ func TestCopilot_CreateInterceptor(t *testing.T) { req.Header.Set("Authorization", "Bearer test-token") req.Header.Set("Editor-Version", "vscode/1.85.0") req.Header.Set("Copilot-Integration-Id", "test-integration") - req.Header.Set("X-Custom-Header", "should-not-forward") w := httptest.NewRecorder() interceptor, err := provider.CreateInterceptor(w, req, testTracer) @@ -260,12 +258,12 @@ func TestCopilot_CreateInterceptor(t *testing.T) { err = interceptor.ProcessRequest(w, processReq) require.NoError(t, err) - // Verify headers were forwarded + // Verify Copilot-specific headers were forwarded. assert.Equal(t, "vscode/1.85.0", receivedHeaders.Get("Editor-Version")) assert.Equal(t, "test-integration", receivedHeaders.Get("Copilot-Integration-Id")) - - // Verify non-Copilot headers are not forwarded - assert.Empty(t, receivedHeaders.Get("X-Custom-Header"), "non-Copilot headers should not be forwarded") + // Copilot uses per-user tokens: the client's Authorization must reach upstream as-is. + assert.Equal(t, "Bearer test-token", receivedHeaders.Get("Authorization"), "client Authorization must be used as provider key") + assert.Empty(t, receivedHeaders.Get("X-Api-Key"), "X-Api-Key must not be set upstream") }) t.Run("UnknownRoute", func(t *testing.T) { diff --git a/provider/openai.go b/provider/openai.go index 43d6811e..dd68f0d9 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -105,9 +105,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace } if req.Stream { - interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.cfg, tracer) + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.cfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.cfg, tracer) + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.cfg, r.Header, p.AuthHeader(), tracer) } case routeResponses: @@ -120,9 +120,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace return nil, fmt.Errorf("unmarshal request body: %w", err) } if req.Stream { - interceptor = responses.NewStreamingInterceptor(id, &req, payload, p.cfg, string(req.Model), tracer) + interceptor = responses.NewStreamingInterceptor(id, &req, payload, p.cfg, string(req.Model), r.Header, p.AuthHeader(), tracer) } else { - interceptor = responses.NewBlockingInterceptor(id, &req, payload, p.cfg, string(req.Model), tracer) + interceptor = responses.NewBlockingInterceptor(id, &req, payload, p.cfg, string(req.Model), r.Header, p.AuthHeader(), tracer) } default: diff --git a/provider/openai_test.go b/provider/openai_test.go index f2654b07..d30d0179 100644 --- a/provider/openai_test.go +++ b/provider/openai_test.go @@ -9,7 +9,11 @@ import ( "strings" "testing" + "cdr.dev/slog/v3" "github.com/coder/aibridge/config" + "github.com/coder/aibridge/internal/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace/noop" "golang.org/x/sync/errgroup" ) @@ -150,6 +154,92 @@ func generateResponsesPayload(payloadSize int, inputCount int, stream bool) []by return bodyBytes } +func TestOpenAI_CreateInterceptor(t *testing.T) { + t.Parallel() + + t.Run("ChatCompletions_ClientHeaders", func(t *testing.T) { + t.Parallel() + + var receivedHeaders http.Header + + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`)) + })) + t.Cleanup(mockUpstream.Close) + + provider := NewOpenAI(config.OpenAI{ + BaseURL: mockUpstream.URL, + Key: "test-key", + }) + + body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}], "stream": false}` + req := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+routeChatCompletions, bytes.NewBufferString(body)) + // Simulate a client sending its own auth credential, which must be replaced + // by aibridge with the configured provider key. + req.Header.Set("Authorization", "Bearer fake-client-bearer") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + require.NoError(t, err) + require.NotNil(t, interceptor) + + logger := slog.Make() + interceptor.Setup(logger, &testutil.MockRecorder{}, nil) + + processReq := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+routeChatCompletions, nil) + err = interceptor.ProcessRequest(w, processReq) + require.NoError(t, err) + + // Verify aibridge's configured key was used and the client's auth credential was not forwarded. + assert.Equal(t, "Bearer test-key", receivedHeaders.Get("Authorization"), "upstream must receive configured provider key") + assert.Empty(t, receivedHeaders.Get("X-Api-Key"), "X-Api-Key must not be set upstream") + }) + + t.Run("Responses_ClientHeaders", func(t *testing.T) { + t.Parallel() + + var receivedHeaders http.Header + + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"resp-123","object":"response","created_at":1677652288,"model":"gpt-5","output":[],"usage":{"input_tokens":5,"output_tokens":10,"total_tokens":15}}`)) + })) + t.Cleanup(mockUpstream.Close) + + provider := NewOpenAI(config.OpenAI{ + BaseURL: mockUpstream.URL, + Key: "test-key", + }) + + body := `{"model": "gpt-5", "input": "hello", "stream": false}` + req := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+routeResponses, bytes.NewBufferString(body)) + // Simulate a client sending its own auth credential, which must be replaced + // by aibridge with the configured provider key. + req.Header.Set("Authorization", "Bearer fake-client-bearer") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + require.NoError(t, err) + require.NotNil(t, interceptor) + + logger := slog.Make() + interceptor.Setup(logger, &testutil.MockRecorder{}, nil) + + processReq := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+routeResponses, nil) + err = interceptor.ProcessRequest(w, processReq) + require.NoError(t, err) + + // Verify aibridge's configured key was used and the client's auth credential was not forwarded. + assert.Equal(t, "Bearer test-key", receivedHeaders.Get("Authorization"), "upstream must receive configured provider key") + assert.Empty(t, receivedHeaders.Get("X-Api-Key"), "X-Api-Key must not be set upstream") + }) +} + func BenchmarkOpenAI_CreateInterceptor_ChatCompletions(b *testing.B) { provider := NewOpenAI(config.OpenAI{ BaseURL: "https://api.openai.com/v1/", From f1b1fc0b0e4909999fac78a2304de3643286dce5 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Mon, 16 Mar 2026 09:49:58 +0000 Subject: [PATCH 2/3] chore: address comments --- intercept/client_headers.go | 24 +++++++++++++----------- intercept/client_headers_test.go | 14 +++++++------- provider/openai_test.go | 11 +++++++++-- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/intercept/client_headers.go b/intercept/client_headers.go index 6eaae541..3e4678f3 100644 --- a/intercept/client_headers.go +++ b/intercept/client_headers.go @@ -25,34 +25,36 @@ var nonForwardedHeaders = []string{ } // authHeaders are headers that carry authentication credentials from the -// client. These are stripped because the SDK re-injects the correct -// provider credentials (API key or per-user token). +// client. The upstream request is built by the SDK, which sets the correct +// provider credentials via option.WithAPIKey. Client auth headers are +// stripped here and the provider credentials are re-injected by +// BuildUpstreamHeaders from the SDK-built request. var authHeaders = []string{ "Authorization", "X-Api-Key", } -// SanitizeClientHeaders returns a copy of the client headers with hop-by-hop, +// PrepareClientHeaders returns a copy of the client headers with hop-by-hop, // transport, and auth headers removed. -func SanitizeClientHeaders(clientHeaders http.Header) http.Header { - sanitized := clientHeaders.Clone() +func PrepareClientHeaders(clientHeaders http.Header) http.Header { + prepared := clientHeaders.Clone() for _, h := range hopByHopHeaders { - sanitized.Del(h) + prepared.Del(h) } for _, h := range nonForwardedHeaders { - sanitized.Del(h) + prepared.Del(h) } for _, h := range authHeaders { - sanitized.Del(h) + prepared.Del(h) } - return sanitized + return prepared } // BuildUpstreamHeaders produces the header set for an upstream SDK request. -// It starts from the sanitized client headers, then preserves specific +// It starts from the prepared client headers, then preserves specific // headers from the SDK-built request that must not be overwritten. func BuildUpstreamHeaders(sdkHeader http.Header, clientHeaders http.Header, authHeaderName string) http.Header { - headers := SanitizeClientHeaders(clientHeaders) + headers := PrepareClientHeaders(clientHeaders) // Preserve the auth header set by the SDK from the provider configuration. if v := sdkHeader.Get(authHeaderName); v != "" { diff --git a/intercept/client_headers_test.go b/intercept/client_headers_test.go index c696e82a..ecd2f018 100644 --- a/intercept/client_headers_test.go +++ b/intercept/client_headers_test.go @@ -8,13 +8,13 @@ import ( "github.com/stretchr/testify/require" ) -func TestSanitizeClientHeaders(t *testing.T) { +func TestPrepareClientHeaders(t *testing.T) { t.Parallel() t.Run("nil input returns empty header", func(t *testing.T) { t.Parallel() - result := SanitizeClientHeaders(nil) + result := PrepareClientHeaders(nil) require.Empty(t, result) }) @@ -29,7 +29,7 @@ func TestSanitizeClientHeaders(t *testing.T) { "X-Custom": {"preserved"}, } - result := SanitizeClientHeaders(input) + result := PrepareClientHeaders(input) assert.Empty(t, result.Get("Connection")) assert.Empty(t, result.Get("Keep-Alive")) @@ -48,7 +48,7 @@ func TestSanitizeClientHeaders(t *testing.T) { "X-Custom": {"preserved"}, } - result := SanitizeClientHeaders(input) + result := PrepareClientHeaders(input) assert.Empty(t, result.Get("Host")) assert.Empty(t, result.Get("Accept-Encoding")) @@ -65,7 +65,7 @@ func TestSanitizeClientHeaders(t *testing.T) { "X-Custom": {"preserved"}, } - result := SanitizeClientHeaders(input) + result := PrepareClientHeaders(input) assert.Empty(t, result.Get("Authorization")) assert.Empty(t, result.Get("X-Api-Key")) @@ -79,7 +79,7 @@ func TestSanitizeClientHeaders(t *testing.T) { "X-Custom": {"value-1", "value-2"}, } - result := SanitizeClientHeaders(input) + result := PrepareClientHeaders(input) require.Equal(t, []string{"value-1", "value-2"}, result["X-Custom"]) }) @@ -93,7 +93,7 @@ func TestSanitizeClientHeaders(t *testing.T) { } originalCopy := input.Clone() - _ = SanitizeClientHeaders(input) + _ = PrepareClientHeaders(input) require.Equal(t, originalCopy, input) }) diff --git a/provider/openai_test.go b/provider/openai_test.go index d30d0179..bdea142b 100644 --- a/provider/openai_test.go +++ b/provider/openai_test.go @@ -18,6 +18,11 @@ import ( "golang.org/x/sync/errgroup" ) +const ( + chatCompletionResponse = `{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}` + responsesAPIResponse = `{"id":"resp-123","object":"response","created_at":1677652288,"model":"gpt-5","output":[],"usage":{"input_tokens":5,"output_tokens":10,"total_tokens":15}}` +) + type message struct { Role string Content string @@ -166,7 +171,8 @@ func TestOpenAI_CreateInterceptor(t *testing.T) { receivedHeaders = r.Header.Clone() w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`)) + _, err := w.Write([]byte(chatCompletionResponse)) + require.NoError(t, err) })) t.Cleanup(mockUpstream.Close) @@ -207,7 +213,8 @@ func TestOpenAI_CreateInterceptor(t *testing.T) { receivedHeaders = r.Header.Clone() w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"id":"resp-123","object":"response","created_at":1677652288,"model":"gpt-5","output":[],"usage":{"input_tokens":5,"output_tokens":10,"total_tokens":15}}`)) + _, err := w.Write([]byte(responsesAPIResponse)) + require.NoError(t, err) })) t.Cleanup(mockUpstream.Close) From c9298ccea915ecb014635a12795cf56f946a957f Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Mon, 16 Mar 2026 10:13:42 +0000 Subject: [PATCH 3/3] test: update openai tests to table driven format --- provider/openai_test.go | 123 +++++++++++++++++----------------------- 1 file changed, 51 insertions(+), 72 deletions(-) diff --git a/provider/openai_test.go b/provider/openai_test.go index bdea142b..4add332e 100644 --- a/provider/openai_test.go +++ b/provider/openai_test.go @@ -162,89 +162,68 @@ func generateResponsesPayload(payloadSize int, inputCount int, stream bool) []by func TestOpenAI_CreateInterceptor(t *testing.T) { t.Parallel() - t.Run("ChatCompletions_ClientHeaders", func(t *testing.T) { - t.Parallel() - - var receivedHeaders http.Header - - mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedHeaders = r.Header.Clone() - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte(chatCompletionResponse)) - require.NoError(t, err) - })) - t.Cleanup(mockUpstream.Close) - - provider := NewOpenAI(config.OpenAI{ - BaseURL: mockUpstream.URL, - Key: "test-key", - }) + tests := []struct { + name string + route string + requestBody string + responseBody string + }{ + { + name: "ChatCompletions_ClientHeaders", + route: routeChatCompletions, + requestBody: `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}], "stream": false}`, + responseBody: chatCompletionResponse, + }, + { + name: "Responses_ClientHeaders", + route: routeResponses, + requestBody: `{"model": "gpt-5", "input": "hello", "stream": false}`, + responseBody: responsesAPIResponse, + }, + } - body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}], "stream": false}` - req := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+routeChatCompletions, bytes.NewBufferString(body)) - // Simulate a client sending its own auth credential, which must be replaced - // by aibridge with the configured provider key. - req.Header.Set("Authorization", "Bearer fake-client-bearer") - w := httptest.NewRecorder() + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - interceptor, err := provider.CreateInterceptor(w, req, testTracer) - require.NoError(t, err) - require.NotNil(t, interceptor) + var receivedHeaders http.Header - logger := slog.Make() - interceptor.Setup(logger, &testutil.MockRecorder{}, nil) + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(tc.responseBody)) + require.NoError(t, err) + })) + t.Cleanup(mockUpstream.Close) - processReq := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+routeChatCompletions, nil) - err = interceptor.ProcessRequest(w, processReq) - require.NoError(t, err) + provider := NewOpenAI(config.OpenAI{ + BaseURL: mockUpstream.URL, + Key: "test-key", + }) - // Verify aibridge's configured key was used and the client's auth credential was not forwarded. - assert.Equal(t, "Bearer test-key", receivedHeaders.Get("Authorization"), "upstream must receive configured provider key") - assert.Empty(t, receivedHeaders.Get("X-Api-Key"), "X-Api-Key must not be set upstream") - }) + req := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+tc.route, bytes.NewBufferString(tc.requestBody)) + // Simulate a client sending its own auth credential, which must be replaced + // by aibridge with the configured provider key. + req.Header.Set("Authorization", "Bearer fake-client-bearer") + w := httptest.NewRecorder() - t.Run("Responses_ClientHeaders", func(t *testing.T) { - t.Parallel() + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + require.NoError(t, err) + require.NotNil(t, interceptor) - var receivedHeaders http.Header + logger := slog.Make() + interceptor.Setup(logger, &testutil.MockRecorder{}, nil) - mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedHeaders = r.Header.Clone() - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte(responsesAPIResponse)) + processReq := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+tc.route, nil) + err = interceptor.ProcessRequest(w, processReq) require.NoError(t, err) - })) - t.Cleanup(mockUpstream.Close) - provider := NewOpenAI(config.OpenAI{ - BaseURL: mockUpstream.URL, - Key: "test-key", + // Verify aibridge's configured key was used and the client's auth credential was not forwarded. + assert.Equal(t, "Bearer test-key", receivedHeaders.Get("Authorization"), "upstream must receive configured provider key") + assert.Empty(t, receivedHeaders.Get("X-Api-Key"), "X-Api-Key must not be set upstream") }) - - body := `{"model": "gpt-5", "input": "hello", "stream": false}` - req := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+routeResponses, bytes.NewBufferString(body)) - // Simulate a client sending its own auth credential, which must be replaced - // by aibridge with the configured provider key. - req.Header.Set("Authorization", "Bearer fake-client-bearer") - w := httptest.NewRecorder() - - interceptor, err := provider.CreateInterceptor(w, req, testTracer) - require.NoError(t, err) - require.NotNil(t, interceptor) - - logger := slog.Make() - interceptor.Setup(logger, &testutil.MockRecorder{}, nil) - - processReq := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+routeResponses, nil) - err = interceptor.ProcessRequest(w, processReq) - require.NoError(t, err) - - // Verify aibridge's configured key was used and the client's auth credential was not forwarded. - assert.Equal(t, "Bearer test-key", receivedHeaders.Get("Authorization"), "upstream must receive configured provider key") - assert.Empty(t, receivedHeaders.Get("X-Api-Key"), "X-Api-Key must not be set upstream") - }) + } } func BenchmarkOpenAI_CreateInterceptor_ChatCompletions(b *testing.B) {