Commit 56d9c5107589dfcad09af77e0f07d5b050615c72

Authored by Wiebe Cazemier
1 parent 483db298

Add packet based test client

Now that parsing and handling of packets is separated, we can use the
main code to parse packets in the new FlashMQTestClient. This allows
great flexibility in inspecting the server response in a flexible
manner.

We now also have the ability to make tests for MQTT5 features.
FlashMQTests/FlashMQTests.pro
@@ -55,6 +55,7 @@ SOURCES += tst_maintests.cpp \ @@ -55,6 +55,7 @@ SOURCES += tst_maintests.cpp \
55 ../globalstats.cpp \ 55 ../globalstats.cpp \
56 ../derivablecounter.cpp \ 56 ../derivablecounter.cpp \
57 ../packetdatatypes.cpp \ 57 ../packetdatatypes.cpp \
  58 + ../flashmqtestclient.cpp \
58 mainappthread.cpp \ 59 mainappthread.cpp \
59 twoclienttestcontext.cpp 60 twoclienttestcontext.cpp
60 61
@@ -102,6 +103,7 @@ HEADERS += \ @@ -102,6 +103,7 @@ HEADERS += \
102 ../globalstats.h \ 103 ../globalstats.h \
103 ../derivablecounter.h \ 104 ../derivablecounter.h \
104 ../packetdatatypes.h \ 105 ../packetdatatypes.h \
  106 + ../flashmqtestclient.h \
105 mainappthread.h \ 107 mainappthread.h \
106 twoclienttestcontext.h 108 twoclienttestcontext.h
107 109
FlashMQTests/tst_maintests.cpp
@@ -35,6 +35,8 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. @@ -35,6 +35,8 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
35 #include "threaddata.h" 35 #include "threaddata.h"
36 #include "threadglobals.h" 36 #include "threadglobals.h"
37 37
  38 +#include "flashmqtestclient.h"
  39 +
38 // Dumb Qt version gives warnings when comparing uint with number literal. 40 // Dumb Qt version gives warnings when comparing uint with number literal.
39 template <typename T1, typename T2> 41 template <typename T1, typename T2>
40 inline bool myCastCompare(const T1 &t1, const T2 &t2, const char *actual, const char *expected, 42 inline bool myCastCompare(const T1 &t1, const T2 &t2, const char *actual, const char *expected,
@@ -56,6 +58,7 @@ class MainTests : public QObject @@ -56,6 +58,7 @@ class MainTests : public QObject
56 Q_OBJECT 58 Q_OBJECT
57 59
58 QScopedPointer<MainAppThread> mainApp; 60 QScopedPointer<MainAppThread> mainApp;
  61 + std::shared_ptr<ThreadData> dummyThreadData;
59 62
60 void testParsePacketHelper(const std::string &topic, char from_qos, bool retain); 63 void testParsePacketHelper(const std::string &topic, char from_qos, bool retain);
61 64
@@ -117,6 +120,8 @@ private slots: @@ -117,6 +120,8 @@ private slots:
117 120
118 void testUnSubscribe(); 121 void testUnSubscribe();
119 122
  123 + void testBasicsWithFlashMQTestClient();
  124 +
120 }; 125 };
121 126
122 MainTests::MainTests() 127 MainTests::MainTests()
@@ -135,6 +140,12 @@ void MainTests::init() @@ -135,6 +140,12 @@ void MainTests::init()
135 mainApp.reset(new MainAppThread()); 140 mainApp.reset(new MainAppThread());
136 mainApp->start(); 141 mainApp->start();
137 mainApp->waitForStarted(); 142 mainApp->waitForStarted();
  143 +
  144 + // We test functions directly that the server normally only calls from worker threads, in which thread data is available. This is kind of a dummy-fix, until
  145 + // we actually need correct thread data at those points (at this point, it's only to increase message counters).
  146 + std::shared_ptr<Settings> settings = std::make_shared<Settings>();
  147 + this->dummyThreadData = std::make_shared<ThreadData>(666, settings);
  148 + ThreadGlobals::assignThreadData(dummyThreadData.get());
138 } 149 }
139 150
140 void MainTests::cleanup() 151 void MainTests::cleanup()
@@ -1342,6 +1353,77 @@ void MainTests::testUnSubscribe() @@ -1342,6 +1353,77 @@ void MainTests::testUnSubscribe()
1342 })); 1353 }));
1343 } 1354 }
1344 1355
  1356 +/**
  1357 + * @brief MainTests::testBasicsWithFlashMQTestClient was used to develop FlashMQTestClient.
  1358 + */
  1359 +void MainTests::testBasicsWithFlashMQTestClient()
  1360 +{
  1361 + FlashMQTestClient client;
  1362 + client.start();
  1363 + client.connectClient(ProtocolVersion::Mqtt311);
  1364 +
  1365 + MqttPacket &connAckPack = client.receivedPackets.front();
  1366 + QVERIFY(connAckPack.packetType == PacketType::CONNACK);
  1367 +
  1368 + {
  1369 + client.subscribe("a/b", 1);
  1370 +
  1371 + MqttPacket &subAck = client.receivedPackets.front();
  1372 + SubAckData subAckData = subAck.parseSubAckData();
  1373 +
  1374 + QVERIFY(subAckData.subAckCodes.size() == 1);
  1375 + QVERIFY(subAckData.subAckCodes.front() == 1);
  1376 + }
  1377 +
  1378 + {
  1379 + client.subscribe("c/d", 2);
  1380 +
  1381 + MqttPacket &subAck = client.receivedPackets.front();
  1382 + SubAckData subAckData = subAck.parseSubAckData();
  1383 +
  1384 + QVERIFY(subAckData.subAckCodes.size() == 1);
  1385 + QVERIFY(subAckData.subAckCodes.front() == 2);
  1386 + }
  1387 +
  1388 + client.clearReceivedLists();
  1389 +
  1390 + FlashMQTestClient publisher;
  1391 + publisher.start();
  1392 + publisher.connectClient(ProtocolVersion::Mqtt5);
  1393 +
  1394 + {
  1395 + publisher.publish("a/b", "wave", 2);
  1396 +
  1397 + client.waitForMessageCount(1);
  1398 + MqttPacket &p = client.receivedPublishes.front();
  1399 +
  1400 + QCOMPARE(p.getPublishData().topic, "a/b");
  1401 + QCOMPARE(p.getPayloadCopy(), "wave");
  1402 + QCOMPARE(p.getPublishData().qos, 1);
  1403 + QVERIFY(p.getPacketId() > 0);
  1404 + QVERIFY(p.protocolVersion == ProtocolVersion::Mqtt311);
  1405 + }
  1406 +
  1407 + client.clearReceivedLists();
  1408 +
  1409 + {
  1410 + publisher.publish("c/d", "asdfasdfasdf", 2);
  1411 +
  1412 + client.waitForMessageCount(1);
  1413 + MqttPacket &p = client.receivedPublishes.back();
  1414 +
  1415 + MYCASTCOMPARE(client.receivedPublishes.size(), 1);
  1416 +
  1417 + QCOMPARE(p.getPublishData().topic, "c/d");
  1418 + QCOMPARE(p.getPayloadCopy(), "asdfasdfasdf");
  1419 + QCOMPARE(p.getPublishData().qos, 2);
  1420 + QVERIFY(p.getPacketId() > 1); // It's the same client, so it should not re-use packet id 1
  1421 + QVERIFY(p.protocolVersion == ProtocolVersion::Mqtt311);
  1422 + }
  1423 +
  1424 +
  1425 +}
  1426 +
1345 1427
1346 int main(int argc, char *argv[]) 1428 int main(int argc, char *argv[])
1347 { 1429 {
client.cpp
@@ -62,7 +62,7 @@ Client::Client(int fd, std::shared_ptr&lt;ThreadData&gt; threadData, SSL *ssl, bool we @@ -62,7 +62,7 @@ Client::Client(int fd, std::shared_ptr&lt;ThreadData&gt; threadData, SSL *ssl, bool we
62 Client::~Client() 62 Client::~Client()
63 { 63 {
64 // Dummy clients, that I sometimes need just because the interface demands it but there's not actually a client, have no thread. 64 // Dummy clients, that I sometimes need just because the interface demands it but there's not actually a client, have no thread.
65 - if (this->threadData.expired()) 65 + if (this->epoll_fd == 0)
66 return; 66 return;
67 67
68 if (disconnectReason.empty()) 68 if (disconnectReason.empty())
client.h
@@ -187,6 +187,10 @@ public: @@ -187,6 +187,10 @@ public:
187 void setExtendedAuthenticationMethod(const std::string &authMethod); 187 void setExtendedAuthenticationMethod(const std::string &authMethod);
188 const std::string &getExtendedAuthenticationMethod() const; 188 const std::string &getExtendedAuthenticationMethod() const;
189 189
  190 +#ifdef TESTING
  191 + std::function<void(MqttPacket &packet)> onPacketReceived;
  192 +#endif
  193 +
190 #ifndef NDEBUG 194 #ifndef NDEBUG
191 void setFakeUpgraded(); 195 void setFakeUpgraded();
192 #endif 196 #endif
flashmqtestclient.cpp 0 โ†’ 100644
  1 +#include "flashmqtestclient.h"
  2 +
  3 +#include <sys/epoll.h>
  4 +#include <cstring>
  5 +#include "errno.h"
  6 +#include "functional"
  7 +
  8 +#include "threadloop.h"
  9 +#include "utils.h"
  10 +#include "client.h"
  11 +
  12 +
  13 +#define TEST_CLIENT_MAX_EVENTS 25
  14 +
  15 +int FlashMQTestClient::clientCount = 0;
  16 +
  17 +FlashMQTestClient::FlashMQTestClient() :
  18 + settings(std::make_shared<Settings>()),
  19 + testServerWorkerThreadData(std::make_shared<ThreadData>(0, settings)),
  20 + dummyThreadData(std::make_shared<ThreadData>(666, settings))
  21 +{
  22 +
  23 +}
  24 +
  25 +/**
  26 + * @brief FlashMQTestClient::~FlashMQTestClient properly quits the threads when exiting.
  27 + *
  28 + * This prevents accidental crashes on calling terminate(), and Qt Macro's prematurely end the method, skipping explicit waits after the tests.
  29 + */
  30 +FlashMQTestClient::~FlashMQTestClient()
  31 +{
  32 + waitForQuit();
  33 +}
  34 +
  35 +void FlashMQTestClient::waitForCondition(std::function<bool()> f)
  36 +{
  37 + int n = 0;
  38 + while(n++ < 100)
  39 + {
  40 + usleep(10000);
  41 +
  42 + std::lock_guard<std::mutex> locker(receivedListMutex);
  43 +
  44 + if (f())
  45 + break;
  46 + }
  47 +
  48 + std::lock_guard<std::mutex> locker(receivedListMutex);
  49 +
  50 + if (!f())
  51 + {
  52 + throw std::runtime_error("Wait condition failed.");
  53 + }
  54 +}
  55 +
  56 +void FlashMQTestClient::clearReceivedLists()
  57 +{
  58 + std::lock_guard<std::mutex> locker(receivedListMutex);
  59 + receivedPackets.clear();
  60 + receivedPublishes.clear();
  61 +}
  62 +
  63 +void FlashMQTestClient::start()
  64 +{
  65 + testServerWorkerThreadData->start(&do_thread_work);
  66 +}
  67 +
  68 +void FlashMQTestClient::connectClient(ProtocolVersion protocolVersion)
  69 +{
  70 + int sockfd = check<std::runtime_error>(socket(AF_INET, SOCK_STREAM, 0));
  71 +
  72 + struct sockaddr_in servaddr;
  73 + bzero(&servaddr, sizeof(servaddr));
  74 +
  75 + const std::string hostname = "127.0.0.1";
  76 +
  77 + servaddr.sin_family = AF_INET;
  78 + servaddr.sin_addr.s_addr = inet_addr(hostname.c_str());
  79 + servaddr.sin_port = htons(1883);
  80 +
  81 + int flags = fcntl(sockfd, F_GETFL);
  82 + fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);
  83 +
  84 + int rc = connect(sockfd, reinterpret_cast<struct sockaddr*>(&servaddr), sizeof (servaddr));
  85 +
  86 + if (rc < 0 && errno != EINPROGRESS)
  87 + {
  88 + throw std::runtime_error(strerror(errno));
  89 + }
  90 +
  91 + const std::string clientid = formatString("testclient_%d", clientCount++);
  92 +
  93 + this->client = std::make_shared<Client>(sockfd, testServerWorkerThreadData, nullptr, false, reinterpret_cast<struct sockaddr*>(&servaddr), settings.get());
  94 + this->client->setClientProperties(protocolVersion, clientid, "user", false, 60);
  95 +
  96 + testServerWorkerThreadData->giveClient(this->client);
  97 +
  98 + // This gets called in the test client's worker thread, but the STL container's minimal thread safety should be enough: only list manipulation is
  99 + // mutexed, elements within are not.
  100 + client->onPacketReceived = [&](MqttPacket &pack)
  101 + {
  102 + std::lock_guard<std::mutex> locker(receivedListMutex);
  103 +
  104 + if (pack.packetType == PacketType::PUBLISH)
  105 + {
  106 + pack.parsePublishData();
  107 +
  108 + MqttPacket copyPacket = pack;
  109 + this->receivedPublishes.push_back(copyPacket);
  110 +
  111 + if (pack.getPublishData().qos > 0)
  112 + {
  113 + PubResponse pubAck(this->client->getProtocolVersion(), PacketType::PUBACK, ReasonCodes::Success, pack.getPacketId());
  114 + this->client->writeMqttPacketAndBlameThisClient(pubAck);
  115 + }
  116 + }
  117 + else if (pack.packetType == PacketType::PUBREL)
  118 + {
  119 + pack.parsePubRelData();
  120 + PubResponse pubComp(this->client->getProtocolVersion(), PacketType::PUBCOMP, ReasonCodes::Success, pack.getPacketId());
  121 + this->client->writeMqttPacketAndBlameThisClient(pubComp);
  122 + }
  123 + else if (pack.packetType == PacketType::PUBREC)
  124 + {
  125 + pack.parsePubRecData();
  126 + PubResponse pubRel(this->client->getProtocolVersion(), PacketType::PUBREL, ReasonCodes::Success, pack.getPacketId());
  127 + this->client->writeMqttPacketAndBlameThisClient(pubRel);
  128 + }
  129 +
  130 + this->receivedPackets.push_back(std::move(pack));
  131 + };
  132 +
  133 + Connect connect(protocolVersion, client->getClientId());
  134 + MqttPacket connectPack(connect);
  135 + this->client->writeMqttPacketAndBlameThisClient(connectPack);
  136 +
  137 + waitForConnack();
  138 +}
  139 +
  140 +void FlashMQTestClient::subscribe(const std::string topic, char qos)
  141 +{
  142 + clearReceivedLists();
  143 +
  144 + const uint16_t packet_id = 66;
  145 +
  146 + Subscribe sub(client->getProtocolVersion(), packet_id, topic, qos);
  147 + MqttPacket subPack(sub);
  148 + client->writeMqttPacketAndBlameThisClient(subPack);
  149 +
  150 + waitForCondition([&]() {
  151 + return !this->receivedPackets.empty() && this->receivedPackets.front().packetType == PacketType::SUBACK;
  152 + });
  153 +
  154 + MqttPacket &subAck = this->receivedPackets.front();
  155 + SubAckData data = subAck.parseSubAckData();
  156 +
  157 + if (data.packet_id != packet_id)
  158 + throw std::runtime_error("Incorrect packet id in suback");
  159 +
  160 + if (!std::all_of(data.subAckCodes.begin(), data.subAckCodes.end(), [&](uint8_t x) { return x <= qos ;}))
  161 + {
  162 + throw std::runtime_error("Suback indicates error.");
  163 + }
  164 +}
  165 +
  166 +void FlashMQTestClient::publish(const std::string &topic, const std::string &payload, char qos)
  167 +{
  168 + clearReceivedLists();
  169 +
  170 + const uint16_t packet_id = 77;
  171 +
  172 + Publish pub(topic, payload, qos);
  173 + MqttPacket pubPack(client->getProtocolVersion(), pub);
  174 + if (qos > 0)
  175 + pubPack.setPacketId(packet_id);
  176 + client->writeMqttPacketAndBlameThisClient(pubPack);
  177 +
  178 + if (qos == 1)
  179 + {
  180 + waitForCondition([&]() {
  181 + return this->receivedPackets.size() == 1;
  182 + });
  183 +
  184 + MqttPacket &pubAckPack = this->receivedPackets.front();
  185 + pubAckPack.parsePubAckData();
  186 +
  187 + if (pubAckPack.packetType != PacketType::PUBACK)
  188 + throw std::runtime_error("First packet received from server is not a PUBACK.");
  189 +
  190 + if (pubAckPack.getPacketId() != packet_id || this->receivedPackets.size() != 1)
  191 + throw std::runtime_error("Packet ID mismatch on QoS 1 publish or packet count wrong.");
  192 + }
  193 + else if (qos == 2)
  194 + {
  195 + waitForCondition([&]() {
  196 + return this->receivedPackets.size() >= 2;
  197 + });
  198 +
  199 + MqttPacket &pubRecPack = this->receivedPackets.front();
  200 + pubRecPack.parsePubRecData();
  201 + MqttPacket &pubCompPack = this->receivedPackets.back();
  202 + pubCompPack.parsePubComp();
  203 +
  204 + if (pubRecPack.packetType != PacketType::PUBREC)
  205 + throw std::runtime_error("First packet received from server is not a PUBREC.");
  206 +
  207 + if (pubCompPack.packetType != PacketType::PUBCOMP)
  208 + throw std::runtime_error("Last packet received from server is not a PUBCOMP.");
  209 +
  210 + if (pubRecPack.getPacketId() != packet_id || pubCompPack.getPacketId() != packet_id)
  211 + throw std::runtime_error("Packet ID mismatch on QoS 2 publish.");
  212 + }
  213 +}
  214 +
  215 +void FlashMQTestClient::waitForQuit()
  216 +{
  217 + testServerWorkerThreadData->queueQuit();
  218 + testServerWorkerThreadData->waitForQuit();
  219 +}
  220 +
  221 +void FlashMQTestClient::waitForConnack()
  222 +{
  223 + waitForCondition([&]() {
  224 + return std::any_of(this->receivedPackets.begin(), this->receivedPackets.end(), [](const MqttPacket &p) {
  225 + return p.packetType == PacketType::CONNACK;
  226 + });
  227 + });
  228 +}
  229 +
  230 +void FlashMQTestClient::waitForMessageCount(const size_t count)
  231 +{
  232 + waitForCondition([&]() {
  233 + return this->receivedPublishes.size() >= count;
  234 + });
  235 +}
flashmqtestclient.h 0 โ†’ 100644
  1 +#ifndef FLASHMQTESTCLIENT_H
  2 +#define FLASHMQTESTCLIENT_H
  3 +
  4 +#include <thread>
  5 +#include <memory>
  6 +
  7 +#include "subscriptionstore.h"
  8 +
  9 +/**
  10 + * @brief The FlashMQTestClient class uses the existing server code as a client, for testing purposes.
  11 + */
  12 +class FlashMQTestClient
  13 +{
  14 + std::shared_ptr<Settings> settings;
  15 + std::shared_ptr<ThreadData> testServerWorkerThreadData;
  16 + std::shared_ptr<Client> client;
  17 +
  18 + std::shared_ptr<ThreadData> dummyThreadData;
  19 +
  20 + std::mutex receivedListMutex;
  21 +
  22 + static int clientCount;
  23 +
  24 + void waitForCondition(std::function<bool()> f);
  25 +
  26 +
  27 +public:
  28 + std::list<MqttPacket> receivedPackets;
  29 + std::list<MqttPacket> receivedPublishes;
  30 +
  31 + FlashMQTestClient();
  32 + ~FlashMQTestClient();
  33 +
  34 + void start();
  35 + void connectClient(ProtocolVersion protocolVersion);
  36 + void subscribe(const std::string topic, char qos);
  37 + void publish(const std::string &topic, const std::string &payload, char qos);
  38 + void clearReceivedLists();
  39 +
  40 + void waitForQuit();
  41 + void waitForConnack();
  42 + void waitForMessageCount(const size_t count);
  43 +};
  44 +
  45 +#endif // FLASHMQTESTCLIENT_H
mqttpacket.cpp
@@ -229,6 +229,60 @@ MqttPacket::MqttPacket(const Auth &amp;auth) : @@ -229,6 +229,60 @@ MqttPacket::MqttPacket(const Auth &amp;auth) :
229 calculateRemainingLength(); 229 calculateRemainingLength();
230 } 230 }
231 231
  232 +MqttPacket::MqttPacket(const Connect &connect) :
  233 + bites(connect.getLengthWithoutFixedHeader()),
  234 + protocolVersion(connect.protocolVersion),
  235 + packetType(PacketType::CONNECT)
  236 +{
  237 +#ifndef TESTING
  238 + throw NotImplementedException("Code is only for testing.");
  239 +#endif
  240 +
  241 + first_byte = static_cast<char>(packetType) << 4;
  242 +
  243 + const std::string magicString = connect.getMagicString();
  244 + writeString(magicString);
  245 +
  246 + writeByte(static_cast<char>(protocolVersion));
  247 + writeByte(2); // flags; The only bit set is 'clean session'.
  248 +
  249 + // Keep-alive
  250 + writeUint16(60);
  251 +
  252 + if (connect.protocolVersion >= ProtocolVersion::Mqtt5)
  253 + {
  254 + writeProperties(connect.propertyBuilder);
  255 + }
  256 +
  257 + writeString(connect.clientid);
  258 +
  259 + calculateRemainingLength();
  260 +}
  261 +
  262 +MqttPacket::MqttPacket(const Subscribe &subscribe) :
  263 + bites(subscribe.getLengthWithoutFixedHeader()),
  264 + packetType(PacketType::SUBSCRIBE)
  265 +{
  266 +#ifndef TESTING
  267 + throw NotImplementedException("Code is only for testing.");
  268 +#endif
  269 +
  270 + first_byte = static_cast<char>(packetType) << 4;
  271 + first_byte |= 2; // required reserved bit
  272 +
  273 + writeUint16(subscribe.packetId);
  274 +
  275 + if (subscribe.protocolVersion >= ProtocolVersion::Mqtt5)
  276 + {
  277 + writeProperties(subscribe.propertyBuilder);
  278 + }
  279 +
  280 + writeString(subscribe.topic);
  281 + writeByte(subscribe.qos);
  282 +
  283 + calculateRemainingLength();
  284 +}
  285 +
232 void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender) 286 void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender)
233 { 287 {
234 while (buf.usedBytes() >= MQTT_HEADER_LENGH) 288 while (buf.usedBytes() >= MQTT_HEADER_LENGH)
@@ -1035,6 +1089,8 @@ void MqttPacket::handleUnsubscribe() @@ -1035,6 +1089,8 @@ void MqttPacket::handleUnsubscribe()
1035 1089
1036 void MqttPacket::parsePublishData() 1090 void MqttPacket::parsePublishData()
1037 { 1091 {
  1092 + assert(externallyReceived);
  1093 +
1038 setPosToDataStart(); 1094 setPosToDataStart();
1039 1095
1040 publishData.retain = (first_byte & 0b00000001); 1096 publishData.retain = (first_byte & 0b00000001);
@@ -1214,6 +1270,7 @@ void MqttPacket::handlePublish() @@ -1214,6 +1270,7 @@ void MqttPacket::handlePublish()
1214 void MqttPacket::parsePubAckData() 1270 void MqttPacket::parsePubAckData()
1215 { 1271 {
1216 setPosToDataStart(); 1272 setPosToDataStart();
  1273 + this->publishData.qos = 1;
1217 this->packet_id = readTwoBytesToUInt16(); 1274 this->packet_id = readTwoBytesToUInt16();
1218 } 1275 }
1219 1276
@@ -1554,6 +1611,12 @@ void MqttPacket::writeVariableByteInt(const VariableByteInt &amp;v) @@ -1554,6 +1611,12 @@ void MqttPacket::writeVariableByteInt(const VariableByteInt &amp;v)
1554 writeBytes(v.data(), v.getLen()); 1611 writeBytes(v.data(), v.getLen());
1555 } 1612 }
1556 1613
  1614 +void MqttPacket::writeString(const std::string &s)
  1615 +{
  1616 + writeUint16(s.length());
  1617 + writeBytes(s.c_str(), s.length());
  1618 +}
  1619 +
1557 uint16_t MqttPacket::readTwoBytesToUInt16() 1620 uint16_t MqttPacket::readTwoBytesToUInt16()
1558 { 1621 {
1559 if (pos + 2 > bites.size()) 1622 if (pos + 2 > bites.size())
mqttpacket.h
@@ -78,6 +78,7 @@ class MqttPacket @@ -78,6 +78,7 @@ class MqttPacket
78 void writeBytes(const char *b, size_t len); 78 void writeBytes(const char *b, size_t len);
79 void writeProperties(const std::shared_ptr<Mqtt5PropertyBuilder> &properties); 79 void writeProperties(const std::shared_ptr<Mqtt5PropertyBuilder> &properties);
80 void writeVariableByteInt(const VariableByteInt &v); 80 void writeVariableByteInt(const VariableByteInt &v);
  81 + void writeString(const std::string &s);
81 uint16_t readTwoBytesToUInt16(); 82 uint16_t readTwoBytesToUInt16();
82 uint32_t readFourBytesToUint32(); 83 uint32_t readFourBytesToUint32();
83 size_t remainingAfterPos(); 84 size_t remainingAfterPos();
@@ -89,11 +90,18 @@ class MqttPacket @@ -89,11 +90,18 @@ class MqttPacket
89 void setPosToDataStart(); 90 void setPosToDataStart();
90 bool atEnd() const; 91 bool atEnd() const;
91 92
  93 +#ifndef TESTING
  94 + // In production, I want to be sure I don't accidentally copy packets, because it's slow.
92 MqttPacket(const MqttPacket &other) = delete; 95 MqttPacket(const MqttPacket &other) = delete;
  96 +#endif
93 public: 97 public:
  98 +#ifdef TESTING
  99 + // In testing I need to copy packets for administrative purposes.
  100 + MqttPacket(const MqttPacket &other) = default;
  101 +#endif
94 PacketType packetType = PacketType::Reserved; 102 PacketType packetType = PacketType::Reserved;
95 - MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, std::shared_ptr<Client> &sender); // Constructor for parsing incoming packets.  
96 103
  104 + MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, std::shared_ptr<Client> &sender); // Constructor for parsing incoming packets.
97 MqttPacket(MqttPacket &&other) = default; 105 MqttPacket(MqttPacket &&other) = default;
98 106
99 size_t setClientSpecificPropertiesAndGetRequiredSizeForPublish(const ProtocolVersion protocolVersion, Publish &publishData) const; 107 size_t setClientSpecificPropertiesAndGetRequiredSizeForPublish(const ProtocolVersion protocolVersion, Publish &publishData) const;
@@ -106,6 +114,8 @@ public: @@ -106,6 +114,8 @@ public:
106 MqttPacket(const PubResponse &pubAck); 114 MqttPacket(const PubResponse &pubAck);
107 MqttPacket(const Disconnect &disconnect); 115 MqttPacket(const Disconnect &disconnect);
108 MqttPacket(const Auth &auth); 116 MqttPacket(const Auth &auth);
  117 + MqttPacket(const Connect &connect);
  118 + MqttPacket(const Subscribe &subscribe);
109 119
110 static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender); 120 static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender);
111 121
threadloop.cpp
@@ -110,6 +110,11 @@ void do_thread_work(ThreadData *threadData) @@ -110,6 +110,11 @@ void do_thread_work(ThreadData *threadData)
110 110
111 for (MqttPacket &packet : packetQueueIn) 111 for (MqttPacket &packet : packetQueueIn)
112 { 112 {
  113 +#ifdef TESTING
  114 + if (client->onPacketReceived)
  115 + client->onPacketReceived(packet);
  116 + else
  117 +#endif
113 packet.handle(); 118 packet.handle();
114 } 119 }
115 120
types.cpp
@@ -345,9 +345,57 @@ size_t Auth::getLengthWithoutFixedHeader() const @@ -345,9 +345,57 @@ size_t Auth::getLengthWithoutFixedHeader() const
345 return result; 345 return result;
346 } 346 }
347 347
  348 +Connect::Connect(ProtocolVersion protocolVersion, const std::string &clientid) :
  349 + protocolVersion(protocolVersion),
  350 + clientid(clientid)
  351 +{
  352 +
  353 +}
  354 +
  355 +size_t Connect::getLengthWithoutFixedHeader() const
  356 +{
  357 + size_t result = clientid.length() + 2;
  358 +
  359 + result += this->protocolVersion <= ProtocolVersion::Mqtt31 ? 6 : 4;
  360 + result += 6; // header stuff, lengths, keep-alive
  361 +
  362 + if (this->protocolVersion >= ProtocolVersion::Mqtt5)
  363 + {
  364 + const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1;
  365 + result += proplen;
  366 + }
  367 + return result;
  368 +
  369 +}
348 370
  371 +std::string Connect::getMagicString() const
  372 +{
  373 + if (protocolVersion <= ProtocolVersion::Mqtt31)
  374 + return "MQIsdp";
  375 + else
  376 + return "MQTT";
  377 +}
349 378
  379 +Subscribe::Subscribe(const ProtocolVersion protocolVersion, uint16_t packetId, const std::string &topic, char qos) :
  380 + protocolVersion(protocolVersion),
  381 + packetId(packetId),
  382 + topic(topic),
  383 + qos(qos)
  384 +{
350 385
  386 +}
351 387
  388 +size_t Subscribe::getLengthWithoutFixedHeader() const
  389 +{
  390 + size_t result = topic.size() + 2;
  391 + result += 2; // packet id
  392 + result += 1; // requested QoS
352 393
  394 + if (this->protocolVersion >= ProtocolVersion::Mqtt5)
  395 + {
  396 + const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1;
  397 + result += proplen;
  398 + }
353 399
  400 + return result;
  401 +}
@@ -277,4 +277,34 @@ public: @@ -277,4 +277,34 @@ public:
277 size_t getLengthWithoutFixedHeader() const; 277 size_t getLengthWithoutFixedHeader() const;
278 }; 278 };
279 279
  280 +struct Connect
  281 +{
  282 + const ProtocolVersion protocolVersion;
  283 + std::string clientid;
  284 + std::string username;
  285 + std::string password;
  286 + std::shared_ptr<Mqtt5PropertyBuilder> propertyBuilder;
  287 +
  288 + Connect(ProtocolVersion protocolVersion, const std::string &clientid);
  289 + size_t getLengthWithoutFixedHeader() const;
  290 + std::string getMagicString() const;
  291 +};
  292 +
  293 +/**
  294 + * @brief The Subscribe struct can be used to construct a mqtt packet of type 'subscribe'.
  295 + *
  296 + * It's rudimentary. Offically you can subscribe to multiple topics at once, but I have no need for that.
  297 + */
  298 +struct Subscribe
  299 +{
  300 + const ProtocolVersion protocolVersion;
  301 + uint16_t packetId;
  302 + std::string topic;
  303 + char qos;
  304 + std::shared_ptr<Mqtt5PropertyBuilder> propertyBuilder;
  305 +
  306 + Subscribe(const ProtocolVersion protocolVersion, uint16_t packetId, const std::string &topic, char qos);
  307 + size_t getLengthWithoutFixedHeader() const;
  308 +};
  309 +
280 #endif // TYPES_H 310 #endif // TYPES_H