Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions async_postgres/pg_connection.nim
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ type
stmtCounter*: int
stmtCacheCapacity*: int ## 0=disabled, default 256
rowDataBuf*: RowData ## Reusable RowData buffer to avoid per-query allocation
hstoreOid*: int32 ## Dynamic OID for hstore extension type; 0 if not available

QueryResult* = object
## Result of a query: field descriptions, row data, and command tag.
Expand Down Expand Up @@ -386,7 +387,10 @@ proc addStmtCache*(conn: PgConnection, sql: string, cached: CachedStmt) =
return # caller should have evicted; skip if still full
var entry = cached
if entry.resultFormats.len == 0 and entry.fields.len > 0:
entry.resultFormats = buildResultFormats(entry.fields)
if conn.hstoreOid != 0:
entry.resultFormats = buildResultFormats(entry.fields, [conn.hstoreOid])
else:
entry.resultFormats = buildResultFormats(entry.fields)
entry.colFmts = newSeq[int16](entry.fields.len)
entry.colOids = newSeq[int32](entry.fields.len)
for i in 0 ..< entry.fields.len:
Expand Down Expand Up @@ -863,6 +867,11 @@ when defined(posix):
"TCP keepalive timing options (idle/interval/count) are not supported on this platform and will be ignored"
.}

proc bytesToString(data: seq[byte]): string =
result = newString(data.len)
for i in 0 ..< data.len:
result[i] = char(data[i])

proc connectToHost(
config: ConnConfig, hostAddr: string, hostPort: int
): Future[PgConnection] {.async.} =
Expand Down Expand Up @@ -993,6 +1002,36 @@ proc connectToHost(
discard
await conn.fillRecvBuf()

# Discover extension type OIDs (hstore, etc.)
conn.state = csBusy
await conn.sendMsg(
encodeQuery("SELECT oid FROM pg_type WHERE typname = 'hstore' LIMIT 1")
)
block discoverLoop:
while true:
while (let opt = conn.nextMessage(); opt.isSome):
let msg = opt.get
case msg.kind
of bmkRowDescription:
discard
of bmkDataRow:
if msg.columns.len > 0 and msg.columns[0].isSome:
try:
conn.hstoreOid = int32(parseInt(bytesToString(msg.columns[0].get)))
except ValueError:
discard
of bmkCommandComplete, bmkEmptyQueryResponse:
discard
of bmkReadyForQuery:
conn.txStatus = msg.txStatus
conn.state = csReady
break discoverLoop
of bmkErrorResponse:
discard
else:
discard
await conn.fillRecvBuf()

conn.createdAt = Moment.now()
return conn
except CatchableError as e:
Expand Down Expand Up @@ -1201,11 +1240,6 @@ proc close*(conn: PgConnection): Future[void] {.async.} =
conn.notifyWaiter.fail(newException(PgError, "Connection closed"))
await conn.closeTransport()

proc bytesToString(data: seq[byte]): string =
result = newString(data.len)
for i in 0 ..< data.len:
result[i] = char(data[i])

proc checkSessionAttrs(
conn: PgConnection, attrs: TargetSessionAttrs
): Future[bool] {.async.} =
Expand Down Expand Up @@ -1320,6 +1354,7 @@ proc reconnectInPlace(conn: PgConnection) {.async.} =
conn.txStatus = newConn.txStatus
conn.state = csReady
conn.createdAt = newConn.createdAt
conn.hstoreOid = newConn.hstoreOid
for ch in conn.listenChannels:
discard await conn.simpleQuery("LISTEN " & quoteIdentifier(ch))

Expand Down
15 changes: 15 additions & 0 deletions async_postgres/pg_protocol.nim
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,21 @@ proc buildResultFormats*(fields: openArray[FieldDescription]): seq[int16] =
for i, f in fields:
result[i] = if isBinarySafeOid(f.typeOid): 1'i16 else: 0'i16

proc buildResultFormats*(
fields: openArray[FieldDescription], extraBinaryOids: openArray[int32]
): seq[int16] =
## Build per-column binary format codes with additional runtime-safe OIDs.
result = newSeq[int16](fields.len)
for i, f in fields:
if isBinarySafeOid(f.typeOid):
result[i] = 1'i16
else:
result[i] = 0'i16
for oid in extraBinaryOids:
if f.typeOid == oid:
result[i] = 1'i16
break

proc parseDataRowInto*(body: openArray[byte], rd: RowData) =
## Parse a DataRow message body directly into a RowData flat buffer.
## Column data is appended to rd.buf and (offset, length) pairs to rd.cellIndex.
Expand Down
177 changes: 177 additions & 0 deletions async_postgres/pg_types.nim
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ type
PgMultirange*[T] = distinct seq[PgRange[T]]
## PostgreSQL multirange value (PostgreSQL 14+). A sorted set of non-overlapping ranges.

PgHstore* = Table[string, Option[string]]
## PostgreSQL hstore type: a set of key/value pairs where values may be NULL.

PgParam* = object
## A single query parameter in binary wire format, ready to send to PostgreSQL.
oid*: int32
Expand Down Expand Up @@ -692,6 +695,34 @@ proc toPgParam*(v: PgCircle): PgParam =
proc toPgParam*(v: JsonNode): PgParam =
PgParam(oid: OidJsonb, format: 0, value: some(toBytes($v)))

proc encodeHstoreText*(v: PgHstore): string =
## Encode hstore as PostgreSQL text format: ``"key1"=>"val1", "key2"=>NULL``.
var parts: seq[string]
for k, v in v.pairs:
var keyEsc = newStringOfCap(k.len + 2)
keyEsc.add('"')
for c in k:
if c == '"' or c == '\\':
keyEsc.add('\\')
keyEsc.add(c)
keyEsc.add('"')
if v.isSome:
var valEsc = newStringOfCap(v.get.len + 2)
valEsc.add('"')
for c in v.get:
if c == '"' or c == '\\':
valEsc.add('\\')
valEsc.add(c)
valEsc.add('"')
parts.add(keyEsc & "=>" & valEsc)
else:
parts.add(keyEsc & "=>NULL")
parts.join(", ")

proc toPgParam*(v: PgHstore): PgParam =
## Send hstore as text format. PostgreSQL casts text to hstore implicitly.
PgParam(oid: OidText, format: 0, value: some(toBytes(encodeHstoreText(v))))

proc encodeBinaryArray*(elemOid: int32, elements: seq[seq[byte]]): seq[byte] =
## Encode a 1-dimensional PostgreSQL binary array.
## Header: ndim(4) + has_null(4) + elem_oid(4) + dim_len(4) + lower_bound(4) = 20 bytes
Expand Down Expand Up @@ -1089,6 +1120,75 @@ proc toPgBinaryParam*(v: Option[JsonNode]): PgParam =
else:
PgParam(oid: OidJsonb, format: 1, value: none(seq[byte]))

proc encodeHstoreBinary*(v: PgHstore): seq[byte] =
## Encode hstore as PostgreSQL binary format.
## Format: numPairs(int32) + [keyLen(int32) + keyData + valLen(int32) + valData]...
var size = 4
for k, val in v.pairs:
size += 4 + k.len + 4
if val.isSome:
size += val.get.len
result = newSeq[byte](size)
let np = toBE32(int32(v.len))
copyMem(addr result[0], unsafeAddr np[0], 4)
var pos = 4
for k, val in v.pairs:
let kLen = toBE32(int32(k.len))
copyMem(addr result[pos], unsafeAddr kLen[0], 4)
pos += 4
if k.len > 0:
copyMem(addr result[pos], unsafeAddr k[0], k.len)
pos += k.len
if val.isSome:
let vLen = toBE32(int32(val.get.len))
copyMem(addr result[pos], unsafeAddr vLen[0], 4)
pos += 4
if val.get.len > 0:
copyMem(addr result[pos], unsafeAddr val.get[0], val.get.len)
pos += val.get.len
else:
let nullLen = toBE32(-1'i32)
copyMem(addr result[pos], unsafeAddr nullLen[0], 4)
pos += 4

proc decodeHstoreBinary*(data: openArray[byte]): PgHstore =
## Decode PostgreSQL binary hstore format.
result = initTable[string, Option[string]]()
if data.len < 4:
raise newException(PgTypeError, "hstore binary data too short")
let numPairs = int(fromBE32(data.toOpenArray(0, 3)))
var pos = 4
for _ in 0 ..< numPairs:
if pos + 4 > data.len:
raise newException(PgTypeError, "hstore binary: truncated key length")
let keyLen = int(fromBE32(data.toOpenArray(pos, pos + 3)))
pos += 4
if keyLen < 0 or pos + keyLen > data.len:
raise newException(PgTypeError, "hstore binary: truncated key data")
var key = newString(keyLen)
if keyLen > 0:
copyMem(addr key[0], unsafeAddr data[pos], keyLen)
pos += keyLen
if pos + 4 > data.len:
raise newException(PgTypeError, "hstore binary: truncated value length")
let valLen = int(fromBE32(data.toOpenArray(pos, pos + 3)))
pos += 4
if valLen == -1:
result[key] = none(string)
else:
if valLen < 0 or pos + valLen > data.len:
raise newException(PgTypeError, "hstore binary: truncated value data")
var val = newString(valLen)
if valLen > 0:
copyMem(addr val[0], unsafeAddr data[pos], valLen)
pos += valLen
result[key] = some(val)

proc toPgBinaryParam*(v: PgHstore, oid: int32): PgParam =
## Encode hstore in binary format. Requires the dynamic hstore OID
## (available as ``conn.hstoreOid`` after connection).
PgParam(oid: oid, format: 1, value: some(encodeHstoreBinary(v)))

proc toPgBinaryParam*[T](v: Option[T]): PgParam =
if v.isSome:
result = toPgBinaryParam(v.get)
Expand Down Expand Up @@ -1704,6 +1804,71 @@ proc getDate*(row: Row, col: int): DateTime =
except TimeParseError:
raise newException(PgTypeError, "Invalid date: " & s)

proc parseHstoreText*(s: string): PgHstore =
## Parse PostgreSQL hstore text format: ``"key1"=>"val1", "key2"=>NULL``.
result = initTable[string, Option[string]]()
if s.len == 0:
return
var i = 0
while i < s.len:
# Skip whitespace and commas
while i < s.len and s[i] in {' ', ',', '\t', '\n', '\r'}:
i += 1
if i >= s.len:
break
# Parse key (must be quoted)
if s[i] != '"':
raise newException(PgTypeError, "hstore: expected '\"' at position " & $i)
i += 1
var key = ""
while i < s.len:
if s[i] == '\\' and i + 1 < s.len:
i += 1
key.add(s[i])
elif s[i] == '"':
break
else:
key.add(s[i])
i += 1
if i >= s.len:
raise newException(PgTypeError, "hstore: unterminated key string")
i += 1 # skip closing quote
# Skip whitespace
while i < s.len and s[i] == ' ':
i += 1
# Expect =>
if i + 1 >= s.len or s[i] != '=' or s[i + 1] != '>':
raise newException(PgTypeError, "hstore: expected '=>' at position " & $i)
i += 2
# Skip whitespace
while i < s.len and s[i] == ' ':
i += 1
# Parse value (NULL or quoted string)
if i + 3 < s.len and s[i] == 'N' and s[i + 1] == 'U' and s[i + 2] == 'L' and
s[i + 3] == 'L' and (i + 4 >= s.len or s[i + 4] in {',', ' ', '\t', '\n', '\r'}):
result[key] = none(string)
i += 4
elif i < s.len and s[i] == '"':
i += 1
var val = ""
while i < s.len:
if s[i] == '\\' and i + 1 < s.len:
i += 1
val.add(s[i])
elif s[i] == '"':
break
else:
val.add(s[i])
i += 1
if i >= s.len:
raise newException(PgTypeError, "hstore: unterminated value string")
i += 1 # skip closing quote
result[key] = some(val)
else:
raise newException(
PgTypeError, "hstore: expected NULL or quoted string at position " & $i
)

proc getJson*(row: Row, col: int): JsonNode =
## Get a column value as a parsed JsonNode. Handles binary json/jsonb format.
if row.isBinaryCol(col):
Expand Down Expand Up @@ -2055,6 +2220,15 @@ proc getXml*(row: Row, col: int): PgXml =
return PgXml(s)
PgXml(row.getStr(col))

proc getHstore*(row: Row, col: int): PgHstore =
## Get a column value as PgHstore. Handles both text and binary format.
if row.isBinaryCol(col):
let (off, clen) = cellInfo(row, col)
if clen == -1:
raise newException(PgTypeError, "Column " & $col & " is NULL")
return decodeHstoreBinary(row.data.buf.toOpenArray(off, off + clen - 1))
parseHstoreText(row.getStr(col))

# Geometry text format parsers

proc parsePointText(s: string): PgPoint =
Expand Down Expand Up @@ -2291,6 +2465,7 @@ optAccessor(getMacAddr8, getMacAddr8Opt, PgMacAddr8)
optAccessor(getTsVector, getTsVectorOpt, PgTsVector)
optAccessor(getTsQuery, getTsQueryOpt, PgTsQuery)
optAccessor(getXml, getXmlOpt, PgXml)
optAccessor(getHstore, getHstoreOpt, PgHstore)
optAccessor(getPoint, getPointOpt, PgPoint)
optAccessor(getLine, getLineOpt, PgLine)
optAccessor(getLseg, getLsegOpt, PgLseg)
Expand Down Expand Up @@ -4356,6 +4531,7 @@ nameAccessor(getMacAddr8, PgMacAddr8)
nameAccessor(getTsVector, PgTsVector)
nameAccessor(getTsQuery, PgTsQuery)
nameAccessor(getXml, PgXml)
nameAccessor(getHstore, PgHstore)
nameAccessor(getPoint, PgPoint)
nameAccessor(getLine, PgLine)
nameAccessor(getLseg, PgLseg)
Expand All @@ -4379,6 +4555,7 @@ nameAccessor(getMacAddr8Opt, Option[PgMacAddr8])
nameAccessor(getTsVectorOpt, Option[PgTsVector])
nameAccessor(getTsQueryOpt, Option[PgTsQuery])
nameAccessor(getXmlOpt, Option[PgXml])
nameAccessor(getHstoreOpt, Option[PgHstore])
nameAccessor(getPointOpt, Option[PgPoint])
nameAccessor(getLineOpt, Option[PgLine])
nameAccessor(getLsegOpt, Option[PgLseg])
Expand Down
28 changes: 27 additions & 1 deletion tests/test_ssl.nim
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,35 @@ proc sendAuthOkAndReady(client: MockClient): Future[void] {.async.} =
resp.add(buildBackendMsg('Z', @[byte('I')]))
await sendBytes(client, resp)

proc sendEmptyQueryResult(client: MockClient): Future[void] {.async.} =
## Respond to the hstore OID discovery query with an empty result set.
var resp: seq[byte]
# CommandComplete: "SELECT 0"
var ccBody: seq[byte] = @[]
for c in "SELECT 0":
ccBody.add(byte(c))
ccBody.add(0'u8)
resp.add(buildBackendMsg('C', ccBody))
# ReadyForQuery
resp.add(buildBackendMsg('Z', @[byte('I')]))
await sendBytes(client, resp)

proc drainFrontendMessage(client: MockClient): Future[void] {.async.} =
## Read a frontend message (type byte + int32 length + body).
discard await readN(client, 1) # message type
let lenBuf = await readN(client, 4)
let msgLen = decodeInt32(lenBuf, 0)
if msgLen > 4:
discard await readN(client, msgLen - 4)

proc drainUntilClose(client: MockClient): Future[void] {.async.} =
try:
discard await readN(client, 64)
# Drain the hstore discovery query ('Q' message)
await drainFrontendMessage(client)
# Send back an empty result + ReadyForQuery
await sendEmptyQueryResult(client)
# Drain the Terminate message
await drainFrontendMessage(client)
except CatchableError:
discard

Expand Down
Loading
Loading