Commit 52cdff03ea3371df7e8b519aa728b35186fc4f38
1 parent
4eaca771
Some more compliant connect error handling
Showing
6 changed files
with
46 additions
and
2 deletions
client.cpp
| @@ -100,6 +100,9 @@ void Client::writeMqttPacket(const MqttPacket &packet) | @@ -100,6 +100,9 @@ void Client::writeMqttPacket(const MqttPacket &packet) | ||
| 100 | assert(wwi >= static_cast<int>(packet.getSizeIncludingNonPresentHeader())); | 100 | assert(wwi >= static_cast<int>(packet.getSizeIncludingNonPresentHeader())); |
| 101 | assert(wwi <= static_cast<int>(writeBufsize)); | 101 | assert(wwi <= static_cast<int>(writeBufsize)); |
| 102 | 102 | ||
| 103 | + if (packet.packetType == PacketType::DISCONNECT) | ||
| 104 | + setReadyForDisconnect(); | ||
| 105 | + | ||
| 103 | setReadyForWriting(true); | 106 | setReadyForWriting(true); |
| 104 | } | 107 | } |
| 105 | 108 |
client.h
| @@ -36,6 +36,7 @@ class Client | @@ -36,6 +36,7 @@ class Client | ||
| 36 | bool connectPacketSeen = false; | 36 | bool connectPacketSeen = false; |
| 37 | bool readyForWriting = false; | 37 | bool readyForWriting = false; |
| 38 | bool readyForReading = true; | 38 | bool readyForReading = true; |
| 39 | + bool disconnectWhenBytesWritten = false; | ||
| 39 | 40 | ||
| 40 | std::string clientid; | 41 | std::string clientid; |
| 41 | std::string username; | 42 | std::string username; |
| @@ -123,6 +124,10 @@ public: | @@ -123,6 +124,10 @@ public: | ||
| 123 | void writePingResp(); | 124 | void writePingResp(); |
| 124 | void writeMqttPacket(const MqttPacket &packet); | 125 | void writeMqttPacket(const MqttPacket &packet); |
| 125 | bool writeBufIntoFd(); | 126 | bool writeBufIntoFd(); |
| 127 | + bool readyForDisconnecting() const { return disconnectWhenBytesWritten && wwi == wri && wwi == 0; } | ||
| 128 | + | ||
| 129 | + // Do this before calling an action that makes this client ready for writing, so that the EPOLLOUT will handle it. | ||
| 130 | + void setReadyForDisconnect() { disconnectWhenBytesWritten = true; } | ||
| 126 | 131 | ||
| 127 | std::string repr(); | 132 | std::string repr(); |
| 128 | 133 |
mainapp.cpp
| @@ -74,6 +74,9 @@ void do_thread_work(ThreadData *threadData) | @@ -74,6 +74,9 @@ void do_thread_work(ThreadData *threadData) | ||
| 74 | { | 74 | { |
| 75 | if (!client->writeBufIntoFd()) | 75 | if (!client->writeBufIntoFd()) |
| 76 | threadData->removeClient(client); | 76 | threadData->removeClient(client); |
| 77 | + | ||
| 78 | + if (client->readyForDisconnecting()) | ||
| 79 | + threadData->removeClient(client); | ||
| 77 | } | 80 | } |
| 78 | } | 81 | } |
| 79 | catch(std::exception &ex) | 82 | catch(std::exception &ex) |
mqttpacket.cpp
| @@ -128,7 +128,12 @@ void MqttPacket::handleConnect() | @@ -128,7 +128,12 @@ void MqttPacket::handleConnect() | ||
| 128 | } | 128 | } |
| 129 | else | 129 | else |
| 130 | { | 130 | { |
| 131 | - throw ProtocolError("Only MQTT 3.1 and 3.1.1 supported."); | 131 | + ConnAck connAck(ConnAckReturnCodes::UnacceptableProtocolVersion); |
| 132 | + MqttPacket response(connAck); | ||
| 133 | + sender->setReadyForDisconnect(); | ||
| 134 | + sender->writeMqttPacket(response); | ||
| 135 | + std::cout << "Rejecting because of invalid protocol version: " << sender->repr() << std::endl; | ||
| 136 | + return; | ||
| 132 | } | 137 | } |
| 133 | 138 | ||
| 134 | char flagByte = readByte(); | 139 | char flagByte = readByte(); |
| @@ -174,7 +179,28 @@ void MqttPacket::handleConnect() | @@ -174,7 +179,28 @@ void MqttPacket::handleConnect() | ||
| 174 | password = std::string(readBytes(password_length), password_length); | 179 | password = std::string(readBytes(password_length), password_length); |
| 175 | } | 180 | } |
| 176 | 181 | ||
| 177 | - // TODO: validate UTF8 encoded username/password. | 182 | + // The specs don't really say what to do when client id not UTF8, so including here. |
| 183 | + if (!isValidUtf8(client_id) || !isValidUtf8(username) || !isValidUtf8(password)) | ||
| 184 | + { | ||
| 185 | + ConnAck connAck(ConnAckReturnCodes::MalformedUsernameOrPassword); | ||
| 186 | + MqttPacket response(connAck); | ||
| 187 | + sender->setReadyForDisconnect(); | ||
| 188 | + sender->writeMqttPacket(response); | ||
| 189 | + std::cout << "Client ID, username or passwords has invalid UTF8: " << sender->repr() << std::endl; | ||
| 190 | + return; | ||
| 191 | + } | ||
| 192 | + | ||
| 193 | + // In case the client_id ever appears in topics. | ||
| 194 | + // TODO: make setting? | ||
| 195 | + if (strContains(client_id, "+") || strContains(client_id, "#")) | ||
| 196 | + { | ||
| 197 | + ConnAck connAck(ConnAckReturnCodes::ClientIdRejected); | ||
| 198 | + MqttPacket response(connAck); | ||
| 199 | + sender->setReadyForDisconnect(); | ||
| 200 | + sender->writeMqttPacket(response); | ||
| 201 | + std::cout << "ClientID has + or # in the id: " << sender->repr() << std::endl; | ||
| 202 | + return; | ||
| 203 | + } | ||
| 178 | 204 | ||
| 179 | sender->setClientProperties(client_id, username, true, keep_alive); | 205 | sender->setClientProperties(client_id, username, true, keep_alive); |
| 180 | sender->setWill(will_topic, will_payload, will_retain, will_qos); | 206 | sender->setWill(will_topic, will_payload, will_retain, will_qos); |
utils.cpp
| @@ -106,3 +106,8 @@ bool isValidUtf8(const std::string &s) | @@ -106,3 +106,8 @@ bool isValidUtf8(const std::string &s) | ||
| 106 | } | 106 | } |
| 107 | return multibyte_remain == 0; | 107 | return multibyte_remain == 0; |
| 108 | } | 108 | } |
| 109 | + | ||
| 110 | +bool strContains(const std::string &s, const std::string &needle) | ||
| 111 | +{ | ||
| 112 | + return s.find(needle) != std::string::npos; | ||
| 113 | +} |
utils.h
| @@ -25,4 +25,6 @@ bool topicsMatch(const std::string &subscribeTopic, const std::string &publishTo | @@ -25,4 +25,6 @@ bool topicsMatch(const std::string &subscribeTopic, const std::string &publishTo | ||
| 25 | 25 | ||
| 26 | bool isValidUtf8(const std::string &s); | 26 | bool isValidUtf8(const std::string &s); |
| 27 | 27 | ||
| 28 | +bool strContains(const std::string &s, const std::string &needle); | ||
| 29 | + | ||
| 28 | #endif // UTILS_H | 30 | #endif // UTILS_H |