Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Async PostgreSQL client in Nim.
- Pool cluster with read replica routing
- SSL/TLS support (disable, allow, prefer, require, verify-ca, verify-full)
- MD5, SCRAM-SHA-256 and SCRAM-SHA-256-PLUS authentication
- `channel_binding` policy (disable, prefer, require) to harden SCRAM against downgrade
- DSN connection string parsing
- Unix socket connection
- Multi-host failover
Expand Down
101 changes: 83 additions & 18 deletions async_postgres/pg_connection.nim
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ type
sslVerifyCa ## Require SSL + verify CA chain (no hostname verification)
sslVerifyFull ## Require SSL + verify CA chain and hostname

ChannelBindingMode* = enum
## SCRAM channel binding policy (libpq-compatible).
cbPrefer ## Use SCRAM-SHA-256-PLUS when SSL and server support it (default).
cbDisable ## Never use SCRAM-SHA-256-PLUS; only SCRAM-SHA-256.
cbRequire ## Require SCRAM-SHA-256-PLUS; fail if unavailable.

TargetSessionAttrs* = enum
## Target server type for multi-host failover (libpq compatible).
tsaAny ## Connect to any server (default)
Expand All @@ -91,6 +97,9 @@ type
database*: string
sslMode*: SslMode
sslRootCert*: string ## PEM-encoded CA certificate(s) for sslVerifyCa/sslVerifyFull
channelBinding*: ChannelBindingMode
## SCRAM channel binding policy (default cbPrefer). `cbRequire` fails the
## connection if SCRAM-SHA-256-PLUS cannot actually be used (libpq parity).
applicationName*: string
connectTimeout*: Duration ## TCP connect timeout (default: no timeout)
keepAlive*: bool ## Enable TCP keepalive (default true via parseDsn)
Expand Down Expand Up @@ -1112,6 +1121,57 @@ proc bytesToString(data: seq[byte]): string =
for i in 0 ..< data.len:
result[i] = char(data[i])

proc selectScramMechanism(
sslEnabled: bool,
serverCertDer: openArray[byte],
saslMechanisms: seq[string],
mode: ChannelBindingMode,
): tuple[mechanism: string, cbType: string, cbData: seq[byte]] =
## Pick the SCRAM mechanism and channel-binding material for a SASL
## authentication attempt. Raises `PgConnectionError` when the server-offered
## mechanisms cannot satisfy `mode`.
let serverHasPlus = "SCRAM-SHA-256-PLUS" in saslMechanisms
let serverHasScram = "SCRAM-SHA-256" in saslMechanisms
let canUsePlus = sslEnabled and serverCertDer.len > 0 and serverHasPlus
case mode
of cbRequire:
if not sslEnabled:
raise newException(
PgConnectionError, "channel binding is required, but SSL is not in use"
)
if not serverHasPlus:
raise newException(
PgConnectionError,
"channel binding is required, but server did not offer SCRAM-SHA-256-PLUS",
)
if serverCertDer.len == 0:
raise newException(
PgConnectionError,
"channel binding is required, but server certificate is unavailable",
)
result.mechanism = "SCRAM-SHA-256-PLUS"
result.cbType = "tls-server-end-point"
result.cbData = computeTlsServerEndpoint(serverCertDer)
of cbPrefer:
if canUsePlus:
result.mechanism = "SCRAM-SHA-256-PLUS"
result.cbType = "tls-server-end-point"
result.cbData = computeTlsServerEndpoint(serverCertDer)
elif serverHasScram:
result.mechanism = "SCRAM-SHA-256"
else:
raise newException(
PgConnectionError, "server doesn't support SCRAM-SHA-256 or SCRAM-SHA-256-PLUS"
)
of cbDisable:
if serverHasScram:
result.mechanism = "SCRAM-SHA-256"
else:
raise newException(
PgConnectionError,
"channel binding is disabled, but server only offered SCRAM-SHA-256-PLUS",
)

proc connectToHost(
config: ConnConfig, hostAddr: string, hostPort: int
): Future[PgConnection] {.async.} =
Expand Down Expand Up @@ -1251,24 +1311,14 @@ proc connectToHost(
let hash = md5AuthHash(config.user, config.password, msg.md5Salt)
await conn.sendMsg(encodePassword(hash))
of bmkAuthenticationSASL:
var mechanism: string
var cbType = ""
var cbData: seq[byte]
if conn.sslEnabled and conn.serverCertDer.len > 0 and
"SCRAM-SHA-256-PLUS" in msg.saslMechanisms:
mechanism = "SCRAM-SHA-256-PLUS"
cbType = "tls-server-end-point"
cbData = computeTlsServerEndpoint(conn.serverCertDer)
elif "SCRAM-SHA-256" in msg.saslMechanisms:
mechanism = "SCRAM-SHA-256"
else:
raise newException(
PgConnectionError,
"Server doesn't support SCRAM-SHA-256 or SCRAM-SHA-256-PLUS",
)
let clientFirst =
scramClientFirstMessage(config.user, scramState, cbType, cbData)
await conn.sendMsg(encodeSASLInitialResponse(mechanism, clientFirst))
let choice = selectScramMechanism(
conn.sslEnabled, conn.serverCertDer, msg.saslMechanisms,
config.channelBinding,
)
let clientFirst = scramClientFirstMessage(
config.user, scramState, choice.cbType, choice.cbData
)
await conn.sendMsg(encodeSASLInitialResponse(choice.mechanism, clientFirst))
of bmkAuthenticationSASLContinue:
let clientFinal =
scramClientFinalMessage(config.password, msg.saslData, scramState)
Expand Down Expand Up @@ -1918,6 +1968,17 @@ proc parseSslMode(s: string): SslMode =
else:
raise newException(PgError, "Invalid sslmode: " & s)

proc parseChannelBindingMode(s: string): ChannelBindingMode =
case s
of "disable":
cbDisable
of "prefer":
cbPrefer
of "require":
cbRequire
else:
raise newException(PgError, "Invalid channel_binding: " & s)

proc parseTargetSessionAttrs(s: string): TargetSessionAttrs =
case s
of "any":
Expand Down Expand Up @@ -1958,6 +2019,8 @@ proc applyParam(result: var ConnConfig, key, val: string) =
result.password = val
of "sslmode":
result.sslMode = parseSslMode(val)
of "channel_binding":
result.channelBinding = parseChannelBindingMode(val)
of "application_name":
result.applicationName = val
of "connect_timeout":
Expand Down Expand Up @@ -2187,6 +2250,7 @@ proc initConnConfig*(
database = "",
sslMode = sslDisable,
sslRootCert = "",
channelBinding = cbPrefer,
applicationName = "",
connectTimeout = ZeroDuration,
keepAlive = true,
Expand All @@ -2207,6 +2271,7 @@ proc initConnConfig*(
database: database,
sslMode: sslMode,
sslRootCert: sslRootCert,
channelBinding: channelBinding,
applicationName: applicationName,
connectTimeout: connectTimeout,
keepAlive: keepAlive,
Expand Down
33 changes: 33 additions & 0 deletions tests/test_dsn.nim
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,31 @@ suite "parseDsn":
expect PgError:
discard parseDsn("postgresql://host/db?sslmode=bogus")

test "query param channel_binding":
for mode in ["disable", "prefer", "require"]:
let cfg = parseDsn("postgresql://host/db?channel_binding=" & mode)
case mode
of "disable":
check cfg.channelBinding == cbDisable
of "prefer":
check cfg.channelBinding == cbPrefer
of "require":
check cfg.channelBinding == cbRequire
else:
discard

test "channel_binding default is prefer":
let cfg = parseDsn("postgresql://host/db")
check cfg.channelBinding == cbPrefer

test "ConnConfig zero init has cbPrefer":
let cfg = ConnConfig()
check cfg.channelBinding == cbPrefer

test "error: invalid channel_binding":
expect PgError:
discard parseDsn("postgresql://host/db?channel_binding=bogus")

test "error: invalid connect_timeout":
expect PgError:
discard parseDsn("postgresql://host/db?connect_timeout=abc")
Expand Down Expand Up @@ -459,6 +484,14 @@ suite "parseDsn keyword=value":
expect PgError:
discard parseDsn("host=h sslmode=bogus")

test "channel_binding parameter":
let cfg = parseDsn("host=h channel_binding=require")
check cfg.channelBinding == cbRequire

test "error: invalid channel_binding":
expect PgError:
discard parseDsn("host=h channel_binding=bogus")

test "error: invalid connect_timeout":
expect PgError:
discard parseDsn("host=h connect_timeout=abc")
Expand Down
Loading
Loading