Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ var (
ErrUnsupportedRSVBits = errors.New("frame has unsupported RSV bits set")
)

var zeroMask [4]byte

// ClientKey is a websocket client key.
type ClientKey string

Expand Down Expand Up @@ -157,7 +159,7 @@ func ReadFrame(buf io.Reader) (*Frame, error) {
// read & optionally unmask payload
payload := make([]byte, payloadLength)
if _, err := io.ReadFull(buf, payload); err != nil {
return nil, fmt.Errorf("error reading payload: %w", err)
return nil, fmt.Errorf("error reading %d byte payload: %w", payloadLength, err)
}
if masked {
for i, b := range payload {
Expand All @@ -176,11 +178,18 @@ func ReadFrame(buf io.Reader) (*Frame, error) {
}, nil
}

// WriteFrame writes a websocket frame to the wire.
// WriteFrame writes an unmasked (i.e. server-side) websocket frame to the
// wire.
func WriteFrame(dst io.Writer, frame *Frame) error {
return WriteFrameMasked(dst, frame, zeroMask)
}

// WriteFrameMasked writes a masked websocket frame to the wire.
func WriteFrameMasked(dst io.Writer, frame *Frame, mask [4]byte) error {
// worst case payload size is 13 header bytes + payload size, where 13 is
// (1 byte header) + (1-8 byte length) + (0-4 byte mask key)
buf := make([]byte, 0, 13+len(frame.Payload))
masked := mask != zeroMask

// FIN, RSV1-3, OPCODE
var b0 byte
Expand All @@ -199,21 +208,33 @@ func WriteFrame(dst io.Writer, frame *Frame) error {
b0 |= uint8(frame.Opcode) & 0b00001111
buf = append(buf, b0)

// payload length
// Masked bit, payload length
var b1 byte
if masked {
b1 |= 0b10000000
}

payloadLen := int64(len(frame.Payload))
switch {
case payloadLen <= 125:
buf = append(buf, byte(payloadLen))
buf = append(buf, b1|byte(payloadLen))
case payloadLen <= 65535:
buf = append(buf, 126)
buf = append(buf, b1|126)
buf = binary.BigEndian.AppendUint16(buf, uint16(payloadLen))
default:
buf = append(buf, 127)
buf = append(buf, b1|127)
buf = binary.BigEndian.AppendUint64(buf, uint64(payloadLen))
}

// payload
buf = append(buf, frame.Payload...)
// Optional masking key and actual payload
if masked {
buf = append(buf, mask[:]...)
for i, b := range frame.Payload {
buf = append(buf, b^mask[i%4])
}
} else {
buf = append(buf, frame.Payload...)
}

n, err := dst.Write(buf)
if err != nil {
Expand Down
34 changes: 34 additions & 0 deletions proto_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package websocket_test

import (
"bytes"
"testing"

"github.com/mccutchen/websocket"
"github.com/mccutchen/websocket/internal/testing/assert"
)

func TestFrameRoundTrip(t *testing.T) {
// Basic test to ensure that we can read back the same frame that we
// write.
t.Parallel()

// write masked "client" frame to buffer
clientFrame := &websocket.Frame{
Opcode: websocket.OpcodeText,
Fin: true,
Payload: []byte("hello"),
}
mask := [4]byte{1, 2, 3, 4}
buf := &bytes.Buffer{}
assert.NilError(t, websocket.WriteFrameMasked(buf, clientFrame, mask))

// read "server" frame from buffer
serverFrame, err := websocket.ReadFrame(buf)
assert.NilError(t, err)

// ensure client and server frame match
assert.Equal(t, serverFrame.Fin, clientFrame.Fin, "expected matching FIN bits")
assert.Equal(t, serverFrame.Opcode, clientFrame.Opcode, "expected matching opcodes")
assert.Equal(t, string(serverFrame.Payload), string(clientFrame.Payload), "expected matching payloads")
}
78 changes: 76 additions & 2 deletions websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package websocket_test
import (
"bufio"
"bytes"
"crypto/rand"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
Expand Down Expand Up @@ -291,6 +293,69 @@ func TestConnectionLimits(t *testing.T) {
})
}

// TODO: flesh out basic protocol test cases
// - successful echo
// - successful echo across multiple frames
// - frame size limits
// - message size limits
// - utf8 validation
// - unexpected continuation frames
func TestProtocolBasics(t *testing.T) {
var (
maxDuration = 250 * time.Millisecond
maxFragmentSize = 16
maxMessageSize = 32
)

echoHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ws, err := websocket.Accept(w, r, websocket.Options{
ReadTimeout: maxDuration,
WriteTimeout: maxDuration,
MaxFragmentSize: maxFragmentSize,
MaxMessageSize: maxMessageSize,
Hooks: newTestHooks(t),
})
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
ws.Serve(r.Context(), websocket.EchoHandler)
})

t.Run("basic echo functionality", func(t *testing.T) {
t.Parallel()
_, conn := setupRawConn(t, echoHandler)
// write client frame
clientFrame := &websocket.Frame{
Opcode: websocket.OpcodeText,
Fin: true,
Payload: []byte("hello"),
}
mask := [4]byte{1, 2, 3, 4}
assert.NilError(t, websocket.WriteFrameMasked(conn, clientFrame, mask))
// read server frame
serverFrame, err := websocket.ReadFrame(conn)
assert.NilError(t, err)
// ensure we get back the same frame
assert.Equal(t, serverFrame.Fin, clientFrame.Fin, "expected matching FIN bits")
assert.Equal(t, serverFrame.Opcode, clientFrame.Opcode, "expected matching opcodes")
assert.Equal(t, string(serverFrame.Payload), string(clientFrame.Payload), "expected matching payloads")
})

t.Run("server requires masked frames", func(t *testing.T) {
t.Parallel()
_, conn := setupRawConn(t, echoHandler)
frame := &websocket.Frame{
Opcode: websocket.OpcodeText,
Fin: true,
Payload: []byte("hello"),
}
assert.NilError(t, websocket.WriteFrame(conn, frame))
validateCloseFrame(t, conn, websocket.StatusProtocolError, "received unmasked client frame")
})

}

// setupRawConn is a test helpers that runs a test server and does the
// initial websocket handshake. The returned connection is ready for use to
// sent/receive websocket messages.
Expand All @@ -310,7 +375,7 @@ func setupRawConn(t *testing.T, handler http.Handler) (*httptest.Server, net.Con
for k, v := range map[string]string{
"Connection": "upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Key": makeClientKey(),
"Sec-WebSocket-Version": "13",
} {
handshakeReq.Header.Set(k, v)
Expand All @@ -325,6 +390,15 @@ func setupRawConn(t *testing.T, handler http.Handler) (*httptest.Server, net.Con
return srv, conn
}

func makeClientKey() string {
b := make([]byte, 16)
_, err := rand.Read(b)
if err != nil {
panic(fmt.Sprintf("failed to read random bytes: %s", err))
}
return base64.StdEncoding.EncodeToString(b)
}

// validateCloseFrame ensures that we can read a close frame from the given
// reader and optionally ensures that the close frame includes a specific
// status code and message.
Expand All @@ -344,7 +418,7 @@ func validateCloseFrame(t *testing.T, r io.Reader, wantStatus websocket.StatusCo
statusCode := websocket.StatusCode(binary.BigEndian.Uint16(frame.Payload[:2]))
closeMsg := string(frame.Payload[2:])
t.Logf("got close frame: code=%v msg=%q", statusCode, closeMsg)
assert.Equal(t, statusCode, websocket.StatusServerError, "got incorrect close status code")
assert.Equal(t, int(statusCode), int(wantStatus), "got incorrect close status code")
assert.Contains(t, closeMsg, wantMsg, "got incorrect close message")
}

Expand Down
Loading