Commit 55108074b66746038666c0147e022142df03f8e7

Authored by Wiebe Cazemier
1 parent e61b7fbf

Evaluate ACL of delayed publishes too

This requires storing the clientid and username in the Publish object.
authplugin.cpp
@@ -278,6 +278,21 @@ void Authentication::securityCleanup(bool reloading) @@ -278,6 +278,21 @@ void Authentication::securityCleanup(bool reloading)
278 } 278 }
279 } 279 }
280 280
  281 +/**
  282 + * @brief Authentication::aclCheck performs a write ACL check on the incoming publish.
  283 + * @param publishData
  284 + * @return
  285 + */
  286 +AuthResult Authentication::aclCheck(Publish &publishData)
  287 +{
  288 + // Anonymous publishes come from FlashMQ internally, like SYS topics. We need to allow them.
  289 + if (publishData.client_id.empty())
  290 + return AuthResult::success;
  291 +
  292 + return aclCheck(publishData.client_id, publishData.username, publishData.topic, publishData.getSubtopics(), AclAccess::write, publishData.qos,
  293 + publishData.retain, publishData.getUserProperties());
  294 +}
  295 +
281 AuthResult Authentication::aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics, 296 AuthResult Authentication::aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics,
282 AclAccess access, char qos, bool retain, const std::vector<std::pair<std::string, std::string>> *userProperties) 297 AclAccess access, char qos, bool retain, const std::vector<std::pair<std::string, std::string>> *userProperties)
283 { 298 {
authplugin.h
@@ -156,6 +156,7 @@ public: @@ -156,6 +156,7 @@ public:
156 void cleanup(); 156 void cleanup();
157 void securityInit(bool reloading); 157 void securityInit(bool reloading);
158 void securityCleanup(bool reloading); 158 void securityCleanup(bool reloading);
  159 + AuthResult aclCheck(Publish &publishData);
159 AuthResult aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics, 160 AuthResult aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics,
160 AclAccess access, char qos, bool retain, const std::vector<std::pair<std::string, std::string>> *userProperties); 161 AclAccess access, char qos, bool retain, const std::vector<std::pair<std::string, std::string>> *userProperties);
161 AuthResult unPwdCheck(const std::string &username, const std::string &password, 162 AuthResult unPwdCheck(const std::string &username, const std::string &password,
client.cpp
@@ -645,6 +645,8 @@ void Client::setClientProperties(ProtocolVersion protocolVersion, const std::str @@ -645,6 +645,8 @@ void Client::setClientProperties(ProtocolVersion protocolVersion, const std::str
645 void Client::setWill(WillPublish &&willPublish) 645 void Client::setWill(WillPublish &&willPublish)
646 { 646 {
647 this->willPublish = std::make_shared<WillPublish>(std::move(willPublish)); 647 this->willPublish = std::make_shared<WillPublish>(std::move(willPublish));
  648 + this->willPublish->client_id = this->clientid;
  649 + this->willPublish->username = this->username;
648 } 650 }
649 651
650 void Client::assignSession(std::shared_ptr<Session> &session) 652 void Client::assignSession(std::shared_ptr<Session> &session)
mqttpacket.cpp
@@ -138,6 +138,9 @@ MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, Publish &amp;_publish) @@ -138,6 +138,9 @@ MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, Publish &amp;_publish)
138 138
139 this->protocolVersion = protocolVersion; 139 this->protocolVersion = protocolVersion;
140 140
  141 + this->publishData.client_id = _publish.client_id;
  142 + this->publishData.username = _publish.username;
  143 +
141 if (!_publish.skipTopic) 144 if (!_publish.skipTopic)
142 this->publishData.topic = _publish.topic; 145 this->publishData.topic = _publish.topic;
143 146
@@ -1136,6 +1139,9 @@ void MqttPacket::parsePublishData() @@ -1136,6 +1139,9 @@ void MqttPacket::parsePublishData()
1136 if (publishData.qos == 0 && duplicate) 1139 if (publishData.qos == 0 && duplicate)
1137 throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal.", ReasonCodes::MalformedPacket); 1140 throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal.", ReasonCodes::MalformedPacket);
1138 1141
  1142 + publishData.username = sender->getUsername();
  1143 + publishData.client_id = sender->getClientId();
  1144 +
1139 publishData.topic = readBytesToString(true, true); 1145 publishData.topic = readBytesToString(true, true);
1140 1146
1141 if (publishData.qos) 1147 if (publishData.qos)
@@ -1262,7 +1268,7 @@ void MqttPacket::handlePublish() @@ -1262,7 +1268,7 @@ void MqttPacket::handlePublish()
1262 if (publishData.qos == 2) 1268 if (publishData.qos == 2)
1263 sender->getSession()->addIncomingQoS2MessageId(_packet_id); 1269 sender->getSession()->addIncomingQoS2MessageId(_packet_id);
1264 1270
1265 - if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.getSubtopics(), AclAccess::write, publishData.qos, publishData.retain, getUserProperties()) == AuthResult::success) 1271 + if (authentication.aclCheck(this->publishData) == AuthResult::success)
1266 { 1272 {
1267 if (publishData.retain) 1273 if (publishData.retain)
1268 { 1274 {
session.cpp
@@ -226,6 +226,8 @@ bool Session::clearQosMessage(uint16_t packet_id, bool qosHandshakeEnds) @@ -226,6 +226,8 @@ bool Session::clearQosMessage(uint16_t packet_id, bool qosHandshakeEnds)
226 */ 226 */
227 void Session::sendAllPendingQosData() 227 void Session::sendAllPendingQosData()
228 { 228 {
  229 + Authentication &authentication = *ThreadGlobals::getAuth();
  230 +
229 std::shared_ptr<Client> c = makeSharedClient(); 231 std::shared_ptr<Client> c = makeSharedClient();
230 if (c) 232 if (c)
231 { 233 {
@@ -237,7 +239,7 @@ void Session::sendAllPendingQosData() @@ -237,7 +239,7 @@ void Session::sendAllPendingQosData()
237 QueuedPublish &queuedPublish = *pos; 239 QueuedPublish &queuedPublish = *pos;
238 Publish &pub = queuedPublish.getPublish(); 240 Publish &pub = queuedPublish.getPublish();
239 241
240 - if (pub.hasExpired()) 242 + if (pub.hasExpired() || (authentication.aclCheck(pub) != AuthResult::success))
241 { 243 {
242 pos = qosPacketQueue.erase(pos); 244 pos = qosPacketQueue.erase(pos);
243 continue; 245 continue;
subscriptionstore.cpp
@@ -297,6 +297,8 @@ std::shared_ptr&lt;Session&gt; SubscriptionStore::lockSession(const std::string &amp;clien @@ -297,6 +297,8 @@ std::shared_ptr&lt;Session&gt; SubscriptionStore::lockSession(const std::string &amp;clien
297 */ 297 */
298 void SubscriptionStore::sendQueuedWillMessages() 298 void SubscriptionStore::sendQueuedWillMessages()
299 { 299 {
  300 + Authentication &auth = *ThreadGlobals::getAuth();
  301 +
300 const auto now = std::chrono::steady_clock::now(); 302 const auto now = std::chrono::steady_clock::now();
301 const std::chrono::seconds secondsSinceEpoch = std::chrono::duration_cast<std::chrono::seconds>(now.time_since_epoch()); 303 const std::chrono::seconds secondsSinceEpoch = std::chrono::duration_cast<std::chrono::seconds>(now.time_since_epoch());
302 std::lock_guard<std::mutex> locker(this->pendingWillsMutex); 304 std::lock_guard<std::mutex> locker(this->pendingWillsMutex);
@@ -325,11 +327,14 @@ void SubscriptionStore::sendQueuedWillMessages() @@ -325,11 +327,14 @@ void SubscriptionStore::sendQueuedWillMessages()
325 if (s && !s->hasActiveClient()) 327 if (s && !s->hasActiveClient())
326 { 328 {
327 logger->logf(LOG_DEBUG, "Sending delayed will on topic '%s'.", p->topic.c_str() ); 329 logger->logf(LOG_DEBUG, "Sending delayed will on topic '%s'.", p->topic.c_str() );
328 - PublishCopyFactory factory(p.get());  
329 - queuePacketAtSubscribers(factory); 330 + if (auth.aclCheck(*p) == AuthResult::success)
  331 + {
  332 + PublishCopyFactory factory(p.get());
  333 + queuePacketAtSubscribers(factory);
330 334
331 - if (p->retain)  
332 - setRetainedMessage(*p, p->getSubtopics()); 335 + if (p->retain)
  336 + setRetainedMessage(*p, p->getSubtopics());
  337 + }
333 338
334 s->clearWill(); 339 s->clearWill();
335 } 340 }
@@ -352,16 +357,21 @@ void SubscriptionStore::queueWillMessage(const std::shared_ptr&lt;WillPublish&gt; &amp;wil @@ -352,16 +357,21 @@ void SubscriptionStore::queueWillMessage(const std::shared_ptr&lt;WillPublish&gt; &amp;wil
352 if (!willMessage) 357 if (!willMessage)
353 return; 358 return;
354 359
  360 + Authentication &auth = *ThreadGlobals::getAuth();
  361 +
355 const int delay = forceNow ? 0 : willMessage->will_delay; 362 const int delay = forceNow ? 0 : willMessage->will_delay;
356 logger->logf(LOG_DEBUG, "Queueing will on topic '%s', with delay %d seconds.", willMessage->topic.c_str(), delay ); 363 logger->logf(LOG_DEBUG, "Queueing will on topic '%s', with delay %d seconds.", willMessage->topic.c_str(), delay );
357 364
358 if (delay == 0) 365 if (delay == 0)
359 { 366 {
360 - PublishCopyFactory factory(willMessage.get());  
361 - queuePacketAtSubscribers(factory); 367 + if (auth.aclCheck(*willMessage) == AuthResult::success)
  368 + {
  369 + PublishCopyFactory factory(willMessage.get());
  370 + queuePacketAtSubscribers(factory);
362 371
363 - if (willMessage->retain)  
364 - setRetainedMessage(*willMessage.get(), (*willMessage).getSubtopics()); 372 + if (willMessage->retain)
  373 + setRetainedMessage(*willMessage.get(), (*willMessage).getSubtopics());
  374 + }
365 375
366 // Avoid sending two immediate wills when a session is destroyed with the client disconnect. 376 // Avoid sending two immediate wills when a session is destroyed with the client disconnect.
367 if (session) // session is null when you're destroying a client before a session is assigned. 377 if (session) // session is null when you're destroying a client before a session is assigned.
@@ -466,18 +476,23 @@ void SubscriptionStore::queuePacketAtSubscribers(PublishCopyFactory &amp;copyFactory @@ -466,18 +476,23 @@ void SubscriptionStore::queuePacketAtSubscribers(PublishCopyFactory &amp;copyFactory
466 476
467 void SubscriptionStore::giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, 477 void SubscriptionStore::giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it,
468 std::vector<std::string>::const_iterator end, RetainedMessageNode *this_node, 478 std::vector<std::string>::const_iterator end, RetainedMessageNode *this_node,
469 - bool poundMode, std::forward_list<Publish> &packetList) const 479 + bool poundMode, std::forward_list<Publish> &packetList)
470 { 480 {
471 if (cur_subtopic_it == end) 481 if (cur_subtopic_it == end)
472 { 482 {
  483 + Authentication &auth = *ThreadGlobals::getAuth();
  484 +
473 auto pos = this_node->retainedMessages.begin(); 485 auto pos = this_node->retainedMessages.begin();
474 while (pos != this_node->retainedMessages.end()) 486 while (pos != this_node->retainedMessages.end())
475 { 487 {
476 auto cur = pos++; 488 auto cur = pos++;
477 - if (cur->publish.hasExpired()) 489 +
  490 + Publish publish = cur->publish;
  491 +
  492 + if (publish.hasExpired())
478 this_node->retainedMessages.erase(cur); 493 this_node->retainedMessages.erase(cur);
479 - else  
480 - packetList.emplace_front(cur->publish); // TODO: hmm, const stuff forces me/it to make copy 494 + else if (auth.aclCheck(publish) == AuthResult::success)
  495 + packetList.push_front(std::move(publish));
481 } 496 }
482 if (poundMode) 497 if (poundMode)
483 { 498 {
subscriptionstore.h
@@ -129,9 +129,9 @@ class SubscriptionStore @@ -129,9 +129,9 @@ class SubscriptionStore
129 std::forward_list<ReceivingSubscriber> &targetSessions); 129 std::forward_list<ReceivingSubscriber> &targetSessions);
130 static void publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end, 130 static void publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end,
131 SubscriptionNode *this_node, std::forward_list<ReceivingSubscriber> &targetSessions); 131 SubscriptionNode *this_node, std::forward_list<ReceivingSubscriber> &targetSessions);
132 - void giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, 132 + static void giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it,
133 std::vector<std::string>::const_iterator end, RetainedMessageNode *this_node, bool poundMode, 133 std::vector<std::string>::const_iterator end, RetainedMessageNode *this_node, bool poundMode,
134 - std::forward_list<Publish> &packetList) const; 134 + std::forward_list<Publish> &packetList);
135 void getRetainedMessages(RetainedMessageNode *this_node, std::vector<RetainedMessage> &outputList) const; 135 void getRetainedMessages(RetainedMessageNode *this_node, std::vector<RetainedMessage> &outputList) const;
136 void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root, 136 void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root,
137 std::unordered_map<std::string, std::list<SubscriptionForSerializing>> &outputList) const; 137 std::unordered_map<std::string, std::list<SubscriptionForSerializing>> &outputList) const;
@@ -199,6 +199,8 @@ class PublishBase @@ -199,6 +199,8 @@ class PublishBase
199 std::chrono::seconds expiresAfter; 199 std::chrono::seconds expiresAfter;
200 200
201 public: 201 public:
  202 + std::string client_id;
  203 + std::string username;
202 std::string topic; 204 std::string topic;
203 std::string payload; 205 std::string payload;
204 char qos = 0; 206 char qos = 0;