Skip to content

Commit b08ea15

Browse files
authored
feat(bigtable): add connection factory to abstract connection (#13755)
lift and sift of sushanb#5
1 parent 09bb990 commit b08ea15

File tree

3 files changed

+246
-73
lines changed

3 files changed

+246
-73
lines changed

bigtable/internal/transport/connpool.go

Lines changed: 88 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"time"
3232

3333
btpb "cloud.google.com/go/bigtable/apiv2/bigtablepb"
34+
"github.com/googleapis/gax-go/v2"
3435
"go.opentelemetry.io/otel/attribute"
3536
"go.opentelemetry.io/otel/metric"
3637
gtransport "google.golang.org/api/transport/grpc"
@@ -320,7 +321,10 @@ type BigtableChannelPool struct {
320321
appProfile string
321322
instanceName string
322323
featureFlagsMD metadata.MD
323-
meterProvider metric.MeterProvider
324+
325+
factory *connectionFactory // Use the factory for connection creation
326+
327+
meterProvider metric.MeterProvider
324328
// configs
325329
metricsConfig btopt.MetricsReporterConfig
326330

@@ -367,6 +371,15 @@ func NewBigtableChannelPool(ctx context.Context, connPoolSize int, strategy btop
367371
opt(pool)
368372
}
369373

374+
// Initialize the connectionFactory
375+
pool.factory = &connectionFactory{
376+
dial: dial,
377+
instanceName: pool.instanceName,
378+
appProfile: pool.appProfile,
379+
featureFlagsMD: pool.featureFlagsMD,
380+
logger: pool.logger,
381+
}
382+
370383
// Set the selection function based on the strategy
371384
switch strategy {
372385
case btopt.LeastInFlight:
@@ -379,6 +392,7 @@ func NewBigtableChannelPool(ctx context.Context, connPoolSize int, strategy btop
379392

380393
var exitSignal error
381394

395+
// TODO: Replace this logic with addConnections(...).
382396
initialConns := make([]*connEntry, connPoolSize)
383397
for i := 0; i < connPoolSize; i++ {
384398
select {
@@ -391,21 +405,12 @@ func NewBigtableChannelPool(ctx context.Context, connPoolSize int, strategy btop
391405
break
392406
}
393407

394-
conn, err := dial()
408+
entry, err := pool.factory.newEntry(ctx)
395409
if err != nil {
396410
exitSignal = err
397411
break
398412
}
399-
400-
entry := &connEntry{conn: conn}
401-
initialConns[i] = entry // Note, we keep non primed conns in conns
402-
// Prime the new connection in a non-blocking goroutine to warm it up.
403-
go func(e *connEntry) {
404-
err := e.conn.Prime(ctx, pool.instanceName, pool.appProfile, pool.featureFlagsMD)
405-
if err != nil {
406-
btopt.Debugf(pool.logger, "bigtable_connpool: failed to prime initial connection: %v\n", err)
407-
}
408-
}(entry)
413+
initialConns[i] = entry
409414
}
410415
if exitSignal != nil {
411416
btopt.Debugf(pool.logger, "bigtable_connpool: error during initial connection creation: %v\n", exitSignal)
@@ -530,25 +535,13 @@ func (p *BigtableChannelPool) replaceConnection(oldEntry *connEntry) {
530535
return
531536
default:
532537
}
533-
newConn, err := p.dial()
538+
newEntry, err := p.factory.newEntry(p.poolCtx)
534539
if err != nil {
535-
btopt.Debugf(p.logger, "bigtable_connpool: Failed to redial connection at index %d: %v\n", idx, err)
540+
btopt.Debugf(p.logger, "bigtable_connpool: Failed to replace connection at index %d: %v. Closing new conn. Old connection remains (draining).\n", idx, err)
536541
return
537542
}
538543

539-
err = newConn.Prime(p.poolCtx, p.instanceName, p.appProfile, p.featureFlagsMD)
540-
541-
if err != nil {
542-
btopt.Debugf(p.logger, "bigtable_connpool: Failed to prime replacement connection at index %d: %v. Closing new conn. Old connection remains (draining).\n", idx, err)
543-
newConn.Close() //
544-
return // Abort
545-
}
546-
547544
btopt.Debugf(p.logger, "bigtable_connpool: Successfully primed new connection. Replacing connection at index %d\n", idx)
548-
newEntry := &connEntry{
549-
conn: newConn,
550-
}
551-
552545
// Copy-on-write
553546
newConns := make([]*connEntry, len(currentConns))
554547
copy(newConns, currentConns)
@@ -785,20 +778,13 @@ func (p *BigtableChannelPool) addConnections(increaseDelta, maxConns int) bool {
785778
default:
786779
}
787780

788-
conn, err := p.dial()
789-
if err != nil {
790-
btopt.Debugf(p.logger, "bigtable_connpool: Failed to dial new connection for scale up: %v\n", err)
791-
return
792-
}
793-
794-
err = conn.Prime(p.poolCtx, p.instanceName, p.appProfile, p.featureFlagsMD)
781+
entry, err := p.factory.newEntry(p.poolCtx)
795782
if err != nil {
796-
btopt.Debugf(p.logger, "bigtable_connpool: Failed to prime new connection: %v. Connection will not be added.\n", err)
797-
conn.Close()
783+
btopt.Debugf(p.logger, "bigtable_connpool: Failed to add new connection: %v. Connection will not be added.\n", err)
798784
return
799785
}
800786

801-
results <- &connEntry{conn: conn}
787+
results <- entry
802788
}()
803789
}
804790
// Goroutine to close the results channel once all workers are done.
@@ -904,6 +890,72 @@ func (p *BigtableChannelPool) removeConnections(decreaseDelta, minConns, maxRemo
904890

905891
}
906892

893+
// connectionFactory is responsible for creating and priming new Bigtable connections.
894+
// TODO remove these members from BigtableConnPool struct
895+
type connectionFactory struct {
896+
dial func() (*BigtableConn, error)
897+
instanceName string
898+
appProfile string
899+
featureFlagsMD metadata.MD
900+
logger *log.Logger
901+
}
902+
903+
// newEntry creates a new connection, primes it, and returns it as a connEntry.
904+
// Blocks until the connection is successfully primed, or returns an error.
905+
func (cf *connectionFactory) newEntry(ctx context.Context) (*connEntry, error) {
906+
conn, err := cf.dial()
907+
if err != nil {
908+
return nil, fmt.Errorf("factory dial failed: %w", err)
909+
}
910+
911+
if err := cf.primeWithRetry(ctx, conn); err != nil {
912+
conn.Close()
913+
return nil, fmt.Errorf("bigtable_connpool: connection factory prime failed: %w", err)
914+
}
915+
916+
return &connEntry{conn: conn}, nil
917+
}
918+
919+
// primeWithRetry attempts to prime the connection, retrying with exponential backoff.
920+
func (cf *connectionFactory) primeWithRetry(ctx context.Context, conn *BigtableConn) error {
921+
backoffPolicy := gax.Backoff{
922+
Initial: 100 * time.Millisecond,
923+
Max: 2 * time.Second,
924+
Multiplier: 1.2,
925+
}
926+
maxAttempts := 3
927+
var lastErr error
928+
for attempt := 0; attempt < maxAttempts; attempt++ {
929+
930+
// ctx.Done() returns a error
931+
if err := ctx.Err(); err != nil {
932+
return fmt.Errorf("bigtable_connpool: error before prime attempt %d: %w", attempt, err)
933+
}
934+
935+
lastErr = conn.Prime(ctx, cf.instanceName, cf.appProfile, cf.featureFlagsMD)
936+
if lastErr == nil {
937+
return nil
938+
}
939+
940+
if attempt == maxAttempts-1 {
941+
// no need to pause(), short circuit
942+
break
943+
}
944+
945+
pause := backoffPolicy.Pause()
946+
btopt.Debugf(cf.logger, "bigtable_connpool: Prime failed with error on attempt %d, retrying in %v: %v", attempt+1, pause, lastErr)
947+
948+
select {
949+
case <-ctx.Done():
950+
return fmt.Errorf("context done while backing off for prime: %w", ctx.Err())
951+
case <-time.After(pause):
952+
}
953+
}
954+
955+
return fmt.Errorf("factory prime failed after %d attempts: %w", maxAttempts, lastErr)
956+
957+
}
958+
907959
type multiError []error
908960

909961
func (m multiError) Error() string {

bigtable/internal/transport/connpool_helper_test.go

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -41,36 +41,29 @@ type fakeService struct {
4141
delay time.Duration // To simulate work
4242
serverErr error // Error to return from server
4343
lastPingAndWarmMetadata metadata.MD // Stores metadata from the last PingAndWarm call
44-
pingErr error // Error to return from PingAndWarm
45-
pingErrMu sync.Mutex // Protects pingErr
44+
pingErrs []error // Errors to return from PingAndWarm
4645
streamRecvErr error // Error to return from stream.Recv()
4746
streamSendErr error // Error to return from stream.Send()
4847
}
4948

50-
func (s *fakeService) setPingErr(err error) {
51-
s.pingErrMu.Lock()
52-
defer s.pingErrMu.Unlock()
53-
s.pingErr = err
49+
func (s *fakeService) setPingErr(errs ...error) {
50+
s.mu.Lock()
51+
defer s.mu.Unlock()
52+
s.pingErrs = errs
5453
}
5554

5655
func (s *fakeService) setDelay(duration time.Duration) {
57-
s.pingErrMu.Lock()
58-
defer s.pingErrMu.Unlock()
56+
s.mu.Lock()
57+
defer s.mu.Unlock()
5958
s.delay = duration
6059
}
6160

6261
func (s *fakeService) getDelay() time.Duration {
63-
s.pingErrMu.Lock()
64-
defer s.pingErrMu.Unlock()
62+
s.mu.Lock()
63+
defer s.mu.Unlock()
6564
return s.delay
6665
}
6766

68-
func (s *fakeService) getPingErr() error {
69-
s.pingErrMu.Lock()
70-
defer s.pingErrMu.Unlock()
71-
return s.pingErr
72-
}
73-
7467
func (s *fakeService) UnaryCall(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
7568
s.mu.Lock()
7669
s.callCount++
@@ -84,6 +77,12 @@ func (s *fakeService) UnaryCall(ctx context.Context, req *testpb.SimpleRequest)
8477
return &testpb.SimpleResponse{Payload: req.GetPayload()}, nil
8578
}
8679

80+
func (f *fakeService) getPingCallCount() int {
81+
f.mu.Lock()
82+
defer f.mu.Unlock()
83+
return f.pingCount
84+
}
85+
8786
func (s *fakeService) StreamingCall(stream testpb.BenchmarkService_StreamingCallServer) error {
8887
s.mu.Lock()
8988
s.callCount++
@@ -124,7 +123,7 @@ func (f *fakeService) reset() {
124123
f.callCount = 0
125124
f.pingCount = 0
126125
f.serverErr = nil
127-
f.pingErr = nil
126+
f.pingErrs = nil
128127
f.delay = 0
129128
f.lastPingAndWarmMetadata = nil
130129
if f.streamSema != nil {
@@ -138,30 +137,36 @@ func (f *fakeService) reset() {
138137

139138
func (s *fakeService) PingAndWarm(ctx context.Context, req *btpb.PingAndWarmRequest) (*btpb.PingAndWarmResponse, error) {
140139
s.mu.Lock()
140+
callNum := s.pingCount
141141
s.pingCount++
142-
defer s.mu.Unlock()
143142

144-
// Capture metadata
145-
if md, ok := metadata.FromIncomingContext(ctx); ok {
146-
s.lastPingAndWarmMetadata = md.Copy()
143+
var err error
144+
if len(s.pingErrs) > 0 {
145+
if callNum < len(s.pingErrs) {
146+
err = s.pingErrs[callNum]
147+
} else {
148+
// If callCount exceeds provided errors, use the last one for subsequent calls
149+
err = s.pingErrs[len(s.pingErrs)-1]
150+
}
147151
}
148152

149-
delay := s.getDelay()
150-
153+
delay := s.delay
154+
// Capture metadata on the first call, assuming headers are constant
155+
if callNum == 0 {
156+
s.lastPingAndWarmMetadata, _ = metadata.FromIncomingContext(ctx)
157+
}
158+
s.mu.Unlock()
151159
if delay > 0 {
152160
select {
153161
case <-time.After(delay):
154162
case <-ctx.Done():
155163
return nil, ctx.Err()
156164
}
157165
}
158-
159-
if err := ctx.Err(); err != nil {
160-
return nil, err
161-
}
162-
if err := s.getPingErr(); err != nil {
166+
if err != nil {
163167
return nil, err
164168
}
169+
165170
return &btpb.PingAndWarmResponse{}, nil
166171
}
167172

0 commit comments

Comments
 (0)