diff --git a/cmd/serve.go b/cmd/serve.go index 07ed23a13..5c0dee646 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -412,8 +412,9 @@ func buildAPIDependencies( roleService := role.NewService(roleRepository, relationService, permissionService, auditRecordRepository, cfg.App.PAT.DeniedPermissionsSet()) policyService := policy.NewService(policyPGRepository, relationService, roleService) userService := user.NewService(userRepository, relationService, policyService, roleService) + patValidator := userpat.NewValidator(logger, userPATRepo, cfg.App.PAT) authnService := authenticate.NewService(logger, cfg.App.Authentication, - postgres.NewFlowRepository(logger, dbc), mailDialer, tokenService, sessionService, userService, serviceUserService, webAuthConfig) + postgres.NewFlowRepository(logger, dbc), mailDialer, tokenService, sessionService, userService, serviceUserService, webAuthConfig, patValidator) groupService := group.NewService(groupRepository, relationService, authnService, policyService) organizationService := organization.NewService(organizationRepository, relationService, userService, authnService, policyService, preferenceService, auditRecordRepository) diff --git a/core/authenticate/authenticate.go b/core/authenticate/authenticate.go index 553cb6288..f84fb5f79 100644 --- a/core/authenticate/authenticate.go +++ b/core/authenticate/authenticate.go @@ -7,6 +7,7 @@ import ( "github.com/raystack/frontier/core/serviceuser" "github.com/raystack/frontier/core/user" + pat "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/pkg/metadata" @@ -42,6 +43,8 @@ const ( // ClientCredentialsClientAssertion is used to authenticate using client_id and client_secret // that provides access token for the client ClientCredentialsClientAssertion ClientAssertion = "client_credentials" + // PATClientAssertion is used to authenticate using Personal Access Token + PATClientAssertion ClientAssertion = "pat" // PassthroughHeaderClientAssertion is used to authenticate using headers passed by the client // this is non secure way of authenticating client in test environments PassthroughHeaderClientAssertion ClientAssertion = "passthrough_header" @@ -53,9 +56,10 @@ func (a ClientAssertion) String() string { var APIAssertions = []ClientAssertion{ SessionClientAssertion, + PATClientAssertion, AccessTokenClientAssertion, - OpaqueTokenClientAssertion, JWTGrantClientAssertion, + OpaqueTokenClientAssertion, // ClientCredentialsClientAssertion should be removed in future to avoid DDOS attacks on CPU // and should only be allowed to be used get access token for the client ClientCredentialsClientAssertion, @@ -131,9 +135,10 @@ type Principal struct { // ID is the unique identifier of principal ID string // Type is the namespace of principal - // E.g. app/user, app/serviceuser + // E.g. app/user, app/serviceuser, app/pat Type string User *user.User ServiceUser *serviceuser.ServiceUser + PAT *pat.PAT } diff --git a/core/authenticate/authenticators.go b/core/authenticate/authenticators.go new file mode 100644 index 000000000..5dd361d36 --- /dev/null +++ b/core/authenticate/authenticators.go @@ -0,0 +1,227 @@ +package authenticate + +import ( + "context" + "encoding/base64" + "fmt" + "strings" + + "github.com/lestrrat-go/jwx/v2/jwt" + frontiersession "github.com/raystack/frontier/core/authenticate/session" + "github.com/raystack/frontier/core/authenticate/token" + patErrors "github.com/raystack/frontier/core/userpat/errors" + "github.com/raystack/frontier/internal/bootstrap/schema" + "github.com/raystack/frontier/pkg/errors" + "github.com/raystack/frontier/pkg/utils" +) + +// AuthenticatorFunc attempts to authenticate a request. +// Returns (Principal, nil) on success, errSkip if not applicable (try next), +// or any other error for a terminal authentication failure. +type AuthenticatorFunc func(ctx context.Context, s *Service) (Principal, error) + +// authenticators maps each ClientAssertion to its authentication function. +var authenticators = map[ClientAssertion]AuthenticatorFunc{ + SessionClientAssertion: authenticateWithSession, + PATClientAssertion: authenticateWithPAT, + AccessTokenClientAssertion: authenticateWithAccessToken, + JWTGrantClientAssertion: authenticateWithJWTGrant, + ClientCredentialsClientAssertion: authenticateWithClientCredentials, + OpaqueTokenClientAssertion: authenticateWithClientCredentials, + PassthroughHeaderClientAssertion: authenticateWithPassthroughHeader, +} + +// authenticateWithSession extracts user from session cookie. +// Copied from original GetPrincipal session block. +func authenticateWithSession(ctx context.Context, s *Service) (Principal, error) { + session, err := s.sessionService.ExtractFromContext(ctx) + if err == nil && session.IsValid(s.Now()) && utils.IsValidUUID(session.UserID) { + // userID is a valid uuid + currentUser, err := s.userService.GetByID(ctx, session.UserID) + if err != nil { + s.log.Debug(fmt.Sprintf("unable to get session user by id: %v", err)) + return Principal{}, err + } + return Principal{ + ID: currentUser.ID, + Type: schema.UserPrincipal, + User: ¤tUser, + }, nil + } + if err != nil && !errors.Is(err, frontiersession.ErrNoSession) { + s.log.Debug(fmt.Sprintf("unable to extract session from context: %v", err)) + return Principal{}, err + } + return Principal{}, errSkip +} + +// authenticateWithPAT validates a personal access token. +func authenticateWithPAT(ctx context.Context, s *Service) (Principal, error) { + value, ok := GetTokenFromContext(ctx) + if !ok { + return Principal{}, errSkip + } + + pat, err := s.userPATService.Validate(ctx, value) + if err != nil { + if errors.Is(err, patErrors.ErrInvalidPAT) || errors.Is(err, patErrors.ErrDisabled) { + return Principal{}, errSkip + } + s.log.Debug("PAT validation failed", "err", err) + return Principal{}, err + } + + // resolve the owning user so downstream handlers can access principal.User + currentUser, err := s.userService.GetByID(ctx, pat.UserID) + if err != nil { + s.log.Debug("failed to get PAT owner", "err", err) + return Principal{}, err + } + + return Principal{ + ID: pat.ID, + Type: schema.PATPrincipal, + PAT: &pat, + User: ¤tUser, + }, nil +} + +// authenticateWithAccessToken validates a Frontier-issued JWT access token. +// Copied from original GetPrincipal access token block. +func authenticateWithAccessToken(ctx context.Context, s *Service) (Principal, error) { + userToken, ok := GetTokenFromContext(ctx) + if !ok { + return Principal{}, errSkip + } + + insecureJWT, err := jwt.ParseInsecure([]byte(userToken)) + if err != nil { + // NOTE: in the original code, AccessToken and JWTGrant were in the same if-block, + // so JWT parse failure fell through to GetByJWT. With separate authenticators, + // errSkip is required to preserve that behavior. + s.log.Debug(fmt.Sprintf("unable to parse token: %v", err)) + return Principal{}, errSkip + } + + // check type of jwt + if genClaim, ok := insecureJWT.Get(token.GeneratedClaimKey); ok { + // jwt generated by frontier using public key + claimVal, ok := genClaim.(string) + if !ok || claimVal != token.GeneratedClaimValue { + s.log.Debug("generated claim value mismatch") + return Principal{}, errors.ErrUnauthenticated + } + + // extract user from token if present as its created by frontier + userID, claims, err := s.internalTokenService.Parse(ctx, []byte(userToken)) + if err != nil || !utils.IsValidUUID(userID) { + s.log.Debug("failed to parse as internal token ", "err", err) + return Principal{}, errors.ErrUnauthenticated + } + + // userID is a valid uuid + if claims[token.SubTypeClaimsKey] == schema.ServiceUserPrincipal { + currentUser, err := s.serviceUserService.Get(ctx, userID) + if err != nil { + s.log.Debug("failed to get service user", "err", err) + return Principal{}, err + } + return Principal{ + ID: currentUser.ID, + Type: schema.ServiceUserPrincipal, + ServiceUser: ¤tUser, + }, nil + } + + currentUser, err := s.userService.GetByID(ctx, userID) + if err != nil { + s.log.Debug("failed to get user", "err", err) + return Principal{}, err + } + return Principal{ + ID: currentUser.ID, + Type: schema.UserPrincipal, + User: ¤tUser, + }, nil + } + + // NOTE: in the original code, a valid JWT without GeneratedClaimKey fell through to + // GetByJWT within the same if-block. errSkip preserves that behavior. + return Principal{}, errSkip +} + +// authenticateWithJWTGrant validates a service user JWT grant token. +// Copied from original GetPrincipal jwt grant block. +func authenticateWithJWTGrant(ctx context.Context, s *Service) (Principal, error) { + userToken, ok := GetTokenFromContext(ctx) + if !ok { + return Principal{}, errSkip + } + + serviceUser, err := s.serviceUserService.GetByJWT(ctx, userToken) + if err == nil { + return Principal{ + ID: serviceUser.ID, + Type: schema.ServiceUserPrincipal, + ServiceUser: &serviceUser, + }, nil + } + s.log.Debug("failed to parse as user token ", "err", err) + return Principal{}, errors.ErrUnauthenticated +} + +// authenticateWithClientCredentials validates client_id:client_secret credentials. +// Copied from original GetPrincipal client credentials block. +func authenticateWithClientCredentials(ctx context.Context, s *Service) (Principal, error) { + userSecretRaw, ok := GetSecretFromContext(ctx) + if !ok { + return Principal{}, errSkip + } + + // verify client secret + userSecret, err := base64.StdEncoding.DecodeString(userSecretRaw) + if err != nil { + s.log.Debug("failed to decode user secret", "err", err) + return Principal{}, errors.ErrUnauthenticated + } + userSecretParts := strings.Split(string(userSecret), ":") + if len(userSecretParts) != 2 { + s.log.Debug("failed to parse user secret") + return Principal{}, errors.ErrUnauthenticated + } + clientID, clientSecret := userSecretParts[0], userSecretParts[1] + + // extract user from secret if it's a service user + serviceUser, err := s.serviceUserService.GetBySecret(ctx, clientID, clientSecret) + if err == nil { + return Principal{ + ID: serviceUser.ID, + Type: schema.ServiceUserPrincipal, + ServiceUser: &serviceUser, + }, nil + } + s.log.Debug("failed to authenticate with client credentials", "err", err) + return Principal{}, errors.ErrUnauthenticated +} + +// authenticateWithPassthroughHeader extracts user from email header. +// Copied from original GetPrincipal passthrough block. +func authenticateWithPassthroughHeader(ctx context.Context, s *Service) (Principal, error) { + // check if header with user email is set + // TODO(kushsharma): this should ideally be deprecated + val, ok := GetEmailFromContext(ctx) + if !ok || len(val) == 0 { + return Principal{}, errSkip + } + + currentUser, err := s.getOrCreateUser(ctx, strings.TrimSpace(val), strings.Split(val, "@")[0]) + if err != nil { + s.log.Debug("failed to get user", "err", err) + return Principal{}, err + } + return Principal{ + ID: currentUser.ID, + Type: schema.UserPrincipal, + User: ¤tUser, + }, nil +} diff --git a/core/authenticate/errors.go b/core/authenticate/errors.go index d2e4172bc..94002e7a5 100644 --- a/core/authenticate/errors.go +++ b/core/authenticate/errors.go @@ -4,4 +4,7 @@ import "errors" var ( ErrInvalidID = errors.New("user id is invalid") + + // errSkip signals that this authenticator doesn't apply to the request. + errSkip = errors.New("skip authenticator") ) diff --git a/core/authenticate/mocks/user_pat_service.go b/core/authenticate/mocks/user_pat_service.go new file mode 100644 index 000000000..f4da52e89 --- /dev/null +++ b/core/authenticate/mocks/user_pat_service.go @@ -0,0 +1,94 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + models "github.com/raystack/frontier/core/userpat/models" + mock "github.com/stretchr/testify/mock" +) + +// UserPATService is an autogenerated mock type for the UserPATService type +type UserPATService struct { + mock.Mock +} + +type UserPATService_Expecter struct { + mock *mock.Mock +} + +func (_m *UserPATService) EXPECT() *UserPATService_Expecter { + return &UserPATService_Expecter{mock: &_m.Mock} +} + +// Validate provides a mock function with given fields: ctx, value +func (_m *UserPATService) Validate(ctx context.Context, value string) (models.PAT, error) { + ret := _m.Called(ctx, value) + + if len(ret) == 0 { + panic("no return value specified for Validate") + } + + var r0 models.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (models.PAT, error)); ok { + return rf(ctx, value) + } + if rf, ok := ret.Get(0).(func(context.Context, string) models.PAT); ok { + r0 = rf(ctx, value) + } else { + r0 = ret.Get(0).(models.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, value) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UserPATService_Validate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Validate' +type UserPATService_Validate_Call struct { + *mock.Call +} + +// Validate is a helper method to define mock.On call +// - ctx context.Context +// - value string +func (_e *UserPATService_Expecter) Validate(ctx interface{}, value interface{}) *UserPATService_Validate_Call { + return &UserPATService_Validate_Call{Call: _e.mock.On("Validate", ctx, value)} +} + +func (_c *UserPATService_Validate_Call) Run(run func(ctx context.Context, value string)) *UserPATService_Validate_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *UserPATService_Validate_Call) Return(_a0 models.PAT, _a1 error) *UserPATService_Validate_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *UserPATService_Validate_Call) RunAndReturn(run func(context.Context, string) (models.PAT, error)) *UserPATService_Validate_Call { + _c.Call.Return(run) + return _c +} + +// NewUserPATService creates a new instance of UserPATService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewUserPATService(t interface { + mock.TestingT + Cleanup(func()) +}) *UserPATService { + mock := &UserPATService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/authenticate/service.go b/core/authenticate/service.go index abdacf861..45a7754be 100644 --- a/core/authenticate/service.go +++ b/core/authenticate/service.go @@ -18,10 +18,9 @@ import ( "golang.org/x/exp/slices" - "github.com/lestrrat-go/jwx/v2/jwt" - frontiersession "github.com/raystack/frontier/core/authenticate/session" "github.com/raystack/frontier/core/serviceuser" + patModels "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/frontier/internal/metrics" "github.com/raystack/frontier/pkg/errors" @@ -89,6 +88,10 @@ type TokenService interface { Parse(ctx context.Context, userToken []byte) (string, map[string]any, error) } +type UserPATService interface { + Validate(ctx context.Context, value string) (patModels.PAT, error) +} + type Service struct { log log.Logger cron *cron.Cron @@ -100,12 +103,14 @@ type Service struct { internalTokenService TokenService sessionService SessionService serviceUserService ServiceUserService + userPATService UserPATService webAuth *webauthn.WebAuthn } func NewService(logger log.Logger, config Config, flowRepo FlowRepository, mailDialer mailer.Dialer, tokenService TokenService, sessionService SessionService, - userService UserService, serviceUserService ServiceUserService, webAuthConfig *webauthn.WebAuthn) *Service { + userService UserService, serviceUserService ServiceUserService, webAuthConfig *webauthn.WebAuthn, + userPATService UserPATService) *Service { r := &Service{ log: logger, cron: cron.New(cron.WithChain( @@ -122,6 +127,7 @@ func NewService(logger log.Logger, config Config, flowRepo FlowRepository, internalTokenService: tokenService, sessionService: sessionService, serviceUserService: serviceUserService, + userPATService: userPATService, webAuth: webAuthConfig, } return r @@ -746,157 +752,26 @@ func (s Service) GetPrincipal(ctx context.Context, assertions ...ClientAssertion defer promCollect() } - var currentPrincipal Principal - if len(assertions) == 0 { - // check all assertions - assertions = APIAssertions - } - - // check if already enriched by auth middleware if val, ok := GetPrincipalFromContext(ctx); ok { - currentPrincipal = *val - return currentPrincipal, nil + return *val, nil } - // extract user from session if present - if slices.Contains[[]ClientAssertion](assertions, SessionClientAssertion) { - session, err := s.sessionService.ExtractFromContext(ctx) - if err == nil && session.IsValid(s.Now()) && utils.IsValidUUID(session.UserID) { - // userID is a valid uuid - currentUser, err := s.userService.GetByID(ctx, session.UserID) - if err != nil { - s.log.Debug(fmt.Sprintf("unable to get session user by id: %v", err)) - return Principal{}, err - } - return Principal{ - ID: currentUser.ID, - Type: schema.UserPrincipal, - User: ¤tUser, - }, nil - } - if err != nil && !errors.Is(err, frontiersession.ErrNoSession) { - s.log.Debug(fmt.Sprintf("unable to extract session from context: %v", err)) - return Principal{}, err - } + if len(assertions) == 0 { + // check all assertions + assertions = APIAssertions } - // check for token - userToken, tokenOK := GetTokenFromContext(ctx) - if tokenOK { - if slices.Contains[[]ClientAssertion](assertions, AccessTokenClientAssertion) { - insecureJWT, err := jwt.ParseInsecure([]byte(userToken)) - if err != nil { - s.log.Debug(fmt.Sprintf("unable to parse token: %v", err)) - return Principal{}, errors.ErrUnauthenticated - } - // check type of jwt - if genClaim, ok := insecureJWT.Get(token.GeneratedClaimKey); ok { - // jwt generated by frontier using public key - claimVal, ok := genClaim.(string) - if !ok || claimVal != token.GeneratedClaimValue { - s.log.Debug("generated claim value mismatch") - return Principal{}, errors.ErrUnauthenticated - } - - // extract user from token if present as its created by frontier - userID, claims, err := s.internalTokenService.Parse(ctx, []byte(userToken)) - if err != nil || !utils.IsValidUUID(userID) { - s.log.Debug("failed to parse as internal token ", "err", err) - return Principal{}, errors.ErrUnauthenticated - } - - // userID is a valid uuid - if claims[token.SubTypeClaimsKey] == schema.ServiceUserPrincipal { - currentUser, err := s.serviceUserService.Get(ctx, userID) - if err != nil { - s.log.Debug("failed to get service user", "err", err) - return Principal{}, err - } - return Principal{ - ID: currentUser.ID, - Type: schema.ServiceUserPrincipal, - ServiceUser: ¤tUser, - }, nil - } - - currentUser, err := s.userService.GetByID(ctx, userID) - if err != nil { - s.log.Debug("failed to get user", "err", err) - return Principal{}, err - } - return Principal{ - ID: currentUser.ID, - Type: schema.UserPrincipal, - User: ¤tUser, - }, nil - } - } - - // extract user from token if it's a service user - if slices.Contains[[]ClientAssertion](assertions, JWTGrantClientAssertion) { - serviceUser, err := s.serviceUserService.GetByJWT(ctx, userToken) - if err == nil { - return Principal{ - ID: serviceUser.ID, - Type: schema.ServiceUserPrincipal, - ServiceUser: &serviceUser, - }, nil - } - if err != nil { - s.log.Debug("failed to parse as user token ", "err", err) - return Principal{}, errors.ErrUnauthenticated - } + for _, assertion := range assertions { + authenticator, exists := authenticators[assertion] + if !exists { + continue } - } - - // check for client secret - if slices.Contains[[]ClientAssertion](assertions, ClientCredentialsClientAssertion) || - slices.Contains[[]ClientAssertion](assertions, OpaqueTokenClientAssertion) { - userSecretRaw, secretOK := GetSecretFromContext(ctx) - if secretOK { - // verify client secret - userSecret, err := base64.StdEncoding.DecodeString(userSecretRaw) - if err != nil { - s.log.Debug("failed to decode user secret", "err", err) - return Principal{}, errors.ErrUnauthenticated - } - userSecretParts := strings.Split(string(userSecret), ":") - if len(userSecretParts) != 2 { - s.log.Debug("failed to parse user secret", "err", err) - return Principal{}, errors.ErrUnauthenticated - } - clientID, clientSecret := userSecretParts[0], userSecretParts[1] - - // extract user from secret if it's a service user - serviceUser, err := s.serviceUserService.GetBySecret(ctx, clientID, clientSecret) - if err == nil { - return Principal{ - ID: serviceUser.ID, - Type: schema.ServiceUserPrincipal, - ServiceUser: &serviceUser, - }, nil - } - if err != nil { - s.log.Debug("failed to parse as user token ", "err", err) - return Principal{}, errors.ErrUnauthenticated - } + principal, err := authenticator(ctx, &s) + if err == nil { + return principal, nil } - } - - if slices.Contains[[]ClientAssertion](assertions, PassthroughHeaderClientAssertion) { - // check if header with user email is set - // TODO(kushsharma): this should ideally be deprecated - if val, ok := GetEmailFromContext(ctx); ok && len(val) > 0 { - currentUser, err := s.getOrCreateUser(ctx, strings.TrimSpace(val), strings.Split(val, "@")[0]) - if err != nil { - s.log.Debug("failed to get user", "err", err) - return Principal{}, err - } - return Principal{ - ID: currentUser.ID, - Type: schema.UserPrincipal, - User: ¤tUser, - }, nil + if !errors.Is(err, errSkip) { + return Principal{}, err } } diff --git a/core/authenticate/service_test.go b/core/authenticate/service_test.go index 7ce41de71..17d0069d9 100644 --- a/core/authenticate/service_test.go +++ b/core/authenticate/service_test.go @@ -76,7 +76,7 @@ func TestService_GetPrincipal(t *testing.T) { }, wantErr: false, setup: func() *authenticate.Service { - return authenticate.NewService(nil, authenticate.Config{}, nil, nil, nil, nil, nil, nil, nil) + return authenticate.NewService(nil, authenticate.Config{}, nil, nil, nil, nil, nil, nil, nil, nil) }, }, { @@ -111,7 +111,7 @@ func TestService_GetPrincipal(t *testing.T) { }, nil) return authenticate.NewService(nil, authenticate.Config{}, - mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil) + mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, nil) }, }, { @@ -135,7 +135,7 @@ func TestService_GetPrincipal(t *testing.T) { mockSessionService.EXPECT().ExtractFromContext(mock.Anything).Return(mockSess, nil) return authenticate.NewService(log.NewLogrus(), authenticate.Config{}, - mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil) + mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, nil) }, }, { @@ -163,7 +163,7 @@ func TestService_GetPrincipal(t *testing.T) { }, nil) return authenticate.NewService(nil, authenticate.Config{}, - mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil) + mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, nil) }, }, { @@ -181,7 +181,7 @@ func TestService_GetPrincipal(t *testing.T) { mockTokenService.EXPECT().Parse(mock.Anything, tokenBytes).Return("", map[string]interface{}{}, errors.New("invalid token")) return authenticate.NewService(log.NewLogrus(), authenticate.Config{}, - mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil) + mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, nil) }, }, { @@ -208,7 +208,7 @@ func TestService_GetPrincipal(t *testing.T) { }, nil) return authenticate.NewService(nil, authenticate.Config{}, - mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil) + mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, nil) }, }, { @@ -226,7 +226,7 @@ func TestService_GetPrincipal(t *testing.T) { mockServiceUserService.EXPECT().GetByJWT(mock.Anything, string(tokenBytes)).Return(serviceuser.ServiceUser{}, errors.New("invalid")) return authenticate.NewService(log.NewLogrus(), authenticate.Config{}, - mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil) + mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, nil) }, }, { @@ -253,7 +253,7 @@ func TestService_GetPrincipal(t *testing.T) { }, nil) return authenticate.NewService(nil, authenticate.Config{}, - mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil) + mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, nil) }, }, { @@ -280,7 +280,7 @@ func TestService_GetPrincipal(t *testing.T) { }, nil) return authenticate.NewService(nil, authenticate.Config{}, - mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil) + mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, nil) }, }, } @@ -339,7 +339,7 @@ func TestService_StartFlow(t *testing.T) { wantErr: authenticate.ErrUnsupportedMethod, setup: func() *authenticate.Service { return authenticate.NewService(nil, authenticate.Config{}, nil, nil, - nil, nil, nil, nil, nil) + nil, nil, nil, nil, nil, nil) }, }, { @@ -370,7 +370,7 @@ func TestService_StartFlow(t *testing.T) { TestUsers: testusers.Config{Enabled: true, OTP: "111111", Domain: "example.com"}, }, mockFlowRepo, mockDialer, nil, nil, - nil, nil, nil) + nil, nil, nil, nil) srv.Now = func() time.Time { return timeNow } @@ -402,7 +402,7 @@ func TestService_StartFlow(t *testing.T) { TestUsers: testusers.Config{Enabled: true, OTP: "111111", Domain: "example.com"}, }, mockFlowRepo, mockDialer, nil, nil, - nil, nil, nil) + nil, nil, nil, nil) srv.Now = func() time.Time { return timeNow } @@ -433,7 +433,7 @@ func TestService_StartFlow(t *testing.T) { MailOTP: authenticate.MailOTPConfig{}, }, mockFlowRepo, mockDialer, nil, nil, - nil, nil, nil) + nil, nil, nil, nil) srv.Now = func() time.Time { return timeNow } diff --git a/core/relation/errors.go b/core/relation/errors.go index 9d09538d1..cd87d886b 100644 --- a/core/relation/errors.go +++ b/core/relation/errors.go @@ -11,4 +11,5 @@ var ( ErrCreatingRelationInStore = errors.New("error while creating relation") ErrCreatingRelationInAuthzEngine = errors.New("error while creating relation in authz engine") ErrFetchingUser = errors.New("error while fetching user") + ErrSubjectNotAllowed = errors.New("subject type is not allowed on this relation") ) diff --git a/core/relation/service.go b/core/relation/service.go index 0549da742..25341e217 100644 --- a/core/relation/service.go +++ b/core/relation/service.go @@ -5,6 +5,10 @@ import ( "errors" "fmt" "regexp" + + "github.com/raystack/frontier/internal/bootstrap/schema" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type Service struct { @@ -35,6 +39,12 @@ func (s Service) Create(ctx context.Context, rel Relation) (Relation, error) { err = s.authzRepository.Add(ctx, createdRelation) if err != nil { + // PAT subjects may be rejected by the authz schema for relations they are not allowed on + if createdRelation.Subject.Namespace == schema.PATPrincipal { + if st, ok := status.FromError(err); ok && st.Code() == codes.InvalidArgument { + return Relation{}, fmt.Errorf("%w: %s", ErrSubjectNotAllowed, st.Message()) + } + } return Relation{}, fmt.Errorf("%w: %s", ErrCreatingRelationInAuthzEngine, err.Error()) } diff --git a/core/userpat/errors.go b/core/userpat/errors/errors.go similarity index 84% rename from core/userpat/errors.go rename to core/userpat/errors/errors.go index 297069d6b..ea4c1063d 100644 --- a/core/userpat/errors.go +++ b/core/userpat/errors/errors.go @@ -1,4 +1,4 @@ -package userpat +package errors import "errors" @@ -6,7 +6,8 @@ var ( ErrNotFound = errors.New("personal access token not found") ErrConflict = errors.New("personal access token with this name already exists") ErrExpired = errors.New("personal access token has expired") - ErrInvalidToken = errors.New("personal access token is invalid") + ErrInvalidPAT = errors.New("not a personal access token") + ErrMalformedPAT = errors.New("personal access token is malformed") ErrLimitExceeded = errors.New("maximum number of personal access tokens reached") ErrDisabled = errors.New("personal access tokens are not enabled") ErrExpiryExceeded = errors.New("expiry exceeds maximum allowed lifetime") diff --git a/core/userpat/mocks/repository.go b/core/userpat/mocks/repository.go index aae660f04..b4a022075 100644 --- a/core/userpat/mocks/repository.go +++ b/core/userpat/mocks/repository.go @@ -5,8 +5,10 @@ package mocks import ( context "context" - userpat "github.com/raystack/frontier/core/userpat" + models "github.com/raystack/frontier/core/userpat/models" mock "github.com/stretchr/testify/mock" + + time "time" ) // Repository is an autogenerated mock type for the Repository type @@ -81,25 +83,25 @@ func (_c *Repository_CountActive_Call) RunAndReturn(run func(context.Context, st } // Create provides a mock function with given fields: ctx, pat -func (_m *Repository) Create(ctx context.Context, pat userpat.PAT) (userpat.PAT, error) { +func (_m *Repository) Create(ctx context.Context, pat models.PAT) (models.PAT, error) { ret := _m.Called(ctx, pat) if len(ret) == 0 { panic("no return value specified for Create") } - var r0 userpat.PAT + var r0 models.PAT var r1 error - if rf, ok := ret.Get(0).(func(context.Context, userpat.PAT) (userpat.PAT, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, models.PAT) (models.PAT, error)); ok { return rf(ctx, pat) } - if rf, ok := ret.Get(0).(func(context.Context, userpat.PAT) userpat.PAT); ok { + if rf, ok := ret.Get(0).(func(context.Context, models.PAT) models.PAT); ok { r0 = rf(ctx, pat) } else { - r0 = ret.Get(0).(userpat.PAT) + r0 = ret.Get(0).(models.PAT) } - if rf, ok := ret.Get(1).(func(context.Context, userpat.PAT) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, models.PAT) error); ok { r1 = rf(ctx, pat) } else { r1 = ret.Error(1) @@ -115,24 +117,129 @@ type Repository_Create_Call struct { // Create is a helper method to define mock.On call // - ctx context.Context -// - pat userpat.PAT +// - pat models.PAT func (_e *Repository_Expecter) Create(ctx interface{}, pat interface{}) *Repository_Create_Call { return &Repository_Create_Call{Call: _e.mock.On("Create", ctx, pat)} } -func (_c *Repository_Create_Call) Run(run func(ctx context.Context, pat userpat.PAT)) *Repository_Create_Call { +func (_c *Repository_Create_Call) Run(run func(ctx context.Context, pat models.PAT)) *Repository_Create_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(userpat.PAT)) + run(args[0].(context.Context), args[1].(models.PAT)) }) return _c } -func (_c *Repository_Create_Call) Return(_a0 userpat.PAT, _a1 error) *Repository_Create_Call { +func (_c *Repository_Create_Call) Return(_a0 models.PAT, _a1 error) *Repository_Create_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *Repository_Create_Call) RunAndReturn(run func(context.Context, userpat.PAT) (userpat.PAT, error)) *Repository_Create_Call { +func (_c *Repository_Create_Call) RunAndReturn(run func(context.Context, models.PAT) (models.PAT, error)) *Repository_Create_Call { + _c.Call.Return(run) + return _c +} + +// GetBySecretHash provides a mock function with given fields: ctx, secretHash +func (_m *Repository) GetBySecretHash(ctx context.Context, secretHash string) (models.PAT, error) { + ret := _m.Called(ctx, secretHash) + + if len(ret) == 0 { + panic("no return value specified for GetBySecretHash") + } + + var r0 models.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (models.PAT, error)); ok { + return rf(ctx, secretHash) + } + if rf, ok := ret.Get(0).(func(context.Context, string) models.PAT); ok { + r0 = rf(ctx, secretHash) + } else { + r0 = ret.Get(0).(models.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, secretHash) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_GetBySecretHash_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetBySecretHash' +type Repository_GetBySecretHash_Call struct { + *mock.Call +} + +// GetBySecretHash is a helper method to define mock.On call +// - ctx context.Context +// - secretHash string +func (_e *Repository_Expecter) GetBySecretHash(ctx interface{}, secretHash interface{}) *Repository_GetBySecretHash_Call { + return &Repository_GetBySecretHash_Call{Call: _e.mock.On("GetBySecretHash", ctx, secretHash)} +} + +func (_c *Repository_GetBySecretHash_Call) Run(run func(ctx context.Context, secretHash string)) *Repository_GetBySecretHash_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Repository_GetBySecretHash_Call) Return(_a0 models.PAT, _a1 error) *Repository_GetBySecretHash_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_GetBySecretHash_Call) RunAndReturn(run func(context.Context, string) (models.PAT, error)) *Repository_GetBySecretHash_Call { + _c.Call.Return(run) + return _c +} + +// UpdateLastUsedAt provides a mock function with given fields: ctx, id, at +func (_m *Repository) UpdateLastUsedAt(ctx context.Context, id string, at time.Time) error { + ret := _m.Called(ctx, id, at) + + if len(ret) == 0 { + panic("no return value specified for UpdateLastUsedAt") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, time.Time) error); ok { + r0 = rf(ctx, id, at) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Repository_UpdateLastUsedAt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateLastUsedAt' +type Repository_UpdateLastUsedAt_Call struct { + *mock.Call +} + +// UpdateLastUsedAt is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - at time.Time +func (_e *Repository_Expecter) UpdateLastUsedAt(ctx interface{}, id interface{}, at interface{}) *Repository_UpdateLastUsedAt_Call { + return &Repository_UpdateLastUsedAt_Call{Call: _e.mock.On("UpdateLastUsedAt", ctx, id, at)} +} + +func (_c *Repository_UpdateLastUsedAt_Call) Run(run func(ctx context.Context, id string, at time.Time)) *Repository_UpdateLastUsedAt_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(time.Time)) + }) + return _c +} + +func (_c *Repository_UpdateLastUsedAt_Call) Return(_a0 error) *Repository_UpdateLastUsedAt_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Repository_UpdateLastUsedAt_Call) RunAndReturn(run func(context.Context, string, time.Time) error) *Repository_UpdateLastUsedAt_Call { _c.Call.Return(run) return _c } diff --git a/core/userpat/models/pat.go b/core/userpat/models/pat.go new file mode 100644 index 000000000..618843d07 --- /dev/null +++ b/core/userpat/models/pat.go @@ -0,0 +1,20 @@ +package models + +import ( + "time" + + "github.com/raystack/frontier/pkg/metadata" +) + +type PAT struct { + ID string `rql:"name=id,type=string"` + UserID string `rql:"name=user_id,type=string"` + OrgID string `rql:"name=org_id,type=string"` + Title string `rql:"name=title,type=string"` + SecretHash string `json:"-"` + Metadata metadata.Metadata + LastUsedAt *time.Time `rql:"name=last_used_at,type=datetime"` // last_used_at can be null + ExpiresAt time.Time `rql:"name=expires_at,type=datetime"` + CreatedAt time.Time `rql:"name=created_at,type=datetime"` + UpdatedAt time.Time `rql:"name=updated_at,type=datetime"` +} diff --git a/core/userpat/service.go b/core/userpat/service.go index 78ce57e50..da9ca16e4 100644 --- a/core/userpat/service.go +++ b/core/userpat/service.go @@ -15,6 +15,8 @@ import ( "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/policy" "github.com/raystack/frontier/core/role" + paterrors "github.com/raystack/frontier/core/userpat/errors" + patmodels "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/bootstrap/schema" pkgAuditRecord "github.com/raystack/frontier/pkg/auditrecord" "github.com/raystack/salt/log" @@ -77,19 +79,19 @@ type CreateRequest struct { // the configured maximum PAT lifetime. func (s *Service) ValidateExpiry(expiresAt time.Time) error { if !expiresAt.After(time.Now()) { - return ErrExpiryInPast + return paterrors.ErrExpiryInPast } if expiresAt.After(time.Now().Add(s.config.MaxExpiry())) { - return ErrExpiryExceeded + return paterrors.ErrExpiryExceeded } return nil } // Create generates a new PAT and returns it with the plaintext value. // The plaintext value is only available at creation time. -func (s *Service) Create(ctx context.Context, req CreateRequest) (PAT, string, error) { +func (s *Service) Create(ctx context.Context, req CreateRequest) (patmodels.PAT, string, error) { if !s.config.Enabled { - return PAT{}, "", ErrDisabled + return patmodels.PAT{}, "", paterrors.ErrDisabled } // NOTE: CountActive + Create is not atomic (TOCTOU race). Two concurrent requests @@ -98,23 +100,23 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (PAT, string, e // use an atomic INSERT ... SELECT with a count subquery in the WHERE clause. count, err := s.repo.CountActive(ctx, req.UserID, req.OrgID) if err != nil { - return PAT{}, "", fmt.Errorf("counting active PATs: %w", err) + return patmodels.PAT{}, "", fmt.Errorf("counting active PATs: %w", err) } if count >= s.config.MaxPerUserPerOrg { - return PAT{}, "", ErrLimitExceeded + return patmodels.PAT{}, "", paterrors.ErrLimitExceeded } roles, err := s.resolveAndValidateRoles(ctx, req.RoleIDs) if err != nil { - return PAT{}, "", err + return patmodels.PAT{}, "", err } patValue, secretHash, err := s.generatePAT() if err != nil { - return PAT{}, "", err + return patmodels.PAT{}, "", err } - pat := PAT{ + pat := patmodels.PAT{ UserID: req.UserID, OrgID: req.OrgID, Title: req.Title, @@ -125,11 +127,11 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (PAT, string, e created, err := s.repo.Create(ctx, pat) if err != nil { - return PAT{}, "", err + return patmodels.PAT{}, "", err } if err := s.createPolicies(ctx, created.ID, req.OrgID, roles, req.ProjectIDs); err != nil { - return PAT{}, "", fmt.Errorf("creating policies: %w", err) + return patmodels.PAT{}, "", fmt.Errorf("creating policies: %w", err) } // TODO: move audit record creation into the same transaction as PAT creation to avoid partial state where PAT exists but audit record doesn't. @@ -144,7 +146,7 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (PAT, string, e } // createAuditRecord logs a PAT lifecycle event with org context and PAT metadata. -func (s *Service) createAuditRecord(ctx context.Context, event pkgAuditRecord.Event, pat PAT, occurredAt time.Time, targetMetadata map[string]any) error { +func (s *Service) createAuditRecord(ctx context.Context, event pkgAuditRecord.Event, pat patmodels.PAT, occurredAt time.Time, targetMetadata map[string]any) error { orgName := "" if org, err := s.orgService.GetRaw(ctx, pat.OrgID); err == nil { orgName = org.Title @@ -194,7 +196,7 @@ func (s *Service) resolveAndValidateRoles(ctx context.Context, roleIDs []string) missing = append(missing, id) } } - return nil, fmt.Errorf("role IDs not found: %v: %w", missing, ErrRoleNotFound) + return nil, fmt.Errorf("role IDs not found: %v: %w", missing, paterrors.ErrRoleNotFound) } if err := s.validateRolePermissions(roles); err != nil { @@ -203,11 +205,11 @@ func (s *Service) resolveAndValidateRoles(ctx context.Context, roleIDs []string) for _, r := range roles { if len(r.Scopes) == 0 { - return nil, fmt.Errorf("role %s has scopes %v: %w", r.Name, r.Scopes, ErrUnsupportedScope) + return nil, fmt.Errorf("role %s has no scopes defined: %w", r.Name, paterrors.ErrUnsupportedScope) } for _, scope := range r.Scopes { if scope != schema.ProjectNamespace && scope != schema.OrganizationNamespace { - return nil, fmt.Errorf("role %s has scopes %v: %w", r.Name, r.Scopes, ErrUnsupportedScope) + return nil, fmt.Errorf("role %s has scopes %v: %w", r.Name, r.Scopes, paterrors.ErrUnsupportedScope) } } } @@ -229,7 +231,7 @@ func (s *Service) createPolicies(ctx context.Context, patID, orgID string, roles case slices.Contains(r.Scopes, schema.OrganizationNamespace): err = s.createOrgScopedPolicy(ctx, patID, orgID, r) default: - err = fmt.Errorf("role %s has scopes %v: %w", r.Name, r.Scopes, ErrUnsupportedScope) + err = fmt.Errorf("role %s has scopes %v: %w", r.Name, r.Scopes, paterrors.ErrUnsupportedScope) } if err != nil { return err @@ -243,7 +245,7 @@ func (s *Service) validateRolePermissions(roles []role.Role) error { for _, r := range roles { for _, perm := range r.Permissions { if _, denied := s.deniedPerms[perm]; denied { - return fmt.Errorf("role %s has denied permission %s: %w", r.Name, perm, ErrDeniedRole) + return fmt.Errorf("role %s has denied permission %s: %w", r.Name, perm, paterrors.ErrDeniedRole) } } } diff --git a/core/userpat/service_test.go b/core/userpat/service_test.go index 33b2f2074..dd1629472 100644 --- a/core/userpat/service_test.go +++ b/core/userpat/service_test.go @@ -10,12 +10,14 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/raystack/frontier/core/auditrecord/models" + auditmodels "github.com/raystack/frontier/core/auditrecord/models" "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/policy" "github.com/raystack/frontier/core/role" "github.com/raystack/frontier/core/userpat" + paterrors "github.com/raystack/frontier/core/userpat/errors" "github.com/raystack/frontier/core/userpat/mocks" + "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/salt/log" "github.com/stretchr/testify/mock" @@ -46,7 +48,7 @@ func newSuccessMocks(t *testing.T) (*mocks.OrganizationService, *mocks.RoleServi Return(policy.Policy{}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). - Return(models.AuditRecord{}, nil).Maybe() + Return(auditmodels.AuditRecord{}, nil).Maybe() return orgSvc, roleSvc, policySvc, auditRepo } @@ -60,7 +62,7 @@ func TestService_Create(t *testing.T) { wantErr bool wantErrIs error wantErrMsg string - validateFunc func(t *testing.T, got userpat.PAT, tokenValue string) + validateFunc func(t *testing.T, got models.PAT, tokenValue string) }{ { name: "should return ErrDisabled when PAT feature is disabled", @@ -72,7 +74,7 @@ func TestService_Create(t *testing.T) { ExpiresAt: time.Now().Add(24 * time.Hour), }, wantErr: true, - wantErrIs: userpat.ErrDisabled, + wantErrIs: paterrors.ErrDisabled, setup: func() *userpat.Service { repo := mocks.NewRepository(t) orgSvc := mocks.NewOrganizationService(t) @@ -112,7 +114,7 @@ func TestService_Create(t *testing.T) { ExpiresAt: time.Now().Add(24 * time.Hour), }, wantErr: true, - wantErrIs: userpat.ErrLimitExceeded, + wantErrIs: paterrors.ErrLimitExceeded, setup: func() *userpat.Service { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). @@ -132,7 +134,7 @@ func TestService_Create(t *testing.T) { ExpiresAt: time.Now().Add(24 * time.Hour), }, wantErr: true, - wantErrIs: userpat.ErrLimitExceeded, + wantErrIs: paterrors.ErrLimitExceeded, setup: func() *userpat.Service { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). @@ -157,8 +159,8 @@ func TestService_Create(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{}, errors.New("insert failed")) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{}, errors.New("insert failed")) orgSvc := mocks.NewOrganizationService(t) auditRepo := mocks.NewAuditRecordRepository(t) roleSvc := mocks.NewRoleService(t) @@ -178,13 +180,13 @@ func TestService_Create(t *testing.T) { ExpiresAt: time.Now().Add(24 * time.Hour), }, wantErr: true, - wantErrIs: userpat.ErrConflict, + wantErrIs: paterrors.ErrConflict, setup: func() *userpat.Service { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{}, userpat.ErrConflict) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{}, paterrors.ErrConflict) orgSvc := mocks.NewOrganizationService(t) auditRepo := mocks.NewAuditRecordRepository(t) roleSvc := mocks.NewRoleService(t) @@ -210,8 +212,8 @@ func TestService_Create(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Run(func(ctx context.Context, pat userpat.PAT) { + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Run(func(ctx context.Context, pat models.PAT) { if pat.UserID != "user-1" { t.Errorf("Create() UserID = %v, want %v", pat.UserID, "user-1") } @@ -231,7 +233,7 @@ func TestService_Create(t *testing.T) { t.Errorf("Create() ExpiresAt = %v, want %v", pat.ExpiresAt, futureExpiry) } }). - Return(userpat.PAT{ + Return(models.PAT{ ID: "pat-id-1", UserID: "user-1", OrgID: "org-1", @@ -243,7 +245,7 @@ func TestService_Create(t *testing.T) { orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, roleSvc, policySvc, auditRepo) }, - validateFunc: func(t *testing.T, got userpat.PAT, tokenValue string) { + validateFunc: func(t *testing.T, got models.PAT, tokenValue string) { t.Helper() if got.ID != "pat-id-1" { t.Errorf("Create() ID = %v, want %v", got.ID, "pat-id-1") @@ -270,12 +272,12 @@ func TestService_Create(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1"}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, roleSvc, policySvc, auditRepo) }, - validateFunc: func(t *testing.T, got userpat.PAT, tokenValue string) { + validateFunc: func(t *testing.T, got models.PAT, tokenValue string) { t.Helper() if !strings.HasPrefix(tokenValue, "fpt_") { t.Errorf("token should start with prefix fpt_, got %v", tokenValue) @@ -307,12 +309,12 @@ func TestService_Create(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1"}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, roleSvc, policySvc, auditRepo) }, - validateFunc: func(t *testing.T, got userpat.PAT, tokenValue string) { + validateFunc: func(t *testing.T, got models.PAT, tokenValue string) { t.Helper() parts := strings.SplitN(tokenValue, "_", 2) if len(parts) != 2 { @@ -343,8 +345,8 @@ func TestService_Create(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1"}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) return userpat.NewService(log.NewNoop(), repo, userpat.Config{ Enabled: true, @@ -353,7 +355,7 @@ func TestService_Create(t *testing.T) { MaxLifetime: "8760h", }, orgSvc, roleSvc, policySvc, auditRepo) }, - validateFunc: func(t *testing.T, got userpat.PAT, tokenValue string) { + validateFunc: func(t *testing.T, got models.PAT, tokenValue string) { t.Helper() if !strings.HasPrefix(tokenValue, "custom_") { t.Errorf("token should start with custom_, got %v", tokenValue) @@ -374,8 +376,8 @@ func TestService_Create(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(49), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1"}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, roleSvc, policySvc, auditRepo) }, @@ -394,8 +396,8 @@ func TestService_Create(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1"}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, roleSvc, policySvc, auditRepo) }, @@ -428,8 +430,8 @@ func TestService_Create_UniquePATs(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil).Times(2) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1"}, nil).Times(2) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil).Times(2) orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) svc := userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, roleSvc, policySvc, auditRepo) @@ -460,11 +462,11 @@ func TestService_Create_HashVerification(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Run(func(ctx context.Context, pat userpat.PAT) { + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Run(func(ctx context.Context, pat models.PAT) { capturedHash = pat.SecretHash }). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1"}, nil) + Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) svc := userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, roleSvc, policySvc, auditRepo) @@ -499,14 +501,14 @@ func TestService_Create_HashVerification(t *testing.T) { func TestService_CreatePolicies_OrgScopedRole(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1").Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) - auditRepo.On("Create", mock.Anything, mock.Anything).Return(models.AuditRecord{}, nil).Maybe() + auditRepo.On("Create", mock.Anything, mock.Anything).Return(auditmodels.AuditRecord{}, nil).Maybe() roleSvc := mocks.NewRoleService(t) roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"org-role-1"}}).Return([]role.Role{{ @@ -541,14 +543,14 @@ func TestService_CreatePolicies_OrgScopedRole(t *testing.T) { func TestService_CreatePolicies_ProjectScopedAllProjects(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1").Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) - auditRepo.On("Create", mock.Anything, mock.Anything).Return(models.AuditRecord{}, nil).Maybe() + auditRepo.On("Create", mock.Anything, mock.Anything).Return(auditmodels.AuditRecord{}, nil).Maybe() roleSvc := mocks.NewRoleService(t) roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"proj-role-1"}}).Return([]role.Role{{ @@ -584,14 +586,14 @@ func TestService_CreatePolicies_ProjectScopedAllProjects(t *testing.T) { func TestService_CreatePolicies_ProjectScopedSpecificProjects(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1").Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) - auditRepo.On("Create", mock.Anything, mock.Anything).Return(models.AuditRecord{}, nil).Maybe() + auditRepo.On("Create", mock.Anything, mock.Anything).Return(auditmodels.AuditRecord{}, nil).Maybe() roleSvc := mocks.NewRoleService(t) roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"proj-role-1"}}).Return([]role.Role{{ @@ -663,7 +665,7 @@ func TestService_CreatePolicies_DeniedPermission(t *testing.T) { if err == nil { t.Fatal("Create() expected error for denied permission, got nil") } - if !errors.Is(err, userpat.ErrDeniedRole) { + if !errors.Is(err, paterrors.ErrDeniedRole) { t.Errorf("Create() error = %v, want ErrDeniedRole", err) } } @@ -727,7 +729,7 @@ func TestService_CreatePolicies_UnsupportedScope(t *testing.T) { if err == nil { t.Fatal("Create() expected error for unsupported scope, got nil") } - if !errors.Is(err, userpat.ErrUnsupportedScope) { + if !errors.Is(err, paterrors.ErrUnsupportedScope) { t.Errorf("Create() error = %v, want ErrUnsupportedScope", err) } } @@ -761,7 +763,7 @@ func TestService_CreatePolicies_MissingRoleID(t *testing.T) { if err == nil { t.Fatal("Create() expected error for missing role, got nil") } - if !errors.Is(err, userpat.ErrRoleNotFound) { + if !errors.Is(err, paterrors.ErrRoleNotFound) { t.Errorf("Create() error = %v, want ErrRoleNotFound", err) } if !strings.Contains(err.Error(), "role-b") { @@ -772,8 +774,8 @@ func TestService_CreatePolicies_MissingRoleID(t *testing.T) { func TestService_CreatePolicies_NoRoles(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1").Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) svc := userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, roleSvc, policySvc, auditRepo) @@ -1005,7 +1007,7 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { }, want: nil, // no policies should be created wantErr: true, - wantErrIs: userpat.ErrDeniedRole, + wantErrIs: paterrors.ErrDeniedRole, }, { name: "unsupported scope rejects before any policy creation", @@ -1017,7 +1019,7 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { }, want: nil, // scope validation happens upfront — no token or policies created wantErr: true, - wantErrIs: userpat.ErrUnsupportedScope, + wantErrIs: paterrors.ErrUnsupportedScope, }, { name: "role with mixed supported and unsupported scopes is rejected", @@ -1028,7 +1030,7 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { }, want: nil, wantErr: true, - wantErrIs: userpat.ErrUnsupportedScope, + wantErrIs: paterrors.ErrUnsupportedScope, }, { name: "role with empty scopes is unsupported", @@ -1039,7 +1041,7 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { }, want: nil, wantErr: true, - wantErrIs: userpat.ErrUnsupportedScope, + wantErrIs: paterrors.ErrUnsupportedScope, }, { name: "role count mismatch: requested 2 but found 1", @@ -1050,7 +1052,7 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { }, want: nil, wantErr: true, - wantErrIs: userpat.ErrRoleNotFound, + wantErrIs: paterrors.ErrRoleNotFound, }, } @@ -1065,15 +1067,15 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1").Return(int64(0), nil) // Only mock repo.Create for success cases — validation errors fail before token creation if !tt.wantErr { - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) } orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) - auditRepo.On("Create", mock.Anything, mock.Anything).Return(models.AuditRecord{}, nil).Maybe() + auditRepo.On("Create", mock.Anything, mock.Anything).Return(auditmodels.AuditRecord{}, nil).Maybe() // --- roleService: return the test's roles roleSvc := mocks.NewRoleService(t) @@ -1172,8 +1174,8 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { func TestService_CreatePolicies_PolicyCreateFailure(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1").Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) orgSvc := mocks.NewOrganizationService(t) auditRepo := mocks.NewAuditRecordRepository(t) diff --git a/core/userpat/userpat.go b/core/userpat/userpat.go index 6e2965813..0a3dcc3f7 100644 --- a/core/userpat/userpat.go +++ b/core/userpat/userpat.go @@ -4,23 +4,12 @@ import ( "context" "time" - "github.com/raystack/frontier/pkg/metadata" + "github.com/raystack/frontier/core/userpat/models" ) -type PAT struct { - ID string `rql:"name=id,type=string"` - UserID string `rql:"name=user_id,type=string"` - OrgID string `rql:"name=org_id,type=string"` - Title string `rql:"name=title,type=string"` - SecretHash string `json:"-"` - Metadata metadata.Metadata - LastUsedAt *time.Time `rql:"name=last_used_at,type=datetime"` - ExpiresAt time.Time `rql:"name=expires_at,type=datetime"` - CreatedAt time.Time `rql:"name=created_at,type=datetime"` - UpdatedAt time.Time `rql:"name=updated_at,type=datetime"` -} - type Repository interface { - Create(ctx context.Context, pat PAT) (PAT, error) + Create(ctx context.Context, pat models.PAT) (models.PAT, error) CountActive(ctx context.Context, userID, orgID string) (int64, error) + GetBySecretHash(ctx context.Context, secretHash string) (models.PAT, error) + UpdateLastUsedAt(ctx context.Context, id string, at time.Time) error } diff --git a/core/userpat/validator.go b/core/userpat/validator.go new file mode 100644 index 000000000..f62d3972d --- /dev/null +++ b/core/userpat/validator.go @@ -0,0 +1,69 @@ +package userpat + +import ( + "context" + "encoding/base64" + "encoding/hex" + "fmt" + "strings" + "time" + + paterrors "github.com/raystack/frontier/core/userpat/errors" + "github.com/raystack/frontier/core/userpat/models" + "github.com/raystack/salt/log" + "golang.org/x/crypto/sha3" +) + +// Validator validates PAT values during authentication. +type Validator struct { + repo Repository + config Config + logger log.Logger +} + +func NewValidator(logger log.Logger, repo Repository, config Config) *Validator { + return &Validator{ + repo: repo, + config: config, + logger: logger, + } +} + +// Validate checks a PAT value and returns the corresponding PAT. +// Returns ErrInvalidPAT if the value doesn't match the configured prefix (allowing +// the auth chain to fall through to the next authenticator). +// Returns ErrMalformedPAT, ErrExpired, ErrNotFound, or ErrDisabled for terminal auth failures. +func (v *Validator) Validate(ctx context.Context, value string) (models.PAT, error) { + if !v.config.Enabled { + return models.PAT{}, paterrors.ErrDisabled + } + + prefix := v.config.Prefix + "_" + if !strings.HasPrefix(value, prefix) { + return models.PAT{}, paterrors.ErrInvalidPAT + } + + encoded := value[len(prefix):] + secretBytes, err := base64.RawURLEncoding.DecodeString(encoded) + if err != nil { + return models.PAT{}, fmt.Errorf("%w: invalid encoding", paterrors.ErrMalformedPAT) + } + + hash := sha3.Sum256(secretBytes) + secretHash := hex.EncodeToString(hash[:]) + + pat, err := v.repo.GetBySecretHash(ctx, secretHash) + if err != nil { + return models.PAT{}, err + } + + if pat.ExpiresAt.Before(time.Now()) { + return models.PAT{}, paterrors.ErrExpired + } + + if err := v.repo.UpdateLastUsedAt(ctx, pat.ID, time.Now()); err != nil { + return models.PAT{}, fmt.Errorf("updating last_used_at: %w", err) + } + + return pat, nil +} diff --git a/core/userpat/validator_test.go b/core/userpat/validator_test.go new file mode 100644 index 000000000..c9a87ad55 --- /dev/null +++ b/core/userpat/validator_test.go @@ -0,0 +1,148 @@ +package userpat_test + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "errors" + "testing" + "time" + + "github.com/raystack/frontier/core/userpat" + paterrors "github.com/raystack/frontier/core/userpat/errors" + "github.com/raystack/frontier/core/userpat/mocks" + "github.com/raystack/frontier/core/userpat/models" + "github.com/raystack/salt/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/sha3" +) + +func validPATValue(t *testing.T, prefix string) (value string, secretHash string) { + t.Helper() + secretBytes := make([]byte, 32) + _, err := rand.Read(secretBytes) + require.NoError(t, err) + value = prefix + "_" + base64.RawURLEncoding.EncodeToString(secretBytes) + hash := sha3.Sum256(secretBytes) + secretHash = hex.EncodeToString(hash[:]) + return value, secretHash +} + +func TestValidator_Validate(t *testing.T) { + const prefix = "fpt" + cfg := userpat.Config{ + Enabled: true, + Prefix: prefix, + } + + t.Run("disabled feature returns ErrDisabled", func(t *testing.T) { + v := userpat.NewValidator(log.NewNoop(), nil, userpat.Config{Enabled: false}) + _, err := v.Validate(context.Background(), "fpt_anything") + assert.ErrorIs(t, err, paterrors.ErrDisabled) + }) + + t.Run("wrong prefix returns ErrInvalidPAT", func(t *testing.T) { + repo := mocks.NewRepository(t) + v := userpat.NewValidator(log.NewNoop(), repo, cfg) + + _, err := v.Validate(context.Background(), "ghp_sometoken") + assert.ErrorIs(t, err, paterrors.ErrInvalidPAT) + }) + + t.Run("no prefix separator returns ErrInvalidPAT", func(t *testing.T) { + repo := mocks.NewRepository(t) + v := userpat.NewValidator(log.NewNoop(), repo, cfg) + + _, err := v.Validate(context.Background(), "randomstring") + assert.ErrorIs(t, err, paterrors.ErrInvalidPAT) + }) + + t.Run("malformed base64 returns ErrMalformedPAT", func(t *testing.T) { + repo := mocks.NewRepository(t) + v := userpat.NewValidator(log.NewNoop(), repo, cfg) + + _, err := v.Validate(context.Background(), "fpt_!!!not-base64!!!") + assert.ErrorIs(t, err, paterrors.ErrMalformedPAT) + assert.NotErrorIs(t, err, paterrors.ErrInvalidPAT) + }) + + t.Run("unknown hash returns ErrNotFound", func(t *testing.T) { + repo := mocks.NewRepository(t) + v := userpat.NewValidator(log.NewNoop(), repo, cfg) + + value, secretHash := validPATValue(t, prefix) + repo.EXPECT().GetBySecretHash(mock.Anything, secretHash).Return(models.PAT{}, paterrors.ErrNotFound) + + _, err := v.Validate(context.Background(), value) + assert.ErrorIs(t, err, paterrors.ErrNotFound) + }) + + t.Run("expired PAT returns ErrExpired", func(t *testing.T) { + repo := mocks.NewRepository(t) + v := userpat.NewValidator(log.NewNoop(), repo, cfg) + + value, secretHash := validPATValue(t, prefix) + repo.EXPECT().GetBySecretHash(mock.Anything, secretHash).Return(models.PAT{ + ID: "pat-1", + ExpiresAt: time.Now().Add(-time.Hour), + }, nil) + + _, err := v.Validate(context.Background(), value) + assert.ErrorIs(t, err, paterrors.ErrExpired) + }) + + t.Run("db error propagates as-is", func(t *testing.T) { + repo := mocks.NewRepository(t) + v := userpat.NewValidator(log.NewNoop(), repo, cfg) + + value, secretHash := validPATValue(t, prefix) + dbErr := errors.New("connection refused") + repo.EXPECT().GetBySecretHash(mock.Anything, secretHash).Return(models.PAT{}, dbErr) + + _, err := v.Validate(context.Background(), value) + assert.ErrorIs(t, err, dbErr) + assert.NotErrorIs(t, err, paterrors.ErrInvalidPAT) + }) + + t.Run("UpdateLastUsedAt failure returns error", func(t *testing.T) { + repo := mocks.NewRepository(t) + v := userpat.NewValidator(log.NewNoop(), repo, cfg) + + value, secretHash := validPATValue(t, prefix) + repo.EXPECT().GetBySecretHash(mock.Anything, secretHash).Return(models.PAT{ + ID: "pat-1", + ExpiresAt: time.Now().Add(time.Hour), + }, nil) + dbErr := errors.New("connection refused") + repo.EXPECT().UpdateLastUsedAt(mock.Anything, "pat-1", mock.AnythingOfType("time.Time")).Return(dbErr) + + _, err := v.Validate(context.Background(), value) + assert.ErrorIs(t, err, dbErr) + }) + + t.Run("valid PAT returns PAT and updates last_used_at", func(t *testing.T) { + repo := mocks.NewRepository(t) + v := userpat.NewValidator(log.NewNoop(), repo, cfg) + + value, secretHash := validPATValue(t, prefix) + expectedPAT := models.PAT{ + ID: "pat-1", + UserID: "user-1", + OrgID: "org-1", + Title: "my-pat", + ExpiresAt: time.Now().Add(time.Hour), + } + repo.EXPECT().GetBySecretHash(mock.Anything, secretHash).Return(expectedPAT, nil) + repo.EXPECT().UpdateLastUsedAt(mock.Anything, "pat-1", mock.AnythingOfType("time.Time")).Return(nil) + + pat, err := v.Validate(context.Background(), value) + require.NoError(t, err) + assert.Equal(t, expectedPAT.ID, pat.ID) + assert.Equal(t, expectedPAT.UserID, pat.UserID) + assert.Equal(t, expectedPAT.OrgID, pat.OrgID) + assert.Equal(t, expectedPAT.Title, pat.Title) + }) +} diff --git a/internal/api/v1beta1connect/authenticate.go b/internal/api/v1beta1connect/authenticate.go index 32dbd3cce..bb4aa4818 100644 --- a/internal/api/v1beta1connect/authenticate.go +++ b/internal/api/v1beta1connect/authenticate.go @@ -14,6 +14,7 @@ import ( "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/relation" "github.com/raystack/frontier/core/user" + patErrors "github.com/raystack/frontier/core/userpat/errors" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/frontier/pkg/server/consts" sessionutils "github.com/raystack/frontier/pkg/session" @@ -302,6 +303,11 @@ func (h *ConnectHandler) GetLoggedInPrincipal(ctx context.Context, via ...authen return principal, connect.NewError(connect.CodeNotFound, ErrUserNotExist) case errors.Is(err, errors.ErrUnauthenticated): return principal, connect.NewError(connect.CodeUnauthenticated, ErrUnauthenticated) + case errors.Is(err, patErrors.ErrMalformedPAT), + errors.Is(err, patErrors.ErrNotFound), + errors.Is(err, patErrors.ErrExpired), + errors.Is(err, patErrors.ErrDisabled): + return principal, connect.NewError(connect.CodeUnauthenticated, ErrUnauthenticated) default: return principal, connect.NewError(connect.CodeInternal, err) } diff --git a/internal/api/v1beta1connect/interfaces.go b/internal/api/v1beta1connect/interfaces.go index 235477aa3..3738f39ae 100644 --- a/internal/api/v1beta1connect/interfaces.go +++ b/internal/api/v1beta1connect/interfaces.go @@ -47,6 +47,7 @@ import ( "github.com/raystack/frontier/core/serviceuser" "github.com/raystack/frontier/core/user" "github.com/raystack/frontier/core/userpat" + "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/core/webhook" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/frontier/pkg/metadata" @@ -400,5 +401,5 @@ type AuditRecordService interface { type UserPATService interface { ValidateExpiry(expiresAt time.Time) error - Create(ctx context.Context, req userpat.CreateRequest) (userpat.PAT, string, error) + Create(ctx context.Context, req userpat.CreateRequest) (models.PAT, string, error) } diff --git a/internal/api/v1beta1connect/mocks/user_pat_service.go b/internal/api/v1beta1connect/mocks/user_pat_service.go index ecbdb4362..16e89e76e 100644 --- a/internal/api/v1beta1connect/mocks/user_pat_service.go +++ b/internal/api/v1beta1connect/mocks/user_pat_service.go @@ -4,10 +4,12 @@ package mocks import ( context "context" - time "time" + models "github.com/raystack/frontier/core/userpat/models" mock "github.com/stretchr/testify/mock" + time "time" + userpat "github.com/raystack/frontier/core/userpat" ) @@ -25,23 +27,23 @@ func (_m *UserPATService) EXPECT() *UserPATService_Expecter { } // Create provides a mock function with given fields: ctx, req -func (_m *UserPATService) Create(ctx context.Context, req userpat.CreateRequest) (userpat.PAT, string, error) { +func (_m *UserPATService) Create(ctx context.Context, req userpat.CreateRequest) (models.PAT, string, error) { ret := _m.Called(ctx, req) if len(ret) == 0 { panic("no return value specified for Create") } - var r0 userpat.PAT + var r0 models.PAT var r1 string var r2 error - if rf, ok := ret.Get(0).(func(context.Context, userpat.CreateRequest) (userpat.PAT, string, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, userpat.CreateRequest) (models.PAT, string, error)); ok { return rf(ctx, req) } - if rf, ok := ret.Get(0).(func(context.Context, userpat.CreateRequest) userpat.PAT); ok { + if rf, ok := ret.Get(0).(func(context.Context, userpat.CreateRequest) models.PAT); ok { r0 = rf(ctx, req) } else { - r0 = ret.Get(0).(userpat.PAT) + r0 = ret.Get(0).(models.PAT) } if rf, ok := ret.Get(1).(func(context.Context, userpat.CreateRequest) string); ok { @@ -78,12 +80,12 @@ func (_c *UserPATService_Create_Call) Run(run func(ctx context.Context, req user return _c } -func (_c *UserPATService_Create_Call) Return(_a0 userpat.PAT, _a1 string, _a2 error) *UserPATService_Create_Call { +func (_c *UserPATService_Create_Call) Return(_a0 models.PAT, _a1 string, _a2 error) *UserPATService_Create_Call { _c.Call.Return(_a0, _a1, _a2) return _c } -func (_c *UserPATService_Create_Call) RunAndReturn(run func(context.Context, userpat.CreateRequest) (userpat.PAT, string, error)) *UserPATService_Create_Call { +func (_c *UserPATService_Create_Call) RunAndReturn(run func(context.Context, userpat.CreateRequest) (models.PAT, string, error)) *UserPATService_Create_Call { _c.Call.Return(run) return _c } diff --git a/internal/api/v1beta1connect/organization.go b/internal/api/v1beta1connect/organization.go index 6e80d9a43..30ac95f60 100644 --- a/internal/api/v1beta1connect/organization.go +++ b/internal/api/v1beta1connect/organization.go @@ -9,6 +9,7 @@ import ( "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/policy" "github.com/raystack/frontier/core/project" + "github.com/raystack/frontier/core/relation" "github.com/raystack/frontier/core/role" "github.com/raystack/frontier/core/serviceuser" "github.com/raystack/frontier/core/user" @@ -144,6 +145,11 @@ func (h *ConnectHandler) CreateOrganization(ctx context.Context, request *connec return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) case errors.Is(err, organization.ErrConflict): return nil, connect.NewError(connect.CodeAlreadyExists, ErrConflictRequest) + case errors.Is(err, relation.ErrSubjectNotAllowed): + errorLogger.LogServiceError(ctx, request, "CreateOrganization.Create", err, + zap.String("org_name", request.Msg.GetBody().GetName()), + zap.String("org_title", request.Msg.GetBody().GetTitle())) + return nil, connect.NewError(connect.CodePermissionDenied, ErrUnauthorized) default: errorLogger.LogServiceError(ctx, request, "CreateOrganization.Create", err, zap.String("org_name", request.Msg.GetBody().GetName()), diff --git a/internal/api/v1beta1connect/user_pat.go b/internal/api/v1beta1connect/user_pat.go index 35a54918f..35bf6dd89 100644 --- a/internal/api/v1beta1connect/user_pat.go +++ b/internal/api/v1beta1connect/user_pat.go @@ -6,6 +6,8 @@ import ( "connectrpc.com/connect" "github.com/raystack/frontier/core/userpat" + paterrors "github.com/raystack/frontier/core/userpat/errors" + "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/frontier/pkg/metadata" frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1" @@ -47,18 +49,18 @@ func (h *ConnectHandler) CreateCurrentUserPAT(ctx context.Context, request *conn zap.String("org_id", request.Msg.GetOrgId())) switch { - case errors.Is(err, userpat.ErrDisabled): + case errors.Is(err, paterrors.ErrDisabled): return nil, connect.NewError(connect.CodeFailedPrecondition, err) - case errors.Is(err, userpat.ErrConflict): + case errors.Is(err, paterrors.ErrConflict): return nil, connect.NewError(connect.CodeAlreadyExists, err) - case errors.Is(err, userpat.ErrLimitExceeded): + case errors.Is(err, paterrors.ErrLimitExceeded): return nil, connect.NewError(connect.CodeResourceExhausted, err) - case errors.Is(err, userpat.ErrRoleNotFound): - return nil, connect.NewError(connect.CodeInvalidArgument, userpat.ErrRoleNotFound) - case errors.Is(err, userpat.ErrDeniedRole): - return nil, connect.NewError(connect.CodeInvalidArgument, userpat.ErrDeniedRole) - case errors.Is(err, userpat.ErrUnsupportedScope): - return nil, connect.NewError(connect.CodeInvalidArgument, userpat.ErrUnsupportedScope) + case errors.Is(err, paterrors.ErrRoleNotFound): + return nil, connect.NewError(connect.CodeInvalidArgument, paterrors.ErrRoleNotFound) + case errors.Is(err, paterrors.ErrDeniedRole): + return nil, connect.NewError(connect.CodeInvalidArgument, paterrors.ErrDeniedRole) + case errors.Is(err, paterrors.ErrUnsupportedScope): + return nil, connect.NewError(connect.CodeInvalidArgument, paterrors.ErrUnsupportedScope) default: return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) } @@ -69,7 +71,7 @@ func (h *ConnectHandler) CreateCurrentUserPAT(ctx context.Context, request *conn }), nil } -func transformPATToPB(pat userpat.PAT, patValue string) *frontierv1beta1.PAT { +func transformPATToPB(pat models.PAT, patValue string) *frontierv1beta1.PAT { pbPAT := &frontierv1beta1.PAT{ Id: pat.ID, Title: pat.Title, diff --git a/internal/api/v1beta1connect/user_pat_test.go b/internal/api/v1beta1connect/user_pat_test.go index 2ccd9191c..f42e028e8 100644 --- a/internal/api/v1beta1connect/user_pat_test.go +++ b/internal/api/v1beta1connect/user_pat_test.go @@ -10,6 +10,8 @@ import ( "github.com/raystack/frontier/core/authenticate" "github.com/raystack/frontier/core/user" "github.com/raystack/frontier/core/userpat" + paterrors "github.com/raystack/frontier/core/userpat/errors" + "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/api/v1beta1connect/mocks" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/frontier/pkg/errors" @@ -74,7 +76,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { Type: schema.UserPrincipal, User: &user.User{ID: testUserID}, }, nil) - ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(userpat.ErrExpiryInPast) + ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(paterrors.ErrExpiryInPast) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -83,7 +85,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { ExpiresAt: timestamppb.New(time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)), }), want: nil, - wantErr: connect.NewError(connect.CodeInvalidArgument, userpat.ErrExpiryInPast), + wantErr: connect.NewError(connect.CodeInvalidArgument, paterrors.ErrExpiryInPast), }, { name: "should return invalid argument when expiry exceeds max lifetime", @@ -93,7 +95,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { Type: schema.UserPrincipal, User: &user.User{ID: testUserID}, }, nil) - ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(userpat.ErrExpiryExceeded) + ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(paterrors.ErrExpiryExceeded) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -102,7 +104,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { ExpiresAt: timestamppb.New(time.Now().Add(48 * time.Hour)), }), want: nil, - wantErr: connect.NewError(connect.CodeInvalidArgument, userpat.ErrExpiryExceeded), + wantErr: connect.NewError(connect.CodeInvalidArgument, paterrors.ErrExpiryExceeded), }, { name: "should return failed precondition when PAT is disabled", @@ -114,7 +116,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { }, nil) ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(nil) ps.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.CreateRequest")). - Return(userpat.PAT{}, "", userpat.ErrDisabled) + Return(models.PAT{}, "", paterrors.ErrDisabled) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -123,7 +125,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { ExpiresAt: timestamppb.New(testTime), }), want: nil, - wantErr: connect.NewError(connect.CodeFailedPrecondition, userpat.ErrDisabled), + wantErr: connect.NewError(connect.CodeFailedPrecondition, paterrors.ErrDisabled), }, { name: "should return already exists when title conflicts", @@ -135,7 +137,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { }, nil) ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(nil) ps.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.CreateRequest")). - Return(userpat.PAT{}, "", userpat.ErrConflict) + Return(models.PAT{}, "", paterrors.ErrConflict) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -144,7 +146,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { ExpiresAt: timestamppb.New(testTime), }), want: nil, - wantErr: connect.NewError(connect.CodeAlreadyExists, userpat.ErrConflict), + wantErr: connect.NewError(connect.CodeAlreadyExists, paterrors.ErrConflict), }, { name: "should return resource exhausted when limit exceeded", @@ -156,7 +158,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { }, nil) ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(nil) ps.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.CreateRequest")). - Return(userpat.PAT{}, "", userpat.ErrLimitExceeded) + Return(models.PAT{}, "", paterrors.ErrLimitExceeded) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -165,7 +167,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { ExpiresAt: timestamppb.New(testTime), }), want: nil, - wantErr: connect.NewError(connect.CodeResourceExhausted, userpat.ErrLimitExceeded), + wantErr: connect.NewError(connect.CodeResourceExhausted, paterrors.ErrLimitExceeded), }, { name: "should return invalid argument when role is not found", @@ -177,7 +179,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { }, nil) ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(nil) ps.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.CreateRequest")). - Return(userpat.PAT{}, "", fmt.Errorf("fetching roles: %w", userpat.ErrRoleNotFound)) + Return(models.PAT{}, "", fmt.Errorf("fetching roles: %w", paterrors.ErrRoleNotFound)) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -186,7 +188,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { ExpiresAt: timestamppb.New(testTime), }), want: nil, - wantErr: connect.NewError(connect.CodeInvalidArgument, userpat.ErrRoleNotFound), + wantErr: connect.NewError(connect.CodeInvalidArgument, paterrors.ErrRoleNotFound), }, { name: "should return invalid argument when role is denied", @@ -198,7 +200,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { }, nil) ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(nil) ps.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.CreateRequest")). - Return(userpat.PAT{}, "", fmt.Errorf("creating policies: %w", userpat.ErrDeniedRole)) + Return(models.PAT{}, "", fmt.Errorf("creating policies: %w", paterrors.ErrDeniedRole)) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -207,7 +209,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { ExpiresAt: timestamppb.New(testTime), }), want: nil, - wantErr: connect.NewError(connect.CodeInvalidArgument, userpat.ErrDeniedRole), + wantErr: connect.NewError(connect.CodeInvalidArgument, paterrors.ErrDeniedRole), }, { name: "should return invalid argument when role scope is unsupported", @@ -219,7 +221,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { }, nil) ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(nil) ps.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.CreateRequest")). - Return(userpat.PAT{}, "", fmt.Errorf("creating policies: %w", userpat.ErrUnsupportedScope)) + Return(models.PAT{}, "", fmt.Errorf("creating policies: %w", paterrors.ErrUnsupportedScope)) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -228,7 +230,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { ExpiresAt: timestamppb.New(testTime), }), want: nil, - wantErr: connect.NewError(connect.CodeInvalidArgument, userpat.ErrUnsupportedScope), + wantErr: connect.NewError(connect.CodeInvalidArgument, paterrors.ErrUnsupportedScope), }, { name: "should return internal error for unknown service failure", @@ -240,7 +242,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { }, nil) ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(nil) ps.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.CreateRequest")). - Return(userpat.PAT{}, "", errors.New("unexpected error")) + Return(models.PAT{}, "", errors.New("unexpected error")) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -265,7 +267,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { req.OrgID == testOrgID && req.Title == "my-token" && len(req.RoleIDs) == 1 && req.RoleIDs[0] == testRoleID - })).Return(userpat.PAT{ + })).Return(models.PAT{ ID: "pat-1", UserID: testUserID, OrgID: testOrgID, @@ -305,7 +307,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { }, nil) ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(nil) ps.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.CreateRequest")). - Return(userpat.PAT{ + Return(models.PAT{ ID: "pat-1", UserID: testUserID, OrgID: testOrgID, @@ -390,13 +392,13 @@ func TestTransformPATToPB(t *testing.T) { tests := []struct { name string - pat userpat.PAT + pat models.PAT patValue string want *frontierv1beta1.PAT }{ { name: "should transform minimal PAT", - pat: userpat.PAT{ + pat: models.PAT{ ID: "pat-1", UserID: "user-1", OrgID: "org-1", @@ -418,7 +420,7 @@ func TestTransformPATToPB(t *testing.T) { }, { name: "should include token value when provided", - pat: userpat.PAT{ + pat: models.PAT{ ID: "pat-1", UserID: "user-1", OrgID: "org-1", @@ -441,7 +443,7 @@ func TestTransformPATToPB(t *testing.T) { }, { name: "should include last_used_at when set", - pat: userpat.PAT{ + pat: models.PAT{ ID: "pat-1", UserID: "user-1", OrgID: "org-1", @@ -465,7 +467,7 @@ func TestTransformPATToPB(t *testing.T) { }, { name: "should include metadata when set", - pat: userpat.PAT{ + pat: models.PAT{ ID: "pat-1", UserID: "user-1", OrgID: "org-1", diff --git a/internal/store/postgres/userpat.go b/internal/store/postgres/userpat.go index 67d2c03e8..c629d93d1 100644 --- a/internal/store/postgres/userpat.go +++ b/internal/store/postgres/userpat.go @@ -4,7 +4,7 @@ import ( "encoding/json" "time" - "github.com/raystack/frontier/core/userpat" + "github.com/raystack/frontier/core/userpat/models" ) type UserPAT struct { @@ -21,14 +21,14 @@ type UserPAT struct { DeletedAt *time.Time `db:"deleted_at"` } -func (t UserPAT) transform() (userpat.PAT, error) { +func (t UserPAT) transform() (models.PAT, error) { var unmarshalledMetadata map[string]any if len(t.Metadata) > 0 { if err := json.Unmarshal(t.Metadata, &unmarshalledMetadata); err != nil { - return userpat.PAT{}, err + return models.PAT{}, err } } - return userpat.PAT{ + return models.PAT{ ID: t.ID, UserID: t.UserID, OrgID: t.OrgID, diff --git a/internal/store/postgres/userpat_repository.go b/internal/store/postgres/userpat_repository.go index 63bf49c51..d53f9bf6c 100644 --- a/internal/store/postgres/userpat_repository.go +++ b/internal/store/postgres/userpat_repository.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -11,7 +12,8 @@ import ( "github.com/doug-martin/goqu/v9" "github.com/google/uuid" - "github.com/raystack/frontier/core/userpat" + paterrors "github.com/raystack/frontier/core/userpat/errors" + "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/pkg/db" ) @@ -25,14 +27,14 @@ func NewUserPATRepository(dbc *db.Client) *UserPATRepository { } } -func (r UserPATRepository) Create(ctx context.Context, pat userpat.PAT) (userpat.PAT, error) { +func (r UserPATRepository) Create(ctx context.Context, pat models.PAT) (models.PAT, error) { if strings.TrimSpace(pat.ID) == "" { pat.ID = uuid.New().String() } marshaledMetadata, err := json.Marshal(pat.Metadata) if err != nil { - return userpat.PAT{}, fmt.Errorf("%w: %w", parseErr, err) + return models.PAT{}, fmt.Errorf("%w: %w", parseErr, err) } var model UserPAT @@ -47,7 +49,7 @@ func (r UserPATRepository) Create(ctx context.Context, pat userpat.PAT) (userpat "expires_at": pat.ExpiresAt, }).Returning(&UserPAT{}).ToSQL() if err != nil { - return userpat.PAT{}, fmt.Errorf("%w: %w", queryErr, err) + return models.PAT{}, fmt.Errorf("%w: %w", queryErr, err) } if err = r.dbc.WithTimeout(ctx, TABLE_USER_PATS, "Create", func(ctx context.Context) error { @@ -55,9 +57,9 @@ func (r UserPATRepository) Create(ctx context.Context, pat userpat.PAT) (userpat }); err != nil { err = checkPostgresError(err) if errors.Is(err, ErrDuplicateKey) { - return userpat.PAT{}, userpat.ErrConflict + return models.PAT{}, paterrors.ErrConflict } - return userpat.PAT{}, fmt.Errorf("%w: %w", dbErr, err) + return models.PAT{}, fmt.Errorf("%w: %w", dbErr, err) } return model.transform() @@ -84,3 +86,47 @@ func (r UserPATRepository) CountActive(ctx context.Context, userID, orgID string return count, nil } + +func (r UserPATRepository) GetBySecretHash(ctx context.Context, secretHash string) (models.PAT, error) { + query, params, err := dialect.From(TABLE_USER_PATS). + Select(&UserPAT{}). + Where( + goqu.Ex{"secret_hash": secretHash}, + goqu.Ex{"deleted_at": nil}, + ).Limit(1).ToSQL() + if err != nil { + return models.PAT{}, fmt.Errorf("%w: %w", queryErr, err) + } + + var model UserPAT + if err = r.dbc.WithTimeout(ctx, TABLE_USER_PATS, "GetBySecretHash", func(ctx context.Context) error { + return r.dbc.GetContext(ctx, &model, query, params...) + }); err != nil { + err = checkPostgresError(err) + if errors.Is(err, sql.ErrNoRows) { + return models.PAT{}, paterrors.ErrNotFound + } + return models.PAT{}, fmt.Errorf("%w: %w", dbErr, err) + } + + return model.transform() +} + +func (r UserPATRepository) UpdateLastUsedAt(ctx context.Context, id string, at time.Time) error { + query, params, err := dialect.Update(TABLE_USER_PATS). + Set(goqu.Record{"last_used_at": at}). + Where(goqu.Ex{"id": id}). + ToSQL() + if err != nil { + return fmt.Errorf("%w: %w", queryErr, err) + } + + if err = r.dbc.WithTimeout(ctx, TABLE_USER_PATS, "UpdateLastUsedAt", func(ctx context.Context) error { + _, err := r.dbc.ExecContext(ctx, query, params...) + return err + }); err != nil { + return fmt.Errorf("%w: %w", dbErr, err) + } + + return nil +} diff --git a/internal/store/postgres/userpat_repository_test.go b/internal/store/postgres/userpat_repository_test.go index d01bb7d65..a124ab61c 100644 --- a/internal/store/postgres/userpat_repository_test.go +++ b/internal/store/postgres/userpat_repository_test.go @@ -10,7 +10,8 @@ import ( "github.com/ory/dockertest" "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/user" - "github.com/raystack/frontier/core/userpat" + paterrors "github.com/raystack/frontier/core/userpat/errors" + "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/store/postgres" "github.com/raystack/frontier/pkg/db" "github.com/raystack/salt/log" @@ -74,7 +75,7 @@ func (s *UserPATRepositoryTestSuite) cleanup() error { func (s *UserPATRepositoryTestSuite) TestCreate() { s.Run("should create a token and return it with generated ID", func() { - pat := userpat.PAT{ + pat := models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "test-token", @@ -96,7 +97,7 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { s.Run("should use provided ID if set", func() { customID := uuid.New().String() - pat := userpat.PAT{ + pat := models.PAT{ ID: customID, UserID: s.users[0].ID, OrgID: s.orgs[0].ID, @@ -111,7 +112,7 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { }) s.Run("should store and return metadata", func() { - pat := userpat.PAT{ + pat := models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "token-with-meta", @@ -127,7 +128,7 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { }) s.Run("should return ErrConflict for duplicate title per user per org", func() { - pat := userpat.PAT{ + pat := models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "duplicate-title", @@ -141,11 +142,11 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { pat.ID = "" pat.SecretHash = "hashB" _, err = s.repository.Create(s.ctx, pat) - s.ErrorIs(err, userpat.ErrConflict) + s.ErrorIs(err, paterrors.ErrConflict) }) s.Run("should return ErrConflict for duplicate secret hash", func() { - pat1 := userpat.PAT{ + pat1 := models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "token-unique-hash-1", @@ -156,7 +157,7 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { _, err := s.repository.Create(s.ctx, pat1) s.Require().NoError(err) - pat2 := userpat.PAT{ + pat2 := models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "token-unique-hash-2", @@ -164,11 +165,11 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { ExpiresAt: time.Now().Add(24 * time.Hour), } _, err = s.repository.Create(s.ctx, pat2) - s.ErrorIs(err, userpat.ErrConflict) + s.ErrorIs(err, paterrors.ErrConflict) }) s.Run("should allow same title for different users in same org", func() { - pat1 := userpat.PAT{ + pat1 := models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "shared-title", @@ -178,7 +179,7 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { _, err := s.repository.Create(s.ctx, pat1) s.Require().NoError(err) - pat2 := userpat.PAT{ + pat2 := models.PAT{ UserID: s.users[1].ID, OrgID: s.orgs[0].ID, Title: "shared-title", @@ -190,7 +191,7 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { }) s.Run("should allow same title for same user in different orgs", func() { - pat1 := userpat.PAT{ + pat1 := models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "cross-org-title", @@ -200,7 +201,7 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { _, err := s.repository.Create(s.ctx, pat1) s.Require().NoError(err) - pat2 := userpat.PAT{ + pat2 := models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[1].ID, Title: "cross-org-title", @@ -229,7 +230,7 @@ func (s *UserPATRepositoryTestSuite) TestCountActive_ExcludesExpired() { s.truncateTokens() // create an active token - _, err := s.repository.Create(s.ctx, userpat.PAT{ + _, err := s.repository.Create(s.ctx, models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "active-token", @@ -239,7 +240,7 @@ func (s *UserPATRepositoryTestSuite) TestCountActive_ExcludesExpired() { s.Require().NoError(err) // create an expired token - _, err = s.repository.Create(s.ctx, userpat.PAT{ + _, err = s.repository.Create(s.ctx, models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "expired-token", @@ -257,7 +258,7 @@ func (s *UserPATRepositoryTestSuite) TestCountActive_FiltersByUserAndOrg() { s.truncateTokens() // token for user[0] in org[0] - _, err := s.repository.Create(s.ctx, userpat.PAT{ + _, err := s.repository.Create(s.ctx, models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "user0-org0", @@ -267,7 +268,7 @@ func (s *UserPATRepositoryTestSuite) TestCountActive_FiltersByUserAndOrg() { s.Require().NoError(err) // token for user[1] in org[0] - _, err = s.repository.Create(s.ctx, userpat.PAT{ + _, err = s.repository.Create(s.ctx, models.PAT{ UserID: s.users[1].ID, OrgID: s.orgs[0].ID, Title: "user1-org0", @@ -277,7 +278,7 @@ func (s *UserPATRepositoryTestSuite) TestCountActive_FiltersByUserAndOrg() { s.Require().NoError(err) // token for user[0] in org[1] - _, err = s.repository.Create(s.ctx, userpat.PAT{ + _, err = s.repository.Create(s.ctx, models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[1].ID, Title: "user0-org1", @@ -295,7 +296,7 @@ func (s *UserPATRepositoryTestSuite) TestCountActive_MultipleTokens() { s.truncateTokens() for i := 0; i < 3; i++ { - _, err := s.repository.Create(s.ctx, userpat.PAT{ + _, err := s.repository.Create(s.ctx, models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: fmt.Sprintf("multi-token-%d", i), diff --git a/pkg/server/connect_interceptors/session.go b/pkg/server/connect_interceptors/session.go index 4a44f7378..4c782f7c0 100644 --- a/pkg/server/connect_interceptors/session.go +++ b/pkg/server/connect_interceptors/session.go @@ -8,6 +8,7 @@ import ( "github.com/lestrrat-go/jwx/v2/jwt" "github.com/raystack/frontier/core/authenticate" + "github.com/raystack/frontier/core/userpat" "github.com/raystack/frontier/internal/api/v1beta1connect" "github.com/raystack/frontier/pkg/server/consts" @@ -22,6 +23,7 @@ type SessionInterceptor struct { // use secure cookie EncodeMulti/DecodeMulti cookieCodec securecookie.Codec conf authenticate.SessionConfig + patConf userpat.Config h *v1beta1connect.ConnectHandler } @@ -66,6 +68,8 @@ func (s *SessionInterceptor) WrapStreamingHandler(next connect.StreamingHandlerF if token.JwtID() != "" && token.Expiration().After(time.Now().UTC()) { incomingMD.Set(consts.UserTokenGatewayKey, tokenVal) } + } else if s.patConf.Prefix != "" && strings.HasPrefix(tokenVal, s.patConf.Prefix+"_") { + incomingMD.Set(consts.UserTokenGatewayKey, tokenVal) } secretVal := strings.TrimSpace(strings.TrimPrefix(authHeader[0], "Basic ")) if len(secretVal) > 0 { @@ -112,6 +116,8 @@ func (s *SessionInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc if token.JwtID() != "" && token.Expiration().After(time.Now().UTC()) { incomingMD.Set(consts.UserTokenGatewayKey, tokenVal) } + } else if s.patConf.Prefix != "" && strings.HasPrefix(tokenVal, s.patConf.Prefix+"_") { + incomingMD.Set(consts.UserTokenGatewayKey, tokenVal) } secretVal := strings.TrimSpace(strings.TrimPrefix(authHeader[0], "Basic ")) if len(secretVal) > 0 { @@ -197,12 +203,13 @@ func (s *SessionInterceptor) UnaryConnectResponseInterceptor() connect.UnaryInte return connect.UnaryInterceptorFunc(interceptor) } -func NewSessionInterceptor(cookieCutter securecookie.Codec, conf authenticate.SessionConfig, h *v1beta1connect.ConnectHandler) *SessionInterceptor { +func NewSessionInterceptor(cookieCutter securecookie.Codec, conf authenticate.SessionConfig, h *v1beta1connect.ConnectHandler, patConf userpat.Config) *SessionInterceptor { return &SessionInterceptor{ // could be nil if not configured by user cookieCodec: cookieCutter, conf: conf, h: h, + patConf: patConf, } } @@ -256,6 +263,8 @@ func (s *SessionInterceptor) UnaryConnectRequestHeadersAnnotator() connect.Unary if token.JwtID() != "" && token.Expiration().After(time.Now().UTC()) { incomingMD.Set(consts.UserTokenGatewayKey, tokenVal) } + } else if s.patConf.Prefix != "" && strings.HasPrefix(tokenVal, s.patConf.Prefix+"_") { + incomingMD.Set(consts.UserTokenGatewayKey, tokenVal) } secretVal := strings.TrimSpace(strings.TrimPrefix(authHeader[0], "Basic ")) if len(secretVal) > 0 { diff --git a/pkg/server/server.go b/pkg/server/server.go index bb9a766f0..802f1ee6f 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -146,7 +146,7 @@ func ServeConnect(ctx context.Context, logger log.Logger, cfg Config, deps api.D authNInterceptor := connectinterceptors.NewAuthenticationInterceptor(frontierService, cfg.Authentication.Session.Headers) authZInterceptor := connectinterceptors.NewAuthorizationInterceptor(frontierService) - sessionInterceptor := connectinterceptors.NewSessionInterceptor(sessionCookieCutter, cfg.Authentication.Session, frontierService) + sessionInterceptor := connectinterceptors.NewSessionInterceptor(sessionCookieCutter, cfg.Authentication.Session, frontierService, cfg.PAT) auditInterceptor := connectinterceptors.NewAuditInterceptor(deps.AuditService) interceptors := connect.WithInterceptors(