diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp
index 6bd3544..7dbfeba 100644
--- a/subscriptionstore.cpp
+++ b/subscriptionstore.cpp
@@ -22,6 +22,13 @@ License along with FlashMQ. If not, see .
#include "rwlockguard.h"
#include "retainedmessagesdb.h"
+ReceivingSubscriber::ReceivingSubscriber(const std::shared_ptr &ses, char qos) :
+ session(ses),
+ qos(qos)
+{
+
+}
+
SubscriptionNode::SubscriptionNode(const std::string &subtopic) :
subtopic(subtopic)
{
@@ -260,7 +267,8 @@ bool SubscriptionStore::sessionPresent(const std::string &clientid)
return result;
}
-void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::unordered_map &subscribers, uint64_t &count) const
+void SubscriptionStore::publishNonRecursively(const std::unordered_map &subscribers,
+ std::forward_list &targetSessions) const
{
for (auto &pair : subscribers)
{
@@ -269,7 +277,8 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const st
const std::shared_ptr session = sub.session.lock();
if (session) // Shared pointer expires when session has been cleaned by 'clean session' connect.
{
- session->writePacket(packet, sub.qos, false, count);
+ ReceivingSubscriber x(session, sub.qos);
+ targetSessions.emplace_front(session, sub.qos);
}
}
}
@@ -282,16 +291,16 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const st
* @param packet
* @param count as a reference (vs return value) because a return value introduces an extra call i.e. limits tail recursion optimization.
*
- * As noted in the params section, this method was written so that it could be (somewhat) optimized for tail recursion by the kernel. If you refactor this,
+ * As noted in the params section, this method was written so that it could be (somewhat) optimized for tail recursion by the compiler. If you refactor this,
* look at objdump --disassemble --demangle to see how many calls (not jumps) to itself are made and compare.
*/
void SubscriptionStore::publishRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end,
- SubscriptionNode *this_node, const MqttPacket &packet, uint64_t &count) const
+ SubscriptionNode *this_node, std::forward_list &targetSessions) const
{
if (cur_subtopic_it == end) // This is the end of the topic path, so look for subscribers here.
{
if (this_node)
- publishNonRecursively(packet, this_node->getSubscribers(), count);
+ publishNonRecursively(this_node->getSubscribers(), targetSessions);
return;
}
@@ -308,18 +317,18 @@ void SubscriptionStore::publishRecursively(std::vector::const_itera
if (this_node->childrenPound)
{
- publishNonRecursively(packet, this_node->childrenPound->getSubscribers(), count);
+ publishNonRecursively(this_node->childrenPound->getSubscribers(), targetSessions);
}
const auto &sub_node = this_node->children.find(cur_subtop);
if (sub_node != this_node->children.end())
{
- publishRecursively(next_subtopic, end, sub_node->second.get(), packet, count);
+ publishRecursively(next_subtopic, end, sub_node->second.get(), targetSessions);
}
if (this_node->childrenPlus)
{
- publishRecursively(next_subtopic, end, this_node->childrenPlus.get(), packet, count);
+ publishRecursively(next_subtopic, end, this_node->childrenPlus.get(), targetSessions);
}
}
@@ -329,11 +338,19 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::vector
SubscriptionNode *startNode = dollar ? &rootDollar : &root;
- RWLockGuard lock_guard(&subscriptionsRwlock);
- lock_guard.rdlock();
-
uint64_t count = 0;
- publishRecursively(subtopics.begin(), subtopics.end(), startNode, packet, count);
+ std::forward_list subscriberSessions;
+
+ {
+ RWLockGuard lock_guard(&subscriptionsRwlock);
+ lock_guard.rdlock();
+ publishRecursively(subtopics.begin(), subtopics.end(), startNode, subscriberSessions);
+ }
+
+ for(const ReceivingSubscriber &x : subscriberSessions)
+ {
+ x.session->writePacket(packet, x.qos, false, count);
+ }
std::shared_ptr sender = packet.getSender();
if (sender)
diff --git a/subscriptionstore.h b/subscriptionstore.h
index 7486b9f..f1c7a0e 100644
--- a/subscriptionstore.h
+++ b/subscriptionstore.h
@@ -41,6 +41,15 @@ struct Subscription
void reset();
};
+struct ReceivingSubscriber
+{
+ const std::shared_ptr session;
+ const char qos;
+
+public:
+ ReceivingSubscriber(const std::shared_ptr &ses, char qos);
+};
+
class SubscriptionNode
{
std::string subtopic;
@@ -94,9 +103,10 @@ class SubscriptionStore
Logger *logger = Logger::getInstance();
- void publishNonRecursively(const MqttPacket &packet, const std::unordered_map &subscribers, uint64_t &count) const;
+ void publishNonRecursively(const std::unordered_map &subscribers,
+ std::forward_list &targetSessions) const;
void publishRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end,
- SubscriptionNode *this_node, const MqttPacket &packet, uint64_t &count) const;
+ SubscriptionNode *this_node, std::forward_list &targetSessions) const;
void getRetainedMessages(RetainedMessageNode *this_node, std::vector &outputList) const;
void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root,
std::unordered_map> &outputList) const;