diff --git a/session.cpp b/session.cpp index 3295fc9..702794c 100644 --- a/session.cpp +++ b/session.cpp @@ -21,11 +21,12 @@ std::shared_ptr Session::makeSharedClient() const void Session::assignActiveConnection(std::shared_ptr &client) { this->client = client; + this->client_id = client->getClientId(); } -void Session::writePacket(const MqttPacket &packet) +void Session::writePacket(const MqttPacket &packet, char qos_arg) { - const char qos = packet.getQos(); + const char qos = std::min(packet.getQos(), qos_arg); if (qos == 0) { @@ -41,7 +42,7 @@ void Session::writePacket(const MqttPacket &packet) std::unique_lock locker(qosQueueMutex); if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) { - logger->logf(LOG_WARNING, "Dropping QoS message for client 'TODO', because its QoS buffers were full."); + logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str()); return; } const uint16_t pid = nextPacketId++; diff --git a/session.h b/session.h index a866c32..ac16714 100644 --- a/session.h +++ b/session.h @@ -15,6 +15,7 @@ class Session { std::weak_ptr client; + std::string client_id; std::unordered_map> qosPacketQueue; // TODO: because the max queue length should remain low-ish, perhaps a vector is better here. std::mutex qosQueueMutex; uint16_t nextPacketId = 0; @@ -25,10 +26,11 @@ public: Session(const Session &other) = delete; Session(Session &&other) = delete; + const std::string &getClientId() const { return client_id; } bool clientDisconnected() const; std::shared_ptr makeSharedClient() const; void assignActiveConnection(std::shared_ptr &client); - void writePacket(const MqttPacket &packet); + void writePacket(const MqttPacket &packet, char qos_arg); void clearQosMessage(uint16_t packet_id); void sendPendingQosMessages(); }; diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index 6ef0f34..a4ddb55 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -11,6 +11,25 @@ SubscriptionNode::SubscriptionNode(const std::string &subtopic) : } +std::vector &SubscriptionNode::getSubscribers() +{ + return subscribers; +} + +void SubscriptionNode::addSubscriber(const std::shared_ptr &subscriber, char qos) +{ + Subscription sub; + sub.session = subscriber; + sub.qos = qos; + + // I'll have to decide whether to keep the subscriber as a vector. Vectors are + // fast, and relatively, you don't often add subscribers. + if (std::find(subscribers.begin(), subscribers.end(), sub) == subscribers.end()) + { + subscribers.push_back(sub); + } +} + SubscriptionStore::SubscriptionStore() : root(new SubscriptionNode("root")), sessionsByIdConst(sessionsById) @@ -52,7 +71,7 @@ void SubscriptionStore::addSubscription(Client_p &client, const std::string &top if (session_it != sessionsByIdConst.end()) { std::weak_ptr b = session_it->second; - deepestNode->subscribers.push_front(b); + deepestNode->addSubscriber(session_it->second, qos); } } @@ -99,14 +118,15 @@ void SubscriptionStore::registerClientAndKickExistingOne(Client_p &client) } // TODO: should I implement cache, this needs to be changed to returning a list of clients. -void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::forward_list> &subscribers) const +void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::vector &subscribers) const { - for (const std::weak_ptr session_weak : subscribers) + for (const Subscription &sub : subscribers) { + std::weak_ptr session_weak = sub.session; if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect. { const std::shared_ptr session = session_weak.lock(); - session->writePacket(packet); + session->writePacket(packet, sub.qos); } } } @@ -116,7 +136,7 @@ void SubscriptionStore::publishRecursively(std::vector::const_itera { if (cur_subtopic_it == end) // This is the end of the topic path, so look for subscribers here. { - publishNonRecursively(packet, this_node->subscribers); + publishNonRecursively(packet, this_node->getSubscribers()); return; } @@ -129,7 +149,7 @@ void SubscriptionStore::publishRecursively(std::vector::const_itera if (this_node->childrenPound) { - publishNonRecursively(packet, this_node->childrenPound->subscribers); + publishNonRecursively(packet, this_node->childrenPound->getSubscribers()); } auto sub_node = this_node->children.find(cur_subtop); @@ -197,4 +217,23 @@ void SubscriptionStore::setRetainedMessage(const std::string &topic, const std:: retainedMessages.insert(std::move(rm)); } +// QoS is not used in the comparision. This means you upgrade your QoS by subscribing again. The +// specs don't specify what to do there. +bool Subscription::operator==(const Subscription &rhs) const +{ + if (session.expired() && rhs.session.expired()) + return true; + if (session.expired() || rhs.session.expired()) + return false; + + const std::shared_ptr lhs_ses = session.lock(); + const std::shared_ptr rhs_ses = rhs.session.lock(); + return lhs_ses->getClientId() == rhs_ses->getClientId(); +} + +void Subscription::reset() +{ + session.reset(); + qos = 0; +} diff --git a/subscriptionstore.h b/subscriptionstore.h index 4ac3a77..d1c009d 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -21,19 +21,30 @@ struct RetainedPayload char qos; }; +struct Subscription +{ + std::weak_ptr session; // Weak pointer expires when session has been cleaned by 'clean session' connect. + char qos; + bool operator==(const Subscription &rhs) const; + void reset(); +}; + class SubscriptionNode { std::string subtopic; + std::vector subscribers; public: SubscriptionNode(const std::string &subtopic); SubscriptionNode(const SubscriptionNode &node) = delete; SubscriptionNode(SubscriptionNode &&node) = delete; - std::forward_list> subscribers; // TODO: a subscription class, with qos + std::vector &getSubscribers(); + void addSubscriber(const std::shared_ptr &subscriber, char qos); std::unordered_map> children; std::unique_ptr childrenPlus; std::unique_ptr childrenPound; + }; class SubscriptionStore @@ -48,7 +59,7 @@ class SubscriptionStore Logger *logger = Logger::getInstance(); - void publishNonRecursively(const MqttPacket &packet, const std::forward_list> &subscribers) const; + void publishNonRecursively(const MqttPacket &packet, const std::vector &subscribers) const; void publishRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, std::unique_ptr &next, const MqttPacket &packet) const; public: