You need to sign in before continuing.
Commit 462fb2edc3c418fe9d6c0854f788575d9bf6cf74
1 parent
543d4451
Working auth with plugin
Showing
9 changed files
with
76 additions
and
37 deletions
authplugin.cpp
| @@ -12,7 +12,7 @@ void mosquitto_log_printf(int level, const char *fmt, ...) | @@ -12,7 +12,7 @@ void mosquitto_log_printf(int level, const char *fmt, ...) | ||
| 12 | Logger *logger = Logger::getInstance(); | 12 | Logger *logger = Logger::getInstance(); |
| 13 | va_list valist; | 13 | va_list valist; |
| 14 | va_start(valist, fmt); | 14 | va_start(valist, fmt); |
| 15 | - logger->logf(level, fmt); | 15 | + logger->logf(level, fmt, valist); |
| 16 | va_end(valist); | 16 | va_end(valist); |
| 17 | } | 17 | } |
| 18 | 18 |
logger.cpp
| @@ -40,7 +40,15 @@ Logger *Logger::getInstance() | @@ -40,7 +40,15 @@ Logger *Logger::getInstance() | ||
| 40 | 40 | ||
| 41 | void Logger::logf(int level, const char *str, ...) | 41 | void Logger::logf(int level, const char *str, ...) |
| 42 | { | 42 | { |
| 43 | - if (level > curLogLevel) | 43 | + va_list valist; |
| 44 | + va_start(valist, str); | ||
| 45 | + this->logf(level, str, valist); | ||
| 46 | + va_end(valist); | ||
| 47 | +} | ||
| 48 | + | ||
| 49 | +void Logger::logf(int level, const char *str, va_list valist) | ||
| 50 | +{ | ||
| 51 | + if (level > curLogLevel) // TODO: wrong: bitmap based | ||
| 44 | return; | 52 | return; |
| 45 | 53 | ||
| 46 | std::lock_guard<std::mutex> locker(logMutex); | 54 | std::lock_guard<std::mutex> locker(logMutex); |
| @@ -55,8 +63,6 @@ void Logger::logf(int level, const char *str, ...) | @@ -55,8 +63,6 @@ void Logger::logf(int level, const char *str, ...) | ||
| 55 | const std::string s = oss.str(); | 63 | const std::string s = oss.str(); |
| 56 | const char *logfmtstring = s.c_str(); | 64 | const char *logfmtstring = s.c_str(); |
| 57 | 65 | ||
| 58 | - va_list valist; | ||
| 59 | - va_start(valist, str); | ||
| 60 | if (this->file) | 66 | if (this->file) |
| 61 | { | 67 | { |
| 62 | vfprintf(this->file, logfmtstring, valist); | 68 | vfprintf(this->file, logfmtstring, valist); |
| @@ -75,5 +81,4 @@ void Logger::logf(int level, const char *str, ...) | @@ -75,5 +81,4 @@ void Logger::logf(int level, const char *str, ...) | ||
| 75 | fflush(stdout); | 81 | fflush(stdout); |
| 76 | #endif | 82 | #endif |
| 77 | } | 83 | } |
| 78 | - va_end(valist); | ||
| 79 | } | 84 | } |
logger.h
| @@ -25,6 +25,7 @@ class Logger | @@ -25,6 +25,7 @@ class Logger | ||
| 25 | 25 | ||
| 26 | public: | 26 | public: |
| 27 | static Logger *getInstance(); | 27 | static Logger *getInstance(); |
| 28 | + void logf(int level, const char *str, va_list args); | ||
| 28 | void logf(int level, const char *str, ...); | 29 | void logf(int level, const char *str, ...); |
| 29 | 30 | ||
| 30 | }; | 31 | }; |
main.cpp
| @@ -65,6 +65,9 @@ int main() | @@ -65,6 +65,9 @@ int main() | ||
| 65 | { | 65 | { |
| 66 | try | 66 | try |
| 67 | { | 67 | { |
| 68 | + Logger *logger = Logger::getInstance(); | ||
| 69 | + logger->logf(LOG_NOTICE, "Starting FlashMQ"); | ||
| 70 | + | ||
| 68 | mainApp = MainApp::getMainApp(); | 71 | mainApp = MainApp::getMainApp(); |
| 69 | check<std::runtime_error>(register_signal_handers()); | 72 | check<std::runtime_error>(register_signal_handers()); |
| 70 | mainApp->start(); | 73 | mainApp->start(); |
mainapp.cpp
| @@ -17,6 +17,21 @@ void do_thread_work(ThreadData *threadData) | @@ -17,6 +17,21 @@ void do_thread_work(ThreadData *threadData) | ||
| 17 | std::vector<MqttPacket> packetQueueIn; | 17 | std::vector<MqttPacket> packetQueueIn; |
| 18 | time_t lastKeepAliveCheck = time(NULL); | 18 | time_t lastKeepAliveCheck = time(NULL); |
| 19 | 19 | ||
| 20 | + Logger *logger = Logger::getInstance(); | ||
| 21 | + | ||
| 22 | + try | ||
| 23 | + { | ||
| 24 | + logger->logf(LOG_NOTICE, "Thread %d doing auth init.", threadData->threadnr); | ||
| 25 | + threadData->initAuthPlugin(); | ||
| 26 | + } | ||
| 27 | + catch(std::exception &ex) | ||
| 28 | + { | ||
| 29 | + logger->logf(LOG_ERR, "Error initializing auth back-end: %s", ex.what()); | ||
| 30 | + threadData->running = false; | ||
| 31 | + MainApp *instance = MainApp::getMainApp(); | ||
| 32 | + instance->quit(); | ||
| 33 | + } | ||
| 34 | + | ||
| 20 | while (threadData->running) | 35 | while (threadData->running) |
| 21 | { | 36 | { |
| 22 | int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); | 37 | int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); |
| @@ -83,7 +98,7 @@ void do_thread_work(ThreadData *threadData) | @@ -83,7 +98,7 @@ void do_thread_work(ThreadData *threadData) | ||
| 83 | { | 98 | { |
| 84 | try | 99 | try |
| 85 | { | 100 | { |
| 86 | - packet.handle(threadData->getSubscriptionStore()); | 101 | + packet.handle(); |
| 87 | } | 102 | } |
| 88 | catch (std::exception &ex) | 103 | catch (std::exception &ex) |
| 89 | { | 104 | { |
| @@ -209,17 +224,16 @@ void MainApp::start() | @@ -209,17 +224,16 @@ void MainApp::start() | ||
| 209 | } | 224 | } |
| 210 | } | 225 | } |
| 211 | 226 | ||
| 227 | + for(std::shared_ptr<ThreadData> &thread : threads) | ||
| 228 | + { | ||
| 229 | + thread->quit(); | ||
| 230 | + } | ||
| 231 | + | ||
| 212 | close(listen_fd); | 232 | close(listen_fd); |
| 213 | } | 233 | } |
| 214 | 234 | ||
| 215 | void MainApp::quit() | 235 | void MainApp::quit() |
| 216 | { | 236 | { |
| 217 | std::cout << "Quitting FlashMQ" << std::endl; | 237 | std::cout << "Quitting FlashMQ" << std::endl; |
| 218 | - | ||
| 219 | running = false; | 238 | running = false; |
| 220 | - | ||
| 221 | - for(std::shared_ptr<ThreadData> &thread : threads) | ||
| 222 | - { | ||
| 223 | - thread->quit(); | ||
| 224 | - } | ||
| 225 | } | 239 | } |
mqttpacket.cpp
| @@ -98,7 +98,7 @@ MqttPacket::MqttPacket(const Publish &publish) : | @@ -98,7 +98,7 @@ MqttPacket::MqttPacket(const Publish &publish) : | ||
| 98 | calculateRemainingLength(); | 98 | calculateRemainingLength(); |
| 99 | } | 99 | } |
| 100 | 100 | ||
| 101 | -void MqttPacket::handle(std::shared_ptr<SubscriptionStore> &subscriptionStore) | 101 | +void MqttPacket::handle() |
| 102 | { | 102 | { |
| 103 | if (packetType != PacketType::CONNECT) | 103 | if (packetType != PacketType::CONNECT) |
| 104 | { | 104 | { |
| @@ -113,9 +113,9 @@ void MqttPacket::handle(std::shared_ptr<SubscriptionStore> &subscriptionStore) | @@ -113,9 +113,9 @@ void MqttPacket::handle(std::shared_ptr<SubscriptionStore> &subscriptionStore) | ||
| 113 | else if (packetType == PacketType::PINGREQ) | 113 | else if (packetType == PacketType::PINGREQ) |
| 114 | sender->writePingResp(); | 114 | sender->writePingResp(); |
| 115 | else if (packetType == PacketType::SUBSCRIBE) | 115 | else if (packetType == PacketType::SUBSCRIBE) |
| 116 | - handleSubscribe(subscriptionStore); | 116 | + handleSubscribe(); |
| 117 | else if (packetType == PacketType::PUBLISH) | 117 | else if (packetType == PacketType::PUBLISH) |
| 118 | - handlePublish(subscriptionStore); | 118 | + handlePublish(); |
| 119 | } | 119 | } |
| 120 | 120 | ||
| 121 | void MqttPacket::handleConnect() | 121 | void MqttPacket::handleConnect() |
| @@ -146,7 +146,7 @@ void MqttPacket::handleConnect() | @@ -146,7 +146,7 @@ void MqttPacket::handleConnect() | ||
| 146 | MqttPacket response(connAck); | 146 | MqttPacket response(connAck); |
| 147 | sender->setReadyForDisconnect(); | 147 | sender->setReadyForDisconnect(); |
| 148 | sender->writeMqttPacket(response); | 148 | sender->writeMqttPacket(response); |
| 149 | - std::cout << "Rejecting because of invalid protocol version: " << sender->repr() << std::endl; | 149 | + logger->logf(LOG_ERR, "Rejecting because of invalid protocol version: %s", sender->repr().c_str()); |
| 150 | return; | 150 | return; |
| 151 | } | 151 | } |
| 152 | 152 | ||
| @@ -200,7 +200,7 @@ void MqttPacket::handleConnect() | @@ -200,7 +200,7 @@ void MqttPacket::handleConnect() | ||
| 200 | MqttPacket response(connAck); | 200 | MqttPacket response(connAck); |
| 201 | sender->setReadyForDisconnect(); | 201 | sender->setReadyForDisconnect(); |
| 202 | sender->writeMqttPacket(response); | 202 | sender->writeMqttPacket(response); |
| 203 | - std::cout << "Client ID, username or passwords has invalid UTF8: " << sender->repr() << std::endl; | 203 | + logger->logf(LOG_ERR, "Client ID, username or passwords has invalid UTF8: ", client_id.c_str()); |
| 204 | return; | 204 | return; |
| 205 | } | 205 | } |
| 206 | 206 | ||
| @@ -212,19 +212,29 @@ void MqttPacket::handleConnect() | @@ -212,19 +212,29 @@ void MqttPacket::handleConnect() | ||
| 212 | MqttPacket response(connAck); | 212 | MqttPacket response(connAck); |
| 213 | sender->setReadyForDisconnect(); | 213 | sender->setReadyForDisconnect(); |
| 214 | sender->writeMqttPacket(response); | 214 | sender->writeMqttPacket(response); |
| 215 | - std::cout << "ClientID has + or # in the id: " << sender->repr() << std::endl; | 215 | + logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id:", client_id.c_str()); |
| 216 | return; | 216 | return; |
| 217 | } | 217 | } |
| 218 | 218 | ||
| 219 | sender->setClientProperties(client_id, username, true, keep_alive); | 219 | sender->setClientProperties(client_id, username, true, keep_alive); |
| 220 | sender->setWill(will_topic, will_payload, will_retain, will_qos); | 220 | sender->setWill(will_topic, will_payload, will_retain, will_qos); |
| 221 | - sender->setAuthenticated(true); | ||
| 222 | 221 | ||
| 223 | - std::cout << "Connect: " << sender->repr() << std::endl; | ||
| 224 | - | ||
| 225 | - ConnAck connAck(ConnAckReturnCodes::Accepted); | ||
| 226 | - MqttPacket response(connAck); | ||
| 227 | - sender->writeMqttPacket(response); | 222 | + if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success) |
| 223 | + { | ||
| 224 | + sender->setAuthenticated(true); | ||
| 225 | + ConnAck connAck(ConnAckReturnCodes::Accepted); | ||
| 226 | + MqttPacket response(connAck); | ||
| 227 | + sender->writeMqttPacket(response); | ||
| 228 | + logger->logf(LOG_NOTICE, "User %s logged in successfully", username.c_str()); | ||
| 229 | + } | ||
| 230 | + else | ||
| 231 | + { | ||
| 232 | + ConnAck connDeny(ConnAckReturnCodes::NotAuthorized); | ||
| 233 | + MqttPacket response(connDeny); | ||
| 234 | + sender->setReadyForDisconnect(); | ||
| 235 | + sender->writeMqttPacket(response); | ||
| 236 | + logger->logf(LOG_NOTICE, "User %s access denied", username.c_str()); | ||
| 237 | + } | ||
| 228 | } | 238 | } |
| 229 | else | 239 | else |
| 230 | { | 240 | { |
| @@ -232,7 +242,7 @@ void MqttPacket::handleConnect() | @@ -232,7 +242,7 @@ void MqttPacket::handleConnect() | ||
| 232 | } | 242 | } |
| 233 | } | 243 | } |
| 234 | 244 | ||
| 235 | -void MqttPacket::handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptionStore) | 245 | +void MqttPacket::handleSubscribe() |
| 236 | { | 246 | { |
| 237 | uint16_t packet_id = readTwoBytesToUInt16(); | 247 | uint16_t packet_id = readTwoBytesToUInt16(); |
| 238 | 248 | ||
| @@ -245,7 +255,7 @@ void MqttPacket::handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptio | @@ -245,7 +255,7 @@ void MqttPacket::handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptio | ||
| 245 | if (qos > 0) | 255 | if (qos > 0) |
| 246 | throw NotImplementedException("QoS not implemented"); | 256 | throw NotImplementedException("QoS not implemented"); |
| 247 | std::cout << sender->repr() << " Subscribed to " << topic << std::endl; | 257 | std::cout << sender->repr() << " Subscribed to " << topic << std::endl; |
| 248 | - subscriptionStore->addSubscription(sender, topic); | 258 | + sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic); |
| 249 | subs_reponse_codes.push_back(qos); | 259 | subs_reponse_codes.push_back(qos); |
| 250 | } | 260 | } |
| 251 | 261 | ||
| @@ -254,7 +264,7 @@ void MqttPacket::handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptio | @@ -254,7 +264,7 @@ void MqttPacket::handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptio | ||
| 254 | sender->writeMqttPacket(response); | 264 | sender->writeMqttPacket(response); |
| 255 | } | 265 | } |
| 256 | 266 | ||
| 257 | -void MqttPacket::handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore) | 267 | +void MqttPacket::handlePublish() |
| 258 | { | 268 | { |
| 259 | uint16_t variable_header_length = readTwoBytesToUInt16(); | 269 | uint16_t variable_header_length = readTwoBytesToUInt16(); |
| 260 | 270 | ||
| @@ -287,7 +297,7 @@ void MqttPacket::handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionS | @@ -287,7 +297,7 @@ void MqttPacket::handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionS | ||
| 287 | size_t payload_length = remainingAfterPos(); | 297 | size_t payload_length = remainingAfterPos(); |
| 288 | std::string payload(readBytes(payload_length), payload_length); | 298 | std::string payload(readBytes(payload_length), payload_length); |
| 289 | 299 | ||
| 290 | - subscriptionStore->setRetainedMessage(topic, payload, qos); | 300 | + sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, payload, qos); |
| 291 | } | 301 | } |
| 292 | 302 | ||
| 293 | // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. | 303 | // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. |
| @@ -295,7 +305,7 @@ void MqttPacket::handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionS | @@ -295,7 +305,7 @@ void MqttPacket::handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionS | ||
| 295 | bites[0] &= 0b11110110; | 305 | bites[0] &= 0b11110110; |
| 296 | 306 | ||
| 297 | // For the existing clients, we can just write the same packet back out, with our small alterations. | 307 | // For the existing clients, we can just write the same packet back out, with our small alterations. |
| 298 | - subscriptionStore->queuePacketAtSubscribers(topic, *this, sender); | 308 | + sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this, sender); |
| 299 | } | 309 | } |
| 300 | 310 | ||
| 301 | void MqttPacket::calculateRemainingLength() | 311 | void MqttPacket::calculateRemainingLength() |
mqttpacket.h
| @@ -13,6 +13,7 @@ | @@ -13,6 +13,7 @@ | ||
| 13 | #include "types.h" | 13 | #include "types.h" |
| 14 | #include "subscriptionstore.h" | 14 | #include "subscriptionstore.h" |
| 15 | #include "cirbuf.h" | 15 | #include "cirbuf.h" |
| 16 | +#include "logger.h" | ||
| 16 | 17 | ||
| 17 | struct RemainingLength | 18 | struct RemainingLength |
| 18 | { | 19 | { |
| @@ -31,6 +32,7 @@ class MqttPacket | @@ -31,6 +32,7 @@ class MqttPacket | ||
| 31 | char first_byte = 0; | 32 | char first_byte = 0; |
| 32 | size_t pos = 0; | 33 | size_t pos = 0; |
| 33 | ProtocolVersion protocolVersion = ProtocolVersion::None; | 34 | ProtocolVersion protocolVersion = ProtocolVersion::None; |
| 35 | + Logger *logger = Logger::getInstance(); | ||
| 34 | 36 | ||
| 35 | char *readBytes(size_t length); | 37 | char *readBytes(size_t length); |
| 36 | char readByte(); | 38 | char readByte(); |
| @@ -50,11 +52,11 @@ public: | @@ -50,11 +52,11 @@ public: | ||
| 50 | MqttPacket(const SubAck &subAck); | 52 | MqttPacket(const SubAck &subAck); |
| 51 | MqttPacket(const Publish &publish); | 53 | MqttPacket(const Publish &publish); |
| 52 | 54 | ||
| 53 | - void handle(std::shared_ptr<SubscriptionStore> &subscriptionStore); | 55 | + void handle(); |
| 54 | void handleConnect(); | 56 | void handleConnect(); |
| 55 | - void handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptionStore); | 57 | + void handleSubscribe(); |
| 56 | void handlePing(); | 58 | void handlePing(); |
| 57 | - void handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore); | 59 | + void handlePublish(); |
| 58 | 60 | ||
| 59 | size_t getSizeIncludingNonPresentHeader() const; | 61 | size_t getSizeIncludingNonPresentHeader() const; |
| 60 | const std::vector<char> &getBites() const { return bites; } | 62 | const std::vector<char> &getBites() const { return bites; } |
threaddata.cpp
| @@ -11,10 +11,6 @@ ThreadData::ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscri | @@ -11,10 +11,6 @@ ThreadData::ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscri | ||
| 11 | logger = Logger::getInstance(); | 11 | logger = Logger::getInstance(); |
| 12 | 12 | ||
| 13 | epollfd = check<std::runtime_error>(epoll_create(999)); | 13 | epollfd = check<std::runtime_error>(epoll_create(999)); |
| 14 | - | ||
| 15 | - authPlugin.loadPlugin(confFileParser.getAuthPluginPath()); | ||
| 16 | - authPlugin.init(); | ||
| 17 | - authPlugin.securityInit(false); | ||
| 18 | } | 14 | } |
| 19 | 15 | ||
| 20 | void ThreadData::moveThreadHere(std::thread &&thread) | 16 | void ThreadData::moveThreadHere(std::thread &&thread) |
| @@ -100,6 +96,13 @@ bool ThreadData::doKeepAliveCheck() | @@ -100,6 +96,13 @@ bool ThreadData::doKeepAliveCheck() | ||
| 100 | return true; | 96 | return true; |
| 101 | } | 97 | } |
| 102 | 98 | ||
| 99 | +void ThreadData::initAuthPlugin() | ||
| 100 | +{ | ||
| 101 | + authPlugin.loadPlugin(confFileParser.getAuthPluginPath()); | ||
| 102 | + authPlugin.init(); | ||
| 103 | + authPlugin.securityInit(false); | ||
| 104 | +} | ||
| 105 | + | ||
| 103 | void ThreadData::reload() | 106 | void ThreadData::reload() |
| 104 | { | 107 | { |
| 105 | try | 108 | try |
threaddata.h
| @@ -26,10 +26,10 @@ class ThreadData | @@ -26,10 +26,10 @@ class ThreadData | ||
| 26 | std::mutex clients_by_fd_mutex; | 26 | std::mutex clients_by_fd_mutex; |
| 27 | std::shared_ptr<SubscriptionStore> subscriptionStore; | 27 | std::shared_ptr<SubscriptionStore> subscriptionStore; |
| 28 | ConfigFileParser &confFileParser; | 28 | ConfigFileParser &confFileParser; |
| 29 | - AuthPlugin authPlugin; | ||
| 30 | Logger *logger; | 29 | Logger *logger; |
| 31 | 30 | ||
| 32 | public: | 31 | public: |
| 32 | + AuthPlugin authPlugin; | ||
| 33 | bool running = true; | 33 | bool running = true; |
| 34 | std::thread thread; | 34 | std::thread thread; |
| 35 | int threadnr = 0; | 35 | int threadnr = 0; |
| @@ -48,6 +48,7 @@ public: | @@ -48,6 +48,7 @@ public: | ||
| 48 | std::shared_ptr<SubscriptionStore> &getSubscriptionStore(); | 48 | std::shared_ptr<SubscriptionStore> &getSubscriptionStore(); |
| 49 | 49 | ||
| 50 | bool doKeepAliveCheck(); | 50 | bool doKeepAliveCheck(); |
| 51 | + void initAuthPlugin(); | ||
| 51 | void reload(); | 52 | void reload(); |
| 52 | }; | 53 | }; |
| 53 | 54 |