Commit ecb60b4815605a49162a89dc287eebf981ca3f79
1 parent
54d52925
Change writing packets to clients to writing PublishCopyFactory
This is a preparation for MQTT5, because when there are receivers and publishers with different protocols, you can't always just write out the same packet. You can sometimes though, so that's what the copy factory determines.
Showing
19 changed files
with
284 additions
and
91 deletions
CMakeLists.txt
| ... | ... | @@ -58,6 +58,7 @@ add_executable(FlashMQ |
| 58 | 58 | qospacketqueue.h |
| 59 | 59 | threadglobals.h |
| 60 | 60 | threadloop.h |
| 61 | + publishcopyfactory.h | |
| 61 | 62 | |
| 62 | 63 | mainapp.cpp |
| 63 | 64 | main.cpp |
| ... | ... | @@ -95,6 +96,7 @@ add_executable(FlashMQ |
| 95 | 96 | qospacketqueue.cpp |
| 96 | 97 | threadglobals.cpp |
| 97 | 98 | threadloop.cpp |
| 99 | + publishcopyfactory.cpp | |
| 98 | 100 | |
| 99 | 101 | ) |
| 100 | 102 | ... | ... |
FlashMQTests/FlashMQTests.pro
| ... | ... | @@ -49,6 +49,7 @@ SOURCES += tst_maintests.cpp \ |
| 49 | 49 | ../qospacketqueue.cpp \ |
| 50 | 50 | ../threadglobals.cpp \ |
| 51 | 51 | ../threadloop.cpp \ |
| 52 | + ../publishcopyfactory.cpp \ | |
| 52 | 53 | mainappthread.cpp \ |
| 53 | 54 | twoclienttestcontext.cpp |
| 54 | 55 | |
| ... | ... | @@ -90,6 +91,7 @@ HEADERS += \ |
| 90 | 91 | ../qospacketqueue.h \ |
| 91 | 92 | ../threadglobals.h \ |
| 92 | 93 | ../threadloop.h \ |
| 94 | + ../publishcopyfactory.h \ | |
| 93 | 95 | mainappthread.h \ |
| 94 | 96 | twoclienttestcontext.h |
| 95 | 97 | ... | ... |
FlashMQTests/tst_maintests.cpp
| ... | ... | @@ -111,6 +111,8 @@ private slots: |
| 111 | 111 | void testDowngradeQoSOnSubscribeQos1to0(); |
| 112 | 112 | void testDowngradeQoSOnSubscribeQos0to0(); |
| 113 | 113 | |
| 114 | + void testNotMessingUpQosLevels(); | |
| 115 | + | |
| 114 | 116 | }; |
| 115 | 117 | |
| 116 | 118 | MainTests::MainTests() |
| ... | ... | @@ -1037,8 +1039,8 @@ void MainTests::testSavingSessions() |
| 1037 | 1039 | std::shared_ptr<Session> c1ses = c1->getSession(); |
| 1038 | 1040 | c1.reset(); |
| 1039 | 1041 | MqttPacket publishPacket(publish); |
| 1040 | - std::shared_ptr<MqttPacket> possibleQos0Copy; | |
| 1041 | - c1ses->writePacket(publishPacket, 1, possibleQos0Copy, count); | |
| 1042 | + PublishCopyFactory fac(publishPacket); | |
| 1043 | + c1ses->writePacket(fac, 1, count); | |
| 1042 | 1044 | |
| 1043 | 1045 | store->saveSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db"); |
| 1044 | 1046 | |
| ... | ... | @@ -1241,6 +1243,71 @@ void MainTests::testDowngradeQoSOnSubscribeQos0to0() |
| 1241 | 1243 | testDowngradeQoSOnSubscribeHelper(0, 0); |
| 1242 | 1244 | } |
| 1243 | 1245 | |
| 1246 | +/** | |
| 1247 | + * @brief MainTests::testNotMessingUpQosLevels was divised because we optimize by preventing packet copies. This entails changing the vector of the original | |
| 1248 | + * incoming packet, resulting in possibly changing values like QoS levels for later subscribers. | |
| 1249 | + */ | |
| 1250 | +void MainTests::testNotMessingUpQosLevels() | |
| 1251 | +{ | |
| 1252 | + const QString topic = "HK7c1MFu6kdT69fWY"; | |
| 1253 | + const QByteArray payload = "M4XK2LZ2Smaazba8RobZOgoe6CENxCll"; | |
| 1254 | + | |
| 1255 | + TwoClientTestContext testContextSender; | |
| 1256 | + TwoClientTestContext testContextReceiver1(1); | |
| 1257 | + TwoClientTestContext testContextReceiver2(2); | |
| 1258 | + TwoClientTestContext testContextReceiver3(3); | |
| 1259 | + TwoClientTestContext testContextReceiver4(4); | |
| 1260 | + TwoClientTestContext testContextReceiver5(5); | |
| 1261 | + | |
| 1262 | + testContextReceiver1.connectReceiver(); | |
| 1263 | + testContextReceiver1.subscribeReceiver(topic, 0); | |
| 1264 | + | |
| 1265 | + testContextReceiver2.connectReceiver(); | |
| 1266 | + testContextReceiver2.subscribeReceiver(topic, 1); | |
| 1267 | + | |
| 1268 | + testContextReceiver3.connectReceiver(); | |
| 1269 | + testContextReceiver3.subscribeReceiver(topic, 2); | |
| 1270 | + | |
| 1271 | + testContextReceiver4.connectReceiver(); | |
| 1272 | + testContextReceiver4.subscribeReceiver(topic, 1); | |
| 1273 | + | |
| 1274 | + testContextReceiver5.connectReceiver(); | |
| 1275 | + testContextReceiver5.subscribeReceiver(topic, 0); | |
| 1276 | + | |
| 1277 | + testContextSender.connectSender(); | |
| 1278 | + testContextSender.publish(topic, payload, 2, false); | |
| 1279 | + | |
| 1280 | + testContextReceiver1.waitReceiverReceived(1); | |
| 1281 | + testContextReceiver2.waitReceiverReceived(1); | |
| 1282 | + testContextReceiver3.waitReceiverReceived(1); | |
| 1283 | + testContextReceiver4.waitReceiverReceived(1); | |
| 1284 | + testContextReceiver5.waitReceiverReceived(1); | |
| 1285 | + | |
| 1286 | + QCOMPARE(testContextReceiver1.receivedMessages.count(), 1); | |
| 1287 | + QCOMPARE(testContextReceiver2.receivedMessages.count(), 1); | |
| 1288 | + QCOMPARE(testContextReceiver3.receivedMessages.count(), 1); | |
| 1289 | + QCOMPARE(testContextReceiver4.receivedMessages.count(), 1); | |
| 1290 | + QCOMPARE(testContextReceiver5.receivedMessages.count(), 1); | |
| 1291 | + | |
| 1292 | + QCOMPARE(testContextReceiver1.receivedMessages.first().qos(), 0); | |
| 1293 | + QCOMPARE(testContextReceiver2.receivedMessages.first().qos(), 1); | |
| 1294 | + QCOMPARE(testContextReceiver3.receivedMessages.first().qos(), 2); | |
| 1295 | + QCOMPARE(testContextReceiver4.receivedMessages.first().qos(), 1); | |
| 1296 | + QCOMPARE(testContextReceiver5.receivedMessages.first().qos(), 0); | |
| 1297 | + | |
| 1298 | + QCOMPARE(testContextReceiver1.receivedMessages.first().payload(), payload); | |
| 1299 | + QCOMPARE(testContextReceiver2.receivedMessages.first().payload(), payload); | |
| 1300 | + QCOMPARE(testContextReceiver3.receivedMessages.first().payload(), payload); | |
| 1301 | + QCOMPARE(testContextReceiver4.receivedMessages.first().payload(), payload); | |
| 1302 | + QCOMPARE(testContextReceiver5.receivedMessages.first().payload(), payload); | |
| 1303 | + | |
| 1304 | + QCOMPARE(testContextReceiver1.receivedMessages.first().id(), 0); | |
| 1305 | + QCOMPARE(testContextReceiver2.receivedMessages.first().id(), 1); | |
| 1306 | + QCOMPARE(testContextReceiver3.receivedMessages.first().id(), 1); | |
| 1307 | + QCOMPARE(testContextReceiver4.receivedMessages.first().id(), 1); | |
| 1308 | + QCOMPARE(testContextReceiver5.receivedMessages.first().id(), 0); | |
| 1309 | +} | |
| 1310 | + | |
| 1244 | 1311 | |
| 1245 | 1312 | int main(int argc, char *argv[]) |
| 1246 | 1313 | { | ... | ... |
FlashMQTests/twoclienttestcontext.cpp
| ... | ... | @@ -22,14 +22,14 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. |
| 22 | 22 | |
| 23 | 23 | // TODO: port to QMqttClient that newer Qts now have? |
| 24 | 24 | |
| 25 | -TwoClientTestContext::TwoClientTestContext(QObject *parent) : QObject(parent) | |
| 25 | +TwoClientTestContext::TwoClientTestContext(int clientNr, QObject *parent) : QObject(parent) | |
| 26 | 26 | { |
| 27 | 27 | QHostInfo targetHostInfo = QHostInfo::fromName("localhost"); |
| 28 | 28 | QHostAddress targetHost(targetHostInfo.addresses().first()); |
| 29 | 29 | sender.reset(new QMQTT::Client(targetHost)); |
| 30 | - sender->setClientId("Sender"); | |
| 30 | + sender->setClientId(QString("Sender%1").arg(clientNr)); | |
| 31 | 31 | receiver.reset(new QMQTT::Client(targetHost)); |
| 32 | - receiver->setClientId("Receiver"); | |
| 32 | + receiver->setClientId(QString("Receiver%1").arg(clientNr)); | |
| 33 | 33 | |
| 34 | 34 | connect(sender.data(), &QMQTT::Client::error, this, &TwoClientTestContext::onClientError); |
| 35 | 35 | connect(receiver.data(), &QMQTT::Client::error, this, &TwoClientTestContext::onClientError); | ... | ... |
FlashMQTests/twoclienttestcontext.h
| ... | ... | @@ -33,7 +33,7 @@ private slots: |
| 33 | 33 | void onReceiverReceived(const QMQTT::Message& message); |
| 34 | 34 | |
| 35 | 35 | public: |
| 36 | - explicit TwoClientTestContext(QObject *parent = nullptr); | |
| 36 | + explicit TwoClientTestContext(int clientNr = 0, QObject *parent = nullptr); | |
| 37 | 37 | void publish(const QString &topic, const QByteArray &payload); |
| 38 | 38 | void publish(const QString &topic, const QByteArray &payload, bool retain); |
| 39 | 39 | void publish(const QString &topic, const QByteArray &payload, const quint8 qos, bool retain); | ... | ... |
client.cpp
| ... | ... | @@ -207,6 +207,21 @@ int Client::writeMqttPacket(const MqttPacket &packet) |
| 207 | 207 | return 1; |
| 208 | 208 | } |
| 209 | 209 | |
| 210 | +int Client::writeMqttPacketAndBlameThisClient(PublishCopyFactory ©Factory, char max_qos, uint16_t packet_id) | |
| 211 | +{ | |
| 212 | + MqttPacket &p = copyFactory.getOptimumPacket(max_qos); | |
| 213 | + | |
| 214 | + if (p.getQos() > 0) | |
| 215 | + { | |
| 216 | + // This may change the packet ID and QoS of the incoming packet for each subscriber, but because we don't store that packet anywhere, | |
| 217 | + // that should be fine. | |
| 218 | + p.setPacketId(packet_id); | |
| 219 | + p.setQos(max_qos); | |
| 220 | + } | |
| 221 | + | |
| 222 | + return writeMqttPacketAndBlameThisClient(p); | |
| 223 | +} | |
| 224 | + | |
| 210 | 225 | // Helper method to avoid the exception ending up at the sender of messages, which would then get disconnected. |
| 211 | 226 | int Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet) |
| 212 | 227 | { | ... | ... |
client.h
| ... | ... | @@ -37,6 +37,8 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. |
| 37 | 37 | #include "types.h" |
| 38 | 38 | #include "iowrapper.h" |
| 39 | 39 | |
| 40 | +#include "publishcopyfactory.h" | |
| 41 | + | |
| 40 | 42 | #define MQTT_HEADER_LENGH 2 |
| 41 | 43 | |
| 42 | 44 | |
| ... | ... | @@ -122,6 +124,7 @@ public: |
| 122 | 124 | void writeText(const std::string &text); |
| 123 | 125 | void writePingResp(); |
| 124 | 126 | int writeMqttPacket(const MqttPacket &packet); |
| 127 | + int writeMqttPacketAndBlameThisClient(PublishCopyFactory ©Factory, char max_qos, uint16_t packet_id); | |
| 125 | 128 | int writeMqttPacketAndBlameThisClient(const MqttPacket &packet); |
| 126 | 129 | bool writeBufIntoFd(); |
| 127 | 130 | bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; } | ... | ... |
mqttpacket.cpp
| ... | ... | @@ -822,11 +822,6 @@ void MqttPacket::setDuplicate() |
| 822 | 822 | } |
| 823 | 823 | } |
| 824 | 824 | |
| 825 | -size_t MqttPacket::getTotalMemoryFootprint() | |
| 826 | -{ | |
| 827 | - return bites.size() + sizeof(MqttPacket); | |
| 828 | -} | |
| 829 | - | |
| 830 | 825 | /** |
| 831 | 826 | * @brief MqttPacket::getPayloadCopy takes part of the vector of bytes and returns it as a string. |
| 832 | 827 | * @return | ... | ... |
mqttpacket.h
| ... | ... | @@ -120,7 +120,6 @@ public: |
| 120 | 120 | void setPacketId(uint16_t packet_id); |
| 121 | 121 | uint16_t getPacketId() const; |
| 122 | 122 | void setDuplicate(); |
| 123 | - size_t getTotalMemoryFootprint(); | |
| 124 | 123 | void readIntoBuf(CirBuf &buf) const; |
| 125 | 124 | std::string getPayloadCopy() const; |
| 126 | 125 | bool getRetain() const; | ... | ... |
publishcopyfactory.cpp
0 โ 100644
| 1 | +#include <cassert> | |
| 2 | + | |
| 3 | +#include "publishcopyfactory.h" | |
| 4 | +#include "mqttpacket.h" | |
| 5 | + | |
| 6 | +PublishCopyFactory::PublishCopyFactory(MqttPacket &packet) : | |
| 7 | + packet(packet), | |
| 8 | + orgQos(packet.getQos()) | |
| 9 | +{ | |
| 10 | + | |
| 11 | +} | |
| 12 | + | |
| 13 | +MqttPacket &PublishCopyFactory::getOptimumPacket(char max_qos) | |
| 14 | +{ | |
| 15 | + if (max_qos == 0 && max_qos < packet.getQos()) | |
| 16 | + { | |
| 17 | + if (!downgradedQos0PacketCopy) | |
| 18 | + downgradedQos0PacketCopy = packet.getCopy(max_qos); | |
| 19 | + assert(downgradedQos0PacketCopy->getQos() == 0); | |
| 20 | + return *downgradedQos0PacketCopy.get(); | |
| 21 | + } | |
| 22 | + | |
| 23 | + return packet; | |
| 24 | +} | |
| 25 | + | |
| 26 | +char PublishCopyFactory::getEffectiveQos(char max_qos) const | |
| 27 | +{ | |
| 28 | + const char effectiveQos = std::min<char>(orgQos, max_qos); | |
| 29 | + return effectiveQos; | |
| 30 | +} | |
| 31 | + | |
| 32 | +const std::string &PublishCopyFactory::getTopic() const | |
| 33 | +{ | |
| 34 | + return packet.getTopic(); | |
| 35 | +} | |
| 36 | + | |
| 37 | +const std::vector<std::string> &PublishCopyFactory::getSubtopics() const | |
| 38 | +{ | |
| 39 | + return packet.getSubtopics(); | |
| 40 | +} | |
| 41 | + | |
| 42 | +bool PublishCopyFactory::getRetain() const | |
| 43 | +{ | |
| 44 | + return packet.getRetain(); | |
| 45 | +} | |
| 46 | + | |
| 47 | +Publish PublishCopyFactory::getPublish() const | |
| 48 | +{ | |
| 49 | + assert(packet.getQos() > 0); | |
| 50 | + | |
| 51 | + Publish p(packet.getTopic(), packet.getPayloadCopy(), packet.getQos()); | |
| 52 | + return p; | |
| 53 | +} | ... | ... |
publishcopyfactory.h
0 โ 100644
| 1 | +#ifndef PUBLISHCOPYFACTORY_H | |
| 2 | +#define PUBLISHCOPYFACTORY_H | |
| 3 | + | |
| 4 | +#include <vector> | |
| 5 | + | |
| 6 | +#include "forward_declarations.h" | |
| 7 | +#include "types.h" | |
| 8 | + | |
| 9 | +/** | |
| 10 | + * @brief The PublishCopyFactory class is for managing copies of an incoming publish, including sometimes not making copies at all. | |
| 11 | + * | |
| 12 | + * The idea is that certain incoming packets can just be written to the receiving client as-is, without constructing a new one. We do have to change the bytes | |
| 13 | + * where the QoS is stored, so we keep track of the original. | |
| 14 | + */ | |
| 15 | +class PublishCopyFactory | |
| 16 | +{ | |
| 17 | + MqttPacket &packet; | |
| 18 | + const char orgQos; | |
| 19 | + std::shared_ptr<MqttPacket> downgradedQos0PacketCopy; | |
| 20 | + | |
| 21 | + // TODO: constructed mqtt3 packet and mqtt5 packet | |
| 22 | +public: | |
| 23 | + PublishCopyFactory(MqttPacket &packet); | |
| 24 | + PublishCopyFactory(const PublishCopyFactory &other) = delete; | |
| 25 | + PublishCopyFactory(PublishCopyFactory &&other) = delete; | |
| 26 | + | |
| 27 | + MqttPacket &getOptimumPacket(char max_qos); | |
| 28 | + char getEffectiveQos(char max_qos) const; | |
| 29 | + const std::string &getTopic() const; | |
| 30 | + const std::vector<std::string> &getSubtopics() const; | |
| 31 | + bool getRetain() const; | |
| 32 | + Publish getPublish() const; | |
| 33 | +}; | |
| 34 | + | |
| 35 | +#endif // PUBLISHCOPYFACTORY_H | ... | ... |
qospacketqueue.cpp
| ... | ... | @@ -4,16 +4,39 @@ |
| 4 | 4 | |
| 5 | 5 | #include "mqttpacket.h" |
| 6 | 6 | |
| 7 | -void QoSPacketQueue::erase(const uint16_t packet_id) | |
| 7 | +QueuedPublish::QueuedPublish(Publish &&publish, uint16_t packet_id) : | |
| 8 | + publish(std::move(publish)), | |
| 9 | + packet_id(packet_id) | |
| 10 | +{ | |
| 11 | + | |
| 12 | +} | |
| 13 | + | |
| 14 | +uint16_t QueuedPublish::getPacketId() const | |
| 15 | +{ | |
| 16 | + return this->packet_id; | |
| 17 | +} | |
| 18 | + | |
| 19 | +const Publish &QueuedPublish::getPublish() const | |
| 20 | +{ | |
| 21 | + return publish; | |
| 22 | +} | |
| 23 | + | |
| 24 | +size_t QueuedPublish::getApproximateMemoryFootprint() const | |
| 25 | +{ | |
| 26 | + return publish.topic.length() + publish.payload.length(); | |
| 27 | +} | |
| 28 | + | |
| 29 | + | |
| 30 | +void QoSPublishQueue::erase(const uint16_t packet_id) | |
| 8 | 31 | { |
| 9 | 32 | auto it = queue.begin(); |
| 10 | 33 | auto end = queue.end(); |
| 11 | 34 | while (it != end) |
| 12 | 35 | { |
| 13 | - std::shared_ptr<MqttPacket> &p = *it; | |
| 14 | - if (p->getPacketId() == packet_id) | |
| 36 | + QueuedPublish &p = *it; | |
| 37 | + if (p.getPacketId() == packet_id) | |
| 15 | 38 | { |
| 16 | - size_t mem = p->getTotalMemoryFootprint(); | |
| 39 | + size_t mem = p.getApproximateMemoryFootprint(); | |
| 17 | 40 | qosQueueBytes -= mem; |
| 18 | 41 | assert(qosQueueBytes >= 0); |
| 19 | 42 | if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose. |
| ... | ... | @@ -28,50 +51,40 @@ void QoSPacketQueue::erase(const uint16_t packet_id) |
| 28 | 51 | } |
| 29 | 52 | } |
| 30 | 53 | |
| 31 | -size_t QoSPacketQueue::size() const | |
| 54 | +size_t QoSPublishQueue::size() const | |
| 32 | 55 | { |
| 33 | 56 | return queue.size(); |
| 34 | 57 | } |
| 35 | 58 | |
| 36 | -size_t QoSPacketQueue::getByteSize() const | |
| 59 | +size_t QoSPublishQueue::getByteSize() const | |
| 37 | 60 | { |
| 38 | 61 | return qosQueueBytes; |
| 39 | 62 | } |
| 40 | 63 | |
| 41 | -/** | |
| 42 | - * @brief QoSPacketQueue::queuePacket makes a copy of the packet because it has state for the receiver in question. | |
| 43 | - * @param p | |
| 44 | - * @param id | |
| 45 | - * @return the packet copy. | |
| 46 | - */ | |
| 47 | -std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos) | |
| 64 | +void QoSPublishQueue::queuePublish(PublishCopyFactory ©Factory, uint16_t id, char new_max_qos) | |
| 48 | 65 | { |
| 49 | - assert(p.getQos() > 0); | |
| 66 | + assert(new_max_qos > 0); | |
| 67 | + assert(id > 0); | |
| 50 | 68 | |
| 51 | - std::shared_ptr<MqttPacket> copyPacket = p.getCopy(new_max_qos); | |
| 52 | - copyPacket->setPacketId(id); | |
| 53 | - queue.push_back(copyPacket); | |
| 54 | - qosQueueBytes += copyPacket->getTotalMemoryFootprint(); | |
| 55 | - return copyPacket; | |
| 69 | + Publish pub = copyFactory.getPublish(); | |
| 70 | + queue.emplace_back(std::move(pub), id); | |
| 71 | + qosQueueBytes += queue.back().getApproximateMemoryFootprint(); | |
| 56 | 72 | } |
| 57 | 73 | |
| 58 | -std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const Publish &pub, uint16_t id) | |
| 74 | +void QoSPublishQueue::queuePublish(Publish &&pub, uint16_t id) | |
| 59 | 75 | { |
| 60 | - assert(pub.qos > 0); | |
| 76 | + assert(id > 0); | |
| 61 | 77 | |
| 62 | - std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(pub)); | |
| 63 | - copyPacket->setPacketId(id); | |
| 64 | - queue.push_back(copyPacket); | |
| 65 | - qosQueueBytes += copyPacket->getTotalMemoryFootprint(); | |
| 66 | - return copyPacket; | |
| 78 | + queue.emplace_back(std::move(pub), id); | |
| 79 | + qosQueueBytes += queue.back().getApproximateMemoryFootprint(); | |
| 67 | 80 | } |
| 68 | 81 | |
| 69 | -std::list<std::shared_ptr<MqttPacket>>::const_iterator QoSPacketQueue::begin() const | |
| 82 | +std::list<QueuedPublish>::const_iterator QoSPublishQueue::begin() const | |
| 70 | 83 | { |
| 71 | 84 | return queue.cbegin(); |
| 72 | 85 | } |
| 73 | 86 | |
| 74 | -std::list<std::shared_ptr<MqttPacket>>::const_iterator QoSPacketQueue::end() const | |
| 87 | +std::list<QueuedPublish>::const_iterator QoSPublishQueue::end() const | |
| 75 | 88 | { |
| 76 | 89 | return queue.cend(); |
| 77 | 90 | } | ... | ... |
qospacketqueue.h
| ... | ... | @@ -5,21 +5,39 @@ |
| 5 | 5 | |
| 6 | 6 | #include "forward_declarations.h" |
| 7 | 7 | #include "types.h" |
| 8 | +#include "publishcopyfactory.h" | |
| 8 | 9 | |
| 9 | -class QoSPacketQueue | |
| 10 | +/** | |
| 11 | + * @brief The QueuedPublish class wraps the publish with a packet id. | |
| 12 | + * | |
| 13 | + * We don't want to store the packet id in the Publish object, because the packet id is determined/tracked per client/session. | |
| 14 | + */ | |
| 15 | +class QueuedPublish | |
| 10 | 16 | { |
| 11 | - std::list<std::shared_ptr<MqttPacket>> queue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] | |
| 17 | + Publish publish; | |
| 18 | + uint16_t packet_id = 0; | |
| 19 | +public: | |
| 20 | + QueuedPublish(Publish &&publish, uint16_t packet_id); | |
| 21 | + | |
| 22 | + size_t getApproximateMemoryFootprint() const; | |
| 23 | + uint16_t getPacketId() const; | |
| 24 | + const Publish &getPublish() const; | |
| 25 | +}; | |
| 26 | + | |
| 27 | +class QoSPublishQueue | |
| 28 | +{ | |
| 29 | + std::list<QueuedPublish> queue; // Using list because it's easiest to maintain order [MQTT-4.6.0-6] | |
| 12 | 30 | ssize_t qosQueueBytes = 0; |
| 13 | 31 | |
| 14 | 32 | public: |
| 15 | 33 | void erase(const uint16_t packet_id); |
| 16 | 34 | size_t size() const; |
| 17 | 35 | size_t getByteSize() const; |
| 18 | - std::shared_ptr<MqttPacket> queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos); | |
| 19 | - std::shared_ptr<MqttPacket> queuePacket(const Publish &pub, uint16_t id); | |
| 36 | + void queuePublish(PublishCopyFactory ©Factory, uint16_t id, char new_max_qos); | |
| 37 | + void queuePublish(Publish &&pub, uint16_t id); | |
| 20 | 38 | |
| 21 | - std::list<std::shared_ptr<MqttPacket>>::const_iterator begin() const; | |
| 22 | - std::list<std::shared_ptr<MqttPacket>>::const_iterator end() const; | |
| 39 | + std::list<QueuedPublish>::const_iterator begin() const; | |
| 40 | + std::list<QueuedPublish>::const_iterator end() const; | |
| 23 | 41 | }; |
| 24 | 42 | |
| 25 | 43 | #endif // QOSPACKETQUEUE_H | ... | ... |
session.cpp
| ... | ... | @@ -101,8 +101,7 @@ Session::Session(const Session &other) |
| 101 | 101 | this->nextPacketId = other.nextPacketId; |
| 102 | 102 | this->lastTouched = other.lastTouched; |
| 103 | 103 | |
| 104 | - // To be fully correct, we should copy the individual packets, but copying sessions is only done for saving them, and I know | |
| 105 | - // that no member of MqttPacket changes in the QoS process, so we can just keep the shared pointer to the original. | |
| 104 | + // TODO: see git history for a change here. We now copy the whole queued publish. Do we want to address that? | |
| 106 | 105 | this->qosPacketQueue = other.qosPacketQueue; |
| 107 | 106 | } |
| 108 | 107 | |
| ... | ... | @@ -145,33 +144,25 @@ void Session::assignActiveConnection(std::shared_ptr<Client> &client) |
| 145 | 144 | * @param retain. Keep MQTT-3.3.1-9 in mind: existing subscribers don't get retain=1 on packets. |
| 146 | 145 | * @param count. Reference value is updated. It's for statistics. |
| 147 | 146 | */ |
| 148 | -void Session::writePacket(MqttPacket &packet, char max_qos, std::shared_ptr<MqttPacket> &downgradedQos0PacketCopy, uint64_t &count) | |
| 147 | +void Session::writePacket(PublishCopyFactory ©Factory, const char max_qos, uint64_t &count) | |
| 149 | 148 | { |
| 150 | 149 | assert(max_qos <= 2); |
| 151 | - const char effectiveQos = std::min<char>(packet.getQos(), max_qos); | |
| 150 | + | |
| 151 | + const char effectiveQos = copyFactory.getEffectiveQos(max_qos); | |
| 152 | 152 | |
| 153 | 153 | const Settings *settings = ThreadGlobals::getSettings(); |
| 154 | 154 | |
| 155 | 155 | Authentication *_auth = ThreadGlobals::getAuth(); |
| 156 | 156 | assert(_auth); |
| 157 | 157 | Authentication &auth = *_auth; |
| 158 | - if (auth.aclCheck(client_id, username, packet.getTopic(), packet.getSubtopics(), AclAccess::read, effectiveQos, packet.getRetain()) == AuthResult::success) | |
| 158 | + if (auth.aclCheck(client_id, username, copyFactory.getTopic(), copyFactory.getSubtopics(), AclAccess::read, effectiveQos, copyFactory.getRetain()) == AuthResult::success) | |
| 159 | 159 | { |
| 160 | 160 | std::shared_ptr<Client> c = makeSharedClient(); |
| 161 | 161 | if (effectiveQos == 0) |
| 162 | 162 | { |
| 163 | 163 | if (c) |
| 164 | 164 | { |
| 165 | - const MqttPacket *packetToSend = &packet; | |
| 166 | - | |
| 167 | - if (max_qos < packet.getQos()) | |
| 168 | - { | |
| 169 | - if (!downgradedQos0PacketCopy) | |
| 170 | - downgradedQos0PacketCopy = packet.getCopy(max_qos); | |
| 171 | - packetToSend = downgradedQos0PacketCopy.get(); | |
| 172 | - } | |
| 173 | - | |
| 174 | - count += c->writeMqttPacketAndBlameThisClient(*packetToSend); | |
| 165 | + count += c->writeMqttPacketAndBlameThisClient(copyFactory, effectiveQos, 0); | |
| 175 | 166 | } |
| 176 | 167 | } |
| 177 | 168 | else if (effectiveQos > 0) |
| ... | ... | @@ -195,12 +186,12 @@ void Session::writePacket(MqttPacket &packet, char max_qos, std::shared_ptr<Mqtt |
| 195 | 186 | } |
| 196 | 187 | |
| 197 | 188 | increasePacketId(); |
| 198 | - std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId, effectiveQos); | |
| 189 | + | |
| 190 | + qosPacketQueue.queuePublish(copyFactory, nextPacketId, effectiveQos); | |
| 199 | 191 | |
| 200 | 192 | if (c) |
| 201 | 193 | { |
| 202 | - count += c->writeMqttPacketAndBlameThisClient(*copyPacket.get()); | |
| 203 | - copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | |
| 194 | + count += c->writeMqttPacketAndBlameThisClient(copyFactory, effectiveQos, nextPacketId); | |
| 204 | 195 | } |
| 205 | 196 | } |
| 206 | 197 | else |
| ... | ... | @@ -224,14 +215,9 @@ void Session::writePacket(MqttPacket &packet, char max_qos, std::shared_ptr<Mqtt |
| 224 | 215 | |
| 225 | 216 | increasePacketId(); |
| 226 | 217 | |
| 227 | - // This changes the packet ID and QoS of the incoming packet for each subscriber, but because we don't store that packet anywhere, | |
| 228 | - // that should be fine. | |
| 229 | - packet.setPacketId(nextPacketId); | |
| 230 | - packet.setQos(effectiveQos); | |
| 231 | - | |
| 232 | 218 | qosInFlightCounter++; |
| 233 | 219 | assert(c); // with requiresRetransmission==false, there must be a client. |
| 234 | - count += c->writeMqttPacketAndBlameThisClient(packet); | |
| 220 | + count += c->writeMqttPacketAndBlameThisClient(copyFactory, effectiveQos, nextPacketId); | |
| 235 | 221 | } |
| 236 | 222 | } |
| 237 | 223 | } |
| ... | ... | @@ -267,10 +253,12 @@ uint64_t Session::sendPendingQosMessages() |
| 267 | 253 | if (c) |
| 268 | 254 | { |
| 269 | 255 | std::lock_guard<std::mutex> locker(qosQueueMutex); |
| 270 | - for (const std::shared_ptr<MqttPacket> &qosMessage : qosPacketQueue) | |
| 256 | + for (const QueuedPublish &queuedPublish : qosPacketQueue) | |
| 271 | 257 | { |
| 272 | - count += c->writeMqttPacketAndBlameThisClient(*qosMessage.get()); | |
| 273 | - qosMessage->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | |
| 258 | + MqttPacket p(queuedPublish.getPublish()); | |
| 259 | + p.setDuplicate(); | |
| 260 | + | |
| 261 | + count += c->writeMqttPacketAndBlameThisClient(p); | |
| 274 | 262 | } |
| 275 | 263 | |
| 276 | 264 | for (const uint16_t packet_id : outgoingQoS2MessageIds) | ... | ... |
session.h
| ... | ... | @@ -27,6 +27,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. |
| 27 | 27 | #include "logger.h" |
| 28 | 28 | #include "sessionsandsubscriptionsdb.h" |
| 29 | 29 | #include "qospacketqueue.h" |
| 30 | +#include "publishcopyfactory.h" | |
| 30 | 31 | |
| 31 | 32 | class Session |
| 32 | 33 | { |
| ... | ... | @@ -39,7 +40,7 @@ class Session |
| 39 | 40 | std::weak_ptr<Client> client; |
| 40 | 41 | std::string client_id; |
| 41 | 42 | std::string username; |
| 42 | - QoSPacketQueue qosPacketQueue; | |
| 43 | + QoSPublishQueue qosPacketQueue; | |
| 43 | 44 | std::set<uint16_t> incomingQoS2MessageIds; |
| 44 | 45 | std::set<uint16_t> outgoingQoS2MessageIds; |
| 45 | 46 | std::mutex qosQueueMutex; |
| ... | ... | @@ -69,7 +70,7 @@ public: |
| 69 | 70 | const std::string &getClientId() const { return client_id; } |
| 70 | 71 | std::shared_ptr<Client> makeSharedClient() const; |
| 71 | 72 | void assignActiveConnection(std::shared_ptr<Client> &client); |
| 72 | - void writePacket(MqttPacket &packet, char max_qos, std::shared_ptr<MqttPacket> &downgradedQos0PacketCopy, uint64_t &count); | |
| 73 | + void writePacket(PublishCopyFactory ©Factory, const char max_qos, uint64_t &count); | |
| 73 | 74 | void clearQosMessage(uint16_t packet_id); |
| 74 | 75 | uint64_t sendPendingQosMessages(); |
| 75 | 76 | void touch(std::chrono::time_point<std::chrono::steady_clock> val); | ... | ... |
sessionsandsubscriptionsdb.cpp
| ... | ... | @@ -116,7 +116,7 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV1() |
| 116 | 116 | |
| 117 | 117 | Publish pub(topic, payload, qos); |
| 118 | 118 | logger->logf(LOG_DEBUG, "Loaded QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str()); |
| 119 | - ses->qosPacketQueue.queuePacket(pub, id); | |
| 119 | + ses->qosPacketQueue.queuePublish(std::move(pub), id); | |
| 120 | 120 | } |
| 121 | 121 | |
| 122 | 122 | const uint32_t nrOfIncomingPacketIds = readUint32(eofFound); |
| ... | ... | @@ -215,23 +215,24 @@ void SessionsAndSubscriptionsDB::saveData(const std::vector<std::unique_ptr<Sess |
| 215 | 215 | size_t qosPacketsCounted = 0; |
| 216 | 216 | writeUint32(qosPacketsExpected); |
| 217 | 217 | |
| 218 | - for (const std::shared_ptr<MqttPacket> &p: ses->qosPacketQueue) | |
| 218 | + for (const QueuedPublish &p: ses->qosPacketQueue) | |
| 219 | 219 | { |
| 220 | - logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", p->getQos(), p->getTopic().c_str()); | |
| 220 | + const Publish &pub = p.getPublish(); | |
| 221 | + | |
| 222 | + logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str()); | |
| 221 | 223 | |
| 222 | 224 | qosPacketsCounted++; |
| 223 | 225 | |
| 224 | - writeUint16(p->getPacketId()); | |
| 226 | + writeUint16(p.getPacketId()); | |
| 225 | 227 | |
| 226 | - writeUint32(p->getTopic().length()); | |
| 227 | - std::string payload = p->getPayloadCopy(); | |
| 228 | - writeUint32(payload.size()); | |
| 228 | + writeUint32(pub.topic.length()); | |
| 229 | + writeUint32(pub.payload.size()); | |
| 229 | 230 | |
| 230 | - const char qos = p->getQos(); | |
| 231 | + const char qos = pub.qos; | |
| 231 | 232 | writeCheck(&qos, 1, 1, f); |
| 232 | 233 | |
| 233 | - writeCheck(p->getTopic().c_str(), 1, p->getTopic().length(), f); | |
| 234 | - writeCheck(payload.c_str(), 1, payload.length(), f); | |
| 234 | + writeCheck(pub.topic.c_str(), 1, pub.topic.length(), f); | |
| 235 | + writeCheck(pub.payload.c_str(), 1, pub.payload.length(), f); | |
| 235 | 236 | } |
| 236 | 237 | |
| 237 | 238 | assert(qosPacketsExpected == qosPacketsCounted); | ... | ... |
subscriptionstore.cpp
| ... | ... | @@ -21,6 +21,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. |
| 21 | 21 | |
| 22 | 22 | #include "rwlockguard.h" |
| 23 | 23 | #include "retainedmessagesdb.h" |
| 24 | +#include "publishcopyfactory.h" | |
| 24 | 25 | |
| 25 | 26 | ReceivingSubscriber::ReceivingSubscriber(const std::shared_ptr<Session> &ses, char qos) : |
| 26 | 27 | session(ses), |
| ... | ... | @@ -346,10 +347,10 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::vector<std::string> |
| 346 | 347 | publishRecursively(subtopics.begin(), subtopics.end(), startNode, subscriberSessions); |
| 347 | 348 | } |
| 348 | 349 | |
| 349 | - std::shared_ptr<MqttPacket> possibleQos0Copy; | |
| 350 | + PublishCopyFactory copyFactory(packet); | |
| 350 | 351 | for(const ReceivingSubscriber &x : subscriberSessions) |
| 351 | 352 | { |
| 352 | - x.session->writePacket(packet, x.qos, possibleQos0Copy, count); | |
| 353 | + x.session->writePacket(copyFactory, x.qos, count); | |
| 353 | 354 | } |
| 354 | 355 | |
| 355 | 356 | std::shared_ptr<Client> sender = packet.getSender(); |
| ... | ... | @@ -425,8 +426,8 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr<Ses |
| 425 | 426 | |
| 426 | 427 | for(MqttPacket &packet : packetList) |
| 427 | 428 | { |
| 428 | - std::shared_ptr<MqttPacket> possibleQos0Copy; | |
| 429 | - ses->writePacket(packet, max_qos, possibleQos0Copy, count); | |
| 429 | + PublishCopyFactory copyFactory(packet); | |
| 430 | + ses->writePacket(copyFactory, max_qos, count); | |
| 430 | 431 | } |
| 431 | 432 | |
| 432 | 433 | return count; | ... | ... |
types.cpp
| ... | ... | @@ -46,7 +46,7 @@ size_t SubAck::getLengthWithoutFixedHeader() const |
| 46 | 46 | return result; |
| 47 | 47 | } |
| 48 | 48 | |
| 49 | -Publish::Publish(const std::string &topic, const std::string payload, char qos) : | |
| 49 | +Publish::Publish(const std::string &topic, const std::string &payload, char qos) : | |
| 50 | 50 | topic(topic), |
| 51 | 51 | payload(payload), |
| 52 | 52 | qos(qos) | ... | ... |
types.h
| ... | ... | @@ -101,7 +101,7 @@ public: |
| 101 | 101 | std::string payload; |
| 102 | 102 | char qos = 0; |
| 103 | 103 | bool retain = false; // Note: existing subscribers don't get publishes of retained messages with retain=1. [MQTT-3.3.1-9] |
| 104 | - Publish(const std::string &topic, const std::string payload, char qos); | |
| 104 | + Publish(const std::string &topic, const std::string &payload, char qos); | |
| 105 | 105 | size_t getLengthWithoutFixedHeader() const; |
| 106 | 106 | }; |
| 107 | 107 | ... | ... |