Skip to content

Commit 1c35361

Browse files
author
Ryan Clarke
committed
[feature] Add policy filters to file adapter
Extend the default file adapter to support the new policy filtering feature.
1 parent c8d1293 commit 1c35361

File tree

8 files changed

+112
-25
lines changed

8 files changed

+112
-25
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ import (
212212
adapter := mongodbadapter.NewFilteredAdapter("127.0.0.1:27017")
213213
enforcer := casbin.NewEnforcer("examples/rbac_with_domains_model.conf", adapter)
214214
215+
// Values of type `map[string]interface{}` may be used here instead of BSON.
215216
filter := &bson.M{
216217
"$or": []bson.M{
217218
bson.M{"ptype": "p", "v1": "mydomain"},

enforcer.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ import (
3131

3232
// Enforcer is the main interface for authorization enforcement and policy management.
3333
type Enforcer struct {
34-
modelPath string
35-
model model.Model
36-
fm model.FunctionMap
37-
eft effect.Effector
34+
modelPath string
35+
model model.Model
36+
fm model.FunctionMap
37+
eft effect.Effector
3838

39-
adapter persist.Adapter
40-
watcher persist.Watcher
41-
rm rbac.RoleManager
39+
adapter persist.Adapter
40+
watcher persist.Watcher
41+
rm rbac.RoleManager
4242

4343
enabled bool
4444
autoSave bool
@@ -192,7 +192,7 @@ func (e *Enforcer) SetAdapter(adapter persist.Adapter) {
192192
// SetWatcher sets the current watcher.
193193
func (e *Enforcer) SetWatcher(watcher persist.Watcher) {
194194
e.watcher = watcher
195-
watcher.SetUpdateCallback(func (string) {e.LoadPolicy()})
195+
watcher.SetUpdateCallback(func(string) { e.LoadPolicy() })
196196
}
197197

198198
// SetRoleManager sets the current role manager.

enforcer_safe.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,4 @@ func (e *Enforcer) RemoveFilteredPolicySafe(fieldIndex int, fieldValues ...strin
9797
result = e.RemoveFilteredNamedPolicy("p", fieldIndex, fieldValues...)
9898
err = nil
9999
return
100-
}
100+
}

enforcer_synced.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import (
2424

2525
type SyncedEnforcer struct {
2626
*Enforcer
27-
m sync.RWMutex
27+
m sync.RWMutex
2828
autoLoad bool
2929
}
3030

@@ -50,10 +50,10 @@ func (e *SyncedEnforcer) StartAutoLoadPolicy(d time.Duration) {
5050
e.LoadPolicy()
5151
// Uncomment this line to see when the policy is loaded.
5252
// log.Print("Load policy for time: ", n)
53-
n ++
53+
n++
5454
time.Sleep(d)
5555
}
56-
} ()
56+
}()
5757
}
5858

5959
func (e *SyncedEnforcer) StopAutoLoadPolicy() {
@@ -63,7 +63,7 @@ func (e *SyncedEnforcer) StopAutoLoadPolicy() {
6363
// SetWatcher sets the current watcher.
6464
func (e *SyncedEnforcer) SetWatcher(watcher persist.Watcher) {
6565
e.watcher = watcher
66-
watcher.SetUpdateCallback(func (string) {e.LoadPolicy()})
66+
watcher.SetUpdateCallback(func(string) { e.LoadPolicy() })
6767
}
6868

6969
// ClearPolicy clears all policy.

enforcer_test.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,4 +400,19 @@ func TestInitEmpty(t *testing.T) {
400400
e.LoadPolicy()
401401

402402
testEnforce(t, e, "alice", "/alice_data/resource1", "GET", true)
403-
}
403+
}
404+
405+
func TestLoadFilteredPolicy(t *testing.T) {
406+
e := NewEnforcer("examples/rbac_with_domains_model.conf", "examples/rbac_with_domains_policy.csv")
407+
408+
testHasPolicy(t, e, []string{"admin", "domain1", "data1", "read"}, true)
409+
testHasPolicy(t, e, []string{"admin", "domain2", "data2", "read"}, true)
410+
411+
e.LoadFilteredPolicy(&fileadapter.Filter{
412+
P: []string{"", "domain1"},
413+
G: []string{"", "", "domain1"},
414+
})
415+
416+
testHasPolicy(t, e, []string{"admin", "domain1", "data1", "read"}, true)
417+
testHasPolicy(t, e, []string{"admin", "domain2", "data2", "read"}, false)
418+
}

model_b_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ func BenchmarkRBACModelSmall(b *testing.B) {
5959
e.EnableAutoBuildRoleLinks(false)
6060
// 100 roles, 10 resources.
6161
for i := 0; i < 100; i++ {
62-
e.AddPolicy(fmt.Sprintf("group%d", i), fmt.Sprintf("data%d", i / 10), "read")
62+
e.AddPolicy(fmt.Sprintf("group%d", i), fmt.Sprintf("data%d", i/10), "read")
6363
}
6464
// 1000 users.
6565
for i := 0; i < 1000; i++ {
66-
e.AddGroupingPolicy(fmt.Sprintf("user%d", i), fmt.Sprintf("group%d", i / 10))
66+
e.AddGroupingPolicy(fmt.Sprintf("user%d", i), fmt.Sprintf("group%d", i/10))
6767
}
6868
e.BuildRoleLinks()
6969

@@ -79,11 +79,11 @@ func BenchmarkRBACModelMedium(b *testing.B) {
7979
e.EnableAutoBuildRoleLinks(false)
8080
// 1000 roles, 100 resources.
8181
for i := 0; i < 1000; i++ {
82-
e.AddPolicy(fmt.Sprintf("group%d", i), fmt.Sprintf("data%d", i / 10), "read")
82+
e.AddPolicy(fmt.Sprintf("group%d", i), fmt.Sprintf("data%d", i/10), "read")
8383
}
8484
// 10000 users.
8585
for i := 0; i < 10000; i++ {
86-
e.AddGroupingPolicy(fmt.Sprintf("user%d", i), fmt.Sprintf("group%d", i / 10))
86+
e.AddGroupingPolicy(fmt.Sprintf("user%d", i), fmt.Sprintf("group%d", i/10))
8787
}
8888
e.BuildRoleLinks()
8989

@@ -99,11 +99,11 @@ func BenchmarkRBACModelLarge(b *testing.B) {
9999
e.EnableAutoBuildRoleLinks(false)
100100
// 10000 roles, 1000 resources.
101101
for i := 0; i < 10000; i++ {
102-
e.AddPolicy(fmt.Sprintf("group%d", i), fmt.Sprintf("data%d", i / 10), "read")
102+
e.AddPolicy(fmt.Sprintf("group%d", i), fmt.Sprintf("data%d", i/10), "read")
103103
}
104104
// 100000 users.
105105
for i := 0; i < 100000; i++ {
106-
e.AddGroupingPolicy(fmt.Sprintf("user%d", i), fmt.Sprintf("group%d", i / 10))
106+
e.AddGroupingPolicy(fmt.Sprintf("user%d", i), fmt.Sprintf("group%d", i/10))
107107
}
108108
e.BuildRoleLinks()
109109

model_test.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,12 @@ func NewRoleManager() rbac.RoleManager {
286286
return &testCustomRoleManager{}
287287
}
288288
func (rm *testCustomRoleManager) Clear() error { return nil }
289-
func (rm *testCustomRoleManager) AddLink(name1 string, name2 string, domain ...string) error { return nil }
290-
func (rm *testCustomRoleManager) DeleteLink(name1 string, name2 string, domain ...string) error { return nil }
289+
func (rm *testCustomRoleManager) AddLink(name1 string, name2 string, domain ...string) error {
290+
return nil
291+
}
292+
func (rm *testCustomRoleManager) DeleteLink(name1 string, name2 string, domain ...string) error {
293+
return nil
294+
}
291295
func (rm *testCustomRoleManager) HasLink(name1 string, name2 string, domain ...string) (bool, error) {
292296
if name1 == "alice" && name2 == "alice" {
293297
return true, nil

persist/file-adapter/adapter.go

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ import (
3131
// It can load policy from file or save policy to file.
3232
type Adapter struct {
3333
filePath string
34+
filtered bool
35+
}
36+
37+
// Filter defines filter options for Adater.
38+
type Filter struct {
39+
// P is the filter option for "p" policies.
40+
P []string
41+
// G is the filter option for "g" policies.
42+
G []string
3443
}
3544

3645
// NewAdapter is the constructor for Adapter.
@@ -40,18 +49,46 @@ func NewAdapter(filePath string) *Adapter {
4049
return &a
4150
}
4251

52+
// NewFilteredAdapter is the constructor for FilteredAdapter.
53+
func NewFilteredAdapter(filePath string) *Adapter {
54+
return NewAdapter(filePath)
55+
}
56+
4357
// LoadPolicy loads all policy rules from the storage.
4458
func (a *Adapter) LoadPolicy(model model.Model) error {
59+
return a.LoadFilteredPolicy(model, nil)
60+
}
61+
62+
// LoadFilteredPolicy loads matching policy rules from the storage.
63+
func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) error {
64+
var filterValue *Filter
65+
if filter == nil {
66+
a.filtered = false
67+
} else {
68+
var ok bool
69+
filterValue, ok = filter.(*Filter)
70+
if !ok {
71+
return errors.New("invalid filter type, expected *Filter")
72+
}
73+
a.filtered = true
74+
}
4575
if a.filePath == "" {
4676
return errors.New("invalid file path, file path cannot be empty")
4777
}
4878

49-
err := a.loadPolicyFile(model, persist.LoadPolicyLine)
50-
return err
79+
return a.loadPolicyFile(model, filterValue, persist.LoadPolicyLine)
80+
}
81+
82+
// IsFiltered returns true if the loaded policy is filtered.
83+
func (a *Adapter) IsFiltered() bool {
84+
return a.filtered
5185
}
5286

5387
// SavePolicy saves all policy rules to the storage.
5488
func (a *Adapter) SavePolicy(model model.Model) error {
89+
if a.filtered == true {
90+
return errors.New("cannot save a filtered policy")
91+
}
5592
if a.filePath == "" {
5693
return errors.New("invalid file path, file path cannot be empty")
5794
}
@@ -78,7 +115,7 @@ func (a *Adapter) SavePolicy(model model.Model) error {
78115
return err
79116
}
80117

81-
func (a *Adapter) loadPolicyFile(model model.Model, handler func(string, model.Model)) error {
118+
func (a *Adapter) loadPolicyFile(model model.Model, filter *Filter, handler func(string, model.Model)) error {
82119
f, err := os.Open(a.filePath)
83120
if err != nil {
84121
return err
@@ -89,6 +126,14 @@ func (a *Adapter) loadPolicyFile(model model.Model, handler func(string, model.M
89126
for {
90127
line, err := buf.ReadString('\n')
91128
line = strings.TrimSpace(line)
129+
130+
// If a filter is defined, apply it to the policy
131+
if filter != nil {
132+
if filterLine(line, filter) {
133+
continue
134+
}
135+
}
136+
92137
handler(line, model)
93138
if err != nil {
94139
if err == io.EOF {
@@ -99,6 +144,28 @@ func (a *Adapter) loadPolicyFile(model model.Model, handler func(string, model.M
99144
}
100145
}
101146

147+
func filterLine(line string, filter *Filter) bool {
148+
p := strings.Split(line, ",")
149+
if len(p) == 0 {
150+
return true
151+
}
152+
var filterSet []string
153+
if strings.TrimSpace(p[0]) == "p" {
154+
filterSet = filter.P
155+
}
156+
if strings.TrimSpace(p[0]) == "g" {
157+
filterSet = filter.G
158+
}
159+
var skip bool
160+
for i, v := range filterSet {
161+
if len(p) < i+2 || len(v) > 0 && strings.TrimSpace(v) != strings.TrimSpace(p[i+1]) {
162+
skip = true
163+
break
164+
}
165+
}
166+
return skip
167+
}
168+
102169
func (a *Adapter) savePolicyFile(text string) error {
103170
f, err := os.Create(a.filePath)
104171
if err != nil {

0 commit comments

Comments
 (0)