Commit 5d72070e0e6f6b6ca0a1f181c3d0041c923a5e32

Authored by Wiebe Cazemier
1 parent e5746d7f

Support generic listener in config file

This allows creation of multiple listeners, with different protocols
and/or SSL certificates.

Related change: settings is now a class that is copyable and assignable,
and is done so to each thread on reload.

Semi-related fix: fix crash in quit when multiple threads initiated it.
This came to light when testing the auth plugin settings.
CMakeLists.txt
@@ -28,8 +28,10 @@ add_executable(FlashMQ @@ -28,8 +28,10 @@ add_executable(FlashMQ
28 configfileparser.cpp 28 configfileparser.cpp
29 sslctxmanager.cpp 29 sslctxmanager.cpp
30 timer.cpp 30 timer.cpp
31 - globalsettings.cpp  
32 iowrapper.cpp 31 iowrapper.cpp
  32 + mosquittoauthoptcompatwrap.cpp
  33 + settings.cpp
  34 + listener.cpp
33 ) 35 )
34 36
35 target_link_libraries(FlashMQ pthread dl ssl crypto) 37 target_link_libraries(FlashMQ pthread dl ssl crypto)
authplugin.cpp
@@ -17,8 +17,8 @@ void mosquitto_log_printf(int level, const char *fmt, ...) @@ -17,8 +17,8 @@ void mosquitto_log_printf(int level, const char *fmt, ...)
17 } 17 }
18 18
19 19
20 -AuthPlugin::AuthPlugin(ConfigFileParser &confFileParser) :  
21 - confFileParser(confFileParser) 20 +AuthPlugin::AuthPlugin(Settings &settings) :
  21 + settings(settings)
22 { 22 {
23 logger = Logger::getInstance(); 23 logger = Logger::getInstance();
24 } 24 }
@@ -89,7 +89,7 @@ void AuthPlugin::init() @@ -89,7 +89,7 @@ void AuthPlugin::init()
89 if (!wanted) 89 if (!wanted)
90 return; 90 return;
91 91
92 - AuthOptCompatWrap &authOpts = confFileParser.getAuthOptsCompat(); 92 + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
93 int result = init_v2(&pluginData, authOpts.head(), authOpts.size()); 93 int result = init_v2(&pluginData, authOpts.head(), authOpts.size());
94 if (result != 0) 94 if (result != 0)
95 throw FatalError("Error initialising auth plugin."); 95 throw FatalError("Error initialising auth plugin.");
@@ -102,7 +102,7 @@ void AuthPlugin::cleanup() @@ -102,7 +102,7 @@ void AuthPlugin::cleanup()
102 102
103 securityCleanup(false); 103 securityCleanup(false);
104 104
105 - AuthOptCompatWrap &authOpts = confFileParser.getAuthOptsCompat(); 105 + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
106 int result = cleanup_v2(pluginData, authOpts.head(), authOpts.size()); 106 int result = cleanup_v2(pluginData, authOpts.head(), authOpts.size());
107 if (result != 0) 107 if (result != 0)
108 logger->logf(LOG_ERR, "Error cleaning up auth plugin"); // Not doing exception, because we're shutting down anyway. 108 logger->logf(LOG_ERR, "Error cleaning up auth plugin"); // Not doing exception, because we're shutting down anyway.
@@ -113,7 +113,7 @@ void AuthPlugin::securityInit(bool reloading) @@ -113,7 +113,7 @@ void AuthPlugin::securityInit(bool reloading)
113 if (!wanted) 113 if (!wanted)
114 return; 114 return;
115 115
116 - AuthOptCompatWrap &authOpts = confFileParser.getAuthOptsCompat(); 116 + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
117 int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading); 117 int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading);
118 if (result != 0) 118 if (result != 0)
119 { 119 {
@@ -128,7 +128,7 @@ void AuthPlugin::securityCleanup(bool reloading) @@ -128,7 +128,7 @@ void AuthPlugin::securityCleanup(bool reloading)
128 return; 128 return;
129 129
130 initialized = false; 130 initialized = false;
131 - AuthOptCompatWrap &authOpts = confFileParser.getAuthOptsCompat(); 131 + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
132 int result = security_cleanup_v2(pluginData, authOpts.head(), authOpts.size(), reloading); 132 int result = security_cleanup_v2(pluginData, authOpts.head(), authOpts.size(), reloading);
133 133
134 if (result != 0) 134 if (result != 0)
authplugin.h
@@ -51,7 +51,7 @@ class AuthPlugin @@ -51,7 +51,7 @@ class AuthPlugin
51 F_auth_plugin_unpwd_check_v2 unpwd_check_v2 = nullptr; 51 F_auth_plugin_unpwd_check_v2 unpwd_check_v2 = nullptr;
52 F_auth_plugin_psk_key_get_v2 psk_key_get_v2 = nullptr; 52 F_auth_plugin_psk_key_get_v2 psk_key_get_v2 = nullptr;
53 53
54 - ConfigFileParser &confFileParser; 54 + Settings &settings; // A ref because I want it to always be the same as the thread's settings
55 55
56 void *pluginData = nullptr; 56 void *pluginData = nullptr;
57 Logger *logger = nullptr; 57 Logger *logger = nullptr;
@@ -60,7 +60,7 @@ class AuthPlugin @@ -60,7 +60,7 @@ class AuthPlugin
60 60
61 void *loadSymbol(void *handle, const char *symbol) const; 61 void *loadSymbol(void *handle, const char *symbol) const;
62 public: 62 public:
63 - AuthPlugin(ConfigFileParser &confFileParser); 63 + AuthPlugin(Settings &settings);
64 AuthPlugin(const AuthPlugin &other) = delete; 64 AuthPlugin(const AuthPlugin &other) = delete;
65 AuthPlugin(AuthPlugin &&other) = delete; 65 AuthPlugin(AuthPlugin &&other) = delete;
66 ~AuthPlugin(); 66 ~AuthPlugin();
client.cpp
@@ -7,10 +7,10 @@ @@ -7,10 +7,10 @@
7 7
8 #include "logger.h" 8 #include "logger.h"
9 9
10 -Client::Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, const GlobalSettings &settings) : 10 +Client::Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, std::shared_ptr<Settings> settings) :
11 fd(fd), 11 fd(fd),
12 - initialBufferSize(settings.clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy  
13 - maxPacketSize(settings.maxPacketSize), // Same as initialBufferSize comment. 12 + initialBufferSize(settings->clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy
  13 + maxPacketSize(settings->maxPacketSize), // Same as initialBufferSize comment.
14 ioWrapper(ssl, websocket, initialBufferSize, this), 14 ioWrapper(ssl, websocket, initialBufferSize, this),
15 readbuf(initialBufferSize), 15 readbuf(initialBufferSize),
16 writebuf(initialBufferSize), 16 writebuf(initialBufferSize),
client.h
@@ -70,7 +70,7 @@ class Client @@ -70,7 +70,7 @@ class Client
70 void setReadyForReading(bool val); 70 void setReadyForReading(bool val);
71 71
72 public: 72 public:
73 - Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, const GlobalSettings &settings); 73 + Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, std::shared_ptr<Settings> settings);
74 Client(const Client &other) = delete; 74 Client(const Client &other) = delete;
75 Client(Client &&other) = delete; 75 Client(Client &&other) = delete;
76 ~Client(); 76 ~Client();
configfileparser.cpp
@@ -4,45 +4,24 @@ @@ -4,45 +4,24 @@
4 #include <unistd.h> 4 #include <unistd.h>
5 #include <sstream> 5 #include <sstream>
6 #include "fstream" 6 #include "fstream"
  7 +#include <regex>
7 8
8 #include "openssl/ssl.h" 9 #include "openssl/ssl.h"
9 #include "openssl/err.h" 10 #include "openssl/err.h"
10 11
11 #include "exceptions.h" 12 #include "exceptions.h"
12 #include "utils.h" 13 #include "utils.h"
13 -#include <regex>  
14 -  
15 #include "logger.h" 14 #include "logger.h"
16 15
17 16
18 -mosquitto_auth_opt::mosquitto_auth_opt(const std::string &key, const std::string &value)  
19 -{  
20 - this->key = strdup(key.c_str());  
21 - this->value = strdup(value.c_str());  
22 -}  
23 -  
24 -mosquitto_auth_opt::mosquitto_auth_opt(mosquitto_auth_opt &&other)  
25 -{  
26 - this->key = other.key;  
27 - this->value = other.value;  
28 - other.key = nullptr;  
29 - other.value = nullptr;  
30 -}  
31 -  
32 -mosquitto_auth_opt::~mosquitto_auth_opt()  
33 -{  
34 - if (key)  
35 - delete key;  
36 - if (value)  
37 - delete value;  
38 -}  
39 -  
40 -AuthOptCompatWrap::AuthOptCompatWrap(const std::unordered_map<std::string, std::string> &authOpts) 17 +void ConfigFileParser::testKeyValidity(const std::string &key, const std::set<std::string> &validKeys) const
41 { 18 {
42 - for(auto &pair : authOpts) 19 + auto valid_key_it = validKeys.find(key);
  20 + if (valid_key_it == validKeys.end())
43 { 21 {
44 - mosquitto_auth_opt opt(pair.first, pair.second);  
45 - optArray.push_back(std::move(opt)); 22 + std::ostringstream oss;
  23 + oss << "Config key '" << key << "' is not valid here.";
  24 + throw ConfigFileException(oss.str());
46 } 25 }
47 } 26 }
48 27
@@ -56,51 +35,21 @@ void ConfigFileParser::checkFileAccess(const std::string &amp;key, const std::string @@ -56,51 +35,21 @@ void ConfigFileParser::checkFileAccess(const std::string &amp;key, const std::string
56 } 35 }
57 } 36 }
58 37
59 -// Using a separate ssl context to test, because it's the easiest way to load certs and key atomitcally.  
60 -void ConfigFileParser::testSsl(const std::string &fullchain, const std::string &privkey, uint portNr) const  
61 -{  
62 - if (portNr == 0)  
63 - return;  
64 -  
65 - if (fullchain.empty() && privkey.empty())  
66 - throw ConfigFileException("No privkey and fullchain specified.");  
67 -  
68 - if (fullchain.empty())  
69 - throw ConfigFileException("No private key specified for fullchain");  
70 -  
71 - if (privkey.empty())  
72 - throw ConfigFileException("No fullchain specified for private key");  
73 -  
74 - SslCtxManager sslCtx;  
75 - if (SSL_CTX_use_certificate_file(sslCtx.get(), fullchain.c_str(), SSL_FILETYPE_PEM) != 1)  
76 - {  
77 - ERR_print_errors_cb(logSslError, NULL);  
78 - throw ConfigFileException("Error loading full chain " + fullchain);  
79 - }  
80 - if (SSL_CTX_use_PrivateKey_file(sslCtx.get(), privkey.c_str(), SSL_FILETYPE_PEM) != 1)  
81 - {  
82 - ERR_print_errors_cb(logSslError, NULL);  
83 - throw ConfigFileException("Error loading private key " + privkey);  
84 - }  
85 - if (SSL_CTX_check_private_key(sslCtx.get()) != 1)  
86 - {  
87 - ERR_print_errors_cb(logSslError, NULL);  
88 - throw ConfigFileException("Private key and certificate don't match.");  
89 - }  
90 -}  
91 -  
92 ConfigFileParser::ConfigFileParser(const std::string &path) : 38 ConfigFileParser::ConfigFileParser(const std::string &path) :
93 path(path) 39 path(path)
94 { 40 {
95 validKeys.insert("auth_plugin"); 41 validKeys.insert("auth_plugin");
96 validKeys.insert("log_file"); 42 validKeys.insert("log_file");
97 - validKeys.insert("listen_port");  
98 - validKeys.insert("ssl_listen_port");  
99 - validKeys.insert("fullchain");  
100 - validKeys.insert("privkey");  
101 validKeys.insert("allow_unsafe_clientid_chars"); 43 validKeys.insert("allow_unsafe_clientid_chars");
102 validKeys.insert("client_initial_buffer_size"); 44 validKeys.insert("client_initial_buffer_size");
103 validKeys.insert("max_packet_size"); 45 validKeys.insert("max_packet_size");
  46 +
  47 + validListenKeys.insert("port");
  48 + validListenKeys.insert("protocol");
  49 + validListenKeys.insert("fullchain");
  50 + validListenKeys.insert("privkey");
  51 +
  52 + settings.reset(new Settings());
104 } 53 }
105 54
106 void ConfigFileParser::loadFile(bool test) 55 void ConfigFileParser::loadFile(bool test)
@@ -121,12 +70,19 @@ void ConfigFileParser::loadFile(bool test) @@ -121,12 +70,19 @@ void ConfigFileParser::loadFile(bool test)
121 70
122 std::list<std::string> lines; 71 std::list<std::string> lines;
123 72
124 - const std::regex r("^([a-zA-Z0-9_\\-]+) +([a-zA-Z0-9_\\-/\\.]+)$"); 73 + const std::regex key_value_regex("^([a-zA-Z0-9_\\-]+) +([a-zA-Z0-9_\\-/\\.]+)$");
  74 + const std::regex block_regex_start("^([a-zA-Z0-9_\\-]+) *\\{$");
  75 + const std::regex block_regex_end("^\\}$");
  76 +
  77 + bool inBlock = false;
  78 + std::ostringstream oss;
  79 + int linenr = 0;
125 80
126 // First parse the file and keep the valid lines. 81 // First parse the file and keep the valid lines.
127 for(std::string line; getline(infile, line ); ) 82 for(std::string line; getline(infile, line ); )
128 { 83 {
129 trim(line); 84 trim(line);
  85 + linenr++;
130 86
131 if (startsWith(line, "#")) 87 if (startsWith(line, "#"))
132 continue; 88 continue;
@@ -136,102 +92,133 @@ void ConfigFileParser::loadFile(bool test) @@ -136,102 +92,133 @@ void ConfigFileParser::loadFile(bool test)
136 92
137 std::smatch matches; 93 std::smatch matches;
138 94
139 - if (!std::regex_search(line, matches, r) || matches.size() != 3) 95 + const bool blockStartMatch = std::regex_search(line, matches, block_regex_start);
  96 + const bool blockEndMatch = std::regex_search(line, matches, block_regex_end);
  97 +
  98 + if ((blockStartMatch && inBlock) || (blockEndMatch && !inBlock))
140 { 99 {
141 - std::ostringstream oss;  
142 - oss << "Line '" << line << "' not in 'key value' format"; 100 + oss << "Unexpected block start or end at line " << linenr << ": " << line;
143 throw ConfigFileException(oss.str()); 101 throw ConfigFileException(oss.str());
144 } 102 }
145 103
  104 + if (!std::regex_search(line, matches, key_value_regex) && !blockStartMatch && !blockEndMatch)
  105 + {
  106 + oss << "Line '" << line << "' invalid";
  107 + throw ConfigFileException(oss.str());
  108 + }
  109 +
  110 + if (blockStartMatch)
  111 + inBlock = true;
  112 + if (blockEndMatch)
  113 + inBlock = false;
  114 +
146 lines.push_back(line); 115 lines.push_back(line);
147 } 116 }
148 117
149 - authOpts.clear();  
150 - authOptCompatWrap.reset(); 118 + if (inBlock)
  119 + {
  120 + throw ConfigFileException("Unclosed config block. Expecting }");
  121 + }
  122 +
  123 + std::unordered_map<std::string, std::string> authOpts;
151 124
152 - std::string sslFullChainTmp;  
153 - std::string sslPrivkeyTmp; 125 + ConfigParseLevel curParseLevel = ConfigParseLevel::Root;
  126 + std::shared_ptr<Listener> curListener;
  127 + std::unique_ptr<Settings> tmpSettings(new Settings);
154 128
155 // Then once we know the config file is valid, process it. 129 // Then once we know the config file is valid, process it.
156 for (std::string &line : lines) 130 for (std::string &line : lines)
157 { 131 {
158 std::smatch matches; 132 std::smatch matches;
159 133
160 - if (!std::regex_search(line, matches, r) || matches.size() != 3) 134 + if (std::regex_match(line, matches, block_regex_start))
  135 + {
  136 + std::string key = matches[1].str();
  137 + if (matches[1].str() == "listen")
  138 + {
  139 + curParseLevel = ConfigParseLevel::Listen;
  140 + curListener.reset(new Listener);
  141 + }
  142 + else
  143 + {
  144 + throw ConfigFileException(formatString("'%s' is not a valid block.", key.c_str()));
  145 + }
  146 +
  147 + continue;
  148 + }
  149 + else if (std::regex_match(line, matches, block_regex_end))
161 { 150 {
162 - throw ConfigFileException("Config parse error at a point that should not be possible."); 151 + if (curParseLevel == ConfigParseLevel::Listen)
  152 + {
  153 + curListener->isValid();
  154 + tmpSettings->listeners.push_back(curListener);
  155 + curListener.reset();
  156 + }
  157 +
  158 + curParseLevel = ConfigParseLevel::Root;
  159 + continue;
163 } 160 }
164 161
  162 + std::regex_match(line, matches, key_value_regex);
  163 +
165 std::string key = matches[1].str(); 164 std::string key = matches[1].str();
166 const std::string value = matches[2].str(); 165 const std::string value = matches[2].str();
167 166
168 - const std::string auth_opt_ = "auth_opt_";  
169 - if (startsWith(key, auth_opt_))  
170 - {  
171 - key.replace(0, auth_opt_.length(), "");  
172 - authOpts[key] = value;  
173 - }  
174 - else 167 + try
175 { 168 {
176 - auto valid_key_it = validKeys.find(key);  
177 - if (valid_key_it == validKeys.end()) 169 + if (curParseLevel == ConfigParseLevel::Listen)
178 { 170 {
179 - std::ostringstream oss;  
180 - oss << "Config key '" << key << "' is not valid. This error should have been cought before. Bug?";  
181 - throw ConfigFileException(oss.str());  
182 - } 171 + testKeyValidity(key, validListenKeys);
183 172
184 - if (key == "auth_plugin")  
185 - {  
186 - checkFileAccess(key, value);  
187 - if (!test)  
188 - this->authPluginPath = value;  
189 - } 173 + if (key == "protocol")
  174 + {
  175 + if (value != "mqtt" && value != "websockets")
  176 + throw ConfigFileException(formatString("Protocol '%s' is not a valid listener protocol", value.c_str()));
  177 + curListener->websocket = value == "websockets";
  178 + }
  179 + else if (key == "port")
  180 + {
  181 + curListener->port = std::stoi(value);
  182 + }
  183 + else if (key == "fullchain")
  184 + {
  185 + curListener->sslFullchain = value;
  186 + }
  187 + if (key == "privkey")
  188 + {
  189 + curListener->sslPrivkey = value;
  190 + }
190 191
191 - if (key == "log_file")  
192 - {  
193 - checkFileAccess(key, value);  
194 - if (!test)  
195 - this->logPath = value; 192 + continue;
196 } 193 }
197 194
198 - if (key == "allow_unsafe_clientid_chars")  
199 - {  
200 - bool tmp = stringTruthiness(value);  
201 - if (!test)  
202 - this->allowUnsafeClientidChars = tmp;  
203 - }  
204 195
205 - if (key == "fullchain") 196 + const std::string auth_opt_ = "auth_opt_";
  197 + if (startsWith(key, auth_opt_))
206 { 198 {
207 - checkFileAccess(key, value);  
208 - sslFullChainTmp = value; 199 + key.replace(0, auth_opt_.length(), "");
  200 + authOpts[key] = value;
209 } 201 }
210 -  
211 - if (key == "privkey") 202 + else
212 { 203 {
213 - checkFileAccess(key, value);  
214 - sslPrivkeyTmp = value;  
215 - } 204 + testKeyValidity(key, validKeys);
216 205
217 - try  
218 - {  
219 - // TODO: make this possible. There are many error cases to deal with, like bind failures, etc. You don't want to end up without listeners.  
220 - if (key == "listen_port") 206 + if (key == "auth_plugin")
221 { 207 {
222 - uint listenportNew = std::stoi(value);  
223 - if (listenPort > 0 && listenPort != listenportNew)  
224 - throw ConfigFileException("Changing (ssl_)listen_port is not supported at this time.");  
225 - listenPort = listenportNew; 208 + checkFileAccess(key, value);
  209 + tmpSettings->authPluginPath = value;
226 } 210 }
227 211
228 - // TODO: make this possible. There are many error cases to deal with, like bind failures, etc. You don't want to end up without listeners.  
229 - if (key == "ssl_listen_port") 212 + if (key == "log_file")
230 { 213 {
231 - uint sslListenPortNew = std::stoi(value);  
232 - if (sslListenPort > 0 && sslListenPort != sslListenPortNew)  
233 - throw ConfigFileException("Changing (ssl_)listen_port is not supported at this time.");  
234 - sslListenPort = sslListenPortNew; 214 + checkFileAccess(key, value);
  215 + tmpSettings->logPath = value;
  216 + }
  217 +
  218 + if (key == "allow_unsafe_clientid_chars")
  219 + {
  220 + bool tmp = stringTruthiness(value);
  221 + tmpSettings->allowUnsafeClientidChars = tmp;
235 } 222 }
236 223
237 if (key == "client_initial_buffer_size") 224 if (key == "client_initial_buffer_size")
@@ -239,8 +226,7 @@ void ConfigFileParser::loadFile(bool test) @@ -239,8 +226,7 @@ void ConfigFileParser::loadFile(bool test)
239 int newVal = std::stoi(value); 226 int newVal = std::stoi(value);
240 if (!isPowerOfTwo(newVal)) 227 if (!isPowerOfTwo(newVal))
241 throw ConfigFileException("client_initial_buffer_size value " + value + " is not a power of two."); 228 throw ConfigFileException("client_initial_buffer_size value " + value + " is not a power of two.");
242 - if (!test)  
243 - clientInitialBufferSize = newVal; 229 + tmpSettings->clientInitialBufferSize = newVal;
244 } 230 }
245 231
246 if (key == "max_packet_size") 232 if (key == "max_packet_size")
@@ -252,29 +238,33 @@ void ConfigFileParser::loadFile(bool test) @@ -252,29 +238,33 @@ void ConfigFileParser::loadFile(bool test)
252 oss << "Value for max_packet_size " << newVal << "is higher than absolute maximum " << ABSOLUTE_MAX_PACKET_SIZE; 238 oss << "Value for max_packet_size " << newVal << "is higher than absolute maximum " << ABSOLUTE_MAX_PACKET_SIZE;
253 throw ConfigFileException(oss.str()); 239 throw ConfigFileException(oss.str());
254 } 240 }
255 - if (!test)  
256 - maxPacketSize = newVal; 241 + tmpSettings->maxPacketSize = newVal;
257 } 242 }
258 -  
259 - }  
260 - catch (std::invalid_argument &ex) // catch for the stoi()  
261 - {  
262 - throw ConfigFileException(ex.what());  
263 } 243 }
264 } 244 }
  245 + catch (std::invalid_argument &ex) // catch for the stoi()
  246 + {
  247 + throw ConfigFileException(ex.what());
  248 + }
  249 + }
  250 +
  251 + if (tmpSettings->listeners.empty())
  252 + {
  253 + std::shared_ptr<Listener> defaultListener(new Listener());
  254 + tmpSettings->listeners.push_back(defaultListener);
265 } 255 }
266 256
267 - testSsl(sslFullChainTmp, sslPrivkeyTmp, sslListenPort);  
268 - this->sslFullchain = sslFullChainTmp;  
269 - this->sslPrivkey = sslPrivkeyTmp; 257 + tmpSettings->authOptCompatWrap = AuthOptCompatWrap(authOpts);
270 258
271 - authOptCompatWrap.reset(new AuthOptCompatWrap(authOpts)); 259 + if (!test)
  260 + {
  261 + this->settings = std::move(tmpSettings);
  262 + }
272 } 263 }
273 264
274 -AuthOptCompatWrap &ConfigFileParser::getAuthOptsCompat() 265 +AuthOptCompatWrap &Settings::getAuthOptsCompat()
275 { 266 {
276 - return *authOptCompatWrap.get(); 267 + return authOptCompatWrap;
277 } 268 }
278 269
279 270
280 -  
configfileparser.h
@@ -6,59 +6,33 @@ @@ -6,59 +6,33 @@
6 #include <unordered_map> 6 #include <unordered_map>
7 #include <vector> 7 #include <vector>
8 #include <memory> 8 #include <memory>
  9 +#include <list>
9 10
10 #include "sslctxmanager.h" 11 #include "sslctxmanager.h"
  12 +#include "listener.h"
  13 +#include "settings.h"
11 14
12 #define ABSOLUTE_MAX_PACKET_SIZE 268435461 // 256 MB + 5 15 #define ABSOLUTE_MAX_PACKET_SIZE 268435461 // 256 MB + 5
13 16
14 -struct mosquitto_auth_opt 17 +enum class ConfigParseLevel
15 { 18 {
16 - char *key = nullptr;  
17 - char *value = nullptr;  
18 -  
19 - mosquitto_auth_opt(const std::string &key, const std::string &value);  
20 - mosquitto_auth_opt(mosquitto_auth_opt &&other);  
21 - mosquitto_auth_opt(const mosquitto_auth_opt &other) = delete;  
22 - ~mosquitto_auth_opt();  
23 -};  
24 -  
25 -struct AuthOptCompatWrap  
26 -{  
27 - std::vector<struct mosquitto_auth_opt> optArray;  
28 -  
29 - AuthOptCompatWrap(const std::unordered_map<std::string, std::string> &authOpts);  
30 - AuthOptCompatWrap(const AuthOptCompatWrap &other) = delete;  
31 - AuthOptCompatWrap(AuthOptCompatWrap &&other) = delete;  
32 -  
33 - struct mosquitto_auth_opt *head() { return &optArray[0]; }  
34 - int size() { return optArray.size(); } 19 + Root,
  20 + Listen
35 }; 21 };
36 22
37 class ConfigFileParser 23 class ConfigFileParser
38 { 24 {
39 const std::string path; 25 const std::string path;
40 std::set<std::string> validKeys; 26 std::set<std::string> validKeys;
41 - std::unordered_map<std::string, std::string> authOpts;  
42 - std::unique_ptr<AuthOptCompatWrap> authOptCompatWrap;  
43 - 27 + std::set<std::string> validListenKeys;
44 28
  29 + void testKeyValidity(const std::string &key, const std::set<std::string> &validKeys) const;
45 void checkFileAccess(const std::string &key, const std::string &pathToCheck) const; 30 void checkFileAccess(const std::string &key, const std::string &pathToCheck) const;
46 - void testSsl(const std::string &fullchain, const std::string &privkey, uint portNr) const;  
47 public: 31 public:
48 ConfigFileParser(const std::string &path); 32 ConfigFileParser(const std::string &path);
49 void loadFile(bool test); 33 void loadFile(bool test);
50 - AuthOptCompatWrap &getAuthOptsCompat();  
51 34
52 - // Actual config options with their defaults. Just making them public, I can retrain myself misuing them.  
53 - std::string authPluginPath;  
54 - std::string logPath;  
55 - std::string sslFullchain;  
56 - std::string sslPrivkey;  
57 - uint listenPort = 1883;  
58 - uint sslListenPort = 0;  
59 - bool allowUnsafeClientidChars = false;  
60 - int clientInitialBufferSize = 1024; // Must be power of 2  
61 - int maxPacketSize = 268435461; // 256 MB + 5 35 + std::unique_ptr<Settings> settings;
62 }; 36 };
63 37
64 #endif // CONFIGFILEPARSER_H 38 #endif // CONFIGFILEPARSER_H
exceptions.cpp
1 #include "exceptions.h" 1 #include "exceptions.h"
2 2
  3 +
  4 +
exceptions.h
@@ -3,6 +3,7 @@ @@ -3,6 +3,7 @@
3 3
4 #include <exception> 4 #include <exception>
5 #include <stdexcept> 5 #include <stdexcept>
  6 +#include <sstream>
6 7
7 class ProtocolError : public std::runtime_error 8 class ProtocolError : public std::runtime_error
8 { 9 {
@@ -26,6 +27,7 @@ class ConfigFileException : public std::runtime_error @@ -26,6 +27,7 @@ class ConfigFileException : public std::runtime_error
26 { 27 {
27 public: 28 public:
28 ConfigFileException(const std::string &msg) : std::runtime_error(msg) {} 29 ConfigFileException(const std::string &msg) : std::runtime_error(msg) {}
  30 + ConfigFileException(std::ostringstream oss) : std::runtime_error(oss.str()) {}
29 }; 31 };
30 32
31 class AuthPluginException : public std::runtime_error 33 class AuthPluginException : public std::runtime_error
forward_declarations.h
@@ -10,6 +10,7 @@ typedef std::shared_ptr&lt;ThreadData&gt; ThreadData_p; @@ -10,6 +10,7 @@ typedef std::shared_ptr&lt;ThreadData&gt; ThreadData_p;
10 class MqttPacket; 10 class MqttPacket;
11 class SubscriptionStore; 11 class SubscriptionStore;
12 class Session; 12 class Session;
  13 +class Settings;
13 14
14 15
15 #endif // FORWARD_DECLARATIONS_H 16 #endif // FORWARD_DECLARATIONS_H
globalsettings.cpp deleted
1 -#include "globalsettings.h"  
2 -  
3 -  
globalsettings.h deleted
1 -#ifndef GLOBALSETTINGS_H  
2 -#define GLOBALSETTINGS_H  
3 -  
4 -// Defaults are defined in ConfigFileParser  
5 -struct GlobalSettings  
6 -{  
7 - bool allow_unsafe_clientid_chars = false;  
8 - int clientInitialBufferSize = 0;  
9 - int maxPacketSize = 0;  
10 -};  
11 -#endif // GLOBALSETTINGS_H  
listener.cpp 0 โ†’ 100644
  1 +#include "listener.h"
  2 +
  3 +#include "utils.h"
  4 +#include "exceptions.h"
  5 +
  6 +void Listener::isValid()
  7 +{
  8 + if (isSsl())
  9 + {
  10 + if (port == 0)
  11 + {
  12 + if (websocket)
  13 + port = 4443;
  14 + else
  15 + port = 8883;
  16 + }
  17 +
  18 + testSsl(sslFullchain, sslPrivkey);
  19 + }
  20 + else
  21 + {
  22 + if (port == 0)
  23 + {
  24 + if (websocket)
  25 + port = 8080;
  26 + else
  27 + port = 1883;
  28 + }
  29 + }
  30 +
  31 + if (port <= 0 || port > 65534)
  32 + {
  33 + throw ConfigFileException(formatString("Port nr %d is not valid", port));
  34 + }
  35 +}
  36 +
  37 +bool Listener::isSsl() const
  38 +{
  39 + return (!sslFullchain.empty() || !sslPrivkey.empty());
  40 +}
  41 +
  42 +std::string Listener::getProtocolName() const
  43 +{
  44 + if (isSsl())
  45 + {
  46 + if (websocket)
  47 + return "SSL websocket";
  48 + else
  49 + return "SSL TCP";
  50 + }
  51 + else
  52 + {
  53 + if (websocket)
  54 + return "non-SSL websocket";
  55 + else
  56 + return "non-SSL TCP";
  57 + }
  58 +
  59 + return "whoops";
  60 +}
  61 +
  62 +void Listener::loadCertAndKeyFromConfig()
  63 +{
  64 + if (!isSsl())
  65 + return;
  66 +
  67 + if (!sslctx)
  68 + {
  69 + sslctx.reset(new SslCtxManager());
  70 + SSL_CTX_set_options(sslctx->get(), SSL_OP_NO_SSLv3); // TODO: config option
  71 + SSL_CTX_set_options(sslctx->get(), SSL_OP_NO_TLSv1); // TODO: config option
  72 + }
  73 +
  74 + if (SSL_CTX_use_certificate_file(sslctx->get(), sslFullchain.c_str(), SSL_FILETYPE_PEM) != 1)
  75 + throw std::runtime_error("Loading cert failed. This was after test loading the certificate, so is very unexpected.");
  76 + if (SSL_CTX_use_PrivateKey_file(sslctx->get(), sslPrivkey.c_str(), SSL_FILETYPE_PEM) != 1)
  77 + throw std::runtime_error("Loading key failed. This was after test loading the certificate, so is very unexpected.");
  78 +}
listener.h 0 โ†’ 100644
  1 +#ifndef LISTENER_H
  2 +#define LISTENER_H
  3 +
  4 +#include <string>
  5 +#include <memory>
  6 +
  7 +#include "sslctxmanager.h"
  8 +
  9 +struct Listener
  10 +{
  11 + int port = 0;
  12 + bool websocket = false;
  13 + std::string sslFullchain;
  14 + std::string sslPrivkey;
  15 + std::unique_ptr<SslCtxManager> sslctx;
  16 +
  17 + void isValid();
  18 + bool isSsl() const;
  19 + std::string getProtocolName() const;
  20 + void loadCertAndKeyFromConfig();
  21 +};
  22 +#endif // LISTENER_H
mainapp.cpp
@@ -172,9 +172,6 @@ MainApp::MainApp(const std::string &amp;configFilePath) : @@ -172,9 +172,6 @@ MainApp::MainApp(const std::string &amp;configFilePath) :
172 172
173 MainApp::~MainApp() 173 MainApp::~MainApp()
174 { 174 {
175 - if (sslctx)  
176 - SSL_CTX_free(sslctx);  
177 -  
178 if (epollFdAccept > 0) 175 if (epollFdAccept > 0)
179 close(epollFdAccept); 176 close(epollFdAccept);
180 } 177 }
@@ -202,54 +199,47 @@ void MainApp::showLicense() @@ -202,54 +199,47 @@ void MainApp::showLicense()
202 puts("Author: Wiebe Cazemier <wiebe@halfgaar.net>"); 199 puts("Author: Wiebe Cazemier <wiebe@halfgaar.net>");
203 } 200 }
204 201
205 -void MainApp::setCertAndKeyFromConfig()  
206 -{  
207 - if (sslctx == nullptr)  
208 - return;  
209 -  
210 - if (SSL_CTX_use_certificate_file(sslctx, confFileParser->sslFullchain.c_str(), SSL_FILETYPE_PEM) != 1)  
211 - throw std::runtime_error("Loading cert failed. This was after test loading the certificate, so is very unexpected.");  
212 - if (SSL_CTX_use_PrivateKey_file(sslctx, confFileParser->sslPrivkey.c_str(), SSL_FILETYPE_PEM) != 1)  
213 - throw std::runtime_error("Loading key failed. This was after test loading the certificate, so is very unexpected.");  
214 -}  
215 -  
216 -int MainApp::createListenSocket(int portNr, bool ssl) 202 +int MainApp::createListenSocket(const std::shared_ptr<Listener> &listener)
217 { 203 {
218 - if (portNr <= 0) 204 + if (listener->port <= 0)
219 return -2; 205 return -2;
220 206
221 - int listen_fd = check<std::runtime_error>(socket(AF_INET, SOCK_STREAM, 0));  
222 -  
223 - // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT.  
224 - int optval = 1;  
225 - check<std::runtime_error>(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval)));  
226 -  
227 - int flags = fcntl(listen_fd, F_GETFL);  
228 - check<std::runtime_error>(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK )); 207 + logger->logf(LOG_NOTICE, "Creating %s listener on port %d", listener->getProtocolName().c_str(), listener->port);
229 208
230 - struct sockaddr_in in_addr_plain;  
231 - in_addr_plain.sin_family = AF_INET;  
232 - in_addr_plain.sin_addr.s_addr = INADDR_ANY;  
233 - in_addr_plain.sin_port = htons(portNr); 209 + try
  210 + {
  211 + int listen_fd = check<std::runtime_error>(socket(AF_INET, SOCK_STREAM, 0));
234 212
235 - check<std::runtime_error>(bind(listen_fd, (struct sockaddr *)(&in_addr_plain), sizeof(struct sockaddr_in)));  
236 - check<std::runtime_error>(listen(listen_fd, 1024)); 213 + // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT.
  214 + int optval = 1;
  215 + check<std::runtime_error>(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval)));
237 216
238 - struct epoll_event ev;  
239 - memset(&ev, 0, sizeof (struct epoll_event)); 217 + int flags = fcntl(listen_fd, F_GETFL);
  218 + check<std::runtime_error>(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK ));
240 219
241 - ev.data.fd = listen_fd;  
242 - ev.events = EPOLLIN;  
243 - check<std::runtime_error>(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, listen_fd, &ev)); 220 + struct sockaddr_in in_addr_plain;
  221 + in_addr_plain.sin_family = AF_INET;
  222 + in_addr_plain.sin_addr.s_addr = INADDR_ANY;
  223 + in_addr_plain.sin_port = htons(listener->port);
244 224
245 - std::string socketType = "plain"; 225 + check<std::runtime_error>(bind(listen_fd, (struct sockaddr *)(&in_addr_plain), sizeof(struct sockaddr_in)));
  226 + check<std::runtime_error>(listen(listen_fd, 1024));
246 227
247 - if (ssl)  
248 - socketType = "SSL"; 228 + struct epoll_event ev;
  229 + memset(&ev, 0, sizeof (struct epoll_event));
249 230
250 - logger->logf(LOG_NOTICE, "Listening on %s port %d", socketType.c_str(), portNr); 231 + ev.data.fd = listen_fd;
  232 + ev.events = EPOLLIN;
  233 + check<std::runtime_error>(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, listen_fd, &ev));
251 234
252 - return listen_fd; 235 + return listen_fd;
  236 + }
  237 + catch (std::exception &ex)
  238 + {
  239 + logger->logf(LOG_NOTICE, "Creating %s listener on port %d failed: %s", listener->getProtocolName().c_str(), listener->port, ex.what());
  240 + return -1;
  241 + }
  242 + return -1;
253 } 243 }
254 244
255 void MainApp::wakeUpThread() 245 void MainApp::wakeUpThread()
@@ -349,9 +339,14 @@ void MainApp::start() @@ -349,9 +339,14 @@ void MainApp::start()
349 { 339 {
350 timer.start(); 340 timer.start();
351 341
352 - int listen_fd_plain = createListenSocket(this->listenPort, false);  
353 - int listen_fd_ssl = createListenSocket(this->sslListenPort, true);  
354 - int listen_fd_websocket_plain = createListenSocket(1443, true); 342 + std::map<int, std::shared_ptr<Listener>> listenerMap;
  343 +
  344 + for(std::shared_ptr<Listener> &listener : this->listeners)
  345 + {
  346 + int fd = createListenSocket(listener);
  347 + if (fd > 0)
  348 + listenerMap[fd] = listener;
  349 + }
355 350
356 #ifdef NDEBUG 351 #ifdef NDEBUG
357 logger->noLongerLogToStd(); 352 logger->noLongerLogToStd();
@@ -365,7 +360,7 @@ void MainApp::start() @@ -365,7 +360,7 @@ void MainApp::start()
365 360
366 for (int i = 0; i < num_threads; i++) 361 for (int i = 0; i < num_threads; i++)
367 { 362 {
368 - std::shared_ptr<ThreadData> t(new ThreadData(i, subscriptionStore, *confFileParser.get(), settings)); 363 + std::shared_ptr<ThreadData> t(new ThreadData(i, subscriptionStore, settings));
369 t->start(&do_thread_work); 364 t->start(&do_thread_work);
370 threads.push_back(t); 365 threads.push_back(t);
371 } 366 }
@@ -392,22 +387,22 @@ void MainApp::start() @@ -392,22 +387,22 @@ void MainApp::start()
392 int cur_fd = events[i].data.fd; 387 int cur_fd = events[i].data.fd;
393 try 388 try
394 { 389 {
395 - if (cur_fd == listen_fd_plain || cur_fd == listen_fd_ssl || listen_fd_websocket_plain) 390 + if (cur_fd != taskEventFd)
396 { 391 {
  392 + std::shared_ptr<Listener> listener = listenerMap[cur_fd];
397 std::shared_ptr<ThreadData> thread_data = threads[next_thread_index++ % num_threads]; 393 std::shared_ptr<ThreadData> thread_data = threads[next_thread_index++ % num_threads];
398 394
399 - logger->logf(LOG_INFO, "Accepting connection on thread %d", thread_data->threadnr); 395 + logger->logf(LOG_INFO, "Accepting connection on thread %d on %s", thread_data->threadnr, listener->getProtocolName().c_str());
400 396
401 struct sockaddr addr; 397 struct sockaddr addr;
402 memset(&addr, 0, sizeof(struct sockaddr)); 398 memset(&addr, 0, sizeof(struct sockaddr));
403 socklen_t len = sizeof(struct sockaddr); 399 socklen_t len = sizeof(struct sockaddr);
404 int fd = check<std::runtime_error>(accept(cur_fd, &addr, &len)); 400 int fd = check<std::runtime_error>(accept(cur_fd, &addr, &len));
405 401
406 - bool websocket = cur_fd == listen_fd_websocket_plain;  
407 SSL *clientSSL = nullptr; 402 SSL *clientSSL = nullptr;
408 - if (cur_fd == listen_fd_ssl) 403 + if (listener->isSsl())
409 { 404 {
410 - clientSSL = SSL_new(sslctx); 405 + clientSSL = SSL_new(listener->sslctx->get());
411 406
412 if (clientSSL == NULL) 407 if (clientSSL == NULL)
413 { 408 {
@@ -419,10 +414,10 @@ void MainApp::start() @@ -419,10 +414,10 @@ void MainApp::start()
419 SSL_set_fd(clientSSL, fd); 414 SSL_set_fd(clientSSL, fd);
420 } 415 }
421 416
422 - Client_p client(new Client(fd, thread_data, clientSSL, websocket, settings)); 417 + Client_p client(new Client(fd, thread_data, clientSSL, listener->websocket, settings));
423 thread_data->giveClient(client); 418 thread_data->giveClient(client);
424 } 419 }
425 - else if (cur_fd == taskEventFd) 420 + else
426 { 421 {
427 uint64_t eventfd_value = 0; 422 uint64_t eventfd_value = 0;
428 check<std::runtime_error>(read(cur_fd, &eventfd_value, sizeof(uint64_t))); 423 check<std::runtime_error>(read(cur_fd, &eventfd_value, sizeof(uint64_t)));
@@ -434,10 +429,6 @@ void MainApp::start() @@ -434,10 +429,6 @@ void MainApp::start()
434 } 429 }
435 taskQueue.clear(); 430 taskQueue.clear();
436 } 431 }
437 - else  
438 - {  
439 - throw std::runtime_error("Bug: the main thread had activity on an fd it's not supposed to monitor.");  
440 - }  
441 } 432 }
442 catch (std::exception &ex) 433 catch (std::exception &ex)
443 { 434 {
@@ -452,12 +443,18 @@ void MainApp::start() @@ -452,12 +443,18 @@ void MainApp::start()
452 thread->quit(); 443 thread->quit();
453 } 444 }
454 445
455 - close(listen_fd_plain);  
456 - close(listen_fd_ssl); 446 + for(auto pair : listenerMap)
  447 + {
  448 + close(pair.first);
  449 + }
457 } 450 }
458 451
459 void MainApp::quit() 452 void MainApp::quit()
460 { 453 {
  454 + std::lock_guard<std::mutex> guard(quitMutex);
  455 + if (!running)
  456 + return;
  457 +
461 Logger *logger = Logger::getInstance(); 458 Logger *logger = Logger::getInstance();
462 logger->logf(LOG_NOTICE, "Quitting FlashMQ"); 459 logger->logf(LOG_NOTICE, "Quitting FlashMQ");
463 timer.stop(); 460 timer.stop();
@@ -472,26 +469,21 @@ void MainApp::loadConfig() @@ -472,26 +469,21 @@ void MainApp::loadConfig()
472 // Atomic loading, first test. 469 // Atomic loading, first test.
473 confFileParser->loadFile(true); 470 confFileParser->loadFile(true);
474 confFileParser->loadFile(false); 471 confFileParser->loadFile(false);
  472 + settings = std::move(confFileParser->settings);
475 473
476 - logger->setLogPath(confFileParser->logPath);  
477 - logger->reOpen(); 474 + // For now, it's too much work to be able to reload new listeners, with all the shared resource stuff going on. So, I'm
  475 + // loading them to a local var which is never updated.
  476 + if (listeners.empty())
  477 + listeners = settings->listeners;
478 478
479 - listenPort = confFileParser->listenPort;  
480 - sslListenPort = confFileParser->sslListenPort; 479 + logger->setLogPath(settings->logPath);
  480 + logger->reOpen();
481 481
482 - if (sslctx == nullptr && sslListenPort > 0) 482 + for (std::shared_ptr<Listener> &l : this->listeners)
483 { 483 {
484 - sslctx = SSL_CTX_new(TLS_server_method());  
485 - SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv3); // TODO: config option  
486 - SSL_CTX_set_options(sslctx, SSL_OP_NO_TLSv1); // TODO: config option 484 + l->loadCertAndKeyFromConfig();
487 } 485 }
488 486
489 - settings.allow_unsafe_clientid_chars = confFileParser->allowUnsafeClientidChars;  
490 - settings.clientInitialBufferSize = confFileParser->clientInitialBufferSize;  
491 - settings.maxPacketSize = confFileParser->maxPacketSize;  
492 -  
493 - setCertAndKeyFromConfig();  
494 -  
495 for (std::shared_ptr<ThreadData> &thread : threads) 487 for (std::shared_ptr<ThreadData> &thread : threads)
496 { 488 {
497 thread->queueReload(settings); 489 thread->queueReload(settings);
mainapp.h
@@ -20,7 +20,6 @@ @@ -20,7 +20,6 @@
20 #include "subscriptionstore.h" 20 #include "subscriptionstore.h"
21 #include "configfileparser.h" 21 #include "configfileparser.h"
22 #include "timer.h" 22 #include "timer.h"
23 -#include "globalsettings.h"  
24 23
25 class MainApp 24 class MainApp
26 { 25 {
@@ -37,11 +36,9 @@ class MainApp @@ -37,11 +36,9 @@ class MainApp
37 int taskEventFd = -1; 36 int taskEventFd = -1;
38 std::mutex eventMutex; 37 std::mutex eventMutex;
39 Timer timer; 38 Timer timer;
40 - GlobalSettings settings;  
41 -  
42 - uint listenPort = 0;  
43 - uint sslListenPort = 0;  
44 - SSL_CTX *sslctx = nullptr; 39 + std::shared_ptr<Settings> settings;
  40 + std::list<std::shared_ptr<Listener>> listeners;
  41 + std::mutex quitMutex;
45 42
46 Logger *logger = Logger::getInstance(); 43 Logger *logger = Logger::getInstance();
47 44
@@ -49,8 +46,7 @@ class MainApp @@ -49,8 +46,7 @@ class MainApp
49 void reloadConfig(); 46 void reloadConfig();
50 static void doHelp(const char *arg); 47 static void doHelp(const char *arg);
51 static void showLicense(); 48 static void showLicense();
52 - void setCertAndKeyFromConfig();  
53 - int createListenSocket(int portNr, bool ssl); 49 + int createListenSocket(const std::shared_ptr<Listener> &listener);
54 void wakeUpThread(); 50 void wakeUpThread();
55 void queueKeepAliveCheckAtAllThreads(); 51 void queueKeepAliveCheckAtAllThreads();
56 52
@@ -66,7 +62,6 @@ public: @@ -66,7 +62,6 @@ public:
66 bool getStarted() const {return started;} 62 bool getStarted() const {return started;}
67 static void testConfig(); 63 static void testConfig();
68 64
69 - GlobalSettings &getGlobalSettings();  
70 void queueConfigReload(); 65 void queueConfigReload();
71 void queueCleanup(); 66 void queueCleanup();
72 }; 67 };
mosquittoauthoptcompatwrap.cpp 0 โ†’ 100644
  1 +#include "mosquittoauthoptcompatwrap.h"
  2 +
  3 +mosquitto_auth_opt::mosquitto_auth_opt(const std::string &key, const std::string &value)
  4 +{
  5 + this->key = strdup(key.c_str());
  6 + this->value = strdup(value.c_str());
  7 +}
  8 +
  9 +mosquitto_auth_opt::mosquitto_auth_opt(mosquitto_auth_opt &&other)
  10 +{
  11 + this->key = other.key;
  12 + this->value = other.value;
  13 + other.key = nullptr;
  14 + other.value = nullptr;
  15 +}
  16 +
  17 +mosquitto_auth_opt::mosquitto_auth_opt(const mosquitto_auth_opt &other)
  18 +{
  19 + this->key = strdup(other.key);
  20 + this->value = strdup(other.value);
  21 +}
  22 +
  23 +mosquitto_auth_opt::~mosquitto_auth_opt()
  24 +{
  25 + if (key)
  26 + delete key;
  27 + if (value)
  28 + delete value;
  29 +}
  30 +
  31 +mosquitto_auth_opt &mosquitto_auth_opt::operator=(const mosquitto_auth_opt &other)
  32 +{
  33 + if (key)
  34 + delete key;
  35 + if (value)
  36 + delete value;
  37 +
  38 + this->key = strdup(other.key);
  39 + this->value = strdup(other.value);
  40 +
  41 + return *this;
  42 +}
  43 +
  44 +AuthOptCompatWrap::AuthOptCompatWrap(const std::unordered_map<std::string, std::string> &authOpts)
  45 +{
  46 + for(auto &pair : authOpts)
  47 + {
  48 + mosquitto_auth_opt opt(pair.first, pair.second);
  49 + optArray.push_back(std::move(opt));
  50 + }
  51 +}
  52 +
mosquittoauthoptcompatwrap.h 0 โ†’ 100644
  1 +#ifndef MOSQUITTOAUTHOPTCOMPATWRAP_H
  2 +#define MOSQUITTOAUTHOPTCOMPATWRAP_H
  3 +
  4 +#include <vector>
  5 +#include <unordered_map>
  6 +#include <cstring>
  7 +
  8 +/**
  9 + * @brief The mosquitto_auth_opt struct is a resource managed class of auth options, compatible with passing as arguments to Mosquitto
  10 + * auth plugins.
  11 + *
  12 + * It's fully assignable and copyable.
  13 + */
  14 +struct mosquitto_auth_opt
  15 +{
  16 + char *key = nullptr;
  17 + char *value = nullptr;
  18 +
  19 + mosquitto_auth_opt(const std::string &key, const std::string &value);
  20 + mosquitto_auth_opt(mosquitto_auth_opt &&other);
  21 + mosquitto_auth_opt(const mosquitto_auth_opt &other);
  22 + ~mosquitto_auth_opt();
  23 +
  24 + mosquitto_auth_opt& operator=(const mosquitto_auth_opt &other);
  25 +};
  26 +
  27 +/**
  28 + * @brief The AuthOptCompatWrap struct contains a vector of mosquitto auth options, with a head pointer and count which can be passed to
  29 + * Mosquitto auth plugins.
  30 + */
  31 +struct AuthOptCompatWrap
  32 +{
  33 + std::vector<struct mosquitto_auth_opt> optArray;
  34 +
  35 + AuthOptCompatWrap(const std::unordered_map<std::string, std::string> &authOpts);
  36 + AuthOptCompatWrap(const AuthOptCompatWrap &other) = default;
  37 + AuthOptCompatWrap(AuthOptCompatWrap &&other) = delete;
  38 + AuthOptCompatWrap() = default;
  39 +
  40 + struct mosquitto_auth_opt *head() { return &optArray[0]; }
  41 + int size() { return optArray.size(); }
  42 +
  43 + AuthOptCompatWrap &operator=(const AuthOptCompatWrap &other) = default;
  44 +};
  45 +
  46 +
  47 +#endif // MOSQUITTOAUTHOPTCOMPATWRAP_H
mqttpacket.cpp
@@ -145,7 +145,7 @@ void MqttPacket::handleConnect() @@ -145,7 +145,7 @@ void MqttPacket::handleConnect()
145 145
146 uint16_t variable_header_length = readTwoBytesToUInt16(); 146 uint16_t variable_header_length = readTwoBytesToUInt16();
147 147
148 - const GlobalSettings &settings = sender->getThreadData()->settingsLocalCopy; 148 + const Settings &settings = sender->getThreadData()->settingsLocalCopy;
149 149
150 if (variable_header_length == 4 || variable_header_length == 6) 150 if (variable_header_length == 4 || variable_header_length == 6)
151 { 151 {
@@ -232,7 +232,7 @@ void MqttPacket::handleConnect() @@ -232,7 +232,7 @@ void MqttPacket::handleConnect()
232 bool validClientId = true; 232 bool validClientId = true;
233 233
234 // Check for wildcard chars in case the client_id ever appears in topics. 234 // Check for wildcard chars in case the client_id ever appears in topics.
235 - if (!settings.allow_unsafe_clientid_chars && (strContains(client_id, "+") || strContains(client_id, "#"))) 235 + if (!settings.allowUnsafeClientidChars && (strContains(client_id, "+") || strContains(client_id, "#")))
236 { 236 {
237 logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false:", client_id.c_str()); 237 logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false:", client_id.c_str());
238 validClientId = false; 238 validClientId = false;
mqttpacket.h
@@ -15,7 +15,6 @@ @@ -15,7 +15,6 @@
15 #include "cirbuf.h" 15 #include "cirbuf.h"
16 #include "logger.h" 16 #include "logger.h"
17 #include "mainapp.h" 17 #include "mainapp.h"
18 -#include "globalsettings.h"  
19 18
20 struct RemainingLength 19 struct RemainingLength
21 { 20 {
settings.cpp 0 โ†’ 100644
  1 +#include "settings.h"
  2 +
  3 +
settings.h 0 โ†’ 100644
  1 +#ifndef SETTINGS_H
  2 +#define SETTINGS_H
  3 +
  4 +#include <memory>
  5 +#include <list>
  6 +
  7 +#include "mosquittoauthoptcompatwrap.h"
  8 +#include "listener.h"
  9 +
  10 +class Settings
  11 +{
  12 + friend class ConfigFileParser;
  13 +
  14 + AuthOptCompatWrap authOptCompatWrap;
  15 +
  16 +public:
  17 + // Actual config options with their defaults.
  18 + std::string authPluginPath;
  19 + std::string logPath;
  20 + bool allowUnsafeClientidChars = false;
  21 + int clientInitialBufferSize = 1024; // Must be power of 2
  22 + int maxPacketSize = 268435461; // 256 MB + 5
  23 + std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined.
  24 +
  25 + AuthOptCompatWrap &getAuthOptsCompat();
  26 +};
  27 +
  28 +#endif // SETTINGS_H
sslctxmanager.cpp
@@ -16,3 +16,8 @@ SSL_CTX *SslCtxManager::get() const @@ -16,3 +16,8 @@ SSL_CTX *SslCtxManager::get() const
16 { 16 {
17 return ssl_ctx; 17 return ssl_ctx;
18 } 18 }
  19 +
  20 +SslCtxManager::operator bool() const
  21 +{
  22 + return ssl_ctx == nullptr;
  23 +}
sslctxmanager.h
@@ -11,6 +11,7 @@ public: @@ -11,6 +11,7 @@ public:
11 ~SslCtxManager(); 11 ~SslCtxManager();
12 12
13 SSL_CTX *get() const; 13 SSL_CTX *get() const;
  14 + operator bool() const;
14 }; 15 };
15 16
16 #endif // SSLCTXMANAGER_H 17 #endif // SSLCTXMANAGER_H
threaddata.cpp
@@ -2,12 +2,11 @@ @@ -2,12 +2,11 @@
2 #include <string> 2 #include <string>
3 #include <sstream> 3 #include <sstream>
4 4
5 -ThreadData::ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore, ConfigFileParser &confFileParser, const GlobalSettings &settings) : 5 +ThreadData::ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore, std::shared_ptr<Settings> settings) :
6 subscriptionStore(subscriptionStore), 6 subscriptionStore(subscriptionStore),
7 - confFileParser(confFileParser),  
8 - authPlugin(confFileParser),  
9 - threadnr(threadnr),  
10 - settingsLocalCopy(settings) 7 + settingsLocalCopy(*settings.get()),
  8 + authPlugin(settingsLocalCopy),
  9 + threadnr(threadnr)
11 { 10 {
12 logger = Logger::getInstance(); 11 logger = Logger::getInstance();
13 12
@@ -146,18 +145,19 @@ void ThreadData::doKeepAliveCheck() @@ -146,18 +145,19 @@ void ThreadData::doKeepAliveCheck()
146 145
147 void ThreadData::initAuthPlugin() 146 void ThreadData::initAuthPlugin()
148 { 147 {
149 - authPlugin.loadPlugin(confFileParser.authPluginPath); 148 + authPlugin.loadPlugin(settingsLocalCopy.authPluginPath);
150 authPlugin.init(); 149 authPlugin.init();
151 authPlugin.securityInit(false); 150 authPlugin.securityInit(false);
152 } 151 }
153 152
154 -void ThreadData::reload(GlobalSettings settings) 153 +void ThreadData::reload(std::shared_ptr<Settings> settings)
155 { 154 {
156 logger->logf(LOG_DEBUG, "Doing reload in thread %d", threadnr); 155 logger->logf(LOG_DEBUG, "Doing reload in thread %d", threadnr);
157 156
158 try 157 try
159 { 158 {
160 - settingsLocalCopy = settings; 159 + // Because the auth plugin has a reference to it, it will also be updated.
  160 + settingsLocalCopy = *settings.get();
161 161
162 authPlugin.securityCleanup(true); 162 authPlugin.securityCleanup(true);
163 authPlugin.securityInit(true); 163 authPlugin.securityInit(true);
@@ -172,7 +172,7 @@ void ThreadData::reload(GlobalSettings settings) @@ -172,7 +172,7 @@ void ThreadData::reload(GlobalSettings settings)
172 } 172 }
173 } 173 }
174 174
175 -void ThreadData::queueReload(GlobalSettings settings) 175 +void ThreadData::queueReload(std::shared_ptr<Settings> settings)
176 { 176 {
177 std::lock_guard<std::mutex> locker(taskQueueMutex); 177 std::lock_guard<std::mutex> locker(taskQueueMutex);
178 178
threaddata.h
@@ -20,7 +20,6 @@ @@ -20,7 +20,6 @@
20 #include "configfileparser.h" 20 #include "configfileparser.h"
21 #include "authplugin.h" 21 #include "authplugin.h"
22 #include "logger.h" 22 #include "logger.h"
23 -#include "globalsettings.h"  
24 23
25 typedef void (*thread_f)(ThreadData *); 24 typedef void (*thread_f)(ThreadData *);
26 25
@@ -29,14 +28,14 @@ class ThreadData @@ -29,14 +28,14 @@ class ThreadData
29 std::unordered_map<int, Client_p> clients_by_fd; 28 std::unordered_map<int, Client_p> clients_by_fd;
30 std::mutex clients_by_fd_mutex; 29 std::mutex clients_by_fd_mutex;
31 std::shared_ptr<SubscriptionStore> subscriptionStore; 30 std::shared_ptr<SubscriptionStore> subscriptionStore;
32 - ConfigFileParser &confFileParser;  
33 Logger *logger; 31 Logger *logger;
34 32
35 - void reload(GlobalSettings settings); 33 + void reload(std::shared_ptr<Settings> settings);
36 void wakeUpThread(); 34 void wakeUpThread();
37 void doKeepAliveCheck(); 35 void doKeepAliveCheck();
38 36
39 public: 37 public:
  38 + Settings settingsLocalCopy; // Is updated on reload, within the thread loop.
40 AuthPlugin authPlugin; 39 AuthPlugin authPlugin;
41 bool running = true; 40 bool running = true;
42 std::thread thread; 41 std::thread thread;
@@ -45,9 +44,8 @@ public: @@ -45,9 +44,8 @@ public:
45 int taskEventFd = 0; 44 int taskEventFd = 0;
46 std::mutex taskQueueMutex; 45 std::mutex taskQueueMutex;
47 std::forward_list<std::function<void()>> taskQueue; 46 std::forward_list<std::function<void()>> taskQueue;
48 - GlobalSettings settingsLocalCopy; // Is updated on reload, within the thread loop.  
49 47
50 - ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore, ConfigFileParser &confFileParser, const GlobalSettings &settings); 48 + ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore, std::shared_ptr<Settings> settings);
51 ThreadData(const ThreadData &other) = delete; 49 ThreadData(const ThreadData &other) = delete;
52 ThreadData(ThreadData &&other) = delete; 50 ThreadData(ThreadData &&other) = delete;
53 51
@@ -60,7 +58,7 @@ public: @@ -60,7 +58,7 @@ public:
60 std::shared_ptr<SubscriptionStore> &getSubscriptionStore(); 58 std::shared_ptr<SubscriptionStore> &getSubscriptionStore();
61 59
62 void initAuthPlugin(); 60 void initAuthPlugin();
63 - void queueReload(GlobalSettings settings); 61 + void queueReload(std::shared_ptr<Settings> settings);
64 void queueDoKeepAliveCheck(); 62 void queueDoKeepAliveCheck();
65 63
66 }; 64 };
utils.cpp
@@ -3,10 +3,15 @@ @@ -3,10 +3,15 @@
3 #include "sys/time.h" 3 #include "sys/time.h"
4 #include "sys/random.h" 4 #include "sys/random.h"
5 #include <algorithm> 5 #include <algorithm>
6 -#include <sstream> 6 +#include <cstdio>
  7 +
  8 +#include "openssl/ssl.h"
  9 +#include "openssl/err.h"
7 10
8 #include "exceptions.h" 11 #include "exceptions.h"
9 #include "cirbuf.h" 12 #include "cirbuf.h"
  13 +#include "sslctxmanager.h"
  14 +#include "logger.h"
10 15
11 std::list<std::__cxx11::string> split(const std::string &input, const char sep, size_t max, bool keep_empty_parts) 16 std::list<std::__cxx11::string> split(const std::string &input, const char sep, size_t max, bool keep_empty_parts)
12 { 17 {
@@ -346,3 +351,47 @@ std::string generateWebsocketAnswer(const std::string &amp;acceptString) @@ -346,3 +351,47 @@ std::string generateWebsocketAnswer(const std::string &amp;acceptString)
346 oss.flush(); 351 oss.flush();
347 return oss.str(); 352 return oss.str();
348 } 353 }
  354 +
  355 +// Using a separate ssl context to test, because it's the easiest way to load certs and key atomitcally.
  356 +void testSsl(const std::string &fullchain, const std::string &privkey)
  357 +{
  358 + if (fullchain.empty() && privkey.empty())
  359 + throw ConfigFileException("No privkey and fullchain specified.");
  360 +
  361 + if (fullchain.empty())
  362 + throw ConfigFileException("No private key specified for fullchain");
  363 +
  364 + if (privkey.empty())
  365 + throw ConfigFileException("No fullchain specified for private key");
  366 +
  367 + SslCtxManager sslCtx;
  368 + if (SSL_CTX_use_certificate_file(sslCtx.get(), fullchain.c_str(), SSL_FILETYPE_PEM) != 1)
  369 + {
  370 + ERR_print_errors_cb(logSslError, NULL);
  371 + throw ConfigFileException("Error loading full chain " + fullchain);
  372 + }
  373 + if (SSL_CTX_use_PrivateKey_file(sslCtx.get(), privkey.c_str(), SSL_FILETYPE_PEM) != 1)
  374 + {
  375 + ERR_print_errors_cb(logSslError, NULL);
  376 + throw ConfigFileException("Error loading private key " + privkey);
  377 + }
  378 + if (SSL_CTX_check_private_key(sslCtx.get()) != 1)
  379 + {
  380 + ERR_print_errors_cb(logSslError, NULL);
  381 + throw ConfigFileException("Private key and certificate don't match.");
  382 + }
  383 +}
  384 +
  385 +std::string formatString(const std::string str, ...)
  386 +{
  387 + char buf[512];
  388 +
  389 + va_list valist;
  390 + va_start(valist, str);
  391 + vsnprintf(buf, 512, str.c_str(), valist);
  392 + va_end(valist);
  393 +
  394 + std::string result(buf, 512);
  395 +
  396 + return result;
  397 +}
@@ -55,4 +55,8 @@ std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion); @@ -55,4 +55,8 @@ std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion);
55 std::string generateBadHttpRequestReponse(const std::string &msg); 55 std::string generateBadHttpRequestReponse(const std::string &msg);
56 std::string generateWebsocketAnswer(const std::string &acceptString); 56 std::string generateWebsocketAnswer(const std::string &acceptString);
57 57
  58 +void testSsl(const std::string &fullchain, const std::string &privkey);
  59 +
  60 +std::string formatString(const std::string str, ...);
  61 +
58 #endif // UTILS_H 62 #endif // UTILS_H