diff --git a/mqttpacket.cpp b/mqttpacket.cpp index f935412..6248322 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -110,19 +110,26 @@ MqttPacket::MqttPacket(const UnsubAck &unsubAck) : calculateRemainingLength(); } -size_t MqttPacket::getRequiredSizeForPublish(const ProtocolVersion protocolVersion, const Publish &publish) const +size_t MqttPacket::setClientSpecificPropertiesAndGetRequiredSizeForPublish(const ProtocolVersion protocolVersion, Publish &publish) const { size_t result = publish.getLengthWithoutFixedHeader(); if (protocolVersion >= ProtocolVersion::Mqtt5) { + publish.setClientSpecificProperties(); + const size_t proplen = publish.propertyBuilder ? publish.propertyBuilder->getLength() : 1; result += proplen; } return result; } -MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish) : - bites(getRequiredSizeForPublish(protocolVersion, _publish)) +/** + * @brief Construct a packet for a specific protocol version. + * @param protocolVersion is required here, and not on the Publish object, because publishes don't have a protocol until they are for a specific client. + * @param _publish + */ +MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish) : + bites(setClientSpecificPropertiesAndGetRequiredSizeForPublish(protocolVersion, _publish)) { if (_publish.topic.length() > 0xFFFF) { diff --git a/mqttpacket.h b/mqttpacket.h index 1b2795d..1116fd3 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -92,13 +92,13 @@ public: MqttPacket(MqttPacket &&other) = default; - size_t getRequiredSizeForPublish(const ProtocolVersion protocolVersion, const Publish &publishData) const; + size_t setClientSpecificPropertiesAndGetRequiredSizeForPublish(const ProtocolVersion protocolVersion, Publish &publishData) const; // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance. MqttPacket(const ConnAck &connAck); MqttPacket(const SubAck &subAck); MqttPacket(const UnsubAck &unsubAck); - MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish); + MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish); MqttPacket(const PubResponse &pubAck); static void bufferToMqttPackets(CirBuf &buf, std::vector &packetQueueIn, std::shared_ptr &sender); diff --git a/publishcopyfactory.cpp b/publishcopyfactory.cpp index 2e21bf0..df9c00f 100644 --- a/publishcopyfactory.cpp +++ b/publishcopyfactory.cpp @@ -28,7 +28,6 @@ MqttPacket *PublishCopyFactory::getOptimumPacket(const char max_qos, const Proto newPublish.qos = max_qos; newPublish.topicAlias = topic_alias; newPublish.skipTopic = skip_topic; - newPublish.setClientSpecificProperties(); this->oneShotPacket = std::make_unique(protocolVersion, newPublish); return this->oneShotPacket.get(); } @@ -47,8 +46,6 @@ MqttPacket *PublishCopyFactory::getOptimumPacket(const char max_qos, const Proto Publish newPublish(packet->getPublishData()); newPublish.splitTopic = false; newPublish.qos = max_qos; - if (protocolVersion >= ProtocolVersion::Mqtt5) - newPublish.setClientSpecificProperties(); cachedPack = std::make_unique(protocolVersion, newPublish); } @@ -58,8 +55,6 @@ MqttPacket *PublishCopyFactory::getOptimumPacket(const char max_qos, const Proto // Getting a packet of a Publish object happens on will messages and SYS topics and maybe some others. It's low traffic, anyway. assert(publish); - if (protocolVersion >= ProtocolVersion::Mqtt5) - publish->setClientSpecificProperties(); this->oneShotPacket = std::make_unique(protocolVersion, *publish); return this->oneShotPacket.get(); } diff --git a/qospacketqueue.cpp b/qospacketqueue.cpp index 5a21343..fffc3f1 100644 --- a/qospacketqueue.cpp +++ b/qospacketqueue.cpp @@ -16,7 +16,7 @@ uint16_t QueuedPublish::getPacketId() const return this->packet_id; } -const Publish &QueuedPublish::getPublish() const +Publish &QueuedPublish::getPublish() { return publish; } @@ -51,7 +51,7 @@ void QoSPublishQueue::erase(const uint16_t packet_id) } } -std::list::const_iterator QoSPublishQueue::erase(std::list::const_iterator pos) +std::list::iterator QoSPublishQueue::erase(std::list::iterator pos) { return this->queue.erase(pos); } @@ -86,12 +86,12 @@ void QoSPublishQueue::queuePublish(Publish &&pub, uint16_t id) qosQueueBytes += queue.back().getApproximateMemoryFootprint(); } -std::list::const_iterator QoSPublishQueue::begin() const +std::list::iterator QoSPublishQueue::begin() { return queue.begin(); } -std::list::const_iterator QoSPublishQueue::end() const +std::list::iterator QoSPublishQueue::end() { return queue.end(); } diff --git a/qospacketqueue.h b/qospacketqueue.h index 30d7805..febe263 100644 --- a/qospacketqueue.h +++ b/qospacketqueue.h @@ -21,7 +21,7 @@ public: size_t getApproximateMemoryFootprint() const; uint16_t getPacketId() const; - const Publish &getPublish() const; + Publish &getPublish(); }; class QoSPublishQueue @@ -31,14 +31,14 @@ class QoSPublishQueue public: void erase(const uint16_t packet_id); - std::list::const_iterator erase(std::list::const_iterator pos); + std::list::iterator erase(std::list::iterator pos); size_t size() const; size_t getByteSize() const; void queuePublish(PublishCopyFactory ©Factory, uint16_t id, char new_max_qos); void queuePublish(Publish &&pub, uint16_t id); - std::list::const_iterator begin() const; - std::list::const_iterator end() const; + std::list::iterator begin(); + std::list::iterator end(); }; #endif // QOSPACKETQUEUE_H diff --git a/session.cpp b/session.cpp index edd50bb..5cf9280 100644 --- a/session.cpp +++ b/session.cpp @@ -238,15 +238,16 @@ uint64_t Session::sendPendingQosMessages() auto pos = qosPacketQueue.begin(); while (pos != qosPacketQueue.end()) { - const QueuedPublish &queuedPublish = *pos; + QueuedPublish &queuedPublish = *pos; + Publish &pub = queuedPublish.getPublish(); - if (queuedPublish.getPublish().hasExpired()) + if (pub.hasExpired()) { pos = qosPacketQueue.erase(pos); continue; } - MqttPacket p(c->getProtocolVersion(), queuedPublish.getPublish()); + MqttPacket p(c->getProtocolVersion(), pub); p.setDuplicate(); count += c->writeMqttPacketAndBlameThisClient(p); diff --git a/sessionsandsubscriptionsdb.cpp b/sessionsandsubscriptionsdb.cpp index 1d253e3..fa497f3 100644 --- a/sessionsandsubscriptionsdb.cpp +++ b/sessionsandsubscriptionsdb.cpp @@ -235,11 +235,11 @@ void SessionsAndSubscriptionsDB::saveData(const std::vectorqosPacketQueue) + for (QueuedPublish &p: ses->qosPacketQueue) { qosPacketsCounted++; - const Publish &pub = p.getPublish(); + Publish &pub = p.getPublish(); assert(!pub.splitTopic); assert(!pub.skipTopic); @@ -247,6 +247,8 @@ void SessionsAndSubscriptionsDB::saveData(const std::vectorlogf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str()); + pub.clearClientSpecificProperties(); + MqttPacket pack(ProtocolVersion::Mqtt5, pub); pack.setPacketId(p.getPacketId()); const uint32_t packSize = pack.getSizeIncludingNonPresentHeader(); diff --git a/types.cpp b/types.cpp index 6f77687..01fd761 100644 --- a/types.cpp +++ b/types.cpp @@ -157,6 +157,12 @@ void PublishBase::setClientSpecificProperties() propertyBuilder->writeTopicAlias(this->topicAlias); } +void PublishBase::clearClientSpecificProperties() +{ + if (propertyBuilder) + propertyBuilder->clearClientSpecificBytes(); +} + void PublishBase::constructPropertyBuilder() { if (this->propertyBuilder) diff --git a/types.h b/types.h index 680143c..7b2df7e 100644 --- a/types.h +++ b/types.h @@ -209,6 +209,7 @@ public: PublishBase(const std::string &topic, const std::string &payload, char qos); size_t getLengthWithoutFixedHeader() const; void setClientSpecificProperties(); + void clearClientSpecificProperties(); void constructPropertyBuilder(); bool hasUserProperties() const; bool hasExpired() const;