diff --git a/CMakeLists.txt b/CMakeLists.txt index 758c80b..f13c2bf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,6 +32,7 @@ add_executable(FlashMQ mosquittoauthoptcompatwrap.cpp settings.cpp listener.cpp + unscopedlock.cpp ) target_link_libraries(FlashMQ pthread dl ssl crypto) diff --git a/authplugin.cpp b/authplugin.cpp index f13ec03..8ef2767 100644 --- a/authplugin.cpp +++ b/authplugin.cpp @@ -6,6 +6,10 @@ #include #include "exceptions.h" +#include "unscopedlock.h" + +std::mutex AuthPlugin::initMutex; +std::mutex AuthPlugin::authChecksMutex; void mosquitto_log_printf(int level, const char *fmt, ...) { @@ -89,6 +93,10 @@ void AuthPlugin::init() if (!wanted) return; + UnscopedLock lock(initMutex); + if (settings.authPluginSerializeInit) + lock.lock(); + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); int result = init_v2(&pluginData, authOpts.head(), authOpts.size()); if (result != 0) @@ -113,6 +121,10 @@ void AuthPlugin::securityInit(bool reloading) if (!wanted) return; + UnscopedLock lock(initMutex); + if (settings.authPluginSerializeInit) + lock.lock(); + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading); if (result != 0) @@ -148,6 +160,10 @@ AuthResult AuthPlugin::aclCheck(const std::string &clientid, const std::string & return AuthResult::error; } + UnscopedLock lock(authChecksMutex); + if (settings.authPluginSerializeAuthChecks) + lock.lock(); + int result = acl_check_v2(pluginData, clientid.c_str(), username.c_str(), topic.c_str(), static_cast(access)); AuthResult result_ = static_cast(result); @@ -170,6 +186,10 @@ AuthResult AuthPlugin::unPwdCheck(const std::string &username, const std::string return AuthResult::error; } + UnscopedLock lock(authChecksMutex); + if (settings.authPluginSerializeAuthChecks) + lock.lock(); + int result = unpwd_check_v2(pluginData, username.c_str(), password.c_str()); AuthResult r = static_cast(result); diff --git a/authplugin.h b/authplugin.h index 329d194..ec4039a 100644 --- a/authplugin.h +++ b/authplugin.h @@ -54,6 +54,9 @@ class AuthPlugin F_auth_plugin_unpwd_check_v2 unpwd_check_v2 = nullptr; F_auth_plugin_psk_key_get_v2 psk_key_get_v2 = nullptr; + static std::mutex initMutex; + static std::mutex authChecksMutex; + Settings &settings; // A ref because I want it to always be the same as the thread's settings void *pluginData = nullptr; diff --git a/configfileparser.cpp b/configfileparser.cpp index 0d8cd07..57cacb0 100644 --- a/configfileparser.cpp +++ b/configfileparser.cpp @@ -20,7 +20,7 @@ void ConfigFileParser::testKeyValidity(const std::string &key, const std::setallowUnsafeUsernameChars = tmp; } + if (key == "auth_plugin_serialize_init") + { + bool tmp = stringTruthiness(value); + tmpSettings->authPluginSerializeInit = tmp; + } + + if (key == "auth_plugin_serialize_auth_checks") + { + bool tmp = stringTruthiness(value); + tmpSettings->authPluginSerializeAuthChecks = tmp; + } + if (key == "client_initial_buffer_size") { int newVal = std::stoi(value); diff --git a/settings.h b/settings.h index 687fc70..f1b8e00 100644 --- a/settings.h +++ b/settings.h @@ -19,6 +19,8 @@ public: std::string logPath; bool allowUnsafeClientidChars = false; bool allowUnsafeUsernameChars = false; + bool authPluginSerializeInit = false; + bool authPluginSerializeAuthChecks = false; int clientInitialBufferSize = 1024; // Must be power of 2 int maxPacketSize = 268435461; // 256 MB + 5 std::list> listeners; // Default one is created later, when none are defined. diff --git a/unscopedlock.cpp b/unscopedlock.cpp new file mode 100644 index 0000000..a2a01d7 --- /dev/null +++ b/unscopedlock.cpp @@ -0,0 +1,21 @@ +#include "unscopedlock.h" + +UnscopedLock::~UnscopedLock() +{ + if (locked) + { + managedMutex.unlock(); + } +} + +UnscopedLock::UnscopedLock(std::mutex &mutex) : + managedMutex(mutex) +{ + +} + +void UnscopedLock::lock() +{ + managedMutex.lock(); + locked = true; +} diff --git a/unscopedlock.h b/unscopedlock.h new file mode 100644 index 0000000..d97d182 --- /dev/null +++ b/unscopedlock.h @@ -0,0 +1,24 @@ +#ifndef UNSCOPEDLOCK_H +#define UNSCOPEDLOCK_H + +#include + +/** + * @brief The UnscopedLock class is a simple variety of the std::lock_guard or std::scoped_lock that allows optional locking using RAII. + * + * STL doesn't provide a similar feature, or am I missing something? You could do it with smart pointers, but I want to avoid having to + * use the free store. + */ +class UnscopedLock +{ + std::mutex &managedMutex; + bool locked = false; + +public: + ~UnscopedLock(); + + UnscopedLock(std::mutex &mutex); + void lock(); +}; + +#endif // UNSCOPEDLOCK_H