Commit 862f093347dafa9e69bd096b38bb67247199ea9e

Authored by Wiebe Cazemier
1 parent 3cb1fae8

Sessions and 'clean session' works

client.cpp
... ... @@ -327,12 +327,13 @@ bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_
327 327 return true;
328 328 }
329 329  
330   -void Client::setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive)
  330 +void Client::setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession)
331 331 {
332 332 this->clientid = clientId;
333 333 this->username = username;
334 334 this->connectPacketSeen = connectPacketSeen;
335 335 this->keepalive = keepalive;
  336 + this->cleanSession = cleanSession;
336 337 }
337 338  
338 339 void Client::setWill(const std::string &topic, const std::string &payload, bool retain, char qos)
... ...
client.h
... ... @@ -41,6 +41,7 @@ class Client
41 41 std::string clientid;
42 42 std::string username;
43 43 uint16_t keepalive = 0;
  44 + bool cleanSession = false;
44 45  
45 46 std::string will_topic;
46 47 std::string will_payload;
... ... @@ -64,13 +65,14 @@ public:
64 65 void markAsDisconnecting();
65 66 bool readFdIntoBuffer();
66 67 bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender);
67   - void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive);
  68 + void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession);
68 69 void setWill(const std::string &topic, const std::string &payload, bool retain, char qos);
69 70 void setAuthenticated(bool value) { authenticated = value;}
70 71 bool getAuthenticated() { return authenticated; }
71 72 bool hasConnectPacketSeen() { return connectPacketSeen; }
72 73 ThreadData_p getThreadData() { return threadData; }
73 74 std::string &getClientId() { return this->clientid; }
  75 + bool getCleanSession() { return cleanSession; }
74 76  
75 77 void writePingResp();
76 78 void writeMqttPacket(const MqttPacket &packet);
... ...
mqttpacket.cpp
... ... @@ -216,7 +216,7 @@ void MqttPacket::handleConnect()
216 216 return;
217 217 }
218 218  
219   - sender->setClientProperties(client_id, username, true, keep_alive);
  219 + sender->setClientProperties(client_id, username, true, keep_alive, clean_session);
220 220 sender->setWill(will_topic, will_payload, will_retain, will_qos);
221 221  
222 222 if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success)
... ...
session.cpp
... ... @@ -5,11 +5,6 @@ Session::Session()
5 5  
6 6 }
7 7  
8   -Session::Session(std::shared_ptr<Client> &client)
9   -{
10   - this->client = client;
11   -}
12   -
13 8 bool Session::clientDisconnected() const
14 9 {
15 10 return client.expired();
... ... @@ -17,5 +12,10 @@ bool Session::clientDisconnected() const
17 12  
18 13 std::shared_ptr<Client> Session::makeSharedClient() const
19 14 {
20   - return client.lock();
  15 + return client.lock();
  16 +}
  17 +
  18 +void Session::assignActiveConnection(std::shared_ptr<Client> &client)
  19 +{
  20 + this->client = client;
21 21 }
... ...
session.h
... ... @@ -11,10 +11,12 @@ class Session
11 11 // TODO: qos message queue, as some kind of movable pointer.
12 12 public:
13 13 Session();
14   - Session(std::shared_ptr<Client> &client);
  14 + Session(const Session &other) = delete;
  15 + Session(Session &&other) = delete;
15 16  
16 17 bool clientDisconnected() const;
17 18 std::shared_ptr<Client> makeSharedClient() const;
  19 + void assignActiveConnection(std::shared_ptr<Client> &client);
18 20 };
19 21  
20 22 #endif // SESSION_H
... ...
subscriptionstore.cpp
... ... @@ -48,7 +48,12 @@ void SubscriptionStore::addSubscription(Client_p &amp;client, const std::string &amp;top
48 48  
49 49 if (deepestNode)
50 50 {
51   - deepestNode->subscribers.push_front(client->getClientId());
  51 + auto session_it = sessionsByIdConst.find(client->getClientId());
  52 + if (session_it != sessionsByIdConst.end())
  53 + {
  54 + std::weak_ptr<Session> b = session_it->second;
  55 + deepestNode->subscribers.push_front(b);
  56 + }
52 57 }
53 58  
54 59 lock_guard.unlock();
... ... @@ -65,35 +70,43 @@ void SubscriptionStore::registerClientAndKickExistingOne(Client_p &amp;client)
65 70 if (client->getClientId().empty())
66 71 throw ProtocolError("Trying to store client without an ID.");
67 72  
  73 + std::shared_ptr<Session> session;
68 74 auto session_it = sessionsById.find(client->getClientId());
69 75 if (session_it != sessionsById.end())
70 76 {
71   - Session &session = session_it->second;
  77 + session = session_it->second;
72 78  
73   - if (!session.clientDisconnected())
  79 + if (session && !session->clientDisconnected())
74 80 {
75   - std::shared_ptr<Client> cl = session.makeSharedClient();
  81 + std::shared_ptr<Client> cl = session->makeSharedClient();
76 82 logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str());
77 83 cl->setReadyForDisconnect();
78 84 cl->getThreadData()->removeClient(cl);
79 85 cl->markAsDisconnecting();
80 86 }
81 87 }
82   - sessionsById[client->getClientId()] = client;
  88 +
  89 + if (!session || client->getCleanSession())
  90 + {
  91 + session.reset(new Session());
  92 + sessionsById[client->getClientId()] = session;
  93 + }
  94 +
  95 + session->assignActiveConnection(client);
83 96 }
84 97  
85 98 // TODO: should I implement cache, this needs to be changed to returning a list of clients.
86   -void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::forward_list<std::string> &subscribers) const
  99 +void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::forward_list<std::weak_ptr<Session>> &subscribers) const
87 100 {
88   - for (const std::string &client_id : subscribers)
  101 + for (const std::weak_ptr<Session> session_weak : subscribers)
89 102 {
90   - auto session_it = sessionsByIdConst.find(client_id);
91   - if (session_it != sessionsByIdConst.end())
  103 + if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect.
92 104 {
93   - const Session &session = session_it->second;
94   - if (!session.clientDisconnected())
  105 + const std::shared_ptr<Session> session = session_weak.lock();
  106 +
  107 + if (!session->clientDisconnected())
95 108 {
96   - Client_p c = session.makeSharedClient();
  109 + Client_p c = session->makeSharedClient();
97 110 c->writeMqttPacketAndBlameThisClient(packet);
98 111 }
99 112 }
... ...
subscriptionstore.h
... ... @@ -30,7 +30,7 @@ public:
30 30 SubscriptionNode(const SubscriptionNode &node) = delete;
31 31 SubscriptionNode(SubscriptionNode &&node) = delete;
32 32  
33   - std::forward_list<std::string> subscribers; // The idea is to store subscriptions by client id, to support persistent sessions.
  33 + std::forward_list<std::weak_ptr<Session>> subscribers; // The idea is to store subscriptions by client id, to support persistent sessions.
34 34 std::unordered_map<std::string, std::unique_ptr<SubscriptionNode>> children;
35 35 std::unique_ptr<SubscriptionNode> childrenPlus;
36 36 std::unique_ptr<SubscriptionNode> childrenPound;
... ... @@ -40,15 +40,15 @@ class SubscriptionStore
40 40 {
41 41 std::unique_ptr<SubscriptionNode> root;
42 42 pthread_rwlock_t subscriptionsRwlock = PTHREAD_RWLOCK_INITIALIZER;
43   - std::unordered_map<std::string, Session> sessionsById;
44   - const std::unordered_map<std::string, Session> &sessionsByIdConst;
  43 + std::unordered_map<std::string, std::shared_ptr<Session>> sessionsById;
  44 + const std::unordered_map<std::string, std::shared_ptr<Session>> &sessionsByIdConst;
45 45  
46 46 pthread_rwlock_t retainedMessagesRwlock = PTHREAD_RWLOCK_INITIALIZER;
47 47 std::unordered_set<RetainedMessage> retainedMessages;
48 48  
49 49 Logger *logger = Logger::getInstance();
50 50  
51   - void publishNonRecursively(const MqttPacket &packet, const std::forward_list<std::string> &subscribers) const;
  51 + void publishNonRecursively(const MqttPacket &packet, const std::forward_list<std::weak_ptr<Session>> &subscribers) const;
52 52 void publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end,
53 53 std::unique_ptr<SubscriptionNode> &next, const MqttPacket &packet) const;
54 54 public:
... ...