Commit b5ba41f503c602c0edcdeec0db5b2b8344050cb8
1 parent
1106e210
Add IoWrapper, with websocket support added
The ping/pong is actually untested at this point, because Paho (my test client for now) doesn't do those. I wonder if any do, because MQTT already has ping/pong.
Showing
11 changed files
with
988 additions
and
268 deletions
CMakeLists.txt
cirbuf.cpp
| ... | ... | @@ -9,10 +9,14 @@ |
| 9 | 9 | #include <cstring> |
| 10 | 10 | |
| 11 | 11 | #include "logger.h" |
| 12 | +#include "utils.h" | |
| 12 | 13 | |
| 13 | 14 | CirBuf::CirBuf(size_t size) : |
| 14 | 15 | size(size) |
| 15 | 16 | { |
| 17 | + if (size == 0) | |
| 18 | + return; | |
| 19 | + | |
| 16 | 20 | buf = (char*)malloc(size); |
| 17 | 21 | |
| 18 | 22 | if (buf == NULL) |
| ... | ... | @@ -69,12 +73,14 @@ char *CirBuf::tailPtr() |
| 69 | 73 | |
| 70 | 74 | void CirBuf::advanceHead(uint32_t n) |
| 71 | 75 | { |
| 76 | + assert(n <= freeSpace()); | |
| 72 | 77 | head = (head + n) & (size -1); |
| 73 | 78 | assert(tail != head); // Putting things in the buffer must never end on tail, because tail == head == empty. |
| 74 | 79 | } |
| 75 | 80 | |
| 76 | 81 | void CirBuf::advanceTail(uint32_t n) |
| 77 | 82 | { |
| 83 | + assert(n <= usedBytes()); | |
| 78 | 84 | tail = (tail + n) & (size -1); |
| 79 | 85 | } |
| 80 | 86 | |
| ... | ... | @@ -84,9 +90,31 @@ char CirBuf::peakAhead(uint32_t offset) const |
| 84 | 90 | return b; |
| 85 | 91 | } |
| 86 | 92 | |
| 87 | -void CirBuf::doubleSize() | |
| 93 | +void CirBuf::ensureFreeSpace(size_t n) | |
| 94 | +{ | |
| 95 | + const size_t _usedBytes = usedBytes(); | |
| 96 | + | |
| 97 | + int mul = 1; | |
| 98 | + | |
| 99 | + while((mul * size - _usedBytes - 1) < n) | |
| 100 | + { | |
| 101 | + mul = mul << 1; | |
| 102 | + } | |
| 103 | + | |
| 104 | + if (mul == 1) | |
| 105 | + return; | |
| 106 | + | |
| 107 | + doubleSize(mul); | |
| 108 | +} | |
| 109 | + | |
| 110 | +void CirBuf::doubleSize(uint factor) | |
| 88 | 111 | { |
| 89 | - uint newSize = size * 2; | |
| 112 | + if (factor == 1) | |
| 113 | + return; | |
| 114 | + | |
| 115 | + assert(isPowerOfTwo(factor)); | |
| 116 | + | |
| 117 | + uint newSize = size * factor; | |
| 90 | 118 | char *newBuf = (char*)realloc(buf, newSize); |
| 91 | 119 | |
| 92 | 120 | if (newBuf == NULL) |
| ... | ... | @@ -145,3 +173,55 @@ void CirBuf::resetSize(size_t newSize) |
| 145 | 173 | memset(buf, 0, newSize); |
| 146 | 174 | #endif |
| 147 | 175 | } |
| 176 | + | |
| 177 | +void CirBuf::reset() | |
| 178 | +{ | |
| 179 | + head = 0; | |
| 180 | + tail = 0; | |
| 181 | + | |
| 182 | +#ifndef NDEBUG | |
| 183 | + memset(buf, 0, size); | |
| 184 | +#endif | |
| 185 | +} | |
| 186 | + | |
| 187 | +// When you know the data you want to write fits in the buffer, use this. | |
| 188 | +void CirBuf::write(const void *buf, size_t count) | |
| 189 | +{ | |
| 190 | + assert(count <= freeSpace()); | |
| 191 | + | |
| 192 | + ssize_t len_left = count; | |
| 193 | + size_t src_i = 0; | |
| 194 | + while (len_left > 0) | |
| 195 | + { | |
| 196 | + const size_t len = std::min<int>(len_left, maxWriteSize()); | |
| 197 | + assert(len > 0); | |
| 198 | + const char *src = &reinterpret_cast<const char*>(buf)[src_i]; | |
| 199 | + std::memcpy(headPtr(), src, len); | |
| 200 | + advanceHead(len); | |
| 201 | + src_i += len; | |
| 202 | + len_left -= len; | |
| 203 | + } | |
| 204 | + assert(len_left == 0); | |
| 205 | + assert(src_i == count); | |
| 206 | +} | |
| 207 | + | |
| 208 | +// When you know 'count' bytes are present and you want to read them into buf | |
| 209 | +void CirBuf::read(void *buf, const size_t count) | |
| 210 | +{ | |
| 211 | + assert(count <= usedBytes()); | |
| 212 | + | |
| 213 | + char *_buf = static_cast<char*>(buf); | |
| 214 | + int i = 0; | |
| 215 | + ssize_t _packet_len = count; | |
| 216 | + while (_packet_len > 0) | |
| 217 | + { | |
| 218 | + const int readlen = std::min<int>(maxReadSize(), _packet_len); | |
| 219 | + assert(readlen > 0); | |
| 220 | + std::memcpy(&_buf[i], tailPtr(), readlen); | |
| 221 | + advanceTail(readlen); | |
| 222 | + i += readlen; | |
| 223 | + _packet_len -= readlen; | |
| 224 | + } | |
| 225 | + assert(_packet_len == 0); | |
| 226 | + assert(i == _packet_len); | |
| 227 | +} | ... | ... |
cirbuf.h
| ... | ... | @@ -32,11 +32,16 @@ public: |
| 32 | 32 | void advanceHead(uint32_t n); |
| 33 | 33 | void advanceTail(uint32_t n); |
| 34 | 34 | char peakAhead(uint32_t offset) const; |
| 35 | - void doubleSize(); | |
| 35 | + void ensureFreeSpace(size_t n); | |
| 36 | + void doubleSize(uint factor = 2); | |
| 36 | 37 | uint32_t getSize() const; |
| 37 | 38 | |
| 38 | 39 | time_t bufferLastResizedSecondsAgo() const; |
| 39 | 40 | void resetSize(size_t size); |
| 41 | + void reset(); | |
| 42 | + | |
| 43 | + void write(const void *buf, size_t count); | |
| 44 | + void read(void *buf, const size_t count); | |
| 40 | 45 | }; |
| 41 | 46 | |
| 42 | 47 | #endif // CIRBUF_H | ... | ... |
client.cpp
| ... | ... | @@ -7,11 +7,11 @@ |
| 7 | 7 | |
| 8 | 8 | #include "logger.h" |
| 9 | 9 | |
| 10 | -Client::Client(int fd, ThreadData_p threadData, SSL *ssl, const GlobalSettings &settings) : | |
| 10 | +Client::Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, const GlobalSettings &settings) : | |
| 11 | 11 | fd(fd), |
| 12 | - ssl(ssl), | |
| 13 | 12 | initialBufferSize(settings.clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy |
| 14 | 13 | maxPacketSize(settings.maxPacketSize), // Same as initialBufferSize comment. |
| 14 | + ioWrapper(ssl, websocket, initialBufferSize, this), | |
| 15 | 15 | readbuf(initialBufferSize), |
| 16 | 16 | writebuf(initialBufferSize), |
| 17 | 17 | threadData(threadData) |
| ... | ... | @@ -40,63 +40,32 @@ Client::~Client() |
| 40 | 40 | logger->logf(LOG_NOTICE, "Removing client '%s'. Reason(s): %s", repr().c_str(), disconnectReason.c_str()); |
| 41 | 41 | if (epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL) != 0) |
| 42 | 42 | logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd, repr().c_str(), strerror(errno)); |
| 43 | - if (ssl) | |
| 44 | - { | |
| 45 | - // I don't do SSL_shutdown(), because I don't want to keep the session, plus, that takes active de-negiotation, so it can't be done | |
| 46 | - // in the destructor. | |
| 47 | - SSL_free(ssl); | |
| 48 | - } | |
| 49 | 43 | close(fd); |
| 50 | 44 | } |
| 51 | 45 | |
| 52 | 46 | bool Client::isSslAccepted() const |
| 53 | 47 | { |
| 54 | - return sslAccepted; | |
| 48 | + return ioWrapper.isSslAccepted(); | |
| 55 | 49 | } |
| 56 | 50 | |
| 57 | 51 | bool Client::isSsl() const |
| 58 | 52 | { |
| 59 | - return this->ssl != nullptr; | |
| 53 | + return ioWrapper.isSsl(); | |
| 60 | 54 | } |
| 61 | 55 | |
| 62 | 56 | bool Client::getSslReadWantsWrite() const |
| 63 | 57 | { |
| 64 | - return this->sslReadWantsWrite; | |
| 58 | + return ioWrapper.getSslReadWantsWrite(); | |
| 65 | 59 | } |
| 66 | 60 | |
| 67 | 61 | bool Client::getSslWriteWantsRead() const |
| 68 | 62 | { |
| 69 | - return this->sslWriteWantsRead; | |
| 63 | + return ioWrapper.getSslWriteWantsRead(); | |
| 70 | 64 | } |
| 71 | 65 | |
| 72 | 66 | void Client::startOrContinueSslAccept() |
| 73 | 67 | { |
| 74 | - ERR_clear_error(); | |
| 75 | - int accepted = SSL_accept(ssl); | |
| 76 | - char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; | |
| 77 | - if (accepted <= 0) | |
| 78 | - { | |
| 79 | - int err = SSL_get_error(ssl, accepted); | |
| 80 | - | |
| 81 | - if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) | |
| 82 | - { | |
| 83 | - setReadyForWriting(err == SSL_ERROR_WANT_WRITE); | |
| 84 | - return; | |
| 85 | - } | |
| 86 | - | |
| 87 | - unsigned long error_code = ERR_get_error(); | |
| 88 | - | |
| 89 | - ERR_error_string(error_code, sslErrorBuf); | |
| 90 | - std::string errorMsg(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); | |
| 91 | - | |
| 92 | - if (error_code == OPENSSL_WRONG_VERSION_NUMBER) | |
| 93 | - errorMsg = "Wrong protocol version number. Probably a non-SSL connection on SSL socket."; | |
| 94 | - | |
| 95 | - //ERR_print_errors_cb(logSslError, NULL); | |
| 96 | - throw std::runtime_error("Problem accepting SSL socket: " + errorMsg); | |
| 97 | - } | |
| 98 | - setReadyForWriting(false); // Undo write readiness that may have have happened during SSL handshake | |
| 99 | - sslAccepted = true; | |
| 68 | + ioWrapper.startOrContinueSslAccept(); | |
| 100 | 69 | } |
| 101 | 70 | |
| 102 | 71 | // Causes future activity on the client to cause a disconnect. |
| ... | ... | @@ -108,83 +77,6 @@ void Client::markAsDisconnecting() |
| 108 | 77 | disconnecting = true; |
| 109 | 78 | } |
| 110 | 79 | |
| 111 | -// SSL and non-SSL sockets behave differently. For one, reading 0 doesn't mean 'disconnected' with an SSL | |
| 112 | -// socket. This wrapper unifies behavor for the caller. | |
| 113 | -ssize_t Client::readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error) | |
| 114 | -{ | |
| 115 | - *error = IoWrapResult::Success; | |
| 116 | - ssize_t n = 0; | |
| 117 | - if (!ssl) | |
| 118 | - { | |
| 119 | - n = read(fd, buf, nbytes); | |
| 120 | - if (n < 0) | |
| 121 | - { | |
| 122 | - if (errno == EINTR) | |
| 123 | - *error = IoWrapResult::Interrupted; | |
| 124 | - else if (errno == EAGAIN || errno == EWOULDBLOCK) | |
| 125 | - *error = IoWrapResult::Wouldblock; | |
| 126 | - else | |
| 127 | - check<std::runtime_error>(n); | |
| 128 | - } | |
| 129 | - else if (n == 0) | |
| 130 | - { | |
| 131 | - *error = IoWrapResult::Disconnected; | |
| 132 | - } | |
| 133 | - } | |
| 134 | - else | |
| 135 | - { | |
| 136 | - this->sslReadWantsWrite = false; | |
| 137 | - ERR_clear_error(); | |
| 138 | - char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; | |
| 139 | - n = SSL_read(ssl, buf, nbytes); | |
| 140 | - | |
| 141 | - if (n <= 0) | |
| 142 | - { | |
| 143 | - int err = SSL_get_error(ssl, n); | |
| 144 | - unsigned long error_code = ERR_get_error(); | |
| 145 | - | |
| 146 | - // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html "BUGS" why EOF is seen as SSL_ERROR_SYSCALL. | |
| 147 | - if (err == SSL_ERROR_ZERO_RETURN || (err == SSL_ERROR_SYSCALL && errno == 0)) | |
| 148 | - { | |
| 149 | - *error = IoWrapResult::Disconnected; | |
| 150 | - } | |
| 151 | - else if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) | |
| 152 | - { | |
| 153 | - *error = IoWrapResult::Wouldblock; | |
| 154 | - if (err == SSL_ERROR_WANT_WRITE) | |
| 155 | - { | |
| 156 | - sslReadWantsWrite = true; | |
| 157 | - setReadyForWriting(true); | |
| 158 | - } | |
| 159 | - n = -1; | |
| 160 | - } | |
| 161 | - else | |
| 162 | - { | |
| 163 | - if (err == SSL_ERROR_SYSCALL) | |
| 164 | - { | |
| 165 | - // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say | |
| 166 | - // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it | |
| 167 | - // implies EINTR is not included? | |
| 168 | - if (errno == EINTR) | |
| 169 | - *error = IoWrapResult::Interrupted; | |
| 170 | - else | |
| 171 | - { | |
| 172 | - char *err = strerror(errno); | |
| 173 | - std::string msg(err); | |
| 174 | - throw std::runtime_error("SSL read error: " + msg); | |
| 175 | - } | |
| 176 | - } | |
| 177 | - ERR_error_string(error_code, sslErrorBuf); | |
| 178 | - std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); | |
| 179 | - ERR_print_errors_cb(logSslError, NULL); | |
| 180 | - throw std::runtime_error("SSL socket error reading: " + errorString); | |
| 181 | - } | |
| 182 | - } | |
| 183 | - } | |
| 184 | - | |
| 185 | - return n; | |
| 186 | -} | |
| 187 | - | |
| 188 | 80 | // false means any kind of error we want to get rid of the client for. |
| 189 | 81 | bool Client::readFdIntoBuffer() |
| 190 | 82 | { |
| ... | ... | @@ -193,7 +85,7 @@ bool Client::readFdIntoBuffer() |
| 193 | 85 | |
| 194 | 86 | IoWrapResult error = IoWrapResult::Success; |
| 195 | 87 | int n = 0; |
| 196 | - while (readbuf.freeSpace() > 0 && (n = readWrap(fd, readbuf.headPtr(), readbuf.maxWriteSize(), &error)) != 0) | |
| 88 | + while (readbuf.freeSpace() > 0 && (n = ioWrapper.readWebsocketAndOrSsl(fd, readbuf.headPtr(), readbuf.maxWriteSize(), &error)) != 0) | |
| 197 | 89 | { |
| 198 | 90 | if (n > 0) |
| 199 | 91 | { |
| ... | ... | @@ -232,6 +124,20 @@ bool Client::readFdIntoBuffer() |
| 232 | 124 | return true; |
| 233 | 125 | } |
| 234 | 126 | |
| 127 | +void Client::writeText(const std::string &text) | |
| 128 | +{ | |
| 129 | + assert(ioWrapper.isWebsocket()); | |
| 130 | + assert(ioWrapper.getWebsocketState() == WebsocketState::NotUpgraded); | |
| 131 | + | |
| 132 | + // Not necessary, because at this point, no other threads write to this client, but including for clarity. | |
| 133 | + std::lock_guard<std::mutex> locker(writeBufMutex); | |
| 134 | + | |
| 135 | + writebuf.ensureFreeSpace(text.size()); | |
| 136 | + writebuf.write(text.c_str(), text.length()); | |
| 137 | + | |
| 138 | + setReadyForWriting(true); | |
| 139 | +} | |
| 140 | + | |
| 235 | 141 | void Client::writeMqttPacket(const MqttPacket &packet) |
| 236 | 142 | { |
| 237 | 143 | std::lock_guard<std::mutex> locker(writeBufMutex); |
| ... | ... | @@ -322,94 +228,6 @@ void Client::writePingResp() |
| 322 | 228 | setReadyForWriting(true); |
| 323 | 229 | } |
| 324 | 230 | |
| 325 | -// SSL and non-SSL sockets behave differently. This wrapper unifies behavor for the caller. | |
| 326 | -ssize_t Client::writeWrap(int fd, const void *buf, size_t nbytes, IoWrapResult *error) | |
| 327 | -{ | |
| 328 | - *error = IoWrapResult::Success; | |
| 329 | - ssize_t n = 0; | |
| 330 | - | |
| 331 | - if (!ssl) | |
| 332 | - { | |
| 333 | - // A write on a socket with count=0 is unspecified. | |
| 334 | - assert(nbytes > 0); | |
| 335 | - | |
| 336 | - n = write(fd, buf, nbytes); | |
| 337 | - if (n < 0) | |
| 338 | - { | |
| 339 | - if (errno == EINTR) | |
| 340 | - *error = IoWrapResult::Interrupted; | |
| 341 | - else if (errno == EAGAIN || errno == EWOULDBLOCK) | |
| 342 | - *error = IoWrapResult::Wouldblock; | |
| 343 | - else | |
| 344 | - check<std::runtime_error>(n); | |
| 345 | - } | |
| 346 | - } | |
| 347 | - else | |
| 348 | - { | |
| 349 | - const void *buf_ = buf; | |
| 350 | - size_t nbytes_ = nbytes; | |
| 351 | - | |
| 352 | - /* | |
| 353 | - * OpenSSL doc: When a write function call has to be repeated because SSL_get_error(3) returned | |
| 354 | - * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments | |
| 355 | - */ | |
| 356 | - if (this->incompleteSslWrite.hasPendingWrite()) | |
| 357 | - { | |
| 358 | - buf_ = this->incompleteSslWrite.buf; | |
| 359 | - nbytes_ = this->incompleteSslWrite.nbytes; | |
| 360 | - } | |
| 361 | - | |
| 362 | - // OpenSSL: "You should not call SSL_write() with num=0, it will return an error" | |
| 363 | - assert(nbytes_ > 0); | |
| 364 | - | |
| 365 | - this->sslWriteWantsRead = false; | |
| 366 | - this->incompleteSslWrite.reset(); | |
| 367 | - | |
| 368 | - ERR_clear_error(); | |
| 369 | - char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; | |
| 370 | - n = SSL_write(ssl, buf_, nbytes_); | |
| 371 | - | |
| 372 | - if (n <= 0) | |
| 373 | - { | |
| 374 | - int err = SSL_get_error(ssl, n); | |
| 375 | - unsigned long error_code = ERR_get_error(); | |
| 376 | - if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) | |
| 377 | - { | |
| 378 | - logger->logf(LOG_DEBUG, "Write is incomplete: %d", err); | |
| 379 | - *error = IoWrapResult::Wouldblock; | |
| 380 | - IncompleteSslWrite sslAction(buf_, nbytes_); | |
| 381 | - this->incompleteSslWrite = sslAction; | |
| 382 | - if (err == SSL_ERROR_WANT_READ) | |
| 383 | - this->sslWriteWantsRead = true; | |
| 384 | - n = 0; | |
| 385 | - } | |
| 386 | - else | |
| 387 | - { | |
| 388 | - if (err == SSL_ERROR_SYSCALL) | |
| 389 | - { | |
| 390 | - // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say | |
| 391 | - // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it | |
| 392 | - // implies EINTR is not included? | |
| 393 | - if (errno == EINTR) | |
| 394 | - *error = IoWrapResult::Interrupted; | |
| 395 | - else | |
| 396 | - { | |
| 397 | - char *err = strerror(errno); | |
| 398 | - std::string msg(err); | |
| 399 | - throw std::runtime_error(msg); | |
| 400 | - } | |
| 401 | - } | |
| 402 | - ERR_error_string(error_code, sslErrorBuf); | |
| 403 | - std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); | |
| 404 | - ERR_print_errors_cb(logSslError, NULL); | |
| 405 | - throw std::runtime_error("SSL socket error writing: " + errorString); | |
| 406 | - } | |
| 407 | - } | |
| 408 | - } | |
| 409 | - | |
| 410 | - return n; | |
| 411 | -} | |
| 412 | - | |
| 413 | 231 | bool Client::writeBufIntoFd() |
| 414 | 232 | { |
| 415 | 233 | std::unique_lock<std::mutex> lock(writeBufMutex, std::try_to_lock); |
| ... | ... | @@ -422,9 +240,9 @@ bool Client::writeBufIntoFd() |
| 422 | 240 | |
| 423 | 241 | IoWrapResult error = IoWrapResult::Success; |
| 424 | 242 | int n; |
| 425 | - while (writebuf.usedBytes() > 0 || incompleteSslWrite.hasPendingWrite()) | |
| 243 | + while (writebuf.usedBytes() > 0 || ioWrapper.hasPendingWrite()) | |
| 426 | 244 | { |
| 427 | - n = writeWrap(fd, writebuf.tailPtr(), writebuf.maxReadSize(), &error); | |
| 245 | + n = ioWrapper.writeWebsocketAndOrSsl(fd, writebuf.tailPtr(), writebuf.maxReadSize(), &error); | |
| 428 | 246 | |
| 429 | 247 | if (n > 0) |
| 430 | 248 | writebuf.advanceTail(n); |
| ... | ... | @@ -484,7 +302,7 @@ void Client::setReadyForWriting(bool val) |
| 484 | 302 | if (disconnecting) |
| 485 | 303 | return; |
| 486 | 304 | |
| 487 | - if (sslReadWantsWrite) | |
| 305 | + if (ioWrapper.getSslReadWantsWrite()) | |
| 488 | 306 | val = true; |
| 489 | 307 | |
| 490 | 308 | if (val == this->readyForWriting) |
| ... | ... | @@ -622,20 +440,3 @@ void Client::clearWill() |
| 622 | 440 | will_qos = 0; |
| 623 | 441 | } |
| 624 | 442 | |
| 625 | -IncompleteSslWrite::IncompleteSslWrite(const void *buf, size_t nbytes) : | |
| 626 | - buf(buf), | |
| 627 | - nbytes(nbytes) | |
| 628 | -{ | |
| 629 | - | |
| 630 | -} | |
| 631 | - | |
| 632 | -void IncompleteSslWrite::reset() | |
| 633 | -{ | |
| 634 | - buf = nullptr; | |
| 635 | - nbytes = 0; | |
| 636 | -} | |
| 637 | - | |
| 638 | -bool IncompleteSslWrite::hasPendingWrite() | |
| 639 | -{ | |
| 640 | - return buf != nullptr; | |
| 641 | -} | ... | ... |
client.h
| ... | ... | @@ -9,6 +9,7 @@ |
| 9 | 9 | #include <time.h> |
| 10 | 10 | |
| 11 | 11 | #include <openssl/ssl.h> |
| 12 | +#include <openssl/err.h> | |
| 12 | 13 | |
| 13 | 14 | #include "forward_declarations.h" |
| 14 | 15 | |
| ... | ... | @@ -17,54 +18,25 @@ |
| 17 | 18 | #include "exceptions.h" |
| 18 | 19 | #include "cirbuf.h" |
| 19 | 20 | #include "types.h" |
| 20 | - | |
| 21 | -#include <openssl/ssl.h> | |
| 22 | -#include <openssl/err.h> | |
| 21 | +#include "iowrapper.h" | |
| 23 | 22 | |
| 24 | 23 | #define MQTT_HEADER_LENGH 2 |
| 25 | 24 | |
| 26 | -#define OPENSSL_ERROR_STRING_SIZE 256 // OpenSSL requires at least 256. | |
| 27 | -#define OPENSSL_WRONG_VERSION_NUMBER 336130315 | |
| 28 | - | |
| 29 | -enum class IoWrapResult | |
| 30 | -{ | |
| 31 | - Success = 0, | |
| 32 | - Interrupted = 1, | |
| 33 | - Wouldblock = 2, | |
| 34 | - Disconnected = 3, | |
| 35 | - Error = 4 | |
| 36 | -}; | |
| 37 | - | |
| 38 | -/* | |
| 39 | - * OpenSSL doc: "When a write function call has to be repeated because SSL_get_error(3) returned | |
| 40 | - * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments" | |
| 41 | - */ | |
| 42 | -struct IncompleteSslWrite | |
| 43 | -{ | |
| 44 | - const void *buf = nullptr; | |
| 45 | - size_t nbytes = 0; | |
| 46 | - | |
| 47 | - IncompleteSslWrite() = default; | |
| 48 | - IncompleteSslWrite(const void *buf, size_t nbytes); | |
| 49 | - bool hasPendingWrite(); | |
| 50 | - | |
| 51 | - void reset(); | |
| 52 | -}; | |
| 53 | 25 | |
| 54 | 26 | // TODO: give accepted addr, for showing in logs |
| 55 | 27 | class Client |
| 56 | 28 | { |
| 29 | + friend class IoWrapper; | |
| 30 | + | |
| 57 | 31 | int fd; |
| 58 | - SSL *ssl = nullptr; | |
| 59 | - bool sslAccepted = false; | |
| 60 | - IncompleteSslWrite incompleteSslWrite; | |
| 61 | - bool sslReadWantsWrite = false; | |
| 62 | - bool sslWriteWantsRead = false; | |
| 32 | + | |
| 63 | 33 | ProtocolVersion protocolVersion = ProtocolVersion::None; |
| 64 | 34 | |
| 65 | 35 | const size_t initialBufferSize = 0; |
| 66 | 36 | const size_t maxPacketSize = 0; |
| 67 | 37 | |
| 38 | + IoWrapper ioWrapper; | |
| 39 | + | |
| 68 | 40 | CirBuf readbuf; |
| 69 | 41 | uint8_t readBufIsZeroCount = 0; |
| 70 | 42 | |
| ... | ... | @@ -97,12 +69,11 @@ class Client |
| 97 | 69 | |
| 98 | 70 | Logger *logger = Logger::getInstance(); |
| 99 | 71 | |
| 100 | - | |
| 101 | 72 | void setReadyForWriting(bool val); |
| 102 | 73 | void setReadyForReading(bool val); |
| 103 | 74 | |
| 104 | 75 | public: |
| 105 | - Client(int fd, ThreadData_p threadData, SSL *ssl, const GlobalSettings &settings); | |
| 76 | + Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, const GlobalSettings &settings); | |
| 106 | 77 | Client(const Client &other) = delete; |
| 107 | 78 | Client(Client &&other) = delete; |
| 108 | 79 | ~Client(); |
| ... | ... | @@ -115,7 +86,6 @@ public: |
| 115 | 86 | |
| 116 | 87 | void startOrContinueSslAccept(); |
| 117 | 88 | void markAsDisconnecting(); |
| 118 | - ssize_t readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error); | |
| 119 | 89 | bool readFdIntoBuffer(); |
| 120 | 90 | bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender); |
| 121 | 91 | void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession); |
| ... | ... | @@ -131,10 +101,10 @@ public: |
| 131 | 101 | std::shared_ptr<Session> getSession(); |
| 132 | 102 | void setDisconnectReason(const std::string &reason); |
| 133 | 103 | |
| 104 | + void writeText(const std::string &text); | |
| 134 | 105 | void writePingResp(); |
| 135 | 106 | void writeMqttPacket(const MqttPacket &packet); |
| 136 | 107 | void writeMqttPacketAndBlameThisClient(const MqttPacket &packet); |
| 137 | - ssize_t writeWrap(int fd, const void *buf, size_t nbytes, IoWrapResult *error); | |
| 138 | 108 | bool writeBufIntoFd(); |
| 139 | 109 | bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; } |
| 140 | 110 | ... | ... |
exceptions.h
| ... | ... | @@ -34,4 +34,16 @@ public: |
| 34 | 34 | AuthPluginException(const std::string &msg) : std::runtime_error(msg) {} |
| 35 | 35 | }; |
| 36 | 36 | |
| 37 | +class BadWebsocketVersionException : public std::runtime_error | |
| 38 | +{ | |
| 39 | +public: | |
| 40 | + BadWebsocketVersionException(const std::string &msg) : std::runtime_error(msg) {} | |
| 41 | +}; | |
| 42 | + | |
| 43 | +class BadHttpRequest : public std::runtime_error | |
| 44 | +{ | |
| 45 | +public: | |
| 46 | + BadHttpRequest(const std::string &msg) : std::runtime_error(msg) {} | |
| 47 | +}; | |
| 48 | + | |
| 37 | 49 | #endif // EXCEPTIONS_H | ... | ... |
iowrapper.cpp
0 โ 100644
| 1 | +#include "iowrapper.h" | |
| 2 | + | |
| 3 | +#include "cassert" | |
| 4 | + | |
| 5 | +#include "logger.h" | |
| 6 | +#include "client.h" | |
| 7 | + | |
| 8 | +IncompleteSslWrite::IncompleteSslWrite(const void *buf, size_t nbytes) : | |
| 9 | + buf(buf), | |
| 10 | + nbytes(nbytes) | |
| 11 | +{ | |
| 12 | + | |
| 13 | +} | |
| 14 | + | |
| 15 | +void IncompleteSslWrite::reset() | |
| 16 | +{ | |
| 17 | + buf = nullptr; | |
| 18 | + nbytes = 0; | |
| 19 | +} | |
| 20 | + | |
| 21 | +bool IncompleteSslWrite::hasPendingWrite() const | |
| 22 | +{ | |
| 23 | + return buf != nullptr; | |
| 24 | +} | |
| 25 | + | |
| 26 | +void IncompleteWebsocketRead::reset() | |
| 27 | +{ | |
| 28 | + maskingKeyI = 0; | |
| 29 | + memset(maskingKey,0, 4); | |
| 30 | + frame_bytes_left = 0; | |
| 31 | + opcode = WebsocketOpcode::Unknown; | |
| 32 | +} | |
| 33 | + | |
| 34 | +bool IncompleteWebsocketRead::sillWorkingOnFrame() const | |
| 35 | +{ | |
| 36 | + return frame_bytes_left > 0; | |
| 37 | +} | |
| 38 | + | |
| 39 | +char IncompleteWebsocketRead::getNextMaskingByte() | |
| 40 | +{ | |
| 41 | + return maskingKey[maskingKeyI++ % 4]; | |
| 42 | +} | |
| 43 | + | |
| 44 | +IoWrapper::IoWrapper(SSL *ssl, bool websocket, const size_t initialBufferSize, Client *parent) : | |
| 45 | + parentClient(parent), | |
| 46 | + initialBufferSize(initialBufferSize), | |
| 47 | + ssl(ssl), | |
| 48 | + websocket(websocket), | |
| 49 | + websocketPendingBytes(websocket ? initialBufferSize : 0), | |
| 50 | + websocketWriteRemainder(websocket ? initialBufferSize : 0) | |
| 51 | +{ | |
| 52 | + | |
| 53 | +} | |
| 54 | + | |
| 55 | +IoWrapper::~IoWrapper() | |
| 56 | +{ | |
| 57 | + if (ssl) | |
| 58 | + { | |
| 59 | + // I don't do SSL_shutdown(), because I don't want to keep the session, plus, that takes active de-negiotation, so it can't be done | |
| 60 | + // in the destructor. | |
| 61 | + SSL_free(ssl); | |
| 62 | + } | |
| 63 | +} | |
| 64 | + | |
| 65 | +void IoWrapper::startOrContinueSslAccept() | |
| 66 | +{ | |
| 67 | + ERR_clear_error(); | |
| 68 | + int accepted = SSL_accept(ssl); | |
| 69 | + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; | |
| 70 | + if (accepted <= 0) | |
| 71 | + { | |
| 72 | + int err = SSL_get_error(ssl, accepted); | |
| 73 | + | |
| 74 | + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) | |
| 75 | + { | |
| 76 | + parentClient->setReadyForWriting(err == SSL_ERROR_WANT_WRITE); | |
| 77 | + return; | |
| 78 | + } | |
| 79 | + | |
| 80 | + unsigned long error_code = ERR_get_error(); | |
| 81 | + | |
| 82 | + ERR_error_string(error_code, sslErrorBuf); | |
| 83 | + std::string errorMsg(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); | |
| 84 | + | |
| 85 | + if (error_code == OPENSSL_WRONG_VERSION_NUMBER) | |
| 86 | + errorMsg = "Wrong protocol version number. Probably a non-SSL connection on SSL socket."; | |
| 87 | + | |
| 88 | + //ERR_print_errors_cb(logSslError, NULL); | |
| 89 | + throw std::runtime_error("Problem accepting SSL socket: " + errorMsg); | |
| 90 | + } | |
| 91 | + parentClient->setReadyForWriting(false); // Undo write readiness that may have have happened during SSL handshake | |
| 92 | + sslAccepted = true; | |
| 93 | +} | |
| 94 | + | |
| 95 | +bool IoWrapper::getSslReadWantsWrite() const | |
| 96 | +{ | |
| 97 | + return this->sslReadWantsWrite; | |
| 98 | +} | |
| 99 | + | |
| 100 | +bool IoWrapper::getSslWriteWantsRead() const | |
| 101 | +{ | |
| 102 | + return sslWriteWantsRead; | |
| 103 | +} | |
| 104 | + | |
| 105 | +bool IoWrapper::isSslAccepted() const | |
| 106 | +{ | |
| 107 | + return this->sslAccepted; | |
| 108 | +} | |
| 109 | + | |
| 110 | +bool IoWrapper::isSsl() const | |
| 111 | +{ | |
| 112 | + return this->ssl != nullptr; | |
| 113 | +} | |
| 114 | + | |
| 115 | +bool IoWrapper::hasPendingWrite() const | |
| 116 | +{ | |
| 117 | + return incompleteSslWrite.hasPendingWrite() || websocketWriteRemainder.usedBytes() > 0; | |
| 118 | +} | |
| 119 | + | |
| 120 | +bool IoWrapper::isWebsocket() const | |
| 121 | +{ | |
| 122 | + return websocket; | |
| 123 | +} | |
| 124 | + | |
| 125 | +WebsocketState IoWrapper::getWebsocketState() const | |
| 126 | +{ | |
| 127 | + return websocketState; | |
| 128 | +} | |
| 129 | + | |
| 130 | +/** | |
| 131 | + * @brief SSL and non-SSL sockets behave differently. For one, reading 0 doesn't mean 'disconnected' with an SSL | |
| 132 | + * socket. This wrapper unifies behavor for the caller. | |
| 133 | + * | |
| 134 | + * @param fd | |
| 135 | + * @param buf | |
| 136 | + * @param nbytes | |
| 137 | + * @param error is an out-argument with the result. | |
| 138 | + * @return | |
| 139 | + */ | |
| 140 | +ssize_t IoWrapper::readOrSslRead(int fd, void *buf, size_t nbytes, IoWrapResult *error) | |
| 141 | +{ | |
| 142 | + *error = IoWrapResult::Success; | |
| 143 | + ssize_t n = 0; | |
| 144 | + if (!ssl) | |
| 145 | + { | |
| 146 | + n = read(fd, buf, nbytes); | |
| 147 | + if (n < 0) | |
| 148 | + { | |
| 149 | + if (errno == EINTR) | |
| 150 | + *error = IoWrapResult::Interrupted; | |
| 151 | + else if (errno == EAGAIN || errno == EWOULDBLOCK) | |
| 152 | + *error = IoWrapResult::Wouldblock; | |
| 153 | + else | |
| 154 | + check<std::runtime_error>(n); | |
| 155 | + } | |
| 156 | + else if (n == 0) | |
| 157 | + { | |
| 158 | + *error = IoWrapResult::Disconnected; | |
| 159 | + } | |
| 160 | + } | |
| 161 | + else | |
| 162 | + { | |
| 163 | + this->sslReadWantsWrite = false; | |
| 164 | + ERR_clear_error(); | |
| 165 | + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; | |
| 166 | + n = SSL_read(ssl, buf, nbytes); | |
| 167 | + | |
| 168 | + if (n <= 0) | |
| 169 | + { | |
| 170 | + int err = SSL_get_error(ssl, n); | |
| 171 | + unsigned long error_code = ERR_get_error(); | |
| 172 | + | |
| 173 | + // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html "BUGS" why EOF is seen as SSL_ERROR_SYSCALL. | |
| 174 | + if (err == SSL_ERROR_ZERO_RETURN || (err == SSL_ERROR_SYSCALL && errno == 0)) | |
| 175 | + { | |
| 176 | + *error = IoWrapResult::Disconnected; | |
| 177 | + } | |
| 178 | + else if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) | |
| 179 | + { | |
| 180 | + *error = IoWrapResult::Wouldblock; | |
| 181 | + if (err == SSL_ERROR_WANT_WRITE) | |
| 182 | + { | |
| 183 | + sslReadWantsWrite = true; | |
| 184 | + parentClient->setReadyForWriting(true); | |
| 185 | + } | |
| 186 | + n = -1; | |
| 187 | + } | |
| 188 | + else | |
| 189 | + { | |
| 190 | + if (err == SSL_ERROR_SYSCALL) | |
| 191 | + { | |
| 192 | + // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say | |
| 193 | + // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it | |
| 194 | + // implies EINTR is not included? | |
| 195 | + if (errno == EINTR) | |
| 196 | + *error = IoWrapResult::Interrupted; | |
| 197 | + else | |
| 198 | + { | |
| 199 | + char *err = strerror(errno); | |
| 200 | + std::string msg(err); | |
| 201 | + throw std::runtime_error("SSL read error: " + msg); | |
| 202 | + } | |
| 203 | + } | |
| 204 | + ERR_error_string(error_code, sslErrorBuf); | |
| 205 | + std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); | |
| 206 | + ERR_print_errors_cb(logSslError, NULL); | |
| 207 | + throw std::runtime_error("SSL socket error reading: " + errorString); | |
| 208 | + } | |
| 209 | + } | |
| 210 | + } | |
| 211 | + | |
| 212 | + return n; | |
| 213 | +} | |
| 214 | + | |
| 215 | +// SSL and non-SSL sockets behave differently. This wrapper unifies behavor for the caller. | |
| 216 | +ssize_t IoWrapper::writeOrSslWrite(int fd, const void *buf, size_t nbytes, IoWrapResult *error) | |
| 217 | +{ | |
| 218 | + *error = IoWrapResult::Success; | |
| 219 | + ssize_t n = 0; | |
| 220 | + | |
| 221 | + if (!ssl) | |
| 222 | + { | |
| 223 | + // A write on a socket with count=0 is unspecified. | |
| 224 | + assert(nbytes > 0); | |
| 225 | + | |
| 226 | + n = write(fd, buf, nbytes); | |
| 227 | + if (n < 0) | |
| 228 | + { | |
| 229 | + if (errno == EINTR) | |
| 230 | + *error = IoWrapResult::Interrupted; | |
| 231 | + else if (errno == EAGAIN || errno == EWOULDBLOCK) | |
| 232 | + *error = IoWrapResult::Wouldblock; | |
| 233 | + else | |
| 234 | + check<std::runtime_error>(n); | |
| 235 | + } | |
| 236 | + } | |
| 237 | + else | |
| 238 | + { | |
| 239 | + const void *buf_ = buf; | |
| 240 | + size_t nbytes_ = nbytes; | |
| 241 | + | |
| 242 | + /* | |
| 243 | + * OpenSSL doc: When a write function call has to be repeated because SSL_get_error(3) returned | |
| 244 | + * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments | |
| 245 | + */ | |
| 246 | + if (this->incompleteSslWrite.hasPendingWrite()) | |
| 247 | + { | |
| 248 | + buf_ = this->incompleteSslWrite.buf; | |
| 249 | + nbytes_ = this->incompleteSslWrite.nbytes; | |
| 250 | + } | |
| 251 | + | |
| 252 | + // OpenSSL: "You should not call SSL_write() with num=0, it will return an error" | |
| 253 | + assert(nbytes_ > 0); | |
| 254 | + | |
| 255 | + this->sslWriteWantsRead = false; | |
| 256 | + this->incompleteSslWrite.reset(); | |
| 257 | + | |
| 258 | + ERR_clear_error(); | |
| 259 | + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE]; | |
| 260 | + n = SSL_write(ssl, buf_, nbytes_); | |
| 261 | + | |
| 262 | + if (n <= 0) | |
| 263 | + { | |
| 264 | + int err = SSL_get_error(ssl, n); | |
| 265 | + unsigned long error_code = ERR_get_error(); | |
| 266 | + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) | |
| 267 | + { | |
| 268 | + logger->logf(LOG_DEBUG, "Write is incomplete: %d", err); | |
| 269 | + *error = IoWrapResult::Wouldblock; | |
| 270 | + IncompleteSslWrite sslAction(buf_, nbytes_); | |
| 271 | + this->incompleteSslWrite = sslAction; | |
| 272 | + if (err == SSL_ERROR_WANT_READ) | |
| 273 | + this->sslWriteWantsRead = true; | |
| 274 | + n = 0; | |
| 275 | + } | |
| 276 | + else | |
| 277 | + { | |
| 278 | + if (err == SSL_ERROR_SYSCALL) | |
| 279 | + { | |
| 280 | + // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say | |
| 281 | + // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it | |
| 282 | + // implies EINTR is not included? | |
| 283 | + if (errno == EINTR) | |
| 284 | + *error = IoWrapResult::Interrupted; | |
| 285 | + else | |
| 286 | + { | |
| 287 | + char *err = strerror(errno); | |
| 288 | + std::string msg(err); | |
| 289 | + throw std::runtime_error(msg); | |
| 290 | + } | |
| 291 | + } | |
| 292 | + ERR_error_string(error_code, sslErrorBuf); | |
| 293 | + std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE); | |
| 294 | + ERR_print_errors_cb(logSslError, NULL); | |
| 295 | + throw std::runtime_error("SSL socket error writing: " + errorString); | |
| 296 | + } | |
| 297 | + } | |
| 298 | + } | |
| 299 | + | |
| 300 | + return n; | |
| 301 | +} | |
| 302 | + | |
| 303 | +// Use a small intermediate buffer to write (partial) websocket frames to our normal read buffer. MQTT is already a frames protocol, so we don't | |
| 304 | +// care about websocket frames being incomplete. | |
| 305 | +ssize_t IoWrapper::readWebsocketAndOrSsl(int fd, void *buf, size_t nbytes, IoWrapResult *error) | |
| 306 | +{ | |
| 307 | + if (!websocket) | |
| 308 | + { | |
| 309 | + return readOrSslRead(fd, buf, nbytes, error); | |
| 310 | + } | |
| 311 | + else | |
| 312 | + { | |
| 313 | + ssize_t n = 0; | |
| 314 | + while (websocketPendingBytes.freeSpace() > 0 && (n = readOrSslRead(fd, websocketPendingBytes.headPtr(), websocketPendingBytes.maxWriteSize(), error)) != 0) | |
| 315 | + { | |
| 316 | + if (n > 0) | |
| 317 | + websocketPendingBytes.advanceHead(n); | |
| 318 | + if (n < 0) | |
| 319 | + break; // signal/error handling is done by the caller, so we just stop. | |
| 320 | + | |
| 321 | + if (websocketState == WebsocketState::NotUpgraded && websocketPendingBytes.freeSpace() == 0) | |
| 322 | + { | |
| 323 | + if (websocketPendingBytes.getSize() * 2 <= 8192) | |
| 324 | + websocketPendingBytes.doubleSize(); | |
| 325 | + else | |
| 326 | + throw ProtocolError("Trying to exceed websocket buffer. Probably not valid websocket traffic."); | |
| 327 | + } | |
| 328 | + } | |
| 329 | + | |
| 330 | + const bool hasWebsocketPendingBytes = websocketPendingBytes.usedBytes() > 0; | |
| 331 | + | |
| 332 | + // When some or all the data has been read, we can continue. | |
| 333 | + if (!(*error == IoWrapResult::Wouldblock || *error == IoWrapResult::Success) && !hasWebsocketPendingBytes) | |
| 334 | + return n; | |
| 335 | + | |
| 336 | + if (hasWebsocketPendingBytes) | |
| 337 | + { | |
| 338 | + n = 0; | |
| 339 | + | |
| 340 | + if (websocketState == WebsocketState::NotUpgraded) | |
| 341 | + { | |
| 342 | + try | |
| 343 | + { | |
| 344 | + std::string websocketKey; | |
| 345 | + int websocketVersion; | |
| 346 | + if (parseHttpHeader(websocketPendingBytes, websocketKey, websocketVersion)) | |
| 347 | + { | |
| 348 | + if (websocketKey.empty()) | |
| 349 | + throw BadHttpRequest("No websocket key specified."); | |
| 350 | + if (websocketVersion != 13) | |
| 351 | + throw BadWebsocketVersionException("Websocket version 13 required."); | |
| 352 | + | |
| 353 | + const std::string acceptString = generateWebsocketAcceptString(websocketKey); | |
| 354 | + | |
| 355 | + std::string answer = generateWebsocketAnswer(acceptString); | |
| 356 | + parentClient->writeText(answer); | |
| 357 | + websocketState = WebsocketState::Upgrading; | |
| 358 | + websocketPendingBytes.reset(); | |
| 359 | + websocketPendingBytes.resetSize(initialBufferSize); | |
| 360 | + *error = IoWrapResult::Success; | |
| 361 | + } | |
| 362 | + } | |
| 363 | + catch (BadWebsocketVersionException &ex) | |
| 364 | + { | |
| 365 | + std::string response = generateInvalidWebsocketVersionHttpHeaders(13); | |
| 366 | + parentClient->writeText(response); | |
| 367 | + parentClient->setDisconnectReason("Invalid websocket version"); | |
| 368 | + parentClient->setReadyForDisconnect(); | |
| 369 | + } | |
| 370 | + catch (BadHttpRequest &ex) // Should should also properly deal with attempt at HTTP2 with PRI. | |
| 371 | + { | |
| 372 | + std::string response = generateBadHttpRequestReponse(ex.what()); | |
| 373 | + parentClient->writeText(response); | |
| 374 | + parentClient->setDisconnectReason("Invalid websocket start"); | |
| 375 | + parentClient->setReadyForDisconnect(); | |
| 376 | + } | |
| 377 | + } | |
| 378 | + else | |
| 379 | + { | |
| 380 | + n = websocketBytesToReadBuffer(buf, nbytes); | |
| 381 | + | |
| 382 | + if (n > 0) | |
| 383 | + *error = IoWrapResult::Success; | |
| 384 | + else if (n == 0) | |
| 385 | + *error = IoWrapResult::Wouldblock; | |
| 386 | + } | |
| 387 | + } | |
| 388 | + | |
| 389 | + return n; | |
| 390 | + } | |
| 391 | +} | |
| 392 | + | |
| 393 | +ssize_t IoWrapper::websocketBytesToReadBuffer(void *buf, const size_t nbytes) | |
| 394 | +{ | |
| 395 | + const ssize_t targetBufMaxSize = nbytes; | |
| 396 | + ssize_t nbytesRead = 0; | |
| 397 | + | |
| 398 | + while (websocketPendingBytes.usedBytes() >= WEBSOCKET_MIN_HEADER_BYTES_NEEDED && nbytesRead < targetBufMaxSize) | |
| 399 | + { | |
| 400 | + // This block decodes the header. | |
| 401 | + if (!incompleteWebsocketRead.sillWorkingOnFrame()) | |
| 402 | + { | |
| 403 | + const uint8_t byte1 = websocketPendingBytes.peakAhead(0); | |
| 404 | + const uint8_t byte2 = websocketPendingBytes.peakAhead(1); | |
| 405 | + bool masked = !!(byte2 & 0b10000000); | |
| 406 | + uint8_t reserved = (byte1 & 0b01110000) >> 4; | |
| 407 | + WebsocketOpcode opcode = (WebsocketOpcode)(byte1 & 0b00001111); | |
| 408 | + const uint8_t payloadLength = byte2 & 0b01111111; | |
| 409 | + size_t realPayloadLength = payloadLength; | |
| 410 | + uint64_t extendedPayloadLengthLength = 0; | |
| 411 | + uint8_t headerLength = masked ? 6 : 2; | |
| 412 | + | |
| 413 | + if (payloadLength == 126) | |
| 414 | + extendedPayloadLengthLength = 2; | |
| 415 | + else if (payloadLength == 127) | |
| 416 | + extendedPayloadLengthLength = 8; | |
| 417 | + headerLength += extendedPayloadLengthLength; | |
| 418 | + | |
| 419 | + //if (!masked) | |
| 420 | + // throw ProtocolError("Client must send masked websocket bytes."); | |
| 421 | + | |
| 422 | + if (reserved != 0) | |
| 423 | + throw ProtocolError("Reserved bytes in header must be 0."); | |
| 424 | + | |
| 425 | + if (headerLength > websocketPendingBytes.usedBytes()) | |
| 426 | + return nbytesRead; | |
| 427 | + | |
| 428 | + uint64_t extendedPayloadLength = 0; | |
| 429 | + | |
| 430 | + int i = 2; | |
| 431 | + int shift = extendedPayloadLengthLength * 8; | |
| 432 | + while (shift > 0) | |
| 433 | + { | |
| 434 | + shift -= 8; | |
| 435 | + uint8_t byte = websocketPendingBytes.peakAhead(i++); | |
| 436 | + extendedPayloadLength += (byte << shift); | |
| 437 | + } | |
| 438 | + | |
| 439 | + if (extendedPayloadLength > 0) | |
| 440 | + realPayloadLength = extendedPayloadLength; | |
| 441 | + | |
| 442 | + if (headerLength > websocketPendingBytes.usedBytes()) | |
| 443 | + return nbytesRead; | |
| 444 | + | |
| 445 | + if (masked) | |
| 446 | + { | |
| 447 | + for (int j = 0; j < 4; j++) | |
| 448 | + { | |
| 449 | + incompleteWebsocketRead.maskingKey[j] = websocketPendingBytes.peakAhead(i++); | |
| 450 | + } | |
| 451 | + } | |
| 452 | + | |
| 453 | + assert(i == headerLength); | |
| 454 | + assert(headerLength <= websocketPendingBytes.usedBytes()); | |
| 455 | + websocketPendingBytes.advanceTail(headerLength); | |
| 456 | + | |
| 457 | + incompleteWebsocketRead.frame_bytes_left = realPayloadLength; | |
| 458 | + incompleteWebsocketRead.opcode = opcode; | |
| 459 | + } | |
| 460 | + | |
| 461 | + if (incompleteWebsocketRead.opcode == WebsocketOpcode::Binary) | |
| 462 | + { | |
| 463 | + // The following reads one websocket frame max: it will continue with the previous, or start a new one, which it may or may not finish. | |
| 464 | + size_t targetBufI = 0; | |
| 465 | + char *targetBuf = &static_cast<char*>(buf)[nbytesRead]; | |
| 466 | + while(websocketPendingBytes.usedBytes() > 0 && incompleteWebsocketRead.frame_bytes_left > 0 && nbytesRead < targetBufMaxSize) | |
| 467 | + { | |
| 468 | + const size_t asManyBytesOfThisFrameAsPossible = std::min<int>(websocketPendingBytes.maxReadSize(), incompleteWebsocketRead.frame_bytes_left); | |
| 469 | + const size_t maxReadSize = std::min<int>(asManyBytesOfThisFrameAsPossible, targetBufMaxSize - nbytesRead); | |
| 470 | + assert(maxReadSize > 0); | |
| 471 | + assert(static_cast<ssize_t>(maxReadSize) + nbytesRead <= targetBufMaxSize); | |
| 472 | + for (size_t x = 0; x < maxReadSize; x++) | |
| 473 | + { | |
| 474 | + targetBuf[targetBufI++] = websocketPendingBytes.tailPtr()[x] ^ incompleteWebsocketRead.getNextMaskingByte(); | |
| 475 | + } | |
| 476 | + websocketPendingBytes.advanceTail(maxReadSize); | |
| 477 | + incompleteWebsocketRead.frame_bytes_left -= maxReadSize; | |
| 478 | + nbytesRead += maxReadSize; | |
| 479 | + } | |
| 480 | + } | |
| 481 | + else if (incompleteWebsocketRead.opcode == WebsocketOpcode::Ping) | |
| 482 | + { | |
| 483 | + // A ping MAY have user data, which needs to be ponged back; | |
| 484 | + | |
| 485 | + // Constructing a new temporary buffer because I need the reponse in one frame for writeAsMuchOfBufAsWebsocketFrame(). | |
| 486 | + std::vector<char> response(incompleteWebsocketRead.frame_bytes_left); | |
| 487 | + websocketPendingBytes.read(response.data(), response.size()); | |
| 488 | + | |
| 489 | + websocketWriteRemainder.ensureFreeSpace(response.size()); | |
| 490 | + writeAsMuchOfBufAsWebsocketFrame(response.data(), response.size(), WebsocketOpcode::Pong); | |
| 491 | + parentClient->setReadyForWriting(true); | |
| 492 | + } | |
| 493 | + | |
| 494 | + if (!incompleteWebsocketRead.sillWorkingOnFrame()) | |
| 495 | + incompleteWebsocketRead.reset(); | |
| 496 | + } | |
| 497 | + assert(nbytesRead <= static_cast<ssize_t>(nbytes)); | |
| 498 | + | |
| 499 | + return nbytesRead; | |
| 500 | +} | |
| 501 | + | |
| 502 | +/** | |
| 503 | + * @brief IoWrapper::writeAsMuchOfBufAsWebsocketFrame writes buf of part of buf as websocket frame to websocketWriteRemainder | |
| 504 | + * @param buf | |
| 505 | + * @param nbytes. The amount of bytes. Can be 0, for just an empty websocket frame. | |
| 506 | + * @return | |
| 507 | + */ | |
| 508 | +ssize_t IoWrapper::writeAsMuchOfBufAsWebsocketFrame(const void *buf, size_t nbytes, WebsocketOpcode opcode) | |
| 509 | +{ | |
| 510 | + // We do allow pong frames to generate a zero payload packet, but for binary, that's not necessary. | |
| 511 | + if (nbytes == 0 && opcode == WebsocketOpcode::Binary) | |
| 512 | + return 0; | |
| 513 | + | |
| 514 | + ssize_t nBytesReal = 0; | |
| 515 | + | |
| 516 | + // We normally wrap each write in a frame, but if a previous one didn't fit in the system's write buffers, we're still working on it. | |
| 517 | + if (websocketWriteRemainder.freeSpace() > WEBSOCKET_MAX_SENDING_HEADER_SIZE) | |
| 518 | + { | |
| 519 | + uint8_t extended_payload_length_num_bytes = 0; | |
| 520 | + uint8_t payload_length = 0; | |
| 521 | + if (nbytes < 126) | |
| 522 | + payload_length = nbytes; | |
| 523 | + else if (nbytes >= 126 && nbytes <= 0xFFFF) | |
| 524 | + { | |
| 525 | + payload_length = 126; | |
| 526 | + extended_payload_length_num_bytes = 2; | |
| 527 | + } | |
| 528 | + else if (nbytes > 0xFFFF) | |
| 529 | + { | |
| 530 | + payload_length = 127; | |
| 531 | + extended_payload_length_num_bytes = 8; | |
| 532 | + } | |
| 533 | + | |
| 534 | + int x = 0; | |
| 535 | + char header[WEBSOCKET_MAX_SENDING_HEADER_SIZE]; | |
| 536 | + header[x++] = (0b10000000 | static_cast<char>(opcode)); | |
| 537 | + header[x++] = payload_length; | |
| 538 | + | |
| 539 | + const int header_length = x + extended_payload_length_num_bytes; | |
| 540 | + | |
| 541 | + // This block writes the extended payload length. | |
| 542 | + nBytesReal = std::min<int>(nbytes, websocketWriteRemainder.freeSpace() - header_length); | |
| 543 | + const uint64_t nbytes64 = nBytesReal; | |
| 544 | + for (int z = extended_payload_length_num_bytes - 1; z >= 0; z--) | |
| 545 | + { | |
| 546 | + header[x++] = (nbytes64 >> (z*8)) & 0xFF; | |
| 547 | + } | |
| 548 | + assert(x <= WEBSOCKET_MAX_SENDING_HEADER_SIZE); | |
| 549 | + assert(x == header_length); | |
| 550 | + | |
| 551 | + websocketWriteRemainder.write(header, header_length); | |
| 552 | + websocketWriteRemainder.write(buf, nBytesReal); | |
| 553 | + } | |
| 554 | + | |
| 555 | + return nBytesReal; | |
| 556 | +} | |
| 557 | + | |
| 558 | +/* | |
| 559 | + * Mqtt docs: "A single WebSocket data frame can contain multiple or partial MQTT Control Packets. The receiver | |
| 560 | + * MUST NOT assume that MQTT Control Packets are aligned on WebSocket frame boundaries [MQTT-6.0.0-2]." We | |
| 561 | + * make use of that here, and wrap each write in a frame. | |
| 562 | + * | |
| 563 | + * It's can legitimately return a number of bytes written AND error with 'would block'. So, no need to do that | |
| 564 | + * repeating of the write thing that SSL_write() has. | |
| 565 | + */ | |
| 566 | +ssize_t IoWrapper::writeWebsocketAndOrSsl(int fd, const void *buf, size_t nbytes, IoWrapResult *error) | |
| 567 | +{ | |
| 568 | + if (websocketState != WebsocketState::Upgraded) | |
| 569 | + { | |
| 570 | + if (websocket && websocketState == WebsocketState::Upgrading) | |
| 571 | + websocketState = WebsocketState::Upgraded; | |
| 572 | + | |
| 573 | + return writeOrSslWrite(fd, buf, nbytes, error); | |
| 574 | + } | |
| 575 | + else | |
| 576 | + { | |
| 577 | + ssize_t nBytesReal = writeAsMuchOfBufAsWebsocketFrame(buf, nbytes); | |
| 578 | + | |
| 579 | + ssize_t n = 0; | |
| 580 | + while (websocketWriteRemainder.usedBytes() > 0) | |
| 581 | + { | |
| 582 | + n = writeOrSslWrite(fd, websocketWriteRemainder.tailPtr(), websocketWriteRemainder.maxReadSize(), error); | |
| 583 | + | |
| 584 | + if (n > 0) | |
| 585 | + websocketWriteRemainder.advanceTail(n); | |
| 586 | + if (n < 0) | |
| 587 | + break; | |
| 588 | + } | |
| 589 | + | |
| 590 | + if (n > 0) | |
| 591 | + return nBytesReal; | |
| 592 | + | |
| 593 | + return n; | |
| 594 | + } | |
| 595 | +} | ... | ... |
iowrapper.h
0 โ 100644
| 1 | +#ifndef IOWRAPPER_H | |
| 2 | +#define IOWRAPPER_H | |
| 3 | + | |
| 4 | +#include "unistd.h" | |
| 5 | +#include "openssl/ssl.h" | |
| 6 | +#include "openssl/err.h" | |
| 7 | +#include <exception> | |
| 8 | + | |
| 9 | +#include "forward_declarations.h" | |
| 10 | + | |
| 11 | +#include "types.h" | |
| 12 | +#include "utils.h" | |
| 13 | +#include "logger.h" | |
| 14 | +#include "exceptions.h" | |
| 15 | + | |
| 16 | +#define WEBSOCKET_MIN_HEADER_BYTES_NEEDED 2 | |
| 17 | +#define WEBSOCKET_MAX_SENDING_HEADER_SIZE 10 | |
| 18 | + | |
| 19 | +#define OPENSSL_ERROR_STRING_SIZE 256 // OpenSSL requires at least 256. | |
| 20 | +#define OPENSSL_WRONG_VERSION_NUMBER 336130315 | |
| 21 | + | |
| 22 | +enum class IoWrapResult | |
| 23 | +{ | |
| 24 | + Success = 0, | |
| 25 | + Interrupted = 1, | |
| 26 | + Wouldblock = 2, | |
| 27 | + Disconnected = 3, | |
| 28 | + Error = 4 | |
| 29 | +}; | |
| 30 | + | |
| 31 | +enum class WebsocketOpcode | |
| 32 | +{ | |
| 33 | + Continuation = 0x00, | |
| 34 | + Text = 0x1, | |
| 35 | + Binary = 0x2, | |
| 36 | + Close = 0x8, | |
| 37 | + Ping = 0x9, | |
| 38 | + Pong = 0xA, | |
| 39 | + Unknown = 0xF | |
| 40 | +}; | |
| 41 | + | |
| 42 | +/* | |
| 43 | + * OpenSSL doc: "When a write function call has to be repeated because SSL_get_error(3) returned | |
| 44 | + * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments" | |
| 45 | + */ | |
| 46 | +struct IncompleteSslWrite | |
| 47 | +{ | |
| 48 | + const void *buf = nullptr; | |
| 49 | + size_t nbytes = 0; | |
| 50 | + | |
| 51 | + IncompleteSslWrite() = default; | |
| 52 | + IncompleteSslWrite(const void *buf, size_t nbytes); | |
| 53 | + bool hasPendingWrite() const; | |
| 54 | + | |
| 55 | + void reset(); | |
| 56 | +}; | |
| 57 | + | |
| 58 | +struct IncompleteWebsocketRead | |
| 59 | +{ | |
| 60 | + size_t frame_bytes_left = 0; | |
| 61 | + char maskingKey[4]; | |
| 62 | + int maskingKeyI = 0; | |
| 63 | + WebsocketOpcode opcode; | |
| 64 | + | |
| 65 | + void reset(); | |
| 66 | + bool sillWorkingOnFrame() const; | |
| 67 | + char getNextMaskingByte(); | |
| 68 | +}; | |
| 69 | + | |
| 70 | +enum class WebsocketState | |
| 71 | +{ | |
| 72 | + NotUpgraded, | |
| 73 | + Upgrading, | |
| 74 | + Upgraded | |
| 75 | +}; | |
| 76 | + | |
| 77 | +/** | |
| 78 | + * @brief provides a unified wrapper for SSL and websockets to read() and write(). | |
| 79 | + * | |
| 80 | + * | |
| 81 | + */ | |
| 82 | +class IoWrapper | |
| 83 | +{ | |
| 84 | + Client *parentClient; | |
| 85 | + const size_t initialBufferSize; | |
| 86 | + | |
| 87 | + SSL *ssl = nullptr; | |
| 88 | + bool sslAccepted = false; | |
| 89 | + IncompleteSslWrite incompleteSslWrite; | |
| 90 | + bool sslReadWantsWrite = false; | |
| 91 | + bool sslWriteWantsRead = false; | |
| 92 | + | |
| 93 | + bool websocket; | |
| 94 | + WebsocketState websocketState = WebsocketState::NotUpgraded; | |
| 95 | + CirBuf websocketPendingBytes; | |
| 96 | + IncompleteWebsocketRead incompleteWebsocketRead; | |
| 97 | + CirBuf websocketWriteRemainder; | |
| 98 | + | |
| 99 | + Logger *logger = Logger::getInstance(); | |
| 100 | + | |
| 101 | + ssize_t websocketBytesToReadBuffer(void *buf, const size_t nbytes); | |
| 102 | + ssize_t readOrSslRead(int fd, void *buf, size_t nbytes, IoWrapResult *error); | |
| 103 | + ssize_t writeOrSslWrite(int fd, const void *buf, size_t nbytes, IoWrapResult *error); | |
| 104 | + ssize_t writeAsMuchOfBufAsWebsocketFrame(const void *buf, size_t nbytes, WebsocketOpcode opcode = WebsocketOpcode::Binary); | |
| 105 | +public: | |
| 106 | + IoWrapper(SSL *ssl, bool websocket, const size_t initialBufferSize, Client *parent); | |
| 107 | + ~IoWrapper(); | |
| 108 | + | |
| 109 | + void startOrContinueSslAccept(); | |
| 110 | + bool getSslReadWantsWrite() const; | |
| 111 | + bool getSslWriteWantsRead() const; | |
| 112 | + bool isSslAccepted() const; | |
| 113 | + bool isSsl() const; | |
| 114 | + bool hasPendingWrite() const; | |
| 115 | + bool isWebsocket() const; | |
| 116 | + WebsocketState getWebsocketState() const; | |
| 117 | + | |
| 118 | + ssize_t readWebsocketAndOrSsl(int fd, void *buf, size_t nbytes, IoWrapResult *error); | |
| 119 | + ssize_t writeWebsocketAndOrSsl(int fd, const void *buf, size_t nbytes, IoWrapResult *error); | |
| 120 | +}; | |
| 121 | + | |
| 122 | +#endif // IOWRAPPER_H | ... | ... |
mainapp.cpp
| ... | ... | @@ -351,6 +351,7 @@ void MainApp::start() |
| 351 | 351 | |
| 352 | 352 | int listen_fd_plain = createListenSocket(this->listenPort, false); |
| 353 | 353 | int listen_fd_ssl = createListenSocket(this->sslListenPort, true); |
| 354 | + int listen_fd_websocket_plain = createListenSocket(1443, true); | |
| 354 | 355 | |
| 355 | 356 | #ifdef NDEBUG |
| 356 | 357 | logger->noLongerLogToStd(); |
| ... | ... | @@ -391,7 +392,7 @@ void MainApp::start() |
| 391 | 392 | int cur_fd = events[i].data.fd; |
| 392 | 393 | try |
| 393 | 394 | { |
| 394 | - if (cur_fd == listen_fd_plain || cur_fd == listen_fd_ssl) | |
| 395 | + if (cur_fd == listen_fd_plain || cur_fd == listen_fd_ssl || listen_fd_websocket_plain) | |
| 395 | 396 | { |
| 396 | 397 | std::shared_ptr<ThreadData> thread_data = threads[next_thread_index++ % num_threads]; |
| 397 | 398 | |
| ... | ... | @@ -402,6 +403,7 @@ void MainApp::start() |
| 402 | 403 | socklen_t len = sizeof(struct sockaddr); |
| 403 | 404 | int fd = check<std::runtime_error>(accept(cur_fd, &addr, &len)); |
| 404 | 405 | |
| 406 | + bool websocket = cur_fd == listen_fd_websocket_plain; | |
| 405 | 407 | SSL *clientSSL = nullptr; |
| 406 | 408 | if (cur_fd == listen_fd_ssl) |
| 407 | 409 | { |
| ... | ... | @@ -417,7 +419,7 @@ void MainApp::start() |
| 417 | 419 | SSL_set_fd(clientSSL, fd); |
| 418 | 420 | } |
| 419 | 421 | |
| 420 | - Client_p client(new Client(fd, thread_data, clientSSL, settings)); | |
| 422 | + Client_p client(new Client(fd, thread_data, clientSSL, websocket, settings)); | |
| 421 | 423 | thread_data->giveClient(client); |
| 422 | 424 | } |
| 423 | 425 | else if (cur_fd == taskEventFd) | ... | ... |
utils.cpp
| ... | ... | @@ -3,8 +3,10 @@ |
| 3 | 3 | #include "sys/time.h" |
| 4 | 4 | #include "sys/random.h" |
| 5 | 5 | #include <algorithm> |
| 6 | +#include <sstream> | |
| 6 | 7 | |
| 7 | 8 | #include "exceptions.h" |
| 9 | +#include "cirbuf.h" | |
| 8 | 10 | |
| 9 | 11 | std::list<std::__cxx11::string> split(const std::string &input, const char sep, size_t max, bool keep_empty_parts) |
| 10 | 12 | { |
| ... | ... | @@ -226,3 +228,121 @@ bool isPowerOfTwo(int n) |
| 226 | 228 | { |
| 227 | 229 | return (n != 0) && (n & (n - 1)) == 0; |
| 228 | 230 | } |
| 231 | + | |
| 232 | +bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version) | |
| 233 | +{ | |
| 234 | + const std::string s(buf.tailPtr(), buf.usedBytes()); | |
| 235 | + std::istringstream is(s); | |
| 236 | + bool doubleEmptyLine = false; // meaning, the HTTP header is complete | |
| 237 | + bool upgradeHeaderSeen = false; | |
| 238 | + bool connectionHeaderSeen = false; | |
| 239 | + bool firstLine = true; | |
| 240 | + | |
| 241 | + std::string line; | |
| 242 | + while (std::getline(is, line)) | |
| 243 | + { | |
| 244 | + trim(line); | |
| 245 | + if (firstLine) | |
| 246 | + { | |
| 247 | + firstLine = false; | |
| 248 | + if (!startsWith(line, "GET")) | |
| 249 | + throw BadHttpRequest("Websocket request should start with GET."); | |
| 250 | + continue; | |
| 251 | + } | |
| 252 | + if (line.empty()) | |
| 253 | + { | |
| 254 | + doubleEmptyLine = true; | |
| 255 | + break; | |
| 256 | + } | |
| 257 | + | |
| 258 | + std::list<std::string> fields = split(line, ':', 1); | |
| 259 | + const std::vector<std::string> fields2(fields.begin(), fields.end()); | |
| 260 | + std::string name = str_tolower(fields2[0]); | |
| 261 | + trim(name); | |
| 262 | + std::string value = fields2[1]; | |
| 263 | + trim(value); | |
| 264 | + std::string value_lower = str_tolower(value); | |
| 265 | + | |
| 266 | + if (name == "upgrade" && value_lower == "websocket") | |
| 267 | + upgradeHeaderSeen = true; | |
| 268 | + else if (name == "connection" && value_lower == "upgrade") | |
| 269 | + connectionHeaderSeen = true; | |
| 270 | + else if (name == "sec-websocket-key") | |
| 271 | + websocket_key = value; | |
| 272 | + else if (name == "sec-websocket-version") | |
| 273 | + websocket_version = stoi(value); | |
| 274 | + } | |
| 275 | + | |
| 276 | + if (doubleEmptyLine) | |
| 277 | + { | |
| 278 | + if (!connectionHeaderSeen || !upgradeHeaderSeen) | |
| 279 | + throw BadHttpRequest("HTTP request is not a websocket upgrade request."); | |
| 280 | + } | |
| 281 | + | |
| 282 | + return doubleEmptyLine; | |
| 283 | +} | |
| 284 | + | |
| 285 | +std::string base64Encode(const unsigned char *input, const int length) | |
| 286 | +{ | |
| 287 | + const int pl = 4*((length+2)/3); | |
| 288 | + char *output = reinterpret_cast<char *>(calloc(pl+1, 1)); | |
| 289 | + const int ol = EVP_EncodeBlock(reinterpret_cast<unsigned char *>(output), input, length); | |
| 290 | + std::string result(output); | |
| 291 | + free(output); | |
| 292 | + | |
| 293 | + if (pl != ol) | |
| 294 | + throw std::runtime_error("Base64 encode error."); | |
| 295 | + | |
| 296 | + return result; | |
| 297 | +} | |
| 298 | + | |
| 299 | +std::string generateWebsocketAcceptString(const std::string &websocketKey) | |
| 300 | +{ | |
| 301 | + unsigned char md_value[EVP_MAX_MD_SIZE]; | |
| 302 | + unsigned int md_len; | |
| 303 | + | |
| 304 | + EVP_MD_CTX *mdctx = EVP_MD_CTX_new();; | |
| 305 | + const EVP_MD *md = EVP_sha1(); | |
| 306 | + EVP_DigestInit_ex(mdctx, md, NULL); | |
| 307 | + | |
| 308 | + const std::string keyPlusMagic = websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; | |
| 309 | + | |
| 310 | + EVP_DigestUpdate(mdctx, keyPlusMagic.c_str(), keyPlusMagic.length()); | |
| 311 | + EVP_DigestFinal_ex(mdctx, md_value, &md_len); | |
| 312 | + EVP_MD_CTX_free(mdctx); | |
| 313 | + | |
| 314 | + std::string base64 = base64Encode(md_value, md_len); | |
| 315 | + return base64; | |
| 316 | +} | |
| 317 | + | |
| 318 | +std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion) | |
| 319 | +{ | |
| 320 | + std::ostringstream oss; | |
| 321 | + oss << "HTTP/1.1 400 Bad Request\r\n"; | |
| 322 | + oss << "Sec-WebSocket-Version: " << wantedVersion; | |
| 323 | + oss << "\r\n"; | |
| 324 | + oss.flush(); | |
| 325 | + return oss.str(); | |
| 326 | +} | |
| 327 | + | |
| 328 | +std::string generateBadHttpRequestReponse(const std::string &msg) | |
| 329 | +{ | |
| 330 | + std::ostringstream oss; | |
| 331 | + oss << "HTTP/1.1 400 Bad Request\r\n"; | |
| 332 | + oss << "\r\n"; | |
| 333 | + oss << msg; | |
| 334 | + oss.flush(); | |
| 335 | + return oss.str(); | |
| 336 | +} | |
| 337 | + | |
| 338 | +std::string generateWebsocketAnswer(const std::string &acceptString) | |
| 339 | +{ | |
| 340 | + std::ostringstream oss; | |
| 341 | + oss << "HTTP/1.1 101 Switching Protocols\r\n"; | |
| 342 | + oss << "Upgrade: websocket\r\n"; | |
| 343 | + oss << "Connection: Upgrade\r\n"; | |
| 344 | + oss << "Sec-WebSocket-Accept: " << acceptString << "\r\n"; | |
| 345 | + oss << "\r\n"; | |
| 346 | + oss.flush(); | |
| 347 | + return oss.str(); | |
| 348 | +} | ... | ... |
utils.h
| ... | ... | @@ -8,6 +8,9 @@ |
| 8 | 8 | #include <limits> |
| 9 | 9 | #include <vector> |
| 10 | 10 | #include <algorithm> |
| 11 | +#include <openssl/evp.h> | |
| 12 | + | |
| 13 | +#include "cirbuf.h" | |
| 11 | 14 | |
| 12 | 15 | template<typename T> int check(int rc) |
| 13 | 16 | { |
| ... | ... | @@ -43,4 +46,13 @@ std::string str_tolower(std::string s); |
| 43 | 46 | bool stringTruthiness(const std::string &val); |
| 44 | 47 | bool isPowerOfTwo(int val); |
| 45 | 48 | |
| 49 | +bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version); | |
| 50 | + | |
| 51 | +std::string base64Encode(const unsigned char *input, const int length); | |
| 52 | +std::string generateWebsocketAcceptString(const std::string &websocketKey); | |
| 53 | + | |
| 54 | +std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion); | |
| 55 | +std::string generateBadHttpRequestReponse(const std::string &msg); | |
| 56 | +std::string generateWebsocketAnswer(const std::string &acceptString); | |
| 57 | + | |
| 46 | 58 | #endif // UTILS_H | ... | ... |