diff --git a/authplugin.cpp b/authplugin.cpp index 4906824..f13ec03 100644 --- a/authplugin.cpp +++ b/authplugin.cpp @@ -181,4 +181,18 @@ AuthResult AuthPlugin::unPwdCheck(const std::string &username, const std::string return r; } +std::string AuthResultToString(AuthResult r) +{ + { + if (r == AuthResult::success) + return "success"; + if (r == AuthResult::acl_denied) + return "ACL denied"; + if (r == AuthResult::login_denied) + return "login Denied"; + if (r == AuthResult::error) + return "error in check"; + } + return ""; +} diff --git a/authplugin.h b/authplugin.h index 2c0ff40..329d194 100644 --- a/authplugin.h +++ b/authplugin.h @@ -40,6 +40,9 @@ extern "C" void mosquitto_log_printf(int level, const char *fmt, ...); } +std::string AuthResultToString(AuthResult r); + + class AuthPlugin { F_auth_plugin_version version = nullptr; diff --git a/client.h b/client.h index 49f2068..8c6aa98 100644 --- a/client.h +++ b/client.h @@ -93,6 +93,7 @@ public: bool hasConnectPacketSeen() { return connectPacketSeen; } ThreadData_p getThreadData() { return threadData; } std::string &getClientId() { return this->clientid; } + const std::string &getUsername() const { return this->username; } bool getCleanSession() { return cleanSession; } void assignSession(std::shared_ptr &session); std::shared_ptr getSession(); diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 0c65549..fa58114 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -73,17 +73,19 @@ MqttPacket::MqttPacket(const Publish &publish) : throw ProtocolError("Topic path too long."); } + this->topic = publish.topic; + packetType = PacketType::PUBLISH; this->qos = publish.qos; first_byte = static_cast(packetType) << 4; first_byte |= (publish.qos << 1); first_byte |= (static_cast(publish.retain) & 0b00000001); - char topicLenMSB = (publish.topic.length() & 0xFF00) >> 8; - char topicLenLSB = publish.topic.length() & 0x00FF; + char topicLenMSB = (topic.length() & 0xFF00) >> 8; + char topicLenLSB = topic.length() & 0x00FF; writeByte(topicLenMSB); writeByte(topicLenLSB); - writeBytes(publish.topic.c_str(), publish.topic.length()); + writeBytes(topic.c_str(), topic.length()); if (publish.qos) { @@ -357,7 +359,7 @@ void MqttPacket::handlePublish() if (qos == 0 && dup) throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal."); - std::string topic(readBytes(variable_header_length), variable_header_length); + topic = std::string(readBytes(variable_header_length), variable_header_length); if (!isValidUtf8(topic)) { @@ -388,20 +390,23 @@ void MqttPacket::handlePublish() sender->writeMqttPacket(response); } - if (retain) + if (sender->getThreadData()->authPlugin.aclCheck(sender->getClientId(), sender->getUsername(), topic, AclAccess::write) == AuthResult::success) { - size_t payload_length = remainingAfterPos(); - std::string payload(readBytes(payload_length), payload_length); + if (retain) + { + size_t payload_length = remainingAfterPos(); + std::string payload(readBytes(payload_length), payload_length); - sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, payload, qos); - } + sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, payload, qos); + } - // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. - // Existing subscribers don't get retain=1. [MQTT-3.3.1-9] - bites[0] &= 0b11110110; + // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. + // Existing subscribers don't get retain=1. [MQTT-3.3.1-9] + bites[0] &= 0b11110110; - // For the existing clients, we can just write the same packet back out, with our small alterations. - sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this); + // For the existing clients, we can just write the same packet back out, with our small alterations. + sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this); + } } void MqttPacket::handlePubAck() @@ -485,6 +490,11 @@ size_t MqttPacket::getSizeIncludingNonPresentHeader() const return total; } +const std::string &MqttPacket::getTopic() const +{ + return this->topic; +} + Client_p MqttPacket::getSender() const { diff --git a/mqttpacket.h b/mqttpacket.h index 31cfe0d..14421e9 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -26,6 +26,7 @@ public: class MqttPacket { + std::string topic; std::vector bites; size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header. RemainingLength remainingLength; @@ -72,6 +73,7 @@ public: size_t getSizeIncludingNonPresentHeader() const; const std::vector &getBites() const { return bites; } char getQos() const { return qos; } + const std::string &getTopic() const; Client_p getSender() const; void setSender(const Client_p &value); bool containsFixedHeader() const; diff --git a/session.cpp b/session.cpp index ed0eb19..d74b2ce 100644 --- a/session.cpp +++ b/session.cpp @@ -27,43 +27,48 @@ void Session::assignActiveConnection(std::shared_ptr &client) { this->client = client; this->client_id = client->getClientId(); + this->username = client->getUsername(); + this->thread = client->getThreadData(); } void Session::writePacket(const MqttPacket &packet, char max_qos) { - const char qos = std::min(packet.getQos(), max_qos); - - if (qos == 0) - { - if (!clientDisconnected()) - { - Client_p c = makeSharedClient(); - c->writeMqttPacketAndBlameThisClient(packet); - } - } - else if (qos == 1) + if (thread->authPlugin.aclCheck(client_id, username, packet.getTopic(), AclAccess::read) == AuthResult::success) { - std::shared_ptr copyPacket = packet.getCopy(); - std::unique_lock locker(qosQueueMutex); - if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) + const char qos = std::min(packet.getQos(), max_qos); + + if (qos == 0) { - logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str()); - return; + if (!clientDisconnected()) + { + Client_p c = makeSharedClient(); + c->writeMqttPacketAndBlameThisClient(packet); + } } - const uint16_t pid = nextPacketId++; - copyPacket->setPacketId(pid); - QueuedQosPacket p; - p.packet = copyPacket; - p.id = pid; - qosPacketQueue.push_back(p); - qosQueueBytes += copyPacket->getTotalMemoryFootprint(); - locker.unlock(); - - if (!clientDisconnected()) + else if (qos == 1) { - Client_p c = makeSharedClient(); - c->writeMqttPacketAndBlameThisClient(*copyPacket.get()); - copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. + std::shared_ptr copyPacket = packet.getCopy(); + std::unique_lock locker(qosQueueMutex); + if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) + { + logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str()); + return; + } + const uint16_t pid = nextPacketId++; + copyPacket->setPacketId(pid); + QueuedQosPacket p; + p.packet = copyPacket; + p.id = pid; + qosPacketQueue.push_back(p); + qosQueueBytes += copyPacket->getTotalMemoryFootprint(); + locker.unlock(); + + if (!clientDisconnected()) + { + Client_p c = makeSharedClient(); + c->writeMqttPacketAndBlameThisClient(*copyPacket.get()); + copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. + } } } } diff --git a/session.h b/session.h index 8150b86..b1e304f 100644 --- a/session.h +++ b/session.h @@ -24,7 +24,9 @@ struct QueuedQosPacket class Session { std::weak_ptr client; + ThreadData_p thread; std::string client_id; + std::string username; std::list qosPacketQueue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] std::mutex qosQueueMutex; uint16_t nextPacketId = 0;