Skip to content

Commit 9c974a3

Browse files
yroblataskbot
andauthored
Fix health-check close failing auth in vMCP BackendClient (#4613)
* Fix health-check close failing auth in vMCP BackendClient mcp-go's StreamableHTTP.Close() creates a DELETE request with context.Background(), discarding the health-check marker and identity from the original call context. The identityPropagatingRoundTripper only stored identity, so when ListCapabilities deferred c.Close(), the DELETE hit auth strategies with neither a health-check marker nor an identity, producing "no identity found in context" errors. Fix by capturing isHealthCheck at transport creation time and re-injecting both identity and the health-check marker into every outgoing request, including the synthetic DELETE from Close(). Fixes #4573 * fixes from review --------- Co-authored-by: taskbot <taskbot@users.noreply.github.com>
1 parent d851c69 commit 9c974a3

File tree

11 files changed

+223
-44
lines changed

11 files changed

+223
-44
lines changed

pkg/vmcp/auth/strategies/header_injection_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
"github.com/stretchr/testify/require"
1414

1515
authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types"
16-
"github.com/stacklok/toolhive/pkg/vmcp/health"
16+
healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context"
1717
)
1818

1919
func TestHeaderInjectionStrategy_Name(t *testing.T) {
@@ -43,7 +43,7 @@ func TestHeaderInjectionStrategy_Authenticate(t *testing.T) {
4343
HeaderValue: "secret-key-123",
4444
},
4545
},
46-
setupCtx: func() context.Context { return health.WithHealthCheckMarker(context.Background()) },
46+
setupCtx: func() context.Context { return healthcontext.WithHealthCheckMarker(context.Background()) },
4747
expectError: false,
4848
checkHeader: func(t *testing.T, req *http.Request) {
4949
t.Helper()

pkg/vmcp/auth/strategies/tokenexchange.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import (
1818
"github.com/stacklok/toolhive/pkg/auth"
1919
"github.com/stacklok/toolhive/pkg/auth/tokenexchange"
2020
authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types"
21-
"github.com/stacklok/toolhive/pkg/vmcp/health"
21+
healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context"
2222
)
2323

2424
const (
@@ -107,7 +107,7 @@ func (s *TokenExchangeStrategy) Authenticate(
107107
// For health checks there is no user identity to exchange. If client credentials
108108
// are configured, use a client credentials grant to authenticate the probe request.
109109
// Otherwise skip authentication — the backend will be probed unauthenticated.
110-
if health.IsHealthCheck(ctx) {
110+
if healthcontext.IsHealthCheck(ctx) {
111111
if config.ClientID != "" && config.ClientSecret != "" {
112112
return s.authenticateWithClientCredentials(ctx, req, config)
113113
}

pkg/vmcp/auth/strategies/tokenexchange_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import (
1818
"github.com/stacklok/toolhive-core/env/mocks"
1919
"github.com/stacklok/toolhive/pkg/auth"
2020
authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types"
21-
"github.com/stacklok/toolhive/pkg/vmcp/health"
21+
healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context"
2222
)
2323

2424
// Test constants
@@ -108,7 +108,7 @@ func TestTokenExchangeStrategy_Authenticate(t *testing.T) {
108108
}{
109109
{
110110
name: "health check without client credentials skips authentication",
111-
setupCtx: func() context.Context { return health.WithHealthCheckMarker(context.Background()) },
111+
setupCtx: func() context.Context { return healthcontext.WithHealthCheckMarker(context.Background()) },
112112
setupServer: func() *httptest.Server {
113113
return httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
114114
t.Error("token endpoint should not be called when no client credentials are configured")
@@ -125,7 +125,7 @@ func TestTokenExchangeStrategy_Authenticate(t *testing.T) {
125125
},
126126
{
127127
name: "health check with client credentials uses client credentials grant",
128-
setupCtx: func() context.Context { return health.WithHealthCheckMarker(context.Background()) },
128+
setupCtx: func() context.Context { return healthcontext.WithHealthCheckMarker(context.Background()) },
129129
setupServer: func() *httptest.Server {
130130
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
131131
t.Helper()

pkg/vmcp/auth/strategies/upstream_inject.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010

1111
"github.com/stacklok/toolhive/pkg/auth"
1212
authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types"
13-
"github.com/stacklok/toolhive/pkg/vmcp/health"
13+
healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context"
1414
)
1515

1616
// UpstreamInjectStrategy injects an upstream IDP token into backend request headers.
@@ -62,7 +62,7 @@ func (*UpstreamInjectStrategy) Authenticate(
6262
ctx context.Context, req *http.Request, strategy *authtypes.BackendAuthStrategy,
6363
) error {
6464
// Health checks have no user identity — skip authentication.
65-
if health.IsHealthCheck(ctx) {
65+
if healthcontext.IsHealthCheck(ctx) {
6666
return nil
6767
}
6868

pkg/vmcp/auth/strategies/upstream_inject_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414
"github.com/stretchr/testify/require"
1515

1616
authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" // BackendAuthStrategy, ErrUpstreamTokenNotFound
17-
"github.com/stacklok/toolhive/pkg/vmcp/health"
17+
healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context"
1818
)
1919

2020
func TestUpstreamInjectStrategy_Name(t *testing.T) {
@@ -118,7 +118,7 @@ func TestUpstreamInjectStrategy_Authenticate(t *testing.T) {
118118
},
119119
{
120120
name: "health check bypass",
121-
setupCtx: func() context.Context { return health.WithHealthCheckMarker(context.Background()) },
121+
setupCtx: func() context.Context { return healthcontext.WithHealthCheckMarker(context.Background()) },
122122
strategy: &authtypes.BackendAuthStrategy{
123123
Type: authtypes.StrategyTypeUpstreamInject,
124124
UpstreamInject: &authtypes.UpstreamInjectConfig{

pkg/vmcp/client/client.go

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import (
3232
vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth"
3333
authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types"
3434
"github.com/stacklok/toolhive/pkg/vmcp/conversion"
35+
healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context"
3536
)
3637

3738
const (
@@ -168,19 +169,31 @@ func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
168169
return f(req)
169170
}
170171

171-
// identityPropagatingRoundTripper propagates identity to backend HTTP requests.
172+
// identityPropagatingRoundTripper propagates identity and health-check markers to backend HTTP requests.
172173
// This ensures that identity information from the vMCP handler is available for authentication
173174
// strategies that need it (e.g., token exchange).
175+
//
176+
// The health-check marker is stored at transport creation time and re-injected into every
177+
// outgoing request, including the DELETE that mcp-go sends when closing a streamable-HTTP
178+
// session. Without this, mcp-go's Close() creates a fresh context.Background()-based request
179+
// that loses the health-check marker, causing auth strategies (UpstreamInjectStrategy,
180+
// TokenExchangeStrategy) to fail with "no identity found in context".
174181
type identityPropagatingRoundTripper struct {
175-
base http.RoundTripper
176-
identity *auth.Identity
182+
base http.RoundTripper
183+
identity *auth.Identity
184+
isHealthCheck bool
177185
}
178186

179-
// RoundTrip implements http.RoundTripper by adding identity to the request context.
187+
// RoundTrip implements http.RoundTripper by adding identity and health-check marker to the request context.
180188
func (i *identityPropagatingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
189+
ctx := req.Context()
181190
if i.identity != nil {
182-
// Add identity to the request's context
183-
ctx := auth.WithIdentity(req.Context(), i.identity)
191+
ctx = auth.WithIdentity(ctx, i.identity)
192+
}
193+
if i.isHealthCheck {
194+
ctx = healthcontext.WithHealthCheckMarker(ctx)
195+
}
196+
if i.identity != nil || i.isHealthCheck {
184197
req = req.Clone(ctx)
185198
}
186199
return i.base.RoundTrip(req)
@@ -283,12 +296,16 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm
283296
target: target,
284297
}
285298

286-
// Extract identity from context and propagate it to backend requests
287-
// This ensures authentication strategies (e.g., token exchange) can access identity
299+
// Extract identity and health-check marker from context and propagate them to backend
300+
// requests. The health-check marker must be carried through to the DELETE request that
301+
// mcp-go emits when closing a streamable-HTTP session: mcp-go creates that request with
302+
// context.Background(), which loses both the identity and the health-check marker that
303+
// were present on the original ListCapabilities call context.
288304
identity, _ := auth.IdentityFromContext(ctx)
289305
baseTransport = &identityPropagatingRoundTripper{
290-
base: baseTransport,
291-
identity: identity,
306+
base: baseTransport,
307+
identity: identity,
308+
isHealthCheck: healthcontext.IsHealthCheck(ctx),
292309
}
293310

294311
// Inject W3C Trace Context headers (traceparent/tracestate) into outgoing requests.

pkg/vmcp/client/client_test.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@ import (
3333
"go.opentelemetry.io/otel/trace"
3434
"go.uber.org/mock/gomock"
3535

36+
pkgauth "github.com/stacklok/toolhive/pkg/auth"
3637
"github.com/stacklok/toolhive/pkg/vmcp"
3738
"github.com/stacklok/toolhive/pkg/vmcp/auth"
3839
authmocks "github.com/stacklok/toolhive/pkg/vmcp/auth/mocks"
3940
"github.com/stacklok/toolhive/pkg/vmcp/auth/strategies"
4041
authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types"
42+
healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context"
4143
)
4244

4345
func TestHTTPBackendClient_ListCapabilities_WithMockFactory(t *testing.T) {
@@ -1020,3 +1022,125 @@ func TestWrapBackendError(t *testing.T) {
10201022
})
10211023
}
10221024
}
1025+
1026+
// ---------------------------------------------------------------------------
1027+
// identityPropagatingRoundTripper
1028+
// ---------------------------------------------------------------------------
1029+
1030+
func TestIdentityPropagatingRoundTripper_WithIdentity_PropagatesIdentityInContext(t *testing.T) {
1031+
t.Parallel()
1032+
1033+
base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}}
1034+
identity := &pkgauth.Identity{PrincipalInfo: pkgauth.PrincipalInfo{Subject: "user-1"}}
1035+
rt := &identityPropagatingRoundTripper{base: base, identity: identity}
1036+
1037+
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://backend.example.com/mcp", nil)
1038+
require.NoError(t, err)
1039+
1040+
_, err = rt.RoundTrip(req)
1041+
require.NoError(t, err)
1042+
1043+
require.NotNil(t, base.capturedReq)
1044+
got, ok := pkgauth.IdentityFromContext(base.capturedReq.Context())
1045+
require.True(t, ok, "identity should be in downstream request context")
1046+
assert.Equal(t, "user-1", got.Subject)
1047+
}
1048+
1049+
func TestIdentityPropagatingRoundTripper_NilIdentity_NoIdentityInContext(t *testing.T) {
1050+
t.Parallel()
1051+
1052+
base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}}
1053+
rt := &identityPropagatingRoundTripper{base: base, identity: nil}
1054+
1055+
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://backend.example.com/mcp", nil)
1056+
require.NoError(t, err)
1057+
1058+
_, err = rt.RoundTrip(req)
1059+
require.NoError(t, err)
1060+
1061+
require.NotNil(t, base.capturedReq)
1062+
_, ok := pkgauth.IdentityFromContext(base.capturedReq.Context())
1063+
assert.False(t, ok, "no identity should be in downstream context when nil identity configured")
1064+
}
1065+
1066+
func TestIdentityPropagatingRoundTripper_HealthCheck_PropagatesMarker(t *testing.T) {
1067+
t.Parallel()
1068+
1069+
base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}}
1070+
rt := &identityPropagatingRoundTripper{base: base, identity: nil, isHealthCheck: true}
1071+
1072+
// Simulate mcp-go Close(): request created with context.Background(), no health check marker.
1073+
req, err := http.NewRequestWithContext(context.Background(), http.MethodDelete, "http://backend.example.com/mcp", nil)
1074+
require.NoError(t, err)
1075+
1076+
_, err = rt.RoundTrip(req)
1077+
require.NoError(t, err)
1078+
1079+
require.NotNil(t, base.capturedReq)
1080+
assert.True(t, healthcontext.IsHealthCheck(base.capturedReq.Context()),
1081+
"health check marker should be propagated even when original request context lacks it")
1082+
}
1083+
1084+
func TestIdentityPropagatingRoundTripper_NonHealthCheck_NoMarkerAdded(t *testing.T) {
1085+
t.Parallel()
1086+
1087+
base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}}
1088+
rt := &identityPropagatingRoundTripper{base: base, identity: nil, isHealthCheck: false}
1089+
1090+
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://backend.example.com/mcp", nil)
1091+
require.NoError(t, err)
1092+
1093+
_, err = rt.RoundTrip(req)
1094+
require.NoError(t, err)
1095+
1096+
require.NotNil(t, base.capturedReq)
1097+
assert.False(t, healthcontext.IsHealthCheck(base.capturedReq.Context()),
1098+
"health check marker should not be injected for non-health-check transports")
1099+
}
1100+
1101+
func TestIdentityPropagatingRoundTripper_HealthCheckWithIdentity_PropagatesBoth(t *testing.T) {
1102+
t.Parallel()
1103+
1104+
base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}}
1105+
identity := &pkgauth.Identity{PrincipalInfo: pkgauth.PrincipalInfo{Subject: "svc-account"}}
1106+
rt := &identityPropagatingRoundTripper{base: base, identity: identity, isHealthCheck: true}
1107+
1108+
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://backend.example.com/mcp", nil)
1109+
require.NoError(t, err)
1110+
1111+
_, err = rt.RoundTrip(req)
1112+
require.NoError(t, err)
1113+
1114+
require.NotNil(t, base.capturedReq)
1115+
got, ok := pkgauth.IdentityFromContext(base.capturedReq.Context())
1116+
require.True(t, ok)
1117+
assert.Equal(t, "svc-account", got.Subject)
1118+
assert.True(t, healthcontext.IsHealthCheck(base.capturedReq.Context()))
1119+
}
1120+
1121+
// TestIdentityPropagatingRoundTripper_HealthCheckClose_OriginalRequestContextUnchanged verifies
1122+
// that when the transport is in health-check mode, RoundTrip injects the health-check marker
1123+
// into the downstream request's context without mutating the original request context. This
1124+
// covers requests (e.g. the DELETE mcp-go emits on Close()) whose context does not already
1125+
// carry the marker.
1126+
func TestIdentityPropagatingRoundTripper_HealthCheckClose_OriginalRequestContextUnchanged(t *testing.T) {
1127+
t.Parallel()
1128+
1129+
base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}}
1130+
rt := &identityPropagatingRoundTripper{base: base, identity: nil, isHealthCheck: true}
1131+
1132+
originalCtx := context.Background() // no health check marker — simulates mcp-go Close()
1133+
req, err := http.NewRequestWithContext(originalCtx, http.MethodDelete, "http://backend.example.com/mcp", nil)
1134+
require.NoError(t, err)
1135+
1136+
_, err = rt.RoundTrip(req)
1137+
require.NoError(t, err)
1138+
1139+
// Original request context must NOT be modified.
1140+
assert.False(t, healthcontext.IsHealthCheck(originalCtx),
1141+
"original request context must not be mutated")
1142+
// But downstream context MUST have the marker.
1143+
require.NotNil(t, base.capturedReq)
1144+
assert.True(t, healthcontext.IsHealthCheck(base.capturedReq.Context()),
1145+
"downstream request must carry health check marker")
1146+
}

pkg/vmcp/health/context/context.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
// Package healthcontext provides a lightweight, dependency-free context marker
5+
// for identifying health check requests. Keeping this in a separate package
6+
// allows packages like pkg/vmcp/client and pkg/vmcp/auth/strategies to use
7+
// the marker without pulling in the heavyweight pkg/vmcp/health dependencies
8+
// (e.g. k8s.io/apimachinery).
9+
package healthcontext
10+
11+
import "context"
12+
13+
// healthCheckContextKey is an unexported key type for the health check marker.
14+
type healthCheckContextKey struct{}
15+
16+
// WithHealthCheckMarker marks a context as a health check request.
17+
// Authentication layers can use IsHealthCheck to identify and skip authentication
18+
// for health check requests.
19+
func WithHealthCheckMarker(ctx context.Context) context.Context {
20+
return context.WithValue(ctx, healthCheckContextKey{}, true)
21+
}
22+
23+
// IsHealthCheck returns true if the context is marked as a health check.
24+
// Authentication strategies use this to bypass authentication for health checks,
25+
// since health checks verify backend availability and should not require user credentials.
26+
// Returns false for nil contexts.
27+
func IsHealthCheck(ctx context.Context) bool {
28+
if ctx == nil {
29+
return false
30+
}
31+
val, ok := ctx.Value(healthCheckContextKey{}).(bool)
32+
return ok && val
33+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package healthcontext
5+
6+
import (
7+
"context"
8+
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
func TestIsHealthCheck_WrongValueType(t *testing.T) {
14+
t.Parallel()
15+
16+
ctx := context.WithValue(context.Background(), healthCheckContextKey{}, "not-a-bool")
17+
assert.False(t, IsHealthCheck(ctx), "non-bool value should not be treated as health check marker")
18+
}
19+
20+
func TestIsHealthCheck_FalseValue(t *testing.T) {
21+
t.Parallel()
22+
23+
ctx := context.WithValue(context.Background(), healthCheckContextKey{}, false)
24+
assert.False(t, IsHealthCheck(ctx), "explicit false value should not be treated as health check marker")
25+
}

pkg/vmcp/health/monitor.go

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,22 @@ import (
1414
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1515

1616
"github.com/stacklok/toolhive/pkg/vmcp"
17+
healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context"
1718
)
1819

19-
// healthCheckContextKey is a marker for health check requests.
20-
type healthCheckContextKey struct{}
21-
2220
// WithHealthCheckMarker marks a context as a health check request.
2321
// Authentication layers can use IsHealthCheck to identify and skip authentication
2422
// for health check requests.
2523
func WithHealthCheckMarker(ctx context.Context) context.Context {
26-
return context.WithValue(ctx, healthCheckContextKey{}, true)
24+
return healthcontext.WithHealthCheckMarker(ctx)
2725
}
2826

2927
// IsHealthCheck returns true if the context is marked as a health check.
3028
// Authentication strategies use this to bypass authentication for health checks,
3129
// since health checks verify backend availability and should not require user credentials.
3230
// Returns false for nil contexts.
3331
func IsHealthCheck(ctx context.Context) bool {
34-
if ctx == nil {
35-
return false
36-
}
37-
val, ok := ctx.Value(healthCheckContextKey{}).(bool)
38-
return ok && val
32+
return healthcontext.IsHealthCheck(ctx)
3933
}
4034

4135
// StatusProvider provides read-only access to backend health status.

0 commit comments

Comments
 (0)