Skip to content

Commit 2789396

Browse files
committed
fix: ensure connection-scoped headers are filtered in upstream requests
- Added `connectionScopedHeaders` utility to respect "Connection" header directives. - Updated `FilterUpstreamHeaders` to remove connection-scoped headers dynamically. - Refactored and tested upstream header filtering with additional validations. - Adjusted upstream header handling during retries to replace headers safely.
1 parent 61da7bd commit 2789396

File tree

7 files changed

+136
-19
lines changed

7 files changed

+136
-19
lines changed

internal/runtime/executor/codex_websockets_executor.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
363363
}
364364
}
365365

366-
func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
366+
func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
367367
log.Debugf("Executing Codex Websockets stream request with auth ID: %s, model: %s", auth.ID, req.Model)
368368
if ctx == nil {
369369
ctx = context.Background()
@@ -436,7 +436,9 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
436436
})
437437

438438
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
439+
var upstreamHeaders http.Header
439440
if respHS != nil {
441+
upstreamHeaders = respHS.Header.Clone()
440442
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
441443
}
442444
if errDial != nil {
@@ -516,7 +518,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
516518
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
517519

518520
out := make(chan cliproxyexecutor.StreamChunk)
519-
stream = out
520521
go func() {
521522
terminateReason := "completed"
522523
var terminateErr error
@@ -627,7 +628,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
627628
}
628629
}()
629630

630-
return stream, nil
631+
return &cliproxyexecutor.StreamResult{Headers: upstreamHeaders, Chunks: out}, nil
631632
}
632633

633634
func (e *CodexWebsocketsExecutor) dialCodexWebsocket(ctx context.Context, auth *cliproxyauth.Auth, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) {
@@ -1343,7 +1344,7 @@ func (e *CodexAutoExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
13431344
return e.httpExec.Execute(ctx, auth, req, opts)
13441345
}
13451346

1346-
func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
1347+
func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
13471348
if e == nil || e.httpExec == nil || e.wsExec == nil {
13481349
return nil, fmt.Errorf("codex auto executor: executor is nil")
13491350
}

sdk/api/handlers/handlers.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,11 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
593593
return nil, nil, errChan
594594
}
595595
// Capture upstream headers from the initial connection synchronously before the goroutine starts.
596-
upstreamHeaders := FilterUpstreamHeaders(streamResult.Headers)
596+
// Keep a mutable map so bootstrap retries can replace it before first payload is sent.
597+
upstreamHeaders := cloneHeader(FilterUpstreamHeaders(streamResult.Headers))
598+
if upstreamHeaders == nil {
599+
upstreamHeaders = make(http.Header)
600+
}
597601
chunks := streamResult.Chunks
598602
dataChan := make(chan []byte)
599603
errChan := make(chan *interfaces.ErrorMessage, 1)
@@ -670,6 +674,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
670674
bootstrapRetries++
671675
retryResult, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
672676
if retryErr == nil {
677+
replaceHeader(upstreamHeaders, FilterUpstreamHeaders(retryResult.Headers))
673678
chunks = retryResult.Chunks
674679
continue outer
675680
}
@@ -761,6 +766,26 @@ func cloneBytes(src []byte) []byte {
761766
return dst
762767
}
763768

769+
func cloneHeader(src http.Header) http.Header {
770+
if src == nil {
771+
return nil
772+
}
773+
dst := make(http.Header, len(src))
774+
for key, values := range src {
775+
dst[key] = append([]string(nil), values...)
776+
}
777+
return dst
778+
}
779+
780+
func replaceHeader(dst http.Header, src http.Header) {
781+
for key := range dst {
782+
delete(dst, key)
783+
}
784+
for key, values := range src {
785+
dst[key] = append([]string(nil), values...)
786+
}
787+
}
788+
764789
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
765790
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
766791
status := http.StatusInternalServerError

sdk/api/handlers/handlers_stream_bootstrap_test.go

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,18 @@ func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth,
4040
},
4141
}
4242
close(ch)
43-
return &coreexecutor.StreamResult{Chunks: ch}, nil
43+
return &coreexecutor.StreamResult{
44+
Headers: http.Header{"X-Upstream-Attempt": {"1"}},
45+
Chunks: ch,
46+
}, nil
4447
}
4548

4649
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
4750
close(ch)
48-
return &coreexecutor.StreamResult{Chunks: ch}, nil
51+
return &coreexecutor.StreamResult{
52+
Headers: http.Header{"X-Upstream-Attempt": {"2"}},
53+
Chunks: ch,
54+
}, nil
4955
}
5056

5157
func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
@@ -134,7 +140,7 @@ func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coree
134140
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
135141
}
136142

137-
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
143+
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (*coreexecutor.StreamResult, error) {
138144
_ = ctx
139145
_ = req
140146
_ = opts
@@ -160,12 +166,12 @@ func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *corea
160166
},
161167
}
162168
close(ch)
163-
return ch, nil
169+
return &coreexecutor.StreamResult{Chunks: ch}, nil
164170
}
165171

166172
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
167173
close(ch)
168-
return ch, nil
174+
return &coreexecutor.StreamResult{Chunks: ch}, nil
169175
}
170176

171177
func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
@@ -235,7 +241,7 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
235241
BootstrapRetries: 1,
236242
},
237243
}, manager)
238-
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
244+
dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
239245
if dataChan == nil || errChan == nil {
240246
t.Fatalf("expected non-nil channels")
241247
}
@@ -257,6 +263,10 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
257263
if executor.Calls() != 2 {
258264
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
259265
}
266+
upstreamAttemptHeader := upstreamHeaders.Get("X-Upstream-Attempt")
267+
if upstreamAttemptHeader != "2" {
268+
t.Fatalf("expected upstream header from retry attempt, got %q", upstreamAttemptHeader)
269+
}
260270
}
261271

262272
func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
@@ -367,7 +377,7 @@ func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T)
367377
},
368378
}, manager)
369379
ctx := WithPinnedAuthID(context.Background(), "auth1")
370-
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
380+
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
371381
if dataChan == nil || errChan == nil {
372382
t.Fatalf("expected non-nil channels")
373383
}
@@ -431,7 +441,7 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test
431441
ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) {
432442
selectedAuthID = authID
433443
})
434-
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
444+
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
435445
if dataChan == nil || errChan == nil {
436446
t.Fatalf("expected non-nil channels")
437447
}

sdk/api/handlers/header_filter.go

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package handlers
22

3-
import "net/http"
3+
import (
4+
"net/http"
5+
"strings"
6+
)
47

58
// hopByHopHeaders lists RFC 7230 Section 6.1 hop-by-hop headers that MUST NOT
69
// be forwarded by proxies, plus security-sensitive headers that should not leak.
@@ -27,9 +30,14 @@ func FilterUpstreamHeaders(src http.Header) http.Header {
2730
if src == nil {
2831
return nil
2932
}
33+
connectionScoped := connectionScopedHeaders(src)
3034
dst := make(http.Header)
3135
for key, values := range src {
32-
if _, blocked := hopByHopHeaders[http.CanonicalHeaderKey(key)]; blocked {
36+
canonicalKey := http.CanonicalHeaderKey(key)
37+
if _, blocked := hopByHopHeaders[canonicalKey]; blocked {
38+
continue
39+
}
40+
if _, scoped := connectionScoped[canonicalKey]; scoped {
3341
continue
3442
}
3543
dst[key] = values
@@ -40,6 +48,20 @@ func FilterUpstreamHeaders(src http.Header) http.Header {
4048
return dst
4149
}
4250

51+
func connectionScopedHeaders(src http.Header) map[string]struct{} {
52+
scoped := make(map[string]struct{})
53+
for _, rawValue := range src.Values("Connection") {
54+
for _, token := range strings.Split(rawValue, ",") {
55+
headerName := strings.TrimSpace(token)
56+
if headerName == "" {
57+
continue
58+
}
59+
scoped[http.CanonicalHeaderKey(headerName)] = struct{}{}
60+
}
61+
}
62+
return scoped
63+
}
64+
4365
// WriteUpstreamHeaders writes filtered upstream headers to the gin response writer.
4466
// Headers already set by CPA (e.g., Content-Type) are NOT overwritten.
4567
func WriteUpstreamHeaders(dst http.Header, src http.Header) {
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package handlers
2+
3+
import (
4+
"net/http"
5+
"testing"
6+
)
7+
8+
func TestFilterUpstreamHeaders_RemovesConnectionScopedHeaders(t *testing.T) {
9+
src := http.Header{}
10+
src.Add("Connection", "keep-alive, x-hop-a, x-hop-b")
11+
src.Add("Connection", "x-hop-c")
12+
src.Set("Keep-Alive", "timeout=5")
13+
src.Set("X-Hop-A", "a")
14+
src.Set("X-Hop-B", "b")
15+
src.Set("X-Hop-C", "c")
16+
src.Set("X-Request-Id", "req-1")
17+
src.Set("Set-Cookie", "session=secret")
18+
19+
filtered := FilterUpstreamHeaders(src)
20+
if filtered == nil {
21+
t.Fatalf("expected filtered headers, got nil")
22+
}
23+
24+
requestID := filtered.Get("X-Request-Id")
25+
if requestID != "req-1" {
26+
t.Fatalf("expected X-Request-Id to be preserved, got %q", requestID)
27+
}
28+
29+
blockedHeaderKeys := []string{
30+
"Connection",
31+
"Keep-Alive",
32+
"X-Hop-A",
33+
"X-Hop-B",
34+
"X-Hop-C",
35+
"Set-Cookie",
36+
}
37+
for _, key := range blockedHeaderKeys {
38+
value := filtered.Get(key)
39+
if value != "" {
40+
t.Fatalf("expected %s to be removed, got %q", key, value)
41+
}
42+
}
43+
}
44+
45+
func TestFilterUpstreamHeaders_ReturnsNilWhenAllHeadersBlocked(t *testing.T) {
46+
src := http.Header{}
47+
src.Add("Connection", "x-hop-a")
48+
src.Set("X-Hop-A", "a")
49+
src.Set("Set-Cookie", "session=secret")
50+
51+
filtered := FilterUpstreamHeaders(src)
52+
if filtered != nil {
53+
t.Fatalf("expected nil when all headers are filtered, got %#v", filtered)
54+
}
55+
}

sdk/api/handlers/openai/openai_responses_websocket.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
153153
pinnedAuthID = strings.TrimSpace(authID)
154154
})
155155
}
156-
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
156+
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
157157

158158
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
159159
if errForward != nil {

sdk/cliproxy/auth/conductor_executor_replace_test.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ func (e *replaceAwareExecutor) Execute(context.Context, *Auth, cliproxyexecutor.
2424
return cliproxyexecutor.Response{}, nil
2525
}
2626

27-
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
27+
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
2828
ch := make(chan cliproxyexecutor.StreamChunk)
2929
close(ch)
30-
return ch, nil
30+
return &cliproxyexecutor.StreamResult{Chunks: ch}, nil
3131
}
3232

3333
func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
@@ -89,7 +89,11 @@ func TestManagerExecutorReturnsRegisteredExecutor(t *testing.T) {
8989
if !okResolved {
9090
t.Fatal("expected registered executor to be found")
9191
}
92-
if resolved != current {
92+
resolvedExecutor, okResolvedExecutor := resolved.(*replaceAwareExecutor)
93+
if !okResolvedExecutor {
94+
t.Fatalf("expected resolved executor type %T, got %T", current, resolved)
95+
}
96+
if resolvedExecutor != current {
9397
t.Fatal("expected resolved executor to match registered executor")
9498
}
9599

0 commit comments

Comments
 (0)