Commit c24becf2152e131882af98d5e0810d13162aefc6

Authored by Wiebe Cazemier
1 parent 1ca1bdf6

Allow for any websocket protocol containing mqtt

I've see various variations, like mqtt and mqttv31. The 3.1.1 specs are
clear what it should be, but the 3.1 specs aren't. So, allowing
anything with mqtt in it.
iowrapper.cpp
... ... @@ -371,7 +371,8 @@ ssize_t IoWrapper::readWebsocketAndOrSsl(int fd, void *buf, size_t nbytes, IoWra
371 371 {
372 372 std::string websocketKey;
373 373 int websocketVersion;
374   - if (parseHttpHeader(websocketPendingBytes, websocketKey, websocketVersion))
  374 + std::string subprotocol;
  375 + if (parseHttpHeader(websocketPendingBytes, websocketKey, websocketVersion, subprotocol))
375 376 {
376 377 if (websocketKey.empty())
377 378 throw BadHttpRequest("No websocket key specified.");
... ... @@ -380,7 +381,7 @@ ssize_t IoWrapper::readWebsocketAndOrSsl(int fd, void *buf, size_t nbytes, IoWra
380 381  
381 382 const std::string acceptString = generateWebsocketAcceptString(websocketKey);
382 383  
383   - std::string answer = generateWebsocketAnswer(acceptString);
  384 + std::string answer = generateWebsocketAnswer(acceptString, subprotocol);
384 385 parentClient->writeText(answer);
385 386 websocketState = WebsocketState::Upgrading;
386 387 websocketPendingBytes.reset();
... ...
utils.cpp
... ... @@ -348,7 +348,7 @@ bool isPowerOfTwo(int n)
348 348 return (n != 0) && (n & (n - 1)) == 0;
349 349 }
350 350  
351   -bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version)
  351 +bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version, std::string &subprotocol)
352 352 {
353 353 const std::string s(buf.tailPtr(), buf.usedBytes());
354 354 std::istringstream is(s);
... ... @@ -397,8 +397,11 @@ bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_ver
397 397 websocket_key = value;
398 398 else if (name == "sec-websocket-version")
399 399 websocket_version = stoi(value);
400   - else if (name == "sec-websocket-protocol" && value_lower == "mqtt")
  400 + else if (name == "sec-websocket-protocol" && strContains(value_lower, "mqtt"))
  401 + {
  402 + subprotocol = value;
401 403 subprotocol_seen = true;
  404 + }
402 405 }
403 406  
404 407 if (doubleEmptyLine)
... ... @@ -489,14 +492,14 @@ std::string generateBadHttpRequestReponse(const std::string &msg)
489 492 return oss.str();
490 493 }
491 494  
492   -std::string generateWebsocketAnswer(const std::string &acceptString)
  495 +std::string generateWebsocketAnswer(const std::string &acceptString, const std::string &subprotocol)
493 496 {
494 497 std::ostringstream oss;
495 498 oss << "HTTP/1.1 101 Switching Protocols\r\n";
496 499 oss << "Upgrade: websocket\r\n";
497 500 oss << "Connection: Upgrade\r\n";
498 501 oss << "Sec-WebSocket-Accept: " << acceptString << "\r\n";
499   - oss << "Sec-WebSocket-Protocol: mqtt\r\n";
  502 + oss << "Sec-WebSocket-Protocol: " << subprotocol << "\r\n";
500 503 oss << "\r\n";
501 504 oss.flush();
502 505 return oss.str();
... ...
... ... @@ -77,7 +77,7 @@ std::string str_tolower(std::string s);
77 77 bool stringTruthiness(const std::string &val);
78 78 bool isPowerOfTwo(int val);
79 79  
80   -bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version);
  80 +bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version, std::string &subprotocol);
81 81  
82 82 std::vector<char> base64Decode(const std::string &s);
83 83 std::string base64Encode(const unsigned char *input, const int length);
... ... @@ -85,7 +85,7 @@ std::string generateWebsocketAcceptString(const std::string &amp;websocketKey);
85 85  
86 86 std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion);
87 87 std::string generateBadHttpRequestReponse(const std::string &msg);
88   -std::string generateWebsocketAnswer(const std::string &acceptString);
  88 +std::string generateWebsocketAnswer(const std::string &acceptString, const std::string &subprotocol);
89 89  
90 90 void testSsl(const std::string &fullchain, const std::string &privkey);
91 91  
... ...