diff --git a/CMakeLists.txt b/CMakeLists.txt index 60123d0..c24d452 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,6 +23,7 @@ add_executable(FlashMQ cirbuf.cpp logger.cpp authplugin.cpp + configfileparser.cpp ) target_link_libraries(FlashMQ pthread dl) diff --git a/authplugin.cpp b/authplugin.cpp index 41b40d4..38777f8 100644 --- a/authplugin.cpp +++ b/authplugin.cpp @@ -7,10 +7,6 @@ #include "exceptions.h" -// TODO: error handling on all the calls to the plugin. Exceptions? Passing to the caller? -// TODO: where to do the conditionals about whether the plugin is loaded, what to do on error, etc? -// -> Perhaps merely log the error (and return 'denied'?)? - void mosquitto_log_printf(int level, const char *fmt, ...) { Logger *logger = Logger::getInstance(); @@ -21,12 +17,18 @@ void mosquitto_log_printf(int level, const char *fmt, ...) } -AuthPlugin::AuthPlugin() // our configuration object as param +AuthPlugin::AuthPlugin(ConfigFileParser &confFileParser) : + confFileParser(confFileParser) { logger = Logger::getInstance(); } -void *AuthPlugin::loadSymbol(void *handle, const char *symbol) +AuthPlugin::~AuthPlugin() +{ + cleanup(); +} + +void *AuthPlugin::loadSymbol(void *handle, const char *symbol) const { void *r = dlsym(handle, symbol); @@ -41,8 +43,14 @@ void *AuthPlugin::loadSymbol(void *handle, const char *symbol) void AuthPlugin::loadPlugin(const std::string &pathToSoFile) { + if (pathToSoFile.empty()) + return; + logger->logf(LOG_INFO, "Loading auth plugin %s", pathToSoFile.c_str()); + initialized = false; + wanted = true; + if (access(pathToSoFile.c_str(), R_OK) != 0) { std::ostringstream oss; @@ -72,47 +80,105 @@ void AuthPlugin::loadPlugin(const std::string &pathToSoFile) acl_check_v2 = (F_auth_plugin_acl_check_v2)loadSymbol(r, "mosquitto_auth_acl_check"); unpwd_check_v2 = (F_auth_plugin_unpwd_check_v2)loadSymbol(r, "mosquitto_auth_unpwd_check"); psk_key_get_v2 = (F_auth_plugin_psk_key_get_v2)loadSymbol(r, "mosquitto_auth_psk_key_get"); + + initialized = true; } -int AuthPlugin::init() +void AuthPlugin::init() { - struct mosquitto_auth_opt auth_opts[2]; // TODO: get auth opts from central config object - std::memset(&auth_opts, 0, sizeof(struct mosquitto_auth_opt) * 2); - int result = init_v2(&pluginData, auth_opts, 2); - return result; + if (!wanted) + return; + + AuthOptCompatWrap &authOpts = confFileParser.getAuthOptsCompat(); + int result = init_v2(&pluginData, authOpts.head(), authOpts.size()); + if (result != 0) + throw FatalError("Error initialising auth plugin."); } -int AuthPlugin::cleanup() +void AuthPlugin::cleanup() { - struct mosquitto_auth_opt auth_opts[2]; // TODO: get auth opts from central config object - std::memset(&auth_opts, 0, sizeof(struct mosquitto_auth_opt) * 2); - return cleanup_v2(pluginData, auth_opts, 2); + if (!cleanup_v2) + return; + + securityCleanup(false); + + AuthOptCompatWrap &authOpts = confFileParser.getAuthOptsCompat(); + int result = cleanup_v2(pluginData, authOpts.head(), authOpts.size()); + if (result != 0) + logger->logf(LOG_ERR, "Error cleaning up auth plugin"); // Not doing exception, because we're shutting down anyway. } -int AuthPlugin::securityInit(bool reloading) +void AuthPlugin::securityInit(bool reloading) { - struct mosquitto_auth_opt auth_opts[2]; // TODO: get auth opts from central config object - std::memset(&auth_opts, 0, sizeof(struct mosquitto_auth_opt) * 2); - return security_init_v2(pluginData, auth_opts, 2, reloading); + if (!wanted) + return; + + AuthOptCompatWrap &authOpts = confFileParser.getAuthOptsCompat(); + int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading); + if (result != 0) + { + throw AuthPluginException("Plugin function mosquitto_auth_security_init returned an error. If it didn't log anything, we don't know what it was."); + } + initialized = true; } -int AuthPlugin::securityCleanup(bool reloading) +void AuthPlugin::securityCleanup(bool reloading) { - struct mosquitto_auth_opt auth_opts[2]; // TODO: get auth opts from central config object - std::memset(&auth_opts, 0, sizeof(struct mosquitto_auth_opt) * 2); - return security_cleanup_v2(pluginData, auth_opts, 2, reloading); + if (!wanted) + return; + + initialized = false; + AuthOptCompatWrap &authOpts = confFileParser.getAuthOptsCompat(); + int result = security_cleanup_v2(pluginData, authOpts.head(), authOpts.size(), reloading); + + if (result != 0) + { + throw AuthPluginException("Plugin function mosquitto_auth_security_cleanup returned an error. If it didn't log anything, we don't know what it was."); + } } AuthResult AuthPlugin::aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, AclAccess access) { + if (!wanted) + return AuthResult::success; + + if (!initialized) + { + logger->logf(LOG_ERR, "ACL check wanted, but initialization failed. Can't perform check."); + return AuthResult::error; + } + int result = acl_check_v2(pluginData, clientid.c_str(), username.c_str(), topic.c_str(), static_cast(access)); - return static_cast(result); + AuthResult result_ = static_cast(result); + + if (result_ == AuthResult::error) + { + logger->logf(LOG_ERR, "ACL check by plugin returned error for topic '%s'. If it didn't log anything, we don't know what it was.", topic.c_str()); + } + + return result_; } AuthResult AuthPlugin::unPwdCheck(const std::string &username, const std::string &password) { + if (!wanted) + return AuthResult::success; + + if (!initialized) + { + logger->logf(LOG_ERR, "Username+password check wanted, but initialization failed. Can't perform check."); + return AuthResult::error; + } + int result = unpwd_check_v2(pluginData, username.c_str(), password.c_str()); - return static_cast(result); + AuthResult r = static_cast(result); + + if (r == AuthResult::error) + { + logger->logf(LOG_ERR, "Username+password check by plugin returned error for user '%s'. If it didn't log anything, we don't know what it was.", username.c_str()); + } + + return r; } diff --git a/authplugin.h b/authplugin.h index ca3d548..6c768ec 100644 --- a/authplugin.h +++ b/authplugin.h @@ -5,6 +5,7 @@ #include #include "logger.h" +#include "configfileparser.h" // Compatible with Mosquitto enum class AclAccess @@ -23,11 +24,6 @@ enum class AuthResult error = 13 }; -struct mosquitto_auth_opt { - char *key; - char *value; -}; - typedef int (*F_auth_plugin_version)(void); typedef int (*F_auth_plugin_init_v2)(void **, struct mosquitto_auth_opt *, int); @@ -55,18 +51,25 @@ class AuthPlugin F_auth_plugin_unpwd_check_v2 unpwd_check_v2 = nullptr; F_auth_plugin_psk_key_get_v2 psk_key_get_v2 = nullptr; + ConfigFileParser &confFileParser; + void *pluginData = nullptr; Logger *logger = nullptr; + bool initialized = false; + bool wanted = false; - void *loadSymbol(void *handle, const char *symbol); + void *loadSymbol(void *handle, const char *symbol) const; public: - AuthPlugin(); + AuthPlugin(ConfigFileParser &confFileParser); + AuthPlugin(const AuthPlugin &other) = delete; + AuthPlugin(AuthPlugin &&other) = delete; + ~AuthPlugin(); void loadPlugin(const std::string &pathToSoFile); - int init(); - int cleanup(); - int securityInit(bool reloading); - int securityCleanup(bool reloading); + void init(); + void cleanup(); + void securityInit(bool reloading); + void securityCleanup(bool reloading); AuthResult aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, AclAccess access); AuthResult unPwdCheck(const std::string &username, const std::string &password); diff --git a/configfileparser.cpp b/configfileparser.cpp new file mode 100644 index 0000000..ea1f5e7 --- /dev/null +++ b/configfileparser.cpp @@ -0,0 +1,143 @@ +#include "configfileparser.h" + +#include +#include +#include +#include "fstream" + +#include "exceptions.h" +#include "utils.h" +#include + + +mosquitto_auth_opt::mosquitto_auth_opt(const std::string &key, const std::string &value) +{ + this->key = strdup(key.c_str()); + this->value = strdup(value.c_str()); +} + +mosquitto_auth_opt::mosquitto_auth_opt(mosquitto_auth_opt &&other) +{ + this->key = other.key; + this->value = other.value; + other.key = nullptr; + other.value = nullptr; +} + +mosquitto_auth_opt::~mosquitto_auth_opt() +{ + if (key) + delete key; + if (value) + delete value; +} + +AuthOptCompatWrap::AuthOptCompatWrap(const std::unordered_map &authOpts) +{ + for(auto &pair : authOpts) + { + mosquitto_auth_opt opt(pair.first, pair.second); + optArray.push_back(std::move(opt)); + } +} + +ConfigFileParser::ConfigFileParser(const std::string &path) : + path(path) +{ + validKeys.insert("auth_plugin"); +} + +void ConfigFileParser::loadFile() +{ + if (access(path.c_str(), R_OK) != 0) + { + std::ostringstream oss; + oss << "Error: " << path << " is not there or not readable"; + throw ConfigFileException(oss.str()); + } + + std::ifstream infile(path, std::ios::in); + + if (!infile.is_open()) + { + std::ostringstream oss; + oss << "Error loading " << path; + throw ConfigFileException(oss.str()); + } + + std::list lines; + + const std::regex r("^([a-zA-Z0-9_\\-]+) +([a-zA-Z0-9_\\-/\\.]+)$"); + + // First parse the file and keep the valid lines. + for(std::string line; getline(infile, line ); ) + { + trim(line); + + if (startsWith(line, "#")) + continue; + + if (line.empty()) + continue; + + std::smatch matches; + + if (!std::regex_search(line, matches, r) || matches.size() != 3) + { + std::ostringstream oss; + oss << "Line '" << line << "' not in 'key value' format"; + throw ConfigFileException(oss.str()); + } + + lines.push_back(line); + } + + authOpts.clear(); + authOptCompatWrap.reset(); + + // Then once we know the config file is valid, process it. + for (std::string &line : lines) + { + std::smatch matches; + + if (!std::regex_search(line, matches, r) || matches.size() != 3) + { + throw ConfigFileException("Config parse error at a point that should not be possible."); + } + + std::string key = matches[1].str(); + const std::string value = matches[2].str(); + + const std::string auth_opt_ = "auth_opt_"; + if (startsWith(key, auth_opt_)) + { + key.replace(0, auth_opt_.length(), ""); + authOpts[key] = value; + } + else + { + auto valid_key_it = validKeys.find(key); + if (valid_key_it == validKeys.end()) + { + std::ostringstream oss; + oss << "Config key '" << key << "' is not valid"; + throw ConfigFileException(oss.str()); + } + + if (key == "auth_plugin") + { + this->authPluginPath = value; + } + } + } + + authOptCompatWrap.reset(new AuthOptCompatWrap(authOpts)); +} + +AuthOptCompatWrap &ConfigFileParser::getAuthOptsCompat() +{ + return *authOptCompatWrap.get(); +} + + + diff --git a/configfileparser.h b/configfileparser.h new file mode 100644 index 0000000..f4cf730 --- /dev/null +++ b/configfileparser.h @@ -0,0 +1,49 @@ +#ifndef CONFIGFILEPARSER_H +#define CONFIGFILEPARSER_H + +#include +#include +#include +#include +#include + +struct mosquitto_auth_opt +{ + char *key = nullptr; + char *value = nullptr; + + mosquitto_auth_opt(const std::string &key, const std::string &value); + mosquitto_auth_opt(mosquitto_auth_opt &&other); + mosquitto_auth_opt(const mosquitto_auth_opt &other) = delete; + ~mosquitto_auth_opt(); +}; + +struct AuthOptCompatWrap +{ + std::vector optArray; + + AuthOptCompatWrap(const std::unordered_map &authOpts); + AuthOptCompatWrap(const AuthOptCompatWrap &other) = delete; + AuthOptCompatWrap(AuthOptCompatWrap &&other) = delete; + + struct mosquitto_auth_opt *head() { return &optArray[0]; } + int size() { return optArray.size(); } +}; + +class ConfigFileParser +{ + const std::string path; + std::set validKeys; + std::unordered_map authOpts; + std::unique_ptr authOptCompatWrap; + std::string authPluginPath; + +public: + ConfigFileParser(const std::string &path); + void loadFile(); + AuthOptCompatWrap &getAuthOptsCompat(); + + std::string getAuthPluginPath() { return authPluginPath; } +}; + +#endif // CONFIGFILEPARSER_H diff --git a/exceptions.h b/exceptions.h index 55cac3b..563ce89 100644 --- a/exceptions.h +++ b/exceptions.h @@ -22,4 +22,16 @@ public: FatalError(const std::string &msg) : std::runtime_error(msg) {} }; +class ConfigFileException : public std::runtime_error +{ +public: + ConfigFileException(const std::string &msg) : std::runtime_error(msg) {} +}; + +class AuthPluginException : public std::runtime_error +{ +public: + AuthPluginException(const std::string &msg) : std::runtime_error(msg) {} +}; + #endif // EXCEPTIONS_H diff --git a/main.cpp b/main.cpp index 54efb61..a311883 100644 --- a/main.cpp +++ b/main.cpp @@ -6,7 +6,7 @@ #include "mainapp.h" -MainApp *mainApp = MainApp::getMainApp(); +MainApp *mainApp = nullptr; static void signal_handler(int signal) { @@ -65,6 +65,7 @@ int main() { try { + mainApp = MainApp::getMainApp(); check(register_signal_handers()); mainApp->start(); } diff --git a/mainapp.cpp b/mainapp.cpp index 9a48dac..bba53f1 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -111,7 +111,8 @@ void do_thread_work(ThreadData *threadData) MainApp::MainApp() : subscriptionStore(new SubscriptionStore()) { - + confFileParser.reset(new ConfigFileParser("/home/halfgaar/Projects/FlashMQThings/config.txt")); // TODO: from argv + confFileParser->loadFile(); } MainApp *MainApp::getMainApp() @@ -154,7 +155,7 @@ void MainApp::start() for (int i = 0; i < NR_OF_THREADS; i++) { - std::shared_ptr t(new ThreadData(i, subscriptionStore)); + std::shared_ptr t(new ThreadData(i, subscriptionStore, *confFileParser.get())); std::thread thread(do_thread_work, t.get()); t->moveThreadHere(std::move(thread)); threads.push_back(t); diff --git a/mainapp.h b/mainapp.h index e796b0c..e084b6e 100644 --- a/mainapp.h +++ b/mainapp.h @@ -16,6 +16,7 @@ #include "client.h" #include "mqttpacket.h" #include "subscriptionstore.h" +#include "configfileparser.h" class MainApp { @@ -25,6 +26,7 @@ class MainApp bool running = true; std::vector> threads; std::shared_ptr subscriptionStore; + std::unique_ptr confFileParser; MainApp(); public: diff --git a/threaddata.cpp b/threaddata.cpp index e0a9c92..f07706d 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -2,11 +2,19 @@ #include #include -ThreadData::ThreadData(int threadnr, std::shared_ptr &subscriptionStore) : +ThreadData::ThreadData(int threadnr, std::shared_ptr &subscriptionStore, ConfigFileParser &confFileParser) : subscriptionStore(subscriptionStore), + confFileParser(confFileParser), + authPlugin(confFileParser), threadnr(threadnr) { + logger = Logger::getInstance(); + epollfd = check(epoll_create(999)); + + authPlugin.loadPlugin(confFileParser.getAuthPluginPath()); + authPlugin.init(); + authPlugin.securityInit(false); } void ThreadData::moveThreadHere(std::thread &&thread) @@ -92,5 +100,18 @@ bool ThreadData::doKeepAliveCheck() return true; } +void ThreadData::reload() +{ + try + { + authPlugin.securityCleanup(true); + authPlugin.securityInit(true); + } + catch (AuthPluginException &ex) + { + logger->logf(LOG_ERR, "Error reloading auth plugin: %s. Security checks will now fail, because we don't know the status of the plugin anymore.", ex.what()); + } +} + diff --git a/threaddata.h b/threaddata.h index 4887a43..ac92b27 100644 --- a/threaddata.h +++ b/threaddata.h @@ -16,14 +16,18 @@ #include "client.h" #include "subscriptionstore.h" #include "utils.h" - - +#include "configfileparser.h" +#include "authplugin.h" +#include "logger.h" class ThreadData { std::unordered_map clients_by_fd; std::mutex clients_by_fd_mutex; std::shared_ptr subscriptionStore; + ConfigFileParser &confFileParser; + AuthPlugin authPlugin; + Logger *logger; public: bool running = true; @@ -31,7 +35,9 @@ public: int threadnr = 0; int epollfd = 0; - ThreadData(int threadnr, std::shared_ptr &subscriptionStore); + ThreadData(int threadnr, std::shared_ptr &subscriptionStore, ConfigFileParser &confFileParser); + ThreadData(const ThreadData &other) = delete; + ThreadData(ThreadData &&other) = delete; void moveThreadHere(std::thread &&thread); void quit(); @@ -42,6 +48,7 @@ public: std::shared_ptr &getSubscriptionStore(); bool doKeepAliveCheck(); + void reload(); }; #endif // THREADDATA_H diff --git a/utils.cpp b/utils.cpp index 898d713..15be2a6 100644 --- a/utils.cpp +++ b/utils.cpp @@ -1,6 +1,6 @@ #include "utils.h" - +#include std::list split(const std::string &input, const char sep, size_t max, bool keep_empty_parts) { @@ -143,3 +143,28 @@ std::vector splitToVector(const std::string &input, const char sep, list.push_back(input.substr(start, std::string::npos)); return list; } + +void ltrim(std::string &s) +{ + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { + return !std::isspace(ch); + })); +} + +void rtrim(std::string &s) +{ + s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { + return !std::isspace(ch); + }).base(), s.end()); +} + +void trim(std::string &s) +{ + ltrim(s); + rtrim(s); +} + +bool startsWith(const std::string &s, const std::string &needle) +{ + return s.find(needle) == 0; +} diff --git a/utils.h b/utils.h index af0f0c2..57409df 100644 --- a/utils.h +++ b/utils.h @@ -7,6 +7,7 @@ #include #include #include +#include template int check(int rc) { @@ -31,4 +32,10 @@ bool strContains(const std::string &s, const std::string &needle); bool isValidPublishPath(const std::string &s); +void ltrim(std::string &s); +void rtrim(std::string &s); +void trim(std::string &s); +bool startsWith(const std::string &s, const std::string &needle); + + #endif // UTILS_H