From 462fb2edc3c418fe9d6c0854f788575d9bf6cf74 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Sat, 2 Jan 2021 16:59:00 +0100 Subject: [PATCH] Working auth with plugin --- authplugin.cpp | 2 +- logger.cpp | 13 +++++++++---- logger.h | 1 + main.cpp | 3 +++ mainapp.cpp | 28 +++++++++++++++++++++------- mqttpacket.cpp | 44 +++++++++++++++++++++++++++----------------- mqttpacket.h | 8 +++++--- threaddata.cpp | 11 +++++++---- threaddata.h | 3 ++- 9 files changed, 76 insertions(+), 37 deletions(-) diff --git a/authplugin.cpp b/authplugin.cpp index 38777f8..86ee416 100644 --- a/authplugin.cpp +++ b/authplugin.cpp @@ -12,7 +12,7 @@ void mosquitto_log_printf(int level, const char *fmt, ...) Logger *logger = Logger::getInstance(); va_list valist; va_start(valist, fmt); - logger->logf(level, fmt); + logger->logf(level, fmt, valist); va_end(valist); } diff --git a/logger.cpp b/logger.cpp index 5a534df..59e41db 100644 --- a/logger.cpp +++ b/logger.cpp @@ -40,7 +40,15 @@ Logger *Logger::getInstance() void Logger::logf(int level, const char *str, ...) { - if (level > curLogLevel) + va_list valist; + va_start(valist, str); + this->logf(level, str, valist); + va_end(valist); +} + +void Logger::logf(int level, const char *str, va_list valist) +{ + if (level > curLogLevel) // TODO: wrong: bitmap based return; std::lock_guard locker(logMutex); @@ -55,8 +63,6 @@ void Logger::logf(int level, const char *str, ...) const std::string s = oss.str(); const char *logfmtstring = s.c_str(); - va_list valist; - va_start(valist, str); if (this->file) { vfprintf(this->file, logfmtstring, valist); @@ -75,5 +81,4 @@ void Logger::logf(int level, const char *str, ...) fflush(stdout); #endif } - va_end(valist); } diff --git a/logger.h b/logger.h index 9d88422..84a10bf 100644 --- a/logger.h +++ b/logger.h @@ -25,6 +25,7 @@ class Logger public: static Logger *getInstance(); + void logf(int level, const char *str, va_list args); void logf(int level, const char *str, ...); }; diff --git a/main.cpp b/main.cpp index a311883..d09c587 100644 --- a/main.cpp +++ b/main.cpp @@ -65,6 +65,9 @@ int main() { try { + Logger *logger = Logger::getInstance(); + logger->logf(LOG_NOTICE, "Starting FlashMQ"); + mainApp = MainApp::getMainApp(); check(register_signal_handers()); mainApp->start(); diff --git a/mainapp.cpp b/mainapp.cpp index bba53f1..d3e956e 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -17,6 +17,21 @@ void do_thread_work(ThreadData *threadData) std::vector packetQueueIn; time_t lastKeepAliveCheck = time(NULL); + Logger *logger = Logger::getInstance(); + + try + { + logger->logf(LOG_NOTICE, "Thread %d doing auth init.", threadData->threadnr); + threadData->initAuthPlugin(); + } + catch(std::exception &ex) + { + logger->logf(LOG_ERR, "Error initializing auth back-end: %s", ex.what()); + threadData->running = false; + MainApp *instance = MainApp::getMainApp(); + instance->quit(); + } + while (threadData->running) { int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); @@ -83,7 +98,7 @@ void do_thread_work(ThreadData *threadData) { try { - packet.handle(threadData->getSubscriptionStore()); + packet.handle(); } catch (std::exception &ex) { @@ -209,17 +224,16 @@ void MainApp::start() } } + for(std::shared_ptr &thread : threads) + { + thread->quit(); + } + close(listen_fd); } void MainApp::quit() { std::cout << "Quitting FlashMQ" << std::endl; - running = false; - - for(std::shared_ptr &thread : threads) - { - thread->quit(); - } } diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 641fb98..7387b5d 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -98,7 +98,7 @@ MqttPacket::MqttPacket(const Publish &publish) : calculateRemainingLength(); } -void MqttPacket::handle(std::shared_ptr &subscriptionStore) +void MqttPacket::handle() { if (packetType != PacketType::CONNECT) { @@ -113,9 +113,9 @@ void MqttPacket::handle(std::shared_ptr &subscriptionStore) else if (packetType == PacketType::PINGREQ) sender->writePingResp(); else if (packetType == PacketType::SUBSCRIBE) - handleSubscribe(subscriptionStore); + handleSubscribe(); else if (packetType == PacketType::PUBLISH) - handlePublish(subscriptionStore); + handlePublish(); } void MqttPacket::handleConnect() @@ -146,7 +146,7 @@ void MqttPacket::handleConnect() MqttPacket response(connAck); sender->setReadyForDisconnect(); sender->writeMqttPacket(response); - std::cout << "Rejecting because of invalid protocol version: " << sender->repr() << std::endl; + logger->logf(LOG_ERR, "Rejecting because of invalid protocol version: %s", sender->repr().c_str()); return; } @@ -200,7 +200,7 @@ void MqttPacket::handleConnect() MqttPacket response(connAck); sender->setReadyForDisconnect(); sender->writeMqttPacket(response); - std::cout << "Client ID, username or passwords has invalid UTF8: " << sender->repr() << std::endl; + logger->logf(LOG_ERR, "Client ID, username or passwords has invalid UTF8: ", client_id.c_str()); return; } @@ -212,19 +212,29 @@ void MqttPacket::handleConnect() MqttPacket response(connAck); sender->setReadyForDisconnect(); sender->writeMqttPacket(response); - std::cout << "ClientID has + or # in the id: " << sender->repr() << std::endl; + logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id:", client_id.c_str()); return; } sender->setClientProperties(client_id, username, true, keep_alive); sender->setWill(will_topic, will_payload, will_retain, will_qos); - sender->setAuthenticated(true); - std::cout << "Connect: " << sender->repr() << std::endl; - - ConnAck connAck(ConnAckReturnCodes::Accepted); - MqttPacket response(connAck); - sender->writeMqttPacket(response); + if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success) + { + sender->setAuthenticated(true); + ConnAck connAck(ConnAckReturnCodes::Accepted); + MqttPacket response(connAck); + sender->writeMqttPacket(response); + logger->logf(LOG_NOTICE, "User %s logged in successfully", username.c_str()); + } + else + { + ConnAck connDeny(ConnAckReturnCodes::NotAuthorized); + MqttPacket response(connDeny); + sender->setReadyForDisconnect(); + sender->writeMqttPacket(response); + logger->logf(LOG_NOTICE, "User %s access denied", username.c_str()); + } } else { @@ -232,7 +242,7 @@ void MqttPacket::handleConnect() } } -void MqttPacket::handleSubscribe(std::shared_ptr &subscriptionStore) +void MqttPacket::handleSubscribe() { uint16_t packet_id = readTwoBytesToUInt16(); @@ -245,7 +255,7 @@ void MqttPacket::handleSubscribe(std::shared_ptr &subscriptio if (qos > 0) throw NotImplementedException("QoS not implemented"); std::cout << sender->repr() << " Subscribed to " << topic << std::endl; - subscriptionStore->addSubscription(sender, topic); + sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic); subs_reponse_codes.push_back(qos); } @@ -254,7 +264,7 @@ void MqttPacket::handleSubscribe(std::shared_ptr &subscriptio sender->writeMqttPacket(response); } -void MqttPacket::handlePublish(std::shared_ptr &subscriptionStore) +void MqttPacket::handlePublish() { uint16_t variable_header_length = readTwoBytesToUInt16(); @@ -287,7 +297,7 @@ void MqttPacket::handlePublish(std::shared_ptr &subscriptionS size_t payload_length = remainingAfterPos(); std::string payload(readBytes(payload_length), payload_length); - subscriptionStore->setRetainedMessage(topic, payload, qos); + sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, payload, qos); } // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. @@ -295,7 +305,7 @@ void MqttPacket::handlePublish(std::shared_ptr &subscriptionS bites[0] &= 0b11110110; // For the existing clients, we can just write the same packet back out, with our small alterations. - subscriptionStore->queuePacketAtSubscribers(topic, *this, sender); + sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this, sender); } void MqttPacket::calculateRemainingLength() diff --git a/mqttpacket.h b/mqttpacket.h index f470c0d..90df81a 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -13,6 +13,7 @@ #include "types.h" #include "subscriptionstore.h" #include "cirbuf.h" +#include "logger.h" struct RemainingLength { @@ -31,6 +32,7 @@ class MqttPacket char first_byte = 0; size_t pos = 0; ProtocolVersion protocolVersion = ProtocolVersion::None; + Logger *logger = Logger::getInstance(); char *readBytes(size_t length); char readByte(); @@ -50,11 +52,11 @@ public: MqttPacket(const SubAck &subAck); MqttPacket(const Publish &publish); - void handle(std::shared_ptr &subscriptionStore); + void handle(); void handleConnect(); - void handleSubscribe(std::shared_ptr &subscriptionStore); + void handleSubscribe(); void handlePing(); - void handlePublish(std::shared_ptr &subscriptionStore); + void handlePublish(); size_t getSizeIncludingNonPresentHeader() const; const std::vector &getBites() const { return bites; } diff --git a/threaddata.cpp b/threaddata.cpp index f07706d..3b4a3c9 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -11,10 +11,6 @@ ThreadData::ThreadData(int threadnr, std::shared_ptr &subscri logger = Logger::getInstance(); epollfd = check(epoll_create(999)); - - authPlugin.loadPlugin(confFileParser.getAuthPluginPath()); - authPlugin.init(); - authPlugin.securityInit(false); } void ThreadData::moveThreadHere(std::thread &&thread) @@ -100,6 +96,13 @@ bool ThreadData::doKeepAliveCheck() return true; } +void ThreadData::initAuthPlugin() +{ + authPlugin.loadPlugin(confFileParser.getAuthPluginPath()); + authPlugin.init(); + authPlugin.securityInit(false); +} + void ThreadData::reload() { try diff --git a/threaddata.h b/threaddata.h index ac92b27..b9583cf 100644 --- a/threaddata.h +++ b/threaddata.h @@ -26,10 +26,10 @@ class ThreadData std::mutex clients_by_fd_mutex; std::shared_ptr subscriptionStore; ConfigFileParser &confFileParser; - AuthPlugin authPlugin; Logger *logger; public: + AuthPlugin authPlugin; bool running = true; std::thread thread; int threadnr = 0; @@ -48,6 +48,7 @@ public: std::shared_ptr &getSubscriptionStore(); bool doKeepAliveCheck(); + void initAuthPlugin(); void reload(); }; -- libgit2 0.21.4