Skip to content

Commit 6507084

Browse files
authored
fix: improve client pool concurrency safety and performance (#67)
1 parent ba519ab commit 6507084

10 files changed

+932
-48
lines changed
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
// Licensed to Elasticsearch B.V. under one or more contributor
2+
// license agreements. See the NOTICE file distributed with
3+
// this work for additional information regarding copyright
4+
// ownership. Elasticsearch B.V. licenses this file to you under
5+
// the Apache License, Version 2.0 (the "License"); you may
6+
// not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//go:build !integration
19+
// +build !integration
20+
21+
package elastictransport
22+
23+
import (
24+
"fmt"
25+
"io"
26+
"net/http"
27+
"net/url"
28+
"strings"
29+
"sync"
30+
"testing"
31+
)
32+
33+
// benchTransport returns a minimal response without hitting the network.
34+
// Each call returns a fresh *http.Response so parallel goroutines don't share
35+
// mutable state.
36+
type benchTransport struct{}
37+
38+
func (t *benchTransport) RoundTrip(*http.Request) (*http.Response, error) {
39+
return &http.Response{
40+
StatusCode: 200,
41+
Body: io.NopCloser(strings.NewReader("")),
42+
ContentLength: 0,
43+
Header: http.Header{},
44+
}, nil
45+
}
46+
47+
// benchmarkPerform exercises Client.Perform under varying parallelism factors.
48+
// The sub-benchmarks vary the value passed to b.SetParallelism, which scales
49+
// parallel work relative to GOMAXPROCS rather than creating an exact number
50+
// of goroutines.
51+
func benchmarkPerform(b *testing.B, name string, client *Client) {
52+
b.Helper()
53+
b.Run(name, func(b *testing.B) {
54+
for _, p := range []int{1, 10, 100, 1000} {
55+
b.Run(fmt.Sprintf("parallelism-%d", p), func(b *testing.B) {
56+
b.SetParallelism(p)
57+
b.RunParallel(func(pb *testing.PB) {
58+
for pb.Next() {
59+
req, _ := http.NewRequest("GET", "/", nil)
60+
res, err := client.Perform(req)
61+
if err != nil {
62+
b.Fatalf("Perform: %s", err)
63+
}
64+
if err := res.Body.Close(); err != nil {
65+
b.Fatalf("Close response body: %s", err)
66+
}
67+
}
68+
})
69+
})
70+
}
71+
})
72+
}
73+
74+
func BenchmarkClientPoolAccess(b *testing.B) {
75+
b.ReportAllocs()
76+
77+
b.Run("SingleConnectionPool", func(b *testing.B) {
78+
u, _ := url.Parse("http://localhost:9200")
79+
tp, err := New(Config{
80+
URLs: []*url.URL{u},
81+
Transport: &benchTransport{},
82+
DisableRetry: true,
83+
})
84+
if err != nil {
85+
b.Fatal(err)
86+
}
87+
benchmarkPerform(b, "Perform", tp)
88+
})
89+
90+
b.Run("StatusConnectionPool", func(b *testing.B) {
91+
var urls []*url.URL
92+
for i := 0; i < 10; i++ {
93+
u, _ := url.Parse(fmt.Sprintf("http://node%d:9200", i))
94+
urls = append(urls, u)
95+
}
96+
tp, err := New(Config{
97+
URLs: urls,
98+
Transport: &benchTransport{},
99+
DisableRetry: true,
100+
})
101+
if err != nil {
102+
b.Fatal(err)
103+
}
104+
benchmarkPerform(b, "Perform", tp)
105+
})
106+
107+
b.Run("CustomPool/SynchronizedWrapper", func(b *testing.B) {
108+
var urls []*url.URL
109+
for i := 0; i < 10; i++ {
110+
u, _ := url.Parse(fmt.Sprintf("http://node%d:9200", i))
111+
urls = append(urls, u)
112+
}
113+
tp, err := New(Config{
114+
URLs: urls,
115+
ConnectionPoolFunc: func(conns []*Connection, sel Selector) ConnectionPool {
116+
return &benchCustomPool{conns: conns}
117+
},
118+
Transport: &benchTransport{},
119+
DisableRetry: true,
120+
})
121+
if err != nil {
122+
b.Fatal(err)
123+
}
124+
benchmarkPerform(b, "Perform", tp)
125+
})
126+
127+
b.Run("CustomPool/ConcurrentSafeOptIn", func(b *testing.B) {
128+
var urls []*url.URL
129+
for i := 0; i < 10; i++ {
130+
u, _ := url.Parse(fmt.Sprintf("http://node%d:9200", i))
131+
urls = append(urls, u)
132+
}
133+
tp, err := New(Config{
134+
URLs: urls,
135+
ConnectionPoolFunc: func(conns []*Connection, sel Selector) ConnectionPool {
136+
return &benchConcurrentSafeCustomPool{
137+
benchCustomPool: benchCustomPool{conns: conns},
138+
}
139+
},
140+
Transport: &benchTransport{},
141+
DisableRetry: true,
142+
})
143+
if err != nil {
144+
b.Fatal(err)
145+
}
146+
benchmarkPerform(b, "Perform", tp)
147+
})
148+
}
149+
150+
type benchCustomPool struct {
151+
mu sync.Mutex
152+
conns []*Connection
153+
curr int
154+
}
155+
156+
func (p *benchCustomPool) Next() (*Connection, error) {
157+
p.mu.Lock()
158+
defer p.mu.Unlock()
159+
c := p.conns[p.curr%len(p.conns)]
160+
p.curr++
161+
return c, nil
162+
}
163+
164+
func (p *benchCustomPool) OnSuccess(*Connection) error { return nil }
165+
func (p *benchCustomPool) OnFailure(*Connection) error { return nil }
166+
func (p *benchCustomPool) URLs() []*url.URL {
167+
p.mu.Lock()
168+
defer p.mu.Unlock()
169+
out := make([]*url.URL, len(p.conns))
170+
for i, c := range p.conns {
171+
out[i] = c.URL
172+
}
173+
return out
174+
}
175+
176+
type benchConcurrentSafeCustomPool struct {
177+
benchCustomPool
178+
}
179+
180+
func (p *benchConcurrentSafeCustomPool) ConcurrentSafe() {}

elastictransport/connection.go

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ var (
3939
_ ConnectionPool = (*statusConnectionPool)(nil)
4040
_ UpdatableConnectionPool = (*statusConnectionPool)(nil)
4141
_ CloseableConnectionPool = (*statusConnectionPool)(nil)
42+
_ ConnectionPool = (*synchronizedPool)(nil)
43+
_ CloseableConnectionPool = (*synchronizedPool)(nil)
44+
_ connectionable = (*synchronizedPool)(nil)
45+
_ ConnectionPool = (*synchronizedUpdatablePool)(nil)
46+
_ CloseableConnectionPool = (*synchronizedUpdatablePool)(nil)
47+
_ UpdatableConnectionPool = (*synchronizedUpdatablePool)(nil)
4248
_ Selector = (*roundRobinSelector)(nil)
4349
)
4450

@@ -59,11 +65,107 @@ type UpdatableConnectionPool interface {
5965
Update([]*Connection) error // Update injects newly found nodes in the cluster.
6066
}
6167

68+
// ConcurrentSafeConnectionPool marks a custom pool as safe for concurrent use
69+
// by the transport. Pools returned by ConnectionPoolFunc that implement this
70+
// interface are not wrapped by synchronizedPool.
71+
type ConcurrentSafeConnectionPool interface {
72+
ConnectionPool
73+
ConcurrentSafe()
74+
}
75+
6276
// CloseableConnectionPool defines the interface for the connection pool that can be closed.
6377
type CloseableConnectionPool interface {
6478
Close(context.Context) error
6579
}
6680

81+
// synchronizedPool wraps a ConnectionPool with a mutex to serialize all method
82+
// calls. Used to wrap user-provided pools from ConnectionPoolFunc that may not
83+
// be safe for concurrent use.
84+
//
85+
// Use newSynchronizedPool to construct; it returns a synchronizedUpdatablePool
86+
// when the inner pool implements UpdatableConnectionPool.
87+
type synchronizedPool struct {
88+
mu sync.Mutex
89+
pool ConnectionPool
90+
}
91+
92+
// synchronizedUpdatablePool extends synchronizedPool for inner pools that
93+
// support in-place updates. Discovery will prefer Update() over full pool
94+
// replacement when this interface is present.
95+
type synchronizedUpdatablePool struct {
96+
synchronizedPool
97+
}
98+
99+
// newSynchronizedPool wraps pool in a synchronized adapter by default.
100+
// ConcurrentSafeConnectionPool implementations are returned as-is.
101+
// If pool implements UpdatableConnectionPool the returned value will too, so
102+
// discovery can update the pool in place rather than replacing it.
103+
func newSynchronizedPool(pool ConnectionPool) ConnectionPool {
104+
if _, ok := pool.(ConcurrentSafeConnectionPool); ok {
105+
return pool
106+
}
107+
108+
if _, ok := pool.(UpdatableConnectionPool); ok {
109+
return &synchronizedUpdatablePool{synchronizedPool{pool: pool}}
110+
}
111+
return &synchronizedPool{pool: pool}
112+
}
113+
114+
func (sp *synchronizedPool) Next() (*Connection, error) {
115+
sp.mu.Lock()
116+
defer sp.mu.Unlock()
117+
return sp.pool.Next()
118+
}
119+
120+
func (sp *synchronizedPool) OnSuccess(c *Connection) error {
121+
sp.mu.Lock()
122+
defer sp.mu.Unlock()
123+
return sp.pool.OnSuccess(c)
124+
}
125+
126+
func (sp *synchronizedPool) OnFailure(c *Connection) error {
127+
sp.mu.Lock()
128+
defer sp.mu.Unlock()
129+
return sp.pool.OnFailure(c)
130+
}
131+
132+
func (sp *synchronizedPool) URLs() []*url.URL {
133+
sp.mu.Lock()
134+
defer sp.mu.Unlock()
135+
return sp.pool.URLs()
136+
}
137+
138+
func (sp *synchronizedPool) Close(ctx context.Context) error {
139+
sp.mu.Lock()
140+
cp, ok := sp.pool.(CloseableConnectionPool)
141+
sp.mu.Unlock()
142+
if !ok {
143+
return nil
144+
}
145+
// Avoid holding the wrapper lock during close; in-flight callbacks also take
146+
// this lock and may need to complete before the inner pool can fully close.
147+
return cp.Close(ctx)
148+
}
149+
150+
func (sp *synchronizedPool) connections() []*Connection {
151+
sp.mu.Lock()
152+
defer sp.mu.Unlock()
153+
if cp, ok := sp.pool.(connectionable); ok {
154+
return cp.connections()
155+
}
156+
return nil
157+
}
158+
159+
func (sp *synchronizedUpdatablePool) Update(conns []*Connection) error {
160+
sp.mu.Lock()
161+
defer sp.mu.Unlock()
162+
up, ok := sp.pool.(UpdatableConnectionPool)
163+
if !ok {
164+
return fmt.Errorf("inner pool %T does not implement UpdatableConnectionPool", sp.pool)
165+
}
166+
return up.Update(conns)
167+
}
168+
67169
// Connection represents a connection to a node.
68170
type Connection struct {
69171
sync.Mutex
@@ -240,8 +342,10 @@ func (cp *statusConnectionPool) OnFailure(c *Connection) error {
240342
}
241343

242344
// Update merges the existing live and dead connections with the latest nodes discovered from the cluster.
243-
// ConnectionPool must be locked before calling.
244345
func (cp *statusConnectionPool) Update(connections []*Connection) error {
346+
cp.Lock()
347+
defer cp.Unlock()
348+
245349
if len(connections) == 0 {
246350
return errors.New("no connections provided, connection pool left untouched")
247351
}
@@ -354,7 +458,9 @@ func (cp *statusConnectionPool) isClosed() bool {
354458
}
355459

356460
func (cp *statusConnectionPool) connections() []*Connection {
357-
var conns []*Connection
461+
cp.Lock()
462+
defer cp.Unlock()
463+
conns := make([]*Connection, 0, len(cp.live)+len(cp.dead))
358464
conns = append(conns, cp.live...)
359465
conns = append(conns, cp.dead...)
360466
return conns

0 commit comments

Comments
 (0)