Commit 33d19cec63a85e6b0a226a3dcfab5eb1d345c0b4

Authored by Wiebe Cazemier
1 parent 70896f68

I think this covers retained messages

CMakeLists.txt
@@ -18,6 +18,7 @@ add_executable(FlashMQ @@ -18,6 +18,7 @@ add_executable(FlashMQ
18 types.cpp 18 types.cpp
19 subscriptionstore.cpp 19 subscriptionstore.cpp
20 rwlockguard.cpp 20 rwlockguard.cpp
  21 + retainedmessage.cpp
21 ) 22 )
22 23
23 target_link_libraries(FlashMQ pthread) 24 target_link_libraries(FlashMQ pthread)
retainedmessage.cpp 0 → 100644
  1 +#include "retainedmessage.h"
  2 +
  3 +RetainedMessage::RetainedMessage(const std::string &topic, const std::string &payload, char qos) :
  4 + topic(topic),
  5 + payload(payload),
  6 + qos(qos)
  7 +{
  8 +
  9 +}
  10 +
  11 +bool RetainedMessage::operator==(const RetainedMessage &rhs) const
  12 +{
  13 + return this->topic == rhs.topic;
  14 +}
retainedmessage.h 0 → 100644
  1 +#ifndef RETAINEDMESSAGE_H
  2 +#define RETAINEDMESSAGE_H
  3 +
  4 +#include <string>
  5 +
  6 +struct RetainedMessage
  7 +{
  8 + std::string topic;
  9 + std::string payload;
  10 + char qos;
  11 +
  12 + RetainedMessage(const std::string &topic, const std::string &payload, char qos);
  13 +
  14 + bool operator==(const RetainedMessage &rhs) const;
  15 +};
  16 +
  17 +namespace std {
  18 +
  19 + template <>
  20 + struct hash<RetainedMessage>
  21 + {
  22 + std::size_t operator()(const RetainedMessage& k) const
  23 + {
  24 + using std::size_t;
  25 + using std::hash;
  26 + using std::string;
  27 +
  28 + return hash<string>()(k.topic);
  29 + }
  30 + };
  31 +
  32 +}
  33 +
  34 +#endif // RETAINEDMESSAGE_H
subscriptionstore.cpp
@@ -46,7 +46,7 @@ void SubscriptionStore::addSubscription(Client_p &amp;client, const std::string &amp;top @@ -46,7 +46,7 @@ void SubscriptionStore::addSubscription(Client_p &amp;client, const std::string &amp;top
46 clients_by_id[client->getClientId()] = client; 46 clients_by_id[client->getClientId()] = client;
47 lock_guard.unlock(); 47 lock_guard.unlock();
48 48
49 - giveClientRetainedMessage(client, topic); // TODO: wildcards 49 + giveClientRetainedMessages(client, topic);
50 } 50 }
51 51
52 void SubscriptionStore::removeClient(const Client_p &client) 52 void SubscriptionStore::removeClient(const Client_p &client)
@@ -122,22 +122,20 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::string &amp;topic, const @@ -122,22 +122,20 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::string &amp;topic, const
122 publishRecursively(subtopics.begin(), subtopics.end(), root, packet); 122 publishRecursively(subtopics.begin(), subtopics.end(), root, packet);
123 } 123 }
124 124
125 -void SubscriptionStore::giveClientRetainedMessage(Client_p &client, const std::string &topic) 125 +void SubscriptionStore::giveClientRetainedMessages(Client_p &client, const std::string &subscribe_topic)
126 { 126 {
127 RWLockGuard locker(&retainedMessagesRwlock); 127 RWLockGuard locker(&retainedMessagesRwlock);
128 locker.rdlock(); 128 locker.rdlock();
129 129
130 - auto retained_ptr = retainedMessages.find(topic);  
131 -  
132 - if (retained_ptr == retainedMessages.end())  
133 - return;  
134 -  
135 - const RetainedPayload &m = retained_ptr->second; 130 + for(const RetainedMessage &rm : retainedMessages)
  131 + {
  132 + Publish publish(rm.topic, rm.payload, rm.qos);
  133 + publish.retain = true;
  134 + const MqttPacket packet(publish);
136 135
137 - Publish publish(topic, m.payload, m.qos);  
138 - publish.retain = true;  
139 - const MqttPacket packet(publish);  
140 - client->writeMqttPacket(packet); 136 + if (topicsMatch(subscribe_topic, rm.topic))
  137 + client->writeMqttPacket(packet);
  138 + }
141 } 139 }
142 140
143 void SubscriptionStore::setRetainedMessage(const std::string &topic, const std::string &payload, char qos) 141 void SubscriptionStore::setRetainedMessage(const std::string &topic, const std::string &payload, char qos)
@@ -145,7 +143,9 @@ void SubscriptionStore::setRetainedMessage(const std::string &amp;topic, const std:: @@ -145,7 +143,9 @@ void SubscriptionStore::setRetainedMessage(const std::string &amp;topic, const std::
145 RWLockGuard locker(&retainedMessagesRwlock); 143 RWLockGuard locker(&retainedMessagesRwlock);
146 locker.wrlock(); 144 locker.wrlock();
147 145
148 - auto retained_ptr = retainedMessages.find(topic); 146 + RetainedMessage rm(topic, payload, qos);
  147 +
  148 + auto retained_ptr = retainedMessages.find(rm);
149 bool retained_found = retained_ptr != retainedMessages.end(); 149 bool retained_found = retained_ptr != retainedMessages.end();
150 150
151 if (!retained_found && payload.empty()) 151 if (!retained_found && payload.empty())
@@ -153,13 +153,11 @@ void SubscriptionStore::setRetainedMessage(const std::string &amp;topic, const std:: @@ -153,13 +153,11 @@ void SubscriptionStore::setRetainedMessage(const std::string &amp;topic, const std::
153 153
154 if (retained_found && payload.empty()) 154 if (retained_found && payload.empty())
155 { 155 {
156 - retainedMessages.erase(topic); 156 + retainedMessages.erase(rm);
157 return; 157 return;
158 } 158 }
159 159
160 - RetainedPayload &m = retainedMessages[topic];  
161 - m.payload = payload;  
162 - m.qos = qos; 160 + retainedMessages.insert(std::move(rm));
163 } 161 }
164 162
165 163
subscriptionstore.h
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 11
12 #include "client.h" 12 #include "client.h"
13 #include "utils.h" 13 #include "utils.h"
  14 +#include "retainedmessage.h"
14 15
15 struct RetainedPayload 16 struct RetainedPayload
16 { 17 {
@@ -39,7 +40,7 @@ class SubscriptionStore @@ -39,7 +40,7 @@ class SubscriptionStore
39 const std::unordered_map<std::string, Client_p> &clients_by_id_const; 40 const std::unordered_map<std::string, Client_p> &clients_by_id_const;
40 41
41 pthread_rwlock_t retainedMessagesRwlock = PTHREAD_RWLOCK_INITIALIZER; 42 pthread_rwlock_t retainedMessagesRwlock = PTHREAD_RWLOCK_INITIALIZER;
42 - std::unordered_map<std::string, RetainedPayload> retainedMessages; 43 + std::unordered_set<RetainedMessage> retainedMessages;
43 44
44 bool publishNonRecursively(const MqttPacket &packet, const std::forward_list<std::string> &subscribers) const; 45 bool publishNonRecursively(const MqttPacket &packet, const std::forward_list<std::string> &subscribers) const;
45 bool publishRecursively(std::list<std::string>::const_iterator cur_subtopic_it, std::list<std::string>::const_iterator end, 46 bool publishRecursively(std::list<std::string>::const_iterator cur_subtopic_it, std::list<std::string>::const_iterator end,
@@ -51,7 +52,7 @@ public: @@ -51,7 +52,7 @@ public:
51 void removeClient(const Client_p &client); 52 void removeClient(const Client_p &client);
52 53
53 void queuePacketAtSubscribers(const std::string &topic, const MqttPacket &packet, const Client_p &sender); 54 void queuePacketAtSubscribers(const std::string &topic, const MqttPacket &packet, const Client_p &sender);
54 - void giveClientRetainedMessage(Client_p &client, const std::string &topic); 55 + void giveClientRetainedMessages(Client_p &client, const std::string &subscribe_topic);
55 56
56 void setRetainedMessage(const std::string &topic, const std::string &payload, char qos); 57 void setRetainedMessage(const std::string &topic, const std::string &payload, char qos);
57 }; 58 };
utils.cpp
@@ -18,3 +18,33 @@ std::list&lt;std::__cxx11::string&gt; split(const std::string &amp;input, const char sep, @@ -18,3 +18,33 @@ std::list&lt;std::__cxx11::string&gt; split(const std::string &amp;input, const char sep,
18 list.push_back(input.substr(start, std::string::npos)); 18 list.push_back(input.substr(start, std::string::npos));
19 return list; 19 return list;
20 } 20 }
  21 +
  22 +
  23 +bool topicsMatch(const std::string &subscribeTopic, const std::string &publishTopic)
  24 +{
  25 + if (subscribeTopic.find("+") == std::string::npos && subscribeTopic.find("#") == std::string::npos)
  26 + return subscribeTopic == publishTopic;
  27 +
  28 + const std::list<std::string> subscribeParts = split(subscribeTopic, '/');
  29 + const std::list<std::string> publishParts = split(publishTopic, '/');
  30 +
  31 + auto subscribe_itr = subscribeParts.begin();
  32 + auto publish_itr = publishParts.begin();
  33 +
  34 + bool result = true;
  35 + while (subscribe_itr != subscribeParts.end() && publish_itr != publishParts.end())
  36 + {
  37 + const std::string &subscribe_subtopic = *subscribe_itr++;
  38 + const std::string &publish_subtopic = *publish_itr++;
  39 +
  40 + if (subscribe_subtopic == "+")
  41 + continue;
  42 + if (subscribe_subtopic == "#")
  43 + return true;
  44 + if (subscribe_subtopic != publish_subtopic)
  45 + return false;
  46 + }
  47 +
  48 + result = subscribe_itr == subscribeParts.end() && publish_itr == publishParts.end();
  49 + return result;
  50 +}
@@ -21,4 +21,6 @@ template&lt;typename T&gt; int check(int rc) @@ -21,4 +21,6 @@ template&lt;typename T&gt; int check(int rc)
21 21
22 std::list<std::string> split(const std::string &input, const char sep, size_t max = std::numeric_limits<int>::max(), bool keep_empty_parts = true); 22 std::list<std::string> split(const std::string &input, const char sep, size_t max = std::numeric_limits<int>::max(), bool keep_empty_parts = true);
23 23
  24 +bool topicsMatch(const std::string &subscribeTopic, const std::string &publishTopic);
  25 +
24 #endif // UTILS_H 26 #endif // UTILS_H