Commit c20ba7c5ccc68ed325d46f7c1d5ec751ea329c06
1 parent
4bfa5aa5
Use thread local pointer to authentication object
There were bugs in which authentication object was used when, causing threadings bugs. Instead of getting from the 'sender', we can just store a thread local pointer.
Showing
8 changed files
with
51 additions
and
5 deletions
CMakeLists.txt
| @@ -54,6 +54,7 @@ add_executable(FlashMQ | @@ -54,6 +54,7 @@ add_executable(FlashMQ | ||
| 54 | persistencefile.h | 54 | persistencefile.h |
| 55 | sessionsandsubscriptionsdb.h | 55 | sessionsandsubscriptionsdb.h |
| 56 | qospacketqueue.h | 56 | qospacketqueue.h |
| 57 | + threadauth.h | ||
| 57 | 58 | ||
| 58 | mainapp.cpp | 59 | mainapp.cpp |
| 59 | main.cpp | 60 | main.cpp |
| @@ -89,6 +90,7 @@ add_executable(FlashMQ | @@ -89,6 +90,7 @@ add_executable(FlashMQ | ||
| 89 | persistencefile.cpp | 90 | persistencefile.cpp |
| 90 | sessionsandsubscriptionsdb.cpp | 91 | sessionsandsubscriptionsdb.cpp |
| 91 | qospacketqueue.cpp | 92 | qospacketqueue.cpp |
| 93 | + threadauth.cpp | ||
| 92 | 94 | ||
| 93 | ) | 95 | ) |
| 94 | 96 |
FlashMQTests/FlashMQTests.pro
| @@ -46,6 +46,7 @@ SOURCES += tst_maintests.cpp \ | @@ -46,6 +46,7 @@ SOURCES += tst_maintests.cpp \ | ||
| 46 | ../persistencefile.cpp \ | 46 | ../persistencefile.cpp \ |
| 47 | ../sessionsandsubscriptionsdb.cpp \ | 47 | ../sessionsandsubscriptionsdb.cpp \ |
| 48 | ../qospacketqueue.cpp \ | 48 | ../qospacketqueue.cpp \ |
| 49 | + ../threadauth.cpp \ | ||
| 49 | mainappthread.cpp \ | 50 | mainappthread.cpp \ |
| 50 | twoclienttestcontext.cpp | 51 | twoclienttestcontext.cpp |
| 51 | 52 | ||
| @@ -85,6 +86,7 @@ HEADERS += \ | @@ -85,6 +86,7 @@ HEADERS += \ | ||
| 85 | ../persistencefile.h \ | 86 | ../persistencefile.h \ |
| 86 | ../sessionsandsubscriptionsdb.h \ | 87 | ../sessionsandsubscriptionsdb.h \ |
| 87 | ../qospacketqueue.h \ | 88 | ../qospacketqueue.h \ |
| 89 | + ../threadauth.h \ | ||
| 88 | mainappthread.h \ | 90 | mainappthread.h \ |
| 89 | twoclienttestcontext.h | 91 | twoclienttestcontext.h |
| 90 | 92 |
FlashMQTests/tst_maintests.cpp
| @@ -33,6 +33,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | @@ -33,6 +33,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | ||
| 33 | #include "sessionsandsubscriptionsdb.h" | 33 | #include "sessionsandsubscriptionsdb.h" |
| 34 | #include "session.h" | 34 | #include "session.h" |
| 35 | #include "threaddata.h" | 35 | #include "threaddata.h" |
| 36 | +#include "threadauth.h" | ||
| 36 | 37 | ||
| 37 | // Dumb Qt version gives warnings when comparing uint with number literal. | 38 | // Dumb Qt version gives warnings when comparing uint with number literal. |
| 38 | template <typename T1, typename T2> | 39 | template <typename T1, typename T2> |
| @@ -947,6 +948,10 @@ void MainTests::testSavingSessions() | @@ -947,6 +948,10 @@ void MainTests::testSavingSessions() | ||
| 947 | std::shared_ptr<SubscriptionStore> store(new SubscriptionStore()); | 948 | std::shared_ptr<SubscriptionStore> store(new SubscriptionStore()); |
| 948 | std::shared_ptr<ThreadData> t(new ThreadData(0, store, settings)); | 949 | std::shared_ptr<ThreadData> t(new ThreadData(0, store, settings)); |
| 949 | 950 | ||
| 951 | + // Kind of a hack... | ||
| 952 | + Authentication auth(*settings.get()); | ||
| 953 | + ThreadAuth::assign(&auth); | ||
| 954 | + | ||
| 950 | std::shared_ptr<Client> c1(new Client(0, t, nullptr, false, nullptr, settings, false)); | 955 | std::shared_ptr<Client> c1(new Client(0, t, nullptr, false, nullptr, settings, false)); |
| 951 | c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60, false); | 956 | c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60, false); |
| 952 | store->registerClientAndKickExistingOne(c1); | 957 | store->registerClientAndKickExistingOne(c1); |
mainapp.cpp
| @@ -28,6 +28,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | @@ -28,6 +28,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | ||
| 28 | #include <openssl/err.h> | 28 | #include <openssl/err.h> |
| 29 | 29 | ||
| 30 | #include "logger.h" | 30 | #include "logger.h" |
| 31 | +#include "threadauth.h" | ||
| 31 | 32 | ||
| 32 | #define MAX_EVENTS 1024 | 33 | #define MAX_EVENTS 1024 |
| 33 | 34 | ||
| @@ -36,6 +37,7 @@ MainApp *MainApp::instance = nullptr; | @@ -36,6 +37,7 @@ MainApp *MainApp::instance = nullptr; | ||
| 36 | void do_thread_work(ThreadData *threadData) | 37 | void do_thread_work(ThreadData *threadData) |
| 37 | { | 38 | { |
| 38 | int epoll_fd = threadData->epollfd; | 39 | int epoll_fd = threadData->epollfd; |
| 40 | + ThreadAuth::assign(&threadData->authentication); | ||
| 39 | 41 | ||
| 40 | struct epoll_event events[MAX_EVENTS]; | 42 | struct epoll_event events[MAX_EVENTS]; |
| 41 | memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); | 43 | memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); |
mqttpacket.cpp
| @@ -22,6 +22,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | @@ -22,6 +22,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | ||
| 22 | #include <cassert> | 22 | #include <cassert> |
| 23 | 23 | ||
| 24 | #include "utils.h" | 24 | #include "utils.h" |
| 25 | +#include "threadauth.h" | ||
| 25 | 26 | ||
| 26 | // We can void constant reallocation of space for parsed subtopics by using this. But, beware to only use it during handling of the current | 27 | // We can void constant reallocation of space for parsed subtopics by using this. But, beware to only use it during handling of the current |
| 27 | // packet. Don't access it for a stored packet, because then it will have changed. | 28 | // packet. Don't access it for a stored packet, because then it will have changed. |
| @@ -356,6 +357,8 @@ void MqttPacket::handleConnect() | @@ -356,6 +357,8 @@ void MqttPacket::handleConnect() | ||
| 356 | bool accessGranted = false; | 357 | bool accessGranted = false; |
| 357 | std::string denyLogMsg; | 358 | std::string denyLogMsg; |
| 358 | 359 | ||
| 360 | + Authentication &authentication = *ThreadAuth::getAuth(); | ||
| 361 | + | ||
| 359 | if (!user_name_flag && settings.allowAnonymous) | 362 | if (!user_name_flag && settings.allowAnonymous) |
| 360 | { | 363 | { |
| 361 | accessGranted = true; | 364 | accessGranted = true; |
| @@ -366,7 +369,7 @@ void MqttPacket::handleConnect() | @@ -366,7 +369,7 @@ void MqttPacket::handleConnect() | ||
| 366 | sender->setDisconnectReason("Invalid username character"); | 369 | sender->setDisconnectReason("Invalid username character"); |
| 367 | accessGranted = false; | 370 | accessGranted = false; |
| 368 | } | 371 | } |
| 369 | - else if (sender->getThreadData()->authentication.unPwdCheck(username, password) == AuthResult::success) | 372 | + else if (authentication.unPwdCheck(username, password) == AuthResult::success) |
| 370 | { | 373 | { |
| 371 | accessGranted = true; | 374 | accessGranted = true; |
| 372 | } | 375 | } |
| @@ -421,6 +424,8 @@ void MqttPacket::handleSubscribe() | @@ -421,6 +424,8 @@ void MqttPacket::handleSubscribe() | ||
| 421 | 424 | ||
| 422 | uint16_t packet_id = readTwoBytesToUInt16(); | 425 | uint16_t packet_id = readTwoBytesToUInt16(); |
| 423 | 426 | ||
| 427 | + Authentication &authentication = *ThreadAuth::getAuth(); | ||
| 428 | + | ||
| 424 | std::list<char> subs_reponse_codes; | 429 | std::list<char> subs_reponse_codes; |
| 425 | while (remainingAfterPos() > 0) | 430 | while (remainingAfterPos() > 0) |
| 426 | { | 431 | { |
| @@ -439,7 +444,7 @@ void MqttPacket::handleSubscribe() | @@ -439,7 +444,7 @@ void MqttPacket::handleSubscribe() | ||
| 439 | throw ProtocolError("QoS is greater than 2, and/or reserved bytes in QoS field are not 0."); | 444 | throw ProtocolError("QoS is greater than 2, and/or reserved bytes in QoS field are not 0."); |
| 440 | 445 | ||
| 441 | splitTopic(topic, *subtopics); | 446 | splitTopic(topic, *subtopics); |
| 442 | - if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::subscribe, qos, false) == AuthResult::success) | 447 | + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::subscribe, qos, false) == AuthResult::success) |
| 443 | { | 448 | { |
| 444 | logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str()); | 449 | logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str()); |
| 445 | sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, *subtopics, qos); | 450 | sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, *subtopics, qos); |
| @@ -555,7 +560,8 @@ void MqttPacket::handlePublish() | @@ -555,7 +560,8 @@ void MqttPacket::handlePublish() | ||
| 555 | payloadLen = remainingAfterPos(); | 560 | payloadLen = remainingAfterPos(); |
| 556 | payloadStart = pos; | 561 | payloadStart = pos; |
| 557 | 562 | ||
| 558 | - if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success) | 563 | + Authentication &authentication = *ThreadAuth::getAuth(); |
| 564 | + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success) | ||
| 559 | { | 565 | { |
| 560 | if (retain) | 566 | if (retain) |
| 561 | { | 567 | { |
session.cpp
| @@ -19,6 +19,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | @@ -19,6 +19,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | ||
| 19 | 19 | ||
| 20 | #include "session.h" | 20 | #include "session.h" |
| 21 | #include "client.h" | 21 | #include "client.h" |
| 22 | +#include "threadauth.h" | ||
| 22 | 23 | ||
| 23 | std::chrono::time_point<std::chrono::steady_clock> appStartTime = std::chrono::steady_clock::now(); | 24 | std::chrono::time_point<std::chrono::steady_clock> appStartTime = std::chrono::steady_clock::now(); |
| 24 | 25 | ||
| @@ -118,8 +119,9 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u | @@ -118,8 +119,9 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u | ||
| 118 | assert(max_qos <= 2); | 119 | assert(max_qos <= 2); |
| 119 | const char qos = std::min<char>(packet.getQos(), max_qos); | 120 | const char qos = std::min<char>(packet.getQos(), max_qos); |
| 120 | 121 | ||
| 121 | - assert(packet.getSender()); | ||
| 122 | - Authentication &auth = packet.getSender()->getThreadData()->authentication; | 122 | + Authentication *_auth = ThreadAuth::getAuth(); |
| 123 | + assert(_auth); | ||
| 124 | + Authentication &auth = *_auth; | ||
| 123 | if (auth.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success) | 125 | if (auth.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success) |
| 124 | { | 126 | { |
| 125 | if (qos == 0) | 127 | if (qos == 0) |
threadauth.cpp
0 → 100644
threadauth.h
0 → 100644