Commit 33ef5bdf3c10df486f997b497984ee460663454e

Authored by Wiebe Cazemier
1 parent 9e33ebda

Working on expiring sessions

This includes a timer mechanism.
CMakeLists.txt
... ... @@ -27,6 +27,7 @@ add_executable(FlashMQ
27 27 authplugin.cpp
28 28 configfileparser.cpp
29 29 sslctxmanager.cpp
  30 + timer.cpp
30 31 )
31 32  
32 33 target_link_libraries(FlashMQ pthread dl ssl crypto)
... ...
client.cpp
... ... @@ -212,6 +212,8 @@ bool Client::readFdIntoBuffer()
212 212 }
213 213  
214 214 lastActivity = time(NULL);
  215 + if (session)
  216 + session->touch(lastActivity);
215 217  
216 218 return true;
217 219 }
... ...
mainapp.cpp
... ... @@ -153,6 +153,9 @@ MainApp::MainApp(const std::string &configFilePath) :
153 153  
154 154 confFileParser.reset(new ConfigFileParser(configFilePath));
155 155 loadConfig();
  156 +
  157 + auto f = std::bind(&MainApp::queueCleanup, this);
  158 + timer.addCallback(f, 86400000, "session expiration");
156 159 }
157 160  
158 161 MainApp::~MainApp()
... ... @@ -237,6 +240,12 @@ int MainApp::createListenSocket(int portNr, bool ssl)
237 240 return listen_fd;
238 241 }
239 242  
  243 +void MainApp::wakeUpThread()
  244 +{
  245 + uint64_t one = 1;
  246 + write(taskEventFd, &one, sizeof(uint64_t));
  247 +}
  248 +
240 249 void MainApp::initMainApp(int argc, char *argv[])
241 250 {
242 251 if (instance != nullptr)
... ... @@ -318,6 +327,8 @@ MainApp *MainApp::getMainApp()
318 327  
319 328 void MainApp::start()
320 329 {
  330 + timer.start();
  331 +
321 332 int listen_fd_plain = createListenSocket(this->listenPort, false);
322 333 int listen_fd_ssl = createListenSocket(this->sslListenPort, true);
323 334  
... ... @@ -422,6 +433,7 @@ void MainApp::quit()
422 433 {
423 434 Logger *logger = Logger::getInstance();
424 435 logger->logf(LOG_NOTICE, "Quitting FlashMQ");
  436 + timer.stop();
425 437 running = false;
426 438 }
427 439  
... ... @@ -473,6 +485,15 @@ void MainApp::queueConfigReload()
473 485 auto f = std::bind(&MainApp::reloadConfig, this);
474 486 taskQueue.push_front(f);
475 487  
476   - uint64_t one = 1;
477   - write(taskEventFd, &one, sizeof(uint64_t));
  488 + wakeUpThread();
  489 +}
  490 +
  491 +void MainApp::queueCleanup()
  492 +{
  493 + std::lock_guard<std::mutex> locker(eventMutex);
  494 +
  495 + auto f = std::bind(&SubscriptionStore::removeExpiredSessionsClients, subscriptionStore.get());
  496 + taskQueue.push_front(f);
  497 +
  498 + wakeUpThread();
478 499 }
... ...
mainapp.h
... ... @@ -19,6 +19,7 @@
19 19 #include "mqttpacket.h"
20 20 #include "subscriptionstore.h"
21 21 #include "configfileparser.h"
  22 +#include "timer.h"
22 23  
23 24 class MainApp
24 25 {
... ... @@ -33,6 +34,7 @@ class MainApp
33 34 int epollFdAccept = -1;
34 35 int taskEventFd = -1;
35 36 std::mutex eventMutex;
  37 + Timer timer;
36 38  
37 39 uint listenPort = 0;
38 40 uint sslListenPort = 0;
... ... @@ -46,6 +48,7 @@ class MainApp
46 48 static void showLicense();
47 49 void setCertAndKeyFromConfig();
48 50 int createListenSocket(int portNr, bool ssl);
  51 + void wakeUpThread();
49 52  
50 53 MainApp(const std::string &configFilePath);
51 54 public:
... ... @@ -61,6 +64,7 @@ public:
61 64  
62 65  
63 66 void queueConfigReload();
  67 + void queueCleanup();
64 68 };
65 69  
66 70 #endif // MAINAPP_H
... ...
session.cpp
... ... @@ -8,6 +8,11 @@ Session::Session()
8 8  
9 9 }
10 10  
  11 +Session::~Session()
  12 +{
  13 + logger->logf(LOG_DEBUG, "Session %s is being destroyed.", getClientId().c_str());
  14 +}
  15 +
11 16 bool Session::clientDisconnected() const
12 17 {
13 18 return client.expired();
... ... @@ -110,3 +115,14 @@ void Session::sendPendingQosMessages()
110 115 }
111 116 }
112 117 }
  118 +
  119 +void Session::touch(time_t val)
  120 +{
  121 + time_t newval = val > 0 ? val : time(NULL);
  122 + lastTouched = newval;
  123 +}
  124 +
  125 +bool Session::hasExpired()
  126 +{
  127 + return clientDisconnected() && (lastTouched + EXPIRE_SESSION_AFTER) < time(NULL);
  128 +}
... ...
session.h
... ... @@ -12,6 +12,9 @@
12 12 #define MAX_QOS_MSG_PENDING_PER_CLIENT 32
13 13 #define MAX_QOS_BYTES_PENDING_PER_CLIENT 4096
14 14  
  15 +// TODO make setting
  16 +#define EXPIRE_SESSION_AFTER 1209600
  17 +
15 18 struct QueuedQosPacket
16 19 {
17 20 uint16_t id;
... ... @@ -26,11 +29,14 @@ class Session
26 29 std::mutex qosQueueMutex;
27 30 uint16_t nextPacketId = 0;
28 31 ssize_t qosQueueBytes = 0;
  32 + time_t lastTouched = time(NULL);
29 33 Logger *logger = Logger::getInstance();
  34 +
30 35 public:
31 36 Session();
32 37 Session(const Session &other) = delete;
33 38 Session(Session &&other) = delete;
  39 + ~Session();
34 40  
35 41 const std::string &getClientId() const { return client_id; }
36 42 bool clientDisconnected() const;
... ... @@ -39,6 +45,8 @@ public:
39 45 void writePacket(const MqttPacket &packet, char max_qos);
40 46 void clearQosMessage(uint16_t packet_id);
41 47 void sendPendingQosMessages();
  48 + void touch(time_t val = 0);
  49 + bool hasExpired();
42 50 };
43 51  
44 52 #endif // SESSION_H
... ...
subscriptionstore.cpp
... ... @@ -30,6 +30,7 @@ void SubscriptionNode::addSubscriber(const std::shared_ptr&lt;Session&gt; &amp;subscriber,
30 30 }
31 31 }
32 32  
  33 +
33 34 SubscriptionStore::SubscriptionStore() :
34 35 root(new SubscriptionNode("root")),
35 36 sessionsByIdConst(sessionsById)
... ... @@ -220,6 +221,89 @@ void SubscriptionStore::setRetainedMessage(const std::string &amp;topic, const std::
220 221 retainedMessages.insert(std::move(rm));
221 222 }
222 223  
  224 +// Clean up the weak pointers to sessions and remove nodes that are empty.
  225 +int SubscriptionNode::cleanSubscriptions()
  226 +{
  227 + int subscribersLeftInChildren = 0;
  228 + auto childrenIt = children.begin();
  229 + while(childrenIt != children.end())
  230 + {
  231 + subscribersLeftInChildren += childrenIt->second->cleanSubscriptions();
  232 +
  233 + if (subscribersLeftInChildren > 0)
  234 + childrenIt++;
  235 + else
  236 + {
  237 + Logger::getInstance()->logf(LOG_DEBUG, "Removing orphaned subscriber node from %s", childrenIt->first.c_str());
  238 + childrenIt = children.erase(childrenIt);
  239 + }
  240 + }
  241 +
  242 + std::list<std::unique_ptr<SubscriptionNode>*> wildcardChildren;
  243 + wildcardChildren.push_back(&childrenPlus);
  244 + wildcardChildren.push_back(&childrenPound);
  245 +
  246 + for (std::unique_ptr<SubscriptionNode> *node : wildcardChildren)
  247 + {
  248 + std::unique_ptr<SubscriptionNode> &node_ = *node;
  249 +
  250 + if (!node_)
  251 + continue;
  252 + int n = node_->cleanSubscriptions();
  253 + subscribersLeftInChildren += n;
  254 +
  255 + if (n == 0)
  256 + {
  257 + Logger::getInstance()->logf(LOG_DEBUG, "Resetting wildcard children");
  258 + node_.reset();
  259 + }
  260 + }
  261 +
  262 + // This is not particularlly fast when it's many items. But we don't do it often, so is probably okay.
  263 + auto it = subscribers.begin();
  264 + while (it != subscribers.end())
  265 + {
  266 + if (it->sessionGone())
  267 + {
  268 + Logger::getInstance()->logf(LOG_DEBUG, "Removing empty spot in subscribers vector");
  269 + it = subscribers.erase(it);
  270 + }
  271 + else
  272 + it++;
  273 + }
  274 +
  275 + return subscribers.size() + subscribersLeftInChildren;
  276 +}
  277 +
  278 +// This is not MQTT compliant, but the standard doesn't keep real world constraints into account.
  279 +void SubscriptionStore::removeExpiredSessionsClients()
  280 +{
  281 + RWLockGuard lock_guard(&subscriptionsRwlock);
  282 + lock_guard.wrlock();
  283 +
  284 + logger->logf(LOG_NOTICE, "Cleaning out old sessions");
  285 +
  286 + auto session_it = sessionsById.begin();
  287 + while (session_it != sessionsById.end())
  288 + {
  289 + std::shared_ptr<Session> &session = session_it->second;
  290 +
  291 + if (session->hasExpired())
  292 + {
  293 +#ifndef NDEBUG
  294 + logger->logf(LOG_DEBUG, "Removing expired session from store %s", session->getClientId().c_str());
  295 +#endif
  296 + session_it = sessionsById.erase(session_it);
  297 + }
  298 + else
  299 + session_it++;
  300 + }
  301 +
  302 + logger->logf(LOG_NOTICE, "Rebuilding subscription tree");
  303 +
  304 + root->cleanSubscriptions();
  305 +}
  306 +
223 307 // QoS is not used in the comparision. This means you upgrade your QoS by subscribing again. The
224 308 // specs don't specify what to do there.
225 309 bool Subscription::operator==(const Subscription &rhs) const
... ... @@ -240,3 +324,8 @@ void Subscription::reset()
240 324 session.reset();
241 325 qos = 0;
242 326 }
  327 +
  328 +bool Subscription::sessionGone() const
  329 +{
  330 + return session.expired();
  331 +}
... ...
subscriptionstore.h
... ... @@ -23,10 +23,11 @@ struct RetainedPayload
23 23  
24 24 struct Subscription
25 25 {
26   - std::weak_ptr<Session> session; // Weak pointer expires when session has been cleaned by 'clean session' connect.
  26 + std::weak_ptr<Session> session; // Weak pointer expires when session has been cleaned by 'clean session' connect or when it was remove because it expired
27 27 char qos;
28 28 bool operator==(const Subscription &rhs) const;
29 29 void reset();
  30 + bool sessionGone() const;
30 31 };
31 32  
32 33 class SubscriptionNode
... ... @@ -45,6 +46,7 @@ public:
45 46 std::unique_ptr<SubscriptionNode> childrenPlus;
46 47 std::unique_ptr<SubscriptionNode> childrenPound;
47 48  
  49 + int cleanSubscriptions();
48 50 };
49 51  
50 52 class SubscriptionStore
... ... @@ -62,6 +64,7 @@ class SubscriptionStore
62 64 void publishNonRecursively(const MqttPacket &packet, const std::vector<Subscription> &subscribers) const;
63 65 void publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end,
64 66 std::unique_ptr<SubscriptionNode> &next, const MqttPacket &packet) const;
  67 +
65 68 public:
66 69 SubscriptionStore();
67 70  
... ... @@ -72,6 +75,8 @@ public:
72 75 void giveClientRetainedMessages(const std::shared_ptr<Session> &ses, const std::string &subscribe_topic, char max_qos);
73 76  
74 77 void setRetainedMessage(const std::string &topic, const std::string &payload, char qos);
  78 +
  79 + void removeExpiredSessionsClients();
75 80 };
76 81  
77 82 #endif // SUBSCRIPTIONSTORE_H
... ...
timer.cpp 0 → 100644
  1 +#include "timer.h"
  2 +#include "sys/eventfd.h"
  3 +#include "sys/epoll.h"
  4 +#include "unistd.h"
  5 +
  6 +#include "utils.h"
  7 +
  8 +void CallbackEntry::updateExectedAt()
  9 +{
  10 + this->lastExecuted = currentMSecsSinceEpoch();
  11 +}
  12 +
  13 +uint64_t CallbackEntry::getNextCallMs() const
  14 +{
  15 + int64_t elapsedSinceLastCall = currentMSecsSinceEpoch() - lastExecuted;
  16 + if (elapsedSinceLastCall < 0) // Correct for clock drift
  17 + elapsedSinceLastCall = 0;
  18 +
  19 + int64_t newDelay = this->interval - elapsedSinceLastCall;
  20 + if (newDelay < 0)
  21 + newDelay = 0;
  22 + return newDelay;
  23 +}
  24 +
  25 +bool CallbackEntry::operator <(const CallbackEntry &other) const
  26 +{
  27 + return this->getNextCallMs() < other.getNextCallMs();
  28 +}
  29 +
  30 +Timer::Timer()
  31 +{
  32 + fd = eventfd(0, EFD_NONBLOCK);
  33 + epollfd = check<std::runtime_error>(epoll_create(999));
  34 +
  35 + struct epoll_event ev;
  36 + memset(&ev, 0, sizeof (struct epoll_event));
  37 + ev.data.fd = fd;
  38 + ev.events = EPOLLIN;
  39 + check<std::runtime_error>(epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev));
  40 +}
  41 +
  42 +Timer::~Timer()
  43 +{
  44 + close(fd);
  45 + close(epollfd);
  46 +}
  47 +
  48 +void Timer::start()
  49 +{
  50 + running = true;
  51 +
  52 + auto f = std::bind(&Timer::process, this);
  53 + this->t = std::thread(f, this);
  54 +
  55 + pthread_t native = this->t.native_handle();
  56 + pthread_setname_np(native, "Timer");
  57 +}
  58 +
  59 +void Timer::stop()
  60 +{
  61 + running = false;
  62 + uint64_t one = 1;
  63 + write(fd, &one, sizeof(uint64_t));
  64 + t.join();
  65 +}
  66 +
  67 +void Timer::addCallback(std::function<void ()> f, uint64_t interval_ms, const std::string &name)
  68 +{
  69 + logger->logf(LOG_DEBUG, "Adding event '%s' to the timer.", name.c_str());
  70 +
  71 + CallbackEntry c;
  72 + c.f = f;
  73 + c.interval = interval_ms;
  74 + c.name = name;
  75 + callbacks.push_back(std::move(c));
  76 + sortAndSetSleeptimeTillNext();
  77 + wakeUpPoll();
  78 +}
  79 +
  80 +void Timer::sortAndSetSleeptimeTillNext()
  81 +{
  82 + std::sort(callbacks.begin(), callbacks.end());
  83 + this->sleeptime = callbacks.front().getNextCallMs();
  84 +}
  85 +
  86 +void Timer::process()
  87 +{
  88 + struct epoll_event events[MAX_TIMER_EVENTS];
  89 + memset(&events, 0, sizeof (struct epoll_event)*MAX_TIMER_EVENTS);
  90 +
  91 + while (running)
  92 + {
  93 + logger->logf(LOG_DEBUG, "Timer sleeping for %d ms until event '%s' or callbacks are added.", sleeptime, callbacks.front().name.c_str());
  94 + int num_fds = epoll_wait(this->epollfd, events, MAX_TIMER_EVENTS, sleeptime);
  95 +
  96 + if (!running)
  97 + continue;
  98 +
  99 + if (num_fds < 0)
  100 + {
  101 + if (errno == EINTR)
  102 + continue;
  103 + logger->logf(LOG_ERR, "Waiting for timer fd error: %s", strerror(errno));
  104 + }
  105 +
  106 + // If it was the eventfd, an action woke up the loop, and not a pending event.
  107 + for (int i = 0; i < num_fds; i++)
  108 + {
  109 + int cur_fd = events[i].data.fd;
  110 +
  111 + if (cur_fd == this->fd)
  112 + {
  113 + uint64_t eventfd_value = 0;
  114 + check<std::runtime_error>(read(fd, &eventfd_value, sizeof(uint64_t)));
  115 + }
  116 +
  117 + continue;
  118 + }
  119 +
  120 + CallbackEntry &c = callbacks.front();
  121 + c.updateExectedAt();
  122 + c.f();
  123 +
  124 + sortAndSetSleeptimeTillNext();
  125 + }
  126 +}
  127 +
  128 +void Timer::wakeUpPoll()
  129 +{
  130 + if (!running)
  131 + return;
  132 +
  133 + uint64_t one = 1;
  134 + write(fd, &one, sizeof(uint64_t));
  135 +}
  136 +
... ...
timer.h 0 → 100644
  1 +#ifndef TIMER_H
  2 +#define TIMER_H
  3 +
  4 +#include <functional>
  5 +#include <thread>
  6 +#include <list>
  7 +
  8 +#include "logger.h"
  9 +#include "utils.h"
  10 +
  11 +#define MAX_TIMER_EVENTS 32
  12 +
  13 +struct CallbackEntry
  14 +{
  15 + uint64_t lastExecuted = currentMSecsSinceEpoch(); // assume the first one executed to avoid instantly calling it.
  16 + uint64_t interval = 0;
  17 + std::function<void ()> f = nullptr;
  18 + std::string name;
  19 +
  20 + void updateExectedAt();
  21 + uint64_t getNextCallMs() const;
  22 + bool operator <(const CallbackEntry &other) const;
  23 +};
  24 +
  25 +// Simple timer that calls your callback. The callback is executed on the timer thread.
  26 +class Timer
  27 +{
  28 + std::thread t;
  29 + int epollfd = 0;
  30 + int fd = 0;
  31 + uint64_t sleeptime = 1000;
  32 + int running = false;
  33 + Logger *logger = Logger::getInstance();
  34 + std::vector<CallbackEntry> callbacks;
  35 +
  36 + void sortAndSetSleeptimeTillNext();
  37 + void process();
  38 + void wakeUpPoll();
  39 +public:
  40 + Timer();
  41 + ~Timer();
  42 + void start();
  43 + void stop();
  44 + void addCallback(std::function<void()> f, uint64_t interval_ms, const std::string &name);
  45 +};
  46 +
  47 +#endif // TIMER_H
... ...
utils.cpp
1 1 #include "utils.h"
2 2  
  3 +#include "sys/time.h"
  4 +
3 5 #include <algorithm>
4 6  
5 7 std::list<std::__cxx11::string> split(const std::string &input, const char sep, size_t max, bool keep_empty_parts)
... ... @@ -168,3 +170,11 @@ bool startsWith(const std::string &amp;s, const std::string &amp;needle)
168 170 {
169 171 return s.find(needle) == 0;
170 172 }
  173 +
  174 +int64_t currentMSecsSinceEpoch()
  175 +{
  176 + struct timeval te;
  177 + gettimeofday(&te, NULL);
  178 + int64_t milliseconds = te.tv_sec*1000LL + te.tv_usec/1000;
  179 + return milliseconds;
  180 +}
... ...
... ... @@ -37,5 +37,6 @@ void rtrim(std::string &amp;s);
37 37 void trim(std::string &s);
38 38 bool startsWith(const std::string &s, const std::string &needle);
39 39  
  40 +int64_t currentMSecsSinceEpoch();
40 41  
41 42 #endif // UTILS_H
... ...