Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions enforcer.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
// Enforcer is the main interface for authorization enforcement and policy management.
type Enforcer struct {
modelPath string
model *model.Model
model model.Model
fm model.FunctionMap
eft effect.Effector

Expand Down Expand Up @@ -99,7 +99,7 @@ func NewEnforcer(params ...interface{}) (*Enforcer, error) {
case string:
return nil, errors.New("invalid parameters for enforcer")
default:
err := e.InitWithModelAndAdapter(p0.(*model.Model), params[1].(persist.Adapter))
err := e.InitWithModelAndAdapter(p0.(model.Model), params[1].(persist.Adapter))
if err != nil {
return nil, err
}
Expand All @@ -113,7 +113,7 @@ func NewEnforcer(params ...interface{}) (*Enforcer, error) {
return nil, err
}
default:
err := e.InitWithModelAndAdapter(p0.(*model.Model), nil)
err := e.InitWithModelAndAdapter(p0.(model.Model), nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -150,7 +150,7 @@ func (e *Enforcer) InitWithAdapter(modelPath string, adapter persist.Adapter) er
}

// InitWithModelAndAdapter initializes an enforcer with a model and a database adapter.
func (e *Enforcer) InitWithModelAndAdapter(m *model.Model, adapter persist.Adapter) error {
func (e *Enforcer) InitWithModelAndAdapter(m model.Model, adapter persist.Adapter) error {
e.adapter = adapter

e.model = m
Expand Down Expand Up @@ -200,12 +200,12 @@ func (e *Enforcer) LoadModel() error {
}

// GetModel gets the current model.
func (e *Enforcer) GetModel() *model.Model {
func (e *Enforcer) GetModel() model.Model {
return e.model
}

// SetModel sets the current model.
func (e *Enforcer) SetModel(m *model.Model) {
func (e *Enforcer) SetModel(m model.Model) {
e.model = m
e.fm = model.LoadFunctionMap()

Expand Down
6 changes: 3 additions & 3 deletions enforcer_interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ type IEnforcer interface {
/* Enforcer API */
InitWithFile(modelPath string, policyPath string) error
InitWithAdapter(modelPath string, adapter persist.Adapter) error
InitWithModelAndAdapter(m *model.Model, adapter persist.Adapter) error
InitWithModelAndAdapter(m model.Model, adapter persist.Adapter) error
LoadModel() error
GetModel() *model.Model
SetModel(m *model.Model)
GetModel() model.Model
SetModel(m model.Model)
GetAdapter() persist.Adapter
SetAdapter(adapter persist.Adapter)
SetWatcher(watcher persist.Watcher) error
Expand Down
97 changes: 65 additions & 32 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,41 @@ import (
"github.com/casbin/casbin/v3/util"
)

// Model represents the whole access control model.
type Model struct {
// Model is an interface that manager the whole access control model.
type Model interface {
AddDef(sec string, key string, value string) bool
AddPolicy(sec string, ptype string, rule []string)
AddPolicies(sec string, ptype string, rules [][]string)
BuildRoleLinks(rm rbac.RoleManager) error
BuildIncrementalRoleLinks(rm rbac.RoleManager, op PolicyOp, sec string, ptype string, rules [][]string) error
ClearPolicy()
GenerateFunctions(fm FunctionMap) map[string]govaluate.ExpressionFunction
GetPtypes(sec string) []string
GetTokens(sec string, ptype string) map[string]int
GetMatcher() string
GetPolicy(sec string, ptype string) [][]string
GetFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) [][]string
GetValuesForFieldInPolicy(sec string, ptype string, fieldIndex int) []string
GetValuesForFieldInPolicyAllTypes(sec string, fieldIndex int) []string
GetEffectExpression() string
GetRoleManager(sec string, ptype string) rbac.RoleManager
HasPolicy(sec string, ptype string, rule []string) bool
HasPolicies(sec string, ptype string, rules [][]string) bool
PrintPolicy()
PrintModel()
RemovePolicy(sec string, ptype string, rule []string) bool
RemovePolicies(sec string, ptype string, rules [][]string) bool
RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) (bool, [][]string)
}

// DefaultModel provides a default implementation.
type DefaultModel struct {
data map[string]AssertionMap
mutex sync.RWMutex
}

var _ Model = (*DefaultModel)(nil)

// AssertionMap is the collection of assertions, can be "r", "p", "g", "e", "m".
type AssertionMap map[string]*Assertion

Expand All @@ -48,19 +77,19 @@ var sectionNameMap = map[string]string{
// Minimal required sections for a model to be valid
var requiredSections = []string{"r", "p", "e", "m"}

func loadAssertion(model *Model, cfg config.ConfigInterface, sec string, key string) bool {
func loadAssertion(model *DefaultModel, cfg config.ConfigInterface, sec string, key string) bool {
value := cfg.String(sectionNameMap[sec] + "::" + key)
return model.addDef(sec, key, value)
}

// AddDef adds an assertion to the model.
func (model *Model) AddDef(sec string, key string, value string) bool {
func (model *DefaultModel) AddDef(sec string, key string, value string) bool {
model.mutex.Lock()
defer model.mutex.Unlock()
return model.addDef(sec, key, value)
}

func (model *Model) addDef(sec string, key string, value string) bool {
func (model *DefaultModel) addDef(sec string, key string, value string) bool {
if value == "" {
return false
}
Expand Down Expand Up @@ -96,7 +125,7 @@ func getKeySuffix(i int) string {
return strconv.Itoa(i)
}

func loadSection(model *Model, cfg config.ConfigInterface, sec string) {
func loadSection(model *DefaultModel, cfg config.ConfigInterface, sec string) {
i := 1
for {
if !loadAssertion(model, cfg, sec, sec+getKeySuffix(i)) {
Expand All @@ -106,40 +135,44 @@ func loadSection(model *Model, cfg config.ConfigInterface, sec string) {
}
}
}

// NewModel creates an empty model.
func NewModel() *Model {
m := new(Model)
// newDefaultModel creates an empty model.
func newDefaultModel() *DefaultModel {
m := new(DefaultModel)
m.data = make(map[string]AssertionMap)
return m
}

// NewModelFromFile creates a model from a .CONF file.
func NewModelFromFile(path string) (*Model, error) {
m := NewModel()
// NewModel creates an empty model.
func NewModel() Model {
return newDefaultModel()
}

err := m.LoadModel(path)
// NewModelFromFile creates a model from a .CONF file.
func NewModelFromFile(path string) (Model, error) {
m := newDefaultModel()
cfg, err := config.NewConfig(path)
if err != nil {
return nil, err
}

return m, nil
err = m.loadModelFromConfig(cfg)
if err != nil {
return nil, err
}
return m, err
}

// NewModelFromString creates a model from a string which contains model text.
func NewModelFromString(text string) (*Model, error) {
m := NewModel()

func NewModelFromString(text string) (Model, error) {
m := newDefaultModel()
err := m.LoadModelFromText(text)
if err != nil {
return nil, err
}

return m, nil
}

// LoadModel loads the model from model CONF file.
func (model *Model) LoadModel(path string) error {
func (model *DefaultModel) LoadModel(path string) error {
cfg, err := config.NewConfig(path)
if err != nil {
return err
Expand All @@ -149,7 +182,7 @@ func (model *Model) LoadModel(path string) error {
}

// LoadModelFromText loads the model from the text.
func (model *Model) LoadModelFromText(text string) error {
func (model *DefaultModel) LoadModelFromText(text string) error {
cfg, err := config.NewConfigFromText(text)
if err != nil {
return err
Expand All @@ -158,7 +191,7 @@ func (model *Model) LoadModelFromText(text string) error {
return model.loadModelFromConfig(cfg)
}

func (model *Model) loadModelFromConfig(cfg config.ConfigInterface) error {
func (model *DefaultModel) loadModelFromConfig(cfg config.ConfigInterface) error {
model.mutex.Lock()
defer model.mutex.Unlock()
for s := range sectionNameMap {
Expand All @@ -176,16 +209,16 @@ func (model *Model) loadModelFromConfig(cfg config.ConfigInterface) error {
return nil
}

func (model *Model) hasSection(sec string) bool {
func (model *DefaultModel) hasSection(sec string) bool {
section := model.data[sec]
return section != nil
}

// PrintModel prints the model to the log.
func (model *Model) PrintModel() {
func (model *DefaultModel) PrintModel() {
model.mutex.RLock()
defer model.mutex.RUnlock()
log.LogPrint("Model:")
log.LogPrint("DefaultModel:")
for k, v := range model.data {
for i, j := range v {
log.LogPrintf("%s.%s: %s", k, i, j.Value)
Expand All @@ -194,28 +227,28 @@ func (model *Model) PrintModel() {
}

// GetMatcher gets the matcher.
func (model *Model) GetMatcher() string {
func (model *DefaultModel) GetMatcher() string {
model.mutex.RLock()
defer model.mutex.RUnlock()
return model.data["m"]["m"].Value
}

// GetEffectExpression gets the effect expression.
func (model *Model) GetEffectExpression() string {
func (model *DefaultModel) GetEffectExpression() string {
model.mutex.RLock()
defer model.mutex.RUnlock()
return model.data["e"]["e"].Value
}

// GetRoleManager gets the current role manager used in ptype.
func (model *Model) GetRoleManager(sec string, ptype string) rbac.RoleManager {
func (model *DefaultModel) GetRoleManager(sec string, ptype string) rbac.RoleManager {
model.mutex.RLock()
defer model.mutex.RUnlock()
return model.data[sec][ptype].RM
}

// GetTokens returns a map with all the tokens
func (model *Model) GetTokens(sec string, ptype string) map[string]int {
func (model *DefaultModel) GetTokens(sec string, ptype string) map[string]int {
model.mutex.RLock()
defer model.mutex.RUnlock()
tokens := make(map[string]int, len(model.data[sec][ptype].Tokens))
Expand All @@ -227,7 +260,7 @@ func (model *Model) GetTokens(sec string, ptype string) map[string]int {
}

// GetPtypes returns a slice for all ptype
func (model *Model) GetPtypes(sec string) []string {
func (model *DefaultModel) GetPtypes(sec string) []string {
model.mutex.RLock()
defer model.mutex.RUnlock()
var res []string
Expand All @@ -238,7 +271,7 @@ func (model *Model) GetPtypes(sec string) []string {
}

// GenerateFunctions return a map with all the functions
func (model *Model) GenerateFunctions(fm FunctionMap) map[string]govaluate.ExpressionFunction {
func (model *DefaultModel) GenerateFunctions(fm FunctionMap) map[string]govaluate.ExpressionFunction {
model.mutex.RLock()
defer model.mutex.RUnlock()
functions := fm.GetFunctions()
Expand Down
Loading