Commit e69a2c35b602ee09306e86658623fe911e24518a

Authored by Wiebe Cazemier
1 parent 163563e1

Have the MqttPacket constructor set client specific properties

This prevents bugs because the calling context forgets it. A (small)
downside is that I have to make the Publish argument non-const. But,
that's exactly what it is then, so...
mqttpacket.cpp
@@ -110,19 +110,26 @@ MqttPacket::MqttPacket(const UnsubAck &unsubAck) : @@ -110,19 +110,26 @@ MqttPacket::MqttPacket(const UnsubAck &unsubAck) :
110 calculateRemainingLength(); 110 calculateRemainingLength();
111 } 111 }
112 112
113 -size_t MqttPacket::getRequiredSizeForPublish(const ProtocolVersion protocolVersion, const Publish &publish) const 113 +size_t MqttPacket::setClientSpecificPropertiesAndGetRequiredSizeForPublish(const ProtocolVersion protocolVersion, Publish &publish) const
114 { 114 {
115 size_t result = publish.getLengthWithoutFixedHeader(); 115 size_t result = publish.getLengthWithoutFixedHeader();
116 if (protocolVersion >= ProtocolVersion::Mqtt5) 116 if (protocolVersion >= ProtocolVersion::Mqtt5)
117 { 117 {
  118 + publish.setClientSpecificProperties();
  119 +
118 const size_t proplen = publish.propertyBuilder ? publish.propertyBuilder->getLength() : 1; 120 const size_t proplen = publish.propertyBuilder ? publish.propertyBuilder->getLength() : 1;
119 result += proplen; 121 result += proplen;
120 } 122 }
121 return result; 123 return result;
122 } 124 }
123 125
124 -MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish) :  
125 - bites(getRequiredSizeForPublish(protocolVersion, _publish)) 126 +/**
  127 + * @brief Construct a packet for a specific protocol version.
  128 + * @param protocolVersion is required here, and not on the Publish object, because publishes don't have a protocol until they are for a specific client.
  129 + * @param _publish
  130 + */
  131 +MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish) :
  132 + bites(setClientSpecificPropertiesAndGetRequiredSizeForPublish(protocolVersion, _publish))
126 { 133 {
127 if (_publish.topic.length() > 0xFFFF) 134 if (_publish.topic.length() > 0xFFFF)
128 { 135 {
mqttpacket.h
@@ -92,13 +92,13 @@ public: @@ -92,13 +92,13 @@ public:
92 92
93 MqttPacket(MqttPacket &&other) = default; 93 MqttPacket(MqttPacket &&other) = default;
94 94
95 - size_t getRequiredSizeForPublish(const ProtocolVersion protocolVersion, const Publish &publishData) const; 95 + size_t setClientSpecificPropertiesAndGetRequiredSizeForPublish(const ProtocolVersion protocolVersion, Publish &publishData) const;
96 96
97 // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance. 97 // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance.
98 MqttPacket(const ConnAck &connAck); 98 MqttPacket(const ConnAck &connAck);
99 MqttPacket(const SubAck &subAck); 99 MqttPacket(const SubAck &subAck);
100 MqttPacket(const UnsubAck &unsubAck); 100 MqttPacket(const UnsubAck &unsubAck);
101 - MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish); 101 + MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish);
102 MqttPacket(const PubResponse &pubAck); 102 MqttPacket(const PubResponse &pubAck);
103 103
104 static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender); 104 static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender);
publishcopyfactory.cpp
@@ -28,7 +28,6 @@ MqttPacket *PublishCopyFactory::getOptimumPacket(const char max_qos, const Proto @@ -28,7 +28,6 @@ MqttPacket *PublishCopyFactory::getOptimumPacket(const char max_qos, const Proto
28 newPublish.qos = max_qos; 28 newPublish.qos = max_qos;
29 newPublish.topicAlias = topic_alias; 29 newPublish.topicAlias = topic_alias;
30 newPublish.skipTopic = skip_topic; 30 newPublish.skipTopic = skip_topic;
31 - newPublish.setClientSpecificProperties();  
32 this->oneShotPacket = std::make_unique<MqttPacket>(protocolVersion, newPublish); 31 this->oneShotPacket = std::make_unique<MqttPacket>(protocolVersion, newPublish);
33 return this->oneShotPacket.get(); 32 return this->oneShotPacket.get();
34 } 33 }
@@ -47,8 +46,6 @@ MqttPacket *PublishCopyFactory::getOptimumPacket(const char max_qos, const Proto @@ -47,8 +46,6 @@ MqttPacket *PublishCopyFactory::getOptimumPacket(const char max_qos, const Proto
47 Publish newPublish(packet->getPublishData()); 46 Publish newPublish(packet->getPublishData());
48 newPublish.splitTopic = false; 47 newPublish.splitTopic = false;
49 newPublish.qos = max_qos; 48 newPublish.qos = max_qos;
50 - if (protocolVersion >= ProtocolVersion::Mqtt5)  
51 - newPublish.setClientSpecificProperties();  
52 cachedPack = std::make_unique<MqttPacket>(protocolVersion, newPublish); 49 cachedPack = std::make_unique<MqttPacket>(protocolVersion, newPublish);
53 } 50 }
54 51
@@ -58,8 +55,6 @@ MqttPacket *PublishCopyFactory::getOptimumPacket(const char max_qos, const Proto @@ -58,8 +55,6 @@ MqttPacket *PublishCopyFactory::getOptimumPacket(const char max_qos, const Proto
58 // Getting a packet of a Publish object happens on will messages and SYS topics and maybe some others. It's low traffic, anyway. 55 // Getting a packet of a Publish object happens on will messages and SYS topics and maybe some others. It's low traffic, anyway.
59 assert(publish); 56 assert(publish);
60 57
61 - if (protocolVersion >= ProtocolVersion::Mqtt5)  
62 - publish->setClientSpecificProperties();  
63 this->oneShotPacket = std::make_unique<MqttPacket>(protocolVersion, *publish); 58 this->oneShotPacket = std::make_unique<MqttPacket>(protocolVersion, *publish);
64 return this->oneShotPacket.get(); 59 return this->oneShotPacket.get();
65 } 60 }
qospacketqueue.cpp
@@ -16,7 +16,7 @@ uint16_t QueuedPublish::getPacketId() const @@ -16,7 +16,7 @@ uint16_t QueuedPublish::getPacketId() const
16 return this->packet_id; 16 return this->packet_id;
17 } 17 }
18 18
19 -const Publish &QueuedPublish::getPublish() const 19 +Publish &QueuedPublish::getPublish()
20 { 20 {
21 return publish; 21 return publish;
22 } 22 }
@@ -51,7 +51,7 @@ void QoSPublishQueue::erase(const uint16_t packet_id) @@ -51,7 +51,7 @@ void QoSPublishQueue::erase(const uint16_t packet_id)
51 } 51 }
52 } 52 }
53 53
54 -std::list<QueuedPublish>::const_iterator QoSPublishQueue::erase(std::list<QueuedPublish>::const_iterator pos) 54 +std::list<QueuedPublish>::iterator QoSPublishQueue::erase(std::list<QueuedPublish>::iterator pos)
55 { 55 {
56 return this->queue.erase(pos); 56 return this->queue.erase(pos);
57 } 57 }
@@ -86,12 +86,12 @@ void QoSPublishQueue::queuePublish(Publish &amp;&amp;pub, uint16_t id) @@ -86,12 +86,12 @@ void QoSPublishQueue::queuePublish(Publish &amp;&amp;pub, uint16_t id)
86 qosQueueBytes += queue.back().getApproximateMemoryFootprint(); 86 qosQueueBytes += queue.back().getApproximateMemoryFootprint();
87 } 87 }
88 88
89 -std::list<QueuedPublish>::const_iterator QoSPublishQueue::begin() const 89 +std::list<QueuedPublish>::iterator QoSPublishQueue::begin()
90 { 90 {
91 return queue.begin(); 91 return queue.begin();
92 } 92 }
93 93
94 -std::list<QueuedPublish>::const_iterator QoSPublishQueue::end() const 94 +std::list<QueuedPublish>::iterator QoSPublishQueue::end()
95 { 95 {
96 return queue.end(); 96 return queue.end();
97 } 97 }
qospacketqueue.h
@@ -21,7 +21,7 @@ public: @@ -21,7 +21,7 @@ public:
21 21
22 size_t getApproximateMemoryFootprint() const; 22 size_t getApproximateMemoryFootprint() const;
23 uint16_t getPacketId() const; 23 uint16_t getPacketId() const;
24 - const Publish &getPublish() const; 24 + Publish &getPublish();
25 }; 25 };
26 26
27 class QoSPublishQueue 27 class QoSPublishQueue
@@ -31,14 +31,14 @@ class QoSPublishQueue @@ -31,14 +31,14 @@ class QoSPublishQueue
31 31
32 public: 32 public:
33 void erase(const uint16_t packet_id); 33 void erase(const uint16_t packet_id);
34 - std::list<QueuedPublish>::const_iterator erase(std::list<QueuedPublish>::const_iterator pos); 34 + std::list<QueuedPublish>::iterator erase(std::list<QueuedPublish>::iterator pos);
35 size_t size() const; 35 size_t size() const;
36 size_t getByteSize() const; 36 size_t getByteSize() const;
37 void queuePublish(PublishCopyFactory &copyFactory, uint16_t id, char new_max_qos); 37 void queuePublish(PublishCopyFactory &copyFactory, uint16_t id, char new_max_qos);
38 void queuePublish(Publish &&pub, uint16_t id); 38 void queuePublish(Publish &&pub, uint16_t id);
39 39
40 - std::list<QueuedPublish>::const_iterator begin() const;  
41 - std::list<QueuedPublish>::const_iterator end() const; 40 + std::list<QueuedPublish>::iterator begin();
  41 + std::list<QueuedPublish>::iterator end();
42 }; 42 };
43 43
44 #endif // QOSPACKETQUEUE_H 44 #endif // QOSPACKETQUEUE_H
session.cpp
@@ -238,15 +238,16 @@ uint64_t Session::sendPendingQosMessages() @@ -238,15 +238,16 @@ uint64_t Session::sendPendingQosMessages()
238 auto pos = qosPacketQueue.begin(); 238 auto pos = qosPacketQueue.begin();
239 while (pos != qosPacketQueue.end()) 239 while (pos != qosPacketQueue.end())
240 { 240 {
241 - const QueuedPublish &queuedPublish = *pos; 241 + QueuedPublish &queuedPublish = *pos;
  242 + Publish &pub = queuedPublish.getPublish();
242 243
243 - if (queuedPublish.getPublish().hasExpired()) 244 + if (pub.hasExpired())
244 { 245 {
245 pos = qosPacketQueue.erase(pos); 246 pos = qosPacketQueue.erase(pos);
246 continue; 247 continue;
247 } 248 }
248 249
249 - MqttPacket p(c->getProtocolVersion(), queuedPublish.getPublish()); 250 + MqttPacket p(c->getProtocolVersion(), pub);
250 p.setDuplicate(); 251 p.setDuplicate();
251 252
252 count += c->writeMqttPacketAndBlameThisClient(p); 253 count += c->writeMqttPacketAndBlameThisClient(p);
sessionsandsubscriptionsdb.cpp
@@ -235,11 +235,11 @@ void SessionsAndSubscriptionsDB::saveData(const std::vector&lt;std::unique_ptr&lt;Sess @@ -235,11 +235,11 @@ void SessionsAndSubscriptionsDB::saveData(const std::vector&lt;std::unique_ptr&lt;Sess
235 size_t qosPacketsCounted = 0; 235 size_t qosPacketsCounted = 0;
236 writeUint32(qosPacketsExpected); 236 writeUint32(qosPacketsExpected);
237 237
238 - for (const QueuedPublish &p: ses->qosPacketQueue) 238 + for (QueuedPublish &p: ses->qosPacketQueue)
239 { 239 {
240 qosPacketsCounted++; 240 qosPacketsCounted++;
241 241
242 - const Publish &pub = p.getPublish(); 242 + Publish &pub = p.getPublish();
243 243
244 assert(!pub.splitTopic); 244 assert(!pub.splitTopic);
245 assert(!pub.skipTopic); 245 assert(!pub.skipTopic);
@@ -247,6 +247,8 @@ void SessionsAndSubscriptionsDB::saveData(const std::vector&lt;std::unique_ptr&lt;Sess @@ -247,6 +247,8 @@ void SessionsAndSubscriptionsDB::saveData(const std::vector&lt;std::unique_ptr&lt;Sess
247 247
248 logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str()); 248 logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str());
249 249
  250 + pub.clearClientSpecificProperties();
  251 +
250 MqttPacket pack(ProtocolVersion::Mqtt5, pub); 252 MqttPacket pack(ProtocolVersion::Mqtt5, pub);
251 pack.setPacketId(p.getPacketId()); 253 pack.setPacketId(p.getPacketId());
252 const uint32_t packSize = pack.getSizeIncludingNonPresentHeader(); 254 const uint32_t packSize = pack.getSizeIncludingNonPresentHeader();
types.cpp
@@ -157,6 +157,12 @@ void PublishBase::setClientSpecificProperties() @@ -157,6 +157,12 @@ void PublishBase::setClientSpecificProperties()
157 propertyBuilder->writeTopicAlias(this->topicAlias); 157 propertyBuilder->writeTopicAlias(this->topicAlias);
158 } 158 }
159 159
  160 +void PublishBase::clearClientSpecificProperties()
  161 +{
  162 + if (propertyBuilder)
  163 + propertyBuilder->clearClientSpecificBytes();
  164 +}
  165 +
160 void PublishBase::constructPropertyBuilder() 166 void PublishBase::constructPropertyBuilder()
161 { 167 {
162 if (this->propertyBuilder) 168 if (this->propertyBuilder)
@@ -209,6 +209,7 @@ public: @@ -209,6 +209,7 @@ public:
209 PublishBase(const std::string &topic, const std::string &payload, char qos); 209 PublishBase(const std::string &topic, const std::string &payload, char qos);
210 size_t getLengthWithoutFixedHeader() const; 210 size_t getLengthWithoutFixedHeader() const;
211 void setClientSpecificProperties(); 211 void setClientSpecificProperties();
  212 + void clearClientSpecificProperties();
212 void constructPropertyBuilder(); 213 void constructPropertyBuilder();
213 bool hasUserProperties() const; 214 bool hasUserProperties() const;
214 bool hasExpired() const; 215 bool hasExpired() const;