@@ -62,72 +62,32 @@ func newIpRuleCache() (ipRuleCache, error) {
6262 })
6363}
6464
65- type Enforcer struct {
66- env environment.Env
67-
68- cache ipRuleCache
69- }
70-
71- type NoOpEnforcer struct {}
72-
73- func (n * NoOpEnforcer ) Authorize (ctx context.Context ) error {
74- return nil
75- }
76-
77- func (n * NoOpEnforcer ) AuthorizeGroup (ctx context.Context , groupID string ) error {
78- return nil
79- }
80-
81- func (n * NoOpEnforcer ) AuthorizeHTTPRequest (ctx context.Context , r * http.Request ) error {
82- return nil
65+ // An abstraction for retrieving IP rules from a source of truth.
66+ type ipRulesProvider interface {
67+ // TODO(iain): get rid of skipCache and skipRuleID.
68+ get (ctx context.Context , groupID string , skipCache bool , skipRuleID string ) ([]* net.IPNet , error )
69+ startRefresher (env environment.Env ) error
8370}
8471
85- func (n * NoOpEnforcer ) Check (ctx context.Context , groupID string , skipCache bool , skipRuleID string ) error {
86- return nil
72+ // An implementation of ipRulesProvider that retrieves IP rules from a database.
73+ type dbIPRulesProvider struct {
74+ db interfaces.DBHandle
75+ cache ipRuleCache
8776}
8877
89- func New (env environment.Env ) (* Enforcer , error ) {
78+ func newDBIPRulesProvider (env environment.Env ) (* dbIPRulesProvider , error ) {
9079 cache , err := newIpRuleCache ()
9180 if err != nil {
9281 return nil , err
9382 }
94-
95- svc := & Enforcer {
96- env : env ,
83+ return & dbIPRulesProvider {
84+ db : env .GetDBHandle (),
9785 cache : cache ,
98- }
99- if sns := env .GetServerNotificationService (); sns != nil {
100- go func () {
101- for msg := range sns .Subscribe (& snpb.InvalidateIPRulesCache {}) {
102- ic , ok := msg .(* snpb.InvalidateIPRulesCache )
103- if ! ok {
104- alert .UnexpectedEvent ("iprules_invalid_proto_type" , "received proto type %T" , msg )
105- continue
106- }
107- if err := svc .refreshRules (env .GetServerContext (), ic .GetGroupId ()); err != nil {
108- log .Warningf ("could not refresh IP rules for group %q: %s" , ic .GetGroupId (), err )
109- }
110- }
111- }()
112- }
113- return svc , nil
86+ }, nil
11487}
11588
116- func Register (env * real_environment.RealEnv ) error {
117- var enforcer interfaces.IPRulesEnforcer = & NoOpEnforcer {}
118- if * enableIPRules {
119- realEnforcer , err := New (env )
120- if err != nil {
121- return err
122- }
123- enforcer = realEnforcer
124- }
125- env .SetIPRulesEnforcer (enforcer )
126- return nil
127- }
128-
129- func (s * Enforcer ) loadRulesFromDB (ctx context.Context , groupID string ) ([]* tables.IPRule , error ) {
130- rq := s .env .GetDBHandle ().NewQuery (ctx , "iprules_load_rules" ).Raw (
89+ func (p * dbIPRulesProvider ) loadRulesFromDB (ctx context.Context , groupID string ) ([]* tables.IPRule , error ) {
90+ rq := p .db .NewQuery (ctx , "iprules_load_rules" ).Raw (
13191 `SELECT * FROM "IPRules" WHERE group_id = ? ORDER BY created_at_usec` , groupID )
13292 rules , err := db .ScanAll (rq , & tables.IPRule {})
13393 if err != nil {
@@ -136,8 +96,8 @@ func (s *Enforcer) loadRulesFromDB(ctx context.Context, groupID string) ([]*tabl
13696 return rules , nil
13797}
13898
139- func (s * Enforcer ) loadParsedRulesFromDB (ctx context.Context , groupID string , skipRuleID string ) ([]* net.IPNet , error ) {
140- rs , err := s .loadRulesFromDB (ctx , groupID )
99+ func (p * dbIPRulesProvider ) loadParsedRulesFromDB (ctx context.Context , groupID string , skipRuleID string ) ([]* net.IPNet , error ) {
100+ rs , err := p .loadRulesFromDB (ctx , groupID )
141101 if err != nil {
142102 return nil , err
143103 }
@@ -157,16 +117,104 @@ func (s *Enforcer) loadParsedRulesFromDB(ctx context.Context, groupID string, sk
157117 return allowed , nil
158118}
159119
160- func (s * Enforcer ) refreshRules (ctx context.Context , groupID string ) error {
161- pr , err := s .loadParsedRulesFromDB (ctx , groupID , "" /*=skipRuleId*/ )
120+ func (p * dbIPRulesProvider ) refreshRules (ctx context.Context , groupID string ) error {
121+ pr , err := p .loadParsedRulesFromDB (ctx , groupID , "" /*=skipRuleId*/ )
162122 if err != nil {
163123 return err
164124 }
165- s .cache .Add (groupID , pr )
125+ p .cache .Add (groupID , pr )
166126 log .CtxInfof (ctx , "refreshed IP rules for group %s" , groupID )
167127 return nil
168128}
169129
130+ func (p * dbIPRulesProvider ) get (ctx context.Context , groupID string , skipCache bool , skipRuleID string ) ([]* net.IPNet , error ) {
131+ allowed , ok := p .cache .Get (groupID )
132+ if ! ok || skipCache {
133+ pr , err := p .loadParsedRulesFromDB (ctx , groupID , skipRuleID )
134+ if err != nil {
135+ return nil , err
136+ }
137+ // if skipRuleID is set, the retrieved rule list may be incomplete.
138+ if skipRuleID == "" {
139+ p .cache .Add (groupID , pr )
140+ }
141+ allowed = pr
142+ }
143+ return allowed , nil
144+ }
145+
146+ // TODO(iain): halt goroutine on server exit.
147+ func (p * dbIPRulesProvider ) startRefresher (env environment.Env ) error {
148+ sns := env .GetServerNotificationService ()
149+ if sns == nil {
150+ return nil
151+ }
152+ go func () {
153+ for msg := range sns .Subscribe (& snpb.InvalidateIPRulesCache {}) {
154+ ic , ok := msg .(* snpb.InvalidateIPRulesCache )
155+ if ! ok {
156+ alert .UnexpectedEvent ("iprules_invalid_proto_type" , "received proto type %T" , msg )
157+ continue
158+ }
159+ if err := p .refreshRules (env .GetServerContext (), ic .GetGroupId ()); err != nil {
160+ log .Warningf ("could not refresh IP rules for group %q: %s" , ic .GetGroupId (), err )
161+ }
162+ }
163+ }()
164+ return nil
165+ }
166+
167+ type Enforcer struct {
168+ env environment.Env
169+ rulesProvider ipRulesProvider
170+ }
171+
172+ type NoOpEnforcer struct {}
173+
174+ func (n * NoOpEnforcer ) Authorize (ctx context.Context ) error {
175+ return nil
176+ }
177+
178+ func (n * NoOpEnforcer ) AuthorizeGroup (ctx context.Context , groupID string ) error {
179+ return nil
180+ }
181+
182+ func (n * NoOpEnforcer ) AuthorizeHTTPRequest (ctx context.Context , r * http.Request ) error {
183+ return nil
184+ }
185+
186+ func (n * NoOpEnforcer ) Check (ctx context.Context , groupID string , skipCache bool , skipRuleID string ) error {
187+ return nil
188+ }
189+
190+ func New (env environment.Env ) (* Enforcer , error ) {
191+ rulesProvider , err := newDBIPRulesProvider (env )
192+ if err != nil {
193+ return nil , err
194+ }
195+
196+ if err := rulesProvider .startRefresher (env ); err != nil {
197+ return nil , err
198+ }
199+ return & Enforcer {
200+ env : env ,
201+ rulesProvider : rulesProvider ,
202+ }, nil
203+ }
204+
205+ func Register (env * real_environment.RealEnv ) error {
206+ var enforcer interfaces.IPRulesEnforcer = & NoOpEnforcer {}
207+ if * enableIPRules {
208+ realEnforcer , err := New (env )
209+ if err != nil {
210+ return err
211+ }
212+ enforcer = realEnforcer
213+ }
214+ env .SetIPRulesEnforcer (enforcer )
215+ return nil
216+ }
217+
170218func (s * Enforcer ) Check (ctx context.Context , groupID string , skipCache bool , skipRuleID string ) error {
171219 rawClientIP := clientip .Get (ctx )
172220 clientIP := net .ParseIP (rawClientIP )
@@ -175,17 +223,9 @@ func (s *Enforcer) Check(ctx context.Context, groupID string, skipCache bool, sk
175223 return status .FailedPreconditionErrorf ("client IP %q is not valid" , rawClientIP )
176224 }
177225
178- allowed , ok := s .cache .Get (groupID )
179- if ! ok || skipCache {
180- pr , err := s .loadParsedRulesFromDB (ctx , groupID , skipRuleID )
181- if err != nil {
182- return err
183- }
184- // if skipRuleID is set, the retrieved rule list may be incomplete.
185- if skipRuleID == "" {
186- s .cache .Add (groupID , pr )
187- }
188- allowed = pr
226+ allowed , err := s .rulesProvider .get (ctx , groupID , skipCache , skipRuleID )
227+ if err != nil {
228+ return err
189229 }
190230
191231 for _ , a := range allowed {
0 commit comments