diff --git a/CMakeLists.txt b/CMakeLists.txt
index 759c073..3657202 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -50,6 +50,10 @@ add_executable(FlashMQ
enums.h
threadlocalutils.h
flashmq_plugin.h
+ retainedmessagesdb.h
+ persistencefile.h
+ sessionsandsubscriptionsdb.h
+ qospacketqueue.h
mainapp.cpp
main.cpp
@@ -81,6 +85,10 @@ add_executable(FlashMQ
acltree.cpp
threadlocalutils.cpp
flashmq_plugin.cpp
+ retainedmessagesdb.cpp
+ persistencefile.cpp
+ sessionsandsubscriptionsdb.cpp
+ qospacketqueue.cpp
)
diff --git a/FlashMQTests/FlashMQTests.pro b/FlashMQTests/FlashMQTests.pro
index 1c48319..450323e 100644
--- a/FlashMQTests/FlashMQTests.pro
+++ b/FlashMQTests/FlashMQTests.pro
@@ -42,6 +42,10 @@ SOURCES += tst_maintests.cpp \
../acltree.cpp \
../threadlocalutils.cpp \
../flashmq_plugin.cpp \
+ ../retainedmessagesdb.cpp \
+ ../persistencefile.cpp \
+ ../sessionsandsubscriptionsdb.cpp \
+ ../qospacketqueue.cpp \
mainappthread.cpp \
twoclienttestcontext.cpp
@@ -77,6 +81,10 @@ HEADERS += \
../acltree.h \
../threadlocalutils.h \
../flashmq_plugin.h \
+ ../retainedmessagesdb.h \
+ ../persistencefile.h \
+ ../sessionsandsubscriptionsdb.h \
+ ../qospacketqueue.h \
mainappthread.h \
twoclienttestcontext.h
diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp
index 14294a9..ef1866a 100644
--- a/FlashMQTests/tst_maintests.cpp
+++ b/FlashMQTests/tst_maintests.cpp
@@ -21,12 +21,18 @@ License along with FlashMQ. If not, see .
#include
#include
#include
+#include
+#include
#include "cirbuf.h"
#include "mainapp.h"
#include "mainappthread.h"
#include "twoclienttestcontext.h"
#include "threadlocalutils.h"
+#include "retainedmessagesdb.h"
+#include "sessionsandsubscriptionsdb.h"
+#include "session.h"
+#include "threaddata.h"
// Dumb Qt version gives warnings when comparing uint with number literal.
template
@@ -88,6 +94,12 @@ private slots:
void testTopicsMatch();
+ void testRetainedMessageDB();
+ void testRetainedMessageDBNotPresent();
+ void testRetainedMessageDBEmptyList();
+
+ void testSavingSessions();
+
};
MainTests::MainTests()
@@ -822,6 +834,221 @@ void MainTests::testTopicsMatch()
}
+void MainTests::testRetainedMessageDB()
+{
+ try
+ {
+ std::string longpayload = getSecureRandomString(65537);
+ std::string longTopic = formatString("one/two/%s", getSecureRandomString(4000).c_str());
+
+ std::vector messages;
+ messages.emplace_back("one/two/three", "payload", 0);
+ messages.emplace_back("one/two/wer", "payload", 1);
+ messages.emplace_back("one/e/wer", "payload", 1);
+ messages.emplace_back("one/wee/wer", "asdfasdfasdf", 1);
+ messages.emplace_back("one/two/wer", "µsdf", 1);
+ messages.emplace_back("/boe/bah", longpayload, 1);
+ messages.emplace_back("one/two/wer", "paylasdfaoad", 1);
+ messages.emplace_back("one/two/wer", "payload", 1);
+ messages.emplace_back(longTopic, "payload", 1);
+ messages.emplace_back(longTopic, longpayload, 1);
+ messages.emplace_back("one", "µsdf", 1);
+ messages.emplace_back("/boe", longpayload, 1);
+ messages.emplace_back("one", "µsdf", 1);
+ messages.emplace_back("", "foremptytopic", 0);
+
+ RetainedMessagesDB db("/tmp/flashmqtests_retained.db");
+ db.openWrite();
+ db.saveData(messages);
+ db.closeFile();
+
+ RetainedMessagesDB db2("/tmp/flashmqtests_retained.db");
+ db2.openRead();
+ std::list messagesLoaded = db2.readData();
+ db2.closeFile();
+
+ QCOMPARE(messages.size(), messagesLoaded.size());
+
+ auto itOrg = messages.begin();
+ auto itLoaded = messagesLoaded.begin();
+ while (itOrg != messages.end() && itLoaded != messagesLoaded.end())
+ {
+ RetainedMessage &one = *itOrg;
+ RetainedMessage &two = *itLoaded;
+
+ // Comparing the fields because the RetainedMessage class has an == operator that only looks at topic.
+ QCOMPARE(one.topic, two.topic);
+ QCOMPARE(one.payload, two.payload);
+ QCOMPARE(one.qos, two.qos);
+
+ itOrg++;
+ itLoaded++;
+ }
+ }
+ catch (std::exception &ex)
+ {
+ QVERIFY2(false, ex.what());
+ }
+}
+
+void MainTests::testRetainedMessageDBNotPresent()
+{
+ try
+ {
+ RetainedMessagesDB db2("/tmp/flashmqtests_asdfasdfasdf.db");
+ db2.openRead();
+ std::list messagesLoaded = db2.readData();
+ db2.closeFile();
+
+ MYCASTCOMPARE(messagesLoaded.size(), 0);
+
+ QVERIFY2(false, "We should have run into an exception.");
+ }
+ catch (PersistenceFileCantBeOpened &ex)
+ {
+ QVERIFY(true);
+ }
+ catch (std::exception &ex)
+ {
+ QVERIFY2(false, ex.what());
+ }
+}
+
+void MainTests::testRetainedMessageDBEmptyList()
+{
+ try
+ {
+ std::vector messages;
+
+ RetainedMessagesDB db("/tmp/flashmqtests_retained.db");
+ db.openWrite();
+ db.saveData(messages);
+ db.closeFile();
+
+ RetainedMessagesDB db2("/tmp/flashmqtests_retained.db");
+ db2.openRead();
+ std::list messagesLoaded = db2.readData();
+ db2.closeFile();
+
+ MYCASTCOMPARE(messages.size(), messagesLoaded.size());
+ MYCASTCOMPARE(messages.size(), 0);
+ }
+ catch (std::exception &ex)
+ {
+ QVERIFY2(false, ex.what());
+ }
+}
+
+void MainTests::testSavingSessions()
+{
+ try
+ {
+ std::shared_ptr settings(new Settings());
+ std::shared_ptr store(new SubscriptionStore());
+ std::shared_ptr t(new ThreadData(0, store, settings));
+
+ std::shared_ptr c1(new Client(0, t, nullptr, false, nullptr, settings, false));
+ c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60, false);
+ store->registerClientAndKickExistingOne(c1);
+ c1->getSession()->touch();
+ c1->getSession()->addIncomingQoS2MessageId(2);
+ c1->getSession()->addIncomingQoS2MessageId(3);
+
+ std::shared_ptr c2(new Client(0, t, nullptr, false, nullptr, settings, false));
+ c2->setClientProperties(ProtocolVersion::Mqtt311, "c2", "user2", true, 60, false);
+ store->registerClientAndKickExistingOne(c2);
+ c2->getSession()->touch();
+ c2->getSession()->addOutgoingQoS2MessageId(55);
+ c2->getSession()->addOutgoingQoS2MessageId(66);
+
+ const std::string topic1 = "one/two/three";
+ std::vector subtopics;
+ splitTopic(topic1, subtopics);
+ store->addSubscription(c1, topic1, subtopics, 0);
+
+ const std::string topic2 = "four/five/six";
+ splitTopic(topic2, subtopics);
+ store->addSubscription(c2, topic2, subtopics, 0);
+ store->addSubscription(c1, topic2, subtopics, 0);
+
+ const std::string topic3 = "";
+ splitTopic(topic3, subtopics);
+ store->addSubscription(c2, topic3, subtopics, 0);
+
+ const std::string topic4 = "#";
+ splitTopic(topic4, subtopics);
+ store->addSubscription(c2, topic4, subtopics, 0);
+
+ uint64_t count = 0;
+
+ Publish publish("a/b/c", "Hello Barry", 1);
+
+ std::shared_ptr c1ses = c1->getSession();
+ c1.reset();
+ c1ses->writePacket(publish, 1, false, count);
+
+ store->saveSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db");
+
+ std::shared_ptr store2(new SubscriptionStore());
+ store2->loadSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db");
+
+ MYCASTCOMPARE(store->sessionsById.size(), 2);
+ MYCASTCOMPARE(store2->sessionsById.size(), 2);
+
+ for (auto &pair : store->sessionsById)
+ {
+ std::shared_ptr &ses = pair.second;
+ std::shared_ptr &ses2 = store2->sessionsById[pair.first];
+ QCOMPARE(pair.first, ses2->getClientId());
+
+ QCOMPARE(ses->username, ses2->username);
+ QCOMPARE(ses->client_id, ses2->client_id);
+ QCOMPARE(ses->incomingQoS2MessageIds, ses2->incomingQoS2MessageIds);
+ QCOMPARE(ses->outgoingQoS2MessageIds, ses2->outgoingQoS2MessageIds);
+ QCOMPARE(ses->nextPacketId, ses2->nextPacketId);
+ }
+
+
+ std::unordered_map> store1Subscriptions;
+ store->getSubscriptions(&store->root, "", true, store1Subscriptions);
+
+ std::unordered_map> store2Subscriptions;
+ store->getSubscriptions(&store->root, "", true, store2Subscriptions);
+
+ MYCASTCOMPARE(store1Subscriptions.size(), 4);
+ MYCASTCOMPARE(store2Subscriptions.size(), 4);
+
+ for(auto &pair : store1Subscriptions)
+ {
+ std::list &subscList1 = pair.second;
+ std::list &subscList2 = store2Subscriptions[pair.first];
+
+ QCOMPARE(subscList1.size(), subscList2.size());
+
+ auto subs1It = subscList1.begin();
+ auto subs2It = subscList2.begin();
+
+ while (subs1It != subscList1.end())
+ {
+ SubscriptionForSerializing &one = *subs1It;
+ SubscriptionForSerializing &two = *subs2It;
+ QCOMPARE(one.clientId, two.clientId);
+ QCOMPARE(one.qos, two.qos);
+
+ subs1It++;
+ subs2It++;
+ }
+
+ }
+
+
+ }
+ catch (std::exception &ex)
+ {
+ QVERIFY2(false, ex.what());
+ }
+}
+
QTEST_GUILESS_MAIN(MainTests)
#include "tst_maintests.moc"
diff --git a/configfileparser.cpp b/configfileparser.cpp
index b99810e..9211f17 100644
--- a/configfileparser.cpp
+++ b/configfileparser.cpp
@@ -107,6 +107,7 @@ ConfigFileParser::ConfigFileParser(const std::string &path) :
validKeys.insert("allow_anonymous");
validKeys.insert("rlimit_nofile");
validKeys.insert("expire_sessions_after_seconds");
+ validKeys.insert("storage_dir");
validListenKeys.insert("port");
validListenKeys.insert("protocol");
@@ -412,6 +413,14 @@ void ConfigFileParser::loadFile(bool test)
}
tmpSettings->authPluginTimerPeriod = newVal;
}
+
+ if (key == "storage_dir")
+ {
+ std::string newPath = value;
+ rtrim(newPath, '/');
+ checkWritableDir(newPath);
+ tmpSettings->storageDir = newPath;
+ }
}
}
catch (std::invalid_argument &ex) // catch for the stoi()
diff --git a/mainapp.cpp b/mainapp.cpp
index de05ac7..de4f00e 100644
--- a/mainapp.cpp
+++ b/mainapp.cpp
@@ -202,6 +202,15 @@ MainApp::MainApp(const std::string &configFilePath) :
auto fAuthPluginPeriodicEvent = std::bind(&MainApp::queueAuthPluginPeriodicEventAllThreads, this);
timer.addCallback(fAuthPluginPeriodicEvent, settings->authPluginTimerPeriod*1000, "Auth plugin periodic event.");
}
+
+ if (!settings->storageDir.empty())
+ {
+ subscriptionStore->loadRetainedMessages(settings->getRetainedMessagesDBFile());
+ subscriptionStore->loadSessionsAndSubscriptions(settings->getSessionsDBFile());
+ }
+
+ auto fSaveState = std::bind(&MainApp::saveState, this);
+ timer.addCallback(fSaveState, 900000, "Save state.");
}
MainApp::~MainApp()
@@ -369,6 +378,25 @@ void MainApp::publishStat(const std::string &topic, uint64_t n)
subscriptionStore->setRetainedMessage(topic, subtopics, payload, 0);
}
+void MainApp::saveState()
+{
+ try
+ {
+ if (!settings->storageDir.empty())
+ {
+ const std::string retainedDBPath = settings->getRetainedMessagesDBFile();
+ subscriptionStore->saveRetainedMessages(retainedDBPath);
+
+ const std::string sessionsDBPath = settings->getSessionsDBFile();
+ subscriptionStore->saveSessionsAndSubscriptions(sessionsDBPath);
+ }
+ }
+ catch(std::exception &ex)
+ {
+ logger->logf(LOG_ERR, "Error saving state: %s", ex.what());
+ }
+}
+
void MainApp::initMainApp(int argc, char *argv[])
{
if (instance != nullptr)
@@ -659,6 +687,8 @@ void MainApp::start()
{
thread->waitForQuit();
}
+
+ saveState();
}
void MainApp::quit()
diff --git a/mainapp.h b/mainapp.h
index 2a00475..7a31509 100644
--- a/mainapp.h
+++ b/mainapp.h
@@ -84,6 +84,7 @@ class MainApp
void setFuzzFile(const std::string &fuzzFilePath);
void publishStatsOnDollarTopic();
void publishStat(const std::string &topic, uint64_t n);
+ void saveState();
MainApp(const std::string &configFilePath);
public:
diff --git a/mqttpacket.cpp b/mqttpacket.cpp
index 19817d7..e69c310 100644
--- a/mqttpacket.cpp
+++ b/mqttpacket.cpp
@@ -47,8 +47,11 @@ MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_lengt
pos += fixed_header_length;
}
-// This is easier than using the copy constructor publically, because then I have to keep maintaining a functioning copy constructor.
-// Returning shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource.
+/**
+ * @brief MqttPacket::getCopy (using default copy constructor and resetting some selected fields) is easier than using the copy constructor
+ * publically, because then I have to keep maintaining a functioning copy constructor for each new field I add.
+ * @return a shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource.
+ */
std::shared_ptr MqttPacket::getCopy() const
{
std::shared_ptr copyPacket(new MqttPacket(*this));
@@ -129,6 +132,9 @@ MqttPacket::MqttPacket(const Publish &publish) :
writeBytes(zero, 2);
}
+ payloadStart = pos;
+ payloadLen = publish.payload.length();
+
writeBytes(publish.payload.c_str(), publish.payload.length());
calculateRemainingLength();
}
@@ -546,13 +552,14 @@ void MqttPacket::handlePublish()
}
}
+ payloadLen = remainingAfterPos();
+ payloadStart = pos;
+
if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success)
{
if (retain)
{
- size_t payload_length = remainingAfterPos();
- std::string payload(readBytes(payload_length), payload_length);
-
+ std::string payload(readBytes(payloadLen), payloadLen);
sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, *subtopics, payload, qos);
}
@@ -679,6 +686,23 @@ size_t MqttPacket::getTotalMemoryFootprint()
return bites.size() + sizeof(MqttPacket);
}
+/**
+ * @brief MqttPacket::getPayloadCopy takes part of the vector of bytes and returns it as a string.
+ * @return
+ *
+ * It's necessary sometimes, but it's against FlashMQ's concept of not parsing the payload. Normally, you can just write out
+ * the whole byte array that is a packet to subscribers. No need to copy and such.
+ *
+ * I created it for saving QoS packages in the db file.
+ */
+std::string MqttPacket::getPayloadCopy()
+{
+ assert(payloadStart > 0);
+ assert(pos <= bites.size());
+ std::string payload(&bites[payloadStart], payloadLen);
+ return payload;
+}
+
size_t MqttPacket::getSizeIncludingNonPresentHeader() const
{
size_t total = bites.size();
diff --git a/mqttpacket.h b/mqttpacket.h
index 6c8ddb8..9b0d7a5 100644
--- a/mqttpacket.h
+++ b/mqttpacket.h
@@ -59,6 +59,8 @@ class MqttPacket
size_t packet_id_pos = 0;
uint16_t packet_id = 0;
ProtocolVersion protocolVersion = ProtocolVersion::None;
+ size_t payloadStart = 0;
+ size_t payloadLen = 0;
Logger *logger = Logger::getInstance();
char *readBytes(size_t length);
@@ -116,6 +118,7 @@ public:
uint16_t getPacketId() const;
void setDuplicate();
size_t getTotalMemoryFootprint();
+ std::string getPayloadCopy();
};
#endif // MQTTPACKET_H
diff --git a/persistencefile.cpp b/persistencefile.cpp
new file mode 100644
index 0000000..03ed4f7
--- /dev/null
+++ b/persistencefile.cpp
@@ -0,0 +1,320 @@
+/*
+This file is part of FlashMQ (https://www.flashmq.org)
+Copyright (C) 2021 Wiebe Cazemier
+
+FlashMQ is free software: you can redistribute it and/or modify
+it under the terms of the GNU Affero General Public License as
+published by the Free Software Foundation, version 3.
+
+FlashMQ is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Affero General Public License for more details.
+
+You should have received a copy of the GNU Affero General Public
+License along with FlashMQ. If not, see .
+*/
+
+#include "persistencefile.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "utils.h"
+#include "logger.h"
+
+PersistenceFile::PersistenceFile(const std::string &filePath) :
+ digestContext(EVP_MD_CTX_new()),
+ buf(1024*1024)
+{
+ if (!filePath.empty() && filePath[filePath.size() - 1] == '/')
+ throw std::runtime_error("Target file can't contain trailing slash.");
+
+ this->filePath = filePath;
+ this->filePathTemp = formatString("%s.newfile.%s", filePath.c_str(), getSecureRandomString(8).c_str());
+ this->filePathCorrupt = formatString("%s.corrupt.%s", filePath.c_str(), getSecureRandomString(8).c_str());
+}
+
+PersistenceFile::~PersistenceFile()
+{
+ closeFile();
+
+ if (digestContext)
+ {
+ EVP_MD_CTX_free(digestContext);
+ }
+}
+
+/**
+ * @brief RetainedMessagesDB::hashFile hashes the data after the headers and writes the hash in the header. Uses SHA512.
+ *
+ */
+void PersistenceFile::writeCheck(const void *ptr, size_t size, size_t n, FILE *s)
+{
+ if (fwrite(ptr, size, n, s) != n)
+ {
+ throw std::runtime_error(formatString("Error writing: %s", strerror(errno)));
+ }
+}
+
+ssize_t PersistenceFile::readCheck(void *ptr, size_t size, size_t n, FILE *stream)
+{
+ size_t nread = fread(ptr, size, n, stream);
+
+ if (nread != n)
+ {
+ if (feof(f))
+ return -1;
+
+ throw std::runtime_error(formatString("Error reading: %s", strerror(errno)));
+ }
+
+ return nread;
+}
+
+void PersistenceFile::hashFile()
+{
+ logger->logf(LOG_DEBUG, "Calculating and saving hash of '%s'.", filePath.c_str());
+
+ fseek(f, TOTAL_HEADER_SIZE, SEEK_SET);
+
+ unsigned int output_len = 0;
+ unsigned char md_value[EVP_MAX_MD_SIZE];
+ std::memset(md_value, 0, EVP_MAX_MD_SIZE);
+
+ EVP_MD_CTX_reset(digestContext);
+ EVP_DigestInit_ex(digestContext, sha512, NULL);
+
+ while (!feof(f))
+ {
+ size_t n = fread(buf.data(), 1, buf.size(), f);
+ EVP_DigestUpdate(digestContext, buf.data(), n);
+ }
+
+ EVP_DigestFinal_ex(digestContext, md_value, &output_len);
+
+ if (output_len != HASH_SIZE)
+ throw std::runtime_error("Impossible: calculated hash size wrong length");
+
+ fseek(f, MAGIC_STRING_LENGH, SEEK_SET);
+
+ writeCheck(md_value, output_len, 1, f);
+ fflush(f);
+}
+
+void PersistenceFile::verifyHash()
+{
+ fseek(f, 0, SEEK_END);
+ const size_t size = ftell(f);
+
+ if (size < TOTAL_HEADER_SIZE)
+ throw std::runtime_error(formatString("File '%s' is too small for it even to contain a header.", filePath.c_str()));
+
+ unsigned char md_from_disk[HASH_SIZE];
+ std::memset(md_from_disk, 0, HASH_SIZE);
+
+ fseek(f, MAGIC_STRING_LENGH, SEEK_SET);
+ fread(md_from_disk, 1, HASH_SIZE, f);
+
+ unsigned int output_len = 0;
+ unsigned char md_value[EVP_MAX_MD_SIZE];
+ std::memset(md_value, 0, EVP_MAX_MD_SIZE);
+
+ EVP_MD_CTX_reset(digestContext);
+ EVP_DigestInit_ex(digestContext, sha512, NULL);
+
+ while (!feof(f))
+ {
+ size_t n = fread(buf.data(), 1, buf.size(), f);
+ EVP_DigestUpdate(digestContext, buf.data(), n);
+ }
+
+ EVP_DigestFinal_ex(digestContext, md_value, &output_len);
+
+ if (output_len != HASH_SIZE)
+ throw std::runtime_error("Impossible: calculated hash size wrong length");
+
+ if (std::memcmp(md_from_disk, md_value, output_len) != 0)
+ {
+ fclose(f);
+ f = nullptr;
+
+ if (rename(filePath.c_str(), filePathCorrupt.c_str()) == 0)
+ {
+ throw std::runtime_error(formatString("File '%s' is corrupt: hash mismatch. Moved aside to '%s'.", filePath.c_str(), filePathCorrupt.c_str()));
+ }
+ else
+ {
+ throw std::runtime_error(formatString("File '%s' is corrupt: hash mismatch. Tried to move aside, but that failed: '%s'.",
+ filePath.c_str(), strerror(errno)));
+ }
+ }
+
+ logger->logf(LOG_DEBUG, "Hash of '%s' correct", filePath.c_str());
+}
+
+/**
+ * @brief PersistenceFile::makeSureBufSize grows the buffer if n is bigger.
+ * @param n in bytes.
+ *
+ * Remember that when you're dealing with fields that are sized in MQTT by 16 bit ints, like topic paths, the buffer will always be big enough, because it's 1 MB.
+ */
+void PersistenceFile::makeSureBufSize(size_t n)
+{
+ if (n > buf.size())
+ buf.resize(n);
+}
+
+void PersistenceFile::writeInt64(const int64_t val)
+{
+ unsigned char buf[8];
+
+ // Write big-endian
+ int shift = 56;
+ int i = 0;
+ while (shift >= 0)
+ {
+ unsigned char wantedByte = val >> shift;
+ buf[i++] = wantedByte;
+ shift -= 8;
+ }
+ writeCheck(buf, 1, 8, f);
+}
+
+void PersistenceFile::writeUint32(const uint32_t val)
+{
+ unsigned char buf[4];
+
+ // Write big-endian
+ int shift = 24;
+ int i = 0;
+ while (shift >= 0)
+ {
+ unsigned char wantedByte = val >> shift;
+ buf[i++] = wantedByte;
+ shift -= 8;
+ }
+ writeCheck(buf, 1, 4, f);
+}
+
+void PersistenceFile::writeUint16(const uint16_t val)
+{
+ unsigned char buf[2];
+
+ // Write big-endian
+ int shift = 8;
+ int i = 0;
+ while (shift >= 0)
+ {
+ unsigned char wantedByte = val >> shift;
+ buf[i++] = wantedByte;
+ shift -= 8;
+ }
+ writeCheck(buf, 1, 2, f);
+}
+
+int64_t PersistenceFile::readInt64(bool &eofFound)
+{
+ if (readCheck(buf.data(), 1, 8, f) < 0)
+ eofFound = true;
+
+ unsigned char *buf_ = reinterpret_cast(buf.data());
+ const uint64_t val1 = ((buf_[0]) << 24) | ((buf_[1]) << 16) | ((buf_[2]) << 8) | (buf_[3]);
+ const uint64_t val2 = ((buf_[4]) << 24) | ((buf_[5]) << 16) | ((buf_[6]) << 8) | (buf_[7]);
+ const int64_t val = (val1 << 32) | val2;
+ return val;
+}
+
+uint32_t PersistenceFile::readUint32(bool &eofFound)
+{
+ if (readCheck(buf.data(), 1, 4, f) < 0)
+ eofFound = true;
+
+ uint32_t val;
+ unsigned char *buf_ = reinterpret_cast(buf.data());
+ val = ((buf_[0]) << 24) | ((buf_[1]) << 16) | ((buf_[2]) << 8) | (buf_[3]);
+ return val;
+}
+
+uint16_t PersistenceFile::readUint16(bool &eofFound)
+{
+ if (readCheck(buf.data(), 1, 2, f) < 0)
+ eofFound = true;
+
+ uint16_t val;
+ unsigned char *buf_ = reinterpret_cast(buf.data());
+ val = ((buf_[0]) << 8) | (buf_[1]);
+ return val;
+}
+
+/**
+ * @brief RetainedMessagesDB::openWrite doesn't explicitely name a file version (v1, etc), because we always write the current definition.
+ */
+void PersistenceFile::openWrite(const std::string &versionString)
+{
+ if (openMode != FileMode::unknown)
+ throw std::runtime_error("File is already open.");
+
+ f = fopen(filePathTemp.c_str(), "w+b");
+
+ if (f == nullptr)
+ {
+ throw std::runtime_error(formatString("Can't open '%s': %s", filePathTemp.c_str(), strerror(errno)));
+ }
+
+ openMode = FileMode::write;
+
+ writeCheck(buf.data(), 1, MAGIC_STRING_LENGH, f);
+ rewind(f);
+ writeCheck(versionString.c_str(), 1, versionString.length(), f);
+ fseek(f, MAGIC_STRING_LENGH, SEEK_SET);
+ writeCheck(buf.data(), 1, HASH_SIZE, f);
+}
+
+void PersistenceFile::openRead()
+{
+ if (openMode != FileMode::unknown)
+ throw std::runtime_error("File is already open.");
+
+ f = fopen(filePath.c_str(), "rb");
+
+ if (f == nullptr)
+ throw PersistenceFileCantBeOpened(formatString("Can't open '%s': %s.", filePath.c_str(), strerror(errno)).c_str());
+
+ openMode = FileMode::read;
+
+ verifyHash();
+ rewind(f);
+
+ fread(buf.data(), 1, MAGIC_STRING_LENGH, f);
+ detectedVersionString = std::string(buf.data(), strlen(buf.data()));
+
+ fseek(f, TOTAL_HEADER_SIZE, SEEK_SET);
+}
+
+void PersistenceFile::closeFile()
+{
+ if (!f)
+ return;
+
+ if (openMode == FileMode::write)
+ hashFile();
+
+ if (f != nullptr)
+ {
+ fclose(f);
+ f = nullptr;
+ }
+
+ if (openMode == FileMode::write && !filePathTemp.empty() && ! filePath.empty())
+ {
+ if (rename(filePathTemp.c_str(), filePath.c_str()) < 0)
+ throw std::runtime_error(formatString("Saving '%s' failed: rename of temp file to target failed with: %s", filePath.c_str(), strerror(errno)));
+ }
+}
diff --git a/persistencefile.h b/persistencefile.h
new file mode 100644
index 0000000..6000d01
--- /dev/null
+++ b/persistencefile.h
@@ -0,0 +1,92 @@
+/*
+This file is part of FlashMQ (https://www.flashmq.org)
+Copyright (C) 2021 Wiebe Cazemier
+
+FlashMQ is free software: you can redistribute it and/or modify
+it under the terms of the GNU Affero General Public License as
+published by the Free Software Foundation, version 3.
+
+FlashMQ is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Affero General Public License for more details.
+
+You should have received a copy of the GNU Affero General Public
+License along with FlashMQ. If not, see .
+*/
+
+#ifndef PERSISTENCEFILE_H
+#define PERSISTENCEFILE_H
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "logger.h"
+
+#define MAGIC_STRING_LENGH 32
+#define HASH_SIZE 64
+#define TOTAL_HEADER_SIZE (MAGIC_STRING_LENGH + HASH_SIZE)
+
+/**
+ * @brief The PersistenceFileCantBeOpened class should be thrown when a non-fatal file-not-found error happens.
+ */
+class PersistenceFileCantBeOpened : public std::runtime_error
+{
+public:
+ PersistenceFileCantBeOpened(const std::string &msg) : std::runtime_error(msg) {}
+};
+
+class PersistenceFile
+{
+ std::string filePath;
+ std::string filePathTemp;
+ std::string filePathCorrupt;
+
+ EVP_MD_CTX *digestContext = nullptr;
+ const EVP_MD *sha512 = EVP_sha512();
+
+ void hashFile();
+ void verifyHash();
+
+protected:
+ enum class FileMode
+ {
+ unknown,
+ read,
+ write
+ };
+
+ FILE *f = nullptr;
+ std::vector buf;
+ FileMode openMode = FileMode::unknown;
+ std::string detectedVersionString;
+
+ Logger *logger = Logger::getInstance();
+
+ void makeSureBufSize(size_t n);
+
+ void writeCheck(const void *__restrict __ptr, size_t __size, size_t __n, FILE *__restrict __s);
+ ssize_t readCheck(void *__restrict ptr, size_t size, size_t n, FILE *__restrict stream);
+
+ void writeInt64(const int64_t val);
+ void writeUint32(const uint32_t val);
+ void writeUint16(const uint16_t val);
+ int64_t readInt64(bool &eofFound);
+ uint32_t readUint32(bool &eofFound);
+ uint16_t readUint16(bool &eofFound);
+
+public:
+ PersistenceFile(const std::string &filePath);
+ ~PersistenceFile();
+
+ void openWrite(const std::string &versionString);
+ void openRead();
+ void closeFile();
+};
+
+#endif // PERSISTENCEFILE_H
diff --git a/qospacketqueue.cpp b/qospacketqueue.cpp
new file mode 100644
index 0000000..696bedd
--- /dev/null
+++ b/qospacketqueue.cpp
@@ -0,0 +1,77 @@
+#include "qospacketqueue.h"
+
+#include "cassert"
+
+#include "mqttpacket.h"
+
+void QoSPacketQueue::erase(const uint16_t packet_id)
+{
+ auto it = queue.begin();
+ auto end = queue.end();
+ while (it != end)
+ {
+ std::shared_ptr &p = *it;
+ if (p->getPacketId() == packet_id)
+ {
+ size_t mem = p->getTotalMemoryFootprint();
+ qosQueueBytes -= mem;
+ assert(qosQueueBytes >= 0);
+ if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose.
+ qosQueueBytes = 0;
+
+ queue.erase(it);
+
+ break;
+ }
+
+ it++;
+ }
+}
+
+size_t QoSPacketQueue::size() const
+{
+ return queue.size();
+}
+
+size_t QoSPacketQueue::getByteSize() const
+{
+ return qosQueueBytes;
+}
+
+/**
+ * @brief QoSPacketQueue::queuePacket makes a copy of the packet because it has state for the receiver in question.
+ * @param p
+ * @param id
+ * @return the packet copy.
+ */
+std::shared_ptr QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id)
+{
+ assert(p.getQos() > 0);
+
+ std::shared_ptr copyPacket = p.getCopy();
+ copyPacket->setPacketId(id);
+ queue.push_back(copyPacket);
+ qosQueueBytes += copyPacket->getTotalMemoryFootprint();
+ return copyPacket;
+}
+
+std::shared_ptr QoSPacketQueue::queuePacket(const Publish &pub, uint16_t id)
+{
+ assert(pub.qos > 0);
+
+ std::shared_ptr copyPacket(new MqttPacket(pub));
+ copyPacket->setPacketId(id);
+ queue.push_back(copyPacket);
+ qosQueueBytes += copyPacket->getTotalMemoryFootprint();
+ return copyPacket;
+}
+
+std::list>::const_iterator QoSPacketQueue::begin() const
+{
+ return queue.cbegin();
+}
+
+std::list>::const_iterator QoSPacketQueue::end() const
+{
+ return queue.cend();
+}
diff --git a/qospacketqueue.h b/qospacketqueue.h
new file mode 100644
index 0000000..f8b18d1
--- /dev/null
+++ b/qospacketqueue.h
@@ -0,0 +1,25 @@
+#ifndef QOSPACKETQUEUE_H
+#define QOSPACKETQUEUE_H
+
+#include "list"
+
+#include "forward_declarations.h"
+#include "types.h"
+
+class QoSPacketQueue
+{
+ std::list> queue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6]
+ ssize_t qosQueueBytes = 0;
+
+public:
+ void erase(const uint16_t packet_id);
+ size_t size() const;
+ size_t getByteSize() const;
+ std::shared_ptr queuePacket(const MqttPacket &p, uint16_t id);
+ std::shared_ptr queuePacket(const Publish &pub, uint16_t id);
+
+ std::list>::const_iterator begin() const;
+ std::list>::const_iterator end() const;
+};
+
+#endif // QOSPACKETQUEUE_H
diff --git a/retainedmessage.cpp b/retainedmessage.cpp
index 29647d1..c01f9e2 100644
--- a/retainedmessage.cpp
+++ b/retainedmessage.cpp
@@ -34,3 +34,8 @@ bool RetainedMessage::empty() const
{
return payload.empty();
}
+
+uint32_t RetainedMessage::getSize() const
+{
+ return topic.length() + payload.length() + 1;
+}
diff --git a/retainedmessage.h b/retainedmessage.h
index 01491b0..97b5369 100644
--- a/retainedmessage.h
+++ b/retainedmessage.h
@@ -30,6 +30,7 @@ struct RetainedMessage
bool operator==(const RetainedMessage &rhs) const;
bool empty() const;
+ uint32_t getSize() const;
};
namespace std {
diff --git a/retainedmessagesdb.cpp b/retainedmessagesdb.cpp
new file mode 100644
index 0000000..70ae9b0
--- /dev/null
+++ b/retainedmessagesdb.cpp
@@ -0,0 +1,146 @@
+/*
+This file is part of FlashMQ (https://www.flashmq.org)
+Copyright (C) 2021 Wiebe Cazemier
+
+FlashMQ is free software: you can redistribute it and/or modify
+it under the terms of the GNU Affero General Public License as
+published by the Free Software Foundation, version 3.
+
+FlashMQ is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Affero General Public License for more details.
+
+You should have received a copy of the GNU Affero General Public
+License along with FlashMQ. If not, see .
+*/
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "retainedmessagesdb.h"
+#include "utils.h"
+#include "logger.h"
+
+RetainedMessagesDB::RetainedMessagesDB(const std::string &filePath) : PersistenceFile(filePath)
+{
+
+}
+
+void RetainedMessagesDB::openWrite()
+{
+ PersistenceFile::openWrite(MAGIC_STRING_V1);
+}
+
+void RetainedMessagesDB::openRead()
+{
+ PersistenceFile::openRead();
+
+ if (detectedVersionString == MAGIC_STRING_V1)
+ readVersion = ReadVersion::v1;
+ else
+ throw std::runtime_error("Unknown file version.");
+}
+
+/**
+ * @brief RetainedMessagesDB::writeRowHeader writes two 32 bit integers: topic size and payload size.
+ * @param rm
+ *
+ * So, the header per message is 8 bytes long.
+ *
+ * It writes no information about the length of the QoS value, because that is always one.
+ */
+void RetainedMessagesDB::writeRowHeader(const RetainedMessage &rm)
+{
+ writeUint32(rm.topic.size());
+ writeUint32(rm.payload.size());
+}
+
+RetainedMessagesDB::RowHeader RetainedMessagesDB::readRowHeaderV1(bool &eofFound)
+{
+ RetainedMessagesDB::RowHeader result;
+
+ result.topicLen = readUint32(eofFound);
+ result.payloadLen = readUint32(eofFound);
+
+ return result;
+}
+
+/**
+ * @brief RetainedMessagesDB::saveData doesn't explicitely name a file version (v1, etc), because we always write the current definition.
+ * @param messages
+ */
+void RetainedMessagesDB::saveData(const std::vector &messages)
+{
+ if (!f)
+ return;
+
+ char reserved[RESERVED_SPACE_RETAINED_DB_V1];
+ std::memset(reserved, 0, RESERVED_SPACE_RETAINED_DB_V1);
+
+ char qos = 0;
+ for (const RetainedMessage &rm : messages)
+ {
+ logger->logf(LOG_DEBUG, "Saving retained message for topic '%s' QoS %d.", rm.topic.c_str(), rm.qos);
+
+ writeRowHeader(rm);
+ qos = rm.qos;
+ writeCheck(&qos, 1, 1, f);
+ writeCheck(reserved, 1, RESERVED_SPACE_RETAINED_DB_V1, f);
+ writeCheck(rm.topic.c_str(), 1, rm.topic.length(), f);
+ writeCheck(rm.payload.c_str(), 1, rm.payload.length(), f);
+ }
+
+ fflush(f);
+}
+
+std::list RetainedMessagesDB::readData()
+{
+ std::list defaultResult;
+
+ if (!f)
+ return defaultResult;
+
+ if (readVersion == ReadVersion::v1)
+ return readDataV1();
+
+ return defaultResult;
+}
+
+std::list RetainedMessagesDB::readDataV1()
+{
+ std::list messages;
+
+ while (!feof(f))
+ {
+ bool eofFound = false;
+ RetainedMessagesDB::RowHeader header = readRowHeaderV1(eofFound);
+
+ if (eofFound)
+ continue;
+
+ makeSureBufSize(header.payloadLen);
+
+ readCheck(buf.data(), 1, 1, f);
+ char qos = buf[0];
+ fseek(f, RESERVED_SPACE_RETAINED_DB_V1, SEEK_CUR);
+
+ readCheck(buf.data(), 1, header.topicLen, f);
+ std::string topic(buf.data(), header.topicLen);
+
+ readCheck(buf.data(), 1, header.payloadLen, f);
+ std::string payload(buf.data(), header.payloadLen);
+
+ RetainedMessage msg(topic, payload, qos);
+ logger->logf(LOG_DEBUG, "Loading retained message for topic '%s' QoS %d.", msg.topic.c_str(), msg.qos);
+ messages.push_back(std::move(msg));
+ }
+
+ return messages;
+}
diff --git a/retainedmessagesdb.h b/retainedmessagesdb.h
new file mode 100644
index 0000000..7b2299d
--- /dev/null
+++ b/retainedmessagesdb.h
@@ -0,0 +1,71 @@
+/*
+This file is part of FlashMQ (https://www.flashmq.org)
+Copyright (C) 2021 Wiebe Cazemier
+
+FlashMQ is free software: you can redistribute it and/or modify
+it under the terms of the GNU Affero General Public License as
+published by the Free Software Foundation, version 3.
+
+FlashMQ is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Affero General Public License for more details.
+
+You should have received a copy of the GNU Affero General Public
+License along with FlashMQ. If not, see .
+*/
+
+#ifndef RETAINEDMESSAGESDB_H
+#define RETAINEDMESSAGESDB_H
+
+#include "persistencefile.h"
+#include "retainedmessage.h"
+
+#include "logger.h"
+
+#define MAGIC_STRING_V1 "FlashMQRetainedDBv1"
+#define ROW_HEADER_SIZE 8
+#define RESERVED_SPACE_RETAINED_DB_V1 31
+
+/**
+ * @brief The RetainedMessagesDB class saves and loads the retained messages.
+ *
+ * The DB looks like, from the top:
+ *
+ * MAGIC_STRING_LENGH bytes file header
+ * HASH_SIZE SHA512
+ * [MESSAGES]
+ *
+ * Each message has a row header, which is 8 bytes. See writeRowHeader().
+ *
+ */
+class RetainedMessagesDB : public PersistenceFile
+{
+ enum class ReadVersion
+ {
+ unknown,
+ v1
+ };
+
+ struct RowHeader
+ {
+ uint32_t topicLen = 0;
+ uint32_t payloadLen = 0;
+ };
+
+ ReadVersion readVersion = ReadVersion::unknown;
+
+ void writeRowHeader(const RetainedMessage &rm);
+ RowHeader readRowHeaderV1(bool &eofFound);
+ std::list readDataV1();
+public:
+ RetainedMessagesDB(const std::string &filePath);
+
+ void openWrite();
+ void openRead();
+
+ void saveData(const std::vector &messages);
+ std::list readData();
+};
+
+#endif // RETAINEDMESSAGESDB_H
diff --git a/rwlockguard.cpp b/rwlockguard.cpp
index 15bae1c..c25c324 100644
--- a/rwlockguard.cpp
+++ b/rwlockguard.cpp
@@ -32,7 +32,15 @@ RWLockGuard::~RWLockGuard()
void RWLockGuard::wrlock()
{
- if (pthread_rwlock_wrlock(rwlock) != 0)
+ const int rc = pthread_rwlock_wrlock(rwlock);
+
+ if (rc == EDEADLK)
+ {
+ rwlock = nullptr;
+ return;
+ }
+
+ if (rc != 0)
throw std::runtime_error("wrlock failed.");
}
diff --git a/session.cpp b/session.cpp
index 3b328f8..e0fc9fe 100644
--- a/session.cpp
+++ b/session.cpp
@@ -20,16 +20,75 @@ License along with FlashMQ. If not, see .
#include "session.h"
#include "client.h"
+std::chrono::time_point appStartTime = std::chrono::steady_clock::now();
+
Session::Session()
{
}
+int64_t Session::getProgramStartedAtUnixTimestamp()
+{
+ auto secondsSinceEpoch = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count();
+ const std::chrono::seconds age = std::chrono::duration_cast(std::chrono::steady_clock::now() - appStartTime);
+ int64_t result = secondsSinceEpoch - age.count();
+ return result;
+}
+
+void Session::setProgramStartedAtUnixTimestamp(const int64_t unix_timestamp)
+{
+ auto secondsSinceEpoch = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch());
+ const std::chrono::seconds _unix_timestamp = std::chrono::seconds(unix_timestamp);
+ const std::chrono::seconds age_in_s = secondsSinceEpoch - _unix_timestamp;
+ appStartTime = std::chrono::steady_clock::now() - age_in_s;
+}
+
+
+int64_t Session::getSessionRelativeAgeInMs() const
+{
+ const std::chrono::milliseconds sessionAge = std::chrono::duration_cast(lastTouched - appStartTime);
+ const int64_t sInMs = sessionAge.count();
+ return sInMs;
+}
+
+void Session::setSessionTouch(int64_t ageInMs)
+{
+ std::chrono::milliseconds ms(ageInMs);
+ std::chrono::time_point point = appStartTime + ms;
+ lastTouched = point;
+}
+
+/**
+ * @brief Session::Session copy constructor. Was created for session storing, and is explicitely kept private, to avoid making accidental copies.
+ * @param other
+ *
+ * Because it was created for session storing, the fields we're copying are the fields being stored.
+ */
+Session::Session(const Session &other)
+{
+ this->username = other.username;
+ this->client_id = other.client_id;
+ this->incomingQoS2MessageIds = other.incomingQoS2MessageIds;
+ this->outgoingQoS2MessageIds = other.outgoingQoS2MessageIds;
+ this->nextPacketId = other.nextPacketId;
+ this->lastTouched = other.lastTouched;
+
+ // To be fully correct, we should copy the individual packets, but copying sessions is only done for saving them, and I know
+ // that no member of MqttPacket changes in the QoS process, so we can just keep the shared pointer to the original.
+ this->qosPacketQueue = other.qosPacketQueue;
+}
+
Session::~Session()
{
logger->logf(LOG_DEBUG, "Session %s is being destroyed.", getClientId().c_str());
}
+std::unique_ptr Session::getCopy() const
+{
+ std::unique_ptr s(new Session(*this));
+ return s;
+}
+
bool Session::clientDisconnected() const
{
return client.expired();
@@ -45,7 +104,6 @@ void Session::assignActiveConnection(std::shared_ptr &client)
this->client = client;
this->client_id = client->getClientId();
this->username = client->getUsername();
- this->thread = client->getThreadData();
}
/**
@@ -60,7 +118,9 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u
assert(max_qos <= 2);
const char qos = std::min(packet.getQos(), max_qos);
- if (thread->authentication.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success)
+ assert(packet.getSender());
+ Authentication &auth = packet.getSender()->getThreadData()->authentication;
+ if (auth.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success)
{
if (qos == 0)
{
@@ -73,11 +133,10 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u
}
else if (qos > 0)
{
- std::shared_ptr copyPacket = packet.getCopy();
std::unique_lock locker(qosQueueMutex);
const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size();
- if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0))
+ if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosPacketQueue.getByteSize() >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0))
{
logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str());
return;
@@ -86,13 +145,7 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u
if (nextPacketId == 0)
nextPacketId++;
- const uint16_t pid = nextPacketId;
- copyPacket->setPacketId(pid);
- QueuedQosPacket p;
- p.packet = copyPacket;
- p.id = pid;
- qosPacketQueue.push_back(p);
- qosQueueBytes += copyPacket->getTotalMemoryFootprint();
+ std::shared_ptr copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId);
locker.unlock();
if (!clientDisconnected())
@@ -115,27 +168,7 @@ void Session::clearQosMessage(uint16_t packet_id)
#endif
std::lock_guard locker(qosQueueMutex);
-
- auto it = qosPacketQueue.begin();
- auto end = qosPacketQueue.end();
- while (it != end)
- {
- QueuedQosPacket &p = *it;
- if (p.id == packet_id)
- {
- size_t mem = p.packet->getTotalMemoryFootprint();
- qosQueueBytes -= mem;
- assert(qosQueueBytes >= 0);
- if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose.
- qosQueueBytes = 0;
-
- qosPacketQueue.erase(it);
-
- break;
- }
-
- it++;
- }
+ qosPacketQueue.erase(packet_id);
}
// [MQTT-4.4.0-1]: "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST re-send any
@@ -152,10 +185,10 @@ uint64_t Session::sendPendingQosMessages()
{
std::shared_ptr c = makeSharedClient();
std::lock_guard locker(qosQueueMutex);
- for (QueuedQosPacket &qosMessage : qosPacketQueue)
+ for (const std::shared_ptr &qosMessage : qosPacketQueue)
{
- c->writeMqttPacketAndBlameThisClient(*qosMessage.packet.get(), qosMessage.packet->getQos());
- qosMessage.packet->setDuplicate(); // Any dealings with this packet from here will be a duplicate.
+ c->writeMqttPacketAndBlameThisClient(*qosMessage.get(), qosMessage->getQos());
+ qosMessage->setDuplicate(); // Any dealings with this packet from here will be a duplicate.
count++;
}
diff --git a/session.h b/session.h
index c0faad3..ec2770e 100644
--- a/session.h
+++ b/session.h
@@ -25,38 +25,46 @@ License along with FlashMQ. If not, see .
#include "forward_declarations.h"
#include "logger.h"
+#include "sessionsandsubscriptionsdb.h"
+#include "qospacketqueue.h"
// TODO make settings. But, num of packets can't exceed 65536, because the counter is 16 bit.
#define MAX_QOS_MSG_PENDING_PER_CLIENT 32
#define MAX_QOS_BYTES_PENDING_PER_CLIENT 4096
-struct QueuedQosPacket
-{
- uint16_t id;
- std::shared_ptr packet;
-};
-
class Session
{
+#ifdef TESTING
+ friend class MainTests;
+#endif
+
+ friend class SessionsAndSubscriptionsDB;
+
std::weak_ptr client;
- std::shared_ptr thread;
std::string client_id;
std::string username;
- std::list qosPacketQueue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6]
+ QoSPacketQueue qosPacketQueue;
std::set incomingQoS2MessageIds;
std::set outgoingQoS2MessageIds;
std::mutex qosQueueMutex;
uint16_t nextPacketId = 0;
- ssize_t qosQueueBytes = 0;
- std::chrono::time_point lastTouched;
+ std::chrono::time_point lastTouched = std::chrono::steady_clock::now();
Logger *logger = Logger::getInstance();
+ int64_t getSessionRelativeAgeInMs() const;
+ void setSessionTouch(int64_t ageInMs);
+ Session(const Session &other);
public:
Session();
- Session(const Session &other) = delete;
+
Session(Session &&other) = delete;
~Session();
+ static int64_t getProgramStartedAtUnixTimestamp();
+ static void setProgramStartedAtUnixTimestamp(const int64_t unix_timestamp);
+
+ std::unique_ptr getCopy() const;
+
const std::string &getClientId() const { return client_id; }
bool clientDisconnected() const;
std::shared_ptr makeSharedClient() const;
@@ -74,7 +82,6 @@ public:
void addOutgoingQoS2MessageId(uint16_t packet_id);
void removeOutgoingQoS2MessageId(u_int16_t packet_id);
-
};
#endif // SESSION_H
diff --git a/sessionsandsubscriptionsdb.cpp b/sessionsandsubscriptionsdb.cpp
new file mode 100644
index 0000000..b9b2bf5
--- /dev/null
+++ b/sessionsandsubscriptionsdb.cpp
@@ -0,0 +1,299 @@
+/*
+This file is part of FlashMQ (https://www.flashmq.org)
+Copyright (C) 2021 Wiebe Cazemier
+
+FlashMQ is free software: you can redistribute it and/or modify
+it under the terms of the GNU Affero General Public License as
+published by the Free Software Foundation, version 3.
+
+FlashMQ is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Affero General Public License for more details.
+
+You should have received a copy of the GNU Affero General Public
+License along with FlashMQ. If not, see .
+*/
+
+#include "sessionsandsubscriptionsdb.h"
+#include "mqttpacket.h"
+
+#include "cassert"
+
+SubscriptionForSerializing::SubscriptionForSerializing(const std::string &clientId, char qos) :
+ clientId(clientId),
+ qos(qos)
+{
+
+}
+
+SubscriptionForSerializing::SubscriptionForSerializing(const std::string &&clientId, char qos) :
+ clientId(clientId),
+ qos(qos)
+{
+
+}
+
+SessionsAndSubscriptionsDB::SessionsAndSubscriptionsDB(const std::string &filePath) : PersistenceFile(filePath)
+{
+
+}
+
+void SessionsAndSubscriptionsDB::openWrite()
+{
+ PersistenceFile::openWrite(MAGIC_STRING_SESSION_FILE_V1);
+}
+
+void SessionsAndSubscriptionsDB::openRead()
+{
+ PersistenceFile::openRead();
+
+ if (detectedVersionString == MAGIC_STRING_SESSION_FILE_V1)
+ readVersion = ReadVersion::v1;
+ else
+ throw std::runtime_error("Unknown file version.");
+}
+
+SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV1()
+{
+ SessionsAndSubscriptionsResult result;
+
+ while (!feof(f))
+ {
+ bool eofFound = false;
+
+ const int64_t programStartAge = readInt64(eofFound);
+ if (eofFound)
+ continue;
+
+ logger->logf(LOG_DEBUG, "Setting first app start time to timestamp %ld", programStartAge);
+ Session::setProgramStartedAtUnixTimestamp(programStartAge);
+
+ const uint32_t nrOfSessions = readUint32(eofFound);
+
+ if (eofFound)
+ continue;
+
+ std::vector reserved(RESERVED_SPACE_SESSIONS_DB_V1);
+
+ for (uint32_t i = 0; i < nrOfSessions; i++)
+ {
+ readCheck(buf.data(), 1, RESERVED_SPACE_SESSIONS_DB_V1, f);
+
+ uint32_t usernameLength = readUint32(eofFound);
+ readCheck(buf.data(), 1, usernameLength, f);
+ std::string username(buf.data(), usernameLength);
+
+ uint32_t clientIdLength = readUint32(eofFound);
+ readCheck(buf.data(), 1, clientIdLength, f);
+ std::string clientId(buf.data(), clientIdLength);
+
+ std::shared_ptr ses(new Session());
+ result.sessions.push_back(ses);
+ ses->username = username;
+ ses->client_id = clientId;
+
+ logger->logf(LOG_DEBUG, "Loading session '%s'.", ses->getClientId().c_str());
+
+ const uint32_t nrOfQueuedQoSPackets = readUint32(eofFound);
+ for (uint32_t i = 0; i < nrOfQueuedQoSPackets; i++)
+ {
+ const uint16_t id = readUint16(eofFound);
+ const uint32_t topicSize = readUint32(eofFound);
+ const uint32_t payloadSize = readUint32(eofFound);
+
+ assert(id > 0);
+
+ readCheck(buf.data(), 1, 1, f);
+ const unsigned char qos = buf[0];
+
+ readCheck(buf.data(), 1, topicSize, f);
+ const std::string topic(buf.data(), topicSize);
+
+ makeSureBufSize(payloadSize);
+ readCheck(buf.data(), 1, payloadSize, f);
+ const std::string payload(buf.data(), payloadSize);
+
+ Publish pub(topic, payload, qos);
+ logger->logf(LOG_DEBUG, "Loaded QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str());
+ ses->qosPacketQueue.queuePacket(pub, id);
+ }
+
+ const uint32_t nrOfIncomingPacketIds = readUint32(eofFound);
+ for (uint32_t i = 0; i < nrOfIncomingPacketIds; i++)
+ {
+ uint16_t id = readUint16(eofFound);
+ assert(id > 0);
+ logger->logf(LOG_DEBUG, "Loaded incomming QoS2 message id %d.", id);
+ ses->incomingQoS2MessageIds.insert(id);
+ }
+
+ const uint32_t nrOfOutgoingPacketIds = readUint32(eofFound);
+ for (uint32_t i = 0; i < nrOfOutgoingPacketIds; i++)
+ {
+ uint16_t id = readUint16(eofFound);
+ assert(id > 0);
+ logger->logf(LOG_DEBUG, "Loaded outgoing QoS2 message id %d.", id);
+ ses->outgoingQoS2MessageIds.insert(id);
+ }
+
+ const uint16_t nextPacketId = readUint16(eofFound);
+ logger->logf(LOG_DEBUG, "Loaded next packetid %d.", ses->nextPacketId);
+ ses->nextPacketId = nextPacketId;
+
+ int64_t sessionAge = readInt64(eofFound);
+ logger->logf(LOG_DEBUG, "Loaded session age: %ld ms.", sessionAge);
+ ses->setSessionTouch(sessionAge);
+ }
+
+ const uint32_t nrOfSubscriptions = readUint32(eofFound);
+ for (uint32_t i = 0; i < nrOfSubscriptions; i++)
+ {
+ const uint32_t topicLength = readUint32(eofFound);
+ readCheck(buf.data(), 1, topicLength, f);
+ const std::string topic(buf.data(), topicLength);
+
+ logger->logf(LOG_DEBUG, "Loading subscriptions to topic '%s'.", topic.c_str());
+
+ const uint32_t nrOfClientIds = readUint32(eofFound);
+
+ for (uint32_t i = 0; i < nrOfClientIds; i++)
+ {
+ const uint32_t clientIdLength = readUint32(eofFound);
+ readCheck(buf.data(), 1, clientIdLength, f);
+ const std::string clientId(buf.data(), clientIdLength);
+
+ char qos;
+ readCheck(&qos, 1, 1, f);
+
+ logger->logf(LOG_DEBUG, "Saving session '%s' subscription to '%s' QoS %d.", clientId.c_str(), topic.c_str(), qos);
+
+ SubscriptionForSerializing sub(std::move(clientId), qos);
+ result.subscriptions[topic].push_back(std::move(sub));
+ }
+
+ }
+ }
+
+ return result;
+}
+
+void SessionsAndSubscriptionsDB::writeRowHeader()
+{
+
+}
+
+void SessionsAndSubscriptionsDB::saveData(const std::list> &sessions, const std::unordered_map> &subscriptions)
+{
+ if (!f)
+ return;
+
+ char reserved[RESERVED_SPACE_SESSIONS_DB_V1];
+ std::memset(reserved, 0, RESERVED_SPACE_SESSIONS_DB_V1);
+
+ const int64_t start_stamp = Session::getProgramStartedAtUnixTimestamp();
+ logger->logf(LOG_DEBUG, "Saving program first start time stamp as %ld", start_stamp);
+ writeInt64(start_stamp);
+
+ writeUint32(sessions.size());
+
+ for (const std::unique_ptr &ses : sessions)
+ {
+ logger->logf(LOG_DEBUG, "Saving session '%s'.", ses->getClientId().c_str());
+
+ writeRowHeader();
+
+ writeCheck(reserved, 1, RESERVED_SPACE_SESSIONS_DB_V1, f);
+
+ writeUint32(ses->username.length());
+ writeCheck(ses->username.c_str(), 1, ses->username.length(), f);
+
+ writeUint32(ses->client_id.length());
+ writeCheck(ses->client_id.c_str(), 1, ses->client_id.length(), f);
+
+ const size_t qosPacketsExpected = ses->qosPacketQueue.size();
+ size_t qosPacketsCounted = 0;
+ writeUint32(qosPacketsExpected);
+
+ for (const std::shared_ptr &p: ses->qosPacketQueue)
+ {
+ logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", p->getQos(), p->getTopic().c_str());
+
+ qosPacketsCounted++;
+
+ writeUint16(p->getPacketId());
+
+ writeUint32(p->getTopic().length());
+ std::string payload = p->getPayloadCopy();
+ writeUint32(payload.size());
+
+ const char qos = p->getQos();
+ writeCheck(&qos, 1, 1, f);
+
+ writeCheck(p->getTopic().c_str(), 1, p->getTopic().length(), f);
+ writeCheck(payload.c_str(), 1, payload.length(), f);
+ }
+
+ assert(qosPacketsExpected == qosPacketsCounted);
+
+ writeUint32(ses->incomingQoS2MessageIds.size());
+ for (uint16_t id : ses->incomingQoS2MessageIds)
+ {
+ logger->logf(LOG_DEBUG, "Writing incomming QoS2 message id %d.", id);
+ writeUint16(id);
+ }
+
+ writeUint32(ses->outgoingQoS2MessageIds.size());
+ for (uint16_t id : ses->outgoingQoS2MessageIds)
+ {
+ logger->logf(LOG_DEBUG, "Writing outgoing QoS2 message id %d.", id);
+ writeUint16(id);
+ }
+
+ logger->logf(LOG_DEBUG, "Writing next packetid %d.", ses->nextPacketId);
+ writeUint16(ses->nextPacketId);
+
+ const int64_t sInMs = ses->getSessionRelativeAgeInMs();
+ logger->logf(LOG_DEBUG, "Writing session age: %ld ms.", sInMs);
+ writeInt64(sInMs);
+ }
+
+ writeUint32(subscriptions.size());
+
+ for (auto &pair : subscriptions)
+ {
+ const std::string &topic = pair.first;
+ const std::list &subscriptions = pair.second;
+
+ logger->logf(LOG_DEBUG, "Writing subscriptions to topic '%s'.", topic.c_str());
+
+ writeUint32(topic.size());
+ writeCheck(topic.c_str(), 1, topic.size(), f);
+
+ writeUint32(subscriptions.size());
+
+ for (const SubscriptionForSerializing &subscription : subscriptions)
+ {
+ logger->logf(LOG_DEBUG, "Saving session '%s' subscription to '%s' QoS %d.", subscription.clientId.c_str(), topic.c_str(), subscription.qos);
+
+ writeUint32(subscription.clientId.size());
+ writeCheck(subscription.clientId.c_str(), 1, subscription.clientId.size(), f);
+ writeCheck(&subscription.qos, 1, 1, f);
+ }
+ }
+
+ fflush(f);
+}
+
+SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readData()
+{
+ SessionsAndSubscriptionsResult defaultResult;
+
+ if (!f)
+ return defaultResult;
+
+ if (readVersion == ReadVersion::v1)
+ return readDataV1();
+
+ return defaultResult;
+}
diff --git a/sessionsandsubscriptionsdb.h b/sessionsandsubscriptionsdb.h
new file mode 100644
index 0000000..15a242e
--- /dev/null
+++ b/sessionsandsubscriptionsdb.h
@@ -0,0 +1,71 @@
+/*
+This file is part of FlashMQ (https://www.flashmq.org)
+Copyright (C) 2021 Wiebe Cazemier
+
+FlashMQ is free software: you can redistribute it and/or modify
+it under the terms of the GNU Affero General Public License as
+published by the Free Software Foundation, version 3.
+
+FlashMQ is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Affero General Public License for more details.
+
+You should have received a copy of the GNU Affero General Public
+License along with FlashMQ. If not, see .
+*/
+
+#ifndef SESSIONSANDSUBSCRIPTIONSDB_H
+#define SESSIONSANDSUBSCRIPTIONSDB_H
+
+#include
+#include
+
+#include "persistencefile.h"
+#include "session.h"
+
+#define MAGIC_STRING_SESSION_FILE_V1 "FlashMQRetainedDBv1"
+#define RESERVED_SPACE_SESSIONS_DB_V1 32
+
+/**
+ * @brief The SubscriptionForSerializing struct contains the fields we're interested in when saving a subscription.
+ */
+struct SubscriptionForSerializing
+{
+ const std::string clientId;
+ const char qos = 0;
+
+ SubscriptionForSerializing(const std::string &clientId, char qos);
+ SubscriptionForSerializing(const std::string &&clientId, char qos);
+};
+
+struct SessionsAndSubscriptionsResult
+{
+ std::list> sessions;
+ std::unordered_map> subscriptions;
+};
+
+
+class SessionsAndSubscriptionsDB : public PersistenceFile
+{
+ enum class ReadVersion
+ {
+ unknown,
+ v1
+ };
+
+ ReadVersion readVersion = ReadVersion::unknown;
+
+ SessionsAndSubscriptionsResult readDataV1();
+ void writeRowHeader();
+public:
+ SessionsAndSubscriptionsDB(const std::string &filePath);
+
+ void openWrite();
+ void openRead();
+
+ void saveData(const std::list> &sessions, const std::unordered_map> &subscriptions);
+ SessionsAndSubscriptionsResult readData();
+};
+
+#endif // SESSIONSANDSUBSCRIPTIONSDB_H
diff --git a/settings.cpp b/settings.cpp
index 49b175e..629ea52 100644
--- a/settings.cpp
+++ b/settings.cpp
@@ -16,7 +16,7 @@ License along with FlashMQ. If not, see .
*/
#include "settings.h"
-
+#include "utils.h"
AuthOptCompatWrap &Settings::getAuthOptsCompat()
{
@@ -27,3 +27,21 @@ std::unordered_map &Settings::getFlashmqAuthPluginOpts
{
return this->flashmqAuthPluginOpts;
}
+
+std::string Settings::getRetainedMessagesDBFile() const
+{
+ if (storageDir.empty())
+ return "";
+
+ std::string path = formatString("%s/%s", storageDir.c_str(), "retained.db");
+ return path;
+}
+
+std::string Settings::getSessionsDBFile() const
+{
+ if (storageDir.empty())
+ return "";
+
+ std::string path = formatString("%s/%s", storageDir.c_str(), "sessions.db");
+ return path;
+}
diff --git a/settings.h b/settings.h
index d79e486..ecc627a 100644
--- a/settings.h
+++ b/settings.h
@@ -53,10 +53,14 @@ public:
int rlimitNoFile = 1000000;
uint64_t expireSessionsAfterSeconds = 1209600;
int authPluginTimerPeriod = 60;
+ std::string storageDir;
std::list> listeners; // Default one is created later, when none are defined.
AuthOptCompatWrap &getAuthOptsCompat();
std::unordered_map &getFlashmqAuthPluginOpts();
+
+ std::string getRetainedMessagesDBFile() const;
+ std::string getSessionsDBFile() const;
};
#endif // SETTINGS_H
diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp
index e216228..6628ebb 100644
--- a/subscriptionstore.cpp
+++ b/subscriptionstore.cpp
@@ -20,7 +20,7 @@ License along with FlashMQ. If not, see .
#include "cassert"
#include "rwlockguard.h"
-
+#include "retainedmessagesdb.h"
SubscriptionNode::SubscriptionNode(const std::string &subtopic) :
subtopic(subtopic)
@@ -33,6 +33,11 @@ std::vector &SubscriptionNode::getSubscribers()
return subscribers;
}
+const std::string &SubscriptionNode::getSubtopic() const
+{
+ return subtopic;
+}
+
void SubscriptionNode::addSubscriber(const std::shared_ptr &subscriber, char qos)
{
Subscription sub;
@@ -84,15 +89,20 @@ SubscriptionStore::SubscriptionStore() :
}
-void SubscriptionStore::addSubscription(std::shared_ptr &client, const std::string &topic, const std::vector &subtopics, char qos)
+/**
+ * @brief SubscriptionStore::getDeepestNode gets the node in the tree walking the path of 'the/subscription/topic/path', making new nodes as required.
+ * @param topic
+ * @param subtopics
+ * @return
+ *
+ * caller is responsible for locking.
+ */
+SubscriptionNode *SubscriptionStore::getDeepestNode(const std::string &topic, const std::vector &subtopics)
{
SubscriptionNode *deepestNode = &root;
if (topic.length() > 0 && topic[0] == '$')
deepestNode = &rootDollar;
- RWLockGuard lock_guard(&subscriptionsRwlock);
- lock_guard.wrlock();
-
for(const std::string &subtopic : subtopics)
{
std::unique_ptr *selectedChildren = nullptr;
@@ -114,6 +124,15 @@ void SubscriptionStore::addSubscription(std::shared_ptr &client, const s
}
assert(deepestNode);
+ return deepestNode;
+}
+
+void SubscriptionStore::addSubscription(std::shared_ptr &client, const std::string &topic, const std::vector &subtopics, char qos)
+{
+ RWLockGuard lock_guard(&subscriptionsRwlock);
+ lock_guard.wrlock();
+
+ SubscriptionNode *deepestNode = getDeepestNode(topic, subtopics);
if (deepestNode)
{
@@ -494,6 +513,183 @@ int64_t SubscriptionStore::getRetainedMessageCount() const
return retainedMessageCount;
}
+void SubscriptionStore::getRetainedMessages(RetainedMessageNode *this_node, std::vector &outputList) const
+{
+ for(const RetainedMessage &rm : this_node->retainedMessages)
+ {
+ outputList.push_back(rm);
+ }
+
+ for(auto &pair : this_node->children)
+ {
+ const std::unique_ptr &child = pair.second;
+ getRetainedMessages(child.get(), outputList);
+ }
+}
+
+/**
+ * @brief SubscriptionStore::getSubscriptions
+ * @param this_node
+ * @param composedTopic
+ * @param root bool. Every subtopic is concatenated with a '/', but not the first topic to 'root'. The root is a bit weird, virtual, so it needs different treatment.
+ * @param outputList
+ */
+void SubscriptionStore::getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root,
+ std::unordered_map> &outputList) const
+{
+ for (const Subscription &node : this_node->getSubscribers())
+ {
+ if (!node.sessionGone())
+ {
+ SubscriptionForSerializing sub(node.session.lock()->getClientId(), node.qos);
+ outputList[composedTopic].push_back(sub);
+ }
+ }
+
+ for (auto &pair : this_node->children)
+ {
+ SubscriptionNode *node = pair.second.get();
+ const std::string topicAtNextLevel = root ? pair.first : composedTopic + "/" + pair.first;
+ getSubscriptions(node, topicAtNextLevel, false, outputList);
+ }
+
+ if (this_node->childrenPlus)
+ {
+ const std::string topicAtNextLevel = root ? "+" : composedTopic + "/+";
+ getSubscriptions(this_node->childrenPlus.get(), topicAtNextLevel, false, outputList);
+ }
+
+ if (this_node->childrenPound)
+ {
+ const std::string topicAtNextLevel = root ? "#" : composedTopic + "/#";
+ getSubscriptions(this_node->childrenPound.get(), topicAtNextLevel, false, outputList);
+ }
+}
+
+void SubscriptionStore::saveRetainedMessages(const std::string &filePath)
+{
+ logger->logf(LOG_INFO, "Saving retained messages to '%s'", filePath.c_str());
+
+ std::vector result;
+ result.reserve(retainedMessageCount);
+
+ // Create the list of messages under lock, and unlock right after.
+ RWLockGuard locker(&retainedMessagesRwlock);
+ locker.rdlock();
+ getRetainedMessages(&retainedMessagesRoot, result);
+ locker.unlock();
+
+ logger->logf(LOG_DEBUG, "Collected %ld retained messages to save.", result.size());
+
+ // Then do the IO without locking the threads.
+ RetainedMessagesDB db(filePath);
+ db.openWrite();
+ db.saveData(result);
+}
+
+void SubscriptionStore::loadRetainedMessages(const std::string &filePath)
+{
+ try
+ {
+ logger->logf(LOG_INFO, "Loading '%s'", filePath.c_str());
+
+ RetainedMessagesDB db(filePath);
+ db.openRead();
+ std::list messages = db.readData();
+
+ RWLockGuard locker(&retainedMessagesRwlock);
+ locker.wrlock();
+
+ std::vector subtopics;
+ for (const RetainedMessage &rm : messages)
+ {
+ splitTopic(rm.topic, subtopics);
+ setRetainedMessage(rm.topic, subtopics, rm.payload, rm.qos);
+ }
+ }
+ catch (PersistenceFileCantBeOpened &ex)
+ {
+ logger->logf(LOG_WARNING, "File '%s' is not there (yet)", filePath.c_str());
+ }
+}
+
+void SubscriptionStore::saveSessionsAndSubscriptions(const std::string &filePath)
+{
+ logger->logf(LOG_INFO, "Saving sessions and subscriptions to '%s'", filePath.c_str());
+
+ RWLockGuard lock_guard(&subscriptionsRwlock);
+ lock_guard.wrlock();
+
+ // First copy the sessions...
+
+ std::list> sessionCopies;
+
+ for (const auto &pair : sessionsByIdConst)
+ {
+ const Session &org = *pair.second.get();
+ sessionCopies.push_back(org.getCopy());
+ }
+
+ std::unordered_map> subscriptionCopies;
+ getSubscriptions(&root, "", true, subscriptionCopies);
+
+ lock_guard.unlock();
+
+ // Then write the copies to disk, after having released the lock
+
+ logger->logf(LOG_DEBUG, "Collected %ld sessions and %ld subscriptions to save.", sessionCopies.size(), subscriptionCopies.size());
+
+ SessionsAndSubscriptionsDB db(filePath);
+ db.openWrite();
+ db.saveData(sessionCopies, subscriptionCopies);
+}
+
+void SubscriptionStore::loadSessionsAndSubscriptions(const std::string &filePath)
+{
+ try
+ {
+ logger->logf(LOG_INFO, "Loading '%s'", filePath.c_str());
+
+ SessionsAndSubscriptionsDB db(filePath);
+ db.openRead();
+ SessionsAndSubscriptionsResult loadedData = db.readData();
+
+ RWLockGuard locker(&subscriptionsRwlock);
+ locker.wrlock();
+
+ for (std::shared_ptr &session : loadedData.sessions)
+ {
+ sessionsById[session->getClientId()] = session;
+ }
+
+ std::vector subtopics;
+
+ for (auto &pair : loadedData.subscriptions)
+ {
+ const std::string &topic = pair.first;
+ const std::list &subs = pair.second;
+
+ for (const SubscriptionForSerializing &sub : subs)
+ {
+ splitTopic(topic, subtopics);
+ SubscriptionNode *subscriptionNode = getDeepestNode(topic, subtopics);
+
+ auto session_it = sessionsByIdConst.find(sub.clientId);
+ if (session_it != sessionsByIdConst.end())
+ {
+ const std::shared_ptr &ses = session_it->second;
+ subscriptionNode->addSubscriber(ses, sub.qos);
+ }
+
+ }
+ }
+ }
+ catch (PersistenceFileCantBeOpened &ex)
+ {
+ logger->logf(LOG_WARNING, "File '%s' is not there (yet)", filePath.c_str());
+ }
+}
+
// QoS is not used in the comparision. This means you upgrade your QoS by subscribing again. The
// specs don't specify what to do there.
bool Subscription::operator==(const Subscription &rhs) const
diff --git a/subscriptionstore.h b/subscriptionstore.h
index 35416b4..93e3dc6 100644
--- a/subscriptionstore.h
+++ b/subscriptionstore.h
@@ -53,6 +53,7 @@ public:
SubscriptionNode(SubscriptionNode &&node) = delete;
std::vector &getSubscribers();
+ const std::string &getSubtopic() const;
void addSubscriber(const std::shared_ptr &subscriber, char qos);
void removeSubscriber(const std::shared_ptr &subscriber);
std::unordered_map> children;
@@ -77,6 +78,10 @@ class RetainedMessageNode
class SubscriptionStore
{
+#ifdef TESTING
+ friend class MainTests;
+#endif
+
SubscriptionNode root;
SubscriptionNode rootDollar;
pthread_rwlock_t subscriptionsRwlock = PTHREAD_RWLOCK_INITIALIZER;
@@ -93,7 +98,11 @@ class SubscriptionStore
void publishNonRecursively(const MqttPacket &packet, const std::vector &subscribers, uint64_t &count) const;
void publishRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end,
SubscriptionNode *this_node, const MqttPacket &packet, uint64_t &count) const;
+ void getRetainedMessages(RetainedMessageNode *this_node, std::vector &outputList) const;
+ void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root,
+ std::unordered_map> &outputList) const;
+ SubscriptionNode *getDeepestNode(const std::string &topic, const std::vector &subtopics);
public:
SubscriptionStore();
@@ -103,7 +112,6 @@ public:
bool sessionPresent(const std::string &clientid);
void queuePacketAtSubscribers(const std::vector &subtopics, const MqttPacket &packet, bool dollar = false);
- uint64_t giveClientRetainedMessages(const std::shared_ptr &ses, const std::string &subscribe_topic, char max_qos);
void giveClientRetainedMessagesRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end,
RetainedMessageNode *this_node, char max_qos, const std::shared_ptr &ses,
bool poundMode, uint64_t &count) const;
@@ -114,6 +122,12 @@ public:
void removeExpiredSessionsClients(int expireSessionsAfterSeconds);
int64_t getRetainedMessageCount() const;
+
+ void saveRetainedMessages(const std::string &filePath);
+ void loadRetainedMessages(const std::string &filePath);
+
+ void saveSessionsAndSubscriptions(const std::string &filePath);
+ void loadSessionsAndSubscriptions(const std::string &filePath);
};
#endif // SUBSCRIPTIONSTORE_H
diff --git a/utils.cpp b/utils.cpp
index 79f78fc..575d328 100644
--- a/utils.cpp
+++ b/utils.cpp
@@ -289,6 +289,14 @@ void trim(std::string &s)
rtrim(s);
}
+std::string &rtrim(std::string &s, unsigned char c)
+{
+ s.erase(std::find_if(s.rbegin(), s.rend(), [=](unsigned char ch) {
+ return (c != ch);
+ }).base(), s.end());
+ return s;
+}
+
bool startsWith(const std::string &s, const std::string &needle)
{
return s.find(needle) == 0;
diff --git a/utils.h b/utils.h
index 0a83fa3..70c8352 100644
--- a/utils.h
+++ b/utils.h
@@ -28,6 +28,8 @@ License along with FlashMQ. If not, see .
#include
#include
#include
+#include "unistd.h"
+#include "sys/stat.h"
#include "cirbuf.h"
#include "bindaddr.h"
@@ -64,6 +66,7 @@ void ltrim(std::string &s);
void rtrim(std::string &s);
void trim(std::string &s);
bool startsWith(const std::string &s, const std::string &needle);
+std::string &rtrim(std::string &s, unsigned char c);
std::string getSecureRandomString(const ssize_t len);
std::string str_tolower(std::string s);
@@ -92,5 +95,32 @@ ssize_t getFileSize(const std::string &path);
std::string sockaddrToString(struct sockaddr *addr);
+template void checkWritableDir(const std::string &path)
+{
+ if (path.empty())
+ throw ex("Dir path to check is an empty string.");
+
+ if (access(path.c_str(), W_OK) != 0)
+ {
+ std::string msg = formatString("Path '%s' is not there or not writable", path.c_str());
+ throw ex(msg);
+ }
+
+ struct stat statbuf;
+ memset(&statbuf, 0, sizeof(struct stat));
+ if (stat(path.c_str(), &statbuf) < 0)
+ {
+ // We checked for W_OK above, so this shouldn't happen.
+ std::string msg = formatString("Error getting information about '%s'.", path.c_str());
+ throw ex(msg);
+ }
+
+ if (!S_ISDIR(statbuf.st_mode))
+ {
+ std::string msg = formatString("Path '%s' is not a directory.", path.c_str());
+ throw ex(msg);
+ }
+}
+
#endif // UTILS_H