Commit 775180c0c7b55b5c3fde9b9247f623ab4ad327e6
1 parent
be205082
Implement the option to serialize auth checks
Showing
7 changed files
with
86 additions
and
1 deletions
CMakeLists.txt
| @@ -32,6 +32,7 @@ add_executable(FlashMQ | @@ -32,6 +32,7 @@ add_executable(FlashMQ | ||
| 32 | mosquittoauthoptcompatwrap.cpp | 32 | mosquittoauthoptcompatwrap.cpp |
| 33 | settings.cpp | 33 | settings.cpp |
| 34 | listener.cpp | 34 | listener.cpp |
| 35 | + unscopedlock.cpp | ||
| 35 | ) | 36 | ) |
| 36 | 37 | ||
| 37 | target_link_libraries(FlashMQ pthread dl ssl crypto) | 38 | target_link_libraries(FlashMQ pthread dl ssl crypto) |
authplugin.cpp
| @@ -6,6 +6,10 @@ | @@ -6,6 +6,10 @@ | ||
| 6 | #include <dlfcn.h> | 6 | #include <dlfcn.h> |
| 7 | 7 | ||
| 8 | #include "exceptions.h" | 8 | #include "exceptions.h" |
| 9 | +#include "unscopedlock.h" | ||
| 10 | + | ||
| 11 | +std::mutex AuthPlugin::initMutex; | ||
| 12 | +std::mutex AuthPlugin::authChecksMutex; | ||
| 9 | 13 | ||
| 10 | void mosquitto_log_printf(int level, const char *fmt, ...) | 14 | void mosquitto_log_printf(int level, const char *fmt, ...) |
| 11 | { | 15 | { |
| @@ -89,6 +93,10 @@ void AuthPlugin::init() | @@ -89,6 +93,10 @@ void AuthPlugin::init() | ||
| 89 | if (!wanted) | 93 | if (!wanted) |
| 90 | return; | 94 | return; |
| 91 | 95 | ||
| 96 | + UnscopedLock lock(initMutex); | ||
| 97 | + if (settings.authPluginSerializeInit) | ||
| 98 | + lock.lock(); | ||
| 99 | + | ||
| 92 | AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); | 100 | AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); |
| 93 | int result = init_v2(&pluginData, authOpts.head(), authOpts.size()); | 101 | int result = init_v2(&pluginData, authOpts.head(), authOpts.size()); |
| 94 | if (result != 0) | 102 | if (result != 0) |
| @@ -113,6 +121,10 @@ void AuthPlugin::securityInit(bool reloading) | @@ -113,6 +121,10 @@ void AuthPlugin::securityInit(bool reloading) | ||
| 113 | if (!wanted) | 121 | if (!wanted) |
| 114 | return; | 122 | return; |
| 115 | 123 | ||
| 124 | + UnscopedLock lock(initMutex); | ||
| 125 | + if (settings.authPluginSerializeInit) | ||
| 126 | + lock.lock(); | ||
| 127 | + | ||
| 116 | AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); | 128 | AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); |
| 117 | int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading); | 129 | int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading); |
| 118 | if (result != 0) | 130 | if (result != 0) |
| @@ -148,6 +160,10 @@ AuthResult AuthPlugin::aclCheck(const std::string &clientid, const std::string & | @@ -148,6 +160,10 @@ AuthResult AuthPlugin::aclCheck(const std::string &clientid, const std::string & | ||
| 148 | return AuthResult::error; | 160 | return AuthResult::error; |
| 149 | } | 161 | } |
| 150 | 162 | ||
| 163 | + UnscopedLock lock(authChecksMutex); | ||
| 164 | + if (settings.authPluginSerializeAuthChecks) | ||
| 165 | + lock.lock(); | ||
| 166 | + | ||
| 151 | int result = acl_check_v2(pluginData, clientid.c_str(), username.c_str(), topic.c_str(), static_cast<int>(access)); | 167 | int result = acl_check_v2(pluginData, clientid.c_str(), username.c_str(), topic.c_str(), static_cast<int>(access)); |
| 152 | AuthResult result_ = static_cast<AuthResult>(result); | 168 | AuthResult result_ = static_cast<AuthResult>(result); |
| 153 | 169 | ||
| @@ -170,6 +186,10 @@ AuthResult AuthPlugin::unPwdCheck(const std::string &username, const std::string | @@ -170,6 +186,10 @@ AuthResult AuthPlugin::unPwdCheck(const std::string &username, const std::string | ||
| 170 | return AuthResult::error; | 186 | return AuthResult::error; |
| 171 | } | 187 | } |
| 172 | 188 | ||
| 189 | + UnscopedLock lock(authChecksMutex); | ||
| 190 | + if (settings.authPluginSerializeAuthChecks) | ||
| 191 | + lock.lock(); | ||
| 192 | + | ||
| 173 | int result = unpwd_check_v2(pluginData, username.c_str(), password.c_str()); | 193 | int result = unpwd_check_v2(pluginData, username.c_str(), password.c_str()); |
| 174 | AuthResult r = static_cast<AuthResult>(result); | 194 | AuthResult r = static_cast<AuthResult>(result); |
| 175 | 195 |
authplugin.h
| @@ -54,6 +54,9 @@ class AuthPlugin | @@ -54,6 +54,9 @@ class AuthPlugin | ||
| 54 | F_auth_plugin_unpwd_check_v2 unpwd_check_v2 = nullptr; | 54 | F_auth_plugin_unpwd_check_v2 unpwd_check_v2 = nullptr; |
| 55 | F_auth_plugin_psk_key_get_v2 psk_key_get_v2 = nullptr; | 55 | F_auth_plugin_psk_key_get_v2 psk_key_get_v2 = nullptr; |
| 56 | 56 | ||
| 57 | + static std::mutex initMutex; | ||
| 58 | + static std::mutex authChecksMutex; | ||
| 59 | + | ||
| 57 | Settings &settings; // A ref because I want it to always be the same as the thread's settings | 60 | Settings &settings; // A ref because I want it to always be the same as the thread's settings |
| 58 | 61 | ||
| 59 | void *pluginData = nullptr; | 62 | void *pluginData = nullptr; |
configfileparser.cpp
| @@ -20,7 +20,7 @@ void ConfigFileParser::testKeyValidity(const std::string &key, const std::set<st | @@ -20,7 +20,7 @@ void ConfigFileParser::testKeyValidity(const std::string &key, const std::set<st | ||
| 20 | if (valid_key_it == validKeys.end()) | 20 | if (valid_key_it == validKeys.end()) |
| 21 | { | 21 | { |
| 22 | std::ostringstream oss; | 22 | std::ostringstream oss; |
| 23 | - oss << "Config key '" << key << "' is not valid here."; | 23 | + oss << "Config key '" << key << "' is not valid (here)."; |
| 24 | throw ConfigFileException(oss.str()); | 24 | throw ConfigFileException(oss.str()); |
| 25 | } | 25 | } |
| 26 | } | 26 | } |
| @@ -42,6 +42,8 @@ ConfigFileParser::ConfigFileParser(const std::string &path) : | @@ -42,6 +42,8 @@ ConfigFileParser::ConfigFileParser(const std::string &path) : | ||
| 42 | validKeys.insert("log_file"); | 42 | validKeys.insert("log_file"); |
| 43 | validKeys.insert("allow_unsafe_clientid_chars"); | 43 | validKeys.insert("allow_unsafe_clientid_chars"); |
| 44 | validKeys.insert("allow_unsafe_username_chars"); | 44 | validKeys.insert("allow_unsafe_username_chars"); |
| 45 | + validKeys.insert("auth_plugin_serialize_init"); | ||
| 46 | + validKeys.insert("auth_plugin_serialize_auth_checks"); | ||
| 45 | validKeys.insert("client_initial_buffer_size"); | 47 | validKeys.insert("client_initial_buffer_size"); |
| 46 | validKeys.insert("max_packet_size"); | 48 | validKeys.insert("max_packet_size"); |
| 47 | 49 | ||
| @@ -228,6 +230,18 @@ void ConfigFileParser::loadFile(bool test) | @@ -228,6 +230,18 @@ void ConfigFileParser::loadFile(bool test) | ||
| 228 | tmpSettings->allowUnsafeUsernameChars = tmp; | 230 | tmpSettings->allowUnsafeUsernameChars = tmp; |
| 229 | } | 231 | } |
| 230 | 232 | ||
| 233 | + if (key == "auth_plugin_serialize_init") | ||
| 234 | + { | ||
| 235 | + bool tmp = stringTruthiness(value); | ||
| 236 | + tmpSettings->authPluginSerializeInit = tmp; | ||
| 237 | + } | ||
| 238 | + | ||
| 239 | + if (key == "auth_plugin_serialize_auth_checks") | ||
| 240 | + { | ||
| 241 | + bool tmp = stringTruthiness(value); | ||
| 242 | + tmpSettings->authPluginSerializeAuthChecks = tmp; | ||
| 243 | + } | ||
| 244 | + | ||
| 231 | if (key == "client_initial_buffer_size") | 245 | if (key == "client_initial_buffer_size") |
| 232 | { | 246 | { |
| 233 | int newVal = std::stoi(value); | 247 | int newVal = std::stoi(value); |
settings.h
| @@ -19,6 +19,8 @@ public: | @@ -19,6 +19,8 @@ public: | ||
| 19 | std::string logPath; | 19 | std::string logPath; |
| 20 | bool allowUnsafeClientidChars = false; | 20 | bool allowUnsafeClientidChars = false; |
| 21 | bool allowUnsafeUsernameChars = false; | 21 | bool allowUnsafeUsernameChars = false; |
| 22 | + bool authPluginSerializeInit = false; | ||
| 23 | + bool authPluginSerializeAuthChecks = false; | ||
| 22 | int clientInitialBufferSize = 1024; // Must be power of 2 | 24 | int clientInitialBufferSize = 1024; // Must be power of 2 |
| 23 | int maxPacketSize = 268435461; // 256 MB + 5 | 25 | int maxPacketSize = 268435461; // 256 MB + 5 |
| 24 | std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined. | 26 | std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined. |
unscopedlock.cpp
0 → 100644
| 1 | +#include "unscopedlock.h" | ||
| 2 | + | ||
| 3 | +UnscopedLock::~UnscopedLock() | ||
| 4 | +{ | ||
| 5 | + if (locked) | ||
| 6 | + { | ||
| 7 | + managedMutex.unlock(); | ||
| 8 | + } | ||
| 9 | +} | ||
| 10 | + | ||
| 11 | +UnscopedLock::UnscopedLock(std::mutex &mutex) : | ||
| 12 | + managedMutex(mutex) | ||
| 13 | +{ | ||
| 14 | + | ||
| 15 | +} | ||
| 16 | + | ||
| 17 | +void UnscopedLock::lock() | ||
| 18 | +{ | ||
| 19 | + managedMutex.lock(); | ||
| 20 | + locked = true; | ||
| 21 | +} |
unscopedlock.h
0 → 100644
| 1 | +#ifndef UNSCOPEDLOCK_H | ||
| 2 | +#define UNSCOPEDLOCK_H | ||
| 3 | + | ||
| 4 | +#include <mutex> | ||
| 5 | + | ||
| 6 | +/** | ||
| 7 | + * @brief The UnscopedLock class is a simple variety of the std::lock_guard or std::scoped_lock that allows optional locking using RAII. | ||
| 8 | + * | ||
| 9 | + * STL doesn't provide a similar feature, or am I missing something? You could do it with smart pointers, but I want to avoid having to | ||
| 10 | + * use the free store. | ||
| 11 | + */ | ||
| 12 | +class UnscopedLock | ||
| 13 | +{ | ||
| 14 | + std::mutex &managedMutex; | ||
| 15 | + bool locked = false; | ||
| 16 | + | ||
| 17 | +public: | ||
| 18 | + ~UnscopedLock(); | ||
| 19 | + | ||
| 20 | + UnscopedLock(std::mutex &mutex); | ||
| 21 | + void lock(); | ||
| 22 | +}; | ||
| 23 | + | ||
| 24 | +#endif // UNSCOPEDLOCK_H |