diff --git a/async_postgres/pg_connection.nim b/async_postgres/pg_connection.nim index 9ae6d4d..2311d57 100644 --- a/async_postgres/pg_connection.nim +++ b/async_postgres/pg_connection.nim @@ -78,6 +78,7 @@ type SslMode* = enum ## SSL/TLS negotiation mode for the connection. sslDisable ## Disable SSL (default) + sslAllow ## Try plaintext; fall back to SSL if refused sslPrefer ## Try SSL; fall back to plaintext if refused sslRequire ## Require SSL (no certificate verification) sslVerifyCa ## Require SSL + verify CA chain (no hostname verification) @@ -887,6 +888,25 @@ proc connectToHost( config: ConnConfig, hostAddr: string, hostPort: int ): Future[PgConnection] {.async.} = ## Connect to a single PostgreSQL host. Internal helper for multi-host connect. + + if config.sslMode == sslAllow: + # sslAllow: try plaintext first, then fall back to SSL. + var plainConfig = config + plainConfig.sslMode = sslDisable + try: + return await connectToHost(plainConfig, hostAddr, hostPort) + except CancelledError as e: + raise e + except CatchableError: + # Plaintext failed — retry with SSL. + # WARNING: This is vulnerable to MITM downgrade attacks. A network + # attacker can force the first attempt to fail and then intercept + # the SSL connection. Use sslRequire or stronger if security is needed. + stderr.writeLine "pg_connection: plaintext connection failed, retrying with SSL (sslmode=allow)" + var sslConfig = config + sslConfig.sslMode = sslPrefer + return await connectToHost(sslConfig, hostAddr, hostPort) + var conn: PgConnection let isUnix = isUnixSocket(hostAddr) @@ -1583,6 +1603,8 @@ proc parseSslMode(s: string): SslMode = case s of "disable": sslDisable + of "allow": + sslAllow of "prefer": sslPrefer of "require": diff --git a/tests/test_dsn.nim b/tests/test_dsn.nim index 2b21f93..210d47a 100644 --- a/tests/test_dsn.nim +++ b/tests/test_dsn.nim @@ -76,11 +76,13 @@ suite "parseDsn": check cfg.database == "my/db" test "query param sslmode": - for mode in ["disable", "prefer", "require", "verify-ca", "verify-full"]: + for mode in ["disable", "allow", "prefer", "require", "verify-ca", "verify-full"]: let cfg = parseDsn("postgresql://host/db?sslmode=" & mode) case mode of "disable": check cfg.sslMode == sslDisable + of "allow": + check cfg.sslMode == sslAllow of "prefer": check cfg.sslMode == sslPrefer of "require": diff --git a/tests/test_e2e.nim b/tests/test_e2e.nim index 9c70bd5..4b8ce3c 100644 --- a/tests/test_e2e.nim +++ b/tests/test_e2e.nim @@ -315,6 +315,15 @@ suite "E2E: SSL Connection": waitFor t() + test "sslAllow connects without SSL when server accepts plaintext": + proc t() {.async.} = + let conn = await connect(sslConfig(sslAllow)) + doAssert conn.state == csReady + doAssert conn.sslEnabled == false + await conn.close() + + waitFor t() + test "query over SSL connection": proc t() {.async.} = let conn = await connect(sslConfig(sslRequire)) diff --git a/tests/test_ssl.nim b/tests/test_ssl.nim index 795f3f5..4c072d6 100644 --- a/tests/test_ssl.nim +++ b/tests/test_ssl.nim @@ -373,6 +373,10 @@ suite "SSL negotiation - sslVerifyCa": waitFor testBody() check raised + test "sslAllow ordinal is between sslDisable and sslPrefer": + check ord(sslAllow) > ord(sslDisable) + check ord(sslAllow) < ord(sslPrefer) + test "sslVerifyCa ordinal is between sslRequire and sslVerifyFull": check ord(sslVerifyCa) > ord(sslRequire) check ord(sslVerifyCa) < ord(sslVerifyFull) @@ -381,6 +385,119 @@ suite "SSL negotiation - sslVerifyCa": let config = ConnConfig() check config.sslRootCert == "" +suite "SSL negotiation - sslAllow": + test "sslAllow connects without SSL when server accepts plaintext": + var connState: PgConnState + var connSslEnabled: bool + + proc testBody() {.async.} = + let ms = startMockServer() + + proc serverHandler() {.async.} = + let st = await ms.accept() + try: + await drainStartupMessage(st) + await sendAuthOkAndReady(st) + await drainUntilClose(st) + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + + let config = ConnConfig( + host: "127.0.0.1", + port: ms.port, + user: "test", + database: "test", + sslMode: sslAllow, + ) + + let conn = await connect(config) + connState = conn.state + connSslEnabled = conn.sslEnabled + await conn.close() + + await serverFut + await closeServer(ms) + + waitFor testBody() + check connState == csReady + check connSslEnabled == false + + test "sslAllow retries with SSL when plaintext is rejected": + var connState: PgConnState + var connSslEnabled: bool + var attemptCount: int = 0 + + proc testBody() {.async.} = + let ms = startMockServer() + + proc serverHandler() {.async.} = + # First connection: reject plaintext with FATAL error + block: + let st = await ms.accept() + attemptCount.inc + try: + discard await readN(st, 8) # read StartupMessage header + # Send FATAL error response (server requires SSL) + var body: seq[byte] = @[] + body.add(byte('S')) + for c in "FATAL": + body.add(byte(c)) + body.add(0) + body.add(byte('M')) + for c in "no pg_hba.conf entry": + body.add(byte(c)) + body.add(0) + body.add(0) # terminator + var msg: seq[byte] = @[byte('E')] + msg.addInt32(int32(4 + body.len)) + msg.add(body) + await sendBytes(st, msg) + except CatchableError: + discard + await closeClient(st) + + # Second connection: accept SSL and complete handshake + block: + let st = await ms.accept() + attemptCount.inc + try: + # Read SSLRequest + discard await readN(st, 8) + # Refuse SSL (sslPrefer will fall back to plaintext) + await sendBytes(st, @[byte('N')]) + await drainStartupMessage(st) + await sendAuthOkAndReady(st) + await drainUntilClose(st) + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + + let config = ConnConfig( + host: "127.0.0.1", + port: ms.port, + user: "test", + database: "test", + sslMode: sslAllow, + ) + + let conn = await connect(config) + connState = conn.state + connSslEnabled = conn.sslEnabled + await conn.close() + + await serverFut + await closeServer(ms) + + waitFor testBody() + check attemptCount == 2 + check connState == csReady + check connSslEnabled == false + suite "SSL negotiation - sslDisable": test "sslDisable sends StartupMessage directly without SSLRequest": var firstMsgVersion: int32 = 0