diff --git a/client.cpp b/client.cpp index 1d7c6a2..c6add14 100644 --- a/client.cpp +++ b/client.cpp @@ -509,9 +509,9 @@ void Client::setClientProperties(ProtocolVersion protocolVersion, const std::str this->maxOutgoingTopicAliasValue = maxOutgoingTopicAliasValue; } -void Client::setWill(Publish &&willPublish) +void Client::setWill(WillPublish &&willPublish) { - this->willPublish = std::make_shared(std::move(willPublish)); + this->willPublish = std::make_shared(std::move(willPublish)); } void Client::assignSession(std::shared_ptr &session) diff --git a/client.h b/client.h index d8da2bc..b2558e5 100644 --- a/client.h +++ b/client.h @@ -77,7 +77,7 @@ class Client std::string username; uint16_t keepalive = 0; - std::shared_ptr willPublish; + std::shared_ptr willPublish; std::shared_ptr threadData; std::mutex writeBufMutex; @@ -115,7 +115,7 @@ public: void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, uint32_t maxOutgoingPacketSize, uint16_t maxOutgoingTopicAliasValue); void setWill(const std::string &topic, const std::string &payload, bool retain, char qos); - void setWill(Publish &&willPublish); + void setWill(WillPublish &&willPublish); void clearWill(); void setAuthenticated(bool value) { authenticated = value;} bool getAuthenticated() { return authenticated; } @@ -123,7 +123,7 @@ public: std::shared_ptr getThreadData() { return threadData; } std::string &getClientId() { return this->clientid; } const std::string &getUsername() const { return this->username; } - std::shared_ptr &getWill() { return this->willPublish; } + std::shared_ptr &getWill() { return this->willPublish; } void assignSession(std::shared_ptr &session); std::shared_ptr getSession(); void setDisconnectReason(const std::string &reason); diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 1fc8ab0..02aaf4a 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -423,7 +423,7 @@ void MqttPacket::handleConnect() std::string username; std::string password; - Publish willpublish; + WillPublish willpublish; willpublish.qos = will_qos; willpublish.retain = will_retain; diff --git a/session.cpp b/session.cpp index 37f1d99..15a9c2a 100644 --- a/session.cpp +++ b/session.cpp @@ -71,6 +71,7 @@ Session::Session(const Session &other) this->outgoingQoS2MessageIds = other.outgoingQoS2MessageIds; this->nextPacketId = other.nextPacketId; this->sessionExpiryInterval = other.sessionExpiryInterval; + this->willPublish = other.willPublish; // TODO: see git history for a change here. We now copy the whole queued publish. Do we want to address that? this->qosPacketQueue = other.qosPacketQueue; @@ -276,11 +277,16 @@ void Session::clearWill() this->willPublish.reset(); } -std::shared_ptr &Session::getWill() +std::shared_ptr &Session::getWill() { return this->willPublish; } +void Session::setWill(WillPublish &&pub) +{ + this->willPublish = std::make_shared(std::move(pub)); +} + void Session::addIncomingQoS2MessageId(uint16_t packet_id) { assert(packet_id > 0); diff --git a/session.h b/session.h index 9882638..173074b 100644 --- a/session.h +++ b/session.h @@ -50,7 +50,7 @@ class Session uint16_t maxQosMsgPending; uint16_t QoSLogPrintedAtId = 0; bool destroyOnDisconnect = false; - std::shared_ptr willPublish; + std::shared_ptr willPublish; Logger *logger = Logger::getInstance(); bool requiresPacketRetransmission() const; @@ -73,7 +73,8 @@ public: uint64_t sendPendingQosMessages(); bool hasActiveClient() const; void clearWill(); - std::shared_ptr &getWill(); + std::shared_ptr &getWill(); + void setWill(WillPublish &&pub); void addIncomingQoS2MessageId(uint16_t packet_id); bool incomingQoS2MessageIdInTransit(uint16_t packet_id); diff --git a/sessionsandsubscriptionsdb.cpp b/sessionsandsubscriptionsdb.cpp index 7c5b827..c346875 100644 --- a/sessionsandsubscriptionsdb.cpp +++ b/sessionsandsubscriptionsdb.cpp @@ -167,6 +167,32 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2() // it with a more relevant value. // The protocol version 5 is just dummy, to get the behavior I want. ses->setSessionProperties(maxQosPending, sessionExpiryInterval, 0, ProtocolVersion::Mqtt5); + + const uint16_t hasWill = readUint16(eofFound); + + if (hasWill) + { + const uint16_t fixed_header_length = readUint16(eofFound); + const uint32_t originalWillDelay = readUint32(eofFound); + const uint32_t originalWillQueueAge = readUint32(eofFound); + const uint32_t newWillDelayAfterMaybeAlreadyBeingQueued = originalWillQueueAge < originalWillDelay ? originalWillDelay - originalWillQueueAge : 0; + const uint32_t packlen = readUint32(eofFound); + + const uint32_t stateAgecompensatedWillDelay = + persistence_state_age > newWillDelayAfterMaybeAlreadyBeingQueued ? 0 : newWillDelayAfterMaybeAlreadyBeingQueued - persistence_state_age; + + cirbuf.reset(); + cirbuf.ensureFreeSpace(packlen + 32); + + readCheck(cirbuf.headPtr(), 1, packlen, f); + cirbuf.advanceHead(packlen); + MqttPacket publishpack(cirbuf, packlen, fixed_header_length, dummyClient); + publishpack.parsePublishData(); + WillPublish willPublish = publishpack.getPublishData(); + willPublish.will_delay = stateAgecompensatedWillDelay; + + ses->setWill(std::move(willPublish)); + } } const uint32_t nrOfSubscriptions = readUint32(eofFound); @@ -289,6 +315,30 @@ void SessionsAndSubscriptionsDB::saveData(const std::vectorsessionExpiryInterval); writeUint16(ses->maxQosMsgPending); + + const bool hasWillThatShouldSurviveRestart = ses->getWill().operator bool() && ses->getWill()->will_delay > 0; + writeUint16(static_cast(hasWillThatShouldSurviveRestart)); + + if (hasWillThatShouldSurviveRestart) + { + WillPublish &will = *ses->getWill().get(); + MqttPacket willpacket(ProtocolVersion::Mqtt5, will); + + // Dummy, to please the parser on reading. + if (will.qos > 0) + willpacket.setPacketId(666); + + const uint32_t packSize = willpacket.getSizeIncludingNonPresentHeader(); + cirbuf.reset(); + cirbuf.ensureFreeSpace(packSize + 32); + willpacket.readIntoBuf(cirbuf); + + writeUint16(willpacket.getFixedHeaderLength()); + writeUint32(will.will_delay); + writeUint32(will.getQueuedAtAge()); + writeUint32(packSize); + writeCheck(cirbuf.tailPtr(), 1, cirbuf.usedBytes(), f); + } } writeUint32(subscriptions.size()); diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index e3052f2..5be6b50 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -326,7 +326,7 @@ void SubscriptionStore::sendQueuedWillMessages() * @param willMessage * @param forceNow */ -void SubscriptionStore::queueWillMessage(const std::shared_ptr &willMessage, const std::shared_ptr &session, bool forceNow) +void SubscriptionStore::queueWillMessage(const std::shared_ptr &willMessage, const std::shared_ptr &session, bool forceNow) { if (!willMessage) return; @@ -349,6 +349,8 @@ void SubscriptionStore::queueWillMessage(const std::shared_ptr &willMes return; } + willMessage->setQueuedAt(); + QueuedWill queuedWill(willMessage, session); std::lock_guard(this->pendingWillsMutex); @@ -613,7 +615,7 @@ void SubscriptionStore::removeSession(const std::shared_ptr &session) const std::string &clientid = session->getClientId(); logger->logf(LOG_DEBUG, "Removing session of client '%s'.", clientid.c_str()); - std::shared_ptr &will = session->getWill(); + std::shared_ptr &will = session->getWill(); if (will) { queueWillMessage(will, session, true); @@ -905,6 +907,7 @@ void SubscriptionStore::loadSessionsAndSubscriptions(const std::string &filePath { sessionsById[session->getClientId()] = session; queueSessionRemoval(session); + queueWillMessage(session->getWill(), session); } std::vector subtopics; @@ -1013,7 +1016,7 @@ std::shared_ptr QueuedSessionRemoval::getSession() const return session.lock(); } -QueuedWill::QueuedWill(const std::shared_ptr &will, const std::shared_ptr &session) : +QueuedWill::QueuedWill(const std::shared_ptr &will, const std::shared_ptr &session) : will(will), session(session), sendAt(std::chrono::steady_clock::now() + std::chrono::seconds(will->will_delay)) @@ -1021,7 +1024,7 @@ QueuedWill::QueuedWill(const std::shared_ptr &will, const std::shared_p } -const std::weak_ptr &QueuedWill::getWill() const +const std::weak_ptr &QueuedWill::getWill() const { return this->will; } @@ -1036,9 +1039,9 @@ std::shared_ptr QueuedWill::getSession() return this->session.lock(); } -bool willDelayCompare(const std::shared_ptr &a, const QueuedWill &b) +bool willDelayCompare(const std::shared_ptr &a, const QueuedWill &b) { - std::shared_ptr _b = b.getWill().lock(); + std::shared_ptr _b = b.getWill().lock(); if (!_b) return true; diff --git a/subscriptionstore.h b/subscriptionstore.h index 3124d3d..db46f61 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -103,14 +103,14 @@ public: class QueuedWill { - std::weak_ptr will; + std::weak_ptr will; std::weak_ptr session; std::chrono::time_point sendAt; public: - QueuedWill(const std::shared_ptr &will, const std::shared_ptr &session); + QueuedWill(const std::shared_ptr &will, const std::shared_ptr &session); - const std::weak_ptr &getWill() const; + const std::weak_ptr &getWill() const; std::chrono::time_point getSendAt() const; std::shared_ptr getSession(); }; @@ -165,7 +165,7 @@ public: std::shared_ptr lockSession(const std::string &clientid); void sendQueuedWillMessages(); - void queueWillMessage(const std::shared_ptr &willMessage, const std::shared_ptr &session, bool forceNow = false); + void queueWillMessage(const std::shared_ptr &willMessage, const std::shared_ptr &session, bool forceNow = false); void queuePacketAtSubscribers(PublishCopyFactory ©Factory, bool dollar = false); uint64_t giveClientRetainedMessages(const std::shared_ptr &ses, const std::vector &subscribeSubtopics, char max_qos); @@ -188,6 +188,6 @@ public: void queueSessionRemoval(const std::shared_ptr &session); }; -bool willDelayCompare(const std::shared_ptr &a, const QueuedWill &b); +bool willDelayCompare(const std::shared_ptr &a, const QueuedWill &b); #endif // SUBSCRIPTIONSTORE_H diff --git a/types.cpp b/types.cpp index 9cbd2da..d9b5c0b 100644 --- a/types.cpp +++ b/types.cpp @@ -216,6 +216,34 @@ Publish::Publish(const std::string &topic, const std::string &payload, char qos) } +WillPublish::WillPublish(const Publish &other) : + Publish(other) +{ + +} + +void WillPublish::setQueuedAt() +{ + this->isQueued = true; + this->queuedAt = std::chrono::steady_clock::now(); +} + +/** + * @brief WillPublish::getQueuedAtAge gets the time ago in seconds when this will was queued. The time is set externally by the queue action. + * @return + * + * This age is required when saving wills to disk, because the new will delay to set on load is not the original will delay, but minus the + * elapsed time after queueing. + */ +uint32_t WillPublish::getQueuedAtAge() const +{ + if (!isQueued) + return 0; + + const std::chrono::seconds age = std::chrono::duration_cast(std::chrono::steady_clock::now() - this->queuedAt); + return age.count(); +} + PubResponse::PubResponse(const ProtocolVersion protVersion, const PacketType packet_type, ReasonCodes reason_code, uint16_t packet_id) : packet_type(packet_type), protocol_version(protVersion), diff --git a/types.h b/types.h index 3afa14e..0c4697a 100644 --- a/types.h +++ b/types.h @@ -203,7 +203,6 @@ public: std::string payload; char qos = 0; bool retain = false; // Note: existing subscribers don't get publishes of retained messages with retain=1. [MQTT-3.3.1-9] - uint32_t will_delay = 0; // if will, this is the delay. bool splitTopic = true; uint16_t topicAlias = 0; bool skipTopic = false; @@ -233,6 +232,18 @@ public: Publish(const std::string &topic, const std::string &payload, char qos); }; +class WillPublish : public Publish +{ + bool isQueued = false; + std::chrono::time_point queuedAt; +public: + uint32_t will_delay = 0; + WillPublish() = default; + WillPublish(const Publish &other); + void setQueuedAt(); + uint32_t getQueuedAtAge() const; +}; + class PubResponse { public: