Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions async_postgres/pg_client.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,29 @@ proc hasReturnStmt*(n: NimNode): bool =
return true
return false

proc buildTxBeginAndTimeout*(arg: NimNode): tuple[beginSql, txTimeout: NimNode] =
## Shared helper for `withTransaction` macros.
## Uses `when ... is` to dispatch on the argument type at compile time.
let buildBeginSqlSym = bindSym"buildBeginSql"
let zeroDurSym = bindSym"ZeroDuration"
let txOptsSym = bindSym"TransactionOptions"
let durSym = bindSym"Duration"
let beginSql = quote:
when `arg` is `txOptsSym`:
`buildBeginSqlSym`(`arg`)
elif `arg` is `durSym`:
"BEGIN"
else:
{.error: "withTransaction expects TransactionOptions or Duration".}
let txTimeout = quote:
when `arg` is `txOptsSym`:
`zeroDurSym`
elif `arg` is `durSym`:
`arg`
else:
{.error: "withTransaction expects TransactionOptions or Duration".}
(beginSql, txTimeout)

macro withTransaction*(conn: PgConnection, args: varargs[untyped]): untyped =
## Execute `body` inside a BEGIN/COMMIT transaction.
## On exception, ROLLBACK is issued automatically.
Expand All @@ -1580,10 +1603,6 @@ macro withTransaction*(conn: PgConnection, args: varargs[untyped]): untyped =
## await conn.exec(...)
## conn.withTransaction(TransactionOptions(...), seconds(5)):
## await conn.exec(...)
##
## **Note:** `TransactionOptions` must be passed as a constructor literal, not
## through a variable (the macro uses AST node kind to distinguish options
## from timeout).
var body: NimNode
var beginSql: NimNode
var txTimeout: NimNode
Expand All @@ -1593,15 +1612,8 @@ macro withTransaction*(conn: PgConnection, args: varargs[untyped]): untyped =
beginSql = newStrLitNode("BEGIN")
txTimeout = bindSym"ZeroDuration"
of 2:
if args[0].kind == nnkObjConstr:
# conn.withTransaction(TransactionOptions(...)): body
beginSql = newCall(bindSym"buildBeginSql", args[0])
txTimeout = bindSym"ZeroDuration"
else:
# conn.withTransaction(timeout): body
beginSql = newStrLitNode("BEGIN")
txTimeout = args[0]
body = args[1]
(beginSql, txTimeout) = buildTxBeginAndTimeout(args[0])
of 3:
let opts = args[0]
txTimeout = args[1]
Expand Down
13 changes: 1 addition & 12 deletions async_postgres/pg_pool.nim
Original file line number Diff line number Diff line change
Expand Up @@ -641,10 +641,6 @@ macro withTransaction*(pool: PgPool, args: varargs[untyped]): untyped =
## pool.withTransaction(conn, opts, seconds(5)):
## conn.exec(...)
##
## **Note:** `TransactionOptions` must be passed as a constructor literal, not
## through a variable (the macro uses AST node kind to distinguish options
## from timeout).
##
## **Warning:** Inside the body, use `conn.exec(...)` / `conn.query(...)`
## directly — not `pool.exec(...)` / `pool.query(...)`. Pool methods acquire
## a separate connection, so those statements would run outside this transaction.
Expand All @@ -659,15 +655,8 @@ macro withTransaction*(pool: PgPool, args: varargs[untyped]): untyped =
txTimeout = bindSym"ZeroDuration"
of 3:
connIdent = args[0]
if args[1].kind == nnkObjConstr:
# pool.withTransaction(conn, TransactionOptions(...)): body
beginSql = newCall(bindSym"buildBeginSql", args[1])
txTimeout = bindSym"ZeroDuration"
else:
# pool.withTransaction(conn, timeout): body
beginSql = newStrLitNode("BEGIN")
txTimeout = args[1]
body = args[2]
(beginSql, txTimeout) = buildTxBeginAndTimeout(args[1])
of 4:
connIdent = args[0]
let opts = args[1]
Expand Down
95 changes: 95 additions & 0 deletions tests/test_e2e.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,101 @@ suite "E2E: Transaction":

waitFor t()

test "withTransaction with TransactionOptions variable":
proc t() {.async.} =
let conn = await connect(plainConfig())
discard await conn.exec("DROP TABLE IF EXISTS test_tx_opts_var")
discard await conn.exec(
"CREATE TABLE test_tx_opts_var (id serial PRIMARY KEY, val text)"
)

let opts = TransactionOptions(isolation: ilSerializable)
conn.withTransaction(opts):
discard await conn.exec(
"INSERT INTO test_tx_opts_var (val) VALUES ($1)", @[toPgParam("opts_var")]
)

let res = await conn.query("SELECT val FROM test_tx_opts_var")
doAssert res.rows.len == 1
doAssert res.rows[0].getStr(0) == "opts_var"

discard await conn.exec("DROP TABLE test_tx_opts_var")
await conn.close()

waitFor t()

test "withTransaction with Duration variable":
proc t() {.async.} =
let conn = await connect(plainConfig())
discard await conn.exec("DROP TABLE IF EXISTS test_tx_dur_var")
discard await conn.exec(
"CREATE TABLE test_tx_dur_var (id serial PRIMARY KEY, val text)"
)

let timeout = seconds(5)
conn.withTransaction(timeout):
discard await conn.exec(
"INSERT INTO test_tx_dur_var (val) VALUES ($1)", @[toPgParam("dur_var")]
)

let res = await conn.query("SELECT val FROM test_tx_dur_var")
doAssert res.rows.len == 1
doAssert res.rows[0].getStr(0) == "dur_var"

discard await conn.exec("DROP TABLE test_tx_dur_var")
await conn.close()

waitFor t()

test "pool.withTransaction with TransactionOptions variable":
proc t() {.async.} =
let pool =
await newPool(PoolConfig(connConfig: plainConfig(), minSize: 1, maxSize: 3))
discard await pool.exec("DROP TABLE IF EXISTS test_ptx_opts_var")
discard await pool.exec(
"CREATE TABLE test_ptx_opts_var (id serial PRIMARY KEY, val text)"
)

let opts = TransactionOptions(isolation: ilRepeatableRead)
pool.withTransaction(conn, opts):
discard await conn.exec(
"INSERT INTO test_ptx_opts_var (val) VALUES ($1)",
@[toPgParam("pool_opts_var")],
)

let res = await pool.query("SELECT val FROM test_ptx_opts_var")
doAssert res.rows.len == 1
doAssert res.rows[0].getStr(0) == "pool_opts_var"

discard await pool.exec("DROP TABLE test_ptx_opts_var")
await pool.close()

waitFor t()

test "pool.withTransaction with Duration variable":
proc t() {.async.} =
let pool =
await newPool(PoolConfig(connConfig: plainConfig(), minSize: 1, maxSize: 3))
discard await pool.exec("DROP TABLE IF EXISTS test_ptx_dur_var")
discard await pool.exec(
"CREATE TABLE test_ptx_dur_var (id serial PRIMARY KEY, val text)"
)

let timeout = seconds(5)
pool.withTransaction(conn, timeout):
discard await conn.exec(
"INSERT INTO test_ptx_dur_var (val) VALUES ($1)", @[toPgParam("pool_dur_var")]
)

let res = await pool.query("SELECT val FROM test_ptx_dur_var")
doAssert res.rows.len == 1
doAssert res.rows[0].getStr(0) == "pool_dur_var"

discard await pool.exec("DROP TABLE test_ptx_dur_var")
await pool.close()

waitFor t()

test "withSavepoint releases on success":
proc t() {.async.} =
let conn = await connect(plainConfig())
Expand Down
Loading