Commit 70896f68dcd8390aa52433f1472b36e9b9f61f4b

Authored by Wiebe Cazemier
1 parent e3ad153c

Start of retained messages

Also materializes some concepts about MqttPacket.
client.cpp
... ... @@ -83,11 +83,22 @@ void Client::writeMqttPacket(const MqttPacket &packet)
83 83 if (packet.packetType == PacketType::PUBLISH && wwi > CLIENT_MAX_BUFFER_SIZE)
84 84 return;
85 85  
86   - if (packet.getSize() > getWriteBufMaxWriteSize())
87   - growWriteBuffer(packet.getSize());
  86 + if (packet.getSizeIncludingNonPresentHeader() > getWriteBufMaxWriteSize())
  87 + growWriteBuffer(packet.getSizeIncludingNonPresentHeader());
88 88  
89   - std::memcpy(&writebuf[wwi], &packet.getBites()[0], packet.getSize());
90   - wwi += packet.getSize();
  89 + if (!packet.containsFixedHeader())
  90 + {
  91 + writebuf[wwi++] = packet.getFirstByte();
  92 + RemainingLength r = packet.getRemainingLength();
  93 + std::memcpy(&writebuf[wwi], r.bytes, r.len);
  94 + wwi += r.len;
  95 + }
  96 +
  97 + std::memcpy(&writebuf[wwi], &packet.getBites()[0], packet.getBites().size());
  98 + wwi += packet.getBites().size();
  99 +
  100 + assert(wwi >= static_cast<int>(packet.getSizeIncludingNonPresentHeader()));
  101 + assert(wwi <= static_cast<int>(writeBufsize));
91 102  
92 103 setReadyForWriting(true);
93 104 }
... ...
mqttpacket.cpp
... ... @@ -2,6 +2,12 @@
2 2 #include <cstring>
3 3 #include <iostream>
4 4 #include <list>
  5 +#include <cassert>
  6 +
  7 +RemainingLength::RemainingLength()
  8 +{
  9 + memset(bytes, 0, 4);
  10 +}
5 11  
6 12 // constructor for parsing incoming packets
7 13 MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender) :
... ... @@ -16,9 +22,11 @@ MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client
16 22 pos += fixed_header_length;
17 23 }
18 24  
  25 +// This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector.
19 26 MqttPacket::MqttPacket(const ConnAck &connAck) :
20   - bites(4)
  27 + bites(connAck.getLength() + 2)
21 28 {
  29 + fixed_header_length = 2;
22 30 packetType = PacketType::CONNACK;
23 31 char first_byte = static_cast<char>(packetType) << 4;
24 32 writeByte(first_byte);
... ... @@ -31,6 +39,7 @@ MqttPacket::MqttPacket(const ConnAck &amp;connAck) :
31 39 MqttPacket::MqttPacket(const SubAck &subAck) :
32 40 bites(3)
33 41 {
  42 + fixed_header_length = 2; // TODO: this is wrong, pending implementation of the new method in SubAck
34 43 packetType = PacketType::SUBACK;
35 44 char first_byte = static_cast<char>(packetType) << 4;
36 45 writeByte(first_byte);
... ... @@ -48,11 +57,17 @@ MqttPacket::MqttPacket(const SubAck &amp;subAck) :
48 57 }
49 58  
50 59 MqttPacket::MqttPacket(const Publish &publish) :
51   - bites(publish.topic.length() + publish.payload.length() + 2 + 2) // TODO: same as above
  60 + bites(publish.getLength())
52 61 {
  62 + if (publish.topic.length() > 0xFFFF)
  63 + {
  64 + throw ProtocolError("Topic path too long.");
  65 + }
  66 +
53 67 packetType = PacketType::PUBLISH;
54   - char first_byte = static_cast<char>(packetType) << 4;
55   - writeByte(first_byte);
  68 + first_byte = static_cast<char>(packetType) << 4;
  69 + first_byte |= (publish.qos << 1);
  70 + first_byte |= (static_cast<char>(publish.retain) & 0b00000001);
56 71  
57 72 char topicLenMSB = (publish.topic.length() & 0xF0) >> 8;
58 73 char topicLenLSB = publish.topic.length() & 0x0F;
... ... @@ -60,9 +75,13 @@ MqttPacket::MqttPacket(const Publish &amp;publish) :
60 75 writeByte(topicLenLSB);
61 76 writeBytes(publish.topic.c_str(), publish.topic.length());
62 77  
63   - writeBytes(publish.payload.c_str(), publish.payload.length());
  78 + if (publish.qos)
  79 + {
  80 + throw NotImplementedException("I would write two bytes containing the packet id here, but QoS is not done yet.");
  81 + }
64 82  
65   - // TODO: untested. May be unnecessary, because a received packet can be resent just like that.
  83 + writeBytes(publish.payload.c_str(), publish.payload.length());
  84 + calculateRemainingLength();
66 85 }
67 86  
68 87 void MqttPacket::handle(std::shared_ptr<SubscriptionStore> &subscriptionStore)
... ... @@ -201,23 +220,77 @@ void MqttPacket::handlePublish(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptionS
201 220 if (variable_header_length == 0)
202 221 throw ProtocolError("Empty publish topic");
203 222  
  223 + bool retain = (first_byte & 0b00000001);
  224 + bool dup = !!(first_byte & 0b00001000);
204 225 char qos = (first_byte & 0b00000110) >> 1;
205 226  
  227 + if (qos == 3)
  228 + throw ProtocolError("QoS 3 is a protocol violation.");
  229 +
  230 + // TODO: validate UTF8.
206 231 std::string topic(readBytes(variable_header_length), variable_header_length);
207 232  
208 233 if (qos)
209 234 {
210 235 throw ProtocolError("Qos not implemented.");
211   - //uint16_t packet_id = readTwoBytesToUInt16();
  236 + uint16_t packet_id = readTwoBytesToUInt16();
212 237 }
213 238  
214   - // TODO: validate UTF8.
215   - size_t payload_length = remainingAfterPos();
216   - std::string payload(readBytes(payload_length), payload_length);
  239 + if (retain)
  240 + {
  241 + size_t payload_length = remainingAfterPos();
  242 + std::string payload(readBytes(payload_length), payload_length);
217 243  
  244 + subscriptionStore->setRetainedMessage(topic, payload, qos);
  245 + }
  246 +
  247 + // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3].
  248 + // Existing subscribers don't get retain=1. [MQTT-3.3.1-9]
  249 + bites[0] &= 0b11110110;
  250 +
  251 + // For the existing clients, we can just write the same packet back out, with our small alterations.
218 252 subscriptionStore->queuePacketAtSubscribers(topic, *this, sender);
219 253 }
220 254  
  255 +void MqttPacket::calculateRemainingLength()
  256 +{
  257 + assert(fixed_header_length == 0); // because you're not supposed to call this on packet that we already know the length of.
  258 +
  259 + size_t x = bites.size();
  260 +
  261 + do
  262 + {
  263 + if (remainingLength.len > 4)
  264 + throw std::runtime_error("Calculated remaining length is longer than 4 bytes.");
  265 +
  266 + char encodedByte = x % 128;
  267 + x = x / 128;
  268 + if (x > 0)
  269 + encodedByte = encodedByte | 128;
  270 + remainingLength.bytes[remainingLength.len++] = encodedByte;
  271 + }
  272 + while(x > 0);
  273 +}
  274 +
  275 +RemainingLength MqttPacket::getRemainingLength() const
  276 +{
  277 + assert(remainingLength.len > 0);
  278 + return remainingLength;
  279 +}
  280 +
  281 +size_t MqttPacket::getSizeIncludingNonPresentHeader() const
  282 +{
  283 + size_t total = bites.size();
  284 +
  285 + if (fixed_header_length == 0)
  286 + {
  287 + total++;
  288 + total += remainingLength.len;
  289 + }
  290 +
  291 + return total;
  292 +}
  293 +
221 294  
222 295 Client_p MqttPacket::getSender() const
223 296 {
... ... @@ -229,6 +302,16 @@ void MqttPacket::setSender(const Client_p &amp;value)
229 302 sender = value;
230 303 }
231 304  
  305 +bool MqttPacket::containsFixedHeader() const
  306 +{
  307 + return fixed_header_length > 0;
  308 +}
  309 +
  310 +char MqttPacket::getFirstByte() const
  311 +{
  312 + return first_byte;
  313 +}
  314 +
232 315 char *MqttPacket::readBytes(size_t length)
233 316 {
234 317 if (pos + length > bites.size())
... ... @@ -262,6 +345,7 @@ void MqttPacket::writeBytes(const char *b, size_t len)
262 345 throw ProtocolError("Exceeding packet size");
263 346  
264 347 memcpy(&bites[pos], b, len);
  348 + pos += len;
265 349 }
266 350  
267 351 uint16_t MqttPacket::readTwoBytesToUInt16()
... ... @@ -287,3 +371,5 @@ size_t MqttPacket::remainingAfterPos()
287 371  
288 372  
289 373  
  374 +
  375 +
... ...
mqttpacket.h
... ... @@ -13,13 +13,21 @@
13 13 #include "types.h"
14 14 #include "subscriptionstore.h"
15 15  
  16 +struct RemainingLength
  17 +{
  18 + char bytes[4];
  19 + int len = 0;
  20 +public:
  21 + RemainingLength();
  22 +};
16 23  
17 24 class MqttPacket
18 25 {
19 26 std::vector<char> bites;
20   - size_t fixed_header_length = 0;
  27 + size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header.
  28 + RemainingLength remainingLength;
21 29 Client_p sender;
22   - char first_byte;
  30 + char first_byte = 0;
23 31 size_t pos = 0;
24 32 ProtocolVersion protocolVersion = ProtocolVersion::None;
25 33  
... ... @@ -30,12 +38,13 @@ class MqttPacket
30 38 uint16_t readTwoBytesToUInt16();
31 39 size_t remainingAfterPos();
32 40  
  41 + void calculateRemainingLength();
  42 +
33 43 public:
34 44 PacketType packetType = PacketType::Reserved;
35   - MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender);
  45 + MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender); // Constructor for parsing incoming packets.
36 46  
37   - // TODO: not constructors, but static functions that return all the stuff after the fixed header, then a constructor with vector.
38   - // Or, I can not have the fixed header, and calculate that on write-to-buf.
  47 + // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance.
39 48 MqttPacket(const ConnAck &connAck);
40 49 MqttPacket(const SubAck &subAck);
41 50 MqttPacket(const Publish &publish);
... ... @@ -46,11 +55,15 @@ public:
46 55 void handlePing();
47 56 void handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore);
48 57  
49   - size_t getSize() const { return bites.size(); }
  58 + size_t getSizeIncludingNonPresentHeader() const;
50 59 const std::vector<char> &getBites() const { return bites; }
51 60  
52 61 Client_p getSender() const;
53 62 void setSender(const Client_p &value);
  63 +
  64 + bool containsFixedHeader() const;
  65 + char getFirstByte() const;
  66 + RemainingLength getRemainingLength() const;
54 67 };
55 68  
56 69 #endif // MQTTPACKET_H
... ...
rwlockguard.cpp
... ... @@ -10,7 +10,7 @@ RWLockGuard::RWLockGuard(pthread_rwlock_t *rwlock) :
10 10  
11 11 RWLockGuard::~RWLockGuard()
12 12 {
13   - pthread_rwlock_unlock(rwlock);
  13 + unlock();
14 14 }
15 15  
16 16 void RWLockGuard::wrlock()
... ... @@ -24,3 +24,12 @@ void RWLockGuard::rdlock()
24 24 if (pthread_rwlock_wrlock(rwlock) != 0)
25 25 throw std::runtime_error("rdlock failed.");
26 26 }
  27 +
  28 +void RWLockGuard::unlock()
  29 +{
  30 + if (rwlock != NULL)
  31 + {
  32 + pthread_rwlock_unlock(rwlock);
  33 + rwlock = NULL;
  34 + }
  35 +}
... ...
rwlockguard.h
... ... @@ -12,6 +12,7 @@ public:
12 12 ~RWLockGuard();
13 13 void wrlock();
14 14 void rdlock();
  15 + void unlock();
15 16 };
16 17  
17 18 #endif // RWLOCKGUARD_H
... ...
subscriptionstore.cpp
... ... @@ -44,6 +44,9 @@ void SubscriptionStore::addSubscription(Client_p &amp;client, const std::string &amp;top
44 44 }
45 45  
46 46 clients_by_id[client->getClientId()] = client;
  47 + lock_guard.unlock();
  48 +
  49 + giveClientRetainedMessage(client, topic); // TODO: wildcards
47 50 }
48 51  
49 52 void SubscriptionStore::removeClient(const Client_p &client)
... ... @@ -119,4 +122,44 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::string &amp;topic, const
119 122 publishRecursively(subtopics.begin(), subtopics.end(), root, packet);
120 123 }
121 124  
  125 +void SubscriptionStore::giveClientRetainedMessage(Client_p &client, const std::string &topic)
  126 +{
  127 + RWLockGuard locker(&retainedMessagesRwlock);
  128 + locker.rdlock();
  129 +
  130 + auto retained_ptr = retainedMessages.find(topic);
  131 +
  132 + if (retained_ptr == retainedMessages.end())
  133 + return;
  134 +
  135 + const RetainedPayload &m = retained_ptr->second;
  136 +
  137 + Publish publish(topic, m.payload, m.qos);
  138 + publish.retain = true;
  139 + const MqttPacket packet(publish);
  140 + client->writeMqttPacket(packet);
  141 +}
  142 +
  143 +void SubscriptionStore::setRetainedMessage(const std::string &topic, const std::string &payload, char qos)
  144 +{
  145 + RWLockGuard locker(&retainedMessagesRwlock);
  146 + locker.wrlock();
  147 +
  148 + auto retained_ptr = retainedMessages.find(topic);
  149 + bool retained_found = retained_ptr != retainedMessages.end();
  150 +
  151 + if (!retained_found && payload.empty())
  152 + return;
  153 +
  154 + if (retained_found && payload.empty())
  155 + {
  156 + retainedMessages.erase(topic);
  157 + return;
  158 + }
  159 +
  160 + RetainedPayload &m = retainedMessages[topic];
  161 + m.payload = payload;
  162 + m.qos = qos;
  163 +}
  164 +
122 165  
... ...
subscriptionstore.h
... ... @@ -12,6 +12,12 @@
12 12 #include "client.h"
13 13 #include "utils.h"
14 14  
  15 +struct RetainedPayload
  16 +{
  17 + std::string payload;
  18 + char qos;
  19 +};
  20 +
15 21 class SubscriptionNode
16 22 {
17 23 std::string subtopic;
... ... @@ -32,6 +38,9 @@ class SubscriptionStore
32 38 std::unordered_map<std::string, Client_p> clients_by_id;
33 39 const std::unordered_map<std::string, Client_p> &clients_by_id_const;
34 40  
  41 + pthread_rwlock_t retainedMessagesRwlock = PTHREAD_RWLOCK_INITIALIZER;
  42 + std::unordered_map<std::string, RetainedPayload> retainedMessages;
  43 +
35 44 bool publishNonRecursively(const MqttPacket &packet, const std::forward_list<std::string> &subscribers) const;
36 45 bool publishRecursively(std::list<std::string>::const_iterator cur_subtopic_it, std::list<std::string>::const_iterator end,
37 46 std::unique_ptr<SubscriptionNode> &next, const MqttPacket &packet) const;
... ... @@ -42,6 +51,9 @@ public:
42 51 void removeClient(const Client_p &client);
43 52  
44 53 void queuePacketAtSubscribers(const std::string &topic, const MqttPacket &packet, const Client_p &sender);
  54 + void giveClientRetainedMessage(Client_p &client, const std::string &topic);
  55 +
  56 + void setRetainedMessage(const std::string &topic, const std::string &payload, char qos);
45 57 };
46 58  
47 59 #endif // SUBSCRIPTIONSTORE_H
... ...
types.cpp
... ... @@ -16,9 +16,21 @@ SubAck::SubAck(uint16_t packet_id, const std::list&lt;std::string&gt; &amp;subs) :
16 16 }
17 17 }
18 18  
19   -Publish::Publish(std::string &topic, std::string payload) :
  19 +Publish::Publish(const std::string &topic, const std::string payload, char qos) :
20 20 topic(topic),
21   - payload(payload)
  21 + payload(payload),
  22 + qos(qos)
22 23 {
23 24  
24 25 }
  26 +
  27 +// Length starting at the variable header, not the fixed header.
  28 +size_t Publish::getLength() const
  29 +{
  30 + int result = topic.length() + payload.length() + 2;
  31 +
  32 + if (qos)
  33 + result += 2;
  34 +
  35 + return result;
  36 +}
... ...
... ... @@ -48,6 +48,7 @@ class ConnAck
48 48 public:
49 49 ConnAck(ConnAckReturnCodes return_code);
50 50 ConnAckReturnCodes return_code;
  51 + size_t getLength() const { return 2;} // size of connack is always the same
51 52 };
52 53  
53 54 enum class SubAckReturnCodes
... ... @@ -71,7 +72,10 @@ class Publish
71 72 public:
72 73 std::string topic;
73 74 std::string payload;
74   - Publish(std::string &topic, std::string payload);
  75 + char qos = 0;
  76 + bool retain = false; // Note: existing subscribers don't get publishes of retained messages with retain=1. [MQTT-3.3.1-9]
  77 + Publish(const std::string &topic, const std::string payload, char qos);
  78 + size_t getLength() const;
75 79 };
76 80  
77 81 #endif // TYPES_H
... ...