From 313b3346c39903b76a835d8a211c0679191bf26b Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Mon, 7 Mar 2022 21:05:58 +0100 Subject: [PATCH] Connect/connack in mqtt5 --- CMakeLists.txt | 2 ++ forward_declarations.h | 1 + mqtt5properties.cpp | 133 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ mqtt5properties.h | 35 +++++++++++++++++++++++++++++++++++ mqttpacket.cpp | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++++------------ mqttpacket.h | 3 +++ types.cpp | 62 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++----- types.h | 66 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--- variablebyteint.cpp | 5 +++++ variablebyteint.h | 1 + 10 files changed, 352 insertions(+), 20 deletions(-) create mode 100644 mqtt5properties.cpp create mode 100644 mqtt5properties.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 6924f56..5a171c2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,6 +60,7 @@ add_executable(FlashMQ threadloop.h publishcopyfactory.h variablebyteint.h + mqtt5properties.h mainapp.cpp main.cpp @@ -99,6 +100,7 @@ add_executable(FlashMQ threadloop.cpp publishcopyfactory.cpp variablebyteint.cpp + mqtt5properties.cpp ) diff --git a/forward_declarations.h b/forward_declarations.h index 755dc38..fc9813a 100644 --- a/forward_declarations.h +++ b/forward_declarations.h @@ -26,6 +26,7 @@ class MqttPacket; class SubscriptionStore; class Session; class Settings; +class Mqtt5PropertyBuilder; #endif // FORWARD_DECLARATIONS_H diff --git a/mqtt5properties.cpp b/mqtt5properties.cpp new file mode 100644 index 0000000..89d8d4b --- /dev/null +++ b/mqtt5properties.cpp @@ -0,0 +1,133 @@ +#include "mqtt5properties.h" + +#include "cstring" + +#include "exceptions.h" + +Mqtt5PropertyBuilder::Mqtt5PropertyBuilder() +{ + bites.reserve(128); +} + +size_t Mqtt5PropertyBuilder::getLength() const +{ + return length.getLen() + bites.size(); +} + +const VariableByteInt &Mqtt5PropertyBuilder::getVarInt() const +{ + return length; +} + +const std::vector &Mqtt5PropertyBuilder::getBites() const +{ + return bites; +} + +void Mqtt5PropertyBuilder::writeSessionExpiry(uint32_t val) +{ + writeUint32(Mqtt5Properties::SessionExpiryInterval, val); +} + +void Mqtt5PropertyBuilder::writeReceiveMax(uint16_t val) +{ + writeUint16(Mqtt5Properties::ReceiveMaximum, val); +} + +void Mqtt5PropertyBuilder::writeRetainAvailable(uint8_t val) +{ + writeUint8(Mqtt5Properties::RetainAvailable, val); +} + +void Mqtt5PropertyBuilder::writeMaxPacketSize(uint32_t val) +{ + writeUint32(Mqtt5Properties::MaximumPacketSize, val); +} + +void Mqtt5PropertyBuilder::writeAssignedClientId(const std::string &clientid) +{ + writeStr(Mqtt5Properties::AssignedClientIdentifier, clientid); +} + +void Mqtt5PropertyBuilder::writeMaxTopicAliases(uint16_t val) +{ + writeUint16(Mqtt5Properties::TopicAliasMaximum, val); +} + +void Mqtt5PropertyBuilder::writeWildcardSubscriptionAvailable(uint8_t val) +{ + writeUint8(Mqtt5Properties::WildcardSubscriptionAvailable, val); +} + +void Mqtt5PropertyBuilder::writeSharedSubscriptionAvailable(uint8_t val) +{ + writeUint8(Mqtt5Properties::SharedSubscriptionAvailable, val); +} + +void Mqtt5PropertyBuilder::writeUint32(Mqtt5Properties prop, const uint32_t x) +{ + size_t pos = bites.size(); + const size_t newSize = pos + 5; + bites.resize(newSize); + this->length = newSize; + + const uint8_t a = static_cast(x >> 24); + const uint8_t b = static_cast(x >> 16); + const uint8_t c = static_cast(x >> 8); + const uint8_t d = static_cast(x); + + bites[pos++] = static_cast(prop); + bites[pos++] = a; + bites[pos++] = b; + bites[pos++] = c; + bites[pos] = d; +} + +void Mqtt5PropertyBuilder::writeUint16(Mqtt5Properties prop, const uint16_t x) +{ + size_t pos = bites.size(); + const size_t newSize = pos + 3; + bites.resize(newSize); + this->length = newSize; + + const uint8_t a = static_cast(x >> 8); + const uint8_t b = static_cast(x); + + bites[pos++] = static_cast(prop); + bites[pos++] = a; + bites[pos] = b; +} + +void Mqtt5PropertyBuilder::writeUint8(Mqtt5Properties prop, const uint8_t x) +{ + size_t pos = bites.size(); + const size_t newSize = pos + 2; + bites.resize(newSize); + this->length = newSize; + + bites[pos++] = static_cast(prop); + bites[pos] = x; +} + +void Mqtt5PropertyBuilder::writeStr(Mqtt5Properties prop, const std::string &str) +{ + if (str.length() > 65535) + throw ProtocolError("String too long."); + + const uint16_t strlen = str.length(); + + size_t pos = bites.size(); + const size_t newSize = pos + strlen + 2; + bites.resize(newSize); + this->length = newSize; + + const uint8_t a = static_cast(strlen >> 8); + const uint8_t b = static_cast(strlen); + + bites[pos++] = static_cast(prop); + bites[pos++] = a; + bites[pos++] = b; + + std::memcpy(&bites[pos], str.c_str(), strlen); +} + diff --git a/mqtt5properties.h b/mqtt5properties.h new file mode 100644 index 0000000..2650847 --- /dev/null +++ b/mqtt5properties.h @@ -0,0 +1,35 @@ +#ifndef MQTT5PROPERTIES_H +#define MQTT5PROPERTIES_H + +#include + +#include "types.h" +#include "variablebyteint.h" + +class Mqtt5PropertyBuilder +{ + std::vector bites; + VariableByteInt length; + + void writeUint32(Mqtt5Properties prop, const uint32_t x); + void writeUint16(Mqtt5Properties prop, const uint16_t x); + void writeUint8(Mqtt5Properties prop, const uint8_t x); + void writeStr(Mqtt5Properties prop, const std::string &str); +public: + Mqtt5PropertyBuilder(); + + size_t getLength() const; + const VariableByteInt &getVarInt() const; + const std::vector &getBites() const; + + void writeSessionExpiry(uint32_t val); + void writeReceiveMax(uint16_t val); + void writeRetainAvailable(uint8_t val); + void writeMaxPacketSize(uint32_t val); + void writeAssignedClientId(const std::string &clientid); + void writeMaxTopicAliases(uint16_t val); + void writeWildcardSubscriptionAvailable(uint8_t val); + void writeSharedSubscriptionAvailable(uint8_t val); +}; + +#endif // MQTT5PROPERTIES_H diff --git a/mqttpacket.cpp b/mqttpacket.cpp index abf8939..87cdf4a 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -110,17 +110,20 @@ std::shared_ptr MqttPacket::getCopy(char new_max_qos) const return copyPacket; } -// This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector. MqttPacket::MqttPacket(const ConnAck &connAck) : - bites(connAck.getLengthWithoutFixedHeader() + 2) + bites(connAck.getLengthWithoutFixedHeader()) { - fixed_header_length = 2; packetType = PacketType::CONNACK; first_byte = static_cast(packetType) << 4; - writeByte(first_byte); - writeByte(2); // length is always 2. writeByte(connAck.session_present & 0b00000001); // all connect-ack flags are 0, except session-present. [MQTT-3.2.2.1] - writeByte(static_cast(connAck.return_code)); + writeByte(connAck.return_code); + + if (connAck.protocol_version >= ProtocolVersion::Mqtt5) + { + writeProperties(connAck.propertyBuilder); + } + + calculateRemainingLength(); } MqttPacket::MqttPacket(const SubAck &subAck) : @@ -343,7 +346,10 @@ void MqttPacket::handleConnect() } else { - ConnAck connAck(ConnAckReturnCodes::UnacceptableProtocolVersion); + // The specs are unclear when to use the version 3 codes or version 5 codes. + ProtocolVersion fuzzyProtocolVersion = protocol_level < 0x05 ? ProtocolVersion::Mqtt31 : ProtocolVersion::Mqtt5; + + ConnAck connAck(fuzzyProtocolVersion, ReasonCodes::UnsupportedProtocolVersion); MqttPacket response(connAck); sender->setReadyForDisconnect(); sender->writeMqttPacket(response); @@ -398,7 +404,7 @@ void MqttPacket::handleConnect() max_packet_size = std::min(readFourBytesToUint32(), max_packet_size); break; case Mqtt5Properties::TopicAliasMaximum: - max_topic_aliases = readTwoBytesToUInt16(); + max_topic_aliases = std::min(readTwoBytesToUInt16(), max_topic_aliases); break; case Mqtt5Properties::RequestResponseInformation: request_response_information = !!readByte(); @@ -451,7 +457,7 @@ void MqttPacket::handleConnect() // The specs don't really say what to do when client id not UTF8, so including here. if (!isValidUtf8(client_id) || !isValidUtf8(username) || !isValidUtf8(password) || !isValidUtf8(will_topic)) { - ConnAck connAck(ConnAckReturnCodes::MalformedUsernameOrPassword); + ConnAck connAck(protocolVersion, ReasonCodes::BadUserNameOrPassword); MqttPacket response(connAck); sender->setReadyForDisconnect(); sender->writeMqttPacket(response); @@ -480,7 +486,7 @@ void MqttPacket::handleConnect() if (!validClientId) { - ConnAck connAck(ConnAckReturnCodes::ClientIdRejected); + ConnAck connAck(protocolVersion, ReasonCodes::ClientIdentifierNotValid); MqttPacket response(connAck); sender->setDisconnectReason("Invalid clientID"); sender->setReadyForDisconnect(); @@ -488,9 +494,11 @@ void MqttPacket::handleConnect() return; } + bool clientIdGenerated = false; if (client_id.empty()) { client_id = getSecureRandomString(23); + clientIdGenerated = true; } sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, max_packet_size, max_topic_aliases); @@ -521,7 +529,22 @@ void MqttPacket::handleConnect() bool sessionPresent = protocolVersion >= ProtocolVersion::Mqtt311 && !clean_start && subscriptionStore->sessionPresent(client_id); sender->setAuthenticated(true); - ConnAck connAck(ConnAckReturnCodes::Accepted, sessionPresent); + ConnAck connAck(protocolVersion, ReasonCodes::Success, sessionPresent); + + if (protocolVersion >= ProtocolVersion::Mqtt5) + { + connAck.propertyBuilder = std::make_shared(); + connAck.propertyBuilder->writeSessionExpiry(session_expire); + connAck.propertyBuilder->writeReceiveMax(max_qos_packets); + connAck.propertyBuilder->writeRetainAvailable(1); + connAck.propertyBuilder->writeMaxPacketSize(max_packet_size); + if (clientIdGenerated) + connAck.propertyBuilder->writeAssignedClientId(client_id); + connAck.propertyBuilder->writeMaxTopicAliases(max_topic_aliases); + connAck.propertyBuilder->writeWildcardSubscriptionAvailable(1); + connAck.propertyBuilder->writeSharedSubscriptionAvailable(0); + } + MqttPacket response(connAck); sender->writeMqttPacket(response); logger->logf(LOG_NOTICE, "Client '%s' logged in successfully", sender->repr().c_str()); @@ -530,7 +553,7 @@ void MqttPacket::handleConnect() } else { - ConnAck connDeny(ConnAckReturnCodes::NotAuthorized, false); + ConnAck connDeny(protocolVersion, ReasonCodes::NotAuthorized, false); MqttPacket response(connDeny); sender->setDisconnectReason("Access denied"); sender->setReadyForDisconnect(); @@ -960,6 +983,23 @@ void MqttPacket::writeBytes(const char *b, size_t len) pos += len; } +void MqttPacket::writeProperties(const std::shared_ptr &properties) +{ + if (!properties) + writeByte(0); + else + { + writeVariableByteInt(properties->getVarInt()); + const std::vector &b = properties->getBites(); + writeBytes(b.data(), b.size()); + } +} + +void MqttPacket::writeVariableByteInt(const VariableByteInt &v) +{ + writeBytes(v.data(), v.getLen()); +} + uint16_t MqttPacket::readTwoBytesToUInt16() { if (pos + 2 > bites.size()) diff --git a/mqttpacket.h b/mqttpacket.h index 283cc50..ca2fe2b 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -34,6 +34,7 @@ License along with FlashMQ. If not, see . #include "mainapp.h" #include "variablebyteint.h" +#include "mqtt5properties.h" class MqttPacket { @@ -62,6 +63,8 @@ class MqttPacket void writeByte(char b); void writeUint16(uint16_t x); void writeBytes(const char *b, size_t len); + void writeProperties(const std::shared_ptr &properties); + void writeVariableByteInt(const VariableByteInt &v); uint16_t readTwoBytesToUInt16(); uint32_t readFourBytesToUint32(); size_t remainingAfterPos(); diff --git a/types.cpp b/types.cpp index 8907b3d..6c1c03d 100644 --- a/types.cpp +++ b/types.cpp @@ -18,14 +18,66 @@ License along with FlashMQ. If not, see . #include "cassert" #include "types.h" +#include "mqtt5properties.h" -ConnAck::ConnAck(ConnAckReturnCodes return_code, bool session_present) : - return_code(return_code), +ConnAck::ConnAck(const ProtocolVersion protVersion, ReasonCodes return_code, bool session_present) : + protocol_version(protVersion), session_present(session_present) { - // [MQTT-3.2.2-4] - if (return_code > ConnAckReturnCodes::Accepted) - session_present = false; + + if (this->protocol_version <= ProtocolVersion::Mqtt311) + { + ConnAckReturnCodes mqtt3_return = ConnAckReturnCodes::Accepted; + + switch (return_code) + { + case ReasonCodes::Success: + mqtt3_return = ConnAckReturnCodes::Accepted; + break; + case ReasonCodes::UnsupportedProtocolVersion: + mqtt3_return = ConnAckReturnCodes::UnacceptableProtocolVersion; + break; + case ReasonCodes::ClientIdentifierNotValid: + mqtt3_return = ConnAckReturnCodes::ClientIdRejected; + break; + case ReasonCodes::ServerUnavailable: + mqtt3_return = ConnAckReturnCodes::ServerUnavailable; + break; + case ReasonCodes::BadUserNameOrPassword: + mqtt3_return = ConnAckReturnCodes::MalformedUsernameOrPassword; + break; + case ReasonCodes::NotAuthorized: + mqtt3_return = ConnAckReturnCodes::NotAuthorized; + default: + assert(false); + } + + // [MQTT-3.2.2-4] + if (mqtt3_return > ConnAckReturnCodes::Accepted) + session_present = false; + + this->return_code = static_cast(mqtt3_return); + } + else + { + this->return_code = static_cast(return_code); + + // MQTT-3.2.2-6 + if (this->return_code > 0) + session_present = false; + } +} + +size_t ConnAck::getLengthWithoutFixedHeader() const +{ + size_t result = 2; + + if (this->protocol_version >= ProtocolVersion::Mqtt5) + { + const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1; + result += proplen; + } + return result; } SubAck::SubAck(uint16_t packet_id, const std::list &subs_qos_reponses) : diff --git a/types.h b/types.h index 05aa950..e883e2e 100644 --- a/types.h +++ b/types.h @@ -21,6 +21,9 @@ License along with FlashMQ. If not, see . #include "stdint.h" #include #include +#include + +#include "forward_declarations.h" enum class PacketType { @@ -83,6 +86,9 @@ enum class Mqtt5Properties SharedSubscriptionAvailable = 42 }; +/** + * @brief The ConnAckReturnCodes enum are for MQTT3 + */ enum class ConnAckReturnCodes { Accepted = 0, @@ -93,13 +99,67 @@ enum class ConnAckReturnCodes NotAuthorized = 5 }; +/** + * @brief The ReasonCodes enum are for MQTT5. + */ +enum class ReasonCodes +{ + Success = 0, + GrantedQoS0 = 0, + GrantedQoS1 = 1, + GrantedQoS2 = 2, + DisconnectWithWill = 4, + NoMatchingSubscribers = 16, + NoSubscriptionExisted = 17, + ContinueAuthentication = 24, + ReAuthenticate = 25, + UnspecifiedError = 128, + MalformedPacket = 129, + ProtocolError = 130, + ImplementationSpecificError = 131, + UnsupportedProtocolVersion = 132, + ClientIdentifierNotValid = 133, + BadUserNameOrPassword = 134, + NotAuthorized = 135, + ServerUnavailable = 136, + ServerBusy = 137, + Banned = 138, + ServerShuttingDown = 139, + BadAuthenticationMethod = 140, + KeepAliveTimeout = 141, + SessionTakenOver = 142, + TopicFilterInvalid = 143, + TopicNameInvalid = 144, + PacketIdentifierInUse = 145, + ReceiveMaximumExceeded = 147, + TopicAliasInvalid = 148, + PacketTooLarge = 149, + MessageRateTooHigh = 150, + QuoteExceeded = 151, + AdministrativeAction = 152, + PayloadFormatInvalid = 153, + RetainNotSupported = 154, + QosNotSupported = 155, + UseAnotherServer = 156, + ServerMoved = 157, + SharedSubscriptionsNotSupported = 158, + ConnectionRateExceeded = 159, + MaximumConnectTime = 160, + SubscriptionIdentifiersNotSupported = 161, + WildcardSubscriptionsNotSupported = 162 +}; + class ConnAck { public: - ConnAck(ConnAckReturnCodes return_code, bool session_present=false); - ConnAckReturnCodes return_code; + ConnAck(const ProtocolVersion protVersion, ReasonCodes return_code, bool session_present=false); + + const ProtocolVersion protocol_version; + uint8_t return_code; bool session_present = false; - size_t getLengthWithoutFixedHeader() const { return 2;} // size of connack is always the same + std::shared_ptr propertyBuilder; + + size_t getLengthWithoutFixedHeader() const; }; enum class SubAckReturnCodes diff --git a/variablebyteint.cpp b/variablebyteint.cpp index b126a07..e66683f 100644 --- a/variablebyteint.cpp +++ b/variablebyteint.cpp @@ -34,3 +34,8 @@ uint8_t VariableByteInt::getLen() const { return len; } + +const char *VariableByteInt::data() const +{ + return &bytes[0]; +} diff --git a/variablebyteint.h b/variablebyteint.h index 3322262..445136b 100644 --- a/variablebyteint.h +++ b/variablebyteint.h @@ -12,6 +12,7 @@ public: void readIntoBuf(CirBuf &buf) const; VariableByteInt &operator=(uint32_t x); uint8_t getLen() const; + const char *data() const; }; #endif // VARIABLEBYTEINT_H -- libgit2 0.21.4