diff --git a/globalsettings.cpp b/globalsettings.cpp index e559a67..5e10a0f 100644 --- a/globalsettings.cpp +++ b/globalsettings.cpp @@ -1,16 +1,3 @@ #include "globalsettings.h" -GlobalSettings *GlobalSettings::instance = nullptr; -GlobalSettings::GlobalSettings() -{ - -} - -GlobalSettings *GlobalSettings::getInstance() -{ - if (instance == nullptr) - instance = new GlobalSettings(); - - return instance; -} diff --git a/globalsettings.h b/globalsettings.h index 14fed52..11a2437 100644 --- a/globalsettings.h +++ b/globalsettings.h @@ -1,14 +1,9 @@ #ifndef GLOBALSETTINGS_H #define GLOBALSETTINGS_H -// 'Global' as in, needed outside of the mainapp, like listen ports. -class GlobalSettings +// Defaults are defined in ConfigFileParser +struct GlobalSettings { - static GlobalSettings *instance; - GlobalSettings(); -public: - static GlobalSettings *getInstance(); - bool allow_unsafe_clientid_chars = false; }; #endif // GLOBALSETTINGS_H diff --git a/mainapp.cpp b/mainapp.cpp index 6bb8bf8..81f4415 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -25,7 +25,6 @@ void do_thread_work(ThreadData *threadData) memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); std::vector packetQueueIn; - time_t lastKeepAliveCheck = time(NULL); Logger *logger = Logger::getInstance(); @@ -59,6 +58,21 @@ void do_thread_work(ThreadData *threadData) struct epoll_event cur_ev = events[i]; int fd = cur_ev.data.fd; + if (fd == threadData->taskEventFd) + { + uint64_t eventfd_value = 0; + check(read(fd, &eventfd_value, sizeof(uint64_t))); + + std::lock_guard locker(threadData->taskQueueMutex); + for(auto &f : threadData->taskQueue) + { + f(); + } + threadData->taskQueue.clear(); + + continue; + } + Client_p client = threadData->getClient(fd); if (client) @@ -127,19 +141,6 @@ void do_thread_work(ThreadData *threadData) } } packetQueueIn.clear(); - - try - { - if (lastKeepAliveCheck + 30 < time(NULL)) - { - if (threadData->doKeepAliveCheck()) - lastKeepAliveCheck = time(NULL); - } - } - catch (std::exception &ex) - { - logger->logf(LOG_ERR, "Error handling keep-alives: %s.", ex.what()); - } } } @@ -164,6 +165,9 @@ MainApp::MainApp(const std::string &configFilePath) : auto f = std::bind(&MainApp::queueCleanup, this); timer.addCallback(f, 86400000, "session expiration"); + + auto fKeepAlive = std::bind(&MainApp::queueKeepAliveCheckAtAllThreads, this); + timer.addCallback(fKeepAlive, 30000, "keep-alive check"); } MainApp::~MainApp() @@ -254,6 +258,14 @@ void MainApp::wakeUpThread() check(write(taskEventFd, &one, sizeof(uint64_t))); } +void MainApp::queueKeepAliveCheckAtAllThreads() +{ + for (std::shared_ptr &thread : threads) + { + thread->queueDoKeepAliveCheck(); + } +} + void MainApp::initMainApp(int argc, char *argv[]) { if (instance != nullptr) @@ -352,7 +364,7 @@ void MainApp::start() for (int i = 0; i < num_threads; i++) { - std::shared_ptr t(new ThreadData(i, subscriptionStore, *confFileParser.get())); + std::shared_ptr t(new ThreadData(i, subscriptionStore, *confFileParser.get(), settings)); t->start(&do_thread_work); threads.push_back(t); } @@ -413,6 +425,7 @@ void MainApp::start() uint64_t eventfd_value = 0; check(read(cur_fd, &eventfd_value, sizeof(uint64_t))); + std::lock_guard locker(eventMutex); for(auto &f : taskQueue) { f(); @@ -453,7 +466,6 @@ void MainApp::quit() void MainApp::loadConfig() { Logger *logger = Logger::getInstance(); - GlobalSettings *setting = GlobalSettings::getInstance(); // Atomic loading, first test. confFileParser->loadFile(true); @@ -472,9 +484,14 @@ void MainApp::loadConfig() SSL_CTX_set_options(sslctx, SSL_OP_NO_TLSv1); // TODO: config option } - setting->allow_unsafe_clientid_chars = confFileParser->allowUnsafeClientidChars; + settings.allow_unsafe_clientid_chars = confFileParser->allowUnsafeClientidChars; setCertAndKeyFromConfig(); + + for (std::shared_ptr &thread : threads) + { + thread->queueReload(settings); + } } void MainApp::reloadConfig() diff --git a/mainapp.h b/mainapp.h index 97b6746..c2f1263 100644 --- a/mainapp.h +++ b/mainapp.h @@ -37,6 +37,7 @@ class MainApp int taskEventFd = -1; std::mutex eventMutex; Timer timer; + GlobalSettings settings; uint listenPort = 0; uint sslListenPort = 0; @@ -51,6 +52,7 @@ class MainApp void setCertAndKeyFromConfig(); int createListenSocket(int portNr, bool ssl); void wakeUpThread(); + void queueKeepAliveCheckAtAllThreads(); MainApp(const std::string &configFilePath); public: diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 92ffbeb..bf1e5f3 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -152,12 +152,12 @@ void MqttPacket::handleConnect() if (sender->hasConnectPacketSeen()) throw ProtocolError("Client already sent a CONNECT."); - GlobalSettings *settings = GlobalSettings::getInstance(); - std::shared_ptr subscriptionStore = sender->getThreadData()->getSubscriptionStore(); uint16_t variable_header_length = readTwoBytesToUInt16(); + const GlobalSettings &settings = sender->getThreadData()->settingsLocalCopy; + if (variable_header_length == 4 || variable_header_length == 6) { char *c = readBytes(variable_header_length); @@ -243,7 +243,7 @@ void MqttPacket::handleConnect() bool validClientId = true; // Check for wildcard chars in case the client_id ever appears in topics. - if (!settings->allow_unsafe_clientid_chars && (strContains(client_id, "+") || strContains(client_id, "#"))) + if (!settings.allow_unsafe_clientid_chars && (strContains(client_id, "+") || strContains(client_id, "#"))) { logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false:", client_id.c_str()); validClientId = false; diff --git a/threaddata.cpp b/threaddata.cpp index d4d7c9c..e2cd748 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -2,15 +2,26 @@ #include #include -ThreadData::ThreadData(int threadnr, std::shared_ptr &subscriptionStore, ConfigFileParser &confFileParser) : +ThreadData::ThreadData(int threadnr, std::shared_ptr &subscriptionStore, ConfigFileParser &confFileParser, const GlobalSettings &settings) : subscriptionStore(subscriptionStore), confFileParser(confFileParser), authPlugin(confFileParser), - threadnr(threadnr) + threadnr(threadnr), + settingsLocalCopy(settings) { logger = Logger::getInstance(); epollfd = check(epoll_create(999)); + + taskEventFd = eventfd(0, EFD_NONBLOCK); + if (taskEventFd < 0) + throw std::runtime_error("Can't create eventfd."); + + struct epoll_event ev; + memset(&ev, 0, sizeof (struct epoll_event)); + ev.data.fd = taskEventFd; + ev.events = EPOLLIN; + check(epoll_ctl(this->epollfd, EPOLL_CTL_ADD, taskEventFd, &ev)); } void ThreadData::start(thread_f f) @@ -89,27 +100,45 @@ std::shared_ptr &ThreadData::getSubscriptionStore() return subscriptionStore; } +void ThreadData::queueDoKeepAliveCheck() +{ + std::lock_guard locker(taskQueueMutex); + + auto f = std::bind(&ThreadData::doKeepAliveCheck, this); + taskQueue.push_front(f); + + wakeUpThread(); +} + // TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial? -bool ThreadData::doKeepAliveCheck() +void ThreadData::doKeepAliveCheck() { + // We don't need to stall normal connects and disconnects for keep-alive checking. We can do it later. std::unique_lock lock(clients_by_fd_mutex, std::try_to_lock); if (!lock.owns_lock()) - return false; + return; - auto it = clients_by_fd.begin(); - while (it != clients_by_fd.end()) + logger->logf(LOG_DEBUG, "Doing keep-alive check in thread %d", threadnr); + + try { - Client_p &client = it->second; - if (client && client->keepAliveExpired()) + auto it = clients_by_fd.begin(); + while (it != clients_by_fd.end()) { - client->setDisconnectReason("Keep-alive expired: " + client->getKeepAliveInfoString()); - it = clients_by_fd.erase(it); + Client_p &client = it->second; + if (client && client->keepAliveExpired()) + { + client->setDisconnectReason("Keep-alive expired: " + client->getKeepAliveInfoString()); + it = clients_by_fd.erase(it); + } + else + it++; } - else - it++; } - - return true; + catch (std::exception &ex) + { + logger->logf(LOG_ERR, "Error handling keep-alives: %s.", ex.what()); + } } void ThreadData::initAuthPlugin() @@ -119,10 +148,14 @@ void ThreadData::initAuthPlugin() authPlugin.securityInit(false); } -void ThreadData::reload() +void ThreadData::reload(GlobalSettings settings) { + logger->logf(LOG_DEBUG, "Doing reload in thread %d", threadnr); + try { + settingsLocalCopy = settings; + authPlugin.securityCleanup(true); authPlugin.securityInit(true); } @@ -130,7 +163,29 @@ void ThreadData::reload() { logger->logf(LOG_ERR, "Error reloading auth plugin: %s. Security checks will now fail, because we don't know the status of the plugin anymore.", ex.what()); } + catch (std::exception &ex) + { + logger->logf(LOG_ERR, "Error reloading: %s.", ex.what()); + } +} + +void ThreadData::queueReload(GlobalSettings settings) +{ + std::lock_guard locker(taskQueueMutex); + + auto f = std::bind(&ThreadData::reload, this, settings); + taskQueue.push_front(f); + + wakeUpThread(); } +void ThreadData::wakeUpThread() +{ + uint64_t one = 1; + check(write(taskEventFd, &one, sizeof(uint64_t))); +} + + + diff --git a/threaddata.h b/threaddata.h index 386932d..1cd2664 100644 --- a/threaddata.h +++ b/threaddata.h @@ -10,6 +10,7 @@ #include #include #include +#include #include "forward_declarations.h" @@ -19,6 +20,7 @@ #include "configfileparser.h" #include "authplugin.h" #include "logger.h" +#include "globalsettings.h" typedef void (*thread_f)(ThreadData *); @@ -30,14 +32,22 @@ class ThreadData ConfigFileParser &confFileParser; Logger *logger; + void reload(GlobalSettings settings); + void wakeUpThread(); + void doKeepAliveCheck(); + public: AuthPlugin authPlugin; bool running = true; std::thread thread; int threadnr = 0; int epollfd = 0; + int taskEventFd = 0; + std::mutex taskQueueMutex; + std::forward_list> taskQueue; + GlobalSettings settingsLocalCopy; // Is updated on reload, within the thread loop. - ThreadData(int threadnr, std::shared_ptr &subscriptionStore, ConfigFileParser &confFileParser); + ThreadData(int threadnr, std::shared_ptr &subscriptionStore, ConfigFileParser &confFileParser, const GlobalSettings &settings); ThreadData(const ThreadData &other) = delete; ThreadData(ThreadData &&other) = delete; @@ -49,9 +59,10 @@ public: void removeClient(int fd); std::shared_ptr &getSubscriptionStore(); - bool doKeepAliveCheck(); void initAuthPlugin(); - void reload(); + void queueReload(GlobalSettings settings); + void queueDoKeepAliveCheck(); + }; #endif // THREADDATA_H