diff --git a/async_postgres/pg_client.nim b/async_postgres/pg_client.nim index b755b0e..f36e645 100644 --- a/async_postgres/pg_client.nim +++ b/async_postgres/pg_client.nim @@ -18,6 +18,7 @@ type PreparedStatement* = object ## A server-side prepared statement returned by `prepare`. conn*: PgConnection name*: string + sql*: string fields*: seq[FieldDescription] paramOids*: seq[int32] @@ -345,15 +346,23 @@ proc exec*( ## On timeout the connection is marked closed (protocol desync) and cannot be ## reused; pooled connections are discarded automatically. var tag: string - if timeout > ZeroDuration: - try: - tag = await execImpl(conn, sql, params, timeout).wait(timeout) - except AsyncTimeoutError: - conn.cancelNoWait() - conn.state = csClosed - raise newException(PgTimeoutError, "Exec timed out") - else: - tag = await execImpl(conn, sql, params) + withConnTracing( + conn, + onQueryStart, + onQueryEnd, + TraceQueryStartData(sql: sql, params: params, isExec: true), + TraceQueryEndData, + TraceQueryEndData(commandTag: tag), + ): + if timeout > ZeroDuration: + try: + tag = await execImpl(conn, sql, params, timeout).wait(timeout) + except AsyncTimeoutError: + conn.cancelNoWait() + conn.state = csClosed + raise newException(PgTimeoutError, "Exec timed out") + else: + tag = await execImpl(conn, sql, params) return initCommandResult(tag) template queryRecvLoop( @@ -753,17 +762,27 @@ proc queryEach*( ): Future[int64] {.async.} = ## Execute a query with typed parameters, invoking `callback` once per row. ## Returns the number of rows processed. - let resultFormats = resultFormat.toFormatCodes() - if timeout > ZeroDuration: - try: - return await queryEachImpl(conn, sql, params, callback, resultFormats, timeout) - .wait(timeout) - except AsyncTimeoutError: - conn.cancelNoWait() - conn.state = csClosed - raise newException(PgTimeoutError, "queryEach timed out") - else: - return await queryEachImpl(conn, sql, params, callback, resultFormats) + var count: int64 + withConnTracing( + conn, + onQueryStart, + onQueryEnd, + TraceQueryStartData(sql: sql, params: params, isExec: false), + TraceQueryEndData, + TraceQueryEndData(rowCount: count), + ): + let resultFormats = resultFormat.toFormatCodes() + if timeout > ZeroDuration: + try: + count = await queryEachImpl(conn, sql, params, callback, resultFormats, timeout) + .wait(timeout) + except AsyncTimeoutError: + conn.cancelNoWait() + conn.state = csClosed + raise newException(PgTimeoutError, "queryEach timed out") + else: + count = await queryEachImpl(conn, sql, params, callback, resultFormats) + return count proc query( conn: PgConnection, @@ -800,16 +819,26 @@ proc query*( ## Execute a query with typed parameters. ## On timeout the connection is marked closed (protocol desync) and cannot be ## reused; pooled connections are discarded automatically. - let resultFormats = resultFormat.toFormatCodes() - if timeout > ZeroDuration: - try: - return await queryImpl(conn, sql, params, resultFormats, timeout).wait(timeout) - except AsyncTimeoutError: - conn.cancelNoWait() - conn.state = csClosed - raise newException(PgTimeoutError, "Query timed out") - else: - return await queryImpl(conn, sql, params, resultFormats) + var qr: QueryResult + withConnTracing( + conn, + onQueryStart, + onQueryEnd, + TraceQueryStartData(sql: sql, params: params, isExec: false), + TraceQueryEndData, + TraceQueryEndData(commandTag: qr.commandTag, rowCount: qr.rowCount), + ): + let resultFormats = resultFormat.toFormatCodes() + if timeout > ZeroDuration: + try: + qr = await queryImpl(conn, sql, params, resultFormats, timeout).wait(timeout) + except AsyncTimeoutError: + conn.cancelNoWait() + conn.state = csClosed + raise newException(PgTimeoutError, "Query timed out") + else: + qr = await queryImpl(conn, sql, params, resultFormats) + return qr proc queryOne*( conn: PgConnection, @@ -968,7 +997,7 @@ proc prepareImpl( batch.addSync() await conn.sendMsg(batch) - var stmt = PreparedStatement(conn: conn, name: name) + var stmt = PreparedStatement(conn: conn, name: name, sql: sql) var queryError: ref PgQueryError block recvLoop: @@ -1003,15 +1032,25 @@ proc prepare*( ): Future[PreparedStatement] {.async.} = ## Prepare a named statement, returning metadata. ## On timeout, the connection is marked csClosed (protocol out of sync). - if timeout > ZeroDuration: - try: - return await prepareImpl(conn, name, sql, timeout).wait(timeout) - except AsyncTimeoutError: - conn.cancelNoWait() - conn.state = csClosed - raise newException(PgTimeoutError, "Prepare timed out") - else: - return await prepareImpl(conn, name, sql) + var stmt: PreparedStatement + withConnTracing( + conn, + onPrepareStart, + onPrepareEnd, + TracePrepareStartData(name: name, sql: sql), + TracePrepareEndData, + TracePrepareEndData(), + ): + if timeout > ZeroDuration: + try: + stmt = await prepareImpl(conn, name, sql, timeout).wait(timeout) + except AsyncTimeoutError: + conn.cancelNoWait() + conn.state = csClosed + raise newException(PgTimeoutError, "Prepare timed out") + else: + stmt = await prepareImpl(conn, name, sql) + return stmt proc executeImpl( stmt: PreparedStatement, @@ -1098,16 +1137,26 @@ proc execute*( timeout: Duration = ZeroDuration, ): Future[QueryResult] {.async.} = ## Execute a prepared statement with typed parameters. - let resultFormats = resultFormat.toFormatCodes() - if timeout > ZeroDuration: - try: - return await executeImpl(stmt, params, resultFormats, timeout).wait(timeout) - except AsyncTimeoutError: - stmt.conn.cancelNoWait() - stmt.conn.state = csClosed - raise newException(PgTimeoutError, "Execute timed out") - else: - return await executeImpl(stmt, params, resultFormats) + var qr: QueryResult + withConnTracing( + stmt.conn, + onQueryStart, + onQueryEnd, + TraceQueryStartData(sql: stmt.sql, params: params, isExec: false), + TraceQueryEndData, + TraceQueryEndData(commandTag: qr.commandTag, rowCount: qr.rowCount), + ): + let resultFormats = resultFormat.toFormatCodes() + if timeout > ZeroDuration: + try: + qr = await executeImpl(stmt, params, resultFormats, timeout).wait(timeout) + except AsyncTimeoutError: + stmt.conn.cancelNoWait() + stmt.conn.state = csClosed + raise newException(PgTimeoutError, "Execute timed out") + else: + qr = await executeImpl(stmt, params, resultFormats) + return qr proc closeImpl( stmt: PreparedStatement, timeout: Duration = ZeroDuration @@ -1250,15 +1299,23 @@ proc copyIn*( ## Execute COPY ... FROM STDIN with a single contiguous ``seq[byte]``. ## Avoids the copy that the ``openArray[byte]`` overload performs. var tag: string - if timeout > ZeroDuration: - try: - tag = await copyInRawImpl(conn, sql, data, timeout).wait(timeout) - except AsyncTimeoutError: - conn.cancelNoWait() - conn.state = csClosed - raise newException(PgTimeoutError, "COPY IN timed out") - else: - tag = await copyInRawImpl(conn, sql, data) + withConnTracing( + conn, + onCopyStart, + onCopyEnd, + TraceCopyStartData(sql: sql, direction: tcdIn), + TraceCopyEndData, + TraceCopyEndData(commandTag: tag), + ): + if timeout > ZeroDuration: + try: + tag = await copyInRawImpl(conn, sql, data, timeout).wait(timeout) + except AsyncTimeoutError: + conn.cancelNoWait() + conn.state = csClosed + raise newException(PgTimeoutError, "COPY IN timed out") + else: + tag = await copyInRawImpl(conn, sql, data) return initCommandResult(tag) proc copyIn*( @@ -1410,15 +1467,25 @@ proc copyInStream*( ## ``seq[byte]`` signals EOF. If the callback raises, CopyFail is sent and ## the connection returns to csReady. ## On timeout, the connection is marked csClosed (protocol out of sync). - if timeout > ZeroDuration: - try: - return await copyInStreamImpl(conn, sql, callback, timeout).wait(timeout) - except AsyncTimeoutError: - conn.cancelNoWait() - conn.state = csClosed - raise newException(PgTimeoutError, "COPY IN stream timed out") - else: - return await copyInStreamImpl(conn, sql, callback) + var info: CopyInInfo + withConnTracing( + conn, + onCopyStart, + onCopyEnd, + TraceCopyStartData(sql: sql, direction: tcdIn), + TraceCopyEndData, + TraceCopyEndData(commandTag: info.commandTag), + ): + if timeout > ZeroDuration: + try: + info = await copyInStreamImpl(conn, sql, callback, timeout).wait(timeout) + except AsyncTimeoutError: + conn.cancelNoWait() + conn.state = csClosed + raise newException(PgTimeoutError, "COPY IN stream timed out") + else: + info = await copyInStreamImpl(conn, sql, callback) + return info proc copyOutImpl( conn: PgConnection, sql: string, timeout: Duration = ZeroDuration @@ -1464,15 +1531,25 @@ proc copyOut*( ## Execute COPY ... TO STDOUT via simple query protocol. ## Collects all CopyData messages and returns them in a CopyResult. ## On timeout, the connection is marked csClosed (protocol out of sync). - if timeout > ZeroDuration: - try: - return await copyOutImpl(conn, sql, timeout).wait(timeout) - except AsyncTimeoutError: - conn.cancelNoWait() - conn.state = csClosed - raise newException(PgTimeoutError, "COPY OUT timed out") - else: - return await copyOutImpl(conn, sql) + var cr: CopyResult + withConnTracing( + conn, + onCopyStart, + onCopyEnd, + TraceCopyStartData(sql: sql, direction: tcdOut), + TraceCopyEndData, + TraceCopyEndData(commandTag: cr.commandTag), + ): + if timeout > ZeroDuration: + try: + cr = await copyOutImpl(conn, sql, timeout).wait(timeout) + except AsyncTimeoutError: + conn.cancelNoWait() + conn.state = csClosed + raise newException(PgTimeoutError, "COPY OUT timed out") + else: + cr = await copyOutImpl(conn, sql) + return cr proc copyOutStreamImpl( conn: PgConnection, @@ -1541,15 +1618,25 @@ proc copyOutStream*( ## CopyData chunk through `callback`. The callback is awaited, providing ## natural TCP backpressure. If the callback raises, the connection is ## marked csClosed (protocol cannot be resynchronized). - if timeout > ZeroDuration: - try: - return await copyOutStreamImpl(conn, sql, callback, timeout).wait(timeout) - except AsyncTimeoutError: - conn.cancelNoWait() - conn.state = csClosed - raise newException(PgTimeoutError, "COPY OUT stream timed out") - else: - return await copyOutStreamImpl(conn, sql, callback) + var info: CopyOutInfo + withConnTracing( + conn, + onCopyStart, + onCopyEnd, + TraceCopyStartData(sql: sql, direction: tcdOut), + TraceCopyEndData, + TraceCopyEndData(commandTag: info.commandTag), + ): + if timeout > ZeroDuration: + try: + info = await copyOutStreamImpl(conn, sql, callback, timeout).wait(timeout) + except AsyncTimeoutError: + conn.cancelNoWait() + conn.state = csClosed + raise newException(PgTimeoutError, "COPY OUT stream timed out") + else: + info = await copyOutStreamImpl(conn, sql, callback) + return info proc hasReturnStmt*(n: NimNode): bool = ## Check whether an AST contains a `return` statement (excluding nested @@ -1817,20 +1904,30 @@ proc execInTransaction( ## Execute a statement inside a pipelined BEGIN/COMMIT transaction (1 round trip). ## On error, ROLLBACK is issued automatically. ## On timeout, the connection is marked csClosed (protocol out of sync). - if timeout > ZeroDuration: - try: - return await execInTransactionImpl( + var tag: string + withConnTracing( + conn, + onQueryStart, + onQueryEnd, + TraceQueryStartData(sql: sql, isExec: true), + TraceQueryEndData, + TraceQueryEndData(commandTag: tag), + ): + if timeout > ZeroDuration: + try: + tag = await execInTransactionImpl( + conn, "BEGIN", sql, params, paramOids, paramFormats, timeout + ) + .wait(timeout) + except AsyncTimeoutError: + conn.cancelNoWait() + conn.state = csClosed + raise newException(PgTimeoutError, "execInTransaction timed out") + else: + tag = await execInTransactionImpl( conn, "BEGIN", sql, params, paramOids, paramFormats, timeout ) - .wait(timeout) - except AsyncTimeoutError: - conn.cancelNoWait() - conn.state = csClosed - raise newException(PgTimeoutError, "execInTransaction timed out") - else: - return await execInTransactionImpl( - conn, "BEGIN", sql, params, paramOids, paramFormats, timeout - ) + return tag proc execInTransaction*( conn: PgConnection, @@ -1851,22 +1948,30 @@ proc execInTransaction*( timeout: Duration = ZeroDuration, ): Future[CommandResult] {.async.} = ## Execute a statement inside a pipelined transaction with options. - let (oids, formats, values) = extractParams(params) - let beginSql = buildBeginSql(opts) var tag: string - if timeout > ZeroDuration: - try: - tag = await execInTransactionImpl( - conn, beginSql, sql, values, oids, formats, timeout - ) - .wait(timeout) - except AsyncTimeoutError: - conn.cancelNoWait() - conn.state = csClosed - raise newException(PgTimeoutError, "execInTransaction timed out") - else: - tag = - await execInTransactionImpl(conn, beginSql, sql, values, oids, formats, timeout) + withConnTracing( + conn, + onQueryStart, + onQueryEnd, + TraceQueryStartData(sql: sql, params: params, isExec: true), + TraceQueryEndData, + TraceQueryEndData(commandTag: tag), + ): + let (oids, formats, values) = extractParams(params) + let beginSql = buildBeginSql(opts) + if timeout > ZeroDuration: + try: + tag = await execInTransactionImpl( + conn, beginSql, sql, values, oids, formats, timeout + ) + .wait(timeout) + except AsyncTimeoutError: + conn.cancelNoWait() + conn.state = csClosed + raise newException(PgTimeoutError, "execInTransaction timed out") + else: + tag = + await execInTransactionImpl(conn, beginSql, sql, values, oids, formats, timeout) return initCommandResult(tag) proc queryInTransactionImpl( @@ -1958,20 +2063,30 @@ proc queryInTransaction( ## Execute a query inside a pipelined BEGIN/COMMIT transaction (1 round trip). ## Returns rows. On error, ROLLBACK is issued automatically. ## On timeout, the connection is marked csClosed (protocol out of sync). - if timeout > ZeroDuration: - try: - return await queryInTransactionImpl( + var qr: QueryResult + withConnTracing( + conn, + onQueryStart, + onQueryEnd, + TraceQueryStartData(sql: sql, isExec: false), + TraceQueryEndData, + TraceQueryEndData(commandTag: qr.commandTag, rowCount: qr.rowCount), + ): + if timeout > ZeroDuration: + try: + qr = await queryInTransactionImpl( + conn, "BEGIN", sql, params, paramOids, paramFormats, resultFormats, timeout + ) + .wait(timeout) + except AsyncTimeoutError: + conn.cancelNoWait() + conn.state = csClosed + raise newException(PgTimeoutError, "queryInTransaction timed out") + else: + qr = await queryInTransactionImpl( conn, "BEGIN", sql, params, paramOids, paramFormats, resultFormats, timeout ) - .wait(timeout) - except AsyncTimeoutError: - conn.cancelNoWait() - conn.state = csClosed - raise newException(PgTimeoutError, "queryInTransaction timed out") - else: - return await queryInTransactionImpl( - conn, "BEGIN", sql, params, paramOids, paramFormats, resultFormats, timeout - ) + return qr proc queryInTransaction*( conn: PgConnection, @@ -1995,23 +2110,33 @@ proc queryInTransaction*( timeout: Duration = ZeroDuration, ): Future[QueryResult] {.async.} = ## Execute a query inside a pipelined transaction with options. - let (oids, formats, values) = extractParams(params) - let resultFormats = resultFormat.toFormatCodes() - let beginSql = buildBeginSql(opts) - if timeout > ZeroDuration: - try: - return await queryInTransactionImpl( + var qr: QueryResult + withConnTracing( + conn, + onQueryStart, + onQueryEnd, + TraceQueryStartData(sql: sql, params: params, isExec: false), + TraceQueryEndData, + TraceQueryEndData(commandTag: qr.commandTag, rowCount: qr.rowCount), + ): + let (oids, formats, values) = extractParams(params) + let resultFormats = resultFormat.toFormatCodes() + let beginSql = buildBeginSql(opts) + if timeout > ZeroDuration: + try: + qr = await queryInTransactionImpl( + conn, beginSql, sql, values, oids, formats, resultFormats, timeout + ) + .wait(timeout) + except AsyncTimeoutError: + conn.cancelNoWait() + conn.state = csClosed + raise newException(PgTimeoutError, "queryInTransaction timed out") + else: + qr = await queryInTransactionImpl( conn, beginSql, sql, values, oids, formats, resultFormats, timeout ) - .wait(timeout) - except AsyncTimeoutError: - conn.cancelNoWait() - conn.state = csClosed - raise newException(PgTimeoutError, "queryInTransaction timed out") - else: - return await queryInTransactionImpl( - conn, beginSql, sql, values, oids, formats, resultFormats, timeout - ) + return qr proc newPipeline*(conn: PgConnection): Pipeline = ## Create a new pipeline for batching multiple operations into a single round trip. @@ -2247,15 +2372,25 @@ proc execute*( ## On timeout, the connection is marked csClosed (protocol out of sync). if p.ops.len == 0: return @[] - if timeout > ZeroDuration: - try: - return await executeImpl(p, timeout).wait(timeout) - except AsyncTimeoutError: - p.conn.cancelNoWait() - p.conn.state = csClosed - raise newException(PgTimeoutError, "Pipeline execute timed out") - else: - return await executeImpl(p, timeout) + var results: seq[PipelineResult] + withConnTracing( + p.conn, + onPipelineStart, + onPipelineEnd, + TracePipelineStartData(opCount: p.ops.len), + TracePipelineEndData, + TracePipelineEndData(), + ): + if timeout > ZeroDuration: + try: + results = await executeImpl(p, timeout).wait(timeout) + except AsyncTimeoutError: + p.conn.cancelNoWait() + p.conn.state = csClosed + raise newException(PgTimeoutError, "Pipeline execute timed out") + else: + results = await executeImpl(p, timeout) + return results proc openCursorImpl( conn: PgConnection, diff --git a/async_postgres/pg_connection.nim b/async_postgres/pg_connection.nim index 2311d57..15fb213 100644 --- a/async_postgres/pg_connection.nim +++ b/async_postgres/pg_connection.nim @@ -115,6 +115,7 @@ type hosts*: seq[HostEntry] ## Multiple hosts for failover (empty = use host/port) targetSessionAttrs*: TargetSessionAttrs ## Target server type (default tsaAny) extraParams*: seq[(string, string)] ## Additional startup parameters + tracer*: PgTracer ## Optional tracer for connection-level hooks Notification* = object ## A NOTIFY message received from PostgreSQL. pid*: int32 @@ -181,6 +182,7 @@ type stmtCacheCapacity*: int ## 0=disabled, default 256 rowDataBuf*: RowData ## Reusable RowData buffer to avoid per-query allocation hstoreOid*: int32 ## Dynamic OID for hstore extension type; 0 if not available + tracer*: PgTracer ## Inherited from ConnConfig on connect QueryResult* = object ## Result of a query: field descriptions, row data, and command tag. @@ -206,6 +208,157 @@ type columnFormats*: seq[int16] commandTag*: string + # Tracing types + TraceContext* = RootRef + ## Opaque correlation token returned by trace Start hooks and passed to End hooks. + ## Users subtype RootObj (e.g. ``type Span = ref object of RootObj``) and return + ## it from Start hooks; End hooks downcast via ``Span(ctx)``. + + TraceCopyDirection* = enum + tcdIn + tcdOut + + TraceConnectStartData* = object ## Data passed to the connect start hook. + hosts*: seq[HostEntry] + + TraceConnectEndData* = object ## Data passed to the connect end hook. + conn*: PgConnection + err*: ref CatchableError + + TraceQueryStartData* = object ## Data passed to the query/exec start hook. + sql*: string + params*: seq[PgParam] + isExec*: bool ## true for exec, false for query + + TraceQueryEndData* = object ## Data passed to the query/exec end hook. + commandTag*: string + rowCount*: int64 + err*: ref CatchableError + + TracePrepareStartData* = object ## Data passed to the prepare start hook. + name*: string + sql*: string + + TracePrepareEndData* = object ## Data passed to the prepare end hook. + err*: ref CatchableError + + TracePipelineStartData* = object ## Data passed to the pipeline start hook. + opCount*: int + + TracePipelineEndData* = object ## Data passed to the pipeline end hook. + err*: ref CatchableError + + TraceCopyStartData* = object ## Data passed to the copy start hook. + sql*: string + direction*: TraceCopyDirection + + TraceCopyEndData* = object ## Data passed to the copy end hook. + commandTag*: string + err*: ref CatchableError + + TracePoolAcquireStartData* = object ## Data passed to the pool acquire start hook. + idleCount*: int + activeCount*: int + maxSize*: int + + TracePoolAcquireEndData* = object ## Data passed to the pool acquire end hook. + conn*: PgConnection + err*: ref CatchableError + wasCreated*: bool ## true if a new connection was created + + TracePoolReleaseStartData* = object ## Data passed to the pool release start hook. + conn*: PgConnection + + TracePoolReleaseEndData* = object ## Data passed to the pool release end hook. + wasClosed*: bool ## true if connection was closed instead of returned to pool + handedToWaiter*: bool ## true if connection was given directly to a waiting acquirer + + PgTracer* = ref object + ## Tracing hooks for async-postgres operations. + ## Set only the callbacks you need; nil callbacks are skipped with zero overhead. + ## + ## Start hooks return a ``TraceContext`` (opaque pointer) that is passed to the + ## corresponding End hook for correlation (e.g. timing, span linking). + ## Return nil from Start if you don't need correlation. + onConnectStart*: + proc(data: TraceConnectStartData): TraceContext {.gcsafe, raises: [].} + onConnectEnd*: + proc(ctx: TraceContext, data: TraceConnectEndData) {.gcsafe, raises: [].} + onQueryStart*: proc(conn: PgConnection, data: TraceQueryStartData): TraceContext {. + gcsafe, raises: [] + .} + onQueryEnd*: proc(ctx: TraceContext, conn: PgConnection, data: TraceQueryEndData) {. + gcsafe, raises: [] + .} + onPrepareStart*: proc(conn: PgConnection, data: TracePrepareStartData): TraceContext {. + gcsafe, raises: [] + .} + onPrepareEnd*: proc( + ctx: TraceContext, conn: PgConnection, data: TracePrepareEndData + ) {.gcsafe, raises: [].} + onPipelineStart*: proc( + conn: PgConnection, data: TracePipelineStartData + ): TraceContext {.gcsafe, raises: [].} + onPipelineEnd*: proc( + ctx: TraceContext, conn: PgConnection, data: TracePipelineEndData + ) {.gcsafe, raises: [].} + onCopyStart*: proc(conn: PgConnection, data: TraceCopyStartData): TraceContext {. + gcsafe, raises: [] + .} + onCopyEnd*: proc(ctx: TraceContext, conn: PgConnection, data: TraceCopyEndData) {. + gcsafe, raises: [] + .} + onPoolAcquireStart*: + proc(data: TracePoolAcquireStartData): TraceContext {.gcsafe, raises: [].} + onPoolAcquireEnd*: + proc(ctx: TraceContext, data: TracePoolAcquireEndData) {.gcsafe, raises: [].} + onPoolReleaseStart*: + proc(data: TracePoolReleaseStartData): TraceContext {.gcsafe, raises: [].} + onPoolReleaseEnd*: + proc(ctx: TraceContext, data: TracePoolReleaseEndData) {.gcsafe, raises: [].} + +template withConnTracing*( + conn: PgConnection, + startHook, endHook: untyped, + startData: typed, + EndDataType: typedesc, + endDataExpr: typed, + body: untyped, +) = + ## Wrap an operation with connection-scoped tracing hooks. + var traceCtx {.inject.}: TraceContext + if conn.tracer != nil and conn.tracer.startHook != nil: + traceCtx = conn.tracer.startHook(conn, startData) + try: + body + except CatchableError as e: + if conn.tracer != nil and conn.tracer.endHook != nil: + conn.tracer.endHook(traceCtx, conn, EndDataType(err: e)) + raise e + if conn.tracer != nil and conn.tracer.endHook != nil: + conn.tracer.endHook(traceCtx, conn, endDataExpr) + +template withTracing*( + tracer: PgTracer, + startHook, endHook: untyped, + startData: typed, + EndDataType: typedesc, + endDataExpr: typed, + body: untyped, +) = + ## Wrap an operation with non-connection tracing hooks (connect, pool). + var traceCtx {.inject.}: TraceContext + if tracer != nil and tracer.startHook != nil: + traceCtx = tracer.startHook(startData) + try: + body + except CatchableError as e: + if tracer != nil and tracer.endHook != nil: + tracer.endHook(traceCtx, EndDataType(err: e)) + raise e + if tracer != nil and tracer.endHook != nil: + tracer.endHook(traceCtx, endDataExpr) + proc isUnixSocket*(host: string): bool {.inline.} = ## True if `host` represents a Unix socket directory (starts with '/'). ## Compatible with libpq behavior. @@ -1115,42 +1268,60 @@ proc simpleQuery*(conn: PgConnection, sql: string): Future[seq[QueryResult]] {.a ## Execute one or more SQL statements via simple query protocol. ## Returns one `QueryResult` per statement. Supports multiple statements separated by semicolons. conn.checkReady() - conn.state = csBusy - await conn.sendMsg(encodeQuery(sql)) var results: seq[QueryResult] - var current = QueryResult() - var queryError: ref PgQueryError + var totalRows: int32 + var lastTag: string + # For multi-statement queries (e.g. "SELECT 1; SELECT 2"), the trace end hook + # receives the aggregated row count and only the last command tag. + withConnTracing( + conn, + onQueryStart, + onQueryEnd, + TraceQueryStartData(sql: sql, isExec: false), + TraceQueryEndData, + TraceQueryEndData(commandTag: lastTag, rowCount: totalRows), + ): + conn.state = csBusy + await conn.sendMsg(encodeQuery(sql)) - block recvLoop: - while true: - while ( - let opt = conn.nextMessage(current.data, addr current.rowCount) - opt.isSome - ) - : - let msg = opt.get - case msg.kind - of bmkRowDescription: - current = - QueryResult(fields: msg.fields, data: newRowData(int16(msg.fields.len))) - of bmkCommandComplete: - current.commandTag = msg.commandTag - results.add(current) - current = QueryResult() - of bmkEmptyQueryResponse: - results.add(QueryResult()) - of bmkErrorResponse: - queryError = newPgQueryError(msg.errorFields) - of bmkReadyForQuery: - conn.txStatus = msg.txStatus - conn.state = csReady - if queryError != nil: - raise queryError - break recvLoop - else: - discard - await conn.fillRecvBuf() + var current = QueryResult() + var queryError: ref PgQueryError + + block recvLoop: + while true: + while ( + let opt = conn.nextMessage(current.data, addr current.rowCount) + opt.isSome + ) + : + let msg = opt.get + case msg.kind + of bmkRowDescription: + current = + QueryResult(fields: msg.fields, data: newRowData(int16(msg.fields.len))) + of bmkCommandComplete: + current.commandTag = msg.commandTag + results.add(current) + current = QueryResult() + of bmkEmptyQueryResponse: + results.add(QueryResult()) + of bmkErrorResponse: + queryError = newPgQueryError(msg.errorFields) + of bmkReadyForQuery: + conn.txStatus = msg.txStatus + conn.state = csReady + if queryError != nil: + raise queryError + break recvLoop + else: + discard + await conn.fillRecvBuf() + + for r in results: + totalRows += r.rowCount + if r.commandTag.len > 0: + lastTag = r.commandTag return results @@ -1251,15 +1422,23 @@ proc simpleExec*( ## Lighter than `exec` for parameter-less commands (no Parse/Bind/Describe overhead). ## On timeout, the connection is marked csClosed (protocol out of sync). var tag: string - if timeout > ZeroDuration: - try: - tag = await simpleExecImpl(conn, sql, timeout).wait(timeout) - except AsyncTimeoutError: - conn.cancelNoWait() - conn.state = csClosed - raise newException(PgTimeoutError, "simpleExec timed out") - else: - tag = await simpleExecImpl(conn, sql) + withConnTracing( + conn, + onQueryStart, + onQueryEnd, + TraceQueryStartData(sql: sql, isExec: true), + TraceQueryEndData, + TraceQueryEndData(commandTag: tag), + ): + if timeout > ZeroDuration: + try: + tag = await simpleExecImpl(conn, sql, timeout).wait(timeout) + except AsyncTimeoutError: + conn.cancelNoWait() + conn.state = csClosed + raise newException(PgTimeoutError, "simpleExec timed out") + else: + tag = await simpleExecImpl(conn, sql) return initCommandResult(tag) proc isConnected(conn: PgConnection): bool = @@ -1404,10 +1583,26 @@ proc connect*(config: ConnConfig): Future[PgConnection] = PgConnectionError, "Could not connect to any host: " & errors.join("; ") ) - if config.connectTimeout != default(Duration): - perform().wait(config.connectTimeout) - else: - perform() + proc wrapped(): Future[PgConnection] {.async.} = + let hosts = config.getHosts() + var conn: PgConnection + withTracing( + config.tracer, + onConnectStart, + onConnectEnd, + TraceConnectStartData(hosts: hosts), + TraceConnectEndData, + TraceConnectEndData(conn: conn), + ): + conn = + if config.connectTimeout != default(Duration): + await perform().wait(config.connectTimeout) + else: + await perform() + conn.tracer = config.tracer + return conn + + wrapped() proc onNotify*(conn: PgConnection, callback: NotifyCallback) = ## Set a callback invoked for each incoming NOTIFY message. diff --git a/async_postgres/pg_pool.nim b/async_postgres/pg_pool.nim index 4de434b..e921cec 100644 --- a/async_postgres/pg_pool.nim +++ b/async_postgres/pg_pool.nim @@ -29,6 +29,7 @@ type ## "DEALLOCATE ALL" (clear prepared statements only), ## "RESET ALL" (reset session parameters only). ## On failure, the connection is discarded. + tracer*: PgTracer ## Optional tracer for pool-level hooks (acquire/release) PooledConn = object ## An idle connection held by the pool with its last-used timestamp. @@ -259,27 +260,41 @@ proc release*(pool: PgPool, conn: PgConnection) = ## Return a connection to the pool. If the connection is broken or in a ## transaction, it is closed instead. If waiters are queued, the connection ## is handed directly to the next waiter. + var traceCtx: TraceContext + if pool.config.tracer != nil and pool.config.tracer.onPoolReleaseStart != nil: + traceCtx = + pool.config.tracer.onPoolReleaseStart(TracePoolReleaseStartData(conn: conn)) + + var wasClosed = false + var handedToWaiter = false if pool.closed or conn.state != csReady or conn.txStatus != tsIdle: if pool.active > 0: pool.active.dec pool.closeNoWait(conn) - return + wasClosed = true + else: + block dispatch: + while pool.waiters.len > 0: + let waiter = pool.waiters.popFirst() + if waiter.cancelled: + continue + pool.waiterCount.dec + waiter.fut.complete(conn) + handedToWaiter = true + break dispatch + if pool.active > 0: + pool.active.dec + pool.idle.addLast(PooledConn(conn: conn, lastUsedAt: pool.cachedNow)) - while pool.waiters.len > 0: - let waiter = pool.waiters.popFirst() - if waiter.cancelled: - continue - pool.waiterCount.dec - waiter.fut.complete(conn) - return - if pool.active > 0: - pool.active.dec - pool.idle.addLast(PooledConn(conn: conn, lastUsedAt: pool.cachedNow)) + if pool.config.tracer != nil and pool.config.tracer.onPoolReleaseEnd != nil: + pool.config.tracer.onPoolReleaseEnd( + traceCtx, + TracePoolReleaseEndData(wasClosed: wasClosed, handedToWaiter: handedToWaiter), + ) -proc acquire*(pool: PgPool): Future[PgConnection] {.async.} = - ## Acquire a connection from the pool. Tries idle connections first (with - ## health checks), creates a new one if under `maxSize`, or waits for a - ## release. Raises `PgPoolError` on timeout or if the pool is closed. +type AcquireResult = tuple[conn: PgConnection, wasCreated: bool] + +proc acquireImpl(pool: PgPool): Future[AcquireResult] {.async.} = if pool.closed: raise newException(PgPoolError, "Pool is closed") @@ -323,7 +338,7 @@ proc acquire*(pool: PgPool): Future[PgConnection] {.async.} = continue pool.active.inc recordAcquire() - return pc.conn + return (pc.conn, false) # No idle connections; create new if under limit if pool.active < pool.config.maxSize: @@ -332,7 +347,7 @@ proc acquire*(pool: PgPool): Future[PgConnection] {.async.} = let conn = await connect(pool.config.connConfig) pool.metrics.createCount.inc recordAcquire() - return conn + return (conn, true) except CatchableError as e: pool.active.dec raise e @@ -351,7 +366,7 @@ proc acquire*(pool: PgPool): Future[PgConnection] {.async.} = try: let conn = await fut.wait(pool.config.acquireTimeout) recordAcquire() - return conn + return (conn, false) except AsyncTimeoutError: waiter.cancelled = true pool.waiterCount.dec @@ -366,7 +381,25 @@ proc acquire*(pool: PgPool): Future[PgConnection] {.async.} = else: let conn = await fut recordAcquire() - return conn + return (conn, false) + +proc acquire*(pool: PgPool): Future[PgConnection] {.async.} = + ## Acquire a connection from the pool. Tries idle connections first (with + ## health checks), creates a new one if under `maxSize`, or waits for a + ## release. Raises `PgPoolError` on timeout or if the pool is closed. + var ar: AcquireResult + withTracing( + pool.config.tracer, + onPoolAcquireStart, + onPoolAcquireEnd, + TracePoolAcquireStartData( + idleCount: pool.idle.len, activeCount: pool.active, maxSize: pool.config.maxSize + ), + TracePoolAcquireEndData, + TracePoolAcquireEndData(conn: ar.conn, wasCreated: ar.wasCreated), + ): + ar = await pool.acquireImpl() + return ar.conn template withConnection*(pool: PgPool, conn, body: untyped) = ## Acquire a connection, execute `body`, then release it back to the pool. diff --git a/tests/all_tests.nim b/tests/all_tests.nim index 64ee273..dfc116e 100644 --- a/tests/all_tests.nim +++ b/tests/all_tests.nim @@ -1,4 +1,4 @@ import test_advisory_lock, test_auth, test_dsn, test_e2e, test_keepalive, test_largeobject, - test_pool, test_protocol, test_rowdata, test_sql, test_ssl, test_types, + test_pool, test_protocol, test_rowdata, test_sql, test_ssl, test_tracing, test_types, test_pool_cluster diff --git a/tests/test_tracing.nim b/tests/test_tracing.nim new file mode 100644 index 0000000..31d805b --- /dev/null +++ b/tests/test_tracing.nim @@ -0,0 +1,727 @@ +import std/[unittest, strutils] + +import ../async_postgres/async_backend + +import ../async_postgres/[pg_client, pg_types, pg_protocol] +import ../async_postgres/pg_pool {.all.} +import ../async_postgres/pg_connection {.all.} + +const + PgHost = "127.0.0.1" + PgPort = 15432 + PgUser = "test" + PgPassword = "test" + PgDatabase = "test" + +proc plainConfig(): ConnConfig = + ConnConfig( + host: PgHost, + port: PgPort, + user: PgUser, + password: PgPassword, + database: PgDatabase, + sslMode: sslDisable, + ) + +proc toBytes(s: string): seq[byte] = + @(s.toOpenArrayByte(0, s.high)) + +# TraceContext helpers +type + SpanKind = enum + skConnect + skQuery + skPrepare + skPipeline + skCopy + skPoolAcquire + skPoolRelease + + Span = ref object of RootObj + kind: SpanKind + id: int + +var nextSpanId = 0 + +proc newSpan(kind: SpanKind): Span = + nextSpanId.inc + Span(kind: kind, id: nextSpanId) + +# Record types +type + ConnectStartRec = object + hosts: seq[HostEntry] + + ConnectEndRec = object + spanId: int + hasConn: bool + hasErr: bool + + QueryStartRec = object + sql: string + isExec: bool + + QueryEndRec = object + spanId: int + commandTag: string + rowCount: int64 + hasErr: bool + + PrepareStartRec = object + name: string + sql: string + + PrepareEndRec = object + spanId: int + hasErr: bool + + PipelineStartRec = object + opCount: int + + PipelineEndRec = object + spanId: int + hasErr: bool + + CopyStartRec = object + sql: string + direction: TraceCopyDirection + + CopyEndRec = object + spanId: int + commandTag: string + hasErr: bool + + PoolAcquireStartRec = object + idleCount: int + activeCount: int + maxSize: int + + PoolAcquireEndRec = object + spanId: int + hasConn: bool + wasCreated: bool + hasErr: bool + + PoolReleaseStartRec = object + hasConn: bool + + PoolReleaseEndRec = object + spanId: int + wasClosed: bool + handedToWaiter: bool + + TraceLog = ref object + connectStarts: seq[ConnectStartRec] + connectEnds: seq[ConnectEndRec] + queryStarts: seq[QueryStartRec] + queryEnds: seq[QueryEndRec] + prepareStarts: seq[PrepareStartRec] + prepareEnds: seq[PrepareEndRec] + pipelineStarts: seq[PipelineStartRec] + pipelineEnds: seq[PipelineEndRec] + copyStarts: seq[CopyStartRec] + copyEnds: seq[CopyEndRec] + poolAcquireStarts: seq[PoolAcquireStartRec] + poolAcquireEnds: seq[PoolAcquireEndRec] + poolReleaseStarts: seq[PoolReleaseStartRec] + poolReleaseEnds: seq[PoolReleaseEndRec] + +proc newTraceLog(): TraceLog = + TraceLog() + +proc buildTracer(log: TraceLog): PgTracer = + let tracer = PgTracer() + + tracer.onConnectStart = proc( + data: TraceConnectStartData + ): TraceContext {.gcsafe, raises: [].} = + log.connectStarts.add(ConnectStartRec(hosts: data.hosts)) + let span = newSpan(skConnect) + return span + + tracer.onConnectEnd = proc( + ctx: TraceContext, data: TraceConnectEndData + ) {.gcsafe, raises: [].} = + let span = Span(ctx) + log.connectEnds.add( + ConnectEndRec(spanId: span.id, hasConn: data.conn != nil, hasErr: data.err != nil) + ) + + tracer.onQueryStart = proc( + conn: PgConnection, data: TraceQueryStartData + ): TraceContext {.gcsafe, raises: [].} = + log.queryStarts.add(QueryStartRec(sql: data.sql, isExec: data.isExec)) + let span = newSpan(skQuery) + return span + + tracer.onQueryEnd = proc( + ctx: TraceContext, conn: PgConnection, data: TraceQueryEndData + ) {.gcsafe, raises: [].} = + let span = Span(ctx) + log.queryEnds.add( + QueryEndRec( + spanId: span.id, + commandTag: data.commandTag, + rowCount: data.rowCount, + hasErr: data.err != nil, + ) + ) + + tracer.onPrepareStart = proc( + conn: PgConnection, data: TracePrepareStartData + ): TraceContext {.gcsafe, raises: [].} = + log.prepareStarts.add(PrepareStartRec(name: data.name, sql: data.sql)) + let span = newSpan(skPrepare) + return span + + tracer.onPrepareEnd = proc( + ctx: TraceContext, conn: PgConnection, data: TracePrepareEndData + ) {.gcsafe, raises: [].} = + let span = Span(ctx) + log.prepareEnds.add(PrepareEndRec(spanId: span.id, hasErr: data.err != nil)) + + tracer.onPipelineStart = proc( + conn: PgConnection, data: TracePipelineStartData + ): TraceContext {.gcsafe, raises: [].} = + log.pipelineStarts.add(PipelineStartRec(opCount: data.opCount)) + let span = newSpan(skPipeline) + return span + + tracer.onPipelineEnd = proc( + ctx: TraceContext, conn: PgConnection, data: TracePipelineEndData + ) {.gcsafe, raises: [].} = + let span = Span(ctx) + log.pipelineEnds.add(PipelineEndRec(spanId: span.id, hasErr: data.err != nil)) + + tracer.onCopyStart = proc( + conn: PgConnection, data: TraceCopyStartData + ): TraceContext {.gcsafe, raises: [].} = + log.copyStarts.add(CopyStartRec(sql: data.sql, direction: data.direction)) + let span = newSpan(skCopy) + return span + + tracer.onCopyEnd = proc( + ctx: TraceContext, conn: PgConnection, data: TraceCopyEndData + ) {.gcsafe, raises: [].} = + let span = Span(ctx) + log.copyEnds.add( + CopyEndRec(spanId: span.id, commandTag: data.commandTag, hasErr: data.err != nil) + ) + + tracer.onPoolAcquireStart = proc( + data: TracePoolAcquireStartData + ): TraceContext {.gcsafe, raises: [].} = + log.poolAcquireStarts.add( + PoolAcquireStartRec( + idleCount: data.idleCount, activeCount: data.activeCount, maxSize: data.maxSize + ) + ) + let span = newSpan(skPoolAcquire) + return span + + tracer.onPoolAcquireEnd = proc( + ctx: TraceContext, data: TracePoolAcquireEndData + ) {.gcsafe, raises: [].} = + let span = Span(ctx) + log.poolAcquireEnds.add( + PoolAcquireEndRec( + spanId: span.id, + hasConn: data.conn != nil, + wasCreated: data.wasCreated, + hasErr: data.err != nil, + ) + ) + + tracer.onPoolReleaseStart = proc( + data: TracePoolReleaseStartData + ): TraceContext {.gcsafe, raises: [].} = + log.poolReleaseStarts.add(PoolReleaseStartRec(hasConn: data.conn != nil)) + let span = newSpan(skPoolRelease) + return span + + tracer.onPoolReleaseEnd = proc( + ctx: TraceContext, data: TracePoolReleaseEndData + ) {.gcsafe, raises: [].} = + let span = Span(ctx) + log.poolReleaseEnds.add( + PoolReleaseEndRec( + spanId: span.id, wasClosed: data.wasClosed, handedToWaiter: data.handedToWaiter + ) + ) + + return tracer + +proc tracedConfig(tracer: PgTracer): ConnConfig = + var cfg = plainConfig() + cfg.tracer = tracer + return cfg + +suite "Tracing: connect": + test "onConnectStart and onConnectEnd are called with correct context": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + let conn = await connect(tracedConfig(tracer)) + + doAssert log.connectStarts.len == 1 + doAssert log.connectStarts[0].hosts.len == 1 + doAssert log.connectStarts[0].hosts[0].host == PgHost + doAssert log.connectStarts[0].hosts[0].port == PgPort + doAssert log.connectEnds.len == 1 + doAssert log.connectEnds[0].hasConn + doAssert not log.connectEnds[0].hasErr + # Verify context correlation + doAssert log.connectEnds[0].spanId > 0 + + await conn.close() + + waitFor t() + + test "onConnectEnd receives error on connection failure": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + var cfg = tracedConfig(tracer) + cfg.port = 19999 # wrong port + + var raised = false + try: + let conn = await connect(cfg) + await conn.close() + except CatchableError: + raised = true + + doAssert raised + doAssert log.connectStarts.len == 1 + doAssert log.connectEnds.len == 1 + doAssert log.connectEnds[0].hasErr + + waitFor t() + +suite "Tracing: exec": + test "onQueryStart(isExec=true) and onQueryEnd with commandTag": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + let conn = await connect(tracedConfig(tracer)) + log.queryStarts.setLen(0) + log.queryEnds.setLen(0) + + discard await conn.exec("SELECT 1") + + doAssert log.queryStarts.len == 1 + doAssert log.queryStarts[0].sql == "SELECT 1" + doAssert log.queryStarts[0].isExec == true + doAssert log.queryEnds.len == 1 + doAssert log.queryEnds[0].commandTag.len > 0 + doAssert not log.queryEnds[0].hasErr + # Context correlation + doAssert log.queryEnds[0].spanId > 0 + + await conn.close() + + waitFor t() + +suite "Tracing: query": + test "onQueryStart(isExec=false) and onQueryEnd with rowCount": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + let conn = await connect(tracedConfig(tracer)) + log.queryStarts.setLen(0) + log.queryEnds.setLen(0) + + discard await conn.query("SELECT generate_series(1, 3)") + + doAssert log.queryStarts.len == 1 + doAssert log.queryStarts[0].sql == "SELECT generate_series(1, 3)" + doAssert log.queryStarts[0].isExec == false + doAssert log.queryEnds.len == 1 + doAssert log.queryEnds[0].rowCount == 3 + doAssert not log.queryEnds[0].hasErr + + await conn.close() + + waitFor t() + +suite "Tracing: simpleExec": + test "onQueryStart and onQueryEnd are called": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + let conn = await connect(tracedConfig(tracer)) + log.queryStarts.setLen(0) + log.queryEnds.setLen(0) + + discard await conn.simpleExec("SELECT 1") + + doAssert log.queryStarts.len == 1 + doAssert log.queryStarts[0].isExec == true + doAssert log.queryEnds.len == 1 + doAssert not log.queryEnds[0].hasErr + + await conn.close() + + waitFor t() + +suite "Tracing: simpleQuery": + test "onQueryStart and onQueryEnd are called": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + let conn = await connect(tracedConfig(tracer)) + log.queryStarts.setLen(0) + log.queryEnds.setLen(0) + + discard await conn.simpleQuery("SELECT 1; SELECT 2") + + doAssert log.queryStarts.len == 1 + doAssert log.queryStarts[0].isExec == false + doAssert log.queryEnds.len == 1 + doAssert log.queryEnds[0].rowCount == 2 # 1 row per statement + doAssert not log.queryEnds[0].hasErr + + await conn.close() + + waitFor t() + +suite "Tracing: prepare": + test "onPrepareStart and onPrepareEnd with context correlation": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + let conn = await connect(tracedConfig(tracer)) + + discard await conn.prepare("test_stmt", "SELECT $1::int4") + + doAssert log.prepareStarts.len == 1 + doAssert log.prepareStarts[0].name == "test_stmt" + doAssert log.prepareStarts[0].sql == "SELECT $1::int4" + doAssert log.prepareEnds.len == 1 + doAssert not log.prepareEnds[0].hasErr + doAssert log.prepareEnds[0].spanId > 0 + + await conn.close() + + waitFor t() + +suite "Tracing: PreparedStatement.execute": + test "onQueryStart and onQueryEnd are called": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + let conn = await connect(tracedConfig(tracer)) + + let stmt = await conn.prepare("test_exec_stmt", "SELECT $1::int4") + log.queryStarts.setLen(0) + log.queryEnds.setLen(0) + + discard await stmt.execute(@[toPgParam(42'i32)]) + + doAssert log.queryStarts.len == 1 + doAssert log.queryStarts[0].sql == "SELECT $1::int4" + doAssert log.queryStarts[0].isExec == false + doAssert log.queryEnds.len == 1 + doAssert log.queryEnds[0].rowCount == 1 + doAssert not log.queryEnds[0].hasErr + doAssert log.queryEnds[0].spanId > 0 + + await conn.close() + + waitFor t() + +suite "Tracing: copyIn": + test "onCopyStart(tcdIn) and onCopyEnd with commandTag": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + let conn = await connect(tracedConfig(tracer)) + + discard + await conn.exec("CREATE TEMP TABLE test_trace_copy_in (id int, name text)") + + var rows: seq[byte] + rows.add("1\tAlice\n".toBytes()) + rows.add("2\tBob\n".toBytes()) + discard await conn.copyIn("COPY test_trace_copy_in FROM STDIN", rows) + + doAssert log.copyStarts.len == 1 + doAssert log.copyStarts[0].direction == tcdIn + doAssert "test_trace_copy_in" in log.copyStarts[0].sql + doAssert log.copyEnds.len == 1 + doAssert log.copyEnds[0].commandTag.startsWith("COPY") + doAssert not log.copyEnds[0].hasErr + + await conn.close() + + waitFor t() + +suite "Tracing: copyOut": + test "onCopyStart(tcdOut) and onCopyEnd with commandTag": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + let conn = await connect(tracedConfig(tracer)) + + discard + await conn.exec("CREATE TEMP TABLE test_trace_copy_out (id int, name text)") + discard await conn.exec("INSERT INTO test_trace_copy_out VALUES (1, 'Alice')") + + log.copyStarts.setLen(0) + log.copyEnds.setLen(0) + + discard await conn.copyOut("COPY test_trace_copy_out TO STDOUT") + + doAssert log.copyStarts.len == 1 + doAssert log.copyStarts[0].direction == tcdOut + doAssert log.copyEnds.len == 1 + doAssert log.copyEnds[0].commandTag.startsWith("COPY") + doAssert not log.copyEnds[0].hasErr + + await conn.close() + + waitFor t() + +suite "Tracing: copyInStream": + test "onCopyStart(tcdIn) and onCopyEnd with commandTag": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + let conn = await connect(tracedConfig(tracer)) + + discard await conn.exec( + "CREATE TEMP TABLE test_trace_copy_in_stream (id int, name text)" + ) + + var chunks = @["1\tAlice\n".toBytes(), "2\tBob\n".toBytes()] + var idx = 0 + let cb = makeCopyInCallback: + if idx < chunks.len: + let data = chunks[idx] + idx.inc + data + else: + @[] + + log.copyStarts.setLen(0) + log.copyEnds.setLen(0) + + discard await conn.copyInStream("COPY test_trace_copy_in_stream FROM STDIN", cb) + + doAssert log.copyStarts.len == 1 + doAssert log.copyStarts[0].direction == tcdIn + doAssert "test_trace_copy_in_stream" in log.copyStarts[0].sql + doAssert log.copyEnds.len == 1 + doAssert log.copyEnds[0].commandTag.startsWith("COPY") + doAssert not log.copyEnds[0].hasErr + + await conn.close() + + waitFor t() + +suite "Tracing: copyOutStream": + test "onCopyStart(tcdOut) and onCopyEnd with commandTag": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + let conn = await connect(tracedConfig(tracer)) + + discard await conn.exec( + "CREATE TEMP TABLE test_trace_copy_out_stream (id int, name text)" + ) + discard + await conn.exec("INSERT INTO test_trace_copy_out_stream VALUES (1, 'Alice')") + + log.copyStarts.setLen(0) + log.copyEnds.setLen(0) + + var received: seq[seq[byte]] + let cb = makeCopyOutCallback: + received.add(data) + + discard await conn.copyOutStream("COPY test_trace_copy_out_stream TO STDOUT", cb) + + doAssert received.len > 0 + doAssert log.copyStarts.len == 1 + doAssert log.copyStarts[0].direction == tcdOut + doAssert log.copyEnds.len == 1 + doAssert log.copyEnds[0].commandTag.startsWith("COPY") + doAssert not log.copyEnds[0].hasErr + + await conn.close() + + waitFor t() + +suite "Tracing: pipeline": + test "onPipelineStart and onPipelineEnd with opCount": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + let conn = await connect(tracedConfig(tracer)) + + let p = newPipeline(conn) + p.addQuery("SELECT 1::int4") + p.addQuery("SELECT 2::int4") + p.addQuery("SELECT 3::int4") + discard await p.execute() + + doAssert log.pipelineStarts.len == 1 + doAssert log.pipelineStarts[0].opCount == 3 + doAssert log.pipelineEnds.len == 1 + doAssert not log.pipelineEnds[0].hasErr + doAssert log.pipelineEnds[0].spanId > 0 + + await conn.close() + + waitFor t() + +suite "Tracing: pool acquire/release": + test "acquire and release hooks are called": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + var poolCfg = initPoolConfig(tracedConfig(tracer), minSize = 0, maxSize = 2) + poolCfg.tracer = tracer + let pool = await newPool(poolCfg) + + let conn = await pool.acquire() + + doAssert log.poolAcquireStarts.len == 1 + doAssert log.poolAcquireStarts[0].maxSize == 2 + doAssert log.poolAcquireEnds.len == 1 + doAssert log.poolAcquireEnds[0].hasConn + doAssert log.poolAcquireEnds[0].wasCreated == true + doAssert not log.poolAcquireEnds[0].hasErr + + pool.release(conn) + + doAssert log.poolReleaseStarts.len == 1 + doAssert log.poolReleaseStarts[0].hasConn + doAssert log.poolReleaseEnds.len == 1 + doAssert not log.poolReleaseEnds[0].wasClosed + doAssert not log.poolReleaseEnds[0].handedToWaiter + + # Re-acquire should reuse the idle connection + let conn2 = await pool.acquire() + + doAssert log.poolAcquireStarts.len == 2 + doAssert log.poolAcquireEnds.len == 2 + doAssert log.poolAcquireEnds[1].wasCreated == false + + pool.release(conn2) + await pool.close() + + waitFor t() + + test "release hands connection to waiter": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + var poolCfg = initPoolConfig(tracedConfig(tracer), minSize = 0, maxSize = 1) + poolCfg.tracer = tracer + let pool = await newPool(poolCfg) + + let conn1 = await pool.acquire() + # maxSize=1, so next acquire will wait + let fut = pool.acquire() + # Release conn1 -- should hand to the waiter + pool.release(conn1) + let conn2 = await fut + + doAssert log.poolReleaseEnds.len == 1 + doAssert not log.poolReleaseEnds[0].wasClosed + doAssert log.poolReleaseEnds[0].handedToWaiter + + pool.release(conn2) + await pool.close() + + waitFor t() + +suite "Tracing: queryEach": + test "onQueryStart and onQueryEnd are called with rowCount": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + let conn = await connect(tracedConfig(tracer)) + log.queryStarts.setLen(0) + log.queryEnds.setLen(0) + + var rows: seq[string] + discard await conn.queryEach( + "SELECT generate_series(1, 3)", + callback = proc(row: Row) = + rows.add(row.getStr(0)), + ) + + doAssert rows.len == 3 + doAssert log.queryStarts.len == 1 + doAssert log.queryStarts[0].sql == "SELECT generate_series(1, 3)" + doAssert log.queryStarts[0].isExec == false + doAssert log.queryEnds.len == 1 + doAssert log.queryEnds[0].rowCount == 3 + doAssert not log.queryEnds[0].hasErr + + await conn.close() + + waitFor t() + +suite "Tracing: execInTransaction": + test "onQueryStart(isExec=true) and onQueryEnd with commandTag": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + let conn = await connect(tracedConfig(tracer)) + log.queryStarts.setLen(0) + log.queryEnds.setLen(0) + + discard + await conn.execInTransaction("CREATE TEMP TABLE test_trace_exec_tx (id int)") + + doAssert log.queryStarts.len == 1 + doAssert log.queryStarts[0].isExec == true + doAssert log.queryEnds.len == 1 + doAssert log.queryEnds[0].commandTag.len > 0 + doAssert not log.queryEnds[0].hasErr + + await conn.close() + + waitFor t() + +suite "Tracing: queryInTransaction": + test "onQueryStart(isExec=false) and onQueryEnd with rowCount": + proc t() {.async.} = + let log = newTraceLog() + let tracer = buildTracer(log) + let conn = await connect(tracedConfig(tracer)) + log.queryStarts.setLen(0) + log.queryEnds.setLen(0) + + let qr = await conn.queryInTransaction("SELECT generate_series(1, 3)") + + doAssert qr.rowCount == 3 + doAssert log.queryStarts.len == 1 + doAssert log.queryStarts[0].sql == "SELECT generate_series(1, 3)" + doAssert log.queryStarts[0].isExec == false + doAssert log.queryEnds.len == 1 + doAssert log.queryEnds[0].rowCount == 3 + doAssert not log.queryEnds[0].hasErr + + await conn.close() + + waitFor t() + +suite "Tracing: nil tracer": + test "operations work without tracer": + proc t() {.async.} = + let conn = await connect(plainConfig()) + doAssert conn.tracer == nil + + discard await conn.exec("SELECT 1") + let qr = await conn.query("SELECT 1") + doAssert qr.rowCount == 1 + discard await conn.simpleExec("SELECT 1") + discard await conn.simpleQuery("SELECT 1") + + await conn.close() + + waitFor t()