diff --git a/intercept/responses/base.go b/intercept/responses/base.go index a0d6068b..dcd72a0d 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -52,6 +52,12 @@ type responsesInterceptionBase struct { func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService { opts := []option.RequestOption{option.WithBaseURL(i.cfg.BaseURL), option.WithAPIKey(i.cfg.Key)} + // Add extra headers if configured. + // Some providers require additional headers that are not added by the SDK. + for key, value := range i.cfg.ExtraHeaders { + opts = append(opts, option.WithHeader(key, value)) + } + // Add API dump middleware if configured if mw := apidump.NewMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) diff --git a/provider/copilot_test.go b/provider/copilot_test.go index 119c114e..697b6990 100644 --- a/provider/copilot_test.go +++ b/provider/copilot_test.go @@ -221,6 +221,53 @@ func TestCopilot_CreateInterceptor(t *testing.T) { assert.Contains(t, err.Error(), "unmarshal responses request body") }) + t.Run("Responses_ForwardsHeadersToUpstream", func(t *testing.T) { + t.Parallel() + + var receivedHeaders http.Header + + // Mock upstream that captures headers + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"resp-123","object":"responses.response","created":1677652288,"model":"gpt-5-mini","output":[],"usage":{"input_tokens":5,"output_tokens":10,"total_tokens":15}}`)) + })) + t.Cleanup(mockUpstream.Close) + + // Create provider with mock upstream URL + provider := NewCopilot(config.Copilot{ + BaseURL: mockUpstream.URL, + }) + + body := `{"model": "gpt-5-mini", "input": "hello", "stream": false}` + req := httptest.NewRequest(http.MethodPost, routeCopilotResponses, bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer test-token") + req.Header.Set("Editor-Version", "vscode/1.85.0") + req.Header.Set("Copilot-Integration-Id", "test-integration") + req.Header.Set("X-Custom-Header", "should-not-forward") + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + require.NoError(t, err) + require.NotNil(t, interceptor) + + // Setup and process request + logger := slog.Make() + interceptor.Setup(logger, &testutil.MockRecorder{}, nil) + + processReq := httptest.NewRequest(http.MethodPost, routeCopilotResponses, nil) + err = interceptor.ProcessRequest(w, processReq) + require.NoError(t, err) + + // Verify headers were forwarded + assert.Equal(t, "vscode/1.85.0", receivedHeaders.Get("Editor-Version")) + assert.Equal(t, "test-integration", receivedHeaders.Get("Copilot-Integration-Id")) + + // Verify non-Copilot headers are not forwarded + assert.Empty(t, receivedHeaders.Get("X-Custom-Header"), "non-Copilot headers should not be forwarded") + }) + t.Run("UnknownRoute", func(t *testing.T) { t.Parallel()