Commit 5f1d825927afe50335f72ef8572f68f05015bc90
1 parent
4d957a3d
Add ability to send disconnect packets with code
Showing
6 changed files
with
76 additions
and
2 deletions
exceptions.h
| @@ -22,10 +22,18 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | @@ -22,10 +22,18 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | ||
| 22 | #include <stdexcept> | 22 | #include <stdexcept> |
| 23 | #include <sstream> | 23 | #include <sstream> |
| 24 | 24 | ||
| 25 | +#include "types.h" | ||
| 26 | + | ||
| 25 | class ProtocolError : public std::runtime_error | 27 | class ProtocolError : public std::runtime_error |
| 26 | { | 28 | { |
| 27 | public: | 29 | public: |
| 28 | - ProtocolError(const std::string &msg) : std::runtime_error(msg) {} | 30 | + const ReasonCodes reasonCode; |
| 31 | + | ||
| 32 | + ProtocolError(const std::string &msg, ReasonCodes reasonCode = ReasonCodes::UnspecifiedError) : std::runtime_error(msg), | ||
| 33 | + reasonCode(reasonCode) | ||
| 34 | + { | ||
| 35 | + | ||
| 36 | + } | ||
| 29 | }; | 37 | }; |
| 30 | 38 | ||
| 31 | class NotImplementedException : public std::runtime_error | 39 | class NotImplementedException : public std::runtime_error |
mqttpacket.cpp
| @@ -34,7 +34,7 @@ MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_lengt | @@ -34,7 +34,7 @@ MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_lengt | ||
| 34 | 34 | ||
| 35 | if (packet_len > sender->getMaxIncomingPacketSize()) | 35 | if (packet_len > sender->getMaxIncomingPacketSize()) |
| 36 | { | 36 | { |
| 37 | - throw ProtocolError("Incoming packet size exceeded. TODO: DISCONNECT WITH CODE 0x95"); | 37 | + throw ProtocolError("Incoming packet size exceeded.", ReasonCodes::PacketTooLarge); |
| 38 | } | 38 | } |
| 39 | 39 | ||
| 40 | buf.read(bites.data(), packet_len); | 40 | buf.read(bites.data(), packet_len); |
| @@ -204,6 +204,19 @@ MqttPacket::MqttPacket(const PubResponse &pubAck) : | @@ -204,6 +204,19 @@ MqttPacket::MqttPacket(const PubResponse &pubAck) : | ||
| 204 | } | 204 | } |
| 205 | } | 205 | } |
| 206 | 206 | ||
| 207 | +MqttPacket::MqttPacket(const Disconnect &disconnect) : | ||
| 208 | + bites(disconnect.getLengthWithoutFixedHeader()) | ||
| 209 | +{ | ||
| 210 | + this->protocolVersion = ProtocolVersion::Mqtt5; | ||
| 211 | + | ||
| 212 | + packetType = PacketType::DISCONNECT; | ||
| 213 | + first_byte = static_cast<char>(packetType) << 4; | ||
| 214 | + | ||
| 215 | + writeByte(static_cast<uint8_t>(disconnect.reasonCode)); | ||
| 216 | + writeProperties(disconnect.propertyBuilder); | ||
| 217 | + calculateRemainingLength(); | ||
| 218 | +} | ||
| 219 | + | ||
| 207 | void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender) | 220 | void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender) |
| 208 | { | 221 | { |
| 209 | while (buf.usedBytes() >= MQTT_HEADER_LENGH) | 222 | while (buf.usedBytes() >= MQTT_HEADER_LENGH) |
| @@ -630,6 +643,9 @@ void MqttPacket::handleConnect() | @@ -630,6 +643,9 @@ void MqttPacket::handleConnect() | ||
| 630 | 643 | ||
| 631 | void MqttPacket::handleDisconnect() | 644 | void MqttPacket::handleDisconnect() |
| 632 | { | 645 | { |
| 646 | + if (first_byte & 0b1111) | ||
| 647 | + throw ProtocolError("Disconnect packet first 4 bits should be 0.", ReasonCodes::MalformedPacket); | ||
| 648 | + | ||
| 633 | logger->logf(LOG_NOTICE, "Client '%s' cleanly disconnecting", sender->repr().c_str()); | 649 | logger->logf(LOG_NOTICE, "Client '%s' cleanly disconnecting", sender->repr().c_str()); |
| 634 | sender->setDisconnectReason("MQTT Disconnect received."); | 650 | sender->setDisconnectReason("MQTT Disconnect received."); |
| 635 | sender->markAsDisconnecting(); | 651 | sender->markAsDisconnecting(); |
mqttpacket.h
| @@ -100,6 +100,7 @@ public: | @@ -100,6 +100,7 @@ public: | ||
| 100 | MqttPacket(const UnsubAck &unsubAck); | 100 | MqttPacket(const UnsubAck &unsubAck); |
| 101 | MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish); | 101 | MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish); |
| 102 | MqttPacket(const PubResponse &pubAck); | 102 | MqttPacket(const PubResponse &pubAck); |
| 103 | + MqttPacket(const Disconnect &disconnect); | ||
| 103 | 104 | ||
| 104 | static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender); | 105 | static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender); |
| 105 | 106 |
threadloop.cpp
| @@ -135,6 +135,22 @@ void do_thread_work(ThreadData *threadData) | @@ -135,6 +135,22 @@ void do_thread_work(ThreadData *threadData) | ||
| 135 | } | 135 | } |
| 136 | } | 136 | } |
| 137 | } | 137 | } |
| 138 | + catch (ProtocolError &ex) | ||
| 139 | + { | ||
| 140 | + client->setDisconnectReason(ex.what()); | ||
| 141 | + if (client->getProtocolVersion() >= ProtocolVersion::Mqtt5 && client->hasConnectPacketSeen()) | ||
| 142 | + { | ||
| 143 | + Disconnect d(client->getProtocolVersion(), ex.reasonCode); | ||
| 144 | + MqttPacket p(d); | ||
| 145 | + client->writeMqttPacket(p); | ||
| 146 | + client->setReadyForDisconnect(); | ||
| 147 | + } | ||
| 148 | + else | ||
| 149 | + { | ||
| 150 | + logger->logf(LOG_ERR, "Protocol error: %s. Removing client.", ex.what()); | ||
| 151 | + threadData->removeClient(client); | ||
| 152 | + } | ||
| 153 | + } | ||
| 138 | catch(std::exception &ex) | 154 | catch(std::exception &ex) |
| 139 | { | 155 | { |
| 140 | client->setDisconnectReason(ex.what()); | 156 | client->setDisconnectReason(ex.what()); |
types.cpp
| @@ -289,3 +289,27 @@ size_t UnsubAck::getLengthWithoutFixedHeader() const | @@ -289,3 +289,27 @@ size_t UnsubAck::getLengthWithoutFixedHeader() const | ||
| 289 | 289 | ||
| 290 | return result; | 290 | return result; |
| 291 | } | 291 | } |
| 292 | + | ||
| 293 | +Disconnect::Disconnect(const ProtocolVersion protVersion, ReasonCodes reason_code) : | ||
| 294 | + reasonCode(reason_code) | ||
| 295 | +{ | ||
| 296 | + assert(protVersion >= ProtocolVersion::Mqtt5); | ||
| 297 | + | ||
| 298 | + | ||
| 299 | +} | ||
| 300 | + | ||
| 301 | +size_t Disconnect::getLengthWithoutFixedHeader() const | ||
| 302 | +{ | ||
| 303 | + size_t result = 1; | ||
| 304 | + const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1; | ||
| 305 | + result += proplen; | ||
| 306 | + return result; | ||
| 307 | +} | ||
| 308 | + | ||
| 309 | + | ||
| 310 | + | ||
| 311 | + | ||
| 312 | + | ||
| 313 | + | ||
| 314 | + | ||
| 315 | + |
types.h
| @@ -249,4 +249,13 @@ public: | @@ -249,4 +249,13 @@ public: | ||
| 249 | bool needsReasonCode() const; | 249 | bool needsReasonCode() const; |
| 250 | }; | 250 | }; |
| 251 | 251 | ||
| 252 | +class Disconnect | ||
| 253 | +{ | ||
| 254 | +public: | ||
| 255 | + ReasonCodes reasonCode; | ||
| 256 | + std::shared_ptr<Mqtt5PropertyBuilder> propertyBuilder; | ||
| 257 | + Disconnect(const ProtocolVersion protVersion, ReasonCodes reason_code); | ||
| 258 | + size_t getLengthWithoutFixedHeader() const; | ||
| 259 | +}; | ||
| 260 | + | ||
| 252 | #endif // TYPES_H | 261 | #endif // TYPES_H |