Commit 70896f68dcd8390aa52433f1472b36e9b9f61f4b
1 parent
e3ad153c
Start of retained messages
Also materializes some concepts about MqttPacket.
Showing
9 changed files
with
215 additions
and
24 deletions
client.cpp
| ... | ... | @@ -83,11 +83,22 @@ void Client::writeMqttPacket(const MqttPacket &packet) |
| 83 | 83 | if (packet.packetType == PacketType::PUBLISH && wwi > CLIENT_MAX_BUFFER_SIZE) |
| 84 | 84 | return; |
| 85 | 85 | |
| 86 | - if (packet.getSize() > getWriteBufMaxWriteSize()) | |
| 87 | - growWriteBuffer(packet.getSize()); | |
| 86 | + if (packet.getSizeIncludingNonPresentHeader() > getWriteBufMaxWriteSize()) | |
| 87 | + growWriteBuffer(packet.getSizeIncludingNonPresentHeader()); | |
| 88 | 88 | |
| 89 | - std::memcpy(&writebuf[wwi], &packet.getBites()[0], packet.getSize()); | |
| 90 | - wwi += packet.getSize(); | |
| 89 | + if (!packet.containsFixedHeader()) | |
| 90 | + { | |
| 91 | + writebuf[wwi++] = packet.getFirstByte(); | |
| 92 | + RemainingLength r = packet.getRemainingLength(); | |
| 93 | + std::memcpy(&writebuf[wwi], r.bytes, r.len); | |
| 94 | + wwi += r.len; | |
| 95 | + } | |
| 96 | + | |
| 97 | + std::memcpy(&writebuf[wwi], &packet.getBites()[0], packet.getBites().size()); | |
| 98 | + wwi += packet.getBites().size(); | |
| 99 | + | |
| 100 | + assert(wwi >= static_cast<int>(packet.getSizeIncludingNonPresentHeader())); | |
| 101 | + assert(wwi <= static_cast<int>(writeBufsize)); | |
| 91 | 102 | |
| 92 | 103 | setReadyForWriting(true); |
| 93 | 104 | } | ... | ... |
mqttpacket.cpp
| ... | ... | @@ -2,6 +2,12 @@ |
| 2 | 2 | #include <cstring> |
| 3 | 3 | #include <iostream> |
| 4 | 4 | #include <list> |
| 5 | +#include <cassert> | |
| 6 | + | |
| 7 | +RemainingLength::RemainingLength() | |
| 8 | +{ | |
| 9 | + memset(bytes, 0, 4); | |
| 10 | +} | |
| 5 | 11 | |
| 6 | 12 | // constructor for parsing incoming packets |
| 7 | 13 | MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender) : |
| ... | ... | @@ -16,9 +22,11 @@ MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client |
| 16 | 22 | pos += fixed_header_length; |
| 17 | 23 | } |
| 18 | 24 | |
| 25 | +// This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector. | |
| 19 | 26 | MqttPacket::MqttPacket(const ConnAck &connAck) : |
| 20 | - bites(4) | |
| 27 | + bites(connAck.getLength() + 2) | |
| 21 | 28 | { |
| 29 | + fixed_header_length = 2; | |
| 22 | 30 | packetType = PacketType::CONNACK; |
| 23 | 31 | char first_byte = static_cast<char>(packetType) << 4; |
| 24 | 32 | writeByte(first_byte); |
| ... | ... | @@ -31,6 +39,7 @@ MqttPacket::MqttPacket(const ConnAck &connAck) : |
| 31 | 39 | MqttPacket::MqttPacket(const SubAck &subAck) : |
| 32 | 40 | bites(3) |
| 33 | 41 | { |
| 42 | + fixed_header_length = 2; // TODO: this is wrong, pending implementation of the new method in SubAck | |
| 34 | 43 | packetType = PacketType::SUBACK; |
| 35 | 44 | char first_byte = static_cast<char>(packetType) << 4; |
| 36 | 45 | writeByte(first_byte); |
| ... | ... | @@ -48,11 +57,17 @@ MqttPacket::MqttPacket(const SubAck &subAck) : |
| 48 | 57 | } |
| 49 | 58 | |
| 50 | 59 | MqttPacket::MqttPacket(const Publish &publish) : |
| 51 | - bites(publish.topic.length() + publish.payload.length() + 2 + 2) // TODO: same as above | |
| 60 | + bites(publish.getLength()) | |
| 52 | 61 | { |
| 62 | + if (publish.topic.length() > 0xFFFF) | |
| 63 | + { | |
| 64 | + throw ProtocolError("Topic path too long."); | |
| 65 | + } | |
| 66 | + | |
| 53 | 67 | packetType = PacketType::PUBLISH; |
| 54 | - char first_byte = static_cast<char>(packetType) << 4; | |
| 55 | - writeByte(first_byte); | |
| 68 | + first_byte = static_cast<char>(packetType) << 4; | |
| 69 | + first_byte |= (publish.qos << 1); | |
| 70 | + first_byte |= (static_cast<char>(publish.retain) & 0b00000001); | |
| 56 | 71 | |
| 57 | 72 | char topicLenMSB = (publish.topic.length() & 0xF0) >> 8; |
| 58 | 73 | char topicLenLSB = publish.topic.length() & 0x0F; |
| ... | ... | @@ -60,9 +75,13 @@ MqttPacket::MqttPacket(const Publish &publish) : |
| 60 | 75 | writeByte(topicLenLSB); |
| 61 | 76 | writeBytes(publish.topic.c_str(), publish.topic.length()); |
| 62 | 77 | |
| 63 | - writeBytes(publish.payload.c_str(), publish.payload.length()); | |
| 78 | + if (publish.qos) | |
| 79 | + { | |
| 80 | + throw NotImplementedException("I would write two bytes containing the packet id here, but QoS is not done yet."); | |
| 81 | + } | |
| 64 | 82 | |
| 65 | - // TODO: untested. May be unnecessary, because a received packet can be resent just like that. | |
| 83 | + writeBytes(publish.payload.c_str(), publish.payload.length()); | |
| 84 | + calculateRemainingLength(); | |
| 66 | 85 | } |
| 67 | 86 | |
| 68 | 87 | void MqttPacket::handle(std::shared_ptr<SubscriptionStore> &subscriptionStore) |
| ... | ... | @@ -201,23 +220,77 @@ void MqttPacket::handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionS |
| 201 | 220 | if (variable_header_length == 0) |
| 202 | 221 | throw ProtocolError("Empty publish topic"); |
| 203 | 222 | |
| 223 | + bool retain = (first_byte & 0b00000001); | |
| 224 | + bool dup = !!(first_byte & 0b00001000); | |
| 204 | 225 | char qos = (first_byte & 0b00000110) >> 1; |
| 205 | 226 | |
| 227 | + if (qos == 3) | |
| 228 | + throw ProtocolError("QoS 3 is a protocol violation."); | |
| 229 | + | |
| 230 | + // TODO: validate UTF8. | |
| 206 | 231 | std::string topic(readBytes(variable_header_length), variable_header_length); |
| 207 | 232 | |
| 208 | 233 | if (qos) |
| 209 | 234 | { |
| 210 | 235 | throw ProtocolError("Qos not implemented."); |
| 211 | - //uint16_t packet_id = readTwoBytesToUInt16(); | |
| 236 | + uint16_t packet_id = readTwoBytesToUInt16(); | |
| 212 | 237 | } |
| 213 | 238 | |
| 214 | - // TODO: validate UTF8. | |
| 215 | - size_t payload_length = remainingAfterPos(); | |
| 216 | - std::string payload(readBytes(payload_length), payload_length); | |
| 239 | + if (retain) | |
| 240 | + { | |
| 241 | + size_t payload_length = remainingAfterPos(); | |
| 242 | + std::string payload(readBytes(payload_length), payload_length); | |
| 217 | 243 | |
| 244 | + subscriptionStore->setRetainedMessage(topic, payload, qos); | |
| 245 | + } | |
| 246 | + | |
| 247 | + // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. | |
| 248 | + // Existing subscribers don't get retain=1. [MQTT-3.3.1-9] | |
| 249 | + bites[0] &= 0b11110110; | |
| 250 | + | |
| 251 | + // For the existing clients, we can just write the same packet back out, with our small alterations. | |
| 218 | 252 | subscriptionStore->queuePacketAtSubscribers(topic, *this, sender); |
| 219 | 253 | } |
| 220 | 254 | |
| 255 | +void MqttPacket::calculateRemainingLength() | |
| 256 | +{ | |
| 257 | + assert(fixed_header_length == 0); // because you're not supposed to call this on packet that we already know the length of. | |
| 258 | + | |
| 259 | + size_t x = bites.size(); | |
| 260 | + | |
| 261 | + do | |
| 262 | + { | |
| 263 | + if (remainingLength.len > 4) | |
| 264 | + throw std::runtime_error("Calculated remaining length is longer than 4 bytes."); | |
| 265 | + | |
| 266 | + char encodedByte = x % 128; | |
| 267 | + x = x / 128; | |
| 268 | + if (x > 0) | |
| 269 | + encodedByte = encodedByte | 128; | |
| 270 | + remainingLength.bytes[remainingLength.len++] = encodedByte; | |
| 271 | + } | |
| 272 | + while(x > 0); | |
| 273 | +} | |
| 274 | + | |
| 275 | +RemainingLength MqttPacket::getRemainingLength() const | |
| 276 | +{ | |
| 277 | + assert(remainingLength.len > 0); | |
| 278 | + return remainingLength; | |
| 279 | +} | |
| 280 | + | |
| 281 | +size_t MqttPacket::getSizeIncludingNonPresentHeader() const | |
| 282 | +{ | |
| 283 | + size_t total = bites.size(); | |
| 284 | + | |
| 285 | + if (fixed_header_length == 0) | |
| 286 | + { | |
| 287 | + total++; | |
| 288 | + total += remainingLength.len; | |
| 289 | + } | |
| 290 | + | |
| 291 | + return total; | |
| 292 | +} | |
| 293 | + | |
| 221 | 294 | |
| 222 | 295 | Client_p MqttPacket::getSender() const |
| 223 | 296 | { |
| ... | ... | @@ -229,6 +302,16 @@ void MqttPacket::setSender(const Client_p &value) |
| 229 | 302 | sender = value; |
| 230 | 303 | } |
| 231 | 304 | |
| 305 | +bool MqttPacket::containsFixedHeader() const | |
| 306 | +{ | |
| 307 | + return fixed_header_length > 0; | |
| 308 | +} | |
| 309 | + | |
| 310 | +char MqttPacket::getFirstByte() const | |
| 311 | +{ | |
| 312 | + return first_byte; | |
| 313 | +} | |
| 314 | + | |
| 232 | 315 | char *MqttPacket::readBytes(size_t length) |
| 233 | 316 | { |
| 234 | 317 | if (pos + length > bites.size()) |
| ... | ... | @@ -262,6 +345,7 @@ void MqttPacket::writeBytes(const char *b, size_t len) |
| 262 | 345 | throw ProtocolError("Exceeding packet size"); |
| 263 | 346 | |
| 264 | 347 | memcpy(&bites[pos], b, len); |
| 348 | + pos += len; | |
| 265 | 349 | } |
| 266 | 350 | |
| 267 | 351 | uint16_t MqttPacket::readTwoBytesToUInt16() |
| ... | ... | @@ -287,3 +371,5 @@ size_t MqttPacket::remainingAfterPos() |
| 287 | 371 | |
| 288 | 372 | |
| 289 | 373 | |
| 374 | + | |
| 375 | + | ... | ... |
mqttpacket.h
| ... | ... | @@ -13,13 +13,21 @@ |
| 13 | 13 | #include "types.h" |
| 14 | 14 | #include "subscriptionstore.h" |
| 15 | 15 | |
| 16 | +struct RemainingLength | |
| 17 | +{ | |
| 18 | + char bytes[4]; | |
| 19 | + int len = 0; | |
| 20 | +public: | |
| 21 | + RemainingLength(); | |
| 22 | +}; | |
| 16 | 23 | |
| 17 | 24 | class MqttPacket |
| 18 | 25 | { |
| 19 | 26 | std::vector<char> bites; |
| 20 | - size_t fixed_header_length = 0; | |
| 27 | + size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header. | |
| 28 | + RemainingLength remainingLength; | |
| 21 | 29 | Client_p sender; |
| 22 | - char first_byte; | |
| 30 | + char first_byte = 0; | |
| 23 | 31 | size_t pos = 0; |
| 24 | 32 | ProtocolVersion protocolVersion = ProtocolVersion::None; |
| 25 | 33 | |
| ... | ... | @@ -30,12 +38,13 @@ class MqttPacket |
| 30 | 38 | uint16_t readTwoBytesToUInt16(); |
| 31 | 39 | size_t remainingAfterPos(); |
| 32 | 40 | |
| 41 | + void calculateRemainingLength(); | |
| 42 | + | |
| 33 | 43 | public: |
| 34 | 44 | PacketType packetType = PacketType::Reserved; |
| 35 | - MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender); | |
| 45 | + MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender); // Constructor for parsing incoming packets. | |
| 36 | 46 | |
| 37 | - // TODO: not constructors, but static functions that return all the stuff after the fixed header, then a constructor with vector. | |
| 38 | - // Or, I can not have the fixed header, and calculate that on write-to-buf. | |
| 47 | + // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance. | |
| 39 | 48 | MqttPacket(const ConnAck &connAck); |
| 40 | 49 | MqttPacket(const SubAck &subAck); |
| 41 | 50 | MqttPacket(const Publish &publish); |
| ... | ... | @@ -46,11 +55,15 @@ public: |
| 46 | 55 | void handlePing(); |
| 47 | 56 | void handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore); |
| 48 | 57 | |
| 49 | - size_t getSize() const { return bites.size(); } | |
| 58 | + size_t getSizeIncludingNonPresentHeader() const; | |
| 50 | 59 | const std::vector<char> &getBites() const { return bites; } |
| 51 | 60 | |
| 52 | 61 | Client_p getSender() const; |
| 53 | 62 | void setSender(const Client_p &value); |
| 63 | + | |
| 64 | + bool containsFixedHeader() const; | |
| 65 | + char getFirstByte() const; | |
| 66 | + RemainingLength getRemainingLength() const; | |
| 54 | 67 | }; |
| 55 | 68 | |
| 56 | 69 | #endif // MQTTPACKET_H | ... | ... |
rwlockguard.cpp
| ... | ... | @@ -10,7 +10,7 @@ RWLockGuard::RWLockGuard(pthread_rwlock_t *rwlock) : |
| 10 | 10 | |
| 11 | 11 | RWLockGuard::~RWLockGuard() |
| 12 | 12 | { |
| 13 | - pthread_rwlock_unlock(rwlock); | |
| 13 | + unlock(); | |
| 14 | 14 | } |
| 15 | 15 | |
| 16 | 16 | void RWLockGuard::wrlock() |
| ... | ... | @@ -24,3 +24,12 @@ void RWLockGuard::rdlock() |
| 24 | 24 | if (pthread_rwlock_wrlock(rwlock) != 0) |
| 25 | 25 | throw std::runtime_error("rdlock failed."); |
| 26 | 26 | } |
| 27 | + | |
| 28 | +void RWLockGuard::unlock() | |
| 29 | +{ | |
| 30 | + if (rwlock != NULL) | |
| 31 | + { | |
| 32 | + pthread_rwlock_unlock(rwlock); | |
| 33 | + rwlock = NULL; | |
| 34 | + } | |
| 35 | +} | ... | ... |
rwlockguard.h
subscriptionstore.cpp
| ... | ... | @@ -44,6 +44,9 @@ void SubscriptionStore::addSubscription(Client_p &client, const std::string &top |
| 44 | 44 | } |
| 45 | 45 | |
| 46 | 46 | clients_by_id[client->getClientId()] = client; |
| 47 | + lock_guard.unlock(); | |
| 48 | + | |
| 49 | + giveClientRetainedMessage(client, topic); // TODO: wildcards | |
| 47 | 50 | } |
| 48 | 51 | |
| 49 | 52 | void SubscriptionStore::removeClient(const Client_p &client) |
| ... | ... | @@ -119,4 +122,44 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::string &topic, const |
| 119 | 122 | publishRecursively(subtopics.begin(), subtopics.end(), root, packet); |
| 120 | 123 | } |
| 121 | 124 | |
| 125 | +void SubscriptionStore::giveClientRetainedMessage(Client_p &client, const std::string &topic) | |
| 126 | +{ | |
| 127 | + RWLockGuard locker(&retainedMessagesRwlock); | |
| 128 | + locker.rdlock(); | |
| 129 | + | |
| 130 | + auto retained_ptr = retainedMessages.find(topic); | |
| 131 | + | |
| 132 | + if (retained_ptr == retainedMessages.end()) | |
| 133 | + return; | |
| 134 | + | |
| 135 | + const RetainedPayload &m = retained_ptr->second; | |
| 136 | + | |
| 137 | + Publish publish(topic, m.payload, m.qos); | |
| 138 | + publish.retain = true; | |
| 139 | + const MqttPacket packet(publish); | |
| 140 | + client->writeMqttPacket(packet); | |
| 141 | +} | |
| 142 | + | |
| 143 | +void SubscriptionStore::setRetainedMessage(const std::string &topic, const std::string &payload, char qos) | |
| 144 | +{ | |
| 145 | + RWLockGuard locker(&retainedMessagesRwlock); | |
| 146 | + locker.wrlock(); | |
| 147 | + | |
| 148 | + auto retained_ptr = retainedMessages.find(topic); | |
| 149 | + bool retained_found = retained_ptr != retainedMessages.end(); | |
| 150 | + | |
| 151 | + if (!retained_found && payload.empty()) | |
| 152 | + return; | |
| 153 | + | |
| 154 | + if (retained_found && payload.empty()) | |
| 155 | + { | |
| 156 | + retainedMessages.erase(topic); | |
| 157 | + return; | |
| 158 | + } | |
| 159 | + | |
| 160 | + RetainedPayload &m = retainedMessages[topic]; | |
| 161 | + m.payload = payload; | |
| 162 | + m.qos = qos; | |
| 163 | +} | |
| 164 | + | |
| 122 | 165 | ... | ... |
subscriptionstore.h
| ... | ... | @@ -12,6 +12,12 @@ |
| 12 | 12 | #include "client.h" |
| 13 | 13 | #include "utils.h" |
| 14 | 14 | |
| 15 | +struct RetainedPayload | |
| 16 | +{ | |
| 17 | + std::string payload; | |
| 18 | + char qos; | |
| 19 | +}; | |
| 20 | + | |
| 15 | 21 | class SubscriptionNode |
| 16 | 22 | { |
| 17 | 23 | std::string subtopic; |
| ... | ... | @@ -32,6 +38,9 @@ class SubscriptionStore |
| 32 | 38 | std::unordered_map<std::string, Client_p> clients_by_id; |
| 33 | 39 | const std::unordered_map<std::string, Client_p> &clients_by_id_const; |
| 34 | 40 | |
| 41 | + pthread_rwlock_t retainedMessagesRwlock = PTHREAD_RWLOCK_INITIALIZER; | |
| 42 | + std::unordered_map<std::string, RetainedPayload> retainedMessages; | |
| 43 | + | |
| 35 | 44 | bool publishNonRecursively(const MqttPacket &packet, const std::forward_list<std::string> &subscribers) const; |
| 36 | 45 | bool publishRecursively(std::list<std::string>::const_iterator cur_subtopic_it, std::list<std::string>::const_iterator end, |
| 37 | 46 | std::unique_ptr<SubscriptionNode> &next, const MqttPacket &packet) const; |
| ... | ... | @@ -42,6 +51,9 @@ public: |
| 42 | 51 | void removeClient(const Client_p &client); |
| 43 | 52 | |
| 44 | 53 | void queuePacketAtSubscribers(const std::string &topic, const MqttPacket &packet, const Client_p &sender); |
| 54 | + void giveClientRetainedMessage(Client_p &client, const std::string &topic); | |
| 55 | + | |
| 56 | + void setRetainedMessage(const std::string &topic, const std::string &payload, char qos); | |
| 45 | 57 | }; |
| 46 | 58 | |
| 47 | 59 | #endif // SUBSCRIPTIONSTORE_H | ... | ... |
types.cpp
| ... | ... | @@ -16,9 +16,21 @@ SubAck::SubAck(uint16_t packet_id, const std::list<std::string> &subs) : |
| 16 | 16 | } |
| 17 | 17 | } |
| 18 | 18 | |
| 19 | -Publish::Publish(std::string &topic, std::string payload) : | |
| 19 | +Publish::Publish(const std::string &topic, const std::string payload, char qos) : | |
| 20 | 20 | topic(topic), |
| 21 | - payload(payload) | |
| 21 | + payload(payload), | |
| 22 | + qos(qos) | |
| 22 | 23 | { |
| 23 | 24 | |
| 24 | 25 | } |
| 26 | + | |
| 27 | +// Length starting at the variable header, not the fixed header. | |
| 28 | +size_t Publish::getLength() const | |
| 29 | +{ | |
| 30 | + int result = topic.length() + payload.length() + 2; | |
| 31 | + | |
| 32 | + if (qos) | |
| 33 | + result += 2; | |
| 34 | + | |
| 35 | + return result; | |
| 36 | +} | ... | ... |
types.h
| ... | ... | @@ -48,6 +48,7 @@ class ConnAck |
| 48 | 48 | public: |
| 49 | 49 | ConnAck(ConnAckReturnCodes return_code); |
| 50 | 50 | ConnAckReturnCodes return_code; |
| 51 | + size_t getLength() const { return 2;} // size of connack is always the same | |
| 51 | 52 | }; |
| 52 | 53 | |
| 53 | 54 | enum class SubAckReturnCodes |
| ... | ... | @@ -71,7 +72,10 @@ class Publish |
| 71 | 72 | public: |
| 72 | 73 | std::string topic; |
| 73 | 74 | std::string payload; |
| 74 | - Publish(std::string &topic, std::string payload); | |
| 75 | + char qos = 0; | |
| 76 | + bool retain = false; // Note: existing subscribers don't get publishes of retained messages with retain=1. [MQTT-3.3.1-9] | |
| 77 | + Publish(const std::string &topic, const std::string payload, char qos); | |
| 78 | + size_t getLength() const; | |
| 75 | 79 | }; |
| 76 | 80 | |
| 77 | 81 | #endif // TYPES_H | ... | ... |