diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 208a807..24db110 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -895,20 +895,19 @@ void MainTests::testRetainedMessageDB() std::string longTopic = formatString("one/two/%s", getSecureRandomString(4000).c_str()); std::vector messages; - messages.emplace_back("one/two/three", "payload", 0); - messages.emplace_back("one/two/wer", "payload", 1); - messages.emplace_back("one/e/wer", "payload", 1); - messages.emplace_back("one/wee/wer", "asdfasdfasdf", 1); - messages.emplace_back("one/two/wer", "µsdf", 1); - messages.emplace_back("/boe/bah", longpayload, 1); - messages.emplace_back("one/two/wer", "paylasdfaoad", 1); - messages.emplace_back("one/two/wer", "payload", 1); - messages.emplace_back(longTopic, "payload", 1); - messages.emplace_back(longTopic, longpayload, 1); - messages.emplace_back("one", "µsdf", 1); - messages.emplace_back("/boe", longpayload, 1); - messages.emplace_back("one", "µsdf", 1); - messages.emplace_back("", "foremptytopic", 0); + messages.emplace_back(Publish("one/two/three", "payload", 0)); + messages.emplace_back(Publish("one/two/wer", "payload", 1)); + messages.emplace_back(Publish("one/e/wer", "payload", 1)); + messages.emplace_back(Publish("one/wee/wer", "asdfasdfasdf", 1)); + messages.emplace_back(Publish("one/two/wer", "µsdf", 1)); + messages.emplace_back(Publish("/boe/bah", longpayload, 1)); + messages.emplace_back(Publish("one/two/wer", "paylasdfaoad", 1)); + messages.emplace_back(Publish("one/two/wer", "payload", 1)); + messages.emplace_back(Publish(longTopic, "payload", 1)); + messages.emplace_back(Publish(longTopic, longpayload, 1)); + messages.emplace_back(Publish("one", "µsdf", 1)); + messages.emplace_back(Publish("/boe", longpayload, 1)); + messages.emplace_back(Publish("one", "µsdf", 1)); RetainedMessagesDB db("/tmp/flashmqtests_retained.db"); db.openWrite(); @@ -920,7 +919,7 @@ void MainTests::testRetainedMessageDB() std::list messagesLoaded = db2.readData(); db2.closeFile(); - QCOMPARE(messages.size(), messagesLoaded.size()); + QCOMPARE(messagesLoaded.size(), messages.size()); auto itOrg = messages.begin(); auto itLoaded = messagesLoaded.begin(); @@ -930,9 +929,9 @@ void MainTests::testRetainedMessageDB() RetainedMessage &two = *itLoaded; // Comparing the fields because the RetainedMessage class has an == operator that only looks at topic. - QCOMPARE(one.topic, two.topic); - QCOMPARE(one.payload, two.payload); - QCOMPARE(one.qos, two.qos); + QCOMPARE(one.publish.topic, two.publish.topic); + QCOMPARE(one.publish.payload, two.publish.payload); + QCOMPARE(one.publish.qos, two.publish.qos); itOrg++; itLoaded++; diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 42f1312..9329457 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -1017,8 +1017,8 @@ void MqttPacket::handlePublish() { if (publishData.retain) { - std::string payload(readBytes(payloadLen), payloadLen); - sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(publishData.topic, publishData.subtopics, payload, publishData.qos); + publishData.payload = getPayloadCopy(); + sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(publishData, publishData.subtopics); } // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. @@ -1420,6 +1420,7 @@ bool MqttPacket::containsClientSpecificProperties() const void MqttPacket::readIntoBuf(CirBuf &buf) const { assert(packetType != PacketType::PUBLISH || (first_byte & 0b00000110) >> 1 == publishData.qos); + assert(publishData.qos == 0 || packet_id > 0); buf.ensureFreeSpace(getSizeIncludingNonPresentHeader()); diff --git a/retainedmessage.cpp b/retainedmessage.cpp index c01f9e2..afd403b 100644 --- a/retainedmessage.cpp +++ b/retainedmessage.cpp @@ -17,25 +17,24 @@ License along with FlashMQ. If not, see . #include "retainedmessage.h" -RetainedMessage::RetainedMessage(const std::string &topic, const std::string &payload, char qos) : - topic(topic), - payload(payload), - qos(qos) +RetainedMessage::RetainedMessage(const Publish &publish) : + publish(publish) { - + this->publish.retain = true; + this->publish.splitTopic = false; } bool RetainedMessage::operator==(const RetainedMessage &rhs) const { - return this->topic == rhs.topic; + return this->publish.topic == rhs.publish.topic; } bool RetainedMessage::empty() const { - return payload.empty(); + return publish.payload.empty(); } uint32_t RetainedMessage::getSize() const { - return topic.length() + payload.length() + 1; + return publish.topic.length() + publish.payload.length() + 1; } diff --git a/retainedmessage.h b/retainedmessage.h index 97b5369..f44a414 100644 --- a/retainedmessage.h +++ b/retainedmessage.h @@ -19,14 +19,13 @@ License along with FlashMQ. If not, see . #define RETAINEDMESSAGE_H #include +#include "types.h" struct RetainedMessage { - std::string topic; - std::string payload; - char qos; + Publish publish; - RetainedMessage(const std::string &topic, const std::string &payload, char qos); + RetainedMessage(const Publish &publish); bool operator==(const RetainedMessage &rhs) const; bool empty() const; @@ -44,7 +43,7 @@ namespace std { using std::hash; using std::string; - return hash()(k.topic); + return hash()(k.publish.topic); } }; diff --git a/retainedmessagesdb.cpp b/retainedmessagesdb.cpp index 70ae9b0..e401e48 100644 --- a/retainedmessagesdb.cpp +++ b/retainedmessagesdb.cpp @@ -27,6 +27,8 @@ License along with FlashMQ. If not, see . #include "retainedmessagesdb.h" #include "utils.h" #include "logger.h" +#include "mqttpacket.h" +#include "threadglobals.h" RetainedMessagesDB::RetainedMessagesDB(const std::string &filePath) : PersistenceFile(filePath) { @@ -35,7 +37,7 @@ RetainedMessagesDB::RetainedMessagesDB(const std::string &filePath) : Persistenc void RetainedMessagesDB::openWrite() { - PersistenceFile::openWrite(MAGIC_STRING_V1); + PersistenceFile::openWrite(MAGIC_STRING_V2); } void RetainedMessagesDB::openRead() @@ -44,35 +46,13 @@ void RetainedMessagesDB::openRead() if (detectedVersionString == MAGIC_STRING_V1) readVersion = ReadVersion::v1; + else if (detectedVersionString == MAGIC_STRING_V2) + readVersion = ReadVersion::v2; else throw std::runtime_error("Unknown file version."); } /** - * @brief RetainedMessagesDB::writeRowHeader writes two 32 bit integers: topic size and payload size. - * @param rm - * - * So, the header per message is 8 bytes long. - * - * It writes no information about the length of the QoS value, because that is always one. - */ -void RetainedMessagesDB::writeRowHeader(const RetainedMessage &rm) -{ - writeUint32(rm.topic.size()); - writeUint32(rm.payload.size()); -} - -RetainedMessagesDB::RowHeader RetainedMessagesDB::readRowHeaderV1(bool &eofFound) -{ - RetainedMessagesDB::RowHeader result; - - result.topicLen = readUint32(eofFound); - result.payloadLen = readUint32(eofFound); - - return result; -} - -/** * @brief RetainedMessagesDB::saveData doesn't explicitely name a file version (v1, etc), because we always write the current definition. * @param messages */ @@ -81,20 +61,34 @@ void RetainedMessagesDB::saveData(const std::vector &messages) if (!f) return; - char reserved[RESERVED_SPACE_RETAINED_DB_V1]; - std::memset(reserved, 0, RESERVED_SPACE_RETAINED_DB_V1); + CirBuf cirbuf(1024); + + writeUint32(messages.size()); + + char reserved[RESERVED_SPACE_RETAINED_DB_V2]; + std::memset(reserved, 0, RESERVED_SPACE_RETAINED_DB_V2); + writeCheck(reserved, 1, RESERVED_SPACE_RETAINED_DB_V2, f); - char qos = 0; for (const RetainedMessage &rm : messages) { - logger->logf(LOG_DEBUG, "Saving retained message for topic '%s' QoS %d.", rm.topic.c_str(), rm.qos); - - writeRowHeader(rm); - qos = rm.qos; - writeCheck(&qos, 1, 1, f); - writeCheck(reserved, 1, RESERVED_SPACE_RETAINED_DB_V1, f); - writeCheck(rm.topic.c_str(), 1, rm.topic.length(), f); - writeCheck(rm.payload.c_str(), 1, rm.payload.length(), f); + logger->logf(LOG_DEBUG, "Saving retained message for topic '%s' QoS %d.", rm.publish.topic.c_str(), rm.publish.qos); + + Publish pcopy(rm.publish); + MqttPacket pack(ProtocolVersion::Mqtt5, pcopy); + + // Dummy, to please the parser on reading. + if (pcopy.qos > 0) + pack.setPacketId(666); + + const uint32_t packSize = pack.getSizeIncludingNonPresentHeader(); + + cirbuf.reset(); + cirbuf.ensureFreeSpace(packSize + 32); + pack.readIntoBuf(cirbuf); + + writeUint16(pack.getFixedHeaderLength()); + writeUint32(packSize); + writeCheck(cirbuf.tailPtr(), 1, cirbuf.usedBytes(), f); } fflush(f); @@ -108,38 +102,57 @@ std::list RetainedMessagesDB::readData() return defaultResult; if (readVersion == ReadVersion::v1) - return readDataV1(); + logger->logf(LOG_WARNING, "File '%s' is version 1, an internal development version that was never finalized. Not reading.", getFilePath().c_str()); + if (readVersion == ReadVersion::v2) + return readDataV2(); return defaultResult; } -std::list RetainedMessagesDB::readDataV1() +std::list RetainedMessagesDB::readDataV2() { std::list messages; + CirBuf cirbuf(1024); + + const Settings *settings = ThreadGlobals::getSettings(); + std::shared_ptr dummyThreadData; + std::shared_ptr dummyClient(new Client(0, dummyThreadData, nullptr, false, nullptr, settings, false)); + dummyClient->setClientProperties(ProtocolVersion::Mqtt5, "Dummyforloadingretained", "nobody", true, 60); + while (!feof(f)) { bool eofFound = false; - RetainedMessagesDB::RowHeader header = readRowHeaderV1(eofFound); + + const uint32_t numberOfMessages = readUint32(eofFound); if (eofFound) continue; - makeSureBufSize(header.payloadLen); + fseek(f, RESERVED_SPACE_RETAINED_DB_V2, SEEK_CUR); + + for(uint32_t i = 0; i < numberOfMessages; i++) + { + const uint16_t fixed_header_length = readUint16(eofFound); + const uint32_t packlen = readUint32(eofFound); + + if (eofFound) + continue; - readCheck(buf.data(), 1, 1, f); - char qos = buf[0]; - fseek(f, RESERVED_SPACE_RETAINED_DB_V1, SEEK_CUR); + cirbuf.reset(); + cirbuf.ensureFreeSpace(packlen + 32); - readCheck(buf.data(), 1, header.topicLen, f); - std::string topic(buf.data(), header.topicLen); + readCheck(cirbuf.headPtr(), 1, packlen, f); + cirbuf.advanceHead(packlen); + MqttPacket pack(cirbuf, packlen, fixed_header_length, dummyClient); - readCheck(buf.data(), 1, header.payloadLen, f); - std::string payload(buf.data(), header.payloadLen); + pack.parsePublishData(); + Publish pub(pack.getPublishData()); - RetainedMessage msg(topic, payload, qos); - logger->logf(LOG_DEBUG, "Loading retained message for topic '%s' QoS %d.", msg.topic.c_str(), msg.qos); - messages.push_back(std::move(msg)); + RetainedMessage msg(pub); + logger->logf(LOG_DEBUG, "Loading retained message for topic '%s' QoS %d.", msg.publish.topic.c_str(), msg.publish.qos); + messages.push_back(std::move(msg)); + } } return messages; diff --git a/retainedmessagesdb.h b/retainedmessagesdb.h index 7b2299d..b623f5f 100644 --- a/retainedmessagesdb.h +++ b/retainedmessagesdb.h @@ -24,8 +24,8 @@ License along with FlashMQ. If not, see . #include "logger.h" #define MAGIC_STRING_V1 "FlashMQRetainedDBv1" -#define ROW_HEADER_SIZE 8 -#define RESERVED_SPACE_RETAINED_DB_V1 31 +#define MAGIC_STRING_V2 "FlashMQRetainedDBv2" +#define RESERVED_SPACE_RETAINED_DB_V2 64 /** * @brief The RetainedMessagesDB class saves and loads the retained messages. @@ -44,7 +44,8 @@ class RetainedMessagesDB : public PersistenceFile enum class ReadVersion { unknown, - v1 + v1, + v2 }; struct RowHeader @@ -55,9 +56,7 @@ class RetainedMessagesDB : public PersistenceFile ReadVersion readVersion = ReadVersion::unknown; - void writeRowHeader(const RetainedMessage &rm); - RowHeader readRowHeaderV1(bool &eofFound); - std::list readDataV1(); + std::list readDataV2(); public: RetainedMessagesDB(const std::string &filePath); diff --git a/sessionsandsubscriptionsdb.cpp b/sessionsandsubscriptionsdb.cpp index f5575da..7c5b827 100644 --- a/sessionsandsubscriptionsdb.cpp +++ b/sessionsandsubscriptionsdb.cpp @@ -85,7 +85,6 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2() std::vector reserved(RESERVED_SPACE_SESSIONS_DB_V2); CirBuf cirbuf(1024); - // TODO: all that settings and thread data needs to be removed from Client. std::shared_ptr dummyThreadData; // which thread am I going get/use here? std::shared_ptr dummyClient(new Client(0, dummyThreadData, nullptr, false, nullptr, settings, false)); dummyClient->setClientProperties(ProtocolVersion::Mqtt5, "Dummyforloadingqueuedqos", "nobody", true, 60); diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index 55795b0..90d6466 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -449,10 +449,10 @@ void SubscriptionStore::giveClientRetainedMessagesRecursively(ProtocolVersion pr { for(const RetainedMessage &rm : this_node->retainedMessages) { - // TODO: set the still to make 'split topic' to false - Publish publish(rm.topic, rm.payload, rm.qos); - publish.retain = true; - packetList.emplace_front(protocolVersion, publish); + // TODO: hmm, const stuff forces me to make copy + Publish pubcopy(rm.publish); + pubcopy.splitTopic = true; + packetList.emplace_front(protocolVersion, pubcopy); } if (poundMode) { @@ -516,7 +516,7 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr &subtopics, const std::string &payload, char qos) +void SubscriptionStore::setRetainedMessage(const Publish &publish, const std::vector &subtopics) { RetainedMessageNode *deepestNode = &retainedMessagesRoot; if (!subtopics.empty() && !subtopics[0].empty() > 0 && subtopics[0][0] == '$') @@ -540,7 +540,7 @@ void SubscriptionStore::setRetainedMessage(const std::string &topic, const std:: if (deepestNode) { - deepestNode->addPayload(topic, payload, qos, retainedMessageCount); + deepestNode->addPayload(publish, retainedMessageCount); } locker.unlock(); @@ -834,10 +834,10 @@ void SubscriptionStore::loadRetainedMessages(const std::string &filePath) locker.wrlock(); std::vector subtopics; - for (const RetainedMessage &rm : messages) + for (RetainedMessage &rm : messages) { - splitTopic(rm.topic, subtopics); - setRetainedMessage(rm.topic, subtopics, rm.payload, rm.qos); + splitTopic(rm.publish.topic, rm.publish.subtopics); + setRetainedMessage(rm.publish, rm.publish.subtopics); } } catch (PersistenceFileCantBeOpened &ex) @@ -950,18 +950,18 @@ void Subscription::reset() qos = 0; } -void RetainedMessageNode::addPayload(const std::string &topic, const std::string &payload, char qos, int64_t &totalCount) +void RetainedMessageNode::addPayload(const Publish &publish, int64_t &totalCount) { const int64_t countBefore = retainedMessages.size(); - RetainedMessage rm(topic, payload, qos); + RetainedMessage rm(publish); auto retained_ptr = retainedMessages.find(rm); bool retained_found = retained_ptr != retainedMessages.end(); - if (!retained_found && payload.empty()) + if (!retained_found && publish.payload.empty()) return; - if (retained_found && payload.empty()) + if (retained_found && publish.payload.empty()) { retainedMessages.erase(rm); const int64_t diffCount = (retainedMessages.size() - countBefore); diff --git a/subscriptionstore.h b/subscriptionstore.h index 63c1f84..d1a7f12 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -80,7 +80,7 @@ class RetainedMessageNode std::unordered_map> children; std::unordered_set retainedMessages; - void addPayload(const std::string &topic, const std::string &payload, char qos, int64_t &totalCount); + void addPayload(const Publish &publish, int64_t &totalCount); RetainedMessageNode *getChildren(const std::string &subtopic) const; }; @@ -170,7 +170,7 @@ public: uint64_t giveClientRetainedMessages(const std::shared_ptr &client, const std::shared_ptr &ses, const std::vector &subscribeSubtopics, char max_qos); - void setRetainedMessage(const std::string &topic, const std::vector &subtopics, const std::string &payload, char qos); + void setRetainedMessage(const Publish &publish, const std::vector &subtopics); void removeSession(const std::shared_ptr &session); void removeExpiredSessionsClients(); diff --git a/threaddata.cpp b/threaddata.cpp index b918296..f3b62b8 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -149,7 +149,7 @@ void ThreadData::publishStat(const std::string &topic, uint64_t n) Publish p(topic, payload, 0); PublishCopyFactory factory(&p); subscriptionStore->queuePacketAtSubscribers(factory, true); - subscriptionStore->setRetainedMessage(topic, factory.getSubtopics(), payload, 0); + subscriptionStore->setRetainedMessage(p, factory.getSubtopics()); } void ThreadData::sendQueuedWills()