@@ -33,11 +33,13 @@ import (
3333 "go.opentelemetry.io/otel/trace"
3434 "go.uber.org/mock/gomock"
3535
36+ pkgauth "github.com/stacklok/toolhive/pkg/auth"
3637 "github.com/stacklok/toolhive/pkg/vmcp"
3738 "github.com/stacklok/toolhive/pkg/vmcp/auth"
3839 authmocks "github.com/stacklok/toolhive/pkg/vmcp/auth/mocks"
3940 "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies"
4041 authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types"
42+ healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context"
4143)
4244
4345func TestHTTPBackendClient_ListCapabilities_WithMockFactory (t * testing.T ) {
@@ -1020,3 +1022,125 @@ func TestWrapBackendError(t *testing.T) {
10201022 })
10211023 }
10221024}
1025+
1026+ // ---------------------------------------------------------------------------
1027+ // identityPropagatingRoundTripper
1028+ // ---------------------------------------------------------------------------
1029+
1030+ func TestIdentityPropagatingRoundTripper_WithIdentity_PropagatesIdentityInContext (t * testing.T ) {
1031+ t .Parallel ()
1032+
1033+ base := & mockRoundTripper {response : & http.Response {StatusCode : http .StatusOK }}
1034+ identity := & pkgauth.Identity {PrincipalInfo : pkgauth.PrincipalInfo {Subject : "user-1" }}
1035+ rt := & identityPropagatingRoundTripper {base : base , identity : identity }
1036+
1037+ req , err := http .NewRequestWithContext (context .Background (), http .MethodGet , "http://backend.example.com/mcp" , nil )
1038+ require .NoError (t , err )
1039+
1040+ _ , err = rt .RoundTrip (req )
1041+ require .NoError (t , err )
1042+
1043+ require .NotNil (t , base .capturedReq )
1044+ got , ok := pkgauth .IdentityFromContext (base .capturedReq .Context ())
1045+ require .True (t , ok , "identity should be in downstream request context" )
1046+ assert .Equal (t , "user-1" , got .Subject )
1047+ }
1048+
1049+ func TestIdentityPropagatingRoundTripper_NilIdentity_NoIdentityInContext (t * testing.T ) {
1050+ t .Parallel ()
1051+
1052+ base := & mockRoundTripper {response : & http.Response {StatusCode : http .StatusOK }}
1053+ rt := & identityPropagatingRoundTripper {base : base , identity : nil }
1054+
1055+ req , err := http .NewRequestWithContext (context .Background (), http .MethodGet , "http://backend.example.com/mcp" , nil )
1056+ require .NoError (t , err )
1057+
1058+ _ , err = rt .RoundTrip (req )
1059+ require .NoError (t , err )
1060+
1061+ require .NotNil (t , base .capturedReq )
1062+ _ , ok := pkgauth .IdentityFromContext (base .capturedReq .Context ())
1063+ assert .False (t , ok , "no identity should be in downstream context when nil identity configured" )
1064+ }
1065+
1066+ func TestIdentityPropagatingRoundTripper_HealthCheck_PropagatesMarker (t * testing.T ) {
1067+ t .Parallel ()
1068+
1069+ base := & mockRoundTripper {response : & http.Response {StatusCode : http .StatusOK }}
1070+ rt := & identityPropagatingRoundTripper {base : base , identity : nil , isHealthCheck : true }
1071+
1072+ // Simulate mcp-go Close(): request created with context.Background(), no health check marker.
1073+ req , err := http .NewRequestWithContext (context .Background (), http .MethodDelete , "http://backend.example.com/mcp" , nil )
1074+ require .NoError (t , err )
1075+
1076+ _ , err = rt .RoundTrip (req )
1077+ require .NoError (t , err )
1078+
1079+ require .NotNil (t , base .capturedReq )
1080+ assert .True (t , healthcontext .IsHealthCheck (base .capturedReq .Context ()),
1081+ "health check marker should be propagated even when original request context lacks it" )
1082+ }
1083+
1084+ func TestIdentityPropagatingRoundTripper_NonHealthCheck_NoMarkerAdded (t * testing.T ) {
1085+ t .Parallel ()
1086+
1087+ base := & mockRoundTripper {response : & http.Response {StatusCode : http .StatusOK }}
1088+ rt := & identityPropagatingRoundTripper {base : base , identity : nil , isHealthCheck : false }
1089+
1090+ req , err := http .NewRequestWithContext (context .Background (), http .MethodGet , "http://backend.example.com/mcp" , nil )
1091+ require .NoError (t , err )
1092+
1093+ _ , err = rt .RoundTrip (req )
1094+ require .NoError (t , err )
1095+
1096+ require .NotNil (t , base .capturedReq )
1097+ assert .False (t , healthcontext .IsHealthCheck (base .capturedReq .Context ()),
1098+ "health check marker should not be injected for non-health-check transports" )
1099+ }
1100+
1101+ func TestIdentityPropagatingRoundTripper_HealthCheckWithIdentity_PropagatesBoth (t * testing.T ) {
1102+ t .Parallel ()
1103+
1104+ base := & mockRoundTripper {response : & http.Response {StatusCode : http .StatusOK }}
1105+ identity := & pkgauth.Identity {PrincipalInfo : pkgauth.PrincipalInfo {Subject : "svc-account" }}
1106+ rt := & identityPropagatingRoundTripper {base : base , identity : identity , isHealthCheck : true }
1107+
1108+ req , err := http .NewRequestWithContext (context .Background (), http .MethodGet , "http://backend.example.com/mcp" , nil )
1109+ require .NoError (t , err )
1110+
1111+ _ , err = rt .RoundTrip (req )
1112+ require .NoError (t , err )
1113+
1114+ require .NotNil (t , base .capturedReq )
1115+ got , ok := pkgauth .IdentityFromContext (base .capturedReq .Context ())
1116+ require .True (t , ok )
1117+ assert .Equal (t , "svc-account" , got .Subject )
1118+ assert .True (t , healthcontext .IsHealthCheck (base .capturedReq .Context ()))
1119+ }
1120+
1121+ // TestIdentityPropagatingRoundTripper_HealthCheckClose_OriginalRequestContextUnchanged verifies
1122+ // that when the transport is in health-check mode, RoundTrip injects the health-check marker
1123+ // into the downstream request's context without mutating the original request context. This
1124+ // covers requests (e.g. the DELETE mcp-go emits on Close()) whose context does not already
1125+ // carry the marker.
1126+ func TestIdentityPropagatingRoundTripper_HealthCheckClose_OriginalRequestContextUnchanged (t * testing.T ) {
1127+ t .Parallel ()
1128+
1129+ base := & mockRoundTripper {response : & http.Response {StatusCode : http .StatusOK }}
1130+ rt := & identityPropagatingRoundTripper {base : base , identity : nil , isHealthCheck : true }
1131+
1132+ originalCtx := context .Background () // no health check marker — simulates mcp-go Close()
1133+ req , err := http .NewRequestWithContext (originalCtx , http .MethodDelete , "http://backend.example.com/mcp" , nil )
1134+ require .NoError (t , err )
1135+
1136+ _ , err = rt .RoundTrip (req )
1137+ require .NoError (t , err )
1138+
1139+ // Original request context must NOT be modified.
1140+ assert .False (t , healthcontext .IsHealthCheck (originalCtx ),
1141+ "original request context must not be mutated" )
1142+ // But downstream context MUST have the marker.
1143+ require .NotNil (t , base .capturedReq )
1144+ assert .True (t , healthcontext .IsHealthCheck (base .capturedReq .Context ()),
1145+ "downstream request must carry health check marker" )
1146+ }
0 commit comments