diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 4d44f98..ce53660 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -1002,6 +1002,7 @@ void MainTests::testSavingSessions() // Kind of a hack... Authentication auth(*settings.get()); ThreadGlobals::assign(&auth); + ThreadGlobals::assignThreadData(t.get()); std::shared_ptr c1(new Client(0, t, nullptr, false, nullptr, settings.get(), false)); c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60); diff --git a/client.cpp b/client.cpp index 069e83e..95b1c2b 100644 --- a/client.cpp +++ b/client.cpp @@ -45,6 +45,7 @@ Client::Client(int fd, std::shared_ptr threadData, SSL *ssl, bool we ioWrapper(ssl, websocket, initialBufferSize, this), readbuf(initialBufferSize), writebuf(initialBufferSize), + epoll_fd(threadData ? threadData->epollfd : 0), threadData(threadData) { int flags = fcntl(fd, F_GETFL); @@ -61,7 +62,7 @@ Client::Client(int fd, std::shared_ptr threadData, SSL *ssl, bool we Client::~Client() { // Dummy clients, that I sometimes need just because the interface demands it but there's not actually a client, have no thread. - if (!this->threadData) + if (this->threadData.expired()) return; if (disconnectReason.empty()) @@ -78,7 +79,7 @@ Client::~Client() if (fd > 0) // this check is essentially for testing, when working with a dummy fd. { - if (epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL) != 0) + if (epoll_ctl(this->epoll_fd, EPOLL_CTL_DEL, fd, NULL) != 0) logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd, repr().c_str(), strerror(errno)); close(fd); } @@ -272,7 +273,9 @@ void Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet) } catch (std::exception &ex) { - threadData->removeClientQueued(fd); + std::shared_ptr td = this->threadData.lock(); + if (td) + td->removeClientQueued(fd); } } @@ -414,7 +417,7 @@ uint16_t Client::getMaxIncomingTopicAliasValue() const void Client::sendOrQueueWill() { - if (!this->threadData) + if (this->threadData.expired()) return; if (!this->willPublish) @@ -447,7 +450,10 @@ void Client::serverInitiatedDisconnect(ReasonCodes reason) else { markAsDisconnecting(); - threadData->removeClientQueued(fd); + + std::shared_ptr td = this->threadData.lock(); + if (td) + td->removeClientQueued(fd); } } @@ -573,7 +579,7 @@ void Client::setReadyForWriting(bool val) memset(&ev, 0, sizeof (struct epoll_event)); ev.data.fd = fd; ev.events = readyForReading*EPOLLIN | readyForWriting*EPOLLOUT; - check(epoll_ctl(threadData->epollfd, EPOLL_CTL_MOD, fd, &ev)); + check(epoll_ctl(this->epoll_fd, EPOLL_CTL_MOD, fd, &ev)); } void Client::setReadyForReading(bool val) @@ -606,7 +612,7 @@ void Client::setReadyForReading(bool val) std::lock_guard locker(writeBufMutex); ev.events = readyForReading*EPOLLIN | readyForWriting*EPOLLOUT; - check(epoll_ctl(threadData->epollfd, EPOLL_CTL_MOD, fd, &ev)); + check(epoll_ctl(this->epoll_fd, EPOLL_CTL_MOD, fd, &ev)); } } diff --git a/client.h b/client.h index dd66abb..7de28ed 100644 --- a/client.h +++ b/client.h @@ -93,7 +93,8 @@ class Client std::shared_ptr willPublish; - std::shared_ptr threadData; + const int epoll_fd; + std::weak_ptr threadData; // The thread (data) that this client 'lives' in. std::mutex writeBufMutex; std::shared_ptr session; @@ -139,7 +140,6 @@ public: void setAuthenticated(bool value) { authenticated = value;} bool getAuthenticated() { return authenticated; } bool hasConnectPacketSeen() { return connectPacketSeen; } - std::shared_ptr getThreadData() { return threadData; } std::string &getClientId() { return this->clientid; } const std::string &getUsername() const { return this->username; } std::string &getMutableUsername(); diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 984f814..8662f4a 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -328,7 +328,7 @@ void MqttPacket::handleConnect() std::shared_ptr subscriptionStore = MainApp::getMainApp()->getSubscriptionStore(); - sender->getThreadData()->mqttConnectCounter.inc(); + ThreadGlobals::getThreadData()->mqttConnectCounter.inc(); uint16_t variable_header_length = readTwoBytesToUInt16(); @@ -864,7 +864,7 @@ void MqttPacket::handleDisconnect() sender->markAsDisconnecting(); if (reasonCode == ReasonCodes::Success) sender->clearWill(); - sender->getThreadData()->removeClientQueued(sender); + ThreadGlobals::getThreadData()->removeClientQueued(sender); } void MqttPacket::handleSubscribe() @@ -1131,7 +1131,7 @@ void MqttPacket::handlePublish() ReasonCodes ackCode = ReasonCodes::Success; - sender->getThreadData()->receivedMessageCounter.inc(); + ThreadGlobals::getThreadData()->receivedMessageCounter.inc(); Authentication &authentication = *ThreadGlobals::getAuth(); diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index a7ca0f1..c76107c 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -225,7 +225,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr // Removes an existing client when it already exists [MQTT-3.1.4-2]. void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr &client, bool clean_start, uint16_t clientReceiveMax, uint32_t sessionExpiryInterval) { - client->getThreadData()->queueClientNextKeepAliveCheckLocked(client, true); + ThreadGlobals::getThreadData()->queueClientNextKeepAliveCheckLocked(client, true); RWLockGuard lock_guard(&subscriptionsRwlock); lock_guard.wrlock();