Commit c20ba7c5ccc68ed325d46f7c1d5ec751ea329c06

Authored by Wiebe Cazemier
1 parent 4bfa5aa5

Use thread local pointer to authentication object

There were bugs in which authentication object was used when, causing
threadings bugs. Instead of getting from the 'sender', we can just store
a thread local pointer.
CMakeLists.txt
... ... @@ -54,6 +54,7 @@ add_executable(FlashMQ
54 54 persistencefile.h
55 55 sessionsandsubscriptionsdb.h
56 56 qospacketqueue.h
  57 + threadauth.h
57 58  
58 59 mainapp.cpp
59 60 main.cpp
... ... @@ -89,6 +90,7 @@ add_executable(FlashMQ
89 90 persistencefile.cpp
90 91 sessionsandsubscriptionsdb.cpp
91 92 qospacketqueue.cpp
  93 + threadauth.cpp
92 94  
93 95 )
94 96  
... ...
FlashMQTests/FlashMQTests.pro
... ... @@ -46,6 +46,7 @@ SOURCES += tst_maintests.cpp \
46 46 ../persistencefile.cpp \
47 47 ../sessionsandsubscriptionsdb.cpp \
48 48 ../qospacketqueue.cpp \
  49 + ../threadauth.cpp \
49 50 mainappthread.cpp \
50 51 twoclienttestcontext.cpp
51 52  
... ... @@ -85,6 +86,7 @@ HEADERS += \
85 86 ../persistencefile.h \
86 87 ../sessionsandsubscriptionsdb.h \
87 88 ../qospacketqueue.h \
  89 + ../threadauth.h \
88 90 mainappthread.h \
89 91 twoclienttestcontext.h
90 92  
... ...
FlashMQTests/tst_maintests.cpp
... ... @@ -33,6 +33,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
33 33 #include "sessionsandsubscriptionsdb.h"
34 34 #include "session.h"
35 35 #include "threaddata.h"
  36 +#include "threadauth.h"
36 37  
37 38 // Dumb Qt version gives warnings when comparing uint with number literal.
38 39 template <typename T1, typename T2>
... ... @@ -947,6 +948,10 @@ void MainTests::testSavingSessions()
947 948 std::shared_ptr<SubscriptionStore> store(new SubscriptionStore());
948 949 std::shared_ptr<ThreadData> t(new ThreadData(0, store, settings));
949 950  
  951 + // Kind of a hack...
  952 + Authentication auth(*settings.get());
  953 + ThreadAuth::assign(&auth);
  954 +
950 955 std::shared_ptr<Client> c1(new Client(0, t, nullptr, false, nullptr, settings, false));
951 956 c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60, false);
952 957 store->registerClientAndKickExistingOne(c1);
... ...
mainapp.cpp
... ... @@ -28,6 +28,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
28 28 #include <openssl/err.h>
29 29  
30 30 #include "logger.h"
  31 +#include "threadauth.h"
31 32  
32 33 #define MAX_EVENTS 1024
33 34  
... ... @@ -36,6 +37,7 @@ MainApp *MainApp::instance = nullptr;
36 37 void do_thread_work(ThreadData *threadData)
37 38 {
38 39 int epoll_fd = threadData->epollfd;
  40 + ThreadAuth::assign(&threadData->authentication);
39 41  
40 42 struct epoll_event events[MAX_EVENTS];
41 43 memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS);
... ...
mqttpacket.cpp
... ... @@ -22,6 +22,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
22 22 #include <cassert>
23 23  
24 24 #include "utils.h"
  25 +#include "threadauth.h"
25 26  
26 27 // We can void constant reallocation of space for parsed subtopics by using this. But, beware to only use it during handling of the current
27 28 // packet. Don't access it for a stored packet, because then it will have changed.
... ... @@ -356,6 +357,8 @@ void MqttPacket::handleConnect()
356 357 bool accessGranted = false;
357 358 std::string denyLogMsg;
358 359  
  360 + Authentication &authentication = *ThreadAuth::getAuth();
  361 +
359 362 if (!user_name_flag && settings.allowAnonymous)
360 363 {
361 364 accessGranted = true;
... ... @@ -366,7 +369,7 @@ void MqttPacket::handleConnect()
366 369 sender->setDisconnectReason("Invalid username character");
367 370 accessGranted = false;
368 371 }
369   - else if (sender->getThreadData()->authentication.unPwdCheck(username, password) == AuthResult::success)
  372 + else if (authentication.unPwdCheck(username, password) == AuthResult::success)
370 373 {
371 374 accessGranted = true;
372 375 }
... ... @@ -421,6 +424,8 @@ void MqttPacket::handleSubscribe()
421 424  
422 425 uint16_t packet_id = readTwoBytesToUInt16();
423 426  
  427 + Authentication &authentication = *ThreadAuth::getAuth();
  428 +
424 429 std::list<char> subs_reponse_codes;
425 430 while (remainingAfterPos() > 0)
426 431 {
... ... @@ -439,7 +444,7 @@ void MqttPacket::handleSubscribe()
439 444 throw ProtocolError("QoS is greater than 2, and/or reserved bytes in QoS field are not 0.");
440 445  
441 446 splitTopic(topic, *subtopics);
442   - if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::subscribe, qos, false) == AuthResult::success)
  447 + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::subscribe, qos, false) == AuthResult::success)
443 448 {
444 449 logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str());
445 450 sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, *subtopics, qos);
... ... @@ -555,7 +560,8 @@ void MqttPacket::handlePublish()
555 560 payloadLen = remainingAfterPos();
556 561 payloadStart = pos;
557 562  
558   - if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success)
  563 + Authentication &authentication = *ThreadAuth::getAuth();
  564 + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success)
559 565 {
560 566 if (retain)
561 567 {
... ...
session.cpp
... ... @@ -19,6 +19,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
19 19  
20 20 #include "session.h"
21 21 #include "client.h"
  22 +#include "threadauth.h"
22 23  
23 24 std::chrono::time_point<std::chrono::steady_clock> appStartTime = std::chrono::steady_clock::now();
24 25  
... ... @@ -118,8 +119,9 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos, bool retain, u
118 119 assert(max_qos <= 2);
119 120 const char qos = std::min<char>(packet.getQos(), max_qos);
120 121  
121   - assert(packet.getSender());
122   - Authentication &auth = packet.getSender()->getThreadData()->authentication;
  122 + Authentication *_auth = ThreadAuth::getAuth();
  123 + assert(_auth);
  124 + Authentication &auth = *_auth;
123 125 if (auth.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success)
124 126 {
125 127 if (qos == 0)
... ...
threadauth.cpp 0 → 100644
  1 +#include "threadauth.h"
  2 +
  3 +thread_local Authentication *ThreadAuth::auth = nullptr;
  4 +
  5 +void ThreadAuth::assign(Authentication *auth)
  6 +{
  7 + ThreadAuth::auth = auth;
  8 +}
  9 +
  10 +Authentication *ThreadAuth::getAuth()
  11 +{
  12 + return auth;
  13 +}
... ...
threadauth.h 0 → 100644
  1 +#ifndef THREADAUTH_H
  2 +#define THREADAUTH_H
  3 +
  4 +class Authentication;
  5 +
  6 +class ThreadAuth
  7 +{
  8 + static thread_local Authentication *auth;
  9 +public:
  10 + static void assign(Authentication *auth);
  11 + static Authentication *getAuth();
  12 +};
  13 +
  14 +#endif // THREADAUTH_H
... ...