Skip to content

Commit 1bb3849

Browse files
Remove IPRulesEnforcer.Check's skipCache parameter in favor of explicit cache invalidation
1 parent 905f657 commit 1bb3849

File tree

5 files changed

+72
-33
lines changed

5 files changed

+72
-33
lines changed

enterprise/server/ip_rules_enforcer/ip_rules_enforcer.go

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ const (
3636

3737
// An abstraction for retrieving IP rules from a source of truth.
3838
type ipRulesProvider interface {
39-
get(ctx context.Context, groupID string, skipCache bool) ([]*ipRule, error)
39+
get(ctx context.Context, groupID string) ([]*ipRule, error)
40+
remove(ctx context.Context, groupID string)
4041
startRefresher(env environment.Env) error
4142
}
4243

@@ -59,9 +60,9 @@ func newIPRulesProvider(db interfaces.DBHandle) (ipRulesProvider, error) {
5960
return &dbIPRulesProvider{db: db, cache: cache}, nil
6061
}
6162

62-
func (p *dbIPRulesProvider) get(ctx context.Context, groupID string, skipCache bool) ([]*ipRule, error) {
63+
func (p *dbIPRulesProvider) get(ctx context.Context, groupID string) ([]*ipRule, error) {
6364
allowed, ok := p.cache.Get(groupID)
64-
if ok && !skipCache {
65+
if ok {
6566
return allowed, nil
6667
}
6768

@@ -73,6 +74,10 @@ func (p *dbIPRulesProvider) get(ctx context.Context, groupID string, skipCache b
7374
return allowed, nil
7475
}
7576

77+
func (p *dbIPRulesProvider) remove(ctx context.Context, groupID string) {
78+
p.cache.Remove(groupID)
79+
}
80+
7681
func (p *dbIPRulesProvider) refresh(ctx context.Context, groupID string) error {
7782
allowed, err := p.loadParsedRulesFromDB(ctx, groupID)
7883
if err != nil {
@@ -160,6 +165,7 @@ func (p *dbIPRulesProvider) loadParsedRulesFromDB(ctx context.Context, groupID s
160165

161166
type ipRuleCache interface {
162167
Add(groupID string, allowed []*ipRule) bool
168+
Remove(groupID string) bool
163169
Get(groupID string) ([]*ipRule, bool)
164170
}
165171

@@ -170,6 +176,10 @@ func (c *noopIpRuleCache) Add(groupID string, allowed []*ipRule) bool {
170176
return false
171177
}
172178

179+
func (c *noopIpRuleCache) Remove(groupID string) bool {
180+
return false
181+
}
182+
173183
func (c *noopIpRuleCache) Get(groupID string) ([]*ipRule, bool) {
174184
return nil, false
175185
}
@@ -206,7 +216,10 @@ func (n *NoOpEnforcer) AuthorizeHTTPRequest(ctx context.Context, r *http.Request
206216
return nil
207217
}
208218

209-
func (n *NoOpEnforcer) Check(ctx context.Context, groupID string, skipCache bool, skipRuleID string) error {
219+
func (s *NoOpEnforcer) InvalidateCachedRules(ctx context.Context, groupID string) {
220+
}
221+
222+
func (n *NoOpEnforcer) Check(ctx context.Context, groupID string, skipRuleID string) error {
210223
return nil
211224
}
212225

@@ -235,15 +248,19 @@ func Register(env *real_environment.RealEnv) error {
235248
return nil
236249
}
237250

238-
func (s *Enforcer) Check(ctx context.Context, groupID string, skipCache bool, skipRuleID string) error {
251+
func (s *Enforcer) InvalidateCachedRules(ctx context.Context, groupID string) {
252+
s.rulesProvider.remove(ctx, groupID)
253+
}
254+
255+
func (s *Enforcer) Check(ctx context.Context, groupID string, skipRuleID string) error {
239256
rawClientIP := clientip.Get(ctx)
240257
clientIP := net.ParseIP(rawClientIP)
241258
// Client IP is not parsable.
242259
if clientIP == nil {
243260
return status.FailedPreconditionErrorf("client IP %q is not valid", rawClientIP)
244261
}
245262

246-
allowed, err := s.rulesProvider.get(ctx, groupID, skipCache)
263+
allowed, err := s.rulesProvider.get(ctx, groupID)
247264
if err != nil {
248265
return err
249266
}
@@ -262,7 +279,7 @@ func (s *Enforcer) Check(ctx context.Context, groupID string, skipCache bool, sk
262279

263280
func (s *Enforcer) authorize(ctx context.Context, groupID string) error {
264281
start := time.Now()
265-
err := s.Check(ctx, groupID, false /*=skipCache*/, "" /*skipRuleID*/)
282+
err := s.Check(ctx, groupID, "" /*skipRuleID*/)
266283
metrics.IPRulesCheckLatencyUsec.With(
267284
prometheus.Labels{metrics.StatusHumanReadableLabel: status.MetricsLabel(err)},
268285
).Observe(float64(time.Since(start).Microseconds()))

enterprise/server/ip_rules_enforcer/ip_rules_enforcer_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ func TestNoOpEnforcer(t *testing.T) {
137137
require.NoError(t, enforcer.Authorize(context.Background()))
138138
require.NoError(t, enforcer.AuthorizeGroup(context.Background(), "G1"))
139139
require.NoError(t, enforcer.AuthorizeHTTPRequest(context.Background(), httptest.NewRequest("GET", "/rpc/BuildBuddyService/GetUser", nil)))
140-
require.NoError(t, enforcer.Check(context.Background(), "G1", true, ""))
140+
require.NoError(t, enforcer.Check(context.Background(), "G1", ""))
141141
}
142142

143143
func TestAuthorizeAndAuthorizeGroup_EnforcementNotEnabled(t *testing.T) {
@@ -211,10 +211,10 @@ func TestCheckSkipRuleID(t *testing.T) {
211211
ruleID := insertRule(t, env, groupID, "1.2.3.4/32", "rule1")
212212
ctx := context.WithValue(context.Background(), clientip.ContextKey, "1.2.3.4")
213213

214-
err := irs.Check(ctx, groupID, true /* skipCache */, "" /* skipRuleID */)
214+
err := irs.Check(ctx, groupID, "" /* skipRuleID */)
215215
require.NoError(t, err)
216216

217-
err = irs.Check(ctx, groupID, true /* skipCache */, ruleID)
217+
err = irs.Check(ctx, groupID, ruleID)
218218
require.Error(t, err)
219219
require.True(t, status.IsPermissionDeniedError(err))
220220
}
@@ -278,17 +278,17 @@ func TestRefresherStopsOnShutdown(t *testing.T) {
278278

279279
insertRule(t, env, groupID, "1.2.3.4/32", "rule1")
280280
ctx1 := context.WithValue(context.Background(), clientip.ContextKey, "1.2.3.4")
281-
require.NoError(t, irs.Check(ctx1, groupID, false /*=skipCache*/, "" /*=skipRuleID*/))
281+
require.NoError(t, irs.Check(ctx1, groupID, "" /*=skipRuleID*/))
282282

283283
insertRule(t, env, groupID, "4.5.6.7/32", "rule2")
284284
ctx2 := context.WithValue(context.Background(), clientip.ContextKey, "4.5.6.7")
285-
err = irs.Check(ctx2, groupID, false /*=skipCache*/, "" /*=skipRuleID*/)
285+
err = irs.Check(ctx2, groupID, "" /*=skipRuleID*/)
286286
require.Error(t, err)
287287
require.True(t, status.IsPermissionDeniedError(err))
288288

289289
sns.ch <- &snpb.InvalidateIPRulesCache{GroupId: groupID}
290290
require.Eventually(t, func() bool {
291-
return irs.Check(ctx2, groupID, false /*=skipCache*/, "" /*=skipRuleID*/) == nil
291+
return irs.Check(ctx2, groupID, "" /*=skipRuleID*/) == nil
292292
}, time.Second, 10*time.Millisecond)
293293

294294
env.GetHealthChecker().Shutdown()
@@ -298,7 +298,7 @@ func TestRefresherStopsOnShutdown(t *testing.T) {
298298
sns.ch <- &snpb.InvalidateIPRulesCache{GroupId: groupID}
299299

300300
ctx3 := context.WithValue(context.Background(), clientip.ContextKey, "8.9.10.11")
301-
err = irs.Check(ctx3, groupID, false /*=skipCache*/, "" /*=skipRuleID*/)
301+
err = irs.Check(ctx3, groupID, "" /*=skipRuleID*/)
302302
require.Error(t, err)
303303
require.True(t, status.IsPermissionDeniedError(err))
304304
}

enterprise/server/ip_rules_service/ip_rules_service.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ func (s *Service) SetIPRuleConfig(ctx context.Context, req *irpb.SetRulesConfigR
107107
}
108108

109109
if req.GetEnforceIpRules() {
110-
err := s.enforcer.Check(ctx, req.GetRequestContext().GetGroupId(), true /*=skipCache*/, "" /*=skipRuleID*/)
110+
groupID := req.GetRequestContext().GetGroupId()
111+
s.enforcer.InvalidateCachedRules(ctx, groupID)
112+
err := s.enforcer.Check(ctx, groupID, "" /*=skipRuleID*/)
111113
if err != nil {
112114
if status.IsPermissionDeniedError(err) {
113115
return nil, status.InvalidArgumentErrorf("Enabling IP rule enforcement would block your IP (%s) from accessing the organization.", clientip.Get(ctx))
@@ -237,7 +239,9 @@ func (s *Service) DeleteRule(ctx context.Context, req *irpb.DeleteRuleRequest) (
237239
if g.EnforceIPRules {
238240
// Check if deleting the rule would lock out the client calling this
239241
// API.
240-
err := s.enforcer.Check(ctx, req.GetRequestContext().GetGroupId(), true /*=skipCache*/, req.GetIpRuleId())
242+
groupID := req.GetRequestContext().GetGroupId()
243+
s.enforcer.InvalidateCachedRules(ctx, groupID)
244+
err := s.enforcer.Check(ctx, groupID, req.GetIpRuleId())
241245
if err != nil {
242246
if status.IsPermissionDeniedError(err) {
243247
return nil, status.InvalidArgumentErrorf("Deleting this rule would block your IP (%s) from accessing the organization.", clientip.Get(ctx))

enterprise/server/ip_rules_service/ip_rules_service_test.go

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ import (
2222
snpb "github.com/buildbuddy-io/buildbuddy/proto/server_notification"
2323
)
2424

25-
type checkCall struct {
26-
groupID string
27-
skipCache bool
28-
skipRuleID string
25+
type operations struct {
26+
groupID string
27+
invalidation bool
28+
skipRuleID string
2929
}
3030

3131
type fakeIPRulesEnforcer struct {
32-
checkCalls []checkCall
32+
checkCalls []operations
3333
checkErr error
3434
}
3535

@@ -45,10 +45,16 @@ func (f *fakeIPRulesEnforcer) AuthorizeHTTPRequest(ctx context.Context, r *http.
4545
return nil
4646
}
4747

48-
func (f *fakeIPRulesEnforcer) Check(ctx context.Context, groupID string, skipCache bool, skipRuleID string) error {
49-
f.checkCalls = append(f.checkCalls, checkCall{
48+
func (f *fakeIPRulesEnforcer) InvalidateCachedRules(ctx context.Context, groupID string) {
49+
f.checkCalls = append(f.checkCalls, operations{
50+
groupID: groupID,
51+
invalidation: true,
52+
})
53+
}
54+
55+
func (f *fakeIPRulesEnforcer) Check(ctx context.Context, groupID string, skipRuleID string) error {
56+
f.checkCalls = append(f.checkCalls, operations{
5057
groupID: groupID,
51-
skipCache: skipCache,
5258
skipRuleID: skipRuleID,
5359
})
5460
return f.checkErr
@@ -262,7 +268,10 @@ func TestSetAndGetIPRuleConfig(t *testing.T) {
262268
EnforceIpRules: true,
263269
})
264270
require.NoError(t, err)
265-
require.Equal(t, []checkCall{{groupID: groupID, skipCache: true, skipRuleID: ""}}, enforcer.checkCalls)
271+
require.Equal(t, []operations{
272+
{groupID: groupID, invalidation: true},
273+
{groupID: groupID, skipRuleID: ""},
274+
}, enforcer.checkCalls)
266275

267276
cfgRsp, err = svc.GetIPRuleConfig(authCtx, &irpb.GetRulesConfigRequest{
268277
RequestContext: &ctxpb.RequestContext{GroupId: groupID},
@@ -275,7 +284,7 @@ func TestSetAndGetIPRuleConfig(t *testing.T) {
275284
EnforceIpRules: false,
276285
})
277286
require.NoError(t, err)
278-
require.Len(t, enforcer.checkCalls, 1)
287+
require.Len(t, enforcer.checkCalls, 2)
279288

280289
cfgRsp, err = svc.GetIPRuleConfig(authCtx, &irpb.GetRulesConfigRequest{
281290
RequestContext: &ctxpb.RequestContext{GroupId: groupID},
@@ -296,7 +305,10 @@ func TestSetIPRuleConfigRejectsLockout(t *testing.T) {
296305
require.Error(t, err)
297306
require.True(t, status.IsInvalidArgumentError(err))
298307
require.Contains(t, err.Error(), "9.8.7.6")
299-
require.Equal(t, []checkCall{{groupID: groupID, skipCache: true, skipRuleID: ""}}, enforcer.checkCalls)
308+
require.Equal(t, []operations{
309+
{groupID: groupID, invalidation: true},
310+
{groupID: groupID, skipRuleID: ""},
311+
}, enforcer.checkCalls)
300312

301313
cfgRsp, err := svc.GetIPRuleConfig(authCtx, &irpb.GetRulesConfigRequest{
302314
RequestContext: &ctxpb.RequestContext{GroupId: groupID},
@@ -330,11 +342,13 @@ func TestDeleteRuleRejectsLockout(t *testing.T) {
330342
require.Error(t, err)
331343
require.True(t, status.IsInvalidArgumentError(err))
332344
require.Contains(t, err.Error(), "1.2.3.4")
333-
require.Equal(t, []checkCall{{
334-
groupID: groupID,
335-
skipCache: true,
336-
skipRuleID: addRsp.GetRule().GetIpRuleId(),
337-
}}, enforcer.checkCalls)
345+
require.Equal(t, []operations{
346+
{groupID: groupID, invalidation: true},
347+
{
348+
groupID: groupID,
349+
skipRuleID: addRsp.GetRule().GetIpRuleId(),
350+
},
351+
}, enforcer.checkCalls)
338352

339353
_, err = svc.GetRule(authCtx, groupID, addRsp.GetRule().GetIpRuleId())
340354
require.NoError(t, err)

server/interfaces/interfaces.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1579,10 +1579,14 @@ type IPRulesEnforcer interface {
15791579
// context.
15801580
AuthorizeHTTPRequest(ctx context.Context, r *http.Request) error
15811581

1582+
// Invalidates all cached IP Rules for the provided group ID, ensuring the
1583+
// next authorization/check request will use fresh rules from the backend.
1584+
InvalidateCachedRules(ctx context.Context, groupID string)
1585+
15821586
// Performs an explicit IP rule check for the given group ID with the
15831587
// option to force refresh rules from the backend and skip specific rules
15841588
// (for testing rule changes made by IPRulesService).
1585-
Check(ctx context.Context, groupID string, skipCache bool, skipRuleID string) error
1589+
Check(ctx context.Context, groupID string, skipRuleID string) error
15861590
}
15871591

15881592
type IPRulesService interface {

0 commit comments

Comments
 (0)