Commit 510ac00769b66ba0e53fb4ce7a0d3c70a13b34f9

Authored by Patric Stout
1 parent b3d744fe

feat(client): call the correct callback when receiving messages

Subscriptions are now stored in a tree-like structure, to quickly
find the correct callbacks. This not only reduces the complexity
from O(n) to O(logn), but also doesn't require stuff like regex.

It does however require slightly more memory.
README.md
1 1 # TrueMQTT - A modern C++ MQTT Client library
2 2  
3   -This project is currently a Work In Progress, and is not functional.
  3 +This project is currently a Work In Progress.
  4 +Although the basics are functional, it is untested.
4 5  
5 6 ## Development
6 7  
... ...
example/pubsub/main.cpp
... ... @@ -21,20 +21,34 @@ int main()
21 21 { std::cout << "Error " << error << ": " << message << std::endl; });
22 22  
23 23 client.connect();
  24 + std::this_thread::sleep_for(std::chrono::milliseconds(100));
24 25  
25   - bool stop = false;
  26 + int stop = 0;
26 27  
27 28 // Subscribe to the topic we will be publishing under in a bit.
28   - client.subscribe("test", [&stop](const std::string &topic, const std::string &payload)
  29 + client.subscribe("test/test/test", [&stop](const std::string topic, const std::string payload)
29 30 {
30   - std::cout << "Received message on topic " << topic << ": " << payload << std::endl;
31   - stop = true; });
  31 + std::cout << "Received message on exact topic " << topic << ": " << payload << std::endl;
  32 + stop++; });
  33 + client.subscribe("test/+/test", [&stop](const std::string topic, const std::string payload)
  34 + {
  35 + std::cout << "Received message on single wildcard topic " << topic << ": " << payload << std::endl;
  36 + stop++; });
  37 + client.subscribe("test/#", [&stop](const std::string topic, const std::string payload)
  38 + {
  39 + std::cout << "Received message on multi wildcard topic " << topic << ": " << payload << std::endl;
  40 + stop++; });
  41 + client.subscribe("test/test/+", [&stop](const std::string topic, const std::string payload)
  42 + {
  43 + /* Never actually called */ });
  44 +
  45 + client.unsubscribe("test/test/+");
32 46  
33 47 // Publish a message on the same topic as we subscribed too.
34   - client.publish("test", "Hello World!", false);
  48 + client.publish("test/test/test", "Hello World!", false);
35 49  
36 50 // Wait till we receive the message back on our subscription.
37   - while (!stop)
  51 + while (stop != 3)
38 52 {
39 53 std::this_thread::sleep_for(std::chrono::milliseconds(10));
40 54 }
... ...
src/Client.cpp
... ... @@ -10,6 +10,8 @@
10 10 #include "ClientImpl.h"
11 11 #include "Log.h"
12 12  
  13 +#include <sstream>
  14 +
13 15 using TrueMQTT::Client;
14 16  
15 17 Client::Client(const std::string &host, int port, const std::string &client_id, int connection_timeout, int connection_backoff_max, int keep_alive_interval)
... ... @@ -135,8 +137,21 @@ void Client::subscribe(const std::string &amp;topic, std::function&lt;void(std::string,
135 137  
136 138 LOG_DEBUG(this->m_impl, "Subscribing to topic '" + topic + "'");
137 139  
138   - this->m_impl->subscriptions[topic] = callback;
  140 + // Split the topic on /, to find each part.
  141 + std::string part;
  142 + std::stringstream stopic(topic);
  143 + std::getline(stopic, part, '/');
  144 +
  145 + // Find the root node, and walk down till we find the leaf node.
  146 + Client::Impl::SubscriptionPart *subscriptions = &this->m_impl->subscriptions.try_emplace(part, Client::Impl::SubscriptionPart()).first->second;
  147 + while (std::getline(stopic, part, '/'))
  148 + {
  149 + subscriptions = &subscriptions->children.try_emplace(part, Client::Impl::SubscriptionPart()).first->second;
  150 + }
  151 + // Add the callback to the leaf node.
  152 + subscriptions->callbacks.push_back(callback);
139 153  
  154 + this->m_impl->subscription_topics.insert(topic);
140 155 if (this->m_impl->state == Client::Impl::State::CONNECTED)
141 156 {
142 157 this->m_impl->sendSubscribe(topic);
... ... @@ -155,8 +170,46 @@ void Client::unsubscribe(const std::string &amp;topic)
155 170  
156 171 LOG_DEBUG(this->m_impl, "Unsubscribing from topic '" + topic + "'");
157 172  
158   - this->m_impl->subscriptions.erase(topic);
  173 + // Split the topic on /, to find each part.
  174 + std::string part;
  175 + std::stringstream stopic(topic);
  176 + std::getline(stopic, part, '/');
159 177  
  178 + // Find the root node, and walk down till we find the leaf node.
  179 + std::vector<std::tuple<std::string, Client::Impl::SubscriptionPart *>> reverse;
  180 + Client::Impl::SubscriptionPart *subscriptions = &this->m_impl->subscriptions[part];
  181 + reverse.push_back({part, subscriptions});
  182 + while (std::getline(stopic, part, '/'))
  183 + {
  184 + subscriptions = &subscriptions->children[part];
  185 + reverse.push_back({part, subscriptions});
  186 + }
  187 + // Clear the callbacks in the leaf node.
  188 + subscriptions->callbacks.clear();
  189 +
  190 + // Bookkeeping: remove any empty nodes.
  191 + // Otherwise we will slowly grow in memory if a user does a lot of unsubscribes
  192 + // on different topics.
  193 + std::string remove_next = "";
  194 + for (auto it = reverse.rbegin(); it != reverse.rend(); it++)
  195 + {
  196 + if (!remove_next.empty())
  197 + {
  198 + std::get<1>(*it)->children.erase(remove_next);
  199 + remove_next = "";
  200 + }
  201 +
  202 + if (std::get<1>(*it)->callbacks.empty() && std::get<1>(*it)->children.empty())
  203 + {
  204 + remove_next = std::get<0>(*it);
  205 + }
  206 + }
  207 + if (!remove_next.empty())
  208 + {
  209 + this->m_impl->subscriptions.erase(remove_next);
  210 + }
  211 +
  212 + this->m_impl->subscription_topics.erase(topic);
160 213 if (this->m_impl->state == Client::Impl::State::CONNECTED)
161 214 {
162 215 this->m_impl->sendUnsubscribe(topic);
... ... @@ -182,9 +235,9 @@ void Client::Impl::connectionStateChange(bool connected)
182 235 // implementation.
183 236  
184 237 // First restore any subscription.
185   - for (auto &subscription : this->subscriptions)
  238 + for (auto &subscription : this->subscription_topics)
186 239 {
187   - this->sendSubscribe(subscription.first);
  240 + this->sendSubscribe(subscription);
188 241 }
189 242 // Flush the publish queue.
190 243 for (auto &message : this->publish_queue)
... ... @@ -233,9 +286,80 @@ void Client::Impl::toPublishQueue(const std::string &amp;topic, const std::string &amp;p
233 286 this->publish_queue.push_back({topic, payload, retain});
234 287 }
235 288  
  289 +void Client::Impl::findSubscriptionMatch(std::vector<std::function<void(std::string, std::string)>> &matching_callbacks, const std::map<std::string, Client::Impl::SubscriptionPart> &subscriptions, std::deque<std::string> &parts)
  290 +{
  291 + // If we reached the end of the topic, do nothing anymore.
  292 + if (parts.size() == 0)
  293 + {
  294 + return;
  295 + }
  296 +
  297 + LOG_TRACE(this, "Finding subscription match for part '" + parts.front() + "'");
  298 +
  299 + // Find the match based on the part.
  300 + auto it = subscriptions.find(parts.front());
  301 + if (it != subscriptions.end())
  302 + {
  303 + LOG_TRACE(this, "Found subscription match for part '" + parts.front() + "' with " + std::to_string(it->second.callbacks.size()) + " callbacks");
  304 +
  305 + matching_callbacks.insert(matching_callbacks.end(), it->second.callbacks.begin(), it->second.callbacks.end());
  306 +
  307 + std::deque<std::string> remaining_parts(parts.begin() + 1, parts.end());
  308 + findSubscriptionMatch(matching_callbacks, it->second.children, remaining_parts);
  309 + }
  310 +
  311 + // Find the match if this part is a wildcard.
  312 + it = subscriptions.find("+");
  313 + if (it != subscriptions.end())
  314 + {
  315 + LOG_TRACE(this, "Found subscription match for '+' with " + std::to_string(it->second.callbacks.size()) + " callbacks");
  316 +
  317 + matching_callbacks.insert(matching_callbacks.end(), it->second.callbacks.begin(), it->second.callbacks.end());
  318 +
  319 + std::deque<std::string> remaining_parts(parts.begin() + 1, parts.end());
  320 + findSubscriptionMatch(matching_callbacks, it->second.children, remaining_parts);
  321 + }
  322 +
  323 + // Find the match if the remaining is a wildcard.
  324 + it = subscriptions.find("#");
  325 + if (it != subscriptions.end())
  326 + {
  327 + LOG_TRACE(this, "Found subscription match for '#' with " + std::to_string(it->second.callbacks.size()) + " callbacks");
  328 +
  329 + matching_callbacks.insert(matching_callbacks.end(), it->second.callbacks.begin(), it->second.callbacks.end());
  330 + // No more recursion here, as we implicit consume the rest of the parts too.
  331 + }
  332 +}
  333 +
236 334 void Client::Impl::messageReceived(std::string topic, std::string payload)
237 335 {
238 336 LOG_TRACE(this, "Message received on topic '" + topic + "': " + payload);
239 337  
240   - // TODO -- Find which subscriptions match, and call the callbacks.
  338 + // Split the topic on the / in parts.
  339 + std::string part;
  340 + std::stringstream stopic(topic);
  341 + std::deque<std::string> parts;
  342 + while (std::getline(stopic, part, '/'))
  343 + {
  344 + parts.emplace_back(part);
  345 + }
  346 +
  347 + // Find the matching subscription(s) with recursion.
  348 + std::vector<std::function<void(std::string, std::string)>> matching_callbacks;
  349 + findSubscriptionMatch(matching_callbacks, subscriptions, parts);
  350 +
  351 + LOG_TRACE(this, "Found " + std::to_string(matching_callbacks.size()) + " subscription(s) for topic '" + topic + "'");
  352 +
  353 + if (matching_callbacks.size() == 1)
  354 + {
  355 + // For a single callback there is no need to copy the topic/payload.
  356 + matching_callbacks[0](std::move(topic), std::move(payload));
  357 + }
  358 + else
  359 + {
  360 + for (auto &callback : matching_callbacks)
  361 + {
  362 + callback(topic, payload);
  363 + }
  364 + }
241 365 }
... ...
src/ClientImpl.h
... ... @@ -14,6 +14,7 @@
14 14 #include <deque>
15 15 #include <map>
16 16 #include <mutex>
  17 +#include <set>
17 18 #include <string>
18 19 #include <thread>
19 20  
... ... @@ -39,6 +40,13 @@ public:
39 40 CONNECTED, ///< The client is connected to the broker.
40 41 };
41 42  
  43 + class SubscriptionPart
  44 + {
  45 + public:
  46 + std::map<std::string, SubscriptionPart> children;
  47 + std::vector<std::function<void(std::string, std::string)>> callbacks;
  48 + };
  49 +
42 50 void connect(); ///< Connect to the broker.
43 51 void disconnect(); ///< Disconnect from the broker.
44 52 void sendPublish(const std::string &topic, const std::string &payload, bool retain); ///< Send a publish message to the broker.
... ... @@ -46,7 +54,9 @@ public:
46 54 void sendUnsubscribe(const std::string &topic); ///< Send an unsubscribe message to the broker.
47 55 void connectionStateChange(bool connected); ///< Called when a connection goes from CONNECTING state to CONNECTED state or visa versa.
48 56 void toPublishQueue(const std::string &topic, const std::string &payload, bool retain); ///< Add a publish message to the publish queue.
49   - void messageReceived(std::string topic, std::string payload); ///< Called when a message is received from the broker.
  57 + void messageReceived(std::string topic, std::string payload); ///< Called when a message is received from the broker.
  58 +
  59 + void findSubscriptionMatch(std::vector<std::function<void(std::string, std::string)>> &callbacks, const std::map<std::string, SubscriptionPart> &subscriptions, std::deque<std::string> &parts); ///< Recursive function to find any matching subscription based on parts.
50 60  
51 61 State state = State::DISCONNECTED; ///< The current state of the client.
52 62 std::mutex state_mutex; ///< Mutex to protect state changes.
... ... @@ -71,7 +81,8 @@ public:
71 81 size_t publish_queue_size = -1; ///< Size of the publish queue.
72 82 std::deque<std::tuple<std::string, std::string, bool>> publish_queue; ///< Queue of publish messages to send to the broker.
73 83  
74   - std::map<std::string, std::function<void(std::string, std::string)>> subscriptions; ///< Map of active subscriptions.
  84 + std::set<std::string> subscription_topics; ///< Flat list of topics the client is subscribed to.
  85 + std::map<std::string, SubscriptionPart> subscriptions; ///< Tree of active subscriptions build up from the parts on the topic.
75 86  
76 87 std::unique_ptr<Connection> connection; ///< Connection to the broker.
77 88 uint16_t packet_id = 0; ///< The next packet ID to use. Will overflow on 65535 to 0.
... ...
src/Connection.cpp
... ... @@ -36,6 +36,8 @@ Connection::Connection(TrueMQTT::Client::LogLevel log_level,
36 36  
37 37 Connection::~Connection()
38 38 {
  39 + m_state = State::STOP;
  40 +
39 41 // Make sure the connection thread is terminated.
40 42 if (m_thread.joinable())
41 43 {
... ... @@ -90,6 +92,10 @@ void Connection::run()
90 92 {
91 93 if (!recvLoop())
92 94 {
  95 + if (m_state == State::STOP)
  96 + {
  97 + break;
  98 + }
93 99 if (m_socket != INVALID_SOCKET)
94 100 {
95 101 closesocket(m_socket);
... ... @@ -100,6 +106,9 @@ void Connection::run()
100 106 }
101 107 break;
102 108 }
  109 +
  110 + case State::STOP:
  111 + return;
103 112 }
104 113 }
105 114 }
... ...
src/Connection.h
... ... @@ -55,6 +55,7 @@ private:
55 55 AUTHENTICATING,
56 56 CONNECTED,
57 57 BACKOFF,
  58 + STOP,
58 59 };
59 60  
60 61 TrueMQTT::Client::LogLevel log_level;
... ...
src/Packet.cpp
... ... @@ -123,6 +123,29 @@ public:
123 123  
124 124 ssize_t Connection::recv(char *buffer, size_t length)
125 125 {
  126 + // We idle-check every 100ms if we are requested to stop, as otherwise
  127 + // this thread will block till the server disconnects us.
  128 + while (m_state != State::STOP)
  129 + {
  130 + // Check if there is any data available on the socket.
  131 + fd_set read_fds;
  132 + FD_ZERO(&read_fds);
  133 + FD_SET(m_socket, &read_fds);
  134 + timeval timeout = {0, 100};
  135 + size_t ret = select(m_socket + 1, &read_fds, nullptr, nullptr, &timeout);
  136 +
  137 + if (ret == 0)
  138 + {
  139 + continue;
  140 + }
  141 + break;
  142 + }
  143 + if (m_state == State::STOP)
  144 + {
  145 + LOG_TRACE(this, "Closing connection as STOP has been requested");
  146 + return -1;
  147 + }
  148 +
126 149 ssize_t res = ::recv(m_socket, buffer, length, 0);
127 150 if (res == 0)
128 151 {
... ... @@ -277,6 +300,20 @@ bool Connection::recvLoop()
277 300  
278 301 break;
279 302 }
  303 + case Packet::PacketType::UNSUBACK:
  304 + {
  305 + uint16_t packet_id;
  306 +
  307 + if (!packet.read_uint16(packet_id))
  308 + {
  309 + LOG_ERROR(this, "Malformed packet received, closing connection");
  310 + return false;
  311 + }
  312 +
  313 + LOG_DEBUG(this, "Received UNSUBACK with packet id " + std::to_string(packet_id));
  314 +
  315 + break;
  316 + }
280 317 default:
281 318 LOG_ERROR(this, "Received unexpected packet type " + std::string(magic_enum::enum_name(packet_type)) + " from broker, closing connection");
282 319 return false;
... ... @@ -384,4 +421,17 @@ void TrueMQTT::Client::Impl::sendSubscribe(const std::string &amp;topic)
384 421 void TrueMQTT::Client::Impl::sendUnsubscribe(const std::string &topic)
385 422 {
386 423 LOG_TRACE(this, "Sending unsubscribe message for topic '" + topic + "'");
  424 +
  425 + Packet packet(Packet::PacketType::UNSUBSCRIBE, 2);
  426 +
  427 + // By specs, packet-id zero is not allowed.
  428 + if (packet_id == 0)
  429 + {
  430 + packet_id++;
  431 + }
  432 +
  433 + packet.write_uint16(packet_id++);
  434 + packet.write_string(topic);
  435 +
  436 + connection->send(packet);
387 437 }
... ...