From 0e3d6b8d0133b759b898dede07f594512adc4212 Mon Sep 17 00:00:00 2001 From: Yumiue <229866007@qq.com> Date: Mon, 18 May 2026 23:49:38 -0400 Subject: [PATCH 1/2] =?UTF-8?q?fix(gateway):=20=E9=9A=94=E7=A6=BB=20Web=20?= =?UTF-8?q?=E5=B7=A5=E4=BD=9C=E5=8C=BA=E4=BC=9A=E8=AF=9D=E7=BB=91=E5=AE=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/gateway/bootstrap.go | 47 +++++++++++--- internal/gateway/bootstrap_test.go | 55 +++++++++++++++- internal/gateway/registry.go | 4 +- internal/gateway/request_logging.go | 22 ++++--- internal/gateway/rpc_dispatch.go | 5 +- internal/gateway/stream_relay.go | 62 ++++++++++++------- internal/gateway/stream_relay_test.go | 43 +++++++++++++ web/src/components/chat/ChatInput.test.tsx | 16 +++++ web/src/components/chat/ChatInput.tsx | 8 ++- .../RuntimeProvider.lifecycle.test.tsx | 41 ++++++++++++ web/src/context/RuntimeProvider.tsx | 52 +++++++++++----- 11 files changed, 294 insertions(+), 61 deletions(-) diff --git a/internal/gateway/bootstrap.go b/internal/gateway/bootstrap.go index 309053620..40b378b02 100644 --- a/internal/gateway/bootstrap.go +++ b/internal/gateway/bootstrap.go @@ -108,7 +108,7 @@ func handleAuthenticateFrame(ctx context.Context, frame MessageFrame) MessageFra } // handleBindStreamFrame 处理 gateway.bindStream 并注册连接订阅关系。 -func handleBindStreamFrame(ctx context.Context, frame MessageFrame) MessageFrame { +func handleBindStreamFrame(ctx context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame { params, err := decodeBindStreamParams(frame.Payload) if err != nil { return errorFrame(frame, err) @@ -120,13 +120,18 @@ func handleBindStreamFrame(ctx context.Context, frame MessageFrame) MessageFrame return errorFrame(frame, NewFrameError(ErrorCodeInternalError, "stream relay context is unavailable")) } + if validationFrame := validateBindStreamSession(ctx, frame, runtimePort, params.SessionID); validationFrame != nil { + return *validationFrame + } + if bindErr := relay.BindConnection(connectionID, StreamBinding{ - SessionID: params.SessionID, - RunID: params.RunID, - Channel: params.Channel, - Role: params.Role, - State: cloneMapValue(params.State), - Explicit: true, + SessionID: params.SessionID, + RunID: params.RunID, + WorkspaceHash: WorkspaceHashFromContext(ctx), + Channel: params.Channel, + Role: params.Role, + State: cloneMapValue(params.State), + Explicit: true, }); bindErr != nil { return errorFrame(frame, bindErr) } @@ -146,6 +151,34 @@ func handleBindStreamFrame(ctx context.Context, frame MessageFrame) MessageFrame } } +// validateBindStreamSession 确认事件流绑定的会话在当前工作区 runtime 中可见。 +func validateBindStreamSession( + ctx context.Context, + frame MessageFrame, + runtimePort RuntimePort, + sessionID string, +) *MessageFrame { + if runtimePort == nil { + return nil + } + normalizedSessionID := strings.TrimSpace(sessionID) + if normalizedSessionID == "" { + return nil + } + + callCtx, cancel := withRuntimeOperationTimeout(ctx) + defer cancel() + _, err := runtimePort.LoadSession(callCtx, LoadSessionInput{ + SubjectID: AuthenticatedSubjectIDFromContext(ctx), + SessionID: normalizedSessionID, + }) + if err == nil { + return nil + } + failedFrame := runtimeCallFailedFrame(callCtx, frame, err, "bind_stream") + return &failedFrame +} + // handleAskFrame 处理 gateway.ask 请求,并以异步方式转发到底层 Ask 编排能力。 func handleAskFrame(ctx context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame { if runtimePort == nil { diff --git a/internal/gateway/bootstrap_test.go b/internal/gateway/bootstrap_test.go index e5551f5fd..b36ca7984 100644 --- a/internal/gateway/bootstrap_test.go +++ b/internal/gateway/bootstrap_test.go @@ -1232,7 +1232,7 @@ func TestHandleBindStreamFrameErrors(t *testing.T) { Payload: protocol.BindStreamParams{ SessionID: "session-1", }, - }) + }, nil) if response.Type != FrameTypeError { t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) } @@ -1271,7 +1271,7 @@ func TestHandleBindStreamFrameErrors(t *testing.T) { SessionID: "session-1", Channel: "ipc", }, - }) + }, nil) if response.Type != FrameTypeError { t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) } @@ -1281,6 +1281,57 @@ func TestHandleBindStreamFrameErrors(t *testing.T) { }) } +func TestHandleBindStreamFrameRejectsSessionOutsideCurrentWorkspace(t *testing.T) { + relay := NewStreamRelay(StreamRelayOptions{}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + connectionID := NewConnectionID() + workspaceState := NewConnectionWorkspaceState() + workspaceState.SetWorkspaceHash("workspace-b") + connectionCtx := WithConnectionID(ctx, connectionID) + connectionCtx = WithConnectionWorkspaceState(connectionCtx, workspaceState) + connectionCtx = WithStreamRelay(connectionCtx, relay) + if err := relay.RegisterConnection(ConnectionRegistration{ + ConnectionID: connectionID, + Channel: StreamChannelIPC, + Context: connectionCtx, + Cancel: cancel, + Write: func(message RelayMessage) error { + _ = message + return nil + }, + Close: func() {}, + }); err != nil { + t.Fatalf("register connection: %v", err) + } + defer relay.dropConnection(connectionID) + + runtimeStub := &bootstrapRuntimeStub{ + loadSessionFn: func(context.Context, LoadSessionInput) (Session, error) { + return Session{}, ErrRuntimeResourceNotFound + }, + } + response := handleBindStreamFrame(connectionCtx, MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionBindStream, + RequestID: "bind-cross-workspace", + Payload: protocol.BindStreamParams{ + SessionID: "session-from-workspace-a", + Channel: "all", + }, + }, runtimeStub) + if response.Type != FrameTypeError { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + } + if response.Error == nil || response.Error.Code != ErrorCodeResourceNotFound.String() { + t.Fatalf("response error = %#v, want resource_not_found", response.Error) + } + if got := relay.ResolveFallbackSessionIDForWorkspace(connectionID, "workspace-b"); got != "" { + t.Fatalf("binding should not be written after validation failure, got fallback %q", got) + } +} + func TestHandleTriggerActionFrame(t *testing.T) { registerConnection := func( t *testing.T, diff --git a/internal/gateway/registry.go b/internal/gateway/registry.go index 518bbcedc..853a665f3 100644 --- a/internal/gateway/registry.go +++ b/internal/gateway/registry.go @@ -38,9 +38,7 @@ func (r *ActionRegistry) initCore() { r.core[FrameActionPing] = func(ctx context.Context, frame MessageFrame, _ RuntimePort) MessageFrame { return handlePingFrame(ctx, frame) } - r.core[FrameActionBindStream] = func(ctx context.Context, frame MessageFrame, _ RuntimePort) MessageFrame { - return handleBindStreamFrame(ctx, frame) - } + r.core[FrameActionBindStream] = handleBindStreamFrame r.core[FrameActionAsk] = handleAskFrame r.core[FrameActionDeleteAskSession] = handleDeleteAskSessionFrame r.core[FrameActionTriggerAction] = handleTriggerActionFrame diff --git a/internal/gateway/request_logging.go b/internal/gateway/request_logging.go index 22a4c23c4..74cfcbcfd 100644 --- a/internal/gateway/request_logging.go +++ b/internal/gateway/request_logging.go @@ -12,15 +12,16 @@ import ( // RequestLogEntry 表示统一结构化请求日志字段。 type RequestLogEntry struct { - RequestID string `json:"request_id"` - SessionID string `json:"session_id"` - Method string `json:"method"` - Source string `json:"source"` - Status string `json:"status"` - GatewayCode string `json:"gateway_code,omitempty"` - LatencyMS int64 `json:"latency_ms"` - ConnectionID string `json:"connection_id,omitempty"` - AuthState string `json:"auth_state,omitempty"` + RequestID string `json:"request_id"` + SessionID string `json:"session_id"` + Method string `json:"method"` + Source string `json:"source"` + Status string `json:"status"` + WorkspaceHash string `json:"workspace_hash,omitempty"` + GatewayCode string `json:"gateway_code,omitempty"` + LatencyMS int64 `json:"latency_ms"` + ConnectionID string `json:"connection_id,omitempty"` + AuthState string `json:"auth_state,omitempty"` } // emitRequestLog 输出网关结构化日志。 @@ -37,6 +38,9 @@ func emitRequestLog(ctx context.Context, logger *log.Logger, entry RequestLogEnt if connectionID, ok := ConnectionIDFromContext(ctx); ok { entry.ConnectionID = string(connectionID) } + if entry.WorkspaceHash == "" { + entry.WorkspaceHash = WorkspaceHashFromContext(ctx) + } if authState, ok := ConnectionAuthStateFromContext(ctx); ok && authState.IsAuthenticated() { entry.AuthState = "authenticated" } else if _, ok := TokenAuthenticatorFromContext(ctx); ok { diff --git a/internal/gateway/rpc_dispatch.go b/internal/gateway/rpc_dispatch.go index 0256ce01b..0360b000d 100644 --- a/internal/gateway/rpc_dispatch.go +++ b/internal/gateway/rpc_dispatch.go @@ -277,7 +277,10 @@ func hydrateFrameSessionFromConnection(ctx context.Context, frame MessageFrame) return frame } - frame.SessionID = strings.TrimSpace(relay.ResolveFallbackSessionID(connectionID)) + frame.SessionID = strings.TrimSpace(relay.ResolveFallbackSessionIDForWorkspace( + connectionID, + WorkspaceHashFromContext(ctx), + )) return frame } diff --git a/internal/gateway/stream_relay.go b/internal/gateway/stream_relay.go index 87616b65c..11374fb52 100644 --- a/internal/gateway/stream_relay.go +++ b/internal/gateway/stream_relay.go @@ -50,12 +50,13 @@ type ConnectionRegistration struct { // StreamBinding 描述连接绑定到会话路由表的一条订阅关系。 type StreamBinding struct { - SessionID string - RunID string - Channel StreamChannel - Role StreamRole - State map[string]any - Explicit bool + SessionID string + RunID string + WorkspaceHash string + Channel StreamChannel + Role StreamRole + State map[string]any + Explicit bool } // StreamRelayOptions 描述会话路由与流式中继的可选配置。 @@ -87,14 +88,15 @@ type bindingKey struct { } type bindingState struct { - sessionID string - runID string - channel StreamChannel - role StreamRole - state map[string]any - explicit bool - expireAt time.Time - lastSeen time.Time + sessionID string + runID string + workspaceHash string + channel StreamChannel + role StreamRole + state map[string]any + explicit bool + expireAt time.Time + lastSeen time.Time } // StreamRelay 维护连接-会话-运行态映射,并负责运行事件的精确中继。 @@ -405,6 +407,11 @@ func (r *StreamRelay) BindConnection(connectionID ConnectionID, binding StreamBi return NewFrameError(ErrorCodeInvalidAction, "bind channel does not match connection channel") } + workspaceHash := strings.TrimSpace(binding.WorkspaceHash) + if workspaceHash == "" { + workspaceHash = WorkspaceHashFromContext(connection.ctx) + } + key := bindingKey{sessionID: sessionID, runID: runID} connectionBindingMap := r.connectionBindings[normalizedConnectionID] if connectionBindingMap == nil { @@ -424,14 +431,15 @@ func (r *StreamRelay) BindConnection(connectionID ConnectionID, binding StreamBi return NewFrameError(ErrorCodeInvalidAction, "too many stream bindings for connection") } connectionBindingMap[key] = &bindingState{ - sessionID: sessionID, - runID: runID, - channel: channel, - role: role, - state: state, - explicit: binding.Explicit, - expireAt: now.Add(r.bindingTTL), - lastSeen: now, + sessionID: sessionID, + runID: runID, + workspaceHash: workspaceHash, + channel: channel, + role: role, + state: state, + explicit: binding.Explicit, + expireAt: now.Add(r.bindingTTL), + lastSeen: now, } r.addConnectionToSessionIndexLocked(sessionID, normalizedConnectionID) if runID != "" { @@ -443,6 +451,11 @@ func (r *StreamRelay) BindConnection(connectionID ConnectionID, binding StreamBi // ResolveFallbackSessionID 返回连接当前可用绑定中的会话兜底值(取最近续期的绑定)。 func (r *StreamRelay) ResolveFallbackSessionID(connectionID ConnectionID) string { + return r.ResolveFallbackSessionIDForWorkspace(connectionID, "") +} + +// ResolveFallbackSessionIDForWorkspace 返回指定工作区内最近续期的连接兜底会话。 +func (r *StreamRelay) ResolveFallbackSessionIDForWorkspace(connectionID ConnectionID, workspaceHash string) string { if r == nil { return "" } @@ -453,6 +466,7 @@ func (r *StreamRelay) ResolveFallbackSessionID(connectionID ConnectionID) string } now := time.Now() + normalizedWorkspaceHash := strings.TrimSpace(workspaceHash) r.mu.RLock() connectionBindingMap := r.connectionBindings[normalizedConnectionID] @@ -464,6 +478,10 @@ func (r *StreamRelay) ResolveFallbackSessionID(connectionID ConnectionID) string if state == nil || state.expireAt.Before(now) { continue } + if normalizedWorkspaceHash != "" && + !strings.EqualFold(strings.TrimSpace(state.workspaceHash), normalizedWorkspaceHash) { + continue + } if state.lastSeen.After(latestSeen) { latestSeen = state.lastSeen latestSessionID = state.sessionID diff --git a/internal/gateway/stream_relay_test.go b/internal/gateway/stream_relay_test.go index 2e980bcfd..3dd430873 100644 --- a/internal/gateway/stream_relay_test.go +++ b/internal/gateway/stream_relay_test.go @@ -47,6 +47,49 @@ func TestStreamRelayBindAndFallbackSession(t *testing.T) { } } +func TestStreamRelayFallbackSessionIsWorkspaceScoped(t *testing.T) { + relay := NewStreamRelay(StreamRelayOptions{}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + connectionID := NewConnectionID() + workspaceState := NewConnectionWorkspaceState() + workspaceState.SetWorkspaceHash("workspace-a") + connectionCtx := WithConnectionID(ctx, connectionID) + connectionCtx = WithConnectionWorkspaceState(connectionCtx, workspaceState) + connectionCtx = WithStreamRelay(connectionCtx, relay) + if err := relay.RegisterConnection(ConnectionRegistration{ + ConnectionID: connectionID, + Channel: StreamChannelIPC, + Context: connectionCtx, + Cancel: cancel, + Write: func(message RelayMessage) error { + _ = message + return nil + }, + Close: func() {}, + }); err != nil { + t.Fatalf("register connection: %v", err) + } + defer relay.dropConnection(connectionID) + + if bindErr := relay.BindConnection(connectionID, StreamBinding{ + SessionID: "session-a", + Channel: StreamChannelAll, + Explicit: true, + }); bindErr != nil { + t.Fatalf("bind workspace-a: %v", bindErr) + } + + workspaceState.SetWorkspaceHash("workspace-b") + if got := relay.ResolveFallbackSessionIDForWorkspace(connectionID, "workspace-b"); got != "" { + t.Fatalf("workspace-b fallback session id = %q, want empty", got) + } + if got := relay.ResolveFallbackSessionIDForWorkspace(connectionID, "workspace-a"); got != "session-a" { + t.Fatalf("workspace-a fallback session id = %q, want session-a", got) + } +} + func TestStreamRelayPublishRuntimeEventNoCrossSession(t *testing.T) { relay := NewStreamRelay(StreamRelayOptions{}) ctx, cancel := context.WithCancel(context.Background()) diff --git a/web/src/components/chat/ChatInput.test.tsx b/web/src/components/chat/ChatInput.test.tsx index e1993b3f2..17f97f477 100644 --- a/web/src/components/chat/ChatInput.test.tsx +++ b/web/src/components/chat/ChatInput.test.tsx @@ -5,6 +5,7 @@ import { useChatStore } from '@/stores/useChatStore' import { useComposerStore } from '@/stores/useComposerStore' import { useSessionStore } from '@/stores/useSessionStore' import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore' +import { useGatewayStore } from '@/stores/useGatewayStore' const mockGatewayAPI = { listAvailableSkills: vi.fn(), @@ -70,6 +71,7 @@ describe('ChatInput', () => { useComposerStore.setState({ composerText: '' }) useSessionStore.setState({ currentSessionId: '' } as never) + useGatewayStore.setState({ currentRunId: '' } as never) useRuntimeInsightStore.getState().reset() useChatStore.setState({ isGenerating: false, @@ -335,4 +337,18 @@ describe('ChatInput', () => { expect(ring).toHaveAttribute('stroke', 'var(--error)') }) + + it('sends session id when cancelling an active run', async () => { + useSessionStore.setState({ currentSessionId: 'session-1' } as never) + useGatewayStore.setState({ currentRunId: 'run-1' } as never) + useChatStore.setState({ isGenerating: true } as never) + mockGatewayAPI.cancel.mockResolvedValueOnce({ payload: { canceled: true, run_id: 'run-1' } }) + render() + + fireEvent.click(screen.getByTitle('停止生成')) + + await waitFor(() => { + expect(mockGatewayAPI.cancel).toHaveBeenCalledWith({ session_id: 'session-1', run_id: 'run-1' }) + }) + }) }) diff --git a/web/src/components/chat/ChatInput.tsx b/web/src/components/chat/ChatInput.tsx index 834e322d4..9997c51a2 100644 --- a/web/src/components/chat/ChatInput.tsx +++ b/web/src/components/chat/ChatInput.tsx @@ -424,9 +424,15 @@ export default function ChatInput() { async function handleCancel() { runCancelledRef.current = true const runId = useGatewayStore.getState().currentRunId + const currentSessionId = useSessionStore.getState().currentSessionId if (runId && gatewayAPI) { + if (!isValidSessionId(currentSessionId)) { + useUIStore.getState().showToast('Cannot cancel run without an active session', 'error') + useChatStore.getState().resetGeneratingState() + return + } try { - await gatewayAPI.cancel({ run_id: runId }) + await gatewayAPI.cancel({ session_id: currentSessionId, run_id: runId }) } catch (err) { console.error('Cancel failed:', err) } diff --git a/web/src/context/RuntimeProvider.lifecycle.test.tsx b/web/src/context/RuntimeProvider.lifecycle.test.tsx index a21b9d551..73e305551 100644 --- a/web/src/context/RuntimeProvider.lifecycle.test.tsx +++ b/web/src/context/RuntimeProvider.lifecycle.test.tsx @@ -41,6 +41,7 @@ vi.mock('@/api/wsClient', () => ({ } }), _emitState: (s: any) => onState?.(s), + _emitReconnect: () => onReconnect?.(), } clients.push(client) return client @@ -149,5 +150,45 @@ describe('RuntimeProvider lifecycle', () => { expect(chatClear).toHaveBeenCalled() expect(runtimeSnapshot.status).toBe('needs_config') }) + + it('restores workspace context before rebinding session on reconnect', async () => { + sessionStorage.setItem( + 'neocode.browserRuntimeConfig', + JSON.stringify({ mode: 'browser', gatewayBaseURL: 'http://127.0.0.1:8080', token: 'tok' }), + ) + useWorkspaceStore.setState({ + fetchWorkspaces: vi.fn().mockResolvedValue(undefined), + workspaces: [{ hash: 'w2', path: '/workspace-two', name: 'Two', createdAt: '', updatedAt: '' }], + currentWorkspaceHash: 'w2', + } as any) + useSessionStore.setState({ + ...useSessionStore.getState(), + currentSessionId: 'session-2', + fetchSessions: vi.fn().mockResolvedValue(undefined), + } as any) + + let runtimeSnapshot: any = null + render( + + { runtimeSnapshot = rt }} /> + , + ) + await waitFor(() => expect(runtimeSnapshot?.status).toBe('connected')) + const client = clients[0] + client.call.mockClear() + + await act(async () => { + await client._emitReconnect() + }) + + const methods = client.call.mock.calls.map((call: any[]) => call[0]) + const switchIndex = methods.indexOf('gateway.switchWorkspace') + const bindIndex = methods.indexOf('gateway.bindStream') + expect(switchIndex).toBeGreaterThanOrEqual(0) + expect(bindIndex).toBeGreaterThanOrEqual(0) + expect(switchIndex).toBeLessThan(bindIndex) + expect(client.call).toHaveBeenCalledWith('gateway.switchWorkspace', { workspace_hash: 'w2' }) + expect(client.call).toHaveBeenCalledWith('gateway.bindStream', { session_id: 'session-2', channel: 'all' }) + }) }) diff --git a/web/src/context/RuntimeProvider.tsx b/web/src/context/RuntimeProvider.tsx index 2486936ae..93c059f5c 100644 --- a/web/src/context/RuntimeProvider.tsx +++ b/web/src/context/RuntimeProvider.tsx @@ -6,6 +6,7 @@ import { useGatewayStore } from '@/stores/useGatewayStore' import { useSessionStore, isValidSessionId } from '@/stores/useSessionStore' import { useUIStore } from '@/stores/useUIStore' import { useWorkspaceStore } from '@/stores/useWorkspaceStore' +import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore' import { handleGatewayEvent } from '@/utils/eventBridge' const browserRuntimeStorageKey = 'neocode.browserRuntimeConfig' @@ -67,6 +68,34 @@ async function refreshPendingUserQuestion(gatewayAPI: GatewayAPI, sessionId: str } } +/** syncWorkspaceContext 将前端选中的工作区恢复到当前 Gateway 连接上下文。 */ +async function syncWorkspaceContext(gatewayAPI: GatewayAPI): Promise { + const workspaceStore = useWorkspaceStore.getState() + await workspaceStore.fetchWorkspaces(gatewayAPI) + + const nextState = useWorkspaceStore.getState() + if (nextState.workspaces.length === 0) return false + + const currentHash = nextState.currentWorkspaceHash.trim() + const target = + nextState.workspaces.find((workspace) => workspace.hash === currentHash) ?? + nextState.workspaces[0] + if (!target?.hash) return false + + if (target.hash !== currentHash) { + useWorkspaceStore.getState().setCurrentWorkspaceHash(target.hash) + useChatStore.getState().clearMessages() + useSessionStore.getState().setCurrentSessionId('') + useSessionStore.getState().setCurrentProjectId('') + useGatewayStore.getState().setCurrentRunId('') + useRuntimeInsightStore.getState().reset() + useUIStore.getState().clearFileChanges() + } + + await gatewayAPI.switchWorkspace(target.hash) + return true +} + /** RuntimeProvider 装配前端运行时,并为业务组件提供当前 Gateway 客户端。 */ export function RuntimeProvider({ children }: { children: ReactNode }) { const mode = useMemo(detectRuntimeMode, []) @@ -148,21 +177,12 @@ export function RuntimeProvider({ children }: { children: ReactNode }) { await api.authenticate(nextConfig.token) useGatewayStore.getState().setAuthenticated(true) - // Re-bind stream for current session (skip temporary IDs) - const sessionId = useSessionStore.getState().currentSessionId - if (isValidSessionId(sessionId)) { - await api.bindStream({ session_id: sessionId, channel: 'all' }) - } - - // Refresh workspace list (best-effort) and session list - try { - await useWorkspaceStore.getState().fetchWorkspaces(api) - } catch (workspaceErr) { - console.warn('[RuntimeProvider] reconnect fetchWorkspaces failed:', workspaceErr) - } - - const hasWorkspaces = useWorkspaceStore.getState().workspaces.length > 0 + const hasWorkspaces = await syncWorkspaceContext(api) if (hasWorkspaces) { + const sessionId = useSessionStore.getState().currentSessionId + if (isValidSessionId(sessionId)) { + await api.bindStream({ session_id: sessionId, channel: 'all' }) + } await useSessionStore.getState().fetchSessions(api, true) } await refreshPendingUserQuestion(api, useSessionStore.getState().currentSessionId) @@ -194,14 +214,14 @@ export function RuntimeProvider({ children }: { children: ReactNode }) { useGatewayStore.getState().setAuthenticated(true) // Fetch workspaces (best-effort; gracefully degrades if backend not upgraded) + let hasWorkspaces = false try { - await useWorkspaceStore.getState().fetchWorkspaces(api) + hasWorkspaces = await syncWorkspaceContext(api) } catch (workspaceErr) { console.warn('[RuntimeProvider] fetchWorkspaces failed, falling back to single workspace:', workspaceErr) } // Fetch sessions and initialize only when workspaces exist - const hasWorkspaces = useWorkspaceStore.getState().workspaces.length > 0 if (hasWorkspaces) { await useSessionStore.getState().fetchSessions(api, true) await useSessionStore.getState().initializeActiveSession(api) From 82a77f7cdc0341fe63861b88947b1aed54706464 Mon Sep 17 00:00:00 2001 From: Yumiue <229866007@qq.com> Date: Tue, 19 May 2026 02:38:02 -0400 Subject: [PATCH 2/2] fix(web): handle stale session state on reconnect --- internal/gateway/bootstrap_test.go | 53 +++++++++++++++ web/src/components/chat/ChatInput.test.tsx | 30 +++++++++ web/src/components/chat/ChatInput.tsx | 13 ++-- .../RuntimeProvider.lifecycle.test.tsx | 64 ++++++++++++++++++- web/src/context/RuntimeProvider.tsx | 28 ++++++-- web/src/stores/useSessionStore.ts | 8 ++- 6 files changed, 182 insertions(+), 14 deletions(-) diff --git a/internal/gateway/bootstrap_test.go b/internal/gateway/bootstrap_test.go index b36ca7984..441372589 100644 --- a/internal/gateway/bootstrap_test.go +++ b/internal/gateway/bootstrap_test.go @@ -1332,6 +1332,59 @@ func TestHandleBindStreamFrameRejectsSessionOutsideCurrentWorkspace(t *testing.T } } +func TestHandleBindStreamFrameValidatesVisibleSessionBeforeBinding(t *testing.T) { + relay := NewStreamRelay(StreamRelayOptions{}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + connectionID := NewConnectionID() + workspaceState := NewConnectionWorkspaceState() + workspaceState.SetWorkspaceHash("workspace-a") + connectionCtx := WithConnectionID(ctx, connectionID) + connectionCtx = WithConnectionWorkspaceState(connectionCtx, workspaceState) + connectionCtx = WithStreamRelay(connectionCtx, relay) + if err := relay.RegisterConnection(ConnectionRegistration{ + ConnectionID: connectionID, + Channel: StreamChannelIPC, + Context: connectionCtx, + Cancel: cancel, + Write: func(message RelayMessage) error { + _ = message + return nil + }, + Close: func() {}, + }); err != nil { + t.Fatalf("register connection: %v", err) + } + defer relay.dropConnection(connectionID) + + var loaded LoadSessionInput + runtimeStub := &bootstrapRuntimeStub{ + loadSessionFn: func(_ context.Context, input LoadSessionInput) (Session, error) { + loaded = input + return Session{ID: input.SessionID}, nil + }, + } + response := handleBindStreamFrame(connectionCtx, MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionBindStream, + RequestID: "bind-visible-session", + Payload: protocol.BindStreamParams{ + SessionID: "session-visible", + Channel: "all", + }, + }, runtimeStub) + if response.Type != FrameTypeAck { + t.Fatalf("response type = %q, want %q: %#v", response.Type, FrameTypeAck, response.Error) + } + if loaded.SessionID != "session-visible" { + t.Fatalf("validated session_id = %q, want %q", loaded.SessionID, "session-visible") + } + if got := relay.ResolveFallbackSessionIDForWorkspace(connectionID, "workspace-a"); got != "session-visible" { + t.Fatalf("fallback session = %q, want %q", got, "session-visible") + } +} + func TestHandleTriggerActionFrame(t *testing.T) { registerConnection := func( t *testing.T, diff --git a/web/src/components/chat/ChatInput.test.tsx b/web/src/components/chat/ChatInput.test.tsx index 17f97f477..53829db08 100644 --- a/web/src/components/chat/ChatInput.test.tsx +++ b/web/src/components/chat/ChatInput.test.tsx @@ -351,4 +351,34 @@ describe('ChatInput', () => { expect(mockGatewayAPI.cancel).toHaveBeenCalledWith({ session_id: 'session-1', run_id: 'run-1' }) }) }) + + it('falls back to run id only when cancelling without an active session', async () => { + useSessionStore.setState({ currentSessionId: '' } as never) + useGatewayStore.setState({ currentRunId: 'run-1' } as never) + useChatStore.setState({ isGenerating: true } as never) + mockGatewayAPI.cancel.mockResolvedValueOnce({ payload: { canceled: true, run_id: 'run-1' } }) + render() + + fireEvent.click(screen.getByTitle(/停止生成/)) + + await waitFor(() => { + expect(mockGatewayAPI.cancel).toHaveBeenCalledWith({ run_id: 'run-1' }) + }) + }) + + it('does not reset generating state when no cancel request is sent', async () => { + const resetGeneratingState = vi.spyOn(useChatStore.getState(), 'resetGeneratingState') + useSessionStore.setState({ currentSessionId: '' } as never) + useGatewayStore.setState({ currentRunId: '' } as never) + useChatStore.setState({ isGenerating: true } as never) + render() + + fireEvent.click(screen.getByTitle(/停止生成/)) + + await waitFor(() => { + expect(mockGatewayAPI.cancel).not.toHaveBeenCalled() + }) + expect(resetGeneratingState).not.toHaveBeenCalled() + resetGeneratingState.mockRestore() + }) }) diff --git a/web/src/components/chat/ChatInput.tsx b/web/src/components/chat/ChatInput.tsx index 9997c51a2..6291a702b 100644 --- a/web/src/components/chat/ChatInput.tsx +++ b/web/src/components/chat/ChatInput.tsx @@ -426,18 +426,17 @@ export default function ChatInput() { const runId = useGatewayStore.getState().currentRunId const currentSessionId = useSessionStore.getState().currentSessionId if (runId && gatewayAPI) { - if (!isValidSessionId(currentSessionId)) { - useUIStore.getState().showToast('Cannot cancel run without an active session', 'error') - useChatStore.getState().resetGeneratingState() - return - } try { - await gatewayAPI.cancel({ session_id: currentSessionId, run_id: runId }) + const cancelParams = isValidSessionId(currentSessionId) + ? { session_id: currentSessionId, run_id: runId } + : { run_id: runId } + await gatewayAPI.cancel(cancelParams) + useChatStore.getState().resetGeneratingState() } catch (err) { console.error('Cancel failed:', err) } + return } - useChatStore.getState().resetGeneratingState() } const isEmpty = !text.trim() diff --git a/web/src/context/RuntimeProvider.lifecycle.test.tsx b/web/src/context/RuntimeProvider.lifecycle.test.tsx index 73e305551..2a385d0ed 100644 --- a/web/src/context/RuntimeProvider.lifecycle.test.tsx +++ b/web/src/context/RuntimeProvider.lifecycle.test.tsx @@ -164,7 +164,14 @@ describe('RuntimeProvider lifecycle', () => { useSessionStore.setState({ ...useSessionStore.getState(), currentSessionId: 'session-2', - fetchSessions: vi.fn().mockResolvedValue(undefined), + fetchSessions: vi.fn(async (gatewayAPI: any) => { + await gatewayAPI.listSessions() + }), + projects: [{ + id: 'group_today', + name: 'Today', + sessions: [{ id: 'session-2', title: 'Two', time: new Date(0).toISOString() }], + }], } as any) let runtimeSnapshot: any = null @@ -183,12 +190,65 @@ describe('RuntimeProvider lifecycle', () => { const methods = client.call.mock.calls.map((call: any[]) => call[0]) const switchIndex = methods.indexOf('gateway.switchWorkspace') + const fetchIndex = methods.indexOf('gateway.listSessions') const bindIndex = methods.indexOf('gateway.bindStream') expect(switchIndex).toBeGreaterThanOrEqual(0) + expect(fetchIndex).toBeGreaterThanOrEqual(0) expect(bindIndex).toBeGreaterThanOrEqual(0) - expect(switchIndex).toBeLessThan(bindIndex) + expect(switchIndex).toBeLessThan(fetchIndex) + expect(fetchIndex).toBeLessThan(bindIndex) expect(client.call).toHaveBeenCalledWith('gateway.switchWorkspace', { workspace_hash: 'w2' }) expect(client.call).toHaveBeenCalledWith('gateway.bindStream', { session_id: 'session-2', channel: 'all' }) }) + + it('recovers reconnect when rebinding a stale session fails', async () => { + sessionStorage.setItem( + 'neocode.browserRuntimeConfig', + JSON.stringify({ mode: 'browser', gatewayBaseURL: 'http://127.0.0.1:8080', token: 'tok' }), + ) + useWorkspaceStore.setState({ + fetchWorkspaces: vi.fn().mockResolvedValue(undefined), + workspaces: [{ hash: 'w2', path: '/workspace-two', name: 'Two', createdAt: '', updatedAt: '' }], + currentWorkspaceHash: 'w2', + } as any) + useSessionStore.setState({ + ...useSessionStore.getState(), + currentSessionId: 'session-stale', + fetchSessions: vi.fn().mockResolvedValue(undefined), + projects: [{ + id: 'group_today', + name: 'Today', + sessions: [{ id: 'session-stale', title: 'Stale', time: new Date(0).toISOString() }], + }], + } as any) + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + + let runtimeSnapshot: any = null + render( + + { runtimeSnapshot = rt }} /> + , + ) + await waitFor(() => expect(runtimeSnapshot?.status).toBe('connected')) + const client = clients[0] + client.call.mockImplementation(async (method: string, params?: any) => { + if (method === 'gateway.bindStream' && params?.session_id === 'session-stale') { + throw new Error('session not found') + } + if (method === 'gateway.authenticate') return { payload: {} } + if (method === 'gateway.listWorkspaces') return { payload: { workspaces: [] } } + if (method === 'gateway.ping') return { payload: {} } + return { payload: {} } + }) + + await act(async () => { + await client._emitReconnect() + }) + + await waitFor(() => expect(runtimeSnapshot?.status).toBe('connected')) + expect(runtimeSnapshot.error).toBe('') + expect(useSessionStore.getState().setCurrentSessionId).toHaveBeenCalledWith('') + warnSpy.mockRestore() + }) }) diff --git a/web/src/context/RuntimeProvider.tsx b/web/src/context/RuntimeProvider.tsx index 93c059f5c..d48b63704 100644 --- a/web/src/context/RuntimeProvider.tsx +++ b/web/src/context/RuntimeProvider.tsx @@ -96,6 +96,29 @@ async function syncWorkspaceContext(gatewayAPI: GatewayAPI): Promise { return true } +function sessionExistsInProjects(sessionId: string) { + return useSessionStore.getState().projects.some((project) => + project.sessions.some((session) => session.id === sessionId), + ) +} + +async function bindCurrentSessionForReconnect(gatewayAPI: GatewayAPI) { + const sessionId = useSessionStore.getState().currentSessionId + if (!isValidSessionId(sessionId)) return + if (!sessionExistsInProjects(sessionId)) { + useSessionStore.getState().setCurrentSessionId('') + useSessionStore.getState().setCurrentProjectId('') + return + } + try { + await gatewayAPI.bindStream({ session_id: sessionId, channel: 'all' }) + } catch (err) { + console.warn('[RuntimeProvider] Reconnect bindStream skipped stale session:', err) + useSessionStore.getState().setCurrentSessionId('') + useSessionStore.getState().setCurrentProjectId('') + } +} + /** RuntimeProvider 装配前端运行时,并为业务组件提供当前 Gateway 客户端。 */ export function RuntimeProvider({ children }: { children: ReactNode }) { const mode = useMemo(detectRuntimeMode, []) @@ -179,11 +202,8 @@ export function RuntimeProvider({ children }: { children: ReactNode }) { const hasWorkspaces = await syncWorkspaceContext(api) if (hasWorkspaces) { - const sessionId = useSessionStore.getState().currentSessionId - if (isValidSessionId(sessionId)) { - await api.bindStream({ session_id: sessionId, channel: 'all' }) - } await useSessionStore.getState().fetchSessions(api, true) + await bindCurrentSessionForReconnect(api) } await refreshPendingUserQuestion(api, useSessionStore.getState().currentSessionId) diff --git a/web/src/stores/useSessionStore.ts b/web/src/stores/useSessionStore.ts index 63ff495bc..6cee19384 100644 --- a/web/src/stores/useSessionStore.ts +++ b/web/src/stores/useSessionStore.ts @@ -414,8 +414,14 @@ export const useSessionStore = create((set, get) => ({ const projects = mapSessionsToProjects(sessions) set({ projects, loading: false }) - const state = get() + let state = get() if (requestSeq !== _fetchSessionsSeq) return + const currentSessionVisible = isValidSessionId(state.currentSessionId) && + sessions.some((session) => session.id === state.currentSessionId) + if (isValidSessionId(state.currentSessionId) && !currentSessionVisible) { + set({ currentSessionId: '', currentProjectId: '', _initialBindDone: false }) + state = get() + } if (!isValidSessionId(state.currentSessionId) && sessions.length > 0) { const firstSession = sessions[0] set({ currentSessionId: firstSession.id })