diff --git a/cmd/serve.go b/cmd/serve.go index 64bd9e482..80119fcbb 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -441,11 +441,13 @@ func buildAPIDependencies( authnService, serviceUserService, groupService, roleService) membershipService := membership.NewService(logger, policyService, relationService, roleService, organizationService, userService, projectService, groupService, serviceUserService, auditRecordRepository) - // Setter injection: org/group → membership is circular (membership needs them - // for validation; they need membership for Create). Break the cycle post-init. + // Setter injection: org/group/project → membership is circular (membership + // needs them for validation; they need membership for resource-by-principal + // listing). Break the cycle post-init. organizationService.SetMembershipService(membershipService) serviceUserService.SetMembershipService(membershipService) groupService.SetMembershipService(membershipService) + projectService.SetMembershipService(membershipService) orgKycRepository := postgres.NewOrgKycRepository(dbc) orgKycService := kyc.NewService(orgKycRepository) diff --git a/core/aggregates/orgpats/mocks/project_service.go b/core/aggregates/orgpats/mocks/project_service.go index b9643e796..3dd430db6 100644 --- a/core/aggregates/orgpats/mocks/project_service.go +++ b/core/aggregates/orgpats/mocks/project_service.go @@ -5,8 +5,6 @@ package mocks import ( context "context" - authenticate "github.com/raystack/frontier/core/authenticate" - mock "github.com/stretchr/testify/mock" project "github.com/raystack/frontier/core/project" @@ -25,29 +23,29 @@ func (_m *ProjectService) EXPECT() *ProjectService_Expecter { return &ProjectService_Expecter{mock: &_m.Mock} } -// ListByUser provides a mock function with given fields: ctx, principal, flt -func (_m *ProjectService) ListByUser(ctx context.Context, principal authenticate.Principal, flt project.Filter) ([]project.Project, error) { - ret := _m.Called(ctx, principal, flt) +// List provides a mock function with given fields: ctx, flt +func (_m *ProjectService) List(ctx context.Context, flt project.Filter) ([]project.Project, error) { + ret := _m.Called(ctx, flt) if len(ret) == 0 { - panic("no return value specified for ListByUser") + panic("no return value specified for List") } var r0 []project.Project var r1 error - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, project.Filter) ([]project.Project, error)); ok { - return rf(ctx, principal, flt) + if rf, ok := ret.Get(0).(func(context.Context, project.Filter) ([]project.Project, error)); ok { + return rf(ctx, flt) } - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, project.Filter) []project.Project); ok { - r0 = rf(ctx, principal, flt) + if rf, ok := ret.Get(0).(func(context.Context, project.Filter) []project.Project); ok { + r0 = rf(ctx, flt) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]project.Project) } } - if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, project.Filter) error); ok { - r1 = rf(ctx, principal, flt) + if rf, ok := ret.Get(1).(func(context.Context, project.Filter) error); ok { + r1 = rf(ctx, flt) } else { r1 = ret.Error(1) } @@ -55,32 +53,31 @@ func (_m *ProjectService) ListByUser(ctx context.Context, principal authenticate return r0, r1 } -// ProjectService_ListByUser_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListByUser' -type ProjectService_ListByUser_Call struct { +// ProjectService_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' +type ProjectService_List_Call struct { *mock.Call } -// ListByUser is a helper method to define mock.On call +// List is a helper method to define mock.On call // - ctx context.Context -// - principal authenticate.Principal // - flt project.Filter -func (_e *ProjectService_Expecter) ListByUser(ctx interface{}, principal interface{}, flt interface{}) *ProjectService_ListByUser_Call { - return &ProjectService_ListByUser_Call{Call: _e.mock.On("ListByUser", ctx, principal, flt)} +func (_e *ProjectService_Expecter) List(ctx interface{}, flt interface{}) *ProjectService_List_Call { + return &ProjectService_List_Call{Call: _e.mock.On("List", ctx, flt)} } -func (_c *ProjectService_ListByUser_Call) Run(run func(ctx context.Context, principal authenticate.Principal, flt project.Filter)) *ProjectService_ListByUser_Call { +func (_c *ProjectService_List_Call) Run(run func(ctx context.Context, flt project.Filter)) *ProjectService_List_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(project.Filter)) + run(args[0].(context.Context), args[1].(project.Filter)) }) return _c } -func (_c *ProjectService_ListByUser_Call) Return(_a0 []project.Project, _a1 error) *ProjectService_ListByUser_Call { +func (_c *ProjectService_List_Call) Return(_a0 []project.Project, _a1 error) *ProjectService_List_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *ProjectService_ListByUser_Call) RunAndReturn(run func(context.Context, authenticate.Principal, project.Filter) ([]project.Project, error)) *ProjectService_ListByUser_Call { +func (_c *ProjectService_List_Call) RunAndReturn(run func(context.Context, project.Filter) ([]project.Project, error)) *ProjectService_List_Call { _c.Call.Return(run) return _c } @@ -97,4 +94,4 @@ func NewProjectService(t interface { t.Cleanup(func() { mock.AssertExpectations(t) }) return mock -} +} \ No newline at end of file diff --git a/core/aggregates/orgpats/service.go b/core/aggregates/orgpats/service.go index 6e4e488ac..452d9d84a 100644 --- a/core/aggregates/orgpats/service.go +++ b/core/aggregates/orgpats/service.go @@ -17,7 +17,7 @@ type Repository interface { } type ProjectService interface { - ListByUser(ctx context.Context, principal authenticate.Principal, flt project.Filter) ([]project.Project, error) + List(ctx context.Context, flt project.Filter) ([]project.Project, error) } type Service struct { @@ -80,8 +80,9 @@ func (s *Service) Search(ctx context.Context, orgID string, query *rql.Query) (O return result, nil } -// resolveAllProjectsScope populates ResourceIDs for all-projects scopes by calling SpiceDB. -// Groups PATs by user_id to minimize SpiceDB calls. +// resolveAllProjectsScope populates ResourceIDs for all-projects scopes by +// listing projects the underlying user can see via membership. Groups PATs by +// user_id to minimize project-service calls. func (s *Service) resolveAllProjectsScope(ctx context.Context, orgID string, pats []AggregatedPAT) error { // Collect users that have all-projects scopes type allProjectsRef struct { @@ -108,7 +109,7 @@ func (s *Service) resolveAllProjectsScope(ctx context.Context, orgID string, pat ID: userID, Type: schema.UserPrincipal, } - projects, err := s.projectService.ListByUser(ctx, principal, project.Filter{OrgID: orgID}) + projects, err := s.projectService.List(ctx, project.Filter{Principal: &principal, OrgID: orgID}) if err != nil { return err } diff --git a/core/aggregates/orgpats/service_test.go b/core/aggregates/orgpats/service_test.go index 359ae4a87..b106070fa 100644 --- a/core/aggregates/orgpats/service_test.go +++ b/core/aggregates/orgpats/service_test.go @@ -90,17 +90,15 @@ func TestService_Search(t *testing.T) { } repo.EXPECT().Search(mock.Anything, orgID, query).Return(repoResult, nil) - projSvc.EXPECT().ListByUser(mock.Anything, mock.Anything, mock.Anything). - Return([]project.Project{{ID: "proj-1"}, {ID: "proj-2"}}, nil).Maybe() + projSvc.EXPECT().List(mock.Anything, mock.MatchedBy(func(f project.Filter) bool { + return f.OrgID == orgID && f.Principal != nil && f.Principal.ID == "user-1" && f.Principal.Type == schema.UserPrincipal + })).Return([]project.Project{{ID: "proj-1"}, {ID: "proj-2"}}, nil).Once() svc := orgpats.NewService(repo, projSvc) result, err := svc.Search(ctx, orgID, query) assert.NoError(t, err) assert.Len(t, result.PATs, 1) - // After resolution, the all-projects scope should have project IDs - if len(result.PATs[0].Scopes[0].ResourceIDs) > 0 { - assert.Contains(t, result.PATs[0].Scopes[0].ResourceIDs, "proj-1") - } + assert.ElementsMatch(t, []string{"proj-1", "proj-2"}, result.PATs[0].Scopes[0].ResourceIDs) }) t.Run("skips resolution when no all-projects scopes", func(t *testing.T) { @@ -127,7 +125,7 @@ func TestService_Search(t *testing.T) { } repo.EXPECT().Search(mock.Anything, orgID, query).Return(repoResult, nil) - // ProjectService.ListByUser should NOT be called + // ProjectService.List should NOT be called svc := orgpats.NewService(repo, projSvc) result, err := svc.Search(ctx, orgID, query) assert.NoError(t, err) diff --git a/core/membership/service.go b/core/membership/service.go index 69662fdc8..87e608984 100644 --- a/core/membership/service.go +++ b/core/membership/service.go @@ -1575,6 +1575,12 @@ func (s *Service) ListGroupsByPrincipal(ctx context.Context, principal authentic return s.listResourcesForPrincipal(ctx, subjectID, subjectType, schema.GroupNamespace, ResourceFilter{OrgID: orgID}) } +// ListProjectsByPrincipal Shim for the project package (project → membership +// would cycle). Delegates to ListResourcesByPrincipal so PAT scope is intersected. +func (s *Service) ListProjectsByPrincipal(ctx context.Context, principal authenticate.Principal, orgID string, nonInherited bool) ([]string, error) { + return s.ListResourcesByPrincipal(ctx, principal, schema.ProjectNamespace, ResourceFilter{OrgID: orgID, NonInherited: nonInherited}) +} + // ListResourcesByPrincipal returns the resource IDs of the given type on which // the principal has at least one policy. Reads Postgres policies — no SpiceDB. // With a PAT, runs the algorithm twice (user, then PAT-as-principal) and diff --git a/core/membership/service_test.go b/core/membership/service_test.go index a82e0b3e9..d5fbfd734 100644 --- a/core/membership/service_test.go +++ b/core/membership/service_test.go @@ -2393,3 +2393,124 @@ func TestService_ListGroupsByPrincipal(t *testing.T) { }) } } + +func TestService_ListProjectsByPrincipal(t *testing.T) { + ctx := context.Background() + + userID := uuid.New().String() + patID := uuid.New().String() + orgA := uuid.New().String() + projDirect := uuid.New().String() + projPATScope := uuid.New().String() + + t.Run("user principal with NonInherited=true skips org-inheritance branch", func(t *testing.T) { + mp := mocks.NewPolicyService(t) + // Direct project policies fetch. + mp.EXPECT().List(ctx, policy.Filter{ + PrincipalID: userID, + PrincipalType: schema.UserPrincipal, + ResourceType: schema.ProjectNamespace, + RolePermissions: schema.ProjectDirectVisibilityPerms, + }).Return([]policy.Policy{{ResourceID: projDirect}}, nil) + // Group expansion: principal has no groups (NonInherited=true on inner call). + mp.EXPECT().List(ctx, policy.Filter{ + PrincipalID: userID, + PrincipalType: schema.UserPrincipal, + ResourceType: schema.GroupNamespace, + }).Return([]policy.Policy{}, nil) + // NO org-inheritance fetch must happen — that's the NonInherited contract. + + svc := membership.NewService( + slog.New(slog.NewTextHandler(io.Discard, nil)), + mp, + mocks.NewRelationService(t), + mocks.NewRoleService(t), + mocks.NewOrgService(t), + mocks.NewUserService(t), + mocks.NewProjectService(t), + mocks.NewGroupService(t), + mocks.NewServiceuserService(t), + mocks.NewAuditRecordRepository(t), + ) + + got, err := svc.ListProjectsByPrincipal( + ctx, + authenticate.Principal{ID: userID, Type: schema.UserPrincipal}, + "", + true, + ) + assert.NoError(t, err) + assert.ElementsMatch(t, []string{projDirect}, got) + }) + + t.Run("PAT principal — runs both user-side and PAT-side queries and intersects (unlike groups)", func(t *testing.T) { + mp := mocks.NewPolicyService(t) + // User-side: direct project policies + (no groups) + org-inheritance branch. + mp.EXPECT().List(ctx, policy.Filter{ + PrincipalID: userID, + PrincipalType: schema.UserPrincipal, + ResourceType: schema.ProjectNamespace, + RolePermissions: schema.ProjectDirectVisibilityPerms, + }).Return([]policy.Policy{ + {ResourceID: projDirect}, + {ResourceID: projPATScope}, + }, nil) + mp.EXPECT().List(ctx, policy.Filter{ + PrincipalID: userID, + PrincipalType: schema.UserPrincipal, + ResourceType: schema.GroupNamespace, + }).Return([]policy.Policy{}, nil) + mp.EXPECT().List(ctx, policy.Filter{ + PrincipalID: userID, + PrincipalType: schema.UserPrincipal, + ResourceType: schema.OrganizationNamespace, + RolePermissions: schema.OrganizationProjectInheritPerms, + }).Return([]policy.Policy{}, nil) + + // PAT-side: same fanout under PAT principal type — PAT only scopes projPATScope. + mp.EXPECT().List(ctx, policy.Filter{ + PrincipalID: patID, + PrincipalType: schema.PATPrincipal, + ResourceType: schema.ProjectNamespace, + RolePermissions: schema.ProjectDirectVisibilityPerms, + }).Return([]policy.Policy{{ResourceID: projPATScope}}, nil) + mp.EXPECT().List(ctx, policy.Filter{ + PrincipalID: patID, + PrincipalType: schema.PATPrincipal, + ResourceType: schema.GroupNamespace, + }).Return([]policy.Policy{}, nil) + mp.EXPECT().List(ctx, policy.Filter{ + PrincipalID: patID, + PrincipalType: schema.PATPrincipal, + ResourceType: schema.OrganizationNamespace, + RolePermissions: schema.OrganizationProjectInheritPerms, + }).Return([]policy.Policy{}, nil) + + svc := membership.NewService( + slog.New(slog.NewTextHandler(io.Discard, nil)), + mp, + mocks.NewRelationService(t), + mocks.NewRoleService(t), + mocks.NewOrgService(t), + mocks.NewUserService(t), + mocks.NewProjectService(t), + mocks.NewGroupService(t), + mocks.NewServiceuserService(t), + mocks.NewAuditRecordRepository(t), + ) + + got, err := svc.ListProjectsByPrincipal( + ctx, + authenticate.Principal{ + ID: userID, + Type: schema.UserPrincipal, + PAT: &pat.PAT{ID: patID, UserID: userID, OrgID: orgA}, + }, + "", + false, + ) + assert.NoError(t, err) + // PAT narrows: user sees [direct, patScope]; PAT sees [patScope]; intersect → [patScope]. + assert.ElementsMatch(t, []string{projPATScope}, got) + }) +} diff --git a/core/project/filter.go b/core/project/filter.go index 10e4dd7b4..c55cde33a 100644 --- a/core/project/filter.go +++ b/core/project/filter.go @@ -1,6 +1,9 @@ package project -import "github.com/raystack/frontier/pkg/pagination" +import ( + "github.com/raystack/frontier/core/authenticate" + "github.com/raystack/frontier/pkg/pagination" +) type Filter struct { OrgID string @@ -17,4 +20,10 @@ type Filter struct { // are set, projects must satisfy both (intersection) — typically yields // no rows unless OrgID is one of OrgIDs. OrgIDs []string + + // Principal narrows results to projects on which the principal has a + // policy (direct, via group membership, or org-inheritance unless + // NonInherited is set). When combined with ProjectIDs the two are + // intersected. Resolved by membership.Service. + Principal *authenticate.Principal } diff --git a/core/project/mocks/group_service.go b/core/project/mocks/group_service.go index df1e0bfc1..6d08b9ad8 100644 --- a/core/project/mocks/group_service.go +++ b/core/project/mocks/group_service.go @@ -79,124 +79,6 @@ func (_c *GroupService_Get_Call) RunAndReturn(run func(context.Context, string) return _c } -// GetByIDs provides a mock function with given fields: ctx, ids -func (_m *GroupService) GetByIDs(ctx context.Context, ids []string) ([]group.Group, error) { - ret := _m.Called(ctx, ids) - - if len(ret) == 0 { - panic("no return value specified for GetByIDs") - } - - var r0 []group.Group - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, []string) ([]group.Group, error)); ok { - return rf(ctx, ids) - } - if rf, ok := ret.Get(0).(func(context.Context, []string) []group.Group); ok { - r0 = rf(ctx, ids) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]group.Group) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, []string) error); ok { - r1 = rf(ctx, ids) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GroupService_GetByIDs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByIDs' -type GroupService_GetByIDs_Call struct { - *mock.Call -} - -// GetByIDs is a helper method to define mock.On call -// - ctx context.Context -// - ids []string -func (_e *GroupService_Expecter) GetByIDs(ctx interface{}, ids interface{}) *GroupService_GetByIDs_Call { - return &GroupService_GetByIDs_Call{Call: _e.mock.On("GetByIDs", ctx, ids)} -} - -func (_c *GroupService_GetByIDs_Call) Run(run func(ctx context.Context, ids []string)) *GroupService_GetByIDs_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].([]string)) - }) - return _c -} - -func (_c *GroupService_GetByIDs_Call) Return(_a0 []group.Group, _a1 error) *GroupService_GetByIDs_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *GroupService_GetByIDs_Call) RunAndReturn(run func(context.Context, []string) ([]group.Group, error)) *GroupService_GetByIDs_Call { - _c.Call.Return(run) - return _c -} - -// List provides a mock function with given fields: ctx, flt -func (_m *GroupService) List(ctx context.Context, flt group.Filter) ([]group.Group, error) { - ret := _m.Called(ctx, flt) - - if len(ret) == 0 { - panic("no return value specified for List") - } - - var r0 []group.Group - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, group.Filter) ([]group.Group, error)); ok { - return rf(ctx, flt) - } - if rf, ok := ret.Get(0).(func(context.Context, group.Filter) []group.Group); ok { - r0 = rf(ctx, flt) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]group.Group) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, group.Filter) error); ok { - r1 = rf(ctx, flt) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GroupService_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' -type GroupService_List_Call struct { - *mock.Call -} - -// List is a helper method to define mock.On call -// - ctx context.Context -// - flt group.Filter -func (_e *GroupService_Expecter) List(ctx interface{}, flt interface{}) *GroupService_List_Call { - return &GroupService_List_Call{Call: _e.mock.On("List", ctx, flt)} -} - -func (_c *GroupService_List_Call) Run(run func(ctx context.Context, flt group.Filter)) *GroupService_List_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(group.Filter)) - }) - return _c -} - -func (_c *GroupService_List_Call) Return(_a0 []group.Group, _a1 error) *GroupService_List_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *GroupService_List_Call) RunAndReturn(run func(context.Context, group.Filter) ([]group.Group, error)) *GroupService_List_Call { - _c.Call.Return(run) - return _c -} - // NewGroupService creates a new instance of GroupService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewGroupService(t interface { diff --git a/core/project/mocks/membership_service.go b/core/project/mocks/membership_service.go new file mode 100644 index 000000000..3bdc05427 --- /dev/null +++ b/core/project/mocks/membership_service.go @@ -0,0 +1,99 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + authenticate "github.com/raystack/frontier/core/authenticate" + + mock "github.com/stretchr/testify/mock" +) + +// MembershipService is an autogenerated mock type for the MembershipService type +type MembershipService struct { + mock.Mock +} + +type MembershipService_Expecter struct { + mock *mock.Mock +} + +func (_m *MembershipService) EXPECT() *MembershipService_Expecter { + return &MembershipService_Expecter{mock: &_m.Mock} +} + +// ListProjectsByPrincipal provides a mock function with given fields: ctx, principal, orgID, nonInherited +func (_m *MembershipService) ListProjectsByPrincipal(ctx context.Context, principal authenticate.Principal, orgID string, nonInherited bool) ([]string, error) { + ret := _m.Called(ctx, principal, orgID, nonInherited) + + if len(ret) == 0 { + panic("no return value specified for ListProjectsByPrincipal") + } + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, string, bool) ([]string, error)); ok { + return rf(ctx, principal, orgID, nonInherited) + } + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, string, bool) []string); ok { + r0 = rf(ctx, principal, orgID, nonInherited) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, string, bool) error); ok { + r1 = rf(ctx, principal, orgID, nonInherited) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MembershipService_ListProjectsByPrincipal_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListProjectsByPrincipal' +type MembershipService_ListProjectsByPrincipal_Call struct { + *mock.Call +} + +// ListProjectsByPrincipal is a helper method to define mock.On call +// - ctx context.Context +// - principal authenticate.Principal +// - orgID string +// - nonInherited bool +func (_e *MembershipService_Expecter) ListProjectsByPrincipal(ctx interface{}, principal interface{}, orgID interface{}, nonInherited interface{}) *MembershipService_ListProjectsByPrincipal_Call { + return &MembershipService_ListProjectsByPrincipal_Call{Call: _e.mock.On("ListProjectsByPrincipal", ctx, principal, orgID, nonInherited)} +} + +func (_c *MembershipService_ListProjectsByPrincipal_Call) Run(run func(ctx context.Context, principal authenticate.Principal, orgID string, nonInherited bool)) *MembershipService_ListProjectsByPrincipal_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(string), args[3].(bool)) + }) + return _c +} + +func (_c *MembershipService_ListProjectsByPrincipal_Call) Return(_a0 []string, _a1 error) *MembershipService_ListProjectsByPrincipal_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MembershipService_ListProjectsByPrincipal_Call) RunAndReturn(run func(context.Context, authenticate.Principal, string, bool) ([]string, error)) *MembershipService_ListProjectsByPrincipal_Call { + _c.Call.Return(run) + return _c +} + +// NewMembershipService creates a new instance of MembershipService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMembershipService(t interface { + mock.TestingT + Cleanup(func()) +}) *MembershipService { + mock := &MembershipService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/project/mocks/relation_service.go b/core/project/mocks/relation_service.go index cadbfdc2e..81faa7759 100644 --- a/core/project/mocks/relation_service.go +++ b/core/project/mocks/relation_service.go @@ -127,65 +127,6 @@ func (_c *RelationService_Delete_Call) RunAndReturn(run func(context.Context, re return _c } -// LookupResources provides a mock function with given fields: ctx, rel -func (_m *RelationService) LookupResources(ctx context.Context, rel relation.Relation) ([]string, error) { - ret := _m.Called(ctx, rel) - - if len(ret) == 0 { - panic("no return value specified for LookupResources") - } - - var r0 []string - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, relation.Relation) ([]string, error)); ok { - return rf(ctx, rel) - } - if rf, ok := ret.Get(0).(func(context.Context, relation.Relation) []string); ok { - r0 = rf(ctx, rel) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]string) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, relation.Relation) error); ok { - r1 = rf(ctx, rel) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// RelationService_LookupResources_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LookupResources' -type RelationService_LookupResources_Call struct { - *mock.Call -} - -// LookupResources is a helper method to define mock.On call -// - ctx context.Context -// - rel relation.Relation -func (_e *RelationService_Expecter) LookupResources(ctx interface{}, rel interface{}) *RelationService_LookupResources_Call { - return &RelationService_LookupResources_Call{Call: _e.mock.On("LookupResources", ctx, rel)} -} - -func (_c *RelationService_LookupResources_Call) Run(run func(ctx context.Context, rel relation.Relation)) *RelationService_LookupResources_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(relation.Relation)) - }) - return _c -} - -func (_c *RelationService_LookupResources_Call) Return(_a0 []string, _a1 error) *RelationService_LookupResources_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *RelationService_LookupResources_Call) RunAndReturn(run func(context.Context, relation.Relation) ([]string, error)) *RelationService_LookupResources_Call { - _c.Call.Return(run) - return _c -} - // LookupSubjects provides a mock function with given fields: ctx, rel func (_m *RelationService) LookupSubjects(ctx context.Context, rel relation.Relation) ([]string, error) { ret := _m.Called(ctx, rel) diff --git a/core/project/service.go b/core/project/service.go index d6572059a..deac0f329 100644 --- a/core/project/service.go +++ b/core/project/service.go @@ -24,7 +24,6 @@ import ( type RelationService interface { Create(ctx context.Context, rel relation.Relation) (relation.Relation, error) LookupSubjects(ctx context.Context, rel relation.Relation) ([]string, error) - LookupResources(ctx context.Context, rel relation.Relation) ([]string, error) Delete(ctx context.Context, rel relation.Relation) error } @@ -56,19 +55,22 @@ type AuthnService interface { type GroupService interface { Get(ctx context.Context, id string) (group.Group, error) - GetByIDs(ctx context.Context, ids []string) ([]group.Group, error) - List(ctx context.Context, flt group.Filter) ([]group.Group, error) +} + +type MembershipService interface { + ListProjectsByPrincipal(ctx context.Context, principal authenticate.Principal, orgID string, nonInherited bool) ([]string, error) } type Service struct { - repository Repository - relationService RelationService - userService UserService - suserService ServiceuserService - policyService PolicyService - authnService AuthnService - groupService GroupService - roleService RoleService + repository Repository + relationService RelationService + userService UserService + suserService ServiceuserService + policyService PolicyService + authnService AuthnService + groupService GroupService + roleService RoleService + membershipService MembershipService } func NewService(repository Repository, relationService RelationService, userService UserService, @@ -86,6 +88,12 @@ func NewService(repository Repository, relationService RelationService, userServ } } +// SetMembershipService sets the membership dependency after construction to +// break the circular init order between project and membership services. +func (s *Service) SetMembershipService(ms MembershipService) { + s.membershipService = ms +} + func (s Service) Get(ctx context.Context, idOrName string) (Project, error) { if utils.IsValidUUID(idOrName) { return s.repository.GetByID(ctx, idOrName) @@ -124,6 +132,26 @@ func (s Service) Create(ctx context.Context, prj Project) (Project, error) { } func (s Service) List(ctx context.Context, f Filter) ([]Project, error) { + if f.Principal != nil { + if f.Principal.ID == "" || f.Principal.Type == "" { + return nil, fmt.Errorf("project: invalid principal filter") + } + if s.membershipService == nil { + return nil, fmt.Errorf("project: membership service not wired") + } + ids, err := s.membershipService.ListProjectsByPrincipal(ctx, *f.Principal, f.OrgID, f.NonInherited) + if err != nil { + return nil, err + } + if len(f.ProjectIDs) > 0 { + ids = utils.Intersection(ids, f.ProjectIDs) + } + if len(ids) == 0 { + return []Project{}, nil + } + f.ProjectIDs = ids + } + projects, err := s.repository.List(ctx, f) if err != nil { return nil, err @@ -150,95 +178,6 @@ func (s Service) List(ctx context.Context, f Filter) ([]Project, error) { return projects, nil } -func (s Service) ListByUser(ctx context.Context, principal authenticate.Principal, - flt Filter) ([]Project, error) { - subjectID, subjectType := principal.ResolveSubject() - - var projIDs []string - var err error - if flt.NonInherited { - // direct added users - projIDs, err = s.listNonInheritedProjectIDs(ctx, subjectID, subjectType) - } else { - projIDs, err = s.relationService.LookupResources(ctx, relation.Relation{ - Object: relation.Object{Namespace: schema.ProjectNamespace}, - Subject: relation.Subject{Namespace: subjectType, ID: subjectID}, - RelationName: MemberPermission, - }) - } - if err != nil { - return nil, err - } - - projIDs = utils.Deduplicate(projIDs) - projIDs, err = s.intersectPATScope(ctx, principal, schema.ProjectNamespace, projIDs) - if err != nil { - return nil, err - } - if len(projIDs) == 0 { - return []Project{}, nil - } - - flt.ProjectIDs = projIDs - return s.List(ctx, flt) -} - -// listNonInheritedProjectIDs returns project IDs where the principal has direct -// role assignments (not inherited through org), including via group memberships. -func (s Service) listNonInheritedProjectIDs(ctx context.Context, principalID, principalType string) ([]string, error) { - policies, err := s.policyService.List(ctx, policy.Filter{ - PrincipalType: principalType, - PrincipalID: principalID, - ResourceType: schema.ProjectNamespace, - }) - if err != nil { - return nil, err - } - var projIDs []string - for _, pol := range policies { - projIDs = append(projIDs, pol.ResourceID) - } - - // projects added via group memberships - principal := authenticate.Principal{ID: principalID, Type: principalType} - groups, err := s.groupService.List(ctx, group.Filter{Principal: &principal}) - if err != nil { - return nil, err - } - groupIDs := utils.Map(groups, func(g group.Group) string { return g.ID }) - if len(groupIDs) > 0 { - policies, err = s.policyService.List(ctx, policy.Filter{ - PrincipalType: schema.GroupPrincipal, - PrincipalIDs: groupIDs, - ResourceType: schema.ProjectNamespace, - }) - if err != nil { - return nil, err - } - for _, pol := range policies { - projIDs = append(projIDs, pol.ResourceID) - } - } - return projIDs, nil -} - -// intersectPATScope narrows resource IDs to only those the PAT is scoped to. -func (s Service) intersectPATScope(ctx context.Context, principal authenticate.Principal, - namespace string, resourceIDs []string) ([]string, error) { - if principal.PAT == nil || len(resourceIDs) == 0 { - return resourceIDs, nil - } - patIDs, err := s.relationService.LookupResources(ctx, relation.Relation{ - Object: relation.Object{Namespace: namespace}, - Subject: relation.Subject{ID: principal.PAT.ID, Namespace: schema.PATPrincipal}, - RelationName: schema.GetPermission, - }) - if err != nil { - return nil, err - } - return utils.Intersection(resourceIDs, patIDs), nil -} - func (s Service) Update(ctx context.Context, prj Project) (Project, error) { if utils.IsValidUUID(prj.ID) { return s.repository.UpdateByID(ctx, prj) diff --git a/core/project/service_test.go b/core/project/service_test.go index 9bec3baa5..7289ff1b8 100644 --- a/core/project/service_test.go +++ b/core/project/service_test.go @@ -17,7 +17,6 @@ import ( "github.com/raystack/frontier/core/role" "github.com/raystack/frontier/core/serviceuser" "github.com/raystack/frontier/core/user" - pat "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/stretchr/testify/assert" ) @@ -287,495 +286,184 @@ func TestService_List(t *testing.T) { } } -func TestService_ListByUser(t *testing.T) { +func TestService_List_WithPrincipal(t *testing.T) { ctx := context.Background() - type args struct { - principal authenticate.Principal - flt project.Filter - } + userPrincipal := authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal} + tests := []struct { name string - setup func() *project.Service - args args + setup func(*testing.T) *project.Service + filter project.Filter want []project.Project wantErr bool }{ { - name: "list all projects by user successfully", - args: args{ - principal: authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, - flt: project.Filter{}, - }, - want: []project.Project{ - { - ID: "project-id", - Name: "test", - Organization: organization.Organization{ - ID: "org-id", - }, - }, - { - ID: "project-id-2", - Name: "test-2", - Organization: organization.Organization{ - ID: "org-id", - }, - }, - { - ID: "project-id-3", - Name: "test-3", - Organization: organization.Organization{ - ID: "org-id-2", - }, - }, - }, - wantErr: false, - setup: func() *project.Service { + name: "errors when membership service is not wired", + filter: project.Filter{Principal: &userPrincipal}, + setup: func(t *testing.T) *project.Service { + t.Helper() repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t) - _ = roleService - relationService.EXPECT().LookupResources(ctx, relation.Relation{ - Object: relation.Object{ - Namespace: schema.ProjectNamespace, - }, - Subject: relation.Subject{ - Namespace: schema.UserPrincipal, - ID: "user-id", - }, - RelationName: project.MemberPermission, - }).Return([]string{"project-id", "project-id-2", "project-id-3"}, nil) - - repo.EXPECT().List(ctx, project.Filter{ - ProjectIDs: []string{"project-id", "project-id-2", "project-id-3"}, - }).Return([]project.Project{ - { - ID: "project-id", - Name: "test", - Organization: organization.Organization{ - ID: "org-id", - }, - }, - { - ID: "project-id-2", - Name: "test-2", - Organization: organization.Organization{ - ID: "org-id", - }, - }, - { - ID: "project-id-3", - Name: "test-3", - Organization: organization.Organization{ - ID: "org-id-2", - }, - }, - }, nil) + // Intentionally skip SetMembershipService. return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) }, + wantErr: true, }, { - name: "list all projects by user with non-inherited policies (with no groups)", - args: args{ - principal: authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, - flt: project.Filter{ - NonInherited: true, - }, - }, - want: []project.Project{ - { - ID: "project-id", - Name: "test", - Organization: organization.Organization{ - ID: "org-id", - }, - }, - }, - wantErr: false, - setup: func() *project.Service { + name: "errors when Principal has empty ID", + filter: project.Filter{Principal: &authenticate.Principal{Type: schema.UserPrincipal}}, + setup: func(t *testing.T) *project.Service { + t.Helper() repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t) - _ = roleService - policyService.EXPECT().List(ctx, policy.Filter{ - PrincipalType: schema.UserPrincipal, - PrincipalID: "user-id", - ResourceType: schema.ProjectNamespace, - }).Return([]policy.Policy{ - { - ResourceID: "project-id", - ResourceType: schema.ProjectNamespace, - PrincipalID: "user-id", - PrincipalType: schema.UserPrincipal, - }, - }, nil) - - groupService.EXPECT().List(ctx, group.Filter{ - Principal: &authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, - }).Return([]group.Group{}, nil) - - repo.EXPECT().List(ctx, project.Filter{ - ProjectIDs: []string{"project-id"}, - NonInherited: true, - }).Return([]project.Project{ - { - ID: "project-id", - Name: "test", - Organization: organization.Organization{ - ID: "org-id", - }, - }, - }, nil) return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) }, + wantErr: true, }, { - name: "list all projects by user with non-inherited policies (with groups)", - args: args{ - principal: authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, - flt: project.Filter{ - NonInherited: true, - }, - }, - want: []project.Project{ - { - ID: "project-id", - Name: "test", - Organization: organization.Organization{ - ID: "org-id", - }, - }, - { - ID: "project-id-2", - Name: "test-2", - Organization: organization.Organization{ - ID: "org-id", - }, - }, - }, - wantErr: false, - setup: func() *project.Service { + name: "errors when Principal has empty Type", + filter: project.Filter{Principal: &authenticate.Principal{ID: "user-id"}}, + setup: func(t *testing.T) *project.Service { + t.Helper() repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t) - _ = roleService - policyService.EXPECT().List(ctx, policy.Filter{ - PrincipalType: schema.UserPrincipal, - PrincipalID: "user-id", - ResourceType: schema.ProjectNamespace, - }).Return([]policy.Policy{ - { - ResourceID: "project-id", - ResourceType: schema.ProjectNamespace, - PrincipalID: "user-id", - PrincipalType: schema.UserPrincipal, - }, - }, nil) - - groupService.EXPECT().List(ctx, group.Filter{ - Principal: &authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, - }).Return([]group.Group{ - { - ID: "group-id", - }, - }, nil) - - policyService.EXPECT().List(ctx, policy.Filter{ - PrincipalType: schema.GroupPrincipal, - PrincipalIDs: []string{"group-id"}, - ResourceType: schema.ProjectNamespace, - }).Return([]policy.Policy{ - { - ResourceID: "project-id-2", - ResourceType: schema.ProjectNamespace, - PrincipalID: "group-id", - PrincipalType: schema.GroupPrincipal, - }, - }, nil) - - repo.EXPECT().List(ctx, project.Filter{ - ProjectIDs: []string{"project-id", "project-id-2"}, - NonInherited: true, - }).Return([]project.Project{ - { - ID: "project-id", - Name: "test", - Organization: organization.Organization{ - ID: "org-id", - }, - }, - { - ID: "project-id-2", - Name: "test-2", - Organization: organization.Organization{ - ID: "org-id", - }, - }, - }, nil) return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) }, + wantErr: true, }, { - name: "PAT principal should resolve to user and intersect with PAT project scope", - args: args{ - principal: authenticate.Principal{ - ID: "pat-456", - Type: schema.PATPrincipal, - PAT: &pat.PAT{ID: "pat-456", UserID: "user-id", OrgID: "org-1"}, - }, - flt: project.Filter{}, - }, + name: "returns projects from the membership shim", + filter: project.Filter{Principal: &userPrincipal}, want: []project.Project{ - { - ID: "project-id", - Name: "test", - Organization: organization.Organization{ - ID: "org-id", - }, - }, + {ID: "p1", Name: "p1"}, + {ID: "p2", Name: "p2"}, }, - wantErr: false, - setup: func() *project.Service { + setup: func(t *testing.T) *project.Service { + t.Helper() repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t) - _ = roleService - // LookupResources for user's project memberships (resolved from PAT) - relationService.EXPECT().LookupResources(ctx, relation.Relation{ - Object: relation.Object{ - Namespace: schema.ProjectNamespace, - }, - Subject: relation.Subject{ - Namespace: schema.UserPrincipal, - ID: "user-id", - }, - RelationName: project.MemberPermission, - }).Return([]string{"project-id", "project-id-2", "project-id-3"}, nil) - - // LookupResources for PAT's project scope - relationService.EXPECT().LookupResources(ctx, relation.Relation{ - Object: relation.Object{ - Namespace: schema.ProjectNamespace, - }, - Subject: relation.Subject{ - ID: "pat-456", - Namespace: schema.PATPrincipal, - }, - RelationName: schema.GetPermission, - }).Return([]string{"project-id"}, nil) - - // Repo called with intersection - repo.EXPECT().List(ctx, project.Filter{ - ProjectIDs: []string{"project-id"}, - }).Return([]project.Project{ - { - ID: "project-id", - Name: "test", - Organization: organization.Organization{ - ID: "org-id", - }, - }, - }, nil) - return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) - }, - }, - { - name: "PAT principal with non-inherited should resolve to user and intersect", - args: args{ - principal: authenticate.Principal{ - ID: "pat-456", - Type: schema.PATPrincipal, - PAT: &pat.PAT{ID: "pat-456", UserID: "user-id", OrgID: "org-1"}, - }, - flt: project.Filter{ - NonInherited: true, - }, - }, - want: []project.Project{ - { - ID: "project-id", - Name: "test", - Organization: organization.Organization{ - ID: "org-id", - }, - }, - }, - wantErr: false, - setup: func() *project.Service { + membershipService := mocks.NewMembershipService(t) + membershipService.EXPECT(). + ListProjectsByPrincipal(ctx, userPrincipal, "", false). + Return([]string{"p1", "p2"}, nil) + repo.EXPECT(). + List(ctx, project.Filter{Principal: &userPrincipal, ProjectIDs: []string{"p1", "p2"}}). + Return([]project.Project{{ID: "p1", Name: "p1"}, {ID: "p2", Name: "p2"}}, nil) + svc := project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) + svc.SetMembershipService(membershipService) + return svc + }, + }, + { + name: "passes OrgID and NonInherited through to the shim", + filter: project.Filter{Principal: &userPrincipal, OrgID: "org-1", NonInherited: true}, + want: []project.Project{{ID: "p1", Organization: organization.Organization{ID: "org-1"}}}, + setup: func(t *testing.T) *project.Service { + t.Helper() repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t) - _ = roleService - // Direct policies for user (resolved from PAT) - policyService.EXPECT().List(ctx, policy.Filter{ - PrincipalType: schema.UserPrincipal, - PrincipalID: "user-id", - ResourceType: schema.ProjectNamespace, - }).Return([]policy.Policy{ - { - ResourceID: "project-id", - ResourceType: schema.ProjectNamespace, - PrincipalID: "user-id", - PrincipalType: schema.UserPrincipal, - }, - { - ResourceID: "project-id-2", - ResourceType: schema.ProjectNamespace, - PrincipalID: "user-id", - PrincipalType: schema.UserPrincipal, - }, - }, nil) - - // Group lookup uses user-only principal (no double PAT filtering) - groupService.EXPECT().List(ctx, group.Filter{ - Principal: &authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, - }).Return([]group.Group{}, nil) - - // PAT scope intersection - relationService.EXPECT().LookupResources(ctx, relation.Relation{ - Object: relation.Object{ - Namespace: schema.ProjectNamespace, - }, - Subject: relation.Subject{ - ID: "pat-456", - Namespace: schema.PATPrincipal, - }, - RelationName: schema.GetPermission, - }).Return([]string{"project-id"}, nil) - - // Repo called with intersection result - repo.EXPECT().List(ctx, project.Filter{ - ProjectIDs: []string{"project-id"}, - NonInherited: true, - }).Return([]project.Project{ - { - ID: "project-id", - Name: "test", - Organization: organization.Organization{ - ID: "org-id", - }, - }, - }, nil) - return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) + membershipService := mocks.NewMembershipService(t) + membershipService.EXPECT(). + ListProjectsByPrincipal(ctx, userPrincipal, "org-1", true). + Return([]string{"p1"}, nil) + repo.EXPECT(). + List(ctx, project.Filter{Principal: &userPrincipal, OrgID: "org-1", NonInherited: true, ProjectIDs: []string{"p1"}}). + Return([]project.Project{{ID: "p1", Organization: organization.Organization{ID: "org-1"}}}, nil) + svc := project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) + svc.SetMembershipService(membershipService) + return svc + }, + }, + { + name: "intersects shim result with caller-supplied ProjectIDs", + filter: project.Filter{Principal: &userPrincipal, ProjectIDs: []string{"p2", "p3", "p4"}}, + want: []project.Project{{ID: "p2"}, {ID: "p3"}}, + setup: func(t *testing.T) *project.Service { + t.Helper() + repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t) + membershipService := mocks.NewMembershipService(t) + membershipService.EXPECT(). + ListProjectsByPrincipal(ctx, userPrincipal, "", false). + Return([]string{"p1", "p2", "p3"}, nil) + repo.EXPECT(). + List(ctx, project.Filter{Principal: &userPrincipal, ProjectIDs: []string{"p2", "p3"}}). + Return([]project.Project{{ID: "p2"}, {ID: "p3"}}, nil) + svc := project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) + svc.SetMembershipService(membershipService) + return svc + }, + }, + { + name: "short-circuits to empty slice when shim returns no IDs", + filter: project.Filter{Principal: &userPrincipal}, + want: []project.Project{}, + setup: func(t *testing.T) *project.Service { + t.Helper() + repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t) + membershipService := mocks.NewMembershipService(t) + membershipService.EXPECT(). + ListProjectsByPrincipal(ctx, userPrincipal, "", false). + Return(nil, nil) + // repo.List must NOT be called. + svc := project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) + svc.SetMembershipService(membershipService) + return svc }, }, { - name: "PAT principal with non-inherited surfaces group-mediated projects intersected with PAT scope", - args: args{ - principal: authenticate.Principal{ - ID: "pat-456", - Type: schema.PATPrincipal, - PAT: &pat.PAT{ID: "pat-456", UserID: "user-id", OrgID: "org-1"}, - }, - flt: project.Filter{ - NonInherited: true, - }, - }, - want: []project.Project{ - { - ID: "project-via-group", - Name: "group-project", - Organization: organization.Organization{ID: "org-1"}, - }, - }, - wantErr: false, - setup: func() *project.Service { + name: "propagates membership shim error", + filter: project.Filter{Principal: &userPrincipal}, + setup: func(t *testing.T) *project.Service { + t.Helper() repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t) - _ = roleService - // User has no direct project policies; everything comes via group - policyService.EXPECT().List(ctx, policy.Filter{ - PrincipalType: schema.UserPrincipal, - PrincipalID: "user-id", - ResourceType: schema.ProjectNamespace, - }).Return([]policy.Policy{}, nil) - - // User is a member of a group (PAT resolved to user before this call) - groupService.EXPECT().List(ctx, group.Filter{ - Principal: &authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, - }).Return([]group.Group{{ID: "group-id"}}, nil) - - // That group has policy on a project - policyService.EXPECT().List(ctx, policy.Filter{ - PrincipalType: schema.GroupPrincipal, - PrincipalIDs: []string{"group-id"}, - ResourceType: schema.ProjectNamespace, - }).Return([]policy.Policy{ - { - ResourceID: "project-via-group", - ResourceType: schema.ProjectNamespace, - PrincipalID: "group-id", - PrincipalType: schema.GroupPrincipal, - }, - }, nil) - - // PAT scope grants the same project - relationService.EXPECT().LookupResources(ctx, relation.Relation{ - Object: relation.Object{ - Namespace: schema.ProjectNamespace, - }, - Subject: relation.Subject{ - ID: "pat-456", - Namespace: schema.PATPrincipal, - }, - RelationName: schema.GetPermission, - }).Return([]string{"project-via-group"}, nil) - - repo.EXPECT().List(ctx, project.Filter{ - ProjectIDs: []string{"project-via-group"}, - NonInherited: true, - }).Return([]project.Project{ - { - ID: "project-via-group", - Name: "group-project", - Organization: organization.Organization{ID: "org-1"}, - }, - }, nil) - return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) + membershipService := mocks.NewMembershipService(t) + membershipService.EXPECT(). + ListProjectsByPrincipal(ctx, userPrincipal, "", false). + Return(nil, errors.New("membership boom")) + svc := project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) + svc.SetMembershipService(membershipService) + return svc }, + wantErr: true, }, { - name: "PAT principal with no project overlap returns empty", - args: args{ - principal: authenticate.Principal{ - ID: "pat-456", - Type: schema.PATPrincipal, - PAT: &pat.PAT{ID: "pat-456", UserID: "user-id", OrgID: "org-1"}, - }, - flt: project.Filter{}, + name: "composes Filter.Principal with WithMemberCount enrichment", + filter: project.Filter{Principal: &userPrincipal, OrgID: "org-1", WithMemberCount: true}, + want: []project.Project{ + {ID: "p1", Organization: organization.Organization{ID: "org-1"}, MemberCount: 5}, + {ID: "p2", Organization: organization.Organization{ID: "org-1"}, MemberCount: 2}, }, - want: []project.Project{}, - wantErr: false, - setup: func() *project.Service { + setup: func(t *testing.T) *project.Service { + t.Helper() repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t) - _ = roleService - // User has projects - relationService.EXPECT().LookupResources(ctx, relation.Relation{ - Object: relation.Object{ - Namespace: schema.ProjectNamespace, - }, - Subject: relation.Subject{ - Namespace: schema.UserPrincipal, - ID: "user-id", - }, - RelationName: project.MemberPermission, - }).Return([]string{"project-id-1"}, nil) - - // PAT scoped to different projects - relationService.EXPECT().LookupResources(ctx, relation.Relation{ - Object: relation.Object{ - Namespace: schema.ProjectNamespace, - }, - Subject: relation.Subject{ - ID: "pat-456", - Namespace: schema.PATPrincipal, - }, - RelationName: schema.GetPermission, - }).Return([]string{"project-id-2"}, nil) - - return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) + membershipService := mocks.NewMembershipService(t) + membershipService.EXPECT(). + ListProjectsByPrincipal(ctx, userPrincipal, "org-1", false). + Return([]string{"p1", "p2"}, nil) + repo.EXPECT(). + List(ctx, project.Filter{Principal: &userPrincipal, OrgID: "org-1", WithMemberCount: true, ProjectIDs: []string{"p1", "p2"}}). + Return([]project.Project{ + {ID: "p1", Organization: organization.Organization{ID: "org-1"}}, + {ID: "p2", Organization: organization.Organization{ID: "org-1"}}, + }, nil) + policyService.EXPECT(). + ProjectMemberCount(ctx, []string{"p1", "p2"}). + Return([]policy.MemberCount{ + {ID: "p1", Count: 5}, + {ID: "p2", Count: 2}, + }, nil) + svc := project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) + svc.SetMembershipService(membershipService) + return svc }, }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := tt.setup() - got, err := s.ListByUser(ctx, tt.args.principal, tt.args.flt) + s := tt.setup(t) + got, err := s.List(ctx, tt.filter) if (err != nil) != tt.wantErr { - t.Errorf("ListByUser() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("List() error = %v, wantErr %v", err, tt.wantErr) return } if diff := cmp.Diff(got, tt.want); diff != "" { - t.Errorf("ListByUser() mismatch (-want +got):\n%s", diff) + t.Errorf("List() mismatch (-want +got):\n%s", diff) } }) } diff --git a/core/userpat/mocks/project_service.go b/core/userpat/mocks/project_service.go index b9643e796..3dd430db6 100644 --- a/core/userpat/mocks/project_service.go +++ b/core/userpat/mocks/project_service.go @@ -5,8 +5,6 @@ package mocks import ( context "context" - authenticate "github.com/raystack/frontier/core/authenticate" - mock "github.com/stretchr/testify/mock" project "github.com/raystack/frontier/core/project" @@ -25,29 +23,29 @@ func (_m *ProjectService) EXPECT() *ProjectService_Expecter { return &ProjectService_Expecter{mock: &_m.Mock} } -// ListByUser provides a mock function with given fields: ctx, principal, flt -func (_m *ProjectService) ListByUser(ctx context.Context, principal authenticate.Principal, flt project.Filter) ([]project.Project, error) { - ret := _m.Called(ctx, principal, flt) +// List provides a mock function with given fields: ctx, flt +func (_m *ProjectService) List(ctx context.Context, flt project.Filter) ([]project.Project, error) { + ret := _m.Called(ctx, flt) if len(ret) == 0 { - panic("no return value specified for ListByUser") + panic("no return value specified for List") } var r0 []project.Project var r1 error - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, project.Filter) ([]project.Project, error)); ok { - return rf(ctx, principal, flt) + if rf, ok := ret.Get(0).(func(context.Context, project.Filter) ([]project.Project, error)); ok { + return rf(ctx, flt) } - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, project.Filter) []project.Project); ok { - r0 = rf(ctx, principal, flt) + if rf, ok := ret.Get(0).(func(context.Context, project.Filter) []project.Project); ok { + r0 = rf(ctx, flt) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]project.Project) } } - if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, project.Filter) error); ok { - r1 = rf(ctx, principal, flt) + if rf, ok := ret.Get(1).(func(context.Context, project.Filter) error); ok { + r1 = rf(ctx, flt) } else { r1 = ret.Error(1) } @@ -55,32 +53,31 @@ func (_m *ProjectService) ListByUser(ctx context.Context, principal authenticate return r0, r1 } -// ProjectService_ListByUser_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListByUser' -type ProjectService_ListByUser_Call struct { +// ProjectService_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' +type ProjectService_List_Call struct { *mock.Call } -// ListByUser is a helper method to define mock.On call +// List is a helper method to define mock.On call // - ctx context.Context -// - principal authenticate.Principal // - flt project.Filter -func (_e *ProjectService_Expecter) ListByUser(ctx interface{}, principal interface{}, flt interface{}) *ProjectService_ListByUser_Call { - return &ProjectService_ListByUser_Call{Call: _e.mock.On("ListByUser", ctx, principal, flt)} +func (_e *ProjectService_Expecter) List(ctx interface{}, flt interface{}) *ProjectService_List_Call { + return &ProjectService_List_Call{Call: _e.mock.On("List", ctx, flt)} } -func (_c *ProjectService_ListByUser_Call) Run(run func(ctx context.Context, principal authenticate.Principal, flt project.Filter)) *ProjectService_ListByUser_Call { +func (_c *ProjectService_List_Call) Run(run func(ctx context.Context, flt project.Filter)) *ProjectService_List_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(project.Filter)) + run(args[0].(context.Context), args[1].(project.Filter)) }) return _c } -func (_c *ProjectService_ListByUser_Call) Return(_a0 []project.Project, _a1 error) *ProjectService_ListByUser_Call { +func (_c *ProjectService_List_Call) Return(_a0 []project.Project, _a1 error) *ProjectService_List_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *ProjectService_ListByUser_Call) RunAndReturn(run func(context.Context, authenticate.Principal, project.Filter) ([]project.Project, error)) *ProjectService_ListByUser_Call { +func (_c *ProjectService_List_Call) RunAndReturn(run func(context.Context, project.Filter) ([]project.Project, error)) *ProjectService_List_Call { _c.Call.Return(run) return _c } @@ -97,4 +94,4 @@ func NewProjectService(t interface { t.Cleanup(func() { mock.AssertExpectations(t) }) return mock -} +} \ No newline at end of file diff --git a/core/userpat/service.go b/core/userpat/service.go index a1ef57fb6..98472e6b1 100644 --- a/core/userpat/service.go +++ b/core/userpat/service.go @@ -50,7 +50,7 @@ type PolicyService interface { } type ProjectService interface { - ListByUser(ctx context.Context, principal authenticate.Principal, flt project.Filter) ([]project.Project, error) + List(ctx context.Context, flt project.Filter) ([]project.Project, error) } type AuditRecordRepository interface { @@ -503,7 +503,7 @@ func (s *Service) validateProjectAccess(ctx context.Context, userID, orgID strin ID: userID, Type: schema.UserPrincipal, } - userProjects, err := s.projectService.ListByUser(ctx, principal, project.Filter{OrgID: orgID}) + userProjects, err := s.projectService.List(ctx, project.Filter{Principal: &principal, OrgID: orgID}) if err != nil { return fmt.Errorf("listing user projects: %w", err) } diff --git a/core/userpat/service_test.go b/core/userpat/service_test.go index 843a8de6f..498f37859 100644 --- a/core/userpat/service_test.go +++ b/core/userpat/service_test.go @@ -646,7 +646,9 @@ func TestService_CreatePolicies_ProjectScopedSpecificProjects(t *testing.T) { policySvc.On("List", mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() projSvc := mocks.NewProjectService(t) - projSvc.On("ListByUser", mock.Anything, mock.Anything, mock.Anything).Return([]project.Project{ + projSvc.On("List", mock.Anything, mock.MatchedBy(func(f project.Filter) bool { + return f.OrgID == "org-1" && f.Principal != nil && f.Principal.ID == "user-1" && f.Principal.Type == schema.UserPrincipal + })).Return([]project.Project{ {ID: "proj-a"}, {ID: "proj-b"}, }, nil).Maybe() @@ -1170,7 +1172,9 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { policySvc.On("List", mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() projSvc := mocks.NewProjectService(t) - projSvc.On("ListByUser", mock.Anything, mock.Anything, mock.Anything).Return([]project.Project{ + projSvc.On("List", mock.Anything, mock.MatchedBy(func(f project.Filter) bool { + return f.OrgID == "org-1" && f.Principal != nil && f.Principal.ID == "user-1" && f.Principal.Type == schema.UserPrincipal + })).Return([]project.Project{ {ID: "proj-1"}, {ID: "proj-2"}, {ID: "proj-3"}, {ID: "proj-a"}, {ID: "proj-b"}, }, nil).Maybe() @@ -2559,7 +2563,9 @@ func TestService_ValidateProjectAccess(t *testing.T) { {ID: "role-1", Name: "proj_viewer", Scopes: []string{schema.ProjectNamespace}, Permissions: []string{"app_project_get"}}, }, nil) projSvc := mocks.NewProjectService(t) - projSvc.On("ListByUser", mock.Anything, mock.Anything, mock.Anything).Return([]project.Project{ + projSvc.On("List", mock.Anything, mock.MatchedBy(func(f project.Filter) bool { + return f.OrgID == "org-1" && f.Principal != nil && f.Principal.ID == "user-1" && f.Principal.Type == schema.UserPrincipal + })).Return([]project.Project{ {ID: "proj-in-org"}, }, nil) @@ -2597,7 +2603,9 @@ func TestService_ValidateProjectAccess(t *testing.T) { policySvc.On("Create", mock.Anything, mock.Anything).Return(policy.Policy{}, nil).Maybe() policySvc.On("List", mock.Anything, mock.Anything).Return([]policy.Policy{}, nil).Maybe() projSvc := mocks.NewProjectService(t) - projSvc.On("ListByUser", mock.Anything, mock.Anything, mock.Anything).Return([]project.Project{ + projSvc.On("List", mock.Anything, mock.MatchedBy(func(f project.Filter) bool { + return f.OrgID == "org-1" && f.Principal != nil && f.Principal.ID == "user-1" && f.Principal.Type == schema.UserPrincipal + })).Return([]project.Project{ {ID: "proj-in-org"}, }, nil) auditRepo := mocks.NewAuditRecordRepository(t) diff --git a/internal/api/v1beta1connect/interfaces.go b/internal/api/v1beta1connect/interfaces.go index ac952fc52..33b35e0e0 100644 --- a/internal/api/v1beta1connect/interfaces.go +++ b/internal/api/v1beta1connect/interfaces.go @@ -342,7 +342,6 @@ type ProjectService interface { Get(ctx context.Context, idOrName string) (project.Project, error) Create(ctx context.Context, prj project.Project) (project.Project, error) List(ctx context.Context, f project.Filter) ([]project.Project, error) - ListByUser(ctx context.Context, principal authenticate.Principal, flt project.Filter) ([]project.Project, error) Update(ctx context.Context, toUpdate project.Project) (project.Project, error) Enable(ctx context.Context, id string) error Disable(ctx context.Context, id string) error diff --git a/internal/api/v1beta1connect/mocks/project_service.go b/internal/api/v1beta1connect/mocks/project_service.go index cce5ed11d..907cc0b13 100644 --- a/internal/api/v1beta1connect/mocks/project_service.go +++ b/internal/api/v1beta1connect/mocks/project_service.go @@ -5,11 +5,8 @@ package mocks import ( context "context" - authenticate "github.com/raystack/frontier/core/authenticate" - - mock "github.com/stretchr/testify/mock" - project "github.com/raystack/frontier/core/project" + mock "github.com/stretchr/testify/mock" ) // ProjectService is an autogenerated mock type for the ProjectService type @@ -292,66 +289,6 @@ func (_c *ProjectService_List_Call) RunAndReturn(run func(context.Context, proje return _c } -// ListByUser provides a mock function with given fields: ctx, principal, flt -func (_m *ProjectService) ListByUser(ctx context.Context, principal authenticate.Principal, flt project.Filter) ([]project.Project, error) { - ret := _m.Called(ctx, principal, flt) - - if len(ret) == 0 { - panic("no return value specified for ListByUser") - } - - var r0 []project.Project - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, project.Filter) ([]project.Project, error)); ok { - return rf(ctx, principal, flt) - } - if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, project.Filter) []project.Project); ok { - r0 = rf(ctx, principal, flt) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]project.Project) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, project.Filter) error); ok { - r1 = rf(ctx, principal, flt) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// ProjectService_ListByUser_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListByUser' -type ProjectService_ListByUser_Call struct { - *mock.Call -} - -// ListByUser is a helper method to define mock.On call -// - ctx context.Context -// - principal authenticate.Principal -// - flt project.Filter -func (_e *ProjectService_Expecter) ListByUser(ctx interface{}, principal interface{}, flt interface{}) *ProjectService_ListByUser_Call { - return &ProjectService_ListByUser_Call{Call: _e.mock.On("ListByUser", ctx, principal, flt)} -} - -func (_c *ProjectService_ListByUser_Call) Run(run func(ctx context.Context, principal authenticate.Principal, flt project.Filter)) *ProjectService_ListByUser_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(project.Filter)) - }) - return _c -} - -func (_c *ProjectService_ListByUser_Call) Return(_a0 []project.Project, _a1 error) *ProjectService_ListByUser_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *ProjectService_ListByUser_Call) RunAndReturn(run func(context.Context, authenticate.Principal, project.Filter) ([]project.Project, error)) *ProjectService_ListByUser_Call { - _c.Call.Return(run) - return _c -} - // Update provides a mock function with given fields: ctx, toUpdate func (_m *ProjectService) Update(ctx context.Context, toUpdate project.Project) (project.Project, error) { ret := _m.Called(ctx, toUpdate) diff --git a/internal/api/v1beta1connect/serviceuser.go b/internal/api/v1beta1connect/serviceuser.go index a936d49f3..76daf229d 100644 --- a/internal/api/v1beta1connect/serviceuser.go +++ b/internal/api/v1beta1connect/serviceuser.go @@ -463,11 +463,13 @@ func (h *ConnectHandler) ListServiceUserProjects(ctx context.Context, request *c errorLogger := NewErrorLogger() serviceUserID := request.Msg.GetId() orgID := request.Msg.GetOrgId() + if serviceUserID == "" { + return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) + } - projList, err := h.projectService.ListByUser(ctx, authenticate.Principal{ - ID: serviceUserID, Type: schema.ServiceUserPrincipal, - }, project.Filter{ - OrgID: orgID, + projList, err := h.projectService.List(ctx, project.Filter{ + Principal: &authenticate.Principal{ID: serviceUserID, Type: schema.ServiceUserPrincipal}, + OrgID: orgID, }) if err != nil { errorLogger.LogServiceError(ctx, request, "ListServiceUserProjects", err, diff --git a/internal/api/v1beta1connect/serviceuser_test.go b/internal/api/v1beta1connect/serviceuser_test.go index eb0282f04..1f15a2059 100644 --- a/internal/api/v1beta1connect/serviceuser_test.go +++ b/internal/api/v1beta1connect/serviceuser_test.go @@ -1338,7 +1338,7 @@ func TestHandler_ListServiceUserProjects(t *testing.T) { Id: "1", }), setup: func(projSvc *mocks.ProjectService, permSvc *mocks.PermissionService, resourceSvc *mocks.ResourceService) { - projSvc.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: "1", Type: schema.ServiceUserPrincipal}, project.Filter{}).Return(nil, errors.New("test error")) + projSvc.EXPECT().List(mock.Anything, project.Filter{Principal: &authenticate.Principal{ID: "1", Type: schema.ServiceUserPrincipal}}).Return(nil, errors.New("test error")) }, want: nil, wantErr: ErrInternalServerError, @@ -1354,7 +1354,7 @@ func TestHandler_ListServiceUserProjects(t *testing.T) { for _, projectID := range testProjectIDList { projects = append(projects, testProjectMap[projectID]) } - projSvc.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: "1", Type: schema.ServiceUserPrincipal}, project.Filter{}).Return(projects, nil) + projSvc.EXPECT().List(mock.Anything, project.Filter{Principal: &authenticate.Principal{ID: "1", Type: schema.ServiceUserPrincipal}}).Return(projects, nil) }, want: connect.NewResponse(&frontierv1beta1.ListServiceUserProjectsResponse{ Projects: []*frontierv1beta1.Project{{ @@ -1386,6 +1386,34 @@ func TestHandler_ListServiceUserProjects(t *testing.T) { wantErr: nil, errCode: connect.Code(0), }, + { + name: "should return invalid argument when id is empty", + request: connect.NewRequest(&frontierv1beta1.ListServiceUserProjectsRequest{ + Id: "", + }), + setup: func(projSvc *mocks.ProjectService, permSvc *mocks.PermissionService, resourceSvc *mocks.ResourceService) { + // projectService.List must NOT be called. + }, + want: nil, + wantErr: ErrBadRequest, + errCode: connect.CodeInvalidArgument, + }, + { + name: "should forward org_id to project.Filter when set", + request: connect.NewRequest(&frontierv1beta1.ListServiceUserProjectsRequest{ + Id: "1", + OrgId: "9f256f86-31a3-11ec-8d3d-0242ac130003", + }), + setup: func(projSvc *mocks.ProjectService, permSvc *mocks.PermissionService, resourceSvc *mocks.ResourceService) { + projSvc.EXPECT().List(mock.Anything, project.Filter{ + Principal: &authenticate.Principal{ID: "1", Type: schema.ServiceUserPrincipal}, + OrgID: "9f256f86-31a3-11ec-8d3d-0242ac130003", + }).Return([]project.Project{}, nil) + }, + want: connect.NewResponse(&frontierv1beta1.ListServiceUserProjectsResponse{}), + wantErr: nil, + errCode: connect.Code(0), + }, { name: "should return project list with access pairs if withPermission is passed", request: connect.NewRequest(&frontierv1beta1.ListServiceUserProjectsRequest{ @@ -1399,7 +1427,7 @@ func TestHandler_ListServiceUserProjects(t *testing.T) { } ctx := mock.Anything - projSvc.EXPECT().ListByUser(ctx, authenticate.Principal{ID: "1", Type: schema.ServiceUserPrincipal}, project.Filter{}).Return(projects, nil) + projSvc.EXPECT().List(ctx, project.Filter{Principal: &authenticate.Principal{ID: "1", Type: schema.ServiceUserPrincipal}}).Return(projects, nil) permSvc.EXPECT().Get(ctx, "app/project:get").Return( permission.Permission{ diff --git a/internal/api/v1beta1connect/user.go b/internal/api/v1beta1connect/user.go index 8c859e135..45ff7e43b 100644 --- a/internal/api/v1beta1connect/user.go +++ b/internal/api/v1beta1connect/user.go @@ -864,12 +864,15 @@ func (h *ConnectHandler) ListOrganizationsByCurrentUser(ctx context.Context, req func (h *ConnectHandler) ListProjectsByUser(ctx context.Context, request *connect.Request[frontierv1beta1.ListProjectsByUserRequest]) (*connect.Response[frontierv1beta1.ListProjectsByUserResponse], error) { errorLogger := NewErrorLogger() userID := request.Msg.GetId() + if userID == "" { + return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) + } - projList, err := h.projectService.ListByUser(ctx, authenticate.Principal{ - ID: userID, Type: schema.UserPrincipal, - }, project.Filter{}) + projList, err := h.projectService.List(ctx, project.Filter{ + Principal: &authenticate.Principal{ID: userID, Type: schema.UserPrincipal}, + }) if err != nil { - errorLogger.LogServiceError(ctx, request, "ListProjectsByUser.ListByUser", err, + errorLogger.LogServiceError(ctx, request, "ListProjectsByUser.List", err, "user_id", userID) switch { @@ -908,7 +911,8 @@ func (h *ConnectHandler) ListProjectsByCurrentUser(ctx context.Context, request } paginate := pagination.NewPagination(request.Msg.GetPageNum(), request.Msg.GetPageSize()) - projList, err := h.projectService.ListByUser(ctx, principal, project.Filter{ + projList, err := h.projectService.List(ctx, project.Filter{ + Principal: &principal, OrgID: request.Msg.GetOrgId(), NonInherited: request.Msg.GetNonInherited(), WithMemberCount: request.Msg.GetWithMemberCount(), diff --git a/internal/api/v1beta1connect/user_test.go b/internal/api/v1beta1connect/user_test.go index 8dffaf9b6..73e475170 100644 --- a/internal/api/v1beta1connect/user_test.go +++ b/internal/api/v1beta1connect/user_test.go @@ -1569,7 +1569,7 @@ func TestConnectHandler_ListProjectsByUser(t *testing.T) { { title: "should list user projects successfully", setup: func(ps *mocks.ProjectService, as *mocks.AuthnService) { - ps.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: "user-1", Type: schema.UserPrincipal}, project.Filter{}).Return([]project.Project{ + ps.EXPECT().List(mock.Anything, project.Filter{Principal: &authenticate.Principal{ID: "user-1", Type: schema.UserPrincipal}}).Return([]project.Project{ { ID: "project-1", Name: "test-project-1", @@ -1616,7 +1616,7 @@ func TestConnectHandler_ListProjectsByUser(t *testing.T) { { title: "should return empty list when user has no projects", setup: func(ps *mocks.ProjectService, as *mocks.AuthnService) { - ps.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: "user-1", Type: schema.UserPrincipal}, project.Filter{}).Return([]project.Project{}, nil) + ps.EXPECT().List(mock.Anything, project.Filter{Principal: &authenticate.Principal{ID: "user-1", Type: schema.UserPrincipal}}).Return([]project.Project{}, nil) }, req: &frontierv1beta1.ListProjectsByUserRequest{Id: "user-1"}, want: &frontierv1beta1.ListProjectsByUserResponse{ @@ -1627,7 +1627,7 @@ func TestConnectHandler_ListProjectsByUser(t *testing.T) { { title: "should return not found error when user does not exist", setup: func(ps *mocks.ProjectService, as *mocks.AuthnService) { - ps.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: "non-existent-user", Type: schema.UserPrincipal}, project.Filter{}).Return(nil, user.ErrNotExist) + ps.EXPECT().List(mock.Anything, project.Filter{Principal: &authenticate.Principal{ID: "non-existent-user", Type: schema.UserPrincipal}}).Return(nil, user.ErrNotExist) }, req: &frontierv1beta1.ListProjectsByUserRequest{Id: "non-existent-user"}, want: nil, @@ -1636,7 +1636,7 @@ func TestConnectHandler_ListProjectsByUser(t *testing.T) { { title: "should return bad request error for invalid user ID", setup: func(ps *mocks.ProjectService, as *mocks.AuthnService) { - ps.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: "invalid-id", Type: schema.UserPrincipal}, project.Filter{}).Return(nil, user.ErrInvalidUUID) + ps.EXPECT().List(mock.Anything, project.Filter{Principal: &authenticate.Principal{ID: "invalid-id", Type: schema.UserPrincipal}}).Return(nil, user.ErrInvalidUUID) }, req: &frontierv1beta1.ListProjectsByUserRequest{Id: "invalid-id"}, want: nil, @@ -1645,12 +1645,21 @@ func TestConnectHandler_ListProjectsByUser(t *testing.T) { { title: "should return internal error for project service failure", setup: func(ps *mocks.ProjectService, as *mocks.AuthnService) { - ps.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: "user-1", Type: schema.UserPrincipal}, project.Filter{}).Return(nil, errors.New("database error")) + ps.EXPECT().List(mock.Anything, project.Filter{Principal: &authenticate.Principal{ID: "user-1", Type: schema.UserPrincipal}}).Return(nil, errors.New("database error")) }, req: &frontierv1beta1.ListProjectsByUserRequest{Id: "user-1"}, want: nil, err: connect.CodeInternal, }, + { + title: "should return invalid argument when id is empty", + setup: func(ps *mocks.ProjectService, as *mocks.AuthnService) { + // projectService.List must NOT be called. + }, + req: &frontierv1beta1.ListProjectsByUserRequest{Id: ""}, + want: nil, + err: connect.CodeInvalidArgument, + }, } for _, tt := range tests { @@ -1712,7 +1721,10 @@ func TestConnectHandler_ListProjectsByCurrentUser(t *testing.T) { } as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) - ps.EXPECT().ListByUser(mock.Anything, mockPrincipal, mock.MatchedBy(func(filter project.Filter) bool { + ps.EXPECT().List(mock.Anything, mock.MatchedBy(func(filter project.Filter) bool { + if filter.Principal == nil || *filter.Principal != mockPrincipal { + return false + } return filter.OrgID == "" })).Return([]project.Project{ { @@ -1769,7 +1781,10 @@ func TestConnectHandler_ListProjectsByCurrentUser(t *testing.T) { } as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) - ps.EXPECT().ListByUser(mock.Anything, mockPrincipal, mock.MatchedBy(func(filter project.Filter) bool { + ps.EXPECT().List(mock.Anything, mock.MatchedBy(func(filter project.Filter) bool { + if filter.Principal == nil || *filter.Principal != mockPrincipal { + return false + } return filter.OrgID == "org-1" })).Return([]project.Project{ { @@ -1809,7 +1824,10 @@ func TestConnectHandler_ListProjectsByCurrentUser(t *testing.T) { } as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) - ps.EXPECT().ListByUser(mock.Anything, mockPrincipal, mock.MatchedBy(func(filter project.Filter) bool { + ps.EXPECT().List(mock.Anything, mock.MatchedBy(func(filter project.Filter) bool { + if filter.Principal == nil || *filter.Principal != mockPrincipal { + return false + } return filter.OrgID == "" })).Return([]project.Project{}, nil) }, @@ -1839,7 +1857,10 @@ func TestConnectHandler_ListProjectsByCurrentUser(t *testing.T) { } as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) - ps.EXPECT().ListByUser(mock.Anything, mockPrincipal, mock.MatchedBy(func(filter project.Filter) bool { + ps.EXPECT().List(mock.Anything, mock.MatchedBy(func(filter project.Filter) bool { + if filter.Principal == nil || *filter.Principal != mockPrincipal { + return false + } return filter.OrgID == "" })).Return(nil, errors.New("database error")) }, diff --git a/test/e2e/regression/api_test.go b/test/e2e/regression/api_test.go index 096cb2534..eca5df037 100644 --- a/test/e2e/regression/api_test.go +++ b/test/e2e/regression/api_test.go @@ -747,6 +747,77 @@ func (s *APIRegressionTestSuite) TestProjectAPI() { }) } +func (s *APIRegressionTestSuite) TestProjectAPI_StaleRelationRegression() { + ctxOrgAdminAuth := testbench.ContextWithAuth(context.Background(), s.adminCookie) + + s.Run("org owner demoted to viewer no longer sees inherited projects", func() { + createOrgResp, err := s.testBench.Client.CreateOrganization(ctxOrgAdminAuth, connect.NewRequest(&frontierv1beta1.CreateOrganizationRequest{ + Body: &frontierv1beta1.OrganizationRequestBody{ + Title: "stale relation regression org", + Name: "stale-rel-org", + }, + })) + s.Require().NoError(err) + orgID := createOrgResp.Msg.GetOrganization().GetId() + + createProjResp, err := s.testBench.Client.CreateProject(ctxOrgAdminAuth, connect.NewRequest(&frontierv1beta1.CreateProjectRequest{ + Body: &frontierv1beta1.ProjectRequestBody{ + Name: "stale-rel-proj", + Title: "stale relation regression project", + OrgId: orgID, + }, + })) + s.Require().NoError(err) + projectID := createProjResp.Msg.GetProject().GetId() + + createUserResp, err := s.testBench.Client.CreateUser(ctxOrgAdminAuth, connect.NewRequest(&frontierv1beta1.CreateUserRequest{ + Body: &frontierv1beta1.UserRequestBody{ + Title: "stale rel user", + Email: "stale-rel-user@raystack.org", + Name: "stale_rel_user", + }, + })) + s.Require().NoError(err) + userID := createUserResp.Msg.GetUser().GetId() + + addMembersResp, err := s.testBench.AdminClient.AddOrganizationMembers(ctxOrgAdminAuth, connect.NewRequest(&frontierv1beta1.AddOrganizationMembersRequest{ + OrgId: orgID, + Members: []*frontierv1beta1.OrgMemberEntry{{ + UserId: userID, + RoleId: s.orgOwnerRole, + }}, + })) + requireAddOrgMembersSuccess(s.T(), addMembersResp, err) + + userCookie, err := testbench.AuthenticateUser(context.Background(), s.testBench.Client, createUserResp.Msg.GetUser().GetEmail()) + s.Require().NoError(err) + ctxUserAuth := testbench.ContextWithAuth(context.Background(), userCookie) + + listAsOwnerResp, err := s.testBench.Client.ListProjectsByCurrentUser(ctxUserAuth, connect.NewRequest(&frontierv1beta1.ListProjectsByCurrentUserRequest{ + OrgId: orgID, + })) + s.Require().NoError(err) + s.Assert().True(slices.ContainsFunc(listAsOwnerResp.Msg.GetProjects(), func(p *frontierv1beta1.Project) bool { + return p.GetId() == projectID + }), "org owner should see inherited project before demotion") + + _, err = s.testBench.Client.SetOrganizationMemberRole(ctxOrgAdminAuth, connect.NewRequest(&frontierv1beta1.SetOrganizationMemberRoleRequest{ + OrgId: orgID, + UserId: userID, + RoleId: s.orgViewerRole, + })) + s.Require().NoError(err) + + listAsViewerResp, err := s.testBench.Client.ListProjectsByCurrentUser(ctxUserAuth, connect.NewRequest(&frontierv1beta1.ListProjectsByCurrentUserRequest{ + OrgId: orgID, + })) + s.Require().NoError(err) + s.Assert().False(slices.ContainsFunc(listAsViewerResp.Msg.GetProjects(), func(p *frontierv1beta1.Project) bool { + return p.GetId() == projectID + }), "demoted viewer must not inherit project visibility") + }) +} + func (s *APIRegressionTestSuite) TestGroupAPI() { var newGroup *frontierv1beta1.Group ctxOrgAdminAuth := testbench.ContextWithAuth(context.Background(), s.adminCookie)