diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index 7a755e06..086b1c78 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -9,6 +9,7 @@ import ( "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" @@ -29,6 +30,10 @@ type interceptionBase struct { req *ChatCompletionNewParamsWrapper cfg config.OpenAI + // clientHeaders are the original HTTP headers from the client request. + clientHeaders http.Header + authHeaderName string + logger slog.Logger tracer trace.Tracer @@ -41,10 +46,21 @@ func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService // Add extra headers if configured. // Some providers require additional headers that are not added by the SDK. + // TODO(ssncferreira): remove as part of https://github.com/coder/aibridge/issues/192 for key, value := range i.cfg.ExtraHeaders { opts = append(opts, option.WithHeader(key, value)) } + // Forward client headers to upstream. This middleware runs after the SDK + // has built the request, and replaces the outgoing headers with the sanitized + // client headers plus provider auth. + if i.clientHeaders != nil { + opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + req.Header = intercept.BuildUpstreamHeaders(req.Header, i.clientHeaders, i.authHeaderName) + return next(req) + })) + } + // Add API dump middleware if configured if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) diff --git a/intercept/chatcompletions/blocking.go b/intercept/chatcompletions/blocking.go index 9a84d143..2816ed7a 100644 --- a/intercept/chatcompletions/blocking.go +++ b/intercept/chatcompletions/blocking.go @@ -28,12 +28,21 @@ type BlockingInterception struct { interceptionBase } -func NewBlockingInterceptor(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg config.OpenAI, tracer trace.Tracer) *BlockingInterception { +func NewBlockingInterceptor( + id uuid.UUID, + req *ChatCompletionNewParamsWrapper, + cfg config.OpenAI, + clientHeaders http.Header, + authHeaderName string, + tracer trace.Tracer, +) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ - id: id, - req: req, - cfg: cfg, - tracer: tracer, + id: id, + req: req, + cfg: cfg, + clientHeaders: clientHeaders, + authHeaderName: authHeaderName, + tracer: tracer, }} } @@ -78,6 +87,9 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req var opts []option.RequestOption opts = append(opts, option.WithRequestTimeout(time.Second*600)) + + // TODO(ssncferreira): inject actor headers directly in the client-header + // middleware instead of using SDK options. if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...) } diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index ff3b78c6..231601c0 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -33,12 +33,21 @@ type StreamingInterception struct { interceptionBase } -func NewStreamingInterceptor(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg config.OpenAI, tracer trace.Tracer) *StreamingInterception { +func NewStreamingInterceptor( + id uuid.UUID, + req *ChatCompletionNewParamsWrapper, + cfg config.OpenAI, + clientHeaders http.Header, + authHeaderName string, + tracer trace.Tracer, +) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ - id: id, - req: req, - cfg: cfg, - tracer: tracer, + id: id, + req: req, + cfg: cfg, + clientHeaders: clientHeaders, + authHeaderName: authHeaderName, + tracer: tracer, }} } @@ -121,6 +130,9 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re for { // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) var opts []option.RequestOption + + // TODO(ssncferreira): inject actor headers directly in the client-header + // middleware instead of using SDK options. if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...) } diff --git a/intercept/chatcompletions/streaming_test.go b/intercept/chatcompletions/streaming_test.go index 7d8d4d57..233831e6 100644 --- a/intercept/chatcompletions/streaming_test.go +++ b/intercept/chatcompletions/streaming_test.go @@ -81,16 +81,16 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) { Stream: true, } + // Create test request + w := httptest.NewRecorder() + httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil) + tracer := otel.Tracer("test") - interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, tracer) + interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, httpReq.Header, "Authorization", tracer) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) interceptor.Setup(logger, &testutil.MockRecorder{}, nil) - // Create test request - w := httptest.NewRecorder() - httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil) - // Process the request err := interceptor.ProcessRequest(w, httpReq) diff --git a/intercept/client_headers.go b/intercept/client_headers.go new file mode 100644 index 00000000..3e4678f3 --- /dev/null +++ b/intercept/client_headers.go @@ -0,0 +1,72 @@ +package intercept + +import "net/http" + +// hopByHopHeaders are connection-level headers specific to the connection +// between client and AI Bridge, not meant for the upstream. +// See https://www.rfc-editor.org/rfc/rfc2616#section-13.5.1 +var hopByHopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", + "Trailer", + "Transfer-Encoding", + "Upgrade", +} + +// nonForwardedHeaders are transport-level headers managed by aibridge or +// Go's HTTP transport that must not be forwarded to the upstream provider. +var nonForwardedHeaders = []string{ + "Host", + "Accept-Encoding", + "Content-Length", +} + +// authHeaders are headers that carry authentication credentials from the +// client. The upstream request is built by the SDK, which sets the correct +// provider credentials via option.WithAPIKey. Client auth headers are +// stripped here and the provider credentials are re-injected by +// BuildUpstreamHeaders from the SDK-built request. +var authHeaders = []string{ + "Authorization", + "X-Api-Key", +} + +// PrepareClientHeaders returns a copy of the client headers with hop-by-hop, +// transport, and auth headers removed. +func PrepareClientHeaders(clientHeaders http.Header) http.Header { + prepared := clientHeaders.Clone() + for _, h := range hopByHopHeaders { + prepared.Del(h) + } + for _, h := range nonForwardedHeaders { + prepared.Del(h) + } + for _, h := range authHeaders { + prepared.Del(h) + } + return prepared +} + +// BuildUpstreamHeaders produces the header set for an upstream SDK request. +// It starts from the prepared client headers, then preserves specific +// headers from the SDK-built request that must not be overwritten. +func BuildUpstreamHeaders(sdkHeader http.Header, clientHeaders http.Header, authHeaderName string) http.Header { + headers := PrepareClientHeaders(clientHeaders) + + // Preserve the auth header set by the SDK from the provider configuration. + if v := sdkHeader.Get(authHeaderName); v != "" { + headers.Set(authHeaderName, v) + } + + // Preserve actor headers injected by aibridge as per-request SDK options. + for name, values := range sdkHeader { + if IsActorHeader(name) { + headers[name] = values + } + } + + return headers +} diff --git a/intercept/client_headers_test.go b/intercept/client_headers_test.go new file mode 100644 index 00000000..ecd2f018 --- /dev/null +++ b/intercept/client_headers_test.go @@ -0,0 +1,219 @@ +package intercept + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPrepareClientHeaders(t *testing.T) { + t.Parallel() + + t.Run("nil input returns empty header", func(t *testing.T) { + t.Parallel() + + result := PrepareClientHeaders(nil) + require.Empty(t, result) + }) + + t.Run("hop-by-hop headers are removed", func(t *testing.T) { + t.Parallel() + + input := http.Header{ + "Connection": {"keep-alive"}, + "Keep-Alive": {"timeout=5"}, + "Transfer-Encoding": {"chunked"}, + "Upgrade": {"websocket"}, + "X-Custom": {"preserved"}, + } + + result := PrepareClientHeaders(input) + + assert.Empty(t, result.Get("Connection")) + assert.Empty(t, result.Get("Keep-Alive")) + assert.Empty(t, result.Get("Transfer-Encoding")) + assert.Empty(t, result.Get("Upgrade")) + assert.Equal(t, "preserved", result.Get("X-Custom")) + }) + + t.Run("non-forwarded headers are removed", func(t *testing.T) { + t.Parallel() + + input := http.Header{ + "Host": {"example.com"}, + "Accept-Encoding": {"gzip"}, + "Content-Length": {"42"}, + "X-Custom": {"preserved"}, + } + + result := PrepareClientHeaders(input) + + assert.Empty(t, result.Get("Host")) + assert.Empty(t, result.Get("Accept-Encoding")) + assert.Empty(t, result.Get("Content-Length")) + assert.Equal(t, "preserved", result.Get("X-Custom")) + }) + + t.Run("auth headers are removed", func(t *testing.T) { + t.Parallel() + + input := http.Header{ + "Authorization": {"Bearer coder-session-token"}, + "X-Api-Key": {"sk-client-key"}, + "X-Custom": {"preserved"}, + } + + result := PrepareClientHeaders(input) + + assert.Empty(t, result.Get("Authorization")) + assert.Empty(t, result.Get("X-Api-Key")) + assert.Equal(t, "preserved", result.Get("X-Custom")) + }) + + t.Run("multi-value headers are preserved", func(t *testing.T) { + t.Parallel() + + input := http.Header{ + "X-Custom": {"value-1", "value-2"}, + } + + result := PrepareClientHeaders(input) + + require.Equal(t, []string{"value-1", "value-2"}, result["X-Custom"]) + }) + + t.Run("input is not mutated", func(t *testing.T) { + t.Parallel() + + input := http.Header{ + "Connection": {"keep-alive"}, + "X-Custom": {"preserved"}, + } + originalCopy := input.Clone() + + _ = PrepareClientHeaders(input) + + require.Equal(t, originalCopy, input) + }) +} + +func TestBuildUpstreamHeaders(t *testing.T) { + t.Parallel() + + t.Run("preserves auth from SDK", func(t *testing.T) { + t.Parallel() + + sdkHeader := http.Header{ + "Authorization": {"Bearer sk-provider-key"}, + } + clientHeaders := http.Header{ + "Authorization": {"Bearer coder-session-token"}, + "User-Agent": {"claude-code/1.0"}, + } + + result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + + assert.Equal(t, "Bearer sk-provider-key", result.Get("Authorization")) + assert.Equal(t, "claude-code/1.0", result.Get("User-Agent")) + }) + + t.Run("preserves X-Api-Key from SDK and strips client Authorization", func(t *testing.T) { + t.Parallel() + + sdkHeader := http.Header{ + "X-Api-Key": {"sk-ant-provider-key"}, + } + clientHeaders := http.Header{ + "X-Api-Key": {"sk-ant-client-key"}, + "Authorization": {"Bearer coder-session-token"}, + "Anthropic-Beta": {"prompt-caching-2024-07-31"}, + } + + result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "X-Api-Key") + + assert.Equal(t, "sk-ant-provider-key", result.Get("X-Api-Key")) + assert.Empty(t, result.Get("Authorization")) + assert.Equal(t, "prompt-caching-2024-07-31", result.Get("Anthropic-Beta")) + }) + + t.Run("preserves actor headers from SDK", func(t *testing.T) { + t.Parallel() + + sdkHeader := http.Header{ + "Authorization": {"Bearer sk-key"}, + "X-Ai-Bridge-Actor-Id": {"user-123"}, + "X-Ai-Bridge-Actor-Metadata-Name": {"alice"}, + } + clientHeaders := http.Header{ + "Authorization": {"Bearer coder-token"}, + "User-Agent": {"claude-code/1.0"}, + } + + result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + + assert.Equal(t, "Bearer sk-key", result.Get("Authorization")) + assert.Equal(t, "user-123", result.Get("X-Ai-Bridge-Actor-Id")) + assert.Equal(t, "alice", result.Get("X-Ai-Bridge-Actor-Metadata-Name")) + assert.Equal(t, "claude-code/1.0", result.Get("User-Agent")) + }) + + t.Run("strips hop-by-hop and transport headers", func(t *testing.T) { + t.Parallel() + + sdkHeader := http.Header{ + "Authorization": {"Bearer sk-key"}, + } + clientHeaders := http.Header{ + "Connection": {"keep-alive"}, + "Host": {"bridge.example.com"}, + "Content-Length": {"99"}, + "Accept-Encoding": {"gzip"}, + "Transfer-Encoding": {"chunked"}, + "User-Agent": {"claude-code/1.0"}, + } + + result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + + assert.Empty(t, result.Get("Connection")) + assert.Empty(t, result.Get("Host")) + assert.Empty(t, result.Get("Content-Length")) + assert.Empty(t, result.Get("Accept-Encoding")) + assert.Empty(t, result.Get("Transfer-Encoding")) + assert.Equal(t, "claude-code/1.0", result.Get("User-Agent")) + }) + + t.Run("empty auth header in SDK is not injected", func(t *testing.T) { + t.Parallel() + + sdkHeader := http.Header{} + clientHeaders := http.Header{ + "User-Agent": {"claude-code/1.0"}, + } + + result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + + assert.Empty(t, result.Get("Authorization")) + assert.Equal(t, "claude-code/1.0", result.Get("User-Agent")) + }) + + t.Run("does not mutate inputs", func(t *testing.T) { + t.Parallel() + + sdkHeader := http.Header{ + "Authorization": {"Bearer sk-key"}, + } + clientHeaders := http.Header{ + "Authorization": {"Bearer coder-token"}, + "Connection": {"keep-alive"}, + } + sdkCopy := sdkHeader.Clone() + clientCopy := clientHeaders.Clone() + + _ = BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + + require.Equal(t, sdkCopy, sdkHeader) + require.Equal(t, clientCopy, clientHeaders) + }) +} diff --git a/intercept/messages/base.go b/intercept/messages/base.go index c61a3a7c..ff21c96c 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -18,6 +18,7 @@ import ( "github.com/aws/aws-sdk-go-v2/credentials" aibconfig "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" @@ -40,6 +41,10 @@ type interceptionBase struct { cfg aibconfig.Anthropic bedrockCfg *aibconfig.AWSBedrock + // clientHeaders are the original HTTP headers from the client request. + clientHeaders http.Header + authHeaderName string + tracer trace.Tracer logger slog.Logger @@ -183,10 +188,21 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio // Add extra headers if configured. // Some providers require additional headers that are not added by the SDK. + // TODO(ssncferreira): remove as part of https://github.com/coder/aibridge/issues/192 for key, value := range i.cfg.ExtraHeaders { opts = append(opts, option.WithHeader(key, value)) } + // Forward client headers to upstream. This middleware runs after the SDK + // has built the request, and replaces the outgoing headers with the sanitized + // client headers plus provider auth. + if i.clientHeaders != nil { + opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + req.Header = intercept.BuildUpstreamHeaders(req.Header, i.clientHeaders, i.authHeaderName) + return next(req) + })) + } + // Add API dump middleware if configured if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, aibconfig.ProviderAnthropic, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index e22b97f8..75c43479 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -28,14 +28,25 @@ type BlockingInterception struct { interceptionBase } -func NewBlockingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *BlockingInterception { +func NewBlockingInterceptor( + id uuid.UUID, + req *MessageNewParamsWrapper, + payload []byte, + cfg config.Anthropic, + bedrockCfg *config.AWSBedrock, + clientHeaders http.Header, + authHeaderName string, + tracer trace.Tracer, +) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ - id: id, - req: req, - payload: payload, - cfg: cfg, - bedrockCfg: bedrockCfg, - tracer: tracer, + id: id, + req: req, + payload: payload, + cfg: cfg, + bedrockCfg: bedrockCfg, + clientHeaders: clientHeaders, + authHeaderName: authHeaderName, + tracer: tracer, }} } @@ -74,6 +85,9 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req } opts := []option.RequestOption{option.WithRequestTimeout(time.Second * 600)} + + // TODO(ssncferreira): inject actor headers directly in the client-header + // middleware instead of using SDK options. if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { opts = append(opts, intercept.ActorHeadersAsAnthropicOpts(actor)...) } diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index 4e87fd85..78eb287e 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -34,14 +34,25 @@ type StreamingInterception struct { interceptionBase } -func NewStreamingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *StreamingInterception { +func NewStreamingInterceptor( + id uuid.UUID, + req *MessageNewParamsWrapper, + payload []byte, + cfg config.Anthropic, + bedrockCfg *config.AWSBedrock, + clientHeaders http.Header, + authHeaderName string, + tracer trace.Tracer, +) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ - id: id, - req: req, - payload: payload, - cfg: cfg, - bedrockCfg: bedrockCfg, - tracer: tracer, + id: id, + req: req, + payload: payload, + cfg: cfg, + bedrockCfg: bedrockCfg, + clientHeaders: clientHeaders, + authHeaderName: authHeaderName, + tracer: tracer, }} } @@ -110,6 +121,8 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re streamCtx, streamCancel := context.WithCancelCause(ctx) defer streamCancel(errors.New("deferred")) + // TODO(ssncferreira): inject actor headers directly in the client-header + // middleware instead of using SDK options. var opts []option.RequestOption if actor := aibcontext.ActorFromContext(ctx); actor != nil && i.cfg.SendActorHeaders { opts = append(opts, intercept.ActorHeadersAsAnthropicOpts(actor)...) diff --git a/intercept/responses/base.go b/intercept/responses/base.go index 8b7c3ded..f4272031 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -17,6 +17,7 @@ import ( "cdr.dev/slog/v3" "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/metrics" @@ -42,11 +43,16 @@ type responsesInterceptionBase struct { reqPayload []byte cfg config.OpenAI model string - recorder recorder.Recorder - mcpProxy mcp.ServerProxier - logger slog.Logger - metrics metrics.Metrics - tracer trace.Tracer + + // clientHeaders are the original HTTP headers from the client request. + clientHeaders http.Header + authHeaderName string + + recorder recorder.Recorder + mcpProxy mcp.ServerProxier + logger slog.Logger + metrics metrics.Metrics + tracer trace.Tracer } func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService { @@ -54,10 +60,21 @@ func (i *responsesInterceptionBase) newResponsesService() responses.ResponseServ // Add extra headers if configured. // Some providers require additional headers that are not added by the SDK. + // TODO(ssncferreira): remove as part of https://github.com/coder/aibridge/issues/192 for key, value := range i.cfg.ExtraHeaders { opts = append(opts, option.WithHeader(key, value)) } + // Forward client headers to upstream. This middleware runs after the SDK + // has built the request, and replaces the outgoing headers with the sanitized + // client headers plus provider auth. + if i.clientHeaders != nil { + opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + req.Header = intercept.BuildUpstreamHeaders(req.Header, i.clientHeaders, i.authHeaderName) + return next(req) + })) + } + // Add API dump middleware if configured if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index 3e94a6cc..687ce5b7 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -25,15 +25,26 @@ type BlockingResponsesInterceptor struct { responsesInterceptionBase } -func NewBlockingInterceptor(id uuid.UUID, req *ResponsesNewParamsWrapper, reqPayload []byte, cfg config.OpenAI, model string, tracer trace.Tracer) *BlockingResponsesInterceptor { +func NewBlockingInterceptor( + id uuid.UUID, + req *ResponsesNewParamsWrapper, + reqPayload []byte, + cfg config.OpenAI, + model string, + clientHeaders http.Header, + authHeaderName string, + tracer trace.Tracer, +) *BlockingResponsesInterceptor { return &BlockingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ - id: id, - req: req, - reqPayload: reqPayload, - cfg: cfg, - model: model, - tracer: tracer, + id: id, + req: req, + reqPayload: reqPayload, + cfg: cfg, + model: model, + clientHeaders: clientHeaders, + authHeaderName: authHeaderName, + tracer: tracer, }, } } @@ -80,6 +91,9 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * opts := i.requestOptions(&respCopy) opts = append(opts, option.WithRequestTimeout(time.Second*600)) + + // TODO(ssncferreira): inject actor headers directly in the client-header + // middleware instead of using SDK options. if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...) } diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index 6925d86f..99c6b33a 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -32,15 +32,26 @@ type StreamingResponsesInterceptor struct { responsesInterceptionBase } -func NewStreamingInterceptor(id uuid.UUID, req *ResponsesNewParamsWrapper, reqPayload []byte, cfg config.OpenAI, model string, tracer trace.Tracer) *StreamingResponsesInterceptor { +func NewStreamingInterceptor( + id uuid.UUID, + req *ResponsesNewParamsWrapper, + reqPayload []byte, + cfg config.OpenAI, + model string, + clientHeaders http.Header, + authHeaderName string, + tracer trace.Tracer, +) *StreamingResponsesInterceptor { return &StreamingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ - id: id, - req: req, - reqPayload: reqPayload, - cfg: cfg, - model: model, - tracer: tracer, + id: id, + req: req, + reqPayload: reqPayload, + cfg: cfg, + model: model, + clientHeaders: clientHeaders, + authHeaderName: authHeaderName, + tracer: tracer, }, } } @@ -98,6 +109,9 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r respCopy = responseCopier{} opts := i.requestOptions(&respCopy) + + // TODO(ssncferreira): inject actor headers directly in the client-header + // middleware instead of using SDK options. if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders { opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...) } diff --git a/provider/anthropic.go b/provider/anthropic.go index e682fdb7..4a79cd42 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -112,9 +112,9 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr var interceptor intercept.Interceptor if req.Stream { - interceptor = messages.NewStreamingInterceptor(id, &req, payload, cfg, p.bedrockCfg, tracer) + interceptor = messages.NewStreamingInterceptor(id, &req, payload, cfg, p.bedrockCfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = messages.NewBlockingInterceptor(id, &req, payload, cfg, p.bedrockCfg, tracer) + interceptor = messages.NewBlockingInterceptor(id, &req, payload, cfg, p.bedrockCfg, r.Header, p.AuthHeader(), tracer) } span.SetAttributes(interceptor.TraceAttributes(r)...) return interceptor, nil diff --git a/provider/anthropic_test.go b/provider/anthropic_test.go index 924c0f98..d34fd029 100644 --- a/provider/anthropic_test.go +++ b/provider/anthropic_test.go @@ -61,7 +61,7 @@ func TestAnthropic_CreateInterceptor(t *testing.T) { assert.Contains(t, err.Error(), "unmarshal request body") }) - t.Run("Messages_ForwardsAnthropicBetaHeaderToUpstream", func(t *testing.T) { + t.Run("Messages_ClientHeaders", func(t *testing.T) { t.Parallel() var receivedHeaders http.Header @@ -86,7 +86,9 @@ func TestAnthropic_CreateInterceptor(t *testing.T) { body := `{"model": "claude-opus-4-5", "max_tokens": 1024, "messages": [{"role": "user", "content": "hello"}], "stream": false}` req := httptest.NewRequest(http.MethodPost, routeMessages, bytes.NewBufferString(body)) req.Header.Set("Anthropic-Beta", betaHeader) - req.Header.Set("X-Custom-Header", "should-not-forward") + // Simulate a client sending its own auth credential, which must be replaced + // by aibridge with the configured provider key. + req.Header.Set("Authorization", "Bearer fake-client-bearer") w := httptest.NewRecorder() interceptor, err := provider.CreateInterceptor(w, req, testTracer) @@ -101,10 +103,11 @@ func TestAnthropic_CreateInterceptor(t *testing.T) { require.NoError(t, err) // Verify the full Anthropic-Beta header (all betas) was forwarded unchanged. - assert.Equal(t, betaHeader, receivedHeaders.Get("Anthropic-Beta")) + assert.Equal(t, betaHeader, receivedHeaders.Get("Anthropic-Beta"), "Anthropic-Beta header must be forwarded unchanged to upstream") - // Verify non-Anthropic headers are not forwarded. - assert.Empty(t, receivedHeaders.Get("X-Custom-Header"), "non-Anthropic headers should not be forwarded") + // Verify aibridge's configured key was used and the client's auth credential was not forwarded. + assert.Equal(t, "test-key", receivedHeaders.Get("X-Api-Key"), "upstream must receive configured provider key") + assert.Empty(t, receivedHeaders.Get("Authorization"), "client Authorization header must not reach upstream") }) t.Run("UnknownRoute", func(t *testing.T) { diff --git a/provider/copilot.go b/provider/copilot.go index 9b128cab..4bdf6a29 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -148,9 +148,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac } if req.Stream { - interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, tracer) + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, tracer) + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer) } case routeCopilotResponses: @@ -164,9 +164,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac } if req.Stream { - interceptor = responses.NewStreamingInterceptor(id, &req, payload, cfg, req.Model, tracer) + interceptor = responses.NewStreamingInterceptor(id, &req, payload, cfg, req.Model, r.Header, p.AuthHeader(), tracer) } else { - interceptor = responses.NewBlockingInterceptor(id, &req, payload, cfg, req.Model, tracer) + interceptor = responses.NewBlockingInterceptor(id, &req, payload, cfg, req.Model, r.Header, p.AuthHeader(), tracer) } default: diff --git a/provider/copilot_test.go b/provider/copilot_test.go index 697b6990..a5df7bd4 100644 --- a/provider/copilot_test.go +++ b/provider/copilot_test.go @@ -129,7 +129,7 @@ func TestCopilot_CreateInterceptor(t *testing.T) { assert.Contains(t, err.Error(), "unmarshal chat completions request body") }) - t.Run("ChatCompletions_ForwardsHeadersToUpstream", func(t *testing.T) { + t.Run("ChatCompletions_ClientHeaders", func(t *testing.T) { t.Parallel() var receivedHeaders http.Header @@ -153,7 +153,6 @@ func TestCopilot_CreateInterceptor(t *testing.T) { req.Header.Set("Authorization", "Bearer test-token") req.Header.Set("Editor-Version", "vscode/1.85.0") req.Header.Set("Copilot-Integration-Id", "test-integration") - req.Header.Set("X-Custom-Header", "should-not-forward") w := httptest.NewRecorder() interceptor, err := provider.CreateInterceptor(w, req, testTracer) @@ -168,12 +167,12 @@ func TestCopilot_CreateInterceptor(t *testing.T) { err = interceptor.ProcessRequest(w, processReq) require.NoError(t, err) - // Verify headers were forwarded + // Verify Copilot-specific headers were forwarded. assert.Equal(t, "vscode/1.85.0", receivedHeaders.Get("Editor-Version")) assert.Equal(t, "test-integration", receivedHeaders.Get("Copilot-Integration-Id")) - - // Verify non-Copilot headers are not forwarded - assert.Empty(t, receivedHeaders.Get("X-Custom-Header"), "non-Copilot headers should not be forwarded") + // Copilot uses per-user tokens: the client's Authorization must reach upstream as-is. + assert.Equal(t, "Bearer test-token", receivedHeaders.Get("Authorization"), "client Authorization must be used as provider key") + assert.Empty(t, receivedHeaders.Get("X-Api-Key"), "X-Api-Key must not be set upstream") }) t.Run("Responses_NonStreamingRequest_BlockingInterceptor", func(t *testing.T) { @@ -221,7 +220,7 @@ func TestCopilot_CreateInterceptor(t *testing.T) { assert.Contains(t, err.Error(), "unmarshal responses request body") }) - t.Run("Responses_ForwardsHeadersToUpstream", func(t *testing.T) { + t.Run("Responses_ClientHeaders", func(t *testing.T) { t.Parallel() var receivedHeaders http.Header @@ -245,7 +244,6 @@ func TestCopilot_CreateInterceptor(t *testing.T) { req.Header.Set("Authorization", "Bearer test-token") req.Header.Set("Editor-Version", "vscode/1.85.0") req.Header.Set("Copilot-Integration-Id", "test-integration") - req.Header.Set("X-Custom-Header", "should-not-forward") w := httptest.NewRecorder() interceptor, err := provider.CreateInterceptor(w, req, testTracer) @@ -260,12 +258,12 @@ func TestCopilot_CreateInterceptor(t *testing.T) { err = interceptor.ProcessRequest(w, processReq) require.NoError(t, err) - // Verify headers were forwarded + // Verify Copilot-specific headers were forwarded. assert.Equal(t, "vscode/1.85.0", receivedHeaders.Get("Editor-Version")) assert.Equal(t, "test-integration", receivedHeaders.Get("Copilot-Integration-Id")) - - // Verify non-Copilot headers are not forwarded - assert.Empty(t, receivedHeaders.Get("X-Custom-Header"), "non-Copilot headers should not be forwarded") + // Copilot uses per-user tokens: the client's Authorization must reach upstream as-is. + assert.Equal(t, "Bearer test-token", receivedHeaders.Get("Authorization"), "client Authorization must be used as provider key") + assert.Empty(t, receivedHeaders.Get("X-Api-Key"), "X-Api-Key must not be set upstream") }) t.Run("UnknownRoute", func(t *testing.T) { diff --git a/provider/openai.go b/provider/openai.go index 43d6811e..dd68f0d9 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -105,9 +105,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace } if req.Stream { - interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.cfg, tracer) + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.cfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.cfg, tracer) + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.cfg, r.Header, p.AuthHeader(), tracer) } case routeResponses: @@ -120,9 +120,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace return nil, fmt.Errorf("unmarshal request body: %w", err) } if req.Stream { - interceptor = responses.NewStreamingInterceptor(id, &req, payload, p.cfg, string(req.Model), tracer) + interceptor = responses.NewStreamingInterceptor(id, &req, payload, p.cfg, string(req.Model), r.Header, p.AuthHeader(), tracer) } else { - interceptor = responses.NewBlockingInterceptor(id, &req, payload, p.cfg, string(req.Model), tracer) + interceptor = responses.NewBlockingInterceptor(id, &req, payload, p.cfg, string(req.Model), r.Header, p.AuthHeader(), tracer) } default: diff --git a/provider/openai_test.go b/provider/openai_test.go index f2654b07..4add332e 100644 --- a/provider/openai_test.go +++ b/provider/openai_test.go @@ -9,11 +9,20 @@ import ( "strings" "testing" + "cdr.dev/slog/v3" "github.com/coder/aibridge/config" + "github.com/coder/aibridge/internal/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace/noop" "golang.org/x/sync/errgroup" ) +const ( + chatCompletionResponse = `{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}` + responsesAPIResponse = `{"id":"resp-123","object":"response","created_at":1677652288,"model":"gpt-5","output":[],"usage":{"input_tokens":5,"output_tokens":10,"total_tokens":15}}` +) + type message struct { Role string Content string @@ -150,6 +159,73 @@ func generateResponsesPayload(payloadSize int, inputCount int, stream bool) []by return bodyBytes } +func TestOpenAI_CreateInterceptor(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + route string + requestBody string + responseBody string + }{ + { + name: "ChatCompletions_ClientHeaders", + route: routeChatCompletions, + requestBody: `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}], "stream": false}`, + responseBody: chatCompletionResponse, + }, + { + name: "Responses_ClientHeaders", + route: routeResponses, + requestBody: `{"model": "gpt-5", "input": "hello", "stream": false}`, + responseBody: responsesAPIResponse, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var receivedHeaders http.Header + + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(tc.responseBody)) + require.NoError(t, err) + })) + t.Cleanup(mockUpstream.Close) + + provider := NewOpenAI(config.OpenAI{ + BaseURL: mockUpstream.URL, + Key: "test-key", + }) + + req := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+tc.route, bytes.NewBufferString(tc.requestBody)) + // Simulate a client sending its own auth credential, which must be replaced + // by aibridge with the configured provider key. + req.Header.Set("Authorization", "Bearer fake-client-bearer") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + require.NoError(t, err) + require.NotNil(t, interceptor) + + logger := slog.Make() + interceptor.Setup(logger, &testutil.MockRecorder{}, nil) + + processReq := httptest.NewRequest(http.MethodPost, provider.RoutePrefix()+tc.route, nil) + err = interceptor.ProcessRequest(w, processReq) + require.NoError(t, err) + + // Verify aibridge's configured key was used and the client's auth credential was not forwarded. + assert.Equal(t, "Bearer test-key", receivedHeaders.Get("Authorization"), "upstream must receive configured provider key") + assert.Empty(t, receivedHeaders.Get("X-Api-Key"), "X-Api-Key must not be set upstream") + }) + } +} + func BenchmarkOpenAI_CreateInterceptor_ChatCompletions(b *testing.B) { provider := NewOpenAI(config.OpenAI{ BaseURL: "https://api.openai.com/v1/",