Commit 775180c0c7b55b5c3fde9b9247f623ab4ad327e6

Authored by Wiebe Cazemier
1 parent be205082

Implement the option to serialize auth checks

CMakeLists.txt
... ... @@ -32,6 +32,7 @@ add_executable(FlashMQ
32 32 mosquittoauthoptcompatwrap.cpp
33 33 settings.cpp
34 34 listener.cpp
  35 + unscopedlock.cpp
35 36 )
36 37  
37 38 target_link_libraries(FlashMQ pthread dl ssl crypto)
... ...
authplugin.cpp
... ... @@ -6,6 +6,10 @@
6 6 #include <dlfcn.h>
7 7  
8 8 #include "exceptions.h"
  9 +#include "unscopedlock.h"
  10 +
  11 +std::mutex AuthPlugin::initMutex;
  12 +std::mutex AuthPlugin::authChecksMutex;
9 13  
10 14 void mosquitto_log_printf(int level, const char *fmt, ...)
11 15 {
... ... @@ -89,6 +93,10 @@ void AuthPlugin::init()
89 93 if (!wanted)
90 94 return;
91 95  
  96 + UnscopedLock lock(initMutex);
  97 + if (settings.authPluginSerializeInit)
  98 + lock.lock();
  99 +
92 100 AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
93 101 int result = init_v2(&pluginData, authOpts.head(), authOpts.size());
94 102 if (result != 0)
... ... @@ -113,6 +121,10 @@ void AuthPlugin::securityInit(bool reloading)
113 121 if (!wanted)
114 122 return;
115 123  
  124 + UnscopedLock lock(initMutex);
  125 + if (settings.authPluginSerializeInit)
  126 + lock.lock();
  127 +
116 128 AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
117 129 int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading);
118 130 if (result != 0)
... ... @@ -148,6 +160,10 @@ AuthResult AuthPlugin::aclCheck(const std::string &amp;clientid, const std::string &amp;
148 160 return AuthResult::error;
149 161 }
150 162  
  163 + UnscopedLock lock(authChecksMutex);
  164 + if (settings.authPluginSerializeAuthChecks)
  165 + lock.lock();
  166 +
151 167 int result = acl_check_v2(pluginData, clientid.c_str(), username.c_str(), topic.c_str(), static_cast<int>(access));
152 168 AuthResult result_ = static_cast<AuthResult>(result);
153 169  
... ... @@ -170,6 +186,10 @@ AuthResult AuthPlugin::unPwdCheck(const std::string &amp;username, const std::string
170 186 return AuthResult::error;
171 187 }
172 188  
  189 + UnscopedLock lock(authChecksMutex);
  190 + if (settings.authPluginSerializeAuthChecks)
  191 + lock.lock();
  192 +
173 193 int result = unpwd_check_v2(pluginData, username.c_str(), password.c_str());
174 194 AuthResult r = static_cast<AuthResult>(result);
175 195  
... ...
authplugin.h
... ... @@ -54,6 +54,9 @@ class AuthPlugin
54 54 F_auth_plugin_unpwd_check_v2 unpwd_check_v2 = nullptr;
55 55 F_auth_plugin_psk_key_get_v2 psk_key_get_v2 = nullptr;
56 56  
  57 + static std::mutex initMutex;
  58 + static std::mutex authChecksMutex;
  59 +
57 60 Settings &settings; // A ref because I want it to always be the same as the thread's settings
58 61  
59 62 void *pluginData = nullptr;
... ...
configfileparser.cpp
... ... @@ -20,7 +20,7 @@ void ConfigFileParser::testKeyValidity(const std::string &amp;key, const std::set&lt;st
20 20 if (valid_key_it == validKeys.end())
21 21 {
22 22 std::ostringstream oss;
23   - oss << "Config key '" << key << "' is not valid here.";
  23 + oss << "Config key '" << key << "' is not valid (here).";
24 24 throw ConfigFileException(oss.str());
25 25 }
26 26 }
... ... @@ -42,6 +42,8 @@ ConfigFileParser::ConfigFileParser(const std::string &amp;path) :
42 42 validKeys.insert("log_file");
43 43 validKeys.insert("allow_unsafe_clientid_chars");
44 44 validKeys.insert("allow_unsafe_username_chars");
  45 + validKeys.insert("auth_plugin_serialize_init");
  46 + validKeys.insert("auth_plugin_serialize_auth_checks");
45 47 validKeys.insert("client_initial_buffer_size");
46 48 validKeys.insert("max_packet_size");
47 49  
... ... @@ -228,6 +230,18 @@ void ConfigFileParser::loadFile(bool test)
228 230 tmpSettings->allowUnsafeUsernameChars = tmp;
229 231 }
230 232  
  233 + if (key == "auth_plugin_serialize_init")
  234 + {
  235 + bool tmp = stringTruthiness(value);
  236 + tmpSettings->authPluginSerializeInit = tmp;
  237 + }
  238 +
  239 + if (key == "auth_plugin_serialize_auth_checks")
  240 + {
  241 + bool tmp = stringTruthiness(value);
  242 + tmpSettings->authPluginSerializeAuthChecks = tmp;
  243 + }
  244 +
231 245 if (key == "client_initial_buffer_size")
232 246 {
233 247 int newVal = std::stoi(value);
... ...
settings.h
... ... @@ -19,6 +19,8 @@ public:
19 19 std::string logPath;
20 20 bool allowUnsafeClientidChars = false;
21 21 bool allowUnsafeUsernameChars = false;
  22 + bool authPluginSerializeInit = false;
  23 + bool authPluginSerializeAuthChecks = false;
22 24 int clientInitialBufferSize = 1024; // Must be power of 2
23 25 int maxPacketSize = 268435461; // 256 MB + 5
24 26 std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined.
... ...
unscopedlock.cpp 0 → 100644
  1 +#include "unscopedlock.h"
  2 +
  3 +UnscopedLock::~UnscopedLock()
  4 +{
  5 + if (locked)
  6 + {
  7 + managedMutex.unlock();
  8 + }
  9 +}
  10 +
  11 +UnscopedLock::UnscopedLock(std::mutex &mutex) :
  12 + managedMutex(mutex)
  13 +{
  14 +
  15 +}
  16 +
  17 +void UnscopedLock::lock()
  18 +{
  19 + managedMutex.lock();
  20 + locked = true;
  21 +}
... ...
unscopedlock.h 0 → 100644
  1 +#ifndef UNSCOPEDLOCK_H
  2 +#define UNSCOPEDLOCK_H
  3 +
  4 +#include <mutex>
  5 +
  6 +/**
  7 + * @brief The UnscopedLock class is a simple variety of the std::lock_guard or std::scoped_lock that allows optional locking using RAII.
  8 + *
  9 + * STL doesn't provide a similar feature, or am I missing something? You could do it with smart pointers, but I want to avoid having to
  10 + * use the free store.
  11 + */
  12 +class UnscopedLock
  13 +{
  14 + std::mutex &managedMutex;
  15 + bool locked = false;
  16 +
  17 +public:
  18 + ~UnscopedLock();
  19 +
  20 + UnscopedLock(std::mutex &mutex);
  21 + void lock();
  22 +};
  23 +
  24 +#endif // UNSCOPEDLOCK_H
... ...