Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 129 additions & 2 deletions async_postgres/pg_connection.nim
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ type
cbDisable ## Never use SCRAM-SHA-256-PLUS; only SCRAM-SHA-256.
cbRequire ## Require SCRAM-SHA-256-PLUS; fail if unavailable.

AuthMethod* = enum
## Individual authentication methods for `ConnConfig.requireAuth`
## allowlisting (libpq `require_auth` parity).
amNone ## AuthenticationOk with no challenge (trust/peer/ident)
amPassword ## cleartext password (libpq: "password")
amMd5 ## MD5 challenge (libpq: "md5")
amScramSha256 ## SASL SCRAM-SHA-256 (libpq: "scram-sha-256")
amScramSha256Plus ## SASL SCRAM-SHA-256-PLUS (libpq: "scram-sha-256-plus")

TargetSessionAttrs* = enum
## Target server type for multi-host failover (libpq compatible).
tsaAny ## Connect to any server (default)
Expand All @@ -100,6 +109,16 @@ type
channelBinding*: ChannelBindingMode
## SCRAM channel binding policy (default cbPrefer). `cbRequire` fails the
## connection if SCRAM-SHA-256-PLUS cannot actually be used (libpq parity).
requireAuth*: set[AuthMethod]
## Allowlist of auth methods the client will accept. An empty set
## (default) means "allow any" — matching libpq when `require_auth` is
## unset. If the server requests a method outside this set, connect
## fails with `PgConnectionError`. For SASL, advertised mechanisms are
## filtered and the selected mechanism is validated.
##
## Note: libpq's `!`-prefix negation syntax (e.g. `!password`) is not
## yet supported by `parseRequireAuth` — specify the allowed methods
## positively instead.
applicationName*: string
connectTimeout*: Duration ## TCP connect timeout (default: no timeout)
keepAlive*: bool ## Enable TCP keepalive (default true via parseDsn)
Expand Down Expand Up @@ -284,6 +303,15 @@ type
conn*: PgConnection
err*: ref CatchableError

TraceInsecureAuthData* = object
## Advisory notification that a server-requested auth method is
## considered insecure in the current transport context. Currently fires
## for cleartext password over a non-SSL connection. The connection is
## NOT aborted — use `ConnConfig.requireAuth` for actual enforcement.
conn*: PgConnection
authMethod*: AuthMethod ## The method the server requested
sslEnabled*: bool ## Transport state at the time of the auth step

PgTracer* = ref object
## Tracing hooks for async-postgres operations.
## Set only the callbacks you need; nil callbacks are skipped with zero overhead.
Expand Down Expand Up @@ -328,6 +356,10 @@ type
onPoolReleaseEnd*:
proc(ctx: TraceContext, data: TracePoolReleaseEndData) {.gcsafe, raises: [].}
onPoolCloseError*: proc(data: TracePoolCloseErrorData) {.gcsafe, raises: [].}
onInsecureAuth*: proc(data: TraceInsecureAuthData) {.gcsafe, raises: [].}
## Fires when an auth method is used over an insecure transport
## (currently: cleartext password without SSL). Advisory only; does
## not abort the connection. Use `ConnConfig.requireAuth` to enforce.

# Public API: read-only getters

Expand Down Expand Up @@ -1125,6 +1157,42 @@ when defined(posix):
"TCP keepalive timing options (idle/interval/count) are not supported on this platform and will be ignored"
.}

proc fireInsecureAuth(conn: PgConnection, authMethod: AuthMethod) =
let t = conn.config.tracer
if t != nil and t.onInsecureAuth != nil:
t.onInsecureAuth(
TraceInsecureAuthData(
conn: conn, authMethod: authMethod, sslEnabled: conn.sslEnabled
)
)

proc enforceAuthAllowed(
authMethod: AuthMethod, allowed: set[AuthMethod], offered: string = ""
) {.raises: [PgConnectionError].} =
if allowed.len > 0 and authMethod notin allowed:
var msg =
"server requested auth method '" & $authMethod &
"' which is not in require_auth allowlist " & $allowed
if offered.len > 0:
msg.add(" (server offered: ")
msg.add(offered)
msg.add(")")
raise newException(PgConnectionError, msg)

proc filterSaslByRequireAuth*(
mechs: seq[string], allowed: set[AuthMethod]
): seq[string] =
## Filter a server-offered SASL mechanism list by the client's
## `requireAuth` policy. An empty `allowed` set performs no filtering
## (matching libpq semantics when `require_auth` is unset).
if allowed.len == 0:
return mechs
for m in mechs:
if m == "SCRAM-SHA-256-PLUS" and amScramSha256Plus in allowed:
result.add(m)
elif m == "SCRAM-SHA-256" and amScramSha256 in allowed:
result.add(m)

proc bytesToString(data: seq[byte]): string =
result = newString(data.len)
for i in 0 ..< data.len:
Expand Down Expand Up @@ -1307,23 +1375,51 @@ proc connectToHost(

# Authentication loop
var scramState: ScramState
var sawAuthRequest = false
block authLoop:
while true:
while (let opt = conn.nextMessage(); opt.isSome):
let msg = opt.get
case msg.kind
of bmkAuthenticationOk:
if not sawAuthRequest:
enforceAuthAllowed(amNone, config.requireAuth)
break authLoop
of bmkAuthenticationCleartextPassword:
sawAuthRequest = true
if not conn.sslEnabled:
fireInsecureAuth(conn, amPassword)
enforceAuthAllowed(amPassword, config.requireAuth)
await conn.sendMsg(encodePassword(config.password))
of bmkAuthenticationMD5Password:
sawAuthRequest = true
enforceAuthAllowed(amMd5, config.requireAuth)
let hash = md5AuthHash(config.user, config.password, msg.md5Salt)
await conn.sendMsg(encodePassword(hash))
of bmkAuthenticationSASL:
sawAuthRequest = true
let filtered =
filterSaslByRequireAuth(msg.saslMechanisms, config.requireAuth)
if config.requireAuth.len > 0 and filtered.len == 0:
raise newException(
PgConnectionError,
"server offered SASL mechanisms " & $msg.saslMechanisms &
" but none match require_auth allowlist " & $config.requireAuth,
)
let choice = selectScramMechanism(
conn.sslEnabled, conn.serverCertDer, msg.saslMechanisms,
config.channelBinding,
conn.sslEnabled, conn.serverCertDer, filtered, config.channelBinding
)
let chosen =
if choice.mechanism == "SCRAM-SHA-256-PLUS":
amScramSha256Plus
else:
amScramSha256
# Defensive: filterSaslByRequireAuth above already dropped
# disallowed mechanisms, so `chosen` is guaranteed allowed. This
# re-check guards against future changes to selectScramMechanism
# introducing a bypass (e.g. a fallback that reaches past the
# filtered list).
enforceAuthAllowed(chosen, config.requireAuth, $msg.saslMechanisms)
let clientFirst = scramClientFirstMessage(
config.user, scramState, choice.cbType, choice.cbData
)
Expand Down Expand Up @@ -1988,6 +2084,33 @@ proc parseChannelBindingMode(s: string): ChannelBindingMode =
else:
raise newException(PgError, "Invalid channel_binding: " & s)

proc parseAuthMethod(s: string): AuthMethod =
case s
of "none":
amNone
of "password":
amPassword
of "md5":
amMd5
of "scram-sha-256":
amScramSha256
of "scram-sha-256-plus":
amScramSha256Plus
else:
raise newException(PgError, "Invalid require_auth method: " & s)

proc parseRequireAuth*(s: string): set[AuthMethod] =
## Parse a comma-separated list of auth method names into a set
## (libpq `require_auth` syntax; negation prefix `!` is not yet supported).
## Empty input returns the empty set (allow any).
if s.len == 0:
return {}
for raw in s.split(','):
let tok = raw.strip()
if tok.len == 0:
raise newException(PgError, "Empty entry in require_auth list: " & s)
result.incl(parseAuthMethod(tok))

proc parseTargetSessionAttrs(s: string): TargetSessionAttrs =
case s
of "any":
Expand Down Expand Up @@ -2030,6 +2153,8 @@ proc applyParam(result: var ConnConfig, key, val: string) =
result.sslMode = parseSslMode(val)
of "channel_binding":
result.channelBinding = parseChannelBindingMode(val)
of "require_auth":
result.requireAuth = parseRequireAuth(val)
of "application_name":
result.applicationName = val
of "connect_timeout":
Expand Down Expand Up @@ -2268,6 +2393,7 @@ proc initConnConfig*(
keepAliveCount = 0,
hosts: seq[HostEntry] = @[],
targetSessionAttrs = tsaAny,
requireAuth: set[AuthMethod] = {},
extraParams: seq[(string, string)] = @[],
): ConnConfig =
## Create a connection configuration with sensible defaults.
Expand All @@ -2289,6 +2415,7 @@ proc initConnConfig*(
keepAliveCount: keepAliveCount,
hosts: hosts,
targetSessionAttrs: targetSessionAttrs,
requireAuth: requireAuth,
extraParams: extraParams,
)

Expand Down
41 changes: 41 additions & 0 deletions tests/test_dsn.nim
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,47 @@ suite "parseDsn":
expect PgError:
discard parseDsn("postgresql://host/db?channel_binding=bogus")

test "require_auth default is empty set":
let cfg = parseDsn("postgresql://host/db")
check cfg.requireAuth == {}

test "ConnConfig zero init has empty requireAuth":
let cfg = ConnConfig()
check cfg.requireAuth == {}

test "query param require_auth single value":
let cases = {
"none": amNone,
"password": amPassword,
"md5": amMd5,
"scram-sha-256": amScramSha256,
"scram-sha-256-plus": amScramSha256Plus,
}
for (name, expected) in cases:
let cfg = parseDsn("postgresql://host/db?require_auth=" & name)
check cfg.requireAuth == {expected}

test "query param require_auth comma list":
let cfg =
parseDsn("postgresql://host/db?require_auth=scram-sha-256,scram-sha-256-plus")
check cfg.requireAuth == {amScramSha256, amScramSha256Plus}

test "query param require_auth tolerates whitespace":
let cfg = parseDsn("postgresql://host/db?require_auth=scram-sha-256,%20md5")
check cfg.requireAuth == {amScramSha256, amMd5}

test "keyword=value form require_auth":
let cfg = parseDsn("host=127.0.0.1 require_auth=md5,password")
check cfg.requireAuth == {amMd5, amPassword}

test "error: unknown require_auth method":
expect PgError:
discard parseDsn("postgresql://host/db?require_auth=sha1")

test "error: empty entry in require_auth list":
expect PgError:
discard parseDsn("postgresql://host/db?require_auth=md5,,password")

test "error: invalid connect_timeout":
expect PgError:
discard parseDsn("postgresql://host/db?connect_timeout=abc")
Expand Down
111 changes: 111 additions & 0 deletions tests/test_tracing.nim
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ type
hasConn: bool
errMsg: string

InsecureAuthRec = object
hasConn: bool
authMethod: AuthMethod
sslEnabled: bool

TraceLog = ref object
connectStarts: seq[ConnectStartRec]
connectEnds: seq[ConnectEndRec]
Expand All @@ -130,6 +135,7 @@ type
poolReleaseStarts: seq[PoolReleaseStartRec]
poolReleaseEnds: seq[PoolReleaseEndRec]
poolCloseErrors: seq[PoolCloseErrorRec]
insecureAuths: seq[InsecureAuthRec]

proc newTraceLog(): TraceLog =
TraceLog()
Expand Down Expand Up @@ -261,6 +267,15 @@ proc buildTracer(log: TraceLog): PgTracer =
)
)

tracer.onInsecureAuth = proc(data: TraceInsecureAuthData) {.gcsafe, raises: [].} =
log.insecureAuths.add(
InsecureAuthRec(
hasConn: data.conn != nil,
authMethod: data.authMethod,
sslEnabled: data.sslEnabled,
)
)

return tracer

proc tracedConfig(tracer: PgTracer): ConnConfig =
Expand Down Expand Up @@ -821,3 +836,99 @@ suite "Tracing: nil tracer":
await conn.close()

waitFor t()

suite "Tracing: onInsecureAuth":
# These unit tests exercise the tracer closure directly via the public
# hook. End-to-end firing from the auth loop (cleartext over plaintext)
# requires a PG server configured to request `password` auth, which is
# out of scope for the current docker-compose setup.
test "closure receives method and transport state (plaintext)":
let log = newTraceLog()
let tracer = buildTracer(log)
let fake = PgConnection()

tracer.onInsecureAuth(
TraceInsecureAuthData(conn: fake, authMethod: amPassword, sslEnabled: false)
)

check log.insecureAuths.len == 1
check log.insecureAuths[0].hasConn
check log.insecureAuths[0].authMethod == amPassword
check log.insecureAuths[0].sslEnabled == false

test "closure receives sslEnabled=true":
let log = newTraceLog()
let tracer = buildTracer(log)
let fake = PgConnection()

tracer.onInsecureAuth(
TraceInsecureAuthData(conn: fake, authMethod: amPassword, sslEnabled: true)
)

check log.insecureAuths.len == 1
check log.insecureAuths[0].sslEnabled == true

test "PgTracer with nil onInsecureAuth hook is safe to leave unset":
let tracer = PgTracer()
check tracer.onInsecureAuth == nil

suite "filterSaslByRequireAuth":
test "empty allowed set performs no filtering":
let mechs = @["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"]
check filterSaslByRequireAuth(mechs, {}) == mechs

test "drops PLUS when only SCRAM allowed":
let mechs = @["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"]
check filterSaslByRequireAuth(mechs, {amScramSha256}) == @["SCRAM-SHA-256"]

test "drops SCRAM when only PLUS allowed":
let mechs = @["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"]
check filterSaslByRequireAuth(mechs, {amScramSha256Plus}) == @["SCRAM-SHA-256-PLUS"]

test "empty result when nothing matches":
let mechs = @["SCRAM-SHA-256"]
check filterSaslByRequireAuth(mechs, {amMd5}).len == 0

test "unrelated non-SCRAM methods in allowlist do not add mechanisms":
let mechs = @["SCRAM-SHA-256"]
check filterSaslByRequireAuth(mechs, {amScramSha256, amMd5, amPassword}) ==
@["SCRAM-SHA-256"]

suite "Tracing: requireAuth happy path":
test "accepts SCRAM when explicitly allowed":
proc t() {.async.} =
var cfg = plainConfig()
cfg.requireAuth = {amScramSha256, amScramSha256Plus}
let conn = await connect(cfg)
doAssert conn != nil
discard await conn.exec("SELECT 1")
await conn.close()

waitFor t()

suite "Tracing: requireAuth negative path":
# These tests require the docker-compose PG to use SCRAM auth (the default
# for modern postgres images). If the server ever switches to trust/md5,
# the expected error message will differ but the connect must still fail.
proc connectRaises(requireAuth: set[AuthMethod]): bool =
var raised = false
proc t() {.async.} =
var cfg = plainConfig()
cfg.requireAuth = requireAuth
try:
let conn = await connect(cfg)
await conn.close()
except PgConnectionError:
raised = true

waitFor t()
raised

test "rejects when only md5 is allowed against SCRAM server":
check connectRaises({amMd5})

test "rejects when only password is allowed against SCRAM server":
check connectRaises({amPassword})

test "rejects when only amNone is allowed against SCRAM server":
check connectRaises({amNone})
Loading