Commit 43f54d96d7b8400452b67011d9f2900b7282486d

Authored by Wiebe Cazemier
1 parent 3eb55c4d

Solve crashes because of race conditions in sessions and clients

One was confirmed: writing an mqtt packet into a client that
disconnected after checking the weak pointer for validity.

The rest made sense to change as well.
session.cpp
... ... @@ -90,11 +90,14 @@ std::unique_ptr<Session> Session::getCopy() const
90 90 return s;
91 91 }
92 92  
93   -bool Session::clientDisconnected() const
94   -{
95   - return client.expired();
96   -}
97   -
  93 +/**
  94 + * @brief Session::makeSharedClient get the client of the session, or a null when it has no active current client.
  95 + * @return Returns shared_ptr<Client>, which can contain null when the client has disconnected.
  96 + *
  97 + * The lock() operation is atomic and therefore is the only way to get the current active client without race condition, because
  98 + * typically, this method is called from other client's threads to perform writes, so you have to check validity after
  99 + * obtaining the shared pointer.
  100 + */
98 101 std::shared_ptr<Client> Session::makeSharedClient() const
99 102 {
100 103 return client.lock();
... ... @@ -126,9 +129,10 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos, bool retain, u
126 129 {
127 130 if (qos == 0)
128 131 {
129   - if (!clientDisconnected())
  132 + std::shared_ptr<Client> c = makeSharedClient();
  133 +
  134 + if (c)
130 135 {
131   - std::shared_ptr<Client> c = makeSharedClient();
132 136 c->writeMqttPacketAndBlameThisClient(packet, qos);
133 137 count++;
134 138 }
... ... @@ -150,9 +154,9 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos, bool retain, u
150 154 std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId);
151 155 locker.unlock();
152 156  
153   - if (!clientDisconnected())
  157 + std::shared_ptr<Client> c = makeSharedClient();
  158 + if (c)
154 159 {
155   - std::shared_ptr<Client> c = makeSharedClient();
156 160 c->writeMqttPacketAndBlameThisClient(*copyPacket.get(), qos);
157 161 copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate.
158 162 count++;
... ... @@ -183,9 +187,9 @@ uint64_t Session::sendPendingQosMessages()
183 187 {
184 188 uint64_t count = 0;
185 189  
186   - if (!clientDisconnected())
  190 + std::shared_ptr<Client> c = makeSharedClient();
  191 + if (c)
187 192 {
188   - std::shared_ptr<Client> c = makeSharedClient();
189 193 std::lock_guard<std::mutex> locker(qosQueueMutex);
190 194 for (const std::shared_ptr<MqttPacket> &qosMessage : qosPacketQueue)
191 195 {
... ... @@ -223,7 +227,7 @@ bool Session::hasExpired(int expireAfterSeconds)
223 227 {
224 228 std::chrono::seconds expireAfter(expireAfterSeconds);
225 229 std::chrono::time_point<std::chrono::steady_clock> now = std::chrono::steady_clock::now();
226   - return clientDisconnected() && (lastTouched + expireAfter) < now;
  230 + return client.expired() && (lastTouched + expireAfter) < now;
227 231 }
228 232  
229 233 void Session::addIncomingQoS2MessageId(uint16_t packet_id)
... ...
session.h
... ... @@ -66,7 +66,6 @@ public:
66 66 std::unique_ptr<Session> getCopy() const;
67 67  
68 68 const std::string &getClientId() const { return client_id; }
69   - bool clientDisconnected() const;
70 69 std::shared_ptr<Client> makeSharedClient() const;
71 70 void assignActiveConnection(std::shared_ptr<Client> &client);
72 71 void writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count);
... ...
subscriptionstore.cpp
... ... @@ -212,13 +212,17 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr&lt;Client&gt;
212 212 {
213 213 session = session_it->second;
214 214  
215   - if (session && !session->clientDisconnected())
  215 + if (session)
216 216 {
217 217 std::shared_ptr<Client> cl = session->makeSharedClient();
218   - logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str());
219   - cl->setReadyForDisconnect();
220   - cl->getThreadData()->removeClient(cl);
221   - cl->markAsDisconnecting();
  218 +
  219 + if (cl)
  220 + {
  221 + logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str());
  222 + cl->setReadyForDisconnect();
  223 + cl->getThreadData()->removeClient(cl);
  224 + cl->markAsDisconnecting();
  225 + }
222 226 }
223 227 }
224 228  
... ... @@ -255,10 +259,9 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &amp;packet, const st
255 259 {
256 260 for (const Subscription &sub : subscribers)
257 261 {
258   - std::weak_ptr<Session> session_weak = sub.session;
259   - if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect.
  262 + const std::shared_ptr<Session> session = sub.session.lock();
  263 + if (session) // Shared pointer expires when session has been cleaned by 'clean session' connect.
260 264 {
261   - const std::shared_ptr<Session> session = session_weak.lock();
262 265 session->writePacket(packet, sub.qos, false, count);
263 266 }
264 267 }
... ... @@ -469,7 +472,8 @@ int SubscriptionNode::cleanSubscriptions()
469 472 auto it = subscribers.begin();
470 473 while (it != subscribers.end())
471 474 {
472   - if (it->sessionGone())
  475 + std::shared_ptr<Session> ses = it->session.lock();
  476 + if (!ses)
473 477 {
474 478 Logger::getInstance()->logf(LOG_DEBUG, "Removing empty spot in subscribers vector");
475 479 it = subscribers.erase(it);
... ... @@ -539,9 +543,10 @@ void SubscriptionStore::getSubscriptions(SubscriptionNode *this_node, const std:
539 543 {
540 544 for (const Subscription &node : this_node->getSubscribers())
541 545 {
542   - if (!node.sessionGone())
  546 + std::shared_ptr<Session> ses = node.session.lock();
  547 + if (ses)
543 548 {
544   - SubscriptionForSerializing sub(node.session.lock()->getClientId(), node.qos);
  549 + SubscriptionForSerializing sub(ses->getClientId(), node.qos);
545 550 outputList[composedTopic].push_back(sub);
546 551 }
547 552 }
... ... @@ -702,7 +707,7 @@ bool Subscription::operator==(const Subscription &amp;rhs) const
702 707 const std::shared_ptr<Session> lhs_ses = session.lock();
703 708 const std::shared_ptr<Session> rhs_ses = rhs.session.lock();
704 709  
705   - return lhs_ses->getClientId() == rhs_ses->getClientId();
  710 + return lhs_ses && rhs_ses && lhs_ses->getClientId() == rhs_ses->getClientId();
706 711 }
707 712  
708 713 void Subscription::reset()
... ... @@ -711,11 +716,6 @@ void Subscription::reset()
711 716 qos = 0;
712 717 }
713 718  
714   -bool Subscription::sessionGone() const
715   -{
716   - return session.expired();
717   -}
718   -
719 719 void RetainedMessageNode::addPayload(const std::string &topic, const std::string &payload, char qos, int64_t &totalCount)
720 720 {
721 721 const int64_t countBefore = retainedMessages.size();
... ...
subscriptionstore.h
... ... @@ -39,7 +39,6 @@ struct Subscription
39 39 char qos;
40 40 bool operator==(const Subscription &rhs) const;
41 41 void reset();
42   - bool sessionGone() const;
43 42 };
44 43  
45 44 class SubscriptionNode
... ...