Skip to content

Commit 6adbe2b

Browse files
authored
Resolve JWKS keys in-process for embedded auth server (MCP server) (#4502)
When the embedded auth server is enabled, token validation currently fails silently because the token validator fetches JWKS keys over HTTP from the proxy's own endpoint. This self-referential HTTP call requires operators to set `insecureAllowHTTP` and/or `jwksAllowPrivateIP` flags — insecure workarounds that are difficult to debug when missing. This PR eliminates the self-referential HTTP fetch by wiring the embedded auth server's `KeyProvider` directly into the token validator. When both components run in the same process, JWKS keys are resolved in-memory with a graceful fallback to HTTP for cases where the local provider cannot satisfy the request. Note: this only addresses the issue for the runner and proxy runner - vMCP wiring will come in a separate change.
1 parent 6c4e023 commit 6adbe2b

File tree

11 files changed

+385
-32
lines changed

11 files changed

+385
-32
lines changed

pkg/auth/middleware.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun
5656
if reader := runner.GetUpstreamTokenReader(); reader != nil {
5757
opts = append(opts, WithUpstreamTokenReader(reader))
5858
}
59+
if provider := runner.GetKeyProvider(); provider != nil {
60+
opts = append(opts, WithKeyProvider(provider))
61+
}
5962

6063
middleware, authInfoHandler, err := GetAuthenticationMiddleware(context.Background(), params.OIDCConfig, opts...)
6164
if err != nil {

pkg/auth/middleware_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,9 @@ func TestCreateMiddleware_WithoutOIDCConfig(t *testing.T) {
102102
// Create mock runner
103103
mockRunner := mocks.NewMockMiddlewareRunner(ctrl)
104104

105-
// Expect GetUpstreamTokenReader to be called (returns nil = no auth server)
105+
// Expect GetUpstreamTokenReader and GetKeyProvider to be called (returns nil = no auth server)
106106
mockRunner.EXPECT().GetUpstreamTokenReader().Return(nil)
107+
mockRunner.EXPECT().GetKeyProvider().Return(nil)
107108

108109
// Expect AddMiddleware to be called with a middleware instance
109110
mockRunner.EXPECT().AddMiddleware(gomock.Any(), gomock.Any()).Do(func(name string, mw types.Middleware) {
@@ -213,8 +214,9 @@ func TestCreateMiddleware_EmptyParameters(t *testing.T) {
213214

214215
mockRunner := mocks.NewMockMiddlewareRunner(ctrl)
215216

216-
// Expect GetUpstreamTokenReader to be called (returns nil = no auth server)
217+
// Expect GetUpstreamTokenReader and GetKeyProvider to be called (returns nil = no auth server)
217218
mockRunner.EXPECT().GetUpstreamTokenReader().Return(nil)
219+
mockRunner.EXPECT().GetKeyProvider().Return(nil)
218220

219221
// Expect AddMiddleware to be called
220222
mockRunner.EXPECT().AddMiddleware(gomock.Any(), gomock.Any())

pkg/auth/token.go

Lines changed: 101 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/stacklok/toolhive-core/env"
2727
"github.com/stacklok/toolhive/pkg/auth/oauth"
2828
"github.com/stacklok/toolhive/pkg/auth/upstreamtoken"
29+
"github.com/stacklok/toolhive/pkg/authserver/server/keys"
2930
"github.com/stacklok/toolhive/pkg/networking"
3031
oauthproto "github.com/stacklok/toolhive/pkg/oauth"
3132
)
@@ -372,6 +373,12 @@ type TokenValidator struct {
372373
// nil means no enrichment (no embedded auth server).
373374
upstreamTokenReader upstreamtoken.TokenReader
374375

376+
// keyProvider provides in-process JWKS key lookups from the embedded auth
377+
// server's key provider. When set, getKeyFromJWKS resolves keys locally
378+
// before falling back to HTTP. Eliminates self-referential HTTP calls.
379+
// nil when no embedded auth server is configured.
380+
keyProvider keys.PublicKeyProvider
381+
375382
// Lazy JWKS registration
376383
jwksRegistered bool
377384
jwksRegistrationMu sync.Mutex
@@ -547,6 +554,7 @@ func registerIntrospectionProviders(config TokenValidatorConfig, clientSecret st
547554
type tokenValidatorOptions struct {
548555
envReader env.Reader
549556
upstreamTokenReader upstreamtoken.TokenReader
557+
keyProvider keys.PublicKeyProvider
550558
}
551559

552560
// TokenValidatorOption is a functional option for NewTokenValidator.
@@ -570,6 +578,31 @@ func WithUpstreamTokenReader(reader upstreamtoken.TokenReader) TokenValidatorOpt
570578
}
571579
}
572580

581+
// WithKeyProvider configures the token validator to use an in-process key
582+
// provider for JWKS lookups instead of fetching keys over HTTP. This is used
583+
// when the embedded auth server's key provider is available in the same process,
584+
// eliminating self-referential HTTP calls and the need for insecureAllowHTTP
585+
// and jwksAllowPrivateIP flags.
586+
//
587+
// Only PublicKeyProvider is required — the validator never signs tokens.
588+
func WithKeyProvider(provider keys.PublicKeyProvider) TokenValidatorOption {
589+
return func(o *tokenValidatorOptions) {
590+
o.keyProvider = provider
591+
}
592+
}
593+
594+
// resolveClientSecret returns the client secret from the config, falling back
595+
// to the TOOLHIVE_OIDC_CLIENT_SECRET environment variable if not set.
596+
func resolveClientSecret(configSecret string, envReader env.Reader) string {
597+
if configSecret != "" {
598+
return configSecret
599+
}
600+
if envSecret := envReader.Getenv("TOOLHIVE_OIDC_CLIENT_SECRET"); envSecret != "" {
601+
return envSecret
602+
}
603+
return ""
604+
}
605+
573606
// NewTokenValidator creates a new token validator.
574607
func NewTokenValidator(ctx context.Context, config TokenValidatorConfig, opts ...TokenValidatorOption) (*TokenValidator, error) {
575608
// Apply functional options
@@ -611,8 +644,9 @@ func NewTokenValidator(ctx context.Context, config TokenValidatorConfig, opts ..
611644
slog.Debug("OIDC discovery deferred - will discover on first validation request", "issuer", config.Issuer)
612645
}
613646

614-
// Ensure we have either an explicit JWKS URL or an issuer to discover from
615-
if jwksURL == "" && config.Issuer == "" {
647+
// Ensure we have either an explicit JWKS URL, an issuer to discover from,
648+
// or a local key provider (embedded auth server).
649+
if jwksURL == "" && config.Issuer == "" && o.keyProvider == nil {
616650
return nil, ErrMissingIssuerAndJWKSURL
617651
}
618652

@@ -638,14 +672,8 @@ func NewTokenValidator(ctx context.Context, config TokenValidatorConfig, opts ..
638672

639673
// Skip synchronous JWKS registration - will be done lazily on first use
640674

641-
// Load client secret from environment variable if not provided in config
642-
// This allows secrets to be injected via Kubernetes Secret references
643-
clientSecret := config.ClientSecret
644-
if clientSecret == "" {
645-
if envSecret := o.envReader.Getenv("TOOLHIVE_OIDC_CLIENT_SECRET"); envSecret != "" {
646-
clientSecret = envSecret
647-
}
648-
}
675+
// Resolve client secret from config or environment variable
676+
clientSecret := resolveClientSecret(config.ClientSecret, o.envReader)
649677

650678
// Register introspection providers
651679
registry, err := registerIntrospectionProviders(config, clientSecret)
@@ -667,6 +695,7 @@ func NewTokenValidator(ctx context.Context, config TokenValidatorConfig, opts ..
667695
registry: registry,
668696
insecureAllowHTTP: config.InsecureAllowHTTP,
669697
upstreamTokenReader: o.upstreamTokenReader,
698+
keyProvider: o.keyProvider,
670699
}
671700

672701
return validator, nil
@@ -802,8 +831,67 @@ func (v *TokenValidator) ensureOIDCDiscovered(ctx context.Context) error {
802831
return nil
803832
}
804833

834+
// getKeyFromLocalProvider attempts to find a verification key from the local
835+
// key provider (embedded auth server). Returns (key, nil) on success,
836+
// (nil, nil) to signal fallback to HTTP, or (nil, error) for hard failures.
837+
// validateTokenHeader checks the signing method is supported (RSA or ECDSA) and
838+
// extracts the key ID from the token header. Returns an error for unsupported
839+
// methods or a missing kid claim.
840+
func validateTokenHeader(token *jwt.Token) (string, error) {
841+
switch token.Method.(type) {
842+
case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
843+
// Supported signing methods
844+
default:
845+
return "", fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
846+
}
847+
848+
kid, ok := token.Header["kid"].(string)
849+
if !ok {
850+
return "", fmt.Errorf("token header missing kid")
851+
}
852+
return kid, nil
853+
}
854+
855+
func (v *TokenValidator) getKeyFromLocalProvider(ctx context.Context, token *jwt.Token) (interface{}, error) {
856+
if v.keyProvider == nil {
857+
return nil, nil
858+
}
859+
860+
kid, err := validateTokenHeader(token)
861+
if err != nil {
862+
return nil, err
863+
}
864+
865+
pubKeys, err := v.keyProvider.PublicKeys(ctx)
866+
if err != nil {
867+
slog.Debug("local JWKS provider failed, falling back to HTTP", "error", err)
868+
return nil, nil
869+
}
870+
871+
for _, k := range pubKeys {
872+
if k.KeyID == kid {
873+
slog.Debug("resolved JWKS key from embedded auth server", "kid", kid)
874+
return k.PublicKey, nil
875+
}
876+
}
877+
878+
// Key not found locally — fall back to HTTP JWKS
879+
slog.Debug("key not found in local JWKS provider, falling back to HTTP", "kid", kid)
880+
return nil, nil
881+
}
882+
805883
// getKeyFromJWKS gets the key from the JWKS.
806884
func (v *TokenValidator) getKeyFromJWKS(ctx context.Context, token *jwt.Token) (interface{}, error) {
885+
// Try local key provider first (embedded auth server in-process keys).
886+
// This avoids self-referential HTTP calls when the auth server and
887+
// token validator run in the same process.
888+
if key, err := v.getKeyFromLocalProvider(ctx, token); err != nil {
889+
return nil, err
890+
} else if key != nil {
891+
return key, nil
892+
}
893+
894+
// Fall through to HTTP-based JWKS lookup.
807895
// Defensive check: JWKS URL must be set before calling this function.
808896
// This invariant is normally guaranteed by ValidateToken calling ensureOIDCDiscovered first.
809897
if v.jwksURL == "" {
@@ -815,18 +903,9 @@ func (v *TokenValidator) getKeyFromJWKS(ctx context.Context, token *jwt.Token) (
815903
return nil, fmt.Errorf("JWKS registration failed: %w", err)
816904
}
817905

818-
// Validate the signing method
819-
switch token.Method.(type) {
820-
case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
821-
// Supported RSA signing methods
822-
default:
823-
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
824-
}
825-
826-
// Get the key ID from the token header
827-
kid, ok := token.Header["kid"].(string)
828-
if !ok {
829-
return nil, fmt.Errorf("token header missing kid")
906+
kid, err := validateTokenHeader(token)
907+
if err != nil {
908+
return nil, err
830909
}
831910

832911
// Get the key set from the JWKS

pkg/auth/token_test.go

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ import (
2626
envmocks "github.com/stacklok/toolhive-core/env/mocks"
2727
"github.com/stacklok/toolhive/pkg/auth/upstreamtoken"
2828
upstreamtokenmocks "github.com/stacklok/toolhive/pkg/auth/upstreamtoken/mocks"
29+
"github.com/stacklok/toolhive/pkg/authserver/server/keys"
30+
keysmocks "github.com/stacklok/toolhive/pkg/authserver/server/keys/mocks"
2931
"github.com/stacklok/toolhive/pkg/networking"
3032
oauthproto "github.com/stacklok/toolhive/pkg/oauth"
3133
)
@@ -2466,3 +2468,134 @@ func TestMiddleware_UpstreamTokenEnrichment(t *testing.T) {
24662468
require.Nil(t, captured.UpstreamTokens)
24672469
})
24682470
}
2471+
2472+
func TestWithKeyProvider(t *testing.T) {
2473+
t.Parallel()
2474+
2475+
ctrl := gomock.NewController(t)
2476+
provider := keysmocks.NewMockPublicKeyProvider(ctrl)
2477+
opt := WithKeyProvider(provider)
2478+
2479+
o := &tokenValidatorOptions{}
2480+
opt(o)
2481+
2482+
require.Equal(t, provider, o.keyProvider)
2483+
}
2484+
2485+
func TestGetKeyFromLocalProvider(t *testing.T) {
2486+
t.Parallel()
2487+
2488+
// Generate a test RSA key pair for verification
2489+
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
2490+
require.NoError(t, err)
2491+
2492+
t.Run("returns nil when no provider configured", func(t *testing.T) {
2493+
t.Parallel()
2494+
2495+
v := &TokenValidator{} // no keyProvider
2496+
token := &jwt.Token{
2497+
Method: jwt.SigningMethodRS256,
2498+
Header: map[string]interface{}{"kid": "test-kid"},
2499+
}
2500+
2501+
key, err := v.getKeyFromLocalProvider(context.Background(), token)
2502+
require.NoError(t, err)
2503+
require.Nil(t, key)
2504+
})
2505+
2506+
t.Run("returns key when kid matches", func(t *testing.T) {
2507+
t.Parallel()
2508+
2509+
ctrl := gomock.NewController(t)
2510+
provider := keysmocks.NewMockPublicKeyProvider(ctrl)
2511+
provider.EXPECT().PublicKeys(gomock.Any()).Return([]*keys.PublicKeyData{
2512+
{KeyID: "other-kid", PublicKey: &privateKey.PublicKey},
2513+
{KeyID: "target-kid", PublicKey: &privateKey.PublicKey},
2514+
}, nil)
2515+
2516+
v := &TokenValidator{keyProvider: provider}
2517+
token := &jwt.Token{
2518+
Method: jwt.SigningMethodRS256,
2519+
Header: map[string]interface{}{"kid": "target-kid"},
2520+
}
2521+
2522+
key, err := v.getKeyFromLocalProvider(context.Background(), token)
2523+
require.NoError(t, err)
2524+
require.NotNil(t, key)
2525+
require.Equal(t, &privateKey.PublicKey, key)
2526+
})
2527+
2528+
t.Run("falls back when kid not found", func(t *testing.T) {
2529+
t.Parallel()
2530+
2531+
ctrl := gomock.NewController(t)
2532+
provider := keysmocks.NewMockPublicKeyProvider(ctrl)
2533+
provider.EXPECT().PublicKeys(gomock.Any()).Return([]*keys.PublicKeyData{
2534+
{KeyID: "other-kid", PublicKey: &privateKey.PublicKey},
2535+
}, nil)
2536+
2537+
v := &TokenValidator{keyProvider: provider}
2538+
token := &jwt.Token{
2539+
Method: jwt.SigningMethodRS256,
2540+
Header: map[string]interface{}{"kid": "missing-kid"},
2541+
}
2542+
2543+
key, err := v.getKeyFromLocalProvider(context.Background(), token)
2544+
require.NoError(t, err)
2545+
require.Nil(t, key, "should return nil to signal HTTP fallback")
2546+
})
2547+
2548+
t.Run("falls back when provider returns error", func(t *testing.T) {
2549+
t.Parallel()
2550+
2551+
ctrl := gomock.NewController(t)
2552+
provider := keysmocks.NewMockPublicKeyProvider(ctrl)
2553+
provider.EXPECT().PublicKeys(gomock.Any()).Return(nil, errors.New("key unavailable"))
2554+
2555+
v := &TokenValidator{keyProvider: provider}
2556+
token := &jwt.Token{
2557+
Method: jwt.SigningMethodRS256,
2558+
Header: map[string]interface{}{"kid": "test-kid"},
2559+
}
2560+
2561+
key, err := v.getKeyFromLocalProvider(context.Background(), token)
2562+
require.NoError(t, err, "provider errors should trigger fallback, not hard failure")
2563+
require.Nil(t, key)
2564+
})
2565+
2566+
t.Run("rejects unsupported signing method", func(t *testing.T) {
2567+
t.Parallel()
2568+
2569+
ctrl := gomock.NewController(t)
2570+
provider := keysmocks.NewMockPublicKeyProvider(ctrl)
2571+
2572+
v := &TokenValidator{keyProvider: provider}
2573+
token := &jwt.Token{
2574+
Method: jwt.SigningMethodHS256,
2575+
Header: map[string]interface{}{"alg": "HS256", "kid": "test-kid"},
2576+
}
2577+
2578+
key, err := v.getKeyFromLocalProvider(context.Background(), token)
2579+
require.Error(t, err)
2580+
require.Contains(t, err.Error(), "unexpected signing method")
2581+
require.Nil(t, key)
2582+
})
2583+
2584+
t.Run("rejects missing kid", func(t *testing.T) {
2585+
t.Parallel()
2586+
2587+
ctrl := gomock.NewController(t)
2588+
provider := keysmocks.NewMockPublicKeyProvider(ctrl)
2589+
2590+
v := &TokenValidator{keyProvider: provider}
2591+
token := &jwt.Token{
2592+
Method: jwt.SigningMethodRS256,
2593+
Header: map[string]interface{}{},
2594+
}
2595+
2596+
key, err := v.getKeyFromLocalProvider(context.Background(), token)
2597+
require.Error(t, err)
2598+
require.Contains(t, err.Error(), "token header missing kid")
2599+
require.Nil(t, key)
2600+
})
2601+
}

0 commit comments

Comments
 (0)