Commit d125dc4917664850d3bba2e9eedfffa76ebbf9f6

Authored by Wiebe Cazemier
1 parent efc917f9

Disconnect clients with good reason code

On server shutdown and when taking over a session.

On disconnect, wills are queued first, we wait for the queueing to be done, then
initiate disconnect.

When TCP buffers are full and fds are not reported by epoll, the thread
loop still exits and clients are just closed on exit.
client.cpp
@@ -407,6 +407,32 @@ void Client::sendOrQueueWill() @@ -407,6 +407,32 @@ void Client::sendOrQueueWill()
407 this->willPublish.reset(); 407 this->willPublish.reset();
408 } 408 }
409 409
  410 +/**
  411 + * @brief Client::serverInitiatedDisconnect queues a disconnect packet and when the last bytes are written, the thread loop will disconnect it.
  412 + * @param reason is an MQTT5 reason code.
  413 + *
  414 + * There is a chance that an client's TCP buffers are full (when the client is gone, for example) and epoll will not report the
  415 + * fd as EPOLLOUT, which means the disconnect will not happen. It will then be up to the keep-alive mechanism to kick the client out.
  416 + *
  417 + * Sending clients disconnect packets is only supported by MQTT >= 5, so in case of MQTT3, just close the connection.
  418 + */
  419 +void Client::serverInitiatedDisconnect(ReasonCodes reason)
  420 +{
  421 + setDisconnectReason(formatString("Server initiating disconnect with reason code '%d'", static_cast<uint8_t>(reason)));
  422 +
  423 + if (this->protocolVersion >= ProtocolVersion::Mqtt5)
  424 + {
  425 + setReadyForDisconnect();
  426 + Disconnect d(ProtocolVersion::Mqtt5, reason);
  427 + writeMqttPacket(d);
  428 + }
  429 + else
  430 + {
  431 + markAsDisconnecting();
  432 + threadData->removeClientQueued(fd);
  433 + }
  434 +}
  435 +
410 #ifndef NDEBUG 436 #ifndef NDEBUG
411 /** 437 /**
412 * @brief IoWrapper::setFakeUpgraded(). 438 * @brief IoWrapper::setFakeUpgraded().
client.h
@@ -150,6 +150,7 @@ public: @@ -150,6 +150,7 @@ public:
150 uint32_t getMaxIncomingPacketSize() const; 150 uint32_t getMaxIncomingPacketSize() const;
151 151
152 void sendOrQueueWill(); 152 void sendOrQueueWill();
  153 + void serverInitiatedDisconnect(ReasonCodes reason);
153 154
154 #ifndef NDEBUG 155 #ifndef NDEBUG
155 void setFakeUpgraded(); 156 void setFakeUpgraded();
mainapp.cpp
@@ -285,9 +285,17 @@ void MainApp::queueRemoveExpiredSessions() @@ -285,9 +285,17 @@ void MainApp::queueRemoveExpiredSessions()
285 } 285 }
286 } 286 }
287 287
288 -void MainApp::waitForAllThreadsQueuedWills() 288 +void MainApp::waitForWillsQueued()
289 { 289 {
290 - while(std::any_of(threads.begin(), threads.end(), [](std::shared_ptr<ThreadData> t){ return !t->allWilssSentForExit; })) 290 + while(std::any_of(threads.begin(), threads.end(), [](std::shared_ptr<ThreadData> t){ return !t->allWillsQueued; }))
  291 + {
  292 + usleep(1000);
  293 + }
  294 +}
  295 +
  296 +void MainApp::waitForDisconnectsInitiated()
  297 +{
  298 + while(std::any_of(threads.begin(), threads.end(), [](std::shared_ptr<ThreadData> t){ return !t->allDisconnectsSent; }))
291 { 299 {
292 usleep(1000); 300 usleep(1000);
293 } 301 }
@@ -618,10 +626,16 @@ void MainApp::start() @@ -618,10 +626,16 @@ void MainApp::start()
618 logger->logf(LOG_DEBUG, "Having all client in all threads send or queue their will."); 626 logger->logf(LOG_DEBUG, "Having all client in all threads send or queue their will.");
619 for(std::shared_ptr<ThreadData> &thread : threads) 627 for(std::shared_ptr<ThreadData> &thread : threads)
620 { 628 {
621 - thread->queueSendAllWills(); 629 + thread->queueSendWills();
622 } 630 }
  631 + waitForWillsQueued();
623 632
624 - waitForAllThreadsQueuedWills(); 633 + logger->logf(LOG_DEBUG, "Having all client in all threads send a disconnect packet.");
  634 + for(std::shared_ptr<ThreadData> &thread : threads)
  635 + {
  636 + thread->queueSendDisconnects();
  637 + }
  638 + waitForDisconnectsInitiated();
625 639
626 oneInstanceLock.unlock(); 640 oneInstanceLock.unlock();
627 641
mainapp.h
@@ -93,7 +93,8 @@ class MainApp @@ -93,7 +93,8 @@ class MainApp
93 void saveStateInThread(); 93 void saveStateInThread();
94 void queueSendQueuedWills(); 94 void queueSendQueuedWills();
95 void queueRemoveExpiredSessions(); 95 void queueRemoveExpiredSessions();
96 - void waitForAllThreadsQueuedWills(); 96 + void waitForWillsQueued();
  97 + void waitForDisconnectsInitiated();
97 98
98 MainApp(const std::string &configFilePath); 99 MainApp(const std::string &configFilePath);
99 public: 100 public:
subscriptionstore.cpp
@@ -230,10 +230,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr&lt;Client&gt; @@ -230,10 +230,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr&lt;Client&gt;
230 { 230 {
231 logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str()); 231 logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str());
232 cl->setDisconnectReason("Another client with this ID connected"); 232 cl->setDisconnectReason("Another client with this ID connected");
233 -  
234 - cl->setReadyForDisconnect();  
235 - cl->getThreadData()->removeClientQueued(cl);  
236 - cl->markAsDisconnecting(); 233 + cl->serverInitiatedDisconnect(ReasonCodes::SessionTakenOver);
237 } 234 }
238 235
239 } 236 }
threaddata.cpp
@@ -162,16 +162,39 @@ void ThreadData::removeExpiredSessions() @@ -162,16 +162,39 @@ void ThreadData::removeExpiredSessions()
162 subscriptionStore->removeExpiredSessionsClients(); 162 subscriptionStore->removeExpiredSessionsClients();
163 } 163 }
164 164
165 -void ThreadData::sendAllWils() 165 +void ThreadData::sendAllWills()
166 { 166 {
167 std::lock_guard<std::mutex> lck(clients_by_fd_mutex); 167 std::lock_guard<std::mutex> lck(clients_by_fd_mutex);
168 168
169 - for(auto pairs : clients_by_fd) 169 + for(auto &pair : clients_by_fd)
170 { 170 {
171 - pairs.second->sendOrQueueWill(); 171 + std::shared_ptr<Client> &c = pair.second;
  172 + c->sendOrQueueWill();
172 } 173 }
173 174
174 - allWilssSentForExit = true; 175 + allWillsQueued = true;
  176 +}
  177 +
  178 +void ThreadData::sendAllDisconnects()
  179 +{
  180 + std::vector<std::shared_ptr<Client>> clientsFound;
  181 +
  182 + {
  183 + std::lock_guard<std::mutex> lck(clients_by_fd_mutex);
  184 + clientsFound.reserve(clients_by_fd.size());
  185 +
  186 + for(auto &pair : clients_by_fd)
  187 + {
  188 + clientsFound.push_back(pair.second);
  189 + }
  190 + }
  191 +
  192 + for (std::shared_ptr<Client> &c : clientsFound)
  193 + {
  194 + c->serverInitiatedDisconnect(ReasonCodes::ServerShuttingDown);
  195 + }
  196 +
  197 + allDisconnectsSent = true;
175 } 198 }
176 199
177 void ThreadData::removeQueuedClients() 200 void ThreadData::removeQueuedClients()
@@ -401,11 +424,21 @@ void ThreadData::authPluginPeriodicEvent() @@ -401,11 +424,21 @@ void ThreadData::authPluginPeriodicEvent()
401 authentication.periodicEvent(); 424 authentication.periodicEvent();
402 } 425 }
403 426
404 -void ThreadData::queueSendAllWills() 427 +void ThreadData::queueSendWills()
  428 +{
  429 + std::lock_guard<std::mutex> locker(taskQueueMutex);
  430 +
  431 + auto f = std::bind(&ThreadData::sendAllWills, this);
  432 + taskQueue.push_front(f);
  433 +
  434 + wakeUpThread();
  435 +}
  436 +
  437 +void ThreadData::queueSendDisconnects()
405 { 438 {
406 std::lock_guard<std::mutex> locker(taskQueueMutex); 439 std::lock_guard<std::mutex> locker(taskQueueMutex);
407 440
408 - auto f = std::bind(&ThreadData::sendAllWils, this); 441 + auto f = std::bind(&ThreadData::sendAllDisconnects, this);
409 taskQueue.push_front(f); 442 taskQueue.push_front(f);
410 443
411 wakeUpThread(); 444 wakeUpThread();
threaddata.h
@@ -66,7 +66,8 @@ class ThreadData @@ -66,7 +66,8 @@ class ThreadData
66 void publishStat(const std::string &topic, uint64_t n); 66 void publishStat(const std::string &topic, uint64_t n);
67 void sendQueuedWills(); 67 void sendQueuedWills();
68 void removeExpiredSessions(); 68 void removeExpiredSessions();
69 - void sendAllWils(); 69 + void sendAllWills();
  70 + void sendAllDisconnects();
70 71
71 void removeQueuedClients(); 72 void removeQueuedClients();
72 73
@@ -75,7 +76,8 @@ public: @@ -75,7 +76,8 @@ public:
75 Authentication authentication; 76 Authentication authentication;
76 bool running = true; 77 bool running = true;
77 bool finished = false; 78 bool finished = false;
78 - bool allWilssSentForExit = false; 79 + bool allWillsQueued = false;
  80 + bool allDisconnectsSent = false;
79 std::thread thread; 81 std::thread thread;
80 int threadnr = 0; 82 int threadnr = 0;
81 int epollfd = 0; 83 int epollfd = 0;
@@ -120,7 +122,8 @@ public: @@ -120,7 +122,8 @@ public:
120 void queueAuthPluginPeriodicEvent(); 122 void queueAuthPluginPeriodicEvent();
121 void authPluginPeriodicEvent(); 123 void authPluginPeriodicEvent();
122 124
123 - void queueSendAllWills(); 125 + void queueSendWills();
  126 + void queueSendDisconnects();
124 }; 127 };
125 128
126 #endif // THREADDATA_H 129 #endif // THREADDATA_H
threadloop.cpp
@@ -144,6 +144,10 @@ void do_thread_work(ThreadData *threadData) @@ -144,6 +144,10 @@ void do_thread_work(ThreadData *threadData)
144 MqttPacket p(d); 144 MqttPacket p(d);
145 client->writeMqttPacket(p); 145 client->writeMqttPacket(p);
146 client->setReadyForDisconnect(); 146 client->setReadyForDisconnect();
  147 +
  148 + // When a client's TCP buffers are full (when the client is gone, for instance), EPOLLOUT will never be
  149 + // reported. In those cases, the client is not removed; not until the keep-alive mechanism anyway. Is
  150 + // that a problem?
147 } 151 }
148 else 152 else
149 { 153 {