Commit 70e77e6a47b92e38ff7756c6f8a597fb1fc8a234
1 parent
6cdda452
Don't access current thread through client
This needed a separation: getting the current thread, and getting the thread of the client you're queueing a command for. This also resolves a circular reference between Client and ThreadData.
Showing
5 changed files
with
20 additions
and
13 deletions
FlashMQTests/tst_maintests.cpp
| @@ -1002,6 +1002,7 @@ void MainTests::testSavingSessions() | @@ -1002,6 +1002,7 @@ void MainTests::testSavingSessions() | ||
| 1002 | // Kind of a hack... | 1002 | // Kind of a hack... |
| 1003 | Authentication auth(*settings.get()); | 1003 | Authentication auth(*settings.get()); |
| 1004 | ThreadGlobals::assign(&auth); | 1004 | ThreadGlobals::assign(&auth); |
| 1005 | + ThreadGlobals::assignThreadData(t.get()); | ||
| 1005 | 1006 | ||
| 1006 | std::shared_ptr<Client> c1(new Client(0, t, nullptr, false, nullptr, settings.get(), false)); | 1007 | std::shared_ptr<Client> c1(new Client(0, t, nullptr, false, nullptr, settings.get(), false)); |
| 1007 | c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60); | 1008 | c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60); |
client.cpp
| @@ -45,6 +45,7 @@ Client::Client(int fd, std::shared_ptr<ThreadData> threadData, SSL *ssl, bool we | @@ -45,6 +45,7 @@ Client::Client(int fd, std::shared_ptr<ThreadData> threadData, SSL *ssl, bool we | ||
| 45 | ioWrapper(ssl, websocket, initialBufferSize, this), | 45 | ioWrapper(ssl, websocket, initialBufferSize, this), |
| 46 | readbuf(initialBufferSize), | 46 | readbuf(initialBufferSize), |
| 47 | writebuf(initialBufferSize), | 47 | writebuf(initialBufferSize), |
| 48 | + epoll_fd(threadData ? threadData->epollfd : 0), | ||
| 48 | threadData(threadData) | 49 | threadData(threadData) |
| 49 | { | 50 | { |
| 50 | int flags = fcntl(fd, F_GETFL); | 51 | int flags = fcntl(fd, F_GETFL); |
| @@ -61,7 +62,7 @@ Client::Client(int fd, std::shared_ptr<ThreadData> threadData, SSL *ssl, bool we | @@ -61,7 +62,7 @@ Client::Client(int fd, std::shared_ptr<ThreadData> threadData, SSL *ssl, bool we | ||
| 61 | Client::~Client() | 62 | Client::~Client() |
| 62 | { | 63 | { |
| 63 | // Dummy clients, that I sometimes need just because the interface demands it but there's not actually a client, have no thread. | 64 | // Dummy clients, that I sometimes need just because the interface demands it but there's not actually a client, have no thread. |
| 64 | - if (!this->threadData) | 65 | + if (this->threadData.expired()) |
| 65 | return; | 66 | return; |
| 66 | 67 | ||
| 67 | if (disconnectReason.empty()) | 68 | if (disconnectReason.empty()) |
| @@ -78,7 +79,7 @@ Client::~Client() | @@ -78,7 +79,7 @@ Client::~Client() | ||
| 78 | 79 | ||
| 79 | if (fd > 0) // this check is essentially for testing, when working with a dummy fd. | 80 | if (fd > 0) // this check is essentially for testing, when working with a dummy fd. |
| 80 | { | 81 | { |
| 81 | - if (epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL) != 0) | 82 | + if (epoll_ctl(this->epoll_fd, EPOLL_CTL_DEL, fd, NULL) != 0) |
| 82 | logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd, repr().c_str(), strerror(errno)); | 83 | logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd, repr().c_str(), strerror(errno)); |
| 83 | close(fd); | 84 | close(fd); |
| 84 | } | 85 | } |
| @@ -272,7 +273,9 @@ void Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet) | @@ -272,7 +273,9 @@ void Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet) | ||
| 272 | } | 273 | } |
| 273 | catch (std::exception &ex) | 274 | catch (std::exception &ex) |
| 274 | { | 275 | { |
| 275 | - threadData->removeClientQueued(fd); | 276 | + std::shared_ptr<ThreadData> td = this->threadData.lock(); |
| 277 | + if (td) | ||
| 278 | + td->removeClientQueued(fd); | ||
| 276 | } | 279 | } |
| 277 | } | 280 | } |
| 278 | 281 | ||
| @@ -414,7 +417,7 @@ uint16_t Client::getMaxIncomingTopicAliasValue() const | @@ -414,7 +417,7 @@ uint16_t Client::getMaxIncomingTopicAliasValue() const | ||
| 414 | 417 | ||
| 415 | void Client::sendOrQueueWill() | 418 | void Client::sendOrQueueWill() |
| 416 | { | 419 | { |
| 417 | - if (!this->threadData) | 420 | + if (this->threadData.expired()) |
| 418 | return; | 421 | return; |
| 419 | 422 | ||
| 420 | if (!this->willPublish) | 423 | if (!this->willPublish) |
| @@ -447,7 +450,10 @@ void Client::serverInitiatedDisconnect(ReasonCodes reason) | @@ -447,7 +450,10 @@ void Client::serverInitiatedDisconnect(ReasonCodes reason) | ||
| 447 | else | 450 | else |
| 448 | { | 451 | { |
| 449 | markAsDisconnecting(); | 452 | markAsDisconnecting(); |
| 450 | - threadData->removeClientQueued(fd); | 453 | + |
| 454 | + std::shared_ptr<ThreadData> td = this->threadData.lock(); | ||
| 455 | + if (td) | ||
| 456 | + td->removeClientQueued(fd); | ||
| 451 | } | 457 | } |
| 452 | } | 458 | } |
| 453 | 459 | ||
| @@ -573,7 +579,7 @@ void Client::setReadyForWriting(bool val) | @@ -573,7 +579,7 @@ void Client::setReadyForWriting(bool val) | ||
| 573 | memset(&ev, 0, sizeof (struct epoll_event)); | 579 | memset(&ev, 0, sizeof (struct epoll_event)); |
| 574 | ev.data.fd = fd; | 580 | ev.data.fd = fd; |
| 575 | ev.events = readyForReading*EPOLLIN | readyForWriting*EPOLLOUT; | 581 | ev.events = readyForReading*EPOLLIN | readyForWriting*EPOLLOUT; |
| 576 | - check<std::runtime_error>(epoll_ctl(threadData->epollfd, EPOLL_CTL_MOD, fd, &ev)); | 582 | + check<std::runtime_error>(epoll_ctl(this->epoll_fd, EPOLL_CTL_MOD, fd, &ev)); |
| 577 | } | 583 | } |
| 578 | 584 | ||
| 579 | void Client::setReadyForReading(bool val) | 585 | void Client::setReadyForReading(bool val) |
| @@ -606,7 +612,7 @@ void Client::setReadyForReading(bool val) | @@ -606,7 +612,7 @@ void Client::setReadyForReading(bool val) | ||
| 606 | std::lock_guard<std::mutex> locker(writeBufMutex); | 612 | std::lock_guard<std::mutex> locker(writeBufMutex); |
| 607 | 613 | ||
| 608 | ev.events = readyForReading*EPOLLIN | readyForWriting*EPOLLOUT; | 614 | ev.events = readyForReading*EPOLLIN | readyForWriting*EPOLLOUT; |
| 609 | - check<std::runtime_error>(epoll_ctl(threadData->epollfd, EPOLL_CTL_MOD, fd, &ev)); | 615 | + check<std::runtime_error>(epoll_ctl(this->epoll_fd, EPOLL_CTL_MOD, fd, &ev)); |
| 610 | } | 616 | } |
| 611 | } | 617 | } |
| 612 | 618 |
client.h
| @@ -93,7 +93,8 @@ class Client | @@ -93,7 +93,8 @@ class Client | ||
| 93 | 93 | ||
| 94 | std::shared_ptr<WillPublish> willPublish; | 94 | std::shared_ptr<WillPublish> willPublish; |
| 95 | 95 | ||
| 96 | - std::shared_ptr<ThreadData> threadData; | 96 | + const int epoll_fd; |
| 97 | + std::weak_ptr<ThreadData> threadData; // The thread (data) that this client 'lives' in. | ||
| 97 | std::mutex writeBufMutex; | 98 | std::mutex writeBufMutex; |
| 98 | 99 | ||
| 99 | std::shared_ptr<Session> session; | 100 | std::shared_ptr<Session> session; |
| @@ -139,7 +140,6 @@ public: | @@ -139,7 +140,6 @@ public: | ||
| 139 | void setAuthenticated(bool value) { authenticated = value;} | 140 | void setAuthenticated(bool value) { authenticated = value;} |
| 140 | bool getAuthenticated() { return authenticated; } | 141 | bool getAuthenticated() { return authenticated; } |
| 141 | bool hasConnectPacketSeen() { return connectPacketSeen; } | 142 | bool hasConnectPacketSeen() { return connectPacketSeen; } |
| 142 | - std::shared_ptr<ThreadData> getThreadData() { return threadData; } | ||
| 143 | std::string &getClientId() { return this->clientid; } | 143 | std::string &getClientId() { return this->clientid; } |
| 144 | const std::string &getUsername() const { return this->username; } | 144 | const std::string &getUsername() const { return this->username; } |
| 145 | std::string &getMutableUsername(); | 145 | std::string &getMutableUsername(); |
mqttpacket.cpp
| @@ -328,7 +328,7 @@ void MqttPacket::handleConnect() | @@ -328,7 +328,7 @@ void MqttPacket::handleConnect() | ||
| 328 | 328 | ||
| 329 | std::shared_ptr<SubscriptionStore> subscriptionStore = MainApp::getMainApp()->getSubscriptionStore(); | 329 | std::shared_ptr<SubscriptionStore> subscriptionStore = MainApp::getMainApp()->getSubscriptionStore(); |
| 330 | 330 | ||
| 331 | - sender->getThreadData()->mqttConnectCounter.inc(); | 331 | + ThreadGlobals::getThreadData()->mqttConnectCounter.inc(); |
| 332 | 332 | ||
| 333 | uint16_t variable_header_length = readTwoBytesToUInt16(); | 333 | uint16_t variable_header_length = readTwoBytesToUInt16(); |
| 334 | 334 | ||
| @@ -864,7 +864,7 @@ void MqttPacket::handleDisconnect() | @@ -864,7 +864,7 @@ void MqttPacket::handleDisconnect() | ||
| 864 | sender->markAsDisconnecting(); | 864 | sender->markAsDisconnecting(); |
| 865 | if (reasonCode == ReasonCodes::Success) | 865 | if (reasonCode == ReasonCodes::Success) |
| 866 | sender->clearWill(); | 866 | sender->clearWill(); |
| 867 | - sender->getThreadData()->removeClientQueued(sender); | 867 | + ThreadGlobals::getThreadData()->removeClientQueued(sender); |
| 868 | } | 868 | } |
| 869 | 869 | ||
| 870 | void MqttPacket::handleSubscribe() | 870 | void MqttPacket::handleSubscribe() |
| @@ -1131,7 +1131,7 @@ void MqttPacket::handlePublish() | @@ -1131,7 +1131,7 @@ void MqttPacket::handlePublish() | ||
| 1131 | 1131 | ||
| 1132 | ReasonCodes ackCode = ReasonCodes::Success; | 1132 | ReasonCodes ackCode = ReasonCodes::Success; |
| 1133 | 1133 | ||
| 1134 | - sender->getThreadData()->receivedMessageCounter.inc(); | 1134 | + ThreadGlobals::getThreadData()->receivedMessageCounter.inc(); |
| 1135 | 1135 | ||
| 1136 | Authentication &authentication = *ThreadGlobals::getAuth(); | 1136 | Authentication &authentication = *ThreadGlobals::getAuth(); |
| 1137 | 1137 |
subscriptionstore.cpp
| @@ -225,7 +225,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr<Client> | @@ -225,7 +225,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr<Client> | ||
| 225 | // Removes an existing client when it already exists [MQTT-3.1.4-2]. | 225 | // Removes an existing client when it already exists [MQTT-3.1.4-2]. |
| 226 | void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr<Client> &client, bool clean_start, uint16_t clientReceiveMax, uint32_t sessionExpiryInterval) | 226 | void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr<Client> &client, bool clean_start, uint16_t clientReceiveMax, uint32_t sessionExpiryInterval) |
| 227 | { | 227 | { |
| 228 | - client->getThreadData()->queueClientNextKeepAliveCheckLocked(client, true); | 228 | + ThreadGlobals::getThreadData()->queueClientNextKeepAliveCheckLocked(client, true); |
| 229 | 229 | ||
| 230 | RWLockGuard lock_guard(&subscriptionsRwlock); | 230 | RWLockGuard lock_guard(&subscriptionsRwlock); |
| 231 | lock_guard.wrlock(); | 231 | lock_guard.wrlock(); |