From 016e8963802fb56f17e105a5a6b4b2a699eaf1aa Mon Sep 17 00:00:00 2001 From: Abhishek Sah Date: Wed, 22 Apr 2026 11:12:57 +0530 Subject: [PATCH 1/5] refactor: move project member mutations into membership package Move SetProjectMemberRole and RemoveProjectMember from core/project/ into core/membership/ for consistency with the centralized membership pattern. Handlers now call membershipService instead of projectService for project member mutations. Key changes: - Add validateProjectRole with org-scoping (rejects cross-org custom roles) - Add validateOrgMembership with disabled user rejection - Add ServiceuserService dependency for org membership validation - Add audit records at service layer for project member mutations - Use principal-agnostic error messages and audit target types - Reuse principalTypeToAuditType from RemoveOrganizationMember Ref: https://github.com/raystack/frontier/issues/1478 Co-Authored-By: Claude Opus 4.6 (1M context) --- cmd/serve.go | 2 +- core/membership/errors.go | 2 + core/membership/mocks/group_service.go | 53 +++++ core/membership/mocks/project_service.go | 53 +++++ core/membership/mocks/serviceuser_service.go | 95 ++++++++ core/membership/service.go | 221 ++++++++++++++++++ core/membership/service_test.go | 174 +++++++++++++- internal/api/v1beta1connect/errors.go | 4 +- internal/api/v1beta1connect/interfaces.go | 4 +- .../mocks/membership_service.go | 30 +++ internal/api/v1beta1connect/project.go | 19 +- pkg/auditrecord/consts.go | 4 + 12 files changed, 643 insertions(+), 18 deletions(-) create mode 100644 core/membership/mocks/serviceuser_service.go diff --git a/cmd/serve.go b/cmd/serve.go index 342eed92b..bcfc1dde6 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -434,7 +434,7 @@ func buildAPIDependencies( projectService := project.NewService(projectRepository, relationService, userService, policyService, authnService, serviceUserService, groupService, roleService) - membershipService := membership.NewService(logger, policyService, relationService, roleService, organizationService, userService, projectService, groupService, auditRecordRepository) + membershipService := membership.NewService(logger, policyService, relationService, roleService, organizationService, userService, projectService, groupService, serviceUserService, auditRecordRepository) // Setter injection: org → membership is circular (membership needs org for validation, // org needs membership for Create/AdminCreate). Break the cycle with a post-init setter. organizationService.SetMembershipService(membershipService) diff --git a/core/membership/errors.go b/core/membership/errors.go index 01eb2e235..0e01732ea 100644 --- a/core/membership/errors.go +++ b/core/membership/errors.go @@ -9,4 +9,6 @@ var ( ErrLastOwnerRole = errors.New("cannot change role: this is the last owner of the organization") ErrInvalidPrincipal = errors.New("only user principals are supported") ErrInvalidPrincipalType = errors.New("unsupported principal type") + ErrNotOrgMember = errors.New("principal is not a member of the organization") + ErrInvalidProjectRole = errors.New("role is not valid for project scope") ) diff --git a/core/membership/mocks/group_service.go b/core/membership/mocks/group_service.go index 844ff8c8e..287a6ad03 100644 --- a/core/membership/mocks/group_service.go +++ b/core/membership/mocks/group_service.go @@ -82,6 +82,59 @@ func (_c *GroupService_List_Call) RunAndReturn(run func(context.Context, group.F return _c } +// Get provides a mock function with given fields: ctx, idOrName +func (_m *GroupService) Get(ctx context.Context, idOrName string) (group.Group, error) { + ret := _m.Called(ctx, idOrName) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 group.Group + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (group.Group, error)); ok { + return rf(ctx, idOrName) + } + if rf, ok := ret.Get(0).(func(context.Context, string) group.Group); ok { + r0 = rf(ctx, idOrName) + } else { + r0 = ret.Get(0).(group.Group) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, idOrName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type GroupService_Get_Call struct { + *mock.Call +} + +func (_e *GroupService_Expecter) Get(ctx interface{}, idOrName interface{}) *GroupService_Get_Call { + return &GroupService_Get_Call{Call: _e.mock.On("Get", ctx, idOrName)} +} + +func (_c *GroupService_Get_Call) Run(run func(ctx context.Context, idOrName string)) *GroupService_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *GroupService_Get_Call) Return(_a0 group.Group, _a1 error) *GroupService_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *GroupService_Get_Call) RunAndReturn(run func(context.Context, string) (group.Group, error)) *GroupService_Get_Call { + _c.Call.Return(run) + return _c +} + // NewGroupService creates a new instance of GroupService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewGroupService(t interface { diff --git a/core/membership/mocks/project_service.go b/core/membership/mocks/project_service.go index 59dd48b65..f53ac32d9 100644 --- a/core/membership/mocks/project_service.go +++ b/core/membership/mocks/project_service.go @@ -82,6 +82,59 @@ func (_c *ProjectService_List_Call) RunAndReturn(run func(context.Context, proje return _c } +// Get provides a mock function with given fields: ctx, idOrName +func (_m *ProjectService) Get(ctx context.Context, idOrName string) (project.Project, error) { + ret := _m.Called(ctx, idOrName) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 project.Project + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (project.Project, error)); ok { + return rf(ctx, idOrName) + } + if rf, ok := ret.Get(0).(func(context.Context, string) project.Project); ok { + r0 = rf(ctx, idOrName) + } else { + r0 = ret.Get(0).(project.Project) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, idOrName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type ProjectService_Get_Call struct { + *mock.Call +} + +func (_e *ProjectService_Expecter) Get(ctx interface{}, idOrName interface{}) *ProjectService_Get_Call { + return &ProjectService_Get_Call{Call: _e.mock.On("Get", ctx, idOrName)} +} + +func (_c *ProjectService_Get_Call) Run(run func(ctx context.Context, idOrName string)) *ProjectService_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *ProjectService_Get_Call) Return(_a0 project.Project, _a1 error) *ProjectService_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *ProjectService_Get_Call) RunAndReturn(run func(context.Context, string) (project.Project, error)) *ProjectService_Get_Call { + _c.Call.Return(run) + return _c +} + // NewProjectService creates a new instance of ProjectService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewProjectService(t interface { diff --git a/core/membership/mocks/serviceuser_service.go b/core/membership/mocks/serviceuser_service.go new file mode 100644 index 000000000..230b7ee98 --- /dev/null +++ b/core/membership/mocks/serviceuser_service.go @@ -0,0 +1,95 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + serviceuser "github.com/raystack/frontier/core/serviceuser" +) + +// ServiceuserService is an autogenerated mock type for the ServiceuserService type +type ServiceuserService struct { + mock.Mock +} + +type ServiceuserService_Expecter struct { + mock *mock.Mock +} + +func (_m *ServiceuserService) EXPECT() *ServiceuserService_Expecter { + return &ServiceuserService_Expecter{mock: &_m.Mock} +} + +// Get provides a mock function with given fields: ctx, id +func (_m *ServiceuserService) Get(ctx context.Context, id string) (serviceuser.ServiceUser, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 serviceuser.ServiceUser + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (serviceuser.ServiceUser, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) serviceuser.ServiceUser); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(serviceuser.ServiceUser) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ServiceuserService_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type ServiceuserService_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *ServiceuserService_Expecter) Get(ctx interface{}, id interface{}) *ServiceuserService_Get_Call { + return &ServiceuserService_Get_Call{Call: _e.mock.On("Get", ctx, id)} +} + +func (_c *ServiceuserService_Get_Call) Run(run func(ctx context.Context, id string)) *ServiceuserService_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *ServiceuserService_Get_Call) Return(_a0 serviceuser.ServiceUser, _a1 error) *ServiceuserService_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *ServiceuserService_Get_Call) RunAndReturn(run func(context.Context, string) (serviceuser.ServiceUser, error)) *ServiceuserService_Get_Call { + _c.Call.Return(run) + return _c +} + +// NewServiceuserService creates a new instance of ServiceuserService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewServiceuserService(t interface { + mock.TestingT + Cleanup(func()) +}) *ServiceuserService { + mock := &ServiceuserService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/membership/service.go b/core/membership/service.go index 5aa358709..e5a77eeb5 100644 --- a/core/membership/service.go +++ b/core/membership/service.go @@ -15,6 +15,7 @@ import ( "github.com/raystack/frontier/core/project" "github.com/raystack/frontier/core/relation" "github.com/raystack/frontier/core/role" + "github.com/raystack/frontier/core/serviceuser" "github.com/raystack/frontier/core/user" "github.com/raystack/frontier/internal/bootstrap/schema" pkgAuditRecord "github.com/raystack/frontier/pkg/auditrecord" @@ -46,13 +47,19 @@ type UserService interface { } type ProjectService interface { + Get(ctx context.Context, idOrName string) (project.Project, error) List(ctx context.Context, flt project.Filter) ([]project.Project, error) } type GroupService interface { + Get(ctx context.Context, idOrName string) (group.Group, error) List(ctx context.Context, flt group.Filter) ([]group.Group, error) } +type ServiceuserService interface { + Get(ctx context.Context, id string) (serviceuser.ServiceUser, error) +} + type AuditRecordRepository interface { Create(ctx context.Context, auditRecord auditrecord.AuditRecord) (auditrecord.AuditRecord, error) } @@ -66,6 +73,7 @@ type Service struct { userService UserService projectService ProjectService groupService GroupService + serviceuserService ServiceuserService auditRecordRepository AuditRecordRepository } @@ -78,6 +86,7 @@ func NewService( userService UserService, projectService ProjectService, groupService GroupService, + serviceuserService ServiceuserService, auditRecordRepository AuditRecordRepository, ) *Service { return &Service{ @@ -89,6 +98,7 @@ func NewService( userService: userService, projectService: projectService, groupService: groupService, + serviceuserService: serviceuserService, auditRecordRepository: auditRecordRepository, } } @@ -653,3 +663,214 @@ func principalTypeToAuditType(principalType string) (pkgAuditRecord.EntityType, return "", ErrInvalidPrincipalType } } + +// SetProjectMemberRole sets or changes a principal's role in a project (upsert). +// It validates the role is project-scoped and the principal is a member of the parent org. +// No explicit SpiceDB relations are managed — projects use policies only. +func (s *Service) SetProjectMemberRole(ctx context.Context, projectID, principalID, principalType, roleID string) error { + prj, err := s.projectService.Get(ctx, projectID) + if err != nil { + return err + } + + fetchedRole, err := s.validateProjectRole(ctx, roleID, prj.Organization.ID) + if err != nil { + return err + } + resolvedRoleID := fetchedRole.ID + + if err := s.validateOrgMembership(ctx, prj.Organization.ID, principalID, principalType); err != nil { + return err + } + + existing, err := s.policyService.List(ctx, policy.Filter{ + ProjectID: projectID, + PrincipalID: principalID, + PrincipalType: principalType, + }) + if err != nil { + return fmt.Errorf("list existing policies: %w", err) + } + + // skip if the principal already has exactly this role + if len(existing) == 1 && existing[0].RoleID == resolvedRoleID { + return nil + } + + if err := s.replacePolicy(ctx, projectID, schema.ProjectNamespace, principalID, principalType, resolvedRoleID, existing); err != nil { + return err + } + + s.auditProjectMemberRoleChanged(ctx, prj, principalID, principalType, resolvedRoleID) + return nil +} + +// RemoveProjectMember removes a principal from a project by deleting all their project-level policies. +func (s *Service) RemoveProjectMember(ctx context.Context, projectID, principalID, principalType string) error { + switch principalType { + case schema.UserPrincipal, schema.ServiceUserPrincipal, schema.GroupPrincipal: + default: + return ErrInvalidPrincipalType + } + + removed, err := s.removeAllPolicies(ctx, projectID, schema.ProjectNamespace, principalID, principalType) + if err != nil { + return err + } + if removed == 0 { + return ErrNotMember + } + + // best-effort audit — fetch project for context, skip if it fails + if prj, err := s.projectService.Get(ctx, projectID); err == nil { + s.auditProjectMemberRemoved(ctx, prj, principalID, principalType) + } + return nil +} + +// removeAllPolicies finds and deletes all policies for a principal on a resource. +// Returns the number of policies deleted. +func (s *Service) removeAllPolicies(ctx context.Context, resourceID, resourceType, principalID, principalType string) (int, error) { + f := policyFilterForResource(resourceID, resourceType, principalID, principalType) + existing, err := s.policyService.List(ctx, f) + if err != nil { + return 0, fmt.Errorf("list policies: %w", err) + } + for _, pol := range existing { + if err := s.policyService.Delete(ctx, pol.ID); err != nil { + return 0, fmt.Errorf("delete policy %s: %w", pol.ID, err) + } + } + return len(existing), nil +} + +// policyFilterForResource builds a policy.Filter with the correct resource-type field set. +func policyFilterForResource(resourceID, resourceType, principalID, principalType string) policy.Filter { + f := policy.Filter{ + PrincipalID: principalID, + PrincipalType: principalType, + } + switch resourceType { + case schema.OrganizationNamespace: + f.OrgID = resourceID + case schema.ProjectNamespace: + f.ProjectID = resourceID + case schema.GroupNamespace: + f.GroupID = resourceID + } + return f +} + +// validateProjectRole checks that the role is valid for project scope: +// - a platform-wide role scoped to projects, or +// - a custom role created for the project's parent organization. +func (s *Service) validateProjectRole(ctx context.Context, roleID, orgID string) (role.Role, error) { + fetchedRole, err := s.roleService.Get(ctx, roleID) + if err != nil { + return role.Role{}, err + } + if !slices.Contains(fetchedRole.Scopes, schema.ProjectNamespace) { + return role.Role{}, ErrInvalidProjectRole + } + + // custom role belonging to the project's parent org + if fetchedRole.OrgID == orgID { + return fetchedRole, nil + } + + // platform-wide role (no org ownership) + if utils.IsNullUUID(fetchedRole.OrgID) { + return fetchedRole, nil + } + + return role.Role{}, ErrInvalidProjectRole +} + +// validateOrgMembership checks that the principal exists and belongs to the given org. +// For users, org membership is verified via org-level policies. +// For service users and groups, org membership is verified via their org ID field. +func (s *Service) validateOrgMembership(ctx context.Context, orgID, principalID, principalType string) error { + switch principalType { + case schema.UserPrincipal: + usr, err := s.userService.GetByID(ctx, principalID) + if err != nil { + return err + } + if usr.State == user.Disabled { + return user.ErrDisabled + } + orgPolicies, err := s.policyService.List(ctx, policy.Filter{ + OrgID: orgID, + PrincipalID: principalID, + PrincipalType: principalType, + }) + if err != nil { + return err + } + if len(orgPolicies) == 0 { + return ErrNotOrgMember + } + case schema.ServiceUserPrincipal: + su, err := s.serviceuserService.Get(ctx, principalID) + if err != nil { + return err + } + if su.OrgID != orgID { + return ErrNotOrgMember + } + case schema.GroupPrincipal: + grp, err := s.groupService.Get(ctx, principalID) + if err != nil { + return err + } + if grp.OrganizationID != orgID { + return ErrNotOrgMember + } + default: + return ErrInvalidPrincipalType + } + return nil +} + +func (s *Service) auditProjectMemberRoleChanged(ctx context.Context, prj project.Project, principalID, principalType, roleID string) { + targetType, _ := principalTypeToAuditType(principalType) + s.auditRecordRepository.Create(ctx, auditrecord.AuditRecord{ + Event: pkgAuditRecord.ProjectMemberRoleChangedEvent, + Resource: auditrecord.Resource{ + ID: prj.ID, + Type: pkgAuditRecord.ProjectType, + Name: prj.Title, + }, + Target: &auditrecord.Target{ + ID: principalID, + Type: targetType, + Metadata: map[string]any{ + "principal_type": principalType, + "role_id": roleID, + }, + }, + OrgID: prj.Organization.ID, + OccurredAt: time.Now(), + }) +} + +func (s *Service) auditProjectMemberRemoved(ctx context.Context, prj project.Project, principalID, principalType string) { + targetType, _ := principalTypeToAuditType(principalType) + s.auditRecordRepository.Create(ctx, auditrecord.AuditRecord{ + Event: pkgAuditRecord.ProjectMemberRemovedEvent, + Resource: auditrecord.Resource{ + ID: prj.ID, + Type: pkgAuditRecord.ProjectType, + Name: prj.Title, + }, + Target: &auditrecord.Target{ + ID: principalID, + Type: targetType, + Metadata: map[string]any{ + "principal_type": principalType, + }, + }, + OrgID: prj.Organization.ID, + OccurredAt: time.Now(), + }) +} diff --git a/core/membership/service_test.go b/core/membership/service_test.go index 2d7db615e..de81d6fbd 100644 --- a/core/membership/service_test.go +++ b/core/membership/service_test.go @@ -15,6 +15,7 @@ import ( "github.com/raystack/frontier/core/project" "github.com/raystack/frontier/core/relation" "github.com/raystack/frontier/core/role" + "github.com/raystack/frontier/core/serviceuser" "github.com/raystack/frontier/core/user" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/salt/log" @@ -256,7 +257,7 @@ func TestService_AddOrganizationMember(t *testing.T) { tt.setup(mockPolicySvc, mockRelSvc, mockRoleSvc, mockOrgSvc, mockUserSvc, mockAuditRepo) } - svc := membership.NewService(log.NewNoop(), mockPolicySvc, mockRelSvc, mockRoleSvc, mockOrgSvc, mockUserSvc, mocks.NewProjectService(t), mocks.NewGroupService(t), mockAuditRepo) + svc := membership.NewService(log.NewNoop(), mockPolicySvc, mockRelSvc, mockRoleSvc, mockOrgSvc, mockUserSvc, mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mockAuditRepo) principalType := tt.principalType if principalType == "" { @@ -447,7 +448,7 @@ func TestService_SetOrganizationMemberRole(t *testing.T) { tt.setup(mockPolicySvc, mockRelSvc, mockRoleSvc, mockOrgSvc, mockUserSvc, mockAuditRepo) } - svc := membership.NewService(log.NewNoop(), mockPolicySvc, mockRelSvc, mockRoleSvc, mockOrgSvc, mockUserSvc, mocks.NewProjectService(t), mocks.NewGroupService(t), mockAuditRepo) + svc := membership.NewService(log.NewNoop(), mockPolicySvc, mockRelSvc, mockRoleSvc, mockOrgSvc, mockUserSvc, mocks.NewProjectService(t), mocks.NewGroupService(t), mocks.NewServiceuserService(t), mockAuditRepo) principalType := tt.principalType if principalType == "" { @@ -671,7 +672,7 @@ func TestService_RemoveOrganizationMember(t *testing.T) { tt.setup(d) } - svc := membership.NewService(log.NewNoop(), d.policySvc, d.relSvc, d.roleSvc, d.orgSvc, mocks.NewUserService(t), d.projSvc, d.grpSvc, d.auditRepo) + svc := membership.NewService(log.NewNoop(), d.policySvc, d.relSvc, d.roleSvc, d.orgSvc, mocks.NewUserService(t), d.projSvc, d.grpSvc, mocks.NewServiceuserService(t), d.auditRepo) principalType := tt.principalType if principalType == "" { @@ -689,3 +690,170 @@ func TestService_RemoveOrganizationMember(t *testing.T) { }) } } + +func TestService_SetProjectMemberRole(t *testing.T) { + ctx := context.Background() + projectID := uuid.New().String() + orgID := uuid.New().String() + userID := uuid.New().String() + suID := uuid.New().String() + roleID := uuid.New().String() + + prj := project.Project{ + ID: projectID, + Organization: organization.Organization{ID: orgID}, + } + + tests := []struct { + name string + setup func(*mocks.PolicyService, *mocks.RoleService, *mocks.ProjectService, *mocks.UserService, *mocks.ServiceuserService, *mocks.GroupService, *mocks.AuditRecordRepository) + principalID string + principalType string + roleID string + wantErr error + }{ + { + name: "should return error if project does not exist", + setup: func(_ *mocks.PolicyService, _ *mocks.RoleService, prjSvc *mocks.ProjectService, _ *mocks.UserService, _ *mocks.ServiceuserService, _ *mocks.GroupService, _ *mocks.AuditRecordRepository) { + prjSvc.EXPECT().Get(ctx, projectID).Return(project.Project{}, project.ErrNotExist) + }, + principalID: userID, principalType: schema.UserPrincipal, roleID: roleID, + wantErr: project.ErrNotExist, + }, + { + name: "should return error if role is not project-scoped", + setup: func(_ *mocks.PolicyService, roleSvc *mocks.RoleService, prjSvc *mocks.ProjectService, _ *mocks.UserService, _ *mocks.ServiceuserService, _ *mocks.GroupService, _ *mocks.AuditRecordRepository) { + prjSvc.EXPECT().Get(ctx, projectID).Return(prj, nil) + roleSvc.EXPECT().Get(ctx, roleID).Return(role.Role{ID: roleID, Scopes: []string{schema.OrganizationNamespace}}, nil) + }, + principalID: userID, principalType: schema.UserPrincipal, roleID: roleID, + wantErr: membership.ErrInvalidProjectRole, + }, + { + name: "should return error if user is not org member", + setup: func(policySvc *mocks.PolicyService, roleSvc *mocks.RoleService, prjSvc *mocks.ProjectService, userSvc *mocks.UserService, _ *mocks.ServiceuserService, _ *mocks.GroupService, _ *mocks.AuditRecordRepository) { + prjSvc.EXPECT().Get(ctx, projectID).Return(prj, nil) + roleSvc.EXPECT().Get(ctx, roleID).Return(role.Role{ID: roleID, Scopes: []string{schema.ProjectNamespace}}, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(user.User{ID: userID, State: user.Enabled}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{}, nil) + }, + principalID: userID, principalType: schema.UserPrincipal, roleID: roleID, + wantErr: membership.ErrNotOrgMember, + }, + { + name: "should return error if service user is not in org", + setup: func(_ *mocks.PolicyService, roleSvc *mocks.RoleService, prjSvc *mocks.ProjectService, _ *mocks.UserService, suSvc *mocks.ServiceuserService, _ *mocks.GroupService, _ *mocks.AuditRecordRepository) { + prjSvc.EXPECT().Get(ctx, projectID).Return(prj, nil) + roleSvc.EXPECT().Get(ctx, roleID).Return(role.Role{ID: roleID, Scopes: []string{schema.ProjectNamespace}}, nil) + suSvc.EXPECT().Get(ctx, suID).Return(serviceuser.ServiceUser{ID: suID, OrgID: "other-org"}, nil) + }, + principalID: suID, principalType: schema.ServiceUserPrincipal, roleID: roleID, + wantErr: membership.ErrNotOrgMember, + }, + { + name: "should succeed adding new user to project", + setup: func(policySvc *mocks.PolicyService, roleSvc *mocks.RoleService, prjSvc *mocks.ProjectService, userSvc *mocks.UserService, _ *mocks.ServiceuserService, _ *mocks.GroupService, auditRepo *mocks.AuditRecordRepository) { + prjSvc.EXPECT().Get(ctx, projectID).Return(prj, nil) + roleSvc.EXPECT().Get(ctx, roleID).Return(role.Role{ID: roleID, Scopes: []string{schema.ProjectNamespace}}, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(user.User{ID: userID, State: user.Enabled}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "org-p1"}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{ProjectID: projectID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{}, nil) + policySvc.EXPECT().Create(ctx, policy.Policy{ + RoleID: roleID, ResourceID: projectID, ResourceType: schema.ProjectNamespace, + PrincipalID: userID, PrincipalType: schema.UserPrincipal, + }).Return(policy.Policy{}, nil) + auditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + }, + principalID: userID, principalType: schema.UserPrincipal, roleID: roleID, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockRoleSvc := mocks.NewRoleService(t) + mockPrjSvc := mocks.NewProjectService(t) + mockUserSvc := mocks.NewUserService(t) + mockSuSvc := mocks.NewServiceuserService(t) + mockGrpSvc := mocks.NewGroupService(t) + mockAuditRepo := mocks.NewAuditRecordRepository(t) + + if tt.setup != nil { + tt.setup(mockPolicySvc, mockRoleSvc, mockPrjSvc, mockUserSvc, mockSuSvc, mockGrpSvc, mockAuditRepo) + } + + svc := membership.NewService(log.NewNoop(), mockPolicySvc, mocks.NewRelationService(t), mockRoleSvc, mocks.NewOrgService(t), mockUserSvc, mockPrjSvc, mockGrpSvc, mockSuSvc, mockAuditRepo) + err := svc.SetProjectMemberRole(ctx, projectID, tt.principalID, tt.principalType, tt.roleID) + + if tt.wantErr != nil { + assert.ErrorIs(t, err, tt.wantErr) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestService_RemoveProjectMember(t *testing.T) { + ctx := context.Background() + projectID := uuid.New().String() + userID := uuid.New().String() + + prj := project.Project{ + ID: projectID, + Title: "Test Project", + Organization: organization.Organization{ID: uuid.New().String()}, + } + + tests := []struct { + name string + setup func(*mocks.PolicyService, *mocks.ProjectService, *mocks.AuditRecordRepository) + principalType string + wantErr error + }{ + { + name: "should return error for invalid principal type", + principalType: "app/invalid", + wantErr: membership.ErrInvalidPrincipalType, + }, + { + name: "should return error if not a member", + setup: func(policySvc *mocks.PolicyService, _ *mocks.ProjectService, _ *mocks.AuditRecordRepository) { + policySvc.EXPECT().List(ctx, policy.Filter{ProjectID: projectID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{}, nil) + }, + principalType: schema.UserPrincipal, + wantErr: membership.ErrNotMember, + }, + { + name: "should succeed removing a member", + setup: func(policySvc *mocks.PolicyService, prjSvc *mocks.ProjectService, auditRepo *mocks.AuditRecordRepository) { + policySvc.EXPECT().List(ctx, policy.Filter{ProjectID: projectID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1"}}, nil) + policySvc.EXPECT().Delete(ctx, "p1").Return(nil) + prjSvc.EXPECT().Get(ctx, projectID).Return(prj, nil) + auditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + }, + principalType: schema.UserPrincipal, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockPolicySvc := mocks.NewPolicyService(t) + mockPrjSvc := mocks.NewProjectService(t) + mockAuditRepo := mocks.NewAuditRecordRepository(t) + + if tt.setup != nil { + tt.setup(mockPolicySvc, mockPrjSvc, mockAuditRepo) + } + + svc := membership.NewService(log.NewNoop(), mockPolicySvc, mocks.NewRelationService(t), mocks.NewRoleService(t), mocks.NewOrgService(t), mocks.NewUserService(t), mockPrjSvc, mocks.NewGroupService(t), mocks.NewServiceuserService(t), mockAuditRepo) + err := svc.RemoveProjectMember(ctx, projectID, userID, tt.principalType) + + if tt.wantErr != nil { + assert.ErrorIs(t, err, tt.wantErr) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/api/v1beta1connect/errors.go b/internal/api/v1beta1connect/errors.go index 27fb89345..db757aafb 100644 --- a/internal/api/v1beta1connect/errors.go +++ b/internal/api/v1beta1connect/errors.go @@ -34,8 +34,8 @@ var ( ErrDomainMismatch = errors.New("user and org's whitelisted domains doesn't match") ErrInvitationNotFound = errors.New("invitation not found") ErrInvitationExpired = errors.New("invitation expired") - ErrAlreadyMember = errors.New("user is already a member of the organization") - ErrNotMember = errors.New("user is not a member of the organization") + ErrAlreadyMember = errors.New("principal is already a member of the organization") + ErrNotMember = errors.New("principal is not a member of the organization") ErrInvalidOrgRole = errors.New("role is not valid for organization scope") ErrInvalidProjectRole = errors.New("role is not valid for project scope") ErrEmptyEmailID = errors.New("email id is empty") diff --git a/internal/api/v1beta1connect/interfaces.go b/internal/api/v1beta1connect/interfaces.go index 7c04dcfec..f3969a259 100644 --- a/internal/api/v1beta1connect/interfaces.go +++ b/internal/api/v1beta1connect/interfaces.go @@ -354,8 +354,6 @@ type ProjectService interface { ListGroups(ctx context.Context, id string) ([]group.Group, error) Enable(ctx context.Context, id string) error Disable(ctx context.Context, id string) error - SetMemberRole(ctx context.Context, projectID, principalID, principalType, newRoleID string) error - RemoveMember(ctx context.Context, projectID, principalID, principalType string) error } type OrgUsersService interface { @@ -409,6 +407,8 @@ type MembershipService interface { AddOrganizationMember(ctx context.Context, orgID, principalID, principalType, roleID string) error SetOrganizationMemberRole(ctx context.Context, orgID, principalID, principalType, roleID string) error RemoveOrganizationMember(ctx context.Context, orgID, principalID, principalType string) error + SetProjectMemberRole(ctx context.Context, projectID, principalID, principalType, roleID string) error + RemoveProjectMember(ctx context.Context, projectID, principalID, principalType string) error } type UserPATService interface { diff --git a/internal/api/v1beta1connect/mocks/membership_service.go b/internal/api/v1beta1connect/mocks/membership_service.go index 790e69c71..eaa821ab0 100644 --- a/internal/api/v1beta1connect/mocks/membership_service.go +++ b/internal/api/v1beta1connect/mocks/membership_service.go @@ -170,6 +170,36 @@ func (_c *MembershipService_SetOrganizationMemberRole_Call) RunAndReturn(run fun return _c } +// SetProjectMemberRole provides a mock function with given fields: ctx, projectID, principalID, principalType, roleID +func (_m *MembershipService) SetProjectMemberRole(ctx context.Context, projectID string, principalID string, principalType string, roleID string) error { + ret := _m.Called(ctx, projectID, principalID, principalType, roleID) + if len(ret) == 0 { + panic("no return value specified for SetProjectMemberRole") + } + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, projectID, principalID, principalType, roleID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// RemoveProjectMember provides a mock function with given fields: ctx, projectID, principalID, principalType +func (_m *MembershipService) RemoveProjectMember(ctx context.Context, projectID string, principalID string, principalType string) error { + ret := _m.Called(ctx, projectID, principalID, principalType) + if len(ret) == 0 { + panic("no return value specified for RemoveProjectMember") + } + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { + r0 = rf(ctx, projectID, principalID, principalType) + } else { + r0 = ret.Error(0) + } + return r0 +} + // NewMembershipService creates a new instance of MembershipService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMembershipService(t interface { diff --git a/internal/api/v1beta1connect/project.go b/internal/api/v1beta1connect/project.go index 93c45a895..affa199b2 100644 --- a/internal/api/v1beta1connect/project.go +++ b/internal/api/v1beta1connect/project.go @@ -6,6 +6,7 @@ import ( "connectrpc.com/connect" "github.com/raystack/frontier/core/audit" "github.com/raystack/frontier/core/group" + "github.com/raystack/frontier/core/membership" "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/project" "github.com/raystack/frontier/core/role" @@ -370,7 +371,7 @@ func (h *ConnectHandler) SetProjectMemberRole(ctx context.Context, request *conn principalType := request.Msg.GetPrincipalType() roleID := request.Msg.GetRoleId() - if err := h.projectService.SetMemberRole(ctx, projectID, principalID, principalType, roleID); err != nil { + if err := h.membershipService.SetProjectMemberRole(ctx, projectID, principalID, principalType, roleID); err != nil { errorLogger.LogServiceError(ctx, request, "SetProjectMemberRole", err, zap.String("project_id", projectID), zap.String("principal_id", principalID), @@ -386,15 +387,15 @@ func (h *ConnectHandler) SetProjectMemberRole(ctx context.Context, request *conn return nil, connect.NewError(connect.CodeNotFound, ErrServiceUserNotFound) case errors.Is(err, group.ErrNotExist): return nil, connect.NewError(connect.CodeNotFound, ErrGroupNotFound) - case errors.Is(err, project.ErrNotOrgMember): + case errors.Is(err, membership.ErrNotOrgMember): return nil, connect.NewError(connect.CodeFailedPrecondition, ErrNotMember) case errors.Is(err, role.ErrNotExist): return nil, connect.NewError(connect.CodeNotFound, ErrInvalidRoleID) case errors.Is(err, role.ErrInvalidID): return nil, connect.NewError(connect.CodeInvalidArgument, ErrInvalidRoleID) - case errors.Is(err, project.ErrInvalidProjectRole): + case errors.Is(err, membership.ErrInvalidProjectRole): return nil, connect.NewError(connect.CodeInvalidArgument, ErrInvalidProjectRole) - case errors.Is(err, project.ErrInvalidPrincipalType): + case errors.Is(err, membership.ErrInvalidPrincipalType): return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) default: return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) @@ -416,18 +417,16 @@ func (h *ConnectHandler) RemoveProjectMember(ctx context.Context, request *conne principalID := request.Msg.GetPrincipalId() principalType := request.Msg.GetPrincipalType() - if err := h.projectService.RemoveMember(ctx, projectID, principalID, principalType); err != nil { + if err := h.membershipService.RemoveProjectMember(ctx, projectID, principalID, principalType); err != nil { errorLogger.LogServiceError(ctx, request, "RemoveProjectMember", err, zap.String("project_id", projectID), zap.String("principal_id", principalID), zap.String("principal_type", principalType)) switch { - case errors.Is(err, project.ErrNotExist): - return nil, connect.NewError(connect.CodeNotFound, ErrProjectNotFound) - case errors.Is(err, project.ErrNotMember): - return nil, connect.NewError(connect.CodeNotFound, project.ErrNotMember) - case errors.Is(err, project.ErrInvalidPrincipalType): + case errors.Is(err, membership.ErrNotMember): + return nil, connect.NewError(connect.CodeNotFound, ErrNotMember) + case errors.Is(err, membership.ErrInvalidPrincipalType): return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) default: return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) diff --git a/pkg/auditrecord/consts.go b/pkg/auditrecord/consts.go index 01166acb7..f1f321d25 100644 --- a/pkg/auditrecord/consts.go +++ b/pkg/auditrecord/consts.go @@ -39,6 +39,10 @@ const ( OrganizationMemberRoleChangedEvent Event = "organization.role_changed" OrganizationInvitationAcceptedEvent Event = "organization.accepted" + // Project Member Events + ProjectMemberRoleChangedEvent Event = "project.member.role_changed" + ProjectMemberRemovedEvent Event = "project.member.removed" + // KYC Events KYCVerifiedEvent Event = "kyc.verified" KYCUnverifiedEvent Event = "kyc.unverified" From 52fabf9575b51f78f3377be6b4b1eb76cc45e55e Mon Sep 17 00:00:00 2001 From: Abhishek Sah Date: Wed, 22 Apr 2026 11:24:13 +0530 Subject: [PATCH 2/5] fix: fetch project before deleting policies in RemoveProjectMember Validate project existence and capture context for audit before performing any destructive operations. This ensures stale policies for invalid project IDs are not silently deleted and audit records are always emitted on success. Co-Authored-By: Claude Opus 4.6 (1M context) --- core/membership/service.go | 10 ++++++---- core/membership/service_test.go | 5 +++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/core/membership/service.go b/core/membership/service.go index e5a77eeb5..0366bbc62 100644 --- a/core/membership/service.go +++ b/core/membership/service.go @@ -713,6 +713,11 @@ func (s *Service) RemoveProjectMember(ctx context.Context, projectID, principalI return ErrInvalidPrincipalType } + prj, err := s.projectService.Get(ctx, projectID) + if err != nil { + return err + } + removed, err := s.removeAllPolicies(ctx, projectID, schema.ProjectNamespace, principalID, principalType) if err != nil { return err @@ -721,10 +726,7 @@ func (s *Service) RemoveProjectMember(ctx context.Context, projectID, principalI return ErrNotMember } - // best-effort audit — fetch project for context, skip if it fails - if prj, err := s.projectService.Get(ctx, projectID); err == nil { - s.auditProjectMemberRemoved(ctx, prj, principalID, principalType) - } + s.auditProjectMemberRemoved(ctx, prj, principalID, principalType) return nil } diff --git a/core/membership/service_test.go b/core/membership/service_test.go index de81d6fbd..dd87262a7 100644 --- a/core/membership/service_test.go +++ b/core/membership/service_test.go @@ -818,7 +818,8 @@ func TestService_RemoveProjectMember(t *testing.T) { }, { name: "should return error if not a member", - setup: func(policySvc *mocks.PolicyService, _ *mocks.ProjectService, _ *mocks.AuditRecordRepository) { + setup: func(policySvc *mocks.PolicyService, prjSvc *mocks.ProjectService, _ *mocks.AuditRecordRepository) { + prjSvc.EXPECT().Get(ctx, projectID).Return(prj, nil) policySvc.EXPECT().List(ctx, policy.Filter{ProjectID: projectID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{}, nil) }, principalType: schema.UserPrincipal, @@ -827,9 +828,9 @@ func TestService_RemoveProjectMember(t *testing.T) { { name: "should succeed removing a member", setup: func(policySvc *mocks.PolicyService, prjSvc *mocks.ProjectService, auditRepo *mocks.AuditRecordRepository) { + prjSvc.EXPECT().Get(ctx, projectID).Return(prj, nil) policySvc.EXPECT().List(ctx, policy.Filter{ProjectID: projectID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1"}}, nil) policySvc.EXPECT().Delete(ctx, "p1").Return(nil) - prjSvc.EXPECT().Get(ctx, projectID).Return(prj, nil) auditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) }, principalType: schema.UserPrincipal, From c657180d6a9aba66b44f92125e969532a3ed40d9 Mon Sep 17 00:00:00 2001 From: Abhishek Sah Date: Wed, 22 Apr 2026 11:27:42 +0530 Subject: [PATCH 3/5] chore: regenerate mocks with mockery Co-Authored-By: Claude Opus 4.6 (1M context) --- core/membership/mocks/group_service.go | 90 +++++++++-------- core/membership/mocks/project_service.go | 90 +++++++++-------- .../mocks/membership_service.go | 95 +++++++++++++++--- .../v1beta1connect/mocks/project_service.go | 99 ------------------- 4 files changed, 176 insertions(+), 198 deletions(-) diff --git a/core/membership/mocks/group_service.go b/core/membership/mocks/group_service.go index 287a6ad03..2e84c6765 100644 --- a/core/membership/mocks/group_service.go +++ b/core/membership/mocks/group_service.go @@ -23,29 +23,27 @@ func (_m *GroupService) EXPECT() *GroupService_Expecter { return &GroupService_Expecter{mock: &_m.Mock} } -// List provides a mock function with given fields: ctx, flt -func (_m *GroupService) List(ctx context.Context, flt group.Filter) ([]group.Group, error) { - ret := _m.Called(ctx, flt) +// Get provides a mock function with given fields: ctx, idOrName +func (_m *GroupService) Get(ctx context.Context, idOrName string) (group.Group, error) { + ret := _m.Called(ctx, idOrName) if len(ret) == 0 { - panic("no return value specified for List") + panic("no return value specified for Get") } - var r0 []group.Group + var r0 group.Group var r1 error - if rf, ok := ret.Get(0).(func(context.Context, group.Filter) ([]group.Group, error)); ok { - return rf(ctx, flt) + if rf, ok := ret.Get(0).(func(context.Context, string) (group.Group, error)); ok { + return rf(ctx, idOrName) } - if rf, ok := ret.Get(0).(func(context.Context, group.Filter) []group.Group); ok { - r0 = rf(ctx, flt) + if rf, ok := ret.Get(0).(func(context.Context, string) group.Group); ok { + r0 = rf(ctx, idOrName) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]group.Group) - } + r0 = ret.Get(0).(group.Group) } - if rf, ok := ret.Get(1).(func(context.Context, group.Filter) error); ok { - r1 = rf(ctx, flt) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, idOrName) } else { r1 = ret.Error(1) } @@ -53,56 +51,58 @@ func (_m *GroupService) List(ctx context.Context, flt group.Filter) ([]group.Gro return r0, r1 } -// GroupService_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' -type GroupService_List_Call struct { +// GroupService_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type GroupService_Get_Call struct { *mock.Call } -// List is a helper method to define mock.On call +// Get is a helper method to define mock.On call // - ctx context.Context -// - flt group.Filter -func (_e *GroupService_Expecter) List(ctx interface{}, flt interface{}) *GroupService_List_Call { - return &GroupService_List_Call{Call: _e.mock.On("List", ctx, flt)} +// - idOrName string +func (_e *GroupService_Expecter) Get(ctx interface{}, idOrName interface{}) *GroupService_Get_Call { + return &GroupService_Get_Call{Call: _e.mock.On("Get", ctx, idOrName)} } -func (_c *GroupService_List_Call) Run(run func(ctx context.Context, flt group.Filter)) *GroupService_List_Call { +func (_c *GroupService_Get_Call) Run(run func(ctx context.Context, idOrName string)) *GroupService_Get_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(group.Filter)) + run(args[0].(context.Context), args[1].(string)) }) return _c } -func (_c *GroupService_List_Call) Return(_a0 []group.Group, _a1 error) *GroupService_List_Call { +func (_c *GroupService_Get_Call) Return(_a0 group.Group, _a1 error) *GroupService_Get_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *GroupService_List_Call) RunAndReturn(run func(context.Context, group.Filter) ([]group.Group, error)) *GroupService_List_Call { +func (_c *GroupService_Get_Call) RunAndReturn(run func(context.Context, string) (group.Group, error)) *GroupService_Get_Call { _c.Call.Return(run) return _c } -// Get provides a mock function with given fields: ctx, idOrName -func (_m *GroupService) Get(ctx context.Context, idOrName string) (group.Group, error) { - ret := _m.Called(ctx, idOrName) +// List provides a mock function with given fields: ctx, flt +func (_m *GroupService) List(ctx context.Context, flt group.Filter) ([]group.Group, error) { + ret := _m.Called(ctx, flt) if len(ret) == 0 { - panic("no return value specified for Get") + panic("no return value specified for List") } - var r0 group.Group + var r0 []group.Group var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (group.Group, error)); ok { - return rf(ctx, idOrName) + if rf, ok := ret.Get(0).(func(context.Context, group.Filter) ([]group.Group, error)); ok { + return rf(ctx, flt) } - if rf, ok := ret.Get(0).(func(context.Context, string) group.Group); ok { - r0 = rf(ctx, idOrName) + if rf, ok := ret.Get(0).(func(context.Context, group.Filter) []group.Group); ok { + r0 = rf(ctx, flt) } else { - r0 = ret.Get(0).(group.Group) + if ret.Get(0) != nil { + r0 = ret.Get(0).([]group.Group) + } } - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, idOrName) + if rf, ok := ret.Get(1).(func(context.Context, group.Filter) error); ok { + r1 = rf(ctx, flt) } else { r1 = ret.Error(1) } @@ -110,27 +110,31 @@ func (_m *GroupService) Get(ctx context.Context, idOrName string) (group.Group, return r0, r1 } -type GroupService_Get_Call struct { +// GroupService_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' +type GroupService_List_Call struct { *mock.Call } -func (_e *GroupService_Expecter) Get(ctx interface{}, idOrName interface{}) *GroupService_Get_Call { - return &GroupService_Get_Call{Call: _e.mock.On("Get", ctx, idOrName)} +// List is a helper method to define mock.On call +// - ctx context.Context +// - flt group.Filter +func (_e *GroupService_Expecter) List(ctx interface{}, flt interface{}) *GroupService_List_Call { + return &GroupService_List_Call{Call: _e.mock.On("List", ctx, flt)} } -func (_c *GroupService_Get_Call) Run(run func(ctx context.Context, idOrName string)) *GroupService_Get_Call { +func (_c *GroupService_List_Call) Run(run func(ctx context.Context, flt group.Filter)) *GroupService_List_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string)) + run(args[0].(context.Context), args[1].(group.Filter)) }) return _c } -func (_c *GroupService_Get_Call) Return(_a0 group.Group, _a1 error) *GroupService_Get_Call { +func (_c *GroupService_List_Call) Return(_a0 []group.Group, _a1 error) *GroupService_List_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *GroupService_Get_Call) RunAndReturn(run func(context.Context, string) (group.Group, error)) *GroupService_Get_Call { +func (_c *GroupService_List_Call) RunAndReturn(run func(context.Context, group.Filter) ([]group.Group, error)) *GroupService_List_Call { _c.Call.Return(run) return _c } diff --git a/core/membership/mocks/project_service.go b/core/membership/mocks/project_service.go index f53ac32d9..71561b589 100644 --- a/core/membership/mocks/project_service.go +++ b/core/membership/mocks/project_service.go @@ -23,29 +23,27 @@ func (_m *ProjectService) EXPECT() *ProjectService_Expecter { return &ProjectService_Expecter{mock: &_m.Mock} } -// List provides a mock function with given fields: ctx, flt -func (_m *ProjectService) List(ctx context.Context, flt project.Filter) ([]project.Project, error) { - ret := _m.Called(ctx, flt) +// Get provides a mock function with given fields: ctx, idOrName +func (_m *ProjectService) Get(ctx context.Context, idOrName string) (project.Project, error) { + ret := _m.Called(ctx, idOrName) if len(ret) == 0 { - panic("no return value specified for List") + panic("no return value specified for Get") } - var r0 []project.Project + var r0 project.Project var r1 error - if rf, ok := ret.Get(0).(func(context.Context, project.Filter) ([]project.Project, error)); ok { - return rf(ctx, flt) + if rf, ok := ret.Get(0).(func(context.Context, string) (project.Project, error)); ok { + return rf(ctx, idOrName) } - if rf, ok := ret.Get(0).(func(context.Context, project.Filter) []project.Project); ok { - r0 = rf(ctx, flt) + if rf, ok := ret.Get(0).(func(context.Context, string) project.Project); ok { + r0 = rf(ctx, idOrName) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]project.Project) - } + r0 = ret.Get(0).(project.Project) } - if rf, ok := ret.Get(1).(func(context.Context, project.Filter) error); ok { - r1 = rf(ctx, flt) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, idOrName) } else { r1 = ret.Error(1) } @@ -53,56 +51,58 @@ func (_m *ProjectService) List(ctx context.Context, flt project.Filter) ([]proje return r0, r1 } -// ProjectService_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' -type ProjectService_List_Call struct { +// ProjectService_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type ProjectService_Get_Call struct { *mock.Call } -// List is a helper method to define mock.On call +// Get is a helper method to define mock.On call // - ctx context.Context -// - flt project.Filter -func (_e *ProjectService_Expecter) List(ctx interface{}, flt interface{}) *ProjectService_List_Call { - return &ProjectService_List_Call{Call: _e.mock.On("List", ctx, flt)} +// - idOrName string +func (_e *ProjectService_Expecter) Get(ctx interface{}, idOrName interface{}) *ProjectService_Get_Call { + return &ProjectService_Get_Call{Call: _e.mock.On("Get", ctx, idOrName)} } -func (_c *ProjectService_List_Call) Run(run func(ctx context.Context, flt project.Filter)) *ProjectService_List_Call { +func (_c *ProjectService_Get_Call) Run(run func(ctx context.Context, idOrName string)) *ProjectService_Get_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(project.Filter)) + run(args[0].(context.Context), args[1].(string)) }) return _c } -func (_c *ProjectService_List_Call) Return(_a0 []project.Project, _a1 error) *ProjectService_List_Call { +func (_c *ProjectService_Get_Call) Return(_a0 project.Project, _a1 error) *ProjectService_Get_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *ProjectService_List_Call) RunAndReturn(run func(context.Context, project.Filter) ([]project.Project, error)) *ProjectService_List_Call { +func (_c *ProjectService_Get_Call) RunAndReturn(run func(context.Context, string) (project.Project, error)) *ProjectService_Get_Call { _c.Call.Return(run) return _c } -// Get provides a mock function with given fields: ctx, idOrName -func (_m *ProjectService) Get(ctx context.Context, idOrName string) (project.Project, error) { - ret := _m.Called(ctx, idOrName) +// List provides a mock function with given fields: ctx, flt +func (_m *ProjectService) List(ctx context.Context, flt project.Filter) ([]project.Project, error) { + ret := _m.Called(ctx, flt) if len(ret) == 0 { - panic("no return value specified for Get") + panic("no return value specified for List") } - var r0 project.Project + var r0 []project.Project var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (project.Project, error)); ok { - return rf(ctx, idOrName) + if rf, ok := ret.Get(0).(func(context.Context, project.Filter) ([]project.Project, error)); ok { + return rf(ctx, flt) } - if rf, ok := ret.Get(0).(func(context.Context, string) project.Project); ok { - r0 = rf(ctx, idOrName) + if rf, ok := ret.Get(0).(func(context.Context, project.Filter) []project.Project); ok { + r0 = rf(ctx, flt) } else { - r0 = ret.Get(0).(project.Project) + if ret.Get(0) != nil { + r0 = ret.Get(0).([]project.Project) + } } - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, idOrName) + if rf, ok := ret.Get(1).(func(context.Context, project.Filter) error); ok { + r1 = rf(ctx, flt) } else { r1 = ret.Error(1) } @@ -110,27 +110,31 @@ func (_m *ProjectService) Get(ctx context.Context, idOrName string) (project.Pro return r0, r1 } -type ProjectService_Get_Call struct { +// ProjectService_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' +type ProjectService_List_Call struct { *mock.Call } -func (_e *ProjectService_Expecter) Get(ctx interface{}, idOrName interface{}) *ProjectService_Get_Call { - return &ProjectService_Get_Call{Call: _e.mock.On("Get", ctx, idOrName)} +// List is a helper method to define mock.On call +// - ctx context.Context +// - flt project.Filter +func (_e *ProjectService_Expecter) List(ctx interface{}, flt interface{}) *ProjectService_List_Call { + return &ProjectService_List_Call{Call: _e.mock.On("List", ctx, flt)} } -func (_c *ProjectService_Get_Call) Run(run func(ctx context.Context, idOrName string)) *ProjectService_Get_Call { +func (_c *ProjectService_List_Call) Run(run func(ctx context.Context, flt project.Filter)) *ProjectService_List_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string)) + run(args[0].(context.Context), args[1].(project.Filter)) }) return _c } -func (_c *ProjectService_Get_Call) Return(_a0 project.Project, _a1 error) *ProjectService_Get_Call { +func (_c *ProjectService_List_Call) Return(_a0 []project.Project, _a1 error) *ProjectService_List_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *ProjectService_Get_Call) RunAndReturn(run func(context.Context, string) (project.Project, error)) *ProjectService_Get_Call { +func (_c *ProjectService_List_Call) RunAndReturn(run func(context.Context, project.Filter) ([]project.Project, error)) *ProjectService_List_Call { _c.Call.Return(run) return _c } diff --git a/internal/api/v1beta1connect/mocks/membership_service.go b/internal/api/v1beta1connect/mocks/membership_service.go index eaa821ab0..3b677ba33 100644 --- a/internal/api/v1beta1connect/mocks/membership_service.go +++ b/internal/api/v1beta1connect/mocks/membership_service.go @@ -120,6 +120,55 @@ func (_c *MembershipService_RemoveOrganizationMember_Call) RunAndReturn(run func return _c } +// RemoveProjectMember provides a mock function with given fields: ctx, projectID, principalID, principalType +func (_m *MembershipService) RemoveProjectMember(ctx context.Context, projectID string, principalID string, principalType string) error { + ret := _m.Called(ctx, projectID, principalID, principalType) + + if len(ret) == 0 { + panic("no return value specified for RemoveProjectMember") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { + r0 = rf(ctx, projectID, principalID, principalType) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MembershipService_RemoveProjectMember_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveProjectMember' +type MembershipService_RemoveProjectMember_Call struct { + *mock.Call +} + +// RemoveProjectMember is a helper method to define mock.On call +// - ctx context.Context +// - projectID string +// - principalID string +// - principalType string +func (_e *MembershipService_Expecter) RemoveProjectMember(ctx interface{}, projectID interface{}, principalID interface{}, principalType interface{}) *MembershipService_RemoveProjectMember_Call { + return &MembershipService_RemoveProjectMember_Call{Call: _e.mock.On("RemoveProjectMember", ctx, projectID, principalID, principalType)} +} + +func (_c *MembershipService_RemoveProjectMember_Call) Run(run func(ctx context.Context, projectID string, principalID string, principalType string)) *MembershipService_RemoveProjectMember_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MembershipService_RemoveProjectMember_Call) Return(_a0 error) *MembershipService_RemoveProjectMember_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_RemoveProjectMember_Call) RunAndReturn(run func(context.Context, string, string, string) error) *MembershipService_RemoveProjectMember_Call { + _c.Call.Return(run) + return _c +} + // SetOrganizationMemberRole provides a mock function with given fields: ctx, orgID, principalID, principalType, roleID func (_m *MembershipService) SetOrganizationMemberRole(ctx context.Context, orgID string, principalID string, principalType string, roleID string) error { ret := _m.Called(ctx, orgID, principalID, principalType, roleID) @@ -173,31 +222,51 @@ func (_c *MembershipService_SetOrganizationMemberRole_Call) RunAndReturn(run fun // SetProjectMemberRole provides a mock function with given fields: ctx, projectID, principalID, principalType, roleID func (_m *MembershipService) SetProjectMemberRole(ctx context.Context, projectID string, principalID string, principalType string, roleID string) error { ret := _m.Called(ctx, projectID, principalID, principalType, roleID) + if len(ret) == 0 { panic("no return value specified for SetProjectMemberRole") } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { r0 = rf(ctx, projectID, principalID, principalType, roleID) } else { r0 = ret.Error(0) } + return r0 } -// RemoveProjectMember provides a mock function with given fields: ctx, projectID, principalID, principalType -func (_m *MembershipService) RemoveProjectMember(ctx context.Context, projectID string, principalID string, principalType string) error { - ret := _m.Called(ctx, projectID, principalID, principalType) - if len(ret) == 0 { - panic("no return value specified for RemoveProjectMember") - } - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { - r0 = rf(ctx, projectID, principalID, principalType) - } else { - r0 = ret.Error(0) - } - return r0 +// MembershipService_SetProjectMemberRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetProjectMemberRole' +type MembershipService_SetProjectMemberRole_Call struct { + *mock.Call +} + +// SetProjectMemberRole is a helper method to define mock.On call +// - ctx context.Context +// - projectID string +// - principalID string +// - principalType string +// - roleID string +func (_e *MembershipService_Expecter) SetProjectMemberRole(ctx interface{}, projectID interface{}, principalID interface{}, principalType interface{}, roleID interface{}) *MembershipService_SetProjectMemberRole_Call { + return &MembershipService_SetProjectMemberRole_Call{Call: _e.mock.On("SetProjectMemberRole", ctx, projectID, principalID, principalType, roleID)} +} + +func (_c *MembershipService_SetProjectMemberRole_Call) Run(run func(ctx context.Context, projectID string, principalID string, principalType string, roleID string)) *MembershipService_SetProjectMemberRole_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *MembershipService_SetProjectMemberRole_Call) Return(_a0 error) *MembershipService_SetProjectMemberRole_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MembershipService_SetProjectMemberRole_Call) RunAndReturn(run func(context.Context, string, string, string, string) error) *MembershipService_SetProjectMemberRole_Call { + _c.Call.Return(run) + return _c } // NewMembershipService creates a new instance of MembershipService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. diff --git a/internal/api/v1beta1connect/mocks/project_service.go b/internal/api/v1beta1connect/mocks/project_service.go index 84504884e..3635dccb6 100644 --- a/internal/api/v1beta1connect/mocks/project_service.go +++ b/internal/api/v1beta1connect/mocks/project_service.go @@ -537,105 +537,6 @@ func (_c *ProjectService_ListUsers_Call) RunAndReturn(run func(context.Context, return _c } -// RemoveMember provides a mock function with given fields: ctx, projectID, principalID, principalType -func (_m *ProjectService) RemoveMember(ctx context.Context, projectID string, principalID string, principalType string) error { - ret := _m.Called(ctx, projectID, principalID, principalType) - - if len(ret) == 0 { - panic("no return value specified for RemoveMember") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { - r0 = rf(ctx, projectID, principalID, principalType) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// ProjectService_RemoveMember_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveMember' -type ProjectService_RemoveMember_Call struct { - *mock.Call -} - -// RemoveMember is a helper method to define mock.On call -// - ctx context.Context -// - projectID string -// - principalID string -// - principalType string -func (_e *ProjectService_Expecter) RemoveMember(ctx interface{}, projectID interface{}, principalID interface{}, principalType interface{}) *ProjectService_RemoveMember_Call { - return &ProjectService_RemoveMember_Call{Call: _e.mock.On("RemoveMember", ctx, projectID, principalID, principalType)} -} - -func (_c *ProjectService_RemoveMember_Call) Run(run func(ctx context.Context, projectID string, principalID string, principalType string)) *ProjectService_RemoveMember_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) - }) - return _c -} - -func (_c *ProjectService_RemoveMember_Call) Return(_a0 error) *ProjectService_RemoveMember_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *ProjectService_RemoveMember_Call) RunAndReturn(run func(context.Context, string, string, string) error) *ProjectService_RemoveMember_Call { - _c.Call.Return(run) - return _c -} - -// SetMemberRole provides a mock function with given fields: ctx, projectID, principalID, principalType, newRoleID -func (_m *ProjectService) SetMemberRole(ctx context.Context, projectID string, principalID string, principalType string, newRoleID string) error { - ret := _m.Called(ctx, projectID, principalID, principalType, newRoleID) - - if len(ret) == 0 { - panic("no return value specified for SetMemberRole") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { - r0 = rf(ctx, projectID, principalID, principalType, newRoleID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// ProjectService_SetMemberRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetMemberRole' -type ProjectService_SetMemberRole_Call struct { - *mock.Call -} - -// SetMemberRole is a helper method to define mock.On call -// - ctx context.Context -// - projectID string -// - principalID string -// - principalType string -// - newRoleID string -func (_e *ProjectService_Expecter) SetMemberRole(ctx interface{}, projectID interface{}, principalID interface{}, principalType interface{}, newRoleID interface{}) *ProjectService_SetMemberRole_Call { - return &ProjectService_SetMemberRole_Call{Call: _e.mock.On("SetMemberRole", ctx, projectID, principalID, principalType, newRoleID)} -} - -func (_c *ProjectService_SetMemberRole_Call) Run(run func(ctx context.Context, projectID string, principalID string, principalType string, newRoleID string)) *ProjectService_SetMemberRole_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) - }) - return _c -} - -func (_c *ProjectService_SetMemberRole_Call) Return(_a0 error) *ProjectService_SetMemberRole_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *ProjectService_SetMemberRole_Call) RunAndReturn(run func(context.Context, string, string, string, string) error) *ProjectService_SetMemberRole_Call { - _c.Call.Return(run) - return _c -} - // Update provides a mock function with given fields: ctx, toUpdate func (_m *ProjectService) Update(ctx context.Context, toUpdate project.Project) (project.Project, error) { ret := _m.Called(ctx, toUpdate) From 9da2b1671ff76390c4c2cd499e570290f10f435f Mon Sep 17 00:00:00 2001 From: Abhishek Sah Date: Wed, 22 Apr 2026 11:38:47 +0530 Subject: [PATCH 4/5] test: add service user and group success cases for project member mutations Cover all principal types in SetProjectMemberRole and RemoveProjectMember tests to protect the principal-type audit mapping. Co-Authored-By: Claude Opus 4.6 (1M context) --- core/membership/service_test.go | 51 +++++++++++++++++++++++++++++++-- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/core/membership/service_test.go b/core/membership/service_test.go index dd87262a7..1779cde17 100644 --- a/core/membership/service_test.go +++ b/core/membership/service_test.go @@ -697,6 +697,7 @@ func TestService_SetProjectMemberRole(t *testing.T) { orgID := uuid.New().String() userID := uuid.New().String() suID := uuid.New().String() + groupID := uuid.New().String() roleID := uuid.New().String() prj := project.Project{ @@ -766,6 +767,36 @@ func TestService_SetProjectMemberRole(t *testing.T) { }, principalID: userID, principalType: schema.UserPrincipal, roleID: roleID, }, + { + name: "should succeed adding service user to project", + setup: func(policySvc *mocks.PolicyService, roleSvc *mocks.RoleService, prjSvc *mocks.ProjectService, _ *mocks.UserService, suSvc *mocks.ServiceuserService, _ *mocks.GroupService, auditRepo *mocks.AuditRecordRepository) { + prjSvc.EXPECT().Get(ctx, projectID).Return(prj, nil) + roleSvc.EXPECT().Get(ctx, roleID).Return(role.Role{ID: roleID, Scopes: []string{schema.ProjectNamespace}}, nil) + suSvc.EXPECT().Get(ctx, suID).Return(serviceuser.ServiceUser{ID: suID, OrgID: orgID}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{ProjectID: projectID, PrincipalID: suID, PrincipalType: schema.ServiceUserPrincipal}).Return([]policy.Policy{}, nil) + policySvc.EXPECT().Create(ctx, policy.Policy{ + RoleID: roleID, ResourceID: projectID, ResourceType: schema.ProjectNamespace, + PrincipalID: suID, PrincipalType: schema.ServiceUserPrincipal, + }).Return(policy.Policy{}, nil) + auditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + }, + principalID: suID, principalType: schema.ServiceUserPrincipal, roleID: roleID, + }, + { + name: "should succeed adding group to project", + setup: func(policySvc *mocks.PolicyService, roleSvc *mocks.RoleService, prjSvc *mocks.ProjectService, _ *mocks.UserService, _ *mocks.ServiceuserService, grpSvc *mocks.GroupService, auditRepo *mocks.AuditRecordRepository) { + prjSvc.EXPECT().Get(ctx, projectID).Return(prj, nil) + roleSvc.EXPECT().Get(ctx, roleID).Return(role.Role{ID: roleID, Scopes: []string{schema.ProjectNamespace}}, nil) + grpSvc.EXPECT().Get(ctx, groupID).Return(group.Group{ID: groupID, OrganizationID: orgID}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{ProjectID: projectID, PrincipalID: groupID, PrincipalType: schema.GroupPrincipal}).Return([]policy.Policy{}, nil) + policySvc.EXPECT().Create(ctx, policy.Policy{ + RoleID: roleID, ResourceID: projectID, ResourceType: schema.ProjectNamespace, + PrincipalID: groupID, PrincipalType: schema.GroupPrincipal, + }).Return(policy.Policy{}, nil) + auditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + }, + principalID: groupID, principalType: schema.GroupPrincipal, roleID: roleID, + }, } for _, tt := range tests { @@ -798,6 +829,7 @@ func TestService_RemoveProjectMember(t *testing.T) { ctx := context.Background() projectID := uuid.New().String() userID := uuid.New().String() + suID := uuid.New().String() prj := project.Project{ ID: projectID, @@ -808,11 +840,13 @@ func TestService_RemoveProjectMember(t *testing.T) { tests := []struct { name string setup func(*mocks.PolicyService, *mocks.ProjectService, *mocks.AuditRecordRepository) + principalID string principalType string wantErr error }{ { name: "should return error for invalid principal type", + principalID: userID, principalType: "app/invalid", wantErr: membership.ErrInvalidPrincipalType, }, @@ -822,19 +856,32 @@ func TestService_RemoveProjectMember(t *testing.T) { prjSvc.EXPECT().Get(ctx, projectID).Return(prj, nil) policySvc.EXPECT().List(ctx, policy.Filter{ProjectID: projectID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{}, nil) }, + principalID: userID, principalType: schema.UserPrincipal, wantErr: membership.ErrNotMember, }, { - name: "should succeed removing a member", + name: "should succeed removing a user", setup: func(policySvc *mocks.PolicyService, prjSvc *mocks.ProjectService, auditRepo *mocks.AuditRecordRepository) { prjSvc.EXPECT().Get(ctx, projectID).Return(prj, nil) policySvc.EXPECT().List(ctx, policy.Filter{ProjectID: projectID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1"}}, nil) policySvc.EXPECT().Delete(ctx, "p1").Return(nil) auditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) }, + principalID: userID, principalType: schema.UserPrincipal, }, + { + name: "should succeed removing a service user", + setup: func(policySvc *mocks.PolicyService, prjSvc *mocks.ProjectService, auditRepo *mocks.AuditRecordRepository) { + prjSvc.EXPECT().Get(ctx, projectID).Return(prj, nil) + policySvc.EXPECT().List(ctx, policy.Filter{ProjectID: projectID, PrincipalID: suID, PrincipalType: schema.ServiceUserPrincipal}).Return([]policy.Policy{{ID: "p1"}}, nil) + policySvc.EXPECT().Delete(ctx, "p1").Return(nil) + auditRepo.EXPECT().Create(ctx, mock.Anything).Return(auditrecord.AuditRecord{}, nil) + }, + principalID: suID, + principalType: schema.ServiceUserPrincipal, + }, } for _, tt := range tests { @@ -848,7 +895,7 @@ func TestService_RemoveProjectMember(t *testing.T) { } svc := membership.NewService(log.NewNoop(), mockPolicySvc, mocks.NewRelationService(t), mocks.NewRoleService(t), mocks.NewOrgService(t), mocks.NewUserService(t), mockPrjSvc, mocks.NewGroupService(t), mocks.NewServiceuserService(t), mockAuditRepo) - err := svc.RemoveProjectMember(ctx, projectID, userID, tt.principalType) + err := svc.RemoveProjectMember(ctx, projectID, tt.principalID, tt.principalType) if tt.wantErr != nil { assert.ErrorIs(t, err, tt.wantErr) From 041c3d88b8bf00422421cc4da087eb4307da302b Mon Sep 17 00:00:00 2001 From: Abhishek Sah Date: Wed, 22 Apr 2026 14:58:04 +0530 Subject: [PATCH 5/5] fix: address review comments from Aman and Copilot 1. Handle user.ErrDisabled in SetProjectMemberRole handler 2. Handle project.ErrNotExist in RemoveProjectMember handler 3. Make ErrAlreadyMember/ErrNotMember resource-agnostic 4. Follow {resource}.{action} pattern for audit events 5. Merge auditProjectMemberRoleChanged/Removed into single auditProjectMember helper with event + metadata params Co-Authored-By: Claude Opus 4.6 (1M context) --- core/membership/service.go | 42 +++++++------------------- internal/api/v1beta1connect/errors.go | 4 +-- internal/api/v1beta1connect/project.go | 4 +++ pkg/auditrecord/consts.go | 4 +-- 4 files changed, 19 insertions(+), 35 deletions(-) diff --git a/core/membership/service.go b/core/membership/service.go index 0366bbc62..4c900c2cf 100644 --- a/core/membership/service.go +++ b/core/membership/service.go @@ -701,7 +701,7 @@ func (s *Service) SetProjectMemberRole(ctx context.Context, projectID, principal return err } - s.auditProjectMemberRoleChanged(ctx, prj, principalID, principalType, resolvedRoleID) + s.auditProjectMember(ctx, pkgAuditRecord.ProjectMemberRoleChangedEvent, prj, principalID, principalType, map[string]any{"role_id": resolvedRoleID}) return nil } @@ -726,7 +726,7 @@ func (s *Service) RemoveProjectMember(ctx context.Context, projectID, principalI return ErrNotMember } - s.auditProjectMemberRemoved(ctx, prj, principalID, principalType) + s.auditProjectMember(ctx, pkgAuditRecord.ProjectMemberRemovedEvent, prj, principalID, principalType, nil) return nil } @@ -834,43 +834,23 @@ func (s *Service) validateOrgMembership(ctx context.Context, orgID, principalID, return nil } -func (s *Service) auditProjectMemberRoleChanged(ctx context.Context, prj project.Project, principalID, principalType, roleID string) { - targetType, _ := principalTypeToAuditType(principalType) - s.auditRecordRepository.Create(ctx, auditrecord.AuditRecord{ - Event: pkgAuditRecord.ProjectMemberRoleChangedEvent, - Resource: auditrecord.Resource{ - ID: prj.ID, - Type: pkgAuditRecord.ProjectType, - Name: prj.Title, - }, - Target: &auditrecord.Target{ - ID: principalID, - Type: targetType, - Metadata: map[string]any{ - "principal_type": principalType, - "role_id": roleID, - }, - }, - OrgID: prj.Organization.ID, - OccurredAt: time.Now(), - }) -} - -func (s *Service) auditProjectMemberRemoved(ctx context.Context, prj project.Project, principalID, principalType string) { +func (s *Service) auditProjectMember(ctx context.Context, event pkgAuditRecord.Event, prj project.Project, principalID, principalType string, meta map[string]any) { targetType, _ := principalTypeToAuditType(principalType) + if meta == nil { + meta = map[string]any{} + } + meta["principal_type"] = principalType s.auditRecordRepository.Create(ctx, auditrecord.AuditRecord{ - Event: pkgAuditRecord.ProjectMemberRemovedEvent, + Event: event, Resource: auditrecord.Resource{ ID: prj.ID, Type: pkgAuditRecord.ProjectType, Name: prj.Title, }, Target: &auditrecord.Target{ - ID: principalID, - Type: targetType, - Metadata: map[string]any{ - "principal_type": principalType, - }, + ID: principalID, + Type: targetType, + Metadata: meta, }, OrgID: prj.Organization.ID, OccurredAt: time.Now(), diff --git a/internal/api/v1beta1connect/errors.go b/internal/api/v1beta1connect/errors.go index db757aafb..08bada279 100644 --- a/internal/api/v1beta1connect/errors.go +++ b/internal/api/v1beta1connect/errors.go @@ -34,8 +34,8 @@ var ( ErrDomainMismatch = errors.New("user and org's whitelisted domains doesn't match") ErrInvitationNotFound = errors.New("invitation not found") ErrInvitationExpired = errors.New("invitation expired") - ErrAlreadyMember = errors.New("principal is already a member of the organization") - ErrNotMember = errors.New("principal is not a member of the organization") + ErrAlreadyMember = errors.New("principal is already a member of the resource") + ErrNotMember = errors.New("principal is not a member of the resource") ErrInvalidOrgRole = errors.New("role is not valid for organization scope") ErrInvalidProjectRole = errors.New("role is not valid for project scope") ErrEmptyEmailID = errors.New("email id is empty") diff --git a/internal/api/v1beta1connect/project.go b/internal/api/v1beta1connect/project.go index affa199b2..d35cd70e1 100644 --- a/internal/api/v1beta1connect/project.go +++ b/internal/api/v1beta1connect/project.go @@ -389,6 +389,8 @@ func (h *ConnectHandler) SetProjectMemberRole(ctx context.Context, request *conn return nil, connect.NewError(connect.CodeNotFound, ErrGroupNotFound) case errors.Is(err, membership.ErrNotOrgMember): return nil, connect.NewError(connect.CodeFailedPrecondition, ErrNotMember) + case errors.Is(err, user.ErrDisabled): + return nil, connect.NewError(connect.CodeFailedPrecondition, ErrBadRequest) case errors.Is(err, role.ErrNotExist): return nil, connect.NewError(connect.CodeNotFound, ErrInvalidRoleID) case errors.Is(err, role.ErrInvalidID): @@ -424,6 +426,8 @@ func (h *ConnectHandler) RemoveProjectMember(ctx context.Context, request *conne zap.String("principal_type", principalType)) switch { + case errors.Is(err, project.ErrNotExist): + return nil, connect.NewError(connect.CodeNotFound, ErrProjectNotFound) case errors.Is(err, membership.ErrNotMember): return nil, connect.NewError(connect.CodeNotFound, ErrNotMember) case errors.Is(err, membership.ErrInvalidPrincipalType): diff --git a/pkg/auditrecord/consts.go b/pkg/auditrecord/consts.go index f1f321d25..ca8a09bf2 100644 --- a/pkg/auditrecord/consts.go +++ b/pkg/auditrecord/consts.go @@ -40,8 +40,8 @@ const ( OrganizationInvitationAcceptedEvent Event = "organization.accepted" // Project Member Events - ProjectMemberRoleChangedEvent Event = "project.member.role_changed" - ProjectMemberRemovedEvent Event = "project.member.removed" + ProjectMemberRoleChangedEvent Event = "project.member_role_changed" + ProjectMemberRemovedEvent Event = "project.member_removed" // KYC Events KYCVerifiedEvent Event = "kyc.verified"