Commit ecb60b4815605a49162a89dc287eebf981ca3f79

Authored by Wiebe Cazemier
1 parent 54d52925

Change writing packets to clients to writing PublishCopyFactory

This is a preparation for MQTT5, because when there are receivers and
publishers with different protocols, you can't always just write out the
same packet. You can sometimes though, so that's what the copy factory
determines.
CMakeLists.txt
... ... @@ -58,6 +58,7 @@ add_executable(FlashMQ
58 58 qospacketqueue.h
59 59 threadglobals.h
60 60 threadloop.h
  61 + publishcopyfactory.h
61 62  
62 63 mainapp.cpp
63 64 main.cpp
... ... @@ -95,6 +96,7 @@ add_executable(FlashMQ
95 96 qospacketqueue.cpp
96 97 threadglobals.cpp
97 98 threadloop.cpp
  99 + publishcopyfactory.cpp
98 100  
99 101 )
100 102  
... ...
FlashMQTests/FlashMQTests.pro
... ... @@ -49,6 +49,7 @@ SOURCES += tst_maintests.cpp \
49 49 ../qospacketqueue.cpp \
50 50 ../threadglobals.cpp \
51 51 ../threadloop.cpp \
  52 + ../publishcopyfactory.cpp \
52 53 mainappthread.cpp \
53 54 twoclienttestcontext.cpp
54 55  
... ... @@ -90,6 +91,7 @@ HEADERS += \
90 91 ../qospacketqueue.h \
91 92 ../threadglobals.h \
92 93 ../threadloop.h \
  94 + ../publishcopyfactory.h \
93 95 mainappthread.h \
94 96 twoclienttestcontext.h
95 97  
... ...
FlashMQTests/tst_maintests.cpp
... ... @@ -111,6 +111,8 @@ private slots:
111 111 void testDowngradeQoSOnSubscribeQos1to0();
112 112 void testDowngradeQoSOnSubscribeQos0to0();
113 113  
  114 + void testNotMessingUpQosLevels();
  115 +
114 116 };
115 117  
116 118 MainTests::MainTests()
... ... @@ -1037,8 +1039,8 @@ void MainTests::testSavingSessions()
1037 1039 std::shared_ptr<Session> c1ses = c1->getSession();
1038 1040 c1.reset();
1039 1041 MqttPacket publishPacket(publish);
1040   - std::shared_ptr<MqttPacket> possibleQos0Copy;
1041   - c1ses->writePacket(publishPacket, 1, possibleQos0Copy, count);
  1042 + PublishCopyFactory fac(publishPacket);
  1043 + c1ses->writePacket(fac, 1, count);
1042 1044  
1043 1045 store->saveSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db");
1044 1046  
... ... @@ -1241,6 +1243,71 @@ void MainTests::testDowngradeQoSOnSubscribeQos0to0()
1241 1243 testDowngradeQoSOnSubscribeHelper(0, 0);
1242 1244 }
1243 1245  
  1246 +/**
  1247 + * @brief MainTests::testNotMessingUpQosLevels was divised because we optimize by preventing packet copies. This entails changing the vector of the original
  1248 + * incoming packet, resulting in possibly changing values like QoS levels for later subscribers.
  1249 + */
  1250 +void MainTests::testNotMessingUpQosLevels()
  1251 +{
  1252 + const QString topic = "HK7c1MFu6kdT69fWY";
  1253 + const QByteArray payload = "M4XK2LZ2Smaazba8RobZOgoe6CENxCll";
  1254 +
  1255 + TwoClientTestContext testContextSender;
  1256 + TwoClientTestContext testContextReceiver1(1);
  1257 + TwoClientTestContext testContextReceiver2(2);
  1258 + TwoClientTestContext testContextReceiver3(3);
  1259 + TwoClientTestContext testContextReceiver4(4);
  1260 + TwoClientTestContext testContextReceiver5(5);
  1261 +
  1262 + testContextReceiver1.connectReceiver();
  1263 + testContextReceiver1.subscribeReceiver(topic, 0);
  1264 +
  1265 + testContextReceiver2.connectReceiver();
  1266 + testContextReceiver2.subscribeReceiver(topic, 1);
  1267 +
  1268 + testContextReceiver3.connectReceiver();
  1269 + testContextReceiver3.subscribeReceiver(topic, 2);
  1270 +
  1271 + testContextReceiver4.connectReceiver();
  1272 + testContextReceiver4.subscribeReceiver(topic, 1);
  1273 +
  1274 + testContextReceiver5.connectReceiver();
  1275 + testContextReceiver5.subscribeReceiver(topic, 0);
  1276 +
  1277 + testContextSender.connectSender();
  1278 + testContextSender.publish(topic, payload, 2, false);
  1279 +
  1280 + testContextReceiver1.waitReceiverReceived(1);
  1281 + testContextReceiver2.waitReceiverReceived(1);
  1282 + testContextReceiver3.waitReceiverReceived(1);
  1283 + testContextReceiver4.waitReceiverReceived(1);
  1284 + testContextReceiver5.waitReceiverReceived(1);
  1285 +
  1286 + QCOMPARE(testContextReceiver1.receivedMessages.count(), 1);
  1287 + QCOMPARE(testContextReceiver2.receivedMessages.count(), 1);
  1288 + QCOMPARE(testContextReceiver3.receivedMessages.count(), 1);
  1289 + QCOMPARE(testContextReceiver4.receivedMessages.count(), 1);
  1290 + QCOMPARE(testContextReceiver5.receivedMessages.count(), 1);
  1291 +
  1292 + QCOMPARE(testContextReceiver1.receivedMessages.first().qos(), 0);
  1293 + QCOMPARE(testContextReceiver2.receivedMessages.first().qos(), 1);
  1294 + QCOMPARE(testContextReceiver3.receivedMessages.first().qos(), 2);
  1295 + QCOMPARE(testContextReceiver4.receivedMessages.first().qos(), 1);
  1296 + QCOMPARE(testContextReceiver5.receivedMessages.first().qos(), 0);
  1297 +
  1298 + QCOMPARE(testContextReceiver1.receivedMessages.first().payload(), payload);
  1299 + QCOMPARE(testContextReceiver2.receivedMessages.first().payload(), payload);
  1300 + QCOMPARE(testContextReceiver3.receivedMessages.first().payload(), payload);
  1301 + QCOMPARE(testContextReceiver4.receivedMessages.first().payload(), payload);
  1302 + QCOMPARE(testContextReceiver5.receivedMessages.first().payload(), payload);
  1303 +
  1304 + QCOMPARE(testContextReceiver1.receivedMessages.first().id(), 0);
  1305 + QCOMPARE(testContextReceiver2.receivedMessages.first().id(), 1);
  1306 + QCOMPARE(testContextReceiver3.receivedMessages.first().id(), 1);
  1307 + QCOMPARE(testContextReceiver4.receivedMessages.first().id(), 1);
  1308 + QCOMPARE(testContextReceiver5.receivedMessages.first().id(), 0);
  1309 +}
  1310 +
1244 1311  
1245 1312 int main(int argc, char *argv[])
1246 1313 {
... ...
FlashMQTests/twoclienttestcontext.cpp
... ... @@ -22,14 +22,14 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
22 22  
23 23 // TODO: port to QMqttClient that newer Qts now have?
24 24  
25   -TwoClientTestContext::TwoClientTestContext(QObject *parent) : QObject(parent)
  25 +TwoClientTestContext::TwoClientTestContext(int clientNr, QObject *parent) : QObject(parent)
26 26 {
27 27 QHostInfo targetHostInfo = QHostInfo::fromName("localhost");
28 28 QHostAddress targetHost(targetHostInfo.addresses().first());
29 29 sender.reset(new QMQTT::Client(targetHost));
30   - sender->setClientId("Sender");
  30 + sender->setClientId(QString("Sender%1").arg(clientNr));
31 31 receiver.reset(new QMQTT::Client(targetHost));
32   - receiver->setClientId("Receiver");
  32 + receiver->setClientId(QString("Receiver%1").arg(clientNr));
33 33  
34 34 connect(sender.data(), &QMQTT::Client::error, this, &TwoClientTestContext::onClientError);
35 35 connect(receiver.data(), &QMQTT::Client::error, this, &TwoClientTestContext::onClientError);
... ...
FlashMQTests/twoclienttestcontext.h
... ... @@ -33,7 +33,7 @@ private slots:
33 33 void onReceiverReceived(const QMQTT::Message& message);
34 34  
35 35 public:
36   - explicit TwoClientTestContext(QObject *parent = nullptr);
  36 + explicit TwoClientTestContext(int clientNr = 0, QObject *parent = nullptr);
37 37 void publish(const QString &topic, const QByteArray &payload);
38 38 void publish(const QString &topic, const QByteArray &payload, bool retain);
39 39 void publish(const QString &topic, const QByteArray &payload, const quint8 qos, bool retain);
... ...
client.cpp
... ... @@ -207,6 +207,21 @@ int Client::writeMqttPacket(const MqttPacket &amp;packet)
207 207 return 1;
208 208 }
209 209  
  210 +int Client::writeMqttPacketAndBlameThisClient(PublishCopyFactory &copyFactory, char max_qos, uint16_t packet_id)
  211 +{
  212 + MqttPacket &p = copyFactory.getOptimumPacket(max_qos);
  213 +
  214 + if (p.getQos() > 0)
  215 + {
  216 + // This may change the packet ID and QoS of the incoming packet for each subscriber, but because we don't store that packet anywhere,
  217 + // that should be fine.
  218 + p.setPacketId(packet_id);
  219 + p.setQos(max_qos);
  220 + }
  221 +
  222 + return writeMqttPacketAndBlameThisClient(p);
  223 +}
  224 +
210 225 // Helper method to avoid the exception ending up at the sender of messages, which would then get disconnected.
211 226 int Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet)
212 227 {
... ...
client.h
... ... @@ -37,6 +37,8 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
37 37 #include "types.h"
38 38 #include "iowrapper.h"
39 39  
  40 +#include "publishcopyfactory.h"
  41 +
40 42 #define MQTT_HEADER_LENGH 2
41 43  
42 44  
... ... @@ -122,6 +124,7 @@ public:
122 124 void writeText(const std::string &text);
123 125 void writePingResp();
124 126 int writeMqttPacket(const MqttPacket &packet);
  127 + int writeMqttPacketAndBlameThisClient(PublishCopyFactory &copyFactory, char max_qos, uint16_t packet_id);
125 128 int writeMqttPacketAndBlameThisClient(const MqttPacket &packet);
126 129 bool writeBufIntoFd();
127 130 bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; }
... ...
mqttpacket.cpp
... ... @@ -822,11 +822,6 @@ void MqttPacket::setDuplicate()
822 822 }
823 823 }
824 824  
825   -size_t MqttPacket::getTotalMemoryFootprint()
826   -{
827   - return bites.size() + sizeof(MqttPacket);
828   -}
829   -
830 825 /**
831 826 * @brief MqttPacket::getPayloadCopy takes part of the vector of bytes and returns it as a string.
832 827 * @return
... ...
mqttpacket.h
... ... @@ -120,7 +120,6 @@ public:
120 120 void setPacketId(uint16_t packet_id);
121 121 uint16_t getPacketId() const;
122 122 void setDuplicate();
123   - size_t getTotalMemoryFootprint();
124 123 void readIntoBuf(CirBuf &buf) const;
125 124 std::string getPayloadCopy() const;
126 125 bool getRetain() const;
... ...
publishcopyfactory.cpp 0 โ†’ 100644
  1 +#include <cassert>
  2 +
  3 +#include "publishcopyfactory.h"
  4 +#include "mqttpacket.h"
  5 +
  6 +PublishCopyFactory::PublishCopyFactory(MqttPacket &packet) :
  7 + packet(packet),
  8 + orgQos(packet.getQos())
  9 +{
  10 +
  11 +}
  12 +
  13 +MqttPacket &PublishCopyFactory::getOptimumPacket(char max_qos)
  14 +{
  15 + if (max_qos == 0 && max_qos < packet.getQos())
  16 + {
  17 + if (!downgradedQos0PacketCopy)
  18 + downgradedQos0PacketCopy = packet.getCopy(max_qos);
  19 + assert(downgradedQos0PacketCopy->getQos() == 0);
  20 + return *downgradedQos0PacketCopy.get();
  21 + }
  22 +
  23 + return packet;
  24 +}
  25 +
  26 +char PublishCopyFactory::getEffectiveQos(char max_qos) const
  27 +{
  28 + const char effectiveQos = std::min<char>(orgQos, max_qos);
  29 + return effectiveQos;
  30 +}
  31 +
  32 +const std::string &PublishCopyFactory::getTopic() const
  33 +{
  34 + return packet.getTopic();
  35 +}
  36 +
  37 +const std::vector<std::string> &PublishCopyFactory::getSubtopics() const
  38 +{
  39 + return packet.getSubtopics();
  40 +}
  41 +
  42 +bool PublishCopyFactory::getRetain() const
  43 +{
  44 + return packet.getRetain();
  45 +}
  46 +
  47 +Publish PublishCopyFactory::getPublish() const
  48 +{
  49 + assert(packet.getQos() > 0);
  50 +
  51 + Publish p(packet.getTopic(), packet.getPayloadCopy(), packet.getQos());
  52 + return p;
  53 +}
... ...
publishcopyfactory.h 0 โ†’ 100644
  1 +#ifndef PUBLISHCOPYFACTORY_H
  2 +#define PUBLISHCOPYFACTORY_H
  3 +
  4 +#include <vector>
  5 +
  6 +#include "forward_declarations.h"
  7 +#include "types.h"
  8 +
  9 +/**
  10 + * @brief The PublishCopyFactory class is for managing copies of an incoming publish, including sometimes not making copies at all.
  11 + *
  12 + * The idea is that certain incoming packets can just be written to the receiving client as-is, without constructing a new one. We do have to change the bytes
  13 + * where the QoS is stored, so we keep track of the original.
  14 + */
  15 +class PublishCopyFactory
  16 +{
  17 + MqttPacket &packet;
  18 + const char orgQos;
  19 + std::shared_ptr<MqttPacket> downgradedQos0PacketCopy;
  20 +
  21 + // TODO: constructed mqtt3 packet and mqtt5 packet
  22 +public:
  23 + PublishCopyFactory(MqttPacket &packet);
  24 + PublishCopyFactory(const PublishCopyFactory &other) = delete;
  25 + PublishCopyFactory(PublishCopyFactory &&other) = delete;
  26 +
  27 + MqttPacket &getOptimumPacket(char max_qos);
  28 + char getEffectiveQos(char max_qos) const;
  29 + const std::string &getTopic() const;
  30 + const std::vector<std::string> &getSubtopics() const;
  31 + bool getRetain() const;
  32 + Publish getPublish() const;
  33 +};
  34 +
  35 +#endif // PUBLISHCOPYFACTORY_H
... ...
qospacketqueue.cpp
... ... @@ -4,16 +4,39 @@
4 4  
5 5 #include "mqttpacket.h"
6 6  
7   -void QoSPacketQueue::erase(const uint16_t packet_id)
  7 +QueuedPublish::QueuedPublish(Publish &&publish, uint16_t packet_id) :
  8 + publish(std::move(publish)),
  9 + packet_id(packet_id)
  10 +{
  11 +
  12 +}
  13 +
  14 +uint16_t QueuedPublish::getPacketId() const
  15 +{
  16 + return this->packet_id;
  17 +}
  18 +
  19 +const Publish &QueuedPublish::getPublish() const
  20 +{
  21 + return publish;
  22 +}
  23 +
  24 +size_t QueuedPublish::getApproximateMemoryFootprint() const
  25 +{
  26 + return publish.topic.length() + publish.payload.length();
  27 +}
  28 +
  29 +
  30 +void QoSPublishQueue::erase(const uint16_t packet_id)
8 31 {
9 32 auto it = queue.begin();
10 33 auto end = queue.end();
11 34 while (it != end)
12 35 {
13   - std::shared_ptr<MqttPacket> &p = *it;
14   - if (p->getPacketId() == packet_id)
  36 + QueuedPublish &p = *it;
  37 + if (p.getPacketId() == packet_id)
15 38 {
16   - size_t mem = p->getTotalMemoryFootprint();
  39 + size_t mem = p.getApproximateMemoryFootprint();
17 40 qosQueueBytes -= mem;
18 41 assert(qosQueueBytes >= 0);
19 42 if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose.
... ... @@ -28,50 +51,40 @@ void QoSPacketQueue::erase(const uint16_t packet_id)
28 51 }
29 52 }
30 53  
31   -size_t QoSPacketQueue::size() const
  54 +size_t QoSPublishQueue::size() const
32 55 {
33 56 return queue.size();
34 57 }
35 58  
36   -size_t QoSPacketQueue::getByteSize() const
  59 +size_t QoSPublishQueue::getByteSize() const
37 60 {
38 61 return qosQueueBytes;
39 62 }
40 63  
41   -/**
42   - * @brief QoSPacketQueue::queuePacket makes a copy of the packet because it has state for the receiver in question.
43   - * @param p
44   - * @param id
45   - * @return the packet copy.
46   - */
47   -std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos)
  64 +void QoSPublishQueue::queuePublish(PublishCopyFactory &copyFactory, uint16_t id, char new_max_qos)
48 65 {
49   - assert(p.getQos() > 0);
  66 + assert(new_max_qos > 0);
  67 + assert(id > 0);
50 68  
51   - std::shared_ptr<MqttPacket> copyPacket = p.getCopy(new_max_qos);
52   - copyPacket->setPacketId(id);
53   - queue.push_back(copyPacket);
54   - qosQueueBytes += copyPacket->getTotalMemoryFootprint();
55   - return copyPacket;
  69 + Publish pub = copyFactory.getPublish();
  70 + queue.emplace_back(std::move(pub), id);
  71 + qosQueueBytes += queue.back().getApproximateMemoryFootprint();
56 72 }
57 73  
58   -std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const Publish &pub, uint16_t id)
  74 +void QoSPublishQueue::queuePublish(Publish &&pub, uint16_t id)
59 75 {
60   - assert(pub.qos > 0);
  76 + assert(id > 0);
61 77  
62   - std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(pub));
63   - copyPacket->setPacketId(id);
64   - queue.push_back(copyPacket);
65   - qosQueueBytes += copyPacket->getTotalMemoryFootprint();
66   - return copyPacket;
  78 + queue.emplace_back(std::move(pub), id);
  79 + qosQueueBytes += queue.back().getApproximateMemoryFootprint();
67 80 }
68 81  
69   -std::list<std::shared_ptr<MqttPacket>>::const_iterator QoSPacketQueue::begin() const
  82 +std::list<QueuedPublish>::const_iterator QoSPublishQueue::begin() const
70 83 {
71 84 return queue.cbegin();
72 85 }
73 86  
74   -std::list<std::shared_ptr<MqttPacket>>::const_iterator QoSPacketQueue::end() const
  87 +std::list<QueuedPublish>::const_iterator QoSPublishQueue::end() const
75 88 {
76 89 return queue.cend();
77 90 }
... ...
qospacketqueue.h
... ... @@ -5,21 +5,39 @@
5 5  
6 6 #include "forward_declarations.h"
7 7 #include "types.h"
  8 +#include "publishcopyfactory.h"
8 9  
9   -class QoSPacketQueue
  10 +/**
  11 + * @brief The QueuedPublish class wraps the publish with a packet id.
  12 + *
  13 + * We don't want to store the packet id in the Publish object, because the packet id is determined/tracked per client/session.
  14 + */
  15 +class QueuedPublish
10 16 {
11   - std::list<std::shared_ptr<MqttPacket>> queue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6]
  17 + Publish publish;
  18 + uint16_t packet_id = 0;
  19 +public:
  20 + QueuedPublish(Publish &&publish, uint16_t packet_id);
  21 +
  22 + size_t getApproximateMemoryFootprint() const;
  23 + uint16_t getPacketId() const;
  24 + const Publish &getPublish() const;
  25 +};
  26 +
  27 +class QoSPublishQueue
  28 +{
  29 + std::list<QueuedPublish> queue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6]
12 30 ssize_t qosQueueBytes = 0;
13 31  
14 32 public:
15 33 void erase(const uint16_t packet_id);
16 34 size_t size() const;
17 35 size_t getByteSize() const;
18   - std::shared_ptr<MqttPacket> queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos);
19   - std::shared_ptr<MqttPacket> queuePacket(const Publish &pub, uint16_t id);
  36 + void queuePublish(PublishCopyFactory &copyFactory, uint16_t id, char new_max_qos);
  37 + void queuePublish(Publish &&pub, uint16_t id);
20 38  
21   - std::list<std::shared_ptr<MqttPacket>>::const_iterator begin() const;
22   - std::list<std::shared_ptr<MqttPacket>>::const_iterator end() const;
  39 + std::list<QueuedPublish>::const_iterator begin() const;
  40 + std::list<QueuedPublish>::const_iterator end() const;
23 41 };
24 42  
25 43 #endif // QOSPACKETQUEUE_H
... ...
session.cpp
... ... @@ -101,8 +101,7 @@ Session::Session(const Session &amp;other)
101 101 this->nextPacketId = other.nextPacketId;
102 102 this->lastTouched = other.lastTouched;
103 103  
104   - // To be fully correct, we should copy the individual packets, but copying sessions is only done for saving them, and I know
105   - // that no member of MqttPacket changes in the QoS process, so we can just keep the shared pointer to the original.
  104 + // TODO: see git history for a change here. We now copy the whole queued publish. Do we want to address that?
106 105 this->qosPacketQueue = other.qosPacketQueue;
107 106 }
108 107  
... ... @@ -145,33 +144,25 @@ void Session::assignActiveConnection(std::shared_ptr&lt;Client&gt; &amp;client)
145 144 * @param retain. Keep MQTT-3.3.1-9 in mind: existing subscribers don't get retain=1 on packets.
146 145 * @param count. Reference value is updated. It's for statistics.
147 146 */
148   -void Session::writePacket(MqttPacket &packet, char max_qos, std::shared_ptr<MqttPacket> &downgradedQos0PacketCopy, uint64_t &count)
  147 +void Session::writePacket(PublishCopyFactory &copyFactory, const char max_qos, uint64_t &count)
149 148 {
150 149 assert(max_qos <= 2);
151   - const char effectiveQos = std::min<char>(packet.getQos(), max_qos);
  150 +
  151 + const char effectiveQos = copyFactory.getEffectiveQos(max_qos);
152 152  
153 153 const Settings *settings = ThreadGlobals::getSettings();
154 154  
155 155 Authentication *_auth = ThreadGlobals::getAuth();
156 156 assert(_auth);
157 157 Authentication &auth = *_auth;
158   - if (auth.aclCheck(client_id, username, packet.getTopic(), packet.getSubtopics(), AclAccess::read, effectiveQos, packet.getRetain()) == AuthResult::success)
  158 + if (auth.aclCheck(client_id, username, copyFactory.getTopic(), copyFactory.getSubtopics(), AclAccess::read, effectiveQos, copyFactory.getRetain()) == AuthResult::success)
159 159 {
160 160 std::shared_ptr<Client> c = makeSharedClient();
161 161 if (effectiveQos == 0)
162 162 {
163 163 if (c)
164 164 {
165   - const MqttPacket *packetToSend = &packet;
166   -
167   - if (max_qos < packet.getQos())
168   - {
169   - if (!downgradedQos0PacketCopy)
170   - downgradedQos0PacketCopy = packet.getCopy(max_qos);
171   - packetToSend = downgradedQos0PacketCopy.get();
172   - }
173   -
174   - count += c->writeMqttPacketAndBlameThisClient(*packetToSend);
  165 + count += c->writeMqttPacketAndBlameThisClient(copyFactory, effectiveQos, 0);
175 166 }
176 167 }
177 168 else if (effectiveQos > 0)
... ... @@ -195,12 +186,12 @@ void Session::writePacket(MqttPacket &amp;packet, char max_qos, std::shared_ptr&lt;Mqtt
195 186 }
196 187  
197 188 increasePacketId();
198   - std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId, effectiveQos);
  189 +
  190 + qosPacketQueue.queuePublish(copyFactory, nextPacketId, effectiveQos);
199 191  
200 192 if (c)
201 193 {
202   - count += c->writeMqttPacketAndBlameThisClient(*copyPacket.get());
203   - copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate.
  194 + count += c->writeMqttPacketAndBlameThisClient(copyFactory, effectiveQos, nextPacketId);
204 195 }
205 196 }
206 197 else
... ... @@ -224,14 +215,9 @@ void Session::writePacket(MqttPacket &amp;packet, char max_qos, std::shared_ptr&lt;Mqtt
224 215  
225 216 increasePacketId();
226 217  
227   - // This changes the packet ID and QoS of the incoming packet for each subscriber, but because we don't store that packet anywhere,
228   - // that should be fine.
229   - packet.setPacketId(nextPacketId);
230   - packet.setQos(effectiveQos);
231   -
232 218 qosInFlightCounter++;
233 219 assert(c); // with requiresRetransmission==false, there must be a client.
234   - count += c->writeMqttPacketAndBlameThisClient(packet);
  220 + count += c->writeMqttPacketAndBlameThisClient(copyFactory, effectiveQos, nextPacketId);
235 221 }
236 222 }
237 223 }
... ... @@ -267,10 +253,12 @@ uint64_t Session::sendPendingQosMessages()
267 253 if (c)
268 254 {
269 255 std::lock_guard<std::mutex> locker(qosQueueMutex);
270   - for (const std::shared_ptr<MqttPacket> &qosMessage : qosPacketQueue)
  256 + for (const QueuedPublish &queuedPublish : qosPacketQueue)
271 257 {
272   - count += c->writeMqttPacketAndBlameThisClient(*qosMessage.get());
273   - qosMessage->setDuplicate(); // Any dealings with this packet from here will be a duplicate.
  258 + MqttPacket p(queuedPublish.getPublish());
  259 + p.setDuplicate();
  260 +
  261 + count += c->writeMqttPacketAndBlameThisClient(p);
274 262 }
275 263  
276 264 for (const uint16_t packet_id : outgoingQoS2MessageIds)
... ...
session.h
... ... @@ -27,6 +27,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
27 27 #include "logger.h"
28 28 #include "sessionsandsubscriptionsdb.h"
29 29 #include "qospacketqueue.h"
  30 +#include "publishcopyfactory.h"
30 31  
31 32 class Session
32 33 {
... ... @@ -39,7 +40,7 @@ class Session
39 40 std::weak_ptr<Client> client;
40 41 std::string client_id;
41 42 std::string username;
42   - QoSPacketQueue qosPacketQueue;
  43 + QoSPublishQueue qosPacketQueue;
43 44 std::set<uint16_t> incomingQoS2MessageIds;
44 45 std::set<uint16_t> outgoingQoS2MessageIds;
45 46 std::mutex qosQueueMutex;
... ... @@ -69,7 +70,7 @@ public:
69 70 const std::string &getClientId() const { return client_id; }
70 71 std::shared_ptr<Client> makeSharedClient() const;
71 72 void assignActiveConnection(std::shared_ptr<Client> &client);
72   - void writePacket(MqttPacket &packet, char max_qos, std::shared_ptr<MqttPacket> &downgradedQos0PacketCopy, uint64_t &count);
  73 + void writePacket(PublishCopyFactory &copyFactory, const char max_qos, uint64_t &count);
73 74 void clearQosMessage(uint16_t packet_id);
74 75 uint64_t sendPendingQosMessages();
75 76 void touch(std::chrono::time_point<std::chrono::steady_clock> val);
... ...
sessionsandsubscriptionsdb.cpp
... ... @@ -116,7 +116,7 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV1()
116 116  
117 117 Publish pub(topic, payload, qos);
118 118 logger->logf(LOG_DEBUG, "Loaded QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str());
119   - ses->qosPacketQueue.queuePacket(pub, id);
  119 + ses->qosPacketQueue.queuePublish(std::move(pub), id);
120 120 }
121 121  
122 122 const uint32_t nrOfIncomingPacketIds = readUint32(eofFound);
... ... @@ -215,23 +215,24 @@ void SessionsAndSubscriptionsDB::saveData(const std::vector&lt;std::unique_ptr&lt;Sess
215 215 size_t qosPacketsCounted = 0;
216 216 writeUint32(qosPacketsExpected);
217 217  
218   - for (const std::shared_ptr<MqttPacket> &p: ses->qosPacketQueue)
  218 + for (const QueuedPublish &p: ses->qosPacketQueue)
219 219 {
220   - logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", p->getQos(), p->getTopic().c_str());
  220 + const Publish &pub = p.getPublish();
  221 +
  222 + logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str());
221 223  
222 224 qosPacketsCounted++;
223 225  
224   - writeUint16(p->getPacketId());
  226 + writeUint16(p.getPacketId());
225 227  
226   - writeUint32(p->getTopic().length());
227   - std::string payload = p->getPayloadCopy();
228   - writeUint32(payload.size());
  228 + writeUint32(pub.topic.length());
  229 + writeUint32(pub.payload.size());
229 230  
230   - const char qos = p->getQos();
  231 + const char qos = pub.qos;
231 232 writeCheck(&qos, 1, 1, f);
232 233  
233   - writeCheck(p->getTopic().c_str(), 1, p->getTopic().length(), f);
234   - writeCheck(payload.c_str(), 1, payload.length(), f);
  234 + writeCheck(pub.topic.c_str(), 1, pub.topic.length(), f);
  235 + writeCheck(pub.payload.c_str(), 1, pub.payload.length(), f);
235 236 }
236 237  
237 238 assert(qosPacketsExpected == qosPacketsCounted);
... ...
subscriptionstore.cpp
... ... @@ -21,6 +21,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
21 21  
22 22 #include "rwlockguard.h"
23 23 #include "retainedmessagesdb.h"
  24 +#include "publishcopyfactory.h"
24 25  
25 26 ReceivingSubscriber::ReceivingSubscriber(const std::shared_ptr<Session> &ses, char qos) :
26 27 session(ses),
... ... @@ -346,10 +347,10 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::vector&lt;std::string&gt;
346 347 publishRecursively(subtopics.begin(), subtopics.end(), startNode, subscriberSessions);
347 348 }
348 349  
349   - std::shared_ptr<MqttPacket> possibleQos0Copy;
  350 + PublishCopyFactory copyFactory(packet);
350 351 for(const ReceivingSubscriber &x : subscriberSessions)
351 352 {
352   - x.session->writePacket(packet, x.qos, possibleQos0Copy, count);
  353 + x.session->writePacket(copyFactory, x.qos, count);
353 354 }
354 355  
355 356 std::shared_ptr<Client> sender = packet.getSender();
... ... @@ -425,8 +426,8 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr&lt;Ses
425 426  
426 427 for(MqttPacket &packet : packetList)
427 428 {
428   - std::shared_ptr<MqttPacket> possibleQos0Copy;
429   - ses->writePacket(packet, max_qos, possibleQos0Copy, count);
  429 + PublishCopyFactory copyFactory(packet);
  430 + ses->writePacket(copyFactory, max_qos, count);
430 431 }
431 432  
432 433 return count;
... ...
types.cpp
... ... @@ -46,7 +46,7 @@ size_t SubAck::getLengthWithoutFixedHeader() const
46 46 return result;
47 47 }
48 48  
49   -Publish::Publish(const std::string &topic, const std::string payload, char qos) :
  49 +Publish::Publish(const std::string &topic, const std::string &payload, char qos) :
50 50 topic(topic),
51 51 payload(payload),
52 52 qos(qos)
... ...
... ... @@ -101,7 +101,7 @@ public:
101 101 std::string payload;
102 102 char qos = 0;
103 103 bool retain = false; // Note: existing subscribers don't get publishes of retained messages with retain=1. [MQTT-3.3.1-9]
104   - Publish(const std::string &topic, const std::string payload, char qos);
  104 + Publish(const std::string &topic, const std::string &payload, char qos);
105 105 size_t getLengthWithoutFixedHeader() const;
106 106 };
107 107  
... ...