diff --git a/client.cpp b/client.cpp index 69f303c..124ebe8 100644 --- a/client.cpp +++ b/client.cpp @@ -407,6 +407,32 @@ void Client::sendOrQueueWill() this->willPublish.reset(); } +/** + * @brief Client::serverInitiatedDisconnect queues a disconnect packet and when the last bytes are written, the thread loop will disconnect it. + * @param reason is an MQTT5 reason code. + * + * There is a chance that an client's TCP buffers are full (when the client is gone, for example) and epoll will not report the + * fd as EPOLLOUT, which means the disconnect will not happen. It will then be up to the keep-alive mechanism to kick the client out. + * + * Sending clients disconnect packets is only supported by MQTT >= 5, so in case of MQTT3, just close the connection. + */ +void Client::serverInitiatedDisconnect(ReasonCodes reason) +{ + setDisconnectReason(formatString("Server initiating disconnect with reason code '%d'", static_cast(reason))); + + if (this->protocolVersion >= ProtocolVersion::Mqtt5) + { + setReadyForDisconnect(); + Disconnect d(ProtocolVersion::Mqtt5, reason); + writeMqttPacket(d); + } + else + { + markAsDisconnecting(); + threadData->removeClientQueued(fd); + } +} + #ifndef NDEBUG /** * @brief IoWrapper::setFakeUpgraded(). diff --git a/client.h b/client.h index b2558e5..4cc4061 100644 --- a/client.h +++ b/client.h @@ -150,6 +150,7 @@ public: uint32_t getMaxIncomingPacketSize() const; void sendOrQueueWill(); + void serverInitiatedDisconnect(ReasonCodes reason); #ifndef NDEBUG void setFakeUpgraded(); diff --git a/mainapp.cpp b/mainapp.cpp index 60ae030..9cab856 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -285,9 +285,17 @@ void MainApp::queueRemoveExpiredSessions() } } -void MainApp::waitForAllThreadsQueuedWills() +void MainApp::waitForWillsQueued() { - while(std::any_of(threads.begin(), threads.end(), [](std::shared_ptr t){ return !t->allWilssSentForExit; })) + while(std::any_of(threads.begin(), threads.end(), [](std::shared_ptr t){ return !t->allWillsQueued; })) + { + usleep(1000); + } +} + +void MainApp::waitForDisconnectsInitiated() +{ + while(std::any_of(threads.begin(), threads.end(), [](std::shared_ptr t){ return !t->allDisconnectsSent; })) { usleep(1000); } @@ -618,10 +626,16 @@ void MainApp::start() logger->logf(LOG_DEBUG, "Having all client in all threads send or queue their will."); for(std::shared_ptr &thread : threads) { - thread->queueSendAllWills(); + thread->queueSendWills(); } + waitForWillsQueued(); - waitForAllThreadsQueuedWills(); + logger->logf(LOG_DEBUG, "Having all client in all threads send a disconnect packet."); + for(std::shared_ptr &thread : threads) + { + thread->queueSendDisconnects(); + } + waitForDisconnectsInitiated(); oneInstanceLock.unlock(); diff --git a/mainapp.h b/mainapp.h index 0629c8e..4e2da90 100644 --- a/mainapp.h +++ b/mainapp.h @@ -93,7 +93,8 @@ class MainApp void saveStateInThread(); void queueSendQueuedWills(); void queueRemoveExpiredSessions(); - void waitForAllThreadsQueuedWills(); + void waitForWillsQueued(); + void waitForDisconnectsInitiated(); MainApp(const std::string &configFilePath); public: diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index ccbd0f3..f19003c 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -230,10 +230,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr { logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str()); cl->setDisconnectReason("Another client with this ID connected"); - - cl->setReadyForDisconnect(); - cl->getThreadData()->removeClientQueued(cl); - cl->markAsDisconnecting(); + cl->serverInitiatedDisconnect(ReasonCodes::SessionTakenOver); } } diff --git a/threaddata.cpp b/threaddata.cpp index f3b62b8..519b8f9 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -162,16 +162,39 @@ void ThreadData::removeExpiredSessions() subscriptionStore->removeExpiredSessionsClients(); } -void ThreadData::sendAllWils() +void ThreadData::sendAllWills() { std::lock_guard lck(clients_by_fd_mutex); - for(auto pairs : clients_by_fd) + for(auto &pair : clients_by_fd) { - pairs.second->sendOrQueueWill(); + std::shared_ptr &c = pair.second; + c->sendOrQueueWill(); } - allWilssSentForExit = true; + allWillsQueued = true; +} + +void ThreadData::sendAllDisconnects() +{ + std::vector> clientsFound; + + { + std::lock_guard lck(clients_by_fd_mutex); + clientsFound.reserve(clients_by_fd.size()); + + for(auto &pair : clients_by_fd) + { + clientsFound.push_back(pair.second); + } + } + + for (std::shared_ptr &c : clientsFound) + { + c->serverInitiatedDisconnect(ReasonCodes::ServerShuttingDown); + } + + allDisconnectsSent = true; } void ThreadData::removeQueuedClients() @@ -401,11 +424,21 @@ void ThreadData::authPluginPeriodicEvent() authentication.periodicEvent(); } -void ThreadData::queueSendAllWills() +void ThreadData::queueSendWills() +{ + std::lock_guard locker(taskQueueMutex); + + auto f = std::bind(&ThreadData::sendAllWills, this); + taskQueue.push_front(f); + + wakeUpThread(); +} + +void ThreadData::queueSendDisconnects() { std::lock_guard locker(taskQueueMutex); - auto f = std::bind(&ThreadData::sendAllWils, this); + auto f = std::bind(&ThreadData::sendAllDisconnects, this); taskQueue.push_front(f); wakeUpThread(); diff --git a/threaddata.h b/threaddata.h index 3e0727b..ef7cb0f 100644 --- a/threaddata.h +++ b/threaddata.h @@ -66,7 +66,8 @@ class ThreadData void publishStat(const std::string &topic, uint64_t n); void sendQueuedWills(); void removeExpiredSessions(); - void sendAllWils(); + void sendAllWills(); + void sendAllDisconnects(); void removeQueuedClients(); @@ -75,7 +76,8 @@ public: Authentication authentication; bool running = true; bool finished = false; - bool allWilssSentForExit = false; + bool allWillsQueued = false; + bool allDisconnectsSent = false; std::thread thread; int threadnr = 0; int epollfd = 0; @@ -120,7 +122,8 @@ public: void queueAuthPluginPeriodicEvent(); void authPluginPeriodicEvent(); - void queueSendAllWills(); + void queueSendWills(); + void queueSendDisconnects(); }; #endif // THREADDATA_H diff --git a/threadloop.cpp b/threadloop.cpp index a19b349..9e60ae5 100644 --- a/threadloop.cpp +++ b/threadloop.cpp @@ -144,6 +144,10 @@ void do_thread_work(ThreadData *threadData) MqttPacket p(d); client->writeMqttPacket(p); client->setReadyForDisconnect(); + + // When a client's TCP buffers are full (when the client is gone, for instance), EPOLLOUT will never be + // reported. In those cases, the client is not removed; not until the keep-alive mechanism anyway. Is + // that a problem? } else {