Commit 6141bbcc6a0f756f291059010c545f465538646d

Authored by Wiebe Cazemier
1 parent 760ec588

Store QoS in subscription and use it

Connected to this is preventing duplicate subscriptions. It's a bit
unclear what to do when you get a subscription for the same topic with a
different QoS? Change the Qos? Ignore?
session.cpp
... ... @@ -21,11 +21,12 @@ std::shared_ptr<Client> Session::makeSharedClient() const
21 21 void Session::assignActiveConnection(std::shared_ptr<Client> &client)
22 22 {
23 23 this->client = client;
  24 + this->client_id = client->getClientId();
24 25 }
25 26  
26   -void Session::writePacket(const MqttPacket &packet)
  27 +void Session::writePacket(const MqttPacket &packet, char qos_arg)
27 28 {
28   - const char qos = packet.getQos();
  29 + const char qos = std::min<char>(packet.getQos(), qos_arg);
29 30  
30 31 if (qos == 0)
31 32 {
... ... @@ -41,7 +42,7 @@ void Session::writePacket(const MqttPacket &amp;packet)
41 42 std::unique_lock<std::mutex> locker(qosQueueMutex);
42 43 if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0))
43 44 {
44   - logger->logf(LOG_WARNING, "Dropping QoS message for client 'TODO', because its QoS buffers were full.");
  45 + logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str());
45 46 return;
46 47 }
47 48 const uint16_t pid = nextPacketId++;
... ...
session.h
... ... @@ -15,6 +15,7 @@
15 15 class Session
16 16 {
17 17 std::weak_ptr<Client> client;
  18 + std::string client_id;
18 19 std::unordered_map<uint16_t, std::shared_ptr<MqttPacket>> qosPacketQueue; // TODO: because the max queue length should remain low-ish, perhaps a vector is better here.
19 20 std::mutex qosQueueMutex;
20 21 uint16_t nextPacketId = 0;
... ... @@ -25,10 +26,11 @@ public:
25 26 Session(const Session &other) = delete;
26 27 Session(Session &&other) = delete;
27 28  
  29 + const std::string &getClientId() const { return client_id; }
28 30 bool clientDisconnected() const;
29 31 std::shared_ptr<Client> makeSharedClient() const;
30 32 void assignActiveConnection(std::shared_ptr<Client> &client);
31   - void writePacket(const MqttPacket &packet);
  33 + void writePacket(const MqttPacket &packet, char qos_arg);
32 34 void clearQosMessage(uint16_t packet_id);
33 35 void sendPendingQosMessages();
34 36 };
... ...
subscriptionstore.cpp
... ... @@ -11,6 +11,25 @@ SubscriptionNode::SubscriptionNode(const std::string &amp;subtopic) :
11 11  
12 12 }
13 13  
  14 +std::vector<Subscription> &SubscriptionNode::getSubscribers()
  15 +{
  16 + return subscribers;
  17 +}
  18 +
  19 +void SubscriptionNode::addSubscriber(const std::shared_ptr<Session> &subscriber, char qos)
  20 +{
  21 + Subscription sub;
  22 + sub.session = subscriber;
  23 + sub.qos = qos;
  24 +
  25 + // I'll have to decide whether to keep the subscriber as a vector. Vectors are
  26 + // fast, and relatively, you don't often add subscribers.
  27 + if (std::find(subscribers.begin(), subscribers.end(), sub) == subscribers.end())
  28 + {
  29 + subscribers.push_back(sub);
  30 + }
  31 +}
  32 +
14 33 SubscriptionStore::SubscriptionStore() :
15 34 root(new SubscriptionNode("root")),
16 35 sessionsByIdConst(sessionsById)
... ... @@ -52,7 +71,7 @@ void SubscriptionStore::addSubscription(Client_p &amp;client, const std::string &amp;top
52 71 if (session_it != sessionsByIdConst.end())
53 72 {
54 73 std::weak_ptr<Session> b = session_it->second;
55   - deepestNode->subscribers.push_front(b);
  74 + deepestNode->addSubscriber(session_it->second, qos);
56 75 }
57 76 }
58 77  
... ... @@ -99,14 +118,15 @@ void SubscriptionStore::registerClientAndKickExistingOne(Client_p &amp;client)
99 118 }
100 119  
101 120 // TODO: should I implement cache, this needs to be changed to returning a list of clients.
102   -void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::forward_list<std::weak_ptr<Session>> &subscribers) const
  121 +void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::vector<Subscription> &subscribers) const
103 122 {
104   - for (const std::weak_ptr<Session> session_weak : subscribers)
  123 + for (const Subscription &sub : subscribers)
105 124 {
  125 + std::weak_ptr<Session> session_weak = sub.session;
106 126 if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect.
107 127 {
108 128 const std::shared_ptr<Session> session = session_weak.lock();
109   - session->writePacket(packet);
  129 + session->writePacket(packet, sub.qos);
110 130 }
111 131 }
112 132 }
... ... @@ -116,7 +136,7 @@ void SubscriptionStore::publishRecursively(std::vector&lt;std::string&gt;::const_itera
116 136 {
117 137 if (cur_subtopic_it == end) // This is the end of the topic path, so look for subscribers here.
118 138 {
119   - publishNonRecursively(packet, this_node->subscribers);
  139 + publishNonRecursively(packet, this_node->getSubscribers());
120 140 return;
121 141 }
122 142  
... ... @@ -129,7 +149,7 @@ void SubscriptionStore::publishRecursively(std::vector&lt;std::string&gt;::const_itera
129 149  
130 150 if (this_node->childrenPound)
131 151 {
132   - publishNonRecursively(packet, this_node->childrenPound->subscribers);
  152 + publishNonRecursively(packet, this_node->childrenPound->getSubscribers());
133 153 }
134 154  
135 155 auto sub_node = this_node->children.find(cur_subtop);
... ... @@ -197,4 +217,23 @@ void SubscriptionStore::setRetainedMessage(const std::string &amp;topic, const std::
197 217 retainedMessages.insert(std::move(rm));
198 218 }
199 219  
  220 +// QoS is not used in the comparision. This means you upgrade your QoS by subscribing again. The
  221 +// specs don't specify what to do there.
  222 +bool Subscription::operator==(const Subscription &rhs) const
  223 +{
  224 + if (session.expired() && rhs.session.expired())
  225 + return true;
  226 + if (session.expired() || rhs.session.expired())
  227 + return false;
  228 +
  229 + const std::shared_ptr<Session> lhs_ses = session.lock();
  230 + const std::shared_ptr<Session> rhs_ses = rhs.session.lock();
200 231  
  232 + return lhs_ses->getClientId() == rhs_ses->getClientId();
  233 +}
  234 +
  235 +void Subscription::reset()
  236 +{
  237 + session.reset();
  238 + qos = 0;
  239 +}
... ...
subscriptionstore.h
... ... @@ -21,19 +21,30 @@ struct RetainedPayload
21 21 char qos;
22 22 };
23 23  
  24 +struct Subscription
  25 +{
  26 + std::weak_ptr<Session> session; // Weak pointer expires when session has been cleaned by 'clean session' connect.
  27 + char qos;
  28 + bool operator==(const Subscription &rhs) const;
  29 + void reset();
  30 +};
  31 +
24 32 class SubscriptionNode
25 33 {
26 34 std::string subtopic;
  35 + std::vector<Subscription> subscribers;
27 36  
28 37 public:
29 38 SubscriptionNode(const std::string &subtopic);
30 39 SubscriptionNode(const SubscriptionNode &node) = delete;
31 40 SubscriptionNode(SubscriptionNode &&node) = delete;
32 41  
33   - std::forward_list<std::weak_ptr<Session>> subscribers; // TODO: a subscription class, with qos
  42 + std::vector<Subscription> &getSubscribers();
  43 + void addSubscriber(const std::shared_ptr<Session> &subscriber, char qos);
34 44 std::unordered_map<std::string, std::unique_ptr<SubscriptionNode>> children;
35 45 std::unique_ptr<SubscriptionNode> childrenPlus;
36 46 std::unique_ptr<SubscriptionNode> childrenPound;
  47 +
37 48 };
38 49  
39 50 class SubscriptionStore
... ... @@ -48,7 +59,7 @@ class SubscriptionStore
48 59  
49 60 Logger *logger = Logger::getInstance();
50 61  
51   - void publishNonRecursively(const MqttPacket &packet, const std::forward_list<std::weak_ptr<Session>> &subscribers) const;
  62 + void publishNonRecursively(const MqttPacket &packet, const std::vector<Subscription> &subscribers) const;
52 63 void publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end,
53 64 std::unique_ptr<SubscriptionNode> &next, const MqttPacket &packet) const;
54 65 public:
... ...