Commit 56d9c5107589dfcad09af77e0f07d5b050615c72

Authored by Wiebe Cazemier
1 parent 483db298

Add packet based test client

Now that parsing and handling of packets is separated, we can use the
main code to parse packets in the new FlashMQTestClient. This allows
great flexibility in inspecting the server response in a flexible
manner.

We now also have the ability to make tests for MQTT5 features.
FlashMQTests/FlashMQTests.pro
... ... @@ -55,6 +55,7 @@ SOURCES += tst_maintests.cpp \
55 55 ../globalstats.cpp \
56 56 ../derivablecounter.cpp \
57 57 ../packetdatatypes.cpp \
  58 + ../flashmqtestclient.cpp \
58 59 mainappthread.cpp \
59 60 twoclienttestcontext.cpp
60 61  
... ... @@ -102,6 +103,7 @@ HEADERS += \
102 103 ../globalstats.h \
103 104 ../derivablecounter.h \
104 105 ../packetdatatypes.h \
  106 + ../flashmqtestclient.h \
105 107 mainappthread.h \
106 108 twoclienttestcontext.h
107 109  
... ...
FlashMQTests/tst_maintests.cpp
... ... @@ -35,6 +35,8 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
35 35 #include "threaddata.h"
36 36 #include "threadglobals.h"
37 37  
  38 +#include "flashmqtestclient.h"
  39 +
38 40 // Dumb Qt version gives warnings when comparing uint with number literal.
39 41 template <typename T1, typename T2>
40 42 inline bool myCastCompare(const T1 &t1, const T2 &t2, const char *actual, const char *expected,
... ... @@ -56,6 +58,7 @@ class MainTests : public QObject
56 58 Q_OBJECT
57 59  
58 60 QScopedPointer<MainAppThread> mainApp;
  61 + std::shared_ptr<ThreadData> dummyThreadData;
59 62  
60 63 void testParsePacketHelper(const std::string &topic, char from_qos, bool retain);
61 64  
... ... @@ -117,6 +120,8 @@ private slots:
117 120  
118 121 void testUnSubscribe();
119 122  
  123 + void testBasicsWithFlashMQTestClient();
  124 +
120 125 };
121 126  
122 127 MainTests::MainTests()
... ... @@ -135,6 +140,12 @@ void MainTests::init()
135 140 mainApp.reset(new MainAppThread());
136 141 mainApp->start();
137 142 mainApp->waitForStarted();
  143 +
  144 + // We test functions directly that the server normally only calls from worker threads, in which thread data is available. This is kind of a dummy-fix, until
  145 + // we actually need correct thread data at those points (at this point, it's only to increase message counters).
  146 + std::shared_ptr<Settings> settings = std::make_shared<Settings>();
  147 + this->dummyThreadData = std::make_shared<ThreadData>(666, settings);
  148 + ThreadGlobals::assignThreadData(dummyThreadData.get());
138 149 }
139 150  
140 151 void MainTests::cleanup()
... ... @@ -1342,6 +1353,77 @@ void MainTests::testUnSubscribe()
1342 1353 }));
1343 1354 }
1344 1355  
  1356 +/**
  1357 + * @brief MainTests::testBasicsWithFlashMQTestClient was used to develop FlashMQTestClient.
  1358 + */
  1359 +void MainTests::testBasicsWithFlashMQTestClient()
  1360 +{
  1361 + FlashMQTestClient client;
  1362 + client.start();
  1363 + client.connectClient(ProtocolVersion::Mqtt311);
  1364 +
  1365 + MqttPacket &connAckPack = client.receivedPackets.front();
  1366 + QVERIFY(connAckPack.packetType == PacketType::CONNACK);
  1367 +
  1368 + {
  1369 + client.subscribe("a/b", 1);
  1370 +
  1371 + MqttPacket &subAck = client.receivedPackets.front();
  1372 + SubAckData subAckData = subAck.parseSubAckData();
  1373 +
  1374 + QVERIFY(subAckData.subAckCodes.size() == 1);
  1375 + QVERIFY(subAckData.subAckCodes.front() == 1);
  1376 + }
  1377 +
  1378 + {
  1379 + client.subscribe("c/d", 2);
  1380 +
  1381 + MqttPacket &subAck = client.receivedPackets.front();
  1382 + SubAckData subAckData = subAck.parseSubAckData();
  1383 +
  1384 + QVERIFY(subAckData.subAckCodes.size() == 1);
  1385 + QVERIFY(subAckData.subAckCodes.front() == 2);
  1386 + }
  1387 +
  1388 + client.clearReceivedLists();
  1389 +
  1390 + FlashMQTestClient publisher;
  1391 + publisher.start();
  1392 + publisher.connectClient(ProtocolVersion::Mqtt5);
  1393 +
  1394 + {
  1395 + publisher.publish("a/b", "wave", 2);
  1396 +
  1397 + client.waitForMessageCount(1);
  1398 + MqttPacket &p = client.receivedPublishes.front();
  1399 +
  1400 + QCOMPARE(p.getPublishData().topic, "a/b");
  1401 + QCOMPARE(p.getPayloadCopy(), "wave");
  1402 + QCOMPARE(p.getPublishData().qos, 1);
  1403 + QVERIFY(p.getPacketId() > 0);
  1404 + QVERIFY(p.protocolVersion == ProtocolVersion::Mqtt311);
  1405 + }
  1406 +
  1407 + client.clearReceivedLists();
  1408 +
  1409 + {
  1410 + publisher.publish("c/d", "asdfasdfasdf", 2);
  1411 +
  1412 + client.waitForMessageCount(1);
  1413 + MqttPacket &p = client.receivedPublishes.back();
  1414 +
  1415 + MYCASTCOMPARE(client.receivedPublishes.size(), 1);
  1416 +
  1417 + QCOMPARE(p.getPublishData().topic, "c/d");
  1418 + QCOMPARE(p.getPayloadCopy(), "asdfasdfasdf");
  1419 + QCOMPARE(p.getPublishData().qos, 2);
  1420 + QVERIFY(p.getPacketId() > 1); // It's the same client, so it should not re-use packet id 1
  1421 + QVERIFY(p.protocolVersion == ProtocolVersion::Mqtt311);
  1422 + }
  1423 +
  1424 +
  1425 +}
  1426 +
1345 1427  
1346 1428 int main(int argc, char *argv[])
1347 1429 {
... ...
client.cpp
... ... @@ -62,7 +62,7 @@ Client::Client(int fd, std::shared_ptr&lt;ThreadData&gt; threadData, SSL *ssl, bool we
62 62 Client::~Client()
63 63 {
64 64 // Dummy clients, that I sometimes need just because the interface demands it but there's not actually a client, have no thread.
65   - if (this->threadData.expired())
  65 + if (this->epoll_fd == 0)
66 66 return;
67 67  
68 68 if (disconnectReason.empty())
... ...
client.h
... ... @@ -187,6 +187,10 @@ public:
187 187 void setExtendedAuthenticationMethod(const std::string &authMethod);
188 188 const std::string &getExtendedAuthenticationMethod() const;
189 189  
  190 +#ifdef TESTING
  191 + std::function<void(MqttPacket &packet)> onPacketReceived;
  192 +#endif
  193 +
190 194 #ifndef NDEBUG
191 195 void setFakeUpgraded();
192 196 #endif
... ...
flashmqtestclient.cpp 0 โ†’ 100644
  1 +#include "flashmqtestclient.h"
  2 +
  3 +#include <sys/epoll.h>
  4 +#include <cstring>
  5 +#include "errno.h"
  6 +#include "functional"
  7 +
  8 +#include "threadloop.h"
  9 +#include "utils.h"
  10 +#include "client.h"
  11 +
  12 +
  13 +#define TEST_CLIENT_MAX_EVENTS 25
  14 +
  15 +int FlashMQTestClient::clientCount = 0;
  16 +
  17 +FlashMQTestClient::FlashMQTestClient() :
  18 + settings(std::make_shared<Settings>()),
  19 + testServerWorkerThreadData(std::make_shared<ThreadData>(0, settings)),
  20 + dummyThreadData(std::make_shared<ThreadData>(666, settings))
  21 +{
  22 +
  23 +}
  24 +
  25 +/**
  26 + * @brief FlashMQTestClient::~FlashMQTestClient properly quits the threads when exiting.
  27 + *
  28 + * This prevents accidental crashes on calling terminate(), and Qt Macro's prematurely end the method, skipping explicit waits after the tests.
  29 + */
  30 +FlashMQTestClient::~FlashMQTestClient()
  31 +{
  32 + waitForQuit();
  33 +}
  34 +
  35 +void FlashMQTestClient::waitForCondition(std::function<bool()> f)
  36 +{
  37 + int n = 0;
  38 + while(n++ < 100)
  39 + {
  40 + usleep(10000);
  41 +
  42 + std::lock_guard<std::mutex> locker(receivedListMutex);
  43 +
  44 + if (f())
  45 + break;
  46 + }
  47 +
  48 + std::lock_guard<std::mutex> locker(receivedListMutex);
  49 +
  50 + if (!f())
  51 + {
  52 + throw std::runtime_error("Wait condition failed.");
  53 + }
  54 +}
  55 +
  56 +void FlashMQTestClient::clearReceivedLists()
  57 +{
  58 + std::lock_guard<std::mutex> locker(receivedListMutex);
  59 + receivedPackets.clear();
  60 + receivedPublishes.clear();
  61 +}
  62 +
  63 +void FlashMQTestClient::start()
  64 +{
  65 + testServerWorkerThreadData->start(&do_thread_work);
  66 +}
  67 +
  68 +void FlashMQTestClient::connectClient(ProtocolVersion protocolVersion)
  69 +{
  70 + int sockfd = check<std::runtime_error>(socket(AF_INET, SOCK_STREAM, 0));
  71 +
  72 + struct sockaddr_in servaddr;
  73 + bzero(&servaddr, sizeof(servaddr));
  74 +
  75 + const std::string hostname = "127.0.0.1";
  76 +
  77 + servaddr.sin_family = AF_INET;
  78 + servaddr.sin_addr.s_addr = inet_addr(hostname.c_str());
  79 + servaddr.sin_port = htons(1883);
  80 +
  81 + int flags = fcntl(sockfd, F_GETFL);
  82 + fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);
  83 +
  84 + int rc = connect(sockfd, reinterpret_cast<struct sockaddr*>(&servaddr), sizeof (servaddr));
  85 +
  86 + if (rc < 0 && errno != EINPROGRESS)
  87 + {
  88 + throw std::runtime_error(strerror(errno));
  89 + }
  90 +
  91 + const std::string clientid = formatString("testclient_%d", clientCount++);
  92 +
  93 + this->client = std::make_shared<Client>(sockfd, testServerWorkerThreadData, nullptr, false, reinterpret_cast<struct sockaddr*>(&servaddr), settings.get());
  94 + this->client->setClientProperties(protocolVersion, clientid, "user", false, 60);
  95 +
  96 + testServerWorkerThreadData->giveClient(this->client);
  97 +
  98 + // This gets called in the test client's worker thread, but the STL container's minimal thread safety should be enough: only list manipulation is
  99 + // mutexed, elements within are not.
  100 + client->onPacketReceived = [&](MqttPacket &pack)
  101 + {
  102 + std::lock_guard<std::mutex> locker(receivedListMutex);
  103 +
  104 + if (pack.packetType == PacketType::PUBLISH)
  105 + {
  106 + pack.parsePublishData();
  107 +
  108 + MqttPacket copyPacket = pack;
  109 + this->receivedPublishes.push_back(copyPacket);
  110 +
  111 + if (pack.getPublishData().qos > 0)
  112 + {
  113 + PubResponse pubAck(this->client->getProtocolVersion(), PacketType::PUBACK, ReasonCodes::Success, pack.getPacketId());
  114 + this->client->writeMqttPacketAndBlameThisClient(pubAck);
  115 + }
  116 + }
  117 + else if (pack.packetType == PacketType::PUBREL)
  118 + {
  119 + pack.parsePubRelData();
  120 + PubResponse pubComp(this->client->getProtocolVersion(), PacketType::PUBCOMP, ReasonCodes::Success, pack.getPacketId());
  121 + this->client->writeMqttPacketAndBlameThisClient(pubComp);
  122 + }
  123 + else if (pack.packetType == PacketType::PUBREC)
  124 + {
  125 + pack.parsePubRecData();
  126 + PubResponse pubRel(this->client->getProtocolVersion(), PacketType::PUBREL, ReasonCodes::Success, pack.getPacketId());
  127 + this->client->writeMqttPacketAndBlameThisClient(pubRel);
  128 + }
  129 +
  130 + this->receivedPackets.push_back(std::move(pack));
  131 + };
  132 +
  133 + Connect connect(protocolVersion, client->getClientId());
  134 + MqttPacket connectPack(connect);
  135 + this->client->writeMqttPacketAndBlameThisClient(connectPack);
  136 +
  137 + waitForConnack();
  138 +}
  139 +
  140 +void FlashMQTestClient::subscribe(const std::string topic, char qos)
  141 +{
  142 + clearReceivedLists();
  143 +
  144 + const uint16_t packet_id = 66;
  145 +
  146 + Subscribe sub(client->getProtocolVersion(), packet_id, topic, qos);
  147 + MqttPacket subPack(sub);
  148 + client->writeMqttPacketAndBlameThisClient(subPack);
  149 +
  150 + waitForCondition([&]() {
  151 + return !this->receivedPackets.empty() && this->receivedPackets.front().packetType == PacketType::SUBACK;
  152 + });
  153 +
  154 + MqttPacket &subAck = this->receivedPackets.front();
  155 + SubAckData data = subAck.parseSubAckData();
  156 +
  157 + if (data.packet_id != packet_id)
  158 + throw std::runtime_error("Incorrect packet id in suback");
  159 +
  160 + if (!std::all_of(data.subAckCodes.begin(), data.subAckCodes.end(), [&](uint8_t x) { return x <= qos ;}))
  161 + {
  162 + throw std::runtime_error("Suback indicates error.");
  163 + }
  164 +}
  165 +
  166 +void FlashMQTestClient::publish(const std::string &topic, const std::string &payload, char qos)
  167 +{
  168 + clearReceivedLists();
  169 +
  170 + const uint16_t packet_id = 77;
  171 +
  172 + Publish pub(topic, payload, qos);
  173 + MqttPacket pubPack(client->getProtocolVersion(), pub);
  174 + if (qos > 0)
  175 + pubPack.setPacketId(packet_id);
  176 + client->writeMqttPacketAndBlameThisClient(pubPack);
  177 +
  178 + if (qos == 1)
  179 + {
  180 + waitForCondition([&]() {
  181 + return this->receivedPackets.size() == 1;
  182 + });
  183 +
  184 + MqttPacket &pubAckPack = this->receivedPackets.front();
  185 + pubAckPack.parsePubAckData();
  186 +
  187 + if (pubAckPack.packetType != PacketType::PUBACK)
  188 + throw std::runtime_error("First packet received from server is not a PUBACK.");
  189 +
  190 + if (pubAckPack.getPacketId() != packet_id || this->receivedPackets.size() != 1)
  191 + throw std::runtime_error("Packet ID mismatch on QoS 1 publish or packet count wrong.");
  192 + }
  193 + else if (qos == 2)
  194 + {
  195 + waitForCondition([&]() {
  196 + return this->receivedPackets.size() >= 2;
  197 + });
  198 +
  199 + MqttPacket &pubRecPack = this->receivedPackets.front();
  200 + pubRecPack.parsePubRecData();
  201 + MqttPacket &pubCompPack = this->receivedPackets.back();
  202 + pubCompPack.parsePubComp();
  203 +
  204 + if (pubRecPack.packetType != PacketType::PUBREC)
  205 + throw std::runtime_error("First packet received from server is not a PUBREC.");
  206 +
  207 + if (pubCompPack.packetType != PacketType::PUBCOMP)
  208 + throw std::runtime_error("Last packet received from server is not a PUBCOMP.");
  209 +
  210 + if (pubRecPack.getPacketId() != packet_id || pubCompPack.getPacketId() != packet_id)
  211 + throw std::runtime_error("Packet ID mismatch on QoS 2 publish.");
  212 + }
  213 +}
  214 +
  215 +void FlashMQTestClient::waitForQuit()
  216 +{
  217 + testServerWorkerThreadData->queueQuit();
  218 + testServerWorkerThreadData->waitForQuit();
  219 +}
  220 +
  221 +void FlashMQTestClient::waitForConnack()
  222 +{
  223 + waitForCondition([&]() {
  224 + return std::any_of(this->receivedPackets.begin(), this->receivedPackets.end(), [](const MqttPacket &p) {
  225 + return p.packetType == PacketType::CONNACK;
  226 + });
  227 + });
  228 +}
  229 +
  230 +void FlashMQTestClient::waitForMessageCount(const size_t count)
  231 +{
  232 + waitForCondition([&]() {
  233 + return this->receivedPublishes.size() >= count;
  234 + });
  235 +}
... ...
flashmqtestclient.h 0 โ†’ 100644
  1 +#ifndef FLASHMQTESTCLIENT_H
  2 +#define FLASHMQTESTCLIENT_H
  3 +
  4 +#include <thread>
  5 +#include <memory>
  6 +
  7 +#include "subscriptionstore.h"
  8 +
  9 +/**
  10 + * @brief The FlashMQTestClient class uses the existing server code as a client, for testing purposes.
  11 + */
  12 +class FlashMQTestClient
  13 +{
  14 + std::shared_ptr<Settings> settings;
  15 + std::shared_ptr<ThreadData> testServerWorkerThreadData;
  16 + std::shared_ptr<Client> client;
  17 +
  18 + std::shared_ptr<ThreadData> dummyThreadData;
  19 +
  20 + std::mutex receivedListMutex;
  21 +
  22 + static int clientCount;
  23 +
  24 + void waitForCondition(std::function<bool()> f);
  25 +
  26 +
  27 +public:
  28 + std::list<MqttPacket> receivedPackets;
  29 + std::list<MqttPacket> receivedPublishes;
  30 +
  31 + FlashMQTestClient();
  32 + ~FlashMQTestClient();
  33 +
  34 + void start();
  35 + void connectClient(ProtocolVersion protocolVersion);
  36 + void subscribe(const std::string topic, char qos);
  37 + void publish(const std::string &topic, const std::string &payload, char qos);
  38 + void clearReceivedLists();
  39 +
  40 + void waitForQuit();
  41 + void waitForConnack();
  42 + void waitForMessageCount(const size_t count);
  43 +};
  44 +
  45 +#endif // FLASHMQTESTCLIENT_H
... ...
mqttpacket.cpp
... ... @@ -229,6 +229,60 @@ MqttPacket::MqttPacket(const Auth &amp;auth) :
229 229 calculateRemainingLength();
230 230 }
231 231  
  232 +MqttPacket::MqttPacket(const Connect &connect) :
  233 + bites(connect.getLengthWithoutFixedHeader()),
  234 + protocolVersion(connect.protocolVersion),
  235 + packetType(PacketType::CONNECT)
  236 +{
  237 +#ifndef TESTING
  238 + throw NotImplementedException("Code is only for testing.");
  239 +#endif
  240 +
  241 + first_byte = static_cast<char>(packetType) << 4;
  242 +
  243 + const std::string magicString = connect.getMagicString();
  244 + writeString(magicString);
  245 +
  246 + writeByte(static_cast<char>(protocolVersion));
  247 + writeByte(2); // flags; The only bit set is 'clean session'.
  248 +
  249 + // Keep-alive
  250 + writeUint16(60);
  251 +
  252 + if (connect.protocolVersion >= ProtocolVersion::Mqtt5)
  253 + {
  254 + writeProperties(connect.propertyBuilder);
  255 + }
  256 +
  257 + writeString(connect.clientid);
  258 +
  259 + calculateRemainingLength();
  260 +}
  261 +
  262 +MqttPacket::MqttPacket(const Subscribe &subscribe) :
  263 + bites(subscribe.getLengthWithoutFixedHeader()),
  264 + packetType(PacketType::SUBSCRIBE)
  265 +{
  266 +#ifndef TESTING
  267 + throw NotImplementedException("Code is only for testing.");
  268 +#endif
  269 +
  270 + first_byte = static_cast<char>(packetType) << 4;
  271 + first_byte |= 2; // required reserved bit
  272 +
  273 + writeUint16(subscribe.packetId);
  274 +
  275 + if (subscribe.protocolVersion >= ProtocolVersion::Mqtt5)
  276 + {
  277 + writeProperties(subscribe.propertyBuilder);
  278 + }
  279 +
  280 + writeString(subscribe.topic);
  281 + writeByte(subscribe.qos);
  282 +
  283 + calculateRemainingLength();
  284 +}
  285 +
232 286 void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender)
233 287 {
234 288 while (buf.usedBytes() >= MQTT_HEADER_LENGH)
... ... @@ -1035,6 +1089,8 @@ void MqttPacket::handleUnsubscribe()
1035 1089  
1036 1090 void MqttPacket::parsePublishData()
1037 1091 {
  1092 + assert(externallyReceived);
  1093 +
1038 1094 setPosToDataStart();
1039 1095  
1040 1096 publishData.retain = (first_byte & 0b00000001);
... ... @@ -1214,6 +1270,7 @@ void MqttPacket::handlePublish()
1214 1270 void MqttPacket::parsePubAckData()
1215 1271 {
1216 1272 setPosToDataStart();
  1273 + this->publishData.qos = 1;
1217 1274 this->packet_id = readTwoBytesToUInt16();
1218 1275 }
1219 1276  
... ... @@ -1554,6 +1611,12 @@ void MqttPacket::writeVariableByteInt(const VariableByteInt &amp;v)
1554 1611 writeBytes(v.data(), v.getLen());
1555 1612 }
1556 1613  
  1614 +void MqttPacket::writeString(const std::string &s)
  1615 +{
  1616 + writeUint16(s.length());
  1617 + writeBytes(s.c_str(), s.length());
  1618 +}
  1619 +
1557 1620 uint16_t MqttPacket::readTwoBytesToUInt16()
1558 1621 {
1559 1622 if (pos + 2 > bites.size())
... ...
mqttpacket.h
... ... @@ -78,6 +78,7 @@ class MqttPacket
78 78 void writeBytes(const char *b, size_t len);
79 79 void writeProperties(const std::shared_ptr<Mqtt5PropertyBuilder> &properties);
80 80 void writeVariableByteInt(const VariableByteInt &v);
  81 + void writeString(const std::string &s);
81 82 uint16_t readTwoBytesToUInt16();
82 83 uint32_t readFourBytesToUint32();
83 84 size_t remainingAfterPos();
... ... @@ -89,11 +90,18 @@ class MqttPacket
89 90 void setPosToDataStart();
90 91 bool atEnd() const;
91 92  
  93 +#ifndef TESTING
  94 + // In production, I want to be sure I don't accidentally copy packets, because it's slow.
92 95 MqttPacket(const MqttPacket &other) = delete;
  96 +#endif
93 97 public:
  98 +#ifdef TESTING
  99 + // In testing I need to copy packets for administrative purposes.
  100 + MqttPacket(const MqttPacket &other) = default;
  101 +#endif
94 102 PacketType packetType = PacketType::Reserved;
95   - MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, std::shared_ptr<Client> &sender); // Constructor for parsing incoming packets.
96 103  
  104 + MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, std::shared_ptr<Client> &sender); // Constructor for parsing incoming packets.
97 105 MqttPacket(MqttPacket &&other) = default;
98 106  
99 107 size_t setClientSpecificPropertiesAndGetRequiredSizeForPublish(const ProtocolVersion protocolVersion, Publish &publishData) const;
... ... @@ -106,6 +114,8 @@ public:
106 114 MqttPacket(const PubResponse &pubAck);
107 115 MqttPacket(const Disconnect &disconnect);
108 116 MqttPacket(const Auth &auth);
  117 + MqttPacket(const Connect &connect);
  118 + MqttPacket(const Subscribe &subscribe);
109 119  
110 120 static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender);
111 121  
... ...
threadloop.cpp
... ... @@ -110,6 +110,11 @@ void do_thread_work(ThreadData *threadData)
110 110  
111 111 for (MqttPacket &packet : packetQueueIn)
112 112 {
  113 +#ifdef TESTING
  114 + if (client->onPacketReceived)
  115 + client->onPacketReceived(packet);
  116 + else
  117 +#endif
113 118 packet.handle();
114 119 }
115 120  
... ...
types.cpp
... ... @@ -345,9 +345,57 @@ size_t Auth::getLengthWithoutFixedHeader() const
345 345 return result;
346 346 }
347 347  
  348 +Connect::Connect(ProtocolVersion protocolVersion, const std::string &clientid) :
  349 + protocolVersion(protocolVersion),
  350 + clientid(clientid)
  351 +{
  352 +
  353 +}
  354 +
  355 +size_t Connect::getLengthWithoutFixedHeader() const
  356 +{
  357 + size_t result = clientid.length() + 2;
  358 +
  359 + result += this->protocolVersion <= ProtocolVersion::Mqtt31 ? 6 : 4;
  360 + result += 6; // header stuff, lengths, keep-alive
  361 +
  362 + if (this->protocolVersion >= ProtocolVersion::Mqtt5)
  363 + {
  364 + const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1;
  365 + result += proplen;
  366 + }
  367 + return result;
  368 +
  369 +}
348 370  
  371 +std::string Connect::getMagicString() const
  372 +{
  373 + if (protocolVersion <= ProtocolVersion::Mqtt31)
  374 + return "MQIsdp";
  375 + else
  376 + return "MQTT";
  377 +}
349 378  
  379 +Subscribe::Subscribe(const ProtocolVersion protocolVersion, uint16_t packetId, const std::string &topic, char qos) :
  380 + protocolVersion(protocolVersion),
  381 + packetId(packetId),
  382 + topic(topic),
  383 + qos(qos)
  384 +{
350 385  
  386 +}
351 387  
  388 +size_t Subscribe::getLengthWithoutFixedHeader() const
  389 +{
  390 + size_t result = topic.size() + 2;
  391 + result += 2; // packet id
  392 + result += 1; // requested QoS
352 393  
  394 + if (this->protocolVersion >= ProtocolVersion::Mqtt5)
  395 + {
  396 + const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1;
  397 + result += proplen;
  398 + }
353 399  
  400 + return result;
  401 +}
... ...
... ... @@ -277,4 +277,34 @@ public:
277 277 size_t getLengthWithoutFixedHeader() const;
278 278 };
279 279  
  280 +struct Connect
  281 +{
  282 + const ProtocolVersion protocolVersion;
  283 + std::string clientid;
  284 + std::string username;
  285 + std::string password;
  286 + std::shared_ptr<Mqtt5PropertyBuilder> propertyBuilder;
  287 +
  288 + Connect(ProtocolVersion protocolVersion, const std::string &clientid);
  289 + size_t getLengthWithoutFixedHeader() const;
  290 + std::string getMagicString() const;
  291 +};
  292 +
  293 +/**
  294 + * @brief The Subscribe struct can be used to construct a mqtt packet of type 'subscribe'.
  295 + *
  296 + * It's rudimentary. Offically you can subscribe to multiple topics at once, but I have no need for that.
  297 + */
  298 +struct Subscribe
  299 +{
  300 + const ProtocolVersion protocolVersion;
  301 + uint16_t packetId;
  302 + std::string topic;
  303 + char qos;
  304 + std::shared_ptr<Mqtt5PropertyBuilder> propertyBuilder;
  305 +
  306 + Subscribe(const ProtocolVersion protocolVersion, uint16_t packetId, const std::string &topic, char qos);
  307 + size_t getLengthWithoutFixedHeader() const;
  308 +};
  309 +
280 310 #endif // TYPES_H
... ...