diff --git a/FlashMQTests/FlashMQTests.pro b/FlashMQTests/FlashMQTests.pro
index f8d746b..53a2776 100644
--- a/FlashMQTests/FlashMQTests.pro
+++ b/FlashMQTests/FlashMQTests.pro
@@ -55,6 +55,7 @@ SOURCES += tst_maintests.cpp \
../globalstats.cpp \
../derivablecounter.cpp \
../packetdatatypes.cpp \
+ ../flashmqtestclient.cpp \
mainappthread.cpp \
twoclienttestcontext.cpp
@@ -102,6 +103,7 @@ HEADERS += \
../globalstats.h \
../derivablecounter.h \
../packetdatatypes.h \
+ ../flashmqtestclient.h \
mainappthread.h \
twoclienttestcontext.h
diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp
index 6f7ad13..5c40916 100644
--- a/FlashMQTests/tst_maintests.cpp
+++ b/FlashMQTests/tst_maintests.cpp
@@ -35,6 +35,8 @@ License along with FlashMQ. If not, see .
#include "threaddata.h"
#include "threadglobals.h"
+#include "flashmqtestclient.h"
+
// Dumb Qt version gives warnings when comparing uint with number literal.
template
inline bool myCastCompare(const T1 &t1, const T2 &t2, const char *actual, const char *expected,
@@ -56,6 +58,7 @@ class MainTests : public QObject
Q_OBJECT
QScopedPointer mainApp;
+ std::shared_ptr dummyThreadData;
void testParsePacketHelper(const std::string &topic, char from_qos, bool retain);
@@ -117,6 +120,8 @@ private slots:
void testUnSubscribe();
+ void testBasicsWithFlashMQTestClient();
+
};
MainTests::MainTests()
@@ -135,6 +140,12 @@ void MainTests::init()
mainApp.reset(new MainAppThread());
mainApp->start();
mainApp->waitForStarted();
+
+ // We test functions directly that the server normally only calls from worker threads, in which thread data is available. This is kind of a dummy-fix, until
+ // we actually need correct thread data at those points (at this point, it's only to increase message counters).
+ std::shared_ptr settings = std::make_shared();
+ this->dummyThreadData = std::make_shared(666, settings);
+ ThreadGlobals::assignThreadData(dummyThreadData.get());
}
void MainTests::cleanup()
@@ -1342,6 +1353,77 @@ void MainTests::testUnSubscribe()
}));
}
+/**
+ * @brief MainTests::testBasicsWithFlashMQTestClient was used to develop FlashMQTestClient.
+ */
+void MainTests::testBasicsWithFlashMQTestClient()
+{
+ FlashMQTestClient client;
+ client.start();
+ client.connectClient(ProtocolVersion::Mqtt311);
+
+ MqttPacket &connAckPack = client.receivedPackets.front();
+ QVERIFY(connAckPack.packetType == PacketType::CONNACK);
+
+ {
+ client.subscribe("a/b", 1);
+
+ MqttPacket &subAck = client.receivedPackets.front();
+ SubAckData subAckData = subAck.parseSubAckData();
+
+ QVERIFY(subAckData.subAckCodes.size() == 1);
+ QVERIFY(subAckData.subAckCodes.front() == 1);
+ }
+
+ {
+ client.subscribe("c/d", 2);
+
+ MqttPacket &subAck = client.receivedPackets.front();
+ SubAckData subAckData = subAck.parseSubAckData();
+
+ QVERIFY(subAckData.subAckCodes.size() == 1);
+ QVERIFY(subAckData.subAckCodes.front() == 2);
+ }
+
+ client.clearReceivedLists();
+
+ FlashMQTestClient publisher;
+ publisher.start();
+ publisher.connectClient(ProtocolVersion::Mqtt5);
+
+ {
+ publisher.publish("a/b", "wave", 2);
+
+ client.waitForMessageCount(1);
+ MqttPacket &p = client.receivedPublishes.front();
+
+ QCOMPARE(p.getPublishData().topic, "a/b");
+ QCOMPARE(p.getPayloadCopy(), "wave");
+ QCOMPARE(p.getPublishData().qos, 1);
+ QVERIFY(p.getPacketId() > 0);
+ QVERIFY(p.protocolVersion == ProtocolVersion::Mqtt311);
+ }
+
+ client.clearReceivedLists();
+
+ {
+ publisher.publish("c/d", "asdfasdfasdf", 2);
+
+ client.waitForMessageCount(1);
+ MqttPacket &p = client.receivedPublishes.back();
+
+ MYCASTCOMPARE(client.receivedPublishes.size(), 1);
+
+ QCOMPARE(p.getPublishData().topic, "c/d");
+ QCOMPARE(p.getPayloadCopy(), "asdfasdfasdf");
+ QCOMPARE(p.getPublishData().qos, 2);
+ QVERIFY(p.getPacketId() > 1); // It's the same client, so it should not re-use packet id 1
+ QVERIFY(p.protocolVersion == ProtocolVersion::Mqtt311);
+ }
+
+
+}
+
int main(int argc, char *argv[])
{
diff --git a/client.cpp b/client.cpp
index 95b1c2b..1e6df1b 100644
--- a/client.cpp
+++ b/client.cpp
@@ -62,7 +62,7 @@ Client::Client(int fd, std::shared_ptr threadData, SSL *ssl, bool we
Client::~Client()
{
// Dummy clients, that I sometimes need just because the interface demands it but there's not actually a client, have no thread.
- if (this->threadData.expired())
+ if (this->epoll_fd == 0)
return;
if (disconnectReason.empty())
diff --git a/client.h b/client.h
index 7de28ed..89ee81d 100644
--- a/client.h
+++ b/client.h
@@ -187,6 +187,10 @@ public:
void setExtendedAuthenticationMethod(const std::string &authMethod);
const std::string &getExtendedAuthenticationMethod() const;
+#ifdef TESTING
+ std::function onPacketReceived;
+#endif
+
#ifndef NDEBUG
void setFakeUpgraded();
#endif
diff --git a/flashmqtestclient.cpp b/flashmqtestclient.cpp
new file mode 100644
index 0000000..2efc450
--- /dev/null
+++ b/flashmqtestclient.cpp
@@ -0,0 +1,235 @@
+#include "flashmqtestclient.h"
+
+#include
+#include
+#include "errno.h"
+#include "functional"
+
+#include "threadloop.h"
+#include "utils.h"
+#include "client.h"
+
+
+#define TEST_CLIENT_MAX_EVENTS 25
+
+int FlashMQTestClient::clientCount = 0;
+
+FlashMQTestClient::FlashMQTestClient() :
+ settings(std::make_shared()),
+ testServerWorkerThreadData(std::make_shared(0, settings)),
+ dummyThreadData(std::make_shared(666, settings))
+{
+
+}
+
+/**
+ * @brief FlashMQTestClient::~FlashMQTestClient properly quits the threads when exiting.
+ *
+ * This prevents accidental crashes on calling terminate(), and Qt Macro's prematurely end the method, skipping explicit waits after the tests.
+ */
+FlashMQTestClient::~FlashMQTestClient()
+{
+ waitForQuit();
+}
+
+void FlashMQTestClient::waitForCondition(std::function f)
+{
+ int n = 0;
+ while(n++ < 100)
+ {
+ usleep(10000);
+
+ std::lock_guard locker(receivedListMutex);
+
+ if (f())
+ break;
+ }
+
+ std::lock_guard locker(receivedListMutex);
+
+ if (!f())
+ {
+ throw std::runtime_error("Wait condition failed.");
+ }
+}
+
+void FlashMQTestClient::clearReceivedLists()
+{
+ std::lock_guard locker(receivedListMutex);
+ receivedPackets.clear();
+ receivedPublishes.clear();
+}
+
+void FlashMQTestClient::start()
+{
+ testServerWorkerThreadData->start(&do_thread_work);
+}
+
+void FlashMQTestClient::connectClient(ProtocolVersion protocolVersion)
+{
+ int sockfd = check(socket(AF_INET, SOCK_STREAM, 0));
+
+ struct sockaddr_in servaddr;
+ bzero(&servaddr, sizeof(servaddr));
+
+ const std::string hostname = "127.0.0.1";
+
+ servaddr.sin_family = AF_INET;
+ servaddr.sin_addr.s_addr = inet_addr(hostname.c_str());
+ servaddr.sin_port = htons(1883);
+
+ int flags = fcntl(sockfd, F_GETFL);
+ fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);
+
+ int rc = connect(sockfd, reinterpret_cast(&servaddr), sizeof (servaddr));
+
+ if (rc < 0 && errno != EINPROGRESS)
+ {
+ throw std::runtime_error(strerror(errno));
+ }
+
+ const std::string clientid = formatString("testclient_%d", clientCount++);
+
+ this->client = std::make_shared(sockfd, testServerWorkerThreadData, nullptr, false, reinterpret_cast(&servaddr), settings.get());
+ this->client->setClientProperties(protocolVersion, clientid, "user", false, 60);
+
+ testServerWorkerThreadData->giveClient(this->client);
+
+ // This gets called in the test client's worker thread, but the STL container's minimal thread safety should be enough: only list manipulation is
+ // mutexed, elements within are not.
+ client->onPacketReceived = [&](MqttPacket &pack)
+ {
+ std::lock_guard locker(receivedListMutex);
+
+ if (pack.packetType == PacketType::PUBLISH)
+ {
+ pack.parsePublishData();
+
+ MqttPacket copyPacket = pack;
+ this->receivedPublishes.push_back(copyPacket);
+
+ if (pack.getPublishData().qos > 0)
+ {
+ PubResponse pubAck(this->client->getProtocolVersion(), PacketType::PUBACK, ReasonCodes::Success, pack.getPacketId());
+ this->client->writeMqttPacketAndBlameThisClient(pubAck);
+ }
+ }
+ else if (pack.packetType == PacketType::PUBREL)
+ {
+ pack.parsePubRelData();
+ PubResponse pubComp(this->client->getProtocolVersion(), PacketType::PUBCOMP, ReasonCodes::Success, pack.getPacketId());
+ this->client->writeMqttPacketAndBlameThisClient(pubComp);
+ }
+ else if (pack.packetType == PacketType::PUBREC)
+ {
+ pack.parsePubRecData();
+ PubResponse pubRel(this->client->getProtocolVersion(), PacketType::PUBREL, ReasonCodes::Success, pack.getPacketId());
+ this->client->writeMqttPacketAndBlameThisClient(pubRel);
+ }
+
+ this->receivedPackets.push_back(std::move(pack));
+ };
+
+ Connect connect(protocolVersion, client->getClientId());
+ MqttPacket connectPack(connect);
+ this->client->writeMqttPacketAndBlameThisClient(connectPack);
+
+ waitForConnack();
+}
+
+void FlashMQTestClient::subscribe(const std::string topic, char qos)
+{
+ clearReceivedLists();
+
+ const uint16_t packet_id = 66;
+
+ Subscribe sub(client->getProtocolVersion(), packet_id, topic, qos);
+ MqttPacket subPack(sub);
+ client->writeMqttPacketAndBlameThisClient(subPack);
+
+ waitForCondition([&]() {
+ return !this->receivedPackets.empty() && this->receivedPackets.front().packetType == PacketType::SUBACK;
+ });
+
+ MqttPacket &subAck = this->receivedPackets.front();
+ SubAckData data = subAck.parseSubAckData();
+
+ if (data.packet_id != packet_id)
+ throw std::runtime_error("Incorrect packet id in suback");
+
+ if (!std::all_of(data.subAckCodes.begin(), data.subAckCodes.end(), [&](uint8_t x) { return x <= qos ;}))
+ {
+ throw std::runtime_error("Suback indicates error.");
+ }
+}
+
+void FlashMQTestClient::publish(const std::string &topic, const std::string &payload, char qos)
+{
+ clearReceivedLists();
+
+ const uint16_t packet_id = 77;
+
+ Publish pub(topic, payload, qos);
+ MqttPacket pubPack(client->getProtocolVersion(), pub);
+ if (qos > 0)
+ pubPack.setPacketId(packet_id);
+ client->writeMqttPacketAndBlameThisClient(pubPack);
+
+ if (qos == 1)
+ {
+ waitForCondition([&]() {
+ return this->receivedPackets.size() == 1;
+ });
+
+ MqttPacket &pubAckPack = this->receivedPackets.front();
+ pubAckPack.parsePubAckData();
+
+ if (pubAckPack.packetType != PacketType::PUBACK)
+ throw std::runtime_error("First packet received from server is not a PUBACK.");
+
+ if (pubAckPack.getPacketId() != packet_id || this->receivedPackets.size() != 1)
+ throw std::runtime_error("Packet ID mismatch on QoS 1 publish or packet count wrong.");
+ }
+ else if (qos == 2)
+ {
+ waitForCondition([&]() {
+ return this->receivedPackets.size() >= 2;
+ });
+
+ MqttPacket &pubRecPack = this->receivedPackets.front();
+ pubRecPack.parsePubRecData();
+ MqttPacket &pubCompPack = this->receivedPackets.back();
+ pubCompPack.parsePubComp();
+
+ if (pubRecPack.packetType != PacketType::PUBREC)
+ throw std::runtime_error("First packet received from server is not a PUBREC.");
+
+ if (pubCompPack.packetType != PacketType::PUBCOMP)
+ throw std::runtime_error("Last packet received from server is not a PUBCOMP.");
+
+ if (pubRecPack.getPacketId() != packet_id || pubCompPack.getPacketId() != packet_id)
+ throw std::runtime_error("Packet ID mismatch on QoS 2 publish.");
+ }
+}
+
+void FlashMQTestClient::waitForQuit()
+{
+ testServerWorkerThreadData->queueQuit();
+ testServerWorkerThreadData->waitForQuit();
+}
+
+void FlashMQTestClient::waitForConnack()
+{
+ waitForCondition([&]() {
+ return std::any_of(this->receivedPackets.begin(), this->receivedPackets.end(), [](const MqttPacket &p) {
+ return p.packetType == PacketType::CONNACK;
+ });
+ });
+}
+
+void FlashMQTestClient::waitForMessageCount(const size_t count)
+{
+ waitForCondition([&]() {
+ return this->receivedPublishes.size() >= count;
+ });
+}
diff --git a/flashmqtestclient.h b/flashmqtestclient.h
new file mode 100644
index 0000000..8fe4f1c
--- /dev/null
+++ b/flashmqtestclient.h
@@ -0,0 +1,45 @@
+#ifndef FLASHMQTESTCLIENT_H
+#define FLASHMQTESTCLIENT_H
+
+#include
+#include
+
+#include "subscriptionstore.h"
+
+/**
+ * @brief The FlashMQTestClient class uses the existing server code as a client, for testing purposes.
+ */
+class FlashMQTestClient
+{
+ std::shared_ptr settings;
+ std::shared_ptr testServerWorkerThreadData;
+ std::shared_ptr client;
+
+ std::shared_ptr dummyThreadData;
+
+ std::mutex receivedListMutex;
+
+ static int clientCount;
+
+ void waitForCondition(std::function f);
+
+
+public:
+ std::list receivedPackets;
+ std::list receivedPublishes;
+
+ FlashMQTestClient();
+ ~FlashMQTestClient();
+
+ void start();
+ void connectClient(ProtocolVersion protocolVersion);
+ void subscribe(const std::string topic, char qos);
+ void publish(const std::string &topic, const std::string &payload, char qos);
+ void clearReceivedLists();
+
+ void waitForQuit();
+ void waitForConnack();
+ void waitForMessageCount(const size_t count);
+};
+
+#endif // FLASHMQTESTCLIENT_H
diff --git a/mqttpacket.cpp b/mqttpacket.cpp
index 30b0f8d..84f8b7d 100644
--- a/mqttpacket.cpp
+++ b/mqttpacket.cpp
@@ -229,6 +229,60 @@ MqttPacket::MqttPacket(const Auth &auth) :
calculateRemainingLength();
}
+MqttPacket::MqttPacket(const Connect &connect) :
+ bites(connect.getLengthWithoutFixedHeader()),
+ protocolVersion(connect.protocolVersion),
+ packetType(PacketType::CONNECT)
+{
+#ifndef TESTING
+ throw NotImplementedException("Code is only for testing.");
+#endif
+
+ first_byte = static_cast(packetType) << 4;
+
+ const std::string magicString = connect.getMagicString();
+ writeString(magicString);
+
+ writeByte(static_cast(protocolVersion));
+ writeByte(2); // flags; The only bit set is 'clean session'.
+
+ // Keep-alive
+ writeUint16(60);
+
+ if (connect.protocolVersion >= ProtocolVersion::Mqtt5)
+ {
+ writeProperties(connect.propertyBuilder);
+ }
+
+ writeString(connect.clientid);
+
+ calculateRemainingLength();
+}
+
+MqttPacket::MqttPacket(const Subscribe &subscribe) :
+ bites(subscribe.getLengthWithoutFixedHeader()),
+ packetType(PacketType::SUBSCRIBE)
+{
+#ifndef TESTING
+ throw NotImplementedException("Code is only for testing.");
+#endif
+
+ first_byte = static_cast(packetType) << 4;
+ first_byte |= 2; // required reserved bit
+
+ writeUint16(subscribe.packetId);
+
+ if (subscribe.protocolVersion >= ProtocolVersion::Mqtt5)
+ {
+ writeProperties(subscribe.propertyBuilder);
+ }
+
+ writeString(subscribe.topic);
+ writeByte(subscribe.qos);
+
+ calculateRemainingLength();
+}
+
void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector &packetQueueIn, std::shared_ptr &sender)
{
while (buf.usedBytes() >= MQTT_HEADER_LENGH)
@@ -1035,6 +1089,8 @@ void MqttPacket::handleUnsubscribe()
void MqttPacket::parsePublishData()
{
+ assert(externallyReceived);
+
setPosToDataStart();
publishData.retain = (first_byte & 0b00000001);
@@ -1214,6 +1270,7 @@ void MqttPacket::handlePublish()
void MqttPacket::parsePubAckData()
{
setPosToDataStart();
+ this->publishData.qos = 1;
this->packet_id = readTwoBytesToUInt16();
}
@@ -1554,6 +1611,12 @@ void MqttPacket::writeVariableByteInt(const VariableByteInt &v)
writeBytes(v.data(), v.getLen());
}
+void MqttPacket::writeString(const std::string &s)
+{
+ writeUint16(s.length());
+ writeBytes(s.c_str(), s.length());
+}
+
uint16_t MqttPacket::readTwoBytesToUInt16()
{
if (pos + 2 > bites.size())
diff --git a/mqttpacket.h b/mqttpacket.h
index 32c0645..97cc263 100644
--- a/mqttpacket.h
+++ b/mqttpacket.h
@@ -78,6 +78,7 @@ class MqttPacket
void writeBytes(const char *b, size_t len);
void writeProperties(const std::shared_ptr &properties);
void writeVariableByteInt(const VariableByteInt &v);
+ void writeString(const std::string &s);
uint16_t readTwoBytesToUInt16();
uint32_t readFourBytesToUint32();
size_t remainingAfterPos();
@@ -89,11 +90,18 @@ class MqttPacket
void setPosToDataStart();
bool atEnd() const;
+#ifndef TESTING
+ // In production, I want to be sure I don't accidentally copy packets, because it's slow.
MqttPacket(const MqttPacket &other) = delete;
+#endif
public:
+#ifdef TESTING
+ // In testing I need to copy packets for administrative purposes.
+ MqttPacket(const MqttPacket &other) = default;
+#endif
PacketType packetType = PacketType::Reserved;
- MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, std::shared_ptr &sender); // Constructor for parsing incoming packets.
+ MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, std::shared_ptr &sender); // Constructor for parsing incoming packets.
MqttPacket(MqttPacket &&other) = default;
size_t setClientSpecificPropertiesAndGetRequiredSizeForPublish(const ProtocolVersion protocolVersion, Publish &publishData) const;
@@ -106,6 +114,8 @@ public:
MqttPacket(const PubResponse &pubAck);
MqttPacket(const Disconnect &disconnect);
MqttPacket(const Auth &auth);
+ MqttPacket(const Connect &connect);
+ MqttPacket(const Subscribe &subscribe);
static void bufferToMqttPackets(CirBuf &buf, std::vector &packetQueueIn, std::shared_ptr &sender);
diff --git a/threadloop.cpp b/threadloop.cpp
index 9e60ae5..c3c161d 100644
--- a/threadloop.cpp
+++ b/threadloop.cpp
@@ -110,6 +110,11 @@ void do_thread_work(ThreadData *threadData)
for (MqttPacket &packet : packetQueueIn)
{
+#ifdef TESTING
+ if (client->onPacketReceived)
+ client->onPacketReceived(packet);
+ else
+#endif
packet.handle();
}
diff --git a/types.cpp b/types.cpp
index a7a56f7..86f23b6 100644
--- a/types.cpp
+++ b/types.cpp
@@ -345,9 +345,57 @@ size_t Auth::getLengthWithoutFixedHeader() const
return result;
}
+Connect::Connect(ProtocolVersion protocolVersion, const std::string &clientid) :
+ protocolVersion(protocolVersion),
+ clientid(clientid)
+{
+
+}
+
+size_t Connect::getLengthWithoutFixedHeader() const
+{
+ size_t result = clientid.length() + 2;
+
+ result += this->protocolVersion <= ProtocolVersion::Mqtt31 ? 6 : 4;
+ result += 6; // header stuff, lengths, keep-alive
+
+ if (this->protocolVersion >= ProtocolVersion::Mqtt5)
+ {
+ const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1;
+ result += proplen;
+ }
+ return result;
+
+}
+std::string Connect::getMagicString() const
+{
+ if (protocolVersion <= ProtocolVersion::Mqtt31)
+ return "MQIsdp";
+ else
+ return "MQTT";
+}
+Subscribe::Subscribe(const ProtocolVersion protocolVersion, uint16_t packetId, const std::string &topic, char qos) :
+ protocolVersion(protocolVersion),
+ packetId(packetId),
+ topic(topic),
+ qos(qos)
+{
+}
+size_t Subscribe::getLengthWithoutFixedHeader() const
+{
+ size_t result = topic.size() + 2;
+ result += 2; // packet id
+ result += 1; // requested QoS
+ if (this->protocolVersion >= ProtocolVersion::Mqtt5)
+ {
+ const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1;
+ result += proplen;
+ }
+ return result;
+}
diff --git a/types.h b/types.h
index 193dfc4..2fc79ce 100644
--- a/types.h
+++ b/types.h
@@ -277,4 +277,34 @@ public:
size_t getLengthWithoutFixedHeader() const;
};
+struct Connect
+{
+ const ProtocolVersion protocolVersion;
+ std::string clientid;
+ std::string username;
+ std::string password;
+ std::shared_ptr propertyBuilder;
+
+ Connect(ProtocolVersion protocolVersion, const std::string &clientid);
+ size_t getLengthWithoutFixedHeader() const;
+ std::string getMagicString() const;
+};
+
+/**
+ * @brief The Subscribe struct can be used to construct a mqtt packet of type 'subscribe'.
+ *
+ * It's rudimentary. Offically you can subscribe to multiple topics at once, but I have no need for that.
+ */
+struct Subscribe
+{
+ const ProtocolVersion protocolVersion;
+ uint16_t packetId;
+ std::string topic;
+ char qos;
+ std::shared_ptr propertyBuilder;
+
+ Subscribe(const ProtocolVersion protocolVersion, uint16_t packetId, const std::string &topic, char qos);
+ size_t getLengthWithoutFixedHeader() const;
+};
+
#endif // TYPES_H