Commit 1f3786991edf45e47fabbbabd7a77cba502d5ab9

Authored by Wiebe Cazemier
1 parent d0473da1

Deal with clientid protocol version-appropriate

And include some extra error conditions.
client.cpp
... ... @@ -565,8 +565,9 @@ bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_
565 565 return true;
566 566 }
567 567  
568   -void Client::setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession)
  568 +void Client::setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession)
569 569 {
  570 + this->protocolVersion = protocolVersion;
570 571 this->clientid = clientId;
571 572 this->username = username;
572 573 this->connectPacketSeen = connectPacketSeen;
... ...
client.h
... ... @@ -16,6 +16,7 @@
16 16 #include "mqttpacket.h"
17 17 #include "exceptions.h"
18 18 #include "cirbuf.h"
  19 +#include "types.h"
19 20  
20 21 #include <openssl/ssl.h>
21 22 #include <openssl/err.h>
... ... @@ -61,6 +62,7 @@ class Client
61 62 IncompleteSslWrite incompleteSslWrite;
62 63 bool sslReadWantsWrite = false;
63 64 bool sslWriteWantsRead = false;
  65 + ProtocolVersion protocolVersion = ProtocolVersion::None;
64 66  
65 67 CirBuf readbuf;
66 68 uint8_t readBufIsZeroCount = 0;
... ... @@ -115,7 +117,7 @@ public:
115 117 ssize_t readWrap(int fd, void *buf, size_t nbytes, IoWrapResult *error);
116 118 bool readFdIntoBuffer();
117 119 bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender);
118   - void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession);
  120 + void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession);
119 121 void setWill(const std::string &topic, const std::string &payload, bool retain, char qos);
120 122 void setAuthenticated(bool value) { authenticated = value;}
121 123 bool getAuthenticated() { return authenticated; }
... ...
mqttpacket.cpp
... ... @@ -237,19 +237,41 @@ void MqttPacket::handleConnect()
237 237 return;
238 238 }
239 239  
240   - // In case the client_id ever appears in topics.
  240 + bool validClientId = true;
  241 +
  242 + // Check for wildcard chars in case the client_id ever appears in topics.
241 243 // TODO: make setting?
242 244 if (strContains(client_id, "+") || strContains(client_id, "#"))
243 245 {
  246 + logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id:", client_id.c_str());
  247 + validClientId = false;
  248 + }
  249 + else if (!clean_session && client_id.empty())
  250 + {
  251 + logger->logf(LOG_ERR, "ClientID empty and clean session 0, which is incompatible");
  252 + validClientId = false;
  253 + }
  254 + else if (protocolVersion < ProtocolVersion::Mqtt311 && client_id.empty())
  255 + {
  256 + logger->logf(LOG_ERR, "Empty clientID. Connect with protocol 3.1.1 or higher to have one generated securely.");
  257 + validClientId = false;
  258 + }
  259 +
  260 + if (!validClientId)
  261 + {
244 262 ConnAck connAck(ConnAckReturnCodes::ClientIdRejected);
245 263 MqttPacket response(connAck);
246 264 sender->setReadyForDisconnect();
247 265 sender->writeMqttPacket(response);
248   - logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id:", client_id.c_str());
249 266 return;
250 267 }
251 268  
252   - sender->setClientProperties(client_id, username, true, keep_alive, clean_session);
  269 + if (client_id.empty())
  270 + {
  271 + client_id = getSecureRandomString(23);
  272 + }
  273 +
  274 + sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, clean_session);
253 275 sender->setWill(will_topic, will_payload, will_retain, will_qos);
254 276  
255 277 if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success)
... ...
utils.cpp
1 1 #include "utils.h"
2 2  
3 3 #include "sys/time.h"
4   -
  4 +#include "sys/random.h"
5 5 #include <algorithm>
6 6  
  7 +#include "exceptions.h"
  8 +
7 9 std::list<std::__cxx11::string> split(const std::string &input, const char sep, size_t max, bool keep_empty_parts)
8 10 {
9 11 std::list<std::string> list;
... ... @@ -178,3 +180,26 @@ int64_t currentMSecsSinceEpoch()
178 180 int64_t milliseconds = te.tv_sec*1000LL + te.tv_usec/1000;
179 181 return milliseconds;
180 182 }
  183 +
  184 +std::string getSecureRandomString(const size_t len)
  185 +{
  186 + std::vector<char> buf(len);
  187 + size_t actual_len = getrandom(buf.data(), len, 0);
  188 +
  189 + if (actual_len < 0 || actual_len != len)
  190 + {
  191 + throw std::runtime_error("Error requesting random data");
  192 + }
  193 +
  194 + const std::string possibleCharacters("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrtsuvwxyz1234567890");
  195 + const int possibleCharactersCount = possibleCharacters.length();
  196 +
  197 + std::string randomString;
  198 + for(const unsigned char &c : buf)
  199 + {
  200 + unsigned int index = c % possibleCharactersCount;
  201 + char nextChar = possibleCharacters.at(index);
  202 + randomString.push_back(nextChar);
  203 + }
  204 + return randomString;
  205 +}
... ...
... ... @@ -38,5 +38,6 @@ void trim(std::string &amp;s);
38 38 bool startsWith(const std::string &s, const std::string &needle);
39 39  
40 40 int64_t currentMSecsSinceEpoch();
  41 +std::string getSecureRandomString(const size_t len);
41 42  
42 43 #endif // UTILS_H
... ...