diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go index d121a0a7b..16b5dd7f3 100644 --- a/internal/mcp/connection.go +++ b/internal/mcp/connection.go @@ -11,6 +11,7 @@ import ( "net/http" "os/exec" "strings" + "sync" "time" "github.com/github/gh-aw-mcpg/internal/difc" @@ -71,6 +72,27 @@ type Connection struct { httpClient *http.Client httpSessionID string // Session ID returned by the HTTP backend httpTransportType HTTPTransportType // Type of HTTP transport in use + // sessionMu protects the mutable session fields: httpSessionID, session, and client. + // Always use getHTTPSessionID() or getSDKSession() to read these fields; the + // reconnect functions (reconnectPlainJSON, reconnectSDKTransport) hold the full Lock. + sessionMu sync.RWMutex +} + +// getSDKSession returns a snapshot of the current SDK session under a read lock. +// Returns nil if no session is available (e.g. plain JSON-RPC transport). +func (c *Connection) getSDKSession() *sdk.ClientSession { + c.sessionMu.RLock() + s := c.session + c.sessionMu.RUnlock() + return s +} + +// getHTTPSessionID returns a snapshot of the current HTTP session ID under a read lock. +func (c *Connection) getHTTPSessionID() string { + c.sessionMu.RLock() + id := c.httpSessionID + c.sessionMu.RUnlock() + return id } // NewConnection creates a new MCP connection using the official SDK @@ -255,6 +277,95 @@ func (c *Connection) GetHTTPHeaders() map[string]string { return c.headers } +// reconnectPlainJSON re-initialises the plain JSON-RPC session with the HTTP backend. +// It is safe for concurrent callers: only one reconnect runs at a time, and the updated +// session ID is available to all callers once the lock is released. +func (c *Connection) reconnectPlainJSON() error { + c.sessionMu.Lock() + defer c.sessionMu.Unlock() + + logConn.Printf("Session expired, reconnecting plain JSON-RPC for serverID=%s", c.serverID) + logger.LogWarn("backend", "MCP session expired for %s, attempting to reconnect...", c.serverID) + + sessionID, err := c.initializeHTTPSession() + if err != nil { + logger.LogError("backend", "Session reconnect failed for %s: %v", c.serverID, err) + return fmt.Errorf("session reconnect failed: %w", err) + } + + c.httpSessionID = sessionID + logConn.Printf("Reconnected plain JSON-RPC session for serverID=%s, new sessionID=%s", c.serverID, sessionID) + logger.LogInfo("backend", "Session successfully reconnected for %s", c.serverID) + return nil +} + +// reconnectSDKTransport re-establishes the SDK session for streamable or SSE transports. +// It is safe for concurrent callers: only one reconnect runs at a time. +func (c *Connection) reconnectSDKTransport() error { + c.sessionMu.Lock() + defer c.sessionMu.Unlock() + + logConn.Printf("Session expired, reconnecting SDK transport for serverID=%s, type=%s", c.serverID, c.httpTransportType) + logger.LogWarn("backend", "MCP session expired for %s, attempting to reconnect...", c.serverID) + + // Close the existing session gracefully (ignore error – it's already dead). + if c.session != nil { + _ = c.session.Close() + } + + // Build the appropriate transport. + client := newMCPClient(logConn) + var transport sdk.Transport + switch c.httpTransportType { + case HTTPTransportStreamable: + transport = &sdk.StreamableClientTransport{ + Endpoint: c.httpURL, + HTTPClient: c.httpClient, + MaxRetries: 0, + } + case HTTPTransportSSE: + transport = &sdk.SSEClientTransport{ + Endpoint: c.httpURL, + HTTPClient: c.httpClient, + } + default: + return fmt.Errorf("cannot reconnect: unsupported transport type %s", c.httpTransportType) + } + + connectCtx, cancel := context.WithTimeout(c.ctx, 10*time.Second) + defer cancel() + + session, err := client.Connect(connectCtx, transport, nil) + if err != nil { + logger.LogError("backend", "Session reconnect failed for %s: %v", c.serverID, err) + return fmt.Errorf("session reconnect failed: %w", err) + } + + c.client = client + c.session = session + + logConn.Printf("Reconnected SDK session for serverID=%s", c.serverID) + logger.LogInfo("backend", "Session successfully reconnected for %s", c.serverID) + return nil +} + +// callSDKMethodWithReconnect calls the SDK method and, if the session has expired, +// reconnects and retries exactly once before propagating the error. +func (c *Connection) callSDKMethodWithReconnect(method string, params interface{}) (*Response, error) { + result, err := c.callSDKMethod(method, params) + if err != nil && isSessionNotFoundError(err) { + logConn.Printf("Session not found error from SDK (serverID=%s), attempting reconnect", c.serverID) + if reconnErr := c.reconnectSDKTransport(); reconnErr != nil { + logConn.Printf("SDK session reconnect failed for serverID=%s: %v; returning original error", c.serverID, reconnErr) + logger.LogError("backend", "SDK session reconnect failed for %s: %v", c.serverID, reconnErr) + // Return the original session-not-found error so the caller sees a meaningful message. + return result, err + } + result, err = c.callSDKMethod(method, params) + } + return result, err +} + // SendRequest sends a JSON-RPC request and waits for the response // The serverID parameter is used for logging to associate the request with a backend server func (c *Connection) SendRequest(method string, params interface{}) (*Response, error) { @@ -301,7 +412,7 @@ func (c *Connection) SendRequestWithServerID(ctx context.Context, method string, } // For streamable and SSE transports, use SDK session methods - result, err = c.callSDKMethod(method, params) + result, err = c.callSDKMethodWithReconnect(method, params) // Log the response from backend server var responsePayload []byte if result != nil { @@ -374,7 +485,7 @@ func marshalToResponse(result interface{}) (*Response, error) { // This helper centralizes session validation logic across all MCP method wrappers. // Returns an error if the session is nil (e.g., for plain JSON-RPC transport). func (c *Connection) requireSession() error { - if c.session == nil { + if c.getSDKSession() == nil { return fmt.Errorf("SDK session not available for plain JSON-RPC transport") } return nil @@ -429,7 +540,7 @@ func callParamMethod[P any](c *Connection, rawParams interface{}, fn func(P) (in func (c *Connection) listTools() (*Response, error) { logConn.Printf("listTools: requesting tool list from backend serverID=%s", c.serverID) return c.callListMethod(func() (interface{}, error) { - result, err := c.session.ListTools(c.ctx, &sdk.ListToolsParams{}) + result, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{}) if err == nil { logConn.Printf("listTools: received %d tools from serverID=%s", len(result.Tools), c.serverID) } @@ -445,7 +556,7 @@ func (c *Connection) callTool(params interface{}) (*Response, error) { p.Arguments = make(map[string]interface{}) } logConn.Printf("callTool: parsed name=%s, arguments=%+v", p.Name, p.Arguments) - return c.session.CallTool(c.ctx, &sdk.CallToolParams{ + return c.getSDKSession().CallTool(c.ctx, &sdk.CallToolParams{ Name: p.Name, Arguments: p.Arguments, }) @@ -455,7 +566,7 @@ func (c *Connection) callTool(params interface{}) (*Response, error) { func (c *Connection) listResources() (*Response, error) { logConn.Printf("listResources: requesting resource list from backend serverID=%s", c.serverID) return c.callListMethod(func() (interface{}, error) { - result, err := c.session.ListResources(c.ctx, &sdk.ListResourcesParams{}) + result, err := c.getSDKSession().ListResources(c.ctx, &sdk.ListResourcesParams{}) if err == nil { logConn.Printf("listResources: received %d resources from serverID=%s", len(result.Resources), c.serverID) } @@ -469,7 +580,7 @@ func (c *Connection) readResource(params interface{}) (*Response, error) { } return callParamMethod(c, params, func(p readResourceParams) (interface{}, error) { logConn.Printf("readResource: reading resource uri=%s from serverID=%s", p.URI, c.serverID) - return c.session.ReadResource(c.ctx, &sdk.ReadResourceParams{ + return c.getSDKSession().ReadResource(c.ctx, &sdk.ReadResourceParams{ URI: p.URI, }) }) @@ -478,7 +589,7 @@ func (c *Connection) readResource(params interface{}) (*Response, error) { func (c *Connection) listPrompts() (*Response, error) { logConn.Printf("listPrompts: requesting prompt list from backend serverID=%s", c.serverID) return c.callListMethod(func() (interface{}, error) { - result, err := c.session.ListPrompts(c.ctx, &sdk.ListPromptsParams{}) + result, err := c.getSDKSession().ListPrompts(c.ctx, &sdk.ListPromptsParams{}) if err == nil { logConn.Printf("listPrompts: received %d prompts from serverID=%s", len(result.Prompts), c.serverID) } @@ -493,7 +604,7 @@ func (c *Connection) getPrompt(params interface{}) (*Response, error) { } return callParamMethod(c, params, func(p getPromptParams) (interface{}, error) { logConn.Printf("getPrompt: getting prompt name=%s from serverID=%s", p.Name, c.serverID) - return c.session.GetPrompt(c.ctx, &sdk.GetPromptParams{ + return c.getSDKSession().GetPrompt(c.ctx, &sdk.GetPromptParams{ Name: p.Name, Arguments: p.Arguments, }) @@ -504,8 +615,8 @@ func (c *Connection) getPrompt(params interface{}) (*Response, error) { func (c *Connection) Close() error { logConn.Printf("Closing connection: serverID=%s, isHTTP=%v", c.serverID, c.isHTTP) c.cancel() - if c.session != nil { - return c.session.Close() + if session := c.getSDKSession(); session != nil { + return session.Close() } return nil } diff --git a/internal/mcp/http_transport.go b/internal/mcp/http_transport.go index 933df4dca..2621ed828 100644 --- a/internal/mcp/http_transport.go +++ b/internal/mcp/http_transport.go @@ -62,6 +62,24 @@ func isHTTPConnectionError(err error) bool { return false } +// isSessionNotFoundError checks if an error message indicates a backend MCP session has expired +// or is not found. This is used to detect when automatic reconnection to the backend is needed. +func isSessionNotFoundError(err error) bool { + if err == nil { + return false + } + return strings.Contains(strings.ToLower(err.Error()), "session not found") +} + +// isSessionNotFoundHTTPResponse checks if an HTTP response indicates the backend session was not found. +// MCP backends return HTTP 404 with a "session not found" body when a session has expired. +func isSessionNotFoundHTTPResponse(statusCode int, body []byte) bool { + if statusCode != http.StatusNotFound { + return false + } + return strings.Contains(strings.ToLower(string(body)), "session not found") +} + // parseSSEResponse extracts JSON data from SSE-formatted response // SSE format: "event: message\ndata: {json}\n\n" func parseSSEResponse(body []byte) ([]byte, error) { @@ -436,58 +454,47 @@ func (c *Connection) initializeHTTPSession() (string, error) { return sessionID, nil } -// sendHTTPRequest sends a JSON-RPC request to an HTTP MCP server -// The ctx parameter is used to extract session ID for the Mcp-Session-Id header -func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params interface{}) (*Response, error) { - // Generate unique request ID using atomic counter - requestID := atomic.AddUint64(&requestIDCounter, 1) - - // For tools/call, ensure arguments field always exists (MCP protocol requirement) - if method == "tools/call" { - params = ensureToolCallArguments(params) - } - - logConn.Printf("Sending HTTP request to %s: method=%s, id=%d", c.httpURL, method, requestID) - - // Execute HTTP request with custom header modification for session ID - result, err := c.executeHTTPRequest(ctx, method, params, requestID, func(httpReq *http.Request) { - // Add Mcp-Session-Id header with priority: - // 1) Context session ID (if explicitly provided for this request) - // 2) Stored httpSessionID from initialization +// buildSessionHeaderModifier returns a header modifier function that adds the Mcp-Session-Id header. +// Priority: context session ID > stored connection session ID. +// Context session IDs are static for the lifetime of a single request and are captured once at +// construction time. Connection session IDs can change during a reconnect, so getHTTPSessionID() +// is called at request time to always pick up the current value. +func (c *Connection) buildSessionHeaderModifier(ctx context.Context) func(*http.Request) { + // Capture any context-provided session ID once (it never changes for this request). + ctxSessionID, _ := ctx.Value(SessionIDContextKey).(string) + return func(httpReq *http.Request) { var sessionID string - if ctxSessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && ctxSessionID != "" { + if ctxSessionID != "" { sessionID = ctxSessionID logConn.Printf("Using session ID from context: %s", sessionID) - } else if c.httpSessionID != "" { - sessionID = c.httpSessionID + } else if id := c.getHTTPSessionID(); id != "" { + sessionID = id logConn.Printf("Using stored session ID from initialization: %s", sessionID) } - if sessionID != "" { httpReq.Header.Set("Mcp-Session-Id", sessionID) } else { logConn.Printf("No session ID available (backend may not require session management)") } - }) - if err != nil { - return nil, err } +} - logConn.Printf("Received HTTP response: status=%d, body_len=%d", result.StatusCode, len(result.ResponseBody)) - - // Parse JSON-RPC response - // The response might be in SSE format (event: message\ndata: {...}) +// parseHTTPResult converts a raw httpRequestResult into a JSON-RPC Response, handling non-OK +// HTTP status codes by synthesising a JSON-RPC error when the server did not provide one. +func parseHTTPResult(result *httpRequestResult) (*Response, error) { + // Parse JSON-RPC response. + // The response might be in SSE format (event: message\ndata: {...}). rpcResponse, err := parseJSONRPCResponseWithSSE(result.ResponseBody, result.StatusCode, "JSON-RPC response") if err != nil { return nil, err } - // Check for HTTP errors after parsing + // Check for HTTP errors after parsing. // If we have a non-OK status but successfully parsed a JSON-RPC response, - // pass it through (it may already contain an error field) + // pass it through (it may already contain an error field). if result.StatusCode != http.StatusOK { logConn.Printf("HTTP error status=%d with valid JSON-RPC response, passing through", result.StatusCode) - // If the response doesn't already have an error, construct one + // If the response doesn't already have an error, construct one. if rpcResponse.Error == nil { rpcResponse.Error = &ResponseError{ Code: -32603, // Internal error @@ -499,3 +506,44 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params return rpcResponse, nil } + +// sendHTTPRequest sends a JSON-RPC request to an HTTP MCP server. +// The ctx parameter is used to extract session ID for the Mcp-Session-Id header. +// If the backend returns a "session not found" (HTTP 404) response, it attempts a one-time +// session reconnect and retries the request transparently. +func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params interface{}) (*Response, error) { + // For tools/call, ensure arguments field always exists (MCP protocol requirement) + if method == "tools/call" { + params = ensureToolCallArguments(params) + } + + headerModifier := c.buildSessionHeaderModifier(ctx) + + requestID := atomic.AddUint64(&requestIDCounter, 1) + logConn.Printf("Sending HTTP request to %s: method=%s, id=%d", c.httpURL, method, requestID) + + result, err := c.executeHTTPRequest(ctx, method, params, requestID, headerModifier) + if err != nil { + return nil, err + } + + logConn.Printf("Received HTTP response: status=%d, body_len=%d", result.StatusCode, len(result.ResponseBody)) + + // If the backend reported that the session has expired, reconnect and retry once. + if isSessionNotFoundHTTPResponse(result.StatusCode, result.ResponseBody) { + logConn.Printf("Session not found from %s (serverID=%s), attempting reconnect", c.httpURL, c.serverID) + if reconnErr := c.reconnectPlainJSON(); reconnErr == nil { + requestID = atomic.AddUint64(&requestIDCounter, 1) + logConn.Printf("Retrying HTTP request after reconnect: method=%s, id=%d", method, requestID) + result, err = c.executeHTTPRequest(ctx, method, params, requestID, headerModifier) + if err != nil { + return nil, err + } + logConn.Printf("Retry HTTP response: status=%d, body_len=%d", result.StatusCode, len(result.ResponseBody)) + } else { + logConn.Printf("Session reconnect failed (%v), returning original session-not-found error", reconnErr) + } + } + + return parseHTTPResult(result) +} diff --git a/internal/mcp/http_transport_test.go b/internal/mcp/http_transport_test.go index e40991991..177291a7c 100644 --- a/internal/mcp/http_transport_test.go +++ b/internal/mcp/http_transport_test.go @@ -856,3 +856,228 @@ func TestSendHTTPRequest_NonToolsCallMethodDoesNotAddArguments(t *testing.T) { assert.False(t, hasArgs, "arguments should NOT be added for non tools/call methods") } } + +// ============================================================================= +// isSessionNotFoundError tests +// ============================================================================= + +func TestIsSessionNotFoundError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {name: "nil error returns false", err: nil, want: false}, + {name: "unrelated error returns false", err: fmt.Errorf("internal server error"), want: false}, + {name: "exact match returns true", err: fmt.Errorf("session not found"), want: true}, + {name: "uppercase returns true", err: fmt.Errorf("Session Not Found"), want: true}, + {name: "embedded in longer message returns true", err: fmt.Errorf("Streamable HTTP error: Error POSTing to endpoint: session not found"), want: true}, + {name: "session expired message returns false", err: fmt.Errorf("session expired"), want: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, isSessionNotFoundError(tt.err)) + }) + } +} + +// ============================================================================= +// isSessionNotFoundHTTPResponse tests +// ============================================================================= + +func TestIsSessionNotFoundHTTPResponse(t *testing.T) { + tests := []struct { + name string + statusCode int + body []byte + want bool + }{ + {name: "200 OK returns false", statusCode: http.StatusOK, body: []byte("session not found"), want: false}, + {name: "500 returns false", statusCode: http.StatusInternalServerError, body: []byte("session not found"), want: false}, + {name: "404 with unrelated body returns false", statusCode: http.StatusNotFound, body: []byte("resource not found"), want: false}, + {name: "404 with session not found body returns true", statusCode: http.StatusNotFound, body: []byte("session not found"), want: true}, + {name: "404 with uppercase body returns true", statusCode: http.StatusNotFound, body: []byte("Session Not Found"), want: true}, + {name: "404 with session not found embedded in JSON returns true", statusCode: http.StatusNotFound, body: []byte(`{"error":"session not found"}`), want: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, isSessionNotFoundHTTPResponse(tt.statusCode, tt.body)) + }) + } +} + +// ============================================================================= +// Session reconnect tests (plain JSON-RPC) +// ============================================================================= + +// TestSendHTTPRequest_ReconnectsOnSessionNotFound verifies that when the backend returns +// a 404 "session not found" response, sendHTTPRequest reconnects and retries the request. +func TestSendHTTPRequest_ReconnectsOnSessionNotFound(t *testing.T) { + requestCount := 0 + var receivedSessionIDs []string + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]interface{} + json.NewDecoder(r.Body).Decode(&body) //nolint:errcheck + method, _ := body["method"].(string) + + switch method { + case "initialize": + requestCount++ + sessionID := fmt.Sprintf("session-%d", requestCount) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Mcp-Session-Id", sessionID) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ //nolint:errcheck + "jsonrpc": "2.0", + "id": body["id"], + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "serverInfo": map[string]interface{}{"name": "test"}, + }, + }) + + case "tools/list": + sessionID := r.Header.Get("Mcp-Session-Id") + receivedSessionIDs = append(receivedSessionIDs, sessionID) + + // Simulate first tool call failing with session-not-found (session-1 expired) + if sessionID == "session-1" { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, "session not found") + return + } + + // Subsequent calls with the new session succeed + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ //nolint:errcheck + "jsonrpc": "2.0", + "id": body["id"], + "result": map[string]interface{}{"tools": []interface{}{}}, + }) + } + })) + defer testServer.Close() + + conn, err := NewHTTPConnection(context.Background(), "test-server", testServer.URL, map[string]string{ + "Authorization": "test-token", + }) + require.NoError(t, err) + require.NotNil(t, conn) + defer conn.Close() + + // The initial session is "session-1". The first tools/list call should trigger a + // reconnect (because the server returns 404 session-not-found for session-1), + // get a new "session-2", and then succeed on retry. + resp, err := conn.sendHTTPRequest(context.Background(), "tools/list", nil) + require.NoError(t, err) + require.NotNil(t, resp) + require.Nil(t, resp.Error, "response should not contain an error after reconnect") + + // Verify the reconnect happened: session-1 failed, session-2 succeeded. + require.Len(t, receivedSessionIDs, 2, "expected two tool calls: initial (failed) + retry (succeeded)") + assert.Equal(t, "session-1", receivedSessionIDs[0], "first attempt should use the initial session") + assert.Equal(t, "session-2", receivedSessionIDs[1], "retry should use the reconnected session") + assert.Equal(t, "session-2", conn.httpSessionID, "connection should store the new session ID") +} + +// TestSendHTTPRequest_ReconnectFailure verifies that when reconnection itself fails, +// the original session-not-found response is returned to the caller. +func TestSendHTTPRequest_ReconnectFailure(t *testing.T) { + firstInitDone := false + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]interface{} + json.NewDecoder(r.Body).Decode(&body) //nolint:errcheck + method, _ := body["method"].(string) + + switch method { + case "initialize": + if !firstInitDone { + firstInitDone = true + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Mcp-Session-Id", "session-original") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ //nolint:errcheck + "jsonrpc": "2.0", + "id": body["id"], + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "serverInfo": map[string]interface{}{"name": "test"}, + }, + }) + } else { + // Reconnect attempt also fails + w.WriteHeader(http.StatusInternalServerError) + } + + case "tools/list": + // Always return session not found + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, "session not found") + } + })) + defer testServer.Close() + + conn, err := NewHTTPConnection(context.Background(), "test-server", testServer.URL, map[string]string{ + "Authorization": "test-token", + }) + require.NoError(t, err) + require.NotNil(t, conn) + defer conn.Close() + + // The tools/list call gets session-not-found, reconnect fails (500 on initialize), + // so the original session-not-found response is passed through. + resp, err := conn.sendHTTPRequest(context.Background(), "tools/list", nil) + require.NoError(t, err, "should not return a Go error, but a JSON-RPC error response") + require.NotNil(t, resp) + require.NotNil(t, resp.Error, "response should contain an error when session-not-found and reconnect failed") +} + +// TestSendHTTPRequest_NoReconnectOnOtherErrors verifies that non-session errors +// (e.g. 500 internal server error) do not trigger a reconnect attempt. +func TestSendHTTPRequest_NoReconnectOnOtherErrors(t *testing.T) { + initCount := 0 + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]interface{} + json.NewDecoder(r.Body).Decode(&body) //nolint:errcheck + method, _ := body["method"].(string) + + if method == "initialize" { + initCount++ + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Mcp-Session-Id", "session-1") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ //nolint:errcheck + "jsonrpc": "2.0", + "id": body["id"], + "result": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "serverInfo": map[string]interface{}{"name": "test"}, + }, + }) + return + } + + // Return 500 – should not trigger reconnect + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprint(w, "internal server error") + })) + defer testServer.Close() + + conn, err := NewHTTPConnection(context.Background(), "test-server", testServer.URL, map[string]string{ + "Authorization": "test-token", + }) + require.NoError(t, err) + defer conn.Close() + + _, err = conn.sendHTTPRequest(context.Background(), "tools/list", nil) + require.NoError(t, err) + + // initCount should be 1 (initial only) – no reconnect was attempted. + assert.Equal(t, 1, initCount, "should not reconnect on non-session-not-found errors") +} diff --git a/internal/server/transport.go b/internal/server/transport.go index c3f7c2720..81dd560d2 100644 --- a/internal/server/transport.go +++ b/internal/server/transport.go @@ -37,7 +37,7 @@ func CreateHTTPServerForMCP(addr string, unifiedServer *UnifiedServer, apiKey st }, &sdk.StreamableHTTPOptions{ Stateless: false, // Support stateful sessions Logger: logger.NewSlogLoggerWithHandler(logTransport), // Integrate SDK logging with project logger - SessionTimeout: 30 * time.Minute, // Prevent resource leaks from idle connections + SessionTimeout: 2 * time.Hour, // 2h accommodates long-running workflows with idle periods }) // Apply standard middleware stack (SDK logging → shutdown check → auth)