Skip to content

Commit 2614daf

Browse files
Remove ipRulesProvider.get's skipRule parameter (#11610)
Related issues: buildbuddy-io/buildbuddy-internal#6797 --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
1 parent e1053a3 commit 2614daf

File tree

1 file changed

+31
-27
lines changed

1 file changed

+31
-27
lines changed

enterprise/server/ip_rules_enforcer/ip_rules_enforcer.go

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,43 +36,47 @@ const (
3636
cacheSize = 100_000
3737
)
3838

39+
type ipRule struct {
40+
id string
41+
allowed *net.IPNet
42+
}
43+
3944
type ipRuleCache interface {
40-
Add(groupID string, allowed []*net.IPNet) bool
45+
Add(groupID string, allowed []ipRule) bool
4146
Remove(groupID string) bool
42-
Get(groupID string) ([]*net.IPNet, bool)
47+
Get(groupID string) ([]ipRule, bool)
4348
}
4449

4550
type noopIpRuleCache struct {
4651
}
4752

48-
func (c *noopIpRuleCache) Add(groupID string, allowed []*net.IPNet) bool {
53+
func (c *noopIpRuleCache) Add(groupID string, allowed []ipRule) bool {
4954
return false
5055
}
5156

5257
func (c *noopIpRuleCache) Remove(groupID string) bool {
5358
return false
5459
}
5560

56-
func (c *noopIpRuleCache) Get(groupID string) ([]*net.IPNet, bool) {
61+
func (c *noopIpRuleCache) Get(groupID string) ([]ipRule, bool) {
5762
return nil, false
5863
}
5964

6065
func newIpRuleCache() (ipRuleCache, error) {
6166
if *cacheTTL == 0 {
6267
return &noopIpRuleCache{}, nil
6368
}
64-
return lru.New(&lru.Config[[]*net.IPNet]{
69+
return lru.New(&lru.Config[[]ipRule]{
6570
TTL: *cacheTTL,
6671
MaxSize: cacheSize,
67-
SizeFn: func(v []*net.IPNet) int64 { return int64(len(v)) },
72+
SizeFn: func(v []ipRule) int64 { return int64(len(v)) },
6873
ThreadSafe: true,
6974
})
7075
}
7176

7277
// An abstraction for retrieving IP rules from a source of truth.
7378
type ipRulesProvider interface {
74-
// TODO(iain): get rid of skipRuleID.
75-
get(ctx context.Context, groupID string, skipRuleID string) ([]*net.IPNet, error)
79+
get(ctx context.Context, groupID string) ([]ipRule, error)
7680
invalidate(ctx context.Context, groupID string)
7781
startRefresher(env environment.Env) error
7882
}
@@ -104,29 +108,29 @@ func (p *dbIPRulesProvider) loadRulesFromDB(ctx context.Context, groupID string)
104108
return rules, nil
105109
}
106110

107-
func (p *dbIPRulesProvider) loadParsedRulesFromDB(ctx context.Context, groupID string, skipRuleID string) ([]*net.IPNet, error) {
111+
func (p *dbIPRulesProvider) loadParsedRulesFromDB(ctx context.Context, groupID string) ([]ipRule, error) {
108112
rs, err := p.loadRulesFromDB(ctx, groupID)
109113
if err != nil {
110114
return nil, err
111115
}
112116

113-
var allowed []*net.IPNet
117+
var allowed []ipRule
114118
for _, r := range rs {
115-
if r.IPRuleID == skipRuleID {
116-
continue
117-
}
118119
_, ipNet, err := net.ParseCIDR(r.CIDR)
119120
if err != nil {
120121
alert.UnexpectedEvent("unparsable CIDR rule", "rule %q", r.CIDR)
121122
continue
122123
}
123-
allowed = append(allowed, ipNet)
124+
allowed = append(allowed, ipRule{
125+
id: r.IPRuleID,
126+
allowed: ipNet,
127+
})
124128
}
125129
return allowed, nil
126130
}
127131

128132
func (p *dbIPRulesProvider) refreshRules(ctx context.Context, groupID string) error {
129-
pr, err := p.loadParsedRulesFromDB(ctx, groupID, "" /*=skipRuleId*/)
133+
pr, err := p.loadParsedRulesFromDB(ctx, groupID)
130134
if err != nil {
131135
return err
132136
}
@@ -135,17 +139,14 @@ func (p *dbIPRulesProvider) refreshRules(ctx context.Context, groupID string) er
135139
return nil
136140
}
137141

138-
func (p *dbIPRulesProvider) get(ctx context.Context, groupID string, skipRuleID string) ([]*net.IPNet, error) {
142+
func (p *dbIPRulesProvider) get(ctx context.Context, groupID string) ([]ipRule, error) {
139143
allowed, ok := p.cache.Get(groupID)
140144
if !ok {
141-
pr, err := p.loadParsedRulesFromDB(ctx, groupID, skipRuleID)
145+
pr, err := p.loadParsedRulesFromDB(ctx, groupID)
142146
if err != nil {
143147
return nil, err
144148
}
145-
// if skipRuleID is set, the retrieved rule list may be incomplete.
146-
if skipRuleID == "" {
147-
p.cache.Add(groupID, pr)
148-
}
149+
p.cache.Add(groupID, pr)
149150
allowed = pr
150151
}
151152
return allowed, nil
@@ -240,7 +241,7 @@ func (n *NoOpEnforcer) AuthorizeHTTPRequest(ctx context.Context, r *http.Request
240241
func (n *NoOpEnforcer) InvalidateCache(ctx context.Context, groupID string) {
241242
}
242243

243-
func (n *NoOpEnforcer) Check(ctx context.Context, groupID string, skipRuleID string) error {
244+
func (n *NoOpEnforcer) Check(ctx context.Context, groupID, skipRuleID string) error {
244245
return nil
245246
}
246247

@@ -272,21 +273,24 @@ func Register(env *real_environment.RealEnv) error {
272273
return nil
273274
}
274275

275-
func (s *Enforcer) Check(ctx context.Context, groupID string, skipRuleID string) error {
276+
func (s *Enforcer) Check(ctx context.Context, groupID, skipRuleID string) error {
276277
rawClientIP := clientip.Get(ctx)
277278
clientIP := net.ParseIP(rawClientIP)
278279
// Client IP is not parsable.
279280
if clientIP == nil {
280281
return status.FailedPreconditionErrorf("client IP %q is not valid", rawClientIP)
281282
}
282283

283-
allowed, err := s.rulesProvider.get(ctx, groupID, skipRuleID)
284+
rules, err := s.rulesProvider.get(ctx, groupID)
284285
if err != nil {
285286
return err
286287
}
287288

288-
for _, a := range allowed {
289-
if a.Contains(clientIP) {
289+
for _, rule := range rules {
290+
if rule.id == skipRuleID {
291+
continue
292+
}
293+
if rule.allowed.Contains(clientIP) {
290294
return nil
291295
}
292296
}
@@ -296,7 +300,7 @@ func (s *Enforcer) Check(ctx context.Context, groupID string, skipRuleID string)
296300

297301
func (s *Enforcer) authorize(ctx context.Context, groupID string) error {
298302
start := time.Now()
299-
err := s.Check(ctx, groupID, "" /*skipRuleID*/)
303+
err := s.Check(ctx, groupID, "" /*=skipRuleID*/)
300304
metrics.IPRulesCheckLatencyUsec.With(
301305
prometheus.Labels{metrics.StatusHumanReadableLabel: status.MetricsLabel(err)},
302306
).Observe(float64(time.Since(start).Microseconds()))

0 commit comments

Comments
 (0)