diff --git a/async_postgres/pg_connection.nim b/async_postgres/pg_connection.nim index c18cf71..1b55b8a 100644 --- a/async_postgres/pg_connection.nim +++ b/async_postgres/pg_connection.nim @@ -152,6 +152,11 @@ type colOids*: seq[int32] ## Per-column type OIDs for RowData lruNode*: DoublyLinkedNode[string] ## Embedded LRU list node + PgPoolOwner* = ref object of RootObj + ## Opaque base for pool-ownership back-references on `PgConnection`. + ## The concrete type is `PgPool` (defined in `pg_pool`); this base lives + ## here to avoid a circular import. Consumers should not subclass this. + PgConnection* = ref object ## A single PostgreSQL connection with buffered I/O and statement caching. when hasChronos: @@ -198,6 +203,11 @@ type hstoreOid: int32 ## Dynamic OID for hstore extension type; 0 if not available hstoreArrayOid: int32 ## Dynamic OID for hstore[] array; 0 if not available tracer: PgTracer ## Inherited from ConnConfig on connect + ownerPool*: PgPoolOwner + ## Owning pool back-reference. Set when this connection is managed by + ## a `PgPool` (or a pool inside `PgPoolCluster`); `nil` for standalone + ## connections created via `connect`. Used by `release(conn)` to route + ## the connection back to the correct pool. QueryResult* = object ## Result of a query: field descriptions, row data, and command tag. diff --git a/async_postgres/pg_pool.nim b/async_postgres/pg_pool.nim index 0fe7b4e..c49cbce 100644 --- a/async_postgres/pg_pool.nim +++ b/async_postgres/pg_pool.nim @@ -70,7 +70,8 @@ type execFut: Future[CommandResult] ## Non-nil for popExec queryFut: Future[QueryResult] ## Non-nil for popQuery - PgPool* = ref object ## Connection pool that manages a set of PostgreSQL connections. + PgPool* = ref object of PgPoolOwner + ## Connection pool that manages a set of PostgreSQL connections. config: PoolConfig idle: Deque[PooledConn] active: int @@ -274,6 +275,7 @@ proc maintenanceLoop(pool: PgPool) {.async.} = break try: let conn = await connect(pool.config.connConfig).wait(replenishTimeout) + conn.ownerPool = pool pool.metrics.createCount.inc pool.idle.addLast(PooledConn(conn: conn, lastUsedAt: now)) except CatchableError: @@ -301,6 +303,7 @@ proc newPool*(config: PoolConfig): Future[PgPool] {.async.} = pool.cachedNow = Moment.now() for i in 0 ..< cfg.minSize: let conn = await connect(cfg.connConfig) + conn.ownerPool = pool pool.metrics.createCount.inc pool.idle.addLast(PooledConn(conn: conn, lastUsedAt: pool.cachedNow)) except CatchableError as e: @@ -312,10 +315,11 @@ proc newPool*(config: PoolConfig): Future[PgPool] {.async.} = pool.maintenanceTask = maintenanceLoop(pool) return pool -proc release*(pool: PgPool, conn: PgConnection) = - ## Return a connection to the pool. If the connection is broken or in a - ## transaction, it is closed instead. If waiters are queued, the connection - ## is handed directly to the next waiter. +proc releaseImpl(pool: PgPool, conn: PgConnection) = + ## Implementation of `release(conn)`; called once the owning pool is known. + ## Returns the connection to the pool. If the connection is broken or in + ## a transaction, it is closed instead. If waiters are queued, the + ## connection is handed directly to the next waiter. ## ## Discard criteria (`conn.state != csReady`): ## - A timed-out request reaches us via `invalidateOnTimeout` with @@ -358,6 +362,26 @@ proc release*(pool: PgPool, conn: PgConnection) = TracePoolReleaseEndData(wasClosed: wasClosed, handedToWaiter: handedToWaiter), ) +proc release*(conn: PgConnection) = + ## Return a connection to its owning pool. If the connection is broken or + ## in a transaction, it is closed instead; if waiters are queued, it is + ## handed directly to the next waiter. + ## + ## The owning pool is tracked on `conn.ownerPool`, set automatically when + ## the connection is acquired from a `PgPool` (including pools inside a + ## `PgPoolCluster`). For standalone connections created with `connect` + ## this field is `nil` and calling `release` raises `PgError` — use + ## `conn.close()` instead. + ## + ## `withConnection`, `withReadConnection`, `withWriteConnection`, + ## `withPipeline`, and `withTransaction` call this automatically; direct + ## callers only need it when they manage `acquire`/`release` manually. + if conn.ownerPool == nil: + raise newException( + PgError, "release() called on a standalone connection; use conn.close() instead" + ) + PgPool(conn.ownerPool).releaseImpl(conn) + type AcquireResult = tuple[conn: PgConnection, wasCreated: bool] proc acquireImpl(pool: PgPool): Future[AcquireResult] {.async.} = @@ -402,6 +426,7 @@ proc acquireImpl(pool: PgPool): Future[AcquireResult] {.async.} = pool.active.inc try: let conn = await connect(pool.config.connConfig) + conn.ownerPool = pool pool.metrics.createCount.inc recordAcquire() return (conn, true) @@ -433,7 +458,7 @@ proc acquireImpl(pool: PgPool): Future[AcquireResult] {.async.} = # the future just before the timeout fired, return the connection # to the pool instead of leaking it. if fut.completed(): - pool.release(fut.read()) + fut.read().release() raise newException(PgPoolError, "Pool acquire timeout") else: let conn = await fut @@ -467,7 +492,7 @@ template withConnection*(pool: PgPool, conn, body: untyped) = body finally: await pool.resetSession(conn) - pool.release(conn) + conn.release() proc failPendingOp(op: PendingPoolOp, e: ref CatchableError) = ## Fail a pending op's future if not already finished. @@ -531,7 +556,7 @@ proc executeBatch( failPendingOp(op, e) finally: await pool.resetSession(conn) - pool.release(conn) + conn.release() proc dispatchBatchImpl(pool: PgPool) {.async.} = ## Drain the pending ops queue and execute them via pipelined connections. @@ -564,7 +589,7 @@ proc dispatchBatchImpl(pool: PgPool) {.async.} = op.queryFut.complete(r) finally: await pool.resetSession(conn) - pool.release(conn) + conn.release() except CatchableError as e: failPendingOp(op, e) return @@ -596,7 +621,7 @@ proc dispatchBatchImpl(pool: PgPool) {.async.} = for ci in 0 ..< conns.len: if connOps[ci].len == 0: await pool.resetSession(conns[ci]) - pool.release(conns[ci]) + conns[ci].release() continue batchFuts.add(executeBatch(pool, conns[ci], connOps[ci])) @@ -657,7 +682,7 @@ proc exec*( return await conn.exec(sql, params, timeout = timeout) finally: await pool.resetSession(conn) - pool.release(conn) + conn.release() proc query*( pool: PgPool, @@ -688,7 +713,7 @@ proc query*( return await conn.query(sql, params, resultFormat = resultFormat, timeout = timeout) finally: await pool.resetSession(conn) - pool.release(conn) + conn.release() proc queryEach*( pool: PgPool, @@ -708,7 +733,7 @@ proc queryEach*( return await conn.queryEach(sql, params, callback, resultFormat, timeout) finally: await pool.resetSession(conn) - pool.release(conn) + conn.release() proc queryRowOpt*( pool: PgPool, @@ -873,7 +898,7 @@ proc simpleQuery*(pool: PgPool, sql: string): Future[seq[QueryResult]] {.async.} return await conn.simpleQuery(sql) finally: await pool.resetSession(conn) - pool.release(conn) + conn.release() proc simpleExec*( pool: PgPool, sql: string, timeout: Duration = ZeroDuration @@ -885,7 +910,7 @@ proc simpleExec*( return await conn.simpleExec(sql, timeout) finally: await pool.resetSession(conn) - pool.release(conn) + conn.release() proc execInTransaction*( pool: PgPool, @@ -899,7 +924,7 @@ proc execInTransaction*( return await conn.execInTransaction(sql, params, timeout) finally: await pool.resetSession(conn) - pool.release(conn) + conn.release() proc queryInTransaction*( pool: PgPool, @@ -914,7 +939,7 @@ proc queryInTransaction*( return await conn.queryInTransaction(sql, params, resultFormat, timeout) finally: await pool.resetSession(conn) - pool.release(conn) + conn.release() proc notify*( pool: PgPool, @@ -928,7 +953,7 @@ proc notify*( await conn.notify(channel, payload, timeout) finally: await pool.resetSession(conn) - pool.release(conn) + conn.release() macro withTransaction*(pool: PgPool, args: varargs[untyped]): untyped = ## Execute `body` inside a BEGIN/COMMIT transaction using a pooled connection. @@ -999,7 +1024,7 @@ macro withTransaction*(pool: PgPool, args: varargs[untyped]): untyped = raise `eSym` finally: await `resetSessionSym`(`poolSym`, `connIdent`) - `poolSym`.release(`connIdent`) + `connIdent`.release() template withPipeline*(pool: PgPool, pipeline, body: untyped) = ## Acquire a connection, create a Pipeline, execute body, then release. @@ -1010,7 +1035,7 @@ template withPipeline*(pool: PgPool, pipeline, body: untyped) = body finally: await pool.resetSession(conn) - pool.release(conn) + conn.release() proc close*(pool: PgPool, timeout = ZeroDuration): Future[void] {.async.} = ## Close the pool: stop the maintenance loop, cancel all waiters, and close diff --git a/async_postgres/pg_pool_cluster.nim b/async_postgres/pg_pool_cluster.nim index 5533646..c8a5555 100644 --- a/async_postgres/pg_pool_cluster.nim +++ b/async_postgres/pg_pool_cluster.nim @@ -112,7 +112,7 @@ template withReadConnection*(cluster: PgPoolCluster, conn, body: untyped) = body finally: await connPool.resetSession(conn) - connPool.release(conn) + conn.release() template withWriteConnection*(cluster: PgPoolCluster, conn, body: untyped) = ## Acquire a write connection from the primary pool, execute `body`, then release. @@ -122,7 +122,7 @@ template withWriteConnection*(cluster: PgPoolCluster, conn, body: untyped) = body finally: await cluster.primary.resetSession(conn) - cluster.primary.release(conn) + conn.release() # Macro to generate cluster forwarding procs from compact declarations. # Each entry is a bodiless `proc` whose name starts with "read" or "write". @@ -185,7 +185,7 @@ macro clusterForwards(mode: static[string], body: untyped): untyped = ident"await", newCall(newDotExpr(ident"pool", ident"resetSession"), ident"conn"), ), - newCall(newDotExpr(ident"pool", ident"release"), ident"conn"), + newCall(newDotExpr(ident"conn", ident"release")), ) else: let primary = newDotExpr(ident"cluster", ident"primary") @@ -200,7 +200,7 @@ macro clusterForwards(mode: static[string], body: untyped): untyped = ident"await", newCall(newDotExpr(primary.copyNimTree(), ident"resetSession"), ident"conn"), ), - newCall(newDotExpr(primary.copyNimTree(), ident"release"), ident"conn"), + newCall(newDotExpr(ident"conn", ident"release")), ) let tryFinally = @@ -523,7 +523,7 @@ macro withTransaction*(cluster: PgPoolCluster, args: varargs[untyped]): untyped raise `eSym` finally: await `resetSessionSym`(`clusterSym`.primary, `connIdent`) - `clusterSym`.primary.release(`connIdent`) + `connIdent`.release() template withPipeline*(cluster: PgPoolCluster, pipeline, body: untyped) = ## Create a pipeline on the primary pool. diff --git a/tests/test_e2e.nim b/tests/test_e2e.nim index 3da6b29..afb8855 100644 --- a/tests/test_e2e.nim +++ b/tests/test_e2e.nim @@ -495,7 +495,7 @@ suite "E2E: Connection Pool": await newPool(PoolConfig(connConfig: plainConfig(), minSize: 1, maxSize: 3)) let conn = await pool.acquire() doAssert conn.state == csReady - pool.release(conn) + conn.release() await pool.close() waitFor t() @@ -521,9 +521,9 @@ suite "E2E: Connection Pool": doAssert c1.state == csReady doAssert c2.state == csReady doAssert c3.state == csReady - pool.release(c1) - pool.release(c2) - pool.release(c3) + c1.release() + c2.release() + c3.release() await pool.close() waitFor t() @@ -544,7 +544,7 @@ suite "E2E: Connection Pool": let conn1 = await pool.acquire() let pid1 = conn1.pid doAssert conn1.state == csReady - pool.release(conn1) + conn1.release() # Wait for maxLifetime to expire await sleepAsync(milliseconds(600)) @@ -555,7 +555,7 @@ suite "E2E: Connection Pool": doAssert conn2.state == csReady # The new connection should be different (different pid from server) doAssert conn2.pid != pid1 - pool.release(conn2) + conn2.release() await pool.close() @@ -576,7 +576,7 @@ suite "E2E: Connection Pool": # Create and release a connection so it sits idle let conn = await pool.acquire() doAssert conn.state == csReady - pool.release(conn) + conn.release() # Wait for idleTimeout + maintenance cycle await sleepAsync(milliseconds(500)) @@ -604,9 +604,9 @@ suite "E2E: Connection Pool": let c1 = await pool.acquire() let c2 = await pool.acquire() let c3 = await pool.acquire() - pool.release(c1) - pool.release(c2) - pool.release(c3) + c1.release() + c2.release() + c3.release() doAssert pool.idleCount == 3 # Wait for idleTimeout + maintenance cycles @@ -633,7 +633,7 @@ suite "E2E: Connection Pool": let conn1 = await pool.acquire() let pid1 = conn1.pid - pool.release(conn1) + conn1.release() # Wait for maxLifetime to expire await sleepAsync(milliseconds(400)) @@ -642,7 +642,7 @@ suite "E2E: Connection Pool": let conn2 = await pool.acquire() doAssert conn2.state == csReady doAssert conn2.pid != pid1 - pool.release(conn2) + conn2.release() await pool.close() @@ -661,7 +661,7 @@ suite "E2E: Connection Pool": ) let conn1 = await pool.acquire() - pool.release(conn1) + conn1.release() # Wait for maxLifetime to expire and maintenance to clean up await sleepAsync(milliseconds(500)) @@ -690,12 +690,12 @@ suite "E2E: Connection Pool": let conn1 = await pool.acquire() let pid1 = conn1.pid - pool.release(conn1) + conn1.release() # Immediate re-acquire should return the same connection let conn2 = await pool.acquire() doAssert conn2.pid == pid1 - pool.release(conn2) + conn2.release() await pool.close() @@ -2764,7 +2764,7 @@ suite "E2E: Pool minSize Replenishment": doAssert conn.state == csReady let res = await conn.simpleQuery("SELECT 1") doAssert res[0].rows[0][0].get().toString() == "1" - pool.release(conn) + conn.release() await pool.close() @@ -2992,7 +2992,7 @@ suite "E2E: Pool Stress": var raised = false try: let conn2 = await pool.acquire() - pool.release(conn2) + conn2.release() except PgError as e: raised = true doAssert "timeout" in e.msg.toLowerAscii() @@ -3000,7 +3000,7 @@ suite "E2E: Pool Stress": doAssert raised # Release and verify pool still works - pool.release(conn) + conn.release() let res = await pool.query("SELECT 1") doAssert res.rows.len == 1 doAssert res.rows[0].getStr(0) == "1" @@ -3028,7 +3028,7 @@ suite "E2E: Pool Stress": let conn = await pool.acquire() let pidRes = await conn.query("SELECT pg_backend_pid()") let pid = pidRes.rows[0].getInt(0) - pool.release(conn) + conn.release() # Kill the backend via a separate connection let killer = await connect(plainConfig()) diff --git a/tests/test_pool.nim b/tests/test_pool.nim index 3aa2c04..74eed01 100644 --- a/tests/test_pool.nim +++ b/tests/test_pool.nim @@ -15,13 +15,14 @@ privateAccess(PgConnection) privateAccess(PooledConn) privateAccess(Waiter) -proc mockConn(state: PgConnState = csReady): PgConnection = +proc mockConn(state: PgConnState = csReady, pool: PgPool = nil): PgConnection = result = PgConnection( recvBuf: @[], state: state, txStatus: tsIdle, serverParams: initTable[string, string](), createdAt: Moment.now(), + ownerPool: pool, ) proc makePool(minSize: int = 0, maxSize: int = 5): PgPool = @@ -43,6 +44,14 @@ proc makePool(minSize: int = 0, maxSize: int = 5): PgPool = proc toPooled(conn: PgConnection): PooledConn = PooledConn(conn: conn, lastUsedAt: Moment.now()) +proc release(pool: PgPool, conn: PgConnection) = + ## Test-only shim that wires `ownerPool` on throw-away mock connections + ## and delegates to the public `conn.release()` API. Production callers + ## should use `conn.release()` directly; pool-acquired connections already + ## have `ownerPool` set. + conn.ownerPool = pool + conn.release() + suite "initConnConfig": test "defaults": let cfg = initConnConfig() @@ -272,6 +281,12 @@ suite "Pool release": check pool.active == 0 check pool.idle.len == 0 + test "release on standalone connection raises PgError": + let conn = mockConn() + check conn.ownerPool == nil + expect PgError: + conn.release() + suite "Pool resetSession": test "resetSession is no-op when resetQuery is empty": let pool = makePool() diff --git a/tests/test_pool_cluster.nim b/tests/test_pool_cluster.nim index 9a2b261..462725b 100644 --- a/tests/test_pool_cluster.nim +++ b/tests/test_pool_cluster.nim @@ -10,13 +10,14 @@ privateAccess(PooledConn) privateAccess(Waiter) privateAccess(PgPoolCluster) -proc mockConn(state: PgConnState = csReady): PgConnection = +proc mockConn(state: PgConnState = csReady, pool: PgPool = nil): PgConnection = PgConnection( recvBuf: @[], state: state, txStatus: tsIdle, serverParams: initTable[string, string](), createdAt: Moment.now(), + ownerPool: pool, ) proc makePool(minSize: int = 0, maxSize: int = 5): PgPool = @@ -49,6 +50,14 @@ proc makeCluster( closed: false, ) +proc mockIdle(pool: PgPool, conn: PgConnection) = + ## Place a mock connection in the pool's idle queue, wiring `ownerPool` + ## the same way production `newPool` does. Without this, a subsequent + ## `conn.release()` (e.g. via `withReadConnection`) would raise because + ## the back-reference is nil. + conn.ownerPool = pool + pool.idle.addLast(PooledConn(conn: conn, lastUsedAt: Moment.now())) + suite "newPoolCluster targetSessionAttrs": test "auto-sets tsaReadWrite for primary and tsaPreferStandby for replica when tsaAny": var pCfg = PoolConfig( @@ -96,7 +105,7 @@ suite "Read routing": test "acquireRead returns connection from replica pool": let cluster = makeCluster() let conn = mockConn() - cluster.replica.idle.addLast(PooledConn(conn: conn, lastUsedAt: Moment.now())) + cluster.replica.mockIdle(conn) let (acquired, pool) = waitFor acquireRead(cluster) check acquired == conn @@ -108,7 +117,7 @@ suite "Read routing": proc t() {.async.} = let cluster = makeCluster() let conn = mockConn() - cluster.replica.idle.addLast(PooledConn(conn: conn, lastUsedAt: Moment.now())) + cluster.replica.mockIdle(conn) cluster.withReadConnection(c): doAssert c == conn @@ -123,7 +132,7 @@ suite "Read routing": proc t() {.async.} = let cluster = makeCluster() let conn = mockConn() - cluster.primary.idle.addLast(PooledConn(conn: conn, lastUsedAt: Moment.now())) + cluster.primary.mockIdle(conn) cluster.withWriteConnection(c): doAssert c == conn @@ -139,8 +148,8 @@ suite "Read routing": let cluster = makeCluster() let wConn = mockConn() let rConn = mockConn() - cluster.primary.idle.addLast(PooledConn(conn: wConn, lastUsedAt: Moment.now())) - cluster.replica.idle.addLast(PooledConn(conn: rConn, lastUsedAt: Moment.now())) + cluster.primary.mockIdle(wConn) + cluster.replica.mockIdle(rConn) cluster.withWriteConnection(conn): doAssert conn == wConn @@ -160,7 +169,7 @@ suite "Exception safety": proc t() {.async.} = let cluster = makeCluster() let conn = mockConn() - cluster.replica.idle.addLast(PooledConn(conn: conn, lastUsedAt: Moment.now())) + cluster.replica.mockIdle(conn) var caught = false try: @@ -180,7 +189,7 @@ suite "Exception safety": proc t() {.async.} = let cluster = makeCluster() let conn = mockConn() - cluster.primary.idle.addLast(PooledConn(conn: conn, lastUsedAt: Moment.now())) + cluster.primary.mockIdle(conn) var caught = false try: diff --git a/tests/test_tracing.nim b/tests/test_tracing.nim index 93d4e1e..0e97876 100644 --- a/tests/test_tracing.nim +++ b/tests/test_tracing.nim @@ -620,7 +620,7 @@ suite "Tracing: pool acquire/release": doAssert log.poolAcquireEnds[0].wasCreated == true doAssert not log.poolAcquireEnds[0].hasErr - pool.release(conn) + conn.release() doAssert log.poolReleaseStarts.len == 1 doAssert log.poolReleaseStarts[0].hasConn @@ -635,7 +635,7 @@ suite "Tracing: pool acquire/release": doAssert log.poolAcquireEnds.len == 2 doAssert log.poolAcquireEnds[1].wasCreated == false - pool.release(conn2) + conn2.release() await pool.close() waitFor t() @@ -652,14 +652,14 @@ suite "Tracing: pool acquire/release": # maxSize=1, so next acquire will wait let fut = pool.acquire() # Release conn1 -- should hand to the waiter - pool.release(conn1) + conn1.release() let conn2 = await fut doAssert log.poolReleaseEnds.len == 1 doAssert not log.poolReleaseEnds[0].wasClosed doAssert log.poolReleaseEnds[0].handedToWaiter - pool.release(conn2) + conn2.release() await pool.close() waitFor t() @@ -686,7 +686,7 @@ suite "Tracing: pool close errors": doAssert log.poolCloseErrors[0].hasConn doAssert log.poolCloseErrors[0].errMsg == "simulated close failure" - pool.release(conn) + conn.release() await pool.close() waitFor t() @@ -703,7 +703,7 @@ suite "Tracing: pool close errors": # Must not raise even though the hook is nil. pool.reportCloseError(conn, newException(PgError, "ignored")) - pool.release(conn) + conn.release() await pool.close() waitFor t()