diff --git a/session.cpp b/session.cpp index 15a9c2a..853dbb3 100644 --- a/session.cpp +++ b/session.cpp @@ -72,6 +72,11 @@ Session::Session(const Session &other) this->nextPacketId = other.nextPacketId; this->sessionExpiryInterval = other.sessionExpiryInterval; this->willPublish = other.willPublish; + this->removalQueued = other.removalQueued; + this->removalQueuedAt = other.removalQueuedAt; + + + // TODO: perhaps this copy constructor is nonsense now. // TODO: see git history for a change here. We now copy the whole queued publish. Do we want to address that? this->qosPacketQueue = other.qosPacketQueue; @@ -107,6 +112,7 @@ void Session::assignActiveConnection(std::shared_ptr &client) this->client_id = client->getClientId(); this->username = client->getUsername(); this->willPublish = client->getWill(); + this->removalQueued = false; } /** @@ -371,8 +377,25 @@ void Session::setSessionExpiryInterval(uint32_t newVal) this->sessionExpiryInterval = newVal; } +void Session::setQueuedRemovalAt() +{ + this->removalQueuedAt = std::chrono::steady_clock::now(); + this->removalQueued = true; +} + uint32_t Session::getSessionExpiryInterval() const { return this->sessionExpiryInterval; } +uint32_t Session::getCurrentSessionExpiryInterval() const +{ + if (!this->removalQueued || hasActiveClient()) + return this->sessionExpiryInterval; + + const std::chrono::seconds age = std::chrono::duration_cast(std::chrono::steady_clock::now() - this->removalQueuedAt); + const uint32_t ageInSeconds = age.count(); + const uint32_t result = ageInSeconds <= this->sessionExpiryInterval ? this->sessionExpiryInterval - age.count() : 0; + return result; +} + diff --git a/session.h b/session.h index 173074b..f21ef41 100644 --- a/session.h +++ b/session.h @@ -51,6 +51,8 @@ class Session uint16_t QoSLogPrintedAtId = 0; bool destroyOnDisconnect = false; std::shared_ptr willPublish; + bool removalQueued = false; + std::chrono::time_point removalQueuedAt; Logger *logger = Logger::getInstance(); bool requiresPacketRetransmission() const; @@ -87,7 +89,9 @@ public: void setSessionProperties(uint16_t maxQosPackets, uint32_t sessionExpiryInterval, bool clean_start, ProtocolVersion protocol_version); void setSessionExpiryInterval(uint32_t newVal); + void setQueuedRemovalAt(); uint32_t getSessionExpiryInterval() const; + uint32_t getCurrentSessionExpiryInterval() const; }; #endif // SESSION_H diff --git a/sessionsandsubscriptionsdb.cpp b/sessionsandsubscriptionsdb.cpp index c346875..823613e 100644 --- a/sessionsandsubscriptionsdb.cpp +++ b/sessionsandsubscriptionsdb.cpp @@ -313,7 +313,7 @@ void SessionsAndSubscriptionsDB::saveData(const std::vectorlogf(LOG_DEBUG, "Writing next packetid %d.", ses->nextPacketId); writeUint16(ses->nextPacketId); - writeUint32(ses->sessionExpiryInterval); + writeUint32(ses->getCurrentSessionExpiryInterval()); writeUint16(ses->maxQosMsgPending); const bool hasWillThatShouldSurviveRestart = ses->getWill().operator bool() && ses->getWill()->will_delay > 0; diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index 5be6b50..ccbd0f3 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -695,6 +695,8 @@ void SubscriptionStore::queueSessionRemoval(const std::shared_ptr &sess return a.getExpiresAt() < b.getExpiresAt(); }; + session->setQueuedRemovalAt(); + std::lock_guard(this->queuedSessionRemovalsMutex); auto pos = std::upper_bound(this->queuedSessionRemovals.begin(), this->queuedSessionRemovals.end(), qsr, comp); this->queuedSessionRemovals.insert(pos, qsr);