Commit 9e33ebdacf9cbc196ea7eb08538ed8e53bfef564

Authored by Wiebe Cazemier
1 parent 327f0c38

Add SSL support

It also contains some related improvements that I needed:

* Show disconnect reason
* Fix the while condition for doing write() to avoid an unnecessary call
* Config reloading logic
CMakeLists.txt
... ... @@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.5)
2 2  
3 3 project(FlashMQ LANGUAGES CXX)
4 4  
  5 +add_definitions(-DOPENSSL_API_COMPAT=0x10100000L)
  6 +
5 7 set(CMAKE_CXX_STANDARD 11)
6 8 set(CMAKE_CXX_STANDARD_REQUIRED ON)
7 9  
... ... @@ -24,6 +26,7 @@ add_executable(FlashMQ
24 26 logger.cpp
25 27 authplugin.cpp
26 28 configfileparser.cpp
  29 + sslctxmanager.cpp
27 30 )
28 31  
29   -target_link_libraries(FlashMQ pthread dl)
  32 +target_link_libraries(FlashMQ pthread dl ssl crypto)
... ...
client.cpp
... ... @@ -7,8 +7,9 @@
7 7  
8 8 #include "logger.h"
9 9  
10   -Client::Client(int fd, ThreadData_p threadData) :
  10 +Client::Client(int fd, ThreadData_p threadData, SSL *ssl) :
11 11 fd(fd),
  12 + ssl(ssl),
12 13 readbuf(CLIENT_BUFFER_SIZE),
13 14 writebuf(CLIENT_BUFFER_SIZE),
14 15 threadData(threadData)
... ... @@ -19,13 +20,71 @@ Client::Client(int fd, ThreadData_p threadData) :
19 20  
20 21 Client::~Client()
21 22 {
22   - Logger *logger = Logger::getInstance();
23   - logger->logf(LOG_NOTICE, "Removing client '%s'", repr().c_str());
  23 + if (disconnectReason.empty())
  24 + disconnectReason = "not specified";
  25 +
  26 + logger->logf(LOG_NOTICE, "Removing client '%s'. Reason: %s", repr().c_str(), disconnectReason.c_str());
24 27 if (epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL) != 0)
25 28 logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd, repr().c_str(), strerror(errno));
  29 + if (ssl)
  30 + {
  31 + // I don't do SSL_shutdown(), because I don't want to keep the session, plus, that takes active de-negiotation, so it can't be done
  32 + // in the destructor.
  33 + SSL_free(ssl);
  34 + }
26 35 close(fd);
27 36 }
28 37  
  38 +bool Client::isSslAccepted() const
  39 +{
  40 + return sslAccepted;
  41 +}
  42 +
  43 +bool Client::isSsl() const
  44 +{
  45 + return this->ssl != nullptr;
  46 +}
  47 +
  48 +bool Client::getSslReadWantsWrite() const
  49 +{
  50 + return this->sslReadWantsWrite;
  51 +}
  52 +
  53 +bool Client::getSslWriteWantsRead() const
  54 +{
  55 + return this->sslWriteWantsRead;
  56 +}
  57 +
  58 +void Client::startOrContinueSslAccept()
  59 +{
  60 + ERR_clear_error();
  61 + int accepted = SSL_accept(ssl);
  62 + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];
  63 + if (accepted <= 0)
  64 + {
  65 + int err = SSL_get_error(ssl, accepted);
  66 +
  67 + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
  68 + {
  69 + setReadyForWriting(err == SSL_ERROR_WANT_WRITE);
  70 + return;
  71 + }
  72 +
  73 + unsigned long error_code = ERR_get_error();
  74 +
  75 + ERR_error_string(error_code, sslErrorBuf);
  76 + std::string errorMsg(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);
  77 +
  78 + if (error_code == OPENSSL_WRONG_VERSION_NUMBER)
  79 + errorMsg = "Wrong protocol version number. Probably a non-SSL connection on SSL socket.";
  80 +
  81 + //ERR_print_errors_cb(logSslError, NULL);
  82 + throw std::runtime_error("Problem accepting SSL socket: " + errorMsg);
  83 + }
  84 + setReadyForWriting(false); // Undo write readiness that may have have happened during SSL handshake
  85 + sslAccepted = true;
  86 +}
  87 +
29 88 // Causes future activity on the client to cause a disconnect.
30 89 void Client::markAsDisconnecting()
31 90 {
... ... @@ -35,29 +94,102 @@ void Client::markAsDisconnecting()
35 94 disconnecting = true;
36 95 }
37 96  
  97 +// SSL and non-SSL sockets behave differently. For one, reading 0 doesn't mean 'disconnected' with an SSL
  98 +// socket. This wrapper unifies behavor for the caller.
  99 +ssize_t Client::readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error)
  100 +{
  101 + *error = IoWrapResult::Success;
  102 + ssize_t n = 0;
  103 + if (!ssl)
  104 + {
  105 + n = read(fd, buf, nbytes);
  106 + if (n < 0)
  107 + {
  108 + if (errno == EINTR)
  109 + *error = IoWrapResult::Interrupted;
  110 + else if (errno == EAGAIN || errno == EWOULDBLOCK)
  111 + *error = IoWrapResult::Wouldblock;
  112 + else
  113 + check<std::runtime_error>(n);
  114 + }
  115 + else if (n == 0)
  116 + {
  117 + *error = IoWrapResult::Disconnected;
  118 + }
  119 + }
  120 + else
  121 + {
  122 + this->sslReadWantsWrite = false;
  123 + ERR_clear_error();
  124 + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];
  125 + n = SSL_read(ssl, buf, nbytes);
  126 +
  127 + if (n <= 0)
  128 + {
  129 + int err = SSL_get_error(ssl, n);
  130 + unsigned long error_code = ERR_get_error();
  131 +
  132 + // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html "BUGS" why EOF is seen as SSL_ERROR_SYSCALL.
  133 + if (err == SSL_ERROR_ZERO_RETURN || (err == SSL_ERROR_SYSCALL && errno == 0))
  134 + {
  135 + *error = IoWrapResult::Disconnected;
  136 + }
  137 + else if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
  138 + {
  139 + *error = IoWrapResult::Wouldblock;
  140 + if (err == SSL_ERROR_WANT_WRITE)
  141 + {
  142 + sslReadWantsWrite = true;
  143 + setReadyForWriting(true);
  144 + }
  145 + n = -1;
  146 + }
  147 + else
  148 + {
  149 + if (err == SSL_ERROR_SYSCALL)
  150 + {
  151 + // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say
  152 + // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it
  153 + // implies EINTR is not included?
  154 + if (errno == EINTR)
  155 + *error = IoWrapResult::Interrupted;
  156 + else
  157 + {
  158 + char *err = strerror(errno);
  159 + std::string msg(err);
  160 + throw std::runtime_error("SSL read error: " + msg);
  161 + }
  162 + }
  163 + ERR_error_string(error_code, sslErrorBuf);
  164 + std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);
  165 + ERR_print_errors_cb(logSslError, NULL);
  166 + throw std::runtime_error("SSL socket error reading: " + errorString);
  167 + }
  168 + }
  169 + }
  170 +
  171 + return n;
  172 +}
  173 +
38 174 // false means any kind of error we want to get rid of the client for.
39 175 bool Client::readFdIntoBuffer()
40 176 {
41 177 if (disconnecting)
42 178 return false;
43 179  
  180 + IoWrapResult error = IoWrapResult::Success;
44 181 int n = 0;
45   - while (readbuf.freeSpace() > 0 && (n = read(fd, readbuf.headPtr(), readbuf.maxWriteSize())) != 0)
  182 + while (readbuf.freeSpace() > 0 && (n = readWrap(fd, readbuf.headPtr(), readbuf.maxWriteSize(), &error)) != 0)
46 183 {
47 184 if (n > 0)
48 185 {
49 186 readbuf.advanceHead(n);
50 187 }
51 188  
52   - if (n < 0)
53   - {
54   - if (errno == EINTR)
55   - continue;
56   - if (errno == EAGAIN || errno == EWOULDBLOCK)
57   - break;
58   - else
59   - check<std::runtime_error>(n);
60   - }
  189 + if (error == IoWrapResult::Interrupted)
  190 + continue;
  191 + if (error == IoWrapResult::Wouldblock)
  192 + break;
61 193  
62 194 // Make sure we either always have enough space for a next call of this method, or stop reading the fd.
63 195 if (readbuf.freeSpace() == 0)
... ... @@ -74,7 +206,7 @@ bool Client::readFdIntoBuffer()
74 206 }
75 207 }
76 208  
77   - if (n == 0) // client disconnected.
  209 + if (error == IoWrapResult::Disconnected)
78 210 {
79 211 return false;
80 212 }
... ... @@ -174,6 +306,94 @@ void Client::writePingResp()
174 306 setReadyForWriting(true);
175 307 }
176 308  
  309 +// SSL and non-SSL sockets behave differently. This wrapper unifies behavor for the caller.
  310 +ssize_t Client::writeWrap(int fd, const void *buf, size_t nbytes, IoWrapResult *error)
  311 +{
  312 + *error = IoWrapResult::Success;
  313 + ssize_t n = 0;
  314 +
  315 + if (!ssl)
  316 + {
  317 + // A write on a socket with count=0 is unspecified.
  318 + assert(nbytes > 0);
  319 +
  320 + n = write(fd, buf, nbytes);
  321 + if (n < 0)
  322 + {
  323 + if (errno == EINTR)
  324 + *error = IoWrapResult::Interrupted;
  325 + else if (errno == EAGAIN || errno == EWOULDBLOCK)
  326 + *error = IoWrapResult::Wouldblock;
  327 + else
  328 + check<std::runtime_error>(n);
  329 + }
  330 + }
  331 + else
  332 + {
  333 + const void *buf_ = buf;
  334 + size_t nbytes_ = nbytes;
  335 +
  336 + /*
  337 + * OpenSSL doc: When a write function call has to be repeated because SSL_get_error(3) returned
  338 + * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments
  339 + */
  340 + if (this->incompleteSslWrite.hasPendingWrite())
  341 + {
  342 + buf_ = this->incompleteSslWrite.buf;
  343 + nbytes_ = this->incompleteSslWrite.nbytes;
  344 + }
  345 +
  346 + // OpenSSL: "You should not call SSL_write() with num=0, it will return an error"
  347 + assert(nbytes_ > 0);
  348 +
  349 + this->sslWriteWantsRead = false;
  350 + this->incompleteSslWrite.reset();
  351 +
  352 + ERR_clear_error();
  353 + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];
  354 + n = SSL_write(ssl, buf_, nbytes_);
  355 +
  356 + if (n <= 0)
  357 + {
  358 + int err = SSL_get_error(ssl, n);
  359 + unsigned long error_code = ERR_get_error();
  360 + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
  361 + {
  362 + logger->logf(LOG_DEBUG, "Write is incomplete: %d", err);
  363 + *error = IoWrapResult::Wouldblock;
  364 + IncompleteSslWrite sslAction(buf_, nbytes_);
  365 + this->incompleteSslWrite = sslAction;
  366 + if (err == SSL_ERROR_WANT_READ)
  367 + this->sslWriteWantsRead = true;
  368 + n = 0;
  369 + }
  370 + else
  371 + {
  372 + if (err == SSL_ERROR_SYSCALL)
  373 + {
  374 + // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say
  375 + // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it
  376 + // implies EINTR is not included?
  377 + if (errno == EINTR)
  378 + *error = IoWrapResult::Interrupted;
  379 + else
  380 + {
  381 + char *err = strerror(errno);
  382 + std::string msg(err);
  383 + throw std::runtime_error(msg);
  384 + }
  385 + }
  386 + ERR_error_string(error_code, sslErrorBuf);
  387 + std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);
  388 + ERR_print_errors_cb(logSslError, NULL);
  389 + throw std::runtime_error("SSL socket error writing: " + errorString);
  390 + }
  391 + }
  392 + }
  393 +
  394 + return n;
  395 +}
  396 +
177 397 bool Client::writeBufIntoFd()
178 398 {
179 399 std::unique_lock<std::mutex> lock(writeBufMutex, std::try_to_lock);
... ... @@ -184,24 +404,23 @@ bool Client::writeBufIntoFd()
184 404 if (disconnecting)
185 405 return false;
186 406  
  407 + IoWrapResult error = IoWrapResult::Success;
187 408 int n;
188   - while ((n = write(fd, writebuf.tailPtr(), writebuf.maxReadSize())) != 0)
  409 + while (writebuf.usedBytes() > 0 || incompleteSslWrite.hasPendingWrite())
189 410 {
  411 + n = writeWrap(fd, writebuf.tailPtr(), writebuf.maxReadSize(), &error);
  412 +
190 413 if (n > 0)
191 414 writebuf.advanceTail(n);
192   - if (n < 0)
193   - {
194   - if (errno == EINTR)
195   - continue;
196   - if (errno == EAGAIN || errno == EWOULDBLOCK)
197   - break;
198   - else
199   - check<std::runtime_error>(n);
200   - }
  415 +
  416 + if (error == IoWrapResult::Interrupted)
  417 + continue;
  418 + if (error == IoWrapResult::Wouldblock)
  419 + break;
201 420 }
202 421  
203 422 const bool bufferHasData = writebuf.usedBytes() > 0;
204   - setReadyForWriting(bufferHasData);
  423 + setReadyForWriting(bufferHasData || error == IoWrapResult::Wouldblock);
205 424  
206 425 if (!bufferHasData)
207 426 {
... ... @@ -236,11 +455,22 @@ bool Client::keepAliveExpired()
236 455 return result;
237 456 }
238 457  
  458 +std::string Client::getKeepAliveInfoString() const
  459 +{
  460 + std::string s = "authenticated: " + std::to_string(authenticated) + ", keep-alive: " + std::to_string(keepalive) + "s, last activity "
  461 + + std::to_string(time(NULL) - lastActivity) + " seconds ago.";
  462 + return s;
  463 +}
  464 +
  465 +// Call this from a place you know the writeBufMutex is locked, or we're still only doing SSL accept.
239 466 void Client::setReadyForWriting(bool val)
240 467 {
241 468 if (disconnecting)
242 469 return;
243 470  
  471 + if (sslReadWantsWrite)
  472 + val = true;
  473 +
244 474 if (val == this->readyForWriting)
245 475 return;
246 476  
... ... @@ -360,17 +590,29 @@ std::shared_ptr&lt;Session&gt; Client::getSession()
360 590 return this->session;
361 591 }
362 592  
  593 +void Client::setDisconnectReason(const std::string &reason)
  594 +{
  595 + // If we have a chain of errors causing this to be set, probably the first one is the most interesting.
  596 + if (!disconnectReason.empty())
  597 + return;
363 598  
  599 + this->disconnectReason = reason;
  600 +}
364 601  
  602 +IncompleteSslWrite::IncompleteSslWrite(const void *buf, size_t nbytes) :
  603 + buf(buf),
  604 + nbytes(nbytes)
  605 +{
365 606  
  607 +}
366 608  
  609 +void IncompleteSslWrite::reset()
  610 +{
  611 + buf = nullptr;
  612 + nbytes = 0;
  613 +}
367 614  
368   -
369   -
370   -
371   -
372   -
373   -
374   -
375   -
376   -
  615 +bool IncompleteSslWrite::hasPendingWrite()
  616 +{
  617 + return buf != nullptr;
  618 +}
... ...
client.h
... ... @@ -8,6 +8,8 @@
8 8 #include <iostream>
9 9 #include <time.h>
10 10  
  11 +#include <openssl/ssl.h>
  12 +
11 13 #include "forward_declarations.h"
12 14  
13 15 #include "threaddata.h"
... ... @@ -15,14 +17,50 @@
15 17 #include "exceptions.h"
16 18 #include "cirbuf.h"
17 19  
  20 +#include <openssl/ssl.h>
  21 +#include <openssl/err.h>
18 22  
19 23 #define CLIENT_BUFFER_SIZE 1024 // Must be power of 2
20 24 #define MAX_PACKET_SIZE 268435461 // 256 MB + 5
21 25 #define MQTT_HEADER_LENGH 2
22 26  
  27 +#define OPENSSL_ERROR_STRING_SIZE 256 // OpenSSL requires at least 256.
  28 +#define OPENSSL_WRONG_VERSION_NUMBER 336130315
  29 +
  30 +enum class IoWrapResult
  31 +{
  32 + Success = 0,
  33 + Interrupted = 1,
  34 + Wouldblock = 2,
  35 + Disconnected = 3,
  36 + Error = 4
  37 +};
  38 +
  39 +/*
  40 + * OpenSSL doc: "When a write function call has to be repeated because SSL_get_error(3) returned
  41 + * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments"
  42 + */
  43 +struct IncompleteSslWrite
  44 +{
  45 + const void *buf = nullptr;
  46 + size_t nbytes = 0;
  47 +
  48 + IncompleteSslWrite() = default;
  49 + IncompleteSslWrite(const void *buf, size_t nbytes);
  50 + bool hasPendingWrite();
  51 +
  52 + void reset();
  53 +};
  54 +
  55 +// TODO: give accepted addr, for showing in logs
23 56 class Client
24 57 {
25 58 int fd;
  59 + SSL *ssl = nullptr;
  60 + bool sslAccepted = false;
  61 + IncompleteSslWrite incompleteSslWrite;
  62 + bool sslReadWantsWrite = false;
  63 + bool sslWriteWantsRead = false;
26 64  
27 65 CirBuf readbuf;
28 66 uint8_t readBufIsZeroCount = 0;
... ... @@ -36,6 +74,7 @@ class Client
36 74 bool readyForReading = true;
37 75 bool disconnectWhenBytesWritten = false;
38 76 bool disconnecting = false;
  77 + std::string disconnectReason;
39 78 time_t lastActivity = time(NULL);
40 79  
41 80 std::string clientid;
... ... @@ -53,18 +92,27 @@ class Client
53 92  
54 93 std::shared_ptr<Session> session;
55 94  
  95 + Logger *logger = Logger::getInstance();
  96 +
56 97  
57 98 void setReadyForWriting(bool val);
58 99 void setReadyForReading(bool val);
59 100  
60 101 public:
61   - Client(int fd, ThreadData_p threadData);
  102 + Client(int fd, ThreadData_p threadData, SSL *ssl);
62 103 Client(const Client &other) = delete;
63 104 Client(Client &&other) = delete;
64 105 ~Client();
65 106  
66 107 int getFd() { return fd;}
  108 + bool isSslAccepted() const;
  109 + bool isSsl() const;
  110 + bool getSslReadWantsWrite() const;
  111 + bool getSslWriteWantsRead() const;
  112 +
  113 + void startOrContinueSslAccept();
67 114 void markAsDisconnecting();
  115 + ssize_t readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error);
68 116 bool readFdIntoBuffer();
69 117 bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender);
70 118 void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession);
... ... @@ -77,10 +125,12 @@ public:
77 125 bool getCleanSession() { return cleanSession; }
78 126 void assignSession(std::shared_ptr<Session> &session);
79 127 std::shared_ptr<Session> getSession();
  128 + void setDisconnectReason(const std::string &reason);
80 129  
81 130 void writePingResp();
82 131 void writeMqttPacket(const MqttPacket &packet);
83 132 void writeMqttPacketAndBlameThisClient(const MqttPacket &packet);
  133 + ssize_t writeWrap(int fd, const void *buf, size_t nbytes, IoWrapResult *error);
84 134 bool writeBufIntoFd();
85 135 bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; }
86 136  
... ... @@ -89,6 +139,7 @@ public:
89 139  
90 140 std::string repr();
91 141 bool keepAliveExpired();
  142 + std::string getKeepAliveInfoString() const;
92 143  
93 144 };
94 145  
... ...
configfileparser.cpp
... ... @@ -5,10 +5,15 @@
5 5 #include <sstream>
6 6 #include "fstream"
7 7  
  8 +#include "openssl/ssl.h"
  9 +#include "openssl/err.h"
  10 +
8 11 #include "exceptions.h"
9 12 #include "utils.h"
10 13 #include <regex>
11 14  
  15 +#include "logger.h"
  16 +
12 17  
13 18 mosquitto_auth_opt::mosquitto_auth_opt(const std::string &key, const std::string &value)
14 19 {
... ... @@ -41,24 +46,66 @@ AuthOptCompatWrap::AuthOptCompatWrap(const std::unordered_map&lt;std::string, std::
41 46 }
42 47 }
43 48  
  49 +void ConfigFileParser::checkFileAccess(const std::string &key, const std::string &pathToCheck) const
  50 +{
  51 + if (access(pathToCheck.c_str(), R_OK) != 0)
  52 + {
  53 + std::ostringstream oss;
  54 + oss << "Error for '" << key << "': " << pathToCheck << " is not there or not readable";
  55 + throw ConfigFileException(oss.str());
  56 + }
  57 +}
  58 +
  59 +// Using a separate ssl context to test, because it's the easiest way to load certs and key atomitcally.
  60 +void ConfigFileParser::testSsl(const std::string &fullchain, const std::string &privkey, uint portNr) const
  61 +{
  62 + if (portNr == 0)
  63 + return;
  64 +
  65 + if (fullchain.empty() && privkey.empty())
  66 + throw ConfigFileException("No privkey and fullchain specified.");
  67 +
  68 + if (fullchain.empty())
  69 + throw ConfigFileException("No private key specified for fullchain");
  70 +
  71 + if (privkey.empty())
  72 + throw ConfigFileException("No fullchain specified for private key");
  73 +
  74 + SslCtxManager sslCtx;
  75 + if (SSL_CTX_use_certificate_file(sslCtx.get(), fullchain.c_str(), SSL_FILETYPE_PEM) != 1)
  76 + {
  77 + ERR_print_errors_cb(logSslError, NULL);
  78 + throw ConfigFileException("Error loading full chain " + fullchain);
  79 + }
  80 + if (SSL_CTX_use_PrivateKey_file(sslCtx.get(), privkey.c_str(), SSL_FILETYPE_PEM) != 1)
  81 + {
  82 + ERR_print_errors_cb(logSslError, NULL);
  83 + throw ConfigFileException("Error loading private key " + privkey);
  84 + }
  85 + if (SSL_CTX_check_private_key(sslCtx.get()) != 1)
  86 + {
  87 + ERR_print_errors_cb(logSslError, NULL);
  88 + throw ConfigFileException("Private key and certificate don't match.");
  89 + }
  90 +}
  91 +
44 92 ConfigFileParser::ConfigFileParser(const std::string &path) :
45 93 path(path)
46 94 {
47 95 validKeys.insert("auth_plugin");
48 96 validKeys.insert("log_file");
  97 + validKeys.insert("listen_port");
  98 + validKeys.insert("ssl_listen_port");
  99 + validKeys.insert("fullchain");
  100 + validKeys.insert("privkey");
49 101 }
50 102  
51   -void ConfigFileParser::loadFile()
  103 +void ConfigFileParser::loadFile(bool test)
52 104 {
53 105 if (path.empty())
54 106 return;
55 107  
56   - if (access(path.c_str(), R_OK) != 0)
57   - {
58   - std::ostringstream oss;
59   - oss << "Error: " << path << " is not there or not readable";
60   - throw ConfigFileException(oss.str());
61   - }
  108 + checkFileAccess("application config file", path);
62 109  
63 110 std::ifstream infile(path, std::ios::in);
64 111  
... ... @@ -99,6 +146,9 @@ void ConfigFileParser::loadFile()
99 146 authOpts.clear();
100 147 authOptCompatWrap.reset();
101 148  
  149 + std::string sslFullChainTmp;
  150 + std::string sslPrivkeyTmp;
  151 +
102 152 // Then once we know the config file is valid, process it.
103 153 for (std::string &line : lines)
104 154 {
... ... @@ -124,22 +174,68 @@ void ConfigFileParser::loadFile()
124 174 if (valid_key_it == validKeys.end())
125 175 {
126 176 std::ostringstream oss;
127   - oss << "Config key '" << key << "' is not valid";
  177 + oss << "Config key '" << key << "' is not valid. This error should have been cought before. Bug?";
128 178 throw ConfigFileException(oss.str());
129 179 }
130 180  
131 181 if (key == "auth_plugin")
132 182 {
133   - this->authPluginPath = value;
  183 + checkFileAccess(key, value);
  184 + if (!test)
  185 + this->authPluginPath = value;
134 186 }
135 187  
136 188 if (key == "log_file")
137 189 {
138   - this->logPath = value;
  190 + checkFileAccess(key, value);
  191 + if (!test)
  192 + this->logPath = value;
  193 + }
  194 +
  195 + if (key == "fullchain")
  196 + {
  197 + checkFileAccess(key, value);
  198 + sslFullChainTmp = value;
  199 + }
  200 +
  201 + if (key == "privkey")
  202 + {
  203 + checkFileAccess(key, value);
  204 + sslPrivkeyTmp = value;
  205 + }
  206 +
  207 + try
  208 + {
  209 + // TODO: make this possible. There are many error cases to deal with, like bind failures, etc. You don't want to end up without listeners.
  210 + if (key == "listen_port")
  211 + {
  212 + uint listenportNew = std::stoi(value);
  213 + if (listenPort > 0 && listenPort != listenportNew)
  214 + throw ConfigFileException("Changing (ssl_)listen_port is not supported at this time.");
  215 + listenPort = listenportNew;
  216 + }
  217 +
  218 + // TODO: make this possible. There are many error cases to deal with, like bind failures, etc. You don't want to end up without listeners.
  219 + if (key == "ssl_listen_port")
  220 + {
  221 + uint sslListenPortNew = std::stoi(value);
  222 + if (sslListenPort > 0 && sslListenPort != sslListenPortNew)
  223 + throw ConfigFileException("Changing (ssl_)listen_port is not supported at this time.");
  224 + sslListenPort = sslListenPortNew;
  225 + }
  226 +
  227 + }
  228 + catch (std::invalid_argument &ex)
  229 + {
  230 + throw ConfigFileException(ex.what());
139 231 }
140 232 }
141 233 }
142 234  
  235 + testSsl(sslFullChainTmp, sslPrivkeyTmp, sslListenPort);
  236 + this->sslFullchain = sslFullChainTmp;
  237 + this->sslPrivkey = sslPrivkeyTmp;
  238 +
143 239 authOptCompatWrap.reset(new AuthOptCompatWrap(authOpts));
144 240 }
145 241  
... ...
configfileparser.h
... ... @@ -7,6 +7,8 @@
7 7 #include <vector>
8 8 #include <memory>
9 9  
  10 +#include "sslctxmanager.h"
  11 +
10 12 struct mosquitto_auth_opt
11 13 {
12 14 char *key = nullptr;
... ... @@ -38,15 +40,20 @@ class ConfigFileParser
38 40 std::unique_ptr<AuthOptCompatWrap> authOptCompatWrap;
39 41  
40 42  
41   -
  43 + void checkFileAccess(const std::string &key, const std::string &pathToCheck) const;
  44 + void testSsl(const std::string &fullchain, const std::string &privkey, uint portNr) const;
42 45 public:
43 46 ConfigFileParser(const std::string &path);
44   - void loadFile();
  47 + void loadFile(bool test);
45 48 AuthOptCompatWrap &getAuthOptsCompat();
46 49  
47 50 // Actual config options with their defaults. Just making them public, I can retrain myself misuing them.
48 51 std::string authPluginPath;
49 52 std::string logPath;
  53 + std::string sslFullchain;
  54 + std::string sslPrivkey;
  55 + uint listenPort = 1883;
  56 + uint sslListenPort = 0;
50 57 };
51 58  
52 59 #endif // CONFIGFILEPARSER_H
... ...
logger.cpp
... ... @@ -121,3 +121,9 @@ void Logger::logf(int level, const char *str, va_list valist)
121 121 #endif
122 122 }
123 123 }
  124 +
  125 +int logSslError(const char *str, size_t len, void *u)
  126 +{
  127 + Logger *logger = Logger::getInstance();
  128 + logger->logf(LOG_ERR, str);
  129 +}
... ...
logger.h
... ... @@ -13,6 +13,8 @@
13 13 #define LOG_ERR 0x08
14 14 #define LOG_DEBUG 0x10
15 15  
  16 +int logSslError(const char *str, size_t len, void *u);
  17 +
16 18 class Logger
17 19 {
18 20 static Logger *instance;
... ...
main.cpp
... ... @@ -3,6 +3,7 @@
3 3 #include <memory>
4 4 #include <string.h>
5 5 #include <sys/resource.h>
  6 +#include <openssl/ssl.h>
6 7  
7 8 #include "mainapp.h"
8 9  
... ...
mainapp.cpp
... ... @@ -5,6 +5,11 @@
5 5 #include <unistd.h>
6 6 #include <stdio.h>
7 7  
  8 +#include <openssl/ssl.h>
  9 +#include <openssl/err.h>
  10 +
  11 +#include "logger.h"
  12 +
8 13 #define MAX_EVENTS 1024
9 14 #define NR_OF_THREADS 4
10 15  
... ... @@ -60,18 +65,30 @@ void do_thread_work(ThreadData *threadData)
60 65 {
61 66 try
62 67 {
63   - if (cur_ev.events & EPOLLIN)
  68 + if (cur_ev.events & (EPOLLERR | EPOLLHUP))
  69 + {
  70 + client->setDisconnectReason("epoll says socket is in ERR or HUP state.");
  71 + threadData->removeClient(client);
  72 + continue;
  73 + }
  74 + if (client->isSsl() && !client->isSslAccepted())
  75 + {
  76 + client->startOrContinueSslAccept();
  77 + continue;
  78 + }
  79 + if ((cur_ev.events & EPOLLIN) || ((cur_ev.events & EPOLLOUT) && client->getSslReadWantsWrite()))
64 80 {
65 81 bool readSuccess = client->readFdIntoBuffer();
66 82 client->bufferToMqttPackets(packetQueueIn, client);
67 83  
68 84 if (!readSuccess)
69 85 {
  86 + client->setDisconnectReason("socket disconnect detected");
70 87 threadData->removeClient(client);
71 88 continue;
72 89 }
73 90 }
74   - if (cur_ev.events & EPOLLOUT)
  91 + if ((cur_ev.events & EPOLLOUT) || ((cur_ev.events & EPOLLIN) && client->getSslWriteWantsRead()))
75 92 {
76 93 if (!client->writeBufIntoFd())
77 94 {
... ... @@ -85,13 +102,10 @@ void do_thread_work(ThreadData *threadData)
85 102 continue;
86 103 }
87 104 }
88   - if (cur_ev.events & (EPOLLERR | EPOLLHUP))
89   - {
90   - threadData->removeClient(client);
91   - }
92 105 }
93 106 catch(std::exception &ex)
94 107 {
  108 + client->setDisconnectReason(ex.what());
95 109 logger->logf(LOG_ERR, "Packet read/write error: %s. Removing client.", ex.what());
96 110 threadData->removeClient(client);
97 111 }
... ... @@ -107,6 +121,7 @@ void do_thread_work(ThreadData *threadData)
107 121 }
108 122 catch (std::exception &ex)
109 123 {
  124 + packet.getSender()->setDisconnectReason(ex.what());
110 125 logger->logf(LOG_ERR, "MqttPacket handling error: %s. Removing client.", ex.what());
111 126 threadData->removeClient(packet.getSender());
112 127 }
... ... @@ -133,12 +148,22 @@ void do_thread_work(ThreadData *threadData)
133 148 MainApp::MainApp(const std::string &configFilePath) :
134 149 subscriptionStore(new SubscriptionStore())
135 150 {
  151 + epollFdAccept = check<std::runtime_error>(epoll_create(999));
136 152 taskEventFd = eventfd(0, EFD_NONBLOCK);
137 153  
138 154 confFileParser.reset(new ConfigFileParser(configFilePath));
139 155 loadConfig();
140 156 }
141 157  
  158 +MainApp::~MainApp()
  159 +{
  160 + if (sslctx)
  161 + SSL_CTX_free(sslctx);
  162 +
  163 + if (epollFdAccept > 0)
  164 + close(epollFdAccept);
  165 +}
  166 +
142 167 void MainApp::doHelp(const char *arg)
143 168 {
144 169 puts("FlashMQ - the scalable light-weight MQTT broker");
... ... @@ -147,6 +172,7 @@ void MainApp::doHelp(const char *arg)
147 172 puts("");
148 173 puts(" -h, --help Print help");
149 174 puts(" -c, --config-file <flashmq.conf> Configuration file.");
  175 + puts(" -t, --test-config Test configuration file.");
150 176 puts(" -V, --version Show version");
151 177 puts(" -l, --license Show license");
152 178 }
... ... @@ -161,6 +187,56 @@ void MainApp::showLicense()
161 187 puts("Author: Wiebe Cazemier <wiebe@halfgaar.net>");
162 188 }
163 189  
  190 +void MainApp::setCertAndKeyFromConfig()
  191 +{
  192 + if (sslctx == nullptr)
  193 + return;
  194 +
  195 + if (SSL_CTX_use_certificate_file(sslctx, confFileParser->sslFullchain.c_str(), SSL_FILETYPE_PEM) != 1)
  196 + throw std::runtime_error("Loading cert failed. This was after test loading the certificate, so is very unexpected.");
  197 + if (SSL_CTX_use_PrivateKey_file(sslctx, confFileParser->sslPrivkey.c_str(), SSL_FILETYPE_PEM) != 1)
  198 + throw std::runtime_error("Loading key failed. This was after test loading the certificate, so is very unexpected.");
  199 +}
  200 +
  201 +int MainApp::createListenSocket(int portNr, bool ssl)
  202 +{
  203 + if (portNr <= 0)
  204 + return -2;
  205 +
  206 + int listen_fd = check<std::runtime_error>(socket(AF_INET, SOCK_STREAM, 0));
  207 +
  208 + // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT.
  209 + int optval = 1;
  210 + check<std::runtime_error>(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval)));
  211 +
  212 + int flags = fcntl(listen_fd, F_GETFL);
  213 + check<std::runtime_error>(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK ));
  214 +
  215 + struct sockaddr_in in_addr_plain;
  216 + in_addr_plain.sin_family = AF_INET;
  217 + in_addr_plain.sin_addr.s_addr = INADDR_ANY;
  218 + in_addr_plain.sin_port = htons(portNr);
  219 +
  220 + check<std::runtime_error>(bind(listen_fd, (struct sockaddr *)(&in_addr_plain), sizeof(struct sockaddr_in)));
  221 + check<std::runtime_error>(listen(listen_fd, 1024));
  222 +
  223 + struct epoll_event ev;
  224 + memset(&ev, 0, sizeof (struct epoll_event));
  225 +
  226 + ev.data.fd = listen_fd;
  227 + ev.events = EPOLLIN;
  228 + check<std::runtime_error>(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, listen_fd, &ev));
  229 +
  230 + std::string socketType = "plain";
  231 +
  232 + if (ssl)
  233 + socketType = "SSL";
  234 +
  235 + logger->logf(LOG_NOTICE, "Listening on %s port %d", socketType.c_str(), portNr);
  236 +
  237 + return listen_fd;
  238 +}
  239 +
164 240 void MainApp::initMainApp(int argc, char *argv[])
165 241 {
166 242 if (instance != nullptr)
... ... @@ -170,6 +246,7 @@ void MainApp::initMainApp(int argc, char *argv[])
170 246 {
171 247 {"help", no_argument, nullptr, 'h'},
172 248 {"config-file", required_argument, nullptr, 'c'},
  249 + {"test-config", no_argument, nullptr, 't'},
173 250 {"version", no_argument, nullptr, 'V'},
174 251 {"license", no_argument, nullptr, 'l'},
175 252 {nullptr, 0, nullptr, 0}
... ... @@ -179,7 +256,8 @@ void MainApp::initMainApp(int argc, char *argv[])
179 256  
180 257 int option_index = 0;
181 258 int opt;
182   - while((opt = getopt_long(argc, argv, "hc:Vl", long_options, &option_index)) != -1)
  259 + bool testConfig = false;
  260 + while((opt = getopt_long(argc, argv, "hc:Vlt", long_options, &option_index)) != -1)
183 261 {
184 262 switch(opt)
185 263 {
... ... @@ -195,12 +273,38 @@ void MainApp::initMainApp(int argc, char *argv[])
195 273 case 'h':
196 274 MainApp::doHelp(argv[0]);
197 275 exit(16);
  276 + case 't':
  277 + testConfig = true;
  278 + break;
198 279 case '?':
199 280 MainApp::doHelp(argv[0]);
200 281 exit(16);
201 282 }
202 283 }
203 284  
  285 + if (testConfig)
  286 + {
  287 + try
  288 + {
  289 + if (configFile.empty())
  290 + {
  291 + std::cerr << "No config specified." << std::endl;
  292 + MainApp::doHelp(argv[0]);
  293 + exit(1);
  294 + }
  295 +
  296 + ConfigFileParser c(configFile);
  297 + c.loadFile(true);
  298 + puts("Config OK");
  299 + exit(0);
  300 + }
  301 + catch (ConfigFileException &ex)
  302 + {
  303 + std::cerr << ex.what() << std::endl;
  304 + exit(1);
  305 + }
  306 + }
  307 +
204 308 instance = new MainApp(configFile);
205 309 }
206 310  
... ... @@ -214,40 +318,14 @@ MainApp *MainApp::getMainApp()
214 318  
215 319 void MainApp::start()
216 320 {
217   - Logger *logger = Logger::getInstance();
218   -
219   - int listen_fd = check<std::runtime_error>(socket(AF_INET, SOCK_STREAM, 0));
  321 + int listen_fd_plain = createListenSocket(this->listenPort, false);
  322 + int listen_fd_ssl = createListenSocket(this->sslListenPort, true);
220 323  
221   - // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT.
222   - int optval = 1;
223   - check<std::runtime_error>(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval)));
224   -
225   - int flags = fcntl(listen_fd, F_GETFL);
226   - check<std::runtime_error>(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK ));
227   -
228   - struct sockaddr_in in_addr;
229   - in_addr.sin_family = AF_INET;
230   - in_addr.sin_addr.s_addr = INADDR_ANY;
231   - in_addr.sin_port = htons(1883);
232   -
233   - check<std::runtime_error>(bind(listen_fd, (struct sockaddr *)(&in_addr), sizeof(struct sockaddr_in)));
234   - check<std::runtime_error>(listen(listen_fd, 1024));
235   -
236   - int epoll_fd_accept = check<std::runtime_error>(epoll_create(999));
237   -
238   - struct epoll_event events[MAX_EVENTS];
239 324 struct epoll_event ev;
240 325 memset(&ev, 0, sizeof (struct epoll_event));
241   - memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS);
242   -
243   - ev.data.fd = listen_fd;
244   - ev.events = EPOLLIN;
245   - check<std::runtime_error>(epoll_ctl(epoll_fd_accept, EPOLL_CTL_ADD, listen_fd, &ev));
246   -
247   - memset(&ev, 0, sizeof (struct epoll_event));
248 326 ev.data.fd = taskEventFd;
249 327 ev.events = EPOLLIN;
250   - check<std::runtime_error>(epoll_ctl(epoll_fd_accept, EPOLL_CTL_ADD, taskEventFd, &ev));
  328 + check<std::runtime_error>(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, taskEventFd, &ev));
251 329  
252 330 for (int i = 0; i < NR_OF_THREADS; i++)
253 331 {
... ... @@ -256,14 +334,15 @@ void MainApp::start()
256 334 threads.push_back(t);
257 335 }
258 336  
259   - logger->logf(LOG_NOTICE, "Listening on port 1883");
260   -
261 337 uint next_thread_index = 0;
262 338  
  339 + struct epoll_event events[MAX_EVENTS];
  340 + memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS);
  341 +
263 342 started = true;
264 343 while (running)
265 344 {
266   - int num_fds = epoll_wait(epoll_fd_accept, events, MAX_EVENTS, 100);
  345 + int num_fds = epoll_wait(this->epollFdAccept, events, MAX_EVENTS, 100);
267 346  
268 347 if (num_fds < 0)
269 348 {
... ... @@ -277,7 +356,7 @@ void MainApp::start()
277 356 int cur_fd = events[i].data.fd;
278 357 try
279 358 {
280   - if (cur_fd == listen_fd)
  359 + if (cur_fd == listen_fd_plain || cur_fd == listen_fd_ssl)
281 360 {
282 361 std::shared_ptr<ThreadData> thread_data = threads[next_thread_index++ % NR_OF_THREADS];
283 362  
... ... @@ -288,7 +367,22 @@ void MainApp::start()
288 367 socklen_t len = sizeof(struct sockaddr);
289 368 int fd = check<std::runtime_error>(accept(cur_fd, &addr, &len));
290 369  
291   - Client_p client(new Client(fd, thread_data));
  370 + SSL *clientSSL = nullptr;
  371 + if (cur_fd == listen_fd_ssl)
  372 + {
  373 + clientSSL = SSL_new(sslctx);
  374 +
  375 + if (clientSSL == NULL)
  376 + {
  377 + logger->logf(LOG_ERR, "Problem creating SSL object. Closing client.");
  378 + close(fd);
  379 + continue;
  380 + }
  381 +
  382 + SSL_set_fd(clientSSL, fd);
  383 + }
  384 +
  385 + Client_p client(new Client(fd, thread_data, clientSSL));
292 386 thread_data->giveClient(client);
293 387 }
294 388 else if (cur_fd == taskEventFd)
... ... @@ -320,7 +414,8 @@ void MainApp::start()
320 414 thread->quit();
321 415 }
322 416  
323   - close(listen_fd);
  417 + close(listen_fd_plain);
  418 + close(listen_fd_ssl);
324 419 }
325 420  
326 421 void MainApp::quit()
... ... @@ -330,12 +425,29 @@ void MainApp::quit()
330 425 running = false;
331 426 }
332 427  
  428 +// Loaded on app start where you want it to crash, loaded from within try/catch on reload, to allow the program to continue.
333 429 void MainApp::loadConfig()
334 430 {
335 431 Logger *logger = Logger::getInstance();
336   - confFileParser->loadFile();
  432 +
  433 + // Atomic loading, first test.
  434 + confFileParser->loadFile(true);
  435 + confFileParser->loadFile(false);
  436 +
337 437 logger->setLogPath(confFileParser->logPath);
338 438 logger->reOpen();
  439 +
  440 + listenPort = confFileParser->listenPort;
  441 + sslListenPort = confFileParser->sslListenPort;
  442 +
  443 + if (sslctx == nullptr && sslListenPort > 0)
  444 + {
  445 + sslctx = SSL_CTX_new(TLS_server_method());
  446 + SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv3); // TODO: config option
  447 + SSL_CTX_set_options(sslctx, SSL_OP_NO_TLSv1); // TODO: config option
  448 + }
  449 +
  450 + setCertAndKeyFromConfig();
339 451 }
340 452  
341 453 void MainApp::reloadConfig()
... ...
mainapp.h
... ... @@ -30,23 +30,34 @@ class MainApp
30 30 std::shared_ptr<SubscriptionStore> subscriptionStore;
31 31 std::unique_ptr<ConfigFileParser> confFileParser;
32 32 std::forward_list<std::function<void()>> taskQueue;
  33 + int epollFdAccept = -1;
33 34 int taskEventFd = -1;
34 35 std::mutex eventMutex;
35 36  
  37 + uint listenPort = 0;
  38 + uint sslListenPort = 0;
  39 + SSL_CTX *sslctx = nullptr;
  40 +
  41 + Logger *logger = Logger::getInstance();
  42 +
36 43 void loadConfig();
37 44 void reloadConfig();
38 45 static void doHelp(const char *arg);
39 46 static void showLicense();
  47 + void setCertAndKeyFromConfig();
  48 + int createListenSocket(int portNr, bool ssl);
40 49  
41 50 MainApp(const std::string &configFilePath);
42 51 public:
43 52 MainApp(const MainApp &rhs) = delete;
44 53 MainApp(MainApp &&rhs) = delete;
  54 + ~MainApp();
45 55 static MainApp *getMainApp();
46 56 static void initMainApp(int argc, char *argv[]);
47 57 void start();
48 58 void quit();
49 59 bool getStarted() const {return started;}
  60 + static void testConfig();
50 61  
51 62  
52 63 void queueConfigReload();
... ...
mqttpacket.cpp
... ... @@ -281,6 +281,7 @@ void MqttPacket::handleDisconnect()
281 281 {
282 282 logger->logf(LOG_NOTICE, "Client '%s' cleanly disconnecting", sender->repr().c_str());
283 283 sender->markAsDisconnecting();
  284 + sender->setDisconnectReason("MQTT Disconnect received.");
284 285 sender->getThreadData()->removeClient(sender);
285 286  
286 287 // TODO: clear will
... ...
sslctxmanager.cpp 0 → 100644
  1 +#include "sslctxmanager.h"
  2 +
  3 +SslCtxManager::SslCtxManager() :
  4 + ssl_ctx(SSL_CTX_new(TLS_server_method()))
  5 +{
  6 +
  7 +}
  8 +
  9 +SslCtxManager::~SslCtxManager()
  10 +{
  11 + if (ssl_ctx)
  12 + SSL_CTX_free(ssl_ctx);
  13 +}
  14 +
  15 +SSL_CTX *SslCtxManager::get() const
  16 +{
  17 + return ssl_ctx;
  18 +}
... ...
sslctxmanager.h 0 → 100644
  1 +#ifndef SSLCTXMANAGER_H
  2 +#define SSLCTXMANAGER_H
  3 +
  4 +#include "openssl/ssl.h"
  5 +
  6 +class SslCtxManager
  7 +{
  8 + SSL_CTX *ssl_ctx = nullptr;
  9 +public:
  10 + SslCtxManager();
  11 + ~SslCtxManager();
  12 +
  13 + SSL_CTX *get() const;
  14 +};
  15 +
  16 +#endif // SSLCTXMANAGER_H
... ...
threaddata.cpp
... ... @@ -87,6 +87,7 @@ bool ThreadData::doKeepAliveCheck()
87 87 Client_p &client = it->second;
88 88 if (client && client->keepAliveExpired())
89 89 {
  90 + client->setDisconnectReason("Keep-alive expired: " + client->getKeepAliveInfoString());
90 91 it = clients_by_fd.erase(it);
91 92 }
92 93 else
... ...