diff --git a/async_postgres/pg_types.nim b/async_postgres/pg_types.nim index 8984186..af231d7 100644 --- a/async_postgres/pg_types.nim +++ b/async_postgres/pg_types.nim @@ -48,6 +48,9 @@ type PgMacAddr8* = distinct string ## EUI-64 MAC address as "08:00:2b:01:02:03:04:05" + PgTsVector* = distinct string ## PostgreSQL tsvector (full-text search document). + PgTsQuery* = distinct string ## PostgreSQL tsquery (full-text search query). + PgPoint* = object ## PostgreSQL point type: (x, y). x*: float64 y*: float64 @@ -192,6 +195,10 @@ const OidDateMultirange* = 4535'i32 OidInt8Multirange* = 4536'i32 + # Full-text search types + OidTsVector* = 3614'i32 + OidTsQuery* = 3615'i32 + rangeEmpty* = 0x01'u8 ## Range flag: range is empty. rangeHasLower* = 0x02'u8 ## Range flag: lower bound present. rangeHasUpper* = 0x04'u8 ## Range flag: upper bound present. @@ -208,6 +215,12 @@ proc `==`*(a, b: PgMacAddr): bool {.borrow.} proc `$`*(v: PgMacAddr8): string {.borrow.} proc `==`*(a, b: PgMacAddr8): bool {.borrow.} +proc `$`*(v: PgTsVector): string {.borrow.} +proc `==`*(a, b: PgTsVector): bool {.borrow.} + +proc `$`*(v: PgTsQuery): string {.borrow.} +proc `==`*(a, b: PgTsQuery): bool {.borrow.} + proc parsePgNumeric*(s: string): PgNumeric = ## Parse a decimal string (e.g. "123.45", "-0.001", "NaN") into PgNumeric. if s.len == 0: @@ -631,6 +644,12 @@ proc toPgParam*(v: PgMacAddr): PgParam = proc toPgParam*(v: PgMacAddr8): PgParam = PgParam(oid: OidMacAddr8, format: 0, value: some(toBytes(string(v)))) +proc toPgParam*(v: PgTsVector): PgParam = + PgParam(oid: OidTsVector, format: 0, value: some(toBytes(string(v)))) + +proc toPgParam*(v: PgTsQuery): PgParam = + PgParam(oid: OidTsQuery, format: 0, value: some(toBytes(string(v)))) + proc toPgParam*(v: PgPoint): PgParam = PgParam(oid: OidPoint, format: 0, value: some(toBytes($v))) @@ -952,6 +971,14 @@ proc toPgBinaryParam*(v: PgMacAddr8): PgParam = data[i] = byte(parseHexInt(parts[i])) PgParam(oid: OidMacAddr8, format: 1, value: some(data)) +proc toPgBinaryParam*(v: PgTsVector): PgParam = + ## Send as text format — PostgreSQL handles the parsing. + PgParam(oid: OidTsVector, format: 0, value: some(toBytes(string(v)))) + +proc toPgBinaryParam*(v: PgTsQuery): PgParam = + ## Send as text format — PostgreSQL handles the parsing. + PgParam(oid: OidTsQuery, format: 0, value: some(toBytes(string(v)))) + proc encodePointBinary(p: PgPoint): seq[byte] = ## Encode a point as 16 bytes (two float64 big-endian). result = newSeq[byte](16) @@ -1852,6 +1879,148 @@ proc getMacAddr8*(row: Row, col: int): PgMacAddr8 = return PgMacAddr8(parts.join(":")) PgMacAddr8(row.getStr(col)) +proc decodeBinaryTsVector(data: openArray[byte]): string = + ## Decode PostgreSQL binary tsvector to text representation. + if data.len < 4: + raise newException(PgTypeError, "tsvector binary data too short") + let nlexemes = int(fromBE32(data.toOpenArray(0, 3))) + var pos = 4 + var parts = newSeq[string](nlexemes) + const weightChars = ['D', 'C', 'B', 'A'] + for i in 0 ..< nlexemes: + # Read null-terminated lexeme + var lexEnd = pos + while lexEnd < data.len and data[lexEnd] != 0: + inc lexEnd + if lexEnd >= data.len: + raise newException(PgTypeError, "tsvector binary: lexeme missing null terminator") + var lexeme = newString(lexEnd - pos) + for j in 0 ..< lexEnd - pos: + lexeme[j] = char(data[pos + j]) + pos = lexEnd + 1 # skip null terminator + # Read positions + if pos + 1 >= data.len: + raise newException(PgTypeError, "tsvector binary truncated at position count") + let npos = int(fromBE16(data.toOpenArray(pos, pos + 1))) + pos += 2 + var part = "'" & lexeme & "'" + if npos > 0: + part.add(':') + for j in 0 ..< npos: + if pos + 1 >= data.len: + raise newException(PgTypeError, "tsvector binary truncated at position") + let posVal = uint16(fromBE16(data.toOpenArray(pos, pos + 1))) + pos += 2 + let position = posVal and 0x3FFF + let weight = int((posVal shr 14) and 0x3) + if j > 0: + part.add(',') + part.add($position) + if weight > 0: + part.add(weightChars[weight]) + parts[i] = part + parts.join(" ") + +proc parseTsQueryNode(data: openArray[byte], pos: var int): string = + if pos >= data.len: + raise newException(PgTypeError, "tsquery binary truncated") + let tokenType = data[pos] + inc pos + case tokenType + of 1: # operand + if pos + 2 >= data.len: + raise newException(PgTypeError, "tsquery operand truncated") + let weightByte = data[pos] + inc pos + let prefix = data[pos] != 0 + inc pos + var strEnd = pos + while strEnd < data.len and data[strEnd] != 0: + inc strEnd + if strEnd >= data.len: + raise newException(PgTypeError, "tsquery binary: operand missing null terminator") + var operand = newString(strEnd - pos) + for j in 0 ..< strEnd - pos: + operand[j] = char(data[pos + j]) + pos = strEnd + 1 + var s = "'" & operand & "'" + var suffix = "" + if (weightByte and 0x08) != 0: + suffix.add('A') + if (weightByte and 0x04) != 0: + suffix.add('B') + if (weightByte and 0x02) != 0: + suffix.add('C') + if (weightByte and 0x01) != 0: + suffix.add('D') + if suffix.len > 0 or prefix: + s.add(':') + s.add(suffix) + if prefix: + s.add('*') + s + of 2: # operator + if pos >= data.len: + raise newException(PgTypeError, "tsquery operator truncated") + let op = data[pos] + inc pos + case op + of 1: # NOT + let arg = parseTsQueryNode(data, pos) + "!" & arg + of 2: # AND + let left = parseTsQueryNode(data, pos) + let right = parseTsQueryNode(data, pos) + left & " & " & right + of 3: # OR + let left = parseTsQueryNode(data, pos) + let right = parseTsQueryNode(data, pos) + "( " & left & " | " & right & " )" + of 4: # PHRASE + if pos + 1 >= data.len: + raise newException(PgTypeError, "tsquery PHRASE distance truncated") + let distance = int(fromBE16(data.toOpenArray(pos, pos + 1))) + pos += 2 + let left = parseTsQueryNode(data, pos) + let right = parseTsQueryNode(data, pos) + if distance == 1: + left & " <-> " & right + else: + left & " <" & $distance & "> " & right + else: + raise newException(PgTypeError, "Unknown tsquery operator: " & $op) + else: + raise newException(PgTypeError, "Unknown tsquery token type: " & $tokenType) + +proc decodeBinaryTsQuery(data: openArray[byte]): string = + ## Decode PostgreSQL binary tsquery (prefix/preorder) to text representation (infix). + if data.len < 4: + raise newException(PgTypeError, "tsquery binary data too short") + let ntokens = int(fromBE32(data.toOpenArray(0, 3))) + if ntokens == 0: + return "" + var pos = 4 + parseTsQueryNode(data, pos) + +proc getTsVector*(row: Row, col: int): PgTsVector = + ## Get a column value as PgTsVector. Handles both text and binary format. + if row.isBinaryCol(col): + let (off, clen) = cellInfo(row, col) + if clen == -1: + raise newException(PgTypeError, "Column " & $col & " is NULL") + return + PgTsVector(decodeBinaryTsVector(row.data.buf.toOpenArray(off, off + clen - 1))) + PgTsVector(row.getStr(col)) + +proc getTsQuery*(row: Row, col: int): PgTsQuery = + ## Get a column value as PgTsQuery. Handles both text and binary format. + if row.isBinaryCol(col): + let (off, clen) = cellInfo(row, col) + if clen == -1: + raise newException(PgTypeError, "Column " & $col & " is NULL") + return PgTsQuery(decodeBinaryTsQuery(row.data.buf.toOpenArray(off, off + clen - 1))) + PgTsQuery(row.getStr(col)) + # Geometry text format parsers proc parsePointText(s: string): PgPoint = @@ -2085,6 +2254,8 @@ optAccessor(getInet, getInetOpt, PgInet) optAccessor(getCidr, getCidrOpt, PgCidr) optAccessor(getMacAddr, getMacAddrOpt, PgMacAddr) optAccessor(getMacAddr8, getMacAddr8Opt, PgMacAddr8) +optAccessor(getTsVector, getTsVectorOpt, PgTsVector) +optAccessor(getTsQuery, getTsQueryOpt, PgTsQuery) optAccessor(getPoint, getPointOpt, PgPoint) optAccessor(getLine, getLineOpt, PgLine) optAccessor(getLseg, getLsegOpt, PgLseg) @@ -3665,6 +3836,12 @@ proc get*(row: Row, col: int, T: typedesc[PgMacAddr]): PgMacAddr = proc get*(row: Row, col: int, T: typedesc[PgMacAddr8]): PgMacAddr8 = row.getMacAddr8(col) +proc get*(row: Row, col: int, T: typedesc[PgTsVector]): PgTsVector = + row.getTsVector(col) + +proc get*(row: Row, col: int, T: typedesc[PgTsQuery]): PgTsQuery = + row.getTsQuery(col) + proc get*(row: Row, col: int, T: typedesc[PgPoint]): PgPoint = row.getPoint(col) @@ -3785,6 +3962,8 @@ nameAccessor(getInet, PgInet) nameAccessor(getCidr, PgCidr) nameAccessor(getMacAddr, PgMacAddr) nameAccessor(getMacAddr8, PgMacAddr8) +nameAccessor(getTsVector, PgTsVector) +nameAccessor(getTsQuery, PgTsQuery) nameAccessor(getPoint, PgPoint) nameAccessor(getLine, PgLine) nameAccessor(getLseg, PgLseg) @@ -3805,6 +3984,8 @@ nameAccessor(getInetOpt, Option[PgInet]) nameAccessor(getCidrOpt, Option[PgCidr]) nameAccessor(getMacAddrOpt, Option[PgMacAddr]) nameAccessor(getMacAddr8Opt, Option[PgMacAddr8]) +nameAccessor(getTsVectorOpt, Option[PgTsVector]) +nameAccessor(getTsQueryOpt, Option[PgTsQuery]) nameAccessor(getPointOpt, Option[PgPoint]) nameAccessor(getLineOpt, Option[PgLine]) nameAccessor(getLsegOpt, Option[PgLseg]) diff --git a/tests/test_e2e.nim b/tests/test_e2e.nim index e14bfab..edb0d4c 100644 --- a/tests/test_e2e.nim +++ b/tests/test_e2e.nim @@ -6046,3 +6046,97 @@ suite "E2E: cancelNoWait": await conn.close() waitFor t() + + test "tsvector roundtrip": + proc t() {.async.} = + let conn = await connect(plainConfig()) + let v = PgTsVector("'cat':1A 'dog':3") + let res = await conn.query("SELECT $1::tsvector", @[toPgParam(v)]) + doAssert res.rows.len == 1 + let got = res.rows[0].getTsVector(0) + doAssert $got == "'cat':1A 'dog':3" + await conn.close() + + waitFor t() + + test "to_tsvector function": + proc t() {.async.} = + let conn = await connect(plainConfig()) + let res = + await conn.query("SELECT to_tsvector('english', 'The fat cat sat on the mat')") + doAssert res.rows.len == 1 + let v = res.rows[0].getTsVector(0) + let s = $v + doAssert "'cat'" in s + doAssert "'fat'" in s + doAssert "'mat'" in s + doAssert "'sat'" in s + await conn.close() + + waitFor t() + + test "tsquery roundtrip": + proc t() {.async.} = + let conn = await connect(plainConfig()) + let q = PgTsQuery("'fat' & 'rat'") + let res = await conn.query("SELECT $1::tsquery", @[toPgParam(q)]) + doAssert res.rows.len == 1 + let got = res.rows[0].getTsQuery(0) + doAssert "'fat' & 'rat'" == $got + await conn.close() + + waitFor t() + + test "full-text search with @@ operator": + proc t() {.async.} = + let conn = await connect(plainConfig()) + let res = await conn.query( + "SELECT to_tsvector('english', 'the fat cat') @@ to_tsquery('english', 'fat & cat')" + ) + doAssert res.rows.len == 1 + doAssert res.rows[0].getBool(0) == true + let res2 = await conn.query( + "SELECT to_tsvector('english', 'the fat cat') @@ to_tsquery('english', 'fat & dog')" + ) + doAssert res2.rows[0].getBool(0) == false + await conn.close() + + waitFor t() + + test "NULL tsvector and tsquery": + proc t() {.async.} = + let conn = await connect(plainConfig()) + let res = await conn.query("SELECT NULL::tsvector, NULL::tsquery") + doAssert res.rows.len == 1 + doAssert res.rows[0].getTsVectorOpt(0).isNone + doAssert res.rows[0].getTsQueryOpt(1).isNone + await conn.close() + + waitFor t() + + test "tsvector binary results": + proc t() {.async.} = + let conn = await connect(plainConfig()) + let res = + await conn.query("SELECT 'cat:1A dog:3'::tsvector", resultFormat = rfBinary) + doAssert res.rows.len == 1 + let v = res.rows[0].getTsVector(0) + let s = $v + doAssert "'cat'" in s + doAssert "'dog'" in s + await conn.close() + + waitFor t() + + test "tsquery binary results": + proc t() {.async.} = + let conn = await connect(plainConfig()) + let res = await conn.query("SELECT 'fat & rat'::tsquery", resultFormat = rfBinary) + doAssert res.rows.len == 1 + let q = res.rows[0].getTsQuery(0) + let s = $q + doAssert "'fat'" in s + doAssert "'rat'" in s + await conn.close() + + waitFor t() diff --git a/tests/test_types.nim b/tests/test_types.nim index 877b0a0..49ba3b6 100644 --- a/tests/test_types.nim +++ b/tests/test_types.nim @@ -67,6 +67,8 @@ suite "OID constants": check OidCidr == 650'i32 check OidMacAddr == 829'i32 check OidMacAddr8 == 774'i32 + check OidTsVector == 3614'i32 + check OidTsQuery == 3615'i32 suite "toPgParam": test "string": @@ -3711,3 +3713,214 @@ suite "Geometry types": let fields = @[mkField(OidCircle, 1'i16)] let row = mkRow(@[none(seq[byte])], fields) check row.getCircleOpt(0).isNone + +suite "tsvector / tsquery": + test "toPgParam PgTsVector": + let v = PgTsVector("'cat':1A 'dog':3") + let p = toPgParam(v) + check p.oid == OidTsVector + check p.format == 0 + check toString(p.value.get) == "'cat':1A 'dog':3" + + test "toPgParam PgTsQuery": + let q = PgTsQuery("'fat' & 'rat'") + let p = toPgParam(q) + check p.oid == OidTsQuery + check p.format == 0 + check toString(p.value.get) == "'fat' & 'rat'" + + test "toPgBinaryParam PgTsVector sends text format": + let v = PgTsVector("'cat':1A 'dog':3") + let p = toPgBinaryParam(v) + check p.oid == OidTsVector + check p.format == 0 + + test "toPgBinaryParam PgTsQuery sends text format": + let q = PgTsQuery("'fat' & 'rat'") + let p = toPgBinaryParam(q) + check p.oid == OidTsQuery + check p.format == 0 + + test "$ PgTsVector": + let v = PgTsVector("'cat':1A 'dog':3") + check $v == "'cat':1A 'dog':3" + + test "$ PgTsQuery": + let q = PgTsQuery("'fat' & 'rat'") + check $q == "'fat' & 'rat'" + + test "== PgTsVector": + check PgTsVector("'a':1") == PgTsVector("'a':1") + check PgTsVector("'a':1") != PgTsVector("'b':1") + + test "== PgTsQuery": + check PgTsQuery("'a' & 'b'") == PgTsQuery("'a' & 'b'") + check PgTsQuery("'a'") != PgTsQuery("'b'") + + test "getTsVector text format": + let data = toBytes("'cat':1A,2B 'dog':3") + let fields = @[mkField(OidTsVector, 0'i16)] + let row = mkRow(@[some(data)], fields) + check $row.getTsVector(0) == "'cat':1A,2B 'dog':3" + + test "getTsQuery text format": + let data = toBytes("'fat' & 'rat'") + let fields = @[mkField(OidTsQuery, 0'i16)] + let row = mkRow(@[some(data)], fields) + check $row.getTsQuery(0) == "'fat' & 'rat'" + + test "getTsVector binary format": + # Binary tsvector for 'cat':1A (1 lexeme, 1 position with weight A) + # Header: nlexemes = 1 + # Lexeme: "cat\0" + npos=1 + position=1 with weight A (3 << 14 | 1 = 0xC001) + var data: seq[byte] = @[] + # nlexemes = 1 + data.add(@(toBE32(1'i32))) + # lexeme "cat" + null terminator + for c in "cat": + data.add(byte(c)) + data.add(0'u8) + # npos = 1 + data.add(@(toBE16(1'i16))) + # position 1 with weight A (weight=3, so 3 << 14 | 1 = 0xC001) + data.add(@(toBE16(cast[int16](0xC001'u16)))) + let fields = @[mkField(OidTsVector, 1'i16)] + let row = mkRow(@[some(data)], fields) + let v = row.getTsVector(0) + check "'cat':1A" == $v + + test "getTsVector binary format multiple lexemes": + # Binary tsvector for 'bar' 'foo':2B + var data: seq[byte] = @[] + # nlexemes = 2 + data.add(@(toBE32(2'i32))) + # lexeme "bar" + null + npos=0 + for c in "bar": + data.add(byte(c)) + data.add(0'u8) + data.add(@(toBE16(0'i16))) + # lexeme "foo" + null + npos=1 + position 2 with weight B (weight=2, 2 << 14 | 2 = 0x8002) + for c in "foo": + data.add(byte(c)) + data.add(0'u8) + data.add(@(toBE16(1'i16))) + data.add(@(toBE16(cast[int16](0x8002'u16)))) + let fields = @[mkField(OidTsVector, 1'i16)] + let row = mkRow(@[some(data)], fields) + let v = row.getTsVector(0) + check "'bar' 'foo':2B" == $v + + test "getTsQuery binary format simple AND": + # Binary tsquery for 'cat' & 'dog' (prefix: AND, cat, dog) + var data: seq[byte] = @[] + # ntokens = 3 + data.add(@(toBE32(3'i32))) + # AND operator: type=2, op=2 + data.add(2'u8) + data.add(2'u8) + # operand "cat": type=1, weight=0, prefix=0, "cat\0" + data.add(1'u8) # type + data.add(0'u8) # weight + data.add(0'u8) # prefix + for c in "cat": + data.add(byte(c)) + data.add(0'u8) + # operand "dog": type=1, weight=0, prefix=0, "dog\0" + data.add(1'u8) + data.add(0'u8) + data.add(0'u8) + for c in "dog": + data.add(byte(c)) + data.add(0'u8) + let fields = @[mkField(OidTsQuery, 1'i16)] + let row = mkRow(@[some(data)], fields) + let q = row.getTsQuery(0) + check "'cat' & 'dog'" == $q + + test "getTsQuery binary format NOT": + # Binary tsquery for !'cat' (prefix: NOT, cat) + var data: seq[byte] = @[] + data.add(@(toBE32(2'i32))) + # NOT operator: type=2, op=1 + data.add(2'u8) + data.add(1'u8) + # operand "cat" + data.add(1'u8) + data.add(0'u8) + data.add(0'u8) + for c in "cat": + data.add(byte(c)) + data.add(0'u8) + let fields = @[mkField(OidTsQuery, 1'i16)] + let row = mkRow(@[some(data)], fields) + let q = row.getTsQuery(0) + check "!'cat'" == $q + + test "getTsQuery binary format PHRASE": + # Binary tsquery for 'cat' <-> 'dog' (prefix: PHRASE dist=1, cat, dog) + var data: seq[byte] = @[] + data.add(@(toBE32(3'i32))) + # PHRASE operator: type=2, op=4, distance=1 + data.add(2'u8) + data.add(4'u8) + data.add(@(toBE16(1'i16))) + # operand "cat" + data.add(1'u8) + data.add(0'u8) + data.add(0'u8) + for c in "cat": + data.add(byte(c)) + data.add(0'u8) + # operand "dog" + data.add(1'u8) + data.add(0'u8) + data.add(0'u8) + for c in "dog": + data.add(byte(c)) + data.add(0'u8) + let fields = @[mkField(OidTsQuery, 1'i16)] + let row = mkRow(@[some(data)], fields) + let q = row.getTsQuery(0) + check "'cat' <-> 'dog'" == $q + + test "getTsQuery binary format with weight and prefix": + # Binary tsquery for 'cat':AB* (single operand with weights A+B and prefix) + var data: seq[byte] = @[] + data.add(@(toBE32(1'i32))) + # operand "cat": type=1, weight=0x0C (A=0x08 + B=0x04), prefix=1 + data.add(1'u8) + data.add(0x0C'u8) # A + B + data.add(1'u8) # prefix + for c in "cat": + data.add(byte(c)) + data.add(0'u8) + let fields = @[mkField(OidTsQuery, 1'i16)] + let row = mkRow(@[some(data)], fields) + let q = row.getTsQuery(0) + check "'cat':AB*" == $q + + test "getTsVectorOpt text some": + let data = toBytes("'cat':1A") + let fields = @[mkField(OidTsVector, 0'i16)] + let row = mkRow(@[some(data)], fields) + let r = row.getTsVectorOpt(0) + check r.isSome + check $r.get == "'cat':1A" + + test "getTsVectorOpt none": + let fields = @[mkField(OidTsVector, 0'i16)] + let row = mkRow(@[none(seq[byte])], fields) + check row.getTsVectorOpt(0).isNone + + test "getTsQueryOpt text some": + let data = toBytes("'cat' & 'dog'") + let fields = @[mkField(OidTsQuery, 0'i16)] + let row = mkRow(@[some(data)], fields) + let r = row.getTsQueryOpt(0) + check r.isSome + check $r.get == "'cat' & 'dog'" + + test "getTsQueryOpt none": + let fields = @[mkField(OidTsQuery, 0'i16)] + let row = mkRow(@[none(seq[byte])], fields) + check row.getTsQueryOpt(0).isNone