Skip to content

Commit 5ebc58f

Browse files
committed
refactor(executor): remove legacy connCreateSent logic and standardize response.create usage for all websocket events
- Simplified connection logic by removing `connCreateSent` and related state handling. - Updated `buildCodexWebsocketRequestBody` to always use `response.create`. - Added unit tests to validate `response.create` behavior and beta header preservation. - Dropped unsupported `response.append` and outdated `response.done` event types.
1 parent 2b609dd commit 5ebc58f

File tree

4 files changed

+130
-109
lines changed

4 files changed

+130
-109
lines changed

internal/runtime/executor/codex_websockets_executor.go

Lines changed: 15 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import (
3131
)
3232

3333
const (
34-
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-04"
34+
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-06"
3535
codexResponsesWebsocketIdleTimeout = 5 * time.Minute
3636
codexResponsesWebsocketHandshakeTO = 30 * time.Second
3737
)
@@ -57,11 +57,6 @@ type codexWebsocketSession struct {
5757
wsURL string
5858
authID string
5959

60-
// connCreateSent tracks whether a `response.create` message has been successfully sent
61-
// on the current websocket connection. The upstream expects the first message on each
62-
// connection to be `response.create`.
63-
connCreateSent bool
64-
6560
writeMu sync.Mutex
6661

6762
activeMu sync.Mutex
@@ -212,13 +207,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
212207
defer sess.reqMu.Unlock()
213208
}
214209

215-
allowAppend := true
216-
if sess != nil {
217-
sess.connMu.Lock()
218-
allowAppend = sess.connCreateSent
219-
sess.connMu.Unlock()
220-
}
221-
wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend)
210+
wsReqBody := buildCodexWebsocketRequestBody(body)
222211
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
223212
URL: wsURL,
224213
Method: "WEBSOCKET",
@@ -280,10 +269,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
280269
// execution session.
281270
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
282271
if errDialRetry == nil && connRetry != nil {
283-
sess.connMu.Lock()
284-
allowAppend = sess.connCreateSent
285-
sess.connMu.Unlock()
286-
wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend)
272+
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
287273
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
288274
URL: wsURL,
289275
Method: "WEBSOCKET",
@@ -312,7 +298,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
312298
return resp, errSend
313299
}
314300
}
315-
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
316301

317302
for {
318303
if ctx != nil && ctx.Err() != nil {
@@ -403,26 +388,20 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
403388
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey)
404389

405390
var authID, authLabel, authType, authValue string
406-
if auth != nil {
407-
authID = auth.ID
408-
authLabel = auth.Label
409-
authType, authValue = auth.AccountInfo()
410-
}
391+
authID = auth.ID
392+
authLabel = auth.Label
393+
authType, authValue = auth.AccountInfo()
411394

412395
executionSessionID := executionSessionIDFromOptions(opts)
413396
var sess *codexWebsocketSession
414397
if executionSessionID != "" {
415398
sess = e.getOrCreateSession(executionSessionID)
416-
sess.reqMu.Lock()
399+
if sess != nil {
400+
sess.reqMu.Lock()
401+
}
417402
}
418403

419-
allowAppend := true
420-
if sess != nil {
421-
sess.connMu.Lock()
422-
allowAppend = sess.connCreateSent
423-
sess.connMu.Unlock()
424-
}
425-
wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend)
404+
wsReqBody := buildCodexWebsocketRequestBody(body)
426405
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
427406
URL: wsURL,
428407
Method: "WEBSOCKET",
@@ -483,10 +462,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
483462
sess.reqMu.Unlock()
484463
return nil, errDialRetry
485464
}
486-
sess.connMu.Lock()
487-
allowAppend = sess.connCreateSent
488-
sess.connMu.Unlock()
489-
wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend)
465+
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
490466
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
491467
URL: wsURL,
492468
Method: "WEBSOCKET",
@@ -515,7 +491,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
515491
return nil, errSend
516492
}
517493
}
518-
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
519494

520495
out := make(chan cliproxyexecutor.StreamChunk)
521496
go func() {
@@ -657,31 +632,14 @@ func writeCodexWebsocketMessage(sess *codexWebsocketSession, conn *websocket.Con
657632
return conn.WriteMessage(websocket.TextMessage, payload)
658633
}
659634

660-
func buildCodexWebsocketRequestBody(body []byte, allowAppend bool) []byte {
635+
func buildCodexWebsocketRequestBody(body []byte) []byte {
661636
if len(body) == 0 {
662637
return nil
663638
}
664639

665-
// Codex CLI websocket v2 uses `response.create` with `previous_response_id` for incremental turns.
666-
// The upstream ChatGPT Codex websocket currently rejects that with close 1008 (policy violation).
667-
// Fall back to v1 `response.append` semantics on the same websocket connection to keep the session alive.
668-
//
669-
// NOTE: The upstream expects the first websocket event on each connection to be `response.create`,
670-
// so we only use `response.append` after we have initialized the current connection.
671-
if allowAppend {
672-
if prev := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String()); prev != "" {
673-
inputNode := gjson.GetBytes(body, "input")
674-
wsReqBody := []byte(`{}`)
675-
wsReqBody, _ = sjson.SetBytes(wsReqBody, "type", "response.append")
676-
if inputNode.Exists() && inputNode.IsArray() && strings.TrimSpace(inputNode.Raw) != "" {
677-
wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte(inputNode.Raw))
678-
return wsReqBody
679-
}
680-
wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte("[]"))
681-
return wsReqBody
682-
}
683-
}
684-
640+
// Match codex-rs websocket v2 semantics: every request is `response.create`.
641+
// Incremental follow-up turns continue on the same websocket using
642+
// `previous_response_id` + incremental `input`, not `response.append`.
685643
wsReqBody, errSet := sjson.SetBytes(bytes.Clone(body), "type", "response.create")
686644
if errSet == nil && len(wsReqBody) > 0 {
687645
return wsReqBody
@@ -725,21 +683,6 @@ func readCodexWebsocketMessage(ctx context.Context, sess *codexWebsocketSession,
725683
}
726684
}
727685

728-
func markCodexWebsocketCreateSent(sess *codexWebsocketSession, conn *websocket.Conn, payload []byte) {
729-
if sess == nil || conn == nil || len(payload) == 0 {
730-
return
731-
}
732-
if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" {
733-
return
734-
}
735-
736-
sess.connMu.Lock()
737-
if sess.conn == conn {
738-
sess.connCreateSent = true
739-
}
740-
sess.connMu.Unlock()
741-
}
742-
743686
func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *websocket.Dialer {
744687
dialer := &websocket.Dialer{
745688
Proxy: http.ProxyFromEnvironment,
@@ -1017,36 +960,6 @@ func closeHTTPResponseBody(resp *http.Response, logPrefix string) {
1017960
}
1018961
}
1019962

1020-
func closeOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} {
1021-
done := make(chan struct{})
1022-
if ctx == nil || conn == nil {
1023-
return done
1024-
}
1025-
go func() {
1026-
select {
1027-
case <-done:
1028-
case <-ctx.Done():
1029-
_ = conn.Close()
1030-
}
1031-
}()
1032-
return done
1033-
}
1034-
1035-
func cancelReadOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} {
1036-
done := make(chan struct{})
1037-
if ctx == nil || conn == nil {
1038-
return done
1039-
}
1040-
go func() {
1041-
select {
1042-
case <-done:
1043-
case <-ctx.Done():
1044-
_ = conn.SetReadDeadline(time.Now())
1045-
}
1046-
}()
1047-
return done
1048-
}
1049-
1050963
func executionSessionIDFromOptions(opts cliproxyexecutor.Options) string {
1051964
if len(opts.Metadata) == 0 {
1052965
return ""
@@ -1120,7 +1033,6 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *
11201033
sess.conn = conn
11211034
sess.wsURL = wsURL
11221035
sess.authID = authID
1123-
sess.connCreateSent = false
11241036
sess.readerConn = conn
11251037
sess.connMu.Unlock()
11261038

@@ -1206,7 +1118,6 @@ func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSes
12061118
return
12071119
}
12081120
sess.conn = nil
1209-
sess.connCreateSent = false
12101121
if sess.readerConn == conn {
12111122
sess.readerConn = nil
12121123
}
@@ -1273,7 +1184,6 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess
12731184
authID := sess.authID
12741185
wsURL := sess.wsURL
12751186
sess.conn = nil
1276-
sess.connCreateSent = false
12771187
if sess.readerConn == conn {
12781188
sess.readerConn = nil
12791189
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package executor
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"testing"
7+
8+
"github.com/tidwall/gjson"
9+
)
10+
11+
func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T) {
12+
body := []byte(`{"model":"gpt-5-codex","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-1"}]}`)
13+
14+
wsReqBody := buildCodexWebsocketRequestBody(body)
15+
16+
if got := gjson.GetBytes(wsReqBody, "type").String(); got != "response.create" {
17+
t.Fatalf("type = %s, want response.create", got)
18+
}
19+
if got := gjson.GetBytes(wsReqBody, "previous_response_id").String(); got != "resp-1" {
20+
t.Fatalf("previous_response_id = %s, want resp-1", got)
21+
}
22+
if gjson.GetBytes(wsReqBody, "input.0.id").String() != "msg-1" {
23+
t.Fatalf("input item id mismatch")
24+
}
25+
if got := gjson.GetBytes(wsReqBody, "type").String(); got == "response.append" {
26+
t.Fatalf("unexpected websocket request type: %s", got)
27+
}
28+
}
29+
30+
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
31+
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "")
32+
33+
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
34+
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
35+
}
36+
}

sdk/api/handlers/openai/openai_responses_websocket.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ const (
2626
wsRequestTypeAppend = "response.append"
2727
wsEventTypeError = "error"
2828
wsEventTypeCompleted = "response.completed"
29-
wsEventTypeDone = "response.done"
3029
wsDoneMarker = "[DONE]"
3130
wsTurnStateHeader = "x-codex-turn-state"
3231
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
@@ -469,9 +468,6 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
469468
for i := range payloads {
470469
eventType := gjson.GetBytes(payloads[i], "type").String()
471470
if eventType == wsEventTypeCompleted {
472-
// log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone)
473-
payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone)
474-
475471
completed = true
476472
completedOutput = responseCompletedOutputFromPayload(payloads[i])
477473
}

sdk/api/handlers/openai/openai_responses_websocket_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@ package openai
22

33
import (
44
"bytes"
5+
"errors"
56
"net/http"
67
"net/http/httptest"
78
"strings"
89
"testing"
910

1011
"github.com/gin-gonic/gin"
12+
"github.com/gorilla/websocket"
13+
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
1114
"github.com/tidwall/gjson"
1215
)
1316

@@ -247,3 +250,79 @@ func TestSetWebsocketRequestBody(t *testing.T) {
247250
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
248251
}
249252
}
253+
254+
func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
255+
gin.SetMode(gin.TestMode)
256+
257+
serverErrCh := make(chan error, 1)
258+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
259+
conn, err := responsesWebsocketUpgrader.Upgrade(w, r, nil)
260+
if err != nil {
261+
serverErrCh <- err
262+
return
263+
}
264+
defer func() {
265+
errClose := conn.Close()
266+
if errClose != nil {
267+
serverErrCh <- errClose
268+
}
269+
}()
270+
271+
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
272+
ctx.Request = r
273+
274+
data := make(chan []byte, 1)
275+
errCh := make(chan *interfaces.ErrorMessage)
276+
data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"message\",\"id\":\"out-1\"}]}}\n\n")
277+
close(data)
278+
close(errCh)
279+
280+
var bodyLog strings.Builder
281+
completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
282+
ctx,
283+
conn,
284+
func(...interface{}) {},
285+
data,
286+
errCh,
287+
&bodyLog,
288+
"session-1",
289+
)
290+
if err != nil {
291+
serverErrCh <- err
292+
return
293+
}
294+
if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" {
295+
serverErrCh <- errors.New("completed output not captured")
296+
return
297+
}
298+
serverErrCh <- nil
299+
}))
300+
defer server.Close()
301+
302+
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
303+
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
304+
if err != nil {
305+
t.Fatalf("dial websocket: %v", err)
306+
}
307+
defer func() {
308+
errClose := conn.Close()
309+
if errClose != nil {
310+
t.Fatalf("close websocket: %v", errClose)
311+
}
312+
}()
313+
314+
_, payload, errReadMessage := conn.ReadMessage()
315+
if errReadMessage != nil {
316+
t.Fatalf("read websocket message: %v", errReadMessage)
317+
}
318+
if gjson.GetBytes(payload, "type").String() != wsEventTypeCompleted {
319+
t.Fatalf("payload type = %s, want %s", gjson.GetBytes(payload, "type").String(), wsEventTypeCompleted)
320+
}
321+
if strings.Contains(string(payload), "response.done") {
322+
t.Fatalf("payload unexpectedly rewrote completed event: %s", payload)
323+
}
324+
325+
if errServer := <-serverErrCh; errServer != nil {
326+
t.Fatalf("server error: %v", errServer)
327+
}
328+
}

0 commit comments

Comments
 (0)