Commit f3a45a2a9df301da64045555118352dec7388dd8

Authored by Wiebe Cazemier
1 parent 99a087f4

Like main thread, also have a task queue in threads

It's also used to reload settings. Settings are copied to threads, to
avoid concurrency issues.
globalsettings.cpp
1 #include "globalsettings.h" 1 #include "globalsettings.h"
2 2
3 -GlobalSettings *GlobalSettings::instance = nullptr;  
4 3
5 -GlobalSettings::GlobalSettings()  
6 -{  
7 -  
8 -}  
9 -  
10 -GlobalSettings *GlobalSettings::getInstance()  
11 -{  
12 - if (instance == nullptr)  
13 - instance = new GlobalSettings();  
14 -  
15 - return instance;  
16 -}  
globalsettings.h
1 #ifndef GLOBALSETTINGS_H 1 #ifndef GLOBALSETTINGS_H
2 #define GLOBALSETTINGS_H 2 #define GLOBALSETTINGS_H
3 3
4 -// 'Global' as in, needed outside of the mainapp, like listen ports.  
5 -class GlobalSettings 4 +// Defaults are defined in ConfigFileParser
  5 +struct GlobalSettings
6 { 6 {
7 - static GlobalSettings *instance;  
8 - GlobalSettings();  
9 -public:  
10 - static GlobalSettings *getInstance();  
11 -  
12 bool allow_unsafe_clientid_chars = false; 7 bool allow_unsafe_clientid_chars = false;
13 }; 8 };
14 #endif // GLOBALSETTINGS_H 9 #endif // GLOBALSETTINGS_H
mainapp.cpp
@@ -25,7 +25,6 @@ void do_thread_work(ThreadData *threadData) @@ -25,7 +25,6 @@ void do_thread_work(ThreadData *threadData)
25 memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); 25 memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS);
26 26
27 std::vector<MqttPacket> packetQueueIn; 27 std::vector<MqttPacket> packetQueueIn;
28 - time_t lastKeepAliveCheck = time(NULL);  
29 28
30 Logger *logger = Logger::getInstance(); 29 Logger *logger = Logger::getInstance();
31 30
@@ -59,6 +58,21 @@ void do_thread_work(ThreadData *threadData) @@ -59,6 +58,21 @@ void do_thread_work(ThreadData *threadData)
59 struct epoll_event cur_ev = events[i]; 58 struct epoll_event cur_ev = events[i];
60 int fd = cur_ev.data.fd; 59 int fd = cur_ev.data.fd;
61 60
  61 + if (fd == threadData->taskEventFd)
  62 + {
  63 + uint64_t eventfd_value = 0;
  64 + check<std::runtime_error>(read(fd, &eventfd_value, sizeof(uint64_t)));
  65 +
  66 + std::lock_guard<std::mutex> locker(threadData->taskQueueMutex);
  67 + for(auto &f : threadData->taskQueue)
  68 + {
  69 + f();
  70 + }
  71 + threadData->taskQueue.clear();
  72 +
  73 + continue;
  74 + }
  75 +
62 Client_p client = threadData->getClient(fd); 76 Client_p client = threadData->getClient(fd);
63 77
64 if (client) 78 if (client)
@@ -127,19 +141,6 @@ void do_thread_work(ThreadData *threadData) @@ -127,19 +141,6 @@ void do_thread_work(ThreadData *threadData)
127 } 141 }
128 } 142 }
129 packetQueueIn.clear(); 143 packetQueueIn.clear();
130 -  
131 - try  
132 - {  
133 - if (lastKeepAliveCheck + 30 < time(NULL))  
134 - {  
135 - if (threadData->doKeepAliveCheck())  
136 - lastKeepAliveCheck = time(NULL);  
137 - }  
138 - }  
139 - catch (std::exception &ex)  
140 - {  
141 - logger->logf(LOG_ERR, "Error handling keep-alives: %s.", ex.what());  
142 - }  
143 } 144 }
144 } 145 }
145 146
@@ -164,6 +165,9 @@ MainApp::MainApp(const std::string &amp;configFilePath) : @@ -164,6 +165,9 @@ MainApp::MainApp(const std::string &amp;configFilePath) :
164 165
165 auto f = std::bind(&MainApp::queueCleanup, this); 166 auto f = std::bind(&MainApp::queueCleanup, this);
166 timer.addCallback(f, 86400000, "session expiration"); 167 timer.addCallback(f, 86400000, "session expiration");
  168 +
  169 + auto fKeepAlive = std::bind(&MainApp::queueKeepAliveCheckAtAllThreads, this);
  170 + timer.addCallback(fKeepAlive, 30000, "keep-alive check");
167 } 171 }
168 172
169 MainApp::~MainApp() 173 MainApp::~MainApp()
@@ -254,6 +258,14 @@ void MainApp::wakeUpThread() @@ -254,6 +258,14 @@ void MainApp::wakeUpThread()
254 check<std::runtime_error>(write(taskEventFd, &one, sizeof(uint64_t))); 258 check<std::runtime_error>(write(taskEventFd, &one, sizeof(uint64_t)));
255 } 259 }
256 260
  261 +void MainApp::queueKeepAliveCheckAtAllThreads()
  262 +{
  263 + for (std::shared_ptr<ThreadData> &thread : threads)
  264 + {
  265 + thread->queueDoKeepAliveCheck();
  266 + }
  267 +}
  268 +
257 void MainApp::initMainApp(int argc, char *argv[]) 269 void MainApp::initMainApp(int argc, char *argv[])
258 { 270 {
259 if (instance != nullptr) 271 if (instance != nullptr)
@@ -352,7 +364,7 @@ void MainApp::start() @@ -352,7 +364,7 @@ void MainApp::start()
352 364
353 for (int i = 0; i < num_threads; i++) 365 for (int i = 0; i < num_threads; i++)
354 { 366 {
355 - std::shared_ptr<ThreadData> t(new ThreadData(i, subscriptionStore, *confFileParser.get())); 367 + std::shared_ptr<ThreadData> t(new ThreadData(i, subscriptionStore, *confFileParser.get(), settings));
356 t->start(&do_thread_work); 368 t->start(&do_thread_work);
357 threads.push_back(t); 369 threads.push_back(t);
358 } 370 }
@@ -413,6 +425,7 @@ void MainApp::start() @@ -413,6 +425,7 @@ void MainApp::start()
413 uint64_t eventfd_value = 0; 425 uint64_t eventfd_value = 0;
414 check<std::runtime_error>(read(cur_fd, &eventfd_value, sizeof(uint64_t))); 426 check<std::runtime_error>(read(cur_fd, &eventfd_value, sizeof(uint64_t)));
415 427
  428 + std::lock_guard<std::mutex> locker(eventMutex);
416 for(auto &f : taskQueue) 429 for(auto &f : taskQueue)
417 { 430 {
418 f(); 431 f();
@@ -453,7 +466,6 @@ void MainApp::quit() @@ -453,7 +466,6 @@ void MainApp::quit()
453 void MainApp::loadConfig() 466 void MainApp::loadConfig()
454 { 467 {
455 Logger *logger = Logger::getInstance(); 468 Logger *logger = Logger::getInstance();
456 - GlobalSettings *setting = GlobalSettings::getInstance();  
457 469
458 // Atomic loading, first test. 470 // Atomic loading, first test.
459 confFileParser->loadFile(true); 471 confFileParser->loadFile(true);
@@ -472,9 +484,14 @@ void MainApp::loadConfig() @@ -472,9 +484,14 @@ void MainApp::loadConfig()
472 SSL_CTX_set_options(sslctx, SSL_OP_NO_TLSv1); // TODO: config option 484 SSL_CTX_set_options(sslctx, SSL_OP_NO_TLSv1); // TODO: config option
473 } 485 }
474 486
475 - setting->allow_unsafe_clientid_chars = confFileParser->allowUnsafeClientidChars; 487 + settings.allow_unsafe_clientid_chars = confFileParser->allowUnsafeClientidChars;
476 488
477 setCertAndKeyFromConfig(); 489 setCertAndKeyFromConfig();
  490 +
  491 + for (std::shared_ptr<ThreadData> &thread : threads)
  492 + {
  493 + thread->queueReload(settings);
  494 + }
478 } 495 }
479 496
480 void MainApp::reloadConfig() 497 void MainApp::reloadConfig()
mainapp.h
@@ -37,6 +37,7 @@ class MainApp @@ -37,6 +37,7 @@ class MainApp
37 int taskEventFd = -1; 37 int taskEventFd = -1;
38 std::mutex eventMutex; 38 std::mutex eventMutex;
39 Timer timer; 39 Timer timer;
  40 + GlobalSettings settings;
40 41
41 uint listenPort = 0; 42 uint listenPort = 0;
42 uint sslListenPort = 0; 43 uint sslListenPort = 0;
@@ -51,6 +52,7 @@ class MainApp @@ -51,6 +52,7 @@ class MainApp
51 void setCertAndKeyFromConfig(); 52 void setCertAndKeyFromConfig();
52 int createListenSocket(int portNr, bool ssl); 53 int createListenSocket(int portNr, bool ssl);
53 void wakeUpThread(); 54 void wakeUpThread();
  55 + void queueKeepAliveCheckAtAllThreads();
54 56
55 MainApp(const std::string &configFilePath); 57 MainApp(const std::string &configFilePath);
56 public: 58 public:
mqttpacket.cpp
@@ -152,12 +152,12 @@ void MqttPacket::handleConnect() @@ -152,12 +152,12 @@ void MqttPacket::handleConnect()
152 if (sender->hasConnectPacketSeen()) 152 if (sender->hasConnectPacketSeen())
153 throw ProtocolError("Client already sent a CONNECT."); 153 throw ProtocolError("Client already sent a CONNECT.");
154 154
155 - GlobalSettings *settings = GlobalSettings::getInstance();  
156 -  
157 std::shared_ptr<SubscriptionStore> subscriptionStore = sender->getThreadData()->getSubscriptionStore(); 155 std::shared_ptr<SubscriptionStore> subscriptionStore = sender->getThreadData()->getSubscriptionStore();
158 156
159 uint16_t variable_header_length = readTwoBytesToUInt16(); 157 uint16_t variable_header_length = readTwoBytesToUInt16();
160 158
  159 + const GlobalSettings &settings = sender->getThreadData()->settingsLocalCopy;
  160 +
161 if (variable_header_length == 4 || variable_header_length == 6) 161 if (variable_header_length == 4 || variable_header_length == 6)
162 { 162 {
163 char *c = readBytes(variable_header_length); 163 char *c = readBytes(variable_header_length);
@@ -243,7 +243,7 @@ void MqttPacket::handleConnect() @@ -243,7 +243,7 @@ void MqttPacket::handleConnect()
243 bool validClientId = true; 243 bool validClientId = true;
244 244
245 // Check for wildcard chars in case the client_id ever appears in topics. 245 // Check for wildcard chars in case the client_id ever appears in topics.
246 - if (!settings->allow_unsafe_clientid_chars && (strContains(client_id, "+") || strContains(client_id, "#"))) 246 + if (!settings.allow_unsafe_clientid_chars && (strContains(client_id, "+") || strContains(client_id, "#")))
247 { 247 {
248 logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false:", client_id.c_str()); 248 logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false:", client_id.c_str());
249 validClientId = false; 249 validClientId = false;
threaddata.cpp
@@ -2,15 +2,26 @@ @@ -2,15 +2,26 @@
2 #include <string> 2 #include <string>
3 #include <sstream> 3 #include <sstream>
4 4
5 -ThreadData::ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore, ConfigFileParser &confFileParser) : 5 +ThreadData::ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore, ConfigFileParser &confFileParser, const GlobalSettings &settings) :
6 subscriptionStore(subscriptionStore), 6 subscriptionStore(subscriptionStore),
7 confFileParser(confFileParser), 7 confFileParser(confFileParser),
8 authPlugin(confFileParser), 8 authPlugin(confFileParser),
9 - threadnr(threadnr) 9 + threadnr(threadnr),
  10 + settingsLocalCopy(settings)
10 { 11 {
11 logger = Logger::getInstance(); 12 logger = Logger::getInstance();
12 13
13 epollfd = check<std::runtime_error>(epoll_create(999)); 14 epollfd = check<std::runtime_error>(epoll_create(999));
  15 +
  16 + taskEventFd = eventfd(0, EFD_NONBLOCK);
  17 + if (taskEventFd < 0)
  18 + throw std::runtime_error("Can't create eventfd.");
  19 +
  20 + struct epoll_event ev;
  21 + memset(&ev, 0, sizeof (struct epoll_event));
  22 + ev.data.fd = taskEventFd;
  23 + ev.events = EPOLLIN;
  24 + check<std::runtime_error>(epoll_ctl(this->epollfd, EPOLL_CTL_ADD, taskEventFd, &ev));
14 } 25 }
15 26
16 void ThreadData::start(thread_f f) 27 void ThreadData::start(thread_f f)
@@ -89,27 +100,45 @@ std::shared_ptr&lt;SubscriptionStore&gt; &amp;ThreadData::getSubscriptionStore() @@ -89,27 +100,45 @@ std::shared_ptr&lt;SubscriptionStore&gt; &amp;ThreadData::getSubscriptionStore()
89 return subscriptionStore; 100 return subscriptionStore;
90 } 101 }
91 102
  103 +void ThreadData::queueDoKeepAliveCheck()
  104 +{
  105 + std::lock_guard<std::mutex> locker(taskQueueMutex);
  106 +
  107 + auto f = std::bind(&ThreadData::doKeepAliveCheck, this);
  108 + taskQueue.push_front(f);
  109 +
  110 + wakeUpThread();
  111 +}
  112 +
92 // TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial? 113 // TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial?
93 -bool ThreadData::doKeepAliveCheck() 114 +void ThreadData::doKeepAliveCheck()
94 { 115 {
  116 + // We don't need to stall normal connects and disconnects for keep-alive checking. We can do it later.
95 std::unique_lock<std::mutex> lock(clients_by_fd_mutex, std::try_to_lock); 117 std::unique_lock<std::mutex> lock(clients_by_fd_mutex, std::try_to_lock);
96 if (!lock.owns_lock()) 118 if (!lock.owns_lock())
97 - return false; 119 + return;
98 120
99 - auto it = clients_by_fd.begin();  
100 - while (it != clients_by_fd.end()) 121 + logger->logf(LOG_DEBUG, "Doing keep-alive check in thread %d", threadnr);
  122 +
  123 + try
101 { 124 {
102 - Client_p &client = it->second;  
103 - if (client && client->keepAliveExpired()) 125 + auto it = clients_by_fd.begin();
  126 + while (it != clients_by_fd.end())
104 { 127 {
105 - client->setDisconnectReason("Keep-alive expired: " + client->getKeepAliveInfoString());  
106 - it = clients_by_fd.erase(it); 128 + Client_p &client = it->second;
  129 + if (client && client->keepAliveExpired())
  130 + {
  131 + client->setDisconnectReason("Keep-alive expired: " + client->getKeepAliveInfoString());
  132 + it = clients_by_fd.erase(it);
  133 + }
  134 + else
  135 + it++;
107 } 136 }
108 - else  
109 - it++;  
110 } 137 }
111 -  
112 - return true; 138 + catch (std::exception &ex)
  139 + {
  140 + logger->logf(LOG_ERR, "Error handling keep-alives: %s.", ex.what());
  141 + }
113 } 142 }
114 143
115 void ThreadData::initAuthPlugin() 144 void ThreadData::initAuthPlugin()
@@ -119,10 +148,14 @@ void ThreadData::initAuthPlugin() @@ -119,10 +148,14 @@ void ThreadData::initAuthPlugin()
119 authPlugin.securityInit(false); 148 authPlugin.securityInit(false);
120 } 149 }
121 150
122 -void ThreadData::reload() 151 +void ThreadData::reload(GlobalSettings settings)
123 { 152 {
  153 + logger->logf(LOG_DEBUG, "Doing reload in thread %d", threadnr);
  154 +
124 try 155 try
125 { 156 {
  157 + settingsLocalCopy = settings;
  158 +
126 authPlugin.securityCleanup(true); 159 authPlugin.securityCleanup(true);
127 authPlugin.securityInit(true); 160 authPlugin.securityInit(true);
128 } 161 }
@@ -130,7 +163,29 @@ void ThreadData::reload() @@ -130,7 +163,29 @@ void ThreadData::reload()
130 { 163 {
131 logger->logf(LOG_ERR, "Error reloading auth plugin: %s. Security checks will now fail, because we don't know the status of the plugin anymore.", ex.what()); 164 logger->logf(LOG_ERR, "Error reloading auth plugin: %s. Security checks will now fail, because we don't know the status of the plugin anymore.", ex.what());
132 } 165 }
  166 + catch (std::exception &ex)
  167 + {
  168 + logger->logf(LOG_ERR, "Error reloading: %s.", ex.what());
  169 + }
  170 +}
  171 +
  172 +void ThreadData::queueReload(GlobalSettings settings)
  173 +{
  174 + std::lock_guard<std::mutex> locker(taskQueueMutex);
  175 +
  176 + auto f = std::bind(&ThreadData::reload, this, settings);
  177 + taskQueue.push_front(f);
  178 +
  179 + wakeUpThread();
133 } 180 }
134 181
  182 +void ThreadData::wakeUpThread()
  183 +{
  184 + uint64_t one = 1;
  185 + check<std::runtime_error>(write(taskEventFd, &one, sizeof(uint64_t)));
  186 +}
  187 +
  188 +
  189 +
135 190
136 191
threaddata.h
@@ -10,6 +10,7 @@ @@ -10,6 +10,7 @@
10 #include <unordered_map> 10 #include <unordered_map>
11 #include <mutex> 11 #include <mutex>
12 #include <shared_mutex> 12 #include <shared_mutex>
  13 +#include <functional>
13 14
14 #include "forward_declarations.h" 15 #include "forward_declarations.h"
15 16
@@ -19,6 +20,7 @@ @@ -19,6 +20,7 @@
19 #include "configfileparser.h" 20 #include "configfileparser.h"
20 #include "authplugin.h" 21 #include "authplugin.h"
21 #include "logger.h" 22 #include "logger.h"
  23 +#include "globalsettings.h"
22 24
23 typedef void (*thread_f)(ThreadData *); 25 typedef void (*thread_f)(ThreadData *);
24 26
@@ -30,14 +32,22 @@ class ThreadData @@ -30,14 +32,22 @@ class ThreadData
30 ConfigFileParser &confFileParser; 32 ConfigFileParser &confFileParser;
31 Logger *logger; 33 Logger *logger;
32 34
  35 + void reload(GlobalSettings settings);
  36 + void wakeUpThread();
  37 + void doKeepAliveCheck();
  38 +
33 public: 39 public:
34 AuthPlugin authPlugin; 40 AuthPlugin authPlugin;
35 bool running = true; 41 bool running = true;
36 std::thread thread; 42 std::thread thread;
37 int threadnr = 0; 43 int threadnr = 0;
38 int epollfd = 0; 44 int epollfd = 0;
  45 + int taskEventFd = 0;
  46 + std::mutex taskQueueMutex;
  47 + std::forward_list<std::function<void()>> taskQueue;
  48 + GlobalSettings settingsLocalCopy; // Is updated on reload, within the thread loop.
39 49
40 - ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore, ConfigFileParser &confFileParser); 50 + ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore, ConfigFileParser &confFileParser, const GlobalSettings &settings);
41 ThreadData(const ThreadData &other) = delete; 51 ThreadData(const ThreadData &other) = delete;
42 ThreadData(ThreadData &&other) = delete; 52 ThreadData(ThreadData &&other) = delete;
43 53
@@ -49,9 +59,10 @@ public: @@ -49,9 +59,10 @@ public:
49 void removeClient(int fd); 59 void removeClient(int fd);
50 std::shared_ptr<SubscriptionStore> &getSubscriptionStore(); 60 std::shared_ptr<SubscriptionStore> &getSubscriptionStore();
51 61
52 - bool doKeepAliveCheck();  
53 void initAuthPlugin(); 62 void initAuthPlugin();
54 - void reload(); 63 + void queueReload(GlobalSettings settings);
  64 + void queueDoKeepAliveCheck();
  65 +
55 }; 66 };
56 67
57 #endif // THREADDATA_H 68 #endif // THREADDATA_H