diff --git a/utils.cpp b/utils.cpp index 87de974..1cf4428 100644 --- a/utils.cpp +++ b/utils.cpp @@ -356,6 +356,7 @@ bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_ver bool upgradeHeaderSeen = false; bool connectionHeaderSeen = false; bool firstLine = true; + bool subprotocol_seen = false; std::string line; while (std::getline(is, line)) @@ -396,12 +397,16 @@ bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_ver websocket_key = value; else if (name == "sec-websocket-version") websocket_version = stoi(value); + else if (name == "sec-websocket-protocol" && value_lower == "mqtt") + subprotocol_seen = true; } if (doubleEmptyLine) { if (!connectionHeaderSeen || !upgradeHeaderSeen) throw BadHttpRequest("HTTP request is not a websocket upgrade request."); + if (!subprotocol_seen) + throw BadHttpRequest("HTTP header Sec-WebSocket-Protocol with value 'mqtt' must be present."); } return doubleEmptyLine; @@ -491,6 +496,7 @@ std::string generateWebsocketAnswer(const std::string &acceptString) oss << "Upgrade: websocket\r\n"; oss << "Connection: Upgrade\r\n"; oss << "Sec-WebSocket-Accept: " << acceptString << "\r\n"; + oss << "Sec-WebSocket-Protocol: mqtt\r\n"; oss << "\r\n"; oss.flush(); return oss.str();