diff --git a/client.cpp b/client.cpp index a51d1d6..00611c2 100644 --- a/client.cpp +++ b/client.cpp @@ -3,6 +3,7 @@ #include #include #include +#include Client::Client(int fd, ThreadData_p threadData) : fd(fd), @@ -12,6 +13,9 @@ Client::Client(int fd, ThreadData_p threadData) : fcntl(fd, F_SETFL, flags | O_NONBLOCK); readbuf = (char*)malloc(CLIENT_BUFFER_SIZE); writebuf = (char*)malloc(CLIENT_BUFFER_SIZE); + + if (readbuf == NULL || writebuf == NULL) + throw std::runtime_error("Malloc error constructing client."); } Client::~Client() @@ -25,7 +29,7 @@ void Client::closeConnection() { if (fd < 0) return; - epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL); + check(epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL)); close(fd); fd = -1; } @@ -59,10 +63,7 @@ bool Client::readFdIntoBuffer() if (errno == EAGAIN || errno == EWOULDBLOCK) break; else - { - std::cerr << strerror(errno) << std::endl; - return false; - } + check(n); } } @@ -91,13 +92,6 @@ void Client::writeMqttPacket(const MqttPacket &packet) setReadyForWriting(true); } -// Not sure if this is the method I want to use -void Client::writeMqttPacketLocked(const MqttPacket &packet) -{ - std::lock_guard lock(writeBufMutex); - writeMqttPacket(packet); -} - // Ping responses are always the same, so hardcoding it for optimization. void Client::writePingResp() { @@ -131,7 +125,7 @@ bool Client::writeBufIntoFd() if (errno == EAGAIN || errno == EWOULDBLOCK) break; else - return false; + check(n); } } @@ -147,27 +141,14 @@ bool Client::writeBufIntoFd() return true; } - - std::string Client::repr() { std::ostringstream a; a << "Client = " << clientid << ", user = " << username; + a.flush(); return a.str(); } -void Client::queueMessage(const MqttPacket &packet) -{ - - - // TODO: semaphores on stl containers? -} - -void Client::queuedMessagesToBuffer() -{ - -} - void Client::setReadyForWriting(bool val) { if (val == this->readyForWriting) @@ -233,9 +214,7 @@ bool Client::bufferToMqttPackets(std::vector &packetQueueIn, Client_ packetQueueIn.push_back(std::move(packet)); ri += packet_length; - - if (ri > wi) - throw std::runtime_error("hier"); + assert(ri <= wi); } else break; diff --git a/client.h b/client.h index a374017..01d3996 100644 --- a/client.h +++ b/client.h @@ -116,15 +116,10 @@ public: void writePingResp(); void writeMqttPacket(const MqttPacket &packet); - void writeMqttPacketLocked(const MqttPacket &packet); bool writeBufIntoFd(); std::string repr(); - void queueMessage(const MqttPacket &packet); - void queuedMessagesToBuffer(); - - }; #endif // CLIENT_H diff --git a/exceptions.h b/exceptions.h index c8c1856..99f7e21 100644 --- a/exceptions.h +++ b/exceptions.h @@ -10,5 +10,11 @@ public: ProtocolError(const std::string &msg) : std::runtime_error(msg) {} }; +class NotImplementedException : public std::runtime_error +{ +public: + NotImplementedException(const std::string &msg) : std::runtime_error(msg) {} +}; + #endif // EXCEPTIONS_H diff --git a/main.cpp b/main.cpp index 0c80ad9..07ffee5 100644 --- a/main.cpp +++ b/main.cpp @@ -38,7 +38,7 @@ int register_signal_handers() if (sigaction(SIGHUP, &sa, nullptr) != 0 || sigaction(SIGTERM, &sa, nullptr) != 0 || sigaction(SIGINT, &sa, nullptr) != 0) { std::cerr << "Error registering signals" << std::endl; - return 1; + return -1; } sigset_t set; diff --git a/mainapp.cpp b/mainapp.cpp index e9e3891..340b9a6 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -1,5 +1,6 @@ #include "mainapp.h" #include "cassert" +#include "exceptions.h" #define MAX_EVENTS 1024 #define NR_OF_THREADS 4 @@ -32,7 +33,13 @@ void do_thread_work(ThreadData *threadData) int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); - if (fdcount > 0) + if (fdcount < 0) + { + if (errno == EINTR) + continue; + std::cerr << "Problem waiting for fd: " << strerror(errno) << std::endl; + } + else if (fdcount > 0) { for (int i = 0; i < fdcount; i++) { @@ -50,21 +57,29 @@ void do_thread_work(ThreadData *threadData) if (client) { - if (cur_ev.events & EPOLLIN) + try { - bool readSuccess = client->readFdIntoBuffer(); - client->bufferToMqttPackets(packetQueueIn, client); - - if (!readSuccess) + if (cur_ev.events & EPOLLIN) + { + bool readSuccess = client->readFdIntoBuffer(); + client->bufferToMqttPackets(packetQueueIn, client); + + if (!readSuccess) + { + std::cout << "Disconnect: " << client->repr() << std::endl; + threadData->removeClient(client); + } + } + if (cur_ev.events & EPOLLOUT) { - std::cout << "Disconnect: " << client->repr() << std::endl; - threadData->removeClient(client); + if (!client->writeBufIntoFd()) + threadData->removeClient(client); } } - if (cur_ev.events & EPOLLOUT) + catch(std::exception &ex) { - if (!client->writeBufIntoFd()) - threadData->removeClient(client); + std::cerr << ex.what() << std::endl; + threadData->removeClient(client); } } else @@ -76,7 +91,15 @@ void do_thread_work(ThreadData *threadData) for (MqttPacket &packet : packetQueueIn) { - packet.handle(threadData->getSubscriptionStore()); + try + { + packet.handle(threadData->getSubscriptionStore()); + } + catch (std::exception &ex) + { + std::cerr << ex.what() << std::endl; + threadData->removeClient(packet.getSender()); + } } packetQueueIn.clear(); } @@ -142,26 +165,40 @@ void MainApp::start() { int num_fds = epoll_wait(epoll_fd_accept, events, MAX_EVENTS, 100); + if (num_fds < 0) + { + if (errno == EINTR) + continue; + std::cerr << strerror(errno) << std::endl; + } + for (int i = 0; i < num_fds; i++) { int cur_fd = events[i].data.fd; - if (cur_fd == listen_fd) + try { - std::shared_ptr thread_data = threads[next_thread_index++ % NR_OF_THREADS]; + if (cur_fd == listen_fd) + { + std::shared_ptr thread_data = threads[next_thread_index++ % NR_OF_THREADS]; - std::cout << "Accepting connection on thread " << thread_data->threadnr << std::endl; + std::cout << "Accepting connection on thread " << thread_data->threadnr << std::endl; - struct sockaddr addr; - memset(&addr, 0, sizeof(struct sockaddr)); - socklen_t len = sizeof(struct sockaddr); - int fd = check(accept(cur_fd, &addr, &len)); + struct sockaddr addr; + memset(&addr, 0, sizeof(struct sockaddr)); + socklen_t len = sizeof(struct sockaddr); + int fd = check(accept(cur_fd, &addr, &len)); - Client_p client(new Client(fd, thread_data)); - thread_data->giveClient(client); + Client_p client(new Client(fd, thread_data)); + thread_data->giveClient(client); + } + else + { + throw std::runtime_error("The main thread had activity on an accepted socket?"); + } } - else + catch (std::exception &ex) { - throw std::runtime_error("The main thread had activity on an accepted socket?"); + std::cerr << "Problem accepting connection: " << ex.what() << std::endl; } } diff --git a/rwlockguard.cpp b/rwlockguard.cpp index 89649f8..6f22c80 100644 --- a/rwlockguard.cpp +++ b/rwlockguard.cpp @@ -1,4 +1,6 @@ #include "rwlockguard.h" +#include "utils.h" +#include "stdexcept" RWLockGuard::RWLockGuard(pthread_rwlock_t *rwlock) : rwlock(rwlock) @@ -13,10 +15,12 @@ RWLockGuard::~RWLockGuard() void RWLockGuard::wrlock() { - pthread_rwlock_wrlock(rwlock); + if (pthread_rwlock_wrlock(rwlock) != 0) + throw std::runtime_error("wrlock failed."); } void RWLockGuard::rdlock() { - pthread_rwlock_wrlock(rwlock); + if (pthread_rwlock_wrlock(rwlock) != 0) + throw std::runtime_error("rdlock failed."); }