From aa19d41b6524de1e1bfff40a32e4f8344b0a5eeb Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Tue, 8 Jun 2021 22:35:32 +0200 Subject: [PATCH] Implement a native FlashMQ auth plugin --- CMakeLists.txt | 2 ++ authplugin.cpp | 211 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------------------------- authplugin.h | 38 +++++++++++++++++++++++++++++++++++--- client.cpp | 5 +++++ client.h | 1 + configfileparser.cpp | 20 ++++++++++++++------ enums.h | 17 +---------------- flashmq_plugin.cpp | 22 ++++++++++++++++++++++ flashmq_plugin.h | 211 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ logger.h | 11 +---------- mainapp.cpp | 20 ++++++++++++++++++-- mainapp.h | 1 + mqttpacket.cpp | 23 +++++++++++++++++++---- session.cpp | 14 ++++++++++---- session.h | 2 +- settings.cpp | 9 +++++++++ settings.h | 3 +++ subscriptionstore.cpp | 8 +++----- subscriptionstore.h | 2 +- threaddata.cpp | 22 +++++++++++++++++----- threaddata.h | 3 +++ types.cpp | 4 ++++ 22 files changed, 545 insertions(+), 104 deletions(-) create mode 100644 flashmq_plugin.cpp create mode 100644 flashmq_plugin.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 20db3da..759c073 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,6 +49,7 @@ add_executable(FlashMQ acltree.h enums.h threadlocalutils.h + flashmq_plugin.h mainapp.cpp main.cpp @@ -79,6 +80,7 @@ add_executable(FlashMQ evpencodectxmanager.cpp acltree.cpp threadlocalutils.cpp + flashmq_plugin.cpp ) diff --git a/authplugin.cpp b/authplugin.cpp index ba9e144..8d269d1 100644 --- a/authplugin.cpp +++ b/authplugin.cpp @@ -74,11 +74,11 @@ Authentication::~Authentication() EVP_MD_CTX_free(mosquittoDigestContext); } -void *Authentication::loadSymbol(void *handle, const char *symbol) const +void *Authentication::loadSymbol(void *handle, const char *symbol, bool exceptionOnError) const { void *r = dlsym(handle, symbol); - if (r == NULL) + if (r == NULL && exceptionOnError) { std::string errmsg(dlerror()); throw FatalError(errmsg); @@ -95,7 +95,7 @@ void Authentication::loadPlugin(const std::string &pathToSoFile) logger->logf(LOG_NOTICE, "Loading auth plugin %s", pathToSoFile.c_str()); initialized = false; - useExternalPlugin = true; + pluginVersion = PluginVersion::Determining; if (access(pathToSoFile.c_str(), R_OK) != 0) { @@ -112,20 +112,41 @@ void Authentication::loadPlugin(const std::string &pathToSoFile) throw FatalError(errmsg); } - version = (F_auth_plugin_version)loadSymbol(r, "mosquitto_auth_plugin_version"); - - if (version() != 2) + version = (F_auth_plugin_version)loadSymbol(r, "mosquitto_auth_plugin_version", false); + if (version != nullptr) { - throw FatalError("Only Mosquitto plugin version 2 is supported at this time."); + if (version() != 2) + { + throw FatalError("Only Mosquitto plugin version 2 is supported at this time."); + } + + pluginVersion = PluginVersion::MosquittoV2; + + init_v2 = (F_auth_plugin_init_v2)loadSymbol(r, "mosquitto_auth_plugin_init"); + cleanup_v2 = (F_auth_plugin_cleanup_v2)loadSymbol(r, "mosquitto_auth_plugin_cleanup"); + security_init_v2 = (F_auth_plugin_security_init_v2)loadSymbol(r, "mosquitto_auth_security_init"); + security_cleanup_v2 = (F_auth_plugin_security_cleanup_v2)loadSymbol(r, "mosquitto_auth_security_cleanup"); + acl_check_v2 = (F_auth_plugin_acl_check_v2)loadSymbol(r, "mosquitto_auth_acl_check"); + unpwd_check_v2 = (F_auth_plugin_unpwd_check_v2)loadSymbol(r, "mosquitto_auth_unpwd_check"); + psk_key_get_v2 = (F_auth_plugin_psk_key_get_v2)loadSymbol(r, "mosquitto_auth_psk_key_get"); } + else if ((version = (F_auth_plugin_version)loadSymbol(r, "flashmq_auth_plugin_version", false)) != nullptr) + { + if (version() != 1) + { + throw FatalError("FlashMQ plugin only supports version 1."); + } + + pluginVersion = PluginVersion::FlashMQv1; - init_v2 = (F_auth_plugin_init_v2)loadSymbol(r, "mosquitto_auth_plugin_init"); - cleanup_v2 = (F_auth_plugin_cleanup_v2)loadSymbol(r, "mosquitto_auth_plugin_cleanup"); - security_init_v2 = (F_auth_plugin_security_init_v2)loadSymbol(r, "mosquitto_auth_security_init"); - security_cleanup_v2 = (F_auth_plugin_security_cleanup_v2)loadSymbol(r, "mosquitto_auth_security_cleanup"); - acl_check_v2 = (F_auth_plugin_acl_check_v2)loadSymbol(r, "mosquitto_auth_acl_check"); - unpwd_check_v2 = (F_auth_plugin_unpwd_check_v2)loadSymbol(r, "mosquitto_auth_unpwd_check"); - psk_key_get_v2 = (F_auth_plugin_psk_key_get_v2)loadSymbol(r, "mosquitto_auth_psk_key_get"); + flashmq_auth_plugin_allocate_thread_memory_v1 = (F_flashmq_auth_plugin_allocate_thread_memory_v1)loadSymbol(r, "flashmq_auth_plugin_allocate_thread_memory"); + flashmq_auth_plugin_deallocate_thread_memory_v1 = (F_flashmq_auth_plugin_deallocate_thread_memory_v1)loadSymbol(r, "flashmq_auth_plugin_deallocate_thread_memory"); + flashmq_auth_plugin_init_v1 = (F_flashmq_auth_plugin_init_v1)loadSymbol(r, "flashmq_auth_plugin_init"); + flashmq_auth_plugin_deinit_v1 = (F_flashmq_auth_plugin_deinit_v1)loadSymbol(r, "flashmq_auth_plugin_deinit"); + flashmq_auth_plugin_acl_check_v1 = (F_flashmq_auth_plugin_acl_check_v1)loadSymbol(r, "flashmq_auth_plugin_acl_check"); + flashmq_auth_plugin_login_check_v1 = (F_flashmq_auth_plugin_login_check_v1)loadSymbol(r, "flashmq_auth_plugin_login_check"); + flashmq_auth_plugin_periodic_event_v1 = (F_flashmq_auth_plugin_periodic_event)loadSymbol(r, "flashmq_auth_plugin_periodic_event", false); + } initialized = true; } @@ -136,7 +157,7 @@ void Authentication::loadPlugin(const std::string &pathToSoFile) */ void Authentication::init() { - if (!useExternalPlugin) + if (pluginVersion == PluginVersion::None) return; UnscopedLock lock(initMutex); @@ -146,23 +167,46 @@ void Authentication::init() if (quitting) return; - AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); - int result = init_v2(&pluginData, authOpts.head(), authOpts.size()); - if (result != 0) - throw FatalError("Error initialising auth plugin."); + if (pluginVersion == PluginVersion::MosquittoV2) + { + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); + int result = init_v2(&pluginData, authOpts.head(), authOpts.size()); + if (result != 0) + throw FatalError("Error initialising auth plugin."); + } + else if (pluginVersion == PluginVersion::FlashMQv1) + { + std::unordered_map &authOpts = settings.getFlashmqAuthPluginOpts(); + flashmq_auth_plugin_allocate_thread_memory_v1(&pluginData, authOpts); + } } void Authentication::cleanup() { - if (!cleanup_v2) + if (pluginVersion == PluginVersion::None) return; securityCleanup(false); - AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); - int result = cleanup_v2(pluginData, authOpts.head(), authOpts.size()); - if (result != 0) - logger->logf(LOG_ERR, "Error cleaning up auth plugin"); // Not doing exception, because we're shutting down anyway. + if (pluginVersion == PluginVersion::MosquittoV2) + { + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); + int result = cleanup_v2(pluginData, authOpts.head(), authOpts.size()); + if (result != 0) + logger->logf(LOG_ERR, "Error cleaning up auth plugin"); // Not doing exception, because we're shutting down anyway. + } + else if (pluginVersion == PluginVersion::FlashMQv1) + { + try + { + std::unordered_map &authOpts = settings.getFlashmqAuthPluginOpts(); + flashmq_auth_plugin_deallocate_thread_memory_v1(pluginData, authOpts); + } + catch (std::exception &ex) + { + logger->logf(LOG_ERR, "Error cleaning up auth plugin: '%s'", ex.what()); // Not doing exception, because we're shutting down anyway. + } + } } /** @@ -171,7 +215,7 @@ void Authentication::cleanup() */ void Authentication::securityInit(bool reloading) { - if (!useExternalPlugin) + if (pluginVersion == PluginVersion::None) return; UnscopedLock lock(initMutex); @@ -181,31 +225,52 @@ void Authentication::securityInit(bool reloading) if (quitting) return; - AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); - int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading); - if (result != 0) + if (pluginVersion == PluginVersion::MosquittoV2) + { + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); + int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading); + if (result != 0) + { + throw AuthPluginException("Plugin function mosquitto_auth_security_init returned an error. If it didn't log anything, we don't know what it was."); + } + } + else if (pluginVersion == PluginVersion::FlashMQv1) { - throw AuthPluginException("Plugin function mosquitto_auth_security_init returned an error. If it didn't log anything, we don't know what it was."); + std::unordered_map &authOpts = settings.getFlashmqAuthPluginOpts(); + flashmq_auth_plugin_init_v1(pluginData, authOpts, reloading); } + initialized = true; + + periodicEvent(); } void Authentication::securityCleanup(bool reloading) { - if (!useExternalPlugin) + if (pluginVersion == PluginVersion::None) return; initialized = false; - AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); - int result = security_cleanup_v2(pluginData, authOpts.head(), authOpts.size(), reloading); - if (result != 0) + if (pluginVersion == PluginVersion::MosquittoV2) + { + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); + int result = security_cleanup_v2(pluginData, authOpts.head(), authOpts.size(), reloading); + + if (result != 0) + { + throw AuthPluginException("Plugin function mosquitto_auth_security_cleanup returned an error. If it didn't log anything, we don't know what it was."); + } + } + else if (pluginVersion == PluginVersion::FlashMQv1) { - throw AuthPluginException("Plugin function mosquitto_auth_security_cleanup returned an error. If it didn't log anything, we don't know what it was."); + std::unordered_map &authOpts = settings.getFlashmqAuthPluginOpts(); + flashmq_auth_plugin_deinit_v1(pluginData, authOpts, reloading); } } -AuthResult Authentication::aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, AclAccess access) +AuthResult Authentication::aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, + AclAccess access, char qos, bool retain) { assert(subtopics.size() > 0); @@ -214,7 +279,7 @@ AuthResult Authentication::aclCheck(const std::string &clientid, const std::stri if (firstResult != AuthResult::success) return firstResult; - if (!useExternalPlugin) + if (pluginVersion == PluginVersion::None) return firstResult; if (!initialized) @@ -227,15 +292,33 @@ AuthResult Authentication::aclCheck(const std::string &clientid, const std::stri if (settings.authPluginSerializeAuthChecks) lock.lock(); - int result = acl_check_v2(pluginData, clientid.c_str(), username.c_str(), topic.c_str(), static_cast(access)); - AuthResult result_ = static_cast(result); + if (pluginVersion == PluginVersion::MosquittoV2) + { + int result = acl_check_v2(pluginData, clientid.c_str(), username.c_str(), topic.c_str(), static_cast(access)); + AuthResult result_ = static_cast(result); - if (result_ == AuthResult::error) + if (result_ == AuthResult::error) + { + logger->logf(LOG_ERR, "ACL check by plugin returned error for topic '%s'. If it didn't log anything, we don't know what it was.", topic.c_str()); + } + } + else if (pluginVersion == PluginVersion::FlashMQv1) { - logger->logf(LOG_ERR, "ACL check by plugin returned error for topic '%s'. If it didn't log anything, we don't know what it was.", topic.c_str()); + // I'm using this try/catch because propagating the exception higher up conflicts with who gets the blame, and then the publisher + // gets disconnected. + try + { + FlashMQMessage msg(topic, subtopics, qos, retain); + return flashmq_auth_plugin_acl_check_v1(pluginData, access, clientid, username, msg); + } + catch (std::exception &ex) + { + logger->logf(LOG_ERR, "Error doing ACL check in plugin: '%s'", ex.what()); + logger->logf(LOG_WARNING, "Throwing exceptions from auth plugin login/ACL checks is slow. There's no need."); + } } - return result_; + return AuthResult::error; } AuthResult Authentication::unPwdCheck(const std::string &username, const std::string &password) @@ -245,7 +328,7 @@ AuthResult Authentication::unPwdCheck(const std::string &username, const std::st if (firstResult != AuthResult::success) return firstResult; - if (!useExternalPlugin) + if (pluginVersion == PluginVersion::None) return firstResult; if (!initialized) @@ -258,15 +341,32 @@ AuthResult Authentication::unPwdCheck(const std::string &username, const std::st if (settings.authPluginSerializeAuthChecks) lock.lock(); - int result = unpwd_check_v2(pluginData, username.c_str(), password.c_str()); - AuthResult r = static_cast(result); + if (pluginVersion == PluginVersion::MosquittoV2) + { + int result = unpwd_check_v2(pluginData, username.c_str(), password.c_str()); + AuthResult r = static_cast(result); - if (r == AuthResult::error) + if (r == AuthResult::error) + { + logger->logf(LOG_ERR, "Username+password check by plugin returned error for user '%s'. If it didn't log anything, we don't know what it was.", username.c_str()); + } + } + else if (pluginVersion == PluginVersion::FlashMQv1) { - logger->logf(LOG_ERR, "Username+password check by plugin returned error for user '%s'. If it didn't log anything, we don't know what it was.", username.c_str()); + // I'm using this try/catch because propagating the exception higher up conflicts with who gets the blame, and then the publisher + // gets disconnected. + try + { + return flashmq_auth_plugin_login_check_v1(pluginData, username, password); + } + catch (std::exception &ex) + { + logger->logf(LOG_ERR, "Error doing login check in plugin: '%s'", ex.what()); + logger->logf(LOG_WARNING, "Throwing exceptions from auth plugin login/ACL checks is slow. There's no need."); + } } - return r; + return AuthResult::error; } void Authentication::setQuitting() @@ -488,6 +588,23 @@ AuthResult Authentication::unPwdCheckFromMosquittoPasswordFile(const std::string return result; } +void Authentication::periodicEvent() +{ + if (pluginVersion == PluginVersion::None) + return; + + if (!initialized) + { + logger->logf(LOG_ERR, "Auth plugin period event called, but initialization failed or not performed."); + return; + } + + if (pluginVersion == PluginVersion::FlashMQv1 && flashmq_auth_plugin_periodic_event_v1) + { + flashmq_auth_plugin_periodic_event_v1(pluginData); + } +} + std::string AuthResultToString(AuthResult r) { if (r == AuthResult::success) diff --git a/authplugin.h b/authplugin.h index 95ece07..a7b0da0 100644 --- a/authplugin.h +++ b/authplugin.h @@ -26,6 +26,7 @@ License along with FlashMQ. If not, see . #include "logger.h" #include "configfileparser.h" #include "acltree.h" +#include "flashmq_plugin.h" /** * @brief The MosquittoPasswordFileEntry struct stores the decoded base64 password salt and hash. @@ -49,6 +50,7 @@ struct MosquittoPasswordFileEntry typedef int (*F_auth_plugin_version)(void); +// Mosquitto functions typedef int (*F_auth_plugin_init_v2)(void **, struct mosquitto_auth_opt *, int); typedef int (*F_auth_plugin_cleanup_v2)(void *, struct mosquitto_auth_opt *, int); typedef int (*F_auth_plugin_security_init_v2)(void *, struct mosquitto_auth_opt *, int, bool); @@ -57,12 +59,29 @@ typedef int (*F_auth_plugin_acl_check_v2)(void *, const char *, const char *, co typedef int (*F_auth_plugin_unpwd_check_v2)(void *, const char *, const char *); typedef int (*F_auth_plugin_psk_key_get_v2)(void *, const char *, const char *, char *, int); + +typedef void(*F_flashmq_auth_plugin_allocate_thread_memory_v1)(void **thread_data, std::unordered_map &auth_opts); +typedef void(*F_flashmq_auth_plugin_deallocate_thread_memory_v1)(void *thread_data, std::unordered_map &auth_opts); +typedef void(*F_flashmq_auth_plugin_init_v1)(void *thread_data, std::unordered_map &auth_opts, bool reloading); +typedef void(*F_flashmq_auth_plugin_deinit_v1)(void *thread_data, std::unordered_map &auth_opts, bool reloading); +typedef AuthResult(*F_flashmq_auth_plugin_acl_check_v1)(void *thread_data, AclAccess access, const std::string &clientid, const std::string &username, const FlashMQMessage &msg); +typedef AuthResult(*F_flashmq_auth_plugin_login_check_v1)(void *thread_data, const std::string &username, const std::string &password); +typedef void (*F_flashmq_auth_plugin_periodic_event)(void *thread_data); + extern "C" { // Gets called by the plugin, so it needs to exist, globally void mosquitto_log_printf(int level, const char *fmt, ...); } +enum class PluginVersion +{ + None, + Determining, + FlashMQv1, + MosquittoV2, +}; + std::string AuthResultToString(AuthResult r); /** @@ -72,6 +91,8 @@ std::string AuthResultToString(AuthResult r); class Authentication { F_auth_plugin_version version = nullptr; + + // Mosquitto functions F_auth_plugin_init_v2 init_v2 = nullptr; F_auth_plugin_cleanup_v2 cleanup_v2 = nullptr; F_auth_plugin_security_init_v2 security_init_v2 = nullptr; @@ -80,6 +101,14 @@ class Authentication F_auth_plugin_unpwd_check_v2 unpwd_check_v2 = nullptr; F_auth_plugin_psk_key_get_v2 psk_key_get_v2 = nullptr; + F_flashmq_auth_plugin_allocate_thread_memory_v1 flashmq_auth_plugin_allocate_thread_memory_v1 = nullptr; + F_flashmq_auth_plugin_deallocate_thread_memory_v1 flashmq_auth_plugin_deallocate_thread_memory_v1 = nullptr; + F_flashmq_auth_plugin_init_v1 flashmq_auth_plugin_init_v1 = nullptr; + F_flashmq_auth_plugin_deinit_v1 flashmq_auth_plugin_deinit_v1 = nullptr; + F_flashmq_auth_plugin_acl_check_v1 flashmq_auth_plugin_acl_check_v1 = nullptr; + F_flashmq_auth_plugin_login_check_v1 flashmq_auth_plugin_login_check_v1 = nullptr; + F_flashmq_auth_plugin_periodic_event flashmq_auth_plugin_periodic_event_v1 = nullptr; + static std::mutex initMutex; static std::mutex authChecksMutex; @@ -88,7 +117,7 @@ class Authentication void *pluginData = nullptr; Logger *logger = nullptr; bool initialized = false; - bool useExternalPlugin = false; + PluginVersion pluginVersion = PluginVersion::None; bool quitting = false; /** @@ -110,7 +139,7 @@ class Authentication AclTree aclTree; - void *loadSymbol(void *handle, const char *symbol) const; + void *loadSymbol(void *handle, const char *symbol, bool exceptionOnError = true) const; public: Authentication(Settings &settings); Authentication(const Authentication &other) = delete; @@ -122,7 +151,8 @@ public: void cleanup(); void securityInit(bool reloading); void securityCleanup(bool reloading); - AuthResult aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, AclAccess access); + AuthResult aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, + AclAccess access, char qos, bool retain); AuthResult unPwdCheck(const std::string &username, const std::string &password); void setQuitting(); @@ -131,6 +161,8 @@ public: AuthResult aclCheckFromMosquittoAclFile(const std::string &clientid, const std::string &username, const std::vector &subtopics, AclAccess access); AuthResult unPwdCheckFromMosquittoPasswordFile(const std::string &username, const std::string &password); + void periodicEvent(); + }; #endif // AUTHPLUGIN_H diff --git a/client.cpp b/client.cpp index 8e3372c..93e66c0 100644 --- a/client.cpp +++ b/client.cpp @@ -91,6 +91,11 @@ bool Client::getSslWriteWantsRead() const return ioWrapper.getSslWriteWantsRead(); } +ProtocolVersion Client::getProtocolVersion() const +{ + return protocolVersion; +} + void Client::startOrContinueSslAccept() { ioWrapper.startOrContinueSslAccept(); diff --git a/client.h b/client.h index 076efe5..94c70b3 100644 --- a/client.h +++ b/client.h @@ -99,6 +99,7 @@ public: bool isSsl() const; bool getSslReadWantsWrite() const; bool getSslWriteWantsRead() const; + ProtocolVersion getProtocolVersion() const; void startOrContinueSslAccept(); void markAsDisconnecting(); diff --git a/configfileparser.cpp b/configfileparser.cpp index da46935..b99810e 100644 --- a/configfileparser.cpp +++ b/configfileparser.cpp @@ -92,11 +92,12 @@ ConfigFileParser::ConfigFileParser(const std::string &path) : path(path) { validKeys.insert("auth_plugin"); + validKeys.insert("auth_plugin_serialize_init"); + validKeys.insert("auth_plugin_serialize_auth_checks"); + validKeys.insert("auth_plugin_timer_period"); validKeys.insert("log_file"); validKeys.insert("allow_unsafe_clientid_chars"); validKeys.insert("allow_unsafe_username_chars"); - validKeys.insert("auth_plugin_serialize_init"); - validKeys.insert("auth_plugin_serialize_auth_checks"); validKeys.insert("client_initial_buffer_size"); validKeys.insert("max_packet_size"); validKeys.insert("log_debug"); @@ -401,6 +402,16 @@ void ConfigFileParser::loadFile(bool test) } tmpSettings->expireSessionsAfterSeconds = newVal; } + + if (key == "auth_plugin_timer_period") + { + int newVal = std::stoi(value); + if (newVal < 0) + { + throw ConfigFileException(formatString("auth_plugin_timer_period value '%d' is invalid. Valid values are 0 or higher. 0 means disabled.", newVal)); + } + tmpSettings->authPluginTimerPeriod = newVal; + } } } catch (std::invalid_argument &ex) // catch for the stoi() @@ -410,6 +421,7 @@ void ConfigFileParser::loadFile(bool test) } tmpSettings->authOptCompatWrap = AuthOptCompatWrap(authOpts); + tmpSettings->flashmqAuthPluginOpts = std::move(authOpts); if (!test) { @@ -417,9 +429,5 @@ void ConfigFileParser::loadFile(bool test) } } -AuthOptCompatWrap &Settings::getAuthOptsCompat() -{ - return authOptCompatWrap; -} diff --git a/enums.h b/enums.h index 0c7db47..b76c847 100644 --- a/enums.h +++ b/enums.h @@ -1,21 +1,6 @@ #ifndef ENUMS_H #define ENUMS_H -// Compatible with Mosquitto -enum class AclAccess -{ - none = 0, - read = 1, - write = 2 -}; - -// Compatible with Mosquitto -enum class AuthResult -{ - success = 0, - acl_denied = 12, - login_denied = 11, - error = 13 -}; +#include "flashmq_plugin.h" #endif // ENUMS_H diff --git a/flashmq_plugin.cpp b/flashmq_plugin.cpp new file mode 100644 index 0000000..51c23a0 --- /dev/null +++ b/flashmq_plugin.cpp @@ -0,0 +1,22 @@ +#include "flashmq_plugin.h" + +#include "logger.h" + +void flashmq_logf(int level, const char *str, ...) +{ + Logger *logger = Logger::getInstance(); + + va_list valist; + va_start(valist, str); + logger->logf(level, str, valist); + va_end(valist); +} + +FlashMQMessage::FlashMQMessage(const std::string &topic, const std::vector &subtopics, const char qos, const bool retain) : + topic(topic), + subtopics(subtopics), + qos(qos), + retain(retain) +{ + +} diff --git a/flashmq_plugin.h b/flashmq_plugin.h new file mode 100644 index 0000000..51bbe96 --- /dev/null +++ b/flashmq_plugin.h @@ -0,0 +1,211 @@ +/* + * This file is part of FlashMQ (https://www.flashmq.org). It defines the + * authentication plugin interface. + * + * This interface definition is public domain and you are encouraged + * to copy it to your authentication plugin project, for portability. Including + * this file in your project does not require your code to have a compatibile + * license nor requires you to open source it. + * + * Compile like: gcc -fPIC -shared authplugin.cpp -o authplugin.so + */ + +#ifndef FLASHMQ_PLUGIN_H +#define FLASHMQ_PLUGIN_H + +#include +#include +#include + +#define FLASHMQ_PLUGIN_VERSION 1 + +// Compatible with Mosquitto, for auth plugin compatability. +#define LOG_NONE 0x00 +#define LOG_INFO 0x01 +#define LOG_NOTICE 0x02 +#define LOG_WARNING 0x04 +#define LOG_ERR 0x08 +#define LOG_DEBUG 0x10 +#define LOG_SUBSCRIBE 0x20 +#define LOG_UNSUBSCRIBE 0x40 + +extern "C" +{ + +/** + * @brief The AclAccess enum's numbers are compatible with Mosquitto's 'int access'. + * + * read = reading a publish published by someone else. + * write = doing a publish. + * subscribe = subscribing. + */ +enum class AclAccess +{ + none = 0, + read = 1, + write = 2, + subscribe = 4 +}; + +/** + * @brief The AuthResult enum's numbers are compatible with Mosquitto's auth result. + */ +enum class AuthResult +{ + success = 0, + acl_denied = 12, + login_denied = 11, + error = 13 +}; + +/** + * @brief The FlashMQMessage struct contains the meta data of a publish. + * + * The subtopics is the topic split, so you don't have to do that anymore. + * + * As for 'retain', keep in mind that for existing subscribers, this will always be false [MQTT-3.3.1-9]. Only publishes or + * retain messages as a result of a subscribe can have that set to true. + * + * For subscribtions, 'retain' is always false. + */ +struct FlashMQMessage +{ + const std::string &topic; + const std::vector &subtopics; + const char qos; + const bool retain; + + FlashMQMessage(const std::string &topic, const std::vector &subtopics, const char qos, const bool retain); +}; + +/** + * @brief flashmq_logf calls the internal logger of FlashMQ. The logger mutexes all access, so is thread-safe. + * @param level is any of the levels defined above, starting with LOG_. + * @param str + * + * FlashMQ makes no distinction between INFO and NOTICE. + */ +void flashmq_logf(int level, const char *str, ...); + +/** + * @brief flashmq_plugin_version must return FLASHMQ_PLUGIN_VERSION. + * @return FLASHMQ_PLUGIN_VERSION. + */ +int flashmq_auth_plugin_version(); + +/** + * @brief flashmq_auth_plugin_allocate_thread_memory is called once by each thread. Never again. + * @param thread_data. Create a memory structure and assign it to *thread_data. + * @param global_data. The global data created in flashmq_auth_plugin_allocate_global_memory, if you use it. + * @param auth_opts. Map of flashmq_auth_opt_* from the config file. + * + * Only allocate the plugin's memory here. Don't open connections, etc. + * + * The global data is created by flashmq_auth_plugin_allocate_global_memory() and if you need it, you can assign it to your + * own thread_data storage. It is not passed as argument to other functions. + * + * You can use static variables for global scope if you must, but do provide proper locking where necessary. + * + * throw an exception on errors. + */ +void flashmq_auth_plugin_allocate_thread_memory(void **thread_data, std::unordered_map &auth_opts); + +/** + * @brief flashmq_auth_plugin_deallocate_thread_memory is called once by each thread. Never again. + * @param thread_data. Delete this memory. + * @param auth_opts. Map of flashmq_auth_opt_* from the config file. + * + * throw an exception on errors. + */ +void flashmq_auth_plugin_deallocate_thread_memory(void *thread_data, std::unordered_map &auth_opts); + +/** + * @brief flashmq_auth_plugin_init is called on thread start and config reload. It is the main place to initialize the plugin. + * @param thread_data is memory allocated in flashmq_auth_plugin_allocate_thread_memory(). + * @param auth_opts. Map of flashmq_auth_opt_* from the config file. + * @param reloading. + * + * The best approach to state keeping is doing everything per thread. You can initialize connections to database servers, load encryption keys, + * create maps, etc. + * + * Keep in mind that libraries you use may not be thread safe (by default). Sometimes they use global scope in treacherous ways. As a random + * example: Qt's QSqlDatabase needs a unique name for each connection, otherwise it is not thread safe and will crash. + * + * There is the option to set 'auth_plugin_serialize_init true' in the config file, which allows some mitigation in + * case you run into problems. + * + * throw an exception on errors. + */ +void flashmq_auth_plugin_init(void *thread_data, std::unordered_map &auth_opts, bool reloading); + +/** + * @brief flashmq_auth_plugin_deinit is called on thread stop and config reload. It is the precursor to initializing. + * @param thread_data is memory allocated in flashmq_auth_plugin_allocate_thread_memory(). + * @param auth_opts. Map of flashmq_auth_opt_* from the config file. + * @param reloading + * + * throw an exception on errors. + */ +void flashmq_auth_plugin_deinit(void *thread_data, std::unordered_map &auth_opts, bool reloading); + +/** + * @brief flashmq_auth_plugin_periodic is called every x seconds as defined in the config file. + * @param thread_data is memory allocated in flashmq_auth_plugin_allocate_thread_memory(). + * + * You may need to periodically refresh data from a database, post stats, etc. You can do that from here. It's queued + * in each thread at the same time, so you can perform somewhat synchronized events in all threads. + * + * Note that it's executed in the event loop, so it blocks the thread if you block here. If you need asynchronous operation, + * you can make threads yourself. Be sure to synchronize data access properly in that case. + * + * The setting auth_plugin_timer_period sets this interval in seconds. + * + * Implementing this is optional. + * + * throw an exception on errors. + */ +void flashmq_auth_plugin_periodic_event(void *thread_data); + +/** + * @brief flashmq_auth_plugin_login_check is called on login of a client. + * @param thread_data is memory allocated in flashmq_auth_plugin_allocate_thread_memory(). + * @param username + * @param password + * @return + * + * You could throw exceptions here, but that will be slow and pointless. It will just get converted into AuthResult::error, + * because there's nothing else to do: the state of FlashMQ won't change. + * + * Note that there is a setting 'auth_plugin_serialize_auth_checks'. Use only as a last resort if your plugin is not + * thread-safe. It will negate much of FlashMQ's multi-core model. + */ +AuthResult flashmq_auth_plugin_login_check(void *thread_data, const std::string &username, const std::string &password); + +/** + * @brief flashmq_auth_plugin_acl_check is called on publish, deliver and subscribe. + * @param thread_data is memory allocated in flashmq_auth_plugin_allocate_thread_memory(). + * @param access + * @param clientid + * @param username + * @param msg. See FlashMQMessage. + * @return + * + * You could throw exceptions here, but that will be slow and pointless. It will just get converted into AuthResult::error, + * because there's nothing else to do: the state of FlashMQ won't change. + * + * Controlling subscribe access can have several benefits. For instance, you may want to avoid subscriptions that cause + * a lot of server load. If clients pester you with many subscriptions like '+/+/+/+/+/+/+/+/+/', that causes a lot + * of tree walking. Similarly, if all clients subscribe to '#' because it's easy, every single message passing through + * the server will have to be ACL checked for every subscriber. + * + * Note that only MQTT 3.1.1 or higher has a 'failed' return code for subscribing, so older clients will see a normal + * ack and won't know it failed. + * + * Note that there is a setting 'auth_plugin_serialize_auth_checks'. Use only as a last resort if your plugin is not + * thread-safe. It will negate much of FlashMQ's multi-core model. + */ +AuthResult flashmq_auth_plugin_acl_check(void *thread_data, AclAccess access, const std::string &clientid, const std::string &username, const FlashMQMessage &msg); + +} + +#endif // FLASHMQ_PLUGIN_H diff --git a/logger.h b/logger.h index 5899c03..1819d4f 100644 --- a/logger.h +++ b/logger.h @@ -22,16 +22,7 @@ License along with FlashMQ. If not, see . #include #include -// Compatible with Mosquitto, for auth plugin compatability. -// Can be OR'ed together. -#define LOG_NONE 0x00 -#define LOG_INFO 0x01 -#define LOG_NOTICE 0x02 -#define LOG_WARNING 0x04 -#define LOG_ERR 0x08 -#define LOG_DEBUG 0x10 -#define LOG_SUBSCRIBE 0x20 -#define LOG_UNSUBSCRIBE 0x40 +#include "flashmq_plugin.h" int logSslError(const char *str, size_t len, void *u); diff --git a/mainapp.cpp b/mainapp.cpp index 352085f..d01f4d6 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -196,6 +196,12 @@ MainApp::MainApp(const std::string &configFilePath) : auto fPublishStats = std::bind(&MainApp::publishStatsOnDollarTopic, this); timer.addCallback(fPublishStats, 10000, "Publish stats on $SYS"); publishStatsOnDollarTopic(); + + if (settings->authPluginTimerPeriod > 0) + { + auto fAuthPluginPeriodicEvent = std::bind(&MainApp::queueAuthPluginPeriodicEventAllThreads, this); + timer.addCallback(fAuthPluginPeriodicEvent, settings->authPluginTimerPeriod*1000, "Auth plugin periodic event."); + } } MainApp::~MainApp() @@ -310,6 +316,14 @@ void MainApp::queuePasswordFileReloadAllThreads() } } +void MainApp::queueAuthPluginPeriodicEventAllThreads() +{ + for (std::shared_ptr &thread : threads) + { + thread->queueAuthPluginPeriodicEvent(); + } +} + void MainApp::setFuzzFile(const std::string &fuzzFilePath) { this->fuzzFilePath = fuzzFilePath; @@ -505,6 +519,7 @@ void MainApp::start() try { std::vector packetQueueIn; + std::vector subtopics; std::shared_ptr threaddata(new ThreadData(0, subscriptionStore, settings)); @@ -518,10 +533,11 @@ void MainApp::start() websocketsubscriber->setAuthenticated(true); websocketsubscriber->setFakeUpgraded(); subscriptionStore->registerClientAndKickExistingOne(websocketsubscriber); - subscriptionStore->addSubscription(websocketsubscriber, "#", 0); + splitTopic("#", subtopics); + subscriptionStore->addSubscription(websocketsubscriber, "#", subtopics, 0); subscriptionStore->registerClientAndKickExistingOne(subscriber); - subscriptionStore->addSubscription(subscriber, "#", 0); + subscriptionStore->addSubscription(subscriber, "#", subtopics, 0); if (fuzzWebsockets && strContains(fuzzFilePathLower, "upgrade")) { diff --git a/mainapp.h b/mainapp.h index a11599e..a738e6d 100644 --- a/mainapp.h +++ b/mainapp.h @@ -76,6 +76,7 @@ class MainApp void wakeUpThread(); void queueKeepAliveCheckAtAllThreads(); void queuePasswordFileReloadAllThreads(); + void queueAuthPluginPeriodicEventAllThreads(); void setFuzzFile(const std::string &fuzzFilePath); void publishStatsOnDollarTopic(); void publishStat(const std::string &topic, uint64_t n); diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 91f5897..0fc69fb 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -407,6 +407,7 @@ void MqttPacket::handleDisconnect() void MqttPacket::handleSubscribe() { + this->subtopics = &gSubtopics; const char firstByteFirstNibble = (first_byte & 0x0F); if (firstByteFirstNibble != 2) @@ -431,9 +432,23 @@ void MqttPacket::handleSubscribe() if (qos > 2) throw ProtocolError("QoS is greater than 2, and/or reserved bytes in QoS field are not 0."); - logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str()); - sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, qos); - subs_reponse_codes.push_back(qos); + splitTopic(topic, *subtopics); + if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::subscribe, qos, false) == AuthResult::success) + { + logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str()); + sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, *subtopics, qos); + subs_reponse_codes.push_back(qos); + } + else + { + logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribe to '%s' denied or failed.", sender->repr().c_str(), topic.c_str()); + + // We can't not send an ack, because if there are multiple subscribes, you send fewer acks back, losing sync. + char return_code = qos; + if (sender->getProtocolVersion() >= ProtocolVersion::Mqtt311) + return_code = static_cast(SubAckReturnCodes::Fail); + subs_reponse_codes.push_back(return_code); + } } SubAck subAck(packet_id, subs_reponse_codes); @@ -531,7 +546,7 @@ void MqttPacket::handlePublish() } } - if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write) == AuthResult::success) + if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success) { if (retain) { diff --git a/session.cpp b/session.cpp index 421b5a3..3b328f8 100644 --- a/session.cpp +++ b/session.cpp @@ -48,14 +48,20 @@ void Session::assignActiveConnection(std::shared_ptr &client) this->thread = client->getThreadData(); } -void Session::writePacket(const MqttPacket &packet, char max_qos, uint64_t &count) +/** + * @brief Session::writePacket is the main way to give a client a packet -> it goes through the session. + * @param packet + * @param max_qos + * @param retain. Keep MQTT-3.3.1-9 in mind: existing subscribers don't get retain=1 on packets. + * @param count. Reference value is updated. It's for statistics. + */ +void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count) { assert(max_qos <= 2); + const char qos = std::min(packet.getQos(), max_qos); - if (thread->authentication.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read) == AuthResult::success) + if (thread->authentication.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success) { - const char qos = std::min(packet.getQos(), max_qos); - if (qos == 0) { if (!clientDisconnected()) diff --git a/session.h b/session.h index ad6d05c..c0faad3 100644 --- a/session.h +++ b/session.h @@ -61,7 +61,7 @@ public: bool clientDisconnected() const; std::shared_ptr makeSharedClient() const; void assignActiveConnection(std::shared_ptr &client); - void writePacket(const MqttPacket &packet, char max_qos, uint64_t &count); + void writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count); void clearQosMessage(uint16_t packet_id); uint64_t sendPendingQosMessages(); void touch(std::chrono::time_point val); diff --git a/settings.cpp b/settings.cpp index a2a8e03..49b175e 100644 --- a/settings.cpp +++ b/settings.cpp @@ -18,3 +18,12 @@ License along with FlashMQ. If not, see . #include "settings.h" +AuthOptCompatWrap &Settings::getAuthOptsCompat() +{ + return authOptCompatWrap; +} + +std::unordered_map &Settings::getFlashmqAuthPluginOpts() +{ + return this->flashmqAuthPluginOpts; +} diff --git a/settings.h b/settings.h index 525508e..476c3a7 100644 --- a/settings.h +++ b/settings.h @@ -29,6 +29,7 @@ class Settings friend class ConfigFileParser; AuthOptCompatWrap authOptCompatWrap; + std::unordered_map flashmqAuthPluginOpts; public: // Actual config options with their defaults. @@ -47,9 +48,11 @@ public: bool allowAnonymous = false; int rlimitNoFile = 1000000; uint64_t expireSessionsAfterSeconds = 1209600; + int authPluginTimerPeriod = 60; std::list> listeners; // Default one is created later, when none are defined. AuthOptCompatWrap &getAuthOptsCompat(); + std::unordered_map &getFlashmqAuthPluginOpts(); }; #endif // SETTINGS_H diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index 97579f3..6061e8d 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -84,10 +84,8 @@ SubscriptionStore::SubscriptionStore() : } -void SubscriptionStore::addSubscription(std::shared_ptr &client, const std::string &topic, char qos) +void SubscriptionStore::addSubscription(std::shared_ptr &client, const std::string &topic, const std::vector &subtopics, char qos) { - const std::list subtopics = split(topic, '/'); - SubscriptionNode *deepestNode = &root; if (topic.length() > 0 && topic[0] == '$') deepestNode = &rootDollar; @@ -242,7 +240,7 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const st if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect. { const std::shared_ptr session = session_weak.lock(); - session->writePacket(packet, sub.qos, count); + session->writePacket(packet, sub.qos, false, count); } } } @@ -330,7 +328,7 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptrwritePacket(packet, max_qos, count); + ses->writePacket(packet, max_qos, true, count); } } diff --git a/subscriptionstore.h b/subscriptionstore.h index e1bd106..4339a4f 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -89,7 +89,7 @@ class SubscriptionStore public: SubscriptionStore(); - void addSubscription(std::shared_ptr &client, const std::string &topic, char qos); + void addSubscription(std::shared_ptr &client, const std::string &topic, const std::vector &subtopics, char qos); void removeSubscription(std::shared_ptr &client, const std::string &topic); void registerClientAndKickExistingOne(std::shared_ptr &client); bool sessionPresent(const std::string &clientid); diff --git a/threaddata.cpp b/threaddata.cpp index fb987b1..6bc1891 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -142,6 +142,7 @@ void ThreadData::queueQuit() void ThreadData::waitForQuit() { thread.join(); + authentication.cleanup(); } void ThreadData::queuePasswdFileReload() @@ -210,6 +211,21 @@ uint64_t ThreadData::getSentMessagePerSecond() return result; } +void ThreadData::queueAuthPluginPeriodicEvent() +{ + std::lock_guard locker(taskQueueMutex); + + auto f = std::bind(&ThreadData::authPluginPeriodicEvent, this); + taskQueue.push_front(f); + + wakeUpThread(); +} + +void ThreadData::authPluginPeriodicEvent() +{ + authentication.periodicEvent(); +} + // TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial? void ThreadData::doKeepAliveCheck() { @@ -266,13 +282,9 @@ void ThreadData::reload(std::shared_ptr settings) authentication.securityCleanup(true); authentication.securityInit(true); } - catch (AuthPluginException &ex) - { - logger->logf(LOG_ERR, "Error reloading auth plugin: %s. Security checks will now fail, because we don't know the status of the plugin anymore.", ex.what()); - } catch (std::exception &ex) { - logger->logf(LOG_ERR, "Error reloading: %s.", ex.what()); + logger->logf(LOG_ERR, "Error reloading auth plugin: %s. Security checks will now fail, because we don't know the status of the plugin anymore.", ex.what()); } } diff --git a/threaddata.h b/threaddata.h index b36cbeb..0e09b2a 100644 --- a/threaddata.h +++ b/threaddata.h @@ -101,6 +101,9 @@ public: void incrementSentMessageCount(uint64_t n); uint64_t getSentMessageCount() const; uint64_t getSentMessagePerSecond(); + + void queueAuthPluginPeriodicEvent(); + void authPluginPeriodicEvent(); }; #endif // THREADDATA_H diff --git a/types.cpp b/types.cpp index b3bd2a2..11f1b7a 100644 --- a/types.cpp +++ b/types.cpp @@ -15,6 +15,8 @@ You should have received a copy of the GNU Affero General Public License along with FlashMQ. If not, see . */ +#include "cassert" + #include "types.h" ConnAck::ConnAck(ConnAckReturnCodes return_code, bool session_present) : @@ -29,6 +31,8 @@ ConnAck::ConnAck(ConnAckReturnCodes return_code, bool session_present) : SubAck::SubAck(uint16_t packet_id, const std::list &subs_qos_reponses) : packet_id(packet_id) { + assert(!subs_qos_reponses.empty()); + for (char ack_code : subs_qos_reponses) { responses.push_back(static_cast(ack_code)); -- libgit2 0.21.4