Commit f3a45a2a9df301da64045555118352dec7388dd8

Authored by Wiebe Cazemier
1 parent 99a087f4

Like main thread, also have a task queue in threads

It's also used to reload settings. Settings are copied to threads, to
avoid concurrency issues.
globalsettings.cpp
1 1 #include "globalsettings.h"
2 2  
3   -GlobalSettings *GlobalSettings::instance = nullptr;
4 3  
5   -GlobalSettings::GlobalSettings()
6   -{
7   -
8   -}
9   -
10   -GlobalSettings *GlobalSettings::getInstance()
11   -{
12   - if (instance == nullptr)
13   - instance = new GlobalSettings();
14   -
15   - return instance;
16   -}
... ...
globalsettings.h
1 1 #ifndef GLOBALSETTINGS_H
2 2 #define GLOBALSETTINGS_H
3 3  
4   -// 'Global' as in, needed outside of the mainapp, like listen ports.
5   -class GlobalSettings
  4 +// Defaults are defined in ConfigFileParser
  5 +struct GlobalSettings
6 6 {
7   - static GlobalSettings *instance;
8   - GlobalSettings();
9   -public:
10   - static GlobalSettings *getInstance();
11   -
12 7 bool allow_unsafe_clientid_chars = false;
13 8 };
14 9 #endif // GLOBALSETTINGS_H
... ...
mainapp.cpp
... ... @@ -25,7 +25,6 @@ void do_thread_work(ThreadData *threadData)
25 25 memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS);
26 26  
27 27 std::vector<MqttPacket> packetQueueIn;
28   - time_t lastKeepAliveCheck = time(NULL);
29 28  
30 29 Logger *logger = Logger::getInstance();
31 30  
... ... @@ -59,6 +58,21 @@ void do_thread_work(ThreadData *threadData)
59 58 struct epoll_event cur_ev = events[i];
60 59 int fd = cur_ev.data.fd;
61 60  
  61 + if (fd == threadData->taskEventFd)
  62 + {
  63 + uint64_t eventfd_value = 0;
  64 + check<std::runtime_error>(read(fd, &eventfd_value, sizeof(uint64_t)));
  65 +
  66 + std::lock_guard<std::mutex> locker(threadData->taskQueueMutex);
  67 + for(auto &f : threadData->taskQueue)
  68 + {
  69 + f();
  70 + }
  71 + threadData->taskQueue.clear();
  72 +
  73 + continue;
  74 + }
  75 +
62 76 Client_p client = threadData->getClient(fd);
63 77  
64 78 if (client)
... ... @@ -127,19 +141,6 @@ void do_thread_work(ThreadData *threadData)
127 141 }
128 142 }
129 143 packetQueueIn.clear();
130   -
131   - try
132   - {
133   - if (lastKeepAliveCheck + 30 < time(NULL))
134   - {
135   - if (threadData->doKeepAliveCheck())
136   - lastKeepAliveCheck = time(NULL);
137   - }
138   - }
139   - catch (std::exception &ex)
140   - {
141   - logger->logf(LOG_ERR, "Error handling keep-alives: %s.", ex.what());
142   - }
143 144 }
144 145 }
145 146  
... ... @@ -164,6 +165,9 @@ MainApp::MainApp(const std::string &amp;configFilePath) :
164 165  
165 166 auto f = std::bind(&MainApp::queueCleanup, this);
166 167 timer.addCallback(f, 86400000, "session expiration");
  168 +
  169 + auto fKeepAlive = std::bind(&MainApp::queueKeepAliveCheckAtAllThreads, this);
  170 + timer.addCallback(fKeepAlive, 30000, "keep-alive check");
167 171 }
168 172  
169 173 MainApp::~MainApp()
... ... @@ -254,6 +258,14 @@ void MainApp::wakeUpThread()
254 258 check<std::runtime_error>(write(taskEventFd, &one, sizeof(uint64_t)));
255 259 }
256 260  
  261 +void MainApp::queueKeepAliveCheckAtAllThreads()
  262 +{
  263 + for (std::shared_ptr<ThreadData> &thread : threads)
  264 + {
  265 + thread->queueDoKeepAliveCheck();
  266 + }
  267 +}
  268 +
257 269 void MainApp::initMainApp(int argc, char *argv[])
258 270 {
259 271 if (instance != nullptr)
... ... @@ -352,7 +364,7 @@ void MainApp::start()
352 364  
353 365 for (int i = 0; i < num_threads; i++)
354 366 {
355   - std::shared_ptr<ThreadData> t(new ThreadData(i, subscriptionStore, *confFileParser.get()));
  367 + std::shared_ptr<ThreadData> t(new ThreadData(i, subscriptionStore, *confFileParser.get(), settings));
356 368 t->start(&do_thread_work);
357 369 threads.push_back(t);
358 370 }
... ... @@ -413,6 +425,7 @@ void MainApp::start()
413 425 uint64_t eventfd_value = 0;
414 426 check<std::runtime_error>(read(cur_fd, &eventfd_value, sizeof(uint64_t)));
415 427  
  428 + std::lock_guard<std::mutex> locker(eventMutex);
416 429 for(auto &f : taskQueue)
417 430 {
418 431 f();
... ... @@ -453,7 +466,6 @@ void MainApp::quit()
453 466 void MainApp::loadConfig()
454 467 {
455 468 Logger *logger = Logger::getInstance();
456   - GlobalSettings *setting = GlobalSettings::getInstance();
457 469  
458 470 // Atomic loading, first test.
459 471 confFileParser->loadFile(true);
... ... @@ -472,9 +484,14 @@ void MainApp::loadConfig()
472 484 SSL_CTX_set_options(sslctx, SSL_OP_NO_TLSv1); // TODO: config option
473 485 }
474 486  
475   - setting->allow_unsafe_clientid_chars = confFileParser->allowUnsafeClientidChars;
  487 + settings.allow_unsafe_clientid_chars = confFileParser->allowUnsafeClientidChars;
476 488  
477 489 setCertAndKeyFromConfig();
  490 +
  491 + for (std::shared_ptr<ThreadData> &thread : threads)
  492 + {
  493 + thread->queueReload(settings);
  494 + }
478 495 }
479 496  
480 497 void MainApp::reloadConfig()
... ...
mainapp.h
... ... @@ -37,6 +37,7 @@ class MainApp
37 37 int taskEventFd = -1;
38 38 std::mutex eventMutex;
39 39 Timer timer;
  40 + GlobalSettings settings;
40 41  
41 42 uint listenPort = 0;
42 43 uint sslListenPort = 0;
... ... @@ -51,6 +52,7 @@ class MainApp
51 52 void setCertAndKeyFromConfig();
52 53 int createListenSocket(int portNr, bool ssl);
53 54 void wakeUpThread();
  55 + void queueKeepAliveCheckAtAllThreads();
54 56  
55 57 MainApp(const std::string &configFilePath);
56 58 public:
... ...
mqttpacket.cpp
... ... @@ -152,12 +152,12 @@ void MqttPacket::handleConnect()
152 152 if (sender->hasConnectPacketSeen())
153 153 throw ProtocolError("Client already sent a CONNECT.");
154 154  
155   - GlobalSettings *settings = GlobalSettings::getInstance();
156   -
157 155 std::shared_ptr<SubscriptionStore> subscriptionStore = sender->getThreadData()->getSubscriptionStore();
158 156  
159 157 uint16_t variable_header_length = readTwoBytesToUInt16();
160 158  
  159 + const GlobalSettings &settings = sender->getThreadData()->settingsLocalCopy;
  160 +
161 161 if (variable_header_length == 4 || variable_header_length == 6)
162 162 {
163 163 char *c = readBytes(variable_header_length);
... ... @@ -243,7 +243,7 @@ void MqttPacket::handleConnect()
243 243 bool validClientId = true;
244 244  
245 245 // Check for wildcard chars in case the client_id ever appears in topics.
246   - if (!settings->allow_unsafe_clientid_chars && (strContains(client_id, "+") || strContains(client_id, "#")))
  246 + if (!settings.allow_unsafe_clientid_chars && (strContains(client_id, "+") || strContains(client_id, "#")))
247 247 {
248 248 logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false:", client_id.c_str());
249 249 validClientId = false;
... ...
threaddata.cpp
... ... @@ -2,15 +2,26 @@
2 2 #include <string>
3 3 #include <sstream>
4 4  
5   -ThreadData::ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore, ConfigFileParser &confFileParser) :
  5 +ThreadData::ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore, ConfigFileParser &confFileParser, const GlobalSettings &settings) :
6 6 subscriptionStore(subscriptionStore),
7 7 confFileParser(confFileParser),
8 8 authPlugin(confFileParser),
9   - threadnr(threadnr)
  9 + threadnr(threadnr),
  10 + settingsLocalCopy(settings)
10 11 {
11 12 logger = Logger::getInstance();
12 13  
13 14 epollfd = check<std::runtime_error>(epoll_create(999));
  15 +
  16 + taskEventFd = eventfd(0, EFD_NONBLOCK);
  17 + if (taskEventFd < 0)
  18 + throw std::runtime_error("Can't create eventfd.");
  19 +
  20 + struct epoll_event ev;
  21 + memset(&ev, 0, sizeof (struct epoll_event));
  22 + ev.data.fd = taskEventFd;
  23 + ev.events = EPOLLIN;
  24 + check<std::runtime_error>(epoll_ctl(this->epollfd, EPOLL_CTL_ADD, taskEventFd, &ev));
14 25 }
15 26  
16 27 void ThreadData::start(thread_f f)
... ... @@ -89,27 +100,45 @@ std::shared_ptr&lt;SubscriptionStore&gt; &amp;ThreadData::getSubscriptionStore()
89 100 return subscriptionStore;
90 101 }
91 102  
  103 +void ThreadData::queueDoKeepAliveCheck()
  104 +{
  105 + std::lock_guard<std::mutex> locker(taskQueueMutex);
  106 +
  107 + auto f = std::bind(&ThreadData::doKeepAliveCheck, this);
  108 + taskQueue.push_front(f);
  109 +
  110 + wakeUpThread();
  111 +}
  112 +
92 113 // TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial?
93   -bool ThreadData::doKeepAliveCheck()
  114 +void ThreadData::doKeepAliveCheck()
94 115 {
  116 + // We don't need to stall normal connects and disconnects for keep-alive checking. We can do it later.
95 117 std::unique_lock<std::mutex> lock(clients_by_fd_mutex, std::try_to_lock);
96 118 if (!lock.owns_lock())
97   - return false;
  119 + return;
98 120  
99   - auto it = clients_by_fd.begin();
100   - while (it != clients_by_fd.end())
  121 + logger->logf(LOG_DEBUG, "Doing keep-alive check in thread %d", threadnr);
  122 +
  123 + try
101 124 {
102   - Client_p &client = it->second;
103   - if (client && client->keepAliveExpired())
  125 + auto it = clients_by_fd.begin();
  126 + while (it != clients_by_fd.end())
104 127 {
105   - client->setDisconnectReason("Keep-alive expired: " + client->getKeepAliveInfoString());
106   - it = clients_by_fd.erase(it);
  128 + Client_p &client = it->second;
  129 + if (client && client->keepAliveExpired())
  130 + {
  131 + client->setDisconnectReason("Keep-alive expired: " + client->getKeepAliveInfoString());
  132 + it = clients_by_fd.erase(it);
  133 + }
  134 + else
  135 + it++;
107 136 }
108   - else
109   - it++;
110 137 }
111   -
112   - return true;
  138 + catch (std::exception &ex)
  139 + {
  140 + logger->logf(LOG_ERR, "Error handling keep-alives: %s.", ex.what());
  141 + }
113 142 }
114 143  
115 144 void ThreadData::initAuthPlugin()
... ... @@ -119,10 +148,14 @@ void ThreadData::initAuthPlugin()
119 148 authPlugin.securityInit(false);
120 149 }
121 150  
122   -void ThreadData::reload()
  151 +void ThreadData::reload(GlobalSettings settings)
123 152 {
  153 + logger->logf(LOG_DEBUG, "Doing reload in thread %d", threadnr);
  154 +
124 155 try
125 156 {
  157 + settingsLocalCopy = settings;
  158 +
126 159 authPlugin.securityCleanup(true);
127 160 authPlugin.securityInit(true);
128 161 }
... ... @@ -130,7 +163,29 @@ void ThreadData::reload()
130 163 {
131 164 logger->logf(LOG_ERR, "Error reloading auth plugin: %s. Security checks will now fail, because we don't know the status of the plugin anymore.", ex.what());
132 165 }
  166 + catch (std::exception &ex)
  167 + {
  168 + logger->logf(LOG_ERR, "Error reloading: %s.", ex.what());
  169 + }
  170 +}
  171 +
  172 +void ThreadData::queueReload(GlobalSettings settings)
  173 +{
  174 + std::lock_guard<std::mutex> locker(taskQueueMutex);
  175 +
  176 + auto f = std::bind(&ThreadData::reload, this, settings);
  177 + taskQueue.push_front(f);
  178 +
  179 + wakeUpThread();
133 180 }
134 181  
  182 +void ThreadData::wakeUpThread()
  183 +{
  184 + uint64_t one = 1;
  185 + check<std::runtime_error>(write(taskEventFd, &one, sizeof(uint64_t)));
  186 +}
  187 +
  188 +
  189 +
135 190  
136 191  
... ...
threaddata.h
... ... @@ -10,6 +10,7 @@
10 10 #include <unordered_map>
11 11 #include <mutex>
12 12 #include <shared_mutex>
  13 +#include <functional>
13 14  
14 15 #include "forward_declarations.h"
15 16  
... ... @@ -19,6 +20,7 @@
19 20 #include "configfileparser.h"
20 21 #include "authplugin.h"
21 22 #include "logger.h"
  23 +#include "globalsettings.h"
22 24  
23 25 typedef void (*thread_f)(ThreadData *);
24 26  
... ... @@ -30,14 +32,22 @@ class ThreadData
30 32 ConfigFileParser &confFileParser;
31 33 Logger *logger;
32 34  
  35 + void reload(GlobalSettings settings);
  36 + void wakeUpThread();
  37 + void doKeepAliveCheck();
  38 +
33 39 public:
34 40 AuthPlugin authPlugin;
35 41 bool running = true;
36 42 std::thread thread;
37 43 int threadnr = 0;
38 44 int epollfd = 0;
  45 + int taskEventFd = 0;
  46 + std::mutex taskQueueMutex;
  47 + std::forward_list<std::function<void()>> taskQueue;
  48 + GlobalSettings settingsLocalCopy; // Is updated on reload, within the thread loop.
39 49  
40   - ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore, ConfigFileParser &confFileParser);
  50 + ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore, ConfigFileParser &confFileParser, const GlobalSettings &settings);
41 51 ThreadData(const ThreadData &other) = delete;
42 52 ThreadData(ThreadData &&other) = delete;
43 53  
... ... @@ -49,9 +59,10 @@ public:
49 59 void removeClient(int fd);
50 60 std::shared_ptr<SubscriptionStore> &getSubscriptionStore();
51 61  
52   - bool doKeepAliveCheck();
53 62 void initAuthPlugin();
54   - void reload();
  63 + void queueReload(GlobalSettings settings);
  64 + void queueDoKeepAliveCheck();
  65 +
55 66 };
56 67  
57 68 #endif // THREADDATA_H
... ...