From 510ac00769b66ba0e53fb4ce7a0d3c70a13b34f9 Mon Sep 17 00:00:00 2001 From: Patric Stout Date: Sun, 11 Sep 2022 14:50:30 +0200 Subject: [PATCH] feat(client): call the correct callback when receiving messages --- README.md | 3 ++- example/pubsub/main.cpp | 26 ++++++++++++++++++++------ src/Client.cpp | 134 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----- src/ClientImpl.h | 15 +++++++++++++-- src/Connection.cpp | 9 +++++++++ src/Connection.h | 1 + src/Packet.cpp | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 224 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index f013736..a0ebbc1 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # TrueMQTT - A modern C++ MQTT Client library -This project is currently a Work In Progress, and is not functional. +This project is currently a Work In Progress. +Although the basics are functional, it is untested. ## Development diff --git a/example/pubsub/main.cpp b/example/pubsub/main.cpp index eab126f..30cd041 100644 --- a/example/pubsub/main.cpp +++ b/example/pubsub/main.cpp @@ -21,20 +21,34 @@ int main() { std::cout << "Error " << error << ": " << message << std::endl; }); client.connect(); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); - bool stop = false; + int stop = 0; // Subscribe to the topic we will be publishing under in a bit. - client.subscribe("test", [&stop](const std::string &topic, const std::string &payload) + client.subscribe("test/test/test", [&stop](const std::string topic, const std::string payload) { - std::cout << "Received message on topic " << topic << ": " << payload << std::endl; - stop = true; }); + std::cout << "Received message on exact topic " << topic << ": " << payload << std::endl; + stop++; }); + client.subscribe("test/+/test", [&stop](const std::string topic, const std::string payload) + { + std::cout << "Received message on single wildcard topic " << topic << ": " << payload << std::endl; + stop++; }); + client.subscribe("test/#", [&stop](const std::string topic, const std::string payload) + { + std::cout << "Received message on multi wildcard topic " << topic << ": " << payload << std::endl; + stop++; }); + client.subscribe("test/test/+", [&stop](const std::string topic, const std::string payload) + { + /* Never actually called */ }); + + client.unsubscribe("test/test/+"); // Publish a message on the same topic as we subscribed too. - client.publish("test", "Hello World!", false); + client.publish("test/test/test", "Hello World!", false); // Wait till we receive the message back on our subscription. - while (!stop) + while (stop != 3) { std::this_thread::sleep_for(std::chrono::milliseconds(10)); } diff --git a/src/Client.cpp b/src/Client.cpp index 8a49641..dc322c7 100644 --- a/src/Client.cpp +++ b/src/Client.cpp @@ -10,6 +10,8 @@ #include "ClientImpl.h" #include "Log.h" +#include + using TrueMQTT::Client; Client::Client(const std::string &host, int port, const std::string &client_id, int connection_timeout, int connection_backoff_max, int keep_alive_interval) @@ -135,8 +137,21 @@ void Client::subscribe(const std::string &topic, std::functionm_impl, "Subscribing to topic '" + topic + "'"); - this->m_impl->subscriptions[topic] = callback; + // Split the topic on /, to find each part. + std::string part; + std::stringstream stopic(topic); + std::getline(stopic, part, '/'); + + // Find the root node, and walk down till we find the leaf node. + Client::Impl::SubscriptionPart *subscriptions = &this->m_impl->subscriptions.try_emplace(part, Client::Impl::SubscriptionPart()).first->second; + while (std::getline(stopic, part, '/')) + { + subscriptions = &subscriptions->children.try_emplace(part, Client::Impl::SubscriptionPart()).first->second; + } + // Add the callback to the leaf node. + subscriptions->callbacks.push_back(callback); + this->m_impl->subscription_topics.insert(topic); if (this->m_impl->state == Client::Impl::State::CONNECTED) { this->m_impl->sendSubscribe(topic); @@ -155,8 +170,46 @@ void Client::unsubscribe(const std::string &topic) LOG_DEBUG(this->m_impl, "Unsubscribing from topic '" + topic + "'"); - this->m_impl->subscriptions.erase(topic); + // Split the topic on /, to find each part. + std::string part; + std::stringstream stopic(topic); + std::getline(stopic, part, '/'); + // Find the root node, and walk down till we find the leaf node. + std::vector> reverse; + Client::Impl::SubscriptionPart *subscriptions = &this->m_impl->subscriptions[part]; + reverse.push_back({part, subscriptions}); + while (std::getline(stopic, part, '/')) + { + subscriptions = &subscriptions->children[part]; + reverse.push_back({part, subscriptions}); + } + // Clear the callbacks in the leaf node. + subscriptions->callbacks.clear(); + + // Bookkeeping: remove any empty nodes. + // Otherwise we will slowly grow in memory if a user does a lot of unsubscribes + // on different topics. + std::string remove_next = ""; + for (auto it = reverse.rbegin(); it != reverse.rend(); it++) + { + if (!remove_next.empty()) + { + std::get<1>(*it)->children.erase(remove_next); + remove_next = ""; + } + + if (std::get<1>(*it)->callbacks.empty() && std::get<1>(*it)->children.empty()) + { + remove_next = std::get<0>(*it); + } + } + if (!remove_next.empty()) + { + this->m_impl->subscriptions.erase(remove_next); + } + + this->m_impl->subscription_topics.erase(topic); if (this->m_impl->state == Client::Impl::State::CONNECTED) { this->m_impl->sendUnsubscribe(topic); @@ -182,9 +235,9 @@ void Client::Impl::connectionStateChange(bool connected) // implementation. // First restore any subscription. - for (auto &subscription : this->subscriptions) + for (auto &subscription : this->subscription_topics) { - this->sendSubscribe(subscription.first); + this->sendSubscribe(subscription); } // Flush the publish queue. for (auto &message : this->publish_queue) @@ -233,9 +286,80 @@ void Client::Impl::toPublishQueue(const std::string &topic, const std::string &p this->publish_queue.push_back({topic, payload, retain}); } +void Client::Impl::findSubscriptionMatch(std::vector> &matching_callbacks, const std::map &subscriptions, std::deque &parts) +{ + // If we reached the end of the topic, do nothing anymore. + if (parts.size() == 0) + { + return; + } + + LOG_TRACE(this, "Finding subscription match for part '" + parts.front() + "'"); + + // Find the match based on the part. + auto it = subscriptions.find(parts.front()); + if (it != subscriptions.end()) + { + LOG_TRACE(this, "Found subscription match for part '" + parts.front() + "' with " + std::to_string(it->second.callbacks.size()) + " callbacks"); + + matching_callbacks.insert(matching_callbacks.end(), it->second.callbacks.begin(), it->second.callbacks.end()); + + std::deque remaining_parts(parts.begin() + 1, parts.end()); + findSubscriptionMatch(matching_callbacks, it->second.children, remaining_parts); + } + + // Find the match if this part is a wildcard. + it = subscriptions.find("+"); + if (it != subscriptions.end()) + { + LOG_TRACE(this, "Found subscription match for '+' with " + std::to_string(it->second.callbacks.size()) + " callbacks"); + + matching_callbacks.insert(matching_callbacks.end(), it->second.callbacks.begin(), it->second.callbacks.end()); + + std::deque remaining_parts(parts.begin() + 1, parts.end()); + findSubscriptionMatch(matching_callbacks, it->second.children, remaining_parts); + } + + // Find the match if the remaining is a wildcard. + it = subscriptions.find("#"); + if (it != subscriptions.end()) + { + LOG_TRACE(this, "Found subscription match for '#' with " + std::to_string(it->second.callbacks.size()) + " callbacks"); + + matching_callbacks.insert(matching_callbacks.end(), it->second.callbacks.begin(), it->second.callbacks.end()); + // No more recursion here, as we implicit consume the rest of the parts too. + } +} + void Client::Impl::messageReceived(std::string topic, std::string payload) { LOG_TRACE(this, "Message received on topic '" + topic + "': " + payload); - // TODO -- Find which subscriptions match, and call the callbacks. + // Split the topic on the / in parts. + std::string part; + std::stringstream stopic(topic); + std::deque parts; + while (std::getline(stopic, part, '/')) + { + parts.emplace_back(part); + } + + // Find the matching subscription(s) with recursion. + std::vector> matching_callbacks; + findSubscriptionMatch(matching_callbacks, subscriptions, parts); + + LOG_TRACE(this, "Found " + std::to_string(matching_callbacks.size()) + " subscription(s) for topic '" + topic + "'"); + + if (matching_callbacks.size() == 1) + { + // For a single callback there is no need to copy the topic/payload. + matching_callbacks[0](std::move(topic), std::move(payload)); + } + else + { + for (auto &callback : matching_callbacks) + { + callback(topic, payload); + } + } } diff --git a/src/ClientImpl.h b/src/ClientImpl.h index 92baead..ebc6617 100644 --- a/src/ClientImpl.h +++ b/src/ClientImpl.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -39,6 +40,13 @@ public: CONNECTED, ///< The client is connected to the broker. }; + class SubscriptionPart + { + public: + std::map children; + std::vector> callbacks; + }; + void connect(); ///< Connect to the broker. void disconnect(); ///< Disconnect from the broker. void sendPublish(const std::string &topic, const std::string &payload, bool retain); ///< Send a publish message to the broker. @@ -46,7 +54,9 @@ public: void sendUnsubscribe(const std::string &topic); ///< Send an unsubscribe message to the broker. void connectionStateChange(bool connected); ///< Called when a connection goes from CONNECTING state to CONNECTED state or visa versa. void toPublishQueue(const std::string &topic, const std::string &payload, bool retain); ///< Add a publish message to the publish queue. - void messageReceived(std::string topic, std::string payload); ///< Called when a message is received from the broker. + void messageReceived(std::string topic, std::string payload); ///< Called when a message is received from the broker. + + void findSubscriptionMatch(std::vector> &callbacks, const std::map &subscriptions, std::deque &parts); ///< Recursive function to find any matching subscription based on parts. State state = State::DISCONNECTED; ///< The current state of the client. std::mutex state_mutex; ///< Mutex to protect state changes. @@ -71,7 +81,8 @@ public: size_t publish_queue_size = -1; ///< Size of the publish queue. std::deque> publish_queue; ///< Queue of publish messages to send to the broker. - std::map> subscriptions; ///< Map of active subscriptions. + std::set subscription_topics; ///< Flat list of topics the client is subscribed to. + std::map subscriptions; ///< Tree of active subscriptions build up from the parts on the topic. std::unique_ptr connection; ///< Connection to the broker. uint16_t packet_id = 0; ///< The next packet ID to use. Will overflow on 65535 to 0. diff --git a/src/Connection.cpp b/src/Connection.cpp index 14048de..e4e85f0 100644 --- a/src/Connection.cpp +++ b/src/Connection.cpp @@ -36,6 +36,8 @@ Connection::Connection(TrueMQTT::Client::LogLevel log_level, Connection::~Connection() { + m_state = State::STOP; + // Make sure the connection thread is terminated. if (m_thread.joinable()) { @@ -90,6 +92,10 @@ void Connection::run() { if (!recvLoop()) { + if (m_state == State::STOP) + { + break; + } if (m_socket != INVALID_SOCKET) { closesocket(m_socket); @@ -100,6 +106,9 @@ void Connection::run() } break; } + + case State::STOP: + return; } } } diff --git a/src/Connection.h b/src/Connection.h index 03b037a..e0398d3 100644 --- a/src/Connection.h +++ b/src/Connection.h @@ -55,6 +55,7 @@ private: AUTHENTICATING, CONNECTED, BACKOFF, + STOP, }; TrueMQTT::Client::LogLevel log_level; diff --git a/src/Packet.cpp b/src/Packet.cpp index 7d5b96d..93a2e9c 100644 --- a/src/Packet.cpp +++ b/src/Packet.cpp @@ -123,6 +123,29 @@ public: ssize_t Connection::recv(char *buffer, size_t length) { + // We idle-check every 100ms if we are requested to stop, as otherwise + // this thread will block till the server disconnects us. + while (m_state != State::STOP) + { + // Check if there is any data available on the socket. + fd_set read_fds; + FD_ZERO(&read_fds); + FD_SET(m_socket, &read_fds); + timeval timeout = {0, 100}; + size_t ret = select(m_socket + 1, &read_fds, nullptr, nullptr, &timeout); + + if (ret == 0) + { + continue; + } + break; + } + if (m_state == State::STOP) + { + LOG_TRACE(this, "Closing connection as STOP has been requested"); + return -1; + } + ssize_t res = ::recv(m_socket, buffer, length, 0); if (res == 0) { @@ -277,6 +300,20 @@ bool Connection::recvLoop() break; } + case Packet::PacketType::UNSUBACK: + { + uint16_t packet_id; + + if (!packet.read_uint16(packet_id)) + { + LOG_ERROR(this, "Malformed packet received, closing connection"); + return false; + } + + LOG_DEBUG(this, "Received UNSUBACK with packet id " + std::to_string(packet_id)); + + break; + } default: LOG_ERROR(this, "Received unexpected packet type " + std::string(magic_enum::enum_name(packet_type)) + " from broker, closing connection"); return false; @@ -384,4 +421,17 @@ void TrueMQTT::Client::Impl::sendSubscribe(const std::string &topic) void TrueMQTT::Client::Impl::sendUnsubscribe(const std::string &topic) { LOG_TRACE(this, "Sending unsubscribe message for topic '" + topic + "'"); + + Packet packet(Packet::PacketType::UNSUBSCRIBE, 2); + + // By specs, packet-id zero is not allowed. + if (packet_id == 0) + { + packet_id++; + } + + packet.write_uint16(packet_id++); + packet.write_string(topic); + + connection->send(packet); } -- libgit2 0.21.4