Commit 510ac00769b66ba0e53fb4ce7a0d3c70a13b34f9

Authored by Patric Stout
1 parent b3d744fe

feat(client): call the correct callback when receiving messages

Subscriptions are now stored in a tree-like structure, to quickly
find the correct callbacks. This not only reduces the complexity
from O(n) to O(logn), but also doesn't require stuff like regex.

It does however require slightly more memory.
README.md
1 # TrueMQTT - A modern C++ MQTT Client library 1 # TrueMQTT - A modern C++ MQTT Client library
2 2
3 -This project is currently a Work In Progress, and is not functional. 3 +This project is currently a Work In Progress.
  4 +Although the basics are functional, it is untested.
4 5
5 ## Development 6 ## Development
6 7
example/pubsub/main.cpp
@@ -21,20 +21,34 @@ int main() @@ -21,20 +21,34 @@ int main()
21 { std::cout << "Error " << error << ": " << message << std::endl; }); 21 { std::cout << "Error " << error << ": " << message << std::endl; });
22 22
23 client.connect(); 23 client.connect();
  24 + std::this_thread::sleep_for(std::chrono::milliseconds(100));
24 25
25 - bool stop = false; 26 + int stop = 0;
26 27
27 // Subscribe to the topic we will be publishing under in a bit. 28 // Subscribe to the topic we will be publishing under in a bit.
28 - client.subscribe("test", [&stop](const std::string &topic, const std::string &payload) 29 + client.subscribe("test/test/test", [&stop](const std::string topic, const std::string payload)
29 { 30 {
30 - std::cout << "Received message on topic " << topic << ": " << payload << std::endl;  
31 - stop = true; }); 31 + std::cout << "Received message on exact topic " << topic << ": " << payload << std::endl;
  32 + stop++; });
  33 + client.subscribe("test/+/test", [&stop](const std::string topic, const std::string payload)
  34 + {
  35 + std::cout << "Received message on single wildcard topic " << topic << ": " << payload << std::endl;
  36 + stop++; });
  37 + client.subscribe("test/#", [&stop](const std::string topic, const std::string payload)
  38 + {
  39 + std::cout << "Received message on multi wildcard topic " << topic << ": " << payload << std::endl;
  40 + stop++; });
  41 + client.subscribe("test/test/+", [&stop](const std::string topic, const std::string payload)
  42 + {
  43 + /* Never actually called */ });
  44 +
  45 + client.unsubscribe("test/test/+");
32 46
33 // Publish a message on the same topic as we subscribed too. 47 // Publish a message on the same topic as we subscribed too.
34 - client.publish("test", "Hello World!", false); 48 + client.publish("test/test/test", "Hello World!", false);
35 49
36 // Wait till we receive the message back on our subscription. 50 // Wait till we receive the message back on our subscription.
37 - while (!stop) 51 + while (stop != 3)
38 { 52 {
39 std::this_thread::sleep_for(std::chrono::milliseconds(10)); 53 std::this_thread::sleep_for(std::chrono::milliseconds(10));
40 } 54 }
src/Client.cpp
@@ -10,6 +10,8 @@ @@ -10,6 +10,8 @@
10 #include "ClientImpl.h" 10 #include "ClientImpl.h"
11 #include "Log.h" 11 #include "Log.h"
12 12
  13 +#include <sstream>
  14 +
13 using TrueMQTT::Client; 15 using TrueMQTT::Client;
14 16
15 Client::Client(const std::string &host, int port, const std::string &client_id, int connection_timeout, int connection_backoff_max, int keep_alive_interval) 17 Client::Client(const std::string &host, int port, const std::string &client_id, int connection_timeout, int connection_backoff_max, int keep_alive_interval)
@@ -135,8 +137,21 @@ void Client::subscribe(const std::string &amp;topic, std::function&lt;void(std::string, @@ -135,8 +137,21 @@ void Client::subscribe(const std::string &amp;topic, std::function&lt;void(std::string,
135 137
136 LOG_DEBUG(this->m_impl, "Subscribing to topic '" + topic + "'"); 138 LOG_DEBUG(this->m_impl, "Subscribing to topic '" + topic + "'");
137 139
138 - this->m_impl->subscriptions[topic] = callback; 140 + // Split the topic on /, to find each part.
  141 + std::string part;
  142 + std::stringstream stopic(topic);
  143 + std::getline(stopic, part, '/');
  144 +
  145 + // Find the root node, and walk down till we find the leaf node.
  146 + Client::Impl::SubscriptionPart *subscriptions = &this->m_impl->subscriptions.try_emplace(part, Client::Impl::SubscriptionPart()).first->second;
  147 + while (std::getline(stopic, part, '/'))
  148 + {
  149 + subscriptions = &subscriptions->children.try_emplace(part, Client::Impl::SubscriptionPart()).first->second;
  150 + }
  151 + // Add the callback to the leaf node.
  152 + subscriptions->callbacks.push_back(callback);
139 153
  154 + this->m_impl->subscription_topics.insert(topic);
140 if (this->m_impl->state == Client::Impl::State::CONNECTED) 155 if (this->m_impl->state == Client::Impl::State::CONNECTED)
141 { 156 {
142 this->m_impl->sendSubscribe(topic); 157 this->m_impl->sendSubscribe(topic);
@@ -155,8 +170,46 @@ void Client::unsubscribe(const std::string &amp;topic) @@ -155,8 +170,46 @@ void Client::unsubscribe(const std::string &amp;topic)
155 170
156 LOG_DEBUG(this->m_impl, "Unsubscribing from topic '" + topic + "'"); 171 LOG_DEBUG(this->m_impl, "Unsubscribing from topic '" + topic + "'");
157 172
158 - this->m_impl->subscriptions.erase(topic); 173 + // Split the topic on /, to find each part.
  174 + std::string part;
  175 + std::stringstream stopic(topic);
  176 + std::getline(stopic, part, '/');
159 177
  178 + // Find the root node, and walk down till we find the leaf node.
  179 + std::vector<std::tuple<std::string, Client::Impl::SubscriptionPart *>> reverse;
  180 + Client::Impl::SubscriptionPart *subscriptions = &this->m_impl->subscriptions[part];
  181 + reverse.push_back({part, subscriptions});
  182 + while (std::getline(stopic, part, '/'))
  183 + {
  184 + subscriptions = &subscriptions->children[part];
  185 + reverse.push_back({part, subscriptions});
  186 + }
  187 + // Clear the callbacks in the leaf node.
  188 + subscriptions->callbacks.clear();
  189 +
  190 + // Bookkeeping: remove any empty nodes.
  191 + // Otherwise we will slowly grow in memory if a user does a lot of unsubscribes
  192 + // on different topics.
  193 + std::string remove_next = "";
  194 + for (auto it = reverse.rbegin(); it != reverse.rend(); it++)
  195 + {
  196 + if (!remove_next.empty())
  197 + {
  198 + std::get<1>(*it)->children.erase(remove_next);
  199 + remove_next = "";
  200 + }
  201 +
  202 + if (std::get<1>(*it)->callbacks.empty() && std::get<1>(*it)->children.empty())
  203 + {
  204 + remove_next = std::get<0>(*it);
  205 + }
  206 + }
  207 + if (!remove_next.empty())
  208 + {
  209 + this->m_impl->subscriptions.erase(remove_next);
  210 + }
  211 +
  212 + this->m_impl->subscription_topics.erase(topic);
160 if (this->m_impl->state == Client::Impl::State::CONNECTED) 213 if (this->m_impl->state == Client::Impl::State::CONNECTED)
161 { 214 {
162 this->m_impl->sendUnsubscribe(topic); 215 this->m_impl->sendUnsubscribe(topic);
@@ -182,9 +235,9 @@ void Client::Impl::connectionStateChange(bool connected) @@ -182,9 +235,9 @@ void Client::Impl::connectionStateChange(bool connected)
182 // implementation. 235 // implementation.
183 236
184 // First restore any subscription. 237 // First restore any subscription.
185 - for (auto &subscription : this->subscriptions) 238 + for (auto &subscription : this->subscription_topics)
186 { 239 {
187 - this->sendSubscribe(subscription.first); 240 + this->sendSubscribe(subscription);
188 } 241 }
189 // Flush the publish queue. 242 // Flush the publish queue.
190 for (auto &message : this->publish_queue) 243 for (auto &message : this->publish_queue)
@@ -233,9 +286,80 @@ void Client::Impl::toPublishQueue(const std::string &amp;topic, const std::string &amp;p @@ -233,9 +286,80 @@ void Client::Impl::toPublishQueue(const std::string &amp;topic, const std::string &amp;p
233 this->publish_queue.push_back({topic, payload, retain}); 286 this->publish_queue.push_back({topic, payload, retain});
234 } 287 }
235 288
  289 +void Client::Impl::findSubscriptionMatch(std::vector<std::function<void(std::string, std::string)>> &matching_callbacks, const std::map<std::string, Client::Impl::SubscriptionPart> &subscriptions, std::deque<std::string> &parts)
  290 +{
  291 + // If we reached the end of the topic, do nothing anymore.
  292 + if (parts.size() == 0)
  293 + {
  294 + return;
  295 + }
  296 +
  297 + LOG_TRACE(this, "Finding subscription match for part '" + parts.front() + "'");
  298 +
  299 + // Find the match based on the part.
  300 + auto it = subscriptions.find(parts.front());
  301 + if (it != subscriptions.end())
  302 + {
  303 + LOG_TRACE(this, "Found subscription match for part '" + parts.front() + "' with " + std::to_string(it->second.callbacks.size()) + " callbacks");
  304 +
  305 + matching_callbacks.insert(matching_callbacks.end(), it->second.callbacks.begin(), it->second.callbacks.end());
  306 +
  307 + std::deque<std::string> remaining_parts(parts.begin() + 1, parts.end());
  308 + findSubscriptionMatch(matching_callbacks, it->second.children, remaining_parts);
  309 + }
  310 +
  311 + // Find the match if this part is a wildcard.
  312 + it = subscriptions.find("+");
  313 + if (it != subscriptions.end())
  314 + {
  315 + LOG_TRACE(this, "Found subscription match for '+' with " + std::to_string(it->second.callbacks.size()) + " callbacks");
  316 +
  317 + matching_callbacks.insert(matching_callbacks.end(), it->second.callbacks.begin(), it->second.callbacks.end());
  318 +
  319 + std::deque<std::string> remaining_parts(parts.begin() + 1, parts.end());
  320 + findSubscriptionMatch(matching_callbacks, it->second.children, remaining_parts);
  321 + }
  322 +
  323 + // Find the match if the remaining is a wildcard.
  324 + it = subscriptions.find("#");
  325 + if (it != subscriptions.end())
  326 + {
  327 + LOG_TRACE(this, "Found subscription match for '#' with " + std::to_string(it->second.callbacks.size()) + " callbacks");
  328 +
  329 + matching_callbacks.insert(matching_callbacks.end(), it->second.callbacks.begin(), it->second.callbacks.end());
  330 + // No more recursion here, as we implicit consume the rest of the parts too.
  331 + }
  332 +}
  333 +
236 void Client::Impl::messageReceived(std::string topic, std::string payload) 334 void Client::Impl::messageReceived(std::string topic, std::string payload)
237 { 335 {
238 LOG_TRACE(this, "Message received on topic '" + topic + "': " + payload); 336 LOG_TRACE(this, "Message received on topic '" + topic + "': " + payload);
239 337
240 - // TODO -- Find which subscriptions match, and call the callbacks. 338 + // Split the topic on the / in parts.
  339 + std::string part;
  340 + std::stringstream stopic(topic);
  341 + std::deque<std::string> parts;
  342 + while (std::getline(stopic, part, '/'))
  343 + {
  344 + parts.emplace_back(part);
  345 + }
  346 +
  347 + // Find the matching subscription(s) with recursion.
  348 + std::vector<std::function<void(std::string, std::string)>> matching_callbacks;
  349 + findSubscriptionMatch(matching_callbacks, subscriptions, parts);
  350 +
  351 + LOG_TRACE(this, "Found " + std::to_string(matching_callbacks.size()) + " subscription(s) for topic '" + topic + "'");
  352 +
  353 + if (matching_callbacks.size() == 1)
  354 + {
  355 + // For a single callback there is no need to copy the topic/payload.
  356 + matching_callbacks[0](std::move(topic), std::move(payload));
  357 + }
  358 + else
  359 + {
  360 + for (auto &callback : matching_callbacks)
  361 + {
  362 + callback(topic, payload);
  363 + }
  364 + }
241 } 365 }
src/ClientImpl.h
@@ -14,6 +14,7 @@ @@ -14,6 +14,7 @@
14 #include <deque> 14 #include <deque>
15 #include <map> 15 #include <map>
16 #include <mutex> 16 #include <mutex>
  17 +#include <set>
17 #include <string> 18 #include <string>
18 #include <thread> 19 #include <thread>
19 20
@@ -39,6 +40,13 @@ public: @@ -39,6 +40,13 @@ public:
39 CONNECTED, ///< The client is connected to the broker. 40 CONNECTED, ///< The client is connected to the broker.
40 }; 41 };
41 42
  43 + class SubscriptionPart
  44 + {
  45 + public:
  46 + std::map<std::string, SubscriptionPart> children;
  47 + std::vector<std::function<void(std::string, std::string)>> callbacks;
  48 + };
  49 +
42 void connect(); ///< Connect to the broker. 50 void connect(); ///< Connect to the broker.
43 void disconnect(); ///< Disconnect from the broker. 51 void disconnect(); ///< Disconnect from the broker.
44 void sendPublish(const std::string &topic, const std::string &payload, bool retain); ///< Send a publish message to the broker. 52 void sendPublish(const std::string &topic, const std::string &payload, bool retain); ///< Send a publish message to the broker.
@@ -46,7 +54,9 @@ public: @@ -46,7 +54,9 @@ public:
46 void sendUnsubscribe(const std::string &topic); ///< Send an unsubscribe message to the broker. 54 void sendUnsubscribe(const std::string &topic); ///< Send an unsubscribe message to the broker.
47 void connectionStateChange(bool connected); ///< Called when a connection goes from CONNECTING state to CONNECTED state or visa versa. 55 void connectionStateChange(bool connected); ///< Called when a connection goes from CONNECTING state to CONNECTED state or visa versa.
48 void toPublishQueue(const std::string &topic, const std::string &payload, bool retain); ///< Add a publish message to the publish queue. 56 void toPublishQueue(const std::string &topic, const std::string &payload, bool retain); ///< Add a publish message to the publish queue.
49 - void messageReceived(std::string topic, std::string payload); ///< Called when a message is received from the broker. 57 + void messageReceived(std::string topic, std::string payload); ///< Called when a message is received from the broker.
  58 +
  59 + void findSubscriptionMatch(std::vector<std::function<void(std::string, std::string)>> &callbacks, const std::map<std::string, SubscriptionPart> &subscriptions, std::deque<std::string> &parts); ///< Recursive function to find any matching subscription based on parts.
50 60
51 State state = State::DISCONNECTED; ///< The current state of the client. 61 State state = State::DISCONNECTED; ///< The current state of the client.
52 std::mutex state_mutex; ///< Mutex to protect state changes. 62 std::mutex state_mutex; ///< Mutex to protect state changes.
@@ -71,7 +81,8 @@ public: @@ -71,7 +81,8 @@ public:
71 size_t publish_queue_size = -1; ///< Size of the publish queue. 81 size_t publish_queue_size = -1; ///< Size of the publish queue.
72 std::deque<std::tuple<std::string, std::string, bool>> publish_queue; ///< Queue of publish messages to send to the broker. 82 std::deque<std::tuple<std::string, std::string, bool>> publish_queue; ///< Queue of publish messages to send to the broker.
73 83
74 - std::map<std::string, std::function<void(std::string, std::string)>> subscriptions; ///< Map of active subscriptions. 84 + std::set<std::string> subscription_topics; ///< Flat list of topics the client is subscribed to.
  85 + std::map<std::string, SubscriptionPart> subscriptions; ///< Tree of active subscriptions build up from the parts on the topic.
75 86
76 std::unique_ptr<Connection> connection; ///< Connection to the broker. 87 std::unique_ptr<Connection> connection; ///< Connection to the broker.
77 uint16_t packet_id = 0; ///< The next packet ID to use. Will overflow on 65535 to 0. 88 uint16_t packet_id = 0; ///< The next packet ID to use. Will overflow on 65535 to 0.
src/Connection.cpp
@@ -36,6 +36,8 @@ Connection::Connection(TrueMQTT::Client::LogLevel log_level, @@ -36,6 +36,8 @@ Connection::Connection(TrueMQTT::Client::LogLevel log_level,
36 36
37 Connection::~Connection() 37 Connection::~Connection()
38 { 38 {
  39 + m_state = State::STOP;
  40 +
39 // Make sure the connection thread is terminated. 41 // Make sure the connection thread is terminated.
40 if (m_thread.joinable()) 42 if (m_thread.joinable())
41 { 43 {
@@ -90,6 +92,10 @@ void Connection::run() @@ -90,6 +92,10 @@ void Connection::run()
90 { 92 {
91 if (!recvLoop()) 93 if (!recvLoop())
92 { 94 {
  95 + if (m_state == State::STOP)
  96 + {
  97 + break;
  98 + }
93 if (m_socket != INVALID_SOCKET) 99 if (m_socket != INVALID_SOCKET)
94 { 100 {
95 closesocket(m_socket); 101 closesocket(m_socket);
@@ -100,6 +106,9 @@ void Connection::run() @@ -100,6 +106,9 @@ void Connection::run()
100 } 106 }
101 break; 107 break;
102 } 108 }
  109 +
  110 + case State::STOP:
  111 + return;
103 } 112 }
104 } 113 }
105 } 114 }
src/Connection.h
@@ -55,6 +55,7 @@ private: @@ -55,6 +55,7 @@ private:
55 AUTHENTICATING, 55 AUTHENTICATING,
56 CONNECTED, 56 CONNECTED,
57 BACKOFF, 57 BACKOFF,
  58 + STOP,
58 }; 59 };
59 60
60 TrueMQTT::Client::LogLevel log_level; 61 TrueMQTT::Client::LogLevel log_level;
src/Packet.cpp
@@ -123,6 +123,29 @@ public: @@ -123,6 +123,29 @@ public:
123 123
124 ssize_t Connection::recv(char *buffer, size_t length) 124 ssize_t Connection::recv(char *buffer, size_t length)
125 { 125 {
  126 + // We idle-check every 100ms if we are requested to stop, as otherwise
  127 + // this thread will block till the server disconnects us.
  128 + while (m_state != State::STOP)
  129 + {
  130 + // Check if there is any data available on the socket.
  131 + fd_set read_fds;
  132 + FD_ZERO(&read_fds);
  133 + FD_SET(m_socket, &read_fds);
  134 + timeval timeout = {0, 100};
  135 + size_t ret = select(m_socket + 1, &read_fds, nullptr, nullptr, &timeout);
  136 +
  137 + if (ret == 0)
  138 + {
  139 + continue;
  140 + }
  141 + break;
  142 + }
  143 + if (m_state == State::STOP)
  144 + {
  145 + LOG_TRACE(this, "Closing connection as STOP has been requested");
  146 + return -1;
  147 + }
  148 +
126 ssize_t res = ::recv(m_socket, buffer, length, 0); 149 ssize_t res = ::recv(m_socket, buffer, length, 0);
127 if (res == 0) 150 if (res == 0)
128 { 151 {
@@ -277,6 +300,20 @@ bool Connection::recvLoop() @@ -277,6 +300,20 @@ bool Connection::recvLoop()
277 300
278 break; 301 break;
279 } 302 }
  303 + case Packet::PacketType::UNSUBACK:
  304 + {
  305 + uint16_t packet_id;
  306 +
  307 + if (!packet.read_uint16(packet_id))
  308 + {
  309 + LOG_ERROR(this, "Malformed packet received, closing connection");
  310 + return false;
  311 + }
  312 +
  313 + LOG_DEBUG(this, "Received UNSUBACK with packet id " + std::to_string(packet_id));
  314 +
  315 + break;
  316 + }
280 default: 317 default:
281 LOG_ERROR(this, "Received unexpected packet type " + std::string(magic_enum::enum_name(packet_type)) + " from broker, closing connection"); 318 LOG_ERROR(this, "Received unexpected packet type " + std::string(magic_enum::enum_name(packet_type)) + " from broker, closing connection");
282 return false; 319 return false;
@@ -384,4 +421,17 @@ void TrueMQTT::Client::Impl::sendSubscribe(const std::string &amp;topic) @@ -384,4 +421,17 @@ void TrueMQTT::Client::Impl::sendSubscribe(const std::string &amp;topic)
384 void TrueMQTT::Client::Impl::sendUnsubscribe(const std::string &topic) 421 void TrueMQTT::Client::Impl::sendUnsubscribe(const std::string &topic)
385 { 422 {
386 LOG_TRACE(this, "Sending unsubscribe message for topic '" + topic + "'"); 423 LOG_TRACE(this, "Sending unsubscribe message for topic '" + topic + "'");
  424 +
  425 + Packet packet(Packet::PacketType::UNSUBSCRIBE, 2);
  426 +
  427 + // By specs, packet-id zero is not allowed.
  428 + if (packet_id == 0)
  429 + {
  430 + packet_id++;
  431 + }
  432 +
  433 + packet.write_uint16(packet_id++);
  434 + packet.write_string(topic);
  435 +
  436 + connection->send(packet);
387 } 437 }