diff --git a/CMakeLists.txt b/CMakeLists.txt index d5d3ab2..ad1bfa7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,6 +28,7 @@ add_executable(FlashMQ configfileparser.cpp sslctxmanager.cpp timer.cpp + globalsettings.cpp ) target_link_libraries(FlashMQ pthread dl ssl crypto) diff --git a/configfileparser.cpp b/configfileparser.cpp index b9cc8b7..f6acf16 100644 --- a/configfileparser.cpp +++ b/configfileparser.cpp @@ -98,6 +98,7 @@ ConfigFileParser::ConfigFileParser(const std::string &path) : validKeys.insert("ssl_listen_port"); validKeys.insert("fullchain"); validKeys.insert("privkey"); + validKeys.insert("allow_unsafe_clientid_chars"); } void ConfigFileParser::loadFile(bool test) @@ -192,6 +193,13 @@ void ConfigFileParser::loadFile(bool test) this->logPath = value; } + if (key == "allow_unsafe_clientid_chars") + { + bool tmp = stringTruthiness(value); + if (!test) + this->allowUnsafeClientidChars = tmp; + } + if (key == "fullchain") { checkFileAccess(key, value); diff --git a/configfileparser.h b/configfileparser.h index abebdeb..bb4529a 100644 --- a/configfileparser.h +++ b/configfileparser.h @@ -54,6 +54,7 @@ public: std::string sslPrivkey; uint listenPort = 1883; uint sslListenPort = 0; + bool allowUnsafeClientidChars = false; }; #endif // CONFIGFILEPARSER_H diff --git a/globalsettings.cpp b/globalsettings.cpp new file mode 100644 index 0000000..e559a67 --- /dev/null +++ b/globalsettings.cpp @@ -0,0 +1,16 @@ +#include "globalsettings.h" + +GlobalSettings *GlobalSettings::instance = nullptr; + +GlobalSettings::GlobalSettings() +{ + +} + +GlobalSettings *GlobalSettings::getInstance() +{ + if (instance == nullptr) + instance = new GlobalSettings(); + + return instance; +} diff --git a/globalsettings.h b/globalsettings.h new file mode 100644 index 0000000..14fed52 --- /dev/null +++ b/globalsettings.h @@ -0,0 +1,14 @@ +#ifndef GLOBALSETTINGS_H +#define GLOBALSETTINGS_H + +// 'Global' as in, needed outside of the mainapp, like listen ports. +class GlobalSettings +{ + static GlobalSettings *instance; + GlobalSettings(); +public: + static GlobalSettings *getInstance(); + + bool allow_unsafe_clientid_chars = false; +}; +#endif // GLOBALSETTINGS_H diff --git a/mainapp.cpp b/mainapp.cpp index 5d49326..6bb8bf8 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -453,6 +453,7 @@ void MainApp::quit() void MainApp::loadConfig() { Logger *logger = Logger::getInstance(); + GlobalSettings *setting = GlobalSettings::getInstance(); // Atomic loading, first test. confFileParser->loadFile(true); @@ -471,6 +472,8 @@ void MainApp::loadConfig() SSL_CTX_set_options(sslctx, SSL_OP_NO_TLSv1); // TODO: config option } + setting->allow_unsafe_clientid_chars = confFileParser->allowUnsafeClientidChars; + setCertAndKeyFromConfig(); } diff --git a/mainapp.h b/mainapp.h index 413ca7d..97b6746 100644 --- a/mainapp.h +++ b/mainapp.h @@ -20,6 +20,7 @@ #include "subscriptionstore.h" #include "configfileparser.h" #include "timer.h" +#include "globalsettings.h" class MainApp { @@ -63,7 +64,7 @@ public: bool getStarted() const {return started;} static void testConfig(); - + GlobalSettings &getGlobalSettings(); void queueConfigReload(); void queueCleanup(); }; diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 79b781b..b8080ef 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -153,6 +153,8 @@ void MqttPacket::handleConnect() if (sender->hasConnectPacketSeen()) throw ProtocolError("Client already sent a CONNECT."); + GlobalSettings *settings = GlobalSettings::getInstance(); + uint16_t variable_header_length = readTwoBytesToUInt16(); if (variable_header_length == 4 || variable_header_length == 6) @@ -240,10 +242,9 @@ void MqttPacket::handleConnect() bool validClientId = true; // Check for wildcard chars in case the client_id ever appears in topics. - // TODO: make setting? - if (strContains(client_id, "+") || strContains(client_id, "#")) + if (!settings->allow_unsafe_clientid_chars && (strContains(client_id, "+") || strContains(client_id, "#"))) { - logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id:", client_id.c_str()); + logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false:", client_id.c_str()); validClientId = false; } else if (!clean_session && client_id.empty()) @@ -261,6 +262,7 @@ void MqttPacket::handleConnect() { ConnAck connAck(ConnAckReturnCodes::ClientIdRejected); MqttPacket response(connAck); + sender->setDisconnectReason("Invalid clientID"); sender->setReadyForDisconnect(); sender->writeMqttPacket(response); return; @@ -288,6 +290,7 @@ void MqttPacket::handleConnect() { ConnAck connDeny(ConnAckReturnCodes::NotAuthorized); MqttPacket response(connDeny); + sender->setDisconnectReason("Access denied"); sender->setReadyForDisconnect(); sender->writeMqttPacket(response); logger->logf(LOG_NOTICE, "User '%s' access denied", username.c_str()); diff --git a/mqttpacket.h b/mqttpacket.h index 92d85dd..2f4c86a 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -14,6 +14,8 @@ #include "subscriptionstore.h" #include "cirbuf.h" #include "logger.h" +#include "mainapp.h" +#include "globalsettings.h" struct RemainingLength { diff --git a/utils.cpp b/utils.cpp index baf8c46..17be781 100644 --- a/utils.cpp +++ b/utils.cpp @@ -203,3 +203,21 @@ std::string getSecureRandomString(const size_t len) } return randomString; } + +std::string str_tolower(std::string s) +{ + std::transform(s.begin(), s.end(), s.begin(), + [](unsigned char c){ return std::tolower(c); }); + return s; +} + +bool stringTruthiness(const std::string &val) +{ + std::string val_ = str_tolower(val); + trim(val_); + if (val_ == "yes" || val_ == "true" || val_ == "on") + return true; + if (val_ == "no" || val_ == "false" || val_ == "off") + return false; + throw ConfigFileException("Value '" + val + "' can't be converted to boolean"); +} diff --git a/utils.h b/utils.h index 6862146..b1161bc 100644 --- a/utils.h +++ b/utils.h @@ -39,5 +39,7 @@ bool startsWith(const std::string &s, const std::string &needle); int64_t currentMSecsSinceEpoch(); std::string getSecureRandomString(const size_t len); +std::string str_tolower(std::string s); +bool stringTruthiness(const std::string &val); #endif // UTILS_H