From 367579cc24954a856a89a48e819ba4ba788ef6e4 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Sun, 30 May 2021 19:19:54 +0200 Subject: [PATCH] Properly handle dollar topics --- FlashMQTests/tst_maintests.cpp | 4 ++++ mainapp.cpp | 41 +++++++++++++++++++++++++++++++++++++++++ mainapp.h | 4 ++++ mqttpacket.cpp | 2 ++ session.cpp | 11 +++++++++-- session.h | 4 ++-- subscriptionstore.cpp | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++----------------- subscriptionstore.h | 9 +++++---- threaddata.cpp | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++++ threaddata.h | 20 ++++++++++++++++++++ threadlocalutils.cpp | 3 +++ utils.cpp | 3 +++ 12 files changed, 196 insertions(+), 25 deletions(-) diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 6b1374f..390fb78 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -323,6 +323,9 @@ void MainTests::test_validSubscribePath() QVERIFY(isValidSubscribePath("")); QVERIFY(isValidSubscribePath("hello")); + QVERIFY(isValidSubscribePath("$SYS/hello")); + QVERIFY(isValidSubscribePath("hello/$SYS")); // Hmm, is this valid? + QVERIFY(!isValidSubscribePath("one/tw+o/three")); QVERIFY(!isValidSubscribePath("one/+o/three")); QVERIFY(!isValidSubscribePath("one/a+/three")); @@ -707,6 +710,7 @@ void MainTests::test_validUtf8Sse() QVERIFY(!data.isValidUtf8("+", true)); QVERIFY(!data.isValidUtf8("🩰+asdfasdfasdf", true)); QVERIFY(!data.isValidUtf8("+asdfasdfasdf", true)); + QVERIFY(!data.isValidUtf8("$SYS/asdfasdfasdf", true)); std::memset(m, 0, 16); m[0] = 'a'; diff --git a/mainapp.cpp b/mainapp.cpp index 94a3d61..e7a1261 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -187,6 +187,10 @@ MainApp::MainApp(const std::string &configFilePath) : auto fPasswordFileReload = std::bind(&MainApp::queuePasswordFileReloadAllThreads, this); timer.addCallback(fPasswordFileReload, 2000, "Password file reload."); + + auto fPublishStats = std::bind(&MainApp::publishStatsOnDollarTopic, this); + timer.addCallback(fPublishStats, 10000, "Publish stats on $SYS"); + publishStatsOnDollarTopic(); } MainApp::~MainApp() @@ -306,6 +310,43 @@ void MainApp::setFuzzFile(const std::string &fuzzFilePath) this->fuzzFilePath = fuzzFilePath; } +void MainApp::publishStatsOnDollarTopic() +{ + uint nrOfClients = 0; + uint64_t receivedMessageCountPerSecond = 0; + uint64_t receivedMessageCount = 0; + uint64_t sentMessageCountPerSecond = 0; + uint64_t sentMessageCount = 0; + + for (std::shared_ptr &thread : threads) + { + nrOfClients += thread->getNrOfClients(); + + receivedMessageCountPerSecond += thread->getReceivedMessagePerSecond(); + receivedMessageCount += thread->getReceivedMessageCount(); + + sentMessageCountPerSecond += thread->getSentMessagePerSecond(); + sentMessageCount += thread->getSentMessageCount(); + } + + publishStat("$SYS/broker/clients/total", nrOfClients); + + publishStat("$SYS/broker/load/messages/received/total", receivedMessageCount); + publishStat("$SYS/broker/load/messages/received/persecond", receivedMessageCountPerSecond); + + publishStat("$SYS/broker/load/messages/sent/total", sentMessageCount); + publishStat("$SYS/broker/load/messages/sent/persecond", sentMessageCountPerSecond); +} + +void MainApp::publishStat(const std::string &topic, uint64_t n) +{ + std::vector *subtopics = utils.splitTopic(topic); + const std::string payload = std::to_string(n); + Publish p(topic, payload, 0); + subscriptionStore->queuePacketAtSubscribers(*subtopics, p, true); + subscriptionStore->setRetainedMessage(topic, payload, 0); +} + void MainApp::initMainApp(int argc, char *argv[]) { if (instance != nullptr) diff --git a/mainapp.h b/mainapp.h index e411152..5cc1afb 100644 --- a/mainapp.h +++ b/mainapp.h @@ -41,6 +41,7 @@ License along with FlashMQ. If not, see . #include "timer.h" #include "scopedsocket.h" #include "oneinstancelock.h" +#include "threadlocalutils.h" #define VERSION "0.7.0" @@ -64,6 +65,7 @@ class MainApp std::mutex quitMutex; std::string fuzzFilePath; OneInstanceLock oneInstanceLock; + Utils utils; Logger *logger = Logger::getInstance(); @@ -77,6 +79,8 @@ class MainApp void queueKeepAliveCheckAtAllThreads(); void queuePasswordFileReloadAllThreads(); void setFuzzFile(const std::string &fuzzFilePath); + void publishStatsOnDollarTopic(); + void publishStat(const std::string &topic, uint64_t n); MainApp(const std::string &configFilePath); public: diff --git a/mqttpacket.cpp b/mqttpacket.cpp index c34ba3e..392685d 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -497,6 +497,8 @@ void MqttPacket::handlePublish() logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", topic.c_str(), qos, retain, dup); #endif + sender->getThreadData()->incrementReceivedMessageCount(); + if (qos) { packet_id_pos = pos; diff --git a/session.cpp b/session.cpp index 7a1442e..92dda1a 100644 --- a/session.cpp +++ b/session.cpp @@ -48,7 +48,7 @@ void Session::assignActiveConnection(std::shared_ptr &client) this->thread = client->getThreadData(); } -void Session::writePacket(const MqttPacket &packet, char max_qos) +void Session::writePacket(const MqttPacket &packet, char max_qos, uint64_t &count) { assert(max_qos <= 2); @@ -62,6 +62,7 @@ void Session::writePacket(const MqttPacket &packet, char max_qos) { std::shared_ptr c = makeSharedClient(); c->writeMqttPacketAndBlameThisClient(packet, qos); + count++; } } else if (qos > 0) @@ -93,6 +94,7 @@ void Session::writePacket(const MqttPacket &packet, char max_qos) std::shared_ptr c = makeSharedClient(); c->writeMqttPacketAndBlameThisClient(*copyPacket.get(), qos); copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. + count++; } } } @@ -136,8 +138,10 @@ void Session::clearQosMessage(uint16_t packet_id) // // There is a bit of a hole there, I think. When we write out a packet to a receiver, it may decide to drop it, if its buffers // are full, for instance. We are not required to (periodically) retry. TODO Perhaps I will implement that retry anyway. -void Session::sendPendingQosMessages() +uint64_t Session::sendPendingQosMessages() { + uint64_t count = 0; + if (!clientDisconnected()) { std::shared_ptr c = makeSharedClient(); @@ -146,6 +150,7 @@ void Session::sendPendingQosMessages() { c->writeMqttPacketAndBlameThisClient(*qosMessage.packet.get(), qosMessage.packet->getQos()); qosMessage.packet->setDuplicate(); // Any dealings with this packet from here will be a duplicate. + count++; } for (const uint16_t packet_id : outgoingQoS2MessageIds) @@ -155,6 +160,8 @@ void Session::sendPendingQosMessages() c->writeMqttPacketAndBlameThisClient(packet, 2); } } + + return count; } void Session::touch(time_t val) diff --git a/session.h b/session.h index 596a0d6..94851e1 100644 --- a/session.h +++ b/session.h @@ -64,9 +64,9 @@ public: bool clientDisconnected() const; std::shared_ptr makeSharedClient() const; void assignActiveConnection(std::shared_ptr &client); - void writePacket(const MqttPacket &packet, char max_qos); + void writePacket(const MqttPacket &packet, char max_qos, uint64_t &count); void clearQosMessage(uint16_t packet_id); - void sendPendingQosMessages(); + uint64_t sendPendingQosMessages(); void touch(time_t val = 0); bool hasExpired(); diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index f82208a..9243dd3 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -78,6 +78,7 @@ SubscriptionNode *SubscriptionNode::getChildren(const std::string &subtopic) con SubscriptionStore::SubscriptionStore() : root("root"), + rootDollar("rootDollar"), sessionsByIdConst(sessionsById) { @@ -87,10 +88,13 @@ void SubscriptionStore::addSubscription(std::shared_ptr &client, const s { const std::list subtopics = split(topic, '/'); + SubscriptionNode *deepestNode = &root; + if (topic.length() > 0 && topic[0] == '$') + deepestNode = &rootDollar; + RWLockGuard lock_guard(&subscriptionsRwlock); lock_guard.wrlock(); - SubscriptionNode *deepestNode = &root; for(const std::string &subtopic : subtopics) { std::unique_ptr *selectedChildren = nullptr; @@ -120,25 +124,27 @@ void SubscriptionStore::addSubscription(std::shared_ptr &client, const s { const std::shared_ptr &ses = session_it->second; deepestNode->addSubscriber(ses, qos); - giveClientRetainedMessages(ses, topic, qos); + uint64_t count = giveClientRetainedMessages(ses, topic, qos); + client->getThreadData()->incrementSentMessageCount(count); } } lock_guard.unlock(); - - } void SubscriptionStore::removeSubscription(std::shared_ptr &client, const std::string &topic) { const std::list subtopics = split(topic, '/'); + SubscriptionNode *deepestNode = &root; + if (topic.length() > 0 && topic[0] == '$') + deepestNode = &rootDollar; + RWLockGuard lock_guard(&subscriptionsRwlock); lock_guard.wrlock(); // This code looks like that for addSubscription(), but it's specifically different in that we don't want to default-create non-existing // nodes. We need to abort when that happens. - SubscriptionNode *deepestNode = &root; for(const std::string &subtopic : subtopics) { SubscriptionNode *selectedChildren = nullptr; @@ -208,7 +214,8 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr session->assignActiveConnection(client); client->assignSession(session); - session->sendPendingQosMessages(); + uint64_t count = session->sendPendingQosMessages(); + client->getThreadData()->incrementSentMessageCount(count); } bool SubscriptionStore::sessionPresent(const std::string &clientid) @@ -227,7 +234,7 @@ bool SubscriptionStore::sessionPresent(const std::string &clientid) return result; } -void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::vector &subscribers) const +void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::vector &subscribers, uint64_t &count) const { for (const Subscription &sub : subscribers) { @@ -235,18 +242,29 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const st if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect. { const std::shared_ptr session = session_weak.lock(); - session->writePacket(packet, sub.qos); + session->writePacket(packet, sub.qos, count); } } } +/** + * @brief SubscriptionStore::publishRecursively + * @param cur_subtopic_it + * @param end + * @param this_node + * @param packet + * @param count as a reference (vs return value) because a return value introduces an extra call i.e. limits tail recursion optimization. + * + * As noted in the params section, this method was written so that it could be (somewhat) optimized for tail recursion by the kernel. If you refactor this, + * look at objdump --disassemble --demangle to see how many calls (not jumps) to itself are made and compare. + */ void SubscriptionStore::publishRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, - SubscriptionNode *this_node, const MqttPacket &packet) const + SubscriptionNode *this_node, const MqttPacket &packet, uint64_t &count) const { if (cur_subtopic_it == end) // This is the end of the topic path, so look for subscribers here. { if (this_node) - publishNonRecursively(packet, this_node->getSubscribers()); + publishNonRecursively(packet, this_node->getSubscribers(), count); return; } @@ -263,33 +281,44 @@ void SubscriptionStore::publishRecursively(std::vector::const_itera if (this_node->childrenPound) { - publishNonRecursively(packet, this_node->childrenPound->getSubscribers()); + publishNonRecursively(packet, this_node->childrenPound->getSubscribers(), count); } const auto &sub_node = this_node->children.find(cur_subtop); if (sub_node != this_node->children.end()) { - publishRecursively(next_subtopic, end, sub_node->second.get(), packet); + publishRecursively(next_subtopic, end, sub_node->second.get(), packet, count); } if (this_node->childrenPlus) { - publishRecursively(next_subtopic, end, this_node->childrenPlus.get(), packet); + publishRecursively(next_subtopic, end, this_node->childrenPlus.get(), packet, count); } } -void SubscriptionStore::queuePacketAtSubscribers(const std::vector &subtopics, const MqttPacket &packet) +void SubscriptionStore::queuePacketAtSubscribers(const std::vector &subtopics, const MqttPacket &packet, bool dollar) { assert(subtopics.size() > 0); + SubscriptionNode *startNode = dollar ? &rootDollar : &root; + RWLockGuard lock_guard(&subscriptionsRwlock); lock_guard.rdlock(); - publishRecursively(subtopics.begin(), subtopics.end(), &root, packet); + uint64_t count = 0; + publishRecursively(subtopics.begin(), subtopics.end(), startNode, packet, count); + + std::shared_ptr sender = packet.getSender(); + if (sender) + { + sender->getThreadData()->incrementSentMessageCount(count); + } } -void SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr &ses, const std::string &subscribe_topic, char max_qos) +uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr &ses, const std::string &subscribe_topic, char max_qos) { + uint64_t count = 0; + RWLockGuard locker(&retainedMessagesRwlock); locker.rdlock(); @@ -300,8 +329,12 @@ void SubscriptionStore::giveClientRetainedMessages(const std::shared_ptrwritePacket(packet, max_qos); + { + ses->writePacket(packet, max_qos, count); + } } + + return count; } void SubscriptionStore::setRetainedMessage(const std::string &topic, const std::string &payload, char qos) diff --git a/subscriptionstore.h b/subscriptionstore.h index 86ddee0..6f4eeba 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -72,6 +72,7 @@ public: class SubscriptionStore { SubscriptionNode root; + SubscriptionNode rootDollar; pthread_rwlock_t subscriptionsRwlock = PTHREAD_RWLOCK_INITIALIZER; std::unordered_map> sessionsById; const std::unordered_map> &sessionsByIdConst; @@ -81,9 +82,9 @@ class SubscriptionStore Logger *logger = Logger::getInstance(); - void publishNonRecursively(const MqttPacket &packet, const std::vector &subscribers) const; + void publishNonRecursively(const MqttPacket &packet, const std::vector &subscribers, uint64_t &count) const; void publishRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, - SubscriptionNode *this_node, const MqttPacket &packet) const; + SubscriptionNode *this_node, const MqttPacket &packet, uint64_t &count) const; public: SubscriptionStore(); @@ -93,8 +94,8 @@ public: void registerClientAndKickExistingOne(std::shared_ptr &client); bool sessionPresent(const std::string &clientid); - void queuePacketAtSubscribers(const std::vector &subtopics, const MqttPacket &packet); - void giveClientRetainedMessages(const std::shared_ptr &ses, const std::string &subscribe_topic, char max_qos); + void queuePacketAtSubscribers(const std::vector &subtopics, const MqttPacket &packet, bool dollar = false); + uint64_t giveClientRetainedMessages(const std::shared_ptr &ses, const std::string &subscribe_topic, char max_qos); void setRetainedMessage(const std::string &topic, const std::string &payload, char qos); diff --git a/threaddata.cpp b/threaddata.cpp index 45cb9cd..fb987b1 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -157,6 +157,59 @@ void ThreadData::queuePasswdFileReload() wakeUpThread(); } +int ThreadData::getNrOfClients() const +{ + return clients_by_fd.size(); +} + +void ThreadData::incrementReceivedMessageCount() +{ + receivedMessageCount++; +} + +uint64_t ThreadData::getReceivedMessageCount() const +{ + return receivedMessageCount; +} + +/** + * @brief ThreadData::getReceivedMessagePerSecond gets the amount of seconds received, averaged over the last time this was called. + * @return + * + * Locking is not required, because the counter is not written to from here. + */ +uint64_t ThreadData::getReceivedMessagePerSecond() +{ + std::chrono::time_point now = std::chrono::steady_clock::now(); + std::chrono::milliseconds msSinceLastTime = std::chrono::duration_cast(now - receivedMessagePreviousTime); + uint64_t messagesTimes1000 = (receivedMessageCount - receivedMessageCountPrevious) * 1000; + uint64_t result = messagesTimes1000 / (msSinceLastTime.count() + 1); // branchless avoidance of div by 0; + receivedMessagePreviousTime = now; + receivedMessageCountPrevious = receivedMessageCount; + return result; +} + +void ThreadData::incrementSentMessageCount(uint64_t n) +{ + sentMessageCount += n; +} + +uint64_t ThreadData::getSentMessageCount() const +{ + return sentMessageCount; +} + +uint64_t ThreadData::getSentMessagePerSecond() +{ + std::chrono::time_point now = std::chrono::steady_clock::now(); + std::chrono::milliseconds msSinceLastTime = std::chrono::duration_cast(now - sentMessagePreviousTime); + uint64_t messagesTimes1000 = (sentMessageCount - sentMessageCountPrevious) * 1000; + uint64_t result = messagesTimes1000 / (msSinceLastTime.count() + 1); // branchless avoidance of div by 0; + sentMessagePreviousTime = now; + sentMessageCountPrevious = sentMessageCount; + return result; +} + // TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial? void ThreadData::doKeepAliveCheck() { diff --git a/threaddata.h b/threaddata.h index 798ac51..b36cbeb 100644 --- a/threaddata.h +++ b/threaddata.h @@ -28,6 +28,7 @@ License along with FlashMQ. If not, see . #include #include #include +#include #include "forward_declarations.h" @@ -47,6 +48,15 @@ class ThreadData std::shared_ptr subscriptionStore; Logger *logger; + uint64_t receivedMessageCount = 0; + uint64_t receivedMessageCountPrevious = 0; + std::chrono::time_point receivedMessagePreviousTime = std::chrono::steady_clock::now(); + + uint64_t sentMessageCount = 0; + uint64_t sentMessageCountPrevious = 0; + std::chrono::time_point sentMessagePreviousTime = std::chrono::steady_clock::now(); + + void reload(std::shared_ptr settings); void wakeUpThread(); void doKeepAliveCheck(); @@ -81,6 +91,16 @@ public: void queueQuit(); void waitForQuit(); void queuePasswdFileReload(); + + int getNrOfClients() const; + + void incrementReceivedMessageCount(); + uint64_t getReceivedMessageCount() const; + uint64_t getReceivedMessagePerSecond(); + + void incrementSentMessageCount(uint64_t n); + uint64_t getSentMessageCount() const; + uint64_t getSentMessagePerSecond(); }; #endif // THREADDATA_H diff --git a/threadlocalutils.cpp b/threadlocalutils.cpp index 485133f..7de86d5 100644 --- a/threadlocalutils.cpp +++ b/threadlocalutils.cpp @@ -60,6 +60,9 @@ bool Utils::isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars) std::memcpy(topicCopy.data(), s.c_str(), len); std::memset(&topicCopy.data()[len], 0x20, 16); // I fill out with spaces, as valid chars + if (alsoCheckInvalidPublishChars && len > 0 && s[0] == '$') + return false; + int n = 0; const char *i = topicCopy.data(); while (n < len) diff --git a/utils.cpp b/utils.cpp index 07e6704..236a77c 100644 --- a/utils.cpp +++ b/utils.cpp @@ -57,6 +57,9 @@ bool topicsMatch(const std::string &subscribeTopic, const std::string &publishTo if (subscribeTopic.find("+") == std::string::npos && subscribeTopic.find("#") == std::string::npos) return subscribeTopic == publishTopic; + if (!subscribeTopic.empty() && !publishTopic.empty() && publishTopic[0] == '$' && subscribeTopic[0] != '$') + return false; + const std::vector subscribeParts = splitToVector(subscribeTopic, '/'); const std::vector publishParts = splitToVector(publishTopic, '/'); -- libgit2 0.21.4