diff --git a/bun.lock b/bun.lock index cf46ed14b7..ffb78df0d2 100644 --- a/bun.lock +++ b/bun.lock @@ -46,6 +46,7 @@ "@radix-ui/react-toggle-group": "^1.1.11", "@radix-ui/react-tooltip": "^1.2.8", "@radix-ui/react-visually-hidden": "^1.2.4", + "@vercel/ai-sdk-openai-websocket-fetch": "^1.0.0", "@xterm/addon-serialize": "^0.14.0", "@xterm/headless": "^6.0.0", "ai": "^6.0.72", @@ -1587,6 +1588,8 @@ "@unrs/resolver-binding-win32-x64-msvc": ["@unrs/resolver-binding-win32-x64-msvc@1.11.1", "", { "os": "win32", "cpu": "x64" }, "sha512-lrW200hZdbfRtztbygyaq/6jP6AKE8qQN2KvPcJ+x7wiD038YtnYtZ82IMNJ69GJibV7bwL3y9FgK+5w/pYt6g=="], + "@vercel/ai-sdk-openai-websocket-fetch": ["@vercel/ai-sdk-openai-websocket-fetch@1.0.0", "", { "dependencies": { "ws": "^8" }, "peerDependencies": { "@ai-sdk/openai": ">=3" }, "optionalPeers": ["@ai-sdk/openai"] }, "sha512-QTBex1bJogUTdh8lcmgxQhYkRZDVDMEdRyJjqy/Chnzv0i8UpLB2imFZvXORNUNdvu51aNPwgsclvx51F0Pc+A=="], + "@vercel/oidc": ["@vercel/oidc@3.1.0", "", {}, "sha512-Fw28YZpRnA3cAHHDlkt7xQHiJ0fcL+NRcIqsocZQUSmbzeIKRpwttJjik5ZGanXP+vlA4SbTg+AbA3bP363l+w=="], "@vitejs/plugin-react": ["@vitejs/plugin-react@4.7.0", "", { "dependencies": { "@babel/core": "^7.28.0", "@babel/plugin-transform-react-jsx-self": "^7.27.1", "@babel/plugin-transform-react-jsx-source": "^7.27.1", "@rolldown/pluginutils": "1.0.0-beta.27", "@types/babel__core": "^7.20.5", "react-refresh": "^0.17.0" }, "peerDependencies": { "vite": "^4.2.0 || ^5.0.0 || ^6.0.0 || ^7.0.0" } }, "sha512-gUu9hwfWvvEDBBmgtAowQCojwZmJ5mcLn3aufeCsitijs3+f2NsrPtlAWIR6OPiqljl96GVCUbLe0HyqIpVaoA=="], diff --git a/flake.nix b/flake.nix index 717b38e3b0..d54949a7d0 100644 --- a/flake.nix +++ b/flake.nix @@ -83,7 +83,7 @@ outputHashMode = "recursive"; # Marker used by scripts/update_flake_hash.sh to update this hash in place. - outputHash = "sha256-nSkVmS55SWfLbUIscBGMzgR2su6vIlE9GcSRDLrn4eI="; # mux-offline-cache-hash + outputHash = "sha256-jHp/RsmtwHKsbrD0b86+nb+XQ75pT5y1tEeltT6hDVQ="; # mux-offline-cache-hash }; configurePhase = '' diff --git a/package.json b/package.json index e5fe833cba..e76a7a864f 100644 --- a/package.json +++ b/package.json @@ -88,6 +88,7 @@ "@radix-ui/react-toggle-group": "^1.1.11", "@radix-ui/react-tooltip": "^1.2.8", "@radix-ui/react-visually-hidden": "^1.2.4", + "@vercel/ai-sdk-openai-websocket-fetch": "^1.0.0", "@xterm/addon-serialize": "^0.14.0", "@xterm/headless": "^6.0.0", "ai": "^6.0.72", diff --git a/src/browser/features/Settings/Sections/ProvidersSection.test.tsx b/src/browser/features/Settings/Sections/ProvidersSection.test.tsx index 8e31956d30..53255a7a0c 100644 --- a/src/browser/features/Settings/Sections/ProvidersSection.test.tsx +++ b/src/browser/features/Settings/Sections/ProvidersSection.test.tsx @@ -150,12 +150,27 @@ function patchProviderMethods(client: APIClient, providersConfig: ProvidersConfi delete providersConfig[input.provider]; return Promise.resolve({ success: true as const, data: undefined }); }); + const setProviderConfig = mock((input) => { + const provider = providersConfig[input.provider]; + if (provider) { + const key = input.keyPath[0] as keyof ProviderConfigInfo | undefined; + if (key) { + if (input.value === "") { + delete provider[key]; + } else { + Object.assign(provider, { [key]: input.value }); + } + } + } + return Promise.resolve({ success: true as const, data: undefined }); + }); const onConfigChanged = mock(() => Promise.resolve(emptyConfigChangeIterator())); Object.assign(client.providers, { getConfig, addCustomOpenAICompatibleProvider, removeCustomProvider, + setProviderConfig, onConfigChanged, }); @@ -163,6 +178,7 @@ function patchProviderMethods(client: APIClient, providersConfig: ProvidersConfi addCustomOpenAICompatibleProvider, getConfig, removeCustomProvider, + setProviderConfig, }; } @@ -318,6 +334,108 @@ describe("ProvidersSection", () => { ).toBeTruthy(); }); + test("shows and persists the OpenAI WebSocket transport toggle", async () => { + const view = renderProvidersSection(); + const openAiButton = await view.findByRole("button", { name: /^OpenAI$/ }); + + fireEvent.click(openAiButton); + + const openAiCard = getProviderCard(openAiButton); + const webSocketToggle = within(openAiCard).getByRole("switch", { + name: /WebSocket transport/i, + }); + expect(webSocketToggle).toBeTruthy(); + + fireEvent.click(webSocketToggle); + + await waitFor(() => { + expect(view.setProviderConfig).toHaveBeenCalledWith({ + provider: "openai", + keyPath: ["webSocketTransportEnabled"], + value: true, + }); + }); + }); + + test("clears the OpenAI WebSocket transport preference when toggled off", async () => { + const view = renderProvidersSection(); + view.providersConfig.openai.webSocketTransportEnabled = true; + const openAiButton = await view.findByRole("button", { name: /^OpenAI$/ }); + + fireEvent.click(openAiButton); + + const openAiCard = getProviderCard(openAiButton); + const webSocketToggle = within(openAiCard).getByRole("switch", { + name: /WebSocket transport/i, + }); + + fireEvent.click(webSocketToggle); + + await waitFor(() => { + expect(view.setProviderConfig).toHaveBeenCalledWith({ + provider: "openai", + keyPath: ["webSocketTransportEnabled"], + value: "", + }); + }); + }); + + test("hides the OpenAI WebSocket transport toggle when Codex OAuth is the active default", async () => { + const view = renderProvidersSection(); + view.providersConfig.openai.codexOauthSet = true; + view.providersConfig.openai.apiKeySet = false; + view.providersConfig.openai.webSocketTransportEnabled = true; + const openAiButton = await view.findByRole("button", { name: /^OpenAI$/ }); + + fireEvent.click(openAiButton); + + const openAiCard = getProviderCard(openAiButton); + expect( + within(openAiCard).queryByRole("switch", { + name: /WebSocket transport/i, + }) + ).toBeNull(); + expect(view.providersConfig.openai.webSocketTransportEnabled).toBe(true); + }); + + test("hides the OpenAI WebSocket transport toggle when OpenAI uses a custom base URL", async () => { + const view = renderProvidersSection(); + view.providersConfig.openai.baseUrl = "https://proxy.openai.test/v1"; + view.providersConfig.openai.webSocketTransportEnabled = true; + const openAiButton = await view.findByRole("button", { name: /^OpenAI$/ }); + + fireEvent.click(openAiButton); + + const openAiCard = getProviderCard(openAiButton); + expect( + within(openAiCard).queryByRole("switch", { + name: /WebSocket transport/i, + }) + ).toBeNull(); + expect(view.providersConfig.openai.webSocketTransportEnabled).toBe(true); + }); + + test("hides the OpenAI WebSocket transport toggle for Chat Completions without clearing it", async () => { + const view = renderProvidersSection(); + view.providersConfig.openai.wireFormat = "chatCompletions"; + view.providersConfig.openai.webSocketTransportEnabled = true; + const openAiButton = await view.findByRole("button", { name: /^OpenAI$/ }); + + fireEvent.click(openAiButton); + + const openAiCard = getProviderCard(openAiButton); + expect( + within(openAiCard).queryByRole("switch", { + name: /WebSocket transport/i, + }) + ).toBeNull(); + expect(within(openAiCard).queryByText("WebSocket transport")).toBeNull(); + expect(view.providersConfig.openai.webSocketTransportEnabled).toBe(true); + expect(view.setProviderConfig).not.toHaveBeenCalledWith( + expect.objectContaining({ keyPath: ["webSocketTransportEnabled"], value: "" }) + ); + }); + test("shows remove only for expanded custom provider cards", async () => { const view = renderProvidersSection(); const customButton = await view.findByRole("button", { name: /Acme OpenAI/ }); diff --git a/src/browser/features/Settings/Sections/ProvidersSection.tsx b/src/browser/features/Settings/Sections/ProvidersSection.tsx index b7c140a04c..3d2c871801 100644 --- a/src/browser/features/Settings/Sections/ProvidersSection.tsx +++ b/src/browser/features/Settings/Sections/ProvidersSection.tsx @@ -2245,370 +2245,434 @@ export function ProvidersSection() { )} {/* OpenAI: ChatGPT OAuth + service tier */} - {provider === "openai" && ( -
-
- - - {codexOauthStatus === "starting" - ? "Starting..." - : codexOauthStatus === "waiting" - ? "Waiting for login..." - : codexOauthIsConnected - ? "Connected" - : "Not connected"} - -
- -
- {!isRemoteServer && ( - - )} - + {provider === "openai" && + (() => { + const openAIWireFormat = providerInfo?.wireFormat ?? "responses"; + const openAIBaseUrl = + typeof providerInfo?.baseUrl === "string" + ? providerInfo.baseUrl.trim() + : ""; + const openAIResolvedBaseUrl = + typeof providerInfo?.baseUrlResolved === "string" + ? providerInfo.baseUrlResolved.trim() + : ""; + const openAICodexOAuthIsDefault = + providerInfo?.codexOauthSet === true && + (providerInfo.apiKeySet !== true || + providerInfo.codexOauthDefaultAuth !== "apiKey"); + const openAIWebSocketTransportVisible = + openAIWireFormat === "responses" && + openAIBaseUrl.length === 0 && + openAIResolvedBaseUrl.length === 0 && + !openAICodexOAuthIsDefault; + return ( +
+
+ + + {codexOauthStatus === "starting" + ? "Starting..." + : codexOauthStatus === "waiting" + ? "Waiting for login..." + : codexOauthIsConnected + ? "Connected" + : "Not connected"} + +
- {codexOauthStatus === "waiting" && - !codexOauthDeviceFlow && - codexOauthAuthorizeUrl && ( +
+ {!isRemoteServer && ( + + )} - )} - {codexOauthLoginInProgress && ( - - )} + {codexOauthStatus === "waiting" && + !codexOauthDeviceFlow && + codexOauthAuthorizeUrl && ( + + )} - {codexOauthIsConnected && ( - - )} -
+ {codexOauthLoginInProgress && ( + + )} - {codexOauthDeviceFlow && ( -
-

- Enter this code on the OpenAI verification page: -

-
- - {codexOauthDeviceFlow.userCode} - - + {codexOauthIsConnected && ( + + )}
-

- - Waiting for authorization... -

-
- )} - - {codexOauthStatus === "waiting" && !codexOauthDeviceFlow && ( -

- - Waiting for authorization... -

- )} - {codexOauthStatus === "error" && codexOauthError && ( -

{codexOauthError}

- )} + {codexOauthDeviceFlow && ( +
+

+ Enter this code on the OpenAI verification page: +

+
+ + {codexOauthDeviceFlow.userCode} + + +
+

+ + Waiting for authorization... +

+
+ )} -
-
- -

- Applies to models that support both ChatGPT OAuth and API keys (e.g.{" "} - gpt-5.5). -

-
+ {codexOauthStatus === "waiting" && !codexOauthDeviceFlow && ( +

+ + Waiting for authorization... +

+ )} - { - if (!api) return; - if (next !== "oauth" && next !== "apiKey") { - return; - } + {codexOauthStatus === "error" && codexOauthError && ( +

{codexOauthError}

+ )} - updateOptimistically("openai", { codexOauthDefaultAuth: next }); - void api.providers.setProviderConfig({ - provider: "openai", - keyPath: ["codexOauthDefaultAuth"], - value: next, - }); - }} - size="sm" - className="h-9" - disabled={!api || !codexOauthDefaultAuthIsEditable} - > - - Use ChatGPT OAuth by default - - - Use OpenAI API key by default - -
+
+
+ +

+ Applies to models that support both ChatGPT OAuth and API keys + (e.g. gpt-5.5). +

+
-

- ChatGPT OAuth uses subscription billing (costs included). API key uses - OpenAI platform billing. -

+ { + if (!api) return; + if (next !== "oauth" && next !== "apiKey") { + return; + } + + updateOptimistically("openai", { codexOauthDefaultAuth: next }); + void api.providers.setProviderConfig({ + provider: "openai", + keyPath: ["codexOauthDefaultAuth"], + value: next, + }); + }} + size="sm" + className="h-9" + disabled={!api || !codexOauthDefaultAuthIsEditable} + > + + Use ChatGPT OAuth by default + + + Use OpenAI API key by default + + - {!codexOauthDefaultAuthIsEditable && ( -

- Connect ChatGPT OAuth and set an OpenAI API key to change this - setting. -

- )} -
+

+ ChatGPT OAuth uses subscription billing (costs included). API key + uses OpenAI platform billing. +

-
-
- - - - - - ? - - - -
-
OpenAI service tier
-
- auto: standard - behavior. -
-
- priority: lower - latency, higher cost. -
-
- flex: lower cost, - higher latency. -
-
-
-
-
-
- { + if (!api) return; + + if (next === OPENAI_SERVICE_TIER_UNSET) { + setOpenaiServiceTierSelectOverride(OPENAI_SERVICE_TIER_UNSET); + void api.providers + .setProviderConfig({ + provider: "openai", + keyPath: ["serviceTier"], + value: "", + }) + .then(() => refresh()) + .finally(() => setOpenaiServiceTierSelectOverride(null)); + return; + } + + if (!isOpenAIServiceTier(next)) { + return; + } + + setOpenaiServiceTierSelectOverride(null); + updateOptimistically("openai", { serviceTier: next }); + void api.providers.setProviderConfig({ provider: "openai", keyPath: ["serviceTier"], - value: "", - }) - .then(() => refresh()) - .finally(() => setOpenaiServiceTierSelectOverride(null)); - return; - } - - if (!isOpenAIServiceTier(next)) { - return; - } - - setOpenaiServiceTierSelectOverride(null); - updateOptimistically("openai", { serviceTier: next }); - void api.providers.setProviderConfig({ - provider: "openai", - keyPath: ["serviceTier"], - value: next, - }); - }} - > - - - - - - Not configured (omit service_tier) - - auto - default - flex - priority - - -
+ value: next, + }); + }} + > + + + + + + Not configured (omit service_tier) + + auto + default + flex + priority + + +
-
-
- - - - - - ? - - - -
-
OpenAI wire format
-
- responses: modern API - with persistence and built-in tools (default). -
-
- chat completions: - legacy /chat/completions endpoint. Use if your provider - doesn't support the Responses API (e.g. Azure Gov). -
+
+
+ + + + + + ? + + + +
+
OpenAI wire format
+
+ responses: modern + API with persistence and built-in tools (default). +
+
+ chat completions: + legacy /chat/completions endpoint. Use if your provider + doesn't support the Responses API (e.g. Azure Gov). +
+
+
+
+
+
+ +
+ {openAIWebSocketTransportVisible && ( +
+
+
+ + + Experimental: uses OpenAI's Responses WebSocket + transport for streaming Responses API requests. Unsupported + endpoints may fail. +
- - - -
- -
-
-
- - - - - - ? - - - -
-
OpenAI response storage
-
- enabled: OpenAI - stores responses for retrieval and context (default). -
-
- disabled: responses - are not stored. Required for zero data retention (ZDR) - endpoints. -
-
-
-
-
+
+
+ + + + + + ? + + + +
+
+ OpenAI response storage +
+
+ enabled: OpenAI + stores responses for retrieval and context (default). +
+
+ disabled: + responses are not stored. Required for zero data + retention (ZDR) endpoints. +
+
+
+
+
+
+ +
- -
-
- )} + ); + })()} {isCustomOpenAICompatible && (
diff --git a/src/common/config/schemas/providersConfig.test.ts b/src/common/config/schemas/providersConfig.test.ts index d360127e95..a36f277212 100644 --- a/src/common/config/schemas/providersConfig.test.ts +++ b/src/common/config/schemas/providersConfig.test.ts @@ -62,6 +62,27 @@ describe("ProvidersConfigSchema", () => { expect(ProvidersConfigSchema.safeParse(invalid).success).toBe(false); }); + it("accepts OpenAI WebSocket transport opt-in as an optional boolean", () => { + const valid = { + openai: { apiKey: "sk-openai-123", webSocketTransportEnabled: true }, + }; + + const parsed = ProvidersConfigSchema.safeParse(valid); + + expect(parsed.success).toBe(true); + if (parsed.success) { + expect(parsed.data.openai?.webSocketTransportEnabled).toBe(true); + } + }); + + it("rejects non-boolean OpenAI WebSocket transport values", () => { + const invalid = { + openai: { webSocketTransportEnabled: "true" }, + }; + + expect(ProvidersConfigSchema.safeParse(invalid).success).toBe(false); + }); + describe("modelParameters", () => { it("accepts valid per-model and wildcard overrides", () => { const valid = { diff --git a/src/common/config/schemas/providersConfig.ts b/src/common/config/schemas/providersConfig.ts index 8748293e27..957e4331a4 100644 --- a/src/common/config/schemas/providersConfig.ts +++ b/src/common/config/schemas/providersConfig.ts @@ -34,6 +34,7 @@ export const OpenAIProviderConfigSchema = BaseProviderConfigSchema.extend({ codexOauth: z.record(z.string(), z.unknown()).optional(), defaultModel: z.string().optional(), apiVersion: z.string().optional(), + webSocketTransportEnabled: z.boolean().optional(), }); export const BedrockProviderConfigSchema = BaseProviderConfigSchema.extend({ diff --git a/src/common/orpc/schemas/api.test.ts b/src/common/orpc/schemas/api.test.ts index 0c9e9871d4..7e5c7a13b9 100644 --- a/src/common/orpc/schemas/api.test.ts +++ b/src/common/orpc/schemas/api.test.ts @@ -116,6 +116,7 @@ describe("ProviderConfigInfoSchema conformance", () => { models: ["claude-3-opus", "claude-3-sonnet"], serviceTier: "flex", store: false, + webSocketTransportEnabled: true, cacheTtl: "1h", disableBetaFeatures: true, codexOauthSet: true, @@ -150,6 +151,7 @@ describe("ProviderConfigInfoSchema conformance", () => { expect(parsed.models).toEqual(full.models); expect(parsed.serviceTier).toBe(full.serviceTier); expect(parsed.store).toBe(full.store); + expect(parsed.webSocketTransportEnabled).toBe(full.webSocketTransportEnabled); expect(parsed.cacheTtl).toBe(full.cacheTtl); expect(parsed.disableBetaFeatures).toBe(full.disableBetaFeatures); expect(parsed.codexOauthSet).toBe(full.codexOauthSet); diff --git a/src/common/orpc/schemas/api.ts b/src/common/orpc/schemas/api.ts index cf6669549c..630020a930 100644 --- a/src/common/orpc/schemas/api.ts +++ b/src/common/orpc/schemas/api.ts @@ -205,6 +205,7 @@ export const ProviderConfigInfoSchema = z.object({ serviceTier: ServiceTierSchema.optional(), wireFormat: z.enum(["responses", "chatCompletions"]).optional(), store: z.boolean().optional(), + webSocketTransportEnabled: z.boolean().optional(), /** Anthropic-specific fields */ cacheTtl: CacheTtlSchema.optional(), disableBetaFeatures: z.boolean().optional(), diff --git a/src/node/services/languageModelCleanup.test.ts b/src/node/services/languageModelCleanup.test.ts new file mode 100644 index 0000000000..45efab5d9b --- /dev/null +++ b/src/node/services/languageModelCleanup.test.ts @@ -0,0 +1,91 @@ +import { describe, expect, test } from "bun:test"; +import type { LanguageModel } from "ai"; + +import { + attachLanguageModelCleanup, + hasLanguageModelCleanup, + moveLanguageModelCleanup, + runLanguageModelCleanup, +} from "./languageModelCleanup"; + +function createModel(): LanguageModel { + return { + specificationVersion: "v3", + provider: "test", + modelId: "test-model", + supportedUrls: {}, + doGenerate: () => Promise.reject(new Error("doGenerate is unused in cleanup tests")), + doStream: () => Promise.reject(new Error("doStream is unused in cleanup tests")), + }; +} + +describe("language model cleanup", () => { + test("runs attached cleanup exactly once", () => { + const model = createModel(); + let cleanupCalls = 0; + + attachLanguageModelCleanup(model, () => { + cleanupCalls += 1; + }); + + runLanguageModelCleanup(model); + runLanguageModelCleanup(model); + + expect(cleanupCalls).toBe(1); + }); + + test("reports whether cleanup is attached", () => { + const model = createModel(); + + expect(hasLanguageModelCleanup(model)).toBe(false); + attachLanguageModelCleanup(model, () => undefined); + expect(hasLanguageModelCleanup(model)).toBe(true); + runLanguageModelCleanup(model); + expect(hasLanguageModelCleanup(model)).toBe(false); + }); + + test("moves cleanup to a wrapper model", () => { + const inner = createModel(); + const outer = createModel(); + let cleanupCalls = 0; + + attachLanguageModelCleanup(inner, () => { + cleanupCalls += 1; + }); + + moveLanguageModelCleanup(inner, outer); + + expect(hasLanguageModelCleanup(inner)).toBe(false); + expect(hasLanguageModelCleanup(outer)).toBe(true); + runLanguageModelCleanup(inner); + runLanguageModelCleanup(outer); + expect(cleanupCalls).toBe(1); + }); + + test("rejects double attach before cleanup is moved or run", () => { + const model = createModel(); + attachLanguageModelCleanup(model, () => undefined); + + expect(() => attachLanguageModelCleanup(model, () => undefined)).toThrow( + "language model already has cleanup attached" + ); + }); + + test("models without cleanup are safe", () => { + expect(() => runLanguageModelCleanup(createModel())).not.toThrow(); + }); + + test("cleanup errors are swallowed after the first attempt", () => { + const model = createModel(); + let cleanupCalls = 0; + + attachLanguageModelCleanup(model, () => { + cleanupCalls += 1; + throw new Error("close failed"); + }); + + expect(() => runLanguageModelCleanup(model)).not.toThrow(); + expect(() => runLanguageModelCleanup(model)).not.toThrow(); + expect(cleanupCalls).toBe(1); + }); +}); diff --git a/src/node/services/languageModelCleanup.ts b/src/node/services/languageModelCleanup.ts new file mode 100644 index 0000000000..1b96c9138e --- /dev/null +++ b/src/node/services/languageModelCleanup.ts @@ -0,0 +1,60 @@ +import assert from "node:assert"; +import type { LanguageModel } from "ai"; + +import { log } from "./log"; + +const languageModelCleanupSymbol = Symbol("mux.languageModelCleanup"); + +type LanguageModelCleanup = () => void; +type LanguageModelWithCleanup = LanguageModel & { + [languageModelCleanupSymbol]?: LanguageModelCleanup; +}; + +export function attachLanguageModelCleanup( + model: LanguageModel, + cleanup: LanguageModelCleanup +): void { + assert(typeof cleanup === "function", "language model cleanup must be a function"); + assert( + !hasLanguageModelCleanup(model), + "language model already has cleanup attached; call moveLanguageModelCleanup instead" + ); + const modelWithCleanup = model as LanguageModelWithCleanup; + modelWithCleanup[languageModelCleanupSymbol] = cleanup; +} + +export function moveLanguageModelCleanup(source: LanguageModel, target: LanguageModel): void { + const sourceWithCleanup = source as LanguageModelWithCleanup; + const cleanup = sourceWithCleanup[languageModelCleanupSymbol]; + if (cleanup === undefined) { + return; + } + + delete sourceWithCleanup[languageModelCleanupSymbol]; + attachLanguageModelCleanup(target, cleanup); +} + +export function hasLanguageModelCleanup(model: LanguageModel): boolean { + const modelWithCleanup = model as LanguageModelWithCleanup; + return typeof modelWithCleanup[languageModelCleanupSymbol] === "function"; +} + +export function runLanguageModelCleanup(model: LanguageModel | undefined): void { + if (model === undefined) { + return; + } + + const modelWithCleanup = model as LanguageModelWithCleanup; + const cleanup = modelWithCleanup[languageModelCleanupSymbol]; + if (cleanup === undefined) { + return; + } + + delete modelWithCleanup[languageModelCleanupSymbol]; + + try { + cleanup(); + } catch (error) { + log.warn("Failed to clean up language model resources", { error }); + } +} diff --git a/src/node/services/openAIWebSocketTransportFetch.test.ts b/src/node/services/openAIWebSocketTransportFetch.test.ts new file mode 100644 index 0000000000..b7e98e02e4 --- /dev/null +++ b/src/node/services/openAIWebSocketTransportFetch.test.ts @@ -0,0 +1,322 @@ +import { describe, expect, test } from "bun:test"; + +import { + consumeCapturedRequestHeaders, + DEVTOOLS_RUN_METADATA_ID_HEADER, + DEVTOOLS_STEP_ID_HEADER, +} from "./devToolsHeaderCapture"; +import { createOpenAIWebSocketTransportFetch } from "./openAIWebSocketTransportFetch"; + +function getFetchInputUrl(input: RequestInfo | URL): string { + if (input instanceof URL) { + return input.toString(); + } + if (typeof input === "string") { + return input; + } + return input.url; +} + +function createTestFetch( + handler: (input: RequestInfo | URL, init?: RequestInit) => Promise +): typeof fetch { + return Object.assign(handler, { preconnect: fetch.preconnect.bind(fetch) }) as typeof fetch; +} + +function createTestWebSocketFetch( + handler: (input: RequestInfo | URL, init?: RequestInit) => Promise, + close: () => void = () => undefined +): typeof fetch & { close: () => void } { + return Object.assign(createTestFetch(handler), { close }); +} + +describe("createOpenAIWebSocketTransportFetch", () => { + test("disabled transport keeps using the base fetch and exposes inactive cleanup", async () => { + const baseCalls: string[] = []; + const baseFetch = createTestFetch((input: RequestInfo | URL, _init?: RequestInit) => { + baseCalls.push(getFetchInputUrl(input)); + return Promise.resolve(new Response("base")); + }); + + const transport = createOpenAIWebSocketTransportFetch({ + enabled: false, + baseFetch, + createWebSocketFetch: () => { + throw new Error("WebSocket fetch should not be created when disabled"); + }, + }); + + const response = await transport.fetch("https://api.openai.com/v1/responses", { + method: "POST", + body: JSON.stringify({ stream: true }), + }); + + expect(await response.text()).toBe("base"); + expect(baseCalls).toEqual(["https://api.openai.com/v1/responses"]); + expect(transport.active).toBe(false); + expect(() => transport.close()).not.toThrow(); + }); + + test("enabled transport creates the WebSocket fetch lazily", async () => { + let created = false; + const transport = createOpenAIWebSocketTransportFetch({ + enabled: true, + baseFetch: createTestFetch(() => Promise.resolve(new Response("base"))), + createWebSocketFetch: () => { + created = true; + return createTestWebSocketFetch(() => Promise.resolve(new Response("ws"))); + }, + }); + + expect(created).toBe(false); + await transport.fetch("https://api.openai.com/v1/models", { method: "GET" }); + expect(created).toBe(false); + await transport.fetch("https://api.openai.com/v1/responses", { + method: "POST", + body: JSON.stringify({ stream: true }), + }); + expect(created).toBe(true); + }); + + test("enabled transport sends streaming Responses API posts through WebSocket fetch", async () => { + const wsCalls: string[] = []; + const transport = createOpenAIWebSocketTransportFetch({ + enabled: true, + baseFetch: createTestFetch(() => Promise.resolve(new Response("base"))), + createWebSocketFetch: () => { + return createTestWebSocketFetch((input: RequestInfo | URL, _init?: RequestInit) => { + wsCalls.push(getFetchInputUrl(input)); + return Promise.resolve(new Response("ws")); + }); + }, + }); + + const response = await transport.fetch("https://api.openai.com/v1/responses", { + method: "POST", + body: JSON.stringify({ stream: true }), + }); + + expect(await response.text()).toBe("ws"); + expect(wsCalls).toEqual(["https://api.openai.com/v1/responses"]); + expect(transport.active).toBe(true); + }); + + test("enabled transport keeps non-eligible requests on the base fetch", async () => { + const baseCalls: string[] = []; + const wsCalls: string[] = []; + const baseFetch = createTestFetch((input: RequestInfo | URL, _init?: RequestInit) => { + baseCalls.push(getFetchInputUrl(input)); + return Promise.resolve(new Response("base")); + }); + + const transport = createOpenAIWebSocketTransportFetch({ + enabled: true, + baseFetch, + createWebSocketFetch: () => { + return createTestWebSocketFetch((input: RequestInfo | URL, _init?: RequestInit) => { + wsCalls.push(getFetchInputUrl(input)); + return Promise.resolve(new Response("ws")); + }); + }, + }); + + const response = await transport.fetch("https://api.openai.com/v1/models", { + method: "GET", + }); + + expect(await response.text()).toBe("base"); + expect(baseCalls).toEqual(["https://api.openai.com/v1/models"]); + expect(wsCalls).toEqual([]); + }); + + test("enabled transport strips DevTools headers before WebSocket dispatch", async () => { + let webSocketHeaders: Headers | undefined; + const transport = createOpenAIWebSocketTransportFetch({ + enabled: true, + baseFetch: createTestFetch(() => Promise.resolve(new Response("base"))), + createWebSocketFetch: () => + createTestWebSocketFetch((_input: RequestInfo | URL, init?: RequestInit) => { + webSocketHeaders = new Headers(init?.headers); + return Promise.resolve(new Response("ws")); + }), + }); + + await transport.fetch("https://api.openai.com/v1/responses", { + method: "POST", + headers: { + Authorization: "Bearer test-key", + [DEVTOOLS_STEP_ID_HEADER]: "step-ws-1", + [DEVTOOLS_RUN_METADATA_ID_HEADER]: "run-metadata-1", + }, + body: JSON.stringify({ stream: true }), + }); + + expect(webSocketHeaders).toBeDefined(); + if (!webSocketHeaders) { + throw new Error("Expected WebSocket fetch to receive request headers"); + } + expect(webSocketHeaders.get(DEVTOOLS_STEP_ID_HEADER)).toBeNull(); + expect(webSocketHeaders.get(DEVTOOLS_RUN_METADATA_ID_HEADER)).toBeNull(); + const captured = consumeCapturedRequestHeaders("step-ws-1"); + expect(captured).toEqual({ authorization: "[REDACTED]" }); + }); + + test("enabled transport keeps non-streaming Responses posts on the base fetch", async () => { + const baseBodies: string[] = []; + const wsCalls: string[] = []; + const baseFetch = createTestFetch((_input: RequestInfo | URL, init?: RequestInit) => { + baseBodies.push(typeof init?.body === "string" ? init.body : ""); + return Promise.resolve(new Response("base")); + }); + const transport = createOpenAIWebSocketTransportFetch({ + enabled: true, + baseFetch, + createWebSocketFetch: () => + createTestWebSocketFetch((input: RequestInfo | URL) => { + wsCalls.push(getFetchInputUrl(input)); + return Promise.resolve(new Response("ws")); + }), + }); + + const streamFalse = await transport.fetch("https://api.openai.com/v1/responses", { + method: "POST", + body: JSON.stringify({ stream: false }), + }); + const streamAbsent = await transport.fetch("https://api.openai.com/v1/responses", { + method: "POST", + body: JSON.stringify({}), + }); + + expect(await streamFalse.text()).toBe("base"); + expect(await streamAbsent.text()).toBe("base"); + expect(baseBodies).toEqual([JSON.stringify({ stream: false }), JSON.stringify({})]); + expect(wsCalls).toEqual([]); + }); + + test("enabled transport recognizes streaming Responses Request objects", async () => { + const wsCalls: string[] = []; + let webSocketHeaders: Headers | undefined; + const transport = createOpenAIWebSocketTransportFetch({ + enabled: true, + baseFetch: createTestFetch(() => Promise.resolve(new Response("base"))), + createWebSocketFetch: () => + createTestWebSocketFetch((input: RequestInfo | URL, init?: RequestInit) => { + wsCalls.push(getFetchInputUrl(input)); + webSocketHeaders = new Headers(init?.headers); + return Promise.resolve(new Response("ws")); + }), + }); + const request = new Request("https://api.openai.com/v1/responses", { + method: "POST", + headers: { Authorization: "Bearer request-key" }, + body: JSON.stringify({ stream: true }), + }); + + const response = await transport.fetch(request); + + expect(await response.text()).toBe("ws"); + expect(wsCalls).toEqual(["https://api.openai.com/v1/responses"]); + expect(webSocketHeaders?.get("authorization")).toBe("Bearer request-key"); + }); + + test("enabled transport recognizes Responses URLs with query parameters", async () => { + const wsCalls: string[] = []; + const transport = createOpenAIWebSocketTransportFetch({ + enabled: true, + baseFetch: createTestFetch(() => Promise.resolve(new Response("base"))), + createWebSocketFetch: () => + createTestWebSocketFetch((input: RequestInfo | URL) => { + wsCalls.push(getFetchInputUrl(input)); + return Promise.resolve(new Response("ws")); + }), + }); + + const response = await transport.fetch("https://api.openai.com/v1/responses?beta=2", { + method: "POST", + body: JSON.stringify({ stream: true }), + }); + + expect(await response.text()).toBe("ws"); + expect(wsCalls).toEqual(["https://api.openai.com/v1/responses?beta=2"]); + }); + + test("close retries after a connection-establishment race", async () => { + let closeCalls = 0; + let resolveWebSocketFetch: ((response: Response) => void) | undefined; + const webSocketFetchPromise = new Promise((resolve) => { + resolveWebSocketFetch = resolve; + }); + const transport = createOpenAIWebSocketTransportFetch({ + enabled: true, + baseFetch: createTestFetch(() => Promise.resolve(new Response("base"))), + createWebSocketFetch: () => + createTestWebSocketFetch( + () => webSocketFetchPromise, + () => { + closeCalls += 1; + } + ), + }); + + const responsePromise = transport.fetch("https://api.openai.com/v1/responses", { + method: "POST", + body: JSON.stringify({ stream: true }), + }); + await Promise.resolve(); + transport.close(); + if (!resolveWebSocketFetch) { + throw new Error("Expected test WebSocket fetch resolver to be initialized"); + } + resolveWebSocketFetch(new Response("ws")); + + expect(await (await responsePromise).text()).toBe("ws"); + expect(closeCalls).toBe(2); + }); + + test("close retry failure does not mask a resolved WebSocket response", async () => { + const transport = createOpenAIWebSocketTransportFetch({ + enabled: true, + baseFetch: createTestFetch(() => Promise.resolve(new Response("base"))), + createWebSocketFetch: () => + createTestWebSocketFetch( + () => Promise.resolve(new Response("ws")), + () => { + throw new Error("close failed"); + } + ), + }); + + const responsePromise = transport.fetch("https://api.openai.com/v1/responses", { + method: "POST", + body: JSON.stringify({ stream: true }), + }); + await Promise.resolve(); + expect(() => transport.close()).toThrow("close failed"); + + expect(await (await responsePromise).text()).toBe("ws"); + }); + + test("close is idempotent after WebSocket fetch creation", async () => { + let closeCalls = 0; + const transport = createOpenAIWebSocketTransportFetch({ + enabled: true, + baseFetch: createTestFetch(() => Promise.resolve(new Response("base"))), + createWebSocketFetch: () => + createTestWebSocketFetch( + () => Promise.resolve(new Response("ws")), + () => { + closeCalls += 1; + } + ), + }); + + await transport.fetch("https://api.openai.com/v1/responses", { + method: "POST", + body: JSON.stringify({ stream: true }), + }); + transport.close(); + transport.close(); + + expect(closeCalls).toBe(1); + }); +}); diff --git a/src/node/services/openAIWebSocketTransportFetch.ts b/src/node/services/openAIWebSocketTransportFetch.ts new file mode 100644 index 0000000000..5ad9d800d8 --- /dev/null +++ b/src/node/services/openAIWebSocketTransportFetch.ts @@ -0,0 +1,130 @@ +import assert from "node:assert"; +import { captureAndStripDevToolsHeader } from "./devToolsHeaderCapture"; +import { createWebSocketFetch as createOpenAIWebSocketFetch } from "@vercel/ai-sdk-openai-websocket-fetch"; + +type WebSocketFetch = ((input: RequestInfo | URL, init?: RequestInit) => Promise) & { + close: () => void; +}; +type WebSocketFetchFactory = () => WebSocketFetch; + +interface CreateOpenAIWebSocketTransportFetchOptions { + enabled: boolean; + baseFetch: typeof fetch; + createWebSocketFetch?: WebSocketFetchFactory; +} + +interface OpenAIWebSocketTransportFetch { + fetch: typeof fetch; + close: () => void; + active: boolean; +} + +function getRequestUrl(input: RequestInfo | URL): string { + if (input instanceof URL) { + return input.toString(); + } + if (typeof input === "string") { + return input; + } + return input.url; +} + +async function isStreamingResponsesRequest( + input: RequestInfo | URL, + init?: RequestInit +): Promise { + const method = init?.method ?? (input instanceof Request ? input.method : "GET"); + if (method.toUpperCase() !== "POST") { + return false; + } + + if (!/\/v1\/responses(\?|$)/.test(getRequestUrl(input))) { + return false; + } + + const bodyText = + typeof init?.body === "string" + ? init.body + : init?.body == null && input instanceof Request + ? await input.clone().text() + : undefined; + if (bodyText === undefined) { + return false; + } + + try { + const body = JSON.parse(bodyText) as { stream?: unknown }; + return body.stream === true; + } catch { + return false; + } +} + +export function createOpenAIWebSocketTransportFetch( + options: CreateOpenAIWebSocketTransportFetchOptions +): OpenAIWebSocketTransportFetch { + if (!options.enabled) { + return { + fetch: options.baseFetch, + close: () => undefined, + active: false, + }; + } + + const webSocketFetchFactory = options.createWebSocketFetch ?? createOpenAIWebSocketFetch; + let webSocketFetch: WebSocketFetch | null = null; + + const getWebSocketFetch = (): WebSocketFetch => { + webSocketFetch ??= webSocketFetchFactory(); + assert( + typeof webSocketFetch.close === "function", + "OpenAI WebSocket fetch must expose close()" + ); + return webSocketFetch; + }; + + let closeRequested = false; + const close = (): void => { + if (closeRequested) { + return; + } + closeRequested = true; + webSocketFetch?.close(); + }; + + const baseFetchWithPreconnect = options.baseFetch as typeof fetch & { + preconnect?: typeof fetch.preconnect; + }; + const fetchExtras = + typeof baseFetchWithPreconnect.preconnect === "function" + ? { preconnect: baseFetchWithPreconnect.preconnect.bind(baseFetchWithPreconnect) } + : {}; + const transportFetch = Object.assign(async (input: RequestInfo | URL, init?: RequestInit) => { + // The upstream package falls through to globalThis.fetch for non-WebSocket requests. + // Pre-filter here so Mux's existing fetch wrappers keep handling those HTTP paths. + if (!(await isStreamingResponsesRequest(input, init))) { + return options.baseFetch(input, init); + } + + const activeWebSocketFetch = getWebSocketFetch(); + const headers = new Headers( + init?.headers ?? (input instanceof Request ? input.headers : undefined) + ); + captureAndStripDevToolsHeader(headers); + const response = await activeWebSocketFetch(input, { ...(init ?? {}), headers }); + if (closeRequested) { + try { + activeWebSocketFetch.close(); + } catch { + // Cleanup after a cancellation race must not mask the successful fetch response. + } + } + return response; + }, fetchExtras) as typeof fetch; + + return { + fetch: transportFetch, + close, + active: true, + }; +} diff --git a/src/node/services/providerModelFactory.test.ts b/src/node/services/providerModelFactory.test.ts index cff450fe36..de9c0e54b6 100644 --- a/src/node/services/providerModelFactory.test.ts +++ b/src/node/services/providerModelFactory.test.ts @@ -21,6 +21,8 @@ import { wrapFetchWithAnthropicCacheControl, } from "./providerModelFactory"; import { MUX_ANTHROPIC_EFFORT_OVERRIDE_HEADER } from "@/common/utils/ai/providerOptions"; +import { hasLanguageModelCleanup } from "./languageModelCleanup"; +import type { DevToolsService } from "./devToolsService"; import { CodexOauthService } from "./codexOauthService"; import { PolicyService } from "./policyService"; import { ProviderService } from "./providerService"; @@ -40,6 +42,27 @@ async function withTempConfig( } } +async function withOpenAIBaseUrlEnvUnset(run: () => Promise): Promise { + const savedBaseUrl = process.env.OPENAI_BASE_URL; + const savedApiBase = process.env.OPENAI_API_BASE; + delete process.env.OPENAI_BASE_URL; + delete process.env.OPENAI_API_BASE; + try { + await run(); + } finally { + if (savedBaseUrl === undefined) { + delete process.env.OPENAI_BASE_URL; + } else { + process.env.OPENAI_BASE_URL = savedBaseUrl; + } + if (savedApiBase === undefined) { + delete process.env.OPENAI_API_BASE; + } else { + process.env.OPENAI_API_BASE = savedApiBase; + } + } +} + async function withTempPolicyProviderFactory( policy: unknown, run: ( @@ -781,6 +804,146 @@ describe("ProviderModelFactory GitHub Copilot", () => { }); }); +describe("ProviderModelFactory OpenAI WebSocket transport", () => { + it("attaches cleanup when enabled for Responses models", async () => { + await withOpenAIBaseUrlEnvUnset(async () => + withTempConfig(async (config, factory) => { + config.saveProvidersConfig({ + openai: { + apiKey: "sk-test", + webSocketTransportEnabled: true, + }, + }); + + const result = await factory.createModel("openai:gpt-4.1-mini"); + + expect(result.success).toBe(true); + if (!result.success) { + return; + } + expect(hasLanguageModelCleanup(result.data)).toBe(true); + }) + ); + }); + + it("does not attach cleanup for Codex OAuth routed models", async () => { + await withTempConfig(async (config, factory) => { + config.saveProvidersConfig({ + openai: { + webSocketTransportEnabled: true, + codexOauth: { + type: "oauth", + access: "test-access-token", + refresh: "test-refresh-token", + expires: Date.now() + 60_000, + accountId: "test-account-id", + }, + }, + }); + + const result = await factory.createModel(KNOWN_MODELS.GPT_53_CODEX.id); + + expect(result.success).toBe(true); + if (!result.success) { + return; + } + expect(hasLanguageModelCleanup(result.data)).toBe(false); + expect(modelCostsIncluded(result.data)).toBe(true); + }); + }); + + it("does not attach cleanup when a custom OpenAI base URL is configured", async () => { + await withTempConfig(async (config, factory) => { + config.saveProvidersConfig({ + openai: { + apiKey: "sk-test", + baseURL: "https://proxy.openai.test/v1", + webSocketTransportEnabled: true, + }, + }); + + const result = await factory.createModel("openai:gpt-4.1-mini"); + + expect(result.success).toBe(true); + if (!result.success) { + return; + } + expect(hasLanguageModelCleanup(result.data)).toBe(false); + }); + }); + + it("preserves cleanup when DevTools wraps an OpenAI WebSocket model", async () => { + await withOpenAIBaseUrlEnvUnset(async () => + withTempConfig(async (config) => { + config.saveProvidersConfig({ + openai: { + apiKey: "sk-test", + webSocketTransportEnabled: true, + }, + }); + const providerService = new ProviderService(config); + const devToolsService = { enabled: true } as unknown as DevToolsService; + const factory = new ProviderModelFactory( + config, + providerService, + undefined, + undefined, + devToolsService + ); + + const result = await factory.createModel("openai:gpt-4.1-mini", undefined, { + workspaceId: "devtools-workspace", + }); + + expect(result.success).toBe(true); + if (!result.success) { + return; + } + expect(hasLanguageModelCleanup(result.data)).toBe(true); + }) + ); + }); + + it("does not attach cleanup when Chat Completions is selected", async () => { + await withTempConfig(async (config, factory) => { + config.saveProvidersConfig({ + openai: { + apiKey: "sk-test", + wireFormat: "chatCompletions", + webSocketTransportEnabled: true, + }, + }); + + const result = await factory.createModel("openai:gpt-4.1-mini"); + + expect(result.success).toBe(true); + if (!result.success) { + return; + } + expect(hasLanguageModelCleanup(result.data)).toBe(false); + }); + }); + + it("ignores invalid persisted WebSocket transport values", async () => { + await withTempConfig(async (config, factory) => { + config.saveProvidersConfig({ + openai: { + apiKey: "sk-test", + webSocketTransportEnabled: "true", + }, + } as unknown as Parameters[0]); + + const result = await factory.createModel("openai:gpt-4.1-mini"); + + expect(result.success).toBe(true); + if (!result.success) { + return; + } + expect(hasLanguageModelCleanup(result.data)).toBe(false); + }); + }); +}); + describe("ProviderModelFactory modelCostsIncluded", () => { it("marks gpt-5.3-codex as subscription-covered when routed through Codex OAuth", async () => { await withTempConfig(async (config, factory) => { diff --git a/src/node/services/providerModelFactory.ts b/src/node/services/providerModelFactory.ts index fea730ba1f..c11087a04b 100644 --- a/src/node/services/providerModelFactory.ts +++ b/src/node/services/providerModelFactory.ts @@ -41,6 +41,11 @@ import type { CodexOauthService } from "@/node/services/codexOauthService"; import type { DevToolsService } from "@/node/services/devToolsService"; import { captureAndStripDevToolsHeader } from "@/node/services/devToolsHeaderCapture"; import { createDevToolsMiddleware } from "@/node/services/devToolsMiddleware"; +import { + attachLanguageModelCleanup, + moveLanguageModelCleanup, +} from "@/node/services/languageModelCleanup"; +import { createOpenAIWebSocketTransportFetch } from "@/node/services/openAIWebSocketTransportFetch"; import { log } from "@/node/services/log"; import { MUX_ANTHROPIC_EFFORT_OVERRIDE_HEADER, @@ -958,10 +963,12 @@ export class ProviderModelFactory { const workspaceId = opts?.workspaceId; const devToolsService = this.devToolsService; if (workspaceId != null && devToolsService?.enabled) { + const innerModel = model; model = wrapLanguageModel({ model, middleware: createDevToolsMiddleware(workspaceId, devToolsService), }); + moveLanguageModelCleanup(innerModel, model); } return Ok(model); @@ -1296,6 +1303,11 @@ export class ProviderModelFactory { const baseFetch = getProviderFetch(providerConfig); const codexOauthService = this.codexOauthService; + const openAIHasCustomBaseURL = + providerConfig.baseURL != null || providerConfig.baseUrl != null || creds.baseUrl != null; + const webSocketTransportEnabled = + (providerConfig as { webSocketTransportEnabled?: unknown }).webSocketTransportEnabled === + true && !openAIHasCustomBaseURL; // Wrap fetch so Codex OAuth Responses requests are normalized before // they are rerouted from api.openai.com to chatgpt.com's Codex backend. @@ -1378,9 +1390,17 @@ export class ProviderModelFactory { : {} ); + const webSocketTransport = createOpenAIWebSocketTransportFetch({ + enabled: + webSocketTransportEnabled && + effectiveWireFormat === "responses" && + !shouldRouteThroughCodexOauth, + baseFetch: fetchWithOpenAICodexNormalization as typeof fetch, + }); + // Lazy-load OpenAI provider to reduce startup time const { createOpenAI } = await PROVIDER_REGISTRY.openai(); - const providerFetch = fetchWithOpenAICodexNormalization as typeof fetch; + const providerFetch = webSocketTransport.fetch; const provider = createOpenAI({ ...configWithCreds, // Cast is safe: our fetch implementation is compatible with the SDK's fetch type. @@ -1393,6 +1413,10 @@ export class ProviderModelFactory { effectiveWireFormat === "chatCompletions" ? provider.chat(modelId) : provider.responses(modelId); + if (webSocketTransport.active) { + attachLanguageModelCleanup(model, webSocketTransport.close); + } + const injectModelOpenAIStore = (storeValue: unknown, mode: "default" | "force"): void => { assert(typeof storeValue === "boolean", "OpenAI store override must be boolean"); const store = storeValue; diff --git a/src/node/services/providerService.test.ts b/src/node/services/providerService.test.ts index bc2ec9bdef..cdf62757e0 100644 --- a/src/node/services/providerService.test.ts +++ b/src/node/services/providerService.test.ts @@ -204,6 +204,42 @@ describe("ProviderService.getConfig", () => { }); }); + it("surfaces valid OpenAI WebSocket transport preference", () => { + withTempConfig((config, service) => { + config.saveProvidersConfig({ + openai: { + apiKey: "sk-test", + webSocketTransportEnabled: true, + }, + }); + + const cfg = service.getConfig(); + + expect(cfg.openai.webSocketTransportEnabled).toBe(true); + expect(Object.prototype.hasOwnProperty.call(cfg.openai, "webSocketTransportEnabled")).toBe( + true + ); + }); + }); + + it("omits invalid OpenAI WebSocket transport preference", () => { + withTempConfig((config, service) => { + config.saveProvidersConfig({ + openai: { + apiKey: "sk-test", + webSocketTransportEnabled: "true", + }, + }); + + const cfg = service.getConfig(); + + expect(cfg.openai.webSocketTransportEnabled).toBeUndefined(); + expect(Object.prototype.hasOwnProperty.call(cfg.openai, "webSocketTransportEnabled")).toBe( + false + ); + }); + }); + it("surfaces non-secret op:// API key references", () => { withTempConfig((config, service) => { const opRef = "op://Personal/Anthropic/credential"; diff --git a/src/node/services/providerService.ts b/src/node/services/providerService.ts index 0dd93364e9..5ed17ba8de 100644 --- a/src/node/services/providerService.ts +++ b/src/node/services/providerService.ts @@ -255,6 +255,7 @@ export class ProviderService { serviceTier?: string; wireFormat?: string; store?: unknown; + webSocketTransportEnabled?: unknown; cacheTtl?: unknown; disableBetaFeatures?: unknown; /** OpenAI-only: default auth precedence for Codex-OAuth-allowed models. */ @@ -333,6 +334,10 @@ export class ProviderService { providerInfo.store = config.store; } + if (provider === "openai" && typeof config.webSocketTransportEnabled === "boolean") { + providerInfo.webSocketTransportEnabled = config.webSocketTransportEnabled; + } + // Anthropic-specific fields const cacheTtl = config.cacheTtl; if (provider === "anthropic" && (cacheTtl === "5m" || cacheTtl === "1h")) { diff --git a/src/node/services/streamManager.test.ts b/src/node/services/streamManager.test.ts index c4f079d2f2..75d562229a 100644 --- a/src/node/services/streamManager.test.ts +++ b/src/node/services/streamManager.test.ts @@ -5,7 +5,14 @@ import { KNOWN_MODELS } from "@/common/constants/knownModels"; import type { ToolPolicy } from "@/common/utils/tools/toolPolicy"; import { StreamManager, stripEncryptedContent } from "./streamManager"; import * as aiSdk from "ai"; -import { APICallError, RetryError, tool, type ModelMessage, type Tool } from "ai"; +import { + APICallError, + RetryError, + tool, + type LanguageModel, + type ModelMessage, + type Tool, +} from "ai"; import { z } from "zod"; import * as modelStatsModule from "@/common/utils/tokens/modelStats"; import type { HistoryService } from "./historyService"; @@ -17,8 +24,20 @@ import { shouldRunIntegrationTests, validateApiKeys } from "../../../tests/testU import { DisposableTempDir } from "@/node/services/tempDir"; import type { ExecOptions, ExecStream, Runtime } from "@/node/runtime/Runtime"; import { createRuntime } from "@/node/runtime/runtimeFactory"; +import { attachLanguageModelCleanup } from "./languageModelCleanup"; import { shellQuote } from "@/common/utils/shell"; +function createTestLanguageModel(modelId = "cleanup-model"): LanguageModel { + return { + specificationVersion: "v3", + provider: "test", + modelId, + supportedUrls: {}, + doGenerate: () => Promise.reject(new Error("doGenerate is unused in StreamManager tests")), + doStream: () => Promise.reject(new Error("doStream is unused in StreamManager tests")), + }; +} + // Skip integration tests if TEST_INTEGRATION is not set const describeIntegration = shouldRunIntegrationTests() ? describe : describe.skip; @@ -918,6 +937,364 @@ describe("StreamManager - call settings overrides", () => { }); }); +describe("StreamManager - language model cleanup", () => { + const runtime = createRuntime({ type: "local", srcBaseDir: "/tmp" }); + + test("runs model cleanup when stream processing finishes", async () => { + const streamManager = new StreamManager(historyService); + streamManager.on("error", () => undefined); + const workspaceId = "cleanup-workspace"; + const messageId = "cleanup-message"; + const historySequence = 1; + let cleanupCalls = 0; + const model = createTestLanguageModel(); + attachLanguageModelCleanup(model, () => { + cleanupCalls += 1; + }); + + const appendResult = await historyService.appendToHistory(workspaceId, { + id: messageId, + role: "assistant", + metadata: { historySequence, partial: true }, + parts: [], + }); + expect(appendResult.success).toBe(true); + + const processStreamWithCleanup = Reflect.get(streamManager, "processStreamWithCleanup") as ( + workspaceId: string, + streamInfo: unknown, + historySequence: number + ) => Promise; + expect(typeof processStreamWithCleanup).toBe("function"); + + const streamInfo = { + state: "streaming", + streamResult: { + fullStream: (async function* () { + await Promise.resolve(); + yield* [] as unknown[]; + })(), + totalUsage: Promise.resolve({ inputTokens: 1, outputTokens: 1, totalTokens: 2 }), + usage: Promise.resolve({ inputTokens: 1, outputTokens: 1, totalTokens: 2 }), + providerMetadata: Promise.resolve(undefined), + steps: Promise.resolve([]), + }, + abortController: new AbortController(), + messageId, + token: "cleanup-token", + startTime: Date.now(), + lastPartTimestamp: Date.now(), + toolCompletionTimestamps: new Map(), + model: "openai:gpt-4.1-mini", + metadataModel: "openai:gpt-4.1-mini", + historySequence, + request: { model, messages: [], providerOptions: undefined }, + toolModelUsages: [], + parts: [{ type: "text" as const, text: "done", timestamp: Date.now() }], + lastPartialWriteTime: 0, + partialWritePromise: undefined, + processingPromise: Promise.resolve(), + softInterrupt: { pending: false as const }, + runtimeTempDir: "", + runtime, + cumulativeUsage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 }, + cumulativeProviderMetadata: undefined, + didRetryPreviousResponseIdAtStep: false, + currentStepStartIndex: 0, + stepTracker: {}, + }; + + const workspaceStreamsValue: unknown = Reflect.get(streamManager, "workspaceStreams"); + expect(workspaceStreamsValue instanceof Map).toBe(true); + if (!(workspaceStreamsValue instanceof Map)) { + throw new Error("Expected StreamManager.workspaceStreams to be a Map"); + } + workspaceStreamsValue.set(workspaceId, streamInfo); + + await processStreamWithCleanup.call(streamManager, workspaceId, streamInfo, historySequence); + + expect(cleanupCalls).toBe(1); + }); + test("runs model cleanup when startStream exits before processing after abort", async () => { + const streamManager = new StreamManager(historyService); + let cleanupCalls = 0; + const model = createTestLanguageModel("cleanup-preabort-model"); + attachLanguageModelCleanup(model, () => { + cleanupCalls += 1; + }); + const abortController = new AbortController(); + abortController.abort(new Error("pre-abort")); + + const result = await streamManager.startStream( + "cleanup-preabort-workspace", + [{ role: "user", content: "hello" }], + model, + "openai:gpt-4.1-mini", + 1, + "system", + runtime, + "cleanup-preabort-message", + abortController.signal + ); + + expect(result.success).toBe(true); + expect(cleanupCalls).toBe(1); + }); + + test("runs model cleanup when stream creation throws before processing", async () => { + const streamManager = new StreamManager(historyService); + let cleanupCalls = 0; + const model = createTestLanguageModel("cleanup-create-throw-model"); + attachLanguageModelCleanup(model, () => { + cleanupCalls += 1; + }); + const replaceCreateStreamResult = Reflect.set(streamManager, "createStreamResult", () => { + throw new Error("create stream failed"); + }); + expect(replaceCreateStreamResult).toBe(true); + + const result = await streamManager.startStream( + "cleanup-create-throw-workspace", + [{ role: "user", content: "hello" }], + model, + "openai:gpt-4.1-mini", + 1, + "system", + runtime, + "cleanup-create-throw-message" + ); + + expect(result.success).toBe(false); + expect(cleanupCalls).toBe(1); + }); + + test("keeps model cleanup until a multi-step tool stream finishes", async () => { + const streamManager = new StreamManager(historyService); + streamManager.on("error", () => undefined); + const workspaceId = "cleanup-multistep-workspace"; + const messageId = "cleanup-multistep-message"; + const historySequence = 1; + let cleanupCalls = 0; + const model = createTestLanguageModel("cleanup-multistep-model"); + attachLanguageModelCleanup(model, () => { + cleanupCalls += 1; + }); + + const appendResult = await historyService.appendToHistory(workspaceId, { + id: messageId, + role: "assistant", + metadata: { historySequence, partial: true }, + parts: [], + }); + expect(appendResult.success).toBe(true); + + const processStreamWithCleanup = Reflect.get(streamManager, "processStreamWithCleanup") as ( + workspaceId: string, + streamInfo: unknown, + historySequence: number + ) => Promise; + expect(typeof processStreamWithCleanup).toBe("function"); + + const streamInfo = { + state: "streaming", + streamResult: { + fullStream: (async function* () { + await Promise.resolve(); + yield { + type: "tool-call", + toolCallId: "call-1", + toolName: "test_tool", + input: { value: 1 }, + }; + expect(cleanupCalls).toBe(0); + yield { + type: "tool-result", + toolCallId: "call-1", + toolName: "test_tool", + output: { ok: true }, + }; + expect(cleanupCalls).toBe(0); + yield { type: "text-delta", text: "done" }; + expect(cleanupCalls).toBe(0); + yield { type: "finish", finishReason: "stop" }; + })(), + totalUsage: Promise.resolve({ inputTokens: 1, outputTokens: 1, totalTokens: 2 }), + usage: Promise.resolve({ inputTokens: 1, outputTokens: 1, totalTokens: 2 }), + providerMetadata: Promise.resolve(undefined), + steps: Promise.resolve([]), + }, + abortController: new AbortController(), + messageId, + token: "cleanup-multistep-token", + startTime: Date.now(), + lastPartTimestamp: Date.now(), + toolCompletionTimestamps: new Map(), + model: "openai:gpt-4.1-mini", + metadataModel: "openai:gpt-4.1-mini", + historySequence, + request: { model, messages: [], providerOptions: undefined }, + toolModelUsages: [], + parts: [], + lastPartialWriteTime: 0, + partialWritePromise: undefined, + processingPromise: Promise.resolve(), + softInterrupt: { pending: false as const }, + runtimeTempDir: "", + runtime, + cumulativeUsage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 }, + cumulativeProviderMetadata: undefined, + didRetryPreviousResponseIdAtStep: false, + currentStepStartIndex: 0, + stepTracker: {}, + }; + + await processStreamWithCleanup.call(streamManager, workspaceId, streamInfo, historySequence); + + expect(cleanupCalls).toBe(1); + }); + + test("runs model cleanup when stream processing fails", async () => { + const streamManager = new StreamManager(historyService); + streamManager.on("error", () => undefined); + const workspaceId = "cleanup-error-workspace"; + const messageId = "cleanup-error-message"; + const historySequence = 1; + let cleanupCalls = 0; + const model = createTestLanguageModel("cleanup-error-model"); + attachLanguageModelCleanup(model, () => { + cleanupCalls += 1; + }); + + const appendResult = await historyService.appendToHistory(workspaceId, { + id: messageId, + role: "assistant", + metadata: { historySequence, partial: true }, + parts: [], + }); + expect(appendResult.success).toBe(true); + + const processStreamWithCleanup = Reflect.get(streamManager, "processStreamWithCleanup") as ( + workspaceId: string, + streamInfo: unknown, + historySequence: number + ) => Promise; + expect(typeof processStreamWithCleanup).toBe("function"); + + const streamInfo = { + state: "streaming", + streamResult: { + fullStream: (async function* () { + await Promise.resolve(); + throw new Error("stream failed before output"); + yield* [] as unknown[]; + })(), + totalUsage: Promise.resolve({ inputTokens: 1, outputTokens: 0, totalTokens: 1 }), + usage: Promise.resolve({ inputTokens: 1, outputTokens: 0, totalTokens: 1 }), + providerMetadata: Promise.resolve(undefined), + steps: Promise.resolve([]), + }, + abortController: new AbortController(), + messageId, + token: "cleanup-error-token", + startTime: Date.now(), + lastPartTimestamp: Date.now(), + toolCompletionTimestamps: new Map(), + model: "openai:gpt-4.1-mini", + metadataModel: "openai:gpt-4.1-mini", + historySequence, + request: { model, messages: [], providerOptions: undefined }, + toolModelUsages: [], + parts: [], + lastPartialWriteTime: 0, + partialWritePromise: undefined, + processingPromise: Promise.resolve(), + softInterrupt: { pending: false as const }, + runtimeTempDir: "", + runtime, + cumulativeUsage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 }, + cumulativeProviderMetadata: undefined, + didRetryPreviousResponseIdAtStep: false, + currentStepStartIndex: 0, + stepTracker: {}, + }; + + await processStreamWithCleanup.call(streamManager, workspaceId, streamInfo, historySequence); + + expect(cleanupCalls).toBe(1); + }); + + test("runs model cleanup when stream processing is aborted", async () => { + const streamManager = new StreamManager(historyService); + streamManager.on("error", () => undefined); + const workspaceId = "cleanup-abort-workspace"; + const messageId = "cleanup-abort-message"; + const historySequence = 1; + let cleanupCalls = 0; + const model = createTestLanguageModel("cleanup-abort-model"); + attachLanguageModelCleanup(model, () => { + cleanupCalls += 1; + }); + + const appendResult = await historyService.appendToHistory(workspaceId, { + id: messageId, + role: "assistant", + metadata: { historySequence, partial: true }, + parts: [], + }); + expect(appendResult.success).toBe(true); + + const processStreamWithCleanup = Reflect.get(streamManager, "processStreamWithCleanup") as ( + workspaceId: string, + streamInfo: unknown, + historySequence: number + ) => Promise; + expect(typeof processStreamWithCleanup).toBe("function"); + const abortController = new AbortController(); + abortController.abort(new Error("test abort")); + + const streamInfo = { + state: "streaming", + streamResult: { + fullStream: (async function* () { + await Promise.resolve(); + yield* [] as unknown[]; + })(), + totalUsage: Promise.resolve({ inputTokens: 1, outputTokens: 0, totalTokens: 1 }), + usage: Promise.resolve({ inputTokens: 1, outputTokens: 0, totalTokens: 1 }), + providerMetadata: Promise.resolve(undefined), + steps: Promise.resolve([]), + }, + abortController, + messageId, + token: "cleanup-abort-token", + startTime: Date.now(), + lastPartTimestamp: Date.now(), + toolCompletionTimestamps: new Map(), + model: "openai:gpt-4.1-mini", + metadataModel: "openai:gpt-4.1-mini", + historySequence, + request: { model, messages: [], providerOptions: undefined }, + toolModelUsages: [], + parts: [], + lastPartialWriteTime: 0, + partialWritePromise: undefined, + processingPromise: Promise.resolve(), + softInterrupt: { pending: false as const }, + runtimeTempDir: "", + runtime, + cumulativeUsage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 }, + cumulativeProviderMetadata: undefined, + didRetryPreviousResponseIdAtStep: false, + currentStepStartIndex: 0, + stepTracker: {}, + }; + + await processStreamWithCleanup.call(streamManager, workspaceId, streamInfo, historySequence); + + expect(cleanupCalls).toBe(1); + }); +}); + describe("StreamManager - stripEncryptedContent", () => { test("strips encryptedContent from array output shape", () => { const output = [ @@ -1733,6 +2110,7 @@ describe("StreamManager - TTFT metadata persistence", () => { historySequence: params.historySequence, initialMetadata: params.initialMetadata, toolModelUsages: [], + request: { model: createTestLanguageModel(), messages: [], providerOptions: undefined }, parts: params.parts, lastPartialWriteTime: 0, partialWriteTimer: undefined, diff --git a/src/node/services/streamManager.ts b/src/node/services/streamManager.ts index d296647ce8..8ed4312190 100644 --- a/src/node/services/streamManager.ts +++ b/src/node/services/streamManager.ts @@ -60,6 +60,7 @@ import { withSequentialExecution } from "@/node/services/tools/withSequentialExe import type { ResolvedCallSettingsOverrides } from "@/common/config/schemas/modelParameters"; import { resolveModelForMetadata } from "@/common/utils/providers/modelEntries"; import { getErrorMessage } from "@/common/utils/errors"; +import { runLanguageModelCleanup } from "./languageModelCleanup"; import { shellQuote } from "@/common/utils/shell"; import { classify429Capacity } from "@/common/utils/errors/classify429Capacity"; import { normalizeLiteralRequiredToolPattern } from "@/common/utils/agentTools"; @@ -2334,6 +2335,8 @@ export class StreamManager extends EventEmitter { streamInfo.partialWriteTimer = undefined; } + runLanguageModelCleanup(streamInfo.request?.model); + streamInfo.unlinkAbortSignal?.(); streamInfo.unlinkAbortSignal = undefined; @@ -3013,6 +3016,7 @@ export class StreamManager extends EventEmitter { return Ok(streamToken); } finally { if (!streamRegistered) { + runLanguageModelCleanup(model); unlinkAbortSignal(); if (runtimeTempDir) { this.cleanupStreamTempDir(runtime, runtimeTempDir); diff --git a/src/node/services/workspaceTitleGenerator.test.ts b/src/node/services/workspaceTitleGenerator.test.ts index 71b75bcf08..089a6b08b2 100644 --- a/src/node/services/workspaceTitleGenerator.test.ts +++ b/src/node/services/workspaceTitleGenerator.test.ts @@ -1,10 +1,19 @@ -import { APICallError, NoOutputGeneratedError, RetryError } from "ai"; -import { describe, expect, test } from "bun:test"; +import * as aiSdk from "ai"; +import { APICallError, NoOutputGeneratedError, RetryError, type LanguageModel } from "ai"; +import { afterEach, describe, expect, mock, spyOn, test } from "bun:test"; import { buildWorkspaceIdentityPrompt, + generateWorkspaceIdentity, mapModelCreationError, mapNameGenerationError, } from "./workspaceTitleGenerator"; +import { Ok } from "@/common/types/result"; +import type { AIService } from "./aiService"; +import { attachLanguageModelCleanup } from "./languageModelCleanup"; + +afterEach(() => { + mock.restore(); +}); describe("buildWorkspaceIdentityPrompt", () => { test("includes overall-scope guidance, conversation turns, and latest-user context without precedence", () => { @@ -58,6 +67,139 @@ const createApiCallError = ( responseBody: overrides?.responseBody, }); +describe("generateWorkspaceIdentity cleanup", () => { + function createTitleModel(modelId = "title-model"): LanguageModel { + return { + specificationVersion: "v3", + provider: "test", + modelId, + supportedUrls: {}, + doGenerate: () => Promise.reject(new Error("doGenerate is unused in cleanup tests")), + doStream: () => Promise.reject(new Error("doStream is unused in cleanup tests")), + }; + } + + function createTitleAIService(model: LanguageModel): AIService { + return { createModel: () => Promise.resolve(Ok(model)) } as unknown as AIService; + } + + test("cleans up the model after a successful title stream", async () => { + let cleanupCalls = 0; + const model = createTitleModel(); + attachLanguageModelCleanup(model, () => { + cleanupCalls += 1; + }); + const titleAiService = createTitleAIService(model); + + spyOn(aiSdk, "streamText").mockReturnValue({ + toolResults: Promise.resolve([ + { + dynamic: false, + toolName: "propose_name", + output: { name: "settings", title: "Add setting" }, + }, + ]), + } as unknown as ReturnType); + + const result = await generateWorkspaceIdentity( + "Add setting", + ["openai:gpt-4.1-mini"], + titleAiService + ); + + expect(result.success).toBe(true); + expect(cleanupCalls).toBe(1); + }); + + test("cleans up when title stream throws before trying the next candidate", async () => { + let cleanupCalls = 0; + const failingModel = createTitleModel("title-failing-model"); + attachLanguageModelCleanup(failingModel, () => { + cleanupCalls += 1; + }); + const aiService = createTitleAIService(failingModel); + + spyOn(aiSdk, "streamText").mockImplementation(() => { + throw new Error("title stream failed"); + }); + + const result = await generateWorkspaceIdentity( + "Add setting", + ["openai:gpt-4.1-mini"], + aiService + ); + + expect(result.success).toBe(false); + expect(cleanupCalls).toBe(1); + }); + + test("cleans up each candidate when title generation retries", async () => { + let firstCleanupCalls = 0; + let secondCleanupCalls = 0; + const firstModel = createTitleModel("title-first-model"); + const secondModel = createTitleModel("title-second-model"); + attachLanguageModelCleanup(firstModel, () => { + firstCleanupCalls += 1; + }); + attachLanguageModelCleanup(secondModel, () => { + secondCleanupCalls += 1; + }); + const aiService = { + createModel: mock((modelString: string) => + Promise.resolve(Ok(modelString.includes("first") ? firstModel : secondModel)) + ), + } as unknown as AIService; + let streamTextCalls = 0; + spyOn(aiSdk, "streamText").mockImplementation((() => { + streamTextCalls += 1; + if (streamTextCalls === 1) { + throw new Error("first candidate failed"); + } + return { + toolResults: Promise.resolve([ + { + dynamic: false, + toolName: "propose_name", + output: { name: "settings", title: "Add setting" }, + }, + ]), + } as unknown as ReturnType; + }) as unknown as typeof aiSdk.streamText); + + const result = await generateWorkspaceIdentity( + "Add setting", + ["openai:first", "openai:second"], + aiService + ); + + expect(result.success).toBe(true); + expect(firstCleanupCalls).toBe(1); + expect(secondCleanupCalls).toBe(1); + }); + + test("cleans up when title stream returns no propose_name result", async () => { + let cleanupCalls = 0; + const model = createTitleModel("title-no-tool-model"); + attachLanguageModelCleanup(model, () => { + cleanupCalls += 1; + }); + const aiService = createTitleAIService(model); + + spyOn(aiSdk, "streamText").mockReturnValue({ + toolResults: Promise.resolve([]), + } as unknown as ReturnType); + + const result = await generateWorkspaceIdentity( + "Add setting", + ["openai:gpt-4.1-mini"], + aiService + ); + + expect(result.success).toBe(false); + expect(cleanupCalls).toBe(1); + }); +}); + describe("workspaceTitleGenerator error mappers", () => { describe("mapNameGenerationError", () => { test("preserves provider context for auth and permission API failures", () => { diff --git a/src/node/services/workspaceTitleGenerator.ts b/src/node/services/workspaceTitleGenerator.ts index 35b933f7fb..0cea63887d 100644 --- a/src/node/services/workspaceTitleGenerator.ts +++ b/src/node/services/workspaceTitleGenerator.ts @@ -7,6 +7,7 @@ import type { NameGenerationError, SendMessageError } from "@/common/types/error import { getErrorMessage } from "@/common/utils/errors"; import { classify429Capacity } from "@/common/utils/errors/classify429Capacity"; import { TOOL_DEFINITIONS, ProposeNameToolArgsSchema } from "@/common/utils/tools/toolDefinitions"; +import { runLanguageModelCleanup } from "./languageModelCleanup"; import crypto from "crypto"; export interface WorkspaceIdentity { @@ -272,6 +273,8 @@ export async function generateWorkspaceIdentity( lastError = mapNameGenerationError(error, modelString); log.warn("Name generation failed, trying next candidate", { modelString, error: lastError }); continue; + } finally { + runLanguageModelCleanup(modelResult.data); } }