Commit f67238c5b794bb4ff5d3666b052bb6c80dad3ba2
1 parent
4947d33a
Fix not downgrading QoS to subscription QoS
This entails making copies of the original packet when necessary, because QoS 0 doesn't have a packet id. I tried to keep it to an absolute minimum and do some precarious optmizations for it. There are tests though.
Showing
16 changed files
with
419 additions
and
60 deletions
FlashMQTests/tst_maintests.cpp
| @@ -102,6 +102,15 @@ private slots: | @@ -102,6 +102,15 @@ private slots: | ||
| 102 | 102 | ||
| 103 | void testSavingSessions(); | 103 | void testSavingSessions(); |
| 104 | 104 | ||
| 105 | + void testCopyPacket(); | ||
| 106 | + | ||
| 107 | + void testDowngradeQoSOnSubscribeQos2to2(); | ||
| 108 | + void testDowngradeQoSOnSubscribeQos2to1(); | ||
| 109 | + void testDowngradeQoSOnSubscribeQos2to0(); | ||
| 110 | + void testDowngradeQoSOnSubscribeQos1to1(); | ||
| 111 | + void testDowngradeQoSOnSubscribeQos1to0(); | ||
| 112 | + void testDowngradeQoSOnSubscribeQos0to0(); | ||
| 113 | + | ||
| 105 | }; | 114 | }; |
| 106 | 115 | ||
| 107 | MainTests::MainTests() | 116 | MainTests::MainTests() |
| @@ -1027,7 +1036,9 @@ void MainTests::testSavingSessions() | @@ -1027,7 +1036,9 @@ void MainTests::testSavingSessions() | ||
| 1027 | 1036 | ||
| 1028 | std::shared_ptr<Session> c1ses = c1->getSession(); | 1037 | std::shared_ptr<Session> c1ses = c1->getSession(); |
| 1029 | c1.reset(); | 1038 | c1.reset(); |
| 1030 | - c1ses->writePacket(publish, 1, false, count); | 1039 | + MqttPacket publishPacket(publish); |
| 1040 | + std::shared_ptr<MqttPacket> possibleQos0Copy; | ||
| 1041 | + c1ses->writePacket(publishPacket, 1, possibleQos0Copy, count); | ||
| 1031 | 1042 | ||
| 1032 | store->saveSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db"); | 1043 | store->saveSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db"); |
| 1033 | 1044 | ||
| @@ -1091,6 +1102,145 @@ void MainTests::testSavingSessions() | @@ -1091,6 +1102,145 @@ void MainTests::testSavingSessions() | ||
| 1091 | } | 1102 | } |
| 1092 | } | 1103 | } |
| 1093 | 1104 | ||
| 1105 | +void testCopyPacketHelper(const std::string &topic, char from_qos, char to_qos, bool retain) | ||
| 1106 | +{ | ||
| 1107 | + assert(to_qos <= from_qos); | ||
| 1108 | + | ||
| 1109 | + Logger::getInstance()->setFlags(false, false, true); | ||
| 1110 | + | ||
| 1111 | + std::shared_ptr<Settings> settings(new Settings()); | ||
| 1112 | + settings->logDebug = false; | ||
| 1113 | + std::shared_ptr<SubscriptionStore> store(new SubscriptionStore()); | ||
| 1114 | + std::shared_ptr<ThreadData> t(new ThreadData(0, store, settings)); | ||
| 1115 | + | ||
| 1116 | + // Kind of a hack... | ||
| 1117 | + Authentication auth(*settings.get()); | ||
| 1118 | + ThreadAuth::assign(&auth); | ||
| 1119 | + | ||
| 1120 | + std::shared_ptr<Client> dummyClient(new Client(0, t, nullptr, false, nullptr, settings, false)); | ||
| 1121 | + dummyClient->setClientProperties(ProtocolVersion::Mqtt311, "qostestclient", "user1", true, 60, false); | ||
| 1122 | + store->registerClientAndKickExistingOne(dummyClient); | ||
| 1123 | + | ||
| 1124 | + uint16_t packetid = 66; | ||
| 1125 | + for (int len = 0; len < 150; len++ ) | ||
| 1126 | + { | ||
| 1127 | + const uint16_t pack_id = packetid++; | ||
| 1128 | + | ||
| 1129 | + std::vector<MqttPacket> parsedPackets; | ||
| 1130 | + | ||
| 1131 | + const std::string payloadOne = getSecureRandomString(len); | ||
| 1132 | + Publish pubOne(topic, payloadOne, from_qos); | ||
| 1133 | + pubOne.retain = retain; | ||
| 1134 | + MqttPacket stagingPacketOne(pubOne); | ||
| 1135 | + if (from_qos > 0) | ||
| 1136 | + stagingPacketOne.setPacketId(pack_id); | ||
| 1137 | + CirBuf stagingBufOne(1024); | ||
| 1138 | + stagingPacketOne.readIntoBuf(stagingBufOne); | ||
| 1139 | + | ||
| 1140 | + MqttPacket::bufferToMqttPackets(stagingBufOne, parsedPackets, dummyClient); | ||
| 1141 | + QVERIFY(parsedPackets.size() == 1); | ||
| 1142 | + MqttPacket parsedPacketOne = std::move(parsedPackets.front()); | ||
| 1143 | + parsedPacketOne.handlePublish(); | ||
| 1144 | + if (retain) // A normal handled packet always has retain=0, so I force setting it here. | ||
| 1145 | + parsedPacketOne.setRetain(); | ||
| 1146 | + QCOMPARE(stagingPacketOne.getTopic(), parsedPacketOne.getTopic()); | ||
| 1147 | + QCOMPARE(stagingPacketOne.getPayloadCopy(), parsedPacketOne.getPayloadCopy()); | ||
| 1148 | + | ||
| 1149 | + std::shared_ptr<MqttPacket> copiedPacketOne = parsedPacketOne.getCopy(to_qos); | ||
| 1150 | + | ||
| 1151 | + QCOMPARE(payloadOne, copiedPacketOne->getPayloadCopy()); | ||
| 1152 | + | ||
| 1153 | + // Now compare the written buffer of our copied packet to one that was written with our known good reference packet. | ||
| 1154 | + | ||
| 1155 | + Publish pubReference(topic, payloadOne, to_qos); | ||
| 1156 | + pubReference.retain = retain; | ||
| 1157 | + MqttPacket packetReference(pubReference); | ||
| 1158 | + if (to_qos > 0) | ||
| 1159 | + packetReference.setPacketId(pack_id); | ||
| 1160 | + CirBuf bufOfReference(1024); | ||
| 1161 | + CirBuf bufOfCopied(1024); | ||
| 1162 | + packetReference.readIntoBuf(bufOfReference); | ||
| 1163 | + copiedPacketOne->readIntoBuf(bufOfCopied); | ||
| 1164 | + QVERIFY2(bufOfCopied == bufOfReference, formatString("Failure on length %d for topic %s, from qos %d to qos %d, retain: %d.", | ||
| 1165 | + len, topic.c_str(), from_qos, to_qos, retain).c_str()); | ||
| 1166 | + } | ||
| 1167 | +} | ||
| 1168 | + | ||
| 1169 | +/** | ||
| 1170 | + * @brief MainTests::testCopyPacket tests the actual bytes of a copied that would be written to a client. | ||
| 1171 | + * | ||
| 1172 | + * This is specifically to test the optimiziations in getCopy(). It indirectly also tests packet parsing. | ||
| 1173 | + */ | ||
| 1174 | +void MainTests::testCopyPacket() | ||
| 1175 | +{ | ||
| 1176 | + for (int retain = 0; retain < 2; retain++) | ||
| 1177 | + { | ||
| 1178 | + testCopyPacketHelper("John/McLane", 0, 0, retain); | ||
| 1179 | + testCopyPacketHelper("Ben/Sisko", 1, 1, retain); | ||
| 1180 | + testCopyPacketHelper("Rebecca/Bunch", 2, 2, retain); | ||
| 1181 | + | ||
| 1182 | + testCopyPacketHelper("Buffy/Slayer", 1, 0, retain); | ||
| 1183 | + testCopyPacketHelper("Sarah/Connor", 2, 0, retain); | ||
| 1184 | + testCopyPacketHelper("Susan/Mayer", 2, 1, retain); | ||
| 1185 | + } | ||
| 1186 | +} | ||
| 1187 | + | ||
| 1188 | +void testDowngradeQoSOnSubscribeHelper(const char pub_qos, const char sub_qos) | ||
| 1189 | +{ | ||
| 1190 | + TwoClientTestContext testContext; | ||
| 1191 | + | ||
| 1192 | + const QString topic("Star/Trek"); | ||
| 1193 | + const QByteArray payload("Captain Kirk"); | ||
| 1194 | + | ||
| 1195 | + testContext.connectSender(); | ||
| 1196 | + testContext.connectReceiver(); | ||
| 1197 | + | ||
| 1198 | + testContext.subscribeReceiver(topic, sub_qos); | ||
| 1199 | + testContext.publish(topic, payload, pub_qos, false); | ||
| 1200 | + | ||
| 1201 | + testContext.waitReceiverReceived(1); | ||
| 1202 | + | ||
| 1203 | + QCOMPARE(testContext.receivedMessages.length(), 1); | ||
| 1204 | + QMQTT::Message &recv = testContext.receivedMessages.first(); | ||
| 1205 | + | ||
| 1206 | + const char expected_qos = std::min<const char>(pub_qos, sub_qos); | ||
| 1207 | + QVERIFY2(recv.qos() == expected_qos, formatString("Failure: received QoS is %d. Published is %d. Subscribed as %d. Expected QoS is %d", | ||
| 1208 | + recv.qos(), pub_qos, sub_qos, expected_qos).c_str()); | ||
| 1209 | + QVERIFY(recv.topic() == topic); | ||
| 1210 | + QVERIFY(recv.payload() == payload); | ||
| 1211 | +} | ||
| 1212 | + | ||
| 1213 | +void MainTests::testDowngradeQoSOnSubscribeQos2to2() | ||
| 1214 | +{ | ||
| 1215 | + testDowngradeQoSOnSubscribeHelper(2, 2); | ||
| 1216 | +} | ||
| 1217 | + | ||
| 1218 | +void MainTests::testDowngradeQoSOnSubscribeQos2to1() | ||
| 1219 | +{ | ||
| 1220 | + testDowngradeQoSOnSubscribeHelper(2, 1); | ||
| 1221 | +} | ||
| 1222 | + | ||
| 1223 | +void MainTests::testDowngradeQoSOnSubscribeQos2to0() | ||
| 1224 | +{ | ||
| 1225 | + testDowngradeQoSOnSubscribeHelper(2, 0); | ||
| 1226 | +} | ||
| 1227 | + | ||
| 1228 | +void MainTests::testDowngradeQoSOnSubscribeQos1to1() | ||
| 1229 | +{ | ||
| 1230 | + testDowngradeQoSOnSubscribeHelper(1, 1); | ||
| 1231 | +} | ||
| 1232 | + | ||
| 1233 | +void MainTests::testDowngradeQoSOnSubscribeQos1to0() | ||
| 1234 | +{ | ||
| 1235 | + testDowngradeQoSOnSubscribeHelper(1, 0); | ||
| 1236 | +} | ||
| 1237 | + | ||
| 1238 | +void MainTests::testDowngradeQoSOnSubscribeQos0to0() | ||
| 1239 | +{ | ||
| 1240 | + testDowngradeQoSOnSubscribeHelper(0, 0); | ||
| 1241 | +} | ||
| 1242 | + | ||
| 1243 | + | ||
| 1094 | int main(int argc, char *argv[]) | 1244 | int main(int argc, char *argv[]) |
| 1095 | { | 1245 | { |
| 1096 | QCoreApplication app(argc, argv); | 1246 | QCoreApplication app(argc, argv); |
FlashMQTests/twoclienttestcontext.cpp
| @@ -35,12 +35,22 @@ TwoClientTestContext::TwoClientTestContext(QObject *parent) : QObject(parent) | @@ -35,12 +35,22 @@ TwoClientTestContext::TwoClientTestContext(QObject *parent) : QObject(parent) | ||
| 35 | connect(receiver.data(), &QMQTT::Client::error, this, &TwoClientTestContext::onClientError); | 35 | connect(receiver.data(), &QMQTT::Client::error, this, &TwoClientTestContext::onClientError); |
| 36 | } | 36 | } |
| 37 | 37 | ||
| 38 | +void TwoClientTestContext::publish(const QString &topic, const QByteArray &payload) | ||
| 39 | +{ | ||
| 40 | + publish(topic, payload, 0, false); | ||
| 41 | +} | ||
| 42 | + | ||
| 38 | void TwoClientTestContext::publish(const QString &topic, const QByteArray &payload, bool retain) | 43 | void TwoClientTestContext::publish(const QString &topic, const QByteArray &payload, bool retain) |
| 39 | { | 44 | { |
| 45 | + publish(topic, payload, 0, retain); | ||
| 46 | +} | ||
| 47 | + | ||
| 48 | +void TwoClientTestContext::publish(const QString &topic, const QByteArray &payload, const quint8 qos, bool retain) | ||
| 49 | +{ | ||
| 40 | QMQTT::Message msg; | 50 | QMQTT::Message msg; |
| 41 | msg.setTopic(topic); | 51 | msg.setTopic(topic); |
| 42 | msg.setRetain(retain); | 52 | msg.setRetain(retain); |
| 43 | - msg.setQos(0); | 53 | + msg.setQos(qos); |
| 44 | msg.setPayload(payload); | 54 | msg.setPayload(payload); |
| 45 | sender->publish(msg); | 55 | sender->publish(msg); |
| 46 | } | 56 | } |
| @@ -71,9 +81,9 @@ void TwoClientTestContext::disconnectReceiver() | @@ -71,9 +81,9 @@ void TwoClientTestContext::disconnectReceiver() | ||
| 71 | waiter.exec(); | 81 | waiter.exec(); |
| 72 | } | 82 | } |
| 73 | 83 | ||
| 74 | -void TwoClientTestContext::subscribeReceiver(const QString &topic) | 84 | +void TwoClientTestContext::subscribeReceiver(const QString &topic, const quint8 qos) |
| 75 | { | 85 | { |
| 76 | - receiver->subscribe(topic); | 86 | + receiver->subscribe(topic, qos); |
| 77 | 87 | ||
| 78 | QEventLoop waiter; | 88 | QEventLoop waiter; |
| 79 | QTimer timeout; | 89 | QTimer timeout; |
FlashMQTests/twoclienttestcontext.h
| @@ -34,11 +34,13 @@ private slots: | @@ -34,11 +34,13 @@ private slots: | ||
| 34 | 34 | ||
| 35 | public: | 35 | public: |
| 36 | explicit TwoClientTestContext(QObject *parent = nullptr); | 36 | explicit TwoClientTestContext(QObject *parent = nullptr); |
| 37 | - void publish(const QString &topic, const QByteArray &payload, bool retain = false); | 37 | + void publish(const QString &topic, const QByteArray &payload); |
| 38 | + void publish(const QString &topic, const QByteArray &payload, bool retain); | ||
| 39 | + void publish(const QString &topic, const QByteArray &payload, const quint8 qos, bool retain); | ||
| 38 | void connectSender(); | 40 | void connectSender(); |
| 39 | void connectReceiver(); | 41 | void connectReceiver(); |
| 40 | void disconnectReceiver(); | 42 | void disconnectReceiver(); |
| 41 | - void subscribeReceiver(const QString &topic); | 43 | + void subscribeReceiver(const QString &topic, const quint8 qos = 0); |
| 42 | void waitReceiverReceived(int count); | 44 | void waitReceiverReceived(int count); |
| 43 | void onClientError(const QMQTT::ClientError error); | 45 | void onClientError(const QMQTT::ClientError error); |
| 44 | 46 |
cirbuf.cpp
| @@ -252,3 +252,22 @@ void CirBuf::read(void *buf, const size_t count) | @@ -252,3 +252,22 @@ void CirBuf::read(void *buf, const size_t count) | ||
| 252 | assert(_packet_len == 0); | 252 | assert(_packet_len == 0); |
| 253 | assert(i == static_cast<int>(count)); | 253 | assert(i == static_cast<int>(count)); |
| 254 | } | 254 | } |
| 255 | + | ||
| 256 | +/** | ||
| 257 | + * @brief CirBuf::operator == simplistic comparision. It doesn't take the fact that it's circular into account. | ||
| 258 | + * @param other | ||
| 259 | + * @return | ||
| 260 | + * | ||
| 261 | + * It was created for unit testing. read() and write() are non-const, so taking the circular properties into account | ||
| 262 | + * would need more/duplicate code that I don't need at this point. | ||
| 263 | + */ | ||
| 264 | +bool CirBuf::operator==(const CirBuf &other) const | ||
| 265 | +{ | ||
| 266 | +#ifdef NDEBUG | ||
| 267 | + throw std::exception(); // you can't use it in release builds, because new buffers aren't zeroed. | ||
| 268 | +#endif | ||
| 269 | + | ||
| 270 | + return tail == 0 && other.tail == 0 | ||
| 271 | + && usedBytes() == other.usedBytes() | ||
| 272 | + && std::memcmp(buf, other.buf, size) == 0; | ||
| 273 | +} |
cirbuf.h
| @@ -60,6 +60,8 @@ public: | @@ -60,6 +60,8 @@ public: | ||
| 60 | 60 | ||
| 61 | void write(const void *buf, size_t count); | 61 | void write(const void *buf, size_t count); |
| 62 | void read(void *buf, const size_t count); | 62 | void read(void *buf, const size_t count); |
| 63 | + | ||
| 64 | + bool operator==(const CirBuf &other) const; | ||
| 63 | }; | 65 | }; |
| 64 | 66 | ||
| 65 | #endif // CIRBUF_H | 67 | #endif // CIRBUF_H |
client.cpp
| @@ -56,7 +56,7 @@ Client::~Client() | @@ -56,7 +56,7 @@ Client::~Client() | ||
| 56 | { | 56 | { |
| 57 | Publish will(will_topic, will_payload, will_qos); | 57 | Publish will(will_topic, will_payload, will_qos); |
| 58 | will.retain = will_retain; | 58 | will.retain = will_retain; |
| 59 | - const MqttPacket willPacket(will); | 59 | + MqttPacket willPacket(will); |
| 60 | 60 | ||
| 61 | const std::vector<std::string> subtopics = splitToVector(will_topic, '/'); | 61 | const std::vector<std::string> subtopics = splitToVector(will_topic, '/'); |
| 62 | store->queuePacketAtSubscribers(subtopics, willPacket); | 62 | store->queuePacketAtSubscribers(subtopics, willPacket); |
| @@ -180,7 +180,7 @@ void Client::writeText(const std::string &text) | @@ -180,7 +180,7 @@ void Client::writeText(const std::string &text) | ||
| 180 | setReadyForWriting(true); | 180 | setReadyForWriting(true); |
| 181 | } | 181 | } |
| 182 | 182 | ||
| 183 | -int Client::writeMqttPacket(const MqttPacket &packet, const char qos) | 183 | +int Client::writeMqttPacket(const MqttPacket &packet) |
| 184 | { | 184 | { |
| 185 | std::lock_guard<std::mutex> locker(writeBufMutex); | 185 | std::lock_guard<std::mutex> locker(writeBufMutex); |
| 186 | 186 | ||
| @@ -193,7 +193,7 @@ int Client::writeMqttPacket(const MqttPacket &packet, const char qos) | @@ -193,7 +193,7 @@ int Client::writeMqttPacket(const MqttPacket &packet, const char qos) | ||
| 193 | 193 | ||
| 194 | // And drop a publish when it doesn't fit, even after resizing. This means we do allow pings. And | 194 | // And drop a publish when it doesn't fit, even after resizing. This means we do allow pings. And |
| 195 | // QoS packet are queued and limited elsewhere. | 195 | // QoS packet are queued and limited elsewhere. |
| 196 | - if (packet.packetType == PacketType::PUBLISH && qos == 0 && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace()) | 196 | + if (packet.packetType == PacketType::PUBLISH && packet.getQos() == 0 && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace()) |
| 197 | { | 197 | { |
| 198 | return 0; | 198 | return 0; |
| 199 | } | 199 | } |
| @@ -208,11 +208,11 @@ int Client::writeMqttPacket(const MqttPacket &packet, const char qos) | @@ -208,11 +208,11 @@ int Client::writeMqttPacket(const MqttPacket &packet, const char qos) | ||
| 208 | } | 208 | } |
| 209 | 209 | ||
| 210 | // Helper method to avoid the exception ending up at the sender of messages, which would then get disconnected. | 210 | // Helper method to avoid the exception ending up at the sender of messages, which would then get disconnected. |
| 211 | -int Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet, const char qos) | 211 | +int Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet) |
| 212 | { | 212 | { |
| 213 | try | 213 | try |
| 214 | { | 214 | { |
| 215 | - return this->writeMqttPacket(packet, qos); | 215 | + return this->writeMqttPacket(packet); |
| 216 | } | 216 | } |
| 217 | catch (std::exception &ex) | 217 | catch (std::exception &ex) |
| 218 | { | 218 | { |
client.h
| @@ -121,8 +121,8 @@ public: | @@ -121,8 +121,8 @@ public: | ||
| 121 | 121 | ||
| 122 | void writeText(const std::string &text); | 122 | void writeText(const std::string &text); |
| 123 | void writePingResp(); | 123 | void writePingResp(); |
| 124 | - int writeMqttPacket(const MqttPacket &packet, const char qos = 0); | ||
| 125 | - int writeMqttPacketAndBlameThisClient(const MqttPacket &packet, const char qos); | 124 | + int writeMqttPacket(const MqttPacket &packet); |
| 125 | + int writeMqttPacketAndBlameThisClient(const MqttPacket &packet); | ||
| 126 | bool writeBufIntoFd(); | 126 | bool writeBufIntoFd(); |
| 127 | bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; } | 127 | bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; } |
| 128 | 128 |
mqttpacket.cpp
| @@ -48,11 +48,69 @@ MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_lengt | @@ -48,11 +48,69 @@ MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_lengt | ||
| 48 | * @brief MqttPacket::getCopy (using default copy constructor and resetting some selected fields) is easier than using the copy constructor | 48 | * @brief MqttPacket::getCopy (using default copy constructor and resetting some selected fields) is easier than using the copy constructor |
| 49 | * publically, because then I have to keep maintaining a functioning copy constructor for each new field I add. | 49 | * publically, because then I have to keep maintaining a functioning copy constructor for each new field I add. |
| 50 | * @return a shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource. | 50 | * @return a shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource. |
| 51 | + * | ||
| 52 | + * The idea is that because a packet with QoS is longer than one without, we just copy as much as possible if both packets have the same QoS. | ||
| 53 | + * | ||
| 54 | + * Note that there can be two types of packets: one with the fixed header (including remaining length), and one without. The latter we could be | ||
| 55 | + * more clever about, but I'm forgoing that right now. Their use is mostly for retained messages. | ||
| 56 | + * | ||
| 57 | + * Also note that some fields are undeterminstic in the copy: dup, retain and packetid for instance. Sometimes they come from the original, | ||
| 58 | + * sometimes not. The current planned usage is that those fields will either ONLY or NEVER be used in the copy, so it doesn't matter what I do | ||
| 59 | + * with them here. I may reconsider. | ||
| 51 | */ | 60 | */ |
| 52 | -std::shared_ptr<MqttPacket> MqttPacket::getCopy() const | 61 | +std::shared_ptr<MqttPacket> MqttPacket::getCopy(char new_max_qos) const |
| 53 | { | 62 | { |
| 63 | + assert(packetType == PacketType::PUBLISH); | ||
| 64 | + | ||
| 65 | + // You're not supposed to copy a duplicate packet. The only packets that get the dup flag, should not be copied AGAIN. This | ||
| 66 | + // has to do with the Session::writePacket() and Session::sendPendingQosMessages() logic. | ||
| 67 | + assert((first_byte & 0b00001000) == 0); | ||
| 68 | + | ||
| 69 | + if (qos > 0 && new_max_qos == 0) | ||
| 70 | + { | ||
| 71 | + // if shrinking the packet doesn't alter the amount of bytes in the 'remaining length' part of the header, we can | ||
| 72 | + // just memmove+shrink the packet. This is because the packet id always is two bytes before the payload, so we just move the payload | ||
| 73 | + // over it. When testing 100M copies, it went from 21000 ms to 10000 ms. In other words, about 100 ยตs to 200 ยตs per copy. | ||
| 74 | + // There is an elaborate unit test to test this optimization. | ||
| 75 | + if ((fixed_header_length == 2 && bites.size() < 125)) | ||
| 76 | + { | ||
| 77 | + // I don't know yet if this is true, but I don't want to forget when I implemenet MQTT5. | ||
| 78 | + assert(sender && sender->getProtocolVersion() <= ProtocolVersion::Mqtt311); | ||
| 79 | + | ||
| 80 | + std::shared_ptr<MqttPacket> p(new MqttPacket(*this)); | ||
| 81 | + p->sender.reset(); | ||
| 82 | + | ||
| 83 | + if (payloadLen > 0) | ||
| 84 | + std::memmove(&p->bites[packet_id_pos], &p->bites[packet_id_pos+2], payloadLen); | ||
| 85 | + p->bites.erase(p->bites.end() - 2, p->bites.end()); | ||
| 86 | + p->packet_id_pos = 0; | ||
| 87 | + p->payloadStart -= 2; | ||
| 88 | + if (pos > p->bites.size()) // pos can possible be set elsewhere, so we only set it back if it was after the payload. | ||
| 89 | + p->pos -= 2; | ||
| 90 | + p->packet_id = 0; | ||
| 91 | + | ||
| 92 | + // Clear QoS bits from the header. | ||
| 93 | + p->first_byte &= 0b11111001; | ||
| 94 | + p->bites[0] = p->first_byte; | ||
| 95 | + | ||
| 96 | + assert((p->bites[1] & 0b10000000) == 0); // when there is an MSB, I musn't get rid of it. | ||
| 97 | + assert(p->bites[1] > 3); // There has to be a remaining value after subtracting 2. | ||
| 98 | + | ||
| 99 | + p->bites[1] -= 2; // Reduce the value in the 'remaining length' part of the header. | ||
| 100 | + | ||
| 101 | + return p; | ||
| 102 | + } | ||
| 103 | + | ||
| 104 | + Publish pub(topic, getPayloadCopy(), new_max_qos); | ||
| 105 | + pub.retain = getRetain(); | ||
| 106 | + std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(pub)); | ||
| 107 | + return copyPacket; | ||
| 108 | + } | ||
| 109 | + | ||
| 54 | std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(*this)); | 110 | std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(*this)); |
| 55 | copyPacket->sender.reset(); | 111 | copyPacket->sender.reset(); |
| 112 | + if (qos != new_max_qos) | ||
| 113 | + copyPacket->setQos(new_max_qos); | ||
| 56 | return copyPacket; | 114 | return copyPacket; |
| 57 | } | 115 | } |
| 58 | 116 | ||
| @@ -773,11 +831,11 @@ size_t MqttPacket::getTotalMemoryFootprint() | @@ -773,11 +831,11 @@ size_t MqttPacket::getTotalMemoryFootprint() | ||
| 773 | * @return | 831 | * @return |
| 774 | * | 832 | * |
| 775 | * It's necessary sometimes, but it's against FlashMQ's concept of not parsing the payload. Normally, you can just write out | 833 | * It's necessary sometimes, but it's against FlashMQ's concept of not parsing the payload. Normally, you can just write out |
| 776 | - * the whole byte array that is a packet to subscribers. No need to copy and such. | 834 | + * the whole byte array of an original packet to subscribers. No need to copy and such. |
| 777 | * | 835 | * |
| 778 | - * I created it for saving QoS packages in the db file. | 836 | + * But, as stated, sometimes it's necessary. |
| 779 | */ | 837 | */ |
| 780 | -std::string MqttPacket::getPayloadCopy() | 838 | +std::string MqttPacket::getPayloadCopy() const |
| 781 | { | 839 | { |
| 782 | assert(payloadStart > 0); | 840 | assert(payloadStart > 0); |
| 783 | assert(pos <= bites.size()); | 841 | assert(pos <= bites.size()); |
| @@ -798,6 +856,23 @@ size_t MqttPacket::getSizeIncludingNonPresentHeader() const | @@ -798,6 +856,23 @@ size_t MqttPacket::getSizeIncludingNonPresentHeader() const | ||
| 798 | return total; | 856 | return total; |
| 799 | } | 857 | } |
| 800 | 858 | ||
| 859 | +void MqttPacket::setQos(const char new_qos) | ||
| 860 | +{ | ||
| 861 | + // You can't change to a QoS level that would remove the packet identifier. | ||
| 862 | + assert((qos == 0 && new_qos == 0) || (qos > 0 && new_qos > 0)); | ||
| 863 | + assert(new_qos > 0 && packet_id_pos > 0); | ||
| 864 | + | ||
| 865 | + qos = new_qos; | ||
| 866 | + first_byte &= 0b11111001; | ||
| 867 | + first_byte |= (qos << 1); | ||
| 868 | + | ||
| 869 | + if (fixed_header_length > 0) | ||
| 870 | + { | ||
| 871 | + pos = 0; | ||
| 872 | + writeByte(first_byte); | ||
| 873 | + } | ||
| 874 | +} | ||
| 875 | + | ||
| 801 | const std::string &MqttPacket::getTopic() const | 876 | const std::string &MqttPacket::getTopic() const |
| 802 | { | 877 | { |
| 803 | return this->topic; | 878 | return this->topic; |
| @@ -886,6 +961,31 @@ size_t MqttPacket::remainingAfterPos() | @@ -886,6 +961,31 @@ size_t MqttPacket::remainingAfterPos() | ||
| 886 | return bites.size() - pos; | 961 | return bites.size() - pos; |
| 887 | } | 962 | } |
| 888 | 963 | ||
| 964 | +bool MqttPacket::getRetain() const | ||
| 965 | +{ | ||
| 966 | + return (first_byte & 0b00000001); | ||
| 967 | +} | ||
| 968 | + | ||
| 969 | +/** | ||
| 970 | + * @brief MqttPacket::setRetain set the retain bit in the first byte. I think I only need this in tests, because existing subscribers don't get retain=1, | ||
| 971 | + * so handlePublish() clears it. But I needed it to be set in testing. | ||
| 972 | + * | ||
| 973 | + * Publishing of the retained messages goes through the MqttPacket(Publish) constructor, hence this setRetain() isn't necessary for that. | ||
| 974 | + */ | ||
| 975 | +void MqttPacket::setRetain() | ||
| 976 | +{ | ||
| 977 | +#ifndef TESTING | ||
| 978 | + assert(false); | ||
| 979 | +#endif | ||
| 980 | + | ||
| 981 | + first_byte |= 0b00000001; | ||
| 982 | + | ||
| 983 | + if (fixed_header_length > 0) | ||
| 984 | + { | ||
| 985 | + pos = 0; | ||
| 986 | + writeByte(first_byte); | ||
| 987 | + } | ||
| 988 | +} | ||
| 889 | 989 | ||
| 890 | void MqttPacket::readIntoBuf(CirBuf &buf) const | 990 | void MqttPacket::readIntoBuf(CirBuf &buf) const |
| 891 | { | 991 | { |
mqttpacket.h
| @@ -80,7 +80,7 @@ public: | @@ -80,7 +80,7 @@ public: | ||
| 80 | 80 | ||
| 81 | MqttPacket(MqttPacket &&other) = default; | 81 | MqttPacket(MqttPacket &&other) = default; |
| 82 | 82 | ||
| 83 | - std::shared_ptr<MqttPacket> getCopy() const; | 83 | + std::shared_ptr<MqttPacket> getCopy(char new_max_qos) const; |
| 84 | 84 | ||
| 85 | // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance. | 85 | // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance. |
| 86 | MqttPacket(const ConnAck &connAck); | 86 | MqttPacket(const ConnAck &connAck); |
| @@ -109,6 +109,7 @@ public: | @@ -109,6 +109,7 @@ public: | ||
| 109 | size_t getSizeIncludingNonPresentHeader() const; | 109 | size_t getSizeIncludingNonPresentHeader() const; |
| 110 | const std::vector<char> &getBites() const { return bites; } | 110 | const std::vector<char> &getBites() const { return bites; } |
| 111 | char getQos() const { return qos; } | 111 | char getQos() const { return qos; } |
| 112 | + void setQos(const char new_qos); | ||
| 112 | const std::string &getTopic() const; | 113 | const std::string &getTopic() const; |
| 113 | const std::vector<std::string> &getSubtopics() const; | 114 | const std::vector<std::string> &getSubtopics() const; |
| 114 | std::shared_ptr<Client> getSender() const; | 115 | std::shared_ptr<Client> getSender() const; |
| @@ -120,8 +121,10 @@ public: | @@ -120,8 +121,10 @@ public: | ||
| 120 | uint16_t getPacketId() const; | 121 | uint16_t getPacketId() const; |
| 121 | void setDuplicate(); | 122 | void setDuplicate(); |
| 122 | size_t getTotalMemoryFootprint(); | 123 | size_t getTotalMemoryFootprint(); |
| 123 | - std::string getPayloadCopy(); | ||
| 124 | void readIntoBuf(CirBuf &buf) const; | 124 | void readIntoBuf(CirBuf &buf) const; |
| 125 | + std::string getPayloadCopy() const; | ||
| 126 | + bool getRetain() const; | ||
| 127 | + void setRetain(); | ||
| 125 | }; | 128 | }; |
| 126 | 129 | ||
| 127 | #endif // MQTTPACKET_H | 130 | #endif // MQTTPACKET_H |
qospacketqueue.cpp
| @@ -44,11 +44,11 @@ size_t QoSPacketQueue::getByteSize() const | @@ -44,11 +44,11 @@ size_t QoSPacketQueue::getByteSize() const | ||
| 44 | * @param id | 44 | * @param id |
| 45 | * @return the packet copy. | 45 | * @return the packet copy. |
| 46 | */ | 46 | */ |
| 47 | -std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id) | 47 | +std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos) |
| 48 | { | 48 | { |
| 49 | assert(p.getQos() > 0); | 49 | assert(p.getQos() > 0); |
| 50 | 50 | ||
| 51 | - std::shared_ptr<MqttPacket> copyPacket = p.getCopy(); | 51 | + std::shared_ptr<MqttPacket> copyPacket = p.getCopy(new_max_qos); |
| 52 | copyPacket->setPacketId(id); | 52 | copyPacket->setPacketId(id); |
| 53 | queue.push_back(copyPacket); | 53 | queue.push_back(copyPacket); |
| 54 | qosQueueBytes += copyPacket->getTotalMemoryFootprint(); | 54 | qosQueueBytes += copyPacket->getTotalMemoryFootprint(); |
qospacketqueue.h
| @@ -15,7 +15,7 @@ public: | @@ -15,7 +15,7 @@ public: | ||
| 15 | void erase(const uint16_t packet_id); | 15 | void erase(const uint16_t packet_id); |
| 16 | size_t size() const; | 16 | size_t size() const; |
| 17 | size_t getByteSize() const; | 17 | size_t getByteSize() const; |
| 18 | - std::shared_ptr<MqttPacket> queuePacket(const MqttPacket &p, uint16_t id); | 18 | + std::shared_ptr<MqttPacket> queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos); |
| 19 | std::shared_ptr<MqttPacket> queuePacket(const Publish &pub, uint16_t id); | 19 | std::shared_ptr<MqttPacket> queuePacket(const Publish &pub, uint16_t id); |
| 20 | 20 | ||
| 21 | std::list<std::shared_ptr<MqttPacket>>::const_iterator begin() const; | 21 | std::list<std::shared_ptr<MqttPacket>>::const_iterator begin() const; |
session.cpp
| @@ -59,6 +59,29 @@ void Session::setSessionTouch(int64_t ageInMs) | @@ -59,6 +59,29 @@ void Session::setSessionTouch(int64_t ageInMs) | ||
| 59 | lastTouched = point; | 59 | lastTouched = point; |
| 60 | } | 60 | } |
| 61 | 61 | ||
| 62 | +bool Session::requiresPacketRetransmission() const | ||
| 63 | +{ | ||
| 64 | + const std::shared_ptr<Client> client = makeSharedClient(); | ||
| 65 | + | ||
| 66 | + if (!client) | ||
| 67 | + return true; | ||
| 68 | + | ||
| 69 | + // MQTT 3.1: "Brokers, however, should retry any unacknowledged message." | ||
| 70 | + // MQTT 3.1.1: "This [reconnecting] is the only circumstance where a Client or Server is REQUIRED to redeliver messages." | ||
| 71 | + if (client->getProtocolVersion() < ProtocolVersion::Mqtt311) | ||
| 72 | + return true; | ||
| 73 | + | ||
| 74 | + // TODO: for MQTT5, the rules are different. | ||
| 75 | + return !client->getCleanSession(); | ||
| 76 | +} | ||
| 77 | + | ||
| 78 | +void Session::increasePacketId() | ||
| 79 | +{ | ||
| 80 | + nextPacketId++; | ||
| 81 | + if (nextPacketId == 0) | ||
| 82 | + nextPacketId++; | ||
| 83 | +} | ||
| 84 | + | ||
| 62 | /** | 85 | /** |
| 63 | * @brief Session::Session copy constructor. Was created for session storing, and is explicitely kept private, to avoid making accidental copies. | 86 | * @brief Session::Session copy constructor. Was created for session storing, and is explicitely kept private, to avoid making accidental copies. |
| 64 | * @param other | 87 | * @param other |
| @@ -116,63 +139,101 @@ void Session::assignActiveConnection(std::shared_ptr<Client> &client) | @@ -116,63 +139,101 @@ void Session::assignActiveConnection(std::shared_ptr<Client> &client) | ||
| 116 | 139 | ||
| 117 | /** | 140 | /** |
| 118 | * @brief Session::writePacket is the main way to give a client a packet -> it goes through the session. | 141 | * @brief Session::writePacket is the main way to give a client a packet -> it goes through the session. |
| 119 | - * @param packet | 142 | + * @param packet is not const. We set the qos and packet id for each publish. This should be safe, because the packet |
| 143 | + * with original packet id and qos is not saved. This saves unnecessary copying. | ||
| 120 | * @param max_qos | 144 | * @param max_qos |
| 121 | * @param retain. Keep MQTT-3.3.1-9 in mind: existing subscribers don't get retain=1 on packets. | 145 | * @param retain. Keep MQTT-3.3.1-9 in mind: existing subscribers don't get retain=1 on packets. |
| 122 | * @param count. Reference value is updated. It's for statistics. | 146 | * @param count. Reference value is updated. It's for statistics. |
| 123 | */ | 147 | */ |
| 124 | -void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count) | 148 | +void Session::writePacket(MqttPacket &packet, char max_qos, std::shared_ptr<MqttPacket> &downgradedQos0PacketCopy, uint64_t &count) |
| 125 | { | 149 | { |
| 126 | assert(max_qos <= 2); | 150 | assert(max_qos <= 2); |
| 127 | - const char qos = std::min<char>(packet.getQos(), max_qos); | 151 | + const char effectiveQos = std::min<char>(packet.getQos(), max_qos); |
| 128 | 152 | ||
| 129 | Authentication *_auth = ThreadAuth::getAuth(); | 153 | Authentication *_auth = ThreadAuth::getAuth(); |
| 130 | assert(_auth); | 154 | assert(_auth); |
| 131 | Authentication &auth = *_auth; | 155 | Authentication &auth = *_auth; |
| 132 | - if (auth.aclCheck(client_id, username, packet.getTopic(), packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success) | 156 | + if (auth.aclCheck(client_id, username, packet.getTopic(), packet.getSubtopics(), AclAccess::read, effectiveQos, packet.getRetain()) == AuthResult::success) |
| 133 | { | 157 | { |
| 134 | - if (qos == 0) | 158 | + std::shared_ptr<Client> c = makeSharedClient(); |
| 159 | + if (effectiveQos == 0) | ||
| 135 | { | 160 | { |
| 136 | - std::shared_ptr<Client> c = makeSharedClient(); | ||
| 137 | - | ||
| 138 | if (c) | 161 | if (c) |
| 139 | { | 162 | { |
| 140 | - count += c->writeMqttPacketAndBlameThisClient(packet, qos); | 163 | + const MqttPacket *packetToSend = &packet; |
| 164 | + | ||
| 165 | + if (max_qos < packet.getQos()) | ||
| 166 | + { | ||
| 167 | + if (!downgradedQos0PacketCopy) | ||
| 168 | + downgradedQos0PacketCopy = packet.getCopy(max_qos); | ||
| 169 | + packetToSend = downgradedQos0PacketCopy.get(); | ||
| 170 | + } | ||
| 171 | + | ||
| 172 | + count += c->writeMqttPacketAndBlameThisClient(*packetToSend); | ||
| 141 | } | 173 | } |
| 142 | } | 174 | } |
| 143 | - else if (qos > 0) | 175 | + else if (effectiveQos > 0) |
| 144 | { | 176 | { |
| 145 | - std::unique_lock<std::mutex> locker(qosQueueMutex); | 177 | + const bool requiresRetransmission = requiresPacketRetransmission(); |
| 146 | 178 | ||
| 147 | - const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size(); | ||
| 148 | - if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosPacketQueue.getByteSize() >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) | 179 | + if (requiresRetransmission) |
| 149 | { | 180 | { |
| 150 | - if (QoSLogPrintedAtId != nextPacketId) | 181 | + std::unique_lock<std::mutex> locker(qosQueueMutex); |
| 182 | + | ||
| 183 | + const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size(); | ||
| 184 | + if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosPacketQueue.getByteSize() >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) | ||
| 151 | { | 185 | { |
| 152 | - logger->logf(LOG_WARNING, "Dropping QoS message(s) for client '%s', because its QoS buffers were full.", client_id.c_str()); | ||
| 153 | - QoSLogPrintedAtId = nextPacketId; | 186 | + if (QoSLogPrintedAtId != nextPacketId) |
| 187 | + { | ||
| 188 | + logger->logf(LOG_WARNING, "Dropping QoS message(s) for client '%s', because max in-transit packet count reached.", client_id.c_str()); | ||
| 189 | + QoSLogPrintedAtId = nextPacketId; | ||
| 190 | + } | ||
| 191 | + return; | ||
| 154 | } | 192 | } |
| 155 | - return; | ||
| 156 | - } | ||
| 157 | - nextPacketId++; | ||
| 158 | - if (nextPacketId == 0) | ||
| 159 | - nextPacketId++; | ||
| 160 | 193 | ||
| 161 | - std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId); | ||
| 162 | - locker.unlock(); | 194 | + increasePacketId(); |
| 195 | + std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId, effectiveQos); | ||
| 163 | 196 | ||
| 164 | - std::shared_ptr<Client> c = makeSharedClient(); | ||
| 165 | - if (c) | 197 | + if (c) |
| 198 | + { | ||
| 199 | + count += c->writeMqttPacketAndBlameThisClient(*copyPacket.get()); | ||
| 200 | + copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | ||
| 201 | + } | ||
| 202 | + } | ||
| 203 | + else | ||
| 166 | { | 204 | { |
| 167 | - count += c->writeMqttPacketAndBlameThisClient(*copyPacket.get(), qos); | ||
| 168 | - copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | 205 | + // We don't need to make a copy of the packet in this branch, because: |
| 206 | + // - The packet to give the client won't shrink in size because source and client have a packet_id. | ||
| 207 | + // - We don't have to store the copy in the session for retransmission, see Session::requiresPacketRetransmission() | ||
| 208 | + // So, we just keep altering the original published packet. | ||
| 209 | + | ||
| 210 | + std::unique_lock<std::mutex> locker(qosQueueMutex); | ||
| 211 | + | ||
| 212 | + if (qosInFlightCounter >= 65530) // Includes a small safety margin. | ||
| 213 | + { | ||
| 214 | + if (QoSLogPrintedAtId != nextPacketId) | ||
| 215 | + { | ||
| 216 | + logger->logf(LOG_WARNING, "Dropping QoS message(s) for client '%s', because it hasn't seen enough PUBACKs to release places.", client_id.c_str()); | ||
| 217 | + QoSLogPrintedAtId = nextPacketId; | ||
| 218 | + } | ||
| 219 | + return; | ||
| 220 | + } | ||
| 221 | + | ||
| 222 | + increasePacketId(); | ||
| 223 | + | ||
| 224 | + // This changes the packet ID and QoS of the incoming packet for each subscriber, but because we don't store that packet anywhere, | ||
| 225 | + // that should be fine. | ||
| 226 | + packet.setPacketId(nextPacketId); | ||
| 227 | + packet.setQos(effectiveQos); | ||
| 228 | + | ||
| 229 | + qosInFlightCounter++; | ||
| 230 | + assert(c); // with requiresRetransmission==false, there must be a client. | ||
| 231 | + c->writeMqttPacketAndBlameThisClient(packet); | ||
| 169 | } | 232 | } |
| 170 | } | 233 | } |
| 171 | } | 234 | } |
| 172 | } | 235 | } |
| 173 | 236 | ||
| 174 | -// Normatively, this while loop will break on the first element, because all messages are sent out in order and | ||
| 175 | -// should be acked in order. | ||
| 176 | void Session::clearQosMessage(uint16_t packet_id) | 237 | void Session::clearQosMessage(uint16_t packet_id) |
| 177 | { | 238 | { |
| 178 | #ifndef NDEBUG | 239 | #ifndef NDEBUG |
| @@ -180,7 +241,13 @@ void Session::clearQosMessage(uint16_t packet_id) | @@ -180,7 +241,13 @@ void Session::clearQosMessage(uint16_t packet_id) | ||
| 180 | #endif | 241 | #endif |
| 181 | 242 | ||
| 182 | std::lock_guard<std::mutex> locker(qosQueueMutex); | 243 | std::lock_guard<std::mutex> locker(qosQueueMutex); |
| 183 | - qosPacketQueue.erase(packet_id); | 244 | + if (requiresPacketRetransmission()) |
| 245 | + qosPacketQueue.erase(packet_id); | ||
| 246 | + else | ||
| 247 | + { | ||
| 248 | + qosInFlightCounter--; | ||
| 249 | + qosInFlightCounter = std::max<int>(0, qosInFlightCounter); // Should never happen, but in case we receive too many PUBACKs. | ||
| 250 | + } | ||
| 184 | } | 251 | } |
| 185 | 252 | ||
| 186 | // [MQTT-4.4.0-1]: "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST re-send any | 253 | // [MQTT-4.4.0-1]: "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST re-send any |
| @@ -199,7 +266,7 @@ uint64_t Session::sendPendingQosMessages() | @@ -199,7 +266,7 @@ uint64_t Session::sendPendingQosMessages() | ||
| 199 | std::lock_guard<std::mutex> locker(qosQueueMutex); | 266 | std::lock_guard<std::mutex> locker(qosQueueMutex); |
| 200 | for (const std::shared_ptr<MqttPacket> &qosMessage : qosPacketQueue) | 267 | for (const std::shared_ptr<MqttPacket> &qosMessage : qosPacketQueue) |
| 201 | { | 268 | { |
| 202 | - count += c->writeMqttPacketAndBlameThisClient(*qosMessage.get(), qosMessage->getQos()); | 269 | + count += c->writeMqttPacketAndBlameThisClient(*qosMessage.get()); |
| 203 | qosMessage->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | 270 | qosMessage->setDuplicate(); // Any dealings with this packet from here will be a duplicate. |
| 204 | } | 271 | } |
| 205 | 272 | ||
| @@ -207,7 +274,7 @@ uint64_t Session::sendPendingQosMessages() | @@ -207,7 +274,7 @@ uint64_t Session::sendPendingQosMessages() | ||
| 207 | { | 274 | { |
| 208 | PubRel pubRel(packet_id); | 275 | PubRel pubRel(packet_id); |
| 209 | MqttPacket packet(pubRel); | 276 | MqttPacket packet(pubRel); |
| 210 | - count += c->writeMqttPacketAndBlameThisClient(packet, 2); | 277 | + count += c->writeMqttPacketAndBlameThisClient(packet); |
| 211 | } | 278 | } |
| 212 | } | 279 | } |
| 213 | 280 |
session.h
| @@ -48,11 +48,14 @@ class Session | @@ -48,11 +48,14 @@ class Session | ||
| 48 | std::set<uint16_t> outgoingQoS2MessageIds; | 48 | std::set<uint16_t> outgoingQoS2MessageIds; |
| 49 | std::mutex qosQueueMutex; | 49 | std::mutex qosQueueMutex; |
| 50 | uint16_t nextPacketId = 0; | 50 | uint16_t nextPacketId = 0; |
| 51 | + uint16_t qosInFlightCounter = 0; | ||
| 51 | uint16_t QoSLogPrintedAtId = 0; | 52 | uint16_t QoSLogPrintedAtId = 0; |
| 52 | std::chrono::time_point<std::chrono::steady_clock> lastTouched = std::chrono::steady_clock::now(); | 53 | std::chrono::time_point<std::chrono::steady_clock> lastTouched = std::chrono::steady_clock::now(); |
| 53 | Logger *logger = Logger::getInstance(); | 54 | Logger *logger = Logger::getInstance(); |
| 54 | int64_t getSessionRelativeAgeInMs() const; | 55 | int64_t getSessionRelativeAgeInMs() const; |
| 55 | void setSessionTouch(int64_t ageInMs); | 56 | void setSessionTouch(int64_t ageInMs); |
| 57 | + bool requiresPacketRetransmission() const; | ||
| 58 | + void increasePacketId(); | ||
| 56 | 59 | ||
| 57 | Session(const Session &other); | 60 | Session(const Session &other); |
| 58 | public: | 61 | public: |
| @@ -69,7 +72,7 @@ public: | @@ -69,7 +72,7 @@ public: | ||
| 69 | const std::string &getClientId() const { return client_id; } | 72 | const std::string &getClientId() const { return client_id; } |
| 70 | std::shared_ptr<Client> makeSharedClient() const; | 73 | std::shared_ptr<Client> makeSharedClient() const; |
| 71 | void assignActiveConnection(std::shared_ptr<Client> &client); | 74 | void assignActiveConnection(std::shared_ptr<Client> &client); |
| 72 | - void writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count); | 75 | + void writePacket(MqttPacket &packet, char max_qos, std::shared_ptr<MqttPacket> &downgradedQos0PacketCopy, uint64_t &count); |
| 73 | void clearQosMessage(uint16_t packet_id); | 76 | void clearQosMessage(uint16_t packet_id); |
| 74 | uint64_t sendPendingQosMessages(); | 77 | uint64_t sendPendingQosMessages(); |
| 75 | void touch(std::chrono::time_point<std::chrono::steady_clock> val); | 78 | void touch(std::chrono::time_point<std::chrono::steady_clock> val); |
subscriptionstore.cpp
| @@ -331,7 +331,7 @@ void SubscriptionStore::publishRecursively(std::vector<std::string>::const_itera | @@ -331,7 +331,7 @@ void SubscriptionStore::publishRecursively(std::vector<std::string>::const_itera | ||
| 331 | } | 331 | } |
| 332 | } | 332 | } |
| 333 | 333 | ||
| 334 | -void SubscriptionStore::queuePacketAtSubscribers(const std::vector<std::string> &subtopics, const MqttPacket &packet, bool dollar) | 334 | +void SubscriptionStore::queuePacketAtSubscribers(const std::vector<std::string> &subtopics, MqttPacket &packet, bool dollar) |
| 335 | { | 335 | { |
| 336 | assert(subtopics.size() > 0); | 336 | assert(subtopics.size() > 0); |
| 337 | 337 | ||
| @@ -346,9 +346,10 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::vector<std::string> | @@ -346,9 +346,10 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::vector<std::string> | ||
| 346 | publishRecursively(subtopics.begin(), subtopics.end(), startNode, subscriberSessions); | 346 | publishRecursively(subtopics.begin(), subtopics.end(), startNode, subscriberSessions); |
| 347 | } | 347 | } |
| 348 | 348 | ||
| 349 | + std::shared_ptr<MqttPacket> possibleQos0Copy; | ||
| 349 | for(const ReceivingSubscriber &x : subscriberSessions) | 350 | for(const ReceivingSubscriber &x : subscriberSessions) |
| 350 | { | 351 | { |
| 351 | - x.session->writePacket(packet, x.qos, false, count); | 352 | + x.session->writePacket(packet, x.qos, possibleQos0Copy, count); |
| 352 | } | 353 | } |
| 353 | 354 | ||
| 354 | std::shared_ptr<Client> sender = packet.getSender(); | 355 | std::shared_ptr<Client> sender = packet.getSender(); |
| @@ -422,9 +423,10 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr<Ses | @@ -422,9 +423,10 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr<Ses | ||
| 422 | giveClientRetainedMessagesRecursively(subscribeSubtopics.begin(), subscribeSubtopics.end(), startNode, false, packetList); | 423 | giveClientRetainedMessagesRecursively(subscribeSubtopics.begin(), subscribeSubtopics.end(), startNode, false, packetList); |
| 423 | } | 424 | } |
| 424 | 425 | ||
| 425 | - for(const MqttPacket &packet : packetList) | 426 | + std::shared_ptr<MqttPacket> possibleQos0Copy; |
| 427 | + for(MqttPacket &packet : packetList) | ||
| 426 | { | 428 | { |
| 427 | - ses->writePacket(packet, max_qos, true, count); | 429 | + ses->writePacket(packet, max_qos, possibleQos0Copy, count); |
| 428 | } | 430 | } |
| 429 | 431 | ||
| 430 | return count; | 432 | return count; |
subscriptionstore.h
| @@ -120,7 +120,7 @@ public: | @@ -120,7 +120,7 @@ public: | ||
| 120 | void registerClientAndKickExistingOne(std::shared_ptr<Client> &client); | 120 | void registerClientAndKickExistingOne(std::shared_ptr<Client> &client); |
| 121 | bool sessionPresent(const std::string &clientid); | 121 | bool sessionPresent(const std::string &clientid); |
| 122 | 122 | ||
| 123 | - void queuePacketAtSubscribers(const std::vector<std::string> &subtopics, const MqttPacket &packet, bool dollar = false); | 123 | + void queuePacketAtSubscribers(const std::vector<std::string> &subtopics, MqttPacket &packet, bool dollar = false); |
| 124 | void giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end, | 124 | void giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end, |
| 125 | RetainedMessageNode *this_node, bool poundMode, std::forward_list<MqttPacket> &packetList) const; | 125 | RetainedMessageNode *this_node, bool poundMode, std::forward_list<MqttPacket> &packetList) const; |
| 126 | uint64_t giveClientRetainedMessages(const std::shared_ptr<Session> &ses, const std::vector<std::string> &subscribeSubtopics, char max_qos); | 126 | uint64_t giveClientRetainedMessages(const std::shared_ptr<Session> &ses, const std::vector<std::string> &subscribeSubtopics, char max_qos); |
threaddata.cpp
| @@ -127,7 +127,8 @@ void ThreadData::publishStat(const std::string &topic, uint64_t n) | @@ -127,7 +127,8 @@ void ThreadData::publishStat(const std::string &topic, uint64_t n) | ||
| 127 | splitTopic(topic, subtopics); | 127 | splitTopic(topic, subtopics); |
| 128 | const std::string payload = std::to_string(n); | 128 | const std::string payload = std::to_string(n); |
| 129 | Publish p(topic, payload, 0); | 129 | Publish p(topic, payload, 0); |
| 130 | - subscriptionStore->queuePacketAtSubscribers(subtopics, p, true); | 130 | + MqttPacket pack(p); |
| 131 | + subscriptionStore->queuePacketAtSubscribers(subtopics, pack, true); | ||
| 131 | subscriptionStore->setRetainedMessage(topic, subtopics, payload, 0); | 132 | subscriptionStore->setRetainedMessage(topic, subtopics, payload, 0); |
| 132 | } | 133 | } |
| 133 | 134 |