Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun
if reader := runner.GetUpstreamTokenReader(); reader != nil {
opts = append(opts, WithUpstreamTokenReader(reader))
}
if provider := runner.GetKeyProvider(); provider != nil {
opts = append(opts, WithKeyProvider(provider))
}

middleware, authInfoHandler, err := GetAuthenticationMiddleware(context.Background(), params.OIDCConfig, opts...)
if err != nil {
Expand Down
6 changes: 4 additions & 2 deletions pkg/auth/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ func TestCreateMiddleware_WithoutOIDCConfig(t *testing.T) {
// Create mock runner
mockRunner := mocks.NewMockMiddlewareRunner(ctrl)

// Expect GetUpstreamTokenReader to be called (returns nil = no auth server)
// Expect GetUpstreamTokenReader and GetKeyProvider to be called (returns nil = no auth server)
mockRunner.EXPECT().GetUpstreamTokenReader().Return(nil)
mockRunner.EXPECT().GetKeyProvider().Return(nil)

// Expect AddMiddleware to be called with a middleware instance
mockRunner.EXPECT().AddMiddleware(gomock.Any(), gomock.Any()).Do(func(name string, mw types.Middleware) {
Expand Down Expand Up @@ -213,8 +214,9 @@ func TestCreateMiddleware_EmptyParameters(t *testing.T) {

mockRunner := mocks.NewMockMiddlewareRunner(ctrl)

// Expect GetUpstreamTokenReader to be called (returns nil = no auth server)
// Expect GetUpstreamTokenReader and GetKeyProvider to be called (returns nil = no auth server)
mockRunner.EXPECT().GetUpstreamTokenReader().Return(nil)
mockRunner.EXPECT().GetKeyProvider().Return(nil)

// Expect AddMiddleware to be called
mockRunner.EXPECT().AddMiddleware(gomock.Any(), gomock.Any())
Expand Down
121 changes: 99 additions & 22 deletions pkg/auth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/stacklok/toolhive-core/env"
"github.com/stacklok/toolhive/pkg/auth/oauth"
"github.com/stacklok/toolhive/pkg/auth/upstreamtoken"
"github.com/stacklok/toolhive/pkg/authserver/server/keys"
"github.com/stacklok/toolhive/pkg/networking"
oauthproto "github.com/stacklok/toolhive/pkg/oauth"
)
Expand Down Expand Up @@ -372,6 +373,12 @@ type TokenValidator struct {
// nil means no enrichment (no embedded auth server).
upstreamTokenReader upstreamtoken.TokenReader

// keyProvider provides in-process JWKS key lookups from the embedded auth
// server's KeyProvider. When set, getKeyFromJWKS resolves keys locally
// before falling back to HTTP. Eliminates self-referential HTTP calls.
// nil when no embedded auth server is configured.
keyProvider keys.KeyProvider

// Lazy JWKS registration
jwksRegistered bool
jwksRegistrationMu sync.Mutex
Expand Down Expand Up @@ -547,6 +554,7 @@ func registerIntrospectionProviders(config TokenValidatorConfig, clientSecret st
type tokenValidatorOptions struct {
envReader env.Reader
upstreamTokenReader upstreamtoken.TokenReader
keyProvider keys.KeyProvider
}

// TokenValidatorOption is a functional option for NewTokenValidator.
Expand All @@ -570,6 +578,29 @@ func WithUpstreamTokenReader(reader upstreamtoken.TokenReader) TokenValidatorOpt
}
}

// WithKeyProvider configures the token validator to use an in-process key
// provider for JWKS lookups instead of fetching keys over HTTP. This is used
// when the embedded auth server's KeyProvider is available in the same process,
// eliminating self-referential HTTP calls and the need for insecureAllowHTTP
// and jwksAllowPrivateIP flags.
func WithKeyProvider(provider keys.KeyProvider) TokenValidatorOption {
return func(o *tokenValidatorOptions) {
o.keyProvider = provider
}
}

// resolveClientSecret returns the client secret from the config, falling back
// to the TOOLHIVE_OIDC_CLIENT_SECRET environment variable if not set.
func resolveClientSecret(configSecret string, envReader env.Reader) string {
if configSecret != "" {
return configSecret
}
if envSecret := envReader.Getenv("TOOLHIVE_OIDC_CLIENT_SECRET"); envSecret != "" {
return envSecret
}
return ""
}

// NewTokenValidator creates a new token validator.
func NewTokenValidator(ctx context.Context, config TokenValidatorConfig, opts ...TokenValidatorOption) (*TokenValidator, error) {
// Apply functional options
Expand Down Expand Up @@ -611,8 +642,9 @@ func NewTokenValidator(ctx context.Context, config TokenValidatorConfig, opts ..
slog.Debug("OIDC discovery deferred - will discover on first validation request", "issuer", config.Issuer)
}

// Ensure we have either an explicit JWKS URL or an issuer to discover from
if jwksURL == "" && config.Issuer == "" {
// Ensure we have either an explicit JWKS URL, an issuer to discover from,
// or a local key provider (embedded auth server).
if jwksURL == "" && config.Issuer == "" && o.keyProvider == nil {
return nil, ErrMissingIssuerAndJWKSURL
}

Expand All @@ -638,14 +670,8 @@ func NewTokenValidator(ctx context.Context, config TokenValidatorConfig, opts ..

// Skip synchronous JWKS registration - will be done lazily on first use

// Load client secret from environment variable if not provided in config
// This allows secrets to be injected via Kubernetes Secret references
clientSecret := config.ClientSecret
if clientSecret == "" {
if envSecret := o.envReader.Getenv("TOOLHIVE_OIDC_CLIENT_SECRET"); envSecret != "" {
clientSecret = envSecret
}
}
// Resolve client secret from config or environment variable
clientSecret := resolveClientSecret(config.ClientSecret, o.envReader)

// Register introspection providers
registry, err := registerIntrospectionProviders(config, clientSecret)
Expand All @@ -667,6 +693,7 @@ func NewTokenValidator(ctx context.Context, config TokenValidatorConfig, opts ..
registry: registry,
insecureAllowHTTP: config.InsecureAllowHTTP,
upstreamTokenReader: o.upstreamTokenReader,
keyProvider: o.keyProvider,
}

return validator, nil
Expand Down Expand Up @@ -802,8 +829,67 @@ func (v *TokenValidator) ensureOIDCDiscovered(ctx context.Context) error {
return nil
}

// getKeyFromLocalProvider attempts to find a verification key from the local
// key provider (embedded auth server). Returns (key, nil) on success,
// (nil, nil) to signal fallback to HTTP, or (nil, error) for hard failures.
// validateTokenHeader checks the signing method is supported (RSA or ECDSA) and
// extracts the key ID from the token header. Returns an error for unsupported
// methods or a missing kid claim.
func validateTokenHeader(token *jwt.Token) (string, error) {
switch token.Method.(type) {
case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
// Supported signing methods
default:
return "", fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}

kid, ok := token.Header["kid"].(string)
if !ok {
return "", fmt.Errorf("token header missing kid")
}
return kid, nil
}

func (v *TokenValidator) getKeyFromLocalProvider(ctx context.Context, token *jwt.Token) (interface{}, error) {
if v.keyProvider == nil {
return nil, nil
}

kid, err := validateTokenHeader(token)
if err != nil {
return nil, err
}

pubKeys, err := v.keyProvider.PublicKeys(ctx)
if err != nil {
slog.Debug("local JWKS provider failed, falling back to HTTP", "error", err)
return nil, nil
}

for _, k := range pubKeys {
if k.KeyID == kid {
slog.Debug("resolved JWKS key from embedded auth server", "kid", kid)
return k.PublicKey, nil
}
}

// Key not found locally — fall back to HTTP JWKS
slog.Debug("key not found in local JWKS provider, falling back to HTTP", "kid", kid)
return nil, nil
}

// getKeyFromJWKS gets the key from the JWKS.
func (v *TokenValidator) getKeyFromJWKS(ctx context.Context, token *jwt.Token) (interface{}, error) {
// Try local key provider first (embedded auth server in-process keys).
// This avoids self-referential HTTP calls when the auth server and
// token validator run in the same process.
if key, err := v.getKeyFromLocalProvider(ctx, token); err != nil {
return nil, err
} else if key != nil {
return key, nil
}

// Fall through to HTTP-based JWKS lookup.
// Defensive check: JWKS URL must be set before calling this function.
// This invariant is normally guaranteed by ValidateToken calling ensureOIDCDiscovered first.
if v.jwksURL == "" {
Expand All @@ -815,18 +901,9 @@ func (v *TokenValidator) getKeyFromJWKS(ctx context.Context, token *jwt.Token) (
return nil, fmt.Errorf("JWKS registration failed: %w", err)
}

// Validate the signing method
switch token.Method.(type) {
case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
// Supported RSA signing methods
default:
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}

// Get the key ID from the token header
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, fmt.Errorf("token header missing kid")
kid, err := validateTokenHeader(token)
if err != nil {
return nil, err
}

// Get the key set from the JWKS
Expand Down
133 changes: 133 additions & 0 deletions pkg/auth/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
envmocks "github.com/stacklok/toolhive-core/env/mocks"
"github.com/stacklok/toolhive/pkg/auth/upstreamtoken"
upstreamtokenmocks "github.com/stacklok/toolhive/pkg/auth/upstreamtoken/mocks"
"github.com/stacklok/toolhive/pkg/authserver/server/keys"
keysmocks "github.com/stacklok/toolhive/pkg/authserver/server/keys/mocks"
"github.com/stacklok/toolhive/pkg/networking"
oauthproto "github.com/stacklok/toolhive/pkg/oauth"
)
Expand Down Expand Up @@ -2466,3 +2468,134 @@ func TestMiddleware_UpstreamTokenEnrichment(t *testing.T) {
require.Nil(t, captured.UpstreamTokens)
})
}

func TestWithKeyProvider(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
provider := keysmocks.NewMockKeyProvider(ctrl)
opt := WithKeyProvider(provider)

o := &tokenValidatorOptions{}
opt(o)

require.Equal(t, provider, o.keyProvider)
}

func TestGetKeyFromLocalProvider(t *testing.T) {
t.Parallel()

// Generate a test RSA key pair for verification
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

t.Run("returns nil when no provider configured", func(t *testing.T) {
t.Parallel()

v := &TokenValidator{} // no keyProvider
token := &jwt.Token{
Method: jwt.SigningMethodRS256,
Header: map[string]interface{}{"kid": "test-kid"},
}

key, err := v.getKeyFromLocalProvider(context.Background(), token)
require.NoError(t, err)
require.Nil(t, key)
})

t.Run("returns key when kid matches", func(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
provider := keysmocks.NewMockKeyProvider(ctrl)
provider.EXPECT().PublicKeys(gomock.Any()).Return([]*keys.PublicKeyData{
{KeyID: "other-kid", PublicKey: &privateKey.PublicKey},
{KeyID: "target-kid", PublicKey: &privateKey.PublicKey},
}, nil)

v := &TokenValidator{keyProvider: provider}
token := &jwt.Token{
Method: jwt.SigningMethodRS256,
Header: map[string]interface{}{"kid": "target-kid"},
}

key, err := v.getKeyFromLocalProvider(context.Background(), token)
require.NoError(t, err)
require.NotNil(t, key)
require.Equal(t, &privateKey.PublicKey, key)
})

t.Run("falls back when kid not found", func(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
provider := keysmocks.NewMockKeyProvider(ctrl)
provider.EXPECT().PublicKeys(gomock.Any()).Return([]*keys.PublicKeyData{
{KeyID: "other-kid", PublicKey: &privateKey.PublicKey},
}, nil)

v := &TokenValidator{keyProvider: provider}
token := &jwt.Token{
Method: jwt.SigningMethodRS256,
Header: map[string]interface{}{"kid": "missing-kid"},
}

key, err := v.getKeyFromLocalProvider(context.Background(), token)
require.NoError(t, err)
require.Nil(t, key, "should return nil to signal HTTP fallback")
})

t.Run("falls back when provider returns error", func(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
provider := keysmocks.NewMockKeyProvider(ctrl)
provider.EXPECT().PublicKeys(gomock.Any()).Return(nil, errors.New("key unavailable"))

v := &TokenValidator{keyProvider: provider}
token := &jwt.Token{
Method: jwt.SigningMethodRS256,
Header: map[string]interface{}{"kid": "test-kid"},
}

key, err := v.getKeyFromLocalProvider(context.Background(), token)
require.NoError(t, err, "provider errors should trigger fallback, not hard failure")
require.Nil(t, key)
})

t.Run("rejects unsupported signing method", func(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
provider := keysmocks.NewMockKeyProvider(ctrl)

v := &TokenValidator{keyProvider: provider}
token := &jwt.Token{
Method: jwt.SigningMethodHS256,
Header: map[string]interface{}{"alg": "HS256", "kid": "test-kid"},
}

key, err := v.getKeyFromLocalProvider(context.Background(), token)
require.Error(t, err)
require.Contains(t, err.Error(), "unexpected signing method")
require.Nil(t, key)
})

t.Run("rejects missing kid", func(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
provider := keysmocks.NewMockKeyProvider(ctrl)

v := &TokenValidator{keyProvider: provider}
token := &jwt.Token{
Method: jwt.SigningMethodRS256,
Header: map[string]interface{}{},
}

key, err := v.getKeyFromLocalProvider(context.Background(), token)
require.Error(t, err)
require.Contains(t, err.Error(), "token header missing kid")
require.Nil(t, key)
})
}
17 changes: 13 additions & 4 deletions pkg/authserver/runner/embeddedauthserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ const (
// It handles configuration transformation from authserver.RunConfig to authserver.Config,
// manages resource lifecycle, and provides HTTP handlers for OAuth/OIDC endpoints.
type EmbeddedAuthServer struct {
server authserver.Server
closeOnce sync.Once
closeErr error
server authserver.Server
keyProvider keys.KeyProvider
closeOnce sync.Once
closeErr error
}

// NewEmbeddedAuthServer creates an EmbeddedAuthServer from authserver.RunConfig.
Expand Down Expand Up @@ -105,7 +106,8 @@ func NewEmbeddedAuthServer(ctx context.Context, cfg *authserver.RunConfig) (*Emb
}

return &EmbeddedAuthServer{
server: server,
server: server,
keyProvider: keyProvider,
}, nil
}

Expand Down Expand Up @@ -142,6 +144,13 @@ func (e *EmbeddedAuthServer) UpstreamTokenRefresher() storage.UpstreamTokenRefre
return e.server.UpstreamTokenRefresher()
}

// KeyProvider returns the signing key provider used by the authorization server.
// This enables in-process JWKS key lookups, eliminating the need for
// self-referential HTTP calls when the token validator runs in the same process.
func (e *EmbeddedAuthServer) KeyProvider() keys.KeyProvider {
return e.keyProvider
}

// Routes returns the authorization server's HTTP route map.
//
// The /.well-known/ paths are registered explicitly because that namespace is shared:
Expand Down
Loading
Loading