From 9a34fc462fc342e97029fd886bf02111ffcb9c08 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Sun, 6 Mar 2022 21:48:03 +0100 Subject: [PATCH] Change clean session into MQTT5 symantics --- FlashMQTests/tst_maintests.cpp | 12 ++++++------ client.cpp | 19 ++++++------------- client.h | 8 ++------ mainapp.cpp | 4 ++-- mqttpacket.cpp | 12 ++++++------ session.cpp | 31 +++++++++++++++++++++---------- session.h | 6 ++++-- subscriptionstore.cpp | 16 +++++----------- subscriptionstore.h | 2 +- 9 files changed, 53 insertions(+), 57 deletions(-) diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 11155eb..46f6338 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -1001,15 +1001,15 @@ void MainTests::testSavingSessions() ThreadGlobals::assign(&auth); std::shared_ptr c1(new Client(0, t, nullptr, false, nullptr, settings, false)); - c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60, false); - store->registerClientAndKickExistingOne(c1); + c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60); + store->registerClientAndKickExistingOne(c1, false, 512, 120); c1->getSession()->touch(); c1->getSession()->addIncomingQoS2MessageId(2); c1->getSession()->addIncomingQoS2MessageId(3); std::shared_ptr c2(new Client(0, t, nullptr, false, nullptr, settings, false)); - c2->setClientProperties(ProtocolVersion::Mqtt311, "c2", "user2", true, 60, false); - store->registerClientAndKickExistingOne(c2); + c2->setClientProperties(ProtocolVersion::Mqtt311, "c2", "user2", true, 60); + store->registerClientAndKickExistingOne(c2, false, 512, 120); c2->getSession()->touch(); c2->getSession()->addOutgoingQoS2MessageId(55); c2->getSession()->addOutgoingQoS2MessageId(66); @@ -1120,8 +1120,8 @@ void testCopyPacketHelper(const std::string &topic, char from_qos, char to_qos, ThreadGlobals::assign(&auth); std::shared_ptr dummyClient(new Client(0, t, nullptr, false, nullptr, settings, false)); - dummyClient->setClientProperties(ProtocolVersion::Mqtt311, "qostestclient", "user1", true, 60, false); - store->registerClientAndKickExistingOne(dummyClient); + dummyClient->setClientProperties(ProtocolVersion::Mqtt311, "qostestclient", "user1", true, 60); + store->registerClientAndKickExistingOne(dummyClient, false, 512, 120); uint16_t packetid = 66; for (int len = 0; len < 150; len++ ) diff --git a/client.cpp b/client.cpp index 152644b..b17d8d2 100644 --- a/client.cpp +++ b/client.cpp @@ -75,7 +75,7 @@ Client::~Client() } // MQTT-3.1.2-6 - if (cleanSession) + if (session->getDestroyOnDisconnect()) { store->removeSession(clientid); } @@ -286,9 +286,9 @@ bool Client::writeBufIntoFd() std::string Client::repr() { - std::string s = formatString("[ClientID='%s', username='%s', fd=%d, keepalive=%ds, transport='%s', address='%s', cleanses=%d, prot=%s]", + std::string s = formatString("[ClientID='%s', username='%s', fd=%d, keepalive=%ds, transport='%s', address='%s', prot=%s]", clientid.c_str(), username.c_str(), fd, keepalive, this->transportStr.c_str(), this->address.c_str(), - cleanSession, protocolVersionString(protocolVersion).c_str()); + protocolVersionString(protocolVersion).c_str()); return s; } @@ -334,11 +334,6 @@ void Client::resetBuffersIfEligible() writebuf.resetSizeIfEligable(initialBufferSize); } -void Client::setCleanSession(bool val) -{ - this->cleanSession = val; -} - #ifndef NDEBUG /** * @brief IoWrapper::setFakeUpgraded(). @@ -421,24 +416,22 @@ void Client::bufferToMqttPackets(std::vector &packetQueueIn, std::sh setReadyForReading(readbuf.freeSpace() > 0); } -void Client::setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession) +void Client::setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive) { const Settings *settings = ThreadGlobals::getSettings(); - setClientProperties(protocolVersion, clientId, username, connectPacketSeen, keepalive, cleanSession, - settings->maxPacketSize, 0); + setClientProperties(protocolVersion, clientId, username, connectPacketSeen, keepalive, settings->maxPacketSize, 0); } void Client::setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, - bool cleanSession, uint32_t maxPacketSize, uint16_t maxTopicAliases) + uint32_t maxPacketSize, uint16_t maxTopicAliases) { this->protocolVersion = protocolVersion; this->clientid = clientId; this->username = username; this->connectPacketSeen = connectPacketSeen; this->keepalive = keepalive; - this->cleanSession = cleanSession; this->maxPacketSize = maxPacketSize; this->maxTopicAliases = maxTopicAliases; } diff --git a/client.h b/client.h index c0da255..16577ce 100644 --- a/client.h +++ b/client.h @@ -74,7 +74,6 @@ class Client std::string clientid; std::string username; uint16_t keepalive = 0; - bool cleanSession = false; std::string will_topic; std::string will_payload; @@ -108,9 +107,9 @@ public: void markAsDisconnecting(); bool readFdIntoBuffer(); void bufferToMqttPackets(std::vector &packetQueueIn, std::shared_ptr &sender); - void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession); + void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive); void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, - bool cleanSession, uint32_t maxPacketSize, uint16_t maxTopicAliases); + uint32_t maxPacketSize, uint16_t maxTopicAliases); void setWill(const std::string &topic, const std::string &payload, bool retain, char qos); void clearWill(); void setAuthenticated(bool value) { authenticated = value;} @@ -119,7 +118,6 @@ public: std::shared_ptr getThreadData() { return threadData; } std::string &getClientId() { return this->clientid; } const std::string &getUsername() const { return this->username; } - bool getCleanSession() { return cleanSession; } void assignSession(std::shared_ptr &session); std::shared_ptr getSession(); void setDisconnectReason(const std::string &reason); @@ -140,8 +138,6 @@ public: std::string getKeepAliveInfoString() const; void resetBuffersIfEligible(); - void setCleanSession(bool val); - #ifndef NDEBUG void setFakeUpgraded(); #endif diff --git a/mainapp.cpp b/mainapp.cpp index a13e7da..c3d44cd 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -442,11 +442,11 @@ void MainApp::start() std::shared_ptr client = std::make_shared(fd, threaddata, nullptr, fuzzWebsockets, nullptr, settings, true); std::shared_ptr subscriber = std::make_shared(fdnull, threaddata, nullptr, fuzzWebsockets, nullptr, settings, true); - subscriber->setClientProperties(ProtocolVersion::Mqtt311, "subscriber", "subuser", true, 60, true); + subscriber->setClientProperties(ProtocolVersion::Mqtt311, "subscriber", "subuser", true, 60); subscriber->setAuthenticated(true); std::shared_ptr websocketsubscriber = std::make_shared(fdnull2, threaddata, nullptr, true, nullptr, settings, true); - websocketsubscriber->setClientProperties(ProtocolVersion::Mqtt311, "websocketsubscriber", "websocksubuser", true, 60, true); + websocketsubscriber->setClientProperties(ProtocolVersion::Mqtt311, "websocketsubscriber", "websocksubuser", true, 60); websocketsubscriber->setAuthenticated(true); websocketsubscriber->setFakeUpgraded(); subscriptionStore->registerClientAndKickExistingOne(websocketsubscriber); diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 7fd7890..abf8939 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -363,7 +363,7 @@ void MqttPacket::handleConnect() bool will_retain = !!(flagByte & 0b00100000); char will_qos = (flagByte & 0b00011000) >> 3; bool will_flag = !!(flagByte & 0b00000100); - bool clean_session = !!(flagByte & 0b00000010); + bool clean_start = !!(flagByte & 0b00000010); if (will_qos > 2) throw ProtocolError("Invalid QoS for will."); @@ -467,9 +467,9 @@ void MqttPacket::handleConnect() logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false.", client_id.c_str()); validClientId = false; } - else if (!clean_session && client_id.empty()) + else if (!clean_start && client_id.empty()) { - logger->logf(LOG_ERR, "ClientID empty and clean session 0, which is incompatible"); + logger->logf(LOG_ERR, "ClientID empty and clean start 0, which is incompatible"); validClientId = false; } else if (protocolVersion < ProtocolVersion::Mqtt311 && client_id.empty()) @@ -493,7 +493,7 @@ void MqttPacket::handleConnect() client_id = getSecureRandomString(23); } - sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, clean_session, max_packet_size, max_topic_aliases); + sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, max_packet_size, max_topic_aliases); sender->setWill(will_topic, will_payload, will_retain, will_qos); bool accessGranted = false; @@ -518,7 +518,7 @@ void MqttPacket::handleConnect() if (accessGranted) { - bool sessionPresent = protocolVersion >= ProtocolVersion::Mqtt311 && !clean_session && subscriptionStore->sessionPresent(client_id); + bool sessionPresent = protocolVersion >= ProtocolVersion::Mqtt311 && !clean_start && subscriptionStore->sessionPresent(client_id); sender->setAuthenticated(true); ConnAck connAck(ConnAckReturnCodes::Accepted, sessionPresent); @@ -526,7 +526,7 @@ void MqttPacket::handleConnect() sender->writeMqttPacket(response); logger->logf(LOG_NOTICE, "Client '%s' logged in successfully", sender->repr().c_str()); - subscriptionStore->registerClientAndKickExistingOne(sender, max_qos_packets, session_expire); + subscriptionStore->registerClientAndKickExistingOne(sender, clean_start, max_qos_packets, session_expire); } else { diff --git a/session.cpp b/session.cpp index 705a2e6..52e9694 100644 --- a/session.cpp +++ b/session.cpp @@ -76,8 +76,7 @@ bool Session::requiresPacketRetransmission() const if (client->getProtocolVersion() < ProtocolVersion::Mqtt311) return true; - // TODO: for MQTT5, the rules are different. - return !client->getCleanSession(); + return !destroyOnDisconnect; } void Session::increasePacketId() @@ -343,18 +342,30 @@ void Session::removeOutgoingQoS2MessageId(u_int16_t packet_id) outgoingQoS2MessageIds.erase(it); } -bool Session::getCleanSession() const +/** + * @brief Session::getDestroyOnDisconnect + * @return + * + * MQTT5: Setting Clean Start to 1 and a Session Expiry Interval of 0, is equivalent to setting CleanSession to 1 in the MQTT Specification Version 3.1.1. + */ +bool Session::getDestroyOnDisconnect() const { - auto c = client.lock(); - - if (!c) - return false; - - return c->getCleanSession(); + return destroyOnDisconnect; } -void Session::setSessionProperties(uint16_t maxQosPackets, uint32_t sessionExpiryInterval) +void Session::setSessionProperties(uint16_t maxQosPackets, uint32_t sessionExpiryInterval, bool clean_start, ProtocolVersion protocol_version) { this->maxQosMsgPending = maxQosPackets; this->sessionExpiryInterval = sessionExpiryInterval; + + if (protocol_version <= ProtocolVersion::Mqtt311 && clean_start) + destroyOnDisconnect = true; + else + destroyOnDisconnect = sessionExpiryInterval == 0; } + +uint32_t Session::getSessionExpiryInterval() const +{ + return this->sessionExpiryInterval; +} + diff --git a/session.h b/session.h index a525914..264116e 100644 --- a/session.h +++ b/session.h @@ -49,6 +49,7 @@ class Session uint32_t sessionExpiryInterval = 0; uint16_t maxQosMsgPending; uint16_t QoSLogPrintedAtId = 0; + bool destroyOnDisconnect = false; std::chrono::time_point lastTouched = std::chrono::steady_clock::now(); Logger *logger = Logger::getInstance(); @@ -86,9 +87,10 @@ public: void addOutgoingQoS2MessageId(uint16_t packet_id); void removeOutgoingQoS2MessageId(u_int16_t packet_id); - bool getCleanSession() const; + bool getDestroyOnDisconnect() const; - void setSessionProperties(uint16_t maxQosPackets, uint32_t sessionExpiryInterval); + void setSessionProperties(uint16_t maxQosPackets, uint32_t sessionExpiryInterval, bool clean_start, ProtocolVersion protocol_version); + uint32_t getSessionExpiryInterval() const; }; #endif // SESSION_H diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index a0d0cce..60ccc72 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -204,11 +204,11 @@ void SubscriptionStore::removeSubscription(std::shared_ptr &client, cons void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr &client) { const Settings *settings = ThreadGlobals::getSettings(); - registerClientAndKickExistingOne(client, settings->maxQosMsgPendingPerClient, settings->expireSessionsAfterSeconds); + registerClientAndKickExistingOne(client, true, settings->maxQosMsgPendingPerClient, settings->expireSessionsAfterSeconds); } // Removes an existing client when it already exists [MQTT-3.1.4-2]. -void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr &client, uint16_t maxQosPackets, uint32_t sessionExpiryInterval) +void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr &client, bool clean_start, uint16_t maxQosPackets, uint32_t sessionExpiryInterval) { RWLockGuard lock_guard(&subscriptionsRwlock); lock_guard.wrlock(); @@ -216,7 +216,6 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr if (client->getClientId().empty()) throw ProtocolError("Trying to store client without an ID."); - bool originalClientDemandsSessionDestruction = false; std::shared_ptr session; auto session_it = sessionsById.find(client->getClientId()); if (session_it != sessionsById.end()) @@ -232,11 +231,6 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str()); cl->setDisconnectReason("Another client with this ID connected"); - // We have to set session to false, because it's no longer up to the destruction of that client - // to destroy the session. We either do it in this function, or not at all. - originalClientDemandsSessionDestruction = cl->getCleanSession(); - cl->setCleanSession(false); - cl->setReadyForDisconnect(); cl->getThreadData()->removeClientQueued(cl); cl->markAsDisconnecting(); @@ -245,7 +239,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr } } - if (!session || client->getCleanSession() || originalClientDemandsSessionDestruction) + if (!session || session->getDestroyOnDisconnect()) { session = std::make_shared(); @@ -254,7 +248,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr session->assignActiveConnection(client); client->assignSession(session); - session->setSessionProperties(maxQosPackets, sessionExpiryInterval); + session->setSessionProperties(maxQosPackets, sessionExpiryInterval, clean_start, client->getProtocolVersion()); uint64_t count = session->sendPendingQosMessages(); client->getThreadData()->incrementSentMessageCount(count); } @@ -744,7 +738,7 @@ void SubscriptionStore::saveSessionsAndSubscriptions(const std::string &filePath const Session &org = *pair.second.get(); // Sessions created with clean session need to be destroyed when disconnecting, so no point in saving them. - if (org.getCleanSession()) + if (org.getDestroyOnDisconnect()) continue; sessionCopies.push_back(org.getCopy()); diff --git a/subscriptionstore.h b/subscriptionstore.h index e7e549c..f47500b 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -119,7 +119,7 @@ public: void addSubscription(std::shared_ptr &client, const std::string &topic, const std::vector &subtopics, char qos); void removeSubscription(std::shared_ptr &client, const std::string &topic); void registerClientAndKickExistingOne(std::shared_ptr &client); - void registerClientAndKickExistingOne(std::shared_ptr &client, uint16_t maxQosPackets, uint32_t sessionExpiryInterval); + void registerClientAndKickExistingOne(std::shared_ptr &client, bool clean_start, uint16_t maxQosPackets, uint32_t sessionExpiryInterval); bool sessionPresent(const std::string &clientid); void queuePacketAtSubscribers(const std::vector &subtopics, MqttPacket &packet, bool dollar = false); -- libgit2 0.21.4