From 4bfa5aa5716d52b57ed4e725f7f5dd6aa03e395d Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Wed, 16 Jun 2021 21:25:59 +0200 Subject: [PATCH] Add saving state feature --- CMakeLists.txt | 8 ++++++++ FlashMQTests/FlashMQTests.pro | 8 ++++++++ FlashMQTests/tst_maintests.cpp | 227 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ configfileparser.cpp | 9 +++++++++ mainapp.cpp | 30 ++++++++++++++++++++++++++++++ mainapp.h | 1 + mqttpacket.cpp | 34 +++++++++++++++++++++++++++++----- mqttpacket.h | 3 +++ persistencefile.cpp | 320 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ persistencefile.h | 92 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ qospacketqueue.cpp | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ qospacketqueue.h | 25 +++++++++++++++++++++++++ retainedmessage.cpp | 5 +++++ retainedmessage.h | 1 + retainedmessagesdb.cpp | 146 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ retainedmessagesdb.h | 71 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ rwlockguard.cpp | 10 +++++++++- session.cpp | 103 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------------- session.h | 31 +++++++++++++++++++------------ sessionsandsubscriptionsdb.cpp | 299 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ sessionsandsubscriptionsdb.h | 71 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ settings.cpp | 20 +++++++++++++++++++- settings.h | 4 ++++ subscriptionstore.cpp | 206 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----- subscriptionstore.h | 16 +++++++++++++++- utils.cpp | 8 ++++++++ utils.h | 30 ++++++++++++++++++++++++++++++ 27 files changed, 1795 insertions(+), 60 deletions(-) create mode 100644 persistencefile.cpp create mode 100644 persistencefile.h create mode 100644 qospacketqueue.cpp create mode 100644 qospacketqueue.h create mode 100644 retainedmessagesdb.cpp create mode 100644 retainedmessagesdb.h create mode 100644 sessionsandsubscriptionsdb.cpp create mode 100644 sessionsandsubscriptionsdb.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 759c073..3657202 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,6 +50,10 @@ add_executable(FlashMQ enums.h threadlocalutils.h flashmq_plugin.h + retainedmessagesdb.h + persistencefile.h + sessionsandsubscriptionsdb.h + qospacketqueue.h mainapp.cpp main.cpp @@ -81,6 +85,10 @@ add_executable(FlashMQ acltree.cpp threadlocalutils.cpp flashmq_plugin.cpp + retainedmessagesdb.cpp + persistencefile.cpp + sessionsandsubscriptionsdb.cpp + qospacketqueue.cpp ) diff --git a/FlashMQTests/FlashMQTests.pro b/FlashMQTests/FlashMQTests.pro index 1c48319..450323e 100644 --- a/FlashMQTests/FlashMQTests.pro +++ b/FlashMQTests/FlashMQTests.pro @@ -42,6 +42,10 @@ SOURCES += tst_maintests.cpp \ ../acltree.cpp \ ../threadlocalutils.cpp \ ../flashmq_plugin.cpp \ + ../retainedmessagesdb.cpp \ + ../persistencefile.cpp \ + ../sessionsandsubscriptionsdb.cpp \ + ../qospacketqueue.cpp \ mainappthread.cpp \ twoclienttestcontext.cpp @@ -77,6 +81,10 @@ HEADERS += \ ../acltree.h \ ../threadlocalutils.h \ ../flashmq_plugin.h \ + ../retainedmessagesdb.h \ + ../persistencefile.h \ + ../sessionsandsubscriptionsdb.h \ + ../qospacketqueue.h \ mainappthread.h \ twoclienttestcontext.h diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 14294a9..ef1866a 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -21,12 +21,18 @@ License along with FlashMQ. If not, see . #include #include #include +#include +#include #include "cirbuf.h" #include "mainapp.h" #include "mainappthread.h" #include "twoclienttestcontext.h" #include "threadlocalutils.h" +#include "retainedmessagesdb.h" +#include "sessionsandsubscriptionsdb.h" +#include "session.h" +#include "threaddata.h" // Dumb Qt version gives warnings when comparing uint with number literal. template @@ -88,6 +94,12 @@ private slots: void testTopicsMatch(); + void testRetainedMessageDB(); + void testRetainedMessageDBNotPresent(); + void testRetainedMessageDBEmptyList(); + + void testSavingSessions(); + }; MainTests::MainTests() @@ -822,6 +834,221 @@ void MainTests::testTopicsMatch() } +void MainTests::testRetainedMessageDB() +{ + try + { + std::string longpayload = getSecureRandomString(65537); + std::string longTopic = formatString("one/two/%s", getSecureRandomString(4000).c_str()); + + std::vector messages; + messages.emplace_back("one/two/three", "payload", 0); + messages.emplace_back("one/two/wer", "payload", 1); + messages.emplace_back("one/e/wer", "payload", 1); + messages.emplace_back("one/wee/wer", "asdfasdfasdf", 1); + messages.emplace_back("one/two/wer", "µsdf", 1); + messages.emplace_back("/boe/bah", longpayload, 1); + messages.emplace_back("one/two/wer", "paylasdfaoad", 1); + messages.emplace_back("one/two/wer", "payload", 1); + messages.emplace_back(longTopic, "payload", 1); + messages.emplace_back(longTopic, longpayload, 1); + messages.emplace_back("one", "µsdf", 1); + messages.emplace_back("/boe", longpayload, 1); + messages.emplace_back("one", "µsdf", 1); + messages.emplace_back("", "foremptytopic", 0); + + RetainedMessagesDB db("/tmp/flashmqtests_retained.db"); + db.openWrite(); + db.saveData(messages); + db.closeFile(); + + RetainedMessagesDB db2("/tmp/flashmqtests_retained.db"); + db2.openRead(); + std::list messagesLoaded = db2.readData(); + db2.closeFile(); + + QCOMPARE(messages.size(), messagesLoaded.size()); + + auto itOrg = messages.begin(); + auto itLoaded = messagesLoaded.begin(); + while (itOrg != messages.end() && itLoaded != messagesLoaded.end()) + { + RetainedMessage &one = *itOrg; + RetainedMessage &two = *itLoaded; + + // Comparing the fields because the RetainedMessage class has an == operator that only looks at topic. + QCOMPARE(one.topic, two.topic); + QCOMPARE(one.payload, two.payload); + QCOMPARE(one.qos, two.qos); + + itOrg++; + itLoaded++; + } + } + catch (std::exception &ex) + { + QVERIFY2(false, ex.what()); + } +} + +void MainTests::testRetainedMessageDBNotPresent() +{ + try + { + RetainedMessagesDB db2("/tmp/flashmqtests_asdfasdfasdf.db"); + db2.openRead(); + std::list messagesLoaded = db2.readData(); + db2.closeFile(); + + MYCASTCOMPARE(messagesLoaded.size(), 0); + + QVERIFY2(false, "We should have run into an exception."); + } + catch (PersistenceFileCantBeOpened &ex) + { + QVERIFY(true); + } + catch (std::exception &ex) + { + QVERIFY2(false, ex.what()); + } +} + +void MainTests::testRetainedMessageDBEmptyList() +{ + try + { + std::vector messages; + + RetainedMessagesDB db("/tmp/flashmqtests_retained.db"); + db.openWrite(); + db.saveData(messages); + db.closeFile(); + + RetainedMessagesDB db2("/tmp/flashmqtests_retained.db"); + db2.openRead(); + std::list messagesLoaded = db2.readData(); + db2.closeFile(); + + MYCASTCOMPARE(messages.size(), messagesLoaded.size()); + MYCASTCOMPARE(messages.size(), 0); + } + catch (std::exception &ex) + { + QVERIFY2(false, ex.what()); + } +} + +void MainTests::testSavingSessions() +{ + try + { + std::shared_ptr settings(new Settings()); + std::shared_ptr store(new SubscriptionStore()); + std::shared_ptr t(new ThreadData(0, store, settings)); + + std::shared_ptr c1(new Client(0, t, nullptr, false, nullptr, settings, false)); + c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60, false); + store->registerClientAndKickExistingOne(c1); + c1->getSession()->touch(); + c1->getSession()->addIncomingQoS2MessageId(2); + c1->getSession()->addIncomingQoS2MessageId(3); + + std::shared_ptr c2(new Client(0, t, nullptr, false, nullptr, settings, false)); + c2->setClientProperties(ProtocolVersion::Mqtt311, "c2", "user2", true, 60, false); + store->registerClientAndKickExistingOne(c2); + c2->getSession()->touch(); + c2->getSession()->addOutgoingQoS2MessageId(55); + c2->getSession()->addOutgoingQoS2MessageId(66); + + const std::string topic1 = "one/two/three"; + std::vector subtopics; + splitTopic(topic1, subtopics); + store->addSubscription(c1, topic1, subtopics, 0); + + const std::string topic2 = "four/five/six"; + splitTopic(topic2, subtopics); + store->addSubscription(c2, topic2, subtopics, 0); + store->addSubscription(c1, topic2, subtopics, 0); + + const std::string topic3 = ""; + splitTopic(topic3, subtopics); + store->addSubscription(c2, topic3, subtopics, 0); + + const std::string topic4 = "#"; + splitTopic(topic4, subtopics); + store->addSubscription(c2, topic4, subtopics, 0); + + uint64_t count = 0; + + Publish publish("a/b/c", "Hello Barry", 1); + + std::shared_ptr c1ses = c1->getSession(); + c1.reset(); + c1ses->writePacket(publish, 1, false, count); + + store->saveSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db"); + + std::shared_ptr store2(new SubscriptionStore()); + store2->loadSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db"); + + MYCASTCOMPARE(store->sessionsById.size(), 2); + MYCASTCOMPARE(store2->sessionsById.size(), 2); + + for (auto &pair : store->sessionsById) + { + std::shared_ptr &ses = pair.second; + std::shared_ptr &ses2 = store2->sessionsById[pair.first]; + QCOMPARE(pair.first, ses2->getClientId()); + + QCOMPARE(ses->username, ses2->username); + QCOMPARE(ses->client_id, ses2->client_id); + QCOMPARE(ses->incomingQoS2MessageIds, ses2->incomingQoS2MessageIds); + QCOMPARE(ses->outgoingQoS2MessageIds, ses2->outgoingQoS2MessageIds); + QCOMPARE(ses->nextPacketId, ses2->nextPacketId); + } + + + std::unordered_map> store1Subscriptions; + store->getSubscriptions(&store->root, "", true, store1Subscriptions); + + std::unordered_map> store2Subscriptions; + store->getSubscriptions(&store->root, "", true, store2Subscriptions); + + MYCASTCOMPARE(store1Subscriptions.size(), 4); + MYCASTCOMPARE(store2Subscriptions.size(), 4); + + for(auto &pair : store1Subscriptions) + { + std::list &subscList1 = pair.second; + std::list &subscList2 = store2Subscriptions[pair.first]; + + QCOMPARE(subscList1.size(), subscList2.size()); + + auto subs1It = subscList1.begin(); + auto subs2It = subscList2.begin(); + + while (subs1It != subscList1.end()) + { + SubscriptionForSerializing &one = *subs1It; + SubscriptionForSerializing &two = *subs2It; + QCOMPARE(one.clientId, two.clientId); + QCOMPARE(one.qos, two.qos); + + subs1It++; + subs2It++; + } + + } + + + } + catch (std::exception &ex) + { + QVERIFY2(false, ex.what()); + } +} + QTEST_GUILESS_MAIN(MainTests) #include "tst_maintests.moc" diff --git a/configfileparser.cpp b/configfileparser.cpp index b99810e..9211f17 100644 --- a/configfileparser.cpp +++ b/configfileparser.cpp @@ -107,6 +107,7 @@ ConfigFileParser::ConfigFileParser(const std::string &path) : validKeys.insert("allow_anonymous"); validKeys.insert("rlimit_nofile"); validKeys.insert("expire_sessions_after_seconds"); + validKeys.insert("storage_dir"); validListenKeys.insert("port"); validListenKeys.insert("protocol"); @@ -412,6 +413,14 @@ void ConfigFileParser::loadFile(bool test) } tmpSettings->authPluginTimerPeriod = newVal; } + + if (key == "storage_dir") + { + std::string newPath = value; + rtrim(newPath, '/'); + checkWritableDir(newPath); + tmpSettings->storageDir = newPath; + } } } catch (std::invalid_argument &ex) // catch for the stoi() diff --git a/mainapp.cpp b/mainapp.cpp index de05ac7..de4f00e 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -202,6 +202,15 @@ MainApp::MainApp(const std::string &configFilePath) : auto fAuthPluginPeriodicEvent = std::bind(&MainApp::queueAuthPluginPeriodicEventAllThreads, this); timer.addCallback(fAuthPluginPeriodicEvent, settings->authPluginTimerPeriod*1000, "Auth plugin periodic event."); } + + if (!settings->storageDir.empty()) + { + subscriptionStore->loadRetainedMessages(settings->getRetainedMessagesDBFile()); + subscriptionStore->loadSessionsAndSubscriptions(settings->getSessionsDBFile()); + } + + auto fSaveState = std::bind(&MainApp::saveState, this); + timer.addCallback(fSaveState, 900000, "Save state."); } MainApp::~MainApp() @@ -369,6 +378,25 @@ void MainApp::publishStat(const std::string &topic, uint64_t n) subscriptionStore->setRetainedMessage(topic, subtopics, payload, 0); } +void MainApp::saveState() +{ + try + { + if (!settings->storageDir.empty()) + { + const std::string retainedDBPath = settings->getRetainedMessagesDBFile(); + subscriptionStore->saveRetainedMessages(retainedDBPath); + + const std::string sessionsDBPath = settings->getSessionsDBFile(); + subscriptionStore->saveSessionsAndSubscriptions(sessionsDBPath); + } + } + catch(std::exception &ex) + { + logger->logf(LOG_ERR, "Error saving state: %s", ex.what()); + } +} + void MainApp::initMainApp(int argc, char *argv[]) { if (instance != nullptr) @@ -659,6 +687,8 @@ void MainApp::start() { thread->waitForQuit(); } + + saveState(); } void MainApp::quit() diff --git a/mainapp.h b/mainapp.h index 2a00475..7a31509 100644 --- a/mainapp.h +++ b/mainapp.h @@ -84,6 +84,7 @@ class MainApp void setFuzzFile(const std::string &fuzzFilePath); void publishStatsOnDollarTopic(); void publishStat(const std::string &topic, uint64_t n); + void saveState(); MainApp(const std::string &configFilePath); public: diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 19817d7..e69c310 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -47,8 +47,11 @@ MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_lengt pos += fixed_header_length; } -// This is easier than using the copy constructor publically, because then I have to keep maintaining a functioning copy constructor. -// Returning shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource. +/** + * @brief MqttPacket::getCopy (using default copy constructor and resetting some selected fields) is easier than using the copy constructor + * publically, because then I have to keep maintaining a functioning copy constructor for each new field I add. + * @return a shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource. + */ std::shared_ptr MqttPacket::getCopy() const { std::shared_ptr copyPacket(new MqttPacket(*this)); @@ -129,6 +132,9 @@ MqttPacket::MqttPacket(const Publish &publish) : writeBytes(zero, 2); } + payloadStart = pos; + payloadLen = publish.payload.length(); + writeBytes(publish.payload.c_str(), publish.payload.length()); calculateRemainingLength(); } @@ -546,13 +552,14 @@ void MqttPacket::handlePublish() } } + payloadLen = remainingAfterPos(); + payloadStart = pos; + if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success) { if (retain) { - size_t payload_length = remainingAfterPos(); - std::string payload(readBytes(payload_length), payload_length); - + std::string payload(readBytes(payloadLen), payloadLen); sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, *subtopics, payload, qos); } @@ -679,6 +686,23 @@ size_t MqttPacket::getTotalMemoryFootprint() return bites.size() + sizeof(MqttPacket); } +/** + * @brief MqttPacket::getPayloadCopy takes part of the vector of bytes and returns it as a string. + * @return + * + * It's necessary sometimes, but it's against FlashMQ's concept of not parsing the payload. Normally, you can just write out + * the whole byte array that is a packet to subscribers. No need to copy and such. + * + * I created it for saving QoS packages in the db file. + */ +std::string MqttPacket::getPayloadCopy() +{ + assert(payloadStart > 0); + assert(pos <= bites.size()); + std::string payload(&bites[payloadStart], payloadLen); + return payload; +} + size_t MqttPacket::getSizeIncludingNonPresentHeader() const { size_t total = bites.size(); diff --git a/mqttpacket.h b/mqttpacket.h index 6c8ddb8..9b0d7a5 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -59,6 +59,8 @@ class MqttPacket size_t packet_id_pos = 0; uint16_t packet_id = 0; ProtocolVersion protocolVersion = ProtocolVersion::None; + size_t payloadStart = 0; + size_t payloadLen = 0; Logger *logger = Logger::getInstance(); char *readBytes(size_t length); @@ -116,6 +118,7 @@ public: uint16_t getPacketId() const; void setDuplicate(); size_t getTotalMemoryFootprint(); + std::string getPayloadCopy(); }; #endif // MQTTPACKET_H diff --git a/persistencefile.cpp b/persistencefile.cpp new file mode 100644 index 0000000..03ed4f7 --- /dev/null +++ b/persistencefile.cpp @@ -0,0 +1,320 @@ +/* +This file is part of FlashMQ (https://www.flashmq.org) +Copyright (C) 2021 Wiebe Cazemier + +FlashMQ is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, version 3. + +FlashMQ is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public +License along with FlashMQ. If not, see . +*/ + +#include "persistencefile.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.h" +#include "logger.h" + +PersistenceFile::PersistenceFile(const std::string &filePath) : + digestContext(EVP_MD_CTX_new()), + buf(1024*1024) +{ + if (!filePath.empty() && filePath[filePath.size() - 1] == '/') + throw std::runtime_error("Target file can't contain trailing slash."); + + this->filePath = filePath; + this->filePathTemp = formatString("%s.newfile.%s", filePath.c_str(), getSecureRandomString(8).c_str()); + this->filePathCorrupt = formatString("%s.corrupt.%s", filePath.c_str(), getSecureRandomString(8).c_str()); +} + +PersistenceFile::~PersistenceFile() +{ + closeFile(); + + if (digestContext) + { + EVP_MD_CTX_free(digestContext); + } +} + +/** + * @brief RetainedMessagesDB::hashFile hashes the data after the headers and writes the hash in the header. Uses SHA512. + * + */ +void PersistenceFile::writeCheck(const void *ptr, size_t size, size_t n, FILE *s) +{ + if (fwrite(ptr, size, n, s) != n) + { + throw std::runtime_error(formatString("Error writing: %s", strerror(errno))); + } +} + +ssize_t PersistenceFile::readCheck(void *ptr, size_t size, size_t n, FILE *stream) +{ + size_t nread = fread(ptr, size, n, stream); + + if (nread != n) + { + if (feof(f)) + return -1; + + throw std::runtime_error(formatString("Error reading: %s", strerror(errno))); + } + + return nread; +} + +void PersistenceFile::hashFile() +{ + logger->logf(LOG_DEBUG, "Calculating and saving hash of '%s'.", filePath.c_str()); + + fseek(f, TOTAL_HEADER_SIZE, SEEK_SET); + + unsigned int output_len = 0; + unsigned char md_value[EVP_MAX_MD_SIZE]; + std::memset(md_value, 0, EVP_MAX_MD_SIZE); + + EVP_MD_CTX_reset(digestContext); + EVP_DigestInit_ex(digestContext, sha512, NULL); + + while (!feof(f)) + { + size_t n = fread(buf.data(), 1, buf.size(), f); + EVP_DigestUpdate(digestContext, buf.data(), n); + } + + EVP_DigestFinal_ex(digestContext, md_value, &output_len); + + if (output_len != HASH_SIZE) + throw std::runtime_error("Impossible: calculated hash size wrong length"); + + fseek(f, MAGIC_STRING_LENGH, SEEK_SET); + + writeCheck(md_value, output_len, 1, f); + fflush(f); +} + +void PersistenceFile::verifyHash() +{ + fseek(f, 0, SEEK_END); + const size_t size = ftell(f); + + if (size < TOTAL_HEADER_SIZE) + throw std::runtime_error(formatString("File '%s' is too small for it even to contain a header.", filePath.c_str())); + + unsigned char md_from_disk[HASH_SIZE]; + std::memset(md_from_disk, 0, HASH_SIZE); + + fseek(f, MAGIC_STRING_LENGH, SEEK_SET); + fread(md_from_disk, 1, HASH_SIZE, f); + + unsigned int output_len = 0; + unsigned char md_value[EVP_MAX_MD_SIZE]; + std::memset(md_value, 0, EVP_MAX_MD_SIZE); + + EVP_MD_CTX_reset(digestContext); + EVP_DigestInit_ex(digestContext, sha512, NULL); + + while (!feof(f)) + { + size_t n = fread(buf.data(), 1, buf.size(), f); + EVP_DigestUpdate(digestContext, buf.data(), n); + } + + EVP_DigestFinal_ex(digestContext, md_value, &output_len); + + if (output_len != HASH_SIZE) + throw std::runtime_error("Impossible: calculated hash size wrong length"); + + if (std::memcmp(md_from_disk, md_value, output_len) != 0) + { + fclose(f); + f = nullptr; + + if (rename(filePath.c_str(), filePathCorrupt.c_str()) == 0) + { + throw std::runtime_error(formatString("File '%s' is corrupt: hash mismatch. Moved aside to '%s'.", filePath.c_str(), filePathCorrupt.c_str())); + } + else + { + throw std::runtime_error(formatString("File '%s' is corrupt: hash mismatch. Tried to move aside, but that failed: '%s'.", + filePath.c_str(), strerror(errno))); + } + } + + logger->logf(LOG_DEBUG, "Hash of '%s' correct", filePath.c_str()); +} + +/** + * @brief PersistenceFile::makeSureBufSize grows the buffer if n is bigger. + * @param n in bytes. + * + * Remember that when you're dealing with fields that are sized in MQTT by 16 bit ints, like topic paths, the buffer will always be big enough, because it's 1 MB. + */ +void PersistenceFile::makeSureBufSize(size_t n) +{ + if (n > buf.size()) + buf.resize(n); +} + +void PersistenceFile::writeInt64(const int64_t val) +{ + unsigned char buf[8]; + + // Write big-endian + int shift = 56; + int i = 0; + while (shift >= 0) + { + unsigned char wantedByte = val >> shift; + buf[i++] = wantedByte; + shift -= 8; + } + writeCheck(buf, 1, 8, f); +} + +void PersistenceFile::writeUint32(const uint32_t val) +{ + unsigned char buf[4]; + + // Write big-endian + int shift = 24; + int i = 0; + while (shift >= 0) + { + unsigned char wantedByte = val >> shift; + buf[i++] = wantedByte; + shift -= 8; + } + writeCheck(buf, 1, 4, f); +} + +void PersistenceFile::writeUint16(const uint16_t val) +{ + unsigned char buf[2]; + + // Write big-endian + int shift = 8; + int i = 0; + while (shift >= 0) + { + unsigned char wantedByte = val >> shift; + buf[i++] = wantedByte; + shift -= 8; + } + writeCheck(buf, 1, 2, f); +} + +int64_t PersistenceFile::readInt64(bool &eofFound) +{ + if (readCheck(buf.data(), 1, 8, f) < 0) + eofFound = true; + + unsigned char *buf_ = reinterpret_cast(buf.data()); + const uint64_t val1 = ((buf_[0]) << 24) | ((buf_[1]) << 16) | ((buf_[2]) << 8) | (buf_[3]); + const uint64_t val2 = ((buf_[4]) << 24) | ((buf_[5]) << 16) | ((buf_[6]) << 8) | (buf_[7]); + const int64_t val = (val1 << 32) | val2; + return val; +} + +uint32_t PersistenceFile::readUint32(bool &eofFound) +{ + if (readCheck(buf.data(), 1, 4, f) < 0) + eofFound = true; + + uint32_t val; + unsigned char *buf_ = reinterpret_cast(buf.data()); + val = ((buf_[0]) << 24) | ((buf_[1]) << 16) | ((buf_[2]) << 8) | (buf_[3]); + return val; +} + +uint16_t PersistenceFile::readUint16(bool &eofFound) +{ + if (readCheck(buf.data(), 1, 2, f) < 0) + eofFound = true; + + uint16_t val; + unsigned char *buf_ = reinterpret_cast(buf.data()); + val = ((buf_[0]) << 8) | (buf_[1]); + return val; +} + +/** + * @brief RetainedMessagesDB::openWrite doesn't explicitely name a file version (v1, etc), because we always write the current definition. + */ +void PersistenceFile::openWrite(const std::string &versionString) +{ + if (openMode != FileMode::unknown) + throw std::runtime_error("File is already open."); + + f = fopen(filePathTemp.c_str(), "w+b"); + + if (f == nullptr) + { + throw std::runtime_error(formatString("Can't open '%s': %s", filePathTemp.c_str(), strerror(errno))); + } + + openMode = FileMode::write; + + writeCheck(buf.data(), 1, MAGIC_STRING_LENGH, f); + rewind(f); + writeCheck(versionString.c_str(), 1, versionString.length(), f); + fseek(f, MAGIC_STRING_LENGH, SEEK_SET); + writeCheck(buf.data(), 1, HASH_SIZE, f); +} + +void PersistenceFile::openRead() +{ + if (openMode != FileMode::unknown) + throw std::runtime_error("File is already open."); + + f = fopen(filePath.c_str(), "rb"); + + if (f == nullptr) + throw PersistenceFileCantBeOpened(formatString("Can't open '%s': %s.", filePath.c_str(), strerror(errno)).c_str()); + + openMode = FileMode::read; + + verifyHash(); + rewind(f); + + fread(buf.data(), 1, MAGIC_STRING_LENGH, f); + detectedVersionString = std::string(buf.data(), strlen(buf.data())); + + fseek(f, TOTAL_HEADER_SIZE, SEEK_SET); +} + +void PersistenceFile::closeFile() +{ + if (!f) + return; + + if (openMode == FileMode::write) + hashFile(); + + if (f != nullptr) + { + fclose(f); + f = nullptr; + } + + if (openMode == FileMode::write && !filePathTemp.empty() && ! filePath.empty()) + { + if (rename(filePathTemp.c_str(), filePath.c_str()) < 0) + throw std::runtime_error(formatString("Saving '%s' failed: rename of temp file to target failed with: %s", filePath.c_str(), strerror(errno))); + } +} diff --git a/persistencefile.h b/persistencefile.h new file mode 100644 index 0000000..6000d01 --- /dev/null +++ b/persistencefile.h @@ -0,0 +1,92 @@ +/* +This file is part of FlashMQ (https://www.flashmq.org) +Copyright (C) 2021 Wiebe Cazemier + +FlashMQ is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, version 3. + +FlashMQ is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public +License along with FlashMQ. If not, see . +*/ + +#ifndef PERSISTENCEFILE_H +#define PERSISTENCEFILE_H + +#include +#include +#include +#include +#include +#include +#include + +#include "logger.h" + +#define MAGIC_STRING_LENGH 32 +#define HASH_SIZE 64 +#define TOTAL_HEADER_SIZE (MAGIC_STRING_LENGH + HASH_SIZE) + +/** + * @brief The PersistenceFileCantBeOpened class should be thrown when a non-fatal file-not-found error happens. + */ +class PersistenceFileCantBeOpened : public std::runtime_error +{ +public: + PersistenceFileCantBeOpened(const std::string &msg) : std::runtime_error(msg) {} +}; + +class PersistenceFile +{ + std::string filePath; + std::string filePathTemp; + std::string filePathCorrupt; + + EVP_MD_CTX *digestContext = nullptr; + const EVP_MD *sha512 = EVP_sha512(); + + void hashFile(); + void verifyHash(); + +protected: + enum class FileMode + { + unknown, + read, + write + }; + + FILE *f = nullptr; + std::vector buf; + FileMode openMode = FileMode::unknown; + std::string detectedVersionString; + + Logger *logger = Logger::getInstance(); + + void makeSureBufSize(size_t n); + + void writeCheck(const void *__restrict __ptr, size_t __size, size_t __n, FILE *__restrict __s); + ssize_t readCheck(void *__restrict ptr, size_t size, size_t n, FILE *__restrict stream); + + void writeInt64(const int64_t val); + void writeUint32(const uint32_t val); + void writeUint16(const uint16_t val); + int64_t readInt64(bool &eofFound); + uint32_t readUint32(bool &eofFound); + uint16_t readUint16(bool &eofFound); + +public: + PersistenceFile(const std::string &filePath); + ~PersistenceFile(); + + void openWrite(const std::string &versionString); + void openRead(); + void closeFile(); +}; + +#endif // PERSISTENCEFILE_H diff --git a/qospacketqueue.cpp b/qospacketqueue.cpp new file mode 100644 index 0000000..696bedd --- /dev/null +++ b/qospacketqueue.cpp @@ -0,0 +1,77 @@ +#include "qospacketqueue.h" + +#include "cassert" + +#include "mqttpacket.h" + +void QoSPacketQueue::erase(const uint16_t packet_id) +{ + auto it = queue.begin(); + auto end = queue.end(); + while (it != end) + { + std::shared_ptr &p = *it; + if (p->getPacketId() == packet_id) + { + size_t mem = p->getTotalMemoryFootprint(); + qosQueueBytes -= mem; + assert(qosQueueBytes >= 0); + if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose. + qosQueueBytes = 0; + + queue.erase(it); + + break; + } + + it++; + } +} + +size_t QoSPacketQueue::size() const +{ + return queue.size(); +} + +size_t QoSPacketQueue::getByteSize() const +{ + return qosQueueBytes; +} + +/** + * @brief QoSPacketQueue::queuePacket makes a copy of the packet because it has state for the receiver in question. + * @param p + * @param id + * @return the packet copy. + */ +std::shared_ptr QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id) +{ + assert(p.getQos() > 0); + + std::shared_ptr copyPacket = p.getCopy(); + copyPacket->setPacketId(id); + queue.push_back(copyPacket); + qosQueueBytes += copyPacket->getTotalMemoryFootprint(); + return copyPacket; +} + +std::shared_ptr QoSPacketQueue::queuePacket(const Publish &pub, uint16_t id) +{ + assert(pub.qos > 0); + + std::shared_ptr copyPacket(new MqttPacket(pub)); + copyPacket->setPacketId(id); + queue.push_back(copyPacket); + qosQueueBytes += copyPacket->getTotalMemoryFootprint(); + return copyPacket; +} + +std::list>::const_iterator QoSPacketQueue::begin() const +{ + return queue.cbegin(); +} + +std::list>::const_iterator QoSPacketQueue::end() const +{ + return queue.cend(); +} diff --git a/qospacketqueue.h b/qospacketqueue.h new file mode 100644 index 0000000..f8b18d1 --- /dev/null +++ b/qospacketqueue.h @@ -0,0 +1,25 @@ +#ifndef QOSPACKETQUEUE_H +#define QOSPACKETQUEUE_H + +#include "list" + +#include "forward_declarations.h" +#include "types.h" + +class QoSPacketQueue +{ + std::list> queue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] + ssize_t qosQueueBytes = 0; + +public: + void erase(const uint16_t packet_id); + size_t size() const; + size_t getByteSize() const; + std::shared_ptr queuePacket(const MqttPacket &p, uint16_t id); + std::shared_ptr queuePacket(const Publish &pub, uint16_t id); + + std::list>::const_iterator begin() const; + std::list>::const_iterator end() const; +}; + +#endif // QOSPACKETQUEUE_H diff --git a/retainedmessage.cpp b/retainedmessage.cpp index 29647d1..c01f9e2 100644 --- a/retainedmessage.cpp +++ b/retainedmessage.cpp @@ -34,3 +34,8 @@ bool RetainedMessage::empty() const { return payload.empty(); } + +uint32_t RetainedMessage::getSize() const +{ + return topic.length() + payload.length() + 1; +} diff --git a/retainedmessage.h b/retainedmessage.h index 01491b0..97b5369 100644 --- a/retainedmessage.h +++ b/retainedmessage.h @@ -30,6 +30,7 @@ struct RetainedMessage bool operator==(const RetainedMessage &rhs) const; bool empty() const; + uint32_t getSize() const; }; namespace std { diff --git a/retainedmessagesdb.cpp b/retainedmessagesdb.cpp new file mode 100644 index 0000000..70ae9b0 --- /dev/null +++ b/retainedmessagesdb.cpp @@ -0,0 +1,146 @@ +/* +This file is part of FlashMQ (https://www.flashmq.org) +Copyright (C) 2021 Wiebe Cazemier + +FlashMQ is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, version 3. + +FlashMQ is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public +License along with FlashMQ. If not, see . +*/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "retainedmessagesdb.h" +#include "utils.h" +#include "logger.h" + +RetainedMessagesDB::RetainedMessagesDB(const std::string &filePath) : PersistenceFile(filePath) +{ + +} + +void RetainedMessagesDB::openWrite() +{ + PersistenceFile::openWrite(MAGIC_STRING_V1); +} + +void RetainedMessagesDB::openRead() +{ + PersistenceFile::openRead(); + + if (detectedVersionString == MAGIC_STRING_V1) + readVersion = ReadVersion::v1; + else + throw std::runtime_error("Unknown file version."); +} + +/** + * @brief RetainedMessagesDB::writeRowHeader writes two 32 bit integers: topic size and payload size. + * @param rm + * + * So, the header per message is 8 bytes long. + * + * It writes no information about the length of the QoS value, because that is always one. + */ +void RetainedMessagesDB::writeRowHeader(const RetainedMessage &rm) +{ + writeUint32(rm.topic.size()); + writeUint32(rm.payload.size()); +} + +RetainedMessagesDB::RowHeader RetainedMessagesDB::readRowHeaderV1(bool &eofFound) +{ + RetainedMessagesDB::RowHeader result; + + result.topicLen = readUint32(eofFound); + result.payloadLen = readUint32(eofFound); + + return result; +} + +/** + * @brief RetainedMessagesDB::saveData doesn't explicitely name a file version (v1, etc), because we always write the current definition. + * @param messages + */ +void RetainedMessagesDB::saveData(const std::vector &messages) +{ + if (!f) + return; + + char reserved[RESERVED_SPACE_RETAINED_DB_V1]; + std::memset(reserved, 0, RESERVED_SPACE_RETAINED_DB_V1); + + char qos = 0; + for (const RetainedMessage &rm : messages) + { + logger->logf(LOG_DEBUG, "Saving retained message for topic '%s' QoS %d.", rm.topic.c_str(), rm.qos); + + writeRowHeader(rm); + qos = rm.qos; + writeCheck(&qos, 1, 1, f); + writeCheck(reserved, 1, RESERVED_SPACE_RETAINED_DB_V1, f); + writeCheck(rm.topic.c_str(), 1, rm.topic.length(), f); + writeCheck(rm.payload.c_str(), 1, rm.payload.length(), f); + } + + fflush(f); +} + +std::list RetainedMessagesDB::readData() +{ + std::list defaultResult; + + if (!f) + return defaultResult; + + if (readVersion == ReadVersion::v1) + return readDataV1(); + + return defaultResult; +} + +std::list RetainedMessagesDB::readDataV1() +{ + std::list messages; + + while (!feof(f)) + { + bool eofFound = false; + RetainedMessagesDB::RowHeader header = readRowHeaderV1(eofFound); + + if (eofFound) + continue; + + makeSureBufSize(header.payloadLen); + + readCheck(buf.data(), 1, 1, f); + char qos = buf[0]; + fseek(f, RESERVED_SPACE_RETAINED_DB_V1, SEEK_CUR); + + readCheck(buf.data(), 1, header.topicLen, f); + std::string topic(buf.data(), header.topicLen); + + readCheck(buf.data(), 1, header.payloadLen, f); + std::string payload(buf.data(), header.payloadLen); + + RetainedMessage msg(topic, payload, qos); + logger->logf(LOG_DEBUG, "Loading retained message for topic '%s' QoS %d.", msg.topic.c_str(), msg.qos); + messages.push_back(std::move(msg)); + } + + return messages; +} diff --git a/retainedmessagesdb.h b/retainedmessagesdb.h new file mode 100644 index 0000000..7b2299d --- /dev/null +++ b/retainedmessagesdb.h @@ -0,0 +1,71 @@ +/* +This file is part of FlashMQ (https://www.flashmq.org) +Copyright (C) 2021 Wiebe Cazemier + +FlashMQ is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, version 3. + +FlashMQ is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public +License along with FlashMQ. If not, see . +*/ + +#ifndef RETAINEDMESSAGESDB_H +#define RETAINEDMESSAGESDB_H + +#include "persistencefile.h" +#include "retainedmessage.h" + +#include "logger.h" + +#define MAGIC_STRING_V1 "FlashMQRetainedDBv1" +#define ROW_HEADER_SIZE 8 +#define RESERVED_SPACE_RETAINED_DB_V1 31 + +/** + * @brief The RetainedMessagesDB class saves and loads the retained messages. + * + * The DB looks like, from the top: + * + * MAGIC_STRING_LENGH bytes file header + * HASH_SIZE SHA512 + * [MESSAGES] + * + * Each message has a row header, which is 8 bytes. See writeRowHeader(). + * + */ +class RetainedMessagesDB : public PersistenceFile +{ + enum class ReadVersion + { + unknown, + v1 + }; + + struct RowHeader + { + uint32_t topicLen = 0; + uint32_t payloadLen = 0; + }; + + ReadVersion readVersion = ReadVersion::unknown; + + void writeRowHeader(const RetainedMessage &rm); + RowHeader readRowHeaderV1(bool &eofFound); + std::list readDataV1(); +public: + RetainedMessagesDB(const std::string &filePath); + + void openWrite(); + void openRead(); + + void saveData(const std::vector &messages); + std::list readData(); +}; + +#endif // RETAINEDMESSAGESDB_H diff --git a/rwlockguard.cpp b/rwlockguard.cpp index 15bae1c..c25c324 100644 --- a/rwlockguard.cpp +++ b/rwlockguard.cpp @@ -32,7 +32,15 @@ RWLockGuard::~RWLockGuard() void RWLockGuard::wrlock() { - if (pthread_rwlock_wrlock(rwlock) != 0) + const int rc = pthread_rwlock_wrlock(rwlock); + + if (rc == EDEADLK) + { + rwlock = nullptr; + return; + } + + if (rc != 0) throw std::runtime_error("wrlock failed."); } diff --git a/session.cpp b/session.cpp index 3b328f8..e0fc9fe 100644 --- a/session.cpp +++ b/session.cpp @@ -20,16 +20,75 @@ License along with FlashMQ. If not, see . #include "session.h" #include "client.h" +std::chrono::time_point appStartTime = std::chrono::steady_clock::now(); + Session::Session() { } +int64_t Session::getProgramStartedAtUnixTimestamp() +{ + auto secondsSinceEpoch = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); + const std::chrono::seconds age = std::chrono::duration_cast(std::chrono::steady_clock::now() - appStartTime); + int64_t result = secondsSinceEpoch - age.count(); + return result; +} + +void Session::setProgramStartedAtUnixTimestamp(const int64_t unix_timestamp) +{ + auto secondsSinceEpoch = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()); + const std::chrono::seconds _unix_timestamp = std::chrono::seconds(unix_timestamp); + const std::chrono::seconds age_in_s = secondsSinceEpoch - _unix_timestamp; + appStartTime = std::chrono::steady_clock::now() - age_in_s; +} + + +int64_t Session::getSessionRelativeAgeInMs() const +{ + const std::chrono::milliseconds sessionAge = std::chrono::duration_cast(lastTouched - appStartTime); + const int64_t sInMs = sessionAge.count(); + return sInMs; +} + +void Session::setSessionTouch(int64_t ageInMs) +{ + std::chrono::milliseconds ms(ageInMs); + std::chrono::time_point point = appStartTime + ms; + lastTouched = point; +} + +/** + * @brief Session::Session copy constructor. Was created for session storing, and is explicitely kept private, to avoid making accidental copies. + * @param other + * + * Because it was created for session storing, the fields we're copying are the fields being stored. + */ +Session::Session(const Session &other) +{ + this->username = other.username; + this->client_id = other.client_id; + this->incomingQoS2MessageIds = other.incomingQoS2MessageIds; + this->outgoingQoS2MessageIds = other.outgoingQoS2MessageIds; + this->nextPacketId = other.nextPacketId; + this->lastTouched = other.lastTouched; + + // To be fully correct, we should copy the individual packets, but copying sessions is only done for saving them, and I know + // that no member of MqttPacket changes in the QoS process, so we can just keep the shared pointer to the original. + this->qosPacketQueue = other.qosPacketQueue; +} + Session::~Session() { logger->logf(LOG_DEBUG, "Session %s is being destroyed.", getClientId().c_str()); } +std::unique_ptr Session::getCopy() const +{ + std::unique_ptr s(new Session(*this)); + return s; +} + bool Session::clientDisconnected() const { return client.expired(); @@ -45,7 +104,6 @@ void Session::assignActiveConnection(std::shared_ptr &client) this->client = client; this->client_id = client->getClientId(); this->username = client->getUsername(); - this->thread = client->getThreadData(); } /** @@ -60,7 +118,9 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u assert(max_qos <= 2); const char qos = std::min(packet.getQos(), max_qos); - if (thread->authentication.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success) + assert(packet.getSender()); + Authentication &auth = packet.getSender()->getThreadData()->authentication; + if (auth.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success) { if (qos == 0) { @@ -73,11 +133,10 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u } else if (qos > 0) { - std::shared_ptr copyPacket = packet.getCopy(); std::unique_lock locker(qosQueueMutex); const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size(); - if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) + if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosPacketQueue.getByteSize() >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) { logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str()); return; @@ -86,13 +145,7 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u if (nextPacketId == 0) nextPacketId++; - const uint16_t pid = nextPacketId; - copyPacket->setPacketId(pid); - QueuedQosPacket p; - p.packet = copyPacket; - p.id = pid; - qosPacketQueue.push_back(p); - qosQueueBytes += copyPacket->getTotalMemoryFootprint(); + std::shared_ptr copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId); locker.unlock(); if (!clientDisconnected()) @@ -115,27 +168,7 @@ void Session::clearQosMessage(uint16_t packet_id) #endif std::lock_guard locker(qosQueueMutex); - - auto it = qosPacketQueue.begin(); - auto end = qosPacketQueue.end(); - while (it != end) - { - QueuedQosPacket &p = *it; - if (p.id == packet_id) - { - size_t mem = p.packet->getTotalMemoryFootprint(); - qosQueueBytes -= mem; - assert(qosQueueBytes >= 0); - if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose. - qosQueueBytes = 0; - - qosPacketQueue.erase(it); - - break; - } - - it++; - } + qosPacketQueue.erase(packet_id); } // [MQTT-4.4.0-1]: "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST re-send any @@ -152,10 +185,10 @@ uint64_t Session::sendPendingQosMessages() { std::shared_ptr c = makeSharedClient(); std::lock_guard locker(qosQueueMutex); - for (QueuedQosPacket &qosMessage : qosPacketQueue) + for (const std::shared_ptr &qosMessage : qosPacketQueue) { - c->writeMqttPacketAndBlameThisClient(*qosMessage.packet.get(), qosMessage.packet->getQos()); - qosMessage.packet->setDuplicate(); // Any dealings with this packet from here will be a duplicate. + c->writeMqttPacketAndBlameThisClient(*qosMessage.get(), qosMessage->getQos()); + qosMessage->setDuplicate(); // Any dealings with this packet from here will be a duplicate. count++; } diff --git a/session.h b/session.h index c0faad3..ec2770e 100644 --- a/session.h +++ b/session.h @@ -25,38 +25,46 @@ License along with FlashMQ. If not, see . #include "forward_declarations.h" #include "logger.h" +#include "sessionsandsubscriptionsdb.h" +#include "qospacketqueue.h" // TODO make settings. But, num of packets can't exceed 65536, because the counter is 16 bit. #define MAX_QOS_MSG_PENDING_PER_CLIENT 32 #define MAX_QOS_BYTES_PENDING_PER_CLIENT 4096 -struct QueuedQosPacket -{ - uint16_t id; - std::shared_ptr packet; -}; - class Session { +#ifdef TESTING + friend class MainTests; +#endif + + friend class SessionsAndSubscriptionsDB; + std::weak_ptr client; - std::shared_ptr thread; std::string client_id; std::string username; - std::list qosPacketQueue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] + QoSPacketQueue qosPacketQueue; std::set incomingQoS2MessageIds; std::set outgoingQoS2MessageIds; std::mutex qosQueueMutex; uint16_t nextPacketId = 0; - ssize_t qosQueueBytes = 0; - std::chrono::time_point lastTouched; + std::chrono::time_point lastTouched = std::chrono::steady_clock::now(); Logger *logger = Logger::getInstance(); + int64_t getSessionRelativeAgeInMs() const; + void setSessionTouch(int64_t ageInMs); + Session(const Session &other); public: Session(); - Session(const Session &other) = delete; + Session(Session &&other) = delete; ~Session(); + static int64_t getProgramStartedAtUnixTimestamp(); + static void setProgramStartedAtUnixTimestamp(const int64_t unix_timestamp); + + std::unique_ptr getCopy() const; + const std::string &getClientId() const { return client_id; } bool clientDisconnected() const; std::shared_ptr makeSharedClient() const; @@ -74,7 +82,6 @@ public: void addOutgoingQoS2MessageId(uint16_t packet_id); void removeOutgoingQoS2MessageId(u_int16_t packet_id); - }; #endif // SESSION_H diff --git a/sessionsandsubscriptionsdb.cpp b/sessionsandsubscriptionsdb.cpp new file mode 100644 index 0000000..b9b2bf5 --- /dev/null +++ b/sessionsandsubscriptionsdb.cpp @@ -0,0 +1,299 @@ +/* +This file is part of FlashMQ (https://www.flashmq.org) +Copyright (C) 2021 Wiebe Cazemier + +FlashMQ is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, version 3. + +FlashMQ is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public +License along with FlashMQ. If not, see . +*/ + +#include "sessionsandsubscriptionsdb.h" +#include "mqttpacket.h" + +#include "cassert" + +SubscriptionForSerializing::SubscriptionForSerializing(const std::string &clientId, char qos) : + clientId(clientId), + qos(qos) +{ + +} + +SubscriptionForSerializing::SubscriptionForSerializing(const std::string &&clientId, char qos) : + clientId(clientId), + qos(qos) +{ + +} + +SessionsAndSubscriptionsDB::SessionsAndSubscriptionsDB(const std::string &filePath) : PersistenceFile(filePath) +{ + +} + +void SessionsAndSubscriptionsDB::openWrite() +{ + PersistenceFile::openWrite(MAGIC_STRING_SESSION_FILE_V1); +} + +void SessionsAndSubscriptionsDB::openRead() +{ + PersistenceFile::openRead(); + + if (detectedVersionString == MAGIC_STRING_SESSION_FILE_V1) + readVersion = ReadVersion::v1; + else + throw std::runtime_error("Unknown file version."); +} + +SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV1() +{ + SessionsAndSubscriptionsResult result; + + while (!feof(f)) + { + bool eofFound = false; + + const int64_t programStartAge = readInt64(eofFound); + if (eofFound) + continue; + + logger->logf(LOG_DEBUG, "Setting first app start time to timestamp %ld", programStartAge); + Session::setProgramStartedAtUnixTimestamp(programStartAge); + + const uint32_t nrOfSessions = readUint32(eofFound); + + if (eofFound) + continue; + + std::vector reserved(RESERVED_SPACE_SESSIONS_DB_V1); + + for (uint32_t i = 0; i < nrOfSessions; i++) + { + readCheck(buf.data(), 1, RESERVED_SPACE_SESSIONS_DB_V1, f); + + uint32_t usernameLength = readUint32(eofFound); + readCheck(buf.data(), 1, usernameLength, f); + std::string username(buf.data(), usernameLength); + + uint32_t clientIdLength = readUint32(eofFound); + readCheck(buf.data(), 1, clientIdLength, f); + std::string clientId(buf.data(), clientIdLength); + + std::shared_ptr ses(new Session()); + result.sessions.push_back(ses); + ses->username = username; + ses->client_id = clientId; + + logger->logf(LOG_DEBUG, "Loading session '%s'.", ses->getClientId().c_str()); + + const uint32_t nrOfQueuedQoSPackets = readUint32(eofFound); + for (uint32_t i = 0; i < nrOfQueuedQoSPackets; i++) + { + const uint16_t id = readUint16(eofFound); + const uint32_t topicSize = readUint32(eofFound); + const uint32_t payloadSize = readUint32(eofFound); + + assert(id > 0); + + readCheck(buf.data(), 1, 1, f); + const unsigned char qos = buf[0]; + + readCheck(buf.data(), 1, topicSize, f); + const std::string topic(buf.data(), topicSize); + + makeSureBufSize(payloadSize); + readCheck(buf.data(), 1, payloadSize, f); + const std::string payload(buf.data(), payloadSize); + + Publish pub(topic, payload, qos); + logger->logf(LOG_DEBUG, "Loaded QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str()); + ses->qosPacketQueue.queuePacket(pub, id); + } + + const uint32_t nrOfIncomingPacketIds = readUint32(eofFound); + for (uint32_t i = 0; i < nrOfIncomingPacketIds; i++) + { + uint16_t id = readUint16(eofFound); + assert(id > 0); + logger->logf(LOG_DEBUG, "Loaded incomming QoS2 message id %d.", id); + ses->incomingQoS2MessageIds.insert(id); + } + + const uint32_t nrOfOutgoingPacketIds = readUint32(eofFound); + for (uint32_t i = 0; i < nrOfOutgoingPacketIds; i++) + { + uint16_t id = readUint16(eofFound); + assert(id > 0); + logger->logf(LOG_DEBUG, "Loaded outgoing QoS2 message id %d.", id); + ses->outgoingQoS2MessageIds.insert(id); + } + + const uint16_t nextPacketId = readUint16(eofFound); + logger->logf(LOG_DEBUG, "Loaded next packetid %d.", ses->nextPacketId); + ses->nextPacketId = nextPacketId; + + int64_t sessionAge = readInt64(eofFound); + logger->logf(LOG_DEBUG, "Loaded session age: %ld ms.", sessionAge); + ses->setSessionTouch(sessionAge); + } + + const uint32_t nrOfSubscriptions = readUint32(eofFound); + for (uint32_t i = 0; i < nrOfSubscriptions; i++) + { + const uint32_t topicLength = readUint32(eofFound); + readCheck(buf.data(), 1, topicLength, f); + const std::string topic(buf.data(), topicLength); + + logger->logf(LOG_DEBUG, "Loading subscriptions to topic '%s'.", topic.c_str()); + + const uint32_t nrOfClientIds = readUint32(eofFound); + + for (uint32_t i = 0; i < nrOfClientIds; i++) + { + const uint32_t clientIdLength = readUint32(eofFound); + readCheck(buf.data(), 1, clientIdLength, f); + const std::string clientId(buf.data(), clientIdLength); + + char qos; + readCheck(&qos, 1, 1, f); + + logger->logf(LOG_DEBUG, "Saving session '%s' subscription to '%s' QoS %d.", clientId.c_str(), topic.c_str(), qos); + + SubscriptionForSerializing sub(std::move(clientId), qos); + result.subscriptions[topic].push_back(std::move(sub)); + } + + } + } + + return result; +} + +void SessionsAndSubscriptionsDB::writeRowHeader() +{ + +} + +void SessionsAndSubscriptionsDB::saveData(const std::list> &sessions, const std::unordered_map> &subscriptions) +{ + if (!f) + return; + + char reserved[RESERVED_SPACE_SESSIONS_DB_V1]; + std::memset(reserved, 0, RESERVED_SPACE_SESSIONS_DB_V1); + + const int64_t start_stamp = Session::getProgramStartedAtUnixTimestamp(); + logger->logf(LOG_DEBUG, "Saving program first start time stamp as %ld", start_stamp); + writeInt64(start_stamp); + + writeUint32(sessions.size()); + + for (const std::unique_ptr &ses : sessions) + { + logger->logf(LOG_DEBUG, "Saving session '%s'.", ses->getClientId().c_str()); + + writeRowHeader(); + + writeCheck(reserved, 1, RESERVED_SPACE_SESSIONS_DB_V1, f); + + writeUint32(ses->username.length()); + writeCheck(ses->username.c_str(), 1, ses->username.length(), f); + + writeUint32(ses->client_id.length()); + writeCheck(ses->client_id.c_str(), 1, ses->client_id.length(), f); + + const size_t qosPacketsExpected = ses->qosPacketQueue.size(); + size_t qosPacketsCounted = 0; + writeUint32(qosPacketsExpected); + + for (const std::shared_ptr &p: ses->qosPacketQueue) + { + logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", p->getQos(), p->getTopic().c_str()); + + qosPacketsCounted++; + + writeUint16(p->getPacketId()); + + writeUint32(p->getTopic().length()); + std::string payload = p->getPayloadCopy(); + writeUint32(payload.size()); + + const char qos = p->getQos(); + writeCheck(&qos, 1, 1, f); + + writeCheck(p->getTopic().c_str(), 1, p->getTopic().length(), f); + writeCheck(payload.c_str(), 1, payload.length(), f); + } + + assert(qosPacketsExpected == qosPacketsCounted); + + writeUint32(ses->incomingQoS2MessageIds.size()); + for (uint16_t id : ses->incomingQoS2MessageIds) + { + logger->logf(LOG_DEBUG, "Writing incomming QoS2 message id %d.", id); + writeUint16(id); + } + + writeUint32(ses->outgoingQoS2MessageIds.size()); + for (uint16_t id : ses->outgoingQoS2MessageIds) + { + logger->logf(LOG_DEBUG, "Writing outgoing QoS2 message id %d.", id); + writeUint16(id); + } + + logger->logf(LOG_DEBUG, "Writing next packetid %d.", ses->nextPacketId); + writeUint16(ses->nextPacketId); + + const int64_t sInMs = ses->getSessionRelativeAgeInMs(); + logger->logf(LOG_DEBUG, "Writing session age: %ld ms.", sInMs); + writeInt64(sInMs); + } + + writeUint32(subscriptions.size()); + + for (auto &pair : subscriptions) + { + const std::string &topic = pair.first; + const std::list &subscriptions = pair.second; + + logger->logf(LOG_DEBUG, "Writing subscriptions to topic '%s'.", topic.c_str()); + + writeUint32(topic.size()); + writeCheck(topic.c_str(), 1, topic.size(), f); + + writeUint32(subscriptions.size()); + + for (const SubscriptionForSerializing &subscription : subscriptions) + { + logger->logf(LOG_DEBUG, "Saving session '%s' subscription to '%s' QoS %d.", subscription.clientId.c_str(), topic.c_str(), subscription.qos); + + writeUint32(subscription.clientId.size()); + writeCheck(subscription.clientId.c_str(), 1, subscription.clientId.size(), f); + writeCheck(&subscription.qos, 1, 1, f); + } + } + + fflush(f); +} + +SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readData() +{ + SessionsAndSubscriptionsResult defaultResult; + + if (!f) + return defaultResult; + + if (readVersion == ReadVersion::v1) + return readDataV1(); + + return defaultResult; +} diff --git a/sessionsandsubscriptionsdb.h b/sessionsandsubscriptionsdb.h new file mode 100644 index 0000000..15a242e --- /dev/null +++ b/sessionsandsubscriptionsdb.h @@ -0,0 +1,71 @@ +/* +This file is part of FlashMQ (https://www.flashmq.org) +Copyright (C) 2021 Wiebe Cazemier + +FlashMQ is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, version 3. + +FlashMQ is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public +License along with FlashMQ. If not, see . +*/ + +#ifndef SESSIONSANDSUBSCRIPTIONSDB_H +#define SESSIONSANDSUBSCRIPTIONSDB_H + +#include +#include + +#include "persistencefile.h" +#include "session.h" + +#define MAGIC_STRING_SESSION_FILE_V1 "FlashMQRetainedDBv1" +#define RESERVED_SPACE_SESSIONS_DB_V1 32 + +/** + * @brief The SubscriptionForSerializing struct contains the fields we're interested in when saving a subscription. + */ +struct SubscriptionForSerializing +{ + const std::string clientId; + const char qos = 0; + + SubscriptionForSerializing(const std::string &clientId, char qos); + SubscriptionForSerializing(const std::string &&clientId, char qos); +}; + +struct SessionsAndSubscriptionsResult +{ + std::list> sessions; + std::unordered_map> subscriptions; +}; + + +class SessionsAndSubscriptionsDB : public PersistenceFile +{ + enum class ReadVersion + { + unknown, + v1 + }; + + ReadVersion readVersion = ReadVersion::unknown; + + SessionsAndSubscriptionsResult readDataV1(); + void writeRowHeader(); +public: + SessionsAndSubscriptionsDB(const std::string &filePath); + + void openWrite(); + void openRead(); + + void saveData(const std::list> &sessions, const std::unordered_map> &subscriptions); + SessionsAndSubscriptionsResult readData(); +}; + +#endif // SESSIONSANDSUBSCRIPTIONSDB_H diff --git a/settings.cpp b/settings.cpp index 49b175e..629ea52 100644 --- a/settings.cpp +++ b/settings.cpp @@ -16,7 +16,7 @@ License along with FlashMQ. If not, see . */ #include "settings.h" - +#include "utils.h" AuthOptCompatWrap &Settings::getAuthOptsCompat() { @@ -27,3 +27,21 @@ std::unordered_map &Settings::getFlashmqAuthPluginOpts { return this->flashmqAuthPluginOpts; } + +std::string Settings::getRetainedMessagesDBFile() const +{ + if (storageDir.empty()) + return ""; + + std::string path = formatString("%s/%s", storageDir.c_str(), "retained.db"); + return path; +} + +std::string Settings::getSessionsDBFile() const +{ + if (storageDir.empty()) + return ""; + + std::string path = formatString("%s/%s", storageDir.c_str(), "sessions.db"); + return path; +} diff --git a/settings.h b/settings.h index d79e486..ecc627a 100644 --- a/settings.h +++ b/settings.h @@ -53,10 +53,14 @@ public: int rlimitNoFile = 1000000; uint64_t expireSessionsAfterSeconds = 1209600; int authPluginTimerPeriod = 60; + std::string storageDir; std::list> listeners; // Default one is created later, when none are defined. AuthOptCompatWrap &getAuthOptsCompat(); std::unordered_map &getFlashmqAuthPluginOpts(); + + std::string getRetainedMessagesDBFile() const; + std::string getSessionsDBFile() const; }; #endif // SETTINGS_H diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index e216228..6628ebb 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -20,7 +20,7 @@ License along with FlashMQ. If not, see . #include "cassert" #include "rwlockguard.h" - +#include "retainedmessagesdb.h" SubscriptionNode::SubscriptionNode(const std::string &subtopic) : subtopic(subtopic) @@ -33,6 +33,11 @@ std::vector &SubscriptionNode::getSubscribers() return subscribers; } +const std::string &SubscriptionNode::getSubtopic() const +{ + return subtopic; +} + void SubscriptionNode::addSubscriber(const std::shared_ptr &subscriber, char qos) { Subscription sub; @@ -84,15 +89,20 @@ SubscriptionStore::SubscriptionStore() : } -void SubscriptionStore::addSubscription(std::shared_ptr &client, const std::string &topic, const std::vector &subtopics, char qos) +/** + * @brief SubscriptionStore::getDeepestNode gets the node in the tree walking the path of 'the/subscription/topic/path', making new nodes as required. + * @param topic + * @param subtopics + * @return + * + * caller is responsible for locking. + */ +SubscriptionNode *SubscriptionStore::getDeepestNode(const std::string &topic, const std::vector &subtopics) { SubscriptionNode *deepestNode = &root; if (topic.length() > 0 && topic[0] == '$') deepestNode = &rootDollar; - RWLockGuard lock_guard(&subscriptionsRwlock); - lock_guard.wrlock(); - for(const std::string &subtopic : subtopics) { std::unique_ptr *selectedChildren = nullptr; @@ -114,6 +124,15 @@ void SubscriptionStore::addSubscription(std::shared_ptr &client, const s } assert(deepestNode); + return deepestNode; +} + +void SubscriptionStore::addSubscription(std::shared_ptr &client, const std::string &topic, const std::vector &subtopics, char qos) +{ + RWLockGuard lock_guard(&subscriptionsRwlock); + lock_guard.wrlock(); + + SubscriptionNode *deepestNode = getDeepestNode(topic, subtopics); if (deepestNode) { @@ -494,6 +513,183 @@ int64_t SubscriptionStore::getRetainedMessageCount() const return retainedMessageCount; } +void SubscriptionStore::getRetainedMessages(RetainedMessageNode *this_node, std::vector &outputList) const +{ + for(const RetainedMessage &rm : this_node->retainedMessages) + { + outputList.push_back(rm); + } + + for(auto &pair : this_node->children) + { + const std::unique_ptr &child = pair.second; + getRetainedMessages(child.get(), outputList); + } +} + +/** + * @brief SubscriptionStore::getSubscriptions + * @param this_node + * @param composedTopic + * @param root bool. Every subtopic is concatenated with a '/', but not the first topic to 'root'. The root is a bit weird, virtual, so it needs different treatment. + * @param outputList + */ +void SubscriptionStore::getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root, + std::unordered_map> &outputList) const +{ + for (const Subscription &node : this_node->getSubscribers()) + { + if (!node.sessionGone()) + { + SubscriptionForSerializing sub(node.session.lock()->getClientId(), node.qos); + outputList[composedTopic].push_back(sub); + } + } + + for (auto &pair : this_node->children) + { + SubscriptionNode *node = pair.second.get(); + const std::string topicAtNextLevel = root ? pair.first : composedTopic + "/" + pair.first; + getSubscriptions(node, topicAtNextLevel, false, outputList); + } + + if (this_node->childrenPlus) + { + const std::string topicAtNextLevel = root ? "+" : composedTopic + "/+"; + getSubscriptions(this_node->childrenPlus.get(), topicAtNextLevel, false, outputList); + } + + if (this_node->childrenPound) + { + const std::string topicAtNextLevel = root ? "#" : composedTopic + "/#"; + getSubscriptions(this_node->childrenPound.get(), topicAtNextLevel, false, outputList); + } +} + +void SubscriptionStore::saveRetainedMessages(const std::string &filePath) +{ + logger->logf(LOG_INFO, "Saving retained messages to '%s'", filePath.c_str()); + + std::vector result; + result.reserve(retainedMessageCount); + + // Create the list of messages under lock, and unlock right after. + RWLockGuard locker(&retainedMessagesRwlock); + locker.rdlock(); + getRetainedMessages(&retainedMessagesRoot, result); + locker.unlock(); + + logger->logf(LOG_DEBUG, "Collected %ld retained messages to save.", result.size()); + + // Then do the IO without locking the threads. + RetainedMessagesDB db(filePath); + db.openWrite(); + db.saveData(result); +} + +void SubscriptionStore::loadRetainedMessages(const std::string &filePath) +{ + try + { + logger->logf(LOG_INFO, "Loading '%s'", filePath.c_str()); + + RetainedMessagesDB db(filePath); + db.openRead(); + std::list messages = db.readData(); + + RWLockGuard locker(&retainedMessagesRwlock); + locker.wrlock(); + + std::vector subtopics; + for (const RetainedMessage &rm : messages) + { + splitTopic(rm.topic, subtopics); + setRetainedMessage(rm.topic, subtopics, rm.payload, rm.qos); + } + } + catch (PersistenceFileCantBeOpened &ex) + { + logger->logf(LOG_WARNING, "File '%s' is not there (yet)", filePath.c_str()); + } +} + +void SubscriptionStore::saveSessionsAndSubscriptions(const std::string &filePath) +{ + logger->logf(LOG_INFO, "Saving sessions and subscriptions to '%s'", filePath.c_str()); + + RWLockGuard lock_guard(&subscriptionsRwlock); + lock_guard.wrlock(); + + // First copy the sessions... + + std::list> sessionCopies; + + for (const auto &pair : sessionsByIdConst) + { + const Session &org = *pair.second.get(); + sessionCopies.push_back(org.getCopy()); + } + + std::unordered_map> subscriptionCopies; + getSubscriptions(&root, "", true, subscriptionCopies); + + lock_guard.unlock(); + + // Then write the copies to disk, after having released the lock + + logger->logf(LOG_DEBUG, "Collected %ld sessions and %ld subscriptions to save.", sessionCopies.size(), subscriptionCopies.size()); + + SessionsAndSubscriptionsDB db(filePath); + db.openWrite(); + db.saveData(sessionCopies, subscriptionCopies); +} + +void SubscriptionStore::loadSessionsAndSubscriptions(const std::string &filePath) +{ + try + { + logger->logf(LOG_INFO, "Loading '%s'", filePath.c_str()); + + SessionsAndSubscriptionsDB db(filePath); + db.openRead(); + SessionsAndSubscriptionsResult loadedData = db.readData(); + + RWLockGuard locker(&subscriptionsRwlock); + locker.wrlock(); + + for (std::shared_ptr &session : loadedData.sessions) + { + sessionsById[session->getClientId()] = session; + } + + std::vector subtopics; + + for (auto &pair : loadedData.subscriptions) + { + const std::string &topic = pair.first; + const std::list &subs = pair.second; + + for (const SubscriptionForSerializing &sub : subs) + { + splitTopic(topic, subtopics); + SubscriptionNode *subscriptionNode = getDeepestNode(topic, subtopics); + + auto session_it = sessionsByIdConst.find(sub.clientId); + if (session_it != sessionsByIdConst.end()) + { + const std::shared_ptr &ses = session_it->second; + subscriptionNode->addSubscriber(ses, sub.qos); + } + + } + } + } + catch (PersistenceFileCantBeOpened &ex) + { + logger->logf(LOG_WARNING, "File '%s' is not there (yet)", filePath.c_str()); + } +} + // QoS is not used in the comparision. This means you upgrade your QoS by subscribing again. The // specs don't specify what to do there. bool Subscription::operator==(const Subscription &rhs) const diff --git a/subscriptionstore.h b/subscriptionstore.h index 35416b4..93e3dc6 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -53,6 +53,7 @@ public: SubscriptionNode(SubscriptionNode &&node) = delete; std::vector &getSubscribers(); + const std::string &getSubtopic() const; void addSubscriber(const std::shared_ptr &subscriber, char qos); void removeSubscriber(const std::shared_ptr &subscriber); std::unordered_map> children; @@ -77,6 +78,10 @@ class RetainedMessageNode class SubscriptionStore { +#ifdef TESTING + friend class MainTests; +#endif + SubscriptionNode root; SubscriptionNode rootDollar; pthread_rwlock_t subscriptionsRwlock = PTHREAD_RWLOCK_INITIALIZER; @@ -93,7 +98,11 @@ class SubscriptionStore void publishNonRecursively(const MqttPacket &packet, const std::vector &subscribers, uint64_t &count) const; void publishRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, SubscriptionNode *this_node, const MqttPacket &packet, uint64_t &count) const; + void getRetainedMessages(RetainedMessageNode *this_node, std::vector &outputList) const; + void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root, + std::unordered_map> &outputList) const; + SubscriptionNode *getDeepestNode(const std::string &topic, const std::vector &subtopics); public: SubscriptionStore(); @@ -103,7 +112,6 @@ public: bool sessionPresent(const std::string &clientid); void queuePacketAtSubscribers(const std::vector &subtopics, const MqttPacket &packet, bool dollar = false); - uint64_t giveClientRetainedMessages(const std::shared_ptr &ses, const std::string &subscribe_topic, char max_qos); void giveClientRetainedMessagesRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, RetainedMessageNode *this_node, char max_qos, const std::shared_ptr &ses, bool poundMode, uint64_t &count) const; @@ -114,6 +122,12 @@ public: void removeExpiredSessionsClients(int expireSessionsAfterSeconds); int64_t getRetainedMessageCount() const; + + void saveRetainedMessages(const std::string &filePath); + void loadRetainedMessages(const std::string &filePath); + + void saveSessionsAndSubscriptions(const std::string &filePath); + void loadSessionsAndSubscriptions(const std::string &filePath); }; #endif // SUBSCRIPTIONSTORE_H diff --git a/utils.cpp b/utils.cpp index 79f78fc..575d328 100644 --- a/utils.cpp +++ b/utils.cpp @@ -289,6 +289,14 @@ void trim(std::string &s) rtrim(s); } +std::string &rtrim(std::string &s, unsigned char c) +{ + s.erase(std::find_if(s.rbegin(), s.rend(), [=](unsigned char ch) { + return (c != ch); + }).base(), s.end()); + return s; +} + bool startsWith(const std::string &s, const std::string &needle) { return s.find(needle) == 0; diff --git a/utils.h b/utils.h index 0a83fa3..70c8352 100644 --- a/utils.h +++ b/utils.h @@ -28,6 +28,8 @@ License along with FlashMQ. If not, see . #include #include #include +#include "unistd.h" +#include "sys/stat.h" #include "cirbuf.h" #include "bindaddr.h" @@ -64,6 +66,7 @@ void ltrim(std::string &s); void rtrim(std::string &s); void trim(std::string &s); bool startsWith(const std::string &s, const std::string &needle); +std::string &rtrim(std::string &s, unsigned char c); std::string getSecureRandomString(const ssize_t len); std::string str_tolower(std::string s); @@ -92,5 +95,32 @@ ssize_t getFileSize(const std::string &path); std::string sockaddrToString(struct sockaddr *addr); +template void checkWritableDir(const std::string &path) +{ + if (path.empty()) + throw ex("Dir path to check is an empty string."); + + if (access(path.c_str(), W_OK) != 0) + { + std::string msg = formatString("Path '%s' is not there or not writable", path.c_str()); + throw ex(msg); + } + + struct stat statbuf; + memset(&statbuf, 0, sizeof(struct stat)); + if (stat(path.c_str(), &statbuf) < 0) + { + // We checked for W_OK above, so this shouldn't happen. + std::string msg = formatString("Error getting information about '%s'.", path.c_str()); + throw ex(msg); + } + + if (!S_ISDIR(statbuf.st_mode)) + { + std::string msg = formatString("Path '%s' is not a directory.", path.c_str()); + throw ex(msg); + } +} + #endif // UTILS_H -- libgit2 0.21.4