diff --git a/CMakeLists.txt b/CMakeLists.txt index 95a5dcd..4b94e6e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -57,6 +57,7 @@ add_executable(FlashMQ sessionsandsubscriptionsdb.h qospacketqueue.h threadauth.h + threadloop.h mainapp.cpp main.cpp @@ -93,6 +94,7 @@ add_executable(FlashMQ sessionsandsubscriptionsdb.cpp qospacketqueue.cpp threadauth.cpp + threadloop.cpp ) diff --git a/FlashMQTests/FlashMQTests.pro b/FlashMQTests/FlashMQTests.pro index aa068fe..c15cea7 100644 --- a/FlashMQTests/FlashMQTests.pro +++ b/FlashMQTests/FlashMQTests.pro @@ -48,6 +48,7 @@ SOURCES += tst_maintests.cpp \ ../sessionsandsubscriptionsdb.cpp \ ../qospacketqueue.cpp \ ../threadauth.cpp \ + ../threadloop.cpp \ mainappthread.cpp \ twoclienttestcontext.cpp @@ -88,6 +89,7 @@ HEADERS += \ ../sessionsandsubscriptionsdb.h \ ../qospacketqueue.h \ ../threadauth.h \ + ../threadloop.h \ mainappthread.h \ twoclienttestcontext.h diff --git a/mainapp.cpp b/mainapp.cpp index 900403e..05facff 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -29,153 +29,10 @@ License along with FlashMQ. If not, see . #include "logger.h" #include "threadauth.h" - -#define MAX_EVENTS 1024 +#include "threadloop.h" MainApp *MainApp::instance = nullptr; -void do_thread_work(ThreadData *threadData) -{ - int epoll_fd = threadData->epollfd; - ThreadAuth::assign(&threadData->authentication); - - struct epoll_event events[MAX_EVENTS]; - memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); - - std::vector packetQueueIn; - - Logger *logger = Logger::getInstance(); - - try - { - logger->logf(LOG_NOTICE, "Thread %d doing auth init.", threadData->threadnr); - threadData->initAuthPlugin(); - } - catch(std::exception &ex) - { - logger->logf(LOG_ERR, "Error initializing auth back-end: %s", ex.what()); - threadData->running = false; - MainApp *instance = MainApp::getMainApp(); - instance->quit(); - } - - while (threadData->running) - { - int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); - - if (fdcount < 0) - { - if (errno == EINTR) - continue; - logger->logf(LOG_ERR, "Problem waiting for fd: %s", strerror(errno)); - } - else if (fdcount > 0) - { - for (int i = 0; i < fdcount; i++) - { - struct epoll_event cur_ev = events[i]; - int fd = cur_ev.data.fd; - - if (fd == threadData->taskEventFd) - { - uint64_t eventfd_value = 0; - check(read(fd, &eventfd_value, sizeof(uint64_t))); - - std::lock_guard locker(threadData->taskQueueMutex); - for(auto &f : threadData->taskQueue) - { - f(); - } - threadData->taskQueue.clear(); - - continue; - } - - std::shared_ptr client = threadData->getClient(fd); - - if (client) - { - try - { - if (cur_ev.events & (EPOLLERR | EPOLLHUP)) - { - client->setDisconnectReason("epoll says socket is in ERR or HUP state."); - threadData->removeClient(client); - continue; - } - if (client->isSsl() && !client->isSslAccepted()) - { - client->startOrContinueSslAccept(); - continue; - } - if ((cur_ev.events & EPOLLIN) || ((cur_ev.events & EPOLLOUT) && client->getSslReadWantsWrite())) - { - bool readSuccess = client->readFdIntoBuffer(); - client->bufferToMqttPackets(packetQueueIn, client); - - if (!readSuccess) - { - client->setDisconnectReason("socket disconnect detected"); - threadData->removeClient(client); - continue; - } - } - if ((cur_ev.events & EPOLLOUT) || ((cur_ev.events & EPOLLIN) && client->getSslWriteWantsRead())) - { - if (!client->writeBufIntoFd()) - { - threadData->removeClient(client); - continue; - } - - if (client->readyForDisconnecting()) - { - threadData->removeClient(client); - continue; - } - } - } - catch(std::exception &ex) - { - client->setDisconnectReason(ex.what()); - logger->logf(LOG_ERR, "Packet read/write error: %s. Removing client.", ex.what()); - threadData->removeClient(client); - } - } - } - } - - for (MqttPacket &packet : packetQueueIn) - { - try - { - packet.handle(); - } - catch (std::exception &ex) - { - packet.getSender()->setDisconnectReason(ex.what()); - logger->logf(LOG_ERR, "MqttPacket handling error: %s. Removing client.", ex.what()); - threadData->removeClient(packet.getSender()); - } - } - packetQueueIn.clear(); - } - - try - { - logger->logf(LOG_NOTICE, "Thread %d doing auth cleanup.", threadData->threadnr); - threadData->cleanupAuthPlugin(); - } - catch(std::exception &ex) - { - logger->logf(LOG_ERR, "Error cleaning auth back-end: %s", ex.what()); - } - - threadData->finished = true; -} - - - MainApp::MainApp(const std::string &configFilePath) : subscriptionStore(new SubscriptionStore()) { diff --git a/threadloop.cpp b/threadloop.cpp new file mode 100644 index 0000000..a9b1903 --- /dev/null +++ b/threadloop.cpp @@ -0,0 +1,158 @@ +/* +This file is part of FlashMQ (https://www.flashmq.org) +Copyright (C) 2021 Wiebe Cazemier + +FlashMQ is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, version 3. + +FlashMQ is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public +License along with FlashMQ. If not, see . +*/ + +#include "threadloop.h" + +void do_thread_work(ThreadData *threadData) +{ + int epoll_fd = threadData->epollfd; + ThreadAuth::assign(&threadData->authentication); + + struct epoll_event events[MAX_EVENTS]; + memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); + + std::vector packetQueueIn; + + Logger *logger = Logger::getInstance(); + + try + { + logger->logf(LOG_NOTICE, "Thread %d doing auth init.", threadData->threadnr); + threadData->initAuthPlugin(); + } + catch(std::exception &ex) + { + logger->logf(LOG_ERR, "Error initializing auth back-end: %s", ex.what()); + threadData->running = false; + MainApp *instance = MainApp::getMainApp(); + instance->quit(); + } + + while (threadData->running) + { + int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); + + if (fdcount < 0) + { + if (errno == EINTR) + continue; + logger->logf(LOG_ERR, "Problem waiting for fd: %s", strerror(errno)); + } + else if (fdcount > 0) + { + for (int i = 0; i < fdcount; i++) + { + struct epoll_event cur_ev = events[i]; + int fd = cur_ev.data.fd; + + if (fd == threadData->taskEventFd) + { + uint64_t eventfd_value = 0; + check(read(fd, &eventfd_value, sizeof(uint64_t))); + + std::lock_guard locker(threadData->taskQueueMutex); + for(auto &f : threadData->taskQueue) + { + f(); + } + threadData->taskQueue.clear(); + + continue; + } + + std::shared_ptr client = threadData->getClient(fd); + + if (client) + { + try + { + if (cur_ev.events & (EPOLLERR | EPOLLHUP)) + { + client->setDisconnectReason("epoll says socket is in ERR or HUP state."); + threadData->removeClient(client); + continue; + } + if (client->isSsl() && !client->isSslAccepted()) + { + client->startOrContinueSslAccept(); + continue; + } + if ((cur_ev.events & EPOLLIN) || ((cur_ev.events & EPOLLOUT) && client->getSslReadWantsWrite())) + { + bool readSuccess = client->readFdIntoBuffer(); + client->bufferToMqttPackets(packetQueueIn, client); + + if (!readSuccess) + { + client->setDisconnectReason("socket disconnect detected"); + threadData->removeClient(client); + continue; + } + } + if ((cur_ev.events & EPOLLOUT) || ((cur_ev.events & EPOLLIN) && client->getSslWriteWantsRead())) + { + if (!client->writeBufIntoFd()) + { + threadData->removeClient(client); + continue; + } + + if (client->readyForDisconnecting()) + { + threadData->removeClient(client); + continue; + } + } + } + catch(std::exception &ex) + { + client->setDisconnectReason(ex.what()); + logger->logf(LOG_ERR, "Packet read/write error: %s. Removing client.", ex.what()); + threadData->removeClient(client); + } + } + } + } + + for (MqttPacket &packet : packetQueueIn) + { + try + { + packet.handle(); + } + catch (std::exception &ex) + { + packet.getSender()->setDisconnectReason(ex.what()); + logger->logf(LOG_ERR, "MqttPacket handling error: %s. Removing client.", ex.what()); + threadData->removeClient(packet.getSender()); + } + } + packetQueueIn.clear(); + } + + try + { + logger->logf(LOG_NOTICE, "Thread %d doing auth cleanup.", threadData->threadnr); + threadData->cleanupAuthPlugin(); + } + catch(std::exception &ex) + { + logger->logf(LOG_ERR, "Error cleaning auth back-end: %s", ex.what()); + } + + threadData->finished = true; +} diff --git a/threadloop.h b/threadloop.h new file mode 100644 index 0000000..106c33a --- /dev/null +++ b/threadloop.h @@ -0,0 +1,29 @@ +/* +This file is part of FlashMQ (https://www.flashmq.org) +Copyright (C) 2021 Wiebe Cazemier + +FlashMQ is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, version 3. + +FlashMQ is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public +License along with FlashMQ. If not, see . +*/ + +#ifndef THREADLOOP_H +#define THREADLOOP_H + +#include "threaddata.h" +#include "threadauth.h" + +#define MAX_EVENTS 1024 + +void do_thread_work(ThreadData *threadData); + + +#endif // THREADLOOP_H