From 4bfef9db006a58c8fc1a94a96db78f30b05a2d12 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Sun, 14 Nov 2021 11:33:41 +0100 Subject: [PATCH] Fix several deadlocks --- client.cpp | 7 ++++++- client.h | 2 ++ mqttpacket.cpp | 2 +- subscriptionstore.cpp | 13 +++++++++++-- threaddata.cpp | 95 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------- threaddata.h | 7 ++++++- threadloop.cpp | 15 ++++++++++++--- 7 files changed, 123 insertions(+), 18 deletions(-) diff --git a/client.cpp b/client.cpp index bf2b25f..790f408 100644 --- a/client.cpp +++ b/client.cpp @@ -222,7 +222,7 @@ void Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet, const c } catch (std::exception &ex) { - threadData->removeClient(fd); + threadData->removeClientQueued(fd); } } @@ -322,6 +322,11 @@ void Client::resetBuffersIfEligible() writebuf.resetSizeIfEligable(initialBufferSize); } +void Client::setCleanSession(bool val) +{ + this->cleanSession = val; +} + #ifndef NDEBUG /** * @brief IoWrapper::setFakeUpgraded(). diff --git a/client.h b/client.h index 94c70b3..c8f8f12 100644 --- a/client.h +++ b/client.h @@ -134,6 +134,8 @@ public: std::string getKeepAliveInfoString() const; void resetBuffersIfEligible(); + void setCleanSession(bool val); + #ifndef NDEBUG void setFakeUpgraded(); #endif diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 3b0101c..78dd626 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -412,7 +412,7 @@ void MqttPacket::handleDisconnect() sender->setDisconnectReason("MQTT Disconnect received."); sender->markAsDisconnecting(); sender->clearWill(); - sender->getThreadData()->removeClient(sender); + sender->getThreadData()->removeClientQueued(sender); } void MqttPacket::handleSubscribe() diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index 32ac75f..0b0cadc 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -202,6 +202,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr if (client->getClientId().empty()) throw ProtocolError("Trying to store client without an ID."); + bool originalClientDemandsSessionDestruction = false; std::shared_ptr session; auto session_it = sessionsById.find(client->getClientId()); if (session_it != sessionsById.end()) @@ -215,14 +216,22 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr if (cl) { logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str()); + cl->setDisconnectReason("Another client with this ID connected"); + + // We have to set session to false, because it's no longer up to the destruction of that client + // to destroy the session. We either do it in this function, or not at all. + originalClientDemandsSessionDestruction = cl->getCleanSession(); + cl->setCleanSession(false); + cl->setReadyForDisconnect(); - cl->getThreadData()->removeClient(cl); + cl->getThreadData()->removeClientQueued(cl); cl->markAsDisconnecting(); } + } } - if (!session || client->getCleanSession()) + if (!session || client->getCleanSession() || originalClientDemandsSessionDestruction) { session.reset(new Session()); diff --git a/threaddata.cpp b/threaddata.cpp index e88993e..9d59331 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -131,6 +131,36 @@ void ThreadData::publishStat(const std::string &topic, uint64_t n) subscriptionStore->setRetainedMessage(topic, subtopics, payload, 0); } +void ThreadData::removeQueuedClients() +{ + std::vector fds; + fds.reserve(1024); // 1024 is arbitrary... + + { + std::lock_guard lck2(clientsToRemoveMutex); + + for (const std::weak_ptr &c : clientsQueuedForRemoving) + { + std::shared_ptr client = c.lock(); + if (client) + { + int fd = client->getFd(); + fds.push_back(fd); + } + } + + clientsQueuedForRemoving.clear(); + } + + { + std::lock_guard lck(clients_by_fd_mutex); + for(int fd : fds) + { + clients_by_fd.erase(fd); + } + } +} + void ThreadData::giveClient(std::shared_ptr client) { clients_by_fd_mutex.lock(); @@ -151,25 +181,70 @@ std::shared_ptr ThreadData::getClient(int fd) return this->clients_by_fd[fd]; } -void ThreadData::removeClient(std::shared_ptr client) +void ThreadData::removeClientQueued(const std::shared_ptr &client) { - client->markAsDisconnecting(); + bool wakeUpNeeded = true; - std::lock_guard lck(clients_by_fd_mutex); - clients_by_fd.erase(client->getFd()); + { + std::lock_guard locker(clientsToRemoveMutex); + wakeUpNeeded = clientsQueuedForRemoving.empty(); + clientsQueuedForRemoving.push_front(client); + } + + if (wakeUpNeeded) + { + auto f = std::bind(&ThreadData::removeQueuedClients, this); + std::lock_guard lockertaskQueue(taskQueueMutex); + taskQueue.push_front(f); + + wakeUpThread(); + } } -void ThreadData::removeClient(int fd) +void ThreadData::removeClientQueued(int fd) { - std::lock_guard lck(clients_by_fd_mutex); - auto client_it = this->clients_by_fd.find(fd); - if (client_it != this->clients_by_fd.end()) + bool wakeUpNeeded = true; + std::shared_ptr clientFound; + + { + std::lock_guard lck(clients_by_fd_mutex); + auto client_it = this->clients_by_fd.find(fd); + if (client_it != this->clients_by_fd.end()) + { + clientFound = client_it->second; + } + } + + if (clientFound) { - client_it->second->markAsDisconnecting(); - this->clients_by_fd.erase(fd); + { + std::lock_guard locker(clientsToRemoveMutex); + wakeUpNeeded = clientsQueuedForRemoving.empty(); + clientsQueuedForRemoving.push_front(clientFound); + } + + if (wakeUpNeeded) + { + auto f = std::bind(&ThreadData::removeQueuedClients, this); + std::lock_guard lockertaskQueue(taskQueueMutex); + taskQueue.push_front(f); + + wakeUpThread(); + } } } +void ThreadData::removeClient(std::shared_ptr client) +{ + // This function is only for same-thread calling. + assert(pthread_self() == thread.native_handle()); + + client->markAsDisconnecting(); + + std::lock_guard lck(clients_by_fd_mutex); + clients_by_fd.erase(client->getFd()); +} + std::shared_ptr &ThreadData::getSubscriptionStore() { return subscriptionStore; diff --git a/threaddata.h b/threaddata.h index 713c905..a4e45a8 100644 --- a/threaddata.h +++ b/threaddata.h @@ -56,6 +56,8 @@ class ThreadData uint64_t sentMessageCountPrevious = 0; std::chrono::time_point sentMessagePreviousTime = std::chrono::steady_clock::now(); + std::mutex clientsToRemoveMutex; + std::forward_list> clientsQueuedForRemoving; void reload(std::shared_ptr settings); void wakeUpThread(); @@ -64,6 +66,8 @@ class ThreadData void publishStatsOnDollarTopic(std::vector> &threads); void publishStat(const std::string &topic, uint64_t n); + void removeQueuedClients(); + public: Settings settingsLocalCopy; // Is updated on reload, within the thread loop. Authentication authentication; @@ -84,8 +88,9 @@ public: void giveClient(std::shared_ptr client); std::shared_ptr getClient(int fd); + void removeClientQueued(const std::shared_ptr &client); + void removeClientQueued(int fd); void removeClient(std::shared_ptr client); - void removeClient(int fd); std::shared_ptr &getSubscriptionStore(); void initAuthPlugin(); diff --git a/threadloop.cpp b/threadloop.cpp index a9b1903..a790578 100644 --- a/threadloop.cpp +++ b/threadloop.cpp @@ -64,12 +64,21 @@ void do_thread_work(ThreadData *threadData) uint64_t eventfd_value = 0; check(read(fd, &eventfd_value, sizeof(uint64_t))); - std::lock_guard locker(threadData->taskQueueMutex); - for(auto &f : threadData->taskQueue) + std::forward_list> copiedTasks; + + { + std::lock_guard locker(threadData->taskQueueMutex); + for(auto &f : threadData->taskQueue) + { + copiedTasks.push_front(std::move(f)); + } + threadData->taskQueue.clear(); + } + + for(auto &f : copiedTasks) { f(); } - threadData->taskQueue.clear(); continue; } -- libgit2 0.21.4