Commit 55108074b66746038666c0147e022142df03f8e7
1 parent
e61b7fbf
Evaluate ACL of delayed publishes too
This requires storing the clientid and username in the Publish object.
Showing
8 changed files
with
59 additions
and
16 deletions
authplugin.cpp
| @@ -278,6 +278,21 @@ void Authentication::securityCleanup(bool reloading) | @@ -278,6 +278,21 @@ void Authentication::securityCleanup(bool reloading) | ||
| 278 | } | 278 | } |
| 279 | } | 279 | } |
| 280 | 280 | ||
| 281 | +/** | ||
| 282 | + * @brief Authentication::aclCheck performs a write ACL check on the incoming publish. | ||
| 283 | + * @param publishData | ||
| 284 | + * @return | ||
| 285 | + */ | ||
| 286 | +AuthResult Authentication::aclCheck(Publish &publishData) | ||
| 287 | +{ | ||
| 288 | + // Anonymous publishes come from FlashMQ internally, like SYS topics. We need to allow them. | ||
| 289 | + if (publishData.client_id.empty()) | ||
| 290 | + return AuthResult::success; | ||
| 291 | + | ||
| 292 | + return aclCheck(publishData.client_id, publishData.username, publishData.topic, publishData.getSubtopics(), AclAccess::write, publishData.qos, | ||
| 293 | + publishData.retain, publishData.getUserProperties()); | ||
| 294 | +} | ||
| 295 | + | ||
| 281 | AuthResult Authentication::aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics, | 296 | AuthResult Authentication::aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics, |
| 282 | AclAccess access, char qos, bool retain, const std::vector<std::pair<std::string, std::string>> *userProperties) | 297 | AclAccess access, char qos, bool retain, const std::vector<std::pair<std::string, std::string>> *userProperties) |
| 283 | { | 298 | { |
authplugin.h
| @@ -156,6 +156,7 @@ public: | @@ -156,6 +156,7 @@ public: | ||
| 156 | void cleanup(); | 156 | void cleanup(); |
| 157 | void securityInit(bool reloading); | 157 | void securityInit(bool reloading); |
| 158 | void securityCleanup(bool reloading); | 158 | void securityCleanup(bool reloading); |
| 159 | + AuthResult aclCheck(Publish &publishData); | ||
| 159 | AuthResult aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics, | 160 | AuthResult aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector<std::string> &subtopics, |
| 160 | AclAccess access, char qos, bool retain, const std::vector<std::pair<std::string, std::string>> *userProperties); | 161 | AclAccess access, char qos, bool retain, const std::vector<std::pair<std::string, std::string>> *userProperties); |
| 161 | AuthResult unPwdCheck(const std::string &username, const std::string &password, | 162 | AuthResult unPwdCheck(const std::string &username, const std::string &password, |
client.cpp
| @@ -645,6 +645,8 @@ void Client::setClientProperties(ProtocolVersion protocolVersion, const std::str | @@ -645,6 +645,8 @@ void Client::setClientProperties(ProtocolVersion protocolVersion, const std::str | ||
| 645 | void Client::setWill(WillPublish &&willPublish) | 645 | void Client::setWill(WillPublish &&willPublish) |
| 646 | { | 646 | { |
| 647 | this->willPublish = std::make_shared<WillPublish>(std::move(willPublish)); | 647 | this->willPublish = std::make_shared<WillPublish>(std::move(willPublish)); |
| 648 | + this->willPublish->client_id = this->clientid; | ||
| 649 | + this->willPublish->username = this->username; | ||
| 648 | } | 650 | } |
| 649 | 651 | ||
| 650 | void Client::assignSession(std::shared_ptr<Session> &session) | 652 | void Client::assignSession(std::shared_ptr<Session> &session) |
mqttpacket.cpp
| @@ -138,6 +138,9 @@ MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish) | @@ -138,6 +138,9 @@ MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish) | ||
| 138 | 138 | ||
| 139 | this->protocolVersion = protocolVersion; | 139 | this->protocolVersion = protocolVersion; |
| 140 | 140 | ||
| 141 | + this->publishData.client_id = _publish.client_id; | ||
| 142 | + this->publishData.username = _publish.username; | ||
| 143 | + | ||
| 141 | if (!_publish.skipTopic) | 144 | if (!_publish.skipTopic) |
| 142 | this->publishData.topic = _publish.topic; | 145 | this->publishData.topic = _publish.topic; |
| 143 | 146 | ||
| @@ -1136,6 +1139,9 @@ void MqttPacket::parsePublishData() | @@ -1136,6 +1139,9 @@ void MqttPacket::parsePublishData() | ||
| 1136 | if (publishData.qos == 0 && duplicate) | 1139 | if (publishData.qos == 0 && duplicate) |
| 1137 | throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal.", ReasonCodes::MalformedPacket); | 1140 | throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal.", ReasonCodes::MalformedPacket); |
| 1138 | 1141 | ||
| 1142 | + publishData.username = sender->getUsername(); | ||
| 1143 | + publishData.client_id = sender->getClientId(); | ||
| 1144 | + | ||
| 1139 | publishData.topic = readBytesToString(true, true); | 1145 | publishData.topic = readBytesToString(true, true); |
| 1140 | 1146 | ||
| 1141 | if (publishData.qos) | 1147 | if (publishData.qos) |
| @@ -1262,7 +1268,7 @@ void MqttPacket::handlePublish() | @@ -1262,7 +1268,7 @@ void MqttPacket::handlePublish() | ||
| 1262 | if (publishData.qos == 2) | 1268 | if (publishData.qos == 2) |
| 1263 | sender->getSession()->addIncomingQoS2MessageId(_packet_id); | 1269 | sender->getSession()->addIncomingQoS2MessageId(_packet_id); |
| 1264 | 1270 | ||
| 1265 | - if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.getSubtopics(), AclAccess::write, publishData.qos, publishData.retain, getUserProperties()) == AuthResult::success) | 1271 | + if (authentication.aclCheck(this->publishData) == AuthResult::success) |
| 1266 | { | 1272 | { |
| 1267 | if (publishData.retain) | 1273 | if (publishData.retain) |
| 1268 | { | 1274 | { |
session.cpp
| @@ -226,6 +226,8 @@ bool Session::clearQosMessage(uint16_t packet_id, bool qosHandshakeEnds) | @@ -226,6 +226,8 @@ bool Session::clearQosMessage(uint16_t packet_id, bool qosHandshakeEnds) | ||
| 226 | */ | 226 | */ |
| 227 | void Session::sendAllPendingQosData() | 227 | void Session::sendAllPendingQosData() |
| 228 | { | 228 | { |
| 229 | + Authentication &authentication = *ThreadGlobals::getAuth(); | ||
| 230 | + | ||
| 229 | std::shared_ptr<Client> c = makeSharedClient(); | 231 | std::shared_ptr<Client> c = makeSharedClient(); |
| 230 | if (c) | 232 | if (c) |
| 231 | { | 233 | { |
| @@ -237,7 +239,7 @@ void Session::sendAllPendingQosData() | @@ -237,7 +239,7 @@ void Session::sendAllPendingQosData() | ||
| 237 | QueuedPublish &queuedPublish = *pos; | 239 | QueuedPublish &queuedPublish = *pos; |
| 238 | Publish &pub = queuedPublish.getPublish(); | 240 | Publish &pub = queuedPublish.getPublish(); |
| 239 | 241 | ||
| 240 | - if (pub.hasExpired()) | 242 | + if (pub.hasExpired() || (authentication.aclCheck(pub) != AuthResult::success)) |
| 241 | { | 243 | { |
| 242 | pos = qosPacketQueue.erase(pos); | 244 | pos = qosPacketQueue.erase(pos); |
| 243 | continue; | 245 | continue; |
subscriptionstore.cpp
| @@ -297,6 +297,8 @@ std::shared_ptr<Session> SubscriptionStore::lockSession(const std::string &clien | @@ -297,6 +297,8 @@ std::shared_ptr<Session> SubscriptionStore::lockSession(const std::string &clien | ||
| 297 | */ | 297 | */ |
| 298 | void SubscriptionStore::sendQueuedWillMessages() | 298 | void SubscriptionStore::sendQueuedWillMessages() |
| 299 | { | 299 | { |
| 300 | + Authentication &auth = *ThreadGlobals::getAuth(); | ||
| 301 | + | ||
| 300 | const auto now = std::chrono::steady_clock::now(); | 302 | const auto now = std::chrono::steady_clock::now(); |
| 301 | const std::chrono::seconds secondsSinceEpoch = std::chrono::duration_cast<std::chrono::seconds>(now.time_since_epoch()); | 303 | const std::chrono::seconds secondsSinceEpoch = std::chrono::duration_cast<std::chrono::seconds>(now.time_since_epoch()); |
| 302 | std::lock_guard<std::mutex> locker(this->pendingWillsMutex); | 304 | std::lock_guard<std::mutex> locker(this->pendingWillsMutex); |
| @@ -325,11 +327,14 @@ void SubscriptionStore::sendQueuedWillMessages() | @@ -325,11 +327,14 @@ void SubscriptionStore::sendQueuedWillMessages() | ||
| 325 | if (s && !s->hasActiveClient()) | 327 | if (s && !s->hasActiveClient()) |
| 326 | { | 328 | { |
| 327 | logger->logf(LOG_DEBUG, "Sending delayed will on topic '%s'.", p->topic.c_str() ); | 329 | logger->logf(LOG_DEBUG, "Sending delayed will on topic '%s'.", p->topic.c_str() ); |
| 328 | - PublishCopyFactory factory(p.get()); | ||
| 329 | - queuePacketAtSubscribers(factory); | 330 | + if (auth.aclCheck(*p) == AuthResult::success) |
| 331 | + { | ||
| 332 | + PublishCopyFactory factory(p.get()); | ||
| 333 | + queuePacketAtSubscribers(factory); | ||
| 330 | 334 | ||
| 331 | - if (p->retain) | ||
| 332 | - setRetainedMessage(*p, p->getSubtopics()); | 335 | + if (p->retain) |
| 336 | + setRetainedMessage(*p, p->getSubtopics()); | ||
| 337 | + } | ||
| 333 | 338 | ||
| 334 | s->clearWill(); | 339 | s->clearWill(); |
| 335 | } | 340 | } |
| @@ -352,16 +357,21 @@ void SubscriptionStore::queueWillMessage(const std::shared_ptr<WillPublish> &wil | @@ -352,16 +357,21 @@ void SubscriptionStore::queueWillMessage(const std::shared_ptr<WillPublish> &wil | ||
| 352 | if (!willMessage) | 357 | if (!willMessage) |
| 353 | return; | 358 | return; |
| 354 | 359 | ||
| 360 | + Authentication &auth = *ThreadGlobals::getAuth(); | ||
| 361 | + | ||
| 355 | const int delay = forceNow ? 0 : willMessage->will_delay; | 362 | const int delay = forceNow ? 0 : willMessage->will_delay; |
| 356 | logger->logf(LOG_DEBUG, "Queueing will on topic '%s', with delay %d seconds.", willMessage->topic.c_str(), delay ); | 363 | logger->logf(LOG_DEBUG, "Queueing will on topic '%s', with delay %d seconds.", willMessage->topic.c_str(), delay ); |
| 357 | 364 | ||
| 358 | if (delay == 0) | 365 | if (delay == 0) |
| 359 | { | 366 | { |
| 360 | - PublishCopyFactory factory(willMessage.get()); | ||
| 361 | - queuePacketAtSubscribers(factory); | 367 | + if (auth.aclCheck(*willMessage) == AuthResult::success) |
| 368 | + { | ||
| 369 | + PublishCopyFactory factory(willMessage.get()); | ||
| 370 | + queuePacketAtSubscribers(factory); | ||
| 362 | 371 | ||
| 363 | - if (willMessage->retain) | ||
| 364 | - setRetainedMessage(*willMessage.get(), (*willMessage).getSubtopics()); | 372 | + if (willMessage->retain) |
| 373 | + setRetainedMessage(*willMessage.get(), (*willMessage).getSubtopics()); | ||
| 374 | + } | ||
| 365 | 375 | ||
| 366 | // Avoid sending two immediate wills when a session is destroyed with the client disconnect. | 376 | // Avoid sending two immediate wills when a session is destroyed with the client disconnect. |
| 367 | if (session) // session is null when you're destroying a client before a session is assigned. | 377 | if (session) // session is null when you're destroying a client before a session is assigned. |
| @@ -466,18 +476,23 @@ void SubscriptionStore::queuePacketAtSubscribers(PublishCopyFactory &copyFactory | @@ -466,18 +476,23 @@ void SubscriptionStore::queuePacketAtSubscribers(PublishCopyFactory &copyFactory | ||
| 466 | 476 | ||
| 467 | void SubscriptionStore::giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, | 477 | void SubscriptionStore::giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, |
| 468 | std::vector<std::string>::const_iterator end, RetainedMessageNode *this_node, | 478 | std::vector<std::string>::const_iterator end, RetainedMessageNode *this_node, |
| 469 | - bool poundMode, std::forward_list<Publish> &packetList) const | 479 | + bool poundMode, std::forward_list<Publish> &packetList) |
| 470 | { | 480 | { |
| 471 | if (cur_subtopic_it == end) | 481 | if (cur_subtopic_it == end) |
| 472 | { | 482 | { |
| 483 | + Authentication &auth = *ThreadGlobals::getAuth(); | ||
| 484 | + | ||
| 473 | auto pos = this_node->retainedMessages.begin(); | 485 | auto pos = this_node->retainedMessages.begin(); |
| 474 | while (pos != this_node->retainedMessages.end()) | 486 | while (pos != this_node->retainedMessages.end()) |
| 475 | { | 487 | { |
| 476 | auto cur = pos++; | 488 | auto cur = pos++; |
| 477 | - if (cur->publish.hasExpired()) | 489 | + |
| 490 | + Publish publish = cur->publish; | ||
| 491 | + | ||
| 492 | + if (publish.hasExpired()) | ||
| 478 | this_node->retainedMessages.erase(cur); | 493 | this_node->retainedMessages.erase(cur); |
| 479 | - else | ||
| 480 | - packetList.emplace_front(cur->publish); // TODO: hmm, const stuff forces me/it to make copy | 494 | + else if (auth.aclCheck(publish) == AuthResult::success) |
| 495 | + packetList.push_front(std::move(publish)); | ||
| 481 | } | 496 | } |
| 482 | if (poundMode) | 497 | if (poundMode) |
| 483 | { | 498 | { |
subscriptionstore.h
| @@ -129,9 +129,9 @@ class SubscriptionStore | @@ -129,9 +129,9 @@ class SubscriptionStore | ||
| 129 | std::forward_list<ReceivingSubscriber> &targetSessions); | 129 | std::forward_list<ReceivingSubscriber> &targetSessions); |
| 130 | static void publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end, | 130 | static void publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end, |
| 131 | SubscriptionNode *this_node, std::forward_list<ReceivingSubscriber> &targetSessions); | 131 | SubscriptionNode *this_node, std::forward_list<ReceivingSubscriber> &targetSessions); |
| 132 | - void giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, | 132 | + static void giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, |
| 133 | std::vector<std::string>::const_iterator end, RetainedMessageNode *this_node, bool poundMode, | 133 | std::vector<std::string>::const_iterator end, RetainedMessageNode *this_node, bool poundMode, |
| 134 | - std::forward_list<Publish> &packetList) const; | 134 | + std::forward_list<Publish> &packetList); |
| 135 | void getRetainedMessages(RetainedMessageNode *this_node, std::vector<RetainedMessage> &outputList) const; | 135 | void getRetainedMessages(RetainedMessageNode *this_node, std::vector<RetainedMessage> &outputList) const; |
| 136 | void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root, | 136 | void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root, |
| 137 | std::unordered_map<std::string, std::list<SubscriptionForSerializing>> &outputList) const; | 137 | std::unordered_map<std::string, std::list<SubscriptionForSerializing>> &outputList) const; |
types.h
| @@ -199,6 +199,8 @@ class PublishBase | @@ -199,6 +199,8 @@ class PublishBase | ||
| 199 | std::chrono::seconds expiresAfter; | 199 | std::chrono::seconds expiresAfter; |
| 200 | 200 | ||
| 201 | public: | 201 | public: |
| 202 | + std::string client_id; | ||
| 203 | + std::string username; | ||
| 202 | std::string topic; | 204 | std::string topic; |
| 203 | std::string payload; | 205 | std::string payload; |
| 204 | char qos = 0; | 206 | char qos = 0; |