Commit 5f1d825927afe50335f72ef8572f68f05015bc90
1 parent
4d957a3d
Add ability to send disconnect packets with code
Showing
6 changed files
with
76 additions
and
2 deletions
exceptions.h
| ... | ... | @@ -22,10 +22,18 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. |
| 22 | 22 | #include <stdexcept> |
| 23 | 23 | #include <sstream> |
| 24 | 24 | |
| 25 | +#include "types.h" | |
| 26 | + | |
| 25 | 27 | class ProtocolError : public std::runtime_error |
| 26 | 28 | { |
| 27 | 29 | public: |
| 28 | - ProtocolError(const std::string &msg) : std::runtime_error(msg) {} | |
| 30 | + const ReasonCodes reasonCode; | |
| 31 | + | |
| 32 | + ProtocolError(const std::string &msg, ReasonCodes reasonCode = ReasonCodes::UnspecifiedError) : std::runtime_error(msg), | |
| 33 | + reasonCode(reasonCode) | |
| 34 | + { | |
| 35 | + | |
| 36 | + } | |
| 29 | 37 | }; |
| 30 | 38 | |
| 31 | 39 | class NotImplementedException : public std::runtime_error | ... | ... |
mqttpacket.cpp
| ... | ... | @@ -34,7 +34,7 @@ MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_lengt |
| 34 | 34 | |
| 35 | 35 | if (packet_len > sender->getMaxIncomingPacketSize()) |
| 36 | 36 | { |
| 37 | - throw ProtocolError("Incoming packet size exceeded. TODO: DISCONNECT WITH CODE 0x95"); | |
| 37 | + throw ProtocolError("Incoming packet size exceeded.", ReasonCodes::PacketTooLarge); | |
| 38 | 38 | } |
| 39 | 39 | |
| 40 | 40 | buf.read(bites.data(), packet_len); |
| ... | ... | @@ -204,6 +204,19 @@ MqttPacket::MqttPacket(const PubResponse &pubAck) : |
| 204 | 204 | } |
| 205 | 205 | } |
| 206 | 206 | |
| 207 | +MqttPacket::MqttPacket(const Disconnect &disconnect) : | |
| 208 | + bites(disconnect.getLengthWithoutFixedHeader()) | |
| 209 | +{ | |
| 210 | + this->protocolVersion = ProtocolVersion::Mqtt5; | |
| 211 | + | |
| 212 | + packetType = PacketType::DISCONNECT; | |
| 213 | + first_byte = static_cast<char>(packetType) << 4; | |
| 214 | + | |
| 215 | + writeByte(static_cast<uint8_t>(disconnect.reasonCode)); | |
| 216 | + writeProperties(disconnect.propertyBuilder); | |
| 217 | + calculateRemainingLength(); | |
| 218 | +} | |
| 219 | + | |
| 207 | 220 | void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender) |
| 208 | 221 | { |
| 209 | 222 | while (buf.usedBytes() >= MQTT_HEADER_LENGH) |
| ... | ... | @@ -630,6 +643,9 @@ void MqttPacket::handleConnect() |
| 630 | 643 | |
| 631 | 644 | void MqttPacket::handleDisconnect() |
| 632 | 645 | { |
| 646 | + if (first_byte & 0b1111) | |
| 647 | + throw ProtocolError("Disconnect packet first 4 bits should be 0.", ReasonCodes::MalformedPacket); | |
| 648 | + | |
| 633 | 649 | logger->logf(LOG_NOTICE, "Client '%s' cleanly disconnecting", sender->repr().c_str()); |
| 634 | 650 | sender->setDisconnectReason("MQTT Disconnect received."); |
| 635 | 651 | sender->markAsDisconnecting(); | ... | ... |
mqttpacket.h
| ... | ... | @@ -100,6 +100,7 @@ public: |
| 100 | 100 | MqttPacket(const UnsubAck &unsubAck); |
| 101 | 101 | MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish); |
| 102 | 102 | MqttPacket(const PubResponse &pubAck); |
| 103 | + MqttPacket(const Disconnect &disconnect); | |
| 103 | 104 | |
| 104 | 105 | static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender); |
| 105 | 106 | ... | ... |
threadloop.cpp
| ... | ... | @@ -135,6 +135,22 @@ void do_thread_work(ThreadData *threadData) |
| 135 | 135 | } |
| 136 | 136 | } |
| 137 | 137 | } |
| 138 | + catch (ProtocolError &ex) | |
| 139 | + { | |
| 140 | + client->setDisconnectReason(ex.what()); | |
| 141 | + if (client->getProtocolVersion() >= ProtocolVersion::Mqtt5 && client->hasConnectPacketSeen()) | |
| 142 | + { | |
| 143 | + Disconnect d(client->getProtocolVersion(), ex.reasonCode); | |
| 144 | + MqttPacket p(d); | |
| 145 | + client->writeMqttPacket(p); | |
| 146 | + client->setReadyForDisconnect(); | |
| 147 | + } | |
| 148 | + else | |
| 149 | + { | |
| 150 | + logger->logf(LOG_ERR, "Protocol error: %s. Removing client.", ex.what()); | |
| 151 | + threadData->removeClient(client); | |
| 152 | + } | |
| 153 | + } | |
| 138 | 154 | catch(std::exception &ex) |
| 139 | 155 | { |
| 140 | 156 | client->setDisconnectReason(ex.what()); | ... | ... |
types.cpp
| ... | ... | @@ -289,3 +289,27 @@ size_t UnsubAck::getLengthWithoutFixedHeader() const |
| 289 | 289 | |
| 290 | 290 | return result; |
| 291 | 291 | } |
| 292 | + | |
| 293 | +Disconnect::Disconnect(const ProtocolVersion protVersion, ReasonCodes reason_code) : | |
| 294 | + reasonCode(reason_code) | |
| 295 | +{ | |
| 296 | + assert(protVersion >= ProtocolVersion::Mqtt5); | |
| 297 | + | |
| 298 | + | |
| 299 | +} | |
| 300 | + | |
| 301 | +size_t Disconnect::getLengthWithoutFixedHeader() const | |
| 302 | +{ | |
| 303 | + size_t result = 1; | |
| 304 | + const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1; | |
| 305 | + result += proplen; | |
| 306 | + return result; | |
| 307 | +} | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | ... | ... |
types.h
| ... | ... | @@ -249,4 +249,13 @@ public: |
| 249 | 249 | bool needsReasonCode() const; |
| 250 | 250 | }; |
| 251 | 251 | |
| 252 | +class Disconnect | |
| 253 | +{ | |
| 254 | +public: | |
| 255 | + ReasonCodes reasonCode; | |
| 256 | + std::shared_ptr<Mqtt5PropertyBuilder> propertyBuilder; | |
| 257 | + Disconnect(const ProtocolVersion protVersion, ReasonCodes reason_code); | |
| 258 | + size_t getLengthWithoutFixedHeader() const; | |
| 259 | +}; | |
| 260 | + | |
| 252 | 261 | #endif // TYPES_H | ... | ... |