diff --git a/client.cpp b/client.cpp index 31d7d6c..2a57cb2 100644 --- a/client.cpp +++ b/client.cpp @@ -16,12 +16,20 @@ Client::Client(int fd, ThreadData_p threadData) : Client::~Client() { - epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL); // NOTE: the last NULL can cause crash on old kernels - close(fd); + closeConnection(); free(readbuf); free(writebuf); } +void Client::closeConnection() +{ + if (fd < 0) + return; + epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL); + close(fd); + fd = -1; +} + // false means any kind of error we want to get rid of the client for. bool Client::readFdIntoBuffer() { @@ -59,11 +67,16 @@ bool Client::readFdIntoBuffer() void Client::writeMqttPacket(const MqttPacket &packet) { + if (packet.packetType == PacketType::PUBLISH && getWriteBufBytesUsed() > CLIENT_MAX_BUFFER_SIZE) + return; + if (packet.getSize() > getWriteBufMaxWriteSize()) growWriteBuffer(packet.getSize()); std::memcpy(&writebuf[wwi], &packet.getBites()[0], packet.getSize()); wwi += packet.getSize(); + + setReadyForWriting(true); } // Not sure if this is the method I want to use @@ -83,10 +96,11 @@ void Client::writePingResp() writebuf[wwi++] = 0b11010000; writebuf[wwi++] = 0; - writeBufIntoFd(); + + setReadyForWriting(true); } -bool Client::writeBufIntoFd() // TODO: ignore the signal BROKEN PIPE we now also get when a client disappears. +bool Client::writeBufIntoFd() { int n; while ((n = write(fd, &writebuf[wri], getWriteBufBytesUsed())) != 0) @@ -108,6 +122,8 @@ bool Client::writeBufIntoFd() // TODO: ignore the signal BROKEN PIPE we now also { wri = 0; wwi = 0; + + setReadyForWriting(false); } return true; @@ -134,6 +150,21 @@ void Client::queuedMessagesToBuffer() } +void Client::setReadyForWriting(bool val) +{ + if (val == this->readyForWriting) + return; + + readyForWriting = val; + struct epoll_event ev; + memset(&ev, 0, sizeof (struct epoll_event)); + ev.data.fd = fd; + ev.events = EPOLLIN; + if (val) + ev.events |= EPOLLOUT; + check(epoll_ctl(threadData->epollfd, EPOLL_CTL_MOD, fd, &ev)); +} + bool Client::bufferToMqttPackets(std::vector &packetQueueIn, Client_p &sender) { while (getReadBufBytesUsed() >= MQTT_HEADER_LENGH) diff --git a/client.h b/client.h index 1122aae..95fcf8f 100644 --- a/client.h +++ b/client.h @@ -14,6 +14,7 @@ #define CLIENT_BUFFER_SIZE 1024 +#define CLIENT_MAX_BUFFER_SIZE 1048576 #define MQTT_HEADER_LENGH 2 class Client @@ -32,6 +33,8 @@ class Client bool authenticated = false; bool connectPacketSeen = false; + bool readyForWriting = false; + std::string clientid; std::string username; uint16_t keepalive = 0; @@ -84,11 +87,14 @@ class Client writeBufsize = newBufSize; } + + public: Client(int fd, ThreadData_p threadData); ~Client(); int getFd() { return fd;} + void closeConnection(); bool readFdIntoBuffer(); bool bufferToMqttPackets(std::vector &packetQueueIn, Client_p &sender); void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive); @@ -96,6 +102,7 @@ public: bool getAuthenticated() { return authenticated; } bool hasConnectPacketSeen() { return connectPacketSeen; } ThreadData_p getThreadData() { return threadData; } + std::string &getClientId() { return this->clientid; } void writePingResp(); void writeMqttPacket(const MqttPacket &packet); @@ -106,6 +113,8 @@ public: void queueMessage(const MqttPacket &packet); void queuedMessagesToBuffer(); + + void setReadyForWriting(bool val); }; #endif // CLIENT_H diff --git a/mainapp.cpp b/mainapp.cpp index cff434c..2e3fac8 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -1,4 +1,5 @@ #include "mainapp.h" +#include "cassert" #define MAX_EVENTS 1024 #define NR_OF_THREADS 4 @@ -27,8 +28,6 @@ void do_thread_work(ThreadData *threadData) eventfd_value = 0; } - // TODO: do all the buftofd here, not spread out over - int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); if (fdcount > 0) @@ -66,6 +65,10 @@ void do_thread_work(ThreadData *threadData) threadData->removeClient(client); } } + else + { + assert(false); + } } } @@ -85,11 +88,11 @@ MainApp::MainApp() : void MainApp::start() { - int listen_fd = socket(AF_INET, SOCK_STREAM, 0); + int listen_fd = check(socket(AF_INET, SOCK_STREAM, 0)); // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT. - //int optval = 1; - //check(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); + int optval = 1; + check(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); int flags = fcntl(listen_fd, F_GETFL); check(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK )); @@ -153,11 +156,13 @@ void MainApp::start() } } + + close(listen_fd); } void MainApp::quit() { - std::cout << "Quitting application" << std::endl; + std::cout << "Quitting FlashMQ" << std::endl; running = false; diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 6d4d138..215eb9c 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -159,7 +159,6 @@ void MqttPacket::handleConnect() ConnAck connAck(ConnAckReturnCodes::Accepted); MqttPacket response(connAck); sender->writeMqttPacket(response); - sender->writeBufIntoFd(); } else { @@ -186,7 +185,6 @@ void MqttPacket::handleSubscribe(std::shared_ptr &subscriptio SubAck subAck(packet_id, subs); MqttPacket response(subAck); sender->writeMqttPacket(response); - sender->writeBufIntoFd(); } void MqttPacket::handlePublish(std::shared_ptr &subscriptionStore) diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index e0e5fc1..ff83282 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -8,7 +8,17 @@ SubscriptionStore::SubscriptionStore() void SubscriptionStore::addSubscription(Client_p &client, std::string &topic) { std::lock_guard lock(subscriptionsMutex); - this->subscriptions[topic].push_back(client); + this->subscriptions[topic].insert(client); +} + +void SubscriptionStore::removeClient(const Client_p &client) +{ + std::lock_guard lock(subscriptionsMutex); + for(std::pair> &pair : subscriptions) + { + std::unordered_set &bla = pair.second; + bla.erase(client); + } } void SubscriptionStore::queueAtClientsTemp(std::string &topic, const MqttPacket &packet, const Client_p &sender) @@ -16,18 +26,19 @@ void SubscriptionStore::queueAtClientsTemp(std::string &topic, const MqttPacket // TODO: temp. I want to work with read copies of the subscription store, to avoid frequent lock contention. std::lock_guard lock(subscriptionsMutex); - for(Client_p &client : subscriptions[topic]) + for(const Client_p &client : subscriptions[topic]) { if (client->getThreadData()->threadnr == sender->getThreadData()->threadnr) { client->writeMqttPacket(packet); // TODO: with my current hack way, this is wrong. Not using a lock only works with my previous idea of queueing. - client->writeBufIntoFd(); } else { - client->writeMqttPacketLocked(packet); - client->getThreadData()->addToReadyForDequeuing(client); - client->getThreadData()->wakeUpThread(); + // Or keep a list of queued messages in the store, per client? + + //client->writeMqttPacketLocked(packet); + //client->getThreadData()->addToReadyForDequeuing(client); + //client->getThreadData()->wakeUpThread(); } } } diff --git a/subscriptionstore.h b/subscriptionstore.h index 7264bd8..b8427ad 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -11,12 +11,13 @@ class SubscriptionStore { - std::unordered_map> subscriptions; + std::unordered_map> subscriptions; std::mutex subscriptionsMutex; public: SubscriptionStore(); void addSubscription(Client_p &client, std::string &topic); + void removeClient(const Client_p &client); // work with read copies intead of mutex/lock over the central store void getReadCopy(); // TODO diff --git a/threaddata.cpp b/threaddata.cpp index 0fdb935..99d137b 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -1,6 +1,5 @@ #include "threaddata.h" - ThreadData::ThreadData(int threadnr, std::shared_ptr &subscriptionStore) : subscriptionStore(subscriptionStore), threadnr(threadnr) @@ -23,24 +22,31 @@ void ThreadData::quit() void ThreadData::giveClient(Client_p client) { + clients_by_fd_mutex.lock(); int fd = client->getFd(); + clients_by_fd[fd] = client; + clients_by_fd_mutex.unlock(); + struct epoll_event ev; memset(&ev, 0, sizeof (struct epoll_event)); ev.data.fd = fd; ev.events = EPOLLIN; check(epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev)); - - clients_by_fd[fd] = client; } Client_p ThreadData::getClient(int fd) { + std::lock_guard lck(clients_by_fd_mutex); return this->clients_by_fd[fd]; } void ThreadData::removeClient(Client_p client) { + client->closeConnection(); + std::lock_guard lck(clients_by_fd_mutex); + subscriptionStore->removeClient(client); clients_by_fd.erase(client->getFd()); + } std::shared_ptr &ThreadData::getSubscriptionStore() diff --git a/threaddata.h b/threaddata.h index b7e32fb..00d42d0 100644 --- a/threaddata.h +++ b/threaddata.h @@ -9,6 +9,7 @@ #include #include #include +#include #include "forward_declarations.h" @@ -21,6 +22,7 @@ class ThreadData { std::unordered_map clients_by_fd; + std::mutex clients_by_fd_mutex; std::shared_ptr subscriptionStore; std::unordered_set readyForDequeueing; std::mutex readForDequeuingMutex;