diff --git a/client.cpp b/client.cpp index d4d1961..ccd4c4d 100644 --- a/client.cpp +++ b/client.cpp @@ -100,6 +100,9 @@ void Client::writeMqttPacket(const MqttPacket &packet) assert(wwi >= static_cast(packet.getSizeIncludingNonPresentHeader())); assert(wwi <= static_cast(writeBufsize)); + if (packet.packetType == PacketType::DISCONNECT) + setReadyForDisconnect(); + setReadyForWriting(true); } diff --git a/client.h b/client.h index 379caf4..f379895 100644 --- a/client.h +++ b/client.h @@ -36,6 +36,7 @@ class Client bool connectPacketSeen = false; bool readyForWriting = false; bool readyForReading = true; + bool disconnectWhenBytesWritten = false; std::string clientid; std::string username; @@ -123,6 +124,10 @@ public: void writePingResp(); void writeMqttPacket(const MqttPacket &packet); bool writeBufIntoFd(); + bool readyForDisconnecting() const { return disconnectWhenBytesWritten && wwi == wri && wwi == 0; } + + // Do this before calling an action that makes this client ready for writing, so that the EPOLLOUT will handle it. + void setReadyForDisconnect() { disconnectWhenBytesWritten = true; } std::string repr(); diff --git a/mainapp.cpp b/mainapp.cpp index 340b9a6..e8284ce 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -74,6 +74,9 @@ void do_thread_work(ThreadData *threadData) { if (!client->writeBufIntoFd()) threadData->removeClient(client); + + if (client->readyForDisconnecting()) + threadData->removeClient(client); } } catch(std::exception &ex) diff --git a/mqttpacket.cpp b/mqttpacket.cpp index a9eae05..277fd57 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -128,7 +128,12 @@ void MqttPacket::handleConnect() } else { - throw ProtocolError("Only MQTT 3.1 and 3.1.1 supported."); + ConnAck connAck(ConnAckReturnCodes::UnacceptableProtocolVersion); + MqttPacket response(connAck); + sender->setReadyForDisconnect(); + sender->writeMqttPacket(response); + std::cout << "Rejecting because of invalid protocol version: " << sender->repr() << std::endl; + return; } char flagByte = readByte(); @@ -174,7 +179,28 @@ void MqttPacket::handleConnect() password = std::string(readBytes(password_length), password_length); } - // TODO: validate UTF8 encoded username/password. + // The specs don't really say what to do when client id not UTF8, so including here. + if (!isValidUtf8(client_id) || !isValidUtf8(username) || !isValidUtf8(password)) + { + ConnAck connAck(ConnAckReturnCodes::MalformedUsernameOrPassword); + MqttPacket response(connAck); + sender->setReadyForDisconnect(); + sender->writeMqttPacket(response); + std::cout << "Client ID, username or passwords has invalid UTF8: " << sender->repr() << std::endl; + return; + } + + // In case the client_id ever appears in topics. + // TODO: make setting? + if (strContains(client_id, "+") || strContains(client_id, "#")) + { + ConnAck connAck(ConnAckReturnCodes::ClientIdRejected); + MqttPacket response(connAck); + sender->setReadyForDisconnect(); + sender->writeMqttPacket(response); + std::cout << "ClientID has + or # in the id: " << sender->repr() << std::endl; + return; + } sender->setClientProperties(client_id, username, true, keep_alive); sender->setWill(will_topic, will_payload, will_retain, will_qos); diff --git a/utils.cpp b/utils.cpp index bea1b4c..46076bb 100644 --- a/utils.cpp +++ b/utils.cpp @@ -106,3 +106,8 @@ bool isValidUtf8(const std::string &s) } return multibyte_remain == 0; } + +bool strContains(const std::string &s, const std::string &needle) +{ + return s.find(needle) != std::string::npos; +} diff --git a/utils.h b/utils.h index cd92dbb..66b14e0 100644 --- a/utils.h +++ b/utils.h @@ -25,4 +25,6 @@ bool topicsMatch(const std::string &subscribeTopic, const std::string &publishTo bool isValidUtf8(const std::string &s); +bool strContains(const std::string &s, const std::string &needle); + #endif // UTILS_H