Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 125 additions & 90 deletions proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ type Opcode uint8
// See the RFC for the set of defined opcodes:
// https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
const (
OpcodeContinuation Opcode = 0x0
OpcodeText Opcode = 0x1
OpcodeBinary Opcode = 0x2
OpcodeClose Opcode = 0x8
OpcodePing Opcode = 0x9
OpcodePong Opcode = 0xA
OpcodeContinuation Opcode = 0b0000_0000 // 0x0
OpcodeText Opcode = 0b0000_0001 // 0x1
OpcodeBinary Opcode = 0b0000_0010 // 0x2
OpcodeClose Opcode = 0b0000_1000 // 0x8
OpcodePing Opcode = 0b0000_1001 // 0x9
OpcodePong Opcode = 0b0000_1010 // 0xA
Comment on lines -39 to +44
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh this extra verbose binary representation makes it easier for me to understand all the bitwise operations happening here.

)

func (c Opcode) String() string {
Expand Down Expand Up @@ -82,33 +82,91 @@ const (
StatusServerError StatusCode = 1011
)

// RSV bit masks
const (
rsv1mask = 0b100
rsv2mask = 0b010
rsv3mask = 0b001
// first header byte
finMask = 0b1000_0000
opcodeMask = 0b0000_1111

// second header byte
maskedMask = 0b1000_0000
payloadLenMask = 0b0111_1111
)

// RSVBit is a bit mask for RSV bits 1-3
type RSVBit byte

// RSV bits 1-3
const (
RSV1 RSVBit = 0b0100_0000
RSV2 RSVBit = 0b0010_0000
RSV3 RSVBit = 0b0001_0000
)

// NewFrame creates a new websocket frame with the given opcode, fin bit, and
// payload.
func NewFrame(opcode Opcode, fin bool, payload []byte, rsv ...RSVBit) *Frame {
f := &Frame{
payload: payload,
}
// Encode FIN, RSV1-3, and OPCODE.
//
// The second header byte encodes mask bit and payload size, but in-memory
// frames are unmasked and the payload size is directly accessible.
//
// These bits will be encoded as necessary when writing a frame.
if fin {
f.header |= finMask
}
for _, r := range rsv {
f.header |= byte(r)
}
f.header |= uint8(opcode) & opcodeMask
return f
}

// NewCloseFrame creates a close frame with an optional error message.
func NewCloseFrame(code StatusCode, reason string) *Frame {
var payload []byte
if code > 0 {
payload = make([]byte, 0, 2+len(reason))
payload = binary.BigEndian.AppendUint16(payload, uint16(code))
payload = append(payload, []byte(reason)...)
}
return NewFrame(OpcodeClose, true, payload)
}

// Frame is a websocket protocol frame.
type Frame struct {
Fin bool
RSV byte // Bits 0-2 represent RSV1-3
Opcode Opcode
Payload []byte
Masked bool
header byte
payload []byte
}

// Fin returns a bool indicating whether the frame's FIN bit is set.
func (f *Frame) Fin() bool {
return f.header&finMask != 0
}

// Opcode returns the the frame's OPCODE.
func (f *Frame) Opcode() Opcode {
return Opcode(f.header & opcodeMask)
}

// RSV1 returns true if the frame's RSV1 bit is set
func (f *Frame) RSV1() bool { return f.RSV&rsv1mask != 0 }
func (f *Frame) RSV1() bool { return f.header&byte(RSV1) != 0 }

// RSV2 returns true if the frame's RSV2 bit is set
func (f *Frame) RSV2() bool { return f.RSV&rsv2mask != 0 }
func (f *Frame) RSV2() bool { return f.header&byte(RSV2) != 0 }

// RSV3 returns true if the frame's RSV3 bit is set
func (f *Frame) RSV3() bool { return f.RSV&rsv3mask != 0 }
func (f *Frame) RSV3() bool { return f.header&byte(RSV3) != 0 }

// Payload returns the frame's payload.
func (f *Frame) Payload() []byte {
return f.payload
}

func (f Frame) String() string {
return fmt.Sprintf("Frame{Fin: %v, Opcode: %v, Payload: %s}", f.Fin, f.Opcode, formatPayload(f.Payload))
return fmt.Sprintf("Frame{Fin: %v, Opcode: %v, Payload: %s}", f.Fin(), f.Opcode(), formatPayload(f.Payload()))
}

// Message is an application-level message from the client, which may be
Expand All @@ -132,27 +190,28 @@ var formatPayload = func(p []byte) string {
}

// ReadFrame reads a websocket frame from the wire.
func ReadFrame(buf io.Reader, maxPayloadLen int) (*Frame, error) {
bb := make([]byte, 2)
if _, err := io.ReadFull(buf, bb); err != nil {
func ReadFrame(buf io.Reader, mode Mode, maxPayloadLen int) (*Frame, error) {
header := make([]byte, 2)
if _, err := io.ReadFull(buf, header); err != nil {
return nil, fmt.Errorf("error reading frame header: %w", err)
}

// parse first header byte
var (
b0 = bb[0]
fin = b0&0b1000_0000 != 0
rsv = (b0 >> 4) & 0b0111
opcode = Opcode(b0 & 0b0000_1111)
)
frame := &Frame{
header: header[0],
}

// parse second header byte
// figure out how to parse payload
var (
b1 = bb[1]
masked = b1&0b1000_0000 != 0
payloadLen = uint64(b1 & 0b0111_1111)
masked = header[1]&maskedMask != 0
payloadLen = uint64(header[1] & payloadLenMask)
)

// If the data is being sent by the client, the frame(s) MUST be masked
// https://datatracker.ietf.org/doc/html/rfc6455#section-6.1
if mode == ServerMode && !masked {
return nil, ErrClientFrameUnmasked
}

// Payload lengths 0 to 125 encoded directly in last 7 bits of payload
// field, but we may need to read an "extended" payload length.
switch payloadLen {
Expand Down Expand Up @@ -186,23 +245,16 @@ func ReadFrame(buf io.Reader, maxPayloadLen int) (*Frame, error) {
}

// read & optionally unmask payload
payload := make([]byte, payloadLen)
if _, err := io.ReadFull(buf, payload); err != nil {
frame.payload = make([]byte, payloadLen)
if _, err := io.ReadFull(buf, frame.Payload()); err != nil {
return nil, fmt.Errorf("error reading %d byte payload: %w", payloadLen, err)
}
if masked {
for i, b := range payload {
payload[i] = b ^ mask[i%4]
for i, b := range frame.Payload() {
frame.Payload()[i] = b ^ mask[i%4]
}
}

return &Frame{
Fin: fin,
RSV: rsv,
Opcode: opcode,
Payload: payload,
Masked: masked,
}, nil
return frame, nil
}

// WriteFrame writes a masked websocket frame to the wire.
Expand All @@ -218,25 +270,18 @@ func WriteFrame(dst io.Writer, mask MaskingKey, frame *Frame) error {
func MarshalFrame(frame *Frame, mask MaskingKey) []byte {
// worst case payload size is 13 header bytes + payload size, where 13 is
// (1 byte header) + (1-8 byte length) + (0-4 byte mask key)
buf := make([]byte, 0, 13+len(frame.Payload))
buf := make([]byte, 0, marshaledSize(frame, mask))
masked := mask != Unmasked

// FIN, RSV1-3, OPCODE
var b0 byte
if frame.Fin {
b0 |= 0b1000_0000
}
b0 |= (frame.RSV & 0b111) << 4
b0 |= uint8(frame.Opcode) & 0b0000_1111
buf = append(buf, b0)
buf = append(buf, frame.header)

// Masked bit, payload length
var b1 byte
if masked {
b1 |= 0b1000_0000
}

payloadLen := int64(len(frame.Payload))
payloadLen := int64(len(frame.Payload()))
switch {
case payloadLen <= 125:
buf = append(buf, b1|byte(payloadLen))
Expand All @@ -251,33 +296,34 @@ func MarshalFrame(frame *Frame, mask MaskingKey) []byte {
// Optional masking key and actual payload
if masked {
buf = append(buf, mask[:]...)
for i, b := range frame.Payload {
for i, b := range frame.Payload() {
buf = append(buf, b^mask[i%4])
}
} else {
buf = append(buf, frame.Payload...)
buf = append(buf, frame.Payload()...)
}
return buf
}

// NewCloseFrame creates a close frame with an optional error message.
func NewCloseFrame(code StatusCode, reason string) *Frame {
var payload []byte
if code > 0 {
payload = make([]byte, 0, 2+len(reason))
payload = binary.BigEndian.AppendUint16(payload, uint16(code))
payload = append(payload, []byte(reason)...)
// marshaledSize returns the number of bytes required to marshal a frame.
func marshaledSize(f *Frame, mask MaskingKey) int {
payloadLen := len(f.Payload())
size := 2 + payloadLen
switch {
case payloadLen >= 64<<10:
size += 8
case payloadLen > 125:
size += 2
}
return &Frame{
Fin: true,
Opcode: OpcodeClose,
Payload: payload,
if mask != Unmasked {
size += 4
}
return size
}

// messageFrames splits a message into N frames with payloads of at most
// FrameMessage splits a message into N frames with payloads of at most
// frameSize bytes.
func messageFrames(msg *Message, frameSize int) []*Frame {
func FrameMessage(msg *Message, frameSize int) []*Frame {
var result []*Frame

fin := false
Expand All @@ -297,11 +343,7 @@ func messageFrames(msg *Message, frameSize int) []*Frame {
fin = true
end = dataLen
}
result = append(result, &Frame{
Fin: fin,
Opcode: opcode,
Payload: msg.Payload[offset:end],
})
result = append(result, NewFrame(opcode, fin, msg.Payload[offset:end]))
if fin {
break
}
Expand All @@ -327,50 +369,43 @@ var reservedStatusCodes = map[uint16]bool{
2999: true,
}

func validateFrame(frame *Frame, mode Mode) error {
func validateFrame(frame *Frame) error {
// We do not support any extensions, per the spec all RSV bits must be 0:
// https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
if frame.RSV != 0 {
if frame.header&uint8(RSV1|RSV2|RSV3) != 0 {
return ErrRSVBitsUnsupported
}

// If the data is being sent by the client, the frame(s) MUST be masked
// https://datatracker.ietf.org/doc/html/rfc6455#section-6.1
if mode == ServerMode && !frame.Masked {
return ErrClientFrameUnmasked
}

payloadLen := len(frame.Payload)

switch frame.Opcode {
payloadLen := len(frame.Payload())
switch frame.Opcode() {
case OpcodeClose, OpcodePing, OpcodePong:
// All control frames MUST have a payload length of 125 bytes or less
// and MUST NOT be fragmented.
// https://datatracker.ietf.org/doc/html/rfc6455#section-5.5
if payloadLen > 125 {
return ErrControlFrameTooLarge
}
if !frame.Fin {
if !frame.Fin() {
return ErrControlFrameFragmented
}
}

if frame.Opcode == OpcodeClose {
if frame.Opcode() == OpcodeClose {
if payloadLen == 0 {
return nil
}
if payloadLen == 1 {
return ErrClosePayloadInvalid
}

code := binary.BigEndian.Uint16(frame.Payload[:2])
code := binary.BigEndian.Uint16(frame.Payload()[:2])
if code < 1000 || code >= 5000 {
return ErrCloseStatusInvalid
}
if reservedStatusCodes[code] {
return ErrCloseStatusReserved
}
if payloadLen > 2 && !utf8.Valid(frame.Payload[2:]) {
if payloadLen > 2 && !utf8.Valid(frame.Payload()[2:]) {
return ErrEncodingInvalid
}
}
Expand Down
Loading
Loading