Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 98 additions & 28 deletions async_postgres/pg_connection.nim
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ when hasChronos:
import bearssl/[x509, rsa, ec]
elif hasAsyncDispatch:
import std/asyncnet
from std/nativesockets import Domain, SockType, Protocol
when defined(ssl):
import std/[net, tempfiles, os]

Expand Down Expand Up @@ -203,6 +204,15 @@ type
columnFormats*: seq[int16]
commandTag*: string

proc isUnixSocket*(host: string): bool {.inline.} =
## True if `host` represents a Unix socket directory (starts with '/').
## Compatible with libpq behavior.
host.len > 0 and host[0] == '/'

proc unixSocketPath*(host: string, port: int): string =
## Build the libpq-compatible Unix socket file path: ``{dir}/.s.PGSQL.{port}``.
host & "/.s.PGSQL." & $port

proc getHosts*(config: ConnConfig): seq[HostEntry] =
## Return the list of hosts to try. If `hosts` is populated, return it;
## otherwise synthesize a single entry from `host`/`port`.
Expand Down Expand Up @@ -781,14 +791,14 @@ proc negotiateSSL(conn: PgConnection, config: ConnConfig) {.async.} =
when defined(posix):
var TCP_NODELAY {.importc, header: "<netinet/tcp.h>".}: cint

proc configureTcpNoDelay(fd: SocketHandle) =
proc configureTcpNoDelay(fd: posix.SocketHandle) =
## Disable Nagle's algorithm for low-latency sends.
var optval: cint = 1
discard setsockopt(
fd, cint(posix.IPPROTO_TCP), TCP_NODELAY, addr optval, sizeof(optval).SockLen
)

proc configureKeepalive(fd: SocketHandle, config: ConnConfig) =
proc configureKeepalive(fd: posix.SocketHandle, config: ConnConfig) =
## Set TCP keepalive options on the socket.
if not config.keepAlive:
return
Expand Down Expand Up @@ -878,21 +888,33 @@ proc connectToHost(
## Connect to a single PostgreSQL host. Internal helper for multi-host connect.
var conn: PgConnection

let isUnix = isUnixSocket(hostAddr)

when hasChronos:
let addresses = resolveTAddress(hostAddr, Port(hostPort))
if addresses.len == 0:
raise newException(PgConnectionError, "Could not resolve host: " & hostAddr)
let transport = await connect(addresses[0])
let transport =
if isUnix:
when defined(posix):
await connect(initTAddress(unixSocketPath(hostAddr, hostPort)))
else:
raise newException(
PgConnectionError, "Unix sockets are not supported on this platform"
)
else:
let addresses = resolveTAddress(hostAddr, Port(hostPort))
if addresses.len == 0:
raise newException(PgConnectionError, "Could not resolve host: " & hostAddr)
await connect(addresses[0])
when defined(posix):
try:
configureTcpNoDelay(SocketHandle(transport.fd))
configureKeepalive(SocketHandle(transport.fd), config)
except CatchableError as e:
if not isUnix:
try:
await noCancel transport.closeWait()
except CatchableError:
discard
raise newException(PgConnectionError, e.msg, e)
configureTcpNoDelay(posix.SocketHandle(transport.fd))
configureKeepalive(posix.SocketHandle(transport.fd), config)
except CatchableError as e:
try:
await noCancel transport.closeWait()
except CatchableError:
discard
raise newException(PgConnectionError, e.msg, e)
conn = PgConnection(
transport: transport,
recvBuf: @[],
Expand All @@ -905,12 +927,31 @@ proc connectToHost(
stmtCacheCapacity: 256,
)
elif hasAsyncDispatch:
let sock = newAsyncSocket(buffered = false)
let sock =
if isUnix:
when defined(posix):
newAsyncSocket(
Domain.AF_UNIX, SockType.SOCK_STREAM, Protocol.IPPROTO_IP, buffered = false
)
else:
raise newException(
PgConnectionError, "Unix sockets are not supported on this platform"
)
else:
newAsyncSocket(buffered = false)
try:
await sock.connect(hostAddr, Port(hostPort))
when defined(posix):
configureTcpNoDelay(SocketHandle(sock.getFd()))
configureKeepalive(SocketHandle(sock.getFd()), config)
if isUnix:
when defined(posix):
await sock.connectUnix(unixSocketPath(hostAddr, hostPort))
else:
raise newException(
PgConnectionError, "Unix sockets are not supported on this platform"
)
else:
await sock.connect(hostAddr, Port(hostPort))
when defined(posix):
configureTcpNoDelay(posix.SocketHandle(sock.getFd()))
configureKeepalive(posix.SocketHandle(sock.getFd()), config)
except CatchableError:
sock.close()
raise
Expand All @@ -927,8 +968,8 @@ proc connectToHost(
)

try:
# SSL negotiation (before StartupMessage)
if config.sslMode != sslDisable:
# SSL negotiation (before StartupMessage) — skip for Unix sockets
if config.sslMode != sslDisable and not isUnix:
await negotiateSSL(conn, config)

when hasChronos:
Expand Down Expand Up @@ -1123,21 +1164,50 @@ proc simpleExecImpl(
return commandTag

proc cancel*(conn: PgConnection): Future[void] {.async.} =
## Send a CancelRequest over a separate TCP connection to abort the running query.
## Send a CancelRequest over a separate connection to abort the running query.
let isUnix = isUnixSocket(conn.host)
when hasChronos:
let addresses = resolveTAddress(conn.host, Port(conn.port))
if addresses.len == 0:
raise newException(PgConnectionError, "Could not resolve host: " & conn.host)
let transport = await connect(addresses[0])
let transport =
if isUnix:
when defined(posix):
await connect(initTAddress(unixSocketPath(conn.host, conn.port)))
else:
raise newException(
PgConnectionError, "Unix sockets are not supported on this platform"
)
else:
let addresses = resolveTAddress(conn.host, Port(conn.port))
if addresses.len == 0:
raise newException(PgConnectionError, "Could not resolve host: " & conn.host)
await connect(addresses[0])
try:
let msg = encodeCancelRequest(conn.pid, conn.secretKey)
discard await transport.write(msg)
finally:
await transport.closeWait()
elif hasAsyncDispatch:
let sock = newAsyncSocket(buffered = false)
let sock =
if isUnix:
when defined(posix):
newAsyncSocket(
Domain.AF_UNIX, SockType.SOCK_STREAM, Protocol.IPPROTO_IP, buffered = false
)
else:
raise newException(
PgConnectionError, "Unix sockets are not supported on this platform"
)
else:
newAsyncSocket(buffered = false)
try:
await sock.connect(conn.host, Port(conn.port))
if isUnix:
when defined(posix):
await sock.connectUnix(unixSocketPath(conn.host, conn.port))
else:
raise newException(
PgConnectionError, "Unix sockets are not supported on this platform"
)
else:
await sock.connect(conn.host, Port(conn.port))
let msg = encodeCancelRequest(conn.pid, conn.secretKey)
await sock.sendRawBytes(msg)
finally:
Expand Down
14 changes: 7 additions & 7 deletions async_postgres/pg_types.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,7 @@ proc toPgBinaryParam*(v: Option[JsonNode]): PgParam =

proc encodeHstoreBinary*(v: PgHstore): seq[byte] =
## Encode hstore as PostgreSQL binary format.
## Format: numPairs(int32) + [keyLen(int32) + keyData + valLen(int32) + valData]...
## Format: ``numPairs(int32) + [keyLen(int32) + keyData + valLen(int32) + valData]...``
var size = 4
for k, val in v.pairs:
size += 4 + k.len + 4
Expand Down Expand Up @@ -4165,7 +4165,7 @@ optAccessor(getDateMultirange, getDateMultirangeOpt, PgMultirange[DateTime])
# Binary format: standard array container with range elements.

proc getInt4RangeArray*(row: Row, col: int): seq[PgRange[int32]] =
## Get a column value as an int4range[]. Handles binary format.
## Get a column value as an ``int4range[]``. Handles binary format.
if row.isBinaryCol(col):
let (off, clen) = cellInfo(row, col)
if clen == -1:
Expand Down Expand Up @@ -4193,7 +4193,7 @@ proc getInt4RangeArray*(row: Row, col: int): seq[PgRange[int32]] =
)

proc getInt8RangeArray*(row: Row, col: int): seq[PgRange[int64]] =
## Get a column value as an int8range[]. Handles binary format.
## Get a column value as an ``int8range[]``. Handles binary format.
if row.isBinaryCol(col):
let (off, clen) = cellInfo(row, col)
if clen == -1:
Expand Down Expand Up @@ -4221,7 +4221,7 @@ proc getInt8RangeArray*(row: Row, col: int): seq[PgRange[int64]] =
)

proc getNumRangeArray*(row: Row, col: int): seq[PgRange[PgNumeric]] =
## Get a column value as a numrange[]. Handles binary format.
## Get a column value as a ``numrange[]``. Handles binary format.
if row.isBinaryCol(col):
let (off, clen) = cellInfo(row, col)
if clen == -1:
Expand Down Expand Up @@ -4249,7 +4249,7 @@ proc getNumRangeArray*(row: Row, col: int): seq[PgRange[PgNumeric]] =
)

proc getTsRangeArray*(row: Row, col: int): seq[PgRange[DateTime]] =
## Get a column value as a tsrange[]. Handles binary format.
## Get a column value as a ``tsrange[]``. Handles binary format.
if row.isBinaryCol(col):
let (off, clen) = cellInfo(row, col)
if clen == -1:
Expand Down Expand Up @@ -4283,7 +4283,7 @@ proc getTsRangeArray*(row: Row, col: int): seq[PgRange[DateTime]] =
)

proc getTsTzRangeArray*(row: Row, col: int): seq[PgRange[DateTime]] =
## Get a column value as a tstzrange[]. Handles binary format.
## Get a column value as a ``tstzrange[]``. Handles binary format.
if row.isBinaryCol(col):
let (off, clen) = cellInfo(row, col)
if clen == -1:
Expand Down Expand Up @@ -4321,7 +4321,7 @@ proc getTsTzRangeArray*(row: Row, col: int): seq[PgRange[DateTime]] =
)

proc getDateRangeArray*(row: Row, col: int): seq[PgRange[DateTime]] =
## Get a column value as a daterange[]. Handles binary format.
## Get a column value as a ``daterange[]``. Handles binary format.
if row.isBinaryCol(col):
let (off, clen) = cellInfo(row, col)
if clen == -1:
Expand Down
31 changes: 31 additions & 0 deletions tests/test_dsn.nim
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,34 @@ suite "parseDsn keyword=value":
check cfg.user == "myuser"
check cfg.host == "dbhost"
check cfg.port == 5433

suite "Unix socket":
test "isUnixSocket":
check isUnixSocket("/var/run/postgresql") == true
check isUnixSocket("/tmp") == true
check isUnixSocket("localhost") == false
check isUnixSocket("127.0.0.1") == false
check isUnixSocket("") == false

test "unixSocketPath":
check unixSocketPath("/var/run/postgresql", 5432) ==
"/var/run/postgresql/.s.PGSQL.5432"
check unixSocketPath("/tmp", 5433) == "/tmp/.s.PGSQL.5433"

test "key-value DSN with unix socket host":
let cfg = parseDsn("host=/var/run/postgresql port=5432 dbname=test user=myuser")
check cfg.host == "/var/run/postgresql"
check cfg.port == 5432
check cfg.database == "test"
check cfg.user == "myuser"

test "key-value DSN with unix socket host default port":
let cfg = parseDsn("host=/tmp dbname=test")
check cfg.host == "/tmp"
check cfg.port == 5432

test "URI DSN with unix socket via query param":
let cfg = parseDsn("postgresql:///mydb?host=/var/run/postgresql")
check cfg.host == "/var/run/postgresql"
check cfg.database == "mydb"
check cfg.port == 5432
Loading