Commit 1f3786991edf45e47fabbbabd7a77cba502d5ab9

Authored by Wiebe Cazemier
1 parent d0473da1

Deal with clientid protocol version-appropriate

And include some extra error conditions.
client.cpp
@@ -565,8 +565,9 @@ bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_ @@ -565,8 +565,9 @@ bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_
565 return true; 565 return true;
566 } 566 }
567 567
568 -void Client::setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession) 568 +void Client::setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession)
569 { 569 {
  570 + this->protocolVersion = protocolVersion;
570 this->clientid = clientId; 571 this->clientid = clientId;
571 this->username = username; 572 this->username = username;
572 this->connectPacketSeen = connectPacketSeen; 573 this->connectPacketSeen = connectPacketSeen;
client.h
@@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
16 #include "mqttpacket.h" 16 #include "mqttpacket.h"
17 #include "exceptions.h" 17 #include "exceptions.h"
18 #include "cirbuf.h" 18 #include "cirbuf.h"
  19 +#include "types.h"
19 20
20 #include <openssl/ssl.h> 21 #include <openssl/ssl.h>
21 #include <openssl/err.h> 22 #include <openssl/err.h>
@@ -61,6 +62,7 @@ class Client @@ -61,6 +62,7 @@ class Client
61 IncompleteSslWrite incompleteSslWrite; 62 IncompleteSslWrite incompleteSslWrite;
62 bool sslReadWantsWrite = false; 63 bool sslReadWantsWrite = false;
63 bool sslWriteWantsRead = false; 64 bool sslWriteWantsRead = false;
  65 + ProtocolVersion protocolVersion = ProtocolVersion::None;
64 66
65 CirBuf readbuf; 67 CirBuf readbuf;
66 uint8_t readBufIsZeroCount = 0; 68 uint8_t readBufIsZeroCount = 0;
@@ -115,7 +117,7 @@ public: @@ -115,7 +117,7 @@ public:
115 ssize_t readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error); 117 ssize_t readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error);
116 bool readFdIntoBuffer(); 118 bool readFdIntoBuffer();
117 bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender); 119 bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender);
118 - void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession); 120 + void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession);
119 void setWill(const std::string &topic, const std::string &payload, bool retain, char qos); 121 void setWill(const std::string &topic, const std::string &payload, bool retain, char qos);
120 void setAuthenticated(bool value) { authenticated = value;} 122 void setAuthenticated(bool value) { authenticated = value;}
121 bool getAuthenticated() { return authenticated; } 123 bool getAuthenticated() { return authenticated; }
mqttpacket.cpp
@@ -237,19 +237,41 @@ void MqttPacket::handleConnect() @@ -237,19 +237,41 @@ void MqttPacket::handleConnect()
237 return; 237 return;
238 } 238 }
239 239
240 - // In case the client_id ever appears in topics. 240 + bool validClientId = true;
  241 +
  242 + // Check for wildcard chars in case the client_id ever appears in topics.
241 // TODO: make setting? 243 // TODO: make setting?
242 if (strContains(client_id, "+") || strContains(client_id, "#")) 244 if (strContains(client_id, "+") || strContains(client_id, "#"))
243 { 245 {
  246 + logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id:", client_id.c_str());
  247 + validClientId = false;
  248 + }
  249 + else if (!clean_session && client_id.empty())
  250 + {
  251 + logger->logf(LOG_ERR, "ClientID empty and clean session 0, which is incompatible");
  252 + validClientId = false;
  253 + }
  254 + else if (protocolVersion < ProtocolVersion::Mqtt311 && client_id.empty())
  255 + {
  256 + logger->logf(LOG_ERR, "Empty clientID. Connect with protocol 3.1.1 or higher to have one generated securely.");
  257 + validClientId = false;
  258 + }
  259 +
  260 + if (!validClientId)
  261 + {
244 ConnAck connAck(ConnAckReturnCodes::ClientIdRejected); 262 ConnAck connAck(ConnAckReturnCodes::ClientIdRejected);
245 MqttPacket response(connAck); 263 MqttPacket response(connAck);
246 sender->setReadyForDisconnect(); 264 sender->setReadyForDisconnect();
247 sender->writeMqttPacket(response); 265 sender->writeMqttPacket(response);
248 - logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id:", client_id.c_str());  
249 return; 266 return;
250 } 267 }
251 268
252 - sender->setClientProperties(client_id, username, true, keep_alive, clean_session); 269 + if (client_id.empty())
  270 + {
  271 + client_id = getSecureRandomString(23);
  272 + }
  273 +
  274 + sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, clean_session);
253 sender->setWill(will_topic, will_payload, will_retain, will_qos); 275 sender->setWill(will_topic, will_payload, will_retain, will_qos);
254 276
255 if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success) 277 if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success)
utils.cpp
1 #include "utils.h" 1 #include "utils.h"
2 2
3 #include "sys/time.h" 3 #include "sys/time.h"
4 - 4 +#include "sys/random.h"
5 #include <algorithm> 5 #include <algorithm>
6 6
  7 +#include "exceptions.h"
  8 +
7 std::list<std::__cxx11::string> split(const std::string &input, const char sep, size_t max, bool keep_empty_parts) 9 std::list<std::__cxx11::string> split(const std::string &input, const char sep, size_t max, bool keep_empty_parts)
8 { 10 {
9 std::list<std::string> list; 11 std::list<std::string> list;
@@ -178,3 +180,26 @@ int64_t currentMSecsSinceEpoch() @@ -178,3 +180,26 @@ int64_t currentMSecsSinceEpoch()
178 int64_t milliseconds = te.tv_sec*1000LL + te.tv_usec/1000; 180 int64_t milliseconds = te.tv_sec*1000LL + te.tv_usec/1000;
179 return milliseconds; 181 return milliseconds;
180 } 182 }
  183 +
  184 +std::string getSecureRandomString(const size_t len)
  185 +{
  186 + std::vector<char> buf(len);
  187 + size_t actual_len = getrandom(buf.data(), len, 0);
  188 +
  189 + if (actual_len < 0 || actual_len != len)
  190 + {
  191 + throw std::runtime_error("Error requesting random data");
  192 + }
  193 +
  194 + const std::string possibleCharacters("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrtsuvwxyz1234567890");
  195 + const int possibleCharactersCount = possibleCharacters.length();
  196 +
  197 + std::string randomString;
  198 + for(const unsigned char &c : buf)
  199 + {
  200 + unsigned int index = c % possibleCharactersCount;
  201 + char nextChar = possibleCharacters.at(index);
  202 + randomString.push_back(nextChar);
  203 + }
  204 + return randomString;
  205 +}
@@ -38,5 +38,6 @@ void trim(std::string &amp;s); @@ -38,5 +38,6 @@ void trim(std::string &amp;s);
38 bool startsWith(const std::string &s, const std::string &needle); 38 bool startsWith(const std::string &s, const std::string &needle);
39 39
40 int64_t currentMSecsSinceEpoch(); 40 int64_t currentMSecsSinceEpoch();
  41 +std::string getSecureRandomString(const size_t len);
41 42
42 #endif // UTILS_H 43 #endif // UTILS_H