Commit 70e77e6a47b92e38ff7756c6f8a597fb1fc8a234

Authored by Wiebe Cazemier
1 parent 6cdda452

Don't access current thread through client

This needed a separation: getting the current thread, and getting the
thread of the client you're queueing a command for.

This also resolves a circular reference between Client and ThreadData.
FlashMQTests/tst_maintests.cpp
@@ -1002,6 +1002,7 @@ void MainTests::testSavingSessions() @@ -1002,6 +1002,7 @@ void MainTests::testSavingSessions()
1002 // Kind of a hack... 1002 // Kind of a hack...
1003 Authentication auth(*settings.get()); 1003 Authentication auth(*settings.get());
1004 ThreadGlobals::assign(&auth); 1004 ThreadGlobals::assign(&auth);
  1005 + ThreadGlobals::assignThreadData(t.get());
1005 1006
1006 std::shared_ptr<Client> c1(new Client(0, t, nullptr, false, nullptr, settings.get(), false)); 1007 std::shared_ptr<Client> c1(new Client(0, t, nullptr, false, nullptr, settings.get(), false));
1007 c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60); 1008 c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60);
client.cpp
@@ -45,6 +45,7 @@ Client::Client(int fd, std::shared_ptr&lt;ThreadData&gt; threadData, SSL *ssl, bool we @@ -45,6 +45,7 @@ Client::Client(int fd, std::shared_ptr&lt;ThreadData&gt; threadData, SSL *ssl, bool we
45 ioWrapper(ssl, websocket, initialBufferSize, this), 45 ioWrapper(ssl, websocket, initialBufferSize, this),
46 readbuf(initialBufferSize), 46 readbuf(initialBufferSize),
47 writebuf(initialBufferSize), 47 writebuf(initialBufferSize),
  48 + epoll_fd(threadData ? threadData->epollfd : 0),
48 threadData(threadData) 49 threadData(threadData)
49 { 50 {
50 int flags = fcntl(fd, F_GETFL); 51 int flags = fcntl(fd, F_GETFL);
@@ -61,7 +62,7 @@ Client::Client(int fd, std::shared_ptr&lt;ThreadData&gt; threadData, SSL *ssl, bool we @@ -61,7 +62,7 @@ Client::Client(int fd, std::shared_ptr&lt;ThreadData&gt; threadData, SSL *ssl, bool we
61 Client::~Client() 62 Client::~Client()
62 { 63 {
63 // Dummy clients, that I sometimes need just because the interface demands it but there's not actually a client, have no thread. 64 // Dummy clients, that I sometimes need just because the interface demands it but there's not actually a client, have no thread.
64 - if (!this->threadData) 65 + if (this->threadData.expired())
65 return; 66 return;
66 67
67 if (disconnectReason.empty()) 68 if (disconnectReason.empty())
@@ -78,7 +79,7 @@ Client::~Client() @@ -78,7 +79,7 @@ Client::~Client()
78 79
79 if (fd > 0) // this check is essentially for testing, when working with a dummy fd. 80 if (fd > 0) // this check is essentially for testing, when working with a dummy fd.
80 { 81 {
81 - if (epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL) != 0) 82 + if (epoll_ctl(this->epoll_fd, EPOLL_CTL_DEL, fd, NULL) != 0)
82 logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd, repr().c_str(), strerror(errno)); 83 logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd, repr().c_str(), strerror(errno));
83 close(fd); 84 close(fd);
84 } 85 }
@@ -272,7 +273,9 @@ void Client::writeMqttPacketAndBlameThisClient(const MqttPacket &amp;packet) @@ -272,7 +273,9 @@ void Client::writeMqttPacketAndBlameThisClient(const MqttPacket &amp;packet)
272 } 273 }
273 catch (std::exception &ex) 274 catch (std::exception &ex)
274 { 275 {
275 - threadData->removeClientQueued(fd); 276 + std::shared_ptr<ThreadData> td = this->threadData.lock();
  277 + if (td)
  278 + td->removeClientQueued(fd);
276 } 279 }
277 } 280 }
278 281
@@ -414,7 +417,7 @@ uint16_t Client::getMaxIncomingTopicAliasValue() const @@ -414,7 +417,7 @@ uint16_t Client::getMaxIncomingTopicAliasValue() const
414 417
415 void Client::sendOrQueueWill() 418 void Client::sendOrQueueWill()
416 { 419 {
417 - if (!this->threadData) 420 + if (this->threadData.expired())
418 return; 421 return;
419 422
420 if (!this->willPublish) 423 if (!this->willPublish)
@@ -447,7 +450,10 @@ void Client::serverInitiatedDisconnect(ReasonCodes reason) @@ -447,7 +450,10 @@ void Client::serverInitiatedDisconnect(ReasonCodes reason)
447 else 450 else
448 { 451 {
449 markAsDisconnecting(); 452 markAsDisconnecting();
450 - threadData->removeClientQueued(fd); 453 +
  454 + std::shared_ptr<ThreadData> td = this->threadData.lock();
  455 + if (td)
  456 + td->removeClientQueued(fd);
451 } 457 }
452 } 458 }
453 459
@@ -573,7 +579,7 @@ void Client::setReadyForWriting(bool val) @@ -573,7 +579,7 @@ void Client::setReadyForWriting(bool val)
573 memset(&ev, 0, sizeof (struct epoll_event)); 579 memset(&ev, 0, sizeof (struct epoll_event));
574 ev.data.fd = fd; 580 ev.data.fd = fd;
575 ev.events = readyForReading*EPOLLIN | readyForWriting*EPOLLOUT; 581 ev.events = readyForReading*EPOLLIN | readyForWriting*EPOLLOUT;
576 - check<std::runtime_error>(epoll_ctl(threadData->epollfd, EPOLL_CTL_MOD, fd, &ev)); 582 + check<std::runtime_error>(epoll_ctl(this->epoll_fd, EPOLL_CTL_MOD, fd, &ev));
577 } 583 }
578 584
579 void Client::setReadyForReading(bool val) 585 void Client::setReadyForReading(bool val)
@@ -606,7 +612,7 @@ void Client::setReadyForReading(bool val) @@ -606,7 +612,7 @@ void Client::setReadyForReading(bool val)
606 std::lock_guard<std::mutex> locker(writeBufMutex); 612 std::lock_guard<std::mutex> locker(writeBufMutex);
607 613
608 ev.events = readyForReading*EPOLLIN | readyForWriting*EPOLLOUT; 614 ev.events = readyForReading*EPOLLIN | readyForWriting*EPOLLOUT;
609 - check<std::runtime_error>(epoll_ctl(threadData->epollfd, EPOLL_CTL_MOD, fd, &ev)); 615 + check<std::runtime_error>(epoll_ctl(this->epoll_fd, EPOLL_CTL_MOD, fd, &ev));
610 } 616 }
611 } 617 }
612 618
client.h
@@ -93,7 +93,8 @@ class Client @@ -93,7 +93,8 @@ class Client
93 93
94 std::shared_ptr<WillPublish> willPublish; 94 std::shared_ptr<WillPublish> willPublish;
95 95
96 - std::shared_ptr<ThreadData> threadData; 96 + const int epoll_fd;
  97 + std::weak_ptr<ThreadData> threadData; // The thread (data) that this client 'lives' in.
97 std::mutex writeBufMutex; 98 std::mutex writeBufMutex;
98 99
99 std::shared_ptr<Session> session; 100 std::shared_ptr<Session> session;
@@ -139,7 +140,6 @@ public: @@ -139,7 +140,6 @@ public:
139 void setAuthenticated(bool value) { authenticated = value;} 140 void setAuthenticated(bool value) { authenticated = value;}
140 bool getAuthenticated() { return authenticated; } 141 bool getAuthenticated() { return authenticated; }
141 bool hasConnectPacketSeen() { return connectPacketSeen; } 142 bool hasConnectPacketSeen() { return connectPacketSeen; }
142 - std::shared_ptr<ThreadData> getThreadData() { return threadData; }  
143 std::string &getClientId() { return this->clientid; } 143 std::string &getClientId() { return this->clientid; }
144 const std::string &getUsername() const { return this->username; } 144 const std::string &getUsername() const { return this->username; }
145 std::string &getMutableUsername(); 145 std::string &getMutableUsername();
mqttpacket.cpp
@@ -328,7 +328,7 @@ void MqttPacket::handleConnect() @@ -328,7 +328,7 @@ void MqttPacket::handleConnect()
328 328
329 std::shared_ptr<SubscriptionStore> subscriptionStore = MainApp::getMainApp()->getSubscriptionStore(); 329 std::shared_ptr<SubscriptionStore> subscriptionStore = MainApp::getMainApp()->getSubscriptionStore();
330 330
331 - sender->getThreadData()->mqttConnectCounter.inc(); 331 + ThreadGlobals::getThreadData()->mqttConnectCounter.inc();
332 332
333 uint16_t variable_header_length = readTwoBytesToUInt16(); 333 uint16_t variable_header_length = readTwoBytesToUInt16();
334 334
@@ -864,7 +864,7 @@ void MqttPacket::handleDisconnect() @@ -864,7 +864,7 @@ void MqttPacket::handleDisconnect()
864 sender->markAsDisconnecting(); 864 sender->markAsDisconnecting();
865 if (reasonCode == ReasonCodes::Success) 865 if (reasonCode == ReasonCodes::Success)
866 sender->clearWill(); 866 sender->clearWill();
867 - sender->getThreadData()->removeClientQueued(sender); 867 + ThreadGlobals::getThreadData()->removeClientQueued(sender);
868 } 868 }
869 869
870 void MqttPacket::handleSubscribe() 870 void MqttPacket::handleSubscribe()
@@ -1131,7 +1131,7 @@ void MqttPacket::handlePublish() @@ -1131,7 +1131,7 @@ void MqttPacket::handlePublish()
1131 1131
1132 ReasonCodes ackCode = ReasonCodes::Success; 1132 ReasonCodes ackCode = ReasonCodes::Success;
1133 1133
1134 - sender->getThreadData()->receivedMessageCounter.inc(); 1134 + ThreadGlobals::getThreadData()->receivedMessageCounter.inc();
1135 1135
1136 Authentication &authentication = *ThreadGlobals::getAuth(); 1136 Authentication &authentication = *ThreadGlobals::getAuth();
1137 1137
subscriptionstore.cpp
@@ -225,7 +225,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr&lt;Client&gt; @@ -225,7 +225,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr&lt;Client&gt;
225 // Removes an existing client when it already exists [MQTT-3.1.4-2]. 225 // Removes an existing client when it already exists [MQTT-3.1.4-2].
226 void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr<Client> &client, bool clean_start, uint16_t clientReceiveMax, uint32_t sessionExpiryInterval) 226 void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr<Client> &client, bool clean_start, uint16_t clientReceiveMax, uint32_t sessionExpiryInterval)
227 { 227 {
228 - client->getThreadData()->queueClientNextKeepAliveCheckLocked(client, true); 228 + ThreadGlobals::getThreadData()->queueClientNextKeepAliveCheckLocked(client, true);
229 229
230 RWLockGuard lock_guard(&subscriptionsRwlock); 230 RWLockGuard lock_guard(&subscriptionsRwlock);
231 lock_guard.wrlock(); 231 lock_guard.wrlock();