Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 40 additions & 7 deletions internal/gateway/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func handleAuthenticateFrame(ctx context.Context, frame MessageFrame) MessageFra
}

// handleBindStreamFrame 处理 gateway.bindStream 并注册连接订阅关系。
func handleBindStreamFrame(ctx context.Context, frame MessageFrame) MessageFrame {
func handleBindStreamFrame(ctx context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame {
params, err := decodeBindStreamParams(frame.Payload)
if err != nil {
return errorFrame(frame, err)
Expand All @@ -120,13 +120,18 @@ func handleBindStreamFrame(ctx context.Context, frame MessageFrame) MessageFrame
return errorFrame(frame, NewFrameError(ErrorCodeInternalError, "stream relay context is unavailable"))
}

if validationFrame := validateBindStreamSession(ctx, frame, runtimePort, params.SessionID); validationFrame != nil {
return *validationFrame
}

if bindErr := relay.BindConnection(connectionID, StreamBinding{
SessionID: params.SessionID,
RunID: params.RunID,
Channel: params.Channel,
Role: params.Role,
State: cloneMapValue(params.State),
Explicit: true,
SessionID: params.SessionID,
RunID: params.RunID,
WorkspaceHash: WorkspaceHashFromContext(ctx),
Channel: params.Channel,
Role: params.Role,
State: cloneMapValue(params.State),
Explicit: true,
}); bindErr != nil {
return errorFrame(frame, bindErr)
}
Expand All @@ -146,6 +151,34 @@ func handleBindStreamFrame(ctx context.Context, frame MessageFrame) MessageFrame
}
}

// validateBindStreamSession 确认事件流绑定的会话在当前工作区 runtime 中可见。
func validateBindStreamSession(
ctx context.Context,
frame MessageFrame,
runtimePort RuntimePort,
sessionID string,
) *MessageFrame {
if runtimePort == nil {
return nil
}
normalizedSessionID := strings.TrimSpace(sessionID)
if normalizedSessionID == "" {
return nil
}

callCtx, cancel := withRuntimeOperationTimeout(ctx)
defer cancel()
_, err := runtimePort.LoadSession(callCtx, LoadSessionInput{
SubjectID: AuthenticatedSubjectIDFromContext(ctx),
SessionID: normalizedSessionID,
})
if err == nil {
return nil
}
failedFrame := runtimeCallFailedFrame(callCtx, frame, err, "bind_stream")
return &failedFrame
}

// handleAskFrame 处理 gateway.ask 请求,并以异步方式转发到底层 Ask 编排能力。
func handleAskFrame(ctx context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame {
if runtimePort == nil {
Expand Down
108 changes: 106 additions & 2 deletions internal/gateway/bootstrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1232,7 +1232,7 @@ func TestHandleBindStreamFrameErrors(t *testing.T) {
Payload: protocol.BindStreamParams{
SessionID: "session-1",
},
})
}, nil)
if response.Type != FrameTypeError {
t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError)
}
Expand Down Expand Up @@ -1271,7 +1271,7 @@ func TestHandleBindStreamFrameErrors(t *testing.T) {
SessionID: "session-1",
Channel: "ipc",
},
})
}, nil)
if response.Type != FrameTypeError {
t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError)
}
Expand All @@ -1281,6 +1281,110 @@ func TestHandleBindStreamFrameErrors(t *testing.T) {
})
}

func TestHandleBindStreamFrameRejectsSessionOutsideCurrentWorkspace(t *testing.T) {
relay := NewStreamRelay(StreamRelayOptions{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

connectionID := NewConnectionID()
workspaceState := NewConnectionWorkspaceState()
workspaceState.SetWorkspaceHash("workspace-b")
connectionCtx := WithConnectionID(ctx, connectionID)
connectionCtx = WithConnectionWorkspaceState(connectionCtx, workspaceState)
connectionCtx = WithStreamRelay(connectionCtx, relay)
if err := relay.RegisterConnection(ConnectionRegistration{
ConnectionID: connectionID,
Channel: StreamChannelIPC,
Context: connectionCtx,
Cancel: cancel,
Write: func(message RelayMessage) error {
_ = message
return nil
},
Close: func() {},
}); err != nil {
t.Fatalf("register connection: %v", err)
}
defer relay.dropConnection(connectionID)

runtimeStub := &bootstrapRuntimeStub{
loadSessionFn: func(context.Context, LoadSessionInput) (Session, error) {
return Session{}, ErrRuntimeResourceNotFound
},
}
response := handleBindStreamFrame(connectionCtx, MessageFrame{
Type: FrameTypeRequest,
Action: FrameActionBindStream,
RequestID: "bind-cross-workspace",
Payload: protocol.BindStreamParams{
SessionID: "session-from-workspace-a",
Channel: "all",
},
}, runtimeStub)
if response.Type != FrameTypeError {
t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError)
}
if response.Error == nil || response.Error.Code != ErrorCodeResourceNotFound.String() {
t.Fatalf("response error = %#v, want resource_not_found", response.Error)
}
if got := relay.ResolveFallbackSessionIDForWorkspace(connectionID, "workspace-b"); got != "" {
t.Fatalf("binding should not be written after validation failure, got fallback %q", got)
}
}

func TestHandleBindStreamFrameValidatesVisibleSessionBeforeBinding(t *testing.T) {
relay := NewStreamRelay(StreamRelayOptions{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

connectionID := NewConnectionID()
workspaceState := NewConnectionWorkspaceState()
workspaceState.SetWorkspaceHash("workspace-a")
connectionCtx := WithConnectionID(ctx, connectionID)
connectionCtx = WithConnectionWorkspaceState(connectionCtx, workspaceState)
connectionCtx = WithStreamRelay(connectionCtx, relay)
if err := relay.RegisterConnection(ConnectionRegistration{
ConnectionID: connectionID,
Channel: StreamChannelIPC,
Context: connectionCtx,
Cancel: cancel,
Write: func(message RelayMessage) error {
_ = message
return nil
},
Close: func() {},
}); err != nil {
t.Fatalf("register connection: %v", err)
}
defer relay.dropConnection(connectionID)

var loaded LoadSessionInput
runtimeStub := &bootstrapRuntimeStub{
loadSessionFn: func(_ context.Context, input LoadSessionInput) (Session, error) {
loaded = input
return Session{ID: input.SessionID}, nil
},
}
response := handleBindStreamFrame(connectionCtx, MessageFrame{
Type: FrameTypeRequest,
Action: FrameActionBindStream,
RequestID: "bind-visible-session",
Payload: protocol.BindStreamParams{
SessionID: "session-visible",
Channel: "all",
},
}, runtimeStub)
if response.Type != FrameTypeAck {
t.Fatalf("response type = %q, want %q: %#v", response.Type, FrameTypeAck, response.Error)
}
if loaded.SessionID != "session-visible" {
t.Fatalf("validated session_id = %q, want %q", loaded.SessionID, "session-visible")
}
if got := relay.ResolveFallbackSessionIDForWorkspace(connectionID, "workspace-a"); got != "session-visible" {
t.Fatalf("fallback session = %q, want %q", got, "session-visible")
}
}

func TestHandleTriggerActionFrame(t *testing.T) {
registerConnection := func(
t *testing.T,
Expand Down
4 changes: 1 addition & 3 deletions internal/gateway/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ func (r *ActionRegistry) initCore() {
r.core[FrameActionPing] = func(ctx context.Context, frame MessageFrame, _ RuntimePort) MessageFrame {
return handlePingFrame(ctx, frame)
}
r.core[FrameActionBindStream] = func(ctx context.Context, frame MessageFrame, _ RuntimePort) MessageFrame {
return handleBindStreamFrame(ctx, frame)
}
r.core[FrameActionBindStream] = handleBindStreamFrame
r.core[FrameActionAsk] = handleAskFrame
r.core[FrameActionDeleteAskSession] = handleDeleteAskSessionFrame
r.core[FrameActionTriggerAction] = handleTriggerActionFrame
Expand Down
22 changes: 13 additions & 9 deletions internal/gateway/request_logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ import (

// RequestLogEntry 表示统一结构化请求日志字段。
type RequestLogEntry struct {
RequestID string `json:"request_id"`
SessionID string `json:"session_id"`
Method string `json:"method"`
Source string `json:"source"`
Status string `json:"status"`
GatewayCode string `json:"gateway_code,omitempty"`
LatencyMS int64 `json:"latency_ms"`
ConnectionID string `json:"connection_id,omitempty"`
AuthState string `json:"auth_state,omitempty"`
RequestID string `json:"request_id"`
SessionID string `json:"session_id"`
Method string `json:"method"`
Source string `json:"source"`
Status string `json:"status"`
WorkspaceHash string `json:"workspace_hash,omitempty"`
GatewayCode string `json:"gateway_code,omitempty"`
LatencyMS int64 `json:"latency_ms"`
ConnectionID string `json:"connection_id,omitempty"`
AuthState string `json:"auth_state,omitempty"`
}

// emitRequestLog 输出网关结构化日志。
Expand All @@ -37,6 +38,9 @@ func emitRequestLog(ctx context.Context, logger *log.Logger, entry RequestLogEnt
if connectionID, ok := ConnectionIDFromContext(ctx); ok {
entry.ConnectionID = string(connectionID)
}
if entry.WorkspaceHash == "" {
entry.WorkspaceHash = WorkspaceHashFromContext(ctx)
}
if authState, ok := ConnectionAuthStateFromContext(ctx); ok && authState.IsAuthenticated() {
entry.AuthState = "authenticated"
} else if _, ok := TokenAuthenticatorFromContext(ctx); ok {
Expand Down
5 changes: 4 additions & 1 deletion internal/gateway/rpc_dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,10 @@ func hydrateFrameSessionFromConnection(ctx context.Context, frame MessageFrame)
return frame
}

frame.SessionID = strings.TrimSpace(relay.ResolveFallbackSessionID(connectionID))
frame.SessionID = strings.TrimSpace(relay.ResolveFallbackSessionIDForWorkspace(
connectionID,
WorkspaceHashFromContext(ctx),
))
return frame
}

Expand Down
62 changes: 40 additions & 22 deletions internal/gateway/stream_relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ type ConnectionRegistration struct {

// StreamBinding 描述连接绑定到会话路由表的一条订阅关系。
type StreamBinding struct {
SessionID string
RunID string
Channel StreamChannel
Role StreamRole
State map[string]any
Explicit bool
SessionID string
RunID string
WorkspaceHash string
Channel StreamChannel
Role StreamRole
State map[string]any
Explicit bool
}

// StreamRelayOptions 描述会话路由与流式中继的可选配置。
Expand Down Expand Up @@ -87,14 +88,15 @@ type bindingKey struct {
}

type bindingState struct {
sessionID string
runID string
channel StreamChannel
role StreamRole
state map[string]any
explicit bool
expireAt time.Time
lastSeen time.Time
sessionID string
runID string
workspaceHash string
channel StreamChannel
role StreamRole
state map[string]any
explicit bool
expireAt time.Time
lastSeen time.Time
}

// StreamRelay 维护连接-会话-运行态映射,并负责运行事件的精确中继。
Expand Down Expand Up @@ -405,6 +407,11 @@ func (r *StreamRelay) BindConnection(connectionID ConnectionID, binding StreamBi
return NewFrameError(ErrorCodeInvalidAction, "bind channel does not match connection channel")
}

workspaceHash := strings.TrimSpace(binding.WorkspaceHash)
if workspaceHash == "" {
workspaceHash = WorkspaceHashFromContext(connection.ctx)
}

key := bindingKey{sessionID: sessionID, runID: runID}
connectionBindingMap := r.connectionBindings[normalizedConnectionID]
if connectionBindingMap == nil {
Expand All @@ -424,14 +431,15 @@ func (r *StreamRelay) BindConnection(connectionID ConnectionID, binding StreamBi
return NewFrameError(ErrorCodeInvalidAction, "too many stream bindings for connection")
}
connectionBindingMap[key] = &bindingState{
sessionID: sessionID,
runID: runID,
channel: channel,
role: role,
state: state,
explicit: binding.Explicit,
expireAt: now.Add(r.bindingTTL),
lastSeen: now,
sessionID: sessionID,
runID: runID,
workspaceHash: workspaceHash,
channel: channel,
role: role,
state: state,
explicit: binding.Explicit,
expireAt: now.Add(r.bindingTTL),
lastSeen: now,
}
r.addConnectionToSessionIndexLocked(sessionID, normalizedConnectionID)
if runID != "" {
Expand All @@ -443,6 +451,11 @@ func (r *StreamRelay) BindConnection(connectionID ConnectionID, binding StreamBi

// ResolveFallbackSessionID 返回连接当前可用绑定中的会话兜底值(取最近续期的绑定)。
func (r *StreamRelay) ResolveFallbackSessionID(connectionID ConnectionID) string {
return r.ResolveFallbackSessionIDForWorkspace(connectionID, "")
}

// ResolveFallbackSessionIDForWorkspace 返回指定工作区内最近续期的连接兜底会话。
func (r *StreamRelay) ResolveFallbackSessionIDForWorkspace(connectionID ConnectionID, workspaceHash string) string {
if r == nil {
return ""
}
Expand All @@ -453,6 +466,7 @@ func (r *StreamRelay) ResolveFallbackSessionID(connectionID ConnectionID) string
}

now := time.Now()
normalizedWorkspaceHash := strings.TrimSpace(workspaceHash)

r.mu.RLock()
connectionBindingMap := r.connectionBindings[normalizedConnectionID]
Expand All @@ -464,6 +478,10 @@ func (r *StreamRelay) ResolveFallbackSessionID(connectionID ConnectionID) string
if state == nil || state.expireAt.Before(now) {
continue
}
if normalizedWorkspaceHash != "" &&
!strings.EqualFold(strings.TrimSpace(state.workspaceHash), normalizedWorkspaceHash) {
continue
}
if state.lastSeen.After(latestSeen) {
latestSeen = state.lastSeen
latestSessionID = state.sessionID
Expand Down
Loading
Loading