diff --git a/async_postgres/pg_pool.nim b/async_postgres/pg_pool.nim index c4ad751..b8cb2d7 100644 --- a/async_postgres/pg_pool.nim +++ b/async_postgres/pg_pool.nim @@ -683,9 +683,14 @@ template withPipeline*(pool: PgPool, pipeline, body: untyped) = await pool.resetSession(conn) pool.release(conn) -proc close*(pool: PgPool): Future[void] {.async.} = +proc close*(pool: PgPool, timeout = ZeroDuration): Future[void] {.async.} = ## Close the pool: stop the maintenance loop, cancel all waiters, and close - ## all idle connections. Active connections are closed when released. + ## all idle and active connections. + ## + ## When `timeout > ZeroDuration`, waits up to `timeout` for active + ## connections to be released. Unreleased connections are closed when they + ## are eventually returned to the pool. Without a timeout (or + ## `ZeroDuration`), active connections are closed on release. pool.closed = true # Stop maintenance loop @@ -699,6 +704,12 @@ proc close*(pool: PgPool): Future[void] {.async.} = waiter.fut.fail(newException(PgError, "Pool closed")) pool.waiterCount = 0 + # Wait for active connections to drain + if timeout > ZeroDuration and pool.active > 0: + let deadline = Moment.now() + timeout + while pool.active > 0 and Moment.now() < deadline: + await sleepAsync(milliseconds(50)) + # Close all idle connections while pool.idle.len > 0: let pc = pool.idle.popFirst() diff --git a/async_postgres/pg_pool_cluster.nim b/async_postgres/pg_pool_cluster.nim index 17e85f8..8c2a91c 100644 --- a/async_postgres/pg_pool_cluster.nim +++ b/async_postgres/pg_pool_cluster.nim @@ -438,16 +438,16 @@ template withPipeline*(cluster: PgPoolCluster, pipeline, body: untyped) = cluster.primary.withPipeline(pipeline): body -proc close*(cluster: PgPoolCluster): Future[void] {.async.} = +proc close*(cluster: PgPoolCluster, timeout = ZeroDuration): Future[void] {.async.} = ## Close both primary and replica pools. cluster.closed = true var firstErr: ref CatchableError try: - await cluster.primary.close() + await cluster.primary.close(timeout) except CatchableError as e: firstErr = e try: - await cluster.replica.close() + await cluster.replica.close(timeout) except CatchableError as e: if firstErr == nil: firstErr = e diff --git a/tests/test_pool.nim b/tests/test_pool.nim index 63fc9ea..b26a355 100644 --- a/tests/test_pool.nim +++ b/tests/test_pool.nim @@ -387,6 +387,30 @@ suite "Pool close": waitFor pool.close() check pool.closed + test "close with timeout waits for active connections": + let pool = makePool() + pool.active = 1 + + # Simulate a connection being released after a short delay + proc releaseAfter(pool: PgPool) {.async.} = + await sleepAsync(milliseconds(20)) + let conn = mockConn(csClosed) + pool.active.dec + + let releaseFut = releaseAfter(pool) + waitFor pool.close(timeout = seconds(1)) + waitFor releaseFut + check pool.closed + check pool.active == 0 + + test "close with timeout expires when active not released": + let pool = makePool() + pool.active = 1 + + waitFor pool.close(timeout = milliseconds(100)) + check pool.closed + check pool.active == 1 + suite "Pool active count tracking": test "release then acquire roundtrip": let pool = makePool()