Commit 52cdff03ea3371df7e8b519aa728b35186fc4f38

Authored by Wiebe Cazemier
1 parent 4eaca771

Some more compliant connect error handling

client.cpp
@@ -100,6 +100,9 @@ void Client::writeMqttPacket(const MqttPacket &packet) @@ -100,6 +100,9 @@ void Client::writeMqttPacket(const MqttPacket &packet)
100 assert(wwi >= static_cast<int>(packet.getSizeIncludingNonPresentHeader())); 100 assert(wwi >= static_cast<int>(packet.getSizeIncludingNonPresentHeader()));
101 assert(wwi <= static_cast<int>(writeBufsize)); 101 assert(wwi <= static_cast<int>(writeBufsize));
102 102
  103 + if (packet.packetType == PacketType::DISCONNECT)
  104 + setReadyForDisconnect();
  105 +
103 setReadyForWriting(true); 106 setReadyForWriting(true);
104 } 107 }
105 108
client.h
@@ -36,6 +36,7 @@ class Client @@ -36,6 +36,7 @@ class Client
36 bool connectPacketSeen = false; 36 bool connectPacketSeen = false;
37 bool readyForWriting = false; 37 bool readyForWriting = false;
38 bool readyForReading = true; 38 bool readyForReading = true;
  39 + bool disconnectWhenBytesWritten = false;
39 40
40 std::string clientid; 41 std::string clientid;
41 std::string username; 42 std::string username;
@@ -123,6 +124,10 @@ public: @@ -123,6 +124,10 @@ public:
123 void writePingResp(); 124 void writePingResp();
124 void writeMqttPacket(const MqttPacket &packet); 125 void writeMqttPacket(const MqttPacket &packet);
125 bool writeBufIntoFd(); 126 bool writeBufIntoFd();
  127 + bool readyForDisconnecting() const { return disconnectWhenBytesWritten && wwi == wri && wwi == 0; }
  128 +
  129 + // Do this before calling an action that makes this client ready for writing, so that the EPOLLOUT will handle it.
  130 + void setReadyForDisconnect() { disconnectWhenBytesWritten = true; }
126 131
127 std::string repr(); 132 std::string repr();
128 133
mainapp.cpp
@@ -74,6 +74,9 @@ void do_thread_work(ThreadData *threadData) @@ -74,6 +74,9 @@ void do_thread_work(ThreadData *threadData)
74 { 74 {
75 if (!client->writeBufIntoFd()) 75 if (!client->writeBufIntoFd())
76 threadData->removeClient(client); 76 threadData->removeClient(client);
  77 +
  78 + if (client->readyForDisconnecting())
  79 + threadData->removeClient(client);
77 } 80 }
78 } 81 }
79 catch(std::exception &ex) 82 catch(std::exception &ex)
mqttpacket.cpp
@@ -128,7 +128,12 @@ void MqttPacket::handleConnect() @@ -128,7 +128,12 @@ void MqttPacket::handleConnect()
128 } 128 }
129 else 129 else
130 { 130 {
131 - throw ProtocolError("Only MQTT 3.1 and 3.1.1 supported."); 131 + ConnAck connAck(ConnAckReturnCodes::UnacceptableProtocolVersion);
  132 + MqttPacket response(connAck);
  133 + sender->setReadyForDisconnect();
  134 + sender->writeMqttPacket(response);
  135 + std::cout << "Rejecting because of invalid protocol version: " << sender->repr() << std::endl;
  136 + return;
132 } 137 }
133 138
134 char flagByte = readByte(); 139 char flagByte = readByte();
@@ -174,7 +179,28 @@ void MqttPacket::handleConnect() @@ -174,7 +179,28 @@ void MqttPacket::handleConnect()
174 password = std::string(readBytes(password_length), password_length); 179 password = std::string(readBytes(password_length), password_length);
175 } 180 }
176 181
177 - // TODO: validate UTF8 encoded username/password. 182 + // The specs don't really say what to do when client id not UTF8, so including here.
  183 + if (!isValidUtf8(client_id) || !isValidUtf8(username) || !isValidUtf8(password))
  184 + {
  185 + ConnAck connAck(ConnAckReturnCodes::MalformedUsernameOrPassword);
  186 + MqttPacket response(connAck);
  187 + sender->setReadyForDisconnect();
  188 + sender->writeMqttPacket(response);
  189 + std::cout << "Client ID, username or passwords has invalid UTF8: " << sender->repr() << std::endl;
  190 + return;
  191 + }
  192 +
  193 + // In case the client_id ever appears in topics.
  194 + // TODO: make setting?
  195 + if (strContains(client_id, "+") || strContains(client_id, "#"))
  196 + {
  197 + ConnAck connAck(ConnAckReturnCodes::ClientIdRejected);
  198 + MqttPacket response(connAck);
  199 + sender->setReadyForDisconnect();
  200 + sender->writeMqttPacket(response);
  201 + std::cout << "ClientID has + or # in the id: " << sender->repr() << std::endl;
  202 + return;
  203 + }
178 204
179 sender->setClientProperties(client_id, username, true, keep_alive); 205 sender->setClientProperties(client_id, username, true, keep_alive);
180 sender->setWill(will_topic, will_payload, will_retain, will_qos); 206 sender->setWill(will_topic, will_payload, will_retain, will_qos);
utils.cpp
@@ -106,3 +106,8 @@ bool isValidUtf8(const std::string &amp;s) @@ -106,3 +106,8 @@ bool isValidUtf8(const std::string &amp;s)
106 } 106 }
107 return multibyte_remain == 0; 107 return multibyte_remain == 0;
108 } 108 }
  109 +
  110 +bool strContains(const std::string &s, const std::string &needle)
  111 +{
  112 + return s.find(needle) != std::string::npos;
  113 +}
@@ -25,4 +25,6 @@ bool topicsMatch(const std::string &amp;subscribeTopic, const std::string &amp;publishTo @@ -25,4 +25,6 @@ bool topicsMatch(const std::string &amp;subscribeTopic, const std::string &amp;publishTo
25 25
26 bool isValidUtf8(const std::string &s); 26 bool isValidUtf8(const std::string &s);
27 27
  28 +bool strContains(const std::string &s, const std::string &needle);
  29 +
28 #endif // UTILS_H 30 #endif // UTILS_H