Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deps/go_deps.MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ use_repo(
"org_golang_x_time",
"org_golang_x_tools",
"org_uber_go_atomic",
"org_uber_go_goleak",
)

inject_repo(
Expand Down
3 changes: 3 additions & 0 deletions enterprise/server/ip_rules_enforcer/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ go_library(
"//server/util/db",
"//server/util/log",
"//server/util/lru",
"//server/util/proto",
"//server/util/status",
"@com_github_prometheus_client_golang//prometheus",
],
Expand All @@ -37,9 +38,11 @@ go_test(
"//server/testutil/testenv",
"//server/util/authutil",
"//server/util/clientip",
"//server/util/proto",
"//server/util/status",
"//server/util/testing/flags",
"@com_github_stretchr_testify//require",
"@org_golang_google_grpc//metadata",
"@org_uber_go_goleak//:goleak",
],
)
67 changes: 56 additions & 11 deletions enterprise/server/ip_rules_enforcer/ip_rules_enforcer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net"
"net/http"
"strings"
"sync"
"time"

"github.com/buildbuddy-io/buildbuddy/server/environment"
Expand All @@ -18,6 +19,7 @@ import (
"github.com/buildbuddy-io/buildbuddy/server/util/db"
"github.com/buildbuddy-io/buildbuddy/server/util/log"
"github.com/buildbuddy-io/buildbuddy/server/util/lru"
"github.com/buildbuddy-io/buildbuddy/server/util/proto"
"github.com/buildbuddy-io/buildbuddy/server/util/status"
"github.com/prometheus/client_golang/prometheus"

Expand Down Expand Up @@ -153,25 +155,68 @@ func (p *dbIPRulesProvider) invalidate(ctx context.Context, groupID string) {
p.cache.Remove(groupID)
}

// TODO(iain): halt goroutine on server exit.
func (p *dbIPRulesProvider) startRefresher(env environment.Env) error {
sns := env.GetServerNotificationService()
if sns == nil {
return nil
}
go func() {
for msg := range sns.Subscribe(&snpb.InvalidateIPRulesCache{}) {
ic, ok := msg.(*snpb.InvalidateIPRulesCache)
hc := env.GetHealthChecker()
if hc == nil {
return status.FailedPreconditionError("Missing health checker")
}
stop := make(chan struct{})
done := make(chan struct{})
var shutdownOnce sync.Once
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can make this a bit neater by doing:

closeStop := sync.OnceFunc(func() {close(stop)})

Then you can just call pass that to shutdownRefresher, instead of passing both stop and shutdownOnce. Or you can even just call closeStop() before calling shutdownRefresher, since it's pretty weird to pass in a function just to call once it right away. Then maybe shutdownRefresher should be called waitForShutdown.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, fixed

sub := sns.Subscribe(&snpb.InvalidateIPRulesCache{})
go p.runRefresher(env.GetServerContext(), sub, stop, done)
hc.RegisterShutdownFunction(func(ctx context.Context) error {
return p.shutdownRefresher(ctx, stop, done, &shutdownOnce)
})
return nil
}

// runRefresher listens for cache invalidation messages and refreshes the IP
// rules cache accordingly. It closes the done channel when it exits, and can be
// stopped by another goroutine via the stop channel.
func (p *dbIPRulesProvider) runRefresher(ctx context.Context, sub <-chan proto.Message, stop <-chan struct{}, done chan<- struct{}) {
defer close(done)
for {
select {
case <-stop:
return
case msg, ok := <-sub:
if !ok {
alert.UnexpectedEvent("iprules_invalid_proto_type", "received proto type %T", msg)
continue
}
if err := p.refreshRules(env.GetServerContext(), ic.GetGroupId()); err != nil {
log.Warningf("could not refresh IP rules for group %q: %s", ic.GetGroupId(), err)
return
}
p.handleRefresherMessage(ctx, msg)
}
}()
return nil
}
}

func (p *dbIPRulesProvider) handleRefresherMessage(ctx context.Context, msg proto.Message) {
ic, ok := msg.(*snpb.InvalidateIPRulesCache)
if !ok {
alert.UnexpectedEvent("iprules_invalid_proto_type", "received proto type %T", msg)
return
}
if err := p.refreshRules(ctx, ic.GetGroupId()); err != nil {
log.Warningf("could not refresh IP rules for group %q: %s", ic.GetGroupId(), err)
}
}

// The notification service does not expose an unsubscribe API, so shutdown is a
// two-step handshake: signal the refresher to stop waiting on the subscription
// channel, then wait for the goroutine to confirm it has exited.
func (p *dbIPRulesProvider) shutdownRefresher(ctx context.Context, stop chan struct{}, done <-chan struct{}, shutdownOnce *sync.Once) error {
shutdownOnce.Do(func() {
close(stop)
})
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

type Enforcer struct {
Expand Down
30 changes: 30 additions & 0 deletions enterprise/server/ip_rules_enforcer/ip_rules_enforcer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,26 @@ import (
"github.com/buildbuddy-io/buildbuddy/server/testutil/testenv"
"github.com/buildbuddy-io/buildbuddy/server/util/authutil"
"github.com/buildbuddy-io/buildbuddy/server/util/clientip"
"github.com/buildbuddy-io/buildbuddy/server/util/proto"
"github.com/buildbuddy-io/buildbuddy/server/util/status"
"github.com/buildbuddy-io/buildbuddy/server/util/testing/flags"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"google.golang.org/grpc/metadata"
)

type fakeServerNotificationService struct {
ch chan proto.Message
}

func (f *fakeServerNotificationService) Subscribe(msgType proto.Message) <-chan proto.Message {
return f.ch
}

func (f *fakeServerNotificationService) Publish(ctx context.Context, msg proto.Message) error {
return nil
}

func newIPRulesEnforcer(t *testing.T, env environment.Env) *ip_rules_enforcer.Enforcer {
t.Helper()

Expand Down Expand Up @@ -241,3 +255,19 @@ func TestAuthorize_TrustedClientIdentityBypasses(t *testing.T) {
err := irs.Authorize(authCtx)
require.NoError(t, err)
}

func TestRefresherStopsOnShutdown(t *testing.T) {
env := getEnv(t)

// Install a noop server notification service to ensure the refresher runs.
env.SetServerNotificationService(&fakeServerNotificationService{
ch: make(chan proto.Message),
})

// The env starts some goroutines that aren't cleaned up. Ignore them.
defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
_ = newIPRulesEnforcer(t, env)

env.GetHealthChecker().Shutdown()
env.GetHealthChecker().WaitForGracefulShutdown()
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ require (
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 // indirect
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
go.uber.org/goleak v1.3.0
go.uber.org/mock v0.6.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
Expand Down
Loading