Skip to content

Commit ef31006

Browse files
authored
refactor: store RSV bits more efficiently (#38)
Storing a single byte rather than three boolean values should be more efficient, especially since (for now) we only need to test whether _any_ RSV bit is set, and now that can happen in a single comparison.
1 parent a558272 commit ef31006

File tree

4 files changed

+211
-28
lines changed

4 files changed

+211
-28
lines changed

internal/testing/assert/assert.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func NilError(t testing.TB, err error) {
4949
func Error(t testing.TB, got, expected error) {
5050
t.Helper()
5151
if !errorsMatch(t, got, expected) {
52-
t.Fatalf("expected error %q, got %v (%T vs %T)", expected, got, expected, got)
52+
t.Fatalf("expected error %q, got %q (%T vs %T)", expected, got, expected, got)
5353
}
5454
}
5555

proto.go

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,31 @@ const (
8282
StatusServerError StatusCode = 1011
8383
)
8484

85+
// RSV bit masks
86+
const (
87+
rsv1mask = 0b100
88+
rsv2mask = 0b010
89+
rsv3mask = 0b001
90+
)
91+
8592
// Frame is a websocket protocol frame.
8693
type Frame struct {
8794
Fin bool
88-
RSV1 bool
89-
RSV3 bool
90-
RSV2 bool
95+
RSV byte // Bits 0-2 represent RSV1-3
9196
Opcode Opcode
9297
Payload []byte
9398
Masked bool
9499
}
95100

101+
// RSV1 returns true if the frame's RSV1 bit is set
102+
func (f *Frame) RSV1() bool { return f.RSV&rsv1mask != 0 }
103+
104+
// RSV2 returns true if the frame's RSV2 bit is set
105+
func (f *Frame) RSV2() bool { return f.RSV&rsv2mask != 0 }
106+
107+
// RSV3 returns true if the frame's RSV3 bit is set
108+
func (f *Frame) RSV3() bool { return f.RSV&rsv3mask != 0 }
109+
96110
func (f Frame) String() string {
97111
return fmt.Sprintf("Frame{Fin: %v, Opcode: %v, Payload: %s}", f.Fin, f.Opcode, formatPayload(f.Payload))
98112
}
@@ -127,18 +141,16 @@ func ReadFrame(buf io.Reader, maxPayloadLen int) (*Frame, error) {
127141
// parse first header byte
128142
var (
129143
b0 = bb[0]
130-
fin = b0&0b10000000 != 0
131-
rsv1 = b0&0b01000000 != 0
132-
rsv2 = b0&0b00100000 != 0
133-
rsv3 = b0&0b00010000 != 0
134-
opcode = Opcode(b0 & 0b00001111)
144+
fin = b0&0b1000_0000 != 0
145+
rsv = (b0 >> 4) & 0b0111
146+
opcode = Opcode(b0 & 0b0000_1111)
135147
)
136148

137149
// parse second header byte
138150
var (
139151
b1 = bb[1]
140-
masked = b1&0b10000000 != 0
141-
payloadLen = uint64(b1 & 0b01111111)
152+
masked = b1&0b1000_0000 != 0
153+
payloadLen = uint64(b1 & 0b0111_1111)
142154
)
143155

144156
// Payload lengths 0 to 125 encoded directly in last 7 bits of payload
@@ -186,9 +198,7 @@ func ReadFrame(buf io.Reader, maxPayloadLen int) (*Frame, error) {
186198

187199
return &Frame{
188200
Fin: fin,
189-
RSV1: rsv1,
190-
RSV2: rsv2,
191-
RSV3: rsv3,
201+
RSV: rsv,
192202
Opcode: opcode,
193203
Payload: payload,
194204
Masked: masked,
@@ -214,24 +224,16 @@ func MarshalFrame(frame *Frame, mask MaskingKey) []byte {
214224
// FIN, RSV1-3, OPCODE
215225
var b0 byte
216226
if frame.Fin {
217-
b0 |= 0b10000000
218-
}
219-
if frame.RSV1 {
220-
b0 |= 0b01000000
221-
}
222-
if frame.RSV2 {
223-
b0 |= 0b00100000
224-
}
225-
if frame.RSV3 {
226-
b0 |= 0b00010000
227+
b0 |= 0b1000_0000
227228
}
228-
b0 |= uint8(frame.Opcode) & 0b00001111
229+
b0 |= (frame.RSV & 0b111) << 4
230+
b0 |= uint8(frame.Opcode) & 0b0000_1111
229231
buf = append(buf, b0)
230232

231233
// Masked bit, payload length
232234
var b1 byte
233235
if masked {
234-
b1 |= 0b10000000
236+
b1 |= 0b1000_0000
235237
}
236238

237239
payloadLen := int64(len(frame.Payload))
@@ -328,7 +330,7 @@ var reservedStatusCodes = map[uint16]bool{
328330
func validateFrame(frame *Frame, mode Mode) error {
329331
// We do not support any extensions, per the spec all RSV bits must be 0:
330332
// https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
331-
if frame.RSV1 || frame.RSV2 || frame.RSV3 {
333+
if frame.RSV != 0 {
332334
return ErrRSVBitsUnsupported
333335
}
334336

proto_test.go

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package websocket_test
22

33
import (
44
"bytes"
5+
"errors"
56
"testing"
67

78
"github.com/mccutchen/websocket"
@@ -51,3 +52,183 @@ func TestMaxFrameSize(t *testing.T) {
5152
assert.Error(t, err, websocket.ErrFrameTooLarge)
5253
assert.Equal(t, serverFrame, nil, "expected nil frame on error")
5354
}
55+
56+
func TestRSV(t *testing.T) {
57+
// We don't currently support any extensions, so RSV bits are not allowed.
58+
// But we still need to be able properly parse and marshal them.
59+
const (
60+
finBit = 0b1000_0000
61+
txtOpcodeBit = 0b0000_0001
62+
rsv1bit = 0b0100_0000
63+
rsv2bit = 0b0010_0000
64+
rsv3bit = 0b0001_0000
65+
)
66+
testCases := map[string]struct {
67+
rawBytes []byte
68+
wantRSV1 bool
69+
wantRSV2 bool
70+
wantRSV3 bool
71+
}{
72+
"no RSV bits set": {
73+
rawBytes: []byte{0x81, 0x00},
74+
},
75+
"RSV1 set": {
76+
rawBytes: []byte{finBit | rsv1bit | txtOpcodeBit, 0x00},
77+
wantRSV1: true,
78+
},
79+
"RSV2 set": {
80+
rawBytes: []byte{finBit | rsv2bit | txtOpcodeBit, 0x00},
81+
wantRSV2: true,
82+
},
83+
"RSV3 set": {
84+
rawBytes: []byte{finBit | rsv3bit | txtOpcodeBit, 0x00},
85+
wantRSV3: true,
86+
},
87+
"all RSV bits set": {
88+
rawBytes: []byte{finBit | rsv1bit | rsv2bit | rsv3bit | txtOpcodeBit, 0x00},
89+
wantRSV1: true,
90+
wantRSV2: true,
91+
wantRSV3: true,
92+
},
93+
}
94+
for name, tc := range testCases {
95+
tc := tc
96+
t.Run(name, func(t *testing.T) {
97+
t.Parallel()
98+
buf := bytes.NewReader(tc.rawBytes)
99+
frame := mustReadFrame(t, buf, len(tc.rawBytes))
100+
assert.Equal(t, frame.RSV1(), tc.wantRSV1, "incorrect RSV1")
101+
assert.Equal(t, frame.RSV2(), tc.wantRSV2, "incorrect RSV2")
102+
assert.Equal(t, frame.RSV3(), tc.wantRSV3, "incorrect RSV3")
103+
})
104+
}
105+
}
106+
107+
func TestExampleFramesFromRFC(t *testing.T) {
108+
// This tests every example provided in RFC 6455 section 5.7:
109+
// https://datatracker.ietf.org/doc/html/rfc6455#section-5.7
110+
testCases := map[string]struct {
111+
rawBytes []byte
112+
wantFrame *websocket.Frame
113+
}{
114+
"single-frame unmasked text": {
115+
rawBytes: []byte{0x81, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
116+
wantFrame: &websocket.Frame{
117+
Fin: true,
118+
Opcode: websocket.OpcodeText,
119+
Payload: []byte("Hello"),
120+
Masked: false,
121+
},
122+
},
123+
"single-frame masked text": {
124+
rawBytes: []byte{0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58},
125+
wantFrame: &websocket.Frame{
126+
Fin: true,
127+
Opcode: websocket.OpcodeText,
128+
Payload: []byte("Hello"),
129+
Masked: true,
130+
},
131+
},
132+
"fragmented unmasked text part 1": {
133+
rawBytes: []byte{0x01, 0x03, 0x48, 0x65, 0x6c},
134+
wantFrame: &websocket.Frame{
135+
Fin: false,
136+
Opcode: websocket.OpcodeText,
137+
Payload: []byte("Hel"),
138+
},
139+
},
140+
"fragmented unmasked text part 2": {
141+
rawBytes: []byte{0x80, 0x02, 0x6c, 0x6f},
142+
wantFrame: &websocket.Frame{
143+
Fin: true,
144+
Opcode: websocket.OpcodeContinuation,
145+
Payload: []byte("lo"),
146+
},
147+
},
148+
"unmasked ping": {
149+
rawBytes: []byte{
150+
0x89, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f,
151+
},
152+
wantFrame: &websocket.Frame{
153+
Fin: true,
154+
Opcode: websocket.OpcodePing,
155+
Payload: []byte("Hello"),
156+
},
157+
},
158+
"masked ping response": {
159+
rawBytes: []byte{0x8a, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58},
160+
wantFrame: &websocket.Frame{
161+
Fin: true,
162+
Opcode: websocket.OpcodePong,
163+
Payload: []byte("Hello"),
164+
Masked: true,
165+
},
166+
},
167+
"256 bytes binary message": {
168+
rawBytes: append(
169+
[]byte{0x82, 0x7E, 0x01, 0x00},
170+
make([]byte, 256)...,
171+
),
172+
wantFrame: &websocket.Frame{
173+
Fin: true,
174+
Opcode: websocket.OpcodeBinary,
175+
Payload: make([]byte, 256),
176+
},
177+
},
178+
"64KiB binary message": {
179+
rawBytes: append(
180+
[]byte{0x82, 0x7F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00},
181+
make([]byte, 65536)...,
182+
),
183+
wantFrame: &websocket.Frame{
184+
Fin: true,
185+
Opcode: websocket.OpcodeBinary,
186+
Payload: make([]byte, 65536),
187+
},
188+
},
189+
}
190+
191+
for name, tc := range testCases {
192+
tc := tc
193+
t.Run(name, func(t *testing.T) {
194+
t.Parallel()
195+
buf := bytes.NewReader(tc.rawBytes)
196+
got := mustReadFrame(t, buf, len(tc.rawBytes))
197+
assert.DeepEqual(t, got, tc.wantFrame, "frames do not match")
198+
})
199+
}
200+
}
201+
202+
func TestIncompleteFrames(t *testing.T) {
203+
testCases := map[string]struct {
204+
rawBytes []byte
205+
wantErr error
206+
}{
207+
"2-byte extended payload can't be read": {
208+
rawBytes: []byte{0x82, 0x7E},
209+
wantErr: errors.New("error reading 2-byte extended payload length: EOF"),
210+
},
211+
"8-byte extended payload can't be read": {
212+
rawBytes: []byte{0x82, 0x7F},
213+
wantErr: errors.New("error reading 8-byte extended payload length: EOF"),
214+
},
215+
"mask can't be read": {
216+
rawBytes: []byte{0x81, 0x85},
217+
wantErr: errors.New("error reading mask key: EOF"),
218+
},
219+
"payload can't be read": {
220+
rawBytes: []byte{0x81, 0x05},
221+
wantErr: errors.New("error reading 5 byte payload: EOF"),
222+
},
223+
}
224+
225+
for name, tc := range testCases {
226+
tc := tc
227+
t.Run(name, func(t *testing.T) {
228+
t.Parallel()
229+
buf := bytes.NewReader(tc.rawBytes)
230+
_, err := websocket.ReadFrame(buf, 70000)
231+
assert.Error(t, err, tc.wantErr)
232+
})
233+
}
234+
}

websocket_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ func TestProtocolErrors(t *testing.T) {
618618
frames: []*websocket.Frame{
619619
{
620620
Opcode: websocket.OpcodeText,
621-
RSV1: true,
621+
RSV: 0b100,
622622
Fin: true,
623623
Payload: []byte("hello"),
624624
},

0 commit comments

Comments
 (0)