From 1f3786991edf45e47fabbbabd7a77cba502d5ab9 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Mon, 8 Feb 2021 21:07:50 +0100 Subject: [PATCH] Deal with clientid protocol version-appropriate --- client.cpp | 3 ++- client.h | 4 +++- mqttpacket.cpp | 28 +++++++++++++++++++++++++--- utils.cpp | 27 ++++++++++++++++++++++++++- utils.h | 1 + 5 files changed, 57 insertions(+), 6 deletions(-) diff --git a/client.cpp b/client.cpp index 40c6d85..a3dd97d 100644 --- a/client.cpp +++ b/client.cpp @@ -565,8 +565,9 @@ bool Client::bufferToMqttPackets(std::vector &packetQueueIn, Client_ return true; } -void Client::setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession) +void Client::setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession) { + this->protocolVersion = protocolVersion; this->clientid = clientId; this->username = username; this->connectPacketSeen = connectPacketSeen; diff --git a/client.h b/client.h index 32c4174..7c99543 100644 --- a/client.h +++ b/client.h @@ -16,6 +16,7 @@ #include "mqttpacket.h" #include "exceptions.h" #include "cirbuf.h" +#include "types.h" #include #include @@ -61,6 +62,7 @@ class Client IncompleteSslWrite incompleteSslWrite; bool sslReadWantsWrite = false; bool sslWriteWantsRead = false; + ProtocolVersion protocolVersion = ProtocolVersion::None; CirBuf readbuf; uint8_t readBufIsZeroCount = 0; @@ -115,7 +117,7 @@ public: ssize_t readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error); bool readFdIntoBuffer(); bool bufferToMqttPackets(std::vector &packetQueueIn, Client_p &sender); - void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession); + void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession); void setWill(const std::string &topic, const std::string &payload, bool retain, char qos); void setAuthenticated(bool value) { authenticated = value;} bool getAuthenticated() { return authenticated; } diff --git a/mqttpacket.cpp b/mqttpacket.cpp index b4c6ede..79b781b 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -237,19 +237,41 @@ void MqttPacket::handleConnect() return; } - // In case the client_id ever appears in topics. + bool validClientId = true; + + // Check for wildcard chars in case the client_id ever appears in topics. // TODO: make setting? if (strContains(client_id, "+") || strContains(client_id, "#")) { + logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id:", client_id.c_str()); + validClientId = false; + } + else if (!clean_session && client_id.empty()) + { + logger->logf(LOG_ERR, "ClientID empty and clean session 0, which is incompatible"); + validClientId = false; + } + else if (protocolVersion < ProtocolVersion::Mqtt311 && client_id.empty()) + { + logger->logf(LOG_ERR, "Empty clientID. Connect with protocol 3.1.1 or higher to have one generated securely."); + validClientId = false; + } + + if (!validClientId) + { ConnAck connAck(ConnAckReturnCodes::ClientIdRejected); MqttPacket response(connAck); sender->setReadyForDisconnect(); sender->writeMqttPacket(response); - logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id:", client_id.c_str()); return; } - sender->setClientProperties(client_id, username, true, keep_alive, clean_session); + if (client_id.empty()) + { + client_id = getSecureRandomString(23); + } + + sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, clean_session); sender->setWill(will_topic, will_payload, will_retain, will_qos); if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success) diff --git a/utils.cpp b/utils.cpp index 5fe5a4c..baf8c46 100644 --- a/utils.cpp +++ b/utils.cpp @@ -1,9 +1,11 @@ #include "utils.h" #include "sys/time.h" - +#include "sys/random.h" #include +#include "exceptions.h" + std::list split(const std::string &input, const char sep, size_t max, bool keep_empty_parts) { std::list list; @@ -178,3 +180,26 @@ int64_t currentMSecsSinceEpoch() int64_t milliseconds = te.tv_sec*1000LL + te.tv_usec/1000; return milliseconds; } + +std::string getSecureRandomString(const size_t len) +{ + std::vector buf(len); + size_t actual_len = getrandom(buf.data(), len, 0); + + if (actual_len < 0 || actual_len != len) + { + throw std::runtime_error("Error requesting random data"); + } + + const std::string possibleCharacters("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrtsuvwxyz1234567890"); + const int possibleCharactersCount = possibleCharacters.length(); + + std::string randomString; + for(const unsigned char &c : buf) + { + unsigned int index = c % possibleCharactersCount; + char nextChar = possibleCharacters.at(index); + randomString.push_back(nextChar); + } + return randomString; +} diff --git a/utils.h b/utils.h index 0267672..6862146 100644 --- a/utils.h +++ b/utils.h @@ -38,5 +38,6 @@ void trim(std::string &s); bool startsWith(const std::string &s, const std::string &needle); int64_t currentMSecsSinceEpoch(); +std::string getSecureRandomString(const size_t len); #endif // UTILS_H -- libgit2 0.21.4