From 0f9c4f117c0763db23d4b204973abf1fbf1c41f6 Mon Sep 17 00:00:00 2001 From: Mitch Bradley Date: Wed, 8 Apr 2026 17:48:39 -1000 Subject: [PATCH 1/3] Fix websocket client lifetime races during broadcast --- src/AsyncWebSocket.cpp | 79 +++++++++++++++++++++++++++--------------- src/AsyncWebSocket.h | 1 + 2 files changed, 52 insertions(+), 28 deletions(-) diff --git a/src/AsyncWebSocket.cpp b/src/AsyncWebSocket.cpp index f4e2c408..d2a244a8 100644 --- a/src/AsyncWebSocket.cpp +++ b/src/AsyncWebSocket.cpp @@ -37,6 +37,15 @@ using namespace asyncsrv; +namespace { +AsyncWebSocketClient *find_connected_client_locked(std::list &clients, uint32_t id) { + const auto iter = std::find_if(clients.begin(), clients.end(), [id](const AsyncWebSocketClient &client) { + return client.id() == id && client.status() == WS_CONNECTED; + }); + return iter == clients.end() ? nullptr : &(*iter); +} +} // namespace + size_t webSocketSendFrameWindow(AsyncClient *client) { if (!client || !client->canSend()) { return 0; @@ -357,11 +366,11 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time) { } void AsyncWebSocketClient::_onPoll() { + asyncsrv::unique_lock_type lock(_lock); if (!_client) { return; } - asyncsrv::unique_lock_type lock(_lock); if (_client && _client->canSend() && (!_controlQueue.empty() || !_messageQueue.empty())) { _runQueue(); } else if (_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && (_controlQueue.empty() && _messageQueue.empty())) { @@ -444,12 +453,11 @@ bool AsyncWebSocketClient::canSend() const { } bool AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, size_t len, bool mask) { + asyncsrv::lock_guard_type lock(_lock); if (!_client) { return false; } - asyncsrv::lock_guard_type lock(_lock); - _controlQueue.emplace_back(opcode, data, len, mask); async_ws_log_v("[%s][%" PRIu32 "] QUEUE CTRL (%u) << %" PRIu8, _server->url(), _clientId, _controlQueue.size(), opcode); @@ -461,12 +469,11 @@ bool AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, si } bool AsyncWebSocketClient::_queueMessage(AsyncWebSocketSharedBuffer buffer, uint8_t opcode, bool mask) { - if (!_client || buffer->size() == 0 || _status != WS_CONNECTED) { + asyncsrv::unique_lock_type lock(_lock); + if (!_client || !buffer || buffer->empty() || _status != WS_CONNECTED) { return false; } - asyncsrv::unique_lock_type lock(_lock); - if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES) { if (closeWhenFull) { _status = WS_DISCONNECTED; @@ -545,6 +552,7 @@ void AsyncWebSocketClient::_onError(int8_t err) { } void AsyncWebSocketClient::_onTimeout(uint32_t time) { + asyncsrv::lock_guard_type lock(_lock); if (!_client) { return; } @@ -553,7 +561,9 @@ void AsyncWebSocketClient::_onTimeout(uint32_t time) { } void AsyncWebSocketClient::_onDisconnect() { + asyncsrv::lock_guard_type lock(_lock); async_ws_log_v("[%s][%" PRIu32 "] DISCONNECT", _server->url(), _clientId); + _status = WS_DISCONNECTED; _client = nullptr; _server->_handleDisconnect(this); } @@ -947,6 +957,7 @@ bool AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len) { #endif IPAddress AsyncWebSocketClient::remoteIP() const { + asyncsrv::lock_guard_type lock(_lock); if (!_client) { return IPAddress((uint32_t)0U); } @@ -955,6 +966,7 @@ IPAddress AsyncWebSocketClient::remoteIP() const { } uint16_t AsyncWebSocketClient::remotePort() const { + asyncsrv::lock_guard_type lock(_lock); if (!_client) { return 0; } @@ -983,14 +995,10 @@ AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request) } void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient *client) { - asyncsrv::lock_guard_type lock(_lock); - const auto client_id = client->id(); - const auto iter = std::find_if(std::begin(_clients), std::end(_clients), [client_id](const AsyncWebSocketClient &c) { - return c.id() == client_id; - }); - if (iter != std::end(_clients)) { - _clients.erase(iter); - } + (void)client; + // Defer removal to cleanupClients(). Disconnect callbacks can fire while + // iterating _clients for broadcast sends, and erasing here invalidates the + // active iterator in the caller. } bool AsyncWebSocket::availableForWriteAll() { @@ -1031,7 +1039,8 @@ AsyncWebSocketClient *AsyncWebSocket::client(uint32_t id) { } void AsyncWebSocket::close(uint32_t id, uint16_t code, const char *message) { - if (AsyncWebSocketClient *c = client(id)) { + asyncsrv::lock_guard_type lock(_lock); + if (AsyncWebSocketClient *c = find_connected_client_locked(_clients, id)) { c->close(code, message); } } @@ -1047,22 +1056,32 @@ void AsyncWebSocket::closeAll(uint16_t code, const char *message) { void AsyncWebSocket::cleanupClients(uint16_t maxClients) { asyncsrv::lock_guard_type lock(_lock); - const size_t c = count(); - if (c > maxClients) { - async_ws_log_v("[%s] CLEANUP %" PRIu32 " (%u/%" PRIu16 ")", _url.c_str(), _clients.front().id(), c, maxClients); - _clients.front().close(); + const size_t connected = std::count_if(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) { + return c.status() == WS_CONNECTED; + }); + + if (connected > maxClients) { + const auto connected_iter = std::find_if(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) { + return c.status() == WS_CONNECTED; + }); + if (connected_iter != std::end(_clients)) { + async_ws_log_v("[%s] CLEANUP %" PRIu32 " (%u/%" PRIu16 ")", _url.c_str(), connected_iter->id(), connected, maxClients); + connected_iter->close(); + } } - for (auto i = _clients.begin(); i != _clients.end(); ++i) { - if (i->shouldBeDeleted()) { - _clients.erase(i); - break; + for (auto iter = _clients.begin(); iter != _clients.end();) { + if (iter->shouldBeDeleted()) { + iter = _clients.erase(iter); + } else { + ++iter; } } } bool AsyncWebSocket::ping(uint32_t id, const uint8_t *data, size_t len) { - AsyncWebSocketClient *c = client(id); + asyncsrv::lock_guard_type lock(_lock); + AsyncWebSocketClient *c = find_connected_client_locked(_clients, id); return c && c->ping(data, len); } @@ -1081,7 +1100,8 @@ AsyncWebSocket::SendStatus AsyncWebSocket::pingAll(const uint8_t *data, size_t l } bool AsyncWebSocket::text(uint32_t id, const uint8_t *message, size_t len) { - AsyncWebSocketClient *c = client(id); + asyncsrv::lock_guard_type lock(_lock); + AsyncWebSocketClient *c = find_connected_client_locked(_clients, id); return c && c->text(makeSharedBuffer(message, len)); } bool AsyncWebSocket::text(uint32_t id, const char *message, size_t len) { @@ -1127,7 +1147,8 @@ bool AsyncWebSocket::text(uint32_t id, AsyncWebSocketMessageBuffer *buffer) { return enqueued; } bool AsyncWebSocket::text(uint32_t id, AsyncWebSocketSharedBuffer buffer) { - AsyncWebSocketClient *c = client(id); + asyncsrv::lock_guard_type lock(_lock); + AsyncWebSocketClient *c = find_connected_client_locked(_clients, id); return c && c->text(buffer); } @@ -1190,7 +1211,8 @@ AsyncWebSocket::SendStatus AsyncWebSocket::textAll(AsyncWebSocketSharedBuffer bu } bool AsyncWebSocket::binary(uint32_t id, const uint8_t *message, size_t len) { - AsyncWebSocketClient *c = client(id); + asyncsrv::lock_guard_type lock(_lock); + AsyncWebSocketClient *c = find_connected_client_locked(_clients, id); return c && c->binary(makeSharedBuffer(message, len)); } bool AsyncWebSocket::binary(uint32_t id, const char *message, size_t len) { @@ -1226,7 +1248,8 @@ bool AsyncWebSocket::binary(uint32_t id, AsyncWebSocketMessageBuffer *buffer) { return enqueued; } bool AsyncWebSocket::binary(uint32_t id, AsyncWebSocketSharedBuffer buffer) { - AsyncWebSocketClient *c = client(id); + asyncsrv::lock_guard_type lock(_lock); + AsyncWebSocketClient *c = find_connected_client_locked(_clients, id); return c && c->binary(buffer); } diff --git a/src/AsyncWebSocket.h b/src/AsyncWebSocket.h index 1e9a78af..afed25f4 100644 --- a/src/AsyncWebSocket.h +++ b/src/AsyncWebSocket.h @@ -303,6 +303,7 @@ class AsyncWebSocketClient { uint16_t remotePort() const; bool shouldBeDeleted() const { + asyncsrv::lock_guard_type lock(_lock); return !_client; } From 53662c407cccf1d415ff048a764ce343f7733a73 Mon Sep 17 00:00:00 2001 From: Mitch Bradley Date: Wed, 8 Apr 2026 17:59:49 -1000 Subject: [PATCH 2/3] Address websocket review feedback --- src/AsyncWebSocket.cpp | 36 ++++++++++++++++++++---------------- src/AsyncWebSocket.h | 1 + 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/AsyncWebSocket.cpp b/src/AsyncWebSocket.cpp index d2a244a8..317b0b0d 100644 --- a/src/AsyncWebSocket.cpp +++ b/src/AsyncWebSocket.cpp @@ -1055,26 +1055,30 @@ void AsyncWebSocket::closeAll(uint16_t code, const char *message) { } void AsyncWebSocket::cleanupClients(uint16_t maxClients) { - asyncsrv::lock_guard_type lock(_lock); - const size_t connected = std::count_if(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) { - return c.status() == WS_CONNECTED; - }); - - if (connected > maxClients) { - const auto connected_iter = std::find_if(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) { + std::list removed_clients; + { + asyncsrv::lock_guard_type lock(_lock); + const size_t connected = std::count_if(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) { return c.status() == WS_CONNECTED; }); - if (connected_iter != std::end(_clients)) { - async_ws_log_v("[%s] CLEANUP %" PRIu32 " (%u/%" PRIu16 ")", _url.c_str(), connected_iter->id(), connected, maxClients); - connected_iter->close(); + + if (connected > maxClients) { + const auto connected_iter = std::find_if(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) { + return c.status() == WS_CONNECTED; + }); + if (connected_iter != std::end(_clients)) { + async_ws_log_v("[%s] CLEANUP %" PRIu32 " (%u/%" PRIu16 ")", _url.c_str(), connected_iter->id(), connected, maxClients); + connected_iter->close(); + } } - } - for (auto iter = _clients.begin(); iter != _clients.end();) { - if (iter->shouldBeDeleted()) { - iter = _clients.erase(iter); - } else { - ++iter; + for (auto iter = _clients.begin(); iter != _clients.end();) { + if (iter->shouldBeDeleted()) { + auto current = iter++; + removed_clients.splice(removed_clients.end(), _clients, current); + } else { + ++iter; + } } } } diff --git a/src/AsyncWebSocket.h b/src/AsyncWebSocket.h index afed25f4..1740a3ec 100644 --- a/src/AsyncWebSocket.h +++ b/src/AsyncWebSocket.h @@ -256,6 +256,7 @@ class AsyncWebSocketClient { return _clientId; } AwsClientStatus status() const { + asyncsrv::lock_guard_type lock(_lock); return _status; } AsyncClient *client() { From 8544c50b320a2d0a236fdd81256036ea668cab0c Mon Sep 17 00:00:00 2001 From: Mathieu Carbou Date: Thu, 9 Apr 2026 14:59:07 +0200 Subject: [PATCH 3/3] Removed behavior changes to keep teh change set to fixes following commit 085f27806ddc6b72b8598dc1ad494e7c7f869e12 --- src/AsyncWebSocket.cpp | 71 +++++++++++++++++------------------------- src/AsyncWebSocket.h | 1 - 2 files changed, 29 insertions(+), 43 deletions(-) diff --git a/src/AsyncWebSocket.cpp b/src/AsyncWebSocket.cpp index 317b0b0d..16a51b4b 100644 --- a/src/AsyncWebSocket.cpp +++ b/src/AsyncWebSocket.cpp @@ -37,15 +37,6 @@ using namespace asyncsrv; -namespace { -AsyncWebSocketClient *find_connected_client_locked(std::list &clients, uint32_t id) { - const auto iter = std::find_if(clients.begin(), clients.end(), [id](const AsyncWebSocketClient &client) { - return client.id() == id && client.status() == WS_CONNECTED; - }); - return iter == clients.end() ? nullptr : &(*iter); -} -} // namespace - size_t webSocketSendFrameWindow(AsyncClient *client) { if (!client || !client->canSend()) { return 0; @@ -367,6 +358,7 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time) { void AsyncWebSocketClient::_onPoll() { asyncsrv::unique_lock_type lock(_lock); + if (!_client) { return; } @@ -454,6 +446,7 @@ bool AsyncWebSocketClient::canSend() const { bool AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, size_t len, bool mask) { asyncsrv::lock_guard_type lock(_lock); + if (!_client) { return false; } @@ -470,6 +463,7 @@ bool AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, si bool AsyncWebSocketClient::_queueMessage(AsyncWebSocketSharedBuffer buffer, uint8_t opcode, bool mask) { asyncsrv::unique_lock_type lock(_lock); + if (!_client || !buffer || buffer->empty() || _status != WS_CONNECTED) { return false; } @@ -958,6 +952,7 @@ bool AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len) { IPAddress AsyncWebSocketClient::remoteIP() const { asyncsrv::lock_guard_type lock(_lock); + if (!_client) { return IPAddress((uint32_t)0U); } @@ -967,6 +962,7 @@ IPAddress AsyncWebSocketClient::remoteIP() const { uint16_t AsyncWebSocketClient::remotePort() const { asyncsrv::lock_guard_type lock(_lock); + if (!_client) { return 0; } @@ -995,10 +991,14 @@ AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request) } void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient *client) { - (void)client; - // Defer removal to cleanupClients(). Disconnect callbacks can fire while - // iterating _clients for broadcast sends, and erasing here invalidates the - // active iterator in the caller. + asyncsrv::lock_guard_type lock(_lock); + const auto client_id = client->id(); + const auto iter = std::find_if(std::begin(_clients), std::end(_clients), [client_id](const AsyncWebSocketClient &c) { + return c.id() == client_id; + }); + if (iter != std::end(_clients)) { + _clients.erase(iter); + } } bool AsyncWebSocket::availableForWriteAll() { @@ -1040,7 +1040,7 @@ AsyncWebSocketClient *AsyncWebSocket::client(uint32_t id) { void AsyncWebSocket::close(uint32_t id, uint16_t code, const char *message) { asyncsrv::lock_guard_type lock(_lock); - if (AsyncWebSocketClient *c = find_connected_client_locked(_clients, id)) { + if (AsyncWebSocketClient *c = client(id)) { c->close(code, message); } } @@ -1055,37 +1055,24 @@ void AsyncWebSocket::closeAll(uint16_t code, const char *message) { } void AsyncWebSocket::cleanupClients(uint16_t maxClients) { - std::list removed_clients; - { - asyncsrv::lock_guard_type lock(_lock); - const size_t connected = std::count_if(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) { - return c.status() == WS_CONNECTED; - }); - - if (connected > maxClients) { - const auto connected_iter = std::find_if(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) { - return c.status() == WS_CONNECTED; - }); - if (connected_iter != std::end(_clients)) { - async_ws_log_v("[%s] CLEANUP %" PRIu32 " (%u/%" PRIu16 ")", _url.c_str(), connected_iter->id(), connected, maxClients); - connected_iter->close(); - } - } + asyncsrv::lock_guard_type lock(_lock); + const size_t c = count(); + if (c > maxClients) { + async_ws_log_v("[%s] CLEANUP %" PRIu32 " (%u/%" PRIu16 ")", _url.c_str(), _clients.front().id(), c, maxClients); + _clients.front().close(); + } - for (auto iter = _clients.begin(); iter != _clients.end();) { - if (iter->shouldBeDeleted()) { - auto current = iter++; - removed_clients.splice(removed_clients.end(), _clients, current); - } else { - ++iter; - } + for (auto i = _clients.begin(); i != _clients.end(); ++i) { + if (i->shouldBeDeleted()) { + _clients.erase(i); + break; } } } bool AsyncWebSocket::ping(uint32_t id, const uint8_t *data, size_t len) { asyncsrv::lock_guard_type lock(_lock); - AsyncWebSocketClient *c = find_connected_client_locked(_clients, id); + AsyncWebSocketClient *c = client(id); return c && c->ping(data, len); } @@ -1105,7 +1092,7 @@ AsyncWebSocket::SendStatus AsyncWebSocket::pingAll(const uint8_t *data, size_t l bool AsyncWebSocket::text(uint32_t id, const uint8_t *message, size_t len) { asyncsrv::lock_guard_type lock(_lock); - AsyncWebSocketClient *c = find_connected_client_locked(_clients, id); + AsyncWebSocketClient *c = client(id); return c && c->text(makeSharedBuffer(message, len)); } bool AsyncWebSocket::text(uint32_t id, const char *message, size_t len) { @@ -1152,7 +1139,7 @@ bool AsyncWebSocket::text(uint32_t id, AsyncWebSocketMessageBuffer *buffer) { } bool AsyncWebSocket::text(uint32_t id, AsyncWebSocketSharedBuffer buffer) { asyncsrv::lock_guard_type lock(_lock); - AsyncWebSocketClient *c = find_connected_client_locked(_clients, id); + AsyncWebSocketClient *c = client(id); return c && c->text(buffer); } @@ -1216,7 +1203,7 @@ AsyncWebSocket::SendStatus AsyncWebSocket::textAll(AsyncWebSocketSharedBuffer bu bool AsyncWebSocket::binary(uint32_t id, const uint8_t *message, size_t len) { asyncsrv::lock_guard_type lock(_lock); - AsyncWebSocketClient *c = find_connected_client_locked(_clients, id); + AsyncWebSocketClient *c = client(id); return c && c->binary(makeSharedBuffer(message, len)); } bool AsyncWebSocket::binary(uint32_t id, const char *message, size_t len) { @@ -1253,7 +1240,7 @@ bool AsyncWebSocket::binary(uint32_t id, AsyncWebSocketMessageBuffer *buffer) { } bool AsyncWebSocket::binary(uint32_t id, AsyncWebSocketSharedBuffer buffer) { asyncsrv::lock_guard_type lock(_lock); - AsyncWebSocketClient *c = find_connected_client_locked(_clients, id); + AsyncWebSocketClient *c = client(id); return c && c->binary(buffer); } diff --git a/src/AsyncWebSocket.h b/src/AsyncWebSocket.h index 1740a3ec..afed25f4 100644 --- a/src/AsyncWebSocket.h +++ b/src/AsyncWebSocket.h @@ -256,7 +256,6 @@ class AsyncWebSocketClient { return _clientId; } AwsClientStatus status() const { - asyncsrv::lock_guard_type lock(_lock); return _status; } AsyncClient *client() {