From 862f093347dafa9e69bd096b38bb67247199ea9e Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Tue, 5 Jan 2021 21:03:54 +0100 Subject: [PATCH] Sessions and 'clean session' works --- client.cpp | 3 ++- client.h | 4 +++- mqttpacket.cpp | 2 +- session.cpp | 12 ++++++------ session.h | 4 +++- subscriptionstore.cpp | 37 +++++++++++++++++++++++++------------ subscriptionstore.h | 8 ++++---- 7 files changed, 44 insertions(+), 26 deletions(-) diff --git a/client.cpp b/client.cpp index b743efe..73baba5 100644 --- a/client.cpp +++ b/client.cpp @@ -327,12 +327,13 @@ bool Client::bufferToMqttPackets(std::vector &packetQueueIn, Client_ return true; } -void Client::setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive) +void Client::setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession) { this->clientid = clientId; this->username = username; this->connectPacketSeen = connectPacketSeen; this->keepalive = keepalive; + this->cleanSession = cleanSession; } void Client::setWill(const std::string &topic, const std::string &payload, bool retain, char qos) diff --git a/client.h b/client.h index 3199336..b7b49ca 100644 --- a/client.h +++ b/client.h @@ -41,6 +41,7 @@ class Client std::string clientid; std::string username; uint16_t keepalive = 0; + bool cleanSession = false; std::string will_topic; std::string will_payload; @@ -64,13 +65,14 @@ public: void markAsDisconnecting(); bool readFdIntoBuffer(); bool bufferToMqttPackets(std::vector &packetQueueIn, Client_p &sender); - void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive); + void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession); void setWill(const std::string &topic, const std::string &payload, bool retain, char qos); void setAuthenticated(bool value) { authenticated = value;} bool getAuthenticated() { return authenticated; } bool hasConnectPacketSeen() { return connectPacketSeen; } ThreadData_p getThreadData() { return threadData; } std::string &getClientId() { return this->clientid; } + bool getCleanSession() { return cleanSession; } void writePingResp(); void writeMqttPacket(const MqttPacket &packet); diff --git a/mqttpacket.cpp b/mqttpacket.cpp index e4e0283..12b900e 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -216,7 +216,7 @@ void MqttPacket::handleConnect() return; } - sender->setClientProperties(client_id, username, true, keep_alive); + sender->setClientProperties(client_id, username, true, keep_alive, clean_session); sender->setWill(will_topic, will_payload, will_retain, will_qos); if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success) diff --git a/session.cpp b/session.cpp index a228919..964806a 100644 --- a/session.cpp +++ b/session.cpp @@ -5,11 +5,6 @@ Session::Session() } -Session::Session(std::shared_ptr &client) -{ - this->client = client; -} - bool Session::clientDisconnected() const { return client.expired(); @@ -17,5 +12,10 @@ bool Session::clientDisconnected() const std::shared_ptr Session::makeSharedClient() const { - return client.lock(); + return client.lock(); +} + +void Session::assignActiveConnection(std::shared_ptr &client) +{ + this->client = client; } diff --git a/session.h b/session.h index ef6cb60..02b21d0 100644 --- a/session.h +++ b/session.h @@ -11,10 +11,12 @@ class Session // TODO: qos message queue, as some kind of movable pointer. public: Session(); - Session(std::shared_ptr &client); + Session(const Session &other) = delete; + Session(Session &&other) = delete; bool clientDisconnected() const; std::shared_ptr makeSharedClient() const; + void assignActiveConnection(std::shared_ptr &client); }; #endif // SESSION_H diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index b6b4d9e..bc52ba3 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -48,7 +48,12 @@ void SubscriptionStore::addSubscription(Client_p &client, const std::string &top if (deepestNode) { - deepestNode->subscribers.push_front(client->getClientId()); + auto session_it = sessionsByIdConst.find(client->getClientId()); + if (session_it != sessionsByIdConst.end()) + { + std::weak_ptr b = session_it->second; + deepestNode->subscribers.push_front(b); + } } lock_guard.unlock(); @@ -65,35 +70,43 @@ void SubscriptionStore::registerClientAndKickExistingOne(Client_p &client) if (client->getClientId().empty()) throw ProtocolError("Trying to store client without an ID."); + std::shared_ptr session; auto session_it = sessionsById.find(client->getClientId()); if (session_it != sessionsById.end()) { - Session &session = session_it->second; + session = session_it->second; - if (!session.clientDisconnected()) + if (session && !session->clientDisconnected()) { - std::shared_ptr cl = session.makeSharedClient(); + std::shared_ptr cl = session->makeSharedClient(); logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str()); cl->setReadyForDisconnect(); cl->getThreadData()->removeClient(cl); cl->markAsDisconnecting(); } } - sessionsById[client->getClientId()] = client; + + if (!session || client->getCleanSession()) + { + session.reset(new Session()); + sessionsById[client->getClientId()] = session; + } + + session->assignActiveConnection(client); } // TODO: should I implement cache, this needs to be changed to returning a list of clients. -void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::forward_list &subscribers) const +void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::forward_list> &subscribers) const { - for (const std::string &client_id : subscribers) + for (const std::weak_ptr session_weak : subscribers) { - auto session_it = sessionsByIdConst.find(client_id); - if (session_it != sessionsByIdConst.end()) + if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect. { - const Session &session = session_it->second; - if (!session.clientDisconnected()) + const std::shared_ptr session = session_weak.lock(); + + if (!session->clientDisconnected()) { - Client_p c = session.makeSharedClient(); + Client_p c = session->makeSharedClient(); c->writeMqttPacketAndBlameThisClient(packet); } } diff --git a/subscriptionstore.h b/subscriptionstore.h index ffd94cc..dbb02d1 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -30,7 +30,7 @@ public: SubscriptionNode(const SubscriptionNode &node) = delete; SubscriptionNode(SubscriptionNode &&node) = delete; - std::forward_list subscribers; // The idea is to store subscriptions by client id, to support persistent sessions. + std::forward_list> subscribers; // The idea is to store subscriptions by client id, to support persistent sessions. std::unordered_map> children; std::unique_ptr childrenPlus; std::unique_ptr childrenPound; @@ -40,15 +40,15 @@ class SubscriptionStore { std::unique_ptr root; pthread_rwlock_t subscriptionsRwlock = PTHREAD_RWLOCK_INITIALIZER; - std::unordered_map sessionsById; - const std::unordered_map &sessionsByIdConst; + std::unordered_map> sessionsById; + const std::unordered_map> &sessionsByIdConst; pthread_rwlock_t retainedMessagesRwlock = PTHREAD_RWLOCK_INITIALIZER; std::unordered_set retainedMessages; Logger *logger = Logger::getInstance(); - void publishNonRecursively(const MqttPacket &packet, const std::forward_list &subscribers) const; + void publishNonRecursively(const MqttPacket &packet, const std::forward_list> &subscribers) const; void publishRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, std::unique_ptr &next, const MqttPacket &packet) const; public: -- libgit2 0.21.4