From cd447414bb397e61e0ea293fba84a2691c8e4605 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Sun, 22 May 2022 12:01:09 +0200 Subject: [PATCH] Tests for saving clientid and username for publishes --- FlashMQTests/tst_maintests.cpp | 23 ++++++++++++++++++++++- persistencefile.cpp | 19 +++++++++++++++++++ persistencefile.h | 2 ++ retainedmessagesdb.cpp | 8 ++++++++ sessionsandsubscriptionsdb.cpp | 14 ++++++++++++++ 5 files changed, 65 insertions(+), 1 deletion(-) diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 1e3dc41..17cd7af 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -934,6 +934,14 @@ void MainTests::testRetainedMessageDB() messages.emplace_back(Publish("/boe", longpayload, 1)); messages.emplace_back(Publish("one", "µsdf", 1)); + int clientidCount = 1; + int usernameCount = 1; + for (RetainedMessage &rm : messages) + { + rm.publish.client_id = formatString("Clientid__%d", clientidCount++); + rm.publish.username = formatString("Username__%d", usernameCount++); + } + RetainedMessagesDB db("/tmp/flashmqtests_retained.db"); db.openWrite(); db.saveData(messages); @@ -958,6 +966,11 @@ void MainTests::testRetainedMessageDB() QCOMPARE(one.publish.payload, two.publish.payload); QCOMPARE(one.publish.qos, two.publish.qos); + QVERIFY(!two.publish.client_id.empty()); + QVERIFY(!two.publish.username.empty()); + QCOMPARE(two.publish.client_id, one.publish.client_id); + QCOMPARE(two.publish.username, one.publish.username); + itOrg++; itLoaded++; } @@ -1060,6 +1073,8 @@ void MainTests::testSavingSessions() store->addSubscription(c2, topic4, subtopics, 0); Publish publish("a/b/c", "Hello Barry", 1); + publish.client_id = "ClientIdFromFakePublisher"; + publish.username = "UsernameFromFakePublisher"; std::shared_ptr c1ses = c1->getSession(); c1.reset(); @@ -1088,7 +1103,6 @@ void MainTests::testSavingSessions() QCOMPARE(ses->nextPacketId, ses2->nextPacketId); } - std::unordered_map> store1Subscriptions; store->getSubscriptions(&store->root, "", true, store1Subscriptions); @@ -1121,7 +1135,14 @@ void MainTests::testSavingSessions() } + std::shared_ptr loadedSes = store2->sessionsById["c1"]; + QueuedPublish queuedPublishLoaded = *loadedSes->qosPacketQueue.begin(); + QCOMPARE(queuedPublishLoaded.getPublish().topic, "a/b/c"); + QCOMPARE(queuedPublishLoaded.getPublish().payload, "Hello Barry"); + QCOMPARE(queuedPublishLoaded.getPublish().qos, 1); + QCOMPARE(queuedPublishLoaded.getPublish().client_id, "ClientIdFromFakePublisher"); + QCOMPARE(queuedPublishLoaded.getPublish().username, "UsernameFromFakePublisher"); } catch (std::exception &ex) { diff --git a/persistencefile.cpp b/persistencefile.cpp index 87dc773..a41965b 100644 --- a/persistencefile.cpp +++ b/persistencefile.cpp @@ -219,6 +219,12 @@ void PersistenceFile::writeUint16(const uint16_t val) writeCheck(buf, 1, 2, f); } +void PersistenceFile::writeString(const std::string &s) +{ + writeUint32(s.size()); + writeCheck(s.c_str(), 1, s.size(), f); +} + int64_t PersistenceFile::readInt64(bool &eofFound) { if (readCheck(buf.data(), 1, 8, f) < 0) @@ -253,6 +259,19 @@ uint16_t PersistenceFile::readUint16(bool &eofFound) return val; } +std::string PersistenceFile::readString(bool &eofFound) +{ + const uint32_t size = readUint32(eofFound); + + if (size > 0xFFFF) + throw std::runtime_error("In MQTT world, strings are never longer than 65535 bytes."); + + makeSureBufSize(size); + readCheck(buf.data(), 1, size, f); + std::string result(buf.data(), size); + return result; +} + /** * @brief RetainedMessagesDB::openWrite doesn't explicitely name a file version (v1, etc), because we always write the current definition. */ diff --git a/persistencefile.h b/persistencefile.h index 2932e39..3acc924 100644 --- a/persistencefile.h +++ b/persistencefile.h @@ -76,9 +76,11 @@ protected: void writeInt64(const int64_t val); void writeUint32(const uint32_t val); void writeUint16(const uint16_t val); + void writeString(const std::string &s); int64_t readInt64(bool &eofFound); uint32_t readUint32(bool &eofFound); uint16_t readUint16(bool &eofFound); + std::string readString(bool &eofFound); public: PersistenceFile(const std::string &filePath); diff --git a/retainedmessagesdb.cpp b/retainedmessagesdb.cpp index e401e48..3821e1a 100644 --- a/retainedmessagesdb.cpp +++ b/retainedmessagesdb.cpp @@ -88,6 +88,8 @@ void RetainedMessagesDB::saveData(const std::vector &messages) writeUint16(pack.getFixedHeaderLength()); writeUint32(packSize); + writeString(pcopy.client_id); + writeString(pcopy.username); writeCheck(cirbuf.tailPtr(), 1, cirbuf.usedBytes(), f); } @@ -136,6 +138,9 @@ std::list RetainedMessagesDB::readDataV2() const uint16_t fixed_header_length = readUint16(eofFound); const uint32_t packlen = readUint32(eofFound); + const std::string client_id = readString(eofFound); + const std::string username = readString(eofFound); + if (eofFound) continue; @@ -149,6 +154,9 @@ std::list RetainedMessagesDB::readDataV2() pack.parsePublishData(); Publish pub(pack.getPublishData()); + pub.client_id = client_id; + pub.username = username; + RetainedMessage msg(pub); logger->logf(LOG_DEBUG, "Loading retained message for topic '%s' QoS %d.", msg.publish.topic.c_str(), msg.publish.qos); messages.push_back(std::move(msg)); diff --git a/sessionsandsubscriptionsdb.cpp b/sessionsandsubscriptionsdb.cpp index cb39e24..9493ff9 100644 --- a/sessionsandsubscriptionsdb.cpp +++ b/sessionsandsubscriptionsdb.cpp @@ -115,6 +115,8 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2() const uint16_t id = readUint16(eofFound); const uint32_t originalPubAge = readUint32(eofFound); const uint32_t packlen = readUint32(eofFound); + const std::string sender_clientid = readString(eofFound); + const std::string sender_username = readString(eofFound); assert(id > 0); @@ -128,6 +130,9 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2() pack.parsePublishData(); Publish pub(pack.getPublishData()); + pub.client_id = sender_clientid; + pub.username = sender_username; + const uint32_t newPubAge = persistence_state_age + originalPubAge; pub.createdAt = timepointFromAge(newPubAge); @@ -175,6 +180,8 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2() const uint32_t originalWillQueueAge = readUint32(eofFound); const uint32_t newWillDelayAfterMaybeAlreadyBeingQueued = originalWillQueueAge < originalWillDelay ? originalWillDelay - originalWillQueueAge : 0; const uint32_t packlen = readUint32(eofFound); + const std::string sender_clientid = readString(eofFound); + const std::string sender_username = readString(eofFound); const uint32_t stateAgecompensatedWillDelay = persistence_state_age > newWillDelayAfterMaybeAlreadyBeingQueued ? 0 : newWillDelayAfterMaybeAlreadyBeingQueued - persistence_state_age; @@ -189,6 +196,9 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2() WillPublish willPublish = publishpack.getPublishData(); willPublish.will_delay = stateAgecompensatedWillDelay; + willPublish.client_id = sender_clientid; + willPublish.username = sender_username; + ses->setWill(std::move(willPublish)); } } @@ -288,6 +298,8 @@ void SessionsAndSubscriptionsDB::saveData(const std::vector