From e91a03f623087527577d43a0d8ff8ab3e2da5842 Mon Sep 17 00:00:00 2001 From: fox0430 Date: Tue, 21 Apr 2026 19:35:29 +0900 Subject: [PATCH 1/2] Harden decoder against malformed input and add fuzz and network failure tests --- async_postgres/pg_connection.nim | 15 +- tests/all_tests.nim | 4 +- tests/mock_pg_server.nim | 209 ++++++++++++++++++ tests/test_network_failure.nim | 357 +++++++++++++++++++++++++++++++ 4 files changed, 580 insertions(+), 5 deletions(-) create mode 100644 tests/mock_pg_server.nim create mode 100644 tests/test_network_failure.nim diff --git a/async_postgres/pg_connection.nim b/async_postgres/pg_connection.nim index f7502cf..eb9478c 100644 --- a/async_postgres/pg_connection.nim +++ b/async_postgres/pg_connection.nim @@ -856,12 +856,21 @@ proc nextMessage*( ## Returns none if the buffer doesn't contain a complete message. ## Notification/Notice messages are dispatched internally. ## DataRow messages are counted (if rowCount != nil) and consumed. + ## + ## On `ProtocolError` the protocol stream is desynchronised — the connection + ## is transitioned to `csClosed` before re-raising so that it is never + ## reused (in particular, by the connection pool). var pos = conn.recvBufStart while true: var consumed: int - let res = parseBackendMessage( - conn.recvBuf.toOpenArray(pos, conn.recvBuf.len - 1), consumed, rowData - ) + let res = + try: + parseBackendMessage( + conn.recvBuf.toOpenArray(pos, conn.recvBuf.len - 1), consumed, rowData + ) + except ProtocolError: + conn.state = csClosed + raise if res.state == psIncomplete: return none(BackendMessage) pos += consumed diff --git a/tests/all_tests.nim b/tests/all_tests.nim index eb6226f..fdfd9fe 100644 --- a/tests/all_tests.nim +++ b/tests/all_tests.nim @@ -1,6 +1,6 @@ {.push warning[UnusedImport]: off.} import test_advisory_lock, test_auth, test_dsn, test_e2e, test_keepalive, test_largeobject, - test_pool, test_protocol, test_protocol_fuzz, test_rowdata, test_sql, test_ssl, - test_tracing, test_types, test_pool_cluster + test_network_failure, test_pool, test_protocol, test_protocol_fuzz, test_rowdata, + test_sql, test_ssl, test_tracing, test_types, test_pool_cluster {.pop.} diff --git a/tests/mock_pg_server.nim b/tests/mock_pg_server.nim new file mode 100644 index 0000000..10719a2 --- /dev/null +++ b/tests/mock_pg_server.nim @@ -0,0 +1,209 @@ +## In-process PostgreSQL wire-protocol mock server. +## +## Starts a TCP listener on 127.0.0.1 with an ephemeral port and lets test code +## script arbitrary byte sequences back to a real `PgConnection`. Used to +## exercise code paths that a real PostgreSQL server would never reproduce on +## demand: mid-message disconnects, malformed responses, truncated frames, +## stalled senders, etc. +## +## Works with both `chronos` and `asyncdispatch` via the same unified API. +## The chronos and asyncdispatch branches expose identical `MockServer` / +## `MockClient` types and procs so tests can be backend-agnostic. + +import ../async_postgres/[async_backend, pg_protocol] + +when hasAsyncDispatch: + import std/asyncnet + +# Types and low-level transport + +when hasChronos: + type + MockServer* = object + server: StreamServer + port*: int + + MockClient* = StreamTransport + + proc startMockServer*(): MockServer = + let server = createStreamServer(initTAddress("127.0.0.1", 0)) + MockServer(server: server, port: int(server.localAddress().port)) + + proc accept*(ms: MockServer): Future[MockClient] = + ms.server.accept() + + proc closeServer*(ms: MockServer) {.async.} = + await ms.server.closeWait() + + proc closeClient*(client: MockClient) {.async.} = + await client.closeWait() + + proc readN*(client: MockClient, n: int): Future[seq[byte]] {.async.} = + result = newSeq[byte](n) + var offset = 0 + while offset < n: + let bytesRead = await client.readOnce(addr result[offset], n - offset) + if bytesRead == 0: + raise newException(CatchableError, "Connection closed prematurely") + offset += bytesRead + + proc sendBytes*(client: MockClient, data: seq[byte]) {.async.} = + if data.len > 0: + discard await client.write(data) + +elif hasAsyncDispatch: + type + MockServer* = object + socket: AsyncSocket + port*: int + + MockClient* = AsyncSocket + + proc startMockServer*(): MockServer = + let sock = newAsyncSocket(buffered = false) + sock.setSockOpt(OptReuseAddr, true) + sock.bindAddr(Port(0)) + let port = int(sock.getLocalAddr()[1]) + sock.listen() + MockServer(socket: sock, port: port) + + proc accept*(ms: MockServer): Future[MockClient] = + ms.socket.accept() + + proc closeServer*(ms: MockServer) {.async.} = + ms.socket.close() + + proc closeClient*(client: MockClient) {.async.} = + client.close() + + proc readN*(client: MockClient, n: int): Future[seq[byte]] {.async.} = + result = newSeq[byte](n) + var offset = 0 + while offset < n: + let data = await client.recv(n - offset) + if data.len == 0: + raise newException(CatchableError, "Connection closed prematurely") + copyMem(addr result[offset], addr data[0], data.len) + offset += data.len + + proc sendBytes*(client: MockClient, data: seq[byte]) {.async.} = + if data.len > 0: + await client.send(cast[string](data)) + +# Message-building helpers + +proc buildBackendMsg*(msgType: char, body: openArray[byte]): seq[byte] = + ## Wrap `body` with a backend message header (1 type byte + 4 length bytes). + result = @[byte(msgType)] + result.addInt32(int32(4 + body.len)) + result.add(@body) + +proc buildAuthOk*(): seq[byte] = + buildBackendMsg('R', @[byte 0, 0, 0, 0]) + +proc buildParameterStatus*(name, value: string): seq[byte] = + var body: seq[byte] + for c in name: + body.add(byte(c)) + body.add(0'u8) + for c in value: + body.add(byte(c)) + body.add(0'u8) + buildBackendMsg('S', body) + +proc buildBackendKeyData*(pid, secretKey: int32): seq[byte] = + var body: seq[byte] + body.addInt32(pid) + body.addInt32(secretKey) + buildBackendMsg('K', body) + +proc buildReadyForQuery*(status: char = 'I'): seq[byte] = + buildBackendMsg('Z', @[byte(status)]) + +proc buildCommandComplete*(tag: string): seq[byte] = + var body: seq[byte] + for c in tag: + body.add(byte(c)) + body.add(0'u8) + buildBackendMsg('C', body) + +proc buildErrorResponse*(sqlState, message: string): seq[byte] = + ## Minimal ErrorResponse with severity 'S', sqlstate 'C', message 'M'. + var body: seq[byte] + body.add(byte('S')) + for c in "ERROR": + body.add(byte(c)) + body.add(0'u8) + body.add(byte('C')) + for c in sqlState: + body.add(byte(c)) + body.add(0'u8) + body.add(byte('M')) + for c in message: + body.add(byte(c)) + body.add(0'u8) + body.add(0'u8) # field list terminator + buildBackendMsg('E', body) + +# Frontend readers + +proc drainStartupMessage*(client: MockClient) {.async.} = + ## Consume the initial StartupMessage sent by the client (no type byte, + ## 4-byte length prefix). + let lenBuf = await readN(client, 4) + let msgLen = decodeInt32(lenBuf, 0) + if msgLen > 4: + discard await readN(client, msgLen - 4) + +proc drainFrontendMessage*( + client: MockClient +): Future[tuple[msgType: char, body: seq[byte]]] {.async.} = + ## Read a single post-startup frontend message: 1 type byte, int32 length, body. + let head = await readN(client, 1) + result.msgType = char(head[0]) + let lenBuf = await readN(client, 4) + let msgLen = decodeInt32(lenBuf, 0) + if msgLen > 4: + result.body = await readN(client, msgLen - 4) + +# Full handshake shortcut + +proc sendFullHandshake*( + client: MockClient, + pid: int32 = 1234, + secretKey: int32 = 5678, + params: seq[(string, string)] = @[], +) {.async.} = + ## Send AuthOk + ParameterStatus* + BackendKeyData + ReadyForQuery in one + ## round-trip. Use after `drainStartupMessage` and before the client's first + ## real query. + var resp: seq[byte] + resp.add(buildAuthOk()) + for (k, v) in params: + resp.add(buildParameterStatus(k, v)) + resp.add(buildBackendKeyData(pid, secretKey)) + resp.add(buildReadyForQuery('I')) + await sendBytes(client, resp) + +proc sendEmptyHstoreDiscovery*(client: MockClient) {.async.} = + ## Respond to the post-handshake hstore OID discovery query with an empty + ## result set (CommandComplete "SELECT 0" + ReadyForQuery). Leaves the + ## client in `csReady` state. + var resp: seq[byte] + resp.add(buildCommandComplete("SELECT 0")) + resp.add(buildReadyForQuery('I')) + await sendBytes(client, resp) + +proc acceptAndReady*( + ms: MockServer, pid: int32 = 1234, secretKey: int32 = 5678 +): Future[MockClient] {.async.} = + ## End-to-end helper: accept a client, complete startup + handshake, and + ## answer the hstore OID discovery query with an empty result. On return + ## the client is positioned at `csReady` with no outstanding requests. + let client = await ms.accept() + await drainStartupMessage(client) + await sendFullHandshake(client, pid, secretKey) + # The connect() path issues the hstore discovery query next. + discard await drainFrontendMessage(client) + await sendEmptyHstoreDiscovery(client) + client diff --git a/tests/test_network_failure.nim b/tests/test_network_failure.nim new file mode 100644 index 0000000..3ac4efc --- /dev/null +++ b/tests/test_network_failure.nim @@ -0,0 +1,357 @@ +## Network-failure tests using an in-process mock PostgreSQL server. +## +## These exercise code paths that a real server will not reproduce on demand: +## mid-handshake disconnects, mid-query disconnects, malformed backend +## messages, and unknown message type bytes. Verifies that the client raises +## the expected exception type and leaves the connection in a state that +## prevents accidental reuse (`csClosed`). + +import std/[unittest] + +import ../async_postgres/[async_backend, pg_protocol] +import ../async_postgres/pg_connection {.all.} + +import ./mock_pg_server + +proc mockConfig(port: int): ConnConfig = + ConnConfig( + host: "127.0.0.1", port: port, user: "test", database: "test", sslMode: sslDisable + ) + +# Handshake failures + +suite "Network failure: handshake": + test "server closes immediately after accept": + var raised = false + + proc testBody() {.async.} = + let ms = startMockServer() + proc serverHandler() {.async.} = + let st = await ms.accept() + await closeClient(st) + + let serverFut = serverHandler() + try: + let conn = await connect(mockConfig(ms.port)) + await conn.close() + except CatchableError: + raised = true + await serverFut + await closeServer(ms) + + waitFor testBody() + check raised + + test "server reads startup then disconnects": + var raised = false + + proc testBody() {.async.} = + let ms = startMockServer() + proc serverHandler() {.async.} = + let st = await ms.accept() + try: + await drainStartupMessage(st) + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + try: + let conn = await connect(mockConfig(ms.port)) + await conn.close() + except CatchableError: + raised = true + await serverFut + await closeServer(ms) + + waitFor testBody() + check raised + + test "handshake succeeds but hstore discovery is dropped mid-response": + # Server completes auth + ready, but closes the socket after the client's + # hstore discovery query, before sending any response. + var raised = false + + proc testBody() {.async.} = + let ms = startMockServer() + proc serverHandler() {.async.} = + let st = await ms.accept() + try: + await drainStartupMessage(st) + await sendFullHandshake(st) + discard await drainFrontendMessage(st) # the hstore query + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + try: + let conn = await connect(mockConfig(ms.port)) + await conn.close() + except CatchableError: + raised = true + await serverFut + await closeServer(ms) + + waitFor testBody() + check raised + +# Malformed / truncated backend messages + +suite "Network failure: malformed server messages": + test "unknown backend message type 'X' raises ProtocolError": + var raised = false + var finalState: PgConnState + + proc testBody() {.async.} = + let ms = startMockServer() + proc serverHandler() {.async.} = + let st = await acceptAndReady(ms) + try: + discard await drainFrontendMessage(st) # SELECT 1 + # Reply with an unknown message type wrapped in a valid frame. + await sendBytes(st, buildBackendMsg('X', @[byte 0])) + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + let conn = await connect(mockConfig(ms.port)) + try: + discard await conn.simpleQuery("SELECT 1") + except ProtocolError: + raised = true + except CatchableError: + raised = true + finalState = conn.state + try: + await conn.close() + except CatchableError: + discard + await serverFut + await closeServer(ms) + + waitFor testBody() + check raised + check finalState == csClosed + + test "truncated RowDescription before full body arrives": + # Server sends a valid header claiming a larger body than it will deliver, + # then closes. The read loop should return an incomplete-parse state until + # the close signals EOF, which raises an error. + var raised = false + + proc testBody() {.async.} = + let ms = startMockServer() + proc serverHandler() {.async.} = + let st = await acceptAndReady(ms) + try: + discard await drainFrontendMessage(st) + # Claim a 100-byte message but only send 10 bytes then close. + var truncated: seq[byte] + truncated.add(byte('T')) + truncated.addInt32(100'i32) + for _ in 0 ..< 5: + truncated.add(0'u8) + await sendBytes(st, truncated) + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + let conn = await connect(mockConfig(ms.port)) + try: + discard await conn.simpleQuery("SELECT 1") + except CatchableError: + raised = true + try: + await conn.close() + except CatchableError: + discard + await serverFut + await closeServer(ms) + + waitFor testBody() + check raised + + test "claimed msgLen below minimum (3) raises ProtocolError": + var raised = false + var gotProtocolError = false + + proc testBody() {.async.} = + let ms = startMockServer() + proc serverHandler() {.async.} = + let st = await acceptAndReady(ms) + try: + discard await drainFrontendMessage(st) + # 5-byte frame with msgLen=3 — parser must reject. + var buf = newSeq[byte](5) + buf[0] = byte('C') + buf[1] = 0 + buf[2] = 0 + buf[3] = 0 + buf[4] = 3 + await sendBytes(st, buf) + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + let conn = await connect(mockConfig(ms.port)) + try: + discard await conn.simpleQuery("SELECT 1") + except ProtocolError: + raised = true + gotProtocolError = true + except CatchableError: + raised = true + try: + await conn.close() + except CatchableError: + discard + await serverFut + await closeServer(ms) + + waitFor testBody() + check raised + check gotProtocolError + + test "malformed CommandComplete without null terminator raises ProtocolError": + var raised = false + var gotProtocolError = false + + proc testBody() {.async.} = + let ms = startMockServer() + proc serverHandler() {.async.} = + let st = await acceptAndReady(ms) + try: + discard await drainFrontendMessage(st) + # CommandComplete body with no NUL terminator. + var buf: seq[byte] + buf.add(byte('C')) + buf.addInt32(int32(4 + 3)) + buf.add(byte('S')) + buf.add(byte('E')) + buf.add(byte('L')) # no trailing 0 + await sendBytes(st, buf) + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + let conn = await connect(mockConfig(ms.port)) + try: + discard await conn.simpleQuery("SELECT 1") + except ProtocolError: + raised = true + gotProtocolError = true + except CatchableError: + raised = true + try: + await conn.close() + except CatchableError: + discard + await serverFut + await closeServer(ms) + + waitFor testBody() + check raised + check gotProtocolError + +# Mid-query disconnects + +suite "Network failure: mid-query disconnects": + test "server closes after sending partial DataRow": + # Note: csClosed here is reached via the EOF / socket-close path, not via + # the `ProtocolError -> csClosed` transition in `nextMessage`. If future + # changes decouple socket close from state transition, this assertion may + # start failing silently (the state would no longer be csClosed). + var raised = false + var finalState: PgConnState + + proc testBody() {.async.} = + let ms = startMockServer() + proc serverHandler() {.async.} = + let st = await acceptAndReady(ms) + try: + discard await drainFrontendMessage(st) + # Send a valid RowDescription with one int4 column (oid=23). + var rd: seq[byte] + rd.addInt16(1) # 1 field + for c in "n": + rd.add(byte(c)) + rd.add(0'u8) # name terminator + rd.addInt32(0'i32) # tableOid + rd.addInt16(0'i16) # columnAttrNum + rd.addInt32(23'i32) # typeOid (int4) + rd.addInt16(4'i16) # typeSize + rd.addInt32(-1'i32) # typeMod + rd.addInt16(0'i16) # formatCode + await sendBytes(st, buildBackendMsg('T', rd)) + # Start a DataRow but truncate it mid-column. + var dr: seq[byte] + dr.addInt16(1) # 1 column + dr.addInt32(100'i32) # claim 100 bytes + dr.add(byte(0)) # partial + # Claim the whole msgLen is correct for our buffer so parser reads it. + let body = dr + var frame: seq[byte] + frame.add(byte('D')) + frame.addInt32(int32(4 + body.len)) + frame.add(body) + await sendBytes(st, frame) + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + let conn = await connect(mockConfig(ms.port)) + try: + discard await conn.simpleQuery("SELECT 1") + except CatchableError: + raised = true + finalState = conn.state + try: + await conn.close() + except CatchableError: + discard + await serverFut + await closeServer(ms) + + waitFor testBody() + check raised + check finalState == csClosed + + test "server closes after sending ErrorResponse without final field terminator": + var raised = false + + proc testBody() {.async.} = + let ms = startMockServer() + proc serverHandler() {.async.} = + let st = await acceptAndReady(ms) + try: + discard await drainFrontendMessage(st) + # ErrorResponse body with field 'M' value missing its NUL terminator. + var body: seq[byte] + body.add(byte('M')) + body.add(byte('o')) # start value, no NUL + await sendBytes(st, buildBackendMsg('E', body)) + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + let conn = await connect(mockConfig(ms.port)) + try: + discard await conn.simpleQuery("SELECT 1") + except CatchableError: + raised = true + try: + await conn.close() + except CatchableError: + discard + await serverFut + await closeServer(ms) + + waitFor testBody() + check raised From 0bf46d6bcb1800a10af695e717b3903cb1da202a Mon Sep 17 00:00:00 2001 From: fox0430 Date: Tue, 21 Apr 2026 19:55:06 +0900 Subject: [PATCH 2/2] fix --- async_postgres/pg_connection.nim | 17 +++++++++++------ async_postgres/pg_protocol.nim | 2 +- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/async_postgres/pg_connection.nim b/async_postgres/pg_connection.nim index eb9478c..9e488b1 100644 --- a/async_postgres/pg_connection.nim +++ b/async_postgres/pg_connection.nim @@ -698,7 +698,7 @@ else: const RecvBufSize = 131072 ## Size of the temporary read buffer for recv operations -proc dispatchNotification*(conn: PgConnection, msg: BackendMessage) = +proc dispatchNotification*(conn: PgConnection, msg: BackendMessage) {.raises: [].} = let notif = Notification( pid: msg.notifPid, channel: msg.notifChannel, payload: msg.notifPayload ) @@ -713,11 +713,16 @@ proc dispatchNotification*(conn: PgConnection, msg: BackendMessage) = if droppedNow > 0 and conn.notifyOverflowCallback != nil: conn.notifyOverflowCallback(droppedNow) if conn.notifyWaiter != nil and not conn.notifyWaiter.finished: - conn.notifyWaiter.complete() + # asyncdispatch's `Future.complete` has inferred effect `Exception` + # via the callback chain; swallow it to keep this proc `raises: []`. + try: + conn.notifyWaiter.complete() + except Exception: + discard if conn.notifyCallback != nil: conn.notifyCallback(notif) -proc dispatchNotice*(conn: PgConnection, msg: BackendMessage) = +proc dispatchNotice*(conn: PgConnection, msg: BackendMessage) {.raises: [].} = if conn.noticeCallback != nil: conn.noticeCallback(Notice(fields: msg.noticeFields)) @@ -851,7 +856,7 @@ proc fillRecvBuf*( proc nextMessage*( conn: PgConnection, rowData: RowData = nil, rowCount: ptr int32 = nil -): Option[BackendMessage] = +): Option[BackendMessage] {.raises: [ProtocolError].} = ## Synchronously parse the next message from the receive buffer. ## Returns none if the buffer doesn't contain a complete message. ## Notification/Notice messages are dispatched internally. @@ -868,9 +873,9 @@ proc nextMessage*( parseBackendMessage( conn.recvBuf.toOpenArray(pos, conn.recvBuf.len - 1), consumed, rowData ) - except ProtocolError: + except ProtocolError as e: conn.state = csClosed - raise + raise e if res.state == psIncomplete: return none(BackendMessage) pos += consumed diff --git a/async_postgres/pg_protocol.nim b/async_postgres/pg_protocol.nim index a8a0612..762f907 100644 --- a/async_postgres/pg_protocol.nim +++ b/async_postgres/pg_protocol.nim @@ -1069,7 +1069,7 @@ proc parseDataRowInto*(body: openArray[byte], rd: RowData) = proc parseBackendMessage*( buf: openArray[byte], consumed: var int, rowData: RowData = nil -): ParseResult = +): ParseResult {.raises: [ProtocolError].} = ## Parse a single backend message from `buf`. ## On success, sets `consumed` to the number of bytes used. ## The caller is responsible for discarding those bytes from the buffer.