diff --git a/CMakeLists.txt b/CMakeLists.txt index b2553f0..284250b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -56,7 +56,7 @@ add_executable(FlashMQ persistencefile.h sessionsandsubscriptionsdb.h qospacketqueue.h - threadauth.h + threadglobals.h threadloop.h mainapp.cpp @@ -93,7 +93,7 @@ add_executable(FlashMQ persistencefile.cpp sessionsandsubscriptionsdb.cpp qospacketqueue.cpp - threadauth.cpp + threadglobals.cpp threadloop.cpp ) diff --git a/FlashMQTests/FlashMQTests.pro b/FlashMQTests/FlashMQTests.pro index c15cea7..8fca6f2 100644 --- a/FlashMQTests/FlashMQTests.pro +++ b/FlashMQTests/FlashMQTests.pro @@ -47,7 +47,7 @@ SOURCES += tst_maintests.cpp \ ../persistencefile.cpp \ ../sessionsandsubscriptionsdb.cpp \ ../qospacketqueue.cpp \ - ../threadauth.cpp \ + ../threadglobals.cpp \ ../threadloop.cpp \ mainappthread.cpp \ twoclienttestcontext.cpp @@ -88,7 +88,7 @@ HEADERS += \ ../persistencefile.h \ ../sessionsandsubscriptionsdb.h \ ../qospacketqueue.h \ - ../threadauth.h \ + ../threadglobals.h \ ../threadloop.h \ mainappthread.h \ twoclienttestcontext.h diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 4d7cfd1..f17085b 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -33,7 +33,7 @@ License along with FlashMQ. If not, see . #include "sessionsandsubscriptionsdb.h" #include "session.h" #include "threaddata.h" -#include "threadauth.h" +#include "threadglobals.h" // Dumb Qt version gives warnings when comparing uint with number literal. template @@ -996,7 +996,7 @@ void MainTests::testSavingSessions() // Kind of a hack... Authentication auth(*settings.get()); - ThreadAuth::assign(&auth); + ThreadGlobals::assign(&auth); std::shared_ptr c1(new Client(0, t, nullptr, false, nullptr, settings, false)); c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60, false); @@ -1115,7 +1115,7 @@ void testCopyPacketHelper(const std::string &topic, char from_qos, char to_qos, // Kind of a hack... Authentication auth(*settings.get()); - ThreadAuth::assign(&auth); + ThreadGlobals::assign(&auth); std::shared_ptr dummyClient(new Client(0, t, nullptr, false, nullptr, settings, false)); dummyClient->setClientProperties(ProtocolVersion::Mqtt311, "qostestclient", "user1", true, 60, false); diff --git a/mainapp.cpp b/mainapp.cpp index 53cccc9..40625be 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -28,7 +28,7 @@ License along with FlashMQ. If not, see . #include #include "logger.h" -#include "threadauth.h" +#include "threadglobals.h" #include "threadloop.h" MainApp *MainApp::instance = nullptr; @@ -645,7 +645,7 @@ void MainApp::loadConfig() confFileParser->loadFile(false); settings = confFileParser->moveSettings(); settingsLocalCopy = *settings.get(); - ThreadAuth::assignSettings(&settingsLocalCopy); + ThreadGlobals::assignSettings(&settingsLocalCopy); if (settings->listeners.empty()) { diff --git a/mqttpacket.cpp b/mqttpacket.cpp index db95a68..14be782 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -22,7 +22,7 @@ License along with FlashMQ. If not, see . #include #include "utils.h" -#include "threadauth.h" +#include "threadglobals.h" RemainingLength::RemainingLength() { @@ -460,7 +460,7 @@ void MqttPacket::handleConnect() bool accessGranted = false; std::string denyLogMsg; - Authentication &authentication = *ThreadAuth::getAuth(); + Authentication &authentication = *ThreadGlobals::getAuth(); if (!user_name_flag && settings.allowAnonymous) { @@ -531,7 +531,7 @@ void MqttPacket::handleSubscribe() throw ProtocolError("Packet ID 0 when subscribing is invalid."); // [MQTT-2.3.1-1] } - Authentication &authentication = *ThreadAuth::getAuth(); + Authentication &authentication = *ThreadGlobals::getAuth(); std::list subs_reponse_codes; while (remainingAfterPos() > 0) @@ -691,7 +691,7 @@ void MqttPacket::handlePublish() payloadLen = remainingAfterPos(); payloadStart = pos; - Authentication &authentication = *ThreadAuth::getAuth(); + Authentication &authentication = *ThreadGlobals::getAuth(); if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, subtopics, AclAccess::write, qos, retain) == AuthResult::success) { if (retain) diff --git a/session.cpp b/session.cpp index b441940..91887e1 100644 --- a/session.cpp +++ b/session.cpp @@ -19,13 +19,13 @@ License along with FlashMQ. If not, see . #include "session.h" #include "client.h" -#include "threadauth.h" +#include "threadglobals.h" std::chrono::time_point appStartTime = std::chrono::steady_clock::now(); Session::Session() : - maxQosMsgPending(ThreadAuth::getSettings()->maxQosMsgPendingPerClient), - maxQosBytesPending(ThreadAuth::getSettings()->maxQosBytesPendingPerClient) + maxQosMsgPending(ThreadGlobals::getSettings()->maxQosMsgPendingPerClient), + maxQosBytesPending(ThreadGlobals::getSettings()->maxQosBytesPendingPerClient) { } @@ -152,7 +152,7 @@ void Session::writePacket(MqttPacket &packet, char max_qos, std::shared_ptr(packet.getQos(), max_qos); - Authentication *_auth = ThreadAuth::getAuth(); + Authentication *_auth = ThreadGlobals::getAuth(); assert(_auth); Authentication &auth = *_auth; if (auth.aclCheck(client_id, username, packet.getTopic(), packet.getSubtopics(), AclAccess::read, effectiveQos, packet.getRetain()) == AuthResult::success) diff --git a/threadauth.cpp b/threadauth.cpp deleted file mode 100644 index 44bda2b..0000000 --- a/threadauth.cpp +++ /dev/null @@ -1,35 +0,0 @@ -#include "threadauth.h" - -thread_local Authentication *ThreadAuth::auth = nullptr; -thread_local ThreadData *ThreadAuth::threadData = nullptr; -thread_local Settings *ThreadAuth::settings = nullptr; - -void ThreadAuth::assign(Authentication *auth) -{ - ThreadAuth::auth = auth; -} - -Authentication *ThreadAuth::getAuth() -{ - return auth; -} - -void ThreadAuth::assignThreadData(ThreadData *threadData) -{ - ThreadAuth::threadData = threadData; -} - -ThreadData *ThreadAuth::getThreadData() -{ - return threadData; -} - -void ThreadAuth::assignSettings(Settings *settings) -{ - ThreadAuth::settings = settings; -} - -Settings *ThreadAuth::getSettings() -{ - return settings; -} diff --git a/threadglobals.cpp b/threadglobals.cpp new file mode 100644 index 0000000..9813b83 --- /dev/null +++ b/threadglobals.cpp @@ -0,0 +1,45 @@ +#include "threadglobals.h" +#include "cassert" + +thread_local Authentication *ThreadGlobals::auth = nullptr; +thread_local ThreadData *ThreadGlobals::threadData = nullptr; +thread_local Settings *ThreadGlobals::settings = nullptr; + +void ThreadGlobals::assign(Authentication *auth) +{ +#ifndef TESTING + assert(ThreadGlobals::auth == nullptr); +#endif + ThreadGlobals::auth = auth; +} + +Authentication *ThreadGlobals::getAuth() +{ + return auth; +} + +void ThreadGlobals::assignThreadData(ThreadData *threadData) +{ +#ifndef TESTING + assert(ThreadGlobals::threadData == nullptr); +#endif + ThreadGlobals::threadData = threadData; +} + +ThreadData *ThreadGlobals::getThreadData() +{ + return threadData; +} + +void ThreadGlobals::assignSettings(Settings *settings) +{ +#ifndef TESTING + assert(ThreadGlobals::settings == nullptr || ThreadGlobals::settings == settings); +#endif + ThreadGlobals::settings = settings; +} + +Settings *ThreadGlobals::getSettings() +{ + return settings; +} diff --git a/threadauth.h b/threadglobals.h index 6f33444..a71ad42 100644 --- a/threadauth.h +++ b/threadglobals.h @@ -1,12 +1,11 @@ -#ifndef THREADAUTH_H -#define THREADAUTH_H +#ifndef THREADGLOBALS_H +#define THREADGLOBALS_H #include "forward_declarations.h" class Authentication; -// TODO: rename, this is no longer just auth, but thread local globals. -class ThreadAuth +class ThreadGlobals { static thread_local Authentication *auth; static thread_local ThreadData *threadData; @@ -22,4 +21,4 @@ public: static Settings *getSettings(); }; -#endif // THREADAUTH_H +#endif // THREADGLOBALS_H diff --git a/threadloop.cpp b/threadloop.cpp index 39883f4..5143a58 100644 --- a/threadloop.cpp +++ b/threadloop.cpp @@ -20,9 +20,9 @@ License along with FlashMQ. If not, see . void do_thread_work(ThreadData *threadData) { int epoll_fd = threadData->epollfd; - ThreadAuth::assign(&threadData->authentication); - ThreadAuth::assignThreadData(threadData); - ThreadAuth::assignSettings(&threadData->settingsLocalCopy); + ThreadGlobals::assign(&threadData->authentication); + ThreadGlobals::assignThreadData(threadData); + ThreadGlobals::assignSettings(&threadData->settingsLocalCopy); struct epoll_event events[MAX_EVENTS]; memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); diff --git a/threadloop.h b/threadloop.h index 6877a5c..9270a05 100644 --- a/threadloop.h +++ b/threadloop.h @@ -19,7 +19,7 @@ License along with FlashMQ. If not, see . #define THREADLOOP_H #include "threaddata.h" -#include "threadauth.h" +#include "threadglobals.h" #define MAX_EVENTS 65536