Commit b5ba41f503c602c0edcdeec0db5b2b8344050cb8
1 parent
1106e210
Add IoWrapper, with websocket support added
The ping/pong is actually untested at this point, because Paho (my test client for now) doesn't do those. I wonder if any do, because MQTT already has ping/pong.
Showing
11 changed files
with
988 additions
and
268 deletions
CMakeLists.txt
| @@ -29,6 +29,7 @@ add_executable(FlashMQ | @@ -29,6 +29,7 @@ add_executable(FlashMQ | ||
| 29 | sslctxmanager.cpp | 29 | sslctxmanager.cpp |
| 30 | timer.cpp | 30 | timer.cpp |
| 31 | globalsettings.cpp | 31 | globalsettings.cpp |
| 32 | + iowrapper.cpp | ||
| 32 | ) | 33 | ) |
| 33 | 34 | ||
| 34 | target_link_libraries(FlashMQ pthread dl ssl crypto) | 35 | target_link_libraries(FlashMQ pthread dl ssl crypto) |
cirbuf.cpp
| @@ -9,10 +9,14 @@ | @@ -9,10 +9,14 @@ | ||
| 9 | #include <cstring> | 9 | #include <cstring> |
| 10 | 10 | ||
| 11 | #include "logger.h" | 11 | #include "logger.h" |
| 12 | +#include "utils.h" | ||
| 12 | 13 | ||
| 13 | CirBuf::CirBuf(size_t size) : | 14 | CirBuf::CirBuf(size_t size) : |
| 14 | size(size) | 15 | size(size) |
| 15 | { | 16 | { |
| 17 | + if (size == 0) | ||
| 18 | + return; | ||
| 19 | + | ||
| 16 | buf = (char*)malloc(size); | 20 | buf = (char*)malloc(size); |
| 17 | 21 | ||
| 18 | if (buf == NULL) | 22 | if (buf == NULL) |
| @@ -69,12 +73,14 @@ char *CirBuf::tailPtr() | @@ -69,12 +73,14 @@ char *CirBuf::tailPtr() | ||
| 69 | 73 | ||
| 70 | void CirBuf::advanceHead(uint32_t n) | 74 | void CirBuf::advanceHead(uint32_t n) |
| 71 | { | 75 | { |
| 76 | + assert(n <= freeSpace()); | ||
| 72 | head = (head + n) & (size -1); | 77 | head = (head + n) & (size -1); |
| 73 | assert(tail != head); // Putting things in the buffer must never end on tail, because tail == head == empty. | 78 | assert(tail != head); // Putting things in the buffer must never end on tail, because tail == head == empty. |
| 74 | } | 79 | } |
| 75 | 80 | ||
| 76 | void CirBuf::advanceTail(uint32_t n) | 81 | void CirBuf::advanceTail(uint32_t n) |
| 77 | { | 82 | { |
| 83 | + assert(n <= usedBytes()); | ||
| 78 | tail = (tail + n) & (size -1); | 84 | tail = (tail + n) & (size -1); |
| 79 | } | 85 | } |
| 80 | 86 | ||
| @@ -84,9 +90,31 @@ char CirBuf::peakAhead(uint32_t offset) const | @@ -84,9 +90,31 @@ char CirBuf::peakAhead(uint32_t offset) const | ||
| 84 | return b; | 90 | return b; |
| 85 | } | 91 | } |
| 86 | 92 | ||
| 87 | -void CirBuf::doubleSize() | 93 | +void CirBuf::ensureFreeSpace(size_t n) |
| 94 | +{ | ||
| 95 | + const size_t _usedBytes = usedBytes(); | ||
| 96 | + | ||
| 97 | + int mul = 1; | ||
| 98 | + | ||
| 99 | + while((mul * size - _usedBytes - 1) < n) | ||
| 100 | + { | ||
| 101 | + mul = mul << 1; | ||
| 102 | + } | ||
| 103 | + | ||
| 104 | + if (mul == 1) | ||
| 105 | + return; | ||
| 106 | + | ||
| 107 | + doubleSize(mul); | ||
| 108 | +} | ||
| 109 | + | ||
| 110 | +void CirBuf::doubleSize(uint factor) | ||
| 88 | { | 111 | { |
| 89 | - uint newSize = size * 2; | 112 | + if (factor == 1) |
| 113 | + return; | ||
| 114 | + | ||
| 115 | + assert(isPowerOfTwo(factor)); | ||
| 116 | + | ||
| 117 | + uint newSize = size * factor; | ||
| 90 | char *newBuf = (char*)realloc(buf, newSize); | 118 | char *newBuf = (char*)realloc(buf, newSize); |
| 91 | 119 | ||
| 92 | if (newBuf == NULL) | 120 | if (newBuf == NULL) |
| @@ -145,3 +173,55 @@ void CirBuf::resetSize(size_t newSize) | @@ -145,3 +173,55 @@ void CirBuf::resetSize(size_t newSize) | ||
| 145 | memset(buf, 0, newSize); | 173 | memset(buf, 0, newSize); |
| 146 | #endif | 174 | #endif |
| 147 | } | 175 | } |
| 176 | + | ||
| 177 | +void CirBuf::reset() | ||
| 178 | +{ | ||
| 179 | + head = 0; | ||
| 180 | + tail = 0; | ||
| 181 | + | ||
| 182 | +#ifndef NDEBUG | ||
| 183 | + memset(buf, 0, size); | ||
| 184 | +#endif | ||
| 185 | +} | ||
| 186 | + | ||
| 187 | +// When you know the data you want to write fits in the buffer, use this. | ||
| 188 | +void CirBuf::write(const void *buf, size_t count) | ||
| 189 | +{ | ||
| 190 | + assert(count <= freeSpace()); | ||
| 191 | + | ||
| 192 | + ssize_t len_left = count; | ||
| 193 | + size_t src_i = 0; | ||
| 194 | + while (len_left > 0) | ||
| 195 | + { | ||
| 196 | + const size_t len = std::min<int>(len_left, maxWriteSize()); | ||
| 197 | + assert(len > 0); | ||
| 198 | + const char *src = &reinterpret_cast<const char*>(buf)[src_i]; | ||
| 199 | + std::memcpy(headPtr(), src, len); | ||
| 200 | + advanceHead(len); | ||
| 201 | + src_i += len; | ||
| 202 | + len_left -= len; | ||
| 203 | + } | ||
| 204 | + assert(len_left == 0); | ||
| 205 | + assert(src_i == count); | ||
| 206 | +} | ||
| 207 | + | ||
| 208 | +// When you know 'count' bytes are present and you want to read them into buf | ||
| 209 | +void CirBuf::read(void *buf, const size_t count) | ||
| 210 | +{ | ||
| 211 | + assert(count <= usedBytes()); | ||
| 212 | + | ||
| 213 | + char *_buf = static_cast<char*>(buf); | ||
| 214 | + int i = 0; | ||
| 215 | + ssize_t _packet_len = count; | ||
| 216 | + while (_packet_len > 0) | ||
| 217 | + { | ||
| 218 | + const int readlen = std::min<int>(maxReadSize(), _packet_len); | ||
| 219 | + assert(readlen > 0); | ||
| 220 | + std::memcpy(&_buf[i], tailPtr(), readlen); | ||
| 221 | + advanceTail(readlen); | ||
| 222 | + i += readlen; | ||
| 223 | + _packet_len -= readlen; | ||
| 224 | + } | ||
| 225 | + assert(_packet_len == 0); | ||
| 226 | + assert(i == _packet_len); | ||
| 227 | +} |
cirbuf.h
| @@ -32,11 +32,16 @@ public: | @@ -32,11 +32,16 @@ public: | ||
| 32 | void advanceHead(uint32_t n); | 32 | void advanceHead(uint32_t n); |
| 33 | void advanceTail(uint32_t n); | 33 | void advanceTail(uint32_t n); |
| 34 | char peakAhead(uint32_t offset) const; | 34 | char peakAhead(uint32_t offset) const; |
| 35 | - void doubleSize(); | 35 | + void ensureFreeSpace(size_t n); |
| 36 | + void doubleSize(uint factor = 2); | ||
| 36 | uint32_t getSize() const; | 37 | uint32_t getSize() const; |
| 37 | 38 | ||
| 38 | time_t bufferLastResizedSecondsAgo() const; | 39 | time_t bufferLastResizedSecondsAgo() const; |
| 39 | void resetSize(size_t size); | 40 | void resetSize(size_t size); |
| 41 | + void reset(); | ||
| 42 | + | ||
| 43 | + void write(const void *buf, size_t count); | ||
| 44 | + void read(void *buf, const size_t count); | ||
| 40 | }; | 45 | }; |
| 41 | 46 | ||
| 42 | #endif // CIRBUF_H | 47 | #endif // CIRBUF_H |
client.cpp
| @@ -7,11 +7,11 @@ | @@ -7,11 +7,11 @@ | ||
| 7 | 7 | ||
| 8 | #include "logger.h" | 8 | #include "logger.h" |
| 9 | 9 | ||
| 10 | -Client::Client(int fd, ThreadData_p threadData, SSL *ssl, const GlobalSettings &settings) : | 10 | +Client::Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, const GlobalSettings &settings) : |
| 11 | fd(fd), | 11 | fd(fd), |
| 12 | - ssl(ssl), | ||
| 13 | initialBufferSize(settings.clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy | 12 | initialBufferSize(settings.clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy |
| 14 | maxPacketSize(settings.maxPacketSize), // Same as initialBufferSize comment. | 13 | maxPacketSize(settings.maxPacketSize), // Same as initialBufferSize comment. |
| 14 | + ioWrapper(ssl, websocket, initialBufferSize, this), | ||
| 15 | readbuf(initialBufferSize), | 15 | readbuf(initialBufferSize), |
| 16 | writebuf(initialBufferSize), | 16 | writebuf(initialBufferSize), |
| 17 | threadData(threadData) | 17 | threadData(threadData) |
| @@ -40,63 +40,32 @@ Client::~Client() | @@ -40,63 +40,32 @@ Client::~Client() | ||
| 40 | logger->logf(LOG_NOTICE, "Removing client '%s'. Reason(s): %s", repr().c_str(), disconnectReason.c_str()); | 40 | logger->logf(LOG_NOTICE, "Removing client '%s'. Reason(s): %s", repr().c_str(), disconnectReason.c_str()); |
| 41 | if (epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL) != 0) | 41 | if (epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL) != 0) |
| 42 | logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd, repr().c_str(), strerror(errno)); | 42 | logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd, repr().c_str(), strerror(errno)); |
| 43 | - if (ssl) | ||
| 44 | - { | ||
| 45 | - // I don't do SSL_shutdown(), because I don't want to keep the session, plus, that takes active de-negiotation, so it can't be done | ||
| 46 | - // in the destructor. | ||
| 47 | - SSL_free(ssl); | ||
| 48 | - } | ||
| 49 | close(fd); | 43 | close(fd); |
| 50 | } | 44 | } |
| 51 | 45 | ||
| 52 | bool Client::isSslAccepted() const | 46 | bool Client::isSslAccepted() const |
| 53 | { | 47 | { |
| 54 | - return sslAccepted; | 48 | + return ioWrapper.isSslAccepted(); |
| 55 | } | 49 | } |
| 56 | 50 | ||
| 57 | bool Client::isSsl() const | 51 | bool Client::isSsl() const |
| 58 | { | 52 | { |
| 59 | - return this->ssl != nullptr; | 53 | + return ioWrapper.isSsl(); |
| 60 | } | 54 | } |
| 61 | 55 | ||
| 62 | bool Client::getSslReadWantsWrite() const | 56 | bool Client::getSslReadWantsWrite() const |
| 63 | { | 57 | { |
| 64 | - return this->sslReadWantsWrite; | 58 | + return ioWrapper.getSslReadWantsWrite(); |
| 65 | } | 59 | } |
| 66 | 60 | ||
| 67 | bool Client::getSslWriteWantsRead() const | 61 | bool Client::getSslWriteWantsRead() const |
| 68 | { | 62 | { |
| 69 | - return this->sslWriteWantsRead; | 63 | + return ioWrapper.getSslWriteWantsRead(); |
| 70 | } | 64 | } |
| 71 | 65 | ||
| 72 | void Client::startOrContinueSslAccept() | 66 | void Client::startOrContinueSslAccept() |
| 73 | { | 67 | { |
| 74 | - ERR_clear_error(); | ||
| 75 | - int accepted = SSL_accept(ssl); | ||
| 76 | - char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; | ||
| 77 | - if (accepted <= 0) | ||
| 78 | - { | ||
| 79 | - int err = SSL_get_error(ssl, accepted); | ||
| 80 | - | ||
| 81 | - if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) | ||
| 82 | - { | ||
| 83 | - setReadyForWriting(err == SSL_ERROR_WANT_WRITE); | ||
| 84 | - return; | ||
| 85 | - } | ||
| 86 | - | ||
| 87 | - unsigned long error_code = ERR_get_error(); | ||
| 88 | - | ||
| 89 | - ERR_error_string(error_code, sslErrorBuf); | ||
| 90 | - std::string errorMsg(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); | ||
| 91 | - | ||
| 92 | - if (error_code == OPENSSL_WRONG_VERSION_NUMBER) | ||
| 93 | - errorMsg = "Wrong protocol version number. Probably a non-SSL connection on SSL socket."; | ||
| 94 | - | ||
| 95 | - //ERR_print_errors_cb(logSslError, NULL); | ||
| 96 | - throw std::runtime_error("Problem accepting SSL socket: " + errorMsg); | ||
| 97 | - } | ||
| 98 | - setReadyForWriting(false); // Undo write readiness that may have have happened during SSL handshake | ||
| 99 | - sslAccepted = true; | 68 | + ioWrapper.startOrContinueSslAccept(); |
| 100 | } | 69 | } |
| 101 | 70 | ||
| 102 | // Causes future activity on the client to cause a disconnect. | 71 | // Causes future activity on the client to cause a disconnect. |
| @@ -108,83 +77,6 @@ void Client::markAsDisconnecting() | @@ -108,83 +77,6 @@ void Client::markAsDisconnecting() | ||
| 108 | disconnecting = true; | 77 | disconnecting = true; |
| 109 | } | 78 | } |
| 110 | 79 | ||
| 111 | -// SSL and non-SSL sockets behave differently. For one, reading 0 doesn't mean 'disconnected' with an SSL | ||
| 112 | -// socket. This wrapper unifies behavor for the caller. | ||
| 113 | -ssize_t Client::readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error) | ||
| 114 | -{ | ||
| 115 | - *error = IoWrapResult::Success; | ||
| 116 | - ssize_t n = 0; | ||
| 117 | - if (!ssl) | ||
| 118 | - { | ||
| 119 | - n = read(fd, buf, nbytes); | ||
| 120 | - if (n < 0) | ||
| 121 | - { | ||
| 122 | - if (errno == EINTR) | ||
| 123 | - *error = IoWrapResult::Interrupted; | ||
| 124 | - else if (errno == EAGAIN || errno == EWOULDBLOCK) | ||
| 125 | - *error = IoWrapResult::Wouldblock; | ||
| 126 | - else | ||
| 127 | - check<std::runtime_error>(n); | ||
| 128 | - } | ||
| 129 | - else if (n == 0) | ||
| 130 | - { | ||
| 131 | - *error = IoWrapResult::Disconnected; | ||
| 132 | - } | ||
| 133 | - } | ||
| 134 | - else | ||
| 135 | - { | ||
| 136 | - this->sslReadWantsWrite = false; | ||
| 137 | - ERR_clear_error(); | ||
| 138 | - char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; | ||
| 139 | - n = SSL_read(ssl, buf, nbytes); | ||
| 140 | - | ||
| 141 | - if (n <= 0) | ||
| 142 | - { | ||
| 143 | - int err = SSL_get_error(ssl, n); | ||
| 144 | - unsigned long error_code = ERR_get_error(); | ||
| 145 | - | ||
| 146 | - // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html "BUGS" why EOF is seen as SSL_ERROR_SYSCALL. | ||
| 147 | - if (err == SSL_ERROR_ZERO_RETURN || (err == SSL_ERROR_SYSCALL && errno == 0)) | ||
| 148 | - { | ||
| 149 | - *error = IoWrapResult::Disconnected; | ||
| 150 | - } | ||
| 151 | - else if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) | ||
| 152 | - { | ||
| 153 | - *error = IoWrapResult::Wouldblock; | ||
| 154 | - if (err == SSL_ERROR_WANT_WRITE) | ||
| 155 | - { | ||
| 156 | - sslReadWantsWrite = true; | ||
| 157 | - setReadyForWriting(true); | ||
| 158 | - } | ||
| 159 | - n = -1; | ||
| 160 | - } | ||
| 161 | - else | ||
| 162 | - { | ||
| 163 | - if (err == SSL_ERROR_SYSCALL) | ||
| 164 | - { | ||
| 165 | - // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say | ||
| 166 | - // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it | ||
| 167 | - // implies EINTR is not included? | ||
| 168 | - if (errno == EINTR) | ||
| 169 | - *error = IoWrapResult::Interrupted; | ||
| 170 | - else | ||
| 171 | - { | ||
| 172 | - char *err = strerror(errno); | ||
| 173 | - std::string msg(err); | ||
| 174 | - throw std::runtime_error("SSL read error: " + msg); | ||
| 175 | - } | ||
| 176 | - } | ||
| 177 | - ERR_error_string(error_code, sslErrorBuf); | ||
| 178 | - std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); | ||
| 179 | - ERR_print_errors_cb(logSslError, NULL); | ||
| 180 | - throw std::runtime_error("SSL socket error reading: " + errorString); | ||
| 181 | - } | ||
| 182 | - } | ||
| 183 | - } | ||
| 184 | - | ||
| 185 | - return n; | ||
| 186 | -} | ||
| 187 | - | ||
| 188 | // false means any kind of error we want to get rid of the client for. | 80 | // false means any kind of error we want to get rid of the client for. |
| 189 | bool Client::readFdIntoBuffer() | 81 | bool Client::readFdIntoBuffer() |
| 190 | { | 82 | { |
| @@ -193,7 +85,7 @@ bool Client::readFdIntoBuffer() | @@ -193,7 +85,7 @@ bool Client::readFdIntoBuffer() | ||
| 193 | 85 | ||
| 194 | IoWrapResult error = IoWrapResult::Success; | 86 | IoWrapResult error = IoWrapResult::Success; |
| 195 | int n = 0; | 87 | int n = 0; |
| 196 | - while (readbuf.freeSpace() > 0 && (n = readWrap(fd, readbuf.headPtr(), readbuf.maxWriteSize(), &error)) != 0) | 88 | + while (readbuf.freeSpace() > 0 && (n = ioWrapper.readWebsocketAndOrSsl(fd, readbuf.headPtr(), readbuf.maxWriteSize(), &error)) != 0) |
| 197 | { | 89 | { |
| 198 | if (n > 0) | 90 | if (n > 0) |
| 199 | { | 91 | { |
| @@ -232,6 +124,20 @@ bool Client::readFdIntoBuffer() | @@ -232,6 +124,20 @@ bool Client::readFdIntoBuffer() | ||
| 232 | return true; | 124 | return true; |
| 233 | } | 125 | } |
| 234 | 126 | ||
| 127 | +void Client::writeText(const std::string &text) | ||
| 128 | +{ | ||
| 129 | + assert(ioWrapper.isWebsocket()); | ||
| 130 | + assert(ioWrapper.getWebsocketState() == WebsocketState::NotUpgraded); | ||
| 131 | + | ||
| 132 | + // Not necessary, because at this point, no other threads write to this client, but including for clarity. | ||
| 133 | + std::lock_guard<std::mutex> locker(writeBufMutex); | ||
| 134 | + | ||
| 135 | + writebuf.ensureFreeSpace(text.size()); | ||
| 136 | + writebuf.write(text.c_str(), text.length()); | ||
| 137 | + | ||
| 138 | + setReadyForWriting(true); | ||
| 139 | +} | ||
| 140 | + | ||
| 235 | void Client::writeMqttPacket(const MqttPacket &packet) | 141 | void Client::writeMqttPacket(const MqttPacket &packet) |
| 236 | { | 142 | { |
| 237 | std::lock_guard<std::mutex> locker(writeBufMutex); | 143 | std::lock_guard<std::mutex> locker(writeBufMutex); |
| @@ -322,94 +228,6 @@ void Client::writePingResp() | @@ -322,94 +228,6 @@ void Client::writePingResp() | ||
| 322 | setReadyForWriting(true); | 228 | setReadyForWriting(true); |
| 323 | } | 229 | } |
| 324 | 230 | ||
| 325 | -// SSL and non-SSL sockets behave differently. This wrapper unifies behavor for the caller. | ||
| 326 | -ssize_t Client::writeWrap(int fd, const void *buf, size_t nbytes, IoWrapResult *error) | ||
| 327 | -{ | ||
| 328 | - *error = IoWrapResult::Success; | ||
| 329 | - ssize_t n = 0; | ||
| 330 | - | ||
| 331 | - if (!ssl) | ||
| 332 | - { | ||
| 333 | - // A write on a socket with count=0 is unspecified. | ||
| 334 | - assert(nbytes > 0); | ||
| 335 | - | ||
| 336 | - n = write(fd, buf, nbytes); | ||
| 337 | - if (n < 0) | ||
| 338 | - { | ||
| 339 | - if (errno == EINTR) | ||
| 340 | - *error = IoWrapResult::Interrupted; | ||
| 341 | - else if (errno == EAGAIN || errno == EWOULDBLOCK) | ||
| 342 | - *error = IoWrapResult::Wouldblock; | ||
| 343 | - else | ||
| 344 | - check<std::runtime_error>(n); | ||
| 345 | - } | ||
| 346 | - } | ||
| 347 | - else | ||
| 348 | - { | ||
| 349 | - const void *buf_ = buf; | ||
| 350 | - size_t nbytes_ = nbytes; | ||
| 351 | - | ||
| 352 | - /* | ||
| 353 | - * OpenSSL doc: When a write function call has to be repeated because SSL_get_error(3) returned | ||
| 354 | - * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments | ||
| 355 | - */ | ||
| 356 | - if (this->incompleteSslWrite.hasPendingWrite()) | ||
| 357 | - { | ||
| 358 | - buf_ = this->incompleteSslWrite.buf; | ||
| 359 | - nbytes_ = this->incompleteSslWrite.nbytes; | ||
| 360 | - } | ||
| 361 | - | ||
| 362 | - // OpenSSL: "You should not call SSL_write() with num=0, it will return an error" | ||
| 363 | - assert(nbytes_ > 0); | ||
| 364 | - | ||
| 365 | - this->sslWriteWantsRead = false; | ||
| 366 | - this->incompleteSslWrite.reset(); | ||
| 367 | - | ||
| 368 | - ERR_clear_error(); | ||
| 369 | - char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; | ||
| 370 | - n = SSL_write(ssl, buf_, nbytes_); | ||
| 371 | - | ||
| 372 | - if (n <= 0) | ||
| 373 | - { | ||
| 374 | - int err = SSL_get_error(ssl, n); | ||
| 375 | - unsigned long error_code = ERR_get_error(); | ||
| 376 | - if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) | ||
| 377 | - { | ||
| 378 | - logger->logf(LOG_DEBUG, "Write is incomplete: %d", err); | ||
| 379 | - *error = IoWrapResult::Wouldblock; | ||
| 380 | - IncompleteSslWrite sslAction(buf_, nbytes_); | ||
| 381 | - this->incompleteSslWrite = sslAction; | ||
| 382 | - if (err == SSL_ERROR_WANT_READ) | ||
| 383 | - this->sslWriteWantsRead = true; | ||
| 384 | - n = 0; | ||
| 385 | - } | ||
| 386 | - else | ||
| 387 | - { | ||
| 388 | - if (err == SSL_ERROR_SYSCALL) | ||
| 389 | - { | ||
| 390 | - // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say | ||
| 391 | - // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it | ||
| 392 | - // implies EINTR is not included? | ||
| 393 | - if (errno == EINTR) | ||
| 394 | - *error = IoWrapResult::Interrupted; | ||
| 395 | - else | ||
| 396 | - { | ||
| 397 | - char *err = strerror(errno); | ||
| 398 | - std::string msg(err); | ||
| 399 | - throw std::runtime_error(msg); | ||
| 400 | - } | ||
| 401 | - } | ||
| 402 | - ERR_error_string(error_code, sslErrorBuf); | ||
| 403 | - std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); | ||
| 404 | - ERR_print_errors_cb(logSslError, NULL); | ||
| 405 | - throw std::runtime_error("SSL socket error writing: " + errorString); | ||
| 406 | - } | ||
| 407 | - } | ||
| 408 | - } | ||
| 409 | - | ||
| 410 | - return n; | ||
| 411 | -} | ||
| 412 | - | ||
| 413 | bool Client::writeBufIntoFd() | 231 | bool Client::writeBufIntoFd() |
| 414 | { | 232 | { |
| 415 | std::unique_lock<std::mutex> lock(writeBufMutex, std::try_to_lock); | 233 | std::unique_lock<std::mutex> lock(writeBufMutex, std::try_to_lock); |
| @@ -422,9 +240,9 @@ bool Client::writeBufIntoFd() | @@ -422,9 +240,9 @@ bool Client::writeBufIntoFd() | ||
| 422 | 240 | ||
| 423 | IoWrapResult error = IoWrapResult::Success; | 241 | IoWrapResult error = IoWrapResult::Success; |
| 424 | int n; | 242 | int n; |
| 425 | - while (writebuf.usedBytes() > 0 || incompleteSslWrite.hasPendingWrite()) | 243 | + while (writebuf.usedBytes() > 0 || ioWrapper.hasPendingWrite()) |
| 426 | { | 244 | { |
| 427 | - n = writeWrap(fd, writebuf.tailPtr(), writebuf.maxReadSize(), &error); | 245 | + n = ioWrapper.writeWebsocketAndOrSsl(fd, writebuf.tailPtr(), writebuf.maxReadSize(), &error); |
| 428 | 246 | ||
| 429 | if (n > 0) | 247 | if (n > 0) |
| 430 | writebuf.advanceTail(n); | 248 | writebuf.advanceTail(n); |
| @@ -484,7 +302,7 @@ void Client::setReadyForWriting(bool val) | @@ -484,7 +302,7 @@ void Client::setReadyForWriting(bool val) | ||
| 484 | if (disconnecting) | 302 | if (disconnecting) |
| 485 | return; | 303 | return; |
| 486 | 304 | ||
| 487 | - if (sslReadWantsWrite) | 305 | + if (ioWrapper.getSslReadWantsWrite()) |
| 488 | val = true; | 306 | val = true; |
| 489 | 307 | ||
| 490 | if (val == this->readyForWriting) | 308 | if (val == this->readyForWriting) |
| @@ -622,20 +440,3 @@ void Client::clearWill() | @@ -622,20 +440,3 @@ void Client::clearWill() | ||
| 622 | will_qos = 0; | 440 | will_qos = 0; |
| 623 | } | 441 | } |
| 624 | 442 | ||
| 625 | -IncompleteSslWrite::IncompleteSslWrite(const void *buf, size_t nbytes) : | ||
| 626 | - buf(buf), | ||
| 627 | - nbytes(nbytes) | ||
| 628 | -{ | ||
| 629 | - | ||
| 630 | -} | ||
| 631 | - | ||
| 632 | -void IncompleteSslWrite::reset() | ||
| 633 | -{ | ||
| 634 | - buf = nullptr; | ||
| 635 | - nbytes = 0; | ||
| 636 | -} | ||
| 637 | - | ||
| 638 | -bool IncompleteSslWrite::hasPendingWrite() | ||
| 639 | -{ | ||
| 640 | - return buf != nullptr; | ||
| 641 | -} |
client.h
| @@ -9,6 +9,7 @@ | @@ -9,6 +9,7 @@ | ||
| 9 | #include <time.h> | 9 | #include <time.h> |
| 10 | 10 | ||
| 11 | #include <openssl/ssl.h> | 11 | #include <openssl/ssl.h> |
| 12 | +#include <openssl/err.h> | ||
| 12 | 13 | ||
| 13 | #include "forward_declarations.h" | 14 | #include "forward_declarations.h" |
| 14 | 15 | ||
| @@ -17,54 +18,25 @@ | @@ -17,54 +18,25 @@ | ||
| 17 | #include "exceptions.h" | 18 | #include "exceptions.h" |
| 18 | #include "cirbuf.h" | 19 | #include "cirbuf.h" |
| 19 | #include "types.h" | 20 | #include "types.h" |
| 20 | - | ||
| 21 | -#include <openssl/ssl.h> | ||
| 22 | -#include <openssl/err.h> | 21 | +#include "iowrapper.h" |
| 23 | 22 | ||
| 24 | #define MQTT_HEADER_LENGH 2 | 23 | #define MQTT_HEADER_LENGH 2 |
| 25 | 24 | ||
| 26 | -#define OPENSSL_ERROR_STRING_SIZE 256 // OpenSSL requires at least 256. | ||
| 27 | -#define OPENSSL_WRONG_VERSION_NUMBER 336130315 | ||
| 28 | - | ||
| 29 | -enum class IoWrapResult | ||
| 30 | -{ | ||
| 31 | - Success = 0, | ||
| 32 | - Interrupted = 1, | ||
| 33 | - Wouldblock = 2, | ||
| 34 | - Disconnected = 3, | ||
| 35 | - Error = 4 | ||
| 36 | -}; | ||
| 37 | - | ||
| 38 | -/* | ||
| 39 | - * OpenSSL doc: "When a write function call has to be repeated because SSL_get_error(3) returned | ||
| 40 | - * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments" | ||
| 41 | - */ | ||
| 42 | -struct IncompleteSslWrite | ||
| 43 | -{ | ||
| 44 | - const void *buf = nullptr; | ||
| 45 | - size_t nbytes = 0; | ||
| 46 | - | ||
| 47 | - IncompleteSslWrite() = default; | ||
| 48 | - IncompleteSslWrite(const void *buf, size_t nbytes); | ||
| 49 | - bool hasPendingWrite(); | ||
| 50 | - | ||
| 51 | - void reset(); | ||
| 52 | -}; | ||
| 53 | 25 | ||
| 54 | // TODO: give accepted addr, for showing in logs | 26 | // TODO: give accepted addr, for showing in logs |
| 55 | class Client | 27 | class Client |
| 56 | { | 28 | { |
| 29 | + friend class IoWrapper; | ||
| 30 | + | ||
| 57 | int fd; | 31 | int fd; |
| 58 | - SSL *ssl = nullptr; | ||
| 59 | - bool sslAccepted = false; | ||
| 60 | - IncompleteSslWrite incompleteSslWrite; | ||
| 61 | - bool sslReadWantsWrite = false; | ||
| 62 | - bool sslWriteWantsRead = false; | 32 | + |
| 63 | ProtocolVersion protocolVersion = ProtocolVersion::None; | 33 | ProtocolVersion protocolVersion = ProtocolVersion::None; |
| 64 | 34 | ||
| 65 | const size_t initialBufferSize = 0; | 35 | const size_t initialBufferSize = 0; |
| 66 | const size_t maxPacketSize = 0; | 36 | const size_t maxPacketSize = 0; |
| 67 | 37 | ||
| 38 | + IoWrapper ioWrapper; | ||
| 39 | + | ||
| 68 | CirBuf readbuf; | 40 | CirBuf readbuf; |
| 69 | uint8_t readBufIsZeroCount = 0; | 41 | uint8_t readBufIsZeroCount = 0; |
| 70 | 42 | ||
| @@ -97,12 +69,11 @@ class Client | @@ -97,12 +69,11 @@ class Client | ||
| 97 | 69 | ||
| 98 | Logger *logger = Logger::getInstance(); | 70 | Logger *logger = Logger::getInstance(); |
| 99 | 71 | ||
| 100 | - | ||
| 101 | void setReadyForWriting(bool val); | 72 | void setReadyForWriting(bool val); |
| 102 | void setReadyForReading(bool val); | 73 | void setReadyForReading(bool val); |
| 103 | 74 | ||
| 104 | public: | 75 | public: |
| 105 | - Client(int fd, ThreadData_p threadData, SSL *ssl, const GlobalSettings &settings); | 76 | + Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, const GlobalSettings &settings); |
| 106 | Client(const Client &other) = delete; | 77 | Client(const Client &other) = delete; |
| 107 | Client(Client &&other) = delete; | 78 | Client(Client &&other) = delete; |
| 108 | ~Client(); | 79 | ~Client(); |
| @@ -115,7 +86,6 @@ public: | @@ -115,7 +86,6 @@ public: | ||
| 115 | 86 | ||
| 116 | void startOrContinueSslAccept(); | 87 | void startOrContinueSslAccept(); |
| 117 | void markAsDisconnecting(); | 88 | void markAsDisconnecting(); |
| 118 | - ssize_t readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error); | ||
| 119 | bool readFdIntoBuffer(); | 89 | bool readFdIntoBuffer(); |
| 120 | bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender); | 90 | bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender); |
| 121 | void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession); | 91 | void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession); |
| @@ -131,10 +101,10 @@ public: | @@ -131,10 +101,10 @@ public: | ||
| 131 | std::shared_ptr<Session> getSession(); | 101 | std::shared_ptr<Session> getSession(); |
| 132 | void setDisconnectReason(const std::string &reason); | 102 | void setDisconnectReason(const std::string &reason); |
| 133 | 103 | ||
| 104 | + void writeText(const std::string &text); | ||
| 134 | void writePingResp(); | 105 | void writePingResp(); |
| 135 | void writeMqttPacket(const MqttPacket &packet); | 106 | void writeMqttPacket(const MqttPacket &packet); |
| 136 | void writeMqttPacketAndBlameThisClient(const MqttPacket &packet); | 107 | void writeMqttPacketAndBlameThisClient(const MqttPacket &packet); |
| 137 | - ssize_t writeWrap(int fd, const void *buf, size_t nbytes, IoWrapResult *error); | ||
| 138 | bool writeBufIntoFd(); | 108 | bool writeBufIntoFd(); |
| 139 | bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; } | 109 | bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; } |
| 140 | 110 |
exceptions.h
| @@ -34,4 +34,16 @@ public: | @@ -34,4 +34,16 @@ public: | ||
| 34 | AuthPluginException(const std::string &msg) : std::runtime_error(msg) {} | 34 | AuthPluginException(const std::string &msg) : std::runtime_error(msg) {} |
| 35 | }; | 35 | }; |
| 36 | 36 | ||
| 37 | +class BadWebsocketVersionException : public std::runtime_error | ||
| 38 | +{ | ||
| 39 | +public: | ||
| 40 | + BadWebsocketVersionException(const std::string &msg) : std::runtime_error(msg) {} | ||
| 41 | +}; | ||
| 42 | + | ||
| 43 | +class BadHttpRequest : public std::runtime_error | ||
| 44 | +{ | ||
| 45 | +public: | ||
| 46 | + BadHttpRequest(const std::string &msg) : std::runtime_error(msg) {} | ||
| 47 | +}; | ||
| 48 | + | ||
| 37 | #endif // EXCEPTIONS_H | 49 | #endif // EXCEPTIONS_H |
iowrapper.cpp
0 โ 100644
| 1 | +#include "iowrapper.h" | ||
| 2 | + | ||
| 3 | +#include "cassert" | ||
| 4 | + | ||
| 5 | +#include "logger.h" | ||
| 6 | +#include "client.h" | ||
| 7 | + | ||
| 8 | +IncompleteSslWrite::IncompleteSslWrite(const void *buf, size_t nbytes) : | ||
| 9 | + buf(buf), | ||
| 10 | + nbytes(nbytes) | ||
| 11 | +{ | ||
| 12 | + | ||
| 13 | +} | ||
| 14 | + | ||
| 15 | +void IncompleteSslWrite::reset() | ||
| 16 | +{ | ||
| 17 | + buf = nullptr; | ||
| 18 | + nbytes = 0; | ||
| 19 | +} | ||
| 20 | + | ||
| 21 | +bool IncompleteSslWrite::hasPendingWrite() const | ||
| 22 | +{ | ||
| 23 | + return buf != nullptr; | ||
| 24 | +} | ||
| 25 | + | ||
| 26 | +void IncompleteWebsocketRead::reset() | ||
| 27 | +{ | ||
| 28 | + maskingKeyI = 0; | ||
| 29 | + memset(maskingKey,0, 4); | ||
| 30 | + frame_bytes_left = 0; | ||
| 31 | + opcode = WebsocketOpcode::Unknown; | ||
| 32 | +} | ||
| 33 | + | ||
| 34 | +bool IncompleteWebsocketRead::sillWorkingOnFrame() const | ||
| 35 | +{ | ||
| 36 | + return frame_bytes_left > 0; | ||
| 37 | +} | ||
| 38 | + | ||
| 39 | +char IncompleteWebsocketRead::getNextMaskingByte() | ||
| 40 | +{ | ||
| 41 | + return maskingKey[maskingKeyI++ % 4]; | ||
| 42 | +} | ||
| 43 | + | ||
| 44 | +IoWrapper::IoWrapper(SSL *ssl, bool websocket, const size_t initialBufferSize, Client *parent) : | ||
| 45 | + parentClient(parent), | ||
| 46 | + initialBufferSize(initialBufferSize), | ||
| 47 | + ssl(ssl), | ||
| 48 | + websocket(websocket), | ||
| 49 | + websocketPendingBytes(websocket ? initialBufferSize : 0), | ||
| 50 | + websocketWriteRemainder(websocket ? initialBufferSize : 0) | ||
| 51 | +{ | ||
| 52 | + | ||
| 53 | +} | ||
| 54 | + | ||
| 55 | +IoWrapper::~IoWrapper() | ||
| 56 | +{ | ||
| 57 | + if (ssl) | ||
| 58 | + { | ||
| 59 | + // I don't do SSL_shutdown(), because I don't want to keep the session, plus, that takes active de-negiotation, so it can't be done | ||
| 60 | + // in the destructor. | ||
| 61 | + SSL_free(ssl); | ||
| 62 | + } | ||
| 63 | +} | ||
| 64 | + | ||
| 65 | +void IoWrapper::startOrContinueSslAccept() | ||
| 66 | +{ | ||
| 67 | + ERR_clear_error(); | ||
| 68 | + int accepted = SSL_accept(ssl); | ||
| 69 | + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; | ||
| 70 | + if (accepted <= 0) | ||
| 71 | + { | ||
| 72 | + int err = SSL_get_error(ssl, accepted); | ||
| 73 | + | ||
| 74 | + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) | ||
| 75 | + { | ||
| 76 | + parentClient->setReadyForWriting(err == SSL_ERROR_WANT_WRITE); | ||
| 77 | + return; | ||
| 78 | + } | ||
| 79 | + | ||
| 80 | + unsigned long error_code = ERR_get_error(); | ||
| 81 | + | ||
| 82 | + ERR_error_string(error_code, sslErrorBuf); | ||
| 83 | + std::string errorMsg(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); | ||
| 84 | + | ||
| 85 | + if (error_code == OPENSSL_WRONG_VERSION_NUMBER) | ||
| 86 | + errorMsg = "Wrong protocol version number. Probably a non-SSL connection on SSL socket."; | ||
| 87 | + | ||
| 88 | + //ERR_print_errors_cb(logSslError, NULL); | ||
| 89 | + throw std::runtime_error("Problem accepting SSL socket: " + errorMsg); | ||
| 90 | + } | ||
| 91 | + parentClient->setReadyForWriting(false); // Undo write readiness that may have have happened during SSL handshake | ||
| 92 | + sslAccepted = true; | ||
| 93 | +} | ||
| 94 | + | ||
| 95 | +bool IoWrapper::getSslReadWantsWrite() const | ||
| 96 | +{ | ||
| 97 | + return this->sslReadWantsWrite; | ||
| 98 | +} | ||
| 99 | + | ||
| 100 | +bool IoWrapper::getSslWriteWantsRead() const | ||
| 101 | +{ | ||
| 102 | + return sslWriteWantsRead; | ||
| 103 | +} | ||
| 104 | + | ||
| 105 | +bool IoWrapper::isSslAccepted() const | ||
| 106 | +{ | ||
| 107 | + return this->sslAccepted; | ||
| 108 | +} | ||
| 109 | + | ||
| 110 | +bool IoWrapper::isSsl() const | ||
| 111 | +{ | ||
| 112 | + return this->ssl != nullptr; | ||
| 113 | +} | ||
| 114 | + | ||
| 115 | +bool IoWrapper::hasPendingWrite() const | ||
| 116 | +{ | ||
| 117 | + return incompleteSslWrite.hasPendingWrite() || websocketWriteRemainder.usedBytes() > 0; | ||
| 118 | +} | ||
| 119 | + | ||
| 120 | +bool IoWrapper::isWebsocket() const | ||
| 121 | +{ | ||
| 122 | + return websocket; | ||
| 123 | +} | ||
| 124 | + | ||
| 125 | +WebsocketState IoWrapper::getWebsocketState() const | ||
| 126 | +{ | ||
| 127 | + return websocketState; | ||
| 128 | +} | ||
| 129 | + | ||
| 130 | +/** | ||
| 131 | + * @brief SSL and non-SSL sockets behave differently. For one, reading 0 doesn't mean 'disconnected' with an SSL | ||
| 132 | + * socket. This wrapper unifies behavor for the caller. | ||
| 133 | + * | ||
| 134 | + * @param fd | ||
| 135 | + * @param buf | ||
| 136 | + * @param nbytes | ||
| 137 | + * @param error is an out-argument with the result. | ||
| 138 | + * @return | ||
| 139 | + */ | ||
| 140 | +ssize_t IoWrapper::readOrSslRead(int fd, void *buf, size_t nbytes, IoWrapResult *error) | ||
| 141 | +{ | ||
| 142 | + *error = IoWrapResult::Success; | ||
| 143 | + ssize_t n = 0; | ||
| 144 | + if (!ssl) | ||
| 145 | + { | ||
| 146 | + n = read(fd, buf, nbytes); | ||
| 147 | + if (n < 0) | ||
| 148 | + { | ||
| 149 | + if (errno == EINTR) | ||
| 150 | + *error = IoWrapResult::Interrupted; | ||
| 151 | + else if (errno == EAGAIN || errno == EWOULDBLOCK) | ||
| 152 | + *error = IoWrapResult::Wouldblock; | ||
| 153 | + else | ||
| 154 | + check<std::runtime_error>(n); | ||
| 155 | + } | ||
| 156 | + else if (n == 0) | ||
| 157 | + { | ||
| 158 | + *error = IoWrapResult::Disconnected; | ||
| 159 | + } | ||
| 160 | + } | ||
| 161 | + else | ||
| 162 | + { | ||
| 163 | + this->sslReadWantsWrite = false; | ||
| 164 | + ERR_clear_error(); | ||
| 165 | + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; | ||
| 166 | + n = SSL_read(ssl, buf, nbytes); | ||
| 167 | + | ||
| 168 | + if (n <= 0) | ||
| 169 | + { | ||
| 170 | + int err = SSL_get_error(ssl, n); | ||
| 171 | + unsigned long error_code = ERR_get_error(); | ||
| 172 | + | ||
| 173 | + // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html "BUGS" why EOF is seen as SSL_ERROR_SYSCALL. | ||
| 174 | + if (err == SSL_ERROR_ZERO_RETURN || (err == SSL_ERROR_SYSCALL && errno == 0)) | ||
| 175 | + { | ||
| 176 | + *error = IoWrapResult::Disconnected; | ||
| 177 | + } | ||
| 178 | + else if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) | ||
| 179 | + { | ||
| 180 | + *error = IoWrapResult::Wouldblock; | ||
| 181 | + if (err == SSL_ERROR_WANT_WRITE) | ||
| 182 | + { | ||
| 183 | + sslReadWantsWrite = true; | ||
| 184 | + parentClient->setReadyForWriting(true); | ||
| 185 | + } | ||
| 186 | + n = -1; | ||
| 187 | + } | ||
| 188 | + else | ||
| 189 | + { | ||
| 190 | + if (err == SSL_ERROR_SYSCALL) | ||
| 191 | + { | ||
| 192 | + // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say | ||
| 193 | + // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it | ||
| 194 | + // implies EINTR is not included? | ||
| 195 | + if (errno == EINTR) | ||
| 196 | + *error = IoWrapResult::Interrupted; | ||
| 197 | + else | ||
| 198 | + { | ||
| 199 | + char *err = strerror(errno); | ||
| 200 | + std::string msg(err); | ||
| 201 | + throw std::runtime_error("SSL read error: " + msg); | ||
| 202 | + } | ||
| 203 | + } | ||
| 204 | + ERR_error_string(error_code, sslErrorBuf); | ||
| 205 | + std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); | ||
| 206 | + ERR_print_errors_cb(logSslError, NULL); | ||
| 207 | + throw std::runtime_error("SSL socket error reading: " + errorString); | ||
| 208 | + } | ||
| 209 | + } | ||
| 210 | + } | ||
| 211 | + | ||
| 212 | + return n; | ||
| 213 | +} | ||
| 214 | + | ||
| 215 | +// SSL and non-SSL sockets behave differently. This wrapper unifies behavor for the caller. | ||
| 216 | +ssize_t IoWrapper::writeOrSslWrite(int fd, const void *buf, size_t nbytes, IoWrapResult *error) | ||
| 217 | +{ | ||
| 218 | + *error = IoWrapResult::Success; | ||
| 219 | + ssize_t n = 0; | ||
| 220 | + | ||
| 221 | + if (!ssl) | ||
| 222 | + { | ||
| 223 | + // A write on a socket with count=0 is unspecified. | ||
| 224 | + assert(nbytes > 0); | ||
| 225 | + | ||
| 226 | + n = write(fd, buf, nbytes); | ||
| 227 | + if (n < 0) | ||
| 228 | + { | ||
| 229 | + if (errno == EINTR) | ||
| 230 | + *error = IoWrapResult::Interrupted; | ||
| 231 | + else if (errno == EAGAIN || errno == EWOULDBLOCK) | ||
| 232 | + *error = IoWrapResult::Wouldblock; | ||
| 233 | + else | ||
| 234 | + check<std::runtime_error>(n); | ||
| 235 | + } | ||
| 236 | + } | ||
| 237 | + else | ||
| 238 | + { | ||
| 239 | + const void *buf_ = buf; | ||
| 240 | + size_t nbytes_ = nbytes; | ||
| 241 | + | ||
| 242 | + /* | ||
| 243 | + * OpenSSL doc: When a write function call has to be repeated because SSL_get_error(3) returned | ||
| 244 | + * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments | ||
| 245 | + */ | ||
| 246 | + if (this->incompleteSslWrite.hasPendingWrite()) | ||
| 247 | + { | ||
| 248 | + buf_ = this->incompleteSslWrite.buf; | ||
| 249 | + nbytes_ = this->incompleteSslWrite.nbytes; | ||
| 250 | + } | ||
| 251 | + | ||
| 252 | + // OpenSSL: "You should not call SSL_write() with num=0, it will return an error" | ||
| 253 | + assert(nbytes_ > 0); | ||
| 254 | + | ||
| 255 | + this->sslWriteWantsRead = false; | ||
| 256 | + this->incompleteSslWrite.reset(); | ||
| 257 | + | ||
| 258 | + ERR_clear_error(); | ||
| 259 | + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; | ||
| 260 | + n = SSL_write(ssl, buf_, nbytes_); | ||
| 261 | + | ||
| 262 | + if (n <= 0) | ||
| 263 | + { | ||
| 264 | + int err = SSL_get_error(ssl, n); | ||
| 265 | + unsigned long error_code = ERR_get_error(); | ||
| 266 | + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) | ||
| 267 | + { | ||
| 268 | + logger->logf(LOG_DEBUG, "Write is incomplete: %d", err); | ||
| 269 | + *error = IoWrapResult::Wouldblock; | ||
| 270 | + IncompleteSslWrite sslAction(buf_, nbytes_); | ||
| 271 | + this->incompleteSslWrite = sslAction; | ||
| 272 | + if (err == SSL_ERROR_WANT_READ) | ||
| 273 | + this->sslWriteWantsRead = true; | ||
| 274 | + n = 0; | ||
| 275 | + } | ||
| 276 | + else | ||
| 277 | + { | ||
| 278 | + if (err == SSL_ERROR_SYSCALL) | ||
| 279 | + { | ||
| 280 | + // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say | ||
| 281 | + // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it | ||
| 282 | + // implies EINTR is not included? | ||
| 283 | + if (errno == EINTR) | ||
| 284 | + *error = IoWrapResult::Interrupted; | ||
| 285 | + else | ||
| 286 | + { | ||
| 287 | + char *err = strerror(errno); | ||
| 288 | + std::string msg(err); | ||
| 289 | + throw std::runtime_error(msg); | ||
| 290 | + } | ||
| 291 | + } | ||
| 292 | + ERR_error_string(error_code, sslErrorBuf); | ||
| 293 | + std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); | ||
| 294 | + ERR_print_errors_cb(logSslError, NULL); | ||
| 295 | + throw std::runtime_error("SSL socket error writing: " + errorString); | ||
| 296 | + } | ||
| 297 | + } | ||
| 298 | + } | ||
| 299 | + | ||
| 300 | + return n; | ||
| 301 | +} | ||
| 302 | + | ||
| 303 | +// Use a small intermediate buffer to write (partial) websocket frames to our normal read buffer. MQTT is already a frames protocol, so we don't | ||
| 304 | +// care about websocket frames being incomplete. | ||
| 305 | +ssize_t IoWrapper::readWebsocketAndOrSsl(int fd, void *buf, size_t nbytes, IoWrapResult *error) | ||
| 306 | +{ | ||
| 307 | + if (!websocket) | ||
| 308 | + { | ||
| 309 | + return readOrSslRead(fd, buf, nbytes, error); | ||
| 310 | + } | ||
| 311 | + else | ||
| 312 | + { | ||
| 313 | + ssize_t n = 0; | ||
| 314 | + while (websocketPendingBytes.freeSpace() > 0 && (n = readOrSslRead(fd, websocketPendingBytes.headPtr(), websocketPendingBytes.maxWriteSize(), error)) != 0) | ||
| 315 | + { | ||
| 316 | + if (n > 0) | ||
| 317 | + websocketPendingBytes.advanceHead(n); | ||
| 318 | + if (n < 0) | ||
| 319 | + break; // signal/error handling is done by the caller, so we just stop. | ||
| 320 | + | ||
| 321 | + if (websocketState == WebsocketState::NotUpgraded && websocketPendingBytes.freeSpace() == 0) | ||
| 322 | + { | ||
| 323 | + if (websocketPendingBytes.getSize() * 2 <= 8192) | ||
| 324 | + websocketPendingBytes.doubleSize(); | ||
| 325 | + else | ||
| 326 | + throw ProtocolError("Trying to exceed websocket buffer. Probably not valid websocket traffic."); | ||
| 327 | + } | ||
| 328 | + } | ||
| 329 | + | ||
| 330 | + const bool hasWebsocketPendingBytes = websocketPendingBytes.usedBytes() > 0; | ||
| 331 | + | ||
| 332 | + // When some or all the data has been read, we can continue. | ||
| 333 | + if (!(*error == IoWrapResult::Wouldblock || *error == IoWrapResult::Success) && !hasWebsocketPendingBytes) | ||
| 334 | + return n; | ||
| 335 | + | ||
| 336 | + if (hasWebsocketPendingBytes) | ||
| 337 | + { | ||
| 338 | + n = 0; | ||
| 339 | + | ||
| 340 | + if (websocketState == WebsocketState::NotUpgraded) | ||
| 341 | + { | ||
| 342 | + try | ||
| 343 | + { | ||
| 344 | + std::string websocketKey; | ||
| 345 | + int websocketVersion; | ||
| 346 | + if (parseHttpHeader(websocketPendingBytes, websocketKey, websocketVersion)) | ||
| 347 | + { | ||
| 348 | + if (websocketKey.empty()) | ||
| 349 | + throw BadHttpRequest("No websocket key specified."); | ||
| 350 | + if (websocketVersion != 13) | ||
| 351 | + throw BadWebsocketVersionException("Websocket version 13 required."); | ||
| 352 | + | ||
| 353 | + const std::string acceptString = generateWebsocketAcceptString(websocketKey); | ||
| 354 | + | ||
| 355 | + std::string answer = generateWebsocketAnswer(acceptString); | ||
| 356 | + parentClient->writeText(answer); | ||
| 357 | + websocketState = WebsocketState::Upgrading; | ||
| 358 | + websocketPendingBytes.reset(); | ||
| 359 | + websocketPendingBytes.resetSize(initialBufferSize); | ||
| 360 | + *error = IoWrapResult::Success; | ||
| 361 | + } | ||
| 362 | + } | ||
| 363 | + catch (BadWebsocketVersionException &ex) | ||
| 364 | + { | ||
| 365 | + std::string response = generateInvalidWebsocketVersionHttpHeaders(13); | ||
| 366 | + parentClient->writeText(response); | ||
| 367 | + parentClient->setDisconnectReason("Invalid websocket version"); | ||
| 368 | + parentClient->setReadyForDisconnect(); | ||
| 369 | + } | ||
| 370 | + catch (BadHttpRequest &ex) // Should should also properly deal with attempt at HTTP2 with PRI. | ||
| 371 | + { | ||
| 372 | + std::string response = generateBadHttpRequestReponse(ex.what()); | ||
| 373 | + parentClient->writeText(response); | ||
| 374 | + parentClient->setDisconnectReason("Invalid websocket start"); | ||
| 375 | + parentClient->setReadyForDisconnect(); | ||
| 376 | + } | ||
| 377 | + } | ||
| 378 | + else | ||
| 379 | + { | ||
| 380 | + n = websocketBytesToReadBuffer(buf, nbytes); | ||
| 381 | + | ||
| 382 | + if (n > 0) | ||
| 383 | + *error = IoWrapResult::Success; | ||
| 384 | + else if (n == 0) | ||
| 385 | + *error = IoWrapResult::Wouldblock; | ||
| 386 | + } | ||
| 387 | + } | ||
| 388 | + | ||
| 389 | + return n; | ||
| 390 | + } | ||
| 391 | +} | ||
| 392 | + | ||
| 393 | +ssize_t IoWrapper::websocketBytesToReadBuffer(void *buf, const size_t nbytes) | ||
| 394 | +{ | ||
| 395 | + const ssize_t targetBufMaxSize = nbytes; | ||
| 396 | + ssize_t nbytesRead = 0; | ||
| 397 | + | ||
| 398 | + while (websocketPendingBytes.usedBytes() >= WEBSOCKET_MIN_HEADER_BYTES_NEEDED && nbytesRead < targetBufMaxSize) | ||
| 399 | + { | ||
| 400 | + // This block decodes the header. | ||
| 401 | + if (!incompleteWebsocketRead.sillWorkingOnFrame()) | ||
| 402 | + { | ||
| 403 | + const uint8_t byte1 = websocketPendingBytes.peakAhead(0); | ||
| 404 | + const uint8_t byte2 = websocketPendingBytes.peakAhead(1); | ||
| 405 | + bool masked = !!(byte2 & 0b10000000); | ||
| 406 | + uint8_t reserved = (byte1 & 0b01110000) >> 4; | ||
| 407 | + WebsocketOpcode opcode = (WebsocketOpcode)(byte1 & 0b00001111); | ||
| 408 | + const uint8_t payloadLength = byte2 & 0b01111111; | ||
| 409 | + size_t realPayloadLength = payloadLength; | ||
| 410 | + uint64_t extendedPayloadLengthLength = 0; | ||
| 411 | + uint8_t headerLength = masked ? 6 : 2; | ||
| 412 | + | ||
| 413 | + if (payloadLength == 126) | ||
| 414 | + extendedPayloadLengthLength = 2; | ||
| 415 | + else if (payloadLength == 127) | ||
| 416 | + extendedPayloadLengthLength = 8; | ||
| 417 | + headerLength += extendedPayloadLengthLength; | ||
| 418 | + | ||
| 419 | + //if (!masked) | ||
| 420 | + // throw ProtocolError("Client must send masked websocket bytes."); | ||
| 421 | + | ||
| 422 | + if (reserved != 0) | ||
| 423 | + throw ProtocolError("Reserved bytes in header must be 0."); | ||
| 424 | + | ||
| 425 | + if (headerLength > websocketPendingBytes.usedBytes()) | ||
| 426 | + return nbytesRead; | ||
| 427 | + | ||
| 428 | + uint64_t extendedPayloadLength = 0; | ||
| 429 | + | ||
| 430 | + int i = 2; | ||
| 431 | + int shift = extendedPayloadLengthLength * 8; | ||
| 432 | + while (shift > 0) | ||
| 433 | + { | ||
| 434 | + shift -= 8; | ||
| 435 | + uint8_t byte = websocketPendingBytes.peakAhead(i++); | ||
| 436 | + extendedPayloadLength += (byte << shift); | ||
| 437 | + } | ||
| 438 | + | ||
| 439 | + if (extendedPayloadLength > 0) | ||
| 440 | + realPayloadLength = extendedPayloadLength; | ||
| 441 | + | ||
| 442 | + if (headerLength > websocketPendingBytes.usedBytes()) | ||
| 443 | + return nbytesRead; | ||
| 444 | + | ||
| 445 | + if (masked) | ||
| 446 | + { | ||
| 447 | + for (int j = 0; j < 4; j++) | ||
| 448 | + { | ||
| 449 | + incompleteWebsocketRead.maskingKey[j] = websocketPendingBytes.peakAhead(i++); | ||
| 450 | + } | ||
| 451 | + } | ||
| 452 | + | ||
| 453 | + assert(i == headerLength); | ||
| 454 | + assert(headerLength <= websocketPendingBytes.usedBytes()); | ||
| 455 | + websocketPendingBytes.advanceTail(headerLength); | ||
| 456 | + | ||
| 457 | + incompleteWebsocketRead.frame_bytes_left = realPayloadLength; | ||
| 458 | + incompleteWebsocketRead.opcode = opcode; | ||
| 459 | + } | ||
| 460 | + | ||
| 461 | + if (incompleteWebsocketRead.opcode == WebsocketOpcode::Binary) | ||
| 462 | + { | ||
| 463 | + // The following reads one websocket frame max: it will continue with the previous, or start a new one, which it may or may not finish. | ||
| 464 | + size_t targetBufI = 0; | ||
| 465 | + char *targetBuf = &static_cast<char*>(buf)[nbytesRead]; | ||
| 466 | + while(websocketPendingBytes.usedBytes() > 0 && incompleteWebsocketRead.frame_bytes_left > 0 && nbytesRead < targetBufMaxSize) | ||
| 467 | + { | ||
| 468 | + const size_t asManyBytesOfThisFrameAsPossible = std::min<int>(websocketPendingBytes.maxReadSize(), incompleteWebsocketRead.frame_bytes_left); | ||
| 469 | + const size_t maxReadSize = std::min<int>(asManyBytesOfThisFrameAsPossible, targetBufMaxSize - nbytesRead); | ||
| 470 | + assert(maxReadSize > 0); | ||
| 471 | + assert(static_cast<ssize_t>(maxReadSize) + nbytesRead <= targetBufMaxSize); | ||
| 472 | + for (size_t x = 0; x < maxReadSize; x++) | ||
| 473 | + { | ||
| 474 | + targetBuf[targetBufI++] = websocketPendingBytes.tailPtr()[x] ^ incompleteWebsocketRead.getNextMaskingByte(); | ||
| 475 | + } | ||
| 476 | + websocketPendingBytes.advanceTail(maxReadSize); | ||
| 477 | + incompleteWebsocketRead.frame_bytes_left -= maxReadSize; | ||
| 478 | + nbytesRead += maxReadSize; | ||
| 479 | + } | ||
| 480 | + } | ||
| 481 | + else if (incompleteWebsocketRead.opcode == WebsocketOpcode::Ping) | ||
| 482 | + { | ||
| 483 | + // A ping MAY have user data, which needs to be ponged back; | ||
| 484 | + | ||
| 485 | + // Constructing a new temporary buffer because I need the reponse in one frame for writeAsMuchOfBufAsWebsocketFrame(). | ||
| 486 | + std::vector<char> response(incompleteWebsocketRead.frame_bytes_left); | ||
| 487 | + websocketPendingBytes.read(response.data(), response.size()); | ||
| 488 | + | ||
| 489 | + websocketWriteRemainder.ensureFreeSpace(response.size()); | ||
| 490 | + writeAsMuchOfBufAsWebsocketFrame(response.data(), response.size(), WebsocketOpcode::Pong); | ||
| 491 | + parentClient->setReadyForWriting(true); | ||
| 492 | + } | ||
| 493 | + | ||
| 494 | + if (!incompleteWebsocketRead.sillWorkingOnFrame()) | ||
| 495 | + incompleteWebsocketRead.reset(); | ||
| 496 | + } | ||
| 497 | + assert(nbytesRead <= static_cast<ssize_t>(nbytes)); | ||
| 498 | + | ||
| 499 | + return nbytesRead; | ||
| 500 | +} | ||
| 501 | + | ||
| 502 | +/** | ||
| 503 | + * @brief IoWrapper::writeAsMuchOfBufAsWebsocketFrame writes buf of part of buf as websocket frame to websocketWriteRemainder | ||
| 504 | + * @param buf | ||
| 505 | + * @param nbytes. The amount of bytes. Can be 0, for just an empty websocket frame. | ||
| 506 | + * @return | ||
| 507 | + */ | ||
| 508 | +ssize_t IoWrapper::writeAsMuchOfBufAsWebsocketFrame(const void *buf, size_t nbytes, WebsocketOpcode opcode) | ||
| 509 | +{ | ||
| 510 | + // We do allow pong frames to generate a zero payload packet, but for binary, that's not necessary. | ||
| 511 | + if (nbytes == 0 && opcode == WebsocketOpcode::Binary) | ||
| 512 | + return 0; | ||
| 513 | + | ||
| 514 | + ssize_t nBytesReal = 0; | ||
| 515 | + | ||
| 516 | + // We normally wrap each write in a frame, but if a previous one didn't fit in the system's write buffers, we're still working on it. | ||
| 517 | + if (websocketWriteRemainder.freeSpace() > WEBSOCKET_MAX_SENDING_HEADER_SIZE) | ||
| 518 | + { | ||
| 519 | + uint8_t extended_payload_length_num_bytes = 0; | ||
| 520 | + uint8_t payload_length = 0; | ||
| 521 | + if (nbytes < 126) | ||
| 522 | + payload_length = nbytes; | ||
| 523 | + else if (nbytes >= 126 && nbytes <= 0xFFFF) | ||
| 524 | + { | ||
| 525 | + payload_length = 126; | ||
| 526 | + extended_payload_length_num_bytes = 2; | ||
| 527 | + } | ||
| 528 | + else if (nbytes > 0xFFFF) | ||
| 529 | + { | ||
| 530 | + payload_length = 127; | ||
| 531 | + extended_payload_length_num_bytes = 8; | ||
| 532 | + } | ||
| 533 | + | ||
| 534 | + int x = 0; | ||
| 535 | + char header[WEBSOCKET_MAX_SENDING_HEADER_SIZE]; | ||
| 536 | + header[x++] = (0b10000000 | static_cast<char>(opcode)); | ||
| 537 | + header[x++] = payload_length; | ||
| 538 | + | ||
| 539 | + const int header_length = x + extended_payload_length_num_bytes; | ||
| 540 | + | ||
| 541 | + // This block writes the extended payload length. | ||
| 542 | + nBytesReal = std::min<int>(nbytes, websocketWriteRemainder.freeSpace() - header_length); | ||
| 543 | + const uint64_t nbytes64 = nBytesReal; | ||
| 544 | + for (int z = extended_payload_length_num_bytes - 1; z >= 0; z--) | ||
| 545 | + { | ||
| 546 | + header[x++] = (nbytes64 >> (z*8)) & 0xFF; | ||
| 547 | + } | ||
| 548 | + assert(x <= WEBSOCKET_MAX_SENDING_HEADER_SIZE); | ||
| 549 | + assert(x == header_length); | ||
| 550 | + | ||
| 551 | + websocketWriteRemainder.write(header, header_length); | ||
| 552 | + websocketWriteRemainder.write(buf, nBytesReal); | ||
| 553 | + } | ||
| 554 | + | ||
| 555 | + return nBytesReal; | ||
| 556 | +} | ||
| 557 | + | ||
| 558 | +/* | ||
| 559 | + * Mqtt docs: "A single WebSocket data frame can contain multiple or partial MQTT Control Packets. The receiver | ||
| 560 | + * MUST NOT assume that MQTT Control Packets are aligned on WebSocket frame boundaries [MQTT-6.0.0-2]." We | ||
| 561 | + * make use of that here, and wrap each write in a frame. | ||
| 562 | + * | ||
| 563 | + * It's can legitimately return a number of bytes written AND error with 'would block'. So, no need to do that | ||
| 564 | + * repeating of the write thing that SSL_write() has. | ||
| 565 | + */ | ||
| 566 | +ssize_t IoWrapper::writeWebsocketAndOrSsl(int fd, const void *buf, size_t nbytes, IoWrapResult *error) | ||
| 567 | +{ | ||
| 568 | + if (websocketState != WebsocketState::Upgraded) | ||
| 569 | + { | ||
| 570 | + if (websocket && websocketState == WebsocketState::Upgrading) | ||
| 571 | + websocketState = WebsocketState::Upgraded; | ||
| 572 | + | ||
| 573 | + return writeOrSslWrite(fd, buf, nbytes, error); | ||
| 574 | + } | ||
| 575 | + else | ||
| 576 | + { | ||
| 577 | + ssize_t nBytesReal = writeAsMuchOfBufAsWebsocketFrame(buf, nbytes); | ||
| 578 | + | ||
| 579 | + ssize_t n = 0; | ||
| 580 | + while (websocketWriteRemainder.usedBytes() > 0) | ||
| 581 | + { | ||
| 582 | + n = writeOrSslWrite(fd, websocketWriteRemainder.tailPtr(), websocketWriteRemainder.maxReadSize(), error); | ||
| 583 | + | ||
| 584 | + if (n > 0) | ||
| 585 | + websocketWriteRemainder.advanceTail(n); | ||
| 586 | + if (n < 0) | ||
| 587 | + break; | ||
| 588 | + } | ||
| 589 | + | ||
| 590 | + if (n > 0) | ||
| 591 | + return nBytesReal; | ||
| 592 | + | ||
| 593 | + return n; | ||
| 594 | + } | ||
| 595 | +} |
iowrapper.h
0 โ 100644
| 1 | +#ifndef IOWRAPPER_H | ||
| 2 | +#define IOWRAPPER_H | ||
| 3 | + | ||
| 4 | +#include "unistd.h" | ||
| 5 | +#include "openssl/ssl.h" | ||
| 6 | +#include "openssl/err.h" | ||
| 7 | +#include <exception> | ||
| 8 | + | ||
| 9 | +#include "forward_declarations.h" | ||
| 10 | + | ||
| 11 | +#include "types.h" | ||
| 12 | +#include "utils.h" | ||
| 13 | +#include "logger.h" | ||
| 14 | +#include "exceptions.h" | ||
| 15 | + | ||
| 16 | +#define WEBSOCKET_MIN_HEADER_BYTES_NEEDED 2 | ||
| 17 | +#define WEBSOCKET_MAX_SENDING_HEADER_SIZE 10 | ||
| 18 | + | ||
| 19 | +#define OPENSSL_ERROR_STRING_SIZE 256 // OpenSSL requires at least 256. | ||
| 20 | +#define OPENSSL_WRONG_VERSION_NUMBER 336130315 | ||
| 21 | + | ||
| 22 | +enum class IoWrapResult | ||
| 23 | +{ | ||
| 24 | + Success = 0, | ||
| 25 | + Interrupted = 1, | ||
| 26 | + Wouldblock = 2, | ||
| 27 | + Disconnected = 3, | ||
| 28 | + Error = 4 | ||
| 29 | +}; | ||
| 30 | + | ||
| 31 | +enum class WebsocketOpcode | ||
| 32 | +{ | ||
| 33 | + Continuation = 0x00, | ||
| 34 | + Text = 0x1, | ||
| 35 | + Binary = 0x2, | ||
| 36 | + Close = 0x8, | ||
| 37 | + Ping = 0x9, | ||
| 38 | + Pong = 0xA, | ||
| 39 | + Unknown = 0xF | ||
| 40 | +}; | ||
| 41 | + | ||
| 42 | +/* | ||
| 43 | + * OpenSSL doc: "When a write function call has to be repeated because SSL_get_error(3) returned | ||
| 44 | + * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments" | ||
| 45 | + */ | ||
| 46 | +struct IncompleteSslWrite | ||
| 47 | +{ | ||
| 48 | + const void *buf = nullptr; | ||
| 49 | + size_t nbytes = 0; | ||
| 50 | + | ||
| 51 | + IncompleteSslWrite() = default; | ||
| 52 | + IncompleteSslWrite(const void *buf, size_t nbytes); | ||
| 53 | + bool hasPendingWrite() const; | ||
| 54 | + | ||
| 55 | + void reset(); | ||
| 56 | +}; | ||
| 57 | + | ||
| 58 | +struct IncompleteWebsocketRead | ||
| 59 | +{ | ||
| 60 | + size_t frame_bytes_left = 0; | ||
| 61 | + char maskingKey[4]; | ||
| 62 | + int maskingKeyI = 0; | ||
| 63 | + WebsocketOpcode opcode; | ||
| 64 | + | ||
| 65 | + void reset(); | ||
| 66 | + bool sillWorkingOnFrame() const; | ||
| 67 | + char getNextMaskingByte(); | ||
| 68 | +}; | ||
| 69 | + | ||
| 70 | +enum class WebsocketState | ||
| 71 | +{ | ||
| 72 | + NotUpgraded, | ||
| 73 | + Upgrading, | ||
| 74 | + Upgraded | ||
| 75 | +}; | ||
| 76 | + | ||
| 77 | +/** | ||
| 78 | + * @brief provides a unified wrapper for SSL and websockets to read() and write(). | ||
| 79 | + * | ||
| 80 | + * | ||
| 81 | + */ | ||
| 82 | +class IoWrapper | ||
| 83 | +{ | ||
| 84 | + Client *parentClient; | ||
| 85 | + const size_t initialBufferSize; | ||
| 86 | + | ||
| 87 | + SSL *ssl = nullptr; | ||
| 88 | + bool sslAccepted = false; | ||
| 89 | + IncompleteSslWrite incompleteSslWrite; | ||
| 90 | + bool sslReadWantsWrite = false; | ||
| 91 | + bool sslWriteWantsRead = false; | ||
| 92 | + | ||
| 93 | + bool websocket; | ||
| 94 | + WebsocketState websocketState = WebsocketState::NotUpgraded; | ||
| 95 | + CirBuf websocketPendingBytes; | ||
| 96 | + IncompleteWebsocketRead incompleteWebsocketRead; | ||
| 97 | + CirBuf websocketWriteRemainder; | ||
| 98 | + | ||
| 99 | + Logger *logger = Logger::getInstance(); | ||
| 100 | + | ||
| 101 | + ssize_t websocketBytesToReadBuffer(void *buf, const size_t nbytes); | ||
| 102 | + ssize_t readOrSslRead(int fd, void *buf, size_t nbytes, IoWrapResult *error); | ||
| 103 | + ssize_t writeOrSslWrite(int fd, const void *buf, size_t nbytes, IoWrapResult *error); | ||
| 104 | + ssize_t writeAsMuchOfBufAsWebsocketFrame(const void *buf, size_t nbytes, WebsocketOpcode opcode = WebsocketOpcode::Binary); | ||
| 105 | +public: | ||
| 106 | + IoWrapper(SSL *ssl, bool websocket, const size_t initialBufferSize, Client *parent); | ||
| 107 | + ~IoWrapper(); | ||
| 108 | + | ||
| 109 | + void startOrContinueSslAccept(); | ||
| 110 | + bool getSslReadWantsWrite() const; | ||
| 111 | + bool getSslWriteWantsRead() const; | ||
| 112 | + bool isSslAccepted() const; | ||
| 113 | + bool isSsl() const; | ||
| 114 | + bool hasPendingWrite() const; | ||
| 115 | + bool isWebsocket() const; | ||
| 116 | + WebsocketState getWebsocketState() const; | ||
| 117 | + | ||
| 118 | + ssize_t readWebsocketAndOrSsl(int fd, void *buf, size_t nbytes, IoWrapResult *error); | ||
| 119 | + ssize_t writeWebsocketAndOrSsl(int fd, const void *buf, size_t nbytes, IoWrapResult *error); | ||
| 120 | +}; | ||
| 121 | + | ||
| 122 | +#endif // IOWRAPPER_H |
mainapp.cpp
| @@ -351,6 +351,7 @@ void MainApp::start() | @@ -351,6 +351,7 @@ void MainApp::start() | ||
| 351 | 351 | ||
| 352 | int listen_fd_plain = createListenSocket(this->listenPort, false); | 352 | int listen_fd_plain = createListenSocket(this->listenPort, false); |
| 353 | int listen_fd_ssl = createListenSocket(this->sslListenPort, true); | 353 | int listen_fd_ssl = createListenSocket(this->sslListenPort, true); |
| 354 | + int listen_fd_websocket_plain = createListenSocket(1443, true); | ||
| 354 | 355 | ||
| 355 | #ifdef NDEBUG | 356 | #ifdef NDEBUG |
| 356 | logger->noLongerLogToStd(); | 357 | logger->noLongerLogToStd(); |
| @@ -391,7 +392,7 @@ void MainApp::start() | @@ -391,7 +392,7 @@ void MainApp::start() | ||
| 391 | int cur_fd = events[i].data.fd; | 392 | int cur_fd = events[i].data.fd; |
| 392 | try | 393 | try |
| 393 | { | 394 | { |
| 394 | - if (cur_fd == listen_fd_plain || cur_fd == listen_fd_ssl) | 395 | + if (cur_fd == listen_fd_plain || cur_fd == listen_fd_ssl || listen_fd_websocket_plain) |
| 395 | { | 396 | { |
| 396 | std::shared_ptr<ThreadData> thread_data = threads[next_thread_index++ % num_threads]; | 397 | std::shared_ptr<ThreadData> thread_data = threads[next_thread_index++ % num_threads]; |
| 397 | 398 | ||
| @@ -402,6 +403,7 @@ void MainApp::start() | @@ -402,6 +403,7 @@ void MainApp::start() | ||
| 402 | socklen_t len = sizeof(struct sockaddr); | 403 | socklen_t len = sizeof(struct sockaddr); |
| 403 | int fd = check<std::runtime_error>(accept(cur_fd, &addr, &len)); | 404 | int fd = check<std::runtime_error>(accept(cur_fd, &addr, &len)); |
| 404 | 405 | ||
| 406 | + bool websocket = cur_fd == listen_fd_websocket_plain; | ||
| 405 | SSL *clientSSL = nullptr; | 407 | SSL *clientSSL = nullptr; |
| 406 | if (cur_fd == listen_fd_ssl) | 408 | if (cur_fd == listen_fd_ssl) |
| 407 | { | 409 | { |
| @@ -417,7 +419,7 @@ void MainApp::start() | @@ -417,7 +419,7 @@ void MainApp::start() | ||
| 417 | SSL_set_fd(clientSSL, fd); | 419 | SSL_set_fd(clientSSL, fd); |
| 418 | } | 420 | } |
| 419 | 421 | ||
| 420 | - Client_p client(new Client(fd, thread_data, clientSSL, settings)); | 422 | + Client_p client(new Client(fd, thread_data, clientSSL, websocket, settings)); |
| 421 | thread_data->giveClient(client); | 423 | thread_data->giveClient(client); |
| 422 | } | 424 | } |
| 423 | else if (cur_fd == taskEventFd) | 425 | else if (cur_fd == taskEventFd) |
utils.cpp
| @@ -3,8 +3,10 @@ | @@ -3,8 +3,10 @@ | ||
| 3 | #include "sys/time.h" | 3 | #include "sys/time.h" |
| 4 | #include "sys/random.h" | 4 | #include "sys/random.h" |
| 5 | #include <algorithm> | 5 | #include <algorithm> |
| 6 | +#include <sstream> | ||
| 6 | 7 | ||
| 7 | #include "exceptions.h" | 8 | #include "exceptions.h" |
| 9 | +#include "cirbuf.h" | ||
| 8 | 10 | ||
| 9 | std::list<std::__cxx11::string> split(const std::string &input, const char sep, size_t max, bool keep_empty_parts) | 11 | std::list<std::__cxx11::string> split(const std::string &input, const char sep, size_t max, bool keep_empty_parts) |
| 10 | { | 12 | { |
| @@ -226,3 +228,121 @@ bool isPowerOfTwo(int n) | @@ -226,3 +228,121 @@ bool isPowerOfTwo(int n) | ||
| 226 | { | 228 | { |
| 227 | return (n != 0) && (n & (n - 1)) == 0; | 229 | return (n != 0) && (n & (n - 1)) == 0; |
| 228 | } | 230 | } |
| 231 | + | ||
| 232 | +bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version) | ||
| 233 | +{ | ||
| 234 | + const std::string s(buf.tailPtr(), buf.usedBytes()); | ||
| 235 | + std::istringstream is(s); | ||
| 236 | + bool doubleEmptyLine = false; // meaning, the HTTP header is complete | ||
| 237 | + bool upgradeHeaderSeen = false; | ||
| 238 | + bool connectionHeaderSeen = false; | ||
| 239 | + bool firstLine = true; | ||
| 240 | + | ||
| 241 | + std::string line; | ||
| 242 | + while (std::getline(is, line)) | ||
| 243 | + { | ||
| 244 | + trim(line); | ||
| 245 | + if (firstLine) | ||
| 246 | + { | ||
| 247 | + firstLine = false; | ||
| 248 | + if (!startsWith(line, "GET")) | ||
| 249 | + throw BadHttpRequest("Websocket request should start with GET."); | ||
| 250 | + continue; | ||
| 251 | + } | ||
| 252 | + if (line.empty()) | ||
| 253 | + { | ||
| 254 | + doubleEmptyLine = true; | ||
| 255 | + break; | ||
| 256 | + } | ||
| 257 | + | ||
| 258 | + std::list<std::string> fields = split(line, ':', 1); | ||
| 259 | + const std::vector<std::string> fields2(fields.begin(), fields.end()); | ||
| 260 | + std::string name = str_tolower(fields2[0]); | ||
| 261 | + trim(name); | ||
| 262 | + std::string value = fields2[1]; | ||
| 263 | + trim(value); | ||
| 264 | + std::string value_lower = str_tolower(value); | ||
| 265 | + | ||
| 266 | + if (name == "upgrade" && value_lower == "websocket") | ||
| 267 | + upgradeHeaderSeen = true; | ||
| 268 | + else if (name == "connection" && value_lower == "upgrade") | ||
| 269 | + connectionHeaderSeen = true; | ||
| 270 | + else if (name == "sec-websocket-key") | ||
| 271 | + websocket_key = value; | ||
| 272 | + else if (name == "sec-websocket-version") | ||
| 273 | + websocket_version = stoi(value); | ||
| 274 | + } | ||
| 275 | + | ||
| 276 | + if (doubleEmptyLine) | ||
| 277 | + { | ||
| 278 | + if (!connectionHeaderSeen || !upgradeHeaderSeen) | ||
| 279 | + throw BadHttpRequest("HTTP request is not a websocket upgrade request."); | ||
| 280 | + } | ||
| 281 | + | ||
| 282 | + return doubleEmptyLine; | ||
| 283 | +} | ||
| 284 | + | ||
| 285 | +std::string base64Encode(const unsigned char *input, const int length) | ||
| 286 | +{ | ||
| 287 | + const int pl = 4*((length+2)/3); | ||
| 288 | + char *output = reinterpret_cast<char *>(calloc(pl+1, 1)); | ||
| 289 | + const int ol = EVP_EncodeBlock(reinterpret_cast<unsigned char *>(output), input, length); | ||
| 290 | + std::string result(output); | ||
| 291 | + free(output); | ||
| 292 | + | ||
| 293 | + if (pl != ol) | ||
| 294 | + throw std::runtime_error("Base64 encode error."); | ||
| 295 | + | ||
| 296 | + return result; | ||
| 297 | +} | ||
| 298 | + | ||
| 299 | +std::string generateWebsocketAcceptString(const std::string &websocketKey) | ||
| 300 | +{ | ||
| 301 | + unsigned char md_value[EVP_MAX_MD_SIZE]; | ||
| 302 | + unsigned int md_len; | ||
| 303 | + | ||
| 304 | + EVP_MD_CTX *mdctx = EVP_MD_CTX_new();; | ||
| 305 | + const EVP_MD *md = EVP_sha1(); | ||
| 306 | + EVP_DigestInit_ex(mdctx, md, NULL); | ||
| 307 | + | ||
| 308 | + const std::string keyPlusMagic = websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; | ||
| 309 | + | ||
| 310 | + EVP_DigestUpdate(mdctx, keyPlusMagic.c_str(), keyPlusMagic.length()); | ||
| 311 | + EVP_DigestFinal_ex(mdctx, md_value, &md_len); | ||
| 312 | + EVP_MD_CTX_free(mdctx); | ||
| 313 | + | ||
| 314 | + std::string base64 = base64Encode(md_value, md_len); | ||
| 315 | + return base64; | ||
| 316 | +} | ||
| 317 | + | ||
| 318 | +std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion) | ||
| 319 | +{ | ||
| 320 | + std::ostringstream oss; | ||
| 321 | + oss << "HTTP/1.1 400 Bad Request\r\n"; | ||
| 322 | + oss << "Sec-WebSocket-Version: " << wantedVersion; | ||
| 323 | + oss << "\r\n"; | ||
| 324 | + oss.flush(); | ||
| 325 | + return oss.str(); | ||
| 326 | +} | ||
| 327 | + | ||
| 328 | +std::string generateBadHttpRequestReponse(const std::string &msg) | ||
| 329 | +{ | ||
| 330 | + std::ostringstream oss; | ||
| 331 | + oss << "HTTP/1.1 400 Bad Request\r\n"; | ||
| 332 | + oss << "\r\n"; | ||
| 333 | + oss << msg; | ||
| 334 | + oss.flush(); | ||
| 335 | + return oss.str(); | ||
| 336 | +} | ||
| 337 | + | ||
| 338 | +std::string generateWebsocketAnswer(const std::string &acceptString) | ||
| 339 | +{ | ||
| 340 | + std::ostringstream oss; | ||
| 341 | + oss << "HTTP/1.1 101 Switching Protocols\r\n"; | ||
| 342 | + oss << "Upgrade: websocket\r\n"; | ||
| 343 | + oss << "Connection: Upgrade\r\n"; | ||
| 344 | + oss << "Sec-WebSocket-Accept: " << acceptString << "\r\n"; | ||
| 345 | + oss << "\r\n"; | ||
| 346 | + oss.flush(); | ||
| 347 | + return oss.str(); | ||
| 348 | +} |
utils.h
| @@ -8,6 +8,9 @@ | @@ -8,6 +8,9 @@ | ||
| 8 | #include <limits> | 8 | #include <limits> |
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | #include <algorithm> | 10 | #include <algorithm> |
| 11 | +#include <openssl/evp.h> | ||
| 12 | + | ||
| 13 | +#include "cirbuf.h" | ||
| 11 | 14 | ||
| 12 | template<typename T> int check(int rc) | 15 | template<typename T> int check(int rc) |
| 13 | { | 16 | { |
| @@ -43,4 +46,13 @@ std::string str_tolower(std::string s); | @@ -43,4 +46,13 @@ std::string str_tolower(std::string s); | ||
| 43 | bool stringTruthiness(const std::string &val); | 46 | bool stringTruthiness(const std::string &val); |
| 44 | bool isPowerOfTwo(int val); | 47 | bool isPowerOfTwo(int val); |
| 45 | 48 | ||
| 49 | +bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version); | ||
| 50 | + | ||
| 51 | +std::string base64Encode(const unsigned char *input, const int length); | ||
| 52 | +std::string generateWebsocketAcceptString(const std::string &websocketKey); | ||
| 53 | + | ||
| 54 | +std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion); | ||
| 55 | +std::string generateBadHttpRequestReponse(const std::string &msg); | ||
| 56 | +std::string generateWebsocketAnswer(const std::string &acceptString); | ||
| 57 | + | ||
| 46 | #endif // UTILS_H | 58 | #endif // UTILS_H |