diff --git a/async_postgres/pg_connection.nim b/async_postgres/pg_connection.nim index f777e60..cb8b2e4 100644 --- a/async_postgres/pg_connection.nim +++ b/async_postgres/pg_connection.nim @@ -9,6 +9,7 @@ when hasChronos: import bearssl/[x509, rsa, ec] elif hasAsyncDispatch: import std/asyncnet + from std/nativesockets import Domain, SockType, Protocol when defined(ssl): import std/[net, tempfiles, os] @@ -203,6 +204,15 @@ type columnFormats*: seq[int16] commandTag*: string +proc isUnixSocket*(host: string): bool {.inline.} = + ## True if `host` represents a Unix socket directory (starts with '/'). + ## Compatible with libpq behavior. + host.len > 0 and host[0] == '/' + +proc unixSocketPath*(host: string, port: int): string = + ## Build the libpq-compatible Unix socket file path: ``{dir}/.s.PGSQL.{port}``. + host & "/.s.PGSQL." & $port + proc getHosts*(config: ConnConfig): seq[HostEntry] = ## Return the list of hosts to try. If `hosts` is populated, return it; ## otherwise synthesize a single entry from `host`/`port`. @@ -781,14 +791,14 @@ proc negotiateSSL(conn: PgConnection, config: ConnConfig) {.async.} = when defined(posix): var TCP_NODELAY {.importc, header: "".}: cint - proc configureTcpNoDelay(fd: SocketHandle) = + proc configureTcpNoDelay(fd: posix.SocketHandle) = ## Disable Nagle's algorithm for low-latency sends. var optval: cint = 1 discard setsockopt( fd, cint(posix.IPPROTO_TCP), TCP_NODELAY, addr optval, sizeof(optval).SockLen ) - proc configureKeepalive(fd: SocketHandle, config: ConnConfig) = + proc configureKeepalive(fd: posix.SocketHandle, config: ConnConfig) = ## Set TCP keepalive options on the socket. if not config.keepAlive: return @@ -878,21 +888,33 @@ proc connectToHost( ## Connect to a single PostgreSQL host. Internal helper for multi-host connect. var conn: PgConnection + let isUnix = isUnixSocket(hostAddr) + when hasChronos: - let addresses = resolveTAddress(hostAddr, Port(hostPort)) - if addresses.len == 0: - raise newException(PgConnectionError, "Could not resolve host: " & hostAddr) - let transport = await connect(addresses[0]) + let transport = + if isUnix: + when defined(posix): + await connect(initTAddress(unixSocketPath(hostAddr, hostPort))) + else: + raise newException( + PgConnectionError, "Unix sockets are not supported on this platform" + ) + else: + let addresses = resolveTAddress(hostAddr, Port(hostPort)) + if addresses.len == 0: + raise newException(PgConnectionError, "Could not resolve host: " & hostAddr) + await connect(addresses[0]) when defined(posix): - try: - configureTcpNoDelay(SocketHandle(transport.fd)) - configureKeepalive(SocketHandle(transport.fd), config) - except CatchableError as e: + if not isUnix: try: - await noCancel transport.closeWait() - except CatchableError: - discard - raise newException(PgConnectionError, e.msg, e) + configureTcpNoDelay(posix.SocketHandle(transport.fd)) + configureKeepalive(posix.SocketHandle(transport.fd), config) + except CatchableError as e: + try: + await noCancel transport.closeWait() + except CatchableError: + discard + raise newException(PgConnectionError, e.msg, e) conn = PgConnection( transport: transport, recvBuf: @[], @@ -905,12 +927,31 @@ proc connectToHost( stmtCacheCapacity: 256, ) elif hasAsyncDispatch: - let sock = newAsyncSocket(buffered = false) + let sock = + if isUnix: + when defined(posix): + newAsyncSocket( + Domain.AF_UNIX, SockType.SOCK_STREAM, Protocol.IPPROTO_IP, buffered = false + ) + else: + raise newException( + PgConnectionError, "Unix sockets are not supported on this platform" + ) + else: + newAsyncSocket(buffered = false) try: - await sock.connect(hostAddr, Port(hostPort)) - when defined(posix): - configureTcpNoDelay(SocketHandle(sock.getFd())) - configureKeepalive(SocketHandle(sock.getFd()), config) + if isUnix: + when defined(posix): + await sock.connectUnix(unixSocketPath(hostAddr, hostPort)) + else: + raise newException( + PgConnectionError, "Unix sockets are not supported on this platform" + ) + else: + await sock.connect(hostAddr, Port(hostPort)) + when defined(posix): + configureTcpNoDelay(posix.SocketHandle(sock.getFd())) + configureKeepalive(posix.SocketHandle(sock.getFd()), config) except CatchableError: sock.close() raise @@ -927,8 +968,8 @@ proc connectToHost( ) try: - # SSL negotiation (before StartupMessage) - if config.sslMode != sslDisable: + # SSL negotiation (before StartupMessage) — skip for Unix sockets + if config.sslMode != sslDisable and not isUnix: await negotiateSSL(conn, config) when hasChronos: @@ -1123,21 +1164,50 @@ proc simpleExecImpl( return commandTag proc cancel*(conn: PgConnection): Future[void] {.async.} = - ## Send a CancelRequest over a separate TCP connection to abort the running query. + ## Send a CancelRequest over a separate connection to abort the running query. + let isUnix = isUnixSocket(conn.host) when hasChronos: - let addresses = resolveTAddress(conn.host, Port(conn.port)) - if addresses.len == 0: - raise newException(PgConnectionError, "Could not resolve host: " & conn.host) - let transport = await connect(addresses[0]) + let transport = + if isUnix: + when defined(posix): + await connect(initTAddress(unixSocketPath(conn.host, conn.port))) + else: + raise newException( + PgConnectionError, "Unix sockets are not supported on this platform" + ) + else: + let addresses = resolveTAddress(conn.host, Port(conn.port)) + if addresses.len == 0: + raise newException(PgConnectionError, "Could not resolve host: " & conn.host) + await connect(addresses[0]) try: let msg = encodeCancelRequest(conn.pid, conn.secretKey) discard await transport.write(msg) finally: await transport.closeWait() elif hasAsyncDispatch: - let sock = newAsyncSocket(buffered = false) + let sock = + if isUnix: + when defined(posix): + newAsyncSocket( + Domain.AF_UNIX, SockType.SOCK_STREAM, Protocol.IPPROTO_IP, buffered = false + ) + else: + raise newException( + PgConnectionError, "Unix sockets are not supported on this platform" + ) + else: + newAsyncSocket(buffered = false) try: - await sock.connect(conn.host, Port(conn.port)) + if isUnix: + when defined(posix): + await sock.connectUnix(unixSocketPath(conn.host, conn.port)) + else: + raise newException( + PgConnectionError, "Unix sockets are not supported on this platform" + ) + else: + await sock.connect(conn.host, Port(conn.port)) let msg = encodeCancelRequest(conn.pid, conn.secretKey) await sock.sendRawBytes(msg) finally: diff --git a/async_postgres/pg_types.nim b/async_postgres/pg_types.nim index eddfea7..bbb77dd 100644 --- a/async_postgres/pg_types.nim +++ b/async_postgres/pg_types.nim @@ -1129,7 +1129,7 @@ proc toPgBinaryParam*(v: Option[JsonNode]): PgParam = proc encodeHstoreBinary*(v: PgHstore): seq[byte] = ## Encode hstore as PostgreSQL binary format. - ## Format: numPairs(int32) + [keyLen(int32) + keyData + valLen(int32) + valData]... + ## Format: ``numPairs(int32) + [keyLen(int32) + keyData + valLen(int32) + valData]...`` var size = 4 for k, val in v.pairs: size += 4 + k.len + 4 @@ -4165,7 +4165,7 @@ optAccessor(getDateMultirange, getDateMultirangeOpt, PgMultirange[DateTime]) # Binary format: standard array container with range elements. proc getInt4RangeArray*(row: Row, col: int): seq[PgRange[int32]] = - ## Get a column value as an int4range[]. Handles binary format. + ## Get a column value as an ``int4range[]``. Handles binary format. if row.isBinaryCol(col): let (off, clen) = cellInfo(row, col) if clen == -1: @@ -4193,7 +4193,7 @@ proc getInt4RangeArray*(row: Row, col: int): seq[PgRange[int32]] = ) proc getInt8RangeArray*(row: Row, col: int): seq[PgRange[int64]] = - ## Get a column value as an int8range[]. Handles binary format. + ## Get a column value as an ``int8range[]``. Handles binary format. if row.isBinaryCol(col): let (off, clen) = cellInfo(row, col) if clen == -1: @@ -4221,7 +4221,7 @@ proc getInt8RangeArray*(row: Row, col: int): seq[PgRange[int64]] = ) proc getNumRangeArray*(row: Row, col: int): seq[PgRange[PgNumeric]] = - ## Get a column value as a numrange[]. Handles binary format. + ## Get a column value as a ``numrange[]``. Handles binary format. if row.isBinaryCol(col): let (off, clen) = cellInfo(row, col) if clen == -1: @@ -4249,7 +4249,7 @@ proc getNumRangeArray*(row: Row, col: int): seq[PgRange[PgNumeric]] = ) proc getTsRangeArray*(row: Row, col: int): seq[PgRange[DateTime]] = - ## Get a column value as a tsrange[]. Handles binary format. + ## Get a column value as a ``tsrange[]``. Handles binary format. if row.isBinaryCol(col): let (off, clen) = cellInfo(row, col) if clen == -1: @@ -4283,7 +4283,7 @@ proc getTsRangeArray*(row: Row, col: int): seq[PgRange[DateTime]] = ) proc getTsTzRangeArray*(row: Row, col: int): seq[PgRange[DateTime]] = - ## Get a column value as a tstzrange[]. Handles binary format. + ## Get a column value as a ``tstzrange[]``. Handles binary format. if row.isBinaryCol(col): let (off, clen) = cellInfo(row, col) if clen == -1: @@ -4321,7 +4321,7 @@ proc getTsTzRangeArray*(row: Row, col: int): seq[PgRange[DateTime]] = ) proc getDateRangeArray*(row: Row, col: int): seq[PgRange[DateTime]] = - ## Get a column value as a daterange[]. Handles binary format. + ## Get a column value as a ``daterange[]``. Handles binary format. if row.isBinaryCol(col): let (off, clen) = cellInfo(row, col) if clen == -1: diff --git a/tests/test_dsn.nim b/tests/test_dsn.nim index 3595ab4..2b21f93 100644 --- a/tests/test_dsn.nim +++ b/tests/test_dsn.nim @@ -478,3 +478,34 @@ suite "parseDsn keyword=value": check cfg.user == "myuser" check cfg.host == "dbhost" check cfg.port == 5433 + +suite "Unix socket": + test "isUnixSocket": + check isUnixSocket("/var/run/postgresql") == true + check isUnixSocket("/tmp") == true + check isUnixSocket("localhost") == false + check isUnixSocket("127.0.0.1") == false + check isUnixSocket("") == false + + test "unixSocketPath": + check unixSocketPath("/var/run/postgresql", 5432) == + "/var/run/postgresql/.s.PGSQL.5432" + check unixSocketPath("/tmp", 5433) == "/tmp/.s.PGSQL.5433" + + test "key-value DSN with unix socket host": + let cfg = parseDsn("host=/var/run/postgresql port=5432 dbname=test user=myuser") + check cfg.host == "/var/run/postgresql" + check cfg.port == 5432 + check cfg.database == "test" + check cfg.user == "myuser" + + test "key-value DSN with unix socket host default port": + let cfg = parseDsn("host=/tmp dbname=test") + check cfg.host == "/tmp" + check cfg.port == 5432 + + test "URI DSN with unix socket via query param": + let cfg = parseDsn("postgresql:///mydb?host=/var/run/postgresql") + check cfg.host == "/var/run/postgresql" + check cfg.database == "mydb" + check cfg.port == 5432