Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions cmd/vmcp/app/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/stacklok/toolhive/pkg/auth/upstreamtoken"
authserverconfig "github.com/stacklok/toolhive/pkg/authserver"
authserverrunner "github.com/stacklok/toolhive/pkg/authserver/runner"
"github.com/stacklok/toolhive/pkg/authserver/server/keys"
"github.com/stacklok/toolhive/pkg/container/runtime"
"github.com/stacklok/toolhive/pkg/groups"
"github.com/stacklok/toolhive/pkg/telemetry"
Expand Down Expand Up @@ -589,18 +590,20 @@ func runServe(cmd *cobra.Command, _ []string) error {
}
}

// Create an upstream token reader from the embedded auth server so that
// the OIDC middleware can enrich Identity with upstream provider tokens.
// This is required for the upstream_inject outgoing auth strategy.
// Extract dependencies from the embedded auth server so the OIDC middleware
// can (a) resolve JWKS keys in-process instead of self-referential HTTP
// calls, and (b) enrich Identity with upstream provider tokens.
var upstreamReader upstreamtoken.TokenReader
var keyProvider keys.PublicKeyProvider
if embeddedAuthServer != nil {
stor := embeddedAuthServer.IDPTokenStorage()
refresher := embeddedAuthServer.UpstreamTokenRefresher()
upstreamReader = upstreamtoken.NewInProcessService(stor, refresher)
keyProvider = embeddedAuthServer.KeyProvider()
}

authMiddleware, authzMiddleware, authInfoHandler, err :=
factory.NewIncomingAuthMiddleware(ctx, cfg.IncomingAuth, passThroughTools, upstreamReader)
factory.NewIncomingAuthMiddleware(ctx, cfg.IncomingAuth, passThroughTools, upstreamReader, keyProvider)
if err != nil {
return fmt.Errorf("failed to create authentication middleware: %w", err)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/vmcp/auth/factory/authz_not_wired_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestNewIncomingAuthMiddleware_AuthzEnforced(t *testing.T) {
},
}

authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil)
authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil, nil)
require.NoError(t, err, "middleware creation should succeed")
require.NotNil(t, authMw, "auth middleware should not be nil")
require.NotNil(t, authzMw, "authz middleware should not be nil")
Expand Down Expand Up @@ -105,7 +105,7 @@ func TestNewIncomingAuthMiddleware_AuthzEnforced(t *testing.T) {
},
}

authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil)
authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil, nil)
require.NoError(t, err, "middleware creation should succeed")
require.NotNil(t, authMw, "auth middleware should not be nil")
require.NotNil(t, authzMw, "authz middleware should not be nil")
Expand Down Expand Up @@ -163,7 +163,7 @@ func TestNewIncomingAuthMiddleware_AuthzApproveAndBlock(t *testing.T) {
},
}

authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil)
authMw, authzMw, _, err := NewIncomingAuthMiddleware(t.Context(), cfg, nil, nil, nil)
require.NoError(t, err, "middleware creation should succeed")
require.NotNil(t, authMw, "auth middleware should not be nil")
require.NotNil(t, authzMw, "authz middleware should not be nil")
Expand Down
13 changes: 10 additions & 3 deletions pkg/vmcp/auth/factory/incoming.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/auth/upstreamtoken"
"github.com/stacklok/toolhive/pkg/authserver/server/keys"
"github.com/stacklok/toolhive/pkg/authz"
"github.com/stacklok/toolhive/pkg/authz/authorizers"
"github.com/stacklok/toolhive/pkg/authz/authorizers/cedar"
Expand Down Expand Up @@ -51,6 +52,7 @@ func NewIncomingAuthMiddleware(
cfg *config.IncomingAuthConfig,
passThroughTools map[string]struct{},
upstreamReader upstreamtoken.TokenReader,
keyProvider keys.PublicKeyProvider,
) (
authMw func(http.Handler) http.Handler,
authzMw func(http.Handler) http.Handler,
Expand All @@ -65,7 +67,7 @@ func NewIncomingAuthMiddleware(

switch cfg.Type {
case "oidc":
authMiddleware, authInfoHandler, err = newOIDCAuthMiddleware(ctx, cfg.OIDC, upstreamReader)
authMiddleware, authInfoHandler, err = newOIDCAuthMiddleware(ctx, cfg.OIDC, upstreamReader, keyProvider)
case "local":
authMiddleware, authInfoHandler, err = newLocalAuthMiddleware(ctx)
case "anonymous":
Expand Down Expand Up @@ -151,6 +153,7 @@ func newOIDCAuthMiddleware(
ctx context.Context,
oidcCfg *config.OIDCConfig,
reader upstreamtoken.TokenReader,
keyProvider keys.PublicKeyProvider,
) (func(http.Handler) http.Handler, http.Handler, error) {
if oidcCfg == nil {
return nil, nil, fmt.Errorf("OIDC configuration required when Type='oidc'")
Expand All @@ -175,9 +178,13 @@ func newOIDCAuthMiddleware(
Scopes: oidcCfg.Scopes,
}

// Wire the upstream token reader so the JWT validator can enrich Identity
// with upstream provider tokens (needed for upstream_inject auth strategy).
// Wire optional dependencies from the embedded auth server so the JWT
// validator can (a) resolve JWKS keys in-process instead of self-referential
// HTTP calls, and (b) enrich Identity with upstream provider tokens.
var opts []auth.TokenValidatorOption
if keyProvider != nil {
opts = append(opts, auth.WithKeyProvider(keyProvider))
}
if reader != nil {
opts = append(opts, auth.WithUpstreamTokenReader(reader))
}
Expand Down
213 changes: 213 additions & 0 deletions pkg/vmcp/auth/factory/incoming_keyprovider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package factory

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"

pkgauth "github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/authserver/server/keys"
keysmocks "github.com/stacklok/toolhive/pkg/authserver/server/keys/mocks"
"github.com/stacklok/toolhive/pkg/vmcp/config"
)

// TestNewOIDCAuthMiddleware_KeyProvider_LocalResolution verifies that when a
// PublicKeyProvider is wired in, key resolution happens in-process via the
// local provider rather than through an HTTP JWKS fetch.
func TestNewOIDCAuthMiddleware_KeyProvider_LocalResolution(t *testing.T) {
t.Parallel()

// Generate an ECDSA P-256 key pair (matching the embedded auth server's
// default GeneratingProvider algorithm).
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)

const ecdsaKeyID = "test-ecdsa-key-1"

// Stand up a minimal OIDC discovery server so issuer validation passes.
// The JWKS endpoint returns an empty key set — all key resolution should
// happen through the local provider, not HTTP.
server, _ := newTestOIDCServer(t)
t.Cleanup(server.Close)

issuer := server.URL

oidcCfg := &config.OIDCConfig{
Issuer: issuer,
ClientID: "test-client",
Audience: "test-audience",
InsecureAllowHTTP: true,
JwksAllowPrivateIP: true,
}

ctrl := gomock.NewController(t)
mockProvider := keysmocks.NewMockPublicKeyProvider(ctrl)
mockProvider.EXPECT().
PublicKeys(gomock.Any()).
Return([]*keys.PublicKeyData{{
KeyID: ecdsaKeyID,
Algorithm: "ES256",
PublicKey: &privateKey.PublicKey,
CreatedAt: time.Now(),
}}, nil).
AnyTimes()

authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, nil, mockProvider)
require.NoError(t, err, "middleware creation should succeed with key provider")
require.NotNil(t, authMw)

var capturedIdentity *pkgauth.Identity
handler := authMw(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
capturedIdentity, _ = pkgauth.IdentityFromContext(r.Context())
}))

// Sign a JWT with the ECDSA private key — only the local provider
// holds the matching public key.
tok := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{
"iss": issuer,
"aud": "test-audience",
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
})
tok.Header["kid"] = ecdsaKeyID
tokenString, err := tok.SignedString(privateKey)
require.NoError(t, err)

req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("Authorization", "Bearer "+tokenString)
rr := httptest.NewRecorder()

handler.ServeHTTP(rr, req)

require.Equal(t, http.StatusOK, rr.Code, "request should succeed via local key provider")
require.NotNil(t, capturedIdentity, "identity should be present in context")
assert.Equal(t, "test-user", capturedIdentity.Subject)
}

// TestNewOIDCAuthMiddleware_KeyProvider_HTTPFallback verifies that when the
// key provider is nil, key resolution falls back to an HTTP JWKS fetch.
func TestNewOIDCAuthMiddleware_KeyProvider_HTTPFallback(t *testing.T) {
t.Parallel()

// Use the RSA key from the test OIDC server (served via HTTP JWKS).
server, rsaPrivateKey := newTestOIDCServer(t)
t.Cleanup(server.Close)

issuer := server.URL
oidcCfg := &config.OIDCConfig{
Issuer: issuer,
ClientID: "test-client",
Audience: "test-audience",
InsecureAllowHTTP: true,
JwksAllowPrivateIP: true,
}

authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, nil, nil)
require.NoError(t, err)
require.NotNil(t, authMw)

var capturedIdentity *pkgauth.Identity
handler := authMw(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
capturedIdentity, _ = pkgauth.IdentityFromContext(r.Context())
}))

token := signJWT(t, rsaPrivateKey, jwt.MapClaims{
"iss": issuer,
"aud": "test-audience",
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
})

req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
rr := httptest.NewRecorder()

handler.ServeHTTP(rr, req)

require.Equal(t, http.StatusOK, rr.Code, "request should succeed via HTTP JWKS fallback")
require.NotNil(t, capturedIdentity, "identity should be present in context")
assert.Equal(t, "test-user", capturedIdentity.Subject)
}

// TestNewOIDCAuthMiddleware_KeyProvider_KidMissFallback verifies that when the
// local PublicKeyProvider does not hold a key matching the JWT's kid, the
// validator falls back to HTTP JWKS and the request still succeeds. This
// confirms the end-to-end wiring for the kid-miss path at the factory level.
func TestNewOIDCAuthMiddleware_KeyProvider_KidMissFallback(t *testing.T) {
t.Parallel()

// Stand up a real OIDC server that serves the RSA key via HTTP JWKS.
server, rsaPrivateKey := newTestOIDCServer(t)
t.Cleanup(server.Close)

issuer := server.URL
oidcCfg := &config.OIDCConfig{
Issuer: issuer,
ClientID: "test-client",
Audience: "test-audience",
InsecureAllowHTTP: true,
JwksAllowPrivateIP: true,
}

// Wire a mock provider that returns a key with a *different* kid than the
// one in the JWT. The validator should call the local provider first, get a
// kid-miss (nil key returned), and then fall back to HTTP JWKS.
ctrl := gomock.NewController(t)
mockProvider := keysmocks.NewMockPublicKeyProvider(ctrl)

// Generate a throwaway ECDSA key so the mock returns a non-nil key list
// with a different kid.
throwawayKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)

mockProvider.EXPECT().
PublicKeys(gomock.Any()).
Return([]*keys.PublicKeyData{{
KeyID: "unrelated-key-id", // does NOT match testKeyID used by signJWT
Algorithm: "ES256",
PublicKey: &throwawayKey.PublicKey,
CreatedAt: time.Now(),
}}, nil).
AnyTimes()

authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, nil, mockProvider)
require.NoError(t, err)
require.NotNil(t, authMw)

var capturedIdentity *pkgauth.Identity
handler := authMw(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
capturedIdentity, _ = pkgauth.IdentityFromContext(r.Context())
}))

// Sign the JWT with the RSA key from the test server (kid = testKeyID).
// The mock provider holds a key with a different kid, so the validator must
// fall back to HTTP JWKS to find the matching key.
token := signJWT(t, rsaPrivateKey, jwt.MapClaims{
"iss": issuer,
"aud": "test-audience",
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
})

req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
rr := httptest.NewRecorder()

handler.ServeHTTP(rr, req)

require.Equal(t, http.StatusOK, rr.Code, "request should succeed via HTTP JWKS fallback on kid-miss")
require.NotNil(t, capturedIdentity, "identity should be present in context")
assert.Equal(t, "test-user", capturedIdentity.Subject)
}
2 changes: 1 addition & 1 deletion pkg/vmcp/auth/factory/incoming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func TestNewIncomingAuthMiddleware(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

authMw, authzMw, authInfo, err := NewIncomingAuthMiddleware(t.Context(), tt.cfg, nil, nil)
authMw, authzMw, authInfo, err := NewIncomingAuthMiddleware(t.Context(), tt.cfg, nil, nil, nil)

if tt.wantErr {
require.Error(t, err)
Expand Down
6 changes: 3 additions & 3 deletions pkg/vmcp/auth/factory/incoming_upstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func TestNewOIDCAuthMiddleware_UpstreamTokenReaderWiring(t *testing.T) {
GetAllValidTokens(gomock.Any(), "session-abc").
Return(map[string]string{"google": "gcp-access-token"}, nil)

authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, reader)
authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, reader, nil)
require.NoError(t, err, "middleware creation should succeed with non-nil reader")
require.NotNil(t, authMw)

Expand Down Expand Up @@ -145,7 +145,7 @@ func TestNewOIDCAuthMiddleware_UpstreamTokenReaderWiring(t *testing.T) {
t.Run("upstream tokens nil when reader is nil", func(t *testing.T) {
t.Parallel()

authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, nil)
authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, nil, nil)
require.NoError(t, err)
require.NotNil(t, authMw)

Expand Down Expand Up @@ -181,7 +181,7 @@ func TestNewOIDCAuthMiddleware_UpstreamTokenReaderWiring(t *testing.T) {
reader := upstreamtokenmocks.NewMockTokenReader(ctrl)
// No EXPECT -- reader should not be called when tsid is absent.

authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, reader)
authMw, _, err := newOIDCAuthMiddleware(t.Context(), oidcCfg, reader, nil)
require.NoError(t, err)
require.NotNil(t, authMw)

Expand Down
Loading