From c24becf2152e131882af98d5e0810d13162aefc6 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Sun, 3 Jul 2022 16:23:08 +0200 Subject: [PATCH] Allow for any websocket protocol containing mqtt --- iowrapper.cpp | 5 +++-- utils.cpp | 11 +++++++---- utils.h | 4 ++-- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/iowrapper.cpp b/iowrapper.cpp index efaed66..51f4171 100644 --- a/iowrapper.cpp +++ b/iowrapper.cpp @@ -371,7 +371,8 @@ ssize_t IoWrapper::readWebsocketAndOrSsl(int fd, void *buf, size_t nbytes, IoWra { std::string websocketKey; int websocketVersion; - if (parseHttpHeader(websocketPendingBytes, websocketKey, websocketVersion)) + std::string subprotocol; + if (parseHttpHeader(websocketPendingBytes, websocketKey, websocketVersion, subprotocol)) { if (websocketKey.empty()) throw BadHttpRequest("No websocket key specified."); @@ -380,7 +381,7 @@ ssize_t IoWrapper::readWebsocketAndOrSsl(int fd, void *buf, size_t nbytes, IoWra const std::string acceptString = generateWebsocketAcceptString(websocketKey); - std::string answer = generateWebsocketAnswer(acceptString); + std::string answer = generateWebsocketAnswer(acceptString, subprotocol); parentClient->writeText(answer); websocketState = WebsocketState::Upgrading; websocketPendingBytes.reset(); diff --git a/utils.cpp b/utils.cpp index 828fbef..bcd6f21 100644 --- a/utils.cpp +++ b/utils.cpp @@ -348,7 +348,7 @@ bool isPowerOfTwo(int n) return (n != 0) && (n & (n - 1)) == 0; } -bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version) +bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version, std::string &subprotocol) { const std::string s(buf.tailPtr(), buf.usedBytes()); std::istringstream is(s); @@ -397,8 +397,11 @@ bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_ver websocket_key = value; else if (name == "sec-websocket-version") websocket_version = stoi(value); - else if (name == "sec-websocket-protocol" && value_lower == "mqtt") + else if (name == "sec-websocket-protocol" && strContains(value_lower, "mqtt")) + { + subprotocol = value; subprotocol_seen = true; + } } if (doubleEmptyLine) @@ -489,14 +492,14 @@ std::string generateBadHttpRequestReponse(const std::string &msg) return oss.str(); } -std::string generateWebsocketAnswer(const std::string &acceptString) +std::string generateWebsocketAnswer(const std::string &acceptString, const std::string &subprotocol) { std::ostringstream oss; oss << "HTTP/1.1 101 Switching Protocols\r\n"; oss << "Upgrade: websocket\r\n"; oss << "Connection: Upgrade\r\n"; oss << "Sec-WebSocket-Accept: " << acceptString << "\r\n"; - oss << "Sec-WebSocket-Protocol: mqtt\r\n"; + oss << "Sec-WebSocket-Protocol: " << subprotocol << "\r\n"; oss << "\r\n"; oss.flush(); return oss.str(); diff --git a/utils.h b/utils.h index 8ab1152..0f2501e 100644 --- a/utils.h +++ b/utils.h @@ -77,7 +77,7 @@ std::string str_tolower(std::string s); bool stringTruthiness(const std::string &val); bool isPowerOfTwo(int val); -bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version); +bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version, std::string &subprotocol); std::vector base64Decode(const std::string &s); std::string base64Encode(const unsigned char *input, const int length); @@ -85,7 +85,7 @@ std::string generateWebsocketAcceptString(const std::string &websocketKey); std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion); std::string generateBadHttpRequestReponse(const std::string &msg); -std::string generateWebsocketAnswer(const std::string &acceptString); +std::string generateWebsocketAnswer(const std::string &acceptString, const std::string &subprotocol); void testSsl(const std::string &fullchain, const std::string &privkey); -- libgit2 0.21.4