Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion async_postgres/pg_client.nim
Original file line number Diff line number Diff line change
Expand Up @@ -612,9 +612,9 @@ template queryEachRecvLoop(
conn.recvBuf.toOpenArray(pos, conn.recvBuf.len - 1), consumed, rd
)
if res.state == psIncomplete:
conn.recvBufStart = pos
break # need more data
pos += consumed
conn.recvBufStart = pos
if res.state == psDataRow:
# DataRow was parsed into rd — invoke callback, then reset for next row
if callbackError == nil:
Expand Down
3 changes: 1 addition & 2 deletions async_postgres/pg_connection.nim
Original file line number Diff line number Diff line change
Expand Up @@ -651,9 +651,9 @@ proc nextMessage*(
conn.recvBuf.toOpenArray(pos, conn.recvBuf.len - 1), consumed, rowData
)
if res.state == psIncomplete:
conn.recvBufStart = pos
return none(BackendMessage)
pos += consumed
conn.recvBufStart = pos
if res.state == psDataRow:
# DataRow already parsed in-place into rowData; just count it
if rowCount != nil:
Expand All @@ -668,7 +668,6 @@ proc nextMessage*(
if res.message.kind == bmkDataRow and rowCount != nil:
rowCount[] += 1
continue
conn.recvBufStart = pos
return some(res.message)

proc recvMessage*(
Expand Down
73 changes: 72 additions & 1 deletion tests/test_protocol.nim
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import std/[unittest, options, strutils]
import std/[unittest, options, strutils, tables, importutils]

import ../async_postgres/async_backend
import ../async_postgres/[pg_protocol, pg_connection]

privateAccess(PgConnection)

proc parseBackendMessage(buf: var seq[byte]): ParseResult =
## Test-only wrapper that preserves the old var-buf interface.
var consumed: int
Expand Down Expand Up @@ -796,3 +799,71 @@ suite "Frontend encoding - edge cases":
let msg = encodeCopyFail("")
check msg[0] == byte('f')
check decodeInt32(msg, 1) == int32(msg.len - 1)

suite "nextMessage recvBufStart update":
proc buildMsg(msgType: char, body: seq[byte]): seq[byte] =
result = @[byte(msgType)]
result.addInt32(int32(4 + body.len))
result.add(body)

proc buildDataRowMsg(values: openArray[string]): seq[byte] =
var body: seq[byte] = @[]
body.addInt16(int16(values.len))
for v in values:
body.addInt32(int32(v.len))
for c in v:
body.add(byte(c))
buildMsg('D', body)

proc mockConn(): PgConnection =
PgConnection(
recvBuf: @[],
recvBufStart: 0,
state: csReady,
txStatus: tsIdle,
serverParams: initTable[string, string](),
createdAt: Moment.now(),
)

test "recvBufStart advances past DataRows on exception":
var conn = mockConn()
let validRow = buildDataRowMsg(["hello"])
# Build a malformed DataRow: claims 1 column but body is truncated
var badBody: seq[byte] = @[]
badBody.addInt16(1)
badBody.addInt32(-2) # invalid column length
let badRow = buildMsg('D', badBody)

conn.recvBuf = validRow & badRow
conn.recvBufStart = 0

var rd = newRowData(1)
var count: int32 = 0
expect ProtocolError:
discard conn.nextMessage(rd, addr count)

# recvBufStart should have advanced past the valid DataRow
check conn.recvBufStart == validRow.len

test "recvBufStart advances past multiple DataRows before non-DataRow":
var conn = mockConn()
let row1 = buildDataRowMsg(["a"])
let row2 = buildDataRowMsg(["bb"])
# CommandComplete: 'C' with tag "SELECT 2\0"
var ccBody: seq[byte] = @[]
for c in "SELECT 2":
ccBody.add(byte(c))
ccBody.add(0'u8)
let cc = buildMsg('C', ccBody)

conn.recvBuf = row1 & row2 & cc
conn.recvBufStart = 0

var rd = newRowData(1)
var count: int32 = 0
let opt = conn.nextMessage(rd, addr count)

check opt.isSome
check opt.get.kind == bmkCommandComplete
check count == 2
check conn.recvBufStart == row1.len + row2.len + cc.len
Loading