@@ -35,6 +35,7 @@ import (
3535 "google.golang.org/grpc/credentials/tls/certprovider"
3636 icredentials "google.golang.org/grpc/internal/credentials"
3737 xdsinternal "google.golang.org/grpc/internal/credentials/xds"
38+ "google.golang.org/grpc/internal/envconfig"
3839 "google.golang.org/grpc/internal/grpctest"
3940 "google.golang.org/grpc/internal/testutils"
4041 "google.golang.org/grpc/internal/xds/matcher"
@@ -219,7 +220,7 @@ func makeRootProvider(t *testing.T, caPath string) *fakeProvider {
219220
220221// newTestContextWithHandshakeInfo returns a copy of parent with HandshakeInfo
221222// context value added to it.
222- func newTestContextWithHandshakeInfo (parent context.Context , root , identity certprovider.Provider , sanExactMatch string ) context.Context {
223+ func newTestContextWithHandshakeInfo (parent context.Context , root , identity certprovider.Provider , sanExactMatch , sni string , validateSANUsingSNI bool ) context.Context {
223224 // Creating the HandshakeInfo and adding it to the attributes is very
224225 // similar to what the CDS balancer would do when it intercepts calls to
225226 // NewSubConn().
@@ -228,7 +229,7 @@ func newTestContextWithHandshakeInfo(parent context.Context, root, identity cert
228229 sms = []matcher.StringMatcher {matcher .NewExactStringMatcher (sanExactMatch , false )}
229230 }
230231 var hiPtr atomic.Pointer [xdsinternal.HandshakeInfo ]
231- info := xdsinternal .NewHandshakeInfo (root , identity , sms , false )
232+ info := xdsinternal .NewHandshakeInfo (root , identity , sms , false , sni , validateSANUsingSNI )
232233 hiPtr .Store (info )
233234 addr := xdsinternal .SetHandshakeInfo (resolver.Address {}, & hiPtr )
234235
@@ -302,7 +303,7 @@ func (s) TestClientCredsInvalidHandshakeInfo(t *testing.T) {
302303
303304 pCtx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
304305 defer cancel ()
305- ctx := newTestContextWithHandshakeInfo (pCtx , nil , & fakeProvider {}, "" )
306+ ctx := newTestContextWithHandshakeInfo (pCtx , nil , & fakeProvider {}, "" , "" , false )
306307 if _ , _ , err := creds .ClientHandshake (ctx , authority , nil ); err == nil {
307308 t .Fatal ("ClientHandshake succeeded without root certificate provider in HandshakeInfo" )
308309 }
@@ -339,7 +340,7 @@ func (s) TestClientCredsProviderFailure(t *testing.T) {
339340 t .Run (test .desc , func (t * testing.T ) {
340341 ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
341342 defer cancel ()
342- ctx = newTestContextWithHandshakeInfo (ctx , test .rootProvider , test .identityProvider , "" )
343+ ctx = newTestContextWithHandshakeInfo (ctx , test .rootProvider , test .identityProvider , "" , "" , false )
343344 if _ , _ , err := creds .ClientHandshake (ctx , authority , nil ); err == nil || ! strings .Contains (err .Error (), test .wantErr ) {
344345 t .Fatalf ("ClientHandshake() returned error: %q, wantErr: %q" , err , test .wantErr )
345346 }
@@ -353,6 +354,7 @@ func (s) TestClientCredsSuccess(t *testing.T) {
353354 desc string
354355 handshakeFunc testHandshakeFunc
355356 handshakeInfoCtx func (ctx context.Context ) context.Context
357+ enableSNIFlag bool
356358 }{
357359 {
358360 desc : "fallback" ,
@@ -367,27 +369,59 @@ func (s) TestClientCredsSuccess(t *testing.T) {
367369 desc : "TLS" ,
368370 handshakeFunc : testServerTLSHandshake ,
369371 handshakeInfoCtx : func (ctx context.Context ) context.Context {
370- return newTestContextWithHandshakeInfo (ctx , makeRootProvider (t , "x509/server_ca_cert.pem" ), nil , defaultTestCertSAN )
372+ return newTestContextWithHandshakeInfo (ctx , makeRootProvider (t , "x509/server_ca_cert.pem" ), nil , defaultTestCertSAN , "" , false )
371373 },
372374 },
373375 {
374376 desc : "mTLS" ,
375377 handshakeFunc : testServerMutualTLSHandshake ,
376378 handshakeInfoCtx : func (ctx context.Context ) context.Context {
377- return newTestContextWithHandshakeInfo (ctx , makeRootProvider (t , "x509/server_ca_cert.pem" ), makeIdentityProvider (t , "x509/server1_cert.pem" , "x509/server1_key.pem" ), defaultTestCertSAN )
379+ return newTestContextWithHandshakeInfo (ctx , makeRootProvider (t , "x509/server_ca_cert.pem" ), makeIdentityProvider (t , "x509/server1_cert.pem" , "x509/server1_key.pem" ), defaultTestCertSAN , "" , false )
378380 },
379381 },
380382 {
381383 desc : "mTLS with no acceptedSANs specified" ,
382384 handshakeFunc : testServerMutualTLSHandshake ,
383385 handshakeInfoCtx : func (ctx context.Context ) context.Context {
384- return newTestContextWithHandshakeInfo (ctx , makeRootProvider (t , "x509/server_ca_cert.pem" ), makeIdentityProvider (t , "x509/server1_cert.pem" , "x509/server1_key.pem" ), "" )
386+ return newTestContextWithHandshakeInfo (ctx , makeRootProvider (t , "x509/server_ca_cert.pem" ), makeIdentityProvider (t , "x509/server1_cert.pem" , "x509/server1_key.pem" ), "" , "" , false )
385387 },
386388 },
389+ {
390+ desc : "TLS with SNI" ,
391+ handshakeFunc : testServerTLSHandshake ,
392+ handshakeInfoCtx : func (ctx context.Context ) context.Context {
393+ return newTestContextWithHandshakeInfo (ctx , makeRootProvider (t , "x509/server_ca_cert.pem" ), nil , "bad-match" , defaultTestCertSAN , true )
394+ },
395+ enableSNIFlag : true ,
396+ },
397+ {
398+ desc : "TLS with SNI, env variable disabled, AutoSniSanValidation enabled" ,
399+ handshakeFunc : testServerTLSHandshake ,
400+ handshakeInfoCtx : func (ctx context.Context ) context.Context {
401+ return newTestContextWithHandshakeInfo (ctx , makeRootProvider (t , "x509/server_ca_cert.pem" ), nil , defaultTestCertSAN , "bad-sni" , true )
402+ },
403+ },
404+ {
405+ desc : "TLS with SNI, env variable enabled but AutoSniSanValidation disabled" ,
406+ handshakeFunc : testServerTLSHandshake ,
407+ handshakeInfoCtx : func (ctx context.Context ) context.Context {
408+ return newTestContextWithHandshakeInfo (ctx , makeRootProvider (t , "x509/server_ca_cert.pem" ), nil , defaultTestCertSAN , "bad-sni" , false )
409+ },
410+ enableSNIFlag : true ,
411+ },
412+ {
413+ desc : "TLS with empty SNI, env variable enabled, AutoSniSanValidation enabled" ,
414+ handshakeFunc : testServerTLSHandshake ,
415+ handshakeInfoCtx : func (ctx context.Context ) context.Context {
416+ return newTestContextWithHandshakeInfo (ctx , makeRootProvider (t , "x509/server_ca_cert.pem" ), nil , defaultTestCertSAN , "" , true )
417+ },
418+ enableSNIFlag : true ,
419+ },
387420 }
388421
389422 for _ , test := range tests {
390423 t .Run (test .desc , func (t * testing.T ) {
424+ testutils .SetEnvConfig (t , & envconfig .XDSSNIEnabled , test .enableSNIFlag )
391425 ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
392426 defer cancel ()
393427 ts := newTestServerWithHandshakeFunc (ctx , test .handshakeFunc )
@@ -444,7 +478,7 @@ func (s) TestClientCredsHandshakeTimeout(t *testing.T) {
444478
445479 sCtx , sCancel := context .WithTimeout (context .Background (), defaultTestShortTimeout )
446480 defer sCancel ()
447- ctx = newTestContextWithHandshakeInfo (sCtx , makeRootProvider (t , "x509/server_ca_cert.pem" ), nil , defaultTestCertSAN )
481+ ctx = newTestContextWithHandshakeInfo (sCtx , makeRootProvider (t , "x509/server_ca_cert.pem" ), nil , defaultTestCertSAN , "" , false )
448482 if _ , _ , err := creds .ClientHandshake (ctx , authority , conn ); err == nil {
449483 t .Fatal ("ClientHandshake() succeeded when expected to timeout" )
450484 }
@@ -467,11 +501,14 @@ func (s) TestClientCredsHandshakeTimeout(t *testing.T) {
467501// TestClientCredsHandshakeFailure verifies different handshake failure cases.
468502func (s ) TestClientCredsHandshakeFailure (t * testing.T ) {
469503 tests := []struct {
470- desc string
471- handshakeFunc testHandshakeFunc
472- rootProvider certprovider.Provider
473- san string
474- wantErr string
504+ desc string
505+ handshakeFunc testHandshakeFunc
506+ rootProvider certprovider.Provider
507+ san string
508+ sni string
509+ validateSANUsingSNI bool
510+ enableSNIFlag bool
511+ wantErr string
475512 }{
476513 {
477514 desc : "cert validation failure" ,
@@ -487,10 +524,49 @@ func (s) TestClientCredsHandshakeFailure(t *testing.T) {
487524 san : "bad-san" ,
488525 wantErr : "do not match any of the accepted SANs" ,
489526 },
527+ {
528+ desc : "SNI SAN mismatch" ,
529+ handshakeFunc : testServerTLSHandshake ,
530+ rootProvider : makeRootProvider (t , "x509/server_ca_cert.pem" ),
531+ sni : "bad-sni" ,
532+ validateSANUsingSNI : true ,
533+ wantErr : "do not match the SNI" ,
534+ enableSNIFlag : true ,
535+ },
536+ {
537+ desc : "SNI set, AutoSniSanValidation disabled with SAN mismatch" ,
538+ handshakeFunc : testServerTLSHandshake ,
539+ rootProvider : makeRootProvider (t , "x509/server_ca_cert.pem" ),
540+ sni : defaultTestCertSAN ,
541+ san : "bad-san" ,
542+ validateSANUsingSNI : false ,
543+ wantErr : "do not match any of the accepted SANs" ,
544+ enableSNIFlag : true ,
545+ },
546+ {
547+ desc : "SNI set with SAN mismatch and AutoSniSanValidation enabled, environment variable disabled" ,
548+ handshakeFunc : testServerTLSHandshake ,
549+ rootProvider : makeRootProvider (t , "x509/server_ca_cert.pem" ),
550+ sni : defaultTestCertSAN ,
551+ san : "bad-san" ,
552+ validateSANUsingSNI : true ,
553+ wantErr : "do not match any of the accepted SANs" ,
554+ },
555+ {
556+ desc : "SNI empty, AutoSniSanValidation enabled with SAN mismatch" ,
557+ handshakeFunc : testServerTLSHandshake ,
558+ rootProvider : makeRootProvider (t , "x509/server_ca_cert.pem" ),
559+ sni : "" ,
560+ san : "bad-san" ,
561+ validateSANUsingSNI : true ,
562+ wantErr : "do not match any of the accepted SANs" ,
563+ enableSNIFlag : true ,
564+ },
490565 }
491566
492567 for _ , test := range tests {
493568 t .Run (test .desc , func (t * testing.T ) {
569+ testutils .SetEnvConfig (t , & envconfig .XDSSNIEnabled , test .enableSNIFlag )
494570 ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
495571 defer cancel ()
496572 ts := newTestServerWithHandshakeFunc (ctx , test .handshakeFunc )
@@ -508,7 +584,7 @@ func (s) TestClientCredsHandshakeFailure(t *testing.T) {
508584 }
509585 defer conn .Close ()
510586
511- ctx = newTestContextWithHandshakeInfo (ctx , test .rootProvider , nil , test .san )
587+ ctx = newTestContextWithHandshakeInfo (ctx , test .rootProvider , nil , test .san , test . sni , test . validateSANUsingSNI )
512588 if _ , _ , err := creds .ClientHandshake (ctx , authority , conn ); err == nil || ! strings .Contains (err .Error (), test .wantErr ) {
513589 t .Fatalf ("ClientHandshake() returned %q, wantErr %q" , err , test .wantErr )
514590 }
@@ -542,7 +618,7 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {
542618 // Create a root provider which will fail the handshake because it does not
543619 // use the correct trust roots.
544620 root1 := makeRootProvider (t , "x509/client_ca_cert.pem" )
545- handshakeInfo := xdsinternal .NewHandshakeInfo (root1 , nil , []matcher.StringMatcher {matcher .NewExactStringMatcher (defaultTestCertSAN , false )}, false )
621+ handshakeInfo := xdsinternal .NewHandshakeInfo (root1 , nil , []matcher.StringMatcher {matcher .NewExactStringMatcher (defaultTestCertSAN , false )}, false , "" , false )
546622 // We need to repeat most of what newTestContextWithHandshakeInfo() does
547623 // here because we need access to the underlying HandshakeInfo so that we
548624 // can update it before the next call to ClientHandshake().
@@ -569,7 +645,7 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {
569645 // Create a new root provider which uses the correct trust roots. And update
570646 // the HandshakeInfo with the new provider.
571647 root2 := makeRootProvider (t , "x509/server_ca_cert.pem" )
572- handshakeInfo = xdsinternal .NewHandshakeInfo (root2 , nil , []matcher.StringMatcher {matcher .NewExactStringMatcher (defaultTestCertSAN , false )}, false )
648+ handshakeInfo = xdsinternal .NewHandshakeInfo (root2 , nil , []matcher.StringMatcher {matcher .NewExactStringMatcher (defaultTestCertSAN , false )}, false , "" , false )
573649 // Update the existing pointer, which address attribute will continue to
574650 // point to.
575651 hiPtr .Store (handshakeInfo )
0 commit comments