diff --git a/CMakeLists.txt b/CMakeLists.txt index 52e74e7..8062466 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,6 +18,7 @@ add_executable(FlashMQ types.cpp subscriptionstore.cpp rwlockguard.cpp + retainedmessage.cpp ) target_link_libraries(FlashMQ pthread) diff --git a/retainedmessage.cpp b/retainedmessage.cpp new file mode 100644 index 0000000..1432a90 --- /dev/null +++ b/retainedmessage.cpp @@ -0,0 +1,14 @@ +#include "retainedmessage.h" + +RetainedMessage::RetainedMessage(const std::string &topic, const std::string &payload, char qos) : + topic(topic), + payload(payload), + qos(qos) +{ + +} + +bool RetainedMessage::operator==(const RetainedMessage &rhs) const +{ + return this->topic == rhs.topic; +} diff --git a/retainedmessage.h b/retainedmessage.h new file mode 100644 index 0000000..8792e3c --- /dev/null +++ b/retainedmessage.h @@ -0,0 +1,34 @@ +#ifndef RETAINEDMESSAGE_H +#define RETAINEDMESSAGE_H + +#include + +struct RetainedMessage +{ + std::string topic; + std::string payload; + char qos; + + RetainedMessage(const std::string &topic, const std::string &payload, char qos); + + bool operator==(const RetainedMessage &rhs) const; +}; + +namespace std { + + template <> + struct hash + { + std::size_t operator()(const RetainedMessage& k) const + { + using std::size_t; + using std::hash; + using std::string; + + return hash()(k.topic); + } + }; + +} + +#endif // RETAINEDMESSAGE_H diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index 5869799..d5ae976 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -46,7 +46,7 @@ void SubscriptionStore::addSubscription(Client_p &client, const std::string &top clients_by_id[client->getClientId()] = client; lock_guard.unlock(); - giveClientRetainedMessage(client, topic); // TODO: wildcards + giveClientRetainedMessages(client, topic); } void SubscriptionStore::removeClient(const Client_p &client) @@ -122,22 +122,20 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::string &topic, const publishRecursively(subtopics.begin(), subtopics.end(), root, packet); } -void SubscriptionStore::giveClientRetainedMessage(Client_p &client, const std::string &topic) +void SubscriptionStore::giveClientRetainedMessages(Client_p &client, const std::string &subscribe_topic) { RWLockGuard locker(&retainedMessagesRwlock); locker.rdlock(); - auto retained_ptr = retainedMessages.find(topic); - - if (retained_ptr == retainedMessages.end()) - return; - - const RetainedPayload &m = retained_ptr->second; + for(const RetainedMessage &rm : retainedMessages) + { + Publish publish(rm.topic, rm.payload, rm.qos); + publish.retain = true; + const MqttPacket packet(publish); - Publish publish(topic, m.payload, m.qos); - publish.retain = true; - const MqttPacket packet(publish); - client->writeMqttPacket(packet); + if (topicsMatch(subscribe_topic, rm.topic)) + client->writeMqttPacket(packet); + } } void SubscriptionStore::setRetainedMessage(const std::string &topic, const std::string &payload, char qos) @@ -145,7 +143,9 @@ void SubscriptionStore::setRetainedMessage(const std::string &topic, const std:: RWLockGuard locker(&retainedMessagesRwlock); locker.wrlock(); - auto retained_ptr = retainedMessages.find(topic); + RetainedMessage rm(topic, payload, qos); + + auto retained_ptr = retainedMessages.find(rm); bool retained_found = retained_ptr != retainedMessages.end(); if (!retained_found && payload.empty()) @@ -153,13 +153,11 @@ void SubscriptionStore::setRetainedMessage(const std::string &topic, const std:: if (retained_found && payload.empty()) { - retainedMessages.erase(topic); + retainedMessages.erase(rm); return; } - RetainedPayload &m = retainedMessages[topic]; - m.payload = payload; - m.qos = qos; + retainedMessages.insert(std::move(rm)); } diff --git a/subscriptionstore.h b/subscriptionstore.h index 14a96ee..deb242f 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -11,6 +11,7 @@ #include "client.h" #include "utils.h" +#include "retainedmessage.h" struct RetainedPayload { @@ -39,7 +40,7 @@ class SubscriptionStore const std::unordered_map &clients_by_id_const; pthread_rwlock_t retainedMessagesRwlock = PTHREAD_RWLOCK_INITIALIZER; - std::unordered_map retainedMessages; + std::unordered_set retainedMessages; bool publishNonRecursively(const MqttPacket &packet, const std::forward_list &subscribers) const; bool publishRecursively(std::list::const_iterator cur_subtopic_it, std::list::const_iterator end, @@ -51,7 +52,7 @@ public: void removeClient(const Client_p &client); void queuePacketAtSubscribers(const std::string &topic, const MqttPacket &packet, const Client_p &sender); - void giveClientRetainedMessage(Client_p &client, const std::string &topic); + void giveClientRetainedMessages(Client_p &client, const std::string &subscribe_topic); void setRetainedMessage(const std::string &topic, const std::string &payload, char qos); }; diff --git a/utils.cpp b/utils.cpp index fe36c47..1106b5b 100644 --- a/utils.cpp +++ b/utils.cpp @@ -18,3 +18,33 @@ std::list split(const std::string &input, const char sep, list.push_back(input.substr(start, std::string::npos)); return list; } + + +bool topicsMatch(const std::string &subscribeTopic, const std::string &publishTopic) +{ + if (subscribeTopic.find("+") == std::string::npos && subscribeTopic.find("#") == std::string::npos) + return subscribeTopic == publishTopic; + + const std::list subscribeParts = split(subscribeTopic, '/'); + const std::list publishParts = split(publishTopic, '/'); + + auto subscribe_itr = subscribeParts.begin(); + auto publish_itr = publishParts.begin(); + + bool result = true; + while (subscribe_itr != subscribeParts.end() && publish_itr != publishParts.end()) + { + const std::string &subscribe_subtopic = *subscribe_itr++; + const std::string &publish_subtopic = *publish_itr++; + + if (subscribe_subtopic == "+") + continue; + if (subscribe_subtopic == "#") + return true; + if (subscribe_subtopic != publish_subtopic) + return false; + } + + result = subscribe_itr == subscribeParts.end() && publish_itr == publishParts.end(); + return result; +} diff --git a/utils.h b/utils.h index aebc401..5a121d3 100644 --- a/utils.h +++ b/utils.h @@ -21,4 +21,6 @@ template int check(int rc) std::list split(const std::string &input, const char sep, size_t max = std::numeric_limits::max(), bool keep_empty_parts = true); +bool topicsMatch(const std::string &subscribeTopic, const std::string &publishTopic); + #endif // UTILS_H