Commit 43f54d96d7b8400452b67011d9f2900b7282486d

Authored by Wiebe Cazemier
1 parent 3eb55c4d

Solve crashes because of race conditions in sessions and clients

One was confirmed: writing an mqtt packet into a client that
disconnected after checking the weak pointer for validity.

The rest made sense to change as well.
session.cpp
@@ -90,11 +90,14 @@ std::unique_ptr<Session> Session::getCopy() const @@ -90,11 +90,14 @@ std::unique_ptr<Session> Session::getCopy() const
90 return s; 90 return s;
91 } 91 }
92 92
93 -bool Session::clientDisconnected() const  
94 -{  
95 - return client.expired();  
96 -}  
97 - 93 +/**
  94 + * @brief Session::makeSharedClient get the client of the session, or a null when it has no active current client.
  95 + * @return Returns shared_ptr<Client>, which can contain null when the client has disconnected.
  96 + *
  97 + * The lock() operation is atomic and therefore is the only way to get the current active client without race condition, because
  98 + * typically, this method is called from other client's threads to perform writes, so you have to check validity after
  99 + * obtaining the shared pointer.
  100 + */
98 std::shared_ptr<Client> Session::makeSharedClient() const 101 std::shared_ptr<Client> Session::makeSharedClient() const
99 { 102 {
100 return client.lock(); 103 return client.lock();
@@ -126,9 +129,10 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos, bool retain, u @@ -126,9 +129,10 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos, bool retain, u
126 { 129 {
127 if (qos == 0) 130 if (qos == 0)
128 { 131 {
129 - if (!clientDisconnected()) 132 + std::shared_ptr<Client> c = makeSharedClient();
  133 +
  134 + if (c)
130 { 135 {
131 - std::shared_ptr<Client> c = makeSharedClient();  
132 c->writeMqttPacketAndBlameThisClient(packet, qos); 136 c->writeMqttPacketAndBlameThisClient(packet, qos);
133 count++; 137 count++;
134 } 138 }
@@ -150,9 +154,9 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos, bool retain, u @@ -150,9 +154,9 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos, bool retain, u
150 std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId); 154 std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId);
151 locker.unlock(); 155 locker.unlock();
152 156
153 - if (!clientDisconnected()) 157 + std::shared_ptr<Client> c = makeSharedClient();
  158 + if (c)
154 { 159 {
155 - std::shared_ptr<Client> c = makeSharedClient();  
156 c->writeMqttPacketAndBlameThisClient(*copyPacket.get(), qos); 160 c->writeMqttPacketAndBlameThisClient(*copyPacket.get(), qos);
157 copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. 161 copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate.
158 count++; 162 count++;
@@ -183,9 +187,9 @@ uint64_t Session::sendPendingQosMessages() @@ -183,9 +187,9 @@ uint64_t Session::sendPendingQosMessages()
183 { 187 {
184 uint64_t count = 0; 188 uint64_t count = 0;
185 189
186 - if (!clientDisconnected()) 190 + std::shared_ptr<Client> c = makeSharedClient();
  191 + if (c)
187 { 192 {
188 - std::shared_ptr<Client> c = makeSharedClient();  
189 std::lock_guard<std::mutex> locker(qosQueueMutex); 193 std::lock_guard<std::mutex> locker(qosQueueMutex);
190 for (const std::shared_ptr<MqttPacket> &qosMessage : qosPacketQueue) 194 for (const std::shared_ptr<MqttPacket> &qosMessage : qosPacketQueue)
191 { 195 {
@@ -223,7 +227,7 @@ bool Session::hasExpired(int expireAfterSeconds) @@ -223,7 +227,7 @@ bool Session::hasExpired(int expireAfterSeconds)
223 { 227 {
224 std::chrono::seconds expireAfter(expireAfterSeconds); 228 std::chrono::seconds expireAfter(expireAfterSeconds);
225 std::chrono::time_point<std::chrono::steady_clock> now = std::chrono::steady_clock::now(); 229 std::chrono::time_point<std::chrono::steady_clock> now = std::chrono::steady_clock::now();
226 - return clientDisconnected() && (lastTouched + expireAfter) < now; 230 + return client.expired() && (lastTouched + expireAfter) < now;
227 } 231 }
228 232
229 void Session::addIncomingQoS2MessageId(uint16_t packet_id) 233 void Session::addIncomingQoS2MessageId(uint16_t packet_id)
session.h
@@ -66,7 +66,6 @@ public: @@ -66,7 +66,6 @@ public:
66 std::unique_ptr<Session> getCopy() const; 66 std::unique_ptr<Session> getCopy() const;
67 67
68 const std::string &getClientId() const { return client_id; } 68 const std::string &getClientId() const { return client_id; }
69 - bool clientDisconnected() const;  
70 std::shared_ptr<Client> makeSharedClient() const; 69 std::shared_ptr<Client> makeSharedClient() const;
71 void assignActiveConnection(std::shared_ptr<Client> &client); 70 void assignActiveConnection(std::shared_ptr<Client> &client);
72 void writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count); 71 void writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count);
subscriptionstore.cpp
@@ -212,13 +212,17 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr&lt;Client&gt; @@ -212,13 +212,17 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr&lt;Client&gt;
212 { 212 {
213 session = session_it->second; 213 session = session_it->second;
214 214
215 - if (session && !session->clientDisconnected()) 215 + if (session)
216 { 216 {
217 std::shared_ptr<Client> cl = session->makeSharedClient(); 217 std::shared_ptr<Client> cl = session->makeSharedClient();
218 - logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str());  
219 - cl->setReadyForDisconnect();  
220 - cl->getThreadData()->removeClient(cl);  
221 - cl->markAsDisconnecting(); 218 +
  219 + if (cl)
  220 + {
  221 + logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str());
  222 + cl->setReadyForDisconnect();
  223 + cl->getThreadData()->removeClient(cl);
  224 + cl->markAsDisconnecting();
  225 + }
222 } 226 }
223 } 227 }
224 228
@@ -255,10 +259,9 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &amp;packet, const st @@ -255,10 +259,9 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &amp;packet, const st
255 { 259 {
256 for (const Subscription &sub : subscribers) 260 for (const Subscription &sub : subscribers)
257 { 261 {
258 - std::weak_ptr<Session> session_weak = sub.session;  
259 - if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect. 262 + const std::shared_ptr<Session> session = sub.session.lock();
  263 + if (session) // Shared pointer expires when session has been cleaned by 'clean session' connect.
260 { 264 {
261 - const std::shared_ptr<Session> session = session_weak.lock();  
262 session->writePacket(packet, sub.qos, false, count); 265 session->writePacket(packet, sub.qos, false, count);
263 } 266 }
264 } 267 }
@@ -469,7 +472,8 @@ int SubscriptionNode::cleanSubscriptions() @@ -469,7 +472,8 @@ int SubscriptionNode::cleanSubscriptions()
469 auto it = subscribers.begin(); 472 auto it = subscribers.begin();
470 while (it != subscribers.end()) 473 while (it != subscribers.end())
471 { 474 {
472 - if (it->sessionGone()) 475 + std::shared_ptr<Session> ses = it->session.lock();
  476 + if (!ses)
473 { 477 {
474 Logger::getInstance()->logf(LOG_DEBUG, "Removing empty spot in subscribers vector"); 478 Logger::getInstance()->logf(LOG_DEBUG, "Removing empty spot in subscribers vector");
475 it = subscribers.erase(it); 479 it = subscribers.erase(it);
@@ -539,9 +543,10 @@ void SubscriptionStore::getSubscriptions(SubscriptionNode *this_node, const std: @@ -539,9 +543,10 @@ void SubscriptionStore::getSubscriptions(SubscriptionNode *this_node, const std:
539 { 543 {
540 for (const Subscription &node : this_node->getSubscribers()) 544 for (const Subscription &node : this_node->getSubscribers())
541 { 545 {
542 - if (!node.sessionGone()) 546 + std::shared_ptr<Session> ses = node.session.lock();
  547 + if (ses)
543 { 548 {
544 - SubscriptionForSerializing sub(node.session.lock()->getClientId(), node.qos); 549 + SubscriptionForSerializing sub(ses->getClientId(), node.qos);
545 outputList[composedTopic].push_back(sub); 550 outputList[composedTopic].push_back(sub);
546 } 551 }
547 } 552 }
@@ -702,7 +707,7 @@ bool Subscription::operator==(const Subscription &amp;rhs) const @@ -702,7 +707,7 @@ bool Subscription::operator==(const Subscription &amp;rhs) const
702 const std::shared_ptr<Session> lhs_ses = session.lock(); 707 const std::shared_ptr<Session> lhs_ses = session.lock();
703 const std::shared_ptr<Session> rhs_ses = rhs.session.lock(); 708 const std::shared_ptr<Session> rhs_ses = rhs.session.lock();
704 709
705 - return lhs_ses->getClientId() == rhs_ses->getClientId(); 710 + return lhs_ses && rhs_ses && lhs_ses->getClientId() == rhs_ses->getClientId();
706 } 711 }
707 712
708 void Subscription::reset() 713 void Subscription::reset()
@@ -711,11 +716,6 @@ void Subscription::reset() @@ -711,11 +716,6 @@ void Subscription::reset()
711 qos = 0; 716 qos = 0;
712 } 717 }
713 718
714 -bool Subscription::sessionGone() const  
715 -{  
716 - return session.expired();  
717 -}  
718 -  
719 void RetainedMessageNode::addPayload(const std::string &topic, const std::string &payload, char qos, int64_t &totalCount) 719 void RetainedMessageNode::addPayload(const std::string &topic, const std::string &payload, char qos, int64_t &totalCount)
720 { 720 {
721 const int64_t countBefore = retainedMessages.size(); 721 const int64_t countBefore = retainedMessages.size();
subscriptionstore.h
@@ -39,7 +39,6 @@ struct Subscription @@ -39,7 +39,6 @@ struct Subscription
39 char qos; 39 char qos;
40 bool operator==(const Subscription &rhs) const; 40 bool operator==(const Subscription &rhs) const;
41 void reset(); 41 void reset();
42 - bool sessionGone() const;  
43 }; 42 };
44 43
45 class SubscriptionNode 44 class SubscriptionNode