Commit 462fb2edc3c418fe9d6c0854f788575d9bf6cf74

Authored by Wiebe Cazemier
1 parent 543d4451

Working auth with plugin

authplugin.cpp
... ... @@ -12,7 +12,7 @@ void mosquitto_log_printf(int level, const char *fmt, ...)
12 12 Logger *logger = Logger::getInstance();
13 13 va_list valist;
14 14 va_start(valist, fmt);
15   - logger->logf(level, fmt);
  15 + logger->logf(level, fmt, valist);
16 16 va_end(valist);
17 17 }
18 18  
... ...
logger.cpp
... ... @@ -40,7 +40,15 @@ Logger *Logger::getInstance()
40 40  
41 41 void Logger::logf(int level, const char *str, ...)
42 42 {
43   - if (level > curLogLevel)
  43 + va_list valist;
  44 + va_start(valist, str);
  45 + this->logf(level, str, valist);
  46 + va_end(valist);
  47 +}
  48 +
  49 +void Logger::logf(int level, const char *str, va_list valist)
  50 +{
  51 + if (level > curLogLevel) // TODO: wrong: bitmap based
44 52 return;
45 53  
46 54 std::lock_guard<std::mutex> locker(logMutex);
... ... @@ -55,8 +63,6 @@ void Logger::logf(int level, const char *str, ...)
55 63 const std::string s = oss.str();
56 64 const char *logfmtstring = s.c_str();
57 65  
58   - va_list valist;
59   - va_start(valist, str);
60 66 if (this->file)
61 67 {
62 68 vfprintf(this->file, logfmtstring, valist);
... ... @@ -75,5 +81,4 @@ void Logger::logf(int level, const char *str, ...)
75 81 fflush(stdout);
76 82 #endif
77 83 }
78   - va_end(valist);
79 84 }
... ...
logger.h
... ... @@ -25,6 +25,7 @@ class Logger
25 25  
26 26 public:
27 27 static Logger *getInstance();
  28 + void logf(int level, const char *str, va_list args);
28 29 void logf(int level, const char *str, ...);
29 30  
30 31 };
... ...
main.cpp
... ... @@ -65,6 +65,9 @@ int main()
65 65 {
66 66 try
67 67 {
  68 + Logger *logger = Logger::getInstance();
  69 + logger->logf(LOG_NOTICE, "Starting FlashMQ");
  70 +
68 71 mainApp = MainApp::getMainApp();
69 72 check<std::runtime_error>(register_signal_handers());
70 73 mainApp->start();
... ...
mainapp.cpp
... ... @@ -17,6 +17,21 @@ void do_thread_work(ThreadData *threadData)
17 17 std::vector<MqttPacket> packetQueueIn;
18 18 time_t lastKeepAliveCheck = time(NULL);
19 19  
  20 + Logger *logger = Logger::getInstance();
  21 +
  22 + try
  23 + {
  24 + logger->logf(LOG_NOTICE, "Thread %d doing auth init.", threadData->threadnr);
  25 + threadData->initAuthPlugin();
  26 + }
  27 + catch(std::exception &ex)
  28 + {
  29 + logger->logf(LOG_ERR, "Error initializing auth back-end: %s", ex.what());
  30 + threadData->running = false;
  31 + MainApp *instance = MainApp::getMainApp();
  32 + instance->quit();
  33 + }
  34 +
20 35 while (threadData->running)
21 36 {
22 37 int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100);
... ... @@ -83,7 +98,7 @@ void do_thread_work(ThreadData *threadData)
83 98 {
84 99 try
85 100 {
86   - packet.handle(threadData->getSubscriptionStore());
  101 + packet.handle();
87 102 }
88 103 catch (std::exception &ex)
89 104 {
... ... @@ -209,17 +224,16 @@ void MainApp::start()
209 224 }
210 225 }
211 226  
  227 + for(std::shared_ptr<ThreadData> &thread : threads)
  228 + {
  229 + thread->quit();
  230 + }
  231 +
212 232 close(listen_fd);
213 233 }
214 234  
215 235 void MainApp::quit()
216 236 {
217 237 std::cout << "Quitting FlashMQ" << std::endl;
218   -
219 238 running = false;
220   -
221   - for(std::shared_ptr<ThreadData> &thread : threads)
222   - {
223   - thread->quit();
224   - }
225 239 }
... ...
mqttpacket.cpp
... ... @@ -98,7 +98,7 @@ MqttPacket::MqttPacket(const Publish &amp;publish) :
98 98 calculateRemainingLength();
99 99 }
100 100  
101   -void MqttPacket::handle(std::shared_ptr<SubscriptionStore> &subscriptionStore)
  101 +void MqttPacket::handle()
102 102 {
103 103 if (packetType != PacketType::CONNECT)
104 104 {
... ... @@ -113,9 +113,9 @@ void MqttPacket::handle(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptionStore)
113 113 else if (packetType == PacketType::PINGREQ)
114 114 sender->writePingResp();
115 115 else if (packetType == PacketType::SUBSCRIBE)
116   - handleSubscribe(subscriptionStore);
  116 + handleSubscribe();
117 117 else if (packetType == PacketType::PUBLISH)
118   - handlePublish(subscriptionStore);
  118 + handlePublish();
119 119 }
120 120  
121 121 void MqttPacket::handleConnect()
... ... @@ -146,7 +146,7 @@ void MqttPacket::handleConnect()
146 146 MqttPacket response(connAck);
147 147 sender->setReadyForDisconnect();
148 148 sender->writeMqttPacket(response);
149   - std::cout << "Rejecting because of invalid protocol version: " << sender->repr() << std::endl;
  149 + logger->logf(LOG_ERR, "Rejecting because of invalid protocol version: %s", sender->repr().c_str());
150 150 return;
151 151 }
152 152  
... ... @@ -200,7 +200,7 @@ void MqttPacket::handleConnect()
200 200 MqttPacket response(connAck);
201 201 sender->setReadyForDisconnect();
202 202 sender->writeMqttPacket(response);
203   - std::cout << "Client ID, username or passwords has invalid UTF8: " << sender->repr() << std::endl;
  203 + logger->logf(LOG_ERR, "Client ID, username or passwords has invalid UTF8: ", client_id.c_str());
204 204 return;
205 205 }
206 206  
... ... @@ -212,19 +212,29 @@ void MqttPacket::handleConnect()
212 212 MqttPacket response(connAck);
213 213 sender->setReadyForDisconnect();
214 214 sender->writeMqttPacket(response);
215   - std::cout << "ClientID has + or # in the id: " << sender->repr() << std::endl;
  215 + logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id:", client_id.c_str());
216 216 return;
217 217 }
218 218  
219 219 sender->setClientProperties(client_id, username, true, keep_alive);
220 220 sender->setWill(will_topic, will_payload, will_retain, will_qos);
221   - sender->setAuthenticated(true);
222 221  
223   - std::cout << "Connect: " << sender->repr() << std::endl;
224   -
225   - ConnAck connAck(ConnAckReturnCodes::Accepted);
226   - MqttPacket response(connAck);
227   - sender->writeMqttPacket(response);
  222 + if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success)
  223 + {
  224 + sender->setAuthenticated(true);
  225 + ConnAck connAck(ConnAckReturnCodes::Accepted);
  226 + MqttPacket response(connAck);
  227 + sender->writeMqttPacket(response);
  228 + logger->logf(LOG_NOTICE, "User %s logged in successfully", username.c_str());
  229 + }
  230 + else
  231 + {
  232 + ConnAck connDeny(ConnAckReturnCodes::NotAuthorized);
  233 + MqttPacket response(connDeny);
  234 + sender->setReadyForDisconnect();
  235 + sender->writeMqttPacket(response);
  236 + logger->logf(LOG_NOTICE, "User %s access denied", username.c_str());
  237 + }
228 238 }
229 239 else
230 240 {
... ... @@ -232,7 +242,7 @@ void MqttPacket::handleConnect()
232 242 }
233 243 }
234 244  
235   -void MqttPacket::handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptionStore)
  245 +void MqttPacket::handleSubscribe()
236 246 {
237 247 uint16_t packet_id = readTwoBytesToUInt16();
238 248  
... ... @@ -245,7 +255,7 @@ void MqttPacket::handleSubscribe(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptio
245 255 if (qos > 0)
246 256 throw NotImplementedException("QoS not implemented");
247 257 std::cout << sender->repr() << " Subscribed to " << topic << std::endl;
248   - subscriptionStore->addSubscription(sender, topic);
  258 + sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic);
249 259 subs_reponse_codes.push_back(qos);
250 260 }
251 261  
... ... @@ -254,7 +264,7 @@ void MqttPacket::handleSubscribe(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptio
254 264 sender->writeMqttPacket(response);
255 265 }
256 266  
257   -void MqttPacket::handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore)
  267 +void MqttPacket::handlePublish()
258 268 {
259 269 uint16_t variable_header_length = readTwoBytesToUInt16();
260 270  
... ... @@ -287,7 +297,7 @@ void MqttPacket::handlePublish(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptionS
287 297 size_t payload_length = remainingAfterPos();
288 298 std::string payload(readBytes(payload_length), payload_length);
289 299  
290   - subscriptionStore->setRetainedMessage(topic, payload, qos);
  300 + sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(topic, payload, qos);
291 301 }
292 302  
293 303 // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3].
... ... @@ -295,7 +305,7 @@ void MqttPacket::handlePublish(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptionS
295 305 bites[0] &= 0b11110110;
296 306  
297 307 // For the existing clients, we can just write the same packet back out, with our small alterations.
298   - subscriptionStore->queuePacketAtSubscribers(topic, *this, sender);
  308 + sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this, sender);
299 309 }
300 310  
301 311 void MqttPacket::calculateRemainingLength()
... ...
mqttpacket.h
... ... @@ -13,6 +13,7 @@
13 13 #include "types.h"
14 14 #include "subscriptionstore.h"
15 15 #include "cirbuf.h"
  16 +#include "logger.h"
16 17  
17 18 struct RemainingLength
18 19 {
... ... @@ -31,6 +32,7 @@ class MqttPacket
31 32 char first_byte = 0;
32 33 size_t pos = 0;
33 34 ProtocolVersion protocolVersion = ProtocolVersion::None;
  35 + Logger *logger = Logger::getInstance();
34 36  
35 37 char *readBytes(size_t length);
36 38 char readByte();
... ... @@ -50,11 +52,11 @@ public:
50 52 MqttPacket(const SubAck &subAck);
51 53 MqttPacket(const Publish &publish);
52 54  
53   - void handle(std::shared_ptr<SubscriptionStore> &subscriptionStore);
  55 + void handle();
54 56 void handleConnect();
55   - void handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptionStore);
  57 + void handleSubscribe();
56 58 void handlePing();
57   - void handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore);
  59 + void handlePublish();
58 60  
59 61 size_t getSizeIncludingNonPresentHeader() const;
60 62 const std::vector<char> &getBites() const { return bites; }
... ...
threaddata.cpp
... ... @@ -11,10 +11,6 @@ ThreadData::ThreadData(int threadnr, std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscri
11 11 logger = Logger::getInstance();
12 12  
13 13 epollfd = check<std::runtime_error>(epoll_create(999));
14   -
15   - authPlugin.loadPlugin(confFileParser.getAuthPluginPath());
16   - authPlugin.init();
17   - authPlugin.securityInit(false);
18 14 }
19 15  
20 16 void ThreadData::moveThreadHere(std::thread &&thread)
... ... @@ -100,6 +96,13 @@ bool ThreadData::doKeepAliveCheck()
100 96 return true;
101 97 }
102 98  
  99 +void ThreadData::initAuthPlugin()
  100 +{
  101 + authPlugin.loadPlugin(confFileParser.getAuthPluginPath());
  102 + authPlugin.init();
  103 + authPlugin.securityInit(false);
  104 +}
  105 +
103 106 void ThreadData::reload()
104 107 {
105 108 try
... ...
threaddata.h
... ... @@ -26,10 +26,10 @@ class ThreadData
26 26 std::mutex clients_by_fd_mutex;
27 27 std::shared_ptr<SubscriptionStore> subscriptionStore;
28 28 ConfigFileParser &confFileParser;
29   - AuthPlugin authPlugin;
30 29 Logger *logger;
31 30  
32 31 public:
  32 + AuthPlugin authPlugin;
33 33 bool running = true;
34 34 std::thread thread;
35 35 int threadnr = 0;
... ... @@ -48,6 +48,7 @@ public:
48 48 std::shared_ptr<SubscriptionStore> &getSubscriptionStore();
49 49  
50 50 bool doKeepAliveCheck();
  51 + void initAuthPlugin();
51 52 void reload();
52 53 };
53 54  
... ...