@@ -21,11 +21,18 @@ import (
2121 "sync"
2222 "sync/atomic"
2323 "time"
24+
25+ "go-msspi"
2426)
2527
2628// A Conn represents a secured connection.
2729// It implements the net.Conn interface.
2830type Conn struct {
31+ // msspi
32+ msspiConn bool
33+ msspiErr error
34+ msspi * msspi.Handler
35+
2936 // constant
3037 conn net.Conn
3138 isClient bool
@@ -1198,6 +1205,31 @@ func (c *Conn) Write(b []byte) (int, error) {
11981205 c .out .Lock ()
11991206 defer c .out .Unlock ()
12001207
1208+ // msspi
1209+ if c .msspiConn {
1210+ if c .msspi != nil {
1211+ n , err := c .msspi .Write (b )
1212+ if n > 0 {
1213+ return n , err
1214+ }
1215+
1216+ if err == nil {
1217+ if c .msspi .State (1 ) {
1218+ err = net .ErrClosed
1219+ } else if c .msspi .State (2 ) {
1220+ err = errShutdown
1221+ } else {
1222+ err = io .EOF
1223+ }
1224+ }
1225+
1226+ c .out .setErrorLocked (err )
1227+ return n , c .out .err
1228+ } else {
1229+ return 0 , net .ErrClosed
1230+ }
1231+ }
1232+
12011233 if err := c .out .err ; err != nil {
12021234 return 0 , err
12031235 }
@@ -1366,6 +1398,31 @@ func (c *Conn) Read(b []byte) (int, error) {
13661398 c .in .Lock ()
13671399 defer c .in .Unlock ()
13681400
1401+ // msspi
1402+ if c .msspiConn {
1403+ if c .msspi != nil {
1404+ n , err := c .msspi .Read (b )
1405+ if n > 0 {
1406+ return n , err
1407+ }
1408+
1409+ if err == nil {
1410+ if c .msspi .State (1 ) {
1411+ err = net .ErrClosed
1412+ } else if c .msspi .State (2 ) {
1413+ err = io .EOF
1414+ } else {
1415+ err = io .EOF
1416+ }
1417+ }
1418+
1419+ c .in .setErrorLocked (err )
1420+ return n , c .in .err
1421+ } else {
1422+ return 0 , net .ErrClosed
1423+ }
1424+ }
1425+
13691426 for c .input .Len () == 0 {
13701427 if err := c .readRecord (); err != nil {
13711428 return 0 , err
@@ -1419,6 +1476,17 @@ func (c *Conn) Close() error {
14191476 return c .conn .Close ()
14201477 }
14211478
1479+ // msspi
1480+ if c .msspiConn {
1481+ if c .msspi != nil {
1482+ err := c .msspi .Close ()
1483+ c .msspi = nil
1484+ return err
1485+ } else {
1486+ return net .ErrClosed
1487+ }
1488+ }
1489+
14221490 var alertErr error
14231491 if c .isHandshakeComplete .Load () {
14241492 if err := c .closeNotify (); err != nil {
@@ -1449,6 +1517,15 @@ func (c *Conn) closeNotify() error {
14491517 c .out .Lock ()
14501518 defer c .out .Unlock ()
14511519
1520+ // msspi
1521+ if c .msspiConn {
1522+ if c .msspi != nil {
1523+ return c .msspi .Shutdown ()
1524+ } else {
1525+ return errShutdown
1526+ }
1527+ }
1528+
14521529 if ! c .closeNotifySent {
14531530 // Set a Write Deadline to prevent possibly blocking forever.
14541531 c .SetWriteDeadline (time .Now ().Add (time .Second * 5 ))
@@ -1593,6 +1670,84 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
15931670 return c .handshakeErr
15941671}
15951672
1673+ // msspi
1674+ // msspiHandshake handshakes msspi
1675+ func (c * Conn ) msspiHandshake (ctx context.Context ) error {
1676+ if c .msspi == nil {
1677+ return c .msspiErr
1678+ }
1679+
1680+ err := c .msspi .Handshake ()
1681+ if err != nil {
1682+ return err
1683+ }
1684+
1685+ c .vers = c .msspi .VersionTLS ()
1686+ c .cipherSuite = c .msspi .CipherSuite ()
1687+ c .clientProtocol = c .msspi .ClientProtocol ()
1688+
1689+ if c .config .ServerName != "" {
1690+ c .serverName = c .config .ServerName
1691+ }
1692+
1693+ var isPeerCertsRequest bool
1694+ var isPeerCertsRequire bool
1695+ var isPeerCertsVerify bool
1696+
1697+ if c .isClient {
1698+ isPeerCertsRequest = true
1699+ isPeerCertsRequire = true
1700+ isPeerCertsVerify = ! c .config .InsecureSkipVerify
1701+ } else {
1702+ isPeerCertsRequest = c .config .ClientAuth != NoClientCert
1703+ isPeerCertsRequire = requiresClientCert (c .config .ClientAuth )
1704+ isPeerCertsVerify = c .config .ClientAuth >= VerifyClientCertIfGiven
1705+ }
1706+
1707+ if isPeerCertsRequest {
1708+ certificates := c .msspi .PeerCertificates ()
1709+
1710+ certs := make ([]* x509.Certificate , len (certificates ))
1711+ for i , asn1Data := range certificates {
1712+ cert , err := x509 .ParseCertificate (asn1Data )
1713+ if err == nil {
1714+ certs [i ] = cert
1715+ }
1716+ }
1717+
1718+ c .peerCertificates = certs
1719+ }
1720+
1721+ isPeerCerts := len (c .peerCertificates ) != 0
1722+
1723+ if isPeerCertsRequire && ! isPeerCerts {
1724+ return errors .New ("tls: peer didn't provide a certificate" )
1725+ }
1726+
1727+ if isPeerCertsVerify && isPeerCerts {
1728+ certificates := c .msspi .VerifiedChains ()
1729+
1730+ if certificates == nil {
1731+ return errors .New ("tls: failed to verify peer certificate" )
1732+ }
1733+
1734+ certs := make ([]* x509.Certificate , len (certificates ))
1735+ for i , asn1Data := range certificates {
1736+ cert , err := x509 .ParseCertificate (asn1Data )
1737+ if err == nil {
1738+ certs [i ] = cert
1739+ }
1740+ }
1741+
1742+ var verifiedChains [][]* x509.Certificate
1743+ c .verifiedChains = append (verifiedChains , certs )
1744+ }
1745+
1746+ atomic .StoreUint32 (& c .handshakeStatus , 1 )
1747+
1748+ return nil
1749+ }
1750+
15961751// ConnectionState returns basic TLS details about the connection.
15971752func (c * Conn ) ConnectionState () ConnectionState {
15981753 c .handshakeMutex .Lock ()
0 commit comments