Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions internal/mcp/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@ import (

var logConn = logger.New("mcp:connection")

// isHTTPConnectionError checks if an error is a network connection error
// This helper reduces code duplication for checking common connection error patterns.
// Note: Uses string matching which is fragile but consistent with existing patterns in the codebase.
// TODO: Consider using errors.Is() or type assertions (*net.OpError) for more robust error classification.
func isHTTPConnectionError(err error) bool {
if err == nil {
return false
}
errMsg := err.Error()
return strings.Contains(errMsg, "connection refused") ||
strings.Contains(errMsg, "no such host") ||
strings.Contains(errMsg, "network is unreachable")
}

// parseSSEResponse extracts JSON data from SSE-formatted response
// SSE format: "event: message\ndata: {json}\n\n"
func parseSSEResponse(body []byte) ([]byte, error) {
Expand Down Expand Up @@ -584,9 +598,7 @@ func (c *Connection) initializeHTTPSession() (string, error) {
httpResp, err := c.httpClient.Do(httpReq)
if err != nil {
// Check if it's a connection error (cannot connect at all)
if strings.Contains(err.Error(), "connection refused") ||
strings.Contains(err.Error(), "no such host") ||
strings.Contains(err.Error(), "network is unreachable") {
if isHTTPConnectionError(err) {
return "", fmt.Errorf("cannot connect to HTTP backend at %s: %w", c.httpURL, err)
}
return "", fmt.Errorf("failed to send initialize request to %s: %w", c.httpURL, err)
Expand Down Expand Up @@ -698,9 +710,7 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params
httpResp, err := c.httpClient.Do(httpReq)
if err != nil {
// Check if it's a connection error (cannot connect at all)
if strings.Contains(err.Error(), "connection refused") ||
strings.Contains(err.Error(), "no such host") ||
strings.Contains(err.Error(), "network is unreachable") {
if isHTTPConnectionError(err) {
return nil, fmt.Errorf("cannot connect to HTTP backend at %s: %w", c.httpURL, err)
}
return nil, fmt.Errorf("failed to send HTTP request to %s: %w", c.httpURL, err)
Expand Down
47 changes: 47 additions & 0 deletions internal/mcp/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -703,3 +703,50 @@ func TestNewHTTPConnection(t *testing.T) {
assert.Equal(t, httpClient, conn.httpClient, "HTTP client should match")
assert.Equal(t, HTTPTransportStreamable, conn.httpTransportType, "Transport type should match")
}

// TestIsHTTPConnectionError tests the HTTP connection error detection helper
func TestIsHTTPConnectionError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "connection refused error",
err: fmt.Errorf("dial tcp: connection refused"),
expected: true,
},
{
name: "no such host error",
err: fmt.Errorf("dial tcp: lookup example.invalid: no such host"),
expected: true,
},
{
name: "network is unreachable error",
err: fmt.Errorf("dial tcp: network is unreachable"),
expected: true,
},
{
name: "other error",
err: fmt.Errorf("some other error"),
expected: false,
},
{
name: "timeout error",
err: fmt.Errorf("context deadline exceeded"),
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isHTTPConnectionError(tt.err)
assert.Equal(t, tt.expected, result, "isHTTPConnectionError should return %v for %v", tt.expected, tt.err)
})
}
}
39 changes: 39 additions & 0 deletions internal/server/call_tool_result_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"fmt"
"testing"

sdk "github.com/modelcontextprotocol/go-sdk/mcp"
Expand Down Expand Up @@ -151,3 +152,41 @@ func TestConvertToCallToolResult_NilCheck(t *testing.T) {

t.Log("✓ CallToolResult is properly non-nil and structured")
}

// TestNewErrorCallToolResult tests the error CallToolResult helper
func TestNewErrorCallToolResult(t *testing.T) {
tests := []struct {
name string
err error
expectError bool
}{
{
name: "simple error",
err: fmt.Errorf("test error"),
expectError: true,
},
{
name: "formatted error",
err: fmt.Errorf("formatted error: %s", "test"),
expectError: true,
Comment on lines +159 to +171

Copilot AI Feb 7, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expectError in the test case struct is never used (all assertions are unconditional). This is misleading/unused state—either remove the field or use it to drive assertions (e.g., to support a future nil-error case).

Suggested change
name string
err error
expectError bool
}{
{
name: "simple error",
err: fmt.Errorf("test error"),
expectError: true,
},
{
name: "formatted error",
err: fmt.Errorf("formatted error: %s", "test"),
expectError: true,
name string
err error
}{
{
name: "simple error",
err: fmt.Errorf("test error"),
},
{
name: "formatted error",
err: fmt.Errorf("formatted error: %s", "test"),

Copilot uses AI. Check for mistakes.
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, data, err := newErrorCallToolResult(tt.err)

// Verify the error is returned
assert.Equal(t, tt.err, err, "Error should be returned as-is")

// Verify data is nil
assert.Nil(t, data, "Data should be nil for error results")

// Verify CallToolResult is properly structured
require.NotNil(t, result, "CallToolResult should not be nil")
assert.True(t, result.IsError, "IsError should be true")

t.Logf("✓ Error CallToolResult properly created with IsError=%v", result.IsError)
})
}
}
39 changes: 23 additions & 16 deletions internal/server/unified.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ func (us *UnifiedServer) registerToolsFromBackend(serverID string) error {
toolArgs, err := parseToolArguments(req)
if err != nil {
logger.LogError("client", "Failed to unmarshal tool arguments, tool=%s, error=%v", toolNameCopy, err)
return &sdk.CallToolResult{IsError: true}, nil, err
return newErrorCallToolResult(err)
}

// Log the MCP tool call request
Expand All @@ -331,7 +331,7 @@ func (us *UnifiedServer) registerToolsFromBackend(serverID string) error {
// Check session is initialized
if err := us.requireSession(ctx); err != nil {
logger.LogError("client", "MCP tool call failed: session not initialized, session=%s, tool=%s", sessionID, toolNameCopy)
return &sdk.CallToolResult{IsError: true}, nil, err
return newErrorCallToolResult(err)
}

result, data, err := us.callBackendTool(ctx, serverIDCopy, toolNameCopy, toolArgs)
Expand Down Expand Up @@ -395,7 +395,7 @@ func (us *UnifiedServer) registerSysTools() error {
toolArgs, err := parseToolArguments(req)
if err != nil {
logger.LogError("client", "Failed to unmarshal sys_init arguments, error=%v", err)
return &sdk.CallToolResult{IsError: true}, nil, err
return newErrorCallToolResult(err)
}

// Extract token from args
Expand All @@ -408,7 +408,7 @@ func (us *UnifiedServer) registerSysTools() error {
sessionID := us.getSessionID(ctx)
if sessionID == "" {
logger.LogError("client", "MCP session initialization failed: no session ID provided")
return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("no session ID provided")
return newErrorCallToolResult(fmt.Errorf("no session ID provided"))
}

logger.LogInfo("client", "MCP session initialization started, session=%s, has_token=%v", sessionID, token != "")
Expand Down Expand Up @@ -436,7 +436,7 @@ func (us *UnifiedServer) registerSysTools() error {
result, err := us.sysServer.HandleRequest("tools/call", json.RawMessage(params))
if err != nil {
logger.LogError("client", "MCP session initialization: sys_init call failed, session=%s, error=%v", sessionID, err)
return &sdk.CallToolResult{IsError: true}, nil, err
return newErrorCallToolResult(err)
}

resultJSON, _ := json.Marshal(result)
Expand Down Expand Up @@ -486,7 +486,7 @@ func (us *UnifiedServer) registerSysTools() error {
// Check session is initialized
if err := us.requireSession(ctx); err != nil {
logger.LogError("client", "MCP sys_list_servers failed: session not initialized, session=%s", sessionID)
return &sdk.CallToolResult{IsError: true}, nil, err
return newErrorCallToolResult(err)
}

params, _ := json.Marshal(map[string]interface{}{
Expand All @@ -496,7 +496,7 @@ func (us *UnifiedServer) registerSysTools() error {
result, err := us.sysServer.HandleRequest("tools/call", json.RawMessage(params))
if err != nil {
logger.LogError("client", "MCP sys_list_servers error, session=%s, error=%v", sessionID, err)
return &sdk.CallToolResult{IsError: true}, nil, err
return newErrorCallToolResult(err)
}

resultJSON, _ := json.Marshal(result)
Expand Down Expand Up @@ -657,7 +657,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName
resource, operation, err := g.LabelResource(ctx, toolName, args, backendCaller, us.capabilities)
if err != nil {
log.Printf("[DIFC] Guard labeling failed: %v", err)
return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("guard labeling failed: %w", err)
return newErrorCallToolResult(fmt.Errorf("guard labeling failed: %w", err))
}

log.Printf("[DIFC] Resource: %s | Operation: %s | Secrecy: %v | Integrity: %v",
Expand Down Expand Up @@ -688,33 +688,33 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName
sessionID := us.getSessionID(ctx)
conn, err := launcher.GetOrLaunchForSession(us.launcher, serverID, sessionID)
if err != nil {
return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to connect: %w", err)
return newErrorCallToolResult(fmt.Errorf("failed to connect: %w", err))
}

response, err := conn.SendRequestWithServerID(ctx, "tools/call", map[string]interface{}{
"name": toolName,
"arguments": args,
}, serverID)
if err != nil {
return &sdk.CallToolResult{IsError: true}, nil, err
return newErrorCallToolResult(err)
}

// Check if the backend returned an error
if response.Error != nil {
return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("backend error: code=%d, message=%s", response.Error.Code, response.Error.Message)
return newErrorCallToolResult(fmt.Errorf("backend error: code=%d, message=%s", response.Error.Code, response.Error.Message))
}

// Parse the backend result
var backendResult interface{}
if err := json.Unmarshal(response.Result, &backendResult); err != nil {
return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to parse result: %w", err)
return newErrorCallToolResult(fmt.Errorf("failed to parse result: %w", err))
}

// **Phase 4: Guard labels the response data (for fine-grained filtering)**
labeledData, err := g.LabelResponse(ctx, toolName, backendResult, backendCaller, us.capabilities)
if err != nil {
log.Printf("[DIFC] Response labeling failed: %v", err)
return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("response labeling failed: %w", err)
return newErrorCallToolResult(fmt.Errorf("response labeling failed: %w", err))
}

// **Phase 5: Reference Monitor performs fine-grained filtering (if applicable)**
Expand All @@ -735,13 +735,13 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName
// Convert filtered data to result
finalResult, err = filtered.ToResult()
if err != nil {
return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to convert filtered data: %w", err)
return newErrorCallToolResult(fmt.Errorf("failed to convert filtered data: %w", err))
}
} else {
// Simple labeled data - already passed coarse-grained check
finalResult, err = labeledData.ToResult()
if err != nil {
return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to convert labeled data: %w", err)
return newErrorCallToolResult(fmt.Errorf("failed to convert labeled data: %w", err))
}
}

Expand All @@ -767,7 +767,7 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName
// Convert finalResult to SDK CallToolResult format
callResult, err := convertToCallToolResult(finalResult)
if err != nil {
return &sdk.CallToolResult{IsError: true}, nil, fmt.Errorf("failed to convert result: %w", err)
return newErrorCallToolResult(fmt.Errorf("failed to convert result: %w", err))
}

return callResult, finalResult, nil
Expand All @@ -779,6 +779,13 @@ func (us *UnifiedServer) Run(transport sdk.Transport) error {
return us.server.Run(us.ctx, transport)
}

// newErrorCallToolResult creates a standard error CallToolResult
// This helper reduces code duplication for error returns following the pattern:
// return &sdk.CallToolResult{IsError: true}, nil, err
func newErrorCallToolResult(err error) (*sdk.CallToolResult, interface{}, error) {
return &sdk.CallToolResult{IsError: true}, nil, err
}

// parseToolArguments extracts and unmarshals tool arguments from a CallToolRequest
// Returns the parsed arguments as a map, or an error if parsing fails
func parseToolArguments(req *sdk.CallToolRequest) (map[string]interface{}, error) {
Expand Down
Loading