Skip to content

Commit 9230799

Browse files
stuff
1 parent d0dbc74 commit 9230799

File tree

1 file changed

+111
-71
lines changed

1 file changed

+111
-71
lines changed

enterprise/server/ip_rules_enforcer/ip_rules_enforcer.go

Lines changed: 111 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -62,72 +62,32 @@ func newIpRuleCache() (ipRuleCache, error) {
6262
})
6363
}
6464

65-
type Enforcer struct {
66-
env environment.Env
67-
68-
cache ipRuleCache
69-
}
70-
71-
type NoOpEnforcer struct{}
72-
73-
func (n *NoOpEnforcer) Authorize(ctx context.Context) error {
74-
return nil
75-
}
76-
77-
func (n *NoOpEnforcer) AuthorizeGroup(ctx context.Context, groupID string) error {
78-
return nil
79-
}
80-
81-
func (n *NoOpEnforcer) AuthorizeHTTPRequest(ctx context.Context, r *http.Request) error {
82-
return nil
65+
// An abstraction for retrieving IP rules from a source of truth.
66+
type ipRulesProvider interface {
67+
// TODO(iain): get rid of skipCache and skipRuleID.
68+
get(ctx context.Context, groupID string, skipCache bool, skipRuleID string) ([]*net.IPNet, error)
69+
startRefresher(env environment.Env) error
8370
}
8471

85-
func (n *NoOpEnforcer) Check(ctx context.Context, groupID string, skipCache bool, skipRuleID string) error {
86-
return nil
72+
// An implementation of ipRulesProvider that retrieves IP rules from a database.
73+
type dbIPRulesProvider struct {
74+
db interfaces.DBHandle
75+
cache ipRuleCache
8776
}
8877

89-
func New(env environment.Env) (*Enforcer, error) {
78+
func newDBIPRulesProvider(env environment.Env) (*dbIPRulesProvider, error) {
9079
cache, err := newIpRuleCache()
9180
if err != nil {
9281
return nil, err
9382
}
94-
95-
svc := &Enforcer{
96-
env: env,
83+
return &dbIPRulesProvider{
84+
db: env.GetDBHandle(),
9785
cache: cache,
98-
}
99-
if sns := env.GetServerNotificationService(); sns != nil {
100-
go func() {
101-
for msg := range sns.Subscribe(&snpb.InvalidateIPRulesCache{}) {
102-
ic, ok := msg.(*snpb.InvalidateIPRulesCache)
103-
if !ok {
104-
alert.UnexpectedEvent("iprules_invalid_proto_type", "received proto type %T", msg)
105-
continue
106-
}
107-
if err := svc.refreshRules(env.GetServerContext(), ic.GetGroupId()); err != nil {
108-
log.Warningf("could not refresh IP rules for group %q: %s", ic.GetGroupId(), err)
109-
}
110-
}
111-
}()
112-
}
113-
return svc, nil
86+
}, nil
11487
}
11588

116-
func Register(env *real_environment.RealEnv) error {
117-
var enforcer interfaces.IPRulesEnforcer = &NoOpEnforcer{}
118-
if *enableIPRules {
119-
realEnforcer, err := New(env)
120-
if err != nil {
121-
return err
122-
}
123-
enforcer = realEnforcer
124-
}
125-
env.SetIPRulesEnforcer(enforcer)
126-
return nil
127-
}
128-
129-
func (s *Enforcer) loadRulesFromDB(ctx context.Context, groupID string) ([]*tables.IPRule, error) {
130-
rq := s.env.GetDBHandle().NewQuery(ctx, "iprules_load_rules").Raw(
89+
func (p *dbIPRulesProvider) loadRulesFromDB(ctx context.Context, groupID string) ([]*tables.IPRule, error) {
90+
rq := p.db.NewQuery(ctx, "iprules_load_rules").Raw(
13191
`SELECT * FROM "IPRules" WHERE group_id = ? ORDER BY created_at_usec`, groupID)
13292
rules, err := db.ScanAll(rq, &tables.IPRule{})
13393
if err != nil {
@@ -136,8 +96,8 @@ func (s *Enforcer) loadRulesFromDB(ctx context.Context, groupID string) ([]*tabl
13696
return rules, nil
13797
}
13898

139-
func (s *Enforcer) loadParsedRulesFromDB(ctx context.Context, groupID string, skipRuleID string) ([]*net.IPNet, error) {
140-
rs, err := s.loadRulesFromDB(ctx, groupID)
99+
func (p *dbIPRulesProvider) loadParsedRulesFromDB(ctx context.Context, groupID string, skipRuleID string) ([]*net.IPNet, error) {
100+
rs, err := p.loadRulesFromDB(ctx, groupID)
141101
if err != nil {
142102
return nil, err
143103
}
@@ -157,16 +117,104 @@ func (s *Enforcer) loadParsedRulesFromDB(ctx context.Context, groupID string, sk
157117
return allowed, nil
158118
}
159119

160-
func (s *Enforcer) refreshRules(ctx context.Context, groupID string) error {
161-
pr, err := s.loadParsedRulesFromDB(ctx, groupID, "" /*=skipRuleId*/)
120+
func (p *dbIPRulesProvider) refreshRules(ctx context.Context, groupID string) error {
121+
pr, err := p.loadParsedRulesFromDB(ctx, groupID, "" /*=skipRuleId*/)
162122
if err != nil {
163123
return err
164124
}
165-
s.cache.Add(groupID, pr)
125+
p.cache.Add(groupID, pr)
166126
log.CtxInfof(ctx, "refreshed IP rules for group %s", groupID)
167127
return nil
168128
}
169129

130+
func (p *dbIPRulesProvider) get(ctx context.Context, groupID string, skipCache bool, skipRuleID string) ([]*net.IPNet, error) {
131+
allowed, ok := p.cache.Get(groupID)
132+
if !ok || skipCache {
133+
pr, err := p.loadParsedRulesFromDB(ctx, groupID, skipRuleID)
134+
if err != nil {
135+
return nil, err
136+
}
137+
// if skipRuleID is set, the retrieved rule list may be incomplete.
138+
if skipRuleID == "" {
139+
p.cache.Add(groupID, pr)
140+
}
141+
allowed = pr
142+
}
143+
return allowed, nil
144+
}
145+
146+
// TODO(iain): halt goroutine on server exit.
147+
func (p *dbIPRulesProvider) startRefresher(env environment.Env) error {
148+
sns := env.GetServerNotificationService()
149+
if sns == nil {
150+
return nil
151+
}
152+
go func() {
153+
for msg := range sns.Subscribe(&snpb.InvalidateIPRulesCache{}) {
154+
ic, ok := msg.(*snpb.InvalidateIPRulesCache)
155+
if !ok {
156+
alert.UnexpectedEvent("iprules_invalid_proto_type", "received proto type %T", msg)
157+
continue
158+
}
159+
if err := p.refreshRules(env.GetServerContext(), ic.GetGroupId()); err != nil {
160+
log.Warningf("could not refresh IP rules for group %q: %s", ic.GetGroupId(), err)
161+
}
162+
}
163+
}()
164+
return nil
165+
}
166+
167+
type Enforcer struct {
168+
env environment.Env
169+
rulesProvider ipRulesProvider
170+
}
171+
172+
type NoOpEnforcer struct{}
173+
174+
func (n *NoOpEnforcer) Authorize(ctx context.Context) error {
175+
return nil
176+
}
177+
178+
func (n *NoOpEnforcer) AuthorizeGroup(ctx context.Context, groupID string) error {
179+
return nil
180+
}
181+
182+
func (n *NoOpEnforcer) AuthorizeHTTPRequest(ctx context.Context, r *http.Request) error {
183+
return nil
184+
}
185+
186+
func (n *NoOpEnforcer) Check(ctx context.Context, groupID string, skipCache bool, skipRuleID string) error {
187+
return nil
188+
}
189+
190+
func New(env environment.Env) (*Enforcer, error) {
191+
rulesProvider, err := newDBIPRulesProvider(env)
192+
if err != nil {
193+
return nil, err
194+
}
195+
196+
if err := rulesProvider.startRefresher(env); err != nil {
197+
return nil, err
198+
}
199+
return &Enforcer{
200+
env: env,
201+
rulesProvider: rulesProvider,
202+
}, nil
203+
}
204+
205+
func Register(env *real_environment.RealEnv) error {
206+
var enforcer interfaces.IPRulesEnforcer = &NoOpEnforcer{}
207+
if *enableIPRules {
208+
realEnforcer, err := New(env)
209+
if err != nil {
210+
return err
211+
}
212+
enforcer = realEnforcer
213+
}
214+
env.SetIPRulesEnforcer(enforcer)
215+
return nil
216+
}
217+
170218
func (s *Enforcer) Check(ctx context.Context, groupID string, skipCache bool, skipRuleID string) error {
171219
rawClientIP := clientip.Get(ctx)
172220
clientIP := net.ParseIP(rawClientIP)
@@ -175,17 +223,9 @@ func (s *Enforcer) Check(ctx context.Context, groupID string, skipCache bool, sk
175223
return status.FailedPreconditionErrorf("client IP %q is not valid", rawClientIP)
176224
}
177225

178-
allowed, ok := s.cache.Get(groupID)
179-
if !ok || skipCache {
180-
pr, err := s.loadParsedRulesFromDB(ctx, groupID, skipRuleID)
181-
if err != nil {
182-
return err
183-
}
184-
// if skipRuleID is set, the retrieved rule list may be incomplete.
185-
if skipRuleID == "" {
186-
s.cache.Add(groupID, pr)
187-
}
188-
allowed = pr
226+
allowed, err := s.rulesProvider.get(ctx, groupID, skipCache, skipRuleID)
227+
if err != nil {
228+
return err
189229
}
190230

191231
for _, a := range allowed {

0 commit comments

Comments
 (0)