diff --git a/client.cpp b/client.cpp index 45a9255..53ad438 100644 --- a/client.cpp +++ b/client.cpp @@ -55,12 +55,18 @@ Client::~Client() if (!this->threadData) return; - std::shared_ptr &store = this->threadData->getSubscriptionStore(); - if (disconnectReason.empty()) disconnectReason = "not specified"; logger->logf(LOG_NOTICE, "Removing client '%s'. Reason(s): %s", repr().c_str(), disconnectReason.c_str()); + + std::shared_ptr &store = this->threadData->getSubscriptionStore(); + + if (willPublish) + { + store->queueWillMessage(willPublish, session); + } + if (fd > 0) // this check is essentially for testing, when working with a dummy fd. { if (epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL) != 0) diff --git a/forward_declarations.h b/forward_declarations.h index fc9813a..052f004 100644 --- a/forward_declarations.h +++ b/forward_declarations.h @@ -27,6 +27,7 @@ class SubscriptionStore; class Session; class Settings; class Mqtt5PropertyBuilder; +class SessionsAndSubscriptionsDB; #endif // FORWARD_DECLARATIONS_H diff --git a/mqttpacket.cpp b/mqttpacket.cpp index cc1f77a..aac9e69 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -447,7 +447,6 @@ void MqttPacket::handleConnect() { case Mqtt5Properties::WillDelayInterval: willpublish.will_delay = readFourBytesToUint32(); - willpublish.setCreatedAt(std::chrono::steady_clock::now()); break; case Mqtt5Properties::PayloadFormatIndicator: willpublish.propertyBuilder->writePayloadFormatIndicator(readByte()); diff --git a/sessionsandsubscriptionsdb.cpp b/sessionsandsubscriptionsdb.cpp index ec14b52..f5575da 100644 --- a/sessionsandsubscriptionsdb.cpp +++ b/sessionsandsubscriptionsdb.cpp @@ -130,7 +130,7 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2() Publish pub(pack.getPublishData()); const uint32_t newPubAge = persistence_state_age + originalPubAge; - pub.setCreatedAt(timepointFromAge(newPubAge)); + pub.createdAt = timepointFromAge(newPubAge); logger->logf(LOG_DEBUG, "Loaded QoS %d message for topic '%s' for session '%s'.", pub.qos, pub.topic.c_str(), ses->getClientId().c_str()); ses->qosPacketQueue.queuePublish(std::move(pub), id); diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index 7526112..55795b0 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -280,6 +280,9 @@ std::shared_ptr SubscriptionStore::lockSession(const std::string &clien * The expiry interval as set in the properties of the will message is not used to check for expiration here. To * quote the specs: "If present, the Four Byte value is the lifetime of the Will Message in seconds and is sent as * the Publication Expiry Interval when the Server publishes the Will Message." + * + * If a new Network Connection to this Session is made before the Will Delay Interval has passed, the Server + * MUST NOT send the Will Message [MQTT-3.1.3-9]. */ void SubscriptionStore::sendQueuedWillMessages() { @@ -289,15 +292,27 @@ void SubscriptionStore::sendQueuedWillMessages() auto it = pendingWillMessages.begin(); while (it != pendingWillMessages.end()) { - std::shared_ptr p = (*it).lock(); + QueuedWill &qw = *it; + + std::shared_ptr p = qw.getWill().lock(); if (p) { - if (p->getCreatedAt() + std::chrono::seconds(p->will_delay) > now) + if (qw.getSendAt() > now) break; + std::shared_ptr s = qw.getSession(); + + if (!s || s->hasActiveClient()) + { + it = pendingWillMessages.erase(it); + continue; + } + logger->logf(LOG_DEBUG, "Sending delayed will on topic '%s'.", p->topic.c_str() ); PublishCopyFactory factory(p.get()); queuePacketAtSubscribers(factory); + + s->clearWill(); } it = pendingWillMessages.erase(it); } @@ -308,7 +323,7 @@ void SubscriptionStore::sendQueuedWillMessages() * @param willMessage * @param forceNow */ -void SubscriptionStore::queueWillMessage(std::shared_ptr &willMessage, bool forceNow) +void SubscriptionStore::queueWillMessage(const std::shared_ptr &willMessage, const std::shared_ptr &session, bool forceNow) { if (!willMessage) return; @@ -320,12 +335,19 @@ void SubscriptionStore::queueWillMessage(std::shared_ptr &willMessage, { PublishCopyFactory factory(willMessage.get()); queuePacketAtSubscribers(factory); + + // Avoid sending two immediate wills when a session is destroyed with the client disconnect. + if (session) // session is null when you're destroying a client before a session is assigned. + session->clearWill(); + return; } + QueuedWill queuedWill(willMessage, session); + std::lock_guard(this->pendingWillsMutex); - auto pos = std::upper_bound(this->pendingWillMessages.begin(), this->pendingWillMessages.end(), willMessage, WillDelayCompare); - this->pendingWillMessages.insert(pos, willMessage); + auto pos = std::upper_bound(this->pendingWillMessages.begin(), this->pendingWillMessages.end(), willMessage, willDelayCompare); + this->pendingWillMessages.insert(pos, queuedWill); } void SubscriptionStore::publishNonRecursively(const std::unordered_map &subscribers, @@ -588,7 +610,7 @@ void SubscriptionStore::removeSession(const std::shared_ptr &session) std::shared_ptr &will = session->getWill(); if (will) { - queueWillMessage(will, true); + queueWillMessage(will, session, true); } RWLockGuard lock_guard(&subscriptionsRwlock); @@ -984,3 +1006,36 @@ std::shared_ptr QueuedSessionRemoval::getSession() const { return session.lock(); } + +QueuedWill::QueuedWill(const std::shared_ptr &will, const std::shared_ptr &session) : + will(will), + session(session), + sendAt(std::chrono::steady_clock::now() + std::chrono::seconds(will->will_delay)) +{ + +} + +const std::weak_ptr &QueuedWill::getWill() const +{ + return this->will; +} + +std::chrono::time_point QueuedWill::getSendAt() const +{ + return this->sendAt; +} + +std::shared_ptr QueuedWill::getSession() +{ + return this->session.lock(); +} + +bool willDelayCompare(const std::shared_ptr &a, const QueuedWill &b) +{ + std::shared_ptr _b = b.getWill().lock(); + + if (!_b) + return true; + + return a->will_delay < _b->will_delay; +}; diff --git a/subscriptionstore.h b/subscriptionstore.h index e23513e..63c1f84 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -101,6 +101,20 @@ public: std::shared_ptr getSession() const; }; +class QueuedWill +{ + std::weak_ptr will; + std::weak_ptr session; + std::chrono::time_point sendAt; + +public: + QueuedWill(const std::shared_ptr &will, const std::shared_ptr &session); + + const std::weak_ptr &getWill() const; + std::chrono::time_point getSendAt() const; + std::shared_ptr getSession(); +}; + class SubscriptionStore { #ifdef TESTING @@ -122,7 +136,7 @@ class SubscriptionStore int64_t retainedMessageCount = 0; std::mutex pendingWillsMutex; - std::list> pendingWillMessages; + std::list pendingWillMessages; std::chrono::time_point lastTreeCleanup; @@ -151,7 +165,7 @@ public: std::shared_ptr lockSession(const std::string &clientid); void sendQueuedWillMessages(); - void queueWillMessage(std::shared_ptr &willMessage, bool forceNow = false); + void queueWillMessage(const std::shared_ptr &willMessage, const std::shared_ptr &session, bool forceNow = false); void queuePacketAtSubscribers(PublishCopyFactory ©Factory, bool dollar = false); uint64_t giveClientRetainedMessages(const std::shared_ptr &client, const std::shared_ptr &ses, const std::vector &subscribeSubtopics, char max_qos); @@ -174,4 +188,6 @@ public: void queueSessionRemoval(const std::shared_ptr &session); }; +bool willDelayCompare(const std::shared_ptr &a, const QueuedWill &b); + #endif // SUBSCRIPTIONSTORE_H diff --git a/types.cpp b/types.cpp index cc125d1..9cbd2da 100644 --- a/types.cpp +++ b/types.cpp @@ -187,11 +187,6 @@ bool PublishBase::hasExpired() const return (expiresAfter > age); } -void PublishBase::setCreatedAt(std::chrono::time_point t) -{ - this->createdAt = t; -} - void PublishBase::setExpireAfter(uint32_t s) { this->createdAt = std::chrono::steady_clock::now(); @@ -221,16 +216,6 @@ Publish::Publish(const std::string &topic, const std::string &payload, char qos) } -bool WillDelayCompare(const std::shared_ptr &a, const std::weak_ptr &b) -{ - std::shared_ptr _b = b.lock(); - - if (!_b) - return true; - - return a->will_delay < _b->will_delay; -}; - PubResponse::PubResponse(const ProtocolVersion protVersion, const PacketType packet_type, ReasonCodes reason_code, uint16_t packet_id) : packet_type(packet_type), protocol_version(protVersion), diff --git a/types.h b/types.h index b68b1f4..3afa14e 100644 --- a/types.h +++ b/types.h @@ -192,6 +192,8 @@ public: */ class PublishBase { + friend class SessionsAndSubscriptionsDB; + bool hasExpireInfo = false; std::chrono::time_point createdAt; std::chrono::seconds expiresAfter; @@ -201,7 +203,7 @@ public: std::string payload; char qos = 0; bool retain = false; // Note: existing subscribers don't get publishes of retained messages with retain=1. [MQTT-3.3.1-9] - uint32_t will_delay = 0; // if will, this is the delay. Just storing here, to avoid having to make a WillMessage class + uint32_t will_delay = 0; // if will, this is the delay. bool splitTopic = true; uint16_t topicAlias = 0; bool skipTopic = false; @@ -216,7 +218,6 @@ public: bool hasUserProperties() const; bool hasExpired() const; - void setCreatedAt(std::chrono::time_point t); void setExpireAfter(uint32_t s); bool getHasExpireInfo() const; const std::chrono::time_point getCreatedAt() const; @@ -232,8 +233,6 @@ public: Publish(const std::string &topic, const std::string &payload, char qos); }; -bool WillDelayCompare(const std::shared_ptr &a, const std::weak_ptr &b); - class PubResponse { public: