Commit b5ba41f503c602c0edcdeec0db5b2b8344050cb8

Authored by Wiebe Cazemier
1 parent 1106e210

Add IoWrapper, with websocket support added

The ping/pong is actually untested at this point, because Paho (my test
client for now) doesn't do those. I wonder if any do, because MQTT
already has ping/pong.
CMakeLists.txt
... ... @@ -29,6 +29,7 @@ add_executable(FlashMQ
29 29 sslctxmanager.cpp
30 30 timer.cpp
31 31 globalsettings.cpp
  32 + iowrapper.cpp
32 33 )
33 34  
34 35 target_link_libraries(FlashMQ pthread dl ssl crypto)
... ...
cirbuf.cpp
... ... @@ -9,10 +9,14 @@
9 9 #include <cstring>
10 10  
11 11 #include "logger.h"
  12 +#include "utils.h"
12 13  
13 14 CirBuf::CirBuf(size_t size) :
14 15 size(size)
15 16 {
  17 + if (size == 0)
  18 + return;
  19 +
16 20 buf = (char*)malloc(size);
17 21  
18 22 if (buf == NULL)
... ... @@ -69,12 +73,14 @@ char *CirBuf::tailPtr()
69 73  
70 74 void CirBuf::advanceHead(uint32_t n)
71 75 {
  76 + assert(n <= freeSpace());
72 77 head = (head + n) & (size -1);
73 78 assert(tail != head); // Putting things in the buffer must never end on tail, because tail == head == empty.
74 79 }
75 80  
76 81 void CirBuf::advanceTail(uint32_t n)
77 82 {
  83 + assert(n <= usedBytes());
78 84 tail = (tail + n) & (size -1);
79 85 }
80 86  
... ... @@ -84,9 +90,31 @@ char CirBuf::peakAhead(uint32_t offset) const
84 90 return b;
85 91 }
86 92  
87   -void CirBuf::doubleSize()
  93 +void CirBuf::ensureFreeSpace(size_t n)
  94 +{
  95 + const size_t _usedBytes = usedBytes();
  96 +
  97 + int mul = 1;
  98 +
  99 + while((mul * size - _usedBytes - 1) < n)
  100 + {
  101 + mul = mul << 1;
  102 + }
  103 +
  104 + if (mul == 1)
  105 + return;
  106 +
  107 + doubleSize(mul);
  108 +}
  109 +
  110 +void CirBuf::doubleSize(uint factor)
88 111 {
89   - uint newSize = size * 2;
  112 + if (factor == 1)
  113 + return;
  114 +
  115 + assert(isPowerOfTwo(factor));
  116 +
  117 + uint newSize = size * factor;
90 118 char *newBuf = (char*)realloc(buf, newSize);
91 119  
92 120 if (newBuf == NULL)
... ... @@ -145,3 +173,55 @@ void CirBuf::resetSize(size_t newSize)
145 173 memset(buf, 0, newSize);
146 174 #endif
147 175 }
  176 +
  177 +void CirBuf::reset()
  178 +{
  179 + head = 0;
  180 + tail = 0;
  181 +
  182 +#ifndef NDEBUG
  183 + memset(buf, 0, size);
  184 +#endif
  185 +}
  186 +
  187 +// When you know the data you want to write fits in the buffer, use this.
  188 +void CirBuf::write(const void *buf, size_t count)
  189 +{
  190 + assert(count <= freeSpace());
  191 +
  192 + ssize_t len_left = count;
  193 + size_t src_i = 0;
  194 + while (len_left > 0)
  195 + {
  196 + const size_t len = std::min<int>(len_left, maxWriteSize());
  197 + assert(len > 0);
  198 + const char *src = &reinterpret_cast<const char*>(buf)[src_i];
  199 + std::memcpy(headPtr(), src, len);
  200 + advanceHead(len);
  201 + src_i += len;
  202 + len_left -= len;
  203 + }
  204 + assert(len_left == 0);
  205 + assert(src_i == count);
  206 +}
  207 +
  208 +// When you know 'count' bytes are present and you want to read them into buf
  209 +void CirBuf::read(void *buf, const size_t count)
  210 +{
  211 + assert(count <= usedBytes());
  212 +
  213 + char *_buf = static_cast<char*>(buf);
  214 + int i = 0;
  215 + ssize_t _packet_len = count;
  216 + while (_packet_len > 0)
  217 + {
  218 + const int readlen = std::min<int>(maxReadSize(), _packet_len);
  219 + assert(readlen > 0);
  220 + std::memcpy(&_buf[i], tailPtr(), readlen);
  221 + advanceTail(readlen);
  222 + i += readlen;
  223 + _packet_len -= readlen;
  224 + }
  225 + assert(_packet_len == 0);
  226 + assert(i == _packet_len);
  227 +}
... ...
cirbuf.h
... ... @@ -32,11 +32,16 @@ public:
32 32 void advanceHead(uint32_t n);
33 33 void advanceTail(uint32_t n);
34 34 char peakAhead(uint32_t offset) const;
35   - void doubleSize();
  35 + void ensureFreeSpace(size_t n);
  36 + void doubleSize(uint factor = 2);
36 37 uint32_t getSize() const;
37 38  
38 39 time_t bufferLastResizedSecondsAgo() const;
39 40 void resetSize(size_t size);
  41 + void reset();
  42 +
  43 + void write(const void *buf, size_t count);
  44 + void read(void *buf, const size_t count);
40 45 };
41 46  
42 47 #endif // CIRBUF_H
... ...
client.cpp
... ... @@ -7,11 +7,11 @@
7 7  
8 8 #include "logger.h"
9 9  
10   -Client::Client(int fd, ThreadData_p threadData, SSL *ssl, const GlobalSettings &settings) :
  10 +Client::Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, const GlobalSettings &settings) :
11 11 fd(fd),
12   - ssl(ssl),
13 12 initialBufferSize(settings.clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy
14 13 maxPacketSize(settings.maxPacketSize), // Same as initialBufferSize comment.
  14 + ioWrapper(ssl, websocket, initialBufferSize, this),
15 15 readbuf(initialBufferSize),
16 16 writebuf(initialBufferSize),
17 17 threadData(threadData)
... ... @@ -40,63 +40,32 @@ Client::~Client()
40 40 logger->logf(LOG_NOTICE, "Removing client '%s'. Reason(s): %s", repr().c_str(), disconnectReason.c_str());
41 41 if (epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL) != 0)
42 42 logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd, repr().c_str(), strerror(errno));
43   - if (ssl)
44   - {
45   - // I don't do SSL_shutdown(), because I don't want to keep the session, plus, that takes active de-negiotation, so it can't be done
46   - // in the destructor.
47   - SSL_free(ssl);
48   - }
49 43 close(fd);
50 44 }
51 45  
52 46 bool Client::isSslAccepted() const
53 47 {
54   - return sslAccepted;
  48 + return ioWrapper.isSslAccepted();
55 49 }
56 50  
57 51 bool Client::isSsl() const
58 52 {
59   - return this->ssl != nullptr;
  53 + return ioWrapper.isSsl();
60 54 }
61 55  
62 56 bool Client::getSslReadWantsWrite() const
63 57 {
64   - return this->sslReadWantsWrite;
  58 + return ioWrapper.getSslReadWantsWrite();
65 59 }
66 60  
67 61 bool Client::getSslWriteWantsRead() const
68 62 {
69   - return this->sslWriteWantsRead;
  63 + return ioWrapper.getSslWriteWantsRead();
70 64 }
71 65  
72 66 void Client::startOrContinueSslAccept()
73 67 {
74   - ERR_clear_error();
75   - int accepted = SSL_accept(ssl);
76   - char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];
77   - if (accepted <= 0)
78   - {
79   - int err = SSL_get_error(ssl, accepted);
80   -
81   - if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
82   - {
83   - setReadyForWriting(err == SSL_ERROR_WANT_WRITE);
84   - return;
85   - }
86   -
87   - unsigned long error_code = ERR_get_error();
88   -
89   - ERR_error_string(error_code, sslErrorBuf);
90   - std::string errorMsg(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);
91   -
92   - if (error_code == OPENSSL_WRONG_VERSION_NUMBER)
93   - errorMsg = "Wrong protocol version number. Probably a non-SSL connection on SSL socket.";
94   -
95   - //ERR_print_errors_cb(logSslError, NULL);
96   - throw std::runtime_error("Problem accepting SSL socket: " + errorMsg);
97   - }
98   - setReadyForWriting(false); // Undo write readiness that may have have happened during SSL handshake
99   - sslAccepted = true;
  68 + ioWrapper.startOrContinueSslAccept();
100 69 }
101 70  
102 71 // Causes future activity on the client to cause a disconnect.
... ... @@ -108,83 +77,6 @@ void Client::markAsDisconnecting()
108 77 disconnecting = true;
109 78 }
110 79  
111   -// SSL and non-SSL sockets behave differently. For one, reading 0 doesn't mean 'disconnected' with an SSL
112   -// socket. This wrapper unifies behavor for the caller.
113   -ssize_t Client::readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error)
114   -{
115   - *error = IoWrapResult::Success;
116   - ssize_t n = 0;
117   - if (!ssl)
118   - {
119   - n = read(fd, buf, nbytes);
120   - if (n < 0)
121   - {
122   - if (errno == EINTR)
123   - *error = IoWrapResult::Interrupted;
124   - else if (errno == EAGAIN || errno == EWOULDBLOCK)
125   - *error = IoWrapResult::Wouldblock;
126   - else
127   - check<std::runtime_error>(n);
128   - }
129   - else if (n == 0)
130   - {
131   - *error = IoWrapResult::Disconnected;
132   - }
133   - }
134   - else
135   - {
136   - this->sslReadWantsWrite = false;
137   - ERR_clear_error();
138   - char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];
139   - n = SSL_read(ssl, buf, nbytes);
140   -
141   - if (n <= 0)
142   - {
143   - int err = SSL_get_error(ssl, n);
144   - unsigned long error_code = ERR_get_error();
145   -
146   - // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html "BUGS" why EOF is seen as SSL_ERROR_SYSCALL.
147   - if (err == SSL_ERROR_ZERO_RETURN || (err == SSL_ERROR_SYSCALL && errno == 0))
148   - {
149   - *error = IoWrapResult::Disconnected;
150   - }
151   - else if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
152   - {
153   - *error = IoWrapResult::Wouldblock;
154   - if (err == SSL_ERROR_WANT_WRITE)
155   - {
156   - sslReadWantsWrite = true;
157   - setReadyForWriting(true);
158   - }
159   - n = -1;
160   - }
161   - else
162   - {
163   - if (err == SSL_ERROR_SYSCALL)
164   - {
165   - // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say
166   - // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it
167   - // implies EINTR is not included?
168   - if (errno == EINTR)
169   - *error = IoWrapResult::Interrupted;
170   - else
171   - {
172   - char *err = strerror(errno);
173   - std::string msg(err);
174   - throw std::runtime_error("SSL read error: " + msg);
175   - }
176   - }
177   - ERR_error_string(error_code, sslErrorBuf);
178   - std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);
179   - ERR_print_errors_cb(logSslError, NULL);
180   - throw std::runtime_error("SSL socket error reading: " + errorString);
181   - }
182   - }
183   - }
184   -
185   - return n;
186   -}
187   -
188 80 // false means any kind of error we want to get rid of the client for.
189 81 bool Client::readFdIntoBuffer()
190 82 {
... ... @@ -193,7 +85,7 @@ bool Client::readFdIntoBuffer()
193 85  
194 86 IoWrapResult error = IoWrapResult::Success;
195 87 int n = 0;
196   - while (readbuf.freeSpace() > 0 && (n = readWrap(fd, readbuf.headPtr(), readbuf.maxWriteSize(), &error)) != 0)
  88 + while (readbuf.freeSpace() > 0 && (n = ioWrapper.readWebsocketAndOrSsl(fd, readbuf.headPtr(), readbuf.maxWriteSize(), &error)) != 0)
197 89 {
198 90 if (n > 0)
199 91 {
... ... @@ -232,6 +124,20 @@ bool Client::readFdIntoBuffer()
232 124 return true;
233 125 }
234 126  
  127 +void Client::writeText(const std::string &text)
  128 +{
  129 + assert(ioWrapper.isWebsocket());
  130 + assert(ioWrapper.getWebsocketState() == WebsocketState::NotUpgraded);
  131 +
  132 + // Not necessary, because at this point, no other threads write to this client, but including for clarity.
  133 + std::lock_guard<std::mutex> locker(writeBufMutex);
  134 +
  135 + writebuf.ensureFreeSpace(text.size());
  136 + writebuf.write(text.c_str(), text.length());
  137 +
  138 + setReadyForWriting(true);
  139 +}
  140 +
235 141 void Client::writeMqttPacket(const MqttPacket &packet)
236 142 {
237 143 std::lock_guard<std::mutex> locker(writeBufMutex);
... ... @@ -322,94 +228,6 @@ void Client::writePingResp()
322 228 setReadyForWriting(true);
323 229 }
324 230  
325   -// SSL and non-SSL sockets behave differently. This wrapper unifies behavor for the caller.
326   -ssize_t Client::writeWrap(int fd, const void *buf, size_t nbytes, IoWrapResult *error)
327   -{
328   - *error = IoWrapResult::Success;
329   - ssize_t n = 0;
330   -
331   - if (!ssl)
332   - {
333   - // A write on a socket with count=0 is unspecified.
334   - assert(nbytes > 0);
335   -
336   - n = write(fd, buf, nbytes);
337   - if (n < 0)
338   - {
339   - if (errno == EINTR)
340   - *error = IoWrapResult::Interrupted;
341   - else if (errno == EAGAIN || errno == EWOULDBLOCK)
342   - *error = IoWrapResult::Wouldblock;
343   - else
344   - check<std::runtime_error>(n);
345   - }
346   - }
347   - else
348   - {
349   - const void *buf_ = buf;
350   - size_t nbytes_ = nbytes;
351   -
352   - /*
353   - * OpenSSL doc: When a write function call has to be repeated because SSL_get_error(3) returned
354   - * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments
355   - */
356   - if (this->incompleteSslWrite.hasPendingWrite())
357   - {
358   - buf_ = this->incompleteSslWrite.buf;
359   - nbytes_ = this->incompleteSslWrite.nbytes;
360   - }
361   -
362   - // OpenSSL: "You should not call SSL_write() with num=0, it will return an error"
363   - assert(nbytes_ > 0);
364   -
365   - this->sslWriteWantsRead = false;
366   - this->incompleteSslWrite.reset();
367   -
368   - ERR_clear_error();
369   - char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];
370   - n = SSL_write(ssl, buf_, nbytes_);
371   -
372   - if (n <= 0)
373   - {
374   - int err = SSL_get_error(ssl, n);
375   - unsigned long error_code = ERR_get_error();
376   - if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
377   - {
378   - logger->logf(LOG_DEBUG, "Write is incomplete: %d", err);
379   - *error = IoWrapResult::Wouldblock;
380   - IncompleteSslWrite sslAction(buf_, nbytes_);
381   - this->incompleteSslWrite = sslAction;
382   - if (err == SSL_ERROR_WANT_READ)
383   - this->sslWriteWantsRead = true;
384   - n = 0;
385   - }
386   - else
387   - {
388   - if (err == SSL_ERROR_SYSCALL)
389   - {
390   - // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say
391   - // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it
392   - // implies EINTR is not included?
393   - if (errno == EINTR)
394   - *error = IoWrapResult::Interrupted;
395   - else
396   - {
397   - char *err = strerror(errno);
398   - std::string msg(err);
399   - throw std::runtime_error(msg);
400   - }
401   - }
402   - ERR_error_string(error_code, sslErrorBuf);
403   - std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);
404   - ERR_print_errors_cb(logSslError, NULL);
405   - throw std::runtime_error("SSL socket error writing: " + errorString);
406   - }
407   - }
408   - }
409   -
410   - return n;
411   -}
412   -
413 231 bool Client::writeBufIntoFd()
414 232 {
415 233 std::unique_lock<std::mutex> lock(writeBufMutex, std::try_to_lock);
... ... @@ -422,9 +240,9 @@ bool Client::writeBufIntoFd()
422 240  
423 241 IoWrapResult error = IoWrapResult::Success;
424 242 int n;
425   - while (writebuf.usedBytes() > 0 || incompleteSslWrite.hasPendingWrite())
  243 + while (writebuf.usedBytes() > 0 || ioWrapper.hasPendingWrite())
426 244 {
427   - n = writeWrap(fd, writebuf.tailPtr(), writebuf.maxReadSize(), &error);
  245 + n = ioWrapper.writeWebsocketAndOrSsl(fd, writebuf.tailPtr(), writebuf.maxReadSize(), &error);
428 246  
429 247 if (n > 0)
430 248 writebuf.advanceTail(n);
... ... @@ -484,7 +302,7 @@ void Client::setReadyForWriting(bool val)
484 302 if (disconnecting)
485 303 return;
486 304  
487   - if (sslReadWantsWrite)
  305 + if (ioWrapper.getSslReadWantsWrite())
488 306 val = true;
489 307  
490 308 if (val == this->readyForWriting)
... ... @@ -622,20 +440,3 @@ void Client::clearWill()
622 440 will_qos = 0;
623 441 }
624 442  
625   -IncompleteSslWrite::IncompleteSslWrite(const void *buf, size_t nbytes) :
626   - buf(buf),
627   - nbytes(nbytes)
628   -{
629   -
630   -}
631   -
632   -void IncompleteSslWrite::reset()
633   -{
634   - buf = nullptr;
635   - nbytes = 0;
636   -}
637   -
638   -bool IncompleteSslWrite::hasPendingWrite()
639   -{
640   - return buf != nullptr;
641   -}
... ...
client.h
... ... @@ -9,6 +9,7 @@
9 9 #include <time.h>
10 10  
11 11 #include <openssl/ssl.h>
  12 +#include <openssl/err.h>
12 13  
13 14 #include "forward_declarations.h"
14 15  
... ... @@ -17,54 +18,25 @@
17 18 #include "exceptions.h"
18 19 #include "cirbuf.h"
19 20 #include "types.h"
20   -
21   -#include <openssl/ssl.h>
22   -#include <openssl/err.h>
  21 +#include "iowrapper.h"
23 22  
24 23 #define MQTT_HEADER_LENGH 2
25 24  
26   -#define OPENSSL_ERROR_STRING_SIZE 256 // OpenSSL requires at least 256.
27   -#define OPENSSL_WRONG_VERSION_NUMBER 336130315
28   -
29   -enum class IoWrapResult
30   -{
31   - Success = 0,
32   - Interrupted = 1,
33   - Wouldblock = 2,
34   - Disconnected = 3,
35   - Error = 4
36   -};
37   -
38   -/*
39   - * OpenSSL doc: "When a write function call has to be repeated because SSL_get_error(3) returned
40   - * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments"
41   - */
42   -struct IncompleteSslWrite
43   -{
44   - const void *buf = nullptr;
45   - size_t nbytes = 0;
46   -
47   - IncompleteSslWrite() = default;
48   - IncompleteSslWrite(const void *buf, size_t nbytes);
49   - bool hasPendingWrite();
50   -
51   - void reset();
52   -};
53 25  
54 26 // TODO: give accepted addr, for showing in logs
55 27 class Client
56 28 {
  29 + friend class IoWrapper;
  30 +
57 31 int fd;
58   - SSL *ssl = nullptr;
59   - bool sslAccepted = false;
60   - IncompleteSslWrite incompleteSslWrite;
61   - bool sslReadWantsWrite = false;
62   - bool sslWriteWantsRead = false;
  32 +
63 33 ProtocolVersion protocolVersion = ProtocolVersion::None;
64 34  
65 35 const size_t initialBufferSize = 0;
66 36 const size_t maxPacketSize = 0;
67 37  
  38 + IoWrapper ioWrapper;
  39 +
68 40 CirBuf readbuf;
69 41 uint8_t readBufIsZeroCount = 0;
70 42  
... ... @@ -97,12 +69,11 @@ class Client
97 69  
98 70 Logger *logger = Logger::getInstance();
99 71  
100   -
101 72 void setReadyForWriting(bool val);
102 73 void setReadyForReading(bool val);
103 74  
104 75 public:
105   - Client(int fd, ThreadData_p threadData, SSL *ssl, const GlobalSettings &settings);
  76 + Client(int fd, ThreadData_p threadData, SSL *ssl, bool websocket, const GlobalSettings &settings);
106 77 Client(const Client &other) = delete;
107 78 Client(Client &&other) = delete;
108 79 ~Client();
... ... @@ -115,7 +86,6 @@ public:
115 86  
116 87 void startOrContinueSslAccept();
117 88 void markAsDisconnecting();
118   - ssize_t readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error);
119 89 bool readFdIntoBuffer();
120 90 bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender);
121 91 void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession);
... ... @@ -131,10 +101,10 @@ public:
131 101 std::shared_ptr<Session> getSession();
132 102 void setDisconnectReason(const std::string &reason);
133 103  
  104 + void writeText(const std::string &text);
134 105 void writePingResp();
135 106 void writeMqttPacket(const MqttPacket &packet);
136 107 void writeMqttPacketAndBlameThisClient(const MqttPacket &packet);
137   - ssize_t writeWrap(int fd, const void *buf, size_t nbytes, IoWrapResult *error);
138 108 bool writeBufIntoFd();
139 109 bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; }
140 110  
... ...
exceptions.h
... ... @@ -34,4 +34,16 @@ public:
34 34 AuthPluginException(const std::string &msg) : std::runtime_error(msg) {}
35 35 };
36 36  
  37 +class BadWebsocketVersionException : public std::runtime_error
  38 +{
  39 +public:
  40 + BadWebsocketVersionException(const std::string &msg) : std::runtime_error(msg) {}
  41 +};
  42 +
  43 +class BadHttpRequest : public std::runtime_error
  44 +{
  45 +public:
  46 + BadHttpRequest(const std::string &msg) : std::runtime_error(msg) {}
  47 +};
  48 +
37 49 #endif // EXCEPTIONS_H
... ...
iowrapper.cpp 0 โ†’ 100644
  1 +#include "iowrapper.h"
  2 +
  3 +#include "cassert"
  4 +
  5 +#include "logger.h"
  6 +#include "client.h"
  7 +
  8 +IncompleteSslWrite::IncompleteSslWrite(const void *buf, size_t nbytes) :
  9 + buf(buf),
  10 + nbytes(nbytes)
  11 +{
  12 +
  13 +}
  14 +
  15 +void IncompleteSslWrite::reset()
  16 +{
  17 + buf = nullptr;
  18 + nbytes = 0;
  19 +}
  20 +
  21 +bool IncompleteSslWrite::hasPendingWrite() const
  22 +{
  23 + return buf != nullptr;
  24 +}
  25 +
  26 +void IncompleteWebsocketRead::reset()
  27 +{
  28 + maskingKeyI = 0;
  29 + memset(maskingKey,0, 4);
  30 + frame_bytes_left = 0;
  31 + opcode = WebsocketOpcode::Unknown;
  32 +}
  33 +
  34 +bool IncompleteWebsocketRead::sillWorkingOnFrame() const
  35 +{
  36 + return frame_bytes_left > 0;
  37 +}
  38 +
  39 +char IncompleteWebsocketRead::getNextMaskingByte()
  40 +{
  41 + return maskingKey[maskingKeyI++ % 4];
  42 +}
  43 +
  44 +IoWrapper::IoWrapper(SSL *ssl, bool websocket, const size_t initialBufferSize, Client *parent) :
  45 + parentClient(parent),
  46 + initialBufferSize(initialBufferSize),
  47 + ssl(ssl),
  48 + websocket(websocket),
  49 + websocketPendingBytes(websocket ? initialBufferSize : 0),
  50 + websocketWriteRemainder(websocket ? initialBufferSize : 0)
  51 +{
  52 +
  53 +}
  54 +
  55 +IoWrapper::~IoWrapper()
  56 +{
  57 + if (ssl)
  58 + {
  59 + // I don't do SSL_shutdown(), because I don't want to keep the session, plus, that takes active de-negiotation, so it can't be done
  60 + // in the destructor.
  61 + SSL_free(ssl);
  62 + }
  63 +}
  64 +
  65 +void IoWrapper::startOrContinueSslAccept()
  66 +{
  67 + ERR_clear_error();
  68 + int accepted = SSL_accept(ssl);
  69 + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];
  70 + if (accepted <= 0)
  71 + {
  72 + int err = SSL_get_error(ssl, accepted);
  73 +
  74 + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
  75 + {
  76 + parentClient->setReadyForWriting(err == SSL_ERROR_WANT_WRITE);
  77 + return;
  78 + }
  79 +
  80 + unsigned long error_code = ERR_get_error();
  81 +
  82 + ERR_error_string(error_code, sslErrorBuf);
  83 + std::string errorMsg(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);
  84 +
  85 + if (error_code == OPENSSL_WRONG_VERSION_NUMBER)
  86 + errorMsg = "Wrong protocol version number. Probably a non-SSL connection on SSL socket.";
  87 +
  88 + //ERR_print_errors_cb(logSslError, NULL);
  89 + throw std::runtime_error("Problem accepting SSL socket: " + errorMsg);
  90 + }
  91 + parentClient->setReadyForWriting(false); // Undo write readiness that may have have happened during SSL handshake
  92 + sslAccepted = true;
  93 +}
  94 +
  95 +bool IoWrapper::getSslReadWantsWrite() const
  96 +{
  97 + return this->sslReadWantsWrite;
  98 +}
  99 +
  100 +bool IoWrapper::getSslWriteWantsRead() const
  101 +{
  102 + return sslWriteWantsRead;
  103 +}
  104 +
  105 +bool IoWrapper::isSslAccepted() const
  106 +{
  107 + return this->sslAccepted;
  108 +}
  109 +
  110 +bool IoWrapper::isSsl() const
  111 +{
  112 + return this->ssl != nullptr;
  113 +}
  114 +
  115 +bool IoWrapper::hasPendingWrite() const
  116 +{
  117 + return incompleteSslWrite.hasPendingWrite() || websocketWriteRemainder.usedBytes() > 0;
  118 +}
  119 +
  120 +bool IoWrapper::isWebsocket() const
  121 +{
  122 + return websocket;
  123 +}
  124 +
  125 +WebsocketState IoWrapper::getWebsocketState() const
  126 +{
  127 + return websocketState;
  128 +}
  129 +
  130 +/**
  131 + * @brief SSL and non-SSL sockets behave differently. For one, reading 0 doesn't mean 'disconnected' with an SSL
  132 + * socket. This wrapper unifies behavor for the caller.
  133 + *
  134 + * @param fd
  135 + * @param buf
  136 + * @param nbytes
  137 + * @param error is an out-argument with the result.
  138 + * @return
  139 + */
  140 +ssize_t IoWrapper::readOrSslRead(int fd, void *buf, size_t nbytes, IoWrapResult *error)
  141 +{
  142 + *error = IoWrapResult::Success;
  143 + ssize_t n = 0;
  144 + if (!ssl)
  145 + {
  146 + n = read(fd, buf, nbytes);
  147 + if (n < 0)
  148 + {
  149 + if (errno == EINTR)
  150 + *error = IoWrapResult::Interrupted;
  151 + else if (errno == EAGAIN || errno == EWOULDBLOCK)
  152 + *error = IoWrapResult::Wouldblock;
  153 + else
  154 + check<std::runtime_error>(n);
  155 + }
  156 + else if (n == 0)
  157 + {
  158 + *error = IoWrapResult::Disconnected;
  159 + }
  160 + }
  161 + else
  162 + {
  163 + this->sslReadWantsWrite = false;
  164 + ERR_clear_error();
  165 + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];
  166 + n = SSL_read(ssl, buf, nbytes);
  167 +
  168 + if (n <= 0)
  169 + {
  170 + int err = SSL_get_error(ssl, n);
  171 + unsigned long error_code = ERR_get_error();
  172 +
  173 + // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html "BUGS" why EOF is seen as SSL_ERROR_SYSCALL.
  174 + if (err == SSL_ERROR_ZERO_RETURN || (err == SSL_ERROR_SYSCALL && errno == 0))
  175 + {
  176 + *error = IoWrapResult::Disconnected;
  177 + }
  178 + else if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
  179 + {
  180 + *error = IoWrapResult::Wouldblock;
  181 + if (err == SSL_ERROR_WANT_WRITE)
  182 + {
  183 + sslReadWantsWrite = true;
  184 + parentClient->setReadyForWriting(true);
  185 + }
  186 + n = -1;
  187 + }
  188 + else
  189 + {
  190 + if (err == SSL_ERROR_SYSCALL)
  191 + {
  192 + // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say
  193 + // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it
  194 + // implies EINTR is not included?
  195 + if (errno == EINTR)
  196 + *error = IoWrapResult::Interrupted;
  197 + else
  198 + {
  199 + char *err = strerror(errno);
  200 + std::string msg(err);
  201 + throw std::runtime_error("SSL read error: " + msg);
  202 + }
  203 + }
  204 + ERR_error_string(error_code, sslErrorBuf);
  205 + std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);
  206 + ERR_print_errors_cb(logSslError, NULL);
  207 + throw std::runtime_error("SSL socket error reading: " + errorString);
  208 + }
  209 + }
  210 + }
  211 +
  212 + return n;
  213 +}
  214 +
  215 +// SSL and non-SSL sockets behave differently. This wrapper unifies behavor for the caller.
  216 +ssize_t IoWrapper::writeOrSslWrite(int fd, const void *buf, size_t nbytes, IoWrapResult *error)
  217 +{
  218 + *error = IoWrapResult::Success;
  219 + ssize_t n = 0;
  220 +
  221 + if (!ssl)
  222 + {
  223 + // A write on a socket with count=0 is unspecified.
  224 + assert(nbytes > 0);
  225 +
  226 + n = write(fd, buf, nbytes);
  227 + if (n < 0)
  228 + {
  229 + if (errno == EINTR)
  230 + *error = IoWrapResult::Interrupted;
  231 + else if (errno == EAGAIN || errno == EWOULDBLOCK)
  232 + *error = IoWrapResult::Wouldblock;
  233 + else
  234 + check<std::runtime_error>(n);
  235 + }
  236 + }
  237 + else
  238 + {
  239 + const void *buf_ = buf;
  240 + size_t nbytes_ = nbytes;
  241 +
  242 + /*
  243 + * OpenSSL doc: When a write function call has to be repeated because SSL_get_error(3) returned
  244 + * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments
  245 + */
  246 + if (this->incompleteSslWrite.hasPendingWrite())
  247 + {
  248 + buf_ = this->incompleteSslWrite.buf;
  249 + nbytes_ = this->incompleteSslWrite.nbytes;
  250 + }
  251 +
  252 + // OpenSSL: "You should not call SSL_write() with num=0, it will return an error"
  253 + assert(nbytes_ > 0);
  254 +
  255 + this->sslWriteWantsRead = false;
  256 + this->incompleteSslWrite.reset();
  257 +
  258 + ERR_clear_error();
  259 + char sslErrorBuf[OPENSSL_ERROR_STRING_SIZE];
  260 + n = SSL_write(ssl, buf_, nbytes_);
  261 +
  262 + if (n <= 0)
  263 + {
  264 + int err = SSL_get_error(ssl, n);
  265 + unsigned long error_code = ERR_get_error();
  266 + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
  267 + {
  268 + logger->logf(LOG_DEBUG, "Write is incomplete: %d", err);
  269 + *error = IoWrapResult::Wouldblock;
  270 + IncompleteSslWrite sslAction(buf_, nbytes_);
  271 + this->incompleteSslWrite = sslAction;
  272 + if (err == SSL_ERROR_WANT_READ)
  273 + this->sslWriteWantsRead = true;
  274 + n = 0;
  275 + }
  276 + else
  277 + {
  278 + if (err == SSL_ERROR_SYSCALL)
  279 + {
  280 + // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say
  281 + // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it
  282 + // implies EINTR is not included?
  283 + if (errno == EINTR)
  284 + *error = IoWrapResult::Interrupted;
  285 + else
  286 + {
  287 + char *err = strerror(errno);
  288 + std::string msg(err);
  289 + throw std::runtime_error(msg);
  290 + }
  291 + }
  292 + ERR_error_string(error_code, sslErrorBuf);
  293 + std::string errorString(sslErrorBuf, OPENSSL_ERROR_STRING_SIZE);
  294 + ERR_print_errors_cb(logSslError, NULL);
  295 + throw std::runtime_error("SSL socket error writing: " + errorString);
  296 + }
  297 + }
  298 + }
  299 +
  300 + return n;
  301 +}
  302 +
  303 +// Use a small intermediate buffer to write (partial) websocket frames to our normal read buffer. MQTT is already a frames protocol, so we don't
  304 +// care about websocket frames being incomplete.
  305 +ssize_t IoWrapper::readWebsocketAndOrSsl(int fd, void *buf, size_t nbytes, IoWrapResult *error)
  306 +{
  307 + if (!websocket)
  308 + {
  309 + return readOrSslRead(fd, buf, nbytes, error);
  310 + }
  311 + else
  312 + {
  313 + ssize_t n = 0;
  314 + while (websocketPendingBytes.freeSpace() > 0 && (n = readOrSslRead(fd, websocketPendingBytes.headPtr(), websocketPendingBytes.maxWriteSize(), error)) != 0)
  315 + {
  316 + if (n > 0)
  317 + websocketPendingBytes.advanceHead(n);
  318 + if (n < 0)
  319 + break; // signal/error handling is done by the caller, so we just stop.
  320 +
  321 + if (websocketState == WebsocketState::NotUpgraded && websocketPendingBytes.freeSpace() == 0)
  322 + {
  323 + if (websocketPendingBytes.getSize() * 2 <= 8192)
  324 + websocketPendingBytes.doubleSize();
  325 + else
  326 + throw ProtocolError("Trying to exceed websocket buffer. Probably not valid websocket traffic.");
  327 + }
  328 + }
  329 +
  330 + const bool hasWebsocketPendingBytes = websocketPendingBytes.usedBytes() > 0;
  331 +
  332 + // When some or all the data has been read, we can continue.
  333 + if (!(*error == IoWrapResult::Wouldblock || *error == IoWrapResult::Success) && !hasWebsocketPendingBytes)
  334 + return n;
  335 +
  336 + if (hasWebsocketPendingBytes)
  337 + {
  338 + n = 0;
  339 +
  340 + if (websocketState == WebsocketState::NotUpgraded)
  341 + {
  342 + try
  343 + {
  344 + std::string websocketKey;
  345 + int websocketVersion;
  346 + if (parseHttpHeader(websocketPendingBytes, websocketKey, websocketVersion))
  347 + {
  348 + if (websocketKey.empty())
  349 + throw BadHttpRequest("No websocket key specified.");
  350 + if (websocketVersion != 13)
  351 + throw BadWebsocketVersionException("Websocket version 13 required.");
  352 +
  353 + const std::string acceptString = generateWebsocketAcceptString(websocketKey);
  354 +
  355 + std::string answer = generateWebsocketAnswer(acceptString);
  356 + parentClient->writeText(answer);
  357 + websocketState = WebsocketState::Upgrading;
  358 + websocketPendingBytes.reset();
  359 + websocketPendingBytes.resetSize(initialBufferSize);
  360 + *error = IoWrapResult::Success;
  361 + }
  362 + }
  363 + catch (BadWebsocketVersionException &ex)
  364 + {
  365 + std::string response = generateInvalidWebsocketVersionHttpHeaders(13);
  366 + parentClient->writeText(response);
  367 + parentClient->setDisconnectReason("Invalid websocket version");
  368 + parentClient->setReadyForDisconnect();
  369 + }
  370 + catch (BadHttpRequest &ex) // Should should also properly deal with attempt at HTTP2 with PRI.
  371 + {
  372 + std::string response = generateBadHttpRequestReponse(ex.what());
  373 + parentClient->writeText(response);
  374 + parentClient->setDisconnectReason("Invalid websocket start");
  375 + parentClient->setReadyForDisconnect();
  376 + }
  377 + }
  378 + else
  379 + {
  380 + n = websocketBytesToReadBuffer(buf, nbytes);
  381 +
  382 + if (n > 0)
  383 + *error = IoWrapResult::Success;
  384 + else if (n == 0)
  385 + *error = IoWrapResult::Wouldblock;
  386 + }
  387 + }
  388 +
  389 + return n;
  390 + }
  391 +}
  392 +
  393 +ssize_t IoWrapper::websocketBytesToReadBuffer(void *buf, const size_t nbytes)
  394 +{
  395 + const ssize_t targetBufMaxSize = nbytes;
  396 + ssize_t nbytesRead = 0;
  397 +
  398 + while (websocketPendingBytes.usedBytes() >= WEBSOCKET_MIN_HEADER_BYTES_NEEDED && nbytesRead < targetBufMaxSize)
  399 + {
  400 + // This block decodes the header.
  401 + if (!incompleteWebsocketRead.sillWorkingOnFrame())
  402 + {
  403 + const uint8_t byte1 = websocketPendingBytes.peakAhead(0);
  404 + const uint8_t byte2 = websocketPendingBytes.peakAhead(1);
  405 + bool masked = !!(byte2 & 0b10000000);
  406 + uint8_t reserved = (byte1 & 0b01110000) >> 4;
  407 + WebsocketOpcode opcode = (WebsocketOpcode)(byte1 & 0b00001111);
  408 + const uint8_t payloadLength = byte2 & 0b01111111;
  409 + size_t realPayloadLength = payloadLength;
  410 + uint64_t extendedPayloadLengthLength = 0;
  411 + uint8_t headerLength = masked ? 6 : 2;
  412 +
  413 + if (payloadLength == 126)
  414 + extendedPayloadLengthLength = 2;
  415 + else if (payloadLength == 127)
  416 + extendedPayloadLengthLength = 8;
  417 + headerLength += extendedPayloadLengthLength;
  418 +
  419 + //if (!masked)
  420 + // throw ProtocolError("Client must send masked websocket bytes.");
  421 +
  422 + if (reserved != 0)
  423 + throw ProtocolError("Reserved bytes in header must be 0.");
  424 +
  425 + if (headerLength > websocketPendingBytes.usedBytes())
  426 + return nbytesRead;
  427 +
  428 + uint64_t extendedPayloadLength = 0;
  429 +
  430 + int i = 2;
  431 + int shift = extendedPayloadLengthLength * 8;
  432 + while (shift > 0)
  433 + {
  434 + shift -= 8;
  435 + uint8_t byte = websocketPendingBytes.peakAhead(i++);
  436 + extendedPayloadLength += (byte << shift);
  437 + }
  438 +
  439 + if (extendedPayloadLength > 0)
  440 + realPayloadLength = extendedPayloadLength;
  441 +
  442 + if (headerLength > websocketPendingBytes.usedBytes())
  443 + return nbytesRead;
  444 +
  445 + if (masked)
  446 + {
  447 + for (int j = 0; j < 4; j++)
  448 + {
  449 + incompleteWebsocketRead.maskingKey[j] = websocketPendingBytes.peakAhead(i++);
  450 + }
  451 + }
  452 +
  453 + assert(i == headerLength);
  454 + assert(headerLength <= websocketPendingBytes.usedBytes());
  455 + websocketPendingBytes.advanceTail(headerLength);
  456 +
  457 + incompleteWebsocketRead.frame_bytes_left = realPayloadLength;
  458 + incompleteWebsocketRead.opcode = opcode;
  459 + }
  460 +
  461 + if (incompleteWebsocketRead.opcode == WebsocketOpcode::Binary)
  462 + {
  463 + // The following reads one websocket frame max: it will continue with the previous, or start a new one, which it may or may not finish.
  464 + size_t targetBufI = 0;
  465 + char *targetBuf = &static_cast<char*>(buf)[nbytesRead];
  466 + while(websocketPendingBytes.usedBytes() > 0 && incompleteWebsocketRead.frame_bytes_left > 0 && nbytesRead < targetBufMaxSize)
  467 + {
  468 + const size_t asManyBytesOfThisFrameAsPossible = std::min<int>(websocketPendingBytes.maxReadSize(), incompleteWebsocketRead.frame_bytes_left);
  469 + const size_t maxReadSize = std::min<int>(asManyBytesOfThisFrameAsPossible, targetBufMaxSize - nbytesRead);
  470 + assert(maxReadSize > 0);
  471 + assert(static_cast<ssize_t>(maxReadSize) + nbytesRead <= targetBufMaxSize);
  472 + for (size_t x = 0; x < maxReadSize; x++)
  473 + {
  474 + targetBuf[targetBufI++] = websocketPendingBytes.tailPtr()[x] ^ incompleteWebsocketRead.getNextMaskingByte();
  475 + }
  476 + websocketPendingBytes.advanceTail(maxReadSize);
  477 + incompleteWebsocketRead.frame_bytes_left -= maxReadSize;
  478 + nbytesRead += maxReadSize;
  479 + }
  480 + }
  481 + else if (incompleteWebsocketRead.opcode == WebsocketOpcode::Ping)
  482 + {
  483 + // A ping MAY have user data, which needs to be ponged back;
  484 +
  485 + // Constructing a new temporary buffer because I need the reponse in one frame for writeAsMuchOfBufAsWebsocketFrame().
  486 + std::vector<char> response(incompleteWebsocketRead.frame_bytes_left);
  487 + websocketPendingBytes.read(response.data(), response.size());
  488 +
  489 + websocketWriteRemainder.ensureFreeSpace(response.size());
  490 + writeAsMuchOfBufAsWebsocketFrame(response.data(), response.size(), WebsocketOpcode::Pong);
  491 + parentClient->setReadyForWriting(true);
  492 + }
  493 +
  494 + if (!incompleteWebsocketRead.sillWorkingOnFrame())
  495 + incompleteWebsocketRead.reset();
  496 + }
  497 + assert(nbytesRead <= static_cast<ssize_t>(nbytes));
  498 +
  499 + return nbytesRead;
  500 +}
  501 +
  502 +/**
  503 + * @brief IoWrapper::writeAsMuchOfBufAsWebsocketFrame writes buf of part of buf as websocket frame to websocketWriteRemainder
  504 + * @param buf
  505 + * @param nbytes. The amount of bytes. Can be 0, for just an empty websocket frame.
  506 + * @return
  507 + */
  508 +ssize_t IoWrapper::writeAsMuchOfBufAsWebsocketFrame(const void *buf, size_t nbytes, WebsocketOpcode opcode)
  509 +{
  510 + // We do allow pong frames to generate a zero payload packet, but for binary, that's not necessary.
  511 + if (nbytes == 0 && opcode == WebsocketOpcode::Binary)
  512 + return 0;
  513 +
  514 + ssize_t nBytesReal = 0;
  515 +
  516 + // We normally wrap each write in a frame, but if a previous one didn't fit in the system's write buffers, we're still working on it.
  517 + if (websocketWriteRemainder.freeSpace() > WEBSOCKET_MAX_SENDING_HEADER_SIZE)
  518 + {
  519 + uint8_t extended_payload_length_num_bytes = 0;
  520 + uint8_t payload_length = 0;
  521 + if (nbytes < 126)
  522 + payload_length = nbytes;
  523 + else if (nbytes >= 126 && nbytes <= 0xFFFF)
  524 + {
  525 + payload_length = 126;
  526 + extended_payload_length_num_bytes = 2;
  527 + }
  528 + else if (nbytes > 0xFFFF)
  529 + {
  530 + payload_length = 127;
  531 + extended_payload_length_num_bytes = 8;
  532 + }
  533 +
  534 + int x = 0;
  535 + char header[WEBSOCKET_MAX_SENDING_HEADER_SIZE];
  536 + header[x++] = (0b10000000 | static_cast<char>(opcode));
  537 + header[x++] = payload_length;
  538 +
  539 + const int header_length = x + extended_payload_length_num_bytes;
  540 +
  541 + // This block writes the extended payload length.
  542 + nBytesReal = std::min<int>(nbytes, websocketWriteRemainder.freeSpace() - header_length);
  543 + const uint64_t nbytes64 = nBytesReal;
  544 + for (int z = extended_payload_length_num_bytes - 1; z >= 0; z--)
  545 + {
  546 + header[x++] = (nbytes64 >> (z*8)) & 0xFF;
  547 + }
  548 + assert(x <= WEBSOCKET_MAX_SENDING_HEADER_SIZE);
  549 + assert(x == header_length);
  550 +
  551 + websocketWriteRemainder.write(header, header_length);
  552 + websocketWriteRemainder.write(buf, nBytesReal);
  553 + }
  554 +
  555 + return nBytesReal;
  556 +}
  557 +
  558 +/*
  559 + * Mqtt docs: "A single WebSocket data frame can contain multiple or partial MQTT Control Packets. The receiver
  560 + * MUST NOT assume that MQTT Control Packets are aligned on WebSocket frame boundaries [MQTT-6.0.0-2]." We
  561 + * make use of that here, and wrap each write in a frame.
  562 + *
  563 + * It's can legitimately return a number of bytes written AND error with 'would block'. So, no need to do that
  564 + * repeating of the write thing that SSL_write() has.
  565 + */
  566 +ssize_t IoWrapper::writeWebsocketAndOrSsl(int fd, const void *buf, size_t nbytes, IoWrapResult *error)
  567 +{
  568 + if (websocketState != WebsocketState::Upgraded)
  569 + {
  570 + if (websocket && websocketState == WebsocketState::Upgrading)
  571 + websocketState = WebsocketState::Upgraded;
  572 +
  573 + return writeOrSslWrite(fd, buf, nbytes, error);
  574 + }
  575 + else
  576 + {
  577 + ssize_t nBytesReal = writeAsMuchOfBufAsWebsocketFrame(buf, nbytes);
  578 +
  579 + ssize_t n = 0;
  580 + while (websocketWriteRemainder.usedBytes() > 0)
  581 + {
  582 + n = writeOrSslWrite(fd, websocketWriteRemainder.tailPtr(), websocketWriteRemainder.maxReadSize(), error);
  583 +
  584 + if (n > 0)
  585 + websocketWriteRemainder.advanceTail(n);
  586 + if (n < 0)
  587 + break;
  588 + }
  589 +
  590 + if (n > 0)
  591 + return nBytesReal;
  592 +
  593 + return n;
  594 + }
  595 +}
... ...
iowrapper.h 0 โ†’ 100644
  1 +#ifndef IOWRAPPER_H
  2 +#define IOWRAPPER_H
  3 +
  4 +#include "unistd.h"
  5 +#include "openssl/ssl.h"
  6 +#include "openssl/err.h"
  7 +#include <exception>
  8 +
  9 +#include "forward_declarations.h"
  10 +
  11 +#include "types.h"
  12 +#include "utils.h"
  13 +#include "logger.h"
  14 +#include "exceptions.h"
  15 +
  16 +#define WEBSOCKET_MIN_HEADER_BYTES_NEEDED 2
  17 +#define WEBSOCKET_MAX_SENDING_HEADER_SIZE 10
  18 +
  19 +#define OPENSSL_ERROR_STRING_SIZE 256 // OpenSSL requires at least 256.
  20 +#define OPENSSL_WRONG_VERSION_NUMBER 336130315
  21 +
  22 +enum class IoWrapResult
  23 +{
  24 + Success = 0,
  25 + Interrupted = 1,
  26 + Wouldblock = 2,
  27 + Disconnected = 3,
  28 + Error = 4
  29 +};
  30 +
  31 +enum class WebsocketOpcode
  32 +{
  33 + Continuation = 0x00,
  34 + Text = 0x1,
  35 + Binary = 0x2,
  36 + Close = 0x8,
  37 + Ping = 0x9,
  38 + Pong = 0xA,
  39 + Unknown = 0xF
  40 +};
  41 +
  42 +/*
  43 + * OpenSSL doc: "When a write function call has to be repeated because SSL_get_error(3) returned
  44 + * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments"
  45 + */
  46 +struct IncompleteSslWrite
  47 +{
  48 + const void *buf = nullptr;
  49 + size_t nbytes = 0;
  50 +
  51 + IncompleteSslWrite() = default;
  52 + IncompleteSslWrite(const void *buf, size_t nbytes);
  53 + bool hasPendingWrite() const;
  54 +
  55 + void reset();
  56 +};
  57 +
  58 +struct IncompleteWebsocketRead
  59 +{
  60 + size_t frame_bytes_left = 0;
  61 + char maskingKey[4];
  62 + int maskingKeyI = 0;
  63 + WebsocketOpcode opcode;
  64 +
  65 + void reset();
  66 + bool sillWorkingOnFrame() const;
  67 + char getNextMaskingByte();
  68 +};
  69 +
  70 +enum class WebsocketState
  71 +{
  72 + NotUpgraded,
  73 + Upgrading,
  74 + Upgraded
  75 +};
  76 +
  77 +/**
  78 + * @brief provides a unified wrapper for SSL and websockets to read() and write().
  79 + *
  80 + *
  81 + */
  82 +class IoWrapper
  83 +{
  84 + Client *parentClient;
  85 + const size_t initialBufferSize;
  86 +
  87 + SSL *ssl = nullptr;
  88 + bool sslAccepted = false;
  89 + IncompleteSslWrite incompleteSslWrite;
  90 + bool sslReadWantsWrite = false;
  91 + bool sslWriteWantsRead = false;
  92 +
  93 + bool websocket;
  94 + WebsocketState websocketState = WebsocketState::NotUpgraded;
  95 + CirBuf websocketPendingBytes;
  96 + IncompleteWebsocketRead incompleteWebsocketRead;
  97 + CirBuf websocketWriteRemainder;
  98 +
  99 + Logger *logger = Logger::getInstance();
  100 +
  101 + ssize_t websocketBytesToReadBuffer(void *buf, const size_t nbytes);
  102 + ssize_t readOrSslRead(int fd, void *buf, size_t nbytes, IoWrapResult *error);
  103 + ssize_t writeOrSslWrite(int fd, const void *buf, size_t nbytes, IoWrapResult *error);
  104 + ssize_t writeAsMuchOfBufAsWebsocketFrame(const void *buf, size_t nbytes, WebsocketOpcode opcode = WebsocketOpcode::Binary);
  105 +public:
  106 + IoWrapper(SSL *ssl, bool websocket, const size_t initialBufferSize, Client *parent);
  107 + ~IoWrapper();
  108 +
  109 + void startOrContinueSslAccept();
  110 + bool getSslReadWantsWrite() const;
  111 + bool getSslWriteWantsRead() const;
  112 + bool isSslAccepted() const;
  113 + bool isSsl() const;
  114 + bool hasPendingWrite() const;
  115 + bool isWebsocket() const;
  116 + WebsocketState getWebsocketState() const;
  117 +
  118 + ssize_t readWebsocketAndOrSsl(int fd, void *buf, size_t nbytes, IoWrapResult *error);
  119 + ssize_t writeWebsocketAndOrSsl(int fd, const void *buf, size_t nbytes, IoWrapResult *error);
  120 +};
  121 +
  122 +#endif // IOWRAPPER_H
... ...
mainapp.cpp
... ... @@ -351,6 +351,7 @@ void MainApp::start()
351 351  
352 352 int listen_fd_plain = createListenSocket(this->listenPort, false);
353 353 int listen_fd_ssl = createListenSocket(this->sslListenPort, true);
  354 + int listen_fd_websocket_plain = createListenSocket(1443, true);
354 355  
355 356 #ifdef NDEBUG
356 357 logger->noLongerLogToStd();
... ... @@ -391,7 +392,7 @@ void MainApp::start()
391 392 int cur_fd = events[i].data.fd;
392 393 try
393 394 {
394   - if (cur_fd == listen_fd_plain || cur_fd == listen_fd_ssl)
  395 + if (cur_fd == listen_fd_plain || cur_fd == listen_fd_ssl || listen_fd_websocket_plain)
395 396 {
396 397 std::shared_ptr<ThreadData> thread_data = threads[next_thread_index++ % num_threads];
397 398  
... ... @@ -402,6 +403,7 @@ void MainApp::start()
402 403 socklen_t len = sizeof(struct sockaddr);
403 404 int fd = check<std::runtime_error>(accept(cur_fd, &addr, &len));
404 405  
  406 + bool websocket = cur_fd == listen_fd_websocket_plain;
405 407 SSL *clientSSL = nullptr;
406 408 if (cur_fd == listen_fd_ssl)
407 409 {
... ... @@ -417,7 +419,7 @@ void MainApp::start()
417 419 SSL_set_fd(clientSSL, fd);
418 420 }
419 421  
420   - Client_p client(new Client(fd, thread_data, clientSSL, settings));
  422 + Client_p client(new Client(fd, thread_data, clientSSL, websocket, settings));
421 423 thread_data->giveClient(client);
422 424 }
423 425 else if (cur_fd == taskEventFd)
... ...
utils.cpp
... ... @@ -3,8 +3,10 @@
3 3 #include "sys/time.h"
4 4 #include "sys/random.h"
5 5 #include <algorithm>
  6 +#include <sstream>
6 7  
7 8 #include "exceptions.h"
  9 +#include "cirbuf.h"
8 10  
9 11 std::list<std::__cxx11::string> split(const std::string &input, const char sep, size_t max, bool keep_empty_parts)
10 12 {
... ... @@ -226,3 +228,121 @@ bool isPowerOfTwo(int n)
226 228 {
227 229 return (n != 0) && (n & (n - 1)) == 0;
228 230 }
  231 +
  232 +bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version)
  233 +{
  234 + const std::string s(buf.tailPtr(), buf.usedBytes());
  235 + std::istringstream is(s);
  236 + bool doubleEmptyLine = false; // meaning, the HTTP header is complete
  237 + bool upgradeHeaderSeen = false;
  238 + bool connectionHeaderSeen = false;
  239 + bool firstLine = true;
  240 +
  241 + std::string line;
  242 + while (std::getline(is, line))
  243 + {
  244 + trim(line);
  245 + if (firstLine)
  246 + {
  247 + firstLine = false;
  248 + if (!startsWith(line, "GET"))
  249 + throw BadHttpRequest("Websocket request should start with GET.");
  250 + continue;
  251 + }
  252 + if (line.empty())
  253 + {
  254 + doubleEmptyLine = true;
  255 + break;
  256 + }
  257 +
  258 + std::list<std::string> fields = split(line, ':', 1);
  259 + const std::vector<std::string> fields2(fields.begin(), fields.end());
  260 + std::string name = str_tolower(fields2[0]);
  261 + trim(name);
  262 + std::string value = fields2[1];
  263 + trim(value);
  264 + std::string value_lower = str_tolower(value);
  265 +
  266 + if (name == "upgrade" && value_lower == "websocket")
  267 + upgradeHeaderSeen = true;
  268 + else if (name == "connection" && value_lower == "upgrade")
  269 + connectionHeaderSeen = true;
  270 + else if (name == "sec-websocket-key")
  271 + websocket_key = value;
  272 + else if (name == "sec-websocket-version")
  273 + websocket_version = stoi(value);
  274 + }
  275 +
  276 + if (doubleEmptyLine)
  277 + {
  278 + if (!connectionHeaderSeen || !upgradeHeaderSeen)
  279 + throw BadHttpRequest("HTTP request is not a websocket upgrade request.");
  280 + }
  281 +
  282 + return doubleEmptyLine;
  283 +}
  284 +
  285 +std::string base64Encode(const unsigned char *input, const int length)
  286 +{
  287 + const int pl = 4*((length+2)/3);
  288 + char *output = reinterpret_cast<char *>(calloc(pl+1, 1));
  289 + const int ol = EVP_EncodeBlock(reinterpret_cast<unsigned char *>(output), input, length);
  290 + std::string result(output);
  291 + free(output);
  292 +
  293 + if (pl != ol)
  294 + throw std::runtime_error("Base64 encode error.");
  295 +
  296 + return result;
  297 +}
  298 +
  299 +std::string generateWebsocketAcceptString(const std::string &websocketKey)
  300 +{
  301 + unsigned char md_value[EVP_MAX_MD_SIZE];
  302 + unsigned int md_len;
  303 +
  304 + EVP_MD_CTX *mdctx = EVP_MD_CTX_new();;
  305 + const EVP_MD *md = EVP_sha1();
  306 + EVP_DigestInit_ex(mdctx, md, NULL);
  307 +
  308 + const std::string keyPlusMagic = websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
  309 +
  310 + EVP_DigestUpdate(mdctx, keyPlusMagic.c_str(), keyPlusMagic.length());
  311 + EVP_DigestFinal_ex(mdctx, md_value, &md_len);
  312 + EVP_MD_CTX_free(mdctx);
  313 +
  314 + std::string base64 = base64Encode(md_value, md_len);
  315 + return base64;
  316 +}
  317 +
  318 +std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion)
  319 +{
  320 + std::ostringstream oss;
  321 + oss << "HTTP/1.1 400 Bad Request\r\n";
  322 + oss << "Sec-WebSocket-Version: " << wantedVersion;
  323 + oss << "\r\n";
  324 + oss.flush();
  325 + return oss.str();
  326 +}
  327 +
  328 +std::string generateBadHttpRequestReponse(const std::string &msg)
  329 +{
  330 + std::ostringstream oss;
  331 + oss << "HTTP/1.1 400 Bad Request\r\n";
  332 + oss << "\r\n";
  333 + oss << msg;
  334 + oss.flush();
  335 + return oss.str();
  336 +}
  337 +
  338 +std::string generateWebsocketAnswer(const std::string &acceptString)
  339 +{
  340 + std::ostringstream oss;
  341 + oss << "HTTP/1.1 101 Switching Protocols\r\n";
  342 + oss << "Upgrade: websocket\r\n";
  343 + oss << "Connection: Upgrade\r\n";
  344 + oss << "Sec-WebSocket-Accept: " << acceptString << "\r\n";
  345 + oss << "\r\n";
  346 + oss.flush();
  347 + return oss.str();
  348 +}
... ...
... ... @@ -8,6 +8,9 @@
8 8 #include <limits>
9 9 #include <vector>
10 10 #include <algorithm>
  11 +#include <openssl/evp.h>
  12 +
  13 +#include "cirbuf.h"
11 14  
12 15 template<typename T> int check(int rc)
13 16 {
... ... @@ -43,4 +46,13 @@ std::string str_tolower(std::string s);
43 46 bool stringTruthiness(const std::string &val);
44 47 bool isPowerOfTwo(int val);
45 48  
  49 +bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version);
  50 +
  51 +std::string base64Encode(const unsigned char *input, const int length);
  52 +std::string generateWebsocketAcceptString(const std::string &websocketKey);
  53 +
  54 +std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion);
  55 +std::string generateBadHttpRequestReponse(const std::string &msg);
  56 +std::string generateWebsocketAnswer(const std::string &acceptString);
  57 +
46 58 #endif // UTILS_H
... ...