Commit 65146ae28f3327793a00f89bd55b5ed235c5a6be

Authored by Wiebe Cazemier
1 parent 1f378699

Add setting 'allow_unsafe_clientid_chars'

And some side issues.
CMakeLists.txt
... ... @@ -28,6 +28,7 @@ add_executable(FlashMQ
28 28 configfileparser.cpp
29 29 sslctxmanager.cpp
30 30 timer.cpp
  31 + globalsettings.cpp
31 32 )
32 33  
33 34 target_link_libraries(FlashMQ pthread dl ssl crypto)
... ...
configfileparser.cpp
... ... @@ -98,6 +98,7 @@ ConfigFileParser::ConfigFileParser(const std::string &path) :
98 98 validKeys.insert("ssl_listen_port");
99 99 validKeys.insert("fullchain");
100 100 validKeys.insert("privkey");
  101 + validKeys.insert("allow_unsafe_clientid_chars");
101 102 }
102 103  
103 104 void ConfigFileParser::loadFile(bool test)
... ... @@ -192,6 +193,13 @@ void ConfigFileParser::loadFile(bool test)
192 193 this->logPath = value;
193 194 }
194 195  
  196 + if (key == "allow_unsafe_clientid_chars")
  197 + {
  198 + bool tmp = stringTruthiness(value);
  199 + if (!test)
  200 + this->allowUnsafeClientidChars = tmp;
  201 + }
  202 +
195 203 if (key == "fullchain")
196 204 {
197 205 checkFileAccess(key, value);
... ...
configfileparser.h
... ... @@ -54,6 +54,7 @@ public:
54 54 std::string sslPrivkey;
55 55 uint listenPort = 1883;
56 56 uint sslListenPort = 0;
  57 + bool allowUnsafeClientidChars = false;
57 58 };
58 59  
59 60 #endif // CONFIGFILEPARSER_H
... ...
globalsettings.cpp 0 → 100644
  1 +#include "globalsettings.h"
  2 +
  3 +GlobalSettings *GlobalSettings::instance = nullptr;
  4 +
  5 +GlobalSettings::GlobalSettings()
  6 +{
  7 +
  8 +}
  9 +
  10 +GlobalSettings *GlobalSettings::getInstance()
  11 +{
  12 + if (instance == nullptr)
  13 + instance = new GlobalSettings();
  14 +
  15 + return instance;
  16 +}
... ...
globalsettings.h 0 → 100644
  1 +#ifndef GLOBALSETTINGS_H
  2 +#define GLOBALSETTINGS_H
  3 +
  4 +// 'Global' as in, needed outside of the mainapp, like listen ports.
  5 +class GlobalSettings
  6 +{
  7 + static GlobalSettings *instance;
  8 + GlobalSettings();
  9 +public:
  10 + static GlobalSettings *getInstance();
  11 +
  12 + bool allow_unsafe_clientid_chars = false;
  13 +};
  14 +#endif // GLOBALSETTINGS_H
... ...
mainapp.cpp
... ... @@ -453,6 +453,7 @@ void MainApp::quit()
453 453 void MainApp::loadConfig()
454 454 {
455 455 Logger *logger = Logger::getInstance();
  456 + GlobalSettings *setting = GlobalSettings::getInstance();
456 457  
457 458 // Atomic loading, first test.
458 459 confFileParser->loadFile(true);
... ... @@ -471,6 +472,8 @@ void MainApp::loadConfig()
471 472 SSL_CTX_set_options(sslctx, SSL_OP_NO_TLSv1); // TODO: config option
472 473 }
473 474  
  475 + setting->allow_unsafe_clientid_chars = confFileParser->allowUnsafeClientidChars;
  476 +
474 477 setCertAndKeyFromConfig();
475 478 }
476 479  
... ...
mainapp.h
... ... @@ -20,6 +20,7 @@
20 20 #include "subscriptionstore.h"
21 21 #include "configfileparser.h"
22 22 #include "timer.h"
  23 +#include "globalsettings.h"
23 24  
24 25 class MainApp
25 26 {
... ... @@ -63,7 +64,7 @@ public:
63 64 bool getStarted() const {return started;}
64 65 static void testConfig();
65 66  
66   -
  67 + GlobalSettings &getGlobalSettings();
67 68 void queueConfigReload();
68 69 void queueCleanup();
69 70 };
... ...
mqttpacket.cpp
... ... @@ -153,6 +153,8 @@ void MqttPacket::handleConnect()
153 153 if (sender->hasConnectPacketSeen())
154 154 throw ProtocolError("Client already sent a CONNECT.");
155 155  
  156 + GlobalSettings *settings = GlobalSettings::getInstance();
  157 +
156 158 uint16_t variable_header_length = readTwoBytesToUInt16();
157 159  
158 160 if (variable_header_length == 4 || variable_header_length == 6)
... ... @@ -240,10 +242,9 @@ void MqttPacket::handleConnect()
240 242 bool validClientId = true;
241 243  
242 244 // Check for wildcard chars in case the client_id ever appears in topics.
243   - // TODO: make setting?
244   - if (strContains(client_id, "+") || strContains(client_id, "#"))
  245 + if (!settings->allow_unsafe_clientid_chars && (strContains(client_id, "+") || strContains(client_id, "#")))
245 246 {
246   - logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id:", client_id.c_str());
  247 + logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false:", client_id.c_str());
247 248 validClientId = false;
248 249 }
249 250 else if (!clean_session && client_id.empty())
... ... @@ -261,6 +262,7 @@ void MqttPacket::handleConnect()
261 262 {
262 263 ConnAck connAck(ConnAckReturnCodes::ClientIdRejected);
263 264 MqttPacket response(connAck);
  265 + sender->setDisconnectReason("Invalid clientID");
264 266 sender->setReadyForDisconnect();
265 267 sender->writeMqttPacket(response);
266 268 return;
... ... @@ -288,6 +290,7 @@ void MqttPacket::handleConnect()
288 290 {
289 291 ConnAck connDeny(ConnAckReturnCodes::NotAuthorized);
290 292 MqttPacket response(connDeny);
  293 + sender->setDisconnectReason("Access denied");
291 294 sender->setReadyForDisconnect();
292 295 sender->writeMqttPacket(response);
293 296 logger->logf(LOG_NOTICE, "User '%s' access denied", username.c_str());
... ...
mqttpacket.h
... ... @@ -14,6 +14,8 @@
14 14 #include "subscriptionstore.h"
15 15 #include "cirbuf.h"
16 16 #include "logger.h"
  17 +#include "mainapp.h"
  18 +#include "globalsettings.h"
17 19  
18 20 struct RemainingLength
19 21 {
... ...
utils.cpp
... ... @@ -203,3 +203,21 @@ std::string getSecureRandomString(const size_t len)
203 203 }
204 204 return randomString;
205 205 }
  206 +
  207 +std::string str_tolower(std::string s)
  208 +{
  209 + std::transform(s.begin(), s.end(), s.begin(),
  210 + [](unsigned char c){ return std::tolower(c); });
  211 + return s;
  212 +}
  213 +
  214 +bool stringTruthiness(const std::string &val)
  215 +{
  216 + std::string val_ = str_tolower(val);
  217 + trim(val_);
  218 + if (val_ == "yes" || val_ == "true" || val_ == "on")
  219 + return true;
  220 + if (val_ == "no" || val_ == "false" || val_ == "off")
  221 + return false;
  222 + throw ConfigFileException("Value '" + val + "' can't be converted to boolean");
  223 +}
... ...
... ... @@ -39,5 +39,7 @@ bool startsWith(const std::string &s, const std::string &needle);
39 39  
40 40 int64_t currentMSecsSinceEpoch();
41 41 std::string getSecureRandomString(const size_t len);
  42 +std::string str_tolower(std::string s);
  43 +bool stringTruthiness(const std::string &val);
42 44  
43 45 #endif // UTILS_H
... ...