Commit 5f1d825927afe50335f72ef8572f68f05015bc90

Authored by Wiebe Cazemier
1 parent 4d957a3d

Add ability to send disconnect packets with code

exceptions.h
... ... @@ -22,10 +22,18 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
22 22 #include <stdexcept>
23 23 #include <sstream>
24 24  
  25 +#include "types.h"
  26 +
25 27 class ProtocolError : public std::runtime_error
26 28 {
27 29 public:
28   - ProtocolError(const std::string &msg) : std::runtime_error(msg) {}
  30 + const ReasonCodes reasonCode;
  31 +
  32 + ProtocolError(const std::string &msg, ReasonCodes reasonCode = ReasonCodes::UnspecifiedError) : std::runtime_error(msg),
  33 + reasonCode(reasonCode)
  34 + {
  35 +
  36 + }
29 37 };
30 38  
31 39 class NotImplementedException : public std::runtime_error
... ...
mqttpacket.cpp
... ... @@ -34,7 +34,7 @@ MqttPacket::MqttPacket(CirBuf &amp;buf, size_t packet_len, size_t fixed_header_lengt
34 34  
35 35 if (packet_len > sender->getMaxIncomingPacketSize())
36 36 {
37   - throw ProtocolError("Incoming packet size exceeded. TODO: DISCONNECT WITH CODE 0x95");
  37 + throw ProtocolError("Incoming packet size exceeded.", ReasonCodes::PacketTooLarge);
38 38 }
39 39  
40 40 buf.read(bites.data(), packet_len);
... ... @@ -204,6 +204,19 @@ MqttPacket::MqttPacket(const PubResponse &amp;pubAck) :
204 204 }
205 205 }
206 206  
  207 +MqttPacket::MqttPacket(const Disconnect &disconnect) :
  208 + bites(disconnect.getLengthWithoutFixedHeader())
  209 +{
  210 + this->protocolVersion = ProtocolVersion::Mqtt5;
  211 +
  212 + packetType = PacketType::DISCONNECT;
  213 + first_byte = static_cast<char>(packetType) << 4;
  214 +
  215 + writeByte(static_cast<uint8_t>(disconnect.reasonCode));
  216 + writeProperties(disconnect.propertyBuilder);
  217 + calculateRemainingLength();
  218 +}
  219 +
207 220 void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender)
208 221 {
209 222 while (buf.usedBytes() >= MQTT_HEADER_LENGH)
... ... @@ -630,6 +643,9 @@ void MqttPacket::handleConnect()
630 643  
631 644 void MqttPacket::handleDisconnect()
632 645 {
  646 + if (first_byte & 0b1111)
  647 + throw ProtocolError("Disconnect packet first 4 bits should be 0.", ReasonCodes::MalformedPacket);
  648 +
633 649 logger->logf(LOG_NOTICE, "Client '%s' cleanly disconnecting", sender->repr().c_str());
634 650 sender->setDisconnectReason("MQTT Disconnect received.");
635 651 sender->markAsDisconnecting();
... ...
mqttpacket.h
... ... @@ -100,6 +100,7 @@ public:
100 100 MqttPacket(const UnsubAck &unsubAck);
101 101 MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish);
102 102 MqttPacket(const PubResponse &pubAck);
  103 + MqttPacket(const Disconnect &disconnect);
103 104  
104 105 static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender);
105 106  
... ...
threadloop.cpp
... ... @@ -135,6 +135,22 @@ void do_thread_work(ThreadData *threadData)
135 135 }
136 136 }
137 137 }
  138 + catch (ProtocolError &ex)
  139 + {
  140 + client->setDisconnectReason(ex.what());
  141 + if (client->getProtocolVersion() >= ProtocolVersion::Mqtt5 && client->hasConnectPacketSeen())
  142 + {
  143 + Disconnect d(client->getProtocolVersion(), ex.reasonCode);
  144 + MqttPacket p(d);
  145 + client->writeMqttPacket(p);
  146 + client->setReadyForDisconnect();
  147 + }
  148 + else
  149 + {
  150 + logger->logf(LOG_ERR, "Protocol error: %s. Removing client.", ex.what());
  151 + threadData->removeClient(client);
  152 + }
  153 + }
138 154 catch(std::exception &ex)
139 155 {
140 156 client->setDisconnectReason(ex.what());
... ...
types.cpp
... ... @@ -289,3 +289,27 @@ size_t UnsubAck::getLengthWithoutFixedHeader() const
289 289  
290 290 return result;
291 291 }
  292 +
  293 +Disconnect::Disconnect(const ProtocolVersion protVersion, ReasonCodes reason_code) :
  294 + reasonCode(reason_code)
  295 +{
  296 + assert(protVersion >= ProtocolVersion::Mqtt5);
  297 +
  298 +
  299 +}
  300 +
  301 +size_t Disconnect::getLengthWithoutFixedHeader() const
  302 +{
  303 + size_t result = 1;
  304 + const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1;
  305 + result += proplen;
  306 + return result;
  307 +}
  308 +
  309 +
  310 +
  311 +
  312 +
  313 +
  314 +
  315 +
... ...
... ... @@ -249,4 +249,13 @@ public:
249 249 bool needsReasonCode() const;
250 250 };
251 251  
  252 +class Disconnect
  253 +{
  254 +public:
  255 + ReasonCodes reasonCode;
  256 + std::shared_ptr<Mqtt5PropertyBuilder> propertyBuilder;
  257 + Disconnect(const ProtocolVersion protVersion, ReasonCodes reason_code);
  258 + size_t getLengthWithoutFixedHeader() const;
  259 +};
  260 +
252 261 #endif // TYPES_H
... ...