Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions async_postgres/pg_protocol.nim
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ const
701, # float8
718, # circle
1043, # varchar
1560, # bit
1561, # bit[]
1562, # varbit
1563, # varbit[]
3904, # int4range
3905, # int4range[]
3906, # numrange
Expand Down
114 changes: 114 additions & 0 deletions async_postgres/pg_types.nim
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ type

PgXml* = distinct string ## PostgreSQL xml type.

PgBit* = object ## PostgreSQL bit / bit varying type.
nbits*: int32 ## number of bits
data*: seq[byte] ## packed bit data (MSB first)

PgPoint* = object ## PostgreSQL point type: (x, y).
x*: float64
y*: float64
Expand Down Expand Up @@ -214,6 +218,11 @@ const

OidXml* = 142'i32

OidBit* = 1560'i32
OidVarbit* = 1562'i32
OidBitArray* = 1561'i32
OidVarbitArray* = 1563'i32

rangeEmpty* = 0x01'u8 ## Range flag: range is empty.
rangeHasLower* = 0x02'u8 ## Range flag: lower bound present.
rangeHasUpper* = 0x04'u8 ## Range flag: upper bound present.
Expand All @@ -239,6 +248,34 @@ proc `==`*(a, b: PgTsQuery): bool {.borrow.}
proc `$`*(v: PgXml): string {.borrow.}
proc `==`*(a, b: PgXml): bool {.borrow.}

proc `$`*(v: PgBit): string =
## Convert PgBit to a bit string like "10110011".
result = newStringOfCap(v.nbits)
for i in 0 ..< v.nbits:
let byteIdx = i div 8
let bitIdx = 7 - (i mod 8)
if (v.data[byteIdx].int shr bitIdx and 1) == 1:
result.add('1')
else:
result.add('0')

proc `==`*(a, b: PgBit): bool =
a.nbits == b.nbits and a.data == b.data

proc parseBitString*(s: string): PgBit =
## Parse a bit string like "10110011" into PgBit.
let nbits = int32(s.len)
let nBytes = (nbits + 7) div 8
var data = newSeq[byte](nBytes)
for i in 0 ..< nbits:
if s[i] == '1':
let byteIdx = i div 8
let bitIdx = 7 - (i mod 8)
data[byteIdx] = data[byteIdx] or byte(1 shl bitIdx)
elif s[i] != '0':
raise newException(PgTypeError, "Invalid bit character: " & $s[i])
PgBit(nbits: nbits, data: data)

proc parsePgNumeric*(s: string): PgNumeric =
## Parse a decimal string (e.g. "123.45", "-0.001", "NaN") into PgNumeric.
if s.len == 0:
Expand Down Expand Up @@ -678,6 +715,9 @@ proc toPgParam*(v: PgTsQuery): PgParam =
proc toPgParam*(v: PgXml): PgParam =
PgParam(oid: OidXml, format: 0, value: some(toBytes(string(v))))

proc toPgParam*(v: PgBit): PgParam =
PgParam(oid: OidVarbit, format: 0, value: some(toBytes($v)))

proc toPgParam*(v: PgPoint): PgParam =
PgParam(oid: OidPoint, format: 0, value: some(toBytes($v)))

Expand Down Expand Up @@ -1039,6 +1079,31 @@ proc toPgBinaryParam*(v: PgXml): PgParam =
## Binary wire format for xml is the text representation itself.
PgParam(oid: OidXml, format: 1, value: some(toBytes(string(v))))

proc toPgBinaryParam*(v: PgBit): PgParam =
## Binary format: 4-byte bit count (big-endian) + packed bit data.
var data = newSeq[byte](4 + v.data.len)
let beNbits = toBE32(v.nbits)
for i in 0 ..< 4:
data[i] = beNbits[i]
for i in 0 ..< v.data.len:
data[4 + i] = v.data[i]
PgParam(oid: OidVarbit, format: 1, value: some(data))

proc toPgBinaryParam*(v: seq[PgBit]): PgParam =
if v.len == 0:
return PgParam(
oid: OidVarbitArray, format: 1, value: some(encodeBinaryArrayEmpty(OidVarbit))
)
var elements = newSeq[seq[byte]](v.len)
for i, x in v:
elements[i] = toPgBinaryParam(x).value.get
PgParam(
oid: OidVarbitArray, format: 1, value: some(encodeBinaryArray(OidVarbit, elements))
)

proc toPgParam*(v: seq[PgBit]): PgParam =
toPgBinaryParam(v)

proc encodePointBinary(p: PgPoint): seq[byte] =
## Encode a point as 16 bytes (two float64 big-endian).
result = newSeq[byte](16)
Expand Down Expand Up @@ -2073,6 +2138,22 @@ proc getMacAddr8*(row: Row, col: int): PgMacAddr8 =
return PgMacAddr8(parts.join(":"))
PgMacAddr8(row.getStr(col))

proc getBit*(row: Row, col: int): PgBit =
## Get a column value as PgBit. Handles both text and binary format.
if row.isBinaryCol(col):
let (off, clen) = cellInfo(row, col)
if clen == -1:
raise newException(PgTypeError, "Column " & $col & " is NULL")
if clen < 4:
raise newException(PgTypeError, "Invalid binary bit data: too short")
let nbits = fromBE32(row.data.buf[off .. off + 3])
let dataLen = clen - 4
var data = newSeq[byte](dataLen)
for i in 0 ..< dataLen:
data[i] = row.data.buf[off + 4 + i]
return PgBit(nbits: nbits, data: data)
parseBitString(row.getStr(col))

proc decodeBinaryTsVector(data: openArray[byte]): string =
## Decode PostgreSQL binary tsvector to text representation.
if data.len < 4:
Expand Down Expand Up @@ -2472,6 +2553,7 @@ optAccessor(getMacAddr8, getMacAddr8Opt, PgMacAddr8)
optAccessor(getTsVector, getTsVectorOpt, PgTsVector)
optAccessor(getTsQuery, getTsQueryOpt, PgTsQuery)
optAccessor(getXml, getXmlOpt, PgXml)
optAccessor(getBit, getBitOpt, PgBit)
optAccessor(getHstore, getHstoreOpt, PgHstore)
optAccessor(getPoint, getPointOpt, PgPoint)
optAccessor(getLine, getLineOpt, PgLine)
Expand Down Expand Up @@ -2683,6 +2765,33 @@ proc getStrArray*(row: Row, col: int): seq[string] =
raise newException(PgTypeError, "NULL element in string array")
result.add(e.get)

proc getBitArray*(row: Row, col: int): seq[PgBit] =
## Get a column value as a seq of PgBit. Handles both text and binary format.
if row.isBinaryCol(col):
let (off, clen) = cellInfo(row, col)
if clen == -1:
raise newException(PgTypeError, "Column " & $col & " is NULL")
let decoded = decodeBinaryArray(row.data.buf.toOpenArray(off, off + clen - 1))
result = newSeq[PgBit](decoded.elements.len)
for i, e in decoded.elements:
if e.len == -1:
raise newException(PgTypeError, "NULL element in bit array")
if e.len < 4:
raise newException(PgTypeError, "Invalid binary bit element: too short")
let nbits = fromBE32(row.data.buf.toOpenArray(off + e.off, off + e.off + 3))
let dataLen = e.len - 4
var data = newSeq[byte](dataLen)
for j in 0 ..< dataLen:
data[j] = row.data.buf[off + e.off + 4 + j]
result[i] = PgBit(nbits: nbits, data: data)
return
let s = row.getStr(col)
let elems = parseTextArray(s)
for e in elems:
if e.isNone:
raise newException(PgTypeError, "NULL element in bit array")
result.add(parseBitString(e.get))

# Array Opt accessors (text format)

optAccessor(getIntArray, getIntArrayOpt, seq[int32])
Expand All @@ -2692,6 +2801,7 @@ optAccessor(getFloatArray, getFloatArrayOpt, seq[float64])
optAccessor(getFloat32Array, getFloat32ArrayOpt, seq[float32])
optAccessor(getBoolArray, getBoolArrayOpt, seq[bool])
optAccessor(getStrArray, getStrArrayOpt, seq[string])
optAccessor(getBitArray, getBitArrayOpt, seq[PgBit])

# Coerce a binary PgParam to match the server-inferred type from a prepared
# statement. This handles the common case where e.g. int32.toPgParam is
Expand Down Expand Up @@ -4538,6 +4648,7 @@ nameAccessor(getMacAddr8, PgMacAddr8)
nameAccessor(getTsVector, PgTsVector)
nameAccessor(getTsQuery, PgTsQuery)
nameAccessor(getXml, PgXml)
nameAccessor(getBit, PgBit)
nameAccessor(getHstore, PgHstore)
nameAccessor(getPoint, PgPoint)
nameAccessor(getLine, PgLine)
Expand All @@ -4562,6 +4673,7 @@ nameAccessor(getMacAddr8Opt, Option[PgMacAddr8])
nameAccessor(getTsVectorOpt, Option[PgTsVector])
nameAccessor(getTsQueryOpt, Option[PgTsQuery])
nameAccessor(getXmlOpt, Option[PgXml])
nameAccessor(getBitOpt, Option[PgBit])
nameAccessor(getHstoreOpt, Option[PgHstore])
nameAccessor(getPointOpt, Option[PgPoint])
nameAccessor(getLineOpt, Option[PgLine])
Expand All @@ -4577,13 +4689,15 @@ nameAccessor(getFloatArray, seq[float64])
nameAccessor(getFloat32Array, seq[float32])
nameAccessor(getBoolArray, seq[bool])
nameAccessor(getStrArray, seq[string])
nameAccessor(getBitArray, seq[PgBit])
nameAccessor(getIntArrayOpt, Option[seq[int32]])
nameAccessor(getInt16ArrayOpt, Option[seq[int16]])
nameAccessor(getInt64ArrayOpt, Option[seq[int64]])
nameAccessor(getFloatArrayOpt, Option[seq[float64]])
nameAccessor(getFloat32ArrayOpt, Option[seq[float32]])
nameAccessor(getBoolArrayOpt, Option[seq[bool]])
nameAccessor(getStrArrayOpt, Option[seq[string]])
nameAccessor(getBitArrayOpt, Option[seq[PgBit]])
nameAccessor(getInt4Range, PgRange[int32])
nameAccessor(getInt8Range, PgRange[int64])
nameAccessor(getNumRange, PgRange[PgNumeric])
Expand Down
154 changes: 154 additions & 0 deletions tests/test_types.nim
Original file line number Diff line number Diff line change
Expand Up @@ -4429,3 +4429,157 @@ suite "hstore":
let fields = @[mkField(OidText, 0'i16)]
let row = mkRow(@[none(seq[byte])], fields)
check row.getHstoreOpt(0).isNone

suite "PgBit":
test "OID constants":
check OidBit == 1560'i32
check OidVarbit == 1562'i32
check OidBitArray == 1561'i32
check OidVarbitArray == 1563'i32

test "parseBitString and $ roundtrip":
let b = parseBitString("10110011")
check b.nbits == 8
check b.data == @[0b10110011'u8]
check $b == "10110011"

test "parseBitString non-byte-aligned":
let b = parseBitString("101")
check b.nbits == 3
check $b == "101"
# data should be 0b10100000
check b.data == @[0b10100000'u8]

test "parseBitString empty":
let b = parseBitString("")
check b.nbits == 0
check b.data.len == 0
check $b == ""

test "== operator":
check parseBitString("1010") == parseBitString("1010")
check parseBitString("1010") != parseBitString("1011")
check parseBitString("10") != parseBitString("1000")

test "toPgParam PgBit":
let b = parseBitString("10110011")
let p = toPgParam(b)
check p.oid == OidVarbit
check p.format == 0
check toString(p.value.get) == "10110011"

test "toPgBinaryParam PgBit":
let b = parseBitString("10110011")
let p = toPgBinaryParam(b)
check p.oid == OidVarbit
check p.format == 1
let data = p.value.get
check data.len == 5 # 4 bytes for nbits + 1 byte for data
# nbits = 8 in big-endian
check data[0 .. 3] == @[0'u8, 0, 0, 8]
check data[4] == 0b10110011'u8

test "toPgBinaryParam PgBit non-byte-aligned":
let b = parseBitString("101")
let p = toPgBinaryParam(b)
let data = p.value.get
check data.len == 5
# nbits = 3 in big-endian
check data[0 .. 3] == @[0'u8, 0, 0, 3]
check data[4] == 0b10100000'u8

test "getBit text format":
let data = toBytes("10110011")
let fields = @[mkField(OidVarbit, 0'i16)]
let row = mkRow(@[some(data)], fields)
let b = row.getBit(0)
check b.nbits == 8
check $b == "10110011"

test "getBit binary format":
# Binary: 4 bytes nbits (8) + 1 byte data
var data: seq[byte] = @[]
data.add(@(toBE32(8'i32)))
data.add(0b10110011'u8)
let fields = @[mkField(OidVarbit, 1'i16)]
let row = mkRow(@[some(data)], fields)
let b = row.getBit(0)
check b.nbits == 8
check $b == "10110011"

test "getBit binary format non-byte-aligned":
var data: seq[byte] = @[]
data.add(@(toBE32(3'i32)))
data.add(0b10100000'u8)
let fields = @[mkField(OidBit, 1'i16)]
let row = mkRow(@[some(data)], fields)
let b = row.getBit(0)
check b.nbits == 3
check $b == "101"

test "getBitOpt with value":
let data = toBytes("10110011")
let fields = @[mkField(OidVarbit, 0'i16)]
let row = mkRow(@[some(data)], fields)
let b = row.getBitOpt(0)
check b.isSome
check $b.get == "10110011"

test "getBitOpt with NULL":
let fields = @[mkField(OidVarbit, 0'i16)]
let row = mkRow(@[none(seq[byte])], fields)
check row.getBitOpt(0).isNone

test "toPgParam seq[PgBit]":
let v = @[parseBitString("1010"), parseBitString("110")]
let p = toPgParam(v)
check p.oid == OidVarbitArray
check p.format == 1

test "toPgParam seq[PgBit] empty":
let v: seq[PgBit] = @[]
let p = toPgParam(v)
check p.oid == OidVarbitArray
check p.format == 1

test "toPgBinaryParam seq[PgBit]":
let v = @[parseBitString("1010"), parseBitString("110")]
let p = toPgBinaryParam(v)
check p.oid == OidVarbitArray
check p.format == 1

test "getBitArray text format":
let data = toBytes("{1010,110,00001111}")
let fields = @[mkField(OidVarbitArray, 0'i16)]
let row = mkRow(@[some(data)], fields)
let arr = row.getBitArray(0)
check arr.len == 3
check $arr[0] == "1010"
check $arr[1] == "110"
check $arr[2] == "00001111"

test "getBitArray binary format":
# Encode two elements using toPgBinaryParam
let v = @[parseBitString("1010"), parseBitString("11001100")]
let p = toPgBinaryParam(v)
let data = p.value.get
let fields = @[mkField(OidVarbitArray, 1'i16)]
let row = mkRow(@[some(data)], fields)
let arr = row.getBitArray(0)
check arr.len == 2
check $arr[0] == "1010"
check $arr[1] == "11001100"

test "getBitArrayOpt with value":
let data = toBytes("{101}")
let fields = @[mkField(OidVarbitArray, 0'i16)]
let row = mkRow(@[some(data)], fields)
let arr = row.getBitArrayOpt(0)
check arr.isSome
check arr.get.len == 1
check $arr.get[0] == "101"

test "getBitArrayOpt with NULL":
let fields = @[mkField(OidVarbitArray, 0'i16)]
let row = mkRow(@[none(seq[byte])], fields)
check row.getBitArrayOpt(0).isNone
Loading