Commit 75657bf52b8203d4483c5de0b86ba560d7f3221d

Authored by Wiebe Cazemier
1 parent 515f796f

Convert to storing Publish object for retained message

This allows easier saving of MQTT5 properties, for which a new file
version for retained messages is created. It uses the packet parsing logic.
FlashMQTests/tst_maintests.cpp
... ... @@ -895,20 +895,19 @@ void MainTests::testRetainedMessageDB()
895 895 std::string longTopic = formatString("one/two/%s", getSecureRandomString(4000).c_str());
896 896  
897 897 std::vector<RetainedMessage> messages;
898   - messages.emplace_back("one/two/three", "payload", 0);
899   - messages.emplace_back("one/two/wer", "payload", 1);
900   - messages.emplace_back("one/e/wer", "payload", 1);
901   - messages.emplace_back("one/wee/wer", "asdfasdfasdf", 1);
902   - messages.emplace_back("one/two/wer", "µsdf", 1);
903   - messages.emplace_back("/boe/bah", longpayload, 1);
904   - messages.emplace_back("one/two/wer", "paylasdfaoad", 1);
905   - messages.emplace_back("one/two/wer", "payload", 1);
906   - messages.emplace_back(longTopic, "payload", 1);
907   - messages.emplace_back(longTopic, longpayload, 1);
908   - messages.emplace_back("one", "µsdf", 1);
909   - messages.emplace_back("/boe", longpayload, 1);
910   - messages.emplace_back("one", "µsdf", 1);
911   - messages.emplace_back("", "foremptytopic", 0);
  898 + messages.emplace_back(Publish("one/two/three", "payload", 0));
  899 + messages.emplace_back(Publish("one/two/wer", "payload", 1));
  900 + messages.emplace_back(Publish("one/e/wer", "payload", 1));
  901 + messages.emplace_back(Publish("one/wee/wer", "asdfasdfasdf", 1));
  902 + messages.emplace_back(Publish("one/two/wer", "µsdf", 1));
  903 + messages.emplace_back(Publish("/boe/bah", longpayload, 1));
  904 + messages.emplace_back(Publish("one/two/wer", "paylasdfaoad", 1));
  905 + messages.emplace_back(Publish("one/two/wer", "payload", 1));
  906 + messages.emplace_back(Publish(longTopic, "payload", 1));
  907 + messages.emplace_back(Publish(longTopic, longpayload, 1));
  908 + messages.emplace_back(Publish("one", "µsdf", 1));
  909 + messages.emplace_back(Publish("/boe", longpayload, 1));
  910 + messages.emplace_back(Publish("one", "µsdf", 1));
912 911  
913 912 RetainedMessagesDB db("/tmp/flashmqtests_retained.db");
914 913 db.openWrite();
... ... @@ -920,7 +919,7 @@ void MainTests::testRetainedMessageDB()
920 919 std::list<RetainedMessage> messagesLoaded = db2.readData();
921 920 db2.closeFile();
922 921  
923   - QCOMPARE(messages.size(), messagesLoaded.size());
  922 + QCOMPARE(messagesLoaded.size(), messages.size());
924 923  
925 924 auto itOrg = messages.begin();
926 925 auto itLoaded = messagesLoaded.begin();
... ... @@ -930,9 +929,9 @@ void MainTests::testRetainedMessageDB()
930 929 RetainedMessage &two = *itLoaded;
931 930  
932 931 // Comparing the fields because the RetainedMessage class has an == operator that only looks at topic.
933   - QCOMPARE(one.topic, two.topic);
934   - QCOMPARE(one.payload, two.payload);
935   - QCOMPARE(one.qos, two.qos);
  932 + QCOMPARE(one.publish.topic, two.publish.topic);
  933 + QCOMPARE(one.publish.payload, two.publish.payload);
  934 + QCOMPARE(one.publish.qos, two.publish.qos);
936 935  
937 936 itOrg++;
938 937 itLoaded++;
... ...
mqttpacket.cpp
... ... @@ -1017,8 +1017,8 @@ void MqttPacket::handlePublish()
1017 1017 {
1018 1018 if (publishData.retain)
1019 1019 {
1020   - std::string payload(readBytes(payloadLen), payloadLen);
1021   - sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(publishData.topic, publishData.subtopics, payload, publishData.qos);
  1020 + publishData.payload = getPayloadCopy();
  1021 + sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(publishData, publishData.subtopics);
1022 1022 }
1023 1023  
1024 1024 // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3].
... ... @@ -1420,6 +1420,7 @@ bool MqttPacket::containsClientSpecificProperties() const
1420 1420 void MqttPacket::readIntoBuf(CirBuf &buf) const
1421 1421 {
1422 1422 assert(packetType != PacketType::PUBLISH || (first_byte & 0b00000110) >> 1 == publishData.qos);
  1423 + assert(publishData.qos == 0 || packet_id > 0);
1423 1424  
1424 1425 buf.ensureFreeSpace(getSizeIncludingNonPresentHeader());
1425 1426  
... ...
retainedmessage.cpp
... ... @@ -17,25 +17,24 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
17 17  
18 18 #include "retainedmessage.h"
19 19  
20   -RetainedMessage::RetainedMessage(const std::string &topic, const std::string &payload, char qos) :
21   - topic(topic),
22   - payload(payload),
23   - qos(qos)
  20 +RetainedMessage::RetainedMessage(const Publish &publish) :
  21 + publish(publish)
24 22 {
25   -
  23 + this->publish.retain = true;
  24 + this->publish.splitTopic = false;
26 25 }
27 26  
28 27 bool RetainedMessage::operator==(const RetainedMessage &rhs) const
29 28 {
30   - return this->topic == rhs.topic;
  29 + return this->publish.topic == rhs.publish.topic;
31 30 }
32 31  
33 32 bool RetainedMessage::empty() const
34 33 {
35   - return payload.empty();
  34 + return publish.payload.empty();
36 35 }
37 36  
38 37 uint32_t RetainedMessage::getSize() const
39 38 {
40   - return topic.length() + payload.length() + 1;
  39 + return publish.topic.length() + publish.payload.length() + 1;
41 40 }
... ...
retainedmessage.h
... ... @@ -19,14 +19,13 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
19 19 #define RETAINEDMESSAGE_H
20 20  
21 21 #include <string>
  22 +#include "types.h"
22 23  
23 24 struct RetainedMessage
24 25 {
25   - std::string topic;
26   - std::string payload;
27   - char qos;
  26 + Publish publish;
28 27  
29   - RetainedMessage(const std::string &topic, const std::string &payload, char qos);
  28 + RetainedMessage(const Publish &publish);
30 29  
31 30 bool operator==(const RetainedMessage &rhs) const;
32 31 bool empty() const;
... ... @@ -44,7 +43,7 @@ namespace std {
44 43 using std::hash;
45 44 using std::string;
46 45  
47   - return hash<string>()(k.topic);
  46 + return hash<string>()(k.publish.topic);
48 47 }
49 48 };
50 49  
... ...
retainedmessagesdb.cpp
... ... @@ -27,6 +27,8 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
27 27 #include "retainedmessagesdb.h"
28 28 #include "utils.h"
29 29 #include "logger.h"
  30 +#include "mqttpacket.h"
  31 +#include "threadglobals.h"
30 32  
31 33 RetainedMessagesDB::RetainedMessagesDB(const std::string &filePath) : PersistenceFile(filePath)
32 34 {
... ... @@ -35,7 +37,7 @@ RetainedMessagesDB::RetainedMessagesDB(const std::string &amp;filePath) : Persistenc
35 37  
36 38 void RetainedMessagesDB::openWrite()
37 39 {
38   - PersistenceFile::openWrite(MAGIC_STRING_V1);
  40 + PersistenceFile::openWrite(MAGIC_STRING_V2);
39 41 }
40 42  
41 43 void RetainedMessagesDB::openRead()
... ... @@ -44,35 +46,13 @@ void RetainedMessagesDB::openRead()
44 46  
45 47 if (detectedVersionString == MAGIC_STRING_V1)
46 48 readVersion = ReadVersion::v1;
  49 + else if (detectedVersionString == MAGIC_STRING_V2)
  50 + readVersion = ReadVersion::v2;
47 51 else
48 52 throw std::runtime_error("Unknown file version.");
49 53 }
50 54  
51 55 /**
52   - * @brief RetainedMessagesDB::writeRowHeader writes two 32 bit integers: topic size and payload size.
53   - * @param rm
54   - *
55   - * So, the header per message is 8 bytes long.
56   - *
57   - * It writes no information about the length of the QoS value, because that is always one.
58   - */
59   -void RetainedMessagesDB::writeRowHeader(const RetainedMessage &rm)
60   -{
61   - writeUint32(rm.topic.size());
62   - writeUint32(rm.payload.size());
63   -}
64   -
65   -RetainedMessagesDB::RowHeader RetainedMessagesDB::readRowHeaderV1(bool &eofFound)
66   -{
67   - RetainedMessagesDB::RowHeader result;
68   -
69   - result.topicLen = readUint32(eofFound);
70   - result.payloadLen = readUint32(eofFound);
71   -
72   - return result;
73   -}
74   -
75   -/**
76 56 * @brief RetainedMessagesDB::saveData doesn't explicitely name a file version (v1, etc), because we always write the current definition.
77 57 * @param messages
78 58 */
... ... @@ -81,20 +61,34 @@ void RetainedMessagesDB::saveData(const std::vector&lt;RetainedMessage&gt; &amp;messages)
81 61 if (!f)
82 62 return;
83 63  
84   - char reserved[RESERVED_SPACE_RETAINED_DB_V1];
85   - std::memset(reserved, 0, RESERVED_SPACE_RETAINED_DB_V1);
  64 + CirBuf cirbuf(1024);
  65 +
  66 + writeUint32(messages.size());
  67 +
  68 + char reserved[RESERVED_SPACE_RETAINED_DB_V2];
  69 + std::memset(reserved, 0, RESERVED_SPACE_RETAINED_DB_V2);
  70 + writeCheck(reserved, 1, RESERVED_SPACE_RETAINED_DB_V2, f);
86 71  
87   - char qos = 0;
88 72 for (const RetainedMessage &rm : messages)
89 73 {
90   - logger->logf(LOG_DEBUG, "Saving retained message for topic '%s' QoS %d.", rm.topic.c_str(), rm.qos);
91   -
92   - writeRowHeader(rm);
93   - qos = rm.qos;
94   - writeCheck(&qos, 1, 1, f);
95   - writeCheck(reserved, 1, RESERVED_SPACE_RETAINED_DB_V1, f);
96   - writeCheck(rm.topic.c_str(), 1, rm.topic.length(), f);
97   - writeCheck(rm.payload.c_str(), 1, rm.payload.length(), f);
  74 + logger->logf(LOG_DEBUG, "Saving retained message for topic '%s' QoS %d.", rm.publish.topic.c_str(), rm.publish.qos);
  75 +
  76 + Publish pcopy(rm.publish);
  77 + MqttPacket pack(ProtocolVersion::Mqtt5, pcopy);
  78 +
  79 + // Dummy, to please the parser on reading.
  80 + if (pcopy.qos > 0)
  81 + pack.setPacketId(666);
  82 +
  83 + const uint32_t packSize = pack.getSizeIncludingNonPresentHeader();
  84 +
  85 + cirbuf.reset();
  86 + cirbuf.ensureFreeSpace(packSize + 32);
  87 + pack.readIntoBuf(cirbuf);
  88 +
  89 + writeUint16(pack.getFixedHeaderLength());
  90 + writeUint32(packSize);
  91 + writeCheck(cirbuf.tailPtr(), 1, cirbuf.usedBytes(), f);
98 92 }
99 93  
100 94 fflush(f);
... ... @@ -108,38 +102,57 @@ std::list&lt;RetainedMessage&gt; RetainedMessagesDB::readData()
108 102 return defaultResult;
109 103  
110 104 if (readVersion == ReadVersion::v1)
111   - return readDataV1();
  105 + logger->logf(LOG_WARNING, "File '%s' is version 1, an internal development version that was never finalized. Not reading.", getFilePath().c_str());
  106 + if (readVersion == ReadVersion::v2)
  107 + return readDataV2();
112 108  
113 109 return defaultResult;
114 110 }
115 111  
116   -std::list<RetainedMessage> RetainedMessagesDB::readDataV1()
  112 +std::list<RetainedMessage> RetainedMessagesDB::readDataV2()
117 113 {
118 114 std::list<RetainedMessage> messages;
119 115  
  116 + CirBuf cirbuf(1024);
  117 +
  118 + const Settings *settings = ThreadGlobals::getSettings();
  119 + std::shared_ptr<ThreadData> dummyThreadData;
  120 + std::shared_ptr<Client> dummyClient(new Client(0, dummyThreadData, nullptr, false, nullptr, settings, false));
  121 + dummyClient->setClientProperties(ProtocolVersion::Mqtt5, "Dummyforloadingretained", "nobody", true, 60);
  122 +
120 123 while (!feof(f))
121 124 {
122 125 bool eofFound = false;
123   - RetainedMessagesDB::RowHeader header = readRowHeaderV1(eofFound);
  126 +
  127 + const uint32_t numberOfMessages = readUint32(eofFound);
124 128  
125 129 if (eofFound)
126 130 continue;
127 131  
128   - makeSureBufSize(header.payloadLen);
  132 + fseek(f, RESERVED_SPACE_RETAINED_DB_V2, SEEK_CUR);
  133 +
  134 + for(uint32_t i = 0; i < numberOfMessages; i++)
  135 + {
  136 + const uint16_t fixed_header_length = readUint16(eofFound);
  137 + const uint32_t packlen = readUint32(eofFound);
  138 +
  139 + if (eofFound)
  140 + continue;
129 141  
130   - readCheck(buf.data(), 1, 1, f);
131   - char qos = buf[0];
132   - fseek(f, RESERVED_SPACE_RETAINED_DB_V1, SEEK_CUR);
  142 + cirbuf.reset();
  143 + cirbuf.ensureFreeSpace(packlen + 32);
133 144  
134   - readCheck(buf.data(), 1, header.topicLen, f);
135   - std::string topic(buf.data(), header.topicLen);
  145 + readCheck(cirbuf.headPtr(), 1, packlen, f);
  146 + cirbuf.advanceHead(packlen);
  147 + MqttPacket pack(cirbuf, packlen, fixed_header_length, dummyClient);
136 148  
137   - readCheck(buf.data(), 1, header.payloadLen, f);
138   - std::string payload(buf.data(), header.payloadLen);
  149 + pack.parsePublishData();
  150 + Publish pub(pack.getPublishData());
139 151  
140   - RetainedMessage msg(topic, payload, qos);
141   - logger->logf(LOG_DEBUG, "Loading retained message for topic '%s' QoS %d.", msg.topic.c_str(), msg.qos);
142   - messages.push_back(std::move(msg));
  152 + RetainedMessage msg(pub);
  153 + logger->logf(LOG_DEBUG, "Loading retained message for topic '%s' QoS %d.", msg.publish.topic.c_str(), msg.publish.qos);
  154 + messages.push_back(std::move(msg));
  155 + }
143 156 }
144 157  
145 158 return messages;
... ...
retainedmessagesdb.h
... ... @@ -24,8 +24,8 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
24 24 #include "logger.h"
25 25  
26 26 #define MAGIC_STRING_V1 "FlashMQRetainedDBv1"
27   -#define ROW_HEADER_SIZE 8
28   -#define RESERVED_SPACE_RETAINED_DB_V1 31
  27 +#define MAGIC_STRING_V2 "FlashMQRetainedDBv2"
  28 +#define RESERVED_SPACE_RETAINED_DB_V2 64
29 29  
30 30 /**
31 31 * @brief The RetainedMessagesDB class saves and loads the retained messages.
... ... @@ -44,7 +44,8 @@ class RetainedMessagesDB : public PersistenceFile
44 44 enum class ReadVersion
45 45 {
46 46 unknown,
47   - v1
  47 + v1,
  48 + v2
48 49 };
49 50  
50 51 struct RowHeader
... ... @@ -55,9 +56,7 @@ class RetainedMessagesDB : public PersistenceFile
55 56  
56 57 ReadVersion readVersion = ReadVersion::unknown;
57 58  
58   - void writeRowHeader(const RetainedMessage &rm);
59   - RowHeader readRowHeaderV1(bool &eofFound);
60   - std::list<RetainedMessage> readDataV1();
  59 + std::list<RetainedMessage> readDataV2();
61 60 public:
62 61 RetainedMessagesDB(const std::string &filePath);
63 62  
... ...
sessionsandsubscriptionsdb.cpp
... ... @@ -85,7 +85,6 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2()
85 85 std::vector<char> reserved(RESERVED_SPACE_SESSIONS_DB_V2);
86 86 CirBuf cirbuf(1024);
87 87  
88   - // TODO: all that settings and thread data needs to be removed from Client.
89 88 std::shared_ptr<ThreadData> dummyThreadData; // which thread am I going get/use here?
90 89 std::shared_ptr<Client> dummyClient(new Client(0, dummyThreadData, nullptr, false, nullptr, settings, false));
91 90 dummyClient->setClientProperties(ProtocolVersion::Mqtt5, "Dummyforloadingqueuedqos", "nobody", true, 60);
... ...
subscriptionstore.cpp
... ... @@ -449,10 +449,10 @@ void SubscriptionStore::giveClientRetainedMessagesRecursively(ProtocolVersion pr
449 449 {
450 450 for(const RetainedMessage &rm : this_node->retainedMessages)
451 451 {
452   - // TODO: set the still to make 'split topic' to false
453   - Publish publish(rm.topic, rm.payload, rm.qos);
454   - publish.retain = true;
455   - packetList.emplace_front(protocolVersion, publish);
  452 + // TODO: hmm, const stuff forces me to make copy
  453 + Publish pubcopy(rm.publish);
  454 + pubcopy.splitTopic = true;
  455 + packetList.emplace_front(protocolVersion, pubcopy);
456 456 }
457 457 if (poundMode)
458 458 {
... ... @@ -516,7 +516,7 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr&lt;Cli
516 516 return count;
517 517 }
518 518  
519   -void SubscriptionStore::setRetainedMessage(const std::string &topic, const std::vector<std::string> &subtopics, const std::string &payload, char qos)
  519 +void SubscriptionStore::setRetainedMessage(const Publish &publish, const std::vector<std::string> &subtopics)
520 520 {
521 521 RetainedMessageNode *deepestNode = &retainedMessagesRoot;
522 522 if (!subtopics.empty() && !subtopics[0].empty() > 0 && subtopics[0][0] == '$')
... ... @@ -540,7 +540,7 @@ void SubscriptionStore::setRetainedMessage(const std::string &amp;topic, const std::
540 540  
541 541 if (deepestNode)
542 542 {
543   - deepestNode->addPayload(topic, payload, qos, retainedMessageCount);
  543 + deepestNode->addPayload(publish, retainedMessageCount);
544 544 }
545 545  
546 546 locker.unlock();
... ... @@ -834,10 +834,10 @@ void SubscriptionStore::loadRetainedMessages(const std::string &amp;filePath)
834 834 locker.wrlock();
835 835  
836 836 std::vector<std::string> subtopics;
837   - for (const RetainedMessage &rm : messages)
  837 + for (RetainedMessage &rm : messages)
838 838 {
839   - splitTopic(rm.topic, subtopics);
840   - setRetainedMessage(rm.topic, subtopics, rm.payload, rm.qos);
  839 + splitTopic(rm.publish.topic, rm.publish.subtopics);
  840 + setRetainedMessage(rm.publish, rm.publish.subtopics);
841 841 }
842 842 }
843 843 catch (PersistenceFileCantBeOpened &ex)
... ... @@ -950,18 +950,18 @@ void Subscription::reset()
950 950 qos = 0;
951 951 }
952 952  
953   -void RetainedMessageNode::addPayload(const std::string &topic, const std::string &payload, char qos, int64_t &totalCount)
  953 +void RetainedMessageNode::addPayload(const Publish &publish, int64_t &totalCount)
954 954 {
955 955 const int64_t countBefore = retainedMessages.size();
956   - RetainedMessage rm(topic, payload, qos);
  956 + RetainedMessage rm(publish);
957 957  
958 958 auto retained_ptr = retainedMessages.find(rm);
959 959 bool retained_found = retained_ptr != retainedMessages.end();
960 960  
961   - if (!retained_found && payload.empty())
  961 + if (!retained_found && publish.payload.empty())
962 962 return;
963 963  
964   - if (retained_found && payload.empty())
  964 + if (retained_found && publish.payload.empty())
965 965 {
966 966 retainedMessages.erase(rm);
967 967 const int64_t diffCount = (retainedMessages.size() - countBefore);
... ...
subscriptionstore.h
... ... @@ -80,7 +80,7 @@ class RetainedMessageNode
80 80 std::unordered_map<std::string, std::unique_ptr<RetainedMessageNode>> children;
81 81 std::unordered_set<RetainedMessage> retainedMessages;
82 82  
83   - void addPayload(const std::string &topic, const std::string &payload, char qos, int64_t &totalCount);
  83 + void addPayload(const Publish &publish, int64_t &totalCount);
84 84 RetainedMessageNode *getChildren(const std::string &subtopic) const;
85 85 };
86 86  
... ... @@ -170,7 +170,7 @@ public:
170 170 uint64_t giveClientRetainedMessages(const std::shared_ptr<Client> &client, const std::shared_ptr<Session> &ses,
171 171 const std::vector<std::string> &subscribeSubtopics, char max_qos);
172 172  
173   - void setRetainedMessage(const std::string &topic, const std::vector<std::string> &subtopics, const std::string &payload, char qos);
  173 + void setRetainedMessage(const Publish &publish, const std::vector<std::string> &subtopics);
174 174  
175 175 void removeSession(const std::shared_ptr<Session> &session);
176 176 void removeExpiredSessionsClients();
... ...
threaddata.cpp
... ... @@ -149,7 +149,7 @@ void ThreadData::publishStat(const std::string &amp;topic, uint64_t n)
149 149 Publish p(topic, payload, 0);
150 150 PublishCopyFactory factory(&p);
151 151 subscriptionStore->queuePacketAtSubscribers(factory, true);
152   - subscriptionStore->setRetainedMessage(topic, factory.getSubtopics(), payload, 0);
  152 + subscriptionStore->setRetainedMessage(p, factory.getSubtopics());
153 153 }
154 154  
155 155 void ThreadData::sendQueuedWills()
... ...