Commit ed1479935ac2866fa015b5e19744bba76724af13

Authored by Wiebe Cazemier
1 parent b971b1a5

Add support for QoS 2

mqttpacket.cpp
@@ -132,20 +132,50 @@ MqttPacket::MqttPacket(const Publish &publish) : @@ -132,20 +132,50 @@ MqttPacket::MqttPacket(const Publish &publish) :
132 calculateRemainingLength(); 132 calculateRemainingLength();
133 } 133 }
134 134
135 -// This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector.  
136 -MqttPacket::MqttPacket(const PubAck &pubAck) :  
137 - bites(pubAck.getLengthWithoutFixedHeader() + 2) 135 +/**
  136 + * @brief MqttPacket::pubCommonConstruct is common code for constructors for all those empty pub control packets (ack, rec(eived), rel(ease), comp(lete)).
  137 + * @param packet_id
  138 + *
  139 + * This functions cheats a bit and doesn't use calculateRemainingLength, because it's always 2. Be sure to allocate enough room in the vector when
  140 + * you use this function (add 2 to the length).
  141 + */
  142 +void MqttPacket::pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits)
138 { 143 {
139 - fixed_header_length = 2; // This is the cheat part mentioned above. We're not calculating it dynamically.  
140 - packetType = PacketType::PUBACK;  
141 - first_byte = static_cast<char>(packetType) << 4; 144 + assert(firstByteDefaultBits <= 0xF);
  145 +
  146 + fixed_header_length = 2;
  147 + first_byte = (static_cast<uint8_t>(packetType) << 4) | firstByteDefaultBits;
142 writeByte(first_byte); 148 writeByte(first_byte);
143 writeByte(2); // length is always 2. 149 writeByte(2); // length is always 2.
144 - char topicLenMSB = (pubAck.packet_id & 0xFF00) >> 8;  
145 - char topicLenLSB = (pubAck.packet_id & 0x00FF); 150 + uint8_t packetIdMSB = (packet_id & 0xFF00) >> 8;
  151 + uint8_t packetIdLSB = (packet_id & 0x00FF);
146 packet_id_pos = pos; 152 packet_id_pos = pos;
147 - writeByte(topicLenMSB);  
148 - writeByte(topicLenLSB); 153 + writeByte(packetIdMSB);
  154 + writeByte(packetIdLSB);
  155 +}
  156 +
  157 +MqttPacket::MqttPacket(const PubAck &pubAck) :
  158 + bites(pubAck.getLengthWithoutFixedHeader() + 2)
  159 +{
  160 + pubCommonConstruct(pubAck.packet_id, PacketType::PUBACK);
  161 +}
  162 +
  163 +MqttPacket::MqttPacket(const PubRec &pubRec) :
  164 + bites(pubRec.getLengthWithoutFixedHeader() + 2)
  165 +{
  166 + pubCommonConstruct(pubRec.packet_id, PacketType::PUBREC);
  167 +}
  168 +
  169 +MqttPacket::MqttPacket(const PubComp &pubComp) :
  170 + bites(pubComp.getLengthWithoutFixedHeader() + 2)
  171 +{
  172 + pubCommonConstruct(pubComp.packet_id, PacketType::PUBCOMP);
  173 +}
  174 +
  175 +MqttPacket::MqttPacket(const PubRel &pubRel) :
  176 + bites(pubRel.getLengthWithoutFixedHeader() + 2)
  177 +{
  178 + pubCommonConstruct(pubRel.packet_id, PacketType::PUBREL, 0b0010);
149 } 179 }
150 180
151 void MqttPacket::handle() 181 void MqttPacket::handle()
@@ -175,6 +205,12 @@ void MqttPacket::handle() @@ -175,6 +205,12 @@ void MqttPacket::handle()
175 handlePublish(); 205 handlePublish();
176 else if (packetType == PacketType::PUBACK) 206 else if (packetType == PacketType::PUBACK)
177 handlePubAck(); 207 handlePubAck();
  208 + else if (packetType == PacketType::PUBREC)
  209 + handlePubRec();
  210 + else if (packetType == PacketType::PUBREL)
  211 + handlePubRel();
  212 + else if (packetType == PacketType::PUBCOMP)
  213 + handlePubComp();
178 } 214 }
179 215
180 void MqttPacket::handleConnect() 216 void MqttPacket::handleConnect()
@@ -437,7 +473,7 @@ void MqttPacket::handlePublish() @@ -437,7 +473,7 @@ void MqttPacket::handlePublish()
437 bool dup = !!(first_byte & 0b00001000); 473 bool dup = !!(first_byte & 0b00001000);
438 char qos = (first_byte & 0b00000110) >> 1; 474 char qos = (first_byte & 0b00000110) >> 1;
439 475
440 - if (qos == 3) 476 + if (qos > 2)
441 throw ProtocolError("QoS 3 is a protocol violation."); 477 throw ProtocolError("QoS 3 is a protocol violation.");
442 this->qos = qos; 478 this->qos = qos;
443 479
@@ -453,21 +489,37 @@ void MqttPacket::handlePublish() @@ -453,21 +489,37 @@ void MqttPacket::handlePublish()
453 return; 489 return;
454 } 490 }
455 491
  492 +#ifndef NDEBUG
  493 + logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", topic.c_str(), qos, retain, dup);
  494 +#endif
  495 +
456 if (qos) 496 if (qos)
457 { 497 {
458 - if (qos > 1)  
459 - throw ProtocolError("Qos > 1 not implemented.");  
460 packet_id_pos = pos; 498 packet_id_pos = pos;
461 - uint16_t packet_id = readTwoBytesToUInt16(); 499 + packet_id = readTwoBytesToUInt16();
462 500
463 - // Clear the packet ID from this packet, because each new publish must get a new one. It's more of a debug precaution.  
464 - pos -= 2;  
465 - char zero[2]; zero[0] = 0; zero[1] = 0;  
466 - writeBytes(zero, 2); 501 + if (qos == 1)
  502 + {
  503 + PubAck pubAck(packet_id);
  504 + MqttPacket response(pubAck);
  505 + sender->writeMqttPacket(response);
  506 + }
  507 + else
  508 + {
  509 + PubRec pubRec(packet_id);
  510 + MqttPacket response(pubRec);
  511 + sender->writeMqttPacket(response);
467 512
468 - PubAck pubAck(packet_id);  
469 - MqttPacket response(pubAck);  
470 - sender->writeMqttPacket(response); 513 + if (sender->getSession()->incomingQoS2MessageIdInTransit(packet_id))
  514 + {
  515 + return;
  516 + }
  517 + else
  518 + {
  519 + // Doing this before the authentication on purpose, so when the publish is not allowed, the QoS control packets are allowed and can finish.
  520 + sender->getSession()->addIncomingQoS2MessageId(packet_id);
  521 + }
  522 + }
471 } 523 }
472 524
473 if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write) == AuthResult::success) 525 if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write) == AuthResult::success)
@@ -496,6 +548,42 @@ void MqttPacket::handlePubAck() @@ -496,6 +548,42 @@ void MqttPacket::handlePubAck()
496 sender->getSession()->clearQosMessage(packet_id); 548 sender->getSession()->clearQosMessage(packet_id);
497 } 549 }
498 550
  551 +/**
  552 + * @brief MqttPacket::handlePubRec handles QoS 2 'publish received' packets. The publisher receives these.
  553 + */
  554 +void MqttPacket::handlePubRec()
  555 +{
  556 + const uint16_t packet_id = readTwoBytesToUInt16();
  557 + sender->getSession()->clearQosMessage(packet_id);
  558 + sender->getSession()->addOutgoingQoS2MessageId(packet_id);
  559 +
  560 + PubRel pubRel(packet_id);
  561 + MqttPacket response(pubRel);
  562 + sender->writeMqttPacket(response);
  563 +}
  564 +
  565 +/**
  566 + * @brief MqttPacket::handlePubRel handles QoS 2 'publish release'. The publisher sends these.
  567 + */
  568 +void MqttPacket::handlePubRel()
  569 +{
  570 + const uint16_t packet_id = readTwoBytesToUInt16();
  571 + sender->getSession()->removeIncomingQoS2MessageId(packet_id);
  572 +
  573 + PubComp pubcomp(packet_id);
  574 + MqttPacket response(pubcomp);
  575 + sender->writeMqttPacket(response);
  576 +}
  577 +
  578 +/**
  579 + * @brief MqttPacket::handlePubComp handles QoS 2 'publish complete'. The publisher receives these.
  580 + */
  581 +void MqttPacket::handlePubComp()
  582 +{
  583 + const uint16_t packet_id = readTwoBytesToUInt16();
  584 + sender->getSession()->removeOutgoingQoS2MessageId(packet_id);
  585 +}
  586 +
499 void MqttPacket::calculateRemainingLength() 587 void MqttPacket::calculateRemainingLength()
500 { 588 {
501 assert(fixed_header_length == 0); // because you're not supposed to call this on packet that we already know the length of. 589 assert(fixed_header_length == 0); // because you're not supposed to call this on packet that we already know the length of.
@@ -529,6 +617,8 @@ void MqttPacket::setPacketId(uint16_t packet_id) @@ -529,6 +617,8 @@ void MqttPacket::setPacketId(uint16_t packet_id)
529 assert(packetType == PacketType::PUBLISH); 617 assert(packetType == PacketType::PUBLISH);
530 assert(qos > 0); 618 assert(qos > 0);
531 619
  620 + this->packet_id = packet_id;
  621 +
532 pos = packet_id_pos; 622 pos = packet_id_pos;
533 623
534 char topicLenMSB = (packet_id & 0xFF00) >> 8; 624 char topicLenMSB = (packet_id & 0xFF00) >> 8;
@@ -537,6 +627,12 @@ void MqttPacket::setPacketId(uint16_t packet_id) @@ -537,6 +627,12 @@ void MqttPacket::setPacketId(uint16_t packet_id)
537 writeByte(topicLenLSB); 627 writeByte(topicLenLSB);
538 } 628 }
539 629
  630 +uint16_t MqttPacket::getPacketId() const
  631 +{
  632 + assert(qos > 0);
  633 + return packet_id;
  634 +}
  635 +
540 // If I read the specs correctly, the DUP flag is merely for show. It doesn't control anything? 636 // If I read the specs correctly, the DUP flag is merely for show. It doesn't control anything?
541 void MqttPacket::setDuplicate() 637 void MqttPacket::setDuplicate()
542 { 638 {
mqttpacket.h
@@ -57,6 +57,7 @@ class MqttPacket @@ -57,6 +57,7 @@ class MqttPacket
57 char first_byte = 0; 57 char first_byte = 0;
58 size_t pos = 0; 58 size_t pos = 0;
59 size_t packet_id_pos = 0; 59 size_t packet_id_pos = 0;
  60 + uint16_t packet_id = 0;
60 ProtocolVersion protocolVersion = ProtocolVersion::None; 61 ProtocolVersion protocolVersion = ProtocolVersion::None;
61 Logger *logger = Logger::getInstance(); 62 Logger *logger = Logger::getInstance();
62 63
@@ -68,6 +69,7 @@ class MqttPacket @@ -68,6 +69,7 @@ class MqttPacket
68 size_t remainingAfterPos(); 69 size_t remainingAfterPos();
69 70
70 void calculateRemainingLength(); 71 void calculateRemainingLength();
  72 + void pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits = 0);
71 73
72 MqttPacket(const MqttPacket &other) = default; 74 MqttPacket(const MqttPacket &other) = default;
73 public: 75 public:
@@ -84,6 +86,9 @@ public: @@ -84,6 +86,9 @@ public:
84 MqttPacket(const UnsubAck &unsubAck); 86 MqttPacket(const UnsubAck &unsubAck);
85 MqttPacket(const Publish &publish); 87 MqttPacket(const Publish &publish);
86 MqttPacket(const PubAck &pubAck); 88 MqttPacket(const PubAck &pubAck);
  89 + MqttPacket(const PubRec &pubRec);
  90 + MqttPacket(const PubComp &pubComp);
  91 + MqttPacket(const PubRel &pubRel);
87 92
88 void handle(); 93 void handle();
89 void handleConnect(); 94 void handleConnect();
@@ -93,6 +98,9 @@ public: @@ -93,6 +98,9 @@ public:
93 void handlePing(); 98 void handlePing();
94 void handlePublish(); 99 void handlePublish();
95 void handlePubAck(); 100 void handlePubAck();
  101 + void handlePubRec();
  102 + void handlePubRel();
  103 + void handlePubComp();
96 104
97 size_t getSizeIncludingNonPresentHeader() const; 105 size_t getSizeIncludingNonPresentHeader() const;
98 const std::vector<char> &getBites() const { return bites; } 106 const std::vector<char> &getBites() const { return bites; }
@@ -105,6 +113,7 @@ public: @@ -105,6 +113,7 @@ public:
105 char getFirstByte() const; 113 char getFirstByte() const;
106 RemainingLength getRemainingLength() const; 114 RemainingLength getRemainingLength() const;
107 void setPacketId(uint16_t packet_id); 115 void setPacketId(uint16_t packet_id);
  116 + uint16_t getPacketId() const;
108 void setDuplicate(); 117 void setDuplicate();
109 size_t getTotalMemoryFootprint(); 118 size_t getTotalMemoryFootprint();
110 }; 119 };
session.cpp
@@ -64,11 +64,13 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos) @@ -64,11 +64,13 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos)
64 c->writeMqttPacketAndBlameThisClient(packet, qos); 64 c->writeMqttPacketAndBlameThisClient(packet, qos);
65 } 65 }
66 } 66 }
67 - else if (qos == 1) 67 + else if (qos > 0)
68 { 68 {
69 std::shared_ptr<MqttPacket> copyPacket = packet.getCopy(); 69 std::shared_ptr<MqttPacket> copyPacket = packet.getCopy();
70 std::unique_lock<std::mutex> locker(qosQueueMutex); 70 std::unique_lock<std::mutex> locker(qosQueueMutex);
71 - if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) 71 +
  72 + const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size();
  73 + if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0))
72 { 74 {
73 logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str()); 75 logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str());
74 return; 76 return;
@@ -145,6 +147,13 @@ void Session::sendPendingQosMessages() @@ -145,6 +147,13 @@ void Session::sendPendingQosMessages()
145 c->writeMqttPacketAndBlameThisClient(*qosMessage.packet.get(), qosMessage.packet->getQos()); 147 c->writeMqttPacketAndBlameThisClient(*qosMessage.packet.get(), qosMessage.packet->getQos());
146 qosMessage.packet->setDuplicate(); // Any dealings with this packet from here will be a duplicate. 148 qosMessage.packet->setDuplicate(); // Any dealings with this packet from here will be a duplicate.
147 } 149 }
  150 +
  151 + for (const uint16_t packet_id : outgoingQoS2MessageIds)
  152 + {
  153 + PubRel pubRel(packet_id);
  154 + MqttPacket packet(pubRel);
  155 + c->writeMqttPacketAndBlameThisClient(packet, 2);
  156 + }
148 } 157 }
149 } 158 }
150 159
@@ -158,3 +167,41 @@ bool Session::hasExpired() @@ -158,3 +167,41 @@ bool Session::hasExpired()
158 { 167 {
159 return clientDisconnected() && (lastTouched + EXPIRE_SESSION_AFTER) < time(NULL); 168 return clientDisconnected() && (lastTouched + EXPIRE_SESSION_AFTER) < time(NULL);
160 } 169 }
  170 +
  171 +void Session::addIncomingQoS2MessageId(uint16_t packet_id)
  172 +{
  173 + incomingQoS2MessageIds.insert(packet_id);
  174 +}
  175 +
  176 +bool Session::incomingQoS2MessageIdInTransit(uint16_t packet_id) const
  177 +{
  178 + const auto it = incomingQoS2MessageIds.find(packet_id);
  179 + return it != incomingQoS2MessageIds.end();
  180 +}
  181 +
  182 +void Session::removeIncomingQoS2MessageId(u_int16_t packet_id)
  183 +{
  184 +#ifndef NDEBUG
  185 + logger->logf(LOG_DEBUG, "As QoS 2 receiver: publish released (PUBREL) for '%s', packet id '%d'. Left in queue: %d", client_id.c_str(), packet_id, incomingQoS2MessageIds.size());
  186 +#endif
  187 +
  188 + const auto it = incomingQoS2MessageIds.find(packet_id);
  189 + if (it != incomingQoS2MessageIds.end())
  190 + incomingQoS2MessageIds.erase(it);
  191 +}
  192 +
  193 +void Session::addOutgoingQoS2MessageId(uint16_t packet_id)
  194 +{
  195 + outgoingQoS2MessageIds.insert(packet_id);
  196 +}
  197 +
  198 +void Session::removeOutgoingQoS2MessageId(u_int16_t packet_id)
  199 +{
  200 +#ifndef NDEBUG
  201 + logger->logf(LOG_DEBUG, "As QoS 2 sender: publish complete (PUBCOMP) for '%s', packet id '%d'. Left in queue: %d", client_id.c_str(), packet_id, outgoingQoS2MessageIds.size());
  202 +#endif
  203 +
  204 + const auto it = outgoingQoS2MessageIds.find(packet_id);
  205 + if (it != outgoingQoS2MessageIds.end())
  206 + outgoingQoS2MessageIds.erase(it);
  207 +}
session.h
@@ -21,6 +21,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -21,6 +21,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
21 #include <memory> 21 #include <memory>
22 #include <list> 22 #include <list>
23 #include <mutex> 23 #include <mutex>
  24 +#include <set>
24 25
25 #include "forward_declarations.h" 26 #include "forward_declarations.h"
26 #include "logger.h" 27 #include "logger.h"
@@ -45,6 +46,8 @@ class Session @@ -45,6 +46,8 @@ class Session
45 std::string client_id; 46 std::string client_id;
46 std::string username; 47 std::string username;
47 std::list<QueuedQosPacket> qosPacketQueue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] 48 std::list<QueuedQosPacket> qosPacketQueue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6]
  49 + std::set<uint16_t> incomingQoS2MessageIds;
  50 + std::set<uint16_t> outgoingQoS2MessageIds;
48 std::mutex qosQueueMutex; 51 std::mutex qosQueueMutex;
49 uint16_t nextPacketId = 0; 52 uint16_t nextPacketId = 0;
50 ssize_t qosQueueBytes = 0; 53 ssize_t qosQueueBytes = 0;
@@ -66,6 +69,14 @@ public: @@ -66,6 +69,14 @@ public:
66 void sendPendingQosMessages(); 69 void sendPendingQosMessages();
67 void touch(time_t val = 0); 70 void touch(time_t val = 0);
68 bool hasExpired(); 71 bool hasExpired();
  72 +
  73 + void addIncomingQoS2MessageId(uint16_t packet_id);
  74 + bool incomingQoS2MessageIdInTransit(uint16_t packet_id) const;
  75 + void removeIncomingQoS2MessageId(u_int16_t packet_id);
  76 +
  77 + void addOutgoingQoS2MessageId(uint16_t packet_id);
  78 + void removeOutgoingQoS2MessageId(u_int16_t packet_id);
  79 +
69 }; 80 };
70 81
71 #endif // SESSION_H 82 #endif // SESSION_H
types.cpp
@@ -82,3 +82,36 @@ size_t UnsubAck::getLengthWithoutFixedHeader() const @@ -82,3 +82,36 @@ size_t UnsubAck::getLengthWithoutFixedHeader() const
82 { 82 {
83 return 2; 83 return 2;
84 } 84 }
  85 +
  86 +PubRec::PubRec(uint16_t packet_id) :
  87 + packet_id(packet_id)
  88 +{
  89 +
  90 +}
  91 +
  92 +size_t PubRec::getLengthWithoutFixedHeader() const
  93 +{
  94 + return 2;
  95 +}
  96 +
  97 +PubComp::PubComp(uint16_t packet_id) :
  98 + packet_id(packet_id)
  99 +{
  100 +
  101 +}
  102 +
  103 +size_t PubComp::getLengthWithoutFixedHeader() const
  104 +{
  105 + return 2;
  106 +}
  107 +
  108 +PubRel::PubRel(uint16_t packet_id) :
  109 + packet_id(packet_id)
  110 +{
  111 +
  112 +}
  113 +
  114 +size_t PubRel::getLengthWithoutFixedHeader() const
  115 +{
  116 + return 2;
  117 +}
@@ -113,4 +113,28 @@ public: @@ -113,4 +113,28 @@ public:
113 size_t getLengthWithoutFixedHeader() const; 113 size_t getLengthWithoutFixedHeader() const;
114 }; 114 };
115 115
  116 +class PubRec
  117 +{
  118 +public:
  119 + PubRec(uint16_t packet_id);
  120 + uint16_t packet_id;
  121 + size_t getLengthWithoutFixedHeader() const;
  122 +};
  123 +
  124 +class PubComp
  125 +{
  126 +public:
  127 + PubComp(uint16_t packet_id);
  128 + uint16_t packet_id;
  129 + size_t getLengthWithoutFixedHeader() const;
  130 +};
  131 +
  132 +class PubRel
  133 +{
  134 +public:
  135 + PubRel(uint16_t packet_id);
  136 + uint16_t packet_id;
  137 + size_t getLengthWithoutFixedHeader() const;
  138 +};
  139 +
116 #endif // TYPES_H 140 #endif // TYPES_H