Commit be2050824d591219683776c1d5b3936253a5237a
1 parent
92f401f5
Implement ACL checks and improved login checks
Showing
7 changed files
with
80 additions
and
43 deletions
authplugin.cpp
| @@ -181,4 +181,18 @@ AuthResult AuthPlugin::unPwdCheck(const std::string &username, const std::string | @@ -181,4 +181,18 @@ AuthResult AuthPlugin::unPwdCheck(const std::string &username, const std::string | ||
| 181 | return r; | 181 | return r; |
| 182 | } | 182 | } |
| 183 | 183 | ||
| 184 | +std::string AuthResultToString(AuthResult r) | ||
| 185 | +{ | ||
| 186 | + { | ||
| 187 | + if (r == AuthResult::success) | ||
| 188 | + return "success"; | ||
| 189 | + if (r == AuthResult::acl_denied) | ||
| 190 | + return "ACL denied"; | ||
| 191 | + if (r == AuthResult::login_denied) | ||
| 192 | + return "login Denied"; | ||
| 193 | + if (r == AuthResult::error) | ||
| 194 | + return "error in check"; | ||
| 195 | + } | ||
| 184 | 196 | ||
| 197 | + return ""; | ||
| 198 | +} |
authplugin.h
| @@ -40,6 +40,9 @@ extern "C" | @@ -40,6 +40,9 @@ extern "C" | ||
| 40 | void mosquitto_log_printf(int level, const char *fmt, ...); | 40 | void mosquitto_log_printf(int level, const char *fmt, ...); |
| 41 | } | 41 | } |
| 42 | 42 | ||
| 43 | +std::string AuthResultToString(AuthResult r); | ||
| 44 | + | ||
| 45 | + | ||
| 43 | class AuthPlugin | 46 | class AuthPlugin |
| 44 | { | 47 | { |
| 45 | F_auth_plugin_version version = nullptr; | 48 | F_auth_plugin_version version = nullptr; |
client.h
| @@ -93,6 +93,7 @@ public: | @@ -93,6 +93,7 @@ public: | ||
| 93 | bool hasConnectPacketSeen() { return connectPacketSeen; } | 93 | bool hasConnectPacketSeen() { return connectPacketSeen; } |
| 94 | ThreadData_p getThreadData() { return threadData; } | 94 | ThreadData_p getThreadData() { return threadData; } |
| 95 | std::string &getClientId() { return this->clientid; } | 95 | std::string &getClientId() { return this->clientid; } |
| 96 | + const std::string &getUsername() const { return this->username; } | ||
| 96 | bool getCleanSession() { return cleanSession; } | 97 | bool getCleanSession() { return cleanSession; } |
| 97 | void assignSession(std::shared_ptr<Session> &session); | 98 | void assignSession(std::shared_ptr<Session> &session); |
| 98 | std::shared_ptr<Session> getSession(); | 99 | std::shared_ptr<Session> getSession(); |
mqttpacket.cpp
| @@ -73,17 +73,19 @@ MqttPacket::MqttPacket(const Publish &publish) : | @@ -73,17 +73,19 @@ MqttPacket::MqttPacket(const Publish &publish) : | ||
| 73 | throw ProtocolError("Topic path too long."); | 73 | throw ProtocolError("Topic path too long."); |
| 74 | } | 74 | } |
| 75 | 75 | ||
| 76 | + this->topic = publish.topic; | ||
| 77 | + | ||
| 76 | packetType = PacketType::PUBLISH; | 78 | packetType = PacketType::PUBLISH; |
| 77 | this->qos = publish.qos; | 79 | this->qos = publish.qos; |
| 78 | first_byte = static_cast<char>(packetType) << 4; | 80 | first_byte = static_cast<char>(packetType) << 4; |
| 79 | first_byte |= (publish.qos << 1); | 81 | first_byte |= (publish.qos << 1); |
| 80 | first_byte |= (static_cast<char>(publish.retain) & 0b00000001); | 82 | first_byte |= (static_cast<char>(publish.retain) & 0b00000001); |
| 81 | 83 | ||
| 82 | - char topicLenMSB = (publish.topic.length() & 0xFF00) >> 8; | ||
| 83 | - char topicLenLSB = publish.topic.length() & 0x00FF; | 84 | + char topicLenMSB = (topic.length() & 0xFF00) >> 8; |
| 85 | + char topicLenLSB = topic.length() & 0x00FF; | ||
| 84 | writeByte(topicLenMSB); | 86 | writeByte(topicLenMSB); |
| 85 | writeByte(topicLenLSB); | 87 | writeByte(topicLenLSB); |
| 86 | - writeBytes(publish.topic.c_str(), publish.topic.length()); | 88 | + writeBytes(topic.c_str(), topic.length()); |
| 87 | 89 | ||
| 88 | if (publish.qos) | 90 | if (publish.qos) |
| 89 | { | 91 | { |
| @@ -357,7 +359,7 @@ void MqttPacket::handlePublish() | @@ -357,7 +359,7 @@ void MqttPacket::handlePublish() | ||
| 357 | if (qos == 0 && dup) | 359 | if (qos == 0 && dup) |
| 358 | throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal."); | 360 | throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal."); |
| 359 | 361 | ||
| 360 | - std::string topic(readBytes(variable_header_length), variable_header_length); | 362 | + topic = std::string(readBytes(variable_header_length), variable_header_length); |
| 361 | 363 | ||
| 362 | if (!isValidUtf8(topic)) | 364 | if (!isValidUtf8(topic)) |
| 363 | { | 365 | { |
| @@ -388,20 +390,23 @@ void MqttPacket::handlePublish() | @@ -388,20 +390,23 @@ void MqttPacket::handlePublish() | ||
| 388 | sender->writeMqttPacket(response); | 390 | sender->writeMqttPacket(response); |
| 389 | } | 391 | } |
| 390 | 392 | ||
| 391 | - if (retain) | 393 | + if (sender->getThreadData()->authPlugin.aclCheck(sender->getClientId(), sender->getUsername(), topic, AclAccess::write) == AuthResult::success) |
| 392 | { | 394 | { |
| 393 | - size_t payload_length = remainingAfterPos(); | ||
| 394 | - std::string payload(readBytes(payload_length), payload_length); | 395 | + if (retain) |
| 396 | + { | ||
| 397 | + size_t payload_length = remainingAfterPos(); | ||
| 398 | + std::string payload(readBytes(payload_length), payload_length); | ||
| 395 | 399 | ||
| 396 | - sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, payload, qos); | ||
| 397 | - } | 400 | + sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, payload, qos); |
| 401 | + } | ||
| 398 | 402 | ||
| 399 | - // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. | ||
| 400 | - // Existing subscribers don't get retain=1. [MQTT-3.3.1-9] | ||
| 401 | - bites[0] &= 0b11110110; | 403 | + // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. |
| 404 | + // Existing subscribers don't get retain=1. [MQTT-3.3.1-9] | ||
| 405 | + bites[0] &= 0b11110110; | ||
| 402 | 406 | ||
| 403 | - // For the existing clients, we can just write the same packet back out, with our small alterations. | ||
| 404 | - sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this); | 407 | + // For the existing clients, we can just write the same packet back out, with our small alterations. |
| 408 | + sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this); | ||
| 409 | + } | ||
| 405 | } | 410 | } |
| 406 | 411 | ||
| 407 | void MqttPacket::handlePubAck() | 412 | void MqttPacket::handlePubAck() |
| @@ -485,6 +490,11 @@ size_t MqttPacket::getSizeIncludingNonPresentHeader() const | @@ -485,6 +490,11 @@ size_t MqttPacket::getSizeIncludingNonPresentHeader() const | ||
| 485 | return total; | 490 | return total; |
| 486 | } | 491 | } |
| 487 | 492 | ||
| 493 | +const std::string &MqttPacket::getTopic() const | ||
| 494 | +{ | ||
| 495 | + return this->topic; | ||
| 496 | +} | ||
| 497 | + | ||
| 488 | 498 | ||
| 489 | Client_p MqttPacket::getSender() const | 499 | Client_p MqttPacket::getSender() const |
| 490 | { | 500 | { |
mqttpacket.h
| @@ -26,6 +26,7 @@ public: | @@ -26,6 +26,7 @@ public: | ||
| 26 | 26 | ||
| 27 | class MqttPacket | 27 | class MqttPacket |
| 28 | { | 28 | { |
| 29 | + std::string topic; | ||
| 29 | std::vector<char> bites; | 30 | std::vector<char> bites; |
| 30 | size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header. | 31 | size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header. |
| 31 | RemainingLength remainingLength; | 32 | RemainingLength remainingLength; |
| @@ -72,6 +73,7 @@ public: | @@ -72,6 +73,7 @@ public: | ||
| 72 | size_t getSizeIncludingNonPresentHeader() const; | 73 | size_t getSizeIncludingNonPresentHeader() const; |
| 73 | const std::vector<char> &getBites() const { return bites; } | 74 | const std::vector<char> &getBites() const { return bites; } |
| 74 | char getQos() const { return qos; } | 75 | char getQos() const { return qos; } |
| 76 | + const std::string &getTopic() const; | ||
| 75 | Client_p getSender() const; | 77 | Client_p getSender() const; |
| 76 | void setSender(const Client_p &value); | 78 | void setSender(const Client_p &value); |
| 77 | bool containsFixedHeader() const; | 79 | bool containsFixedHeader() const; |
session.cpp
| @@ -27,43 +27,48 @@ void Session::assignActiveConnection(std::shared_ptr<Client> &client) | @@ -27,43 +27,48 @@ void Session::assignActiveConnection(std::shared_ptr<Client> &client) | ||
| 27 | { | 27 | { |
| 28 | this->client = client; | 28 | this->client = client; |
| 29 | this->client_id = client->getClientId(); | 29 | this->client_id = client->getClientId(); |
| 30 | + this->username = client->getUsername(); | ||
| 31 | + this->thread = client->getThreadData(); | ||
| 30 | } | 32 | } |
| 31 | 33 | ||
| 32 | void Session::writePacket(const MqttPacket &packet, char max_qos) | 34 | void Session::writePacket(const MqttPacket &packet, char max_qos) |
| 33 | { | 35 | { |
| 34 | - const char qos = std::min<char>(packet.getQos(), max_qos); | ||
| 35 | - | ||
| 36 | - if (qos == 0) | ||
| 37 | - { | ||
| 38 | - if (!clientDisconnected()) | ||
| 39 | - { | ||
| 40 | - Client_p c = makeSharedClient(); | ||
| 41 | - c->writeMqttPacketAndBlameThisClient(packet); | ||
| 42 | - } | ||
| 43 | - } | ||
| 44 | - else if (qos == 1) | 36 | + if (thread->authPlugin.aclCheck(client_id, username, packet.getTopic(), AclAccess::read) == AuthResult::success) |
| 45 | { | 37 | { |
| 46 | - std::shared_ptr<MqttPacket> copyPacket = packet.getCopy(); | ||
| 47 | - std::unique_lock<std::mutex> locker(qosQueueMutex); | ||
| 48 | - if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) | 38 | + const char qos = std::min<char>(packet.getQos(), max_qos); |
| 39 | + | ||
| 40 | + if (qos == 0) | ||
| 49 | { | 41 | { |
| 50 | - logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str()); | ||
| 51 | - return; | 42 | + if (!clientDisconnected()) |
| 43 | + { | ||
| 44 | + Client_p c = makeSharedClient(); | ||
| 45 | + c->writeMqttPacketAndBlameThisClient(packet); | ||
| 46 | + } | ||
| 52 | } | 47 | } |
| 53 | - const uint16_t pid = nextPacketId++; | ||
| 54 | - copyPacket->setPacketId(pid); | ||
| 55 | - QueuedQosPacket p; | ||
| 56 | - p.packet = copyPacket; | ||
| 57 | - p.id = pid; | ||
| 58 | - qosPacketQueue.push_back(p); | ||
| 59 | - qosQueueBytes += copyPacket->getTotalMemoryFootprint(); | ||
| 60 | - locker.unlock(); | ||
| 61 | - | ||
| 62 | - if (!clientDisconnected()) | 48 | + else if (qos == 1) |
| 63 | { | 49 | { |
| 64 | - Client_p c = makeSharedClient(); | ||
| 65 | - c->writeMqttPacketAndBlameThisClient(*copyPacket.get()); | ||
| 66 | - copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | 50 | + std::shared_ptr<MqttPacket> copyPacket = packet.getCopy(); |
| 51 | + std::unique_lock<std::mutex> locker(qosQueueMutex); | ||
| 52 | + if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) | ||
| 53 | + { | ||
| 54 | + logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str()); | ||
| 55 | + return; | ||
| 56 | + } | ||
| 57 | + const uint16_t pid = nextPacketId++; | ||
| 58 | + copyPacket->setPacketId(pid); | ||
| 59 | + QueuedQosPacket p; | ||
| 60 | + p.packet = copyPacket; | ||
| 61 | + p.id = pid; | ||
| 62 | + qosPacketQueue.push_back(p); | ||
| 63 | + qosQueueBytes += copyPacket->getTotalMemoryFootprint(); | ||
| 64 | + locker.unlock(); | ||
| 65 | + | ||
| 66 | + if (!clientDisconnected()) | ||
| 67 | + { | ||
| 68 | + Client_p c = makeSharedClient(); | ||
| 69 | + c->writeMqttPacketAndBlameThisClient(*copyPacket.get()); | ||
| 70 | + copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | ||
| 71 | + } | ||
| 67 | } | 72 | } |
| 68 | } | 73 | } |
| 69 | } | 74 | } |
session.h
| @@ -24,7 +24,9 @@ struct QueuedQosPacket | @@ -24,7 +24,9 @@ struct QueuedQosPacket | ||
| 24 | class Session | 24 | class Session |
| 25 | { | 25 | { |
| 26 | std::weak_ptr<Client> client; | 26 | std::weak_ptr<Client> client; |
| 27 | + ThreadData_p thread; | ||
| 27 | std::string client_id; | 28 | std::string client_id; |
| 29 | + std::string username; | ||
| 28 | std::list<QueuedQosPacket> qosPacketQueue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] | 30 | std::list<QueuedQosPacket> qosPacketQueue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] |
| 29 | std::mutex qosQueueMutex; | 31 | std::mutex qosQueueMutex; |
| 30 | uint16_t nextPacketId = 0; | 32 | uint16_t nextPacketId = 0; |