From 7e87fd91bb6ef68c808936fb178010c579510ae0 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Tue, 22 Dec 2020 17:45:07 +0100 Subject: [PATCH] Change from checking fd==-1 to properly through the destructor --- client.cpp | 35 +++++++++++++++++++++++++++++------ client.h | 8 +++++--- mainapp.cpp | 2 +- subscriptionstore.cpp | 2 +- threaddata.cpp | 15 ++++++++++++++- threaddata.h | 1 + 6 files changed, 51 insertions(+), 12 deletions(-) diff --git a/client.cpp b/client.cpp index 828c23e..24d9e47 100644 --- a/client.cpp +++ b/client.cpp @@ -33,23 +33,27 @@ Client::Client(int fd, ThreadData_p threadData) : Client::~Client() { - closeConnection(); + close(fd); free(readbuf); free(writebuf); } -void Client::closeConnection() +// Do this from a place you'll know ownwership of the shared_ptr is being given up everywhere, so the close happens when the last owner gives it up. +void Client::markAsDisconnecting() { - if (fd < 0) + if (disconnecting) return; + + disconnecting = true; check(epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL)); - close(fd); - fd = -1; } // false means any kind of error we want to get rid of the client for. bool Client::readFdIntoBuffer() { + if (disconnecting) + return false; + if (wi > CLIENT_MAX_BUFFER_SIZE) { setReadyForReading(false); @@ -119,6 +123,19 @@ void Client::writeMqttPacket(const MqttPacket &packet) setReadyForWriting(true); } +// Helper method to avoid the exception ending up at the sender of messages, which would then get disconnected. +void Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet) +{ + try + { + this->writeMqttPacket(packet); + } + catch (std::exception &ex) + { + threadData->removeClient(fd); + } +} + // Ping responses are always the same, so hardcoding it for optimization. void Client::writePingResp() { @@ -142,7 +159,7 @@ bool Client::writeBufIntoFd() return true; // We can abort the write; the client is about to be removed anyway. - if (isDisconnected()) + if (disconnecting) return false; int n; @@ -182,6 +199,9 @@ std::string Client::repr() void Client::setReadyForWriting(bool val) { + if (disconnecting) + return; + if (val == this->readyForWriting) return; @@ -198,6 +218,9 @@ void Client::setReadyForWriting(bool val) void Client::setReadyForReading(bool val) { + if (disconnecting) + return; + if (val == this->readyForReading) return; diff --git a/client.h b/client.h index 9e1d132..ff090e7 100644 --- a/client.h +++ b/client.h @@ -37,6 +37,7 @@ class Client bool readyForWriting = false; bool readyForReading = true; bool disconnectWhenBytesWritten = false; + bool disconnecting = false; std::string clientid; std::string username; @@ -109,10 +110,12 @@ class Client public: Client(int fd, ThreadData_p threadData); + Client(const Client &other) = delete; + Client(Client &&other) = delete; ~Client(); int getFd() { return fd;} - void closeConnection(); + void markAsDisconnecting(); bool readFdIntoBuffer(); bool bufferToMqttPackets(std::vector &packetQueueIn, Client_p &sender); void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive); @@ -125,14 +128,13 @@ public: void writePingResp(); void writeMqttPacket(const MqttPacket &packet); + void writeMqttPacketAndBlameThisClient(const MqttPacket &packet); bool writeBufIntoFd(); bool readyForDisconnecting() const { return disconnectWhenBytesWritten && wwi == wri && wwi == 0; } // Do this before calling an action that makes this client ready for writing, so that the EPOLLOUT will handle it. void setReadyForDisconnect() { disconnectWhenBytesWritten = true; } - bool isDisconnected() const { return fd < 0; } - std::string repr(); }; diff --git a/mainapp.cpp b/mainapp.cpp index 45d6461..bd0eedd 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -55,7 +55,7 @@ void do_thread_work(ThreadData *threadData) Client_p client = threadData->getClient(fd); - if (client && !client->isDisconnected()) + if (client) { try { diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index d5ae976..bd40266 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -66,7 +66,7 @@ bool SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const st auto client_it = clients_by_id_const.find(client_id); if (client_it != clients_by_id_const.end()) { - client_it->second->writeMqttPacket(packet); + client_it->second->writeMqttPacketAndBlameThisClient(packet); result = true; } } diff --git a/threaddata.cpp b/threaddata.cpp index a913be8..e5feab7 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -55,12 +55,25 @@ Client_p ThreadData::getClient(int fd) void ThreadData::removeClient(Client_p client) { + client->markAsDisconnecting(); + std::lock_guard lck(clients_by_fd_mutex); clients_by_fd.erase(client->getFd()); - client->closeConnection(); subscriptionStore->removeClient(client); } +void ThreadData::removeClient(int fd) +{ + std::lock_guard lck(clients_by_fd_mutex); + auto client_it = this->clients_by_fd.find(fd); + if (client_it != this->clients_by_fd.end()) + { + client_it->second->markAsDisconnecting(); + subscriptionStore->removeClient(client_it->second); + this->clients_by_fd.erase(fd); + } +} + std::shared_ptr &ThreadData::getSubscriptionStore() { return subscriptionStore; diff --git a/threaddata.h b/threaddata.h index bfa1acf..6ba3061 100644 --- a/threaddata.h +++ b/threaddata.h @@ -41,6 +41,7 @@ public: void giveClient(Client_p client); Client_p getClient(int fd); void removeClient(Client_p client); + void removeClient(int fd); std::shared_ptr &getSubscriptionStore(); void wakeUpThread(); void addToReadyForDequeuing(Client_p &client); -- libgit2 0.21.4