Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,22 @@ testautobahn:
AUTOBAHN_TESTS=1 AUTOBAHN_OPEN_REPORT=1 go test -run ^TestWebSocketServer$$ $(TEST_ARGS) ./...
.PHONY: autobahntests


bench:
go test -bench=. -benchmem
.PHONY: bench

lint:
test -z "$$($(CMD_GOFUMPT) -d -e .)" || (echo "Error: gofmt failed"; gofmt -d -e . ; exit 1)
test -z "$$($(CMD_GOFUMPT) -d -e .)" || (echo "Error: gofmt failed"; $(CMD_GOFUMPT) -d -e . ; exit 1)
go vet ./...
$(CMD_REVIVE) -set_exit_status ./...
$(CMD_STATICCHECK) ./...
.PHONY: lint

fmt:
$(CMD_GOFUMPT) -d -e -w .
.PHONY: fmt

clean:
rm -rf $(OUT_DIR) $(COVERAGE_PATH)
.PHONY: clean
Expand Down
28 changes: 14 additions & 14 deletions internal/testing/assert/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
)

// Equal asserts that two values are equal.
func Equal[T comparable](t *testing.T, got, want T, msg string, arg ...any) {
func Equal[T comparable](t testing.TB, got, want T, msg string, arg ...any) {
t.Helper()
if got != want {
if msg == "" {
Expand All @@ -26,7 +26,7 @@ func Equal[T comparable](t *testing.T, got, want T, msg string, arg ...any) {
}

// DeepEqual asserts that two values are deeply equal.
func DeepEqual[T any](t *testing.T, got, want T, msg string, arg ...any) {
func DeepEqual[T any](t testing.TB, got, want T, msg string, arg ...any) {
t.Helper()
if !reflect.DeepEqual(got, want) {
if msg == "" {
Expand All @@ -38,22 +38,22 @@ func DeepEqual[T any](t *testing.T, got, want T, msg string, arg ...any) {
}

// NilError asserts that an error is nil.
func NilError(t *testing.T, err error) {
func NilError(t testing.TB, err error) {
t.Helper()
if err != nil {
t.Fatalf("expected nil error, got %q (%T)", err, err)
}
}

// Error asserts that an error is not nil.
func Error(t *testing.T, got, expected error) {
func Error(t testing.TB, got, expected error) {
t.Helper()
if !errorsMatch(t, got, expected) {
t.Fatalf("expected error %q, got %v (%T vs %T)", expected, got, expected, got)
}
}

func errorsMatch(t *testing.T, got, expected error) bool {
func errorsMatch(t testing.TB, got, expected error) bool {
t.Helper()
switch {
case got == expected:
Expand All @@ -68,7 +68,7 @@ func errorsMatch(t *testing.T, got, expected error) bool {
}

// StatusCode asserts that a response has a specific status code.
func StatusCode(t *testing.T, resp *http.Response, code int) {
func StatusCode(t testing.TB, resp *http.Response, code int) {
t.Helper()
if resp.StatusCode != code {
t.Fatalf("expected status code %d, got %d", code, resp.StatusCode)
Expand All @@ -87,7 +87,7 @@ func isSafeContentType(ct string) bool {
}

// Header asserts that a header key has a specific value in a response.
func Header(t *testing.T, resp *http.Response, key, want string) {
func Header(t testing.TB, resp *http.Response, key, want string) {
t.Helper()
got := resp.Header.Get(key)
if want != got {
Expand All @@ -97,42 +97,42 @@ func Header(t *testing.T, resp *http.Response, key, want string) {

// ContentType asserts that a response has a specific Content-Type header
// value.
func ContentType(t *testing.T, resp *http.Response, contentType string) {
func ContentType(t testing.TB, resp *http.Response, contentType string) {
t.Helper()
Header(t, resp, "Content-Type", contentType)
}

// Contains asserts that needle is found in the given string.
func Contains(t *testing.T, s string, needle string, description string) {
func Contains(t testing.TB, s string, needle string, description string) {
t.Helper()
if !strings.Contains(s, needle) {
t.Fatalf("expected string %q in %s %q", needle, description, s)
}
}

// BodyContains asserts that a response body contains a specific substring.
func BodyContains(t *testing.T, resp *http.Response, needle string) {
func BodyContains(t testing.TB, resp *http.Response, needle string) {
t.Helper()
body := must.ReadAll(t, resp.Body)
Contains(t, body, needle, "body")
}

// BodyEquals asserts that a response body is equal to a specific string.
func BodyEquals(t *testing.T, resp *http.Response, want string) {
func BodyEquals(t testing.TB, resp *http.Response, want string) {
t.Helper()
got := must.ReadAll(t, resp.Body)
Equal(t, got, want, "incorrect response body")
}

// BodySize asserts that a response body is a specific size.
func BodySize(t *testing.T, resp *http.Response, want int) {
func BodySize(t testing.TB, resp *http.Response, want int) {
t.Helper()
got := must.ReadAll(t, resp.Body)
Equal(t, len(got), want, "incorrect response body size")
}

// DurationRange asserts that a duration is within a specific range.
func DurationRange(t *testing.T, got, minVal, maxVal time.Duration) {
func DurationRange(t testing.TB, got, minVal, maxVal time.Duration) {
t.Helper()
if got < minVal || got > maxVal {
t.Fatalf("expected duration between %s and %s, got %s", minVal, maxVal, got)
Expand All @@ -144,7 +144,7 @@ type number interface {
}

// RoughlyEqual asserts that a numeric value is within a certain tolerance.
func RoughlyEqual[T number](t *testing.T, got, want T, epsilon T) {
func RoughlyEqual[T number](t testing.TB, got, want T, epsilon T) {
t.Helper()
if got < want-epsilon || got > want+epsilon {
t.Fatalf("expected value between %v and %v, got %v", want-epsilon, want+epsilon, got)
Expand Down
6 changes: 3 additions & 3 deletions internal/testing/must/must.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

// DoReq makes an HTTP request and fails the test if there is an error.
func DoReq(t *testing.T, client *http.Client, req *http.Request) *http.Response {
func DoReq(t testing.TB, client *http.Client, req *http.Request) *http.Response {
t.Helper()
start := time.Now()
resp, err := client.Do(req)
Expand All @@ -24,7 +24,7 @@ func DoReq(t *testing.T, client *http.Client, req *http.Request) *http.Response

// ReadAll reads all bytes from an io.Reader and fails the test if there is an
// error.
func ReadAll(t *testing.T, r io.Reader) string {
func ReadAll(t testing.TB, r io.Reader) string {
t.Helper()
body, err := io.ReadAll(r)
if err != nil {
Expand All @@ -38,7 +38,7 @@ func ReadAll(t *testing.T, r io.Reader) string {

// Unmarshal unmarshals JSON from an io.Reader into a value and fails the test
// if there is an error.
func Unmarshal[T any](t *testing.T, r io.Reader) T {
func Unmarshal[T any](t testing.TB, r io.Reader) T {
t.Helper()
var v T
if err := json.NewDecoder(r).Decode(&v); err != nil {
Expand Down
125 changes: 125 additions & 0 deletions websocket_benchmark_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package websocket_test

import (
"bytes"
"strconv"
"testing"

"github.com/mccutchen/websocket"
"github.com/mccutchen/websocket/internal/testing/assert"
)

func makeFrame(opcode websocket.Opcode, fin bool, payloadLen int) *websocket.Frame {
payload := make([]byte, payloadLen)
for i := range payload {
payload[i] = 0x20 + byte(i%95) // Map to range 0x20 (space) to 0x7E (~)
}

return &websocket.Frame{
Opcode: opcode,
Fin: fin,
Payload: payload,
}
}

func BenchmarkReadFrame(b *testing.B) {
frameSizes := []int{
64,
256,
1024,
64 * 1024,
1024 * 1024,
}

for _, size := range frameSizes {
frame := makeFrame(websocket.OpcodeText, true, size)
mask := [4]byte{1, 2, 3, 4}

buf := &bytes.Buffer{}
assert.NilError(b, websocket.WriteFrameMasked(buf, frame, mask))

// Run sub-benchmarks for each payload size
b.Run(strconv.Itoa(size)+"b", func(b *testing.B) {
src := bytes.NewReader(buf.Bytes())
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = src.Seek(0, 0)
_, err := websocket.ReadFrame(src)
if err != nil {
b.Fatalf("unexpected error: %v", err)
}
}
})
}
}

/*

TODO: benchmark reading entire message after refactoring

func BenchmarkReadMessage(b *testing.B) {
frameSizes := []int{
64,
// 256,
// 1024,
// 64 * 1024,
// 1024 * 1024,
}

messageSizes := []int{
512,
// 1024,
// 256 * 1024,
// 1024 * 1024,
// 2 * 1024 * 1024,
}

echoHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ws, err := websocket.Accept(w, r, websocket.Options{
MaxFragmentSize: 1024 * 1024,
MaxMessageSize: 2 * 1024 * 1024,
// Hooks: newTestHooks(b),
})
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
ws.Serve(r.Context(), websocket.EchoHandler)
})

for _, msgSize := range messageSizes {
for _, frameSize := range frameSizes {
if msgSize%frameSize != 0 {
continue
}
buf := &bytes.Buffer{}
frameCount := msgSize / frameSize
for i := 0; i < frameCount; i++ {
opcode := websocket.OpcodeText
if i > 0 {
opcode = websocket.OpcodeContinuation
}
fin := i == frameCount-1
b.Logf("frame=%d frameCount=%d fin=%v", i, frameCount, fin)
frame := makeFrame(opcode, fin, frameSize)
websocket.WriteFrameMasked(buf, frame, makeMaskingKey())
}

name := fmt.Sprintf("MessageSize=%db/FrameSize=%db/FrameCount=%d", msgSize, frameSize, frameCount)
b.Run(name, func(b *testing.B) {
_, conn := setupRawConn(b, echoHandler)
b.ResetTimer()
for i := 0; i < b.N; i++ {
n, err := conn.Write(buf.Bytes())
assert.NilError(b, err)
assert.Equal(b, n, len(buf.Bytes()), "incorrect number of bytes written")
resp, err := io.ReadAll(conn)
assert.NilError(b, err)
assert.Equal(b, len(resp) >= msgSize, true, "expected to read full message back")
}
})
}
}
}

*/
17 changes: 12 additions & 5 deletions websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,7 @@ func TestProtocolBasics(t *testing.T) {
Fin: true,
Payload: []byte("hello"),
}
mask := [4]byte{1, 2, 3, 4}
assert.NilError(t, websocket.WriteFrameMasked(conn, clientFrame, mask))
assert.NilError(t, websocket.WriteFrameMasked(conn, clientFrame, makeMaskingKey()))
// read server frame
serverFrame, err := websocket.ReadFrame(conn)
assert.NilError(t, err)
Expand All @@ -353,13 +352,12 @@ func TestProtocolBasics(t *testing.T) {
assert.NilError(t, websocket.WriteFrame(conn, frame))
validateCloseFrame(t, conn, websocket.StatusProtocolError, "received unmasked client frame")
})

}

// setupRawConn is a test helpers that runs a test server and does the
// initial websocket handshake. The returned connection is ready for use to
// sent/receive websocket messages.
func setupRawConn(t *testing.T, handler http.Handler) (*httptest.Server, net.Conn) {
func setupRawConn(t testing.TB, handler http.Handler) (*httptest.Server, net.Conn) {
t.Helper()

srv := httptest.NewServer(handler)
Expand Down Expand Up @@ -399,6 +397,15 @@ func makeClientKey() string {
return base64.StdEncoding.EncodeToString(b)
}

func makeMaskingKey() [4]byte {
var key [4]byte
_, err := rand.Read(key[:]) // Fill the key with 4 random bytes
if err != nil {
panic(fmt.Sprintf("failed to read random bytes for masking key: %s", err))
}
return key
}

// validateCloseFrame ensures that we can read a close frame from the given
// reader and optionally ensures that the close frame includes a specific
// status code and message.
Expand Down Expand Up @@ -447,7 +454,7 @@ var (
_ http.Hijacker = &brokenHijackResponseWriter{}
)

func newTestHooks(t *testing.T) websocket.Hooks {
func newTestHooks(t testing.TB) websocket.Hooks {
t.Helper()
return websocket.Hooks{
OnClose: func(key websocket.ClientKey, code websocket.StatusCode, err error) {
Expand Down
Loading