Commit ecb60b4815605a49162a89dc287eebf981ca3f79

Authored by Wiebe Cazemier
1 parent 54d52925

Change writing packets to clients to writing PublishCopyFactory

This is a preparation for MQTT5, because when there are receivers and
publishers with different protocols, you can't always just write out the
same packet. You can sometimes though, so that's what the copy factory
determines.
CMakeLists.txt
@@ -58,6 +58,7 @@ add_executable(FlashMQ @@ -58,6 +58,7 @@ add_executable(FlashMQ
58 qospacketqueue.h 58 qospacketqueue.h
59 threadglobals.h 59 threadglobals.h
60 threadloop.h 60 threadloop.h
  61 + publishcopyfactory.h
61 62
62 mainapp.cpp 63 mainapp.cpp
63 main.cpp 64 main.cpp
@@ -95,6 +96,7 @@ add_executable(FlashMQ @@ -95,6 +96,7 @@ add_executable(FlashMQ
95 qospacketqueue.cpp 96 qospacketqueue.cpp
96 threadglobals.cpp 97 threadglobals.cpp
97 threadloop.cpp 98 threadloop.cpp
  99 + publishcopyfactory.cpp
98 100
99 ) 101 )
100 102
FlashMQTests/FlashMQTests.pro
@@ -49,6 +49,7 @@ SOURCES += tst_maintests.cpp \ @@ -49,6 +49,7 @@ SOURCES += tst_maintests.cpp \
49 ../qospacketqueue.cpp \ 49 ../qospacketqueue.cpp \
50 ../threadglobals.cpp \ 50 ../threadglobals.cpp \
51 ../threadloop.cpp \ 51 ../threadloop.cpp \
  52 + ../publishcopyfactory.cpp \
52 mainappthread.cpp \ 53 mainappthread.cpp \
53 twoclienttestcontext.cpp 54 twoclienttestcontext.cpp
54 55
@@ -90,6 +91,7 @@ HEADERS += \ @@ -90,6 +91,7 @@ HEADERS += \
90 ../qospacketqueue.h \ 91 ../qospacketqueue.h \
91 ../threadglobals.h \ 92 ../threadglobals.h \
92 ../threadloop.h \ 93 ../threadloop.h \
  94 + ../publishcopyfactory.h \
93 mainappthread.h \ 95 mainappthread.h \
94 twoclienttestcontext.h 96 twoclienttestcontext.h
95 97
FlashMQTests/tst_maintests.cpp
@@ -111,6 +111,8 @@ private slots: @@ -111,6 +111,8 @@ private slots:
111 void testDowngradeQoSOnSubscribeQos1to0(); 111 void testDowngradeQoSOnSubscribeQos1to0();
112 void testDowngradeQoSOnSubscribeQos0to0(); 112 void testDowngradeQoSOnSubscribeQos0to0();
113 113
  114 + void testNotMessingUpQosLevels();
  115 +
114 }; 116 };
115 117
116 MainTests::MainTests() 118 MainTests::MainTests()
@@ -1037,8 +1039,8 @@ void MainTests::testSavingSessions() @@ -1037,8 +1039,8 @@ void MainTests::testSavingSessions()
1037 std::shared_ptr<Session> c1ses = c1->getSession(); 1039 std::shared_ptr<Session> c1ses = c1->getSession();
1038 c1.reset(); 1040 c1.reset();
1039 MqttPacket publishPacket(publish); 1041 MqttPacket publishPacket(publish);
1040 - std::shared_ptr<MqttPacket> possibleQos0Copy;  
1041 - c1ses->writePacket(publishPacket, 1, possibleQos0Copy, count); 1042 + PublishCopyFactory fac(publishPacket);
  1043 + c1ses->writePacket(fac, 1, count);
1042 1044
1043 store->saveSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db"); 1045 store->saveSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db");
1044 1046
@@ -1241,6 +1243,71 @@ void MainTests::testDowngradeQoSOnSubscribeQos0to0() @@ -1241,6 +1243,71 @@ void MainTests::testDowngradeQoSOnSubscribeQos0to0()
1241 testDowngradeQoSOnSubscribeHelper(0, 0); 1243 testDowngradeQoSOnSubscribeHelper(0, 0);
1242 } 1244 }
1243 1245
  1246 +/**
  1247 + * @brief MainTests::testNotMessingUpQosLevels was divised because we optimize by preventing packet copies. This entails changing the vector of the original
  1248 + * incoming packet, resulting in possibly changing values like QoS levels for later subscribers.
  1249 + */
  1250 +void MainTests::testNotMessingUpQosLevels()
  1251 +{
  1252 + const QString topic = "HK7c1MFu6kdT69fWY";
  1253 + const QByteArray payload = "M4XK2LZ2Smaazba8RobZOgoe6CENxCll";
  1254 +
  1255 + TwoClientTestContext testContextSender;
  1256 + TwoClientTestContext testContextReceiver1(1);
  1257 + TwoClientTestContext testContextReceiver2(2);
  1258 + TwoClientTestContext testContextReceiver3(3);
  1259 + TwoClientTestContext testContextReceiver4(4);
  1260 + TwoClientTestContext testContextReceiver5(5);
  1261 +
  1262 + testContextReceiver1.connectReceiver();
  1263 + testContextReceiver1.subscribeReceiver(topic, 0);
  1264 +
  1265 + testContextReceiver2.connectReceiver();
  1266 + testContextReceiver2.subscribeReceiver(topic, 1);
  1267 +
  1268 + testContextReceiver3.connectReceiver();
  1269 + testContextReceiver3.subscribeReceiver(topic, 2);
  1270 +
  1271 + testContextReceiver4.connectReceiver();
  1272 + testContextReceiver4.subscribeReceiver(topic, 1);
  1273 +
  1274 + testContextReceiver5.connectReceiver();
  1275 + testContextReceiver5.subscribeReceiver(topic, 0);
  1276 +
  1277 + testContextSender.connectSender();
  1278 + testContextSender.publish(topic, payload, 2, false);
  1279 +
  1280 + testContextReceiver1.waitReceiverReceived(1);
  1281 + testContextReceiver2.waitReceiverReceived(1);
  1282 + testContextReceiver3.waitReceiverReceived(1);
  1283 + testContextReceiver4.waitReceiverReceived(1);
  1284 + testContextReceiver5.waitReceiverReceived(1);
  1285 +
  1286 + QCOMPARE(testContextReceiver1.receivedMessages.count(), 1);
  1287 + QCOMPARE(testContextReceiver2.receivedMessages.count(), 1);
  1288 + QCOMPARE(testContextReceiver3.receivedMessages.count(), 1);
  1289 + QCOMPARE(testContextReceiver4.receivedMessages.count(), 1);
  1290 + QCOMPARE(testContextReceiver5.receivedMessages.count(), 1);
  1291 +
  1292 + QCOMPARE(testContextReceiver1.receivedMessages.first().qos(), 0);
  1293 + QCOMPARE(testContextReceiver2.receivedMessages.first().qos(), 1);
  1294 + QCOMPARE(testContextReceiver3.receivedMessages.first().qos(), 2);
  1295 + QCOMPARE(testContextReceiver4.receivedMessages.first().qos(), 1);
  1296 + QCOMPARE(testContextReceiver5.receivedMessages.first().qos(), 0);
  1297 +
  1298 + QCOMPARE(testContextReceiver1.receivedMessages.first().payload(), payload);
  1299 + QCOMPARE(testContextReceiver2.receivedMessages.first().payload(), payload);
  1300 + QCOMPARE(testContextReceiver3.receivedMessages.first().payload(), payload);
  1301 + QCOMPARE(testContextReceiver4.receivedMessages.first().payload(), payload);
  1302 + QCOMPARE(testContextReceiver5.receivedMessages.first().payload(), payload);
  1303 +
  1304 + QCOMPARE(testContextReceiver1.receivedMessages.first().id(), 0);
  1305 + QCOMPARE(testContextReceiver2.receivedMessages.first().id(), 1);
  1306 + QCOMPARE(testContextReceiver3.receivedMessages.first().id(), 1);
  1307 + QCOMPARE(testContextReceiver4.receivedMessages.first().id(), 1);
  1308 + QCOMPARE(testContextReceiver5.receivedMessages.first().id(), 0);
  1309 +}
  1310 +
1244 1311
1245 int main(int argc, char *argv[]) 1312 int main(int argc, char *argv[])
1246 { 1313 {
FlashMQTests/twoclienttestcontext.cpp
@@ -22,14 +22,14 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -22,14 +22,14 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
22 22
23 // TODO: port to QMqttClient that newer Qts now have? 23 // TODO: port to QMqttClient that newer Qts now have?
24 24
25 -TwoClientTestContext::TwoClientTestContext(QObject *parent) : QObject(parent) 25 +TwoClientTestContext::TwoClientTestContext(int clientNr, QObject *parent) : QObject(parent)
26 { 26 {
27 QHostInfo targetHostInfo = QHostInfo::fromName("localhost"); 27 QHostInfo targetHostInfo = QHostInfo::fromName("localhost");
28 QHostAddress targetHost(targetHostInfo.addresses().first()); 28 QHostAddress targetHost(targetHostInfo.addresses().first());
29 sender.reset(new QMQTT::Client(targetHost)); 29 sender.reset(new QMQTT::Client(targetHost));
30 - sender->setClientId("Sender"); 30 + sender->setClientId(QString("Sender%1").arg(clientNr));
31 receiver.reset(new QMQTT::Client(targetHost)); 31 receiver.reset(new QMQTT::Client(targetHost));
32 - receiver->setClientId("Receiver"); 32 + receiver->setClientId(QString("Receiver%1").arg(clientNr));
33 33
34 connect(sender.data(), &QMQTT::Client::error, this, &TwoClientTestContext::onClientError); 34 connect(sender.data(), &QMQTT::Client::error, this, &TwoClientTestContext::onClientError);
35 connect(receiver.data(), &QMQTT::Client::error, this, &TwoClientTestContext::onClientError); 35 connect(receiver.data(), &QMQTT::Client::error, this, &TwoClientTestContext::onClientError);
FlashMQTests/twoclienttestcontext.h
@@ -33,7 +33,7 @@ private slots: @@ -33,7 +33,7 @@ private slots:
33 void onReceiverReceived(const QMQTT::Message& message); 33 void onReceiverReceived(const QMQTT::Message& message);
34 34
35 public: 35 public:
36 - explicit TwoClientTestContext(QObject *parent = nullptr); 36 + explicit TwoClientTestContext(int clientNr = 0, QObject *parent = nullptr);
37 void publish(const QString &topic, const QByteArray &payload); 37 void publish(const QString &topic, const QByteArray &payload);
38 void publish(const QString &topic, const QByteArray &payload, bool retain); 38 void publish(const QString &topic, const QByteArray &payload, bool retain);
39 void publish(const QString &topic, const QByteArray &payload, const quint8 qos, bool retain); 39 void publish(const QString &topic, const QByteArray &payload, const quint8 qos, bool retain);
client.cpp
@@ -207,6 +207,21 @@ int Client::writeMqttPacket(const MqttPacket &amp;packet) @@ -207,6 +207,21 @@ int Client::writeMqttPacket(const MqttPacket &amp;packet)
207 return 1; 207 return 1;
208 } 208 }
209 209
  210 +int Client::writeMqttPacketAndBlameThisClient(PublishCopyFactory &copyFactory, char max_qos, uint16_t packet_id)
  211 +{
  212 + MqttPacket &p = copyFactory.getOptimumPacket(max_qos);
  213 +
  214 + if (p.getQos() > 0)
  215 + {
  216 + // This may change the packet ID and QoS of the incoming packet for each subscriber, but because we don't store that packet anywhere,
  217 + // that should be fine.
  218 + p.setPacketId(packet_id);
  219 + p.setQos(max_qos);
  220 + }
  221 +
  222 + return writeMqttPacketAndBlameThisClient(p);
  223 +}
  224 +
210 // Helper method to avoid the exception ending up at the sender of messages, which would then get disconnected. 225 // Helper method to avoid the exception ending up at the sender of messages, which would then get disconnected.
211 int Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet) 226 int Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet)
212 { 227 {
client.h
@@ -37,6 +37,8 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -37,6 +37,8 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
37 #include "types.h" 37 #include "types.h"
38 #include "iowrapper.h" 38 #include "iowrapper.h"
39 39
  40 +#include "publishcopyfactory.h"
  41 +
40 #define MQTT_HEADER_LENGH 2 42 #define MQTT_HEADER_LENGH 2
41 43
42 44
@@ -122,6 +124,7 @@ public: @@ -122,6 +124,7 @@ public:
122 void writeText(const std::string &text); 124 void writeText(const std::string &text);
123 void writePingResp(); 125 void writePingResp();
124 int writeMqttPacket(const MqttPacket &packet); 126 int writeMqttPacket(const MqttPacket &packet);
  127 + int writeMqttPacketAndBlameThisClient(PublishCopyFactory &copyFactory, char max_qos, uint16_t packet_id);
125 int writeMqttPacketAndBlameThisClient(const MqttPacket &packet); 128 int writeMqttPacketAndBlameThisClient(const MqttPacket &packet);
126 bool writeBufIntoFd(); 129 bool writeBufIntoFd();
127 bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; } 130 bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; }
mqttpacket.cpp
@@ -822,11 +822,6 @@ void MqttPacket::setDuplicate() @@ -822,11 +822,6 @@ void MqttPacket::setDuplicate()
822 } 822 }
823 } 823 }
824 824
825 -size_t MqttPacket::getTotalMemoryFootprint()  
826 -{  
827 - return bites.size() + sizeof(MqttPacket);  
828 -}  
829 -  
830 /** 825 /**
831 * @brief MqttPacket::getPayloadCopy takes part of the vector of bytes and returns it as a string. 826 * @brief MqttPacket::getPayloadCopy takes part of the vector of bytes and returns it as a string.
832 * @return 827 * @return
mqttpacket.h
@@ -120,7 +120,6 @@ public: @@ -120,7 +120,6 @@ public:
120 void setPacketId(uint16_t packet_id); 120 void setPacketId(uint16_t packet_id);
121 uint16_t getPacketId() const; 121 uint16_t getPacketId() const;
122 void setDuplicate(); 122 void setDuplicate();
123 - size_t getTotalMemoryFootprint();  
124 void readIntoBuf(CirBuf &buf) const; 123 void readIntoBuf(CirBuf &buf) const;
125 std::string getPayloadCopy() const; 124 std::string getPayloadCopy() const;
126 bool getRetain() const; 125 bool getRetain() const;
publishcopyfactory.cpp 0 โ†’ 100644
  1 +#include <cassert>
  2 +
  3 +#include "publishcopyfactory.h"
  4 +#include "mqttpacket.h"
  5 +
  6 +PublishCopyFactory::PublishCopyFactory(MqttPacket &packet) :
  7 + packet(packet),
  8 + orgQos(packet.getQos())
  9 +{
  10 +
  11 +}
  12 +
  13 +MqttPacket &PublishCopyFactory::getOptimumPacket(char max_qos)
  14 +{
  15 + if (max_qos == 0 && max_qos < packet.getQos())
  16 + {
  17 + if (!downgradedQos0PacketCopy)
  18 + downgradedQos0PacketCopy = packet.getCopy(max_qos);
  19 + assert(downgradedQos0PacketCopy->getQos() == 0);
  20 + return *downgradedQos0PacketCopy.get();
  21 + }
  22 +
  23 + return packet;
  24 +}
  25 +
  26 +char PublishCopyFactory::getEffectiveQos(char max_qos) const
  27 +{
  28 + const char effectiveQos = std::min<char>(orgQos, max_qos);
  29 + return effectiveQos;
  30 +}
  31 +
  32 +const std::string &PublishCopyFactory::getTopic() const
  33 +{
  34 + return packet.getTopic();
  35 +}
  36 +
  37 +const std::vector<std::string> &PublishCopyFactory::getSubtopics() const
  38 +{
  39 + return packet.getSubtopics();
  40 +}
  41 +
  42 +bool PublishCopyFactory::getRetain() const
  43 +{
  44 + return packet.getRetain();
  45 +}
  46 +
  47 +Publish PublishCopyFactory::getPublish() const
  48 +{
  49 + assert(packet.getQos() > 0);
  50 +
  51 + Publish p(packet.getTopic(), packet.getPayloadCopy(), packet.getQos());
  52 + return p;
  53 +}
publishcopyfactory.h 0 โ†’ 100644
  1 +#ifndef PUBLISHCOPYFACTORY_H
  2 +#define PUBLISHCOPYFACTORY_H
  3 +
  4 +#include <vector>
  5 +
  6 +#include "forward_declarations.h"
  7 +#include "types.h"
  8 +
  9 +/**
  10 + * @brief The PublishCopyFactory class is for managing copies of an incoming publish, including sometimes not making copies at all.
  11 + *
  12 + * The idea is that certain incoming packets can just be written to the receiving client as-is, without constructing a new one. We do have to change the bytes
  13 + * where the QoS is stored, so we keep track of the original.
  14 + */
  15 +class PublishCopyFactory
  16 +{
  17 + MqttPacket &packet;
  18 + const char orgQos;
  19 + std::shared_ptr<MqttPacket> downgradedQos0PacketCopy;
  20 +
  21 + // TODO: constructed mqtt3 packet and mqtt5 packet
  22 +public:
  23 + PublishCopyFactory(MqttPacket &packet);
  24 + PublishCopyFactory(const PublishCopyFactory &other) = delete;
  25 + PublishCopyFactory(PublishCopyFactory &&other) = delete;
  26 +
  27 + MqttPacket &getOptimumPacket(char max_qos);
  28 + char getEffectiveQos(char max_qos) const;
  29 + const std::string &getTopic() const;
  30 + const std::vector<std::string> &getSubtopics() const;
  31 + bool getRetain() const;
  32 + Publish getPublish() const;
  33 +};
  34 +
  35 +#endif // PUBLISHCOPYFACTORY_H
qospacketqueue.cpp
@@ -4,16 +4,39 @@ @@ -4,16 +4,39 @@
4 4
5 #include "mqttpacket.h" 5 #include "mqttpacket.h"
6 6
7 -void QoSPacketQueue::erase(const uint16_t packet_id) 7 +QueuedPublish::QueuedPublish(Publish &&publish, uint16_t packet_id) :
  8 + publish(std::move(publish)),
  9 + packet_id(packet_id)
  10 +{
  11 +
  12 +}
  13 +
  14 +uint16_t QueuedPublish::getPacketId() const
  15 +{
  16 + return this->packet_id;
  17 +}
  18 +
  19 +const Publish &QueuedPublish::getPublish() const
  20 +{
  21 + return publish;
  22 +}
  23 +
  24 +size_t QueuedPublish::getApproximateMemoryFootprint() const
  25 +{
  26 + return publish.topic.length() + publish.payload.length();
  27 +}
  28 +
  29 +
  30 +void QoSPublishQueue::erase(const uint16_t packet_id)
8 { 31 {
9 auto it = queue.begin(); 32 auto it = queue.begin();
10 auto end = queue.end(); 33 auto end = queue.end();
11 while (it != end) 34 while (it != end)
12 { 35 {
13 - std::shared_ptr<MqttPacket> &p = *it;  
14 - if (p->getPacketId() == packet_id) 36 + QueuedPublish &p = *it;
  37 + if (p.getPacketId() == packet_id)
15 { 38 {
16 - size_t mem = p->getTotalMemoryFootprint(); 39 + size_t mem = p.getApproximateMemoryFootprint();
17 qosQueueBytes -= mem; 40 qosQueueBytes -= mem;
18 assert(qosQueueBytes >= 0); 41 assert(qosQueueBytes >= 0);
19 if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose. 42 if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose.
@@ -28,50 +51,40 @@ void QoSPacketQueue::erase(const uint16_t packet_id) @@ -28,50 +51,40 @@ void QoSPacketQueue::erase(const uint16_t packet_id)
28 } 51 }
29 } 52 }
30 53
31 -size_t QoSPacketQueue::size() const 54 +size_t QoSPublishQueue::size() const
32 { 55 {
33 return queue.size(); 56 return queue.size();
34 } 57 }
35 58
36 -size_t QoSPacketQueue::getByteSize() const 59 +size_t QoSPublishQueue::getByteSize() const
37 { 60 {
38 return qosQueueBytes; 61 return qosQueueBytes;
39 } 62 }
40 63
41 -/**  
42 - * @brief QoSPacketQueue::queuePacket makes a copy of the packet because it has state for the receiver in question.  
43 - * @param p  
44 - * @param id  
45 - * @return the packet copy.  
46 - */  
47 -std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos) 64 +void QoSPublishQueue::queuePublish(PublishCopyFactory &copyFactory, uint16_t id, char new_max_qos)
48 { 65 {
49 - assert(p.getQos() > 0); 66 + assert(new_max_qos > 0);
  67 + assert(id > 0);
50 68
51 - std::shared_ptr<MqttPacket> copyPacket = p.getCopy(new_max_qos);  
52 - copyPacket->setPacketId(id);  
53 - queue.push_back(copyPacket);  
54 - qosQueueBytes += copyPacket->getTotalMemoryFootprint();  
55 - return copyPacket; 69 + Publish pub = copyFactory.getPublish();
  70 + queue.emplace_back(std::move(pub), id);
  71 + qosQueueBytes += queue.back().getApproximateMemoryFootprint();
56 } 72 }
57 73
58 -std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const Publish &pub, uint16_t id) 74 +void QoSPublishQueue::queuePublish(Publish &&pub, uint16_t id)
59 { 75 {
60 - assert(pub.qos > 0); 76 + assert(id > 0);
61 77
62 - std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(pub));  
63 - copyPacket->setPacketId(id);  
64 - queue.push_back(copyPacket);  
65 - qosQueueBytes += copyPacket->getTotalMemoryFootprint();  
66 - return copyPacket; 78 + queue.emplace_back(std::move(pub), id);
  79 + qosQueueBytes += queue.back().getApproximateMemoryFootprint();
67 } 80 }
68 81
69 -std::list<std::shared_ptr<MqttPacket>>::const_iterator QoSPacketQueue::begin() const 82 +std::list<QueuedPublish>::const_iterator QoSPublishQueue::begin() const
70 { 83 {
71 return queue.cbegin(); 84 return queue.cbegin();
72 } 85 }
73 86
74 -std::list<std::shared_ptr<MqttPacket>>::const_iterator QoSPacketQueue::end() const 87 +std::list<QueuedPublish>::const_iterator QoSPublishQueue::end() const
75 { 88 {
76 return queue.cend(); 89 return queue.cend();
77 } 90 }
qospacketqueue.h
@@ -5,21 +5,39 @@ @@ -5,21 +5,39 @@
5 5
6 #include "forward_declarations.h" 6 #include "forward_declarations.h"
7 #include "types.h" 7 #include "types.h"
  8 +#include "publishcopyfactory.h"
8 9
9 -class QoSPacketQueue 10 +/**
  11 + * @brief The QueuedPublish class wraps the publish with a packet id.
  12 + *
  13 + * We don't want to store the packet id in the Publish object, because the packet id is determined/tracked per client/session.
  14 + */
  15 +class QueuedPublish
10 { 16 {
11 - std::list<std::shared_ptr<MqttPacket>> queue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] 17 + Publish publish;
  18 + uint16_t packet_id = 0;
  19 +public:
  20 + QueuedPublish(Publish &&publish, uint16_t packet_id);
  21 +
  22 + size_t getApproximateMemoryFootprint() const;
  23 + uint16_t getPacketId() const;
  24 + const Publish &getPublish() const;
  25 +};
  26 +
  27 +class QoSPublishQueue
  28 +{
  29 + std::list<QueuedPublish> queue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6]
12 ssize_t qosQueueBytes = 0; 30 ssize_t qosQueueBytes = 0;
13 31
14 public: 32 public:
15 void erase(const uint16_t packet_id); 33 void erase(const uint16_t packet_id);
16 size_t size() const; 34 size_t size() const;
17 size_t getByteSize() const; 35 size_t getByteSize() const;
18 - std::shared_ptr<MqttPacket> queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos);  
19 - std::shared_ptr<MqttPacket> queuePacket(const Publish &pub, uint16_t id); 36 + void queuePublish(PublishCopyFactory &copyFactory, uint16_t id, char new_max_qos);
  37 + void queuePublish(Publish &&pub, uint16_t id);
20 38
21 - std::list<std::shared_ptr<MqttPacket>>::const_iterator begin() const;  
22 - std::list<std::shared_ptr<MqttPacket>>::const_iterator end() const; 39 + std::list<QueuedPublish>::const_iterator begin() const;
  40 + std::list<QueuedPublish>::const_iterator end() const;
23 }; 41 };
24 42
25 #endif // QOSPACKETQUEUE_H 43 #endif // QOSPACKETQUEUE_H
session.cpp
@@ -101,8 +101,7 @@ Session::Session(const Session &amp;other) @@ -101,8 +101,7 @@ Session::Session(const Session &amp;other)
101 this->nextPacketId = other.nextPacketId; 101 this->nextPacketId = other.nextPacketId;
102 this->lastTouched = other.lastTouched; 102 this->lastTouched = other.lastTouched;
103 103
104 - // To be fully correct, we should copy the individual packets, but copying sessions is only done for saving them, and I know  
105 - // that no member of MqttPacket changes in the QoS process, so we can just keep the shared pointer to the original. 104 + // TODO: see git history for a change here. We now copy the whole queued publish. Do we want to address that?
106 this->qosPacketQueue = other.qosPacketQueue; 105 this->qosPacketQueue = other.qosPacketQueue;
107 } 106 }
108 107
@@ -145,33 +144,25 @@ void Session::assignActiveConnection(std::shared_ptr&lt;Client&gt; &amp;client) @@ -145,33 +144,25 @@ void Session::assignActiveConnection(std::shared_ptr&lt;Client&gt; &amp;client)
145 * @param retain. Keep MQTT-3.3.1-9 in mind: existing subscribers don't get retain=1 on packets. 144 * @param retain. Keep MQTT-3.3.1-9 in mind: existing subscribers don't get retain=1 on packets.
146 * @param count. Reference value is updated. It's for statistics. 145 * @param count. Reference value is updated. It's for statistics.
147 */ 146 */
148 -void Session::writePacket(MqttPacket &packet, char max_qos, std::shared_ptr<MqttPacket> &downgradedQos0PacketCopy, uint64_t &count) 147 +void Session::writePacket(PublishCopyFactory &copyFactory, const char max_qos, uint64_t &count)
149 { 148 {
150 assert(max_qos <= 2); 149 assert(max_qos <= 2);
151 - const char effectiveQos = std::min<char>(packet.getQos(), max_qos); 150 +
  151 + const char effectiveQos = copyFactory.getEffectiveQos(max_qos);
152 152
153 const Settings *settings = ThreadGlobals::getSettings(); 153 const Settings *settings = ThreadGlobals::getSettings();
154 154
155 Authentication *_auth = ThreadGlobals::getAuth(); 155 Authentication *_auth = ThreadGlobals::getAuth();
156 assert(_auth); 156 assert(_auth);
157 Authentication &auth = *_auth; 157 Authentication &auth = *_auth;
158 - if (auth.aclCheck(client_id, username, packet.getTopic(), packet.getSubtopics(), AclAccess::read, effectiveQos, packet.getRetain()) == AuthResult::success) 158 + if (auth.aclCheck(client_id, username, copyFactory.getTopic(), copyFactory.getSubtopics(), AclAccess::read, effectiveQos, copyFactory.getRetain()) == AuthResult::success)
159 { 159 {
160 std::shared_ptr<Client> c = makeSharedClient(); 160 std::shared_ptr<Client> c = makeSharedClient();
161 if (effectiveQos == 0) 161 if (effectiveQos == 0)
162 { 162 {
163 if (c) 163 if (c)
164 { 164 {
165 - const MqttPacket *packetToSend = &packet;  
166 -  
167 - if (max_qos < packet.getQos())  
168 - {  
169 - if (!downgradedQos0PacketCopy)  
170 - downgradedQos0PacketCopy = packet.getCopy(max_qos);  
171 - packetToSend = downgradedQos0PacketCopy.get();  
172 - }  
173 -  
174 - count += c->writeMqttPacketAndBlameThisClient(*packetToSend); 165 + count += c->writeMqttPacketAndBlameThisClient(copyFactory, effectiveQos, 0);
175 } 166 }
176 } 167 }
177 else if (effectiveQos > 0) 168 else if (effectiveQos > 0)
@@ -195,12 +186,12 @@ void Session::writePacket(MqttPacket &amp;packet, char max_qos, std::shared_ptr&lt;Mqtt @@ -195,12 +186,12 @@ void Session::writePacket(MqttPacket &amp;packet, char max_qos, std::shared_ptr&lt;Mqtt
195 } 186 }
196 187
197 increasePacketId(); 188 increasePacketId();
198 - std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId, effectiveQos); 189 +
  190 + qosPacketQueue.queuePublish(copyFactory, nextPacketId, effectiveQos);
199 191
200 if (c) 192 if (c)
201 { 193 {
202 - count += c->writeMqttPacketAndBlameThisClient(*copyPacket.get());  
203 - copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. 194 + count += c->writeMqttPacketAndBlameThisClient(copyFactory, effectiveQos, nextPacketId);
204 } 195 }
205 } 196 }
206 else 197 else
@@ -224,14 +215,9 @@ void Session::writePacket(MqttPacket &amp;packet, char max_qos, std::shared_ptr&lt;Mqtt @@ -224,14 +215,9 @@ void Session::writePacket(MqttPacket &amp;packet, char max_qos, std::shared_ptr&lt;Mqtt
224 215
225 increasePacketId(); 216 increasePacketId();
226 217
227 - // This changes the packet ID and QoS of the incoming packet for each subscriber, but because we don't store that packet anywhere,  
228 - // that should be fine.  
229 - packet.setPacketId(nextPacketId);  
230 - packet.setQos(effectiveQos);  
231 -  
232 qosInFlightCounter++; 218 qosInFlightCounter++;
233 assert(c); // with requiresRetransmission==false, there must be a client. 219 assert(c); // with requiresRetransmission==false, there must be a client.
234 - count += c->writeMqttPacketAndBlameThisClient(packet); 220 + count += c->writeMqttPacketAndBlameThisClient(copyFactory, effectiveQos, nextPacketId);
235 } 221 }
236 } 222 }
237 } 223 }
@@ -267,10 +253,12 @@ uint64_t Session::sendPendingQosMessages() @@ -267,10 +253,12 @@ uint64_t Session::sendPendingQosMessages()
267 if (c) 253 if (c)
268 { 254 {
269 std::lock_guard<std::mutex> locker(qosQueueMutex); 255 std::lock_guard<std::mutex> locker(qosQueueMutex);
270 - for (const std::shared_ptr<MqttPacket> &qosMessage : qosPacketQueue) 256 + for (const QueuedPublish &queuedPublish : qosPacketQueue)
271 { 257 {
272 - count += c->writeMqttPacketAndBlameThisClient(*qosMessage.get());  
273 - qosMessage->setDuplicate(); // Any dealings with this packet from here will be a duplicate. 258 + MqttPacket p(queuedPublish.getPublish());
  259 + p.setDuplicate();
  260 +
  261 + count += c->writeMqttPacketAndBlameThisClient(p);
274 } 262 }
275 263
276 for (const uint16_t packet_id : outgoingQoS2MessageIds) 264 for (const uint16_t packet_id : outgoingQoS2MessageIds)
session.h
@@ -27,6 +27,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -27,6 +27,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
27 #include "logger.h" 27 #include "logger.h"
28 #include "sessionsandsubscriptionsdb.h" 28 #include "sessionsandsubscriptionsdb.h"
29 #include "qospacketqueue.h" 29 #include "qospacketqueue.h"
  30 +#include "publishcopyfactory.h"
30 31
31 class Session 32 class Session
32 { 33 {
@@ -39,7 +40,7 @@ class Session @@ -39,7 +40,7 @@ class Session
39 std::weak_ptr<Client> client; 40 std::weak_ptr<Client> client;
40 std::string client_id; 41 std::string client_id;
41 std::string username; 42 std::string username;
42 - QoSPacketQueue qosPacketQueue; 43 + QoSPublishQueue qosPacketQueue;
43 std::set<uint16_t> incomingQoS2MessageIds; 44 std::set<uint16_t> incomingQoS2MessageIds;
44 std::set<uint16_t> outgoingQoS2MessageIds; 45 std::set<uint16_t> outgoingQoS2MessageIds;
45 std::mutex qosQueueMutex; 46 std::mutex qosQueueMutex;
@@ -69,7 +70,7 @@ public: @@ -69,7 +70,7 @@ public:
69 const std::string &getClientId() const { return client_id; } 70 const std::string &getClientId() const { return client_id; }
70 std::shared_ptr<Client> makeSharedClient() const; 71 std::shared_ptr<Client> makeSharedClient() const;
71 void assignActiveConnection(std::shared_ptr<Client> &client); 72 void assignActiveConnection(std::shared_ptr<Client> &client);
72 - void writePacket(MqttPacket &packet, char max_qos, std::shared_ptr<MqttPacket> &downgradedQos0PacketCopy, uint64_t &count); 73 + void writePacket(PublishCopyFactory &copyFactory, const char max_qos, uint64_t &count);
73 void clearQosMessage(uint16_t packet_id); 74 void clearQosMessage(uint16_t packet_id);
74 uint64_t sendPendingQosMessages(); 75 uint64_t sendPendingQosMessages();
75 void touch(std::chrono::time_point<std::chrono::steady_clock> val); 76 void touch(std::chrono::time_point<std::chrono::steady_clock> val);
sessionsandsubscriptionsdb.cpp
@@ -116,7 +116,7 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV1() @@ -116,7 +116,7 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV1()
116 116
117 Publish pub(topic, payload, qos); 117 Publish pub(topic, payload, qos);
118 logger->logf(LOG_DEBUG, "Loaded QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str()); 118 logger->logf(LOG_DEBUG, "Loaded QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str());
119 - ses->qosPacketQueue.queuePacket(pub, id); 119 + ses->qosPacketQueue.queuePublish(std::move(pub), id);
120 } 120 }
121 121
122 const uint32_t nrOfIncomingPacketIds = readUint32(eofFound); 122 const uint32_t nrOfIncomingPacketIds = readUint32(eofFound);
@@ -215,23 +215,24 @@ void SessionsAndSubscriptionsDB::saveData(const std::vector&lt;std::unique_ptr&lt;Sess @@ -215,23 +215,24 @@ void SessionsAndSubscriptionsDB::saveData(const std::vector&lt;std::unique_ptr&lt;Sess
215 size_t qosPacketsCounted = 0; 215 size_t qosPacketsCounted = 0;
216 writeUint32(qosPacketsExpected); 216 writeUint32(qosPacketsExpected);
217 217
218 - for (const std::shared_ptr<MqttPacket> &p: ses->qosPacketQueue) 218 + for (const QueuedPublish &p: ses->qosPacketQueue)
219 { 219 {
220 - logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", p->getQos(), p->getTopic().c_str()); 220 + const Publish &pub = p.getPublish();
  221 +
  222 + logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str());
221 223
222 qosPacketsCounted++; 224 qosPacketsCounted++;
223 225
224 - writeUint16(p->getPacketId()); 226 + writeUint16(p.getPacketId());
225 227
226 - writeUint32(p->getTopic().length());  
227 - std::string payload = p->getPayloadCopy();  
228 - writeUint32(payload.size()); 228 + writeUint32(pub.topic.length());
  229 + writeUint32(pub.payload.size());
229 230
230 - const char qos = p->getQos(); 231 + const char qos = pub.qos;
231 writeCheck(&qos, 1, 1, f); 232 writeCheck(&qos, 1, 1, f);
232 233
233 - writeCheck(p->getTopic().c_str(), 1, p->getTopic().length(), f);  
234 - writeCheck(payload.c_str(), 1, payload.length(), f); 234 + writeCheck(pub.topic.c_str(), 1, pub.topic.length(), f);
  235 + writeCheck(pub.payload.c_str(), 1, pub.payload.length(), f);
235 } 236 }
236 237
237 assert(qosPacketsExpected == qosPacketsCounted); 238 assert(qosPacketsExpected == qosPacketsCounted);
subscriptionstore.cpp
@@ -21,6 +21,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -21,6 +21,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
21 21
22 #include "rwlockguard.h" 22 #include "rwlockguard.h"
23 #include "retainedmessagesdb.h" 23 #include "retainedmessagesdb.h"
  24 +#include "publishcopyfactory.h"
24 25
25 ReceivingSubscriber::ReceivingSubscriber(const std::shared_ptr<Session> &ses, char qos) : 26 ReceivingSubscriber::ReceivingSubscriber(const std::shared_ptr<Session> &ses, char qos) :
26 session(ses), 27 session(ses),
@@ -346,10 +347,10 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::vector&lt;std::string&gt; @@ -346,10 +347,10 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::vector&lt;std::string&gt;
346 publishRecursively(subtopics.begin(), subtopics.end(), startNode, subscriberSessions); 347 publishRecursively(subtopics.begin(), subtopics.end(), startNode, subscriberSessions);
347 } 348 }
348 349
349 - std::shared_ptr<MqttPacket> possibleQos0Copy; 350 + PublishCopyFactory copyFactory(packet);
350 for(const ReceivingSubscriber &x : subscriberSessions) 351 for(const ReceivingSubscriber &x : subscriberSessions)
351 { 352 {
352 - x.session->writePacket(packet, x.qos, possibleQos0Copy, count); 353 + x.session->writePacket(copyFactory, x.qos, count);
353 } 354 }
354 355
355 std::shared_ptr<Client> sender = packet.getSender(); 356 std::shared_ptr<Client> sender = packet.getSender();
@@ -425,8 +426,8 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr&lt;Ses @@ -425,8 +426,8 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr&lt;Ses
425 426
426 for(MqttPacket &packet : packetList) 427 for(MqttPacket &packet : packetList)
427 { 428 {
428 - std::shared_ptr<MqttPacket> possibleQos0Copy;  
429 - ses->writePacket(packet, max_qos, possibleQos0Copy, count); 429 + PublishCopyFactory copyFactory(packet);
  430 + ses->writePacket(copyFactory, max_qos, count);
430 } 431 }
431 432
432 return count; 433 return count;
types.cpp
@@ -46,7 +46,7 @@ size_t SubAck::getLengthWithoutFixedHeader() const @@ -46,7 +46,7 @@ size_t SubAck::getLengthWithoutFixedHeader() const
46 return result; 46 return result;
47 } 47 }
48 48
49 -Publish::Publish(const std::string &topic, const std::string payload, char qos) : 49 +Publish::Publish(const std::string &topic, const std::string &payload, char qos) :
50 topic(topic), 50 topic(topic),
51 payload(payload), 51 payload(payload),
52 qos(qos) 52 qos(qos)
@@ -101,7 +101,7 @@ public: @@ -101,7 +101,7 @@ public:
101 std::string payload; 101 std::string payload;
102 char qos = 0; 102 char qos = 0;
103 bool retain = false; // Note: existing subscribers don't get publishes of retained messages with retain=1. [MQTT-3.3.1-9] 103 bool retain = false; // Note: existing subscribers don't get publishes of retained messages with retain=1. [MQTT-3.3.1-9]
104 - Publish(const std::string &topic, const std::string payload, char qos); 104 + Publish(const std::string &topic, const std::string &payload, char qos);
105 size_t getLengthWithoutFixedHeader() const; 105 size_t getLengthWithoutFixedHeader() const;
106 }; 106 };
107 107