Commit 6141bbcc6a0f756f291059010c545f465538646d

Authored by Wiebe Cazemier
1 parent 760ec588

Store QoS in subscription and use it

Connected to this is preventing duplicate subscriptions. It's a bit
unclear what to do when you get a subscription for the same topic with a
different QoS? Change the Qos? Ignore?
session.cpp
@@ -21,11 +21,12 @@ std::shared_ptr<Client> Session::makeSharedClient() const @@ -21,11 +21,12 @@ std::shared_ptr<Client> Session::makeSharedClient() const
21 void Session::assignActiveConnection(std::shared_ptr<Client> &client) 21 void Session::assignActiveConnection(std::shared_ptr<Client> &client)
22 { 22 {
23 this->client = client; 23 this->client = client;
  24 + this->client_id = client->getClientId();
24 } 25 }
25 26
26 -void Session::writePacket(const MqttPacket &packet) 27 +void Session::writePacket(const MqttPacket &packet, char qos_arg)
27 { 28 {
28 - const char qos = packet.getQos(); 29 + const char qos = std::min<char>(packet.getQos(), qos_arg);
29 30
30 if (qos == 0) 31 if (qos == 0)
31 { 32 {
@@ -41,7 +42,7 @@ void Session::writePacket(const MqttPacket &amp;packet) @@ -41,7 +42,7 @@ void Session::writePacket(const MqttPacket &amp;packet)
41 std::unique_lock<std::mutex> locker(qosQueueMutex); 42 std::unique_lock<std::mutex> locker(qosQueueMutex);
42 if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) 43 if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0))
43 { 44 {
44 - logger->logf(LOG_WARNING, "Dropping QoS message for client 'TODO', because its QoS buffers were full."); 45 + logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str());
45 return; 46 return;
46 } 47 }
47 const uint16_t pid = nextPacketId++; 48 const uint16_t pid = nextPacketId++;
session.h
@@ -15,6 +15,7 @@ @@ -15,6 +15,7 @@
15 class Session 15 class Session
16 { 16 {
17 std::weak_ptr<Client> client; 17 std::weak_ptr<Client> client;
  18 + std::string client_id;
18 std::unordered_map<uint16_t, std::shared_ptr<MqttPacket>> qosPacketQueue; // TODO: because the max queue length should remain low-ish, perhaps a vector is better here. 19 std::unordered_map<uint16_t, std::shared_ptr<MqttPacket>> qosPacketQueue; // TODO: because the max queue length should remain low-ish, perhaps a vector is better here.
19 std::mutex qosQueueMutex; 20 std::mutex qosQueueMutex;
20 uint16_t nextPacketId = 0; 21 uint16_t nextPacketId = 0;
@@ -25,10 +26,11 @@ public: @@ -25,10 +26,11 @@ public:
25 Session(const Session &other) = delete; 26 Session(const Session &other) = delete;
26 Session(Session &&other) = delete; 27 Session(Session &&other) = delete;
27 28
  29 + const std::string &getClientId() const { return client_id; }
28 bool clientDisconnected() const; 30 bool clientDisconnected() const;
29 std::shared_ptr<Client> makeSharedClient() const; 31 std::shared_ptr<Client> makeSharedClient() const;
30 void assignActiveConnection(std::shared_ptr<Client> &client); 32 void assignActiveConnection(std::shared_ptr<Client> &client);
31 - void writePacket(const MqttPacket &packet); 33 + void writePacket(const MqttPacket &packet, char qos_arg);
32 void clearQosMessage(uint16_t packet_id); 34 void clearQosMessage(uint16_t packet_id);
33 void sendPendingQosMessages(); 35 void sendPendingQosMessages();
34 }; 36 };
subscriptionstore.cpp
@@ -11,6 +11,25 @@ SubscriptionNode::SubscriptionNode(const std::string &amp;subtopic) : @@ -11,6 +11,25 @@ SubscriptionNode::SubscriptionNode(const std::string &amp;subtopic) :
11 11
12 } 12 }
13 13
  14 +std::vector<Subscription> &SubscriptionNode::getSubscribers()
  15 +{
  16 + return subscribers;
  17 +}
  18 +
  19 +void SubscriptionNode::addSubscriber(const std::shared_ptr<Session> &subscriber, char qos)
  20 +{
  21 + Subscription sub;
  22 + sub.session = subscriber;
  23 + sub.qos = qos;
  24 +
  25 + // I'll have to decide whether to keep the subscriber as a vector. Vectors are
  26 + // fast, and relatively, you don't often add subscribers.
  27 + if (std::find(subscribers.begin(), subscribers.end(), sub) == subscribers.end())
  28 + {
  29 + subscribers.push_back(sub);
  30 + }
  31 +}
  32 +
14 SubscriptionStore::SubscriptionStore() : 33 SubscriptionStore::SubscriptionStore() :
15 root(new SubscriptionNode("root")), 34 root(new SubscriptionNode("root")),
16 sessionsByIdConst(sessionsById) 35 sessionsByIdConst(sessionsById)
@@ -52,7 +71,7 @@ void SubscriptionStore::addSubscription(Client_p &amp;client, const std::string &amp;top @@ -52,7 +71,7 @@ void SubscriptionStore::addSubscription(Client_p &amp;client, const std::string &amp;top
52 if (session_it != sessionsByIdConst.end()) 71 if (session_it != sessionsByIdConst.end())
53 { 72 {
54 std::weak_ptr<Session> b = session_it->second; 73 std::weak_ptr<Session> b = session_it->second;
55 - deepestNode->subscribers.push_front(b); 74 + deepestNode->addSubscriber(session_it->second, qos);
56 } 75 }
57 } 76 }
58 77
@@ -99,14 +118,15 @@ void SubscriptionStore::registerClientAndKickExistingOne(Client_p &amp;client) @@ -99,14 +118,15 @@ void SubscriptionStore::registerClientAndKickExistingOne(Client_p &amp;client)
99 } 118 }
100 119
101 // TODO: should I implement cache, this needs to be changed to returning a list of clients. 120 // TODO: should I implement cache, this needs to be changed to returning a list of clients.
102 -void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::forward_list<std::weak_ptr<Session>> &subscribers) const 121 +void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::vector<Subscription> &subscribers) const
103 { 122 {
104 - for (const std::weak_ptr<Session> session_weak : subscribers) 123 + for (const Subscription &sub : subscribers)
105 { 124 {
  125 + std::weak_ptr<Session> session_weak = sub.session;
106 if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect. 126 if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect.
107 { 127 {
108 const std::shared_ptr<Session> session = session_weak.lock(); 128 const std::shared_ptr<Session> session = session_weak.lock();
109 - session->writePacket(packet); 129 + session->writePacket(packet, sub.qos);
110 } 130 }
111 } 131 }
112 } 132 }
@@ -116,7 +136,7 @@ void SubscriptionStore::publishRecursively(std::vector&lt;std::string&gt;::const_itera @@ -116,7 +136,7 @@ void SubscriptionStore::publishRecursively(std::vector&lt;std::string&gt;::const_itera
116 { 136 {
117 if (cur_subtopic_it == end) // This is the end of the topic path, so look for subscribers here. 137 if (cur_subtopic_it == end) // This is the end of the topic path, so look for subscribers here.
118 { 138 {
119 - publishNonRecursively(packet, this_node->subscribers); 139 + publishNonRecursively(packet, this_node->getSubscribers());
120 return; 140 return;
121 } 141 }
122 142
@@ -129,7 +149,7 @@ void SubscriptionStore::publishRecursively(std::vector&lt;std::string&gt;::const_itera @@ -129,7 +149,7 @@ void SubscriptionStore::publishRecursively(std::vector&lt;std::string&gt;::const_itera
129 149
130 if (this_node->childrenPound) 150 if (this_node->childrenPound)
131 { 151 {
132 - publishNonRecursively(packet, this_node->childrenPound->subscribers); 152 + publishNonRecursively(packet, this_node->childrenPound->getSubscribers());
133 } 153 }
134 154
135 auto sub_node = this_node->children.find(cur_subtop); 155 auto sub_node = this_node->children.find(cur_subtop);
@@ -197,4 +217,23 @@ void SubscriptionStore::setRetainedMessage(const std::string &amp;topic, const std:: @@ -197,4 +217,23 @@ void SubscriptionStore::setRetainedMessage(const std::string &amp;topic, const std::
197 retainedMessages.insert(std::move(rm)); 217 retainedMessages.insert(std::move(rm));
198 } 218 }
199 219
  220 +// QoS is not used in the comparision. This means you upgrade your QoS by subscribing again. The
  221 +// specs don't specify what to do there.
  222 +bool Subscription::operator==(const Subscription &rhs) const
  223 +{
  224 + if (session.expired() && rhs.session.expired())
  225 + return true;
  226 + if (session.expired() || rhs.session.expired())
  227 + return false;
  228 +
  229 + const std::shared_ptr<Session> lhs_ses = session.lock();
  230 + const std::shared_ptr<Session> rhs_ses = rhs.session.lock();
200 231
  232 + return lhs_ses->getClientId() == rhs_ses->getClientId();
  233 +}
  234 +
  235 +void Subscription::reset()
  236 +{
  237 + session.reset();
  238 + qos = 0;
  239 +}
subscriptionstore.h
@@ -21,19 +21,30 @@ struct RetainedPayload @@ -21,19 +21,30 @@ struct RetainedPayload
21 char qos; 21 char qos;
22 }; 22 };
23 23
  24 +struct Subscription
  25 +{
  26 + std::weak_ptr<Session> session; // Weak pointer expires when session has been cleaned by 'clean session' connect.
  27 + char qos;
  28 + bool operator==(const Subscription &rhs) const;
  29 + void reset();
  30 +};
  31 +
24 class SubscriptionNode 32 class SubscriptionNode
25 { 33 {
26 std::string subtopic; 34 std::string subtopic;
  35 + std::vector<Subscription> subscribers;
27 36
28 public: 37 public:
29 SubscriptionNode(const std::string &subtopic); 38 SubscriptionNode(const std::string &subtopic);
30 SubscriptionNode(const SubscriptionNode &node) = delete; 39 SubscriptionNode(const SubscriptionNode &node) = delete;
31 SubscriptionNode(SubscriptionNode &&node) = delete; 40 SubscriptionNode(SubscriptionNode &&node) = delete;
32 41
33 - std::forward_list<std::weak_ptr<Session>> subscribers; // TODO: a subscription class, with qos 42 + std::vector<Subscription> &getSubscribers();
  43 + void addSubscriber(const std::shared_ptr<Session> &subscriber, char qos);
34 std::unordered_map<std::string, std::unique_ptr<SubscriptionNode>> children; 44 std::unordered_map<std::string, std::unique_ptr<SubscriptionNode>> children;
35 std::unique_ptr<SubscriptionNode> childrenPlus; 45 std::unique_ptr<SubscriptionNode> childrenPlus;
36 std::unique_ptr<SubscriptionNode> childrenPound; 46 std::unique_ptr<SubscriptionNode> childrenPound;
  47 +
37 }; 48 };
38 49
39 class SubscriptionStore 50 class SubscriptionStore
@@ -48,7 +59,7 @@ class SubscriptionStore @@ -48,7 +59,7 @@ class SubscriptionStore
48 59
49 Logger *logger = Logger::getInstance(); 60 Logger *logger = Logger::getInstance();
50 61
51 - void publishNonRecursively(const MqttPacket &packet, const std::forward_list<std::weak_ptr<Session>> &subscribers) const; 62 + void publishNonRecursively(const MqttPacket &packet, const std::vector<Subscription> &subscribers) const;
52 void publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end, 63 void publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end,
53 std::unique_ptr<SubscriptionNode> &next, const MqttPacket &packet) const; 64 std::unique_ptr<SubscriptionNode> &next, const MqttPacket &packet) const;
54 public: 65 public: