#include "flashmqtestclient.h" #include #include #include "errno.h" #include "functional" #include "threadloop.h" #include "utils.h" #include "client.h" #define TEST_CLIENT_MAX_EVENTS 25 int FlashMQTestClient::clientCount = 0; FlashMQTestClient::FlashMQTestClient() : settings(std::make_shared()), testServerWorkerThreadData(std::make_shared(0, settings)), dummyThreadData(std::make_shared(666, settings)) { } /** * @brief FlashMQTestClient::~FlashMQTestClient properly quits the threads when exiting. * * This prevents accidental crashes on calling terminate(), and Qt Macro's prematurely end the method, skipping explicit waits after the tests. */ FlashMQTestClient::~FlashMQTestClient() { waitForQuit(); } void FlashMQTestClient::waitForCondition(std::function f, int timeout) { const int loopCount = (timeout * 1000) / 10; int n = 0; while(n++ < loopCount) { usleep(10000); std::lock_guard locker(receivedListMutex); if (f()) break; } std::lock_guard locker(receivedListMutex); if (!f()) { throw std::runtime_error("Wait condition failed."); } } void FlashMQTestClient::clearReceivedLists() { std::lock_guard locker(receivedListMutex); receivedPackets.clear(); receivedPublishes.clear(); } void FlashMQTestClient::setWill(std::shared_ptr &will) { this->will = will; } void FlashMQTestClient::disconnect(ReasonCodes reason) { client->setReadyForDisconnect(); Disconnect d(this->client->getProtocolVersion(), reason); client->writeMqttPacket(d); } void FlashMQTestClient::start() { testServerWorkerThreadData->start(&do_thread_work); } void FlashMQTestClient::connectClient(ProtocolVersion protocolVersion) { connectClient(protocolVersion, true, 0, [](Connect&){}); } void FlashMQTestClient::connectClient(ProtocolVersion protocolVersion, bool clean_start, uint32_t session_expiry_interval) { connectClient(protocolVersion, clean_start, session_expiry_interval, [](Connect&){}); } void FlashMQTestClient::connectClient(ProtocolVersion protocolVersion, bool clean_start, uint32_t session_expiry_interval, std::function manipulateConnect) { int sockfd = check(socket(AF_INET, SOCK_STREAM, 0)); struct sockaddr_in servaddr; bzero(&servaddr, sizeof(servaddr)); const std::string hostname = "127.0.0.1"; servaddr.sin_family = AF_INET; servaddr.sin_addr.s_addr = inet_addr(hostname.c_str()); servaddr.sin_port = htons(21883); int flags = fcntl(sockfd, F_GETFL); fcntl(sockfd, F_SETFL, flags | O_NONBLOCK); int rc = connect(sockfd, reinterpret_cast(&servaddr), sizeof (servaddr)); if (rc < 0 && errno != EINPROGRESS) { throw std::runtime_error(strerror(errno)); } const std::string clientid = formatString("testclient_%d", clientCount++); this->client = std::make_shared(sockfd, testServerWorkerThreadData, nullptr, false, reinterpret_cast(&servaddr), settings.get()); this->client->setClientProperties(protocolVersion, clientid, "user", false, 60); testServerWorkerThreadData->giveClient(this->client); // This gets called in the test client's worker thread, but the STL container's minimal thread safety should be enough: only list manipulation is // mutexed, elements within are not. client->onPacketReceived = [&](MqttPacket &pack) { std::lock_guard locker(receivedListMutex); if (pack.packetType == PacketType::PUBLISH) { pack.parsePublishData(); MqttPacket copyPacket = pack; this->receivedPublishes.push_back(copyPacket); if (pack.getPublishData().qos > 0) { PubResponse pubAck(this->client->getProtocolVersion(), PacketType::PUBACK, ReasonCodes::Success, pack.getPacketId()); this->client->writeMqttPacketAndBlameThisClient(pubAck); } } else if (pack.packetType == PacketType::PUBREL) { pack.parsePubRelData(); PubResponse pubComp(this->client->getProtocolVersion(), PacketType::PUBCOMP, ReasonCodes::Success, pack.getPacketId()); this->client->writeMqttPacketAndBlameThisClient(pubComp); } else if (pack.packetType == PacketType::PUBREC) { pack.parsePubRecData(); PubResponse pubRel(this->client->getProtocolVersion(), PacketType::PUBREL, ReasonCodes::Success, pack.getPacketId()); this->client->writeMqttPacketAndBlameThisClient(pubRel); } this->receivedPackets.push_back(std::move(pack)); }; Connect connect(protocolVersion, client->getClientId()); connect.will = this->will; connect.clean_start = clean_start; connect.constructPropertyBuilder(); connect.propertyBuilder->writeSessionExpiry(session_expiry_interval); manipulateConnect(connect); MqttPacket connectPack(connect); this->client->writeMqttPacketAndBlameThisClient(connectPack); waitForConnack(); this->client->setAuthenticated(true); } void FlashMQTestClient::subscribe(const std::string topic, char qos) { clearReceivedLists(); const uint16_t packet_id = 66; Subscribe sub(client->getProtocolVersion(), packet_id, topic, qos); MqttPacket subPack(sub); client->writeMqttPacketAndBlameThisClient(subPack); waitForCondition([&]() { return !this->receivedPackets.empty() && this->receivedPackets.front().packetType == PacketType::SUBACK; }); MqttPacket &subAck = this->receivedPackets.front(); SubAckData data = subAck.parseSubAckData(); if (data.packet_id != packet_id) throw std::runtime_error("Incorrect packet id in suback"); if (!std::all_of(data.subAckCodes.begin(), data.subAckCodes.end(), [&](uint8_t x) { return x <= qos ;})) { throw std::runtime_error("Suback indicates error."); } } void FlashMQTestClient::publish(Publish &pub) { clearReceivedLists(); const uint16_t packet_id = 77; MqttPacket pubPack(client->getProtocolVersion(), pub); if (pub.qos > 0) pubPack.setPacketId(packet_id); client->writeMqttPacketAndBlameThisClient(pubPack); if (pub.qos == 1) { waitForCondition([&]() { return this->receivedPackets.size() == 1; }); MqttPacket &pubAckPack = this->receivedPackets.front(); pubAckPack.parsePubAckData(); if (pubAckPack.packetType != PacketType::PUBACK) throw std::runtime_error("First packet received from server is not a PUBACK."); if (pubAckPack.getPacketId() != packet_id || this->receivedPackets.size() != 1) throw std::runtime_error("Packet ID mismatch on QoS 1 publish or packet count wrong."); } else if (pub.qos == 2) { waitForCondition([&]() { return this->receivedPackets.size() >= 2; }); MqttPacket &pubRecPack = this->receivedPackets.front(); pubRecPack.parsePubRecData(); MqttPacket &pubCompPack = this->receivedPackets.back(); pubCompPack.parsePubComp(); if (pubRecPack.packetType != PacketType::PUBREC) throw std::runtime_error("First packet received from server is not a PUBREC."); if (pubCompPack.packetType != PacketType::PUBCOMP) throw std::runtime_error("Last packet received from server is not a PUBCOMP."); if (pubRecPack.getPacketId() != packet_id || pubCompPack.getPacketId() != packet_id) throw std::runtime_error("Packet ID mismatch on QoS 2 publish."); } } void FlashMQTestClient::publish(const std::string &topic, const std::string &payload, char qos) { Publish pub(topic, payload, qos); publish(pub); } void FlashMQTestClient::waitForQuit() { testServerWorkerThreadData->queueQuit(); testServerWorkerThreadData->waitForQuit(); } void FlashMQTestClient::waitForConnack() { waitForCondition([&]() { return std::any_of(this->receivedPackets.begin(), this->receivedPackets.end(), [](const MqttPacket &p) { return p.packetType == PacketType::CONNACK; }); }); } void FlashMQTestClient::waitForMessageCount(const size_t count, int timeout) { waitForCondition([&]() { return this->receivedPublishes.size() >= count; }, timeout); }