Skip to content

Commit ecc2f9f

Browse files
Remove IPRulesEnforcer.Check's skipCache parameter (#11609)
Related issues: buildbuddy-io/buildbuddy-internal#6797
1 parent d71d04f commit ecc2f9f

File tree

5 files changed

+77
-36
lines changed

5 files changed

+77
-36
lines changed

enterprise/server/ip_rules_enforcer/ip_rules_enforcer.go

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ const (
3636

3737
type ipRuleCache interface {
3838
Add(groupID string, allowed []*net.IPNet) bool
39+
Remove(groupID string) bool
3940
Get(groupID string) ([]*net.IPNet, bool)
4041
}
4142

@@ -46,6 +47,10 @@ func (c *noopIpRuleCache) Add(groupID string, allowed []*net.IPNet) bool {
4647
return false
4748
}
4849

50+
func (c *noopIpRuleCache) Remove(groupID string) bool {
51+
return false
52+
}
53+
4954
func (c *noopIpRuleCache) Get(groupID string) ([]*net.IPNet, bool) {
5055
return nil, false
5156
}
@@ -64,8 +69,9 @@ func newIpRuleCache() (ipRuleCache, error) {
6469

6570
// An abstraction for retrieving IP rules from a source of truth.
6671
type ipRulesProvider interface {
67-
// TODO(iain): get rid of skipCache and skipRuleID.
68-
get(ctx context.Context, groupID string, skipCache bool, skipRuleID string) ([]*net.IPNet, error)
72+
// TODO(iain): get rid of skipRuleID.
73+
get(ctx context.Context, groupID string, skipRuleID string) ([]*net.IPNet, error)
74+
invalidate(ctx context.Context, groupID string)
6975
startRefresher(env environment.Env) error
7076
}
7177

@@ -127,9 +133,9 @@ func (p *dbIPRulesProvider) refreshRules(ctx context.Context, groupID string) er
127133
return nil
128134
}
129135

130-
func (p *dbIPRulesProvider) get(ctx context.Context, groupID string, skipCache bool, skipRuleID string) ([]*net.IPNet, error) {
136+
func (p *dbIPRulesProvider) get(ctx context.Context, groupID string, skipRuleID string) ([]*net.IPNet, error) {
131137
allowed, ok := p.cache.Get(groupID)
132-
if !ok || skipCache {
138+
if !ok {
133139
pr, err := p.loadParsedRulesFromDB(ctx, groupID, skipRuleID)
134140
if err != nil {
135141
return nil, err
@@ -143,6 +149,10 @@ func (p *dbIPRulesProvider) get(ctx context.Context, groupID string, skipCache b
143149
return allowed, nil
144150
}
145151

152+
func (p *dbIPRulesProvider) invalidate(ctx context.Context, groupID string) {
153+
p.cache.Remove(groupID)
154+
}
155+
146156
// TODO(iain): halt goroutine on server exit.
147157
func (p *dbIPRulesProvider) startRefresher(env environment.Env) error {
148158
sns := env.GetServerNotificationService()
@@ -183,7 +193,10 @@ func (n *NoOpEnforcer) AuthorizeHTTPRequest(ctx context.Context, r *http.Request
183193
return nil
184194
}
185195

186-
func (n *NoOpEnforcer) Check(ctx context.Context, groupID string, skipCache bool, skipRuleID string) error {
196+
func (n *NoOpEnforcer) InvalidateCache(ctx context.Context, groupID string) {
197+
}
198+
199+
func (n *NoOpEnforcer) Check(ctx context.Context, groupID string, skipRuleID string) error {
187200
return nil
188201
}
189202

@@ -215,15 +228,15 @@ func Register(env *real_environment.RealEnv) error {
215228
return nil
216229
}
217230

218-
func (s *Enforcer) Check(ctx context.Context, groupID string, skipCache bool, skipRuleID string) error {
231+
func (s *Enforcer) Check(ctx context.Context, groupID string, skipRuleID string) error {
219232
rawClientIP := clientip.Get(ctx)
220233
clientIP := net.ParseIP(rawClientIP)
221234
// Client IP is not parsable.
222235
if clientIP == nil {
223236
return status.FailedPreconditionErrorf("client IP %q is not valid", rawClientIP)
224237
}
225238

226-
allowed, err := s.rulesProvider.get(ctx, groupID, skipCache, skipRuleID)
239+
allowed, err := s.rulesProvider.get(ctx, groupID, skipRuleID)
227240
if err != nil {
228241
return err
229242
}
@@ -239,7 +252,7 @@ func (s *Enforcer) Check(ctx context.Context, groupID string, skipCache bool, sk
239252

240253
func (s *Enforcer) authorize(ctx context.Context, groupID string) error {
241254
start := time.Now()
242-
err := s.Check(ctx, groupID, false /*=skipCache*/, "" /*skipRuleID*/)
255+
err := s.Check(ctx, groupID, "" /*skipRuleID*/)
243256
metrics.IPRulesCheckLatencyUsec.With(
244257
prometheus.Labels{metrics.StatusHumanReadableLabel: status.MetricsLabel(err)},
245258
).Observe(float64(time.Since(start).Microseconds()))
@@ -330,3 +343,7 @@ func (s *Enforcer) AuthorizeHTTPRequest(ctx context.Context, r *http.Request) er
330343

331344
return nil
332345
}
346+
347+
func (s *Enforcer) InvalidateCache(ctx context.Context, groupID string) {
348+
s.rulesProvider.invalidate(ctx, groupID)
349+
}

enterprise/server/ip_rules_enforcer/ip_rules_enforcer_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ func TestNoOpEnforcer(t *testing.T) {
116116
require.NoError(t, enforcer.Authorize(context.Background()))
117117
require.NoError(t, enforcer.AuthorizeGroup(context.Background(), "G1"))
118118
require.NoError(t, enforcer.AuthorizeHTTPRequest(context.Background(), httptest.NewRequest("GET", "/rpc/BuildBuddyService/GetUser", nil)))
119-
require.NoError(t, enforcer.Check(context.Background(), "G1", true, ""))
119+
require.NoError(t, enforcer.Check(context.Background(), "G1", ""))
120120
}
121121

122122
func TestAuthorizeAndAuthorizeGroup_EnforcementNotEnabled(t *testing.T) {
@@ -190,10 +190,10 @@ func TestCheckSkipRuleID(t *testing.T) {
190190
ruleID := insertRule(t, env, groupID, "1.2.3.4/32", "rule1")
191191
ctx := context.WithValue(context.Background(), clientip.ContextKey, "1.2.3.4")
192192

193-
err := irs.Check(ctx, groupID, true /* skipCache */, "" /* skipRuleID */)
193+
err := irs.Check(ctx, groupID, "" /* skipRuleID */)
194194
require.NoError(t, err)
195195

196-
err = irs.Check(ctx, groupID, true /* skipCache */, ruleID)
196+
err = irs.Check(ctx, groupID, ruleID)
197197
require.Error(t, err)
198198
require.True(t, status.IsPermissionDeniedError(err))
199199
}

enterprise/server/ip_rules_service/ip_rules_service.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ func (s *Service) SetIPRuleConfig(ctx context.Context, req *irpb.SetRulesConfigR
107107
}
108108

109109
if req.GetEnforceIpRules() {
110-
err := s.enforcer.Check(ctx, req.GetRequestContext().GetGroupId(), true /*=skipCache*/, "" /*=skipRuleID*/)
110+
groupID := req.GetRequestContext().GetGroupId()
111+
s.enforcer.InvalidateCache(ctx, groupID)
112+
err := s.enforcer.Check(ctx, groupID, "" /*=skipRuleID*/)
111113
if err != nil {
112114
if status.IsPermissionDeniedError(err) {
113115
return nil, status.InvalidArgumentErrorf("Enabling IP rule enforcement would block your IP (%s) from accessing the organization.", clientip.Get(ctx))
@@ -235,9 +237,10 @@ func (s *Service) DeleteRule(ctx context.Context, req *irpb.DeleteRuleRequest) (
235237
return nil, err
236238
}
237239
if g.EnforceIPRules {
238-
// Check if deleting the rule would lock out the client calling this
239-
// API.
240-
err := s.enforcer.Check(ctx, req.GetRequestContext().GetGroupId(), true /*=skipCache*/, req.GetIpRuleId())
240+
// Check if deleting the rule would block the client calling this API.
241+
groupID := req.GetRequestContext().GetGroupId()
242+
s.enforcer.InvalidateCache(ctx, groupID)
243+
err := s.enforcer.Check(ctx, groupID, req.GetIpRuleId())
241244
if err != nil {
242245
if status.IsPermissionDeniedError(err) {
243246
return nil, status.InvalidArgumentErrorf("Deleting this rule would block your IP (%s) from accessing the organization.", clientip.Get(ctx))

enterprise/server/ip_rules_service/ip_rules_service_test.go

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,16 @@ import (
2222
snpb "github.com/buildbuddy-io/buildbuddy/proto/server_notification"
2323
)
2424

25-
type checkCall struct {
26-
groupID string
27-
skipCache bool
28-
skipRuleID string
25+
type call struct {
26+
groupID string
27+
invalidation bool
28+
check bool
29+
skipRuleID string
2930
}
3031

3132
type fakeIPRulesEnforcer struct {
32-
checkCalls []checkCall
33-
checkErr error
33+
calls []call
34+
checkErr error
3435
}
3536

3637
func (f *fakeIPRulesEnforcer) Authorize(ctx context.Context) error {
@@ -45,10 +46,17 @@ func (f *fakeIPRulesEnforcer) AuthorizeHTTPRequest(ctx context.Context, r *http.
4546
return nil
4647
}
4748

48-
func (f *fakeIPRulesEnforcer) Check(ctx context.Context, groupID string, skipCache bool, skipRuleID string) error {
49-
f.checkCalls = append(f.checkCalls, checkCall{
49+
func (f *fakeIPRulesEnforcer) InvalidateCache(ctx context.Context, groupID string) {
50+
f.calls = append(f.calls, call{
51+
groupID: groupID,
52+
invalidation: true,
53+
})
54+
}
55+
56+
func (f *fakeIPRulesEnforcer) Check(ctx context.Context, groupID string, skipRuleID string) error {
57+
f.calls = append(f.calls, call{
5058
groupID: groupID,
51-
skipCache: skipCache,
59+
check: true,
5260
skipRuleID: skipRuleID,
5361
})
5462
return f.checkErr
@@ -262,7 +270,10 @@ func TestSetAndGetIPRuleConfig(t *testing.T) {
262270
EnforceIpRules: true,
263271
})
264272
require.NoError(t, err)
265-
require.Equal(t, []checkCall{{groupID: groupID, skipCache: true, skipRuleID: ""}}, enforcer.checkCalls)
273+
require.Equal(t, []call{
274+
{groupID: groupID, invalidation: true},
275+
{groupID: groupID, check: true},
276+
}, enforcer.calls)
266277

267278
cfgRsp, err = svc.GetIPRuleConfig(authCtx, &irpb.GetRulesConfigRequest{
268279
RequestContext: &ctxpb.RequestContext{GroupId: groupID},
@@ -275,7 +286,7 @@ func TestSetAndGetIPRuleConfig(t *testing.T) {
275286
EnforceIpRules: false,
276287
})
277288
require.NoError(t, err)
278-
require.Len(t, enforcer.checkCalls, 1)
289+
require.Len(t, enforcer.calls, 2)
279290

280291
cfgRsp, err = svc.GetIPRuleConfig(authCtx, &irpb.GetRulesConfigRequest{
281292
RequestContext: &ctxpb.RequestContext{GroupId: groupID},
@@ -296,7 +307,10 @@ func TestSetIPRuleConfigRejectsLockout(t *testing.T) {
296307
require.Error(t, err)
297308
require.True(t, status.IsInvalidArgumentError(err))
298309
require.Contains(t, err.Error(), "9.8.7.6")
299-
require.Equal(t, []checkCall{{groupID: groupID, skipCache: true, skipRuleID: ""}}, enforcer.checkCalls)
310+
require.Equal(t, []call{
311+
{groupID: groupID, invalidation: true},
312+
{groupID: groupID, check: true},
313+
}, enforcer.calls)
300314

301315
cfgRsp, err := svc.GetIPRuleConfig(authCtx, &irpb.GetRulesConfigRequest{
302316
RequestContext: &ctxpb.RequestContext{GroupId: groupID},
@@ -330,11 +344,16 @@ func TestDeleteRuleRejectsLockout(t *testing.T) {
330344
require.Error(t, err)
331345
require.True(t, status.IsInvalidArgumentError(err))
332346
require.Contains(t, err.Error(), "1.2.3.4")
333-
require.Equal(t, []checkCall{{
334-
groupID: groupID,
335-
skipCache: true,
336-
skipRuleID: addRsp.GetRule().GetIpRuleId(),
337-
}}, enforcer.checkCalls)
347+
require.Equal(t, []call{
348+
{
349+
groupID: groupID,
350+
invalidation: true,
351+
},
352+
{
353+
groupID: groupID,
354+
check: true,
355+
skipRuleID: addRsp.GetRule().GetIpRuleId(),
356+
}}, enforcer.calls)
338357

339358
_, err = svc.GetRule(authCtx, groupID, addRsp.GetRule().GetIpRuleId())
340359
require.NoError(t, err)

server/interfaces/interfaces.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,10 +1579,12 @@ type IPRulesEnforcer interface {
15791579
// context.
15801580
AuthorizeHTTPRequest(ctx context.Context, r *http.Request) error
15811581

1582-
// Performs an explicit IP rule check for the given group ID with the
1583-
// option to force refresh rules from the backend and skip specific rules
1584-
// (for testing rule changes made by IPRulesService).
1585-
Check(ctx context.Context, groupID string, skipCache bool, skipRuleID string) error
1582+
// Invalidates all cached IP rules for the specified group ID.
1583+
InvalidateCache(ctx context.Context, groupID string)
1584+
1585+
// Performs an explicit IP rule check for the given group ID, skipping the
1586+
// rule with the provided ID, if specified.
1587+
Check(ctx context.Context, groupID string, skipRuleID string) error
15861588
}
15871589

15881590
type IPRulesService interface {

0 commit comments

Comments
 (0)