diff --git a/FlashMQTests/FlashMQTests.pro b/FlashMQTests/FlashMQTests.pro index f8d746b..53a2776 100644 --- a/FlashMQTests/FlashMQTests.pro +++ b/FlashMQTests/FlashMQTests.pro @@ -55,6 +55,7 @@ SOURCES += tst_maintests.cpp \ ../globalstats.cpp \ ../derivablecounter.cpp \ ../packetdatatypes.cpp \ + ../flashmqtestclient.cpp \ mainappthread.cpp \ twoclienttestcontext.cpp @@ -102,6 +103,7 @@ HEADERS += \ ../globalstats.h \ ../derivablecounter.h \ ../packetdatatypes.h \ + ../flashmqtestclient.h \ mainappthread.h \ twoclienttestcontext.h diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 6f7ad13..5c40916 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -35,6 +35,8 @@ License along with FlashMQ. If not, see . #include "threaddata.h" #include "threadglobals.h" +#include "flashmqtestclient.h" + // Dumb Qt version gives warnings when comparing uint with number literal. template inline bool myCastCompare(const T1 &t1, const T2 &t2, const char *actual, const char *expected, @@ -56,6 +58,7 @@ class MainTests : public QObject Q_OBJECT QScopedPointer mainApp; + std::shared_ptr dummyThreadData; void testParsePacketHelper(const std::string &topic, char from_qos, bool retain); @@ -117,6 +120,8 @@ private slots: void testUnSubscribe(); + void testBasicsWithFlashMQTestClient(); + }; MainTests::MainTests() @@ -135,6 +140,12 @@ void MainTests::init() mainApp.reset(new MainAppThread()); mainApp->start(); mainApp->waitForStarted(); + + // We test functions directly that the server normally only calls from worker threads, in which thread data is available. This is kind of a dummy-fix, until + // we actually need correct thread data at those points (at this point, it's only to increase message counters). + std::shared_ptr settings = std::make_shared(); + this->dummyThreadData = std::make_shared(666, settings); + ThreadGlobals::assignThreadData(dummyThreadData.get()); } void MainTests::cleanup() @@ -1342,6 +1353,77 @@ void MainTests::testUnSubscribe() })); } +/** + * @brief MainTests::testBasicsWithFlashMQTestClient was used to develop FlashMQTestClient. + */ +void MainTests::testBasicsWithFlashMQTestClient() +{ + FlashMQTestClient client; + client.start(); + client.connectClient(ProtocolVersion::Mqtt311); + + MqttPacket &connAckPack = client.receivedPackets.front(); + QVERIFY(connAckPack.packetType == PacketType::CONNACK); + + { + client.subscribe("a/b", 1); + + MqttPacket &subAck = client.receivedPackets.front(); + SubAckData subAckData = subAck.parseSubAckData(); + + QVERIFY(subAckData.subAckCodes.size() == 1); + QVERIFY(subAckData.subAckCodes.front() == 1); + } + + { + client.subscribe("c/d", 2); + + MqttPacket &subAck = client.receivedPackets.front(); + SubAckData subAckData = subAck.parseSubAckData(); + + QVERIFY(subAckData.subAckCodes.size() == 1); + QVERIFY(subAckData.subAckCodes.front() == 2); + } + + client.clearReceivedLists(); + + FlashMQTestClient publisher; + publisher.start(); + publisher.connectClient(ProtocolVersion::Mqtt5); + + { + publisher.publish("a/b", "wave", 2); + + client.waitForMessageCount(1); + MqttPacket &p = client.receivedPublishes.front(); + + QCOMPARE(p.getPublishData().topic, "a/b"); + QCOMPARE(p.getPayloadCopy(), "wave"); + QCOMPARE(p.getPublishData().qos, 1); + QVERIFY(p.getPacketId() > 0); + QVERIFY(p.protocolVersion == ProtocolVersion::Mqtt311); + } + + client.clearReceivedLists(); + + { + publisher.publish("c/d", "asdfasdfasdf", 2); + + client.waitForMessageCount(1); + MqttPacket &p = client.receivedPublishes.back(); + + MYCASTCOMPARE(client.receivedPublishes.size(), 1); + + QCOMPARE(p.getPublishData().topic, "c/d"); + QCOMPARE(p.getPayloadCopy(), "asdfasdfasdf"); + QCOMPARE(p.getPublishData().qos, 2); + QVERIFY(p.getPacketId() > 1); // It's the same client, so it should not re-use packet id 1 + QVERIFY(p.protocolVersion == ProtocolVersion::Mqtt311); + } + + +} + int main(int argc, char *argv[]) { diff --git a/client.cpp b/client.cpp index 95b1c2b..1e6df1b 100644 --- a/client.cpp +++ b/client.cpp @@ -62,7 +62,7 @@ Client::Client(int fd, std::shared_ptr threadData, SSL *ssl, bool we Client::~Client() { // Dummy clients, that I sometimes need just because the interface demands it but there's not actually a client, have no thread. - if (this->threadData.expired()) + if (this->epoll_fd == 0) return; if (disconnectReason.empty()) diff --git a/client.h b/client.h index 7de28ed..89ee81d 100644 --- a/client.h +++ b/client.h @@ -187,6 +187,10 @@ public: void setExtendedAuthenticationMethod(const std::string &authMethod); const std::string &getExtendedAuthenticationMethod() const; +#ifdef TESTING + std::function onPacketReceived; +#endif + #ifndef NDEBUG void setFakeUpgraded(); #endif diff --git a/flashmqtestclient.cpp b/flashmqtestclient.cpp new file mode 100644 index 0000000..2efc450 --- /dev/null +++ b/flashmqtestclient.cpp @@ -0,0 +1,235 @@ +#include "flashmqtestclient.h" + +#include +#include +#include "errno.h" +#include "functional" + +#include "threadloop.h" +#include "utils.h" +#include "client.h" + + +#define TEST_CLIENT_MAX_EVENTS 25 + +int FlashMQTestClient::clientCount = 0; + +FlashMQTestClient::FlashMQTestClient() : + settings(std::make_shared()), + testServerWorkerThreadData(std::make_shared(0, settings)), + dummyThreadData(std::make_shared(666, settings)) +{ + +} + +/** + * @brief FlashMQTestClient::~FlashMQTestClient properly quits the threads when exiting. + * + * This prevents accidental crashes on calling terminate(), and Qt Macro's prematurely end the method, skipping explicit waits after the tests. + */ +FlashMQTestClient::~FlashMQTestClient() +{ + waitForQuit(); +} + +void FlashMQTestClient::waitForCondition(std::function f) +{ + int n = 0; + while(n++ < 100) + { + usleep(10000); + + std::lock_guard locker(receivedListMutex); + + if (f()) + break; + } + + std::lock_guard locker(receivedListMutex); + + if (!f()) + { + throw std::runtime_error("Wait condition failed."); + } +} + +void FlashMQTestClient::clearReceivedLists() +{ + std::lock_guard locker(receivedListMutex); + receivedPackets.clear(); + receivedPublishes.clear(); +} + +void FlashMQTestClient::start() +{ + testServerWorkerThreadData->start(&do_thread_work); +} + +void FlashMQTestClient::connectClient(ProtocolVersion protocolVersion) +{ + int sockfd = check(socket(AF_INET, SOCK_STREAM, 0)); + + struct sockaddr_in servaddr; + bzero(&servaddr, sizeof(servaddr)); + + const std::string hostname = "127.0.0.1"; + + servaddr.sin_family = AF_INET; + servaddr.sin_addr.s_addr = inet_addr(hostname.c_str()); + servaddr.sin_port = htons(1883); + + int flags = fcntl(sockfd, F_GETFL); + fcntl(sockfd, F_SETFL, flags | O_NONBLOCK); + + int rc = connect(sockfd, reinterpret_cast(&servaddr), sizeof (servaddr)); + + if (rc < 0 && errno != EINPROGRESS) + { + throw std::runtime_error(strerror(errno)); + } + + const std::string clientid = formatString("testclient_%d", clientCount++); + + this->client = std::make_shared(sockfd, testServerWorkerThreadData, nullptr, false, reinterpret_cast(&servaddr), settings.get()); + this->client->setClientProperties(protocolVersion, clientid, "user", false, 60); + + testServerWorkerThreadData->giveClient(this->client); + + // This gets called in the test client's worker thread, but the STL container's minimal thread safety should be enough: only list manipulation is + // mutexed, elements within are not. + client->onPacketReceived = [&](MqttPacket &pack) + { + std::lock_guard locker(receivedListMutex); + + if (pack.packetType == PacketType::PUBLISH) + { + pack.parsePublishData(); + + MqttPacket copyPacket = pack; + this->receivedPublishes.push_back(copyPacket); + + if (pack.getPublishData().qos > 0) + { + PubResponse pubAck(this->client->getProtocolVersion(), PacketType::PUBACK, ReasonCodes::Success, pack.getPacketId()); + this->client->writeMqttPacketAndBlameThisClient(pubAck); + } + } + else if (pack.packetType == PacketType::PUBREL) + { + pack.parsePubRelData(); + PubResponse pubComp(this->client->getProtocolVersion(), PacketType::PUBCOMP, ReasonCodes::Success, pack.getPacketId()); + this->client->writeMqttPacketAndBlameThisClient(pubComp); + } + else if (pack.packetType == PacketType::PUBREC) + { + pack.parsePubRecData(); + PubResponse pubRel(this->client->getProtocolVersion(), PacketType::PUBREL, ReasonCodes::Success, pack.getPacketId()); + this->client->writeMqttPacketAndBlameThisClient(pubRel); + } + + this->receivedPackets.push_back(std::move(pack)); + }; + + Connect connect(protocolVersion, client->getClientId()); + MqttPacket connectPack(connect); + this->client->writeMqttPacketAndBlameThisClient(connectPack); + + waitForConnack(); +} + +void FlashMQTestClient::subscribe(const std::string topic, char qos) +{ + clearReceivedLists(); + + const uint16_t packet_id = 66; + + Subscribe sub(client->getProtocolVersion(), packet_id, topic, qos); + MqttPacket subPack(sub); + client->writeMqttPacketAndBlameThisClient(subPack); + + waitForCondition([&]() { + return !this->receivedPackets.empty() && this->receivedPackets.front().packetType == PacketType::SUBACK; + }); + + MqttPacket &subAck = this->receivedPackets.front(); + SubAckData data = subAck.parseSubAckData(); + + if (data.packet_id != packet_id) + throw std::runtime_error("Incorrect packet id in suback"); + + if (!std::all_of(data.subAckCodes.begin(), data.subAckCodes.end(), [&](uint8_t x) { return x <= qos ;})) + { + throw std::runtime_error("Suback indicates error."); + } +} + +void FlashMQTestClient::publish(const std::string &topic, const std::string &payload, char qos) +{ + clearReceivedLists(); + + const uint16_t packet_id = 77; + + Publish pub(topic, payload, qos); + MqttPacket pubPack(client->getProtocolVersion(), pub); + if (qos > 0) + pubPack.setPacketId(packet_id); + client->writeMqttPacketAndBlameThisClient(pubPack); + + if (qos == 1) + { + waitForCondition([&]() { + return this->receivedPackets.size() == 1; + }); + + MqttPacket &pubAckPack = this->receivedPackets.front(); + pubAckPack.parsePubAckData(); + + if (pubAckPack.packetType != PacketType::PUBACK) + throw std::runtime_error("First packet received from server is not a PUBACK."); + + if (pubAckPack.getPacketId() != packet_id || this->receivedPackets.size() != 1) + throw std::runtime_error("Packet ID mismatch on QoS 1 publish or packet count wrong."); + } + else if (qos == 2) + { + waitForCondition([&]() { + return this->receivedPackets.size() >= 2; + }); + + MqttPacket &pubRecPack = this->receivedPackets.front(); + pubRecPack.parsePubRecData(); + MqttPacket &pubCompPack = this->receivedPackets.back(); + pubCompPack.parsePubComp(); + + if (pubRecPack.packetType != PacketType::PUBREC) + throw std::runtime_error("First packet received from server is not a PUBREC."); + + if (pubCompPack.packetType != PacketType::PUBCOMP) + throw std::runtime_error("Last packet received from server is not a PUBCOMP."); + + if (pubRecPack.getPacketId() != packet_id || pubCompPack.getPacketId() != packet_id) + throw std::runtime_error("Packet ID mismatch on QoS 2 publish."); + } +} + +void FlashMQTestClient::waitForQuit() +{ + testServerWorkerThreadData->queueQuit(); + testServerWorkerThreadData->waitForQuit(); +} + +void FlashMQTestClient::waitForConnack() +{ + waitForCondition([&]() { + return std::any_of(this->receivedPackets.begin(), this->receivedPackets.end(), [](const MqttPacket &p) { + return p.packetType == PacketType::CONNACK; + }); + }); +} + +void FlashMQTestClient::waitForMessageCount(const size_t count) +{ + waitForCondition([&]() { + return this->receivedPublishes.size() >= count; + }); +} diff --git a/flashmqtestclient.h b/flashmqtestclient.h new file mode 100644 index 0000000..8fe4f1c --- /dev/null +++ b/flashmqtestclient.h @@ -0,0 +1,45 @@ +#ifndef FLASHMQTESTCLIENT_H +#define FLASHMQTESTCLIENT_H + +#include +#include + +#include "subscriptionstore.h" + +/** + * @brief The FlashMQTestClient class uses the existing server code as a client, for testing purposes. + */ +class FlashMQTestClient +{ + std::shared_ptr settings; + std::shared_ptr testServerWorkerThreadData; + std::shared_ptr client; + + std::shared_ptr dummyThreadData; + + std::mutex receivedListMutex; + + static int clientCount; + + void waitForCondition(std::function f); + + +public: + std::list receivedPackets; + std::list receivedPublishes; + + FlashMQTestClient(); + ~FlashMQTestClient(); + + void start(); + void connectClient(ProtocolVersion protocolVersion); + void subscribe(const std::string topic, char qos); + void publish(const std::string &topic, const std::string &payload, char qos); + void clearReceivedLists(); + + void waitForQuit(); + void waitForConnack(); + void waitForMessageCount(const size_t count); +}; + +#endif // FLASHMQTESTCLIENT_H diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 30b0f8d..84f8b7d 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -229,6 +229,60 @@ MqttPacket::MqttPacket(const Auth &auth) : calculateRemainingLength(); } +MqttPacket::MqttPacket(const Connect &connect) : + bites(connect.getLengthWithoutFixedHeader()), + protocolVersion(connect.protocolVersion), + packetType(PacketType::CONNECT) +{ +#ifndef TESTING + throw NotImplementedException("Code is only for testing."); +#endif + + first_byte = static_cast(packetType) << 4; + + const std::string magicString = connect.getMagicString(); + writeString(magicString); + + writeByte(static_cast(protocolVersion)); + writeByte(2); // flags; The only bit set is 'clean session'. + + // Keep-alive + writeUint16(60); + + if (connect.protocolVersion >= ProtocolVersion::Mqtt5) + { + writeProperties(connect.propertyBuilder); + } + + writeString(connect.clientid); + + calculateRemainingLength(); +} + +MqttPacket::MqttPacket(const Subscribe &subscribe) : + bites(subscribe.getLengthWithoutFixedHeader()), + packetType(PacketType::SUBSCRIBE) +{ +#ifndef TESTING + throw NotImplementedException("Code is only for testing."); +#endif + + first_byte = static_cast(packetType) << 4; + first_byte |= 2; // required reserved bit + + writeUint16(subscribe.packetId); + + if (subscribe.protocolVersion >= ProtocolVersion::Mqtt5) + { + writeProperties(subscribe.propertyBuilder); + } + + writeString(subscribe.topic); + writeByte(subscribe.qos); + + calculateRemainingLength(); +} + void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector &packetQueueIn, std::shared_ptr &sender) { while (buf.usedBytes() >= MQTT_HEADER_LENGH) @@ -1035,6 +1089,8 @@ void MqttPacket::handleUnsubscribe() void MqttPacket::parsePublishData() { + assert(externallyReceived); + setPosToDataStart(); publishData.retain = (first_byte & 0b00000001); @@ -1214,6 +1270,7 @@ void MqttPacket::handlePublish() void MqttPacket::parsePubAckData() { setPosToDataStart(); + this->publishData.qos = 1; this->packet_id = readTwoBytesToUInt16(); } @@ -1554,6 +1611,12 @@ void MqttPacket::writeVariableByteInt(const VariableByteInt &v) writeBytes(v.data(), v.getLen()); } +void MqttPacket::writeString(const std::string &s) +{ + writeUint16(s.length()); + writeBytes(s.c_str(), s.length()); +} + uint16_t MqttPacket::readTwoBytesToUInt16() { if (pos + 2 > bites.size()) diff --git a/mqttpacket.h b/mqttpacket.h index 32c0645..97cc263 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -78,6 +78,7 @@ class MqttPacket void writeBytes(const char *b, size_t len); void writeProperties(const std::shared_ptr &properties); void writeVariableByteInt(const VariableByteInt &v); + void writeString(const std::string &s); uint16_t readTwoBytesToUInt16(); uint32_t readFourBytesToUint32(); size_t remainingAfterPos(); @@ -89,11 +90,18 @@ class MqttPacket void setPosToDataStart(); bool atEnd() const; +#ifndef TESTING + // In production, I want to be sure I don't accidentally copy packets, because it's slow. MqttPacket(const MqttPacket &other) = delete; +#endif public: +#ifdef TESTING + // In testing I need to copy packets for administrative purposes. + MqttPacket(const MqttPacket &other) = default; +#endif PacketType packetType = PacketType::Reserved; - MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, std::shared_ptr &sender); // Constructor for parsing incoming packets. + MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, std::shared_ptr &sender); // Constructor for parsing incoming packets. MqttPacket(MqttPacket &&other) = default; size_t setClientSpecificPropertiesAndGetRequiredSizeForPublish(const ProtocolVersion protocolVersion, Publish &publishData) const; @@ -106,6 +114,8 @@ public: MqttPacket(const PubResponse &pubAck); MqttPacket(const Disconnect &disconnect); MqttPacket(const Auth &auth); + MqttPacket(const Connect &connect); + MqttPacket(const Subscribe &subscribe); static void bufferToMqttPackets(CirBuf &buf, std::vector &packetQueueIn, std::shared_ptr &sender); diff --git a/threadloop.cpp b/threadloop.cpp index 9e60ae5..c3c161d 100644 --- a/threadloop.cpp +++ b/threadloop.cpp @@ -110,6 +110,11 @@ void do_thread_work(ThreadData *threadData) for (MqttPacket &packet : packetQueueIn) { +#ifdef TESTING + if (client->onPacketReceived) + client->onPacketReceived(packet); + else +#endif packet.handle(); } diff --git a/types.cpp b/types.cpp index a7a56f7..86f23b6 100644 --- a/types.cpp +++ b/types.cpp @@ -345,9 +345,57 @@ size_t Auth::getLengthWithoutFixedHeader() const return result; } +Connect::Connect(ProtocolVersion protocolVersion, const std::string &clientid) : + protocolVersion(protocolVersion), + clientid(clientid) +{ + +} + +size_t Connect::getLengthWithoutFixedHeader() const +{ + size_t result = clientid.length() + 2; + + result += this->protocolVersion <= ProtocolVersion::Mqtt31 ? 6 : 4; + result += 6; // header stuff, lengths, keep-alive + + if (this->protocolVersion >= ProtocolVersion::Mqtt5) + { + const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1; + result += proplen; + } + return result; + +} +std::string Connect::getMagicString() const +{ + if (protocolVersion <= ProtocolVersion::Mqtt31) + return "MQIsdp"; + else + return "MQTT"; +} +Subscribe::Subscribe(const ProtocolVersion protocolVersion, uint16_t packetId, const std::string &topic, char qos) : + protocolVersion(protocolVersion), + packetId(packetId), + topic(topic), + qos(qos) +{ +} +size_t Subscribe::getLengthWithoutFixedHeader() const +{ + size_t result = topic.size() + 2; + result += 2; // packet id + result += 1; // requested QoS + if (this->protocolVersion >= ProtocolVersion::Mqtt5) + { + const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1; + result += proplen; + } + return result; +} diff --git a/types.h b/types.h index 193dfc4..2fc79ce 100644 --- a/types.h +++ b/types.h @@ -277,4 +277,34 @@ public: size_t getLengthWithoutFixedHeader() const; }; +struct Connect +{ + const ProtocolVersion protocolVersion; + std::string clientid; + std::string username; + std::string password; + std::shared_ptr propertyBuilder; + + Connect(ProtocolVersion protocolVersion, const std::string &clientid); + size_t getLengthWithoutFixedHeader() const; + std::string getMagicString() const; +}; + +/** + * @brief The Subscribe struct can be used to construct a mqtt packet of type 'subscribe'. + * + * It's rudimentary. Offically you can subscribe to multiple topics at once, but I have no need for that. + */ +struct Subscribe +{ + const ProtocolVersion protocolVersion; + uint16_t packetId; + std::string topic; + char qos; + std::shared_ptr propertyBuilder; + + Subscribe(const ProtocolVersion protocolVersion, uint16_t packetId, const std::string &topic, char qos); + size_t getLengthWithoutFixedHeader() const; +}; + #endif // TYPES_H