diff --git a/client.cpp b/client.cpp index 629d618..85e23d9 100644 --- a/client.cpp +++ b/client.cpp @@ -345,7 +345,7 @@ bool Client::keepAliveExpired() if (!authenticated) return lastActivity + std::chrono::seconds(20) < now; - std::chrono::seconds x(keepalive*10/5); + std::chrono::seconds x(keepalive + keepalive/2); bool result = (lastActivity + x) < now; return result; } @@ -648,6 +648,34 @@ void Client::setDisconnectReason(const std::string &reason) this->disconnectReason.append(reason); } +/** + * @brief Client::getSecondsTillKillTime gets the amount of seconds from now at which this client should be killed when it was quiet. + * @return + * + * "If the Keep Alive value is non-zero and the Server does not receive an MQTT Control Packet from the Client within one and a + * half times the Keep Alive time period, it MUST close the Network Connection to the Client as if the network had failed [MQTT-3.1.2-22]. + */ +std::chrono::seconds Client::getSecondsTillKillTime() const +{ + if (!this->authenticated) + return std::chrono::seconds(30); + + if (this->keepalive == 0) + return std::chrono::seconds(0); + + const uint32_t timeOfSilenceMeansKill = this->keepalive + (this->keepalive / 2) + 2; + std::chrono::time_point killTime = this->lastActivity + std::chrono::seconds(timeOfSilenceMeansKill); + + std::chrono::seconds secondsTillKillTime = std::chrono::duration_cast(killTime - std::chrono::steady_clock::now()); + + // We floor it, but also protect against the theoretically impossible negative value. Kill time shouldn't be in the past, because then we would + // have killed it already. + if (secondsTillKillTime < std::chrono::seconds(5)) + return std::chrono::seconds(5); + + return secondsTillKillTime; +} + void Client::clearWill() { willPublish.reset(); diff --git a/client.h b/client.h index 33eb457..d8d4011 100644 --- a/client.h +++ b/client.h @@ -146,6 +146,7 @@ public: void assignSession(std::shared_ptr &session); std::shared_ptr getSession(); void setDisconnectReason(const std::string &reason); + std::chrono::seconds getSecondsTillKillTime() const; void writeText(const std::string &text); void writePingResp(); diff --git a/mainapp.cpp b/mainapp.cpp index 1449c7d..802a135 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -68,7 +68,7 @@ MainApp::MainApp(const std::string &configFilePath) : } auto fKeepAlive = std::bind(&MainApp::queueKeepAliveCheckAtAllThreads, this); - timer.addCallback(fKeepAlive, 30000, "keep-alive check"); + timer.addCallback(fKeepAlive, 5000, "keep-alive check"); auto fPasswordFileReload = std::bind(&MainApp::queuePasswordFileReloadAllThreads, this); timer.addCallback(fPasswordFileReload, 2000, "Password file reload."); diff --git a/mqtt5properties.cpp b/mqtt5properties.cpp index fe4165a..7d31ea0 100644 --- a/mqtt5properties.cpp +++ b/mqtt5properties.cpp @@ -44,6 +44,11 @@ std::shared_ptr>> Mqtt5PropertyB return this->userProperties; } +void Mqtt5PropertyBuilder::writeServerKeepAlive(uint16_t val) +{ + writeUint16(Mqtt5Properties::ServerKeepAlive, val); +} + void Mqtt5PropertyBuilder::writeSessionExpiry(uint32_t val) { writeUint32(Mqtt5Properties::SessionExpiryInterval, val, genericBytes); diff --git a/mqtt5properties.h b/mqtt5properties.h index 6e324c8..689c3cc 100644 --- a/mqtt5properties.h +++ b/mqtt5properties.h @@ -29,6 +29,7 @@ public: void clearClientSpecificBytes(); std::shared_ptr>> getUserProperties() const; + void writeServerKeepAlive(uint16_t val); void writeSessionExpiry(uint32_t val); void writeReceiveMax(uint16_t val); void writeRetainAvailable(uint8_t val); diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 73a602e..a945133 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -397,6 +397,8 @@ void MqttPacket::handleConnect() if (protocolVersion == ProtocolVersion::Mqtt5) { + keep_alive = std::max(keep_alive, 5); + const size_t proplen = decodeVariableByteIntAtPos(); const size_t prop_end_at = pos + proplen; @@ -624,6 +626,7 @@ void MqttPacket::handleConnect() connAck->propertyBuilder->writeWildcardSubscriptionAvailable(1); connAck->propertyBuilder->writeSubscriptionIdentifiersAvailable(0); connAck->propertyBuilder->writeSharedSubscriptionAvailable(0); + connAck->propertyBuilder->writeServerKeepAlive(keep_alive); if (!authenticationMethod.empty()) { diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index 5da087c..65655ab 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -226,6 +226,8 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr // Removes an existing client when it already exists [MQTT-3.1.4-2]. void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr &client, bool clean_start, uint16_t clientReceiveMax, uint32_t sessionExpiryInterval) { + client->getThreadData()->queueClientNextKeepAliveCheckLocked(client, true); + RWLockGuard lock_guard(&subscriptionsRwlock); lock_guard.wrlock(); diff --git a/threaddata.cpp b/threaddata.cpp index 519b8f9..b021d59 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -20,6 +20,12 @@ License along with FlashMQ. If not, see . #include #include +KeepAliveCheck::KeepAliveCheck(const std::shared_ptr client) : + client(client) +{ + +} + ThreadData::ThreadData(int threadnr, std::shared_ptr &subscriptionStore, std::shared_ptr settings) : subscriptionStore(subscriptionStore), settingsLocalCopy(*settings.get()), @@ -109,6 +115,26 @@ void ThreadData::queueRemoveExpiredSessions() wakeUpThread(); } +void ThreadData::queueClientNextKeepAliveCheck(std::shared_ptr &client, bool keepRechecking) +{ + const std::chrono::seconds k = client->getSecondsTillKillTime(); + + if (k == std::chrono::seconds(0)) + return; + + const std::chrono::seconds when = std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch() + k); + + KeepAliveCheck check(client); + check.recheck = keepRechecking; + queuedKeepAliveChecks[when].push_back(check); +} + +void ThreadData::queueClientNextKeepAliveCheckLocked(std::shared_ptr &client, bool keepRechecking) +{ + std::lock_guard locker(this->queuedKeepAliveMutex); + queueClientNextKeepAliveCheck(client, keepRechecking); +} + void ThreadData::publishStatsOnDollarTopic(std::vector> &threads) { uint nrOfClients = 0; @@ -229,10 +255,14 @@ void ThreadData::removeQueuedClients() void ThreadData::giveClient(std::shared_ptr client) { - clients_by_fd_mutex.lock(); - int fd = client->getFd(); - clients_by_fd[fd] = client; - clients_by_fd_mutex.unlock(); + const int fd = client->getFd(); + + { + std::lock_guard locker(clients_by_fd_mutex); + clients_by_fd[fd] = client; + } + + queueClientNextKeepAliveCheckLocked(client, false); struct epoll_event ev; memset(&ev, 0, sizeof (struct epoll_event)); @@ -444,32 +474,77 @@ void ThreadData::queueSendDisconnects() wakeUpThread(); } -// TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial? void ThreadData::doKeepAliveCheck() { - // We don't need to stall normal connects and disconnects for keep-alive checking. We can do it later. - std::unique_lock lock(clients_by_fd_mutex, std::try_to_lock); - if (!lock.owns_lock()) - return; + logger->logf(LOG_DEBUG, "doKeepAliveCheck in thread %d", threadnr); - logger->logf(LOG_DEBUG, "Doing keep-alive check in thread %d", threadnr); + const std::chrono::seconds now = std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()); try { - auto it = clients_by_fd.begin(); - while (it != clients_by_fd.end()) + // Put clients to delete in here, to avoid holding two locks. + std::vector> clientsToRemove; + + std::vector> clientsToRecheck; + + const int slotsTotal = this->queuedKeepAliveChecks.size(); + int slotsProcessed = 0; + int clientsChecked = 0; + { - std::shared_ptr &client = it->second; - if (client && client->keepAliveExpired()) + logger->logf(LOG_DEBUG, "Checking clients with pending keep-alive checks in thread %d", threadnr); + + std::lock_guard locker(this->queuedKeepAliveMutex); + + auto pos = this->queuedKeepAliveChecks.begin(); + while (pos != this->queuedKeepAliveChecks.end()) { - client->setDisconnectReason("Keep-alive expired: " + client->getKeepAliveInfoString()); - it = clients_by_fd.erase(it); + const std::chrono::seconds &doCheckAt = pos->first; + + if (doCheckAt > now) + break; + + slotsProcessed++; + + std::vector &checks = pos->second; + + for (KeepAliveCheck &k : checks) + { + std::shared_ptr client = k.client.lock(); + if (client) + { + clientsChecked++; + + if (client->keepAliveExpired()) + { + clientsToRemove.push_back(client); + } + else if (k.recheck) + { + clientsToRecheck.push_back(client); + } + } + } + + pos = this->queuedKeepAliveChecks.erase(pos); } - else + + for (std::shared_ptr &c : clientsToRecheck) + { + c->resetBuffersIfEligible(); + queueClientNextKeepAliveCheck(c, true); + } + } + + logger->logf(LOG_DEBUG, "Checked %d clients in %d of %d keep-alive slots in thread %d", clientsChecked, slotsProcessed, slotsTotal, threadnr); + + { + std::unique_lock lock(clients_by_fd_mutex); + + for (std::shared_ptr c : clientsToRemove) { - if (client) - client->resetBuffersIfEligible(); - it++; + c->setDisconnectReason("Keep-alive expired: " + c->getKeepAliveInfoString()); + clients_by_fd.erase(c->getFd()); } } } diff --git a/threaddata.h b/threaddata.h index ef7cb0f..2c7a358 100644 --- a/threaddata.h +++ b/threaddata.h @@ -40,6 +40,14 @@ License along with FlashMQ. If not, see . typedef void (*thread_f)(ThreadData *); +struct KeepAliveCheck +{ + std::weak_ptr client; + bool recheck = true; + + KeepAliveCheck(const std::shared_ptr client); +}; + class ThreadData { std::unordered_map> clients_by_fd; @@ -58,6 +66,9 @@ class ThreadData std::mutex clientsToRemoveMutex; std::forward_list> clientsQueuedForRemoving; + std::mutex queuedKeepAliveMutex; + std::map> queuedKeepAliveChecks; + void reload(std::shared_ptr settings); void wakeUpThread(); void doKeepAliveCheck(); @@ -68,6 +79,7 @@ class ThreadData void removeExpiredSessions(); void sendAllWills(); void sendAllDisconnects(); + void queueClientNextKeepAliveCheck(std::shared_ptr &client, bool keepRechecking); void removeQueuedClients(); @@ -108,6 +120,7 @@ public: void queuePublishStatsOnDollarTopic(std::vector> &threads); void queueSendingQueuedWills(); void queueRemoveExpiredSessions(); + void queueClientNextKeepAliveCheckLocked(std::shared_ptr &client, bool keepRechecking); int getNrOfClients() const;