Commit 4cd9300cf9a76afcfcc421ef24ce3cde9a5be27d

Authored by Wiebe Cazemier
1 parent c32f063d

The new will structure, with delays, works

client.cpp
... ... @@ -436,7 +436,6 @@ void Client::setClientProperties(ProtocolVersion protocolVersion, const std::str
436 436 void Client::setWill(Publish &&willPublish)
437 437 {
438 438 this->willPublish = std::make_shared<Publish>(std::move(willPublish));
439   - // TODO: also session. Or only the session?
440 439 }
441 440  
442 441 void Client::assignSession(std::shared_ptr<Session> &session)
... ... @@ -459,7 +458,6 @@ void Client::setDisconnectReason(const std::string &amp;reason)
459 458 void Client::clearWill()
460 459 {
461 460 willPublish.reset();
462   - // TODO: the session too? I still need to make that 'send will when session ends' thing.
463   -
  461 + session->clearWill();
464 462 }
465 463  
... ...
client.h
... ... @@ -116,6 +116,7 @@ public:
116 116 std::shared_ptr<ThreadData> getThreadData() { return threadData; }
117 117 std::string &getClientId() { return this->clientid; }
118 118 const std::string &getUsername() const { return this->username; }
  119 + std::shared_ptr<Publish> &getWill() { return this->willPublish; }
119 120 void assignSession(std::shared_ptr<Session> &session);
120 121 std::shared_ptr<Session> getSession();
121 122 void setDisconnectReason(const std::string &reason);
... ...
mainapp.cpp
... ... @@ -62,9 +62,9 @@ MainApp::MainApp(const std::string &amp;configFilePath) :
62 62 if (settings->expireSessionsAfterSeconds > 0)
63 63 {
64 64 auto f = std::bind(&MainApp::queueCleanup, this);
65   - const uint64_t derrivedSessionCheckInterval = std::max<uint64_t>((settings->expireSessionsAfterSeconds)*1000*2, 600000);
66   - const uint64_t sessionCheckInterval = std::min<uint64_t>(derrivedSessionCheckInterval, 86400000);
67   - timer.addCallback(f, sessionCheckInterval, "session expiration");
  65 + //const uint64_t derrivedSessionCheckInterval = std::max<uint64_t>((settings->expireSessionsAfterSeconds)*1000*2, 600000);
  66 + //const uint64_t sessionCheckInterval = std::min<uint64_t>(derrivedSessionCheckInterval, 86400000);
  67 + timer.addCallback(f, 10000, "session expiration");
68 68 }
69 69  
70 70 auto fKeepAlive = std::bind(&MainApp::queueKeepAliveCheckAtAllThreads, this);
... ... @@ -90,6 +90,9 @@ MainApp::MainApp(const std::string &amp;configFilePath) :
90 90  
91 91 auto fSaveState = std::bind(&MainApp::saveStateInThread, this);
92 92 timer.addCallback(fSaveState, 900000, "Save state.");
  93 +
  94 + auto fSendPendingWills = std::bind(&MainApp::queueSendQueuedWills, this);
  95 + timer.addCallback(fSendPendingWills, 2000, "Publish pending wills.");
93 96 }
94 97  
95 98 MainApp::~MainApp()
... ... @@ -254,6 +257,34 @@ void MainApp::saveStateInThread()
254 257 pthread_setname_np(native, "SaveState");
255 258 }
256 259  
  260 +void MainApp::queueSendQueuedWills()
  261 +{
  262 + std::lock_guard<std::mutex> locker(eventMutex);
  263 +
  264 + if (!threads.empty())
  265 + {
  266 + std::shared_ptr<ThreadData> t = threads[nextThreadForTasks++ % threads.size()];
  267 + auto f = std::bind(&ThreadData::queueSendingQueuedWills, t.get());
  268 + taskQueue.push_front(f);
  269 +
  270 + wakeUpThread();
  271 + }
  272 +}
  273 +
  274 +void MainApp::queueRemoveExpiredSessions()
  275 +{
  276 + std::lock_guard<std::mutex> locker(eventMutex);
  277 +
  278 + if (!threads.empty())
  279 + {
  280 + std::shared_ptr<ThreadData> t = threads[nextThreadForTasks++ % threads.size()];
  281 + auto f = std::bind(&ThreadData::queueRemoveExpiredSessions, t.get());
  282 + taskQueue.push_front(f);
  283 +
  284 + wakeUpThread();
  285 + }
  286 +}
  287 +
257 288 void MainApp::saveState()
258 289 {
259 290 std::lock_guard<std::mutex> lg(saveStateMutex);
... ... @@ -713,7 +744,7 @@ void MainApp::queueCleanup()
713 744 {
714 745 std::lock_guard<std::mutex> locker(eventMutex);
715 746  
716   - auto f = std::bind(&SubscriptionStore::removeExpiredSessionsClients, subscriptionStore.get());
  747 + auto f = std::bind(&MainApp::queueRemoveExpiredSessions, this);
717 748 taskQueue.push_front(f);
718 749  
719 750 wakeUpThread();
... ...
mainapp.h
... ... @@ -61,6 +61,7 @@ class MainApp
61 61 int taskEventFd = -1;
62 62 std::mutex eventMutex;
63 63 Timer timer;
  64 + uint16_t nextThreadForTasks = 0;
64 65  
65 66 // We need to keep a settings copy as well as a shared pointer, depending on threads, queueing of config reloads, etc.
66 67 std::shared_ptr<Settings> settings;
... ... @@ -90,6 +91,8 @@ class MainApp
90 91 void queuePublishStatsOnDollarTopic();
91 92 void saveState();
92 93 void saveStateInThread();
  94 + void queueSendQueuedWills();
  95 + void queueRemoveExpiredSessions();
93 96  
94 97 MainApp(const std::string &configFilePath);
95 98 public:
... ...
mqttpacket.cpp
... ... @@ -401,6 +401,7 @@ void MqttPacket::handleConnect()
401 401 {
402 402 case Mqtt5Properties::WillDelayInterval:
403 403 willpublish.will_delay = readFourBytesToUint32();
  404 + willpublish.createdAt = std::chrono::steady_clock::now();
404 405 break;
405 406 case Mqtt5Properties::PayloadFormatIndicator:
406 407 willpublish.propertyBuilder->writePayloadFormatIndicator(readByte());
... ... @@ -504,7 +505,9 @@ void MqttPacket::handleConnect()
504 505 }
505 506  
506 507 sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, max_packet_size, max_topic_aliases);
507   - sender->setWill(std::move(willpublish));
  508 +
  509 + if (will_flag)
  510 + sender->setWill(std::move(willpublish));
508 511  
509 512 bool accessGranted = false;
510 513 std::string denyLogMsg;
... ... @@ -596,6 +599,27 @@ void MqttPacket::handleSubscribe()
596 599 throw ProtocolError("Packet ID 0 when subscribing is invalid."); // [MQTT-2.3.1-1]
597 600 }
598 601  
  602 + if (protocolVersion == ProtocolVersion::Mqtt5)
  603 + {
  604 + const size_t proplen = decodeVariableByteIntAtPos();
  605 + const size_t prop_end_at = pos + proplen;
  606 +
  607 + while (pos < prop_end_at)
  608 + {
  609 + const Mqtt5Properties prop = static_cast<Mqtt5Properties>(readByte());
  610 +
  611 + switch (prop)
  612 + {
  613 + case Mqtt5Properties::SubscriptionIdentifier:
  614 + break;
  615 + case Mqtt5Properties::UserProperty:
  616 + break;
  617 + default:
  618 + throw ProtocolError("Invalid subscribe property.");
  619 + }
  620 + }
  621 + }
  622 +
599 623 Authentication &authentication = *ThreadGlobals::getAuth();
600 624  
601 625 std::list<char> subs_reponse_codes;
... ...
session.cpp
... ... @@ -138,6 +138,7 @@ void Session::assignActiveConnection(std::shared_ptr&lt;Client&gt; &amp;client)
138 138 this->client = client;
139 139 this->client_id = client->getClientId();
140 140 this->username = client->getUsername();
  141 + this->willPublish = client->getWill();
141 142 }
142 143  
143 144 /**
... ... @@ -292,9 +293,22 @@ void Session::touch()
292 293  
293 294 bool Session::hasExpired() const
294 295 {
  296 + if (!client.expired())
  297 + return false;
  298 +
295 299 std::chrono::seconds expireAfter(sessionExpiryInterval);
296 300 std::chrono::time_point<std::chrono::steady_clock> now = std::chrono::steady_clock::now();
297   - return client.expired() && (lastTouched + expireAfter) < now;
  301 + return (lastTouched + expireAfter) < now;
  302 +}
  303 +
  304 +void Session::clearWill()
  305 +{
  306 + this->willPublish.reset();
  307 +}
  308 +
  309 +std::shared_ptr<Publish> &Session::getWill()
  310 +{
  311 + return this->willPublish;
298 312 }
299 313  
300 314 void Session::addIncomingQoS2MessageId(uint16_t packet_id)
... ...
session.h
... ... @@ -51,6 +51,7 @@ class Session
51 51 uint16_t QoSLogPrintedAtId = 0;
52 52 bool destroyOnDisconnect = false;
53 53 std::chrono::time_point<std::chrono::steady_clock> lastTouched = std::chrono::steady_clock::now();
  54 + std::shared_ptr<Publish> willPublish;
54 55 Logger *logger = Logger::getInstance();
55 56  
56 57 int64_t getSessionRelativeAgeInMs() const;
... ... @@ -79,6 +80,8 @@ public:
79 80 void touch(std::chrono::time_point<std::chrono::steady_clock> val);
80 81 void touch();
81 82 bool hasExpired() const;
  83 + void clearWill();
  84 + std::shared_ptr<Publish> &getWill();
82 85  
83 86 void addIncomingQoS2MessageId(uint16_t packet_id);
84 87 bool incomingQoS2MessageIdInTransit(uint16_t packet_id);
... ...
subscriptionstore.cpp
... ... @@ -269,15 +269,50 @@ bool SubscriptionStore::sessionPresent(const std::string &amp;clientid)
269 269 return result;
270 270 }
271 271  
272   -void SubscriptionStore::sendQueuedWillMessages()
  272 +/**
  273 + * @brief SubscriptionStore::purgeEmptyWills doesn't lock a mutex, because it's a helper for elsewhere.
  274 + */
  275 +void SubscriptionStore::purgeEmptyWills()
273 276 {
274   - // TODO: walk the list
  277 + auto it = pendingWillMessages.begin();
  278 + while (it != pendingWillMessages.end())
  279 + {
  280 + std::shared_ptr<Publish> p = (*it).lock();
  281 + if (!p)
  282 + {
  283 + it = pendingWillMessages.erase(it);
  284 + }
  285 + }
  286 +}
275 287  
  288 +void SubscriptionStore::sendQueuedWillMessages()
  289 +{
276 290 std::lock_guard<std::mutex>(this->pendingWillsMutex);
  291 +
  292 + auto it = pendingWillMessages.begin();
  293 + while (it != pendingWillMessages.end())
  294 + {
  295 + std::shared_ptr<Publish> p = (*it).lock();
  296 + if (p)
  297 + {
  298 + if (p->createdAt + std::chrono::seconds(p->will_delay) > std::chrono::steady_clock::now())
  299 + break;
  300 +
  301 + logger->logf(LOG_DEBUG, "Sending delayed will on topic '%s'.", p->topic.c_str() );
  302 + PublishCopyFactory factory(p.get());
  303 + queuePacketAtSubscribers(factory);
  304 + }
  305 + it = pendingWillMessages.erase(it);
  306 + }
277 307 }
278 308  
279 309 void SubscriptionStore::queueWillMessage(std::shared_ptr<Publish> &willMessage)
280 310 {
  311 + if (!willMessage)
  312 + return;
  313 +
  314 + logger->logf(LOG_DEBUG, "Queueing will on topic '%s', with delay %d seconds.", willMessage->topic.c_str(), willMessage->will_delay );
  315 +
281 316 if (willMessage->will_delay == 0)
282 317 {
283 318 PublishCopyFactory factory(willMessage.get());
... ... @@ -285,15 +320,9 @@ void SubscriptionStore::queueWillMessage(std::shared_ptr&lt;Publish&gt; &amp;willMessage)
285 320 return;
286 321 }
287 322  
288   - /* TODO
289   - auto delay_compare = [](std::weak_ptr<Publish> &a, std::shared_ptr<Publish> &b)
290   - {
291   - return true;
292   - };
293 323 std::lock_guard<std::mutex>(this->pendingWillsMutex);
294   - auto pos = std::upper_bound(this->pendingWillMessages.begin(), this->pendingWillMessages.end(), willMessage, delay_compare);
  324 + auto pos = std::upper_bound(this->pendingWillMessages.begin(), this->pendingWillMessages.end(), willMessage, WillDelayCompare);
295 325 this->pendingWillMessages.insert(pos, willMessage);
296   - */
297 326 }
298 327  
299 328 void SubscriptionStore::publishNonRecursively(const std::unordered_map<std::string, Subscription> &subscribers,
... ... @@ -569,11 +598,11 @@ void SubscriptionStore::removeSession(const std::string &amp;clientid)
569 598 */
570 599 void SubscriptionStore::removeExpiredSessionsClients()
571 600 {
  601 + logger->logf(LOG_DEBUG, "Cleaning out old sessions");
  602 +
572 603 RWLockGuard lock_guard(&subscriptionsRwlock);
573 604 lock_guard.wrlock();
574 605  
575   - logger->logf(LOG_NOTICE, "Cleaning out old sessions");
576   -
577 606 auto session_it = sessionsById.begin();
578 607 while (session_it != sessionsById.end())
579 608 {
... ... @@ -582,15 +611,24 @@ void SubscriptionStore::removeExpiredSessionsClients()
582 611 if (session->hasExpired())
583 612 {
584 613 logger->logf(LOG_DEBUG, "Removing expired session from store %s", session->getClientId().c_str());
  614 + std::shared_ptr<Publish> &will = session->getWill();
  615 + if (will)
  616 + {
  617 + will->will_delay = 0;
  618 + queueWillMessage(will);
  619 + }
585 620 session_it = sessionsById.erase(session_it);
586 621 }
587 622 else
588 623 session_it++;
589 624 }
590 625  
591   - logger->logf(LOG_NOTICE, "Rebuilding subscription tree");
592   -
593   - root.cleanSubscriptions();
  626 + if (lastTreeCleanup + std::chrono::minutes(30) < std::chrono::steady_clock::now())
  627 + {
  628 + logger->logf(LOG_NOTICE, "Rebuilding subscription tree");
  629 + root.cleanSubscriptions();
  630 + lastTreeCleanup = std::chrono::steady_clock::now();
  631 + }
594 632 }
595 633  
596 634 int64_t SubscriptionStore::getRetainedMessageCount() const
... ...
subscriptionstore.h
... ... @@ -104,6 +104,8 @@ class SubscriptionStore
104 104 std::mutex pendingWillsMutex;
105 105 std::list<std::weak_ptr<Publish>> pendingWillMessages;
106 106  
  107 + std::chrono::time_point<std::chrono::steady_clock> lastTreeCleanup;
  108 +
107 109 Logger *logger = Logger::getInstance();
108 110  
109 111 void publishNonRecursively(const std::unordered_map<std::string, Subscription> &subscribers,
... ... @@ -119,6 +121,8 @@ class SubscriptionStore
119 121 void countSubscriptions(SubscriptionNode *this_node, int64_t &count) const;
120 122  
121 123 SubscriptionNode *getDeepestNode(const std::string &topic, const std::vector<std::string> &subtopics);
  124 +
  125 + void purgeEmptyWills();
122 126 public:
123 127 SubscriptionStore();
124 128  
... ...
threaddata.cpp
... ... @@ -89,6 +89,26 @@ void ThreadData::queuePublishStatsOnDollarTopic(std::vector&lt;std::shared_ptr&lt;Thre
89 89 wakeUpThread();
90 90 }
91 91  
  92 +void ThreadData::queueSendingQueuedWills()
  93 +{
  94 + std::lock_guard<std::mutex> locker(taskQueueMutex);
  95 +
  96 + auto f = std::bind(&ThreadData::sendQueuedWills, this);
  97 + taskQueue.push_front(f);
  98 +
  99 + wakeUpThread();
  100 +}
  101 +
  102 +void ThreadData::queueRemoveExpiredSessions()
  103 +{
  104 + std::lock_guard<std::mutex> locker(taskQueueMutex);
  105 +
  106 + auto f = std::bind(&ThreadData::removeExpiredSessions, this);
  107 + taskQueue.push_front(f);
  108 +
  109 + wakeUpThread();
  110 +}
  111 +
92 112 void ThreadData::publishStatsOnDollarTopic(std::vector<std::shared_ptr<ThreadData>> &threads)
93 113 {
94 114 uint nrOfClients = 0;
... ... @@ -132,6 +152,16 @@ void ThreadData::publishStat(const std::string &amp;topic, uint64_t n)
132 152 subscriptionStore->setRetainedMessage(topic, factory.getSubtopics(), payload, 0);
133 153 }
134 154  
  155 +void ThreadData::sendQueuedWills()
  156 +{
  157 + subscriptionStore->sendQueuedWillMessages();
  158 +}
  159 +
  160 +void ThreadData::removeExpiredSessions()
  161 +{
  162 + subscriptionStore->removeExpiredSessionsClients();
  163 +}
  164 +
135 165 void ThreadData::removeQueuedClients()
136 166 {
137 167 std::vector<int> fds;
... ...
threaddata.h
... ... @@ -64,6 +64,8 @@ class ThreadData
64 64 void quit();
65 65 void publishStatsOnDollarTopic(std::vector<std::shared_ptr<ThreadData>> &threads);
66 66 void publishStat(const std::string &topic, uint64_t n);
  67 + void sendQueuedWills();
  68 + void removeExpiredSessions();
67 69  
68 70 void removeQueuedClients();
69 71  
... ... @@ -100,6 +102,8 @@ public:
100 102 void waitForQuit();
101 103 void queuePasswdFileReload();
102 104 void queuePublishStatsOnDollarTopic(std::vector<std::shared_ptr<ThreadData>> &threads);
  105 + void queueSendingQueuedWills();
  106 + void queueRemoveExpiredSessions();
103 107  
104 108 int getNrOfClients() const;
105 109  
... ...
timer.cpp
... ... @@ -104,7 +104,7 @@ void Timer::process()
104 104  
105 105 while (running)
106 106 {
107   - logger->logf(LOG_DEBUG, "Timer sleeping for %d ms until event '%s' or callbacks are added.", sleeptime, callbacks.front().name.c_str());
  107 + //logger->logf(LOG_DEBUG, "Timer sleeping for %d ms until event '%s' or callbacks are added.", sleeptime, callbacks.front().name.c_str());
108 108 int num_fds = epoll_wait(this->epollfd, events, MAX_TIMER_EVENTS, sleeptime);
109 109  
110 110 if (!running)
... ...
types.cpp
... ... @@ -152,6 +152,16 @@ Publish::Publish(const std::string &amp;topic, const std::string &amp;payload, char qos)
152 152  
153 153 }
154 154  
  155 +bool WillDelayCompare(const std::shared_ptr<Publish> &a, const std::weak_ptr<Publish> &b)
  156 +{
  157 + std::shared_ptr<Publish> _b = b.lock();
  158 +
  159 + if (!_b)
  160 + return true;
  161 +
  162 + return a->will_delay < _b->will_delay;
  163 +};
  164 +
155 165 PubAck::PubAck(uint16_t packet_id) :
156 166 packet_id(packet_id)
157 167 {
... ...
... ... @@ -222,6 +222,8 @@ public:
222 222 Publish(const std::string &topic, const std::string &payload, char qos);
223 223 };
224 224  
  225 +bool WillDelayCompare(const std::shared_ptr<Publish> &a, const std::weak_ptr<Publish> &b);
  226 +
225 227 class PubAck
226 228 {
227 229 public:
... ...