Commit 4bfa5aa5716d52b57ed4e725f7f5dd6aa03e395d

Authored by Wiebe Cazemier
1 parent 0cbe2a00

Add saving state feature

Files are simple serialized bytes prefaced by lengths. File is hashed to
verify integrity. This was also a good way preventing unexpected errors
when trying to crash the parser by having it load a different file.

This change includes some refactoring that was necessary:

- It 'fixes' looking at the wrong thread's authentiction. This is still
  wrong though. It will be fixed by a thread local pointer in the next
  commit.
- Deadlocks with yourself are handled in rwlockguard.
- QoSPacketQueue is now a class.
- Probably other tweaks.
CMakeLists.txt
@@ -50,6 +50,10 @@ add_executable(FlashMQ @@ -50,6 +50,10 @@ add_executable(FlashMQ
50 enums.h 50 enums.h
51 threadlocalutils.h 51 threadlocalutils.h
52 flashmq_plugin.h 52 flashmq_plugin.h
  53 + retainedmessagesdb.h
  54 + persistencefile.h
  55 + sessionsandsubscriptionsdb.h
  56 + qospacketqueue.h
53 57
54 mainapp.cpp 58 mainapp.cpp
55 main.cpp 59 main.cpp
@@ -81,6 +85,10 @@ add_executable(FlashMQ @@ -81,6 +85,10 @@ add_executable(FlashMQ
81 acltree.cpp 85 acltree.cpp
82 threadlocalutils.cpp 86 threadlocalutils.cpp
83 flashmq_plugin.cpp 87 flashmq_plugin.cpp
  88 + retainedmessagesdb.cpp
  89 + persistencefile.cpp
  90 + sessionsandsubscriptionsdb.cpp
  91 + qospacketqueue.cpp
84 92
85 ) 93 )
86 94
FlashMQTests/FlashMQTests.pro
@@ -42,6 +42,10 @@ SOURCES += tst_maintests.cpp \ @@ -42,6 +42,10 @@ SOURCES += tst_maintests.cpp \
42 ../acltree.cpp \ 42 ../acltree.cpp \
43 ../threadlocalutils.cpp \ 43 ../threadlocalutils.cpp \
44 ../flashmq_plugin.cpp \ 44 ../flashmq_plugin.cpp \
  45 + ../retainedmessagesdb.cpp \
  46 + ../persistencefile.cpp \
  47 + ../sessionsandsubscriptionsdb.cpp \
  48 + ../qospacketqueue.cpp \
45 mainappthread.cpp \ 49 mainappthread.cpp \
46 twoclienttestcontext.cpp 50 twoclienttestcontext.cpp
47 51
@@ -77,6 +81,10 @@ HEADERS += \ @@ -77,6 +81,10 @@ HEADERS += \
77 ../acltree.h \ 81 ../acltree.h \
78 ../threadlocalutils.h \ 82 ../threadlocalutils.h \
79 ../flashmq_plugin.h \ 83 ../flashmq_plugin.h \
  84 + ../retainedmessagesdb.h \
  85 + ../persistencefile.h \
  86 + ../sessionsandsubscriptionsdb.h \
  87 + ../qospacketqueue.h \
80 mainappthread.h \ 88 mainappthread.h \
81 twoclienttestcontext.h 89 twoclienttestcontext.h
82 90
FlashMQTests/tst_maintests.cpp
@@ -21,12 +21,18 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. @@ -21,12 +21,18 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
21 #include <QtQmqtt/qmqtt.h> 21 #include <QtQmqtt/qmqtt.h>
22 #include <QScopedPointer> 22 #include <QScopedPointer>
23 #include <QHostInfo> 23 #include <QHostInfo>
  24 +#include <list>
  25 +#include <unordered_map>
24 26
25 #include "cirbuf.h" 27 #include "cirbuf.h"
26 #include "mainapp.h" 28 #include "mainapp.h"
27 #include "mainappthread.h" 29 #include "mainappthread.h"
28 #include "twoclienttestcontext.h" 30 #include "twoclienttestcontext.h"
29 #include "threadlocalutils.h" 31 #include "threadlocalutils.h"
  32 +#include "retainedmessagesdb.h"
  33 +#include "sessionsandsubscriptionsdb.h"
  34 +#include "session.h"
  35 +#include "threaddata.h"
30 36
31 // Dumb Qt version gives warnings when comparing uint with number literal. 37 // Dumb Qt version gives warnings when comparing uint with number literal.
32 template <typename T1, typename T2> 38 template <typename T1, typename T2>
@@ -88,6 +94,12 @@ private slots: @@ -88,6 +94,12 @@ private slots:
88 94
89 void testTopicsMatch(); 95 void testTopicsMatch();
90 96
  97 + void testRetainedMessageDB();
  98 + void testRetainedMessageDBNotPresent();
  99 + void testRetainedMessageDBEmptyList();
  100 +
  101 + void testSavingSessions();
  102 +
91 }; 103 };
92 104
93 MainTests::MainTests() 105 MainTests::MainTests()
@@ -822,6 +834,221 @@ void MainTests::testTopicsMatch() @@ -822,6 +834,221 @@ void MainTests::testTopicsMatch()
822 834
823 } 835 }
824 836
  837 +void MainTests::testRetainedMessageDB()
  838 +{
  839 + try
  840 + {
  841 + std::string longpayload = getSecureRandomString(65537);
  842 + std::string longTopic = formatString("one/two/%s", getSecureRandomString(4000).c_str());
  843 +
  844 + std::vector<RetainedMessage> messages;
  845 + messages.emplace_back("one/two/three", "payload", 0);
  846 + messages.emplace_back("one/two/wer", "payload", 1);
  847 + messages.emplace_back("one/e/wer", "payload", 1);
  848 + messages.emplace_back("one/wee/wer", "asdfasdfasdf", 1);
  849 + messages.emplace_back("one/two/wer", "ยตsdf", 1);
  850 + messages.emplace_back("/boe/bah", longpayload, 1);
  851 + messages.emplace_back("one/two/wer", "paylasdfaoad", 1);
  852 + messages.emplace_back("one/two/wer", "payload", 1);
  853 + messages.emplace_back(longTopic, "payload", 1);
  854 + messages.emplace_back(longTopic, longpayload, 1);
  855 + messages.emplace_back("one", "ยตsdf", 1);
  856 + messages.emplace_back("/boe", longpayload, 1);
  857 + messages.emplace_back("one", "ยตsdf", 1);
  858 + messages.emplace_back("", "foremptytopic", 0);
  859 +
  860 + RetainedMessagesDB db("/tmp/flashmqtests_retained.db");
  861 + db.openWrite();
  862 + db.saveData(messages);
  863 + db.closeFile();
  864 +
  865 + RetainedMessagesDB db2("/tmp/flashmqtests_retained.db");
  866 + db2.openRead();
  867 + std::list<RetainedMessage> messagesLoaded = db2.readData();
  868 + db2.closeFile();
  869 +
  870 + QCOMPARE(messages.size(), messagesLoaded.size());
  871 +
  872 + auto itOrg = messages.begin();
  873 + auto itLoaded = messagesLoaded.begin();
  874 + while (itOrg != messages.end() && itLoaded != messagesLoaded.end())
  875 + {
  876 + RetainedMessage &one = *itOrg;
  877 + RetainedMessage &two = *itLoaded;
  878 +
  879 + // Comparing the fields because the RetainedMessage class has an == operator that only looks at topic.
  880 + QCOMPARE(one.topic, two.topic);
  881 + QCOMPARE(one.payload, two.payload);
  882 + QCOMPARE(one.qos, two.qos);
  883 +
  884 + itOrg++;
  885 + itLoaded++;
  886 + }
  887 + }
  888 + catch (std::exception &ex)
  889 + {
  890 + QVERIFY2(false, ex.what());
  891 + }
  892 +}
  893 +
  894 +void MainTests::testRetainedMessageDBNotPresent()
  895 +{
  896 + try
  897 + {
  898 + RetainedMessagesDB db2("/tmp/flashmqtests_asdfasdfasdf.db");
  899 + db2.openRead();
  900 + std::list<RetainedMessage> messagesLoaded = db2.readData();
  901 + db2.closeFile();
  902 +
  903 + MYCASTCOMPARE(messagesLoaded.size(), 0);
  904 +
  905 + QVERIFY2(false, "We should have run into an exception.");
  906 + }
  907 + catch (PersistenceFileCantBeOpened &ex)
  908 + {
  909 + QVERIFY(true);
  910 + }
  911 + catch (std::exception &ex)
  912 + {
  913 + QVERIFY2(false, ex.what());
  914 + }
  915 +}
  916 +
  917 +void MainTests::testRetainedMessageDBEmptyList()
  918 +{
  919 + try
  920 + {
  921 + std::vector<RetainedMessage> messages;
  922 +
  923 + RetainedMessagesDB db("/tmp/flashmqtests_retained.db");
  924 + db.openWrite();
  925 + db.saveData(messages);
  926 + db.closeFile();
  927 +
  928 + RetainedMessagesDB db2("/tmp/flashmqtests_retained.db");
  929 + db2.openRead();
  930 + std::list<RetainedMessage> messagesLoaded = db2.readData();
  931 + db2.closeFile();
  932 +
  933 + MYCASTCOMPARE(messages.size(), messagesLoaded.size());
  934 + MYCASTCOMPARE(messages.size(), 0);
  935 + }
  936 + catch (std::exception &ex)
  937 + {
  938 + QVERIFY2(false, ex.what());
  939 + }
  940 +}
  941 +
  942 +void MainTests::testSavingSessions()
  943 +{
  944 + try
  945 + {
  946 + std::shared_ptr<Settings> settings(new Settings());
  947 + std::shared_ptr<SubscriptionStore> store(new SubscriptionStore());
  948 + std::shared_ptr<ThreadData> t(new ThreadData(0, store, settings));
  949 +
  950 + std::shared_ptr<Client> c1(new Client(0, t, nullptr, false, nullptr, settings, false));
  951 + c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60, false);
  952 + store->registerClientAndKickExistingOne(c1);
  953 + c1->getSession()->touch();
  954 + c1->getSession()->addIncomingQoS2MessageId(2);
  955 + c1->getSession()->addIncomingQoS2MessageId(3);
  956 +
  957 + std::shared_ptr<Client> c2(new Client(0, t, nullptr, false, nullptr, settings, false));
  958 + c2->setClientProperties(ProtocolVersion::Mqtt311, "c2", "user2", true, 60, false);
  959 + store->registerClientAndKickExistingOne(c2);
  960 + c2->getSession()->touch();
  961 + c2->getSession()->addOutgoingQoS2MessageId(55);
  962 + c2->getSession()->addOutgoingQoS2MessageId(66);
  963 +
  964 + const std::string topic1 = "one/two/three";
  965 + std::vector<std::string> subtopics;
  966 + splitTopic(topic1, subtopics);
  967 + store->addSubscription(c1, topic1, subtopics, 0);
  968 +
  969 + const std::string topic2 = "four/five/six";
  970 + splitTopic(topic2, subtopics);
  971 + store->addSubscription(c2, topic2, subtopics, 0);
  972 + store->addSubscription(c1, topic2, subtopics, 0);
  973 +
  974 + const std::string topic3 = "";
  975 + splitTopic(topic3, subtopics);
  976 + store->addSubscription(c2, topic3, subtopics, 0);
  977 +
  978 + const std::string topic4 = "#";
  979 + splitTopic(topic4, subtopics);
  980 + store->addSubscription(c2, topic4, subtopics, 0);
  981 +
  982 + uint64_t count = 0;
  983 +
  984 + Publish publish("a/b/c", "Hello Barry", 1);
  985 +
  986 + std::shared_ptr<Session> c1ses = c1->getSession();
  987 + c1.reset();
  988 + c1ses->writePacket(publish, 1, false, count);
  989 +
  990 + store->saveSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db");
  991 +
  992 + std::shared_ptr<SubscriptionStore> store2(new SubscriptionStore());
  993 + store2->loadSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db");
  994 +
  995 + MYCASTCOMPARE(store->sessionsById.size(), 2);
  996 + MYCASTCOMPARE(store2->sessionsById.size(), 2);
  997 +
  998 + for (auto &pair : store->sessionsById)
  999 + {
  1000 + std::shared_ptr<Session> &ses = pair.second;
  1001 + std::shared_ptr<Session> &ses2 = store2->sessionsById[pair.first];
  1002 + QCOMPARE(pair.first, ses2->getClientId());
  1003 +
  1004 + QCOMPARE(ses->username, ses2->username);
  1005 + QCOMPARE(ses->client_id, ses2->client_id);
  1006 + QCOMPARE(ses->incomingQoS2MessageIds, ses2->incomingQoS2MessageIds);
  1007 + QCOMPARE(ses->outgoingQoS2MessageIds, ses2->outgoingQoS2MessageIds);
  1008 + QCOMPARE(ses->nextPacketId, ses2->nextPacketId);
  1009 + }
  1010 +
  1011 +
  1012 + std::unordered_map<std::string, std::list<SubscriptionForSerializing>> store1Subscriptions;
  1013 + store->getSubscriptions(&store->root, "", true, store1Subscriptions);
  1014 +
  1015 + std::unordered_map<std::string, std::list<SubscriptionForSerializing>> store2Subscriptions;
  1016 + store->getSubscriptions(&store->root, "", true, store2Subscriptions);
  1017 +
  1018 + MYCASTCOMPARE(store1Subscriptions.size(), 4);
  1019 + MYCASTCOMPARE(store2Subscriptions.size(), 4);
  1020 +
  1021 + for(auto &pair : store1Subscriptions)
  1022 + {
  1023 + std::list<SubscriptionForSerializing> &subscList1 = pair.second;
  1024 + std::list<SubscriptionForSerializing> &subscList2 = store2Subscriptions[pair.first];
  1025 +
  1026 + QCOMPARE(subscList1.size(), subscList2.size());
  1027 +
  1028 + auto subs1It = subscList1.begin();
  1029 + auto subs2It = subscList2.begin();
  1030 +
  1031 + while (subs1It != subscList1.end())
  1032 + {
  1033 + SubscriptionForSerializing &one = *subs1It;
  1034 + SubscriptionForSerializing &two = *subs2It;
  1035 + QCOMPARE(one.clientId, two.clientId);
  1036 + QCOMPARE(one.qos, two.qos);
  1037 +
  1038 + subs1It++;
  1039 + subs2It++;
  1040 + }
  1041 +
  1042 + }
  1043 +
  1044 +
  1045 + }
  1046 + catch (std::exception &ex)
  1047 + {
  1048 + QVERIFY2(false, ex.what());
  1049 + }
  1050 +}
  1051 +
825 QTEST_GUILESS_MAIN(MainTests) 1052 QTEST_GUILESS_MAIN(MainTests)
826 1053
827 #include "tst_maintests.moc" 1054 #include "tst_maintests.moc"
configfileparser.cpp
@@ -107,6 +107,7 @@ ConfigFileParser::ConfigFileParser(const std::string &amp;path) : @@ -107,6 +107,7 @@ ConfigFileParser::ConfigFileParser(const std::string &amp;path) :
107 validKeys.insert("allow_anonymous"); 107 validKeys.insert("allow_anonymous");
108 validKeys.insert("rlimit_nofile"); 108 validKeys.insert("rlimit_nofile");
109 validKeys.insert("expire_sessions_after_seconds"); 109 validKeys.insert("expire_sessions_after_seconds");
  110 + validKeys.insert("storage_dir");
110 111
111 validListenKeys.insert("port"); 112 validListenKeys.insert("port");
112 validListenKeys.insert("protocol"); 113 validListenKeys.insert("protocol");
@@ -412,6 +413,14 @@ void ConfigFileParser::loadFile(bool test) @@ -412,6 +413,14 @@ void ConfigFileParser::loadFile(bool test)
412 } 413 }
413 tmpSettings->authPluginTimerPeriod = newVal; 414 tmpSettings->authPluginTimerPeriod = newVal;
414 } 415 }
  416 +
  417 + if (key == "storage_dir")
  418 + {
  419 + std::string newPath = value;
  420 + rtrim(newPath, '/');
  421 + checkWritableDir<ConfigFileException>(newPath);
  422 + tmpSettings->storageDir = newPath;
  423 + }
415 } 424 }
416 } 425 }
417 catch (std::invalid_argument &ex) // catch for the stoi() 426 catch (std::invalid_argument &ex) // catch for the stoi()
mainapp.cpp
@@ -202,6 +202,15 @@ MainApp::MainApp(const std::string &amp;configFilePath) : @@ -202,6 +202,15 @@ MainApp::MainApp(const std::string &amp;configFilePath) :
202 auto fAuthPluginPeriodicEvent = std::bind(&MainApp::queueAuthPluginPeriodicEventAllThreads, this); 202 auto fAuthPluginPeriodicEvent = std::bind(&MainApp::queueAuthPluginPeriodicEventAllThreads, this);
203 timer.addCallback(fAuthPluginPeriodicEvent, settings->authPluginTimerPeriod*1000, "Auth plugin periodic event."); 203 timer.addCallback(fAuthPluginPeriodicEvent, settings->authPluginTimerPeriod*1000, "Auth plugin periodic event.");
204 } 204 }
  205 +
  206 + if (!settings->storageDir.empty())
  207 + {
  208 + subscriptionStore->loadRetainedMessages(settings->getRetainedMessagesDBFile());
  209 + subscriptionStore->loadSessionsAndSubscriptions(settings->getSessionsDBFile());
  210 + }
  211 +
  212 + auto fSaveState = std::bind(&MainApp::saveState, this);
  213 + timer.addCallback(fSaveState, 900000, "Save state.");
205 } 214 }
206 215
207 MainApp::~MainApp() 216 MainApp::~MainApp()
@@ -369,6 +378,25 @@ void MainApp::publishStat(const std::string &amp;topic, uint64_t n) @@ -369,6 +378,25 @@ void MainApp::publishStat(const std::string &amp;topic, uint64_t n)
369 subscriptionStore->setRetainedMessage(topic, subtopics, payload, 0); 378 subscriptionStore->setRetainedMessage(topic, subtopics, payload, 0);
370 } 379 }
371 380
  381 +void MainApp::saveState()
  382 +{
  383 + try
  384 + {
  385 + if (!settings->storageDir.empty())
  386 + {
  387 + const std::string retainedDBPath = settings->getRetainedMessagesDBFile();
  388 + subscriptionStore->saveRetainedMessages(retainedDBPath);
  389 +
  390 + const std::string sessionsDBPath = settings->getSessionsDBFile();
  391 + subscriptionStore->saveSessionsAndSubscriptions(sessionsDBPath);
  392 + }
  393 + }
  394 + catch(std::exception &ex)
  395 + {
  396 + logger->logf(LOG_ERR, "Error saving state: %s", ex.what());
  397 + }
  398 +}
  399 +
372 void MainApp::initMainApp(int argc, char *argv[]) 400 void MainApp::initMainApp(int argc, char *argv[])
373 { 401 {
374 if (instance != nullptr) 402 if (instance != nullptr)
@@ -659,6 +687,8 @@ void MainApp::start() @@ -659,6 +687,8 @@ void MainApp::start()
659 { 687 {
660 thread->waitForQuit(); 688 thread->waitForQuit();
661 } 689 }
  690 +
  691 + saveState();
662 } 692 }
663 693
664 void MainApp::quit() 694 void MainApp::quit()
mainapp.h
@@ -84,6 +84,7 @@ class MainApp @@ -84,6 +84,7 @@ class MainApp
84 void setFuzzFile(const std::string &fuzzFilePath); 84 void setFuzzFile(const std::string &fuzzFilePath);
85 void publishStatsOnDollarTopic(); 85 void publishStatsOnDollarTopic();
86 void publishStat(const std::string &topic, uint64_t n); 86 void publishStat(const std::string &topic, uint64_t n);
  87 + void saveState();
87 88
88 MainApp(const std::string &configFilePath); 89 MainApp(const std::string &configFilePath);
89 public: 90 public:
mqttpacket.cpp
@@ -47,8 +47,11 @@ MqttPacket::MqttPacket(CirBuf &amp;buf, size_t packet_len, size_t fixed_header_lengt @@ -47,8 +47,11 @@ MqttPacket::MqttPacket(CirBuf &amp;buf, size_t packet_len, size_t fixed_header_lengt
47 pos += fixed_header_length; 47 pos += fixed_header_length;
48 } 48 }
49 49
50 -// This is easier than using the copy constructor publically, because then I have to keep maintaining a functioning copy constructor.  
51 -// Returning shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource. 50 +/**
  51 + * @brief MqttPacket::getCopy (using default copy constructor and resetting some selected fields) is easier than using the copy constructor
  52 + * publically, because then I have to keep maintaining a functioning copy constructor for each new field I add.
  53 + * @return a shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource.
  54 + */
52 std::shared_ptr<MqttPacket> MqttPacket::getCopy() const 55 std::shared_ptr<MqttPacket> MqttPacket::getCopy() const
53 { 56 {
54 std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(*this)); 57 std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(*this));
@@ -129,6 +132,9 @@ MqttPacket::MqttPacket(const Publish &amp;publish) : @@ -129,6 +132,9 @@ MqttPacket::MqttPacket(const Publish &amp;publish) :
129 writeBytes(zero, 2); 132 writeBytes(zero, 2);
130 } 133 }
131 134
  135 + payloadStart = pos;
  136 + payloadLen = publish.payload.length();
  137 +
132 writeBytes(publish.payload.c_str(), publish.payload.length()); 138 writeBytes(publish.payload.c_str(), publish.payload.length());
133 calculateRemainingLength(); 139 calculateRemainingLength();
134 } 140 }
@@ -546,13 +552,14 @@ void MqttPacket::handlePublish() @@ -546,13 +552,14 @@ void MqttPacket::handlePublish()
546 } 552 }
547 } 553 }
548 554
  555 + payloadLen = remainingAfterPos();
  556 + payloadStart = pos;
  557 +
549 if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success) 558 if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success)
550 { 559 {
551 if (retain) 560 if (retain)
552 { 561 {
553 - size_t payload_length = remainingAfterPos();  
554 - std::string payload(readBytes(payload_length), payload_length);  
555 - 562 + std::string payload(readBytes(payloadLen), payloadLen);
556 sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, *subtopics, payload, qos); 563 sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, *subtopics, payload, qos);
557 } 564 }
558 565
@@ -679,6 +686,23 @@ size_t MqttPacket::getTotalMemoryFootprint() @@ -679,6 +686,23 @@ size_t MqttPacket::getTotalMemoryFootprint()
679 return bites.size() + sizeof(MqttPacket); 686 return bites.size() + sizeof(MqttPacket);
680 } 687 }
681 688
  689 +/**
  690 + * @brief MqttPacket::getPayloadCopy takes part of the vector of bytes and returns it as a string.
  691 + * @return
  692 + *
  693 + * It's necessary sometimes, but it's against FlashMQ's concept of not parsing the payload. Normally, you can just write out
  694 + * the whole byte array that is a packet to subscribers. No need to copy and such.
  695 + *
  696 + * I created it for saving QoS packages in the db file.
  697 + */
  698 +std::string MqttPacket::getPayloadCopy()
  699 +{
  700 + assert(payloadStart > 0);
  701 + assert(pos <= bites.size());
  702 + std::string payload(&bites[payloadStart], payloadLen);
  703 + return payload;
  704 +}
  705 +
682 size_t MqttPacket::getSizeIncludingNonPresentHeader() const 706 size_t MqttPacket::getSizeIncludingNonPresentHeader() const
683 { 707 {
684 size_t total = bites.size(); 708 size_t total = bites.size();
mqttpacket.h
@@ -59,6 +59,8 @@ class MqttPacket @@ -59,6 +59,8 @@ class MqttPacket
59 size_t packet_id_pos = 0; 59 size_t packet_id_pos = 0;
60 uint16_t packet_id = 0; 60 uint16_t packet_id = 0;
61 ProtocolVersion protocolVersion = ProtocolVersion::None; 61 ProtocolVersion protocolVersion = ProtocolVersion::None;
  62 + size_t payloadStart = 0;
  63 + size_t payloadLen = 0;
62 Logger *logger = Logger::getInstance(); 64 Logger *logger = Logger::getInstance();
63 65
64 char *readBytes(size_t length); 66 char *readBytes(size_t length);
@@ -116,6 +118,7 @@ public: @@ -116,6 +118,7 @@ public:
116 uint16_t getPacketId() const; 118 uint16_t getPacketId() const;
117 void setDuplicate(); 119 void setDuplicate();
118 size_t getTotalMemoryFootprint(); 120 size_t getTotalMemoryFootprint();
  121 + std::string getPayloadCopy();
119 }; 122 };
120 123
121 #endif // MQTTPACKET_H 124 #endif // MQTTPACKET_H
persistencefile.cpp 0 โ†’ 100644
  1 +/*
  2 +This file is part of FlashMQ (https://www.flashmq.org)
  3 +Copyright (C) 2021 Wiebe Cazemier
  4 +
  5 +FlashMQ is free software: you can redistribute it and/or modify
  6 +it under the terms of the GNU Affero General Public License as
  7 +published by the Free Software Foundation, version 3.
  8 +
  9 +FlashMQ is distributed in the hope that it will be useful,
  10 +but WITHOUT ANY WARRANTY; without even the implied warranty of
  11 +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12 +GNU Affero General Public License for more details.
  13 +
  14 +You should have received a copy of the GNU Affero General Public
  15 +License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
  16 +*/
  17 +
  18 +#include "persistencefile.h"
  19 +
  20 +#include <sys/types.h>
  21 +#include <sys/stat.h>
  22 +#include <fcntl.h>
  23 +#include <unistd.h>
  24 +#include <exception>
  25 +#include <stdexcept>
  26 +#include <stdio.h>
  27 +#include <cstring>
  28 +
  29 +#include "utils.h"
  30 +#include "logger.h"
  31 +
  32 +PersistenceFile::PersistenceFile(const std::string &filePath) :
  33 + digestContext(EVP_MD_CTX_new()),
  34 + buf(1024*1024)
  35 +{
  36 + if (!filePath.empty() && filePath[filePath.size() - 1] == '/')
  37 + throw std::runtime_error("Target file can't contain trailing slash.");
  38 +
  39 + this->filePath = filePath;
  40 + this->filePathTemp = formatString("%s.newfile.%s", filePath.c_str(), getSecureRandomString(8).c_str());
  41 + this->filePathCorrupt = formatString("%s.corrupt.%s", filePath.c_str(), getSecureRandomString(8).c_str());
  42 +}
  43 +
  44 +PersistenceFile::~PersistenceFile()
  45 +{
  46 + closeFile();
  47 +
  48 + if (digestContext)
  49 + {
  50 + EVP_MD_CTX_free(digestContext);
  51 + }
  52 +}
  53 +
  54 +/**
  55 + * @brief RetainedMessagesDB::hashFile hashes the data after the headers and writes the hash in the header. Uses SHA512.
  56 + *
  57 + */
  58 +void PersistenceFile::writeCheck(const void *ptr, size_t size, size_t n, FILE *s)
  59 +{
  60 + if (fwrite(ptr, size, n, s) != n)
  61 + {
  62 + throw std::runtime_error(formatString("Error writing: %s", strerror(errno)));
  63 + }
  64 +}
  65 +
  66 +ssize_t PersistenceFile::readCheck(void *ptr, size_t size, size_t n, FILE *stream)
  67 +{
  68 + size_t nread = fread(ptr, size, n, stream);
  69 +
  70 + if (nread != n)
  71 + {
  72 + if (feof(f))
  73 + return -1;
  74 +
  75 + throw std::runtime_error(formatString("Error reading: %s", strerror(errno)));
  76 + }
  77 +
  78 + return nread;
  79 +}
  80 +
  81 +void PersistenceFile::hashFile()
  82 +{
  83 + logger->logf(LOG_DEBUG, "Calculating and saving hash of '%s'.", filePath.c_str());
  84 +
  85 + fseek(f, TOTAL_HEADER_SIZE, SEEK_SET);
  86 +
  87 + unsigned int output_len = 0;
  88 + unsigned char md_value[EVP_MAX_MD_SIZE];
  89 + std::memset(md_value, 0, EVP_MAX_MD_SIZE);
  90 +
  91 + EVP_MD_CTX_reset(digestContext);
  92 + EVP_DigestInit_ex(digestContext, sha512, NULL);
  93 +
  94 + while (!feof(f))
  95 + {
  96 + size_t n = fread(buf.data(), 1, buf.size(), f);
  97 + EVP_DigestUpdate(digestContext, buf.data(), n);
  98 + }
  99 +
  100 + EVP_DigestFinal_ex(digestContext, md_value, &output_len);
  101 +
  102 + if (output_len != HASH_SIZE)
  103 + throw std::runtime_error("Impossible: calculated hash size wrong length");
  104 +
  105 + fseek(f, MAGIC_STRING_LENGH, SEEK_SET);
  106 +
  107 + writeCheck(md_value, output_len, 1, f);
  108 + fflush(f);
  109 +}
  110 +
  111 +void PersistenceFile::verifyHash()
  112 +{
  113 + fseek(f, 0, SEEK_END);
  114 + const size_t size = ftell(f);
  115 +
  116 + if (size < TOTAL_HEADER_SIZE)
  117 + throw std::runtime_error(formatString("File '%s' is too small for it even to contain a header.", filePath.c_str()));
  118 +
  119 + unsigned char md_from_disk[HASH_SIZE];
  120 + std::memset(md_from_disk, 0, HASH_SIZE);
  121 +
  122 + fseek(f, MAGIC_STRING_LENGH, SEEK_SET);
  123 + fread(md_from_disk, 1, HASH_SIZE, f);
  124 +
  125 + unsigned int output_len = 0;
  126 + unsigned char md_value[EVP_MAX_MD_SIZE];
  127 + std::memset(md_value, 0, EVP_MAX_MD_SIZE);
  128 +
  129 + EVP_MD_CTX_reset(digestContext);
  130 + EVP_DigestInit_ex(digestContext, sha512, NULL);
  131 +
  132 + while (!feof(f))
  133 + {
  134 + size_t n = fread(buf.data(), 1, buf.size(), f);
  135 + EVP_DigestUpdate(digestContext, buf.data(), n);
  136 + }
  137 +
  138 + EVP_DigestFinal_ex(digestContext, md_value, &output_len);
  139 +
  140 + if (output_len != HASH_SIZE)
  141 + throw std::runtime_error("Impossible: calculated hash size wrong length");
  142 +
  143 + if (std::memcmp(md_from_disk, md_value, output_len) != 0)
  144 + {
  145 + fclose(f);
  146 + f = nullptr;
  147 +
  148 + if (rename(filePath.c_str(), filePathCorrupt.c_str()) == 0)
  149 + {
  150 + throw std::runtime_error(formatString("File '%s' is corrupt: hash mismatch. Moved aside to '%s'.", filePath.c_str(), filePathCorrupt.c_str()));
  151 + }
  152 + else
  153 + {
  154 + throw std::runtime_error(formatString("File '%s' is corrupt: hash mismatch. Tried to move aside, but that failed: '%s'.",
  155 + filePath.c_str(), strerror(errno)));
  156 + }
  157 + }
  158 +
  159 + logger->logf(LOG_DEBUG, "Hash of '%s' correct", filePath.c_str());
  160 +}
  161 +
  162 +/**
  163 + * @brief PersistenceFile::makeSureBufSize grows the buffer if n is bigger.
  164 + * @param n in bytes.
  165 + *
  166 + * Remember that when you're dealing with fields that are sized in MQTT by 16 bit ints, like topic paths, the buffer will always be big enough, because it's 1 MB.
  167 + */
  168 +void PersistenceFile::makeSureBufSize(size_t n)
  169 +{
  170 + if (n > buf.size())
  171 + buf.resize(n);
  172 +}
  173 +
  174 +void PersistenceFile::writeInt64(const int64_t val)
  175 +{
  176 + unsigned char buf[8];
  177 +
  178 + // Write big-endian
  179 + int shift = 56;
  180 + int i = 0;
  181 + while (shift >= 0)
  182 + {
  183 + unsigned char wantedByte = val >> shift;
  184 + buf[i++] = wantedByte;
  185 + shift -= 8;
  186 + }
  187 + writeCheck(buf, 1, 8, f);
  188 +}
  189 +
  190 +void PersistenceFile::writeUint32(const uint32_t val)
  191 +{
  192 + unsigned char buf[4];
  193 +
  194 + // Write big-endian
  195 + int shift = 24;
  196 + int i = 0;
  197 + while (shift >= 0)
  198 + {
  199 + unsigned char wantedByte = val >> shift;
  200 + buf[i++] = wantedByte;
  201 + shift -= 8;
  202 + }
  203 + writeCheck(buf, 1, 4, f);
  204 +}
  205 +
  206 +void PersistenceFile::writeUint16(const uint16_t val)
  207 +{
  208 + unsigned char buf[2];
  209 +
  210 + // Write big-endian
  211 + int shift = 8;
  212 + int i = 0;
  213 + while (shift >= 0)
  214 + {
  215 + unsigned char wantedByte = val >> shift;
  216 + buf[i++] = wantedByte;
  217 + shift -= 8;
  218 + }
  219 + writeCheck(buf, 1, 2, f);
  220 +}
  221 +
  222 +int64_t PersistenceFile::readInt64(bool &eofFound)
  223 +{
  224 + if (readCheck(buf.data(), 1, 8, f) < 0)
  225 + eofFound = true;
  226 +
  227 + unsigned char *buf_ = reinterpret_cast<unsigned char *>(buf.data());
  228 + const uint64_t val1 = ((buf_[0]) << 24) | ((buf_[1]) << 16) | ((buf_[2]) << 8) | (buf_[3]);
  229 + const uint64_t val2 = ((buf_[4]) << 24) | ((buf_[5]) << 16) | ((buf_[6]) << 8) | (buf_[7]);
  230 + const int64_t val = (val1 << 32) | val2;
  231 + return val;
  232 +}
  233 +
  234 +uint32_t PersistenceFile::readUint32(bool &eofFound)
  235 +{
  236 + if (readCheck(buf.data(), 1, 4, f) < 0)
  237 + eofFound = true;
  238 +
  239 + uint32_t val;
  240 + unsigned char *buf_ = reinterpret_cast<unsigned char *>(buf.data());
  241 + val = ((buf_[0]) << 24) | ((buf_[1]) << 16) | ((buf_[2]) << 8) | (buf_[3]);
  242 + return val;
  243 +}
  244 +
  245 +uint16_t PersistenceFile::readUint16(bool &eofFound)
  246 +{
  247 + if (readCheck(buf.data(), 1, 2, f) < 0)
  248 + eofFound = true;
  249 +
  250 + uint16_t val;
  251 + unsigned char *buf_ = reinterpret_cast<unsigned char *>(buf.data());
  252 + val = ((buf_[0]) << 8) | (buf_[1]);
  253 + return val;
  254 +}
  255 +
  256 +/**
  257 + * @brief RetainedMessagesDB::openWrite doesn't explicitely name a file version (v1, etc), because we always write the current definition.
  258 + */
  259 +void PersistenceFile::openWrite(const std::string &versionString)
  260 +{
  261 + if (openMode != FileMode::unknown)
  262 + throw std::runtime_error("File is already open.");
  263 +
  264 + f = fopen(filePathTemp.c_str(), "w+b");
  265 +
  266 + if (f == nullptr)
  267 + {
  268 + throw std::runtime_error(formatString("Can't open '%s': %s", filePathTemp.c_str(), strerror(errno)));
  269 + }
  270 +
  271 + openMode = FileMode::write;
  272 +
  273 + writeCheck(buf.data(), 1, MAGIC_STRING_LENGH, f);
  274 + rewind(f);
  275 + writeCheck(versionString.c_str(), 1, versionString.length(), f);
  276 + fseek(f, MAGIC_STRING_LENGH, SEEK_SET);
  277 + writeCheck(buf.data(), 1, HASH_SIZE, f);
  278 +}
  279 +
  280 +void PersistenceFile::openRead()
  281 +{
  282 + if (openMode != FileMode::unknown)
  283 + throw std::runtime_error("File is already open.");
  284 +
  285 + f = fopen(filePath.c_str(), "rb");
  286 +
  287 + if (f == nullptr)
  288 + throw PersistenceFileCantBeOpened(formatString("Can't open '%s': %s.", filePath.c_str(), strerror(errno)).c_str());
  289 +
  290 + openMode = FileMode::read;
  291 +
  292 + verifyHash();
  293 + rewind(f);
  294 +
  295 + fread(buf.data(), 1, MAGIC_STRING_LENGH, f);
  296 + detectedVersionString = std::string(buf.data(), strlen(buf.data()));
  297 +
  298 + fseek(f, TOTAL_HEADER_SIZE, SEEK_SET);
  299 +}
  300 +
  301 +void PersistenceFile::closeFile()
  302 +{
  303 + if (!f)
  304 + return;
  305 +
  306 + if (openMode == FileMode::write)
  307 + hashFile();
  308 +
  309 + if (f != nullptr)
  310 + {
  311 + fclose(f);
  312 + f = nullptr;
  313 + }
  314 +
  315 + if (openMode == FileMode::write && !filePathTemp.empty() && ! filePath.empty())
  316 + {
  317 + if (rename(filePathTemp.c_str(), filePath.c_str()) < 0)
  318 + throw std::runtime_error(formatString("Saving '%s' failed: rename of temp file to target failed with: %s", filePath.c_str(), strerror(errno)));
  319 + }
  320 +}
persistencefile.h 0 โ†’ 100644
  1 +/*
  2 +This file is part of FlashMQ (https://www.flashmq.org)
  3 +Copyright (C) 2021 Wiebe Cazemier
  4 +
  5 +FlashMQ is free software: you can redistribute it and/or modify
  6 +it under the terms of the GNU Affero General Public License as
  7 +published by the Free Software Foundation, version 3.
  8 +
  9 +FlashMQ is distributed in the hope that it will be useful,
  10 +but WITHOUT ANY WARRANTY; without even the implied warranty of
  11 +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12 +GNU Affero General Public License for more details.
  13 +
  14 +You should have received a copy of the GNU Affero General Public
  15 +License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
  16 +*/
  17 +
  18 +#ifndef PERSISTENCEFILE_H
  19 +#define PERSISTENCEFILE_H
  20 +
  21 +#include <vector>
  22 +#include <list>
  23 +#include <string>
  24 +#include <stdio.h>
  25 +#include <openssl/evp.h>
  26 +#include <stdexcept>
  27 +#include <cstring>
  28 +
  29 +#include "logger.h"
  30 +
  31 +#define MAGIC_STRING_LENGH 32
  32 +#define HASH_SIZE 64
  33 +#define TOTAL_HEADER_SIZE (MAGIC_STRING_LENGH + HASH_SIZE)
  34 +
  35 +/**
  36 + * @brief The PersistenceFileCantBeOpened class should be thrown when a non-fatal file-not-found error happens.
  37 + */
  38 +class PersistenceFileCantBeOpened : public std::runtime_error
  39 +{
  40 +public:
  41 + PersistenceFileCantBeOpened(const std::string &msg) : std::runtime_error(msg) {}
  42 +};
  43 +
  44 +class PersistenceFile
  45 +{
  46 + std::string filePath;
  47 + std::string filePathTemp;
  48 + std::string filePathCorrupt;
  49 +
  50 + EVP_MD_CTX *digestContext = nullptr;
  51 + const EVP_MD *sha512 = EVP_sha512();
  52 +
  53 + void hashFile();
  54 + void verifyHash();
  55 +
  56 +protected:
  57 + enum class FileMode
  58 + {
  59 + unknown,
  60 + read,
  61 + write
  62 + };
  63 +
  64 + FILE *f = nullptr;
  65 + std::vector<char> buf;
  66 + FileMode openMode = FileMode::unknown;
  67 + std::string detectedVersionString;
  68 +
  69 + Logger *logger = Logger::getInstance();
  70 +
  71 + void makeSureBufSize(size_t n);
  72 +
  73 + void writeCheck(const void *__restrict __ptr, size_t __size, size_t __n, FILE *__restrict __s);
  74 + ssize_t readCheck(void *__restrict ptr, size_t size, size_t n, FILE *__restrict stream);
  75 +
  76 + void writeInt64(const int64_t val);
  77 + void writeUint32(const uint32_t val);
  78 + void writeUint16(const uint16_t val);
  79 + int64_t readInt64(bool &eofFound);
  80 + uint32_t readUint32(bool &eofFound);
  81 + uint16_t readUint16(bool &eofFound);
  82 +
  83 +public:
  84 + PersistenceFile(const std::string &filePath);
  85 + ~PersistenceFile();
  86 +
  87 + void openWrite(const std::string &versionString);
  88 + void openRead();
  89 + void closeFile();
  90 +};
  91 +
  92 +#endif // PERSISTENCEFILE_H
qospacketqueue.cpp 0 โ†’ 100644
  1 +#include "qospacketqueue.h"
  2 +
  3 +#include "cassert"
  4 +
  5 +#include "mqttpacket.h"
  6 +
  7 +void QoSPacketQueue::erase(const uint16_t packet_id)
  8 +{
  9 + auto it = queue.begin();
  10 + auto end = queue.end();
  11 + while (it != end)
  12 + {
  13 + std::shared_ptr<MqttPacket> &p = *it;
  14 + if (p->getPacketId() == packet_id)
  15 + {
  16 + size_t mem = p->getTotalMemoryFootprint();
  17 + qosQueueBytes -= mem;
  18 + assert(qosQueueBytes >= 0);
  19 + if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose.
  20 + qosQueueBytes = 0;
  21 +
  22 + queue.erase(it);
  23 +
  24 + break;
  25 + }
  26 +
  27 + it++;
  28 + }
  29 +}
  30 +
  31 +size_t QoSPacketQueue::size() const
  32 +{
  33 + return queue.size();
  34 +}
  35 +
  36 +size_t QoSPacketQueue::getByteSize() const
  37 +{
  38 + return qosQueueBytes;
  39 +}
  40 +
  41 +/**
  42 + * @brief QoSPacketQueue::queuePacket makes a copy of the packet because it has state for the receiver in question.
  43 + * @param p
  44 + * @param id
  45 + * @return the packet copy.
  46 + */
  47 +std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id)
  48 +{
  49 + assert(p.getQos() > 0);
  50 +
  51 + std::shared_ptr<MqttPacket> copyPacket = p.getCopy();
  52 + copyPacket->setPacketId(id);
  53 + queue.push_back(copyPacket);
  54 + qosQueueBytes += copyPacket->getTotalMemoryFootprint();
  55 + return copyPacket;
  56 +}
  57 +
  58 +std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const Publish &pub, uint16_t id)
  59 +{
  60 + assert(pub.qos > 0);
  61 +
  62 + std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(pub));
  63 + copyPacket->setPacketId(id);
  64 + queue.push_back(copyPacket);
  65 + qosQueueBytes += copyPacket->getTotalMemoryFootprint();
  66 + return copyPacket;
  67 +}
  68 +
  69 +std::list<std::shared_ptr<MqttPacket>>::const_iterator QoSPacketQueue::begin() const
  70 +{
  71 + return queue.cbegin();
  72 +}
  73 +
  74 +std::list<std::shared_ptr<MqttPacket>>::const_iterator QoSPacketQueue::end() const
  75 +{
  76 + return queue.cend();
  77 +}
qospacketqueue.h 0 โ†’ 100644
  1 +#ifndef QOSPACKETQUEUE_H
  2 +#define QOSPACKETQUEUE_H
  3 +
  4 +#include "list"
  5 +
  6 +#include "forward_declarations.h"
  7 +#include "types.h"
  8 +
  9 +class QoSPacketQueue
  10 +{
  11 + std::list<std::shared_ptr<MqttPacket>> queue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6]
  12 + ssize_t qosQueueBytes = 0;
  13 +
  14 +public:
  15 + void erase(const uint16_t packet_id);
  16 + size_t size() const;
  17 + size_t getByteSize() const;
  18 + std::shared_ptr<MqttPacket> queuePacket(const MqttPacket &p, uint16_t id);
  19 + std::shared_ptr<MqttPacket> queuePacket(const Publish &pub, uint16_t id);
  20 +
  21 + std::list<std::shared_ptr<MqttPacket>>::const_iterator begin() const;
  22 + std::list<std::shared_ptr<MqttPacket>>::const_iterator end() const;
  23 +};
  24 +
  25 +#endif // QOSPACKETQUEUE_H
retainedmessage.cpp
@@ -34,3 +34,8 @@ bool RetainedMessage::empty() const @@ -34,3 +34,8 @@ bool RetainedMessage::empty() const
34 { 34 {
35 return payload.empty(); 35 return payload.empty();
36 } 36 }
  37 +
  38 +uint32_t RetainedMessage::getSize() const
  39 +{
  40 + return topic.length() + payload.length() + 1;
  41 +}
retainedmessage.h
@@ -30,6 +30,7 @@ struct RetainedMessage @@ -30,6 +30,7 @@ struct RetainedMessage
30 30
31 bool operator==(const RetainedMessage &rhs) const; 31 bool operator==(const RetainedMessage &rhs) const;
32 bool empty() const; 32 bool empty() const;
  33 + uint32_t getSize() const;
33 }; 34 };
34 35
35 namespace std { 36 namespace std {
retainedmessagesdb.cpp 0 โ†’ 100644
  1 +/*
  2 +This file is part of FlashMQ (https://www.flashmq.org)
  3 +Copyright (C) 2021 Wiebe Cazemier
  4 +
  5 +FlashMQ is free software: you can redistribute it and/or modify
  6 +it under the terms of the GNU Affero General Public License as
  7 +published by the Free Software Foundation, version 3.
  8 +
  9 +FlashMQ is distributed in the hope that it will be useful,
  10 +but WITHOUT ANY WARRANTY; without even the implied warranty of
  11 +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12 +GNU Affero General Public License for more details.
  13 +
  14 +You should have received a copy of the GNU Affero General Public
  15 +License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
  16 +*/
  17 +
  18 +#include <sys/types.h>
  19 +#include <sys/stat.h>
  20 +#include <fcntl.h>
  21 +#include <unistd.h>
  22 +#include <exception>
  23 +#include <stdexcept>
  24 +#include <stdio.h>
  25 +#include <cstring>
  26 +
  27 +#include "retainedmessagesdb.h"
  28 +#include "utils.h"
  29 +#include "logger.h"
  30 +
  31 +RetainedMessagesDB::RetainedMessagesDB(const std::string &filePath) : PersistenceFile(filePath)
  32 +{
  33 +
  34 +}
  35 +
  36 +void RetainedMessagesDB::openWrite()
  37 +{
  38 + PersistenceFile::openWrite(MAGIC_STRING_V1);
  39 +}
  40 +
  41 +void RetainedMessagesDB::openRead()
  42 +{
  43 + PersistenceFile::openRead();
  44 +
  45 + if (detectedVersionString == MAGIC_STRING_V1)
  46 + readVersion = ReadVersion::v1;
  47 + else
  48 + throw std::runtime_error("Unknown file version.");
  49 +}
  50 +
  51 +/**
  52 + * @brief RetainedMessagesDB::writeRowHeader writes two 32 bit integers: topic size and payload size.
  53 + * @param rm
  54 + *
  55 + * So, the header per message is 8 bytes long.
  56 + *
  57 + * It writes no information about the length of the QoS value, because that is always one.
  58 + */
  59 +void RetainedMessagesDB::writeRowHeader(const RetainedMessage &rm)
  60 +{
  61 + writeUint32(rm.topic.size());
  62 + writeUint32(rm.payload.size());
  63 +}
  64 +
  65 +RetainedMessagesDB::RowHeader RetainedMessagesDB::readRowHeaderV1(bool &eofFound)
  66 +{
  67 + RetainedMessagesDB::RowHeader result;
  68 +
  69 + result.topicLen = readUint32(eofFound);
  70 + result.payloadLen = readUint32(eofFound);
  71 +
  72 + return result;
  73 +}
  74 +
  75 +/**
  76 + * @brief RetainedMessagesDB::saveData doesn't explicitely name a file version (v1, etc), because we always write the current definition.
  77 + * @param messages
  78 + */
  79 +void RetainedMessagesDB::saveData(const std::vector<RetainedMessage> &messages)
  80 +{
  81 + if (!f)
  82 + return;
  83 +
  84 + char reserved[RESERVED_SPACE_RETAINED_DB_V1];
  85 + std::memset(reserved, 0, RESERVED_SPACE_RETAINED_DB_V1);
  86 +
  87 + char qos = 0;
  88 + for (const RetainedMessage &rm : messages)
  89 + {
  90 + logger->logf(LOG_DEBUG, "Saving retained message for topic '%s' QoS %d.", rm.topic.c_str(), rm.qos);
  91 +
  92 + writeRowHeader(rm);
  93 + qos = rm.qos;
  94 + writeCheck(&qos, 1, 1, f);
  95 + writeCheck(reserved, 1, RESERVED_SPACE_RETAINED_DB_V1, f);
  96 + writeCheck(rm.topic.c_str(), 1, rm.topic.length(), f);
  97 + writeCheck(rm.payload.c_str(), 1, rm.payload.length(), f);
  98 + }
  99 +
  100 + fflush(f);
  101 +}
  102 +
  103 +std::list<RetainedMessage> RetainedMessagesDB::readData()
  104 +{
  105 + std::list<RetainedMessage> defaultResult;
  106 +
  107 + if (!f)
  108 + return defaultResult;
  109 +
  110 + if (readVersion == ReadVersion::v1)
  111 + return readDataV1();
  112 +
  113 + return defaultResult;
  114 +}
  115 +
  116 +std::list<RetainedMessage> RetainedMessagesDB::readDataV1()
  117 +{
  118 + std::list<RetainedMessage> messages;
  119 +
  120 + while (!feof(f))
  121 + {
  122 + bool eofFound = false;
  123 + RetainedMessagesDB::RowHeader header = readRowHeaderV1(eofFound);
  124 +
  125 + if (eofFound)
  126 + continue;
  127 +
  128 + makeSureBufSize(header.payloadLen);
  129 +
  130 + readCheck(buf.data(), 1, 1, f);
  131 + char qos = buf[0];
  132 + fseek(f, RESERVED_SPACE_RETAINED_DB_V1, SEEK_CUR);
  133 +
  134 + readCheck(buf.data(), 1, header.topicLen, f);
  135 + std::string topic(buf.data(), header.topicLen);
  136 +
  137 + readCheck(buf.data(), 1, header.payloadLen, f);
  138 + std::string payload(buf.data(), header.payloadLen);
  139 +
  140 + RetainedMessage msg(topic, payload, qos);
  141 + logger->logf(LOG_DEBUG, "Loading retained message for topic '%s' QoS %d.", msg.topic.c_str(), msg.qos);
  142 + messages.push_back(std::move(msg));
  143 + }
  144 +
  145 + return messages;
  146 +}
retainedmessagesdb.h 0 โ†’ 100644
  1 +/*
  2 +This file is part of FlashMQ (https://www.flashmq.org)
  3 +Copyright (C) 2021 Wiebe Cazemier
  4 +
  5 +FlashMQ is free software: you can redistribute it and/or modify
  6 +it under the terms of the GNU Affero General Public License as
  7 +published by the Free Software Foundation, version 3.
  8 +
  9 +FlashMQ is distributed in the hope that it will be useful,
  10 +but WITHOUT ANY WARRANTY; without even the implied warranty of
  11 +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12 +GNU Affero General Public License for more details.
  13 +
  14 +You should have received a copy of the GNU Affero General Public
  15 +License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
  16 +*/
  17 +
  18 +#ifndef RETAINEDMESSAGESDB_H
  19 +#define RETAINEDMESSAGESDB_H
  20 +
  21 +#include "persistencefile.h"
  22 +#include "retainedmessage.h"
  23 +
  24 +#include "logger.h"
  25 +
  26 +#define MAGIC_STRING_V1 "FlashMQRetainedDBv1"
  27 +#define ROW_HEADER_SIZE 8
  28 +#define RESERVED_SPACE_RETAINED_DB_V1 31
  29 +
  30 +/**
  31 + * @brief The RetainedMessagesDB class saves and loads the retained messages.
  32 + *
  33 + * The DB looks like, from the top:
  34 + *
  35 + * MAGIC_STRING_LENGH bytes file header
  36 + * HASH_SIZE SHA512
  37 + * [MESSAGES]
  38 + *
  39 + * Each message has a row header, which is 8 bytes. See writeRowHeader().
  40 + *
  41 + */
  42 +class RetainedMessagesDB : public PersistenceFile
  43 +{
  44 + enum class ReadVersion
  45 + {
  46 + unknown,
  47 + v1
  48 + };
  49 +
  50 + struct RowHeader
  51 + {
  52 + uint32_t topicLen = 0;
  53 + uint32_t payloadLen = 0;
  54 + };
  55 +
  56 + ReadVersion readVersion = ReadVersion::unknown;
  57 +
  58 + void writeRowHeader(const RetainedMessage &rm);
  59 + RowHeader readRowHeaderV1(bool &eofFound);
  60 + std::list<RetainedMessage> readDataV1();
  61 +public:
  62 + RetainedMessagesDB(const std::string &filePath);
  63 +
  64 + void openWrite();
  65 + void openRead();
  66 +
  67 + void saveData(const std::vector<RetainedMessage> &messages);
  68 + std::list<RetainedMessage> readData();
  69 +};
  70 +
  71 +#endif // RETAINEDMESSAGESDB_H
rwlockguard.cpp
@@ -32,7 +32,15 @@ RWLockGuard::~RWLockGuard() @@ -32,7 +32,15 @@ RWLockGuard::~RWLockGuard()
32 32
33 void RWLockGuard::wrlock() 33 void RWLockGuard::wrlock()
34 { 34 {
35 - if (pthread_rwlock_wrlock(rwlock) != 0) 35 + const int rc = pthread_rwlock_wrlock(rwlock);
  36 +
  37 + if (rc == EDEADLK)
  38 + {
  39 + rwlock = nullptr;
  40 + return;
  41 + }
  42 +
  43 + if (rc != 0)
36 throw std::runtime_error("wrlock failed."); 44 throw std::runtime_error("wrlock failed.");
37 } 45 }
38 46
session.cpp
@@ -20,16 +20,75 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -20,16 +20,75 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
20 #include "session.h" 20 #include "session.h"
21 #include "client.h" 21 #include "client.h"
22 22
  23 +std::chrono::time_point<std::chrono::steady_clock> appStartTime = std::chrono::steady_clock::now();
  24 +
23 Session::Session() 25 Session::Session()
24 { 26 {
25 27
26 } 28 }
27 29
  30 +int64_t Session::getProgramStartedAtUnixTimestamp()
  31 +{
  32 + auto secondsSinceEpoch = std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count();
  33 + const std::chrono::seconds age = std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - appStartTime);
  34 + int64_t result = secondsSinceEpoch - age.count();
  35 + return result;
  36 +}
  37 +
  38 +void Session::setProgramStartedAtUnixTimestamp(const int64_t unix_timestamp)
  39 +{
  40 + auto secondsSinceEpoch = std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch());
  41 + const std::chrono::seconds _unix_timestamp = std::chrono::seconds(unix_timestamp);
  42 + const std::chrono::seconds age_in_s = secondsSinceEpoch - _unix_timestamp;
  43 + appStartTime = std::chrono::steady_clock::now() - age_in_s;
  44 +}
  45 +
  46 +
  47 +int64_t Session::getSessionRelativeAgeInMs() const
  48 +{
  49 + const std::chrono::milliseconds sessionAge = std::chrono::duration_cast<std::chrono::milliseconds>(lastTouched - appStartTime);
  50 + const int64_t sInMs = sessionAge.count();
  51 + return sInMs;
  52 +}
  53 +
  54 +void Session::setSessionTouch(int64_t ageInMs)
  55 +{
  56 + std::chrono::milliseconds ms(ageInMs);
  57 + std::chrono::time_point<std::chrono::steady_clock> point = appStartTime + ms;
  58 + lastTouched = point;
  59 +}
  60 +
  61 +/**
  62 + * @brief Session::Session copy constructor. Was created for session storing, and is explicitely kept private, to avoid making accidental copies.
  63 + * @param other
  64 + *
  65 + * Because it was created for session storing, the fields we're copying are the fields being stored.
  66 + */
  67 +Session::Session(const Session &other)
  68 +{
  69 + this->username = other.username;
  70 + this->client_id = other.client_id;
  71 + this->incomingQoS2MessageIds = other.incomingQoS2MessageIds;
  72 + this->outgoingQoS2MessageIds = other.outgoingQoS2MessageIds;
  73 + this->nextPacketId = other.nextPacketId;
  74 + this->lastTouched = other.lastTouched;
  75 +
  76 + // To be fully correct, we should copy the individual packets, but copying sessions is only done for saving them, and I know
  77 + // that no member of MqttPacket changes in the QoS process, so we can just keep the shared pointer to the original.
  78 + this->qosPacketQueue = other.qosPacketQueue;
  79 +}
  80 +
28 Session::~Session() 81 Session::~Session()
29 { 82 {
30 logger->logf(LOG_DEBUG, "Session %s is being destroyed.", getClientId().c_str()); 83 logger->logf(LOG_DEBUG, "Session %s is being destroyed.", getClientId().c_str());
31 } 84 }
32 85
  86 +std::unique_ptr<Session> Session::getCopy() const
  87 +{
  88 + std::unique_ptr<Session> s(new Session(*this));
  89 + return s;
  90 +}
  91 +
33 bool Session::clientDisconnected() const 92 bool Session::clientDisconnected() const
34 { 93 {
35 return client.expired(); 94 return client.expired();
@@ -45,7 +104,6 @@ void Session::assignActiveConnection(std::shared_ptr&lt;Client&gt; &amp;client) @@ -45,7 +104,6 @@ void Session::assignActiveConnection(std::shared_ptr&lt;Client&gt; &amp;client)
45 this->client = client; 104 this->client = client;
46 this->client_id = client->getClientId(); 105 this->client_id = client->getClientId();
47 this->username = client->getUsername(); 106 this->username = client->getUsername();
48 - this->thread = client->getThreadData();  
49 } 107 }
50 108
51 /** 109 /**
@@ -60,7 +118,9 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos, bool retain, u @@ -60,7 +118,9 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos, bool retain, u
60 assert(max_qos <= 2); 118 assert(max_qos <= 2);
61 const char qos = std::min<char>(packet.getQos(), max_qos); 119 const char qos = std::min<char>(packet.getQos(), max_qos);
62 120
63 - if (thread->authentication.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success) 121 + assert(packet.getSender());
  122 + Authentication &auth = packet.getSender()->getThreadData()->authentication;
  123 + if (auth.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success)
64 { 124 {
65 if (qos == 0) 125 if (qos == 0)
66 { 126 {
@@ -73,11 +133,10 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos, bool retain, u @@ -73,11 +133,10 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos, bool retain, u
73 } 133 }
74 else if (qos > 0) 134 else if (qos > 0)
75 { 135 {
76 - std::shared_ptr<MqttPacket> copyPacket = packet.getCopy();  
77 std::unique_lock<std::mutex> locker(qosQueueMutex); 136 std::unique_lock<std::mutex> locker(qosQueueMutex);
78 137
79 const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size(); 138 const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size();
80 - if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) 139 + if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosPacketQueue.getByteSize() >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0))
81 { 140 {
82 logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str()); 141 logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str());
83 return; 142 return;
@@ -86,13 +145,7 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos, bool retain, u @@ -86,13 +145,7 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos, bool retain, u
86 if (nextPacketId == 0) 145 if (nextPacketId == 0)
87 nextPacketId++; 146 nextPacketId++;
88 147
89 - const uint16_t pid = nextPacketId;  
90 - copyPacket->setPacketId(pid);  
91 - QueuedQosPacket p;  
92 - p.packet = copyPacket;  
93 - p.id = pid;  
94 - qosPacketQueue.push_back(p);  
95 - qosQueueBytes += copyPacket->getTotalMemoryFootprint(); 148 + std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId);
96 locker.unlock(); 149 locker.unlock();
97 150
98 if (!clientDisconnected()) 151 if (!clientDisconnected())
@@ -115,27 +168,7 @@ void Session::clearQosMessage(uint16_t packet_id) @@ -115,27 +168,7 @@ void Session::clearQosMessage(uint16_t packet_id)
115 #endif 168 #endif
116 169
117 std::lock_guard<std::mutex> locker(qosQueueMutex); 170 std::lock_guard<std::mutex> locker(qosQueueMutex);
118 -  
119 - auto it = qosPacketQueue.begin();  
120 - auto end = qosPacketQueue.end();  
121 - while (it != end)  
122 - {  
123 - QueuedQosPacket &p = *it;  
124 - if (p.id == packet_id)  
125 - {  
126 - size_t mem = p.packet->getTotalMemoryFootprint();  
127 - qosQueueBytes -= mem;  
128 - assert(qosQueueBytes >= 0);  
129 - if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose.  
130 - qosQueueBytes = 0;  
131 -  
132 - qosPacketQueue.erase(it);  
133 -  
134 - break;  
135 - }  
136 -  
137 - it++;  
138 - } 171 + qosPacketQueue.erase(packet_id);
139 } 172 }
140 173
141 // [MQTT-4.4.0-1]: "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST re-send any 174 // [MQTT-4.4.0-1]: "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST re-send any
@@ -152,10 +185,10 @@ uint64_t Session::sendPendingQosMessages() @@ -152,10 +185,10 @@ uint64_t Session::sendPendingQosMessages()
152 { 185 {
153 std::shared_ptr<Client> c = makeSharedClient(); 186 std::shared_ptr<Client> c = makeSharedClient();
154 std::lock_guard<std::mutex> locker(qosQueueMutex); 187 std::lock_guard<std::mutex> locker(qosQueueMutex);
155 - for (QueuedQosPacket &qosMessage : qosPacketQueue) 188 + for (const std::shared_ptr<MqttPacket> &qosMessage : qosPacketQueue)
156 { 189 {
157 - c->writeMqttPacketAndBlameThisClient(*qosMessage.packet.get(), qosMessage.packet->getQos());  
158 - qosMessage.packet->setDuplicate(); // Any dealings with this packet from here will be a duplicate. 190 + c->writeMqttPacketAndBlameThisClient(*qosMessage.get(), qosMessage->getQos());
  191 + qosMessage->setDuplicate(); // Any dealings with this packet from here will be a duplicate.
159 count++; 192 count++;
160 } 193 }
161 194
session.h
@@ -25,38 +25,46 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -25,38 +25,46 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
25 25
26 #include "forward_declarations.h" 26 #include "forward_declarations.h"
27 #include "logger.h" 27 #include "logger.h"
  28 +#include "sessionsandsubscriptionsdb.h"
  29 +#include "qospacketqueue.h"
28 30
29 // TODO make settings. But, num of packets can't exceed 65536, because the counter is 16 bit. 31 // TODO make settings. But, num of packets can't exceed 65536, because the counter is 16 bit.
30 #define MAX_QOS_MSG_PENDING_PER_CLIENT 32 32 #define MAX_QOS_MSG_PENDING_PER_CLIENT 32
31 #define MAX_QOS_BYTES_PENDING_PER_CLIENT 4096 33 #define MAX_QOS_BYTES_PENDING_PER_CLIENT 4096
32 34
33 -struct QueuedQosPacket  
34 -{  
35 - uint16_t id;  
36 - std::shared_ptr<MqttPacket> packet;  
37 -};  
38 -  
39 class Session 35 class Session
40 { 36 {
  37 +#ifdef TESTING
  38 + friend class MainTests;
  39 +#endif
  40 +
  41 + friend class SessionsAndSubscriptionsDB;
  42 +
41 std::weak_ptr<Client> client; 43 std::weak_ptr<Client> client;
42 - std::shared_ptr<ThreadData> thread;  
43 std::string client_id; 44 std::string client_id;
44 std::string username; 45 std::string username;
45 - std::list<QueuedQosPacket> qosPacketQueue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] 46 + QoSPacketQueue qosPacketQueue;
46 std::set<uint16_t> incomingQoS2MessageIds; 47 std::set<uint16_t> incomingQoS2MessageIds;
47 std::set<uint16_t> outgoingQoS2MessageIds; 48 std::set<uint16_t> outgoingQoS2MessageIds;
48 std::mutex qosQueueMutex; 49 std::mutex qosQueueMutex;
49 uint16_t nextPacketId = 0; 50 uint16_t nextPacketId = 0;
50 - ssize_t qosQueueBytes = 0;  
51 - std::chrono::time_point<std::chrono::steady_clock> lastTouched; 51 + std::chrono::time_point<std::chrono::steady_clock> lastTouched = std::chrono::steady_clock::now();
52 Logger *logger = Logger::getInstance(); 52 Logger *logger = Logger::getInstance();
  53 + int64_t getSessionRelativeAgeInMs() const;
  54 + void setSessionTouch(int64_t ageInMs);
53 55
  56 + Session(const Session &other);
54 public: 57 public:
55 Session(); 58 Session();
56 - Session(const Session &other) = delete; 59 +
57 Session(Session &&other) = delete; 60 Session(Session &&other) = delete;
58 ~Session(); 61 ~Session();
59 62
  63 + static int64_t getProgramStartedAtUnixTimestamp();
  64 + static void setProgramStartedAtUnixTimestamp(const int64_t unix_timestamp);
  65 +
  66 + std::unique_ptr<Session> getCopy() const;
  67 +
60 const std::string &getClientId() const { return client_id; } 68 const std::string &getClientId() const { return client_id; }
61 bool clientDisconnected() const; 69 bool clientDisconnected() const;
62 std::shared_ptr<Client> makeSharedClient() const; 70 std::shared_ptr<Client> makeSharedClient() const;
@@ -74,7 +82,6 @@ public: @@ -74,7 +82,6 @@ public:
74 82
75 void addOutgoingQoS2MessageId(uint16_t packet_id); 83 void addOutgoingQoS2MessageId(uint16_t packet_id);
76 void removeOutgoingQoS2MessageId(u_int16_t packet_id); 84 void removeOutgoingQoS2MessageId(u_int16_t packet_id);
77 -  
78 }; 85 };
79 86
80 #endif // SESSION_H 87 #endif // SESSION_H
sessionsandsubscriptionsdb.cpp 0 โ†’ 100644
  1 +/*
  2 +This file is part of FlashMQ (https://www.flashmq.org)
  3 +Copyright (C) 2021 Wiebe Cazemier
  4 +
  5 +FlashMQ is free software: you can redistribute it and/or modify
  6 +it under the terms of the GNU Affero General Public License as
  7 +published by the Free Software Foundation, version 3.
  8 +
  9 +FlashMQ is distributed in the hope that it will be useful,
  10 +but WITHOUT ANY WARRANTY; without even the implied warranty of
  11 +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12 +GNU Affero General Public License for more details.
  13 +
  14 +You should have received a copy of the GNU Affero General Public
  15 +License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
  16 +*/
  17 +
  18 +#include "sessionsandsubscriptionsdb.h"
  19 +#include "mqttpacket.h"
  20 +
  21 +#include "cassert"
  22 +
  23 +SubscriptionForSerializing::SubscriptionForSerializing(const std::string &clientId, char qos) :
  24 + clientId(clientId),
  25 + qos(qos)
  26 +{
  27 +
  28 +}
  29 +
  30 +SubscriptionForSerializing::SubscriptionForSerializing(const std::string &&clientId, char qos) :
  31 + clientId(clientId),
  32 + qos(qos)
  33 +{
  34 +
  35 +}
  36 +
  37 +SessionsAndSubscriptionsDB::SessionsAndSubscriptionsDB(const std::string &filePath) : PersistenceFile(filePath)
  38 +{
  39 +
  40 +}
  41 +
  42 +void SessionsAndSubscriptionsDB::openWrite()
  43 +{
  44 + PersistenceFile::openWrite(MAGIC_STRING_SESSION_FILE_V1);
  45 +}
  46 +
  47 +void SessionsAndSubscriptionsDB::openRead()
  48 +{
  49 + PersistenceFile::openRead();
  50 +
  51 + if (detectedVersionString == MAGIC_STRING_SESSION_FILE_V1)
  52 + readVersion = ReadVersion::v1;
  53 + else
  54 + throw std::runtime_error("Unknown file version.");
  55 +}
  56 +
  57 +SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV1()
  58 +{
  59 + SessionsAndSubscriptionsResult result;
  60 +
  61 + while (!feof(f))
  62 + {
  63 + bool eofFound = false;
  64 +
  65 + const int64_t programStartAge = readInt64(eofFound);
  66 + if (eofFound)
  67 + continue;
  68 +
  69 + logger->logf(LOG_DEBUG, "Setting first app start time to timestamp %ld", programStartAge);
  70 + Session::setProgramStartedAtUnixTimestamp(programStartAge);
  71 +
  72 + const uint32_t nrOfSessions = readUint32(eofFound);
  73 +
  74 + if (eofFound)
  75 + continue;
  76 +
  77 + std::vector<char> reserved(RESERVED_SPACE_SESSIONS_DB_V1);
  78 +
  79 + for (uint32_t i = 0; i < nrOfSessions; i++)
  80 + {
  81 + readCheck(buf.data(), 1, RESERVED_SPACE_SESSIONS_DB_V1, f);
  82 +
  83 + uint32_t usernameLength = readUint32(eofFound);
  84 + readCheck(buf.data(), 1, usernameLength, f);
  85 + std::string username(buf.data(), usernameLength);
  86 +
  87 + uint32_t clientIdLength = readUint32(eofFound);
  88 + readCheck(buf.data(), 1, clientIdLength, f);
  89 + std::string clientId(buf.data(), clientIdLength);
  90 +
  91 + std::shared_ptr<Session> ses(new Session());
  92 + result.sessions.push_back(ses);
  93 + ses->username = username;
  94 + ses->client_id = clientId;
  95 +
  96 + logger->logf(LOG_DEBUG, "Loading session '%s'.", ses->getClientId().c_str());
  97 +
  98 + const uint32_t nrOfQueuedQoSPackets = readUint32(eofFound);
  99 + for (uint32_t i = 0; i < nrOfQueuedQoSPackets; i++)
  100 + {
  101 + const uint16_t id = readUint16(eofFound);
  102 + const uint32_t topicSize = readUint32(eofFound);
  103 + const uint32_t payloadSize = readUint32(eofFound);
  104 +
  105 + assert(id > 0);
  106 +
  107 + readCheck(buf.data(), 1, 1, f);
  108 + const unsigned char qos = buf[0];
  109 +
  110 + readCheck(buf.data(), 1, topicSize, f);
  111 + const std::string topic(buf.data(), topicSize);
  112 +
  113 + makeSureBufSize(payloadSize);
  114 + readCheck(buf.data(), 1, payloadSize, f);
  115 + const std::string payload(buf.data(), payloadSize);
  116 +
  117 + Publish pub(topic, payload, qos);
  118 + logger->logf(LOG_DEBUG, "Loaded QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str());
  119 + ses->qosPacketQueue.queuePacket(pub, id);
  120 + }
  121 +
  122 + const uint32_t nrOfIncomingPacketIds = readUint32(eofFound);
  123 + for (uint32_t i = 0; i < nrOfIncomingPacketIds; i++)
  124 + {
  125 + uint16_t id = readUint16(eofFound);
  126 + assert(id > 0);
  127 + logger->logf(LOG_DEBUG, "Loaded incomming QoS2 message id %d.", id);
  128 + ses->incomingQoS2MessageIds.insert(id);
  129 + }
  130 +
  131 + const uint32_t nrOfOutgoingPacketIds = readUint32(eofFound);
  132 + for (uint32_t i = 0; i < nrOfOutgoingPacketIds; i++)
  133 + {
  134 + uint16_t id = readUint16(eofFound);
  135 + assert(id > 0);
  136 + logger->logf(LOG_DEBUG, "Loaded outgoing QoS2 message id %d.", id);
  137 + ses->outgoingQoS2MessageIds.insert(id);
  138 + }
  139 +
  140 + const uint16_t nextPacketId = readUint16(eofFound);
  141 + logger->logf(LOG_DEBUG, "Loaded next packetid %d.", ses->nextPacketId);
  142 + ses->nextPacketId = nextPacketId;
  143 +
  144 + int64_t sessionAge = readInt64(eofFound);
  145 + logger->logf(LOG_DEBUG, "Loaded session age: %ld ms.", sessionAge);
  146 + ses->setSessionTouch(sessionAge);
  147 + }
  148 +
  149 + const uint32_t nrOfSubscriptions = readUint32(eofFound);
  150 + for (uint32_t i = 0; i < nrOfSubscriptions; i++)
  151 + {
  152 + const uint32_t topicLength = readUint32(eofFound);
  153 + readCheck(buf.data(), 1, topicLength, f);
  154 + const std::string topic(buf.data(), topicLength);
  155 +
  156 + logger->logf(LOG_DEBUG, "Loading subscriptions to topic '%s'.", topic.c_str());
  157 +
  158 + const uint32_t nrOfClientIds = readUint32(eofFound);
  159 +
  160 + for (uint32_t i = 0; i < nrOfClientIds; i++)
  161 + {
  162 + const uint32_t clientIdLength = readUint32(eofFound);
  163 + readCheck(buf.data(), 1, clientIdLength, f);
  164 + const std::string clientId(buf.data(), clientIdLength);
  165 +
  166 + char qos;
  167 + readCheck(&qos, 1, 1, f);
  168 +
  169 + logger->logf(LOG_DEBUG, "Saving session '%s' subscription to '%s' QoS %d.", clientId.c_str(), topic.c_str(), qos);
  170 +
  171 + SubscriptionForSerializing sub(std::move(clientId), qos);
  172 + result.subscriptions[topic].push_back(std::move(sub));
  173 + }
  174 +
  175 + }
  176 + }
  177 +
  178 + return result;
  179 +}
  180 +
  181 +void SessionsAndSubscriptionsDB::writeRowHeader()
  182 +{
  183 +
  184 +}
  185 +
  186 +void SessionsAndSubscriptionsDB::saveData(const std::list<std::unique_ptr<Session>> &sessions, const std::unordered_map<std::string, std::list<SubscriptionForSerializing>> &subscriptions)
  187 +{
  188 + if (!f)
  189 + return;
  190 +
  191 + char reserved[RESERVED_SPACE_SESSIONS_DB_V1];
  192 + std::memset(reserved, 0, RESERVED_SPACE_SESSIONS_DB_V1);
  193 +
  194 + const int64_t start_stamp = Session::getProgramStartedAtUnixTimestamp();
  195 + logger->logf(LOG_DEBUG, "Saving program first start time stamp as %ld", start_stamp);
  196 + writeInt64(start_stamp);
  197 +
  198 + writeUint32(sessions.size());
  199 +
  200 + for (const std::unique_ptr<Session> &ses : sessions)
  201 + {
  202 + logger->logf(LOG_DEBUG, "Saving session '%s'.", ses->getClientId().c_str());
  203 +
  204 + writeRowHeader();
  205 +
  206 + writeCheck(reserved, 1, RESERVED_SPACE_SESSIONS_DB_V1, f);
  207 +
  208 + writeUint32(ses->username.length());
  209 + writeCheck(ses->username.c_str(), 1, ses->username.length(), f);
  210 +
  211 + writeUint32(ses->client_id.length());
  212 + writeCheck(ses->client_id.c_str(), 1, ses->client_id.length(), f);
  213 +
  214 + const size_t qosPacketsExpected = ses->qosPacketQueue.size();
  215 + size_t qosPacketsCounted = 0;
  216 + writeUint32(qosPacketsExpected);
  217 +
  218 + for (const std::shared_ptr<MqttPacket> &p: ses->qosPacketQueue)
  219 + {
  220 + logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", p->getQos(), p->getTopic().c_str());
  221 +
  222 + qosPacketsCounted++;
  223 +
  224 + writeUint16(p->getPacketId());
  225 +
  226 + writeUint32(p->getTopic().length());
  227 + std::string payload = p->getPayloadCopy();
  228 + writeUint32(payload.size());
  229 +
  230 + const char qos = p->getQos();
  231 + writeCheck(&qos, 1, 1, f);
  232 +
  233 + writeCheck(p->getTopic().c_str(), 1, p->getTopic().length(), f);
  234 + writeCheck(payload.c_str(), 1, payload.length(), f);
  235 + }
  236 +
  237 + assert(qosPacketsExpected == qosPacketsCounted);
  238 +
  239 + writeUint32(ses->incomingQoS2MessageIds.size());
  240 + for (uint16_t id : ses->incomingQoS2MessageIds)
  241 + {
  242 + logger->logf(LOG_DEBUG, "Writing incomming QoS2 message id %d.", id);
  243 + writeUint16(id);
  244 + }
  245 +
  246 + writeUint32(ses->outgoingQoS2MessageIds.size());
  247 + for (uint16_t id : ses->outgoingQoS2MessageIds)
  248 + {
  249 + logger->logf(LOG_DEBUG, "Writing outgoing QoS2 message id %d.", id);
  250 + writeUint16(id);
  251 + }
  252 +
  253 + logger->logf(LOG_DEBUG, "Writing next packetid %d.", ses->nextPacketId);
  254 + writeUint16(ses->nextPacketId);
  255 +
  256 + const int64_t sInMs = ses->getSessionRelativeAgeInMs();
  257 + logger->logf(LOG_DEBUG, "Writing session age: %ld ms.", sInMs);
  258 + writeInt64(sInMs);
  259 + }
  260 +
  261 + writeUint32(subscriptions.size());
  262 +
  263 + for (auto &pair : subscriptions)
  264 + {
  265 + const std::string &topic = pair.first;
  266 + const std::list<SubscriptionForSerializing> &subscriptions = pair.second;
  267 +
  268 + logger->logf(LOG_DEBUG, "Writing subscriptions to topic '%s'.", topic.c_str());
  269 +
  270 + writeUint32(topic.size());
  271 + writeCheck(topic.c_str(), 1, topic.size(), f);
  272 +
  273 + writeUint32(subscriptions.size());
  274 +
  275 + for (const SubscriptionForSerializing &subscription : subscriptions)
  276 + {
  277 + logger->logf(LOG_DEBUG, "Saving session '%s' subscription to '%s' QoS %d.", subscription.clientId.c_str(), topic.c_str(), subscription.qos);
  278 +
  279 + writeUint32(subscription.clientId.size());
  280 + writeCheck(subscription.clientId.c_str(), 1, subscription.clientId.size(), f);
  281 + writeCheck(&subscription.qos, 1, 1, f);
  282 + }
  283 + }
  284 +
  285 + fflush(f);
  286 +}
  287 +
  288 +SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readData()
  289 +{
  290 + SessionsAndSubscriptionsResult defaultResult;
  291 +
  292 + if (!f)
  293 + return defaultResult;
  294 +
  295 + if (readVersion == ReadVersion::v1)
  296 + return readDataV1();
  297 +
  298 + return defaultResult;
  299 +}
sessionsandsubscriptionsdb.h 0 โ†’ 100644
  1 +/*
  2 +This file is part of FlashMQ (https://www.flashmq.org)
  3 +Copyright (C) 2021 Wiebe Cazemier
  4 +
  5 +FlashMQ is free software: you can redistribute it and/or modify
  6 +it under the terms of the GNU Affero General Public License as
  7 +published by the Free Software Foundation, version 3.
  8 +
  9 +FlashMQ is distributed in the hope that it will be useful,
  10 +but WITHOUT ANY WARRANTY; without even the implied warranty of
  11 +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12 +GNU Affero General Public License for more details.
  13 +
  14 +You should have received a copy of the GNU Affero General Public
  15 +License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
  16 +*/
  17 +
  18 +#ifndef SESSIONSANDSUBSCRIPTIONSDB_H
  19 +#define SESSIONSANDSUBSCRIPTIONSDB_H
  20 +
  21 +#include <list>
  22 +#include <memory>
  23 +
  24 +#include "persistencefile.h"
  25 +#include "session.h"
  26 +
  27 +#define MAGIC_STRING_SESSION_FILE_V1 "FlashMQRetainedDBv1"
  28 +#define RESERVED_SPACE_SESSIONS_DB_V1 32
  29 +
  30 +/**
  31 + * @brief The SubscriptionForSerializing struct contains the fields we're interested in when saving a subscription.
  32 + */
  33 +struct SubscriptionForSerializing
  34 +{
  35 + const std::string clientId;
  36 + const char qos = 0;
  37 +
  38 + SubscriptionForSerializing(const std::string &clientId, char qos);
  39 + SubscriptionForSerializing(const std::string &&clientId, char qos);
  40 +};
  41 +
  42 +struct SessionsAndSubscriptionsResult
  43 +{
  44 + std::list<std::shared_ptr<Session>> sessions;
  45 + std::unordered_map<std::string, std::list<SubscriptionForSerializing>> subscriptions;
  46 +};
  47 +
  48 +
  49 +class SessionsAndSubscriptionsDB : public PersistenceFile
  50 +{
  51 + enum class ReadVersion
  52 + {
  53 + unknown,
  54 + v1
  55 + };
  56 +
  57 + ReadVersion readVersion = ReadVersion::unknown;
  58 +
  59 + SessionsAndSubscriptionsResult readDataV1();
  60 + void writeRowHeader();
  61 +public:
  62 + SessionsAndSubscriptionsDB(const std::string &filePath);
  63 +
  64 + void openWrite();
  65 + void openRead();
  66 +
  67 + void saveData(const std::list<std::unique_ptr<Session>> &sessions, const std::unordered_map<std::string, std::list<SubscriptionForSerializing>> &subscriptions);
  68 + SessionsAndSubscriptionsResult readData();
  69 +};
  70 +
  71 +#endif // SESSIONSANDSUBSCRIPTIONSDB_H
settings.cpp
@@ -16,7 +16,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -16,7 +16,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
16 */ 16 */
17 17
18 #include "settings.h" 18 #include "settings.h"
19 - 19 +#include "utils.h"
20 20
21 AuthOptCompatWrap &Settings::getAuthOptsCompat() 21 AuthOptCompatWrap &Settings::getAuthOptsCompat()
22 { 22 {
@@ -27,3 +27,21 @@ std::unordered_map&lt;std::string, std::string&gt; &amp;Settings::getFlashmqAuthPluginOpts @@ -27,3 +27,21 @@ std::unordered_map&lt;std::string, std::string&gt; &amp;Settings::getFlashmqAuthPluginOpts
27 { 27 {
28 return this->flashmqAuthPluginOpts; 28 return this->flashmqAuthPluginOpts;
29 } 29 }
  30 +
  31 +std::string Settings::getRetainedMessagesDBFile() const
  32 +{
  33 + if (storageDir.empty())
  34 + return "";
  35 +
  36 + std::string path = formatString("%s/%s", storageDir.c_str(), "retained.db");
  37 + return path;
  38 +}
  39 +
  40 +std::string Settings::getSessionsDBFile() const
  41 +{
  42 + if (storageDir.empty())
  43 + return "";
  44 +
  45 + std::string path = formatString("%s/%s", storageDir.c_str(), "sessions.db");
  46 + return path;
  47 +}
settings.h
@@ -53,10 +53,14 @@ public: @@ -53,10 +53,14 @@ public:
53 int rlimitNoFile = 1000000; 53 int rlimitNoFile = 1000000;
54 uint64_t expireSessionsAfterSeconds = 1209600; 54 uint64_t expireSessionsAfterSeconds = 1209600;
55 int authPluginTimerPeriod = 60; 55 int authPluginTimerPeriod = 60;
  56 + std::string storageDir;
56 std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined. 57 std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined.
57 58
58 AuthOptCompatWrap &getAuthOptsCompat(); 59 AuthOptCompatWrap &getAuthOptsCompat();
59 std::unordered_map<std::string, std::string> &getFlashmqAuthPluginOpts(); 60 std::unordered_map<std::string, std::string> &getFlashmqAuthPluginOpts();
  61 +
  62 + std::string getRetainedMessagesDBFile() const;
  63 + std::string getSessionsDBFile() const;
60 }; 64 };
61 65
62 #endif // SETTINGS_H 66 #endif // SETTINGS_H
subscriptionstore.cpp
@@ -20,7 +20,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -20,7 +20,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
20 #include "cassert" 20 #include "cassert"
21 21
22 #include "rwlockguard.h" 22 #include "rwlockguard.h"
23 - 23 +#include "retainedmessagesdb.h"
24 24
25 SubscriptionNode::SubscriptionNode(const std::string &subtopic) : 25 SubscriptionNode::SubscriptionNode(const std::string &subtopic) :
26 subtopic(subtopic) 26 subtopic(subtopic)
@@ -33,6 +33,11 @@ std::vector&lt;Subscription&gt; &amp;SubscriptionNode::getSubscribers() @@ -33,6 +33,11 @@ std::vector&lt;Subscription&gt; &amp;SubscriptionNode::getSubscribers()
33 return subscribers; 33 return subscribers;
34 } 34 }
35 35
  36 +const std::string &SubscriptionNode::getSubtopic() const
  37 +{
  38 + return subtopic;
  39 +}
  40 +
36 void SubscriptionNode::addSubscriber(const std::shared_ptr<Session> &subscriber, char qos) 41 void SubscriptionNode::addSubscriber(const std::shared_ptr<Session> &subscriber, char qos)
37 { 42 {
38 Subscription sub; 43 Subscription sub;
@@ -84,15 +89,20 @@ SubscriptionStore::SubscriptionStore() : @@ -84,15 +89,20 @@ SubscriptionStore::SubscriptionStore() :
84 89
85 } 90 }
86 91
87 -void SubscriptionStore::addSubscription(std::shared_ptr<Client> &client, const std::string &topic, const std::vector<std::string> &subtopics, char qos) 92 +/**
  93 + * @brief SubscriptionStore::getDeepestNode gets the node in the tree walking the path of 'the/subscription/topic/path', making new nodes as required.
  94 + * @param topic
  95 + * @param subtopics
  96 + * @return
  97 + *
  98 + * caller is responsible for locking.
  99 + */
  100 +SubscriptionNode *SubscriptionStore::getDeepestNode(const std::string &topic, const std::vector<std::string> &subtopics)
88 { 101 {
89 SubscriptionNode *deepestNode = &root; 102 SubscriptionNode *deepestNode = &root;
90 if (topic.length() > 0 && topic[0] == '$') 103 if (topic.length() > 0 && topic[0] == '$')
91 deepestNode = &rootDollar; 104 deepestNode = &rootDollar;
92 105
93 - RWLockGuard lock_guard(&subscriptionsRwlock);  
94 - lock_guard.wrlock();  
95 -  
96 for(const std::string &subtopic : subtopics) 106 for(const std::string &subtopic : subtopics)
97 { 107 {
98 std::unique_ptr<SubscriptionNode> *selectedChildren = nullptr; 108 std::unique_ptr<SubscriptionNode> *selectedChildren = nullptr;
@@ -114,6 +124,15 @@ void SubscriptionStore::addSubscription(std::shared_ptr&lt;Client&gt; &amp;client, const s @@ -114,6 +124,15 @@ void SubscriptionStore::addSubscription(std::shared_ptr&lt;Client&gt; &amp;client, const s
114 } 124 }
115 125
116 assert(deepestNode); 126 assert(deepestNode);
  127 + return deepestNode;
  128 +}
  129 +
  130 +void SubscriptionStore::addSubscription(std::shared_ptr<Client> &client, const std::string &topic, const std::vector<std::string> &subtopics, char qos)
  131 +{
  132 + RWLockGuard lock_guard(&subscriptionsRwlock);
  133 + lock_guard.wrlock();
  134 +
  135 + SubscriptionNode *deepestNode = getDeepestNode(topic, subtopics);
117 136
118 if (deepestNode) 137 if (deepestNode)
119 { 138 {
@@ -494,6 +513,183 @@ int64_t SubscriptionStore::getRetainedMessageCount() const @@ -494,6 +513,183 @@ int64_t SubscriptionStore::getRetainedMessageCount() const
494 return retainedMessageCount; 513 return retainedMessageCount;
495 } 514 }
496 515
  516 +void SubscriptionStore::getRetainedMessages(RetainedMessageNode *this_node, std::vector<RetainedMessage> &outputList) const
  517 +{
  518 + for(const RetainedMessage &rm : this_node->retainedMessages)
  519 + {
  520 + outputList.push_back(rm);
  521 + }
  522 +
  523 + for(auto &pair : this_node->children)
  524 + {
  525 + const std::unique_ptr<RetainedMessageNode> &child = pair.second;
  526 + getRetainedMessages(child.get(), outputList);
  527 + }
  528 +}
  529 +
  530 +/**
  531 + * @brief SubscriptionStore::getSubscriptions
  532 + * @param this_node
  533 + * @param composedTopic
  534 + * @param root bool. Every subtopic is concatenated with a '/', but not the first topic to 'root'. The root is a bit weird, virtual, so it needs different treatment.
  535 + * @param outputList
  536 + */
  537 +void SubscriptionStore::getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root,
  538 + std::unordered_map<std::string, std::list<SubscriptionForSerializing>> &outputList) const
  539 +{
  540 + for (const Subscription &node : this_node->getSubscribers())
  541 + {
  542 + if (!node.sessionGone())
  543 + {
  544 + SubscriptionForSerializing sub(node.session.lock()->getClientId(), node.qos);
  545 + outputList[composedTopic].push_back(sub);
  546 + }
  547 + }
  548 +
  549 + for (auto &pair : this_node->children)
  550 + {
  551 + SubscriptionNode *node = pair.second.get();
  552 + const std::string topicAtNextLevel = root ? pair.first : composedTopic + "/" + pair.first;
  553 + getSubscriptions(node, topicAtNextLevel, false, outputList);
  554 + }
  555 +
  556 + if (this_node->childrenPlus)
  557 + {
  558 + const std::string topicAtNextLevel = root ? "+" : composedTopic + "/+";
  559 + getSubscriptions(this_node->childrenPlus.get(), topicAtNextLevel, false, outputList);
  560 + }
  561 +
  562 + if (this_node->childrenPound)
  563 + {
  564 + const std::string topicAtNextLevel = root ? "#" : composedTopic + "/#";
  565 + getSubscriptions(this_node->childrenPound.get(), topicAtNextLevel, false, outputList);
  566 + }
  567 +}
  568 +
  569 +void SubscriptionStore::saveRetainedMessages(const std::string &filePath)
  570 +{
  571 + logger->logf(LOG_INFO, "Saving retained messages to '%s'", filePath.c_str());
  572 +
  573 + std::vector<RetainedMessage> result;
  574 + result.reserve(retainedMessageCount);
  575 +
  576 + // Create the list of messages under lock, and unlock right after.
  577 + RWLockGuard locker(&retainedMessagesRwlock);
  578 + locker.rdlock();
  579 + getRetainedMessages(&retainedMessagesRoot, result);
  580 + locker.unlock();
  581 +
  582 + logger->logf(LOG_DEBUG, "Collected %ld retained messages to save.", result.size());
  583 +
  584 + // Then do the IO without locking the threads.
  585 + RetainedMessagesDB db(filePath);
  586 + db.openWrite();
  587 + db.saveData(result);
  588 +}
  589 +
  590 +void SubscriptionStore::loadRetainedMessages(const std::string &filePath)
  591 +{
  592 + try
  593 + {
  594 + logger->logf(LOG_INFO, "Loading '%s'", filePath.c_str());
  595 +
  596 + RetainedMessagesDB db(filePath);
  597 + db.openRead();
  598 + std::list<RetainedMessage> messages = db.readData();
  599 +
  600 + RWLockGuard locker(&retainedMessagesRwlock);
  601 + locker.wrlock();
  602 +
  603 + std::vector<std::string> subtopics;
  604 + for (const RetainedMessage &rm : messages)
  605 + {
  606 + splitTopic(rm.topic, subtopics);
  607 + setRetainedMessage(rm.topic, subtopics, rm.payload, rm.qos);
  608 + }
  609 + }
  610 + catch (PersistenceFileCantBeOpened &ex)
  611 + {
  612 + logger->logf(LOG_WARNING, "File '%s' is not there (yet)", filePath.c_str());
  613 + }
  614 +}
  615 +
  616 +void SubscriptionStore::saveSessionsAndSubscriptions(const std::string &filePath)
  617 +{
  618 + logger->logf(LOG_INFO, "Saving sessions and subscriptions to '%s'", filePath.c_str());
  619 +
  620 + RWLockGuard lock_guard(&subscriptionsRwlock);
  621 + lock_guard.wrlock();
  622 +
  623 + // First copy the sessions...
  624 +
  625 + std::list<std::unique_ptr<Session>> sessionCopies;
  626 +
  627 + for (const auto &pair : sessionsByIdConst)
  628 + {
  629 + const Session &org = *pair.second.get();
  630 + sessionCopies.push_back(org.getCopy());
  631 + }
  632 +
  633 + std::unordered_map<std::string, std::list<SubscriptionForSerializing>> subscriptionCopies;
  634 + getSubscriptions(&root, "", true, subscriptionCopies);
  635 +
  636 + lock_guard.unlock();
  637 +
  638 + // Then write the copies to disk, after having released the lock
  639 +
  640 + logger->logf(LOG_DEBUG, "Collected %ld sessions and %ld subscriptions to save.", sessionCopies.size(), subscriptionCopies.size());
  641 +
  642 + SessionsAndSubscriptionsDB db(filePath);
  643 + db.openWrite();
  644 + db.saveData(sessionCopies, subscriptionCopies);
  645 +}
  646 +
  647 +void SubscriptionStore::loadSessionsAndSubscriptions(const std::string &filePath)
  648 +{
  649 + try
  650 + {
  651 + logger->logf(LOG_INFO, "Loading '%s'", filePath.c_str());
  652 +
  653 + SessionsAndSubscriptionsDB db(filePath);
  654 + db.openRead();
  655 + SessionsAndSubscriptionsResult loadedData = db.readData();
  656 +
  657 + RWLockGuard locker(&subscriptionsRwlock);
  658 + locker.wrlock();
  659 +
  660 + for (std::shared_ptr<Session> &session : loadedData.sessions)
  661 + {
  662 + sessionsById[session->getClientId()] = session;
  663 + }
  664 +
  665 + std::vector<std::string> subtopics;
  666 +
  667 + for (auto &pair : loadedData.subscriptions)
  668 + {
  669 + const std::string &topic = pair.first;
  670 + const std::list<SubscriptionForSerializing> &subs = pair.second;
  671 +
  672 + for (const SubscriptionForSerializing &sub : subs)
  673 + {
  674 + splitTopic(topic, subtopics);
  675 + SubscriptionNode *subscriptionNode = getDeepestNode(topic, subtopics);
  676 +
  677 + auto session_it = sessionsByIdConst.find(sub.clientId);
  678 + if (session_it != sessionsByIdConst.end())
  679 + {
  680 + const std::shared_ptr<Session> &ses = session_it->second;
  681 + subscriptionNode->addSubscriber(ses, sub.qos);
  682 + }
  683 +
  684 + }
  685 + }
  686 + }
  687 + catch (PersistenceFileCantBeOpened &ex)
  688 + {
  689 + logger->logf(LOG_WARNING, "File '%s' is not there (yet)", filePath.c_str());
  690 + }
  691 +}
  692 +
497 // QoS is not used in the comparision. This means you upgrade your QoS by subscribing again. The 693 // QoS is not used in the comparision. This means you upgrade your QoS by subscribing again. The
498 // specs don't specify what to do there. 694 // specs don't specify what to do there.
499 bool Subscription::operator==(const Subscription &rhs) const 695 bool Subscription::operator==(const Subscription &rhs) const
subscriptionstore.h
@@ -53,6 +53,7 @@ public: @@ -53,6 +53,7 @@ public:
53 SubscriptionNode(SubscriptionNode &&node) = delete; 53 SubscriptionNode(SubscriptionNode &&node) = delete;
54 54
55 std::vector<Subscription> &getSubscribers(); 55 std::vector<Subscription> &getSubscribers();
  56 + const std::string &getSubtopic() const;
56 void addSubscriber(const std::shared_ptr<Session> &subscriber, char qos); 57 void addSubscriber(const std::shared_ptr<Session> &subscriber, char qos);
57 void removeSubscriber(const std::shared_ptr<Session> &subscriber); 58 void removeSubscriber(const std::shared_ptr<Session> &subscriber);
58 std::unordered_map<std::string, std::unique_ptr<SubscriptionNode>> children; 59 std::unordered_map<std::string, std::unique_ptr<SubscriptionNode>> children;
@@ -77,6 +78,10 @@ class RetainedMessageNode @@ -77,6 +78,10 @@ class RetainedMessageNode
77 78
78 class SubscriptionStore 79 class SubscriptionStore
79 { 80 {
  81 +#ifdef TESTING
  82 + friend class MainTests;
  83 +#endif
  84 +
80 SubscriptionNode root; 85 SubscriptionNode root;
81 SubscriptionNode rootDollar; 86 SubscriptionNode rootDollar;
82 pthread_rwlock_t subscriptionsRwlock = PTHREAD_RWLOCK_INITIALIZER; 87 pthread_rwlock_t subscriptionsRwlock = PTHREAD_RWLOCK_INITIALIZER;
@@ -93,7 +98,11 @@ class SubscriptionStore @@ -93,7 +98,11 @@ class SubscriptionStore
93 void publishNonRecursively(const MqttPacket &packet, const std::vector<Subscription> &subscribers, uint64_t &count) const; 98 void publishNonRecursively(const MqttPacket &packet, const std::vector<Subscription> &subscribers, uint64_t &count) const;
94 void publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end, 99 void publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end,
95 SubscriptionNode *this_node, const MqttPacket &packet, uint64_t &count) const; 100 SubscriptionNode *this_node, const MqttPacket &packet, uint64_t &count) const;
  101 + void getRetainedMessages(RetainedMessageNode *this_node, std::vector<RetainedMessage> &outputList) const;
  102 + void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root,
  103 + std::unordered_map<std::string, std::list<SubscriptionForSerializing>> &outputList) const;
96 104
  105 + SubscriptionNode *getDeepestNode(const std::string &topic, const std::vector<std::string> &subtopics);
97 public: 106 public:
98 SubscriptionStore(); 107 SubscriptionStore();
99 108
@@ -103,7 +112,6 @@ public: @@ -103,7 +112,6 @@ public:
103 bool sessionPresent(const std::string &clientid); 112 bool sessionPresent(const std::string &clientid);
104 113
105 void queuePacketAtSubscribers(const std::vector<std::string> &subtopics, const MqttPacket &packet, bool dollar = false); 114 void queuePacketAtSubscribers(const std::vector<std::string> &subtopics, const MqttPacket &packet, bool dollar = false);
106 - uint64_t giveClientRetainedMessages(const std::shared_ptr<Session> &ses, const std::string &subscribe_topic, char max_qos);  
107 void giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end, 115 void giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end,
108 RetainedMessageNode *this_node, char max_qos, const std::shared_ptr<Session> &ses, 116 RetainedMessageNode *this_node, char max_qos, const std::shared_ptr<Session> &ses,
109 bool poundMode, uint64_t &count) const; 117 bool poundMode, uint64_t &count) const;
@@ -114,6 +122,12 @@ public: @@ -114,6 +122,12 @@ public:
114 void removeExpiredSessionsClients(int expireSessionsAfterSeconds); 122 void removeExpiredSessionsClients(int expireSessionsAfterSeconds);
115 123
116 int64_t getRetainedMessageCount() const; 124 int64_t getRetainedMessageCount() const;
  125 +
  126 + void saveRetainedMessages(const std::string &filePath);
  127 + void loadRetainedMessages(const std::string &filePath);
  128 +
  129 + void saveSessionsAndSubscriptions(const std::string &filePath);
  130 + void loadSessionsAndSubscriptions(const std::string &filePath);
117 }; 131 };
118 132
119 #endif // SUBSCRIPTIONSTORE_H 133 #endif // SUBSCRIPTIONSTORE_H
utils.cpp
@@ -289,6 +289,14 @@ void trim(std::string &amp;s) @@ -289,6 +289,14 @@ void trim(std::string &amp;s)
289 rtrim(s); 289 rtrim(s);
290 } 290 }
291 291
  292 +std::string &rtrim(std::string &s, unsigned char c)
  293 +{
  294 + s.erase(std::find_if(s.rbegin(), s.rend(), [=](unsigned char ch) {
  295 + return (c != ch);
  296 + }).base(), s.end());
  297 + return s;
  298 +}
  299 +
292 bool startsWith(const std::string &s, const std::string &needle) 300 bool startsWith(const std::string &s, const std::string &needle)
293 { 301 {
294 return s.find(needle) == 0; 302 return s.find(needle) == 0;
@@ -28,6 +28,8 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -28,6 +28,8 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
28 #include <openssl/evp.h> 28 #include <openssl/evp.h>
29 #include <memory> 29 #include <memory>
30 #include <arpa/inet.h> 30 #include <arpa/inet.h>
  31 +#include "unistd.h"
  32 +#include "sys/stat.h"
31 33
32 #include "cirbuf.h" 34 #include "cirbuf.h"
33 #include "bindaddr.h" 35 #include "bindaddr.h"
@@ -64,6 +66,7 @@ void ltrim(std::string &amp;s); @@ -64,6 +66,7 @@ void ltrim(std::string &amp;s);
64 void rtrim(std::string &s); 66 void rtrim(std::string &s);
65 void trim(std::string &s); 67 void trim(std::string &s);
66 bool startsWith(const std::string &s, const std::string &needle); 68 bool startsWith(const std::string &s, const std::string &needle);
  69 +std::string &rtrim(std::string &s, unsigned char c);
67 70
68 std::string getSecureRandomString(const ssize_t len); 71 std::string getSecureRandomString(const ssize_t len);
69 std::string str_tolower(std::string s); 72 std::string str_tolower(std::string s);
@@ -92,5 +95,32 @@ ssize_t getFileSize(const std::string &amp;path); @@ -92,5 +95,32 @@ ssize_t getFileSize(const std::string &amp;path);
92 95
93 std::string sockaddrToString(struct sockaddr *addr); 96 std::string sockaddrToString(struct sockaddr *addr);
94 97
  98 +template<typename ex> void checkWritableDir(const std::string &path)
  99 +{
  100 + if (path.empty())
  101 + throw ex("Dir path to check is an empty string.");
  102 +
  103 + if (access(path.c_str(), W_OK) != 0)
  104 + {
  105 + std::string msg = formatString("Path '%s' is not there or not writable", path.c_str());
  106 + throw ex(msg);
  107 + }
  108 +
  109 + struct stat statbuf;
  110 + memset(&statbuf, 0, sizeof(struct stat));
  111 + if (stat(path.c_str(), &statbuf) < 0)
  112 + {
  113 + // We checked for W_OK above, so this shouldn't happen.
  114 + std::string msg = formatString("Error getting information about '%s'.", path.c_str());
  115 + throw ex(msg);
  116 + }
  117 +
  118 + if (!S_ISDIR(statbuf.st_mode))
  119 + {
  120 + std::string msg = formatString("Path '%s' is not a directory.", path.c_str());
  121 + throw ex(msg);
  122 + }
  123 +}
  124 +
95 125
96 #endif // UTILS_H 126 #endif // UTILS_H