Commit cb40c2d247db6c218daacb0c058fcaa093015157

Authored by Wiebe Cazemier
1 parent c3961e7f

Implement unsubscribe handling

mqttpacket.cpp
@@ -66,6 +66,16 @@ MqttPacket::MqttPacket(const SubAck &subAck) : @@ -66,6 +66,16 @@ MqttPacket::MqttPacket(const SubAck &subAck) :
66 calculateRemainingLength(); 66 calculateRemainingLength();
67 } 67 }
68 68
  69 +MqttPacket::MqttPacket(const UnsubAck &unsubAck) :
  70 + bites(unsubAck.getLengthWithoutFixedHeader())
  71 +{
  72 + packetType = PacketType::SUBACK;
  73 + first_byte = static_cast<char>(packetType) << 4;
  74 + writeByte((unsubAck.packet_id & 0xFF00) >> 8);
  75 + writeByte(unsubAck.packet_id & 0x00FF);
  76 + calculateRemainingLength();
  77 +}
  78 +
69 MqttPacket::MqttPacket(const Publish &publish) : 79 MqttPacket::MqttPacket(const Publish &publish) :
70 bites(publish.getLengthWithoutFixedHeader()) 80 bites(publish.getLengthWithoutFixedHeader())
71 { 81 {
@@ -136,6 +146,8 @@ void MqttPacket::handle() @@ -136,6 +146,8 @@ void MqttPacket::handle()
136 sender->writePingResp(); 146 sender->writePingResp();
137 else if (packetType == PacketType::SUBSCRIBE) 147 else if (packetType == PacketType::SUBSCRIBE)
138 handleSubscribe(); 148 handleSubscribe();
  149 + else if (packetType == PacketType::UNSUBSCRIBE)
  150 + handleUnsubscribe();
139 else if (packetType == PacketType::PUBLISH) 151 else if (packetType == PacketType::PUBLISH)
140 handlePublish(); 152 handlePublish();
141 else if (packetType == PacketType::PUBACK) 153 else if (packetType == PacketType::PUBACK)
@@ -358,6 +370,32 @@ void MqttPacket::handleSubscribe() @@ -358,6 +370,32 @@ void MqttPacket::handleSubscribe()
358 sender->writeMqttPacket(response); 370 sender->writeMqttPacket(response);
359 } 371 }
360 372
  373 +void MqttPacket::handleUnsubscribe()
  374 +{
  375 + const char firstByteFirstNibble = (first_byte & 0x0F);
  376 +
  377 + if (firstByteFirstNibble != 2)
  378 + throw ProtocolError("First LSB of first byte is wrong value for subscribe packet.");
  379 +
  380 + uint16_t packet_id = readTwoBytesToUInt16();
  381 +
  382 + while (remainingAfterPos() > 0)
  383 + {
  384 + uint16_t topicLength = readTwoBytesToUInt16();
  385 + std::string topic(readBytes(topicLength), topicLength);
  386 +
  387 + if (topic.empty() || !isValidUtf8(topic))
  388 + throw ProtocolError("Subscribe topic not valid UTF-8.");
  389 +
  390 + sender->getThreadData()->getSubscriptionStore()->removeSubscription(sender, topic);
  391 + logger->logf(LOG_INFO, "Client '%s' unsubscribed to '%s'", sender->repr().c_str(), topic.c_str());
  392 + }
  393 +
  394 + UnsubAck unsubAck(packet_id);
  395 + MqttPacket response(unsubAck);
  396 + sender->writeMqttPacket(response);
  397 +}
  398 +
361 void MqttPacket::handlePublish() 399 void MqttPacket::handlePublish()
362 { 400 {
363 uint16_t variable_header_length = readTwoBytesToUInt16(); 401 uint16_t variable_header_length = readTwoBytesToUInt16();
mqttpacket.h
@@ -59,6 +59,7 @@ public: @@ -59,6 +59,7 @@ public:
59 // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance. 59 // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance.
60 MqttPacket(const ConnAck &connAck); 60 MqttPacket(const ConnAck &connAck);
61 MqttPacket(const SubAck &subAck); 61 MqttPacket(const SubAck &subAck);
  62 + MqttPacket(const UnsubAck &unsubAck);
62 MqttPacket(const Publish &publish); 63 MqttPacket(const Publish &publish);
63 MqttPacket(const PubAck &pubAck); 64 MqttPacket(const PubAck &pubAck);
64 65
@@ -66,6 +67,7 @@ public: @@ -66,6 +67,7 @@ public:
66 void handleConnect(); 67 void handleConnect();
67 void handleDisconnect(); 68 void handleDisconnect();
68 void handleSubscribe(); 69 void handleSubscribe();
  70 + void handleUnsubscribe();
69 void handlePing(); 71 void handlePing();
70 void handlePublish(); 72 void handlePublish();
71 void handlePubAck(); 73 void handlePubAck();
subscriptionstore.cpp
@@ -30,6 +30,20 @@ void SubscriptionNode::addSubscriber(const std::shared_ptr&lt;Session&gt; &amp;subscriber, @@ -30,6 +30,20 @@ void SubscriptionNode::addSubscriber(const std::shared_ptr&lt;Session&gt; &amp;subscriber,
30 } 30 }
31 } 31 }
32 32
  33 +void SubscriptionNode::removeSubscriber(const std::shared_ptr<Session> &subscriber)
  34 +{
  35 + Subscription sub;
  36 + sub.session = subscriber;
  37 + sub.qos = 0;
  38 +
  39 + auto it = std::find(subscribers.begin(), subscribers.end(), sub);
  40 +
  41 + if (it != subscribers.end())
  42 + {
  43 + subscribers.erase(it);
  44 + }
  45 +}
  46 +
33 47
34 SubscriptionStore::SubscriptionStore() : 48 SubscriptionStore::SubscriptionStore() :
35 root(new SubscriptionNode("root")), 49 root(new SubscriptionNode("root")),
@@ -84,6 +98,52 @@ void SubscriptionStore::addSubscription(Client_p &amp;client, const std::string &amp;top @@ -84,6 +98,52 @@ void SubscriptionStore::addSubscription(Client_p &amp;client, const std::string &amp;top
84 98
85 } 99 }
86 100
  101 +void SubscriptionStore::removeSubscription(Client_p &client, const std::string &topic)
  102 +{
  103 + const std::list<std::string> subtopics = split(topic, '/');
  104 +
  105 + RWLockGuard lock_guard(&subscriptionsRwlock);
  106 + lock_guard.wrlock();
  107 +
  108 + // TODO: because it's so similar to adding a subscription, make a function to retrieve the deepest node?
  109 + SubscriptionNode *deepestNode = root.get();
  110 + for(const std::string &subtopic : subtopics)
  111 + {
  112 + std::unique_ptr<SubscriptionNode> *selectedChildren = nullptr;
  113 +
  114 + if (subtopic == "#")
  115 + selectedChildren = &deepestNode->childrenPound;
  116 + else if (subtopic == "+")
  117 + selectedChildren = &deepestNode->childrenPlus;
  118 + else
  119 + selectedChildren = &deepestNode->children[subtopic];
  120 +
  121 + std::unique_ptr<SubscriptionNode> &node = *selectedChildren;
  122 +
  123 + if (!node)
  124 + {
  125 + return;
  126 + }
  127 + deepestNode = node.get();
  128 + }
  129 +
  130 + assert(deepestNode);
  131 +
  132 + if (deepestNode)
  133 + {
  134 + auto session_it = sessionsByIdConst.find(client->getClientId());
  135 + if (session_it != sessionsByIdConst.end())
  136 + {
  137 + const std::shared_ptr<Session> &ses = session_it->second;
  138 + deepestNode->removeSubscriber(ses);
  139 + }
  140 + }
  141 +
  142 + lock_guard.unlock();
  143 +
  144 +
  145 +}
  146 +
87 // Removes an existing client when it already exists [MQTT-3.1.4-2]. 147 // Removes an existing client when it already exists [MQTT-3.1.4-2].
88 void SubscriptionStore::registerClientAndKickExistingOne(Client_p &client) 148 void SubscriptionStore::registerClientAndKickExistingOne(Client_p &client)
89 { 149 {
subscriptionstore.h
@@ -42,6 +42,7 @@ public: @@ -42,6 +42,7 @@ public:
42 42
43 std::vector<Subscription> &getSubscribers(); 43 std::vector<Subscription> &getSubscribers();
44 void addSubscriber(const std::shared_ptr<Session> &subscriber, char qos); 44 void addSubscriber(const std::shared_ptr<Session> &subscriber, char qos);
  45 + void removeSubscriber(const std::shared_ptr<Session> &subscriber);
45 std::unordered_map<std::string, std::unique_ptr<SubscriptionNode>> children; 46 std::unordered_map<std::string, std::unique_ptr<SubscriptionNode>> children;
46 std::unique_ptr<SubscriptionNode> childrenPlus; 47 std::unique_ptr<SubscriptionNode> childrenPlus;
47 std::unique_ptr<SubscriptionNode> childrenPound; 48 std::unique_ptr<SubscriptionNode> childrenPound;
@@ -69,6 +70,7 @@ public: @@ -69,6 +70,7 @@ public:
69 SubscriptionStore(); 70 SubscriptionStore();
70 71
71 void addSubscription(Client_p &client, const std::string &topic, char qos); 72 void addSubscription(Client_p &client, const std::string &topic, char qos);
  73 + void removeSubscription(Client_p &client, const std::string &topic);
72 void registerClientAndKickExistingOne(Client_p &client); 74 void registerClientAndKickExistingOne(Client_p &client);
73 bool sessionPresent(const std::string &clientid); 75 bool sessionPresent(const std::string &clientid);
74 76
types.cpp
@@ -54,3 +54,14 @@ size_t PubAck::getLengthWithoutFixedHeader() const @@ -54,3 +54,14 @@ size_t PubAck::getLengthWithoutFixedHeader() const
54 { 54 {
55 return 2; 55 return 2;
56 } 56 }
  57 +
  58 +UnsubAck::UnsubAck(uint16_t packet_id) :
  59 + packet_id(packet_id)
  60 +{
  61 +
  62 +}
  63 +
  64 +size_t UnsubAck::getLengthWithoutFixedHeader() const
  65 +{
  66 + return 2;
  67 +}
@@ -69,6 +69,14 @@ public: @@ -69,6 +69,14 @@ public:
69 size_t getLengthWithoutFixedHeader() const; 69 size_t getLengthWithoutFixedHeader() const;
70 }; 70 };
71 71
  72 +class UnsubAck
  73 +{
  74 +public:
  75 + uint16_t packet_id;
  76 + UnsubAck(uint16_t packet_id);
  77 + size_t getLengthWithoutFixedHeader() const;
  78 +};
  79 +
72 class Publish 80 class Publish
73 { 81 {
74 public: 82 public: