diff --git a/async_postgres/pg_client.nim b/async_postgres/pg_client.nim index 642a95e..5dfbd4f 100644 --- a/async_postgres/pg_client.nim +++ b/async_postgres/pg_client.nim @@ -1566,6 +1566,29 @@ proc hasReturnStmt*(n: NimNode): bool = return true return false +proc buildTxBeginAndTimeout*(arg: NimNode): tuple[beginSql, txTimeout: NimNode] = + ## Shared helper for `withTransaction` macros. + ## Uses `when ... is` to dispatch on the argument type at compile time. + let buildBeginSqlSym = bindSym"buildBeginSql" + let zeroDurSym = bindSym"ZeroDuration" + let txOptsSym = bindSym"TransactionOptions" + let durSym = bindSym"Duration" + let beginSql = quote: + when `arg` is `txOptsSym`: + `buildBeginSqlSym`(`arg`) + elif `arg` is `durSym`: + "BEGIN" + else: + {.error: "withTransaction expects TransactionOptions or Duration".} + let txTimeout = quote: + when `arg` is `txOptsSym`: + `zeroDurSym` + elif `arg` is `durSym`: + `arg` + else: + {.error: "withTransaction expects TransactionOptions or Duration".} + (beginSql, txTimeout) + macro withTransaction*(conn: PgConnection, args: varargs[untyped]): untyped = ## Execute `body` inside a BEGIN/COMMIT transaction. ## On exception, ROLLBACK is issued automatically. @@ -1580,10 +1603,6 @@ macro withTransaction*(conn: PgConnection, args: varargs[untyped]): untyped = ## await conn.exec(...) ## conn.withTransaction(TransactionOptions(...), seconds(5)): ## await conn.exec(...) - ## - ## **Note:** `TransactionOptions` must be passed as a constructor literal, not - ## through a variable (the macro uses AST node kind to distinguish options - ## from timeout). var body: NimNode var beginSql: NimNode var txTimeout: NimNode @@ -1593,15 +1612,8 @@ macro withTransaction*(conn: PgConnection, args: varargs[untyped]): untyped = beginSql = newStrLitNode("BEGIN") txTimeout = bindSym"ZeroDuration" of 2: - if args[0].kind == nnkObjConstr: - # conn.withTransaction(TransactionOptions(...)): body - beginSql = newCall(bindSym"buildBeginSql", args[0]) - txTimeout = bindSym"ZeroDuration" - else: - # conn.withTransaction(timeout): body - beginSql = newStrLitNode("BEGIN") - txTimeout = args[0] body = args[1] + (beginSql, txTimeout) = buildTxBeginAndTimeout(args[0]) of 3: let opts = args[0] txTimeout = args[1] diff --git a/async_postgres/pg_pool.nim b/async_postgres/pg_pool.nim index 6b25f85..4de434b 100644 --- a/async_postgres/pg_pool.nim +++ b/async_postgres/pg_pool.nim @@ -641,10 +641,6 @@ macro withTransaction*(pool: PgPool, args: varargs[untyped]): untyped = ## pool.withTransaction(conn, opts, seconds(5)): ## conn.exec(...) ## - ## **Note:** `TransactionOptions` must be passed as a constructor literal, not - ## through a variable (the macro uses AST node kind to distinguish options - ## from timeout). - ## ## **Warning:** Inside the body, use `conn.exec(...)` / `conn.query(...)` ## directly — not `pool.exec(...)` / `pool.query(...)`. Pool methods acquire ## a separate connection, so those statements would run outside this transaction. @@ -659,15 +655,8 @@ macro withTransaction*(pool: PgPool, args: varargs[untyped]): untyped = txTimeout = bindSym"ZeroDuration" of 3: connIdent = args[0] - if args[1].kind == nnkObjConstr: - # pool.withTransaction(conn, TransactionOptions(...)): body - beginSql = newCall(bindSym"buildBeginSql", args[1]) - txTimeout = bindSym"ZeroDuration" - else: - # pool.withTransaction(conn, timeout): body - beginSql = newStrLitNode("BEGIN") - txTimeout = args[1] body = args[2] + (beginSql, txTimeout) = buildTxBeginAndTimeout(args[1]) of 4: connIdent = args[0] let opts = args[1] diff --git a/tests/test_e2e.nim b/tests/test_e2e.nim index f61b57f..e80e932 100644 --- a/tests/test_e2e.nim +++ b/tests/test_e2e.nim @@ -1053,6 +1053,101 @@ suite "E2E: Transaction": waitFor t() + test "withTransaction with TransactionOptions variable": + proc t() {.async.} = + let conn = await connect(plainConfig()) + discard await conn.exec("DROP TABLE IF EXISTS test_tx_opts_var") + discard await conn.exec( + "CREATE TABLE test_tx_opts_var (id serial PRIMARY KEY, val text)" + ) + + let opts = TransactionOptions(isolation: ilSerializable) + conn.withTransaction(opts): + discard await conn.exec( + "INSERT INTO test_tx_opts_var (val) VALUES ($1)", @[toPgParam("opts_var")] + ) + + let res = await conn.query("SELECT val FROM test_tx_opts_var") + doAssert res.rows.len == 1 + doAssert res.rows[0].getStr(0) == "opts_var" + + discard await conn.exec("DROP TABLE test_tx_opts_var") + await conn.close() + + waitFor t() + + test "withTransaction with Duration variable": + proc t() {.async.} = + let conn = await connect(plainConfig()) + discard await conn.exec("DROP TABLE IF EXISTS test_tx_dur_var") + discard await conn.exec( + "CREATE TABLE test_tx_dur_var (id serial PRIMARY KEY, val text)" + ) + + let timeout = seconds(5) + conn.withTransaction(timeout): + discard await conn.exec( + "INSERT INTO test_tx_dur_var (val) VALUES ($1)", @[toPgParam("dur_var")] + ) + + let res = await conn.query("SELECT val FROM test_tx_dur_var") + doAssert res.rows.len == 1 + doAssert res.rows[0].getStr(0) == "dur_var" + + discard await conn.exec("DROP TABLE test_tx_dur_var") + await conn.close() + + waitFor t() + + test "pool.withTransaction with TransactionOptions variable": + proc t() {.async.} = + let pool = + await newPool(PoolConfig(connConfig: plainConfig(), minSize: 1, maxSize: 3)) + discard await pool.exec("DROP TABLE IF EXISTS test_ptx_opts_var") + discard await pool.exec( + "CREATE TABLE test_ptx_opts_var (id serial PRIMARY KEY, val text)" + ) + + let opts = TransactionOptions(isolation: ilRepeatableRead) + pool.withTransaction(conn, opts): + discard await conn.exec( + "INSERT INTO test_ptx_opts_var (val) VALUES ($1)", + @[toPgParam("pool_opts_var")], + ) + + let res = await pool.query("SELECT val FROM test_ptx_opts_var") + doAssert res.rows.len == 1 + doAssert res.rows[0].getStr(0) == "pool_opts_var" + + discard await pool.exec("DROP TABLE test_ptx_opts_var") + await pool.close() + + waitFor t() + + test "pool.withTransaction with Duration variable": + proc t() {.async.} = + let pool = + await newPool(PoolConfig(connConfig: plainConfig(), minSize: 1, maxSize: 3)) + discard await pool.exec("DROP TABLE IF EXISTS test_ptx_dur_var") + discard await pool.exec( + "CREATE TABLE test_ptx_dur_var (id serial PRIMARY KEY, val text)" + ) + + let timeout = seconds(5) + pool.withTransaction(conn, timeout): + discard await conn.exec( + "INSERT INTO test_ptx_dur_var (val) VALUES ($1)", @[toPgParam("pool_dur_var")] + ) + + let res = await pool.query("SELECT val FROM test_ptx_dur_var") + doAssert res.rows.len == 1 + doAssert res.rows[0].getStr(0) == "pool_dur_var" + + discard await pool.exec("DROP TABLE test_ptx_dur_var") + await pool.close() + + waitFor t() + test "withSavepoint releases on success": proc t() {.async.} = let conn = await connect(plainConfig())