Skip to content

Commit 3fcb1a6

Browse files
authored
database/postgres: add inline certificate authentication fields (#28024)
* add inline cert auth to postres db plugin * handle both sslinline and new TLS plugin fields * refactor PrepareTestContainerWithSSL * add tests for postgres inline TLS fields * changelog * revert back to errwrap since the middleware sanitizing depends on it * enable only setting sslrootcert
1 parent a19195c commit 3fcb1a6

File tree

10 files changed

+371
-94
lines changed

10 files changed

+371
-94
lines changed

builtin/logical/database/backend_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,8 @@ func TestBackend_config_connection(t *testing.T) {
345345
assert.Equal(t, "plugin-test", eventSender.Events[2].Event.Metadata.AsMap()["name"])
346346
}
347347

348+
// TestBackend_BadConnectionString tests that an error response resulting from
349+
// a failed connection does not expose the URL. The middleware should sanitize it.
348350
func TestBackend_BadConnectionString(t *testing.T) {
349351
cluster, sys := getClusterPostgresDB(t)
350352
defer cluster.Cleanup()

changelog/28024.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
```release-note:improvement
2+
database/postgres: Add new fields to the plugin's config endpoint for client certificate authentication.
3+
```

helper/testhelpers/postgresql/postgresqlhelper.go

Lines changed: 41 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ import (
99
"fmt"
1010
"net/url"
1111
"os"
12+
"strconv"
1213
"testing"
14+
"time"
1315

1416
"github.com/hashicorp/vault/helper/testhelpers/certhelpers"
15-
"github.com/hashicorp/vault/sdk/database/helper/connutil"
1617
"github.com/hashicorp/vault/sdk/helper/docker"
18+
"github.com/hashicorp/vault/sdk/helper/pluginutil"
1719
)
1820

1921
const (
@@ -68,7 +70,13 @@ func PrepareTestContainerWithVaultUser(t *testing.T, ctx context.Context) (func(
6870

6971
// PrepareTestContainerWithSSL will setup a test container with SSL enabled so
7072
// that we can test client certificate authentication.
71-
func PrepareTestContainerWithSSL(t *testing.T, ctx context.Context, sslMode string, useFallback bool) (func(), string) {
73+
func PrepareTestContainerWithSSL(
74+
t *testing.T,
75+
sslMode string,
76+
caCert certhelpers.Certificate,
77+
clientCert certhelpers.Certificate,
78+
useFallback bool,
79+
) (func(), string) {
7280
runOpts := defaultRunOpts(t)
7381
runner, err := docker.NewServiceRunner(runOpts)
7482
if err != nil {
@@ -82,21 +90,11 @@ func PrepareTestContainerWithSSL(t *testing.T, ctx context.Context, sslMode stri
8290
}
8391

8492
// Create certificates for postgres authentication
85-
caCert := certhelpers.NewCert(t,
86-
certhelpers.CommonName("ca"),
87-
certhelpers.IsCA(true),
88-
certhelpers.SelfSign(),
89-
)
9093
serverCert := certhelpers.NewCert(t,
9194
certhelpers.CommonName("server"),
9295
certhelpers.DNS("localhost"),
9396
certhelpers.Parent(caCert),
9497
)
95-
clientCert := certhelpers.NewCert(t,
96-
certhelpers.CommonName("postgres"),
97-
certhelpers.DNS("localhost"),
98-
certhelpers.Parent(caCert),
99-
)
10098

10199
bCtx := docker.NewBuildContext()
102100
bCtx["ca.crt"] = docker.PathContentsFromBytes(caCert.CombinedPEM())
@@ -133,6 +131,9 @@ EOF
133131
t.Fatalf("failed to copy to container: %v", err)
134132
}
135133

134+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
135+
defer cancel()
136+
136137
// overwrite the postgresql.conf config file with our ssl settings
137138
mustRunCommand(t, ctx, runner, id,
138139
[]string{"bash", "/var/lib/postgresql/pg-conf.sh"})
@@ -150,7 +151,7 @@ EOF
150151
return svc.Cleanup, svc.Config.URL().String()
151152
}
152153

153-
sslConfig, err := connectPostgresSSL(
154+
sslConfig := getPostgresSSLConfig(
154155
t,
155156
svc.Config.URL().Host,
156157
sslMode,
@@ -197,42 +198,40 @@ func prepareTestContainer(t *testing.T, runOpts docker.RunOptions, password stri
197198
return runner, svc.Cleanup, svc.Config.URL().String(), containerID
198199
}
199200

200-
// connectPostgresSSL is used to verify the connection of our test container
201-
// and construct the connection string that is used in tests.
202-
//
203-
// NOTE: The RawQuery component of the url sets the custom sslinline field and
204-
// inlines the certificate material in the sslrootcert, sslcert, and sslkey
205-
// fields. This feature will be removed in a future version of the SDK.
206-
func connectPostgresSSL(t *testing.T, host, sslMode, caCert, clientCert, clientKey string, useFallback bool) (docker.ServiceConfig, error) {
201+
func getPostgresSSLConfig(t *testing.T, host, sslMode, caCert, clientCert, clientKey string, useFallback bool) docker.ServiceConfig {
207202
if useFallback {
208203
// set the first host to a bad address so we can test the fallback logic
209204
host = "localhost:55," + host
210205
}
211-
u := url.URL{
212-
Scheme: "postgres",
213-
User: url.User("postgres"),
214-
Host: host,
215-
Path: "postgres",
216-
RawQuery: url.Values{
217-
"sslmode": {sslMode},
218-
"sslinline": {"true"},
219-
"sslrootcert": {caCert},
220-
"sslcert": {clientCert},
221-
"sslkey": {clientKey},
222-
}.Encode(),
223-
}
224206

225-
// TODO: remove this deprecated function call in a future SDK version
226-
db, err := connutil.OpenPostgres("pgx", u.String())
227-
if err != nil {
228-
return nil, err
207+
u := url.URL{}
208+
209+
if ok, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUsePostgresSSLInline)); ok {
210+
// TODO: remove this when we remove the underlying feature in a future SDK version
211+
u = url.URL{
212+
Scheme: "postgres",
213+
User: url.User("postgres"),
214+
Host: host,
215+
Path: "postgres",
216+
RawQuery: url.Values{
217+
"sslmode": {sslMode},
218+
"sslinline": {"true"},
219+
"sslrootcert": {caCert},
220+
"sslcert": {clientCert},
221+
"sslkey": {clientKey},
222+
}.Encode(),
223+
}
224+
} else {
225+
u = url.URL{
226+
Scheme: "postgres",
227+
User: url.User("postgres"),
228+
Host: host,
229+
Path: "postgres",
230+
RawQuery: url.Values{"sslmode": {sslMode}}.Encode(),
231+
}
229232
}
230-
defer db.Close()
231233

232-
if err = db.Ping(); err != nil {
233-
return nil, err
234-
}
235-
return docker.NewServiceURL(u), nil
234+
return docker.NewServiceURL(u)
236235
}
237236

238237
func connectPostgres(password, repo string, useFallback bool) docker.ServiceAdapter {

plugins/database/mysql/connection_producer.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,8 @@ func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]inte
123123
}
124124

125125
// validate auth_type if provided
126-
authType := c.AuthType
127-
if authType != "" {
128-
if ok := connutil.ValidateAuthType(authType); !ok {
129-
return nil, fmt.Errorf("invalid auth_type %s provided", authType)
130-
}
126+
if ok := connutil.ValidateAuthType(c.AuthType); !ok {
127+
return nil, fmt.Errorf("invalid auth_type: %s", c.AuthType)
131128
}
132129

133130
if c.AuthType == connutil.AuthTypeGCPIAM {

plugins/database/postgresql/postgresql.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ package postgresql
55

66
import (
77
"context"
8+
"crypto/tls"
9+
"crypto/x509"
810
"database/sql"
11+
"encoding/pem"
12+
"errors"
913
"fmt"
1014
"regexp"
1115
"strings"
@@ -79,11 +83,65 @@ func new() *PostgreSQL {
7983
type PostgreSQL struct {
8084
*connutil.SQLConnectionProducer
8185

86+
TLSCertificateData []byte `json:"tls_certificate" structs:"-" mapstructure:"tls_certificate"`
87+
TLSPrivateKey []byte `json:"tls_private_key" structs:"-" mapstructure:"tls_private_key"`
88+
TLSCAData []byte `json:"tls_ca" structs:"-" mapstructure:"tls_ca"`
89+
8290
usernameProducer template.StringTemplate
8391
passwordAuthentication passwordAuthentication
8492
}
8593

8694
func (p *PostgreSQL) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
95+
sslcert, err := strutil.GetString(req.Config, "tls_certificate")
96+
if err != nil {
97+
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_certificate: %w", err)
98+
}
99+
100+
sslkey, err := strutil.GetString(req.Config, "tls_private_key")
101+
if err != nil {
102+
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_private_key: %w", err)
103+
}
104+
105+
sslrootcert, err := strutil.GetString(req.Config, "tls_ca")
106+
if err != nil {
107+
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_ca: %w", err)
108+
}
109+
110+
useTLS := false
111+
tlsConfig := &tls.Config{}
112+
if sslrootcert != "" {
113+
caCertPool := x509.NewCertPool()
114+
if !caCertPool.AppendCertsFromPEM([]byte(sslrootcert)) {
115+
return dbplugin.InitializeResponse{}, errors.New("unable to add CA to cert pool")
116+
}
117+
118+
tlsConfig.RootCAs = caCertPool
119+
tlsConfig.ClientCAs = caCertPool
120+
p.TLSConfig = tlsConfig
121+
useTLS = true
122+
}
123+
124+
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
125+
return dbplugin.InitializeResponse{}, errors.New(`both "sslcert" and "sslkey" are required`)
126+
}
127+
128+
if sslcert != "" && sslkey != "" {
129+
block, _ := pem.Decode([]byte(sslkey))
130+
131+
cert, err := tls.X509KeyPair([]byte(sslcert), pem.EncodeToMemory(block))
132+
if err != nil {
133+
return dbplugin.InitializeResponse{}, fmt.Errorf("unable to load cert: %w", err)
134+
}
135+
tlsConfig.Certificates = []tls.Certificate{cert}
136+
p.TLSConfig = tlsConfig
137+
useTLS = true
138+
}
139+
140+
if !useTLS {
141+
// set to nil to flag that this connection does not use a custom TLS config
142+
p.TLSConfig = nil
143+
}
144+
87145
newConf, err := p.SQLConnectionProducer.Init(ctx, req.Config, req.VerifyConnection)
88146
if err != nil {
89147
return dbplugin.InitializeResponse{}, err

0 commit comments

Comments
 (0)