Commit 4bfa5aa5716d52b57ed4e725f7f5dd6aa03e395d
1 parent
0cbe2a00
Add saving state feature
Files are simple serialized bytes prefaced by lengths. File is hashed to verify integrity. This was also a good way preventing unexpected errors when trying to crash the parser by having it load a different file. This change includes some refactoring that was necessary: - It 'fixes' looking at the wrong thread's authentiction. This is still wrong though. It will be fixed by a thread local pointer in the next commit. - Deadlocks with yourself are handled in rwlockguard. - QoSPacketQueue is now a class. - Probably other tweaks.
Showing
27 changed files
with
1795 additions
and
60 deletions
CMakeLists.txt
| ... | ... | @@ -50,6 +50,10 @@ add_executable(FlashMQ |
| 50 | 50 | enums.h |
| 51 | 51 | threadlocalutils.h |
| 52 | 52 | flashmq_plugin.h |
| 53 | + retainedmessagesdb.h | |
| 54 | + persistencefile.h | |
| 55 | + sessionsandsubscriptionsdb.h | |
| 56 | + qospacketqueue.h | |
| 53 | 57 | |
| 54 | 58 | mainapp.cpp |
| 55 | 59 | main.cpp |
| ... | ... | @@ -81,6 +85,10 @@ add_executable(FlashMQ |
| 81 | 85 | acltree.cpp |
| 82 | 86 | threadlocalutils.cpp |
| 83 | 87 | flashmq_plugin.cpp |
| 88 | + retainedmessagesdb.cpp | |
| 89 | + persistencefile.cpp | |
| 90 | + sessionsandsubscriptionsdb.cpp | |
| 91 | + qospacketqueue.cpp | |
| 84 | 92 | |
| 85 | 93 | ) |
| 86 | 94 | ... | ... |
FlashMQTests/FlashMQTests.pro
| ... | ... | @@ -42,6 +42,10 @@ SOURCES += tst_maintests.cpp \ |
| 42 | 42 | ../acltree.cpp \ |
| 43 | 43 | ../threadlocalutils.cpp \ |
| 44 | 44 | ../flashmq_plugin.cpp \ |
| 45 | + ../retainedmessagesdb.cpp \ | |
| 46 | + ../persistencefile.cpp \ | |
| 47 | + ../sessionsandsubscriptionsdb.cpp \ | |
| 48 | + ../qospacketqueue.cpp \ | |
| 45 | 49 | mainappthread.cpp \ |
| 46 | 50 | twoclienttestcontext.cpp |
| 47 | 51 | |
| ... | ... | @@ -77,6 +81,10 @@ HEADERS += \ |
| 77 | 81 | ../acltree.h \ |
| 78 | 82 | ../threadlocalutils.h \ |
| 79 | 83 | ../flashmq_plugin.h \ |
| 84 | + ../retainedmessagesdb.h \ | |
| 85 | + ../persistencefile.h \ | |
| 86 | + ../sessionsandsubscriptionsdb.h \ | |
| 87 | + ../qospacketqueue.h \ | |
| 80 | 88 | mainappthread.h \ |
| 81 | 89 | twoclienttestcontext.h |
| 82 | 90 | ... | ... |
FlashMQTests/tst_maintests.cpp
| ... | ... | @@ -21,12 +21,18 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. |
| 21 | 21 | #include <QtQmqtt/qmqtt.h> |
| 22 | 22 | #include <QScopedPointer> |
| 23 | 23 | #include <QHostInfo> |
| 24 | +#include <list> | |
| 25 | +#include <unordered_map> | |
| 24 | 26 | |
| 25 | 27 | #include "cirbuf.h" |
| 26 | 28 | #include "mainapp.h" |
| 27 | 29 | #include "mainappthread.h" |
| 28 | 30 | #include "twoclienttestcontext.h" |
| 29 | 31 | #include "threadlocalutils.h" |
| 32 | +#include "retainedmessagesdb.h" | |
| 33 | +#include "sessionsandsubscriptionsdb.h" | |
| 34 | +#include "session.h" | |
| 35 | +#include "threaddata.h" | |
| 30 | 36 | |
| 31 | 37 | // Dumb Qt version gives warnings when comparing uint with number literal. |
| 32 | 38 | template <typename T1, typename T2> |
| ... | ... | @@ -88,6 +94,12 @@ private slots: |
| 88 | 94 | |
| 89 | 95 | void testTopicsMatch(); |
| 90 | 96 | |
| 97 | + void testRetainedMessageDB(); | |
| 98 | + void testRetainedMessageDBNotPresent(); | |
| 99 | + void testRetainedMessageDBEmptyList(); | |
| 100 | + | |
| 101 | + void testSavingSessions(); | |
| 102 | + | |
| 91 | 103 | }; |
| 92 | 104 | |
| 93 | 105 | MainTests::MainTests() |
| ... | ... | @@ -822,6 +834,221 @@ void MainTests::testTopicsMatch() |
| 822 | 834 | |
| 823 | 835 | } |
| 824 | 836 | |
| 837 | +void MainTests::testRetainedMessageDB() | |
| 838 | +{ | |
| 839 | + try | |
| 840 | + { | |
| 841 | + std::string longpayload = getSecureRandomString(65537); | |
| 842 | + std::string longTopic = formatString("one/two/%s", getSecureRandomString(4000).c_str()); | |
| 843 | + | |
| 844 | + std::vector<RetainedMessage> messages; | |
| 845 | + messages.emplace_back("one/two/three", "payload", 0); | |
| 846 | + messages.emplace_back("one/two/wer", "payload", 1); | |
| 847 | + messages.emplace_back("one/e/wer", "payload", 1); | |
| 848 | + messages.emplace_back("one/wee/wer", "asdfasdfasdf", 1); | |
| 849 | + messages.emplace_back("one/two/wer", "ยตsdf", 1); | |
| 850 | + messages.emplace_back("/boe/bah", longpayload, 1); | |
| 851 | + messages.emplace_back("one/two/wer", "paylasdfaoad", 1); | |
| 852 | + messages.emplace_back("one/two/wer", "payload", 1); | |
| 853 | + messages.emplace_back(longTopic, "payload", 1); | |
| 854 | + messages.emplace_back(longTopic, longpayload, 1); | |
| 855 | + messages.emplace_back("one", "ยตsdf", 1); | |
| 856 | + messages.emplace_back("/boe", longpayload, 1); | |
| 857 | + messages.emplace_back("one", "ยตsdf", 1); | |
| 858 | + messages.emplace_back("", "foremptytopic", 0); | |
| 859 | + | |
| 860 | + RetainedMessagesDB db("/tmp/flashmqtests_retained.db"); | |
| 861 | + db.openWrite(); | |
| 862 | + db.saveData(messages); | |
| 863 | + db.closeFile(); | |
| 864 | + | |
| 865 | + RetainedMessagesDB db2("/tmp/flashmqtests_retained.db"); | |
| 866 | + db2.openRead(); | |
| 867 | + std::list<RetainedMessage> messagesLoaded = db2.readData(); | |
| 868 | + db2.closeFile(); | |
| 869 | + | |
| 870 | + QCOMPARE(messages.size(), messagesLoaded.size()); | |
| 871 | + | |
| 872 | + auto itOrg = messages.begin(); | |
| 873 | + auto itLoaded = messagesLoaded.begin(); | |
| 874 | + while (itOrg != messages.end() && itLoaded != messagesLoaded.end()) | |
| 875 | + { | |
| 876 | + RetainedMessage &one = *itOrg; | |
| 877 | + RetainedMessage &two = *itLoaded; | |
| 878 | + | |
| 879 | + // Comparing the fields because the RetainedMessage class has an == operator that only looks at topic. | |
| 880 | + QCOMPARE(one.topic, two.topic); | |
| 881 | + QCOMPARE(one.payload, two.payload); | |
| 882 | + QCOMPARE(one.qos, two.qos); | |
| 883 | + | |
| 884 | + itOrg++; | |
| 885 | + itLoaded++; | |
| 886 | + } | |
| 887 | + } | |
| 888 | + catch (std::exception &ex) | |
| 889 | + { | |
| 890 | + QVERIFY2(false, ex.what()); | |
| 891 | + } | |
| 892 | +} | |
| 893 | + | |
| 894 | +void MainTests::testRetainedMessageDBNotPresent() | |
| 895 | +{ | |
| 896 | + try | |
| 897 | + { | |
| 898 | + RetainedMessagesDB db2("/tmp/flashmqtests_asdfasdfasdf.db"); | |
| 899 | + db2.openRead(); | |
| 900 | + std::list<RetainedMessage> messagesLoaded = db2.readData(); | |
| 901 | + db2.closeFile(); | |
| 902 | + | |
| 903 | + MYCASTCOMPARE(messagesLoaded.size(), 0); | |
| 904 | + | |
| 905 | + QVERIFY2(false, "We should have run into an exception."); | |
| 906 | + } | |
| 907 | + catch (PersistenceFileCantBeOpened &ex) | |
| 908 | + { | |
| 909 | + QVERIFY(true); | |
| 910 | + } | |
| 911 | + catch (std::exception &ex) | |
| 912 | + { | |
| 913 | + QVERIFY2(false, ex.what()); | |
| 914 | + } | |
| 915 | +} | |
| 916 | + | |
| 917 | +void MainTests::testRetainedMessageDBEmptyList() | |
| 918 | +{ | |
| 919 | + try | |
| 920 | + { | |
| 921 | + std::vector<RetainedMessage> messages; | |
| 922 | + | |
| 923 | + RetainedMessagesDB db("/tmp/flashmqtests_retained.db"); | |
| 924 | + db.openWrite(); | |
| 925 | + db.saveData(messages); | |
| 926 | + db.closeFile(); | |
| 927 | + | |
| 928 | + RetainedMessagesDB db2("/tmp/flashmqtests_retained.db"); | |
| 929 | + db2.openRead(); | |
| 930 | + std::list<RetainedMessage> messagesLoaded = db2.readData(); | |
| 931 | + db2.closeFile(); | |
| 932 | + | |
| 933 | + MYCASTCOMPARE(messages.size(), messagesLoaded.size()); | |
| 934 | + MYCASTCOMPARE(messages.size(), 0); | |
| 935 | + } | |
| 936 | + catch (std::exception &ex) | |
| 937 | + { | |
| 938 | + QVERIFY2(false, ex.what()); | |
| 939 | + } | |
| 940 | +} | |
| 941 | + | |
| 942 | +void MainTests::testSavingSessions() | |
| 943 | +{ | |
| 944 | + try | |
| 945 | + { | |
| 946 | + std::shared_ptr<Settings> settings(new Settings()); | |
| 947 | + std::shared_ptr<SubscriptionStore> store(new SubscriptionStore()); | |
| 948 | + std::shared_ptr<ThreadData> t(new ThreadData(0, store, settings)); | |
| 949 | + | |
| 950 | + std::shared_ptr<Client> c1(new Client(0, t, nullptr, false, nullptr, settings, false)); | |
| 951 | + c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60, false); | |
| 952 | + store->registerClientAndKickExistingOne(c1); | |
| 953 | + c1->getSession()->touch(); | |
| 954 | + c1->getSession()->addIncomingQoS2MessageId(2); | |
| 955 | + c1->getSession()->addIncomingQoS2MessageId(3); | |
| 956 | + | |
| 957 | + std::shared_ptr<Client> c2(new Client(0, t, nullptr, false, nullptr, settings, false)); | |
| 958 | + c2->setClientProperties(ProtocolVersion::Mqtt311, "c2", "user2", true, 60, false); | |
| 959 | + store->registerClientAndKickExistingOne(c2); | |
| 960 | + c2->getSession()->touch(); | |
| 961 | + c2->getSession()->addOutgoingQoS2MessageId(55); | |
| 962 | + c2->getSession()->addOutgoingQoS2MessageId(66); | |
| 963 | + | |
| 964 | + const std::string topic1 = "one/two/three"; | |
| 965 | + std::vector<std::string> subtopics; | |
| 966 | + splitTopic(topic1, subtopics); | |
| 967 | + store->addSubscription(c1, topic1, subtopics, 0); | |
| 968 | + | |
| 969 | + const std::string topic2 = "four/five/six"; | |
| 970 | + splitTopic(topic2, subtopics); | |
| 971 | + store->addSubscription(c2, topic2, subtopics, 0); | |
| 972 | + store->addSubscription(c1, topic2, subtopics, 0); | |
| 973 | + | |
| 974 | + const std::string topic3 = ""; | |
| 975 | + splitTopic(topic3, subtopics); | |
| 976 | + store->addSubscription(c2, topic3, subtopics, 0); | |
| 977 | + | |
| 978 | + const std::string topic4 = "#"; | |
| 979 | + splitTopic(topic4, subtopics); | |
| 980 | + store->addSubscription(c2, topic4, subtopics, 0); | |
| 981 | + | |
| 982 | + uint64_t count = 0; | |
| 983 | + | |
| 984 | + Publish publish("a/b/c", "Hello Barry", 1); | |
| 985 | + | |
| 986 | + std::shared_ptr<Session> c1ses = c1->getSession(); | |
| 987 | + c1.reset(); | |
| 988 | + c1ses->writePacket(publish, 1, false, count); | |
| 989 | + | |
| 990 | + store->saveSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db"); | |
| 991 | + | |
| 992 | + std::shared_ptr<SubscriptionStore> store2(new SubscriptionStore()); | |
| 993 | + store2->loadSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db"); | |
| 994 | + | |
| 995 | + MYCASTCOMPARE(store->sessionsById.size(), 2); | |
| 996 | + MYCASTCOMPARE(store2->sessionsById.size(), 2); | |
| 997 | + | |
| 998 | + for (auto &pair : store->sessionsById) | |
| 999 | + { | |
| 1000 | + std::shared_ptr<Session> &ses = pair.second; | |
| 1001 | + std::shared_ptr<Session> &ses2 = store2->sessionsById[pair.first]; | |
| 1002 | + QCOMPARE(pair.first, ses2->getClientId()); | |
| 1003 | + | |
| 1004 | + QCOMPARE(ses->username, ses2->username); | |
| 1005 | + QCOMPARE(ses->client_id, ses2->client_id); | |
| 1006 | + QCOMPARE(ses->incomingQoS2MessageIds, ses2->incomingQoS2MessageIds); | |
| 1007 | + QCOMPARE(ses->outgoingQoS2MessageIds, ses2->outgoingQoS2MessageIds); | |
| 1008 | + QCOMPARE(ses->nextPacketId, ses2->nextPacketId); | |
| 1009 | + } | |
| 1010 | + | |
| 1011 | + | |
| 1012 | + std::unordered_map<std::string, std::list<SubscriptionForSerializing>> store1Subscriptions; | |
| 1013 | + store->getSubscriptions(&store->root, "", true, store1Subscriptions); | |
| 1014 | + | |
| 1015 | + std::unordered_map<std::string, std::list<SubscriptionForSerializing>> store2Subscriptions; | |
| 1016 | + store->getSubscriptions(&store->root, "", true, store2Subscriptions); | |
| 1017 | + | |
| 1018 | + MYCASTCOMPARE(store1Subscriptions.size(), 4); | |
| 1019 | + MYCASTCOMPARE(store2Subscriptions.size(), 4); | |
| 1020 | + | |
| 1021 | + for(auto &pair : store1Subscriptions) | |
| 1022 | + { | |
| 1023 | + std::list<SubscriptionForSerializing> &subscList1 = pair.second; | |
| 1024 | + std::list<SubscriptionForSerializing> &subscList2 = store2Subscriptions[pair.first]; | |
| 1025 | + | |
| 1026 | + QCOMPARE(subscList1.size(), subscList2.size()); | |
| 1027 | + | |
| 1028 | + auto subs1It = subscList1.begin(); | |
| 1029 | + auto subs2It = subscList2.begin(); | |
| 1030 | + | |
| 1031 | + while (subs1It != subscList1.end()) | |
| 1032 | + { | |
| 1033 | + SubscriptionForSerializing &one = *subs1It; | |
| 1034 | + SubscriptionForSerializing &two = *subs2It; | |
| 1035 | + QCOMPARE(one.clientId, two.clientId); | |
| 1036 | + QCOMPARE(one.qos, two.qos); | |
| 1037 | + | |
| 1038 | + subs1It++; | |
| 1039 | + subs2It++; | |
| 1040 | + } | |
| 1041 | + | |
| 1042 | + } | |
| 1043 | + | |
| 1044 | + | |
| 1045 | + } | |
| 1046 | + catch (std::exception &ex) | |
| 1047 | + { | |
| 1048 | + QVERIFY2(false, ex.what()); | |
| 1049 | + } | |
| 1050 | +} | |
| 1051 | + | |
| 825 | 1052 | QTEST_GUILESS_MAIN(MainTests) |
| 826 | 1053 | |
| 827 | 1054 | #include "tst_maintests.moc" | ... | ... |
configfileparser.cpp
| ... | ... | @@ -107,6 +107,7 @@ ConfigFileParser::ConfigFileParser(const std::string &path) : |
| 107 | 107 | validKeys.insert("allow_anonymous"); |
| 108 | 108 | validKeys.insert("rlimit_nofile"); |
| 109 | 109 | validKeys.insert("expire_sessions_after_seconds"); |
| 110 | + validKeys.insert("storage_dir"); | |
| 110 | 111 | |
| 111 | 112 | validListenKeys.insert("port"); |
| 112 | 113 | validListenKeys.insert("protocol"); |
| ... | ... | @@ -412,6 +413,14 @@ void ConfigFileParser::loadFile(bool test) |
| 412 | 413 | } |
| 413 | 414 | tmpSettings->authPluginTimerPeriod = newVal; |
| 414 | 415 | } |
| 416 | + | |
| 417 | + if (key == "storage_dir") | |
| 418 | + { | |
| 419 | + std::string newPath = value; | |
| 420 | + rtrim(newPath, '/'); | |
| 421 | + checkWritableDir<ConfigFileException>(newPath); | |
| 422 | + tmpSettings->storageDir = newPath; | |
| 423 | + } | |
| 415 | 424 | } |
| 416 | 425 | } |
| 417 | 426 | catch (std::invalid_argument &ex) // catch for the stoi() | ... | ... |
mainapp.cpp
| ... | ... | @@ -202,6 +202,15 @@ MainApp::MainApp(const std::string &configFilePath) : |
| 202 | 202 | auto fAuthPluginPeriodicEvent = std::bind(&MainApp::queueAuthPluginPeriodicEventAllThreads, this); |
| 203 | 203 | timer.addCallback(fAuthPluginPeriodicEvent, settings->authPluginTimerPeriod*1000, "Auth plugin periodic event."); |
| 204 | 204 | } |
| 205 | + | |
| 206 | + if (!settings->storageDir.empty()) | |
| 207 | + { | |
| 208 | + subscriptionStore->loadRetainedMessages(settings->getRetainedMessagesDBFile()); | |
| 209 | + subscriptionStore->loadSessionsAndSubscriptions(settings->getSessionsDBFile()); | |
| 210 | + } | |
| 211 | + | |
| 212 | + auto fSaveState = std::bind(&MainApp::saveState, this); | |
| 213 | + timer.addCallback(fSaveState, 900000, "Save state."); | |
| 205 | 214 | } |
| 206 | 215 | |
| 207 | 216 | MainApp::~MainApp() |
| ... | ... | @@ -369,6 +378,25 @@ void MainApp::publishStat(const std::string &topic, uint64_t n) |
| 369 | 378 | subscriptionStore->setRetainedMessage(topic, subtopics, payload, 0); |
| 370 | 379 | } |
| 371 | 380 | |
| 381 | +void MainApp::saveState() | |
| 382 | +{ | |
| 383 | + try | |
| 384 | + { | |
| 385 | + if (!settings->storageDir.empty()) | |
| 386 | + { | |
| 387 | + const std::string retainedDBPath = settings->getRetainedMessagesDBFile(); | |
| 388 | + subscriptionStore->saveRetainedMessages(retainedDBPath); | |
| 389 | + | |
| 390 | + const std::string sessionsDBPath = settings->getSessionsDBFile(); | |
| 391 | + subscriptionStore->saveSessionsAndSubscriptions(sessionsDBPath); | |
| 392 | + } | |
| 393 | + } | |
| 394 | + catch(std::exception &ex) | |
| 395 | + { | |
| 396 | + logger->logf(LOG_ERR, "Error saving state: %s", ex.what()); | |
| 397 | + } | |
| 398 | +} | |
| 399 | + | |
| 372 | 400 | void MainApp::initMainApp(int argc, char *argv[]) |
| 373 | 401 | { |
| 374 | 402 | if (instance != nullptr) |
| ... | ... | @@ -659,6 +687,8 @@ void MainApp::start() |
| 659 | 687 | { |
| 660 | 688 | thread->waitForQuit(); |
| 661 | 689 | } |
| 690 | + | |
| 691 | + saveState(); | |
| 662 | 692 | } |
| 663 | 693 | |
| 664 | 694 | void MainApp::quit() | ... | ... |
mainapp.h
mqttpacket.cpp
| ... | ... | @@ -47,8 +47,11 @@ MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_lengt |
| 47 | 47 | pos += fixed_header_length; |
| 48 | 48 | } |
| 49 | 49 | |
| 50 | -// This is easier than using the copy constructor publically, because then I have to keep maintaining a functioning copy constructor. | |
| 51 | -// Returning shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource. | |
| 50 | +/** | |
| 51 | + * @brief MqttPacket::getCopy (using default copy constructor and resetting some selected fields) is easier than using the copy constructor | |
| 52 | + * publically, because then I have to keep maintaining a functioning copy constructor for each new field I add. | |
| 53 | + * @return a shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource. | |
| 54 | + */ | |
| 52 | 55 | std::shared_ptr<MqttPacket> MqttPacket::getCopy() const |
| 53 | 56 | { |
| 54 | 57 | std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(*this)); |
| ... | ... | @@ -129,6 +132,9 @@ MqttPacket::MqttPacket(const Publish &publish) : |
| 129 | 132 | writeBytes(zero, 2); |
| 130 | 133 | } |
| 131 | 134 | |
| 135 | + payloadStart = pos; | |
| 136 | + payloadLen = publish.payload.length(); | |
| 137 | + | |
| 132 | 138 | writeBytes(publish.payload.c_str(), publish.payload.length()); |
| 133 | 139 | calculateRemainingLength(); |
| 134 | 140 | } |
| ... | ... | @@ -546,13 +552,14 @@ void MqttPacket::handlePublish() |
| 546 | 552 | } |
| 547 | 553 | } |
| 548 | 554 | |
| 555 | + payloadLen = remainingAfterPos(); | |
| 556 | + payloadStart = pos; | |
| 557 | + | |
| 549 | 558 | if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success) |
| 550 | 559 | { |
| 551 | 560 | if (retain) |
| 552 | 561 | { |
| 553 | - size_t payload_length = remainingAfterPos(); | |
| 554 | - std::string payload(readBytes(payload_length), payload_length); | |
| 555 | - | |
| 562 | + std::string payload(readBytes(payloadLen), payloadLen); | |
| 556 | 563 | sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, *subtopics, payload, qos); |
| 557 | 564 | } |
| 558 | 565 | |
| ... | ... | @@ -679,6 +686,23 @@ size_t MqttPacket::getTotalMemoryFootprint() |
| 679 | 686 | return bites.size() + sizeof(MqttPacket); |
| 680 | 687 | } |
| 681 | 688 | |
| 689 | +/** | |
| 690 | + * @brief MqttPacket::getPayloadCopy takes part of the vector of bytes and returns it as a string. | |
| 691 | + * @return | |
| 692 | + * | |
| 693 | + * It's necessary sometimes, but it's against FlashMQ's concept of not parsing the payload. Normally, you can just write out | |
| 694 | + * the whole byte array that is a packet to subscribers. No need to copy and such. | |
| 695 | + * | |
| 696 | + * I created it for saving QoS packages in the db file. | |
| 697 | + */ | |
| 698 | +std::string MqttPacket::getPayloadCopy() | |
| 699 | +{ | |
| 700 | + assert(payloadStart > 0); | |
| 701 | + assert(pos <= bites.size()); | |
| 702 | + std::string payload(&bites[payloadStart], payloadLen); | |
| 703 | + return payload; | |
| 704 | +} | |
| 705 | + | |
| 682 | 706 | size_t MqttPacket::getSizeIncludingNonPresentHeader() const |
| 683 | 707 | { |
| 684 | 708 | size_t total = bites.size(); | ... | ... |
mqttpacket.h
| ... | ... | @@ -59,6 +59,8 @@ class MqttPacket |
| 59 | 59 | size_t packet_id_pos = 0; |
| 60 | 60 | uint16_t packet_id = 0; |
| 61 | 61 | ProtocolVersion protocolVersion = ProtocolVersion::None; |
| 62 | + size_t payloadStart = 0; | |
| 63 | + size_t payloadLen = 0; | |
| 62 | 64 | Logger *logger = Logger::getInstance(); |
| 63 | 65 | |
| 64 | 66 | char *readBytes(size_t length); |
| ... | ... | @@ -116,6 +118,7 @@ public: |
| 116 | 118 | uint16_t getPacketId() const; |
| 117 | 119 | void setDuplicate(); |
| 118 | 120 | size_t getTotalMemoryFootprint(); |
| 121 | + std::string getPayloadCopy(); | |
| 119 | 122 | }; |
| 120 | 123 | |
| 121 | 124 | #endif // MQTTPACKET_H | ... | ... |
persistencefile.cpp
0 โ 100644
| 1 | +/* | |
| 2 | +This file is part of FlashMQ (https://www.flashmq.org) | |
| 3 | +Copyright (C) 2021 Wiebe Cazemier | |
| 4 | + | |
| 5 | +FlashMQ is free software: you can redistribute it and/or modify | |
| 6 | +it under the terms of the GNU Affero General Public License as | |
| 7 | +published by the Free Software Foundation, version 3. | |
| 8 | + | |
| 9 | +FlashMQ is distributed in the hope that it will be useful, | |
| 10 | +but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| 11 | +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| 12 | +GNU Affero General Public License for more details. | |
| 13 | + | |
| 14 | +You should have received a copy of the GNU Affero General Public | |
| 15 | +License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | |
| 16 | +*/ | |
| 17 | + | |
| 18 | +#include "persistencefile.h" | |
| 19 | + | |
| 20 | +#include <sys/types.h> | |
| 21 | +#include <sys/stat.h> | |
| 22 | +#include <fcntl.h> | |
| 23 | +#include <unistd.h> | |
| 24 | +#include <exception> | |
| 25 | +#include <stdexcept> | |
| 26 | +#include <stdio.h> | |
| 27 | +#include <cstring> | |
| 28 | + | |
| 29 | +#include "utils.h" | |
| 30 | +#include "logger.h" | |
| 31 | + | |
| 32 | +PersistenceFile::PersistenceFile(const std::string &filePath) : | |
| 33 | + digestContext(EVP_MD_CTX_new()), | |
| 34 | + buf(1024*1024) | |
| 35 | +{ | |
| 36 | + if (!filePath.empty() && filePath[filePath.size() - 1] == '/') | |
| 37 | + throw std::runtime_error("Target file can't contain trailing slash."); | |
| 38 | + | |
| 39 | + this->filePath = filePath; | |
| 40 | + this->filePathTemp = formatString("%s.newfile.%s", filePath.c_str(), getSecureRandomString(8).c_str()); | |
| 41 | + this->filePathCorrupt = formatString("%s.corrupt.%s", filePath.c_str(), getSecureRandomString(8).c_str()); | |
| 42 | +} | |
| 43 | + | |
| 44 | +PersistenceFile::~PersistenceFile() | |
| 45 | +{ | |
| 46 | + closeFile(); | |
| 47 | + | |
| 48 | + if (digestContext) | |
| 49 | + { | |
| 50 | + EVP_MD_CTX_free(digestContext); | |
| 51 | + } | |
| 52 | +} | |
| 53 | + | |
| 54 | +/** | |
| 55 | + * @brief RetainedMessagesDB::hashFile hashes the data after the headers and writes the hash in the header. Uses SHA512. | |
| 56 | + * | |
| 57 | + */ | |
| 58 | +void PersistenceFile::writeCheck(const void *ptr, size_t size, size_t n, FILE *s) | |
| 59 | +{ | |
| 60 | + if (fwrite(ptr, size, n, s) != n) | |
| 61 | + { | |
| 62 | + throw std::runtime_error(formatString("Error writing: %s", strerror(errno))); | |
| 63 | + } | |
| 64 | +} | |
| 65 | + | |
| 66 | +ssize_t PersistenceFile::readCheck(void *ptr, size_t size, size_t n, FILE *stream) | |
| 67 | +{ | |
| 68 | + size_t nread = fread(ptr, size, n, stream); | |
| 69 | + | |
| 70 | + if (nread != n) | |
| 71 | + { | |
| 72 | + if (feof(f)) | |
| 73 | + return -1; | |
| 74 | + | |
| 75 | + throw std::runtime_error(formatString("Error reading: %s", strerror(errno))); | |
| 76 | + } | |
| 77 | + | |
| 78 | + return nread; | |
| 79 | +} | |
| 80 | + | |
| 81 | +void PersistenceFile::hashFile() | |
| 82 | +{ | |
| 83 | + logger->logf(LOG_DEBUG, "Calculating and saving hash of '%s'.", filePath.c_str()); | |
| 84 | + | |
| 85 | + fseek(f, TOTAL_HEADER_SIZE, SEEK_SET); | |
| 86 | + | |
| 87 | + unsigned int output_len = 0; | |
| 88 | + unsigned char md_value[EVP_MAX_MD_SIZE]; | |
| 89 | + std::memset(md_value, 0, EVP_MAX_MD_SIZE); | |
| 90 | + | |
| 91 | + EVP_MD_CTX_reset(digestContext); | |
| 92 | + EVP_DigestInit_ex(digestContext, sha512, NULL); | |
| 93 | + | |
| 94 | + while (!feof(f)) | |
| 95 | + { | |
| 96 | + size_t n = fread(buf.data(), 1, buf.size(), f); | |
| 97 | + EVP_DigestUpdate(digestContext, buf.data(), n); | |
| 98 | + } | |
| 99 | + | |
| 100 | + EVP_DigestFinal_ex(digestContext, md_value, &output_len); | |
| 101 | + | |
| 102 | + if (output_len != HASH_SIZE) | |
| 103 | + throw std::runtime_error("Impossible: calculated hash size wrong length"); | |
| 104 | + | |
| 105 | + fseek(f, MAGIC_STRING_LENGH, SEEK_SET); | |
| 106 | + | |
| 107 | + writeCheck(md_value, output_len, 1, f); | |
| 108 | + fflush(f); | |
| 109 | +} | |
| 110 | + | |
| 111 | +void PersistenceFile::verifyHash() | |
| 112 | +{ | |
| 113 | + fseek(f, 0, SEEK_END); | |
| 114 | + const size_t size = ftell(f); | |
| 115 | + | |
| 116 | + if (size < TOTAL_HEADER_SIZE) | |
| 117 | + throw std::runtime_error(formatString("File '%s' is too small for it even to contain a header.", filePath.c_str())); | |
| 118 | + | |
| 119 | + unsigned char md_from_disk[HASH_SIZE]; | |
| 120 | + std::memset(md_from_disk, 0, HASH_SIZE); | |
| 121 | + | |
| 122 | + fseek(f, MAGIC_STRING_LENGH, SEEK_SET); | |
| 123 | + fread(md_from_disk, 1, HASH_SIZE, f); | |
| 124 | + | |
| 125 | + unsigned int output_len = 0; | |
| 126 | + unsigned char md_value[EVP_MAX_MD_SIZE]; | |
| 127 | + std::memset(md_value, 0, EVP_MAX_MD_SIZE); | |
| 128 | + | |
| 129 | + EVP_MD_CTX_reset(digestContext); | |
| 130 | + EVP_DigestInit_ex(digestContext, sha512, NULL); | |
| 131 | + | |
| 132 | + while (!feof(f)) | |
| 133 | + { | |
| 134 | + size_t n = fread(buf.data(), 1, buf.size(), f); | |
| 135 | + EVP_DigestUpdate(digestContext, buf.data(), n); | |
| 136 | + } | |
| 137 | + | |
| 138 | + EVP_DigestFinal_ex(digestContext, md_value, &output_len); | |
| 139 | + | |
| 140 | + if (output_len != HASH_SIZE) | |
| 141 | + throw std::runtime_error("Impossible: calculated hash size wrong length"); | |
| 142 | + | |
| 143 | + if (std::memcmp(md_from_disk, md_value, output_len) != 0) | |
| 144 | + { | |
| 145 | + fclose(f); | |
| 146 | + f = nullptr; | |
| 147 | + | |
| 148 | + if (rename(filePath.c_str(), filePathCorrupt.c_str()) == 0) | |
| 149 | + { | |
| 150 | + throw std::runtime_error(formatString("File '%s' is corrupt: hash mismatch. Moved aside to '%s'.", filePath.c_str(), filePathCorrupt.c_str())); | |
| 151 | + } | |
| 152 | + else | |
| 153 | + { | |
| 154 | + throw std::runtime_error(formatString("File '%s' is corrupt: hash mismatch. Tried to move aside, but that failed: '%s'.", | |
| 155 | + filePath.c_str(), strerror(errno))); | |
| 156 | + } | |
| 157 | + } | |
| 158 | + | |
| 159 | + logger->logf(LOG_DEBUG, "Hash of '%s' correct", filePath.c_str()); | |
| 160 | +} | |
| 161 | + | |
| 162 | +/** | |
| 163 | + * @brief PersistenceFile::makeSureBufSize grows the buffer if n is bigger. | |
| 164 | + * @param n in bytes. | |
| 165 | + * | |
| 166 | + * Remember that when you're dealing with fields that are sized in MQTT by 16 bit ints, like topic paths, the buffer will always be big enough, because it's 1 MB. | |
| 167 | + */ | |
| 168 | +void PersistenceFile::makeSureBufSize(size_t n) | |
| 169 | +{ | |
| 170 | + if (n > buf.size()) | |
| 171 | + buf.resize(n); | |
| 172 | +} | |
| 173 | + | |
| 174 | +void PersistenceFile::writeInt64(const int64_t val) | |
| 175 | +{ | |
| 176 | + unsigned char buf[8]; | |
| 177 | + | |
| 178 | + // Write big-endian | |
| 179 | + int shift = 56; | |
| 180 | + int i = 0; | |
| 181 | + while (shift >= 0) | |
| 182 | + { | |
| 183 | + unsigned char wantedByte = val >> shift; | |
| 184 | + buf[i++] = wantedByte; | |
| 185 | + shift -= 8; | |
| 186 | + } | |
| 187 | + writeCheck(buf, 1, 8, f); | |
| 188 | +} | |
| 189 | + | |
| 190 | +void PersistenceFile::writeUint32(const uint32_t val) | |
| 191 | +{ | |
| 192 | + unsigned char buf[4]; | |
| 193 | + | |
| 194 | + // Write big-endian | |
| 195 | + int shift = 24; | |
| 196 | + int i = 0; | |
| 197 | + while (shift >= 0) | |
| 198 | + { | |
| 199 | + unsigned char wantedByte = val >> shift; | |
| 200 | + buf[i++] = wantedByte; | |
| 201 | + shift -= 8; | |
| 202 | + } | |
| 203 | + writeCheck(buf, 1, 4, f); | |
| 204 | +} | |
| 205 | + | |
| 206 | +void PersistenceFile::writeUint16(const uint16_t val) | |
| 207 | +{ | |
| 208 | + unsigned char buf[2]; | |
| 209 | + | |
| 210 | + // Write big-endian | |
| 211 | + int shift = 8; | |
| 212 | + int i = 0; | |
| 213 | + while (shift >= 0) | |
| 214 | + { | |
| 215 | + unsigned char wantedByte = val >> shift; | |
| 216 | + buf[i++] = wantedByte; | |
| 217 | + shift -= 8; | |
| 218 | + } | |
| 219 | + writeCheck(buf, 1, 2, f); | |
| 220 | +} | |
| 221 | + | |
| 222 | +int64_t PersistenceFile::readInt64(bool &eofFound) | |
| 223 | +{ | |
| 224 | + if (readCheck(buf.data(), 1, 8, f) < 0) | |
| 225 | + eofFound = true; | |
| 226 | + | |
| 227 | + unsigned char *buf_ = reinterpret_cast<unsigned char *>(buf.data()); | |
| 228 | + const uint64_t val1 = ((buf_[0]) << 24) | ((buf_[1]) << 16) | ((buf_[2]) << 8) | (buf_[3]); | |
| 229 | + const uint64_t val2 = ((buf_[4]) << 24) | ((buf_[5]) << 16) | ((buf_[6]) << 8) | (buf_[7]); | |
| 230 | + const int64_t val = (val1 << 32) | val2; | |
| 231 | + return val; | |
| 232 | +} | |
| 233 | + | |
| 234 | +uint32_t PersistenceFile::readUint32(bool &eofFound) | |
| 235 | +{ | |
| 236 | + if (readCheck(buf.data(), 1, 4, f) < 0) | |
| 237 | + eofFound = true; | |
| 238 | + | |
| 239 | + uint32_t val; | |
| 240 | + unsigned char *buf_ = reinterpret_cast<unsigned char *>(buf.data()); | |
| 241 | + val = ((buf_[0]) << 24) | ((buf_[1]) << 16) | ((buf_[2]) << 8) | (buf_[3]); | |
| 242 | + return val; | |
| 243 | +} | |
| 244 | + | |
| 245 | +uint16_t PersistenceFile::readUint16(bool &eofFound) | |
| 246 | +{ | |
| 247 | + if (readCheck(buf.data(), 1, 2, f) < 0) | |
| 248 | + eofFound = true; | |
| 249 | + | |
| 250 | + uint16_t val; | |
| 251 | + unsigned char *buf_ = reinterpret_cast<unsigned char *>(buf.data()); | |
| 252 | + val = ((buf_[0]) << 8) | (buf_[1]); | |
| 253 | + return val; | |
| 254 | +} | |
| 255 | + | |
| 256 | +/** | |
| 257 | + * @brief RetainedMessagesDB::openWrite doesn't explicitely name a file version (v1, etc), because we always write the current definition. | |
| 258 | + */ | |
| 259 | +void PersistenceFile::openWrite(const std::string &versionString) | |
| 260 | +{ | |
| 261 | + if (openMode != FileMode::unknown) | |
| 262 | + throw std::runtime_error("File is already open."); | |
| 263 | + | |
| 264 | + f = fopen(filePathTemp.c_str(), "w+b"); | |
| 265 | + | |
| 266 | + if (f == nullptr) | |
| 267 | + { | |
| 268 | + throw std::runtime_error(formatString("Can't open '%s': %s", filePathTemp.c_str(), strerror(errno))); | |
| 269 | + } | |
| 270 | + | |
| 271 | + openMode = FileMode::write; | |
| 272 | + | |
| 273 | + writeCheck(buf.data(), 1, MAGIC_STRING_LENGH, f); | |
| 274 | + rewind(f); | |
| 275 | + writeCheck(versionString.c_str(), 1, versionString.length(), f); | |
| 276 | + fseek(f, MAGIC_STRING_LENGH, SEEK_SET); | |
| 277 | + writeCheck(buf.data(), 1, HASH_SIZE, f); | |
| 278 | +} | |
| 279 | + | |
| 280 | +void PersistenceFile::openRead() | |
| 281 | +{ | |
| 282 | + if (openMode != FileMode::unknown) | |
| 283 | + throw std::runtime_error("File is already open."); | |
| 284 | + | |
| 285 | + f = fopen(filePath.c_str(), "rb"); | |
| 286 | + | |
| 287 | + if (f == nullptr) | |
| 288 | + throw PersistenceFileCantBeOpened(formatString("Can't open '%s': %s.", filePath.c_str(), strerror(errno)).c_str()); | |
| 289 | + | |
| 290 | + openMode = FileMode::read; | |
| 291 | + | |
| 292 | + verifyHash(); | |
| 293 | + rewind(f); | |
| 294 | + | |
| 295 | + fread(buf.data(), 1, MAGIC_STRING_LENGH, f); | |
| 296 | + detectedVersionString = std::string(buf.data(), strlen(buf.data())); | |
| 297 | + | |
| 298 | + fseek(f, TOTAL_HEADER_SIZE, SEEK_SET); | |
| 299 | +} | |
| 300 | + | |
| 301 | +void PersistenceFile::closeFile() | |
| 302 | +{ | |
| 303 | + if (!f) | |
| 304 | + return; | |
| 305 | + | |
| 306 | + if (openMode == FileMode::write) | |
| 307 | + hashFile(); | |
| 308 | + | |
| 309 | + if (f != nullptr) | |
| 310 | + { | |
| 311 | + fclose(f); | |
| 312 | + f = nullptr; | |
| 313 | + } | |
| 314 | + | |
| 315 | + if (openMode == FileMode::write && !filePathTemp.empty() && ! filePath.empty()) | |
| 316 | + { | |
| 317 | + if (rename(filePathTemp.c_str(), filePath.c_str()) < 0) | |
| 318 | + throw std::runtime_error(formatString("Saving '%s' failed: rename of temp file to target failed with: %s", filePath.c_str(), strerror(errno))); | |
| 319 | + } | |
| 320 | +} | ... | ... |
persistencefile.h
0 โ 100644
| 1 | +/* | |
| 2 | +This file is part of FlashMQ (https://www.flashmq.org) | |
| 3 | +Copyright (C) 2021 Wiebe Cazemier | |
| 4 | + | |
| 5 | +FlashMQ is free software: you can redistribute it and/or modify | |
| 6 | +it under the terms of the GNU Affero General Public License as | |
| 7 | +published by the Free Software Foundation, version 3. | |
| 8 | + | |
| 9 | +FlashMQ is distributed in the hope that it will be useful, | |
| 10 | +but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| 11 | +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| 12 | +GNU Affero General Public License for more details. | |
| 13 | + | |
| 14 | +You should have received a copy of the GNU Affero General Public | |
| 15 | +License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | |
| 16 | +*/ | |
| 17 | + | |
| 18 | +#ifndef PERSISTENCEFILE_H | |
| 19 | +#define PERSISTENCEFILE_H | |
| 20 | + | |
| 21 | +#include <vector> | |
| 22 | +#include <list> | |
| 23 | +#include <string> | |
| 24 | +#include <stdio.h> | |
| 25 | +#include <openssl/evp.h> | |
| 26 | +#include <stdexcept> | |
| 27 | +#include <cstring> | |
| 28 | + | |
| 29 | +#include "logger.h" | |
| 30 | + | |
| 31 | +#define MAGIC_STRING_LENGH 32 | |
| 32 | +#define HASH_SIZE 64 | |
| 33 | +#define TOTAL_HEADER_SIZE (MAGIC_STRING_LENGH + HASH_SIZE) | |
| 34 | + | |
| 35 | +/** | |
| 36 | + * @brief The PersistenceFileCantBeOpened class should be thrown when a non-fatal file-not-found error happens. | |
| 37 | + */ | |
| 38 | +class PersistenceFileCantBeOpened : public std::runtime_error | |
| 39 | +{ | |
| 40 | +public: | |
| 41 | + PersistenceFileCantBeOpened(const std::string &msg) : std::runtime_error(msg) {} | |
| 42 | +}; | |
| 43 | + | |
| 44 | +class PersistenceFile | |
| 45 | +{ | |
| 46 | + std::string filePath; | |
| 47 | + std::string filePathTemp; | |
| 48 | + std::string filePathCorrupt; | |
| 49 | + | |
| 50 | + EVP_MD_CTX *digestContext = nullptr; | |
| 51 | + const EVP_MD *sha512 = EVP_sha512(); | |
| 52 | + | |
| 53 | + void hashFile(); | |
| 54 | + void verifyHash(); | |
| 55 | + | |
| 56 | +protected: | |
| 57 | + enum class FileMode | |
| 58 | + { | |
| 59 | + unknown, | |
| 60 | + read, | |
| 61 | + write | |
| 62 | + }; | |
| 63 | + | |
| 64 | + FILE *f = nullptr; | |
| 65 | + std::vector<char> buf; | |
| 66 | + FileMode openMode = FileMode::unknown; | |
| 67 | + std::string detectedVersionString; | |
| 68 | + | |
| 69 | + Logger *logger = Logger::getInstance(); | |
| 70 | + | |
| 71 | + void makeSureBufSize(size_t n); | |
| 72 | + | |
| 73 | + void writeCheck(const void *__restrict __ptr, size_t __size, size_t __n, FILE *__restrict __s); | |
| 74 | + ssize_t readCheck(void *__restrict ptr, size_t size, size_t n, FILE *__restrict stream); | |
| 75 | + | |
| 76 | + void writeInt64(const int64_t val); | |
| 77 | + void writeUint32(const uint32_t val); | |
| 78 | + void writeUint16(const uint16_t val); | |
| 79 | + int64_t readInt64(bool &eofFound); | |
| 80 | + uint32_t readUint32(bool &eofFound); | |
| 81 | + uint16_t readUint16(bool &eofFound); | |
| 82 | + | |
| 83 | +public: | |
| 84 | + PersistenceFile(const std::string &filePath); | |
| 85 | + ~PersistenceFile(); | |
| 86 | + | |
| 87 | + void openWrite(const std::string &versionString); | |
| 88 | + void openRead(); | |
| 89 | + void closeFile(); | |
| 90 | +}; | |
| 91 | + | |
| 92 | +#endif // PERSISTENCEFILE_H | ... | ... |
qospacketqueue.cpp
0 โ 100644
| 1 | +#include "qospacketqueue.h" | |
| 2 | + | |
| 3 | +#include "cassert" | |
| 4 | + | |
| 5 | +#include "mqttpacket.h" | |
| 6 | + | |
| 7 | +void QoSPacketQueue::erase(const uint16_t packet_id) | |
| 8 | +{ | |
| 9 | + auto it = queue.begin(); | |
| 10 | + auto end = queue.end(); | |
| 11 | + while (it != end) | |
| 12 | + { | |
| 13 | + std::shared_ptr<MqttPacket> &p = *it; | |
| 14 | + if (p->getPacketId() == packet_id) | |
| 15 | + { | |
| 16 | + size_t mem = p->getTotalMemoryFootprint(); | |
| 17 | + qosQueueBytes -= mem; | |
| 18 | + assert(qosQueueBytes >= 0); | |
| 19 | + if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose. | |
| 20 | + qosQueueBytes = 0; | |
| 21 | + | |
| 22 | + queue.erase(it); | |
| 23 | + | |
| 24 | + break; | |
| 25 | + } | |
| 26 | + | |
| 27 | + it++; | |
| 28 | + } | |
| 29 | +} | |
| 30 | + | |
| 31 | +size_t QoSPacketQueue::size() const | |
| 32 | +{ | |
| 33 | + return queue.size(); | |
| 34 | +} | |
| 35 | + | |
| 36 | +size_t QoSPacketQueue::getByteSize() const | |
| 37 | +{ | |
| 38 | + return qosQueueBytes; | |
| 39 | +} | |
| 40 | + | |
| 41 | +/** | |
| 42 | + * @brief QoSPacketQueue::queuePacket makes a copy of the packet because it has state for the receiver in question. | |
| 43 | + * @param p | |
| 44 | + * @param id | |
| 45 | + * @return the packet copy. | |
| 46 | + */ | |
| 47 | +std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id) | |
| 48 | +{ | |
| 49 | + assert(p.getQos() > 0); | |
| 50 | + | |
| 51 | + std::shared_ptr<MqttPacket> copyPacket = p.getCopy(); | |
| 52 | + copyPacket->setPacketId(id); | |
| 53 | + queue.push_back(copyPacket); | |
| 54 | + qosQueueBytes += copyPacket->getTotalMemoryFootprint(); | |
| 55 | + return copyPacket; | |
| 56 | +} | |
| 57 | + | |
| 58 | +std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const Publish &pub, uint16_t id) | |
| 59 | +{ | |
| 60 | + assert(pub.qos > 0); | |
| 61 | + | |
| 62 | + std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(pub)); | |
| 63 | + copyPacket->setPacketId(id); | |
| 64 | + queue.push_back(copyPacket); | |
| 65 | + qosQueueBytes += copyPacket->getTotalMemoryFootprint(); | |
| 66 | + return copyPacket; | |
| 67 | +} | |
| 68 | + | |
| 69 | +std::list<std::shared_ptr<MqttPacket>>::const_iterator QoSPacketQueue::begin() const | |
| 70 | +{ | |
| 71 | + return queue.cbegin(); | |
| 72 | +} | |
| 73 | + | |
| 74 | +std::list<std::shared_ptr<MqttPacket>>::const_iterator QoSPacketQueue::end() const | |
| 75 | +{ | |
| 76 | + return queue.cend(); | |
| 77 | +} | ... | ... |
qospacketqueue.h
0 โ 100644
| 1 | +#ifndef QOSPACKETQUEUE_H | |
| 2 | +#define QOSPACKETQUEUE_H | |
| 3 | + | |
| 4 | +#include "list" | |
| 5 | + | |
| 6 | +#include "forward_declarations.h" | |
| 7 | +#include "types.h" | |
| 8 | + | |
| 9 | +class QoSPacketQueue | |
| 10 | +{ | |
| 11 | + std::list<std::shared_ptr<MqttPacket>> queue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] | |
| 12 | + ssize_t qosQueueBytes = 0; | |
| 13 | + | |
| 14 | +public: | |
| 15 | + void erase(const uint16_t packet_id); | |
| 16 | + size_t size() const; | |
| 17 | + size_t getByteSize() const; | |
| 18 | + std::shared_ptr<MqttPacket> queuePacket(const MqttPacket &p, uint16_t id); | |
| 19 | + std::shared_ptr<MqttPacket> queuePacket(const Publish &pub, uint16_t id); | |
| 20 | + | |
| 21 | + std::list<std::shared_ptr<MqttPacket>>::const_iterator begin() const; | |
| 22 | + std::list<std::shared_ptr<MqttPacket>>::const_iterator end() const; | |
| 23 | +}; | |
| 24 | + | |
| 25 | +#endif // QOSPACKETQUEUE_H | ... | ... |
retainedmessage.cpp
retainedmessage.h
retainedmessagesdb.cpp
0 โ 100644
| 1 | +/* | |
| 2 | +This file is part of FlashMQ (https://www.flashmq.org) | |
| 3 | +Copyright (C) 2021 Wiebe Cazemier | |
| 4 | + | |
| 5 | +FlashMQ is free software: you can redistribute it and/or modify | |
| 6 | +it under the terms of the GNU Affero General Public License as | |
| 7 | +published by the Free Software Foundation, version 3. | |
| 8 | + | |
| 9 | +FlashMQ is distributed in the hope that it will be useful, | |
| 10 | +but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| 11 | +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| 12 | +GNU Affero General Public License for more details. | |
| 13 | + | |
| 14 | +You should have received a copy of the GNU Affero General Public | |
| 15 | +License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | |
| 16 | +*/ | |
| 17 | + | |
| 18 | +#include <sys/types.h> | |
| 19 | +#include <sys/stat.h> | |
| 20 | +#include <fcntl.h> | |
| 21 | +#include <unistd.h> | |
| 22 | +#include <exception> | |
| 23 | +#include <stdexcept> | |
| 24 | +#include <stdio.h> | |
| 25 | +#include <cstring> | |
| 26 | + | |
| 27 | +#include "retainedmessagesdb.h" | |
| 28 | +#include "utils.h" | |
| 29 | +#include "logger.h" | |
| 30 | + | |
| 31 | +RetainedMessagesDB::RetainedMessagesDB(const std::string &filePath) : PersistenceFile(filePath) | |
| 32 | +{ | |
| 33 | + | |
| 34 | +} | |
| 35 | + | |
| 36 | +void RetainedMessagesDB::openWrite() | |
| 37 | +{ | |
| 38 | + PersistenceFile::openWrite(MAGIC_STRING_V1); | |
| 39 | +} | |
| 40 | + | |
| 41 | +void RetainedMessagesDB::openRead() | |
| 42 | +{ | |
| 43 | + PersistenceFile::openRead(); | |
| 44 | + | |
| 45 | + if (detectedVersionString == MAGIC_STRING_V1) | |
| 46 | + readVersion = ReadVersion::v1; | |
| 47 | + else | |
| 48 | + throw std::runtime_error("Unknown file version."); | |
| 49 | +} | |
| 50 | + | |
| 51 | +/** | |
| 52 | + * @brief RetainedMessagesDB::writeRowHeader writes two 32 bit integers: topic size and payload size. | |
| 53 | + * @param rm | |
| 54 | + * | |
| 55 | + * So, the header per message is 8 bytes long. | |
| 56 | + * | |
| 57 | + * It writes no information about the length of the QoS value, because that is always one. | |
| 58 | + */ | |
| 59 | +void RetainedMessagesDB::writeRowHeader(const RetainedMessage &rm) | |
| 60 | +{ | |
| 61 | + writeUint32(rm.topic.size()); | |
| 62 | + writeUint32(rm.payload.size()); | |
| 63 | +} | |
| 64 | + | |
| 65 | +RetainedMessagesDB::RowHeader RetainedMessagesDB::readRowHeaderV1(bool &eofFound) | |
| 66 | +{ | |
| 67 | + RetainedMessagesDB::RowHeader result; | |
| 68 | + | |
| 69 | + result.topicLen = readUint32(eofFound); | |
| 70 | + result.payloadLen = readUint32(eofFound); | |
| 71 | + | |
| 72 | + return result; | |
| 73 | +} | |
| 74 | + | |
| 75 | +/** | |
| 76 | + * @brief RetainedMessagesDB::saveData doesn't explicitely name a file version (v1, etc), because we always write the current definition. | |
| 77 | + * @param messages | |
| 78 | + */ | |
| 79 | +void RetainedMessagesDB::saveData(const std::vector<RetainedMessage> &messages) | |
| 80 | +{ | |
| 81 | + if (!f) | |
| 82 | + return; | |
| 83 | + | |
| 84 | + char reserved[RESERVED_SPACE_RETAINED_DB_V1]; | |
| 85 | + std::memset(reserved, 0, RESERVED_SPACE_RETAINED_DB_V1); | |
| 86 | + | |
| 87 | + char qos = 0; | |
| 88 | + for (const RetainedMessage &rm : messages) | |
| 89 | + { | |
| 90 | + logger->logf(LOG_DEBUG, "Saving retained message for topic '%s' QoS %d.", rm.topic.c_str(), rm.qos); | |
| 91 | + | |
| 92 | + writeRowHeader(rm); | |
| 93 | + qos = rm.qos; | |
| 94 | + writeCheck(&qos, 1, 1, f); | |
| 95 | + writeCheck(reserved, 1, RESERVED_SPACE_RETAINED_DB_V1, f); | |
| 96 | + writeCheck(rm.topic.c_str(), 1, rm.topic.length(), f); | |
| 97 | + writeCheck(rm.payload.c_str(), 1, rm.payload.length(), f); | |
| 98 | + } | |
| 99 | + | |
| 100 | + fflush(f); | |
| 101 | +} | |
| 102 | + | |
| 103 | +std::list<RetainedMessage> RetainedMessagesDB::readData() | |
| 104 | +{ | |
| 105 | + std::list<RetainedMessage> defaultResult; | |
| 106 | + | |
| 107 | + if (!f) | |
| 108 | + return defaultResult; | |
| 109 | + | |
| 110 | + if (readVersion == ReadVersion::v1) | |
| 111 | + return readDataV1(); | |
| 112 | + | |
| 113 | + return defaultResult; | |
| 114 | +} | |
| 115 | + | |
| 116 | +std::list<RetainedMessage> RetainedMessagesDB::readDataV1() | |
| 117 | +{ | |
| 118 | + std::list<RetainedMessage> messages; | |
| 119 | + | |
| 120 | + while (!feof(f)) | |
| 121 | + { | |
| 122 | + bool eofFound = false; | |
| 123 | + RetainedMessagesDB::RowHeader header = readRowHeaderV1(eofFound); | |
| 124 | + | |
| 125 | + if (eofFound) | |
| 126 | + continue; | |
| 127 | + | |
| 128 | + makeSureBufSize(header.payloadLen); | |
| 129 | + | |
| 130 | + readCheck(buf.data(), 1, 1, f); | |
| 131 | + char qos = buf[0]; | |
| 132 | + fseek(f, RESERVED_SPACE_RETAINED_DB_V1, SEEK_CUR); | |
| 133 | + | |
| 134 | + readCheck(buf.data(), 1, header.topicLen, f); | |
| 135 | + std::string topic(buf.data(), header.topicLen); | |
| 136 | + | |
| 137 | + readCheck(buf.data(), 1, header.payloadLen, f); | |
| 138 | + std::string payload(buf.data(), header.payloadLen); | |
| 139 | + | |
| 140 | + RetainedMessage msg(topic, payload, qos); | |
| 141 | + logger->logf(LOG_DEBUG, "Loading retained message for topic '%s' QoS %d.", msg.topic.c_str(), msg.qos); | |
| 142 | + messages.push_back(std::move(msg)); | |
| 143 | + } | |
| 144 | + | |
| 145 | + return messages; | |
| 146 | +} | ... | ... |
retainedmessagesdb.h
0 โ 100644
| 1 | +/* | |
| 2 | +This file is part of FlashMQ (https://www.flashmq.org) | |
| 3 | +Copyright (C) 2021 Wiebe Cazemier | |
| 4 | + | |
| 5 | +FlashMQ is free software: you can redistribute it and/or modify | |
| 6 | +it under the terms of the GNU Affero General Public License as | |
| 7 | +published by the Free Software Foundation, version 3. | |
| 8 | + | |
| 9 | +FlashMQ is distributed in the hope that it will be useful, | |
| 10 | +but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| 11 | +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| 12 | +GNU Affero General Public License for more details. | |
| 13 | + | |
| 14 | +You should have received a copy of the GNU Affero General Public | |
| 15 | +License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | |
| 16 | +*/ | |
| 17 | + | |
| 18 | +#ifndef RETAINEDMESSAGESDB_H | |
| 19 | +#define RETAINEDMESSAGESDB_H | |
| 20 | + | |
| 21 | +#include "persistencefile.h" | |
| 22 | +#include "retainedmessage.h" | |
| 23 | + | |
| 24 | +#include "logger.h" | |
| 25 | + | |
| 26 | +#define MAGIC_STRING_V1 "FlashMQRetainedDBv1" | |
| 27 | +#define ROW_HEADER_SIZE 8 | |
| 28 | +#define RESERVED_SPACE_RETAINED_DB_V1 31 | |
| 29 | + | |
| 30 | +/** | |
| 31 | + * @brief The RetainedMessagesDB class saves and loads the retained messages. | |
| 32 | + * | |
| 33 | + * The DB looks like, from the top: | |
| 34 | + * | |
| 35 | + * MAGIC_STRING_LENGH bytes file header | |
| 36 | + * HASH_SIZE SHA512 | |
| 37 | + * [MESSAGES] | |
| 38 | + * | |
| 39 | + * Each message has a row header, which is 8 bytes. See writeRowHeader(). | |
| 40 | + * | |
| 41 | + */ | |
| 42 | +class RetainedMessagesDB : public PersistenceFile | |
| 43 | +{ | |
| 44 | + enum class ReadVersion | |
| 45 | + { | |
| 46 | + unknown, | |
| 47 | + v1 | |
| 48 | + }; | |
| 49 | + | |
| 50 | + struct RowHeader | |
| 51 | + { | |
| 52 | + uint32_t topicLen = 0; | |
| 53 | + uint32_t payloadLen = 0; | |
| 54 | + }; | |
| 55 | + | |
| 56 | + ReadVersion readVersion = ReadVersion::unknown; | |
| 57 | + | |
| 58 | + void writeRowHeader(const RetainedMessage &rm); | |
| 59 | + RowHeader readRowHeaderV1(bool &eofFound); | |
| 60 | + std::list<RetainedMessage> readDataV1(); | |
| 61 | +public: | |
| 62 | + RetainedMessagesDB(const std::string &filePath); | |
| 63 | + | |
| 64 | + void openWrite(); | |
| 65 | + void openRead(); | |
| 66 | + | |
| 67 | + void saveData(const std::vector<RetainedMessage> &messages); | |
| 68 | + std::list<RetainedMessage> readData(); | |
| 69 | +}; | |
| 70 | + | |
| 71 | +#endif // RETAINEDMESSAGESDB_H | ... | ... |
rwlockguard.cpp
| ... | ... | @@ -32,7 +32,15 @@ RWLockGuard::~RWLockGuard() |
| 32 | 32 | |
| 33 | 33 | void RWLockGuard::wrlock() |
| 34 | 34 | { |
| 35 | - if (pthread_rwlock_wrlock(rwlock) != 0) | |
| 35 | + const int rc = pthread_rwlock_wrlock(rwlock); | |
| 36 | + | |
| 37 | + if (rc == EDEADLK) | |
| 38 | + { | |
| 39 | + rwlock = nullptr; | |
| 40 | + return; | |
| 41 | + } | |
| 42 | + | |
| 43 | + if (rc != 0) | |
| 36 | 44 | throw std::runtime_error("wrlock failed."); |
| 37 | 45 | } |
| 38 | 46 | ... | ... |
session.cpp
| ... | ... | @@ -20,16 +20,75 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. |
| 20 | 20 | #include "session.h" |
| 21 | 21 | #include "client.h" |
| 22 | 22 | |
| 23 | +std::chrono::time_point<std::chrono::steady_clock> appStartTime = std::chrono::steady_clock::now(); | |
| 24 | + | |
| 23 | 25 | Session::Session() |
| 24 | 26 | { |
| 25 | 27 | |
| 26 | 28 | } |
| 27 | 29 | |
| 30 | +int64_t Session::getProgramStartedAtUnixTimestamp() | |
| 31 | +{ | |
| 32 | + auto secondsSinceEpoch = std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count(); | |
| 33 | + const std::chrono::seconds age = std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - appStartTime); | |
| 34 | + int64_t result = secondsSinceEpoch - age.count(); | |
| 35 | + return result; | |
| 36 | +} | |
| 37 | + | |
| 38 | +void Session::setProgramStartedAtUnixTimestamp(const int64_t unix_timestamp) | |
| 39 | +{ | |
| 40 | + auto secondsSinceEpoch = std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()); | |
| 41 | + const std::chrono::seconds _unix_timestamp = std::chrono::seconds(unix_timestamp); | |
| 42 | + const std::chrono::seconds age_in_s = secondsSinceEpoch - _unix_timestamp; | |
| 43 | + appStartTime = std::chrono::steady_clock::now() - age_in_s; | |
| 44 | +} | |
| 45 | + | |
| 46 | + | |
| 47 | +int64_t Session::getSessionRelativeAgeInMs() const | |
| 48 | +{ | |
| 49 | + const std::chrono::milliseconds sessionAge = std::chrono::duration_cast<std::chrono::milliseconds>(lastTouched - appStartTime); | |
| 50 | + const int64_t sInMs = sessionAge.count(); | |
| 51 | + return sInMs; | |
| 52 | +} | |
| 53 | + | |
| 54 | +void Session::setSessionTouch(int64_t ageInMs) | |
| 55 | +{ | |
| 56 | + std::chrono::milliseconds ms(ageInMs); | |
| 57 | + std::chrono::time_point<std::chrono::steady_clock> point = appStartTime + ms; | |
| 58 | + lastTouched = point; | |
| 59 | +} | |
| 60 | + | |
| 61 | +/** | |
| 62 | + * @brief Session::Session copy constructor. Was created for session storing, and is explicitely kept private, to avoid making accidental copies. | |
| 63 | + * @param other | |
| 64 | + * | |
| 65 | + * Because it was created for session storing, the fields we're copying are the fields being stored. | |
| 66 | + */ | |
| 67 | +Session::Session(const Session &other) | |
| 68 | +{ | |
| 69 | + this->username = other.username; | |
| 70 | + this->client_id = other.client_id; | |
| 71 | + this->incomingQoS2MessageIds = other.incomingQoS2MessageIds; | |
| 72 | + this->outgoingQoS2MessageIds = other.outgoingQoS2MessageIds; | |
| 73 | + this->nextPacketId = other.nextPacketId; | |
| 74 | + this->lastTouched = other.lastTouched; | |
| 75 | + | |
| 76 | + // To be fully correct, we should copy the individual packets, but copying sessions is only done for saving them, and I know | |
| 77 | + // that no member of MqttPacket changes in the QoS process, so we can just keep the shared pointer to the original. | |
| 78 | + this->qosPacketQueue = other.qosPacketQueue; | |
| 79 | +} | |
| 80 | + | |
| 28 | 81 | Session::~Session() |
| 29 | 82 | { |
| 30 | 83 | logger->logf(LOG_DEBUG, "Session %s is being destroyed.", getClientId().c_str()); |
| 31 | 84 | } |
| 32 | 85 | |
| 86 | +std::unique_ptr<Session> Session::getCopy() const | |
| 87 | +{ | |
| 88 | + std::unique_ptr<Session> s(new Session(*this)); | |
| 89 | + return s; | |
| 90 | +} | |
| 91 | + | |
| 33 | 92 | bool Session::clientDisconnected() const |
| 34 | 93 | { |
| 35 | 94 | return client.expired(); |
| ... | ... | @@ -45,7 +104,6 @@ void Session::assignActiveConnection(std::shared_ptr<Client> &client) |
| 45 | 104 | this->client = client; |
| 46 | 105 | this->client_id = client->getClientId(); |
| 47 | 106 | this->username = client->getUsername(); |
| 48 | - this->thread = client->getThreadData(); | |
| 49 | 107 | } |
| 50 | 108 | |
| 51 | 109 | /** |
| ... | ... | @@ -60,7 +118,9 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u |
| 60 | 118 | assert(max_qos <= 2); |
| 61 | 119 | const char qos = std::min<char>(packet.getQos(), max_qos); |
| 62 | 120 | |
| 63 | - if (thread->authentication.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success) | |
| 121 | + assert(packet.getSender()); | |
| 122 | + Authentication &auth = packet.getSender()->getThreadData()->authentication; | |
| 123 | + if (auth.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success) | |
| 64 | 124 | { |
| 65 | 125 | if (qos == 0) |
| 66 | 126 | { |
| ... | ... | @@ -73,11 +133,10 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u |
| 73 | 133 | } |
| 74 | 134 | else if (qos > 0) |
| 75 | 135 | { |
| 76 | - std::shared_ptr<MqttPacket> copyPacket = packet.getCopy(); | |
| 77 | 136 | std::unique_lock<std::mutex> locker(qosQueueMutex); |
| 78 | 137 | |
| 79 | 138 | const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size(); |
| 80 | - if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) | |
| 139 | + if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosPacketQueue.getByteSize() >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) | |
| 81 | 140 | { |
| 82 | 141 | logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str()); |
| 83 | 142 | return; |
| ... | ... | @@ -86,13 +145,7 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u |
| 86 | 145 | if (nextPacketId == 0) |
| 87 | 146 | nextPacketId++; |
| 88 | 147 | |
| 89 | - const uint16_t pid = nextPacketId; | |
| 90 | - copyPacket->setPacketId(pid); | |
| 91 | - QueuedQosPacket p; | |
| 92 | - p.packet = copyPacket; | |
| 93 | - p.id = pid; | |
| 94 | - qosPacketQueue.push_back(p); | |
| 95 | - qosQueueBytes += copyPacket->getTotalMemoryFootprint(); | |
| 148 | + std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId); | |
| 96 | 149 | locker.unlock(); |
| 97 | 150 | |
| 98 | 151 | if (!clientDisconnected()) |
| ... | ... | @@ -115,27 +168,7 @@ void Session::clearQosMessage(uint16_t packet_id) |
| 115 | 168 | #endif |
| 116 | 169 | |
| 117 | 170 | std::lock_guard<std::mutex> locker(qosQueueMutex); |
| 118 | - | |
| 119 | - auto it = qosPacketQueue.begin(); | |
| 120 | - auto end = qosPacketQueue.end(); | |
| 121 | - while (it != end) | |
| 122 | - { | |
| 123 | - QueuedQosPacket &p = *it; | |
| 124 | - if (p.id == packet_id) | |
| 125 | - { | |
| 126 | - size_t mem = p.packet->getTotalMemoryFootprint(); | |
| 127 | - qosQueueBytes -= mem; | |
| 128 | - assert(qosQueueBytes >= 0); | |
| 129 | - if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose. | |
| 130 | - qosQueueBytes = 0; | |
| 131 | - | |
| 132 | - qosPacketQueue.erase(it); | |
| 133 | - | |
| 134 | - break; | |
| 135 | - } | |
| 136 | - | |
| 137 | - it++; | |
| 138 | - } | |
| 171 | + qosPacketQueue.erase(packet_id); | |
| 139 | 172 | } |
| 140 | 173 | |
| 141 | 174 | // [MQTT-4.4.0-1]: "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST re-send any |
| ... | ... | @@ -152,10 +185,10 @@ uint64_t Session::sendPendingQosMessages() |
| 152 | 185 | { |
| 153 | 186 | std::shared_ptr<Client> c = makeSharedClient(); |
| 154 | 187 | std::lock_guard<std::mutex> locker(qosQueueMutex); |
| 155 | - for (QueuedQosPacket &qosMessage : qosPacketQueue) | |
| 188 | + for (const std::shared_ptr<MqttPacket> &qosMessage : qosPacketQueue) | |
| 156 | 189 | { |
| 157 | - c->writeMqttPacketAndBlameThisClient(*qosMessage.packet.get(), qosMessage.packet->getQos()); | |
| 158 | - qosMessage.packet->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | |
| 190 | + c->writeMqttPacketAndBlameThisClient(*qosMessage.get(), qosMessage->getQos()); | |
| 191 | + qosMessage->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | |
| 159 | 192 | count++; |
| 160 | 193 | } |
| 161 | 194 | ... | ... |
session.h
| ... | ... | @@ -25,38 +25,46 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. |
| 25 | 25 | |
| 26 | 26 | #include "forward_declarations.h" |
| 27 | 27 | #include "logger.h" |
| 28 | +#include "sessionsandsubscriptionsdb.h" | |
| 29 | +#include "qospacketqueue.h" | |
| 28 | 30 | |
| 29 | 31 | // TODO make settings. But, num of packets can't exceed 65536, because the counter is 16 bit. |
| 30 | 32 | #define MAX_QOS_MSG_PENDING_PER_CLIENT 32 |
| 31 | 33 | #define MAX_QOS_BYTES_PENDING_PER_CLIENT 4096 |
| 32 | 34 | |
| 33 | -struct QueuedQosPacket | |
| 34 | -{ | |
| 35 | - uint16_t id; | |
| 36 | - std::shared_ptr<MqttPacket> packet; | |
| 37 | -}; | |
| 38 | - | |
| 39 | 35 | class Session |
| 40 | 36 | { |
| 37 | +#ifdef TESTING | |
| 38 | + friend class MainTests; | |
| 39 | +#endif | |
| 40 | + | |
| 41 | + friend class SessionsAndSubscriptionsDB; | |
| 42 | + | |
| 41 | 43 | std::weak_ptr<Client> client; |
| 42 | - std::shared_ptr<ThreadData> thread; | |
| 43 | 44 | std::string client_id; |
| 44 | 45 | std::string username; |
| 45 | - std::list<QueuedQosPacket> qosPacketQueue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] | |
| 46 | + QoSPacketQueue qosPacketQueue; | |
| 46 | 47 | std::set<uint16_t> incomingQoS2MessageIds; |
| 47 | 48 | std::set<uint16_t> outgoingQoS2MessageIds; |
| 48 | 49 | std::mutex qosQueueMutex; |
| 49 | 50 | uint16_t nextPacketId = 0; |
| 50 | - ssize_t qosQueueBytes = 0; | |
| 51 | - std::chrono::time_point<std::chrono::steady_clock> lastTouched; | |
| 51 | + std::chrono::time_point<std::chrono::steady_clock> lastTouched = std::chrono::steady_clock::now(); | |
| 52 | 52 | Logger *logger = Logger::getInstance(); |
| 53 | + int64_t getSessionRelativeAgeInMs() const; | |
| 54 | + void setSessionTouch(int64_t ageInMs); | |
| 53 | 55 | |
| 56 | + Session(const Session &other); | |
| 54 | 57 | public: |
| 55 | 58 | Session(); |
| 56 | - Session(const Session &other) = delete; | |
| 59 | + | |
| 57 | 60 | Session(Session &&other) = delete; |
| 58 | 61 | ~Session(); |
| 59 | 62 | |
| 63 | + static int64_t getProgramStartedAtUnixTimestamp(); | |
| 64 | + static void setProgramStartedAtUnixTimestamp(const int64_t unix_timestamp); | |
| 65 | + | |
| 66 | + std::unique_ptr<Session> getCopy() const; | |
| 67 | + | |
| 60 | 68 | const std::string &getClientId() const { return client_id; } |
| 61 | 69 | bool clientDisconnected() const; |
| 62 | 70 | std::shared_ptr<Client> makeSharedClient() const; |
| ... | ... | @@ -74,7 +82,6 @@ public: |
| 74 | 82 | |
| 75 | 83 | void addOutgoingQoS2MessageId(uint16_t packet_id); |
| 76 | 84 | void removeOutgoingQoS2MessageId(u_int16_t packet_id); |
| 77 | - | |
| 78 | 85 | }; |
| 79 | 86 | |
| 80 | 87 | #endif // SESSION_H | ... | ... |
sessionsandsubscriptionsdb.cpp
0 โ 100644
| 1 | +/* | |
| 2 | +This file is part of FlashMQ (https://www.flashmq.org) | |
| 3 | +Copyright (C) 2021 Wiebe Cazemier | |
| 4 | + | |
| 5 | +FlashMQ is free software: you can redistribute it and/or modify | |
| 6 | +it under the terms of the GNU Affero General Public License as | |
| 7 | +published by the Free Software Foundation, version 3. | |
| 8 | + | |
| 9 | +FlashMQ is distributed in the hope that it will be useful, | |
| 10 | +but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| 11 | +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| 12 | +GNU Affero General Public License for more details. | |
| 13 | + | |
| 14 | +You should have received a copy of the GNU Affero General Public | |
| 15 | +License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | |
| 16 | +*/ | |
| 17 | + | |
| 18 | +#include "sessionsandsubscriptionsdb.h" | |
| 19 | +#include "mqttpacket.h" | |
| 20 | + | |
| 21 | +#include "cassert" | |
| 22 | + | |
| 23 | +SubscriptionForSerializing::SubscriptionForSerializing(const std::string &clientId, char qos) : | |
| 24 | + clientId(clientId), | |
| 25 | + qos(qos) | |
| 26 | +{ | |
| 27 | + | |
| 28 | +} | |
| 29 | + | |
| 30 | +SubscriptionForSerializing::SubscriptionForSerializing(const std::string &&clientId, char qos) : | |
| 31 | + clientId(clientId), | |
| 32 | + qos(qos) | |
| 33 | +{ | |
| 34 | + | |
| 35 | +} | |
| 36 | + | |
| 37 | +SessionsAndSubscriptionsDB::SessionsAndSubscriptionsDB(const std::string &filePath) : PersistenceFile(filePath) | |
| 38 | +{ | |
| 39 | + | |
| 40 | +} | |
| 41 | + | |
| 42 | +void SessionsAndSubscriptionsDB::openWrite() | |
| 43 | +{ | |
| 44 | + PersistenceFile::openWrite(MAGIC_STRING_SESSION_FILE_V1); | |
| 45 | +} | |
| 46 | + | |
| 47 | +void SessionsAndSubscriptionsDB::openRead() | |
| 48 | +{ | |
| 49 | + PersistenceFile::openRead(); | |
| 50 | + | |
| 51 | + if (detectedVersionString == MAGIC_STRING_SESSION_FILE_V1) | |
| 52 | + readVersion = ReadVersion::v1; | |
| 53 | + else | |
| 54 | + throw std::runtime_error("Unknown file version."); | |
| 55 | +} | |
| 56 | + | |
| 57 | +SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV1() | |
| 58 | +{ | |
| 59 | + SessionsAndSubscriptionsResult result; | |
| 60 | + | |
| 61 | + while (!feof(f)) | |
| 62 | + { | |
| 63 | + bool eofFound = false; | |
| 64 | + | |
| 65 | + const int64_t programStartAge = readInt64(eofFound); | |
| 66 | + if (eofFound) | |
| 67 | + continue; | |
| 68 | + | |
| 69 | + logger->logf(LOG_DEBUG, "Setting first app start time to timestamp %ld", programStartAge); | |
| 70 | + Session::setProgramStartedAtUnixTimestamp(programStartAge); | |
| 71 | + | |
| 72 | + const uint32_t nrOfSessions = readUint32(eofFound); | |
| 73 | + | |
| 74 | + if (eofFound) | |
| 75 | + continue; | |
| 76 | + | |
| 77 | + std::vector<char> reserved(RESERVED_SPACE_SESSIONS_DB_V1); | |
| 78 | + | |
| 79 | + for (uint32_t i = 0; i < nrOfSessions; i++) | |
| 80 | + { | |
| 81 | + readCheck(buf.data(), 1, RESERVED_SPACE_SESSIONS_DB_V1, f); | |
| 82 | + | |
| 83 | + uint32_t usernameLength = readUint32(eofFound); | |
| 84 | + readCheck(buf.data(), 1, usernameLength, f); | |
| 85 | + std::string username(buf.data(), usernameLength); | |
| 86 | + | |
| 87 | + uint32_t clientIdLength = readUint32(eofFound); | |
| 88 | + readCheck(buf.data(), 1, clientIdLength, f); | |
| 89 | + std::string clientId(buf.data(), clientIdLength); | |
| 90 | + | |
| 91 | + std::shared_ptr<Session> ses(new Session()); | |
| 92 | + result.sessions.push_back(ses); | |
| 93 | + ses->username = username; | |
| 94 | + ses->client_id = clientId; | |
| 95 | + | |
| 96 | + logger->logf(LOG_DEBUG, "Loading session '%s'.", ses->getClientId().c_str()); | |
| 97 | + | |
| 98 | + const uint32_t nrOfQueuedQoSPackets = readUint32(eofFound); | |
| 99 | + for (uint32_t i = 0; i < nrOfQueuedQoSPackets; i++) | |
| 100 | + { | |
| 101 | + const uint16_t id = readUint16(eofFound); | |
| 102 | + const uint32_t topicSize = readUint32(eofFound); | |
| 103 | + const uint32_t payloadSize = readUint32(eofFound); | |
| 104 | + | |
| 105 | + assert(id > 0); | |
| 106 | + | |
| 107 | + readCheck(buf.data(), 1, 1, f); | |
| 108 | + const unsigned char qos = buf[0]; | |
| 109 | + | |
| 110 | + readCheck(buf.data(), 1, topicSize, f); | |
| 111 | + const std::string topic(buf.data(), topicSize); | |
| 112 | + | |
| 113 | + makeSureBufSize(payloadSize); | |
| 114 | + readCheck(buf.data(), 1, payloadSize, f); | |
| 115 | + const std::string payload(buf.data(), payloadSize); | |
| 116 | + | |
| 117 | + Publish pub(topic, payload, qos); | |
| 118 | + logger->logf(LOG_DEBUG, "Loaded QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str()); | |
| 119 | + ses->qosPacketQueue.queuePacket(pub, id); | |
| 120 | + } | |
| 121 | + | |
| 122 | + const uint32_t nrOfIncomingPacketIds = readUint32(eofFound); | |
| 123 | + for (uint32_t i = 0; i < nrOfIncomingPacketIds; i++) | |
| 124 | + { | |
| 125 | + uint16_t id = readUint16(eofFound); | |
| 126 | + assert(id > 0); | |
| 127 | + logger->logf(LOG_DEBUG, "Loaded incomming QoS2 message id %d.", id); | |
| 128 | + ses->incomingQoS2MessageIds.insert(id); | |
| 129 | + } | |
| 130 | + | |
| 131 | + const uint32_t nrOfOutgoingPacketIds = readUint32(eofFound); | |
| 132 | + for (uint32_t i = 0; i < nrOfOutgoingPacketIds; i++) | |
| 133 | + { | |
| 134 | + uint16_t id = readUint16(eofFound); | |
| 135 | + assert(id > 0); | |
| 136 | + logger->logf(LOG_DEBUG, "Loaded outgoing QoS2 message id %d.", id); | |
| 137 | + ses->outgoingQoS2MessageIds.insert(id); | |
| 138 | + } | |
| 139 | + | |
| 140 | + const uint16_t nextPacketId = readUint16(eofFound); | |
| 141 | + logger->logf(LOG_DEBUG, "Loaded next packetid %d.", ses->nextPacketId); | |
| 142 | + ses->nextPacketId = nextPacketId; | |
| 143 | + | |
| 144 | + int64_t sessionAge = readInt64(eofFound); | |
| 145 | + logger->logf(LOG_DEBUG, "Loaded session age: %ld ms.", sessionAge); | |
| 146 | + ses->setSessionTouch(sessionAge); | |
| 147 | + } | |
| 148 | + | |
| 149 | + const uint32_t nrOfSubscriptions = readUint32(eofFound); | |
| 150 | + for (uint32_t i = 0; i < nrOfSubscriptions; i++) | |
| 151 | + { | |
| 152 | + const uint32_t topicLength = readUint32(eofFound); | |
| 153 | + readCheck(buf.data(), 1, topicLength, f); | |
| 154 | + const std::string topic(buf.data(), topicLength); | |
| 155 | + | |
| 156 | + logger->logf(LOG_DEBUG, "Loading subscriptions to topic '%s'.", topic.c_str()); | |
| 157 | + | |
| 158 | + const uint32_t nrOfClientIds = readUint32(eofFound); | |
| 159 | + | |
| 160 | + for (uint32_t i = 0; i < nrOfClientIds; i++) | |
| 161 | + { | |
| 162 | + const uint32_t clientIdLength = readUint32(eofFound); | |
| 163 | + readCheck(buf.data(), 1, clientIdLength, f); | |
| 164 | + const std::string clientId(buf.data(), clientIdLength); | |
| 165 | + | |
| 166 | + char qos; | |
| 167 | + readCheck(&qos, 1, 1, f); | |
| 168 | + | |
| 169 | + logger->logf(LOG_DEBUG, "Saving session '%s' subscription to '%s' QoS %d.", clientId.c_str(), topic.c_str(), qos); | |
| 170 | + | |
| 171 | + SubscriptionForSerializing sub(std::move(clientId), qos); | |
| 172 | + result.subscriptions[topic].push_back(std::move(sub)); | |
| 173 | + } | |
| 174 | + | |
| 175 | + } | |
| 176 | + } | |
| 177 | + | |
| 178 | + return result; | |
| 179 | +} | |
| 180 | + | |
| 181 | +void SessionsAndSubscriptionsDB::writeRowHeader() | |
| 182 | +{ | |
| 183 | + | |
| 184 | +} | |
| 185 | + | |
| 186 | +void SessionsAndSubscriptionsDB::saveData(const std::list<std::unique_ptr<Session>> &sessions, const std::unordered_map<std::string, std::list<SubscriptionForSerializing>> &subscriptions) | |
| 187 | +{ | |
| 188 | + if (!f) | |
| 189 | + return; | |
| 190 | + | |
| 191 | + char reserved[RESERVED_SPACE_SESSIONS_DB_V1]; | |
| 192 | + std::memset(reserved, 0, RESERVED_SPACE_SESSIONS_DB_V1); | |
| 193 | + | |
| 194 | + const int64_t start_stamp = Session::getProgramStartedAtUnixTimestamp(); | |
| 195 | + logger->logf(LOG_DEBUG, "Saving program first start time stamp as %ld", start_stamp); | |
| 196 | + writeInt64(start_stamp); | |
| 197 | + | |
| 198 | + writeUint32(sessions.size()); | |
| 199 | + | |
| 200 | + for (const std::unique_ptr<Session> &ses : sessions) | |
| 201 | + { | |
| 202 | + logger->logf(LOG_DEBUG, "Saving session '%s'.", ses->getClientId().c_str()); | |
| 203 | + | |
| 204 | + writeRowHeader(); | |
| 205 | + | |
| 206 | + writeCheck(reserved, 1, RESERVED_SPACE_SESSIONS_DB_V1, f); | |
| 207 | + | |
| 208 | + writeUint32(ses->username.length()); | |
| 209 | + writeCheck(ses->username.c_str(), 1, ses->username.length(), f); | |
| 210 | + | |
| 211 | + writeUint32(ses->client_id.length()); | |
| 212 | + writeCheck(ses->client_id.c_str(), 1, ses->client_id.length(), f); | |
| 213 | + | |
| 214 | + const size_t qosPacketsExpected = ses->qosPacketQueue.size(); | |
| 215 | + size_t qosPacketsCounted = 0; | |
| 216 | + writeUint32(qosPacketsExpected); | |
| 217 | + | |
| 218 | + for (const std::shared_ptr<MqttPacket> &p: ses->qosPacketQueue) | |
| 219 | + { | |
| 220 | + logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", p->getQos(), p->getTopic().c_str()); | |
| 221 | + | |
| 222 | + qosPacketsCounted++; | |
| 223 | + | |
| 224 | + writeUint16(p->getPacketId()); | |
| 225 | + | |
| 226 | + writeUint32(p->getTopic().length()); | |
| 227 | + std::string payload = p->getPayloadCopy(); | |
| 228 | + writeUint32(payload.size()); | |
| 229 | + | |
| 230 | + const char qos = p->getQos(); | |
| 231 | + writeCheck(&qos, 1, 1, f); | |
| 232 | + | |
| 233 | + writeCheck(p->getTopic().c_str(), 1, p->getTopic().length(), f); | |
| 234 | + writeCheck(payload.c_str(), 1, payload.length(), f); | |
| 235 | + } | |
| 236 | + | |
| 237 | + assert(qosPacketsExpected == qosPacketsCounted); | |
| 238 | + | |
| 239 | + writeUint32(ses->incomingQoS2MessageIds.size()); | |
| 240 | + for (uint16_t id : ses->incomingQoS2MessageIds) | |
| 241 | + { | |
| 242 | + logger->logf(LOG_DEBUG, "Writing incomming QoS2 message id %d.", id); | |
| 243 | + writeUint16(id); | |
| 244 | + } | |
| 245 | + | |
| 246 | + writeUint32(ses->outgoingQoS2MessageIds.size()); | |
| 247 | + for (uint16_t id : ses->outgoingQoS2MessageIds) | |
| 248 | + { | |
| 249 | + logger->logf(LOG_DEBUG, "Writing outgoing QoS2 message id %d.", id); | |
| 250 | + writeUint16(id); | |
| 251 | + } | |
| 252 | + | |
| 253 | + logger->logf(LOG_DEBUG, "Writing next packetid %d.", ses->nextPacketId); | |
| 254 | + writeUint16(ses->nextPacketId); | |
| 255 | + | |
| 256 | + const int64_t sInMs = ses->getSessionRelativeAgeInMs(); | |
| 257 | + logger->logf(LOG_DEBUG, "Writing session age: %ld ms.", sInMs); | |
| 258 | + writeInt64(sInMs); | |
| 259 | + } | |
| 260 | + | |
| 261 | + writeUint32(subscriptions.size()); | |
| 262 | + | |
| 263 | + for (auto &pair : subscriptions) | |
| 264 | + { | |
| 265 | + const std::string &topic = pair.first; | |
| 266 | + const std::list<SubscriptionForSerializing> &subscriptions = pair.second; | |
| 267 | + | |
| 268 | + logger->logf(LOG_DEBUG, "Writing subscriptions to topic '%s'.", topic.c_str()); | |
| 269 | + | |
| 270 | + writeUint32(topic.size()); | |
| 271 | + writeCheck(topic.c_str(), 1, topic.size(), f); | |
| 272 | + | |
| 273 | + writeUint32(subscriptions.size()); | |
| 274 | + | |
| 275 | + for (const SubscriptionForSerializing &subscription : subscriptions) | |
| 276 | + { | |
| 277 | + logger->logf(LOG_DEBUG, "Saving session '%s' subscription to '%s' QoS %d.", subscription.clientId.c_str(), topic.c_str(), subscription.qos); | |
| 278 | + | |
| 279 | + writeUint32(subscription.clientId.size()); | |
| 280 | + writeCheck(subscription.clientId.c_str(), 1, subscription.clientId.size(), f); | |
| 281 | + writeCheck(&subscription.qos, 1, 1, f); | |
| 282 | + } | |
| 283 | + } | |
| 284 | + | |
| 285 | + fflush(f); | |
| 286 | +} | |
| 287 | + | |
| 288 | +SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readData() | |
| 289 | +{ | |
| 290 | + SessionsAndSubscriptionsResult defaultResult; | |
| 291 | + | |
| 292 | + if (!f) | |
| 293 | + return defaultResult; | |
| 294 | + | |
| 295 | + if (readVersion == ReadVersion::v1) | |
| 296 | + return readDataV1(); | |
| 297 | + | |
| 298 | + return defaultResult; | |
| 299 | +} | ... | ... |
sessionsandsubscriptionsdb.h
0 โ 100644
| 1 | +/* | |
| 2 | +This file is part of FlashMQ (https://www.flashmq.org) | |
| 3 | +Copyright (C) 2021 Wiebe Cazemier | |
| 4 | + | |
| 5 | +FlashMQ is free software: you can redistribute it and/or modify | |
| 6 | +it under the terms of the GNU Affero General Public License as | |
| 7 | +published by the Free Software Foundation, version 3. | |
| 8 | + | |
| 9 | +FlashMQ is distributed in the hope that it will be useful, | |
| 10 | +but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| 11 | +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| 12 | +GNU Affero General Public License for more details. | |
| 13 | + | |
| 14 | +You should have received a copy of the GNU Affero General Public | |
| 15 | +License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | |
| 16 | +*/ | |
| 17 | + | |
| 18 | +#ifndef SESSIONSANDSUBSCRIPTIONSDB_H | |
| 19 | +#define SESSIONSANDSUBSCRIPTIONSDB_H | |
| 20 | + | |
| 21 | +#include <list> | |
| 22 | +#include <memory> | |
| 23 | + | |
| 24 | +#include "persistencefile.h" | |
| 25 | +#include "session.h" | |
| 26 | + | |
| 27 | +#define MAGIC_STRING_SESSION_FILE_V1 "FlashMQRetainedDBv1" | |
| 28 | +#define RESERVED_SPACE_SESSIONS_DB_V1 32 | |
| 29 | + | |
| 30 | +/** | |
| 31 | + * @brief The SubscriptionForSerializing struct contains the fields we're interested in when saving a subscription. | |
| 32 | + */ | |
| 33 | +struct SubscriptionForSerializing | |
| 34 | +{ | |
| 35 | + const std::string clientId; | |
| 36 | + const char qos = 0; | |
| 37 | + | |
| 38 | + SubscriptionForSerializing(const std::string &clientId, char qos); | |
| 39 | + SubscriptionForSerializing(const std::string &&clientId, char qos); | |
| 40 | +}; | |
| 41 | + | |
| 42 | +struct SessionsAndSubscriptionsResult | |
| 43 | +{ | |
| 44 | + std::list<std::shared_ptr<Session>> sessions; | |
| 45 | + std::unordered_map<std::string, std::list<SubscriptionForSerializing>> subscriptions; | |
| 46 | +}; | |
| 47 | + | |
| 48 | + | |
| 49 | +class SessionsAndSubscriptionsDB : public PersistenceFile | |
| 50 | +{ | |
| 51 | + enum class ReadVersion | |
| 52 | + { | |
| 53 | + unknown, | |
| 54 | + v1 | |
| 55 | + }; | |
| 56 | + | |
| 57 | + ReadVersion readVersion = ReadVersion::unknown; | |
| 58 | + | |
| 59 | + SessionsAndSubscriptionsResult readDataV1(); | |
| 60 | + void writeRowHeader(); | |
| 61 | +public: | |
| 62 | + SessionsAndSubscriptionsDB(const std::string &filePath); | |
| 63 | + | |
| 64 | + void openWrite(); | |
| 65 | + void openRead(); | |
| 66 | + | |
| 67 | + void saveData(const std::list<std::unique_ptr<Session>> &sessions, const std::unordered_map<std::string, std::list<SubscriptionForSerializing>> &subscriptions); | |
| 68 | + SessionsAndSubscriptionsResult readData(); | |
| 69 | +}; | |
| 70 | + | |
| 71 | +#endif // SESSIONSANDSUBSCRIPTIONSDB_H | ... | ... |
settings.cpp
| ... | ... | @@ -16,7 +16,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. |
| 16 | 16 | */ |
| 17 | 17 | |
| 18 | 18 | #include "settings.h" |
| 19 | - | |
| 19 | +#include "utils.h" | |
| 20 | 20 | |
| 21 | 21 | AuthOptCompatWrap &Settings::getAuthOptsCompat() |
| 22 | 22 | { |
| ... | ... | @@ -27,3 +27,21 @@ std::unordered_map<std::string, std::string> &Settings::getFlashmqAuthPluginOpts |
| 27 | 27 | { |
| 28 | 28 | return this->flashmqAuthPluginOpts; |
| 29 | 29 | } |
| 30 | + | |
| 31 | +std::string Settings::getRetainedMessagesDBFile() const | |
| 32 | +{ | |
| 33 | + if (storageDir.empty()) | |
| 34 | + return ""; | |
| 35 | + | |
| 36 | + std::string path = formatString("%s/%s", storageDir.c_str(), "retained.db"); | |
| 37 | + return path; | |
| 38 | +} | |
| 39 | + | |
| 40 | +std::string Settings::getSessionsDBFile() const | |
| 41 | +{ | |
| 42 | + if (storageDir.empty()) | |
| 43 | + return ""; | |
| 44 | + | |
| 45 | + std::string path = formatString("%s/%s", storageDir.c_str(), "sessions.db"); | |
| 46 | + return path; | |
| 47 | +} | ... | ... |
settings.h
| ... | ... | @@ -53,10 +53,14 @@ public: |
| 53 | 53 | int rlimitNoFile = 1000000; |
| 54 | 54 | uint64_t expireSessionsAfterSeconds = 1209600; |
| 55 | 55 | int authPluginTimerPeriod = 60; |
| 56 | + std::string storageDir; | |
| 56 | 57 | std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined. |
| 57 | 58 | |
| 58 | 59 | AuthOptCompatWrap &getAuthOptsCompat(); |
| 59 | 60 | std::unordered_map<std::string, std::string> &getFlashmqAuthPluginOpts(); |
| 61 | + | |
| 62 | + std::string getRetainedMessagesDBFile() const; | |
| 63 | + std::string getSessionsDBFile() const; | |
| 60 | 64 | }; |
| 61 | 65 | |
| 62 | 66 | #endif // SETTINGS_H | ... | ... |
subscriptionstore.cpp
| ... | ... | @@ -20,7 +20,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. |
| 20 | 20 | #include "cassert" |
| 21 | 21 | |
| 22 | 22 | #include "rwlockguard.h" |
| 23 | - | |
| 23 | +#include "retainedmessagesdb.h" | |
| 24 | 24 | |
| 25 | 25 | SubscriptionNode::SubscriptionNode(const std::string &subtopic) : |
| 26 | 26 | subtopic(subtopic) |
| ... | ... | @@ -33,6 +33,11 @@ std::vector<Subscription> &SubscriptionNode::getSubscribers() |
| 33 | 33 | return subscribers; |
| 34 | 34 | } |
| 35 | 35 | |
| 36 | +const std::string &SubscriptionNode::getSubtopic() const | |
| 37 | +{ | |
| 38 | + return subtopic; | |
| 39 | +} | |
| 40 | + | |
| 36 | 41 | void SubscriptionNode::addSubscriber(const std::shared_ptr<Session> &subscriber, char qos) |
| 37 | 42 | { |
| 38 | 43 | Subscription sub; |
| ... | ... | @@ -84,15 +89,20 @@ SubscriptionStore::SubscriptionStore() : |
| 84 | 89 | |
| 85 | 90 | } |
| 86 | 91 | |
| 87 | -void SubscriptionStore::addSubscription(std::shared_ptr<Client> &client, const std::string &topic, const std::vector<std::string> &subtopics, char qos) | |
| 92 | +/** | |
| 93 | + * @brief SubscriptionStore::getDeepestNode gets the node in the tree walking the path of 'the/subscription/topic/path', making new nodes as required. | |
| 94 | + * @param topic | |
| 95 | + * @param subtopics | |
| 96 | + * @return | |
| 97 | + * | |
| 98 | + * caller is responsible for locking. | |
| 99 | + */ | |
| 100 | +SubscriptionNode *SubscriptionStore::getDeepestNode(const std::string &topic, const std::vector<std::string> &subtopics) | |
| 88 | 101 | { |
| 89 | 102 | SubscriptionNode *deepestNode = &root; |
| 90 | 103 | if (topic.length() > 0 && topic[0] == '$') |
| 91 | 104 | deepestNode = &rootDollar; |
| 92 | 105 | |
| 93 | - RWLockGuard lock_guard(&subscriptionsRwlock); | |
| 94 | - lock_guard.wrlock(); | |
| 95 | - | |
| 96 | 106 | for(const std::string &subtopic : subtopics) |
| 97 | 107 | { |
| 98 | 108 | std::unique_ptr<SubscriptionNode> *selectedChildren = nullptr; |
| ... | ... | @@ -114,6 +124,15 @@ void SubscriptionStore::addSubscription(std::shared_ptr<Client> &client, const s |
| 114 | 124 | } |
| 115 | 125 | |
| 116 | 126 | assert(deepestNode); |
| 127 | + return deepestNode; | |
| 128 | +} | |
| 129 | + | |
| 130 | +void SubscriptionStore::addSubscription(std::shared_ptr<Client> &client, const std::string &topic, const std::vector<std::string> &subtopics, char qos) | |
| 131 | +{ | |
| 132 | + RWLockGuard lock_guard(&subscriptionsRwlock); | |
| 133 | + lock_guard.wrlock(); | |
| 134 | + | |
| 135 | + SubscriptionNode *deepestNode = getDeepestNode(topic, subtopics); | |
| 117 | 136 | |
| 118 | 137 | if (deepestNode) |
| 119 | 138 | { |
| ... | ... | @@ -494,6 +513,183 @@ int64_t SubscriptionStore::getRetainedMessageCount() const |
| 494 | 513 | return retainedMessageCount; |
| 495 | 514 | } |
| 496 | 515 | |
| 516 | +void SubscriptionStore::getRetainedMessages(RetainedMessageNode *this_node, std::vector<RetainedMessage> &outputList) const | |
| 517 | +{ | |
| 518 | + for(const RetainedMessage &rm : this_node->retainedMessages) | |
| 519 | + { | |
| 520 | + outputList.push_back(rm); | |
| 521 | + } | |
| 522 | + | |
| 523 | + for(auto &pair : this_node->children) | |
| 524 | + { | |
| 525 | + const std::unique_ptr<RetainedMessageNode> &child = pair.second; | |
| 526 | + getRetainedMessages(child.get(), outputList); | |
| 527 | + } | |
| 528 | +} | |
| 529 | + | |
| 530 | +/** | |
| 531 | + * @brief SubscriptionStore::getSubscriptions | |
| 532 | + * @param this_node | |
| 533 | + * @param composedTopic | |
| 534 | + * @param root bool. Every subtopic is concatenated with a '/', but not the first topic to 'root'. The root is a bit weird, virtual, so it needs different treatment. | |
| 535 | + * @param outputList | |
| 536 | + */ | |
| 537 | +void SubscriptionStore::getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root, | |
| 538 | + std::unordered_map<std::string, std::list<SubscriptionForSerializing>> &outputList) const | |
| 539 | +{ | |
| 540 | + for (const Subscription &node : this_node->getSubscribers()) | |
| 541 | + { | |
| 542 | + if (!node.sessionGone()) | |
| 543 | + { | |
| 544 | + SubscriptionForSerializing sub(node.session.lock()->getClientId(), node.qos); | |
| 545 | + outputList[composedTopic].push_back(sub); | |
| 546 | + } | |
| 547 | + } | |
| 548 | + | |
| 549 | + for (auto &pair : this_node->children) | |
| 550 | + { | |
| 551 | + SubscriptionNode *node = pair.second.get(); | |
| 552 | + const std::string topicAtNextLevel = root ? pair.first : composedTopic + "/" + pair.first; | |
| 553 | + getSubscriptions(node, topicAtNextLevel, false, outputList); | |
| 554 | + } | |
| 555 | + | |
| 556 | + if (this_node->childrenPlus) | |
| 557 | + { | |
| 558 | + const std::string topicAtNextLevel = root ? "+" : composedTopic + "/+"; | |
| 559 | + getSubscriptions(this_node->childrenPlus.get(), topicAtNextLevel, false, outputList); | |
| 560 | + } | |
| 561 | + | |
| 562 | + if (this_node->childrenPound) | |
| 563 | + { | |
| 564 | + const std::string topicAtNextLevel = root ? "#" : composedTopic + "/#"; | |
| 565 | + getSubscriptions(this_node->childrenPound.get(), topicAtNextLevel, false, outputList); | |
| 566 | + } | |
| 567 | +} | |
| 568 | + | |
| 569 | +void SubscriptionStore::saveRetainedMessages(const std::string &filePath) | |
| 570 | +{ | |
| 571 | + logger->logf(LOG_INFO, "Saving retained messages to '%s'", filePath.c_str()); | |
| 572 | + | |
| 573 | + std::vector<RetainedMessage> result; | |
| 574 | + result.reserve(retainedMessageCount); | |
| 575 | + | |
| 576 | + // Create the list of messages under lock, and unlock right after. | |
| 577 | + RWLockGuard locker(&retainedMessagesRwlock); | |
| 578 | + locker.rdlock(); | |
| 579 | + getRetainedMessages(&retainedMessagesRoot, result); | |
| 580 | + locker.unlock(); | |
| 581 | + | |
| 582 | + logger->logf(LOG_DEBUG, "Collected %ld retained messages to save.", result.size()); | |
| 583 | + | |
| 584 | + // Then do the IO without locking the threads. | |
| 585 | + RetainedMessagesDB db(filePath); | |
| 586 | + db.openWrite(); | |
| 587 | + db.saveData(result); | |
| 588 | +} | |
| 589 | + | |
| 590 | +void SubscriptionStore::loadRetainedMessages(const std::string &filePath) | |
| 591 | +{ | |
| 592 | + try | |
| 593 | + { | |
| 594 | + logger->logf(LOG_INFO, "Loading '%s'", filePath.c_str()); | |
| 595 | + | |
| 596 | + RetainedMessagesDB db(filePath); | |
| 597 | + db.openRead(); | |
| 598 | + std::list<RetainedMessage> messages = db.readData(); | |
| 599 | + | |
| 600 | + RWLockGuard locker(&retainedMessagesRwlock); | |
| 601 | + locker.wrlock(); | |
| 602 | + | |
| 603 | + std::vector<std::string> subtopics; | |
| 604 | + for (const RetainedMessage &rm : messages) | |
| 605 | + { | |
| 606 | + splitTopic(rm.topic, subtopics); | |
| 607 | + setRetainedMessage(rm.topic, subtopics, rm.payload, rm.qos); | |
| 608 | + } | |
| 609 | + } | |
| 610 | + catch (PersistenceFileCantBeOpened &ex) | |
| 611 | + { | |
| 612 | + logger->logf(LOG_WARNING, "File '%s' is not there (yet)", filePath.c_str()); | |
| 613 | + } | |
| 614 | +} | |
| 615 | + | |
| 616 | +void SubscriptionStore::saveSessionsAndSubscriptions(const std::string &filePath) | |
| 617 | +{ | |
| 618 | + logger->logf(LOG_INFO, "Saving sessions and subscriptions to '%s'", filePath.c_str()); | |
| 619 | + | |
| 620 | + RWLockGuard lock_guard(&subscriptionsRwlock); | |
| 621 | + lock_guard.wrlock(); | |
| 622 | + | |
| 623 | + // First copy the sessions... | |
| 624 | + | |
| 625 | + std::list<std::unique_ptr<Session>> sessionCopies; | |
| 626 | + | |
| 627 | + for (const auto &pair : sessionsByIdConst) | |
| 628 | + { | |
| 629 | + const Session &org = *pair.second.get(); | |
| 630 | + sessionCopies.push_back(org.getCopy()); | |
| 631 | + } | |
| 632 | + | |
| 633 | + std::unordered_map<std::string, std::list<SubscriptionForSerializing>> subscriptionCopies; | |
| 634 | + getSubscriptions(&root, "", true, subscriptionCopies); | |
| 635 | + | |
| 636 | + lock_guard.unlock(); | |
| 637 | + | |
| 638 | + // Then write the copies to disk, after having released the lock | |
| 639 | + | |
| 640 | + logger->logf(LOG_DEBUG, "Collected %ld sessions and %ld subscriptions to save.", sessionCopies.size(), subscriptionCopies.size()); | |
| 641 | + | |
| 642 | + SessionsAndSubscriptionsDB db(filePath); | |
| 643 | + db.openWrite(); | |
| 644 | + db.saveData(sessionCopies, subscriptionCopies); | |
| 645 | +} | |
| 646 | + | |
| 647 | +void SubscriptionStore::loadSessionsAndSubscriptions(const std::string &filePath) | |
| 648 | +{ | |
| 649 | + try | |
| 650 | + { | |
| 651 | + logger->logf(LOG_INFO, "Loading '%s'", filePath.c_str()); | |
| 652 | + | |
| 653 | + SessionsAndSubscriptionsDB db(filePath); | |
| 654 | + db.openRead(); | |
| 655 | + SessionsAndSubscriptionsResult loadedData = db.readData(); | |
| 656 | + | |
| 657 | + RWLockGuard locker(&subscriptionsRwlock); | |
| 658 | + locker.wrlock(); | |
| 659 | + | |
| 660 | + for (std::shared_ptr<Session> &session : loadedData.sessions) | |
| 661 | + { | |
| 662 | + sessionsById[session->getClientId()] = session; | |
| 663 | + } | |
| 664 | + | |
| 665 | + std::vector<std::string> subtopics; | |
| 666 | + | |
| 667 | + for (auto &pair : loadedData.subscriptions) | |
| 668 | + { | |
| 669 | + const std::string &topic = pair.first; | |
| 670 | + const std::list<SubscriptionForSerializing> &subs = pair.second; | |
| 671 | + | |
| 672 | + for (const SubscriptionForSerializing &sub : subs) | |
| 673 | + { | |
| 674 | + splitTopic(topic, subtopics); | |
| 675 | + SubscriptionNode *subscriptionNode = getDeepestNode(topic, subtopics); | |
| 676 | + | |
| 677 | + auto session_it = sessionsByIdConst.find(sub.clientId); | |
| 678 | + if (session_it != sessionsByIdConst.end()) | |
| 679 | + { | |
| 680 | + const std::shared_ptr<Session> &ses = session_it->second; | |
| 681 | + subscriptionNode->addSubscriber(ses, sub.qos); | |
| 682 | + } | |
| 683 | + | |
| 684 | + } | |
| 685 | + } | |
| 686 | + } | |
| 687 | + catch (PersistenceFileCantBeOpened &ex) | |
| 688 | + { | |
| 689 | + logger->logf(LOG_WARNING, "File '%s' is not there (yet)", filePath.c_str()); | |
| 690 | + } | |
| 691 | +} | |
| 692 | + | |
| 497 | 693 | // QoS is not used in the comparision. This means you upgrade your QoS by subscribing again. The |
| 498 | 694 | // specs don't specify what to do there. |
| 499 | 695 | bool Subscription::operator==(const Subscription &rhs) const | ... | ... |
subscriptionstore.h
| ... | ... | @@ -53,6 +53,7 @@ public: |
| 53 | 53 | SubscriptionNode(SubscriptionNode &&node) = delete; |
| 54 | 54 | |
| 55 | 55 | std::vector<Subscription> &getSubscribers(); |
| 56 | + const std::string &getSubtopic() const; | |
| 56 | 57 | void addSubscriber(const std::shared_ptr<Session> &subscriber, char qos); |
| 57 | 58 | void removeSubscriber(const std::shared_ptr<Session> &subscriber); |
| 58 | 59 | std::unordered_map<std::string, std::unique_ptr<SubscriptionNode>> children; |
| ... | ... | @@ -77,6 +78,10 @@ class RetainedMessageNode |
| 77 | 78 | |
| 78 | 79 | class SubscriptionStore |
| 79 | 80 | { |
| 81 | +#ifdef TESTING | |
| 82 | + friend class MainTests; | |
| 83 | +#endif | |
| 84 | + | |
| 80 | 85 | SubscriptionNode root; |
| 81 | 86 | SubscriptionNode rootDollar; |
| 82 | 87 | pthread_rwlock_t subscriptionsRwlock = PTHREAD_RWLOCK_INITIALIZER; |
| ... | ... | @@ -93,7 +98,11 @@ class SubscriptionStore |
| 93 | 98 | void publishNonRecursively(const MqttPacket &packet, const std::vector<Subscription> &subscribers, uint64_t &count) const; |
| 94 | 99 | void publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end, |
| 95 | 100 | SubscriptionNode *this_node, const MqttPacket &packet, uint64_t &count) const; |
| 101 | + void getRetainedMessages(RetainedMessageNode *this_node, std::vector<RetainedMessage> &outputList) const; | |
| 102 | + void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root, | |
| 103 | + std::unordered_map<std::string, std::list<SubscriptionForSerializing>> &outputList) const; | |
| 96 | 104 | |
| 105 | + SubscriptionNode *getDeepestNode(const std::string &topic, const std::vector<std::string> &subtopics); | |
| 97 | 106 | public: |
| 98 | 107 | SubscriptionStore(); |
| 99 | 108 | |
| ... | ... | @@ -103,7 +112,6 @@ public: |
| 103 | 112 | bool sessionPresent(const std::string &clientid); |
| 104 | 113 | |
| 105 | 114 | void queuePacketAtSubscribers(const std::vector<std::string> &subtopics, const MqttPacket &packet, bool dollar = false); |
| 106 | - uint64_t giveClientRetainedMessages(const std::shared_ptr<Session> &ses, const std::string &subscribe_topic, char max_qos); | |
| 107 | 115 | void giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end, |
| 108 | 116 | RetainedMessageNode *this_node, char max_qos, const std::shared_ptr<Session> &ses, |
| 109 | 117 | bool poundMode, uint64_t &count) const; |
| ... | ... | @@ -114,6 +122,12 @@ public: |
| 114 | 122 | void removeExpiredSessionsClients(int expireSessionsAfterSeconds); |
| 115 | 123 | |
| 116 | 124 | int64_t getRetainedMessageCount() const; |
| 125 | + | |
| 126 | + void saveRetainedMessages(const std::string &filePath); | |
| 127 | + void loadRetainedMessages(const std::string &filePath); | |
| 128 | + | |
| 129 | + void saveSessionsAndSubscriptions(const std::string &filePath); | |
| 130 | + void loadSessionsAndSubscriptions(const std::string &filePath); | |
| 117 | 131 | }; |
| 118 | 132 | |
| 119 | 133 | #endif // SUBSCRIPTIONSTORE_H | ... | ... |
utils.cpp
| ... | ... | @@ -289,6 +289,14 @@ void trim(std::string &s) |
| 289 | 289 | rtrim(s); |
| 290 | 290 | } |
| 291 | 291 | |
| 292 | +std::string &rtrim(std::string &s, unsigned char c) | |
| 293 | +{ | |
| 294 | + s.erase(std::find_if(s.rbegin(), s.rend(), [=](unsigned char ch) { | |
| 295 | + return (c != ch); | |
| 296 | + }).base(), s.end()); | |
| 297 | + return s; | |
| 298 | +} | |
| 299 | + | |
| 292 | 300 | bool startsWith(const std::string &s, const std::string &needle) |
| 293 | 301 | { |
| 294 | 302 | return s.find(needle) == 0; | ... | ... |
utils.h
| ... | ... | @@ -28,6 +28,8 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. |
| 28 | 28 | #include <openssl/evp.h> |
| 29 | 29 | #include <memory> |
| 30 | 30 | #include <arpa/inet.h> |
| 31 | +#include "unistd.h" | |
| 32 | +#include "sys/stat.h" | |
| 31 | 33 | |
| 32 | 34 | #include "cirbuf.h" |
| 33 | 35 | #include "bindaddr.h" |
| ... | ... | @@ -64,6 +66,7 @@ void ltrim(std::string &s); |
| 64 | 66 | void rtrim(std::string &s); |
| 65 | 67 | void trim(std::string &s); |
| 66 | 68 | bool startsWith(const std::string &s, const std::string &needle); |
| 69 | +std::string &rtrim(std::string &s, unsigned char c); | |
| 67 | 70 | |
| 68 | 71 | std::string getSecureRandomString(const ssize_t len); |
| 69 | 72 | std::string str_tolower(std::string s); |
| ... | ... | @@ -92,5 +95,32 @@ ssize_t getFileSize(const std::string &path); |
| 92 | 95 | |
| 93 | 96 | std::string sockaddrToString(struct sockaddr *addr); |
| 94 | 97 | |
| 98 | +template<typename ex> void checkWritableDir(const std::string &path) | |
| 99 | +{ | |
| 100 | + if (path.empty()) | |
| 101 | + throw ex("Dir path to check is an empty string."); | |
| 102 | + | |
| 103 | + if (access(path.c_str(), W_OK) != 0) | |
| 104 | + { | |
| 105 | + std::string msg = formatString("Path '%s' is not there or not writable", path.c_str()); | |
| 106 | + throw ex(msg); | |
| 107 | + } | |
| 108 | + | |
| 109 | + struct stat statbuf; | |
| 110 | + memset(&statbuf, 0, sizeof(struct stat)); | |
| 111 | + if (stat(path.c_str(), &statbuf) < 0) | |
| 112 | + { | |
| 113 | + // We checked for W_OK above, so this shouldn't happen. | |
| 114 | + std::string msg = formatString("Error getting information about '%s'.", path.c_str()); | |
| 115 | + throw ex(msg); | |
| 116 | + } | |
| 117 | + | |
| 118 | + if (!S_ISDIR(statbuf.st_mode)) | |
| 119 | + { | |
| 120 | + std::string msg = formatString("Path '%s' is not a directory.", path.c_str()); | |
| 121 | + throw ex(msg); | |
| 122 | + } | |
| 123 | +} | |
| 124 | + | |
| 95 | 125 | |
| 96 | 126 | #endif // UTILS_H | ... | ... |