diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index 6bd3544..7dbfeba 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -22,6 +22,13 @@ License along with FlashMQ. If not, see . #include "rwlockguard.h" #include "retainedmessagesdb.h" +ReceivingSubscriber::ReceivingSubscriber(const std::shared_ptr &ses, char qos) : + session(ses), + qos(qos) +{ + +} + SubscriptionNode::SubscriptionNode(const std::string &subtopic) : subtopic(subtopic) { @@ -260,7 +267,8 @@ bool SubscriptionStore::sessionPresent(const std::string &clientid) return result; } -void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::unordered_map &subscribers, uint64_t &count) const +void SubscriptionStore::publishNonRecursively(const std::unordered_map &subscribers, + std::forward_list &targetSessions) const { for (auto &pair : subscribers) { @@ -269,7 +277,8 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const st const std::shared_ptr session = sub.session.lock(); if (session) // Shared pointer expires when session has been cleaned by 'clean session' connect. { - session->writePacket(packet, sub.qos, false, count); + ReceivingSubscriber x(session, sub.qos); + targetSessions.emplace_front(session, sub.qos); } } } @@ -282,16 +291,16 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const st * @param packet * @param count as a reference (vs return value) because a return value introduces an extra call i.e. limits tail recursion optimization. * - * As noted in the params section, this method was written so that it could be (somewhat) optimized for tail recursion by the kernel. If you refactor this, + * As noted in the params section, this method was written so that it could be (somewhat) optimized for tail recursion by the compiler. If you refactor this, * look at objdump --disassemble --demangle to see how many calls (not jumps) to itself are made and compare. */ void SubscriptionStore::publishRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, - SubscriptionNode *this_node, const MqttPacket &packet, uint64_t &count) const + SubscriptionNode *this_node, std::forward_list &targetSessions) const { if (cur_subtopic_it == end) // This is the end of the topic path, so look for subscribers here. { if (this_node) - publishNonRecursively(packet, this_node->getSubscribers(), count); + publishNonRecursively(this_node->getSubscribers(), targetSessions); return; } @@ -308,18 +317,18 @@ void SubscriptionStore::publishRecursively(std::vector::const_itera if (this_node->childrenPound) { - publishNonRecursively(packet, this_node->childrenPound->getSubscribers(), count); + publishNonRecursively(this_node->childrenPound->getSubscribers(), targetSessions); } const auto &sub_node = this_node->children.find(cur_subtop); if (sub_node != this_node->children.end()) { - publishRecursively(next_subtopic, end, sub_node->second.get(), packet, count); + publishRecursively(next_subtopic, end, sub_node->second.get(), targetSessions); } if (this_node->childrenPlus) { - publishRecursively(next_subtopic, end, this_node->childrenPlus.get(), packet, count); + publishRecursively(next_subtopic, end, this_node->childrenPlus.get(), targetSessions); } } @@ -329,11 +338,19 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::vector SubscriptionNode *startNode = dollar ? &rootDollar : &root; - RWLockGuard lock_guard(&subscriptionsRwlock); - lock_guard.rdlock(); - uint64_t count = 0; - publishRecursively(subtopics.begin(), subtopics.end(), startNode, packet, count); + std::forward_list subscriberSessions; + + { + RWLockGuard lock_guard(&subscriptionsRwlock); + lock_guard.rdlock(); + publishRecursively(subtopics.begin(), subtopics.end(), startNode, subscriberSessions); + } + + for(const ReceivingSubscriber &x : subscriberSessions) + { + x.session->writePacket(packet, x.qos, false, count); + } std::shared_ptr sender = packet.getSender(); if (sender) diff --git a/subscriptionstore.h b/subscriptionstore.h index 7486b9f..f1c7a0e 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -41,6 +41,15 @@ struct Subscription void reset(); }; +struct ReceivingSubscriber +{ + const std::shared_ptr session; + const char qos; + +public: + ReceivingSubscriber(const std::shared_ptr &ses, char qos); +}; + class SubscriptionNode { std::string subtopic; @@ -94,9 +103,10 @@ class SubscriptionStore Logger *logger = Logger::getInstance(); - void publishNonRecursively(const MqttPacket &packet, const std::unordered_map &subscribers, uint64_t &count) const; + void publishNonRecursively(const std::unordered_map &subscribers, + std::forward_list &targetSessions) const; void publishRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, - SubscriptionNode *this_node, const MqttPacket &packet, uint64_t &count) const; + SubscriptionNode *this_node, std::forward_list &targetSessions) const; void getRetainedMessages(RetainedMessageNode *this_node, std::vector &outputList) const; void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root, std::unordered_map> &outputList) const;