diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 9a23bf4..682e440 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -1420,53 +1420,57 @@ void MainTests::testNotMessingUpQosLevels() void MainTests::testUnSubscribe() { - TwoClientTestContext testContext; + FlashMQTestClient sender; + FlashMQTestClient receiver; + + sender.start(); + sender.connectClient(ProtocolVersion::Mqtt311); - testContext.connectSender(); - testContext.connectReceiver(); + receiver.start(); + receiver.connectClient(ProtocolVersion::Mqtt311); - testContext.subscribeReceiver("Rebecca/Bunch", 2); - testContext.subscribeReceiver("Josh/Chan", 1); - testContext.subscribeReceiver("White/Josh", 1); + receiver.subscribe("Rebecca/Bunch", 2); + receiver.subscribe("Josh/Chan", 1); + receiver.subscribe("White/Josh", 1); - testContext.publish("Rebecca/Bunch", "Bunch here", 2); - testContext.publish("White/Josh", "Anteater", 2); - testContext.publish("Josh/Chan", "Human flip-flop", 2); + sender.publish("Rebecca/Bunch", "Bunch here", 2); + sender.publish("White/Josh", "Anteater", 2); + sender.publish("Josh/Chan", "Human flip-flop", 2); - testContext.waitReceiverReceived(3); + receiver.waitForMessageCount(3); - QVERIFY(std::any_of(testContext.receivedMessages.begin(), testContext.receivedMessages.end(), [](const QMQTT::Message &msg) { - return msg.payload() == "Bunch here" && msg.topic() == "Rebecca/Bunch"; + QVERIFY(std::any_of(receiver.receivedPublishes.begin(), receiver.receivedPublishes.end(), [](const MqttPacket &pack) { + return pack.getPayloadCopy() == "Bunch here" && pack.getTopic() == "Rebecca/Bunch"; })); - QVERIFY(std::any_of(testContext.receivedMessages.begin(), testContext.receivedMessages.end(), [](const QMQTT::Message &msg) { - return msg.payload() == "Anteater" && msg.topic() == "White/Josh"; + QVERIFY(std::any_of(receiver.receivedPublishes.begin(), receiver.receivedPublishes.end(), [](const MqttPacket &pack) { + return pack.getPayloadCopy() == "Anteater" && pack.getTopic() == "White/Josh"; })); - QVERIFY(std::any_of(testContext.receivedMessages.begin(), testContext.receivedMessages.end(), [](const QMQTT::Message &msg) { - return msg.payload() == "Human flip-flop" && msg.topic() == "Josh/Chan"; + QVERIFY(std::any_of(receiver.receivedPublishes.begin(), receiver.receivedPublishes.end(), [](const MqttPacket &pack) { + return pack.getPayloadCopy() == "Human flip-flop" && pack.getTopic() == "Josh/Chan"; })); - QCOMPARE(testContext.receivedMessages.count(), 3); + MYCASTCOMPARE(receiver.receivedPublishes.size(), 3); - testContext.receivedMessages.clear(); + receiver.clearReceivedLists(); - testContext.unsubscribeReceiver("Josh/Chan"); + receiver.unsubscribe("Josh/Chan"); - testContext.publish("Rebecca/Bunch", "Bunch here", 2); - testContext.publish("White/Josh", "Anteater", 2); - testContext.publish("Josh/Chan", "Human flip-flop", 2); + sender.publish("Rebecca/Bunch", "Bunch here", 2); + sender.publish("White/Josh", "Anteater", 2); + sender.publish("Josh/Chan", "Human flip-flop", 2); - testContext.waitReceiverReceived(2); + receiver.waitForMessageCount(2); - QCOMPARE(testContext.receivedMessages.count(), 2); + MYCASTCOMPARE(receiver.receivedPublishes.size(), 2); - QVERIFY(std::any_of(testContext.receivedMessages.begin(), testContext.receivedMessages.end(), [](const QMQTT::Message &msg) { - return msg.payload() == "Bunch here" && msg.topic() == "Rebecca/Bunch"; + QVERIFY(std::any_of(receiver.receivedPublishes.begin(), receiver.receivedPublishes.end(), [](const MqttPacket &pack) { + return pack.getPayloadCopy() == "Bunch here" && pack.getTopic() == "Rebecca/Bunch"; })); - QVERIFY(std::any_of(testContext.receivedMessages.begin(), testContext.receivedMessages.end(), [](const QMQTT::Message &msg) { - return msg.payload() == "Anteater" && msg.topic() == "White/Josh"; + QVERIFY(std::any_of(receiver.receivedPublishes.begin(), receiver.receivedPublishes.end(), [](const MqttPacket &pack) { + return pack.getPayloadCopy() == "Anteater" && pack.getTopic() == "White/Josh"; })); } diff --git a/flashmqtestclient.cpp b/flashmqtestclient.cpp index 9d6a32f..dafca3d 100644 --- a/flashmqtestclient.cpp +++ b/flashmqtestclient.cpp @@ -195,6 +195,23 @@ void FlashMQTestClient::subscribe(const std::string topic, char qos) } } +void FlashMQTestClient::unsubscribe(const std::string &topic) +{ + clearReceivedLists(); + + const uint16_t packet_id = 66; + + Unsubscribe unsub(client->getProtocolVersion(), packet_id, topic); + MqttPacket unsubPack(unsub); + client->writeMqttPacketAndBlameThisClient(unsubPack); + + waitForCondition([&]() { + return !this->receivedPackets.empty() && this->receivedPackets.front().packetType == PacketType::UNSUBACK; + }); + + // TODO: parse the UNSUBACK and check reason codes. +} + void FlashMQTestClient::publish(Publish &pub) { clearReceivedLists(); diff --git a/flashmqtestclient.h b/flashmqtestclient.h index 42778b0..e8804fd 100644 --- a/flashmqtestclient.h +++ b/flashmqtestclient.h @@ -37,6 +37,7 @@ public: void connectClient(ProtocolVersion protocolVersion, bool clean_start, uint32_t session_expiry_interval); void connectClient(ProtocolVersion protocolVersion, bool clean_start, uint32_t session_expiry_interval, std::function manipulateConnect); void subscribe(const std::string topic, char qos); + void unsubscribe(const std::string &topic); void publish(const std::string &topic, const std::string &payload, char qos); void publish(Publish &pub); void clearReceivedLists(); diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 0b569c9..3bfe720 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -93,7 +93,7 @@ MqttPacket::MqttPacket(const SubAck &subAck) : MqttPacket::MqttPacket(const UnsubAck &unsubAck) : bites(unsubAck.getLengthWithoutFixedHeader()) { - packetType = PacketType::SUBACK; + packetType = PacketType::UNSUBACK; first_byte = static_cast(packetType) << 4; writeUint16(unsubAck.packet_id); @@ -311,6 +311,29 @@ MqttPacket::MqttPacket(const Subscribe &subscribe) : calculateRemainingLength(); } +MqttPacket::MqttPacket(const Unsubscribe &unsubscribe) : + bites(unsubscribe.getLengthWithoutFixedHeader()), + packetType(PacketType::UNSUBSCRIBE) +{ +#ifndef TESTING + throw NotImplementedException("Code is only for testing."); +#endif + + first_byte = static_cast(packetType) << 4; + first_byte |= 2; // required reserved bit + + writeUint16(unsubscribe.packetId); + + if (unsubscribe.protocolVersion >= ProtocolVersion::Mqtt5) + { + writeProperties(unsubscribe.propertyBuilder); + } + + writeString(unsubscribe.topic); + + calculateRemainingLength(); +} + void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector &packetQueueIn, std::shared_ptr &sender) { while (buf.usedBytes() >= MQTT_HEADER_LENGH) diff --git a/mqttpacket.h b/mqttpacket.h index 6c67509..afe9c01 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -116,6 +116,7 @@ public: MqttPacket(const Auth &auth); MqttPacket(const Connect &connect); MqttPacket(const Subscribe &subscribe); + MqttPacket(const Unsubscribe &unsubscribe); static void bufferToMqttPackets(CirBuf &buf, std::vector &packetQueueIn, std::shared_ptr &sender); diff --git a/types.cpp b/types.cpp index b46cf52..3e64cf4 100644 --- a/types.cpp +++ b/types.cpp @@ -439,3 +439,25 @@ size_t Subscribe::getLengthWithoutFixedHeader() const return result; } + +Unsubscribe::Unsubscribe(const ProtocolVersion protocolVersion, uint16_t packetId, const std::string &topic) : + protocolVersion(protocolVersion), + packetId(packetId), + topic(topic) +{ + +} + +size_t Unsubscribe::getLengthWithoutFixedHeader() const +{ + size_t result = topic.size() + 2; + result += 2; // packet id + + if (this->protocolVersion >= ProtocolVersion::Mqtt5) + { + const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1; + result += proplen; + } + + return result; +} diff --git a/types.h b/types.h index ed12918..0dc0012 100644 --- a/types.h +++ b/types.h @@ -315,4 +315,20 @@ struct Subscribe size_t getLengthWithoutFixedHeader() const; }; +/** + * @brief The Unsubscribe struct can be used to construct a mqtt packet of type 'unsubscribe'. + * + * It's rudimentary. Offically you can unsubscribe to multiple topics at once, but I have no need for that. + */ +struct Unsubscribe +{ + const ProtocolVersion protocolVersion; + uint16_t packetId; + std::string topic; + std::shared_ptr propertyBuilder; + + Unsubscribe(const ProtocolVersion protocolVersion, uint16_t packetId, const std::string &topic); + size_t getLengthWithoutFixedHeader() const; +}; + #endif // TYPES_H