Skip to content

Commit 49934b6

Browse files
committed
go-msspi patches
1 parent cb4eee6 commit 49934b6

7 files changed

Lines changed: 282 additions & 0 deletions

File tree

.gitmodules

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[submodule "src/go-msspi"]
2+
path = src/go-msspi
3+
url = https://github.com/deemru/go-msspi.git
4+
[submodule "src/go-pointer"]
5+
path = src/go-pointer
6+
url = https://github.com/mattn/go-pointer.git

src/crypto/tls/common.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,9 @@ const (
525525
// modified. A Config may be reused; the tls package will also not
526526
// modify it.
527527
type Config struct {
528+
// msspi
529+
msspiConfig bool
530+
528531
// Rand provides the source of entropy for nonces and RSA blinding.
529532
// If Rand is nil, TLS uses the cryptographic random reader in package
530533
// crypto/rand.
@@ -1410,6 +1413,9 @@ var writerMutex sync.Mutex
14101413

14111414
// A Certificate is a chain of one or more certificates, leaf first.
14121415
type Certificate struct {
1416+
// msspi
1417+
msspiCert bool
1418+
14131419
Certificate [][]byte
14141420
// PrivateKey contains the private key corresponding to the public key in
14151421
// Leaf. This must implement crypto.Signer with an RSA, ECDSA or Ed25519 PublicKey.

src/crypto/tls/conn.go

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,18 @@ import (
2121
"sync"
2222
"sync/atomic"
2323
"time"
24+
25+
"go-msspi"
2426
)
2527

2628
// A Conn represents a secured connection.
2729
// It implements the net.Conn interface.
2830
type Conn struct {
31+
// msspi
32+
msspiConn bool
33+
msspiErr error
34+
msspi *msspi.Handler
35+
2936
// constant
3037
conn net.Conn
3138
isClient bool
@@ -1198,6 +1205,31 @@ func (c *Conn) Write(b []byte) (int, error) {
11981205
c.out.Lock()
11991206
defer c.out.Unlock()
12001207

1208+
// msspi
1209+
if c.msspiConn {
1210+
if c.msspi != nil {
1211+
n, err := c.msspi.Write(b)
1212+
if n > 0 {
1213+
return n, err
1214+
}
1215+
1216+
if err == nil {
1217+
if c.msspi.State(1) {
1218+
err = net.ErrClosed
1219+
} else if c.msspi.State(2) {
1220+
err = errShutdown
1221+
} else {
1222+
err = io.EOF
1223+
}
1224+
}
1225+
1226+
c.out.setErrorLocked(err)
1227+
return n, c.out.err
1228+
} else {
1229+
return 0, net.ErrClosed
1230+
}
1231+
}
1232+
12011233
if err := c.out.err; err != nil {
12021234
return 0, err
12031235
}
@@ -1366,6 +1398,31 @@ func (c *Conn) Read(b []byte) (int, error) {
13661398
c.in.Lock()
13671399
defer c.in.Unlock()
13681400

1401+
// msspi
1402+
if c.msspiConn {
1403+
if c.msspi != nil {
1404+
n, err := c.msspi.Read(b)
1405+
if n > 0 {
1406+
return n, err
1407+
}
1408+
1409+
if err == nil {
1410+
if c.msspi.State(1) {
1411+
err = net.ErrClosed
1412+
} else if c.msspi.State(2) {
1413+
err = io.EOF
1414+
} else {
1415+
err = io.EOF
1416+
}
1417+
}
1418+
1419+
c.in.setErrorLocked(err)
1420+
return n, c.in.err
1421+
} else {
1422+
return 0, net.ErrClosed
1423+
}
1424+
}
1425+
13691426
for c.input.Len() == 0 {
13701427
if err := c.readRecord(); err != nil {
13711428
return 0, err
@@ -1419,6 +1476,17 @@ func (c *Conn) Close() error {
14191476
return c.conn.Close()
14201477
}
14211478

1479+
// msspi
1480+
if c.msspiConn {
1481+
if c.msspi != nil {
1482+
err := c.msspi.Close()
1483+
c.msspi = nil
1484+
return err
1485+
} else {
1486+
return net.ErrClosed
1487+
}
1488+
}
1489+
14221490
var alertErr error
14231491
if c.isHandshakeComplete.Load() {
14241492
if err := c.closeNotify(); err != nil {
@@ -1449,6 +1517,15 @@ func (c *Conn) closeNotify() error {
14491517
c.out.Lock()
14501518
defer c.out.Unlock()
14511519

1520+
// msspi
1521+
if c.msspiConn {
1522+
if c.msspi != nil {
1523+
return c.msspi.Shutdown()
1524+
} else {
1525+
return errShutdown
1526+
}
1527+
}
1528+
14521529
if !c.closeNotifySent {
14531530
// Set a Write Deadline to prevent possibly blocking forever.
14541531
c.SetWriteDeadline(time.Now().Add(time.Second * 5))
@@ -1593,6 +1670,84 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
15931670
return c.handshakeErr
15941671
}
15951672

1673+
// msspi
1674+
// msspiHandshake handshakes msspi
1675+
func (c *Conn) msspiHandshake(ctx context.Context) error {
1676+
if c.msspi == nil {
1677+
return c.msspiErr
1678+
}
1679+
1680+
err := c.msspi.Handshake()
1681+
if err != nil {
1682+
return err
1683+
}
1684+
1685+
c.vers = c.msspi.VersionTLS()
1686+
c.cipherSuite = c.msspi.CipherSuite()
1687+
c.clientProtocol = c.msspi.ClientProtocol()
1688+
1689+
if c.config.ServerName != "" {
1690+
c.serverName = c.config.ServerName
1691+
}
1692+
1693+
var isPeerCertsRequest bool
1694+
var isPeerCertsRequire bool
1695+
var isPeerCertsVerify bool
1696+
1697+
if c.isClient {
1698+
isPeerCertsRequest = true
1699+
isPeerCertsRequire = true
1700+
isPeerCertsVerify = !c.config.InsecureSkipVerify
1701+
} else {
1702+
isPeerCertsRequest = c.config.ClientAuth != NoClientCert
1703+
isPeerCertsRequire = requiresClientCert(c.config.ClientAuth)
1704+
isPeerCertsVerify = c.config.ClientAuth >= VerifyClientCertIfGiven
1705+
}
1706+
1707+
if isPeerCertsRequest {
1708+
certificates := c.msspi.PeerCertificates()
1709+
1710+
certs := make([]*x509.Certificate, len(certificates))
1711+
for i, asn1Data := range certificates {
1712+
cert, err := x509.ParseCertificate(asn1Data)
1713+
if err == nil {
1714+
certs[i] = cert
1715+
}
1716+
}
1717+
1718+
c.peerCertificates = certs
1719+
}
1720+
1721+
isPeerCerts := len(c.peerCertificates) != 0
1722+
1723+
if isPeerCertsRequire && !isPeerCerts {
1724+
return errors.New("tls: peer didn't provide a certificate")
1725+
}
1726+
1727+
if isPeerCertsVerify && isPeerCerts {
1728+
certificates := c.msspi.VerifiedChains()
1729+
1730+
if certificates == nil {
1731+
return errors.New("tls: failed to verify peer certificate")
1732+
}
1733+
1734+
certs := make([]*x509.Certificate, len(certificates))
1735+
for i, asn1Data := range certificates {
1736+
cert, err := x509.ParseCertificate(asn1Data)
1737+
if err == nil {
1738+
certs[i] = cert
1739+
}
1740+
}
1741+
1742+
var verifiedChains [][]*x509.Certificate
1743+
c.verifiedChains = append(verifiedChains, certs)
1744+
}
1745+
1746+
atomic.StoreUint32(&c.handshakeStatus, 1)
1747+
1748+
return nil
1749+
}
1750+
15961751
// ConnectionState returns basic TLS details about the connection.
15971752
func (c *Conn) ConnectionState() ConnectionState {
15981753
c.handshakeMutex.Lock()

src/crypto/tls/tls.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import (
2525
"net"
2626
"os"
2727
"strings"
28+
29+
"go-msspi"
2830
)
2931

3032
// Server returns a new TLS server side connection
@@ -37,6 +39,28 @@ func Server(conn net.Conn, config *Config) *Conn {
3739
config: config,
3840
}
3941
c.handshakeFn = c.serverHandshake
42+
43+
// msspi
44+
if msspi.ByDefault || c.config.msspiConfig || (len(c.config.Certificates) > 0 && c.config.Certificates[0].msspiCert) {
45+
c.msspiConn = true
46+
c.handshakeFn = c.msspiHandshake
47+
48+
CertificateBytes := [][]byte{}
49+
for _, cert := range config.Certificates {
50+
CertificateBytes = append(CertificateBytes, cert.Certificate[0])
51+
}
52+
53+
c.msspi, c.msspiErr = msspi.Server(&c.conn, CertificateBytes, c.config.ClientAuth != NoClientCert)
54+
55+
if c.msspi == nil {
56+
return c
57+
}
58+
59+
if len(config.NextProtos) > 0 {
60+
c.msspi.SetNextProtos(config.NextProtos)
61+
}
62+
}
63+
4064
return c
4165
}
4266

@@ -51,6 +75,28 @@ func Client(conn net.Conn, config *Config) *Conn {
5175
isClient: true,
5276
}
5377
c.handshakeFn = c.clientHandshake
78+
79+
// msspi
80+
if msspi.ByDefault || c.config.msspiConfig || (len(c.config.Certificates) > 0 && c.config.Certificates[0].msspiCert) {
81+
c.msspiConn = true
82+
c.handshakeFn = c.msspiHandshake
83+
84+
CertificateBytes := [][]byte{}
85+
for _, cert := range config.Certificates {
86+
CertificateBytes = append(CertificateBytes, cert.Certificate[0])
87+
}
88+
89+
c.msspi, c.msspiErr = msspi.Client(&c.conn, CertificateBytes, c.config.ServerName)
90+
91+
if c.msspi == nil {
92+
return c
93+
}
94+
95+
if len(config.NextProtos) > 0 {
96+
c.msspi.SetNextProtos(config.NextProtos)
97+
}
98+
}
99+
54100
return c
55101
}
56102

@@ -245,6 +291,14 @@ func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
245291
func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
246292
fail := func(err error) (Certificate, error) { return Certificate{}, err }
247293

294+
// msspi
295+
if bytes.Equal(certPEMBlock, keyPEMBlock) {
296+
var cert Certificate
297+
cert.Certificate = append(cert.Certificate, certPEMBlock)
298+
cert.msspiCert = true
299+
return cert, nil
300+
}
301+
248302
var cert Certificate
249303
var skippedBlockTypes []string
250304
for {

src/go-msspi

Submodule go-msspi added at a24c4e7

src/go-msspi-test/msspi_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package msspitest
2+
3+
import (
4+
"crypto/tls"
5+
"fmt"
6+
"log"
7+
"net"
8+
"net/http"
9+
"runtime"
10+
"testing"
11+
"time"
12+
)
13+
14+
func clientGet(t *testing.T, host string, uri string) []byte {
15+
conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, 44333), 2*time.Second)
16+
if err != nil {
17+
t.Fatalf("Unexpected error on dial: %v", err)
18+
}
19+
defer conn.Close()
20+
21+
tlsConn := tls.Client(conn, &tls.Config{InsecureSkipVerify: true})
22+
defer tlsConn.Close()
23+
24+
wbuf := []byte("GET " + uri + " HTTP/1.1\r\nHost: " + host + "\r\n\r\n")
25+
if wlen, err := tlsConn.Write(wbuf); wlen != len(wbuf) || err != nil {
26+
t.Fatalf("Error sending: %v", err)
27+
}
28+
29+
rbuf := make([]byte, 16384)
30+
if rlen, err := tlsConn.Read(rbuf); rlen == 0 || err != nil {
31+
t.Fatalf("Error reading: %v", err)
32+
return nil
33+
} else {
34+
return rbuf[:rlen]
35+
}
36+
}
37+
38+
func HelloServer(w http.ResponseWriter, req *http.Request) {
39+
w.Header().Set("Content-Type", "text/plain")
40+
w.Write([]byte("This is an example server.\n"))
41+
}
42+
43+
func RunServer() {
44+
http.HandleFunc("/hello", HelloServer)
45+
err := http.ListenAndServeTLS(":44333", "server.crt", "server.crt", nil)
46+
if err != nil {
47+
log.Fatal("ListenAndServe: ", err)
48+
}
49+
}
50+
func TestMsspiServer(t *testing.T) {
51+
go RunServer()
52+
rbuf := clientGet(t, "localhost", "/hello")
53+
s := string(rbuf)
54+
fmt.Printf(s + "\n")
55+
for i := 0; i < 9999999; i++ {
56+
time.Sleep(time.Second)
57+
runtime.GC()
58+
}
59+
}

src/go-pointer

Submodule go-pointer added at 90e3959

0 commit comments

Comments
 (0)