Skip to content

Commit 42925ab

Browse files
committed
revert back loadStore.Stop() to accept context
1 parent 609f505 commit 42925ab

File tree

8 files changed

+74
-51
lines changed

8 files changed

+74
-51
lines changed

xds/internal/balancer/clusterimpl/clusterimpl.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
package clusterimpl
2525

2626
import (
27+
"context"
2728
"encoding/json"
2829
"fmt"
2930
"sync"
@@ -99,7 +100,7 @@ type clusterImplBalancer struct {
99100
// The following fields are only accessed from balancer API methods, which
100101
// are guaranteed to be called serially by gRPC.
101102
xdsClient xdsclient.XDSClient // Sent down in ResolverState attributes.
102-
cancelLoadReport func(time.Duration) // To stop reporting load through the above xDS client.
103+
cancelLoadReport func(context.Context) // To stop reporting load through the above xDS client.
103104
edsServiceName string // EDS service name to report load for.
104105
lrsServer *bootstrap.ServerConfig // Load reporting server configuration.
105106
dropCategories []DropConfig // The categories for drops.
@@ -221,7 +222,9 @@ func (b *clusterImplBalancer) updateLoadStore(newConfig *LBConfig) error {
221222

222223
if stopOldLoadReport {
223224
if b.cancelLoadReport != nil {
224-
b.cancelLoadReport(loadStoreStopTimeout)
225+
stopCtx, stopCancel := context.WithTimeout(context.Background(), loadStoreStopTimeout)
226+
defer stopCancel()
227+
b.cancelLoadReport(stopCtx)
225228
b.cancelLoadReport = nil
226229
if !startNewLoadReport {
227230
// If a new LRS stream will be started later, no need to update
@@ -347,7 +350,9 @@ func (b *clusterImplBalancer) Close() {
347350
b.childState = balancer.State{}
348351

349352
if b.cancelLoadReport != nil {
350-
b.cancelLoadReport(loadStoreStopTimeout)
353+
stopCtx, stopCancel := context.WithTimeout(context.Background(), loadStoreStopTimeout)
354+
defer stopCancel()
355+
b.cancelLoadReport(stopCtx)
351356
b.cancelLoadReport = nil
352357
}
353358
b.logger.Infof("Shutdown")

xds/internal/clients/lrsclient/load_store.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package lrsclient
2020

2121
import (
22+
"context"
2223
"sync"
2324
"sync/atomic"
2425
"time"
@@ -36,7 +37,7 @@ import (
3637
// It is safe for concurrent use.
3738
type LoadStore struct {
3839
// stop is the function to call to Stop the LoadStore reporting.
39-
stop func(timeout time.Duration)
40+
stop func(ctx context.Context)
4041

4142
// mu only protects the map (2 layers). The read/write to
4243
// *PerClusterReporter doesn't need to hold the mu.
@@ -66,13 +67,13 @@ func newLoadStore() *LoadStore {
6667
// Stop signals the LoadStore to stop reporting.
6768
//
6869
// Before closing the underlying LRS stream, this method may block until a
69-
// final load report send attempt completes or the provided timeout duration
70+
// final load report send attempt completes or the provided context `ctx`
7071
// expires.
7172
//
72-
// The `timeout` duration should be set to prevent Stop from blocking
73-
// indefinitely if the final send attempt fails to complete.
74-
func (ls *LoadStore) Stop(timeout time.Duration) {
75-
ls.stop(timeout)
73+
// The provided context must have a deadline or timeout set to prevent Stop
74+
// from blocking indefinitely if the final send attempt fails to complete.
75+
func (ls *LoadStore) Stop(ctx context.Context) {
76+
ls.stop(ctx)
7677
}
7778

7879
// ReporterForCluster returns the PerClusterReporter for the given cluster and

xds/internal/clients/lrsclient/loadreport_test.go

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ func (s) TestReportLoad_ConnectionCreation(t *testing.T) {
147147
if err != nil {
148148
t.Fatalf("client.ReportLoad() failed: %v", err)
149149
}
150-
defer loadStore1.Stop(defaultTestShortTimeout)
150+
ssCtx, ssCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
151+
defer ssCancel()
152+
defer loadStore1.Stop(ssCtx)
151153

152154
// Call the load reporting API to report load to the first management
153155
// server, and ensure that a connection to the server is created.
@@ -232,7 +234,9 @@ func (s) TestReportLoad_ConnectionCreation(t *testing.T) {
232234
}
233235

234236
// Stop this load reporting stream, server should see error canceled.
235-
loadStore2.Stop(defaultTestShortTimeout)
237+
ssCtx, ssCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout)
238+
defer ssCancel()
239+
loadStore2.Stop(ssCtx)
236240

237241
// Server should receive a stream canceled error. There may be additional
238242
// load reports from the client in the channel.
@@ -419,15 +423,19 @@ func (s) TestReportLoad_StreamCreation(t *testing.T) {
419423

420424
// Cancel the first load reporting call, and ensure that the stream does not
421425
// close (because we have another call open).
422-
loadStore1.Stop(defaultTestShortTimeout)
426+
ssCtx, ssCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
427+
defer ssCancel()
428+
loadStore1.Stop(ssCtx)
423429
sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout)
424430
defer sCancel()
425431
if _, err := lrsServer.LRSStreamCloseChan.Receive(sCtx); err != context.DeadlineExceeded {
426432
t.Fatal("LRS stream closed when expected to stay open")
427433
}
428434

429435
// Stop the second load reporting call, and ensure the stream is closed.
430-
loadStore2.Stop(defaultTestShortTimeout)
436+
ssCtx, ssCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout)
437+
defer ssCancel()
438+
loadStore2.Stop(ssCtx)
431439
if _, err := lrsServer.LRSStreamCloseChan.Receive(ctx); err != nil {
432440
t.Fatal("Timeout waiting for LRS stream to close")
433441
}
@@ -442,16 +450,18 @@ func (s) TestReportLoad_StreamCreation(t *testing.T) {
442450
if _, err := lrsServer.LRSStreamOpenChan.Receive(ctx); err != nil {
443451
t.Fatalf("Timeout when waiting for LRS stream to be created: %v", err)
444452
}
445-
loadStore3.Stop(defaultTestShortTimeout)
453+
ssCtx, ssCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout)
454+
defer ssCancel()
455+
loadStore3.Stop(ssCtx)
446456
}
447457

448-
// TestReportLoad_StopWithTimeout tests the behavior of LoadStore.Stop() when
449-
// called with a timeout duration. It verifies that:
450-
// - Stop() blocks until the timeout expires or final load send attempt is
458+
// TestReportLoad_StopWithContext tests the behavior of LoadStore.Stop() when
459+
// called with a context. It verifies that:
460+
// - Stop() blocks until the context expires or final load send attempt is
451461
// made.
452462
// - Final load report is seen on the server after stop is called.
453463
// - The stream is closed after Stop() returns.
454-
func (s) TestReportLoad_StopWithTimeout(t *testing.T) {
464+
func (s) TestReportLoad_StopWithContext(t *testing.T) {
455465
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
456466
defer cancel()
457467

@@ -536,11 +546,11 @@ func (s) TestReportLoad_StopWithTimeout(t *testing.T) {
536546
t.Fatalf("Unexpected diff in LRS request (-got, +want):\n%s", diff)
537547
}
538548

539-
// Create a timeout duration for Stop() that remains until the end of test
540-
// to ensure that only possibility of Stop() to finish is if final load
541-
// send attempt is made. If final load attempt is not made, test itself
542-
// will timeout.
543-
largeStopTimeout := 10 * defaultTestTimeout
549+
// Create a context for Stop() that remains until the end of test to ensure
550+
// that only possibility of Stop()s to finish is if final load send attempt
551+
// is made. If final load attempt is not made, test will timeout.
552+
stopCtx, stopCancel := context.WithCancel(ctx)
553+
defer stopCancel()
544554

545555
// Push more loads.
546556
loadStore.ReporterForCluster("cluster2", "eds2").CallDropped("test")
@@ -549,7 +559,7 @@ func (s) TestReportLoad_StopWithTimeout(t *testing.T) {
549559
// Call Stop in a separate goroutine. It will block until
550560
// final load send attempt is made.
551561
go func() {
552-
loadStore.Stop(largeStopTimeout)
562+
loadStore.Stop(stopCtx)
553563
close(stopCalled)
554564
}()
555565
<-stopCalled

xds/internal/clients/lrsclient/lrsclient.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
package lrsclient
2323

2424
import (
25+
"context"
2526
"errors"
2627
"fmt"
2728
"sync"
@@ -140,8 +141,8 @@ func (c *LRSClient) getOrCreateLRSStream(serverIdentifier clients.ServerIdentifi
140141
// the LRS stream when the last reference is removed and closes the
141142
// transport and removes the lrs stream and its references from the
142143
// respective maps. Before closing the stream, it waits for the provided
143-
// timeout duration for the final load report attempt to complete.
144-
stop := func(timeout time.Duration) {
144+
// context to be done (timeout or cancellation).
145+
stop := func(ctx context.Context) {
145146
c.mu.Lock()
146147
defer c.mu.Unlock()
147148

@@ -156,16 +157,13 @@ func (c *LRSClient) getOrCreateLRSStream(serverIdentifier clients.ServerIdentifi
156157

157158
lrs.finalSendRequest <- struct{}{}
158159

159-
timer := time.NewTimer(timeout)
160-
defer timer.Stop()
161-
162160
select {
163161
case err := <-lrs.finalSendDone:
164162
if err != nil {
165163
c.logger.Warningf("Final send attempt failed: %v", err)
166164
}
167-
case <-timer.C:
168-
c.logger.Warningf("Timed out before finishing the final send attempt: %v", err)
165+
case <-ctx.Done():
166+
c.logger.Warningf("Context canceled before finishing the final send attempt: %v", err)
169167
}
170168

171169
lrs.cancelStream()

xds/internal/testutils/fakeclient/client.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ package fakeclient
2121

2222
import (
2323
"context"
24-
"time"
2524

2625
"google.golang.org/grpc/internal/testutils"
2726
"google.golang.org/grpc/internal/xds/bootstrap"
@@ -81,14 +80,14 @@ func (*stream) Recv() ([]byte, error) {
8180
}
8281

8382
// ReportLoad starts reporting load about clusterName to server.
84-
func (xdsC *Client) ReportLoad(server *bootstrap.ServerConfig) (loadStore *lrsclient.LoadStore, cancel func(time.Duration)) {
83+
func (xdsC *Client) ReportLoad(server *bootstrap.ServerConfig) (loadStore *lrsclient.LoadStore, cancel func(context.Context)) {
8584
lrsClient, _ := lrsclient.New(lrsclient.Config{Node: clients.Node{ID: "fake-node-id"}, TransportBuilder: &transportBuilder{}})
8685
xdsC.loadStore, _ = lrsClient.ReportLoad(clients.ServerIdentifier{ServerURI: server.ServerURI()})
8786

8887
xdsC.loadReportCh.Send(ReportLoadArgs{Server: server})
8988

90-
return xdsC.loadStore, func(timeout time.Duration) {
91-
xdsC.loadStore.Stop(timeout)
89+
return xdsC.loadStore, func(ctx context.Context) {
90+
xdsC.loadStore.Stop(ctx)
9291
xdsC.lrsCancelCh.Send(nil)
9392
}
9493
}

xds/internal/xdsclient/client.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
package xdsclient
2222

2323
import (
24-
"time"
24+
"context"
2525

2626
v3statuspb "github.com/envoyproxy/go-control-plane/envoy/service/status/v3"
2727
"google.golang.org/grpc/internal/xds/bootstrap"
@@ -49,7 +49,7 @@ type XDSClient interface {
4949
// the watcher is canceled. Callers need to handle this case.
5050
WatchResource(rType xdsresource.Type, resourceName string, watcher xdsresource.ResourceWatcher) (cancel func())
5151

52-
ReportLoad(*bootstrap.ServerConfig) (*lrsclient.LoadStore, func(time.Duration))
52+
ReportLoad(*bootstrap.ServerConfig) (*lrsclient.LoadStore, func(context.Context))
5353

5454
BootstrapConfig() *bootstrap.Config
5555
}

xds/internal/xdsclient/clientimpl_loadreport.go

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
package xdsclient
1919

2020
import (
21+
"context"
2122
"sync"
22-
"time"
2323

2424
"google.golang.org/grpc/internal/xds/bootstrap"
2525
"google.golang.org/grpc/xds/internal/clients"
@@ -31,26 +31,24 @@ import (
3131
// reports to the same server share the LRS stream.
3232
//
3333
// It returns a lrsclient.LoadStore for the user to report loads.
34-
func (c *clientImpl) ReportLoad(server *bootstrap.ServerConfig) (*lrsclient.LoadStore, func(time.Duration)) {
34+
func (c *clientImpl) ReportLoad(server *bootstrap.ServerConfig) (*lrsclient.LoadStore, func(context.Context)) {
3535
if c.lrsClient == nil {
3636
lrsConfig := lrsclient.Config{Node: c.gConfig.Node, TransportBuilder: c.gConfig.TransportBuilder}
3737
lrsC, err := lrsclient.New(lrsConfig)
3838
if err != nil {
3939
c.logger.Warningf("Failed to create an lrs client to the management server to report load: %v", server, err)
40-
return nil, func(time.Duration) {}
40+
return nil, func(context.Context) {}
4141
}
4242
c.lrsClient = lrsC
4343
}
4444

4545
load, err := c.lrsClient.ReportLoad(clients.ServerIdentifier{ServerURI: server.ServerURI(), Extensions: grpctransport.ServerIdentifierExtension{ConfigName: server.SelectedCreds().Type}})
4646
if err != nil {
4747
c.logger.Warningf("Failed to create a load store to the management server to report load: %v", server, err)
48-
return nil, func(time.Duration) {}
48+
return nil, func(context.Context) {}
4949
}
5050
var loadStop sync.Once
51-
return load, func(timeout time.Duration) {
52-
loadStop.Do(func() {
53-
load.Stop(timeout)
54-
})
51+
return load, func(ctx context.Context) {
52+
loadStop.Do(func() { load.Stop(ctx) })
5553
}
5654
}

xds/internal/xdsclient/tests/loadreport_test.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ func (s) TestReportLoad_ConnectionCreation(t *testing.T) {
144144
// Call the load reporting API to report load to the first management
145145
// server, and ensure that a connection to the server is created.
146146
store1, lrsCancel1 := client.ReportLoad(serverCfg1)
147-
defer lrsCancel1(defaultTestShortTimeout)
147+
sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout)
148+
defer sCancel()
149+
defer lrsCancel1(sCtx)
148150
if _, err := newConnChan1.Receive(ctx); err != nil {
149151
t.Fatal("Timeout when waiting for a connection to the first management server, after starting load reporting")
150152
}
@@ -159,7 +161,9 @@ func (s) TestReportLoad_ConnectionCreation(t *testing.T) {
159161
// Call the load reporting API to report load to the second management
160162
// server, and ensure that a connection to the server is created.
161163
store2, lrsCancel2 := client.ReportLoad(serverCfg2)
162-
defer lrsCancel2(defaultTestShortTimeout)
164+
sCtx2, sCancel2 := context.WithTimeout(ctx, defaultTestShortTimeout)
165+
defer sCancel2()
166+
defer lrsCancel2(sCtx2)
163167
if _, err := newConnChan2.Receive(ctx); err != nil {
164168
t.Fatal("Timeout when waiting for a connection to the second management server, after starting load reporting")
165169
}
@@ -227,7 +231,9 @@ func (s) TestReportLoad_ConnectionCreation(t *testing.T) {
227231
}
228232

229233
// Cancel this load reporting stream, server should see error canceled.
230-
lrsCancel2(defaultTestShortTimeout)
234+
sCtx2, sCancel2 = context.WithTimeout(ctx, defaultTestShortTimeout)
235+
defer sCancel2()
236+
lrsCancel2(sCtx2)
231237

232238
// Server should receive a stream canceled error. There may be additional
233239
// load reports from the client in the channel.
@@ -403,15 +409,19 @@ func (s) TestReportLoad_StreamCreation(t *testing.T) {
403409

404410
// Cancel the first load reporting call, and ensure that the stream does not
405411
// close (because we have another call open).
406-
cancel1(defaultTestShortTimeout)
412+
sCtx1, sCancel1 := context.WithTimeout(ctx, defaultTestShortTimeout)
413+
defer sCancel1()
414+
cancel1(sCtx1)
407415
sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout)
408416
defer sCancel()
409417
if _, err := lrsServer.LRSStreamCloseChan.Receive(sCtx); err != context.DeadlineExceeded {
410418
t.Fatal("LRS stream closed when expected to stay open")
411419
}
412420

413421
// Cancel the second load reporting call, and ensure the stream is closed.
414-
cancel2(defaultTestShortTimeout)
422+
sCtx2, sCancel2 := context.WithTimeout(ctx, defaultTestShortTimeout)
423+
defer sCancel2()
424+
cancel2(sCtx2)
415425
if _, err := lrsServer.LRSStreamCloseChan.Receive(ctx); err != nil {
416426
t.Fatal("Timeout waiting for LRS stream to close")
417427
}
@@ -423,5 +433,7 @@ func (s) TestReportLoad_StreamCreation(t *testing.T) {
423433
if _, err := lrsServer.LRSStreamOpenChan.Receive(ctx); err != nil {
424434
t.Fatalf("Timeout when waiting for LRS stream to be created: %v", err)
425435
}
426-
cancel3(defaultTestShortTimeout)
436+
sCtx3, sCancel3 := context.WithTimeout(ctx, defaultTestShortTimeout)
437+
defer sCancel3()
438+
cancel3(sCtx3)
427439
}

0 commit comments

Comments
 (0)