Commit 70e77e6a47b92e38ff7756c6f8a597fb1fc8a234
1 parent
6cdda452
Don't access current thread through client
This needed a separation: getting the current thread, and getting the thread of the client you're queueing a command for. This also resolves a circular reference between Client and ThreadData.
Showing
5 changed files
with
20 additions
and
13 deletions
FlashMQTests/tst_maintests.cpp
| ... | ... | @@ -1002,6 +1002,7 @@ void MainTests::testSavingSessions() |
| 1002 | 1002 | // Kind of a hack... |
| 1003 | 1003 | Authentication auth(*settings.get()); |
| 1004 | 1004 | ThreadGlobals::assign(&auth); |
| 1005 | + ThreadGlobals::assignThreadData(t.get()); | |
| 1005 | 1006 | |
| 1006 | 1007 | std::shared_ptr<Client> c1(new Client(0, t, nullptr, false, nullptr, settings.get(), false)); |
| 1007 | 1008 | c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60); | ... | ... |
client.cpp
| ... | ... | @@ -45,6 +45,7 @@ Client::Client(int fd, std::shared_ptr<ThreadData> threadData, SSL *ssl, bool we |
| 45 | 45 | ioWrapper(ssl, websocket, initialBufferSize, this), |
| 46 | 46 | readbuf(initialBufferSize), |
| 47 | 47 | writebuf(initialBufferSize), |
| 48 | + epoll_fd(threadData ? threadData->epollfd : 0), | |
| 48 | 49 | threadData(threadData) |
| 49 | 50 | { |
| 50 | 51 | int flags = fcntl(fd, F_GETFL); |
| ... | ... | @@ -61,7 +62,7 @@ Client::Client(int fd, std::shared_ptr<ThreadData> threadData, SSL *ssl, bool we |
| 61 | 62 | Client::~Client() |
| 62 | 63 | { |
| 63 | 64 | // Dummy clients, that I sometimes need just because the interface demands it but there's not actually a client, have no thread. |
| 64 | - if (!this->threadData) | |
| 65 | + if (this->threadData.expired()) | |
| 65 | 66 | return; |
| 66 | 67 | |
| 67 | 68 | if (disconnectReason.empty()) |
| ... | ... | @@ -78,7 +79,7 @@ Client::~Client() |
| 78 | 79 | |
| 79 | 80 | if (fd > 0) // this check is essentially for testing, when working with a dummy fd. |
| 80 | 81 | { |
| 81 | - if (epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL) != 0) | |
| 82 | + if (epoll_ctl(this->epoll_fd, EPOLL_CTL_DEL, fd, NULL) != 0) | |
| 82 | 83 | logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd, repr().c_str(), strerror(errno)); |
| 83 | 84 | close(fd); |
| 84 | 85 | } |
| ... | ... | @@ -272,7 +273,9 @@ void Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet) |
| 272 | 273 | } |
| 273 | 274 | catch (std::exception &ex) |
| 274 | 275 | { |
| 275 | - threadData->removeClientQueued(fd); | |
| 276 | + std::shared_ptr<ThreadData> td = this->threadData.lock(); | |
| 277 | + if (td) | |
| 278 | + td->removeClientQueued(fd); | |
| 276 | 279 | } |
| 277 | 280 | } |
| 278 | 281 | |
| ... | ... | @@ -414,7 +417,7 @@ uint16_t Client::getMaxIncomingTopicAliasValue() const |
| 414 | 417 | |
| 415 | 418 | void Client::sendOrQueueWill() |
| 416 | 419 | { |
| 417 | - if (!this->threadData) | |
| 420 | + if (this->threadData.expired()) | |
| 418 | 421 | return; |
| 419 | 422 | |
| 420 | 423 | if (!this->willPublish) |
| ... | ... | @@ -447,7 +450,10 @@ void Client::serverInitiatedDisconnect(ReasonCodes reason) |
| 447 | 450 | else |
| 448 | 451 | { |
| 449 | 452 | markAsDisconnecting(); |
| 450 | - threadData->removeClientQueued(fd); | |
| 453 | + | |
| 454 | + std::shared_ptr<ThreadData> td = this->threadData.lock(); | |
| 455 | + if (td) | |
| 456 | + td->removeClientQueued(fd); | |
| 451 | 457 | } |
| 452 | 458 | } |
| 453 | 459 | |
| ... | ... | @@ -573,7 +579,7 @@ void Client::setReadyForWriting(bool val) |
| 573 | 579 | memset(&ev, 0, sizeof (struct epoll_event)); |
| 574 | 580 | ev.data.fd = fd; |
| 575 | 581 | ev.events = readyForReading*EPOLLIN | readyForWriting*EPOLLOUT; |
| 576 | - check<std::runtime_error>(epoll_ctl(threadData->epollfd, EPOLL_CTL_MOD, fd, &ev)); | |
| 582 | + check<std::runtime_error>(epoll_ctl(this->epoll_fd, EPOLL_CTL_MOD, fd, &ev)); | |
| 577 | 583 | } |
| 578 | 584 | |
| 579 | 585 | void Client::setReadyForReading(bool val) |
| ... | ... | @@ -606,7 +612,7 @@ void Client::setReadyForReading(bool val) |
| 606 | 612 | std::lock_guard<std::mutex> locker(writeBufMutex); |
| 607 | 613 | |
| 608 | 614 | ev.events = readyForReading*EPOLLIN | readyForWriting*EPOLLOUT; |
| 609 | - check<std::runtime_error>(epoll_ctl(threadData->epollfd, EPOLL_CTL_MOD, fd, &ev)); | |
| 615 | + check<std::runtime_error>(epoll_ctl(this->epoll_fd, EPOLL_CTL_MOD, fd, &ev)); | |
| 610 | 616 | } |
| 611 | 617 | } |
| 612 | 618 | ... | ... |
client.h
| ... | ... | @@ -93,7 +93,8 @@ class Client |
| 93 | 93 | |
| 94 | 94 | std::shared_ptr<WillPublish> willPublish; |
| 95 | 95 | |
| 96 | - std::shared_ptr<ThreadData> threadData; | |
| 96 | + const int epoll_fd; | |
| 97 | + std::weak_ptr<ThreadData> threadData; // The thread (data) that this client 'lives' in. | |
| 97 | 98 | std::mutex writeBufMutex; |
| 98 | 99 | |
| 99 | 100 | std::shared_ptr<Session> session; |
| ... | ... | @@ -139,7 +140,6 @@ public: |
| 139 | 140 | void setAuthenticated(bool value) { authenticated = value;} |
| 140 | 141 | bool getAuthenticated() { return authenticated; } |
| 141 | 142 | bool hasConnectPacketSeen() { return connectPacketSeen; } |
| 142 | - std::shared_ptr<ThreadData> getThreadData() { return threadData; } | |
| 143 | 143 | std::string &getClientId() { return this->clientid; } |
| 144 | 144 | const std::string &getUsername() const { return this->username; } |
| 145 | 145 | std::string &getMutableUsername(); | ... | ... |
mqttpacket.cpp
| ... | ... | @@ -328,7 +328,7 @@ void MqttPacket::handleConnect() |
| 328 | 328 | |
| 329 | 329 | std::shared_ptr<SubscriptionStore> subscriptionStore = MainApp::getMainApp()->getSubscriptionStore(); |
| 330 | 330 | |
| 331 | - sender->getThreadData()->mqttConnectCounter.inc(); | |
| 331 | + ThreadGlobals::getThreadData()->mqttConnectCounter.inc(); | |
| 332 | 332 | |
| 333 | 333 | uint16_t variable_header_length = readTwoBytesToUInt16(); |
| 334 | 334 | |
| ... | ... | @@ -864,7 +864,7 @@ void MqttPacket::handleDisconnect() |
| 864 | 864 | sender->markAsDisconnecting(); |
| 865 | 865 | if (reasonCode == ReasonCodes::Success) |
| 866 | 866 | sender->clearWill(); |
| 867 | - sender->getThreadData()->removeClientQueued(sender); | |
| 867 | + ThreadGlobals::getThreadData()->removeClientQueued(sender); | |
| 868 | 868 | } |
| 869 | 869 | |
| 870 | 870 | void MqttPacket::handleSubscribe() |
| ... | ... | @@ -1131,7 +1131,7 @@ void MqttPacket::handlePublish() |
| 1131 | 1131 | |
| 1132 | 1132 | ReasonCodes ackCode = ReasonCodes::Success; |
| 1133 | 1133 | |
| 1134 | - sender->getThreadData()->receivedMessageCounter.inc(); | |
| 1134 | + ThreadGlobals::getThreadData()->receivedMessageCounter.inc(); | |
| 1135 | 1135 | |
| 1136 | 1136 | Authentication &authentication = *ThreadGlobals::getAuth(); |
| 1137 | 1137 | ... | ... |
subscriptionstore.cpp
| ... | ... | @@ -225,7 +225,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr<Client> |
| 225 | 225 | // Removes an existing client when it already exists [MQTT-3.1.4-2]. |
| 226 | 226 | void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr<Client> &client, bool clean_start, uint16_t clientReceiveMax, uint32_t sessionExpiryInterval) |
| 227 | 227 | { |
| 228 | - client->getThreadData()->queueClientNextKeepAliveCheckLocked(client, true); | |
| 228 | + ThreadGlobals::getThreadData()->queueClientNextKeepAliveCheckLocked(client, true); | |
| 229 | 229 | |
| 230 | 230 | RWLockGuard lock_guard(&subscriptionsRwlock); |
| 231 | 231 | lock_guard.wrlock(); | ... | ... |