diff --git a/authplugin.cpp b/authplugin.cpp index 0d6ed3d..65a6a8e 100644 --- a/authplugin.cpp +++ b/authplugin.cpp @@ -278,7 +278,7 @@ void Authentication::securityCleanup(bool reloading) } AuthResult Authentication::aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, - AclAccess access, char qos, bool retain) + AclAccess access, char qos, bool retain, const std::vector> *userProperties) { assert(subtopics.size() > 0); @@ -322,7 +322,7 @@ AuthResult Authentication::aclCheck(const std::string &clientid, const std::stri // gets disconnected. try { - FlashMQMessage msg(topic, subtopics, qos, retain); + FlashMQMessage msg(topic, subtopics, qos, retain, userProperties); return flashmq_auth_plugin_acl_check_v1(pluginData, access, clientid, username, msg); } catch (std::exception &ex) @@ -335,7 +335,8 @@ AuthResult Authentication::aclCheck(const std::string &clientid, const std::stri return AuthResult::error; } -AuthResult Authentication::unPwdCheck(const std::string &username, const std::string &password) +AuthResult Authentication::unPwdCheck(const std::string &username, const std::string &password, + const std::vector> *userProperties) { AuthResult firstResult = unPwdCheckFromMosquittoPasswordFile(username, password); @@ -373,7 +374,7 @@ AuthResult Authentication::unPwdCheck(const std::string &username, const std::st // gets disconnected. try { - return flashmq_auth_plugin_login_check_v1(pluginData, username, password); + return flashmq_auth_plugin_login_check_v1(pluginData, username, password, userProperties); } catch (std::exception &ex) { diff --git a/authplugin.h b/authplugin.h index a7b0da0..2a0b63a 100644 --- a/authplugin.h +++ b/authplugin.h @@ -65,7 +65,8 @@ typedef void(*F_flashmq_auth_plugin_deallocate_thread_memory_v1)(void *thread_da typedef void(*F_flashmq_auth_plugin_init_v1)(void *thread_data, std::unordered_map &auth_opts, bool reloading); typedef void(*F_flashmq_auth_plugin_deinit_v1)(void *thread_data, std::unordered_map &auth_opts, bool reloading); typedef AuthResult(*F_flashmq_auth_plugin_acl_check_v1)(void *thread_data, AclAccess access, const std::string &clientid, const std::string &username, const FlashMQMessage &msg); -typedef AuthResult(*F_flashmq_auth_plugin_login_check_v1)(void *thread_data, const std::string &username, const std::string &password); +typedef AuthResult(*F_flashmq_auth_plugin_login_check_v1)(void *thread_data, const std::string &username, const std::string &password, + const std::vector> *userProperties); typedef void (*F_flashmq_auth_plugin_periodic_event)(void *thread_data); extern "C" @@ -152,8 +153,9 @@ public: void securityInit(bool reloading); void securityCleanup(bool reloading); AuthResult aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, - AclAccess access, char qos, bool retain); - AuthResult unPwdCheck(const std::string &username, const std::string &password); + AclAccess access, char qos, bool retain, const std::vector> *userProperties); + AuthResult unPwdCheck(const std::string &username, const std::string &password, + const std::vector> *userProperties); void setQuitting(); void loadMosquittoPasswordFile(); diff --git a/flashmq_plugin.cpp b/flashmq_plugin.cpp index 51c23a0..b7a0bc8 100644 --- a/flashmq_plugin.cpp +++ b/flashmq_plugin.cpp @@ -12,9 +12,11 @@ void flashmq_logf(int level, const char *str, ...) va_end(valist); } -FlashMQMessage::FlashMQMessage(const std::string &topic, const std::vector &subtopics, const char qos, const bool retain) : +FlashMQMessage::FlashMQMessage(const std::string &topic, const std::vector &subtopics, const char qos, const bool retain, + const std::vector> *userProperties) : topic(topic), subtopics(subtopics), + userProperties(userProperties), qos(qos), retain(retain) { diff --git a/flashmq_plugin.h b/flashmq_plugin.h index 51c2bc1..4d17ab0 100644 --- a/flashmq_plugin.h +++ b/flashmq_plugin.h @@ -16,6 +16,7 @@ #include #include #include +#include #define FLASHMQ_PLUGIN_VERSION 1 @@ -72,10 +73,12 @@ struct FlashMQMessage { const std::string &topic; const std::vector &subtopics; + const std::vector> *userProperties; const char qos; const bool retain; - FlashMQMessage(const std::string &topic, const std::vector &subtopics, const char qos, const bool retain); + FlashMQMessage(const std::string &topic, const std::vector &subtopics, const char qos, const bool retain, + const std::vector> *userProperties); }; /** @@ -177,7 +180,8 @@ void flashmq_auth_plugin_periodic_event(void *thread_data); * Note that there is a setting 'auth_plugin_serialize_auth_checks'. Use only as a last resort if your plugin is not * thread-safe. It will negate much of FlashMQ's multi-core model. */ -AuthResult flashmq_auth_plugin_login_check(void *thread_data, const std::string &username, const std::string &password); +AuthResult flashmq_auth_plugin_login_check(void *thread_data, const std::string &username, const std::string &password, + const std::vector> *userProperties); /** * @brief flashmq_auth_plugin_acl_check is called on publish, deliver and subscribe. diff --git a/mqtt5properties.cpp b/mqtt5properties.cpp index 4dc6c75..d66693d 100644 --- a/mqtt5properties.cpp +++ b/mqtt5properties.cpp @@ -2,6 +2,7 @@ #include "cstring" #include "vector" +#include "cassert" #include "exceptions.h" @@ -38,6 +39,11 @@ void Mqtt5PropertyBuilder::clearClientSpecificBytes() clientSpecificBytes.clear(); } +std::shared_ptr>> Mqtt5PropertyBuilder::getUserProperties() const +{ + return this->userProperties; +} + void Mqtt5PropertyBuilder::writeSessionExpiry(uint32_t val) { writeUint32(Mqtt5Properties::SessionExpiryInterval, val, genericBytes); @@ -103,9 +109,15 @@ void Mqtt5PropertyBuilder::writeResponseTopic(const std::string &str) writeStr(Mqtt5Properties::ResponseTopic, str); } -void Mqtt5PropertyBuilder::writeUserProperty(const std::string &key, const std::string &value) +void Mqtt5PropertyBuilder::writeUserProperty(std::string &&key, std::string &&value) { write2Str(Mqtt5Properties::UserProperty, key, value); + + if (!this->userProperties) + this->userProperties = std::make_shared>>(); + + std::pair pair(std::move(key), std::move(value)); + this->userProperties->push_back(std::move(pair)); } void Mqtt5PropertyBuilder::writeCorrelationData(const std::string &correlationData) @@ -113,6 +125,15 @@ void Mqtt5PropertyBuilder::writeCorrelationData(const std::string &correlationDa writeStr(Mqtt5Properties::CorrelationData, correlationData); } +void Mqtt5PropertyBuilder::setNewUserProperties(const std::shared_ptr>> &userProperties) +{ + assert(!this->userProperties); + assert(this->genericBytes.empty()); + assert(this->clientSpecificBytes.empty()); + + this->userProperties = userProperties; +} + void Mqtt5PropertyBuilder::writeUint32(Mqtt5Properties prop, const uint32_t x, std::vector &target) { size_t pos = target.size(); diff --git a/mqtt5properties.h b/mqtt5properties.h index 2d9ce1a..bdc3cc4 100644 --- a/mqtt5properties.h +++ b/mqtt5properties.h @@ -10,6 +10,7 @@ class Mqtt5PropertyBuilder { std::vector genericBytes; std::vector clientSpecificBytes; // only relevant for publishes + std::shared_ptr>> userProperties; VariableByteInt length; void writeUint32(Mqtt5Properties prop, const uint32_t x, std::vector &target); @@ -25,6 +26,7 @@ public: const std::vector &getGenericBytes() const; const std::vector &getclientSpecificBytes() const; void clearClientSpecificBytes(); + std::shared_ptr>> getUserProperties() const; void writeSessionExpiry(uint32_t val); void writeReceiveMax(uint16_t val); @@ -39,8 +41,9 @@ public: void writePayloadFormatIndicator(uint8_t val); void writeMessageExpiryInterval(uint32_t val); void writeResponseTopic(const std::string &str); - void writeUserProperty(const std::string &key, const std::string &value); + void writeUserProperty(std::string &&key, std::string &&value); void writeCorrelationData(const std::string &correlationData); + void setNewUserProperties(const std::shared_ptr>> &userProperties); }; #endif // MQTT5PROPERTIES_H diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 16dd429..6864d9f 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -132,6 +132,15 @@ MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const Publish &_pu if (protocolVersion >= ProtocolVersion::Mqtt5) { + // Step 1: make certain properties available as objects, because FlashMQ needs access to them for internal logic. + if (_publish.propertyBuilder) + { + this->publishData.constructPropertyBuilder(); + this->publishData.propertyBuilder->setNewUserProperties(_publish.propertyBuilder->getUserProperties()); + } + + // Step 2: this line will make sure the whole byte array containing all properties as flat bytes is present in the 'bites' vector, + // which is sent to the subscribers. writeProperties(_publish.propertyBuilder); } @@ -368,13 +377,8 @@ void MqttPacket::handleConnect() request_problem_information = !!readByte(); break; case Mqtt5Properties::UserProperty: - { - const uint16_t len = readTwoBytesToUInt16(); - readBytes(len); - const uint16_t len2 = readTwoBytesToUInt16(); - readBytes(len2); + readUserProperty(); break; - } case Mqtt5Properties::AuthenticationMethod: { const uint16_t len = readTwoBytesToUInt16(); @@ -407,7 +411,7 @@ void MqttPacket::handleConnect() { if (protocolVersion == ProtocolVersion::Mqtt5) { - willpublish.propertyBuilder = std::make_unique(); + willpublish.constructPropertyBuilder(); const size_t proplen = decodeVariableByteIntAtPos(); const size_t prop_end_at = pos + proplen; @@ -450,15 +454,16 @@ void MqttPacket::handleConnect() { const uint16_t len = readTwoBytesToUInt16(); const std::string correlationData(readBytes(len), len); - publishData.propertyBuilder->writeCorrelationData(correlationData); + willpublish.propertyBuilder->writeCorrelationData(correlationData); break; } case Mqtt5Properties::UserProperty: { - const uint16_t len = readTwoBytesToUInt16(); - readBytes(len); - const uint16_t len2 = readTwoBytesToUInt16(); - readBytes(len2); + const uint16_t lenKey = readTwoBytesToUInt16(); + std::string userPropKey(readBytes(lenKey), lenKey); + const uint16_t lenVal = readTwoBytesToUInt16(); + std::string userPropVal(readBytes(lenVal), lenVal); + willpublish.propertyBuilder->writeUserProperty(std::move(userPropKey), std::move(userPropVal)); break; } default: @@ -554,7 +559,7 @@ void MqttPacket::handleConnect() sender->setDisconnectReason("Invalid username character"); accessGranted = false; } - else if (authentication.unPwdCheck(username, password) == AuthResult::success) + else if (authentication.unPwdCheck(username, password, getUserProperties()) == AuthResult::success) { accessGranted = true; } @@ -644,13 +649,8 @@ void MqttPacket::handleSubscribe() decodeVariableByteIntAtPos(); break; case Mqtt5Properties::UserProperty: - { - const uint16_t len = readTwoBytesToUInt16(); - readBytes(len); - const uint16_t len2 = readTwoBytesToUInt16(); - readBytes(len2); + readUserProperty(); break; - } default: throw ProtocolError("Invalid subscribe property."); } @@ -678,7 +678,7 @@ void MqttPacket::handleSubscribe() std::vector subtopics; splitTopic(topic, subtopics); - if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, subtopics, AclAccess::subscribe, qos, false) == AuthResult::success) + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, subtopics, AclAccess::subscribe, qos, false, getUserProperties()) == AuthResult::success) { logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribed to '%s' QoS %d", sender->repr().c_str(), topic.c_str(), qos); sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, subtopics, qos); @@ -821,7 +821,7 @@ void MqttPacket::handlePublish() const size_t prop_end_at = pos + proplen; if (proplen > 0) - publishData.propertyBuilder = std::make_shared(); + publishData.constructPropertyBuilder(); while (pos < prop_end_at) { @@ -854,10 +854,10 @@ void MqttPacket::handlePublish() case Mqtt5Properties::UserProperty: { const uint16_t lenKey = readTwoBytesToUInt16(); - const std::string userPropKey(readBytes(lenKey), lenKey); + std::string userPropKey(readBytes(lenKey), lenKey); const uint16_t lenVal = readTwoBytesToUInt16(); - const std::string userPropVal(readBytes(lenVal), lenVal); - publishData.propertyBuilder->writeUserProperty(userPropKey, userPropVal); + std::string userPropVal(readBytes(lenVal), lenVal); + publishData.propertyBuilder->writeUserProperty(std::move(userPropKey), std::move(userPropVal)); break; } case Mqtt5Properties::SubscriptionIdentifier: @@ -882,7 +882,7 @@ void MqttPacket::handlePublish() payloadStart = pos; Authentication &authentication = *ThreadGlobals::getAuth(); - if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.subtopics, AclAccess::write, qos, retain) == AuthResult::success) + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.subtopics, AclAccess::write, qos, retain, getUserProperties()) == AuthResult::success) { if (retain) { @@ -1176,6 +1176,26 @@ size_t MqttPacket::decodeVariableByteIntAtPos() return value; } +void MqttPacket::readUserProperty() +{ + this->publishData.constructPropertyBuilder(); + + const uint16_t len = readTwoBytesToUInt16(); + std::string key(readBytes(len), len); + const uint16_t len2 = readTwoBytesToUInt16(); + std::string value(readBytes(len2), len2); + + this->publishData.propertyBuilder->writeUserProperty(std::move(key), std::move(value)); +} + +const std::vector> *MqttPacket::getUserProperties() const +{ + if (this->publishData.propertyBuilder) + return this->publishData.propertyBuilder->getUserProperties().get(); + + return nullptr; +} + bool MqttPacket::getRetain() const { return (first_byte & 0b00000001); diff --git a/mqttpacket.h b/mqttpacket.h index b98b8dd..929c8de 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -67,11 +67,12 @@ class MqttPacket uint32_t readFourBytesToUint32(); size_t remainingAfterPos(); size_t decodeVariableByteIntAtPos(); + void readUserProperty(); void calculateRemainingLength(); void pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits = 0); - MqttPacket(const MqttPacket &other) = default; + MqttPacket(const MqttPacket &other) = delete; public: PacketType packetType = PacketType::Reserved; MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, std::shared_ptr &sender); // Constructor for parsing incoming packets. @@ -123,6 +124,7 @@ public: void setRetain(); const Publish &getPublishData(); bool containsClientSpecificProperties() const; + const std::vector> *getUserProperties() const; }; #endif // MQTTPACKET_H diff --git a/publishcopyfactory.cpp b/publishcopyfactory.cpp index bf3a60a..cc759cc 100644 --- a/publishcopyfactory.cpp +++ b/publishcopyfactory.cpp @@ -124,3 +124,20 @@ std::shared_ptr PublishCopyFactory::getSender() return packet->getSender(); return std::shared_ptr(0); } + +const std::vector > *PublishCopyFactory::getUserProperties() const +{ + if (packet) + { + return packet->getUserProperties(); + } + + assert(publish); + + if (publish->propertyBuilder) + { + return publish->propertyBuilder->getUserProperties().get(); + } + + return nullptr; +} diff --git a/publishcopyfactory.h b/publishcopyfactory.h index 80c775c..b35e3db 100644 --- a/publishcopyfactory.h +++ b/publishcopyfactory.h @@ -36,6 +36,8 @@ public: bool getRetain() const; Publish getNewPublish() const; std::shared_ptr getSender(); + const std::vector> *getUserProperties() const; + }; #endif // PUBLISHCOPYFACTORY_H diff --git a/session.cpp b/session.cpp index 9e88783..db162cf 100644 --- a/session.cpp +++ b/session.cpp @@ -160,7 +160,7 @@ void Session::writePacket(PublishCopyFactory ©Factory, const char max_qos, u Authentication *_auth = ThreadGlobals::getAuth(); assert(_auth); Authentication &auth = *_auth; - if (auth.aclCheck(client_id, username, copyFactory.getTopic(), copyFactory.getSubtopics(), AclAccess::read, effectiveQos, copyFactory.getRetain()) == AuthResult::success) + if (auth.aclCheck(client_id, username, copyFactory.getTopic(), copyFactory.getSubtopics(), AclAccess::read, effectiveQos, copyFactory.getRetain(), copyFactory.getUserProperties()) == AuthResult::success) { std::shared_ptr c = makeSharedClient(); if (effectiveQos == 0) diff --git a/types.cpp b/types.cpp index 0b46019..d6866d1 100644 --- a/types.cpp +++ b/types.cpp @@ -141,6 +141,14 @@ void PublishBase::setClientSpecificProperties() propertyBuilder->writeMessageExpiryInterval(newExpiresAfter.count()); } +void PublishBase::constructPropertyBuilder() +{ + if (this->propertyBuilder) + return; + + this->propertyBuilder = std::make_shared(); +} + Publish::Publish(const Publish &other) : PublishBase(other) { diff --git a/types.h b/types.h index 4cf71ac..503b61a 100644 --- a/types.h +++ b/types.h @@ -212,6 +212,7 @@ public: PublishBase(const std::string &topic, const std::string &payload, char qos); size_t getLengthWithoutFixedHeader() const; void setClientSpecificProperties(); + void constructPropertyBuilder(); }; class Publish : public PublishBase