Commit aa19d41b6524de1e1bfff40a32e4f8344b0a5eeb

Authored by Wiebe Cazemier
1 parent e0d2e93c

Implement a native FlashMQ auth plugin

CMakeLists.txt
... ... @@ -49,6 +49,7 @@ add_executable(FlashMQ
49 49 acltree.h
50 50 enums.h
51 51 threadlocalutils.h
  52 + flashmq_plugin.h
52 53  
53 54 mainapp.cpp
54 55 main.cpp
... ... @@ -79,6 +80,7 @@ add_executable(FlashMQ
79 80 evpencodectxmanager.cpp
80 81 acltree.cpp
81 82 threadlocalutils.cpp
  83 + flashmq_plugin.cpp
82 84  
83 85 )
84 86  
... ...
authplugin.cpp
... ... @@ -74,11 +74,11 @@ Authentication::~Authentication()
74 74 EVP_MD_CTX_free(mosquittoDigestContext);
75 75 }
76 76  
77   -void *Authentication::loadSymbol(void *handle, const char *symbol) const
  77 +void *Authentication::loadSymbol(void *handle, const char *symbol, bool exceptionOnError) const
78 78 {
79 79 void *r = dlsym(handle, symbol);
80 80  
81   - if (r == NULL)
  81 + if (r == NULL && exceptionOnError)
82 82 {
83 83 std::string errmsg(dlerror());
84 84 throw FatalError(errmsg);
... ... @@ -95,7 +95,7 @@ void Authentication::loadPlugin(const std::string &pathToSoFile)
95 95 logger->logf(LOG_NOTICE, "Loading auth plugin %s", pathToSoFile.c_str());
96 96  
97 97 initialized = false;
98   - useExternalPlugin = true;
  98 + pluginVersion = PluginVersion::Determining;
99 99  
100 100 if (access(pathToSoFile.c_str(), R_OK) != 0)
101 101 {
... ... @@ -112,20 +112,41 @@ void Authentication::loadPlugin(const std::string &pathToSoFile)
112 112 throw FatalError(errmsg);
113 113 }
114 114  
115   - version = (F_auth_plugin_version)loadSymbol(r, "mosquitto_auth_plugin_version");
116   -
117   - if (version() != 2)
  115 + version = (F_auth_plugin_version)loadSymbol(r, "mosquitto_auth_plugin_version", false);
  116 + if (version != nullptr)
118 117 {
119   - throw FatalError("Only Mosquitto plugin version 2 is supported at this time.");
  118 + if (version() != 2)
  119 + {
  120 + throw FatalError("Only Mosquitto plugin version 2 is supported at this time.");
  121 + }
  122 +
  123 + pluginVersion = PluginVersion::MosquittoV2;
  124 +
  125 + init_v2 = (F_auth_plugin_init_v2)loadSymbol(r, "mosquitto_auth_plugin_init");
  126 + cleanup_v2 = (F_auth_plugin_cleanup_v2)loadSymbol(r, "mosquitto_auth_plugin_cleanup");
  127 + security_init_v2 = (F_auth_plugin_security_init_v2)loadSymbol(r, "mosquitto_auth_security_init");
  128 + security_cleanup_v2 = (F_auth_plugin_security_cleanup_v2)loadSymbol(r, "mosquitto_auth_security_cleanup");
  129 + acl_check_v2 = (F_auth_plugin_acl_check_v2)loadSymbol(r, "mosquitto_auth_acl_check");
  130 + unpwd_check_v2 = (F_auth_plugin_unpwd_check_v2)loadSymbol(r, "mosquitto_auth_unpwd_check");
  131 + psk_key_get_v2 = (F_auth_plugin_psk_key_get_v2)loadSymbol(r, "mosquitto_auth_psk_key_get");
120 132 }
  133 + else if ((version = (F_auth_plugin_version)loadSymbol(r, "flashmq_auth_plugin_version", false)) != nullptr)
  134 + {
  135 + if (version() != 1)
  136 + {
  137 + throw FatalError("FlashMQ plugin only supports version 1.");
  138 + }
  139 +
  140 + pluginVersion = PluginVersion::FlashMQv1;
121 141  
122   - init_v2 = (F_auth_plugin_init_v2)loadSymbol(r, "mosquitto_auth_plugin_init");
123   - cleanup_v2 = (F_auth_plugin_cleanup_v2)loadSymbol(r, "mosquitto_auth_plugin_cleanup");
124   - security_init_v2 = (F_auth_plugin_security_init_v2)loadSymbol(r, "mosquitto_auth_security_init");
125   - security_cleanup_v2 = (F_auth_plugin_security_cleanup_v2)loadSymbol(r, "mosquitto_auth_security_cleanup");
126   - acl_check_v2 = (F_auth_plugin_acl_check_v2)loadSymbol(r, "mosquitto_auth_acl_check");
127   - unpwd_check_v2 = (F_auth_plugin_unpwd_check_v2)loadSymbol(r, "mosquitto_auth_unpwd_check");
128   - psk_key_get_v2 = (F_auth_plugin_psk_key_get_v2)loadSymbol(r, "mosquitto_auth_psk_key_get");
  142 + flashmq_auth_plugin_allocate_thread_memory_v1 = (F_flashmq_auth_plugin_allocate_thread_memory_v1)loadSymbol(r, "flashmq_auth_plugin_allocate_thread_memory");
  143 + flashmq_auth_plugin_deallocate_thread_memory_v1 = (F_flashmq_auth_plugin_deallocate_thread_memory_v1)loadSymbol(r, "flashmq_auth_plugin_deallocate_thread_memory");
  144 + flashmq_auth_plugin_init_v1 = (F_flashmq_auth_plugin_init_v1)loadSymbol(r, "flashmq_auth_plugin_init");
  145 + flashmq_auth_plugin_deinit_v1 = (F_flashmq_auth_plugin_deinit_v1)loadSymbol(r, "flashmq_auth_plugin_deinit");
  146 + flashmq_auth_plugin_acl_check_v1 = (F_flashmq_auth_plugin_acl_check_v1)loadSymbol(r, "flashmq_auth_plugin_acl_check");
  147 + flashmq_auth_plugin_login_check_v1 = (F_flashmq_auth_plugin_login_check_v1)loadSymbol(r, "flashmq_auth_plugin_login_check");
  148 + flashmq_auth_plugin_periodic_event_v1 = (F_flashmq_auth_plugin_periodic_event)loadSymbol(r, "flashmq_auth_plugin_periodic_event", false);
  149 + }
129 150  
130 151 initialized = true;
131 152 }
... ... @@ -136,7 +157,7 @@ void Authentication::loadPlugin(const std::string &pathToSoFile)
136 157 */
137 158 void Authentication::init()
138 159 {
139   - if (!useExternalPlugin)
  160 + if (pluginVersion == PluginVersion::None)
140 161 return;
141 162  
142 163 UnscopedLock lock(initMutex);
... ... @@ -146,23 +167,46 @@ void Authentication::init()
146 167 if (quitting)
147 168 return;
148 169  
149   - AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
150   - int result = init_v2(&pluginData, authOpts.head(), authOpts.size());
151   - if (result != 0)
152   - throw FatalError("Error initialising auth plugin.");
  170 + if (pluginVersion == PluginVersion::MosquittoV2)
  171 + {
  172 + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
  173 + int result = init_v2(&pluginData, authOpts.head(), authOpts.size());
  174 + if (result != 0)
  175 + throw FatalError("Error initialising auth plugin.");
  176 + }
  177 + else if (pluginVersion == PluginVersion::FlashMQv1)
  178 + {
  179 + std::unordered_map<std::string, std::string> &authOpts = settings.getFlashmqAuthPluginOpts();
  180 + flashmq_auth_plugin_allocate_thread_memory_v1(&pluginData, authOpts);
  181 + }
153 182 }
154 183  
155 184 void Authentication::cleanup()
156 185 {
157   - if (!cleanup_v2)
  186 + if (pluginVersion == PluginVersion::None)
158 187 return;
159 188  
160 189 securityCleanup(false);
161 190  
162   - AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
163   - int result = cleanup_v2(pluginData, authOpts.head(), authOpts.size());
164   - if (result != 0)
165   - logger->logf(LOG_ERR, "Error cleaning up auth plugin"); // Not doing exception, because we're shutting down anyway.
  191 + if (pluginVersion == PluginVersion::MosquittoV2)
  192 + {
  193 + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
  194 + int result = cleanup_v2(pluginData, authOpts.head(), authOpts.size());
  195 + if (result != 0)
  196 + logger->logf(LOG_ERR, "Error cleaning up auth plugin"); // Not doing exception, because we're shutting down anyway.
  197 + }
  198 + else if (pluginVersion == PluginVersion::FlashMQv1)
  199 + {
  200 + try
  201 + {
  202 + std::unordered_map<std::string, std::string> &authOpts = settings.getFlashmqAuthPluginOpts();
  203 + flashmq_auth_plugin_deallocate_thread_memory_v1(pluginData, authOpts);
  204 + }
  205 + catch (std::exception &ex)
  206 + {
  207 + logger->logf(LOG_ERR, "Error cleaning up auth plugin: '%s'", ex.what()); // Not doing exception, because we're shutting down anyway.
  208 + }
  209 + }
166 210 }
167 211  
168 212 /**
... ... @@ -171,7 +215,7 @@ void Authentication::cleanup()
171 215 */
172 216 void Authentication::securityInit(bool reloading)
173 217 {
174   - if (!useExternalPlugin)
  218 + if (pluginVersion == PluginVersion::None)
175 219 return;
176 220  
177 221 UnscopedLock lock(initMutex);
... ... @@ -181,31 +225,52 @@ void Authentication::securityInit(bool reloading)
181 225 if (quitting)
182 226 return;
183 227  
184   - AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
185   - int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading);
186   - if (result != 0)
  228 + if (pluginVersion == PluginVersion::MosquittoV2)
  229 + {
  230 + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
  231 + int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading);
  232 + if (result != 0)
  233 + {
  234 + throw AuthPluginException("Plugin function mosquitto_auth_security_init returned an error. If it didn't log anything, we don't know what it was.");
  235 + }
  236 + }
  237 + else if (pluginVersion == PluginVersion::FlashMQv1)
187 238 {
188   - throw AuthPluginException("Plugin function mosquitto_auth_security_init returned an error. If it didn't log anything, we don't know what it was.");
  239 + std::unordered_map<std::string, std::string> &authOpts = settings.getFlashmqAuthPluginOpts();
  240 + flashmq_auth_plugin_init_v1(pluginData, authOpts, reloading);
189 241 }
  242 +
190 243 initialized = true;
  244 +
  245 + periodicEvent();
191 246 }
192 247  
193 248 void Authentication::securityCleanup(bool reloading)
194 249 {
195   - if (!useExternalPlugin)
  250 + if (pluginVersion == PluginVersion::None)
196 251 return;
197 252  
198 253 initialized = false;
199   - AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
200   - int result = security_cleanup_v2(pluginData, authOpts.head(), authOpts.size(), reloading);
201 254  
202   - if (result != 0)
  255 + if (pluginVersion == PluginVersion::MosquittoV2)
  256 + {
  257 + AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat();
  258 + int result = security_cleanup_v2(pluginData, authOpts.head(), authOpts.size(), reloading);
  259 +
  260 + if (result != 0)
  261 + {
  262 + throw AuthPluginException("Plugin function mosquitto_auth_security_cleanup returned an error. If it didn't log anything, we don't know what it was.");
  263 + }
  264 + }
  265 + else if (pluginVersion == PluginVersion::FlashMQv1)
203 266 {
204   - throw AuthPluginException("Plugin function mosquitto_auth_security_cleanup returned an error. If it didn't log anything, we don't know what it was.");
  267 + std::unordered_map<std::string, std::string> &authOpts = settings.getFlashmqAuthPluginOpts();
  268 + flashmq_auth_plugin_deinit_v1(pluginData, authOpts, reloading);
205 269 }
206 270 }
207 271  
208   -AuthResult Authentication::aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics, AclAccess access)
  272 +AuthResult Authentication::aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics,
  273 + AclAccess access, char qos, bool retain)
209 274 {
210 275 assert(subtopics.size() > 0);
211 276  
... ... @@ -214,7 +279,7 @@ AuthResult Authentication::aclCheck(const std::string &amp;clientid, const std::stri
214 279 if (firstResult != AuthResult::success)
215 280 return firstResult;
216 281  
217   - if (!useExternalPlugin)
  282 + if (pluginVersion == PluginVersion::None)
218 283 return firstResult;
219 284  
220 285 if (!initialized)
... ... @@ -227,15 +292,33 @@ AuthResult Authentication::aclCheck(const std::string &amp;clientid, const std::stri
227 292 if (settings.authPluginSerializeAuthChecks)
228 293 lock.lock();
229 294  
230   - int result = acl_check_v2(pluginData, clientid.c_str(), username.c_str(), topic.c_str(), static_cast<int>(access));
231   - AuthResult result_ = static_cast<AuthResult>(result);
  295 + if (pluginVersion == PluginVersion::MosquittoV2)
  296 + {
  297 + int result = acl_check_v2(pluginData, clientid.c_str(), username.c_str(), topic.c_str(), static_cast<int>(access));
  298 + AuthResult result_ = static_cast<AuthResult>(result);
232 299  
233   - if (result_ == AuthResult::error)
  300 + if (result_ == AuthResult::error)
  301 + {
  302 + logger->logf(LOG_ERR, "ACL check by plugin returned error for topic '%s'. If it didn't log anything, we don't know what it was.", topic.c_str());
  303 + }
  304 + }
  305 + else if (pluginVersion == PluginVersion::FlashMQv1)
234 306 {
235   - logger->logf(LOG_ERR, "ACL check by plugin returned error for topic '%s'. If it didn't log anything, we don't know what it was.", topic.c_str());
  307 + // I'm using this try/catch because propagating the exception higher up conflicts with who gets the blame, and then the publisher
  308 + // gets disconnected.
  309 + try
  310 + {
  311 + FlashMQMessage msg(topic, subtopics, qos, retain);
  312 + return flashmq_auth_plugin_acl_check_v1(pluginData, access, clientid, username, msg);
  313 + }
  314 + catch (std::exception &ex)
  315 + {
  316 + logger->logf(LOG_ERR, "Error doing ACL check in plugin: '%s'", ex.what());
  317 + logger->logf(LOG_WARNING, "Throwing exceptions from auth plugin login/ACL checks is slow. There's no need.");
  318 + }
236 319 }
237 320  
238   - return result_;
  321 + return AuthResult::error;
239 322 }
240 323  
241 324 AuthResult Authentication::unPwdCheck(const std::string &username, const std::string &password)
... ... @@ -245,7 +328,7 @@ AuthResult Authentication::unPwdCheck(const std::string &amp;username, const std::st
245 328 if (firstResult != AuthResult::success)
246 329 return firstResult;
247 330  
248   - if (!useExternalPlugin)
  331 + if (pluginVersion == PluginVersion::None)
249 332 return firstResult;
250 333  
251 334 if (!initialized)
... ... @@ -258,15 +341,32 @@ AuthResult Authentication::unPwdCheck(const std::string &amp;username, const std::st
258 341 if (settings.authPluginSerializeAuthChecks)
259 342 lock.lock();
260 343  
261   - int result = unpwd_check_v2(pluginData, username.c_str(), password.c_str());
262   - AuthResult r = static_cast<AuthResult>(result);
  344 + if (pluginVersion == PluginVersion::MosquittoV2)
  345 + {
  346 + int result = unpwd_check_v2(pluginData, username.c_str(), password.c_str());
  347 + AuthResult r = static_cast<AuthResult>(result);
263 348  
264   - if (r == AuthResult::error)
  349 + if (r == AuthResult::error)
  350 + {
  351 + logger->logf(LOG_ERR, "Username+password check by plugin returned error for user '%s'. If it didn't log anything, we don't know what it was.", username.c_str());
  352 + }
  353 + }
  354 + else if (pluginVersion == PluginVersion::FlashMQv1)
265 355 {
266   - logger->logf(LOG_ERR, "Username+password check by plugin returned error for user '%s'. If it didn't log anything, we don't know what it was.", username.c_str());
  356 + // I'm using this try/catch because propagating the exception higher up conflicts with who gets the blame, and then the publisher
  357 + // gets disconnected.
  358 + try
  359 + {
  360 + return flashmq_auth_plugin_login_check_v1(pluginData, username, password);
  361 + }
  362 + catch (std::exception &ex)
  363 + {
  364 + logger->logf(LOG_ERR, "Error doing login check in plugin: '%s'", ex.what());
  365 + logger->logf(LOG_WARNING, "Throwing exceptions from auth plugin login/ACL checks is slow. There's no need.");
  366 + }
267 367 }
268 368  
269   - return r;
  369 + return AuthResult::error;
270 370 }
271 371  
272 372 void Authentication::setQuitting()
... ... @@ -488,6 +588,23 @@ AuthResult Authentication::unPwdCheckFromMosquittoPasswordFile(const std::string
488 588 return result;
489 589 }
490 590  
  591 +void Authentication::periodicEvent()
  592 +{
  593 + if (pluginVersion == PluginVersion::None)
  594 + return;
  595 +
  596 + if (!initialized)
  597 + {
  598 + logger->logf(LOG_ERR, "Auth plugin period event called, but initialization failed or not performed.");
  599 + return;
  600 + }
  601 +
  602 + if (pluginVersion == PluginVersion::FlashMQv1 && flashmq_auth_plugin_periodic_event_v1)
  603 + {
  604 + flashmq_auth_plugin_periodic_event_v1(pluginData);
  605 + }
  606 +}
  607 +
491 608 std::string AuthResultToString(AuthResult r)
492 609 {
493 610 if (r == AuthResult::success)
... ...
authplugin.h
... ... @@ -26,6 +26,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
26 26 #include "logger.h"
27 27 #include "configfileparser.h"
28 28 #include "acltree.h"
  29 +#include "flashmq_plugin.h"
29 30  
30 31 /**
31 32 * @brief The MosquittoPasswordFileEntry struct stores the decoded base64 password salt and hash.
... ... @@ -49,6 +50,7 @@ struct MosquittoPasswordFileEntry
49 50  
50 51 typedef int (*F_auth_plugin_version)(void);
51 52  
  53 +// Mosquitto functions
52 54 typedef int (*F_auth_plugin_init_v2)(void **, struct mosquitto_auth_opt *, int);
53 55 typedef int (*F_auth_plugin_cleanup_v2)(void *, struct mosquitto_auth_opt *, int);
54 56 typedef int (*F_auth_plugin_security_init_v2)(void *, struct mosquitto_auth_opt *, int, bool);
... ... @@ -57,12 +59,29 @@ typedef int (*F_auth_plugin_acl_check_v2)(void *, const char *, const char *, co
57 59 typedef int (*F_auth_plugin_unpwd_check_v2)(void *, const char *, const char *);
58 60 typedef int (*F_auth_plugin_psk_key_get_v2)(void *, const char *, const char *, char *, int);
59 61  
  62 +
  63 +typedef void(*F_flashmq_auth_plugin_allocate_thread_memory_v1)(void **thread_data, std::unordered_map<std::string, std::string> &auth_opts);
  64 +typedef void(*F_flashmq_auth_plugin_deallocate_thread_memory_v1)(void *thread_data, std::unordered_map<std::string, std::string> &auth_opts);
  65 +typedef void(*F_flashmq_auth_plugin_init_v1)(void *thread_data, std::unordered_map<std::string, std::string> &auth_opts, bool reloading);
  66 +typedef void(*F_flashmq_auth_plugin_deinit_v1)(void *thread_data, std::unordered_map<std::string, std::string> &auth_opts, bool reloading);
  67 +typedef AuthResult(*F_flashmq_auth_plugin_acl_check_v1)(void *thread_data, AclAccess access, const std::string &clientid, const std::string &username, const FlashMQMessage &msg);
  68 +typedef AuthResult(*F_flashmq_auth_plugin_login_check_v1)(void *thread_data, const std::string &username, const std::string &password);
  69 +typedef void (*F_flashmq_auth_plugin_periodic_event)(void *thread_data);
  70 +
60 71 extern "C"
61 72 {
62 73 // Gets called by the plugin, so it needs to exist, globally
63 74 void mosquitto_log_printf(int level, const char *fmt, ...);
64 75 }
65 76  
  77 +enum class PluginVersion
  78 +{
  79 + None,
  80 + Determining,
  81 + FlashMQv1,
  82 + MosquittoV2,
  83 +};
  84 +
66 85 std::string AuthResultToString(AuthResult r);
67 86  
68 87 /**
... ... @@ -72,6 +91,8 @@ std::string AuthResultToString(AuthResult r);
72 91 class Authentication
73 92 {
74 93 F_auth_plugin_version version = nullptr;
  94 +
  95 + // Mosquitto functions
75 96 F_auth_plugin_init_v2 init_v2 = nullptr;
76 97 F_auth_plugin_cleanup_v2 cleanup_v2 = nullptr;
77 98 F_auth_plugin_security_init_v2 security_init_v2 = nullptr;
... ... @@ -80,6 +101,14 @@ class Authentication
80 101 F_auth_plugin_unpwd_check_v2 unpwd_check_v2 = nullptr;
81 102 F_auth_plugin_psk_key_get_v2 psk_key_get_v2 = nullptr;
82 103  
  104 + F_flashmq_auth_plugin_allocate_thread_memory_v1 flashmq_auth_plugin_allocate_thread_memory_v1 = nullptr;
  105 + F_flashmq_auth_plugin_deallocate_thread_memory_v1 flashmq_auth_plugin_deallocate_thread_memory_v1 = nullptr;
  106 + F_flashmq_auth_plugin_init_v1 flashmq_auth_plugin_init_v1 = nullptr;
  107 + F_flashmq_auth_plugin_deinit_v1 flashmq_auth_plugin_deinit_v1 = nullptr;
  108 + F_flashmq_auth_plugin_acl_check_v1 flashmq_auth_plugin_acl_check_v1 = nullptr;
  109 + F_flashmq_auth_plugin_login_check_v1 flashmq_auth_plugin_login_check_v1 = nullptr;
  110 + F_flashmq_auth_plugin_periodic_event flashmq_auth_plugin_periodic_event_v1 = nullptr;
  111 +
83 112 static std::mutex initMutex;
84 113 static std::mutex authChecksMutex;
85 114  
... ... @@ -88,7 +117,7 @@ class Authentication
88 117 void *pluginData = nullptr;
89 118 Logger *logger = nullptr;
90 119 bool initialized = false;
91   - bool useExternalPlugin = false;
  120 + PluginVersion pluginVersion = PluginVersion::None;
92 121 bool quitting = false;
93 122  
94 123 /**
... ... @@ -110,7 +139,7 @@ class Authentication
110 139  
111 140 AclTree aclTree;
112 141  
113   - void *loadSymbol(void *handle, const char *symbol) const;
  142 + void *loadSymbol(void *handle, const char *symbol, bool exceptionOnError = true) const;
114 143 public:
115 144 Authentication(Settings &settings);
116 145 Authentication(const Authentication &other) = delete;
... ... @@ -122,7 +151,8 @@ public:
122 151 void cleanup();
123 152 void securityInit(bool reloading);
124 153 void securityCleanup(bool reloading);
125   - AuthResult aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics, AclAccess access);
  154 + AuthResult aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics,
  155 + AclAccess access, char qos, bool retain);
126 156 AuthResult unPwdCheck(const std::string &username, const std::string &password);
127 157  
128 158 void setQuitting();
... ... @@ -131,6 +161,8 @@ public:
131 161 AuthResult aclCheckFromMosquittoAclFile(const std::string &clientid, const std::string &username, const std::vector<std::string> &subtopics, AclAccess access);
132 162 AuthResult unPwdCheckFromMosquittoPasswordFile(const std::string &username, const std::string &password);
133 163  
  164 + void periodicEvent();
  165 +
134 166 };
135 167  
136 168 #endif // AUTHPLUGIN_H
... ...
client.cpp
... ... @@ -91,6 +91,11 @@ bool Client::getSslWriteWantsRead() const
91 91 return ioWrapper.getSslWriteWantsRead();
92 92 }
93 93  
  94 +ProtocolVersion Client::getProtocolVersion() const
  95 +{
  96 + return protocolVersion;
  97 +}
  98 +
94 99 void Client::startOrContinueSslAccept()
95 100 {
96 101 ioWrapper.startOrContinueSslAccept();
... ...
client.h
... ... @@ -99,6 +99,7 @@ public:
99 99 bool isSsl() const;
100 100 bool getSslReadWantsWrite() const;
101 101 bool getSslWriteWantsRead() const;
  102 + ProtocolVersion getProtocolVersion() const;
102 103  
103 104 void startOrContinueSslAccept();
104 105 void markAsDisconnecting();
... ...
configfileparser.cpp
... ... @@ -92,11 +92,12 @@ ConfigFileParser::ConfigFileParser(const std::string &amp;path) :
92 92 path(path)
93 93 {
94 94 validKeys.insert("auth_plugin");
  95 + validKeys.insert("auth_plugin_serialize_init");
  96 + validKeys.insert("auth_plugin_serialize_auth_checks");
  97 + validKeys.insert("auth_plugin_timer_period");
95 98 validKeys.insert("log_file");
96 99 validKeys.insert("allow_unsafe_clientid_chars");
97 100 validKeys.insert("allow_unsafe_username_chars");
98   - validKeys.insert("auth_plugin_serialize_init");
99   - validKeys.insert("auth_plugin_serialize_auth_checks");
100 101 validKeys.insert("client_initial_buffer_size");
101 102 validKeys.insert("max_packet_size");
102 103 validKeys.insert("log_debug");
... ... @@ -401,6 +402,16 @@ void ConfigFileParser::loadFile(bool test)
401 402 }
402 403 tmpSettings->expireSessionsAfterSeconds = newVal;
403 404 }
  405 +
  406 + if (key == "auth_plugin_timer_period")
  407 + {
  408 + int newVal = std::stoi(value);
  409 + if (newVal < 0)
  410 + {
  411 + throw ConfigFileException(formatString("auth_plugin_timer_period value '%d' is invalid. Valid values are 0 or higher. 0 means disabled.", newVal));
  412 + }
  413 + tmpSettings->authPluginTimerPeriod = newVal;
  414 + }
404 415 }
405 416 }
406 417 catch (std::invalid_argument &ex) // catch for the stoi()
... ... @@ -410,6 +421,7 @@ void ConfigFileParser::loadFile(bool test)
410 421 }
411 422  
412 423 tmpSettings->authOptCompatWrap = AuthOptCompatWrap(authOpts);
  424 + tmpSettings->flashmqAuthPluginOpts = std::move(authOpts);
413 425  
414 426 if (!test)
415 427 {
... ... @@ -417,9 +429,5 @@ void ConfigFileParser::loadFile(bool test)
417 429 }
418 430 }
419 431  
420   -AuthOptCompatWrap &Settings::getAuthOptsCompat()
421   -{
422   - return authOptCompatWrap;
423   -}
424 432  
425 433  
... ...
1 1 #ifndef ENUMS_H
2 2 #define ENUMS_H
3 3  
4   -// Compatible with Mosquitto
5   -enum class AclAccess
6   -{
7   - none = 0,
8   - read = 1,
9   - write = 2
10   -};
11   -
12   -// Compatible with Mosquitto
13   -enum class AuthResult
14   -{
15   - success = 0,
16   - acl_denied = 12,
17   - login_denied = 11,
18   - error = 13
19   -};
  4 +#include "flashmq_plugin.h"
20 5  
21 6 #endif // ENUMS_H
... ...
flashmq_plugin.cpp 0 โ†’ 100644
  1 +#include "flashmq_plugin.h"
  2 +
  3 +#include "logger.h"
  4 +
  5 +void flashmq_logf(int level, const char *str, ...)
  6 +{
  7 + Logger *logger = Logger::getInstance();
  8 +
  9 + va_list valist;
  10 + va_start(valist, str);
  11 + logger->logf(level, str, valist);
  12 + va_end(valist);
  13 +}
  14 +
  15 +FlashMQMessage::FlashMQMessage(const std::string &topic, const std::vector<std::string> &subtopics, const char qos, const bool retain) :
  16 + topic(topic),
  17 + subtopics(subtopics),
  18 + qos(qos),
  19 + retain(retain)
  20 +{
  21 +
  22 +}
... ...
flashmq_plugin.h 0 โ†’ 100644
  1 +/*
  2 + * This file is part of FlashMQ (https://www.flashmq.org). It defines the
  3 + * authentication plugin interface.
  4 + *
  5 + * This interface definition is public domain and you are encouraged
  6 + * to copy it to your authentication plugin project, for portability. Including
  7 + * this file in your project does not require your code to have a compatibile
  8 + * license nor requires you to open source it.
  9 + *
  10 + * Compile like: gcc -fPIC -shared authplugin.cpp -o authplugin.so
  11 + */
  12 +
  13 +#ifndef FLASHMQ_PLUGIN_H
  14 +#define FLASHMQ_PLUGIN_H
  15 +
  16 +#include <string>
  17 +#include <vector>
  18 +#include <unordered_map>
  19 +
  20 +#define FLASHMQ_PLUGIN_VERSION 1
  21 +
  22 +// Compatible with Mosquitto, for auth plugin compatability.
  23 +#define LOG_NONE 0x00
  24 +#define LOG_INFO 0x01
  25 +#define LOG_NOTICE 0x02
  26 +#define LOG_WARNING 0x04
  27 +#define LOG_ERR 0x08
  28 +#define LOG_DEBUG 0x10
  29 +#define LOG_SUBSCRIBE 0x20
  30 +#define LOG_UNSUBSCRIBE 0x40
  31 +
  32 +extern "C"
  33 +{
  34 +
  35 +/**
  36 + * @brief The AclAccess enum's numbers are compatible with Mosquitto's 'int access'.
  37 + *
  38 + * read = reading a publish published by someone else.
  39 + * write = doing a publish.
  40 + * subscribe = subscribing.
  41 + */
  42 +enum class AclAccess
  43 +{
  44 + none = 0,
  45 + read = 1,
  46 + write = 2,
  47 + subscribe = 4
  48 +};
  49 +
  50 +/**
  51 + * @brief The AuthResult enum's numbers are compatible with Mosquitto's auth result.
  52 + */
  53 +enum class AuthResult
  54 +{
  55 + success = 0,
  56 + acl_denied = 12,
  57 + login_denied = 11,
  58 + error = 13
  59 +};
  60 +
  61 +/**
  62 + * @brief The FlashMQMessage struct contains the meta data of a publish.
  63 + *
  64 + * The subtopics is the topic split, so you don't have to do that anymore.
  65 + *
  66 + * As for 'retain', keep in mind that for existing subscribers, this will always be false [MQTT-3.3.1-9]. Only publishes or
  67 + * retain messages as a result of a subscribe can have that set to true.
  68 + *
  69 + * For subscribtions, 'retain' is always false.
  70 + */
  71 +struct FlashMQMessage
  72 +{
  73 + const std::string &topic;
  74 + const std::vector<std::string> &subtopics;
  75 + const char qos;
  76 + const bool retain;
  77 +
  78 + FlashMQMessage(const std::string &topic, const std::vector<std::string> &subtopics, const char qos, const bool retain);
  79 +};
  80 +
  81 +/**
  82 + * @brief flashmq_logf calls the internal logger of FlashMQ. The logger mutexes all access, so is thread-safe.
  83 + * @param level is any of the levels defined above, starting with LOG_.
  84 + * @param str
  85 + *
  86 + * FlashMQ makes no distinction between INFO and NOTICE.
  87 + */
  88 +void flashmq_logf(int level, const char *str, ...);
  89 +
  90 +/**
  91 + * @brief flashmq_plugin_version must return FLASHMQ_PLUGIN_VERSION.
  92 + * @return FLASHMQ_PLUGIN_VERSION.
  93 + */
  94 +int flashmq_auth_plugin_version();
  95 +
  96 +/**
  97 + * @brief flashmq_auth_plugin_allocate_thread_memory is called once by each thread. Never again.
  98 + * @param thread_data. Create a memory structure and assign it to *thread_data.
  99 + * @param global_data. The global data created in flashmq_auth_plugin_allocate_global_memory, if you use it.
  100 + * @param auth_opts. Map of flashmq_auth_opt_* from the config file.
  101 + *
  102 + * Only allocate the plugin's memory here. Don't open connections, etc.
  103 + *
  104 + * The global data is created by flashmq_auth_plugin_allocate_global_memory() and if you need it, you can assign it to your
  105 + * own thread_data storage. It is not passed as argument to other functions.
  106 + *
  107 + * You can use static variables for global scope if you must, but do provide proper locking where necessary.
  108 + *
  109 + * throw an exception on errors.
  110 + */
  111 +void flashmq_auth_plugin_allocate_thread_memory(void **thread_data, std::unordered_map<std::string, std::string> &auth_opts);
  112 +
  113 +/**
  114 + * @brief flashmq_auth_plugin_deallocate_thread_memory is called once by each thread. Never again.
  115 + * @param thread_data. Delete this memory.
  116 + * @param auth_opts. Map of flashmq_auth_opt_* from the config file.
  117 + *
  118 + * throw an exception on errors.
  119 + */
  120 +void flashmq_auth_plugin_deallocate_thread_memory(void *thread_data, std::unordered_map<std::string, std::string> &auth_opts);
  121 +
  122 +/**
  123 + * @brief flashmq_auth_plugin_init is called on thread start and config reload. It is the main place to initialize the plugin.
  124 + * @param thread_data is memory allocated in flashmq_auth_plugin_allocate_thread_memory().
  125 + * @param auth_opts. Map of flashmq_auth_opt_* from the config file.
  126 + * @param reloading.
  127 + *
  128 + * The best approach to state keeping is doing everything per thread. You can initialize connections to database servers, load encryption keys,
  129 + * create maps, etc.
  130 + *
  131 + * Keep in mind that libraries you use may not be thread safe (by default). Sometimes they use global scope in treacherous ways. As a random
  132 + * example: Qt's QSqlDatabase needs a unique name for each connection, otherwise it is not thread safe and will crash.
  133 + *
  134 + * There is the option to set 'auth_plugin_serialize_init true' in the config file, which allows some mitigation in
  135 + * case you run into problems.
  136 + *
  137 + * throw an exception on errors.
  138 + */
  139 +void flashmq_auth_plugin_init(void *thread_data, std::unordered_map<std::string, std::string> &auth_opts, bool reloading);
  140 +
  141 +/**
  142 + * @brief flashmq_auth_plugin_deinit is called on thread stop and config reload. It is the precursor to initializing.
  143 + * @param thread_data is memory allocated in flashmq_auth_plugin_allocate_thread_memory().
  144 + * @param auth_opts. Map of flashmq_auth_opt_* from the config file.
  145 + * @param reloading
  146 + *
  147 + * throw an exception on errors.
  148 + */
  149 +void flashmq_auth_plugin_deinit(void *thread_data, std::unordered_map<std::string, std::string> &auth_opts, bool reloading);
  150 +
  151 +/**
  152 + * @brief flashmq_auth_plugin_periodic is called every x seconds as defined in the config file.
  153 + * @param thread_data is memory allocated in flashmq_auth_plugin_allocate_thread_memory().
  154 + *
  155 + * You may need to periodically refresh data from a database, post stats, etc. You can do that from here. It's queued
  156 + * in each thread at the same time, so you can perform somewhat synchronized events in all threads.
  157 + *
  158 + * Note that it's executed in the event loop, so it blocks the thread if you block here. If you need asynchronous operation,
  159 + * you can make threads yourself. Be sure to synchronize data access properly in that case.
  160 + *
  161 + * The setting auth_plugin_timer_period sets this interval in seconds.
  162 + *
  163 + * Implementing this is optional.
  164 + *
  165 + * throw an exception on errors.
  166 + */
  167 +void flashmq_auth_plugin_periodic_event(void *thread_data);
  168 +
  169 +/**
  170 + * @brief flashmq_auth_plugin_login_check is called on login of a client.
  171 + * @param thread_data is memory allocated in flashmq_auth_plugin_allocate_thread_memory().
  172 + * @param username
  173 + * @param password
  174 + * @return
  175 + *
  176 + * You could throw exceptions here, but that will be slow and pointless. It will just get converted into AuthResult::error,
  177 + * because there's nothing else to do: the state of FlashMQ won't change.
  178 + *
  179 + * Note that there is a setting 'auth_plugin_serialize_auth_checks'. Use only as a last resort if your plugin is not
  180 + * thread-safe. It will negate much of FlashMQ's multi-core model.
  181 + */
  182 +AuthResult flashmq_auth_plugin_login_check(void *thread_data, const std::string &username, const std::string &password);
  183 +
  184 +/**
  185 + * @brief flashmq_auth_plugin_acl_check is called on publish, deliver and subscribe.
  186 + * @param thread_data is memory allocated in flashmq_auth_plugin_allocate_thread_memory().
  187 + * @param access
  188 + * @param clientid
  189 + * @param username
  190 + * @param msg. See FlashMQMessage.
  191 + * @return
  192 + *
  193 + * You could throw exceptions here, but that will be slow and pointless. It will just get converted into AuthResult::error,
  194 + * because there's nothing else to do: the state of FlashMQ won't change.
  195 + *
  196 + * Controlling subscribe access can have several benefits. For instance, you may want to avoid subscriptions that cause
  197 + * a lot of server load. If clients pester you with many subscriptions like '+/+/+/+/+/+/+/+/+/', that causes a lot
  198 + * of tree walking. Similarly, if all clients subscribe to '#' because it's easy, every single message passing through
  199 + * the server will have to be ACL checked for every subscriber.
  200 + *
  201 + * Note that only MQTT 3.1.1 or higher has a 'failed' return code for subscribing, so older clients will see a normal
  202 + * ack and won't know it failed.
  203 + *
  204 + * Note that there is a setting 'auth_plugin_serialize_auth_checks'. Use only as a last resort if your plugin is not
  205 + * thread-safe. It will negate much of FlashMQ's multi-core model.
  206 + */
  207 +AuthResult flashmq_auth_plugin_acl_check(void *thread_data, AclAccess access, const std::string &clientid, const std::string &username, const FlashMQMessage &msg);
  208 +
  209 +}
  210 +
  211 +#endif // FLASHMQ_PLUGIN_H
... ...
logger.h
... ... @@ -22,16 +22,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
22 22 #include <stdarg.h>
23 23 #include <mutex>
24 24  
25   -// Compatible with Mosquitto, for auth plugin compatability.
26   -// Can be OR'ed together.
27   -#define LOG_NONE 0x00
28   -#define LOG_INFO 0x01
29   -#define LOG_NOTICE 0x02
30   -#define LOG_WARNING 0x04
31   -#define LOG_ERR 0x08
32   -#define LOG_DEBUG 0x10
33   -#define LOG_SUBSCRIBE 0x20
34   -#define LOG_UNSUBSCRIBE 0x40
  25 +#include "flashmq_plugin.h"
35 26  
36 27 int logSslError(const char *str, size_t len, void *u);
37 28  
... ...
mainapp.cpp
... ... @@ -196,6 +196,12 @@ MainApp::MainApp(const std::string &amp;configFilePath) :
196 196 auto fPublishStats = std::bind(&MainApp::publishStatsOnDollarTopic, this);
197 197 timer.addCallback(fPublishStats, 10000, "Publish stats on $SYS");
198 198 publishStatsOnDollarTopic();
  199 +
  200 + if (settings->authPluginTimerPeriod > 0)
  201 + {
  202 + auto fAuthPluginPeriodicEvent = std::bind(&MainApp::queueAuthPluginPeriodicEventAllThreads, this);
  203 + timer.addCallback(fAuthPluginPeriodicEvent, settings->authPluginTimerPeriod*1000, "Auth plugin periodic event.");
  204 + }
199 205 }
200 206  
201 207 MainApp::~MainApp()
... ... @@ -310,6 +316,14 @@ void MainApp::queuePasswordFileReloadAllThreads()
310 316 }
311 317 }
312 318  
  319 +void MainApp::queueAuthPluginPeriodicEventAllThreads()
  320 +{
  321 + for (std::shared_ptr<ThreadData> &thread : threads)
  322 + {
  323 + thread->queueAuthPluginPeriodicEvent();
  324 + }
  325 +}
  326 +
313 327 void MainApp::setFuzzFile(const std::string &fuzzFilePath)
314 328 {
315 329 this->fuzzFilePath = fuzzFilePath;
... ... @@ -505,6 +519,7 @@ void MainApp::start()
505 519 try
506 520 {
507 521 std::vector<MqttPacket> packetQueueIn;
  522 + std::vector<std::string> subtopics;
508 523  
509 524 std::shared_ptr<ThreadData> threaddata(new ThreadData(0, subscriptionStore, settings));
510 525  
... ... @@ -518,10 +533,11 @@ void MainApp::start()
518 533 websocketsubscriber->setAuthenticated(true);
519 534 websocketsubscriber->setFakeUpgraded();
520 535 subscriptionStore->registerClientAndKickExistingOne(websocketsubscriber);
521   - subscriptionStore->addSubscription(websocketsubscriber, "#", 0);
  536 + splitTopic("#", subtopics);
  537 + subscriptionStore->addSubscription(websocketsubscriber, "#", subtopics, 0);
522 538  
523 539 subscriptionStore->registerClientAndKickExistingOne(subscriber);
524   - subscriptionStore->addSubscription(subscriber, "#", 0);
  540 + subscriptionStore->addSubscription(subscriber, "#", subtopics, 0);
525 541  
526 542 if (fuzzWebsockets && strContains(fuzzFilePathLower, "upgrade"))
527 543 {
... ...
mainapp.h
... ... @@ -76,6 +76,7 @@ class MainApp
76 76 void wakeUpThread();
77 77 void queueKeepAliveCheckAtAllThreads();
78 78 void queuePasswordFileReloadAllThreads();
  79 + void queueAuthPluginPeriodicEventAllThreads();
79 80 void setFuzzFile(const std::string &fuzzFilePath);
80 81 void publishStatsOnDollarTopic();
81 82 void publishStat(const std::string &topic, uint64_t n);
... ...
mqttpacket.cpp
... ... @@ -407,6 +407,7 @@ void MqttPacket::handleDisconnect()
407 407  
408 408 void MqttPacket::handleSubscribe()
409 409 {
  410 + this->subtopics = &gSubtopics;
410 411 const char firstByteFirstNibble = (first_byte & 0x0F);
411 412  
412 413 if (firstByteFirstNibble != 2)
... ... @@ -431,9 +432,23 @@ void MqttPacket::handleSubscribe()
431 432 if (qos > 2)
432 433 throw ProtocolError("QoS is greater than 2, and/or reserved bytes in QoS field are not 0.");
433 434  
434   - logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str());
435   - sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, qos);
436   - subs_reponse_codes.push_back(qos);
  435 + splitTopic(topic, *subtopics);
  436 + if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::subscribe, qos, false) == AuthResult::success)
  437 + {
  438 + logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str());
  439 + sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, *subtopics, qos);
  440 + subs_reponse_codes.push_back(qos);
  441 + }
  442 + else
  443 + {
  444 + logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribe to '%s' denied or failed.", sender->repr().c_str(), topic.c_str());
  445 +
  446 + // We can't not send an ack, because if there are multiple subscribes, you send fewer acks back, losing sync.
  447 + char return_code = qos;
  448 + if (sender->getProtocolVersion() >= ProtocolVersion::Mqtt311)
  449 + return_code = static_cast<char>(SubAckReturnCodes::Fail);
  450 + subs_reponse_codes.push_back(return_code);
  451 + }
437 452 }
438 453  
439 454 SubAck subAck(packet_id, subs_reponse_codes);
... ... @@ -531,7 +546,7 @@ void MqttPacket::handlePublish()
531 546 }
532 547 }
533 548  
534   - if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write) == AuthResult::success)
  549 + if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success)
535 550 {
536 551 if (retain)
537 552 {
... ...
session.cpp
... ... @@ -48,14 +48,20 @@ void Session::assignActiveConnection(std::shared_ptr&lt;Client&gt; &amp;client)
48 48 this->thread = client->getThreadData();
49 49 }
50 50  
51   -void Session::writePacket(const MqttPacket &packet, char max_qos, uint64_t &count)
  51 +/**
  52 + * @brief Session::writePacket is the main way to give a client a packet -> it goes through the session.
  53 + * @param packet
  54 + * @param max_qos
  55 + * @param retain. Keep MQTT-3.3.1-9 in mind: existing subscribers don't get retain=1 on packets.
  56 + * @param count. Reference value is updated. It's for statistics.
  57 + */
  58 +void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count)
52 59 {
53 60 assert(max_qos <= 2);
  61 + const char qos = std::min<char>(packet.getQos(), max_qos);
54 62  
55   - if (thread->authentication.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read) == AuthResult::success)
  63 + if (thread->authentication.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success)
56 64 {
57   - const char qos = std::min<char>(packet.getQos(), max_qos);
58   -
59 65 if (qos == 0)
60 66 {
61 67 if (!clientDisconnected())
... ...
session.h
... ... @@ -61,7 +61,7 @@ public:
61 61 bool clientDisconnected() const;
62 62 std::shared_ptr<Client> makeSharedClient() const;
63 63 void assignActiveConnection(std::shared_ptr<Client> &client);
64   - void writePacket(const MqttPacket &packet, char max_qos, uint64_t &count);
  64 + void writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count);
65 65 void clearQosMessage(uint16_t packet_id);
66 66 uint64_t sendPendingQosMessages();
67 67 void touch(std::chrono::time_point<std::chrono::steady_clock> val);
... ...
settings.cpp
... ... @@ -18,3 +18,12 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
18 18 #include "settings.h"
19 19  
20 20  
  21 +AuthOptCompatWrap &Settings::getAuthOptsCompat()
  22 +{
  23 + return authOptCompatWrap;
  24 +}
  25 +
  26 +std::unordered_map<std::string, std::string> &Settings::getFlashmqAuthPluginOpts()
  27 +{
  28 + return this->flashmqAuthPluginOpts;
  29 +}
... ...
settings.h
... ... @@ -29,6 +29,7 @@ class Settings
29 29 friend class ConfigFileParser;
30 30  
31 31 AuthOptCompatWrap authOptCompatWrap;
  32 + std::unordered_map<std::string, std::string> flashmqAuthPluginOpts;
32 33  
33 34 public:
34 35 // Actual config options with their defaults.
... ... @@ -47,9 +48,11 @@ public:
47 48 bool allowAnonymous = false;
48 49 int rlimitNoFile = 1000000;
49 50 uint64_t expireSessionsAfterSeconds = 1209600;
  51 + int authPluginTimerPeriod = 60;
50 52 std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined.
51 53  
52 54 AuthOptCompatWrap &getAuthOptsCompat();
  55 + std::unordered_map<std::string, std::string> &getFlashmqAuthPluginOpts();
53 56 };
54 57  
55 58 #endif // SETTINGS_H
... ...
subscriptionstore.cpp
... ... @@ -84,10 +84,8 @@ SubscriptionStore::SubscriptionStore() :
84 84  
85 85 }
86 86  
87   -void SubscriptionStore::addSubscription(std::shared_ptr<Client> &client, const std::string &topic, char qos)
  87 +void SubscriptionStore::addSubscription(std::shared_ptr<Client> &client, const std::string &topic, const std::vector<std::string> &subtopics, char qos)
88 88 {
89   - const std::list<std::string> subtopics = split(topic, '/');
90   -
91 89 SubscriptionNode *deepestNode = &root;
92 90 if (topic.length() > 0 && topic[0] == '$')
93 91 deepestNode = &rootDollar;
... ... @@ -242,7 +240,7 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &amp;packet, const st
242 240 if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect.
243 241 {
244 242 const std::shared_ptr<Session> session = session_weak.lock();
245   - session->writePacket(packet, sub.qos, count);
  243 + session->writePacket(packet, sub.qos, false, count);
246 244 }
247 245 }
248 246 }
... ... @@ -330,7 +328,7 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr&lt;Ses
330 328  
331 329 if (topicsMatch(subscribe_topic, rm.topic))
332 330 {
333   - ses->writePacket(packet, max_qos, count);
  331 + ses->writePacket(packet, max_qos, true, count);
334 332 }
335 333 }
336 334  
... ...
subscriptionstore.h
... ... @@ -89,7 +89,7 @@ class SubscriptionStore
89 89 public:
90 90 SubscriptionStore();
91 91  
92   - void addSubscription(std::shared_ptr<Client> &client, const std::string &topic, char qos);
  92 + void addSubscription(std::shared_ptr<Client> &client, const std::string &topic, const std::vector<std::string> &subtopics, char qos);
93 93 void removeSubscription(std::shared_ptr<Client> &client, const std::string &topic);
94 94 void registerClientAndKickExistingOne(std::shared_ptr<Client> &client);
95 95 bool sessionPresent(const std::string &clientid);
... ...
threaddata.cpp
... ... @@ -142,6 +142,7 @@ void ThreadData::queueQuit()
142 142 void ThreadData::waitForQuit()
143 143 {
144 144 thread.join();
  145 + authentication.cleanup();
145 146 }
146 147  
147 148 void ThreadData::queuePasswdFileReload()
... ... @@ -210,6 +211,21 @@ uint64_t ThreadData::getSentMessagePerSecond()
210 211 return result;
211 212 }
212 213  
  214 +void ThreadData::queueAuthPluginPeriodicEvent()
  215 +{
  216 + std::lock_guard<std::mutex> locker(taskQueueMutex);
  217 +
  218 + auto f = std::bind(&ThreadData::authPluginPeriodicEvent, this);
  219 + taskQueue.push_front(f);
  220 +
  221 + wakeUpThread();
  222 +}
  223 +
  224 +void ThreadData::authPluginPeriodicEvent()
  225 +{
  226 + authentication.periodicEvent();
  227 +}
  228 +
213 229 // TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial?
214 230 void ThreadData::doKeepAliveCheck()
215 231 {
... ... @@ -266,13 +282,9 @@ void ThreadData::reload(std::shared_ptr&lt;Settings&gt; settings)
266 282 authentication.securityCleanup(true);
267 283 authentication.securityInit(true);
268 284 }
269   - catch (AuthPluginException &ex)
270   - {
271   - logger->logf(LOG_ERR, "Error reloading auth plugin: %s. Security checks will now fail, because we don't know the status of the plugin anymore.", ex.what());
272   - }
273 285 catch (std::exception &ex)
274 286 {
275   - logger->logf(LOG_ERR, "Error reloading: %s.", ex.what());
  287 + logger->logf(LOG_ERR, "Error reloading auth plugin: %s. Security checks will now fail, because we don't know the status of the plugin anymore.", ex.what());
276 288 }
277 289 }
278 290  
... ...
threaddata.h
... ... @@ -101,6 +101,9 @@ public:
101 101 void incrementSentMessageCount(uint64_t n);
102 102 uint64_t getSentMessageCount() const;
103 103 uint64_t getSentMessagePerSecond();
  104 +
  105 + void queueAuthPluginPeriodicEvent();
  106 + void authPluginPeriodicEvent();
104 107 };
105 108  
106 109 #endif // THREADDATA_H
... ...
types.cpp
... ... @@ -15,6 +15,8 @@ You should have received a copy of the GNU Affero General Public
15 15 License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
16 16 */
17 17  
  18 +#include "cassert"
  19 +
18 20 #include "types.h"
19 21  
20 22 ConnAck::ConnAck(ConnAckReturnCodes return_code, bool session_present) :
... ... @@ -29,6 +31,8 @@ ConnAck::ConnAck(ConnAckReturnCodes return_code, bool session_present) :
29 31 SubAck::SubAck(uint16_t packet_id, const std::list<char> &subs_qos_reponses) :
30 32 packet_id(packet_id)
31 33 {
  34 + assert(!subs_qos_reponses.empty());
  35 +
32 36 for (char ack_code : subs_qos_reponses)
33 37 {
34 38 responses.push_back(static_cast<SubAckReturnCodes>(ack_code));
... ...