diff --git a/async_postgres/pg_protocol.nim b/async_postgres/pg_protocol.nim index 9ffbccd..7743a74 100644 --- a/async_postgres/pg_protocol.nim +++ b/async_postgres/pg_protocol.nim @@ -174,6 +174,10 @@ const 701, # float8 718, # circle 1043, # varchar + 1560, # bit + 1561, # bit[] + 1562, # varbit + 1563, # varbit[] 3904, # int4range 3905, # int4range[] 3906, # numrange diff --git a/async_postgres/pg_types.nim b/async_postgres/pg_types.nim index bbb77dd..72dbd78 100644 --- a/async_postgres/pg_types.nim +++ b/async_postgres/pg_types.nim @@ -53,6 +53,10 @@ type PgXml* = distinct string ## PostgreSQL xml type. + PgBit* = object ## PostgreSQL bit / bit varying type. + nbits*: int32 ## number of bits + data*: seq[byte] ## packed bit data (MSB first) + PgPoint* = object ## PostgreSQL point type: (x, y). x*: float64 y*: float64 @@ -214,6 +218,11 @@ const OidXml* = 142'i32 + OidBit* = 1560'i32 + OidVarbit* = 1562'i32 + OidBitArray* = 1561'i32 + OidVarbitArray* = 1563'i32 + rangeEmpty* = 0x01'u8 ## Range flag: range is empty. rangeHasLower* = 0x02'u8 ## Range flag: lower bound present. rangeHasUpper* = 0x04'u8 ## Range flag: upper bound present. @@ -239,6 +248,34 @@ proc `==`*(a, b: PgTsQuery): bool {.borrow.} proc `$`*(v: PgXml): string {.borrow.} proc `==`*(a, b: PgXml): bool {.borrow.} +proc `$`*(v: PgBit): string = + ## Convert PgBit to a bit string like "10110011". + result = newStringOfCap(v.nbits) + for i in 0 ..< v.nbits: + let byteIdx = i div 8 + let bitIdx = 7 - (i mod 8) + if (v.data[byteIdx].int shr bitIdx and 1) == 1: + result.add('1') + else: + result.add('0') + +proc `==`*(a, b: PgBit): bool = + a.nbits == b.nbits and a.data == b.data + +proc parseBitString*(s: string): PgBit = + ## Parse a bit string like "10110011" into PgBit. + let nbits = int32(s.len) + let nBytes = (nbits + 7) div 8 + var data = newSeq[byte](nBytes) + for i in 0 ..< nbits: + if s[i] == '1': + let byteIdx = i div 8 + let bitIdx = 7 - (i mod 8) + data[byteIdx] = data[byteIdx] or byte(1 shl bitIdx) + elif s[i] != '0': + raise newException(PgTypeError, "Invalid bit character: " & $s[i]) + PgBit(nbits: nbits, data: data) + proc parsePgNumeric*(s: string): PgNumeric = ## Parse a decimal string (e.g. "123.45", "-0.001", "NaN") into PgNumeric. if s.len == 0: @@ -678,6 +715,9 @@ proc toPgParam*(v: PgTsQuery): PgParam = proc toPgParam*(v: PgXml): PgParam = PgParam(oid: OidXml, format: 0, value: some(toBytes(string(v)))) +proc toPgParam*(v: PgBit): PgParam = + PgParam(oid: OidVarbit, format: 0, value: some(toBytes($v))) + proc toPgParam*(v: PgPoint): PgParam = PgParam(oid: OidPoint, format: 0, value: some(toBytes($v))) @@ -1039,6 +1079,31 @@ proc toPgBinaryParam*(v: PgXml): PgParam = ## Binary wire format for xml is the text representation itself. PgParam(oid: OidXml, format: 1, value: some(toBytes(string(v)))) +proc toPgBinaryParam*(v: PgBit): PgParam = + ## Binary format: 4-byte bit count (big-endian) + packed bit data. + var data = newSeq[byte](4 + v.data.len) + let beNbits = toBE32(v.nbits) + for i in 0 ..< 4: + data[i] = beNbits[i] + for i in 0 ..< v.data.len: + data[4 + i] = v.data[i] + PgParam(oid: OidVarbit, format: 1, value: some(data)) + +proc toPgBinaryParam*(v: seq[PgBit]): PgParam = + if v.len == 0: + return PgParam( + oid: OidVarbitArray, format: 1, value: some(encodeBinaryArrayEmpty(OidVarbit)) + ) + var elements = newSeq[seq[byte]](v.len) + for i, x in v: + elements[i] = toPgBinaryParam(x).value.get + PgParam( + oid: OidVarbitArray, format: 1, value: some(encodeBinaryArray(OidVarbit, elements)) + ) + +proc toPgParam*(v: seq[PgBit]): PgParam = + toPgBinaryParam(v) + proc encodePointBinary(p: PgPoint): seq[byte] = ## Encode a point as 16 bytes (two float64 big-endian). result = newSeq[byte](16) @@ -2073,6 +2138,22 @@ proc getMacAddr8*(row: Row, col: int): PgMacAddr8 = return PgMacAddr8(parts.join(":")) PgMacAddr8(row.getStr(col)) +proc getBit*(row: Row, col: int): PgBit = + ## Get a column value as PgBit. Handles both text and binary format. + if row.isBinaryCol(col): + let (off, clen) = cellInfo(row, col) + if clen == -1: + raise newException(PgTypeError, "Column " & $col & " is NULL") + if clen < 4: + raise newException(PgTypeError, "Invalid binary bit data: too short") + let nbits = fromBE32(row.data.buf[off .. off + 3]) + let dataLen = clen - 4 + var data = newSeq[byte](dataLen) + for i in 0 ..< dataLen: + data[i] = row.data.buf[off + 4 + i] + return PgBit(nbits: nbits, data: data) + parseBitString(row.getStr(col)) + proc decodeBinaryTsVector(data: openArray[byte]): string = ## Decode PostgreSQL binary tsvector to text representation. if data.len < 4: @@ -2472,6 +2553,7 @@ optAccessor(getMacAddr8, getMacAddr8Opt, PgMacAddr8) optAccessor(getTsVector, getTsVectorOpt, PgTsVector) optAccessor(getTsQuery, getTsQueryOpt, PgTsQuery) optAccessor(getXml, getXmlOpt, PgXml) +optAccessor(getBit, getBitOpt, PgBit) optAccessor(getHstore, getHstoreOpt, PgHstore) optAccessor(getPoint, getPointOpt, PgPoint) optAccessor(getLine, getLineOpt, PgLine) @@ -2683,6 +2765,33 @@ proc getStrArray*(row: Row, col: int): seq[string] = raise newException(PgTypeError, "NULL element in string array") result.add(e.get) +proc getBitArray*(row: Row, col: int): seq[PgBit] = + ## Get a column value as a seq of PgBit. Handles both text and binary format. + if row.isBinaryCol(col): + let (off, clen) = cellInfo(row, col) + if clen == -1: + raise newException(PgTypeError, "Column " & $col & " is NULL") + let decoded = decodeBinaryArray(row.data.buf.toOpenArray(off, off + clen - 1)) + result = newSeq[PgBit](decoded.elements.len) + for i, e in decoded.elements: + if e.len == -1: + raise newException(PgTypeError, "NULL element in bit array") + if e.len < 4: + raise newException(PgTypeError, "Invalid binary bit element: too short") + let nbits = fromBE32(row.data.buf.toOpenArray(off + e.off, off + e.off + 3)) + let dataLen = e.len - 4 + var data = newSeq[byte](dataLen) + for j in 0 ..< dataLen: + data[j] = row.data.buf[off + e.off + 4 + j] + result[i] = PgBit(nbits: nbits, data: data) + return + let s = row.getStr(col) + let elems = parseTextArray(s) + for e in elems: + if e.isNone: + raise newException(PgTypeError, "NULL element in bit array") + result.add(parseBitString(e.get)) + # Array Opt accessors (text format) optAccessor(getIntArray, getIntArrayOpt, seq[int32]) @@ -2692,6 +2801,7 @@ optAccessor(getFloatArray, getFloatArrayOpt, seq[float64]) optAccessor(getFloat32Array, getFloat32ArrayOpt, seq[float32]) optAccessor(getBoolArray, getBoolArrayOpt, seq[bool]) optAccessor(getStrArray, getStrArrayOpt, seq[string]) +optAccessor(getBitArray, getBitArrayOpt, seq[PgBit]) # Coerce a binary PgParam to match the server-inferred type from a prepared # statement. This handles the common case where e.g. int32.toPgParam is @@ -4538,6 +4648,7 @@ nameAccessor(getMacAddr8, PgMacAddr8) nameAccessor(getTsVector, PgTsVector) nameAccessor(getTsQuery, PgTsQuery) nameAccessor(getXml, PgXml) +nameAccessor(getBit, PgBit) nameAccessor(getHstore, PgHstore) nameAccessor(getPoint, PgPoint) nameAccessor(getLine, PgLine) @@ -4562,6 +4673,7 @@ nameAccessor(getMacAddr8Opt, Option[PgMacAddr8]) nameAccessor(getTsVectorOpt, Option[PgTsVector]) nameAccessor(getTsQueryOpt, Option[PgTsQuery]) nameAccessor(getXmlOpt, Option[PgXml]) +nameAccessor(getBitOpt, Option[PgBit]) nameAccessor(getHstoreOpt, Option[PgHstore]) nameAccessor(getPointOpt, Option[PgPoint]) nameAccessor(getLineOpt, Option[PgLine]) @@ -4577,6 +4689,7 @@ nameAccessor(getFloatArray, seq[float64]) nameAccessor(getFloat32Array, seq[float32]) nameAccessor(getBoolArray, seq[bool]) nameAccessor(getStrArray, seq[string]) +nameAccessor(getBitArray, seq[PgBit]) nameAccessor(getIntArrayOpt, Option[seq[int32]]) nameAccessor(getInt16ArrayOpt, Option[seq[int16]]) nameAccessor(getInt64ArrayOpt, Option[seq[int64]]) @@ -4584,6 +4697,7 @@ nameAccessor(getFloatArrayOpt, Option[seq[float64]]) nameAccessor(getFloat32ArrayOpt, Option[seq[float32]]) nameAccessor(getBoolArrayOpt, Option[seq[bool]]) nameAccessor(getStrArrayOpt, Option[seq[string]]) +nameAccessor(getBitArrayOpt, Option[seq[PgBit]]) nameAccessor(getInt4Range, PgRange[int32]) nameAccessor(getInt8Range, PgRange[int64]) nameAccessor(getNumRange, PgRange[PgNumeric]) diff --git a/tests/test_types.nim b/tests/test_types.nim index d0641f3..de0fb81 100644 --- a/tests/test_types.nim +++ b/tests/test_types.nim @@ -4429,3 +4429,157 @@ suite "hstore": let fields = @[mkField(OidText, 0'i16)] let row = mkRow(@[none(seq[byte])], fields) check row.getHstoreOpt(0).isNone + +suite "PgBit": + test "OID constants": + check OidBit == 1560'i32 + check OidVarbit == 1562'i32 + check OidBitArray == 1561'i32 + check OidVarbitArray == 1563'i32 + + test "parseBitString and $ roundtrip": + let b = parseBitString("10110011") + check b.nbits == 8 + check b.data == @[0b10110011'u8] + check $b == "10110011" + + test "parseBitString non-byte-aligned": + let b = parseBitString("101") + check b.nbits == 3 + check $b == "101" + # data should be 0b10100000 + check b.data == @[0b10100000'u8] + + test "parseBitString empty": + let b = parseBitString("") + check b.nbits == 0 + check b.data.len == 0 + check $b == "" + + test "== operator": + check parseBitString("1010") == parseBitString("1010") + check parseBitString("1010") != parseBitString("1011") + check parseBitString("10") != parseBitString("1000") + + test "toPgParam PgBit": + let b = parseBitString("10110011") + let p = toPgParam(b) + check p.oid == OidVarbit + check p.format == 0 + check toString(p.value.get) == "10110011" + + test "toPgBinaryParam PgBit": + let b = parseBitString("10110011") + let p = toPgBinaryParam(b) + check p.oid == OidVarbit + check p.format == 1 + let data = p.value.get + check data.len == 5 # 4 bytes for nbits + 1 byte for data + # nbits = 8 in big-endian + check data[0 .. 3] == @[0'u8, 0, 0, 8] + check data[4] == 0b10110011'u8 + + test "toPgBinaryParam PgBit non-byte-aligned": + let b = parseBitString("101") + let p = toPgBinaryParam(b) + let data = p.value.get + check data.len == 5 + # nbits = 3 in big-endian + check data[0 .. 3] == @[0'u8, 0, 0, 3] + check data[4] == 0b10100000'u8 + + test "getBit text format": + let data = toBytes("10110011") + let fields = @[mkField(OidVarbit, 0'i16)] + let row = mkRow(@[some(data)], fields) + let b = row.getBit(0) + check b.nbits == 8 + check $b == "10110011" + + test "getBit binary format": + # Binary: 4 bytes nbits (8) + 1 byte data + var data: seq[byte] = @[] + data.add(@(toBE32(8'i32))) + data.add(0b10110011'u8) + let fields = @[mkField(OidVarbit, 1'i16)] + let row = mkRow(@[some(data)], fields) + let b = row.getBit(0) + check b.nbits == 8 + check $b == "10110011" + + test "getBit binary format non-byte-aligned": + var data: seq[byte] = @[] + data.add(@(toBE32(3'i32))) + data.add(0b10100000'u8) + let fields = @[mkField(OidBit, 1'i16)] + let row = mkRow(@[some(data)], fields) + let b = row.getBit(0) + check b.nbits == 3 + check $b == "101" + + test "getBitOpt with value": + let data = toBytes("10110011") + let fields = @[mkField(OidVarbit, 0'i16)] + let row = mkRow(@[some(data)], fields) + let b = row.getBitOpt(0) + check b.isSome + check $b.get == "10110011" + + test "getBitOpt with NULL": + let fields = @[mkField(OidVarbit, 0'i16)] + let row = mkRow(@[none(seq[byte])], fields) + check row.getBitOpt(0).isNone + + test "toPgParam seq[PgBit]": + let v = @[parseBitString("1010"), parseBitString("110")] + let p = toPgParam(v) + check p.oid == OidVarbitArray + check p.format == 1 + + test "toPgParam seq[PgBit] empty": + let v: seq[PgBit] = @[] + let p = toPgParam(v) + check p.oid == OidVarbitArray + check p.format == 1 + + test "toPgBinaryParam seq[PgBit]": + let v = @[parseBitString("1010"), parseBitString("110")] + let p = toPgBinaryParam(v) + check p.oid == OidVarbitArray + check p.format == 1 + + test "getBitArray text format": + let data = toBytes("{1010,110,00001111}") + let fields = @[mkField(OidVarbitArray, 0'i16)] + let row = mkRow(@[some(data)], fields) + let arr = row.getBitArray(0) + check arr.len == 3 + check $arr[0] == "1010" + check $arr[1] == "110" + check $arr[2] == "00001111" + + test "getBitArray binary format": + # Encode two elements using toPgBinaryParam + let v = @[parseBitString("1010"), parseBitString("11001100")] + let p = toPgBinaryParam(v) + let data = p.value.get + let fields = @[mkField(OidVarbitArray, 1'i16)] + let row = mkRow(@[some(data)], fields) + let arr = row.getBitArray(0) + check arr.len == 2 + check $arr[0] == "1010" + check $arr[1] == "11001100" + + test "getBitArrayOpt with value": + let data = toBytes("{101}") + let fields = @[mkField(OidVarbitArray, 0'i16)] + let row = mkRow(@[some(data)], fields) + let arr = row.getBitArrayOpt(0) + check arr.isSome + check arr.get.len == 1 + check $arr.get[0] == "101" + + test "getBitArrayOpt with NULL": + let fields = @[mkField(OidVarbitArray, 0'i16)] + let row = mkRow(@[none(seq[byte])], fields) + check row.getBitArrayOpt(0).isNone