Commit 70896f68dcd8390aa52433f1472b36e9b9f61f4b

Authored by Wiebe Cazemier
1 parent e3ad153c

Start of retained messages

Also materializes some concepts about MqttPacket.
client.cpp
@@ -83,11 +83,22 @@ void Client::writeMqttPacket(const MqttPacket &packet) @@ -83,11 +83,22 @@ void Client::writeMqttPacket(const MqttPacket &packet)
83 if (packet.packetType == PacketType::PUBLISH && wwi > CLIENT_MAX_BUFFER_SIZE) 83 if (packet.packetType == PacketType::PUBLISH && wwi > CLIENT_MAX_BUFFER_SIZE)
84 return; 84 return;
85 85
86 - if (packet.getSize() > getWriteBufMaxWriteSize())  
87 - growWriteBuffer(packet.getSize()); 86 + if (packet.getSizeIncludingNonPresentHeader() > getWriteBufMaxWriteSize())
  87 + growWriteBuffer(packet.getSizeIncludingNonPresentHeader());
88 88
89 - std::memcpy(&writebuf[wwi], &packet.getBites()[0], packet.getSize());  
90 - wwi += packet.getSize(); 89 + if (!packet.containsFixedHeader())
  90 + {
  91 + writebuf[wwi++] = packet.getFirstByte();
  92 + RemainingLength r = packet.getRemainingLength();
  93 + std::memcpy(&writebuf[wwi], r.bytes, r.len);
  94 + wwi += r.len;
  95 + }
  96 +
  97 + std::memcpy(&writebuf[wwi], &packet.getBites()[0], packet.getBites().size());
  98 + wwi += packet.getBites().size();
  99 +
  100 + assert(wwi >= static_cast<int>(packet.getSizeIncludingNonPresentHeader()));
  101 + assert(wwi <= static_cast<int>(writeBufsize));
91 102
92 setReadyForWriting(true); 103 setReadyForWriting(true);
93 } 104 }
mqttpacket.cpp
@@ -2,6 +2,12 @@ @@ -2,6 +2,12 @@
2 #include <cstring> 2 #include <cstring>
3 #include <iostream> 3 #include <iostream>
4 #include <list> 4 #include <list>
  5 +#include <cassert>
  6 +
  7 +RemainingLength::RemainingLength()
  8 +{
  9 + memset(bytes, 0, 4);
  10 +}
5 11
6 // constructor for parsing incoming packets 12 // constructor for parsing incoming packets
7 MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender) : 13 MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender) :
@@ -16,9 +22,11 @@ MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client @@ -16,9 +22,11 @@ MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client
16 pos += fixed_header_length; 22 pos += fixed_header_length;
17 } 23 }
18 24
  25 +// This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector.
19 MqttPacket::MqttPacket(const ConnAck &connAck) : 26 MqttPacket::MqttPacket(const ConnAck &connAck) :
20 - bites(4) 27 + bites(connAck.getLength() + 2)
21 { 28 {
  29 + fixed_header_length = 2;
22 packetType = PacketType::CONNACK; 30 packetType = PacketType::CONNACK;
23 char first_byte = static_cast<char>(packetType) << 4; 31 char first_byte = static_cast<char>(packetType) << 4;
24 writeByte(first_byte); 32 writeByte(first_byte);
@@ -31,6 +39,7 @@ MqttPacket::MqttPacket(const ConnAck &amp;connAck) : @@ -31,6 +39,7 @@ MqttPacket::MqttPacket(const ConnAck &amp;connAck) :
31 MqttPacket::MqttPacket(const SubAck &subAck) : 39 MqttPacket::MqttPacket(const SubAck &subAck) :
32 bites(3) 40 bites(3)
33 { 41 {
  42 + fixed_header_length = 2; // TODO: this is wrong, pending implementation of the new method in SubAck
34 packetType = PacketType::SUBACK; 43 packetType = PacketType::SUBACK;
35 char first_byte = static_cast<char>(packetType) << 4; 44 char first_byte = static_cast<char>(packetType) << 4;
36 writeByte(first_byte); 45 writeByte(first_byte);
@@ -48,11 +57,17 @@ MqttPacket::MqttPacket(const SubAck &amp;subAck) : @@ -48,11 +57,17 @@ MqttPacket::MqttPacket(const SubAck &amp;subAck) :
48 } 57 }
49 58
50 MqttPacket::MqttPacket(const Publish &publish) : 59 MqttPacket::MqttPacket(const Publish &publish) :
51 - bites(publish.topic.length() + publish.payload.length() + 2 + 2) // TODO: same as above 60 + bites(publish.getLength())
52 { 61 {
  62 + if (publish.topic.length() > 0xFFFF)
  63 + {
  64 + throw ProtocolError("Topic path too long.");
  65 + }
  66 +
53 packetType = PacketType::PUBLISH; 67 packetType = PacketType::PUBLISH;
54 - char first_byte = static_cast<char>(packetType) << 4;  
55 - writeByte(first_byte); 68 + first_byte = static_cast<char>(packetType) << 4;
  69 + first_byte |= (publish.qos << 1);
  70 + first_byte |= (static_cast<char>(publish.retain) & 0b00000001);
56 71
57 char topicLenMSB = (publish.topic.length() & 0xF0) >> 8; 72 char topicLenMSB = (publish.topic.length() & 0xF0) >> 8;
58 char topicLenLSB = publish.topic.length() & 0x0F; 73 char topicLenLSB = publish.topic.length() & 0x0F;
@@ -60,9 +75,13 @@ MqttPacket::MqttPacket(const Publish &amp;publish) : @@ -60,9 +75,13 @@ MqttPacket::MqttPacket(const Publish &amp;publish) :
60 writeByte(topicLenLSB); 75 writeByte(topicLenLSB);
61 writeBytes(publish.topic.c_str(), publish.topic.length()); 76 writeBytes(publish.topic.c_str(), publish.topic.length());
62 77
63 - writeBytes(publish.payload.c_str(), publish.payload.length()); 78 + if (publish.qos)
  79 + {
  80 + throw NotImplementedException("I would write two bytes containing the packet id here, but QoS is not done yet.");
  81 + }
64 82
65 - // TODO: untested. May be unnecessary, because a received packet can be resent just like that. 83 + writeBytes(publish.payload.c_str(), publish.payload.length());
  84 + calculateRemainingLength();
66 } 85 }
67 86
68 void MqttPacket::handle(std::shared_ptr<SubscriptionStore> &subscriptionStore) 87 void MqttPacket::handle(std::shared_ptr<SubscriptionStore> &subscriptionStore)
@@ -201,23 +220,77 @@ void MqttPacket::handlePublish(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptionS @@ -201,23 +220,77 @@ void MqttPacket::handlePublish(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptionS
201 if (variable_header_length == 0) 220 if (variable_header_length == 0)
202 throw ProtocolError("Empty publish topic"); 221 throw ProtocolError("Empty publish topic");
203 222
  223 + bool retain = (first_byte & 0b00000001);
  224 + bool dup = !!(first_byte & 0b00001000);
204 char qos = (first_byte & 0b00000110) >> 1; 225 char qos = (first_byte & 0b00000110) >> 1;
205 226
  227 + if (qos == 3)
  228 + throw ProtocolError("QoS 3 is a protocol violation.");
  229 +
  230 + // TODO: validate UTF8.
206 std::string topic(readBytes(variable_header_length), variable_header_length); 231 std::string topic(readBytes(variable_header_length), variable_header_length);
207 232
208 if (qos) 233 if (qos)
209 { 234 {
210 throw ProtocolError("Qos not implemented."); 235 throw ProtocolError("Qos not implemented.");
211 - //uint16_t packet_id = readTwoBytesToUInt16(); 236 + uint16_t packet_id = readTwoBytesToUInt16();
212 } 237 }
213 238
214 - // TODO: validate UTF8.  
215 - size_t payload_length = remainingAfterPos();  
216 - std::string payload(readBytes(payload_length), payload_length); 239 + if (retain)
  240 + {
  241 + size_t payload_length = remainingAfterPos();
  242 + std::string payload(readBytes(payload_length), payload_length);
217 243
  244 + subscriptionStore->setRetainedMessage(topic, payload, qos);
  245 + }
  246 +
  247 + // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3].
  248 + // Existing subscribers don't get retain=1. [MQTT-3.3.1-9]
  249 + bites[0] &= 0b11110110;
  250 +
  251 + // For the existing clients, we can just write the same packet back out, with our small alterations.
218 subscriptionStore->queuePacketAtSubscribers(topic, *this, sender); 252 subscriptionStore->queuePacketAtSubscribers(topic, *this, sender);
219 } 253 }
220 254
  255 +void MqttPacket::calculateRemainingLength()
  256 +{
  257 + assert(fixed_header_length == 0); // because you're not supposed to call this on packet that we already know the length of.
  258 +
  259 + size_t x = bites.size();
  260 +
  261 + do
  262 + {
  263 + if (remainingLength.len > 4)
  264 + throw std::runtime_error("Calculated remaining length is longer than 4 bytes.");
  265 +
  266 + char encodedByte = x % 128;
  267 + x = x / 128;
  268 + if (x > 0)
  269 + encodedByte = encodedByte | 128;
  270 + remainingLength.bytes[remainingLength.len++] = encodedByte;
  271 + }
  272 + while(x > 0);
  273 +}
  274 +
  275 +RemainingLength MqttPacket::getRemainingLength() const
  276 +{
  277 + assert(remainingLength.len > 0);
  278 + return remainingLength;
  279 +}
  280 +
  281 +size_t MqttPacket::getSizeIncludingNonPresentHeader() const
  282 +{
  283 + size_t total = bites.size();
  284 +
  285 + if (fixed_header_length == 0)
  286 + {
  287 + total++;
  288 + total += remainingLength.len;
  289 + }
  290 +
  291 + return total;
  292 +}
  293 +
221 294
222 Client_p MqttPacket::getSender() const 295 Client_p MqttPacket::getSender() const
223 { 296 {
@@ -229,6 +302,16 @@ void MqttPacket::setSender(const Client_p &amp;value) @@ -229,6 +302,16 @@ void MqttPacket::setSender(const Client_p &amp;value)
229 sender = value; 302 sender = value;
230 } 303 }
231 304
  305 +bool MqttPacket::containsFixedHeader() const
  306 +{
  307 + return fixed_header_length > 0;
  308 +}
  309 +
  310 +char MqttPacket::getFirstByte() const
  311 +{
  312 + return first_byte;
  313 +}
  314 +
232 char *MqttPacket::readBytes(size_t length) 315 char *MqttPacket::readBytes(size_t length)
233 { 316 {
234 if (pos + length > bites.size()) 317 if (pos + length > bites.size())
@@ -262,6 +345,7 @@ void MqttPacket::writeBytes(const char *b, size_t len) @@ -262,6 +345,7 @@ void MqttPacket::writeBytes(const char *b, size_t len)
262 throw ProtocolError("Exceeding packet size"); 345 throw ProtocolError("Exceeding packet size");
263 346
264 memcpy(&bites[pos], b, len); 347 memcpy(&bites[pos], b, len);
  348 + pos += len;
265 } 349 }
266 350
267 uint16_t MqttPacket::readTwoBytesToUInt16() 351 uint16_t MqttPacket::readTwoBytesToUInt16()
@@ -287,3 +371,5 @@ size_t MqttPacket::remainingAfterPos() @@ -287,3 +371,5 @@ size_t MqttPacket::remainingAfterPos()
287 371
288 372
289 373
  374 +
  375 +
mqttpacket.h
@@ -13,13 +13,21 @@ @@ -13,13 +13,21 @@
13 #include "types.h" 13 #include "types.h"
14 #include "subscriptionstore.h" 14 #include "subscriptionstore.h"
15 15
  16 +struct RemainingLength
  17 +{
  18 + char bytes[4];
  19 + int len = 0;
  20 +public:
  21 + RemainingLength();
  22 +};
16 23
17 class MqttPacket 24 class MqttPacket
18 { 25 {
19 std::vector<char> bites; 26 std::vector<char> bites;
20 - size_t fixed_header_length = 0; 27 + size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header.
  28 + RemainingLength remainingLength;
21 Client_p sender; 29 Client_p sender;
22 - char first_byte; 30 + char first_byte = 0;
23 size_t pos = 0; 31 size_t pos = 0;
24 ProtocolVersion protocolVersion = ProtocolVersion::None; 32 ProtocolVersion protocolVersion = ProtocolVersion::None;
25 33
@@ -30,12 +38,13 @@ class MqttPacket @@ -30,12 +38,13 @@ class MqttPacket
30 uint16_t readTwoBytesToUInt16(); 38 uint16_t readTwoBytesToUInt16();
31 size_t remainingAfterPos(); 39 size_t remainingAfterPos();
32 40
  41 + void calculateRemainingLength();
  42 +
33 public: 43 public:
34 PacketType packetType = PacketType::Reserved; 44 PacketType packetType = PacketType::Reserved;
35 - MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender); 45 + MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender); // Constructor for parsing incoming packets.
36 46
37 - // TODO: not constructors, but static functions that return all the stuff after the fixed header, then a constructor with vector.  
38 - // Or, I can not have the fixed header, and calculate that on write-to-buf. 47 + // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance.
39 MqttPacket(const ConnAck &connAck); 48 MqttPacket(const ConnAck &connAck);
40 MqttPacket(const SubAck &subAck); 49 MqttPacket(const SubAck &subAck);
41 MqttPacket(const Publish &publish); 50 MqttPacket(const Publish &publish);
@@ -46,11 +55,15 @@ public: @@ -46,11 +55,15 @@ public:
46 void handlePing(); 55 void handlePing();
47 void handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore); 56 void handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore);
48 57
49 - size_t getSize() const { return bites.size(); } 58 + size_t getSizeIncludingNonPresentHeader() const;
50 const std::vector<char> &getBites() const { return bites; } 59 const std::vector<char> &getBites() const { return bites; }
51 60
52 Client_p getSender() const; 61 Client_p getSender() const;
53 void setSender(const Client_p &value); 62 void setSender(const Client_p &value);
  63 +
  64 + bool containsFixedHeader() const;
  65 + char getFirstByte() const;
  66 + RemainingLength getRemainingLength() const;
54 }; 67 };
55 68
56 #endif // MQTTPACKET_H 69 #endif // MQTTPACKET_H
rwlockguard.cpp
@@ -10,7 +10,7 @@ RWLockGuard::RWLockGuard(pthread_rwlock_t *rwlock) : @@ -10,7 +10,7 @@ RWLockGuard::RWLockGuard(pthread_rwlock_t *rwlock) :
10 10
11 RWLockGuard::~RWLockGuard() 11 RWLockGuard::~RWLockGuard()
12 { 12 {
13 - pthread_rwlock_unlock(rwlock); 13 + unlock();
14 } 14 }
15 15
16 void RWLockGuard::wrlock() 16 void RWLockGuard::wrlock()
@@ -24,3 +24,12 @@ void RWLockGuard::rdlock() @@ -24,3 +24,12 @@ void RWLockGuard::rdlock()
24 if (pthread_rwlock_wrlock(rwlock) != 0) 24 if (pthread_rwlock_wrlock(rwlock) != 0)
25 throw std::runtime_error("rdlock failed."); 25 throw std::runtime_error("rdlock failed.");
26 } 26 }
  27 +
  28 +void RWLockGuard::unlock()
  29 +{
  30 + if (rwlock != NULL)
  31 + {
  32 + pthread_rwlock_unlock(rwlock);
  33 + rwlock = NULL;
  34 + }
  35 +}
rwlockguard.h
@@ -12,6 +12,7 @@ public: @@ -12,6 +12,7 @@ public:
12 ~RWLockGuard(); 12 ~RWLockGuard();
13 void wrlock(); 13 void wrlock();
14 void rdlock(); 14 void rdlock();
  15 + void unlock();
15 }; 16 };
16 17
17 #endif // RWLOCKGUARD_H 18 #endif // RWLOCKGUARD_H
subscriptionstore.cpp
@@ -44,6 +44,9 @@ void SubscriptionStore::addSubscription(Client_p &amp;client, const std::string &amp;top @@ -44,6 +44,9 @@ void SubscriptionStore::addSubscription(Client_p &amp;client, const std::string &amp;top
44 } 44 }
45 45
46 clients_by_id[client->getClientId()] = client; 46 clients_by_id[client->getClientId()] = client;
  47 + lock_guard.unlock();
  48 +
  49 + giveClientRetainedMessage(client, topic); // TODO: wildcards
47 } 50 }
48 51
49 void SubscriptionStore::removeClient(const Client_p &client) 52 void SubscriptionStore::removeClient(const Client_p &client)
@@ -119,4 +122,44 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::string &amp;topic, const @@ -119,4 +122,44 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::string &amp;topic, const
119 publishRecursively(subtopics.begin(), subtopics.end(), root, packet); 122 publishRecursively(subtopics.begin(), subtopics.end(), root, packet);
120 } 123 }
121 124
  125 +void SubscriptionStore::giveClientRetainedMessage(Client_p &client, const std::string &topic)
  126 +{
  127 + RWLockGuard locker(&retainedMessagesRwlock);
  128 + locker.rdlock();
  129 +
  130 + auto retained_ptr = retainedMessages.find(topic);
  131 +
  132 + if (retained_ptr == retainedMessages.end())
  133 + return;
  134 +
  135 + const RetainedPayload &m = retained_ptr->second;
  136 +
  137 + Publish publish(topic, m.payload, m.qos);
  138 + publish.retain = true;
  139 + const MqttPacket packet(publish);
  140 + client->writeMqttPacket(packet);
  141 +}
  142 +
  143 +void SubscriptionStore::setRetainedMessage(const std::string &topic, const std::string &payload, char qos)
  144 +{
  145 + RWLockGuard locker(&retainedMessagesRwlock);
  146 + locker.wrlock();
  147 +
  148 + auto retained_ptr = retainedMessages.find(topic);
  149 + bool retained_found = retained_ptr != retainedMessages.end();
  150 +
  151 + if (!retained_found && payload.empty())
  152 + return;
  153 +
  154 + if (retained_found && payload.empty())
  155 + {
  156 + retainedMessages.erase(topic);
  157 + return;
  158 + }
  159 +
  160 + RetainedPayload &m = retainedMessages[topic];
  161 + m.payload = payload;
  162 + m.qos = qos;
  163 +}
  164 +
122 165
subscriptionstore.h
@@ -12,6 +12,12 @@ @@ -12,6 +12,12 @@
12 #include "client.h" 12 #include "client.h"
13 #include "utils.h" 13 #include "utils.h"
14 14
  15 +struct RetainedPayload
  16 +{
  17 + std::string payload;
  18 + char qos;
  19 +};
  20 +
15 class SubscriptionNode 21 class SubscriptionNode
16 { 22 {
17 std::string subtopic; 23 std::string subtopic;
@@ -32,6 +38,9 @@ class SubscriptionStore @@ -32,6 +38,9 @@ class SubscriptionStore
32 std::unordered_map<std::string, Client_p> clients_by_id; 38 std::unordered_map<std::string, Client_p> clients_by_id;
33 const std::unordered_map<std::string, Client_p> &clients_by_id_const; 39 const std::unordered_map<std::string, Client_p> &clients_by_id_const;
34 40
  41 + pthread_rwlock_t retainedMessagesRwlock = PTHREAD_RWLOCK_INITIALIZER;
  42 + std::unordered_map<std::string, RetainedPayload> retainedMessages;
  43 +
35 bool publishNonRecursively(const MqttPacket &packet, const std::forward_list<std::string> &subscribers) const; 44 bool publishNonRecursively(const MqttPacket &packet, const std::forward_list<std::string> &subscribers) const;
36 bool publishRecursively(std::list<std::string>::const_iterator cur_subtopic_it, std::list<std::string>::const_iterator end, 45 bool publishRecursively(std::list<std::string>::const_iterator cur_subtopic_it, std::list<std::string>::const_iterator end,
37 std::unique_ptr<SubscriptionNode> &next, const MqttPacket &packet) const; 46 std::unique_ptr<SubscriptionNode> &next, const MqttPacket &packet) const;
@@ -42,6 +51,9 @@ public: @@ -42,6 +51,9 @@ public:
42 void removeClient(const Client_p &client); 51 void removeClient(const Client_p &client);
43 52
44 void queuePacketAtSubscribers(const std::string &topic, const MqttPacket &packet, const Client_p &sender); 53 void queuePacketAtSubscribers(const std::string &topic, const MqttPacket &packet, const Client_p &sender);
  54 + void giveClientRetainedMessage(Client_p &client, const std::string &topic);
  55 +
  56 + void setRetainedMessage(const std::string &topic, const std::string &payload, char qos);
45 }; 57 };
46 58
47 #endif // SUBSCRIPTIONSTORE_H 59 #endif // SUBSCRIPTIONSTORE_H
types.cpp
@@ -16,9 +16,21 @@ SubAck::SubAck(uint16_t packet_id, const std::list&lt;std::string&gt; &amp;subs) : @@ -16,9 +16,21 @@ SubAck::SubAck(uint16_t packet_id, const std::list&lt;std::string&gt; &amp;subs) :
16 } 16 }
17 } 17 }
18 18
19 -Publish::Publish(std::string &topic, std::string payload) : 19 +Publish::Publish(const std::string &topic, const std::string payload, char qos) :
20 topic(topic), 20 topic(topic),
21 - payload(payload) 21 + payload(payload),
  22 + qos(qos)
22 { 23 {
23 24
24 } 25 }
  26 +
  27 +// Length starting at the variable header, not the fixed header.
  28 +size_t Publish::getLength() const
  29 +{
  30 + int result = topic.length() + payload.length() + 2;
  31 +
  32 + if (qos)
  33 + result += 2;
  34 +
  35 + return result;
  36 +}
@@ -48,6 +48,7 @@ class ConnAck @@ -48,6 +48,7 @@ class ConnAck
48 public: 48 public:
49 ConnAck(ConnAckReturnCodes return_code); 49 ConnAck(ConnAckReturnCodes return_code);
50 ConnAckReturnCodes return_code; 50 ConnAckReturnCodes return_code;
  51 + size_t getLength() const { return 2;} // size of connack is always the same
51 }; 52 };
52 53
53 enum class SubAckReturnCodes 54 enum class SubAckReturnCodes
@@ -71,7 +72,10 @@ class Publish @@ -71,7 +72,10 @@ class Publish
71 public: 72 public:
72 std::string topic; 73 std::string topic;
73 std::string payload; 74 std::string payload;
74 - Publish(std::string &topic, std::string payload); 75 + char qos = 0;
  76 + bool retain = false; // Note: existing subscribers don't get publishes of retained messages with retain=1. [MQTT-3.3.1-9]
  77 + Publish(const std::string &topic, const std::string payload, char qos);
  78 + size_t getLength() const;
75 }; 79 };
76 80
77 #endif // TYPES_H 81 #endif // TYPES_H