diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 820b5f5..4d7cfd1 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -102,6 +102,15 @@ private slots: void testSavingSessions(); + void testCopyPacket(); + + void testDowngradeQoSOnSubscribeQos2to2(); + void testDowngradeQoSOnSubscribeQos2to1(); + void testDowngradeQoSOnSubscribeQos2to0(); + void testDowngradeQoSOnSubscribeQos1to1(); + void testDowngradeQoSOnSubscribeQos1to0(); + void testDowngradeQoSOnSubscribeQos0to0(); + }; MainTests::MainTests() @@ -1027,7 +1036,9 @@ void MainTests::testSavingSessions() std::shared_ptr c1ses = c1->getSession(); c1.reset(); - c1ses->writePacket(publish, 1, false, count); + MqttPacket publishPacket(publish); + std::shared_ptr possibleQos0Copy; + c1ses->writePacket(publishPacket, 1, possibleQos0Copy, count); store->saveSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db"); @@ -1091,6 +1102,145 @@ void MainTests::testSavingSessions() } } +void testCopyPacketHelper(const std::string &topic, char from_qos, char to_qos, bool retain) +{ + assert(to_qos <= from_qos); + + Logger::getInstance()->setFlags(false, false, true); + + std::shared_ptr settings(new Settings()); + settings->logDebug = false; + std::shared_ptr store(new SubscriptionStore()); + std::shared_ptr t(new ThreadData(0, store, settings)); + + // Kind of a hack... + Authentication auth(*settings.get()); + ThreadAuth::assign(&auth); + + std::shared_ptr dummyClient(new Client(0, t, nullptr, false, nullptr, settings, false)); + dummyClient->setClientProperties(ProtocolVersion::Mqtt311, "qostestclient", "user1", true, 60, false); + store->registerClientAndKickExistingOne(dummyClient); + + uint16_t packetid = 66; + for (int len = 0; len < 150; len++ ) + { + const uint16_t pack_id = packetid++; + + std::vector parsedPackets; + + const std::string payloadOne = getSecureRandomString(len); + Publish pubOne(topic, payloadOne, from_qos); + pubOne.retain = retain; + MqttPacket stagingPacketOne(pubOne); + if (from_qos > 0) + stagingPacketOne.setPacketId(pack_id); + CirBuf stagingBufOne(1024); + stagingPacketOne.readIntoBuf(stagingBufOne); + + MqttPacket::bufferToMqttPackets(stagingBufOne, parsedPackets, dummyClient); + QVERIFY(parsedPackets.size() == 1); + MqttPacket parsedPacketOne = std::move(parsedPackets.front()); + parsedPacketOne.handlePublish(); + if (retain) // A normal handled packet always has retain=0, so I force setting it here. + parsedPacketOne.setRetain(); + QCOMPARE(stagingPacketOne.getTopic(), parsedPacketOne.getTopic()); + QCOMPARE(stagingPacketOne.getPayloadCopy(), parsedPacketOne.getPayloadCopy()); + + std::shared_ptr copiedPacketOne = parsedPacketOne.getCopy(to_qos); + + QCOMPARE(payloadOne, copiedPacketOne->getPayloadCopy()); + + // Now compare the written buffer of our copied packet to one that was written with our known good reference packet. + + Publish pubReference(topic, payloadOne, to_qos); + pubReference.retain = retain; + MqttPacket packetReference(pubReference); + if (to_qos > 0) + packetReference.setPacketId(pack_id); + CirBuf bufOfReference(1024); + CirBuf bufOfCopied(1024); + packetReference.readIntoBuf(bufOfReference); + copiedPacketOne->readIntoBuf(bufOfCopied); + QVERIFY2(bufOfCopied == bufOfReference, formatString("Failure on length %d for topic %s, from qos %d to qos %d, retain: %d.", + len, topic.c_str(), from_qos, to_qos, retain).c_str()); + } +} + +/** + * @brief MainTests::testCopyPacket tests the actual bytes of a copied that would be written to a client. + * + * This is specifically to test the optimiziations in getCopy(). It indirectly also tests packet parsing. + */ +void MainTests::testCopyPacket() +{ + for (int retain = 0; retain < 2; retain++) + { + testCopyPacketHelper("John/McLane", 0, 0, retain); + testCopyPacketHelper("Ben/Sisko", 1, 1, retain); + testCopyPacketHelper("Rebecca/Bunch", 2, 2, retain); + + testCopyPacketHelper("Buffy/Slayer", 1, 0, retain); + testCopyPacketHelper("Sarah/Connor", 2, 0, retain); + testCopyPacketHelper("Susan/Mayer", 2, 1, retain); + } +} + +void testDowngradeQoSOnSubscribeHelper(const char pub_qos, const char sub_qos) +{ + TwoClientTestContext testContext; + + const QString topic("Star/Trek"); + const QByteArray payload("Captain Kirk"); + + testContext.connectSender(); + testContext.connectReceiver(); + + testContext.subscribeReceiver(topic, sub_qos); + testContext.publish(topic, payload, pub_qos, false); + + testContext.waitReceiverReceived(1); + + QCOMPARE(testContext.receivedMessages.length(), 1); + QMQTT::Message &recv = testContext.receivedMessages.first(); + + const char expected_qos = std::min(pub_qos, sub_qos); + QVERIFY2(recv.qos() == expected_qos, formatString("Failure: received QoS is %d. Published is %d. Subscribed as %d. Expected QoS is %d", + recv.qos(), pub_qos, sub_qos, expected_qos).c_str()); + QVERIFY(recv.topic() == topic); + QVERIFY(recv.payload() == payload); +} + +void MainTests::testDowngradeQoSOnSubscribeQos2to2() +{ + testDowngradeQoSOnSubscribeHelper(2, 2); +} + +void MainTests::testDowngradeQoSOnSubscribeQos2to1() +{ + testDowngradeQoSOnSubscribeHelper(2, 1); +} + +void MainTests::testDowngradeQoSOnSubscribeQos2to0() +{ + testDowngradeQoSOnSubscribeHelper(2, 0); +} + +void MainTests::testDowngradeQoSOnSubscribeQos1to1() +{ + testDowngradeQoSOnSubscribeHelper(1, 1); +} + +void MainTests::testDowngradeQoSOnSubscribeQos1to0() +{ + testDowngradeQoSOnSubscribeHelper(1, 0); +} + +void MainTests::testDowngradeQoSOnSubscribeQos0to0() +{ + testDowngradeQoSOnSubscribeHelper(0, 0); +} + + int main(int argc, char *argv[]) { QCoreApplication app(argc, argv); diff --git a/FlashMQTests/twoclienttestcontext.cpp b/FlashMQTests/twoclienttestcontext.cpp index 13010b0..72458da 100644 --- a/FlashMQTests/twoclienttestcontext.cpp +++ b/FlashMQTests/twoclienttestcontext.cpp @@ -35,12 +35,22 @@ TwoClientTestContext::TwoClientTestContext(QObject *parent) : QObject(parent) connect(receiver.data(), &QMQTT::Client::error, this, &TwoClientTestContext::onClientError); } +void TwoClientTestContext::publish(const QString &topic, const QByteArray &payload) +{ + publish(topic, payload, 0, false); +} + void TwoClientTestContext::publish(const QString &topic, const QByteArray &payload, bool retain) { + publish(topic, payload, 0, retain); +} + +void TwoClientTestContext::publish(const QString &topic, const QByteArray &payload, const quint8 qos, bool retain) +{ QMQTT::Message msg; msg.setTopic(topic); msg.setRetain(retain); - msg.setQos(0); + msg.setQos(qos); msg.setPayload(payload); sender->publish(msg); } @@ -71,9 +81,9 @@ void TwoClientTestContext::disconnectReceiver() waiter.exec(); } -void TwoClientTestContext::subscribeReceiver(const QString &topic) +void TwoClientTestContext::subscribeReceiver(const QString &topic, const quint8 qos) { - receiver->subscribe(topic); + receiver->subscribe(topic, qos); QEventLoop waiter; QTimer timeout; diff --git a/FlashMQTests/twoclienttestcontext.h b/FlashMQTests/twoclienttestcontext.h index d1aada8..d8852d1 100644 --- a/FlashMQTests/twoclienttestcontext.h +++ b/FlashMQTests/twoclienttestcontext.h @@ -34,11 +34,13 @@ private slots: public: explicit TwoClientTestContext(QObject *parent = nullptr); - void publish(const QString &topic, const QByteArray &payload, bool retain = false); + void publish(const QString &topic, const QByteArray &payload); + void publish(const QString &topic, const QByteArray &payload, bool retain); + void publish(const QString &topic, const QByteArray &payload, const quint8 qos, bool retain); void connectSender(); void connectReceiver(); void disconnectReceiver(); - void subscribeReceiver(const QString &topic); + void subscribeReceiver(const QString &topic, const quint8 qos = 0); void waitReceiverReceived(int count); void onClientError(const QMQTT::ClientError error); diff --git a/cirbuf.cpp b/cirbuf.cpp index 766b505..d451c34 100644 --- a/cirbuf.cpp +++ b/cirbuf.cpp @@ -252,3 +252,22 @@ void CirBuf::read(void *buf, const size_t count) assert(_packet_len == 0); assert(i == static_cast(count)); } + +/** + * @brief CirBuf::operator == simplistic comparision. It doesn't take the fact that it's circular into account. + * @param other + * @return + * + * It was created for unit testing. read() and write() are non-const, so taking the circular properties into account + * would need more/duplicate code that I don't need at this point. + */ +bool CirBuf::operator==(const CirBuf &other) const +{ +#ifdef NDEBUG + throw std::exception(); // you can't use it in release builds, because new buffers aren't zeroed. +#endif + + return tail == 0 && other.tail == 0 + && usedBytes() == other.usedBytes() + && std::memcmp(buf, other.buf, size) == 0; +} diff --git a/cirbuf.h b/cirbuf.h index b8166c2..0297849 100644 --- a/cirbuf.h +++ b/cirbuf.h @@ -60,6 +60,8 @@ public: void write(const void *buf, size_t count); void read(void *buf, const size_t count); + + bool operator==(const CirBuf &other) const; }; #endif // CIRBUF_H diff --git a/client.cpp b/client.cpp index 65f924e..4d29e93 100644 --- a/client.cpp +++ b/client.cpp @@ -56,7 +56,7 @@ Client::~Client() { Publish will(will_topic, will_payload, will_qos); will.retain = will_retain; - const MqttPacket willPacket(will); + MqttPacket willPacket(will); const std::vector subtopics = splitToVector(will_topic, '/'); store->queuePacketAtSubscribers(subtopics, willPacket); @@ -180,7 +180,7 @@ void Client::writeText(const std::string &text) setReadyForWriting(true); } -int Client::writeMqttPacket(const MqttPacket &packet, const char qos) +int Client::writeMqttPacket(const MqttPacket &packet) { std::lock_guard locker(writeBufMutex); @@ -193,7 +193,7 @@ int Client::writeMqttPacket(const MqttPacket &packet, const char qos) // And drop a publish when it doesn't fit, even after resizing. This means we do allow pings. And // QoS packet are queued and limited elsewhere. - if (packet.packetType == PacketType::PUBLISH && qos == 0 && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace()) + if (packet.packetType == PacketType::PUBLISH && packet.getQos() == 0 && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace()) { return 0; } @@ -208,11 +208,11 @@ int Client::writeMqttPacket(const MqttPacket &packet, const char qos) } // Helper method to avoid the exception ending up at the sender of messages, which would then get disconnected. -int Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet, const char qos) +int Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet) { try { - return this->writeMqttPacket(packet, qos); + return this->writeMqttPacket(packet); } catch (std::exception &ex) { diff --git a/client.h b/client.h index 8510233..0a2ac0d 100644 --- a/client.h +++ b/client.h @@ -121,8 +121,8 @@ public: void writeText(const std::string &text); void writePingResp(); - int writeMqttPacket(const MqttPacket &packet, const char qos = 0); - int writeMqttPacketAndBlameThisClient(const MqttPacket &packet, const char qos); + int writeMqttPacket(const MqttPacket &packet); + int writeMqttPacketAndBlameThisClient(const MqttPacket &packet); bool writeBufIntoFd(); bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; } diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 4c8dc6c..db95a68 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -48,11 +48,69 @@ MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_lengt * @brief MqttPacket::getCopy (using default copy constructor and resetting some selected fields) is easier than using the copy constructor * publically, because then I have to keep maintaining a functioning copy constructor for each new field I add. * @return a shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource. + * + * The idea is that because a packet with QoS is longer than one without, we just copy as much as possible if both packets have the same QoS. + * + * Note that there can be two types of packets: one with the fixed header (including remaining length), and one without. The latter we could be + * more clever about, but I'm forgoing that right now. Their use is mostly for retained messages. + * + * Also note that some fields are undeterminstic in the copy: dup, retain and packetid for instance. Sometimes they come from the original, + * sometimes not. The current planned usage is that those fields will either ONLY or NEVER be used in the copy, so it doesn't matter what I do + * with them here. I may reconsider. */ -std::shared_ptr MqttPacket::getCopy() const +std::shared_ptr MqttPacket::getCopy(char new_max_qos) const { + assert(packetType == PacketType::PUBLISH); + + // You're not supposed to copy a duplicate packet. The only packets that get the dup flag, should not be copied AGAIN. This + // has to do with the Session::writePacket() and Session::sendPendingQosMessages() logic. + assert((first_byte & 0b00001000) == 0); + + if (qos > 0 && new_max_qos == 0) + { + // if shrinking the packet doesn't alter the amount of bytes in the 'remaining length' part of the header, we can + // just memmove+shrink the packet. This is because the packet id always is two bytes before the payload, so we just move the payload + // over it. When testing 100M copies, it went from 21000 ms to 10000 ms. In other words, about 100 µs to 200 µs per copy. + // There is an elaborate unit test to test this optimization. + if ((fixed_header_length == 2 && bites.size() < 125)) + { + // I don't know yet if this is true, but I don't want to forget when I implemenet MQTT5. + assert(sender && sender->getProtocolVersion() <= ProtocolVersion::Mqtt311); + + std::shared_ptr p(new MqttPacket(*this)); + p->sender.reset(); + + if (payloadLen > 0) + std::memmove(&p->bites[packet_id_pos], &p->bites[packet_id_pos+2], payloadLen); + p->bites.erase(p->bites.end() - 2, p->bites.end()); + p->packet_id_pos = 0; + p->payloadStart -= 2; + if (pos > p->bites.size()) // pos can possible be set elsewhere, so we only set it back if it was after the payload. + p->pos -= 2; + p->packet_id = 0; + + // Clear QoS bits from the header. + p->first_byte &= 0b11111001; + p->bites[0] = p->first_byte; + + assert((p->bites[1] & 0b10000000) == 0); // when there is an MSB, I musn't get rid of it. + assert(p->bites[1] > 3); // There has to be a remaining value after subtracting 2. + + p->bites[1] -= 2; // Reduce the value in the 'remaining length' part of the header. + + return p; + } + + Publish pub(topic, getPayloadCopy(), new_max_qos); + pub.retain = getRetain(); + std::shared_ptr copyPacket(new MqttPacket(pub)); + return copyPacket; + } + std::shared_ptr copyPacket(new MqttPacket(*this)); copyPacket->sender.reset(); + if (qos != new_max_qos) + copyPacket->setQos(new_max_qos); return copyPacket; } @@ -773,11 +831,11 @@ size_t MqttPacket::getTotalMemoryFootprint() * @return * * It's necessary sometimes, but it's against FlashMQ's concept of not parsing the payload. Normally, you can just write out - * the whole byte array that is a packet to subscribers. No need to copy and such. + * the whole byte array of an original packet to subscribers. No need to copy and such. * - * I created it for saving QoS packages in the db file. + * But, as stated, sometimes it's necessary. */ -std::string MqttPacket::getPayloadCopy() +std::string MqttPacket::getPayloadCopy() const { assert(payloadStart > 0); assert(pos <= bites.size()); @@ -798,6 +856,23 @@ size_t MqttPacket::getSizeIncludingNonPresentHeader() const return total; } +void MqttPacket::setQos(const char new_qos) +{ + // You can't change to a QoS level that would remove the packet identifier. + assert((qos == 0 && new_qos == 0) || (qos > 0 && new_qos > 0)); + assert(new_qos > 0 && packet_id_pos > 0); + + qos = new_qos; + first_byte &= 0b11111001; + first_byte |= (qos << 1); + + if (fixed_header_length > 0) + { + pos = 0; + writeByte(first_byte); + } +} + const std::string &MqttPacket::getTopic() const { return this->topic; @@ -886,6 +961,31 @@ size_t MqttPacket::remainingAfterPos() return bites.size() - pos; } +bool MqttPacket::getRetain() const +{ + return (first_byte & 0b00000001); +} + +/** + * @brief MqttPacket::setRetain set the retain bit in the first byte. I think I only need this in tests, because existing subscribers don't get retain=1, + * so handlePublish() clears it. But I needed it to be set in testing. + * + * Publishing of the retained messages goes through the MqttPacket(Publish) constructor, hence this setRetain() isn't necessary for that. + */ +void MqttPacket::setRetain() +{ +#ifndef TESTING + assert(false); +#endif + + first_byte |= 0b00000001; + + if (fixed_header_length > 0) + { + pos = 0; + writeByte(first_byte); + } +} void MqttPacket::readIntoBuf(CirBuf &buf) const { diff --git a/mqttpacket.h b/mqttpacket.h index 3eb4fd9..e5d5e96 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -80,7 +80,7 @@ public: MqttPacket(MqttPacket &&other) = default; - std::shared_ptr getCopy() const; + std::shared_ptr getCopy(char new_max_qos) const; // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance. MqttPacket(const ConnAck &connAck); @@ -109,6 +109,7 @@ public: size_t getSizeIncludingNonPresentHeader() const; const std::vector &getBites() const { return bites; } char getQos() const { return qos; } + void setQos(const char new_qos); const std::string &getTopic() const; const std::vector &getSubtopics() const; std::shared_ptr getSender() const; @@ -120,8 +121,10 @@ public: uint16_t getPacketId() const; void setDuplicate(); size_t getTotalMemoryFootprint(); - std::string getPayloadCopy(); void readIntoBuf(CirBuf &buf) const; + std::string getPayloadCopy() const; + bool getRetain() const; + void setRetain(); }; #endif // MQTTPACKET_H diff --git a/qospacketqueue.cpp b/qospacketqueue.cpp index 696bedd..ba6032c 100644 --- a/qospacketqueue.cpp +++ b/qospacketqueue.cpp @@ -44,11 +44,11 @@ size_t QoSPacketQueue::getByteSize() const * @param id * @return the packet copy. */ -std::shared_ptr QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id) +std::shared_ptr QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos) { assert(p.getQos() > 0); - std::shared_ptr copyPacket = p.getCopy(); + std::shared_ptr copyPacket = p.getCopy(new_max_qos); copyPacket->setPacketId(id); queue.push_back(copyPacket); qosQueueBytes += copyPacket->getTotalMemoryFootprint(); diff --git a/qospacketqueue.h b/qospacketqueue.h index f8b18d1..d3e86d5 100644 --- a/qospacketqueue.h +++ b/qospacketqueue.h @@ -15,7 +15,7 @@ public: void erase(const uint16_t packet_id); size_t size() const; size_t getByteSize() const; - std::shared_ptr queuePacket(const MqttPacket &p, uint16_t id); + std::shared_ptr queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos); std::shared_ptr queuePacket(const Publish &pub, uint16_t id); std::list>::const_iterator begin() const; diff --git a/session.cpp b/session.cpp index fc886a1..9aee6bb 100644 --- a/session.cpp +++ b/session.cpp @@ -59,6 +59,29 @@ void Session::setSessionTouch(int64_t ageInMs) lastTouched = point; } +bool Session::requiresPacketRetransmission() const +{ + const std::shared_ptr client = makeSharedClient(); + + if (!client) + return true; + + // MQTT 3.1: "Brokers, however, should retry any unacknowledged message." + // MQTT 3.1.1: "This [reconnecting] is the only circumstance where a Client or Server is REQUIRED to redeliver messages." + if (client->getProtocolVersion() < ProtocolVersion::Mqtt311) + return true; + + // TODO: for MQTT5, the rules are different. + return !client->getCleanSession(); +} + +void Session::increasePacketId() +{ + nextPacketId++; + if (nextPacketId == 0) + nextPacketId++; +} + /** * @brief Session::Session copy constructor. Was created for session storing, and is explicitely kept private, to avoid making accidental copies. * @param other @@ -116,63 +139,101 @@ void Session::assignActiveConnection(std::shared_ptr &client) /** * @brief Session::writePacket is the main way to give a client a packet -> it goes through the session. - * @param packet + * @param packet is not const. We set the qos and packet id for each publish. This should be safe, because the packet + * with original packet id and qos is not saved. This saves unnecessary copying. * @param max_qos * @param retain. Keep MQTT-3.3.1-9 in mind: existing subscribers don't get retain=1 on packets. * @param count. Reference value is updated. It's for statistics. */ -void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count) +void Session::writePacket(MqttPacket &packet, char max_qos, std::shared_ptr &downgradedQos0PacketCopy, uint64_t &count) { assert(max_qos <= 2); - const char qos = std::min(packet.getQos(), max_qos); + const char effectiveQos = std::min(packet.getQos(), max_qos); Authentication *_auth = ThreadAuth::getAuth(); assert(_auth); Authentication &auth = *_auth; - if (auth.aclCheck(client_id, username, packet.getTopic(), packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success) + if (auth.aclCheck(client_id, username, packet.getTopic(), packet.getSubtopics(), AclAccess::read, effectiveQos, packet.getRetain()) == AuthResult::success) { - if (qos == 0) + std::shared_ptr c = makeSharedClient(); + if (effectiveQos == 0) { - std::shared_ptr c = makeSharedClient(); - if (c) { - count += c->writeMqttPacketAndBlameThisClient(packet, qos); + const MqttPacket *packetToSend = &packet; + + if (max_qos < packet.getQos()) + { + if (!downgradedQos0PacketCopy) + downgradedQos0PacketCopy = packet.getCopy(max_qos); + packetToSend = downgradedQos0PacketCopy.get(); + } + + count += c->writeMqttPacketAndBlameThisClient(*packetToSend); } } - else if (qos > 0) + else if (effectiveQos > 0) { - std::unique_lock locker(qosQueueMutex); + const bool requiresRetransmission = requiresPacketRetransmission(); - const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size(); - if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosPacketQueue.getByteSize() >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) + if (requiresRetransmission) { - if (QoSLogPrintedAtId != nextPacketId) + std::unique_lock locker(qosQueueMutex); + + const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size(); + if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosPacketQueue.getByteSize() >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) { - logger->logf(LOG_WARNING, "Dropping QoS message(s) for client '%s', because its QoS buffers were full.", client_id.c_str()); - QoSLogPrintedAtId = nextPacketId; + if (QoSLogPrintedAtId != nextPacketId) + { + logger->logf(LOG_WARNING, "Dropping QoS message(s) for client '%s', because max in-transit packet count reached.", client_id.c_str()); + QoSLogPrintedAtId = nextPacketId; + } + return; } - return; - } - nextPacketId++; - if (nextPacketId == 0) - nextPacketId++; - std::shared_ptr copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId); - locker.unlock(); + increasePacketId(); + std::shared_ptr copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId, effectiveQos); - std::shared_ptr c = makeSharedClient(); - if (c) + if (c) + { + count += c->writeMqttPacketAndBlameThisClient(*copyPacket.get()); + copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. + } + } + else { - count += c->writeMqttPacketAndBlameThisClient(*copyPacket.get(), qos); - copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. + // We don't need to make a copy of the packet in this branch, because: + // - The packet to give the client won't shrink in size because source and client have a packet_id. + // - We don't have to store the copy in the session for retransmission, see Session::requiresPacketRetransmission() + // So, we just keep altering the original published packet. + + std::unique_lock locker(qosQueueMutex); + + if (qosInFlightCounter >= 65530) // Includes a small safety margin. + { + if (QoSLogPrintedAtId != nextPacketId) + { + logger->logf(LOG_WARNING, "Dropping QoS message(s) for client '%s', because it hasn't seen enough PUBACKs to release places.", client_id.c_str()); + QoSLogPrintedAtId = nextPacketId; + } + return; + } + + increasePacketId(); + + // This changes the packet ID and QoS of the incoming packet for each subscriber, but because we don't store that packet anywhere, + // that should be fine. + packet.setPacketId(nextPacketId); + packet.setQos(effectiveQos); + + qosInFlightCounter++; + assert(c); // with requiresRetransmission==false, there must be a client. + c->writeMqttPacketAndBlameThisClient(packet); } } } } -// Normatively, this while loop will break on the first element, because all messages are sent out in order and -// should be acked in order. void Session::clearQosMessage(uint16_t packet_id) { #ifndef NDEBUG @@ -180,7 +241,13 @@ void Session::clearQosMessage(uint16_t packet_id) #endif std::lock_guard locker(qosQueueMutex); - qosPacketQueue.erase(packet_id); + if (requiresPacketRetransmission()) + qosPacketQueue.erase(packet_id); + else + { + qosInFlightCounter--; + qosInFlightCounter = std::max(0, qosInFlightCounter); // Should never happen, but in case we receive too many PUBACKs. + } } // [MQTT-4.4.0-1]: "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST re-send any @@ -199,7 +266,7 @@ uint64_t Session::sendPendingQosMessages() std::lock_guard locker(qosQueueMutex); for (const std::shared_ptr &qosMessage : qosPacketQueue) { - count += c->writeMqttPacketAndBlameThisClient(*qosMessage.get(), qosMessage->getQos()); + count += c->writeMqttPacketAndBlameThisClient(*qosMessage.get()); qosMessage->setDuplicate(); // Any dealings with this packet from here will be a duplicate. } @@ -207,7 +274,7 @@ uint64_t Session::sendPendingQosMessages() { PubRel pubRel(packet_id); MqttPacket packet(pubRel); - count += c->writeMqttPacketAndBlameThisClient(packet, 2); + count += c->writeMqttPacketAndBlameThisClient(packet); } } diff --git a/session.h b/session.h index 366140e..ae75728 100644 --- a/session.h +++ b/session.h @@ -48,11 +48,14 @@ class Session std::set outgoingQoS2MessageIds; std::mutex qosQueueMutex; uint16_t nextPacketId = 0; + uint16_t qosInFlightCounter = 0; uint16_t QoSLogPrintedAtId = 0; std::chrono::time_point lastTouched = std::chrono::steady_clock::now(); Logger *logger = Logger::getInstance(); int64_t getSessionRelativeAgeInMs() const; void setSessionTouch(int64_t ageInMs); + bool requiresPacketRetransmission() const; + void increasePacketId(); Session(const Session &other); public: @@ -69,7 +72,7 @@ public: const std::string &getClientId() const { return client_id; } std::shared_ptr makeSharedClient() const; void assignActiveConnection(std::shared_ptr &client); - void writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count); + void writePacket(MqttPacket &packet, char max_qos, std::shared_ptr &downgradedQos0PacketCopy, uint64_t &count); void clearQosMessage(uint16_t packet_id); uint64_t sendPendingQosMessages(); void touch(std::chrono::time_point val); diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index a0a54ee..ffb2769 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -331,7 +331,7 @@ void SubscriptionStore::publishRecursively(std::vector::const_itera } } -void SubscriptionStore::queuePacketAtSubscribers(const std::vector &subtopics, const MqttPacket &packet, bool dollar) +void SubscriptionStore::queuePacketAtSubscribers(const std::vector &subtopics, MqttPacket &packet, bool dollar) { assert(subtopics.size() > 0); @@ -346,9 +346,10 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::vector publishRecursively(subtopics.begin(), subtopics.end(), startNode, subscriberSessions); } + std::shared_ptr possibleQos0Copy; for(const ReceivingSubscriber &x : subscriberSessions) { - x.session->writePacket(packet, x.qos, false, count); + x.session->writePacket(packet, x.qos, possibleQos0Copy, count); } std::shared_ptr sender = packet.getSender(); @@ -422,9 +423,10 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr possibleQos0Copy; + for(MqttPacket &packet : packetList) { - ses->writePacket(packet, max_qos, true, count); + ses->writePacket(packet, max_qos, possibleQos0Copy, count); } return count; diff --git a/subscriptionstore.h b/subscriptionstore.h index 1a9041a..7923c98 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -120,7 +120,7 @@ public: void registerClientAndKickExistingOne(std::shared_ptr &client); bool sessionPresent(const std::string &clientid); - void queuePacketAtSubscribers(const std::vector &subtopics, const MqttPacket &packet, bool dollar = false); + void queuePacketAtSubscribers(const std::vector &subtopics, MqttPacket &packet, bool dollar = false); void giveClientRetainedMessagesRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, RetainedMessageNode *this_node, bool poundMode, std::forward_list &packetList) const; uint64_t giveClientRetainedMessages(const std::shared_ptr &ses, const std::vector &subscribeSubtopics, char max_qos); diff --git a/threaddata.cpp b/threaddata.cpp index 9d59331..5a366ca 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -127,7 +127,8 @@ void ThreadData::publishStat(const std::string &topic, uint64_t n) splitTopic(topic, subtopics); const std::string payload = std::to_string(n); Publish p(topic, payload, 0); - subscriptionStore->queuePacketAtSubscribers(subtopics, p, true); + MqttPacket pack(p); + subscriptionStore->queuePacketAtSubscribers(subtopics, pack, true); subscriptionStore->setRetainedMessage(topic, subtopics, payload, 0); }