From 99a8b3d90bf974086ab64443cdfdf3523acd46ae Mon Sep 17 00:00:00 2001 From: fox0430 Date: Mon, 20 Apr 2026 17:09:00 +0900 Subject: [PATCH] Add channelBinding mode to reject SCRAM-SHA-256-PLUS downgrade --- README.md | 1 + async_postgres/pg_connection.nim | 101 +++++++++-- tests/test_dsn.nim | 33 ++++ tests/test_ssl.nim | 297 ++++++++++++++++++++++++++++++- 4 files changed, 413 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index d35daf0..03fa648 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ Async PostgreSQL client in Nim. - Pool cluster with read replica routing - SSL/TLS support (disable, allow, prefer, require, verify-ca, verify-full) - MD5, SCRAM-SHA-256 and SCRAM-SHA-256-PLUS authentication +- `channel_binding` policy (disable, prefer, require) to harden SCRAM against downgrade - DSN connection string parsing - Unix socket connection - Multi-host failover diff --git a/async_postgres/pg_connection.nim b/async_postgres/pg_connection.nim index 0d0b164..6ef2d0e 100644 --- a/async_postgres/pg_connection.nim +++ b/async_postgres/pg_connection.nim @@ -69,6 +69,12 @@ type sslVerifyCa ## Require SSL + verify CA chain (no hostname verification) sslVerifyFull ## Require SSL + verify CA chain and hostname + ChannelBindingMode* = enum + ## SCRAM channel binding policy (libpq-compatible). + cbPrefer ## Use SCRAM-SHA-256-PLUS when SSL and server support it (default). + cbDisable ## Never use SCRAM-SHA-256-PLUS; only SCRAM-SHA-256. + cbRequire ## Require SCRAM-SHA-256-PLUS; fail if unavailable. + TargetSessionAttrs* = enum ## Target server type for multi-host failover (libpq compatible). tsaAny ## Connect to any server (default) @@ -91,6 +97,9 @@ type database*: string sslMode*: SslMode sslRootCert*: string ## PEM-encoded CA certificate(s) for sslVerifyCa/sslVerifyFull + channelBinding*: ChannelBindingMode + ## SCRAM channel binding policy (default cbPrefer). `cbRequire` fails the + ## connection if SCRAM-SHA-256-PLUS cannot actually be used (libpq parity). applicationName*: string connectTimeout*: Duration ## TCP connect timeout (default: no timeout) keepAlive*: bool ## Enable TCP keepalive (default true via parseDsn) @@ -1112,6 +1121,57 @@ proc bytesToString(data: seq[byte]): string = for i in 0 ..< data.len: result[i] = char(data[i]) +proc selectScramMechanism( + sslEnabled: bool, + serverCertDer: openArray[byte], + saslMechanisms: seq[string], + mode: ChannelBindingMode, +): tuple[mechanism: string, cbType: string, cbData: seq[byte]] = + ## Pick the SCRAM mechanism and channel-binding material for a SASL + ## authentication attempt. Raises `PgConnectionError` when the server-offered + ## mechanisms cannot satisfy `mode`. + let serverHasPlus = "SCRAM-SHA-256-PLUS" in saslMechanisms + let serverHasScram = "SCRAM-SHA-256" in saslMechanisms + let canUsePlus = sslEnabled and serverCertDer.len > 0 and serverHasPlus + case mode + of cbRequire: + if not sslEnabled: + raise newException( + PgConnectionError, "channel binding is required, but SSL is not in use" + ) + if not serverHasPlus: + raise newException( + PgConnectionError, + "channel binding is required, but server did not offer SCRAM-SHA-256-PLUS", + ) + if serverCertDer.len == 0: + raise newException( + PgConnectionError, + "channel binding is required, but server certificate is unavailable", + ) + result.mechanism = "SCRAM-SHA-256-PLUS" + result.cbType = "tls-server-end-point" + result.cbData = computeTlsServerEndpoint(serverCertDer) + of cbPrefer: + if canUsePlus: + result.mechanism = "SCRAM-SHA-256-PLUS" + result.cbType = "tls-server-end-point" + result.cbData = computeTlsServerEndpoint(serverCertDer) + elif serverHasScram: + result.mechanism = "SCRAM-SHA-256" + else: + raise newException( + PgConnectionError, "server doesn't support SCRAM-SHA-256 or SCRAM-SHA-256-PLUS" + ) + of cbDisable: + if serverHasScram: + result.mechanism = "SCRAM-SHA-256" + else: + raise newException( + PgConnectionError, + "channel binding is disabled, but server only offered SCRAM-SHA-256-PLUS", + ) + proc connectToHost( config: ConnConfig, hostAddr: string, hostPort: int ): Future[PgConnection] {.async.} = @@ -1251,24 +1311,14 @@ proc connectToHost( let hash = md5AuthHash(config.user, config.password, msg.md5Salt) await conn.sendMsg(encodePassword(hash)) of bmkAuthenticationSASL: - var mechanism: string - var cbType = "" - var cbData: seq[byte] - if conn.sslEnabled and conn.serverCertDer.len > 0 and - "SCRAM-SHA-256-PLUS" in msg.saslMechanisms: - mechanism = "SCRAM-SHA-256-PLUS" - cbType = "tls-server-end-point" - cbData = computeTlsServerEndpoint(conn.serverCertDer) - elif "SCRAM-SHA-256" in msg.saslMechanisms: - mechanism = "SCRAM-SHA-256" - else: - raise newException( - PgConnectionError, - "Server doesn't support SCRAM-SHA-256 or SCRAM-SHA-256-PLUS", - ) - let clientFirst = - scramClientFirstMessage(config.user, scramState, cbType, cbData) - await conn.sendMsg(encodeSASLInitialResponse(mechanism, clientFirst)) + let choice = selectScramMechanism( + conn.sslEnabled, conn.serverCertDer, msg.saslMechanisms, + config.channelBinding, + ) + let clientFirst = scramClientFirstMessage( + config.user, scramState, choice.cbType, choice.cbData + ) + await conn.sendMsg(encodeSASLInitialResponse(choice.mechanism, clientFirst)) of bmkAuthenticationSASLContinue: let clientFinal = scramClientFinalMessage(config.password, msg.saslData, scramState) @@ -1918,6 +1968,17 @@ proc parseSslMode(s: string): SslMode = else: raise newException(PgError, "Invalid sslmode: " & s) +proc parseChannelBindingMode(s: string): ChannelBindingMode = + case s + of "disable": + cbDisable + of "prefer": + cbPrefer + of "require": + cbRequire + else: + raise newException(PgError, "Invalid channel_binding: " & s) + proc parseTargetSessionAttrs(s: string): TargetSessionAttrs = case s of "any": @@ -1958,6 +2019,8 @@ proc applyParam(result: var ConnConfig, key, val: string) = result.password = val of "sslmode": result.sslMode = parseSslMode(val) + of "channel_binding": + result.channelBinding = parseChannelBindingMode(val) of "application_name": result.applicationName = val of "connect_timeout": @@ -2187,6 +2250,7 @@ proc initConnConfig*( database = "", sslMode = sslDisable, sslRootCert = "", + channelBinding = cbPrefer, applicationName = "", connectTimeout = ZeroDuration, keepAlive = true, @@ -2207,6 +2271,7 @@ proc initConnConfig*( database: database, sslMode: sslMode, sslRootCert: sslRootCert, + channelBinding: channelBinding, applicationName: applicationName, connectTimeout: connectTimeout, keepAlive: keepAlive, diff --git a/tests/test_dsn.nim b/tests/test_dsn.nim index 210d47a..f98adcb 100644 --- a/tests/test_dsn.nim +++ b/tests/test_dsn.nim @@ -148,6 +148,31 @@ suite "parseDsn": expect PgError: discard parseDsn("postgresql://host/db?sslmode=bogus") + test "query param channel_binding": + for mode in ["disable", "prefer", "require"]: + let cfg = parseDsn("postgresql://host/db?channel_binding=" & mode) + case mode + of "disable": + check cfg.channelBinding == cbDisable + of "prefer": + check cfg.channelBinding == cbPrefer + of "require": + check cfg.channelBinding == cbRequire + else: + discard + + test "channel_binding default is prefer": + let cfg = parseDsn("postgresql://host/db") + check cfg.channelBinding == cbPrefer + + test "ConnConfig zero init has cbPrefer": + let cfg = ConnConfig() + check cfg.channelBinding == cbPrefer + + test "error: invalid channel_binding": + expect PgError: + discard parseDsn("postgresql://host/db?channel_binding=bogus") + test "error: invalid connect_timeout": expect PgError: discard parseDsn("postgresql://host/db?connect_timeout=abc") @@ -459,6 +484,14 @@ suite "parseDsn keyword=value": expect PgError: discard parseDsn("host=h sslmode=bogus") + test "channel_binding parameter": + let cfg = parseDsn("host=h channel_binding=require") + check cfg.channelBinding == cbRequire + + test "error: invalid channel_binding": + expect PgError: + discard parseDsn("host=h channel_binding=bogus") + test "error: invalid connect_timeout": expect PgError: discard parseDsn("host=h connect_timeout=abc") diff --git a/tests/test_ssl.nim b/tests/test_ssl.nim index 4c072d6..59cabd4 100644 --- a/tests/test_ssl.nim +++ b/tests/test_ssl.nim @@ -1,6 +1,8 @@ import std/[unittest, strutils] -import ../async_postgres/[async_backend, pg_protocol, pg_connection] +import ../async_postgres/[async_backend, pg_protocol] + +import ../async_postgres/pg_connection {.all.} when hasAsyncDispatch: import std/asyncnet @@ -545,3 +547,296 @@ suite "SSL negotiation - sslDisable": check firstMsgVersion == 196608'i32 check connState == csReady check connSslEnabled == false + +proc sendAuthSasl(client: MockClient, mechanisms: seq[string]): Future[void] {.async.} = + var body: seq[byte] = @[] + body.addInt32(10) # AuthenticationSASL + for m in mechanisms: + body.addCString(m) + body.add(0'u8) # terminator + await sendBytes(client, buildBackendMsg('R', body)) + +proc readSaslInitialResponseMechanism(client: MockClient): Future[string] {.async.} = + ## Read a frontend 'p' message (SASLInitialResponse) and return the mechanism name. + discard await readN(client, 1) # 'p' message type + let lenBuf = await readN(client, 4) + let msgLen = decodeInt32(lenBuf, 0) + let body = await readN(client, msgLen - 4) + result = "" + var i = 0 + while i < body.len and body[i] != 0: + result.add(char(body[i])) + inc i + +suite "SCRAM channel binding enforcement": + test "cbRequire without SSL raises PgError": + var raised = false + var msgMatches = false + + proc testBody() {.async.} = + let ms = startMockServer() + + proc serverHandler() {.async.} = + let st = await ms.accept() + try: + await drainStartupMessage(st) + await sendAuthSasl(st, @["SCRAM-SHA-256"]) + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + + let config = ConnConfig( + host: "127.0.0.1", + port: ms.port, + user: "test", + password: "test", + database: "test", + sslMode: sslDisable, + channelBinding: cbRequire, + ) + + try: + let conn = await connect(config) + await conn.close() + except PgError as e: + raised = true + msgMatches = "SSL is not in use" in e.msg + + await serverFut + await closeServer(ms) + + waitFor testBody() + check raised + check msgMatches + + test "cbRequire errors when server offers only SCRAM-SHA-256": + var raised = false + var msgMatches = false + + proc testBody() {.async.} = + let ms = startMockServer() + + proc serverHandler() {.async.} = + let st = await ms.accept() + try: + await drainStartupMessage(st) + await sendAuthSasl(st, @["SCRAM-SHA-256"]) + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + + let config = ConnConfig( + host: "127.0.0.1", + port: ms.port, + user: "test", + password: "test", + database: "test", + sslMode: sslDisable, + channelBinding: cbRequire, + ) + + try: + let conn = await connect(config) + await conn.close() + except PgError as e: + raised = true + msgMatches = "channel binding" in e.msg + + await serverFut + await closeServer(ms) + + waitFor testBody() + check raised + check msgMatches + + test "cbDisable picks SCRAM-SHA-256 even when PLUS is offered": + var pickedNonPlus = false + + proc testBody() {.async.} = + let ms = startMockServer() + + proc serverHandler() {.async.} = + let st = await ms.accept() + try: + await drainStartupMessage(st) + await sendAuthSasl(st, @["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"]) + let mech = await readSaslInitialResponseMechanism(st) + pickedNonPlus = mech == "SCRAM-SHA-256" + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + + let config = ConnConfig( + host: "127.0.0.1", + port: ms.port, + user: "test", + password: "test", + database: "test", + sslMode: sslDisable, + channelBinding: cbDisable, + ) + + try: + let conn = await connect(config) + await conn.close() + except PgError: + discard + + await serverFut + await closeServer(ms) + + waitFor testBody() + check pickedNonPlus + + test "cbDisable errors when server offers only PLUS": + var raised = false + var msgMatches = false + + proc testBody() {.async.} = + let ms = startMockServer() + + proc serverHandler() {.async.} = + let st = await ms.accept() + try: + await drainStartupMessage(st) + await sendAuthSasl(st, @["SCRAM-SHA-256-PLUS"]) + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + + let config = ConnConfig( + host: "127.0.0.1", + port: ms.port, + user: "test", + password: "test", + database: "test", + sslMode: sslDisable, + channelBinding: cbDisable, + ) + + try: + let conn = await connect(config) + await conn.close() + except PgError as e: + raised = true + msgMatches = "channel binding" in e.msg + + await serverFut + await closeServer(ms) + + waitFor testBody() + check raised + check msgMatches + + test "cbPrefer without SSL accepts SCRAM-SHA-256": + var pickedScram = false + + proc testBody() {.async.} = + let ms = startMockServer() + + proc serverHandler() {.async.} = + let st = await ms.accept() + try: + await drainStartupMessage(st) + await sendAuthSasl(st, @["SCRAM-SHA-256"]) + let mech = await readSaslInitialResponseMechanism(st) + pickedScram = mech == "SCRAM-SHA-256" + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + + let config = ConnConfig( + host: "127.0.0.1", + port: ms.port, + user: "test", + password: "test", + database: "test", + sslMode: sslDisable, + channelBinding: cbPrefer, + ) + + try: + let conn = await connect(config) + await conn.close() + except PgError: + discard + + await serverFut + await closeServer(ms) + + waitFor testBody() + check pickedScram + +suite "selectScramMechanism": + const fakeCert = @[byte 0x30, 0x82, 0x01, 0x22] # dummy DER prefix + const bothMechs = @["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"] + + test "cbDisable rejects PLUS when SSL is available": + let choice = selectScramMechanism( + sslEnabled = true, + serverCertDer = fakeCert, + saslMechanisms = bothMechs, + mode = cbDisable, + ) + check choice.mechanism == "SCRAM-SHA-256" + check choice.cbType == "" + check choice.cbData.len == 0 + + test "cbPrefer picks PLUS when SSL + cert + server support all present": + let choice = selectScramMechanism( + sslEnabled = true, + serverCertDer = fakeCert, + saslMechanisms = bothMechs, + mode = cbPrefer, + ) + check choice.mechanism == "SCRAM-SHA-256-PLUS" + check choice.cbType == "tls-server-end-point" + check choice.cbData.len > 0 + + test "cbPrefer falls back to SCRAM-SHA-256 when cert is missing": + let choice = selectScramMechanism( + sslEnabled = true, + serverCertDer = @[], + saslMechanisms = bothMechs, + mode = cbPrefer, + ) + check choice.mechanism == "SCRAM-SHA-256" + + test "cbRequire succeeds when SSL + cert + PLUS all available": + let choice = selectScramMechanism( + sslEnabled = true, + serverCertDer = fakeCert, + saslMechanisms = bothMechs, + mode = cbRequire, + ) + check choice.mechanism == "SCRAM-SHA-256-PLUS" + check choice.cbType == "tls-server-end-point" + check choice.cbData.len > 0 + + test "cbRequire raises when cert is missing even with SSL": + expect PgConnectionError: + discard selectScramMechanism( + sslEnabled = true, + serverCertDer = @[], + saslMechanisms = bothMechs, + mode = cbRequire, + ) + + test "cbPrefer raises when no SCRAM mechanism is offered": + expect PgConnectionError: + discard selectScramMechanism( + sslEnabled = false, + serverCertDer = @[], + saslMechanisms = @["SOMETHING-ELSE"], + mode = cbPrefer, + )