From 7d91a448b8dcde2380c11edab7fa2f076a68f45a Mon Sep 17 00:00:00 2001 From: fox0430 Date: Mon, 13 Apr 2026 16:55:11 +0900 Subject: [PATCH] Fix recvBufStart update timing to prevent double-parsing on exception --- async_postgres/pg_client.nim | 2 +- async_postgres/pg_connection.nim | 3 +- tests/test_protocol.nim | 73 +++++++++++++++++++++++++++++++- 3 files changed, 74 insertions(+), 4 deletions(-) diff --git a/async_postgres/pg_client.nim b/async_postgres/pg_client.nim index e3ebbdf..97c6658 100644 --- a/async_postgres/pg_client.nim +++ b/async_postgres/pg_client.nim @@ -612,9 +612,9 @@ template queryEachRecvLoop( conn.recvBuf.toOpenArray(pos, conn.recvBuf.len - 1), consumed, rd ) if res.state == psIncomplete: - conn.recvBufStart = pos break # need more data pos += consumed + conn.recvBufStart = pos if res.state == psDataRow: # DataRow was parsed into rd — invoke callback, then reset for next row if callbackError == nil: diff --git a/async_postgres/pg_connection.nim b/async_postgres/pg_connection.nim index e1d4abc..69ac0e5 100644 --- a/async_postgres/pg_connection.nim +++ b/async_postgres/pg_connection.nim @@ -651,9 +651,9 @@ proc nextMessage*( conn.recvBuf.toOpenArray(pos, conn.recvBuf.len - 1), consumed, rowData ) if res.state == psIncomplete: - conn.recvBufStart = pos return none(BackendMessage) pos += consumed + conn.recvBufStart = pos if res.state == psDataRow: # DataRow already parsed in-place into rowData; just count it if rowCount != nil: @@ -668,7 +668,6 @@ proc nextMessage*( if res.message.kind == bmkDataRow and rowCount != nil: rowCount[] += 1 continue - conn.recvBufStart = pos return some(res.message) proc recvMessage*( diff --git a/tests/test_protocol.nim b/tests/test_protocol.nim index e2bc894..24defed 100644 --- a/tests/test_protocol.nim +++ b/tests/test_protocol.nim @@ -1,7 +1,10 @@ -import std/[unittest, options, strutils] +import std/[unittest, options, strutils, tables, importutils] +import ../async_postgres/async_backend import ../async_postgres/[pg_protocol, pg_connection] +privateAccess(PgConnection) + proc parseBackendMessage(buf: var seq[byte]): ParseResult = ## Test-only wrapper that preserves the old var-buf interface. var consumed: int @@ -796,3 +799,71 @@ suite "Frontend encoding - edge cases": let msg = encodeCopyFail("") check msg[0] == byte('f') check decodeInt32(msg, 1) == int32(msg.len - 1) + +suite "nextMessage recvBufStart update": + proc buildMsg(msgType: char, body: seq[byte]): seq[byte] = + result = @[byte(msgType)] + result.addInt32(int32(4 + body.len)) + result.add(body) + + proc buildDataRowMsg(values: openArray[string]): seq[byte] = + var body: seq[byte] = @[] + body.addInt16(int16(values.len)) + for v in values: + body.addInt32(int32(v.len)) + for c in v: + body.add(byte(c)) + buildMsg('D', body) + + proc mockConn(): PgConnection = + PgConnection( + recvBuf: @[], + recvBufStart: 0, + state: csReady, + txStatus: tsIdle, + serverParams: initTable[string, string](), + createdAt: Moment.now(), + ) + + test "recvBufStart advances past DataRows on exception": + var conn = mockConn() + let validRow = buildDataRowMsg(["hello"]) + # Build a malformed DataRow: claims 1 column but body is truncated + var badBody: seq[byte] = @[] + badBody.addInt16(1) + badBody.addInt32(-2) # invalid column length + let badRow = buildMsg('D', badBody) + + conn.recvBuf = validRow & badRow + conn.recvBufStart = 0 + + var rd = newRowData(1) + var count: int32 = 0 + expect ProtocolError: + discard conn.nextMessage(rd, addr count) + + # recvBufStart should have advanced past the valid DataRow + check conn.recvBufStart == validRow.len + + test "recvBufStart advances past multiple DataRows before non-DataRow": + var conn = mockConn() + let row1 = buildDataRowMsg(["a"]) + let row2 = buildDataRowMsg(["bb"]) + # CommandComplete: 'C' with tag "SELECT 2\0" + var ccBody: seq[byte] = @[] + for c in "SELECT 2": + ccBody.add(byte(c)) + ccBody.add(0'u8) + let cc = buildMsg('C', ccBody) + + conn.recvBuf = row1 & row2 & cc + conn.recvBufStart = 0 + + var rd = newRowData(1) + var count: int32 = 0 + let opt = conn.nextMessage(rd, addr count) + + check opt.isSome + check opt.get.kind == bmkCommandComplete + check count == 2 + check conn.recvBufStart == row1.len + row2.len + cc.len