diff --git a/exceptions.h b/exceptions.h
index 72658f2..c8a6405 100644
--- a/exceptions.h
+++ b/exceptions.h
@@ -22,10 +22,18 @@ License along with FlashMQ. If not, see .
#include
#include
+#include "types.h"
+
class ProtocolError : public std::runtime_error
{
public:
- ProtocolError(const std::string &msg) : std::runtime_error(msg) {}
+ const ReasonCodes reasonCode;
+
+ ProtocolError(const std::string &msg, ReasonCodes reasonCode = ReasonCodes::UnspecifiedError) : std::runtime_error(msg),
+ reasonCode(reasonCode)
+ {
+
+ }
};
class NotImplementedException : public std::runtime_error
diff --git a/mqttpacket.cpp b/mqttpacket.cpp
index 6417ed3..3cc6773 100644
--- a/mqttpacket.cpp
+++ b/mqttpacket.cpp
@@ -34,7 +34,7 @@ MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_lengt
if (packet_len > sender->getMaxIncomingPacketSize())
{
- throw ProtocolError("Incoming packet size exceeded. TODO: DISCONNECT WITH CODE 0x95");
+ throw ProtocolError("Incoming packet size exceeded.", ReasonCodes::PacketTooLarge);
}
buf.read(bites.data(), packet_len);
@@ -204,6 +204,19 @@ MqttPacket::MqttPacket(const PubResponse &pubAck) :
}
}
+MqttPacket::MqttPacket(const Disconnect &disconnect) :
+ bites(disconnect.getLengthWithoutFixedHeader())
+{
+ this->protocolVersion = ProtocolVersion::Mqtt5;
+
+ packetType = PacketType::DISCONNECT;
+ first_byte = static_cast(packetType) << 4;
+
+ writeByte(static_cast(disconnect.reasonCode));
+ writeProperties(disconnect.propertyBuilder);
+ calculateRemainingLength();
+}
+
void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector &packetQueueIn, std::shared_ptr &sender)
{
while (buf.usedBytes() >= MQTT_HEADER_LENGH)
@@ -630,6 +643,9 @@ void MqttPacket::handleConnect()
void MqttPacket::handleDisconnect()
{
+ if (first_byte & 0b1111)
+ throw ProtocolError("Disconnect packet first 4 bits should be 0.", ReasonCodes::MalformedPacket);
+
logger->logf(LOG_NOTICE, "Client '%s' cleanly disconnecting", sender->repr().c_str());
sender->setDisconnectReason("MQTT Disconnect received.");
sender->markAsDisconnecting();
diff --git a/mqttpacket.h b/mqttpacket.h
index 5b6cb15..4f1884e 100644
--- a/mqttpacket.h
+++ b/mqttpacket.h
@@ -100,6 +100,7 @@ public:
MqttPacket(const UnsubAck &unsubAck);
MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish);
MqttPacket(const PubResponse &pubAck);
+ MqttPacket(const Disconnect &disconnect);
static void bufferToMqttPackets(CirBuf &buf, std::vector &packetQueueIn, std::shared_ptr &sender);
diff --git a/threadloop.cpp b/threadloop.cpp
index 4800370..a19b349 100644
--- a/threadloop.cpp
+++ b/threadloop.cpp
@@ -135,6 +135,22 @@ void do_thread_work(ThreadData *threadData)
}
}
}
+ catch (ProtocolError &ex)
+ {
+ client->setDisconnectReason(ex.what());
+ if (client->getProtocolVersion() >= ProtocolVersion::Mqtt5 && client->hasConnectPacketSeen())
+ {
+ Disconnect d(client->getProtocolVersion(), ex.reasonCode);
+ MqttPacket p(d);
+ client->writeMqttPacket(p);
+ client->setReadyForDisconnect();
+ }
+ else
+ {
+ logger->logf(LOG_ERR, "Protocol error: %s. Removing client.", ex.what());
+ threadData->removeClient(client);
+ }
+ }
catch(std::exception &ex)
{
client->setDisconnectReason(ex.what());
diff --git a/types.cpp b/types.cpp
index 33dddea..cc125d1 100644
--- a/types.cpp
+++ b/types.cpp
@@ -289,3 +289,27 @@ size_t UnsubAck::getLengthWithoutFixedHeader() const
return result;
}
+
+Disconnect::Disconnect(const ProtocolVersion protVersion, ReasonCodes reason_code) :
+ reasonCode(reason_code)
+{
+ assert(protVersion >= ProtocolVersion::Mqtt5);
+
+
+}
+
+size_t Disconnect::getLengthWithoutFixedHeader() const
+{
+ size_t result = 1;
+ const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1;
+ result += proplen;
+ return result;
+}
+
+
+
+
+
+
+
+
diff --git a/types.h b/types.h
index 2c7265c..b68b1f4 100644
--- a/types.h
+++ b/types.h
@@ -249,4 +249,13 @@ public:
bool needsReasonCode() const;
};
+class Disconnect
+{
+public:
+ ReasonCodes reasonCode;
+ std::shared_ptr propertyBuilder;
+ Disconnect(const ProtocolVersion protVersion, ReasonCodes reason_code);
+ size_t getLengthWithoutFixedHeader() const;
+};
+
#endif // TYPES_H