You need to sign in before continuing.

Commit 462fb2edc3c418fe9d6c0854f788575d9bf6cf74

Authored by Wiebe Cazemier
1 parent 543d4451

Working auth with plugin

authplugin.cpp
@@ -12,7 +12,7 @@ void mosquitto_log_printf(int level, const char *fmt, ...) @@ -12,7 +12,7 @@ void mosquitto_log_printf(int level, const char *fmt, ...)
12 Logger *logger = Logger::getInstance(); 12 Logger *logger = Logger::getInstance();
13 va_list valist; 13 va_list valist;
14 va_start(valist, fmt); 14 va_start(valist, fmt);
15 - logger->logf(level, fmt); 15 + logger->logf(level, fmt, valist);
16 va_end(valist); 16 va_end(valist);
17 } 17 }
18 18
logger.cpp
@@ -40,7 +40,15 @@ Logger *Logger::getInstance() @@ -40,7 +40,15 @@ Logger *Logger::getInstance()
40 40
41 void Logger::logf(int level, const char *str, ...) 41 void Logger::logf(int level, const char *str, ...)
42 { 42 {
43 - if (level > curLogLevel) 43 + va_list valist;
  44 + va_start(valist, str);
  45 + this->logf(level, str, valist);
  46 + va_end(valist);
  47 +}
  48 +
  49 +void Logger::logf(int level, const char *str, va_list valist)
  50 +{
  51 + if (level > curLogLevel) // TODO: wrong: bitmap based
44 return; 52 return;
45 53
46 std::lock_guard<std::mutex> locker(logMutex); 54 std::lock_guard<std::mutex> locker(logMutex);
@@ -55,8 +63,6 @@ void Logger::logf(int level, const char *str, ...) @@ -55,8 +63,6 @@ void Logger::logf(int level, const char *str, ...)
55 const std::string s = oss.str(); 63 const std::string s = oss.str();
56 const char *logfmtstring = s.c_str(); 64 const char *logfmtstring = s.c_str();
57 65
58 - va_list valist;  
59 - va_start(valist, str);  
60 if (this->file) 66 if (this->file)
61 { 67 {
62 vfprintf(this->file, logfmtstring, valist); 68 vfprintf(this->file, logfmtstring, valist);
@@ -75,5 +81,4 @@ void Logger::logf(int level, const char *str, ...) @@ -75,5 +81,4 @@ void Logger::logf(int level, const char *str, ...)
75 fflush(stdout); 81 fflush(stdout);
76 #endif 82 #endif
77 } 83 }
78 - va_end(valist);  
79 } 84 }
logger.h
@@ -25,6 +25,7 @@ class Logger @@ -25,6 +25,7 @@ class Logger
25 25
26 public: 26 public:
27 static Logger *getInstance(); 27 static Logger *getInstance();
  28 + void logf(int level, const char *str, va_list args);
28 void logf(int level, const char *str, ...); 29 void logf(int level, const char *str, ...);
29 30
30 }; 31 };
main.cpp
@@ -65,6 +65,9 @@ int main() @@ -65,6 +65,9 @@ int main()
65 { 65 {
66 try 66 try
67 { 67 {
  68 + Logger *logger = Logger::getInstance();
  69 + logger->logf(LOG_NOTICE, "Starting FlashMQ");
  70 +
68 mainApp = MainApp::getMainApp(); 71 mainApp = MainApp::getMainApp();
69 check<std::runtime_error>(register_signal_handers()); 72 check<std::runtime_error>(register_signal_handers());
70 mainApp->start(); 73 mainApp->start();
mainapp.cpp
@@ -17,6 +17,21 @@ void do_thread_work(ThreadData *threadData) @@ -17,6 +17,21 @@ void do_thread_work(ThreadData *threadData)
17 std::vector<MqttPacket> packetQueueIn; 17 std::vector<MqttPacket> packetQueueIn;
18 time_t lastKeepAliveCheck = time(NULL); 18 time_t lastKeepAliveCheck = time(NULL);
19 19
  20 + Logger *logger = Logger::getInstance();
  21 +
  22 + try
  23 + {
  24 + logger->logf(LOG_NOTICE, "Thread %d doing auth init.", threadData->threadnr);
  25 + threadData->initAuthPlugin();
  26 + }
  27 + catch(std::exception &ex)
  28 + {
  29 + logger->logf(LOG_ERR, "Error initializing auth back-end: %s", ex.what());
  30 + threadData->running = false;
  31 + MainApp *instance = MainApp::getMainApp();
  32 + instance->quit();
  33 + }
  34 +
20 while (threadData->running) 35 while (threadData->running)
21 { 36 {
22 int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); 37 int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100);
@@ -83,7 +98,7 @@ void do_thread_work(ThreadData *threadData) @@ -83,7 +98,7 @@ void do_thread_work(ThreadData *threadData)
83 { 98 {
84 try 99 try
85 { 100 {
86 - packet.handle(threadData->getSubscriptionStore()); 101 + packet.handle();
87 } 102 }
88 catch (std::exception &ex) 103 catch (std::exception &ex)
89 { 104 {
@@ -209,17 +224,16 @@ void MainApp::start() @@ -209,17 +224,16 @@ void MainApp::start()
209 } 224 }
210 } 225 }
211 226
  227 + for(std::shared_ptr<ThreadData> &thread : threads)
  228 + {
  229 + thread->quit();
  230 + }
  231 +
212 close(listen_fd); 232 close(listen_fd);
213 } 233 }
214 234
215 void MainApp::quit() 235 void MainApp::quit()
216 { 236 {
217 std::cout << "Quitting FlashMQ" << std::endl; 237 std::cout << "Quitting FlashMQ" << std::endl;
218 -  
219 running = false; 238 running = false;
220 -  
221 - for(std::shared_ptr<ThreadData> &thread : threads)  
222 - {  
223 - thread->quit();  
224 - }  
225 } 239 }
mqttpacket.cpp
@@ -98,7 +98,7 @@ MqttPacket::MqttPacket(const Publish &amp;publish) : @@ -98,7 +98,7 @@ MqttPacket::MqttPacket(const Publish &amp;publish) :
98 calculateRemainingLength(); 98 calculateRemainingLength();
99 } 99 }
100 100
101 -void MqttPacket::handle(std::shared_ptr<SubscriptionStore> &subscriptionStore) 101 +void MqttPacket::handle()
102 { 102 {
103 if (packetType != PacketType::CONNECT) 103 if (packetType != PacketType::CONNECT)
104 { 104 {
@@ -113,9 +113,9 @@ void MqttPacket::handle(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptionStore) @@ -113,9 +113,9 @@ void MqttPacket::handle(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptionStore)
113 else if (packetType == PacketType::PINGREQ) 113 else if (packetType == PacketType::PINGREQ)
114 sender->writePingResp(); 114 sender->writePingResp();
115 else if (packetType == PacketType::SUBSCRIBE) 115 else if (packetType == PacketType::SUBSCRIBE)
116 - handleSubscribe(subscriptionStore); 116 + handleSubscribe();
117 else if (packetType == PacketType::PUBLISH) 117 else if (packetType == PacketType::PUBLISH)
118 - handlePublish(subscriptionStore); 118 + handlePublish();
119 } 119 }
120 120
121 void MqttPacket::handleConnect() 121 void MqttPacket::handleConnect()
@@ -146,7 +146,7 @@ void MqttPacket::handleConnect() @@ -146,7 +146,7 @@ void MqttPacket::handleConnect()
146 MqttPacket response(connAck); 146 MqttPacket response(connAck);
147 sender->setReadyForDisconnect(); 147 sender->setReadyForDisconnect();
148 sender->writeMqttPacket(response); 148 sender->writeMqttPacket(response);
149 - std::cout << "Rejecting because of invalid protocol version: " << sender->repr() << std::endl; 149 + logger->logf(LOG_ERR, "Rejecting because of invalid protocol version: %s", sender->repr().c_str());
150 return; 150 return;
151 } 151 }
152 152
@@ -200,7 +200,7 @@ void MqttPacket::handleConnect() @@ -200,7 +200,7 @@ void MqttPacket::handleConnect()
200 MqttPacket response(connAck); 200 MqttPacket response(connAck);
201 sender->setReadyForDisconnect(); 201 sender->setReadyForDisconnect();
202 sender->writeMqttPacket(response); 202 sender->writeMqttPacket(response);
203 - std::cout << "Client ID, username or passwords has invalid UTF8: " << sender->repr() << std::endl; 203 + logger->logf(LOG_ERR, "Client ID, username or passwords has invalid UTF8: ", client_id.c_str());
204 return; 204 return;
205 } 205 }
206 206
@@ -212,19 +212,29 @@ void MqttPacket::handleConnect() @@ -212,19 +212,29 @@ void MqttPacket::handleConnect()
212 MqttPacket response(connAck); 212 MqttPacket response(connAck);
213 sender->setReadyForDisconnect(); 213 sender->setReadyForDisconnect();
214 sender->writeMqttPacket(response); 214 sender->writeMqttPacket(response);
215 - std::cout << "ClientID has + or # in the id: " << sender->repr() << std::endl; 215 + logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id:", client_id.c_str());
216 return; 216 return;
217 } 217 }
218 218
219 sender->setClientProperties(client_id, username, true, keep_alive); 219 sender->setClientProperties(client_id, username, true, keep_alive);
220 sender->setWill(will_topic, will_payload, will_retain, will_qos); 220 sender->setWill(will_topic, will_payload, will_retain, will_qos);
221 - sender->setAuthenticated(true);  
222 221
223 - std::cout << "Connect: " << sender->repr() << std::endl;  
224 -  
225 - ConnAck connAck(ConnAckReturnCodes::Accepted);  
226 - MqttPacket response(connAck);  
227 - sender->writeMqttPacket(response); 222 + if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success)
  223 + {
  224 + sender->setAuthenticated(true);
  225 + ConnAck connAck(ConnAckReturnCodes::Accepted);
  226 + MqttPacket response(connAck);
  227 + sender->writeMqttPacket(response);
  228 + logger->logf(LOG_NOTICE, "User %s logged in successfully", username.c_str());
  229 + }
  230 + else
  231 + {
  232 + ConnAck connDeny(ConnAckReturnCodes::NotAuthorized);
  233 + MqttPacket response(connDeny);
  234 + sender->setReadyForDisconnect();
  235 + sender->writeMqttPacket(response);
  236 + logger->logf(LOG_NOTICE, "User %s access denied", username.c_str());
  237 + }
228 } 238 }
229 else 239 else
230 { 240 {
@@ -232,7 +242,7 @@ void MqttPacket::handleConnect() @@ -232,7 +242,7 @@ void MqttPacket::handleConnect()
232 } 242 }
233 } 243 }
234 244
235 -void MqttPacket::handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptionStore) 245 +void MqttPacket::handleSubscribe()
236 { 246 {
237 uint16_t packet_id = readTwoBytesToUInt16(); 247 uint16_t packet_id = readTwoBytesToUInt16();
238 248
@@ -245,7 +255,7 @@ void MqttPacket::handleSubscribe(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptio @@ -245,7 +255,7 @@ void MqttPacket::handleSubscribe(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptio
245 if (qos > 0) 255 if (qos > 0)
246 throw NotImplementedException("QoS not implemented"); 256 throw NotImplementedException("QoS not implemented");
247 std::cout << sender->repr() << " Subscribed to " << topic << std::endl; 257 std::cout << sender->repr() << " Subscribed to " << topic << std::endl;
248 - subscriptionStore->addSubscription(sender, topic); 258 + sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic);
249 subs_reponse_codes.push_back(qos); 259 subs_reponse_codes.push_back(qos);
250 } 260 }
251 261
@@ -254,7 +264,7 @@ void MqttPacket::handleSubscribe(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptio @@ -254,7 +264,7 @@ void MqttPacket::handleSubscribe(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptio
254 sender->writeMqttPacket(response); 264 sender->writeMqttPacket(response);
255 } 265 }
256 266
257 -void MqttPacket::handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore) 267 +void MqttPacket::handlePublish()
258 { 268 {
259 uint16_t variable_header_length = readTwoBytesToUInt16(); 269 uint16_t variable_header_length = readTwoBytesToUInt16();
260 270
@@ -287,7 +297,7 @@ void MqttPacket::handlePublish(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptionS @@ -287,7 +297,7 @@ void MqttPacket::handlePublish(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptionS
287 size_t payload_length = remainingAfterPos(); 297 size_t payload_length = remainingAfterPos();
288 std::string payload(readBytes(payload_length), payload_length); 298 std::string payload(readBytes(payload_length), payload_length);
289 299
290 - subscriptionStore->setRetainedMessage(topic, payload, qos); 300 + sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, payload, qos);
291 } 301 }
292 302
293 // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. 303 // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3].
@@ -295,7 +305,7 @@ void MqttPacket::handlePublish(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptionS @@ -295,7 +305,7 @@ void MqttPacket::handlePublish(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptionS
295 bites[0] &= 0b11110110; 305 bites[0] &= 0b11110110;
296 306
297 // For the existing clients, we can just write the same packet back out, with our small alterations. 307 // For the existing clients, we can just write the same packet back out, with our small alterations.
298 - subscriptionStore->queuePacketAtSubscribers(topic, *this, sender); 308 + sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this, sender);
299 } 309 }
300 310
301 void MqttPacket::calculateRemainingLength() 311 void MqttPacket::calculateRemainingLength()
mqttpacket.h
@@ -13,6 +13,7 @@ @@ -13,6 +13,7 @@
13 #include "types.h" 13 #include "types.h"
14 #include "subscriptionstore.h" 14 #include "subscriptionstore.h"
15 #include "cirbuf.h" 15 #include "cirbuf.h"
  16 +#include "logger.h"
16 17
17 struct RemainingLength 18 struct RemainingLength
18 { 19 {
@@ -31,6 +32,7 @@ class MqttPacket @@ -31,6 +32,7 @@ class MqttPacket
31 char first_byte = 0; 32 char first_byte = 0;
32 size_t pos = 0; 33 size_t pos = 0;
33 ProtocolVersion protocolVersion = ProtocolVersion::None; 34 ProtocolVersion protocolVersion = ProtocolVersion::None;
  35 + Logger *logger = Logger::getInstance();
34 36
35 char *readBytes(size_t length); 37 char *readBytes(size_t length);
36 char readByte(); 38 char readByte();
@@ -50,11 +52,11 @@ public: @@ -50,11 +52,11 @@ public:
50 MqttPacket(const SubAck &subAck); 52 MqttPacket(const SubAck &subAck);
51 MqttPacket(const Publish &publish); 53 MqttPacket(const Publish &publish);
52 54
53 - void handle(std::shared_ptr<SubscriptionStore> &subscriptionStore); 55 + void handle();
54 void handleConnect(); 56 void handleConnect();
55 - void handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptionStore); 57 + void handleSubscribe();
56 void handlePing(); 58 void handlePing();
57 - void handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore); 59 + void handlePublish();
58 60
59 size_t getSizeIncludingNonPresentHeader() const; 61 size_t getSizeIncludingNonPresentHeader() const;
60 const std::vector<char> &getBites() const { return bites; } 62 const std::vector<char> &getBites() const { return bites; }
threaddata.cpp
@@ -11,10 +11,6 @@ ThreadData::ThreadData(int threadnr, std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscri @@ -11,10 +11,6 @@ ThreadData::ThreadData(int threadnr, std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscri
11 logger = Logger::getInstance(); 11 logger = Logger::getInstance();
12 12
13 epollfd = check<std::runtime_error>(epoll_create(999)); 13 epollfd = check<std::runtime_error>(epoll_create(999));
14 -  
15 - authPlugin.loadPlugin(confFileParser.getAuthPluginPath());  
16 - authPlugin.init();  
17 - authPlugin.securityInit(false);  
18 } 14 }
19 15
20 void ThreadData::moveThreadHere(std::thread &&thread) 16 void ThreadData::moveThreadHere(std::thread &&thread)
@@ -100,6 +96,13 @@ bool ThreadData::doKeepAliveCheck() @@ -100,6 +96,13 @@ bool ThreadData::doKeepAliveCheck()
100 return true; 96 return true;
101 } 97 }
102 98
  99 +void ThreadData::initAuthPlugin()
  100 +{
  101 + authPlugin.loadPlugin(confFileParser.getAuthPluginPath());
  102 + authPlugin.init();
  103 + authPlugin.securityInit(false);
  104 +}
  105 +
103 void ThreadData::reload() 106 void ThreadData::reload()
104 { 107 {
105 try 108 try
threaddata.h
@@ -26,10 +26,10 @@ class ThreadData @@ -26,10 +26,10 @@ class ThreadData
26 std::mutex clients_by_fd_mutex; 26 std::mutex clients_by_fd_mutex;
27 std::shared_ptr<SubscriptionStore> subscriptionStore; 27 std::shared_ptr<SubscriptionStore> subscriptionStore;
28 ConfigFileParser &confFileParser; 28 ConfigFileParser &confFileParser;
29 - AuthPlugin authPlugin;  
30 Logger *logger; 29 Logger *logger;
31 30
32 public: 31 public:
  32 + AuthPlugin authPlugin;
33 bool running = true; 33 bool running = true;
34 std::thread thread; 34 std::thread thread;
35 int threadnr = 0; 35 int threadnr = 0;
@@ -48,6 +48,7 @@ public: @@ -48,6 +48,7 @@ public:
48 std::shared_ptr<SubscriptionStore> &getSubscriptionStore(); 48 std::shared_ptr<SubscriptionStore> &getSubscriptionStore();
49 49
50 bool doKeepAliveCheck(); 50 bool doKeepAliveCheck();
  51 + void initAuthPlugin();
51 void reload(); 52 void reload();
52 }; 53 };
53 54