diff --git a/exceptions.h b/exceptions.h index 72658f2..c8a6405 100644 --- a/exceptions.h +++ b/exceptions.h @@ -22,10 +22,18 @@ License along with FlashMQ. If not, see . #include #include +#include "types.h" + class ProtocolError : public std::runtime_error { public: - ProtocolError(const std::string &msg) : std::runtime_error(msg) {} + const ReasonCodes reasonCode; + + ProtocolError(const std::string &msg, ReasonCodes reasonCode = ReasonCodes::UnspecifiedError) : std::runtime_error(msg), + reasonCode(reasonCode) + { + + } }; class NotImplementedException : public std::runtime_error diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 6417ed3..3cc6773 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -34,7 +34,7 @@ MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_lengt if (packet_len > sender->getMaxIncomingPacketSize()) { - throw ProtocolError("Incoming packet size exceeded. TODO: DISCONNECT WITH CODE 0x95"); + throw ProtocolError("Incoming packet size exceeded.", ReasonCodes::PacketTooLarge); } buf.read(bites.data(), packet_len); @@ -204,6 +204,19 @@ MqttPacket::MqttPacket(const PubResponse &pubAck) : } } +MqttPacket::MqttPacket(const Disconnect &disconnect) : + bites(disconnect.getLengthWithoutFixedHeader()) +{ + this->protocolVersion = ProtocolVersion::Mqtt5; + + packetType = PacketType::DISCONNECT; + first_byte = static_cast(packetType) << 4; + + writeByte(static_cast(disconnect.reasonCode)); + writeProperties(disconnect.propertyBuilder); + calculateRemainingLength(); +} + void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector &packetQueueIn, std::shared_ptr &sender) { while (buf.usedBytes() >= MQTT_HEADER_LENGH) @@ -630,6 +643,9 @@ void MqttPacket::handleConnect() void MqttPacket::handleDisconnect() { + if (first_byte & 0b1111) + throw ProtocolError("Disconnect packet first 4 bits should be 0.", ReasonCodes::MalformedPacket); + logger->logf(LOG_NOTICE, "Client '%s' cleanly disconnecting", sender->repr().c_str()); sender->setDisconnectReason("MQTT Disconnect received."); sender->markAsDisconnecting(); diff --git a/mqttpacket.h b/mqttpacket.h index 5b6cb15..4f1884e 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -100,6 +100,7 @@ public: MqttPacket(const UnsubAck &unsubAck); MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish); MqttPacket(const PubResponse &pubAck); + MqttPacket(const Disconnect &disconnect); static void bufferToMqttPackets(CirBuf &buf, std::vector &packetQueueIn, std::shared_ptr &sender); diff --git a/threadloop.cpp b/threadloop.cpp index 4800370..a19b349 100644 --- a/threadloop.cpp +++ b/threadloop.cpp @@ -135,6 +135,22 @@ void do_thread_work(ThreadData *threadData) } } } + catch (ProtocolError &ex) + { + client->setDisconnectReason(ex.what()); + if (client->getProtocolVersion() >= ProtocolVersion::Mqtt5 && client->hasConnectPacketSeen()) + { + Disconnect d(client->getProtocolVersion(), ex.reasonCode); + MqttPacket p(d); + client->writeMqttPacket(p); + client->setReadyForDisconnect(); + } + else + { + logger->logf(LOG_ERR, "Protocol error: %s. Removing client.", ex.what()); + threadData->removeClient(client); + } + } catch(std::exception &ex) { client->setDisconnectReason(ex.what()); diff --git a/types.cpp b/types.cpp index 33dddea..cc125d1 100644 --- a/types.cpp +++ b/types.cpp @@ -289,3 +289,27 @@ size_t UnsubAck::getLengthWithoutFixedHeader() const return result; } + +Disconnect::Disconnect(const ProtocolVersion protVersion, ReasonCodes reason_code) : + reasonCode(reason_code) +{ + assert(protVersion >= ProtocolVersion::Mqtt5); + + +} + +size_t Disconnect::getLengthWithoutFixedHeader() const +{ + size_t result = 1; + const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1; + result += proplen; + return result; +} + + + + + + + + diff --git a/types.h b/types.h index 2c7265c..b68b1f4 100644 --- a/types.h +++ b/types.h @@ -249,4 +249,13 @@ public: bool needsReasonCode() const; }; +class Disconnect +{ +public: + ReasonCodes reasonCode; + std::shared_ptr propertyBuilder; + Disconnect(const ProtocolVersion protVersion, ReasonCodes reason_code); + size_t getLengthWithoutFixedHeader() const; +}; + #endif // TYPES_H