diff --git a/CMakeLists.txt b/CMakeLists.txt index 17ae16c..809a6ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,6 +58,7 @@ add_executable(FlashMQ qospacketqueue.h threadglobals.h threadloop.h + publishcopyfactory.h mainapp.cpp main.cpp @@ -95,6 +96,7 @@ add_executable(FlashMQ qospacketqueue.cpp threadglobals.cpp threadloop.cpp + publishcopyfactory.cpp ) diff --git a/FlashMQTests/FlashMQTests.pro b/FlashMQTests/FlashMQTests.pro index 8fca6f2..74d7eff 100644 --- a/FlashMQTests/FlashMQTests.pro +++ b/FlashMQTests/FlashMQTests.pro @@ -49,6 +49,7 @@ SOURCES += tst_maintests.cpp \ ../qospacketqueue.cpp \ ../threadglobals.cpp \ ../threadloop.cpp \ + ../publishcopyfactory.cpp \ mainappthread.cpp \ twoclienttestcontext.cpp @@ -90,6 +91,7 @@ HEADERS += \ ../qospacketqueue.h \ ../threadglobals.h \ ../threadloop.h \ + ../publishcopyfactory.h \ mainappthread.h \ twoclienttestcontext.h diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index b751f44..11155eb 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -111,6 +111,8 @@ private slots: void testDowngradeQoSOnSubscribeQos1to0(); void testDowngradeQoSOnSubscribeQos0to0(); + void testNotMessingUpQosLevels(); + }; MainTests::MainTests() @@ -1037,8 +1039,8 @@ void MainTests::testSavingSessions() std::shared_ptr c1ses = c1->getSession(); c1.reset(); MqttPacket publishPacket(publish); - std::shared_ptr possibleQos0Copy; - c1ses->writePacket(publishPacket, 1, possibleQos0Copy, count); + PublishCopyFactory fac(publishPacket); + c1ses->writePacket(fac, 1, count); store->saveSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db"); @@ -1241,6 +1243,71 @@ void MainTests::testDowngradeQoSOnSubscribeQos0to0() testDowngradeQoSOnSubscribeHelper(0, 0); } +/** + * @brief MainTests::testNotMessingUpQosLevels was divised because we optimize by preventing packet copies. This entails changing the vector of the original + * incoming packet, resulting in possibly changing values like QoS levels for later subscribers. + */ +void MainTests::testNotMessingUpQosLevels() +{ + const QString topic = "HK7c1MFu6kdT69fWY"; + const QByteArray payload = "M4XK2LZ2Smaazba8RobZOgoe6CENxCll"; + + TwoClientTestContext testContextSender; + TwoClientTestContext testContextReceiver1(1); + TwoClientTestContext testContextReceiver2(2); + TwoClientTestContext testContextReceiver3(3); + TwoClientTestContext testContextReceiver4(4); + TwoClientTestContext testContextReceiver5(5); + + testContextReceiver1.connectReceiver(); + testContextReceiver1.subscribeReceiver(topic, 0); + + testContextReceiver2.connectReceiver(); + testContextReceiver2.subscribeReceiver(topic, 1); + + testContextReceiver3.connectReceiver(); + testContextReceiver3.subscribeReceiver(topic, 2); + + testContextReceiver4.connectReceiver(); + testContextReceiver4.subscribeReceiver(topic, 1); + + testContextReceiver5.connectReceiver(); + testContextReceiver5.subscribeReceiver(topic, 0); + + testContextSender.connectSender(); + testContextSender.publish(topic, payload, 2, false); + + testContextReceiver1.waitReceiverReceived(1); + testContextReceiver2.waitReceiverReceived(1); + testContextReceiver3.waitReceiverReceived(1); + testContextReceiver4.waitReceiverReceived(1); + testContextReceiver5.waitReceiverReceived(1); + + QCOMPARE(testContextReceiver1.receivedMessages.count(), 1); + QCOMPARE(testContextReceiver2.receivedMessages.count(), 1); + QCOMPARE(testContextReceiver3.receivedMessages.count(), 1); + QCOMPARE(testContextReceiver4.receivedMessages.count(), 1); + QCOMPARE(testContextReceiver5.receivedMessages.count(), 1); + + QCOMPARE(testContextReceiver1.receivedMessages.first().qos(), 0); + QCOMPARE(testContextReceiver2.receivedMessages.first().qos(), 1); + QCOMPARE(testContextReceiver3.receivedMessages.first().qos(), 2); + QCOMPARE(testContextReceiver4.receivedMessages.first().qos(), 1); + QCOMPARE(testContextReceiver5.receivedMessages.first().qos(), 0); + + QCOMPARE(testContextReceiver1.receivedMessages.first().payload(), payload); + QCOMPARE(testContextReceiver2.receivedMessages.first().payload(), payload); + QCOMPARE(testContextReceiver3.receivedMessages.first().payload(), payload); + QCOMPARE(testContextReceiver4.receivedMessages.first().payload(), payload); + QCOMPARE(testContextReceiver5.receivedMessages.first().payload(), payload); + + QCOMPARE(testContextReceiver1.receivedMessages.first().id(), 0); + QCOMPARE(testContextReceiver2.receivedMessages.first().id(), 1); + QCOMPARE(testContextReceiver3.receivedMessages.first().id(), 1); + QCOMPARE(testContextReceiver4.receivedMessages.first().id(), 1); + QCOMPARE(testContextReceiver5.receivedMessages.first().id(), 0); +} + int main(int argc, char *argv[]) { diff --git a/FlashMQTests/twoclienttestcontext.cpp b/FlashMQTests/twoclienttestcontext.cpp index 72458da..af2dc32 100644 --- a/FlashMQTests/twoclienttestcontext.cpp +++ b/FlashMQTests/twoclienttestcontext.cpp @@ -22,14 +22,14 @@ License along with FlashMQ. If not, see . // TODO: port to QMqttClient that newer Qts now have? -TwoClientTestContext::TwoClientTestContext(QObject *parent) : QObject(parent) +TwoClientTestContext::TwoClientTestContext(int clientNr, QObject *parent) : QObject(parent) { QHostInfo targetHostInfo = QHostInfo::fromName("localhost"); QHostAddress targetHost(targetHostInfo.addresses().first()); sender.reset(new QMQTT::Client(targetHost)); - sender->setClientId("Sender"); + sender->setClientId(QString("Sender%1").arg(clientNr)); receiver.reset(new QMQTT::Client(targetHost)); - receiver->setClientId("Receiver"); + receiver->setClientId(QString("Receiver%1").arg(clientNr)); connect(sender.data(), &QMQTT::Client::error, this, &TwoClientTestContext::onClientError); connect(receiver.data(), &QMQTT::Client::error, this, &TwoClientTestContext::onClientError); diff --git a/FlashMQTests/twoclienttestcontext.h b/FlashMQTests/twoclienttestcontext.h index d8852d1..ef61da8 100644 --- a/FlashMQTests/twoclienttestcontext.h +++ b/FlashMQTests/twoclienttestcontext.h @@ -33,7 +33,7 @@ private slots: void onReceiverReceived(const QMQTT::Message& message); public: - explicit TwoClientTestContext(QObject *parent = nullptr); + explicit TwoClientTestContext(int clientNr = 0, QObject *parent = nullptr); void publish(const QString &topic, const QByteArray &payload); void publish(const QString &topic, const QByteArray &payload, bool retain); void publish(const QString &topic, const QByteArray &payload, const quint8 qos, bool retain); diff --git a/client.cpp b/client.cpp index 4d29e93..8b8b37f 100644 --- a/client.cpp +++ b/client.cpp @@ -207,6 +207,21 @@ int Client::writeMqttPacket(const MqttPacket &packet) return 1; } +int Client::writeMqttPacketAndBlameThisClient(PublishCopyFactory ©Factory, char max_qos, uint16_t packet_id) +{ + MqttPacket &p = copyFactory.getOptimumPacket(max_qos); + + if (p.getQos() > 0) + { + // This may change the packet ID and QoS of the incoming packet for each subscriber, but because we don't store that packet anywhere, + // that should be fine. + p.setPacketId(packet_id); + p.setQos(max_qos); + } + + return writeMqttPacketAndBlameThisClient(p); +} + // Helper method to avoid the exception ending up at the sender of messages, which would then get disconnected. int Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet) { diff --git a/client.h b/client.h index 0a2ac0d..0ca6eaa 100644 --- a/client.h +++ b/client.h @@ -37,6 +37,8 @@ License along with FlashMQ. If not, see . #include "types.h" #include "iowrapper.h" +#include "publishcopyfactory.h" + #define MQTT_HEADER_LENGH 2 @@ -122,6 +124,7 @@ public: void writeText(const std::string &text); void writePingResp(); int writeMqttPacket(const MqttPacket &packet); + int writeMqttPacketAndBlameThisClient(PublishCopyFactory ©Factory, char max_qos, uint16_t packet_id); int writeMqttPacketAndBlameThisClient(const MqttPacket &packet); bool writeBufIntoFd(); bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; } diff --git a/mqttpacket.cpp b/mqttpacket.cpp index bed586f..8440c7a 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -822,11 +822,6 @@ void MqttPacket::setDuplicate() } } -size_t MqttPacket::getTotalMemoryFootprint() -{ - return bites.size() + sizeof(MqttPacket); -} - /** * @brief MqttPacket::getPayloadCopy takes part of the vector of bytes and returns it as a string. * @return diff --git a/mqttpacket.h b/mqttpacket.h index e5d5e96..c7d0840 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -120,7 +120,6 @@ public: void setPacketId(uint16_t packet_id); uint16_t getPacketId() const; void setDuplicate(); - size_t getTotalMemoryFootprint(); void readIntoBuf(CirBuf &buf) const; std::string getPayloadCopy() const; bool getRetain() const; diff --git a/publishcopyfactory.cpp b/publishcopyfactory.cpp new file mode 100644 index 0000000..a4a482f --- /dev/null +++ b/publishcopyfactory.cpp @@ -0,0 +1,53 @@ +#include + +#include "publishcopyfactory.h" +#include "mqttpacket.h" + +PublishCopyFactory::PublishCopyFactory(MqttPacket &packet) : + packet(packet), + orgQos(packet.getQos()) +{ + +} + +MqttPacket &PublishCopyFactory::getOptimumPacket(char max_qos) +{ + if (max_qos == 0 && max_qos < packet.getQos()) + { + if (!downgradedQos0PacketCopy) + downgradedQos0PacketCopy = packet.getCopy(max_qos); + assert(downgradedQos0PacketCopy->getQos() == 0); + return *downgradedQos0PacketCopy.get(); + } + + return packet; +} + +char PublishCopyFactory::getEffectiveQos(char max_qos) const +{ + const char effectiveQos = std::min(orgQos, max_qos); + return effectiveQos; +} + +const std::string &PublishCopyFactory::getTopic() const +{ + return packet.getTopic(); +} + +const std::vector &PublishCopyFactory::getSubtopics() const +{ + return packet.getSubtopics(); +} + +bool PublishCopyFactory::getRetain() const +{ + return packet.getRetain(); +} + +Publish PublishCopyFactory::getPublish() const +{ + assert(packet.getQos() > 0); + + Publish p(packet.getTopic(), packet.getPayloadCopy(), packet.getQos()); + return p; +} diff --git a/publishcopyfactory.h b/publishcopyfactory.h new file mode 100644 index 0000000..c2933fd --- /dev/null +++ b/publishcopyfactory.h @@ -0,0 +1,35 @@ +#ifndef PUBLISHCOPYFACTORY_H +#define PUBLISHCOPYFACTORY_H + +#include + +#include "forward_declarations.h" +#include "types.h" + +/** + * @brief The PublishCopyFactory class is for managing copies of an incoming publish, including sometimes not making copies at all. + * + * The idea is that certain incoming packets can just be written to the receiving client as-is, without constructing a new one. We do have to change the bytes + * where the QoS is stored, so we keep track of the original. + */ +class PublishCopyFactory +{ + MqttPacket &packet; + const char orgQos; + std::shared_ptr downgradedQos0PacketCopy; + + // TODO: constructed mqtt3 packet and mqtt5 packet +public: + PublishCopyFactory(MqttPacket &packet); + PublishCopyFactory(const PublishCopyFactory &other) = delete; + PublishCopyFactory(PublishCopyFactory &&other) = delete; + + MqttPacket &getOptimumPacket(char max_qos); + char getEffectiveQos(char max_qos) const; + const std::string &getTopic() const; + const std::vector &getSubtopics() const; + bool getRetain() const; + Publish getPublish() const; +}; + +#endif // PUBLISHCOPYFACTORY_H diff --git a/qospacketqueue.cpp b/qospacketqueue.cpp index ba6032c..9d3c4e2 100644 --- a/qospacketqueue.cpp +++ b/qospacketqueue.cpp @@ -4,16 +4,39 @@ #include "mqttpacket.h" -void QoSPacketQueue::erase(const uint16_t packet_id) +QueuedPublish::QueuedPublish(Publish &&publish, uint16_t packet_id) : + publish(std::move(publish)), + packet_id(packet_id) +{ + +} + +uint16_t QueuedPublish::getPacketId() const +{ + return this->packet_id; +} + +const Publish &QueuedPublish::getPublish() const +{ + return publish; +} + +size_t QueuedPublish::getApproximateMemoryFootprint() const +{ + return publish.topic.length() + publish.payload.length(); +} + + +void QoSPublishQueue::erase(const uint16_t packet_id) { auto it = queue.begin(); auto end = queue.end(); while (it != end) { - std::shared_ptr &p = *it; - if (p->getPacketId() == packet_id) + QueuedPublish &p = *it; + if (p.getPacketId() == packet_id) { - size_t mem = p->getTotalMemoryFootprint(); + size_t mem = p.getApproximateMemoryFootprint(); qosQueueBytes -= mem; assert(qosQueueBytes >= 0); if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose. @@ -28,50 +51,40 @@ void QoSPacketQueue::erase(const uint16_t packet_id) } } -size_t QoSPacketQueue::size() const +size_t QoSPublishQueue::size() const { return queue.size(); } -size_t QoSPacketQueue::getByteSize() const +size_t QoSPublishQueue::getByteSize() const { return qosQueueBytes; } -/** - * @brief QoSPacketQueue::queuePacket makes a copy of the packet because it has state for the receiver in question. - * @param p - * @param id - * @return the packet copy. - */ -std::shared_ptr QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos) +void QoSPublishQueue::queuePublish(PublishCopyFactory ©Factory, uint16_t id, char new_max_qos) { - assert(p.getQos() > 0); + assert(new_max_qos > 0); + assert(id > 0); - std::shared_ptr copyPacket = p.getCopy(new_max_qos); - copyPacket->setPacketId(id); - queue.push_back(copyPacket); - qosQueueBytes += copyPacket->getTotalMemoryFootprint(); - return copyPacket; + Publish pub = copyFactory.getPublish(); + queue.emplace_back(std::move(pub), id); + qosQueueBytes += queue.back().getApproximateMemoryFootprint(); } -std::shared_ptr QoSPacketQueue::queuePacket(const Publish &pub, uint16_t id) +void QoSPublishQueue::queuePublish(Publish &&pub, uint16_t id) { - assert(pub.qos > 0); + assert(id > 0); - std::shared_ptr copyPacket(new MqttPacket(pub)); - copyPacket->setPacketId(id); - queue.push_back(copyPacket); - qosQueueBytes += copyPacket->getTotalMemoryFootprint(); - return copyPacket; + queue.emplace_back(std::move(pub), id); + qosQueueBytes += queue.back().getApproximateMemoryFootprint(); } -std::list>::const_iterator QoSPacketQueue::begin() const +std::list::const_iterator QoSPublishQueue::begin() const { return queue.cbegin(); } -std::list>::const_iterator QoSPacketQueue::end() const +std::list::const_iterator QoSPublishQueue::end() const { return queue.cend(); } diff --git a/qospacketqueue.h b/qospacketqueue.h index d3e86d5..df950bd 100644 --- a/qospacketqueue.h +++ b/qospacketqueue.h @@ -5,21 +5,39 @@ #include "forward_declarations.h" #include "types.h" +#include "publishcopyfactory.h" -class QoSPacketQueue +/** + * @brief The QueuedPublish class wraps the publish with a packet id. + * + * We don't want to store the packet id in the Publish object, because the packet id is determined/tracked per client/session. + */ +class QueuedPublish { - std::list> queue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] + Publish publish; + uint16_t packet_id = 0; +public: + QueuedPublish(Publish &&publish, uint16_t packet_id); + + size_t getApproximateMemoryFootprint() const; + uint16_t getPacketId() const; + const Publish &getPublish() const; +}; + +class QoSPublishQueue +{ + std::list queue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] ssize_t qosQueueBytes = 0; public: void erase(const uint16_t packet_id); size_t size() const; size_t getByteSize() const; - std::shared_ptr queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos); - std::shared_ptr queuePacket(const Publish &pub, uint16_t id); + void queuePublish(PublishCopyFactory ©Factory, uint16_t id, char new_max_qos); + void queuePublish(Publish &&pub, uint16_t id); - std::list>::const_iterator begin() const; - std::list>::const_iterator end() const; + std::list::const_iterator begin() const; + std::list::const_iterator end() const; }; #endif // QOSPACKETQUEUE_H diff --git a/session.cpp b/session.cpp index 309afe7..c5540e9 100644 --- a/session.cpp +++ b/session.cpp @@ -101,8 +101,7 @@ Session::Session(const Session &other) this->nextPacketId = other.nextPacketId; this->lastTouched = other.lastTouched; - // To be fully correct, we should copy the individual packets, but copying sessions is only done for saving them, and I know - // that no member of MqttPacket changes in the QoS process, so we can just keep the shared pointer to the original. + // TODO: see git history for a change here. We now copy the whole queued publish. Do we want to address that? this->qosPacketQueue = other.qosPacketQueue; } @@ -145,33 +144,25 @@ void Session::assignActiveConnection(std::shared_ptr &client) * @param retain. Keep MQTT-3.3.1-9 in mind: existing subscribers don't get retain=1 on packets. * @param count. Reference value is updated. It's for statistics. */ -void Session::writePacket(MqttPacket &packet, char max_qos, std::shared_ptr &downgradedQos0PacketCopy, uint64_t &count) +void Session::writePacket(PublishCopyFactory ©Factory, const char max_qos, uint64_t &count) { assert(max_qos <= 2); - const char effectiveQos = std::min(packet.getQos(), max_qos); + + const char effectiveQos = copyFactory.getEffectiveQos(max_qos); const Settings *settings = ThreadGlobals::getSettings(); Authentication *_auth = ThreadGlobals::getAuth(); assert(_auth); Authentication &auth = *_auth; - if (auth.aclCheck(client_id, username, packet.getTopic(), packet.getSubtopics(), AclAccess::read, effectiveQos, packet.getRetain()) == AuthResult::success) + if (auth.aclCheck(client_id, username, copyFactory.getTopic(), copyFactory.getSubtopics(), AclAccess::read, effectiveQos, copyFactory.getRetain()) == AuthResult::success) { std::shared_ptr c = makeSharedClient(); if (effectiveQos == 0) { if (c) { - const MqttPacket *packetToSend = &packet; - - if (max_qos < packet.getQos()) - { - if (!downgradedQos0PacketCopy) - downgradedQos0PacketCopy = packet.getCopy(max_qos); - packetToSend = downgradedQos0PacketCopy.get(); - } - - count += c->writeMqttPacketAndBlameThisClient(*packetToSend); + count += c->writeMqttPacketAndBlameThisClient(copyFactory, effectiveQos, 0); } } else if (effectiveQos > 0) @@ -195,12 +186,12 @@ void Session::writePacket(MqttPacket &packet, char max_qos, std::shared_ptr copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId, effectiveQos); + + qosPacketQueue.queuePublish(copyFactory, nextPacketId, effectiveQos); if (c) { - count += c->writeMqttPacketAndBlameThisClient(*copyPacket.get()); - copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. + count += c->writeMqttPacketAndBlameThisClient(copyFactory, effectiveQos, nextPacketId); } } else @@ -224,14 +215,9 @@ void Session::writePacket(MqttPacket &packet, char max_qos, std::shared_ptrwriteMqttPacketAndBlameThisClient(packet); + count += c->writeMqttPacketAndBlameThisClient(copyFactory, effectiveQos, nextPacketId); } } } @@ -267,10 +253,12 @@ uint64_t Session::sendPendingQosMessages() if (c) { std::lock_guard locker(qosQueueMutex); - for (const std::shared_ptr &qosMessage : qosPacketQueue) + for (const QueuedPublish &queuedPublish : qosPacketQueue) { - count += c->writeMqttPacketAndBlameThisClient(*qosMessage.get()); - qosMessage->setDuplicate(); // Any dealings with this packet from here will be a duplicate. + MqttPacket p(queuedPublish.getPublish()); + p.setDuplicate(); + + count += c->writeMqttPacketAndBlameThisClient(p); } for (const uint16_t packet_id : outgoingQoS2MessageIds) diff --git a/session.h b/session.h index 8b54cd8..4f92abb 100644 --- a/session.h +++ b/session.h @@ -27,6 +27,7 @@ License along with FlashMQ. If not, see . #include "logger.h" #include "sessionsandsubscriptionsdb.h" #include "qospacketqueue.h" +#include "publishcopyfactory.h" class Session { @@ -39,7 +40,7 @@ class Session std::weak_ptr client; std::string client_id; std::string username; - QoSPacketQueue qosPacketQueue; + QoSPublishQueue qosPacketQueue; std::set incomingQoS2MessageIds; std::set outgoingQoS2MessageIds; std::mutex qosQueueMutex; @@ -69,7 +70,7 @@ public: const std::string &getClientId() const { return client_id; } std::shared_ptr makeSharedClient() const; void assignActiveConnection(std::shared_ptr &client); - void writePacket(MqttPacket &packet, char max_qos, std::shared_ptr &downgradedQos0PacketCopy, uint64_t &count); + void writePacket(PublishCopyFactory ©Factory, const char max_qos, uint64_t &count); void clearQosMessage(uint16_t packet_id); uint64_t sendPendingQosMessages(); void touch(std::chrono::time_point val); diff --git a/sessionsandsubscriptionsdb.cpp b/sessionsandsubscriptionsdb.cpp index 2a85721..79491e8 100644 --- a/sessionsandsubscriptionsdb.cpp +++ b/sessionsandsubscriptionsdb.cpp @@ -116,7 +116,7 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV1() Publish pub(topic, payload, qos); logger->logf(LOG_DEBUG, "Loaded QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str()); - ses->qosPacketQueue.queuePacket(pub, id); + ses->qosPacketQueue.queuePublish(std::move(pub), id); } const uint32_t nrOfIncomingPacketIds = readUint32(eofFound); @@ -215,23 +215,24 @@ void SessionsAndSubscriptionsDB::saveData(const std::vector &p: ses->qosPacketQueue) + for (const QueuedPublish &p: ses->qosPacketQueue) { - logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", p->getQos(), p->getTopic().c_str()); + const Publish &pub = p.getPublish(); + + logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str()); qosPacketsCounted++; - writeUint16(p->getPacketId()); + writeUint16(p.getPacketId()); - writeUint32(p->getTopic().length()); - std::string payload = p->getPayloadCopy(); - writeUint32(payload.size()); + writeUint32(pub.topic.length()); + writeUint32(pub.payload.size()); - const char qos = p->getQos(); + const char qos = pub.qos; writeCheck(&qos, 1, 1, f); - writeCheck(p->getTopic().c_str(), 1, p->getTopic().length(), f); - writeCheck(payload.c_str(), 1, payload.length(), f); + writeCheck(pub.topic.c_str(), 1, pub.topic.length(), f); + writeCheck(pub.payload.c_str(), 1, pub.payload.length(), f); } assert(qosPacketsExpected == qosPacketsCounted); diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index af34ffd..edb84c0 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -21,6 +21,7 @@ License along with FlashMQ. If not, see . #include "rwlockguard.h" #include "retainedmessagesdb.h" +#include "publishcopyfactory.h" ReceivingSubscriber::ReceivingSubscriber(const std::shared_ptr &ses, char qos) : session(ses), @@ -346,10 +347,10 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::vector publishRecursively(subtopics.begin(), subtopics.end(), startNode, subscriberSessions); } - std::shared_ptr possibleQos0Copy; + PublishCopyFactory copyFactory(packet); for(const ReceivingSubscriber &x : subscriberSessions) { - x.session->writePacket(packet, x.qos, possibleQos0Copy, count); + x.session->writePacket(copyFactory, x.qos, count); } std::shared_ptr sender = packet.getSender(); @@ -425,8 +426,8 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr possibleQos0Copy; - ses->writePacket(packet, max_qos, possibleQos0Copy, count); + PublishCopyFactory copyFactory(packet); + ses->writePacket(copyFactory, max_qos, count); } return count; diff --git a/types.cpp b/types.cpp index 11f1b7a..8907b3d 100644 --- a/types.cpp +++ b/types.cpp @@ -46,7 +46,7 @@ size_t SubAck::getLengthWithoutFixedHeader() const return result; } -Publish::Publish(const std::string &topic, const std::string payload, char qos) : +Publish::Publish(const std::string &topic, const std::string &payload, char qos) : topic(topic), payload(payload), qos(qos) diff --git a/types.h b/types.h index ad62dbe..f232e75 100644 --- a/types.h +++ b/types.h @@ -101,7 +101,7 @@ public: std::string payload; char qos = 0; bool retain = false; // Note: existing subscribers don't get publishes of retained messages with retain=1. [MQTT-3.3.1-9] - Publish(const std::string &topic, const std::string payload, char qos); + Publish(const std::string &topic, const std::string &payload, char qos); size_t getLengthWithoutFixedHeader() const; };