Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 121 additions & 10 deletions internal/mcp/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net/http"
"os/exec"
"strings"
"sync"
"time"

"github.com/github/gh-aw-mcpg/internal/difc"
Expand Down Expand Up @@ -71,6 +72,27 @@ type Connection struct {
httpClient *http.Client
httpSessionID string // Session ID returned by the HTTP backend
httpTransportType HTTPTransportType // Type of HTTP transport in use
// sessionMu protects the mutable session fields: httpSessionID, session, and client.
// Always use getHTTPSessionID() or getSDKSession() to read these fields; the
// reconnect functions (reconnectPlainJSON, reconnectSDKTransport) hold the full Lock.
sessionMu sync.RWMutex
}

// getSDKSession returns a snapshot of the current SDK session under a read lock.
// Returns nil if no session is available (e.g. plain JSON-RPC transport).
func (c *Connection) getSDKSession() *sdk.ClientSession {
c.sessionMu.RLock()
s := c.session
c.sessionMu.RUnlock()
return s
}

// getHTTPSessionID returns a snapshot of the current HTTP session ID under a read lock.
func (c *Connection) getHTTPSessionID() string {
c.sessionMu.RLock()
id := c.httpSessionID
c.sessionMu.RUnlock()
return id
}

// NewConnection creates a new MCP connection using the official SDK
Expand Down Expand Up @@ -255,6 +277,95 @@ func (c *Connection) GetHTTPHeaders() map[string]string {
return c.headers
}

// reconnectPlainJSON re-initialises the plain JSON-RPC session with the HTTP backend.
// It is safe for concurrent callers: only one reconnect runs at a time, and the updated
// session ID is available to all callers once the lock is released.
func (c *Connection) reconnectPlainJSON() error {
c.sessionMu.Lock()
defer c.sessionMu.Unlock()

logConn.Printf("Session expired, reconnecting plain JSON-RPC for serverID=%s", c.serverID)
logger.LogWarn("backend", "MCP session expired for %s, attempting to reconnect...", c.serverID)

sessionID, err := c.initializeHTTPSession()
if err != nil {
logger.LogError("backend", "Session reconnect failed for %s: %v", c.serverID, err)
return fmt.Errorf("session reconnect failed: %w", err)
}

c.httpSessionID = sessionID
logConn.Printf("Reconnected plain JSON-RPC session for serverID=%s, new sessionID=%s", c.serverID, sessionID)
logger.LogInfo("backend", "Session successfully reconnected for %s", c.serverID)
return nil
Comment on lines +290 to +299

Copilot AI Mar 26, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reconnectPlainJSON calls initializeHTTPSession(), which uses context.Background() internally. That means a reconnect ignores caller cancellation/deadlines and can hang until the HTTP client timeout, even if the original request context is canceled. Consider plumbing a context into initializeHTTPSession (or adding a context-aware variant) and using a bounded timeout (similar to reconnectSDKTransport) so reconnect attempts are cancelable and predictable.

Suggested change
sessionID, err := c.initializeHTTPSession()
if err != nil {
logger.LogError("backend", "Session reconnect failed for %s: %v", c.serverID, err)
return fmt.Errorf("session reconnect failed: %w", err)
}
c.httpSessionID = sessionID
logConn.Printf("Reconnected plain JSON-RPC session for serverID=%s, new sessionID=%s", c.serverID, sessionID)
logger.LogInfo("backend", "Session successfully reconnected for %s", c.serverID)
return nil
// Use a bounded context so reconnect attempts are cancelable and predictable,
// similar to reconnectSDKTransport.
connectCtx, cancel := context.WithTimeout(c.ctx, 10*time.Second)
defer cancel()
type reconnectResult struct {
sessionID string
err error
}
resultCh := make(chan reconnectResult, 1)
go func() {
sessionID, err := c.initializeHTTPSession()
resultCh <- reconnectResult{sessionID: sessionID, err: err}
}()
select {
case <-connectCtx.Done():
err := connectCtx.Err()
logger.LogError("backend", "Session reconnect canceled or timed out for %s: %v", c.serverID, err)
return fmt.Errorf("session reconnect canceled or timed out: %w", err)
case res := <-resultCh:
if res.err != nil {
logger.LogError("backend", "Session reconnect failed for %s: %v", c.serverID, res.err)
return fmt.Errorf("session reconnect failed: %w", res.err)
}
c.httpSessionID = res.sessionID
logConn.Printf("Reconnected plain JSON-RPC session for serverID=%s, new sessionID=%s", c.serverID, res.sessionID)
logger.LogInfo("backend", "Session successfully reconnected for %s", c.serverID)
return nil
}

Copilot uses AI. Check for mistakes.
}

// reconnectSDKTransport re-establishes the SDK session for streamable or SSE transports.
// It is safe for concurrent callers: only one reconnect runs at a time.
func (c *Connection) reconnectSDKTransport() error {
c.sessionMu.Lock()
defer c.sessionMu.Unlock()

logConn.Printf("Session expired, reconnecting SDK transport for serverID=%s, type=%s", c.serverID, c.httpTransportType)
logger.LogWarn("backend", "MCP session expired for %s, attempting to reconnect...", c.serverID)

// Close the existing session gracefully (ignore error – it's already dead).
if c.session != nil {
_ = c.session.Close()
}

// Build the appropriate transport.
client := newMCPClient(logConn)
var transport sdk.Transport
switch c.httpTransportType {
case HTTPTransportStreamable:
transport = &sdk.StreamableClientTransport{
Endpoint: c.httpURL,
HTTPClient: c.httpClient,
MaxRetries: 0,
}
case HTTPTransportSSE:
transport = &sdk.SSEClientTransport{
Endpoint: c.httpURL,
HTTPClient: c.httpClient,
}
default:
return fmt.Errorf("cannot reconnect: unsupported transport type %s", c.httpTransportType)
}

connectCtx, cancel := context.WithTimeout(c.ctx, 10*time.Second)
defer cancel()

session, err := client.Connect(connectCtx, transport, nil)
if err != nil {
logger.LogError("backend", "Session reconnect failed for %s: %v", c.serverID, err)
return fmt.Errorf("session reconnect failed: %w", err)
}

c.client = client
c.session = session

logConn.Printf("Reconnected SDK session for serverID=%s", c.serverID)
logger.LogInfo("backend", "Session successfully reconnected for %s", c.serverID)
return nil
}

// callSDKMethodWithReconnect calls the SDK method and, if the session has expired,
// reconnects and retries exactly once before propagating the error.
func (c *Connection) callSDKMethodWithReconnect(method string, params interface{}) (*Response, error) {
result, err := c.callSDKMethod(method, params)
if err != nil && isSessionNotFoundError(err) {
logConn.Printf("Session not found error from SDK (serverID=%s), attempting reconnect", c.serverID)
if reconnErr := c.reconnectSDKTransport(); reconnErr != nil {
logConn.Printf("SDK session reconnect failed for serverID=%s: %v; returning original error", c.serverID, reconnErr)
logger.LogError("backend", "SDK session reconnect failed for %s: %v", c.serverID, reconnErr)
// Return the original session-not-found error so the caller sees a meaningful message.
return result, err
}
result, err = c.callSDKMethod(method, params)
}
return result, err
}

// SendRequest sends a JSON-RPC request and waits for the response
// The serverID parameter is used for logging to associate the request with a backend server
func (c *Connection) SendRequest(method string, params interface{}) (*Response, error) {
Expand Down Expand Up @@ -301,7 +412,7 @@ func (c *Connection) SendRequestWithServerID(ctx context.Context, method string,
}

// For streamable and SSE transports, use SDK session methods
result, err = c.callSDKMethod(method, params)
result, err = c.callSDKMethodWithReconnect(method, params)
// Log the response from backend server
var responsePayload []byte
if result != nil {
Expand Down Expand Up @@ -374,7 +485,7 @@ func marshalToResponse(result interface{}) (*Response, error) {
// This helper centralizes session validation logic across all MCP method wrappers.
// Returns an error if the session is nil (e.g., for plain JSON-RPC transport).
func (c *Connection) requireSession() error {
if c.session == nil {
if c.getSDKSession() == nil {
return fmt.Errorf("SDK session not available for plain JSON-RPC transport")
}
return nil
Expand Down Expand Up @@ -429,7 +540,7 @@ func callParamMethod[P any](c *Connection, rawParams interface{}, fn func(P) (in
func (c *Connection) listTools() (*Response, error) {
logConn.Printf("listTools: requesting tool list from backend serverID=%s", c.serverID)
return c.callListMethod(func() (interface{}, error) {
result, err := c.session.ListTools(c.ctx, &sdk.ListToolsParams{})
result, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{})
if err == nil {
logConn.Printf("listTools: received %d tools from serverID=%s", len(result.Tools), c.serverID)
}
Expand All @@ -445,7 +556,7 @@ func (c *Connection) callTool(params interface{}) (*Response, error) {
p.Arguments = make(map[string]interface{})
}
logConn.Printf("callTool: parsed name=%s, arguments=%+v", p.Name, p.Arguments)
return c.session.CallTool(c.ctx, &sdk.CallToolParams{
return c.getSDKSession().CallTool(c.ctx, &sdk.CallToolParams{
Name: p.Name,
Arguments: p.Arguments,
})
Expand All @@ -455,7 +566,7 @@ func (c *Connection) callTool(params interface{}) (*Response, error) {
func (c *Connection) listResources() (*Response, error) {
logConn.Printf("listResources: requesting resource list from backend serverID=%s", c.serverID)
return c.callListMethod(func() (interface{}, error) {
result, err := c.session.ListResources(c.ctx, &sdk.ListResourcesParams{})
result, err := c.getSDKSession().ListResources(c.ctx, &sdk.ListResourcesParams{})
if err == nil {
logConn.Printf("listResources: received %d resources from serverID=%s", len(result.Resources), c.serverID)
}
Expand All @@ -469,7 +580,7 @@ func (c *Connection) readResource(params interface{}) (*Response, error) {
}
return callParamMethod(c, params, func(p readResourceParams) (interface{}, error) {
logConn.Printf("readResource: reading resource uri=%s from serverID=%s", p.URI, c.serverID)
return c.session.ReadResource(c.ctx, &sdk.ReadResourceParams{
return c.getSDKSession().ReadResource(c.ctx, &sdk.ReadResourceParams{
URI: p.URI,
})
})
Expand All @@ -478,7 +589,7 @@ func (c *Connection) readResource(params interface{}) (*Response, error) {
func (c *Connection) listPrompts() (*Response, error) {
logConn.Printf("listPrompts: requesting prompt list from backend serverID=%s", c.serverID)
return c.callListMethod(func() (interface{}, error) {
result, err := c.session.ListPrompts(c.ctx, &sdk.ListPromptsParams{})
result, err := c.getSDKSession().ListPrompts(c.ctx, &sdk.ListPromptsParams{})
if err == nil {
logConn.Printf("listPrompts: received %d prompts from serverID=%s", len(result.Prompts), c.serverID)
}
Expand All @@ -493,7 +604,7 @@ func (c *Connection) getPrompt(params interface{}) (*Response, error) {
}
return callParamMethod(c, params, func(p getPromptParams) (interface{}, error) {
logConn.Printf("getPrompt: getting prompt name=%s from serverID=%s", p.Name, c.serverID)
return c.session.GetPrompt(c.ctx, &sdk.GetPromptParams{
return c.getSDKSession().GetPrompt(c.ctx, &sdk.GetPromptParams{
Name: p.Name,
Arguments: p.Arguments,
})
Expand All @@ -504,8 +615,8 @@ func (c *Connection) getPrompt(params interface{}) (*Response, error) {
func (c *Connection) Close() error {
logConn.Printf("Closing connection: serverID=%s, isHTTP=%v", c.serverID, c.isHTTP)
c.cancel()
if c.session != nil {
return c.session.Close()
if session := c.getSDKSession(); session != nil {
return session.Close()
}
return nil
}
112 changes: 80 additions & 32 deletions internal/mcp/http_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,24 @@ func isHTTPConnectionError(err error) bool {
return false
}

// isSessionNotFoundError checks if an error message indicates a backend MCP session has expired
// or is not found. This is used to detect when automatic reconnection to the backend is needed.
func isSessionNotFoundError(err error) bool {
if err == nil {
return false
}
return strings.Contains(strings.ToLower(err.Error()), "session not found")
}

// isSessionNotFoundHTTPResponse checks if an HTTP response indicates the backend session was not found.
// MCP backends return HTTP 404 with a "session not found" body when a session has expired.
func isSessionNotFoundHTTPResponse(statusCode int, body []byte) bool {
if statusCode != http.StatusNotFound {
return false
}
return strings.Contains(strings.ToLower(string(body)), "session not found")
}

// parseSSEResponse extracts JSON data from SSE-formatted response
// SSE format: "event: message\ndata: {json}\n\n"
func parseSSEResponse(body []byte) ([]byte, error) {
Expand Down Expand Up @@ -436,58 +454,47 @@ func (c *Connection) initializeHTTPSession() (string, error) {
return sessionID, nil
}

// sendHTTPRequest sends a JSON-RPC request to an HTTP MCP server
// The ctx parameter is used to extract session ID for the Mcp-Session-Id header
func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params interface{}) (*Response, error) {
// Generate unique request ID using atomic counter
requestID := atomic.AddUint64(&requestIDCounter, 1)

// For tools/call, ensure arguments field always exists (MCP protocol requirement)
if method == "tools/call" {
params = ensureToolCallArguments(params)
}

logConn.Printf("Sending HTTP request to %s: method=%s, id=%d", c.httpURL, method, requestID)

// Execute HTTP request with custom header modification for session ID
result, err := c.executeHTTPRequest(ctx, method, params, requestID, func(httpReq *http.Request) {
// Add Mcp-Session-Id header with priority:
// 1) Context session ID (if explicitly provided for this request)
// 2) Stored httpSessionID from initialization
// buildSessionHeaderModifier returns a header modifier function that adds the Mcp-Session-Id header.
// Priority: context session ID > stored connection session ID.
// Context session IDs are static for the lifetime of a single request and are captured once at
// construction time. Connection session IDs can change during a reconnect, so getHTTPSessionID()
// is called at request time to always pick up the current value.
func (c *Connection) buildSessionHeaderModifier(ctx context.Context) func(*http.Request) {
// Capture any context-provided session ID once (it never changes for this request).
ctxSessionID, _ := ctx.Value(SessionIDContextKey).(string)
return func(httpReq *http.Request) {
var sessionID string
if ctxSessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && ctxSessionID != "" {
if ctxSessionID != "" {
sessionID = ctxSessionID
logConn.Printf("Using session ID from context: %s", sessionID)
} else if c.httpSessionID != "" {
sessionID = c.httpSessionID
} else if id := c.getHTTPSessionID(); id != "" {
sessionID = id
logConn.Printf("Using stored session ID from initialization: %s", sessionID)
}

if sessionID != "" {
httpReq.Header.Set("Mcp-Session-Id", sessionID)
} else {
logConn.Printf("No session ID available (backend may not require session management)")
}
})
if err != nil {
return nil, err
}
}

logConn.Printf("Received HTTP response: status=%d, body_len=%d", result.StatusCode, len(result.ResponseBody))

// Parse JSON-RPC response
// The response might be in SSE format (event: message\ndata: {...})
// parseHTTPResult converts a raw httpRequestResult into a JSON-RPC Response, handling non-OK
// HTTP status codes by synthesising a JSON-RPC error when the server did not provide one.
func parseHTTPResult(result *httpRequestResult) (*Response, error) {
// Parse JSON-RPC response.
// The response might be in SSE format (event: message\ndata: {...}).
rpcResponse, err := parseJSONRPCResponseWithSSE(result.ResponseBody, result.StatusCode, "JSON-RPC response")
if err != nil {
return nil, err
}

// Check for HTTP errors after parsing
// Check for HTTP errors after parsing.
// If we have a non-OK status but successfully parsed a JSON-RPC response,
// pass it through (it may already contain an error field)
// pass it through (it may already contain an error field).
if result.StatusCode != http.StatusOK {
Comment on lines +492 to 495

Copilot AI Mar 26, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment here suggests that for non-200 HTTP statuses we might still "successfully parse a JSON-RPC response" and pass it through. However parseJSONRPCResponseWithSSE currently returns a synthetic JSON-RPC error for any statusCode != 200 even when the body is valid JSON-RPC, so this branch/doc is misleading and may hide server-provided error details. Consider either (a) adjusting parseJSONRPCResponseWithSSE/parseHTTPResult to preserve parsed JSON-RPC error bodies on non-200 responses, or (b) updating the comment/logging and simplifying the non-OK handling accordingly.

Copilot uses AI. Check for mistakes.
logConn.Printf("HTTP error status=%d with valid JSON-RPC response, passing through", result.StatusCode)
// If the response doesn't already have an error, construct one
// If the response doesn't already have an error, construct one.
if rpcResponse.Error == nil {
rpcResponse.Error = &ResponseError{
Code: -32603, // Internal error
Expand All @@ -499,3 +506,44 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params

return rpcResponse, nil
}

// sendHTTPRequest sends a JSON-RPC request to an HTTP MCP server.
// The ctx parameter is used to extract session ID for the Mcp-Session-Id header.
// If the backend returns a "session not found" (HTTP 404) response, it attempts a one-time
// session reconnect and retries the request transparently.
func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params interface{}) (*Response, error) {
// For tools/call, ensure arguments field always exists (MCP protocol requirement)
if method == "tools/call" {
params = ensureToolCallArguments(params)
}

headerModifier := c.buildSessionHeaderModifier(ctx)

requestID := atomic.AddUint64(&requestIDCounter, 1)
logConn.Printf("Sending HTTP request to %s: method=%s, id=%d", c.httpURL, method, requestID)

result, err := c.executeHTTPRequest(ctx, method, params, requestID, headerModifier)
if err != nil {
return nil, err
}

logConn.Printf("Received HTTP response: status=%d, body_len=%d", result.StatusCode, len(result.ResponseBody))

// If the backend reported that the session has expired, reconnect and retry once.
if isSessionNotFoundHTTPResponse(result.StatusCode, result.ResponseBody) {
logConn.Printf("Session not found from %s (serverID=%s), attempting reconnect", c.httpURL, c.serverID)
if reconnErr := c.reconnectPlainJSON(); reconnErr == nil {
requestID = atomic.AddUint64(&requestIDCounter, 1)
logConn.Printf("Retrying HTTP request after reconnect: method=%s, id=%d", method, requestID)
result, err = c.executeHTTPRequest(ctx, method, params, requestID, headerModifier)
if err != nil {
return nil, err
}
logConn.Printf("Retry HTTP response: status=%d, body_len=%d", result.StatusCode, len(result.ResponseBody))
} else {
logConn.Printf("Session reconnect failed (%v), returning original session-not-found error", reconnErr)
}
}

return parseHTTPResult(result)
}
Loading
Loading