Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 28 additions & 34 deletions src/AsyncWebSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ AsyncWebSocketClient::AsyncWebSocketClient(AsyncClient *client, AsyncWebSocket *

AsyncWebSocketClient::~AsyncWebSocketClient() {
{
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_queue_lock);
_messageQueue.clear();
_controlQueue.clear();
}
Expand All @@ -313,7 +313,7 @@ void AsyncWebSocketClient::_clearQueue() {
void AsyncWebSocketClient::_onAck(size_t len, uint32_t time) {
_lastMessageTime = millis();

asyncsrv::unique_lock_type lock(_lock);
asyncsrv::unique_lock_type lock(_queue_lock);

async_ws_log_v("[%s][%" PRIu32 "] START ACK(%u, %" PRIu32 ") Q:%u", _server->url(), _clientId, len, time, _messageQueue.size());

Expand Down Expand Up @@ -357,7 +357,7 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time) {
}

void AsyncWebSocketClient::_onPoll() {
asyncsrv::unique_lock_type lock(_lock);
asyncsrv::unique_lock_type lock(_queue_lock);

if (!_client) {
return;
Expand Down Expand Up @@ -430,22 +430,22 @@ void AsyncWebSocketClient::_runQueue() {
}

bool AsyncWebSocketClient::queueIsFull() const {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_queue_lock);
return (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES) || (_status != WS_CONNECTED);
}

size_t AsyncWebSocketClient::queueLen() const {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_queue_lock);
return _messageQueue.size();
}

bool AsyncWebSocketClient::canSend() const {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_queue_lock);
return _messageQueue.size() < WS_MAX_QUEUED_MESSAGES;
}

bool AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, size_t len, bool mask) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_queue_lock);

if (!_client) {
return false;
Expand All @@ -462,7 +462,7 @@ bool AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, si
}

bool AsyncWebSocketClient::_queueMessage(AsyncWebSocketSharedBuffer buffer, uint8_t opcode, bool mask) {
asyncsrv::unique_lock_type lock(_lock);
asyncsrv::unique_lock_type lock(_queue_lock);

if (!_client || !buffer || buffer->empty() || _status != WS_CONNECTED) {
return false;
Expand Down Expand Up @@ -531,7 +531,9 @@ void AsyncWebSocketClient::close(uint16_t code, const char *message) {
return;
} else {
async_ws_log_e("Failed to allocate");
_client->abort();
if (_client) {
_client->abort();
}
}
}
_queueControl(WS_DISCONNECT);
Expand All @@ -546,7 +548,6 @@ void AsyncWebSocketClient::_onError(int8_t err) {
}

void AsyncWebSocketClient::_onTimeout(uint32_t time) {
asyncsrv::lock_guard_type lock(_lock);
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed because this is wrong to use the _queue_lock to protect the _client ptr (which is not guarded anywhere else it should be)

if (!_client) {
return;
}
Expand All @@ -555,7 +556,6 @@ void AsyncWebSocketClient::_onTimeout(uint32_t time) {
}

void AsyncWebSocketClient::_onDisconnect() {
asyncsrv::lock_guard_type lock(_lock);
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed because this is wrong to use the _queue_lock to protect the _client ptr (which is not guarded anywhere else it should be)

async_ws_log_v("[%s][%" PRIu32 "] DISCONNECT", _server->url(), _clientId);
_status = WS_DISCONNECTED;
_client = nullptr;
Expand Down Expand Up @@ -951,22 +951,16 @@ bool AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len) {
#endif

IPAddress AsyncWebSocketClient::remoteIP() const {
asyncsrv::lock_guard_type lock(_lock);
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed because this is wrong to use the _queue_lock to protect the _client ptr (which is not guarded anywhere else it should be)


if (!_client) {
return IPAddress((uint32_t)0U);
}

return _client->remoteIP();
}

uint16_t AsyncWebSocketClient::remotePort() const {
asyncsrv::lock_guard_type lock(_lock);
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed because this is wrong to use the _queue_lock to protect the _client ptr (which is not guarded anywhere else it should be)


if (!_client) {
return 0;
}

return _client->remotePort();
}

Expand All @@ -981,7 +975,7 @@ void AsyncWebSocket::_handleEvent(AsyncWebSocketClient *client, AwsEventType typ
}

AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
_clients.emplace_back(request, this);
// we've just detached AsyncTCP client from AsyncWebServerRequest
_handleEvent(&_clients.back(), WS_EVT_CONNECT, request, NULL, 0);
Expand All @@ -991,7 +985,7 @@ AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request)
}

void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient *client) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
const auto client_id = client->id();
const auto iter = std::find_if(std::begin(_clients), std::end(_clients), [client_id](const AsyncWebSocketClient &c) {
return c.id() == client_id;
Expand All @@ -1002,14 +996,14 @@ void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient *client) {
}

bool AsyncWebSocket::availableForWriteAll() {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
return std::none_of(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) {
return c.queueIsFull();
});
}

bool AsyncWebSocket::availableForWrite(uint32_t id) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
const auto iter = std::find_if(std::begin(_clients), std::end(_clients), [id](const AsyncWebSocketClient &c) {
return c.id() == id;
});
Expand All @@ -1020,14 +1014,14 @@ bool AsyncWebSocket::availableForWrite(uint32_t id) {
}

size_t AsyncWebSocket::count() const {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
return std::count_if(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) {
return c.status() == WS_CONNECTED;
});
}

AsyncWebSocketClient *AsyncWebSocket::client(uint32_t id) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
const auto iter = std::find_if(_clients.begin(), _clients.end(), [id](const AsyncWebSocketClient &c) {
return c.id() == id && c.status() == WS_CONNECTED;
});
Expand All @@ -1039,14 +1033,14 @@ AsyncWebSocketClient *AsyncWebSocket::client(uint32_t id) {
}

void AsyncWebSocket::close(uint32_t id, uint16_t code, const char *message) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
if (AsyncWebSocketClient *c = client(id)) {
c->close(code, message);
}
}

void AsyncWebSocket::closeAll(uint16_t code, const char *message) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
for (auto &c : _clients) {
if (c.status() == WS_CONNECTED) {
c.close(code, message);
Expand All @@ -1055,7 +1049,7 @@ void AsyncWebSocket::closeAll(uint16_t code, const char *message) {
}

void AsyncWebSocket::cleanupClients(uint16_t maxClients) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
const size_t c = count();
if (c > maxClients) {
async_ws_log_v("[%s] CLEANUP %" PRIu32 " (%u/%" PRIu16 ")", _url.c_str(), _clients.front().id(), c, maxClients);
Expand All @@ -1071,13 +1065,13 @@ void AsyncWebSocket::cleanupClients(uint16_t maxClients) {
}

bool AsyncWebSocket::ping(uint32_t id, const uint8_t *data, size_t len) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
AsyncWebSocketClient *c = client(id);
return c && c->ping(data, len);
}

AsyncWebSocket::SendStatus AsyncWebSocket::pingAll(const uint8_t *data, size_t len) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
size_t hit = 0;
size_t miss = 0;
for (auto &c : _clients) {
Expand All @@ -1091,7 +1085,7 @@ AsyncWebSocket::SendStatus AsyncWebSocket::pingAll(const uint8_t *data, size_t l
}

bool AsyncWebSocket::text(uint32_t id, const uint8_t *message, size_t len) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
AsyncWebSocketClient *c = client(id);
return c && c->text(makeSharedBuffer(message, len));
}
Expand Down Expand Up @@ -1138,7 +1132,7 @@ bool AsyncWebSocket::text(uint32_t id, AsyncWebSocketMessageBuffer *buffer) {
return enqueued;
}
bool AsyncWebSocket::text(uint32_t id, AsyncWebSocketSharedBuffer buffer) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
AsyncWebSocketClient *c = client(id);
return c && c->text(buffer);
}
Expand Down Expand Up @@ -1188,7 +1182,7 @@ AsyncWebSocket::SendStatus AsyncWebSocket::textAll(AsyncWebSocketMessageBuffer *
}

AsyncWebSocket::SendStatus AsyncWebSocket::textAll(AsyncWebSocketSharedBuffer buffer) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
size_t hit = 0;
size_t miss = 0;
for (auto &c : _clients) {
Expand All @@ -1202,7 +1196,7 @@ AsyncWebSocket::SendStatus AsyncWebSocket::textAll(AsyncWebSocketSharedBuffer bu
}

bool AsyncWebSocket::binary(uint32_t id, const uint8_t *message, size_t len) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
AsyncWebSocketClient *c = client(id);
return c && c->binary(makeSharedBuffer(message, len));
}
Expand Down Expand Up @@ -1239,7 +1233,7 @@ bool AsyncWebSocket::binary(uint32_t id, AsyncWebSocketMessageBuffer *buffer) {
return enqueued;
}
bool AsyncWebSocket::binary(uint32_t id, AsyncWebSocketSharedBuffer buffer) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
AsyncWebSocketClient *c = client(id);
return c && c->binary(buffer);
}
Expand Down Expand Up @@ -1280,7 +1274,7 @@ AsyncWebSocket::SendStatus AsyncWebSocket::binaryAll(AsyncWebSocketMessageBuffer
return status;
}
AsyncWebSocket::SendStatus AsyncWebSocket::binaryAll(AsyncWebSocketSharedBuffer buffer) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
size_t hit = 0;
size_t miss = 0;
for (auto &c : _clients) {
Expand Down
5 changes: 2 additions & 3 deletions src/AsyncWebSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class AsyncWebSocketClient {
uint8_t _pstate;
uint32_t _lastMessageTime;
uint32_t _keepAlivePeriod;
mutable asyncsrv::mutex_type _lock;
mutable asyncsrv::mutex_type _queue_lock;
std::deque<AsyncWebSocketControl> _controlQueue;
std::deque<AsyncWebSocketMessage> _messageQueue;
bool closeWhenFull = true;
Expand Down Expand Up @@ -303,7 +303,6 @@ class AsyncWebSocketClient {
uint16_t remotePort() const;

bool shouldBeDeleted() const {
asyncsrv::lock_guard_type lock(_lock);
return !_client;
}

Expand Down Expand Up @@ -371,7 +370,7 @@ class AsyncWebSocket : public AsyncWebHandler {
AwsEventHandler _eventHandler;
AwsHandshakeHandler _handshakeHandler;
bool _enabled;
mutable asyncsrv::mutex_type _lock;
mutable asyncsrv::mutex_type _ws_clients_lock;

public:
typedef enum {
Expand Down
Loading