From ee79271a6ccb3d8e9013f7a31d030165b2f4aee2 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Fri, 1 Apr 2022 20:00:13 +0200 Subject: [PATCH] Working on puback with reason codes --- mqttpacket.cpp | 142 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------------------------------------------------------- mqttpacket.h | 6 +----- session.cpp | 8 +++++++- types.cpp | 53 +++++++++++++++++++---------------------------------- types.h | 35 +++++++++-------------------------- 5 files changed, 101 insertions(+), 143 deletions(-) diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 722f77c..9d8ec2f 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -153,47 +153,23 @@ MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const Publish &_pu calculateRemainingLength(); } -/** - * @brief MqttPacket::pubCommonConstruct is common code for constructors for all those empty pub control packets (ack, rec(eived), rel(ease), comp(lete)). - * @param packet_id - * - * This functions cheats a bit and doesn't use calculateRemainingLength, because it's always 2. Be sure to allocate enough room in the vector when - * you use this function (add 2 to the length). - */ -void MqttPacket::pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits) +MqttPacket::MqttPacket(const PubResponse &pubAck) : + bites(pubAck.getLengthIncludingFixedHeader()) { - assert(firstByteDefaultBits <= 0xF); + this->protocolVersion = pubAck.protocol_version; fixed_header_length = 2; - first_byte = (static_cast(packetType) << 4) | firstByteDefaultBits; + const uint8_t firstByteDefaultBits = pubAck.packet_type == PacketType::PUBREL ? 0b0010 : 0; + this->first_byte = (static_cast(pubAck.packet_type) << 4) | firstByteDefaultBits; writeByte(first_byte); - writeByte(2); // length is always 2. - packet_id_pos = pos; - writeUint16(packet_id); -} - -MqttPacket::MqttPacket(const PubAck &pubAck) : - bites(pubAck.getLengthWithoutFixedHeader() + 2) -{ - pubCommonConstruct(pubAck.packet_id, PacketType::PUBACK); -} - -MqttPacket::MqttPacket(const PubRec &pubRec) : - bites(pubRec.getLengthWithoutFixedHeader() + 2) -{ - pubCommonConstruct(pubRec.packet_id, PacketType::PUBREC); -} - -MqttPacket::MqttPacket(const PubComp &pubComp) : - bites(pubComp.getLengthWithoutFixedHeader() + 2) -{ - pubCommonConstruct(pubComp.packet_id, PacketType::PUBCOMP); -} + writeByte(pubAck.getRemainingLength()); + this->packet_id_pos = this->pos; + writeUint16(pubAck.packet_id); -MqttPacket::MqttPacket(const PubRel &pubRel) : - bites(pubRel.getLengthWithoutFixedHeader() + 2) -{ - pubCommonConstruct(pubRel.packet_id, PacketType::PUBREL, 0b0010); + if (pubAck.needsReasonCode()) + { + writeByte(static_cast(pubAck.reason_code)); + } } void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector &packetQueueIn, std::shared_ptr &sender) @@ -767,6 +743,8 @@ void MqttPacket::handlePublish() publishData.topic = std::string(readBytes(variable_header_length), variable_header_length); + ReasonCodes ackCode = ReasonCodes::Success; + if (qos) { packet_id_pos = pos; @@ -776,29 +754,6 @@ void MqttPacket::handlePublish() { throw ProtocolError("Packet ID 0 when publishing is invalid."); // [MQTT-2.3.1-1] } - - if (qos == 1) - { - PubAck pubAck(packet_id); - MqttPacket response(pubAck); - sender->writeMqttPacket(response); - } - else - { - PubRec pubRec(packet_id); - MqttPacket response(pubRec); - sender->writeMqttPacket(response); - - if (sender->getSession()->incomingQoS2MessageIdInTransit(packet_id)) - { - return; - } - else - { - // Doing this before the authentication on purpose, so when the publish is not allowed, the QoS control packets are allowed and can finish. - sender->getSession()->addIncomingQoS2MessageId(packet_id); - } - } } if (this->protocolVersion >= ProtocolVersion::Mqtt5 ) @@ -881,17 +836,16 @@ void MqttPacket::handlePublish() } } - splitTopic(publishData.topic, publishData.subtopics); + if (publishData.topic.empty()) + throw ProtocolError("Empty publish topic"); if (!isValidUtf8(publishData.topic, true)) { - logger->logf(LOG_WARNING, "Client '%s' published a message with invalid UTF8 or $/+/# in it. Dropping.", sender->repr().c_str()); - return; + const std::string err = formatString("Client '%s' published a message with invalid UTF8 or $/+/# in it. Dropping.", sender->repr().c_str()); + logger->logf(LOG_WARNING, err.c_str()); + throw ProtocolError(err); } - if (publishData.topic.empty()) - throw ProtocolError("Empty publish topic"); - #ifndef NDEBUG logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", publishData.topic.c_str(), qos, retain, dup); #endif @@ -902,21 +856,55 @@ void MqttPacket::handlePublish() payloadStart = pos; Authentication &authentication = *ThreadGlobals::getAuth(); - if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.subtopics, AclAccess::write, qos, retain, getUserProperties()) == AuthResult::success) + + // Working with a local copy because the subscribing action will modify this->packet_id. See the PublishCopyFactory. + const uint16_t _packet_id = this->packet_id; + + if (qos == 2 && sender->getSession()->incomingQoS2MessageIdInTransit(_packet_id)) + { + ackCode = ReasonCodes::PacketIdentifierInUse; + } + else { - if (retain) + // Doing this before the authentication on purpose, so when the publish is not allowed, the QoS control packets are allowed and can finish. + if (qos == 2) + sender->getSession()->addIncomingQoS2MessageId(_packet_id); + + splitTopic(publishData.topic, publishData.subtopics); + + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.subtopics, AclAccess::write, qos, retain, getUserProperties()) == AuthResult::success) + { + if (retain) + { + std::string payload(readBytes(payloadLen), payloadLen); + sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(publishData.topic, publishData.subtopics, payload, qos); + } + + // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. + // Existing subscribers don't get retain=1. [MQTT-3.3.1-9] + bites[0] &= 0b11110110; + first_byte = bites[0]; + + PublishCopyFactory factory(this); + sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(factory); + } + else { - std::string payload(readBytes(payloadLen), payloadLen); - sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(publishData.topic, publishData.subtopics, payload, qos); + ackCode = ReasonCodes::NotAuthorized; } + } - // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. - // Existing subscribers don't get retain=1. [MQTT-3.3.1-9] - bites[0] &= 0b11110110; - first_byte = bites[0]; +#ifndef NDEBUG + // Protection against using the altered packet id. + this->packet_id = 0; +#endif - PublishCopyFactory factory(this); - sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(factory); + if (qos > 0) + { + const PacketType responseType = qos == 1 ? PacketType::PUBACK : PacketType::PUBREC; + PubResponse pubAck(this->protocolVersion, responseType, ackCode, _packet_id); + MqttPacket response(pubAck); + sender->writeMqttPacket(response); } } @@ -935,7 +923,7 @@ void MqttPacket::handlePubRec() sender->getSession()->clearQosMessage(packet_id); sender->getSession()->addOutgoingQoS2MessageId(packet_id); - PubRel pubRel(packet_id); + PubResponse pubRel(this->protocolVersion, PacketType::PUBREL, ReasonCodes::Success, packet_id); MqttPacket response(pubRel); sender->writeMqttPacket(response); } @@ -952,7 +940,7 @@ void MqttPacket::handlePubRel() const uint16_t packet_id = readTwoBytesToUInt16(); sender->getSession()->removeIncomingQoS2MessageId(packet_id); - PubComp pubcomp(packet_id); + PubResponse pubcomp(this->protocolVersion, PacketType::PUBCOMP, ReasonCodes::Success, packet_id); MqttPacket response(pubcomp); sender->writeMqttPacket(response); } diff --git a/mqttpacket.h b/mqttpacket.h index 5283060..cbdea49 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -71,7 +71,6 @@ class MqttPacket void readUserProperty(); void calculateRemainingLength(); - void pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits = 0); MqttPacket(const MqttPacket &other) = delete; public: @@ -87,10 +86,7 @@ public: MqttPacket(const SubAck &subAck); MqttPacket(const UnsubAck &unsubAck); MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish); - MqttPacket(const PubAck &pubAck); - MqttPacket(const PubRec &pubRec); - MqttPacket(const PubComp &pubComp); - MqttPacket(const PubRel &pubRel); + MqttPacket(const PubResponse &pubAck); static void bufferToMqttPackets(CirBuf &buf, std::vector &packetQueueIn, std::shared_ptr &sender); diff --git a/session.cpp b/session.cpp index db162cf..0a149dd 100644 --- a/session.cpp +++ b/session.cpp @@ -268,7 +268,7 @@ uint64_t Session::sendPendingQosMessages() for (const uint16_t packet_id : outgoingQoS2MessageIds) { - PubRel pubRel(packet_id); + PubResponse pubRel(c->getProtocolVersion(), PacketType::PUBREL, ReasonCodes::Success, packet_id); MqttPacket packet(pubRel); count += c->writeMqttPacketAndBlameThisClient(packet); } @@ -313,12 +313,16 @@ std::shared_ptr &Session::getWill() void Session::addIncomingQoS2MessageId(uint16_t packet_id) { + assert(packet_id > 0); + std::unique_lock locker(qosQueueMutex); incomingQoS2MessageIds.insert(packet_id); } bool Session::incomingQoS2MessageIdInTransit(uint16_t packet_id) { + assert(packet_id > 0); + std::unique_lock locker(qosQueueMutex); const auto it = incomingQoS2MessageIds.find(packet_id); return it != incomingQoS2MessageIds.end(); @@ -326,6 +330,8 @@ bool Session::incomingQoS2MessageIdInTransit(uint16_t packet_id) void Session::removeIncomingQoS2MessageId(u_int16_t packet_id) { + assert(packet_id > 0); + std::unique_lock locker(qosQueueMutex); #ifndef NDEBUG diff --git a/types.cpp b/types.cpp index 03f69e4..4814c08 100644 --- a/types.cpp +++ b/types.cpp @@ -186,58 +186,43 @@ bool WillDelayCompare(const std::shared_ptr &a, const std::weak_ptrwill_delay < _b->will_delay; }; -PubAck::PubAck(uint16_t packet_id) : - packet_id(packet_id) -{ - -} - -// Packet has no payload and only a variable header, of length 2. -size_t PubAck::getLengthWithoutFixedHeader() const -{ - return 2; -} - -UnsubAck::UnsubAck(uint16_t packet_id) : - packet_id(packet_id) -{ - -} - -size_t UnsubAck::getLengthWithoutFixedHeader() const -{ - return 2; -} - -PubRec::PubRec(uint16_t packet_id) : +PubResponse::PubResponse(const ProtocolVersion protVersion, const PacketType packet_type, ReasonCodes reason_code, uint16_t packet_id) : + packet_type(packet_type), + protocol_version(protVersion), + reason_code(protVersion >= ProtocolVersion::Mqtt5 ? reason_code : ReasonCodes::Success), packet_id(packet_id) { - + assert(packet_type == PacketType::PUBACK || packet_type == PacketType::PUBREC || packet_type == PacketType::PUBREL || packet_type == PacketType::PUBCOMP); } -size_t PubRec::getLengthWithoutFixedHeader() const +uint8_t PubResponse::getLengthIncludingFixedHeader() const { - return 2; + return 2 + getRemainingLength(); } -PubComp::PubComp(uint16_t packet_id) : - packet_id(packet_id) +uint8_t PubResponse::getRemainingLength() const { - + // I'm leaving out the property length of 0: "If the Remaining Length is less than 4 there is no Property Length and the value of 0 is used" + const uint8_t result = needsReasonCode() ? 3 : 2; + return result; } -size_t PubComp::getLengthWithoutFixedHeader() const +/** + * @brief "The Reason Code and Property Length can be omitted if the Reason Code is 0x00 (Success) and there are no Properties" + * @return + */ +bool PubResponse::needsReasonCode() const { - return 2; + return this->protocol_version >= ProtocolVersion::Mqtt5 && this->reason_code > ReasonCodes::Success; } -PubRel::PubRel(uint16_t packet_id) : +UnsubAck::UnsubAck(uint16_t packet_id) : packet_id(packet_id) { } -size_t PubRel::getLengthWithoutFixedHeader() const +size_t UnsubAck::getLengthWithoutFixedHeader() const { return 2; } diff --git a/types.h b/types.h index d459411..f0fcd2f 100644 --- a/types.h +++ b/types.h @@ -230,36 +230,19 @@ public: bool WillDelayCompare(const std::shared_ptr &a, const std::weak_ptr &b); -class PubAck +class PubResponse { public: - PubAck(uint16_t packet_id); - uint16_t packet_id; - size_t getLengthWithoutFixedHeader() const; -}; - -class PubRec -{ -public: - PubRec(uint16_t packet_id); - uint16_t packet_id; - size_t getLengthWithoutFixedHeader() const; -}; + PubResponse(const PubResponse &other) = delete; + PubResponse(const ProtocolVersion protVersion, const PacketType packet_type, ReasonCodes reason_code, uint16_t packet_id); -class PubComp -{ -public: - PubComp(uint16_t packet_id); - uint16_t packet_id; - size_t getLengthWithoutFixedHeader() const; -}; - -class PubRel -{ -public: - PubRel(uint16_t packet_id); + const PacketType packet_type; + const ProtocolVersion protocol_version; + const ReasonCodes reason_code; uint16_t packet_id; - size_t getLengthWithoutFixedHeader() const; + uint8_t getLengthIncludingFixedHeader() const; + uint8_t getRemainingLength() const; + bool needsReasonCode() const; }; #endif // TYPES_H -- libgit2 0.21.4