Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions async_postgres/pg_protocol.nim
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,8 @@ proc parseDataRow(body: openArray[byte]): BackendMessage =
raise newException(ProtocolError, "DataRow message too short")
result = BackendMessage(kind: bmkDataRow)
let numCols = decodeInt16(body, 0)
if numCols < 0:
raise newException(ProtocolError, "DataRow: invalid column count " & $numCols)
result.columns = newSeq[Option[seq[byte]]](numCols)
var offset = 2
for i in 0 ..< numCols:
Expand Down Expand Up @@ -819,6 +821,9 @@ proc parseRowDescription(body: openArray[byte]): BackendMessage =
raise newException(ProtocolError, "RowDescription message too short")
result = BackendMessage(kind: bmkRowDescription)
let numFields = decodeInt16(body, 0)
if numFields < 0:
raise
newException(ProtocolError, "RowDescription: invalid field count " & $numFields)
result.fields = newSeq[FieldDescription](numFields)
var offset = 2
for i in 0 ..< numFields:
Expand Down Expand Up @@ -856,6 +861,10 @@ proc parseParameterDescription(body: openArray[byte]): BackendMessage =
raise newException(ProtocolError, "ParameterDescription too short")
result = BackendMessage(kind: bmkParameterDescription)
let numParams = decodeInt16(body, 0)
if numParams < 0:
raise newException(
ProtocolError, "ParameterDescription: invalid param count " & $numParams
)
result.paramTypeOids = newSeq[int32](numParams)
var offset = 2
for i in 0 ..< numParams:
Expand All @@ -873,6 +882,8 @@ proc parseCopyResponse(body: openArray[byte], isIn: bool): BackendMessage =
result = BackendMessage(kind: bmkCopyOutResponse)
result.copyFormat = if body[0] == 0: cfText else: cfBinary
let numCols = decodeInt16(body, 1)
if numCols < 0:
raise newException(ProtocolError, "CopyResponse: invalid column count " & $numCols)
result.copyColumnFormats = newSeq[int16](numCols)
var offset = 3
for i in 0 ..< numCols:
Expand All @@ -887,6 +898,9 @@ proc parseCopyBothResponse(body: openArray[byte]): BackendMessage =
result = BackendMessage(kind: bmkCopyBothResponse)
result.copyFormat = if body[0] == 0: cfText else: cfBinary
let numCols = decodeInt16(body, 1)
if numCols < 0:
raise
newException(ProtocolError, "CopyBothResponse: invalid column count " & $numCols)
result.copyColumnFormats = newSeq[int16](numCols)
var offset = 3
for i in 0 ..< numCols:
Expand Down Expand Up @@ -1010,6 +1024,8 @@ proc parseDataRowInto*(body: openArray[byte], rd: RowData) =
if body.len < 2:
raise newException(ProtocolError, "DataRow message too short")
let numCols = decodeInt16(body, 0)
if numCols < 0:
raise newException(ProtocolError, "DataRow: invalid column count " & $numCols)
# Pre-extend cellIndex for this row
let cellBase = rd.cellIndex.len
rd.cellIndex.setLen(cellBase + int(numCols) * 2)
Expand Down
36 changes: 36 additions & 0 deletions async_postgres/pg_types/decoding.nim
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ proc decodeNumericBinary*(data: openArray[byte]): PgNumeric =
PgNumeric(weight: weight, sign: sign, dscale: dscale, digits: digits)

proc decodeBinaryTimestamp*(data: openArray[byte]): DateTime =
if data.len < 8:
raise newException(PgTypeError, "Binary timestamp data too short: " & $data.len)
let pgUs = fromBE64(data)
let unixUs = pgUs + pgEpochUnix * 1_000_000
var unixSec = unixUs div 1_000_000
Expand All @@ -84,12 +86,21 @@ proc decodeBinaryTimestamp*(data: openArray[byte]): DateTime =
initTime(unixSec, int(fracUs * 1000)).utc()

proc decodeBinaryDate*(data: openArray[byte]): DateTime =
if data.len < 4:
raise newException(PgTypeError, "Binary date data too short: " & $data.len)
let pgDays = fromBE32(data)
let unixSec = (int64(pgDays) + int64(pgEpochDaysOffset)) * 86400
initTime(unixSec, 0).utc()

const pgTimeMaxUs = 86_400_000_000'i64
## PostgreSQL time-of-day is microseconds since midnight in [0, 86_400_000_000).

proc decodeBinaryTime*(data: openArray[byte]): PgTime =
if data.len < 8:
raise newException(PgTypeError, "Binary time data too short: " & $data.len)
let us = fromBE64(data)
if us < 0 or us >= pgTimeMaxUs:
raise newException(PgTypeError, "Binary time: microseconds out of range " & $us)
let hours = int32(us div 3_600_000_000)
let rem1 = us mod 3_600_000_000
let minutes = int32(rem1 div 60_000_000)
Expand All @@ -99,7 +110,11 @@ proc decodeBinaryTime*(data: openArray[byte]): PgTime =
PgTime(hour: hours, minute: minutes, second: seconds, microsecond: microseconds)

proc decodeBinaryTimeTz*(data: openArray[byte]): PgTimeTz =
if data.len < 12:
raise newException(PgTypeError, "Binary timetz data too short: " & $data.len)
let us = fromBE64(data)
if us < 0 or us >= pgTimeMaxUs:
raise newException(PgTypeError, "Binary timetz: microseconds out of range " & $us)
let pgOffset = fromBE32(data.toOpenArray(8, 11))
let hours = int32(us div 3_600_000_000)
let rem1 = us mod 3_600_000_000
Expand All @@ -122,23 +137,31 @@ proc decodeInetBinary*(data: openArray[byte]): tuple[address: IpAddress, mask: u
## 1 byte: is_cidr (0 or 1)
## 1 byte: addrlen (4 or 16)
## N bytes: address
if data.len < 4:
raise newException(PgTypeError, "Binary inet data too short: " & $data.len)
let family = data[0]
let bits = data[1]
# data[2] = is_cidr, ignored for decoding
# data[3] = addrlen
if family == 2:
if data.len < 8:
raise newException(PgTypeError, "Binary inet IPv4 data too short: " & $data.len)
var ip = IpAddress(family: IpAddressFamily.IPv4)
for i in 0 ..< 4:
ip.address_v4[i] = data[4 + i]
(ip, bits)
else:
if data.len < 20:
raise newException(PgTypeError, "Binary inet IPv6 data too short: " & $data.len)
var ip = IpAddress(family: IpAddressFamily.IPv6)
for i in 0 ..< 16:
ip.address_v6[i] = data[4 + i]
(ip, bits)

proc decodePointBinary*(data: openArray[byte], off: int): PgPoint =
## Decode a point from 16 bytes at offset.
if off < 0 or off + 16 > data.len:
raise newException(PgTypeError, "Binary point data truncated at offset " & $off)
let xBits = uint64(
(uint64(data[off]) shl 56) or (uint64(data[off + 1]) shl 48) or
(uint64(data[off + 2]) shl 40) or (uint64(data[off + 3]) shl 32) or
Expand Down Expand Up @@ -175,6 +198,11 @@ proc decodeBinaryArray*(
let dimLen = int(fromBE32(data.toOpenArray(12, 15)))
if dimLen < 0:
raise newException(PgTypeError, "Binary array: invalid dimension length " & $dimLen)
# Each element carries at least a 4-byte length prefix after the 20-byte
# header, so dimLen cannot exceed (data.len - 20) div 4. This guard stops a
# crafted header from triggering a multi-GB allocation on malformed input.
if dimLen > (data.len - 20) div 4:
raise newException(PgTypeError, "Binary array: dimension length exceeds data")
# lower_bound at offset 16, ignored
result.elements = newSeq[tuple[off: int, len: int]](dimLen)
var pos = 20
Expand Down Expand Up @@ -205,6 +233,10 @@ proc decodeBinaryComposite*(
if numFields < 0:
raise
newException(PgTypeError, "Binary composite: invalid field count " & $numFields)
# Each field carries at least an 8-byte header (oid + len) after the 4-byte
# count, so numFields cannot exceed (data.len - 4) div 8.
if numFields > (data.len - 4) div 8:
raise newException(PgTypeError, "Binary composite: field count exceeds data")
result = newSeq[tuple[oid: int32, off: int, len: int]](numFields)
var pos = 4
for i in 0 ..< numFields:
Expand Down Expand Up @@ -479,6 +511,10 @@ proc decodeBinaryTsVector*(data: openArray[byte]): string =
if nlexemes < 0:
raise
newException(PgTypeError, "tsvector binary: invalid lexeme count " & $nlexemes)
# Each lexeme needs at least a null terminator (1 byte) + 2-byte position
# count after the 4-byte count, so nlexemes cannot exceed (data.len - 4) div 3.
if nlexemes > (data.len - 4) div 3:
raise newException(PgTypeError, "tsvector binary: lexeme count exceeds data")
var pos = 4
var parts = newSeq[string](nlexemes)
const weightChars = ['D', 'C', 'B', 'A']
Expand Down
4 changes: 2 additions & 2 deletions tests/all_tests.nim
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{.push warning[UnusedImport]: off.}
import
test_advisory_lock, test_auth, test_dsn, test_e2e, test_keepalive, test_largeobject,
test_pool, test_protocol, test_rowdata, test_sql, test_ssl, test_tracing, test_types,
test_pool_cluster
test_pool, test_protocol, test_protocol_fuzz, test_rowdata, test_sql, test_ssl,
test_tracing, test_types, test_pool_cluster
{.pop.}
Loading