diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3657202..e872376 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -54,6 +54,7 @@ add_executable(FlashMQ
persistencefile.h
sessionsandsubscriptionsdb.h
qospacketqueue.h
+ threadauth.h
mainapp.cpp
main.cpp
@@ -89,6 +90,7 @@ add_executable(FlashMQ
persistencefile.cpp
sessionsandsubscriptionsdb.cpp
qospacketqueue.cpp
+ threadauth.cpp
)
diff --git a/FlashMQTests/FlashMQTests.pro b/FlashMQTests/FlashMQTests.pro
index 450323e..09e615f 100644
--- a/FlashMQTests/FlashMQTests.pro
+++ b/FlashMQTests/FlashMQTests.pro
@@ -46,6 +46,7 @@ SOURCES += tst_maintests.cpp \
../persistencefile.cpp \
../sessionsandsubscriptionsdb.cpp \
../qospacketqueue.cpp \
+ ../threadauth.cpp \
mainappthread.cpp \
twoclienttestcontext.cpp
@@ -85,6 +86,7 @@ HEADERS += \
../persistencefile.h \
../sessionsandsubscriptionsdb.h \
../qospacketqueue.h \
+ ../threadauth.h \
mainappthread.h \
twoclienttestcontext.h
diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp
index ef1866a..db01518 100644
--- a/FlashMQTests/tst_maintests.cpp
+++ b/FlashMQTests/tst_maintests.cpp
@@ -33,6 +33,7 @@ License along with FlashMQ. If not, see .
#include "sessionsandsubscriptionsdb.h"
#include "session.h"
#include "threaddata.h"
+#include "threadauth.h"
// Dumb Qt version gives warnings when comparing uint with number literal.
template
@@ -947,6 +948,10 @@ void MainTests::testSavingSessions()
std::shared_ptr store(new SubscriptionStore());
std::shared_ptr t(new ThreadData(0, store, settings));
+ // Kind of a hack...
+ Authentication auth(*settings.get());
+ ThreadAuth::assign(&auth);
+
std::shared_ptr c1(new Client(0, t, nullptr, false, nullptr, settings, false));
c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60, false);
store->registerClientAndKickExistingOne(c1);
diff --git a/mainapp.cpp b/mainapp.cpp
index de4f00e..be14250 100644
--- a/mainapp.cpp
+++ b/mainapp.cpp
@@ -28,6 +28,7 @@ License along with FlashMQ. If not, see .
#include
#include "logger.h"
+#include "threadauth.h"
#define MAX_EVENTS 1024
@@ -36,6 +37,7 @@ MainApp *MainApp::instance = nullptr;
void do_thread_work(ThreadData *threadData)
{
int epoll_fd = threadData->epollfd;
+ ThreadAuth::assign(&threadData->authentication);
struct epoll_event events[MAX_EVENTS];
memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS);
diff --git a/mqttpacket.cpp b/mqttpacket.cpp
index e69c310..ac9cce5 100644
--- a/mqttpacket.cpp
+++ b/mqttpacket.cpp
@@ -22,6 +22,7 @@ License along with FlashMQ. If not, see .
#include
#include "utils.h"
+#include "threadauth.h"
// We can void constant reallocation of space for parsed subtopics by using this. But, beware to only use it during handling of the current
// packet. Don't access it for a stored packet, because then it will have changed.
@@ -356,6 +357,8 @@ void MqttPacket::handleConnect()
bool accessGranted = false;
std::string denyLogMsg;
+ Authentication &authentication = *ThreadAuth::getAuth();
+
if (!user_name_flag && settings.allowAnonymous)
{
accessGranted = true;
@@ -366,7 +369,7 @@ void MqttPacket::handleConnect()
sender->setDisconnectReason("Invalid username character");
accessGranted = false;
}
- else if (sender->getThreadData()->authentication.unPwdCheck(username, password) == AuthResult::success)
+ else if (authentication.unPwdCheck(username, password) == AuthResult::success)
{
accessGranted = true;
}
@@ -421,6 +424,8 @@ void MqttPacket::handleSubscribe()
uint16_t packet_id = readTwoBytesToUInt16();
+ Authentication &authentication = *ThreadAuth::getAuth();
+
std::list subs_reponse_codes;
while (remainingAfterPos() > 0)
{
@@ -439,7 +444,7 @@ void MqttPacket::handleSubscribe()
throw ProtocolError("QoS is greater than 2, and/or reserved bytes in QoS field are not 0.");
splitTopic(topic, *subtopics);
- if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::subscribe, qos, false) == AuthResult::success)
+ if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::subscribe, qos, false) == AuthResult::success)
{
logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str());
sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, *subtopics, qos);
@@ -555,7 +560,8 @@ void MqttPacket::handlePublish()
payloadLen = remainingAfterPos();
payloadStart = pos;
- if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success)
+ Authentication &authentication = *ThreadAuth::getAuth();
+ if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success)
{
if (retain)
{
diff --git a/session.cpp b/session.cpp
index e0fc9fe..9b3a475 100644
--- a/session.cpp
+++ b/session.cpp
@@ -19,6 +19,7 @@ License along with FlashMQ. If not, see .
#include "session.h"
#include "client.h"
+#include "threadauth.h"
std::chrono::time_point appStartTime = std::chrono::steady_clock::now();
@@ -118,8 +119,9 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u
assert(max_qos <= 2);
const char qos = std::min(packet.getQos(), max_qos);
- assert(packet.getSender());
- Authentication &auth = packet.getSender()->getThreadData()->authentication;
+ Authentication *_auth = ThreadAuth::getAuth();
+ assert(_auth);
+ Authentication &auth = *_auth;
if (auth.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success)
{
if (qos == 0)
diff --git a/threadauth.cpp b/threadauth.cpp
new file mode 100644
index 0000000..6e9d54e
--- /dev/null
+++ b/threadauth.cpp
@@ -0,0 +1,13 @@
+#include "threadauth.h"
+
+thread_local Authentication *ThreadAuth::auth = nullptr;
+
+void ThreadAuth::assign(Authentication *auth)
+{
+ ThreadAuth::auth = auth;
+}
+
+Authentication *ThreadAuth::getAuth()
+{
+ return auth;
+}
diff --git a/threadauth.h b/threadauth.h
new file mode 100644
index 0000000..165cf4b
--- /dev/null
+++ b/threadauth.h
@@ -0,0 +1,14 @@
+#ifndef THREADAUTH_H
+#define THREADAUTH_H
+
+class Authentication;
+
+class ThreadAuth
+{
+ static thread_local Authentication *auth;
+public:
+ static void assign(Authentication *auth);
+ static Authentication *getAuth();
+};
+
+#endif // THREADAUTH_H