diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 43b8c00..f2c7cd1 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -1007,14 +1007,12 @@ void MainTests::testSavingSessions() std::shared_ptr c1(new Client(0, t, nullptr, false, nullptr, settings, false)); c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60); store->registerClientAndKickExistingOne(c1, false, 512, 120); - c1->getSession()->touch(); c1->getSession()->addIncomingQoS2MessageId(2); c1->getSession()->addIncomingQoS2MessageId(3); std::shared_ptr c2(new Client(0, t, nullptr, false, nullptr, settings, false)); c2->setClientProperties(ProtocolVersion::Mqtt311, "c2", "user2", true, 60); store->registerClientAndKickExistingOne(c2, false, 512, 120); - c2->getSession()->touch(); c2->getSession()->addOutgoingQoS2MessageId(55); c2->getSession()->addOutgoingQoS2MessageId(66); diff --git a/client.cpp b/client.cpp index 03fb9bd..81fbd20 100644 --- a/client.cpp +++ b/client.cpp @@ -154,8 +154,6 @@ bool Client::readFdIntoBuffer() } lastActivity = std::chrono::steady_clock::now(); - if (session) - session->touch(lastActivity); return true; } diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 57f44de..27a9883 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -568,7 +568,15 @@ void MqttPacket::handleConnect() if (accessGranted) { - bool sessionPresent = protocolVersion >= ProtocolVersion::Mqtt311 && !clean_start && subscriptionStore->sessionPresent(client_id); + bool sessionPresent = false; + std::shared_ptr existingSession; + + if (protocolVersion >= ProtocolVersion::Mqtt311 && !clean_start) + { + existingSession = subscriptionStore->lockSession(client_id); + if (existingSession) + sessionPresent = true; + } sender->setAuthenticated(true); ConnAck connAck(protocolVersion, ReasonCodes::Success, sessionPresent); diff --git a/session.cpp b/session.cpp index 9828e08..842e69b 100644 --- a/session.cpp +++ b/session.cpp @@ -22,8 +22,6 @@ License along with FlashMQ. If not, see . #include "threadglobals.h" #include "threadglobals.h" -std::chrono::time_point appStartTime = std::chrono::steady_clock::now(); - Session::Session() { const Settings &settings = *ThreadGlobals::getSettings(); @@ -33,44 +31,6 @@ Session::Session() this->sessionExpiryInterval = settings.expireSessionsAfterSeconds; } -int64_t Session::getProgramStartedAtUnixTimestamp() -{ - auto secondsSinceEpoch = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); - const std::chrono::seconds age = std::chrono::duration_cast(std::chrono::steady_clock::now() - appStartTime); - int64_t result = secondsSinceEpoch - age.count(); - return result; -} - -void Session::setProgramStartedAtUnixTimestamp(const int64_t unix_timestamp) -{ - auto secondsSinceEpoch = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()); - const std::chrono::seconds _unix_timestamp = std::chrono::seconds(unix_timestamp); - const std::chrono::seconds age_in_s = secondsSinceEpoch - _unix_timestamp; - appStartTime = std::chrono::steady_clock::now() - age_in_s; -} - -/** - * @brief Session::getSessionRelativeAgeInMs is used to get the value to store on disk when saving sessions. - * @return - */ -int64_t Session::getSessionRelativeAgeInMs() const -{ - const std::chrono::milliseconds sessionAge = std::chrono::duration_cast(lastTouched - appStartTime); - const int64_t sInMs = sessionAge.count(); - return sInMs; -} - -/** - * @brief Session::setSessionTouch is the set 'lastTouched' value relative to the app start time when a session is loaded from disk. - * @param ageInMs - */ -void Session::setSessionTouch(int64_t ageInMs) -{ - std::chrono::milliseconds ms(ageInMs); - std::chrono::time_point point = appStartTime + ms; - lastTouched = point; -} - bool Session::requiresPacketRetransmission() const { const std::shared_ptr client = makeSharedClient(); @@ -110,7 +70,6 @@ Session::Session(const Session &other) this->incomingQoS2MessageIds = other.incomingQoS2MessageIds; this->outgoingQoS2MessageIds = other.outgoingQoS2MessageIds; this->nextPacketId = other.nextPacketId; - this->lastTouched = other.lastTouched; // TODO: see git history for a change here. We now copy the whole queued publish. Do we want to address that? this->qosPacketQueue = other.qosPacketQueue; @@ -284,28 +243,9 @@ uint64_t Session::sendPendingQosMessages() return count; } -/** - * @brief Session::touch with a time value allowed touching without causing another sys/lib call to get the time. - * @param newval - */ -void Session::touch(std::chrono::time_point newval) +bool Session::hasActiveClient() const { - lastTouched = newval; -} - -void Session::touch() -{ - lastTouched = std::chrono::steady_clock::now(); -} - -bool Session::hasExpired() const -{ - if (!client.expired()) - return false; - - std::chrono::seconds expireAfter(sessionExpiryInterval); - std::chrono::time_point now = std::chrono::steady_clock::now(); - return (lastTouched + expireAfter) < now; + return !client.expired(); } void Session::clearWill() diff --git a/session.h b/session.h index f44487b..64b0fa6 100644 --- a/session.h +++ b/session.h @@ -50,12 +50,9 @@ class Session uint16_t maxQosMsgPending; uint16_t QoSLogPrintedAtId = 0; bool destroyOnDisconnect = false; - std::chrono::time_point lastTouched = std::chrono::steady_clock::now(); std::shared_ptr willPublish; Logger *logger = Logger::getInstance(); - int64_t getSessionRelativeAgeInMs() const; - void setSessionTouch(int64_t ageInMs); bool requiresPacketRetransmission() const; void increasePacketId(); @@ -66,9 +63,6 @@ public: Session(Session &&other) = delete; ~Session(); - static int64_t getProgramStartedAtUnixTimestamp(); - static void setProgramStartedAtUnixTimestamp(const int64_t unix_timestamp); - std::unique_ptr getCopy() const; const std::string &getClientId() const { return client_id; } @@ -77,9 +71,7 @@ public: void writePacket(PublishCopyFactory ©Factory, const char max_qos, uint64_t &count); void clearQosMessage(uint16_t packet_id); uint64_t sendPendingQosMessages(); - void touch(std::chrono::time_point val); - void touch(); - bool hasExpired() const; + bool hasActiveClient() const; void clearWill(); std::shared_ptr &getWill(); diff --git a/sessionsandsubscriptionsdb.cpp b/sessionsandsubscriptionsdb.cpp index 8c205f1..fcefe86 100644 --- a/sessionsandsubscriptionsdb.cpp +++ b/sessionsandsubscriptionsdb.cpp @@ -67,12 +67,14 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2() { bool eofFound = false; - const int64_t programStartStamp = readInt64(eofFound); + const int64_t fileSavedAt = readInt64(eofFound); if (eofFound) continue; - logger->logf(LOG_DEBUG, "Setting first app start time to timestamp %ld", programStartStamp); - Session::setProgramStartedAtUnixTimestamp(programStartStamp); + const int64_t now_epoch = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); + const int64_t persistence_state_age = fileSavedAt > now_epoch ? 0 : now_epoch - fileSavedAt; + + logger->logf(LOG_DEBUG, "Session file was saved at %ld. That's %ld seconds ago.", fileSavedAt, persistence_state_age); const uint32_t nrOfSessions = readUint32(eofFound); @@ -146,19 +148,16 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2() logger->logf(LOG_DEBUG, "Loaded next packetid %d.", ses->nextPacketId); ses->nextPacketId = nextPacketId; - int64_t sessionAge = readInt64(eofFound); - logger->logf(LOG_DEBUG, "Loaded session age: %ld ms.", sessionAge); - ses->setSessionTouch(sessionAge); + const uint32_t originalSessionExpiryInterval = readUint32(eofFound); + const uint32_t compensatedSessionExpiry = persistence_state_age > originalSessionExpiryInterval ? 0 : originalSessionExpiryInterval - persistence_state_age; + const uint32_t sessionExpiryInterval = std::min(compensatedSessionExpiry, settings->getExpireSessionAfterSeconds()); - const uint32_t sessionExpiryInterval = std::min(readUint32(eofFound), settings->getExpireSessionAfterSeconds()); const uint16_t maxQosPending = std::min(readUint16(eofFound), settings->maxQosMsgPendingPerClient); - // TODO: perhaps I should calculate a new sessionExpiryInterval, minus the time it was off? - - // Setting the sessionExpiryInterval back to what it was is somewhat naive, in that when you have the - // server off for a week, you basically suspended time and will delay all session destructions. But, - // I'm chosing that option versus kicking out all sessions if the server was off for a longer period. - ses->setSessionProperties(maxQosPending, sessionExpiryInterval, 0, ProtocolVersion::Mqtt5); // The protocol version is just dummy, to get the behavior I want. + // We will set the session expiry interval as it would have had time continued. If a connection picks up session, it will update + // it with a more relevant value. + // The protocol version 5 is just dummy, to get the behavior I want. + ses->setSessionProperties(maxQosPending, sessionExpiryInterval, 0, ProtocolVersion::Mqtt5); } const uint32_t nrOfSubscriptions = readUint32(eofFound); @@ -206,9 +205,9 @@ void SessionsAndSubscriptionsDB::saveData(const std::vectorlogf(LOG_DEBUG, "Saving program first start time stamp as %ld", start_stamp); - writeInt64(start_stamp); + const int64_t now_epoch = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); + logger->logf(LOG_DEBUG, "Saving current time stamp %ld", now_epoch); + writeInt64(now_epoch); writeUint32(sessions.size()); @@ -269,10 +268,6 @@ void SessionsAndSubscriptionsDB::saveData(const std::vectorlogf(LOG_DEBUG, "Writing next packetid %d.", ses->nextPacketId); writeUint16(ses->nextPacketId); - const int64_t sInMs = ses->getSessionRelativeAgeInMs(); - logger->logf(LOG_DEBUG, "Writing session age: %ld ms.", sInMs); - writeInt64(sInMs); - writeUint32(ses->sessionExpiryInterval); writeUint16(ses->maxQosMsgPending); } diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index caaa3d7..737382f 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -253,20 +253,23 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr client->getThreadData()->incrementSentMessageCount(count); } -bool SubscriptionStore::sessionPresent(const std::string &clientid) +/** + * @brief SubscriptionStore::lockSession returns the session if it exists. Returning is done keep the shared pointer active, to + * avoid race conditions with session removal. + * @param clientid + * @return + */ +std::shared_ptr SubscriptionStore::lockSession(const std::string &clientid) { RWLockGuard lock_guard(&subscriptionsRwlock); lock_guard.rdlock(); - bool result = false; - auto it = sessionsByIdConst.find(clientid); if (it != sessionsByIdConst.end()) { - it->second->touch(); // Touching to avoid a race condition between using the session after this, and it expiring. - result = true; + return it->second; } - return result; + return std::shared_ptr(); } void SubscriptionStore::sendQueuedWillMessages() @@ -612,7 +615,7 @@ void SubscriptionStore::removeExpiredSessionsClients() } // A session could have been picked up again, so we have to verify its expiration status. - if (session->hasExpired()) + if (!session->hasActiveClient()) { removeSession(session); } @@ -632,6 +635,10 @@ void SubscriptionStore::removeExpiredSessionsClients() } } +/** + * @brief SubscriptionStore::queueSessionRemoval places session efficiently in a sorted list that is periodically dequeued. + * @param session + */ void SubscriptionStore::queueSessionRemoval(const std::shared_ptr &session) { if (!session) diff --git a/subscriptionstore.h b/subscriptionstore.h index d626b91..e23513e 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -84,6 +84,12 @@ class RetainedMessageNode RetainedMessageNode *getChildren(const std::string &subtopic) const; }; +/** + * @brief A QueuedSessionRemoval is a sort of delayed request for removal. They are kept in a sorted list for fast insertion, + * and fast dequeueing of expired entries from the start. + * + * You can have multiple of these in the pending list. If a client has picked up the session again, the removal is not executed. + */ class QueuedSessionRemoval { std::weak_ptr session; @@ -142,7 +148,7 @@ public: void removeSubscription(std::shared_ptr &client, const std::string &topic); void registerClientAndKickExistingOne(std::shared_ptr &client); void registerClientAndKickExistingOne(std::shared_ptr &client, bool clean_start, uint16_t maxQosPackets, uint32_t sessionExpiryInterval); - bool sessionPresent(const std::string &clientid); + std::shared_ptr lockSession(const std::string &clientid); void sendQueuedWillMessages(); void queueWillMessage(std::shared_ptr &willMessage, bool forceNow = false);