Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions async_postgres/pg_connection.nim
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ type
SslMode* = enum
## SSL/TLS negotiation mode for the connection.
sslDisable ## Disable SSL (default)
sslAllow ## Try plaintext; fall back to SSL if refused
sslPrefer ## Try SSL; fall back to plaintext if refused
sslRequire ## Require SSL (no certificate verification)
sslVerifyCa ## Require SSL + verify CA chain (no hostname verification)
Expand Down Expand Up @@ -887,6 +888,25 @@ proc connectToHost(
config: ConnConfig, hostAddr: string, hostPort: int
): Future[PgConnection] {.async.} =
## Connect to a single PostgreSQL host. Internal helper for multi-host connect.

if config.sslMode == sslAllow:
# sslAllow: try plaintext first, then fall back to SSL.
var plainConfig = config
plainConfig.sslMode = sslDisable
try:
return await connectToHost(plainConfig, hostAddr, hostPort)
except CancelledError as e:
raise e
except CatchableError:
# Plaintext failed — retry with SSL.
# WARNING: This is vulnerable to MITM downgrade attacks. A network
# attacker can force the first attempt to fail and then intercept
# the SSL connection. Use sslRequire or stronger if security is needed.
stderr.writeLine "pg_connection: plaintext connection failed, retrying with SSL (sslmode=allow)"
var sslConfig = config
sslConfig.sslMode = sslPrefer
return await connectToHost(sslConfig, hostAddr, hostPort)

var conn: PgConnection

let isUnix = isUnixSocket(hostAddr)
Expand Down Expand Up @@ -1583,6 +1603,8 @@ proc parseSslMode(s: string): SslMode =
case s
of "disable":
sslDisable
of "allow":
sslAllow
of "prefer":
sslPrefer
of "require":
Expand Down
4 changes: 3 additions & 1 deletion tests/test_dsn.nim
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,13 @@ suite "parseDsn":
check cfg.database == "my/db"

test "query param sslmode":
for mode in ["disable", "prefer", "require", "verify-ca", "verify-full"]:
for mode in ["disable", "allow", "prefer", "require", "verify-ca", "verify-full"]:
let cfg = parseDsn("postgresql://host/db?sslmode=" & mode)
case mode
of "disable":
check cfg.sslMode == sslDisable
of "allow":
check cfg.sslMode == sslAllow
of "prefer":
check cfg.sslMode == sslPrefer
of "require":
Expand Down
9 changes: 9 additions & 0 deletions tests/test_e2e.nim
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,15 @@ suite "E2E: SSL Connection":

waitFor t()

test "sslAllow connects without SSL when server accepts plaintext":
proc t() {.async.} =
let conn = await connect(sslConfig(sslAllow))
doAssert conn.state == csReady
doAssert conn.sslEnabled == false
await conn.close()

waitFor t()

test "query over SSL connection":
proc t() {.async.} =
let conn = await connect(sslConfig(sslRequire))
Expand Down
117 changes: 117 additions & 0 deletions tests/test_ssl.nim
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,10 @@ suite "SSL negotiation - sslVerifyCa":
waitFor testBody()
check raised

test "sslAllow ordinal is between sslDisable and sslPrefer":
check ord(sslAllow) > ord(sslDisable)
check ord(sslAllow) < ord(sslPrefer)

test "sslVerifyCa ordinal is between sslRequire and sslVerifyFull":
check ord(sslVerifyCa) > ord(sslRequire)
check ord(sslVerifyCa) < ord(sslVerifyFull)
Expand All @@ -381,6 +385,119 @@ suite "SSL negotiation - sslVerifyCa":
let config = ConnConfig()
check config.sslRootCert == ""

suite "SSL negotiation - sslAllow":
test "sslAllow connects without SSL when server accepts plaintext":
var connState: PgConnState
var connSslEnabled: bool

proc testBody() {.async.} =
let ms = startMockServer()

proc serverHandler() {.async.} =
let st = await ms.accept()
try:
await drainStartupMessage(st)
await sendAuthOkAndReady(st)
await drainUntilClose(st)
except CatchableError:
discard
await closeClient(st)

let serverFut = serverHandler()

let config = ConnConfig(
host: "127.0.0.1",
port: ms.port,
user: "test",
database: "test",
sslMode: sslAllow,
)

let conn = await connect(config)
connState = conn.state
connSslEnabled = conn.sslEnabled
await conn.close()

await serverFut
await closeServer(ms)

waitFor testBody()
check connState == csReady
check connSslEnabled == false

test "sslAllow retries with SSL when plaintext is rejected":
var connState: PgConnState
var connSslEnabled: bool
var attemptCount: int = 0

proc testBody() {.async.} =
let ms = startMockServer()

proc serverHandler() {.async.} =
# First connection: reject plaintext with FATAL error
block:
let st = await ms.accept()
attemptCount.inc
try:
discard await readN(st, 8) # read StartupMessage header
# Send FATAL error response (server requires SSL)
var body: seq[byte] = @[]
body.add(byte('S'))
for c in "FATAL":
body.add(byte(c))
body.add(0)
body.add(byte('M'))
for c in "no pg_hba.conf entry":
body.add(byte(c))
body.add(0)
body.add(0) # terminator
var msg: seq[byte] = @[byte('E')]
msg.addInt32(int32(4 + body.len))
msg.add(body)
await sendBytes(st, msg)
except CatchableError:
discard
await closeClient(st)

# Second connection: accept SSL and complete handshake
block:
let st = await ms.accept()
attemptCount.inc
try:
# Read SSLRequest
discard await readN(st, 8)
# Refuse SSL (sslPrefer will fall back to plaintext)
await sendBytes(st, @[byte('N')])
await drainStartupMessage(st)
await sendAuthOkAndReady(st)
await drainUntilClose(st)
except CatchableError:
discard
await closeClient(st)

let serverFut = serverHandler()

let config = ConnConfig(
host: "127.0.0.1",
port: ms.port,
user: "test",
database: "test",
sslMode: sslAllow,
)

let conn = await connect(config)
connState = conn.state
connSslEnabled = conn.sslEnabled
await conn.close()

await serverFut
await closeServer(ms)

waitFor testBody()
check attemptCount == 2
check connState == csReady
check connSslEnabled == false

suite "SSL negotiation - sslDisable":
test "sslDisable sends StartupMessage directly without SSLRequest":
var firstMsgVersion: int32 = 0
Expand Down
Loading