Skip to content

Commit 0dbe072

Browse files
authored
refactor: complete rewrite of the frame type (#39)
With this experimental refactor, a decoded websocket frame is represented in memory as a single header byte and a payload byte slice. The single header byte gives us all the info we need at the application level (opcode, fin bit, RSV bits). Because we unmask the payload at read time and already require callers to specify a mask when writing a frame, we do not need to store the masked bit or the mask key in the frame itself. (This does, however, require passing a mode into ReadFrame so that we can still reject unmasked client frames.) And because correctly constructing the header byte for a frame is tricky, a new `NewFrame` constructor is now the only way to create a frame.
1 parent ef31006 commit 0dbe072

File tree

5 files changed

+196
-344
lines changed

5 files changed

+196
-344
lines changed

proto.go

Lines changed: 115 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ type Opcode uint8
3636
// See the RFC for the set of defined opcodes:
3737
// https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
3838
const (
39-
OpcodeContinuation Opcode = 0x0
40-
OpcodeText Opcode = 0x1
41-
OpcodeBinary Opcode = 0x2
42-
OpcodeClose Opcode = 0x8
43-
OpcodePing Opcode = 0x9
44-
OpcodePong Opcode = 0xA
39+
OpcodeContinuation Opcode = 0b0000_0000 // 0x0
40+
OpcodeText Opcode = 0b0000_0001 // 0x1
41+
OpcodeBinary Opcode = 0b0000_0010 // 0x2
42+
OpcodeClose Opcode = 0b0000_1000 // 0x8
43+
OpcodePing Opcode = 0b0000_1001 // 0x9
44+
OpcodePong Opcode = 0b0000_1010 // 0xA
4545
)
4646

4747
func (c Opcode) String() string {
@@ -82,33 +82,86 @@ const (
8282
StatusServerError StatusCode = 1011
8383
)
8484

85-
// RSV bit masks
8685
const (
87-
rsv1mask = 0b100
88-
rsv2mask = 0b010
89-
rsv3mask = 0b001
86+
// first header byte
87+
finMask = 0b1000_0000
88+
opcodeMask = 0b0000_1111
89+
90+
// second header byte
91+
maskedMask = 0b1000_0000
92+
payloadLenMask = 0b0111_1111
9093
)
9194

95+
// RSVBit is a bit mask for RSV bits 1-3
96+
type RSVBit byte
97+
98+
// RSV bits 1-3
99+
const (
100+
RSV1 RSVBit = 0b0100_0000
101+
RSV2 RSVBit = 0b0010_0000
102+
RSV3 RSVBit = 0b0001_0000
103+
)
104+
105+
// NewFrame creates a new websocket frame with the given opcode, fin bit, and
106+
// payload.
107+
func NewFrame(opcode Opcode, fin bool, payload []byte, rsv ...RSVBit) *Frame {
108+
f := &Frame{
109+
Payload: payload,
110+
}
111+
// Encode FIN, RSV1-3, and OPCODE.
112+
//
113+
// The second header byte encodes mask bit and payload size, but in-memory
114+
// frames are unmasked and the payload size is directly accessible.
115+
//
116+
// These bits will be encoded as necessary when writing a frame.
117+
if fin {
118+
f.header |= finMask
119+
}
120+
for _, r := range rsv {
121+
f.header |= byte(r)
122+
}
123+
f.header |= uint8(opcode) & opcodeMask
124+
return f
125+
}
126+
127+
// NewCloseFrame creates a close frame with an optional error message.
128+
func NewCloseFrame(code StatusCode, reason string) *Frame {
129+
var payload []byte
130+
if code > 0 {
131+
payload = make([]byte, 0, 2+len(reason))
132+
payload = binary.BigEndian.AppendUint16(payload, uint16(code))
133+
payload = append(payload, []byte(reason)...)
134+
}
135+
return NewFrame(OpcodeClose, true, payload)
136+
}
137+
92138
// Frame is a websocket protocol frame.
93139
type Frame struct {
94-
Fin bool
95-
RSV byte // Bits 0-2 represent RSV1-3
96-
Opcode Opcode
140+
header byte
97141
Payload []byte
98-
Masked bool
142+
}
143+
144+
// Fin returns a bool indicating whether the frame's FIN bit is set.
145+
func (f *Frame) Fin() bool {
146+
return f.header&finMask != 0
147+
}
148+
149+
// Opcode returns the the frame's OPCODE.
150+
func (f *Frame) Opcode() Opcode {
151+
return Opcode(f.header & opcodeMask)
99152
}
100153

101154
// RSV1 returns true if the frame's RSV1 bit is set
102-
func (f *Frame) RSV1() bool { return f.RSV&rsv1mask != 0 }
155+
func (f *Frame) RSV1() bool { return f.header&byte(RSV1) != 0 }
103156

104157
// RSV2 returns true if the frame's RSV2 bit is set
105-
func (f *Frame) RSV2() bool { return f.RSV&rsv2mask != 0 }
158+
func (f *Frame) RSV2() bool { return f.header&byte(RSV2) != 0 }
106159

107160
// RSV3 returns true if the frame's RSV3 bit is set
108-
func (f *Frame) RSV3() bool { return f.RSV&rsv3mask != 0 }
161+
func (f *Frame) RSV3() bool { return f.header&byte(RSV3) != 0 }
109162

110163
func (f Frame) String() string {
111-
return fmt.Sprintf("Frame{Fin: %v, Opcode: %v, Payload: %s}", f.Fin, f.Opcode, formatPayload(f.Payload))
164+
return fmt.Sprintf("Frame{Fin: %v, Opcode: %v, Payload: %s}", f.Fin(), f.Opcode(), formatPayload(f.Payload))
112165
}
113166

114167
// Message is an application-level message from the client, which may be
@@ -132,26 +185,23 @@ var formatPayload = func(p []byte) string {
132185
}
133186

134187
// ReadFrame reads a websocket frame from the wire.
135-
func ReadFrame(buf io.Reader, maxPayloadLen int) (*Frame, error) {
136-
bb := make([]byte, 2)
137-
if _, err := io.ReadFull(buf, bb); err != nil {
188+
func ReadFrame(buf io.Reader, mode Mode, maxPayloadLen int) (*Frame, error) {
189+
header := make([]byte, 2)
190+
if _, err := io.ReadFull(buf, header); err != nil {
138191
return nil, fmt.Errorf("error reading frame header: %w", err)
139192
}
140193

141-
// parse first header byte
194+
// figure out how to parse payload
142195
var (
143-
b0 = bb[0]
144-
fin = b0&0b1000_0000 != 0
145-
rsv = (b0 >> 4) & 0b0111
146-
opcode = Opcode(b0 & 0b0000_1111)
196+
masked = header[1]&maskedMask != 0
197+
payloadLen = uint64(header[1] & payloadLenMask)
147198
)
148199

149-
// parse second header byte
150-
var (
151-
b1 = bb[1]
152-
masked = b1&0b1000_0000 != 0
153-
payloadLen = uint64(b1 & 0b0111_1111)
154-
)
200+
// If the data is being sent by the client, the frame(s) MUST be masked
201+
// https://datatracker.ietf.org/doc/html/rfc6455#section-6.1
202+
if mode == ServerMode && !masked {
203+
return nil, ErrClientFrameUnmasked
204+
}
155205

156206
// Payload lengths 0 to 125 encoded directly in last 7 bits of payload
157207
// field, but we may need to read an "extended" payload length.
@@ -195,13 +245,9 @@ func ReadFrame(buf io.Reader, maxPayloadLen int) (*Frame, error) {
195245
payload[i] = b ^ mask[i%4]
196246
}
197247
}
198-
199248
return &Frame{
200-
Fin: fin,
201-
RSV: rsv,
202-
Opcode: opcode,
249+
header: header[0],
203250
Payload: payload,
204-
Masked: masked,
205251
}, nil
206252
}
207253

@@ -218,17 +264,10 @@ func WriteFrame(dst io.Writer, mask MaskingKey, frame *Frame) error {
218264
func MarshalFrame(frame *Frame, mask MaskingKey) []byte {
219265
// worst case payload size is 13 header bytes + payload size, where 13 is
220266
// (1 byte header) + (1-8 byte length) + (0-4 byte mask key)
221-
buf := make([]byte, 0, 13+len(frame.Payload))
267+
buf := make([]byte, 0, marshaledSize(frame, mask))
222268
masked := mask != Unmasked
223269

224-
// FIN, RSV1-3, OPCODE
225-
var b0 byte
226-
if frame.Fin {
227-
b0 |= 0b1000_0000
228-
}
229-
b0 |= (frame.RSV & 0b111) << 4
230-
b0 |= uint8(frame.Opcode) & 0b0000_1111
231-
buf = append(buf, b0)
270+
buf = append(buf, frame.header)
232271

233272
// Masked bit, payload length
234273
var b1 byte
@@ -260,24 +299,25 @@ func MarshalFrame(frame *Frame, mask MaskingKey) []byte {
260299
return buf
261300
}
262301

263-
// NewCloseFrame creates a close frame with an optional error message.
264-
func NewCloseFrame(code StatusCode, reason string) *Frame {
265-
var payload []byte
266-
if code > 0 {
267-
payload = make([]byte, 0, 2+len(reason))
268-
payload = binary.BigEndian.AppendUint16(payload, uint16(code))
269-
payload = append(payload, []byte(reason)...)
302+
// marshaledSize returns the number of bytes required to marshal a frame.
303+
func marshaledSize(f *Frame, mask MaskingKey) int {
304+
payloadLen := len(f.Payload)
305+
size := 2 + payloadLen
306+
switch {
307+
case payloadLen >= 64<<10:
308+
size += 8
309+
case payloadLen > 125:
310+
size += 2
270311
}
271-
return &Frame{
272-
Fin: true,
273-
Opcode: OpcodeClose,
274-
Payload: payload,
312+
if mask != Unmasked {
313+
size += 4
275314
}
315+
return size
276316
}
277317

278-
// messageFrames splits a message into N frames with payloads of at most
318+
// FrameMessage splits a message into N frames with payloads of at most
279319
// frameSize bytes.
280-
func messageFrames(msg *Message, frameSize int) []*Frame {
320+
func FrameMessage(msg *Message, frameSize int) []*Frame {
281321
var result []*Frame
282322

283323
fin := false
@@ -297,11 +337,7 @@ func messageFrames(msg *Message, frameSize int) []*Frame {
297337
fin = true
298338
end = dataLen
299339
}
300-
result = append(result, &Frame{
301-
Fin: fin,
302-
Opcode: opcode,
303-
Payload: msg.Payload[offset:end],
304-
})
340+
result = append(result, NewFrame(opcode, fin, msg.Payload[offset:end]))
305341
if fin {
306342
break
307343
}
@@ -327,38 +363,40 @@ var reservedStatusCodes = map[uint16]bool{
327363
2999: true,
328364
}
329365

330-
func validateFrame(frame *Frame, mode Mode) error {
366+
// validateFrame ensures that the frame is valid.
367+
//
368+
// TODO: validate in ReadFrame instead of ReadMessage.
369+
func validateFrame(frame *Frame) error {
331370
// We do not support any extensions, per the spec all RSV bits must be 0:
332371
// https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
333-
if frame.RSV != 0 {
372+
if frame.header&uint8(RSV1|RSV2|RSV3) != 0 {
334373
return ErrRSVBitsUnsupported
335374
}
336375

337-
// If the data is being sent by the client, the frame(s) MUST be masked
338-
// https://datatracker.ietf.org/doc/html/rfc6455#section-6.1
339-
if mode == ServerMode && !frame.Masked {
340-
return ErrClientFrameUnmasked
341-
}
342-
343-
payloadLen := len(frame.Payload)
344-
345-
switch frame.Opcode {
376+
var (
377+
opcode = frame.Opcode()
378+
payloadLen = len(frame.Payload)
379+
)
380+
switch opcode {
346381
case OpcodeClose, OpcodePing, OpcodePong:
347382
// All control frames MUST have a payload length of 125 bytes or less
348383
// and MUST NOT be fragmented.
349384
// https://datatracker.ietf.org/doc/html/rfc6455#section-5.5
350385
if payloadLen > 125 {
351386
return ErrControlFrameTooLarge
352387
}
353-
if !frame.Fin {
388+
if !frame.Fin() {
354389
return ErrControlFrameFragmented
355390
}
356391
}
357392

358-
if frame.Opcode == OpcodeClose {
393+
if opcode == OpcodeClose {
359394
if payloadLen == 0 {
360395
return nil
361396
}
397+
// if a close frame has a payload, the first two bytes MUST encode a
398+
// closing status code.
399+
// https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1
362400
if payloadLen == 1 {
363401
return ErrClosePayloadInvalid
364402
}

0 commit comments

Comments
 (0)