Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion async_postgres/pg_types.nim
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import
std/[
hashes, json, macros, math, options, parseutils, sequtils, strutils, tables, times,
net,
typetraits, net,
]

import pg_protocol
Expand Down Expand Up @@ -3096,6 +3096,63 @@ proc getCompositeOpt*[T: object](row: Row, col: int): Option[T] =
else:
some(getComposite[T](row, col))

# User-defined domain type support
#
# PostgreSQL domain types are named constraints over a base type.
# They have their own OID but share the wire format of the base type.
# In Nim, they map naturally to ``distinct`` types.
#
# Usage:
# type UsPostalCode = distinct string
#
# pgDomain(UsPostalCode, string) # OID = 0; base type OID used
# pgDomain(UsPostalCode, string, 12345'i32) # explicit domain OID
#
# Reading rows:
# let z = row.getDomain[UsPostalCode](0)
# let z = row.getDomainOpt[UsPostalCode](0)

macro pgDomain*(T: typedesc, Base: typedesc, oid: int32 = 0'i32): untyped =
## Generate ``toPgParam`` for a Nim distinct type as a PostgreSQL domain type.
## Encoding delegates to the base type's ``toPgParam``.
## When OID is 0 (default), the base type's OID is used.
let tSym = T.getType[1]
let bSym = Base.getType[1]
result = newStmtList()
result.add quote do:
proc toPgParam*(v: `tSym`): PgParam =
result = toPgParam(`bSym`(v))
if `oid` != 0'i32:
result.oid = `oid`

proc getDomain*[T: distinct](row: Row, col: int): T =
## Read a PostgreSQL domain column as a Nim distinct type.
## The base type determines which row accessor is used.
when distinctBase(T) is string:
T(row.getStr(col))
elif distinctBase(T) is int16:
T(int16(row.getInt(col)))
elif distinctBase(T) is int32:
T(row.getInt(col))
elif distinctBase(T) is int64:
T(row.getInt64(col))
elif distinctBase(T) is float64:
T(row.getFloat(col))
elif distinctBase(T) is bool:
T(row.getBool(col))
else:
{.
error:
"Unsupported domain base type: use string, int16, int32, int64, float64, or bool"
.}

proc getDomainOpt*[T: distinct](row: Row, col: int): Option[T] =
## NULL-safe version of ``getDomain``.
if row.isNull(col):
none(T)
else:
some(getDomain[T](row, col))

# Range type support
#
# PostgreSQL range types represent a range of values of some element type.
Expand Down
185 changes: 184 additions & 1 deletion tests/test_types.nim
Original file line number Diff line number Diff line change
@@ -1,8 +1,32 @@
import std/[json, unittest, options, strutils, tables, times, math, net]
import std/[json, unittest, options, strutils, tables, times, math, net, typetraits]

import ../async_postgres/pg_protocol
import ../async_postgres/pg_types {.all.}

type
UsPostalCode = distinct string
SmallCount = distinct int16
PositiveInt = distinct int32
ProbabilityF = distinct float64
BigCount = distinct int64
IsActive = distinct bool

proc `==`(a, b: UsPostalCode): bool {.borrow.}
proc `==`(a, b: SmallCount): bool {.borrow.}
proc `==`(a, b: PositiveInt): bool {.borrow.}
proc `==`(a, b: BigCount): bool {.borrow.}
proc `==`(a, b: IsActive): bool {.borrow.}
proc `$`(v: UsPostalCode): string {.borrow.}
proc `$`(v: SmallCount): string {.borrow.}
proc `$`(v: PositiveInt): string {.borrow.}

pgDomain(UsPostalCode, string)
pgDomain(SmallCount, int16)
pgDomain(PositiveInt, int32)
pgDomain(ProbabilityF, float64, 90001)
pgDomain(BigCount, int64)
pgDomain(IsActive, bool)

proc toString(data: seq[byte]): string =
result = newString(data.len)
for i in 0 ..< data.len:
Expand Down Expand Up @@ -2856,6 +2880,165 @@ suite "User-defined composite":
check p.oid == 0'i32
check p.value.isNone

suite "User-defined domain":
test "pgDomain generates toPgParam with base type OID":
let p = toPgParam(UsPostalCode("12345"))
check p.oid == OidText
check p.format == 0'i16
check p.value.isSome
check toString(p.value.get) == "12345"

test "pgDomain int16 base type":
let p = toPgParam(SmallCount(7))
check p.oid == OidInt2
check p.format == 1'i16
check p.value.isSome

test "pgDomain int32 base type":
let p = toPgParam(PositiveInt(42))
check p.oid == OidInt4
check p.format == 1'i16 # binary inherited from int32
check p.value.isSome

test "pgDomain with explicit OID":
let p = toPgParam(ProbabilityF(0.95))
check p.oid == 90001'i32
check p.value.isSome

test "getDomain text format string":
let row: Row = @[some(toBytes("12345"))]
check getDomain[UsPostalCode](row, 0) == UsPostalCode("12345")

test "getDomain text format int16":
let row: Row = @[some(toBytes("7"))]
check getDomain[SmallCount](row, 0) == SmallCount(7)

test "getDomain text format int32":
let row: Row = @[some(toBytes("42"))]
check getDomain[PositiveInt](row, 0) == PositiveInt(42)

test "getDomain text format int64":
let row: Row = @[some(toBytes("1000000000"))]
check getDomain[BigCount](row, 0) == BigCount(1000000000'i64)

test "getDomain text format float64":
let row: Row = @[some(toBytes("0.95"))]
let v = getDomain[ProbabilityF](row, 0)
check abs(float64(v) - 0.95) < 1e-10

test "getDomain text format bool":
let rowT: Row = @[some(toBytes("t"))]
check getDomain[IsActive](rowT, 0) == IsActive(true)
let rowF: Row = @[some(toBytes("f"))]
check getDomain[IsActive](rowF, 0) == IsActive(false)

test "getDomain binary format int16":
let fields = @[mkField(OidInt2, 1'i16)]
let row = mkRow(@[some(@(toBE16(3'i16)))], fields)
check getDomain[SmallCount](row, 0) == SmallCount(3)

test "getDomain binary format int32":
let fields = @[mkField(OidInt4, 1'i16)]
let row = mkRow(@[some(@(toBE32(99'i32)))], fields)
check getDomain[PositiveInt](row, 0) == PositiveInt(99)

test "getDomain binary format int64":
let fields = @[mkField(OidInt8, 1'i16)]
let row = mkRow(@[some(@(toBE64(1000000000'i64)))], fields)
check getDomain[BigCount](row, 0) == BigCount(1000000000'i64)

test "getDomain binary format float64":
let fields = @[mkField(OidFloat8, 1'i16)]
let row = mkRow(@[some(@(toBE64(cast[int64](3.14'f64))))], fields)
let v = getDomain[ProbabilityF](row, 0)
check abs(float64(v) - 3.14) < 1e-10

test "getDomain binary format bool":
let fields = @[mkField(OidBool, 1'i16)]
let rowT = mkRow(@[some(@[1'u8])], fields)
check getDomain[IsActive](rowT, 0) == IsActive(true)
let rowF = mkRow(@[some(@[0'u8])], fields)
check getDomain[IsActive](rowF, 0) == IsActive(false)

test "getDomain raises on NULL":
let row: Row = @[none(seq[byte])]
var raised = false
try:
discard getDomain[UsPostalCode](row, 0)
except PgTypeError:
raised = true
check raised

test "getDomainOpt some":
let row: Row = @[some(toBytes("12345"))]
check getDomainOpt[UsPostalCode](row, 0) == some(UsPostalCode("12345"))

test "getDomainOpt none":
let row: Row = @[none(seq[byte])]
check getDomainOpt[UsPostalCode](row, 0) == none(UsPostalCode)

test "getDomainOpt binary NULL":
let fields = @[mkField(OidInt4, 1'i16)]
let row = mkRow(@[none(seq[byte])], fields)
check getDomainOpt[PositiveInt](row, 0) == none(PositiveInt)

test "Option[Domain] toPgParam some":
let p = toPgParam(some(UsPostalCode("12345")))
check p.oid == OidText
check p.value.isSome
check toString(p.value.get) == "12345"

test "Option[Domain] toPgParam none":
let p = toPgParam(none(UsPostalCode))
check p.oid == OidText
check p.value.isNone

test "roundtrip text string":
let orig = UsPostalCode("90210")
let p = toPgParam(orig)
let row: Row = @[p.value]
check getDomain[UsPostalCode](row, 0) == orig

test "roundtrip binary int16":
let orig = SmallCount(5)
let p = toPgParam(orig)
let fields = @[mkField(OidInt2, 1'i16)]
let row = mkRow(@[p.value], fields)
check getDomain[SmallCount](row, 0) == orig

test "roundtrip binary int32":
let orig = PositiveInt(7)
let p = toPgParam(orig)
let fields = @[mkField(OidInt4, 1'i16)]
let row = mkRow(@[p.value], fields)
check getDomain[PositiveInt](row, 0) == orig

test "pgDomain int64 base type":
let p = toPgParam(BigCount(999999'i64))
check p.oid == OidInt8
check p.format == 1'i16
check p.value.isSome

test "pgDomain bool base type":
let p = toPgParam(IsActive(true))
check p.oid == OidBool
check p.format == 1'i16
check p.value.isSome

test "roundtrip binary int64":
let orig = BigCount(123456789'i64)
let p = toPgParam(orig)
let fields = @[mkField(OidInt8, 1'i16)]
let row = mkRow(@[p.value], fields)
check getDomain[BigCount](row, 0) == orig

test "roundtrip binary bool":
let orig = IsActive(true)
let p = toPgParam(orig)
let fields = @[mkField(OidBool, 1'i16)]
let row = mkRow(@[p.value], fields)
check getDomain[IsActive](row, 0) == orig

suite "Range OID constants":
test "range OID values":
check OidInt4Range == 3904'i32
Expand Down
Loading