From 89da0cc5719bc02a214b068c0c1c9f04f8609d46 Mon Sep 17 00:00:00 2001 From: Will McCutchen Date: Sat, 28 Dec 2024 11:34:07 -0500 Subject: [PATCH 1/2] perf: add initial benchmark More thorough benchmarks will likely require additional refactoring. --- Makefile | 11 ++- internal/testing/assert/assert.go | 28 +++---- internal/testing/must/must.go | 6 +- websocket_benchmark_test.go | 124 ++++++++++++++++++++++++++++++ websocket_test.go | 17 ++-- 5 files changed, 163 insertions(+), 23 deletions(-) create mode 100644 websocket_benchmark_test.go diff --git a/Makefile b/Makefile index 8fe106c..211f11c 100644 --- a/Makefile +++ b/Makefile @@ -50,13 +50,22 @@ testautobahn: AUTOBAHN_TESTS=1 AUTOBAHN_OPEN_REPORT=1 go test -run ^TestWebSocketServer$$ $(TEST_ARGS) ./... .PHONY: autobahntests + +bench: + go test -bench=. -benchmem +.PHONY: bench + lint: - test -z "$$($(CMD_GOFUMPT) -d -e .)" || (echo "Error: gofmt failed"; gofmt -d -e . ; exit 1) + test -z "$$($(CMD_GOFUMPT) -d -e .)" || (echo "Error: gofmt failed"; $(CMD_GOFUMPT) -d -e . ; exit 1) go vet ./... $(CMD_REVIVE) -set_exit_status ./... $(CMD_STATICCHECK) ./... .PHONY: lint +fmt: + $(CMD_GOFUMPT) -d -e -w . +.PHONY: fmt + clean: rm -rf $(OUT_DIR) $(COVERAGE_PATH) .PHONY: clean diff --git a/internal/testing/assert/assert.go b/internal/testing/assert/assert.go index 336e0d3..6b8556a 100644 --- a/internal/testing/assert/assert.go +++ b/internal/testing/assert/assert.go @@ -14,7 +14,7 @@ import ( ) // Equal asserts that two values are equal. -func Equal[T comparable](t *testing.T, got, want T, msg string, arg ...any) { +func Equal[T comparable](t testing.TB, got, want T, msg string, arg ...any) { t.Helper() if got != want { if msg == "" { @@ -26,7 +26,7 @@ func Equal[T comparable](t *testing.T, got, want T, msg string, arg ...any) { } // DeepEqual asserts that two values are deeply equal. -func DeepEqual[T any](t *testing.T, got, want T, msg string, arg ...any) { +func DeepEqual[T any](t testing.TB, got, want T, msg string, arg ...any) { t.Helper() if !reflect.DeepEqual(got, want) { if msg == "" { @@ -38,7 +38,7 @@ func DeepEqual[T any](t *testing.T, got, want T, msg string, arg ...any) { } // NilError asserts that an error is nil. -func NilError(t *testing.T, err error) { +func NilError(t testing.TB, err error) { t.Helper() if err != nil { t.Fatalf("expected nil error, got %q (%T)", err, err) @@ -46,14 +46,14 @@ func NilError(t *testing.T, err error) { } // Error asserts that an error is not nil. -func Error(t *testing.T, got, expected error) { +func Error(t testing.TB, got, expected error) { t.Helper() if !errorsMatch(t, got, expected) { t.Fatalf("expected error %q, got %v (%T vs %T)", expected, got, expected, got) } } -func errorsMatch(t *testing.T, got, expected error) bool { +func errorsMatch(t testing.TB, got, expected error) bool { t.Helper() switch { case got == expected: @@ -68,7 +68,7 @@ func errorsMatch(t *testing.T, got, expected error) bool { } // StatusCode asserts that a response has a specific status code. -func StatusCode(t *testing.T, resp *http.Response, code int) { +func StatusCode(t testing.TB, resp *http.Response, code int) { t.Helper() if resp.StatusCode != code { t.Fatalf("expected status code %d, got %d", code, resp.StatusCode) @@ -87,7 +87,7 @@ func isSafeContentType(ct string) bool { } // Header asserts that a header key has a specific value in a response. -func Header(t *testing.T, resp *http.Response, key, want string) { +func Header(t testing.TB, resp *http.Response, key, want string) { t.Helper() got := resp.Header.Get(key) if want != got { @@ -97,13 +97,13 @@ func Header(t *testing.T, resp *http.Response, key, want string) { // ContentType asserts that a response has a specific Content-Type header // value. -func ContentType(t *testing.T, resp *http.Response, contentType string) { +func ContentType(t testing.TB, resp *http.Response, contentType string) { t.Helper() Header(t, resp, "Content-Type", contentType) } // Contains asserts that needle is found in the given string. -func Contains(t *testing.T, s string, needle string, description string) { +func Contains(t testing.TB, s string, needle string, description string) { t.Helper() if !strings.Contains(s, needle) { t.Fatalf("expected string %q in %s %q", needle, description, s) @@ -111,28 +111,28 @@ func Contains(t *testing.T, s string, needle string, description string) { } // BodyContains asserts that a response body contains a specific substring. -func BodyContains(t *testing.T, resp *http.Response, needle string) { +func BodyContains(t testing.TB, resp *http.Response, needle string) { t.Helper() body := must.ReadAll(t, resp.Body) Contains(t, body, needle, "body") } // BodyEquals asserts that a response body is equal to a specific string. -func BodyEquals(t *testing.T, resp *http.Response, want string) { +func BodyEquals(t testing.TB, resp *http.Response, want string) { t.Helper() got := must.ReadAll(t, resp.Body) Equal(t, got, want, "incorrect response body") } // BodySize asserts that a response body is a specific size. -func BodySize(t *testing.T, resp *http.Response, want int) { +func BodySize(t testing.TB, resp *http.Response, want int) { t.Helper() got := must.ReadAll(t, resp.Body) Equal(t, len(got), want, "incorrect response body size") } // DurationRange asserts that a duration is within a specific range. -func DurationRange(t *testing.T, got, minVal, maxVal time.Duration) { +func DurationRange(t testing.TB, got, minVal, maxVal time.Duration) { t.Helper() if got < minVal || got > maxVal { t.Fatalf("expected duration between %s and %s, got %s", minVal, maxVal, got) @@ -144,7 +144,7 @@ type number interface { } // RoughlyEqual asserts that a numeric value is within a certain tolerance. -func RoughlyEqual[T number](t *testing.T, got, want T, epsilon T) { +func RoughlyEqual[T number](t testing.TB, got, want T, epsilon T) { t.Helper() if got < want-epsilon || got > want+epsilon { t.Fatalf("expected value between %v and %v, got %v", want-epsilon, want+epsilon, got) diff --git a/internal/testing/must/must.go b/internal/testing/must/must.go index 3796dd2..2b8940f 100644 --- a/internal/testing/must/must.go +++ b/internal/testing/must/must.go @@ -11,7 +11,7 @@ import ( ) // DoReq makes an HTTP request and fails the test if there is an error. -func DoReq(t *testing.T, client *http.Client, req *http.Request) *http.Response { +func DoReq(t testing.TB, client *http.Client, req *http.Request) *http.Response { t.Helper() start := time.Now() resp, err := client.Do(req) @@ -24,7 +24,7 @@ func DoReq(t *testing.T, client *http.Client, req *http.Request) *http.Response // ReadAll reads all bytes from an io.Reader and fails the test if there is an // error. -func ReadAll(t *testing.T, r io.Reader) string { +func ReadAll(t testing.TB, r io.Reader) string { t.Helper() body, err := io.ReadAll(r) if err != nil { @@ -38,7 +38,7 @@ func ReadAll(t *testing.T, r io.Reader) string { // Unmarshal unmarshals JSON from an io.Reader into a value and fails the test // if there is an error. -func Unmarshal[T any](t *testing.T, r io.Reader) T { +func Unmarshal[T any](t testing.TB, r io.Reader) T { t.Helper() var v T if err := json.NewDecoder(r).Decode(&v); err != nil { diff --git a/websocket_benchmark_test.go b/websocket_benchmark_test.go new file mode 100644 index 0000000..c8d42be --- /dev/null +++ b/websocket_benchmark_test.go @@ -0,0 +1,124 @@ +package websocket_test + +import ( + "bytes" + "strconv" + "testing" + + "github.com/mccutchen/websocket" +) + +func makeFrame(opcode websocket.Opcode, fin bool, payloadLen int) *websocket.Frame { + payload := make([]byte, payloadLen) + for i := range payload { + payload[i] = 0x20 + byte(i%95) // Map to range 0x20 (space) to 0x7E (~) + } + + return &websocket.Frame{ + Opcode: opcode, + Fin: fin, + Payload: payload, + } +} + +func BenchmarkReadFrame(b *testing.B) { + frameSizes := []int{ + 64, + 256, + 1024, + 64 * 1024, + 1024 * 1024, + } + + for _, size := range frameSizes { + frame := makeFrame(websocket.OpcodeText, true, size) + mask := [4]byte{1, 2, 3, 4} + + buf := &bytes.Buffer{} + websocket.WriteFrameMasked(buf, frame, mask) + + // Run sub-benchmarks for each payload size + b.Run(strconv.Itoa(size)+"b", func(b *testing.B) { + src := bytes.NewReader(buf.Bytes()) + b.ResetTimer() + for i := 0; i < b.N; i++ { + src.Seek(0, 0) + _, err := websocket.ReadFrame(src) + if err != nil { + b.Fatalf("unexpected error: %v", err) + } + } + }) + } +} + +/* + +TODO: benchmark reading entire message after refactoring + +func BenchmarkReadMessage(b *testing.B) { + frameSizes := []int{ + 64, + // 256, + // 1024, + // 64 * 1024, + // 1024 * 1024, + } + + messageSizes := []int{ + 512, + // 1024, + // 256 * 1024, + // 1024 * 1024, + // 2 * 1024 * 1024, + } + + echoHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws, err := websocket.Accept(w, r, websocket.Options{ + MaxFragmentSize: 1024 * 1024, + MaxMessageSize: 2 * 1024 * 1024, + // Hooks: newTestHooks(b), + }) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + ws.Serve(r.Context(), websocket.EchoHandler) + }) + + for _, msgSize := range messageSizes { + for _, frameSize := range frameSizes { + if msgSize%frameSize != 0 { + continue + } + buf := &bytes.Buffer{} + frameCount := msgSize / frameSize + for i := 0; i < frameCount; i++ { + opcode := websocket.OpcodeText + if i > 0 { + opcode = websocket.OpcodeContinuation + } + fin := i == frameCount-1 + b.Logf("frame=%d frameCount=%d fin=%v", i, frameCount, fin) + frame := makeFrame(opcode, fin, frameSize) + websocket.WriteFrameMasked(buf, frame, makeMaskingKey()) + } + + name := fmt.Sprintf("MessageSize=%db/FrameSize=%db/FrameCount=%d", msgSize, frameSize, frameCount) + b.Run(name, func(b *testing.B) { + _, conn := setupRawConn(b, echoHandler) + b.ResetTimer() + for i := 0; i < b.N; i++ { + n, err := conn.Write(buf.Bytes()) + assert.NilError(b, err) + assert.Equal(b, n, len(buf.Bytes()), "incorrect number of bytes written") + resp, err := io.ReadAll(conn) + assert.NilError(b, err) + assert.Equal(b, len(resp) >= msgSize, true, "expected to read full message back") + } + }) + } + } +} + +*/ diff --git a/websocket_test.go b/websocket_test.go index 2a4b89d..e21ad7b 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -331,8 +331,7 @@ func TestProtocolBasics(t *testing.T) { Fin: true, Payload: []byte("hello"), } - mask := [4]byte{1, 2, 3, 4} - assert.NilError(t, websocket.WriteFrameMasked(conn, clientFrame, mask)) + assert.NilError(t, websocket.WriteFrameMasked(conn, clientFrame, makeMaskingKey())) // read server frame serverFrame, err := websocket.ReadFrame(conn) assert.NilError(t, err) @@ -353,13 +352,12 @@ func TestProtocolBasics(t *testing.T) { assert.NilError(t, websocket.WriteFrame(conn, frame)) validateCloseFrame(t, conn, websocket.StatusProtocolError, "received unmasked client frame") }) - } // setupRawConn is a test helpers that runs a test server and does the // initial websocket handshake. The returned connection is ready for use to // sent/receive websocket messages. -func setupRawConn(t *testing.T, handler http.Handler) (*httptest.Server, net.Conn) { +func setupRawConn(t testing.TB, handler http.Handler) (*httptest.Server, net.Conn) { t.Helper() srv := httptest.NewServer(handler) @@ -399,6 +397,15 @@ func makeClientKey() string { return base64.StdEncoding.EncodeToString(b) } +func makeMaskingKey() [4]byte { + var key [4]byte + _, err := rand.Read(key[:]) // Fill the key with 4 random bytes + if err != nil { + panic(fmt.Sprintf("failed to read random bytes for masking key: %s", err)) + } + return key +} + // validateCloseFrame ensures that we can read a close frame from the given // reader and optionally ensures that the close frame includes a specific // status code and message. @@ -447,7 +454,7 @@ var ( _ http.Hijacker = &brokenHijackResponseWriter{} ) -func newTestHooks(t *testing.T) websocket.Hooks { +func newTestHooks(t testing.TB) websocket.Hooks { t.Helper() return websocket.Hooks{ OnClose: func(key websocket.ClientKey, code websocket.StatusCode, err error) { From 45f9e804ca58dbce8e7194995b0780d7616bcad5 Mon Sep 17 00:00:00 2001 From: Will McCutchen Date: Sun, 29 Dec 2024 01:54:44 -0500 Subject: [PATCH 2/2] appease linter --- websocket_benchmark_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/websocket_benchmark_test.go b/websocket_benchmark_test.go index c8d42be..6992014 100644 --- a/websocket_benchmark_test.go +++ b/websocket_benchmark_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/mccutchen/websocket" + "github.com/mccutchen/websocket/internal/testing/assert" ) func makeFrame(opcode websocket.Opcode, fin bool, payloadLen int) *websocket.Frame { @@ -35,14 +36,14 @@ func BenchmarkReadFrame(b *testing.B) { mask := [4]byte{1, 2, 3, 4} buf := &bytes.Buffer{} - websocket.WriteFrameMasked(buf, frame, mask) + assert.NilError(b, websocket.WriteFrameMasked(buf, frame, mask)) // Run sub-benchmarks for each payload size b.Run(strconv.Itoa(size)+"b", func(b *testing.B) { src := bytes.NewReader(buf.Bytes()) b.ResetTimer() for i := 0; i < b.N; i++ { - src.Seek(0, 0) + _, _ = src.Seek(0, 0) _, err := websocket.ReadFrame(src) if err != nil { b.Fatalf("unexpected error: %v", err)