diff --git a/CMakeLists.txt b/CMakeLists.txt index fb94126..758c80b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,8 +28,10 @@ add_executable(FlashMQ configfileparser.cpp sslctxmanager.cpp timer.cpp - globalsettings.cpp iowrapper.cpp + mosquittoauthoptcompatwrap.cpp + settings.cpp + listener.cpp ) target_link_libraries(FlashMQ pthread dl ssl crypto) diff --git a/authplugin.cpp b/authplugin.cpp index 86ee416..4906824 100644 --- a/authplugin.cpp +++ b/authplugin.cpp @@ -17,8 +17,8 @@ void mosquitto_log_printf(int level, const char *fmt, ...) } -AuthPlugin::AuthPlugin(ConfigFileParser &confFileParser) : - confFileParser(confFileParser) +AuthPlugin::AuthPlugin(Settings &settings) : + settings(settings) { logger = Logger::getInstance(); } @@ -89,7 +89,7 @@ void AuthPlugin::init() if (!wanted) return; - AuthOptCompatWrap &authOpts = confFileParser.getAuthOptsCompat(); + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); int result = init_v2(&pluginData, authOpts.head(), authOpts.size()); if (result != 0) throw FatalError("Error initialising auth plugin."); @@ -102,7 +102,7 @@ void AuthPlugin::cleanup() securityCleanup(false); - AuthOptCompatWrap &authOpts = confFileParser.getAuthOptsCompat(); + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); int result = cleanup_v2(pluginData, authOpts.head(), authOpts.size()); if (result != 0) logger->logf(LOG_ERR, "Error cleaning up auth plugin"); // Not doing exception, because we're shutting down anyway. @@ -113,7 +113,7 @@ void AuthPlugin::securityInit(bool reloading) if (!wanted) return; - AuthOptCompatWrap &authOpts = confFileParser.getAuthOptsCompat(); + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading); if (result != 0) { @@ -128,7 +128,7 @@ void AuthPlugin::securityCleanup(bool reloading) return; initialized = false; - AuthOptCompatWrap &authOpts = confFileParser.getAuthOptsCompat(); + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); int result = security_cleanup_v2(pluginData, authOpts.head(), authOpts.size(), reloading); if (result != 0) diff --git a/authplugin.h b/authplugin.h index 6c768ec..2c0ff40 100644 --- a/authplugin.h +++ b/authplugin.h @@ -51,7 +51,7 @@ class AuthPlugin F_auth_plugin_unpwd_check_v2 unpwd_check_v2 = nullptr; F_auth_plugin_psk_key_get_v2 psk_key_get_v2 = nullptr; - ConfigFileParser &confFileParser; + Settings &settings; // A ref because I want it to always be the same as the thread's settings void *pluginData = nullptr; Logger *logger = nullptr; @@ -60,7 +60,7 @@ class AuthPlugin void *loadSymbol(void *handle, const char *symbol) const; public: - AuthPlugin(ConfigFileParser &confFileParser); + AuthPlugin(Settings &settings); AuthPlugin(const AuthPlugin &other) = delete; AuthPlugin(AuthPlugin &&other) = delete; ~AuthPlugin(); diff --git a/client.cpp b/client.cpp index 49718a2..c82a0b7 100644 --- a/client.cpp +++ b/client.cpp @@ -7,10 +7,10 @@ #include "logger.h" -Client::Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, const GlobalSettings &settings) : +Client::Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, std::shared_ptr settings) : fd(fd), - initialBufferSize(settings.clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy - maxPacketSize(settings.maxPacketSize), // Same as initialBufferSize comment. + initialBufferSize(settings->clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy + maxPacketSize(settings->maxPacketSize), // Same as initialBufferSize comment. ioWrapper(ssl, websocket, initialBufferSize, this), readbuf(initialBufferSize), writebuf(initialBufferSize), diff --git a/client.h b/client.h index 4e99b52..49f2068 100644 --- a/client.h +++ b/client.h @@ -70,7 +70,7 @@ class Client void setReadyForReading(bool val); public: - Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, const GlobalSettings &settings); + Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, std::shared_ptr settings); Client(const Client &other) = delete; Client(Client &&other) = delete; ~Client(); diff --git a/configfileparser.cpp b/configfileparser.cpp index 168bc80..75e7825 100644 --- a/configfileparser.cpp +++ b/configfileparser.cpp @@ -4,45 +4,24 @@ #include #include #include "fstream" +#include #include "openssl/ssl.h" #include "openssl/err.h" #include "exceptions.h" #include "utils.h" -#include - #include "logger.h" -mosquitto_auth_opt::mosquitto_auth_opt(const std::string &key, const std::string &value) -{ - this->key = strdup(key.c_str()); - this->value = strdup(value.c_str()); -} - -mosquitto_auth_opt::mosquitto_auth_opt(mosquitto_auth_opt &&other) -{ - this->key = other.key; - this->value = other.value; - other.key = nullptr; - other.value = nullptr; -} - -mosquitto_auth_opt::~mosquitto_auth_opt() -{ - if (key) - delete key; - if (value) - delete value; -} - -AuthOptCompatWrap::AuthOptCompatWrap(const std::unordered_map &authOpts) +void ConfigFileParser::testKeyValidity(const std::string &key, const std::set &validKeys) const { - for(auto &pair : authOpts) + auto valid_key_it = validKeys.find(key); + if (valid_key_it == validKeys.end()) { - mosquitto_auth_opt opt(pair.first, pair.second); - optArray.push_back(std::move(opt)); + std::ostringstream oss; + oss << "Config key '" << key << "' is not valid here."; + throw ConfigFileException(oss.str()); } } @@ -56,51 +35,21 @@ void ConfigFileParser::checkFileAccess(const std::string &key, const std::string } } -// Using a separate ssl context to test, because it's the easiest way to load certs and key atomitcally. -void ConfigFileParser::testSsl(const std::string &fullchain, const std::string &privkey, uint portNr) const -{ - if (portNr == 0) - return; - - if (fullchain.empty() && privkey.empty()) - throw ConfigFileException("No privkey and fullchain specified."); - - if (fullchain.empty()) - throw ConfigFileException("No private key specified for fullchain"); - - if (privkey.empty()) - throw ConfigFileException("No fullchain specified for private key"); - - SslCtxManager sslCtx; - if (SSL_CTX_use_certificate_file(sslCtx.get(), fullchain.c_str(), SSL_FILETYPE_PEM) != 1) - { - ERR_print_errors_cb(logSslError, NULL); - throw ConfigFileException("Error loading full chain " + fullchain); - } - if (SSL_CTX_use_PrivateKey_file(sslCtx.get(), privkey.c_str(), SSL_FILETYPE_PEM) != 1) - { - ERR_print_errors_cb(logSslError, NULL); - throw ConfigFileException("Error loading private key " + privkey); - } - if (SSL_CTX_check_private_key(sslCtx.get()) != 1) - { - ERR_print_errors_cb(logSslError, NULL); - throw ConfigFileException("Private key and certificate don't match."); - } -} - ConfigFileParser::ConfigFileParser(const std::string &path) : path(path) { validKeys.insert("auth_plugin"); validKeys.insert("log_file"); - validKeys.insert("listen_port"); - validKeys.insert("ssl_listen_port"); - validKeys.insert("fullchain"); - validKeys.insert("privkey"); validKeys.insert("allow_unsafe_clientid_chars"); validKeys.insert("client_initial_buffer_size"); validKeys.insert("max_packet_size"); + + validListenKeys.insert("port"); + validListenKeys.insert("protocol"); + validListenKeys.insert("fullchain"); + validListenKeys.insert("privkey"); + + settings.reset(new Settings()); } void ConfigFileParser::loadFile(bool test) @@ -121,12 +70,19 @@ void ConfigFileParser::loadFile(bool test) std::list lines; - const std::regex r("^([a-zA-Z0-9_\\-]+) +([a-zA-Z0-9_\\-/\\.]+)$"); + const std::regex key_value_regex("^([a-zA-Z0-9_\\-]+) +([a-zA-Z0-9_\\-/\\.]+)$"); + const std::regex block_regex_start("^([a-zA-Z0-9_\\-]+) *\\{$"); + const std::regex block_regex_end("^\\}$"); + + bool inBlock = false; + std::ostringstream oss; + int linenr = 0; // First parse the file and keep the valid lines. for(std::string line; getline(infile, line ); ) { trim(line); + linenr++; if (startsWith(line, "#")) continue; @@ -136,102 +92,133 @@ void ConfigFileParser::loadFile(bool test) std::smatch matches; - if (!std::regex_search(line, matches, r) || matches.size() != 3) + const bool blockStartMatch = std::regex_search(line, matches, block_regex_start); + const bool blockEndMatch = std::regex_search(line, matches, block_regex_end); + + if ((blockStartMatch && inBlock) || (blockEndMatch && !inBlock)) { - std::ostringstream oss; - oss << "Line '" << line << "' not in 'key value' format"; + oss << "Unexpected block start or end at line " << linenr << ": " << line; throw ConfigFileException(oss.str()); } + if (!std::regex_search(line, matches, key_value_regex) && !blockStartMatch && !blockEndMatch) + { + oss << "Line '" << line << "' invalid"; + throw ConfigFileException(oss.str()); + } + + if (blockStartMatch) + inBlock = true; + if (blockEndMatch) + inBlock = false; + lines.push_back(line); } - authOpts.clear(); - authOptCompatWrap.reset(); + if (inBlock) + { + throw ConfigFileException("Unclosed config block. Expecting }"); + } + + std::unordered_map authOpts; - std::string sslFullChainTmp; - std::string sslPrivkeyTmp; + ConfigParseLevel curParseLevel = ConfigParseLevel::Root; + std::shared_ptr curListener; + std::unique_ptr tmpSettings(new Settings); // Then once we know the config file is valid, process it. for (std::string &line : lines) { std::smatch matches; - if (!std::regex_search(line, matches, r) || matches.size() != 3) + if (std::regex_match(line, matches, block_regex_start)) + { + std::string key = matches[1].str(); + if (matches[1].str() == "listen") + { + curParseLevel = ConfigParseLevel::Listen; + curListener.reset(new Listener); + } + else + { + throw ConfigFileException(formatString("'%s' is not a valid block.", key.c_str())); + } + + continue; + } + else if (std::regex_match(line, matches, block_regex_end)) { - throw ConfigFileException("Config parse error at a point that should not be possible."); + if (curParseLevel == ConfigParseLevel::Listen) + { + curListener->isValid(); + tmpSettings->listeners.push_back(curListener); + curListener.reset(); + } + + curParseLevel = ConfigParseLevel::Root; + continue; } + std::regex_match(line, matches, key_value_regex); + std::string key = matches[1].str(); const std::string value = matches[2].str(); - const std::string auth_opt_ = "auth_opt_"; - if (startsWith(key, auth_opt_)) - { - key.replace(0, auth_opt_.length(), ""); - authOpts[key] = value; - } - else + try { - auto valid_key_it = validKeys.find(key); - if (valid_key_it == validKeys.end()) + if (curParseLevel == ConfigParseLevel::Listen) { - std::ostringstream oss; - oss << "Config key '" << key << "' is not valid. This error should have been cought before. Bug?"; - throw ConfigFileException(oss.str()); - } + testKeyValidity(key, validListenKeys); - if (key == "auth_plugin") - { - checkFileAccess(key, value); - if (!test) - this->authPluginPath = value; - } + if (key == "protocol") + { + if (value != "mqtt" && value != "websockets") + throw ConfigFileException(formatString("Protocol '%s' is not a valid listener protocol", value.c_str())); + curListener->websocket = value == "websockets"; + } + else if (key == "port") + { + curListener->port = std::stoi(value); + } + else if (key == "fullchain") + { + curListener->sslFullchain = value; + } + if (key == "privkey") + { + curListener->sslPrivkey = value; + } - if (key == "log_file") - { - checkFileAccess(key, value); - if (!test) - this->logPath = value; + continue; } - if (key == "allow_unsafe_clientid_chars") - { - bool tmp = stringTruthiness(value); - if (!test) - this->allowUnsafeClientidChars = tmp; - } - if (key == "fullchain") + const std::string auth_opt_ = "auth_opt_"; + if (startsWith(key, auth_opt_)) { - checkFileAccess(key, value); - sslFullChainTmp = value; + key.replace(0, auth_opt_.length(), ""); + authOpts[key] = value; } - - if (key == "privkey") + else { - checkFileAccess(key, value); - sslPrivkeyTmp = value; - } + testKeyValidity(key, validKeys); - try - { - // TODO: make this possible. There are many error cases to deal with, like bind failures, etc. You don't want to end up without listeners. - if (key == "listen_port") + if (key == "auth_plugin") { - uint listenportNew = std::stoi(value); - if (listenPort > 0 && listenPort != listenportNew) - throw ConfigFileException("Changing (ssl_)listen_port is not supported at this time."); - listenPort = listenportNew; + checkFileAccess(key, value); + tmpSettings->authPluginPath = value; } - // TODO: make this possible. There are many error cases to deal with, like bind failures, etc. You don't want to end up without listeners. - if (key == "ssl_listen_port") + if (key == "log_file") { - uint sslListenPortNew = std::stoi(value); - if (sslListenPort > 0 && sslListenPort != sslListenPortNew) - throw ConfigFileException("Changing (ssl_)listen_port is not supported at this time."); - sslListenPort = sslListenPortNew; + checkFileAccess(key, value); + tmpSettings->logPath = value; + } + + if (key == "allow_unsafe_clientid_chars") + { + bool tmp = stringTruthiness(value); + tmpSettings->allowUnsafeClientidChars = tmp; } if (key == "client_initial_buffer_size") @@ -239,8 +226,7 @@ void ConfigFileParser::loadFile(bool test) int newVal = std::stoi(value); if (!isPowerOfTwo(newVal)) throw ConfigFileException("client_initial_buffer_size value " + value + " is not a power of two."); - if (!test) - clientInitialBufferSize = newVal; + tmpSettings->clientInitialBufferSize = newVal; } if (key == "max_packet_size") @@ -252,29 +238,33 @@ void ConfigFileParser::loadFile(bool test) oss << "Value for max_packet_size " << newVal << "is higher than absolute maximum " << ABSOLUTE_MAX_PACKET_SIZE; throw ConfigFileException(oss.str()); } - if (!test) - maxPacketSize = newVal; + tmpSettings->maxPacketSize = newVal; } - - } - catch (std::invalid_argument &ex) // catch for the stoi() - { - throw ConfigFileException(ex.what()); } } + catch (std::invalid_argument &ex) // catch for the stoi() + { + throw ConfigFileException(ex.what()); + } + } + + if (tmpSettings->listeners.empty()) + { + std::shared_ptr defaultListener(new Listener()); + tmpSettings->listeners.push_back(defaultListener); } - testSsl(sslFullChainTmp, sslPrivkeyTmp, sslListenPort); - this->sslFullchain = sslFullChainTmp; - this->sslPrivkey = sslPrivkeyTmp; + tmpSettings->authOptCompatWrap = AuthOptCompatWrap(authOpts); - authOptCompatWrap.reset(new AuthOptCompatWrap(authOpts)); + if (!test) + { + this->settings = std::move(tmpSettings); + } } -AuthOptCompatWrap &ConfigFileParser::getAuthOptsCompat() +AuthOptCompatWrap &Settings::getAuthOptsCompat() { - return *authOptCompatWrap.get(); + return authOptCompatWrap; } - diff --git a/configfileparser.h b/configfileparser.h index e05ca44..d6794ac 100644 --- a/configfileparser.h +++ b/configfileparser.h @@ -6,59 +6,33 @@ #include #include #include +#include #include "sslctxmanager.h" +#include "listener.h" +#include "settings.h" #define ABSOLUTE_MAX_PACKET_SIZE 268435461 // 256 MB + 5 -struct mosquitto_auth_opt +enum class ConfigParseLevel { - char *key = nullptr; - char *value = nullptr; - - mosquitto_auth_opt(const std::string &key, const std::string &value); - mosquitto_auth_opt(mosquitto_auth_opt &&other); - mosquitto_auth_opt(const mosquitto_auth_opt &other) = delete; - ~mosquitto_auth_opt(); -}; - -struct AuthOptCompatWrap -{ - std::vector optArray; - - AuthOptCompatWrap(const std::unordered_map &authOpts); - AuthOptCompatWrap(const AuthOptCompatWrap &other) = delete; - AuthOptCompatWrap(AuthOptCompatWrap &&other) = delete; - - struct mosquitto_auth_opt *head() { return &optArray[0]; } - int size() { return optArray.size(); } + Root, + Listen }; class ConfigFileParser { const std::string path; std::set validKeys; - std::unordered_map authOpts; - std::unique_ptr authOptCompatWrap; - + std::set validListenKeys; + void testKeyValidity(const std::string &key, const std::set &validKeys) const; void checkFileAccess(const std::string &key, const std::string &pathToCheck) const; - void testSsl(const std::string &fullchain, const std::string &privkey, uint portNr) const; public: ConfigFileParser(const std::string &path); void loadFile(bool test); - AuthOptCompatWrap &getAuthOptsCompat(); - // Actual config options with their defaults. Just making them public, I can retrain myself misuing them. - std::string authPluginPath; - std::string logPath; - std::string sslFullchain; - std::string sslPrivkey; - uint listenPort = 1883; - uint sslListenPort = 0; - bool allowUnsafeClientidChars = false; - int clientInitialBufferSize = 1024; // Must be power of 2 - int maxPacketSize = 268435461; // 256 MB + 5 + std::unique_ptr settings; }; #endif // CONFIGFILEPARSER_H diff --git a/exceptions.cpp b/exceptions.cpp index 3a4a7aa..270cc29 100644 --- a/exceptions.cpp +++ b/exceptions.cpp @@ -1,2 +1,4 @@ #include "exceptions.h" + + diff --git a/exceptions.h b/exceptions.h index a27595d..2910c20 100644 --- a/exceptions.h +++ b/exceptions.h @@ -3,6 +3,7 @@ #include #include +#include class ProtocolError : public std::runtime_error { @@ -26,6 +27,7 @@ class ConfigFileException : public std::runtime_error { public: ConfigFileException(const std::string &msg) : std::runtime_error(msg) {} + ConfigFileException(std::ostringstream oss) : std::runtime_error(oss.str()) {} }; class AuthPluginException : public std::runtime_error diff --git a/forward_declarations.h b/forward_declarations.h index d4f6902..1c31c49 100644 --- a/forward_declarations.h +++ b/forward_declarations.h @@ -10,6 +10,7 @@ typedef std::shared_ptr ThreadData_p; class MqttPacket; class SubscriptionStore; class Session; +class Settings; #endif // FORWARD_DECLARATIONS_H diff --git a/globalsettings.cpp b/globalsettings.cpp deleted file mode 100644 index 5e10a0f..0000000 --- a/globalsettings.cpp +++ /dev/null @@ -1,3 +0,0 @@ -#include "globalsettings.h" - - diff --git a/globalsettings.h b/globalsettings.h deleted file mode 100644 index bf9dbbb..0000000 --- a/globalsettings.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef GLOBALSETTINGS_H -#define GLOBALSETTINGS_H - -// Defaults are defined in ConfigFileParser -struct GlobalSettings -{ - bool allow_unsafe_clientid_chars = false; - int clientInitialBufferSize = 0; - int maxPacketSize = 0; -}; -#endif // GLOBALSETTINGS_H diff --git a/listener.cpp b/listener.cpp new file mode 100644 index 0000000..33873bf --- /dev/null +++ b/listener.cpp @@ -0,0 +1,78 @@ +#include "listener.h" + +#include "utils.h" +#include "exceptions.h" + +void Listener::isValid() +{ + if (isSsl()) + { + if (port == 0) + { + if (websocket) + port = 4443; + else + port = 8883; + } + + testSsl(sslFullchain, sslPrivkey); + } + else + { + if (port == 0) + { + if (websocket) + port = 8080; + else + port = 1883; + } + } + + if (port <= 0 || port > 65534) + { + throw ConfigFileException(formatString("Port nr %d is not valid", port)); + } +} + +bool Listener::isSsl() const +{ + return (!sslFullchain.empty() || !sslPrivkey.empty()); +} + +std::string Listener::getProtocolName() const +{ + if (isSsl()) + { + if (websocket) + return "SSL websocket"; + else + return "SSL TCP"; + } + else + { + if (websocket) + return "non-SSL websocket"; + else + return "non-SSL TCP"; + } + + return "whoops"; +} + +void Listener::loadCertAndKeyFromConfig() +{ + if (!isSsl()) + return; + + if (!sslctx) + { + sslctx.reset(new SslCtxManager()); + SSL_CTX_set_options(sslctx->get(), SSL_OP_NO_SSLv3); // TODO: config option + SSL_CTX_set_options(sslctx->get(), SSL_OP_NO_TLSv1); // TODO: config option + } + + if (SSL_CTX_use_certificate_file(sslctx->get(), sslFullchain.c_str(), SSL_FILETYPE_PEM) != 1) + throw std::runtime_error("Loading cert failed. This was after test loading the certificate, so is very unexpected."); + if (SSL_CTX_use_PrivateKey_file(sslctx->get(), sslPrivkey.c_str(), SSL_FILETYPE_PEM) != 1) + throw std::runtime_error("Loading key failed. This was after test loading the certificate, so is very unexpected."); +} diff --git a/listener.h b/listener.h new file mode 100644 index 0000000..7d507bb --- /dev/null +++ b/listener.h @@ -0,0 +1,22 @@ +#ifndef LISTENER_H +#define LISTENER_H + +#include +#include + +#include "sslctxmanager.h" + +struct Listener +{ + int port = 0; + bool websocket = false; + std::string sslFullchain; + std::string sslPrivkey; + std::unique_ptr sslctx; + + void isValid(); + bool isSsl() const; + std::string getProtocolName() const; + void loadCertAndKeyFromConfig(); +}; +#endif // LISTENER_H diff --git a/mainapp.cpp b/mainapp.cpp index d46d11b..3bf3bfa 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -172,9 +172,6 @@ MainApp::MainApp(const std::string &configFilePath) : MainApp::~MainApp() { - if (sslctx) - SSL_CTX_free(sslctx); - if (epollFdAccept > 0) close(epollFdAccept); } @@ -202,54 +199,47 @@ void MainApp::showLicense() puts("Author: Wiebe Cazemier "); } -void MainApp::setCertAndKeyFromConfig() -{ - if (sslctx == nullptr) - return; - - if (SSL_CTX_use_certificate_file(sslctx, confFileParser->sslFullchain.c_str(), SSL_FILETYPE_PEM) != 1) - throw std::runtime_error("Loading cert failed. This was after test loading the certificate, so is very unexpected."); - if (SSL_CTX_use_PrivateKey_file(sslctx, confFileParser->sslPrivkey.c_str(), SSL_FILETYPE_PEM) != 1) - throw std::runtime_error("Loading key failed. This was after test loading the certificate, so is very unexpected."); -} - -int MainApp::createListenSocket(int portNr, bool ssl) +int MainApp::createListenSocket(const std::shared_ptr &listener) { - if (portNr <= 0) + if (listener->port <= 0) return -2; - int listen_fd = check(socket(AF_INET, SOCK_STREAM, 0)); - - // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT. - int optval = 1; - check(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); - - int flags = fcntl(listen_fd, F_GETFL); - check(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK )); + logger->logf(LOG_NOTICE, "Creating %s listener on port %d", listener->getProtocolName().c_str(), listener->port); - struct sockaddr_in in_addr_plain; - in_addr_plain.sin_family = AF_INET; - in_addr_plain.sin_addr.s_addr = INADDR_ANY; - in_addr_plain.sin_port = htons(portNr); + try + { + int listen_fd = check(socket(AF_INET, SOCK_STREAM, 0)); - check(bind(listen_fd, (struct sockaddr *)(&in_addr_plain), sizeof(struct sockaddr_in))); - check(listen(listen_fd, 1024)); + // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT. + int optval = 1; + check(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); - struct epoll_event ev; - memset(&ev, 0, sizeof (struct epoll_event)); + int flags = fcntl(listen_fd, F_GETFL); + check(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK )); - ev.data.fd = listen_fd; - ev.events = EPOLLIN; - check(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, listen_fd, &ev)); + struct sockaddr_in in_addr_plain; + in_addr_plain.sin_family = AF_INET; + in_addr_plain.sin_addr.s_addr = INADDR_ANY; + in_addr_plain.sin_port = htons(listener->port); - std::string socketType = "plain"; + check(bind(listen_fd, (struct sockaddr *)(&in_addr_plain), sizeof(struct sockaddr_in))); + check(listen(listen_fd, 1024)); - if (ssl) - socketType = "SSL"; + struct epoll_event ev; + memset(&ev, 0, sizeof (struct epoll_event)); - logger->logf(LOG_NOTICE, "Listening on %s port %d", socketType.c_str(), portNr); + ev.data.fd = listen_fd; + ev.events = EPOLLIN; + check(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, listen_fd, &ev)); - return listen_fd; + return listen_fd; + } + catch (std::exception &ex) + { + logger->logf(LOG_NOTICE, "Creating %s listener on port %d failed: %s", listener->getProtocolName().c_str(), listener->port, ex.what()); + return -1; + } + return -1; } void MainApp::wakeUpThread() @@ -349,9 +339,14 @@ void MainApp::start() { timer.start(); - int listen_fd_plain = createListenSocket(this->listenPort, false); - int listen_fd_ssl = createListenSocket(this->sslListenPort, true); - int listen_fd_websocket_plain = createListenSocket(1443, true); + std::map> listenerMap; + + for(std::shared_ptr &listener : this->listeners) + { + int fd = createListenSocket(listener); + if (fd > 0) + listenerMap[fd] = listener; + } #ifdef NDEBUG logger->noLongerLogToStd(); @@ -365,7 +360,7 @@ void MainApp::start() for (int i = 0; i < num_threads; i++) { - std::shared_ptr t(new ThreadData(i, subscriptionStore, *confFileParser.get(), settings)); + std::shared_ptr t(new ThreadData(i, subscriptionStore, settings)); t->start(&do_thread_work); threads.push_back(t); } @@ -392,22 +387,22 @@ void MainApp::start() int cur_fd = events[i].data.fd; try { - if (cur_fd == listen_fd_plain || cur_fd == listen_fd_ssl || listen_fd_websocket_plain) + if (cur_fd != taskEventFd) { + std::shared_ptr listener = listenerMap[cur_fd]; std::shared_ptr thread_data = threads[next_thread_index++ % num_threads]; - logger->logf(LOG_INFO, "Accepting connection on thread %d", thread_data->threadnr); + logger->logf(LOG_INFO, "Accepting connection on thread %d on %s", thread_data->threadnr, listener->getProtocolName().c_str()); struct sockaddr addr; memset(&addr, 0, sizeof(struct sockaddr)); socklen_t len = sizeof(struct sockaddr); int fd = check(accept(cur_fd, &addr, &len)); - bool websocket = cur_fd == listen_fd_websocket_plain; SSL *clientSSL = nullptr; - if (cur_fd == listen_fd_ssl) + if (listener->isSsl()) { - clientSSL = SSL_new(sslctx); + clientSSL = SSL_new(listener->sslctx->get()); if (clientSSL == NULL) { @@ -419,10 +414,10 @@ void MainApp::start() SSL_set_fd(clientSSL, fd); } - Client_p client(new Client(fd, thread_data, clientSSL, websocket, settings)); + Client_p client(new Client(fd, thread_data, clientSSL, listener->websocket, settings)); thread_data->giveClient(client); } - else if (cur_fd == taskEventFd) + else { uint64_t eventfd_value = 0; check(read(cur_fd, &eventfd_value, sizeof(uint64_t))); @@ -434,10 +429,6 @@ void MainApp::start() } taskQueue.clear(); } - else - { - throw std::runtime_error("Bug: the main thread had activity on an fd it's not supposed to monitor."); - } } catch (std::exception &ex) { @@ -452,12 +443,18 @@ void MainApp::start() thread->quit(); } - close(listen_fd_plain); - close(listen_fd_ssl); + for(auto pair : listenerMap) + { + close(pair.first); + } } void MainApp::quit() { + std::lock_guard guard(quitMutex); + if (!running) + return; + Logger *logger = Logger::getInstance(); logger->logf(LOG_NOTICE, "Quitting FlashMQ"); timer.stop(); @@ -472,26 +469,21 @@ void MainApp::loadConfig() // Atomic loading, first test. confFileParser->loadFile(true); confFileParser->loadFile(false); + settings = std::move(confFileParser->settings); - logger->setLogPath(confFileParser->logPath); - logger->reOpen(); + // For now, it's too much work to be able to reload new listeners, with all the shared resource stuff going on. So, I'm + // loading them to a local var which is never updated. + if (listeners.empty()) + listeners = settings->listeners; - listenPort = confFileParser->listenPort; - sslListenPort = confFileParser->sslListenPort; + logger->setLogPath(settings->logPath); + logger->reOpen(); - if (sslctx == nullptr && sslListenPort > 0) + for (std::shared_ptr &l : this->listeners) { - sslctx = SSL_CTX_new(TLS_server_method()); - SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv3); // TODO: config option - SSL_CTX_set_options(sslctx, SSL_OP_NO_TLSv1); // TODO: config option + l->loadCertAndKeyFromConfig(); } - settings.allow_unsafe_clientid_chars = confFileParser->allowUnsafeClientidChars; - settings.clientInitialBufferSize = confFileParser->clientInitialBufferSize; - settings.maxPacketSize = confFileParser->maxPacketSize; - - setCertAndKeyFromConfig(); - for (std::shared_ptr &thread : threads) { thread->queueReload(settings); diff --git a/mainapp.h b/mainapp.h index c2f1263..d119b3f 100644 --- a/mainapp.h +++ b/mainapp.h @@ -20,7 +20,6 @@ #include "subscriptionstore.h" #include "configfileparser.h" #include "timer.h" -#include "globalsettings.h" class MainApp { @@ -37,11 +36,9 @@ class MainApp int taskEventFd = -1; std::mutex eventMutex; Timer timer; - GlobalSettings settings; - - uint listenPort = 0; - uint sslListenPort = 0; - SSL_CTX *sslctx = nullptr; + std::shared_ptr settings; + std::list> listeners; + std::mutex quitMutex; Logger *logger = Logger::getInstance(); @@ -49,8 +46,7 @@ class MainApp void reloadConfig(); static void doHelp(const char *arg); static void showLicense(); - void setCertAndKeyFromConfig(); - int createListenSocket(int portNr, bool ssl); + int createListenSocket(const std::shared_ptr &listener); void wakeUpThread(); void queueKeepAliveCheckAtAllThreads(); @@ -66,7 +62,6 @@ public: bool getStarted() const {return started;} static void testConfig(); - GlobalSettings &getGlobalSettings(); void queueConfigReload(); void queueCleanup(); }; diff --git a/mosquittoauthoptcompatwrap.cpp b/mosquittoauthoptcompatwrap.cpp new file mode 100644 index 0000000..b5f22b8 --- /dev/null +++ b/mosquittoauthoptcompatwrap.cpp @@ -0,0 +1,52 @@ +#include "mosquittoauthoptcompatwrap.h" + +mosquitto_auth_opt::mosquitto_auth_opt(const std::string &key, const std::string &value) +{ + this->key = strdup(key.c_str()); + this->value = strdup(value.c_str()); +} + +mosquitto_auth_opt::mosquitto_auth_opt(mosquitto_auth_opt &&other) +{ + this->key = other.key; + this->value = other.value; + other.key = nullptr; + other.value = nullptr; +} + +mosquitto_auth_opt::mosquitto_auth_opt(const mosquitto_auth_opt &other) +{ + this->key = strdup(other.key); + this->value = strdup(other.value); +} + +mosquitto_auth_opt::~mosquitto_auth_opt() +{ + if (key) + delete key; + if (value) + delete value; +} + +mosquitto_auth_opt &mosquitto_auth_opt::operator=(const mosquitto_auth_opt &other) +{ + if (key) + delete key; + if (value) + delete value; + + this->key = strdup(other.key); + this->value = strdup(other.value); + + return *this; +} + +AuthOptCompatWrap::AuthOptCompatWrap(const std::unordered_map &authOpts) +{ + for(auto &pair : authOpts) + { + mosquitto_auth_opt opt(pair.first, pair.second); + optArray.push_back(std::move(opt)); + } +} + diff --git a/mosquittoauthoptcompatwrap.h b/mosquittoauthoptcompatwrap.h new file mode 100644 index 0000000..bcc4fd7 --- /dev/null +++ b/mosquittoauthoptcompatwrap.h @@ -0,0 +1,47 @@ +#ifndef MOSQUITTOAUTHOPTCOMPATWRAP_H +#define MOSQUITTOAUTHOPTCOMPATWRAP_H + +#include +#include +#include + +/** + * @brief The mosquitto_auth_opt struct is a resource managed class of auth options, compatible with passing as arguments to Mosquitto + * auth plugins. + * + * It's fully assignable and copyable. + */ +struct mosquitto_auth_opt +{ + char *key = nullptr; + char *value = nullptr; + + mosquitto_auth_opt(const std::string &key, const std::string &value); + mosquitto_auth_opt(mosquitto_auth_opt &&other); + mosquitto_auth_opt(const mosquitto_auth_opt &other); + ~mosquitto_auth_opt(); + + mosquitto_auth_opt& operator=(const mosquitto_auth_opt &other); +}; + +/** + * @brief The AuthOptCompatWrap struct contains a vector of mosquitto auth options, with a head pointer and count which can be passed to + * Mosquitto auth plugins. + */ +struct AuthOptCompatWrap +{ + std::vector optArray; + + AuthOptCompatWrap(const std::unordered_map &authOpts); + AuthOptCompatWrap(const AuthOptCompatWrap &other) = default; + AuthOptCompatWrap(AuthOptCompatWrap &&other) = delete; + AuthOptCompatWrap() = default; + + struct mosquitto_auth_opt *head() { return &optArray[0]; } + int size() { return optArray.size(); } + + AuthOptCompatWrap &operator=(const AuthOptCompatWrap &other) = default; +}; + + +#endif // MOSQUITTOAUTHOPTCOMPATWRAP_H diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 2b940ee..8a14307 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -145,7 +145,7 @@ void MqttPacket::handleConnect() uint16_t variable_header_length = readTwoBytesToUInt16(); - const GlobalSettings &settings = sender->getThreadData()->settingsLocalCopy; + const Settings &settings = sender->getThreadData()->settingsLocalCopy; if (variable_header_length == 4 || variable_header_length == 6) { @@ -232,7 +232,7 @@ void MqttPacket::handleConnect() bool validClientId = true; // Check for wildcard chars in case the client_id ever appears in topics. - if (!settings.allow_unsafe_clientid_chars && (strContains(client_id, "+") || strContains(client_id, "#"))) + if (!settings.allowUnsafeClientidChars && (strContains(client_id, "+") || strContains(client_id, "#"))) { logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false:", client_id.c_str()); validClientId = false; diff --git a/mqttpacket.h b/mqttpacket.h index 2f4c86a..31cfe0d 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -15,7 +15,6 @@ #include "cirbuf.h" #include "logger.h" #include "mainapp.h" -#include "globalsettings.h" struct RemainingLength { diff --git a/settings.cpp b/settings.cpp new file mode 100644 index 0000000..bca168d --- /dev/null +++ b/settings.cpp @@ -0,0 +1,3 @@ +#include "settings.h" + + diff --git a/settings.h b/settings.h new file mode 100644 index 0000000..d440612 --- /dev/null +++ b/settings.h @@ -0,0 +1,28 @@ +#ifndef SETTINGS_H +#define SETTINGS_H + +#include +#include + +#include "mosquittoauthoptcompatwrap.h" +#include "listener.h" + +class Settings +{ + friend class ConfigFileParser; + + AuthOptCompatWrap authOptCompatWrap; + +public: + // Actual config options with their defaults. + std::string authPluginPath; + std::string logPath; + bool allowUnsafeClientidChars = false; + int clientInitialBufferSize = 1024; // Must be power of 2 + int maxPacketSize = 268435461; // 256 MB + 5 + std::list> listeners; // Default one is created later, when none are defined. + + AuthOptCompatWrap &getAuthOptsCompat(); +}; + +#endif // SETTINGS_H diff --git a/sslctxmanager.cpp b/sslctxmanager.cpp index 7519f44..9ea326e 100644 --- a/sslctxmanager.cpp +++ b/sslctxmanager.cpp @@ -16,3 +16,8 @@ SSL_CTX *SslCtxManager::get() const { return ssl_ctx; } + +SslCtxManager::operator bool() const +{ + return ssl_ctx == nullptr; +} diff --git a/sslctxmanager.h b/sslctxmanager.h index bd5c227..0f0510c 100644 --- a/sslctxmanager.h +++ b/sslctxmanager.h @@ -11,6 +11,7 @@ public: ~SslCtxManager(); SSL_CTX *get() const; + operator bool() const; }; #endif // SSLCTXMANAGER_H diff --git a/threaddata.cpp b/threaddata.cpp index 1c5b2dd..bcfbcd0 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -2,12 +2,11 @@ #include #include -ThreadData::ThreadData(int threadnr, std::shared_ptr &subscriptionStore, ConfigFileParser &confFileParser, const GlobalSettings &settings) : +ThreadData::ThreadData(int threadnr, std::shared_ptr &subscriptionStore, std::shared_ptr settings) : subscriptionStore(subscriptionStore), - confFileParser(confFileParser), - authPlugin(confFileParser), - threadnr(threadnr), - settingsLocalCopy(settings) + settingsLocalCopy(*settings.get()), + authPlugin(settingsLocalCopy), + threadnr(threadnr) { logger = Logger::getInstance(); @@ -146,18 +145,19 @@ void ThreadData::doKeepAliveCheck() void ThreadData::initAuthPlugin() { - authPlugin.loadPlugin(confFileParser.authPluginPath); + authPlugin.loadPlugin(settingsLocalCopy.authPluginPath); authPlugin.init(); authPlugin.securityInit(false); } -void ThreadData::reload(GlobalSettings settings) +void ThreadData::reload(std::shared_ptr settings) { logger->logf(LOG_DEBUG, "Doing reload in thread %d", threadnr); try { - settingsLocalCopy = settings; + // Because the auth plugin has a reference to it, it will also be updated. + settingsLocalCopy = *settings.get(); authPlugin.securityCleanup(true); authPlugin.securityInit(true); @@ -172,7 +172,7 @@ void ThreadData::reload(GlobalSettings settings) } } -void ThreadData::queueReload(GlobalSettings settings) +void ThreadData::queueReload(std::shared_ptr settings) { std::lock_guard locker(taskQueueMutex); diff --git a/threaddata.h b/threaddata.h index 1cd2664..2dc3321 100644 --- a/threaddata.h +++ b/threaddata.h @@ -20,7 +20,6 @@ #include "configfileparser.h" #include "authplugin.h" #include "logger.h" -#include "globalsettings.h" typedef void (*thread_f)(ThreadData *); @@ -29,14 +28,14 @@ class ThreadData std::unordered_map clients_by_fd; std::mutex clients_by_fd_mutex; std::shared_ptr subscriptionStore; - ConfigFileParser &confFileParser; Logger *logger; - void reload(GlobalSettings settings); + void reload(std::shared_ptr settings); void wakeUpThread(); void doKeepAliveCheck(); public: + Settings settingsLocalCopy; // Is updated on reload, within the thread loop. AuthPlugin authPlugin; bool running = true; std::thread thread; @@ -45,9 +44,8 @@ public: int taskEventFd = 0; std::mutex taskQueueMutex; std::forward_list> taskQueue; - GlobalSettings settingsLocalCopy; // Is updated on reload, within the thread loop. - ThreadData(int threadnr, std::shared_ptr &subscriptionStore, ConfigFileParser &confFileParser, const GlobalSettings &settings); + ThreadData(int threadnr, std::shared_ptr &subscriptionStore, std::shared_ptr settings); ThreadData(const ThreadData &other) = delete; ThreadData(ThreadData &&other) = delete; @@ -60,7 +58,7 @@ public: std::shared_ptr &getSubscriptionStore(); void initAuthPlugin(); - void queueReload(GlobalSettings settings); + void queueReload(std::shared_ptr settings); void queueDoKeepAliveCheck(); }; diff --git a/utils.cpp b/utils.cpp index 23b3560..9f4731a 100644 --- a/utils.cpp +++ b/utils.cpp @@ -3,10 +3,15 @@ #include "sys/time.h" #include "sys/random.h" #include -#include +#include + +#include "openssl/ssl.h" +#include "openssl/err.h" #include "exceptions.h" #include "cirbuf.h" +#include "sslctxmanager.h" +#include "logger.h" std::list split(const std::string &input, const char sep, size_t max, bool keep_empty_parts) { @@ -346,3 +351,47 @@ std::string generateWebsocketAnswer(const std::string &acceptString) oss.flush(); return oss.str(); } + +// Using a separate ssl context to test, because it's the easiest way to load certs and key atomitcally. +void testSsl(const std::string &fullchain, const std::string &privkey) +{ + if (fullchain.empty() && privkey.empty()) + throw ConfigFileException("No privkey and fullchain specified."); + + if (fullchain.empty()) + throw ConfigFileException("No private key specified for fullchain"); + + if (privkey.empty()) + throw ConfigFileException("No fullchain specified for private key"); + + SslCtxManager sslCtx; + if (SSL_CTX_use_certificate_file(sslCtx.get(), fullchain.c_str(), SSL_FILETYPE_PEM) != 1) + { + ERR_print_errors_cb(logSslError, NULL); + throw ConfigFileException("Error loading full chain " + fullchain); + } + if (SSL_CTX_use_PrivateKey_file(sslCtx.get(), privkey.c_str(), SSL_FILETYPE_PEM) != 1) + { + ERR_print_errors_cb(logSslError, NULL); + throw ConfigFileException("Error loading private key " + privkey); + } + if (SSL_CTX_check_private_key(sslCtx.get()) != 1) + { + ERR_print_errors_cb(logSslError, NULL); + throw ConfigFileException("Private key and certificate don't match."); + } +} + +std::string formatString(const std::string str, ...) +{ + char buf[512]; + + va_list valist; + va_start(valist, str); + vsnprintf(buf, 512, str.c_str(), valist); + va_end(valist); + + std::string result(buf, 512); + + return result; +} diff --git a/utils.h b/utils.h index 666f517..d7776dc 100644 --- a/utils.h +++ b/utils.h @@ -55,4 +55,8 @@ std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion); std::string generateBadHttpRequestReponse(const std::string &msg); std::string generateWebsocketAnswer(const std::string &acceptString); +void testSsl(const std::string &fullchain, const std::string &privkey); + +std::string formatString(const std::string str, ...); + #endif // UTILS_H