You need to sign in before continuing.
Commit be2050824d591219683776c1d5b3936253a5237a
1 parent
92f401f5
Implement ACL checks and improved login checks
Showing
7 changed files
with
80 additions
and
43 deletions
authplugin.cpp
| ... | ... | @@ -181,4 +181,18 @@ AuthResult AuthPlugin::unPwdCheck(const std::string &username, const std::string |
| 181 | 181 | return r; |
| 182 | 182 | } |
| 183 | 183 | |
| 184 | +std::string AuthResultToString(AuthResult r) | |
| 185 | +{ | |
| 186 | + { | |
| 187 | + if (r == AuthResult::success) | |
| 188 | + return "success"; | |
| 189 | + if (r == AuthResult::acl_denied) | |
| 190 | + return "ACL denied"; | |
| 191 | + if (r == AuthResult::login_denied) | |
| 192 | + return "login Denied"; | |
| 193 | + if (r == AuthResult::error) | |
| 194 | + return "error in check"; | |
| 195 | + } | |
| 184 | 196 | |
| 197 | + return ""; | |
| 198 | +} | ... | ... |
authplugin.h
client.h
| ... | ... | @@ -93,6 +93,7 @@ public: |
| 93 | 93 | bool hasConnectPacketSeen() { return connectPacketSeen; } |
| 94 | 94 | ThreadData_p getThreadData() { return threadData; } |
| 95 | 95 | std::string &getClientId() { return this->clientid; } |
| 96 | + const std::string &getUsername() const { return this->username; } | |
| 96 | 97 | bool getCleanSession() { return cleanSession; } |
| 97 | 98 | void assignSession(std::shared_ptr<Session> &session); |
| 98 | 99 | std::shared_ptr<Session> getSession(); | ... | ... |
mqttpacket.cpp
| ... | ... | @@ -73,17 +73,19 @@ MqttPacket::MqttPacket(const Publish &publish) : |
| 73 | 73 | throw ProtocolError("Topic path too long."); |
| 74 | 74 | } |
| 75 | 75 | |
| 76 | + this->topic = publish.topic; | |
| 77 | + | |
| 76 | 78 | packetType = PacketType::PUBLISH; |
| 77 | 79 | this->qos = publish.qos; |
| 78 | 80 | first_byte = static_cast<char>(packetType) << 4; |
| 79 | 81 | first_byte |= (publish.qos << 1); |
| 80 | 82 | first_byte |= (static_cast<char>(publish.retain) & 0b00000001); |
| 81 | 83 | |
| 82 | - char topicLenMSB = (publish.topic.length() & 0xFF00) >> 8; | |
| 83 | - char topicLenLSB = publish.topic.length() & 0x00FF; | |
| 84 | + char topicLenMSB = (topic.length() & 0xFF00) >> 8; | |
| 85 | + char topicLenLSB = topic.length() & 0x00FF; | |
| 84 | 86 | writeByte(topicLenMSB); |
| 85 | 87 | writeByte(topicLenLSB); |
| 86 | - writeBytes(publish.topic.c_str(), publish.topic.length()); | |
| 88 | + writeBytes(topic.c_str(), topic.length()); | |
| 87 | 89 | |
| 88 | 90 | if (publish.qos) |
| 89 | 91 | { |
| ... | ... | @@ -357,7 +359,7 @@ void MqttPacket::handlePublish() |
| 357 | 359 | if (qos == 0 && dup) |
| 358 | 360 | throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal."); |
| 359 | 361 | |
| 360 | - std::string topic(readBytes(variable_header_length), variable_header_length); | |
| 362 | + topic = std::string(readBytes(variable_header_length), variable_header_length); | |
| 361 | 363 | |
| 362 | 364 | if (!isValidUtf8(topic)) |
| 363 | 365 | { |
| ... | ... | @@ -388,20 +390,23 @@ void MqttPacket::handlePublish() |
| 388 | 390 | sender->writeMqttPacket(response); |
| 389 | 391 | } |
| 390 | 392 | |
| 391 | - if (retain) | |
| 393 | + if (sender->getThreadData()->authPlugin.aclCheck(sender->getClientId(), sender->getUsername(), topic, AclAccess::write) == AuthResult::success) | |
| 392 | 394 | { |
| 393 | - size_t payload_length = remainingAfterPos(); | |
| 394 | - std::string payload(readBytes(payload_length), payload_length); | |
| 395 | + if (retain) | |
| 396 | + { | |
| 397 | + size_t payload_length = remainingAfterPos(); | |
| 398 | + std::string payload(readBytes(payload_length), payload_length); | |
| 395 | 399 | |
| 396 | - sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, payload, qos); | |
| 397 | - } | |
| 400 | + sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, payload, qos); | |
| 401 | + } | |
| 398 | 402 | |
| 399 | - // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. | |
| 400 | - // Existing subscribers don't get retain=1. [MQTT-3.3.1-9] | |
| 401 | - bites[0] &= 0b11110110; | |
| 403 | + // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. | |
| 404 | + // Existing subscribers don't get retain=1. [MQTT-3.3.1-9] | |
| 405 | + bites[0] &= 0b11110110; | |
| 402 | 406 | |
| 403 | - // For the existing clients, we can just write the same packet back out, with our small alterations. | |
| 404 | - sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this); | |
| 407 | + // For the existing clients, we can just write the same packet back out, with our small alterations. | |
| 408 | + sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this); | |
| 409 | + } | |
| 405 | 410 | } |
| 406 | 411 | |
| 407 | 412 | void MqttPacket::handlePubAck() |
| ... | ... | @@ -485,6 +490,11 @@ size_t MqttPacket::getSizeIncludingNonPresentHeader() const |
| 485 | 490 | return total; |
| 486 | 491 | } |
| 487 | 492 | |
| 493 | +const std::string &MqttPacket::getTopic() const | |
| 494 | +{ | |
| 495 | + return this->topic; | |
| 496 | +} | |
| 497 | + | |
| 488 | 498 | |
| 489 | 499 | Client_p MqttPacket::getSender() const |
| 490 | 500 | { | ... | ... |
mqttpacket.h
| ... | ... | @@ -26,6 +26,7 @@ public: |
| 26 | 26 | |
| 27 | 27 | class MqttPacket |
| 28 | 28 | { |
| 29 | + std::string topic; | |
| 29 | 30 | std::vector<char> bites; |
| 30 | 31 | size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header. |
| 31 | 32 | RemainingLength remainingLength; |
| ... | ... | @@ -72,6 +73,7 @@ public: |
| 72 | 73 | size_t getSizeIncludingNonPresentHeader() const; |
| 73 | 74 | const std::vector<char> &getBites() const { return bites; } |
| 74 | 75 | char getQos() const { return qos; } |
| 76 | + const std::string &getTopic() const; | |
| 75 | 77 | Client_p getSender() const; |
| 76 | 78 | void setSender(const Client_p &value); |
| 77 | 79 | bool containsFixedHeader() const; | ... | ... |
session.cpp
| ... | ... | @@ -27,43 +27,48 @@ void Session::assignActiveConnection(std::shared_ptr<Client> &client) |
| 27 | 27 | { |
| 28 | 28 | this->client = client; |
| 29 | 29 | this->client_id = client->getClientId(); |
| 30 | + this->username = client->getUsername(); | |
| 31 | + this->thread = client->getThreadData(); | |
| 30 | 32 | } |
| 31 | 33 | |
| 32 | 34 | void Session::writePacket(const MqttPacket &packet, char max_qos) |
| 33 | 35 | { |
| 34 | - const char qos = std::min<char>(packet.getQos(), max_qos); | |
| 35 | - | |
| 36 | - if (qos == 0) | |
| 37 | - { | |
| 38 | - if (!clientDisconnected()) | |
| 39 | - { | |
| 40 | - Client_p c = makeSharedClient(); | |
| 41 | - c->writeMqttPacketAndBlameThisClient(packet); | |
| 42 | - } | |
| 43 | - } | |
| 44 | - else if (qos == 1) | |
| 36 | + if (thread->authPlugin.aclCheck(client_id, username, packet.getTopic(), AclAccess::read) == AuthResult::success) | |
| 45 | 37 | { |
| 46 | - std::shared_ptr<MqttPacket> copyPacket = packet.getCopy(); | |
| 47 | - std::unique_lock<std::mutex> locker(qosQueueMutex); | |
| 48 | - if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) | |
| 38 | + const char qos = std::min<char>(packet.getQos(), max_qos); | |
| 39 | + | |
| 40 | + if (qos == 0) | |
| 49 | 41 | { |
| 50 | - logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str()); | |
| 51 | - return; | |
| 42 | + if (!clientDisconnected()) | |
| 43 | + { | |
| 44 | + Client_p c = makeSharedClient(); | |
| 45 | + c->writeMqttPacketAndBlameThisClient(packet); | |
| 46 | + } | |
| 52 | 47 | } |
| 53 | - const uint16_t pid = nextPacketId++; | |
| 54 | - copyPacket->setPacketId(pid); | |
| 55 | - QueuedQosPacket p; | |
| 56 | - p.packet = copyPacket; | |
| 57 | - p.id = pid; | |
| 58 | - qosPacketQueue.push_back(p); | |
| 59 | - qosQueueBytes += copyPacket->getTotalMemoryFootprint(); | |
| 60 | - locker.unlock(); | |
| 61 | - | |
| 62 | - if (!clientDisconnected()) | |
| 48 | + else if (qos == 1) | |
| 63 | 49 | { |
| 64 | - Client_p c = makeSharedClient(); | |
| 65 | - c->writeMqttPacketAndBlameThisClient(*copyPacket.get()); | |
| 66 | - copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | |
| 50 | + std::shared_ptr<MqttPacket> copyPacket = packet.getCopy(); | |
| 51 | + std::unique_lock<std::mutex> locker(qosQueueMutex); | |
| 52 | + if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) | |
| 53 | + { | |
| 54 | + logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str()); | |
| 55 | + return; | |
| 56 | + } | |
| 57 | + const uint16_t pid = nextPacketId++; | |
| 58 | + copyPacket->setPacketId(pid); | |
| 59 | + QueuedQosPacket p; | |
| 60 | + p.packet = copyPacket; | |
| 61 | + p.id = pid; | |
| 62 | + qosPacketQueue.push_back(p); | |
| 63 | + qosQueueBytes += copyPacket->getTotalMemoryFootprint(); | |
| 64 | + locker.unlock(); | |
| 65 | + | |
| 66 | + if (!clientDisconnected()) | |
| 67 | + { | |
| 68 | + Client_p c = makeSharedClient(); | |
| 69 | + c->writeMqttPacketAndBlameThisClient(*copyPacket.get()); | |
| 70 | + copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | |
| 71 | + } | |
| 67 | 72 | } |
| 68 | 73 | } |
| 69 | 74 | } | ... | ... |
session.h
| ... | ... | @@ -24,7 +24,9 @@ struct QueuedQosPacket |
| 24 | 24 | class Session |
| 25 | 25 | { |
| 26 | 26 | std::weak_ptr<Client> client; |
| 27 | + ThreadData_p thread; | |
| 27 | 28 | std::string client_id; |
| 29 | + std::string username; | |
| 28 | 30 | std::list<QueuedQosPacket> qosPacketQueue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] |
| 29 | 31 | std::mutex qosQueueMutex; |
| 30 | 32 | uint16_t nextPacketId = 0; | ... | ... |