From 9b1ece8a046cc72e2a7bf7ec2bcdcd2b6efa0bd7 Mon Sep 17 00:00:00 2001 From: fox0430 Date: Wed, 1 Apr 2026 16:15:17 +0900 Subject: [PATCH 1/2] Add PgHsotre --- async_postgres/pg_connection.nim | 47 ++++++-- async_postgres/pg_protocol.nim | 15 +++ async_postgres/pg_types.nim | 177 +++++++++++++++++++++++++++++++ tests/test_types.nim | 142 +++++++++++++++++++++++++ 4 files changed, 375 insertions(+), 6 deletions(-) diff --git a/async_postgres/pg_connection.nim b/async_postgres/pg_connection.nim index 67a1e83..f777e60 100644 --- a/async_postgres/pg_connection.nim +++ b/async_postgres/pg_connection.nim @@ -177,6 +177,7 @@ type stmtCounter*: int stmtCacheCapacity*: int ## 0=disabled, default 256 rowDataBuf*: RowData ## Reusable RowData buffer to avoid per-query allocation + hstoreOid*: int32 ## Dynamic OID for hstore extension type; 0 if not available QueryResult* = object ## Result of a query: field descriptions, row data, and command tag. @@ -386,7 +387,10 @@ proc addStmtCache*(conn: PgConnection, sql: string, cached: CachedStmt) = return # caller should have evicted; skip if still full var entry = cached if entry.resultFormats.len == 0 and entry.fields.len > 0: - entry.resultFormats = buildResultFormats(entry.fields) + if conn.hstoreOid != 0: + entry.resultFormats = buildResultFormats(entry.fields, [conn.hstoreOid]) + else: + entry.resultFormats = buildResultFormats(entry.fields) entry.colFmts = newSeq[int16](entry.fields.len) entry.colOids = newSeq[int32](entry.fields.len) for i in 0 ..< entry.fields.len: @@ -863,6 +867,11 @@ when defined(posix): "TCP keepalive timing options (idle/interval/count) are not supported on this platform and will be ignored" .} +proc bytesToString(data: seq[byte]): string = + result = newString(data.len) + for i in 0 ..< data.len: + result[i] = char(data[i]) + proc connectToHost( config: ConnConfig, hostAddr: string, hostPort: int ): Future[PgConnection] {.async.} = @@ -993,6 +1002,36 @@ proc connectToHost( discard await conn.fillRecvBuf() + # Discover extension type OIDs (hstore, etc.) + conn.state = csBusy + await conn.sendMsg( + encodeQuery("SELECT oid FROM pg_type WHERE typname = 'hstore' LIMIT 1") + ) + block discoverLoop: + while true: + while (let opt = conn.nextMessage(); opt.isSome): + let msg = opt.get + case msg.kind + of bmkRowDescription: + discard + of bmkDataRow: + if msg.columns.len > 0 and msg.columns[0].isSome: + try: + conn.hstoreOid = int32(parseInt(bytesToString(msg.columns[0].get))) + except ValueError: + discard + of bmkCommandComplete, bmkEmptyQueryResponse: + discard + of bmkReadyForQuery: + conn.txStatus = msg.txStatus + conn.state = csReady + break discoverLoop + of bmkErrorResponse: + discard + else: + discard + await conn.fillRecvBuf() + conn.createdAt = Moment.now() return conn except CatchableError as e: @@ -1201,11 +1240,6 @@ proc close*(conn: PgConnection): Future[void] {.async.} = conn.notifyWaiter.fail(newException(PgError, "Connection closed")) await conn.closeTransport() -proc bytesToString(data: seq[byte]): string = - result = newString(data.len) - for i in 0 ..< data.len: - result[i] = char(data[i]) - proc checkSessionAttrs( conn: PgConnection, attrs: TargetSessionAttrs ): Future[bool] {.async.} = @@ -1320,6 +1354,7 @@ proc reconnectInPlace(conn: PgConnection) {.async.} = conn.txStatus = newConn.txStatus conn.state = csReady conn.createdAt = newConn.createdAt + conn.hstoreOid = newConn.hstoreOid for ch in conn.listenChannels: discard await conn.simpleQuery("LISTEN " & quoteIdentifier(ch)) diff --git a/async_postgres/pg_protocol.nim b/async_postgres/pg_protocol.nim index cc09e3f..c51c91d 100644 --- a/async_postgres/pg_protocol.nim +++ b/async_postgres/pg_protocol.nim @@ -784,6 +784,21 @@ proc buildResultFormats*(fields: openArray[FieldDescription]): seq[int16] = for i, f in fields: result[i] = if isBinarySafeOid(f.typeOid): 1'i16 else: 0'i16 +proc buildResultFormats*( + fields: openArray[FieldDescription], extraBinaryOids: openArray[int32] +): seq[int16] = + ## Build per-column binary format codes with additional runtime-safe OIDs. + result = newSeq[int16](fields.len) + for i, f in fields: + if isBinarySafeOid(f.typeOid): + result[i] = 1'i16 + else: + result[i] = 0'i16 + for oid in extraBinaryOids: + if f.typeOid == oid: + result[i] = 1'i16 + break + proc parseDataRowInto*(body: openArray[byte], rd: RowData) = ## Parse a DataRow message body directly into a RowData flat buffer. ## Column data is appended to rd.buf and (offset, length) pairs to rd.cellIndex. diff --git a/async_postgres/pg_types.nim b/async_postgres/pg_types.nim index ff10562..1025115 100644 --- a/async_postgres/pg_types.nim +++ b/async_postgres/pg_types.nim @@ -95,6 +95,9 @@ type PgMultirange*[T] = distinct seq[PgRange[T]] ## PostgreSQL multirange value (PostgreSQL 14+). A sorted set of non-overlapping ranges. + PgHstore* = Table[string, Option[string]] + ## PostgreSQL hstore type: a set of key/value pairs where values may be NULL. + PgParam* = object ## A single query parameter in binary wire format, ready to send to PostgreSQL. oid*: int32 @@ -692,6 +695,34 @@ proc toPgParam*(v: PgCircle): PgParam = proc toPgParam*(v: JsonNode): PgParam = PgParam(oid: OidJsonb, format: 0, value: some(toBytes($v))) +proc encodeHstoreText*(v: PgHstore): string = + ## Encode hstore as PostgreSQL text format: ``"key1"=>"val1", "key2"=>NULL``. + var parts: seq[string] + for k, v in v.pairs: + var keyEsc = newStringOfCap(k.len + 2) + keyEsc.add('"') + for c in k: + if c == '"' or c == '\\': + keyEsc.add('\\') + keyEsc.add(c) + keyEsc.add('"') + if v.isSome: + var valEsc = newStringOfCap(v.get.len + 2) + valEsc.add('"') + for c in v.get: + if c == '"' or c == '\\': + valEsc.add('\\') + valEsc.add(c) + valEsc.add('"') + parts.add(keyEsc & "=>" & valEsc) + else: + parts.add(keyEsc & "=>NULL") + parts.join(", ") + +proc toPgParam*(v: PgHstore): PgParam = + ## Send hstore as text format. PostgreSQL casts text to hstore implicitly. + PgParam(oid: OidText, format: 0, value: some(toBytes(encodeHstoreText(v)))) + proc encodeBinaryArray*(elemOid: int32, elements: seq[seq[byte]]): seq[byte] = ## Encode a 1-dimensional PostgreSQL binary array. ## Header: ndim(4) + has_null(4) + elem_oid(4) + dim_len(4) + lower_bound(4) = 20 bytes @@ -1089,6 +1120,75 @@ proc toPgBinaryParam*(v: Option[JsonNode]): PgParam = else: PgParam(oid: OidJsonb, format: 1, value: none(seq[byte])) +proc encodeHstoreBinary*(v: PgHstore): seq[byte] = + ## Encode hstore as PostgreSQL binary format. + ## Format: numPairs(int32) + [keyLen(int32) + keyData + valLen(int32) + valData]... + var size = 4 + for k, val in v.pairs: + size += 4 + k.len + 4 + if val.isSome: + size += val.get.len + result = newSeq[byte](size) + let np = toBE32(int32(v.len)) + copyMem(addr result[0], unsafeAddr np[0], 4) + var pos = 4 + for k, val in v.pairs: + let kLen = toBE32(int32(k.len)) + copyMem(addr result[pos], unsafeAddr kLen[0], 4) + pos += 4 + if k.len > 0: + copyMem(addr result[pos], unsafeAddr k[0], k.len) + pos += k.len + if val.isSome: + let vLen = toBE32(int32(val.get.len)) + copyMem(addr result[pos], unsafeAddr vLen[0], 4) + pos += 4 + if val.get.len > 0: + copyMem(addr result[pos], unsafeAddr val.get[0], val.get.len) + pos += val.get.len + else: + let nullLen = toBE32(-1'i32) + copyMem(addr result[pos], unsafeAddr nullLen[0], 4) + pos += 4 + +proc decodeHstoreBinary*(data: openArray[byte]): PgHstore = + ## Decode PostgreSQL binary hstore format. + result = initTable[string, Option[string]]() + if data.len < 4: + raise newException(PgTypeError, "hstore binary data too short") + let numPairs = int(fromBE32(data.toOpenArray(0, 3))) + var pos = 4 + for _ in 0 ..< numPairs: + if pos + 4 > data.len: + raise newException(PgTypeError, "hstore binary: truncated key length") + let keyLen = int(fromBE32(data.toOpenArray(pos, pos + 3))) + pos += 4 + if keyLen < 0 or pos + keyLen > data.len: + raise newException(PgTypeError, "hstore binary: truncated key data") + var key = newString(keyLen) + if keyLen > 0: + copyMem(addr key[0], unsafeAddr data[pos], keyLen) + pos += keyLen + if pos + 4 > data.len: + raise newException(PgTypeError, "hstore binary: truncated value length") + let valLen = int(fromBE32(data.toOpenArray(pos, pos + 3))) + pos += 4 + if valLen == -1: + result[key] = none(string) + else: + if valLen < 0 or pos + valLen > data.len: + raise newException(PgTypeError, "hstore binary: truncated value data") + var val = newString(valLen) + if valLen > 0: + copyMem(addr val[0], unsafeAddr data[pos], valLen) + pos += valLen + result[key] = some(val) + +proc toPgBinaryParam*(v: PgHstore, oid: int32): PgParam = + ## Encode hstore in binary format. Requires the dynamic hstore OID + ## (available as ``conn.hstoreOid`` after connection). + PgParam(oid: oid, format: 1, value: some(encodeHstoreBinary(v))) + proc toPgBinaryParam*[T](v: Option[T]): PgParam = if v.isSome: result = toPgBinaryParam(v.get) @@ -1704,6 +1804,71 @@ proc getDate*(row: Row, col: int): DateTime = except TimeParseError: raise newException(PgTypeError, "Invalid date: " & s) +proc parseHstoreText*(s: string): PgHstore = + ## Parse PostgreSQL hstore text format: ``"key1"=>"val1", "key2"=>NULL``. + result = initTable[string, Option[string]]() + if s.len == 0: + return + var i = 0 + while i < s.len: + # Skip whitespace and commas + while i < s.len and s[i] in {' ', ',', '\t', '\n', '\r'}: + i += 1 + if i >= s.len: + break + # Parse key (must be quoted) + if s[i] != '"': + raise newException(PgTypeError, "hstore: expected '\"' at position " & $i) + i += 1 + var key = "" + while i < s.len: + if s[i] == '\\' and i + 1 < s.len: + i += 1 + key.add(s[i]) + elif s[i] == '"': + break + else: + key.add(s[i]) + i += 1 + if i >= s.len: + raise newException(PgTypeError, "hstore: unterminated key string") + i += 1 # skip closing quote + # Skip whitespace + while i < s.len and s[i] == ' ': + i += 1 + # Expect => + if i + 1 >= s.len or s[i] != '=' or s[i + 1] != '>': + raise newException(PgTypeError, "hstore: expected '=>' at position " & $i) + i += 2 + # Skip whitespace + while i < s.len and s[i] == ' ': + i += 1 + # Parse value (NULL or quoted string) + if i + 3 < s.len and s[i] == 'N' and s[i + 1] == 'U' and s[i + 2] == 'L' and + s[i + 3] == 'L' and (i + 4 >= s.len or s[i + 4] in {',', ' ', '\t', '\n', '\r'}): + result[key] = none(string) + i += 4 + elif i < s.len and s[i] == '"': + i += 1 + var val = "" + while i < s.len: + if s[i] == '\\' and i + 1 < s.len: + i += 1 + val.add(s[i]) + elif s[i] == '"': + break + else: + val.add(s[i]) + i += 1 + if i >= s.len: + raise newException(PgTypeError, "hstore: unterminated value string") + i += 1 # skip closing quote + result[key] = some(val) + else: + raise newException( + PgTypeError, "hstore: expected NULL or quoted string at position " & $i + ) + proc getJson*(row: Row, col: int): JsonNode = ## Get a column value as a parsed JsonNode. Handles binary json/jsonb format. if row.isBinaryCol(col): @@ -2055,6 +2220,15 @@ proc getXml*(row: Row, col: int): PgXml = return PgXml(s) PgXml(row.getStr(col)) +proc getHstore*(row: Row, col: int): PgHstore = + ## Get a column value as PgHstore. Handles both text and binary format. + if row.isBinaryCol(col): + let (off, clen) = cellInfo(row, col) + if clen == -1: + raise newException(PgTypeError, "Column " & $col & " is NULL") + return decodeHstoreBinary(row.data.buf.toOpenArray(off, off + clen - 1)) + parseHstoreText(row.getStr(col)) + # Geometry text format parsers proc parsePointText(s: string): PgPoint = @@ -2291,6 +2465,7 @@ optAccessor(getMacAddr8, getMacAddr8Opt, PgMacAddr8) optAccessor(getTsVector, getTsVectorOpt, PgTsVector) optAccessor(getTsQuery, getTsQueryOpt, PgTsQuery) optAccessor(getXml, getXmlOpt, PgXml) +optAccessor(getHstore, getHstoreOpt, PgHstore) optAccessor(getPoint, getPointOpt, PgPoint) optAccessor(getLine, getLineOpt, PgLine) optAccessor(getLseg, getLsegOpt, PgLseg) @@ -4356,6 +4531,7 @@ nameAccessor(getMacAddr8, PgMacAddr8) nameAccessor(getTsVector, PgTsVector) nameAccessor(getTsQuery, PgTsQuery) nameAccessor(getXml, PgXml) +nameAccessor(getHstore, PgHstore) nameAccessor(getPoint, PgPoint) nameAccessor(getLine, PgLine) nameAccessor(getLseg, PgLseg) @@ -4379,6 +4555,7 @@ nameAccessor(getMacAddr8Opt, Option[PgMacAddr8]) nameAccessor(getTsVectorOpt, Option[PgTsVector]) nameAccessor(getTsQueryOpt, Option[PgTsQuery]) nameAccessor(getXmlOpt, Option[PgXml]) +nameAccessor(getHstoreOpt, Option[PgHstore]) nameAccessor(getPointOpt, Option[PgPoint]) nameAccessor(getLineOpt, Option[PgLine]) nameAccessor(getLsegOpt, Option[PgLseg]) diff --git a/tests/test_types.nim b/tests/test_types.nim index dbdea44..d0641f3 100644 --- a/tests/test_types.nim +++ b/tests/test_types.nim @@ -4287,3 +4287,145 @@ suite "xml": let fields = @[mkField(OidXml, 0'i16)] let row = mkRow(@[none(seq[byte])], fields) check row.getXmlOpt(0).isNone + +suite "hstore": + test "encodeHstoreText empty": + let h: PgHstore = initTable[string, Option[string]]() + check encodeHstoreText(h) == "" + + test "encodeHstoreText single pair": + var h: PgHstore = initTable[string, Option[string]]() + h["key"] = some("val") + check encodeHstoreText(h) == "\"key\"=>\"val\"" + + test "encodeHstoreText NULL value": + var h: PgHstore = initTable[string, Option[string]]() + h["key"] = none(string) + check encodeHstoreText(h) == "\"key\"=>NULL" + + test "encodeHstoreText escape": + var h: PgHstore = initTable[string, Option[string]]() + h["k\"ey"] = some("v\\al") + check encodeHstoreText(h) == "\"k\\\"ey\"=>\"v\\\\al\"" + + test "parseHstoreText empty": + let h = parseHstoreText("") + check h.len == 0 + + test "parseHstoreText single pair": + let h = parseHstoreText("\"key\"=>\"val\"") + check h.len == 1 + check h["key"] == some("val") + + test "parseHstoreText NULL value": + let h = parseHstoreText("\"key\"=>NULL") + check h.len == 1 + check h["key"] == none(string) + + test "parseHstoreText multiple pairs": + let h = parseHstoreText("\"a\"=>\"1\", \"b\"=>NULL, \"c\"=>\"3\"") + check h.len == 3 + check h["a"] == some("1") + check h["b"] == none(string) + check h["c"] == some("3") + + test "parseHstoreText escaped": + let h = parseHstoreText("\"k\\\"ey\"=>\"v\\\\al\"") + check h.len == 1 + check h["k\"ey"] == some("v\\al") + + test "parseHstoreText roundtrip": + var h: PgHstore = initTable[string, Option[string]]() + h["hello"] = some("world") + h["null_val"] = none(string) + let encoded = encodeHstoreText(h) + let decoded = parseHstoreText(encoded) + check decoded == h + + test "encodeHstoreBinary empty": + let h: PgHstore = initTable[string, Option[string]]() + let data = encodeHstoreBinary(h) + check data == @[byte 0, 0, 0, 0] # numPairs = 0 + + test "decodeHstoreBinary empty": + let data = @[byte 0, 0, 0, 0] + let h = decodeHstoreBinary(data) + check h.len == 0 + + test "encodeHstoreBinary and decodeHstoreBinary roundtrip": + var h: PgHstore = initTable[string, Option[string]]() + h["key"] = some("val") + h["nul"] = none(string) + let data = encodeHstoreBinary(h) + let decoded = decodeHstoreBinary(data) + check decoded == h + + test "decodeHstoreBinary single pair with NULL": + # numPairs=1, key="a" (len=1), val=NULL (len=-1) + let data = @[ + byte 0, + 0, + 0, + 1, # numPairs = 1 + 0, + 0, + 0, + 1, # keyLen = 1 + byte('a'), # key data + 0xFF, + 0xFF, + 0xFF, + 0xFF, # valLen = -1 (NULL) + ] + let h = decodeHstoreBinary(data) + check h.len == 1 + check h["a"] == none(string) + + test "toPgParam hstore": + var h: PgHstore = initTable[string, Option[string]]() + h["k"] = some("v") + let p = toPgParam(h) + check p.oid == OidText + check p.format == 0 + check p.value.isSome + check toString(p.value.get) == "\"k\"=>\"v\"" + + test "toPgBinaryParam hstore": + var h: PgHstore = initTable[string, Option[string]]() + h["k"] = some("v") + let p = toPgBinaryParam(h, 16385'i32) + check p.oid == 16385'i32 + check p.format == 1 + check p.value.isSome + let decoded = decodeHstoreBinary(p.value.get) + check decoded == h + + test "getHstore text format": + let data = toBytes("\"a\"=>\"1\", \"b\"=>NULL") + let fields = @[mkField(OidText, 0'i16)] + let row = mkRow(@[some(data)], fields) + let h = row.getHstore(0) + check h.len == 2 + check h["a"] == some("1") + check h["b"] == none(string) + + test "getHstore binary format": + var h: PgHstore = initTable[string, Option[string]]() + h["key"] = some("val") + let data = encodeHstoreBinary(h) + let fields = @[mkField(16385'i32, 1'i16)] + let row = mkRow(@[some(data)], fields) + check row.getHstore(0) == h + + test "getHstoreOpt some": + let data = toBytes("\"a\"=>\"1\"") + let fields = @[mkField(OidText, 0'i16)] + let row = mkRow(@[some(data)], fields) + let r = row.getHstoreOpt(0) + check r.isSome + check r.get["a"] == some("1") + + test "getHstoreOpt none": + let fields = @[mkField(OidText, 0'i16)] + let row = mkRow(@[none(seq[byte])], fields) + check row.getHstoreOpt(0).isNone From 5d2e4ca8894f14d051a571f7d0433a105c92eee3 Mon Sep 17 00:00:00 2001 From: fox0430 Date: Wed, 1 Apr 2026 16:52:46 +0900 Subject: [PATCH 2/2] fix tests --- tests/test_ssl.nim | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/test_ssl.nim b/tests/test_ssl.nim index c8f6588..795f3f5 100644 --- a/tests/test_ssl.nim +++ b/tests/test_ssl.nim @@ -96,9 +96,35 @@ proc sendAuthOkAndReady(client: MockClient): Future[void] {.async.} = resp.add(buildBackendMsg('Z', @[byte('I')])) await sendBytes(client, resp) +proc sendEmptyQueryResult(client: MockClient): Future[void] {.async.} = + ## Respond to the hstore OID discovery query with an empty result set. + var resp: seq[byte] + # CommandComplete: "SELECT 0" + var ccBody: seq[byte] = @[] + for c in "SELECT 0": + ccBody.add(byte(c)) + ccBody.add(0'u8) + resp.add(buildBackendMsg('C', ccBody)) + # ReadyForQuery + resp.add(buildBackendMsg('Z', @[byte('I')])) + await sendBytes(client, resp) + +proc drainFrontendMessage(client: MockClient): Future[void] {.async.} = + ## Read a frontend message (type byte + int32 length + body). + discard await readN(client, 1) # message type + let lenBuf = await readN(client, 4) + let msgLen = decodeInt32(lenBuf, 0) + if msgLen > 4: + discard await readN(client, msgLen - 4) + proc drainUntilClose(client: MockClient): Future[void] {.async.} = try: - discard await readN(client, 64) + # Drain the hstore discovery query ('Q' message) + await drainFrontendMessage(client) + # Send back an empty result + ReadyForQuery + await sendEmptyQueryResult(client) + # Drain the Terminate message + await drainFrontendMessage(client) except CatchableError: discard