Commit b5ba41f503c602c0edcdeec0db5b2b8344050cb8

Authored by Wiebe Cazemier
1 parent 1106e210

Add IoWrapper, with websocket support added

The ping/pong is actually untested at this point, because Paho (my test
client for now) doesn't do those. I wonder if any do, because MQTT
already has ping/pong.
CMakeLists.txt
@@ -29,6 +29,7 @@ add_executable(FlashMQ @@ -29,6 +29,7 @@ add_executable(FlashMQ
29 sslctxmanager.cpp 29 sslctxmanager.cpp
30 timer.cpp 30 timer.cpp
31 globalsettings.cpp 31 globalsettings.cpp
  32 + iowrapper.cpp
32 ) 33 )
33 34
34 target_link_libraries(FlashMQ pthread dl ssl crypto) 35 target_link_libraries(FlashMQ pthread dl ssl crypto)
cirbuf.cpp
@@ -9,10 +9,14 @@ @@ -9,10 +9,14 @@
9 #include <cstring> 9 #include <cstring>
10 10
11 #include "logger.h" 11 #include "logger.h"
  12 +#include "utils.h"
12 13
13 CirBuf::CirBuf(size_t size) : 14 CirBuf::CirBuf(size_t size) :
14 size(size) 15 size(size)
15 { 16 {
  17 + if (size == 0)
  18 + return;
  19 +
16 buf = (char*)malloc(size); 20 buf = (char*)malloc(size);
17 21
18 if (buf == NULL) 22 if (buf == NULL)
@@ -69,12 +73,14 @@ char *CirBuf::tailPtr() @@ -69,12 +73,14 @@ char *CirBuf::tailPtr()
69 73
70 void CirBuf::advanceHead(uint32_t n) 74 void CirBuf::advanceHead(uint32_t n)
71 { 75 {
  76 + assert(n <= freeSpace());
72 head = (head + n) & (size -1); 77 head = (head + n) & (size -1);
73 assert(tail != head); // Putting things in the buffer must never end on tail, because tail == head == empty. 78 assert(tail != head); // Putting things in the buffer must never end on tail, because tail == head == empty.
74 } 79 }
75 80
76 void CirBuf::advanceTail(uint32_t n) 81 void CirBuf::advanceTail(uint32_t n)
77 { 82 {
  83 + assert(n <= usedBytes());
78 tail = (tail + n) & (size -1); 84 tail = (tail + n) & (size -1);
79 } 85 }
80 86
@@ -84,9 +90,31 @@ char CirBuf::peakAhead(uint32_t offset) const @@ -84,9 +90,31 @@ char CirBuf::peakAhead(uint32_t offset) const
84 return b; 90 return b;
85 } 91 }
86 92
87 -void CirBuf::doubleSize() 93 +void CirBuf::ensureFreeSpace(size_t n)
  94 +{
  95 + const size_t _usedBytes = usedBytes();
  96 +
  97 + int mul = 1;
  98 +
  99 + while((mul * size - _usedBytes - 1) < n)
  100 + {
  101 + mul = mul << 1;
  102 + }
  103 +
  104 + if (mul == 1)
  105 + return;
  106 +
  107 + doubleSize(mul);
  108 +}
  109 +
  110 +void CirBuf::doubleSize(uint factor)
88 { 111 {
89 - uint newSize = size * 2; 112 + if (factor == 1)
  113 + return;
  114 +
  115 + assert(isPowerOfTwo(factor));
  116 +
  117 + uint newSize = size * factor;
90 char *newBuf = (char*)realloc(buf, newSize); 118 char *newBuf = (char*)realloc(buf, newSize);
91 119
92 if (newBuf == NULL) 120 if (newBuf == NULL)
@@ -145,3 +173,55 @@ void CirBuf::resetSize(size_t newSize) @@ -145,3 +173,55 @@ void CirBuf::resetSize(size_t newSize)
145 memset(buf, 0, newSize); 173 memset(buf, 0, newSize);
146 #endif 174 #endif
147 } 175 }
  176 +
  177 +void CirBuf::reset()
  178 +{
  179 + head = 0;
  180 + tail = 0;
  181 +
  182 +#ifndef NDEBUG
  183 + memset(buf, 0, size);
  184 +#endif
  185 +}
  186 +
  187 +// When you know the data you want to write fits in the buffer, use this.
  188 +void CirBuf::write(const void *buf, size_t count)
  189 +{
  190 + assert(count <= freeSpace());
  191 +
  192 + ssize_t len_left = count;
  193 + size_t src_i = 0;
  194 + while (len_left > 0)
  195 + {
  196 + const size_t len = std::min<int>(len_left, maxWriteSize());
  197 + assert(len > 0);
  198 + const char *src = &reinterpret_cast<const char*>(buf)[src_i];
  199 + std::memcpy(headPtr(), src, len);
  200 + advanceHead(len);
  201 + src_i += len;
  202 + len_left -= len;
  203 + }
  204 + assert(len_left == 0);
  205 + assert(src_i == count);
  206 +}
  207 +
  208 +// When you know 'count' bytes are present and you want to read them into buf
  209 +void CirBuf::read(void *buf, const size_t count)
  210 +{
  211 + assert(count <= usedBytes());
  212 +
  213 + char *_buf = static_cast<char*>(buf);
  214 + int i = 0;
  215 + ssize_t _packet_len = count;
  216 + while (_packet_len > 0)
  217 + {
  218 + const int readlen = std::min<int>(maxReadSize(), _packet_len);
  219 + assert(readlen > 0);
  220 + std::memcpy(&_buf[i], tailPtr(), readlen);
  221 + advanceTail(readlen);
  222 + i += readlen;
  223 + _packet_len -= readlen;
  224 + }
  225 + assert(_packet_len == 0);
  226 + assert(i == _packet_len);
  227 +}
cirbuf.h
@@ -32,11 +32,16 @@ public: @@ -32,11 +32,16 @@ public:
32 void advanceHead(uint32_t n); 32 void advanceHead(uint32_t n);
33 void advanceTail(uint32_t n); 33 void advanceTail(uint32_t n);
34 char peakAhead(uint32_t offset) const; 34 char peakAhead(uint32_t offset) const;
35 - void doubleSize(); 35 + void ensureFreeSpace(size_t n);
  36 + void doubleSize(uint factor = 2);
36 uint32_t getSize() const; 37 uint32_t getSize() const;
37 38
38 time_t bufferLastResizedSecondsAgo() const; 39 time_t bufferLastResizedSecondsAgo() const;
39 void resetSize(size_t size); 40 void resetSize(size_t size);
  41 + void reset();
  42 +
  43 + void write(const void *buf, size_t count);
  44 + void read(void *buf, const size_t count);
40 }; 45 };
41 46
42 #endif // CIRBUF_H 47 #endif // CIRBUF_H
client.cpp
@@ -7,11 +7,11 @@ @@ -7,11 +7,11 @@
7 7
8 #include "logger.h" 8 #include "logger.h"
9 9
10 -Client::Client(int fd, ThreadData_p threadData, SSL *ssl, const GlobalSettings &settings) : 10 +Client::Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, const GlobalSettings &settings) :
11 fd(fd), 11 fd(fd),
12 - ssl(ssl),  
13 initialBufferSize(settings.clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy 12 initialBufferSize(settings.clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy
14 maxPacketSize(settings.maxPacketSize), // Same as initialBufferSize comment. 13 maxPacketSize(settings.maxPacketSize), // Same as initialBufferSize comment.
  14 + ioWrapper(ssl, websocket, initialBufferSize, this),
15 readbuf(initialBufferSize), 15 readbuf(initialBufferSize),
16 writebuf(initialBufferSize), 16 writebuf(initialBufferSize),
17 threadData(threadData) 17 threadData(threadData)
@@ -40,63 +40,32 @@ Client::~Client() @@ -40,63 +40,32 @@ Client::~Client()
40 logger->logf(LOG_NOTICE, "Removing client '%s'. Reason(s): %s", repr().c_str(), disconnectReason.c_str()); 40 logger->logf(LOG_NOTICE, "Removing client '%s'. Reason(s): %s", repr().c_str(), disconnectReason.c_str());
41 if (epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL) != 0) 41 if (epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL) != 0)
42 logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd, repr().c_str(), strerror(errno)); 42 logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd, repr().c_str(), strerror(errno));
43 - if (ssl)  
44 - {  
45 - // I don't do SSL_shutdown(), because I don't want to keep the session, plus, that takes active de-negiotation, so it can't be done  
46 - // in the destructor.  
47 - SSL_free(ssl);  
48 - }  
49 close(fd); 43 close(fd);
50 } 44 }
51 45
52 bool Client::isSslAccepted() const 46 bool Client::isSslAccepted() const
53 { 47 {
54 - return sslAccepted; 48 + return ioWrapper.isSslAccepted();
55 } 49 }
56 50
57 bool Client::isSsl() const 51 bool Client::isSsl() const
58 { 52 {
59 - return this->ssl != nullptr; 53 + return ioWrapper.isSsl();
60 } 54 }
61 55
62 bool Client::getSslReadWantsWrite() const 56 bool Client::getSslReadWantsWrite() const
63 { 57 {
64 - return this->sslReadWantsWrite; 58 + return ioWrapper.getSslReadWantsWrite();
65 } 59 }
66 60
67 bool Client::getSslWriteWantsRead() const 61 bool Client::getSslWriteWantsRead() const
68 { 62 {
69 - return this->sslWriteWantsRead; 63 + return ioWrapper.getSslWriteWantsRead();
70 } 64 }
71 65
72 void Client::startOrContinueSslAccept() 66 void Client::startOrContinueSslAccept()
73 { 67 {
74 - ERR_clear_error();  
75 - int accepted = SSL_accept(ssl);  
76 - char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];  
77 - if (accepted <= 0)  
78 - {  
79 - int err = SSL_get_error(ssl, accepted);  
80 -  
81 - if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)  
82 - {  
83 - setReadyForWriting(err == SSL_ERROR_WANT_WRITE);  
84 - return;  
85 - }  
86 -  
87 - unsigned long error_code = ERR_get_error();  
88 -  
89 - ERR_error_string(error_code, sslErrorBuf);  
90 - std::string errorMsg(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);  
91 -  
92 - if (error_code == OPENSSL_WRONG_VERSION_NUMBER)  
93 - errorMsg = "Wrong protocol version number. Probably a non-SSL connection on SSL socket.";  
94 -  
95 - //ERR_print_errors_cb(logSslError, NULL);  
96 - throw std::runtime_error("Problem accepting SSL socket: " + errorMsg);  
97 - }  
98 - setReadyForWriting(false); // Undo write readiness that may have have happened during SSL handshake  
99 - sslAccepted = true; 68 + ioWrapper.startOrContinueSslAccept();
100 } 69 }
101 70
102 // Causes future activity on the client to cause a disconnect. 71 // Causes future activity on the client to cause a disconnect.
@@ -108,83 +77,6 @@ void Client::markAsDisconnecting() @@ -108,83 +77,6 @@ void Client::markAsDisconnecting()
108 disconnecting = true; 77 disconnecting = true;
109 } 78 }
110 79
111 -// SSL and non-SSL sockets behave differently. For one, reading 0 doesn't mean 'disconnected' with an SSL  
112 -// socket. This wrapper unifies behavor for the caller.  
113 -ssize_t Client::readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error)  
114 -{  
115 - *error = IoWrapResult::Success;  
116 - ssize_t n = 0;  
117 - if (!ssl)  
118 - {  
119 - n = read(fd, buf, nbytes);  
120 - if (n < 0)  
121 - {  
122 - if (errno == EINTR)  
123 - *error = IoWrapResult::Interrupted;  
124 - else if (errno == EAGAIN || errno == EWOULDBLOCK)  
125 - *error = IoWrapResult::Wouldblock;  
126 - else  
127 - check<std::runtime_error>(n);  
128 - }  
129 - else if (n == 0)  
130 - {  
131 - *error = IoWrapResult::Disconnected;  
132 - }  
133 - }  
134 - else  
135 - {  
136 - this->sslReadWantsWrite = false;  
137 - ERR_clear_error();  
138 - char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];  
139 - n = SSL_read(ssl, buf, nbytes);  
140 -  
141 - if (n <= 0)  
142 - {  
143 - int err = SSL_get_error(ssl, n);  
144 - unsigned long error_code = ERR_get_error();  
145 -  
146 - // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html "BUGS" why EOF is seen as SSL_ERROR_SYSCALL.  
147 - if (err == SSL_ERROR_ZERO_RETURN || (err == SSL_ERROR_SYSCALL && errno == 0))  
148 - {  
149 - *error = IoWrapResult::Disconnected;  
150 - }  
151 - else if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)  
152 - {  
153 - *error = IoWrapResult::Wouldblock;  
154 - if (err == SSL_ERROR_WANT_WRITE)  
155 - {  
156 - sslReadWantsWrite = true;  
157 - setReadyForWriting(true);  
158 - }  
159 - n = -1;  
160 - }  
161 - else  
162 - {  
163 - if (err == SSL_ERROR_SYSCALL)  
164 - {  
165 - // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say  
166 - // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it  
167 - // implies EINTR is not included?  
168 - if (errno == EINTR)  
169 - *error = IoWrapResult::Interrupted;  
170 - else  
171 - {  
172 - char *err = strerror(errno);  
173 - std::string msg(err);  
174 - throw std::runtime_error("SSL read error: " + msg);  
175 - }  
176 - }  
177 - ERR_error_string(error_code, sslErrorBuf);  
178 - std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);  
179 - ERR_print_errors_cb(logSslError, NULL);  
180 - throw std::runtime_error("SSL socket error reading: " + errorString);  
181 - }  
182 - }  
183 - }  
184 -  
185 - return n;  
186 -}  
187 -  
188 // false means any kind of error we want to get rid of the client for. 80 // false means any kind of error we want to get rid of the client for.
189 bool Client::readFdIntoBuffer() 81 bool Client::readFdIntoBuffer()
190 { 82 {
@@ -193,7 +85,7 @@ bool Client::readFdIntoBuffer() @@ -193,7 +85,7 @@ bool Client::readFdIntoBuffer()
193 85
194 IoWrapResult error = IoWrapResult::Success; 86 IoWrapResult error = IoWrapResult::Success;
195 int n = 0; 87 int n = 0;
196 - while (readbuf.freeSpace() > 0 && (n = readWrap(fd, readbuf.headPtr(), readbuf.maxWriteSize(), &error)) != 0) 88 + while (readbuf.freeSpace() > 0 && (n = ioWrapper.readWebsocketAndOrSsl(fd, readbuf.headPtr(), readbuf.maxWriteSize(), &error)) != 0)
197 { 89 {
198 if (n > 0) 90 if (n > 0)
199 { 91 {
@@ -232,6 +124,20 @@ bool Client::readFdIntoBuffer() @@ -232,6 +124,20 @@ bool Client::readFdIntoBuffer()
232 return true; 124 return true;
233 } 125 }
234 126
  127 +void Client::writeText(const std::string &text)
  128 +{
  129 + assert(ioWrapper.isWebsocket());
  130 + assert(ioWrapper.getWebsocketState() == WebsocketState::NotUpgraded);
  131 +
  132 + // Not necessary, because at this point, no other threads write to this client, but including for clarity.
  133 + std::lock_guard<std::mutex> locker(writeBufMutex);
  134 +
  135 + writebuf.ensureFreeSpace(text.size());
  136 + writebuf.write(text.c_str(), text.length());
  137 +
  138 + setReadyForWriting(true);
  139 +}
  140 +
235 void Client::writeMqttPacket(const MqttPacket &packet) 141 void Client::writeMqttPacket(const MqttPacket &packet)
236 { 142 {
237 std::lock_guard<std::mutex> locker(writeBufMutex); 143 std::lock_guard<std::mutex> locker(writeBufMutex);
@@ -322,94 +228,6 @@ void Client::writePingResp() @@ -322,94 +228,6 @@ void Client::writePingResp()
322 setReadyForWriting(true); 228 setReadyForWriting(true);
323 } 229 }
324 230
325 -// SSL and non-SSL sockets behave differently. This wrapper unifies behavor for the caller.  
326 -ssize_t Client::writeWrap(int fd, const void *buf, size_t nbytes, IoWrapResult *error)  
327 -{  
328 - *error = IoWrapResult::Success;  
329 - ssize_t n = 0;  
330 -  
331 - if (!ssl)  
332 - {  
333 - // A write on a socket with count=0 is unspecified.  
334 - assert(nbytes > 0);  
335 -  
336 - n = write(fd, buf, nbytes);  
337 - if (n < 0)  
338 - {  
339 - if (errno == EINTR)  
340 - *error = IoWrapResult::Interrupted;  
341 - else if (errno == EAGAIN || errno == EWOULDBLOCK)  
342 - *error = IoWrapResult::Wouldblock;  
343 - else  
344 - check<std::runtime_error>(n);  
345 - }  
346 - }  
347 - else  
348 - {  
349 - const void *buf_ = buf;  
350 - size_t nbytes_ = nbytes;  
351 -  
352 - /*  
353 - * OpenSSL doc: When a write function call has to be repeated because SSL_get_error(3) returned  
354 - * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments  
355 - */  
356 - if (this->incompleteSslWrite.hasPendingWrite())  
357 - {  
358 - buf_ = this->incompleteSslWrite.buf;  
359 - nbytes_ = this->incompleteSslWrite.nbytes;  
360 - }  
361 -  
362 - // OpenSSL: "You should not call SSL_write() with num=0, it will return an error"  
363 - assert(nbytes_ > 0);  
364 -  
365 - this->sslWriteWantsRead = false;  
366 - this->incompleteSslWrite.reset();  
367 -  
368 - ERR_clear_error();  
369 - char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];  
370 - n = SSL_write(ssl, buf_, nbytes_);  
371 -  
372 - if (n <= 0)  
373 - {  
374 - int err = SSL_get_error(ssl, n);  
375 - unsigned long error_code = ERR_get_error();  
376 - if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)  
377 - {  
378 - logger->logf(LOG_DEBUG, "Write is incomplete: %d", err);  
379 - *error = IoWrapResult::Wouldblock;  
380 - IncompleteSslWrite sslAction(buf_, nbytes_);  
381 - this->incompleteSslWrite = sslAction;  
382 - if (err == SSL_ERROR_WANT_READ)  
383 - this->sslWriteWantsRead = true;  
384 - n = 0;  
385 - }  
386 - else  
387 - {  
388 - if (err == SSL_ERROR_SYSCALL)  
389 - {  
390 - // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say  
391 - // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it  
392 - // implies EINTR is not included?  
393 - if (errno == EINTR)  
394 - *error = IoWrapResult::Interrupted;  
395 - else  
396 - {  
397 - char *err = strerror(errno);  
398 - std::string msg(err);  
399 - throw std::runtime_error(msg);  
400 - }  
401 - }  
402 - ERR_error_string(error_code, sslErrorBuf);  
403 - std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);  
404 - ERR_print_errors_cb(logSslError, NULL);  
405 - throw std::runtime_error("SSL socket error writing: " + errorString);  
406 - }  
407 - }  
408 - }  
409 -  
410 - return n;  
411 -}  
412 -  
413 bool Client::writeBufIntoFd() 231 bool Client::writeBufIntoFd()
414 { 232 {
415 std::unique_lock<std::mutex> lock(writeBufMutex, std::try_to_lock); 233 std::unique_lock<std::mutex> lock(writeBufMutex, std::try_to_lock);
@@ -422,9 +240,9 @@ bool Client::writeBufIntoFd() @@ -422,9 +240,9 @@ bool Client::writeBufIntoFd()
422 240
423 IoWrapResult error = IoWrapResult::Success; 241 IoWrapResult error = IoWrapResult::Success;
424 int n; 242 int n;
425 - while (writebuf.usedBytes() > 0 || incompleteSslWrite.hasPendingWrite()) 243 + while (writebuf.usedBytes() > 0 || ioWrapper.hasPendingWrite())
426 { 244 {
427 - n = writeWrap(fd, writebuf.tailPtr(), writebuf.maxReadSize(), &error); 245 + n = ioWrapper.writeWebsocketAndOrSsl(fd, writebuf.tailPtr(), writebuf.maxReadSize(), &error);
428 246
429 if (n > 0) 247 if (n > 0)
430 writebuf.advanceTail(n); 248 writebuf.advanceTail(n);
@@ -484,7 +302,7 @@ void Client::setReadyForWriting(bool val) @@ -484,7 +302,7 @@ void Client::setReadyForWriting(bool val)
484 if (disconnecting) 302 if (disconnecting)
485 return; 303 return;
486 304
487 - if (sslReadWantsWrite) 305 + if (ioWrapper.getSslReadWantsWrite())
488 val = true; 306 val = true;
489 307
490 if (val == this->readyForWriting) 308 if (val == this->readyForWriting)
@@ -622,20 +440,3 @@ void Client::clearWill() @@ -622,20 +440,3 @@ void Client::clearWill()
622 will_qos = 0; 440 will_qos = 0;
623 } 441 }
624 442
625 -IncompleteSslWrite::IncompleteSslWrite(const void *buf, size_t nbytes) :  
626 - buf(buf),  
627 - nbytes(nbytes)  
628 -{  
629 -  
630 -}  
631 -  
632 -void IncompleteSslWrite::reset()  
633 -{  
634 - buf = nullptr;  
635 - nbytes = 0;  
636 -}  
637 -  
638 -bool IncompleteSslWrite::hasPendingWrite()  
639 -{  
640 - return buf != nullptr;  
641 -}  
client.h
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 #include <time.h> 9 #include <time.h>
10 10
11 #include <openssl/ssl.h> 11 #include <openssl/ssl.h>
  12 +#include <openssl/err.h>
12 13
13 #include "forward_declarations.h" 14 #include "forward_declarations.h"
14 15
@@ -17,54 +18,25 @@ @@ -17,54 +18,25 @@
17 #include "exceptions.h" 18 #include "exceptions.h"
18 #include "cirbuf.h" 19 #include "cirbuf.h"
19 #include "types.h" 20 #include "types.h"
20 -  
21 -#include <openssl/ssl.h>  
22 -#include <openssl/err.h> 21 +#include "iowrapper.h"
23 22
24 #define MQTT_HEADER_LENGH 2 23 #define MQTT_HEADER_LENGH 2
25 24
26 -#define OPENSSL_ERROR_STRING_SIZE 256 // OpenSSL requires at least 256.  
27 -#define OPENSSL_WRONG_VERSION_NUMBER 336130315  
28 -  
29 -enum class IoWrapResult  
30 -{  
31 - Success = 0,  
32 - Interrupted = 1,  
33 - Wouldblock = 2,  
34 - Disconnected = 3,  
35 - Error = 4  
36 -};  
37 -  
38 -/*  
39 - * OpenSSL doc: "When a write function call has to be repeated because SSL_get_error(3) returned  
40 - * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments"  
41 - */  
42 -struct IncompleteSslWrite  
43 -{  
44 - const void *buf = nullptr;  
45 - size_t nbytes = 0;  
46 -  
47 - IncompleteSslWrite() = default;  
48 - IncompleteSslWrite(const void *buf, size_t nbytes);  
49 - bool hasPendingWrite();  
50 -  
51 - void reset();  
52 -};  
53 25
54 // TODO: give accepted addr, for showing in logs 26 // TODO: give accepted addr, for showing in logs
55 class Client 27 class Client
56 { 28 {
  29 + friend class IoWrapper;
  30 +
57 int fd; 31 int fd;
58 - SSL *ssl = nullptr;  
59 - bool sslAccepted = false;  
60 - IncompleteSslWrite incompleteSslWrite;  
61 - bool sslReadWantsWrite = false;  
62 - bool sslWriteWantsRead = false; 32 +
63 ProtocolVersion protocolVersion = ProtocolVersion::None; 33 ProtocolVersion protocolVersion = ProtocolVersion::None;
64 34
65 const size_t initialBufferSize = 0; 35 const size_t initialBufferSize = 0;
66 const size_t maxPacketSize = 0; 36 const size_t maxPacketSize = 0;
67 37
  38 + IoWrapper ioWrapper;
  39 +
68 CirBuf readbuf; 40 CirBuf readbuf;
69 uint8_t readBufIsZeroCount = 0; 41 uint8_t readBufIsZeroCount = 0;
70 42
@@ -97,12 +69,11 @@ class Client @@ -97,12 +69,11 @@ class Client
97 69
98 Logger *logger = Logger::getInstance(); 70 Logger *logger = Logger::getInstance();
99 71
100 -  
101 void setReadyForWriting(bool val); 72 void setReadyForWriting(bool val);
102 void setReadyForReading(bool val); 73 void setReadyForReading(bool val);
103 74
104 public: 75 public:
105 - Client(int fd, ThreadData_p threadData, SSL *ssl, const GlobalSettings &settings); 76 + Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, const GlobalSettings &settings);
106 Client(const Client &other) = delete; 77 Client(const Client &other) = delete;
107 Client(Client &&other) = delete; 78 Client(Client &&other) = delete;
108 ~Client(); 79 ~Client();
@@ -115,7 +86,6 @@ public: @@ -115,7 +86,6 @@ public:
115 86
116 void startOrContinueSslAccept(); 87 void startOrContinueSslAccept();
117 void markAsDisconnecting(); 88 void markAsDisconnecting();
118 - ssize_t readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error);  
119 bool readFdIntoBuffer(); 89 bool readFdIntoBuffer();
120 bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender); 90 bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender);
121 void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession); 91 void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession);
@@ -131,10 +101,10 @@ public: @@ -131,10 +101,10 @@ public:
131 std::shared_ptr<Session> getSession(); 101 std::shared_ptr<Session> getSession();
132 void setDisconnectReason(const std::string &reason); 102 void setDisconnectReason(const std::string &reason);
133 103
  104 + void writeText(const std::string &text);
134 void writePingResp(); 105 void writePingResp();
135 void writeMqttPacket(const MqttPacket &packet); 106 void writeMqttPacket(const MqttPacket &packet);
136 void writeMqttPacketAndBlameThisClient(const MqttPacket &packet); 107 void writeMqttPacketAndBlameThisClient(const MqttPacket &packet);
137 - ssize_t writeWrap(int fd, const void *buf, size_t nbytes, IoWrapResult *error);  
138 bool writeBufIntoFd(); 108 bool writeBufIntoFd();
139 bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; } 109 bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; }
140 110
exceptions.h
@@ -34,4 +34,16 @@ public: @@ -34,4 +34,16 @@ public:
34 AuthPluginException(const std::string &msg) : std::runtime_error(msg) {} 34 AuthPluginException(const std::string &msg) : std::runtime_error(msg) {}
35 }; 35 };
36 36
  37 +class BadWebsocketVersionException : public std::runtime_error
  38 +{
  39 +public:
  40 + BadWebsocketVersionException(const std::string &msg) : std::runtime_error(msg) {}
  41 +};
  42 +
  43 +class BadHttpRequest : public std::runtime_error
  44 +{
  45 +public:
  46 + BadHttpRequest(const std::string &msg) : std::runtime_error(msg) {}
  47 +};
  48 +
37 #endif // EXCEPTIONS_H 49 #endif // EXCEPTIONS_H
iowrapper.cpp 0 โ†’ 100644
  1 +#include "iowrapper.h"
  2 +
  3 +#include "cassert"
  4 +
  5 +#include "logger.h"
  6 +#include "client.h"
  7 +
  8 +IncompleteSslWrite::IncompleteSslWrite(const void *buf, size_t nbytes) :
  9 + buf(buf),
  10 + nbytes(nbytes)
  11 +{
  12 +
  13 +}
  14 +
  15 +void IncompleteSslWrite::reset()
  16 +{
  17 + buf = nullptr;
  18 + nbytes = 0;
  19 +}
  20 +
  21 +bool IncompleteSslWrite::hasPendingWrite() const
  22 +{
  23 + return buf != nullptr;
  24 +}
  25 +
  26 +void IncompleteWebsocketRead::reset()
  27 +{
  28 + maskingKeyI = 0;
  29 + memset(maskingKey,0, 4);
  30 + frame_bytes_left = 0;
  31 + opcode = WebsocketOpcode::Unknown;
  32 +}
  33 +
  34 +bool IncompleteWebsocketRead::sillWorkingOnFrame() const
  35 +{
  36 + return frame_bytes_left > 0;
  37 +}
  38 +
  39 +char IncompleteWebsocketRead::getNextMaskingByte()
  40 +{
  41 + return maskingKey[maskingKeyI++ % 4];
  42 +}
  43 +
  44 +IoWrapper::IoWrapper(SSL *ssl, bool websocket, const size_t initialBufferSize, Client *parent) :
  45 + parentClient(parent),
  46 + initialBufferSize(initialBufferSize),
  47 + ssl(ssl),
  48 + websocket(websocket),
  49 + websocketPendingBytes(websocket ? initialBufferSize : 0),
  50 + websocketWriteRemainder(websocket ? initialBufferSize : 0)
  51 +{
  52 +
  53 +}
  54 +
  55 +IoWrapper::~IoWrapper()
  56 +{
  57 + if (ssl)
  58 + {
  59 + // I don't do SSL_shutdown(), because I don't want to keep the session, plus, that takes active de-negiotation, so it can't be done
  60 + // in the destructor.
  61 + SSL_free(ssl);
  62 + }
  63 +}
  64 +
  65 +void IoWrapper::startOrContinueSslAccept()
  66 +{
  67 + ERR_clear_error();
  68 + int accepted = SSL_accept(ssl);
  69 + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];
  70 + if (accepted <= 0)
  71 + {
  72 + int err = SSL_get_error(ssl, accepted);
  73 +
  74 + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
  75 + {
  76 + parentClient->setReadyForWriting(err == SSL_ERROR_WANT_WRITE);
  77 + return;
  78 + }
  79 +
  80 + unsigned long error_code = ERR_get_error();
  81 +
  82 + ERR_error_string(error_code, sslErrorBuf);
  83 + std::string errorMsg(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);
  84 +
  85 + if (error_code == OPENSSL_WRONG_VERSION_NUMBER)
  86 + errorMsg = "Wrong protocol version number. Probably a non-SSL connection on SSL socket.";
  87 +
  88 + //ERR_print_errors_cb(logSslError, NULL);
  89 + throw std::runtime_error("Problem accepting SSL socket: " + errorMsg);
  90 + }
  91 + parentClient->setReadyForWriting(false); // Undo write readiness that may have have happened during SSL handshake
  92 + sslAccepted = true;
  93 +}
  94 +
  95 +bool IoWrapper::getSslReadWantsWrite() const
  96 +{
  97 + return this->sslReadWantsWrite;
  98 +}
  99 +
  100 +bool IoWrapper::getSslWriteWantsRead() const
  101 +{
  102 + return sslWriteWantsRead;
  103 +}
  104 +
  105 +bool IoWrapper::isSslAccepted() const
  106 +{
  107 + return this->sslAccepted;
  108 +}
  109 +
  110 +bool IoWrapper::isSsl() const
  111 +{
  112 + return this->ssl != nullptr;
  113 +}
  114 +
  115 +bool IoWrapper::hasPendingWrite() const
  116 +{
  117 + return incompleteSslWrite.hasPendingWrite() || websocketWriteRemainder.usedBytes() > 0;
  118 +}
  119 +
  120 +bool IoWrapper::isWebsocket() const
  121 +{
  122 + return websocket;
  123 +}
  124 +
  125 +WebsocketState IoWrapper::getWebsocketState() const
  126 +{
  127 + return websocketState;
  128 +}
  129 +
  130 +/**
  131 + * @brief SSL and non-SSL sockets behave differently. For one, reading 0 doesn't mean 'disconnected' with an SSL
  132 + * socket. This wrapper unifies behavor for the caller.
  133 + *
  134 + * @param fd
  135 + * @param buf
  136 + * @param nbytes
  137 + * @param error is an out-argument with the result.
  138 + * @return
  139 + */
  140 +ssize_t IoWrapper::readOrSslRead(int fd, void *buf, size_t nbytes, IoWrapResult *error)
  141 +{
  142 + *error = IoWrapResult::Success;
  143 + ssize_t n = 0;
  144 + if (!ssl)
  145 + {
  146 + n = read(fd, buf, nbytes);
  147 + if (n < 0)
  148 + {
  149 + if (errno == EINTR)
  150 + *error = IoWrapResult::Interrupted;
  151 + else if (errno == EAGAIN || errno == EWOULDBLOCK)
  152 + *error = IoWrapResult::Wouldblock;
  153 + else
  154 + check<std::runtime_error>(n);
  155 + }
  156 + else if (n == 0)
  157 + {
  158 + *error = IoWrapResult::Disconnected;
  159 + }
  160 + }
  161 + else
  162 + {
  163 + this->sslReadWantsWrite = false;
  164 + ERR_clear_error();
  165 + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];
  166 + n = SSL_read(ssl, buf, nbytes);
  167 +
  168 + if (n <= 0)
  169 + {
  170 + int err = SSL_get_error(ssl, n);
  171 + unsigned long error_code = ERR_get_error();
  172 +
  173 + // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html "BUGS" why EOF is seen as SSL_ERROR_SYSCALL.
  174 + if (err == SSL_ERROR_ZERO_RETURN || (err == SSL_ERROR_SYSCALL && errno == 0))
  175 + {
  176 + *error = IoWrapResult::Disconnected;
  177 + }
  178 + else if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
  179 + {
  180 + *error = IoWrapResult::Wouldblock;
  181 + if (err == SSL_ERROR_WANT_WRITE)
  182 + {
  183 + sslReadWantsWrite = true;
  184 + parentClient->setReadyForWriting(true);
  185 + }
  186 + n = -1;
  187 + }
  188 + else
  189 + {
  190 + if (err == SSL_ERROR_SYSCALL)
  191 + {
  192 + // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say
  193 + // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it
  194 + // implies EINTR is not included?
  195 + if (errno == EINTR)
  196 + *error = IoWrapResult::Interrupted;
  197 + else
  198 + {
  199 + char *err = strerror(errno);
  200 + std::string msg(err);
  201 + throw std::runtime_error("SSL read error: " + msg);
  202 + }
  203 + }
  204 + ERR_error_string(error_code, sslErrorBuf);
  205 + std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);
  206 + ERR_print_errors_cb(logSslError, NULL);
  207 + throw std::runtime_error("SSL socket error reading: " + errorString);
  208 + }
  209 + }
  210 + }
  211 +
  212 + return n;
  213 +}
  214 +
  215 +// SSL and non-SSL sockets behave differently. This wrapper unifies behavor for the caller.
  216 +ssize_t IoWrapper::writeOrSslWrite(int fd, const void *buf, size_t nbytes, IoWrapResult *error)
  217 +{
  218 + *error = IoWrapResult::Success;
  219 + ssize_t n = 0;
  220 +
  221 + if (!ssl)
  222 + {
  223 + // A write on a socket with count=0 is unspecified.
  224 + assert(nbytes > 0);
  225 +
  226 + n = write(fd, buf, nbytes);
  227 + if (n < 0)
  228 + {
  229 + if (errno == EINTR)
  230 + *error = IoWrapResult::Interrupted;
  231 + else if (errno == EAGAIN || errno == EWOULDBLOCK)
  232 + *error = IoWrapResult::Wouldblock;
  233 + else
  234 + check<std::runtime_error>(n);
  235 + }
  236 + }
  237 + else
  238 + {
  239 + const void *buf_ = buf;
  240 + size_t nbytes_ = nbytes;
  241 +
  242 + /*
  243 + * OpenSSL doc: When a write function call has to be repeated because SSL_get_error(3) returned
  244 + * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments
  245 + */
  246 + if (this->incompleteSslWrite.hasPendingWrite())
  247 + {
  248 + buf_ = this->incompleteSslWrite.buf;
  249 + nbytes_ = this->incompleteSslWrite.nbytes;
  250 + }
  251 +
  252 + // OpenSSL: "You should not call SSL_write() with num=0, it will return an error"
  253 + assert(nbytes_ > 0);
  254 +
  255 + this->sslWriteWantsRead = false;
  256 + this->incompleteSslWrite.reset();
  257 +
  258 + ERR_clear_error();
  259 + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];
  260 + n = SSL_write(ssl, buf_, nbytes_);
  261 +
  262 + if (n <= 0)
  263 + {
  264 + int err = SSL_get_error(ssl, n);
  265 + unsigned long error_code = ERR_get_error();
  266 + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
  267 + {
  268 + logger->logf(LOG_DEBUG, "Write is incomplete: %d", err);
  269 + *error = IoWrapResult::Wouldblock;
  270 + IncompleteSslWrite sslAction(buf_, nbytes_);
  271 + this->incompleteSslWrite = sslAction;
  272 + if (err == SSL_ERROR_WANT_READ)
  273 + this->sslWriteWantsRead = true;
  274 + n = 0;
  275 + }
  276 + else
  277 + {
  278 + if (err == SSL_ERROR_SYSCALL)
  279 + {
  280 + // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say
  281 + // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it
  282 + // implies EINTR is not included?
  283 + if (errno == EINTR)
  284 + *error = IoWrapResult::Interrupted;
  285 + else
  286 + {
  287 + char *err = strerror(errno);
  288 + std::string msg(err);
  289 + throw std::runtime_error(msg);
  290 + }
  291 + }
  292 + ERR_error_string(error_code, sslErrorBuf);
  293 + std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);
  294 + ERR_print_errors_cb(logSslError, NULL);
  295 + throw std::runtime_error("SSL socket error writing: " + errorString);
  296 + }
  297 + }
  298 + }
  299 +
  300 + return n;
  301 +}
  302 +
  303 +// Use a small intermediate buffer to write (partial) websocket frames to our normal read buffer. MQTT is already a frames protocol, so we don't
  304 +// care about websocket frames being incomplete.
  305 +ssize_t IoWrapper::readWebsocketAndOrSsl(int fd, void *buf, size_t nbytes, IoWrapResult *error)
  306 +{
  307 + if (!websocket)
  308 + {
  309 + return readOrSslRead(fd, buf, nbytes, error);
  310 + }
  311 + else
  312 + {
  313 + ssize_t n = 0;
  314 + while (websocketPendingBytes.freeSpace() > 0 && (n = readOrSslRead(fd, websocketPendingBytes.headPtr(), websocketPendingBytes.maxWriteSize(), error)) != 0)
  315 + {
  316 + if (n > 0)
  317 + websocketPendingBytes.advanceHead(n);
  318 + if (n < 0)
  319 + break; // signal/error handling is done by the caller, so we just stop.
  320 +
  321 + if (websocketState == WebsocketState::NotUpgraded && websocketPendingBytes.freeSpace() == 0)
  322 + {
  323 + if (websocketPendingBytes.getSize() * 2 <= 8192)
  324 + websocketPendingBytes.doubleSize();
  325 + else
  326 + throw ProtocolError("Trying to exceed websocket buffer. Probably not valid websocket traffic.");
  327 + }
  328 + }
  329 +
  330 + const bool hasWebsocketPendingBytes = websocketPendingBytes.usedBytes() > 0;
  331 +
  332 + // When some or all the data has been read, we can continue.
  333 + if (!(*error == IoWrapResult::Wouldblock || *error == IoWrapResult::Success) && !hasWebsocketPendingBytes)
  334 + return n;
  335 +
  336 + if (hasWebsocketPendingBytes)
  337 + {
  338 + n = 0;
  339 +
  340 + if (websocketState == WebsocketState::NotUpgraded)
  341 + {
  342 + try
  343 + {
  344 + std::string websocketKey;
  345 + int websocketVersion;
  346 + if (parseHttpHeader(websocketPendingBytes, websocketKey, websocketVersion))
  347 + {
  348 + if (websocketKey.empty())
  349 + throw BadHttpRequest("No websocket key specified.");
  350 + if (websocketVersion != 13)
  351 + throw BadWebsocketVersionException("Websocket version 13 required.");
  352 +
  353 + const std::string acceptString = generateWebsocketAcceptString(websocketKey);
  354 +
  355 + std::string answer = generateWebsocketAnswer(acceptString);
  356 + parentClient->writeText(answer);
  357 + websocketState = WebsocketState::Upgrading;
  358 + websocketPendingBytes.reset();
  359 + websocketPendingBytes.resetSize(initialBufferSize);
  360 + *error = IoWrapResult::Success;
  361 + }
  362 + }
  363 + catch (BadWebsocketVersionException &ex)
  364 + {
  365 + std::string response = generateInvalidWebsocketVersionHttpHeaders(13);
  366 + parentClient->writeText(response);
  367 + parentClient->setDisconnectReason("Invalid websocket version");
  368 + parentClient->setReadyForDisconnect();
  369 + }
  370 + catch (BadHttpRequest &ex) // Should should also properly deal with attempt at HTTP2 with PRI.
  371 + {
  372 + std::string response = generateBadHttpRequestReponse(ex.what());
  373 + parentClient->writeText(response);
  374 + parentClient->setDisconnectReason("Invalid websocket start");
  375 + parentClient->setReadyForDisconnect();
  376 + }
  377 + }
  378 + else
  379 + {
  380 + n = websocketBytesToReadBuffer(buf, nbytes);
  381 +
  382 + if (n > 0)
  383 + *error = IoWrapResult::Success;
  384 + else if (n == 0)
  385 + *error = IoWrapResult::Wouldblock;
  386 + }
  387 + }
  388 +
  389 + return n;
  390 + }
  391 +}
  392 +
  393 +ssize_t IoWrapper::websocketBytesToReadBuffer(void *buf, const size_t nbytes)
  394 +{
  395 + const ssize_t targetBufMaxSize = nbytes;
  396 + ssize_t nbytesRead = 0;
  397 +
  398 + while (websocketPendingBytes.usedBytes() >= WEBSOCKET_MIN_HEADER_BYTES_NEEDED && nbytesRead < targetBufMaxSize)
  399 + {
  400 + // This block decodes the header.
  401 + if (!incompleteWebsocketRead.sillWorkingOnFrame())
  402 + {
  403 + const uint8_t byte1 = websocketPendingBytes.peakAhead(0);
  404 + const uint8_t byte2 = websocketPendingBytes.peakAhead(1);
  405 + bool masked = !!(byte2 & 0b10000000);
  406 + uint8_t reserved = (byte1 & 0b01110000) >> 4;
  407 + WebsocketOpcode opcode = (WebsocketOpcode)(byte1 & 0b00001111);
  408 + const uint8_t payloadLength = byte2 & 0b01111111;
  409 + size_t realPayloadLength = payloadLength;
  410 + uint64_t extendedPayloadLengthLength = 0;
  411 + uint8_t headerLength = masked ? 6 : 2;
  412 +
  413 + if (payloadLength == 126)
  414 + extendedPayloadLengthLength = 2;
  415 + else if (payloadLength == 127)
  416 + extendedPayloadLengthLength = 8;
  417 + headerLength += extendedPayloadLengthLength;
  418 +
  419 + //if (!masked)
  420 + // throw ProtocolError("Client must send masked websocket bytes.");
  421 +
  422 + if (reserved != 0)
  423 + throw ProtocolError("Reserved bytes in header must be 0.");
  424 +
  425 + if (headerLength > websocketPendingBytes.usedBytes())
  426 + return nbytesRead;
  427 +
  428 + uint64_t extendedPayloadLength = 0;
  429 +
  430 + int i = 2;
  431 + int shift = extendedPayloadLengthLength * 8;
  432 + while (shift > 0)
  433 + {
  434 + shift -= 8;
  435 + uint8_t byte = websocketPendingBytes.peakAhead(i++);
  436 + extendedPayloadLength += (byte << shift);
  437 + }
  438 +
  439 + if (extendedPayloadLength > 0)
  440 + realPayloadLength = extendedPayloadLength;
  441 +
  442 + if (headerLength > websocketPendingBytes.usedBytes())
  443 + return nbytesRead;
  444 +
  445 + if (masked)
  446 + {
  447 + for (int j = 0; j < 4; j++)
  448 + {
  449 + incompleteWebsocketRead.maskingKey[j] = websocketPendingBytes.peakAhead(i++);
  450 + }
  451 + }
  452 +
  453 + assert(i == headerLength);
  454 + assert(headerLength <= websocketPendingBytes.usedBytes());
  455 + websocketPendingBytes.advanceTail(headerLength);
  456 +
  457 + incompleteWebsocketRead.frame_bytes_left = realPayloadLength;
  458 + incompleteWebsocketRead.opcode = opcode;
  459 + }
  460 +
  461 + if (incompleteWebsocketRead.opcode == WebsocketOpcode::Binary)
  462 + {
  463 + // The following reads one websocket frame max: it will continue with the previous, or start a new one, which it may or may not finish.
  464 + size_t targetBufI = 0;
  465 + char *targetBuf = &static_cast<char*>(buf)[nbytesRead];
  466 + while(websocketPendingBytes.usedBytes() > 0 && incompleteWebsocketRead.frame_bytes_left > 0 && nbytesRead < targetBufMaxSize)
  467 + {
  468 + const size_t asManyBytesOfThisFrameAsPossible = std::min<int>(websocketPendingBytes.maxReadSize(), incompleteWebsocketRead.frame_bytes_left);
  469 + const size_t maxReadSize = std::min<int>(asManyBytesOfThisFrameAsPossible, targetBufMaxSize - nbytesRead);
  470 + assert(maxReadSize > 0);
  471 + assert(static_cast<ssize_t>(maxReadSize) + nbytesRead <= targetBufMaxSize);
  472 + for (size_t x = 0; x < maxReadSize; x++)
  473 + {
  474 + targetBuf[targetBufI++] = websocketPendingBytes.tailPtr()[x] ^ incompleteWebsocketRead.getNextMaskingByte();
  475 + }
  476 + websocketPendingBytes.advanceTail(maxReadSize);
  477 + incompleteWebsocketRead.frame_bytes_left -= maxReadSize;
  478 + nbytesRead += maxReadSize;
  479 + }
  480 + }
  481 + else if (incompleteWebsocketRead.opcode == WebsocketOpcode::Ping)
  482 + {
  483 + // A ping MAY have user data, which needs to be ponged back;
  484 +
  485 + // Constructing a new temporary buffer because I need the reponse in one frame for writeAsMuchOfBufAsWebsocketFrame().
  486 + std::vector<char> response(incompleteWebsocketRead.frame_bytes_left);
  487 + websocketPendingBytes.read(response.data(), response.size());
  488 +
  489 + websocketWriteRemainder.ensureFreeSpace(response.size());
  490 + writeAsMuchOfBufAsWebsocketFrame(response.data(), response.size(), WebsocketOpcode::Pong);
  491 + parentClient->setReadyForWriting(true);
  492 + }
  493 +
  494 + if (!incompleteWebsocketRead.sillWorkingOnFrame())
  495 + incompleteWebsocketRead.reset();
  496 + }
  497 + assert(nbytesRead <= static_cast<ssize_t>(nbytes));
  498 +
  499 + return nbytesRead;
  500 +}
  501 +
  502 +/**
  503 + * @brief IoWrapper::writeAsMuchOfBufAsWebsocketFrame writes buf of part of buf as websocket frame to websocketWriteRemainder
  504 + * @param buf
  505 + * @param nbytes. The amount of bytes. Can be 0, for just an empty websocket frame.
  506 + * @return
  507 + */
  508 +ssize_t IoWrapper::writeAsMuchOfBufAsWebsocketFrame(const void *buf, size_t nbytes, WebsocketOpcode opcode)
  509 +{
  510 + // We do allow pong frames to generate a zero payload packet, but for binary, that's not necessary.
  511 + if (nbytes == 0 && opcode == WebsocketOpcode::Binary)
  512 + return 0;
  513 +
  514 + ssize_t nBytesReal = 0;
  515 +
  516 + // We normally wrap each write in a frame, but if a previous one didn't fit in the system's write buffers, we're still working on it.
  517 + if (websocketWriteRemainder.freeSpace() > WEBSOCKET_MAX_SENDING_HEADER_SIZE)
  518 + {
  519 + uint8_t extended_payload_length_num_bytes = 0;
  520 + uint8_t payload_length = 0;
  521 + if (nbytes < 126)
  522 + payload_length = nbytes;
  523 + else if (nbytes >= 126 && nbytes <= 0xFFFF)
  524 + {
  525 + payload_length = 126;
  526 + extended_payload_length_num_bytes = 2;
  527 + }
  528 + else if (nbytes > 0xFFFF)
  529 + {
  530 + payload_length = 127;
  531 + extended_payload_length_num_bytes = 8;
  532 + }
  533 +
  534 + int x = 0;
  535 + char header[WEBSOCKET_MAX_SENDING_HEADER_SIZE];
  536 + header[x++] = (0b10000000 | static_cast<char>(opcode));
  537 + header[x++] = payload_length;
  538 +
  539 + const int header_length = x + extended_payload_length_num_bytes;
  540 +
  541 + // This block writes the extended payload length.
  542 + nBytesReal = std::min<int>(nbytes, websocketWriteRemainder.freeSpace() - header_length);
  543 + const uint64_t nbytes64 = nBytesReal;
  544 + for (int z = extended_payload_length_num_bytes - 1; z >= 0; z--)
  545 + {
  546 + header[x++] = (nbytes64 >> (z*8)) & 0xFF;
  547 + }
  548 + assert(x <= WEBSOCKET_MAX_SENDING_HEADER_SIZE);
  549 + assert(x == header_length);
  550 +
  551 + websocketWriteRemainder.write(header, header_length);
  552 + websocketWriteRemainder.write(buf, nBytesReal);
  553 + }
  554 +
  555 + return nBytesReal;
  556 +}
  557 +
  558 +/*
  559 + * Mqtt docs: "A single WebSocket data frame can contain multiple or partial MQTT Control Packets. The receiver
  560 + * MUST NOT assume that MQTT Control Packets are aligned on WebSocket frame boundaries [MQTT-6.0.0-2]." We
  561 + * make use of that here, and wrap each write in a frame.
  562 + *
  563 + * It's can legitimately return a number of bytes written AND error with 'would block'. So, no need to do that
  564 + * repeating of the write thing that SSL_write() has.
  565 + */
  566 +ssize_t IoWrapper::writeWebsocketAndOrSsl(int fd, const void *buf, size_t nbytes, IoWrapResult *error)
  567 +{
  568 + if (websocketState != WebsocketState::Upgraded)
  569 + {
  570 + if (websocket && websocketState == WebsocketState::Upgrading)
  571 + websocketState = WebsocketState::Upgraded;
  572 +
  573 + return writeOrSslWrite(fd, buf, nbytes, error);
  574 + }
  575 + else
  576 + {
  577 + ssize_t nBytesReal = writeAsMuchOfBufAsWebsocketFrame(buf, nbytes);
  578 +
  579 + ssize_t n = 0;
  580 + while (websocketWriteRemainder.usedBytes() > 0)
  581 + {
  582 + n = writeOrSslWrite(fd, websocketWriteRemainder.tailPtr(), websocketWriteRemainder.maxReadSize(), error);
  583 +
  584 + if (n > 0)
  585 + websocketWriteRemainder.advanceTail(n);
  586 + if (n < 0)
  587 + break;
  588 + }
  589 +
  590 + if (n > 0)
  591 + return nBytesReal;
  592 +
  593 + return n;
  594 + }
  595 +}
iowrapper.h 0 โ†’ 100644
  1 +#ifndef IOWRAPPER_H
  2 +#define IOWRAPPER_H
  3 +
  4 +#include "unistd.h"
  5 +#include "openssl/ssl.h"
  6 +#include "openssl/err.h"
  7 +#include <exception>
  8 +
  9 +#include "forward_declarations.h"
  10 +
  11 +#include "types.h"
  12 +#include "utils.h"
  13 +#include "logger.h"
  14 +#include "exceptions.h"
  15 +
  16 +#define WEBSOCKET_MIN_HEADER_BYTES_NEEDED 2
  17 +#define WEBSOCKET_MAX_SENDING_HEADER_SIZE 10
  18 +
  19 +#define OPENSSL_ERROR_STRING_SIZE 256 // OpenSSL requires at least 256.
  20 +#define OPENSSL_WRONG_VERSION_NUMBER 336130315
  21 +
  22 +enum class IoWrapResult
  23 +{
  24 + Success = 0,
  25 + Interrupted = 1,
  26 + Wouldblock = 2,
  27 + Disconnected = 3,
  28 + Error = 4
  29 +};
  30 +
  31 +enum class WebsocketOpcode
  32 +{
  33 + Continuation = 0x00,
  34 + Text = 0x1,
  35 + Binary = 0x2,
  36 + Close = 0x8,
  37 + Ping = 0x9,
  38 + Pong = 0xA,
  39 + Unknown = 0xF
  40 +};
  41 +
  42 +/*
  43 + * OpenSSL doc: "When a write function call has to be repeated because SSL_get_error(3) returned
  44 + * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments"
  45 + */
  46 +struct IncompleteSslWrite
  47 +{
  48 + const void *buf = nullptr;
  49 + size_t nbytes = 0;
  50 +
  51 + IncompleteSslWrite() = default;
  52 + IncompleteSslWrite(const void *buf, size_t nbytes);
  53 + bool hasPendingWrite() const;
  54 +
  55 + void reset();
  56 +};
  57 +
  58 +struct IncompleteWebsocketRead
  59 +{
  60 + size_t frame_bytes_left = 0;
  61 + char maskingKey[4];
  62 + int maskingKeyI = 0;
  63 + WebsocketOpcode opcode;
  64 +
  65 + void reset();
  66 + bool sillWorkingOnFrame() const;
  67 + char getNextMaskingByte();
  68 +};
  69 +
  70 +enum class WebsocketState
  71 +{
  72 + NotUpgraded,
  73 + Upgrading,
  74 + Upgraded
  75 +};
  76 +
  77 +/**
  78 + * @brief provides a unified wrapper for SSL and websockets to read() and write().
  79 + *
  80 + *
  81 + */
  82 +class IoWrapper
  83 +{
  84 + Client *parentClient;
  85 + const size_t initialBufferSize;
  86 +
  87 + SSL *ssl = nullptr;
  88 + bool sslAccepted = false;
  89 + IncompleteSslWrite incompleteSslWrite;
  90 + bool sslReadWantsWrite = false;
  91 + bool sslWriteWantsRead = false;
  92 +
  93 + bool websocket;
  94 + WebsocketState websocketState = WebsocketState::NotUpgraded;
  95 + CirBuf websocketPendingBytes;
  96 + IncompleteWebsocketRead incompleteWebsocketRead;
  97 + CirBuf websocketWriteRemainder;
  98 +
  99 + Logger *logger = Logger::getInstance();
  100 +
  101 + ssize_t websocketBytesToReadBuffer(void *buf, const size_t nbytes);
  102 + ssize_t readOrSslRead(int fd, void *buf, size_t nbytes, IoWrapResult *error);
  103 + ssize_t writeOrSslWrite(int fd, const void *buf, size_t nbytes, IoWrapResult *error);
  104 + ssize_t writeAsMuchOfBufAsWebsocketFrame(const void *buf, size_t nbytes, WebsocketOpcode opcode = WebsocketOpcode::Binary);
  105 +public:
  106 + IoWrapper(SSL *ssl, bool websocket, const size_t initialBufferSize, Client *parent);
  107 + ~IoWrapper();
  108 +
  109 + void startOrContinueSslAccept();
  110 + bool getSslReadWantsWrite() const;
  111 + bool getSslWriteWantsRead() const;
  112 + bool isSslAccepted() const;
  113 + bool isSsl() const;
  114 + bool hasPendingWrite() const;
  115 + bool isWebsocket() const;
  116 + WebsocketState getWebsocketState() const;
  117 +
  118 + ssize_t readWebsocketAndOrSsl(int fd, void *buf, size_t nbytes, IoWrapResult *error);
  119 + ssize_t writeWebsocketAndOrSsl(int fd, const void *buf, size_t nbytes, IoWrapResult *error);
  120 +};
  121 +
  122 +#endif // IOWRAPPER_H
mainapp.cpp
@@ -351,6 +351,7 @@ void MainApp::start() @@ -351,6 +351,7 @@ void MainApp::start()
351 351
352 int listen_fd_plain = createListenSocket(this->listenPort, false); 352 int listen_fd_plain = createListenSocket(this->listenPort, false);
353 int listen_fd_ssl = createListenSocket(this->sslListenPort, true); 353 int listen_fd_ssl = createListenSocket(this->sslListenPort, true);
  354 + int listen_fd_websocket_plain = createListenSocket(1443, true);
354 355
355 #ifdef NDEBUG 356 #ifdef NDEBUG
356 logger->noLongerLogToStd(); 357 logger->noLongerLogToStd();
@@ -391,7 +392,7 @@ void MainApp::start() @@ -391,7 +392,7 @@ void MainApp::start()
391 int cur_fd = events[i].data.fd; 392 int cur_fd = events[i].data.fd;
392 try 393 try
393 { 394 {
394 - if (cur_fd == listen_fd_plain || cur_fd == listen_fd_ssl) 395 + if (cur_fd == listen_fd_plain || cur_fd == listen_fd_ssl || listen_fd_websocket_plain)
395 { 396 {
396 std::shared_ptr<ThreadData> thread_data = threads[next_thread_index++ % num_threads]; 397 std::shared_ptr<ThreadData> thread_data = threads[next_thread_index++ % num_threads];
397 398
@@ -402,6 +403,7 @@ void MainApp::start() @@ -402,6 +403,7 @@ void MainApp::start()
402 socklen_t len = sizeof(struct sockaddr); 403 socklen_t len = sizeof(struct sockaddr);
403 int fd = check<std::runtime_error>(accept(cur_fd, &addr, &len)); 404 int fd = check<std::runtime_error>(accept(cur_fd, &addr, &len));
404 405
  406 + bool websocket = cur_fd == listen_fd_websocket_plain;
405 SSL *clientSSL = nullptr; 407 SSL *clientSSL = nullptr;
406 if (cur_fd == listen_fd_ssl) 408 if (cur_fd == listen_fd_ssl)
407 { 409 {
@@ -417,7 +419,7 @@ void MainApp::start() @@ -417,7 +419,7 @@ void MainApp::start()
417 SSL_set_fd(clientSSL, fd); 419 SSL_set_fd(clientSSL, fd);
418 } 420 }
419 421
420 - Client_p client(new Client(fd, thread_data, clientSSL, settings)); 422 + Client_p client(new Client(fd, thread_data, clientSSL, websocket, settings));
421 thread_data->giveClient(client); 423 thread_data->giveClient(client);
422 } 424 }
423 else if (cur_fd == taskEventFd) 425 else if (cur_fd == taskEventFd)
utils.cpp
@@ -3,8 +3,10 @@ @@ -3,8 +3,10 @@
3 #include "sys/time.h" 3 #include "sys/time.h"
4 #include "sys/random.h" 4 #include "sys/random.h"
5 #include <algorithm> 5 #include <algorithm>
  6 +#include <sstream>
6 7
7 #include "exceptions.h" 8 #include "exceptions.h"
  9 +#include "cirbuf.h"
8 10
9 std::list<std::__cxx11::string> split(const std::string &input, const char sep, size_t max, bool keep_empty_parts) 11 std::list<std::__cxx11::string> split(const std::string &input, const char sep, size_t max, bool keep_empty_parts)
10 { 12 {
@@ -226,3 +228,121 @@ bool isPowerOfTwo(int n) @@ -226,3 +228,121 @@ bool isPowerOfTwo(int n)
226 { 228 {
227 return (n != 0) && (n & (n - 1)) == 0; 229 return (n != 0) && (n & (n - 1)) == 0;
228 } 230 }
  231 +
  232 +bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version)
  233 +{
  234 + const std::string s(buf.tailPtr(), buf.usedBytes());
  235 + std::istringstream is(s);
  236 + bool doubleEmptyLine = false; // meaning, the HTTP header is complete
  237 + bool upgradeHeaderSeen = false;
  238 + bool connectionHeaderSeen = false;
  239 + bool firstLine = true;
  240 +
  241 + std::string line;
  242 + while (std::getline(is, line))
  243 + {
  244 + trim(line);
  245 + if (firstLine)
  246 + {
  247 + firstLine = false;
  248 + if (!startsWith(line, "GET"))
  249 + throw BadHttpRequest("Websocket request should start with GET.");
  250 + continue;
  251 + }
  252 + if (line.empty())
  253 + {
  254 + doubleEmptyLine = true;
  255 + break;
  256 + }
  257 +
  258 + std::list<std::string> fields = split(line, ':', 1);
  259 + const std::vector<std::string> fields2(fields.begin(), fields.end());
  260 + std::string name = str_tolower(fields2[0]);
  261 + trim(name);
  262 + std::string value = fields2[1];
  263 + trim(value);
  264 + std::string value_lower = str_tolower(value);
  265 +
  266 + if (name == "upgrade" && value_lower == "websocket")
  267 + upgradeHeaderSeen = true;
  268 + else if (name == "connection" && value_lower == "upgrade")
  269 + connectionHeaderSeen = true;
  270 + else if (name == "sec-websocket-key")
  271 + websocket_key = value;
  272 + else if (name == "sec-websocket-version")
  273 + websocket_version = stoi(value);
  274 + }
  275 +
  276 + if (doubleEmptyLine)
  277 + {
  278 + if (!connectionHeaderSeen || !upgradeHeaderSeen)
  279 + throw BadHttpRequest("HTTP request is not a websocket upgrade request.");
  280 + }
  281 +
  282 + return doubleEmptyLine;
  283 +}
  284 +
  285 +std::string base64Encode(const unsigned char *input, const int length)
  286 +{
  287 + const int pl = 4*((length+2)/3);
  288 + char *output = reinterpret_cast<char *>(calloc(pl+1, 1));
  289 + const int ol = EVP_EncodeBlock(reinterpret_cast<unsigned char *>(output), input, length);
  290 + std::string result(output);
  291 + free(output);
  292 +
  293 + if (pl != ol)
  294 + throw std::runtime_error("Base64 encode error.");
  295 +
  296 + return result;
  297 +}
  298 +
  299 +std::string generateWebsocketAcceptString(const std::string &websocketKey)
  300 +{
  301 + unsigned char md_value[EVP_MAX_MD_SIZE];
  302 + unsigned int md_len;
  303 +
  304 + EVP_MD_CTX *mdctx = EVP_MD_CTX_new();;
  305 + const EVP_MD *md = EVP_sha1();
  306 + EVP_DigestInit_ex(mdctx, md, NULL);
  307 +
  308 + const std::string keyPlusMagic = websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
  309 +
  310 + EVP_DigestUpdate(mdctx, keyPlusMagic.c_str(), keyPlusMagic.length());
  311 + EVP_DigestFinal_ex(mdctx, md_value, &md_len);
  312 + EVP_MD_CTX_free(mdctx);
  313 +
  314 + std::string base64 = base64Encode(md_value, md_len);
  315 + return base64;
  316 +}
  317 +
  318 +std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion)
  319 +{
  320 + std::ostringstream oss;
  321 + oss << "HTTP/1.1 400 Bad Request\r\n";
  322 + oss << "Sec-WebSocket-Version: " << wantedVersion;
  323 + oss << "\r\n";
  324 + oss.flush();
  325 + return oss.str();
  326 +}
  327 +
  328 +std::string generateBadHttpRequestReponse(const std::string &msg)
  329 +{
  330 + std::ostringstream oss;
  331 + oss << "HTTP/1.1 400 Bad Request\r\n";
  332 + oss << "\r\n";
  333 + oss << msg;
  334 + oss.flush();
  335 + return oss.str();
  336 +}
  337 +
  338 +std::string generateWebsocketAnswer(const std::string &acceptString)
  339 +{
  340 + std::ostringstream oss;
  341 + oss << "HTTP/1.1 101 Switching Protocols\r\n";
  342 + oss << "Upgrade: websocket\r\n";
  343 + oss << "Connection: Upgrade\r\n";
  344 + oss << "Sec-WebSocket-Accept: " << acceptString << "\r\n";
  345 + oss << "\r\n";
  346 + oss.flush();
  347 + return oss.str();
  348 +}
@@ -8,6 +8,9 @@ @@ -8,6 +8,9 @@
8 #include <limits> 8 #include <limits>
9 #include <vector> 9 #include <vector>
10 #include <algorithm> 10 #include <algorithm>
  11 +#include <openssl/evp.h>
  12 +
  13 +#include "cirbuf.h"
11 14
12 template<typename T> int check(int rc) 15 template<typename T> int check(int rc)
13 { 16 {
@@ -43,4 +46,13 @@ std::string str_tolower(std::string s); @@ -43,4 +46,13 @@ std::string str_tolower(std::string s);
43 bool stringTruthiness(const std::string &val); 46 bool stringTruthiness(const std::string &val);
44 bool isPowerOfTwo(int val); 47 bool isPowerOfTwo(int val);
45 48
  49 +bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version);
  50 +
  51 +std::string base64Encode(const unsigned char *input, const int length);
  52 +std::string generateWebsocketAcceptString(const std::string &websocketKey);
  53 +
  54 +std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion);
  55 +std::string generateBadHttpRequestReponse(const std::string &msg);
  56 +std::string generateWebsocketAnswer(const std::string &acceptString);
  57 +
46 #endif // UTILS_H 58 #endif // UTILS_H