diff --git a/authplugin.cpp b/authplugin.cpp index dbeea39..63603d1 100644 --- a/authplugin.cpp +++ b/authplugin.cpp @@ -278,6 +278,21 @@ void Authentication::securityCleanup(bool reloading) } } +/** + * @brief Authentication::aclCheck performs a write ACL check on the incoming publish. + * @param publishData + * @return + */ +AuthResult Authentication::aclCheck(Publish &publishData) +{ + // Anonymous publishes come from FlashMQ internally, like SYS topics. We need to allow them. + if (publishData.client_id.empty()) + return AuthResult::success; + + return aclCheck(publishData.client_id, publishData.username, publishData.topic, publishData.getSubtopics(), AclAccess::write, publishData.qos, + publishData.retain, publishData.getUserProperties()); +} + AuthResult Authentication::aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, AclAccess access, char qos, bool retain, const std::vector> *userProperties) { diff --git a/authplugin.h b/authplugin.h index e89927e..479fed3 100644 --- a/authplugin.h +++ b/authplugin.h @@ -156,6 +156,7 @@ public: void cleanup(); void securityInit(bool reloading); void securityCleanup(bool reloading); + AuthResult aclCheck(Publish &publishData); AuthResult aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, AclAccess access, char qos, bool retain, const std::vector> *userProperties); AuthResult unPwdCheck(const std::string &username, const std::string &password, diff --git a/client.cpp b/client.cpp index 1e6df1b..3a42fde 100644 --- a/client.cpp +++ b/client.cpp @@ -645,6 +645,8 @@ void Client::setClientProperties(ProtocolVersion protocolVersion, const std::str void Client::setWill(WillPublish &&willPublish) { this->willPublish = std::make_shared(std::move(willPublish)); + this->willPublish->client_id = this->clientid; + this->willPublish->username = this->username; } void Client::assignSession(std::shared_ptr &session) diff --git a/mqttpacket.cpp b/mqttpacket.cpp index f285373..6fc731d 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -138,6 +138,9 @@ MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish) this->protocolVersion = protocolVersion; + this->publishData.client_id = _publish.client_id; + this->publishData.username = _publish.username; + if (!_publish.skipTopic) this->publishData.topic = _publish.topic; @@ -1136,6 +1139,9 @@ void MqttPacket::parsePublishData() if (publishData.qos == 0 && duplicate) throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal.", ReasonCodes::MalformedPacket); + publishData.username = sender->getUsername(); + publishData.client_id = sender->getClientId(); + publishData.topic = readBytesToString(true, true); if (publishData.qos) @@ -1262,7 +1268,7 @@ void MqttPacket::handlePublish() if (publishData.qos == 2) sender->getSession()->addIncomingQoS2MessageId(_packet_id); - if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.getSubtopics(), AclAccess::write, publishData.qos, publishData.retain, getUserProperties()) == AuthResult::success) + if (authentication.aclCheck(this->publishData) == AuthResult::success) { if (publishData.retain) { diff --git a/session.cpp b/session.cpp index b7770e7..fd22f31 100644 --- a/session.cpp +++ b/session.cpp @@ -226,6 +226,8 @@ bool Session::clearQosMessage(uint16_t packet_id, bool qosHandshakeEnds) */ void Session::sendAllPendingQosData() { + Authentication &authentication = *ThreadGlobals::getAuth(); + std::shared_ptr c = makeSharedClient(); if (c) { @@ -237,7 +239,7 @@ void Session::sendAllPendingQosData() QueuedPublish &queuedPublish = *pos; Publish &pub = queuedPublish.getPublish(); - if (pub.hasExpired()) + if (pub.hasExpired() || (authentication.aclCheck(pub) != AuthResult::success)) { pos = qosPacketQueue.erase(pos); continue; diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index ad7c396..1892b02 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -297,6 +297,8 @@ std::shared_ptr SubscriptionStore::lockSession(const std::string &clien */ void SubscriptionStore::sendQueuedWillMessages() { + Authentication &auth = *ThreadGlobals::getAuth(); + const auto now = std::chrono::steady_clock::now(); const std::chrono::seconds secondsSinceEpoch = std::chrono::duration_cast(now.time_since_epoch()); std::lock_guard locker(this->pendingWillsMutex); @@ -325,11 +327,14 @@ void SubscriptionStore::sendQueuedWillMessages() if (s && !s->hasActiveClient()) { logger->logf(LOG_DEBUG, "Sending delayed will on topic '%s'.", p->topic.c_str() ); - PublishCopyFactory factory(p.get()); - queuePacketAtSubscribers(factory); + if (auth.aclCheck(*p) == AuthResult::success) + { + PublishCopyFactory factory(p.get()); + queuePacketAtSubscribers(factory); - if (p->retain) - setRetainedMessage(*p, p->getSubtopics()); + if (p->retain) + setRetainedMessage(*p, p->getSubtopics()); + } s->clearWill(); } @@ -352,16 +357,21 @@ void SubscriptionStore::queueWillMessage(const std::shared_ptr &wil if (!willMessage) return; + Authentication &auth = *ThreadGlobals::getAuth(); + const int delay = forceNow ? 0 : willMessage->will_delay; logger->logf(LOG_DEBUG, "Queueing will on topic '%s', with delay %d seconds.", willMessage->topic.c_str(), delay ); if (delay == 0) { - PublishCopyFactory factory(willMessage.get()); - queuePacketAtSubscribers(factory); + if (auth.aclCheck(*willMessage) == AuthResult::success) + { + PublishCopyFactory factory(willMessage.get()); + queuePacketAtSubscribers(factory); - if (willMessage->retain) - setRetainedMessage(*willMessage.get(), (*willMessage).getSubtopics()); + if (willMessage->retain) + setRetainedMessage(*willMessage.get(), (*willMessage).getSubtopics()); + } // Avoid sending two immediate wills when a session is destroyed with the client disconnect. if (session) // session is null when you're destroying a client before a session is assigned. @@ -466,18 +476,23 @@ void SubscriptionStore::queuePacketAtSubscribers(PublishCopyFactory ©Factory void SubscriptionStore::giveClientRetainedMessagesRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, RetainedMessageNode *this_node, - bool poundMode, std::forward_list &packetList) const + bool poundMode, std::forward_list &packetList) { if (cur_subtopic_it == end) { + Authentication &auth = *ThreadGlobals::getAuth(); + auto pos = this_node->retainedMessages.begin(); while (pos != this_node->retainedMessages.end()) { auto cur = pos++; - if (cur->publish.hasExpired()) + + Publish publish = cur->publish; + + if (publish.hasExpired()) this_node->retainedMessages.erase(cur); - else - packetList.emplace_front(cur->publish); // TODO: hmm, const stuff forces me/it to make copy + else if (auth.aclCheck(publish) == AuthResult::success) + packetList.push_front(std::move(publish)); } if (poundMode) { diff --git a/subscriptionstore.h b/subscriptionstore.h index 3584576..ca77405 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -129,9 +129,9 @@ class SubscriptionStore std::forward_list &targetSessions); static void publishRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, SubscriptionNode *this_node, std::forward_list &targetSessions); - void giveClientRetainedMessagesRecursively(std::vector::const_iterator cur_subtopic_it, + static void giveClientRetainedMessagesRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, RetainedMessageNode *this_node, bool poundMode, - std::forward_list &packetList) const; + std::forward_list &packetList); void getRetainedMessages(RetainedMessageNode *this_node, std::vector &outputList) const; void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root, std::unordered_map> &outputList) const; diff --git a/types.h b/types.h index 74f39da..ed12918 100644 --- a/types.h +++ b/types.h @@ -199,6 +199,8 @@ class PublishBase std::chrono::seconds expiresAfter; public: + std::string client_id; + std::string username; std::string topic; std::string payload; char qos = 0;