Commit ec40e5b1572f04c526c958585e97e2a131787bb3

Authored by Wiebe Cazemier
1 parent e55fbfe4

Efficient client expiration checking

Check events are placed in a sorted map based on the last activity and
keep-alive interval of the client.

This makes it more accurate and reduces system load because it saves
unnecessary checking.
client.cpp
... ... @@ -345,7 +345,7 @@ bool Client::keepAliveExpired()
345 345 if (!authenticated)
346 346 return lastActivity + std::chrono::seconds(20) < now;
347 347  
348   - std::chrono::seconds x(keepalive*10/5);
  348 + std::chrono::seconds x(keepalive + keepalive/2);
349 349 bool result = (lastActivity + x) < now;
350 350 return result;
351 351 }
... ... @@ -648,6 +648,34 @@ void Client::setDisconnectReason(const std::string &amp;reason)
648 648 this->disconnectReason.append(reason);
649 649 }
650 650  
  651 +/**
  652 + * @brief Client::getSecondsTillKillTime gets the amount of seconds from now at which this client should be killed when it was quiet.
  653 + * @return
  654 + *
  655 + * "If the Keep Alive value is non-zero and the Server does not receive an MQTT Control Packet from the Client within one and a
  656 + * half times the Keep Alive time period, it MUST close the Network Connection to the Client as if the network had failed [MQTT-3.1.2-22].
  657 + */
  658 +std::chrono::seconds Client::getSecondsTillKillTime() const
  659 +{
  660 + if (!this->authenticated)
  661 + return std::chrono::seconds(30);
  662 +
  663 + if (this->keepalive == 0)
  664 + return std::chrono::seconds(0);
  665 +
  666 + const uint32_t timeOfSilenceMeansKill = this->keepalive + (this->keepalive / 2) + 2;
  667 + std::chrono::time_point<std::chrono::steady_clock> killTime = this->lastActivity + std::chrono::seconds(timeOfSilenceMeansKill);
  668 +
  669 + std::chrono::seconds secondsTillKillTime = std::chrono::duration_cast<std::chrono::seconds>(killTime - std::chrono::steady_clock::now());
  670 +
  671 + // We floor it, but also protect against the theoretically impossible negative value. Kill time shouldn't be in the past, because then we would
  672 + // have killed it already.
  673 + if (secondsTillKillTime < std::chrono::seconds(5))
  674 + return std::chrono::seconds(5);
  675 +
  676 + return secondsTillKillTime;
  677 +}
  678 +
651 679 void Client::clearWill()
652 680 {
653 681 willPublish.reset();
... ...
client.h
... ... @@ -146,6 +146,7 @@ public:
146 146 void assignSession(std::shared_ptr<Session> &session);
147 147 std::shared_ptr<Session> getSession();
148 148 void setDisconnectReason(const std::string &reason);
  149 + std::chrono::seconds getSecondsTillKillTime() const;
149 150  
150 151 void writeText(const std::string &text);
151 152 void writePingResp();
... ...
mainapp.cpp
... ... @@ -68,7 +68,7 @@ MainApp::MainApp(const std::string &amp;configFilePath) :
68 68 }
69 69  
70 70 auto fKeepAlive = std::bind(&MainApp::queueKeepAliveCheckAtAllThreads, this);
71   - timer.addCallback(fKeepAlive, 30000, "keep-alive check");
  71 + timer.addCallback(fKeepAlive, 5000, "keep-alive check");
72 72  
73 73 auto fPasswordFileReload = std::bind(&MainApp::queuePasswordFileReloadAllThreads, this);
74 74 timer.addCallback(fPasswordFileReload, 2000, "Password file reload.");
... ...
mqtt5properties.cpp
... ... @@ -44,6 +44,11 @@ std::shared_ptr&lt;std::vector&lt;std::pair&lt;std::string, std::string&gt;&gt;&gt; Mqtt5PropertyB
44 44 return this->userProperties;
45 45 }
46 46  
  47 +void Mqtt5PropertyBuilder::writeServerKeepAlive(uint16_t val)
  48 +{
  49 + writeUint16(Mqtt5Properties::ServerKeepAlive, val);
  50 +}
  51 +
47 52 void Mqtt5PropertyBuilder::writeSessionExpiry(uint32_t val)
48 53 {
49 54 writeUint32(Mqtt5Properties::SessionExpiryInterval, val, genericBytes);
... ...
mqtt5properties.h
... ... @@ -29,6 +29,7 @@ public:
29 29 void clearClientSpecificBytes();
30 30 std::shared_ptr<std::vector<std::pair<std::string, std::string>>> getUserProperties() const;
31 31  
  32 + void writeServerKeepAlive(uint16_t val);
32 33 void writeSessionExpiry(uint32_t val);
33 34 void writeReceiveMax(uint16_t val);
34 35 void writeRetainAvailable(uint8_t val);
... ...
mqttpacket.cpp
... ... @@ -397,6 +397,8 @@ void MqttPacket::handleConnect()
397 397  
398 398 if (protocolVersion == ProtocolVersion::Mqtt5)
399 399 {
  400 + keep_alive = std::max<uint16_t>(keep_alive, 5);
  401 +
400 402 const size_t proplen = decodeVariableByteIntAtPos();
401 403 const size_t prop_end_at = pos + proplen;
402 404  
... ... @@ -624,6 +626,7 @@ void MqttPacket::handleConnect()
624 626 connAck->propertyBuilder->writeWildcardSubscriptionAvailable(1);
625 627 connAck->propertyBuilder->writeSubscriptionIdentifiersAvailable(0);
626 628 connAck->propertyBuilder->writeSharedSubscriptionAvailable(0);
  629 + connAck->propertyBuilder->writeServerKeepAlive(keep_alive);
627 630  
628 631 if (!authenticationMethod.empty())
629 632 {
... ...
subscriptionstore.cpp
... ... @@ -226,6 +226,8 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr&lt;Client&gt;
226 226 // Removes an existing client when it already exists [MQTT-3.1.4-2].
227 227 void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr<Client> &client, bool clean_start, uint16_t clientReceiveMax, uint32_t sessionExpiryInterval)
228 228 {
  229 + client->getThreadData()->queueClientNextKeepAliveCheckLocked(client, true);
  230 +
229 231 RWLockGuard lock_guard(&subscriptionsRwlock);
230 232 lock_guard.wrlock();
231 233  
... ...
threaddata.cpp
... ... @@ -20,6 +20,12 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
20 20 #include <sstream>
21 21 #include <cassert>
22 22  
  23 +KeepAliveCheck::KeepAliveCheck(const std::shared_ptr<Client> client) :
  24 + client(client)
  25 +{
  26 +
  27 +}
  28 +
23 29 ThreadData::ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore, std::shared_ptr<Settings> settings) :
24 30 subscriptionStore(subscriptionStore),
25 31 settingsLocalCopy(*settings.get()),
... ... @@ -109,6 +115,26 @@ void ThreadData::queueRemoveExpiredSessions()
109 115 wakeUpThread();
110 116 }
111 117  
  118 +void ThreadData::queueClientNextKeepAliveCheck(std::shared_ptr<Client> &client, bool keepRechecking)
  119 +{
  120 + const std::chrono::seconds k = client->getSecondsTillKillTime();
  121 +
  122 + if (k == std::chrono::seconds(0))
  123 + return;
  124 +
  125 + const std::chrono::seconds when = std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now().time_since_epoch() + k);
  126 +
  127 + KeepAliveCheck check(client);
  128 + check.recheck = keepRechecking;
  129 + queuedKeepAliveChecks[when].push_back(check);
  130 +}
  131 +
  132 +void ThreadData::queueClientNextKeepAliveCheckLocked(std::shared_ptr<Client> &client, bool keepRechecking)
  133 +{
  134 + std::lock_guard<std::mutex> locker(this->queuedKeepAliveMutex);
  135 + queueClientNextKeepAliveCheck(client, keepRechecking);
  136 +}
  137 +
112 138 void ThreadData::publishStatsOnDollarTopic(std::vector<std::shared_ptr<ThreadData>> &threads)
113 139 {
114 140 uint nrOfClients = 0;
... ... @@ -229,10 +255,14 @@ void ThreadData::removeQueuedClients()
229 255  
230 256 void ThreadData::giveClient(std::shared_ptr<Client> client)
231 257 {
232   - clients_by_fd_mutex.lock();
233   - int fd = client->getFd();
234   - clients_by_fd[fd] = client;
235   - clients_by_fd_mutex.unlock();
  258 + const int fd = client->getFd();
  259 +
  260 + {
  261 + std::lock_guard<std::mutex> locker(clients_by_fd_mutex);
  262 + clients_by_fd[fd] = client;
  263 + }
  264 +
  265 + queueClientNextKeepAliveCheckLocked(client, false);
236 266  
237 267 struct epoll_event ev;
238 268 memset(&ev, 0, sizeof (struct epoll_event));
... ... @@ -444,32 +474,77 @@ void ThreadData::queueSendDisconnects()
444 474 wakeUpThread();
445 475 }
446 476  
447   -// TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial?
448 477 void ThreadData::doKeepAliveCheck()
449 478 {
450   - // We don't need to stall normal connects and disconnects for keep-alive checking. We can do it later.
451   - std::unique_lock<std::mutex> lock(clients_by_fd_mutex, std::try_to_lock);
452   - if (!lock.owns_lock())
453   - return;
  479 + logger->logf(LOG_DEBUG, "doKeepAliveCheck in thread %d", threadnr);
454 480  
455   - logger->logf(LOG_DEBUG, "Doing keep-alive check in thread %d", threadnr);
  481 + const std::chrono::seconds now = std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now().time_since_epoch());
456 482  
457 483 try
458 484 {
459   - auto it = clients_by_fd.begin();
460   - while (it != clients_by_fd.end())
  485 + // Put clients to delete in here, to avoid holding two locks.
  486 + std::vector<std::shared_ptr<Client>> clientsToRemove;
  487 +
  488 + std::vector<std::shared_ptr<Client>> clientsToRecheck;
  489 +
  490 + const int slotsTotal = this->queuedKeepAliveChecks.size();
  491 + int slotsProcessed = 0;
  492 + int clientsChecked = 0;
  493 +
461 494 {
462   - std::shared_ptr<Client> &client = it->second;
463   - if (client && client->keepAliveExpired())
  495 + logger->logf(LOG_DEBUG, "Checking clients with pending keep-alive checks in thread %d", threadnr);
  496 +
  497 + std::lock_guard<std::mutex> locker(this->queuedKeepAliveMutex);
  498 +
  499 + auto pos = this->queuedKeepAliveChecks.begin();
  500 + while (pos != this->queuedKeepAliveChecks.end())
464 501 {
465   - client->setDisconnectReason("Keep-alive expired: " + client->getKeepAliveInfoString());
466   - it = clients_by_fd.erase(it);
  502 + const std::chrono::seconds &doCheckAt = pos->first;
  503 +
  504 + if (doCheckAt > now)
  505 + break;
  506 +
  507 + slotsProcessed++;
  508 +
  509 + std::vector<KeepAliveCheck> &checks = pos->second;
  510 +
  511 + for (KeepAliveCheck &k : checks)
  512 + {
  513 + std::shared_ptr<Client> client = k.client.lock();
  514 + if (client)
  515 + {
  516 + clientsChecked++;
  517 +
  518 + if (client->keepAliveExpired())
  519 + {
  520 + clientsToRemove.push_back(client);
  521 + }
  522 + else if (k.recheck)
  523 + {
  524 + clientsToRecheck.push_back(client);
  525 + }
  526 + }
  527 + }
  528 +
  529 + pos = this->queuedKeepAliveChecks.erase(pos);
467 530 }
468   - else
  531 +
  532 + for (std::shared_ptr<Client> &c : clientsToRecheck)
  533 + {
  534 + c->resetBuffersIfEligible();
  535 + queueClientNextKeepAliveCheck(c, true);
  536 + }
  537 + }
  538 +
  539 + logger->logf(LOG_DEBUG, "Checked %d clients in %d of %d keep-alive slots in thread %d", clientsChecked, slotsProcessed, slotsTotal, threadnr);
  540 +
  541 + {
  542 + std::unique_lock<std::mutex> lock(clients_by_fd_mutex);
  543 +
  544 + for (std::shared_ptr<Client> c : clientsToRemove)
469 545 {
470   - if (client)
471   - client->resetBuffersIfEligible();
472   - it++;
  546 + c->setDisconnectReason("Keep-alive expired: " + c->getKeepAliveInfoString());
  547 + clients_by_fd.erase(c->getFd());
473 548 }
474 549 }
475 550 }
... ...
threaddata.h
... ... @@ -40,6 +40,14 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
40 40  
41 41 typedef void (*thread_f)(ThreadData *);
42 42  
  43 +struct KeepAliveCheck
  44 +{
  45 + std::weak_ptr<Client> client;
  46 + bool recheck = true;
  47 +
  48 + KeepAliveCheck(const std::shared_ptr<Client> client);
  49 +};
  50 +
43 51 class ThreadData
44 52 {
45 53 std::unordered_map<int, std::shared_ptr<Client>> clients_by_fd;
... ... @@ -58,6 +66,9 @@ class ThreadData
58 66 std::mutex clientsToRemoveMutex;
59 67 std::forward_list<std::weak_ptr<Client>> clientsQueuedForRemoving;
60 68  
  69 + std::mutex queuedKeepAliveMutex;
  70 + std::map<std::chrono::seconds, std::vector<KeepAliveCheck>> queuedKeepAliveChecks;
  71 +
61 72 void reload(std::shared_ptr<Settings> settings);
62 73 void wakeUpThread();
63 74 void doKeepAliveCheck();
... ... @@ -68,6 +79,7 @@ class ThreadData
68 79 void removeExpiredSessions();
69 80 void sendAllWills();
70 81 void sendAllDisconnects();
  82 + void queueClientNextKeepAliveCheck(std::shared_ptr<Client> &client, bool keepRechecking);
71 83  
72 84 void removeQueuedClients();
73 85  
... ... @@ -108,6 +120,7 @@ public:
108 120 void queuePublishStatsOnDollarTopic(std::vector<std::shared_ptr<ThreadData>> &threads);
109 121 void queueSendingQueuedWills();
110 122 void queueRemoveExpiredSessions();
  123 + void queueClientNextKeepAliveCheckLocked(std::shared_ptr<Client> &client, bool keepRechecking);
111 124  
112 125 int getNrOfClients() const;
113 126  
... ...