diff --git a/async_postgres/pg_connection.nim b/async_postgres/pg_connection.nim index 3029edf..1e5617a 100644 --- a/async_postgres/pg_connection.nim +++ b/async_postgres/pg_connection.nim @@ -168,6 +168,7 @@ type stmtCounter: int stmtCacheCapacity: int ## 0=disabled, default 256 hstoreOid: int32 ## Dynamic OID for hstore extension type; 0 if not available + hstoreArrayOid: int32 ## Dynamic OID for hstore[] array; 0 if not available tracer: PgTracer ## Inherited from ConnConfig on connect QueryResult* = object @@ -337,6 +338,27 @@ func hstoreOid*(conn: PgConnection): int32 {.inline.} = ## Dynamic OID for hstore extension type; 0 if not available. conn.hstoreOid +func hstoreArrayOid*(conn: PgConnection): int32 {.inline.} = + ## Dynamic OID for hstore[] array type; 0 if not available. + conn.hstoreArrayOid + +proc toPgBinaryParam*(conn: PgConnection, v: PgHstore): PgParam {.inline.} = + ## Convenience overload: encode hstore in binary using ``conn.hstoreOid``. + ## Raises ``PgTypeError`` if the hstore extension OID has not been discovered + ## (e.g. extension not installed on the server). + if conn.hstoreOid == 0: + raise newException(PgTypeError, "hstore OID not available on this connection") + toPgBinaryParam(v, conn.hstoreOid) + +proc toPgBinaryParam*(conn: PgConnection, v: seq[PgHstore]): PgParam {.inline.} = + ## Convenience overload: encode hstore[] in binary using ``conn.hstoreOid`` + ## and ``conn.hstoreArrayOid``. Raises ``PgTypeError`` if either OID has not + ## been discovered. + if conn.hstoreOid == 0 or conn.hstoreArrayOid == 0: + raise + newException(PgTypeError, "hstore/hstore[] OIDs not available on this connection") + toPgBinaryParam(v, conn.hstoreOid, conn.hstoreArrayOid) + func notifyCallback*(conn: PgConnection): NotifyCallback {.inline.} = ## The callback invoked when a NOTIFY message arrives. conn.notifyCallback @@ -679,10 +701,12 @@ proc addStmtCache*(conn: PgConnection, sql: string, cached: CachedStmt) = return # caller should have evicted; skip if still full var entry = cached if entry.resultFormats.len == 0 and entry.fields.len > 0: + var extraOids: seq[int32] if conn.hstoreOid != 0: - entry.resultFormats = buildResultFormats(entry.fields, [conn.hstoreOid]) - else: - entry.resultFormats = buildResultFormats(entry.fields) + extraOids.add(conn.hstoreOid) + if conn.hstoreArrayOid != 0: + extraOids.add(conn.hstoreArrayOid) + entry.resultFormats = buildResultFormats(entry.fields, extraOids) entry.colFmts = newSeq[int16](entry.fields.len) entry.colOids = newSeq[int32](entry.fields.len) for i in 0 ..< entry.fields.len: @@ -1277,7 +1301,7 @@ proc connectToHost( # Discover extension type OIDs (hstore, etc.) conn.state = csBusy await conn.sendMsg( - encodeQuery("SELECT oid FROM pg_type WHERE typname = 'hstore' LIMIT 1") + encodeQuery("SELECT oid, typarray FROM pg_type WHERE typname = 'hstore' LIMIT 1") ) block discoverLoop: while true: @@ -1292,6 +1316,11 @@ proc connectToHost( conn.hstoreOid = int32(parseInt(bytesToString(msg.columns[0].get))) except ValueError: discard + if msg.columns.len > 1 and msg.columns[1].isSome: + try: + conn.hstoreArrayOid = int32(parseInt(bytesToString(msg.columns[1].get))) + except ValueError: + discard of bmkCommandComplete, bmkEmptyQueryResponse: discard of bmkReadyForQuery: @@ -1697,6 +1726,7 @@ proc reconnectInPlace(conn: PgConnection) {.async.} = conn.state = csReady conn.createdAt = newConn.createdAt conn.hstoreOid = newConn.hstoreOid + conn.hstoreArrayOid = newConn.hstoreArrayOid for ch in conn.listenChannels: discard await conn.simpleQuery("LISTEN " & quoteIdentifier(ch)) diff --git a/async_postgres/pg_types.nim b/async_postgres/pg_types.nim index a38b7b5..82c508d 100644 --- a/async_postgres/pg_types.nim +++ b/async_postgres/pg_types.nim @@ -1568,19 +1568,7 @@ genArrayEncoder(PgMoney, OidMoneyArray, OidMoney) # Numeric / binary / JSON array encoders -proc toPgParam*(v: seq[PgNumeric]): PgParam = - if v.len == 0: - return PgParam( - oid: OidNumericArray, format: 1, value: some(encodeBinaryArrayEmpty(OidNumeric)) - ) - var elements = newSeq[seq[byte]](v.len) - for i, x in v: - elements[i] = encodeNumericBinary(x) - PgParam( - oid: OidNumericArray, - format: 1, - value: some(encodeBinaryArray(OidNumeric, elements)), - ) +genArrayEncoder(PgNumeric, OidNumericArray, OidNumeric) proc toPgByteaArrayParam*(v: seq[seq[byte]]): PgParam = if v.len == 0: @@ -1789,6 +1777,40 @@ proc toPgBinaryParam*(v: PgHstore, oid: int32): PgParam = ## (available as ``conn.hstoreOid`` after connection). PgParam(oid: oid, format: 1, value: some(encodeHstoreBinary(v))) +proc toPgParam*(v: seq[PgHstore]): PgParam = + ## Send hstore[] in text format using ``OidTextArray``. Requires an explicit + ## ``::hstore[]`` cast in the SQL statement (e.g. ``SELECT $1::hstore[]``), + ## since the parameter is typed as text[]. No connection-specific OID is + ## needed; prefer ``toPgBinaryParam`` when a ``PgConnection`` with the + ## discovered hstore OIDs is available (faster, no cast required). + if v.len == 0: + return PgParam(oid: OidTextArray, format: 0, value: some(toBytes("{}"))) + var s = "{" + for i, h in v: + if i > 0: + s.add(',') + s.add('"') + for c in encodeHstoreText(h): + if c == '"' or c == '\\': + s.add('\\') + s.add(c) + s.add('"') + s.add('}') + PgParam(oid: OidTextArray, format: 0, value: some(toBytes(s))) + +proc toPgBinaryParam*(v: seq[PgHstore], elemOid: int32, arrayOid: int32): PgParam = + ## Encode hstore[] in binary format. Requires both the dynamic hstore OID + ## and hstore[] OID (available as ``conn.hstoreOid`` and + ## ``conn.hstoreArrayOid`` after connection). See also the ``PgConnection`` + ## overload in ``pg_connection`` which reads these OIDs automatically. + if v.len == 0: + return + PgParam(oid: arrayOid, format: 1, value: some(encodeBinaryArrayEmpty(elemOid))) + var elements = newSeq[seq[byte]](v.len) + for i, x in v: + elements[i] = encodeHstoreBinary(x) + PgParam(oid: arrayOid, format: 1, value: some(encodeBinaryArray(elemOid, elements))) + proc toPgBinaryParam*[T](v: Option[T]): PgParam = if v.isSome: result = toPgBinaryParam(v.get) @@ -4099,6 +4121,28 @@ genStringArrayDecoder(getXmlArray, PgXml, "xml") genStringArrayDecoder(getTsVectorArray, PgTsVector, "tsvector") genStringArrayDecoder(getTsQueryArray, PgTsQuery, "tsquery") +proc getHstoreArray*(row: Row, col: int): seq[PgHstore] = + ## Get a column value as ``seq[PgHstore]``. Handles both text and binary format. + if row.isBinaryCol(col): + let (off, clen) = cellInfo(row, col) + if clen == -1: + raise newException(PgTypeError, "Column " & $col & " is NULL") + let decoded = decodeBinaryArray(row.data.buf.toOpenArray(off, off + clen - 1)) + result = newSeq[PgHstore](decoded.elements.len) + for i, e in decoded.elements: + if e.len == -1: + raise newException(PgTypeError, "NULL element in hstore array") + result[i] = decodeHstoreBinary( + row.data.buf.toOpenArray(off + e.off, off + e.off + e.len - 1) + ) + return + let s = row.getStr(col) + let elems = parseTextArray(s) + for e in elems: + if e.isNone: + raise newException(PgTypeError, "NULL element in hstore array") + result.add(parseHstoreText(e.get)) + # Array Opt accessors (text format) optAccessor(getIntArray, getIntArrayOpt, seq[int32]) @@ -4134,6 +4178,7 @@ optAccessor(getCircleArray, getCircleArrayOpt, seq[PgCircle]) optAccessor(getXmlArray, getXmlArrayOpt, seq[PgXml]) optAccessor(getTsVectorArray, getTsVectorArrayOpt, seq[PgTsVector]) optAccessor(getTsQueryArray, getTsQueryArrayOpt, seq[PgTsQuery]) +optAccessor(getHstoreArray, getHstoreArrayOpt, seq[PgHstore]) # Coerce a binary PgParam to match the server-inferred type from a prepared # statement. This handles the common case where e.g. int32.toPgParam is @@ -6293,6 +6338,9 @@ proc get*(row: Row, col: int, T: typedesc[seq[PgTsVector]]): seq[PgTsVector] = proc get*(row: Row, col: int, T: typedesc[seq[PgTsQuery]]): seq[PgTsQuery] = row.getTsQueryArray(col) +proc get*(row: Row, col: int, T: typedesc[seq[PgHstore]]): seq[PgHstore] = + row.getHstoreArray(col) + # Range types (DateTime-based ranges excluded — see note above) proc get*(row: Row, col: int, T: typedesc[PgRange[int32]]): PgRange[int32] = @@ -6464,6 +6512,7 @@ nameAccessor(getCircleArray, seq[PgCircle]) nameAccessor(getXmlArray, seq[PgXml]) nameAccessor(getTsVectorArray, seq[PgTsVector]) nameAccessor(getTsQueryArray, seq[PgTsQuery]) +nameAccessor(getHstoreArray, seq[PgHstore]) nameAccessor(getTimestampArrayOpt, Option[seq[DateTime]]) nameAccessor(getTimestampTzArrayOpt, Option[seq[DateTime]]) nameAccessor(getDateArrayOpt, Option[seq[DateTime]]) @@ -6489,6 +6538,7 @@ nameAccessor(getCircleArrayOpt, Option[seq[PgCircle]]) nameAccessor(getXmlArrayOpt, Option[seq[PgXml]]) nameAccessor(getTsVectorArrayOpt, Option[seq[PgTsVector]]) nameAccessor(getTsQueryArrayOpt, Option[seq[PgTsQuery]]) +nameAccessor(getHstoreArrayOpt, Option[seq[PgHstore]]) nameAccessor(getInt4Range, PgRange[int32]) nameAccessor(getInt8Range, PgRange[int64]) nameAccessor(getNumRange, PgRange[PgNumeric]) diff --git a/tests/test_e2e.nim b/tests/test_e2e.nim index 275b099..c6c3fb8 100644 --- a/tests/test_e2e.nim +++ b/tests/test_e2e.nim @@ -8050,6 +8050,60 @@ suite "E2E: Numeric / binary / JSON array types": waitFor t() + test "hstore array roundtrip (text cast)": + proc t() {.async.} = + let conn = await connect(plainConfig()) + discard await conn.simpleQuery("CREATE EXTENSION IF NOT EXISTS hstore;") + var h1: PgHstore = initTable[string, Option[string]]() + h1["a"] = some("1") + var h2: PgHstore = initTable[string, Option[string]]() + h2["b"] = none(string) + let res = await conn.query("SELECT $1::hstore[]", @[toPgParam(@[h1, h2])]) + doAssert res.rows.len == 1 + let arr = res.rows[0].getHstoreArray(0) + doAssert arr.len == 2 + doAssert arr[0] == h1 + doAssert arr[1] == h2 + await conn.close() + + waitFor t() + + test "hstore array roundtrip (binary)": + proc t() {.async.} = + let conn = await connect(plainConfig()) + discard await conn.simpleQuery("CREATE EXTENSION IF NOT EXISTS hstore;") + doAssert conn.hstoreOid != 0 + doAssert conn.hstoreArrayOid != 0 + var h1: PgHstore = initTable[string, Option[string]]() + h1["x"] = some("y") + var h2: PgHstore = initTable[string, Option[string]]() + h2["nul"] = none(string) + let bin = toPgBinaryParam(@[h1, h2], conn.hstoreOid, conn.hstoreArrayOid) + let res = await conn.query("SELECT $1", @[bin]) + doAssert res.rows.len == 1 + let arr = res.rows[0].getHstoreArray(0) + doAssert arr.len == 2 + doAssert arr[0] == h1 + doAssert arr[1] == h2 + await conn.close() + + waitFor t() + + test "hstore array roundtrip (binary, conn overload)": + proc t() {.async.} = + let conn = await connect(plainConfig()) + discard await conn.simpleQuery("CREATE EXTENSION IF NOT EXISTS hstore;") + var h1: PgHstore = initTable[string, Option[string]]() + h1["k"] = some("v") + let res = await conn.query("SELECT $1", @[conn.toPgBinaryParam(@[h1])]) + doAssert res.rows.len == 1 + let arr = res.rows[0].getHstoreArray(0) + doAssert arr.len == 1 + doAssert arr[0] == h1 + await conn.close() + + waitFor t() + test "bytea array roundtrip": proc t() {.async.} = let conn = await connect(plainConfig()) diff --git a/tests/test_types.nim b/tests/test_types.nim index 47cab97..70234a2 100644 --- a/tests/test_types.nim +++ b/tests/test_types.nim @@ -5011,6 +5011,73 @@ suite "hstore": let row = mkRow(@[none(seq[byte])], fields) check row.getHstoreOpt(0).isNone + test "toPgParam seq[PgHstore] text roundtrip": + var h1: PgHstore = initTable[string, Option[string]]() + h1["a"] = some("1") + var h2: PgHstore = initTable[string, Option[string]]() + h2["b"] = none(string) + let p = toPgParam(@[h1, h2]) + check p.oid == OidTextArray + check p.format == 0 + let fields = @[mkField(OidTextArray, 0'i16)] + let row = mkRow(@[p.value], fields) + let arr = row.getHstoreArray(0) + check arr.len == 2 + check arr[0] == h1 + check arr[1] == h2 + + test "toPgParam seq[PgHstore] empty": + let p = toPgParam(newSeq[PgHstore]()) + check p.oid == OidTextArray + check p.format == 0 + check toString(p.value.get) == "{}" + + test "toPgBinaryParam seq[PgHstore] roundtrip": + var h1: PgHstore = initTable[string, Option[string]]() + h1["x"] = some("y") + var h2: PgHstore = initTable[string, Option[string]]() + h2["nul"] = none(string) + let p = toPgBinaryParam(@[h1, h2], 16385'i32, 16386'i32) + check p.oid == 16386'i32 + check p.format == 1 + let fields = @[mkField(16386'i32, 1'i16)] + let row = mkRow(@[p.value], fields) + let arr = row.getHstoreArray(0) + check arr.len == 2 + check arr[0] == h1 + check arr[1] == h2 + + test "toPgBinaryParam seq[PgHstore] empty": + let p = toPgBinaryParam(newSeq[PgHstore](), 16385'i32, 16386'i32) + check p.oid == 16386'i32 + check p.format == 1 + let fields = @[mkField(16386'i32, 1'i16)] + let row = mkRow(@[p.value], fields) + check row.getHstoreArray(0).len == 0 + + test "getHstoreArray text format": + let row: Row = @[some(toBytes("{\"\\\"a\\\"=>\\\"1\\\"\",\"\\\"b\\\"=>NULL\"}"))] + let arr = row.getHstoreArray(0) + check arr.len == 2 + check arr[0]["a"] == some("1") + check arr[1]["b"] == none(string) + + test "getHstoreArrayOpt some": + var h1: PgHstore = initTable[string, Option[string]]() + h1["k"] = some("v") + let p = toPgParam(@[h1]) + let fields = @[mkField(OidTextArray, 0'i16)] + let row = mkRow(@[p.value], fields) + let r = row.getHstoreArrayOpt(0) + check r.isSome + check r.get.len == 1 + check r.get[0] == h1 + + test "getHstoreArrayOpt none": + let fields = @[mkField(OidTextArray, 0'i16)] + let row = mkRow(@[none(seq[byte])], fields) + check row.getHstoreArrayOpt(0).isNone + suite "PgBit": test "OID constants": check OidBit == 1560'i32