diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go index 6ab009b3e..269056cbd 100644 --- a/internal/mcp/connection.go +++ b/internal/mcp/connection.go @@ -23,6 +23,20 @@ import ( var logConn = logger.New("mcp:connection") +// isHTTPConnectionError checks if an error is a network connection error +// This helper reduces code duplication for checking common connection error patterns. +// Note: Uses string matching which is fragile but consistent with existing patterns in the codebase. +// TODO: Consider using errors.Is() or type assertions (*net.OpError) for more robust error classification. +func isHTTPConnectionError(err error) bool { + if err == nil { + return false + } + errMsg := err.Error() + return strings.Contains(errMsg, "connection refused") || + strings.Contains(errMsg, "no such host") || + strings.Contains(errMsg, "network is unreachable") +} + // parseSSEResponse extracts JSON data from SSE-formatted response // SSE format: "event: message\ndata: {json}\n\n" func parseSSEResponse(body []byte) ([]byte, error) { @@ -584,9 +598,7 @@ func (c *Connection) initializeHTTPSession() (string, error) { httpResp, err := c.httpClient.Do(httpReq) if err != nil { // Check if it's a connection error (cannot connect at all) - if strings.Contains(err.Error(), "connection refused") || - strings.Contains(err.Error(), "no such host") || - strings.Contains(err.Error(), "network is unreachable") { + if isHTTPConnectionError(err) { return "", fmt.Errorf("cannot connect to HTTP backend at %s: %w", c.httpURL, err) } return "", fmt.Errorf("failed to send initialize request to %s: %w", c.httpURL, err) @@ -698,9 +710,7 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params httpResp, err := c.httpClient.Do(httpReq) if err != nil { // Check if it's a connection error (cannot connect at all) - if strings.Contains(err.Error(), "connection refused") || - strings.Contains(err.Error(), "no such host") || - strings.Contains(err.Error(), "network is unreachable") { + if isHTTPConnectionError(err) { return nil, fmt.Errorf("cannot connect to HTTP backend at %s: %w", c.httpURL, err) } return nil, fmt.Errorf("failed to send HTTP request to %s: %w", c.httpURL, err) diff --git a/internal/mcp/connection_test.go b/internal/mcp/connection_test.go index 9ee392c42..dadcaf180 100644 --- a/internal/mcp/connection_test.go +++ b/internal/mcp/connection_test.go @@ -703,3 +703,50 @@ func TestNewHTTPConnection(t *testing.T) { assert.Equal(t, httpClient, conn.httpClient, "HTTP client should match") assert.Equal(t, HTTPTransportStreamable, conn.httpTransportType, "Transport type should match") } + +// TestIsHTTPConnectionError tests the HTTP connection error detection helper +func TestIsHTTPConnectionError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "connection refused error", + err: fmt.Errorf("dial tcp: connection refused"), + expected: true, + }, + { + name: "no such host error", + err: fmt.Errorf("dial tcp: lookup example.invalid: no such host"), + expected: true, + }, + { + name: "network is unreachable error", + err: fmt.Errorf("dial tcp: network is unreachable"), + expected: true, + }, + { + name: "other error", + err: fmt.Errorf("some other error"), + expected: false, + }, + { + name: "timeout error", + err: fmt.Errorf("context deadline exceeded"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isHTTPConnectionError(tt.err) + assert.Equal(t, tt.expected, result, "isHTTPConnectionError should return %v for %v", tt.expected, tt.err) + }) + } +} diff --git a/internal/server/call_tool_result_test.go b/internal/server/call_tool_result_test.go index 90b097688..0fda5da79 100644 --- a/internal/server/call_tool_result_test.go +++ b/internal/server/call_tool_result_test.go @@ -1,6 +1,7 @@ package server import ( + "fmt" "testing" sdk "github.com/modelcontextprotocol/go-sdk/mcp" @@ -151,3 +152,41 @@ func TestConvertToCallToolResult_NilCheck(t *testing.T) { t.Log("✓ CallToolResult is properly non-nil and structured") } + +// TestNewErrorCallToolResult tests the error CallToolResult helper +func TestNewErrorCallToolResult(t *testing.T) { + tests := []struct { + name string + err error + expectError bool + }{ + { + name: "simple error", + err: fmt.Errorf("test error"), + expectError: true, + }, + { + name: "formatted error", + err: fmt.Errorf("formatted error: %s", "test"), + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, data, err := newErrorCallToolResult(tt.err) + + // Verify the error is returned + assert.Equal(t, tt.err, err, "Error should be returned as-is") + + // Verify data is nil + assert.Nil(t, data, "Data should be nil for error results") + + // Verify CallToolResult is properly structured + require.NotNil(t, result, "CallToolResult should not be nil") + assert.True(t, result.IsError, "IsError should be true") + + t.Logf("✓ Error CallToolResult properly created with IsError=%v", result.IsError) + }) + } +} diff --git a/internal/server/unified.go b/internal/server/unified.go index a994ab1f8..a539124cb 100644 --- a/internal/server/unified.go +++ b/internal/server/unified.go @@ -320,7 +320,7 @@ func (us *UnifiedServer) registerToolsFromBackend(serverID string) error { toolArgs, err := parseToolArguments(req) if err != nil { logger.LogError("client", "Failed to unmarshal tool arguments, tool=%s, error=%v", toolNameCopy, err) - return &sdk.CallToolResult{IsError: true}, nil, err + return newErrorCallToolResult(err) } // Log the MCP tool call request @@ -331,7 +331,7 @@ func (us *UnifiedServer) registerToolsFromBackend(serverID string) error { // Check session is initialized if err := us.requireSession(ctx); err != nil { logger.LogError("client", "MCP tool call failed: session not initialized, session=%s, tool=%s", sessionID, toolNameCopy) - return &sdk.CallToolResult{IsError: true}, nil, err + return newErrorCallToolResult(err) } result, data, err := us.callBackendTool(ctx, serverIDCopy, toolNameCopy, toolArgs) @@ -395,7 +395,7 @@ func (us *UnifiedServer) registerSysTools() error { toolArgs, err := parseToolArguments(req) if err != nil { logger.LogError("client", "Failed to unmarshal sys_init arguments, error=%v", err) - return &sdk.CallToolResult{IsError: true}, nil, err + return newErrorCallToolResult(err) } // Extract token from args @@ -408,7 +408,7 @@ func (us *UnifiedServer) registerSysTools() error { sessionID := us.getSessionID(ctx) if sessionID == "" { logger.LogError("client", "MCP session initialization failed: no session ID provided") - return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("no session ID provided") + return newErrorCallToolResult(fmt.Errorf("no session ID provided")) } logger.LogInfo("client", "MCP session initialization started, session=%s, has_token=%v", sessionID, token != "") @@ -436,7 +436,7 @@ func (us *UnifiedServer) registerSysTools() error { result, err := us.sysServer.HandleRequest("tools/call", json.RawMessage(params)) if err != nil { logger.LogError("client", "MCP session initialization: sys_init call failed, session=%s, error=%v", sessionID, err) - return &sdk.CallToolResult{IsError: true}, nil, err + return newErrorCallToolResult(err) } resultJSON, _ := json.Marshal(result) @@ -486,7 +486,7 @@ func (us *UnifiedServer) registerSysTools() error { // Check session is initialized if err := us.requireSession(ctx); err != nil { logger.LogError("client", "MCP sys_list_servers failed: session not initialized, session=%s", sessionID) - return &sdk.CallToolResult{IsError: true}, nil, err + return newErrorCallToolResult(err) } params, _ := json.Marshal(map[string]interface{}{ @@ -496,7 +496,7 @@ func (us *UnifiedServer) registerSysTools() error { result, err := us.sysServer.HandleRequest("tools/call", json.RawMessage(params)) if err != nil { logger.LogError("client", "MCP sys_list_servers error, session=%s, error=%v", sessionID, err) - return &sdk.CallToolResult{IsError: true}, nil, err + return newErrorCallToolResult(err) } resultJSON, _ := json.Marshal(result) @@ -657,7 +657,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName resource, operation, err := g.LabelResource(ctx, toolName, args, backendCaller, us.capabilities) if err != nil { log.Printf("[DIFC] Guard labeling failed: %v", err) - return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("guard labeling failed: %w", err) + return newErrorCallToolResult(fmt.Errorf("guard labeling failed: %w", err)) } log.Printf("[DIFC] Resource: %s | Operation: %s | Secrecy: %v | Integrity: %v", @@ -688,7 +688,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName sessionID := us.getSessionID(ctx) conn, err := launcher.GetOrLaunchForSession(us.launcher, serverID, sessionID) if err != nil { - return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to connect: %w", err) + return newErrorCallToolResult(fmt.Errorf("failed to connect: %w", err)) } response, err := conn.SendRequestWithServerID(ctx, "tools/call", map[string]interface{}{ @@ -696,25 +696,25 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName "arguments": args, }, serverID) if err != nil { - return &sdk.CallToolResult{IsError: true}, nil, err + return newErrorCallToolResult(err) } // Check if the backend returned an error if response.Error != nil { - return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("backend error: code=%d, message=%s", response.Error.Code, response.Error.Message) + return newErrorCallToolResult(fmt.Errorf("backend error: code=%d, message=%s", response.Error.Code, response.Error.Message)) } // Parse the backend result var backendResult interface{} if err := json.Unmarshal(response.Result, &backendResult); err != nil { - return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to parse result: %w", err) + return newErrorCallToolResult(fmt.Errorf("failed to parse result: %w", err)) } // **Phase 4: Guard labels the response data (for fine-grained filtering)** labeledData, err := g.LabelResponse(ctx, toolName, backendResult, backendCaller, us.capabilities) if err != nil { log.Printf("[DIFC] Response labeling failed: %v", err) - return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("response labeling failed: %w", err) + return newErrorCallToolResult(fmt.Errorf("response labeling failed: %w", err)) } // **Phase 5: Reference Monitor performs fine-grained filtering (if applicable)** @@ -735,13 +735,13 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName // Convert filtered data to result finalResult, err = filtered.ToResult() if err != nil { - return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to convert filtered data: %w", err) + return newErrorCallToolResult(fmt.Errorf("failed to convert filtered data: %w", err)) } } else { // Simple labeled data - already passed coarse-grained check finalResult, err = labeledData.ToResult() if err != nil { - return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to convert labeled data: %w", err) + return newErrorCallToolResult(fmt.Errorf("failed to convert labeled data: %w", err)) } } @@ -767,7 +767,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName // Convert finalResult to SDK CallToolResult format callResult, err := convertToCallToolResult(finalResult) if err != nil { - return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to convert result: %w", err) + return newErrorCallToolResult(fmt.Errorf("failed to convert result: %w", err)) } return callResult, finalResult, nil @@ -779,6 +779,13 @@ func (us *UnifiedServer) Run(transport sdk.Transport) error { return us.server.Run(us.ctx, transport) } +// newErrorCallToolResult creates a standard error CallToolResult +// This helper reduces code duplication for error returns following the pattern: +// return &sdk.CallToolResult{IsError: true}, nil, err +func newErrorCallToolResult(err error) (*sdk.CallToolResult, interface{}, error) { + return &sdk.CallToolResult{IsError: true}, nil, err +} + // parseToolArguments extracts and unmarshals tool arguments from a CallToolRequest // Returns the parsed arguments as a map, or an error if parsing fails func parseToolArguments(req *sdk.CallToolRequest) (map[string]interface{}, error) {