diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 215eb9c..ff3fdc8 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -208,7 +208,7 @@ void MqttPacket::handlePublish(std::shared_ptr &subscriptionS size_t payload_length = remainingAfterPos(); std::string payload(readBytes(payload_length), payload_length); - subscriptionStore->queueAtClientsTemp(topic, *this, sender); + subscriptionStore->queuePacketAtSubscribers(topic, *this, sender); } diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index 3db2e6a..16bcd78 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -12,7 +12,8 @@ SubscriptionNode::SubscriptionNode(const std::string &subtopic) : } SubscriptionStore::SubscriptionStore() : - subscriptions2(new SubscriptionNode("root")) + root(new SubscriptionNode("root")), + clients_by_id_const(clients_by_id) { } @@ -24,7 +25,7 @@ void SubscriptionStore::addSubscription(Client_p &client, std::string &topic) RWLockGuard lock_guard(&subscriptionsRwlock); lock_guard.wrlock(); - SubscriptionNode *deepestNode = subscriptions2.get(); + SubscriptionNode *deepestNode = root.get(); for(const std::string &subtopic : subtopics) { SubscriptionNode &nodeRef = *deepestNode; @@ -39,7 +40,7 @@ void SubscriptionStore::addSubscription(Client_p &client, std::string &topic) if (deepestNode) { - deepestNode->subscribers.insert(client->getClientId()); + deepestNode->subscribers.push_front(client->getClientId()); } clients_by_id[client->getClientId()] = client; @@ -52,33 +53,51 @@ void SubscriptionStore::removeClient(const Client_p &client) clients_by_id.erase(client->getClientId()); } -void SubscriptionStore::queueAtClientsTemp(std::string &topic, const MqttPacket &packet, const Client_p &sender) +// TODO: keep a cache of topics vs clients + +bool SubscriptionStore::publishRecursively(std::list::const_iterator cur_subtopic_it, std::list::const_iterator end, + std::unique_ptr &this_node, const MqttPacket &packet) const { - const std::list subtopics = split(topic, '/'); - const auto &clients = clients_by_id; + if (cur_subtopic_it == end) // This is the end of the topic path, so look for subscribers here. + { + for (const std::string &client_id : this_node->subscribers) + { + auto client_it = clients_by_id_const.find(client_id); + if (client_it != clients_by_id_const.end()) + client_it->second->writeMqttPacket(packet); + } - RWLockGuard lock_guard(&subscriptionsRwlock); - lock_guard.rdlock(); + return true; + } - const SubscriptionNode *deepestNode = subscriptions2.get(); - for(const std::string &subtopic : subtopics) - { - auto sub_iter = deepestNode->children.find(subtopic); - if (sub_iter == deepestNode->children.end()) - return; + std::string cur_subtop = *cur_subtopic_it; + auto sub_node = this_node->children.find(cur_subtop); + + const auto next_subtopic = ++cur_subtopic_it; - const std::unique_ptr &sub_node = sub_iter->second; - assert(sub_node); // because any empty unique_ptr's is a bug - deepestNode = sub_node.get(); + if (sub_node != this_node->children.end()) + { + publishRecursively(next_subtopic, end, sub_node->second, packet); } - for (const std::string &client_id : deepestNode->subscribers) + const auto plus_sign_node = this_node->children.find("+"); + + if (plus_sign_node != this_node->children.end()) { - std::cout << "Publishing to " << client_id << std::endl; - auto client_it = clients.find(client_id); - if (client_it != clients.end()) - client_it->second->writeMqttPacket(packet); + publishRecursively(next_subtopic, end, plus_sign_node->second, packet); } + + return false; +} + +void SubscriptionStore::queuePacketAtSubscribers(std::string &topic, const MqttPacket &packet, const Client_p &sender) +{ + const std::list subtopics = split(topic, '/'); + + RWLockGuard lock_guard(&subscriptionsRwlock); + lock_guard.rdlock(); + + publishRecursively(subtopics.begin(), subtopics.end(), root, packet); } diff --git a/subscriptionstore.h b/subscriptionstore.h index 46222d0..949018a 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -2,7 +2,7 @@ #define SUBSCRIPTIONSTORE_H #include -#include +#include #include #include #include @@ -21,27 +21,27 @@ public: SubscriptionNode(const SubscriptionNode &node) = delete; SubscriptionNode(SubscriptionNode &&node) = delete; - std::unordered_set subscribers; // The idea is to store subscriptions by client id, to support persistent sessions. + std::forward_list subscribers; // The idea is to store subscriptions by client id, to support persistent sessions. std::unordered_map> children; - }; class SubscriptionStore { - std::unique_ptr subscriptions2; + std::unique_ptr root; pthread_rwlock_t subscriptionsRwlock = PTHREAD_RWLOCK_INITIALIZER; std::unordered_map clients_by_id; + const std::unordered_map &clients_by_id_const; + + bool publishRecursively(std::list::const_iterator cur_subtopic_it, std::list::const_iterator end, + std::unique_ptr &next, const MqttPacket &packet) const; public: SubscriptionStore(); void addSubscription(Client_p &client, std::string &topic); void removeClient(const Client_p &client); - // work with read copies intead of mutex/lock over the central store - void getReadCopy(); // TODO - - void queueAtClientsTemp(std::string &topic, const MqttPacket &packet, const Client_p &sender); + void queuePacketAtSubscribers(std::string &topic, const MqttPacket &packet, const Client_p &sender); }; #endif // SUBSCRIPTIONSTORE_H