diff --git a/apps/web/src/rpc/protocol.ts b/apps/web/src/rpc/protocol.ts index 5c5c51d8990..0a4288df391 100644 --- a/apps/web/src/rpc/protocol.ts +++ b/apps/web/src/rpc/protocol.ts @@ -18,6 +18,7 @@ import { } from "./wsConnectionState"; export interface WsProtocolLifecycleHandlers { + readonly isActive?: () => boolean; readonly onAttempt?: (socketUrl: string) => void; readonly onOpen?: () => void; readonly onError?: (message: string) => void; @@ -49,6 +50,7 @@ function resolveWsRpcSocketUrl(rawUrl: string): string { function defaultLifecycleHandlers(): Required { return { + isActive: () => true, onAttempt: recordWsConnectionAttempt, onOpen: recordWsConnectionOpened, onError: (message) => { @@ -66,21 +68,35 @@ function composeLifecycleHandlers( handlers?: WsProtocolLifecycleHandlers, ): Required { const defaults = defaultLifecycleHandlers(); + const isActive = handlers?.isActive ?? (() => true); return { + isActive, onAttempt: (socketUrl) => { + if (!isActive()) { + return; + } defaults.onAttempt(socketUrl); handlers?.onAttempt?.(socketUrl); }, onOpen: () => { + if (!isActive()) { + return; + } defaults.onOpen(); handlers?.onOpen?.(); }, onError: (message) => { + if (!isActive()) { + return; + } defaults.onError(message); handlers?.onError?.(message); }, onClose: (details) => { + if (!isActive()) { + return; + } defaults.onClose(details); handlers?.onClose?.(details); }, diff --git a/apps/web/src/rpc/wsTransport.test.ts b/apps/web/src/rpc/wsTransport.test.ts index a5610a931f6..c9632d3f78d 100644 --- a/apps/web/src/rpc/wsTransport.test.ts +++ b/apps/web/src/rpc/wsTransport.test.ts @@ -324,6 +324,11 @@ describe("WsTransport", () => { const secondSocket = getSocket(); expect(secondSocket).not.toBe(firstSocket); expect(firstSocket.readyState).toBe(MockWebSocket.CLOSED); + expect(getWsConnectionStatus()).toMatchObject({ + closeCode: null, + closeReason: null, + phase: "connecting", + }); const requestPromise = transport.request((client) => client[WS_METHODS.serverUpsertKeybinding]({ @@ -361,6 +366,58 @@ describe("WsTransport", () => { await transport.dispose(); }); + it("ignores stale socket lifecycle events after a reconnect starts a new session", async () => { + const onClose = vi.fn(); + const transport = createTransport("ws://localhost:3020", { onClose }); + + await waitFor(() => { + expect(sockets).toHaveLength(1); + }); + + const firstSocket = getSocket(); + firstSocket.open(); + + await waitFor(() => { + expect(getWsConnectionStatus()).toMatchObject({ + hasConnected: true, + phase: "connected", + }); + }); + + await transport.reconnect(); + + await waitFor(() => { + expect(sockets).toHaveLength(2); + }); + + expect(onClose).not.toHaveBeenCalled(); + expect(getWsConnectionStatus()).toMatchObject({ + closeCode: null, + closeReason: null, + phase: "connecting", + }); + + const secondSocket = getSocket(); + secondSocket.open(); + + await waitFor(() => { + expect(getWsConnectionStatus()).toMatchObject({ + phase: "connected", + }); + }); + + firstSocket.close(1006, "stale close"); + + expect(onClose).not.toHaveBeenCalled(); + expect(getWsConnectionStatus()).toMatchObject({ + closeCode: null, + closeReason: null, + phase: "connected", + }); + + await transport.dispose(); + }); + it("marks unary requests as slow until the first server ack arrives", async () => { const slowAckThresholdMs = 25; setSlowRpcAckThresholdMsForTests(slowAckThresholdMs); diff --git a/apps/web/src/rpc/wsTransport.ts b/apps/web/src/rpc/wsTransport.ts index d9a50a9fad0..0b90b22431f 100644 --- a/apps/web/src/rpc/wsTransport.ts +++ b/apps/web/src/rpc/wsTransport.ts @@ -53,6 +53,8 @@ export class WsTransport { private disposed = false; private hasReportedTransportDisconnect = false; private reconnectChain: Promise = Promise.resolve(); + private nextSessionId = 0; + private activeSessionId = 0; private session: TransportSession; constructor( @@ -215,8 +217,17 @@ export class WsTransport { } private createSession(): TransportSession { + const sessionId = this.nextSessionId + 1; + this.nextSessionId = sessionId; + this.activeSessionId = sessionId; const runtime = ManagedRuntime.make( - Layer.mergeAll(createWsRpcProtocolLayer(this.url, this.lifecycleHandlers), ClientTracingLive), + Layer.mergeAll( + createWsRpcProtocolLayer(this.url, { + ...this.lifecycleHandlers, + isActive: () => !this.disposed && this.activeSessionId === sessionId, + }), + ClientTracingLive, + ), ); const clientScope = runtime.runSync(Scope.make()); return {