Skip to content

Commit 3b6c56b

Browse files
committed
Make HTTP implementation thread safe
1 parent 3d8f7ca commit 3b6c56b

12 files changed

Lines changed: 378 additions & 232 deletions

clickhouse.go

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import (
2727

2828
_ "time/tzdata"
2929

30-
chproto "github.com/ClickHouse/ch-go/proto"
3130
"github.com/ClickHouse/clickhouse-go/v2/contributors"
3231
"github.com/ClickHouse/clickhouse-go/v2/lib/column"
3332
"github.com/ClickHouse/clickhouse-go/v2/lib/driver"
@@ -84,28 +83,42 @@ func Open(opt *Options) (driver.Conn, error) {
8483
}
8584
o := opt.setDefaults()
8685

87-
if o.Protocol == HTTP {
88-
httpConn, err := dialHttp(context.Background(), opt.Addr[0], 0, opt)
89-
if err != nil {
90-
return nil, err
91-
}
92-
93-
return &clickhouseHTTP{conn: httpConn}, nil
94-
}
95-
9686
conn := &clickhouse{
9787
opt: o,
98-
idle: make(chan *connect, o.MaxIdleConns),
88+
idle: make(chan nativeTransport, o.MaxIdleConns),
9989
open: make(chan struct{}, o.MaxOpenConns),
10090
exit: make(chan struct{}),
10191
}
10292
go conn.startAutoCloseIdleConnections()
10393
return conn, nil
10494
}
10595

96+
// nativeTransport represents an implementation (TCP or HTTP) that can be pooled by the main clickhouse struct.
97+
// Implementations are not expected to be thread safe, which is why we provide acquire/release functions.
98+
type nativeTransport interface {
99+
serverVersion() (*ServerVersion, error)
100+
query(ctx context.Context, release nativeTransportRelease, query string, args ...any) (*rows, error)
101+
queryRow(ctx context.Context, release nativeTransportRelease, query string, args ...any) *row
102+
prepareBatch(ctx context.Context, release nativeTransportRelease, acquire nativeTransportAcquire, query string, opts driver.PrepareBatchOptions) (driver.Batch, error)
103+
exec(ctx context.Context, query string, args ...any) error
104+
asyncInsert(ctx context.Context, query string, wait bool, args ...any) error
105+
ping(context.Context) error
106+
isBad() bool
107+
connID() int
108+
connectedAtTime() time.Time
109+
isReleased() bool
110+
setReleased(released bool)
111+
debugf(format string, v ...any)
112+
// freeBuffer is called if Options.FreeBufOnConnRelease is set
113+
freeBuffer()
114+
close() error
115+
}
116+
type nativeTransportAcquire func(context.Context) (nativeTransport, error)
117+
type nativeTransportRelease func(nativeTransport, error)
118+
106119
type clickhouse struct {
107120
opt *Options
108-
idle chan *connect
121+
idle chan nativeTransport
109122
open chan struct{}
110123
exit chan struct{}
111124
connID int64
@@ -128,16 +141,16 @@ func (ch *clickhouse) ServerVersion() (*driver.ServerVersion, error) {
128141
if err != nil {
129142
return nil, err
130143
}
131-
ch.release(conn, nil)
132-
return &conn.server, nil
144+
defer ch.release(conn, nil)
145+
return conn.serverVersion()
133146
}
134147

135148
func (ch *clickhouse) Query(ctx context.Context, query string, args ...any) (rows driver.Rows, err error) {
136149
conn, err := ch.acquire(ctx)
137150
if err != nil {
138151
return nil, err
139152
}
140-
conn.debugf("[acquired] connection [%d]", conn.id)
153+
conn.debugf("[acquired] connection [%d]", conn.connID())
141154
return conn.query(ctx, ch.release, query, args...)
142155
}
143156

@@ -148,7 +161,7 @@ func (ch *clickhouse) QueryRow(ctx context.Context, query string, args ...any) d
148161
err: err,
149162
}
150163
}
151-
conn.debugf("[acquired] connection [%d]", conn.id)
164+
conn.debugf("[acquired] connection [%d]", conn.connID())
152165
return conn.queryRow(ctx, ch.release, query, args...)
153166
}
154167

@@ -157,7 +170,7 @@ func (ch *clickhouse) Exec(ctx context.Context, query string, args ...any) error
157170
if err != nil {
158171
return err
159172
}
160-
conn.debugf("[acquired] connection [%d]", conn.id)
173+
conn.debugf("[acquired] connection [%d]", conn.connID())
161174

162175
if err := conn.exec(ctx, query, args...); err != nil {
163176
ch.release(conn, err)
@@ -172,7 +185,7 @@ func (ch *clickhouse) PrepareBatch(ctx context.Context, query string, opts ...dr
172185
if err != nil {
173186
return nil, err
174187
}
175-
batch, err := conn.prepareBatch(ctx, query, getPrepareBatchOptions(opts...), ch.release, ch.acquire)
188+
batch, err := conn.prepareBatch(ctx, ch.release, ch.acquire, query, getPrepareBatchOptions(opts...))
176189
if err != nil {
177190
return nil, err
178191
}
@@ -224,11 +237,18 @@ func (ch *clickhouse) Stats() driver.Stats {
224237
}
225238
}
226239

227-
func (ch *clickhouse) dial(ctx context.Context) (conn *connect, err error) {
240+
func (ch *clickhouse) dial(ctx context.Context) (conn nativeTransport, err error) {
228241
connID := int(atomic.AddInt64(&ch.connID, 1))
229242

230243
dialFunc := func(ctx context.Context, addr string, opt *Options) (DialResult, error) {
231-
conn, err := dial(ctx, addr, connID, opt)
244+
var conn nativeTransport
245+
var err error
246+
switch opt.Protocol {
247+
case HTTP:
248+
conn, err = dialHttp(context.Background(), addr, connID, opt)
249+
default:
250+
conn, err = dial(ctx, addr, connID, opt)
251+
}
232252

233253
return DialResult{conn}, err
234254
}
@@ -270,7 +290,7 @@ func DefaultDialStrategy(ctx context.Context, connID int, opt *Options, dial Dia
270290
return r, err
271291
}
272292

273-
func (ch *clickhouse) acquire(ctx context.Context) (conn *connect, err error) {
293+
func (ch *clickhouse) acquire(ctx context.Context) (conn nativeTransport, err error) {
274294
timer := time.NewTimer(ch.opt.DialTimeout)
275295
defer timer.Stop()
276296
select {
@@ -303,7 +323,7 @@ func (ch *clickhouse) acquire(ctx context.Context) (conn *connect, err error) {
303323
return nil, err
304324
}
305325
}
306-
conn.released = false
326+
conn.setReleased(false)
307327
return conn, nil
308328
default:
309329
}
@@ -336,7 +356,7 @@ func (ch *clickhouse) closeIdleExpired() {
336356
for {
337357
select {
338358
case conn := <-ch.idle:
339-
if conn.connectedAt.Before(cutoff) {
359+
if conn.connectedAtTime().Before(cutoff) {
340360
conn.close()
341361
} else {
342362
select {
@@ -352,24 +372,23 @@ func (ch *clickhouse) closeIdleExpired() {
352372
}
353373
}
354374

355-
func (ch *clickhouse) release(conn *connect, err error) {
356-
if conn.released {
375+
func (ch *clickhouse) release(conn nativeTransport, err error) {
376+
if conn.isReleased() {
357377
return
358378
}
359-
conn.released = true
360-
conn.debugf("[released] connection [%d]", conn.id)
379+
conn.setReleased(true)
380+
conn.debugf("[released] connection [%d]", conn.connID())
361381

362382
select {
363383
case <-ch.open:
364384
default:
365385
}
366-
if err != nil || time.Since(conn.connectedAt) >= ch.opt.ConnMaxLifetime {
386+
if err != nil || time.Since(conn.connectedAtTime()) >= ch.opt.ConnMaxLifetime {
367387
conn.close()
368388
return
369389
}
370390
if ch.opt.FreeBufOnConnRelease {
371-
conn.buffer = new(chproto.Buffer)
372-
conn.compressor.Data = nil
391+
conn.freeBuffer()
373392
}
374393
select {
375394
case ch.idle <- conn:

clickhouse_http.go

Lines changed: 0 additions & 83 deletions
This file was deleted.

clickhouse_options.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ func ParseDSN(dsn string) (*Options, error) {
124124

125125
type Dial func(ctx context.Context, addr string, opt *Options) (DialResult, error)
126126
type DialResult struct {
127-
conn *connect
127+
conn nativeTransport
128128
}
129129

130130
type HTTPProxy func(*http.Request) (*url.URL, error)

clickhouse_std.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import (
3434
"syscall"
3535

3636
"github.com/ClickHouse/clickhouse-go/v2/lib/column"
37-
ldriver "github.com/ClickHouse/clickhouse-go/v2/lib/driver"
37+
chdriver "github.com/ClickHouse/clickhouse-go/v2/lib/driver"
3838
)
3939

4040
var globalConnID int64
@@ -196,10 +196,10 @@ func OpenDB(opt *Options) *sql.DB {
196196
type stdConnect interface {
197197
isBad() bool
198198
close() error
199-
query(ctx context.Context, release func(*connect, error), query string, args ...any) (*rows, error)
199+
query(ctx context.Context, release nativeTransportRelease, query string, args ...any) (*rows, error)
200200
exec(ctx context.Context, query string, args ...any) error
201201
ping(ctx context.Context) (err error)
202-
prepareBatch(ctx context.Context, query string, options ldriver.PrepareBatchOptions, release func(*connect, error), acquire func(context.Context) (*connect, error)) (ldriver.Batch, error)
202+
prepareBatch(ctx context.Context, release nativeTransportRelease, acquire nativeTransportAcquire, query string, options chdriver.PrepareBatchOptions) (chdriver.Batch, error)
203203
asyncInsert(ctx context.Context, query string, wait bool, args ...any) error
204204
}
205205

@@ -333,7 +333,7 @@ func (std *stdDriver) QueryContext(ctx context.Context, query string, args []dri
333333
return nil, driver.ErrBadConn
334334
}
335335

336-
r, err := std.conn.query(ctx, func(*connect, error) {}, query, rebind(args)...)
336+
r, err := std.conn.query(ctx, func(nativeTransport, error) {}, query, rebind(args)...)
337337
if isConnBrokenError(err) {
338338
std.debugf("QueryContext got a fatal error, resetting connection: %v\n", err)
339339
return nil, driver.ErrBadConn
@@ -358,7 +358,7 @@ func (std *stdDriver) PrepareContext(ctx context.Context, query string) (driver.
358358
return nil, driver.ErrBadConn
359359
}
360360

361-
batch, err := std.conn.prepareBatch(ctx, query, ldriver.PrepareBatchOptions{}, func(*connect, error) {}, func(context.Context) (*connect, error) { return nil, nil })
361+
batch, err := std.conn.prepareBatch(ctx, func(nativeTransport, error) {}, func(context.Context) (nativeTransport, error) { return nil, nil }, query, chdriver.PrepareBatchOptions{})
362362
if err != nil {
363363
if isConnBrokenError(err) {
364364
std.debugf("PrepareContext got a fatal error, resetting connection: %v\n", err)
@@ -387,7 +387,7 @@ func (std *stdDriver) Close() error {
387387
}
388388

389389
type stdBatch struct {
390-
batch ldriver.Batch
390+
batch chdriver.Batch
391391
debugf func(format string, v ...any)
392392
}
393393

conn.go

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
9696
id: num,
9797
opt: opt,
9898
conn: conn,
99-
debugf: debugf,
99+
debugfFunc: debugf,
100100
buffer: new(chproto.Buffer),
101101
reader: chproto.NewReader(conn),
102102
revision: ClientTCPProtocolVersion,
@@ -144,7 +144,7 @@ type connect struct {
144144
id int
145145
opt *Options
146146
conn net.Conn
147-
debugf func(format string, v ...any)
147+
debugfFunc func(format string, v ...any)
148148
server ServerVersion
149149
closed bool
150150
buffer *chproto.Buffer
@@ -162,6 +162,22 @@ type connect struct {
162162
closeMutex sync.Mutex
163163
}
164164

165+
func (c *connect) debugf(format string, v ...any) {
166+
c.debugfFunc(format, v...)
167+
}
168+
169+
func (c *connect) connID() int {
170+
return c.id
171+
}
172+
173+
func (c *connect) connectedAtTime() time.Time {
174+
return c.connectedAt
175+
}
176+
177+
func (c *connect) serverVersion() (*ServerVersion, error) {
178+
return &c.server, nil
179+
}
180+
165181
func (c *connect) settings(querySettings Settings) []proto.Setting {
166182
settingToProtoSetting := func(k string, v any) proto.Setting {
167183
isCustom := false
@@ -206,6 +222,14 @@ func (c *connect) isBad() bool {
206222
return false
207223
}
208224

225+
func (c *connect) isReleased() bool {
226+
return c.released
227+
}
228+
229+
func (c *connect) setReleased(released bool) {
230+
c.released = released
231+
}
232+
209233
func (c *connect) isClosed() bool {
210234
c.closeMutex.Lock()
211235
defer c.closeMutex.Unlock()
@@ -372,6 +396,11 @@ func (c *connect) readData(ctx context.Context, packet byte, compressible bool)
372396
return &block, nil
373397
}
374398

399+
func (c *connect) freeBuffer() {
400+
c.buffer = new(chproto.Buffer)
401+
c.compressor.Data = nil
402+
}
403+
375404
func (c *connect) flush() error {
376405
if len(c.buffer.Buf) == 0 {
377406
// Nothing to flush.

0 commit comments

Comments
 (0)