Commit 55108074b66746038666c0147e022142df03f8e7

Authored by Wiebe Cazemier
1 parent e61b7fbf

Evaluate ACL of delayed publishes too

This requires storing the clientid and username in the Publish object.
authplugin.cpp
... ... @@ -278,6 +278,21 @@ void Authentication::securityCleanup(bool reloading)
278 278 }
279 279 }
280 280  
  281 +/**
  282 + * @brief Authentication::aclCheck performs a write ACL check on the incoming publish.
  283 + * @param publishData
  284 + * @return
  285 + */
  286 +AuthResult Authentication::aclCheck(Publish &publishData)
  287 +{
  288 + // Anonymous publishes come from FlashMQ internally, like SYS topics. We need to allow them.
  289 + if (publishData.client_id.empty())
  290 + return AuthResult::success;
  291 +
  292 + return aclCheck(publishData.client_id, publishData.username, publishData.topic, publishData.getSubtopics(), AclAccess::write, publishData.qos,
  293 + publishData.retain, publishData.getUserProperties());
  294 +}
  295 +
281 296 AuthResult Authentication::aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics,
282 297 AclAccess access, char qos, bool retain, const std::vector<std::pair<std::string, std::string>> *userProperties)
283 298 {
... ...
authplugin.h
... ... @@ -156,6 +156,7 @@ public:
156 156 void cleanup();
157 157 void securityInit(bool reloading);
158 158 void securityCleanup(bool reloading);
  159 + AuthResult aclCheck(Publish &publishData);
159 160 AuthResult aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics,
160 161 AclAccess access, char qos, bool retain, const std::vector<std::pair<std::string, std::string>> *userProperties);
161 162 AuthResult unPwdCheck(const std::string &username, const std::string &password,
... ...
client.cpp
... ... @@ -645,6 +645,8 @@ void Client::setClientProperties(ProtocolVersion protocolVersion, const std::str
645 645 void Client::setWill(WillPublish &&willPublish)
646 646 {
647 647 this->willPublish = std::make_shared<WillPublish>(std::move(willPublish));
  648 + this->willPublish->client_id = this->clientid;
  649 + this->willPublish->username = this->username;
648 650 }
649 651  
650 652 void Client::assignSession(std::shared_ptr<Session> &session)
... ...
mqttpacket.cpp
... ... @@ -138,6 +138,9 @@ MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, Publish &amp;_publish)
138 138  
139 139 this->protocolVersion = protocolVersion;
140 140  
  141 + this->publishData.client_id = _publish.client_id;
  142 + this->publishData.username = _publish.username;
  143 +
141 144 if (!_publish.skipTopic)
142 145 this->publishData.topic = _publish.topic;
143 146  
... ... @@ -1136,6 +1139,9 @@ void MqttPacket::parsePublishData()
1136 1139 if (publishData.qos == 0 && duplicate)
1137 1140 throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal.", ReasonCodes::MalformedPacket);
1138 1141  
  1142 + publishData.username = sender->getUsername();
  1143 + publishData.client_id = sender->getClientId();
  1144 +
1139 1145 publishData.topic = readBytesToString(true, true);
1140 1146  
1141 1147 if (publishData.qos)
... ... @@ -1262,7 +1268,7 @@ void MqttPacket::handlePublish()
1262 1268 if (publishData.qos == 2)
1263 1269 sender->getSession()->addIncomingQoS2MessageId(_packet_id);
1264 1270  
1265   - if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.getSubtopics(), AclAccess::write, publishData.qos, publishData.retain, getUserProperties()) == AuthResult::success)
  1271 + if (authentication.aclCheck(this->publishData) == AuthResult::success)
1266 1272 {
1267 1273 if (publishData.retain)
1268 1274 {
... ...
session.cpp
... ... @@ -226,6 +226,8 @@ bool Session::clearQosMessage(uint16_t packet_id, bool qosHandshakeEnds)
226 226 */
227 227 void Session::sendAllPendingQosData()
228 228 {
  229 + Authentication &authentication = *ThreadGlobals::getAuth();
  230 +
229 231 std::shared_ptr<Client> c = makeSharedClient();
230 232 if (c)
231 233 {
... ... @@ -237,7 +239,7 @@ void Session::sendAllPendingQosData()
237 239 QueuedPublish &queuedPublish = *pos;
238 240 Publish &pub = queuedPublish.getPublish();
239 241  
240   - if (pub.hasExpired())
  242 + if (pub.hasExpired() || (authentication.aclCheck(pub) != AuthResult::success))
241 243 {
242 244 pos = qosPacketQueue.erase(pos);
243 245 continue;
... ...
subscriptionstore.cpp
... ... @@ -297,6 +297,8 @@ std::shared_ptr&lt;Session&gt; SubscriptionStore::lockSession(const std::string &amp;clien
297 297 */
298 298 void SubscriptionStore::sendQueuedWillMessages()
299 299 {
  300 + Authentication &auth = *ThreadGlobals::getAuth();
  301 +
300 302 const auto now = std::chrono::steady_clock::now();
301 303 const std::chrono::seconds secondsSinceEpoch = std::chrono::duration_cast<std::chrono::seconds>(now.time_since_epoch());
302 304 std::lock_guard<std::mutex> locker(this->pendingWillsMutex);
... ... @@ -325,11 +327,14 @@ void SubscriptionStore::sendQueuedWillMessages()
325 327 if (s && !s->hasActiveClient())
326 328 {
327 329 logger->logf(LOG_DEBUG, "Sending delayed will on topic '%s'.", p->topic.c_str() );
328   - PublishCopyFactory factory(p.get());
329   - queuePacketAtSubscribers(factory);
  330 + if (auth.aclCheck(*p) == AuthResult::success)
  331 + {
  332 + PublishCopyFactory factory(p.get());
  333 + queuePacketAtSubscribers(factory);
330 334  
331   - if (p->retain)
332   - setRetainedMessage(*p, p->getSubtopics());
  335 + if (p->retain)
  336 + setRetainedMessage(*p, p->getSubtopics());
  337 + }
333 338  
334 339 s->clearWill();
335 340 }
... ... @@ -352,16 +357,21 @@ void SubscriptionStore::queueWillMessage(const std::shared_ptr&lt;WillPublish&gt; &amp;wil
352 357 if (!willMessage)
353 358 return;
354 359  
  360 + Authentication &auth = *ThreadGlobals::getAuth();
  361 +
355 362 const int delay = forceNow ? 0 : willMessage->will_delay;
356 363 logger->logf(LOG_DEBUG, "Queueing will on topic '%s', with delay %d seconds.", willMessage->topic.c_str(), delay );
357 364  
358 365 if (delay == 0)
359 366 {
360   - PublishCopyFactory factory(willMessage.get());
361   - queuePacketAtSubscribers(factory);
  367 + if (auth.aclCheck(*willMessage) == AuthResult::success)
  368 + {
  369 + PublishCopyFactory factory(willMessage.get());
  370 + queuePacketAtSubscribers(factory);
362 371  
363   - if (willMessage->retain)
364   - setRetainedMessage(*willMessage.get(), (*willMessage).getSubtopics());
  372 + if (willMessage->retain)
  373 + setRetainedMessage(*willMessage.get(), (*willMessage).getSubtopics());
  374 + }
365 375  
366 376 // Avoid sending two immediate wills when a session is destroyed with the client disconnect.
367 377 if (session) // session is null when you're destroying a client before a session is assigned.
... ... @@ -466,18 +476,23 @@ void SubscriptionStore::queuePacketAtSubscribers(PublishCopyFactory &amp;copyFactory
466 476  
467 477 void SubscriptionStore::giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it,
468 478 std::vector<std::string>::const_iterator end, RetainedMessageNode *this_node,
469   - bool poundMode, std::forward_list<Publish> &packetList) const
  479 + bool poundMode, std::forward_list<Publish> &packetList)
470 480 {
471 481 if (cur_subtopic_it == end)
472 482 {
  483 + Authentication &auth = *ThreadGlobals::getAuth();
  484 +
473 485 auto pos = this_node->retainedMessages.begin();
474 486 while (pos != this_node->retainedMessages.end())
475 487 {
476 488 auto cur = pos++;
477   - if (cur->publish.hasExpired())
  489 +
  490 + Publish publish = cur->publish;
  491 +
  492 + if (publish.hasExpired())
478 493 this_node->retainedMessages.erase(cur);
479   - else
480   - packetList.emplace_front(cur->publish); // TODO: hmm, const stuff forces me/it to make copy
  494 + else if (auth.aclCheck(publish) == AuthResult::success)
  495 + packetList.push_front(std::move(publish));
481 496 }
482 497 if (poundMode)
483 498 {
... ...
subscriptionstore.h
... ... @@ -129,9 +129,9 @@ class SubscriptionStore
129 129 std::forward_list<ReceivingSubscriber> &targetSessions);
130 130 static void publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end,
131 131 SubscriptionNode *this_node, std::forward_list<ReceivingSubscriber> &targetSessions);
132   - void giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it,
  132 + static void giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it,
133 133 std::vector<std::string>::const_iterator end, RetainedMessageNode *this_node, bool poundMode,
134   - std::forward_list<Publish> &packetList) const;
  134 + std::forward_list<Publish> &packetList);
135 135 void getRetainedMessages(RetainedMessageNode *this_node, std::vector<RetainedMessage> &outputList) const;
136 136 void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root,
137 137 std::unordered_map<std::string, std::list<SubscriptionForSerializing>> &outputList) const;
... ...
... ... @@ -199,6 +199,8 @@ class PublishBase
199 199 std::chrono::seconds expiresAfter;
200 200  
201 201 public:
  202 + std::string client_id;
  203 + std::string username;
202 204 std::string topic;
203 205 std::string payload;
204 206 char qos = 0;
... ...