Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions async_postgres/pg_client.nim
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ type
conn: PgConnection
ops: seq[PipelineOp]

IsolatedPipelineResults* = object
## Results from `executeIsolated`: per-op error isolation via per-query SYNC.
results*: seq[PipelineResult]
errors*: seq[ref CatchableError] ## errors[i] is nil if ops[i] succeeded

proc buildBeginSql*(opts: TransactionOptions): string =
## Build a BEGIN SQL statement with the specified transaction options
## (isolation level, access mode, deferrable mode).
Expand Down Expand Up @@ -2392,6 +2397,208 @@ proc execute*(
results = await executeImpl(p, timeout)
return results

proc executeIsolatedImpl(
p: Pipeline, timeout: Duration = ZeroDuration
): Future[IsolatedPipelineResults] {.async.} =
## Execute pipeline ops with per-query SYNC for error isolation.
## Each op gets its own ReadyForQuery; a failed op does not abort others.
let conn = p.conn
conn.checkReady()
conn.state = csBusy

# Send Phase (same as executeImpl but SYNC per op)
conn.sendBuf.setLen(0)
var cachedStmts: seq[CachedStmt]
var hasCachedStmts = false
var pendingCacheAdds = 0
var defaultFormats: seq[int16]

for i in 0 ..< p.ops.len:
let formats =
if p.ops[i].paramFormats.len > 0:
p.ops[i].paramFormats
else:
let needed = p.ops[i].params.len
if defaultFormats.len != needed:
defaultFormats = newSeq[int16](needed)
defaultFormats

let cached = conn.lookupStmtCache(p.ops[i].sql)
p.ops[i].cacheHit = cached != nil
p.ops[i].cacheMiss = false

if cached != nil:
p.ops[i].stmtName = cached.name
if p.ops[i].kind == pokQuery:
if not hasCachedStmts:
cachedStmts = newSeq[CachedStmt](p.ops.len)
hasCachedStmts = true
cachedStmts[i] = cached[]
var effectiveResultFormats: seq[int16]
if p.ops[i].kind == pokQuery:
effectiveResultFormats =
if p.ops[i].resultFormats.len == 0:
cached.resultFormats
else:
p.ops[i].resultFormats
p.ops[i].resultFormats = effectiveResultFormats
conn.sendBuf.addBind(
"", cached.name, formats, p.ops[i].params, effectiveResultFormats
)
conn.sendBuf.addExecute("", 0)
elif conn.stmtCacheCapacity > 0:
p.ops[i].cacheMiss = true
p.ops[i].stmtName = conn.nextStmtName()
if conn.stmtCache.len + pendingCacheAdds >= conn.stmtCacheCapacity and
conn.stmtCache.len > 0:
let evicted = conn.evictStmtCache()
conn.sendBuf.addClose(dkStatement, evicted.name)
inc pendingCacheAdds
conn.sendBuf.addParse(p.ops[i].stmtName, p.ops[i].sql, p.ops[i].paramOids)
conn.sendBuf.addDescribe(dkStatement, p.ops[i].stmtName)
conn.sendBuf.addBind(
"", p.ops[i].stmtName, formats, p.ops[i].params, p.ops[i].resultFormats
)
conn.sendBuf.addExecute("", 0)
else:
conn.sendBuf.addParse("", p.ops[i].sql, p.ops[i].paramOids)
conn.sendBuf.addBind("", "", formats, p.ops[i].params, p.ops[i].resultFormats)
if p.ops[i].kind == pokQuery:
conn.sendBuf.addDescribe(dkPortal, "")
conn.sendBuf.addExecute("", 0)

conn.sendBuf.addSync() # Per-op SYNC for error isolation

await conn.sendBufMsg()

# Receive Phase (per-op ReadyForQuery)
var results = newSeq[PipelineResult](p.ops.len)
var errors = newSeq[ref CatchableError](p.ops.len)

# Initialize query results
for i in 0 ..< p.ops.len:
if p.ops[i].kind == pokQuery:
results[i] = PipelineResult(kind: prkQuery)
if p.ops[i].cacheHit:
let c = cachedStmts[i]
results[i].queryResult.fields = c.fields
if p.ops[i].resultFormats.len > 0 and c.colFmts.len > 0:
for j in 0 ..< results[i].queryResult.fields.len:
results[i].queryResult.fields[j].formatCode = c.colFmts[j]
if results[i].queryResult.fields.len > 0:
results[i].queryResult.data =
newRowData(int16(results[i].queryResult.fields.len), c.colFmts, c.colOids)
else:
results[i] = PipelineResult(kind: prkExec)

for opIdx in 0 ..< p.ops.len:
var opError: ref PgQueryError
var cachedFields: seq[FieldDescription]

block opRecv:
while true:
var rowData: RowData = nil
var rowCount: ptr int32 = nil
if p.ops[opIdx].kind == pokQuery:
rowData = results[opIdx].queryResult.data
rowCount = addr results[opIdx].queryResult.rowCount

while (let opt = conn.nextMessage(rowData, rowCount); opt.isSome):
let msg = opt.get
case msg.kind
of bmkParseComplete, bmkBindComplete, bmkCloseComplete:
discard
of bmkParameterDescription:
discard
of bmkRowDescription:
if p.ops[opIdx].kind == pokQuery:
if p.ops[opIdx].cacheMiss:
cachedFields = msg.fields
results[opIdx].queryResult.fields = msg.fields
var cf: seq[int16]
var co: seq[int32]
if p.ops[opIdx].resultFormats.len > 0:
cf = newSeq[int16](msg.fields.len)
co = newSeq[int32](msg.fields.len)
for j in 0 ..< msg.fields.len:
co[j] = msg.fields[j].typeOid
if p.ops[opIdx].resultFormats.len == 1:
results[opIdx].queryResult.fields[j].formatCode =
p.ops[opIdx].resultFormats[0]
cf[j] = p.ops[opIdx].resultFormats[0]
elif j < p.ops[opIdx].resultFormats.len:
results[opIdx].queryResult.fields[j].formatCode =
p.ops[opIdx].resultFormats[j]
cf[j] = p.ops[opIdx].resultFormats[j]
results[opIdx].queryResult.data =
newRowData(int16(msg.fields.len), cf, co)
rowData = results[opIdx].queryResult.data
rowCount = addr results[opIdx].queryResult.rowCount
else:
results[opIdx].queryResult.fields = msg.fields
results[opIdx].queryResult.data = newRowData(int16(msg.fields.len))
rowData = results[opIdx].queryResult.data
rowCount = addr results[opIdx].queryResult.rowCount
of bmkNoData:
discard
of bmkCommandComplete:
if p.ops[opIdx].kind == pokExec:
results[opIdx].commandResult = initCommandResult(msg.commandTag)
else:
results[opIdx].queryResult.commandTag = msg.commandTag
of bmkEmptyQueryResponse:
discard
of bmkErrorResponse:
if opError == nil:
opError = newPgQueryError(msg.errorFields)
of bmkReadyForQuery:
conn.txStatus = msg.txStatus
if opError != nil:
if opError.sqlState == "26000" and p.ops[opIdx].cacheHit:
conn.removeStmtCache(p.ops[opIdx].sql)
errors[opIdx] = opError
elif p.ops[opIdx].cacheMiss:
conn.addStmtCache(
p.ops[opIdx].sql,
CachedStmt(name: p.ops[opIdx].stmtName, fields: cachedFields),
)
break opRecv
else:
discard
await conn.fillRecvBuf(timeout)

conn.state = csReady
return IsolatedPipelineResults(results: results, errors: errors)

proc executeIsolated*(
p: Pipeline, timeout: Duration = ZeroDuration
): Future[IsolatedPipelineResults] {.async.} =
## Execute all queued pipeline operations with per-query error isolation.
## Each operation gets its own SYNC message, so a failed operation does not
## abort subsequent ones. Returns results and per-op errors.
## On timeout, the connection is marked csClosed (protocol out of sync).
if p.ops.len == 0:
return IsolatedPipelineResults(results: @[], errors: @[])
var ir: IsolatedPipelineResults
withConnTracing(
p.conn,
onPipelineStart,
onPipelineEnd,
TracePipelineStartData(opCount: p.ops.len),
TracePipelineEndData,
TracePipelineEndData(),
):
if timeout > ZeroDuration:
try:
ir = await executeIsolatedImpl(p, timeout).wait(timeout)
except AsyncTimeoutError:
p.conn.cancelNoWait()
p.conn.state = csClosed
raise newException(PgTimeoutError, "Pipeline executeIsolated timed out")
else:
ir = await executeIsolatedImpl(p, timeout)
return ir

proc openCursorImpl(
conn: PgConnection,
sql: string,
Expand Down
Loading
Loading