/* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, version 3. FlashMQ is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with FlashMQ. If not, see . */ #include "mainapp.h" #include "cassert" #include "exceptions.h" #include "getopt.h" #include #include #include #include #include #include #include "logger.h" #define MAX_EVENTS 1024 MainApp *MainApp::instance = nullptr; void do_thread_work(ThreadData *threadData) { int epoll_fd = threadData->epollfd; struct epoll_event events[MAX_EVENTS]; memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); std::vector packetQueueIn; Logger *logger = Logger::getInstance(); try { logger->logf(LOG_NOTICE, "Thread %d doing auth init.", threadData->threadnr); threadData->initAuthPlugin(); } catch(std::exception &ex) { logger->logf(LOG_ERR, "Error initializing auth back-end: %s", ex.what()); threadData->running = false; MainApp *instance = MainApp::getMainApp(); instance->quit(); } while (threadData->running) { int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); if (fdcount < 0) { if (errno == EINTR) continue; logger->logf(LOG_ERR, "Problem waiting for fd: %s", strerror(errno)); } else if (fdcount > 0) { for (int i = 0; i < fdcount; i++) { struct epoll_event cur_ev = events[i]; int fd = cur_ev.data.fd; if (fd == threadData->taskEventFd) { uint64_t eventfd_value = 0; check(read(fd, &eventfd_value, sizeof(uint64_t))); std::lock_guard locker(threadData->taskQueueMutex); for(auto &f : threadData->taskQueue) { f(); } threadData->taskQueue.clear(); continue; } std::shared_ptr client = threadData->getClient(fd); if (client) { try { if (cur_ev.events & (EPOLLERR | EPOLLHUP)) { client->setDisconnectReason("epoll says socket is in ERR or HUP state."); threadData->removeClient(client); continue; } if (client->isSsl() && !client->isSslAccepted()) { client->startOrContinueSslAccept(); continue; } if ((cur_ev.events & EPOLLIN) || ((cur_ev.events & EPOLLOUT) && client->getSslReadWantsWrite())) { bool readSuccess = client->readFdIntoBuffer(); client->bufferToMqttPackets(packetQueueIn, client); if (!readSuccess) { client->setDisconnectReason("socket disconnect detected"); threadData->removeClient(client); continue; } } if ((cur_ev.events & EPOLLOUT) || ((cur_ev.events & EPOLLIN) && client->getSslWriteWantsRead())) { if (!client->writeBufIntoFd()) { threadData->removeClient(client); continue; } if (client->readyForDisconnecting()) { threadData->removeClient(client); continue; } } } catch(std::exception &ex) { client->setDisconnectReason(ex.what()); logger->logf(LOG_ERR, "Packet read/write error: %s. Removing client.", ex.what()); threadData->removeClient(client); } } } } for (MqttPacket &packet : packetQueueIn) { try { packet.handle(); } catch (std::exception &ex) { packet.getSender()->setDisconnectReason(ex.what()); logger->logf(LOG_ERR, "MqttPacket handling error: %s. Removing client.", ex.what()); threadData->removeClient(packet.getSender()); } } packetQueueIn.clear(); } } MainApp::MainApp(const std::string &configFilePath) : subscriptionStore(new SubscriptionStore()) { this->num_threads = get_nprocs(); if (num_threads <= 0) throw std::runtime_error("Invalid number of CPUs: " + std::to_string(num_threads)); epollFdAccept = check(epoll_create(999)); taskEventFd = eventfd(0, EFD_NONBLOCK); confFileParser.reset(new ConfigFileParser(configFilePath)); loadConfig(); // TODO: override in conf possibility. logger->logf(LOG_NOTICE, "%d CPUs are detected, making as many threads.", num_threads); if (settings->expireSessionsAfterSeconds > 0) { auto f = std::bind(&MainApp::queueCleanup, this); const uint64_t derrivedSessionCheckInterval = std::max((settings->expireSessionsAfterSeconds)*1000*2, 600000); const uint64_t sessionCheckInterval = std::min(derrivedSessionCheckInterval, 86400000); timer.addCallback(f, sessionCheckInterval, "session expiration"); } auto fKeepAlive = std::bind(&MainApp::queueKeepAliveCheckAtAllThreads, this); timer.addCallback(fKeepAlive, 30000, "keep-alive check"); auto fPasswordFileReload = std::bind(&MainApp::queuePasswordFileReloadAllThreads, this); timer.addCallback(fPasswordFileReload, 2000, "Password file reload."); auto fPublishStats = std::bind(&MainApp::publishStatsOnDollarTopic, this); timer.addCallback(fPublishStats, 10000, "Publish stats on $SYS"); publishStatsOnDollarTopic(); if (settings->authPluginTimerPeriod > 0) { auto fAuthPluginPeriodicEvent = std::bind(&MainApp::queueAuthPluginPeriodicEventAllThreads, this); timer.addCallback(fAuthPluginPeriodicEvent, settings->authPluginTimerPeriod*1000, "Auth plugin periodic event."); } } MainApp::~MainApp() { if (epollFdAccept > 0) close(epollFdAccept); } void MainApp::doHelp(const char *arg) { puts("FlashMQ - the scalable light-weight MQTT broker"); puts(""); printf("Usage: %s [options]\n", arg); puts(""); puts(" -h, --help Print help"); puts(" -c, --config-file Configuration file. Default '/etc/flashmq/flashmq.conf'."); puts(" -t, --test-config Test configuration file."); #ifndef NDEBUG puts(" -z, --fuzz-file For fuzzing, provides the bytes that would be sent by a client."); puts(" If the name contains 'web' it will activate websocket mode."); puts(" If the name also contains 'upgrade', it will assume the websocket"); puts(" client is upgrade, and bypass the cryptograhically secured websocket"); puts(" handshake."); #endif puts(" -V, --version Show version"); puts(" -l, --license Show license"); } void MainApp::showLicense() { printf("FlashMQ Version %s\n", VERSION); puts("Copyright (C) 2021 Wiebe Cazemier."); puts("License AGPLv3: GNU AGPL version 3. ."); puts(""); puts("Author: Wiebe Cazemier "); } std::list MainApp::createListenSocket(const std::shared_ptr &listener) { std::list result; if (listener->port <= 0) return result; for (ListenerProtocol p : std::list({ ListenerProtocol::IPv4, ListenerProtocol::IPv6})) { std::string pname = p == ListenerProtocol::IPv4 ? "IPv4" : "IPv6"; int family = p == ListenerProtocol::IPv4 ? AF_INET : AF_INET6; if (!(listener->protocol == ListenerProtocol::IPv46 || listener->protocol == p)) continue; try { logger->logf(LOG_NOTICE, "Creating %s %s listener on [%s]:%d", pname.c_str(), listener->getProtocolName().c_str(), listener->getBindAddress(p).c_str(), listener->port); BindAddr bindAddr = getBindAddr(family, listener->getBindAddress(p), listener->port); int listen_fd = check(socket(family, SOCK_STREAM, 0)); // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT. int optval = 1; check(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); int flags = fcntl(listen_fd, F_GETFL); check(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK )); check(bind(listen_fd, bindAddr.p.get(), bindAddr.len)); check(listen(listen_fd, 1024)); struct epoll_event ev; memset(&ev, 0, sizeof (struct epoll_event)); ev.data.fd = listen_fd; ev.events = EPOLLIN; check(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, listen_fd, &ev)); result.push_back(ScopedSocket(listen_fd)); } catch (std::exception &ex) { logger->logf(LOG_ERR, "Creating %s %s listener on [%s]:%d failed: %s", pname.c_str(), listener->getProtocolName().c_str(), listener->getBindAddress(p).c_str(), listener->port, ex.what()); return std::list(); } } return result; } void MainApp::wakeUpThread() { uint64_t one = 1; check(write(taskEventFd, &one, sizeof(uint64_t))); } void MainApp::queueKeepAliveCheckAtAllThreads() { for (std::shared_ptr &thread : threads) { thread->queueDoKeepAliveCheck(); } } void MainApp::queuePasswordFileReloadAllThreads() { for (std::shared_ptr &thread : threads) { thread->queuePasswdFileReload(); } } void MainApp::queueAuthPluginPeriodicEventAllThreads() { for (std::shared_ptr &thread : threads) { thread->queueAuthPluginPeriodicEvent(); } } void MainApp::setFuzzFile(const std::string &fuzzFilePath) { this->fuzzFilePath = fuzzFilePath; } void MainApp::publishStatsOnDollarTopic() { uint nrOfClients = 0; uint64_t receivedMessageCountPerSecond = 0; uint64_t receivedMessageCount = 0; uint64_t sentMessageCountPerSecond = 0; uint64_t sentMessageCount = 0; for (std::shared_ptr &thread : threads) { nrOfClients += thread->getNrOfClients(); receivedMessageCountPerSecond += thread->getReceivedMessagePerSecond(); receivedMessageCount += thread->getReceivedMessageCount(); sentMessageCountPerSecond += thread->getSentMessagePerSecond(); sentMessageCount += thread->getSentMessageCount(); } publishStat("$SYS/broker/clients/total", nrOfClients); publishStat("$SYS/broker/load/messages/received/total", receivedMessageCount); publishStat("$SYS/broker/load/messages/received/persecond", receivedMessageCountPerSecond); publishStat("$SYS/broker/load/messages/sent/total", sentMessageCount); publishStat("$SYS/broker/load/messages/sent/persecond", sentMessageCountPerSecond); publishStat("$SYS/broker/retained messages/count", subscriptionStore->getRetainedMessageCount()); } void MainApp::publishStat(const std::string &topic, uint64_t n) { std::vector subtopics; splitTopic(topic, subtopics); const std::string payload = std::to_string(n); Publish p(topic, payload, 0); subscriptionStore->queuePacketAtSubscribers(subtopics, p, true); subscriptionStore->setRetainedMessage(topic, subtopics, payload, 0); } void MainApp::initMainApp(int argc, char *argv[]) { if (instance != nullptr) throw std::runtime_error("App was already initialized."); static struct option long_options[] = { {"help", no_argument, nullptr, 'h'}, {"config-file", required_argument, nullptr, 'c'}, {"test-config", no_argument, nullptr, 't'}, {"fuzz-file", required_argument, nullptr, 'z'}, {"version", no_argument, nullptr, 'V'}, {"license", no_argument, nullptr, 'l'}, {nullptr, 0, nullptr, 0} }; const std::string defaultConfigFile = "/etc/flashmq/flashmq.conf"; std::string configFile; if (access(defaultConfigFile.c_str(), R_OK) == 0) { configFile = defaultConfigFile; } std::string fuzzFile; int option_index = 0; int opt; bool testConfig = false; while((opt = getopt_long(argc, argv, "hc:Vltz:", long_options, &option_index)) != -1) { switch(opt) { case 'c': configFile = optarg; break; case 'l': MainApp::showLicense(); exit(0); case 'V': MainApp::showLicense(); exit(0); case 'z': fuzzFile = optarg; break; case 'h': MainApp::doHelp(argv[0]); exit(16); case 't': testConfig = true; break; case '?': MainApp::doHelp(argv[0]); exit(16); } } if (testConfig) { try { if (configFile.empty()) { std::cerr << "No config specified (with -c) and the default " << defaultConfigFile << " not found." << std::endl << std::endl; MainApp::doHelp(argv[0]); exit(1); } ConfigFileParser c(configFile); c.loadFile(true); printf("Config '%s' OK\n", configFile.c_str()); exit(0); } catch (ConfigFileException &ex) { std::cerr << ex.what() << std::endl; exit(1); } } instance = new MainApp(configFile); instance->setFuzzFile(fuzzFile); } MainApp *MainApp::getMainApp() { if (!instance) throw std::runtime_error("You haven't initialized the app yet."); return instance; } void MainApp::start() { #ifndef NDEBUG if (fuzzFilePath.empty()) { oneInstanceLock.lock(); } #endif timer.start(); std::map> listenerMap; // For finding listeners by fd. std::list activeListenSockets; // For RAII/ownership for(std::shared_ptr &listener : this->listeners) { std::list scopedSockets = createListenSocket(listener); for (ScopedSocket &scopedSocket : scopedSockets) { if (scopedSocket.socket > 0) listenerMap[scopedSocket.socket] = listener; activeListenSockets.push_back(std::move(scopedSocket)); } } #ifdef NDEBUG logger->noLongerLogToStd(); #endif struct epoll_event ev; memset(&ev, 0, sizeof (struct epoll_event)); ev.data.fd = taskEventFd; ev.events = EPOLLIN; check(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, taskEventFd, &ev)); #ifndef NDEBUG // I fuzzed using afl-fuzz. You need to compile it with their compiler. if (!fuzzFilePath.empty()) { // No threads for execution stability/determinism. num_threads = 0; settings->allowAnonymous = true; int fd = open(fuzzFilePath.c_str(), O_RDONLY); assert(fd > 0); int fdnull = open("/dev/null", O_RDWR); assert(fdnull > 0); int fdnull2 = open("/dev/null", O_RDWR); assert(fdnull2 > 0); const std::string fuzzFilePathLower = str_tolower(fuzzFilePath); bool fuzzWebsockets = strContains(fuzzFilePathLower, "web"); try { std::vector packetQueueIn; std::vector subtopics; std::shared_ptr threaddata(new ThreadData(0, subscriptionStore, settings)); std::shared_ptr client(new Client(fd, threaddata, nullptr, fuzzWebsockets, nullptr, settings, true)); std::shared_ptr subscriber(new Client(fdnull, threaddata, nullptr, fuzzWebsockets, nullptr, settings, true)); subscriber->setClientProperties(ProtocolVersion::Mqtt311, "subscriber", "subuser", true, 60, true); subscriber->setAuthenticated(true); std::shared_ptr websocketsubscriber(new Client(fdnull2, threaddata, nullptr, true, nullptr, settings, true)); websocketsubscriber->setClientProperties(ProtocolVersion::Mqtt311, "websocketsubscriber", "websocksubuser", true, 60, true); websocketsubscriber->setAuthenticated(true); websocketsubscriber->setFakeUpgraded(); subscriptionStore->registerClientAndKickExistingOne(websocketsubscriber); splitTopic("#", subtopics); subscriptionStore->addSubscription(websocketsubscriber, "#", subtopics, 0); subscriptionStore->registerClientAndKickExistingOne(subscriber); subscriptionStore->addSubscription(subscriber, "#", subtopics, 0); if (fuzzWebsockets && strContains(fuzzFilePathLower, "upgrade")) { client->setFakeUpgraded(); subscriber->setFakeUpgraded(); } client->readFdIntoBuffer(); client->bufferToMqttPackets(packetQueueIn, client); for (MqttPacket &packet : packetQueueIn) { packet.handle(); } subscriber->writeBufIntoFd(); websocketsubscriber->writeBufIntoFd(); } catch (ProtocolError &ex) { logger->logf(LOG_ERR, "Expected MqttPacket handling error: %s", ex.what()); } running = false; } #endif for (int i = 0; i < num_threads; i++) { std::shared_ptr t(new ThreadData(i, subscriptionStore, settings)); t->start(&do_thread_work); threads.push_back(t); } uint next_thread_index = 0; struct epoll_event events[MAX_EVENTS]; memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); started = true; while (running) { int num_fds = epoll_wait(this->epollFdAccept, events, MAX_EVENTS, 100); if (num_fds < 0) { if (errno == EINTR) continue; logger->logf(LOG_ERR, "Waiting for listening socket error: %s", strerror(errno)); } for (int i = 0; i < num_fds; i++) { int cur_fd = events[i].data.fd; try { if (cur_fd != taskEventFd) { std::shared_ptr listener = listenerMap[cur_fd]; std::shared_ptr thread_data = threads[next_thread_index++ % num_threads]; logger->logf(LOG_INFO, "Accepting connection on thread %d on %s", thread_data->threadnr, listener->getProtocolName().c_str()); struct sockaddr_in6 addrBiggest; struct sockaddr *addr = reinterpret_cast(&addrBiggest); socklen_t len = sizeof(struct sockaddr_in6); memset(addr, 0, len); int fd = check(accept(cur_fd, addr, &len)); SSL *clientSSL = nullptr; if (listener->isSsl()) { clientSSL = SSL_new(listener->sslctx->get()); if (clientSSL == NULL) { logger->logf(LOG_ERR, "Problem creating SSL object. Closing client."); close(fd); continue; } SSL_set_fd(clientSSL, fd); } std::shared_ptr client(new Client(fd, thread_data, clientSSL, listener->websocket, addr, settings)); thread_data->giveClient(client); } else { uint64_t eventfd_value = 0; check(read(cur_fd, &eventfd_value, sizeof(uint64_t))); std::lock_guard locker(eventMutex); for(auto &f : taskQueue) { f(); } taskQueue.clear(); } } catch (std::exception &ex) { logger->logf(LOG_ERR, "Problem in main thread: %s", ex.what()); } } } oneInstanceLock.unlock(); for(std::shared_ptr &thread : threads) { thread->queueQuit(); } for(std::shared_ptr &thread : threads) { thread->waitForQuit(); } } void MainApp::quit() { std::lock_guard guard(quitMutex); if (!running) return; Logger *logger = Logger::getInstance(); logger->logf(LOG_NOTICE, "Quitting FlashMQ"); timer.stop(); running = false; } void MainApp::setlimits() { rlim_t nofile = settings->rlimitNoFile; logger->logf(LOG_INFO, "Setting rlimit nofile to %ld.", nofile); struct rlimit v = { nofile, nofile }; if (setrlimit(RLIMIT_NOFILE, &v) < 0) { logger->logf(LOG_ERR, "Setting ulimit nofile failed: '%s'. This means the default is used.", strerror(errno)); } } /** * @brief MainApp::loadConfig is loaded on app start where you want it to crash, loaded from within try/catch on reload, to allow the program to continue. */ void MainApp::loadConfig() { Logger *logger = Logger::getInstance(); // Atomic loading, first test. confFileParser->loadFile(true); confFileParser->loadFile(false); settings = std::move(confFileParser->settings); if (settings->listeners.empty()) { std::shared_ptr defaultListener(new Listener()); defaultListener->isValid(); settings->listeners.push_back(defaultListener); } // For now, it's too much work to be able to reload new listeners, with all the shared resource stuff going on. So, I'm // loading them to a local var which is never updated. if (listeners.empty()) listeners = settings->listeners; logger->setLogPath(settings->logPath); logger->reOpen(); logger->setFlags(settings->logDebug, settings->logSubscriptions); setlimits(); for (std::shared_ptr &l : this->listeners) { l->loadCertAndKeyFromConfig(); } for (std::shared_ptr &thread : threads) { thread->queueReload(settings); } } void MainApp::reloadConfig() { Logger *logger = Logger::getInstance(); logger->logf(LOG_NOTICE, "Reloading config"); try { loadConfig(); } catch (std::exception &ex) { logger->logf(LOG_ERR, "Error reloading config: %s", ex.what()); } } void MainApp::queueConfigReload() { std::lock_guard locker(eventMutex); auto f = std::bind(&MainApp::reloadConfig, this); taskQueue.push_front(f); wakeUpThread(); } void MainApp::queueCleanup() { std::lock_guard locker(eventMutex); auto f = std::bind(&SubscriptionStore::removeExpiredSessionsClients, subscriptionStore.get(), settings->expireSessionsAfterSeconds); taskQueue.push_front(f); wakeUpThread(); }