diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index 62a2c31..32ac75f 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -28,7 +28,7 @@ SubscriptionNode::SubscriptionNode(const std::string &subtopic) : } -std::vector &SubscriptionNode::getSubscribers() +std::unordered_map &SubscriptionNode::getSubscribers() { return subscribers; } @@ -44,18 +44,8 @@ void SubscriptionNode::addSubscriber(const std::shared_ptr &subscriber, sub.session = subscriber; sub.qos = qos; - // I'll have to decide whether to keep the subscriber as a vector. Vectors are - // fast, and relatively, you don't often add subscribers. - auto subscriber_it = std::find(subscribers.begin(), subscribers.end(), sub); - if (subscriber_it == subscribers.end()) - { - subscribers.push_back(sub); - } - else - { - Subscription &existingSub = *subscriber_it; - existingSub = sub; - } + const std::string &client_id = subscriber->getClientId(); + subscribers[client_id] = sub; } void SubscriptionNode::removeSubscriber(const std::shared_ptr &subscriber) @@ -64,7 +54,7 @@ void SubscriptionNode::removeSubscriber(const std::shared_ptr &subscrib sub.session = subscriber; sub.qos = 0; - auto it = std::find(subscribers.begin(), subscribers.end(), sub); + auto it = subscribers.find(subscriber->getClientId()); if (it != subscribers.end()) { @@ -261,10 +251,12 @@ bool SubscriptionStore::sessionPresent(const std::string &clientid) return result; } -void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::vector &subscribers, uint64_t &count) const +void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::unordered_map &subscribers, uint64_t &count) const { - for (const Subscription &sub : subscribers) + for (auto &pair : subscribers) { + const Subscription &sub = pair.second; + const std::shared_ptr session = sub.session.lock(); if (session) // Shared pointer expires when session has been cleaned by 'clean session' connect. { @@ -479,10 +471,10 @@ int SubscriptionNode::cleanSubscriptions() auto it = subscribers.begin(); while (it != subscribers.end()) { - std::shared_ptr ses = it->session.lock(); + std::shared_ptr ses = it->second.session.lock(); if (!ses) { - Logger::getInstance()->logf(LOG_DEBUG, "Removing empty spot in subscribers vector"); + Logger::getInstance()->logf(LOG_DEBUG, "Removing empty spot in subscribers map"); it = subscribers.erase(it); } else @@ -567,8 +559,9 @@ void SubscriptionStore::getRetainedMessages(RetainedMessageNode *this_node, std: void SubscriptionStore::getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root, std::unordered_map> &outputList) const { - for (const Subscription &node : this_node->getSubscribers()) + for (auto &pair : this_node->getSubscribers()) { + const Subscription &node = pair.second; std::shared_ptr ses = node.session.lock(); if (ses) { diff --git a/subscriptionstore.h b/subscriptionstore.h index dec8e57..7486b9f 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -44,14 +44,14 @@ struct Subscription class SubscriptionNode { std::string subtopic; - std::vector subscribers; + std::unordered_map subscribers; public: SubscriptionNode(const std::string &subtopic); SubscriptionNode(const SubscriptionNode &node) = delete; SubscriptionNode(SubscriptionNode &&node) = delete; - std::vector &getSubscribers(); + std::unordered_map &getSubscribers(); const std::string &getSubtopic() const; void addSubscriber(const std::shared_ptr &subscriber, char qos); void removeSubscriber(const std::shared_ptr &subscriber); @@ -94,7 +94,7 @@ class SubscriptionStore Logger *logger = Logger::getInstance(); - void publishNonRecursively(const MqttPacket &packet, const std::vector &subscribers, uint64_t &count) const; + void publishNonRecursively(const MqttPacket &packet, const std::unordered_map &subscribers, uint64_t &count) const; void publishRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, SubscriptionNode *this_node, const MqttPacket &packet, uint64_t &count) const; void getRetainedMessages(RetainedMessageNode *this_node, std::vector &outputList) const;