Commit 43f54d96d7b8400452b67011d9f2900b7282486d
1 parent
3eb55c4d
Solve crashes because of race conditions in sessions and clients
One was confirmed: writing an mqtt packet into a client that disconnected after checking the weak pointer for validity. The rest made sense to change as well.
Showing
4 changed files
with
33 additions
and
31 deletions
session.cpp
| @@ -90,11 +90,14 @@ std::unique_ptr<Session> Session::getCopy() const | @@ -90,11 +90,14 @@ std::unique_ptr<Session> Session::getCopy() const | ||
| 90 | return s; | 90 | return s; |
| 91 | } | 91 | } |
| 92 | 92 | ||
| 93 | -bool Session::clientDisconnected() const | ||
| 94 | -{ | ||
| 95 | - return client.expired(); | ||
| 96 | -} | ||
| 97 | - | 93 | +/** |
| 94 | + * @brief Session::makeSharedClient get the client of the session, or a null when it has no active current client. | ||
| 95 | + * @return Returns shared_ptr<Client>, which can contain null when the client has disconnected. | ||
| 96 | + * | ||
| 97 | + * The lock() operation is atomic and therefore is the only way to get the current active client without race condition, because | ||
| 98 | + * typically, this method is called from other client's threads to perform writes, so you have to check validity after | ||
| 99 | + * obtaining the shared pointer. | ||
| 100 | + */ | ||
| 98 | std::shared_ptr<Client> Session::makeSharedClient() const | 101 | std::shared_ptr<Client> Session::makeSharedClient() const |
| 99 | { | 102 | { |
| 100 | return client.lock(); | 103 | return client.lock(); |
| @@ -126,9 +129,10 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u | @@ -126,9 +129,10 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u | ||
| 126 | { | 129 | { |
| 127 | if (qos == 0) | 130 | if (qos == 0) |
| 128 | { | 131 | { |
| 129 | - if (!clientDisconnected()) | 132 | + std::shared_ptr<Client> c = makeSharedClient(); |
| 133 | + | ||
| 134 | + if (c) | ||
| 130 | { | 135 | { |
| 131 | - std::shared_ptr<Client> c = makeSharedClient(); | ||
| 132 | c->writeMqttPacketAndBlameThisClient(packet, qos); | 136 | c->writeMqttPacketAndBlameThisClient(packet, qos); |
| 133 | count++; | 137 | count++; |
| 134 | } | 138 | } |
| @@ -150,9 +154,9 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u | @@ -150,9 +154,9 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u | ||
| 150 | std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId); | 154 | std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId); |
| 151 | locker.unlock(); | 155 | locker.unlock(); |
| 152 | 156 | ||
| 153 | - if (!clientDisconnected()) | 157 | + std::shared_ptr<Client> c = makeSharedClient(); |
| 158 | + if (c) | ||
| 154 | { | 159 | { |
| 155 | - std::shared_ptr<Client> c = makeSharedClient(); | ||
| 156 | c->writeMqttPacketAndBlameThisClient(*copyPacket.get(), qos); | 160 | c->writeMqttPacketAndBlameThisClient(*copyPacket.get(), qos); |
| 157 | copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | 161 | copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. |
| 158 | count++; | 162 | count++; |
| @@ -183,9 +187,9 @@ uint64_t Session::sendPendingQosMessages() | @@ -183,9 +187,9 @@ uint64_t Session::sendPendingQosMessages() | ||
| 183 | { | 187 | { |
| 184 | uint64_t count = 0; | 188 | uint64_t count = 0; |
| 185 | 189 | ||
| 186 | - if (!clientDisconnected()) | 190 | + std::shared_ptr<Client> c = makeSharedClient(); |
| 191 | + if (c) | ||
| 187 | { | 192 | { |
| 188 | - std::shared_ptr<Client> c = makeSharedClient(); | ||
| 189 | std::lock_guard<std::mutex> locker(qosQueueMutex); | 193 | std::lock_guard<std::mutex> locker(qosQueueMutex); |
| 190 | for (const std::shared_ptr<MqttPacket> &qosMessage : qosPacketQueue) | 194 | for (const std::shared_ptr<MqttPacket> &qosMessage : qosPacketQueue) |
| 191 | { | 195 | { |
| @@ -223,7 +227,7 @@ bool Session::hasExpired(int expireAfterSeconds) | @@ -223,7 +227,7 @@ bool Session::hasExpired(int expireAfterSeconds) | ||
| 223 | { | 227 | { |
| 224 | std::chrono::seconds expireAfter(expireAfterSeconds); | 228 | std::chrono::seconds expireAfter(expireAfterSeconds); |
| 225 | std::chrono::time_point<std::chrono::steady_clock> now = std::chrono::steady_clock::now(); | 229 | std::chrono::time_point<std::chrono::steady_clock> now = std::chrono::steady_clock::now(); |
| 226 | - return clientDisconnected() && (lastTouched + expireAfter) < now; | 230 | + return client.expired() && (lastTouched + expireAfter) < now; |
| 227 | } | 231 | } |
| 228 | 232 | ||
| 229 | void Session::addIncomingQoS2MessageId(uint16_t packet_id) | 233 | void Session::addIncomingQoS2MessageId(uint16_t packet_id) |
session.h
| @@ -66,7 +66,6 @@ public: | @@ -66,7 +66,6 @@ public: | ||
| 66 | std::unique_ptr<Session> getCopy() const; | 66 | std::unique_ptr<Session> getCopy() const; |
| 67 | 67 | ||
| 68 | const std::string &getClientId() const { return client_id; } | 68 | const std::string &getClientId() const { return client_id; } |
| 69 | - bool clientDisconnected() const; | ||
| 70 | std::shared_ptr<Client> makeSharedClient() const; | 69 | std::shared_ptr<Client> makeSharedClient() const; |
| 71 | void assignActiveConnection(std::shared_ptr<Client> &client); | 70 | void assignActiveConnection(std::shared_ptr<Client> &client); |
| 72 | void writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count); | 71 | void writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count); |
subscriptionstore.cpp
| @@ -212,13 +212,17 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr<Client> | @@ -212,13 +212,17 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr<Client> | ||
| 212 | { | 212 | { |
| 213 | session = session_it->second; | 213 | session = session_it->second; |
| 214 | 214 | ||
| 215 | - if (session && !session->clientDisconnected()) | 215 | + if (session) |
| 216 | { | 216 | { |
| 217 | std::shared_ptr<Client> cl = session->makeSharedClient(); | 217 | std::shared_ptr<Client> cl = session->makeSharedClient(); |
| 218 | - logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str()); | ||
| 219 | - cl->setReadyForDisconnect(); | ||
| 220 | - cl->getThreadData()->removeClient(cl); | ||
| 221 | - cl->markAsDisconnecting(); | 218 | + |
| 219 | + if (cl) | ||
| 220 | + { | ||
| 221 | + logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str()); | ||
| 222 | + cl->setReadyForDisconnect(); | ||
| 223 | + cl->getThreadData()->removeClient(cl); | ||
| 224 | + cl->markAsDisconnecting(); | ||
| 225 | + } | ||
| 222 | } | 226 | } |
| 223 | } | 227 | } |
| 224 | 228 | ||
| @@ -255,10 +259,9 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const st | @@ -255,10 +259,9 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const st | ||
| 255 | { | 259 | { |
| 256 | for (const Subscription &sub : subscribers) | 260 | for (const Subscription &sub : subscribers) |
| 257 | { | 261 | { |
| 258 | - std::weak_ptr<Session> session_weak = sub.session; | ||
| 259 | - if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect. | 262 | + const std::shared_ptr<Session> session = sub.session.lock(); |
| 263 | + if (session) // Shared pointer expires when session has been cleaned by 'clean session' connect. | ||
| 260 | { | 264 | { |
| 261 | - const std::shared_ptr<Session> session = session_weak.lock(); | ||
| 262 | session->writePacket(packet, sub.qos, false, count); | 265 | session->writePacket(packet, sub.qos, false, count); |
| 263 | } | 266 | } |
| 264 | } | 267 | } |
| @@ -469,7 +472,8 @@ int SubscriptionNode::cleanSubscriptions() | @@ -469,7 +472,8 @@ int SubscriptionNode::cleanSubscriptions() | ||
| 469 | auto it = subscribers.begin(); | 472 | auto it = subscribers.begin(); |
| 470 | while (it != subscribers.end()) | 473 | while (it != subscribers.end()) |
| 471 | { | 474 | { |
| 472 | - if (it->sessionGone()) | 475 | + std::shared_ptr<Session> ses = it->session.lock(); |
| 476 | + if (!ses) | ||
| 473 | { | 477 | { |
| 474 | Logger::getInstance()->logf(LOG_DEBUG, "Removing empty spot in subscribers vector"); | 478 | Logger::getInstance()->logf(LOG_DEBUG, "Removing empty spot in subscribers vector"); |
| 475 | it = subscribers.erase(it); | 479 | it = subscribers.erase(it); |
| @@ -539,9 +543,10 @@ void SubscriptionStore::getSubscriptions(SubscriptionNode *this_node, const std: | @@ -539,9 +543,10 @@ void SubscriptionStore::getSubscriptions(SubscriptionNode *this_node, const std: | ||
| 539 | { | 543 | { |
| 540 | for (const Subscription &node : this_node->getSubscribers()) | 544 | for (const Subscription &node : this_node->getSubscribers()) |
| 541 | { | 545 | { |
| 542 | - if (!node.sessionGone()) | 546 | + std::shared_ptr<Session> ses = node.session.lock(); |
| 547 | + if (ses) | ||
| 543 | { | 548 | { |
| 544 | - SubscriptionForSerializing sub(node.session.lock()->getClientId(), node.qos); | 549 | + SubscriptionForSerializing sub(ses->getClientId(), node.qos); |
| 545 | outputList[composedTopic].push_back(sub); | 550 | outputList[composedTopic].push_back(sub); |
| 546 | } | 551 | } |
| 547 | } | 552 | } |
| @@ -702,7 +707,7 @@ bool Subscription::operator==(const Subscription &rhs) const | @@ -702,7 +707,7 @@ bool Subscription::operator==(const Subscription &rhs) const | ||
| 702 | const std::shared_ptr<Session> lhs_ses = session.lock(); | 707 | const std::shared_ptr<Session> lhs_ses = session.lock(); |
| 703 | const std::shared_ptr<Session> rhs_ses = rhs.session.lock(); | 708 | const std::shared_ptr<Session> rhs_ses = rhs.session.lock(); |
| 704 | 709 | ||
| 705 | - return lhs_ses->getClientId() == rhs_ses->getClientId(); | 710 | + return lhs_ses && rhs_ses && lhs_ses->getClientId() == rhs_ses->getClientId(); |
| 706 | } | 711 | } |
| 707 | 712 | ||
| 708 | void Subscription::reset() | 713 | void Subscription::reset() |
| @@ -711,11 +716,6 @@ void Subscription::reset() | @@ -711,11 +716,6 @@ void Subscription::reset() | ||
| 711 | qos = 0; | 716 | qos = 0; |
| 712 | } | 717 | } |
| 713 | 718 | ||
| 714 | -bool Subscription::sessionGone() const | ||
| 715 | -{ | ||
| 716 | - return session.expired(); | ||
| 717 | -} | ||
| 718 | - | ||
| 719 | void RetainedMessageNode::addPayload(const std::string &topic, const std::string &payload, char qos, int64_t &totalCount) | 719 | void RetainedMessageNode::addPayload(const std::string &topic, const std::string &payload, char qos, int64_t &totalCount) |
| 720 | { | 720 | { |
| 721 | const int64_t countBefore = retainedMessages.size(); | 721 | const int64_t countBefore = retainedMessages.size(); |
subscriptionstore.h
| @@ -39,7 +39,6 @@ struct Subscription | @@ -39,7 +39,6 @@ struct Subscription | ||
| 39 | char qos; | 39 | char qos; |
| 40 | bool operator==(const Subscription &rhs) const; | 40 | bool operator==(const Subscription &rhs) const; |
| 41 | void reset(); | 41 | void reset(); |
| 42 | - bool sessionGone() const; | ||
| 43 | }; | 42 | }; |
| 44 | 43 | ||
| 45 | class SubscriptionNode | 44 | class SubscriptionNode |