Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Default flags used by the test, testci, testcover targets
COVERAGE_PATH ?= coverage.out
COVERAGE_ARGS ?= -covermode=atomic -coverprofile=$(COVERAGE_PATH)
TEST_ARGS ?= -race -count=1 -timeout=10s
TEST_ARGS ?= -race -count=1 -timeout=15s
AUTOBAHN_ARGS ?= -race -count=1 -timeout=120s
BENCH_COUNT ?= 10
BENCH_ARGS ?= -bench=. -benchmem -count=$(BENCH_COUNT) -run=^$$
Expand Down
36 changes: 14 additions & 22 deletions proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,64 +283,56 @@ func ReadFrame(src io.Reader, mode Mode, maxPayloadLen int) (*Frame, error) {
// WriteFrame writes a [Frame] to dst with the given masking key. To write an
// unmasked frame, use the special [Unmasked] key.
func WriteFrame(dst io.Writer, mask MaskingKey, frame *Frame) error {
_, err := dst.Write(MarshalFrame(frame, mask))
if err != nil {
if _, err := dst.Write(marshalFrame(frame, mask)); err != nil {
return newError(StatusAbnormalClose, "error writing frame: %w", err)
}
return nil
}

// MarshalFrame marshals a [Frame] into bytes for transmission.
func MarshalFrame(frame *Frame, mask MaskingKey) []byte {
// marshalFrame marshals a [Frame] into bytes for transmission.
func marshalFrame(frame *Frame, mask MaskingKey) []byte {
var (
payloadLen = len(frame.Payload)
payloadOffset = 2 // at least 2 bytes will be taken by header
buf = make([]byte, marshaledSize(payloadLen, mask))
masked = mask != Unmasked
)

// Right-size buffer with initial capacity of 2 because we will always write
// two header bytes.
buf := make([]byte, 2, marshaledSize(payloadLen, mask))

// First header byte can be written directly
buf[0] = frame.header

// Second header byte depends on mask and payload size
masked := mask != Unmasked
// Second header byte encodes masked bit and payload size, additional
// header bytes may be written for extended payload sizes.
if masked {
buf[1] |= 0b1000_0000
}

switch {
case payloadLen <= 125:
buf[1] |= byte(payloadLen)
case payloadLen <= 65535:
buf[1] |= 126
buf = binary.BigEndian.AppendUint16(buf, uint16(payloadLen))
binary.BigEndian.PutUint16(buf[payloadOffset:], uint16(payloadLen))
payloadOffset += 2
default:
buf[1] |= 127
buf = binary.BigEndian.AppendUint64(buf, uint64(payloadLen))
binary.BigEndian.PutUint64(buf[payloadOffset:], uint64(payloadLen))
payloadOffset += 8
}

// Optional masking key and actual payload
//
// Note that we manually extend capacity of buffer as necessary to enable
// use of `copy()` instead of `append()`
if masked {
buf = buf[:payloadOffset+payloadLen+4]
copy(buf[payloadOffset:payloadOffset+4], mask[:])
copy(buf[payloadOffset:], mask[:])
payloadOffset += 4
copy(buf[payloadOffset:payloadOffset+payloadLen], frame.Payload)
copy(buf[payloadOffset:], frame.Payload)
applyMask(buf[payloadOffset:payloadOffset+payloadLen], mask)
} else {
buf = buf[:payloadOffset+payloadLen]
copy(buf[payloadOffset:payloadOffset+payloadLen], frame.Payload)
copy(buf[payloadOffset:], frame.Payload)
}
return buf
}

// marshaledSize returns the number of bytes required to marshal a frame.
// marshaledSize returns the number of bytes required to marshal a
// [Frame] with the given payload length and [MaskingKey].
func marshaledSize(payloadLen int, mask MaskingKey) int {
size := 2 + payloadLen
switch {
Expand Down
4 changes: 3 additions & 1 deletion proto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ func TestRSV(t *testing.T) {
// We don't currently support any extensions, so RSV bits are not allowed.
// But we still need to be able properly parse and marshal them.
marshalledFrame := func(rsvBits ...websocket.RSVBit) []byte {
buf := &bytes.Buffer{}
frame := websocket.NewFrame(websocket.OpcodeText, true, nil, rsvBits...)
return websocket.MarshalFrame(frame, websocket.Unmasked)
assert.NilError(t, websocket.WriteFrame(buf, websocket.Unmasked, frame))
return buf.Bytes()
}

testCases := map[string]struct {
Expand Down
Loading