Commit 775180c0c7b55b5c3fde9b9247f623ab4ad327e6

Authored by Wiebe Cazemier
1 parent be205082

Implement the option to serialize auth checks

CMakeLists.txt
@@ -32,6 +32,7 @@ add_executable(FlashMQ @@ -32,6 +32,7 @@ add_executable(FlashMQ
32 mosquittoauthoptcompatwrap.cpp 32 mosquittoauthoptcompatwrap.cpp
33 settings.cpp 33 settings.cpp
34 listener.cpp 34 listener.cpp
  35 + unscopedlock.cpp
35 ) 36 )
36 37
37 target_link_libraries(FlashMQ pthread dl ssl crypto) 38 target_link_libraries(FlashMQ pthread dl ssl crypto)
authplugin.cpp
@@ -6,6 +6,10 @@ @@ -6,6 +6,10 @@
6 #include <dlfcn.h> 6 #include <dlfcn.h>
7 7
8 #include "exceptions.h" 8 #include "exceptions.h"
  9 +#include "unscopedlock.h"
  10 +
  11 +std::mutex AuthPlugin::initMutex;
  12 +std::mutex AuthPlugin::authChecksMutex;
9 13
10 void mosquitto_log_printf(int level, const char *fmt, ...) 14 void mosquitto_log_printf(int level, const char *fmt, ...)
11 { 15 {
@@ -89,6 +93,10 @@ void AuthPlugin::init() @@ -89,6 +93,10 @@ void AuthPlugin::init()
89 if (!wanted) 93 if (!wanted)
90 return; 94 return;
91 95
  96 + UnscopedLock lock(initMutex);
  97 + if (settings.authPluginSerializeInit)
  98 + lock.lock();
  99 +
92 AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); 100 AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
93 int result = init_v2(&pluginData, authOpts.head(), authOpts.size()); 101 int result = init_v2(&pluginData, authOpts.head(), authOpts.size());
94 if (result != 0) 102 if (result != 0)
@@ -113,6 +121,10 @@ void AuthPlugin::securityInit(bool reloading) @@ -113,6 +121,10 @@ void AuthPlugin::securityInit(bool reloading)
113 if (!wanted) 121 if (!wanted)
114 return; 122 return;
115 123
  124 + UnscopedLock lock(initMutex);
  125 + if (settings.authPluginSerializeInit)
  126 + lock.lock();
  127 +
116 AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); 128 AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
117 int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading); 129 int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading);
118 if (result != 0) 130 if (result != 0)
@@ -148,6 +160,10 @@ AuthResult AuthPlugin::aclCheck(const std::string &amp;clientid, const std::string &amp; @@ -148,6 +160,10 @@ AuthResult AuthPlugin::aclCheck(const std::string &amp;clientid, const std::string &amp;
148 return AuthResult::error; 160 return AuthResult::error;
149 } 161 }
150 162
  163 + UnscopedLock lock(authChecksMutex);
  164 + if (settings.authPluginSerializeAuthChecks)
  165 + lock.lock();
  166 +
151 int result = acl_check_v2(pluginData, clientid.c_str(), username.c_str(), topic.c_str(), static_cast<int>(access)); 167 int result = acl_check_v2(pluginData, clientid.c_str(), username.c_str(), topic.c_str(), static_cast<int>(access));
152 AuthResult result_ = static_cast<AuthResult>(result); 168 AuthResult result_ = static_cast<AuthResult>(result);
153 169
@@ -170,6 +186,10 @@ AuthResult AuthPlugin::unPwdCheck(const std::string &amp;username, const std::string @@ -170,6 +186,10 @@ AuthResult AuthPlugin::unPwdCheck(const std::string &amp;username, const std::string
170 return AuthResult::error; 186 return AuthResult::error;
171 } 187 }
172 188
  189 + UnscopedLock lock(authChecksMutex);
  190 + if (settings.authPluginSerializeAuthChecks)
  191 + lock.lock();
  192 +
173 int result = unpwd_check_v2(pluginData, username.c_str(), password.c_str()); 193 int result = unpwd_check_v2(pluginData, username.c_str(), password.c_str());
174 AuthResult r = static_cast<AuthResult>(result); 194 AuthResult r = static_cast<AuthResult>(result);
175 195
authplugin.h
@@ -54,6 +54,9 @@ class AuthPlugin @@ -54,6 +54,9 @@ class AuthPlugin
54 F_auth_plugin_unpwd_check_v2 unpwd_check_v2 = nullptr; 54 F_auth_plugin_unpwd_check_v2 unpwd_check_v2 = nullptr;
55 F_auth_plugin_psk_key_get_v2 psk_key_get_v2 = nullptr; 55 F_auth_plugin_psk_key_get_v2 psk_key_get_v2 = nullptr;
56 56
  57 + static std::mutex initMutex;
  58 + static std::mutex authChecksMutex;
  59 +
57 Settings &settings; // A ref because I want it to always be the same as the thread's settings 60 Settings &settings; // A ref because I want it to always be the same as the thread's settings
58 61
59 void *pluginData = nullptr; 62 void *pluginData = nullptr;
configfileparser.cpp
@@ -20,7 +20,7 @@ void ConfigFileParser::testKeyValidity(const std::string &amp;key, const std::set&lt;st @@ -20,7 +20,7 @@ void ConfigFileParser::testKeyValidity(const std::string &amp;key, const std::set&lt;st
20 if (valid_key_it == validKeys.end()) 20 if (valid_key_it == validKeys.end())
21 { 21 {
22 std::ostringstream oss; 22 std::ostringstream oss;
23 - oss << "Config key '" << key << "' is not valid here."; 23 + oss << "Config key '" << key << "' is not valid (here).";
24 throw ConfigFileException(oss.str()); 24 throw ConfigFileException(oss.str());
25 } 25 }
26 } 26 }
@@ -42,6 +42,8 @@ ConfigFileParser::ConfigFileParser(const std::string &amp;path) : @@ -42,6 +42,8 @@ ConfigFileParser::ConfigFileParser(const std::string &amp;path) :
42 validKeys.insert("log_file"); 42 validKeys.insert("log_file");
43 validKeys.insert("allow_unsafe_clientid_chars"); 43 validKeys.insert("allow_unsafe_clientid_chars");
44 validKeys.insert("allow_unsafe_username_chars"); 44 validKeys.insert("allow_unsafe_username_chars");
  45 + validKeys.insert("auth_plugin_serialize_init");
  46 + validKeys.insert("auth_plugin_serialize_auth_checks");
45 validKeys.insert("client_initial_buffer_size"); 47 validKeys.insert("client_initial_buffer_size");
46 validKeys.insert("max_packet_size"); 48 validKeys.insert("max_packet_size");
47 49
@@ -228,6 +230,18 @@ void ConfigFileParser::loadFile(bool test) @@ -228,6 +230,18 @@ void ConfigFileParser::loadFile(bool test)
228 tmpSettings->allowUnsafeUsernameChars = tmp; 230 tmpSettings->allowUnsafeUsernameChars = tmp;
229 } 231 }
230 232
  233 + if (key == "auth_plugin_serialize_init")
  234 + {
  235 + bool tmp = stringTruthiness(value);
  236 + tmpSettings->authPluginSerializeInit = tmp;
  237 + }
  238 +
  239 + if (key == "auth_plugin_serialize_auth_checks")
  240 + {
  241 + bool tmp = stringTruthiness(value);
  242 + tmpSettings->authPluginSerializeAuthChecks = tmp;
  243 + }
  244 +
231 if (key == "client_initial_buffer_size") 245 if (key == "client_initial_buffer_size")
232 { 246 {
233 int newVal = std::stoi(value); 247 int newVal = std::stoi(value);
settings.h
@@ -19,6 +19,8 @@ public: @@ -19,6 +19,8 @@ public:
19 std::string logPath; 19 std::string logPath;
20 bool allowUnsafeClientidChars = false; 20 bool allowUnsafeClientidChars = false;
21 bool allowUnsafeUsernameChars = false; 21 bool allowUnsafeUsernameChars = false;
  22 + bool authPluginSerializeInit = false;
  23 + bool authPluginSerializeAuthChecks = false;
22 int clientInitialBufferSize = 1024; // Must be power of 2 24 int clientInitialBufferSize = 1024; // Must be power of 2
23 int maxPacketSize = 268435461; // 256 MB + 5 25 int maxPacketSize = 268435461; // 256 MB + 5
24 std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined. 26 std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined.
unscopedlock.cpp 0 → 100644
  1 +#include "unscopedlock.h"
  2 +
  3 +UnscopedLock::~UnscopedLock()
  4 +{
  5 + if (locked)
  6 + {
  7 + managedMutex.unlock();
  8 + }
  9 +}
  10 +
  11 +UnscopedLock::UnscopedLock(std::mutex &mutex) :
  12 + managedMutex(mutex)
  13 +{
  14 +
  15 +}
  16 +
  17 +void UnscopedLock::lock()
  18 +{
  19 + managedMutex.lock();
  20 + locked = true;
  21 +}
unscopedlock.h 0 → 100644
  1 +#ifndef UNSCOPEDLOCK_H
  2 +#define UNSCOPEDLOCK_H
  3 +
  4 +#include <mutex>
  5 +
  6 +/**
  7 + * @brief The UnscopedLock class is a simple variety of the std::lock_guard or std::scoped_lock that allows optional locking using RAII.
  8 + *
  9 + * STL doesn't provide a similar feature, or am I missing something? You could do it with smart pointers, but I want to avoid having to
  10 + * use the free store.
  11 + */
  12 +class UnscopedLock
  13 +{
  14 + std::mutex &managedMutex;
  15 + bool locked = false;
  16 +
  17 +public:
  18 + ~UnscopedLock();
  19 +
  20 + UnscopedLock(std::mutex &mutex);
  21 + void lock();
  22 +};
  23 +
  24 +#endif // UNSCOPEDLOCK_H