diff --git a/CMakeLists.txt b/CMakeLists.txt index ad1bfa7..fb94126 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,6 +29,7 @@ add_executable(FlashMQ sslctxmanager.cpp timer.cpp globalsettings.cpp + iowrapper.cpp ) target_link_libraries(FlashMQ pthread dl ssl crypto) diff --git a/cirbuf.cpp b/cirbuf.cpp index 51279ba..3ea3bce 100644 --- a/cirbuf.cpp +++ b/cirbuf.cpp @@ -9,10 +9,14 @@ #include #include "logger.h" +#include "utils.h" CirBuf::CirBuf(size_t size) : size(size) { + if (size == 0) + return; + buf = (char*)malloc(size); if (buf == NULL) @@ -69,12 +73,14 @@ char *CirBuf::tailPtr() void CirBuf::advanceHead(uint32_t n) { + assert(n <= freeSpace()); head = (head + n) & (size -1); assert(tail != head); // Putting things in the buffer must never end on tail, because tail == head == empty. } void CirBuf::advanceTail(uint32_t n) { + assert(n <= usedBytes()); tail = (tail + n) & (size -1); } @@ -84,9 +90,31 @@ char CirBuf::peakAhead(uint32_t offset) const return b; } -void CirBuf::doubleSize() +void CirBuf::ensureFreeSpace(size_t n) +{ + const size_t _usedBytes = usedBytes(); + + int mul = 1; + + while((mul * size - _usedBytes - 1) < n) + { + mul = mul << 1; + } + + if (mul == 1) + return; + + doubleSize(mul); +} + +void CirBuf::doubleSize(uint factor) { - uint newSize = size * 2; + if (factor == 1) + return; + + assert(isPowerOfTwo(factor)); + + uint newSize = size * factor; char *newBuf = (char*)realloc(buf, newSize); if (newBuf == NULL) @@ -145,3 +173,55 @@ void CirBuf::resetSize(size_t newSize) memset(buf, 0, newSize); #endif } + +void CirBuf::reset() +{ + head = 0; + tail = 0; + +#ifndef NDEBUG + memset(buf, 0, size); +#endif +} + +// When you know the data you want to write fits in the buffer, use this. +void CirBuf::write(const void *buf, size_t count) +{ + assert(count <= freeSpace()); + + ssize_t len_left = count; + size_t src_i = 0; + while (len_left > 0) + { + const size_t len = std::min(len_left, maxWriteSize()); + assert(len > 0); + const char *src = &reinterpret_cast(buf)[src_i]; + std::memcpy(headPtr(), src, len); + advanceHead(len); + src_i += len; + len_left -= len; + } + assert(len_left == 0); + assert(src_i == count); +} + +// When you know 'count' bytes are present and you want to read them into buf +void CirBuf::read(void *buf, const size_t count) +{ + assert(count <= usedBytes()); + + char *_buf = static_cast(buf); + int i = 0; + ssize_t _packet_len = count; + while (_packet_len > 0) + { + const int readlen = std::min(maxReadSize(), _packet_len); + assert(readlen > 0); + std::memcpy(&_buf[i], tailPtr(), readlen); + advanceTail(readlen); + i += readlen; + _packet_len -= readlen; + } + assert(_packet_len == 0); + assert(i == _packet_len); +} diff --git a/cirbuf.h b/cirbuf.h index b01ee76..d2eca01 100644 --- a/cirbuf.h +++ b/cirbuf.h @@ -32,11 +32,16 @@ public: void advanceHead(uint32_t n); void advanceTail(uint32_t n); char peakAhead(uint32_t offset) const; - void doubleSize(); + void ensureFreeSpace(size_t n); + void doubleSize(uint factor = 2); uint32_t getSize() const; time_t bufferLastResizedSecondsAgo() const; void resetSize(size_t size); + void reset(); + + void write(const void *buf, size_t count); + void read(void *buf, const size_t count); }; #endif // CIRBUF_H diff --git a/client.cpp b/client.cpp index 0f539d6..a7ed216 100644 --- a/client.cpp +++ b/client.cpp @@ -7,11 +7,11 @@ #include "logger.h" -Client::Client(int fd, ThreadData_p threadData, SSL *ssl, const GlobalSettings &settings) : +Client::Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, const GlobalSettings &settings) : fd(fd), - ssl(ssl), initialBufferSize(settings.clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy maxPacketSize(settings.maxPacketSize), // Same as initialBufferSize comment. + ioWrapper(ssl, websocket, initialBufferSize, this), readbuf(initialBufferSize), writebuf(initialBufferSize), threadData(threadData) @@ -40,63 +40,32 @@ Client::~Client() logger->logf(LOG_NOTICE, "Removing client '%s'. Reason(s): %s", repr().c_str(), disconnectReason.c_str()); if (epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL) != 0) logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd, repr().c_str(), strerror(errno)); - if (ssl) - { - // I don't do SSL_shutdown(), because I don't want to keep the session, plus, that takes active de-negiotation, so it can't be done - // in the destructor. - SSL_free(ssl); - } close(fd); } bool Client::isSslAccepted() const { - return sslAccepted; + return ioWrapper.isSslAccepted(); } bool Client::isSsl() const { - return this->ssl != nullptr; + return ioWrapper.isSsl(); } bool Client::getSslReadWantsWrite() const { - return this->sslReadWantsWrite; + return ioWrapper.getSslReadWantsWrite(); } bool Client::getSslWriteWantsRead() const { - return this->sslWriteWantsRead; + return ioWrapper.getSslWriteWantsRead(); } void Client::startOrContinueSslAccept() { - ERR_clear_error(); - int accepted = SSL_accept(ssl); - char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; - if (accepted <= 0) - { - int err = SSL_get_error(ssl, accepted); - - if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) - { - setReadyForWriting(err == SSL_ERROR_WANT_WRITE); - return; - } - - unsigned long error_code = ERR_get_error(); - - ERR_error_string(error_code, sslErrorBuf); - std::string errorMsg(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); - - if (error_code == OPENSSL_WRONG_VERSION_NUMBER) - errorMsg = "Wrong protocol version number. Probably a non-SSL connection on SSL socket."; - - //ERR_print_errors_cb(logSslError, NULL); - throw std::runtime_error("Problem accepting SSL socket: " + errorMsg); - } - setReadyForWriting(false); // Undo write readiness that may have have happened during SSL handshake - sslAccepted = true; + ioWrapper.startOrContinueSslAccept(); } // Causes future activity on the client to cause a disconnect. @@ -108,83 +77,6 @@ void Client::markAsDisconnecting() disconnecting = true; } -// SSL and non-SSL sockets behave differently. For one, reading 0 doesn't mean 'disconnected' with an SSL -// socket. This wrapper unifies behavor for the caller. -ssize_t Client::readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error) -{ - *error = IoWrapResult::Success; - ssize_t n = 0; - if (!ssl) - { - n = read(fd, buf, nbytes); - if (n < 0) - { - if (errno == EINTR) - *error = IoWrapResult::Interrupted; - else if (errno == EAGAIN || errno == EWOULDBLOCK) - *error = IoWrapResult::Wouldblock; - else - check(n); - } - else if (n == 0) - { - *error = IoWrapResult::Disconnected; - } - } - else - { - this->sslReadWantsWrite = false; - ERR_clear_error(); - char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; - n = SSL_read(ssl, buf, nbytes); - - if (n <= 0) - { - int err = SSL_get_error(ssl, n); - unsigned long error_code = ERR_get_error(); - - // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html "BUGS" why EOF is seen as SSL_ERROR_SYSCALL. - if (err == SSL_ERROR_ZERO_RETURN || (err == SSL_ERROR_SYSCALL && errno == 0)) - { - *error = IoWrapResult::Disconnected; - } - else if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) - { - *error = IoWrapResult::Wouldblock; - if (err == SSL_ERROR_WANT_WRITE) - { - sslReadWantsWrite = true; - setReadyForWriting(true); - } - n = -1; - } - else - { - if (err == SSL_ERROR_SYSCALL) - { - // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say - // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it - // implies EINTR is not included? - if (errno == EINTR) - *error = IoWrapResult::Interrupted; - else - { - char *err = strerror(errno); - std::string msg(err); - throw std::runtime_error("SSL read error: " + msg); - } - } - ERR_error_string(error_code, sslErrorBuf); - std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); - ERR_print_errors_cb(logSslError, NULL); - throw std::runtime_error("SSL socket error reading: " + errorString); - } - } - } - - return n; -} - // false means any kind of error we want to get rid of the client for. bool Client::readFdIntoBuffer() { @@ -193,7 +85,7 @@ bool Client::readFdIntoBuffer() IoWrapResult error = IoWrapResult::Success; int n = 0; - while (readbuf.freeSpace() > 0 && (n = readWrap(fd, readbuf.headPtr(), readbuf.maxWriteSize(), &error)) != 0) + while (readbuf.freeSpace() > 0 && (n = ioWrapper.readWebsocketAndOrSsl(fd, readbuf.headPtr(), readbuf.maxWriteSize(), &error)) != 0) { if (n > 0) { @@ -232,6 +124,20 @@ bool Client::readFdIntoBuffer() return true; } +void Client::writeText(const std::string &text) +{ + assert(ioWrapper.isWebsocket()); + assert(ioWrapper.getWebsocketState() == WebsocketState::NotUpgraded); + + // Not necessary, because at this point, no other threads write to this client, but including for clarity. + std::lock_guard locker(writeBufMutex); + + writebuf.ensureFreeSpace(text.size()); + writebuf.write(text.c_str(), text.length()); + + setReadyForWriting(true); +} + void Client::writeMqttPacket(const MqttPacket &packet) { std::lock_guard locker(writeBufMutex); @@ -322,94 +228,6 @@ void Client::writePingResp() setReadyForWriting(true); } -// SSL and non-SSL sockets behave differently. This wrapper unifies behavor for the caller. -ssize_t Client::writeWrap(int fd, const void *buf, size_t nbytes, IoWrapResult *error) -{ - *error = IoWrapResult::Success; - ssize_t n = 0; - - if (!ssl) - { - // A write on a socket with count=0 is unspecified. - assert(nbytes > 0); - - n = write(fd, buf, nbytes); - if (n < 0) - { - if (errno == EINTR) - *error = IoWrapResult::Interrupted; - else if (errno == EAGAIN || errno == EWOULDBLOCK) - *error = IoWrapResult::Wouldblock; - else - check(n); - } - } - else - { - const void *buf_ = buf; - size_t nbytes_ = nbytes; - - /* - * OpenSSL doc: When a write function call has to be repeated because SSL_get_error(3) returned - * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments - */ - if (this->incompleteSslWrite.hasPendingWrite()) - { - buf_ = this->incompleteSslWrite.buf; - nbytes_ = this->incompleteSslWrite.nbytes; - } - - // OpenSSL: "You should not call SSL_write() with num=0, it will return an error" - assert(nbytes_ > 0); - - this->sslWriteWantsRead = false; - this->incompleteSslWrite.reset(); - - ERR_clear_error(); - char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; - n = SSL_write(ssl, buf_, nbytes_); - - if (n <= 0) - { - int err = SSL_get_error(ssl, n); - unsigned long error_code = ERR_get_error(); - if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) - { - logger->logf(LOG_DEBUG, "Write is incomplete: %d", err); - *error = IoWrapResult::Wouldblock; - IncompleteSslWrite sslAction(buf_, nbytes_); - this->incompleteSslWrite = sslAction; - if (err == SSL_ERROR_WANT_READ) - this->sslWriteWantsRead = true; - n = 0; - } - else - { - if (err == SSL_ERROR_SYSCALL) - { - // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say - // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it - // implies EINTR is not included? - if (errno == EINTR) - *error = IoWrapResult::Interrupted; - else - { - char *err = strerror(errno); - std::string msg(err); - throw std::runtime_error(msg); - } - } - ERR_error_string(error_code, sslErrorBuf); - std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); - ERR_print_errors_cb(logSslError, NULL); - throw std::runtime_error("SSL socket error writing: " + errorString); - } - } - } - - return n; -} - bool Client::writeBufIntoFd() { std::unique_lock lock(writeBufMutex, std::try_to_lock); @@ -422,9 +240,9 @@ bool Client::writeBufIntoFd() IoWrapResult error = IoWrapResult::Success; int n; - while (writebuf.usedBytes() > 0 || incompleteSslWrite.hasPendingWrite()) + while (writebuf.usedBytes() > 0 || ioWrapper.hasPendingWrite()) { - n = writeWrap(fd, writebuf.tailPtr(), writebuf.maxReadSize(), &error); + n = ioWrapper.writeWebsocketAndOrSsl(fd, writebuf.tailPtr(), writebuf.maxReadSize(), &error); if (n > 0) writebuf.advanceTail(n); @@ -484,7 +302,7 @@ void Client::setReadyForWriting(bool val) if (disconnecting) return; - if (sslReadWantsWrite) + if (ioWrapper.getSslReadWantsWrite()) val = true; if (val == this->readyForWriting) @@ -622,20 +440,3 @@ void Client::clearWill() will_qos = 0; } -IncompleteSslWrite::IncompleteSslWrite(const void *buf, size_t nbytes) : - buf(buf), - nbytes(nbytes) -{ - -} - -void IncompleteSslWrite::reset() -{ - buf = nullptr; - nbytes = 0; -} - -bool IncompleteSslWrite::hasPendingWrite() -{ - return buf != nullptr; -} diff --git a/client.h b/client.h index 5c06253..27da06e 100644 --- a/client.h +++ b/client.h @@ -9,6 +9,7 @@ #include #include +#include #include "forward_declarations.h" @@ -17,54 +18,25 @@ #include "exceptions.h" #include "cirbuf.h" #include "types.h" - -#include -#include +#include "iowrapper.h" #define MQTT_HEADER_LENGH 2 -#define OPENSSL_ERROR_STRING_SIZE 256 // OpenSSL requires at least 256. -#define OPENSSL_WRONG_VERSION_NUMBER 336130315 - -enum class IoWrapResult -{ - Success = 0, - Interrupted = 1, - Wouldblock = 2, - Disconnected = 3, - Error = 4 -}; - -/* - * OpenSSL doc: "When a write function call has to be repeated because SSL_get_error(3) returned - * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments" - */ -struct IncompleteSslWrite -{ - const void *buf = nullptr; - size_t nbytes = 0; - - IncompleteSslWrite() = default; - IncompleteSslWrite(const void *buf, size_t nbytes); - bool hasPendingWrite(); - - void reset(); -}; // TODO: give accepted addr, for showing in logs class Client { + friend class IoWrapper; + int fd; - SSL *ssl = nullptr; - bool sslAccepted = false; - IncompleteSslWrite incompleteSslWrite; - bool sslReadWantsWrite = false; - bool sslWriteWantsRead = false; + ProtocolVersion protocolVersion = ProtocolVersion::None; const size_t initialBufferSize = 0; const size_t maxPacketSize = 0; + IoWrapper ioWrapper; + CirBuf readbuf; uint8_t readBufIsZeroCount = 0; @@ -97,12 +69,11 @@ class Client Logger *logger = Logger::getInstance(); - void setReadyForWriting(bool val); void setReadyForReading(bool val); public: - Client(int fd, ThreadData_p threadData, SSL *ssl, const GlobalSettings &settings); + Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, const GlobalSettings &settings); Client(const Client &other) = delete; Client(Client &&other) = delete; ~Client(); @@ -115,7 +86,6 @@ public: void startOrContinueSslAccept(); void markAsDisconnecting(); - ssize_t readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error); bool readFdIntoBuffer(); bool bufferToMqttPackets(std::vector &packetQueueIn, Client_p &sender); void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession); @@ -131,10 +101,10 @@ public: std::shared_ptr getSession(); void setDisconnectReason(const std::string &reason); + void writeText(const std::string &text); void writePingResp(); void writeMqttPacket(const MqttPacket &packet); void writeMqttPacketAndBlameThisClient(const MqttPacket &packet); - ssize_t writeWrap(int fd, const void *buf, size_t nbytes, IoWrapResult *error); bool writeBufIntoFd(); bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; } diff --git a/exceptions.h b/exceptions.h index 563ce89..a27595d 100644 --- a/exceptions.h +++ b/exceptions.h @@ -34,4 +34,16 @@ public: AuthPluginException(const std::string &msg) : std::runtime_error(msg) {} }; +class BadWebsocketVersionException : public std::runtime_error +{ +public: + BadWebsocketVersionException(const std::string &msg) : std::runtime_error(msg) {} +}; + +class BadHttpRequest : public std::runtime_error +{ +public: + BadHttpRequest(const std::string &msg) : std::runtime_error(msg) {} +}; + #endif // EXCEPTIONS_H diff --git a/iowrapper.cpp b/iowrapper.cpp new file mode 100644 index 0000000..d882e69 --- /dev/null +++ b/iowrapper.cpp @@ -0,0 +1,595 @@ +#include "iowrapper.h" + +#include "cassert" + +#include "logger.h" +#include "client.h" + +IncompleteSslWrite::IncompleteSslWrite(const void *buf, size_t nbytes) : + buf(buf), + nbytes(nbytes) +{ + +} + +void IncompleteSslWrite::reset() +{ + buf = nullptr; + nbytes = 0; +} + +bool IncompleteSslWrite::hasPendingWrite() const +{ + return buf != nullptr; +} + +void IncompleteWebsocketRead::reset() +{ + maskingKeyI = 0; + memset(maskingKey,0, 4); + frame_bytes_left = 0; + opcode = WebsocketOpcode::Unknown; +} + +bool IncompleteWebsocketRead::sillWorkingOnFrame() const +{ + return frame_bytes_left > 0; +} + +char IncompleteWebsocketRead::getNextMaskingByte() +{ + return maskingKey[maskingKeyI++ % 4]; +} + +IoWrapper::IoWrapper(SSL *ssl, bool websocket, const size_t initialBufferSize, Client *parent) : + parentClient(parent), + initialBufferSize(initialBufferSize), + ssl(ssl), + websocket(websocket), + websocketPendingBytes(websocket ? initialBufferSize : 0), + websocketWriteRemainder(websocket ? initialBufferSize : 0) +{ + +} + +IoWrapper::~IoWrapper() +{ + if (ssl) + { + // I don't do SSL_shutdown(), because I don't want to keep the session, plus, that takes active de-negiotation, so it can't be done + // in the destructor. + SSL_free(ssl); + } +} + +void IoWrapper::startOrContinueSslAccept() +{ + ERR_clear_error(); + int accepted = SSL_accept(ssl); + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; + if (accepted <= 0) + { + int err = SSL_get_error(ssl, accepted); + + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) + { + parentClient->setReadyForWriting(err == SSL_ERROR_WANT_WRITE); + return; + } + + unsigned long error_code = ERR_get_error(); + + ERR_error_string(error_code, sslErrorBuf); + std::string errorMsg(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); + + if (error_code == OPENSSL_WRONG_VERSION_NUMBER) + errorMsg = "Wrong protocol version number. Probably a non-SSL connection on SSL socket."; + + //ERR_print_errors_cb(logSslError, NULL); + throw std::runtime_error("Problem accepting SSL socket: " + errorMsg); + } + parentClient->setReadyForWriting(false); // Undo write readiness that may have have happened during SSL handshake + sslAccepted = true; +} + +bool IoWrapper::getSslReadWantsWrite() const +{ + return this->sslReadWantsWrite; +} + +bool IoWrapper::getSslWriteWantsRead() const +{ + return sslWriteWantsRead; +} + +bool IoWrapper::isSslAccepted() const +{ + return this->sslAccepted; +} + +bool IoWrapper::isSsl() const +{ + return this->ssl != nullptr; +} + +bool IoWrapper::hasPendingWrite() const +{ + return incompleteSslWrite.hasPendingWrite() || websocketWriteRemainder.usedBytes() > 0; +} + +bool IoWrapper::isWebsocket() const +{ + return websocket; +} + +WebsocketState IoWrapper::getWebsocketState() const +{ + return websocketState; +} + +/** + * @brief SSL and non-SSL sockets behave differently. For one, reading 0 doesn't mean 'disconnected' with an SSL + * socket. This wrapper unifies behavor for the caller. + * + * @param fd + * @param buf + * @param nbytes + * @param error is an out-argument with the result. + * @return + */ +ssize_t IoWrapper::readOrSslRead(int fd, void *buf, size_t nbytes, IoWrapResult *error) +{ + *error = IoWrapResult::Success; + ssize_t n = 0; + if (!ssl) + { + n = read(fd, buf, nbytes); + if (n < 0) + { + if (errno == EINTR) + *error = IoWrapResult::Interrupted; + else if (errno == EAGAIN || errno == EWOULDBLOCK) + *error = IoWrapResult::Wouldblock; + else + check(n); + } + else if (n == 0) + { + *error = IoWrapResult::Disconnected; + } + } + else + { + this->sslReadWantsWrite = false; + ERR_clear_error(); + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; + n = SSL_read(ssl, buf, nbytes); + + if (n <= 0) + { + int err = SSL_get_error(ssl, n); + unsigned long error_code = ERR_get_error(); + + // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html "BUGS" why EOF is seen as SSL_ERROR_SYSCALL. + if (err == SSL_ERROR_ZERO_RETURN || (err == SSL_ERROR_SYSCALL && errno == 0)) + { + *error = IoWrapResult::Disconnected; + } + else if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) + { + *error = IoWrapResult::Wouldblock; + if (err == SSL_ERROR_WANT_WRITE) + { + sslReadWantsWrite = true; + parentClient->setReadyForWriting(true); + } + n = -1; + } + else + { + if (err == SSL_ERROR_SYSCALL) + { + // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say + // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it + // implies EINTR is not included? + if (errno == EINTR) + *error = IoWrapResult::Interrupted; + else + { + char *err = strerror(errno); + std::string msg(err); + throw std::runtime_error("SSL read error: " + msg); + } + } + ERR_error_string(error_code, sslErrorBuf); + std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); + ERR_print_errors_cb(logSslError, NULL); + throw std::runtime_error("SSL socket error reading: " + errorString); + } + } + } + + return n; +} + +// SSL and non-SSL sockets behave differently. This wrapper unifies behavor for the caller. +ssize_t IoWrapper::writeOrSslWrite(int fd, const void *buf, size_t nbytes, IoWrapResult *error) +{ + *error = IoWrapResult::Success; + ssize_t n = 0; + + if (!ssl) + { + // A write on a socket with count=0 is unspecified. + assert(nbytes > 0); + + n = write(fd, buf, nbytes); + if (n < 0) + { + if (errno == EINTR) + *error = IoWrapResult::Interrupted; + else if (errno == EAGAIN || errno == EWOULDBLOCK) + *error = IoWrapResult::Wouldblock; + else + check(n); + } + } + else + { + const void *buf_ = buf; + size_t nbytes_ = nbytes; + + /* + * OpenSSL doc: When a write function call has to be repeated because SSL_get_error(3) returned + * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments + */ + if (this->incompleteSslWrite.hasPendingWrite()) + { + buf_ = this->incompleteSslWrite.buf; + nbytes_ = this->incompleteSslWrite.nbytes; + } + + // OpenSSL: "You should not call SSL_write() with num=0, it will return an error" + assert(nbytes_ > 0); + + this->sslWriteWantsRead = false; + this->incompleteSslWrite.reset(); + + ERR_clear_error(); + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; + n = SSL_write(ssl, buf_, nbytes_); + + if (n <= 0) + { + int err = SSL_get_error(ssl, n); + unsigned long error_code = ERR_get_error(); + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) + { + logger->logf(LOG_DEBUG, "Write is incomplete: %d", err); + *error = IoWrapResult::Wouldblock; + IncompleteSslWrite sslAction(buf_, nbytes_); + this->incompleteSslWrite = sslAction; + if (err == SSL_ERROR_WANT_READ) + this->sslWriteWantsRead = true; + n = 0; + } + else + { + if (err == SSL_ERROR_SYSCALL) + { + // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say + // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it + // implies EINTR is not included? + if (errno == EINTR) + *error = IoWrapResult::Interrupted; + else + { + char *err = strerror(errno); + std::string msg(err); + throw std::runtime_error(msg); + } + } + ERR_error_string(error_code, sslErrorBuf); + std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); + ERR_print_errors_cb(logSslError, NULL); + throw std::runtime_error("SSL socket error writing: " + errorString); + } + } + } + + return n; +} + +// Use a small intermediate buffer to write (partial) websocket frames to our normal read buffer. MQTT is already a frames protocol, so we don't +// care about websocket frames being incomplete. +ssize_t IoWrapper::readWebsocketAndOrSsl(int fd, void *buf, size_t nbytes, IoWrapResult *error) +{ + if (!websocket) + { + return readOrSslRead(fd, buf, nbytes, error); + } + else + { + ssize_t n = 0; + while (websocketPendingBytes.freeSpace() > 0 && (n = readOrSslRead(fd, websocketPendingBytes.headPtr(), websocketPendingBytes.maxWriteSize(), error)) != 0) + { + if (n > 0) + websocketPendingBytes.advanceHead(n); + if (n < 0) + break; // signal/error handling is done by the caller, so we just stop. + + if (websocketState == WebsocketState::NotUpgraded && websocketPendingBytes.freeSpace() == 0) + { + if (websocketPendingBytes.getSize() * 2 <= 8192) + websocketPendingBytes.doubleSize(); + else + throw ProtocolError("Trying to exceed websocket buffer. Probably not valid websocket traffic."); + } + } + + const bool hasWebsocketPendingBytes = websocketPendingBytes.usedBytes() > 0; + + // When some or all the data has been read, we can continue. + if (!(*error == IoWrapResult::Wouldblock || *error == IoWrapResult::Success) && !hasWebsocketPendingBytes) + return n; + + if (hasWebsocketPendingBytes) + { + n = 0; + + if (websocketState == WebsocketState::NotUpgraded) + { + try + { + std::string websocketKey; + int websocketVersion; + if (parseHttpHeader(websocketPendingBytes, websocketKey, websocketVersion)) + { + if (websocketKey.empty()) + throw BadHttpRequest("No websocket key specified."); + if (websocketVersion != 13) + throw BadWebsocketVersionException("Websocket version 13 required."); + + const std::string acceptString = generateWebsocketAcceptString(websocketKey); + + std::string answer = generateWebsocketAnswer(acceptString); + parentClient->writeText(answer); + websocketState = WebsocketState::Upgrading; + websocketPendingBytes.reset(); + websocketPendingBytes.resetSize(initialBufferSize); + *error = IoWrapResult::Success; + } + } + catch (BadWebsocketVersionException &ex) + { + std::string response = generateInvalidWebsocketVersionHttpHeaders(13); + parentClient->writeText(response); + parentClient->setDisconnectReason("Invalid websocket version"); + parentClient->setReadyForDisconnect(); + } + catch (BadHttpRequest &ex) // Should should also properly deal with attempt at HTTP2 with PRI. + { + std::string response = generateBadHttpRequestReponse(ex.what()); + parentClient->writeText(response); + parentClient->setDisconnectReason("Invalid websocket start"); + parentClient->setReadyForDisconnect(); + } + } + else + { + n = websocketBytesToReadBuffer(buf, nbytes); + + if (n > 0) + *error = IoWrapResult::Success; + else if (n == 0) + *error = IoWrapResult::Wouldblock; + } + } + + return n; + } +} + +ssize_t IoWrapper::websocketBytesToReadBuffer(void *buf, const size_t nbytes) +{ + const ssize_t targetBufMaxSize = nbytes; + ssize_t nbytesRead = 0; + + while (websocketPendingBytes.usedBytes() >= WEBSOCKET_MIN_HEADER_BYTES_NEEDED && nbytesRead < targetBufMaxSize) + { + // This block decodes the header. + if (!incompleteWebsocketRead.sillWorkingOnFrame()) + { + const uint8_t byte1 = websocketPendingBytes.peakAhead(0); + const uint8_t byte2 = websocketPendingBytes.peakAhead(1); + bool masked = !!(byte2 & 0b10000000); + uint8_t reserved = (byte1 & 0b01110000) >> 4; + WebsocketOpcode opcode = (WebsocketOpcode)(byte1 & 0b00001111); + const uint8_t payloadLength = byte2 & 0b01111111; + size_t realPayloadLength = payloadLength; + uint64_t extendedPayloadLengthLength = 0; + uint8_t headerLength = masked ? 6 : 2; + + if (payloadLength == 126) + extendedPayloadLengthLength = 2; + else if (payloadLength == 127) + extendedPayloadLengthLength = 8; + headerLength += extendedPayloadLengthLength; + + //if (!masked) + // throw ProtocolError("Client must send masked websocket bytes."); + + if (reserved != 0) + throw ProtocolError("Reserved bytes in header must be 0."); + + if (headerLength > websocketPendingBytes.usedBytes()) + return nbytesRead; + + uint64_t extendedPayloadLength = 0; + + int i = 2; + int shift = extendedPayloadLengthLength * 8; + while (shift > 0) + { + shift -= 8; + uint8_t byte = websocketPendingBytes.peakAhead(i++); + extendedPayloadLength += (byte << shift); + } + + if (extendedPayloadLength > 0) + realPayloadLength = extendedPayloadLength; + + if (headerLength > websocketPendingBytes.usedBytes()) + return nbytesRead; + + if (masked) + { + for (int j = 0; j < 4; j++) + { + incompleteWebsocketRead.maskingKey[j] = websocketPendingBytes.peakAhead(i++); + } + } + + assert(i == headerLength); + assert(headerLength <= websocketPendingBytes.usedBytes()); + websocketPendingBytes.advanceTail(headerLength); + + incompleteWebsocketRead.frame_bytes_left = realPayloadLength; + incompleteWebsocketRead.opcode = opcode; + } + + if (incompleteWebsocketRead.opcode == WebsocketOpcode::Binary) + { + // The following reads one websocket frame max: it will continue with the previous, or start a new one, which it may or may not finish. + size_t targetBufI = 0; + char *targetBuf = &static_cast(buf)[nbytesRead]; + while(websocketPendingBytes.usedBytes() > 0 && incompleteWebsocketRead.frame_bytes_left > 0 && nbytesRead < targetBufMaxSize) + { + const size_t asManyBytesOfThisFrameAsPossible = std::min(websocketPendingBytes.maxReadSize(), incompleteWebsocketRead.frame_bytes_left); + const size_t maxReadSize = std::min(asManyBytesOfThisFrameAsPossible, targetBufMaxSize - nbytesRead); + assert(maxReadSize > 0); + assert(static_cast(maxReadSize) + nbytesRead <= targetBufMaxSize); + for (size_t x = 0; x < maxReadSize; x++) + { + targetBuf[targetBufI++] = websocketPendingBytes.tailPtr()[x] ^ incompleteWebsocketRead.getNextMaskingByte(); + } + websocketPendingBytes.advanceTail(maxReadSize); + incompleteWebsocketRead.frame_bytes_left -= maxReadSize; + nbytesRead += maxReadSize; + } + } + else if (incompleteWebsocketRead.opcode == WebsocketOpcode::Ping) + { + // A ping MAY have user data, which needs to be ponged back; + + // Constructing a new temporary buffer because I need the reponse in one frame for writeAsMuchOfBufAsWebsocketFrame(). + std::vector response(incompleteWebsocketRead.frame_bytes_left); + websocketPendingBytes.read(response.data(), response.size()); + + websocketWriteRemainder.ensureFreeSpace(response.size()); + writeAsMuchOfBufAsWebsocketFrame(response.data(), response.size(), WebsocketOpcode::Pong); + parentClient->setReadyForWriting(true); + } + + if (!incompleteWebsocketRead.sillWorkingOnFrame()) + incompleteWebsocketRead.reset(); + } + assert(nbytesRead <= static_cast(nbytes)); + + return nbytesRead; +} + +/** + * @brief IoWrapper::writeAsMuchOfBufAsWebsocketFrame writes buf of part of buf as websocket frame to websocketWriteRemainder + * @param buf + * @param nbytes. The amount of bytes. Can be 0, for just an empty websocket frame. + * @return + */ +ssize_t IoWrapper::writeAsMuchOfBufAsWebsocketFrame(const void *buf, size_t nbytes, WebsocketOpcode opcode) +{ + // We do allow pong frames to generate a zero payload packet, but for binary, that's not necessary. + if (nbytes == 0 && opcode == WebsocketOpcode::Binary) + return 0; + + ssize_t nBytesReal = 0; + + // We normally wrap each write in a frame, but if a previous one didn't fit in the system's write buffers, we're still working on it. + if (websocketWriteRemainder.freeSpace() > WEBSOCKET_MAX_SENDING_HEADER_SIZE) + { + uint8_t extended_payload_length_num_bytes = 0; + uint8_t payload_length = 0; + if (nbytes < 126) + payload_length = nbytes; + else if (nbytes >= 126 && nbytes <= 0xFFFF) + { + payload_length = 126; + extended_payload_length_num_bytes = 2; + } + else if (nbytes > 0xFFFF) + { + payload_length = 127; + extended_payload_length_num_bytes = 8; + } + + int x = 0; + char header[WEBSOCKET_MAX_SENDING_HEADER_SIZE]; + header[x++] = (0b10000000 | static_cast(opcode)); + header[x++] = payload_length; + + const int header_length = x + extended_payload_length_num_bytes; + + // This block writes the extended payload length. + nBytesReal = std::min(nbytes, websocketWriteRemainder.freeSpace() - header_length); + const uint64_t nbytes64 = nBytesReal; + for (int z = extended_payload_length_num_bytes - 1; z >= 0; z--) + { + header[x++] = (nbytes64 >> (z*8)) & 0xFF; + } + assert(x <= WEBSOCKET_MAX_SENDING_HEADER_SIZE); + assert(x == header_length); + + websocketWriteRemainder.write(header, header_length); + websocketWriteRemainder.write(buf, nBytesReal); + } + + return nBytesReal; +} + +/* + * Mqtt docs: "A single WebSocket data frame can contain multiple or partial MQTT Control Packets. The receiver + * MUST NOT assume that MQTT Control Packets are aligned on WebSocket frame boundaries [MQTT-6.0.0-2]." We + * make use of that here, and wrap each write in a frame. + * + * It's can legitimately return a number of bytes written AND error with 'would block'. So, no need to do that + * repeating of the write thing that SSL_write() has. + */ +ssize_t IoWrapper::writeWebsocketAndOrSsl(int fd, const void *buf, size_t nbytes, IoWrapResult *error) +{ + if (websocketState != WebsocketState::Upgraded) + { + if (websocket && websocketState == WebsocketState::Upgrading) + websocketState = WebsocketState::Upgraded; + + return writeOrSslWrite(fd, buf, nbytes, error); + } + else + { + ssize_t nBytesReal = writeAsMuchOfBufAsWebsocketFrame(buf, nbytes); + + ssize_t n = 0; + while (websocketWriteRemainder.usedBytes() > 0) + { + n = writeOrSslWrite(fd, websocketWriteRemainder.tailPtr(), websocketWriteRemainder.maxReadSize(), error); + + if (n > 0) + websocketWriteRemainder.advanceTail(n); + if (n < 0) + break; + } + + if (n > 0) + return nBytesReal; + + return n; + } +} diff --git a/iowrapper.h b/iowrapper.h new file mode 100644 index 0000000..77ae85f --- /dev/null +++ b/iowrapper.h @@ -0,0 +1,122 @@ +#ifndef IOWRAPPER_H +#define IOWRAPPER_H + +#include "unistd.h" +#include "openssl/ssl.h" +#include "openssl/err.h" +#include + +#include "forward_declarations.h" + +#include "types.h" +#include "utils.h" +#include "logger.h" +#include "exceptions.h" + +#define WEBSOCKET_MIN_HEADER_BYTES_NEEDED 2 +#define WEBSOCKET_MAX_SENDING_HEADER_SIZE 10 + +#define OPENSSL_ERROR_STRING_SIZE 256 // OpenSSL requires at least 256. +#define OPENSSL_WRONG_VERSION_NUMBER 336130315 + +enum class IoWrapResult +{ + Success = 0, + Interrupted = 1, + Wouldblock = 2, + Disconnected = 3, + Error = 4 +}; + +enum class WebsocketOpcode +{ + Continuation = 0x00, + Text = 0x1, + Binary = 0x2, + Close = 0x8, + Ping = 0x9, + Pong = 0xA, + Unknown = 0xF +}; + +/* + * OpenSSL doc: "When a write function call has to be repeated because SSL_get_error(3) returned + * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments" + */ +struct IncompleteSslWrite +{ + const void *buf = nullptr; + size_t nbytes = 0; + + IncompleteSslWrite() = default; + IncompleteSslWrite(const void *buf, size_t nbytes); + bool hasPendingWrite() const; + + void reset(); +}; + +struct IncompleteWebsocketRead +{ + size_t frame_bytes_left = 0; + char maskingKey[4]; + int maskingKeyI = 0; + WebsocketOpcode opcode; + + void reset(); + bool sillWorkingOnFrame() const; + char getNextMaskingByte(); +}; + +enum class WebsocketState +{ + NotUpgraded, + Upgrading, + Upgraded +}; + +/** + * @brief provides a unified wrapper for SSL and websockets to read() and write(). + * + * + */ +class IoWrapper +{ + Client *parentClient; + const size_t initialBufferSize; + + SSL *ssl = nullptr; + bool sslAccepted = false; + IncompleteSslWrite incompleteSslWrite; + bool sslReadWantsWrite = false; + bool sslWriteWantsRead = false; + + bool websocket; + WebsocketState websocketState = WebsocketState::NotUpgraded; + CirBuf websocketPendingBytes; + IncompleteWebsocketRead incompleteWebsocketRead; + CirBuf websocketWriteRemainder; + + Logger *logger = Logger::getInstance(); + + ssize_t websocketBytesToReadBuffer(void *buf, const size_t nbytes); + ssize_t readOrSslRead(int fd, void *buf, size_t nbytes, IoWrapResult *error); + ssize_t writeOrSslWrite(int fd, const void *buf, size_t nbytes, IoWrapResult *error); + ssize_t writeAsMuchOfBufAsWebsocketFrame(const void *buf, size_t nbytes, WebsocketOpcode opcode = WebsocketOpcode::Binary); +public: + IoWrapper(SSL *ssl, bool websocket, const size_t initialBufferSize, Client *parent); + ~IoWrapper(); + + void startOrContinueSslAccept(); + bool getSslReadWantsWrite() const; + bool getSslWriteWantsRead() const; + bool isSslAccepted() const; + bool isSsl() const; + bool hasPendingWrite() const; + bool isWebsocket() const; + WebsocketState getWebsocketState() const; + + ssize_t readWebsocketAndOrSsl(int fd, void *buf, size_t nbytes, IoWrapResult *error); + ssize_t writeWebsocketAndOrSsl(int fd, const void *buf, size_t nbytes, IoWrapResult *error); +}; + +#endif // IOWRAPPER_H diff --git a/mainapp.cpp b/mainapp.cpp index 1528a8b..d46d11b 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -351,6 +351,7 @@ void MainApp::start() int listen_fd_plain = createListenSocket(this->listenPort, false); int listen_fd_ssl = createListenSocket(this->sslListenPort, true); + int listen_fd_websocket_plain = createListenSocket(1443, true); #ifdef NDEBUG logger->noLongerLogToStd(); @@ -391,7 +392,7 @@ void MainApp::start() int cur_fd = events[i].data.fd; try { - if (cur_fd == listen_fd_plain || cur_fd == listen_fd_ssl) + if (cur_fd == listen_fd_plain || cur_fd == listen_fd_ssl || listen_fd_websocket_plain) { std::shared_ptr thread_data = threads[next_thread_index++ % num_threads]; @@ -402,6 +403,7 @@ void MainApp::start() socklen_t len = sizeof(struct sockaddr); int fd = check(accept(cur_fd, &addr, &len)); + bool websocket = cur_fd == listen_fd_websocket_plain; SSL *clientSSL = nullptr; if (cur_fd == listen_fd_ssl) { @@ -417,7 +419,7 @@ void MainApp::start() SSL_set_fd(clientSSL, fd); } - Client_p client(new Client(fd, thread_data, clientSSL, settings)); + Client_p client(new Client(fd, thread_data, clientSSL, websocket, settings)); thread_data->giveClient(client); } else if (cur_fd == taskEventFd) diff --git a/utils.cpp b/utils.cpp index d37f855..23b3560 100644 --- a/utils.cpp +++ b/utils.cpp @@ -3,8 +3,10 @@ #include "sys/time.h" #include "sys/random.h" #include +#include #include "exceptions.h" +#include "cirbuf.h" std::list split(const std::string &input, const char sep, size_t max, bool keep_empty_parts) { @@ -226,3 +228,121 @@ bool isPowerOfTwo(int n) { return (n != 0) && (n & (n - 1)) == 0; } + +bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version) +{ + const std::string s(buf.tailPtr(), buf.usedBytes()); + std::istringstream is(s); + bool doubleEmptyLine = false; // meaning, the HTTP header is complete + bool upgradeHeaderSeen = false; + bool connectionHeaderSeen = false; + bool firstLine = true; + + std::string line; + while (std::getline(is, line)) + { + trim(line); + if (firstLine) + { + firstLine = false; + if (!startsWith(line, "GET")) + throw BadHttpRequest("Websocket request should start with GET."); + continue; + } + if (line.empty()) + { + doubleEmptyLine = true; + break; + } + + std::list fields = split(line, ':', 1); + const std::vector fields2(fields.begin(), fields.end()); + std::string name = str_tolower(fields2[0]); + trim(name); + std::string value = fields2[1]; + trim(value); + std::string value_lower = str_tolower(value); + + if (name == "upgrade" && value_lower == "websocket") + upgradeHeaderSeen = true; + else if (name == "connection" && value_lower == "upgrade") + connectionHeaderSeen = true; + else if (name == "sec-websocket-key") + websocket_key = value; + else if (name == "sec-websocket-version") + websocket_version = stoi(value); + } + + if (doubleEmptyLine) + { + if (!connectionHeaderSeen || !upgradeHeaderSeen) + throw BadHttpRequest("HTTP request is not a websocket upgrade request."); + } + + return doubleEmptyLine; +} + +std::string base64Encode(const unsigned char *input, const int length) +{ + const int pl = 4*((length+2)/3); + char *output = reinterpret_cast(calloc(pl+1, 1)); + const int ol = EVP_EncodeBlock(reinterpret_cast(output), input, length); + std::string result(output); + free(output); + + if (pl != ol) + throw std::runtime_error("Base64 encode error."); + + return result; +} + +std::string generateWebsocketAcceptString(const std::string &websocketKey) +{ + unsigned char md_value[EVP_MAX_MD_SIZE]; + unsigned int md_len; + + EVP_MD_CTX *mdctx = EVP_MD_CTX_new();; + const EVP_MD *md = EVP_sha1(); + EVP_DigestInit_ex(mdctx, md, NULL); + + const std::string keyPlusMagic = websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + EVP_DigestUpdate(mdctx, keyPlusMagic.c_str(), keyPlusMagic.length()); + EVP_DigestFinal_ex(mdctx, md_value, &md_len); + EVP_MD_CTX_free(mdctx); + + std::string base64 = base64Encode(md_value, md_len); + return base64; +} + +std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion) +{ + std::ostringstream oss; + oss << "HTTP/1.1 400 Bad Request\r\n"; + oss << "Sec-WebSocket-Version: " << wantedVersion; + oss << "\r\n"; + oss.flush(); + return oss.str(); +} + +std::string generateBadHttpRequestReponse(const std::string &msg) +{ + std::ostringstream oss; + oss << "HTTP/1.1 400 Bad Request\r\n"; + oss << "\r\n"; + oss << msg; + oss.flush(); + return oss.str(); +} + +std::string generateWebsocketAnswer(const std::string &acceptString) +{ + std::ostringstream oss; + oss << "HTTP/1.1 101 Switching Protocols\r\n"; + oss << "Upgrade: websocket\r\n"; + oss << "Connection: Upgrade\r\n"; + oss << "Sec-WebSocket-Accept: " << acceptString << "\r\n"; + oss << "\r\n"; + oss.flush(); + return oss.str(); +} diff --git a/utils.h b/utils.h index 804c804..666f517 100644 --- a/utils.h +++ b/utils.h @@ -8,6 +8,9 @@ #include #include #include +#include + +#include "cirbuf.h" template int check(int rc) { @@ -43,4 +46,13 @@ std::string str_tolower(std::string s); bool stringTruthiness(const std::string &val); bool isPowerOfTwo(int val); +bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version); + +std::string base64Encode(const unsigned char *input, const int length); +std::string generateWebsocketAcceptString(const std::string &websocketKey); + +std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion); +std::string generateBadHttpRequestReponse(const std::string &msg); +std::string generateWebsocketAnswer(const std::string &acceptString); + #endif // UTILS_H