@@ -36,43 +36,47 @@ const (
3636 cacheSize = 100_000
3737)
3838
39+ type ipRule struct {
40+ id string
41+ allowed * net.IPNet
42+ }
43+
3944type ipRuleCache interface {
40- Add (groupID string , allowed []* net. IPNet ) bool
45+ Add (groupID string , allowed []ipRule ) bool
4146 Remove (groupID string ) bool
42- Get (groupID string ) ([]* net. IPNet , bool )
47+ Get (groupID string ) ([]ipRule , bool )
4348}
4449
4550type noopIpRuleCache struct {
4651}
4752
48- func (c * noopIpRuleCache ) Add (groupID string , allowed []* net. IPNet ) bool {
53+ func (c * noopIpRuleCache ) Add (groupID string , allowed []ipRule ) bool {
4954 return false
5055}
5156
5257func (c * noopIpRuleCache ) Remove (groupID string ) bool {
5358 return false
5459}
5560
56- func (c * noopIpRuleCache ) Get (groupID string ) ([]* net. IPNet , bool ) {
61+ func (c * noopIpRuleCache ) Get (groupID string ) ([]ipRule , bool ) {
5762 return nil , false
5863}
5964
6065func newIpRuleCache () (ipRuleCache , error ) {
6166 if * cacheTTL == 0 {
6267 return & noopIpRuleCache {}, nil
6368 }
64- return lru .New (& lru.Config [[]* net. IPNet ]{
69+ return lru .New (& lru.Config [[]ipRule ]{
6570 TTL : * cacheTTL ,
6671 MaxSize : cacheSize ,
67- SizeFn : func (v []* net. IPNet ) int64 { return int64 (len (v )) },
72+ SizeFn : func (v []ipRule ) int64 { return int64 (len (v )) },
6873 ThreadSafe : true ,
6974 })
7075}
7176
7277// An abstraction for retrieving IP rules from a source of truth.
7378type ipRulesProvider interface {
74- // TODO(iain): get rid of skipRuleID.
75- get (ctx context.Context , groupID string , skipRuleID string ) ([]* net.IPNet , error )
79+ get (ctx context.Context , groupID string ) ([]ipRule , error )
7680 invalidate (ctx context.Context , groupID string )
7781 startRefresher (env environment.Env ) error
7882}
@@ -104,29 +108,29 @@ func (p *dbIPRulesProvider) loadRulesFromDB(ctx context.Context, groupID string)
104108 return rules , nil
105109}
106110
107- func (p * dbIPRulesProvider ) loadParsedRulesFromDB (ctx context.Context , groupID string , skipRuleID string ) ([]* net. IPNet , error ) {
111+ func (p * dbIPRulesProvider ) loadParsedRulesFromDB (ctx context.Context , groupID string ) ([]ipRule , error ) {
108112 rs , err := p .loadRulesFromDB (ctx , groupID )
109113 if err != nil {
110114 return nil , err
111115 }
112116
113- var allowed []* net. IPNet
117+ var allowed []ipRule
114118 for _ , r := range rs {
115- if r .IPRuleID == skipRuleID {
116- continue
117- }
118119 _ , ipNet , err := net .ParseCIDR (r .CIDR )
119120 if err != nil {
120121 alert .UnexpectedEvent ("unparsable CIDR rule" , "rule %q" , r .CIDR )
121122 continue
122123 }
123- allowed = append (allowed , ipNet )
124+ allowed = append (allowed , ipRule {
125+ id : r .IPRuleID ,
126+ allowed : ipNet ,
127+ })
124128 }
125129 return allowed , nil
126130}
127131
128132func (p * dbIPRulesProvider ) refreshRules (ctx context.Context , groupID string ) error {
129- pr , err := p .loadParsedRulesFromDB (ctx , groupID , "" /*=skipRuleId*/ )
133+ pr , err := p .loadParsedRulesFromDB (ctx , groupID )
130134 if err != nil {
131135 return err
132136 }
@@ -135,17 +139,14 @@ func (p *dbIPRulesProvider) refreshRules(ctx context.Context, groupID string) er
135139 return nil
136140}
137141
138- func (p * dbIPRulesProvider ) get (ctx context.Context , groupID string , skipRuleID string ) ([]* net. IPNet , error ) {
142+ func (p * dbIPRulesProvider ) get (ctx context.Context , groupID string ) ([]ipRule , error ) {
139143 allowed , ok := p .cache .Get (groupID )
140144 if ! ok {
141- pr , err := p .loadParsedRulesFromDB (ctx , groupID , skipRuleID )
145+ pr , err := p .loadParsedRulesFromDB (ctx , groupID )
142146 if err != nil {
143147 return nil , err
144148 }
145- // if skipRuleID is set, the retrieved rule list may be incomplete.
146- if skipRuleID == "" {
147- p .cache .Add (groupID , pr )
148- }
149+ p .cache .Add (groupID , pr )
149150 allowed = pr
150151 }
151152 return allowed , nil
@@ -240,7 +241,7 @@ func (n *NoOpEnforcer) AuthorizeHTTPRequest(ctx context.Context, r *http.Request
240241func (n * NoOpEnforcer ) InvalidateCache (ctx context.Context , groupID string ) {
241242}
242243
243- func (n * NoOpEnforcer ) Check (ctx context.Context , groupID string , skipRuleID string ) error {
244+ func (n * NoOpEnforcer ) Check (ctx context.Context , groupID , skipRuleID string ) error {
244245 return nil
245246}
246247
@@ -272,21 +273,24 @@ func Register(env *real_environment.RealEnv) error {
272273 return nil
273274}
274275
275- func (s * Enforcer ) Check (ctx context.Context , groupID string , skipRuleID string ) error {
276+ func (s * Enforcer ) Check (ctx context.Context , groupID , skipRuleID string ) error {
276277 rawClientIP := clientip .Get (ctx )
277278 clientIP := net .ParseIP (rawClientIP )
278279 // Client IP is not parsable.
279280 if clientIP == nil {
280281 return status .FailedPreconditionErrorf ("client IP %q is not valid" , rawClientIP )
281282 }
282283
283- allowed , err := s .rulesProvider .get (ctx , groupID , skipRuleID )
284+ rules , err := s .rulesProvider .get (ctx , groupID )
284285 if err != nil {
285286 return err
286287 }
287288
288- for _ , a := range allowed {
289- if a .Contains (clientIP ) {
289+ for _ , rule := range rules {
290+ if rule .id == skipRuleID {
291+ continue
292+ }
293+ if rule .allowed .Contains (clientIP ) {
290294 return nil
291295 }
292296 }
@@ -296,7 +300,7 @@ func (s *Enforcer) Check(ctx context.Context, groupID string, skipRuleID string)
296300
297301func (s * Enforcer ) authorize (ctx context.Context , groupID string ) error {
298302 start := time .Now ()
299- err := s .Check (ctx , groupID , "" /*skipRuleID*/ )
303+ err := s .Check (ctx , groupID , "" /*= skipRuleID*/ )
300304 metrics .IPRulesCheckLatencyUsec .With (
301305 prometheus.Labels {metrics .StatusHumanReadableLabel : status .MetricsLabel (err )},
302306 ).Observe (float64 (time .Since (start ).Microseconds ()))
0 commit comments