From e3ecc7c441537703b412dc0bead04ac1c89f40c7 Mon Sep 17 00:00:00 2001 From: fox0430 Date: Tue, 21 Apr 2026 19:13:57 +0900 Subject: [PATCH] Add protocol decoder fuzz tests and bounds guards --- async_postgres/pg_protocol.nim | 16 + async_postgres/pg_types/decoding.nim | 36 ++ tests/all_tests.nim | 4 +- tests/test_protocol_fuzz.nim | 594 +++++++++++++++++++++++++++ 4 files changed, 648 insertions(+), 2 deletions(-) create mode 100644 tests/test_protocol_fuzz.nim diff --git a/async_postgres/pg_protocol.nim b/async_postgres/pg_protocol.nim index 6dab3ba..a8a0612 100644 --- a/async_postgres/pg_protocol.nim +++ b/async_postgres/pg_protocol.nim @@ -757,6 +757,8 @@ proc parseDataRow(body: openArray[byte]): BackendMessage = raise newException(ProtocolError, "DataRow message too short") result = BackendMessage(kind: bmkDataRow) let numCols = decodeInt16(body, 0) + if numCols < 0: + raise newException(ProtocolError, "DataRow: invalid column count " & $numCols) result.columns = newSeq[Option[seq[byte]]](numCols) var offset = 2 for i in 0 ..< numCols: @@ -819,6 +821,9 @@ proc parseRowDescription(body: openArray[byte]): BackendMessage = raise newException(ProtocolError, "RowDescription message too short") result = BackendMessage(kind: bmkRowDescription) let numFields = decodeInt16(body, 0) + if numFields < 0: + raise + newException(ProtocolError, "RowDescription: invalid field count " & $numFields) result.fields = newSeq[FieldDescription](numFields) var offset = 2 for i in 0 ..< numFields: @@ -856,6 +861,10 @@ proc parseParameterDescription(body: openArray[byte]): BackendMessage = raise newException(ProtocolError, "ParameterDescription too short") result = BackendMessage(kind: bmkParameterDescription) let numParams = decodeInt16(body, 0) + if numParams < 0: + raise newException( + ProtocolError, "ParameterDescription: invalid param count " & $numParams + ) result.paramTypeOids = newSeq[int32](numParams) var offset = 2 for i in 0 ..< numParams: @@ -873,6 +882,8 @@ proc parseCopyResponse(body: openArray[byte], isIn: bool): BackendMessage = result = BackendMessage(kind: bmkCopyOutResponse) result.copyFormat = if body[0] == 0: cfText else: cfBinary let numCols = decodeInt16(body, 1) + if numCols < 0: + raise newException(ProtocolError, "CopyResponse: invalid column count " & $numCols) result.copyColumnFormats = newSeq[int16](numCols) var offset = 3 for i in 0 ..< numCols: @@ -887,6 +898,9 @@ proc parseCopyBothResponse(body: openArray[byte]): BackendMessage = result = BackendMessage(kind: bmkCopyBothResponse) result.copyFormat = if body[0] == 0: cfText else: cfBinary let numCols = decodeInt16(body, 1) + if numCols < 0: + raise + newException(ProtocolError, "CopyBothResponse: invalid column count " & $numCols) result.copyColumnFormats = newSeq[int16](numCols) var offset = 3 for i in 0 ..< numCols: @@ -1010,6 +1024,8 @@ proc parseDataRowInto*(body: openArray[byte], rd: RowData) = if body.len < 2: raise newException(ProtocolError, "DataRow message too short") let numCols = decodeInt16(body, 0) + if numCols < 0: + raise newException(ProtocolError, "DataRow: invalid column count " & $numCols) # Pre-extend cellIndex for this row let cellBase = rd.cellIndex.len rd.cellIndex.setLen(cellBase + int(numCols) * 2) diff --git a/async_postgres/pg_types/decoding.nim b/async_postgres/pg_types/decoding.nim index ec067fd..fafb0f4 100644 --- a/async_postgres/pg_types/decoding.nim +++ b/async_postgres/pg_types/decoding.nim @@ -74,6 +74,8 @@ proc decodeNumericBinary*(data: openArray[byte]): PgNumeric = PgNumeric(weight: weight, sign: sign, dscale: dscale, digits: digits) proc decodeBinaryTimestamp*(data: openArray[byte]): DateTime = + if data.len < 8: + raise newException(PgTypeError, "Binary timestamp data too short: " & $data.len) let pgUs = fromBE64(data) let unixUs = pgUs + pgEpochUnix * 1_000_000 var unixSec = unixUs div 1_000_000 @@ -84,12 +86,21 @@ proc decodeBinaryTimestamp*(data: openArray[byte]): DateTime = initTime(unixSec, int(fracUs * 1000)).utc() proc decodeBinaryDate*(data: openArray[byte]): DateTime = + if data.len < 4: + raise newException(PgTypeError, "Binary date data too short: " & $data.len) let pgDays = fromBE32(data) let unixSec = (int64(pgDays) + int64(pgEpochDaysOffset)) * 86400 initTime(unixSec, 0).utc() +const pgTimeMaxUs = 86_400_000_000'i64 + ## PostgreSQL time-of-day is microseconds since midnight in [0, 86_400_000_000). + proc decodeBinaryTime*(data: openArray[byte]): PgTime = + if data.len < 8: + raise newException(PgTypeError, "Binary time data too short: " & $data.len) let us = fromBE64(data) + if us < 0 or us >= pgTimeMaxUs: + raise newException(PgTypeError, "Binary time: microseconds out of range " & $us) let hours = int32(us div 3_600_000_000) let rem1 = us mod 3_600_000_000 let minutes = int32(rem1 div 60_000_000) @@ -99,7 +110,11 @@ proc decodeBinaryTime*(data: openArray[byte]): PgTime = PgTime(hour: hours, minute: minutes, second: seconds, microsecond: microseconds) proc decodeBinaryTimeTz*(data: openArray[byte]): PgTimeTz = + if data.len < 12: + raise newException(PgTypeError, "Binary timetz data too short: " & $data.len) let us = fromBE64(data) + if us < 0 or us >= pgTimeMaxUs: + raise newException(PgTypeError, "Binary timetz: microseconds out of range " & $us) let pgOffset = fromBE32(data.toOpenArray(8, 11)) let hours = int32(us div 3_600_000_000) let rem1 = us mod 3_600_000_000 @@ -122,16 +137,22 @@ proc decodeInetBinary*(data: openArray[byte]): tuple[address: IpAddress, mask: u ## 1 byte: is_cidr (0 or 1) ## 1 byte: addrlen (4 or 16) ## N bytes: address + if data.len < 4: + raise newException(PgTypeError, "Binary inet data too short: " & $data.len) let family = data[0] let bits = data[1] # data[2] = is_cidr, ignored for decoding # data[3] = addrlen if family == 2: + if data.len < 8: + raise newException(PgTypeError, "Binary inet IPv4 data too short: " & $data.len) var ip = IpAddress(family: IpAddressFamily.IPv4) for i in 0 ..< 4: ip.address_v4[i] = data[4 + i] (ip, bits) else: + if data.len < 20: + raise newException(PgTypeError, "Binary inet IPv6 data too short: " & $data.len) var ip = IpAddress(family: IpAddressFamily.IPv6) for i in 0 ..< 16: ip.address_v6[i] = data[4 + i] @@ -139,6 +160,8 @@ proc decodeInetBinary*(data: openArray[byte]): tuple[address: IpAddress, mask: u proc decodePointBinary*(data: openArray[byte], off: int): PgPoint = ## Decode a point from 16 bytes at offset. + if off < 0 or off + 16 > data.len: + raise newException(PgTypeError, "Binary point data truncated at offset " & $off) let xBits = uint64( (uint64(data[off]) shl 56) or (uint64(data[off + 1]) shl 48) or (uint64(data[off + 2]) shl 40) or (uint64(data[off + 3]) shl 32) or @@ -175,6 +198,11 @@ proc decodeBinaryArray*( let dimLen = int(fromBE32(data.toOpenArray(12, 15))) if dimLen < 0: raise newException(PgTypeError, "Binary array: invalid dimension length " & $dimLen) + # Each element carries at least a 4-byte length prefix after the 20-byte + # header, so dimLen cannot exceed (data.len - 20) div 4. This guard stops a + # crafted header from triggering a multi-GB allocation on malformed input. + if dimLen > (data.len - 20) div 4: + raise newException(PgTypeError, "Binary array: dimension length exceeds data") # lower_bound at offset 16, ignored result.elements = newSeq[tuple[off: int, len: int]](dimLen) var pos = 20 @@ -205,6 +233,10 @@ proc decodeBinaryComposite*( if numFields < 0: raise newException(PgTypeError, "Binary composite: invalid field count " & $numFields) + # Each field carries at least an 8-byte header (oid + len) after the 4-byte + # count, so numFields cannot exceed (data.len - 4) div 8. + if numFields > (data.len - 4) div 8: + raise newException(PgTypeError, "Binary composite: field count exceeds data") result = newSeq[tuple[oid: int32, off: int, len: int]](numFields) var pos = 4 for i in 0 ..< numFields: @@ -479,6 +511,10 @@ proc decodeBinaryTsVector*(data: openArray[byte]): string = if nlexemes < 0: raise newException(PgTypeError, "tsvector binary: invalid lexeme count " & $nlexemes) + # Each lexeme needs at least a null terminator (1 byte) + 2-byte position + # count after the 4-byte count, so nlexemes cannot exceed (data.len - 4) div 3. + if nlexemes > (data.len - 4) div 3: + raise newException(PgTypeError, "tsvector binary: lexeme count exceeds data") var pos = 4 var parts = newSeq[string](nlexemes) const weightChars = ['D', 'C', 'B', 'A'] diff --git a/tests/all_tests.nim b/tests/all_tests.nim index 578e0c2..eb6226f 100644 --- a/tests/all_tests.nim +++ b/tests/all_tests.nim @@ -1,6 +1,6 @@ {.push warning[UnusedImport]: off.} import test_advisory_lock, test_auth, test_dsn, test_e2e, test_keepalive, test_largeobject, - test_pool, test_protocol, test_rowdata, test_sql, test_ssl, test_tracing, test_types, - test_pool_cluster + test_pool, test_protocol, test_protocol_fuzz, test_rowdata, test_sql, test_ssl, + test_tracing, test_types, test_pool_cluster {.pop.} diff --git a/tests/test_protocol_fuzz.nim b/tests/test_protocol_fuzz.nim new file mode 100644 index 0000000..75f60ad --- /dev/null +++ b/tests/test_protocol_fuzz.nim @@ -0,0 +1,594 @@ +## Protocol decoder fuzz / negative-path tests. +## +## Goal: prove that `parseBackendMessage` and the binary type decoders never +## crash (no `Defect`) on arbitrary byte input. Acceptable outcomes for any +## input are: +## * `psIncomplete` (need more bytes) +## * `psComplete` / `psDataRow` (valid message recognised) +## * raising `ProtocolError` (wire protocol violation) +## * raising `PgTypeError` (unparseable type payload) +## Anything else — in particular `IndexDefect`, `RangeDefect`, `DivByZeroDefect` +## — is a bug in the decoder. + +import std/[unittest, random] + +import ../async_postgres/pg_protocol +import ../async_postgres/pg_types/[core, decoding] + +# Helpers + +proc wrap(msgType: char, body: openArray[byte]): seq[byte] = + ## Wrap a body with a backend-message header. `msgLen` = 4 + body.len. + result = newSeq[byte](5 + body.len) + result[0] = byte(msgType) + let msgLen = int32(4 + body.len) + result[1] = byte((msgLen shr 24) and 0xFF) + result[2] = byte((msgLen shr 16) and 0xFF) + result[3] = byte((msgLen shr 8) and 0xFF) + result[4] = byte(msgLen and 0xFF) + for i, b in body: + result[5 + i] = b + +proc header(msgType: char, claimedLen: int32): seq[byte] = + ## Header with an explicit `msgLen` — useful for testing length violations. + result = newSeq[byte](5) + result[0] = byte(msgType) + result[1] = byte((claimedLen shr 24) and 0xFF) + result[2] = byte((claimedLen shr 16) and 0xFF) + result[3] = byte((claimedLen shr 8) and 0xFF) + result[4] = byte(claimedLen and 0xFF) + +template expectParseError(body: untyped) = + ## Assert that `body` raises `ProtocolError` (not `Defect`). + var raised = false + try: + body + except ProtocolError: + raised = true + check raised + +template expectTypeError(body: untyped) = + var raised = false + try: + body + except PgTypeError: + raised = true + check raised + +proc tryParse(buf: openArray[byte]): tuple[state: ParseState, consumed: int] = + ## Run `parseBackendMessage` and normalise the outcome. Raises `ProtocolError` + ## on wire-protocol violations; all other exceptions escape (test failure). + var consumed = 0 + let res = parseBackendMessage(buf, consumed) + (res.state, consumed) + +# Hand-written negative paths: framing + +suite "parseBackendMessage: framing": + test "empty buffer is incomplete": + var consumed = 0 + let res = parseBackendMessage(@[], consumed) + check res.state == psIncomplete + check consumed == 0 + + test "4-byte buffer is incomplete": + var consumed = 0 + let res = parseBackendMessage(@[byte('C'), 0, 0, 0], consumed) + check res.state == psIncomplete + check consumed == 0 + + test "msgLen == 0 raises ProtocolError": + expectParseError: + discard tryParse(header('C', 0)) + + test "msgLen == 3 (below minimum 4) raises ProtocolError": + expectParseError: + discard tryParse(header('C', 3)) + + test "negative msgLen raises ProtocolError": + expectParseError: + discard tryParse(header('C', -1)) + expectParseError: + discard tryParse(header('C', int32.low)) + + test "huge msgLen with short buffer is incomplete": + # 2 GiB - 1 but we only supplied 5 bytes. + let (state, consumed) = tryParse(header('C', int32.high)) + check state == psIncomplete + check consumed == 0 + + test "unknown message type 'X' raises ProtocolError": + expectParseError: + discard tryParse(wrap('X', @[])) + + test "bytes unassigned on the backend side raise ProtocolError": + # 'P', 'B', 'F', 'Q' are frontend message types; the rest are truly + # unassigned. All should be rejected by the backend parser. + for ch in ['x', 'Y', 'Q', 'P', 'B', 'F', '!', '@']: + expectParseError: + discard tryParse(wrap(ch, @[])) + +# Hand-written negative paths: per-message-kind body truncation + +suite "parseBackendMessage: per-kind malformed bodies": + test "Authentication (R) with body too short": + # Needs 4 bytes for the auth type tag. + expectParseError: + discard tryParse(wrap('R', @[byte 0])) + expectParseError: + discard tryParse(wrap('R', @[byte 0, 0, 0])) + + test "Authentication MD5 without 4-byte salt": + # authType=5 present, but salt missing. + expectParseError: + discard tryParse(wrap('R', @[byte 0, 0, 0, 5])) + expectParseError: + discard tryParse(wrap('R', @[byte 0, 0, 0, 5, 0, 0])) + + test "Authentication unknown authType": + expectParseError: + discard tryParse(wrap('R', @[byte 0, 0, 0, 99])) + + test "BackendKeyData (K) body shorter than 8": + expectParseError: + discard tryParse(wrap('K', @[])) + expectParseError: + discard tryParse(wrap('K', newSeq[byte](7))) + + test "DataRow (D) body too short for column count": + expectParseError: + discard tryParse(wrap('D', @[])) + expectParseError: + discard tryParse(wrap('D', @[byte 0])) + + test "DataRow with more columns than data": + # numCols=2 but only 4 bytes (enough for one colLen header, not two). + let body = @[byte 0, 2, 0, 0, 0, 0] + expectParseError: + discard tryParse(wrap('D', body)) + + test "DataRow with invalid (< -1) column length": + # numCols=1, colLen=-2. + let body = @[byte 0, 1, 0xFF'u8, 0xFF'u8, 0xFF'u8, 0xFE'u8] + expectParseError: + discard tryParse(wrap('D', body)) + + test "DataRow with colLen exceeding body": + # numCols=1, colLen=100, no actual bytes. + let body = @[byte 0, 1, 0'u8, 0'u8, 0'u8, 100'u8] + expectParseError: + discard tryParse(wrap('D', body)) + + test "DataRow with negative column count raises ProtocolError": + let body = @[byte 0xFF, 0xFF] + expectParseError: + discard tryParse(wrap('D', body)) + + test "RowDescription (T) shorter than 2 bytes": + expectParseError: + discard tryParse(wrap('T', @[])) + expectParseError: + discard tryParse(wrap('T', @[byte 0])) + + test "RowDescription claims a field without enough metadata": + # numFields=1, then a CString "a" followed by only a few bytes (need 18). + let body = @[byte 0, 1, byte('a'), 0'u8, 0, 0, 0, 0] + expectParseError: + discard tryParse(wrap('T', body)) + + test "RowDescription field name missing null terminator": + # numFields=1, "abcd" with no trailing 0. + let body = @[byte 0, 1, byte('a'), byte('b'), byte('c'), byte('d')] + expectParseError: + discard tryParse(wrap('T', body)) + + test "CommandComplete (C) without null terminator": + expectParseError: + discard tryParse(wrap('C', @[byte('S'), byte('E'), byte('L')])) + + test "ErrorResponse (E) / NoticeResponse (N) with non-terminated field": + for kind in ['E', 'N']: + # fieldType byte present, value CString missing null. + let body = @[byte('M'), byte('x'), byte('y')] + expectParseError: + discard tryParse(wrap(kind, body)) + + test "ParameterStatus (S) missing value terminator": + # "key\0" but value has no null. + let body = @[byte('k'), byte('e'), byte('y'), 0'u8, byte('v'), byte('a')] + expectParseError: + discard tryParse(wrap('S', body)) + + test "ReadyForQuery (Z) empty or bad status": + expectParseError: + discard tryParse(wrap('Z', @[])) + expectParseError: + discard tryParse(wrap('Z', @[byte('X')])) + + test "ParameterDescription (t) short header": + expectParseError: + discard tryParse(wrap('t', @[])) + expectParseError: + discard tryParse(wrap('t', @[byte 0])) + + test "ParameterDescription claims params without OID bytes": + # numParams=2 but only 4 bytes of OID data (need 8). + let body = @[byte 0, 2, 0, 0, 0, 0] + expectParseError: + discard tryParse(wrap('t', body)) + + test "CopyInResponse (G) / CopyOutResponse (H) short": + for kind in ['G', 'H']: + expectParseError: + discard tryParse(wrap(kind, @[])) + # Format + numCols header present (3 bytes), but one column format missing. + let body = @[byte 0, 0, 1] + expectParseError: + discard tryParse(wrap(kind, body)) + + test "CopyBothResponse (W) short": + expectParseError: + discard tryParse(wrap('W', @[])) + expectParseError: + discard tryParse(wrap('W', @[byte 0, 0, 1])) + + test "NotificationResponse (A) short and malformed": + expectParseError: + discard tryParse(wrap('A', @[])) + # pid present but channel missing null terminator. + let body = @[byte 0, 0, 0, 1, byte('c'), byte('h')] + expectParseError: + discard tryParse(wrap('A', body)) + + test "Zero-body messages round-trip successfully": + for ch in ['1', '2', '3', 'I', 'n', 's', 'c']: + let (state, consumed) = tryParse(wrap(ch, @[])) + check state == psComplete + check consumed == 5 # 1 type + 4 length + +# Seeded random fuzz on parseBackendMessage + +suite "parseBackendMessage: seeded random fuzz": + const seeds = [1'i64, 42, 0xDEADBEEF, 0xCAFEBABE, 0xA5A5A5A5, 1234567890] + const itersPerSeed = 2000 + + test "random bytes never crash parseBackendMessage": + # The contract: any input produces psComplete/psDataRow/psIncomplete or + # raises ProtocolError. Any other exception escapes and fails the test. + for seed in seeds: + var r = initRand(seed) + for _ in 0 ..< itersPerSeed: + let n = r.rand(0 .. 512) + var buf = newSeq[byte](n) + for i in 0 ..< n: + buf[i] = byte(r.rand(255)) + try: + var consumed = 0 + let res = parseBackendMessage(buf, consumed) + if res.state == psIncomplete: + check consumed == 0 + except ProtocolError: + discard + + test "random buffers whose first byte picks a valid message type": + # Increase hit rate on per-kind decoders by forcing a legal leading byte. + const types = [ + '1', '2', '3', 'A', 'C', 'D', 'E', 'G', 'H', 'I', 'K', 'N', 'R', 'S', 'T', 'W', + 'Z', 'c', 'd', 'n', 's', 't', + ] + for seed in seeds: + var r = initRand(seed) + for _ in 0 ..< itersPerSeed: + let n = r.rand(5 .. 512) + var buf = newSeq[byte](n) + buf[0] = byte(types[r.rand(0 ..< types.len)]) + # Random length field, random body. + for i in 1 ..< n: + buf[i] = byte(r.rand(255)) + try: + var consumed = 0 + discard parseBackendMessage(buf, consumed) + except ProtocolError: + discard + + test "random buffer fed into streaming RowData sink never crashes": + # Exercise the `parseDataRowInto` branch with random data. + for seed in seeds: + var r = initRand(seed) + for _ in 0 ..< itersPerSeed div 4: + let n = r.rand(5 .. 512) + var buf = newSeq[byte](n) + buf[0] = byte('D') + for i in 1 ..< n: + buf[i] = byte(r.rand(255)) + let rd = newRowData(1'i16) + try: + var consumed = 0 + discard parseBackendMessage(buf, consumed, rd) + except ProtocolError: + discard + +# Hand-written negative paths: binary type decoders + +suite "Binary type decoders: malformed input": + test "decodeBinaryTimestamp on short input": + expectTypeError: + discard decodeBinaryTimestamp(@[]) + expectTypeError: + discard decodeBinaryTimestamp(newSeq[byte](7)) + + test "decodeBinaryDate on short input": + expectTypeError: + discard decodeBinaryDate(@[]) + expectTypeError: + discard decodeBinaryDate(newSeq[byte](3)) + + test "decodeBinaryTime on short input": + expectTypeError: + discard decodeBinaryTime(@[]) + expectTypeError: + discard decodeBinaryTime(newSeq[byte](7)) + + test "decodeBinaryTimeTz on short input": + expectTypeError: + discard decodeBinaryTimeTz(newSeq[byte](11)) + + test "decodeInetBinary on short input": + expectTypeError: + discard decodeInetBinary(@[]) + expectTypeError: + discard decodeInetBinary(newSeq[byte](3)) + # family=2 (IPv4) needs 8 bytes total. + expectTypeError: + discard decodeInetBinary(@[byte 2, 32, 0, 4]) + # family=3 (IPv6) needs 20 bytes total. + expectTypeError: + discard decodeInetBinary(@[byte 3, 128, 0, 16]) + + test "decodePointBinary bounds": + expectTypeError: + discard decodePointBinary(@[], 0) + expectTypeError: + discard decodePointBinary(newSeq[byte](16), 1) # off + 16 > len + expectTypeError: + discard decodePointBinary(newSeq[byte](16), -1) + + test "decodeNumericBinary short header": + expectTypeError: + discard decodeNumericBinary(@[]) + expectTypeError: + discard decodeNumericBinary(newSeq[byte](7)) + + test "decodeNumericBinary negative ndigits": + let data = @[byte 0xFF, 0xFF, 0, 0, 0, 0, 0, 0] + expectTypeError: + discard decodeNumericBinary(data) + + test "decodeNumericBinary unknown sign": + let data = @[byte 0, 0, 0, 0, 0x12, 0x34, 0, 0] + expectTypeError: + discard decodeNumericBinary(data) + + test "decodeNumericBinary claims more digits than data": + # ndigits=100 — payload not present. + let data = @[byte 0, 100, 0, 0, 0, 0, 0, 0] + expectTypeError: + discard decodeNumericBinary(data) + + test "decodeBinaryArray short": + expectTypeError: + discard decodeBinaryArray(@[]) + expectTypeError: + discard decodeBinaryArray(newSeq[byte](11)) + + test "decodeBinaryArray ndim != 0, 1 is rejected": + # ndim=2 with 20 bytes of header. + var data = newSeq[byte](20) + data[3] = 2 # ndim low byte + expectTypeError: + discard decodeBinaryArray(data) + + test "decodeBinaryArray ndim=1 header truncated": + # ndim=1 but only 12 bytes (need 20). + var data = newSeq[byte](12) + data[3] = 1 + expectTypeError: + discard decodeBinaryArray(data) + + test "decodeBinaryArray bad dimLen": + var data = newSeq[byte](20) + data[3] = 1 # ndim=1 + data[12] = 0xFF # dimLen=-1 (after sign extension) + data[13] = 0xFF + data[14] = 0xFF + data[15] = 0xFF + expectTypeError: + discard decodeBinaryArray(data) + + test "decodeBinaryArray element truncated": + # ndim=1, dimLen=1, then claim eLen=100 with no payload. + var data = newSeq[byte](24) + data[3] = 1 # ndim=1 + data[15] = 1 # dimLen=1 + data[23] = 100 # eLen=100 + expectTypeError: + discard decodeBinaryArray(data) + + test "decodeBinaryComposite short": + expectTypeError: + discard decodeBinaryComposite(@[]) + expectTypeError: + discard decodeBinaryComposite(newSeq[byte](3)) + + test "decodeBinaryComposite negative numFields": + var data = newSeq[byte](4) + data[0] = 0xFF + data[1] = 0xFF + data[2] = 0xFF + data[3] = 0xFF + expectTypeError: + discard decodeBinaryComposite(data) + + test "decodeBinaryComposite field truncated": + # numFields=1 but no field header. + var data = newSeq[byte](4) + data[3] = 1 + expectTypeError: + discard decodeBinaryComposite(data) + + test "decodeHstoreBinary short": + expectTypeError: + discard decodeHstoreBinary(@[]) + expectTypeError: + discard decodeHstoreBinary(newSeq[byte](3)) + + test "decodeHstoreBinary truncated key": + # numPairs=1, keyLen=100 with no data. + var data = newSeq[byte](8) + data[3] = 1 + data[7] = 100 + expectTypeError: + discard decodeHstoreBinary(data) + + test "decodeBinaryTsVector short": + expectTypeError: + discard decodeBinaryTsVector(@[]) + expectTypeError: + discard decodeBinaryTsVector(newSeq[byte](3)) + + test "decodeBinaryTsVector negative nlexemes": + var data = newSeq[byte](4) + data[0] = 0xFF + data[1] = 0xFF + data[2] = 0xFF + data[3] = 0xFF + expectTypeError: + discard decodeBinaryTsVector(data) + + test "decodeBinaryTsVector missing null terminator": + # nlexemes=1, then bytes without a null terminator. + let data = @[byte 0, 0, 0, 1, byte('a'), byte('b')] + expectTypeError: + discard decodeBinaryTsVector(data) + + test "decodeBinaryTsQuery short": + expectTypeError: + discard decodeBinaryTsQuery(@[]) + expectTypeError: + discard decodeBinaryTsQuery(newSeq[byte](3)) + + test "decodeBinaryTsQuery unknown token type": + # ntokens=1, then token type byte = 99 (unknown). + let data = @[byte 0, 0, 0, 1, 99'u8] + expectTypeError: + discard decodeBinaryTsQuery(data) + +# Seeded random fuzz on binary type decoders + +suite "Binary type decoders: seeded random fuzz": + const seeds = [1'i64, 42, 0xBEEF, 0xC0FFEE, 0x13371337] + const itersPerSeed = 500 + + proc randomBuf(r: var Rand, maxLen: int): seq[byte] = + let n = r.rand(0 .. maxLen) + result = newSeq[byte](n) + for i in 0 ..< n: + result[i] = byte(r.rand(255)) + + test "decodeNumericBinary never crashes": + for seed in seeds: + var r = initRand(seed) + for _ in 0 ..< itersPerSeed: + let buf = randomBuf(r, 64) + try: + discard decodeNumericBinary(buf) + except PgTypeError: + discard + + test "decodeBinaryTimestamp / Date / Time / TimeTz never crash": + for seed in seeds: + var r = initRand(seed) + for _ in 0 ..< itersPerSeed: + let buf = randomBuf(r, 32) + try: + discard decodeBinaryTimestamp(buf) + except PgTypeError: + discard + try: + discard decodeBinaryDate(buf) + except PgTypeError: + discard + try: + discard decodeBinaryTime(buf) + except PgTypeError: + discard + try: + discard decodeBinaryTimeTz(buf) + except PgTypeError: + discard + + test "decodeInetBinary never crashes": + for seed in seeds: + var r = initRand(seed) + for _ in 0 ..< itersPerSeed: + let buf = randomBuf(r, 32) + try: + discard decodeInetBinary(buf) + except PgTypeError: + discard + + test "decodePointBinary never crashes": + for seed in seeds: + var r = initRand(seed) + for _ in 0 ..< itersPerSeed: + let buf = randomBuf(r, 32) + # Try a range of offsets, including negative and past the end. + for off in [-1, 0, 1, 8, 16, buf.len, buf.len + 1]: + try: + discard decodePointBinary(buf, off) + except PgTypeError: + discard + + test "decodeBinaryArray never crashes": + for seed in seeds: + var r = initRand(seed) + for _ in 0 ..< itersPerSeed: + let buf = randomBuf(r, 128) + try: + discard decodeBinaryArray(buf) + except PgTypeError: + discard + + test "decodeBinaryComposite never crashes": + for seed in seeds: + var r = initRand(seed) + for _ in 0 ..< itersPerSeed: + let buf = randomBuf(r, 128) + try: + discard decodeBinaryComposite(buf) + except PgTypeError: + discard + + test "decodeHstoreBinary never crashes": + for seed in seeds: + var r = initRand(seed) + for _ in 0 ..< itersPerSeed: + let buf = randomBuf(r, 128) + try: + discard decodeHstoreBinary(buf) + except PgTypeError: + discard + + test "decodeBinaryTsVector / TsQuery never crash": + for seed in seeds: + var r = initRand(seed) + for _ in 0 ..< itersPerSeed: + let buf = randomBuf(r, 128) + try: + discard decodeBinaryTsVector(buf) + except PgTypeError: + discard + try: + discard decodeBinaryTsQuery(buf) + except PgTypeError: + discard