Commit be2050824d591219683776c1d5b3936253a5237a

Authored by Wiebe Cazemier
1 parent 92f401f5

Implement ACL checks and improved login checks

authplugin.cpp
@@ -181,4 +181,18 @@ AuthResult AuthPlugin::unPwdCheck(const std::string &username, const std::string @@ -181,4 +181,18 @@ AuthResult AuthPlugin::unPwdCheck(const std::string &username, const std::string
181 return r; 181 return r;
182 } 182 }
183 183
  184 +std::string AuthResultToString(AuthResult r)
  185 +{
  186 + {
  187 + if (r == AuthResult::success)
  188 + return "success";
  189 + if (r == AuthResult::acl_denied)
  190 + return "ACL denied";
  191 + if (r == AuthResult::login_denied)
  192 + return "login Denied";
  193 + if (r == AuthResult::error)
  194 + return "error in check";
  195 + }
184 196
  197 + return "";
  198 +}
authplugin.h
@@ -40,6 +40,9 @@ extern "C" @@ -40,6 +40,9 @@ extern "C"
40 void mosquitto_log_printf(int level, const char *fmt, ...); 40 void mosquitto_log_printf(int level, const char *fmt, ...);
41 } 41 }
42 42
  43 +std::string AuthResultToString(AuthResult r);
  44 +
  45 +
43 class AuthPlugin 46 class AuthPlugin
44 { 47 {
45 F_auth_plugin_version version = nullptr; 48 F_auth_plugin_version version = nullptr;
client.h
@@ -93,6 +93,7 @@ public: @@ -93,6 +93,7 @@ public:
93 bool hasConnectPacketSeen() { return connectPacketSeen; } 93 bool hasConnectPacketSeen() { return connectPacketSeen; }
94 ThreadData_p getThreadData() { return threadData; } 94 ThreadData_p getThreadData() { return threadData; }
95 std::string &getClientId() { return this->clientid; } 95 std::string &getClientId() { return this->clientid; }
  96 + const std::string &getUsername() const { return this->username; }
96 bool getCleanSession() { return cleanSession; } 97 bool getCleanSession() { return cleanSession; }
97 void assignSession(std::shared_ptr<Session> &session); 98 void assignSession(std::shared_ptr<Session> &session);
98 std::shared_ptr<Session> getSession(); 99 std::shared_ptr<Session> getSession();
mqttpacket.cpp
@@ -73,17 +73,19 @@ MqttPacket::MqttPacket(const Publish &amp;publish) : @@ -73,17 +73,19 @@ MqttPacket::MqttPacket(const Publish &amp;publish) :
73 throw ProtocolError("Topic path too long."); 73 throw ProtocolError("Topic path too long.");
74 } 74 }
75 75
  76 + this->topic = publish.topic;
  77 +
76 packetType = PacketType::PUBLISH; 78 packetType = PacketType::PUBLISH;
77 this->qos = publish.qos; 79 this->qos = publish.qos;
78 first_byte = static_cast<char>(packetType) << 4; 80 first_byte = static_cast<char>(packetType) << 4;
79 first_byte |= (publish.qos << 1); 81 first_byte |= (publish.qos << 1);
80 first_byte |= (static_cast<char>(publish.retain) & 0b00000001); 82 first_byte |= (static_cast<char>(publish.retain) & 0b00000001);
81 83
82 - char topicLenMSB = (publish.topic.length() & 0xFF00) >> 8;  
83 - char topicLenLSB = publish.topic.length() & 0x00FF; 84 + char topicLenMSB = (topic.length() & 0xFF00) >> 8;
  85 + char topicLenLSB = topic.length() & 0x00FF;
84 writeByte(topicLenMSB); 86 writeByte(topicLenMSB);
85 writeByte(topicLenLSB); 87 writeByte(topicLenLSB);
86 - writeBytes(publish.topic.c_str(), publish.topic.length()); 88 + writeBytes(topic.c_str(), topic.length());
87 89
88 if (publish.qos) 90 if (publish.qos)
89 { 91 {
@@ -357,7 +359,7 @@ void MqttPacket::handlePublish() @@ -357,7 +359,7 @@ void MqttPacket::handlePublish()
357 if (qos == 0 && dup) 359 if (qos == 0 && dup)
358 throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal."); 360 throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal.");
359 361
360 - std::string topic(readBytes(variable_header_length), variable_header_length); 362 + topic = std::string(readBytes(variable_header_length), variable_header_length);
361 363
362 if (!isValidUtf8(topic)) 364 if (!isValidUtf8(topic))
363 { 365 {
@@ -388,20 +390,23 @@ void MqttPacket::handlePublish() @@ -388,20 +390,23 @@ void MqttPacket::handlePublish()
388 sender->writeMqttPacket(response); 390 sender->writeMqttPacket(response);
389 } 391 }
390 392
391 - if (retain) 393 + if (sender->getThreadData()->authPlugin.aclCheck(sender->getClientId(), sender->getUsername(), topic, AclAccess::write) == AuthResult::success)
392 { 394 {
393 - size_t payload_length = remainingAfterPos();  
394 - std::string payload(readBytes(payload_length), payload_length); 395 + if (retain)
  396 + {
  397 + size_t payload_length = remainingAfterPos();
  398 + std::string payload(readBytes(payload_length), payload_length);
395 399
396 - sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, payload, qos);  
397 - } 400 + sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, payload, qos);
  401 + }
398 402
399 - // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3].  
400 - // Existing subscribers don't get retain=1. [MQTT-3.3.1-9]  
401 - bites[0] &= 0b11110110; 403 + // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3].
  404 + // Existing subscribers don't get retain=1. [MQTT-3.3.1-9]
  405 + bites[0] &= 0b11110110;
402 406
403 - // For the existing clients, we can just write the same packet back out, with our small alterations.  
404 - sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this); 407 + // For the existing clients, we can just write the same packet back out, with our small alterations.
  408 + sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this);
  409 + }
405 } 410 }
406 411
407 void MqttPacket::handlePubAck() 412 void MqttPacket::handlePubAck()
@@ -485,6 +490,11 @@ size_t MqttPacket::getSizeIncludingNonPresentHeader() const @@ -485,6 +490,11 @@ size_t MqttPacket::getSizeIncludingNonPresentHeader() const
485 return total; 490 return total;
486 } 491 }
487 492
  493 +const std::string &MqttPacket::getTopic() const
  494 +{
  495 + return this->topic;
  496 +}
  497 +
488 498
489 Client_p MqttPacket::getSender() const 499 Client_p MqttPacket::getSender() const
490 { 500 {
mqttpacket.h
@@ -26,6 +26,7 @@ public: @@ -26,6 +26,7 @@ public:
26 26
27 class MqttPacket 27 class MqttPacket
28 { 28 {
  29 + std::string topic;
29 std::vector<char> bites; 30 std::vector<char> bites;
30 size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header. 31 size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header.
31 RemainingLength remainingLength; 32 RemainingLength remainingLength;
@@ -72,6 +73,7 @@ public: @@ -72,6 +73,7 @@ public:
72 size_t getSizeIncludingNonPresentHeader() const; 73 size_t getSizeIncludingNonPresentHeader() const;
73 const std::vector<char> &getBites() const { return bites; } 74 const std::vector<char> &getBites() const { return bites; }
74 char getQos() const { return qos; } 75 char getQos() const { return qos; }
  76 + const std::string &getTopic() const;
75 Client_p getSender() const; 77 Client_p getSender() const;
76 void setSender(const Client_p &value); 78 void setSender(const Client_p &value);
77 bool containsFixedHeader() const; 79 bool containsFixedHeader() const;
session.cpp
@@ -27,43 +27,48 @@ void Session::assignActiveConnection(std::shared_ptr&lt;Client&gt; &amp;client) @@ -27,43 +27,48 @@ void Session::assignActiveConnection(std::shared_ptr&lt;Client&gt; &amp;client)
27 { 27 {
28 this->client = client; 28 this->client = client;
29 this->client_id = client->getClientId(); 29 this->client_id = client->getClientId();
  30 + this->username = client->getUsername();
  31 + this->thread = client->getThreadData();
30 } 32 }
31 33
32 void Session::writePacket(const MqttPacket &packet, char max_qos) 34 void Session::writePacket(const MqttPacket &packet, char max_qos)
33 { 35 {
34 - const char qos = std::min<char>(packet.getQos(), max_qos);  
35 -  
36 - if (qos == 0)  
37 - {  
38 - if (!clientDisconnected())  
39 - {  
40 - Client_p c = makeSharedClient();  
41 - c->writeMqttPacketAndBlameThisClient(packet);  
42 - }  
43 - }  
44 - else if (qos == 1) 36 + if (thread->authPlugin.aclCheck(client_id, username, packet.getTopic(), AclAccess::read) == AuthResult::success)
45 { 37 {
46 - std::shared_ptr<MqttPacket> copyPacket = packet.getCopy();  
47 - std::unique_lock<std::mutex> locker(qosQueueMutex);  
48 - if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) 38 + const char qos = std::min<char>(packet.getQos(), max_qos);
  39 +
  40 + if (qos == 0)
49 { 41 {
50 - logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str());  
51 - return; 42 + if (!clientDisconnected())
  43 + {
  44 + Client_p c = makeSharedClient();
  45 + c->writeMqttPacketAndBlameThisClient(packet);
  46 + }
52 } 47 }
53 - const uint16_t pid = nextPacketId++;  
54 - copyPacket->setPacketId(pid);  
55 - QueuedQosPacket p;  
56 - p.packet = copyPacket;  
57 - p.id = pid;  
58 - qosPacketQueue.push_back(p);  
59 - qosQueueBytes += copyPacket->getTotalMemoryFootprint();  
60 - locker.unlock();  
61 -  
62 - if (!clientDisconnected()) 48 + else if (qos == 1)
63 { 49 {
64 - Client_p c = makeSharedClient();  
65 - c->writeMqttPacketAndBlameThisClient(*copyPacket.get());  
66 - copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. 50 + std::shared_ptr<MqttPacket> copyPacket = packet.getCopy();
  51 + std::unique_lock<std::mutex> locker(qosQueueMutex);
  52 + if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0))
  53 + {
  54 + logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str());
  55 + return;
  56 + }
  57 + const uint16_t pid = nextPacketId++;
  58 + copyPacket->setPacketId(pid);
  59 + QueuedQosPacket p;
  60 + p.packet = copyPacket;
  61 + p.id = pid;
  62 + qosPacketQueue.push_back(p);
  63 + qosQueueBytes += copyPacket->getTotalMemoryFootprint();
  64 + locker.unlock();
  65 +
  66 + if (!clientDisconnected())
  67 + {
  68 + Client_p c = makeSharedClient();
  69 + c->writeMqttPacketAndBlameThisClient(*copyPacket.get());
  70 + copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate.
  71 + }
67 } 72 }
68 } 73 }
69 } 74 }
session.h
@@ -24,7 +24,9 @@ struct QueuedQosPacket @@ -24,7 +24,9 @@ struct QueuedQosPacket
24 class Session 24 class Session
25 { 25 {
26 std::weak_ptr<Client> client; 26 std::weak_ptr<Client> client;
  27 + ThreadData_p thread;
27 std::string client_id; 28 std::string client_id;
  29 + std::string username;
28 std::list<QueuedQosPacket> qosPacketQueue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] 30 std::list<QueuedQosPacket> qosPacketQueue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6]
29 std::mutex qosQueueMutex; 31 std::mutex qosQueueMutex;
30 uint16_t nextPacketId = 0; 32 uint16_t nextPacketId = 0;