Commit ed1479935ac2866fa015b5e19744bba76724af13

Authored by Wiebe Cazemier
1 parent b971b1a5

Add support for QoS 2

mqttpacket.cpp
... ... @@ -132,20 +132,50 @@ MqttPacket::MqttPacket(const Publish &publish) :
132 132 calculateRemainingLength();
133 133 }
134 134  
135   -// This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector.
136   -MqttPacket::MqttPacket(const PubAck &pubAck) :
137   - bites(pubAck.getLengthWithoutFixedHeader() + 2)
  135 +/**
  136 + * @brief MqttPacket::pubCommonConstruct is common code for constructors for all those empty pub control packets (ack, rec(eived), rel(ease), comp(lete)).
  137 + * @param packet_id
  138 + *
  139 + * This functions cheats a bit and doesn't use calculateRemainingLength, because it's always 2. Be sure to allocate enough room in the vector when
  140 + * you use this function (add 2 to the length).
  141 + */
  142 +void MqttPacket::pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits)
138 143 {
139   - fixed_header_length = 2; // This is the cheat part mentioned above. We're not calculating it dynamically.
140   - packetType = PacketType::PUBACK;
141   - first_byte = static_cast<char>(packetType) << 4;
  144 + assert(firstByteDefaultBits <= 0xF);
  145 +
  146 + fixed_header_length = 2;
  147 + first_byte = (static_cast<uint8_t>(packetType) << 4) | firstByteDefaultBits;
142 148 writeByte(first_byte);
143 149 writeByte(2); // length is always 2.
144   - char topicLenMSB = (pubAck.packet_id & 0xFF00) >> 8;
145   - char topicLenLSB = (pubAck.packet_id & 0x00FF);
  150 + uint8_t packetIdMSB = (packet_id & 0xFF00) >> 8;
  151 + uint8_t packetIdLSB = (packet_id & 0x00FF);
146 152 packet_id_pos = pos;
147   - writeByte(topicLenMSB);
148   - writeByte(topicLenLSB);
  153 + writeByte(packetIdMSB);
  154 + writeByte(packetIdLSB);
  155 +}
  156 +
  157 +MqttPacket::MqttPacket(const PubAck &pubAck) :
  158 + bites(pubAck.getLengthWithoutFixedHeader() + 2)
  159 +{
  160 + pubCommonConstruct(pubAck.packet_id, PacketType::PUBACK);
  161 +}
  162 +
  163 +MqttPacket::MqttPacket(const PubRec &pubRec) :
  164 + bites(pubRec.getLengthWithoutFixedHeader() + 2)
  165 +{
  166 + pubCommonConstruct(pubRec.packet_id, PacketType::PUBREC);
  167 +}
  168 +
  169 +MqttPacket::MqttPacket(const PubComp &pubComp) :
  170 + bites(pubComp.getLengthWithoutFixedHeader() + 2)
  171 +{
  172 + pubCommonConstruct(pubComp.packet_id, PacketType::PUBCOMP);
  173 +}
  174 +
  175 +MqttPacket::MqttPacket(const PubRel &pubRel) :
  176 + bites(pubRel.getLengthWithoutFixedHeader() + 2)
  177 +{
  178 + pubCommonConstruct(pubRel.packet_id, PacketType::PUBREL, 0b0010);
149 179 }
150 180  
151 181 void MqttPacket::handle()
... ... @@ -175,6 +205,12 @@ void MqttPacket::handle()
175 205 handlePublish();
176 206 else if (packetType == PacketType::PUBACK)
177 207 handlePubAck();
  208 + else if (packetType == PacketType::PUBREC)
  209 + handlePubRec();
  210 + else if (packetType == PacketType::PUBREL)
  211 + handlePubRel();
  212 + else if (packetType == PacketType::PUBCOMP)
  213 + handlePubComp();
178 214 }
179 215  
180 216 void MqttPacket::handleConnect()
... ... @@ -437,7 +473,7 @@ void MqttPacket::handlePublish()
437 473 bool dup = !!(first_byte & 0b00001000);
438 474 char qos = (first_byte & 0b00000110) >> 1;
439 475  
440   - if (qos == 3)
  476 + if (qos > 2)
441 477 throw ProtocolError("QoS 3 is a protocol violation.");
442 478 this->qos = qos;
443 479  
... ... @@ -453,21 +489,37 @@ void MqttPacket::handlePublish()
453 489 return;
454 490 }
455 491  
  492 +#ifndef NDEBUG
  493 + logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", topic.c_str(), qos, retain, dup);
  494 +#endif
  495 +
456 496 if (qos)
457 497 {
458   - if (qos > 1)
459   - throw ProtocolError("Qos > 1 not implemented.");
460 498 packet_id_pos = pos;
461   - uint16_t packet_id = readTwoBytesToUInt16();
  499 + packet_id = readTwoBytesToUInt16();
462 500  
463   - // Clear the packet ID from this packet, because each new publish must get a new one. It's more of a debug precaution.
464   - pos -= 2;
465   - char zero[2]; zero[0] = 0; zero[1] = 0;
466   - writeBytes(zero, 2);
  501 + if (qos == 1)
  502 + {
  503 + PubAck pubAck(packet_id);
  504 + MqttPacket response(pubAck);
  505 + sender->writeMqttPacket(response);
  506 + }
  507 + else
  508 + {
  509 + PubRec pubRec(packet_id);
  510 + MqttPacket response(pubRec);
  511 + sender->writeMqttPacket(response);
467 512  
468   - PubAck pubAck(packet_id);
469   - MqttPacket response(pubAck);
470   - sender->writeMqttPacket(response);
  513 + if (sender->getSession()->incomingQoS2MessageIdInTransit(packet_id))
  514 + {
  515 + return;
  516 + }
  517 + else
  518 + {
  519 + // Doing this before the authentication on purpose, so when the publish is not allowed, the QoS control packets are allowed and can finish.
  520 + sender->getSession()->addIncomingQoS2MessageId(packet_id);
  521 + }
  522 + }
471 523 }
472 524  
473 525 if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write) == AuthResult::success)
... ... @@ -496,6 +548,42 @@ void MqttPacket::handlePubAck()
496 548 sender->getSession()->clearQosMessage(packet_id);
497 549 }
498 550  
  551 +/**
  552 + * @brief MqttPacket::handlePubRec handles QoS 2 'publish received' packets. The publisher receives these.
  553 + */
  554 +void MqttPacket::handlePubRec()
  555 +{
  556 + const uint16_t packet_id = readTwoBytesToUInt16();
  557 + sender->getSession()->clearQosMessage(packet_id);
  558 + sender->getSession()->addOutgoingQoS2MessageId(packet_id);
  559 +
  560 + PubRel pubRel(packet_id);
  561 + MqttPacket response(pubRel);
  562 + sender->writeMqttPacket(response);
  563 +}
  564 +
  565 +/**
  566 + * @brief MqttPacket::handlePubRel handles QoS 2 'publish release'. The publisher sends these.
  567 + */
  568 +void MqttPacket::handlePubRel()
  569 +{
  570 + const uint16_t packet_id = readTwoBytesToUInt16();
  571 + sender->getSession()->removeIncomingQoS2MessageId(packet_id);
  572 +
  573 + PubComp pubcomp(packet_id);
  574 + MqttPacket response(pubcomp);
  575 + sender->writeMqttPacket(response);
  576 +}
  577 +
  578 +/**
  579 + * @brief MqttPacket::handlePubComp handles QoS 2 'publish complete'. The publisher receives these.
  580 + */
  581 +void MqttPacket::handlePubComp()
  582 +{
  583 + const uint16_t packet_id = readTwoBytesToUInt16();
  584 + sender->getSession()->removeOutgoingQoS2MessageId(packet_id);
  585 +}
  586 +
499 587 void MqttPacket::calculateRemainingLength()
500 588 {
501 589 assert(fixed_header_length == 0); // because you're not supposed to call this on packet that we already know the length of.
... ... @@ -529,6 +617,8 @@ void MqttPacket::setPacketId(uint16_t packet_id)
529 617 assert(packetType == PacketType::PUBLISH);
530 618 assert(qos > 0);
531 619  
  620 + this->packet_id = packet_id;
  621 +
532 622 pos = packet_id_pos;
533 623  
534 624 char topicLenMSB = (packet_id & 0xFF00) >> 8;
... ... @@ -537,6 +627,12 @@ void MqttPacket::setPacketId(uint16_t packet_id)
537 627 writeByte(topicLenLSB);
538 628 }
539 629  
  630 +uint16_t MqttPacket::getPacketId() const
  631 +{
  632 + assert(qos > 0);
  633 + return packet_id;
  634 +}
  635 +
540 636 // If I read the specs correctly, the DUP flag is merely for show. It doesn't control anything?
541 637 void MqttPacket::setDuplicate()
542 638 {
... ...
mqttpacket.h
... ... @@ -57,6 +57,7 @@ class MqttPacket
57 57 char first_byte = 0;
58 58 size_t pos = 0;
59 59 size_t packet_id_pos = 0;
  60 + uint16_t packet_id = 0;
60 61 ProtocolVersion protocolVersion = ProtocolVersion::None;
61 62 Logger *logger = Logger::getInstance();
62 63  
... ... @@ -68,6 +69,7 @@ class MqttPacket
68 69 size_t remainingAfterPos();
69 70  
70 71 void calculateRemainingLength();
  72 + void pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits = 0);
71 73  
72 74 MqttPacket(const MqttPacket &other) = default;
73 75 public:
... ... @@ -84,6 +86,9 @@ public:
84 86 MqttPacket(const UnsubAck &unsubAck);
85 87 MqttPacket(const Publish &publish);
86 88 MqttPacket(const PubAck &pubAck);
  89 + MqttPacket(const PubRec &pubRec);
  90 + MqttPacket(const PubComp &pubComp);
  91 + MqttPacket(const PubRel &pubRel);
87 92  
88 93 void handle();
89 94 void handleConnect();
... ... @@ -93,6 +98,9 @@ public:
93 98 void handlePing();
94 99 void handlePublish();
95 100 void handlePubAck();
  101 + void handlePubRec();
  102 + void handlePubRel();
  103 + void handlePubComp();
96 104  
97 105 size_t getSizeIncludingNonPresentHeader() const;
98 106 const std::vector<char> &getBites() const { return bites; }
... ... @@ -105,6 +113,7 @@ public:
105 113 char getFirstByte() const;
106 114 RemainingLength getRemainingLength() const;
107 115 void setPacketId(uint16_t packet_id);
  116 + uint16_t getPacketId() const;
108 117 void setDuplicate();
109 118 size_t getTotalMemoryFootprint();
110 119 };
... ...
session.cpp
... ... @@ -64,11 +64,13 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos)
64 64 c->writeMqttPacketAndBlameThisClient(packet, qos);
65 65 }
66 66 }
67   - else if (qos == 1)
  67 + else if (qos > 0)
68 68 {
69 69 std::shared_ptr<MqttPacket> copyPacket = packet.getCopy();
70 70 std::unique_lock<std::mutex> locker(qosQueueMutex);
71   - if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0))
  71 +
  72 + const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size();
  73 + if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0))
72 74 {
73 75 logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str());
74 76 return;
... ... @@ -145,6 +147,13 @@ void Session::sendPendingQosMessages()
145 147 c->writeMqttPacketAndBlameThisClient(*qosMessage.packet.get(), qosMessage.packet->getQos());
146 148 qosMessage.packet->setDuplicate(); // Any dealings with this packet from here will be a duplicate.
147 149 }
  150 +
  151 + for (const uint16_t packet_id : outgoingQoS2MessageIds)
  152 + {
  153 + PubRel pubRel(packet_id);
  154 + MqttPacket packet(pubRel);
  155 + c->writeMqttPacketAndBlameThisClient(packet, 2);
  156 + }
148 157 }
149 158 }
150 159  
... ... @@ -158,3 +167,41 @@ bool Session::hasExpired()
158 167 {
159 168 return clientDisconnected() && (lastTouched + EXPIRE_SESSION_AFTER) < time(NULL);
160 169 }
  170 +
  171 +void Session::addIncomingQoS2MessageId(uint16_t packet_id)
  172 +{
  173 + incomingQoS2MessageIds.insert(packet_id);
  174 +}
  175 +
  176 +bool Session::incomingQoS2MessageIdInTransit(uint16_t packet_id) const
  177 +{
  178 + const auto it = incomingQoS2MessageIds.find(packet_id);
  179 + return it != incomingQoS2MessageIds.end();
  180 +}
  181 +
  182 +void Session::removeIncomingQoS2MessageId(u_int16_t packet_id)
  183 +{
  184 +#ifndef NDEBUG
  185 + logger->logf(LOG_DEBUG, "As QoS 2 receiver: publish released (PUBREL) for '%s', packet id '%d'. Left in queue: %d", client_id.c_str(), packet_id, incomingQoS2MessageIds.size());
  186 +#endif
  187 +
  188 + const auto it = incomingQoS2MessageIds.find(packet_id);
  189 + if (it != incomingQoS2MessageIds.end())
  190 + incomingQoS2MessageIds.erase(it);
  191 +}
  192 +
  193 +void Session::addOutgoingQoS2MessageId(uint16_t packet_id)
  194 +{
  195 + outgoingQoS2MessageIds.insert(packet_id);
  196 +}
  197 +
  198 +void Session::removeOutgoingQoS2MessageId(u_int16_t packet_id)
  199 +{
  200 +#ifndef NDEBUG
  201 + logger->logf(LOG_DEBUG, "As QoS 2 sender: publish complete (PUBCOMP) for '%s', packet id '%d'. Left in queue: %d", client_id.c_str(), packet_id, outgoingQoS2MessageIds.size());
  202 +#endif
  203 +
  204 + const auto it = outgoingQoS2MessageIds.find(packet_id);
  205 + if (it != outgoingQoS2MessageIds.end())
  206 + outgoingQoS2MessageIds.erase(it);
  207 +}
... ...
session.h
... ... @@ -21,6 +21,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
21 21 #include <memory>
22 22 #include <list>
23 23 #include <mutex>
  24 +#include <set>
24 25  
25 26 #include "forward_declarations.h"
26 27 #include "logger.h"
... ... @@ -45,6 +46,8 @@ class Session
45 46 std::string client_id;
46 47 std::string username;
47 48 std::list<QueuedQosPacket> qosPacketQueue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6]
  49 + std::set<uint16_t> incomingQoS2MessageIds;
  50 + std::set<uint16_t> outgoingQoS2MessageIds;
48 51 std::mutex qosQueueMutex;
49 52 uint16_t nextPacketId = 0;
50 53 ssize_t qosQueueBytes = 0;
... ... @@ -66,6 +69,14 @@ public:
66 69 void sendPendingQosMessages();
67 70 void touch(time_t val = 0);
68 71 bool hasExpired();
  72 +
  73 + void addIncomingQoS2MessageId(uint16_t packet_id);
  74 + bool incomingQoS2MessageIdInTransit(uint16_t packet_id) const;
  75 + void removeIncomingQoS2MessageId(u_int16_t packet_id);
  76 +
  77 + void addOutgoingQoS2MessageId(uint16_t packet_id);
  78 + void removeOutgoingQoS2MessageId(u_int16_t packet_id);
  79 +
69 80 };
70 81  
71 82 #endif // SESSION_H
... ...
types.cpp
... ... @@ -82,3 +82,36 @@ size_t UnsubAck::getLengthWithoutFixedHeader() const
82 82 {
83 83 return 2;
84 84 }
  85 +
  86 +PubRec::PubRec(uint16_t packet_id) :
  87 + packet_id(packet_id)
  88 +{
  89 +
  90 +}
  91 +
  92 +size_t PubRec::getLengthWithoutFixedHeader() const
  93 +{
  94 + return 2;
  95 +}
  96 +
  97 +PubComp::PubComp(uint16_t packet_id) :
  98 + packet_id(packet_id)
  99 +{
  100 +
  101 +}
  102 +
  103 +size_t PubComp::getLengthWithoutFixedHeader() const
  104 +{
  105 + return 2;
  106 +}
  107 +
  108 +PubRel::PubRel(uint16_t packet_id) :
  109 + packet_id(packet_id)
  110 +{
  111 +
  112 +}
  113 +
  114 +size_t PubRel::getLengthWithoutFixedHeader() const
  115 +{
  116 + return 2;
  117 +}
... ...
... ... @@ -113,4 +113,28 @@ public:
113 113 size_t getLengthWithoutFixedHeader() const;
114 114 };
115 115  
  116 +class PubRec
  117 +{
  118 +public:
  119 + PubRec(uint16_t packet_id);
  120 + uint16_t packet_id;
  121 + size_t getLengthWithoutFixedHeader() const;
  122 +};
  123 +
  124 +class PubComp
  125 +{
  126 +public:
  127 + PubComp(uint16_t packet_id);
  128 + uint16_t packet_id;
  129 + size_t getLengthWithoutFixedHeader() const;
  130 +};
  131 +
  132 +class PubRel
  133 +{
  134 +public:
  135 + PubRel(uint16_t packet_id);
  136 + uint16_t packet_id;
  137 + size_t getLengthWithoutFixedHeader() const;
  138 +};
  139 +
116 140 #endif // TYPES_H
... ...