From c20ba7c5ccc68ed325d46f7c1d5ec751ea329c06 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Sat, 24 Jul 2021 14:16:08 +0200 Subject: [PATCH] Use thread local pointer to authentication object --- CMakeLists.txt | 2 ++ FlashMQTests/FlashMQTests.pro | 2 ++ FlashMQTests/tst_maintests.cpp | 5 +++++ mainapp.cpp | 2 ++ mqttpacket.cpp | 12 +++++++++--- session.cpp | 6 ++++-- threadauth.cpp | 13 +++++++++++++ threadauth.h | 14 ++++++++++++++ 8 files changed, 51 insertions(+), 5 deletions(-) create mode 100644 threadauth.cpp create mode 100644 threadauth.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 3657202..e872376 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -54,6 +54,7 @@ add_executable(FlashMQ persistencefile.h sessionsandsubscriptionsdb.h qospacketqueue.h + threadauth.h mainapp.cpp main.cpp @@ -89,6 +90,7 @@ add_executable(FlashMQ persistencefile.cpp sessionsandsubscriptionsdb.cpp qospacketqueue.cpp + threadauth.cpp ) diff --git a/FlashMQTests/FlashMQTests.pro b/FlashMQTests/FlashMQTests.pro index 450323e..09e615f 100644 --- a/FlashMQTests/FlashMQTests.pro +++ b/FlashMQTests/FlashMQTests.pro @@ -46,6 +46,7 @@ SOURCES += tst_maintests.cpp \ ../persistencefile.cpp \ ../sessionsandsubscriptionsdb.cpp \ ../qospacketqueue.cpp \ + ../threadauth.cpp \ mainappthread.cpp \ twoclienttestcontext.cpp @@ -85,6 +86,7 @@ HEADERS += \ ../persistencefile.h \ ../sessionsandsubscriptionsdb.h \ ../qospacketqueue.h \ + ../threadauth.h \ mainappthread.h \ twoclienttestcontext.h diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index ef1866a..db01518 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -33,6 +33,7 @@ License along with FlashMQ. If not, see . #include "sessionsandsubscriptionsdb.h" #include "session.h" #include "threaddata.h" +#include "threadauth.h" // Dumb Qt version gives warnings when comparing uint with number literal. template @@ -947,6 +948,10 @@ void MainTests::testSavingSessions() std::shared_ptr store(new SubscriptionStore()); std::shared_ptr t(new ThreadData(0, store, settings)); + // Kind of a hack... + Authentication auth(*settings.get()); + ThreadAuth::assign(&auth); + std::shared_ptr c1(new Client(0, t, nullptr, false, nullptr, settings, false)); c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60, false); store->registerClientAndKickExistingOne(c1); diff --git a/mainapp.cpp b/mainapp.cpp index de4f00e..be14250 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -28,6 +28,7 @@ License along with FlashMQ. If not, see . #include #include "logger.h" +#include "threadauth.h" #define MAX_EVENTS 1024 @@ -36,6 +37,7 @@ MainApp *MainApp::instance = nullptr; void do_thread_work(ThreadData *threadData) { int epoll_fd = threadData->epollfd; + ThreadAuth::assign(&threadData->authentication); struct epoll_event events[MAX_EVENTS]; memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); diff --git a/mqttpacket.cpp b/mqttpacket.cpp index e69c310..ac9cce5 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -22,6 +22,7 @@ License along with FlashMQ. If not, see . #include #include "utils.h" +#include "threadauth.h" // We can void constant reallocation of space for parsed subtopics by using this. But, beware to only use it during handling of the current // packet. Don't access it for a stored packet, because then it will have changed. @@ -356,6 +357,8 @@ void MqttPacket::handleConnect() bool accessGranted = false; std::string denyLogMsg; + Authentication &authentication = *ThreadAuth::getAuth(); + if (!user_name_flag && settings.allowAnonymous) { accessGranted = true; @@ -366,7 +369,7 @@ void MqttPacket::handleConnect() sender->setDisconnectReason("Invalid username character"); accessGranted = false; } - else if (sender->getThreadData()->authentication.unPwdCheck(username, password) == AuthResult::success) + else if (authentication.unPwdCheck(username, password) == AuthResult::success) { accessGranted = true; } @@ -421,6 +424,8 @@ void MqttPacket::handleSubscribe() uint16_t packet_id = readTwoBytesToUInt16(); + Authentication &authentication = *ThreadAuth::getAuth(); + std::list subs_reponse_codes; while (remainingAfterPos() > 0) { @@ -439,7 +444,7 @@ void MqttPacket::handleSubscribe() throw ProtocolError("QoS is greater than 2, and/or reserved bytes in QoS field are not 0."); splitTopic(topic, *subtopics); - if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::subscribe, qos, false) == AuthResult::success) + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::subscribe, qos, false) == AuthResult::success) { logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str()); sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, *subtopics, qos); @@ -555,7 +560,8 @@ void MqttPacket::handlePublish() payloadLen = remainingAfterPos(); payloadStart = pos; - if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success) + Authentication &authentication = *ThreadAuth::getAuth(); + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success) { if (retain) { diff --git a/session.cpp b/session.cpp index e0fc9fe..9b3a475 100644 --- a/session.cpp +++ b/session.cpp @@ -19,6 +19,7 @@ License along with FlashMQ. If not, see . #include "session.h" #include "client.h" +#include "threadauth.h" std::chrono::time_point appStartTime = std::chrono::steady_clock::now(); @@ -118,8 +119,9 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u assert(max_qos <= 2); const char qos = std::min(packet.getQos(), max_qos); - assert(packet.getSender()); - Authentication &auth = packet.getSender()->getThreadData()->authentication; + Authentication *_auth = ThreadAuth::getAuth(); + assert(_auth); + Authentication &auth = *_auth; if (auth.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success) { if (qos == 0) diff --git a/threadauth.cpp b/threadauth.cpp new file mode 100644 index 0000000..6e9d54e --- /dev/null +++ b/threadauth.cpp @@ -0,0 +1,13 @@ +#include "threadauth.h" + +thread_local Authentication *ThreadAuth::auth = nullptr; + +void ThreadAuth::assign(Authentication *auth) +{ + ThreadAuth::auth = auth; +} + +Authentication *ThreadAuth::getAuth() +{ + return auth; +} diff --git a/threadauth.h b/threadauth.h new file mode 100644 index 0000000..165cf4b --- /dev/null +++ b/threadauth.h @@ -0,0 +1,14 @@ +#ifndef THREADAUTH_H +#define THREADAUTH_H + +class Authentication; + +class ThreadAuth +{ + static thread_local Authentication *auth; +public: + static void assign(Authentication *auth); + static Authentication *getAuth(); +}; + +#endif // THREADAUTH_H -- libgit2 0.21.4