Commit 52cdff03ea3371df7e8b519aa728b35186fc4f38

Authored by Wiebe Cazemier
1 parent 4eaca771

Some more compliant connect error handling

client.cpp
... ... @@ -100,6 +100,9 @@ void Client::writeMqttPacket(const MqttPacket &packet)
100 100 assert(wwi >= static_cast<int>(packet.getSizeIncludingNonPresentHeader()));
101 101 assert(wwi <= static_cast<int>(writeBufsize));
102 102  
  103 + if (packet.packetType == PacketType::DISCONNECT)
  104 + setReadyForDisconnect();
  105 +
103 106 setReadyForWriting(true);
104 107 }
105 108  
... ...
client.h
... ... @@ -36,6 +36,7 @@ class Client
36 36 bool connectPacketSeen = false;
37 37 bool readyForWriting = false;
38 38 bool readyForReading = true;
  39 + bool disconnectWhenBytesWritten = false;
39 40  
40 41 std::string clientid;
41 42 std::string username;
... ... @@ -123,6 +124,10 @@ public:
123 124 void writePingResp();
124 125 void writeMqttPacket(const MqttPacket &packet);
125 126 bool writeBufIntoFd();
  127 + bool readyForDisconnecting() const { return disconnectWhenBytesWritten && wwi == wri && wwi == 0; }
  128 +
  129 + // Do this before calling an action that makes this client ready for writing, so that the EPOLLOUT will handle it.
  130 + void setReadyForDisconnect() { disconnectWhenBytesWritten = true; }
126 131  
127 132 std::string repr();
128 133  
... ...
mainapp.cpp
... ... @@ -74,6 +74,9 @@ void do_thread_work(ThreadData *threadData)
74 74 {
75 75 if (!client->writeBufIntoFd())
76 76 threadData->removeClient(client);
  77 +
  78 + if (client->readyForDisconnecting())
  79 + threadData->removeClient(client);
77 80 }
78 81 }
79 82 catch(std::exception &ex)
... ...
mqttpacket.cpp
... ... @@ -128,7 +128,12 @@ void MqttPacket::handleConnect()
128 128 }
129 129 else
130 130 {
131   - throw ProtocolError("Only MQTT 3.1 and 3.1.1 supported.");
  131 + ConnAck connAck(ConnAckReturnCodes::UnacceptableProtocolVersion);
  132 + MqttPacket response(connAck);
  133 + sender->setReadyForDisconnect();
  134 + sender->writeMqttPacket(response);
  135 + std::cout << "Rejecting because of invalid protocol version: " << sender->repr() << std::endl;
  136 + return;
132 137 }
133 138  
134 139 char flagByte = readByte();
... ... @@ -174,7 +179,28 @@ void MqttPacket::handleConnect()
174 179 password = std::string(readBytes(password_length), password_length);
175 180 }
176 181  
177   - // TODO: validate UTF8 encoded username/password.
  182 + // The specs don't really say what to do when client id not UTF8, so including here.
  183 + if (!isValidUtf8(client_id) || !isValidUtf8(username) || !isValidUtf8(password))
  184 + {
  185 + ConnAck connAck(ConnAckReturnCodes::MalformedUsernameOrPassword);
  186 + MqttPacket response(connAck);
  187 + sender->setReadyForDisconnect();
  188 + sender->writeMqttPacket(response);
  189 + std::cout << "Client ID, username or passwords has invalid UTF8: " << sender->repr() << std::endl;
  190 + return;
  191 + }
  192 +
  193 + // In case the client_id ever appears in topics.
  194 + // TODO: make setting?
  195 + if (strContains(client_id, "+") || strContains(client_id, "#"))
  196 + {
  197 + ConnAck connAck(ConnAckReturnCodes::ClientIdRejected);
  198 + MqttPacket response(connAck);
  199 + sender->setReadyForDisconnect();
  200 + sender->writeMqttPacket(response);
  201 + std::cout << "ClientID has + or # in the id: " << sender->repr() << std::endl;
  202 + return;
  203 + }
178 204  
179 205 sender->setClientProperties(client_id, username, true, keep_alive);
180 206 sender->setWill(will_topic, will_payload, will_retain, will_qos);
... ...
utils.cpp
... ... @@ -106,3 +106,8 @@ bool isValidUtf8(const std::string &amp;s)
106 106 }
107 107 return multibyte_remain == 0;
108 108 }
  109 +
  110 +bool strContains(const std::string &s, const std::string &needle)
  111 +{
  112 + return s.find(needle) != std::string::npos;
  113 +}
... ...
... ... @@ -25,4 +25,6 @@ bool topicsMatch(const std::string &amp;subscribeTopic, const std::string &amp;publishTo
25 25  
26 26 bool isValidUtf8(const std::string &s);
27 27  
  28 +bool strContains(const std::string &s, const std::string &needle);
  29 +
28 30 #endif // UTILS_H
... ...