diff --git a/CMakeLists.txt b/CMakeLists.txt index 759c073..3657202 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,6 +50,10 @@ add_executable(FlashMQ enums.h threadlocalutils.h flashmq_plugin.h + retainedmessagesdb.h + persistencefile.h + sessionsandsubscriptionsdb.h + qospacketqueue.h mainapp.cpp main.cpp @@ -81,6 +85,10 @@ add_executable(FlashMQ acltree.cpp threadlocalutils.cpp flashmq_plugin.cpp + retainedmessagesdb.cpp + persistencefile.cpp + sessionsandsubscriptionsdb.cpp + qospacketqueue.cpp ) diff --git a/FlashMQTests/FlashMQTests.pro b/FlashMQTests/FlashMQTests.pro index 1c48319..450323e 100644 --- a/FlashMQTests/FlashMQTests.pro +++ b/FlashMQTests/FlashMQTests.pro @@ -42,6 +42,10 @@ SOURCES += tst_maintests.cpp \ ../acltree.cpp \ ../threadlocalutils.cpp \ ../flashmq_plugin.cpp \ + ../retainedmessagesdb.cpp \ + ../persistencefile.cpp \ + ../sessionsandsubscriptionsdb.cpp \ + ../qospacketqueue.cpp \ mainappthread.cpp \ twoclienttestcontext.cpp @@ -77,6 +81,10 @@ HEADERS += \ ../acltree.h \ ../threadlocalutils.h \ ../flashmq_plugin.h \ + ../retainedmessagesdb.h \ + ../persistencefile.h \ + ../sessionsandsubscriptionsdb.h \ + ../qospacketqueue.h \ mainappthread.h \ twoclienttestcontext.h diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 14294a9..ef1866a 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -21,12 +21,18 @@ License along with FlashMQ. If not, see . #include #include #include +#include +#include #include "cirbuf.h" #include "mainapp.h" #include "mainappthread.h" #include "twoclienttestcontext.h" #include "threadlocalutils.h" +#include "retainedmessagesdb.h" +#include "sessionsandsubscriptionsdb.h" +#include "session.h" +#include "threaddata.h" // Dumb Qt version gives warnings when comparing uint with number literal. template @@ -88,6 +94,12 @@ private slots: void testTopicsMatch(); + void testRetainedMessageDB(); + void testRetainedMessageDBNotPresent(); + void testRetainedMessageDBEmptyList(); + + void testSavingSessions(); + }; MainTests::MainTests() @@ -822,6 +834,221 @@ void MainTests::testTopicsMatch() } +void MainTests::testRetainedMessageDB() +{ + try + { + std::string longpayload = getSecureRandomString(65537); + std::string longTopic = formatString("one/two/%s", getSecureRandomString(4000).c_str()); + + std::vector messages; + messages.emplace_back("one/two/three", "payload", 0); + messages.emplace_back("one/two/wer", "payload", 1); + messages.emplace_back("one/e/wer", "payload", 1); + messages.emplace_back("one/wee/wer", "asdfasdfasdf", 1); + messages.emplace_back("one/two/wer", "µsdf", 1); + messages.emplace_back("/boe/bah", longpayload, 1); + messages.emplace_back("one/two/wer", "paylasdfaoad", 1); + messages.emplace_back("one/two/wer", "payload", 1); + messages.emplace_back(longTopic, "payload", 1); + messages.emplace_back(longTopic, longpayload, 1); + messages.emplace_back("one", "µsdf", 1); + messages.emplace_back("/boe", longpayload, 1); + messages.emplace_back("one", "µsdf", 1); + messages.emplace_back("", "foremptytopic", 0); + + RetainedMessagesDB db("/tmp/flashmqtests_retained.db"); + db.openWrite(); + db.saveData(messages); + db.closeFile(); + + RetainedMessagesDB db2("/tmp/flashmqtests_retained.db"); + db2.openRead(); + std::list messagesLoaded = db2.readData(); + db2.closeFile(); + + QCOMPARE(messages.size(), messagesLoaded.size()); + + auto itOrg = messages.begin(); + auto itLoaded = messagesLoaded.begin(); + while (itOrg != messages.end() && itLoaded != messagesLoaded.end()) + { + RetainedMessage &one = *itOrg; + RetainedMessage &two = *itLoaded; + + // Comparing the fields because the RetainedMessage class has an == operator that only looks at topic. + QCOMPARE(one.topic, two.topic); + QCOMPARE(one.payload, two.payload); + QCOMPARE(one.qos, two.qos); + + itOrg++; + itLoaded++; + } + } + catch (std::exception &ex) + { + QVERIFY2(false, ex.what()); + } +} + +void MainTests::testRetainedMessageDBNotPresent() +{ + try + { + RetainedMessagesDB db2("/tmp/flashmqtests_asdfasdfasdf.db"); + db2.openRead(); + std::list messagesLoaded = db2.readData(); + db2.closeFile(); + + MYCASTCOMPARE(messagesLoaded.size(), 0); + + QVERIFY2(false, "We should have run into an exception."); + } + catch (PersistenceFileCantBeOpened &ex) + { + QVERIFY(true); + } + catch (std::exception &ex) + { + QVERIFY2(false, ex.what()); + } +} + +void MainTests::testRetainedMessageDBEmptyList() +{ + try + { + std::vector messages; + + RetainedMessagesDB db("/tmp/flashmqtests_retained.db"); + db.openWrite(); + db.saveData(messages); + db.closeFile(); + + RetainedMessagesDB db2("/tmp/flashmqtests_retained.db"); + db2.openRead(); + std::list messagesLoaded = db2.readData(); + db2.closeFile(); + + MYCASTCOMPARE(messages.size(), messagesLoaded.size()); + MYCASTCOMPARE(messages.size(), 0); + } + catch (std::exception &ex) + { + QVERIFY2(false, ex.what()); + } +} + +void MainTests::testSavingSessions() +{ + try + { + std::shared_ptr settings(new Settings()); + std::shared_ptr store(new SubscriptionStore()); + std::shared_ptr t(new ThreadData(0, store, settings)); + + std::shared_ptr c1(new Client(0, t, nullptr, false, nullptr, settings, false)); + c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60, false); + store->registerClientAndKickExistingOne(c1); + c1->getSession()->touch(); + c1->getSession()->addIncomingQoS2MessageId(2); + c1->getSession()->addIncomingQoS2MessageId(3); + + std::shared_ptr c2(new Client(0, t, nullptr, false, nullptr, settings, false)); + c2->setClientProperties(ProtocolVersion::Mqtt311, "c2", "user2", true, 60, false); + store->registerClientAndKickExistingOne(c2); + c2->getSession()->touch(); + c2->getSession()->addOutgoingQoS2MessageId(55); + c2->getSession()->addOutgoingQoS2MessageId(66); + + const std::string topic1 = "one/two/three"; + std::vector subtopics; + splitTopic(topic1, subtopics); + store->addSubscription(c1, topic1, subtopics, 0); + + const std::string topic2 = "four/five/six"; + splitTopic(topic2, subtopics); + store->addSubscription(c2, topic2, subtopics, 0); + store->addSubscription(c1, topic2, subtopics, 0); + + const std::string topic3 = ""; + splitTopic(topic3, subtopics); + store->addSubscription(c2, topic3, subtopics, 0); + + const std::string topic4 = "#"; + splitTopic(topic4, subtopics); + store->addSubscription(c2, topic4, subtopics, 0); + + uint64_t count = 0; + + Publish publish("a/b/c", "Hello Barry", 1); + + std::shared_ptr c1ses = c1->getSession(); + c1.reset(); + c1ses->writePacket(publish, 1, false, count); + + store->saveSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db"); + + std::shared_ptr store2(new SubscriptionStore()); + store2->loadSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db"); + + MYCASTCOMPARE(store->sessionsById.size(), 2); + MYCASTCOMPARE(store2->sessionsById.size(), 2); + + for (auto &pair : store->sessionsById) + { + std::shared_ptr &ses = pair.second; + std::shared_ptr &ses2 = store2->sessionsById[pair.first]; + QCOMPARE(pair.first, ses2->getClientId()); + + QCOMPARE(ses->username, ses2->username); + QCOMPARE(ses->client_id, ses2->client_id); + QCOMPARE(ses->incomingQoS2MessageIds, ses2->incomingQoS2MessageIds); + QCOMPARE(ses->outgoingQoS2MessageIds, ses2->outgoingQoS2MessageIds); + QCOMPARE(ses->nextPacketId, ses2->nextPacketId); + } + + + std::unordered_map> store1Subscriptions; + store->getSubscriptions(&store->root, "", true, store1Subscriptions); + + std::unordered_map> store2Subscriptions; + store->getSubscriptions(&store->root, "", true, store2Subscriptions); + + MYCASTCOMPARE(store1Subscriptions.size(), 4); + MYCASTCOMPARE(store2Subscriptions.size(), 4); + + for(auto &pair : store1Subscriptions) + { + std::list &subscList1 = pair.second; + std::list &subscList2 = store2Subscriptions[pair.first]; + + QCOMPARE(subscList1.size(), subscList2.size()); + + auto subs1It = subscList1.begin(); + auto subs2It = subscList2.begin(); + + while (subs1It != subscList1.end()) + { + SubscriptionForSerializing &one = *subs1It; + SubscriptionForSerializing &two = *subs2It; + QCOMPARE(one.clientId, two.clientId); + QCOMPARE(one.qos, two.qos); + + subs1It++; + subs2It++; + } + + } + + + } + catch (std::exception &ex) + { + QVERIFY2(false, ex.what()); + } +} + QTEST_GUILESS_MAIN(MainTests) #include "tst_maintests.moc" diff --git a/configfileparser.cpp b/configfileparser.cpp index b99810e..9211f17 100644 --- a/configfileparser.cpp +++ b/configfileparser.cpp @@ -107,6 +107,7 @@ ConfigFileParser::ConfigFileParser(const std::string &path) : validKeys.insert("allow_anonymous"); validKeys.insert("rlimit_nofile"); validKeys.insert("expire_sessions_after_seconds"); + validKeys.insert("storage_dir"); validListenKeys.insert("port"); validListenKeys.insert("protocol"); @@ -412,6 +413,14 @@ void ConfigFileParser::loadFile(bool test) } tmpSettings->authPluginTimerPeriod = newVal; } + + if (key == "storage_dir") + { + std::string newPath = value; + rtrim(newPath, '/'); + checkWritableDir(newPath); + tmpSettings->storageDir = newPath; + } } } catch (std::invalid_argument &ex) // catch for the stoi() diff --git a/mainapp.cpp b/mainapp.cpp index de05ac7..de4f00e 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -202,6 +202,15 @@ MainApp::MainApp(const std::string &configFilePath) : auto fAuthPluginPeriodicEvent = std::bind(&MainApp::queueAuthPluginPeriodicEventAllThreads, this); timer.addCallback(fAuthPluginPeriodicEvent, settings->authPluginTimerPeriod*1000, "Auth plugin periodic event."); } + + if (!settings->storageDir.empty()) + { + subscriptionStore->loadRetainedMessages(settings->getRetainedMessagesDBFile()); + subscriptionStore->loadSessionsAndSubscriptions(settings->getSessionsDBFile()); + } + + auto fSaveState = std::bind(&MainApp::saveState, this); + timer.addCallback(fSaveState, 900000, "Save state."); } MainApp::~MainApp() @@ -369,6 +378,25 @@ void MainApp::publishStat(const std::string &topic, uint64_t n) subscriptionStore->setRetainedMessage(topic, subtopics, payload, 0); } +void MainApp::saveState() +{ + try + { + if (!settings->storageDir.empty()) + { + const std::string retainedDBPath = settings->getRetainedMessagesDBFile(); + subscriptionStore->saveRetainedMessages(retainedDBPath); + + const std::string sessionsDBPath = settings->getSessionsDBFile(); + subscriptionStore->saveSessionsAndSubscriptions(sessionsDBPath); + } + } + catch(std::exception &ex) + { + logger->logf(LOG_ERR, "Error saving state: %s", ex.what()); + } +} + void MainApp::initMainApp(int argc, char *argv[]) { if (instance != nullptr) @@ -659,6 +687,8 @@ void MainApp::start() { thread->waitForQuit(); } + + saveState(); } void MainApp::quit() diff --git a/mainapp.h b/mainapp.h index 2a00475..7a31509 100644 --- a/mainapp.h +++ b/mainapp.h @@ -84,6 +84,7 @@ class MainApp void setFuzzFile(const std::string &fuzzFilePath); void publishStatsOnDollarTopic(); void publishStat(const std::string &topic, uint64_t n); + void saveState(); MainApp(const std::string &configFilePath); public: diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 19817d7..e69c310 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -47,8 +47,11 @@ MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_lengt pos += fixed_header_length; } -// This is easier than using the copy constructor publically, because then I have to keep maintaining a functioning copy constructor. -// Returning shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource. +/** + * @brief MqttPacket::getCopy (using default copy constructor and resetting some selected fields) is easier than using the copy constructor + * publically, because then I have to keep maintaining a functioning copy constructor for each new field I add. + * @return a shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource. + */ std::shared_ptr MqttPacket::getCopy() const { std::shared_ptr copyPacket(new MqttPacket(*this)); @@ -129,6 +132,9 @@ MqttPacket::MqttPacket(const Publish &publish) : writeBytes(zero, 2); } + payloadStart = pos; + payloadLen = publish.payload.length(); + writeBytes(publish.payload.c_str(), publish.payload.length()); calculateRemainingLength(); } @@ -546,13 +552,14 @@ void MqttPacket::handlePublish() } } + payloadLen = remainingAfterPos(); + payloadStart = pos; + if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success) { if (retain) { - size_t payload_length = remainingAfterPos(); - std::string payload(readBytes(payload_length), payload_length); - + std::string payload(readBytes(payloadLen), payloadLen); sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, *subtopics, payload, qos); } @@ -679,6 +686,23 @@ size_t MqttPacket::getTotalMemoryFootprint() return bites.size() + sizeof(MqttPacket); } +/** + * @brief MqttPacket::getPayloadCopy takes part of the vector of bytes and returns it as a string. + * @return + * + * It's necessary sometimes, but it's against FlashMQ's concept of not parsing the payload. Normally, you can just write out + * the whole byte array that is a packet to subscribers. No need to copy and such. + * + * I created it for saving QoS packages in the db file. + */ +std::string MqttPacket::getPayloadCopy() +{ + assert(payloadStart > 0); + assert(pos <= bites.size()); + std::string payload(&bites[payloadStart], payloadLen); + return payload; +} + size_t MqttPacket::getSizeIncludingNonPresentHeader() const { size_t total = bites.size(); diff --git a/mqttpacket.h b/mqttpacket.h index 6c8ddb8..9b0d7a5 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -59,6 +59,8 @@ class MqttPacket size_t packet_id_pos = 0; uint16_t packet_id = 0; ProtocolVersion protocolVersion = ProtocolVersion::None; + size_t payloadStart = 0; + size_t payloadLen = 0; Logger *logger = Logger::getInstance(); char *readBytes(size_t length); @@ -116,6 +118,7 @@ public: uint16_t getPacketId() const; void setDuplicate(); size_t getTotalMemoryFootprint(); + std::string getPayloadCopy(); }; #endif // MQTTPACKET_H diff --git a/persistencefile.cpp b/persistencefile.cpp new file mode 100644 index 0000000..03ed4f7 --- /dev/null +++ b/persistencefile.cpp @@ -0,0 +1,320 @@ +/* +This file is part of FlashMQ (https://www.flashmq.org) +Copyright (C) 2021 Wiebe Cazemier + +FlashMQ is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, version 3. + +FlashMQ is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public +License along with FlashMQ. If not, see . +*/ + +#include "persistencefile.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.h" +#include "logger.h" + +PersistenceFile::PersistenceFile(const std::string &filePath) : + digestContext(EVP_MD_CTX_new()), + buf(1024*1024) +{ + if (!filePath.empty() && filePath[filePath.size() - 1] == '/') + throw std::runtime_error("Target file can't contain trailing slash."); + + this->filePath = filePath; + this->filePathTemp = formatString("%s.newfile.%s", filePath.c_str(), getSecureRandomString(8).c_str()); + this->filePathCorrupt = formatString("%s.corrupt.%s", filePath.c_str(), getSecureRandomString(8).c_str()); +} + +PersistenceFile::~PersistenceFile() +{ + closeFile(); + + if (digestContext) + { + EVP_MD_CTX_free(digestContext); + } +} + +/** + * @brief RetainedMessagesDB::hashFile hashes the data after the headers and writes the hash in the header. Uses SHA512. + * + */ +void PersistenceFile::writeCheck(const void *ptr, size_t size, size_t n, FILE *s) +{ + if (fwrite(ptr, size, n, s) != n) + { + throw std::runtime_error(formatString("Error writing: %s", strerror(errno))); + } +} + +ssize_t PersistenceFile::readCheck(void *ptr, size_t size, size_t n, FILE *stream) +{ + size_t nread = fread(ptr, size, n, stream); + + if (nread != n) + { + if (feof(f)) + return -1; + + throw std::runtime_error(formatString("Error reading: %s", strerror(errno))); + } + + return nread; +} + +void PersistenceFile::hashFile() +{ + logger->logf(LOG_DEBUG, "Calculating and saving hash of '%s'.", filePath.c_str()); + + fseek(f, TOTAL_HEADER_SIZE, SEEK_SET); + + unsigned int output_len = 0; + unsigned char md_value[EVP_MAX_MD_SIZE]; + std::memset(md_value, 0, EVP_MAX_MD_SIZE); + + EVP_MD_CTX_reset(digestContext); + EVP_DigestInit_ex(digestContext, sha512, NULL); + + while (!feof(f)) + { + size_t n = fread(buf.data(), 1, buf.size(), f); + EVP_DigestUpdate(digestContext, buf.data(), n); + } + + EVP_DigestFinal_ex(digestContext, md_value, &output_len); + + if (output_len != HASH_SIZE) + throw std::runtime_error("Impossible: calculated hash size wrong length"); + + fseek(f, MAGIC_STRING_LENGH, SEEK_SET); + + writeCheck(md_value, output_len, 1, f); + fflush(f); +} + +void PersistenceFile::verifyHash() +{ + fseek(f, 0, SEEK_END); + const size_t size = ftell(f); + + if (size < TOTAL_HEADER_SIZE) + throw std::runtime_error(formatString("File '%s' is too small for it even to contain a header.", filePath.c_str())); + + unsigned char md_from_disk[HASH_SIZE]; + std::memset(md_from_disk, 0, HASH_SIZE); + + fseek(f, MAGIC_STRING_LENGH, SEEK_SET); + fread(md_from_disk, 1, HASH_SIZE, f); + + unsigned int output_len = 0; + unsigned char md_value[EVP_MAX_MD_SIZE]; + std::memset(md_value, 0, EVP_MAX_MD_SIZE); + + EVP_MD_CTX_reset(digestContext); + EVP_DigestInit_ex(digestContext, sha512, NULL); + + while (!feof(f)) + { + size_t n = fread(buf.data(), 1, buf.size(), f); + EVP_DigestUpdate(digestContext, buf.data(), n); + } + + EVP_DigestFinal_ex(digestContext, md_value, &output_len); + + if (output_len != HASH_SIZE) + throw std::runtime_error("Impossible: calculated hash size wrong length"); + + if (std::memcmp(md_from_disk, md_value, output_len) != 0) + { + fclose(f); + f = nullptr; + + if (rename(filePath.c_str(), filePathCorrupt.c_str()) == 0) + { + throw std::runtime_error(formatString("File '%s' is corrupt: hash mismatch. Moved aside to '%s'.", filePath.c_str(), filePathCorrupt.c_str())); + } + else + { + throw std::runtime_error(formatString("File '%s' is corrupt: hash mismatch. Tried to move aside, but that failed: '%s'.", + filePath.c_str(), strerror(errno))); + } + } + + logger->logf(LOG_DEBUG, "Hash of '%s' correct", filePath.c_str()); +} + +/** + * @brief PersistenceFile::makeSureBufSize grows the buffer if n is bigger. + * @param n in bytes. + * + * Remember that when you're dealing with fields that are sized in MQTT by 16 bit ints, like topic paths, the buffer will always be big enough, because it's 1 MB. + */ +void PersistenceFile::makeSureBufSize(size_t n) +{ + if (n > buf.size()) + buf.resize(n); +} + +void PersistenceFile::writeInt64(const int64_t val) +{ + unsigned char buf[8]; + + // Write big-endian + int shift = 56; + int i = 0; + while (shift >= 0) + { + unsigned char wantedByte = val >> shift; + buf[i++] = wantedByte; + shift -= 8; + } + writeCheck(buf, 1, 8, f); +} + +void PersistenceFile::writeUint32(const uint32_t val) +{ + unsigned char buf[4]; + + // Write big-endian + int shift = 24; + int i = 0; + while (shift >= 0) + { + unsigned char wantedByte = val >> shift; + buf[i++] = wantedByte; + shift -= 8; + } + writeCheck(buf, 1, 4, f); +} + +void PersistenceFile::writeUint16(const uint16_t val) +{ + unsigned char buf[2]; + + // Write big-endian + int shift = 8; + int i = 0; + while (shift >= 0) + { + unsigned char wantedByte = val >> shift; + buf[i++] = wantedByte; + shift -= 8; + } + writeCheck(buf, 1, 2, f); +} + +int64_t PersistenceFile::readInt64(bool &eofFound) +{ + if (readCheck(buf.data(), 1, 8, f) < 0) + eofFound = true; + + unsigned char *buf_ = reinterpret_cast(buf.data()); + const uint64_t val1 = ((buf_[0]) << 24) | ((buf_[1]) << 16) | ((buf_[2]) << 8) | (buf_[3]); + const uint64_t val2 = ((buf_[4]) << 24) | ((buf_[5]) << 16) | ((buf_[6]) << 8) | (buf_[7]); + const int64_t val = (val1 << 32) | val2; + return val; +} + +uint32_t PersistenceFile::readUint32(bool &eofFound) +{ + if (readCheck(buf.data(), 1, 4, f) < 0) + eofFound = true; + + uint32_t val; + unsigned char *buf_ = reinterpret_cast(buf.data()); + val = ((buf_[0]) << 24) | ((buf_[1]) << 16) | ((buf_[2]) << 8) | (buf_[3]); + return val; +} + +uint16_t PersistenceFile::readUint16(bool &eofFound) +{ + if (readCheck(buf.data(), 1, 2, f) < 0) + eofFound = true; + + uint16_t val; + unsigned char *buf_ = reinterpret_cast(buf.data()); + val = ((buf_[0]) << 8) | (buf_[1]); + return val; +} + +/** + * @brief RetainedMessagesDB::openWrite doesn't explicitely name a file version (v1, etc), because we always write the current definition. + */ +void PersistenceFile::openWrite(const std::string &versionString) +{ + if (openMode != FileMode::unknown) + throw std::runtime_error("File is already open."); + + f = fopen(filePathTemp.c_str(), "w+b"); + + if (f == nullptr) + { + throw std::runtime_error(formatString("Can't open '%s': %s", filePathTemp.c_str(), strerror(errno))); + } + + openMode = FileMode::write; + + writeCheck(buf.data(), 1, MAGIC_STRING_LENGH, f); + rewind(f); + writeCheck(versionString.c_str(), 1, versionString.length(), f); + fseek(f, MAGIC_STRING_LENGH, SEEK_SET); + writeCheck(buf.data(), 1, HASH_SIZE, f); +} + +void PersistenceFile::openRead() +{ + if (openMode != FileMode::unknown) + throw std::runtime_error("File is already open."); + + f = fopen(filePath.c_str(), "rb"); + + if (f == nullptr) + throw PersistenceFileCantBeOpened(formatString("Can't open '%s': %s.", filePath.c_str(), strerror(errno)).c_str()); + + openMode = FileMode::read; + + verifyHash(); + rewind(f); + + fread(buf.data(), 1, MAGIC_STRING_LENGH, f); + detectedVersionString = std::string(buf.data(), strlen(buf.data())); + + fseek(f, TOTAL_HEADER_SIZE, SEEK_SET); +} + +void PersistenceFile::closeFile() +{ + if (!f) + return; + + if (openMode == FileMode::write) + hashFile(); + + if (f != nullptr) + { + fclose(f); + f = nullptr; + } + + if (openMode == FileMode::write && !filePathTemp.empty() && ! filePath.empty()) + { + if (rename(filePathTemp.c_str(), filePath.c_str()) < 0) + throw std::runtime_error(formatString("Saving '%s' failed: rename of temp file to target failed with: %s", filePath.c_str(), strerror(errno))); + } +} diff --git a/persistencefile.h b/persistencefile.h new file mode 100644 index 0000000..6000d01 --- /dev/null +++ b/persistencefile.h @@ -0,0 +1,92 @@ +/* +This file is part of FlashMQ (https://www.flashmq.org) +Copyright (C) 2021 Wiebe Cazemier + +FlashMQ is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, version 3. + +FlashMQ is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public +License along with FlashMQ. If not, see . +*/ + +#ifndef PERSISTENCEFILE_H +#define PERSISTENCEFILE_H + +#include +#include +#include +#include +#include +#include +#include + +#include "logger.h" + +#define MAGIC_STRING_LENGH 32 +#define HASH_SIZE 64 +#define TOTAL_HEADER_SIZE (MAGIC_STRING_LENGH + HASH_SIZE) + +/** + * @brief The PersistenceFileCantBeOpened class should be thrown when a non-fatal file-not-found error happens. + */ +class PersistenceFileCantBeOpened : public std::runtime_error +{ +public: + PersistenceFileCantBeOpened(const std::string &msg) : std::runtime_error(msg) {} +}; + +class PersistenceFile +{ + std::string filePath; + std::string filePathTemp; + std::string filePathCorrupt; + + EVP_MD_CTX *digestContext = nullptr; + const EVP_MD *sha512 = EVP_sha512(); + + void hashFile(); + void verifyHash(); + +protected: + enum class FileMode + { + unknown, + read, + write + }; + + FILE *f = nullptr; + std::vector buf; + FileMode openMode = FileMode::unknown; + std::string detectedVersionString; + + Logger *logger = Logger::getInstance(); + + void makeSureBufSize(size_t n); + + void writeCheck(const void *__restrict __ptr, size_t __size, size_t __n, FILE *__restrict __s); + ssize_t readCheck(void *__restrict ptr, size_t size, size_t n, FILE *__restrict stream); + + void writeInt64(const int64_t val); + void writeUint32(const uint32_t val); + void writeUint16(const uint16_t val); + int64_t readInt64(bool &eofFound); + uint32_t readUint32(bool &eofFound); + uint16_t readUint16(bool &eofFound); + +public: + PersistenceFile(const std::string &filePath); + ~PersistenceFile(); + + void openWrite(const std::string &versionString); + void openRead(); + void closeFile(); +}; + +#endif // PERSISTENCEFILE_H diff --git a/qospacketqueue.cpp b/qospacketqueue.cpp new file mode 100644 index 0000000..696bedd --- /dev/null +++ b/qospacketqueue.cpp @@ -0,0 +1,77 @@ +#include "qospacketqueue.h" + +#include "cassert" + +#include "mqttpacket.h" + +void QoSPacketQueue::erase(const uint16_t packet_id) +{ + auto it = queue.begin(); + auto end = queue.end(); + while (it != end) + { + std::shared_ptr &p = *it; + if (p->getPacketId() == packet_id) + { + size_t mem = p->getTotalMemoryFootprint(); + qosQueueBytes -= mem; + assert(qosQueueBytes >= 0); + if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose. + qosQueueBytes = 0; + + queue.erase(it); + + break; + } + + it++; + } +} + +size_t QoSPacketQueue::size() const +{ + return queue.size(); +} + +size_t QoSPacketQueue::getByteSize() const +{ + return qosQueueBytes; +} + +/** + * @brief QoSPacketQueue::queuePacket makes a copy of the packet because it has state for the receiver in question. + * @param p + * @param id + * @return the packet copy. + */ +std::shared_ptr QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id) +{ + assert(p.getQos() > 0); + + std::shared_ptr copyPacket = p.getCopy(); + copyPacket->setPacketId(id); + queue.push_back(copyPacket); + qosQueueBytes += copyPacket->getTotalMemoryFootprint(); + return copyPacket; +} + +std::shared_ptr QoSPacketQueue::queuePacket(const Publish &pub, uint16_t id) +{ + assert(pub.qos > 0); + + std::shared_ptr copyPacket(new MqttPacket(pub)); + copyPacket->setPacketId(id); + queue.push_back(copyPacket); + qosQueueBytes += copyPacket->getTotalMemoryFootprint(); + return copyPacket; +} + +std::list>::const_iterator QoSPacketQueue::begin() const +{ + return queue.cbegin(); +} + +std::list>::const_iterator QoSPacketQueue::end() const +{ + return queue.cend(); +} diff --git a/qospacketqueue.h b/qospacketqueue.h new file mode 100644 index 0000000..f8b18d1 --- /dev/null +++ b/qospacketqueue.h @@ -0,0 +1,25 @@ +#ifndef QOSPACKETQUEUE_H +#define QOSPACKETQUEUE_H + +#include "list" + +#include "forward_declarations.h" +#include "types.h" + +class QoSPacketQueue +{ + std::list> queue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] + ssize_t qosQueueBytes = 0; + +public: + void erase(const uint16_t packet_id); + size_t size() const; + size_t getByteSize() const; + std::shared_ptr queuePacket(const MqttPacket &p, uint16_t id); + std::shared_ptr queuePacket(const Publish &pub, uint16_t id); + + std::list>::const_iterator begin() const; + std::list>::const_iterator end() const; +}; + +#endif // QOSPACKETQUEUE_H diff --git a/retainedmessage.cpp b/retainedmessage.cpp index 29647d1..c01f9e2 100644 --- a/retainedmessage.cpp +++ b/retainedmessage.cpp @@ -34,3 +34,8 @@ bool RetainedMessage::empty() const { return payload.empty(); } + +uint32_t RetainedMessage::getSize() const +{ + return topic.length() + payload.length() + 1; +} diff --git a/retainedmessage.h b/retainedmessage.h index 01491b0..97b5369 100644 --- a/retainedmessage.h +++ b/retainedmessage.h @@ -30,6 +30,7 @@ struct RetainedMessage bool operator==(const RetainedMessage &rhs) const; bool empty() const; + uint32_t getSize() const; }; namespace std { diff --git a/retainedmessagesdb.cpp b/retainedmessagesdb.cpp new file mode 100644 index 0000000..70ae9b0 --- /dev/null +++ b/retainedmessagesdb.cpp @@ -0,0 +1,146 @@ +/* +This file is part of FlashMQ (https://www.flashmq.org) +Copyright (C) 2021 Wiebe Cazemier + +FlashMQ is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, version 3. + +FlashMQ is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public +License along with FlashMQ. If not, see . +*/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "retainedmessagesdb.h" +#include "utils.h" +#include "logger.h" + +RetainedMessagesDB::RetainedMessagesDB(const std::string &filePath) : PersistenceFile(filePath) +{ + +} + +void RetainedMessagesDB::openWrite() +{ + PersistenceFile::openWrite(MAGIC_STRING_V1); +} + +void RetainedMessagesDB::openRead() +{ + PersistenceFile::openRead(); + + if (detectedVersionString == MAGIC_STRING_V1) + readVersion = ReadVersion::v1; + else + throw std::runtime_error("Unknown file version."); +} + +/** + * @brief RetainedMessagesDB::writeRowHeader writes two 32 bit integers: topic size and payload size. + * @param rm + * + * So, the header per message is 8 bytes long. + * + * It writes no information about the length of the QoS value, because that is always one. + */ +void RetainedMessagesDB::writeRowHeader(const RetainedMessage &rm) +{ + writeUint32(rm.topic.size()); + writeUint32(rm.payload.size()); +} + +RetainedMessagesDB::RowHeader RetainedMessagesDB::readRowHeaderV1(bool &eofFound) +{ + RetainedMessagesDB::RowHeader result; + + result.topicLen = readUint32(eofFound); + result.payloadLen = readUint32(eofFound); + + return result; +} + +/** + * @brief RetainedMessagesDB::saveData doesn't explicitely name a file version (v1, etc), because we always write the current definition. + * @param messages + */ +void RetainedMessagesDB::saveData(const std::vector &messages) +{ + if (!f) + return; + + char reserved[RESERVED_SPACE_RETAINED_DB_V1]; + std::memset(reserved, 0, RESERVED_SPACE_RETAINED_DB_V1); + + char qos = 0; + for (const RetainedMessage &rm : messages) + { + logger->logf(LOG_DEBUG, "Saving retained message for topic '%s' QoS %d.", rm.topic.c_str(), rm.qos); + + writeRowHeader(rm); + qos = rm.qos; + writeCheck(&qos, 1, 1, f); + writeCheck(reserved, 1, RESERVED_SPACE_RETAINED_DB_V1, f); + writeCheck(rm.topic.c_str(), 1, rm.topic.length(), f); + writeCheck(rm.payload.c_str(), 1, rm.payload.length(), f); + } + + fflush(f); +} + +std::list RetainedMessagesDB::readData() +{ + std::list defaultResult; + + if (!f) + return defaultResult; + + if (readVersion == ReadVersion::v1) + return readDataV1(); + + return defaultResult; +} + +std::list RetainedMessagesDB::readDataV1() +{ + std::list messages; + + while (!feof(f)) + { + bool eofFound = false; + RetainedMessagesDB::RowHeader header = readRowHeaderV1(eofFound); + + if (eofFound) + continue; + + makeSureBufSize(header.payloadLen); + + readCheck(buf.data(), 1, 1, f); + char qos = buf[0]; + fseek(f, RESERVED_SPACE_RETAINED_DB_V1, SEEK_CUR); + + readCheck(buf.data(), 1, header.topicLen, f); + std::string topic(buf.data(), header.topicLen); + + readCheck(buf.data(), 1, header.payloadLen, f); + std::string payload(buf.data(), header.payloadLen); + + RetainedMessage msg(topic, payload, qos); + logger->logf(LOG_DEBUG, "Loading retained message for topic '%s' QoS %d.", msg.topic.c_str(), msg.qos); + messages.push_back(std::move(msg)); + } + + return messages; +} diff --git a/retainedmessagesdb.h b/retainedmessagesdb.h new file mode 100644 index 0000000..7b2299d --- /dev/null +++ b/retainedmessagesdb.h @@ -0,0 +1,71 @@ +/* +This file is part of FlashMQ (https://www.flashmq.org) +Copyright (C) 2021 Wiebe Cazemier + +FlashMQ is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, version 3. + +FlashMQ is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public +License along with FlashMQ. If not, see . +*/ + +#ifndef RETAINEDMESSAGESDB_H +#define RETAINEDMESSAGESDB_H + +#include "persistencefile.h" +#include "retainedmessage.h" + +#include "logger.h" + +#define MAGIC_STRING_V1 "FlashMQRetainedDBv1" +#define ROW_HEADER_SIZE 8 +#define RESERVED_SPACE_RETAINED_DB_V1 31 + +/** + * @brief The RetainedMessagesDB class saves and loads the retained messages. + * + * The DB looks like, from the top: + * + * MAGIC_STRING_LENGH bytes file header + * HASH_SIZE SHA512 + * [MESSAGES] + * + * Each message has a row header, which is 8 bytes. See writeRowHeader(). + * + */ +class RetainedMessagesDB : public PersistenceFile +{ + enum class ReadVersion + { + unknown, + v1 + }; + + struct RowHeader + { + uint32_t topicLen = 0; + uint32_t payloadLen = 0; + }; + + ReadVersion readVersion = ReadVersion::unknown; + + void writeRowHeader(const RetainedMessage &rm); + RowHeader readRowHeaderV1(bool &eofFound); + std::list readDataV1(); +public: + RetainedMessagesDB(const std::string &filePath); + + void openWrite(); + void openRead(); + + void saveData(const std::vector &messages); + std::list readData(); +}; + +#endif // RETAINEDMESSAGESDB_H diff --git a/rwlockguard.cpp b/rwlockguard.cpp index 15bae1c..c25c324 100644 --- a/rwlockguard.cpp +++ b/rwlockguard.cpp @@ -32,7 +32,15 @@ RWLockGuard::~RWLockGuard() void RWLockGuard::wrlock() { - if (pthread_rwlock_wrlock(rwlock) != 0) + const int rc = pthread_rwlock_wrlock(rwlock); + + if (rc == EDEADLK) + { + rwlock = nullptr; + return; + } + + if (rc != 0) throw std::runtime_error("wrlock failed."); } diff --git a/session.cpp b/session.cpp index 3b328f8..e0fc9fe 100644 --- a/session.cpp +++ b/session.cpp @@ -20,16 +20,75 @@ License along with FlashMQ. If not, see . #include "session.h" #include "client.h" +std::chrono::time_point appStartTime = std::chrono::steady_clock::now(); + Session::Session() { } +int64_t Session::getProgramStartedAtUnixTimestamp() +{ + auto secondsSinceEpoch = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); + const std::chrono::seconds age = std::chrono::duration_cast(std::chrono::steady_clock::now() - appStartTime); + int64_t result = secondsSinceEpoch - age.count(); + return result; +} + +void Session::setProgramStartedAtUnixTimestamp(const int64_t unix_timestamp) +{ + auto secondsSinceEpoch = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()); + const std::chrono::seconds _unix_timestamp = std::chrono::seconds(unix_timestamp); + const std::chrono::seconds age_in_s = secondsSinceEpoch - _unix_timestamp; + appStartTime = std::chrono::steady_clock::now() - age_in_s; +} + + +int64_t Session::getSessionRelativeAgeInMs() const +{ + const std::chrono::milliseconds sessionAge = std::chrono::duration_cast(lastTouched - appStartTime); + const int64_t sInMs = sessionAge.count(); + return sInMs; +} + +void Session::setSessionTouch(int64_t ageInMs) +{ + std::chrono::milliseconds ms(ageInMs); + std::chrono::time_point point = appStartTime + ms; + lastTouched = point; +} + +/** + * @brief Session::Session copy constructor. Was created for session storing, and is explicitely kept private, to avoid making accidental copies. + * @param other + * + * Because it was created for session storing, the fields we're copying are the fields being stored. + */ +Session::Session(const Session &other) +{ + this->username = other.username; + this->client_id = other.client_id; + this->incomingQoS2MessageIds = other.incomingQoS2MessageIds; + this->outgoingQoS2MessageIds = other.outgoingQoS2MessageIds; + this->nextPacketId = other.nextPacketId; + this->lastTouched = other.lastTouched; + + // To be fully correct, we should copy the individual packets, but copying sessions is only done for saving them, and I know + // that no member of MqttPacket changes in the QoS process, so we can just keep the shared pointer to the original. + this->qosPacketQueue = other.qosPacketQueue; +} + Session::~Session() { logger->logf(LOG_DEBUG, "Session %s is being destroyed.", getClientId().c_str()); } +std::unique_ptr Session::getCopy() const +{ + std::unique_ptr s(new Session(*this)); + return s; +} + bool Session::clientDisconnected() const { return client.expired(); @@ -45,7 +104,6 @@ void Session::assignActiveConnection(std::shared_ptr &client) this->client = client; this->client_id = client->getClientId(); this->username = client->getUsername(); - this->thread = client->getThreadData(); } /** @@ -60,7 +118,9 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u assert(max_qos <= 2); const char qos = std::min(packet.getQos(), max_qos); - if (thread->authentication.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success) + assert(packet.getSender()); + Authentication &auth = packet.getSender()->getThreadData()->authentication; + if (auth.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success) { if (qos == 0) { @@ -73,11 +133,10 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u } else if (qos > 0) { - std::shared_ptr copyPacket = packet.getCopy(); std::unique_lock locker(qosQueueMutex); const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size(); - if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) + if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosPacketQueue.getByteSize() >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) { logger->logf(LOG_WARNING, "Dropping QoS message for client '%s', because its QoS buffers were full.", client_id.c_str()); return; @@ -86,13 +145,7 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u if (nextPacketId == 0) nextPacketId++; - const uint16_t pid = nextPacketId; - copyPacket->setPacketId(pid); - QueuedQosPacket p; - p.packet = copyPacket; - p.id = pid; - qosPacketQueue.push_back(p); - qosQueueBytes += copyPacket->getTotalMemoryFootprint(); + std::shared_ptr copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId); locker.unlock(); if (!clientDisconnected()) @@ -115,27 +168,7 @@ void Session::clearQosMessage(uint16_t packet_id) #endif std::lock_guard locker(qosQueueMutex); - - auto it = qosPacketQueue.begin(); - auto end = qosPacketQueue.end(); - while (it != end) - { - QueuedQosPacket &p = *it; - if (p.id == packet_id) - { - size_t mem = p.packet->getTotalMemoryFootprint(); - qosQueueBytes -= mem; - assert(qosQueueBytes >= 0); - if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose. - qosQueueBytes = 0; - - qosPacketQueue.erase(it); - - break; - } - - it++; - } + qosPacketQueue.erase(packet_id); } // [MQTT-4.4.0-1]: "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST re-send any @@ -152,10 +185,10 @@ uint64_t Session::sendPendingQosMessages() { std::shared_ptr c = makeSharedClient(); std::lock_guard locker(qosQueueMutex); - for (QueuedQosPacket &qosMessage : qosPacketQueue) + for (const std::shared_ptr &qosMessage : qosPacketQueue) { - c->writeMqttPacketAndBlameThisClient(*qosMessage.packet.get(), qosMessage.packet->getQos()); - qosMessage.packet->setDuplicate(); // Any dealings with this packet from here will be a duplicate. + c->writeMqttPacketAndBlameThisClient(*qosMessage.get(), qosMessage->getQos()); + qosMessage->setDuplicate(); // Any dealings with this packet from here will be a duplicate. count++; } diff --git a/session.h b/session.h index c0faad3..ec2770e 100644 --- a/session.h +++ b/session.h @@ -25,38 +25,46 @@ License along with FlashMQ. If not, see . #include "forward_declarations.h" #include "logger.h" +#include "sessionsandsubscriptionsdb.h" +#include "qospacketqueue.h" // TODO make settings. But, num of packets can't exceed 65536, because the counter is 16 bit. #define MAX_QOS_MSG_PENDING_PER_CLIENT 32 #define MAX_QOS_BYTES_PENDING_PER_CLIENT 4096 -struct QueuedQosPacket -{ - uint16_t id; - std::shared_ptr packet; -}; - class Session { +#ifdef TESTING + friend class MainTests; +#endif + + friend class SessionsAndSubscriptionsDB; + std::weak_ptr client; - std::shared_ptr thread; std::string client_id; std::string username; - std::list qosPacketQueue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] + QoSPacketQueue qosPacketQueue; std::set incomingQoS2MessageIds; std::set outgoingQoS2MessageIds; std::mutex qosQueueMutex; uint16_t nextPacketId = 0; - ssize_t qosQueueBytes = 0; - std::chrono::time_point lastTouched; + std::chrono::time_point lastTouched = std::chrono::steady_clock::now(); Logger *logger = Logger::getInstance(); + int64_t getSessionRelativeAgeInMs() const; + void setSessionTouch(int64_t ageInMs); + Session(const Session &other); public: Session(); - Session(const Session &other) = delete; + Session(Session &&other) = delete; ~Session(); + static int64_t getProgramStartedAtUnixTimestamp(); + static void setProgramStartedAtUnixTimestamp(const int64_t unix_timestamp); + + std::unique_ptr getCopy() const; + const std::string &getClientId() const { return client_id; } bool clientDisconnected() const; std::shared_ptr makeSharedClient() const; @@ -74,7 +82,6 @@ public: void addOutgoingQoS2MessageId(uint16_t packet_id); void removeOutgoingQoS2MessageId(u_int16_t packet_id); - }; #endif // SESSION_H diff --git a/sessionsandsubscriptionsdb.cpp b/sessionsandsubscriptionsdb.cpp new file mode 100644 index 0000000..b9b2bf5 --- /dev/null +++ b/sessionsandsubscriptionsdb.cpp @@ -0,0 +1,299 @@ +/* +This file is part of FlashMQ (https://www.flashmq.org) +Copyright (C) 2021 Wiebe Cazemier + +FlashMQ is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, version 3. + +FlashMQ is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public +License along with FlashMQ. If not, see . +*/ + +#include "sessionsandsubscriptionsdb.h" +#include "mqttpacket.h" + +#include "cassert" + +SubscriptionForSerializing::SubscriptionForSerializing(const std::string &clientId, char qos) : + clientId(clientId), + qos(qos) +{ + +} + +SubscriptionForSerializing::SubscriptionForSerializing(const std::string &&clientId, char qos) : + clientId(clientId), + qos(qos) +{ + +} + +SessionsAndSubscriptionsDB::SessionsAndSubscriptionsDB(const std::string &filePath) : PersistenceFile(filePath) +{ + +} + +void SessionsAndSubscriptionsDB::openWrite() +{ + PersistenceFile::openWrite(MAGIC_STRING_SESSION_FILE_V1); +} + +void SessionsAndSubscriptionsDB::openRead() +{ + PersistenceFile::openRead(); + + if (detectedVersionString == MAGIC_STRING_SESSION_FILE_V1) + readVersion = ReadVersion::v1; + else + throw std::runtime_error("Unknown file version."); +} + +SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV1() +{ + SessionsAndSubscriptionsResult result; + + while (!feof(f)) + { + bool eofFound = false; + + const int64_t programStartAge = readInt64(eofFound); + if (eofFound) + continue; + + logger->logf(LOG_DEBUG, "Setting first app start time to timestamp %ld", programStartAge); + Session::setProgramStartedAtUnixTimestamp(programStartAge); + + const uint32_t nrOfSessions = readUint32(eofFound); + + if (eofFound) + continue; + + std::vector reserved(RESERVED_SPACE_SESSIONS_DB_V1); + + for (uint32_t i = 0; i < nrOfSessions; i++) + { + readCheck(buf.data(), 1, RESERVED_SPACE_SESSIONS_DB_V1, f); + + uint32_t usernameLength = readUint32(eofFound); + readCheck(buf.data(), 1, usernameLength, f); + std::string username(buf.data(), usernameLength); + + uint32_t clientIdLength = readUint32(eofFound); + readCheck(buf.data(), 1, clientIdLength, f); + std::string clientId(buf.data(), clientIdLength); + + std::shared_ptr ses(new Session()); + result.sessions.push_back(ses); + ses->username = username; + ses->client_id = clientId; + + logger->logf(LOG_DEBUG, "Loading session '%s'.", ses->getClientId().c_str()); + + const uint32_t nrOfQueuedQoSPackets = readUint32(eofFound); + for (uint32_t i = 0; i < nrOfQueuedQoSPackets; i++) + { + const uint16_t id = readUint16(eofFound); + const uint32_t topicSize = readUint32(eofFound); + const uint32_t payloadSize = readUint32(eofFound); + + assert(id > 0); + + readCheck(buf.data(), 1, 1, f); + const unsigned char qos = buf[0]; + + readCheck(buf.data(), 1, topicSize, f); + const std::string topic(buf.data(), topicSize); + + makeSureBufSize(payloadSize); + readCheck(buf.data(), 1, payloadSize, f); + const std::string payload(buf.data(), payloadSize); + + Publish pub(topic, payload, qos); + logger->logf(LOG_DEBUG, "Loaded QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str()); + ses->qosPacketQueue.queuePacket(pub, id); + } + + const uint32_t nrOfIncomingPacketIds = readUint32(eofFound); + for (uint32_t i = 0; i < nrOfIncomingPacketIds; i++) + { + uint16_t id = readUint16(eofFound); + assert(id > 0); + logger->logf(LOG_DEBUG, "Loaded incomming QoS2 message id %d.", id); + ses->incomingQoS2MessageIds.insert(id); + } + + const uint32_t nrOfOutgoingPacketIds = readUint32(eofFound); + for (uint32_t i = 0; i < nrOfOutgoingPacketIds; i++) + { + uint16_t id = readUint16(eofFound); + assert(id > 0); + logger->logf(LOG_DEBUG, "Loaded outgoing QoS2 message id %d.", id); + ses->outgoingQoS2MessageIds.insert(id); + } + + const uint16_t nextPacketId = readUint16(eofFound); + logger->logf(LOG_DEBUG, "Loaded next packetid %d.", ses->nextPacketId); + ses->nextPacketId = nextPacketId; + + int64_t sessionAge = readInt64(eofFound); + logger->logf(LOG_DEBUG, "Loaded session age: %ld ms.", sessionAge); + ses->setSessionTouch(sessionAge); + } + + const uint32_t nrOfSubscriptions = readUint32(eofFound); + for (uint32_t i = 0; i < nrOfSubscriptions; i++) + { + const uint32_t topicLength = readUint32(eofFound); + readCheck(buf.data(), 1, topicLength, f); + const std::string topic(buf.data(), topicLength); + + logger->logf(LOG_DEBUG, "Loading subscriptions to topic '%s'.", topic.c_str()); + + const uint32_t nrOfClientIds = readUint32(eofFound); + + for (uint32_t i = 0; i < nrOfClientIds; i++) + { + const uint32_t clientIdLength = readUint32(eofFound); + readCheck(buf.data(), 1, clientIdLength, f); + const std::string clientId(buf.data(), clientIdLength); + + char qos; + readCheck(&qos, 1, 1, f); + + logger->logf(LOG_DEBUG, "Saving session '%s' subscription to '%s' QoS %d.", clientId.c_str(), topic.c_str(), qos); + + SubscriptionForSerializing sub(std::move(clientId), qos); + result.subscriptions[topic].push_back(std::move(sub)); + } + + } + } + + return result; +} + +void SessionsAndSubscriptionsDB::writeRowHeader() +{ + +} + +void SessionsAndSubscriptionsDB::saveData(const std::list> &sessions, const std::unordered_map> &subscriptions) +{ + if (!f) + return; + + char reserved[RESERVED_SPACE_SESSIONS_DB_V1]; + std::memset(reserved, 0, RESERVED_SPACE_SESSIONS_DB_V1); + + const int64_t start_stamp = Session::getProgramStartedAtUnixTimestamp(); + logger->logf(LOG_DEBUG, "Saving program first start time stamp as %ld", start_stamp); + writeInt64(start_stamp); + + writeUint32(sessions.size()); + + for (const std::unique_ptr &ses : sessions) + { + logger->logf(LOG_DEBUG, "Saving session '%s'.", ses->getClientId().c_str()); + + writeRowHeader(); + + writeCheck(reserved, 1, RESERVED_SPACE_SESSIONS_DB_V1, f); + + writeUint32(ses->username.length()); + writeCheck(ses->username.c_str(), 1, ses->username.length(), f); + + writeUint32(ses->client_id.length()); + writeCheck(ses->client_id.c_str(), 1, ses->client_id.length(), f); + + const size_t qosPacketsExpected = ses->qosPacketQueue.size(); + size_t qosPacketsCounted = 0; + writeUint32(qosPacketsExpected); + + for (const std::shared_ptr &p: ses->qosPacketQueue) + { + logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", p->getQos(), p->getTopic().c_str()); + + qosPacketsCounted++; + + writeUint16(p->getPacketId()); + + writeUint32(p->getTopic().length()); + std::string payload = p->getPayloadCopy(); + writeUint32(payload.size()); + + const char qos = p->getQos(); + writeCheck(&qos, 1, 1, f); + + writeCheck(p->getTopic().c_str(), 1, p->getTopic().length(), f); + writeCheck(payload.c_str(), 1, payload.length(), f); + } + + assert(qosPacketsExpected == qosPacketsCounted); + + writeUint32(ses->incomingQoS2MessageIds.size()); + for (uint16_t id : ses->incomingQoS2MessageIds) + { + logger->logf(LOG_DEBUG, "Writing incomming QoS2 message id %d.", id); + writeUint16(id); + } + + writeUint32(ses->outgoingQoS2MessageIds.size()); + for (uint16_t id : ses->outgoingQoS2MessageIds) + { + logger->logf(LOG_DEBUG, "Writing outgoing QoS2 message id %d.", id); + writeUint16(id); + } + + logger->logf(LOG_DEBUG, "Writing next packetid %d.", ses->nextPacketId); + writeUint16(ses->nextPacketId); + + const int64_t sInMs = ses->getSessionRelativeAgeInMs(); + logger->logf(LOG_DEBUG, "Writing session age: %ld ms.", sInMs); + writeInt64(sInMs); + } + + writeUint32(subscriptions.size()); + + for (auto &pair : subscriptions) + { + const std::string &topic = pair.first; + const std::list &subscriptions = pair.second; + + logger->logf(LOG_DEBUG, "Writing subscriptions to topic '%s'.", topic.c_str()); + + writeUint32(topic.size()); + writeCheck(topic.c_str(), 1, topic.size(), f); + + writeUint32(subscriptions.size()); + + for (const SubscriptionForSerializing &subscription : subscriptions) + { + logger->logf(LOG_DEBUG, "Saving session '%s' subscription to '%s' QoS %d.", subscription.clientId.c_str(), topic.c_str(), subscription.qos); + + writeUint32(subscription.clientId.size()); + writeCheck(subscription.clientId.c_str(), 1, subscription.clientId.size(), f); + writeCheck(&subscription.qos, 1, 1, f); + } + } + + fflush(f); +} + +SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readData() +{ + SessionsAndSubscriptionsResult defaultResult; + + if (!f) + return defaultResult; + + if (readVersion == ReadVersion::v1) + return readDataV1(); + + return defaultResult; +} diff --git a/sessionsandsubscriptionsdb.h b/sessionsandsubscriptionsdb.h new file mode 100644 index 0000000..15a242e --- /dev/null +++ b/sessionsandsubscriptionsdb.h @@ -0,0 +1,71 @@ +/* +This file is part of FlashMQ (https://www.flashmq.org) +Copyright (C) 2021 Wiebe Cazemier + +FlashMQ is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, version 3. + +FlashMQ is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public +License along with FlashMQ. If not, see . +*/ + +#ifndef SESSIONSANDSUBSCRIPTIONSDB_H +#define SESSIONSANDSUBSCRIPTIONSDB_H + +#include +#include + +#include "persistencefile.h" +#include "session.h" + +#define MAGIC_STRING_SESSION_FILE_V1 "FlashMQRetainedDBv1" +#define RESERVED_SPACE_SESSIONS_DB_V1 32 + +/** + * @brief The SubscriptionForSerializing struct contains the fields we're interested in when saving a subscription. + */ +struct SubscriptionForSerializing +{ + const std::string clientId; + const char qos = 0; + + SubscriptionForSerializing(const std::string &clientId, char qos); + SubscriptionForSerializing(const std::string &&clientId, char qos); +}; + +struct SessionsAndSubscriptionsResult +{ + std::list> sessions; + std::unordered_map> subscriptions; +}; + + +class SessionsAndSubscriptionsDB : public PersistenceFile +{ + enum class ReadVersion + { + unknown, + v1 + }; + + ReadVersion readVersion = ReadVersion::unknown; + + SessionsAndSubscriptionsResult readDataV1(); + void writeRowHeader(); +public: + SessionsAndSubscriptionsDB(const std::string &filePath); + + void openWrite(); + void openRead(); + + void saveData(const std::list> &sessions, const std::unordered_map> &subscriptions); + SessionsAndSubscriptionsResult readData(); +}; + +#endif // SESSIONSANDSUBSCRIPTIONSDB_H diff --git a/settings.cpp b/settings.cpp index 49b175e..629ea52 100644 --- a/settings.cpp +++ b/settings.cpp @@ -16,7 +16,7 @@ License along with FlashMQ. If not, see . */ #include "settings.h" - +#include "utils.h" AuthOptCompatWrap &Settings::getAuthOptsCompat() { @@ -27,3 +27,21 @@ std::unordered_map &Settings::getFlashmqAuthPluginOpts { return this->flashmqAuthPluginOpts; } + +std::string Settings::getRetainedMessagesDBFile() const +{ + if (storageDir.empty()) + return ""; + + std::string path = formatString("%s/%s", storageDir.c_str(), "retained.db"); + return path; +} + +std::string Settings::getSessionsDBFile() const +{ + if (storageDir.empty()) + return ""; + + std::string path = formatString("%s/%s", storageDir.c_str(), "sessions.db"); + return path; +} diff --git a/settings.h b/settings.h index d79e486..ecc627a 100644 --- a/settings.h +++ b/settings.h @@ -53,10 +53,14 @@ public: int rlimitNoFile = 1000000; uint64_t expireSessionsAfterSeconds = 1209600; int authPluginTimerPeriod = 60; + std::string storageDir; std::list> listeners; // Default one is created later, when none are defined. AuthOptCompatWrap &getAuthOptsCompat(); std::unordered_map &getFlashmqAuthPluginOpts(); + + std::string getRetainedMessagesDBFile() const; + std::string getSessionsDBFile() const; }; #endif // SETTINGS_H diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index e216228..6628ebb 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -20,7 +20,7 @@ License along with FlashMQ. If not, see . #include "cassert" #include "rwlockguard.h" - +#include "retainedmessagesdb.h" SubscriptionNode::SubscriptionNode(const std::string &subtopic) : subtopic(subtopic) @@ -33,6 +33,11 @@ std::vector &SubscriptionNode::getSubscribers() return subscribers; } +const std::string &SubscriptionNode::getSubtopic() const +{ + return subtopic; +} + void SubscriptionNode::addSubscriber(const std::shared_ptr &subscriber, char qos) { Subscription sub; @@ -84,15 +89,20 @@ SubscriptionStore::SubscriptionStore() : } -void SubscriptionStore::addSubscription(std::shared_ptr &client, const std::string &topic, const std::vector &subtopics, char qos) +/** + * @brief SubscriptionStore::getDeepestNode gets the node in the tree walking the path of 'the/subscription/topic/path', making new nodes as required. + * @param topic + * @param subtopics + * @return + * + * caller is responsible for locking. + */ +SubscriptionNode *SubscriptionStore::getDeepestNode(const std::string &topic, const std::vector &subtopics) { SubscriptionNode *deepestNode = &root; if (topic.length() > 0 && topic[0] == '$') deepestNode = &rootDollar; - RWLockGuard lock_guard(&subscriptionsRwlock); - lock_guard.wrlock(); - for(const std::string &subtopic : subtopics) { std::unique_ptr *selectedChildren = nullptr; @@ -114,6 +124,15 @@ void SubscriptionStore::addSubscription(std::shared_ptr &client, const s } assert(deepestNode); + return deepestNode; +} + +void SubscriptionStore::addSubscription(std::shared_ptr &client, const std::string &topic, const std::vector &subtopics, char qos) +{ + RWLockGuard lock_guard(&subscriptionsRwlock); + lock_guard.wrlock(); + + SubscriptionNode *deepestNode = getDeepestNode(topic, subtopics); if (deepestNode) { @@ -494,6 +513,183 @@ int64_t SubscriptionStore::getRetainedMessageCount() const return retainedMessageCount; } +void SubscriptionStore::getRetainedMessages(RetainedMessageNode *this_node, std::vector &outputList) const +{ + for(const RetainedMessage &rm : this_node->retainedMessages) + { + outputList.push_back(rm); + } + + for(auto &pair : this_node->children) + { + const std::unique_ptr &child = pair.second; + getRetainedMessages(child.get(), outputList); + } +} + +/** + * @brief SubscriptionStore::getSubscriptions + * @param this_node + * @param composedTopic + * @param root bool. Every subtopic is concatenated with a '/', but not the first topic to 'root'. The root is a bit weird, virtual, so it needs different treatment. + * @param outputList + */ +void SubscriptionStore::getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root, + std::unordered_map> &outputList) const +{ + for (const Subscription &node : this_node->getSubscribers()) + { + if (!node.sessionGone()) + { + SubscriptionForSerializing sub(node.session.lock()->getClientId(), node.qos); + outputList[composedTopic].push_back(sub); + } + } + + for (auto &pair : this_node->children) + { + SubscriptionNode *node = pair.second.get(); + const std::string topicAtNextLevel = root ? pair.first : composedTopic + "/" + pair.first; + getSubscriptions(node, topicAtNextLevel, false, outputList); + } + + if (this_node->childrenPlus) + { + const std::string topicAtNextLevel = root ? "+" : composedTopic + "/+"; + getSubscriptions(this_node->childrenPlus.get(), topicAtNextLevel, false, outputList); + } + + if (this_node->childrenPound) + { + const std::string topicAtNextLevel = root ? "#" : composedTopic + "/#"; + getSubscriptions(this_node->childrenPound.get(), topicAtNextLevel, false, outputList); + } +} + +void SubscriptionStore::saveRetainedMessages(const std::string &filePath) +{ + logger->logf(LOG_INFO, "Saving retained messages to '%s'", filePath.c_str()); + + std::vector result; + result.reserve(retainedMessageCount); + + // Create the list of messages under lock, and unlock right after. + RWLockGuard locker(&retainedMessagesRwlock); + locker.rdlock(); + getRetainedMessages(&retainedMessagesRoot, result); + locker.unlock(); + + logger->logf(LOG_DEBUG, "Collected %ld retained messages to save.", result.size()); + + // Then do the IO without locking the threads. + RetainedMessagesDB db(filePath); + db.openWrite(); + db.saveData(result); +} + +void SubscriptionStore::loadRetainedMessages(const std::string &filePath) +{ + try + { + logger->logf(LOG_INFO, "Loading '%s'", filePath.c_str()); + + RetainedMessagesDB db(filePath); + db.openRead(); + std::list messages = db.readData(); + + RWLockGuard locker(&retainedMessagesRwlock); + locker.wrlock(); + + std::vector subtopics; + for (const RetainedMessage &rm : messages) + { + splitTopic(rm.topic, subtopics); + setRetainedMessage(rm.topic, subtopics, rm.payload, rm.qos); + } + } + catch (PersistenceFileCantBeOpened &ex) + { + logger->logf(LOG_WARNING, "File '%s' is not there (yet)", filePath.c_str()); + } +} + +void SubscriptionStore::saveSessionsAndSubscriptions(const std::string &filePath) +{ + logger->logf(LOG_INFO, "Saving sessions and subscriptions to '%s'", filePath.c_str()); + + RWLockGuard lock_guard(&subscriptionsRwlock); + lock_guard.wrlock(); + + // First copy the sessions... + + std::list> sessionCopies; + + for (const auto &pair : sessionsByIdConst) + { + const Session &org = *pair.second.get(); + sessionCopies.push_back(org.getCopy()); + } + + std::unordered_map> subscriptionCopies; + getSubscriptions(&root, "", true, subscriptionCopies); + + lock_guard.unlock(); + + // Then write the copies to disk, after having released the lock + + logger->logf(LOG_DEBUG, "Collected %ld sessions and %ld subscriptions to save.", sessionCopies.size(), subscriptionCopies.size()); + + SessionsAndSubscriptionsDB db(filePath); + db.openWrite(); + db.saveData(sessionCopies, subscriptionCopies); +} + +void SubscriptionStore::loadSessionsAndSubscriptions(const std::string &filePath) +{ + try + { + logger->logf(LOG_INFO, "Loading '%s'", filePath.c_str()); + + SessionsAndSubscriptionsDB db(filePath); + db.openRead(); + SessionsAndSubscriptionsResult loadedData = db.readData(); + + RWLockGuard locker(&subscriptionsRwlock); + locker.wrlock(); + + for (std::shared_ptr &session : loadedData.sessions) + { + sessionsById[session->getClientId()] = session; + } + + std::vector subtopics; + + for (auto &pair : loadedData.subscriptions) + { + const std::string &topic = pair.first; + const std::list &subs = pair.second; + + for (const SubscriptionForSerializing &sub : subs) + { + splitTopic(topic, subtopics); + SubscriptionNode *subscriptionNode = getDeepestNode(topic, subtopics); + + auto session_it = sessionsByIdConst.find(sub.clientId); + if (session_it != sessionsByIdConst.end()) + { + const std::shared_ptr &ses = session_it->second; + subscriptionNode->addSubscriber(ses, sub.qos); + } + + } + } + } + catch (PersistenceFileCantBeOpened &ex) + { + logger->logf(LOG_WARNING, "File '%s' is not there (yet)", filePath.c_str()); + } +} + // QoS is not used in the comparision. This means you upgrade your QoS by subscribing again. The // specs don't specify what to do there. bool Subscription::operator==(const Subscription &rhs) const diff --git a/subscriptionstore.h b/subscriptionstore.h index 35416b4..93e3dc6 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -53,6 +53,7 @@ public: SubscriptionNode(SubscriptionNode &&node) = delete; std::vector &getSubscribers(); + const std::string &getSubtopic() const; void addSubscriber(const std::shared_ptr &subscriber, char qos); void removeSubscriber(const std::shared_ptr &subscriber); std::unordered_map> children; @@ -77,6 +78,10 @@ class RetainedMessageNode class SubscriptionStore { +#ifdef TESTING + friend class MainTests; +#endif + SubscriptionNode root; SubscriptionNode rootDollar; pthread_rwlock_t subscriptionsRwlock = PTHREAD_RWLOCK_INITIALIZER; @@ -93,7 +98,11 @@ class SubscriptionStore void publishNonRecursively(const MqttPacket &packet, const std::vector &subscribers, uint64_t &count) const; void publishRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, SubscriptionNode *this_node, const MqttPacket &packet, uint64_t &count) const; + void getRetainedMessages(RetainedMessageNode *this_node, std::vector &outputList) const; + void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root, + std::unordered_map> &outputList) const; + SubscriptionNode *getDeepestNode(const std::string &topic, const std::vector &subtopics); public: SubscriptionStore(); @@ -103,7 +112,6 @@ public: bool sessionPresent(const std::string &clientid); void queuePacketAtSubscribers(const std::vector &subtopics, const MqttPacket &packet, bool dollar = false); - uint64_t giveClientRetainedMessages(const std::shared_ptr &ses, const std::string &subscribe_topic, char max_qos); void giveClientRetainedMessagesRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, RetainedMessageNode *this_node, char max_qos, const std::shared_ptr &ses, bool poundMode, uint64_t &count) const; @@ -114,6 +122,12 @@ public: void removeExpiredSessionsClients(int expireSessionsAfterSeconds); int64_t getRetainedMessageCount() const; + + void saveRetainedMessages(const std::string &filePath); + void loadRetainedMessages(const std::string &filePath); + + void saveSessionsAndSubscriptions(const std::string &filePath); + void loadSessionsAndSubscriptions(const std::string &filePath); }; #endif // SUBSCRIPTIONSTORE_H diff --git a/utils.cpp b/utils.cpp index 79f78fc..575d328 100644 --- a/utils.cpp +++ b/utils.cpp @@ -289,6 +289,14 @@ void trim(std::string &s) rtrim(s); } +std::string &rtrim(std::string &s, unsigned char c) +{ + s.erase(std::find_if(s.rbegin(), s.rend(), [=](unsigned char ch) { + return (c != ch); + }).base(), s.end()); + return s; +} + bool startsWith(const std::string &s, const std::string &needle) { return s.find(needle) == 0; diff --git a/utils.h b/utils.h index 0a83fa3..70c8352 100644 --- a/utils.h +++ b/utils.h @@ -28,6 +28,8 @@ License along with FlashMQ. If not, see . #include #include #include +#include "unistd.h" +#include "sys/stat.h" #include "cirbuf.h" #include "bindaddr.h" @@ -64,6 +66,7 @@ void ltrim(std::string &s); void rtrim(std::string &s); void trim(std::string &s); bool startsWith(const std::string &s, const std::string &needle); +std::string &rtrim(std::string &s, unsigned char c); std::string getSecureRandomString(const ssize_t len); std::string str_tolower(std::string s); @@ -92,5 +95,32 @@ ssize_t getFileSize(const std::string &path); std::string sockaddrToString(struct sockaddr *addr); +template void checkWritableDir(const std::string &path) +{ + if (path.empty()) + throw ex("Dir path to check is an empty string."); + + if (access(path.c_str(), W_OK) != 0) + { + std::string msg = formatString("Path '%s' is not there or not writable", path.c_str()); + throw ex(msg); + } + + struct stat statbuf; + memset(&statbuf, 0, sizeof(struct stat)); + if (stat(path.c_str(), &statbuf) < 0) + { + // We checked for W_OK above, so this shouldn't happen. + std::string msg = formatString("Error getting information about '%s'.", path.c_str()); + throw ex(msg); + } + + if (!S_ISDIR(statbuf.st_mode)) + { + std::string msg = formatString("Path '%s' is not a directory.", path.c_str()); + throw ex(msg); + } +} + #endif // UTILS_H