From 5d72070e0e6f6b6ca0a1f181c3d0041c923a5e32 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Tue, 2 Mar 2021 21:53:23 +0100 Subject: [PATCH] Support generic listener in config file --- CMakeLists.txt | 4 +++- authplugin.cpp | 12 ++++++------ authplugin.h | 4 ++-- client.cpp | 6 +++--- client.h | 2 +- configfileparser.cpp | 276 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------------------------------------------------------------------------------------------------------------------------- configfileparser.h | 44 +++++++++----------------------------------- exceptions.cpp | 2 ++ exceptions.h | 2 ++ forward_declarations.h | 1 + globalsettings.cpp | 3 --- globalsettings.h | 11 ----------- listener.cpp | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ listener.h | 22 ++++++++++++++++++++++ mainapp.cpp | 134 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------------------------------------------------- mainapp.h | 13 ++++--------- mosquittoauthoptcompatwrap.cpp | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ mosquittoauthoptcompatwrap.h | 47 +++++++++++++++++++++++++++++++++++++++++++++++ mqttpacket.cpp | 4 ++-- mqttpacket.h | 1 - settings.cpp | 3 +++ settings.h | 28 ++++++++++++++++++++++++++++ sslctxmanager.cpp | 5 +++++ sslctxmanager.h | 1 + threaddata.cpp | 18 +++++++++--------- threaddata.h | 10 ++++------ utils.cpp | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++- utils.h | 4 ++++ 28 files changed, 534 insertions(+), 304 deletions(-) delete mode 100644 globalsettings.cpp delete mode 100644 globalsettings.h create mode 100644 listener.cpp create mode 100644 listener.h create mode 100644 mosquittoauthoptcompatwrap.cpp create mode 100644 mosquittoauthoptcompatwrap.h create mode 100644 settings.cpp create mode 100644 settings.h diff --git a/CMakeLists.txt b/CMakeLists.txt index fb94126..758c80b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,8 +28,10 @@ add_executable(FlashMQ configfileparser.cpp sslctxmanager.cpp timer.cpp - globalsettings.cpp iowrapper.cpp + mosquittoauthoptcompatwrap.cpp + settings.cpp + listener.cpp ) target_link_libraries(FlashMQ pthread dl ssl crypto) diff --git a/authplugin.cpp b/authplugin.cpp index 86ee416..4906824 100644 --- a/authplugin.cpp +++ b/authplugin.cpp @@ -17,8 +17,8 @@ void mosquitto_log_printf(int level, const char *fmt, ...) } -AuthPlugin::AuthPlugin(ConfigFileParser &confFileParser) : - confFileParser(confFileParser) +AuthPlugin::AuthPlugin(Settings &settings) : + settings(settings) { logger = Logger::getInstance(); } @@ -89,7 +89,7 @@ void AuthPlugin::init() if (!wanted) return; - AuthOptCompatWrap &authOpts = confFileParser.getAuthOptsCompat(); + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); int result = init_v2(&pluginData, authOpts.head(), authOpts.size()); if (result != 0) throw FatalError("Error initialising auth plugin."); @@ -102,7 +102,7 @@ void AuthPlugin::cleanup() securityCleanup(false); - AuthOptCompatWrap &authOpts = confFileParser.getAuthOptsCompat(); + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); int result = cleanup_v2(pluginData, authOpts.head(), authOpts.size()); if (result != 0) logger->logf(LOG_ERR, "Error cleaning up auth plugin"); // Not doing exception, because we're shutting down anyway. @@ -113,7 +113,7 @@ void AuthPlugin::securityInit(bool reloading) if (!wanted) return; - AuthOptCompatWrap &authOpts = confFileParser.getAuthOptsCompat(); + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading); if (result != 0) { @@ -128,7 +128,7 @@ void AuthPlugin::securityCleanup(bool reloading) return; initialized = false; - AuthOptCompatWrap &authOpts = confFileParser.getAuthOptsCompat(); + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); int result = security_cleanup_v2(pluginData, authOpts.head(), authOpts.size(), reloading); if (result != 0) diff --git a/authplugin.h b/authplugin.h index 6c768ec..2c0ff40 100644 --- a/authplugin.h +++ b/authplugin.h @@ -51,7 +51,7 @@ class AuthPlugin F_auth_plugin_unpwd_check_v2 unpwd_check_v2 = nullptr; F_auth_plugin_psk_key_get_v2 psk_key_get_v2 = nullptr; - ConfigFileParser &confFileParser; + Settings &settings; // A ref because I want it to always be the same as the thread's settings void *pluginData = nullptr; Logger *logger = nullptr; @@ -60,7 +60,7 @@ class AuthPlugin void *loadSymbol(void *handle, const char *symbol) const; public: - AuthPlugin(ConfigFileParser &confFileParser); + AuthPlugin(Settings &settings); AuthPlugin(const AuthPlugin &other) = delete; AuthPlugin(AuthPlugin &&other) = delete; ~AuthPlugin(); diff --git a/client.cpp b/client.cpp index 49718a2..c82a0b7 100644 --- a/client.cpp +++ b/client.cpp @@ -7,10 +7,10 @@ #include "logger.h" -Client::Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, const GlobalSettings &settings) : +Client::Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, std::shared_ptr settings) : fd(fd), - initialBufferSize(settings.clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy - maxPacketSize(settings.maxPacketSize), // Same as initialBufferSize comment. + initialBufferSize(settings->clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy + maxPacketSize(settings->maxPacketSize), // Same as initialBufferSize comment. ioWrapper(ssl, websocket, initialBufferSize, this), readbuf(initialBufferSize), writebuf(initialBufferSize), diff --git a/client.h b/client.h index 4e99b52..49f2068 100644 --- a/client.h +++ b/client.h @@ -70,7 +70,7 @@ class Client void setReadyForReading(bool val); public: - Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, const GlobalSettings &settings); + Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, std::shared_ptr settings); Client(const Client &other) = delete; Client(Client &&other) = delete; ~Client(); diff --git a/configfileparser.cpp b/configfileparser.cpp index 168bc80..75e7825 100644 --- a/configfileparser.cpp +++ b/configfileparser.cpp @@ -4,45 +4,24 @@ #include #include #include "fstream" +#include #include "openssl/ssl.h" #include "openssl/err.h" #include "exceptions.h" #include "utils.h" -#include - #include "logger.h" -mosquitto_auth_opt::mosquitto_auth_opt(const std::string &key, const std::string &value) -{ - this->key = strdup(key.c_str()); - this->value = strdup(value.c_str()); -} - -mosquitto_auth_opt::mosquitto_auth_opt(mosquitto_auth_opt &&other) -{ - this->key = other.key; - this->value = other.value; - other.key = nullptr; - other.value = nullptr; -} - -mosquitto_auth_opt::~mosquitto_auth_opt() -{ - if (key) - delete key; - if (value) - delete value; -} - -AuthOptCompatWrap::AuthOptCompatWrap(const std::unordered_map &authOpts) +void ConfigFileParser::testKeyValidity(const std::string &key, const std::set &validKeys) const { - for(auto &pair : authOpts) + auto valid_key_it = validKeys.find(key); + if (valid_key_it == validKeys.end()) { - mosquitto_auth_opt opt(pair.first, pair.second); - optArray.push_back(std::move(opt)); + std::ostringstream oss; + oss << "Config key '" << key << "' is not valid here."; + throw ConfigFileException(oss.str()); } } @@ -56,51 +35,21 @@ void ConfigFileParser::checkFileAccess(const std::string &key, const std::string } } -// Using a separate ssl context to test, because it's the easiest way to load certs and key atomitcally. -void ConfigFileParser::testSsl(const std::string &fullchain, const std::string &privkey, uint portNr) const -{ - if (portNr == 0) - return; - - if (fullchain.empty() && privkey.empty()) - throw ConfigFileException("No privkey and fullchain specified."); - - if (fullchain.empty()) - throw ConfigFileException("No private key specified for fullchain"); - - if (privkey.empty()) - throw ConfigFileException("No fullchain specified for private key"); - - SslCtxManager sslCtx; - if (SSL_CTX_use_certificate_file(sslCtx.get(), fullchain.c_str(), SSL_FILETYPE_PEM) != 1) - { - ERR_print_errors_cb(logSslError, NULL); - throw ConfigFileException("Error loading full chain " + fullchain); - } - if (SSL_CTX_use_PrivateKey_file(sslCtx.get(), privkey.c_str(), SSL_FILETYPE_PEM) != 1) - { - ERR_print_errors_cb(logSslError, NULL); - throw ConfigFileException("Error loading private key " + privkey); - } - if (SSL_CTX_check_private_key(sslCtx.get()) != 1) - { - ERR_print_errors_cb(logSslError, NULL); - throw ConfigFileException("Private key and certificate don't match."); - } -} - ConfigFileParser::ConfigFileParser(const std::string &path) : path(path) { validKeys.insert("auth_plugin"); validKeys.insert("log_file"); - validKeys.insert("listen_port"); - validKeys.insert("ssl_listen_port"); - validKeys.insert("fullchain"); - validKeys.insert("privkey"); validKeys.insert("allow_unsafe_clientid_chars"); validKeys.insert("client_initial_buffer_size"); validKeys.insert("max_packet_size"); + + validListenKeys.insert("port"); + validListenKeys.insert("protocol"); + validListenKeys.insert("fullchain"); + validListenKeys.insert("privkey"); + + settings.reset(new Settings()); } void ConfigFileParser::loadFile(bool test) @@ -121,12 +70,19 @@ void ConfigFileParser::loadFile(bool test) std::list lines; - const std::regex r("^([a-zA-Z0-9_\\-]+) +([a-zA-Z0-9_\\-/\\.]+)$"); + const std::regex key_value_regex("^([a-zA-Z0-9_\\-]+) +([a-zA-Z0-9_\\-/\\.]+)$"); + const std::regex block_regex_start("^([a-zA-Z0-9_\\-]+) *\\{$"); + const std::regex block_regex_end("^\\}$"); + + bool inBlock = false; + std::ostringstream oss; + int linenr = 0; // First parse the file and keep the valid lines. for(std::string line; getline(infile, line ); ) { trim(line); + linenr++; if (startsWith(line, "#")) continue; @@ -136,102 +92,133 @@ void ConfigFileParser::loadFile(bool test) std::smatch matches; - if (!std::regex_search(line, matches, r) || matches.size() != 3) + const bool blockStartMatch = std::regex_search(line, matches, block_regex_start); + const bool blockEndMatch = std::regex_search(line, matches, block_regex_end); + + if ((blockStartMatch && inBlock) || (blockEndMatch && !inBlock)) { - std::ostringstream oss; - oss << "Line '" << line << "' not in 'key value' format"; + oss << "Unexpected block start or end at line " << linenr << ": " << line; throw ConfigFileException(oss.str()); } + if (!std::regex_search(line, matches, key_value_regex) && !blockStartMatch && !blockEndMatch) + { + oss << "Line '" << line << "' invalid"; + throw ConfigFileException(oss.str()); + } + + if (blockStartMatch) + inBlock = true; + if (blockEndMatch) + inBlock = false; + lines.push_back(line); } - authOpts.clear(); - authOptCompatWrap.reset(); + if (inBlock) + { + throw ConfigFileException("Unclosed config block. Expecting }"); + } + + std::unordered_map authOpts; - std::string sslFullChainTmp; - std::string sslPrivkeyTmp; + ConfigParseLevel curParseLevel = ConfigParseLevel::Root; + std::shared_ptr curListener; + std::unique_ptr tmpSettings(new Settings); // Then once we know the config file is valid, process it. for (std::string &line : lines) { std::smatch matches; - if (!std::regex_search(line, matches, r) || matches.size() != 3) + if (std::regex_match(line, matches, block_regex_start)) + { + std::string key = matches[1].str(); + if (matches[1].str() == "listen") + { + curParseLevel = ConfigParseLevel::Listen; + curListener.reset(new Listener); + } + else + { + throw ConfigFileException(formatString("'%s' is not a valid block.", key.c_str())); + } + + continue; + } + else if (std::regex_match(line, matches, block_regex_end)) { - throw ConfigFileException("Config parse error at a point that should not be possible."); + if (curParseLevel == ConfigParseLevel::Listen) + { + curListener->isValid(); + tmpSettings->listeners.push_back(curListener); + curListener.reset(); + } + + curParseLevel = ConfigParseLevel::Root; + continue; } + std::regex_match(line, matches, key_value_regex); + std::string key = matches[1].str(); const std::string value = matches[2].str(); - const std::string auth_opt_ = "auth_opt_"; - if (startsWith(key, auth_opt_)) - { - key.replace(0, auth_opt_.length(), ""); - authOpts[key] = value; - } - else + try { - auto valid_key_it = validKeys.find(key); - if (valid_key_it == validKeys.end()) + if (curParseLevel == ConfigParseLevel::Listen) { - std::ostringstream oss; - oss << "Config key '" << key << "' is not valid. This error should have been cought before. Bug?"; - throw ConfigFileException(oss.str()); - } + testKeyValidity(key, validListenKeys); - if (key == "auth_plugin") - { - checkFileAccess(key, value); - if (!test) - this->authPluginPath = value; - } + if (key == "protocol") + { + if (value != "mqtt" && value != "websockets") + throw ConfigFileException(formatString("Protocol '%s' is not a valid listener protocol", value.c_str())); + curListener->websocket = value == "websockets"; + } + else if (key == "port") + { + curListener->port = std::stoi(value); + } + else if (key == "fullchain") + { + curListener->sslFullchain = value; + } + if (key == "privkey") + { + curListener->sslPrivkey = value; + } - if (key == "log_file") - { - checkFileAccess(key, value); - if (!test) - this->logPath = value; + continue; } - if (key == "allow_unsafe_clientid_chars") - { - bool tmp = stringTruthiness(value); - if (!test) - this->allowUnsafeClientidChars = tmp; - } - if (key == "fullchain") + const std::string auth_opt_ = "auth_opt_"; + if (startsWith(key, auth_opt_)) { - checkFileAccess(key, value); - sslFullChainTmp = value; + key.replace(0, auth_opt_.length(), ""); + authOpts[key] = value; } - - if (key == "privkey") + else { - checkFileAccess(key, value); - sslPrivkeyTmp = value; - } + testKeyValidity(key, validKeys); - try - { - // TODO: make this possible. There are many error cases to deal with, like bind failures, etc. You don't want to end up without listeners. - if (key == "listen_port") + if (key == "auth_plugin") { - uint listenportNew = std::stoi(value); - if (listenPort > 0 && listenPort != listenportNew) - throw ConfigFileException("Changing (ssl_)listen_port is not supported at this time."); - listenPort = listenportNew; + checkFileAccess(key, value); + tmpSettings->authPluginPath = value; } - // TODO: make this possible. There are many error cases to deal with, like bind failures, etc. You don't want to end up without listeners. - if (key == "ssl_listen_port") + if (key == "log_file") { - uint sslListenPortNew = std::stoi(value); - if (sslListenPort > 0 && sslListenPort != sslListenPortNew) - throw ConfigFileException("Changing (ssl_)listen_port is not supported at this time."); - sslListenPort = sslListenPortNew; + checkFileAccess(key, value); + tmpSettings->logPath = value; + } + + if (key == "allow_unsafe_clientid_chars") + { + bool tmp = stringTruthiness(value); + tmpSettings->allowUnsafeClientidChars = tmp; } if (key == "client_initial_buffer_size") @@ -239,8 +226,7 @@ void ConfigFileParser::loadFile(bool test) int newVal = std::stoi(value); if (!isPowerOfTwo(newVal)) throw ConfigFileException("client_initial_buffer_size value " + value + " is not a power of two."); - if (!test) - clientInitialBufferSize = newVal; + tmpSettings->clientInitialBufferSize = newVal; } if (key == "max_packet_size") @@ -252,29 +238,33 @@ void ConfigFileParser::loadFile(bool test) oss << "Value for max_packet_size " << newVal << "is higher than absolute maximum " << ABSOLUTE_MAX_PACKET_SIZE; throw ConfigFileException(oss.str()); } - if (!test) - maxPacketSize = newVal; + tmpSettings->maxPacketSize = newVal; } - - } - catch (std::invalid_argument &ex) // catch for the stoi() - { - throw ConfigFileException(ex.what()); } } + catch (std::invalid_argument &ex) // catch for the stoi() + { + throw ConfigFileException(ex.what()); + } + } + + if (tmpSettings->listeners.empty()) + { + std::shared_ptr defaultListener(new Listener()); + tmpSettings->listeners.push_back(defaultListener); } - testSsl(sslFullChainTmp, sslPrivkeyTmp, sslListenPort); - this->sslFullchain = sslFullChainTmp; - this->sslPrivkey = sslPrivkeyTmp; + tmpSettings->authOptCompatWrap = AuthOptCompatWrap(authOpts); - authOptCompatWrap.reset(new AuthOptCompatWrap(authOpts)); + if (!test) + { + this->settings = std::move(tmpSettings); + } } -AuthOptCompatWrap &ConfigFileParser::getAuthOptsCompat() +AuthOptCompatWrap &Settings::getAuthOptsCompat() { - return *authOptCompatWrap.get(); + return authOptCompatWrap; } - diff --git a/configfileparser.h b/configfileparser.h index e05ca44..d6794ac 100644 --- a/configfileparser.h +++ b/configfileparser.h @@ -6,59 +6,33 @@ #include #include #include +#include #include "sslctxmanager.h" +#include "listener.h" +#include "settings.h" #define ABSOLUTE_MAX_PACKET_SIZE 268435461 // 256 MB + 5 -struct mosquitto_auth_opt +enum class ConfigParseLevel { - char *key = nullptr; - char *value = nullptr; - - mosquitto_auth_opt(const std::string &key, const std::string &value); - mosquitto_auth_opt(mosquitto_auth_opt &&other); - mosquitto_auth_opt(const mosquitto_auth_opt &other) = delete; - ~mosquitto_auth_opt(); -}; - -struct AuthOptCompatWrap -{ - std::vector optArray; - - AuthOptCompatWrap(const std::unordered_map &authOpts); - AuthOptCompatWrap(const AuthOptCompatWrap &other) = delete; - AuthOptCompatWrap(AuthOptCompatWrap &&other) = delete; - - struct mosquitto_auth_opt *head() { return &optArray[0]; } - int size() { return optArray.size(); } + Root, + Listen }; class ConfigFileParser { const std::string path; std::set validKeys; - std::unordered_map authOpts; - std::unique_ptr authOptCompatWrap; - + std::set validListenKeys; + void testKeyValidity(const std::string &key, const std::set &validKeys) const; void checkFileAccess(const std::string &key, const std::string &pathToCheck) const; - void testSsl(const std::string &fullchain, const std::string &privkey, uint portNr) const; public: ConfigFileParser(const std::string &path); void loadFile(bool test); - AuthOptCompatWrap &getAuthOptsCompat(); - // Actual config options with their defaults. Just making them public, I can retrain myself misuing them. - std::string authPluginPath; - std::string logPath; - std::string sslFullchain; - std::string sslPrivkey; - uint listenPort = 1883; - uint sslListenPort = 0; - bool allowUnsafeClientidChars = false; - int clientInitialBufferSize = 1024; // Must be power of 2 - int maxPacketSize = 268435461; // 256 MB + 5 + std::unique_ptr settings; }; #endif // CONFIGFILEPARSER_H diff --git a/exceptions.cpp b/exceptions.cpp index 3a4a7aa..270cc29 100644 --- a/exceptions.cpp +++ b/exceptions.cpp @@ -1,2 +1,4 @@ #include "exceptions.h" + + diff --git a/exceptions.h b/exceptions.h index a27595d..2910c20 100644 --- a/exceptions.h +++ b/exceptions.h @@ -3,6 +3,7 @@ #include #include +#include class ProtocolError : public std::runtime_error { @@ -26,6 +27,7 @@ class ConfigFileException : public std::runtime_error { public: ConfigFileException(const std::string &msg) : std::runtime_error(msg) {} + ConfigFileException(std::ostringstream oss) : std::runtime_error(oss.str()) {} }; class AuthPluginException : public std::runtime_error diff --git a/forward_declarations.h b/forward_declarations.h index d4f6902..1c31c49 100644 --- a/forward_declarations.h +++ b/forward_declarations.h @@ -10,6 +10,7 @@ typedef std::shared_ptr ThreadData_p; class MqttPacket; class SubscriptionStore; class Session; +class Settings; #endif // FORWARD_DECLARATIONS_H diff --git a/globalsettings.cpp b/globalsettings.cpp deleted file mode 100644 index 5e10a0f..0000000 --- a/globalsettings.cpp +++ /dev/null @@ -1,3 +0,0 @@ -#include "globalsettings.h" - - diff --git a/globalsettings.h b/globalsettings.h deleted file mode 100644 index bf9dbbb..0000000 --- a/globalsettings.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef GLOBALSETTINGS_H -#define GLOBALSETTINGS_H - -// Defaults are defined in ConfigFileParser -struct GlobalSettings -{ - bool allow_unsafe_clientid_chars = false; - int clientInitialBufferSize = 0; - int maxPacketSize = 0; -}; -#endif // GLOBALSETTINGS_H diff --git a/listener.cpp b/listener.cpp new file mode 100644 index 0000000..33873bf --- /dev/null +++ b/listener.cpp @@ -0,0 +1,78 @@ +#include "listener.h" + +#include "utils.h" +#include "exceptions.h" + +void Listener::isValid() +{ + if (isSsl()) + { + if (port == 0) + { + if (websocket) + port = 4443; + else + port = 8883; + } + + testSsl(sslFullchain, sslPrivkey); + } + else + { + if (port == 0) + { + if (websocket) + port = 8080; + else + port = 1883; + } + } + + if (port <= 0 || port > 65534) + { + throw ConfigFileException(formatString("Port nr %d is not valid", port)); + } +} + +bool Listener::isSsl() const +{ + return (!sslFullchain.empty() || !sslPrivkey.empty()); +} + +std::string Listener::getProtocolName() const +{ + if (isSsl()) + { + if (websocket) + return "SSL websocket"; + else + return "SSL TCP"; + } + else + { + if (websocket) + return "non-SSL websocket"; + else + return "non-SSL TCP"; + } + + return "whoops"; +} + +void Listener::loadCertAndKeyFromConfig() +{ + if (!isSsl()) + return; + + if (!sslctx) + { + sslctx.reset(new SslCtxManager()); + SSL_CTX_set_options(sslctx->get(), SSL_OP_NO_SSLv3); // TODO: config option + SSL_CTX_set_options(sslctx->get(), SSL_OP_NO_TLSv1); // TODO: config option + } + + if (SSL_CTX_use_certificate_file(sslctx->get(), sslFullchain.c_str(), SSL_FILETYPE_PEM) != 1) + throw std::runtime_error("Loading cert failed. This was after test loading the certificate, so is very unexpected."); + if (SSL_CTX_use_PrivateKey_file(sslctx->get(), sslPrivkey.c_str(), SSL_FILETYPE_PEM) != 1) + throw std::runtime_error("Loading key failed. This was after test loading the certificate, so is very unexpected."); +} diff --git a/listener.h b/listener.h new file mode 100644 index 0000000..7d507bb --- /dev/null +++ b/listener.h @@ -0,0 +1,22 @@ +#ifndef LISTENER_H +#define LISTENER_H + +#include +#include + +#include "sslctxmanager.h" + +struct Listener +{ + int port = 0; + bool websocket = false; + std::string sslFullchain; + std::string sslPrivkey; + std::unique_ptr sslctx; + + void isValid(); + bool isSsl() const; + std::string getProtocolName() const; + void loadCertAndKeyFromConfig(); +}; +#endif // LISTENER_H diff --git a/mainapp.cpp b/mainapp.cpp index d46d11b..3bf3bfa 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -172,9 +172,6 @@ MainApp::MainApp(const std::string &configFilePath) : MainApp::~MainApp() { - if (sslctx) - SSL_CTX_free(sslctx); - if (epollFdAccept > 0) close(epollFdAccept); } @@ -202,54 +199,47 @@ void MainApp::showLicense() puts("Author: Wiebe Cazemier "); } -void MainApp::setCertAndKeyFromConfig() -{ - if (sslctx == nullptr) - return; - - if (SSL_CTX_use_certificate_file(sslctx, confFileParser->sslFullchain.c_str(), SSL_FILETYPE_PEM) != 1) - throw std::runtime_error("Loading cert failed. This was after test loading the certificate, so is very unexpected."); - if (SSL_CTX_use_PrivateKey_file(sslctx, confFileParser->sslPrivkey.c_str(), SSL_FILETYPE_PEM) != 1) - throw std::runtime_error("Loading key failed. This was after test loading the certificate, so is very unexpected."); -} - -int MainApp::createListenSocket(int portNr, bool ssl) +int MainApp::createListenSocket(const std::shared_ptr &listener) { - if (portNr <= 0) + if (listener->port <= 0) return -2; - int listen_fd = check(socket(AF_INET, SOCK_STREAM, 0)); - - // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT. - int optval = 1; - check(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); - - int flags = fcntl(listen_fd, F_GETFL); - check(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK )); + logger->logf(LOG_NOTICE, "Creating %s listener on port %d", listener->getProtocolName().c_str(), listener->port); - struct sockaddr_in in_addr_plain; - in_addr_plain.sin_family = AF_INET; - in_addr_plain.sin_addr.s_addr = INADDR_ANY; - in_addr_plain.sin_port = htons(portNr); + try + { + int listen_fd = check(socket(AF_INET, SOCK_STREAM, 0)); - check(bind(listen_fd, (struct sockaddr *)(&in_addr_plain), sizeof(struct sockaddr_in))); - check(listen(listen_fd, 1024)); + // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT. + int optval = 1; + check(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); - struct epoll_event ev; - memset(&ev, 0, sizeof (struct epoll_event)); + int flags = fcntl(listen_fd, F_GETFL); + check(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK )); - ev.data.fd = listen_fd; - ev.events = EPOLLIN; - check(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, listen_fd, &ev)); + struct sockaddr_in in_addr_plain; + in_addr_plain.sin_family = AF_INET; + in_addr_plain.sin_addr.s_addr = INADDR_ANY; + in_addr_plain.sin_port = htons(listener->port); - std::string socketType = "plain"; + check(bind(listen_fd, (struct sockaddr *)(&in_addr_plain), sizeof(struct sockaddr_in))); + check(listen(listen_fd, 1024)); - if (ssl) - socketType = "SSL"; + struct epoll_event ev; + memset(&ev, 0, sizeof (struct epoll_event)); - logger->logf(LOG_NOTICE, "Listening on %s port %d", socketType.c_str(), portNr); + ev.data.fd = listen_fd; + ev.events = EPOLLIN; + check(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, listen_fd, &ev)); - return listen_fd; + return listen_fd; + } + catch (std::exception &ex) + { + logger->logf(LOG_NOTICE, "Creating %s listener on port %d failed: %s", listener->getProtocolName().c_str(), listener->port, ex.what()); + return -1; + } + return -1; } void MainApp::wakeUpThread() @@ -349,9 +339,14 @@ void MainApp::start() { timer.start(); - int listen_fd_plain = createListenSocket(this->listenPort, false); - int listen_fd_ssl = createListenSocket(this->sslListenPort, true); - int listen_fd_websocket_plain = createListenSocket(1443, true); + std::map> listenerMap; + + for(std::shared_ptr &listener : this->listeners) + { + int fd = createListenSocket(listener); + if (fd > 0) + listenerMap[fd] = listener; + } #ifdef NDEBUG logger->noLongerLogToStd(); @@ -365,7 +360,7 @@ void MainApp::start() for (int i = 0; i < num_threads; i++) { - std::shared_ptr t(new ThreadData(i, subscriptionStore, *confFileParser.get(), settings)); + std::shared_ptr t(new ThreadData(i, subscriptionStore, settings)); t->start(&do_thread_work); threads.push_back(t); } @@ -392,22 +387,22 @@ void MainApp::start() int cur_fd = events[i].data.fd; try { - if (cur_fd == listen_fd_plain || cur_fd == listen_fd_ssl || listen_fd_websocket_plain) + if (cur_fd != taskEventFd) { + std::shared_ptr listener = listenerMap[cur_fd]; std::shared_ptr thread_data = threads[next_thread_index++ % num_threads]; - logger->logf(LOG_INFO, "Accepting connection on thread %d", thread_data->threadnr); + logger->logf(LOG_INFO, "Accepting connection on thread %d on %s", thread_data->threadnr, listener->getProtocolName().c_str()); struct sockaddr addr; memset(&addr, 0, sizeof(struct sockaddr)); socklen_t len = sizeof(struct sockaddr); int fd = check(accept(cur_fd, &addr, &len)); - bool websocket = cur_fd == listen_fd_websocket_plain; SSL *clientSSL = nullptr; - if (cur_fd == listen_fd_ssl) + if (listener->isSsl()) { - clientSSL = SSL_new(sslctx); + clientSSL = SSL_new(listener->sslctx->get()); if (clientSSL == NULL) { @@ -419,10 +414,10 @@ void MainApp::start() SSL_set_fd(clientSSL, fd); } - Client_p client(new Client(fd, thread_data, clientSSL, websocket, settings)); + Client_p client(new Client(fd, thread_data, clientSSL, listener->websocket, settings)); thread_data->giveClient(client); } - else if (cur_fd == taskEventFd) + else { uint64_t eventfd_value = 0; check(read(cur_fd, &eventfd_value, sizeof(uint64_t))); @@ -434,10 +429,6 @@ void MainApp::start() } taskQueue.clear(); } - else - { - throw std::runtime_error("Bug: the main thread had activity on an fd it's not supposed to monitor."); - } } catch (std::exception &ex) { @@ -452,12 +443,18 @@ void MainApp::start() thread->quit(); } - close(listen_fd_plain); - close(listen_fd_ssl); + for(auto pair : listenerMap) + { + close(pair.first); + } } void MainApp::quit() { + std::lock_guard guard(quitMutex); + if (!running) + return; + Logger *logger = Logger::getInstance(); logger->logf(LOG_NOTICE, "Quitting FlashMQ"); timer.stop(); @@ -472,26 +469,21 @@ void MainApp::loadConfig() // Atomic loading, first test. confFileParser->loadFile(true); confFileParser->loadFile(false); + settings = std::move(confFileParser->settings); - logger->setLogPath(confFileParser->logPath); - logger->reOpen(); + // For now, it's too much work to be able to reload new listeners, with all the shared resource stuff going on. So, I'm + // loading them to a local var which is never updated. + if (listeners.empty()) + listeners = settings->listeners; - listenPort = confFileParser->listenPort; - sslListenPort = confFileParser->sslListenPort; + logger->setLogPath(settings->logPath); + logger->reOpen(); - if (sslctx == nullptr && sslListenPort > 0) + for (std::shared_ptr &l : this->listeners) { - sslctx = SSL_CTX_new(TLS_server_method()); - SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv3); // TODO: config option - SSL_CTX_set_options(sslctx, SSL_OP_NO_TLSv1); // TODO: config option + l->loadCertAndKeyFromConfig(); } - settings.allow_unsafe_clientid_chars = confFileParser->allowUnsafeClientidChars; - settings.clientInitialBufferSize = confFileParser->clientInitialBufferSize; - settings.maxPacketSize = confFileParser->maxPacketSize; - - setCertAndKeyFromConfig(); - for (std::shared_ptr &thread : threads) { thread->queueReload(settings); diff --git a/mainapp.h b/mainapp.h index c2f1263..d119b3f 100644 --- a/mainapp.h +++ b/mainapp.h @@ -20,7 +20,6 @@ #include "subscriptionstore.h" #include "configfileparser.h" #include "timer.h" -#include "globalsettings.h" class MainApp { @@ -37,11 +36,9 @@ class MainApp int taskEventFd = -1; std::mutex eventMutex; Timer timer; - GlobalSettings settings; - - uint listenPort = 0; - uint sslListenPort = 0; - SSL_CTX *sslctx = nullptr; + std::shared_ptr settings; + std::list> listeners; + std::mutex quitMutex; Logger *logger = Logger::getInstance(); @@ -49,8 +46,7 @@ class MainApp void reloadConfig(); static void doHelp(const char *arg); static void showLicense(); - void setCertAndKeyFromConfig(); - int createListenSocket(int portNr, bool ssl); + int createListenSocket(const std::shared_ptr &listener); void wakeUpThread(); void queueKeepAliveCheckAtAllThreads(); @@ -66,7 +62,6 @@ public: bool getStarted() const {return started;} static void testConfig(); - GlobalSettings &getGlobalSettings(); void queueConfigReload(); void queueCleanup(); }; diff --git a/mosquittoauthoptcompatwrap.cpp b/mosquittoauthoptcompatwrap.cpp new file mode 100644 index 0000000..b5f22b8 --- /dev/null +++ b/mosquittoauthoptcompatwrap.cpp @@ -0,0 +1,52 @@ +#include "mosquittoauthoptcompatwrap.h" + +mosquitto_auth_opt::mosquitto_auth_opt(const std::string &key, const std::string &value) +{ + this->key = strdup(key.c_str()); + this->value = strdup(value.c_str()); +} + +mosquitto_auth_opt::mosquitto_auth_opt(mosquitto_auth_opt &&other) +{ + this->key = other.key; + this->value = other.value; + other.key = nullptr; + other.value = nullptr; +} + +mosquitto_auth_opt::mosquitto_auth_opt(const mosquitto_auth_opt &other) +{ + this->key = strdup(other.key); + this->value = strdup(other.value); +} + +mosquitto_auth_opt::~mosquitto_auth_opt() +{ + if (key) + delete key; + if (value) + delete value; +} + +mosquitto_auth_opt &mosquitto_auth_opt::operator=(const mosquitto_auth_opt &other) +{ + if (key) + delete key; + if (value) + delete value; + + this->key = strdup(other.key); + this->value = strdup(other.value); + + return *this; +} + +AuthOptCompatWrap::AuthOptCompatWrap(const std::unordered_map &authOpts) +{ + for(auto &pair : authOpts) + { + mosquitto_auth_opt opt(pair.first, pair.second); + optArray.push_back(std::move(opt)); + } +} + diff --git a/mosquittoauthoptcompatwrap.h b/mosquittoauthoptcompatwrap.h new file mode 100644 index 0000000..bcc4fd7 --- /dev/null +++ b/mosquittoauthoptcompatwrap.h @@ -0,0 +1,47 @@ +#ifndef MOSQUITTOAUTHOPTCOMPATWRAP_H +#define MOSQUITTOAUTHOPTCOMPATWRAP_H + +#include +#include +#include + +/** + * @brief The mosquitto_auth_opt struct is a resource managed class of auth options, compatible with passing as arguments to Mosquitto + * auth plugins. + * + * It's fully assignable and copyable. + */ +struct mosquitto_auth_opt +{ + char *key = nullptr; + char *value = nullptr; + + mosquitto_auth_opt(const std::string &key, const std::string &value); + mosquitto_auth_opt(mosquitto_auth_opt &&other); + mosquitto_auth_opt(const mosquitto_auth_opt &other); + ~mosquitto_auth_opt(); + + mosquitto_auth_opt& operator=(const mosquitto_auth_opt &other); +}; + +/** + * @brief The AuthOptCompatWrap struct contains a vector of mosquitto auth options, with a head pointer and count which can be passed to + * Mosquitto auth plugins. + */ +struct AuthOptCompatWrap +{ + std::vector optArray; + + AuthOptCompatWrap(const std::unordered_map &authOpts); + AuthOptCompatWrap(const AuthOptCompatWrap &other) = default; + AuthOptCompatWrap(AuthOptCompatWrap &&other) = delete; + AuthOptCompatWrap() = default; + + struct mosquitto_auth_opt *head() { return &optArray[0]; } + int size() { return optArray.size(); } + + AuthOptCompatWrap &operator=(const AuthOptCompatWrap &other) = default; +}; + + +#endif // MOSQUITTOAUTHOPTCOMPATWRAP_H diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 2b940ee..8a14307 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -145,7 +145,7 @@ void MqttPacket::handleConnect() uint16_t variable_header_length = readTwoBytesToUInt16(); - const GlobalSettings &settings = sender->getThreadData()->settingsLocalCopy; + const Settings &settings = sender->getThreadData()->settingsLocalCopy; if (variable_header_length == 4 || variable_header_length == 6) { @@ -232,7 +232,7 @@ void MqttPacket::handleConnect() bool validClientId = true; // Check for wildcard chars in case the client_id ever appears in topics. - if (!settings.allow_unsafe_clientid_chars && (strContains(client_id, "+") || strContains(client_id, "#"))) + if (!settings.allowUnsafeClientidChars && (strContains(client_id, "+") || strContains(client_id, "#"))) { logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false:", client_id.c_str()); validClientId = false; diff --git a/mqttpacket.h b/mqttpacket.h index 2f4c86a..31cfe0d 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -15,7 +15,6 @@ #include "cirbuf.h" #include "logger.h" #include "mainapp.h" -#include "globalsettings.h" struct RemainingLength { diff --git a/settings.cpp b/settings.cpp new file mode 100644 index 0000000..bca168d --- /dev/null +++ b/settings.cpp @@ -0,0 +1,3 @@ +#include "settings.h" + + diff --git a/settings.h b/settings.h new file mode 100644 index 0000000..d440612 --- /dev/null +++ b/settings.h @@ -0,0 +1,28 @@ +#ifndef SETTINGS_H +#define SETTINGS_H + +#include +#include + +#include "mosquittoauthoptcompatwrap.h" +#include "listener.h" + +class Settings +{ + friend class ConfigFileParser; + + AuthOptCompatWrap authOptCompatWrap; + +public: + // Actual config options with their defaults. + std::string authPluginPath; + std::string logPath; + bool allowUnsafeClientidChars = false; + int clientInitialBufferSize = 1024; // Must be power of 2 + int maxPacketSize = 268435461; // 256 MB + 5 + std::list> listeners; // Default one is created later, when none are defined. + + AuthOptCompatWrap &getAuthOptsCompat(); +}; + +#endif // SETTINGS_H diff --git a/sslctxmanager.cpp b/sslctxmanager.cpp index 7519f44..9ea326e 100644 --- a/sslctxmanager.cpp +++ b/sslctxmanager.cpp @@ -16,3 +16,8 @@ SSL_CTX *SslCtxManager::get() const { return ssl_ctx; } + +SslCtxManager::operator bool() const +{ + return ssl_ctx == nullptr; +} diff --git a/sslctxmanager.h b/sslctxmanager.h index bd5c227..0f0510c 100644 --- a/sslctxmanager.h +++ b/sslctxmanager.h @@ -11,6 +11,7 @@ public: ~SslCtxManager(); SSL_CTX *get() const; + operator bool() const; }; #endif // SSLCTXMANAGER_H diff --git a/threaddata.cpp b/threaddata.cpp index 1c5b2dd..bcfbcd0 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -2,12 +2,11 @@ #include #include -ThreadData::ThreadData(int threadnr, std::shared_ptr &subscriptionStore, ConfigFileParser &confFileParser, const GlobalSettings &settings) : +ThreadData::ThreadData(int threadnr, std::shared_ptr &subscriptionStore, std::shared_ptr settings) : subscriptionStore(subscriptionStore), - confFileParser(confFileParser), - authPlugin(confFileParser), - threadnr(threadnr), - settingsLocalCopy(settings) + settingsLocalCopy(*settings.get()), + authPlugin(settingsLocalCopy), + threadnr(threadnr) { logger = Logger::getInstance(); @@ -146,18 +145,19 @@ void ThreadData::doKeepAliveCheck() void ThreadData::initAuthPlugin() { - authPlugin.loadPlugin(confFileParser.authPluginPath); + authPlugin.loadPlugin(settingsLocalCopy.authPluginPath); authPlugin.init(); authPlugin.securityInit(false); } -void ThreadData::reload(GlobalSettings settings) +void ThreadData::reload(std::shared_ptr settings) { logger->logf(LOG_DEBUG, "Doing reload in thread %d", threadnr); try { - settingsLocalCopy = settings; + // Because the auth plugin has a reference to it, it will also be updated. + settingsLocalCopy = *settings.get(); authPlugin.securityCleanup(true); authPlugin.securityInit(true); @@ -172,7 +172,7 @@ void ThreadData::reload(GlobalSettings settings) } } -void ThreadData::queueReload(GlobalSettings settings) +void ThreadData::queueReload(std::shared_ptr settings) { std::lock_guard locker(taskQueueMutex); diff --git a/threaddata.h b/threaddata.h index 1cd2664..2dc3321 100644 --- a/threaddata.h +++ b/threaddata.h @@ -20,7 +20,6 @@ #include "configfileparser.h" #include "authplugin.h" #include "logger.h" -#include "globalsettings.h" typedef void (*thread_f)(ThreadData *); @@ -29,14 +28,14 @@ class ThreadData std::unordered_map clients_by_fd; std::mutex clients_by_fd_mutex; std::shared_ptr subscriptionStore; - ConfigFileParser &confFileParser; Logger *logger; - void reload(GlobalSettings settings); + void reload(std::shared_ptr settings); void wakeUpThread(); void doKeepAliveCheck(); public: + Settings settingsLocalCopy; // Is updated on reload, within the thread loop. AuthPlugin authPlugin; bool running = true; std::thread thread; @@ -45,9 +44,8 @@ public: int taskEventFd = 0; std::mutex taskQueueMutex; std::forward_list> taskQueue; - GlobalSettings settingsLocalCopy; // Is updated on reload, within the thread loop. - ThreadData(int threadnr, std::shared_ptr &subscriptionStore, ConfigFileParser &confFileParser, const GlobalSettings &settings); + ThreadData(int threadnr, std::shared_ptr &subscriptionStore, std::shared_ptr settings); ThreadData(const ThreadData &other) = delete; ThreadData(ThreadData &&other) = delete; @@ -60,7 +58,7 @@ public: std::shared_ptr &getSubscriptionStore(); void initAuthPlugin(); - void queueReload(GlobalSettings settings); + void queueReload(std::shared_ptr settings); void queueDoKeepAliveCheck(); }; diff --git a/utils.cpp b/utils.cpp index 23b3560..9f4731a 100644 --- a/utils.cpp +++ b/utils.cpp @@ -3,10 +3,15 @@ #include "sys/time.h" #include "sys/random.h" #include -#include +#include + +#include "openssl/ssl.h" +#include "openssl/err.h" #include "exceptions.h" #include "cirbuf.h" +#include "sslctxmanager.h" +#include "logger.h" std::list split(const std::string &input, const char sep, size_t max, bool keep_empty_parts) { @@ -346,3 +351,47 @@ std::string generateWebsocketAnswer(const std::string &acceptString) oss.flush(); return oss.str(); } + +// Using a separate ssl context to test, because it's the easiest way to load certs and key atomitcally. +void testSsl(const std::string &fullchain, const std::string &privkey) +{ + if (fullchain.empty() && privkey.empty()) + throw ConfigFileException("No privkey and fullchain specified."); + + if (fullchain.empty()) + throw ConfigFileException("No private key specified for fullchain"); + + if (privkey.empty()) + throw ConfigFileException("No fullchain specified for private key"); + + SslCtxManager sslCtx; + if (SSL_CTX_use_certificate_file(sslCtx.get(), fullchain.c_str(), SSL_FILETYPE_PEM) != 1) + { + ERR_print_errors_cb(logSslError, NULL); + throw ConfigFileException("Error loading full chain " + fullchain); + } + if (SSL_CTX_use_PrivateKey_file(sslCtx.get(), privkey.c_str(), SSL_FILETYPE_PEM) != 1) + { + ERR_print_errors_cb(logSslError, NULL); + throw ConfigFileException("Error loading private key " + privkey); + } + if (SSL_CTX_check_private_key(sslCtx.get()) != 1) + { + ERR_print_errors_cb(logSslError, NULL); + throw ConfigFileException("Private key and certificate don't match."); + } +} + +std::string formatString(const std::string str, ...) +{ + char buf[512]; + + va_list valist; + va_start(valist, str); + vsnprintf(buf, 512, str.c_str(), valist); + va_end(valist); + + std::string result(buf, 512); + + return result; +} diff --git a/utils.h b/utils.h index 666f517..d7776dc 100644 --- a/utils.h +++ b/utils.h @@ -55,4 +55,8 @@ std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion); std::string generateBadHttpRequestReponse(const std::string &msg); std::string generateWebsocketAnswer(const std::string &acceptString); +void testSsl(const std::string &fullchain, const std::string &privkey); + +std::string formatString(const std::string str, ...); + #endif // UTILS_H -- libgit2 0.21.4