diff --git a/async_postgres/pg_connection.nim b/async_postgres/pg_connection.nim index 9802c65..a51014e 100644 --- a/async_postgres/pg_connection.nim +++ b/async_postgres/pg_connection.nim @@ -75,6 +75,15 @@ type cbDisable ## Never use SCRAM-SHA-256-PLUS; only SCRAM-SHA-256. cbRequire ## Require SCRAM-SHA-256-PLUS; fail if unavailable. + AuthMethod* = enum + ## Individual authentication methods for `ConnConfig.requireAuth` + ## allowlisting (libpq `require_auth` parity). + amNone ## AuthenticationOk with no challenge (trust/peer/ident) + amPassword ## cleartext password (libpq: "password") + amMd5 ## MD5 challenge (libpq: "md5") + amScramSha256 ## SASL SCRAM-SHA-256 (libpq: "scram-sha-256") + amScramSha256Plus ## SASL SCRAM-SHA-256-PLUS (libpq: "scram-sha-256-plus") + TargetSessionAttrs* = enum ## Target server type for multi-host failover (libpq compatible). tsaAny ## Connect to any server (default) @@ -100,6 +109,16 @@ type channelBinding*: ChannelBindingMode ## SCRAM channel binding policy (default cbPrefer). `cbRequire` fails the ## connection if SCRAM-SHA-256-PLUS cannot actually be used (libpq parity). + requireAuth*: set[AuthMethod] + ## Allowlist of auth methods the client will accept. An empty set + ## (default) means "allow any" — matching libpq when `require_auth` is + ## unset. If the server requests a method outside this set, connect + ## fails with `PgConnectionError`. For SASL, advertised mechanisms are + ## filtered and the selected mechanism is validated. + ## + ## Note: libpq's `!`-prefix negation syntax (e.g. `!password`) is not + ## yet supported by `parseRequireAuth` — specify the allowed methods + ## positively instead. applicationName*: string connectTimeout*: Duration ## TCP connect timeout (default: no timeout) keepAlive*: bool ## Enable TCP keepalive (default true via parseDsn) @@ -284,6 +303,15 @@ type conn*: PgConnection err*: ref CatchableError + TraceInsecureAuthData* = object + ## Advisory notification that a server-requested auth method is + ## considered insecure in the current transport context. Currently fires + ## for cleartext password over a non-SSL connection. The connection is + ## NOT aborted — use `ConnConfig.requireAuth` for actual enforcement. + conn*: PgConnection + authMethod*: AuthMethod ## The method the server requested + sslEnabled*: bool ## Transport state at the time of the auth step + PgTracer* = ref object ## Tracing hooks for async-postgres operations. ## Set only the callbacks you need; nil callbacks are skipped with zero overhead. @@ -328,6 +356,10 @@ type onPoolReleaseEnd*: proc(ctx: TraceContext, data: TracePoolReleaseEndData) {.gcsafe, raises: [].} onPoolCloseError*: proc(data: TracePoolCloseErrorData) {.gcsafe, raises: [].} + onInsecureAuth*: proc(data: TraceInsecureAuthData) {.gcsafe, raises: [].} + ## Fires when an auth method is used over an insecure transport + ## (currently: cleartext password without SSL). Advisory only; does + ## not abort the connection. Use `ConnConfig.requireAuth` to enforce. # Public API: read-only getters @@ -1125,6 +1157,42 @@ when defined(posix): "TCP keepalive timing options (idle/interval/count) are not supported on this platform and will be ignored" .} +proc fireInsecureAuth(conn: PgConnection, authMethod: AuthMethod) = + let t = conn.config.tracer + if t != nil and t.onInsecureAuth != nil: + t.onInsecureAuth( + TraceInsecureAuthData( + conn: conn, authMethod: authMethod, sslEnabled: conn.sslEnabled + ) + ) + +proc enforceAuthAllowed( + authMethod: AuthMethod, allowed: set[AuthMethod], offered: string = "" +) {.raises: [PgConnectionError].} = + if allowed.len > 0 and authMethod notin allowed: + var msg = + "server requested auth method '" & $authMethod & + "' which is not in require_auth allowlist " & $allowed + if offered.len > 0: + msg.add(" (server offered: ") + msg.add(offered) + msg.add(")") + raise newException(PgConnectionError, msg) + +proc filterSaslByRequireAuth*( + mechs: seq[string], allowed: set[AuthMethod] +): seq[string] = + ## Filter a server-offered SASL mechanism list by the client's + ## `requireAuth` policy. An empty `allowed` set performs no filtering + ## (matching libpq semantics when `require_auth` is unset). + if allowed.len == 0: + return mechs + for m in mechs: + if m == "SCRAM-SHA-256-PLUS" and amScramSha256Plus in allowed: + result.add(m) + elif m == "SCRAM-SHA-256" and amScramSha256 in allowed: + result.add(m) + proc bytesToString(data: seq[byte]): string = result = newString(data.len) for i in 0 ..< data.len: @@ -1307,23 +1375,51 @@ proc connectToHost( # Authentication loop var scramState: ScramState + var sawAuthRequest = false block authLoop: while true: while (let opt = conn.nextMessage(); opt.isSome): let msg = opt.get case msg.kind of bmkAuthenticationOk: + if not sawAuthRequest: + enforceAuthAllowed(amNone, config.requireAuth) break authLoop of bmkAuthenticationCleartextPassword: + sawAuthRequest = true + if not conn.sslEnabled: + fireInsecureAuth(conn, amPassword) + enforceAuthAllowed(amPassword, config.requireAuth) await conn.sendMsg(encodePassword(config.password)) of bmkAuthenticationMD5Password: + sawAuthRequest = true + enforceAuthAllowed(amMd5, config.requireAuth) let hash = md5AuthHash(config.user, config.password, msg.md5Salt) await conn.sendMsg(encodePassword(hash)) of bmkAuthenticationSASL: + sawAuthRequest = true + let filtered = + filterSaslByRequireAuth(msg.saslMechanisms, config.requireAuth) + if config.requireAuth.len > 0 and filtered.len == 0: + raise newException( + PgConnectionError, + "server offered SASL mechanisms " & $msg.saslMechanisms & + " but none match require_auth allowlist " & $config.requireAuth, + ) let choice = selectScramMechanism( - conn.sslEnabled, conn.serverCertDer, msg.saslMechanisms, - config.channelBinding, + conn.sslEnabled, conn.serverCertDer, filtered, config.channelBinding ) + let chosen = + if choice.mechanism == "SCRAM-SHA-256-PLUS": + amScramSha256Plus + else: + amScramSha256 + # Defensive: filterSaslByRequireAuth above already dropped + # disallowed mechanisms, so `chosen` is guaranteed allowed. This + # re-check guards against future changes to selectScramMechanism + # introducing a bypass (e.g. a fallback that reaches past the + # filtered list). + enforceAuthAllowed(chosen, config.requireAuth, $msg.saslMechanisms) let clientFirst = scramClientFirstMessage( config.user, scramState, choice.cbType, choice.cbData ) @@ -1988,6 +2084,33 @@ proc parseChannelBindingMode(s: string): ChannelBindingMode = else: raise newException(PgError, "Invalid channel_binding: " & s) +proc parseAuthMethod(s: string): AuthMethod = + case s + of "none": + amNone + of "password": + amPassword + of "md5": + amMd5 + of "scram-sha-256": + amScramSha256 + of "scram-sha-256-plus": + amScramSha256Plus + else: + raise newException(PgError, "Invalid require_auth method: " & s) + +proc parseRequireAuth*(s: string): set[AuthMethod] = + ## Parse a comma-separated list of auth method names into a set + ## (libpq `require_auth` syntax; negation prefix `!` is not yet supported). + ## Empty input returns the empty set (allow any). + if s.len == 0: + return {} + for raw in s.split(','): + let tok = raw.strip() + if tok.len == 0: + raise newException(PgError, "Empty entry in require_auth list: " & s) + result.incl(parseAuthMethod(tok)) + proc parseTargetSessionAttrs(s: string): TargetSessionAttrs = case s of "any": @@ -2030,6 +2153,8 @@ proc applyParam(result: var ConnConfig, key, val: string) = result.sslMode = parseSslMode(val) of "channel_binding": result.channelBinding = parseChannelBindingMode(val) + of "require_auth": + result.requireAuth = parseRequireAuth(val) of "application_name": result.applicationName = val of "connect_timeout": @@ -2268,6 +2393,7 @@ proc initConnConfig*( keepAliveCount = 0, hosts: seq[HostEntry] = @[], targetSessionAttrs = tsaAny, + requireAuth: set[AuthMethod] = {}, extraParams: seq[(string, string)] = @[], ): ConnConfig = ## Create a connection configuration with sensible defaults. @@ -2289,6 +2415,7 @@ proc initConnConfig*( keepAliveCount: keepAliveCount, hosts: hosts, targetSessionAttrs: targetSessionAttrs, + requireAuth: requireAuth, extraParams: extraParams, ) diff --git a/tests/test_dsn.nim b/tests/test_dsn.nim index f98adcb..cfc7829 100644 --- a/tests/test_dsn.nim +++ b/tests/test_dsn.nim @@ -173,6 +173,47 @@ suite "parseDsn": expect PgError: discard parseDsn("postgresql://host/db?channel_binding=bogus") + test "require_auth default is empty set": + let cfg = parseDsn("postgresql://host/db") + check cfg.requireAuth == {} + + test "ConnConfig zero init has empty requireAuth": + let cfg = ConnConfig() + check cfg.requireAuth == {} + + test "query param require_auth single value": + let cases = { + "none": amNone, + "password": amPassword, + "md5": amMd5, + "scram-sha-256": amScramSha256, + "scram-sha-256-plus": amScramSha256Plus, + } + for (name, expected) in cases: + let cfg = parseDsn("postgresql://host/db?require_auth=" & name) + check cfg.requireAuth == {expected} + + test "query param require_auth comma list": + let cfg = + parseDsn("postgresql://host/db?require_auth=scram-sha-256,scram-sha-256-plus") + check cfg.requireAuth == {amScramSha256, amScramSha256Plus} + + test "query param require_auth tolerates whitespace": + let cfg = parseDsn("postgresql://host/db?require_auth=scram-sha-256,%20md5") + check cfg.requireAuth == {amScramSha256, amMd5} + + test "keyword=value form require_auth": + let cfg = parseDsn("host=127.0.0.1 require_auth=md5,password") + check cfg.requireAuth == {amMd5, amPassword} + + test "error: unknown require_auth method": + expect PgError: + discard parseDsn("postgresql://host/db?require_auth=sha1") + + test "error: empty entry in require_auth list": + expect PgError: + discard parseDsn("postgresql://host/db?require_auth=md5,,password") + test "error: invalid connect_timeout": expect PgError: discard parseDsn("postgresql://host/db?connect_timeout=abc") diff --git a/tests/test_tracing.nim b/tests/test_tracing.nim index 9306e0c..93d4e1e 100644 --- a/tests/test_tracing.nim +++ b/tests/test_tracing.nim @@ -114,6 +114,11 @@ type hasConn: bool errMsg: string + InsecureAuthRec = object + hasConn: bool + authMethod: AuthMethod + sslEnabled: bool + TraceLog = ref object connectStarts: seq[ConnectStartRec] connectEnds: seq[ConnectEndRec] @@ -130,6 +135,7 @@ type poolReleaseStarts: seq[PoolReleaseStartRec] poolReleaseEnds: seq[PoolReleaseEndRec] poolCloseErrors: seq[PoolCloseErrorRec] + insecureAuths: seq[InsecureAuthRec] proc newTraceLog(): TraceLog = TraceLog() @@ -261,6 +267,15 @@ proc buildTracer(log: TraceLog): PgTracer = ) ) + tracer.onInsecureAuth = proc(data: TraceInsecureAuthData) {.gcsafe, raises: [].} = + log.insecureAuths.add( + InsecureAuthRec( + hasConn: data.conn != nil, + authMethod: data.authMethod, + sslEnabled: data.sslEnabled, + ) + ) + return tracer proc tracedConfig(tracer: PgTracer): ConnConfig = @@ -821,3 +836,99 @@ suite "Tracing: nil tracer": await conn.close() waitFor t() + +suite "Tracing: onInsecureAuth": + # These unit tests exercise the tracer closure directly via the public + # hook. End-to-end firing from the auth loop (cleartext over plaintext) + # requires a PG server configured to request `password` auth, which is + # out of scope for the current docker-compose setup. + test "closure receives method and transport state (plaintext)": + let log = newTraceLog() + let tracer = buildTracer(log) + let fake = PgConnection() + + tracer.onInsecureAuth( + TraceInsecureAuthData(conn: fake, authMethod: amPassword, sslEnabled: false) + ) + + check log.insecureAuths.len == 1 + check log.insecureAuths[0].hasConn + check log.insecureAuths[0].authMethod == amPassword + check log.insecureAuths[0].sslEnabled == false + + test "closure receives sslEnabled=true": + let log = newTraceLog() + let tracer = buildTracer(log) + let fake = PgConnection() + + tracer.onInsecureAuth( + TraceInsecureAuthData(conn: fake, authMethod: amPassword, sslEnabled: true) + ) + + check log.insecureAuths.len == 1 + check log.insecureAuths[0].sslEnabled == true + + test "PgTracer with nil onInsecureAuth hook is safe to leave unset": + let tracer = PgTracer() + check tracer.onInsecureAuth == nil + +suite "filterSaslByRequireAuth": + test "empty allowed set performs no filtering": + let mechs = @["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"] + check filterSaslByRequireAuth(mechs, {}) == mechs + + test "drops PLUS when only SCRAM allowed": + let mechs = @["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"] + check filterSaslByRequireAuth(mechs, {amScramSha256}) == @["SCRAM-SHA-256"] + + test "drops SCRAM when only PLUS allowed": + let mechs = @["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"] + check filterSaslByRequireAuth(mechs, {amScramSha256Plus}) == @["SCRAM-SHA-256-PLUS"] + + test "empty result when nothing matches": + let mechs = @["SCRAM-SHA-256"] + check filterSaslByRequireAuth(mechs, {amMd5}).len == 0 + + test "unrelated non-SCRAM methods in allowlist do not add mechanisms": + let mechs = @["SCRAM-SHA-256"] + check filterSaslByRequireAuth(mechs, {amScramSha256, amMd5, amPassword}) == + @["SCRAM-SHA-256"] + +suite "Tracing: requireAuth happy path": + test "accepts SCRAM when explicitly allowed": + proc t() {.async.} = + var cfg = plainConfig() + cfg.requireAuth = {amScramSha256, amScramSha256Plus} + let conn = await connect(cfg) + doAssert conn != nil + discard await conn.exec("SELECT 1") + await conn.close() + + waitFor t() + +suite "Tracing: requireAuth negative path": + # These tests require the docker-compose PG to use SCRAM auth (the default + # for modern postgres images). If the server ever switches to trust/md5, + # the expected error message will differ but the connect must still fail. + proc connectRaises(requireAuth: set[AuthMethod]): bool = + var raised = false + proc t() {.async.} = + var cfg = plainConfig() + cfg.requireAuth = requireAuth + try: + let conn = await connect(cfg) + await conn.close() + except PgConnectionError: + raised = true + + waitFor t() + raised + + test "rejects when only md5 is allowed against SCRAM server": + check connectRaises({amMd5}) + + test "rejects when only password is allowed against SCRAM server": + check connectRaises({amPassword}) + + test "rejects when only amNone is allowed against SCRAM server": + check connectRaises({amNone})