diff --git a/CMakeLists.txt b/CMakeLists.txt index 444838a..fe5ba53 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,6 +7,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) add_executable(FlashMQ + mainapp.cpp main.cpp utils.cpp threaddata.cpp diff --git a/main.cpp b/main.cpp index 423454d..6c122aa 100644 --- a/main.cpp +++ b/main.cpp @@ -1,172 +1,70 @@ #include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include "mainapp.h" -#include "utils.h" -#include "threaddata.h" -#include "client.h" -#include "mqttpacket.h" -#include "subscriptionstore.h" +static MainApp mainApp; -#define MAX_EVENTS 1024 -#define NR_OF_THREADS 4 - - - -void do_thread_work(ThreadData *threadData) +static void signal_handler(int signal) { - int epoll_fd = threadData->epollfd; - - struct epoll_event events[MAX_EVENTS]; - memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); - - std::vector packetQueueIn; - - uint64_t eventfd_value = 0; - - while (1) + if (signal == SIGPIPE) + { + return; + } + if (signal == SIGHUP) { - if (eventfd_value > 0) - { - for (Client_p client : threadData->getReadyForDequeueing()) - { - //client->queuedMessagesToBuffer(); - client->writeBufIntoFd(); - } - threadData->clearReadyForDequeueing(); - eventfd_value = 0; - } - - // TODO: do all the buftofd here, not spread out over - - int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); - - if (fdcount > 0) - { - for (int i = 0; i < fdcount; i++) - { - struct epoll_event cur_ev = events[i]; - int fd = cur_ev.data.fd; - - // If this thread was actively woken up. - if (fd == threadData->event_fd) - { - read(fd, &eventfd_value, sizeof(uint64_t)); - continue; - } - - Client_p client = threadData->getClient(fd); - - if (client) - { - if (cur_ev.events & EPOLLIN) - { - bool readSuccess = client->readFdIntoBuffer(); - client->bufferToMqttPackets(packetQueueIn, client); - - if (!readSuccess) - { - std::cout << "Disconnect: " << client->repr() << std::endl; - threadData->removeClient(client); - } - } - if (cur_ev.events & EPOLLOUT) - { - if (!client->writeBufIntoFd()) - threadData->removeClient(client); - } - } - } - } - for (MqttPacket &packet : packetQueueIn) - { - packet.handle(threadData->getSubscriptionStore()); - } - packetQueueIn.clear(); + } + else if (signal == SIGTERM || signal == SIGINT) + { + mainApp.quit(); + } + else + { + std::cerr << "Received signal " << signal << std::endl; } } -int main() +int register_signal_handers() { - int listen_fd = socket(AF_INET, SOCK_STREAM, 0); - - // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT. - //int optval = 1; - //check(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); + struct sigaction sa; + memset(&sa, 0, sizeof (struct sigaction)); + sa.sa_handler = &signal_handler; + sigemptyset(&sa.sa_mask); + sa.sa_flags = SA_RESTART; - int flags = fcntl(listen_fd, F_GETFL); - check(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK )); - - struct sockaddr_in in_addr; - in_addr.sin_family = AF_INET; - in_addr.sin_addr.s_addr = INADDR_ANY; - in_addr.sin_port = htons(1883); - - check(bind(listen_fd, (struct sockaddr *)(&in_addr), sizeof(struct sockaddr_in))); - check(listen(listen_fd, 1024)); - - int epoll_fd_accept = check(epoll_create(999)); - - struct epoll_event events[MAX_EVENTS]; - struct epoll_event ev; - memset(&ev, 0, sizeof (struct epoll_event)); - memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); - - ev.data.fd = listen_fd; - ev.events = EPOLLIN; - check(epoll_ctl(epoll_fd_accept, EPOLL_CTL_ADD, listen_fd, &ev)); - - std::shared_ptr subscriptionStore(new SubscriptionStore()); - - std::vector> threads; - - for (int i = 0; i < NR_OF_THREADS; i++) + if (sigaction(SIGHUP, &sa, nullptr) != 0 || sigaction(SIGTERM, &sa, nullptr) != 0 || sigaction(SIGINT, &sa, nullptr) != 0) { - std::shared_ptr t(new ThreadData(i, subscriptionStore)); - std::thread thread(do_thread_work, t.get()); - t->thread = std::move(thread); - threads.push_back(t); + std::cerr << "Error registering signals" << std::endl; + return 1; } - std::cout << "Listening..." << std::endl; - - uint next_thread_index = 0; + sigset_t set; + sigemptyset(&set); + sigaddset(&set,SIGPIPE); - while (1) + int r; + if ((r = sigprocmask(SIG_BLOCK, &set, NULL) != 0)) { - int num_fds = epoll_wait(epoll_fd_accept, events, MAX_EVENTS, 100); - - for (int i = 0; i < num_fds; i++) - { - int cur_fd = events[i].data.fd; - if (cur_fd == listen_fd) - { - std::shared_ptr thread_data = threads[next_thread_index++ % NR_OF_THREADS]; - - std::cout << "Accepting connection on thread " << thread_data->threadnr << std::endl; - - struct sockaddr addr; - memset(&addr, 0, sizeof(struct sockaddr)); - socklen_t len = sizeof(struct sockaddr); - int fd = check(accept(cur_fd, &addr, &len)); - - Client_p client(new Client(fd, thread_data)); - thread_data->giveClient(client); - } - else - { - throw std::runtime_error("The main thread had activity on an accepted socket?"); - } + return r; + } - } + return 0; +} +int main() +{ + try + { + check(register_signal_handers()); + mainApp.start(); + } + catch (std::exception &ex) + { + std::cerr << ex.what() << std::endl; + return 1; } return 0; diff --git a/mainapp.cpp b/mainapp.cpp new file mode 100644 index 0000000..cff434c --- /dev/null +++ b/mainapp.cpp @@ -0,0 +1,168 @@ +#include "mainapp.h" + +#define MAX_EVENTS 1024 +#define NR_OF_THREADS 4 + +void do_thread_work(ThreadData *threadData) +{ + int epoll_fd = threadData->epollfd; + + struct epoll_event events[MAX_EVENTS]; + memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); + + std::vector packetQueueIn; + + uint64_t eventfd_value = 0; + + while (threadData->running) + { + if (eventfd_value > 0) + { + for (Client_p client : threadData->getReadyForDequeueing()) + { + //client->queuedMessagesToBuffer(); + client->writeBufIntoFd(); + } + threadData->clearReadyForDequeueing(); + eventfd_value = 0; + } + + // TODO: do all the buftofd here, not spread out over + + int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); + + if (fdcount > 0) + { + for (int i = 0; i < fdcount; i++) + { + struct epoll_event cur_ev = events[i]; + int fd = cur_ev.data.fd; + + // If this thread was actively woken up. + if (fd == threadData->event_fd) + { + read(fd, &eventfd_value, sizeof(uint64_t)); + continue; + } + + Client_p client = threadData->getClient(fd); + + if (client) + { + if (cur_ev.events & EPOLLIN) + { + bool readSuccess = client->readFdIntoBuffer(); + client->bufferToMqttPackets(packetQueueIn, client); + + if (!readSuccess) + { + std::cout << "Disconnect: " << client->repr() << std::endl; + threadData->removeClient(client); + } + } + if (cur_ev.events & EPOLLOUT) + { + if (!client->writeBufIntoFd()) + threadData->removeClient(client); + } + } + } + } + + for (MqttPacket &packet : packetQueueIn) + { + packet.handle(threadData->getSubscriptionStore()); + } + packetQueueIn.clear(); + } +} + +MainApp::MainApp() : + subscriptionStore(new SubscriptionStore()) +{ + +} + +void MainApp::start() +{ + int listen_fd = socket(AF_INET, SOCK_STREAM, 0); + + // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT. + //int optval = 1; + //check(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); + + int flags = fcntl(listen_fd, F_GETFL); + check(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK )); + + struct sockaddr_in in_addr; + in_addr.sin_family = AF_INET; + in_addr.sin_addr.s_addr = INADDR_ANY; + in_addr.sin_port = htons(1883); + + check(bind(listen_fd, (struct sockaddr *)(&in_addr), sizeof(struct sockaddr_in))); + check(listen(listen_fd, 1024)); + + int epoll_fd_accept = check(epoll_create(999)); + + struct epoll_event events[MAX_EVENTS]; + struct epoll_event ev; + memset(&ev, 0, sizeof (struct epoll_event)); + memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); + + ev.data.fd = listen_fd; + ev.events = EPOLLIN; + check(epoll_ctl(epoll_fd_accept, EPOLL_CTL_ADD, listen_fd, &ev)); + + for (int i = 0; i < NR_OF_THREADS; i++) + { + std::shared_ptr t(new ThreadData(i, subscriptionStore)); + std::thread thread(do_thread_work, t.get()); + t->thread = std::move(thread); + threads.push_back(t); + } + + std::cout << "Listening..." << std::endl; + + uint next_thread_index = 0; + + while (running) + { + int num_fds = epoll_wait(epoll_fd_accept, events, MAX_EVENTS, 100); + + for (int i = 0; i < num_fds; i++) + { + int cur_fd = events[i].data.fd; + if (cur_fd == listen_fd) + { + std::shared_ptr thread_data = threads[next_thread_index++ % NR_OF_THREADS]; + + std::cout << "Accepting connection on thread " << thread_data->threadnr << std::endl; + + struct sockaddr addr; + memset(&addr, 0, sizeof(struct sockaddr)); + socklen_t len = sizeof(struct sockaddr); + int fd = check(accept(cur_fd, &addr, &len)); + + Client_p client(new Client(fd, thread_data)); + thread_data->giveClient(client); + } + else + { + throw std::runtime_error("The main thread had activity on an accepted socket?"); + } + + } + } +} + +void MainApp::quit() +{ + std::cout << "Quitting application" << std::endl; + + running = false; + + for(std::shared_ptr &thread : threads) + { + thread->quit(); + } +} diff --git a/mainapp.h b/mainapp.h new file mode 100644 index 0000000..556e5f1 --- /dev/null +++ b/mainapp.h @@ -0,0 +1,33 @@ +#ifndef MAINAPP_H +#define MAINAPP_H + +#include +#include +#include +#include +#include +#include +#include + +#include "forward_declarations.h" + +#include "utils.h" +#include "threaddata.h" +#include "client.h" +#include "mqttpacket.h" +#include "subscriptionstore.h" + + +class MainApp +{ + bool running = true; + std::vector> threads; + std::shared_ptr subscriptionStore; + +public: + MainApp(); + void start(); + void quit(); +}; + +#endif // MAINAPP_H diff --git a/threaddata.cpp b/threaddata.cpp index 41ffd7f..0fdb935 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -15,6 +15,12 @@ ThreadData::ThreadData(int threadnr, std::shared_ptr &subscri check(epoll_ctl(epollfd, EPOLL_CTL_ADD, event_fd, &ev)); } +void ThreadData::quit() +{ + running = false; + thread.join(); +} + void ThreadData::giveClient(Client_p client) { int fd = client->getFd(); diff --git a/threaddata.h b/threaddata.h index 5d2a03f..b7e32fb 100644 --- a/threaddata.h +++ b/threaddata.h @@ -26,6 +26,7 @@ class ThreadData std::mutex readForDequeuingMutex; public: + bool running = true; std::thread thread; int threadnr = 0; int epollfd = 0; @@ -33,6 +34,7 @@ public: ThreadData(int threadnr, std::shared_ptr &subscriptionStore); + void quit(); void giveClient(Client_p client); Client_p getClient(int fd); void removeClient(Client_p client);