From cb40c2d247db6c218daacb0c058fcaa093015157 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Thu, 11 Mar 2021 22:34:23 +0100 Subject: [PATCH] Implement unsubscribe handling --- mqttpacket.cpp | 38 ++++++++++++++++++++++++++++++++++++++ mqttpacket.h | 2 ++ subscriptionstore.cpp | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ subscriptionstore.h | 2 ++ types.cpp | 11 +++++++++++ types.h | 8 ++++++++ 6 files changed, 121 insertions(+), 0 deletions(-) diff --git a/mqttpacket.cpp b/mqttpacket.cpp index acde794..5f19148 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -66,6 +66,16 @@ MqttPacket::MqttPacket(const SubAck &subAck) : calculateRemainingLength(); } +MqttPacket::MqttPacket(const UnsubAck &unsubAck) : + bites(unsubAck.getLengthWithoutFixedHeader()) +{ + packetType = PacketType::SUBACK; + first_byte = static_cast(packetType) << 4; + writeByte((unsubAck.packet_id & 0xFF00) >> 8); + writeByte(unsubAck.packet_id & 0x00FF); + calculateRemainingLength(); +} + MqttPacket::MqttPacket(const Publish &publish) : bites(publish.getLengthWithoutFixedHeader()) { @@ -136,6 +146,8 @@ void MqttPacket::handle() sender->writePingResp(); else if (packetType == PacketType::SUBSCRIBE) handleSubscribe(); + else if (packetType == PacketType::UNSUBSCRIBE) + handleUnsubscribe(); else if (packetType == PacketType::PUBLISH) handlePublish(); else if (packetType == PacketType::PUBACK) @@ -358,6 +370,32 @@ void MqttPacket::handleSubscribe() sender->writeMqttPacket(response); } +void MqttPacket::handleUnsubscribe() +{ + const char firstByteFirstNibble = (first_byte & 0x0F); + + if (firstByteFirstNibble != 2) + throw ProtocolError("First LSB of first byte is wrong value for subscribe packet."); + + uint16_t packet_id = readTwoBytesToUInt16(); + + while (remainingAfterPos() > 0) + { + uint16_t topicLength = readTwoBytesToUInt16(); + std::string topic(readBytes(topicLength), topicLength); + + if (topic.empty() || !isValidUtf8(topic)) + throw ProtocolError("Subscribe topic not valid UTF-8."); + + sender->getThreadData()->getSubscriptionStore()->removeSubscription(sender, topic); + logger->logf(LOG_INFO, "Client '%s' unsubscribed to '%s'", sender->repr().c_str(), topic.c_str()); + } + + UnsubAck unsubAck(packet_id); + MqttPacket response(unsubAck); + sender->writeMqttPacket(response); +} + void MqttPacket::handlePublish() { uint16_t variable_header_length = readTwoBytesToUInt16(); diff --git a/mqttpacket.h b/mqttpacket.h index 14421e9..bab066b 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -59,6 +59,7 @@ public: // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance. MqttPacket(const ConnAck &connAck); MqttPacket(const SubAck &subAck); + MqttPacket(const UnsubAck &unsubAck); MqttPacket(const Publish &publish); MqttPacket(const PubAck &pubAck); @@ -66,6 +67,7 @@ public: void handleConnect(); void handleDisconnect(); void handleSubscribe(); + void handleUnsubscribe(); void handlePing(); void handlePublish(); void handlePubAck(); diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index 682f997..5b68844 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -30,6 +30,20 @@ void SubscriptionNode::addSubscriber(const std::shared_ptr &subscriber, } } +void SubscriptionNode::removeSubscriber(const std::shared_ptr &subscriber) +{ + Subscription sub; + sub.session = subscriber; + sub.qos = 0; + + auto it = std::find(subscribers.begin(), subscribers.end(), sub); + + if (it != subscribers.end()) + { + subscribers.erase(it); + } +} + SubscriptionStore::SubscriptionStore() : root(new SubscriptionNode("root")), @@ -84,6 +98,52 @@ void SubscriptionStore::addSubscription(Client_p &client, const std::string &top } +void SubscriptionStore::removeSubscription(Client_p &client, const std::string &topic) +{ + const std::list subtopics = split(topic, '/'); + + RWLockGuard lock_guard(&subscriptionsRwlock); + lock_guard.wrlock(); + + // TODO: because it's so similar to adding a subscription, make a function to retrieve the deepest node? + SubscriptionNode *deepestNode = root.get(); + for(const std::string &subtopic : subtopics) + { + std::unique_ptr *selectedChildren = nullptr; + + if (subtopic == "#") + selectedChildren = &deepestNode->childrenPound; + else if (subtopic == "+") + selectedChildren = &deepestNode->childrenPlus; + else + selectedChildren = &deepestNode->children[subtopic]; + + std::unique_ptr &node = *selectedChildren; + + if (!node) + { + return; + } + deepestNode = node.get(); + } + + assert(deepestNode); + + if (deepestNode) + { + auto session_it = sessionsByIdConst.find(client->getClientId()); + if (session_it != sessionsByIdConst.end()) + { + const std::shared_ptr &ses = session_it->second; + deepestNode->removeSubscriber(ses); + } + } + + lock_guard.unlock(); + + +} + // Removes an existing client when it already exists [MQTT-3.1.4-2]. void SubscriptionStore::registerClientAndKickExistingOne(Client_p &client) { diff --git a/subscriptionstore.h b/subscriptionstore.h index ba11b07..d5340cd 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -42,6 +42,7 @@ public: std::vector &getSubscribers(); void addSubscriber(const std::shared_ptr &subscriber, char qos); + void removeSubscriber(const std::shared_ptr &subscriber); std::unordered_map> children; std::unique_ptr childrenPlus; std::unique_ptr childrenPound; @@ -69,6 +70,7 @@ public: SubscriptionStore(); void addSubscription(Client_p &client, const std::string &topic, char qos); + void removeSubscription(Client_p &client, const std::string &topic); void registerClientAndKickExistingOne(Client_p &client); bool sessionPresent(const std::string &clientid); diff --git a/types.cpp b/types.cpp index 973820e..90c2c0a 100644 --- a/types.cpp +++ b/types.cpp @@ -54,3 +54,14 @@ size_t PubAck::getLengthWithoutFixedHeader() const { return 2; } + +UnsubAck::UnsubAck(uint16_t packet_id) : + packet_id(packet_id) +{ + +} + +size_t UnsubAck::getLengthWithoutFixedHeader() const +{ + return 2; +} diff --git a/types.h b/types.h index c6d3352..ccad93f 100644 --- a/types.h +++ b/types.h @@ -69,6 +69,14 @@ public: size_t getLengthWithoutFixedHeader() const; }; +class UnsubAck +{ +public: + uint16_t packet_id; + UnsubAck(uint16_t packet_id); + size_t getLengthWithoutFixedHeader() const; +}; + class Publish { public: -- libgit2 0.21.4