From ed1479935ac2866fa015b5e19744bba76724af13 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Mon, 24 May 2021 19:33:11 +0200 Subject: [PATCH] Add support for QoS 2 --- mqttpacket.cpp | 138 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------- mqttpacket.h | 9 +++++++++ session.cpp | 51 +++++++++++++++++++++++++++++++++++++++++++++++++-- session.h | 11 +++++++++++ types.cpp | 33 +++++++++++++++++++++++++++++++++ types.h | 24 ++++++++++++++++++++++++ 6 files changed, 243 insertions(+), 23 deletions(-) diff --git a/mqttpacket.cpp b/mqttpacket.cpp index d266dd9..4ec5ad3 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -132,20 +132,50 @@ MqttPacket::MqttPacket(const Publish &publish) : calculateRemainingLength(); } -// This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector. -MqttPacket::MqttPacket(const PubAck &pubAck) : - bites(pubAck.getLengthWithoutFixedHeader() + 2) +/** + * @brief MqttPacket::pubCommonConstruct is common code for constructors for all those empty pub control packets (ack, rec(eived), rel(ease), comp(lete)). + * @param packet_id + * + * This functions cheats a bit and doesn't use calculateRemainingLength, because it's always 2. Be sure to allocate enough room in the vector when + * you use this function (add 2 to the length). + */ +void MqttPacket::pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits) { - fixed_header_length = 2; // This is the cheat part mentioned above. We're not calculating it dynamically. - packetType = PacketType::PUBACK; - first_byte = static_cast(packetType) << 4; + assert(firstByteDefaultBits <= 0xF); + + fixed_header_length = 2; + first_byte = (static_cast(packetType) << 4) | firstByteDefaultBits; writeByte(first_byte); writeByte(2); // length is always 2. - char topicLenMSB = (pubAck.packet_id & 0xFF00) >> 8; - char topicLenLSB = (pubAck.packet_id & 0x00FF); + uint8_t packetIdMSB = (packet_id & 0xFF00) >> 8; + uint8_t packetIdLSB = (packet_id & 0x00FF); packet_id_pos = pos; - writeByte(topicLenMSB); - writeByte(topicLenLSB); + writeByte(packetIdMSB); + writeByte(packetIdLSB); +} + +MqttPacket::MqttPacket(const PubAck &pubAck) : + bites(pubAck.getLengthWithoutFixedHeader() + 2) +{ + pubCommonConstruct(pubAck.packet_id, PacketType::PUBACK); +} + +MqttPacket::MqttPacket(const PubRec &pubRec) : + bites(pubRec.getLengthWithoutFixedHeader() + 2) +{ + pubCommonConstruct(pubRec.packet_id, PacketType::PUBREC); +} + +MqttPacket::MqttPacket(const PubComp &pubComp) : + bites(pubComp.getLengthWithoutFixedHeader() + 2) +{ + pubCommonConstruct(pubComp.packet_id, PacketType::PUBCOMP); +} + +MqttPacket::MqttPacket(const PubRel &pubRel) : + bites(pubRel.getLengthWithoutFixedHeader() + 2) +{ + pubCommonConstruct(pubRel.packet_id, PacketType::PUBREL, 0b0010); } void MqttPacket::handle() @@ -175,6 +205,12 @@ void MqttPacket::handle() handlePublish(); else if (packetType == PacketType::PUBACK) handlePubAck(); + else if (packetType == PacketType::PUBREC) + handlePubRec(); + else if (packetType == PacketType::PUBREL) + handlePubRel(); + else if (packetType == PacketType::PUBCOMP) + handlePubComp(); } void MqttPacket::handleConnect() @@ -437,7 +473,7 @@ void MqttPacket::handlePublish() bool dup = !!(first_byte & 0b00001000); char qos = (first_byte & 0b00000110) >> 1; - if (qos == 3) + if (qos > 2) throw ProtocolError("QoS 3 is a protocol violation."); this->qos = qos; @@ -453,21 +489,37 @@ void MqttPacket::handlePublish() return; } +#ifndef NDEBUG + logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", topic.c_str(), qos, retain, dup); +#endif + if (qos) { - if (qos > 1) - throw ProtocolError("Qos > 1 not implemented."); packet_id_pos = pos; - uint16_t packet_id = readTwoBytesToUInt16(); + packet_id = readTwoBytesToUInt16(); - // Clear the packet ID from this packet, because each new publish must get a new one. It's more of a debug precaution. - pos -= 2; - char zero[2]; zero[0] = 0; zero[1] = 0; - writeBytes(zero, 2); + if (qos == 1) + { + PubAck pubAck(packet_id); + MqttPacket response(pubAck); + sender->writeMqttPacket(response); + } + else + { + PubRec pubRec(packet_id); + MqttPacket response(pubRec); + sender->writeMqttPacket(response); - PubAck pubAck(packet_id); - MqttPacket response(pubAck); - sender->writeMqttPacket(response); + if (sender->getSession()->incomingQoS2MessageIdInTransit(packet_id)) + { + return; + } + else + { + // Doing this before the authentication on purpose, so when the publish is not allowed, the QoS control packets are allowed and can finish. + sender->getSession()->addIncomingQoS2MessageId(packet_id); + } + } } if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write) == AuthResult::success) @@ -496,6 +548,42 @@ void MqttPacket::handlePubAck() sender->getSession()->clearQosMessage(packet_id); } +/** + * @brief MqttPacket::handlePubRec handles QoS 2 'publish received' packets. The publisher receives these. + */ +void MqttPacket::handlePubRec() +{ + const uint16_t packet_id = readTwoBytesToUInt16(); + sender->getSession()->clearQosMessage(packet_id); + sender->getSession()->addOutgoingQoS2MessageId(packet_id); + + PubRel pubRel(packet_id); + MqttPacket response(pubRel); + sender->writeMqttPacket(response); +} + +/** + * @brief MqttPacket::handlePubRel handles QoS 2 'publish release'. The publisher sends these. + */ +void MqttPacket::handlePubRel() +{ + const uint16_t packet_id = readTwoBytesToUInt16(); + sender->getSession()->removeIncomingQoS2MessageId(packet_id); + + PubComp pubcomp(packet_id); + MqttPacket response(pubcomp); + sender->writeMqttPacket(response); +} + +/** + * @brief MqttPacket::handlePubComp handles QoS 2 'publish complete'. The publisher receives these. + */ +void MqttPacket::handlePubComp() +{ + const uint16_t packet_id = readTwoBytesToUInt16(); + sender->getSession()->removeOutgoingQoS2MessageId(packet_id); +} + void MqttPacket::calculateRemainingLength() { assert(fixed_header_length == 0); // because you're not supposed to call this on packet that we already know the length of. @@ -529,6 +617,8 @@ void MqttPacket::setPacketId(uint16_t packet_id) assert(packetType == PacketType::PUBLISH); assert(qos > 0); + this->packet_id = packet_id; + pos = packet_id_pos; char topicLenMSB = (packet_id & 0xFF00) >> 8; @@ -537,6 +627,12 @@ void MqttPacket::setPacketId(uint16_t packet_id) writeByte(topicLenLSB); } +uint16_t MqttPacket::getPacketId() const +{ + assert(qos > 0); + return packet_id; +} + // If I read the specs correctly, the DUP flag is merely for show. It doesn't control anything? void MqttPacket::setDuplicate() { diff --git a/mqttpacket.h b/mqttpacket.h index 7c98298..98f85a6 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -57,6 +57,7 @@ class MqttPacket char first_byte = 0; size_t pos = 0; size_t packet_id_pos = 0; + uint16_t packet_id = 0; ProtocolVersion protocolVersion = ProtocolVersion::None; Logger *logger = Logger::getInstance(); @@ -68,6 +69,7 @@ class MqttPacket size_t remainingAfterPos(); void calculateRemainingLength(); + void pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits = 0); MqttPacket(const MqttPacket &other) = default; public: @@ -84,6 +86,9 @@ public: MqttPacket(const UnsubAck &unsubAck); MqttPacket(const Publish &publish); MqttPacket(const PubAck &pubAck); + MqttPacket(const PubRec &pubRec); + MqttPacket(const PubComp &pubComp); + MqttPacket(const PubRel &pubRel); void handle(); void handleConnect(); @@ -93,6 +98,9 @@ public: void handlePing(); void handlePublish(); void handlePubAck(); + void handlePubRec(); + void handlePubRel(); + void handlePubComp(); size_t getSizeIncludingNonPresentHeader() const; const std::vector &getBites() const { return bites; } @@ -105,6 +113,7 @@ public: char getFirstByte() const; RemainingLength getRemainingLength() const; void setPacketId(uint16_t packet_id); + uint16_t getPacketId() const; void setDuplicate(); size_t getTotalMemoryFootprint(); }; diff --git a/session.cpp b/session.cpp index f5e2e95..7a1442e 100644 --- a/session.cpp +++ b/session.cpp @@ -64,11 +64,13 @@ void Session::writePacket(const MqttPacket &packet, char max_qos) c->writeMqttPacketAndBlameThisClient(packet, qos); } } - else if (qos == 1) + else if (qos > 0) { std::shared_ptr copyPacket = packet.getCopy(); std::unique_lock locker(qosQueueMutex); - if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) + + const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size(); + if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) { logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str()); return; @@ -145,6 +147,13 @@ void Session::sendPendingQosMessages() c->writeMqttPacketAndBlameThisClient(*qosMessage.packet.get(), qosMessage.packet->getQos()); qosMessage.packet->setDuplicate(); // Any dealings with this packet from here will be a duplicate. } + + for (const uint16_t packet_id : outgoingQoS2MessageIds) + { + PubRel pubRel(packet_id); + MqttPacket packet(pubRel); + c->writeMqttPacketAndBlameThisClient(packet, 2); + } } } @@ -158,3 +167,41 @@ bool Session::hasExpired() { return clientDisconnected() && (lastTouched + EXPIRE_SESSION_AFTER) < time(NULL); } + +void Session::addIncomingQoS2MessageId(uint16_t packet_id) +{ + incomingQoS2MessageIds.insert(packet_id); +} + +bool Session::incomingQoS2MessageIdInTransit(uint16_t packet_id) const +{ + const auto it = incomingQoS2MessageIds.find(packet_id); + return it != incomingQoS2MessageIds.end(); +} + +void Session::removeIncomingQoS2MessageId(u_int16_t packet_id) +{ +#ifndef NDEBUG + logger->logf(LOG_DEBUG, "As QoS 2 receiver: publish released (PUBREL) for '%s', packet id '%d'. Left in queue: %d", client_id.c_str(), packet_id, incomingQoS2MessageIds.size()); +#endif + + const auto it = incomingQoS2MessageIds.find(packet_id); + if (it != incomingQoS2MessageIds.end()) + incomingQoS2MessageIds.erase(it); +} + +void Session::addOutgoingQoS2MessageId(uint16_t packet_id) +{ + outgoingQoS2MessageIds.insert(packet_id); +} + +void Session::removeOutgoingQoS2MessageId(u_int16_t packet_id) +{ +#ifndef NDEBUG + logger->logf(LOG_DEBUG, "As QoS 2 sender: publish complete (PUBCOMP) for '%s', packet id '%d'. Left in queue: %d", client_id.c_str(), packet_id, outgoingQoS2MessageIds.size()); +#endif + + const auto it = outgoingQoS2MessageIds.find(packet_id); + if (it != outgoingQoS2MessageIds.end()) + outgoingQoS2MessageIds.erase(it); +} diff --git a/session.h b/session.h index 6bfa1d9..596a0d6 100644 --- a/session.h +++ b/session.h @@ -21,6 +21,7 @@ License along with FlashMQ. If not, see . #include #include #include +#include #include "forward_declarations.h" #include "logger.h" @@ -45,6 +46,8 @@ class Session std::string client_id; std::string username; std::list qosPacketQueue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] + std::set incomingQoS2MessageIds; + std::set outgoingQoS2MessageIds; std::mutex qosQueueMutex; uint16_t nextPacketId = 0; ssize_t qosQueueBytes = 0; @@ -66,6 +69,14 @@ public: void sendPendingQosMessages(); void touch(time_t val = 0); bool hasExpired(); + + void addIncomingQoS2MessageId(uint16_t packet_id); + bool incomingQoS2MessageIdInTransit(uint16_t packet_id) const; + void removeIncomingQoS2MessageId(u_int16_t packet_id); + + void addOutgoingQoS2MessageId(uint16_t packet_id); + void removeOutgoingQoS2MessageId(u_int16_t packet_id); + }; #endif // SESSION_H diff --git a/types.cpp b/types.cpp index 720f933..b3bd2a2 100644 --- a/types.cpp +++ b/types.cpp @@ -82,3 +82,36 @@ size_t UnsubAck::getLengthWithoutFixedHeader() const { return 2; } + +PubRec::PubRec(uint16_t packet_id) : + packet_id(packet_id) +{ + +} + +size_t PubRec::getLengthWithoutFixedHeader() const +{ + return 2; +} + +PubComp::PubComp(uint16_t packet_id) : + packet_id(packet_id) +{ + +} + +size_t PubComp::getLengthWithoutFixedHeader() const +{ + return 2; +} + +PubRel::PubRel(uint16_t packet_id) : + packet_id(packet_id) +{ + +} + +size_t PubRel::getLengthWithoutFixedHeader() const +{ + return 2; +} diff --git a/types.h b/types.h index b9af8e5..ad62dbe 100644 --- a/types.h +++ b/types.h @@ -113,4 +113,28 @@ public: size_t getLengthWithoutFixedHeader() const; }; +class PubRec +{ +public: + PubRec(uint16_t packet_id); + uint16_t packet_id; + size_t getLengthWithoutFixedHeader() const; +}; + +class PubComp +{ +public: + PubComp(uint16_t packet_id); + uint16_t packet_id; + size_t getLengthWithoutFixedHeader() const; +}; + +class PubRel +{ +public: + PubRel(uint16_t packet_id); + uint16_t packet_id; + size_t getLengthWithoutFixedHeader() const; +}; + #endif // TYPES_H -- libgit2 0.21.4