diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 95978cc..fea922b 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -133,6 +133,8 @@ private slots: void testReceivingRetainedMessageWithQoS(); + void testQosDowngradeOnOfflineClients(); + }; MainTests::MainTests() @@ -1750,6 +1752,66 @@ void MainTests::testReceivingRetainedMessageWithQoS() MYCASTCOMPARE(9, testCount); } +void MainTests::testQosDowngradeOnOfflineClients() +{ + int testCount = 0; + + std::vector subscribePaths {"topic1/FOOBAR", "+/+", "#"}; + + for (char sendQos = 1; sendQos < 3; sendQos++) + { + for (char subscribeQos = 1; subscribeQos < 3; subscribeQos++) + { + for (const std::string &subscribePath : subscribePaths) + { + testCount++; + + // First start with clean_start to reset the session. + std::unique_ptr receiver = std::make_unique(); + receiver->start(); + receiver->connectClient(ProtocolVersion::Mqtt5, true, 600, [](Connect &connect) { + connect.clientid = "TheReceiver"; + }); + receiver->subscribe(subscribePath, subscribeQos); + receiver->disconnect(ReasonCodes::Success); + receiver.reset(); + + const std::string payload = "We are testing"; + + FlashMQTestClient sender; + sender.start(); + sender.connectClient(ProtocolVersion::Mqtt311); + + Publish p1("topic1/FOOBAR", payload, sendQos); + + for (int i = 0; i < 10; i++) + { + sender.publish(p1); + } + + // Now we connect again, and we should now pick up the existing session. + receiver = std::make_unique(); + receiver->start(); + receiver->connectClient(ProtocolVersion::Mqtt5, false, 600, [](Connect &connect) { + connect.clientid = "TheReceiver"; + }); + + receiver->waitForMessageCount(10); + + const char expQos = std::min(sendQos, subscribeQos); + + MYCASTCOMPARE(receiver->receivedPublishes.size(), 10); + + QVERIFY(std::all_of(receiver->receivedPublishes.begin(), receiver->receivedPublishes.end(), [&](MqttPacket &pack) { return pack.getQos() == expQos;})); + QVERIFY(std::all_of(receiver->receivedPublishes.begin(), receiver->receivedPublishes.end(), [&](MqttPacket &pack) { return pack.getTopic() == "topic1/FOOBAR";})); + QVERIFY(std::all_of(receiver->receivedPublishes.begin(), receiver->receivedPublishes.end(), [&](MqttPacket &pack) { return pack.getPayloadCopy() == payload;})); + } + } + } + + MYCASTCOMPARE(12, testCount); +} + int main(int argc, char *argv[]) { QCoreApplication app(argc, argv); diff --git a/publishcopyfactory.cpp b/publishcopyfactory.cpp index 41b61ee..5e274ff 100644 --- a/publishcopyfactory.cpp +++ b/publishcopyfactory.cpp @@ -95,22 +95,36 @@ bool PublishCopyFactory::getRetain() const return publish->retain; } -Publish PublishCopyFactory::getNewPublish() const +/** + * @brief PublishCopyFactory::getNewPublish gets a new publish object from an existing packet or publish. + * @param new_max_qos + * @return + * + * It being a public function, the idea is that it's only needed for creating publish objects for storing QoS messages for off-line + * clients. For on-line clients, you're always making a packet (with getOptimumPacket()). + */ +Publish PublishCopyFactory::getNewPublish(char new_max_qos) const { + // (At time of writing) we only need to construct new publishes for QoS (because we're storing QoS publishes for offline clients). If + // you're doing it elsewhere, it's a bug. + assert(orgQos > 0); + assert(new_max_qos > 0); + + const char actualQos = getEffectiveQos(new_max_qos); + if (packet) { assert(packet->getQos() > 0); - assert(orgQos > 0); // We only need to construct new publishes for QoS. If you're doing it elsewhere, it's a bug. Publish p(packet->getPublishData()); - p.qos = orgQos; + p.qos = actualQos; return p; } assert(publish->qos > 0); // Same check as above, but then for Publish objects. Publish p(*publish); - p.qos = orgQos; + p.qos = actualQos; return p; } diff --git a/publishcopyfactory.h b/publishcopyfactory.h index 38734e6..d7054e4 100644 --- a/publishcopyfactory.h +++ b/publishcopyfactory.h @@ -34,7 +34,7 @@ public: const std::string &getTopic() const; const std::vector &getSubtopics(); bool getRetain() const; - Publish getNewPublish() const; + Publish getNewPublish(char new_max_qos) const; std::shared_ptr getSender(); const std::vector> *getUserProperties() const; diff --git a/qospacketqueue.cpp b/qospacketqueue.cpp index 84fb967..0f07d53 100644 --- a/qospacketqueue.cpp +++ b/qospacketqueue.cpp @@ -76,7 +76,7 @@ void QoSPublishQueue::queuePublish(PublishCopyFactory ©Factory, uint16_t id, assert(new_max_qos > 0); assert(id > 0); - Publish pub = copyFactory.getNewPublish(); + Publish pub = copyFactory.getNewPublish(new_max_qos); queue.emplace_back(std::move(pub), id); qosQueueBytes += queue.back().getApproximateMemoryFootprint(); } diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index 1892b02..69bd20d 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -253,7 +253,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr } } - if (!session || session->getDestroyOnDisconnect()) + if (!session || session->getDestroyOnDisconnect() || clean_start) { session = std::make_shared();