Commit 462fb2edc3c418fe9d6c0854f788575d9bf6cf74
1 parent
543d4451
Working auth with plugin
Showing
9 changed files
with
76 additions
and
37 deletions
authplugin.cpp
logger.cpp
| ... | ... | @@ -40,7 +40,15 @@ Logger *Logger::getInstance() |
| 40 | 40 | |
| 41 | 41 | void Logger::logf(int level, const char *str, ...) |
| 42 | 42 | { |
| 43 | - if (level > curLogLevel) | |
| 43 | + va_list valist; | |
| 44 | + va_start(valist, str); | |
| 45 | + this->logf(level, str, valist); | |
| 46 | + va_end(valist); | |
| 47 | +} | |
| 48 | + | |
| 49 | +void Logger::logf(int level, const char *str, va_list valist) | |
| 50 | +{ | |
| 51 | + if (level > curLogLevel) // TODO: wrong: bitmap based | |
| 44 | 52 | return; |
| 45 | 53 | |
| 46 | 54 | std::lock_guard<std::mutex> locker(logMutex); |
| ... | ... | @@ -55,8 +63,6 @@ void Logger::logf(int level, const char *str, ...) |
| 55 | 63 | const std::string s = oss.str(); |
| 56 | 64 | const char *logfmtstring = s.c_str(); |
| 57 | 65 | |
| 58 | - va_list valist; | |
| 59 | - va_start(valist, str); | |
| 60 | 66 | if (this->file) |
| 61 | 67 | { |
| 62 | 68 | vfprintf(this->file, logfmtstring, valist); |
| ... | ... | @@ -75,5 +81,4 @@ void Logger::logf(int level, const char *str, ...) |
| 75 | 81 | fflush(stdout); |
| 76 | 82 | #endif |
| 77 | 83 | } |
| 78 | - va_end(valist); | |
| 79 | 84 | } | ... | ... |
logger.h
main.cpp
mainapp.cpp
| ... | ... | @@ -17,6 +17,21 @@ void do_thread_work(ThreadData *threadData) |
| 17 | 17 | std::vector<MqttPacket> packetQueueIn; |
| 18 | 18 | time_t lastKeepAliveCheck = time(NULL); |
| 19 | 19 | |
| 20 | + Logger *logger = Logger::getInstance(); | |
| 21 | + | |
| 22 | + try | |
| 23 | + { | |
| 24 | + logger->logf(LOG_NOTICE, "Thread %d doing auth init.", threadData->threadnr); | |
| 25 | + threadData->initAuthPlugin(); | |
| 26 | + } | |
| 27 | + catch(std::exception &ex) | |
| 28 | + { | |
| 29 | + logger->logf(LOG_ERR, "Error initializing auth back-end: %s", ex.what()); | |
| 30 | + threadData->running = false; | |
| 31 | + MainApp *instance = MainApp::getMainApp(); | |
| 32 | + instance->quit(); | |
| 33 | + } | |
| 34 | + | |
| 20 | 35 | while (threadData->running) |
| 21 | 36 | { |
| 22 | 37 | int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); |
| ... | ... | @@ -83,7 +98,7 @@ void do_thread_work(ThreadData *threadData) |
| 83 | 98 | { |
| 84 | 99 | try |
| 85 | 100 | { |
| 86 | - packet.handle(threadData->getSubscriptionStore()); | |
| 101 | + packet.handle(); | |
| 87 | 102 | } |
| 88 | 103 | catch (std::exception &ex) |
| 89 | 104 | { |
| ... | ... | @@ -209,17 +224,16 @@ void MainApp::start() |
| 209 | 224 | } |
| 210 | 225 | } |
| 211 | 226 | |
| 227 | + for(std::shared_ptr<ThreadData> &thread : threads) | |
| 228 | + { | |
| 229 | + thread->quit(); | |
| 230 | + } | |
| 231 | + | |
| 212 | 232 | close(listen_fd); |
| 213 | 233 | } |
| 214 | 234 | |
| 215 | 235 | void MainApp::quit() |
| 216 | 236 | { |
| 217 | 237 | std::cout << "Quitting FlashMQ" << std::endl; |
| 218 | - | |
| 219 | 238 | running = false; |
| 220 | - | |
| 221 | - for(std::shared_ptr<ThreadData> &thread : threads) | |
| 222 | - { | |
| 223 | - thread->quit(); | |
| 224 | - } | |
| 225 | 239 | } | ... | ... |
mqttpacket.cpp
| ... | ... | @@ -98,7 +98,7 @@ MqttPacket::MqttPacket(const Publish &publish) : |
| 98 | 98 | calculateRemainingLength(); |
| 99 | 99 | } |
| 100 | 100 | |
| 101 | -void MqttPacket::handle(std::shared_ptr<SubscriptionStore> &subscriptionStore) | |
| 101 | +void MqttPacket::handle() | |
| 102 | 102 | { |
| 103 | 103 | if (packetType != PacketType::CONNECT) |
| 104 | 104 | { |
| ... | ... | @@ -113,9 +113,9 @@ void MqttPacket::handle(std::shared_ptr<SubscriptionStore> &subscriptionStore) |
| 113 | 113 | else if (packetType == PacketType::PINGREQ) |
| 114 | 114 | sender->writePingResp(); |
| 115 | 115 | else if (packetType == PacketType::SUBSCRIBE) |
| 116 | - handleSubscribe(subscriptionStore); | |
| 116 | + handleSubscribe(); | |
| 117 | 117 | else if (packetType == PacketType::PUBLISH) |
| 118 | - handlePublish(subscriptionStore); | |
| 118 | + handlePublish(); | |
| 119 | 119 | } |
| 120 | 120 | |
| 121 | 121 | void MqttPacket::handleConnect() |
| ... | ... | @@ -146,7 +146,7 @@ void MqttPacket::handleConnect() |
| 146 | 146 | MqttPacket response(connAck); |
| 147 | 147 | sender->setReadyForDisconnect(); |
| 148 | 148 | sender->writeMqttPacket(response); |
| 149 | - std::cout << "Rejecting because of invalid protocol version: " << sender->repr() << std::endl; | |
| 149 | + logger->logf(LOG_ERR, "Rejecting because of invalid protocol version: %s", sender->repr().c_str()); | |
| 150 | 150 | return; |
| 151 | 151 | } |
| 152 | 152 | |
| ... | ... | @@ -200,7 +200,7 @@ void MqttPacket::handleConnect() |
| 200 | 200 | MqttPacket response(connAck); |
| 201 | 201 | sender->setReadyForDisconnect(); |
| 202 | 202 | sender->writeMqttPacket(response); |
| 203 | - std::cout << "Client ID, username or passwords has invalid UTF8: " << sender->repr() << std::endl; | |
| 203 | + logger->logf(LOG_ERR, "Client ID, username or passwords has invalid UTF8: ", client_id.c_str()); | |
| 204 | 204 | return; |
| 205 | 205 | } |
| 206 | 206 | |
| ... | ... | @@ -212,19 +212,29 @@ void MqttPacket::handleConnect() |
| 212 | 212 | MqttPacket response(connAck); |
| 213 | 213 | sender->setReadyForDisconnect(); |
| 214 | 214 | sender->writeMqttPacket(response); |
| 215 | - std::cout << "ClientID has + or # in the id: " << sender->repr() << std::endl; | |
| 215 | + logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id:", client_id.c_str()); | |
| 216 | 216 | return; |
| 217 | 217 | } |
| 218 | 218 | |
| 219 | 219 | sender->setClientProperties(client_id, username, true, keep_alive); |
| 220 | 220 | sender->setWill(will_topic, will_payload, will_retain, will_qos); |
| 221 | - sender->setAuthenticated(true); | |
| 222 | 221 | |
| 223 | - std::cout << "Connect: " << sender->repr() << std::endl; | |
| 224 | - | |
| 225 | - ConnAck connAck(ConnAckReturnCodes::Accepted); | |
| 226 | - MqttPacket response(connAck); | |
| 227 | - sender->writeMqttPacket(response); | |
| 222 | + if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success) | |
| 223 | + { | |
| 224 | + sender->setAuthenticated(true); | |
| 225 | + ConnAck connAck(ConnAckReturnCodes::Accepted); | |
| 226 | + MqttPacket response(connAck); | |
| 227 | + sender->writeMqttPacket(response); | |
| 228 | + logger->logf(LOG_NOTICE, "User %s logged in successfully", username.c_str()); | |
| 229 | + } | |
| 230 | + else | |
| 231 | + { | |
| 232 | + ConnAck connDeny(ConnAckReturnCodes::NotAuthorized); | |
| 233 | + MqttPacket response(connDeny); | |
| 234 | + sender->setReadyForDisconnect(); | |
| 235 | + sender->writeMqttPacket(response); | |
| 236 | + logger->logf(LOG_NOTICE, "User %s access denied", username.c_str()); | |
| 237 | + } | |
| 228 | 238 | } |
| 229 | 239 | else |
| 230 | 240 | { |
| ... | ... | @@ -232,7 +242,7 @@ void MqttPacket::handleConnect() |
| 232 | 242 | } |
| 233 | 243 | } |
| 234 | 244 | |
| 235 | -void MqttPacket::handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptionStore) | |
| 245 | +void MqttPacket::handleSubscribe() | |
| 236 | 246 | { |
| 237 | 247 | uint16_t packet_id = readTwoBytesToUInt16(); |
| 238 | 248 | |
| ... | ... | @@ -245,7 +255,7 @@ void MqttPacket::handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptio |
| 245 | 255 | if (qos > 0) |
| 246 | 256 | throw NotImplementedException("QoS not implemented"); |
| 247 | 257 | std::cout << sender->repr() << " Subscribed to " << topic << std::endl; |
| 248 | - subscriptionStore->addSubscription(sender, topic); | |
| 258 | + sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic); | |
| 249 | 259 | subs_reponse_codes.push_back(qos); |
| 250 | 260 | } |
| 251 | 261 | |
| ... | ... | @@ -254,7 +264,7 @@ void MqttPacket::handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptio |
| 254 | 264 | sender->writeMqttPacket(response); |
| 255 | 265 | } |
| 256 | 266 | |
| 257 | -void MqttPacket::handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore) | |
| 267 | +void MqttPacket::handlePublish() | |
| 258 | 268 | { |
| 259 | 269 | uint16_t variable_header_length = readTwoBytesToUInt16(); |
| 260 | 270 | |
| ... | ... | @@ -287,7 +297,7 @@ void MqttPacket::handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionS |
| 287 | 297 | size_t payload_length = remainingAfterPos(); |
| 288 | 298 | std::string payload(readBytes(payload_length), payload_length); |
| 289 | 299 | |
| 290 | - subscriptionStore->setRetainedMessage(topic, payload, qos); | |
| 300 | + sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, payload, qos); | |
| 291 | 301 | } |
| 292 | 302 | |
| 293 | 303 | // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. |
| ... | ... | @@ -295,7 +305,7 @@ void MqttPacket::handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionS |
| 295 | 305 | bites[0] &= 0b11110110; |
| 296 | 306 | |
| 297 | 307 | // For the existing clients, we can just write the same packet back out, with our small alterations. |
| 298 | - subscriptionStore->queuePacketAtSubscribers(topic, *this, sender); | |
| 308 | + sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this, sender); | |
| 299 | 309 | } |
| 300 | 310 | |
| 301 | 311 | void MqttPacket::calculateRemainingLength() | ... | ... |
mqttpacket.h
| ... | ... | @@ -13,6 +13,7 @@ |
| 13 | 13 | #include "types.h" |
| 14 | 14 | #include "subscriptionstore.h" |
| 15 | 15 | #include "cirbuf.h" |
| 16 | +#include "logger.h" | |
| 16 | 17 | |
| 17 | 18 | struct RemainingLength |
| 18 | 19 | { |
| ... | ... | @@ -31,6 +32,7 @@ class MqttPacket |
| 31 | 32 | char first_byte = 0; |
| 32 | 33 | size_t pos = 0; |
| 33 | 34 | ProtocolVersion protocolVersion = ProtocolVersion::None; |
| 35 | + Logger *logger = Logger::getInstance(); | |
| 34 | 36 | |
| 35 | 37 | char *readBytes(size_t length); |
| 36 | 38 | char readByte(); |
| ... | ... | @@ -50,11 +52,11 @@ public: |
| 50 | 52 | MqttPacket(const SubAck &subAck); |
| 51 | 53 | MqttPacket(const Publish &publish); |
| 52 | 54 | |
| 53 | - void handle(std::shared_ptr<SubscriptionStore> &subscriptionStore); | |
| 55 | + void handle(); | |
| 54 | 56 | void handleConnect(); |
| 55 | - void handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptionStore); | |
| 57 | + void handleSubscribe(); | |
| 56 | 58 | void handlePing(); |
| 57 | - void handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore); | |
| 59 | + void handlePublish(); | |
| 58 | 60 | |
| 59 | 61 | size_t getSizeIncludingNonPresentHeader() const; |
| 60 | 62 | const std::vector<char> &getBites() const { return bites; } | ... | ... |
threaddata.cpp
| ... | ... | @@ -11,10 +11,6 @@ ThreadData::ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscri |
| 11 | 11 | logger = Logger::getInstance(); |
| 12 | 12 | |
| 13 | 13 | epollfd = check<std::runtime_error>(epoll_create(999)); |
| 14 | - | |
| 15 | - authPlugin.loadPlugin(confFileParser.getAuthPluginPath()); | |
| 16 | - authPlugin.init(); | |
| 17 | - authPlugin.securityInit(false); | |
| 18 | 14 | } |
| 19 | 15 | |
| 20 | 16 | void ThreadData::moveThreadHere(std::thread &&thread) |
| ... | ... | @@ -100,6 +96,13 @@ bool ThreadData::doKeepAliveCheck() |
| 100 | 96 | return true; |
| 101 | 97 | } |
| 102 | 98 | |
| 99 | +void ThreadData::initAuthPlugin() | |
| 100 | +{ | |
| 101 | + authPlugin.loadPlugin(confFileParser.getAuthPluginPath()); | |
| 102 | + authPlugin.init(); | |
| 103 | + authPlugin.securityInit(false); | |
| 104 | +} | |
| 105 | + | |
| 103 | 106 | void ThreadData::reload() |
| 104 | 107 | { |
| 105 | 108 | try | ... | ... |
threaddata.h
| ... | ... | @@ -26,10 +26,10 @@ class ThreadData |
| 26 | 26 | std::mutex clients_by_fd_mutex; |
| 27 | 27 | std::shared_ptr<SubscriptionStore> subscriptionStore; |
| 28 | 28 | ConfigFileParser &confFileParser; |
| 29 | - AuthPlugin authPlugin; | |
| 30 | 29 | Logger *logger; |
| 31 | 30 | |
| 32 | 31 | public: |
| 32 | + AuthPlugin authPlugin; | |
| 33 | 33 | bool running = true; |
| 34 | 34 | std::thread thread; |
| 35 | 35 | int threadnr = 0; |
| ... | ... | @@ -48,6 +48,7 @@ public: |
| 48 | 48 | std::shared_ptr<SubscriptionStore> &getSubscriptionStore(); |
| 49 | 49 | |
| 50 | 50 | bool doKeepAliveCheck(); |
| 51 | + void initAuthPlugin(); | |
| 51 | 52 | void reload(); |
| 52 | 53 | }; |
| 53 | 54 | ... | ... |