Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions async_postgres/pg_connection.nim
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ type
stmtCounter: int
stmtCacheCapacity: int ## 0=disabled, default 256
hstoreOid: int32 ## Dynamic OID for hstore extension type; 0 if not available
hstoreArrayOid: int32 ## Dynamic OID for hstore[] array; 0 if not available
tracer: PgTracer ## Inherited from ConnConfig on connect

QueryResult* = object
Expand Down Expand Up @@ -337,6 +338,27 @@ func hstoreOid*(conn: PgConnection): int32 {.inline.} =
## Dynamic OID for hstore extension type; 0 if not available.
conn.hstoreOid

func hstoreArrayOid*(conn: PgConnection): int32 {.inline.} =
## Dynamic OID for hstore[] array type; 0 if not available.
conn.hstoreArrayOid

proc toPgBinaryParam*(conn: PgConnection, v: PgHstore): PgParam {.inline.} =
## Convenience overload: encode hstore in binary using ``conn.hstoreOid``.
## Raises ``PgTypeError`` if the hstore extension OID has not been discovered
## (e.g. extension not installed on the server).
if conn.hstoreOid == 0:
raise newException(PgTypeError, "hstore OID not available on this connection")
toPgBinaryParam(v, conn.hstoreOid)

proc toPgBinaryParam*(conn: PgConnection, v: seq[PgHstore]): PgParam {.inline.} =
## Convenience overload: encode hstore[] in binary using ``conn.hstoreOid``
## and ``conn.hstoreArrayOid``. Raises ``PgTypeError`` if either OID has not
## been discovered.
if conn.hstoreOid == 0 or conn.hstoreArrayOid == 0:
raise
newException(PgTypeError, "hstore/hstore[] OIDs not available on this connection")
toPgBinaryParam(v, conn.hstoreOid, conn.hstoreArrayOid)

func notifyCallback*(conn: PgConnection): NotifyCallback {.inline.} =
## The callback invoked when a NOTIFY message arrives.
conn.notifyCallback
Expand Down Expand Up @@ -679,10 +701,12 @@ proc addStmtCache*(conn: PgConnection, sql: string, cached: CachedStmt) =
return # caller should have evicted; skip if still full
var entry = cached
if entry.resultFormats.len == 0 and entry.fields.len > 0:
var extraOids: seq[int32]
if conn.hstoreOid != 0:
entry.resultFormats = buildResultFormats(entry.fields, [conn.hstoreOid])
else:
entry.resultFormats = buildResultFormats(entry.fields)
extraOids.add(conn.hstoreOid)
if conn.hstoreArrayOid != 0:
extraOids.add(conn.hstoreArrayOid)
entry.resultFormats = buildResultFormats(entry.fields, extraOids)
entry.colFmts = newSeq[int16](entry.fields.len)
entry.colOids = newSeq[int32](entry.fields.len)
for i in 0 ..< entry.fields.len:
Expand Down Expand Up @@ -1277,7 +1301,7 @@ proc connectToHost(
# Discover extension type OIDs (hstore, etc.)
conn.state = csBusy
await conn.sendMsg(
encodeQuery("SELECT oid FROM pg_type WHERE typname = 'hstore' LIMIT 1")
encodeQuery("SELECT oid, typarray FROM pg_type WHERE typname = 'hstore' LIMIT 1")
)
block discoverLoop:
while true:
Expand All @@ -1292,6 +1316,11 @@ proc connectToHost(
conn.hstoreOid = int32(parseInt(bytesToString(msg.columns[0].get)))
except ValueError:
discard
if msg.columns.len > 1 and msg.columns[1].isSome:
try:
conn.hstoreArrayOid = int32(parseInt(bytesToString(msg.columns[1].get)))
except ValueError:
discard
of bmkCommandComplete, bmkEmptyQueryResponse:
discard
of bmkReadyForQuery:
Expand Down Expand Up @@ -1697,6 +1726,7 @@ proc reconnectInPlace(conn: PgConnection) {.async.} =
conn.state = csReady
conn.createdAt = newConn.createdAt
conn.hstoreOid = newConn.hstoreOid
conn.hstoreArrayOid = newConn.hstoreArrayOid
for ch in conn.listenChannels:
discard await conn.simpleQuery("LISTEN " & quoteIdentifier(ch))

Expand Down
76 changes: 63 additions & 13 deletions async_postgres/pg_types.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1568,19 +1568,7 @@ genArrayEncoder(PgMoney, OidMoneyArray, OidMoney)

# Numeric / binary / JSON array encoders

proc toPgParam*(v: seq[PgNumeric]): PgParam =
if v.len == 0:
return PgParam(
oid: OidNumericArray, format: 1, value: some(encodeBinaryArrayEmpty(OidNumeric))
)
var elements = newSeq[seq[byte]](v.len)
for i, x in v:
elements[i] = encodeNumericBinary(x)
PgParam(
oid: OidNumericArray,
format: 1,
value: some(encodeBinaryArray(OidNumeric, elements)),
)
genArrayEncoder(PgNumeric, OidNumericArray, OidNumeric)

proc toPgByteaArrayParam*(v: seq[seq[byte]]): PgParam =
if v.len == 0:
Expand Down Expand Up @@ -1789,6 +1777,40 @@ proc toPgBinaryParam*(v: PgHstore, oid: int32): PgParam =
## (available as ``conn.hstoreOid`` after connection).
PgParam(oid: oid, format: 1, value: some(encodeHstoreBinary(v)))

proc toPgParam*(v: seq[PgHstore]): PgParam =
## Send hstore[] in text format using ``OidTextArray``. Requires an explicit
## ``::hstore[]`` cast in the SQL statement (e.g. ``SELECT $1::hstore[]``),
## since the parameter is typed as text[]. No connection-specific OID is
## needed; prefer ``toPgBinaryParam`` when a ``PgConnection`` with the
## discovered hstore OIDs is available (faster, no cast required).
if v.len == 0:
return PgParam(oid: OidTextArray, format: 0, value: some(toBytes("{}")))
var s = "{"
for i, h in v:
if i > 0:
s.add(',')
s.add('"')
for c in encodeHstoreText(h):
if c == '"' or c == '\\':
s.add('\\')
s.add(c)
s.add('"')
s.add('}')
PgParam(oid: OidTextArray, format: 0, value: some(toBytes(s)))

proc toPgBinaryParam*(v: seq[PgHstore], elemOid: int32, arrayOid: int32): PgParam =
## Encode hstore[] in binary format. Requires both the dynamic hstore OID
## and hstore[] OID (available as ``conn.hstoreOid`` and
## ``conn.hstoreArrayOid`` after connection). See also the ``PgConnection``
## overload in ``pg_connection`` which reads these OIDs automatically.
if v.len == 0:
return
PgParam(oid: arrayOid, format: 1, value: some(encodeBinaryArrayEmpty(elemOid)))
var elements = newSeq[seq[byte]](v.len)
for i, x in v:
elements[i] = encodeHstoreBinary(x)
PgParam(oid: arrayOid, format: 1, value: some(encodeBinaryArray(elemOid, elements)))

proc toPgBinaryParam*[T](v: Option[T]): PgParam =
if v.isSome:
result = toPgBinaryParam(v.get)
Expand Down Expand Up @@ -4099,6 +4121,28 @@ genStringArrayDecoder(getXmlArray, PgXml, "xml")
genStringArrayDecoder(getTsVectorArray, PgTsVector, "tsvector")
genStringArrayDecoder(getTsQueryArray, PgTsQuery, "tsquery")

proc getHstoreArray*(row: Row, col: int): seq[PgHstore] =
## Get a column value as ``seq[PgHstore]``. Handles both text and binary format.
if row.isBinaryCol(col):
let (off, clen) = cellInfo(row, col)
if clen == -1:
raise newException(PgTypeError, "Column " & $col & " is NULL")
let decoded = decodeBinaryArray(row.data.buf.toOpenArray(off, off + clen - 1))
result = newSeq[PgHstore](decoded.elements.len)
for i, e in decoded.elements:
if e.len == -1:
raise newException(PgTypeError, "NULL element in hstore array")
result[i] = decodeHstoreBinary(
row.data.buf.toOpenArray(off + e.off, off + e.off + e.len - 1)
)
return
let s = row.getStr(col)
let elems = parseTextArray(s)
for e in elems:
if e.isNone:
raise newException(PgTypeError, "NULL element in hstore array")
result.add(parseHstoreText(e.get))

# Array Opt accessors (text format)

optAccessor(getIntArray, getIntArrayOpt, seq[int32])
Expand Down Expand Up @@ -4134,6 +4178,7 @@ optAccessor(getCircleArray, getCircleArrayOpt, seq[PgCircle])
optAccessor(getXmlArray, getXmlArrayOpt, seq[PgXml])
optAccessor(getTsVectorArray, getTsVectorArrayOpt, seq[PgTsVector])
optAccessor(getTsQueryArray, getTsQueryArrayOpt, seq[PgTsQuery])
optAccessor(getHstoreArray, getHstoreArrayOpt, seq[PgHstore])

# Coerce a binary PgParam to match the server-inferred type from a prepared
# statement. This handles the common case where e.g. int32.toPgParam is
Expand Down Expand Up @@ -6293,6 +6338,9 @@ proc get*(row: Row, col: int, T: typedesc[seq[PgTsVector]]): seq[PgTsVector] =
proc get*(row: Row, col: int, T: typedesc[seq[PgTsQuery]]): seq[PgTsQuery] =
row.getTsQueryArray(col)

proc get*(row: Row, col: int, T: typedesc[seq[PgHstore]]): seq[PgHstore] =
row.getHstoreArray(col)

# Range types (DateTime-based ranges excluded — see note above)

proc get*(row: Row, col: int, T: typedesc[PgRange[int32]]): PgRange[int32] =
Expand Down Expand Up @@ -6464,6 +6512,7 @@ nameAccessor(getCircleArray, seq[PgCircle])
nameAccessor(getXmlArray, seq[PgXml])
nameAccessor(getTsVectorArray, seq[PgTsVector])
nameAccessor(getTsQueryArray, seq[PgTsQuery])
nameAccessor(getHstoreArray, seq[PgHstore])
nameAccessor(getTimestampArrayOpt, Option[seq[DateTime]])
nameAccessor(getTimestampTzArrayOpt, Option[seq[DateTime]])
nameAccessor(getDateArrayOpt, Option[seq[DateTime]])
Expand All @@ -6489,6 +6538,7 @@ nameAccessor(getCircleArrayOpt, Option[seq[PgCircle]])
nameAccessor(getXmlArrayOpt, Option[seq[PgXml]])
nameAccessor(getTsVectorArrayOpt, Option[seq[PgTsVector]])
nameAccessor(getTsQueryArrayOpt, Option[seq[PgTsQuery]])
nameAccessor(getHstoreArrayOpt, Option[seq[PgHstore]])
nameAccessor(getInt4Range, PgRange[int32])
nameAccessor(getInt8Range, PgRange[int64])
nameAccessor(getNumRange, PgRange[PgNumeric])
Expand Down
54 changes: 54 additions & 0 deletions tests/test_e2e.nim
Original file line number Diff line number Diff line change
Expand Up @@ -8050,6 +8050,60 @@ suite "E2E: Numeric / binary / JSON array types":

waitFor t()

test "hstore array roundtrip (text cast)":
proc t() {.async.} =
let conn = await connect(plainConfig())
discard await conn.simpleQuery("CREATE EXTENSION IF NOT EXISTS hstore;")
var h1: PgHstore = initTable[string, Option[string]]()
h1["a"] = some("1")
var h2: PgHstore = initTable[string, Option[string]]()
h2["b"] = none(string)
let res = await conn.query("SELECT $1::hstore[]", @[toPgParam(@[h1, h2])])
doAssert res.rows.len == 1
let arr = res.rows[0].getHstoreArray(0)
doAssert arr.len == 2
doAssert arr[0] == h1
doAssert arr[1] == h2
await conn.close()

waitFor t()

test "hstore array roundtrip (binary)":
proc t() {.async.} =
let conn = await connect(plainConfig())
discard await conn.simpleQuery("CREATE EXTENSION IF NOT EXISTS hstore;")
doAssert conn.hstoreOid != 0
doAssert conn.hstoreArrayOid != 0
var h1: PgHstore = initTable[string, Option[string]]()
h1["x"] = some("y")
var h2: PgHstore = initTable[string, Option[string]]()
h2["nul"] = none(string)
let bin = toPgBinaryParam(@[h1, h2], conn.hstoreOid, conn.hstoreArrayOid)
let res = await conn.query("SELECT $1", @[bin])
doAssert res.rows.len == 1
let arr = res.rows[0].getHstoreArray(0)
doAssert arr.len == 2
doAssert arr[0] == h1
doAssert arr[1] == h2
await conn.close()

waitFor t()

test "hstore array roundtrip (binary, conn overload)":
proc t() {.async.} =
let conn = await connect(plainConfig())
discard await conn.simpleQuery("CREATE EXTENSION IF NOT EXISTS hstore;")
var h1: PgHstore = initTable[string, Option[string]]()
h1["k"] = some("v")
let res = await conn.query("SELECT $1", @[conn.toPgBinaryParam(@[h1])])
doAssert res.rows.len == 1
let arr = res.rows[0].getHstoreArray(0)
doAssert arr.len == 1
doAssert arr[0] == h1
await conn.close()

waitFor t()

test "bytea array roundtrip":
proc t() {.async.} =
let conn = await connect(plainConfig())
Expand Down
67 changes: 67 additions & 0 deletions tests/test_types.nim
Original file line number Diff line number Diff line change
Expand Up @@ -5011,6 +5011,73 @@ suite "hstore":
let row = mkRow(@[none(seq[byte])], fields)
check row.getHstoreOpt(0).isNone

test "toPgParam seq[PgHstore] text roundtrip":
var h1: PgHstore = initTable[string, Option[string]]()
h1["a"] = some("1")
var h2: PgHstore = initTable[string, Option[string]]()
h2["b"] = none(string)
let p = toPgParam(@[h1, h2])
check p.oid == OidTextArray
check p.format == 0
let fields = @[mkField(OidTextArray, 0'i16)]
let row = mkRow(@[p.value], fields)
let arr = row.getHstoreArray(0)
check arr.len == 2
check arr[0] == h1
check arr[1] == h2

test "toPgParam seq[PgHstore] empty":
let p = toPgParam(newSeq[PgHstore]())
check p.oid == OidTextArray
check p.format == 0
check toString(p.value.get) == "{}"

test "toPgBinaryParam seq[PgHstore] roundtrip":
var h1: PgHstore = initTable[string, Option[string]]()
h1["x"] = some("y")
var h2: PgHstore = initTable[string, Option[string]]()
h2["nul"] = none(string)
let p = toPgBinaryParam(@[h1, h2], 16385'i32, 16386'i32)
check p.oid == 16386'i32
check p.format == 1
let fields = @[mkField(16386'i32, 1'i16)]
let row = mkRow(@[p.value], fields)
let arr = row.getHstoreArray(0)
check arr.len == 2
check arr[0] == h1
check arr[1] == h2

test "toPgBinaryParam seq[PgHstore] empty":
let p = toPgBinaryParam(newSeq[PgHstore](), 16385'i32, 16386'i32)
check p.oid == 16386'i32
check p.format == 1
let fields = @[mkField(16386'i32, 1'i16)]
let row = mkRow(@[p.value], fields)
check row.getHstoreArray(0).len == 0

test "getHstoreArray text format":
let row: Row = @[some(toBytes("{\"\\\"a\\\"=>\\\"1\\\"\",\"\\\"b\\\"=>NULL\"}"))]
let arr = row.getHstoreArray(0)
check arr.len == 2
check arr[0]["a"] == some("1")
check arr[1]["b"] == none(string)

test "getHstoreArrayOpt some":
var h1: PgHstore = initTable[string, Option[string]]()
h1["k"] = some("v")
let p = toPgParam(@[h1])
let fields = @[mkField(OidTextArray, 0'i16)]
let row = mkRow(@[p.value], fields)
let r = row.getHstoreArrayOpt(0)
check r.isSome
check r.get.len == 1
check r.get[0] == h1

test "getHstoreArrayOpt none":
let fields = @[mkField(OidTextArray, 0'i16)]
let row = mkRow(@[none(seq[byte])], fields)
check row.getHstoreArrayOpt(0).isNone

suite "PgBit":
test "OID constants":
check OidBit == 1560'i32
Expand Down
Loading