Commit 43f54d96d7b8400452b67011d9f2900b7282486d
1 parent
3eb55c4d
Solve crashes because of race conditions in sessions and clients
One was confirmed: writing an mqtt packet into a client that disconnected after checking the weak pointer for validity. The rest made sense to change as well.
Showing
4 changed files
with
33 additions
and
31 deletions
session.cpp
| ... | ... | @@ -90,11 +90,14 @@ std::unique_ptr<Session> Session::getCopy() const |
| 90 | 90 | return s; |
| 91 | 91 | } |
| 92 | 92 | |
| 93 | -bool Session::clientDisconnected() const | |
| 94 | -{ | |
| 95 | - return client.expired(); | |
| 96 | -} | |
| 97 | - | |
| 93 | +/** | |
| 94 | + * @brief Session::makeSharedClient get the client of the session, or a null when it has no active current client. | |
| 95 | + * @return Returns shared_ptr<Client>, which can contain null when the client has disconnected. | |
| 96 | + * | |
| 97 | + * The lock() operation is atomic and therefore is the only way to get the current active client without race condition, because | |
| 98 | + * typically, this method is called from other client's threads to perform writes, so you have to check validity after | |
| 99 | + * obtaining the shared pointer. | |
| 100 | + */ | |
| 98 | 101 | std::shared_ptr<Client> Session::makeSharedClient() const |
| 99 | 102 | { |
| 100 | 103 | return client.lock(); |
| ... | ... | @@ -126,9 +129,10 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u |
| 126 | 129 | { |
| 127 | 130 | if (qos == 0) |
| 128 | 131 | { |
| 129 | - if (!clientDisconnected()) | |
| 132 | + std::shared_ptr<Client> c = makeSharedClient(); | |
| 133 | + | |
| 134 | + if (c) | |
| 130 | 135 | { |
| 131 | - std::shared_ptr<Client> c = makeSharedClient(); | |
| 132 | 136 | c->writeMqttPacketAndBlameThisClient(packet, qos); |
| 133 | 137 | count++; |
| 134 | 138 | } |
| ... | ... | @@ -150,9 +154,9 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u |
| 150 | 154 | std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId); |
| 151 | 155 | locker.unlock(); |
| 152 | 156 | |
| 153 | - if (!clientDisconnected()) | |
| 157 | + std::shared_ptr<Client> c = makeSharedClient(); | |
| 158 | + if (c) | |
| 154 | 159 | { |
| 155 | - std::shared_ptr<Client> c = makeSharedClient(); | |
| 156 | 160 | c->writeMqttPacketAndBlameThisClient(*copyPacket.get(), qos); |
| 157 | 161 | copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. |
| 158 | 162 | count++; |
| ... | ... | @@ -183,9 +187,9 @@ uint64_t Session::sendPendingQosMessages() |
| 183 | 187 | { |
| 184 | 188 | uint64_t count = 0; |
| 185 | 189 | |
| 186 | - if (!clientDisconnected()) | |
| 190 | + std::shared_ptr<Client> c = makeSharedClient(); | |
| 191 | + if (c) | |
| 187 | 192 | { |
| 188 | - std::shared_ptr<Client> c = makeSharedClient(); | |
| 189 | 193 | std::lock_guard<std::mutex> locker(qosQueueMutex); |
| 190 | 194 | for (const std::shared_ptr<MqttPacket> &qosMessage : qosPacketQueue) |
| 191 | 195 | { |
| ... | ... | @@ -223,7 +227,7 @@ bool Session::hasExpired(int expireAfterSeconds) |
| 223 | 227 | { |
| 224 | 228 | std::chrono::seconds expireAfter(expireAfterSeconds); |
| 225 | 229 | std::chrono::time_point<std::chrono::steady_clock> now = std::chrono::steady_clock::now(); |
| 226 | - return clientDisconnected() && (lastTouched + expireAfter) < now; | |
| 230 | + return client.expired() && (lastTouched + expireAfter) < now; | |
| 227 | 231 | } |
| 228 | 232 | |
| 229 | 233 | void Session::addIncomingQoS2MessageId(uint16_t packet_id) | ... | ... |
session.h
| ... | ... | @@ -66,7 +66,6 @@ public: |
| 66 | 66 | std::unique_ptr<Session> getCopy() const; |
| 67 | 67 | |
| 68 | 68 | const std::string &getClientId() const { return client_id; } |
| 69 | - bool clientDisconnected() const; | |
| 70 | 69 | std::shared_ptr<Client> makeSharedClient() const; |
| 71 | 70 | void assignActiveConnection(std::shared_ptr<Client> &client); |
| 72 | 71 | void writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count); | ... | ... |
subscriptionstore.cpp
| ... | ... | @@ -212,13 +212,17 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr<Client> |
| 212 | 212 | { |
| 213 | 213 | session = session_it->second; |
| 214 | 214 | |
| 215 | - if (session && !session->clientDisconnected()) | |
| 215 | + if (session) | |
| 216 | 216 | { |
| 217 | 217 | std::shared_ptr<Client> cl = session->makeSharedClient(); |
| 218 | - logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str()); | |
| 219 | - cl->setReadyForDisconnect(); | |
| 220 | - cl->getThreadData()->removeClient(cl); | |
| 221 | - cl->markAsDisconnecting(); | |
| 218 | + | |
| 219 | + if (cl) | |
| 220 | + { | |
| 221 | + logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str()); | |
| 222 | + cl->setReadyForDisconnect(); | |
| 223 | + cl->getThreadData()->removeClient(cl); | |
| 224 | + cl->markAsDisconnecting(); | |
| 225 | + } | |
| 222 | 226 | } |
| 223 | 227 | } |
| 224 | 228 | |
| ... | ... | @@ -255,10 +259,9 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const st |
| 255 | 259 | { |
| 256 | 260 | for (const Subscription &sub : subscribers) |
| 257 | 261 | { |
| 258 | - std::weak_ptr<Session> session_weak = sub.session; | |
| 259 | - if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect. | |
| 262 | + const std::shared_ptr<Session> session = sub.session.lock(); | |
| 263 | + if (session) // Shared pointer expires when session has been cleaned by 'clean session' connect. | |
| 260 | 264 | { |
| 261 | - const std::shared_ptr<Session> session = session_weak.lock(); | |
| 262 | 265 | session->writePacket(packet, sub.qos, false, count); |
| 263 | 266 | } |
| 264 | 267 | } |
| ... | ... | @@ -469,7 +472,8 @@ int SubscriptionNode::cleanSubscriptions() |
| 469 | 472 | auto it = subscribers.begin(); |
| 470 | 473 | while (it != subscribers.end()) |
| 471 | 474 | { |
| 472 | - if (it->sessionGone()) | |
| 475 | + std::shared_ptr<Session> ses = it->session.lock(); | |
| 476 | + if (!ses) | |
| 473 | 477 | { |
| 474 | 478 | Logger::getInstance()->logf(LOG_DEBUG, "Removing empty spot in subscribers vector"); |
| 475 | 479 | it = subscribers.erase(it); |
| ... | ... | @@ -539,9 +543,10 @@ void SubscriptionStore::getSubscriptions(SubscriptionNode *this_node, const std: |
| 539 | 543 | { |
| 540 | 544 | for (const Subscription &node : this_node->getSubscribers()) |
| 541 | 545 | { |
| 542 | - if (!node.sessionGone()) | |
| 546 | + std::shared_ptr<Session> ses = node.session.lock(); | |
| 547 | + if (ses) | |
| 543 | 548 | { |
| 544 | - SubscriptionForSerializing sub(node.session.lock()->getClientId(), node.qos); | |
| 549 | + SubscriptionForSerializing sub(ses->getClientId(), node.qos); | |
| 545 | 550 | outputList[composedTopic].push_back(sub); |
| 546 | 551 | } |
| 547 | 552 | } |
| ... | ... | @@ -702,7 +707,7 @@ bool Subscription::operator==(const Subscription &rhs) const |
| 702 | 707 | const std::shared_ptr<Session> lhs_ses = session.lock(); |
| 703 | 708 | const std::shared_ptr<Session> rhs_ses = rhs.session.lock(); |
| 704 | 709 | |
| 705 | - return lhs_ses->getClientId() == rhs_ses->getClientId(); | |
| 710 | + return lhs_ses && rhs_ses && lhs_ses->getClientId() == rhs_ses->getClientId(); | |
| 706 | 711 | } |
| 707 | 712 | |
| 708 | 713 | void Subscription::reset() |
| ... | ... | @@ -711,11 +716,6 @@ void Subscription::reset() |
| 711 | 716 | qos = 0; |
| 712 | 717 | } |
| 713 | 718 | |
| 714 | -bool Subscription::sessionGone() const | |
| 715 | -{ | |
| 716 | - return session.expired(); | |
| 717 | -} | |
| 718 | - | |
| 719 | 719 | void RetainedMessageNode::addPayload(const std::string &topic, const std::string &payload, char qos, int64_t &totalCount) |
| 720 | 720 | { |
| 721 | 721 | const int64_t countBefore = retainedMessages.size(); | ... | ... |