Commit 9e33ebdacf9cbc196ea7eb08538ed8e53bfef564

Authored by Wiebe Cazemier
1 parent 327f0c38

Add SSL support

It also contains some related improvements that I needed:

* Show disconnect reason
* Fix the while condition for doing write() to avoid an unnecessary call
* Config reloading logic
CMakeLists.txt
@@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.5) @@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.5)
2 2
3 project(FlashMQ LANGUAGES CXX) 3 project(FlashMQ LANGUAGES CXX)
4 4
  5 +add_definitions(-DOPENSSL_API_COMPAT=0x10100000L)
  6 +
5 set(CMAKE_CXX_STANDARD 11) 7 set(CMAKE_CXX_STANDARD 11)
6 set(CMAKE_CXX_STANDARD_REQUIRED ON) 8 set(CMAKE_CXX_STANDARD_REQUIRED ON)
7 9
@@ -24,6 +26,7 @@ add_executable(FlashMQ @@ -24,6 +26,7 @@ add_executable(FlashMQ
24 logger.cpp 26 logger.cpp
25 authplugin.cpp 27 authplugin.cpp
26 configfileparser.cpp 28 configfileparser.cpp
  29 + sslctxmanager.cpp
27 ) 30 )
28 31
29 -target_link_libraries(FlashMQ pthread dl) 32 +target_link_libraries(FlashMQ pthread dl ssl crypto)
client.cpp
@@ -7,8 +7,9 @@ @@ -7,8 +7,9 @@
7 7
8 #include "logger.h" 8 #include "logger.h"
9 9
10 -Client::Client(int fd, ThreadData_p threadData) : 10 +Client::Client(int fd, ThreadData_p threadData, SSL *ssl) :
11 fd(fd), 11 fd(fd),
  12 + ssl(ssl),
12 readbuf(CLIENT_BUFFER_SIZE), 13 readbuf(CLIENT_BUFFER_SIZE),
13 writebuf(CLIENT_BUFFER_SIZE), 14 writebuf(CLIENT_BUFFER_SIZE),
14 threadData(threadData) 15 threadData(threadData)
@@ -19,13 +20,71 @@ Client::Client(int fd, ThreadData_p threadData) : @@ -19,13 +20,71 @@ Client::Client(int fd, ThreadData_p threadData) :
19 20
20 Client::~Client() 21 Client::~Client()
21 { 22 {
22 - Logger *logger = Logger::getInstance();  
23 - logger->logf(LOG_NOTICE, "Removing client '%s'", repr().c_str()); 23 + if (disconnectReason.empty())
  24 + disconnectReason = "not specified";
  25 +
  26 + logger->logf(LOG_NOTICE, "Removing client '%s'. Reason: %s", repr().c_str(), disconnectReason.c_str());
24 if (epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL) != 0) 27 if (epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL) != 0)
25 logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd, repr().c_str(), strerror(errno)); 28 logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd, repr().c_str(), strerror(errno));
  29 + if (ssl)
  30 + {
  31 + // I don't do SSL_shutdown(), because I don't want to keep the session, plus, that takes active de-negiotation, so it can't be done
  32 + // in the destructor.
  33 + SSL_free(ssl);
  34 + }
26 close(fd); 35 close(fd);
27 } 36 }
28 37
  38 +bool Client::isSslAccepted() const
  39 +{
  40 + return sslAccepted;
  41 +}
  42 +
  43 +bool Client::isSsl() const
  44 +{
  45 + return this->ssl != nullptr;
  46 +}
  47 +
  48 +bool Client::getSslReadWantsWrite() const
  49 +{
  50 + return this->sslReadWantsWrite;
  51 +}
  52 +
  53 +bool Client::getSslWriteWantsRead() const
  54 +{
  55 + return this->sslWriteWantsRead;
  56 +}
  57 +
  58 +void Client::startOrContinueSslAccept()
  59 +{
  60 + ERR_clear_error();
  61 + int accepted = SSL_accept(ssl);
  62 + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];
  63 + if (accepted <= 0)
  64 + {
  65 + int err = SSL_get_error(ssl, accepted);
  66 +
  67 + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
  68 + {
  69 + setReadyForWriting(err == SSL_ERROR_WANT_WRITE);
  70 + return;
  71 + }
  72 +
  73 + unsigned long error_code = ERR_get_error();
  74 +
  75 + ERR_error_string(error_code, sslErrorBuf);
  76 + std::string errorMsg(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);
  77 +
  78 + if (error_code == OPENSSL_WRONG_VERSION_NUMBER)
  79 + errorMsg = "Wrong protocol version number. Probably a non-SSL connection on SSL socket.";
  80 +
  81 + //ERR_print_errors_cb(logSslError, NULL);
  82 + throw std::runtime_error("Problem accepting SSL socket: " + errorMsg);
  83 + }
  84 + setReadyForWriting(false); // Undo write readiness that may have have happened during SSL handshake
  85 + sslAccepted = true;
  86 +}
  87 +
29 // Causes future activity on the client to cause a disconnect. 88 // Causes future activity on the client to cause a disconnect.
30 void Client::markAsDisconnecting() 89 void Client::markAsDisconnecting()
31 { 90 {
@@ -35,29 +94,102 @@ void Client::markAsDisconnecting() @@ -35,29 +94,102 @@ void Client::markAsDisconnecting()
35 disconnecting = true; 94 disconnecting = true;
36 } 95 }
37 96
  97 +// SSL and non-SSL sockets behave differently. For one, reading 0 doesn't mean 'disconnected' with an SSL
  98 +// socket. This wrapper unifies behavor for the caller.
  99 +ssize_t Client::readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error)
  100 +{
  101 + *error = IoWrapResult::Success;
  102 + ssize_t n = 0;
  103 + if (!ssl)
  104 + {
  105 + n = read(fd, buf, nbytes);
  106 + if (n < 0)
  107 + {
  108 + if (errno == EINTR)
  109 + *error = IoWrapResult::Interrupted;
  110 + else if (errno == EAGAIN || errno == EWOULDBLOCK)
  111 + *error = IoWrapResult::Wouldblock;
  112 + else
  113 + check<std::runtime_error>(n);
  114 + }
  115 + else if (n == 0)
  116 + {
  117 + *error = IoWrapResult::Disconnected;
  118 + }
  119 + }
  120 + else
  121 + {
  122 + this->sslReadWantsWrite = false;
  123 + ERR_clear_error();
  124 + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];
  125 + n = SSL_read(ssl, buf, nbytes);
  126 +
  127 + if (n <= 0)
  128 + {
  129 + int err = SSL_get_error(ssl, n);
  130 + unsigned long error_code = ERR_get_error();
  131 +
  132 + // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html "BUGS" why EOF is seen as SSL_ERROR_SYSCALL.
  133 + if (err == SSL_ERROR_ZERO_RETURN || (err == SSL_ERROR_SYSCALL && errno == 0))
  134 + {
  135 + *error = IoWrapResult::Disconnected;
  136 + }
  137 + else if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
  138 + {
  139 + *error = IoWrapResult::Wouldblock;
  140 + if (err == SSL_ERROR_WANT_WRITE)
  141 + {
  142 + sslReadWantsWrite = true;
  143 + setReadyForWriting(true);
  144 + }
  145 + n = -1;
  146 + }
  147 + else
  148 + {
  149 + if (err == SSL_ERROR_SYSCALL)
  150 + {
  151 + // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say
  152 + // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it
  153 + // implies EINTR is not included?
  154 + if (errno == EINTR)
  155 + *error = IoWrapResult::Interrupted;
  156 + else
  157 + {
  158 + char *err = strerror(errno);
  159 + std::string msg(err);
  160 + throw std::runtime_error("SSL read error: " + msg);
  161 + }
  162 + }
  163 + ERR_error_string(error_code, sslErrorBuf);
  164 + std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);
  165 + ERR_print_errors_cb(logSslError, NULL);
  166 + throw std::runtime_error("SSL socket error reading: " + errorString);
  167 + }
  168 + }
  169 + }
  170 +
  171 + return n;
  172 +}
  173 +
38 // false means any kind of error we want to get rid of the client for. 174 // false means any kind of error we want to get rid of the client for.
39 bool Client::readFdIntoBuffer() 175 bool Client::readFdIntoBuffer()
40 { 176 {
41 if (disconnecting) 177 if (disconnecting)
42 return false; 178 return false;
43 179
  180 + IoWrapResult error = IoWrapResult::Success;
44 int n = 0; 181 int n = 0;
45 - while (readbuf.freeSpace() > 0 && (n = read(fd, readbuf.headPtr(), readbuf.maxWriteSize())) != 0) 182 + while (readbuf.freeSpace() > 0 && (n = readWrap(fd, readbuf.headPtr(), readbuf.maxWriteSize(), &error)) != 0)
46 { 183 {
47 if (n > 0) 184 if (n > 0)
48 { 185 {
49 readbuf.advanceHead(n); 186 readbuf.advanceHead(n);
50 } 187 }
51 188
52 - if (n < 0)  
53 - {  
54 - if (errno == EINTR)  
55 - continue;  
56 - if (errno == EAGAIN || errno == EWOULDBLOCK)  
57 - break;  
58 - else  
59 - check<std::runtime_error>(n);  
60 - } 189 + if (error == IoWrapResult::Interrupted)
  190 + continue;
  191 + if (error == IoWrapResult::Wouldblock)
  192 + break;
61 193
62 // Make sure we either always have enough space for a next call of this method, or stop reading the fd. 194 // Make sure we either always have enough space for a next call of this method, or stop reading the fd.
63 if (readbuf.freeSpace() == 0) 195 if (readbuf.freeSpace() == 0)
@@ -74,7 +206,7 @@ bool Client::readFdIntoBuffer() @@ -74,7 +206,7 @@ bool Client::readFdIntoBuffer()
74 } 206 }
75 } 207 }
76 208
77 - if (n == 0) // client disconnected. 209 + if (error == IoWrapResult::Disconnected)
78 { 210 {
79 return false; 211 return false;
80 } 212 }
@@ -174,6 +306,94 @@ void Client::writePingResp() @@ -174,6 +306,94 @@ void Client::writePingResp()
174 setReadyForWriting(true); 306 setReadyForWriting(true);
175 } 307 }
176 308
  309 +// SSL and non-SSL sockets behave differently. This wrapper unifies behavor for the caller.
  310 +ssize_t Client::writeWrap(int fd, const void *buf, size_t nbytes, IoWrapResult *error)
  311 +{
  312 + *error = IoWrapResult::Success;
  313 + ssize_t n = 0;
  314 +
  315 + if (!ssl)
  316 + {
  317 + // A write on a socket with count=0 is unspecified.
  318 + assert(nbytes > 0);
  319 +
  320 + n = write(fd, buf, nbytes);
  321 + if (n < 0)
  322 + {
  323 + if (errno == EINTR)
  324 + *error = IoWrapResult::Interrupted;
  325 + else if (errno == EAGAIN || errno == EWOULDBLOCK)
  326 + *error = IoWrapResult::Wouldblock;
  327 + else
  328 + check<std::runtime_error>(n);
  329 + }
  330 + }
  331 + else
  332 + {
  333 + const void *buf_ = buf;
  334 + size_t nbytes_ = nbytes;
  335 +
  336 + /*
  337 + * OpenSSL doc: When a write function call has to be repeated because SSL_get_error(3) returned
  338 + * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments
  339 + */
  340 + if (this->incompleteSslWrite.hasPendingWrite())
  341 + {
  342 + buf_ = this->incompleteSslWrite.buf;
  343 + nbytes_ = this->incompleteSslWrite.nbytes;
  344 + }
  345 +
  346 + // OpenSSL: "You should not call SSL_write() with num=0, it will return an error"
  347 + assert(nbytes_ > 0);
  348 +
  349 + this->sslWriteWantsRead = false;
  350 + this->incompleteSslWrite.reset();
  351 +
  352 + ERR_clear_error();
  353 + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];
  354 + n = SSL_write(ssl, buf_, nbytes_);
  355 +
  356 + if (n <= 0)
  357 + {
  358 + int err = SSL_get_error(ssl, n);
  359 + unsigned long error_code = ERR_get_error();
  360 + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
  361 + {
  362 + logger->logf(LOG_DEBUG, "Write is incomplete: %d", err);
  363 + *error = IoWrapResult::Wouldblock;
  364 + IncompleteSslWrite sslAction(buf_, nbytes_);
  365 + this->incompleteSslWrite = sslAction;
  366 + if (err == SSL_ERROR_WANT_READ)
  367 + this->sslWriteWantsRead = true;
  368 + n = 0;
  369 + }
  370 + else
  371 + {
  372 + if (err == SSL_ERROR_SYSCALL)
  373 + {
  374 + // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say
  375 + // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it
  376 + // implies EINTR is not included?
  377 + if (errno == EINTR)
  378 + *error = IoWrapResult::Interrupted;
  379 + else
  380 + {
  381 + char *err = strerror(errno);
  382 + std::string msg(err);
  383 + throw std::runtime_error(msg);
  384 + }
  385 + }
  386 + ERR_error_string(error_code, sslErrorBuf);
  387 + std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);
  388 + ERR_print_errors_cb(logSslError, NULL);
  389 + throw std::runtime_error("SSL socket error writing: " + errorString);
  390 + }
  391 + }
  392 + }
  393 +
  394 + return n;
  395 +}
  396 +
177 bool Client::writeBufIntoFd() 397 bool Client::writeBufIntoFd()
178 { 398 {
179 std::unique_lock<std::mutex> lock(writeBufMutex, std::try_to_lock); 399 std::unique_lock<std::mutex> lock(writeBufMutex, std::try_to_lock);
@@ -184,24 +404,23 @@ bool Client::writeBufIntoFd() @@ -184,24 +404,23 @@ bool Client::writeBufIntoFd()
184 if (disconnecting) 404 if (disconnecting)
185 return false; 405 return false;
186 406
  407 + IoWrapResult error = IoWrapResult::Success;
187 int n; 408 int n;
188 - while ((n = write(fd, writebuf.tailPtr(), writebuf.maxReadSize())) != 0) 409 + while (writebuf.usedBytes() > 0 || incompleteSslWrite.hasPendingWrite())
189 { 410 {
  411 + n = writeWrap(fd, writebuf.tailPtr(), writebuf.maxReadSize(), &error);
  412 +
190 if (n > 0) 413 if (n > 0)
191 writebuf.advanceTail(n); 414 writebuf.advanceTail(n);
192 - if (n < 0)  
193 - {  
194 - if (errno == EINTR)  
195 - continue;  
196 - if (errno == EAGAIN || errno == EWOULDBLOCK)  
197 - break;  
198 - else  
199 - check<std::runtime_error>(n);  
200 - } 415 +
  416 + if (error == IoWrapResult::Interrupted)
  417 + continue;
  418 + if (error == IoWrapResult::Wouldblock)
  419 + break;
201 } 420 }
202 421
203 const bool bufferHasData = writebuf.usedBytes() > 0; 422 const bool bufferHasData = writebuf.usedBytes() > 0;
204 - setReadyForWriting(bufferHasData); 423 + setReadyForWriting(bufferHasData || error == IoWrapResult::Wouldblock);
205 424
206 if (!bufferHasData) 425 if (!bufferHasData)
207 { 426 {
@@ -236,11 +455,22 @@ bool Client::keepAliveExpired() @@ -236,11 +455,22 @@ bool Client::keepAliveExpired()
236 return result; 455 return result;
237 } 456 }
238 457
  458 +std::string Client::getKeepAliveInfoString() const
  459 +{
  460 + std::string s = "authenticated: " + std::to_string(authenticated) + ", keep-alive: " + std::to_string(keepalive) + "s, last activity "
  461 + + std::to_string(time(NULL) - lastActivity) + " seconds ago.";
  462 + return s;
  463 +}
  464 +
  465 +// Call this from a place you know the writeBufMutex is locked, or we're still only doing SSL accept.
239 void Client::setReadyForWriting(bool val) 466 void Client::setReadyForWriting(bool val)
240 { 467 {
241 if (disconnecting) 468 if (disconnecting)
242 return; 469 return;
243 470
  471 + if (sslReadWantsWrite)
  472 + val = true;
  473 +
244 if (val == this->readyForWriting) 474 if (val == this->readyForWriting)
245 return; 475 return;
246 476
@@ -360,17 +590,29 @@ std::shared_ptr&lt;Session&gt; Client::getSession() @@ -360,17 +590,29 @@ std::shared_ptr&lt;Session&gt; Client::getSession()
360 return this->session; 590 return this->session;
361 } 591 }
362 592
  593 +void Client::setDisconnectReason(const std::string &reason)
  594 +{
  595 + // If we have a chain of errors causing this to be set, probably the first one is the most interesting.
  596 + if (!disconnectReason.empty())
  597 + return;
363 598
  599 + this->disconnectReason = reason;
  600 +}
364 601
  602 +IncompleteSslWrite::IncompleteSslWrite(const void *buf, size_t nbytes) :
  603 + buf(buf),
  604 + nbytes(nbytes)
  605 +{
365 606
  607 +}
366 608
  609 +void IncompleteSslWrite::reset()
  610 +{
  611 + buf = nullptr;
  612 + nbytes = 0;
  613 +}
367 614
368 -  
369 -  
370 -  
371 -  
372 -  
373 -  
374 -  
375 -  
376 - 615 +bool IncompleteSslWrite::hasPendingWrite()
  616 +{
  617 + return buf != nullptr;
  618 +}
client.h
@@ -8,6 +8,8 @@ @@ -8,6 +8,8 @@
8 #include <iostream> 8 #include <iostream>
9 #include <time.h> 9 #include <time.h>
10 10
  11 +#include <openssl/ssl.h>
  12 +
11 #include "forward_declarations.h" 13 #include "forward_declarations.h"
12 14
13 #include "threaddata.h" 15 #include "threaddata.h"
@@ -15,14 +17,50 @@ @@ -15,14 +17,50 @@
15 #include "exceptions.h" 17 #include "exceptions.h"
16 #include "cirbuf.h" 18 #include "cirbuf.h"
17 19
  20 +#include <openssl/ssl.h>
  21 +#include <openssl/err.h>
18 22
19 #define CLIENT_BUFFER_SIZE 1024 // Must be power of 2 23 #define CLIENT_BUFFER_SIZE 1024 // Must be power of 2
20 #define MAX_PACKET_SIZE 268435461 // 256 MB + 5 24 #define MAX_PACKET_SIZE 268435461 // 256 MB + 5
21 #define MQTT_HEADER_LENGH 2 25 #define MQTT_HEADER_LENGH 2
22 26
  27 +#define OPENSSL_ERROR_STRING_SIZE 256 // OpenSSL requires at least 256.
  28 +#define OPENSSL_WRONG_VERSION_NUMBER 336130315
  29 +
  30 +enum class IoWrapResult
  31 +{
  32 + Success = 0,
  33 + Interrupted = 1,
  34 + Wouldblock = 2,
  35 + Disconnected = 3,
  36 + Error = 4
  37 +};
  38 +
  39 +/*
  40 + * OpenSSL doc: "When a write function call has to be repeated because SSL_get_error(3) returned
  41 + * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments"
  42 + */
  43 +struct IncompleteSslWrite
  44 +{
  45 + const void *buf = nullptr;
  46 + size_t nbytes = 0;
  47 +
  48 + IncompleteSslWrite() = default;
  49 + IncompleteSslWrite(const void *buf, size_t nbytes);
  50 + bool hasPendingWrite();
  51 +
  52 + void reset();
  53 +};
  54 +
  55 +// TODO: give accepted addr, for showing in logs
23 class Client 56 class Client
24 { 57 {
25 int fd; 58 int fd;
  59 + SSL *ssl = nullptr;
  60 + bool sslAccepted = false;
  61 + IncompleteSslWrite incompleteSslWrite;
  62 + bool sslReadWantsWrite = false;
  63 + bool sslWriteWantsRead = false;
26 64
27 CirBuf readbuf; 65 CirBuf readbuf;
28 uint8_t readBufIsZeroCount = 0; 66 uint8_t readBufIsZeroCount = 0;
@@ -36,6 +74,7 @@ class Client @@ -36,6 +74,7 @@ class Client
36 bool readyForReading = true; 74 bool readyForReading = true;
37 bool disconnectWhenBytesWritten = false; 75 bool disconnectWhenBytesWritten = false;
38 bool disconnecting = false; 76 bool disconnecting = false;
  77 + std::string disconnectReason;
39 time_t lastActivity = time(NULL); 78 time_t lastActivity = time(NULL);
40 79
41 std::string clientid; 80 std::string clientid;
@@ -53,18 +92,27 @@ class Client @@ -53,18 +92,27 @@ class Client
53 92
54 std::shared_ptr<Session> session; 93 std::shared_ptr<Session> session;
55 94
  95 + Logger *logger = Logger::getInstance();
  96 +
56 97
57 void setReadyForWriting(bool val); 98 void setReadyForWriting(bool val);
58 void setReadyForReading(bool val); 99 void setReadyForReading(bool val);
59 100
60 public: 101 public:
61 - Client(int fd, ThreadData_p threadData); 102 + Client(int fd, ThreadData_p threadData, SSL *ssl);
62 Client(const Client &other) = delete; 103 Client(const Client &other) = delete;
63 Client(Client &&other) = delete; 104 Client(Client &&other) = delete;
64 ~Client(); 105 ~Client();
65 106
66 int getFd() { return fd;} 107 int getFd() { return fd;}
  108 + bool isSslAccepted() const;
  109 + bool isSsl() const;
  110 + bool getSslReadWantsWrite() const;
  111 + bool getSslWriteWantsRead() const;
  112 +
  113 + void startOrContinueSslAccept();
67 void markAsDisconnecting(); 114 void markAsDisconnecting();
  115 + ssize_t readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error);
68 bool readFdIntoBuffer(); 116 bool readFdIntoBuffer();
69 bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender); 117 bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender);
70 void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession); 118 void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession);
@@ -77,10 +125,12 @@ public: @@ -77,10 +125,12 @@ public:
77 bool getCleanSession() { return cleanSession; } 125 bool getCleanSession() { return cleanSession; }
78 void assignSession(std::shared_ptr<Session> &session); 126 void assignSession(std::shared_ptr<Session> &session);
79 std::shared_ptr<Session> getSession(); 127 std::shared_ptr<Session> getSession();
  128 + void setDisconnectReason(const std::string &reason);
80 129
81 void writePingResp(); 130 void writePingResp();
82 void writeMqttPacket(const MqttPacket &packet); 131 void writeMqttPacket(const MqttPacket &packet);
83 void writeMqttPacketAndBlameThisClient(const MqttPacket &packet); 132 void writeMqttPacketAndBlameThisClient(const MqttPacket &packet);
  133 + ssize_t writeWrap(int fd, const void *buf, size_t nbytes, IoWrapResult *error);
84 bool writeBufIntoFd(); 134 bool writeBufIntoFd();
85 bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; } 135 bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; }
86 136
@@ -89,6 +139,7 @@ public: @@ -89,6 +139,7 @@ public:
89 139
90 std::string repr(); 140 std::string repr();
91 bool keepAliveExpired(); 141 bool keepAliveExpired();
  142 + std::string getKeepAliveInfoString() const;
92 143
93 }; 144 };
94 145
configfileparser.cpp
@@ -5,10 +5,15 @@ @@ -5,10 +5,15 @@
5 #include <sstream> 5 #include <sstream>
6 #include "fstream" 6 #include "fstream"
7 7
  8 +#include "openssl/ssl.h"
  9 +#include "openssl/err.h"
  10 +
8 #include "exceptions.h" 11 #include "exceptions.h"
9 #include "utils.h" 12 #include "utils.h"
10 #include <regex> 13 #include <regex>
11 14
  15 +#include "logger.h"
  16 +
12 17
13 mosquitto_auth_opt::mosquitto_auth_opt(const std::string &key, const std::string &value) 18 mosquitto_auth_opt::mosquitto_auth_opt(const std::string &key, const std::string &value)
14 { 19 {
@@ -41,24 +46,66 @@ AuthOptCompatWrap::AuthOptCompatWrap(const std::unordered_map&lt;std::string, std:: @@ -41,24 +46,66 @@ AuthOptCompatWrap::AuthOptCompatWrap(const std::unordered_map&lt;std::string, std::
41 } 46 }
42 } 47 }
43 48
  49 +void ConfigFileParser::checkFileAccess(const std::string &key, const std::string &pathToCheck) const
  50 +{
  51 + if (access(pathToCheck.c_str(), R_OK) != 0)
  52 + {
  53 + std::ostringstream oss;
  54 + oss << "Error for '" << key << "': " << pathToCheck << " is not there or not readable";
  55 + throw ConfigFileException(oss.str());
  56 + }
  57 +}
  58 +
  59 +// Using a separate ssl context to test, because it's the easiest way to load certs and key atomitcally.
  60 +void ConfigFileParser::testSsl(const std::string &fullchain, const std::string &privkey, uint portNr) const
  61 +{
  62 + if (portNr == 0)
  63 + return;
  64 +
  65 + if (fullchain.empty() && privkey.empty())
  66 + throw ConfigFileException("No privkey and fullchain specified.");
  67 +
  68 + if (fullchain.empty())
  69 + throw ConfigFileException("No private key specified for fullchain");
  70 +
  71 + if (privkey.empty())
  72 + throw ConfigFileException("No fullchain specified for private key");
  73 +
  74 + SslCtxManager sslCtx;
  75 + if (SSL_CTX_use_certificate_file(sslCtx.get(), fullchain.c_str(), SSL_FILETYPE_PEM) != 1)
  76 + {
  77 + ERR_print_errors_cb(logSslError, NULL);
  78 + throw ConfigFileException("Error loading full chain " + fullchain);
  79 + }
  80 + if (SSL_CTX_use_PrivateKey_file(sslCtx.get(), privkey.c_str(), SSL_FILETYPE_PEM) != 1)
  81 + {
  82 + ERR_print_errors_cb(logSslError, NULL);
  83 + throw ConfigFileException("Error loading private key " + privkey);
  84 + }
  85 + if (SSL_CTX_check_private_key(sslCtx.get()) != 1)
  86 + {
  87 + ERR_print_errors_cb(logSslError, NULL);
  88 + throw ConfigFileException("Private key and certificate don't match.");
  89 + }
  90 +}
  91 +
44 ConfigFileParser::ConfigFileParser(const std::string &path) : 92 ConfigFileParser::ConfigFileParser(const std::string &path) :
45 path(path) 93 path(path)
46 { 94 {
47 validKeys.insert("auth_plugin"); 95 validKeys.insert("auth_plugin");
48 validKeys.insert("log_file"); 96 validKeys.insert("log_file");
  97 + validKeys.insert("listen_port");
  98 + validKeys.insert("ssl_listen_port");
  99 + validKeys.insert("fullchain");
  100 + validKeys.insert("privkey");
49 } 101 }
50 102
51 -void ConfigFileParser::loadFile() 103 +void ConfigFileParser::loadFile(bool test)
52 { 104 {
53 if (path.empty()) 105 if (path.empty())
54 return; 106 return;
55 107
56 - if (access(path.c_str(), R_OK) != 0)  
57 - {  
58 - std::ostringstream oss;  
59 - oss << "Error: " << path << " is not there or not readable";  
60 - throw ConfigFileException(oss.str());  
61 - } 108 + checkFileAccess("application config file", path);
62 109
63 std::ifstream infile(path, std::ios::in); 110 std::ifstream infile(path, std::ios::in);
64 111
@@ -99,6 +146,9 @@ void ConfigFileParser::loadFile() @@ -99,6 +146,9 @@ void ConfigFileParser::loadFile()
99 authOpts.clear(); 146 authOpts.clear();
100 authOptCompatWrap.reset(); 147 authOptCompatWrap.reset();
101 148
  149 + std::string sslFullChainTmp;
  150 + std::string sslPrivkeyTmp;
  151 +
102 // Then once we know the config file is valid, process it. 152 // Then once we know the config file is valid, process it.
103 for (std::string &line : lines) 153 for (std::string &line : lines)
104 { 154 {
@@ -124,22 +174,68 @@ void ConfigFileParser::loadFile() @@ -124,22 +174,68 @@ void ConfigFileParser::loadFile()
124 if (valid_key_it == validKeys.end()) 174 if (valid_key_it == validKeys.end())
125 { 175 {
126 std::ostringstream oss; 176 std::ostringstream oss;
127 - oss << "Config key '" << key << "' is not valid"; 177 + oss << "Config key '" << key << "' is not valid. This error should have been cought before. Bug?";
128 throw ConfigFileException(oss.str()); 178 throw ConfigFileException(oss.str());
129 } 179 }
130 180
131 if (key == "auth_plugin") 181 if (key == "auth_plugin")
132 { 182 {
133 - this->authPluginPath = value; 183 + checkFileAccess(key, value);
  184 + if (!test)
  185 + this->authPluginPath = value;
134 } 186 }
135 187
136 if (key == "log_file") 188 if (key == "log_file")
137 { 189 {
138 - this->logPath = value; 190 + checkFileAccess(key, value);
  191 + if (!test)
  192 + this->logPath = value;
  193 + }
  194 +
  195 + if (key == "fullchain")
  196 + {
  197 + checkFileAccess(key, value);
  198 + sslFullChainTmp = value;
  199 + }
  200 +
  201 + if (key == "privkey")
  202 + {
  203 + checkFileAccess(key, value);
  204 + sslPrivkeyTmp = value;
  205 + }
  206 +
  207 + try
  208 + {
  209 + // TODO: make this possible. There are many error cases to deal with, like bind failures, etc. You don't want to end up without listeners.
  210 + if (key == "listen_port")
  211 + {
  212 + uint listenportNew = std::stoi(value);
  213 + if (listenPort > 0 && listenPort != listenportNew)
  214 + throw ConfigFileException("Changing (ssl_)listen_port is not supported at this time.");
  215 + listenPort = listenportNew;
  216 + }
  217 +
  218 + // TODO: make this possible. There are many error cases to deal with, like bind failures, etc. You don't want to end up without listeners.
  219 + if (key == "ssl_listen_port")
  220 + {
  221 + uint sslListenPortNew = std::stoi(value);
  222 + if (sslListenPort > 0 && sslListenPort != sslListenPortNew)
  223 + throw ConfigFileException("Changing (ssl_)listen_port is not supported at this time.");
  224 + sslListenPort = sslListenPortNew;
  225 + }
  226 +
  227 + }
  228 + catch (std::invalid_argument &ex)
  229 + {
  230 + throw ConfigFileException(ex.what());
139 } 231 }
140 } 232 }
141 } 233 }
142 234
  235 + testSsl(sslFullChainTmp, sslPrivkeyTmp, sslListenPort);
  236 + this->sslFullchain = sslFullChainTmp;
  237 + this->sslPrivkey = sslPrivkeyTmp;
  238 +
143 authOptCompatWrap.reset(new AuthOptCompatWrap(authOpts)); 239 authOptCompatWrap.reset(new AuthOptCompatWrap(authOpts));
144 } 240 }
145 241
configfileparser.h
@@ -7,6 +7,8 @@ @@ -7,6 +7,8 @@
7 #include <vector> 7 #include <vector>
8 #include <memory> 8 #include <memory>
9 9
  10 +#include "sslctxmanager.h"
  11 +
10 struct mosquitto_auth_opt 12 struct mosquitto_auth_opt
11 { 13 {
12 char *key = nullptr; 14 char *key = nullptr;
@@ -38,15 +40,20 @@ class ConfigFileParser @@ -38,15 +40,20 @@ class ConfigFileParser
38 std::unique_ptr<AuthOptCompatWrap> authOptCompatWrap; 40 std::unique_ptr<AuthOptCompatWrap> authOptCompatWrap;
39 41
40 42
41 - 43 + void checkFileAccess(const std::string &key, const std::string &pathToCheck) const;
  44 + void testSsl(const std::string &fullchain, const std::string &privkey, uint portNr) const;
42 public: 45 public:
43 ConfigFileParser(const std::string &path); 46 ConfigFileParser(const std::string &path);
44 - void loadFile(); 47 + void loadFile(bool test);
45 AuthOptCompatWrap &getAuthOptsCompat(); 48 AuthOptCompatWrap &getAuthOptsCompat();
46 49
47 // Actual config options with their defaults. Just making them public, I can retrain myself misuing them. 50 // Actual config options with their defaults. Just making them public, I can retrain myself misuing them.
48 std::string authPluginPath; 51 std::string authPluginPath;
49 std::string logPath; 52 std::string logPath;
  53 + std::string sslFullchain;
  54 + std::string sslPrivkey;
  55 + uint listenPort = 1883;
  56 + uint sslListenPort = 0;
50 }; 57 };
51 58
52 #endif // CONFIGFILEPARSER_H 59 #endif // CONFIGFILEPARSER_H
logger.cpp
@@ -121,3 +121,9 @@ void Logger::logf(int level, const char *str, va_list valist) @@ -121,3 +121,9 @@ void Logger::logf(int level, const char *str, va_list valist)
121 #endif 121 #endif
122 } 122 }
123 } 123 }
  124 +
  125 +int logSslError(const char *str, size_t len, void *u)
  126 +{
  127 + Logger *logger = Logger::getInstance();
  128 + logger->logf(LOG_ERR, str);
  129 +}
logger.h
@@ -13,6 +13,8 @@ @@ -13,6 +13,8 @@
13 #define LOG_ERR 0x08 13 #define LOG_ERR 0x08
14 #define LOG_DEBUG 0x10 14 #define LOG_DEBUG 0x10
15 15
  16 +int logSslError(const char *str, size_t len, void *u);
  17 +
16 class Logger 18 class Logger
17 { 19 {
18 static Logger *instance; 20 static Logger *instance;
main.cpp
@@ -3,6 +3,7 @@ @@ -3,6 +3,7 @@
3 #include <memory> 3 #include <memory>
4 #include <string.h> 4 #include <string.h>
5 #include <sys/resource.h> 5 #include <sys/resource.h>
  6 +#include <openssl/ssl.h>
6 7
7 #include "mainapp.h" 8 #include "mainapp.h"
8 9
mainapp.cpp
@@ -5,6 +5,11 @@ @@ -5,6 +5,11 @@
5 #include <unistd.h> 5 #include <unistd.h>
6 #include <stdio.h> 6 #include <stdio.h>
7 7
  8 +#include <openssl/ssl.h>
  9 +#include <openssl/err.h>
  10 +
  11 +#include "logger.h"
  12 +
8 #define MAX_EVENTS 1024 13 #define MAX_EVENTS 1024
9 #define NR_OF_THREADS 4 14 #define NR_OF_THREADS 4
10 15
@@ -60,18 +65,30 @@ void do_thread_work(ThreadData *threadData) @@ -60,18 +65,30 @@ void do_thread_work(ThreadData *threadData)
60 { 65 {
61 try 66 try
62 { 67 {
63 - if (cur_ev.events & EPOLLIN) 68 + if (cur_ev.events & (EPOLLERR | EPOLLHUP))
  69 + {
  70 + client->setDisconnectReason("epoll says socket is in ERR or HUP state.");
  71 + threadData->removeClient(client);
  72 + continue;
  73 + }
  74 + if (client->isSsl() && !client->isSslAccepted())
  75 + {
  76 + client->startOrContinueSslAccept();
  77 + continue;
  78 + }
  79 + if ((cur_ev.events & EPOLLIN) || ((cur_ev.events & EPOLLOUT) && client->getSslReadWantsWrite()))
64 { 80 {
65 bool readSuccess = client->readFdIntoBuffer(); 81 bool readSuccess = client->readFdIntoBuffer();
66 client->bufferToMqttPackets(packetQueueIn, client); 82 client->bufferToMqttPackets(packetQueueIn, client);
67 83
68 if (!readSuccess) 84 if (!readSuccess)
69 { 85 {
  86 + client->setDisconnectReason("socket disconnect detected");
70 threadData->removeClient(client); 87 threadData->removeClient(client);
71 continue; 88 continue;
72 } 89 }
73 } 90 }
74 - if (cur_ev.events & EPOLLOUT) 91 + if ((cur_ev.events & EPOLLOUT) || ((cur_ev.events & EPOLLIN) && client->getSslWriteWantsRead()))
75 { 92 {
76 if (!client->writeBufIntoFd()) 93 if (!client->writeBufIntoFd())
77 { 94 {
@@ -85,13 +102,10 @@ void do_thread_work(ThreadData *threadData) @@ -85,13 +102,10 @@ void do_thread_work(ThreadData *threadData)
85 continue; 102 continue;
86 } 103 }
87 } 104 }
88 - if (cur_ev.events & (EPOLLERR | EPOLLHUP))  
89 - {  
90 - threadData->removeClient(client);  
91 - }  
92 } 105 }
93 catch(std::exception &ex) 106 catch(std::exception &ex)
94 { 107 {
  108 + client->setDisconnectReason(ex.what());
95 logger->logf(LOG_ERR, "Packet read/write error: %s. Removing client.", ex.what()); 109 logger->logf(LOG_ERR, "Packet read/write error: %s. Removing client.", ex.what());
96 threadData->removeClient(client); 110 threadData->removeClient(client);
97 } 111 }
@@ -107,6 +121,7 @@ void do_thread_work(ThreadData *threadData) @@ -107,6 +121,7 @@ void do_thread_work(ThreadData *threadData)
107 } 121 }
108 catch (std::exception &ex) 122 catch (std::exception &ex)
109 { 123 {
  124 + packet.getSender()->setDisconnectReason(ex.what());
110 logger->logf(LOG_ERR, "MqttPacket handling error: %s. Removing client.", ex.what()); 125 logger->logf(LOG_ERR, "MqttPacket handling error: %s. Removing client.", ex.what());
111 threadData->removeClient(packet.getSender()); 126 threadData->removeClient(packet.getSender());
112 } 127 }
@@ -133,12 +148,22 @@ void do_thread_work(ThreadData *threadData) @@ -133,12 +148,22 @@ void do_thread_work(ThreadData *threadData)
133 MainApp::MainApp(const std::string &configFilePath) : 148 MainApp::MainApp(const std::string &configFilePath) :
134 subscriptionStore(new SubscriptionStore()) 149 subscriptionStore(new SubscriptionStore())
135 { 150 {
  151 + epollFdAccept = check<std::runtime_error>(epoll_create(999));
136 taskEventFd = eventfd(0, EFD_NONBLOCK); 152 taskEventFd = eventfd(0, EFD_NONBLOCK);
137 153
138 confFileParser.reset(new ConfigFileParser(configFilePath)); 154 confFileParser.reset(new ConfigFileParser(configFilePath));
139 loadConfig(); 155 loadConfig();
140 } 156 }
141 157
  158 +MainApp::~MainApp()
  159 +{
  160 + if (sslctx)
  161 + SSL_CTX_free(sslctx);
  162 +
  163 + if (epollFdAccept > 0)
  164 + close(epollFdAccept);
  165 +}
  166 +
142 void MainApp::doHelp(const char *arg) 167 void MainApp::doHelp(const char *arg)
143 { 168 {
144 puts("FlashMQ - the scalable light-weight MQTT broker"); 169 puts("FlashMQ - the scalable light-weight MQTT broker");
@@ -147,6 +172,7 @@ void MainApp::doHelp(const char *arg) @@ -147,6 +172,7 @@ void MainApp::doHelp(const char *arg)
147 puts(""); 172 puts("");
148 puts(" -h, --help Print help"); 173 puts(" -h, --help Print help");
149 puts(" -c, --config-file <flashmq.conf> Configuration file."); 174 puts(" -c, --config-file <flashmq.conf> Configuration file.");
  175 + puts(" -t, --test-config Test configuration file.");
150 puts(" -V, --version Show version"); 176 puts(" -V, --version Show version");
151 puts(" -l, --license Show license"); 177 puts(" -l, --license Show license");
152 } 178 }
@@ -161,6 +187,56 @@ void MainApp::showLicense() @@ -161,6 +187,56 @@ void MainApp::showLicense()
161 puts("Author: Wiebe Cazemier <wiebe@halfgaar.net>"); 187 puts("Author: Wiebe Cazemier <wiebe@halfgaar.net>");
162 } 188 }
163 189
  190 +void MainApp::setCertAndKeyFromConfig()
  191 +{
  192 + if (sslctx == nullptr)
  193 + return;
  194 +
  195 + if (SSL_CTX_use_certificate_file(sslctx, confFileParser->sslFullchain.c_str(), SSL_FILETYPE_PEM) != 1)
  196 + throw std::runtime_error("Loading cert failed. This was after test loading the certificate, so is very unexpected.");
  197 + if (SSL_CTX_use_PrivateKey_file(sslctx, confFileParser->sslPrivkey.c_str(), SSL_FILETYPE_PEM) != 1)
  198 + throw std::runtime_error("Loading key failed. This was after test loading the certificate, so is very unexpected.");
  199 +}
  200 +
  201 +int MainApp::createListenSocket(int portNr, bool ssl)
  202 +{
  203 + if (portNr <= 0)
  204 + return -2;
  205 +
  206 + int listen_fd = check<std::runtime_error>(socket(AF_INET, SOCK_STREAM, 0));
  207 +
  208 + // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT.
  209 + int optval = 1;
  210 + check<std::runtime_error>(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval)));
  211 +
  212 + int flags = fcntl(listen_fd, F_GETFL);
  213 + check<std::runtime_error>(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK ));
  214 +
  215 + struct sockaddr_in in_addr_plain;
  216 + in_addr_plain.sin_family = AF_INET;
  217 + in_addr_plain.sin_addr.s_addr = INADDR_ANY;
  218 + in_addr_plain.sin_port = htons(portNr);
  219 +
  220 + check<std::runtime_error>(bind(listen_fd, (struct sockaddr *)(&in_addr_plain), sizeof(struct sockaddr_in)));
  221 + check<std::runtime_error>(listen(listen_fd, 1024));
  222 +
  223 + struct epoll_event ev;
  224 + memset(&ev, 0, sizeof (struct epoll_event));
  225 +
  226 + ev.data.fd = listen_fd;
  227 + ev.events = EPOLLIN;
  228 + check<std::runtime_error>(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, listen_fd, &ev));
  229 +
  230 + std::string socketType = "plain";
  231 +
  232 + if (ssl)
  233 + socketType = "SSL";
  234 +
  235 + logger->logf(LOG_NOTICE, "Listening on %s port %d", socketType.c_str(), portNr);
  236 +
  237 + return listen_fd;
  238 +}
  239 +
164 void MainApp::initMainApp(int argc, char *argv[]) 240 void MainApp::initMainApp(int argc, char *argv[])
165 { 241 {
166 if (instance != nullptr) 242 if (instance != nullptr)
@@ -170,6 +246,7 @@ void MainApp::initMainApp(int argc, char *argv[]) @@ -170,6 +246,7 @@ void MainApp::initMainApp(int argc, char *argv[])
170 { 246 {
171 {"help", no_argument, nullptr, 'h'}, 247 {"help", no_argument, nullptr, 'h'},
172 {"config-file", required_argument, nullptr, 'c'}, 248 {"config-file", required_argument, nullptr, 'c'},
  249 + {"test-config", no_argument, nullptr, 't'},
173 {"version", no_argument, nullptr, 'V'}, 250 {"version", no_argument, nullptr, 'V'},
174 {"license", no_argument, nullptr, 'l'}, 251 {"license", no_argument, nullptr, 'l'},
175 {nullptr, 0, nullptr, 0} 252 {nullptr, 0, nullptr, 0}
@@ -179,7 +256,8 @@ void MainApp::initMainApp(int argc, char *argv[]) @@ -179,7 +256,8 @@ void MainApp::initMainApp(int argc, char *argv[])
179 256
180 int option_index = 0; 257 int option_index = 0;
181 int opt; 258 int opt;
182 - while((opt = getopt_long(argc, argv, "hc:Vl", long_options, &option_index)) != -1) 259 + bool testConfig = false;
  260 + while((opt = getopt_long(argc, argv, "hc:Vlt", long_options, &option_index)) != -1)
183 { 261 {
184 switch(opt) 262 switch(opt)
185 { 263 {
@@ -195,12 +273,38 @@ void MainApp::initMainApp(int argc, char *argv[]) @@ -195,12 +273,38 @@ void MainApp::initMainApp(int argc, char *argv[])
195 case 'h': 273 case 'h':
196 MainApp::doHelp(argv[0]); 274 MainApp::doHelp(argv[0]);
197 exit(16); 275 exit(16);
  276 + case 't':
  277 + testConfig = true;
  278 + break;
198 case '?': 279 case '?':
199 MainApp::doHelp(argv[0]); 280 MainApp::doHelp(argv[0]);
200 exit(16); 281 exit(16);
201 } 282 }
202 } 283 }
203 284
  285 + if (testConfig)
  286 + {
  287 + try
  288 + {
  289 + if (configFile.empty())
  290 + {
  291 + std::cerr << "No config specified." << std::endl;
  292 + MainApp::doHelp(argv[0]);
  293 + exit(1);
  294 + }
  295 +
  296 + ConfigFileParser c(configFile);
  297 + c.loadFile(true);
  298 + puts("Config OK");
  299 + exit(0);
  300 + }
  301 + catch (ConfigFileException &ex)
  302 + {
  303 + std::cerr << ex.what() << std::endl;
  304 + exit(1);
  305 + }
  306 + }
  307 +
204 instance = new MainApp(configFile); 308 instance = new MainApp(configFile);
205 } 309 }
206 310
@@ -214,40 +318,14 @@ MainApp *MainApp::getMainApp() @@ -214,40 +318,14 @@ MainApp *MainApp::getMainApp()
214 318
215 void MainApp::start() 319 void MainApp::start()
216 { 320 {
217 - Logger *logger = Logger::getInstance();  
218 -  
219 - int listen_fd = check<std::runtime_error>(socket(AF_INET, SOCK_STREAM, 0)); 321 + int listen_fd_plain = createListenSocket(this->listenPort, false);
  322 + int listen_fd_ssl = createListenSocket(this->sslListenPort, true);
220 323
221 - // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT.  
222 - int optval = 1;  
223 - check<std::runtime_error>(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval)));  
224 -  
225 - int flags = fcntl(listen_fd, F_GETFL);  
226 - check<std::runtime_error>(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK ));  
227 -  
228 - struct sockaddr_in in_addr;  
229 - in_addr.sin_family = AF_INET;  
230 - in_addr.sin_addr.s_addr = INADDR_ANY;  
231 - in_addr.sin_port = htons(1883);  
232 -  
233 - check<std::runtime_error>(bind(listen_fd, (struct sockaddr *)(&in_addr), sizeof(struct sockaddr_in)));  
234 - check<std::runtime_error>(listen(listen_fd, 1024));  
235 -  
236 - int epoll_fd_accept = check<std::runtime_error>(epoll_create(999));  
237 -  
238 - struct epoll_event events[MAX_EVENTS];  
239 struct epoll_event ev; 324 struct epoll_event ev;
240 memset(&ev, 0, sizeof (struct epoll_event)); 325 memset(&ev, 0, sizeof (struct epoll_event));
241 - memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS);  
242 -  
243 - ev.data.fd = listen_fd;  
244 - ev.events = EPOLLIN;  
245 - check<std::runtime_error>(epoll_ctl(epoll_fd_accept, EPOLL_CTL_ADD, listen_fd, &ev));  
246 -  
247 - memset(&ev, 0, sizeof (struct epoll_event));  
248 ev.data.fd = taskEventFd; 326 ev.data.fd = taskEventFd;
249 ev.events = EPOLLIN; 327 ev.events = EPOLLIN;
250 - check<std::runtime_error>(epoll_ctl(epoll_fd_accept, EPOLL_CTL_ADD, taskEventFd, &ev)); 328 + check<std::runtime_error>(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, taskEventFd, &ev));
251 329
252 for (int i = 0; i < NR_OF_THREADS; i++) 330 for (int i = 0; i < NR_OF_THREADS; i++)
253 { 331 {
@@ -256,14 +334,15 @@ void MainApp::start() @@ -256,14 +334,15 @@ void MainApp::start()
256 threads.push_back(t); 334 threads.push_back(t);
257 } 335 }
258 336
259 - logger->logf(LOG_NOTICE, "Listening on port 1883");  
260 -  
261 uint next_thread_index = 0; 337 uint next_thread_index = 0;
262 338
  339 + struct epoll_event events[MAX_EVENTS];
  340 + memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS);
  341 +
263 started = true; 342 started = true;
264 while (running) 343 while (running)
265 { 344 {
266 - int num_fds = epoll_wait(epoll_fd_accept, events, MAX_EVENTS, 100); 345 + int num_fds = epoll_wait(this->epollFdAccept, events, MAX_EVENTS, 100);
267 346
268 if (num_fds < 0) 347 if (num_fds < 0)
269 { 348 {
@@ -277,7 +356,7 @@ void MainApp::start() @@ -277,7 +356,7 @@ void MainApp::start()
277 int cur_fd = events[i].data.fd; 356 int cur_fd = events[i].data.fd;
278 try 357 try
279 { 358 {
280 - if (cur_fd == listen_fd) 359 + if (cur_fd == listen_fd_plain || cur_fd == listen_fd_ssl)
281 { 360 {
282 std::shared_ptr<ThreadData> thread_data = threads[next_thread_index++ % NR_OF_THREADS]; 361 std::shared_ptr<ThreadData> thread_data = threads[next_thread_index++ % NR_OF_THREADS];
283 362
@@ -288,7 +367,22 @@ void MainApp::start() @@ -288,7 +367,22 @@ void MainApp::start()
288 socklen_t len = sizeof(struct sockaddr); 367 socklen_t len = sizeof(struct sockaddr);
289 int fd = check<std::runtime_error>(accept(cur_fd, &addr, &len)); 368 int fd = check<std::runtime_error>(accept(cur_fd, &addr, &len));
290 369
291 - Client_p client(new Client(fd, thread_data)); 370 + SSL *clientSSL = nullptr;
  371 + if (cur_fd == listen_fd_ssl)
  372 + {
  373 + clientSSL = SSL_new(sslctx);
  374 +
  375 + if (clientSSL == NULL)
  376 + {
  377 + logger->logf(LOG_ERR, "Problem creating SSL object. Closing client.");
  378 + close(fd);
  379 + continue;
  380 + }
  381 +
  382 + SSL_set_fd(clientSSL, fd);
  383 + }
  384 +
  385 + Client_p client(new Client(fd, thread_data, clientSSL));
292 thread_data->giveClient(client); 386 thread_data->giveClient(client);
293 } 387 }
294 else if (cur_fd == taskEventFd) 388 else if (cur_fd == taskEventFd)
@@ -320,7 +414,8 @@ void MainApp::start() @@ -320,7 +414,8 @@ void MainApp::start()
320 thread->quit(); 414 thread->quit();
321 } 415 }
322 416
323 - close(listen_fd); 417 + close(listen_fd_plain);
  418 + close(listen_fd_ssl);
324 } 419 }
325 420
326 void MainApp::quit() 421 void MainApp::quit()
@@ -330,12 +425,29 @@ void MainApp::quit() @@ -330,12 +425,29 @@ void MainApp::quit()
330 running = false; 425 running = false;
331 } 426 }
332 427
  428 +// Loaded on app start where you want it to crash, loaded from within try/catch on reload, to allow the program to continue.
333 void MainApp::loadConfig() 429 void MainApp::loadConfig()
334 { 430 {
335 Logger *logger = Logger::getInstance(); 431 Logger *logger = Logger::getInstance();
336 - confFileParser->loadFile(); 432 +
  433 + // Atomic loading, first test.
  434 + confFileParser->loadFile(true);
  435 + confFileParser->loadFile(false);
  436 +
337 logger->setLogPath(confFileParser->logPath); 437 logger->setLogPath(confFileParser->logPath);
338 logger->reOpen(); 438 logger->reOpen();
  439 +
  440 + listenPort = confFileParser->listenPort;
  441 + sslListenPort = confFileParser->sslListenPort;
  442 +
  443 + if (sslctx == nullptr && sslListenPort > 0)
  444 + {
  445 + sslctx = SSL_CTX_new(TLS_server_method());
  446 + SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv3); // TODO: config option
  447 + SSL_CTX_set_options(sslctx, SSL_OP_NO_TLSv1); // TODO: config option
  448 + }
  449 +
  450 + setCertAndKeyFromConfig();
339 } 451 }
340 452
341 void MainApp::reloadConfig() 453 void MainApp::reloadConfig()
mainapp.h
@@ -30,23 +30,34 @@ class MainApp @@ -30,23 +30,34 @@ class MainApp
30 std::shared_ptr<SubscriptionStore> subscriptionStore; 30 std::shared_ptr<SubscriptionStore> subscriptionStore;
31 std::unique_ptr<ConfigFileParser> confFileParser; 31 std::unique_ptr<ConfigFileParser> confFileParser;
32 std::forward_list<std::function<void()>> taskQueue; 32 std::forward_list<std::function<void()>> taskQueue;
  33 + int epollFdAccept = -1;
33 int taskEventFd = -1; 34 int taskEventFd = -1;
34 std::mutex eventMutex; 35 std::mutex eventMutex;
35 36
  37 + uint listenPort = 0;
  38 + uint sslListenPort = 0;
  39 + SSL_CTX *sslctx = nullptr;
  40 +
  41 + Logger *logger = Logger::getInstance();
  42 +
36 void loadConfig(); 43 void loadConfig();
37 void reloadConfig(); 44 void reloadConfig();
38 static void doHelp(const char *arg); 45 static void doHelp(const char *arg);
39 static void showLicense(); 46 static void showLicense();
  47 + void setCertAndKeyFromConfig();
  48 + int createListenSocket(int portNr, bool ssl);
40 49
41 MainApp(const std::string &configFilePath); 50 MainApp(const std::string &configFilePath);
42 public: 51 public:
43 MainApp(const MainApp &rhs) = delete; 52 MainApp(const MainApp &rhs) = delete;
44 MainApp(MainApp &&rhs) = delete; 53 MainApp(MainApp &&rhs) = delete;
  54 + ~MainApp();
45 static MainApp *getMainApp(); 55 static MainApp *getMainApp();
46 static void initMainApp(int argc, char *argv[]); 56 static void initMainApp(int argc, char *argv[]);
47 void start(); 57 void start();
48 void quit(); 58 void quit();
49 bool getStarted() const {return started;} 59 bool getStarted() const {return started;}
  60 + static void testConfig();
50 61
51 62
52 void queueConfigReload(); 63 void queueConfigReload();
mqttpacket.cpp
@@ -281,6 +281,7 @@ void MqttPacket::handleDisconnect() @@ -281,6 +281,7 @@ void MqttPacket::handleDisconnect()
281 { 281 {
282 logger->logf(LOG_NOTICE, "Client '%s' cleanly disconnecting", sender->repr().c_str()); 282 logger->logf(LOG_NOTICE, "Client '%s' cleanly disconnecting", sender->repr().c_str());
283 sender->markAsDisconnecting(); 283 sender->markAsDisconnecting();
  284 + sender->setDisconnectReason("MQTT Disconnect received.");
284 sender->getThreadData()->removeClient(sender); 285 sender->getThreadData()->removeClient(sender);
285 286
286 // TODO: clear will 287 // TODO: clear will
sslctxmanager.cpp 0 → 100644
  1 +#include "sslctxmanager.h"
  2 +
  3 +SslCtxManager::SslCtxManager() :
  4 + ssl_ctx(SSL_CTX_new(TLS_server_method()))
  5 +{
  6 +
  7 +}
  8 +
  9 +SslCtxManager::~SslCtxManager()
  10 +{
  11 + if (ssl_ctx)
  12 + SSL_CTX_free(ssl_ctx);
  13 +}
  14 +
  15 +SSL_CTX *SslCtxManager::get() const
  16 +{
  17 + return ssl_ctx;
  18 +}
sslctxmanager.h 0 → 100644
  1 +#ifndef SSLCTXMANAGER_H
  2 +#define SSLCTXMANAGER_H
  3 +
  4 +#include "openssl/ssl.h"
  5 +
  6 +class SslCtxManager
  7 +{
  8 + SSL_CTX *ssl_ctx = nullptr;
  9 +public:
  10 + SslCtxManager();
  11 + ~SslCtxManager();
  12 +
  13 + SSL_CTX *get() const;
  14 +};
  15 +
  16 +#endif // SSLCTXMANAGER_H
threaddata.cpp
@@ -87,6 +87,7 @@ bool ThreadData::doKeepAliveCheck() @@ -87,6 +87,7 @@ bool ThreadData::doKeepAliveCheck()
87 Client_p &client = it->second; 87 Client_p &client = it->second;
88 if (client && client->keepAliveExpired()) 88 if (client && client->keepAliveExpired())
89 { 89 {
  90 + client->setDisconnectReason("Keep-alive expired: " + client->getKeepAliveInfoString());
90 it = clients_by_fd.erase(it); 91 it = clients_by_fd.erase(it);
91 } 92 }
92 else 93 else