diff --git a/CMakeLists.txt b/CMakeLists.txt index fe8cc75..d5d3ab2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,6 +27,7 @@ add_executable(FlashMQ authplugin.cpp configfileparser.cpp sslctxmanager.cpp + timer.cpp ) target_link_libraries(FlashMQ pthread dl ssl crypto) diff --git a/client.cpp b/client.cpp index c910bc7..40c6d85 100644 --- a/client.cpp +++ b/client.cpp @@ -212,6 +212,8 @@ bool Client::readFdIntoBuffer() } lastActivity = time(NULL); + if (session) + session->touch(lastActivity); return true; } diff --git a/mainapp.cpp b/mainapp.cpp index e733685..4a9fecd 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -153,6 +153,9 @@ MainApp::MainApp(const std::string &configFilePath) : confFileParser.reset(new ConfigFileParser(configFilePath)); loadConfig(); + + auto f = std::bind(&MainApp::queueCleanup, this); + timer.addCallback(f, 86400000, "session expiration"); } MainApp::~MainApp() @@ -237,6 +240,12 @@ int MainApp::createListenSocket(int portNr, bool ssl) return listen_fd; } +void MainApp::wakeUpThread() +{ + uint64_t one = 1; + write(taskEventFd, &one, sizeof(uint64_t)); +} + void MainApp::initMainApp(int argc, char *argv[]) { if (instance != nullptr) @@ -318,6 +327,8 @@ MainApp *MainApp::getMainApp() void MainApp::start() { + timer.start(); + int listen_fd_plain = createListenSocket(this->listenPort, false); int listen_fd_ssl = createListenSocket(this->sslListenPort, true); @@ -422,6 +433,7 @@ void MainApp::quit() { Logger *logger = Logger::getInstance(); logger->logf(LOG_NOTICE, "Quitting FlashMQ"); + timer.stop(); running = false; } @@ -473,6 +485,15 @@ void MainApp::queueConfigReload() auto f = std::bind(&MainApp::reloadConfig, this); taskQueue.push_front(f); - uint64_t one = 1; - write(taskEventFd, &one, sizeof(uint64_t)); + wakeUpThread(); +} + +void MainApp::queueCleanup() +{ + std::lock_guard locker(eventMutex); + + auto f = std::bind(&SubscriptionStore::removeExpiredSessionsClients, subscriptionStore.get()); + taskQueue.push_front(f); + + wakeUpThread(); } diff --git a/mainapp.h b/mainapp.h index 6a6573e..6b1077a 100644 --- a/mainapp.h +++ b/mainapp.h @@ -19,6 +19,7 @@ #include "mqttpacket.h" #include "subscriptionstore.h" #include "configfileparser.h" +#include "timer.h" class MainApp { @@ -33,6 +34,7 @@ class MainApp int epollFdAccept = -1; int taskEventFd = -1; std::mutex eventMutex; + Timer timer; uint listenPort = 0; uint sslListenPort = 0; @@ -46,6 +48,7 @@ class MainApp static void showLicense(); void setCertAndKeyFromConfig(); int createListenSocket(int portNr, bool ssl); + void wakeUpThread(); MainApp(const std::string &configFilePath); public: @@ -61,6 +64,7 @@ public: void queueConfigReload(); + void queueCleanup(); }; #endif // MAINAPP_H diff --git a/session.cpp b/session.cpp index b9a02b8..ed0eb19 100644 --- a/session.cpp +++ b/session.cpp @@ -8,6 +8,11 @@ Session::Session() } +Session::~Session() +{ + logger->logf(LOG_DEBUG, "Session %s is being destroyed.", getClientId().c_str()); +} + bool Session::clientDisconnected() const { return client.expired(); @@ -110,3 +115,14 @@ void Session::sendPendingQosMessages() } } } + +void Session::touch(time_t val) +{ + time_t newval = val > 0 ? val : time(NULL); + lastTouched = newval; +} + +bool Session::hasExpired() +{ + return clientDisconnected() && (lastTouched + EXPIRE_SESSION_AFTER) < time(NULL); +} diff --git a/session.h b/session.h index 1aa4b76..8150b86 100644 --- a/session.h +++ b/session.h @@ -12,6 +12,9 @@ #define MAX_QOS_MSG_PENDING_PER_CLIENT 32 #define MAX_QOS_BYTES_PENDING_PER_CLIENT 4096 +// TODO make setting +#define EXPIRE_SESSION_AFTER 1209600 + struct QueuedQosPacket { uint16_t id; @@ -26,11 +29,14 @@ class Session std::mutex qosQueueMutex; uint16_t nextPacketId = 0; ssize_t qosQueueBytes = 0; + time_t lastTouched = time(NULL); Logger *logger = Logger::getInstance(); + public: Session(); Session(const Session &other) = delete; Session(Session &&other) = delete; + ~Session(); const std::string &getClientId() const { return client_id; } bool clientDisconnected() const; @@ -39,6 +45,8 @@ public: void writePacket(const MqttPacket &packet, char max_qos); void clearQosMessage(uint16_t packet_id); void sendPendingQosMessages(); + void touch(time_t val = 0); + bool hasExpired(); }; #endif // SESSION_H diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index 1829dcd..f067df1 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -30,6 +30,7 @@ void SubscriptionNode::addSubscriber(const std::shared_ptr &subscriber, } } + SubscriptionStore::SubscriptionStore() : root(new SubscriptionNode("root")), sessionsByIdConst(sessionsById) @@ -220,6 +221,89 @@ void SubscriptionStore::setRetainedMessage(const std::string &topic, const std:: retainedMessages.insert(std::move(rm)); } +// Clean up the weak pointers to sessions and remove nodes that are empty. +int SubscriptionNode::cleanSubscriptions() +{ + int subscribersLeftInChildren = 0; + auto childrenIt = children.begin(); + while(childrenIt != children.end()) + { + subscribersLeftInChildren += childrenIt->second->cleanSubscriptions(); + + if (subscribersLeftInChildren > 0) + childrenIt++; + else + { + Logger::getInstance()->logf(LOG_DEBUG, "Removing orphaned subscriber node from %s", childrenIt->first.c_str()); + childrenIt = children.erase(childrenIt); + } + } + + std::list*> wildcardChildren; + wildcardChildren.push_back(&childrenPlus); + wildcardChildren.push_back(&childrenPound); + + for (std::unique_ptr *node : wildcardChildren) + { + std::unique_ptr &node_ = *node; + + if (!node_) + continue; + int n = node_->cleanSubscriptions(); + subscribersLeftInChildren += n; + + if (n == 0) + { + Logger::getInstance()->logf(LOG_DEBUG, "Resetting wildcard children"); + node_.reset(); + } + } + + // This is not particularlly fast when it's many items. But we don't do it often, so is probably okay. + auto it = subscribers.begin(); + while (it != subscribers.end()) + { + if (it->sessionGone()) + { + Logger::getInstance()->logf(LOG_DEBUG, "Removing empty spot in subscribers vector"); + it = subscribers.erase(it); + } + else + it++; + } + + return subscribers.size() + subscribersLeftInChildren; +} + +// This is not MQTT compliant, but the standard doesn't keep real world constraints into account. +void SubscriptionStore::removeExpiredSessionsClients() +{ + RWLockGuard lock_guard(&subscriptionsRwlock); + lock_guard.wrlock(); + + logger->logf(LOG_NOTICE, "Cleaning out old sessions"); + + auto session_it = sessionsById.begin(); + while (session_it != sessionsById.end()) + { + std::shared_ptr &session = session_it->second; + + if (session->hasExpired()) + { +#ifndef NDEBUG + logger->logf(LOG_DEBUG, "Removing expired session from store %s", session->getClientId().c_str()); +#endif + session_it = sessionsById.erase(session_it); + } + else + session_it++; + } + + logger->logf(LOG_NOTICE, "Rebuilding subscription tree"); + + root->cleanSubscriptions(); +} + // QoS is not used in the comparision. This means you upgrade your QoS by subscribing again. The // specs don't specify what to do there. bool Subscription::operator==(const Subscription &rhs) const @@ -240,3 +324,8 @@ void Subscription::reset() session.reset(); qos = 0; } + +bool Subscription::sessionGone() const +{ + return session.expired(); +} diff --git a/subscriptionstore.h b/subscriptionstore.h index 4d62022..1732574 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -23,10 +23,11 @@ struct RetainedPayload struct Subscription { - std::weak_ptr session; // Weak pointer expires when session has been cleaned by 'clean session' connect. + std::weak_ptr session; // Weak pointer expires when session has been cleaned by 'clean session' connect or when it was remove because it expired char qos; bool operator==(const Subscription &rhs) const; void reset(); + bool sessionGone() const; }; class SubscriptionNode @@ -45,6 +46,7 @@ public: std::unique_ptr childrenPlus; std::unique_ptr childrenPound; + int cleanSubscriptions(); }; class SubscriptionStore @@ -62,6 +64,7 @@ class SubscriptionStore void publishNonRecursively(const MqttPacket &packet, const std::vector &subscribers) const; void publishRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, std::unique_ptr &next, const MqttPacket &packet) const; + public: SubscriptionStore(); @@ -72,6 +75,8 @@ public: void giveClientRetainedMessages(const std::shared_ptr &ses, const std::string &subscribe_topic, char max_qos); void setRetainedMessage(const std::string &topic, const std::string &payload, char qos); + + void removeExpiredSessionsClients(); }; #endif // SUBSCRIPTIONSTORE_H diff --git a/timer.cpp b/timer.cpp new file mode 100644 index 0000000..18514a3 --- /dev/null +++ b/timer.cpp @@ -0,0 +1,136 @@ +#include "timer.h" +#include "sys/eventfd.h" +#include "sys/epoll.h" +#include "unistd.h" + +#include "utils.h" + +void CallbackEntry::updateExectedAt() +{ + this->lastExecuted = currentMSecsSinceEpoch(); +} + +uint64_t CallbackEntry::getNextCallMs() const +{ + int64_t elapsedSinceLastCall = currentMSecsSinceEpoch() - lastExecuted; + if (elapsedSinceLastCall < 0) // Correct for clock drift + elapsedSinceLastCall = 0; + + int64_t newDelay = this->interval - elapsedSinceLastCall; + if (newDelay < 0) + newDelay = 0; + return newDelay; +} + +bool CallbackEntry::operator <(const CallbackEntry &other) const +{ + return this->getNextCallMs() < other.getNextCallMs(); +} + +Timer::Timer() +{ + fd = eventfd(0, EFD_NONBLOCK); + epollfd = check(epoll_create(999)); + + struct epoll_event ev; + memset(&ev, 0, sizeof (struct epoll_event)); + ev.data.fd = fd; + ev.events = EPOLLIN; + check(epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev)); +} + +Timer::~Timer() +{ + close(fd); + close(epollfd); +} + +void Timer::start() +{ + running = true; + + auto f = std::bind(&Timer::process, this); + this->t = std::thread(f, this); + + pthread_t native = this->t.native_handle(); + pthread_setname_np(native, "Timer"); +} + +void Timer::stop() +{ + running = false; + uint64_t one = 1; + write(fd, &one, sizeof(uint64_t)); + t.join(); +} + +void Timer::addCallback(std::function f, uint64_t interval_ms, const std::string &name) +{ + logger->logf(LOG_DEBUG, "Adding event '%s' to the timer.", name.c_str()); + + CallbackEntry c; + c.f = f; + c.interval = interval_ms; + c.name = name; + callbacks.push_back(std::move(c)); + sortAndSetSleeptimeTillNext(); + wakeUpPoll(); +} + +void Timer::sortAndSetSleeptimeTillNext() +{ + std::sort(callbacks.begin(), callbacks.end()); + this->sleeptime = callbacks.front().getNextCallMs(); +} + +void Timer::process() +{ + struct epoll_event events[MAX_TIMER_EVENTS]; + memset(&events, 0, sizeof (struct epoll_event)*MAX_TIMER_EVENTS); + + while (running) + { + logger->logf(LOG_DEBUG, "Timer sleeping for %d ms until event '%s' or callbacks are added.", sleeptime, callbacks.front().name.c_str()); + int num_fds = epoll_wait(this->epollfd, events, MAX_TIMER_EVENTS, sleeptime); + + if (!running) + continue; + + if (num_fds < 0) + { + if (errno == EINTR) + continue; + logger->logf(LOG_ERR, "Waiting for timer fd error: %s", strerror(errno)); + } + + // If it was the eventfd, an action woke up the loop, and not a pending event. + for (int i = 0; i < num_fds; i++) + { + int cur_fd = events[i].data.fd; + + if (cur_fd == this->fd) + { + uint64_t eventfd_value = 0; + check(read(fd, &eventfd_value, sizeof(uint64_t))); + } + + continue; + } + + CallbackEntry &c = callbacks.front(); + c.updateExectedAt(); + c.f(); + + sortAndSetSleeptimeTillNext(); + } +} + +void Timer::wakeUpPoll() +{ + if (!running) + return; + + uint64_t one = 1; + write(fd, &one, sizeof(uint64_t)); +} + diff --git a/timer.h b/timer.h new file mode 100644 index 0000000..3efaaf4 --- /dev/null +++ b/timer.h @@ -0,0 +1,47 @@ +#ifndef TIMER_H +#define TIMER_H + +#include +#include +#include + +#include "logger.h" +#include "utils.h" + +#define MAX_TIMER_EVENTS 32 + +struct CallbackEntry +{ + uint64_t lastExecuted = currentMSecsSinceEpoch(); // assume the first one executed to avoid instantly calling it. + uint64_t interval = 0; + std::function f = nullptr; + std::string name; + + void updateExectedAt(); + uint64_t getNextCallMs() const; + bool operator <(const CallbackEntry &other) const; +}; + +// Simple timer that calls your callback. The callback is executed on the timer thread. +class Timer +{ + std::thread t; + int epollfd = 0; + int fd = 0; + uint64_t sleeptime = 1000; + int running = false; + Logger *logger = Logger::getInstance(); + std::vector callbacks; + + void sortAndSetSleeptimeTillNext(); + void process(); + void wakeUpPoll(); +public: + Timer(); + ~Timer(); + void start(); + void stop(); + void addCallback(std::function f, uint64_t interval_ms, const std::string &name); +}; + +#endif // TIMER_H diff --git a/utils.cpp b/utils.cpp index 15be2a6..5fe5a4c 100644 --- a/utils.cpp +++ b/utils.cpp @@ -1,5 +1,7 @@ #include "utils.h" +#include "sys/time.h" + #include std::list split(const std::string &input, const char sep, size_t max, bool keep_empty_parts) @@ -168,3 +170,11 @@ bool startsWith(const std::string &s, const std::string &needle) { return s.find(needle) == 0; } + +int64_t currentMSecsSinceEpoch() +{ + struct timeval te; + gettimeofday(&te, NULL); + int64_t milliseconds = te.tv_sec*1000LL + te.tv_usec/1000; + return milliseconds; +} diff --git a/utils.h b/utils.h index 57409df..0267672 100644 --- a/utils.h +++ b/utils.h @@ -37,5 +37,6 @@ void rtrim(std::string &s); void trim(std::string &s); bool startsWith(const std::string &s, const std::string &needle); +int64_t currentMSecsSinceEpoch(); #endif // UTILS_H