From 9e33ebdacf9cbc196ea7eb08538ed8e53bfef564 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Sun, 24 Jan 2021 17:27:52 +0100 Subject: [PATCH] Add SSL support --- CMakeLists.txt | 5 ++++- client.cpp | 310 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------------------- client.h | 53 ++++++++++++++++++++++++++++++++++++++++++++++++++++- configfileparser.cpp | 116 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------- configfileparser.h | 11 +++++++++-- logger.cpp | 6 ++++++ logger.h | 2 ++ main.cpp | 1 + mainapp.cpp | 198 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------------------------------- mainapp.h | 11 +++++++++++ mqttpacket.cpp | 1 + sslctxmanager.cpp | 18 ++++++++++++++++++ sslctxmanager.h | 16 ++++++++++++++++ threaddata.cpp | 1 + 14 files changed, 658 insertions(+), 91 deletions(-) create mode 100644 sslctxmanager.cpp create mode 100644 sslctxmanager.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 37dafb2..fe8cc75 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.5) project(FlashMQ LANGUAGES CXX) +add_definitions(-DOPENSSL_API_COMPAT=0x10100000L) + set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD_REQUIRED ON) @@ -24,6 +26,7 @@ add_executable(FlashMQ logger.cpp authplugin.cpp configfileparser.cpp + sslctxmanager.cpp ) -target_link_libraries(FlashMQ pthread dl) +target_link_libraries(FlashMQ pthread dl ssl crypto) diff --git a/client.cpp b/client.cpp index 70eb551..c910bc7 100644 --- a/client.cpp +++ b/client.cpp @@ -7,8 +7,9 @@ #include "logger.h" -Client::Client(int fd, ThreadData_p threadData) : +Client::Client(int fd, ThreadData_p threadData, SSL *ssl) : fd(fd), + ssl(ssl), readbuf(CLIENT_BUFFER_SIZE), writebuf(CLIENT_BUFFER_SIZE), threadData(threadData) @@ -19,13 +20,71 @@ Client::Client(int fd, ThreadData_p threadData) : Client::~Client() { - Logger *logger = Logger::getInstance(); - logger->logf(LOG_NOTICE, "Removing client '%s'", repr().c_str()); + if (disconnectReason.empty()) + disconnectReason = "not specified"; + + logger->logf(LOG_NOTICE, "Removing client '%s'. Reason: %s", repr().c_str(), disconnectReason.c_str()); if (epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL) != 0) logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd, repr().c_str(), strerror(errno)); + if (ssl) + { + // I don't do SSL_shutdown(), because I don't want to keep the session, plus, that takes active de-negiotation, so it can't be done + // in the destructor. + SSL_free(ssl); + } close(fd); } +bool Client::isSslAccepted() const +{ + return sslAccepted; +} + +bool Client::isSsl() const +{ + return this->ssl != nullptr; +} + +bool Client::getSslReadWantsWrite() const +{ + return this->sslReadWantsWrite; +} + +bool Client::getSslWriteWantsRead() const +{ + return this->sslWriteWantsRead; +} + +void Client::startOrContinueSslAccept() +{ + ERR_clear_error(); + int accepted = SSL_accept(ssl); + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; + if (accepted <= 0) + { + int err = SSL_get_error(ssl, accepted); + + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) + { + setReadyForWriting(err == SSL_ERROR_WANT_WRITE); + return; + } + + unsigned long error_code = ERR_get_error(); + + ERR_error_string(error_code, sslErrorBuf); + std::string errorMsg(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); + + if (error_code == OPENSSL_WRONG_VERSION_NUMBER) + errorMsg = "Wrong protocol version number. Probably a non-SSL connection on SSL socket."; + + //ERR_print_errors_cb(logSslError, NULL); + throw std::runtime_error("Problem accepting SSL socket: " + errorMsg); + } + setReadyForWriting(false); // Undo write readiness that may have have happened during SSL handshake + sslAccepted = true; +} + // Causes future activity on the client to cause a disconnect. void Client::markAsDisconnecting() { @@ -35,29 +94,102 @@ void Client::markAsDisconnecting() disconnecting = true; } +// SSL and non-SSL sockets behave differently. For one, reading 0 doesn't mean 'disconnected' with an SSL +// socket. This wrapper unifies behavor for the caller. +ssize_t Client::readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error) +{ + *error = IoWrapResult::Success; + ssize_t n = 0; + if (!ssl) + { + n = read(fd, buf, nbytes); + if (n < 0) + { + if (errno == EINTR) + *error = IoWrapResult::Interrupted; + else if (errno == EAGAIN || errno == EWOULDBLOCK) + *error = IoWrapResult::Wouldblock; + else + check(n); + } + else if (n == 0) + { + *error = IoWrapResult::Disconnected; + } + } + else + { + this->sslReadWantsWrite = false; + ERR_clear_error(); + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; + n = SSL_read(ssl, buf, nbytes); + + if (n <= 0) + { + int err = SSL_get_error(ssl, n); + unsigned long error_code = ERR_get_error(); + + // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html "BUGS" why EOF is seen as SSL_ERROR_SYSCALL. + if (err == SSL_ERROR_ZERO_RETURN || (err == SSL_ERROR_SYSCALL && errno == 0)) + { + *error = IoWrapResult::Disconnected; + } + else if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) + { + *error = IoWrapResult::Wouldblock; + if (err == SSL_ERROR_WANT_WRITE) + { + sslReadWantsWrite = true; + setReadyForWriting(true); + } + n = -1; + } + else + { + if (err == SSL_ERROR_SYSCALL) + { + // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say + // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it + // implies EINTR is not included? + if (errno == EINTR) + *error = IoWrapResult::Interrupted; + else + { + char *err = strerror(errno); + std::string msg(err); + throw std::runtime_error("SSL read error: " + msg); + } + } + ERR_error_string(error_code, sslErrorBuf); + std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); + ERR_print_errors_cb(logSslError, NULL); + throw std::runtime_error("SSL socket error reading: " + errorString); + } + } + } + + return n; +} + // false means any kind of error we want to get rid of the client for. bool Client::readFdIntoBuffer() { if (disconnecting) return false; + IoWrapResult error = IoWrapResult::Success; int n = 0; - while (readbuf.freeSpace() > 0 && (n = read(fd, readbuf.headPtr(), readbuf.maxWriteSize())) != 0) + while (readbuf.freeSpace() > 0 && (n = readWrap(fd, readbuf.headPtr(), readbuf.maxWriteSize(), &error)) != 0) { if (n > 0) { readbuf.advanceHead(n); } - if (n < 0) - { - if (errno == EINTR) - continue; - if (errno == EAGAIN || errno == EWOULDBLOCK) - break; - else - check(n); - } + if (error == IoWrapResult::Interrupted) + continue; + if (error == IoWrapResult::Wouldblock) + break; // Make sure we either always have enough space for a next call of this method, or stop reading the fd. if (readbuf.freeSpace() == 0) @@ -74,7 +206,7 @@ bool Client::readFdIntoBuffer() } } - if (n == 0) // client disconnected. + if (error == IoWrapResult::Disconnected) { return false; } @@ -174,6 +306,94 @@ void Client::writePingResp() setReadyForWriting(true); } +// SSL and non-SSL sockets behave differently. This wrapper unifies behavor for the caller. +ssize_t Client::writeWrap(int fd, const void *buf, size_t nbytes, IoWrapResult *error) +{ + *error = IoWrapResult::Success; + ssize_t n = 0; + + if (!ssl) + { + // A write on a socket with count=0 is unspecified. + assert(nbytes > 0); + + n = write(fd, buf, nbytes); + if (n < 0) + { + if (errno == EINTR) + *error = IoWrapResult::Interrupted; + else if (errno == EAGAIN || errno == EWOULDBLOCK) + *error = IoWrapResult::Wouldblock; + else + check(n); + } + } + else + { + const void *buf_ = buf; + size_t nbytes_ = nbytes; + + /* + * OpenSSL doc: When a write function call has to be repeated because SSL_get_error(3) returned + * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments + */ + if (this->incompleteSslWrite.hasPendingWrite()) + { + buf_ = this->incompleteSslWrite.buf; + nbytes_ = this->incompleteSslWrite.nbytes; + } + + // OpenSSL: "You should not call SSL_write() with num=0, it will return an error" + assert(nbytes_ > 0); + + this->sslWriteWantsRead = false; + this->incompleteSslWrite.reset(); + + ERR_clear_error(); + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; + n = SSL_write(ssl, buf_, nbytes_); + + if (n <= 0) + { + int err = SSL_get_error(ssl, n); + unsigned long error_code = ERR_get_error(); + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) + { + logger->logf(LOG_DEBUG, "Write is incomplete: %d", err); + *error = IoWrapResult::Wouldblock; + IncompleteSslWrite sslAction(buf_, nbytes_); + this->incompleteSslWrite = sslAction; + if (err == SSL_ERROR_WANT_READ) + this->sslWriteWantsRead = true; + n = 0; + } + else + { + if (err == SSL_ERROR_SYSCALL) + { + // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say + // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it + // implies EINTR is not included? + if (errno == EINTR) + *error = IoWrapResult::Interrupted; + else + { + char *err = strerror(errno); + std::string msg(err); + throw std::runtime_error(msg); + } + } + ERR_error_string(error_code, sslErrorBuf); + std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); + ERR_print_errors_cb(logSslError, NULL); + throw std::runtime_error("SSL socket error writing: " + errorString); + } + } + } + + return n; +} + bool Client::writeBufIntoFd() { std::unique_lock lock(writeBufMutex, std::try_to_lock); @@ -184,24 +404,23 @@ bool Client::writeBufIntoFd() if (disconnecting) return false; + IoWrapResult error = IoWrapResult::Success; int n; - while ((n = write(fd, writebuf.tailPtr(), writebuf.maxReadSize())) != 0) + while (writebuf.usedBytes() > 0 || incompleteSslWrite.hasPendingWrite()) { + n = writeWrap(fd, writebuf.tailPtr(), writebuf.maxReadSize(), &error); + if (n > 0) writebuf.advanceTail(n); - if (n < 0) - { - if (errno == EINTR) - continue; - if (errno == EAGAIN || errno == EWOULDBLOCK) - break; - else - check(n); - } + + if (error == IoWrapResult::Interrupted) + continue; + if (error == IoWrapResult::Wouldblock) + break; } const bool bufferHasData = writebuf.usedBytes() > 0; - setReadyForWriting(bufferHasData); + setReadyForWriting(bufferHasData || error == IoWrapResult::Wouldblock); if (!bufferHasData) { @@ -236,11 +455,22 @@ bool Client::keepAliveExpired() return result; } +std::string Client::getKeepAliveInfoString() const +{ + std::string s = "authenticated: " + std::to_string(authenticated) + ", keep-alive: " + std::to_string(keepalive) + "s, last activity " + + std::to_string(time(NULL) - lastActivity) + " seconds ago."; + return s; +} + +// Call this from a place you know the writeBufMutex is locked, or we're still only doing SSL accept. void Client::setReadyForWriting(bool val) { if (disconnecting) return; + if (sslReadWantsWrite) + val = true; + if (val == this->readyForWriting) return; @@ -360,17 +590,29 @@ std::shared_ptr Client::getSession() return this->session; } +void Client::setDisconnectReason(const std::string &reason) +{ + // If we have a chain of errors causing this to be set, probably the first one is the most interesting. + if (!disconnectReason.empty()) + return; + this->disconnectReason = reason; +} +IncompleteSslWrite::IncompleteSslWrite(const void *buf, size_t nbytes) : + buf(buf), + nbytes(nbytes) +{ +} +void IncompleteSslWrite::reset() +{ + buf = nullptr; + nbytes = 0; +} - - - - - - - - - +bool IncompleteSslWrite::hasPendingWrite() +{ + return buf != nullptr; +} diff --git a/client.h b/client.h index fc8ed03..32c4174 100644 --- a/client.h +++ b/client.h @@ -8,6 +8,8 @@ #include #include +#include + #include "forward_declarations.h" #include "threaddata.h" @@ -15,14 +17,50 @@ #include "exceptions.h" #include "cirbuf.h" +#include +#include #define CLIENT_BUFFER_SIZE 1024 // Must be power of 2 #define MAX_PACKET_SIZE 268435461 // 256 MB + 5 #define MQTT_HEADER_LENGH 2 +#define OPENSSL_ERROR_STRING_SIZE 256 // OpenSSL requires at least 256. +#define OPENSSL_WRONG_VERSION_NUMBER 336130315 + +enum class IoWrapResult +{ + Success = 0, + Interrupted = 1, + Wouldblock = 2, + Disconnected = 3, + Error = 4 +}; + +/* + * OpenSSL doc: "When a write function call has to be repeated because SSL_get_error(3) returned + * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments" + */ +struct IncompleteSslWrite +{ + const void *buf = nullptr; + size_t nbytes = 0; + + IncompleteSslWrite() = default; + IncompleteSslWrite(const void *buf, size_t nbytes); + bool hasPendingWrite(); + + void reset(); +}; + +// TODO: give accepted addr, for showing in logs class Client { int fd; + SSL *ssl = nullptr; + bool sslAccepted = false; + IncompleteSslWrite incompleteSslWrite; + bool sslReadWantsWrite = false; + bool sslWriteWantsRead = false; CirBuf readbuf; uint8_t readBufIsZeroCount = 0; @@ -36,6 +74,7 @@ class Client bool readyForReading = true; bool disconnectWhenBytesWritten = false; bool disconnecting = false; + std::string disconnectReason; time_t lastActivity = time(NULL); std::string clientid; @@ -53,18 +92,27 @@ class Client std::shared_ptr session; + Logger *logger = Logger::getInstance(); + void setReadyForWriting(bool val); void setReadyForReading(bool val); public: - Client(int fd, ThreadData_p threadData); + Client(int fd, ThreadData_p threadData, SSL *ssl); Client(const Client &other) = delete; Client(Client &&other) = delete; ~Client(); int getFd() { return fd;} + bool isSslAccepted() const; + bool isSsl() const; + bool getSslReadWantsWrite() const; + bool getSslWriteWantsRead() const; + + void startOrContinueSslAccept(); void markAsDisconnecting(); + ssize_t readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error); bool readFdIntoBuffer(); bool bufferToMqttPackets(std::vector &packetQueueIn, Client_p &sender); void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession); @@ -77,10 +125,12 @@ public: bool getCleanSession() { return cleanSession; } void assignSession(std::shared_ptr &session); std::shared_ptr getSession(); + void setDisconnectReason(const std::string &reason); void writePingResp(); void writeMqttPacket(const MqttPacket &packet); void writeMqttPacketAndBlameThisClient(const MqttPacket &packet); + ssize_t writeWrap(int fd, const void *buf, size_t nbytes, IoWrapResult *error); bool writeBufIntoFd(); bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; } @@ -89,6 +139,7 @@ public: std::string repr(); bool keepAliveExpired(); + std::string getKeepAliveInfoString() const; }; diff --git a/configfileparser.cpp b/configfileparser.cpp index ff34e46..b9cc8b7 100644 --- a/configfileparser.cpp +++ b/configfileparser.cpp @@ -5,10 +5,15 @@ #include #include "fstream" +#include "openssl/ssl.h" +#include "openssl/err.h" + #include "exceptions.h" #include "utils.h" #include +#include "logger.h" + mosquitto_auth_opt::mosquitto_auth_opt(const std::string &key, const std::string &value) { @@ -41,24 +46,66 @@ AuthOptCompatWrap::AuthOptCompatWrap(const std::unordered_mapauthPluginPath = value; + checkFileAccess(key, value); + if (!test) + this->authPluginPath = value; } if (key == "log_file") { - this->logPath = value; + checkFileAccess(key, value); + if (!test) + this->logPath = value; + } + + if (key == "fullchain") + { + checkFileAccess(key, value); + sslFullChainTmp = value; + } + + if (key == "privkey") + { + checkFileAccess(key, value); + sslPrivkeyTmp = value; + } + + try + { + // TODO: make this possible. There are many error cases to deal with, like bind failures, etc. You don't want to end up without listeners. + if (key == "listen_port") + { + uint listenportNew = std::stoi(value); + if (listenPort > 0 && listenPort != listenportNew) + throw ConfigFileException("Changing (ssl_)listen_port is not supported at this time."); + listenPort = listenportNew; + } + + // TODO: make this possible. There are many error cases to deal with, like bind failures, etc. You don't want to end up without listeners. + if (key == "ssl_listen_port") + { + uint sslListenPortNew = std::stoi(value); + if (sslListenPort > 0 && sslListenPort != sslListenPortNew) + throw ConfigFileException("Changing (ssl_)listen_port is not supported at this time."); + sslListenPort = sslListenPortNew; + } + + } + catch (std::invalid_argument &ex) + { + throw ConfigFileException(ex.what()); } } } + testSsl(sslFullChainTmp, sslPrivkeyTmp, sslListenPort); + this->sslFullchain = sslFullChainTmp; + this->sslPrivkey = sslPrivkeyTmp; + authOptCompatWrap.reset(new AuthOptCompatWrap(authOpts)); } diff --git a/configfileparser.h b/configfileparser.h index ee8d072..abebdeb 100644 --- a/configfileparser.h +++ b/configfileparser.h @@ -7,6 +7,8 @@ #include #include +#include "sslctxmanager.h" + struct mosquitto_auth_opt { char *key = nullptr; @@ -38,15 +40,20 @@ class ConfigFileParser std::unique_ptr authOptCompatWrap; - + void checkFileAccess(const std::string &key, const std::string &pathToCheck) const; + void testSsl(const std::string &fullchain, const std::string &privkey, uint portNr) const; public: ConfigFileParser(const std::string &path); - void loadFile(); + void loadFile(bool test); AuthOptCompatWrap &getAuthOptsCompat(); // Actual config options with their defaults. Just making them public, I can retrain myself misuing them. std::string authPluginPath; std::string logPath; + std::string sslFullchain; + std::string sslPrivkey; + uint listenPort = 1883; + uint sslListenPort = 0; }; #endif // CONFIGFILEPARSER_H diff --git a/logger.cpp b/logger.cpp index de7c5d4..7b066cb 100644 --- a/logger.cpp +++ b/logger.cpp @@ -121,3 +121,9 @@ void Logger::logf(int level, const char *str, va_list valist) #endif } } + +int logSslError(const char *str, size_t len, void *u) +{ + Logger *logger = Logger::getInstance(); + logger->logf(LOG_ERR, str); +} diff --git a/logger.h b/logger.h index 18d13a6..6f167bc 100644 --- a/logger.h +++ b/logger.h @@ -13,6 +13,8 @@ #define LOG_ERR 0x08 #define LOG_DEBUG 0x10 +int logSslError(const char *str, size_t len, void *u); + class Logger { static Logger *instance; diff --git a/main.cpp b/main.cpp index c85e3fe..3cc0398 100644 --- a/main.cpp +++ b/main.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "mainapp.h" diff --git a/mainapp.cpp b/mainapp.cpp index d2f595c..e733685 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -5,6 +5,11 @@ #include #include +#include +#include + +#include "logger.h" + #define MAX_EVENTS 1024 #define NR_OF_THREADS 4 @@ -60,18 +65,30 @@ void do_thread_work(ThreadData *threadData) { try { - if (cur_ev.events & EPOLLIN) + if (cur_ev.events & (EPOLLERR | EPOLLHUP)) + { + client->setDisconnectReason("epoll says socket is in ERR or HUP state."); + threadData->removeClient(client); + continue; + } + if (client->isSsl() && !client->isSslAccepted()) + { + client->startOrContinueSslAccept(); + continue; + } + if ((cur_ev.events & EPOLLIN) || ((cur_ev.events & EPOLLOUT) && client->getSslReadWantsWrite())) { bool readSuccess = client->readFdIntoBuffer(); client->bufferToMqttPackets(packetQueueIn, client); if (!readSuccess) { + client->setDisconnectReason("socket disconnect detected"); threadData->removeClient(client); continue; } } - if (cur_ev.events & EPOLLOUT) + if ((cur_ev.events & EPOLLOUT) || ((cur_ev.events & EPOLLIN) && client->getSslWriteWantsRead())) { if (!client->writeBufIntoFd()) { @@ -85,13 +102,10 @@ void do_thread_work(ThreadData *threadData) continue; } } - if (cur_ev.events & (EPOLLERR | EPOLLHUP)) - { - threadData->removeClient(client); - } } catch(std::exception &ex) { + client->setDisconnectReason(ex.what()); logger->logf(LOG_ERR, "Packet read/write error: %s. Removing client.", ex.what()); threadData->removeClient(client); } @@ -107,6 +121,7 @@ void do_thread_work(ThreadData *threadData) } catch (std::exception &ex) { + packet.getSender()->setDisconnectReason(ex.what()); logger->logf(LOG_ERR, "MqttPacket handling error: %s. Removing client.", ex.what()); threadData->removeClient(packet.getSender()); } @@ -133,12 +148,22 @@ void do_thread_work(ThreadData *threadData) MainApp::MainApp(const std::string &configFilePath) : subscriptionStore(new SubscriptionStore()) { + epollFdAccept = check(epoll_create(999)); taskEventFd = eventfd(0, EFD_NONBLOCK); confFileParser.reset(new ConfigFileParser(configFilePath)); loadConfig(); } +MainApp::~MainApp() +{ + if (sslctx) + SSL_CTX_free(sslctx); + + if (epollFdAccept > 0) + close(epollFdAccept); +} + void MainApp::doHelp(const char *arg) { puts("FlashMQ - the scalable light-weight MQTT broker"); @@ -147,6 +172,7 @@ void MainApp::doHelp(const char *arg) puts(""); puts(" -h, --help Print help"); puts(" -c, --config-file Configuration file."); + puts(" -t, --test-config Test configuration file."); puts(" -V, --version Show version"); puts(" -l, --license Show license"); } @@ -161,6 +187,56 @@ void MainApp::showLicense() puts("Author: Wiebe Cazemier "); } +void MainApp::setCertAndKeyFromConfig() +{ + if (sslctx == nullptr) + return; + + if (SSL_CTX_use_certificate_file(sslctx, confFileParser->sslFullchain.c_str(), SSL_FILETYPE_PEM) != 1) + throw std::runtime_error("Loading cert failed. This was after test loading the certificate, so is very unexpected."); + if (SSL_CTX_use_PrivateKey_file(sslctx, confFileParser->sslPrivkey.c_str(), SSL_FILETYPE_PEM) != 1) + throw std::runtime_error("Loading key failed. This was after test loading the certificate, so is very unexpected."); +} + +int MainApp::createListenSocket(int portNr, bool ssl) +{ + if (portNr <= 0) + return -2; + + int listen_fd = check(socket(AF_INET, SOCK_STREAM, 0)); + + // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT. + int optval = 1; + check(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); + + int flags = fcntl(listen_fd, F_GETFL); + check(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK )); + + struct sockaddr_in in_addr_plain; + in_addr_plain.sin_family = AF_INET; + in_addr_plain.sin_addr.s_addr = INADDR_ANY; + in_addr_plain.sin_port = htons(portNr); + + check(bind(listen_fd, (struct sockaddr *)(&in_addr_plain), sizeof(struct sockaddr_in))); + check(listen(listen_fd, 1024)); + + struct epoll_event ev; + memset(&ev, 0, sizeof (struct epoll_event)); + + ev.data.fd = listen_fd; + ev.events = EPOLLIN; + check(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, listen_fd, &ev)); + + std::string socketType = "plain"; + + if (ssl) + socketType = "SSL"; + + logger->logf(LOG_NOTICE, "Listening on %s port %d", socketType.c_str(), portNr); + + return listen_fd; +} + void MainApp::initMainApp(int argc, char *argv[]) { if (instance != nullptr) @@ -170,6 +246,7 @@ void MainApp::initMainApp(int argc, char *argv[]) { {"help", no_argument, nullptr, 'h'}, {"config-file", required_argument, nullptr, 'c'}, + {"test-config", no_argument, nullptr, 't'}, {"version", no_argument, nullptr, 'V'}, {"license", no_argument, nullptr, 'l'}, {nullptr, 0, nullptr, 0} @@ -179,7 +256,8 @@ void MainApp::initMainApp(int argc, char *argv[]) int option_index = 0; int opt; - while((opt = getopt_long(argc, argv, "hc:Vl", long_options, &option_index)) != -1) + bool testConfig = false; + while((opt = getopt_long(argc, argv, "hc:Vlt", long_options, &option_index)) != -1) { switch(opt) { @@ -195,12 +273,38 @@ void MainApp::initMainApp(int argc, char *argv[]) case 'h': MainApp::doHelp(argv[0]); exit(16); + case 't': + testConfig = true; + break; case '?': MainApp::doHelp(argv[0]); exit(16); } } + if (testConfig) + { + try + { + if (configFile.empty()) + { + std::cerr << "No config specified." << std::endl; + MainApp::doHelp(argv[0]); + exit(1); + } + + ConfigFileParser c(configFile); + c.loadFile(true); + puts("Config OK"); + exit(0); + } + catch (ConfigFileException &ex) + { + std::cerr << ex.what() << std::endl; + exit(1); + } + } + instance = new MainApp(configFile); } @@ -214,40 +318,14 @@ MainApp *MainApp::getMainApp() void MainApp::start() { - Logger *logger = Logger::getInstance(); - - int listen_fd = check(socket(AF_INET, SOCK_STREAM, 0)); + int listen_fd_plain = createListenSocket(this->listenPort, false); + int listen_fd_ssl = createListenSocket(this->sslListenPort, true); - // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT. - int optval = 1; - check(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); - - int flags = fcntl(listen_fd, F_GETFL); - check(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK )); - - struct sockaddr_in in_addr; - in_addr.sin_family = AF_INET; - in_addr.sin_addr.s_addr = INADDR_ANY; - in_addr.sin_port = htons(1883); - - check(bind(listen_fd, (struct sockaddr *)(&in_addr), sizeof(struct sockaddr_in))); - check(listen(listen_fd, 1024)); - - int epoll_fd_accept = check(epoll_create(999)); - - struct epoll_event events[MAX_EVENTS]; struct epoll_event ev; memset(&ev, 0, sizeof (struct epoll_event)); - memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); - - ev.data.fd = listen_fd; - ev.events = EPOLLIN; - check(epoll_ctl(epoll_fd_accept, EPOLL_CTL_ADD, listen_fd, &ev)); - - memset(&ev, 0, sizeof (struct epoll_event)); ev.data.fd = taskEventFd; ev.events = EPOLLIN; - check(epoll_ctl(epoll_fd_accept, EPOLL_CTL_ADD, taskEventFd, &ev)); + check(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, taskEventFd, &ev)); for (int i = 0; i < NR_OF_THREADS; i++) { @@ -256,14 +334,15 @@ void MainApp::start() threads.push_back(t); } - logger->logf(LOG_NOTICE, "Listening on port 1883"); - uint next_thread_index = 0; + struct epoll_event events[MAX_EVENTS]; + memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); + started = true; while (running) { - int num_fds = epoll_wait(epoll_fd_accept, events, MAX_EVENTS, 100); + int num_fds = epoll_wait(this->epollFdAccept, events, MAX_EVENTS, 100); if (num_fds < 0) { @@ -277,7 +356,7 @@ void MainApp::start() int cur_fd = events[i].data.fd; try { - if (cur_fd == listen_fd) + if (cur_fd == listen_fd_plain || cur_fd == listen_fd_ssl) { std::shared_ptr thread_data = threads[next_thread_index++ % NR_OF_THREADS]; @@ -288,7 +367,22 @@ void MainApp::start() socklen_t len = sizeof(struct sockaddr); int fd = check(accept(cur_fd, &addr, &len)); - Client_p client(new Client(fd, thread_data)); + SSL *clientSSL = nullptr; + if (cur_fd == listen_fd_ssl) + { + clientSSL = SSL_new(sslctx); + + if (clientSSL == NULL) + { + logger->logf(LOG_ERR, "Problem creating SSL object. Closing client."); + close(fd); + continue; + } + + SSL_set_fd(clientSSL, fd); + } + + Client_p client(new Client(fd, thread_data, clientSSL)); thread_data->giveClient(client); } else if (cur_fd == taskEventFd) @@ -320,7 +414,8 @@ void MainApp::start() thread->quit(); } - close(listen_fd); + close(listen_fd_plain); + close(listen_fd_ssl); } void MainApp::quit() @@ -330,12 +425,29 @@ void MainApp::quit() running = false; } +// Loaded on app start where you want it to crash, loaded from within try/catch on reload, to allow the program to continue. void MainApp::loadConfig() { Logger *logger = Logger::getInstance(); - confFileParser->loadFile(); + + // Atomic loading, first test. + confFileParser->loadFile(true); + confFileParser->loadFile(false); + logger->setLogPath(confFileParser->logPath); logger->reOpen(); + + listenPort = confFileParser->listenPort; + sslListenPort = confFileParser->sslListenPort; + + if (sslctx == nullptr && sslListenPort > 0) + { + sslctx = SSL_CTX_new(TLS_server_method()); + SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv3); // TODO: config option + SSL_CTX_set_options(sslctx, SSL_OP_NO_TLSv1); // TODO: config option + } + + setCertAndKeyFromConfig(); } void MainApp::reloadConfig() diff --git a/mainapp.h b/mainapp.h index 9a1ea46..6a6573e 100644 --- a/mainapp.h +++ b/mainapp.h @@ -30,23 +30,34 @@ class MainApp std::shared_ptr subscriptionStore; std::unique_ptr confFileParser; std::forward_list> taskQueue; + int epollFdAccept = -1; int taskEventFd = -1; std::mutex eventMutex; + uint listenPort = 0; + uint sslListenPort = 0; + SSL_CTX *sslctx = nullptr; + + Logger *logger = Logger::getInstance(); + void loadConfig(); void reloadConfig(); static void doHelp(const char *arg); static void showLicense(); + void setCertAndKeyFromConfig(); + int createListenSocket(int portNr, bool ssl); MainApp(const std::string &configFilePath); public: MainApp(const MainApp &rhs) = delete; MainApp(MainApp &&rhs) = delete; + ~MainApp(); static MainApp *getMainApp(); static void initMainApp(int argc, char *argv[]); void start(); void quit(); bool getStarted() const {return started;} + static void testConfig(); void queueConfigReload(); diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 184170b..b4c6ede 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -281,6 +281,7 @@ void MqttPacket::handleDisconnect() { logger->logf(LOG_NOTICE, "Client '%s' cleanly disconnecting", sender->repr().c_str()); sender->markAsDisconnecting(); + sender->setDisconnectReason("MQTT Disconnect received."); sender->getThreadData()->removeClient(sender); // TODO: clear will diff --git a/sslctxmanager.cpp b/sslctxmanager.cpp new file mode 100644 index 0000000..7519f44 --- /dev/null +++ b/sslctxmanager.cpp @@ -0,0 +1,18 @@ +#include "sslctxmanager.h" + +SslCtxManager::SslCtxManager() : + ssl_ctx(SSL_CTX_new(TLS_server_method())) +{ + +} + +SslCtxManager::~SslCtxManager() +{ + if (ssl_ctx) + SSL_CTX_free(ssl_ctx); +} + +SSL_CTX *SslCtxManager::get() const +{ + return ssl_ctx; +} diff --git a/sslctxmanager.h b/sslctxmanager.h new file mode 100644 index 0000000..bd5c227 --- /dev/null +++ b/sslctxmanager.h @@ -0,0 +1,16 @@ +#ifndef SSLCTXMANAGER_H +#define SSLCTXMANAGER_H + +#include "openssl/ssl.h" + +class SslCtxManager +{ + SSL_CTX *ssl_ctx = nullptr; +public: + SslCtxManager(); + ~SslCtxManager(); + + SSL_CTX *get() const; +}; + +#endif // SSLCTXMANAGER_H diff --git a/threaddata.cpp b/threaddata.cpp index 5816633..d9d3a5d 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -87,6 +87,7 @@ bool ThreadData::doKeepAliveCheck() Client_p &client = it->second; if (client && client->keepAliveExpired()) { + client->setDisconnectReason("Keep-alive expired: " + client->getKeepAliveInfoString()); it = clients_by_fd.erase(it); } else -- libgit2 0.21.4