From a659705f4a5a6ca80cfa8fc09d4c6e439c216905 Mon Sep 17 00:00:00 2001 From: fox0430 Date: Wed, 8 Apr 2026 17:45:35 +0900 Subject: [PATCH 1/3] Add SCRAM-SHA-256-PLUS --- async_postgres/pg_auth.nim | 39 +++++- async_postgres/pg_bearssl.nim | 197 ++++++++++++++++++++++++++++++ async_postgres/pg_connection.nim | 172 ++++++++------------------ tests/test_auth.nim | 202 ++++++++++++++++++++++++++++++- 4 files changed, 482 insertions(+), 128 deletions(-) create mode 100644 async_postgres/pg_bearssl.nim diff --git a/async_postgres/pg_auth.nim b/async_postgres/pg_auth.nim index 8f8cb70..7ec581f 100644 --- a/async_postgres/pg_auth.nim +++ b/async_postgres/pg_auth.nim @@ -9,6 +9,8 @@ type ScramState* = object clientNonce*: string clientFirstBare*: string serverSignature*: array[32, byte] + gs2Header*: string ## GS2 header: "n,," (no binding) or "p=tls-server-end-point,," + channelBindingData*: seq[byte] ## Channel binding data (empty for non-PLUS) proc toBytes(s: string): seq[byte] = result = newSeq[byte](s.len) @@ -34,23 +36,42 @@ proc scramEscapeUsername*(user: string): string = ## '=' is encoded as '=3D' and ',' is encoded as '=2C'. result = user.replace("=", "=3D").replace(",", "=2C") -proc scramClientFirstMessage*(user: string, state: var ScramState): seq[byte] = +proc scramClientFirstMessage*( + user: string, state: var ScramState, cbType: string = "", cbData: seq[byte] = @[] +): seq[byte] = ## Generate the SCRAM-SHA-256 client-first message with a random nonce. + ## When `cbType` is non-empty, use channel binding (SCRAM-SHA-256-PLUS). var nonceBuf: array[24, byte] let n = randomBytes(nonceBuf) if n != 24: raise newException(CatchableError, "SCRAM: failed to generate random nonce") state.clientNonce = base64.encode(nonceBuf) state.clientFirstBare = "n=" & scramEscapeUsername(user) & ",r=" & state.clientNonce - result = toBytes("n,," & state.clientFirstBare) + state.gs2Header = + if cbType.len > 0: + "p=" & cbType & ",," + else: + "n,," + state.channelBindingData = cbData + result = toBytes(state.gs2Header & state.clientFirstBare) proc scramClientFirstMessage*( - user: string, nonce: string, state: var ScramState + user: string, + nonce: string, + state: var ScramState, + cbType: string = "", + cbData: seq[byte] = @[], ): seq[byte] = ## Overload with explicit nonce for testing. state.clientNonce = nonce state.clientFirstBare = "n=" & scramEscapeUsername(user) & ",r=" & nonce - result = toBytes("n,," & state.clientFirstBare) + state.gs2Header = + if cbType.len > 0: + "p=" & cbType & ",," + else: + "n,," + state.channelBindingData = cbData + result = toBytes(state.gs2Header & state.clientFirstBare) proc scramClientFinalMessage*( password: string, serverFirstData: openArray[byte], state: var ScramState @@ -102,7 +123,9 @@ proc scramClientFinalMessage*( let saltedPassword = sha256.pbkdf2(password, salt, iterations, 32) let clientKey = sha256.hmac(saltedPassword, "Client Key").data let storedKey = sha256.digest(clientKey).data - let clientFinalWithoutProof = "c=biws,r=" & combinedNonce + var cbindInput = toBytes(state.gs2Header) + cbindInput.add(state.channelBindingData) + let clientFinalWithoutProof = "c=" & base64.encode(cbindInput) & ",r=" & combinedNonce let authMessage = state.clientFirstBare & "," & serverFirstMsg & "," & clientFinalWithoutProof let clientSignature = sha256.hmac(storedKey, authMessage).data @@ -116,6 +139,12 @@ proc scramClientFinalMessage*( result = toBytes(clientFinalWithoutProof & ",p=" & base64.encode(clientProof)) +proc computeTlsServerEndpoint*(certDer: openArray[byte]): seq[byte] = + ## Compute tls-server-end-point channel binding data per RFC 5929. + ## Always uses SHA-256, matching PostgreSQL (libpq) behavior. + let hash = sha256.digest(certDer) + result = @(hash.data) + proc scramVerifyServerFinal*( serverFinalData: openArray[byte], state: ScramState ): bool = diff --git a/async_postgres/pg_bearssl.nim b/async_postgres/pg_bearssl.nim new file mode 100644 index 0000000..d79e2bc --- /dev/null +++ b/async_postgres/pg_bearssl.nim @@ -0,0 +1,197 @@ +## BearSSL X509 certificate handling for SCRAM-SHA-256-PLUS channel binding. +## Wraps BearSSL callbacks to capture the leaf certificate DER bytes during +## TLS handshake, and provides trust anchor parsing from PEM data. + +import async_backend, pg_types + +when hasChronos: + import chronos/streams/tlsstream + import bearssl/[x509, rsa, ec, ssl] + + type + X509CertCaptureContext* = object + ## X509 callback wrapper that captures the leaf certificate DER bytes + ## during TLS handshake for SCRAM-SHA-256-PLUS channel binding. + vtable: ptr X509Class + inner: X509ClassPointerConst ## Original X509 engine to delegate to + certDer: ptr seq[byte] ## Points to PgConnection.serverCertDer + depth: int ## Certificate depth in chain (0 = leaf) + capturing: bool ## True while capturing leaf cert bytes + + TrustAnchorResult* = object + store*: TrustAnchorStore + backing*: seq[seq[byte]] ## Owns memory pointed to by trust anchor fields + + proc appendDnCallback( + ctx: pointer, buf: pointer, len: uint + ) {.exportc: "pg_append_dn_nim", cdecl, gcsafe, noSideEffect, raises: [].} = + ## DN accumulation callback + let s = cast[ptr seq[byte]](ctx) + let p = cast[ptr UncheckedArray[byte]](buf) + for i in 0 ..< int(len): + s[].add(p[i]) + + # C shim with const void* to satisfy BearSSL's br_x509_decoder_init signature + {. + emit: """ + static void pg_append_dn_shim(void *ctx, const void *buf, size_t len) { + pg_append_dn_nim(ctx, (void*)buf, len); + } + """ + .} + + proc initX509Decoder(ctx: var X509DecoderContext, appendDnCtx: pointer) = + {. + emit: ["br_x509_decoder_init(&", ctx, ", pg_append_dn_shim, ", appendDnCtx, ");"] + .} + + # X509 certificate capture callbacks + # Intercepts BearSSL X509 callbacks to capture the leaf certificate DER bytes, + # then delegates to the original X509 engine for actual validation. + # + # BearSSL X509 callbacks expect `const br_x509_class**` but the Nim binding + # maps them to `ptr ptr X509Class` (non-const). Suppress the resulting + # incompatible-pointer-types error from GCC for this module. + {.localPassC: "-Wno-incompatible-pointer-types".} + + proc x509CaptureStartChain(ctx: ptr ptr X509Class, serverName: cstring) {.cdecl.} = + let self = cast[ptr X509CertCaptureContext](ctx) + self.depth = 0 + self.capturing = false + let inner = cast[ptr ptr X509Class](self.inner) + inner[].startChain(inner, serverName) + + proc x509CaptureStartCert(ctx: ptr ptr X509Class, length: uint32) {.cdecl.} = + let self = cast[ptr X509CertCaptureContext](ctx) + if self.depth == 0: + self.capturing = true + self.certDer[].setLen(0) + let inner = cast[ptr ptr X509Class](self.inner) + inner[].startCert(inner, length) + + proc x509CaptureAppend( + ctx: ptr ptr X509Class, buf: ptr byte, len: csize_t + ) {.cdecl.} = + let self = cast[ptr X509CertCaptureContext](ctx) + if self.capturing: + let oldLen = self.certDer[].len + self.certDer[].setLen(oldLen + int(len)) + copyMem(addr self.certDer[][oldLen], buf, int(len)) + let inner = cast[ptr ptr X509Class](self.inner) + inner[].append(inner, buf, len) + + proc x509CaptureEndCert(ctx: ptr ptr X509Class) {.cdecl.} = + let self = cast[ptr X509CertCaptureContext](ctx) + if self.capturing: + self.capturing = false + self.depth += 1 + let inner = cast[ptr ptr X509Class](self.inner) + inner[].endCert(inner) + + proc x509CaptureEndChain(ctx: ptr ptr X509Class): cuint {.cdecl.} = + let self = cast[ptr X509CertCaptureContext](ctx) + let inner = cast[ptr ptr X509Class](self.inner) + result = inner[].endChain(inner) + + proc x509CaptureGetPkey( + ctx: ptr ptr X509Class, usages: ptr cuint + ): ptr X509Pkey {.cdecl.} = + let self = cast[ptr X509CertCaptureContext](ctx) + let inner = cast[ptr ptr X509Class](self.inner) + result = inner[].getPkey(inner, usages) + + var x509CertCaptureVtable {.global.} = X509Class( + contextSize: uint(sizeof(X509CertCaptureContext)), + startChain: x509CaptureStartChain, + startCert: x509CaptureStartCert, + append: x509CaptureAppend, + endCert: x509CaptureEndCert, + endChain: x509CaptureEndChain, + getPkey: x509CaptureGetPkey, + ) + + # Public API + + proc installX509Capture*( + captureCtx: var X509CertCaptureContext, + eng: var SslEngineContext, + serverCertDer: ptr seq[byte], + ) = + ## Install X509 capture wrapper to intercept server certificate DER bytes. + captureCtx.inner = eng.x509ctx + captureCtx.certDer = serverCertDer + captureCtx.vtable = addr x509CertCaptureVtable + sslEngineSetX509(eng, X509ClassPointerConst(addr captureCtx.vtable)) + + proc parseTrustAnchors*(pemData: string): TrustAnchorResult = + ## Parse PEM-encoded CA certificates into a TrustAnchorStore. + ## Returns both the store and the backing memory that anchor pointers reference. + ## + ## IMPORTANT: X509TrustAnchor contains raw `ptr byte` fields (dn.data, + ## pkey.key.rsa.n/e, pkey.key.ec.q). TrustAnchorStore.new() only shallow-copies + ## these structs, and BearSSL only stores a pointer to the anchor array. + ## The caller MUST keep `result.backing` alive for the lifetime of the TLS session. + let items = pemDecode(pemData) + var anchors: seq[X509TrustAnchor] + var backing: seq[seq[byte]] + + for item in items: + if item.name != "CERTIFICATE": + continue + + var dnBuf: seq[byte] + var decoder: X509DecoderContext + initX509Decoder(decoder, addr dnBuf) + x509DecoderPush(decoder, unsafeAddr item.data[0], uint(item.data.len)) + + if x509DecoderLastError(decoder) != 0: + continue + + let pkey = x509DecoderGetPkey(decoder) + if pkey.isNil: + continue + + # Deep-copy DN + backing.add(dnBuf) + let dnData = addr backing[^1][0] + + # Deep-copy public key and build anchor + var anchor: X509TrustAnchor + anchor.dn = X500Name(data: dnData, len: uint(dnBuf.len)) + anchor.flags = + if x509DecoderIsCA(decoder) != 0: + cuint(X509_TA_CA) + else: + 0 + anchor.pkey.keyType = pkey.keyType + + if pkey.keyType == byte(KEYTYPE_RSA): + var nBuf = newSeq[byte](pkey.key.rsa.nlen) + copyMem(addr nBuf[0], pkey.key.rsa.n, nBuf.len) + backing.add(nBuf) + var eBuf = newSeq[byte](pkey.key.rsa.elen) + copyMem(addr eBuf[0], pkey.key.rsa.e, eBuf.len) + backing.add(eBuf) + anchor.pkey.key.rsa = RsaPublicKey( + n: addr backing[^2][0], + nlen: uint(nBuf.len), + e: addr backing[^1][0], + elen: uint(eBuf.len), + ) + elif pkey.keyType == byte(KEYTYPE_EC): + var qBuf = newSeq[byte](pkey.key.ec.qlen) + copyMem(addr qBuf[0], pkey.key.ec.q, qBuf.len) + backing.add(qBuf) + anchor.pkey.key.ec = EcPublicKey( + curve: pkey.key.ec.curve, q: addr backing[^1][0], qlen: uint(qBuf.len) + ) + else: + continue + + anchors.add(anchor) + + if anchors.len == 0: + raise + newException(PgConnectionError, "No valid CA certificates found in PEM data") + + result = TrustAnchorResult(store: TrustAnchorStore.new(anchors), backing: backing) diff --git a/async_postgres/pg_connection.nim b/async_postgres/pg_connection.nim index 15fb213..db75828 100644 --- a/async_postgres/pg_connection.nim +++ b/async_postgres/pg_connection.nim @@ -6,12 +6,12 @@ import async_backend, pg_protocol, pg_auth, pg_types when hasChronos: import chronos/streams/tlsstream - import bearssl/[x509, rsa, ec] + import pg_bearssl elif hasAsyncDispatch: import std/asyncnet from std/nativesockets import Domain, SockType, Protocol when defined(ssl): - import std/[net, tempfiles, os] + import std/[net, openssl, tempfiles, os] export PgError @@ -50,21 +50,6 @@ type PgNotifyOverflowError* = object of PgError dropped*: int ## Number of notifications dropped due to queue overflow -proc newPgQueryError*(fields: seq[ErrorField]): ref PgQueryError = - ## Create a PgQueryError from server ErrorResponse fields. - let sqlState = getErrorField(fields, 'C') - let severity = getErrorField(fields, 'S') - let detail = getErrorField(fields, 'D') - let hint = getErrorField(fields, 'H') - result = (ref PgQueryError)( - msg: formatError(fields), - sqlState: sqlState, - severity: severity, - detail: detail, - hint: hint, - ) - -type PgConnState* = enum ## Connection lifecycle state. csConnecting @@ -149,8 +134,10 @@ type writer: AsyncStreamWriter tlsStream: TLSAsyncStream trustAnchorBufs: seq[seq[byte]] ## Backing memory for custom trust anchor pointers + x509Capture: X509CertCaptureContext ## X509 wrapper for cert capture elif hasAsyncDispatch: socket*: AsyncSocket + serverCertDer: seq[byte] ## DER-encoded server certificate for SCRAM channel binding sslEnabled*: bool recvBuf*: seq[byte] recvBufStart*: int ## Read pointer into recvBuf; bytes before this are consumed @@ -317,6 +304,20 @@ type onPoolReleaseEnd*: proc(ctx: TraceContext, data: TracePoolReleaseEndData) {.gcsafe, raises: [].} +proc newPgQueryError*(fields: seq[ErrorField]): ref PgQueryError = + ## Create a PgQueryError from server ErrorResponse fields. + let sqlState = getErrorField(fields, 'C') + let severity = getErrorField(fields, 'S') + let detail = getErrorField(fields, 'D') + let hint = getErrorField(fields, 'H') + result = (ref PgQueryError)( + msg: formatError(fields), + sqlState: sqlState, + severity: severity, + detail: detail, + hint: hint, + ) + template withConnTracing*( conn: PgConnection, startHook, endHook: untyped, @@ -411,10 +412,6 @@ when hasChronos: proc(): Future[seq[byte]] {.async: (raises: [CatchableError]), gcsafe.} ## Callback supplying data chunks during streaming COPY IN. Return empty seq to finish. - type TrustAnchorResult = object - store: TrustAnchorStore - backing: seq[seq[byte]] ## Owns memory pointed to by trust anchor fields - else: type CopyOutCallback* = proc(data: seq[byte]): Future[void] {.gcsafe.} ## Callback receiving each chunk during streaming COPY OUT. @@ -707,102 +704,6 @@ proc sendBufMsg*(conn: PgConnection): Future[void] {.async.} = if conn.sendBuf.len > 0: await conn.socket.sendRawData(unsafeAddr conn.sendBuf[0], conn.sendBuf.len) -when hasChronos: - proc appendDnCallback( - ctx: pointer, buf: pointer, len: uint - ) {.exportc: "pg_append_dn_nim", cdecl, gcsafe, noSideEffect, raises: [].} = - let s = cast[ptr seq[byte]](ctx) - let p = cast[ptr UncheckedArray[byte]](buf) - for i in 0 ..< int(len): - s[].add(p[i]) - - # C shim with const void* to satisfy BearSSL's br_x509_decoder_init signature - {. - emit: """ - static void pg_append_dn_shim(void *ctx, const void *buf, size_t len) { - pg_append_dn_nim(ctx, (void*)buf, len); - } - """ - .} - - proc initX509Decoder(ctx: var X509DecoderContext, appendDnCtx: pointer) = - {. - emit: ["br_x509_decoder_init(&", ctx, ", pg_append_dn_shim, ", appendDnCtx, ");"] - .} - - proc parseTrustAnchors(pemData: string): TrustAnchorResult = - ## Parse PEM-encoded CA certificates into a TrustAnchorStore. - ## Returns both the store and the backing memory that anchor pointers reference. - ## - ## IMPORTANT: X509TrustAnchor contains raw `ptr byte` fields (dn.data, - ## pkey.key.rsa.n/e, pkey.key.ec.q). TrustAnchorStore.new() only shallow-copies - ## these structs, and BearSSL only stores a pointer to the anchor array. - ## The caller MUST keep `result.backing` alive for the lifetime of the TLS session. - let items = pemDecode(pemData) - var anchors: seq[X509TrustAnchor] - var backing: seq[seq[byte]] - - for item in items: - if item.name != "CERTIFICATE": - continue - - var dnBuf: seq[byte] - var decoder: X509DecoderContext - initX509Decoder(decoder, addr dnBuf) - x509DecoderPush(decoder, unsafeAddr item.data[0], uint(item.data.len)) - - if x509DecoderLastError(decoder) != 0: - continue - - let pkey = x509DecoderGetPkey(decoder) - if pkey.isNil: - continue - - # Deep-copy DN - backing.add(dnBuf) - let dnData = addr backing[^1][0] - - # Deep-copy public key and build anchor - var anchor: X509TrustAnchor - anchor.dn = X500Name(data: dnData, len: uint(dnBuf.len)) - anchor.flags = - if x509DecoderIsCA(decoder) != 0: - cuint(X509_TA_CA) - else: - 0 - anchor.pkey.keyType = pkey.keyType - - if pkey.keyType == byte(KEYTYPE_RSA): - var nBuf = newSeq[byte](pkey.key.rsa.nlen) - copyMem(addr nBuf[0], pkey.key.rsa.n, nBuf.len) - backing.add(nBuf) - var eBuf = newSeq[byte](pkey.key.rsa.elen) - copyMem(addr eBuf[0], pkey.key.rsa.e, eBuf.len) - backing.add(eBuf) - anchor.pkey.key.rsa = RsaPublicKey( - n: addr backing[^2][0], - nlen: uint(nBuf.len), - e: addr backing[^1][0], - elen: uint(eBuf.len), - ) - elif pkey.keyType == byte(KEYTYPE_EC): - var qBuf = newSeq[byte](pkey.key.ec.qlen) - copyMem(addr qBuf[0], pkey.key.ec.q, qBuf.len) - backing.add(qBuf) - anchor.pkey.key.ec = EcPublicKey( - curve: pkey.key.ec.curve, q: addr backing[^1][0], qlen: uint(qBuf.len) - ) - else: - continue - - anchors.add(anchor) - - if anchors.len == 0: - raise - newException(PgConnectionError, "No valid CA certificates found in PEM data") - - result = TrustAnchorResult(store: TrustAnchorStore.new(anchors), backing: backing) - proc closeTransport(conn: PgConnection) {.async.} = ## Close transport resources without sending Terminate. when hasChronos: @@ -896,6 +797,9 @@ proc negotiateSSL(conn: PgConnection, config: ConnConfig) {.async.} = minVersion = TLSVersion.TLS12, maxVersion = TLSVersion.TLS12, ) + installX509Capture( + conn.x509Capture, conn.tlsStream.ccontext.eng, addr conn.serverCertDer + ) await conn.tlsStream.handshake() conn.reader = conn.tlsStream.reader conn.writer = conn.tlsStream.writer @@ -926,6 +830,17 @@ proc negotiateSSL(conn: PgConnection, config: ConnConfig) {.async.} = let hostname = if config.sslMode == sslVerifyFull: config.host else: "" wrapConnectedSocket(ctx, conn.socket, handshakeAsClient, hostname) conn.sslEnabled = true + # Extract server certificate DER for SCRAM-SHA-256-PLUS channel binding + let peerCert = SSL_get_peer_certificate(conn.socket.sslHandle) + if peerCert != nil: + try: + let derStr = i2d_X509(peerCert) + if derStr.len > 0: + conn.serverCertDer = newSeq[byte](derStr.len) + for i in 0 ..< derStr.len: + conn.serverCertDer[i] = byte(derStr[i]) + finally: + X509_free(peerCert) finally: if tmpPath.len > 0: removeFile(tmpPath) @@ -1176,11 +1091,24 @@ proc connectToHost( let hash = md5AuthHash(config.user, config.password, msg.md5Salt) await conn.sendMsg(encodePassword(hash)) of bmkAuthenticationSASL: - if "SCRAM-SHA-256" notin msg.saslMechanisms: - raise - newException(PgConnectionError, "Server doesn't support SCRAM-SHA-256") - let clientFirst = scramClientFirstMessage(config.user, scramState) - await conn.sendMsg(encodeSASLInitialResponse("SCRAM-SHA-256", clientFirst)) + var mechanism: string + var cbType = "" + var cbData: seq[byte] + if conn.sslEnabled and conn.serverCertDer.len > 0 and + "SCRAM-SHA-256-PLUS" in msg.saslMechanisms: + mechanism = "SCRAM-SHA-256-PLUS" + cbType = "tls-server-end-point" + cbData = computeTlsServerEndpoint(conn.serverCertDer) + elif "SCRAM-SHA-256" in msg.saslMechanisms: + mechanism = "SCRAM-SHA-256" + else: + raise newException( + PgConnectionError, + "Server doesn't support SCRAM-SHA-256 or SCRAM-SHA-256-PLUS", + ) + let clientFirst = + scramClientFirstMessage(config.user, scramState, cbType, cbData) + await conn.sendMsg(encodeSASLInitialResponse(mechanism, clientFirst)) of bmkAuthenticationSASLContinue: let clientFinal = scramClientFinalMessage(config.password, msg.saslData, scramState) diff --git a/tests/test_auth.nim b/tests/test_auth.nim index 6554f94..cf0332b 100644 --- a/tests/test_auth.nim +++ b/tests/test_auth.nim @@ -1,4 +1,7 @@ -import std/[unittest, strutils] +import std/[unittest, strutils, base64] + +import pkg/nimcrypto +import pkg/nimcrypto/pbkdf2 import ../async_postgres/pg_auth @@ -159,3 +162,200 @@ suite "SCRAM-SHA-256": test "scramVerifyServerFinal rejects invalid base64 signature": var state: ScramState check scramVerifyServerFinal(toBytes("v=!!!invalid!!!"), state) == false + +suite "SCRAM-SHA-256-PLUS channel binding": + test "clientFirstMessage with channel binding": + var state: ScramState + let cbData = @[0x01'u8, 0x02, 0x03] + let msg = scramClientFirstMessage( + "user", "testNonce", state, cbType = "tls-server-end-point", cbData = cbData + ) + check toString(msg) == "p=tls-server-end-point,,n=user,r=testNonce" + check state.gs2Header == "p=tls-server-end-point,," + check state.channelBindingData == cbData + check state.clientFirstBare == "n=user,r=testNonce" + + test "clientFirstMessage without channel binding preserves gs2Header": + var state: ScramState + let msg = scramClientFirstMessage("user", "testNonce", state) + check toString(msg) == "n,,n=user,r=testNonce" + check state.gs2Header == "n,," + check state.channelBindingData.len == 0 + + test "clientFinalMessage with channel binding encodes cbind-input": + var state: ScramState + let cbData = @[0xAA'u8, 0xBB, 0xCC] + discard scramClientFirstMessage( + "user", + "rOprNGfwEbeRWgbNEkqO", + state, + cbType = "tls-server-end-point", + cbData = cbData, + ) + let serverFirst = + "r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0," & + "s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096" + let clientFinal = scramClientFinalMessage("pencil", toBytes(serverFirst), state) + let clientFinalStr = toString(clientFinal) + # c= should be base64(gs2Header + cbData), NOT "biws" + let expectedCbind = base64.encode("p=tls-server-end-point,," & "\xAA\xBB\xCC") + check clientFinalStr.startsWith("c=" & expectedCbind & ",r=") + + test "clientFinalMessage without channel binding produces c=biws": + var state: ScramState + discard scramClientFirstMessage("user", "rOprNGfwEbeRWgbNEkqO", state) + let serverFirst = + "r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0," & + "s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096" + let clientFinal = scramClientFinalMessage("pencil", toBytes(serverFirst), state) + let clientFinalStr = toString(clientFinal) + check clientFinalStr.startsWith("c=biws,r=") + + test "RFC 7677 test vectors still pass with default params": + # Backward compatibility: existing test vectors work unchanged + var state: ScramState + let clientFirst = scramClientFirstMessage("user", "rOprNGfwEbeRWgbNEkqO", state) + check toString(clientFirst) == "n,,n=user,r=rOprNGfwEbeRWgbNEkqO" + let serverFirst = + "r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0," & + "s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096" + let clientFinal = scramClientFinalMessage("pencil", toBytes(serverFirst), state) + let expected = + "c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0," & + "p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ=" + check toString(clientFinal) == expected + let serverFinal = "v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4=" + check scramVerifyServerFinal(toBytes(serverFinal), state) == true + + test "full SCRAM-SHA-256-PLUS exchange with server signature verification": + # Simulate a full exchange with channel binding and verify the server + # signature is correctly computed (round-trip: the authMessage changes + # when channel binding is used, so serverSignature must reflect that). + var state: ScramState + let cbData = @[0xDE'u8, 0xAD, 0xBE, 0xEF] + let clientFirst = scramClientFirstMessage( + "user", + "rOprNGfwEbeRWgbNEkqO", + state, + cbType = "tls-server-end-point", + cbData = cbData, + ) + check toString(clientFirst) == + "p=tls-server-end-point,,n=user,r=rOprNGfwEbeRWgbNEkqO" + + let serverFirst = + "r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0," & + "s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096" + let clientFinal = scramClientFinalMessage("pencil", toBytes(serverFirst), state) + let clientFinalStr = toString(clientFinal) + + # c= value must NOT be "biws" (that's the non-PLUS value) + check not clientFinalStr.startsWith("c=biws,") + + # Verify server signature: manually compute expected value + let salt = base64.decode("W22ZaJ0SNY7soEsUEjb6gQ==") + let saltedPassword = sha256.pbkdf2("pencil", salt, 4096, 32) + let serverKey = sha256.hmac(saltedPassword, "Server Key").data + var cbindInput = toBytes("p=tls-server-end-point,,") + cbindInput.add(cbData) + let clientFinalWithoutProof = + "c=" & base64.encode(cbindInput) & + ",r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0" + let authMessage = + "n=user,r=rOprNGfwEbeRWgbNEkqO" & "," & serverFirst & "," & clientFinalWithoutProof + let expectedSig = sha256.hmac(serverKey, authMessage).data + + # Build server-final message and verify + let serverFinal = "v=" & base64.encode(expectedSig) + check scramVerifyServerFinal(toBytes(serverFinal), state) == true + + # Wrong signature must fail + check scramVerifyServerFinal( + toBytes("v=AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="), state + ) == false + + test "channel binding changes server signature vs non-PLUS": + # The server signature for the same credentials must differ between + # SCRAM-SHA-256 and SCRAM-SHA-256-PLUS because authMessage changes. + var stateNormal: ScramState + discard scramClientFirstMessage("user", "rOprNGfwEbeRWgbNEkqO", stateNormal) + let serverFirst = + "r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0," & + "s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096" + discard scramClientFinalMessage("pencil", toBytes(serverFirst), stateNormal) + + var statePlus: ScramState + discard scramClientFirstMessage( + "user", + "rOprNGfwEbeRWgbNEkqO", + statePlus, + cbType = "tls-server-end-point", + cbData = @[0x01'u8], + ) + discard scramClientFinalMessage("pencil", toBytes(serverFirst), statePlus) + + check stateNormal.serverSignature != statePlus.serverSignature + + test "clientFirstMessage with random nonce and channel binding": + var state: ScramState + let cbData = @[0xFF'u8, 0xFE] + let msg = scramClientFirstMessage( + "user", state, cbType = "tls-server-end-point", cbData = cbData + ) + let msgStr = toString(msg) + check msgStr.startsWith("p=tls-server-end-point,,n=user,r=") + check state.gs2Header == "p=tls-server-end-point,," + check state.channelBindingData == cbData + check state.clientNonce.len > 0 + + test "clientFirstMessage with channel binding and special username": + var state: ScramState + let cbData = @[0x01'u8, 0x02] + let msg = scramClientFirstMessage( + "u=ser,1", "testNonce", state, cbType = "tls-server-end-point", cbData = cbData + ) + check toString(msg) == "p=tls-server-end-point,,n=u=3Dser=2C1,r=testNonce" + check state.clientFirstBare == "n=u=3Dser=2C1,r=testNonce" + check state.gs2Header == "p=tls-server-end-point,," + + test "clientFirstMessage with channel binding type but empty cbData": + var state: ScramState + let msg = scramClientFirstMessage( + "user", "testNonce", state, cbType = "tls-server-end-point", cbData = @[] + ) + check toString(msg) == "p=tls-server-end-point,,n=user,r=testNonce" + check state.gs2Header == "p=tls-server-end-point,," + check state.channelBindingData.len == 0 + # c= should be base64("p=tls-server-end-point,,") with no trailing binding data + let serverFirst = + "r=testNonce%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0," & "s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096" + let clientFinal = scramClientFinalMessage("pencil", toBytes(serverFirst), state) + let clientFinalStr = toString(clientFinal) + let expectedCbind = base64.encode("p=tls-server-end-point,,") + check clientFinalStr.startsWith("c=" & expectedCbind & ",r=") + + test "computeTlsServerEndpoint with empty input": + let binding = computeTlsServerEndpoint(@[]) + check binding.len == 32 + # SHA-256 of empty input is a well-known value + let expected = sha256.digest(@(newSeq[byte](0))).data + check binding == @(expected) + + test "computeTlsServerEndpoint matches known SHA-256": + let input = @[0x30'u8, 0x82, 0x01, 0x00] + let binding = computeTlsServerEndpoint(input) + check binding.len == 32 + # Verify against independently computed SHA-256 + let expected = sha256.digest(input).data + check binding == @(expected) + + test "computeTlsServerEndpoint is deterministic": + let cert = @[0x01'u8, 0x02, 0x03, 0x04, 0x05] + let b1 = computeTlsServerEndpoint(cert) + let b2 = computeTlsServerEndpoint(cert) + check b1 == b2 + + test "computeTlsServerEndpoint differs for different certs": + let cert1 = @[0x01'u8, 0x02, 0x03] + let cert2 = @[0x04'u8, 0x05, 0x06] + check computeTlsServerEndpoint(cert1) != computeTlsServerEndpoint(cert2) From 23f74817eab60c5f9e3deaa017265b1fa81611a4 Mon Sep 17 00:00:00 2001 From: fox0430 Date: Wed, 8 Apr 2026 17:52:28 +0900 Subject: [PATCH 2/3] fix --- README.md | 2 +- async_postgres/pg_bearssl.nim | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 684b21b..3f5f8d9 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Async PostgreSQL client in Nim. - Connection pooling with health checks and maintenance - Pool cluster with read replica routing - SSL/TLS support (disable, allow, prefer, require, verify-ca, verify-full) -- MD5 and SCRAM-SHA-256 authentication +- MD5, SCRAM-SHA-256 and SCRAM-SHA-256-PLUS authentication - DSN connection string parsing - Unix socket connection - Multi-host failover diff --git a/async_postgres/pg_bearssl.nim b/async_postgres/pg_bearssl.nim index d79e2bc..3a64645 100644 --- a/async_postgres/pg_bearssl.nim +++ b/async_postgres/pg_bearssl.nim @@ -192,6 +192,6 @@ when hasChronos: if anchors.len == 0: raise - newException(PgConnectionError, "No valid CA certificates found in PEM data") + newException(PgError, "No valid CA certificates found in PEM data") result = TrustAnchorResult(store: TrustAnchorStore.new(anchors), backing: backing) From dba4e021a9f828bf0ec1abc43cb4cb9685756340 Mon Sep 17 00:00:00 2001 From: fox0430 Date: Wed, 8 Apr 2026 17:53:18 +0900 Subject: [PATCH 3/3] nph --- async_postgres/pg_bearssl.nim | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/async_postgres/pg_bearssl.nim b/async_postgres/pg_bearssl.nim index 3a64645..f8a359c 100644 --- a/async_postgres/pg_bearssl.nim +++ b/async_postgres/pg_bearssl.nim @@ -191,7 +191,6 @@ when hasChronos: anchors.add(anchor) if anchors.len == 0: - raise - newException(PgError, "No valid CA certificates found in PEM data") + raise newException(PgError, "No valid CA certificates found in PEM data") result = TrustAnchorResult(store: TrustAnchorStore.new(anchors), backing: backing)