From abdf670263fc05fb42570e80933466a3b1a8025f Mon Sep 17 00:00:00 2001 From: fox0430 Date: Tue, 31 Mar 2026 16:54:35 +0900 Subject: [PATCH] Add Domain types --- async_postgres/pg_types.nim | 59 +++++++++++- tests/test_types.nim | 185 +++++++++++++++++++++++++++++++++++- 2 files changed, 242 insertions(+), 2 deletions(-) diff --git a/async_postgres/pg_types.nim b/async_postgres/pg_types.nim index c64d79b..ff10562 100644 --- a/async_postgres/pg_types.nim +++ b/async_postgres/pg_types.nim @@ -1,7 +1,7 @@ import std/[ hashes, json, macros, math, options, parseutils, sequtils, strutils, tables, times, - net, + typetraits, net, ] import pg_protocol @@ -3096,6 +3096,63 @@ proc getCompositeOpt*[T: object](row: Row, col: int): Option[T] = else: some(getComposite[T](row, col)) +# User-defined domain type support +# +# PostgreSQL domain types are named constraints over a base type. +# They have their own OID but share the wire format of the base type. +# In Nim, they map naturally to ``distinct`` types. +# +# Usage: +# type UsPostalCode = distinct string +# +# pgDomain(UsPostalCode, string) # OID = 0; base type OID used +# pgDomain(UsPostalCode, string, 12345'i32) # explicit domain OID +# +# Reading rows: +# let z = row.getDomain[UsPostalCode](0) +# let z = row.getDomainOpt[UsPostalCode](0) + +macro pgDomain*(T: typedesc, Base: typedesc, oid: int32 = 0'i32): untyped = + ## Generate ``toPgParam`` for a Nim distinct type as a PostgreSQL domain type. + ## Encoding delegates to the base type's ``toPgParam``. + ## When OID is 0 (default), the base type's OID is used. + let tSym = T.getType[1] + let bSym = Base.getType[1] + result = newStmtList() + result.add quote do: + proc toPgParam*(v: `tSym`): PgParam = + result = toPgParam(`bSym`(v)) + if `oid` != 0'i32: + result.oid = `oid` + +proc getDomain*[T: distinct](row: Row, col: int): T = + ## Read a PostgreSQL domain column as a Nim distinct type. + ## The base type determines which row accessor is used. + when distinctBase(T) is string: + T(row.getStr(col)) + elif distinctBase(T) is int16: + T(int16(row.getInt(col))) + elif distinctBase(T) is int32: + T(row.getInt(col)) + elif distinctBase(T) is int64: + T(row.getInt64(col)) + elif distinctBase(T) is float64: + T(row.getFloat(col)) + elif distinctBase(T) is bool: + T(row.getBool(col)) + else: + {. + error: + "Unsupported domain base type: use string, int16, int32, int64, float64, or bool" + .} + +proc getDomainOpt*[T: distinct](row: Row, col: int): Option[T] = + ## NULL-safe version of ``getDomain``. + if row.isNull(col): + none(T) + else: + some(getDomain[T](row, col)) + # Range type support # # PostgreSQL range types represent a range of values of some element type. diff --git a/tests/test_types.nim b/tests/test_types.nim index 4ed664d..dbdea44 100644 --- a/tests/test_types.nim +++ b/tests/test_types.nim @@ -1,8 +1,32 @@ -import std/[json, unittest, options, strutils, tables, times, math, net] +import std/[json, unittest, options, strutils, tables, times, math, net, typetraits] import ../async_postgres/pg_protocol import ../async_postgres/pg_types {.all.} +type + UsPostalCode = distinct string + SmallCount = distinct int16 + PositiveInt = distinct int32 + ProbabilityF = distinct float64 + BigCount = distinct int64 + IsActive = distinct bool + +proc `==`(a, b: UsPostalCode): bool {.borrow.} +proc `==`(a, b: SmallCount): bool {.borrow.} +proc `==`(a, b: PositiveInt): bool {.borrow.} +proc `==`(a, b: BigCount): bool {.borrow.} +proc `==`(a, b: IsActive): bool {.borrow.} +proc `$`(v: UsPostalCode): string {.borrow.} +proc `$`(v: SmallCount): string {.borrow.} +proc `$`(v: PositiveInt): string {.borrow.} + +pgDomain(UsPostalCode, string) +pgDomain(SmallCount, int16) +pgDomain(PositiveInt, int32) +pgDomain(ProbabilityF, float64, 90001) +pgDomain(BigCount, int64) +pgDomain(IsActive, bool) + proc toString(data: seq[byte]): string = result = newString(data.len) for i in 0 ..< data.len: @@ -2856,6 +2880,165 @@ suite "User-defined composite": check p.oid == 0'i32 check p.value.isNone +suite "User-defined domain": + test "pgDomain generates toPgParam with base type OID": + let p = toPgParam(UsPostalCode("12345")) + check p.oid == OidText + check p.format == 0'i16 + check p.value.isSome + check toString(p.value.get) == "12345" + + test "pgDomain int16 base type": + let p = toPgParam(SmallCount(7)) + check p.oid == OidInt2 + check p.format == 1'i16 + check p.value.isSome + + test "pgDomain int32 base type": + let p = toPgParam(PositiveInt(42)) + check p.oid == OidInt4 + check p.format == 1'i16 # binary inherited from int32 + check p.value.isSome + + test "pgDomain with explicit OID": + let p = toPgParam(ProbabilityF(0.95)) + check p.oid == 90001'i32 + check p.value.isSome + + test "getDomain text format string": + let row: Row = @[some(toBytes("12345"))] + check getDomain[UsPostalCode](row, 0) == UsPostalCode("12345") + + test "getDomain text format int16": + let row: Row = @[some(toBytes("7"))] + check getDomain[SmallCount](row, 0) == SmallCount(7) + + test "getDomain text format int32": + let row: Row = @[some(toBytes("42"))] + check getDomain[PositiveInt](row, 0) == PositiveInt(42) + + test "getDomain text format int64": + let row: Row = @[some(toBytes("1000000000"))] + check getDomain[BigCount](row, 0) == BigCount(1000000000'i64) + + test "getDomain text format float64": + let row: Row = @[some(toBytes("0.95"))] + let v = getDomain[ProbabilityF](row, 0) + check abs(float64(v) - 0.95) < 1e-10 + + test "getDomain text format bool": + let rowT: Row = @[some(toBytes("t"))] + check getDomain[IsActive](rowT, 0) == IsActive(true) + let rowF: Row = @[some(toBytes("f"))] + check getDomain[IsActive](rowF, 0) == IsActive(false) + + test "getDomain binary format int16": + let fields = @[mkField(OidInt2, 1'i16)] + let row = mkRow(@[some(@(toBE16(3'i16)))], fields) + check getDomain[SmallCount](row, 0) == SmallCount(3) + + test "getDomain binary format int32": + let fields = @[mkField(OidInt4, 1'i16)] + let row = mkRow(@[some(@(toBE32(99'i32)))], fields) + check getDomain[PositiveInt](row, 0) == PositiveInt(99) + + test "getDomain binary format int64": + let fields = @[mkField(OidInt8, 1'i16)] + let row = mkRow(@[some(@(toBE64(1000000000'i64)))], fields) + check getDomain[BigCount](row, 0) == BigCount(1000000000'i64) + + test "getDomain binary format float64": + let fields = @[mkField(OidFloat8, 1'i16)] + let row = mkRow(@[some(@(toBE64(cast[int64](3.14'f64))))], fields) + let v = getDomain[ProbabilityF](row, 0) + check abs(float64(v) - 3.14) < 1e-10 + + test "getDomain binary format bool": + let fields = @[mkField(OidBool, 1'i16)] + let rowT = mkRow(@[some(@[1'u8])], fields) + check getDomain[IsActive](rowT, 0) == IsActive(true) + let rowF = mkRow(@[some(@[0'u8])], fields) + check getDomain[IsActive](rowF, 0) == IsActive(false) + + test "getDomain raises on NULL": + let row: Row = @[none(seq[byte])] + var raised = false + try: + discard getDomain[UsPostalCode](row, 0) + except PgTypeError: + raised = true + check raised + + test "getDomainOpt some": + let row: Row = @[some(toBytes("12345"))] + check getDomainOpt[UsPostalCode](row, 0) == some(UsPostalCode("12345")) + + test "getDomainOpt none": + let row: Row = @[none(seq[byte])] + check getDomainOpt[UsPostalCode](row, 0) == none(UsPostalCode) + + test "getDomainOpt binary NULL": + let fields = @[mkField(OidInt4, 1'i16)] + let row = mkRow(@[none(seq[byte])], fields) + check getDomainOpt[PositiveInt](row, 0) == none(PositiveInt) + + test "Option[Domain] toPgParam some": + let p = toPgParam(some(UsPostalCode("12345"))) + check p.oid == OidText + check p.value.isSome + check toString(p.value.get) == "12345" + + test "Option[Domain] toPgParam none": + let p = toPgParam(none(UsPostalCode)) + check p.oid == OidText + check p.value.isNone + + test "roundtrip text string": + let orig = UsPostalCode("90210") + let p = toPgParam(orig) + let row: Row = @[p.value] + check getDomain[UsPostalCode](row, 0) == orig + + test "roundtrip binary int16": + let orig = SmallCount(5) + let p = toPgParam(orig) + let fields = @[mkField(OidInt2, 1'i16)] + let row = mkRow(@[p.value], fields) + check getDomain[SmallCount](row, 0) == orig + + test "roundtrip binary int32": + let orig = PositiveInt(7) + let p = toPgParam(orig) + let fields = @[mkField(OidInt4, 1'i16)] + let row = mkRow(@[p.value], fields) + check getDomain[PositiveInt](row, 0) == orig + + test "pgDomain int64 base type": + let p = toPgParam(BigCount(999999'i64)) + check p.oid == OidInt8 + check p.format == 1'i16 + check p.value.isSome + + test "pgDomain bool base type": + let p = toPgParam(IsActive(true)) + check p.oid == OidBool + check p.format == 1'i16 + check p.value.isSome + + test "roundtrip binary int64": + let orig = BigCount(123456789'i64) + let p = toPgParam(orig) + let fields = @[mkField(OidInt8, 1'i16)] + let row = mkRow(@[p.value], fields) + check getDomain[BigCount](row, 0) == orig + + test "roundtrip binary bool": + let orig = IsActive(true) + let p = toPgParam(orig) + let fields = @[mkField(OidBool, 1'i16)] + let row = mkRow(@[p.value], fields) + check getDomain[IsActive](row, 0) == orig + suite "Range OID constants": test "range OID values": check OidInt4Range == 3904'i32