Skip to content

Commit 0f9c4f1

Browse files
committed
Fix websocket client lifetime races during broadcast
1 parent b01d236 commit 0f9c4f1

2 files changed

Lines changed: 52 additions & 28 deletions

File tree

src/AsyncWebSocket.cpp

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,15 @@
3737

3838
using namespace asyncsrv;
3939

40+
namespace {
41+
AsyncWebSocketClient *find_connected_client_locked(std::list<AsyncWebSocketClient> &clients, uint32_t id) {
42+
const auto iter = std::find_if(clients.begin(), clients.end(), [id](const AsyncWebSocketClient &client) {
43+
return client.id() == id && client.status() == WS_CONNECTED;
44+
});
45+
return iter == clients.end() ? nullptr : &(*iter);
46+
}
47+
} // namespace
48+
4049
size_t webSocketSendFrameWindow(AsyncClient *client) {
4150
if (!client || !client->canSend()) {
4251
return 0;
@@ -357,11 +366,11 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time) {
357366
}
358367

359368
void AsyncWebSocketClient::_onPoll() {
369+
asyncsrv::unique_lock_type lock(_lock);
360370
if (!_client) {
361371
return;
362372
}
363373

364-
asyncsrv::unique_lock_type lock(_lock);
365374
if (_client && _client->canSend() && (!_controlQueue.empty() || !_messageQueue.empty())) {
366375
_runQueue();
367376
} else if (_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && (_controlQueue.empty() && _messageQueue.empty())) {
@@ -444,12 +453,11 @@ bool AsyncWebSocketClient::canSend() const {
444453
}
445454

446455
bool AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, size_t len, bool mask) {
456+
asyncsrv::lock_guard_type lock(_lock);
447457
if (!_client) {
448458
return false;
449459
}
450460

451-
asyncsrv::lock_guard_type lock(_lock);
452-
453461
_controlQueue.emplace_back(opcode, data, len, mask);
454462
async_ws_log_v("[%s][%" PRIu32 "] QUEUE CTRL (%u) << %" PRIu8, _server->url(), _clientId, _controlQueue.size(), opcode);
455463

@@ -461,12 +469,11 @@ bool AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, si
461469
}
462470

463471
bool AsyncWebSocketClient::_queueMessage(AsyncWebSocketSharedBuffer buffer, uint8_t opcode, bool mask) {
464-
if (!_client || buffer->size() == 0 || _status != WS_CONNECTED) {
472+
asyncsrv::unique_lock_type lock(_lock);
473+
if (!_client || !buffer || buffer->empty() || _status != WS_CONNECTED) {
465474
return false;
466475
}
467476

468-
asyncsrv::unique_lock_type lock(_lock);
469-
470477
if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES) {
471478
if (closeWhenFull) {
472479
_status = WS_DISCONNECTED;
@@ -545,6 +552,7 @@ void AsyncWebSocketClient::_onError(int8_t err) {
545552
}
546553

547554
void AsyncWebSocketClient::_onTimeout(uint32_t time) {
555+
asyncsrv::lock_guard_type lock(_lock);
548556
if (!_client) {
549557
return;
550558
}
@@ -553,7 +561,9 @@ void AsyncWebSocketClient::_onTimeout(uint32_t time) {
553561
}
554562

555563
void AsyncWebSocketClient::_onDisconnect() {
564+
asyncsrv::lock_guard_type lock(_lock);
556565
async_ws_log_v("[%s][%" PRIu32 "] DISCONNECT", _server->url(), _clientId);
566+
_status = WS_DISCONNECTED;
557567
_client = nullptr;
558568
_server->_handleDisconnect(this);
559569
}
@@ -947,6 +957,7 @@ bool AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len) {
947957
#endif
948958

949959
IPAddress AsyncWebSocketClient::remoteIP() const {
960+
asyncsrv::lock_guard_type lock(_lock);
950961
if (!_client) {
951962
return IPAddress((uint32_t)0U);
952963
}
@@ -955,6 +966,7 @@ IPAddress AsyncWebSocketClient::remoteIP() const {
955966
}
956967

957968
uint16_t AsyncWebSocketClient::remotePort() const {
969+
asyncsrv::lock_guard_type lock(_lock);
958970
if (!_client) {
959971
return 0;
960972
}
@@ -983,14 +995,10 @@ AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request)
983995
}
984996

985997
void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient *client) {
986-
asyncsrv::lock_guard_type lock(_lock);
987-
const auto client_id = client->id();
988-
const auto iter = std::find_if(std::begin(_clients), std::end(_clients), [client_id](const AsyncWebSocketClient &c) {
989-
return c.id() == client_id;
990-
});
991-
if (iter != std::end(_clients)) {
992-
_clients.erase(iter);
993-
}
998+
(void)client;
999+
// Defer removal to cleanupClients(). Disconnect callbacks can fire while
1000+
// iterating _clients for broadcast sends, and erasing here invalidates the
1001+
// active iterator in the caller.
9941002
}
9951003

9961004
bool AsyncWebSocket::availableForWriteAll() {
@@ -1031,7 +1039,8 @@ AsyncWebSocketClient *AsyncWebSocket::client(uint32_t id) {
10311039
}
10321040

10331041
void AsyncWebSocket::close(uint32_t id, uint16_t code, const char *message) {
1034-
if (AsyncWebSocketClient *c = client(id)) {
1042+
asyncsrv::lock_guard_type lock(_lock);
1043+
if (AsyncWebSocketClient *c = find_connected_client_locked(_clients, id)) {
10351044
c->close(code, message);
10361045
}
10371046
}
@@ -1047,22 +1056,32 @@ void AsyncWebSocket::closeAll(uint16_t code, const char *message) {
10471056

10481057
void AsyncWebSocket::cleanupClients(uint16_t maxClients) {
10491058
asyncsrv::lock_guard_type lock(_lock);
1050-
const size_t c = count();
1051-
if (c > maxClients) {
1052-
async_ws_log_v("[%s] CLEANUP %" PRIu32 " (%u/%" PRIu16 ")", _url.c_str(), _clients.front().id(), c, maxClients);
1053-
_clients.front().close();
1059+
const size_t connected = std::count_if(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) {
1060+
return c.status() == WS_CONNECTED;
1061+
});
1062+
1063+
if (connected > maxClients) {
1064+
const auto connected_iter = std::find_if(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) {
1065+
return c.status() == WS_CONNECTED;
1066+
});
1067+
if (connected_iter != std::end(_clients)) {
1068+
async_ws_log_v("[%s] CLEANUP %" PRIu32 " (%u/%" PRIu16 ")", _url.c_str(), connected_iter->id(), connected, maxClients);
1069+
connected_iter->close();
1070+
}
10541071
}
10551072

1056-
for (auto i = _clients.begin(); i != _clients.end(); ++i) {
1057-
if (i->shouldBeDeleted()) {
1058-
_clients.erase(i);
1059-
break;
1073+
for (auto iter = _clients.begin(); iter != _clients.end();) {
1074+
if (iter->shouldBeDeleted()) {
1075+
iter = _clients.erase(iter);
1076+
} else {
1077+
++iter;
10601078
}
10611079
}
10621080
}
10631081

10641082
bool AsyncWebSocket::ping(uint32_t id, const uint8_t *data, size_t len) {
1065-
AsyncWebSocketClient *c = client(id);
1083+
asyncsrv::lock_guard_type lock(_lock);
1084+
AsyncWebSocketClient *c = find_connected_client_locked(_clients, id);
10661085
return c && c->ping(data, len);
10671086
}
10681087

@@ -1081,7 +1100,8 @@ AsyncWebSocket::SendStatus AsyncWebSocket::pingAll(const uint8_t *data, size_t l
10811100
}
10821101

10831102
bool AsyncWebSocket::text(uint32_t id, const uint8_t *message, size_t len) {
1084-
AsyncWebSocketClient *c = client(id);
1103+
asyncsrv::lock_guard_type lock(_lock);
1104+
AsyncWebSocketClient *c = find_connected_client_locked(_clients, id);
10851105
return c && c->text(makeSharedBuffer(message, len));
10861106
}
10871107
bool AsyncWebSocket::text(uint32_t id, const char *message, size_t len) {
@@ -1127,7 +1147,8 @@ bool AsyncWebSocket::text(uint32_t id, AsyncWebSocketMessageBuffer *buffer) {
11271147
return enqueued;
11281148
}
11291149
bool AsyncWebSocket::text(uint32_t id, AsyncWebSocketSharedBuffer buffer) {
1130-
AsyncWebSocketClient *c = client(id);
1150+
asyncsrv::lock_guard_type lock(_lock);
1151+
AsyncWebSocketClient *c = find_connected_client_locked(_clients, id);
11311152
return c && c->text(buffer);
11321153
}
11331154

@@ -1190,7 +1211,8 @@ AsyncWebSocket::SendStatus AsyncWebSocket::textAll(AsyncWebSocketSharedBuffer bu
11901211
}
11911212

11921213
bool AsyncWebSocket::binary(uint32_t id, const uint8_t *message, size_t len) {
1193-
AsyncWebSocketClient *c = client(id);
1214+
asyncsrv::lock_guard_type lock(_lock);
1215+
AsyncWebSocketClient *c = find_connected_client_locked(_clients, id);
11941216
return c && c->binary(makeSharedBuffer(message, len));
11951217
}
11961218
bool AsyncWebSocket::binary(uint32_t id, const char *message, size_t len) {
@@ -1226,7 +1248,8 @@ bool AsyncWebSocket::binary(uint32_t id, AsyncWebSocketMessageBuffer *buffer) {
12261248
return enqueued;
12271249
}
12281250
bool AsyncWebSocket::binary(uint32_t id, AsyncWebSocketSharedBuffer buffer) {
1229-
AsyncWebSocketClient *c = client(id);
1251+
asyncsrv::lock_guard_type lock(_lock);
1252+
AsyncWebSocketClient *c = find_connected_client_locked(_clients, id);
12301253
return c && c->binary(buffer);
12311254
}
12321255

src/AsyncWebSocket.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ class AsyncWebSocketClient {
303303
uint16_t remotePort() const;
304304

305305
bool shouldBeDeleted() const {
306+
asyncsrv::lock_guard_type lock(_lock);
306307
return !_client;
307308
}
308309

0 commit comments

Comments
 (0)