Commit c20ba7c5ccc68ed325d46f7c1d5ec751ea329c06
1 parent
4bfa5aa5
Use thread local pointer to authentication object
There were bugs in which authentication object was used when, causing threadings bugs. Instead of getting from the 'sender', we can just store a thread local pointer.
Showing
8 changed files
with
51 additions
and
5 deletions
CMakeLists.txt
| ... | ... | @@ -54,6 +54,7 @@ add_executable(FlashMQ |
| 54 | 54 | persistencefile.h |
| 55 | 55 | sessionsandsubscriptionsdb.h |
| 56 | 56 | qospacketqueue.h |
| 57 | + threadauth.h | |
| 57 | 58 | |
| 58 | 59 | mainapp.cpp |
| 59 | 60 | main.cpp |
| ... | ... | @@ -89,6 +90,7 @@ add_executable(FlashMQ |
| 89 | 90 | persistencefile.cpp |
| 90 | 91 | sessionsandsubscriptionsdb.cpp |
| 91 | 92 | qospacketqueue.cpp |
| 93 | + threadauth.cpp | |
| 92 | 94 | |
| 93 | 95 | ) |
| 94 | 96 | ... | ... |
FlashMQTests/FlashMQTests.pro
| ... | ... | @@ -46,6 +46,7 @@ SOURCES += tst_maintests.cpp \ |
| 46 | 46 | ../persistencefile.cpp \ |
| 47 | 47 | ../sessionsandsubscriptionsdb.cpp \ |
| 48 | 48 | ../qospacketqueue.cpp \ |
| 49 | + ../threadauth.cpp \ | |
| 49 | 50 | mainappthread.cpp \ |
| 50 | 51 | twoclienttestcontext.cpp |
| 51 | 52 | |
| ... | ... | @@ -85,6 +86,7 @@ HEADERS += \ |
| 85 | 86 | ../persistencefile.h \ |
| 86 | 87 | ../sessionsandsubscriptionsdb.h \ |
| 87 | 88 | ../qospacketqueue.h \ |
| 89 | + ../threadauth.h \ | |
| 88 | 90 | mainappthread.h \ |
| 89 | 91 | twoclienttestcontext.h |
| 90 | 92 | ... | ... |
FlashMQTests/tst_maintests.cpp
| ... | ... | @@ -33,6 +33,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. |
| 33 | 33 | #include "sessionsandsubscriptionsdb.h" |
| 34 | 34 | #include "session.h" |
| 35 | 35 | #include "threaddata.h" |
| 36 | +#include "threadauth.h" | |
| 36 | 37 | |
| 37 | 38 | // Dumb Qt version gives warnings when comparing uint with number literal. |
| 38 | 39 | template <typename T1, typename T2> |
| ... | ... | @@ -947,6 +948,10 @@ void MainTests::testSavingSessions() |
| 947 | 948 | std::shared_ptr<SubscriptionStore> store(new SubscriptionStore()); |
| 948 | 949 | std::shared_ptr<ThreadData> t(new ThreadData(0, store, settings)); |
| 949 | 950 | |
| 951 | + // Kind of a hack... | |
| 952 | + Authentication auth(*settings.get()); | |
| 953 | + ThreadAuth::assign(&auth); | |
| 954 | + | |
| 950 | 955 | std::shared_ptr<Client> c1(new Client(0, t, nullptr, false, nullptr, settings, false)); |
| 951 | 956 | c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60, false); |
| 952 | 957 | store->registerClientAndKickExistingOne(c1); | ... | ... |
mainapp.cpp
| ... | ... | @@ -28,6 +28,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. |
| 28 | 28 | #include <openssl/err.h> |
| 29 | 29 | |
| 30 | 30 | #include "logger.h" |
| 31 | +#include "threadauth.h" | |
| 31 | 32 | |
| 32 | 33 | #define MAX_EVENTS 1024 |
| 33 | 34 | |
| ... | ... | @@ -36,6 +37,7 @@ MainApp *MainApp::instance = nullptr; |
| 36 | 37 | void do_thread_work(ThreadData *threadData) |
| 37 | 38 | { |
| 38 | 39 | int epoll_fd = threadData->epollfd; |
| 40 | + ThreadAuth::assign(&threadData->authentication); | |
| 39 | 41 | |
| 40 | 42 | struct epoll_event events[MAX_EVENTS]; |
| 41 | 43 | memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); | ... | ... |
mqttpacket.cpp
| ... | ... | @@ -22,6 +22,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. |
| 22 | 22 | #include <cassert> |
| 23 | 23 | |
| 24 | 24 | #include "utils.h" |
| 25 | +#include "threadauth.h" | |
| 25 | 26 | |
| 26 | 27 | // We can void constant reallocation of space for parsed subtopics by using this. But, beware to only use it during handling of the current |
| 27 | 28 | // packet. Don't access it for a stored packet, because then it will have changed. |
| ... | ... | @@ -356,6 +357,8 @@ void MqttPacket::handleConnect() |
| 356 | 357 | bool accessGranted = false; |
| 357 | 358 | std::string denyLogMsg; |
| 358 | 359 | |
| 360 | + Authentication &authentication = *ThreadAuth::getAuth(); | |
| 361 | + | |
| 359 | 362 | if (!user_name_flag && settings.allowAnonymous) |
| 360 | 363 | { |
| 361 | 364 | accessGranted = true; |
| ... | ... | @@ -366,7 +369,7 @@ void MqttPacket::handleConnect() |
| 366 | 369 | sender->setDisconnectReason("Invalid username character"); |
| 367 | 370 | accessGranted = false; |
| 368 | 371 | } |
| 369 | - else if (sender->getThreadData()->authentication.unPwdCheck(username, password) == AuthResult::success) | |
| 372 | + else if (authentication.unPwdCheck(username, password) == AuthResult::success) | |
| 370 | 373 | { |
| 371 | 374 | accessGranted = true; |
| 372 | 375 | } |
| ... | ... | @@ -421,6 +424,8 @@ void MqttPacket::handleSubscribe() |
| 421 | 424 | |
| 422 | 425 | uint16_t packet_id = readTwoBytesToUInt16(); |
| 423 | 426 | |
| 427 | + Authentication &authentication = *ThreadAuth::getAuth(); | |
| 428 | + | |
| 424 | 429 | std::list<char> subs_reponse_codes; |
| 425 | 430 | while (remainingAfterPos() > 0) |
| 426 | 431 | { |
| ... | ... | @@ -439,7 +444,7 @@ void MqttPacket::handleSubscribe() |
| 439 | 444 | throw ProtocolError("QoS is greater than 2, and/or reserved bytes in QoS field are not 0."); |
| 440 | 445 | |
| 441 | 446 | splitTopic(topic, *subtopics); |
| 442 | - if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::subscribe, qos, false) == AuthResult::success) | |
| 447 | + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::subscribe, qos, false) == AuthResult::success) | |
| 443 | 448 | { |
| 444 | 449 | logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str()); |
| 445 | 450 | sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, *subtopics, qos); |
| ... | ... | @@ -555,7 +560,8 @@ void MqttPacket::handlePublish() |
| 555 | 560 | payloadLen = remainingAfterPos(); |
| 556 | 561 | payloadStart = pos; |
| 557 | 562 | |
| 558 | - if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success) | |
| 563 | + Authentication &authentication = *ThreadAuth::getAuth(); | |
| 564 | + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write, qos, retain) == AuthResult::success) | |
| 559 | 565 | { |
| 560 | 566 | if (retain) |
| 561 | 567 | { | ... | ... |
session.cpp
| ... | ... | @@ -19,6 +19,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. |
| 19 | 19 | |
| 20 | 20 | #include "session.h" |
| 21 | 21 | #include "client.h" |
| 22 | +#include "threadauth.h" | |
| 22 | 23 | |
| 23 | 24 | std::chrono::time_point<std::chrono::steady_clock> appStartTime = std::chrono::steady_clock::now(); |
| 24 | 25 | |
| ... | ... | @@ -118,8 +119,9 @@ void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, u |
| 118 | 119 | assert(max_qos <= 2); |
| 119 | 120 | const char qos = std::min<char>(packet.getQos(), max_qos); |
| 120 | 121 | |
| 121 | - assert(packet.getSender()); | |
| 122 | - Authentication &auth = packet.getSender()->getThreadData()->authentication; | |
| 122 | + Authentication *_auth = ThreadAuth::getAuth(); | |
| 123 | + assert(_auth); | |
| 124 | + Authentication &auth = *_auth; | |
| 123 | 125 | if (auth.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success) |
| 124 | 126 | { |
| 125 | 127 | if (qos == 0) | ... | ... |
threadauth.cpp
0 → 100644
threadauth.h
0 → 100644