Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 16 additions & 21 deletions internal/mcp/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,15 @@ func NewConnection(ctx context.Context, serverID, command string, args []string,
// For HTTP servers that are already running, we connect and initialize a session
//
// This function implements a fallback strategy for HTTP transports:
// 1. If custom headers are provided, skip SDK transports (they don't support custom headers)
// and use plain JSON-RPC 2.0 over HTTP POST (for safeinputs compatibility)
// 2. Otherwise, try standard transports:
// 1. Try standard transports in order:
// a. Streamable HTTP (2025-03-26 spec) using SDK's StreamableClientTransport
// b. SSE (2024-11-05 spec) using SDK's SSEClientTransport
// c. Plain JSON-RPC 2.0 over HTTP POST as final fallback
//
// Custom headers (e.g. Authorization) are injected into every outgoing request via a
// custom http.RoundTripper, so the SDK transports are used even when authentication
// headers are configured.
//
// This ensures compatibility with all types of HTTP MCP servers.
func NewHTTPConnection(ctx context.Context, serverID, url string, headers map[string]string) (*Connection, error) {
logger.LogInfo("backend", "Creating HTTP MCP connection with transport fallback, url=%s", url)
Expand All @@ -206,26 +208,16 @@ func NewHTTPConnection(ctx context.Context, serverID, url string, headers map[st
},
}

// If custom headers are provided, skip SDK transports as they don't support headers
// This is typical for backends like safeinputs that require authentication
if len(headers) > 0 {
logConn.Printf("Custom headers detected, using plain JSON-RPC transport for %s", url)
conn, err := tryPlainJSONTransport(ctx, cancel, serverID, url, headers, httpClient)
if err == nil {
logger.LogInfo("backend", "Successfully connected using plain JSON-RPC transport, url=%s", url)
log.Printf("Configured HTTP MCP server with plain JSON-RPC transport: %s", url)
return conn, nil
}
cancel()
logger.LogError("backend", "Plain JSON-RPC transport failed for url=%s, error=%v", url, err)
return nil, fmt.Errorf("failed to connect with plain JSON-RPC transport: %w", err)
}
// Build a header-injecting client so that all SDK transports send custom headers
// (e.g. Authorization) on every request. When no headers are configured the
// original client is returned unchanged.
headerClient := buildHTTPClientWithHeaders(httpClient, headers)

// Try standard transports in order: streamable HTTP → SSE → plain JSON-RPC

// Try 1: Streamable HTTP (2025-03-26 spec)
logConn.Printf("Attempting streamable HTTP transport for %s", url)
conn, err := tryStreamableHTTPTransport(ctx, cancel, serverID, url, headers, httpClient)
conn, err := tryStreamableHTTPTransport(ctx, cancel, serverID, url, headers, headerClient)
if err == nil {
logger.LogInfo("backend", "Successfully connected using streamable HTTP transport, url=%s", url)
log.Printf("Configured HTTP MCP server with streamable transport: %s", url)
Expand All @@ -235,7 +227,7 @@ func NewHTTPConnection(ctx context.Context, serverID, url string, headers map[st

// Try 2: SSE (2024-11-05 spec)
logConn.Printf("Attempting SSE transport for %s", url)
conn, err = trySSETransport(ctx, cancel, serverID, url, headers, httpClient)
conn, err = trySSETransport(ctx, cancel, serverID, url, headers, headerClient)
if err == nil {
logger.LogWarn("backend", "⚠️ MCP over SSE has been deprecated. Connected using SSE transport for url=%s. Please migrate to streamable HTTP transport (2025-03-26 spec).", url)
log.Printf("⚠️ WARNING: MCP over SSE (2024-11-05 spec) has been DEPRECATED")
Expand Down Expand Up @@ -313,20 +305,23 @@ func (c *Connection) reconnectSDKTransport() error {
_ = c.session.Close()
}

// Rebuild the header-injecting client so custom auth headers are preserved on reconnect.
headerClient := buildHTTPClientWithHeaders(c.httpClient, c.headers)

// Build the appropriate transport.
client := newMCPClient(logConn)
var transport sdk.Transport
switch c.httpTransportType {
case HTTPTransportStreamable:
transport = &sdk.StreamableClientTransport{
Endpoint: c.httpURL,
HTTPClient: c.httpClient,
HTTPClient: headerClient,
MaxRetries: 0,
}
case HTTPTransportSSE:
transport = &sdk.SSEClientTransport{
Endpoint: c.httpURL,
HTTPClient: c.httpClient,
HTTPClient: headerClient,
}
default:
return fmt.Errorf("cannot reconnect: unsupported transport type %s", c.httpTransportType)
Expand Down
12 changes: 11 additions & 1 deletion internal/mcp/connection_arguments_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,19 @@ func TestCallTool_ArgumentsPassed(t *testing.T) {
bodyBytes, err := io.ReadAll(r.Body)
require.NoError(t, err, "Failed to read request body")

// Ignore requests with empty or non-JSON bodies (e.g. GET/DELETE from
// the Streamable HTTP transport during session lifecycle management).
if len(bodyBytes) == 0 {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}

var request map[string]interface{}
err = json.Unmarshal(bodyBytes, &request)
require.NoError(t, err, "Failed to parse request JSON")
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}

method, _ := request["method"].(string)

Expand Down
8 changes: 7 additions & 1 deletion internal/mcp/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,14 @@ func TestHTTPRequest_ErrorResponses(t *testing.T) {
http.Error(w, "Internal error", http.StatusInternalServerError)
return
}
// Silently reject empty-body requests (e.g. GET/DELETE from Streamable
// transport during session lifecycle); they are not part of this test.
if len(bodyBytes) == 0 {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if err := json.Unmarshal(bodyBytes, &reqBody); err != nil {
t.Errorf("Failed to unmarshal request body: %v", err)
// Silently reject non-JSON bodies (probe requests from SDK transports).
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
Expand Down
35 changes: 20 additions & 15 deletions internal/mcp/http_connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@ import (
"github.com/stretchr/testify/require"
)

// TestNewHTTPConnection_WithCustomHeaders tests that custom headers skip SDK transports
// and use plain JSON-RPC transport directly
// TestNewHTTPConnection_WithCustomHeaders tests that custom headers are injected into the
// SDK-managed Streamable HTTP transport (not bypassed to plain JSON-RPC).
func TestNewHTTPConnection_WithCustomHeaders(t *testing.T) {
assert := assert.New(t)
require := require.New(t)

// Track which transport was attempted
// Track which requests were received
serverCallCount := 0

// Create test server that responds to initialize requests
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
serverCallCount++

// Verify custom headers are present
// Verify custom headers are injected into every request
assert.Equal("test-auth-token", r.Header.Get("Authorization"))
assert.Equal("custom-value", r.Header.Get("X-Custom-Header"))

Expand Down Expand Up @@ -62,11 +62,12 @@ func TestNewHTTPConnection_WithCustomHeaders(t *testing.T) {
// Verify connection properties
assert.True(conn.IsHTTP(), "Connection should be HTTP")
assert.Equal(testServer.URL, conn.GetHTTPURL())
assert.Equal(HTTPTransportPlainJSON, conn.httpTransportType, "Should use plain JSON transport")
// Custom headers are now injected via RoundTripper so the SDK Streamable transport is used
assert.Equal(HTTPTransportStreamable, conn.httpTransportType, "Should use Streamable HTTP transport even with custom headers")
assert.Equal("session-123", conn.httpSessionID, "Session ID should be captured")

// Verify only one call was made (plain JSON transport, no fallback attempts)
assert.Equal(1, serverCallCount, "Should only attempt plain JSON transport with custom headers")
// Verify at least one call was made (Streamable transport connects successfully)
assert.GreaterOrEqual(serverCallCount, 1, "Server should have received at least one request")
}

// TestNewHTTPConnection_WithoutHeaders_FallbackSequence tests connection without custom headers.
Expand Down Expand Up @@ -282,8 +283,10 @@ func TestTryPlainJSONTransport_InitializeFailure(t *testing.T) {
}
}

// TestTryPlainJSONTransport_SSEFormattedResponse tests handling of SSE-formatted responses
func TestTryPlainJSONTransport_SSEFormattedResponse(t *testing.T) {
// TestHTTPConnection_SSEFormattedResponse tests handling of SSE-formatted responses.
// Even when the server returns SSE-formatted data, the streamable HTTP transport
// (which is tried first) is able to handle it.
func TestHTTPConnection_SSEFormattedResponse(t *testing.T) {
require := require.New(t)

// Create test server that returns SSE-formatted initialize response
Expand Down Expand Up @@ -311,11 +314,14 @@ data: {"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","serverIn

// Verify session was captured
assert.Equal(t, "sse-session-456", conn.httpSessionID)
assert.Equal(t, HTTPTransportPlainJSON, conn.httpTransportType)
// The streamable transport handles SSE-formatted text/event-stream responses
assert.Equal(t, HTTPTransportStreamable, conn.httpTransportType)
}

// TestTryPlainJSONTransport_NoSessionIDInResponse tests handling when server doesn't return session ID
func TestTryPlainJSONTransport_NoSessionIDInResponse(t *testing.T) {
// TestHTTPConnection_NoSessionIDInResponse tests handling when server doesn't return session ID.
// When the streamable transport is used and the server omits Mcp-Session-Id, the
// connection still succeeds; the httpSessionID will be empty in that case.
func TestHTTPConnection_NoSessionIDInResponse(t *testing.T) {
require := require.New(t)

// Create test server that doesn't return Mcp-Session-Id header
Expand Down Expand Up @@ -347,9 +353,8 @@ func TestTryPlainJSONTransport_NoSessionIDInResponse(t *testing.T) {
require.NotNil(conn)
defer conn.Close()

// Should have a temporary session ID
assert.NotEmpty(t, conn.httpSessionID, "Should have temporary session ID")
assert.Contains(t, conn.httpSessionID, "awmg-init-", "Should be temporary session ID")
// Session ID may be empty when the server does not return one; the connection is still valid
assert.Equal(t, HTTPTransportStreamable, conn.httpTransportType)
}

// TestNewHTTPConnection_HeadersPropagation tests that custom headers are properly propagated
Expand Down
34 changes: 34 additions & 0 deletions internal/mcp/http_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,40 @@ func newHTTPConnection(ctx context.Context, cancel context.CancelFunc, client *s
}
}

// headerInjectingRoundTripper is an http.RoundTripper that injects a fixed set of
// HTTP headers into every outgoing request. It is used so that SDK-managed transports
// (StreamableClientTransport, SSEClientTransport) can send custom auth headers even
// though those transports do not expose a per-request header API.
type headerInjectingRoundTripper struct {
base http.RoundTripper
headers map[string]string
}

func (rt *headerInjectingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// Clone the request so we don't mutate the caller's copy.
reqCopy := req.Clone(req.Context())
for k, v := range rt.headers {
reqCopy.Header.Set(k, v)
}
return rt.base.RoundTrip(reqCopy)
}

// buildHTTPClientWithHeaders returns a copy of baseClient whose transport injects
// the provided headers into every outgoing request. When headers is empty the
// original baseClient is returned unchanged.
func buildHTTPClientWithHeaders(baseClient *http.Client, headers map[string]string) *http.Client {
if len(headers) == 0 {
return baseClient
}
base := baseClient.Transport
if base == nil {
base = http.DefaultTransport
}
clone := *baseClient
clone.Transport = &headerInjectingRoundTripper{base: base, headers: headers}
return &clone
}

// createJSONRPCRequest creates a JSON-RPC 2.0 request map
func createJSONRPCRequest(requestID uint64, method string, params interface{}) map[string]interface{} {
return map[string]interface{}{
Expand Down
91 changes: 81 additions & 10 deletions internal/mcp/http_transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -783,15 +783,14 @@ func TestSendHTTPRequest_SessionIDFromConnection(t *testing.T) {
}))
defer testServer.Close()

conn, err := NewHTTPConnection(context.Background(), "test-server", testServer.URL, map[string]string{
"Authorization": "test-token",
})
require.NoError(t, err)
// Use plain JSON transport directly: this test exercises the session-ID propagation
// logic in sendHTTPRequest, which is specific to the plain JSON-RPC code path.
conn := newPlainJSONConn(t, testServer.URL, map[string]string{"Authorization": "test-token"})
require.NotNil(t, conn)
defer conn.Close()

// No session ID in context - should use stored session from initialization
_, err = conn.sendHTTPRequest(context.Background(), "tools/list", nil)
_, err := conn.sendHTTPRequest(context.Background(), "tools/list", nil)
require.NoError(t, err)

require.Len(t, receivedSessionIDs, 1)
Expand Down Expand Up @@ -1069,15 +1068,87 @@ func TestSendHTTPRequest_NoReconnectOnOtherErrors(t *testing.T) {
}))
defer testServer.Close()

conn, err := NewHTTPConnection(context.Background(), "test-server", testServer.URL, map[string]string{
"Authorization": "test-token",
})
require.NoError(t, err)
// Use plain JSON transport directly: this test verifies the no-reconnect behaviour
// on 500 errors, which is specific to the plain JSON-RPC sendHTTPRequest path.
conn := newPlainJSONConn(t, testServer.URL, map[string]string{"Authorization": "test-token"})
require.NotNil(t, conn)
defer conn.Close()

_, err = conn.sendHTTPRequest(context.Background(), "tools/list", nil)
_, err := conn.sendHTTPRequest(context.Background(), "tools/list", nil)
require.NoError(t, err)

// initCount should be 1 (initial only) – no reconnect was attempted.
assert.Equal(t, 1, initCount, "should not reconnect on non-session-not-found errors")
}

// =============================================================================
// headerInjectingRoundTripper / buildHTTPClientWithHeaders tests
// =============================================================================

// TestHeaderInjectingRoundTripper verifies that every request made through the
// custom RoundTripper receives the configured headers.
func TestHeaderInjectingRoundTripper(t *testing.T) {
received := make(map[string]string)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
received["Authorization"] = r.Header.Get("Authorization")
received["X-Custom"] = r.Header.Get("X-Custom")
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

rt := &headerInjectingRoundTripper{
base: http.DefaultTransport,
headers: map[string]string{
"Authorization": "Basic dXNlcjpwYXNz",
"X-Custom": "hello",
},
}
client := &http.Client{Transport: rt}

req, err := http.NewRequestWithContext(context.Background(), "GET", srv.URL, nil)
require.NoError(t, err)

resp, err := client.Do(req)
require.NoError(t, err)
resp.Body.Close()

assert.Equal(t, "Basic dXNlcjpwYXNz", received["Authorization"])
assert.Equal(t, "hello", received["X-Custom"])
}

// TestBuildHTTPClientWithHeaders_Empty verifies that an empty headers map returns
// the same client (pointer equality).
func TestBuildHTTPClientWithHeaders_Empty(t *testing.T) {
base := &http.Client{}
result := buildHTTPClientWithHeaders(base, nil)
assert.Same(t, base, result, "empty headers should return the original client unchanged")

result2 := buildHTTPClientWithHeaders(base, map[string]string{})
assert.Same(t, base, result2, "empty map should return the original client unchanged")
}

// TestBuildHTTPClientWithHeaders_NonEmpty verifies that a non-empty headers map
// returns a new client whose transport injects the headers.
func TestBuildHTTPClientWithHeaders_NonEmpty(t *testing.T) {
received := make(map[string]string)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
received["Authorization"] = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

base := &http.Client{Transport: http.DefaultTransport}
injected := buildHTTPClientWithHeaders(base, map[string]string{
"Authorization": "Bearer token123",
})
assert.NotSame(t, base, injected, "non-empty headers should return a new client")

req, err := http.NewRequestWithContext(context.Background(), "GET", srv.URL, nil)
require.NoError(t, err)

resp, err := injected.Do(req)
require.NoError(t, err)
resp.Body.Close()

assert.Equal(t, "Bearer token123", received["Authorization"])
}
Loading
Loading