diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 9a9909c..c5da5f3 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -436,27 +436,31 @@ void MainTests::test_retained() void MainTests::test_retained_changed() { - TwoClientTestContext testContext; + FlashMQTestClient sender; + sender.start(); + sender.connectClient(ProtocolVersion::Mqtt311); - QByteArray payload = "We are testing"; - QString topic = "retaintopic"; + const std::string topic = "retaintopic"; - testContext.connectSender(); - testContext.publish(topic, payload, true); + Publish p(topic, "We are testing", 0); + p.retain = true; + sender.publish(p); - payload = "Changed payload"; + p.payload = "Changed payload"; + sender.publish(p); - testContext.publish(topic, payload, true); + FlashMQTestClient receiver; + receiver.start(); + receiver.connectClient(ProtocolVersion::Mqtt5); + receiver.subscribe(topic, 0); - testContext.connectReceiver(); - testContext.subscribeReceiver(topic); - testContext.waitReceiverReceived(1); + receiver.waitForMessageCount(1); - QCOMPARE(testContext.receivedMessages.count(), 1); + MYCASTCOMPARE(receiver.receivedPublishes.size(), 1); - QMQTT::Message msg = testContext.receivedMessages.first(); - QCOMPARE(msg.payload(), payload); - QVERIFY(msg.retain()); + MqttPacket &pack = receiver.receivedPublishes.front(); + QCOMPARE(pack.getPayloadCopy(), p.payload); + QVERIFY(pack.getRetain()); } void MainTests::test_retained_removed() diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 90cd893..25a5830 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -1011,6 +1011,8 @@ void MqttPacket::handleSubscribe() Authentication &authentication = *ThreadGlobals::getAuth(); + std::forward_list deferredSubscribes; + std::list subs_reponse_codes; while (remainingAfterPos() > 0) { @@ -1031,8 +1033,7 @@ void MqttPacket::handleSubscribe() splitTopic(topic, subtopics); if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, subtopics, AclAccess::subscribe, qos, false, getUserProperties()) == AuthResult::success) { - logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribed to '%s' QoS %d", sender->repr().c_str(), topic.c_str(), qos); - MainApp::getMainApp()->getSubscriptionStore()->addSubscription(sender, topic, subtopics, qos); + deferredSubscribes.emplace_front(topic, subtopics, qos); subs_reponse_codes.push_back(static_cast(qos)); } else @@ -1054,6 +1055,13 @@ void MqttPacket::handleSubscribe() SubAck subAck(this->protocolVersion, packet_id, subs_reponse_codes); MqttPacket response(subAck); sender->writeMqttPacket(response); + + // Adding the subscription will also send publishes for retained messages, so that's why we're doing it at the end. + for(const SubscriptionTuple &tup : deferredSubscribes) + { + logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribed to '%s' QoS %d", sender->repr().c_str(), tup.topic.c_str(), tup.qos); + MainApp::getMainApp()->getSubscriptionStore()->addSubscription(sender, tup.topic, tup.subtopics, tup.qos); + } } void MqttPacket::handleUnsubscribe() @@ -1804,6 +1812,14 @@ void MqttPacket::readIntoBuf(CirBuf &buf) const buf.write(bites.data(), bites.size()); } +SubscriptionTuple::SubscriptionTuple(const std::string &topic, const std::vector &subtopics, char qos) : + topic(topic), + subtopics(subtopics), + qos(qos) +{ + +} + diff --git a/mqttpacket.h b/mqttpacket.h index 97cc263..5984158 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -163,4 +163,13 @@ public: const std::vector> *getUserProperties() const; }; +struct SubscriptionTuple +{ + const std::string topic; + const std::vector subtopics; + const char qos; + + SubscriptionTuple(const std::string &topic, const std::vector &subtopics, char qos); +}; + #endif // MQTTPACKET_H