Commit ed1479935ac2866fa015b5e19744bba76724af13
1 parent
b971b1a5
Add support for QoS 2
Showing
6 changed files
with
243 additions
and
23 deletions
mqttpacket.cpp
| ... | ... | @@ -132,20 +132,50 @@ MqttPacket::MqttPacket(const Publish &publish) : |
| 132 | 132 | calculateRemainingLength(); |
| 133 | 133 | } |
| 134 | 134 | |
| 135 | -// This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector. | |
| 136 | -MqttPacket::MqttPacket(const PubAck &pubAck) : | |
| 137 | - bites(pubAck.getLengthWithoutFixedHeader() + 2) | |
| 135 | +/** | |
| 136 | + * @brief MqttPacket::pubCommonConstruct is common code for constructors for all those empty pub control packets (ack, rec(eived), rel(ease), comp(lete)). | |
| 137 | + * @param packet_id | |
| 138 | + * | |
| 139 | + * This functions cheats a bit and doesn't use calculateRemainingLength, because it's always 2. Be sure to allocate enough room in the vector when | |
| 140 | + * you use this function (add 2 to the length). | |
| 141 | + */ | |
| 142 | +void MqttPacket::pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits) | |
| 138 | 143 | { |
| 139 | - fixed_header_length = 2; // This is the cheat part mentioned above. We're not calculating it dynamically. | |
| 140 | - packetType = PacketType::PUBACK; | |
| 141 | - first_byte = static_cast<char>(packetType) << 4; | |
| 144 | + assert(firstByteDefaultBits <= 0xF); | |
| 145 | + | |
| 146 | + fixed_header_length = 2; | |
| 147 | + first_byte = (static_cast<uint8_t>(packetType) << 4) | firstByteDefaultBits; | |
| 142 | 148 | writeByte(first_byte); |
| 143 | 149 | writeByte(2); // length is always 2. |
| 144 | - char topicLenMSB = (pubAck.packet_id & 0xFF00) >> 8; | |
| 145 | - char topicLenLSB = (pubAck.packet_id & 0x00FF); | |
| 150 | + uint8_t packetIdMSB = (packet_id & 0xFF00) >> 8; | |
| 151 | + uint8_t packetIdLSB = (packet_id & 0x00FF); | |
| 146 | 152 | packet_id_pos = pos; |
| 147 | - writeByte(topicLenMSB); | |
| 148 | - writeByte(topicLenLSB); | |
| 153 | + writeByte(packetIdMSB); | |
| 154 | + writeByte(packetIdLSB); | |
| 155 | +} | |
| 156 | + | |
| 157 | +MqttPacket::MqttPacket(const PubAck &pubAck) : | |
| 158 | + bites(pubAck.getLengthWithoutFixedHeader() + 2) | |
| 159 | +{ | |
| 160 | + pubCommonConstruct(pubAck.packet_id, PacketType::PUBACK); | |
| 161 | +} | |
| 162 | + | |
| 163 | +MqttPacket::MqttPacket(const PubRec &pubRec) : | |
| 164 | + bites(pubRec.getLengthWithoutFixedHeader() + 2) | |
| 165 | +{ | |
| 166 | + pubCommonConstruct(pubRec.packet_id, PacketType::PUBREC); | |
| 167 | +} | |
| 168 | + | |
| 169 | +MqttPacket::MqttPacket(const PubComp &pubComp) : | |
| 170 | + bites(pubComp.getLengthWithoutFixedHeader() + 2) | |
| 171 | +{ | |
| 172 | + pubCommonConstruct(pubComp.packet_id, PacketType::PUBCOMP); | |
| 173 | +} | |
| 174 | + | |
| 175 | +MqttPacket::MqttPacket(const PubRel &pubRel) : | |
| 176 | + bites(pubRel.getLengthWithoutFixedHeader() + 2) | |
| 177 | +{ | |
| 178 | + pubCommonConstruct(pubRel.packet_id, PacketType::PUBREL, 0b0010); | |
| 149 | 179 | } |
| 150 | 180 | |
| 151 | 181 | void MqttPacket::handle() |
| ... | ... | @@ -175,6 +205,12 @@ void MqttPacket::handle() |
| 175 | 205 | handlePublish(); |
| 176 | 206 | else if (packetType == PacketType::PUBACK) |
| 177 | 207 | handlePubAck(); |
| 208 | + else if (packetType == PacketType::PUBREC) | |
| 209 | + handlePubRec(); | |
| 210 | + else if (packetType == PacketType::PUBREL) | |
| 211 | + handlePubRel(); | |
| 212 | + else if (packetType == PacketType::PUBCOMP) | |
| 213 | + handlePubComp(); | |
| 178 | 214 | } |
| 179 | 215 | |
| 180 | 216 | void MqttPacket::handleConnect() |
| ... | ... | @@ -437,7 +473,7 @@ void MqttPacket::handlePublish() |
| 437 | 473 | bool dup = !!(first_byte & 0b00001000); |
| 438 | 474 | char qos = (first_byte & 0b00000110) >> 1; |
| 439 | 475 | |
| 440 | - if (qos == 3) | |
| 476 | + if (qos > 2) | |
| 441 | 477 | throw ProtocolError("QoS 3 is a protocol violation."); |
| 442 | 478 | this->qos = qos; |
| 443 | 479 | |
| ... | ... | @@ -453,21 +489,37 @@ void MqttPacket::handlePublish() |
| 453 | 489 | return; |
| 454 | 490 | } |
| 455 | 491 | |
| 492 | +#ifndef NDEBUG | |
| 493 | + logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", topic.c_str(), qos, retain, dup); | |
| 494 | +#endif | |
| 495 | + | |
| 456 | 496 | if (qos) |
| 457 | 497 | { |
| 458 | - if (qos > 1) | |
| 459 | - throw ProtocolError("Qos > 1 not implemented."); | |
| 460 | 498 | packet_id_pos = pos; |
| 461 | - uint16_t packet_id = readTwoBytesToUInt16(); | |
| 499 | + packet_id = readTwoBytesToUInt16(); | |
| 462 | 500 | |
| 463 | - // Clear the packet ID from this packet, because each new publish must get a new one. It's more of a debug precaution. | |
| 464 | - pos -= 2; | |
| 465 | - char zero[2]; zero[0] = 0; zero[1] = 0; | |
| 466 | - writeBytes(zero, 2); | |
| 501 | + if (qos == 1) | |
| 502 | + { | |
| 503 | + PubAck pubAck(packet_id); | |
| 504 | + MqttPacket response(pubAck); | |
| 505 | + sender->writeMqttPacket(response); | |
| 506 | + } | |
| 507 | + else | |
| 508 | + { | |
| 509 | + PubRec pubRec(packet_id); | |
| 510 | + MqttPacket response(pubRec); | |
| 511 | + sender->writeMqttPacket(response); | |
| 467 | 512 | |
| 468 | - PubAck pubAck(packet_id); | |
| 469 | - MqttPacket response(pubAck); | |
| 470 | - sender->writeMqttPacket(response); | |
| 513 | + if (sender->getSession()->incomingQoS2MessageIdInTransit(packet_id)) | |
| 514 | + { | |
| 515 | + return; | |
| 516 | + } | |
| 517 | + else | |
| 518 | + { | |
| 519 | + // Doing this before the authentication on purpose, so when the publish is not allowed, the QoS control packets are allowed and can finish. | |
| 520 | + sender->getSession()->addIncomingQoS2MessageId(packet_id); | |
| 521 | + } | |
| 522 | + } | |
| 471 | 523 | } |
| 472 | 524 | |
| 473 | 525 | if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write) == AuthResult::success) |
| ... | ... | @@ -496,6 +548,42 @@ void MqttPacket::handlePubAck() |
| 496 | 548 | sender->getSession()->clearQosMessage(packet_id); |
| 497 | 549 | } |
| 498 | 550 | |
| 551 | +/** | |
| 552 | + * @brief MqttPacket::handlePubRec handles QoS 2 'publish received' packets. The publisher receives these. | |
| 553 | + */ | |
| 554 | +void MqttPacket::handlePubRec() | |
| 555 | +{ | |
| 556 | + const uint16_t packet_id = readTwoBytesToUInt16(); | |
| 557 | + sender->getSession()->clearQosMessage(packet_id); | |
| 558 | + sender->getSession()->addOutgoingQoS2MessageId(packet_id); | |
| 559 | + | |
| 560 | + PubRel pubRel(packet_id); | |
| 561 | + MqttPacket response(pubRel); | |
| 562 | + sender->writeMqttPacket(response); | |
| 563 | +} | |
| 564 | + | |
| 565 | +/** | |
| 566 | + * @brief MqttPacket::handlePubRel handles QoS 2 'publish release'. The publisher sends these. | |
| 567 | + */ | |
| 568 | +void MqttPacket::handlePubRel() | |
| 569 | +{ | |
| 570 | + const uint16_t packet_id = readTwoBytesToUInt16(); | |
| 571 | + sender->getSession()->removeIncomingQoS2MessageId(packet_id); | |
| 572 | + | |
| 573 | + PubComp pubcomp(packet_id); | |
| 574 | + MqttPacket response(pubcomp); | |
| 575 | + sender->writeMqttPacket(response); | |
| 576 | +} | |
| 577 | + | |
| 578 | +/** | |
| 579 | + * @brief MqttPacket::handlePubComp handles QoS 2 'publish complete'. The publisher receives these. | |
| 580 | + */ | |
| 581 | +void MqttPacket::handlePubComp() | |
| 582 | +{ | |
| 583 | + const uint16_t packet_id = readTwoBytesToUInt16(); | |
| 584 | + sender->getSession()->removeOutgoingQoS2MessageId(packet_id); | |
| 585 | +} | |
| 586 | + | |
| 499 | 587 | void MqttPacket::calculateRemainingLength() |
| 500 | 588 | { |
| 501 | 589 | assert(fixed_header_length == 0); // because you're not supposed to call this on packet that we already know the length of. |
| ... | ... | @@ -529,6 +617,8 @@ void MqttPacket::setPacketId(uint16_t packet_id) |
| 529 | 617 | assert(packetType == PacketType::PUBLISH); |
| 530 | 618 | assert(qos > 0); |
| 531 | 619 | |
| 620 | + this->packet_id = packet_id; | |
| 621 | + | |
| 532 | 622 | pos = packet_id_pos; |
| 533 | 623 | |
| 534 | 624 | char topicLenMSB = (packet_id & 0xFF00) >> 8; |
| ... | ... | @@ -537,6 +627,12 @@ void MqttPacket::setPacketId(uint16_t packet_id) |
| 537 | 627 | writeByte(topicLenLSB); |
| 538 | 628 | } |
| 539 | 629 | |
| 630 | +uint16_t MqttPacket::getPacketId() const | |
| 631 | +{ | |
| 632 | + assert(qos > 0); | |
| 633 | + return packet_id; | |
| 634 | +} | |
| 635 | + | |
| 540 | 636 | // If I read the specs correctly, the DUP flag is merely for show. It doesn't control anything? |
| 541 | 637 | void MqttPacket::setDuplicate() |
| 542 | 638 | { | ... | ... |
mqttpacket.h
| ... | ... | @@ -57,6 +57,7 @@ class MqttPacket |
| 57 | 57 | char first_byte = 0; |
| 58 | 58 | size_t pos = 0; |
| 59 | 59 | size_t packet_id_pos = 0; |
| 60 | + uint16_t packet_id = 0; | |
| 60 | 61 | ProtocolVersion protocolVersion = ProtocolVersion::None; |
| 61 | 62 | Logger *logger = Logger::getInstance(); |
| 62 | 63 | |
| ... | ... | @@ -68,6 +69,7 @@ class MqttPacket |
| 68 | 69 | size_t remainingAfterPos(); |
| 69 | 70 | |
| 70 | 71 | void calculateRemainingLength(); |
| 72 | + void pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits = 0); | |
| 71 | 73 | |
| 72 | 74 | MqttPacket(const MqttPacket &other) = default; |
| 73 | 75 | public: |
| ... | ... | @@ -84,6 +86,9 @@ public: |
| 84 | 86 | MqttPacket(const UnsubAck &unsubAck); |
| 85 | 87 | MqttPacket(const Publish &publish); |
| 86 | 88 | MqttPacket(const PubAck &pubAck); |
| 89 | + MqttPacket(const PubRec &pubRec); | |
| 90 | + MqttPacket(const PubComp &pubComp); | |
| 91 | + MqttPacket(const PubRel &pubRel); | |
| 87 | 92 | |
| 88 | 93 | void handle(); |
| 89 | 94 | void handleConnect(); |
| ... | ... | @@ -93,6 +98,9 @@ public: |
| 93 | 98 | void handlePing(); |
| 94 | 99 | void handlePublish(); |
| 95 | 100 | void handlePubAck(); |
| 101 | + void handlePubRec(); | |
| 102 | + void handlePubRel(); | |
| 103 | + void handlePubComp(); | |
| 96 | 104 | |
| 97 | 105 | size_t getSizeIncludingNonPresentHeader() const; |
| 98 | 106 | const std::vector<char> &getBites() const { return bites; } |
| ... | ... | @@ -105,6 +113,7 @@ public: |
| 105 | 113 | char getFirstByte() const; |
| 106 | 114 | RemainingLength getRemainingLength() const; |
| 107 | 115 | void setPacketId(uint16_t packet_id); |
| 116 | + uint16_t getPacketId() const; | |
| 108 | 117 | void setDuplicate(); |
| 109 | 118 | size_t getTotalMemoryFootprint(); |
| 110 | 119 | }; | ... | ... |
session.cpp
| ... | ... | @@ -64,11 +64,13 @@ void Session::writePacket(const MqttPacket &packet, char max_qos) |
| 64 | 64 | c->writeMqttPacketAndBlameThisClient(packet, qos); |
| 65 | 65 | } |
| 66 | 66 | } |
| 67 | - else if (qos == 1) | |
| 67 | + else if (qos > 0) | |
| 68 | 68 | { |
| 69 | 69 | std::shared_ptr<MqttPacket> copyPacket = packet.getCopy(); |
| 70 | 70 | std::unique_lock<std::mutex> locker(qosQueueMutex); |
| 71 | - if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) | |
| 71 | + | |
| 72 | + const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size(); | |
| 73 | + if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) | |
| 72 | 74 | { |
| 73 | 75 | logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str()); |
| 74 | 76 | return; |
| ... | ... | @@ -145,6 +147,13 @@ void Session::sendPendingQosMessages() |
| 145 | 147 | c->writeMqttPacketAndBlameThisClient(*qosMessage.packet.get(), qosMessage.packet->getQos()); |
| 146 | 148 | qosMessage.packet->setDuplicate(); // Any dealings with this packet from here will be a duplicate. |
| 147 | 149 | } |
| 150 | + | |
| 151 | + for (const uint16_t packet_id : outgoingQoS2MessageIds) | |
| 152 | + { | |
| 153 | + PubRel pubRel(packet_id); | |
| 154 | + MqttPacket packet(pubRel); | |
| 155 | + c->writeMqttPacketAndBlameThisClient(packet, 2); | |
| 156 | + } | |
| 148 | 157 | } |
| 149 | 158 | } |
| 150 | 159 | |
| ... | ... | @@ -158,3 +167,41 @@ bool Session::hasExpired() |
| 158 | 167 | { |
| 159 | 168 | return clientDisconnected() && (lastTouched + EXPIRE_SESSION_AFTER) < time(NULL); |
| 160 | 169 | } |
| 170 | + | |
| 171 | +void Session::addIncomingQoS2MessageId(uint16_t packet_id) | |
| 172 | +{ | |
| 173 | + incomingQoS2MessageIds.insert(packet_id); | |
| 174 | +} | |
| 175 | + | |
| 176 | +bool Session::incomingQoS2MessageIdInTransit(uint16_t packet_id) const | |
| 177 | +{ | |
| 178 | + const auto it = incomingQoS2MessageIds.find(packet_id); | |
| 179 | + return it != incomingQoS2MessageIds.end(); | |
| 180 | +} | |
| 181 | + | |
| 182 | +void Session::removeIncomingQoS2MessageId(u_int16_t packet_id) | |
| 183 | +{ | |
| 184 | +#ifndef NDEBUG | |
| 185 | + logger->logf(LOG_DEBUG, "As QoS 2 receiver: publish released (PUBREL) for '%s', packet id '%d'. Left in queue: %d", client_id.c_str(), packet_id, incomingQoS2MessageIds.size()); | |
| 186 | +#endif | |
| 187 | + | |
| 188 | + const auto it = incomingQoS2MessageIds.find(packet_id); | |
| 189 | + if (it != incomingQoS2MessageIds.end()) | |
| 190 | + incomingQoS2MessageIds.erase(it); | |
| 191 | +} | |
| 192 | + | |
| 193 | +void Session::addOutgoingQoS2MessageId(uint16_t packet_id) | |
| 194 | +{ | |
| 195 | + outgoingQoS2MessageIds.insert(packet_id); | |
| 196 | +} | |
| 197 | + | |
| 198 | +void Session::removeOutgoingQoS2MessageId(u_int16_t packet_id) | |
| 199 | +{ | |
| 200 | +#ifndef NDEBUG | |
| 201 | + logger->logf(LOG_DEBUG, "As QoS 2 sender: publish complete (PUBCOMP) for '%s', packet id '%d'. Left in queue: %d", client_id.c_str(), packet_id, outgoingQoS2MessageIds.size()); | |
| 202 | +#endif | |
| 203 | + | |
| 204 | + const auto it = outgoingQoS2MessageIds.find(packet_id); | |
| 205 | + if (it != outgoingQoS2MessageIds.end()) | |
| 206 | + outgoingQoS2MessageIds.erase(it); | |
| 207 | +} | ... | ... |
session.h
| ... | ... | @@ -21,6 +21,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. |
| 21 | 21 | #include <memory> |
| 22 | 22 | #include <list> |
| 23 | 23 | #include <mutex> |
| 24 | +#include <set> | |
| 24 | 25 | |
| 25 | 26 | #include "forward_declarations.h" |
| 26 | 27 | #include "logger.h" |
| ... | ... | @@ -45,6 +46,8 @@ class Session |
| 45 | 46 | std::string client_id; |
| 46 | 47 | std::string username; |
| 47 | 48 | std::list<QueuedQosPacket> qosPacketQueue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] |
| 49 | + std::set<uint16_t> incomingQoS2MessageIds; | |
| 50 | + std::set<uint16_t> outgoingQoS2MessageIds; | |
| 48 | 51 | std::mutex qosQueueMutex; |
| 49 | 52 | uint16_t nextPacketId = 0; |
| 50 | 53 | ssize_t qosQueueBytes = 0; |
| ... | ... | @@ -66,6 +69,14 @@ public: |
| 66 | 69 | void sendPendingQosMessages(); |
| 67 | 70 | void touch(time_t val = 0); |
| 68 | 71 | bool hasExpired(); |
| 72 | + | |
| 73 | + void addIncomingQoS2MessageId(uint16_t packet_id); | |
| 74 | + bool incomingQoS2MessageIdInTransit(uint16_t packet_id) const; | |
| 75 | + void removeIncomingQoS2MessageId(u_int16_t packet_id); | |
| 76 | + | |
| 77 | + void addOutgoingQoS2MessageId(uint16_t packet_id); | |
| 78 | + void removeOutgoingQoS2MessageId(u_int16_t packet_id); | |
| 79 | + | |
| 69 | 80 | }; |
| 70 | 81 | |
| 71 | 82 | #endif // SESSION_H | ... | ... |
types.cpp
| ... | ... | @@ -82,3 +82,36 @@ size_t UnsubAck::getLengthWithoutFixedHeader() const |
| 82 | 82 | { |
| 83 | 83 | return 2; |
| 84 | 84 | } |
| 85 | + | |
| 86 | +PubRec::PubRec(uint16_t packet_id) : | |
| 87 | + packet_id(packet_id) | |
| 88 | +{ | |
| 89 | + | |
| 90 | +} | |
| 91 | + | |
| 92 | +size_t PubRec::getLengthWithoutFixedHeader() const | |
| 93 | +{ | |
| 94 | + return 2; | |
| 95 | +} | |
| 96 | + | |
| 97 | +PubComp::PubComp(uint16_t packet_id) : | |
| 98 | + packet_id(packet_id) | |
| 99 | +{ | |
| 100 | + | |
| 101 | +} | |
| 102 | + | |
| 103 | +size_t PubComp::getLengthWithoutFixedHeader() const | |
| 104 | +{ | |
| 105 | + return 2; | |
| 106 | +} | |
| 107 | + | |
| 108 | +PubRel::PubRel(uint16_t packet_id) : | |
| 109 | + packet_id(packet_id) | |
| 110 | +{ | |
| 111 | + | |
| 112 | +} | |
| 113 | + | |
| 114 | +size_t PubRel::getLengthWithoutFixedHeader() const | |
| 115 | +{ | |
| 116 | + return 2; | |
| 117 | +} | ... | ... |
types.h
| ... | ... | @@ -113,4 +113,28 @@ public: |
| 113 | 113 | size_t getLengthWithoutFixedHeader() const; |
| 114 | 114 | }; |
| 115 | 115 | |
| 116 | +class PubRec | |
| 117 | +{ | |
| 118 | +public: | |
| 119 | + PubRec(uint16_t packet_id); | |
| 120 | + uint16_t packet_id; | |
| 121 | + size_t getLengthWithoutFixedHeader() const; | |
| 122 | +}; | |
| 123 | + | |
| 124 | +class PubComp | |
| 125 | +{ | |
| 126 | +public: | |
| 127 | + PubComp(uint16_t packet_id); | |
| 128 | + uint16_t packet_id; | |
| 129 | + size_t getLengthWithoutFixedHeader() const; | |
| 130 | +}; | |
| 131 | + | |
| 132 | +class PubRel | |
| 133 | +{ | |
| 134 | +public: | |
| 135 | + PubRel(uint16_t packet_id); | |
| 136 | + uint16_t packet_id; | |
| 137 | + size_t getLengthWithoutFixedHeader() const; | |
| 138 | +}; | |
| 139 | + | |
| 116 | 140 | #endif // TYPES_H | ... | ... |