Commit 55108074b66746038666c0147e022142df03f8e7
1 parent
e61b7fbf
Evaluate ACL of delayed publishes too
This requires storing the clientid and username in the Publish object.
Showing
8 changed files
with
59 additions
and
16 deletions
authplugin.cpp
| ... | ... | @@ -278,6 +278,21 @@ void Authentication::securityCleanup(bool reloading) |
| 278 | 278 | } |
| 279 | 279 | } |
| 280 | 280 | |
| 281 | +/** | |
| 282 | + * @brief Authentication::aclCheck performs a write ACL check on the incoming publish. | |
| 283 | + * @param publishData | |
| 284 | + * @return | |
| 285 | + */ | |
| 286 | +AuthResult Authentication::aclCheck(Publish &publishData) | |
| 287 | +{ | |
| 288 | + // Anonymous publishes come from FlashMQ internally, like SYS topics. We need to allow them. | |
| 289 | + if (publishData.client_id.empty()) | |
| 290 | + return AuthResult::success; | |
| 291 | + | |
| 292 | + return aclCheck(publishData.client_id, publishData.username, publishData.topic, publishData.getSubtopics(), AclAccess::write, publishData.qos, | |
| 293 | + publishData.retain, publishData.getUserProperties()); | |
| 294 | +} | |
| 295 | + | |
| 281 | 296 | AuthResult Authentication::aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics, |
| 282 | 297 | AclAccess access, char qos, bool retain, const std::vector<std::pair<std::string, std::string>> *userProperties) |
| 283 | 298 | { | ... | ... |
authplugin.h
| ... | ... | @@ -156,6 +156,7 @@ public: |
| 156 | 156 | void cleanup(); |
| 157 | 157 | void securityInit(bool reloading); |
| 158 | 158 | void securityCleanup(bool reloading); |
| 159 | + AuthResult aclCheck(Publish &publishData); | |
| 159 | 160 | AuthResult aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics, |
| 160 | 161 | AclAccess access, char qos, bool retain, const std::vector<std::pair<std::string, std::string>> *userProperties); |
| 161 | 162 | AuthResult unPwdCheck(const std::string &username, const std::string &password, | ... | ... |
client.cpp
| ... | ... | @@ -645,6 +645,8 @@ void Client::setClientProperties(ProtocolVersion protocolVersion, const std::str |
| 645 | 645 | void Client::setWill(WillPublish &&willPublish) |
| 646 | 646 | { |
| 647 | 647 | this->willPublish = std::make_shared<WillPublish>(std::move(willPublish)); |
| 648 | + this->willPublish->client_id = this->clientid; | |
| 649 | + this->willPublish->username = this->username; | |
| 648 | 650 | } |
| 649 | 651 | |
| 650 | 652 | void Client::assignSession(std::shared_ptr<Session> &session) | ... | ... |
mqttpacket.cpp
| ... | ... | @@ -138,6 +138,9 @@ MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish) |
| 138 | 138 | |
| 139 | 139 | this->protocolVersion = protocolVersion; |
| 140 | 140 | |
| 141 | + this->publishData.client_id = _publish.client_id; | |
| 142 | + this->publishData.username = _publish.username; | |
| 143 | + | |
| 141 | 144 | if (!_publish.skipTopic) |
| 142 | 145 | this->publishData.topic = _publish.topic; |
| 143 | 146 | |
| ... | ... | @@ -1136,6 +1139,9 @@ void MqttPacket::parsePublishData() |
| 1136 | 1139 | if (publishData.qos == 0 && duplicate) |
| 1137 | 1140 | throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal.", ReasonCodes::MalformedPacket); |
| 1138 | 1141 | |
| 1142 | + publishData.username = sender->getUsername(); | |
| 1143 | + publishData.client_id = sender->getClientId(); | |
| 1144 | + | |
| 1139 | 1145 | publishData.topic = readBytesToString(true, true); |
| 1140 | 1146 | |
| 1141 | 1147 | if (publishData.qos) |
| ... | ... | @@ -1262,7 +1268,7 @@ void MqttPacket::handlePublish() |
| 1262 | 1268 | if (publishData.qos == 2) |
| 1263 | 1269 | sender->getSession()->addIncomingQoS2MessageId(_packet_id); |
| 1264 | 1270 | |
| 1265 | - if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.getSubtopics(), AclAccess::write, publishData.qos, publishData.retain, getUserProperties()) == AuthResult::success) | |
| 1271 | + if (authentication.aclCheck(this->publishData) == AuthResult::success) | |
| 1266 | 1272 | { |
| 1267 | 1273 | if (publishData.retain) |
| 1268 | 1274 | { | ... | ... |
session.cpp
| ... | ... | @@ -226,6 +226,8 @@ bool Session::clearQosMessage(uint16_t packet_id, bool qosHandshakeEnds) |
| 226 | 226 | */ |
| 227 | 227 | void Session::sendAllPendingQosData() |
| 228 | 228 | { |
| 229 | + Authentication &authentication = *ThreadGlobals::getAuth(); | |
| 230 | + | |
| 229 | 231 | std::shared_ptr<Client> c = makeSharedClient(); |
| 230 | 232 | if (c) |
| 231 | 233 | { |
| ... | ... | @@ -237,7 +239,7 @@ void Session::sendAllPendingQosData() |
| 237 | 239 | QueuedPublish &queuedPublish = *pos; |
| 238 | 240 | Publish &pub = queuedPublish.getPublish(); |
| 239 | 241 | |
| 240 | - if (pub.hasExpired()) | |
| 242 | + if (pub.hasExpired() || (authentication.aclCheck(pub) != AuthResult::success)) | |
| 241 | 243 | { |
| 242 | 244 | pos = qosPacketQueue.erase(pos); |
| 243 | 245 | continue; | ... | ... |
subscriptionstore.cpp
| ... | ... | @@ -297,6 +297,8 @@ std::shared_ptr<Session> SubscriptionStore::lockSession(const std::string &clien |
| 297 | 297 | */ |
| 298 | 298 | void SubscriptionStore::sendQueuedWillMessages() |
| 299 | 299 | { |
| 300 | + Authentication &auth = *ThreadGlobals::getAuth(); | |
| 301 | + | |
| 300 | 302 | const auto now = std::chrono::steady_clock::now(); |
| 301 | 303 | const std::chrono::seconds secondsSinceEpoch = std::chrono::duration_cast<std::chrono::seconds>(now.time_since_epoch()); |
| 302 | 304 | std::lock_guard<std::mutex> locker(this->pendingWillsMutex); |
| ... | ... | @@ -325,11 +327,14 @@ void SubscriptionStore::sendQueuedWillMessages() |
| 325 | 327 | if (s && !s->hasActiveClient()) |
| 326 | 328 | { |
| 327 | 329 | logger->logf(LOG_DEBUG, "Sending delayed will on topic '%s'.", p->topic.c_str() ); |
| 328 | - PublishCopyFactory factory(p.get()); | |
| 329 | - queuePacketAtSubscribers(factory); | |
| 330 | + if (auth.aclCheck(*p) == AuthResult::success) | |
| 331 | + { | |
| 332 | + PublishCopyFactory factory(p.get()); | |
| 333 | + queuePacketAtSubscribers(factory); | |
| 330 | 334 | |
| 331 | - if (p->retain) | |
| 332 | - setRetainedMessage(*p, p->getSubtopics()); | |
| 335 | + if (p->retain) | |
| 336 | + setRetainedMessage(*p, p->getSubtopics()); | |
| 337 | + } | |
| 333 | 338 | |
| 334 | 339 | s->clearWill(); |
| 335 | 340 | } |
| ... | ... | @@ -352,16 +357,21 @@ void SubscriptionStore::queueWillMessage(const std::shared_ptr<WillPublish> &wil |
| 352 | 357 | if (!willMessage) |
| 353 | 358 | return; |
| 354 | 359 | |
| 360 | + Authentication &auth = *ThreadGlobals::getAuth(); | |
| 361 | + | |
| 355 | 362 | const int delay = forceNow ? 0 : willMessage->will_delay; |
| 356 | 363 | logger->logf(LOG_DEBUG, "Queueing will on topic '%s', with delay %d seconds.", willMessage->topic.c_str(), delay ); |
| 357 | 364 | |
| 358 | 365 | if (delay == 0) |
| 359 | 366 | { |
| 360 | - PublishCopyFactory factory(willMessage.get()); | |
| 361 | - queuePacketAtSubscribers(factory); | |
| 367 | + if (auth.aclCheck(*willMessage) == AuthResult::success) | |
| 368 | + { | |
| 369 | + PublishCopyFactory factory(willMessage.get()); | |
| 370 | + queuePacketAtSubscribers(factory); | |
| 362 | 371 | |
| 363 | - if (willMessage->retain) | |
| 364 | - setRetainedMessage(*willMessage.get(), (*willMessage).getSubtopics()); | |
| 372 | + if (willMessage->retain) | |
| 373 | + setRetainedMessage(*willMessage.get(), (*willMessage).getSubtopics()); | |
| 374 | + } | |
| 365 | 375 | |
| 366 | 376 | // Avoid sending two immediate wills when a session is destroyed with the client disconnect. |
| 367 | 377 | if (session) // session is null when you're destroying a client before a session is assigned. |
| ... | ... | @@ -466,18 +476,23 @@ void SubscriptionStore::queuePacketAtSubscribers(PublishCopyFactory &copyFactory |
| 466 | 476 | |
| 467 | 477 | void SubscriptionStore::giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, |
| 468 | 478 | std::vector<std::string>::const_iterator end, RetainedMessageNode *this_node, |
| 469 | - bool poundMode, std::forward_list<Publish> &packetList) const | |
| 479 | + bool poundMode, std::forward_list<Publish> &packetList) | |
| 470 | 480 | { |
| 471 | 481 | if (cur_subtopic_it == end) |
| 472 | 482 | { |
| 483 | + Authentication &auth = *ThreadGlobals::getAuth(); | |
| 484 | + | |
| 473 | 485 | auto pos = this_node->retainedMessages.begin(); |
| 474 | 486 | while (pos != this_node->retainedMessages.end()) |
| 475 | 487 | { |
| 476 | 488 | auto cur = pos++; |
| 477 | - if (cur->publish.hasExpired()) | |
| 489 | + | |
| 490 | + Publish publish = cur->publish; | |
| 491 | + | |
| 492 | + if (publish.hasExpired()) | |
| 478 | 493 | this_node->retainedMessages.erase(cur); |
| 479 | - else | |
| 480 | - packetList.emplace_front(cur->publish); // TODO: hmm, const stuff forces me/it to make copy | |
| 494 | + else if (auth.aclCheck(publish) == AuthResult::success) | |
| 495 | + packetList.push_front(std::move(publish)); | |
| 481 | 496 | } |
| 482 | 497 | if (poundMode) |
| 483 | 498 | { | ... | ... |
subscriptionstore.h
| ... | ... | @@ -129,9 +129,9 @@ class SubscriptionStore |
| 129 | 129 | std::forward_list<ReceivingSubscriber> &targetSessions); |
| 130 | 130 | static void publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end, |
| 131 | 131 | SubscriptionNode *this_node, std::forward_list<ReceivingSubscriber> &targetSessions); |
| 132 | - void giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, | |
| 132 | + static void giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, | |
| 133 | 133 | std::vector<std::string>::const_iterator end, RetainedMessageNode *this_node, bool poundMode, |
| 134 | - std::forward_list<Publish> &packetList) const; | |
| 134 | + std::forward_list<Publish> &packetList); | |
| 135 | 135 | void getRetainedMessages(RetainedMessageNode *this_node, std::vector<RetainedMessage> &outputList) const; |
| 136 | 136 | void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root, |
| 137 | 137 | std::unordered_map<std::string, std::list<SubscriptionForSerializing>> &outputList) const; | ... | ... |