Commit e69a2c35b602ee09306e86658623fe911e24518a

Authored by Wiebe Cazemier
1 parent 163563e1

Have the MqttPacket constructor set client specific properties

This prevents bugs because the calling context forgets it. A (small)
downside is that I have to make the Publish argument non-const. But,
that's exactly what it is then, so...
mqttpacket.cpp
... ... @@ -110,19 +110,26 @@ MqttPacket::MqttPacket(const UnsubAck &unsubAck) :
110 110 calculateRemainingLength();
111 111 }
112 112  
113   -size_t MqttPacket::getRequiredSizeForPublish(const ProtocolVersion protocolVersion, const Publish &publish) const
  113 +size_t MqttPacket::setClientSpecificPropertiesAndGetRequiredSizeForPublish(const ProtocolVersion protocolVersion, Publish &publish) const
114 114 {
115 115 size_t result = publish.getLengthWithoutFixedHeader();
116 116 if (protocolVersion >= ProtocolVersion::Mqtt5)
117 117 {
  118 + publish.setClientSpecificProperties();
  119 +
118 120 const size_t proplen = publish.propertyBuilder ? publish.propertyBuilder->getLength() : 1;
119 121 result += proplen;
120 122 }
121 123 return result;
122 124 }
123 125  
124   -MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish) :
125   - bites(getRequiredSizeForPublish(protocolVersion, _publish))
  126 +/**
  127 + * @brief Construct a packet for a specific protocol version.
  128 + * @param protocolVersion is required here, and not on the Publish object, because publishes don't have a protocol until they are for a specific client.
  129 + * @param _publish
  130 + */
  131 +MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish) :
  132 + bites(setClientSpecificPropertiesAndGetRequiredSizeForPublish(protocolVersion, _publish))
126 133 {
127 134 if (_publish.topic.length() > 0xFFFF)
128 135 {
... ...
mqttpacket.h
... ... @@ -92,13 +92,13 @@ public:
92 92  
93 93 MqttPacket(MqttPacket &&other) = default;
94 94  
95   - size_t getRequiredSizeForPublish(const ProtocolVersion protocolVersion, const Publish &publishData) const;
  95 + size_t setClientSpecificPropertiesAndGetRequiredSizeForPublish(const ProtocolVersion protocolVersion, Publish &publishData) const;
96 96  
97 97 // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance.
98 98 MqttPacket(const ConnAck &connAck);
99 99 MqttPacket(const SubAck &subAck);
100 100 MqttPacket(const UnsubAck &unsubAck);
101   - MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish);
  101 + MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish);
102 102 MqttPacket(const PubResponse &pubAck);
103 103  
104 104 static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender);
... ...
publishcopyfactory.cpp
... ... @@ -28,7 +28,6 @@ MqttPacket *PublishCopyFactory::getOptimumPacket(const char max_qos, const Proto
28 28 newPublish.qos = max_qos;
29 29 newPublish.topicAlias = topic_alias;
30 30 newPublish.skipTopic = skip_topic;
31   - newPublish.setClientSpecificProperties();
32 31 this->oneShotPacket = std::make_unique<MqttPacket>(protocolVersion, newPublish);
33 32 return this->oneShotPacket.get();
34 33 }
... ... @@ -47,8 +46,6 @@ MqttPacket *PublishCopyFactory::getOptimumPacket(const char max_qos, const Proto
47 46 Publish newPublish(packet->getPublishData());
48 47 newPublish.splitTopic = false;
49 48 newPublish.qos = max_qos;
50   - if (protocolVersion >= ProtocolVersion::Mqtt5)
51   - newPublish.setClientSpecificProperties();
52 49 cachedPack = std::make_unique<MqttPacket>(protocolVersion, newPublish);
53 50 }
54 51  
... ... @@ -58,8 +55,6 @@ MqttPacket *PublishCopyFactory::getOptimumPacket(const char max_qos, const Proto
58 55 // Getting a packet of a Publish object happens on will messages and SYS topics and maybe some others. It's low traffic, anyway.
59 56 assert(publish);
60 57  
61   - if (protocolVersion >= ProtocolVersion::Mqtt5)
62   - publish->setClientSpecificProperties();
63 58 this->oneShotPacket = std::make_unique<MqttPacket>(protocolVersion, *publish);
64 59 return this->oneShotPacket.get();
65 60 }
... ...
qospacketqueue.cpp
... ... @@ -16,7 +16,7 @@ uint16_t QueuedPublish::getPacketId() const
16 16 return this->packet_id;
17 17 }
18 18  
19   -const Publish &QueuedPublish::getPublish() const
  19 +Publish &QueuedPublish::getPublish()
20 20 {
21 21 return publish;
22 22 }
... ... @@ -51,7 +51,7 @@ void QoSPublishQueue::erase(const uint16_t packet_id)
51 51 }
52 52 }
53 53  
54   -std::list<QueuedPublish>::const_iterator QoSPublishQueue::erase(std::list<QueuedPublish>::const_iterator pos)
  54 +std::list<QueuedPublish>::iterator QoSPublishQueue::erase(std::list<QueuedPublish>::iterator pos)
55 55 {
56 56 return this->queue.erase(pos);
57 57 }
... ... @@ -86,12 +86,12 @@ void QoSPublishQueue::queuePublish(Publish &amp;&amp;pub, uint16_t id)
86 86 qosQueueBytes += queue.back().getApproximateMemoryFootprint();
87 87 }
88 88  
89   -std::list<QueuedPublish>::const_iterator QoSPublishQueue::begin() const
  89 +std::list<QueuedPublish>::iterator QoSPublishQueue::begin()
90 90 {
91 91 return queue.begin();
92 92 }
93 93  
94   -std::list<QueuedPublish>::const_iterator QoSPublishQueue::end() const
  94 +std::list<QueuedPublish>::iterator QoSPublishQueue::end()
95 95 {
96 96 return queue.end();
97 97 }
... ...
qospacketqueue.h
... ... @@ -21,7 +21,7 @@ public:
21 21  
22 22 size_t getApproximateMemoryFootprint() const;
23 23 uint16_t getPacketId() const;
24   - const Publish &getPublish() const;
  24 + Publish &getPublish();
25 25 };
26 26  
27 27 class QoSPublishQueue
... ... @@ -31,14 +31,14 @@ class QoSPublishQueue
31 31  
32 32 public:
33 33 void erase(const uint16_t packet_id);
34   - std::list<QueuedPublish>::const_iterator erase(std::list<QueuedPublish>::const_iterator pos);
  34 + std::list<QueuedPublish>::iterator erase(std::list<QueuedPublish>::iterator pos);
35 35 size_t size() const;
36 36 size_t getByteSize() const;
37 37 void queuePublish(PublishCopyFactory &copyFactory, uint16_t id, char new_max_qos);
38 38 void queuePublish(Publish &&pub, uint16_t id);
39 39  
40   - std::list<QueuedPublish>::const_iterator begin() const;
41   - std::list<QueuedPublish>::const_iterator end() const;
  40 + std::list<QueuedPublish>::iterator begin();
  41 + std::list<QueuedPublish>::iterator end();
42 42 };
43 43  
44 44 #endif // QOSPACKETQUEUE_H
... ...
session.cpp
... ... @@ -238,15 +238,16 @@ uint64_t Session::sendPendingQosMessages()
238 238 auto pos = qosPacketQueue.begin();
239 239 while (pos != qosPacketQueue.end())
240 240 {
241   - const QueuedPublish &queuedPublish = *pos;
  241 + QueuedPublish &queuedPublish = *pos;
  242 + Publish &pub = queuedPublish.getPublish();
242 243  
243   - if (queuedPublish.getPublish().hasExpired())
  244 + if (pub.hasExpired())
244 245 {
245 246 pos = qosPacketQueue.erase(pos);
246 247 continue;
247 248 }
248 249  
249   - MqttPacket p(c->getProtocolVersion(), queuedPublish.getPublish());
  250 + MqttPacket p(c->getProtocolVersion(), pub);
250 251 p.setDuplicate();
251 252  
252 253 count += c->writeMqttPacketAndBlameThisClient(p);
... ...
sessionsandsubscriptionsdb.cpp
... ... @@ -235,11 +235,11 @@ void SessionsAndSubscriptionsDB::saveData(const std::vector&lt;std::unique_ptr&lt;Sess
235 235 size_t qosPacketsCounted = 0;
236 236 writeUint32(qosPacketsExpected);
237 237  
238   - for (const QueuedPublish &p: ses->qosPacketQueue)
  238 + for (QueuedPublish &p: ses->qosPacketQueue)
239 239 {
240 240 qosPacketsCounted++;
241 241  
242   - const Publish &pub = p.getPublish();
  242 + Publish &pub = p.getPublish();
243 243  
244 244 assert(!pub.splitTopic);
245 245 assert(!pub.skipTopic);
... ... @@ -247,6 +247,8 @@ void SessionsAndSubscriptionsDB::saveData(const std::vector&lt;std::unique_ptr&lt;Sess
247 247  
248 248 logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str());
249 249  
  250 + pub.clearClientSpecificProperties();
  251 +
250 252 MqttPacket pack(ProtocolVersion::Mqtt5, pub);
251 253 pack.setPacketId(p.getPacketId());
252 254 const uint32_t packSize = pack.getSizeIncludingNonPresentHeader();
... ...
types.cpp
... ... @@ -157,6 +157,12 @@ void PublishBase::setClientSpecificProperties()
157 157 propertyBuilder->writeTopicAlias(this->topicAlias);
158 158 }
159 159  
  160 +void PublishBase::clearClientSpecificProperties()
  161 +{
  162 + if (propertyBuilder)
  163 + propertyBuilder->clearClientSpecificBytes();
  164 +}
  165 +
160 166 void PublishBase::constructPropertyBuilder()
161 167 {
162 168 if (this->propertyBuilder)
... ...
... ... @@ -209,6 +209,7 @@ public:
209 209 PublishBase(const std::string &topic, const std::string &payload, char qos);
210 210 size_t getLengthWithoutFixedHeader() const;
211 211 void setClientSpecificProperties();
  212 + void clearClientSpecificProperties();
212 213 void constructPropertyBuilder();
213 214 bool hasUserProperties() const;
214 215 bool hasExpired() const;
... ...