diff --git a/session.cpp b/session.cpp index 9b3a475..ae06448 100644 --- a/session.cpp +++ b/session.cpp @@ -90,11 +90,14 @@ std::unique_ptr Session::getCopy() const return s; } -bool Session::clientDisconnected() const -{ - return client.expired(); -} - +/** + * @brief Session::makeSharedClient get the client of the session, or a null when it has no active current client. + * @return Returns shared_ptr, which can contain null when the client has disconnected. + * + * The lock() operation is atomic and therefore is the only way to get the current active client without race condition, because + * typically, this method is called from other client's threads to perform writes, so you have to check validity after + * obtaining the shared pointer. + */ std::shared_ptr Session::makeSharedClient() const { return client.lock(); @@ -126,9 +129,10 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u { if (qos == 0) { - if (!clientDisconnected()) + std::shared_ptr c = makeSharedClient(); + + if (c) { - std::shared_ptr c = makeSharedClient(); c->writeMqttPacketAndBlameThisClient(packet, qos); count++; } @@ -150,9 +154,9 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u std::shared_ptr copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId); locker.unlock(); - if (!clientDisconnected()) + std::shared_ptr c = makeSharedClient(); + if (c) { - std::shared_ptr c = makeSharedClient(); c->writeMqttPacketAndBlameThisClient(*copyPacket.get(), qos); copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. count++; @@ -183,9 +187,9 @@ uint64_t Session::sendPendingQosMessages() { uint64_t count = 0; - if (!clientDisconnected()) + std::shared_ptr c = makeSharedClient(); + if (c) { - std::shared_ptr c = makeSharedClient(); std::lock_guard locker(qosQueueMutex); for (const std::shared_ptr &qosMessage : qosPacketQueue) { @@ -223,7 +227,7 @@ bool Session::hasExpired(int expireAfterSeconds) { std::chrono::seconds expireAfter(expireAfterSeconds); std::chrono::time_point now = std::chrono::steady_clock::now(); - return clientDisconnected() && (lastTouched + expireAfter) < now; + return client.expired() && (lastTouched + expireAfter) < now; } void Session::addIncomingQoS2MessageId(uint16_t packet_id) diff --git a/session.h b/session.h index ec2770e..c4c319d 100644 --- a/session.h +++ b/session.h @@ -66,7 +66,6 @@ public: std::unique_ptr getCopy() const; const std::string &getClientId() const { return client_id; } - bool clientDisconnected() const; std::shared_ptr makeSharedClient() const; void assignActiveConnection(std::shared_ptr &client); void writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count); diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index 6628ebb..0922f27 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -212,13 +212,17 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr { session = session_it->second; - if (session && !session->clientDisconnected()) + if (session) { std::shared_ptr cl = session->makeSharedClient(); - logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str()); - cl->setReadyForDisconnect(); - cl->getThreadData()->removeClient(cl); - cl->markAsDisconnecting(); + + if (cl) + { + logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str()); + cl->setReadyForDisconnect(); + cl->getThreadData()->removeClient(cl); + cl->markAsDisconnecting(); + } } } @@ -255,10 +259,9 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const st { for (const Subscription &sub : subscribers) { - std::weak_ptr session_weak = sub.session; - if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect. + const std::shared_ptr session = sub.session.lock(); + if (session) // Shared pointer expires when session has been cleaned by 'clean session' connect. { - const std::shared_ptr session = session_weak.lock(); session->writePacket(packet, sub.qos, false, count); } } @@ -469,7 +472,8 @@ int SubscriptionNode::cleanSubscriptions() auto it = subscribers.begin(); while (it != subscribers.end()) { - if (it->sessionGone()) + std::shared_ptr ses = it->session.lock(); + if (!ses) { Logger::getInstance()->logf(LOG_DEBUG, "Removing empty spot in subscribers vector"); it = subscribers.erase(it); @@ -539,9 +543,10 @@ void SubscriptionStore::getSubscriptions(SubscriptionNode *this_node, const std: { for (const Subscription &node : this_node->getSubscribers()) { - if (!node.sessionGone()) + std::shared_ptr ses = node.session.lock(); + if (ses) { - SubscriptionForSerializing sub(node.session.lock()->getClientId(), node.qos); + SubscriptionForSerializing sub(ses->getClientId(), node.qos); outputList[composedTopic].push_back(sub); } } @@ -702,7 +707,7 @@ bool Subscription::operator==(const Subscription &rhs) const const std::shared_ptr lhs_ses = session.lock(); const std::shared_ptr rhs_ses = rhs.session.lock(); - return lhs_ses->getClientId() == rhs_ses->getClientId(); + return lhs_ses && rhs_ses && lhs_ses->getClientId() == rhs_ses->getClientId(); } void Subscription::reset() @@ -711,11 +716,6 @@ void Subscription::reset() qos = 0; } -bool Subscription::sessionGone() const -{ - return session.expired(); -} - void RetainedMessageNode::addPayload(const std::string &topic, const std::string &payload, char qos, int64_t &totalCount) { const int64_t countBefore = retainedMessages.size(); diff --git a/subscriptionstore.h b/subscriptionstore.h index 93e3dc6..078aad5 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -39,7 +39,6 @@ struct Subscription char qos; bool operator==(const Subscription &rhs) const; void reset(); - bool sessionGone() const; }; class SubscriptionNode