Commit c20ba7c5ccc68ed325d46f7c1d5ec751ea329c06

Authored by Wiebe Cazemier
1 parent 4bfa5aa5

Use thread local pointer to authentication object

There were bugs in which authentication object was used when, causing
threadings bugs. Instead of getting from the 'sender', we can just store
a thread local pointer.
CMakeLists.txt
@@ -54,6 +54,7 @@ add_executable(FlashMQ @@ -54,6 +54,7 @@ add_executable(FlashMQ
54 persistencefile.h 54 persistencefile.h
55 sessionsandsubscriptionsdb.h 55 sessionsandsubscriptionsdb.h
56 qospacketqueue.h 56 qospacketqueue.h
  57 + threadauth.h
57 58
58 mainapp.cpp 59 mainapp.cpp
59 main.cpp 60 main.cpp
@@ -89,6 +90,7 @@ add_executable(FlashMQ @@ -89,6 +90,7 @@ add_executable(FlashMQ
89 persistencefile.cpp 90 persistencefile.cpp
90 sessionsandsubscriptionsdb.cpp 91 sessionsandsubscriptionsdb.cpp
91 qospacketqueue.cpp 92 qospacketqueue.cpp
  93 + threadauth.cpp
92 94
93 ) 95 )
94 96
FlashMQTests/FlashMQTests.pro
@@ -46,6 +46,7 @@ SOURCES += tst_maintests.cpp \ @@ -46,6 +46,7 @@ SOURCES += tst_maintests.cpp \
46 ../persistencefile.cpp \ 46 ../persistencefile.cpp \
47 ../sessionsandsubscriptionsdb.cpp \ 47 ../sessionsandsubscriptionsdb.cpp \
48 ../qospacketqueue.cpp \ 48 ../qospacketqueue.cpp \
  49 + ../threadauth.cpp \
49 mainappthread.cpp \ 50 mainappthread.cpp \
50 twoclienttestcontext.cpp 51 twoclienttestcontext.cpp
51 52
@@ -85,6 +86,7 @@ HEADERS += \ @@ -85,6 +86,7 @@ HEADERS += \
85 ../persistencefile.h \ 86 ../persistencefile.h \
86 ../sessionsandsubscriptionsdb.h \ 87 ../sessionsandsubscriptionsdb.h \
87 ../qospacketqueue.h \ 88 ../qospacketqueue.h \
  89 + ../threadauth.h \
88 mainappthread.h \ 90 mainappthread.h \
89 twoclienttestcontext.h 91 twoclienttestcontext.h
90 92
FlashMQTests/tst_maintests.cpp
@@ -33,6 +33,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. @@ -33,6 +33,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
33 #include "sessionsandsubscriptionsdb.h" 33 #include "sessionsandsubscriptionsdb.h"
34 #include "session.h" 34 #include "session.h"
35 #include "threaddata.h" 35 #include "threaddata.h"
  36 +#include "threadauth.h"
36 37
37 // Dumb Qt version gives warnings when comparing uint with number literal. 38 // Dumb Qt version gives warnings when comparing uint with number literal.
38 template <typename T1, typename T2> 39 template <typename T1, typename T2>
@@ -947,6 +948,10 @@ void MainTests::testSavingSessions() @@ -947,6 +948,10 @@ void MainTests::testSavingSessions()
947 std::shared_ptr<SubscriptionStore> store(new SubscriptionStore()); 948 std::shared_ptr<SubscriptionStore> store(new SubscriptionStore());
948 std::shared_ptr<ThreadData> t(new ThreadData(0, store, settings)); 949 std::shared_ptr<ThreadData> t(new ThreadData(0, store, settings));
949 950
  951 + // Kind of a hack...
  952 + Authentication auth(*settings.get());
  953 + ThreadAuth::assign(&auth);
  954 +
950 std::shared_ptr<Client> c1(new Client(0, t, nullptr, false, nullptr, settings, false)); 955 std::shared_ptr<Client> c1(new Client(0, t, nullptr, false, nullptr, settings, false));
951 c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60, false); 956 c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60, false);
952 store->registerClientAndKickExistingOne(c1); 957 store->registerClientAndKickExistingOne(c1);
mainapp.cpp
@@ -28,6 +28,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -28,6 +28,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
28 #include <openssl/err.h> 28 #include <openssl/err.h>
29 29
30 #include "logger.h" 30 #include "logger.h"
  31 +#include "threadauth.h"
31 32
32 #define MAX_EVENTS 1024 33 #define MAX_EVENTS 1024
33 34
@@ -36,6 +37,7 @@ MainApp *MainApp::instance = nullptr; @@ -36,6 +37,7 @@ MainApp *MainApp::instance = nullptr;
36 void do_thread_work(ThreadData *threadData) 37 void do_thread_work(ThreadData *threadData)
37 { 38 {
38 int epoll_fd = threadData->epollfd; 39 int epoll_fd = threadData->epollfd;
  40 + ThreadAuth::assign(&threadData->authentication);
39 41
40 struct epoll_event events[MAX_EVENTS]; 42 struct epoll_event events[MAX_EVENTS];
41 memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); 43 memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS);
mqttpacket.cpp
@@ -22,6 +22,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -22,6 +22,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
22 #include <cassert> 22 #include <cassert>
23 23
24 #include "utils.h" 24 #include "utils.h"
  25 +#include "threadauth.h"
25 26
26 // We can void constant reallocation of space for parsed subtopics by using this. But, beware to only use it during handling of the current 27 // We can void constant reallocation of space for parsed subtopics by using this. But, beware to only use it during handling of the current
27 // packet. Don't access it for a stored packet, because then it will have changed. 28 // packet. Don't access it for a stored packet, because then it will have changed.
@@ -356,6 +357,8 @@ void MqttPacket::handleConnect() @@ -356,6 +357,8 @@ void MqttPacket::handleConnect()
356 bool accessGranted = false; 357 bool accessGranted = false;
357 std::string denyLogMsg; 358 std::string denyLogMsg;
358 359
  360 + Authentication &authentication = *ThreadAuth::getAuth();
  361 +
359 if (!user_name_flag && settings.allowAnonymous) 362 if (!user_name_flag && settings.allowAnonymous)
360 { 363 {
361 accessGranted = true; 364 accessGranted = true;
@@ -366,7 +369,7 @@ void MqttPacket::handleConnect() @@ -366,7 +369,7 @@ void MqttPacket::handleConnect()
366 sender->setDisconnectReason("Invalid username character"); 369 sender->setDisconnectReason("Invalid username character");
367 accessGranted = false; 370 accessGranted = false;
368 } 371 }
369 - else if (sender->getThreadData()->authentication.unPwdCheck(username, password) == AuthResult::success) 372 + else if (authentication.unPwdCheck(username, password) == AuthResult::success)
370 { 373 {
371 accessGranted = true; 374 accessGranted = true;
372 } 375 }
@@ -421,6 +424,8 @@ void MqttPacket::handleSubscribe() @@ -421,6 +424,8 @@ void MqttPacket::handleSubscribe()
421 424
422 uint16_t packet_id = readTwoBytesToUInt16(); 425 uint16_t packet_id = readTwoBytesToUInt16();
423 426
  427 + Authentication &authentication = *ThreadAuth::getAuth();
  428 +
424 std::list<char> subs_reponse_codes; 429 std::list<char> subs_reponse_codes;
425 while (remainingAfterPos() > 0) 430 while (remainingAfterPos() > 0)
426 { 431 {
@@ -439,7 +444,7 @@ void MqttPacket::handleSubscribe() @@ -439,7 +444,7 @@ void MqttPacket::handleSubscribe()
439 throw ProtocolError("QoS is greater than 2, and/or reserved bytes in QoS field are not 0."); 444 throw ProtocolError("QoS is greater than 2, and/or reserved bytes in QoS field are not 0.");
440 445
441 splitTopic(topic, *subtopics); 446 splitTopic(topic, *subtopics);
442 - if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::subscribe, qos, false) == AuthResult::success) 447 + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::subscribe, qos, false) == AuthResult::success)
443 { 448 {
444 logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str()); 449 logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str());
445 sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, *subtopics, qos); 450 sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, *subtopics, qos);
@@ -555,7 +560,8 @@ void MqttPacket::handlePublish() @@ -555,7 +560,8 @@ void MqttPacket::handlePublish()
555 payloadLen = remainingAfterPos(); 560 payloadLen = remainingAfterPos();
556 payloadStart = pos; 561 payloadStart = pos;
557 562
558 - if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success) 563 + Authentication &authentication = *ThreadAuth::getAuth();
  564 + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success)
559 { 565 {
560 if (retain) 566 if (retain)
561 { 567 {
session.cpp
@@ -19,6 +19,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -19,6 +19,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
19 19
20 #include "session.h" 20 #include "session.h"
21 #include "client.h" 21 #include "client.h"
  22 +#include "threadauth.h"
22 23
23 std::chrono::time_point<std::chrono::steady_clock> appStartTime = std::chrono::steady_clock::now(); 24 std::chrono::time_point<std::chrono::steady_clock> appStartTime = std::chrono::steady_clock::now();
24 25
@@ -118,8 +119,9 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos, bool retain, u @@ -118,8 +119,9 @@ void Session::writePacket(const MqttPacket &amp;packet, char max_qos, bool retain, u
118 assert(max_qos <= 2); 119 assert(max_qos <= 2);
119 const char qos = std::min<char>(packet.getQos(), max_qos); 120 const char qos = std::min<char>(packet.getQos(), max_qos);
120 121
121 - assert(packet.getSender());  
122 - Authentication &auth = packet.getSender()->getThreadData()->authentication; 122 + Authentication *_auth = ThreadAuth::getAuth();
  123 + assert(_auth);
  124 + Authentication &auth = *_auth;
123 if (auth.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success) 125 if (auth.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success)
124 { 126 {
125 if (qos == 0) 127 if (qos == 0)
threadauth.cpp 0 → 100644
  1 +#include "threadauth.h"
  2 +
  3 +thread_local Authentication *ThreadAuth::auth = nullptr;
  4 +
  5 +void ThreadAuth::assign(Authentication *auth)
  6 +{
  7 + ThreadAuth::auth = auth;
  8 +}
  9 +
  10 +Authentication *ThreadAuth::getAuth()
  11 +{
  12 + return auth;
  13 +}
threadauth.h 0 → 100644
  1 +#ifndef THREADAUTH_H
  2 +#define THREADAUTH_H
  3 +
  4 +class Authentication;
  5 +
  6 +class ThreadAuth
  7 +{
  8 + static thread_local Authentication *auth;
  9 +public:
  10 + static void assign(Authentication *auth);
  11 + static Authentication *getAuth();
  12 +};
  13 +
  14 +#endif // THREADAUTH_H