From 4cd9300cf9a76afcfcc421ef24ce3cde9a5be27d Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Sun, 20 Mar 2022 16:32:57 +0100 Subject: [PATCH] The new will structure, with delays, works --- client.cpp | 4 +--- client.h | 1 + mainapp.cpp | 39 +++++++++++++++++++++++++++++++++++---- mainapp.h | 3 +++ mqttpacket.cpp | 26 +++++++++++++++++++++++++- session.cpp | 16 +++++++++++++++- session.h | 3 +++ subscriptionstore.cpp | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++++-------------- subscriptionstore.h | 4 ++++ threaddata.cpp | 30 ++++++++++++++++++++++++++++++ threaddata.h | 4 ++++ timer.cpp | 2 +- types.cpp | 10 ++++++++++ types.h | 2 ++ 14 files changed, 186 insertions(+), 24 deletions(-) diff --git a/client.cpp b/client.cpp index 4cfe57b..ec4ab55 100644 --- a/client.cpp +++ b/client.cpp @@ -436,7 +436,6 @@ void Client::setClientProperties(ProtocolVersion protocolVersion, const std::str void Client::setWill(Publish &&willPublish) { this->willPublish = std::make_shared(std::move(willPublish)); - // TODO: also session. Or only the session? } void Client::assignSession(std::shared_ptr &session) @@ -459,7 +458,6 @@ void Client::setDisconnectReason(const std::string &reason) void Client::clearWill() { willPublish.reset(); - // TODO: the session too? I still need to make that 'send will when session ends' thing. - + session->clearWill(); } diff --git a/client.h b/client.h index 04e9ef6..fc397f7 100644 --- a/client.h +++ b/client.h @@ -116,6 +116,7 @@ public: std::shared_ptr getThreadData() { return threadData; } std::string &getClientId() { return this->clientid; } const std::string &getUsername() const { return this->username; } + std::shared_ptr &getWill() { return this->willPublish; } void assignSession(std::shared_ptr &session); std::shared_ptr getSession(); void setDisconnectReason(const std::string &reason); diff --git a/mainapp.cpp b/mainapp.cpp index c3d44cd..932a5f4 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -62,9 +62,9 @@ MainApp::MainApp(const std::string &configFilePath) : if (settings->expireSessionsAfterSeconds > 0) { auto f = std::bind(&MainApp::queueCleanup, this); - const uint64_t derrivedSessionCheckInterval = std::max((settings->expireSessionsAfterSeconds)*1000*2, 600000); - const uint64_t sessionCheckInterval = std::min(derrivedSessionCheckInterval, 86400000); - timer.addCallback(f, sessionCheckInterval, "session expiration"); + //const uint64_t derrivedSessionCheckInterval = std::max((settings->expireSessionsAfterSeconds)*1000*2, 600000); + //const uint64_t sessionCheckInterval = std::min(derrivedSessionCheckInterval, 86400000); + timer.addCallback(f, 10000, "session expiration"); } auto fKeepAlive = std::bind(&MainApp::queueKeepAliveCheckAtAllThreads, this); @@ -90,6 +90,9 @@ MainApp::MainApp(const std::string &configFilePath) : auto fSaveState = std::bind(&MainApp::saveStateInThread, this); timer.addCallback(fSaveState, 900000, "Save state."); + + auto fSendPendingWills = std::bind(&MainApp::queueSendQueuedWills, this); + timer.addCallback(fSendPendingWills, 2000, "Publish pending wills."); } MainApp::~MainApp() @@ -254,6 +257,34 @@ void MainApp::saveStateInThread() pthread_setname_np(native, "SaveState"); } +void MainApp::queueSendQueuedWills() +{ + std::lock_guard locker(eventMutex); + + if (!threads.empty()) + { + std::shared_ptr t = threads[nextThreadForTasks++ % threads.size()]; + auto f = std::bind(&ThreadData::queueSendingQueuedWills, t.get()); + taskQueue.push_front(f); + + wakeUpThread(); + } +} + +void MainApp::queueRemoveExpiredSessions() +{ + std::lock_guard locker(eventMutex); + + if (!threads.empty()) + { + std::shared_ptr t = threads[nextThreadForTasks++ % threads.size()]; + auto f = std::bind(&ThreadData::queueRemoveExpiredSessions, t.get()); + taskQueue.push_front(f); + + wakeUpThread(); + } +} + void MainApp::saveState() { std::lock_guard lg(saveStateMutex); @@ -713,7 +744,7 @@ void MainApp::queueCleanup() { std::lock_guard locker(eventMutex); - auto f = std::bind(&SubscriptionStore::removeExpiredSessionsClients, subscriptionStore.get()); + auto f = std::bind(&MainApp::queueRemoveExpiredSessions, this); taskQueue.push_front(f); wakeUpThread(); diff --git a/mainapp.h b/mainapp.h index 38c6c6b..71809ba 100644 --- a/mainapp.h +++ b/mainapp.h @@ -61,6 +61,7 @@ class MainApp int taskEventFd = -1; std::mutex eventMutex; Timer timer; + uint16_t nextThreadForTasks = 0; // We need to keep a settings copy as well as a shared pointer, depending on threads, queueing of config reloads, etc. std::shared_ptr settings; @@ -90,6 +91,8 @@ class MainApp void queuePublishStatsOnDollarTopic(); void saveState(); void saveStateInThread(); + void queueSendQueuedWills(); + void queueRemoveExpiredSessions(); MainApp(const std::string &configFilePath); public: diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 18d2f03..635dd5f 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -401,6 +401,7 @@ void MqttPacket::handleConnect() { case Mqtt5Properties::WillDelayInterval: willpublish.will_delay = readFourBytesToUint32(); + willpublish.createdAt = std::chrono::steady_clock::now(); break; case Mqtt5Properties::PayloadFormatIndicator: willpublish.propertyBuilder->writePayloadFormatIndicator(readByte()); @@ -504,7 +505,9 @@ void MqttPacket::handleConnect() } sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, max_packet_size, max_topic_aliases); - sender->setWill(std::move(willpublish)); + + if (will_flag) + sender->setWill(std::move(willpublish)); bool accessGranted = false; std::string denyLogMsg; @@ -596,6 +599,27 @@ void MqttPacket::handleSubscribe() throw ProtocolError("Packet ID 0 when subscribing is invalid."); // [MQTT-2.3.1-1] } + if (protocolVersion == ProtocolVersion::Mqtt5) + { + const size_t proplen = decodeVariableByteIntAtPos(); + const size_t prop_end_at = pos + proplen; + + while (pos < prop_end_at) + { + const Mqtt5Properties prop = static_cast(readByte()); + + switch (prop) + { + case Mqtt5Properties::SubscriptionIdentifier: + break; + case Mqtt5Properties::UserProperty: + break; + default: + throw ProtocolError("Invalid subscribe property."); + } + } + } + Authentication &authentication = *ThreadGlobals::getAuth(); std::list subs_reponse_codes; diff --git a/session.cpp b/session.cpp index 21d9821..9e88783 100644 --- a/session.cpp +++ b/session.cpp @@ -138,6 +138,7 @@ void Session::assignActiveConnection(std::shared_ptr &client) this->client = client; this->client_id = client->getClientId(); this->username = client->getUsername(); + this->willPublish = client->getWill(); } /** @@ -292,9 +293,22 @@ void Session::touch() bool Session::hasExpired() const { + if (!client.expired()) + return false; + std::chrono::seconds expireAfter(sessionExpiryInterval); std::chrono::time_point now = std::chrono::steady_clock::now(); - return client.expired() && (lastTouched + expireAfter) < now; + return (lastTouched + expireAfter) < now; +} + +void Session::clearWill() +{ + this->willPublish.reset(); +} + +std::shared_ptr &Session::getWill() +{ + return this->willPublish; } void Session::addIncomingQoS2MessageId(uint16_t packet_id) diff --git a/session.h b/session.h index 264116e..f44487b 100644 --- a/session.h +++ b/session.h @@ -51,6 +51,7 @@ class Session uint16_t QoSLogPrintedAtId = 0; bool destroyOnDisconnect = false; std::chrono::time_point lastTouched = std::chrono::steady_clock::now(); + std::shared_ptr willPublish; Logger *logger = Logger::getInstance(); int64_t getSessionRelativeAgeInMs() const; @@ -79,6 +80,8 @@ public: void touch(std::chrono::time_point val); void touch(); bool hasExpired() const; + void clearWill(); + std::shared_ptr &getWill(); void addIncomingQoS2MessageId(uint16_t packet_id); bool incomingQoS2MessageIdInTransit(uint16_t packet_id); diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index b459165..21fd2b5 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -269,15 +269,50 @@ bool SubscriptionStore::sessionPresent(const std::string &clientid) return result; } -void SubscriptionStore::sendQueuedWillMessages() +/** + * @brief SubscriptionStore::purgeEmptyWills doesn't lock a mutex, because it's a helper for elsewhere. + */ +void SubscriptionStore::purgeEmptyWills() { - // TODO: walk the list + auto it = pendingWillMessages.begin(); + while (it != pendingWillMessages.end()) + { + std::shared_ptr p = (*it).lock(); + if (!p) + { + it = pendingWillMessages.erase(it); + } + } +} +void SubscriptionStore::sendQueuedWillMessages() +{ std::lock_guard(this->pendingWillsMutex); + + auto it = pendingWillMessages.begin(); + while (it != pendingWillMessages.end()) + { + std::shared_ptr p = (*it).lock(); + if (p) + { + if (p->createdAt + std::chrono::seconds(p->will_delay) > std::chrono::steady_clock::now()) + break; + + logger->logf(LOG_DEBUG, "Sending delayed will on topic '%s'.", p->topic.c_str() ); + PublishCopyFactory factory(p.get()); + queuePacketAtSubscribers(factory); + } + it = pendingWillMessages.erase(it); + } } void SubscriptionStore::queueWillMessage(std::shared_ptr &willMessage) { + if (!willMessage) + return; + + logger->logf(LOG_DEBUG, "Queueing will on topic '%s', with delay %d seconds.", willMessage->topic.c_str(), willMessage->will_delay ); + if (willMessage->will_delay == 0) { PublishCopyFactory factory(willMessage.get()); @@ -285,15 +320,9 @@ void SubscriptionStore::queueWillMessage(std::shared_ptr &willMessage) return; } - /* TODO - auto delay_compare = [](std::weak_ptr &a, std::shared_ptr &b) - { - return true; - }; std::lock_guard(this->pendingWillsMutex); - auto pos = std::upper_bound(this->pendingWillMessages.begin(), this->pendingWillMessages.end(), willMessage, delay_compare); + auto pos = std::upper_bound(this->pendingWillMessages.begin(), this->pendingWillMessages.end(), willMessage, WillDelayCompare); this->pendingWillMessages.insert(pos, willMessage); - */ } void SubscriptionStore::publishNonRecursively(const std::unordered_map &subscribers, @@ -569,11 +598,11 @@ void SubscriptionStore::removeSession(const std::string &clientid) */ void SubscriptionStore::removeExpiredSessionsClients() { + logger->logf(LOG_DEBUG, "Cleaning out old sessions"); + RWLockGuard lock_guard(&subscriptionsRwlock); lock_guard.wrlock(); - logger->logf(LOG_NOTICE, "Cleaning out old sessions"); - auto session_it = sessionsById.begin(); while (session_it != sessionsById.end()) { @@ -582,15 +611,24 @@ void SubscriptionStore::removeExpiredSessionsClients() if (session->hasExpired()) { logger->logf(LOG_DEBUG, "Removing expired session from store %s", session->getClientId().c_str()); + std::shared_ptr &will = session->getWill(); + if (will) + { + will->will_delay = 0; + queueWillMessage(will); + } session_it = sessionsById.erase(session_it); } else session_it++; } - logger->logf(LOG_NOTICE, "Rebuilding subscription tree"); - - root.cleanSubscriptions(); + if (lastTreeCleanup + std::chrono::minutes(30) < std::chrono::steady_clock::now()) + { + logger->logf(LOG_NOTICE, "Rebuilding subscription tree"); + root.cleanSubscriptions(); + lastTreeCleanup = std::chrono::steady_clock::now(); + } } int64_t SubscriptionStore::getRetainedMessageCount() const diff --git a/subscriptionstore.h b/subscriptionstore.h index f28e022..187cd25 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -104,6 +104,8 @@ class SubscriptionStore std::mutex pendingWillsMutex; std::list> pendingWillMessages; + std::chrono::time_point lastTreeCleanup; + Logger *logger = Logger::getInstance(); void publishNonRecursively(const std::unordered_map &subscribers, @@ -119,6 +121,8 @@ class SubscriptionStore void countSubscriptions(SubscriptionNode *this_node, int64_t &count) const; SubscriptionNode *getDeepestNode(const std::string &topic, const std::vector &subtopics); + + void purgeEmptyWills(); public: SubscriptionStore(); diff --git a/threaddata.cpp b/threaddata.cpp index 95b93ea..d3ad1d8 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -89,6 +89,26 @@ void ThreadData::queuePublishStatsOnDollarTopic(std::vector locker(taskQueueMutex); + + auto f = std::bind(&ThreadData::sendQueuedWills, this); + taskQueue.push_front(f); + + wakeUpThread(); +} + +void ThreadData::queueRemoveExpiredSessions() +{ + std::lock_guard locker(taskQueueMutex); + + auto f = std::bind(&ThreadData::removeExpiredSessions, this); + taskQueue.push_front(f); + + wakeUpThread(); +} + void ThreadData::publishStatsOnDollarTopic(std::vector> &threads) { uint nrOfClients = 0; @@ -132,6 +152,16 @@ void ThreadData::publishStat(const std::string &topic, uint64_t n) subscriptionStore->setRetainedMessage(topic, factory.getSubtopics(), payload, 0); } +void ThreadData::sendQueuedWills() +{ + subscriptionStore->sendQueuedWillMessages(); +} + +void ThreadData::removeExpiredSessions() +{ + subscriptionStore->removeExpiredSessionsClients(); +} + void ThreadData::removeQueuedClients() { std::vector fds; diff --git a/threaddata.h b/threaddata.h index f136fe8..35ecdb7 100644 --- a/threaddata.h +++ b/threaddata.h @@ -64,6 +64,8 @@ class ThreadData void quit(); void publishStatsOnDollarTopic(std::vector> &threads); void publishStat(const std::string &topic, uint64_t n); + void sendQueuedWills(); + void removeExpiredSessions(); void removeQueuedClients(); @@ -100,6 +102,8 @@ public: void waitForQuit(); void queuePasswdFileReload(); void queuePublishStatsOnDollarTopic(std::vector> &threads); + void queueSendingQueuedWills(); + void queueRemoveExpiredSessions(); int getNrOfClients() const; diff --git a/timer.cpp b/timer.cpp index 8c45338..a253b5d 100644 --- a/timer.cpp +++ b/timer.cpp @@ -104,7 +104,7 @@ void Timer::process() while (running) { - logger->logf(LOG_DEBUG, "Timer sleeping for %d ms until event '%s' or callbacks are added.", sleeptime, callbacks.front().name.c_str()); + //logger->logf(LOG_DEBUG, "Timer sleeping for %d ms until event '%s' or callbacks are added.", sleeptime, callbacks.front().name.c_str()); int num_fds = epoll_wait(this->epollfd, events, MAX_TIMER_EVENTS, sleeptime); if (!running) diff --git a/types.cpp b/types.cpp index 3ee73bc..2fe1fc0 100644 --- a/types.cpp +++ b/types.cpp @@ -152,6 +152,16 @@ Publish::Publish(const std::string &topic, const std::string &payload, char qos) } +bool WillDelayCompare(const std::shared_ptr &a, const std::weak_ptr &b) +{ + std::shared_ptr _b = b.lock(); + + if (!_b) + return true; + + return a->will_delay < _b->will_delay; +}; + PubAck::PubAck(uint16_t packet_id) : packet_id(packet_id) { diff --git a/types.h b/types.h index 595472e..b690859 100644 --- a/types.h +++ b/types.h @@ -222,6 +222,8 @@ public: Publish(const std::string &topic, const std::string &payload, char qos); }; +bool WillDelayCompare(const std::shared_ptr &a, const std::weak_ptr &b); + class PubAck { public: -- libgit2 0.21.4