Commit aa19d41b6524de1e1bfff40a32e4f8344b0a5eeb

Authored by Wiebe Cazemier
1 parent e0d2e93c

Implement a native FlashMQ auth plugin

CMakeLists.txt
@@ -49,6 +49,7 @@ add_executable(FlashMQ @@ -49,6 +49,7 @@ add_executable(FlashMQ
49 acltree.h 49 acltree.h
50 enums.h 50 enums.h
51 threadlocalutils.h 51 threadlocalutils.h
  52 + flashmq_plugin.h
52 53
53 mainapp.cpp 54 mainapp.cpp
54 main.cpp 55 main.cpp
@@ -79,6 +80,7 @@ add_executable(FlashMQ @@ -79,6 +80,7 @@ add_executable(FlashMQ
79 evpencodectxmanager.cpp 80 evpencodectxmanager.cpp
80 acltree.cpp 81 acltree.cpp
81 threadlocalutils.cpp 82 threadlocalutils.cpp
  83 + flashmq_plugin.cpp
82 84
83 ) 85 )
84 86
authplugin.cpp
@@ -74,11 +74,11 @@ Authentication::~Authentication() @@ -74,11 +74,11 @@ Authentication::~Authentication()
74 EVP_MD_CTX_free(mosquittoDigestContext); 74 EVP_MD_CTX_free(mosquittoDigestContext);
75 } 75 }
76 76
77 -void *Authentication::loadSymbol(void *handle, const char *symbol) const 77 +void *Authentication::loadSymbol(void *handle, const char *symbol, bool exceptionOnError) const
78 { 78 {
79 void *r = dlsym(handle, symbol); 79 void *r = dlsym(handle, symbol);
80 80
81 - if (r == NULL) 81 + if (r == NULL && exceptionOnError)
82 { 82 {
83 std::string errmsg(dlerror()); 83 std::string errmsg(dlerror());
84 throw FatalError(errmsg); 84 throw FatalError(errmsg);
@@ -95,7 +95,7 @@ void Authentication::loadPlugin(const std::string &pathToSoFile) @@ -95,7 +95,7 @@ void Authentication::loadPlugin(const std::string &pathToSoFile)
95 logger->logf(LOG_NOTICE, "Loading auth plugin %s", pathToSoFile.c_str()); 95 logger->logf(LOG_NOTICE, "Loading auth plugin %s", pathToSoFile.c_str());
96 96
97 initialized = false; 97 initialized = false;
98 - useExternalPlugin = true; 98 + pluginVersion = PluginVersion::Determining;
99 99
100 if (access(pathToSoFile.c_str(), R_OK) != 0) 100 if (access(pathToSoFile.c_str(), R_OK) != 0)
101 { 101 {
@@ -112,20 +112,41 @@ void Authentication::loadPlugin(const std::string &pathToSoFile) @@ -112,20 +112,41 @@ void Authentication::loadPlugin(const std::string &pathToSoFile)
112 throw FatalError(errmsg); 112 throw FatalError(errmsg);
113 } 113 }
114 114
115 - version = (F_auth_plugin_version)loadSymbol(r, "mosquitto_auth_plugin_version");  
116 -  
117 - if (version() != 2) 115 + version = (F_auth_plugin_version)loadSymbol(r, "mosquitto_auth_plugin_version", false);
  116 + if (version != nullptr)
118 { 117 {
119 - throw FatalError("Only Mosquitto plugin version 2 is supported at this time."); 118 + if (version() != 2)
  119 + {
  120 + throw FatalError("Only Mosquitto plugin version 2 is supported at this time.");
  121 + }
  122 +
  123 + pluginVersion = PluginVersion::MosquittoV2;
  124 +
  125 + init_v2 = (F_auth_plugin_init_v2)loadSymbol(r, "mosquitto_auth_plugin_init");
  126 + cleanup_v2 = (F_auth_plugin_cleanup_v2)loadSymbol(r, "mosquitto_auth_plugin_cleanup");
  127 + security_init_v2 = (F_auth_plugin_security_init_v2)loadSymbol(r, "mosquitto_auth_security_init");
  128 + security_cleanup_v2 = (F_auth_plugin_security_cleanup_v2)loadSymbol(r, "mosquitto_auth_security_cleanup");
  129 + acl_check_v2 = (F_auth_plugin_acl_check_v2)loadSymbol(r, "mosquitto_auth_acl_check");
  130 + unpwd_check_v2 = (F_auth_plugin_unpwd_check_v2)loadSymbol(r, "mosquitto_auth_unpwd_check");
  131 + psk_key_get_v2 = (F_auth_plugin_psk_key_get_v2)loadSymbol(r, "mosquitto_auth_psk_key_get");
120 } 132 }
  133 + else if ((version = (F_auth_plugin_version)loadSymbol(r, "flashmq_auth_plugin_version", false)) != nullptr)
  134 + {
  135 + if (version() != 1)
  136 + {
  137 + throw FatalError("FlashMQ plugin only supports version 1.");
  138 + }
  139 +
  140 + pluginVersion = PluginVersion::FlashMQv1;
121 141
122 - init_v2 = (F_auth_plugin_init_v2)loadSymbol(r, "mosquitto_auth_plugin_init");  
123 - cleanup_v2 = (F_auth_plugin_cleanup_v2)loadSymbol(r, "mosquitto_auth_plugin_cleanup");  
124 - security_init_v2 = (F_auth_plugin_security_init_v2)loadSymbol(r, "mosquitto_auth_security_init");  
125 - security_cleanup_v2 = (F_auth_plugin_security_cleanup_v2)loadSymbol(r, "mosquitto_auth_security_cleanup");  
126 - acl_check_v2 = (F_auth_plugin_acl_check_v2)loadSymbol(r, "mosquitto_auth_acl_check");  
127 - unpwd_check_v2 = (F_auth_plugin_unpwd_check_v2)loadSymbol(r, "mosquitto_auth_unpwd_check");  
128 - psk_key_get_v2 = (F_auth_plugin_psk_key_get_v2)loadSymbol(r, "mosquitto_auth_psk_key_get"); 142 + flashmq_auth_plugin_allocate_thread_memory_v1 = (F_flashmq_auth_plugin_allocate_thread_memory_v1)loadSymbol(r, "flashmq_auth_plugin_allocate_thread_memory");
  143 + flashmq_auth_plugin_deallocate_thread_memory_v1 = (F_flashmq_auth_plugin_deallocate_thread_memory_v1)loadSymbol(r, "flashmq_auth_plugin_deallocate_thread_memory");
  144 + flashmq_auth_plugin_init_v1 = (F_flashmq_auth_plugin_init_v1)loadSymbol(r, "flashmq_auth_plugin_init");
  145 + flashmq_auth_plugin_deinit_v1 = (F_flashmq_auth_plugin_deinit_v1)loadSymbol(r, "flashmq_auth_plugin_deinit");
  146 + flashmq_auth_plugin_acl_check_v1 = (F_flashmq_auth_plugin_acl_check_v1)loadSymbol(r, "flashmq_auth_plugin_acl_check");
  147 + flashmq_auth_plugin_login_check_v1 = (F_flashmq_auth_plugin_login_check_v1)loadSymbol(r, "flashmq_auth_plugin_login_check");
  148 + flashmq_auth_plugin_periodic_event_v1 = (F_flashmq_auth_plugin_periodic_event)loadSymbol(r, "flashmq_auth_plugin_periodic_event", false);
  149 + }
129 150
130 initialized = true; 151 initialized = true;
131 } 152 }
@@ -136,7 +157,7 @@ void Authentication::loadPlugin(const std::string &pathToSoFile) @@ -136,7 +157,7 @@ void Authentication::loadPlugin(const std::string &pathToSoFile)
136 */ 157 */
137 void Authentication::init() 158 void Authentication::init()
138 { 159 {
139 - if (!useExternalPlugin) 160 + if (pluginVersion == PluginVersion::None)
140 return; 161 return;
141 162
142 UnscopedLock lock(initMutex); 163 UnscopedLock lock(initMutex);
@@ -146,23 +167,46 @@ void Authentication::init() @@ -146,23 +167,46 @@ void Authentication::init()
146 if (quitting) 167 if (quitting)
147 return; 168 return;
148 169
149 - AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();  
150 - int result = init_v2(&pluginData, authOpts.head(), authOpts.size());  
151 - if (result != 0)  
152 - throw FatalError("Error initialising auth plugin."); 170 + if (pluginVersion == PluginVersion::MosquittoV2)
  171 + {
  172 + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
  173 + int result = init_v2(&pluginData, authOpts.head(), authOpts.size());
  174 + if (result != 0)
  175 + throw FatalError("Error initialising auth plugin.");
  176 + }
  177 + else if (pluginVersion == PluginVersion::FlashMQv1)
  178 + {
  179 + std::unordered_map<std::string, std::string> &authOpts = settings.getFlashmqAuthPluginOpts();
  180 + flashmq_auth_plugin_allocate_thread_memory_v1(&pluginData, authOpts);
  181 + }
153 } 182 }
154 183
155 void Authentication::cleanup() 184 void Authentication::cleanup()
156 { 185 {
157 - if (!cleanup_v2) 186 + if (pluginVersion == PluginVersion::None)
158 return; 187 return;
159 188
160 securityCleanup(false); 189 securityCleanup(false);
161 190
162 - AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();  
163 - int result = cleanup_v2(pluginData, authOpts.head(), authOpts.size());  
164 - if (result != 0)  
165 - logger->logf(LOG_ERR, "Error cleaning up auth plugin"); // Not doing exception, because we're shutting down anyway. 191 + if (pluginVersion == PluginVersion::MosquittoV2)
  192 + {
  193 + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
  194 + int result = cleanup_v2(pluginData, authOpts.head(), authOpts.size());
  195 + if (result != 0)
  196 + logger->logf(LOG_ERR, "Error cleaning up auth plugin"); // Not doing exception, because we're shutting down anyway.
  197 + }
  198 + else if (pluginVersion == PluginVersion::FlashMQv1)
  199 + {
  200 + try
  201 + {
  202 + std::unordered_map<std::string, std::string> &authOpts = settings.getFlashmqAuthPluginOpts();
  203 + flashmq_auth_plugin_deallocate_thread_memory_v1(pluginData, authOpts);
  204 + }
  205 + catch (std::exception &ex)
  206 + {
  207 + logger->logf(LOG_ERR, "Error cleaning up auth plugin: '%s'", ex.what()); // Not doing exception, because we're shutting down anyway.
  208 + }
  209 + }
166 } 210 }
167 211
168 /** 212 /**
@@ -171,7 +215,7 @@ void Authentication::cleanup() @@ -171,7 +215,7 @@ void Authentication::cleanup()
171 */ 215 */
172 void Authentication::securityInit(bool reloading) 216 void Authentication::securityInit(bool reloading)
173 { 217 {
174 - if (!useExternalPlugin) 218 + if (pluginVersion == PluginVersion::None)
175 return; 219 return;
176 220
177 UnscopedLock lock(initMutex); 221 UnscopedLock lock(initMutex);
@@ -181,31 +225,52 @@ void Authentication::securityInit(bool reloading) @@ -181,31 +225,52 @@ void Authentication::securityInit(bool reloading)
181 if (quitting) 225 if (quitting)
182 return; 226 return;
183 227
184 - AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();  
185 - int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading);  
186 - if (result != 0) 228 + if (pluginVersion == PluginVersion::MosquittoV2)
  229 + {
  230 + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
  231 + int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading);
  232 + if (result != 0)
  233 + {
  234 + throw AuthPluginException("Plugin function mosquitto_auth_security_init returned an error. If it didn't log anything, we don't know what it was.");
  235 + }
  236 + }
  237 + else if (pluginVersion == PluginVersion::FlashMQv1)
187 { 238 {
188 - throw AuthPluginException("Plugin function mosquitto_auth_security_init returned an error. If it didn't log anything, we don't know what it was."); 239 + std::unordered_map<std::string, std::string> &authOpts = settings.getFlashmqAuthPluginOpts();
  240 + flashmq_auth_plugin_init_v1(pluginData, authOpts, reloading);
189 } 241 }
  242 +
190 initialized = true; 243 initialized = true;
  244 +
  245 + periodicEvent();
191 } 246 }
192 247
193 void Authentication::securityCleanup(bool reloading) 248 void Authentication::securityCleanup(bool reloading)
194 { 249 {
195 - if (!useExternalPlugin) 250 + if (pluginVersion == PluginVersion::None)
196 return; 251 return;
197 252
198 initialized = false; 253 initialized = false;
199 - AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();  
200 - int result = security_cleanup_v2(pluginData, authOpts.head(), authOpts.size(), reloading);  
201 254
202 - if (result != 0) 255 + if (pluginVersion == PluginVersion::MosquittoV2)
  256 + {
  257 + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
  258 + int result = security_cleanup_v2(pluginData, authOpts.head(), authOpts.size(), reloading);
  259 +
  260 + if (result != 0)
  261 + {
  262 + throw AuthPluginException("Plugin function mosquitto_auth_security_cleanup returned an error. If it didn't log anything, we don't know what it was.");
  263 + }
  264 + }
  265 + else if (pluginVersion == PluginVersion::FlashMQv1)
203 { 266 {
204 - throw AuthPluginException("Plugin function mosquitto_auth_security_cleanup returned an error. If it didn't log anything, we don't know what it was."); 267 + std::unordered_map<std::string, std::string> &authOpts = settings.getFlashmqAuthPluginOpts();
  268 + flashmq_auth_plugin_deinit_v1(pluginData, authOpts, reloading);
205 } 269 }
206 } 270 }
207 271
208 -AuthResult Authentication::aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics, AclAccess access) 272 +AuthResult Authentication::aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics,
  273 + AclAccess access, char qos, bool retain)
209 { 274 {
210 assert(subtopics.size() > 0); 275 assert(subtopics.size() > 0);
211 276
@@ -214,7 +279,7 @@ AuthResult Authentication::aclCheck(const std::string &amp;clientid, const std::stri @@ -214,7 +279,7 @@ AuthResult Authentication::aclCheck(const std::string &amp;clientid, const std::stri
214 if (firstResult != AuthResult::success) 279 if (firstResult != AuthResult::success)
215 return firstResult; 280 return firstResult;
216 281
217 - if (!useExternalPlugin) 282 + if (pluginVersion == PluginVersion::None)
218 return firstResult; 283 return firstResult;
219 284
220 if (!initialized) 285 if (!initialized)
@@ -227,15 +292,33 @@ AuthResult Authentication::aclCheck(const std::string &amp;clientid, const std::stri @@ -227,15 +292,33 @@ AuthResult Authentication::aclCheck(const std::string &amp;clientid, const std::stri
227 if (settings.authPluginSerializeAuthChecks) 292 if (settings.authPluginSerializeAuthChecks)
228 lock.lock(); 293 lock.lock();
229 294
230 - int result = acl_check_v2(pluginData, clientid.c_str(), username.c_str(), topic.c_str(), static_cast<int>(access));  
231 - AuthResult result_ = static_cast<AuthResult>(result); 295 + if (pluginVersion == PluginVersion::MosquittoV2)
  296 + {
  297 + int result = acl_check_v2(pluginData, clientid.c_str(), username.c_str(), topic.c_str(), static_cast<int>(access));
  298 + AuthResult result_ = static_cast<AuthResult>(result);
232 299
233 - if (result_ == AuthResult::error) 300 + if (result_ == AuthResult::error)
  301 + {
  302 + logger->logf(LOG_ERR, "ACL check by plugin returned error for topic '%s'. If it didn't log anything, we don't know what it was.", topic.c_str());
  303 + }
  304 + }
  305 + else if (pluginVersion == PluginVersion::FlashMQv1)
234 { 306 {
235 - logger->logf(LOG_ERR, "ACL check by plugin returned error for topic '%s'. If it didn't log anything, we don't know what it was.", topic.c_str()); 307 + // I'm using this try/catch because propagating the exception higher up conflicts with who gets the blame, and then the publisher
  308 + // gets disconnected.
  309 + try
  310 + {
  311 + FlashMQMessage msg(topic, subtopics, qos, retain);
  312 + return flashmq_auth_plugin_acl_check_v1(pluginData, access, clientid, username, msg);
  313 + }
  314 + catch (std::exception &ex)
  315 + {
  316 + logger->logf(LOG_ERR, "Error doing ACL check in plugin: '%s'", ex.what());
  317 + logger->logf(LOG_WARNING, "Throwing exceptions from auth plugin login/ACL checks is slow. There's no need.");
  318 + }
236 } 319 }
237 320
238 - return result_; 321 + return AuthResult::error;
239 } 322 }
240 323
241 AuthResult Authentication::unPwdCheck(const std::string &username, const std::string &password) 324 AuthResult Authentication::unPwdCheck(const std::string &username, const std::string &password)
@@ -245,7 +328,7 @@ AuthResult Authentication::unPwdCheck(const std::string &amp;username, const std::st @@ -245,7 +328,7 @@ AuthResult Authentication::unPwdCheck(const std::string &amp;username, const std::st
245 if (firstResult != AuthResult::success) 328 if (firstResult != AuthResult::success)
246 return firstResult; 329 return firstResult;
247 330
248 - if (!useExternalPlugin) 331 + if (pluginVersion == PluginVersion::None)
249 return firstResult; 332 return firstResult;
250 333
251 if (!initialized) 334 if (!initialized)
@@ -258,15 +341,32 @@ AuthResult Authentication::unPwdCheck(const std::string &amp;username, const std::st @@ -258,15 +341,32 @@ AuthResult Authentication::unPwdCheck(const std::string &amp;username, const std::st
258 if (settings.authPluginSerializeAuthChecks) 341 if (settings.authPluginSerializeAuthChecks)
259 lock.lock(); 342 lock.lock();
260 343
261 - int result = unpwd_check_v2(pluginData, username.c_str(), password.c_str());  
262 - AuthResult r = static_cast<AuthResult>(result); 344 + if (pluginVersion == PluginVersion::MosquittoV2)
  345 + {
  346 + int result = unpwd_check_v2(pluginData, username.c_str(), password.c_str());
  347 + AuthResult r = static_cast<AuthResult>(result);
263 348
264 - if (r == AuthResult::error) 349 + if (r == AuthResult::error)
  350 + {
  351 + logger->logf(LOG_ERR, "Username+password check by plugin returned error for user '%s'. If it didn't log anything, we don't know what it was.", username.c_str());
  352 + }
  353 + }
  354 + else if (pluginVersion == PluginVersion::FlashMQv1)
265 { 355 {
266 - logger->logf(LOG_ERR, "Username+password check by plugin returned error for user '%s'. If it didn't log anything, we don't know what it was.", username.c_str()); 356 + // I'm using this try/catch because propagating the exception higher up conflicts with who gets the blame, and then the publisher
  357 + // gets disconnected.
  358 + try
  359 + {
  360 + return flashmq_auth_plugin_login_check_v1(pluginData, username, password);
  361 + }
  362 + catch (std::exception &ex)
  363 + {
  364 + logger->logf(LOG_ERR, "Error doing login check in plugin: '%s'", ex.what());
  365 + logger->logf(LOG_WARNING, "Throwing exceptions from auth plugin login/ACL checks is slow. There's no need.");
  366 + }
267 } 367 }
268 368
269 - return r; 369 + return AuthResult::error;
270 } 370 }
271 371
272 void Authentication::setQuitting() 372 void Authentication::setQuitting()
@@ -488,6 +588,23 @@ AuthResult Authentication::unPwdCheckFromMosquittoPasswordFile(const std::string @@ -488,6 +588,23 @@ AuthResult Authentication::unPwdCheckFromMosquittoPasswordFile(const std::string
488 return result; 588 return result;
489 } 589 }
490 590
  591 +void Authentication::periodicEvent()
  592 +{
  593 + if (pluginVersion == PluginVersion::None)
  594 + return;
  595 +
  596 + if (!initialized)
  597 + {
  598 + logger->logf(LOG_ERR, "Auth plugin period event called, but initialization failed or not performed.");
  599 + return;
  600 + }
  601 +
  602 + if (pluginVersion == PluginVersion::FlashMQv1 && flashmq_auth_plugin_periodic_event_v1)
  603 + {
  604 + flashmq_auth_plugin_periodic_event_v1(pluginData);
  605 + }
  606 +}
  607 +
491 std::string AuthResultToString(AuthResult r) 608 std::string AuthResultToString(AuthResult r)
492 { 609 {
493 if (r == AuthResult::success) 610 if (r == AuthResult::success)
authplugin.h
@@ -26,6 +26,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -26,6 +26,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
26 #include "logger.h" 26 #include "logger.h"
27 #include "configfileparser.h" 27 #include "configfileparser.h"
28 #include "acltree.h" 28 #include "acltree.h"
  29 +#include "flashmq_plugin.h"
29 30
30 /** 31 /**
31 * @brief The MosquittoPasswordFileEntry struct stores the decoded base64 password salt and hash. 32 * @brief The MosquittoPasswordFileEntry struct stores the decoded base64 password salt and hash.
@@ -49,6 +50,7 @@ struct MosquittoPasswordFileEntry @@ -49,6 +50,7 @@ struct MosquittoPasswordFileEntry
49 50
50 typedef int (*F_auth_plugin_version)(void); 51 typedef int (*F_auth_plugin_version)(void);
51 52
  53 +// Mosquitto functions
52 typedef int (*F_auth_plugin_init_v2)(void **, struct mosquitto_auth_opt *, int); 54 typedef int (*F_auth_plugin_init_v2)(void **, struct mosquitto_auth_opt *, int);
53 typedef int (*F_auth_plugin_cleanup_v2)(void *, struct mosquitto_auth_opt *, int); 55 typedef int (*F_auth_plugin_cleanup_v2)(void *, struct mosquitto_auth_opt *, int);
54 typedef int (*F_auth_plugin_security_init_v2)(void *, struct mosquitto_auth_opt *, int, bool); 56 typedef int (*F_auth_plugin_security_init_v2)(void *, struct mosquitto_auth_opt *, int, bool);
@@ -57,12 +59,29 @@ typedef int (*F_auth_plugin_acl_check_v2)(void *, const char *, const char *, co @@ -57,12 +59,29 @@ typedef int (*F_auth_plugin_acl_check_v2)(void *, const char *, const char *, co
57 typedef int (*F_auth_plugin_unpwd_check_v2)(void *, const char *, const char *); 59 typedef int (*F_auth_plugin_unpwd_check_v2)(void *, const char *, const char *);
58 typedef int (*F_auth_plugin_psk_key_get_v2)(void *, const char *, const char *, char *, int); 60 typedef int (*F_auth_plugin_psk_key_get_v2)(void *, const char *, const char *, char *, int);
59 61
  62 +
  63 +typedef void(*F_flashmq_auth_plugin_allocate_thread_memory_v1)(void **thread_data, std::unordered_map<std::string, std::string> &auth_opts);
  64 +typedef void(*F_flashmq_auth_plugin_deallocate_thread_memory_v1)(void *thread_data, std::unordered_map<std::string, std::string> &auth_opts);
  65 +typedef void(*F_flashmq_auth_plugin_init_v1)(void *thread_data, std::unordered_map<std::string, std::string> &auth_opts, bool reloading);
  66 +typedef void(*F_flashmq_auth_plugin_deinit_v1)(void *thread_data, std::unordered_map<std::string, std::string> &auth_opts, bool reloading);
  67 +typedef AuthResult(*F_flashmq_auth_plugin_acl_check_v1)(void *thread_data, AclAccess access, const std::string &clientid, const std::string &username, const FlashMQMessage &msg);
  68 +typedef AuthResult(*F_flashmq_auth_plugin_login_check_v1)(void *thread_data, const std::string &username, const std::string &password);
  69 +typedef void (*F_flashmq_auth_plugin_periodic_event)(void *thread_data);
  70 +
60 extern "C" 71 extern "C"
61 { 72 {
62 // Gets called by the plugin, so it needs to exist, globally 73 // Gets called by the plugin, so it needs to exist, globally
63 void mosquitto_log_printf(int level, const char *fmt, ...); 74 void mosquitto_log_printf(int level, const char *fmt, ...);
64 } 75 }
65 76
  77 +enum class PluginVersion
  78 +{
  79 + None,
  80 + Determining,
  81 + FlashMQv1,
  82 + MosquittoV2,
  83 +};
  84 +
66 std::string AuthResultToString(AuthResult r); 85 std::string AuthResultToString(AuthResult r);
67 86
68 /** 87 /**
@@ -72,6 +91,8 @@ std::string AuthResultToString(AuthResult r); @@ -72,6 +91,8 @@ std::string AuthResultToString(AuthResult r);
72 class Authentication 91 class Authentication
73 { 92 {
74 F_auth_plugin_version version = nullptr; 93 F_auth_plugin_version version = nullptr;
  94 +
  95 + // Mosquitto functions
75 F_auth_plugin_init_v2 init_v2 = nullptr; 96 F_auth_plugin_init_v2 init_v2 = nullptr;
76 F_auth_plugin_cleanup_v2 cleanup_v2 = nullptr; 97 F_auth_plugin_cleanup_v2 cleanup_v2 = nullptr;
77 F_auth_plugin_security_init_v2 security_init_v2 = nullptr; 98 F_auth_plugin_security_init_v2 security_init_v2 = nullptr;
@@ -80,6 +101,14 @@ class Authentication @@ -80,6 +101,14 @@ class Authentication
80 F_auth_plugin_unpwd_check_v2 unpwd_check_v2 = nullptr; 101 F_auth_plugin_unpwd_check_v2 unpwd_check_v2 = nullptr;
81 F_auth_plugin_psk_key_get_v2 psk_key_get_v2 = nullptr; 102 F_auth_plugin_psk_key_get_v2 psk_key_get_v2 = nullptr;
82 103
  104 + F_flashmq_auth_plugin_allocate_thread_memory_v1 flashmq_auth_plugin_allocate_thread_memory_v1 = nullptr;
  105 + F_flashmq_auth_plugin_deallocate_thread_memory_v1 flashmq_auth_plugin_deallocate_thread_memory_v1 = nullptr;
  106 + F_flashmq_auth_plugin_init_v1 flashmq_auth_plugin_init_v1 = nullptr;
  107 + F_flashmq_auth_plugin_deinit_v1 flashmq_auth_plugin_deinit_v1 = nullptr;
  108 + F_flashmq_auth_plugin_acl_check_v1 flashmq_auth_plugin_acl_check_v1 = nullptr;
  109 + F_flashmq_auth_plugin_login_check_v1 flashmq_auth_plugin_login_check_v1 = nullptr;
  110 + F_flashmq_auth_plugin_periodic_event flashmq_auth_plugin_periodic_event_v1 = nullptr;
  111 +
83 static std::mutex initMutex; 112 static std::mutex initMutex;
84 static std::mutex authChecksMutex; 113 static std::mutex authChecksMutex;
85 114
@@ -88,7 +117,7 @@ class Authentication @@ -88,7 +117,7 @@ class Authentication
88 void *pluginData = nullptr; 117 void *pluginData = nullptr;
89 Logger *logger = nullptr; 118 Logger *logger = nullptr;
90 bool initialized = false; 119 bool initialized = false;
91 - bool useExternalPlugin = false; 120 + PluginVersion pluginVersion = PluginVersion::None;
92 bool quitting = false; 121 bool quitting = false;
93 122
94 /** 123 /**
@@ -110,7 +139,7 @@ class Authentication @@ -110,7 +139,7 @@ class Authentication
110 139
111 AclTree aclTree; 140 AclTree aclTree;
112 141
113 - void *loadSymbol(void *handle, const char *symbol) const; 142 + void *loadSymbol(void *handle, const char *symbol, bool exceptionOnError = true) const;
114 public: 143 public:
115 Authentication(Settings &settings); 144 Authentication(Settings &settings);
116 Authentication(const Authentication &other) = delete; 145 Authentication(const Authentication &other) = delete;
@@ -122,7 +151,8 @@ public: @@ -122,7 +151,8 @@ public:
122 void cleanup(); 151 void cleanup();
123 void securityInit(bool reloading); 152 void securityInit(bool reloading);
124 void securityCleanup(bool reloading); 153 void securityCleanup(bool reloading);
125 - AuthResult aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics, AclAccess access); 154 + AuthResult aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics,
  155 + AclAccess access, char qos, bool retain);
126 AuthResult unPwdCheck(const std::string &username, const std::string &password); 156 AuthResult unPwdCheck(const std::string &username, const std::string &password);
127 157
128 void setQuitting(); 158 void setQuitting();
@@ -131,6 +161,8 @@ public: @@ -131,6 +161,8 @@ public:
131 AuthResult aclCheckFromMosquittoAclFile(const std::string &clientid, const std::string &username, const std::vector<std::string> &subtopics, AclAccess access); 161 AuthResult aclCheckFromMosquittoAclFile(const std::string &clientid, const std::string &username, const std::vector<std::string> &subtopics, AclAccess access);
132 AuthResult unPwdCheckFromMosquittoPasswordFile(const std::string &username, const std::string &password); 162 AuthResult unPwdCheckFromMosquittoPasswordFile(const std::string &username, const std::string &password);
133 163
  164 + void periodicEvent();
  165 +
134 }; 166 };
135 167
136 #endif // AUTHPLUGIN_H 168 #endif // AUTHPLUGIN_H
client.cpp
@@ -91,6 +91,11 @@ bool Client::getSslWriteWantsRead() const @@ -91,6 +91,11 @@ bool Client::getSslWriteWantsRead() const
91 return ioWrapper.getSslWriteWantsRead(); 91 return ioWrapper.getSslWriteWantsRead();
92 } 92 }
93 93
  94 +ProtocolVersion Client::getProtocolVersion() const
  95 +{
  96 + return protocolVersion;
  97 +}
  98 +
94 void Client::startOrContinueSslAccept() 99 void Client::startOrContinueSslAccept()
95 { 100 {
96 ioWrapper.startOrContinueSslAccept(); 101 ioWrapper.startOrContinueSslAccept();
client.h
@@ -99,6 +99,7 @@ public: @@ -99,6 +99,7 @@ public:
99 bool isSsl() const; 99 bool isSsl() const;
100 bool getSslReadWantsWrite() const; 100 bool getSslReadWantsWrite() const;
101 bool getSslWriteWantsRead() const; 101 bool getSslWriteWantsRead() const;
  102 + ProtocolVersion getProtocolVersion() const;
102 103
103 void startOrContinueSslAccept(); 104 void startOrContinueSslAccept();
104 void markAsDisconnecting(); 105 void markAsDisconnecting();
configfileparser.cpp
@@ -92,11 +92,12 @@ ConfigFileParser::ConfigFileParser(const std::string &amp;path) : @@ -92,11 +92,12 @@ ConfigFileParser::ConfigFileParser(const std::string &amp;path) :
92 path(path) 92 path(path)
93 { 93 {
94 validKeys.insert("auth_plugin"); 94 validKeys.insert("auth_plugin");
  95 + validKeys.insert("auth_plugin_serialize_init");
  96 + validKeys.insert("auth_plugin_serialize_auth_checks");
  97 + validKeys.insert("auth_plugin_timer_period");
95 validKeys.insert("log_file"); 98 validKeys.insert("log_file");
96 validKeys.insert("allow_unsafe_clientid_chars"); 99 validKeys.insert("allow_unsafe_clientid_chars");
97 validKeys.insert("allow_unsafe_username_chars"); 100 validKeys.insert("allow_unsafe_username_chars");
98 - validKeys.insert("auth_plugin_serialize_init");  
99 - validKeys.insert("auth_plugin_serialize_auth_checks");  
100 validKeys.insert("client_initial_buffer_size"); 101 validKeys.insert("client_initial_buffer_size");
101 validKeys.insert("max_packet_size"); 102 validKeys.insert("max_packet_size");
102 validKeys.insert("log_debug"); 103 validKeys.insert("log_debug");
@@ -401,6 +402,16 @@ void ConfigFileParser::loadFile(bool test) @@ -401,6 +402,16 @@ void ConfigFileParser::loadFile(bool test)
401 } 402 }
402 tmpSettings->expireSessionsAfterSeconds = newVal; 403 tmpSettings->expireSessionsAfterSeconds = newVal;
403 } 404 }
  405 +
  406 + if (key == "auth_plugin_timer_period")
  407 + {
  408 + int newVal = std::stoi(value);
  409 + if (newVal < 0)
  410 + {
  411 + throw ConfigFileException(formatString("auth_plugin_timer_period value '%d' is invalid. Valid values are 0 or higher. 0 means disabled.", newVal));
  412 + }
  413 + tmpSettings->authPluginTimerPeriod = newVal;
  414 + }
404 } 415 }
405 } 416 }
406 catch (std::invalid_argument &ex) // catch for the stoi() 417 catch (std::invalid_argument &ex) // catch for the stoi()
@@ -410,6 +421,7 @@ void ConfigFileParser::loadFile(bool test) @@ -410,6 +421,7 @@ void ConfigFileParser::loadFile(bool test)
410 } 421 }
411 422
412 tmpSettings->authOptCompatWrap = AuthOptCompatWrap(authOpts); 423 tmpSettings->authOptCompatWrap = AuthOptCompatWrap(authOpts);
  424 + tmpSettings->flashmqAuthPluginOpts = std::move(authOpts);
413 425
414 if (!test) 426 if (!test)
415 { 427 {
@@ -417,9 +429,5 @@ void ConfigFileParser::loadFile(bool test) @@ -417,9 +429,5 @@ void ConfigFileParser::loadFile(bool test)
417 } 429 }
418 } 430 }
419 431
420 -AuthOptCompatWrap &Settings::getAuthOptsCompat()  
421 -{  
422 - return authOptCompatWrap;  
423 -}  
424 432
425 433
1 #ifndef ENUMS_H 1 #ifndef ENUMS_H
2 #define ENUMS_H 2 #define ENUMS_H
3 3
4 -// Compatible with Mosquitto  
5 -enum class AclAccess  
6 -{  
7 - none = 0,  
8 - read = 1,  
9 - write = 2  
10 -};  
11 -  
12 -// Compatible with Mosquitto  
13 -enum class AuthResult  
14 -{  
15 - success = 0,  
16 - acl_denied = 12,  
17 - login_denied = 11,  
18 - error = 13  
19 -}; 4 +#include "flashmq_plugin.h"
20 5
21 #endif // ENUMS_H 6 #endif // ENUMS_H
flashmq_plugin.cpp 0 โ†’ 100644
  1 +#include "flashmq_plugin.h"
  2 +
  3 +#include "logger.h"
  4 +
  5 +void flashmq_logf(int level, const char *str, ...)
  6 +{
  7 + Logger *logger = Logger::getInstance();
  8 +
  9 + va_list valist;
  10 + va_start(valist, str);
  11 + logger->logf(level, str, valist);
  12 + va_end(valist);
  13 +}
  14 +
  15 +FlashMQMessage::FlashMQMessage(const std::string &topic, const std::vector<std::string> &subtopics, const char qos, const bool retain) :
  16 + topic(topic),
  17 + subtopics(subtopics),
  18 + qos(qos),
  19 + retain(retain)
  20 +{
  21 +
  22 +}
flashmq_plugin.h 0 โ†’ 100644
  1 +/*
  2 + * This file is part of FlashMQ (https://www.flashmq.org). It defines the
  3 + * authentication plugin interface.
  4 + *
  5 + * This interface definition is public domain and you are encouraged
  6 + * to copy it to your authentication plugin project, for portability. Including
  7 + * this file in your project does not require your code to have a compatibile
  8 + * license nor requires you to open source it.
  9 + *
  10 + * Compile like: gcc -fPIC -shared authplugin.cpp -o authplugin.so
  11 + */
  12 +
  13 +#ifndef FLASHMQ_PLUGIN_H
  14 +#define FLASHMQ_PLUGIN_H
  15 +
  16 +#include <string>
  17 +#include <vector>
  18 +#include <unordered_map>
  19 +
  20 +#define FLASHMQ_PLUGIN_VERSION 1
  21 +
  22 +// Compatible with Mosquitto, for auth plugin compatability.
  23 +#define LOG_NONE 0x00
  24 +#define LOG_INFO 0x01
  25 +#define LOG_NOTICE 0x02
  26 +#define LOG_WARNING 0x04
  27 +#define LOG_ERR 0x08
  28 +#define LOG_DEBUG 0x10
  29 +#define LOG_SUBSCRIBE 0x20
  30 +#define LOG_UNSUBSCRIBE 0x40
  31 +
  32 +extern "C"
  33 +{
  34 +
  35 +/**
  36 + * @brief The AclAccess enum's numbers are compatible with Mosquitto's 'int access'.
  37 + *
  38 + * read = reading a publish published by someone else.
  39 + * write = doing a publish.
  40 + * subscribe = subscribing.
  41 + */
  42 +enum class AclAccess
  43 +{
  44 + none = 0,
  45 + read = 1,
  46 + write = 2,
  47 + subscribe = 4
  48 +};
  49 +
  50 +/**
  51 + * @brief The AuthResult enum's numbers are compatible with Mosquitto's auth result.
  52 + */
  53 +enum class AuthResult
  54 +{
  55 + success = 0,
  56 + acl_denied = 12,
  57 + login_denied = 11,
  58 + error = 13
  59 +};
  60 +
  61 +/**
  62 + * @brief The FlashMQMessage struct contains the meta data of a publish.
  63 + *
  64 + * The subtopics is the topic split, so you don't have to do that anymore.
  65 + *
  66 + * As for 'retain', keep in mind that for existing subscribers, this will always be false [MQTT-3.3.1-9]. Only publishes or
  67 + * retain messages as a result of a subscribe can have that set to true.
  68 + *
  69 + * For subscribtions, 'retain' is always false.
  70 + */
  71 +struct FlashMQMessage
  72 +{
  73 + const std::string &topic;
  74 + const std::vector<std::string> &subtopics;
  75 + const char qos;
  76 + const bool retain;
  77 +
  78 + FlashMQMessage(const std::string &topic, const std::vector<std::string> &subtopics, const char qos, const bool retain);
  79 +};
  80 +
  81 +/**
  82 + * @brief flashmq_logf calls the internal logger of FlashMQ. The logger mutexes all access, so is thread-safe.
  83 + * @param level is any of the levels defined above, starting with LOG_.
  84 + * @param str
  85 + *
  86 + * FlashMQ makes no distinction between INFO and NOTICE.
  87 + */
  88 +void flashmq_logf(int level, const char *str, ...);
  89 +
  90 +/**
  91 + * @brief flashmq_plugin_version must return FLASHMQ_PLUGIN_VERSION.
  92 + * @return FLASHMQ_PLUGIN_VERSION.
  93 + */
  94 +int flashmq_auth_plugin_version();
  95 +
  96 +/**
  97 + * @brief flashmq_auth_plugin_allocate_thread_memory is called once by each thread. Never again.
  98 + * @param thread_data. Create a memory structure and assign it to *thread_data.
  99 + * @param global_data. The global data created in flashmq_auth_plugin_allocate_global_memory, if you use it.
  100 + * @param auth_opts. Map of flashmq_auth_opt_* from the config file.
  101 + *
  102 + * Only allocate the plugin's memory here. Don't open connections, etc.
  103 + *
  104 + * The global data is created by flashmq_auth_plugin_allocate_global_memory() and if you need it, you can assign it to your
  105 + * own thread_data storage. It is not passed as argument to other functions.
  106 + *
  107 + * You can use static variables for global scope if you must, but do provide proper locking where necessary.
  108 + *
  109 + * throw an exception on errors.
  110 + */
  111 +void flashmq_auth_plugin_allocate_thread_memory(void **thread_data, std::unordered_map<std::string, std::string> &auth_opts);
  112 +
  113 +/**
  114 + * @brief flashmq_auth_plugin_deallocate_thread_memory is called once by each thread. Never again.
  115 + * @param thread_data. Delete this memory.
  116 + * @param auth_opts. Map of flashmq_auth_opt_* from the config file.
  117 + *
  118 + * throw an exception on errors.
  119 + */
  120 +void flashmq_auth_plugin_deallocate_thread_memory(void *thread_data, std::unordered_map<std::string, std::string> &auth_opts);
  121 +
  122 +/**
  123 + * @brief flashmq_auth_plugin_init is called on thread start and config reload. It is the main place to initialize the plugin.
  124 + * @param thread_data is memory allocated in flashmq_auth_plugin_allocate_thread_memory().
  125 + * @param auth_opts. Map of flashmq_auth_opt_* from the config file.
  126 + * @param reloading.
  127 + *
  128 + * The best approach to state keeping is doing everything per thread. You can initialize connections to database servers, load encryption keys,
  129 + * create maps, etc.
  130 + *
  131 + * Keep in mind that libraries you use may not be thread safe (by default). Sometimes they use global scope in treacherous ways. As a random
  132 + * example: Qt's QSqlDatabase needs a unique name for each connection, otherwise it is not thread safe and will crash.
  133 + *
  134 + * There is the option to set 'auth_plugin_serialize_init true' in the config file, which allows some mitigation in
  135 + * case you run into problems.
  136 + *
  137 + * throw an exception on errors.
  138 + */
  139 +void flashmq_auth_plugin_init(void *thread_data, std::unordered_map<std::string, std::string> &auth_opts, bool reloading);
  140 +
  141 +/**
  142 + * @brief flashmq_auth_plugin_deinit is called on thread stop and config reload. It is the precursor to initializing.
  143 + * @param thread_data is memory allocated in flashmq_auth_plugin_allocate_thread_memory().
  144 + * @param auth_opts. Map of flashmq_auth_opt_* from the config file.
  145 + * @param reloading
  146 + *
  147 + * throw an exception on errors.
  148 + */
  149 +void flashmq_auth_plugin_deinit(void *thread_data, std::unordered_map<std::string, std::string> &auth_opts, bool reloading);
  150 +
  151 +/**
  152 + * @brief flashmq_auth_plugin_periodic is called every x seconds as defined in the config file.
  153 + * @param thread_data is memory allocated in flashmq_auth_plugin_allocate_thread_memory().
  154 + *
  155 + * You may need to periodically refresh data from a database, post stats, etc. You can do that from here. It's queued
  156 + * in each thread at the same time, so you can perform somewhat synchronized events in all threads.
  157 + *
  158 + * Note that it's executed in the event loop, so it blocks the thread if you block here. If you need asynchronous operation,
  159 + * you can make threads yourself. Be sure to synchronize data access properly in that case.
  160 + *
  161 + * The setting auth_plugin_timer_period sets this interval in seconds.
  162 + *
  163 + * Implementing this is optional.
  164 + *
  165 + * throw an exception on errors.
  166 + */
  167 +void flashmq_auth_plugin_periodic_event(void *thread_data);
  168 +
  169 +/**
  170 + * @brief flashmq_auth_plugin_login_check is called on login of a client.
  171 + * @param thread_data is memory allocated in flashmq_auth_plugin_allocate_thread_memory().
  172 + * @param username
  173 + * @param password
  174 + * @return
  175 + *
  176 + * You could throw exceptions here, but that will be slow and pointless. It will just get converted into AuthResult::error,
  177 + * because there's nothing else to do: the state of FlashMQ won't change.
  178 + *
  179 + * Note that there is a setting 'auth_plugin_serialize_auth_checks'. Use only as a last resort if your plugin is not
  180 + * thread-safe. It will negate much of FlashMQ's multi-core model.
  181 + */
  182 +AuthResult flashmq_auth_plugin_login_check(void *thread_data, const std::string &username, const std::string &password);
  183 +
  184 +/**
  185 + * @brief flashmq_auth_plugin_acl_check is called on publish, deliver and subscribe.
  186 + * @param thread_data is memory allocated in flashmq_auth_plugin_allocate_thread_memory().
  187 + * @param access
  188 + * @param clientid
  189 + * @param username
  190 + * @param msg. See FlashMQMessage.
  191 + * @return
  192 + *
  193 + * You could throw exceptions here, but that will be slow and pointless. It will just get converted into AuthResult::error,
  194 + * because there's nothing else to do: the state of FlashMQ won't change.
  195 + *
  196 + * Controlling subscribe access can have several benefits. For instance, you may want to avoid subscriptions that cause
  197 + * a lot of server load. If clients pester you with many subscriptions like '+/+/+/+/+/+/+/+/+/', that causes a lot
  198 + * of tree walking. Similarly, if all clients subscribe to '#' because it's easy, every single message passing through
  199 + * the server will have to be ACL checked for every subscriber.
  200 + *
  201 + * Note that only MQTT 3.1.1 or higher has a 'failed' return code for subscribing, so older clients will see a normal
  202 + * ack and won't know it failed.
  203 + *
  204 + * Note that there is a setting 'auth_plugin_serialize_auth_checks'. Use only as a last resort if your plugin is not
  205 + * thread-safe. It will negate much of FlashMQ's multi-core model.
  206 + */
  207 +AuthResult flashmq_auth_plugin_acl_check(void *thread_data, AclAccess access, const std::string &clientid, const std::string &username, const FlashMQMessage &msg);
  208 +
  209 +}
  210 +
  211 +#endif // FLASHMQ_PLUGIN_H
logger.h
@@ -22,16 +22,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -22,16 +22,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
22 #include <stdarg.h> 22 #include <stdarg.h>
23 #include <mutex> 23 #include <mutex>
24 24
25 -// Compatible with Mosquitto, for auth plugin compatability.  
26 -// Can be OR'ed together.  
27 -#define LOG_NONE 0x00  
28 -#define LOG_INFO 0x01  
29 -#define LOG_NOTICE 0x02  
30 -#define LOG_WARNING 0x04  
31 -#define LOG_ERR 0x08  
32 -#define LOG_DEBUG 0x10  
33 -#define LOG_SUBSCRIBE 0x20  
34 -#define LOG_UNSUBSCRIBE 0x40 25 +#include "flashmq_plugin.h"
35 26
36 int logSslError(const char *str, size_t len, void *u); 27 int logSslError(const char *str, size_t len, void *u);
37 28
mainapp.cpp
@@ -196,6 +196,12 @@ MainApp::MainApp(const std::string &amp;configFilePath) : @@ -196,6 +196,12 @@ MainApp::MainApp(const std::string &amp;configFilePath) :
196 auto fPublishStats = std::bind(&MainApp::publishStatsOnDollarTopic, this); 196 auto fPublishStats = std::bind(&MainApp::publishStatsOnDollarTopic, this);
197 timer.addCallback(fPublishStats, 10000, "Publish stats on $SYS"); 197 timer.addCallback(fPublishStats, 10000, "Publish stats on $SYS");
198 publishStatsOnDollarTopic(); 198 publishStatsOnDollarTopic();
  199 +
  200 + if (settings->authPluginTimerPeriod > 0)
  201 + {
  202 + auto fAuthPluginPeriodicEvent = std::bind(&MainApp::queueAuthPluginPeriodicEventAllThreads, this);
  203 + timer.addCallback(fAuthPluginPeriodicEvent, settings->authPluginTimerPeriod*1000, "Auth plugin periodic event.");
  204 + }
199 } 205 }
200 206
201 MainApp::~MainApp() 207 MainApp::~MainApp()
@@ -310,6 +316,14 @@ void MainApp::queuePasswordFileReloadAllThreads() @@ -310,6 +316,14 @@ void MainApp::queuePasswordFileReloadAllThreads()
310 } 316 }
311 } 317 }
312 318
  319 +void MainApp::queueAuthPluginPeriodicEventAllThreads()
  320 +{
  321 + for (std::shared_ptr<ThreadData> &thread : threads)
  322 + {
  323 + thread->queueAuthPluginPeriodicEvent();
  324 + }
  325 +}
  326 +
313 void MainApp::setFuzzFile(const std::string &fuzzFilePath) 327 void MainApp::setFuzzFile(const std::string &fuzzFilePath)
314 { 328 {
315 this->fuzzFilePath = fuzzFilePath; 329 this->fuzzFilePath = fuzzFilePath;
@@ -505,6 +519,7 @@ void MainApp::start() @@ -505,6 +519,7 @@ void MainApp::start()
505 try 519 try
506 { 520 {
507 std::vector<MqttPacket> packetQueueIn; 521 std::vector<MqttPacket> packetQueueIn;
  522 + std::vector<std::string> subtopics;
508 523
509 std::shared_ptr<ThreadData> threaddata(new ThreadData(0, subscriptionStore, settings)); 524 std::shared_ptr<ThreadData> threaddata(new ThreadData(0, subscriptionStore, settings));
510 525
@@ -518,10 +533,11 @@ void MainApp::start() @@ -518,10 +533,11 @@ void MainApp::start()
518 websocketsubscriber->setAuthenticated(true); 533 websocketsubscriber->setAuthenticated(true);
519 websocketsubscriber->setFakeUpgraded(); 534 websocketsubscriber->setFakeUpgraded();
520 subscriptionStore->registerClientAndKickExistingOne(websocketsubscriber); 535 subscriptionStore->registerClientAndKickExistingOne(websocketsubscriber);
521 - subscriptionStore->addSubscription(websocketsubscriber, "#", 0); 536 + splitTopic("#", subtopics);
  537 + subscriptionStore->addSubscription(websocketsubscriber, "#", subtopics, 0);
522 538
523 subscriptionStore->registerClientAndKickExistingOne(subscriber); 539 subscriptionStore->registerClientAndKickExistingOne(subscriber);
524 - subscriptionStore->addSubscription(subscriber, "#", 0); 540 + subscriptionStore->addSubscription(subscriber, "#", subtopics, 0);
525 541
526 if (fuzzWebsockets && strContains(fuzzFilePathLower, "upgrade")) 542 if (fuzzWebsockets && strContains(fuzzFilePathLower, "upgrade"))
527 { 543 {
mainapp.h
@@ -76,6 +76,7 @@ class MainApp @@ -76,6 +76,7 @@ class MainApp
76 void wakeUpThread(); 76 void wakeUpThread();
77 void queueKeepAliveCheckAtAllThreads(); 77 void queueKeepAliveCheckAtAllThreads();
78 void queuePasswordFileReloadAllThreads(); 78 void queuePasswordFileReloadAllThreads();
  79 + void queueAuthPluginPeriodicEventAllThreads();
79 void setFuzzFile(const std::string &fuzzFilePath); 80 void setFuzzFile(const std::string &fuzzFilePath);
80 void publishStatsOnDollarTopic(); 81 void publishStatsOnDollarTopic();
81 void publishStat(const std::string &topic, uint64_t n); 82 void publishStat(const std::string &topic, uint64_t n);
mqttpacket.cpp
@@ -407,6 +407,7 @@ void MqttPacket::handleDisconnect() @@ -407,6 +407,7 @@ void MqttPacket::handleDisconnect()
407 407
408 void MqttPacket::handleSubscribe() 408 void MqttPacket::handleSubscribe()
409 { 409 {
  410 + this->subtopics = &gSubtopics;
410 const char firstByteFirstNibble = (first_byte & 0x0F); 411 const char firstByteFirstNibble = (first_byte & 0x0F);
411 412
412 if (firstByteFirstNibble != 2) 413 if (firstByteFirstNibble != 2)
@@ -431,9 +432,23 @@ void MqttPacket::handleSubscribe() @@ -431,9 +432,23 @@ void MqttPacket::handleSubscribe()
431 if (qos > 2) 432 if (qos > 2)
432 throw ProtocolError("QoS is greater than 2, and/or reserved bytes in QoS field are not 0."); 433 throw ProtocolError("QoS is greater than 2, and/or reserved bytes in QoS field are not 0.");
433 434
434 - logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str());  
435 - sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, qos);  
436 - subs_reponse_codes.push_back(qos); 435 + splitTopic(topic, *subtopics);
  436 + if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::subscribe, qos, false) == AuthResult::success)
  437 + {
  438 + logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str());
  439 + sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, *subtopics, qos);
  440 + subs_reponse_codes.push_back(qos);
  441 + }
  442 + else
  443 + {
  444 + logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribe to '%s' denied or failed.", sender->repr().c_str(), topic.c_str());
  445 +
  446 + // We can't not send an ack, because if there are multiple subscribes, you send fewer acks back, losing sync.
  447 + char return_code = qos;
  448 + if (sender->getProtocolVersion() >= ProtocolVersion::Mqtt311)
  449 + return_code = static_cast<char>(SubAckReturnCodes::Fail);
  450 + subs_reponse_codes.push_back(return_code);
  451 + }
437 } 452 }
438 453
439 SubAck subAck(packet_id, subs_reponse_codes); 454 SubAck subAck(packet_id, subs_reponse_codes);
@@ -531,7 +546,7 @@ void MqttPacket::handlePublish() @@ -531,7 +546,7 @@ void MqttPacket::handlePublish()
531 } 546 }
532 } 547 }
533 548
534 - if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write) == AuthResult::success) 549 + if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success)
535 { 550 {
536 if (retain) 551 if (retain)
537 { 552 {
session.cpp
@@ -48,14 +48,20 @@ void Session::assignActiveConnection(std::shared_ptr&lt;Client&gt; &amp;client) @@ -48,14 +48,20 @@ void Session::assignActiveConnection(std::shared_ptr&lt;Client&gt; &amp;client)
48 this->thread = client->getThreadData(); 48 this->thread = client->getThreadData();
49 } 49 }
50 50
51 -void Session::writePacket(const MqttPacket &packet, char max_qos, uint64_t &count) 51 +/**
  52 + * @brief Session::writePacket is the main way to give a client a packet -> it goes through the session.
  53 + * @param packet
  54 + * @param max_qos
  55 + * @param retain. Keep MQTT-3.3.1-9 in mind: existing subscribers don't get retain=1 on packets.
  56 + * @param count. Reference value is updated. It's for statistics.
  57 + */
  58 +void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count)
52 { 59 {
53 assert(max_qos <= 2); 60 assert(max_qos <= 2);
  61 + const char qos = std::min<char>(packet.getQos(), max_qos);
54 62
55 - if (thread->authentication.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read) == AuthResult::success) 63 + if (thread->authentication.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success)
56 { 64 {
57 - const char qos = std::min<char>(packet.getQos(), max_qos);  
58 -  
59 if (qos == 0) 65 if (qos == 0)
60 { 66 {
61 if (!clientDisconnected()) 67 if (!clientDisconnected())
session.h
@@ -61,7 +61,7 @@ public: @@ -61,7 +61,7 @@ public:
61 bool clientDisconnected() const; 61 bool clientDisconnected() const;
62 std::shared_ptr<Client> makeSharedClient() const; 62 std::shared_ptr<Client> makeSharedClient() const;
63 void assignActiveConnection(std::shared_ptr<Client> &client); 63 void assignActiveConnection(std::shared_ptr<Client> &client);
64 - void writePacket(const MqttPacket &packet, char max_qos, uint64_t &count); 64 + void writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count);
65 void clearQosMessage(uint16_t packet_id); 65 void clearQosMessage(uint16_t packet_id);
66 uint64_t sendPendingQosMessages(); 66 uint64_t sendPendingQosMessages();
67 void touch(std::chrono::time_point<std::chrono::steady_clock> val); 67 void touch(std::chrono::time_point<std::chrono::steady_clock> val);
settings.cpp
@@ -18,3 +18,12 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -18,3 +18,12 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
18 #include "settings.h" 18 #include "settings.h"
19 19
20 20
  21 +AuthOptCompatWrap &Settings::getAuthOptsCompat()
  22 +{
  23 + return authOptCompatWrap;
  24 +}
  25 +
  26 +std::unordered_map<std::string, std::string> &Settings::getFlashmqAuthPluginOpts()
  27 +{
  28 + return this->flashmqAuthPluginOpts;
  29 +}
settings.h
@@ -29,6 +29,7 @@ class Settings @@ -29,6 +29,7 @@ class Settings
29 friend class ConfigFileParser; 29 friend class ConfigFileParser;
30 30
31 AuthOptCompatWrap authOptCompatWrap; 31 AuthOptCompatWrap authOptCompatWrap;
  32 + std::unordered_map<std::string, std::string> flashmqAuthPluginOpts;
32 33
33 public: 34 public:
34 // Actual config options with their defaults. 35 // Actual config options with their defaults.
@@ -47,9 +48,11 @@ public: @@ -47,9 +48,11 @@ public:
47 bool allowAnonymous = false; 48 bool allowAnonymous = false;
48 int rlimitNoFile = 1000000; 49 int rlimitNoFile = 1000000;
49 uint64_t expireSessionsAfterSeconds = 1209600; 50 uint64_t expireSessionsAfterSeconds = 1209600;
  51 + int authPluginTimerPeriod = 60;
50 std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined. 52 std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined.
51 53
52 AuthOptCompatWrap &getAuthOptsCompat(); 54 AuthOptCompatWrap &getAuthOptsCompat();
  55 + std::unordered_map<std::string, std::string> &getFlashmqAuthPluginOpts();
53 }; 56 };
54 57
55 #endif // SETTINGS_H 58 #endif // SETTINGS_H
subscriptionstore.cpp
@@ -84,10 +84,8 @@ SubscriptionStore::SubscriptionStore() : @@ -84,10 +84,8 @@ SubscriptionStore::SubscriptionStore() :
84 84
85 } 85 }
86 86
87 -void SubscriptionStore::addSubscription(std::shared_ptr<Client> &client, const std::string &topic, char qos) 87 +void SubscriptionStore::addSubscription(std::shared_ptr<Client> &client, const std::string &topic, const std::vector<std::string> &subtopics, char qos)
88 { 88 {
89 - const std::list<std::string> subtopics = split(topic, '/');  
90 -  
91 SubscriptionNode *deepestNode = &root; 89 SubscriptionNode *deepestNode = &root;
92 if (topic.length() > 0 && topic[0] == '$') 90 if (topic.length() > 0 && topic[0] == '$')
93 deepestNode = &rootDollar; 91 deepestNode = &rootDollar;
@@ -242,7 +240,7 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &amp;packet, const st @@ -242,7 +240,7 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &amp;packet, const st
242 if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect. 240 if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect.
243 { 241 {
244 const std::shared_ptr<Session> session = session_weak.lock(); 242 const std::shared_ptr<Session> session = session_weak.lock();
245 - session->writePacket(packet, sub.qos, count); 243 + session->writePacket(packet, sub.qos, false, count);
246 } 244 }
247 } 245 }
248 } 246 }
@@ -330,7 +328,7 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr&lt;Ses @@ -330,7 +328,7 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr&lt;Ses
330 328
331 if (topicsMatch(subscribe_topic, rm.topic)) 329 if (topicsMatch(subscribe_topic, rm.topic))
332 { 330 {
333 - ses->writePacket(packet, max_qos, count); 331 + ses->writePacket(packet, max_qos, true, count);
334 } 332 }
335 } 333 }
336 334
subscriptionstore.h
@@ -89,7 +89,7 @@ class SubscriptionStore @@ -89,7 +89,7 @@ class SubscriptionStore
89 public: 89 public:
90 SubscriptionStore(); 90 SubscriptionStore();
91 91
92 - void addSubscription(std::shared_ptr<Client> &client, const std::string &topic, char qos); 92 + void addSubscription(std::shared_ptr<Client> &client, const std::string &topic, const std::vector<std::string> &subtopics, char qos);
93 void removeSubscription(std::shared_ptr<Client> &client, const std::string &topic); 93 void removeSubscription(std::shared_ptr<Client> &client, const std::string &topic);
94 void registerClientAndKickExistingOne(std::shared_ptr<Client> &client); 94 void registerClientAndKickExistingOne(std::shared_ptr<Client> &client);
95 bool sessionPresent(const std::string &clientid); 95 bool sessionPresent(const std::string &clientid);
threaddata.cpp
@@ -142,6 +142,7 @@ void ThreadData::queueQuit() @@ -142,6 +142,7 @@ void ThreadData::queueQuit()
142 void ThreadData::waitForQuit() 142 void ThreadData::waitForQuit()
143 { 143 {
144 thread.join(); 144 thread.join();
  145 + authentication.cleanup();
145 } 146 }
146 147
147 void ThreadData::queuePasswdFileReload() 148 void ThreadData::queuePasswdFileReload()
@@ -210,6 +211,21 @@ uint64_t ThreadData::getSentMessagePerSecond() @@ -210,6 +211,21 @@ uint64_t ThreadData::getSentMessagePerSecond()
210 return result; 211 return result;
211 } 212 }
212 213
  214 +void ThreadData::queueAuthPluginPeriodicEvent()
  215 +{
  216 + std::lock_guard<std::mutex> locker(taskQueueMutex);
  217 +
  218 + auto f = std::bind(&ThreadData::authPluginPeriodicEvent, this);
  219 + taskQueue.push_front(f);
  220 +
  221 + wakeUpThread();
  222 +}
  223 +
  224 +void ThreadData::authPluginPeriodicEvent()
  225 +{
  226 + authentication.periodicEvent();
  227 +}
  228 +
213 // TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial? 229 // TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial?
214 void ThreadData::doKeepAliveCheck() 230 void ThreadData::doKeepAliveCheck()
215 { 231 {
@@ -266,13 +282,9 @@ void ThreadData::reload(std::shared_ptr&lt;Settings&gt; settings) @@ -266,13 +282,9 @@ void ThreadData::reload(std::shared_ptr&lt;Settings&gt; settings)
266 authentication.securityCleanup(true); 282 authentication.securityCleanup(true);
267 authentication.securityInit(true); 283 authentication.securityInit(true);
268 } 284 }
269 - catch (AuthPluginException &ex)  
270 - {  
271 - logger->logf(LOG_ERR, "Error reloading auth plugin: %s. Security checks will now fail, because we don't know the status of the plugin anymore.", ex.what());  
272 - }  
273 catch (std::exception &ex) 285 catch (std::exception &ex)
274 { 286 {
275 - logger->logf(LOG_ERR, "Error reloading: %s.", ex.what()); 287 + logger->logf(LOG_ERR, "Error reloading auth plugin: %s. Security checks will now fail, because we don't know the status of the plugin anymore.", ex.what());
276 } 288 }
277 } 289 }
278 290
threaddata.h
@@ -101,6 +101,9 @@ public: @@ -101,6 +101,9 @@ public:
101 void incrementSentMessageCount(uint64_t n); 101 void incrementSentMessageCount(uint64_t n);
102 uint64_t getSentMessageCount() const; 102 uint64_t getSentMessageCount() const;
103 uint64_t getSentMessagePerSecond(); 103 uint64_t getSentMessagePerSecond();
  104 +
  105 + void queueAuthPluginPeriodicEvent();
  106 + void authPluginPeriodicEvent();
104 }; 107 };
105 108
106 #endif // THREADDATA_H 109 #endif // THREADDATA_H
types.cpp
@@ -15,6 +15,8 @@ You should have received a copy of the GNU Affero General Public @@ -15,6 +15,8 @@ You should have received a copy of the GNU Affero General Public
15 License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. 15 License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
16 */ 16 */
17 17
  18 +#include "cassert"
  19 +
18 #include "types.h" 20 #include "types.h"
19 21
20 ConnAck::ConnAck(ConnAckReturnCodes return_code, bool session_present) : 22 ConnAck::ConnAck(ConnAckReturnCodes return_code, bool session_present) :
@@ -29,6 +31,8 @@ ConnAck::ConnAck(ConnAckReturnCodes return_code, bool session_present) : @@ -29,6 +31,8 @@ ConnAck::ConnAck(ConnAckReturnCodes return_code, bool session_present) :
29 SubAck::SubAck(uint16_t packet_id, const std::list<char> &subs_qos_reponses) : 31 SubAck::SubAck(uint16_t packet_id, const std::list<char> &subs_qos_reponses) :
30 packet_id(packet_id) 32 packet_id(packet_id)
31 { 33 {
  34 + assert(!subs_qos_reponses.empty());
  35 +
32 for (char ack_code : subs_qos_reponses) 36 for (char ack_code : subs_qos_reponses)
33 { 37 {
34 responses.push_back(static_cast<SubAckReturnCodes>(ack_code)); 38 responses.push_back(static_cast<SubAckReturnCodes>(ack_code));