Commit f67238c5b794bb4ff5d3666b052bb6c80dad3ba2
1 parent
4947d33a
Fix not downgrading QoS to subscription QoS
This entails making copies of the original packet when necessary, because QoS 0 doesn't have a packet id. I tried to keep it to an absolute minimum and do some precarious optmizations for it. There are tests though.
Showing
16 changed files
with
419 additions
and
60 deletions
FlashMQTests/tst_maintests.cpp
| ... | ... | @@ -102,6 +102,15 @@ private slots: |
| 102 | 102 | |
| 103 | 103 | void testSavingSessions(); |
| 104 | 104 | |
| 105 | + void testCopyPacket(); | |
| 106 | + | |
| 107 | + void testDowngradeQoSOnSubscribeQos2to2(); | |
| 108 | + void testDowngradeQoSOnSubscribeQos2to1(); | |
| 109 | + void testDowngradeQoSOnSubscribeQos2to0(); | |
| 110 | + void testDowngradeQoSOnSubscribeQos1to1(); | |
| 111 | + void testDowngradeQoSOnSubscribeQos1to0(); | |
| 112 | + void testDowngradeQoSOnSubscribeQos0to0(); | |
| 113 | + | |
| 105 | 114 | }; |
| 106 | 115 | |
| 107 | 116 | MainTests::MainTests() |
| ... | ... | @@ -1027,7 +1036,9 @@ void MainTests::testSavingSessions() |
| 1027 | 1036 | |
| 1028 | 1037 | std::shared_ptr<Session> c1ses = c1->getSession(); |
| 1029 | 1038 | c1.reset(); |
| 1030 | - c1ses->writePacket(publish, 1, false, count); | |
| 1039 | + MqttPacket publishPacket(publish); | |
| 1040 | + std::shared_ptr<MqttPacket> possibleQos0Copy; | |
| 1041 | + c1ses->writePacket(publishPacket, 1, possibleQos0Copy, count); | |
| 1031 | 1042 | |
| 1032 | 1043 | store->saveSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db"); |
| 1033 | 1044 | |
| ... | ... | @@ -1091,6 +1102,145 @@ void MainTests::testSavingSessions() |
| 1091 | 1102 | } |
| 1092 | 1103 | } |
| 1093 | 1104 | |
| 1105 | +void testCopyPacketHelper(const std::string &topic, char from_qos, char to_qos, bool retain) | |
| 1106 | +{ | |
| 1107 | + assert(to_qos <= from_qos); | |
| 1108 | + | |
| 1109 | + Logger::getInstance()->setFlags(false, false, true); | |
| 1110 | + | |
| 1111 | + std::shared_ptr<Settings> settings(new Settings()); | |
| 1112 | + settings->logDebug = false; | |
| 1113 | + std::shared_ptr<SubscriptionStore> store(new SubscriptionStore()); | |
| 1114 | + std::shared_ptr<ThreadData> t(new ThreadData(0, store, settings)); | |
| 1115 | + | |
| 1116 | + // Kind of a hack... | |
| 1117 | + Authentication auth(*settings.get()); | |
| 1118 | + ThreadAuth::assign(&auth); | |
| 1119 | + | |
| 1120 | + std::shared_ptr<Client> dummyClient(new Client(0, t, nullptr, false, nullptr, settings, false)); | |
| 1121 | + dummyClient->setClientProperties(ProtocolVersion::Mqtt311, "qostestclient", "user1", true, 60, false); | |
| 1122 | + store->registerClientAndKickExistingOne(dummyClient); | |
| 1123 | + | |
| 1124 | + uint16_t packetid = 66; | |
| 1125 | + for (int len = 0; len < 150; len++ ) | |
| 1126 | + { | |
| 1127 | + const uint16_t pack_id = packetid++; | |
| 1128 | + | |
| 1129 | + std::vector<MqttPacket> parsedPackets; | |
| 1130 | + | |
| 1131 | + const std::string payloadOne = getSecureRandomString(len); | |
| 1132 | + Publish pubOne(topic, payloadOne, from_qos); | |
| 1133 | + pubOne.retain = retain; | |
| 1134 | + MqttPacket stagingPacketOne(pubOne); | |
| 1135 | + if (from_qos > 0) | |
| 1136 | + stagingPacketOne.setPacketId(pack_id); | |
| 1137 | + CirBuf stagingBufOne(1024); | |
| 1138 | + stagingPacketOne.readIntoBuf(stagingBufOne); | |
| 1139 | + | |
| 1140 | + MqttPacket::bufferToMqttPackets(stagingBufOne, parsedPackets, dummyClient); | |
| 1141 | + QVERIFY(parsedPackets.size() == 1); | |
| 1142 | + MqttPacket parsedPacketOne = std::move(parsedPackets.front()); | |
| 1143 | + parsedPacketOne.handlePublish(); | |
| 1144 | + if (retain) // A normal handled packet always has retain=0, so I force setting it here. | |
| 1145 | + parsedPacketOne.setRetain(); | |
| 1146 | + QCOMPARE(stagingPacketOne.getTopic(), parsedPacketOne.getTopic()); | |
| 1147 | + QCOMPARE(stagingPacketOne.getPayloadCopy(), parsedPacketOne.getPayloadCopy()); | |
| 1148 | + | |
| 1149 | + std::shared_ptr<MqttPacket> copiedPacketOne = parsedPacketOne.getCopy(to_qos); | |
| 1150 | + | |
| 1151 | + QCOMPARE(payloadOne, copiedPacketOne->getPayloadCopy()); | |
| 1152 | + | |
| 1153 | + // Now compare the written buffer of our copied packet to one that was written with our known good reference packet. | |
| 1154 | + | |
| 1155 | + Publish pubReference(topic, payloadOne, to_qos); | |
| 1156 | + pubReference.retain = retain; | |
| 1157 | + MqttPacket packetReference(pubReference); | |
| 1158 | + if (to_qos > 0) | |
| 1159 | + packetReference.setPacketId(pack_id); | |
| 1160 | + CirBuf bufOfReference(1024); | |
| 1161 | + CirBuf bufOfCopied(1024); | |
| 1162 | + packetReference.readIntoBuf(bufOfReference); | |
| 1163 | + copiedPacketOne->readIntoBuf(bufOfCopied); | |
| 1164 | + QVERIFY2(bufOfCopied == bufOfReference, formatString("Failure on length %d for topic %s, from qos %d to qos %d, retain: %d.", | |
| 1165 | + len, topic.c_str(), from_qos, to_qos, retain).c_str()); | |
| 1166 | + } | |
| 1167 | +} | |
| 1168 | + | |
| 1169 | +/** | |
| 1170 | + * @brief MainTests::testCopyPacket tests the actual bytes of a copied that would be written to a client. | |
| 1171 | + * | |
| 1172 | + * This is specifically to test the optimiziations in getCopy(). It indirectly also tests packet parsing. | |
| 1173 | + */ | |
| 1174 | +void MainTests::testCopyPacket() | |
| 1175 | +{ | |
| 1176 | + for (int retain = 0; retain < 2; retain++) | |
| 1177 | + { | |
| 1178 | + testCopyPacketHelper("John/McLane", 0, 0, retain); | |
| 1179 | + testCopyPacketHelper("Ben/Sisko", 1, 1, retain); | |
| 1180 | + testCopyPacketHelper("Rebecca/Bunch", 2, 2, retain); | |
| 1181 | + | |
| 1182 | + testCopyPacketHelper("Buffy/Slayer", 1, 0, retain); | |
| 1183 | + testCopyPacketHelper("Sarah/Connor", 2, 0, retain); | |
| 1184 | + testCopyPacketHelper("Susan/Mayer", 2, 1, retain); | |
| 1185 | + } | |
| 1186 | +} | |
| 1187 | + | |
| 1188 | +void testDowngradeQoSOnSubscribeHelper(const char pub_qos, const char sub_qos) | |
| 1189 | +{ | |
| 1190 | + TwoClientTestContext testContext; | |
| 1191 | + | |
| 1192 | + const QString topic("Star/Trek"); | |
| 1193 | + const QByteArray payload("Captain Kirk"); | |
| 1194 | + | |
| 1195 | + testContext.connectSender(); | |
| 1196 | + testContext.connectReceiver(); | |
| 1197 | + | |
| 1198 | + testContext.subscribeReceiver(topic, sub_qos); | |
| 1199 | + testContext.publish(topic, payload, pub_qos, false); | |
| 1200 | + | |
| 1201 | + testContext.waitReceiverReceived(1); | |
| 1202 | + | |
| 1203 | + QCOMPARE(testContext.receivedMessages.length(), 1); | |
| 1204 | + QMQTT::Message &recv = testContext.receivedMessages.first(); | |
| 1205 | + | |
| 1206 | + const char expected_qos = std::min<const char>(pub_qos, sub_qos); | |
| 1207 | + QVERIFY2(recv.qos() == expected_qos, formatString("Failure: received QoS is %d. Published is %d. Subscribed as %d. Expected QoS is %d", | |
| 1208 | + recv.qos(), pub_qos, sub_qos, expected_qos).c_str()); | |
| 1209 | + QVERIFY(recv.topic() == topic); | |
| 1210 | + QVERIFY(recv.payload() == payload); | |
| 1211 | +} | |
| 1212 | + | |
| 1213 | +void MainTests::testDowngradeQoSOnSubscribeQos2to2() | |
| 1214 | +{ | |
| 1215 | + testDowngradeQoSOnSubscribeHelper(2, 2); | |
| 1216 | +} | |
| 1217 | + | |
| 1218 | +void MainTests::testDowngradeQoSOnSubscribeQos2to1() | |
| 1219 | +{ | |
| 1220 | + testDowngradeQoSOnSubscribeHelper(2, 1); | |
| 1221 | +} | |
| 1222 | + | |
| 1223 | +void MainTests::testDowngradeQoSOnSubscribeQos2to0() | |
| 1224 | +{ | |
| 1225 | + testDowngradeQoSOnSubscribeHelper(2, 0); | |
| 1226 | +} | |
| 1227 | + | |
| 1228 | +void MainTests::testDowngradeQoSOnSubscribeQos1to1() | |
| 1229 | +{ | |
| 1230 | + testDowngradeQoSOnSubscribeHelper(1, 1); | |
| 1231 | +} | |
| 1232 | + | |
| 1233 | +void MainTests::testDowngradeQoSOnSubscribeQos1to0() | |
| 1234 | +{ | |
| 1235 | + testDowngradeQoSOnSubscribeHelper(1, 0); | |
| 1236 | +} | |
| 1237 | + | |
| 1238 | +void MainTests::testDowngradeQoSOnSubscribeQos0to0() | |
| 1239 | +{ | |
| 1240 | + testDowngradeQoSOnSubscribeHelper(0, 0); | |
| 1241 | +} | |
| 1242 | + | |
| 1243 | + | |
| 1094 | 1244 | int main(int argc, char *argv[]) |
| 1095 | 1245 | { |
| 1096 | 1246 | QCoreApplication app(argc, argv); | ... | ... |
FlashMQTests/twoclienttestcontext.cpp
| ... | ... | @@ -35,12 +35,22 @@ TwoClientTestContext::TwoClientTestContext(QObject *parent) : QObject(parent) |
| 35 | 35 | connect(receiver.data(), &QMQTT::Client::error, this, &TwoClientTestContext::onClientError); |
| 36 | 36 | } |
| 37 | 37 | |
| 38 | +void TwoClientTestContext::publish(const QString &topic, const QByteArray &payload) | |
| 39 | +{ | |
| 40 | + publish(topic, payload, 0, false); | |
| 41 | +} | |
| 42 | + | |
| 38 | 43 | void TwoClientTestContext::publish(const QString &topic, const QByteArray &payload, bool retain) |
| 39 | 44 | { |
| 45 | + publish(topic, payload, 0, retain); | |
| 46 | +} | |
| 47 | + | |
| 48 | +void TwoClientTestContext::publish(const QString &topic, const QByteArray &payload, const quint8 qos, bool retain) | |
| 49 | +{ | |
| 40 | 50 | QMQTT::Message msg; |
| 41 | 51 | msg.setTopic(topic); |
| 42 | 52 | msg.setRetain(retain); |
| 43 | - msg.setQos(0); | |
| 53 | + msg.setQos(qos); | |
| 44 | 54 | msg.setPayload(payload); |
| 45 | 55 | sender->publish(msg); |
| 46 | 56 | } |
| ... | ... | @@ -71,9 +81,9 @@ void TwoClientTestContext::disconnectReceiver() |
| 71 | 81 | waiter.exec(); |
| 72 | 82 | } |
| 73 | 83 | |
| 74 | -void TwoClientTestContext::subscribeReceiver(const QString &topic) | |
| 84 | +void TwoClientTestContext::subscribeReceiver(const QString &topic, const quint8 qos) | |
| 75 | 85 | { |
| 76 | - receiver->subscribe(topic); | |
| 86 | + receiver->subscribe(topic, qos); | |
| 77 | 87 | |
| 78 | 88 | QEventLoop waiter; |
| 79 | 89 | QTimer timeout; | ... | ... |
FlashMQTests/twoclienttestcontext.h
| ... | ... | @@ -34,11 +34,13 @@ private slots: |
| 34 | 34 | |
| 35 | 35 | public: |
| 36 | 36 | explicit TwoClientTestContext(QObject *parent = nullptr); |
| 37 | - void publish(const QString &topic, const QByteArray &payload, bool retain = false); | |
| 37 | + void publish(const QString &topic, const QByteArray &payload); | |
| 38 | + void publish(const QString &topic, const QByteArray &payload, bool retain); | |
| 39 | + void publish(const QString &topic, const QByteArray &payload, const quint8 qos, bool retain); | |
| 38 | 40 | void connectSender(); |
| 39 | 41 | void connectReceiver(); |
| 40 | 42 | void disconnectReceiver(); |
| 41 | - void subscribeReceiver(const QString &topic); | |
| 43 | + void subscribeReceiver(const QString &topic, const quint8 qos = 0); | |
| 42 | 44 | void waitReceiverReceived(int count); |
| 43 | 45 | void onClientError(const QMQTT::ClientError error); |
| 44 | 46 | ... | ... |
cirbuf.cpp
| ... | ... | @@ -252,3 +252,22 @@ void CirBuf::read(void *buf, const size_t count) |
| 252 | 252 | assert(_packet_len == 0); |
| 253 | 253 | assert(i == static_cast<int>(count)); |
| 254 | 254 | } |
| 255 | + | |
| 256 | +/** | |
| 257 | + * @brief CirBuf::operator == simplistic comparision. It doesn't take the fact that it's circular into account. | |
| 258 | + * @param other | |
| 259 | + * @return | |
| 260 | + * | |
| 261 | + * It was created for unit testing. read() and write() are non-const, so taking the circular properties into account | |
| 262 | + * would need more/duplicate code that I don't need at this point. | |
| 263 | + */ | |
| 264 | +bool CirBuf::operator==(const CirBuf &other) const | |
| 265 | +{ | |
| 266 | +#ifdef NDEBUG | |
| 267 | + throw std::exception(); // you can't use it in release builds, because new buffers aren't zeroed. | |
| 268 | +#endif | |
| 269 | + | |
| 270 | + return tail == 0 && other.tail == 0 | |
| 271 | + && usedBytes() == other.usedBytes() | |
| 272 | + && std::memcmp(buf, other.buf, size) == 0; | |
| 273 | +} | ... | ... |
cirbuf.h
client.cpp
| ... | ... | @@ -56,7 +56,7 @@ Client::~Client() |
| 56 | 56 | { |
| 57 | 57 | Publish will(will_topic, will_payload, will_qos); |
| 58 | 58 | will.retain = will_retain; |
| 59 | - const MqttPacket willPacket(will); | |
| 59 | + MqttPacket willPacket(will); | |
| 60 | 60 | |
| 61 | 61 | const std::vector<std::string> subtopics = splitToVector(will_topic, '/'); |
| 62 | 62 | store->queuePacketAtSubscribers(subtopics, willPacket); |
| ... | ... | @@ -180,7 +180,7 @@ void Client::writeText(const std::string &text) |
| 180 | 180 | setReadyForWriting(true); |
| 181 | 181 | } |
| 182 | 182 | |
| 183 | -int Client::writeMqttPacket(const MqttPacket &packet, const char qos) | |
| 183 | +int Client::writeMqttPacket(const MqttPacket &packet) | |
| 184 | 184 | { |
| 185 | 185 | std::lock_guard<std::mutex> locker(writeBufMutex); |
| 186 | 186 | |
| ... | ... | @@ -193,7 +193,7 @@ int Client::writeMqttPacket(const MqttPacket &packet, const char qos) |
| 193 | 193 | |
| 194 | 194 | // And drop a publish when it doesn't fit, even after resizing. This means we do allow pings. And |
| 195 | 195 | // QoS packet are queued and limited elsewhere. |
| 196 | - if (packet.packetType == PacketType::PUBLISH && qos == 0 && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace()) | |
| 196 | + if (packet.packetType == PacketType::PUBLISH && packet.getQos() == 0 && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace()) | |
| 197 | 197 | { |
| 198 | 198 | return 0; |
| 199 | 199 | } |
| ... | ... | @@ -208,11 +208,11 @@ int Client::writeMqttPacket(const MqttPacket &packet, const char qos) |
| 208 | 208 | } |
| 209 | 209 | |
| 210 | 210 | // Helper method to avoid the exception ending up at the sender of messages, which would then get disconnected. |
| 211 | -int Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet, const char qos) | |
| 211 | +int Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet) | |
| 212 | 212 | { |
| 213 | 213 | try |
| 214 | 214 | { |
| 215 | - return this->writeMqttPacket(packet, qos); | |
| 215 | + return this->writeMqttPacket(packet); | |
| 216 | 216 | } |
| 217 | 217 | catch (std::exception &ex) |
| 218 | 218 | { | ... | ... |
client.h
| ... | ... | @@ -121,8 +121,8 @@ public: |
| 121 | 121 | |
| 122 | 122 | void writeText(const std::string &text); |
| 123 | 123 | void writePingResp(); |
| 124 | - int writeMqttPacket(const MqttPacket &packet, const char qos = 0); | |
| 125 | - int writeMqttPacketAndBlameThisClient(const MqttPacket &packet, const char qos); | |
| 124 | + int writeMqttPacket(const MqttPacket &packet); | |
| 125 | + int writeMqttPacketAndBlameThisClient(const MqttPacket &packet); | |
| 126 | 126 | bool writeBufIntoFd(); |
| 127 | 127 | bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; } |
| 128 | 128 | ... | ... |
mqttpacket.cpp
| ... | ... | @@ -48,11 +48,69 @@ MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_lengt |
| 48 | 48 | * @brief MqttPacket::getCopy (using default copy constructor and resetting some selected fields) is easier than using the copy constructor |
| 49 | 49 | * publically, because then I have to keep maintaining a functioning copy constructor for each new field I add. |
| 50 | 50 | * @return a shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource. |
| 51 | + * | |
| 52 | + * The idea is that because a packet with QoS is longer than one without, we just copy as much as possible if both packets have the same QoS. | |
| 53 | + * | |
| 54 | + * Note that there can be two types of packets: one with the fixed header (including remaining length), and one without. The latter we could be | |
| 55 | + * more clever about, but I'm forgoing that right now. Their use is mostly for retained messages. | |
| 56 | + * | |
| 57 | + * Also note that some fields are undeterminstic in the copy: dup, retain and packetid for instance. Sometimes they come from the original, | |
| 58 | + * sometimes not. The current planned usage is that those fields will either ONLY or NEVER be used in the copy, so it doesn't matter what I do | |
| 59 | + * with them here. I may reconsider. | |
| 51 | 60 | */ |
| 52 | -std::shared_ptr<MqttPacket> MqttPacket::getCopy() const | |
| 61 | +std::shared_ptr<MqttPacket> MqttPacket::getCopy(char new_max_qos) const | |
| 53 | 62 | { |
| 63 | + assert(packetType == PacketType::PUBLISH); | |
| 64 | + | |
| 65 | + // You're not supposed to copy a duplicate packet. The only packets that get the dup flag, should not be copied AGAIN. This | |
| 66 | + // has to do with the Session::writePacket() and Session::sendPendingQosMessages() logic. | |
| 67 | + assert((first_byte & 0b00001000) == 0); | |
| 68 | + | |
| 69 | + if (qos > 0 && new_max_qos == 0) | |
| 70 | + { | |
| 71 | + // if shrinking the packet doesn't alter the amount of bytes in the 'remaining length' part of the header, we can | |
| 72 | + // just memmove+shrink the packet. This is because the packet id always is two bytes before the payload, so we just move the payload | |
| 73 | + // over it. When testing 100M copies, it went from 21000 ms to 10000 ms. In other words, about 100 ยตs to 200 ยตs per copy. | |
| 74 | + // There is an elaborate unit test to test this optimization. | |
| 75 | + if ((fixed_header_length == 2 && bites.size() < 125)) | |
| 76 | + { | |
| 77 | + // I don't know yet if this is true, but I don't want to forget when I implemenet MQTT5. | |
| 78 | + assert(sender && sender->getProtocolVersion() <= ProtocolVersion::Mqtt311); | |
| 79 | + | |
| 80 | + std::shared_ptr<MqttPacket> p(new MqttPacket(*this)); | |
| 81 | + p->sender.reset(); | |
| 82 | + | |
| 83 | + if (payloadLen > 0) | |
| 84 | + std::memmove(&p->bites[packet_id_pos], &p->bites[packet_id_pos+2], payloadLen); | |
| 85 | + p->bites.erase(p->bites.end() - 2, p->bites.end()); | |
| 86 | + p->packet_id_pos = 0; | |
| 87 | + p->payloadStart -= 2; | |
| 88 | + if (pos > p->bites.size()) // pos can possible be set elsewhere, so we only set it back if it was after the payload. | |
| 89 | + p->pos -= 2; | |
| 90 | + p->packet_id = 0; | |
| 91 | + | |
| 92 | + // Clear QoS bits from the header. | |
| 93 | + p->first_byte &= 0b11111001; | |
| 94 | + p->bites[0] = p->first_byte; | |
| 95 | + | |
| 96 | + assert((p->bites[1] & 0b10000000) == 0); // when there is an MSB, I musn't get rid of it. | |
| 97 | + assert(p->bites[1] > 3); // There has to be a remaining value after subtracting 2. | |
| 98 | + | |
| 99 | + p->bites[1] -= 2; // Reduce the value in the 'remaining length' part of the header. | |
| 100 | + | |
| 101 | + return p; | |
| 102 | + } | |
| 103 | + | |
| 104 | + Publish pub(topic, getPayloadCopy(), new_max_qos); | |
| 105 | + pub.retain = getRetain(); | |
| 106 | + std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(pub)); | |
| 107 | + return copyPacket; | |
| 108 | + } | |
| 109 | + | |
| 54 | 110 | std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(*this)); |
| 55 | 111 | copyPacket->sender.reset(); |
| 112 | + if (qos != new_max_qos) | |
| 113 | + copyPacket->setQos(new_max_qos); | |
| 56 | 114 | return copyPacket; |
| 57 | 115 | } |
| 58 | 116 | |
| ... | ... | @@ -773,11 +831,11 @@ size_t MqttPacket::getTotalMemoryFootprint() |
| 773 | 831 | * @return |
| 774 | 832 | * |
| 775 | 833 | * It's necessary sometimes, but it's against FlashMQ's concept of not parsing the payload. Normally, you can just write out |
| 776 | - * the whole byte array that is a packet to subscribers. No need to copy and such. | |
| 834 | + * the whole byte array of an original packet to subscribers. No need to copy and such. | |
| 777 | 835 | * |
| 778 | - * I created it for saving QoS packages in the db file. | |
| 836 | + * But, as stated, sometimes it's necessary. | |
| 779 | 837 | */ |
| 780 | -std::string MqttPacket::getPayloadCopy() | |
| 838 | +std::string MqttPacket::getPayloadCopy() const | |
| 781 | 839 | { |
| 782 | 840 | assert(payloadStart > 0); |
| 783 | 841 | assert(pos <= bites.size()); |
| ... | ... | @@ -798,6 +856,23 @@ size_t MqttPacket::getSizeIncludingNonPresentHeader() const |
| 798 | 856 | return total; |
| 799 | 857 | } |
| 800 | 858 | |
| 859 | +void MqttPacket::setQos(const char new_qos) | |
| 860 | +{ | |
| 861 | + // You can't change to a QoS level that would remove the packet identifier. | |
| 862 | + assert((qos == 0 && new_qos == 0) || (qos > 0 && new_qos > 0)); | |
| 863 | + assert(new_qos > 0 && packet_id_pos > 0); | |
| 864 | + | |
| 865 | + qos = new_qos; | |
| 866 | + first_byte &= 0b11111001; | |
| 867 | + first_byte |= (qos << 1); | |
| 868 | + | |
| 869 | + if (fixed_header_length > 0) | |
| 870 | + { | |
| 871 | + pos = 0; | |
| 872 | + writeByte(first_byte); | |
| 873 | + } | |
| 874 | +} | |
| 875 | + | |
| 801 | 876 | const std::string &MqttPacket::getTopic() const |
| 802 | 877 | { |
| 803 | 878 | return this->topic; |
| ... | ... | @@ -886,6 +961,31 @@ size_t MqttPacket::remainingAfterPos() |
| 886 | 961 | return bites.size() - pos; |
| 887 | 962 | } |
| 888 | 963 | |
| 964 | +bool MqttPacket::getRetain() const | |
| 965 | +{ | |
| 966 | + return (first_byte & 0b00000001); | |
| 967 | +} | |
| 968 | + | |
| 969 | +/** | |
| 970 | + * @brief MqttPacket::setRetain set the retain bit in the first byte. I think I only need this in tests, because existing subscribers don't get retain=1, | |
| 971 | + * so handlePublish() clears it. But I needed it to be set in testing. | |
| 972 | + * | |
| 973 | + * Publishing of the retained messages goes through the MqttPacket(Publish) constructor, hence this setRetain() isn't necessary for that. | |
| 974 | + */ | |
| 975 | +void MqttPacket::setRetain() | |
| 976 | +{ | |
| 977 | +#ifndef TESTING | |
| 978 | + assert(false); | |
| 979 | +#endif | |
| 980 | + | |
| 981 | + first_byte |= 0b00000001; | |
| 982 | + | |
| 983 | + if (fixed_header_length > 0) | |
| 984 | + { | |
| 985 | + pos = 0; | |
| 986 | + writeByte(first_byte); | |
| 987 | + } | |
| 988 | +} | |
| 889 | 989 | |
| 890 | 990 | void MqttPacket::readIntoBuf(CirBuf &buf) const |
| 891 | 991 | { | ... | ... |
mqttpacket.h
| ... | ... | @@ -80,7 +80,7 @@ public: |
| 80 | 80 | |
| 81 | 81 | MqttPacket(MqttPacket &&other) = default; |
| 82 | 82 | |
| 83 | - std::shared_ptr<MqttPacket> getCopy() const; | |
| 83 | + std::shared_ptr<MqttPacket> getCopy(char new_max_qos) const; | |
| 84 | 84 | |
| 85 | 85 | // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance. |
| 86 | 86 | MqttPacket(const ConnAck &connAck); |
| ... | ... | @@ -109,6 +109,7 @@ public: |
| 109 | 109 | size_t getSizeIncludingNonPresentHeader() const; |
| 110 | 110 | const std::vector<char> &getBites() const { return bites; } |
| 111 | 111 | char getQos() const { return qos; } |
| 112 | + void setQos(const char new_qos); | |
| 112 | 113 | const std::string &getTopic() const; |
| 113 | 114 | const std::vector<std::string> &getSubtopics() const; |
| 114 | 115 | std::shared_ptr<Client> getSender() const; |
| ... | ... | @@ -120,8 +121,10 @@ public: |
| 120 | 121 | uint16_t getPacketId() const; |
| 121 | 122 | void setDuplicate(); |
| 122 | 123 | size_t getTotalMemoryFootprint(); |
| 123 | - std::string getPayloadCopy(); | |
| 124 | 124 | void readIntoBuf(CirBuf &buf) const; |
| 125 | + std::string getPayloadCopy() const; | |
| 126 | + bool getRetain() const; | |
| 127 | + void setRetain(); | |
| 125 | 128 | }; |
| 126 | 129 | |
| 127 | 130 | #endif // MQTTPACKET_H | ... | ... |
qospacketqueue.cpp
| ... | ... | @@ -44,11 +44,11 @@ size_t QoSPacketQueue::getByteSize() const |
| 44 | 44 | * @param id |
| 45 | 45 | * @return the packet copy. |
| 46 | 46 | */ |
| 47 | -std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id) | |
| 47 | +std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos) | |
| 48 | 48 | { |
| 49 | 49 | assert(p.getQos() > 0); |
| 50 | 50 | |
| 51 | - std::shared_ptr<MqttPacket> copyPacket = p.getCopy(); | |
| 51 | + std::shared_ptr<MqttPacket> copyPacket = p.getCopy(new_max_qos); | |
| 52 | 52 | copyPacket->setPacketId(id); |
| 53 | 53 | queue.push_back(copyPacket); |
| 54 | 54 | qosQueueBytes += copyPacket->getTotalMemoryFootprint(); | ... | ... |
qospacketqueue.h
| ... | ... | @@ -15,7 +15,7 @@ public: |
| 15 | 15 | void erase(const uint16_t packet_id); |
| 16 | 16 | size_t size() const; |
| 17 | 17 | size_t getByteSize() const; |
| 18 | - std::shared_ptr<MqttPacket> queuePacket(const MqttPacket &p, uint16_t id); | |
| 18 | + std::shared_ptr<MqttPacket> queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos); | |
| 19 | 19 | std::shared_ptr<MqttPacket> queuePacket(const Publish &pub, uint16_t id); |
| 20 | 20 | |
| 21 | 21 | std::list<std::shared_ptr<MqttPacket>>::const_iterator begin() const; | ... | ... |
session.cpp
| ... | ... | @@ -59,6 +59,29 @@ void Session::setSessionTouch(int64_t ageInMs) |
| 59 | 59 | lastTouched = point; |
| 60 | 60 | } |
| 61 | 61 | |
| 62 | +bool Session::requiresPacketRetransmission() const | |
| 63 | +{ | |
| 64 | + const std::shared_ptr<Client> client = makeSharedClient(); | |
| 65 | + | |
| 66 | + if (!client) | |
| 67 | + return true; | |
| 68 | + | |
| 69 | + // MQTT 3.1: "Brokers, however, should retry any unacknowledged message." | |
| 70 | + // MQTT 3.1.1: "This [reconnecting] is the only circumstance where a Client or Server is REQUIRED to redeliver messages." | |
| 71 | + if (client->getProtocolVersion() < ProtocolVersion::Mqtt311) | |
| 72 | + return true; | |
| 73 | + | |
| 74 | + // TODO: for MQTT5, the rules are different. | |
| 75 | + return !client->getCleanSession(); | |
| 76 | +} | |
| 77 | + | |
| 78 | +void Session::increasePacketId() | |
| 79 | +{ | |
| 80 | + nextPacketId++; | |
| 81 | + if (nextPacketId == 0) | |
| 82 | + nextPacketId++; | |
| 83 | +} | |
| 84 | + | |
| 62 | 85 | /** |
| 63 | 86 | * @brief Session::Session copy constructor. Was created for session storing, and is explicitely kept private, to avoid making accidental copies. |
| 64 | 87 | * @param other |
| ... | ... | @@ -116,63 +139,101 @@ void Session::assignActiveConnection(std::shared_ptr<Client> &client) |
| 116 | 139 | |
| 117 | 140 | /** |
| 118 | 141 | * @brief Session::writePacket is the main way to give a client a packet -> it goes through the session. |
| 119 | - * @param packet | |
| 142 | + * @param packet is not const. We set the qos and packet id for each publish. This should be safe, because the packet | |
| 143 | + * with original packet id and qos is not saved. This saves unnecessary copying. | |
| 120 | 144 | * @param max_qos |
| 121 | 145 | * @param retain. Keep MQTT-3.3.1-9 in mind: existing subscribers don't get retain=1 on packets. |
| 122 | 146 | * @param count. Reference value is updated. It's for statistics. |
| 123 | 147 | */ |
| 124 | -void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count) | |
| 148 | +void Session::writePacket(MqttPacket &packet, char max_qos, std::shared_ptr<MqttPacket> &downgradedQos0PacketCopy, uint64_t &count) | |
| 125 | 149 | { |
| 126 | 150 | assert(max_qos <= 2); |
| 127 | - const char qos = std::min<char>(packet.getQos(), max_qos); | |
| 151 | + const char effectiveQos = std::min<char>(packet.getQos(), max_qos); | |
| 128 | 152 | |
| 129 | 153 | Authentication *_auth = ThreadAuth::getAuth(); |
| 130 | 154 | assert(_auth); |
| 131 | 155 | Authentication &auth = *_auth; |
| 132 | - if (auth.aclCheck(client_id, username, packet.getTopic(), packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success) | |
| 156 | + if (auth.aclCheck(client_id, username, packet.getTopic(), packet.getSubtopics(), AclAccess::read, effectiveQos, packet.getRetain()) == AuthResult::success) | |
| 133 | 157 | { |
| 134 | - if (qos == 0) | |
| 158 | + std::shared_ptr<Client> c = makeSharedClient(); | |
| 159 | + if (effectiveQos == 0) | |
| 135 | 160 | { |
| 136 | - std::shared_ptr<Client> c = makeSharedClient(); | |
| 137 | - | |
| 138 | 161 | if (c) |
| 139 | 162 | { |
| 140 | - count += c->writeMqttPacketAndBlameThisClient(packet, qos); | |
| 163 | + const MqttPacket *packetToSend = &packet; | |
| 164 | + | |
| 165 | + if (max_qos < packet.getQos()) | |
| 166 | + { | |
| 167 | + if (!downgradedQos0PacketCopy) | |
| 168 | + downgradedQos0PacketCopy = packet.getCopy(max_qos); | |
| 169 | + packetToSend = downgradedQos0PacketCopy.get(); | |
| 170 | + } | |
| 171 | + | |
| 172 | + count += c->writeMqttPacketAndBlameThisClient(*packetToSend); | |
| 141 | 173 | } |
| 142 | 174 | } |
| 143 | - else if (qos > 0) | |
| 175 | + else if (effectiveQos > 0) | |
| 144 | 176 | { |
| 145 | - std::unique_lock<std::mutex> locker(qosQueueMutex); | |
| 177 | + const bool requiresRetransmission = requiresPacketRetransmission(); | |
| 146 | 178 | |
| 147 | - const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size(); | |
| 148 | - if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosPacketQueue.getByteSize() >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) | |
| 179 | + if (requiresRetransmission) | |
| 149 | 180 | { |
| 150 | - if (QoSLogPrintedAtId != nextPacketId) | |
| 181 | + std::unique_lock<std::mutex> locker(qosQueueMutex); | |
| 182 | + | |
| 183 | + const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size(); | |
| 184 | + if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosPacketQueue.getByteSize() >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) | |
| 151 | 185 | { |
| 152 | - logger->logf(LOG_WARNING, "Dropping QoS message(s) for client '%s', because its QoS buffers were full.", client_id.c_str()); | |
| 153 | - QoSLogPrintedAtId = nextPacketId; | |
| 186 | + if (QoSLogPrintedAtId != nextPacketId) | |
| 187 | + { | |
| 188 | + logger->logf(LOG_WARNING, "Dropping QoS message(s) for client '%s', because max in-transit packet count reached.", client_id.c_str()); | |
| 189 | + QoSLogPrintedAtId = nextPacketId; | |
| 190 | + } | |
| 191 | + return; | |
| 154 | 192 | } |
| 155 | - return; | |
| 156 | - } | |
| 157 | - nextPacketId++; | |
| 158 | - if (nextPacketId == 0) | |
| 159 | - nextPacketId++; | |
| 160 | 193 | |
| 161 | - std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId); | |
| 162 | - locker.unlock(); | |
| 194 | + increasePacketId(); | |
| 195 | + std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId, effectiveQos); | |
| 163 | 196 | |
| 164 | - std::shared_ptr<Client> c = makeSharedClient(); | |
| 165 | - if (c) | |
| 197 | + if (c) | |
| 198 | + { | |
| 199 | + count += c->writeMqttPacketAndBlameThisClient(*copyPacket.get()); | |
| 200 | + copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | |
| 201 | + } | |
| 202 | + } | |
| 203 | + else | |
| 166 | 204 | { |
| 167 | - count += c->writeMqttPacketAndBlameThisClient(*copyPacket.get(), qos); | |
| 168 | - copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | |
| 205 | + // We don't need to make a copy of the packet in this branch, because: | |
| 206 | + // - The packet to give the client won't shrink in size because source and client have a packet_id. | |
| 207 | + // - We don't have to store the copy in the session for retransmission, see Session::requiresPacketRetransmission() | |
| 208 | + // So, we just keep altering the original published packet. | |
| 209 | + | |
| 210 | + std::unique_lock<std::mutex> locker(qosQueueMutex); | |
| 211 | + | |
| 212 | + if (qosInFlightCounter >= 65530) // Includes a small safety margin. | |
| 213 | + { | |
| 214 | + if (QoSLogPrintedAtId != nextPacketId) | |
| 215 | + { | |
| 216 | + logger->logf(LOG_WARNING, "Dropping QoS message(s) for client '%s', because it hasn't seen enough PUBACKs to release places.", client_id.c_str()); | |
| 217 | + QoSLogPrintedAtId = nextPacketId; | |
| 218 | + } | |
| 219 | + return; | |
| 220 | + } | |
| 221 | + | |
| 222 | + increasePacketId(); | |
| 223 | + | |
| 224 | + // This changes the packet ID and QoS of the incoming packet for each subscriber, but because we don't store that packet anywhere, | |
| 225 | + // that should be fine. | |
| 226 | + packet.setPacketId(nextPacketId); | |
| 227 | + packet.setQos(effectiveQos); | |
| 228 | + | |
| 229 | + qosInFlightCounter++; | |
| 230 | + assert(c); // with requiresRetransmission==false, there must be a client. | |
| 231 | + c->writeMqttPacketAndBlameThisClient(packet); | |
| 169 | 232 | } |
| 170 | 233 | } |
| 171 | 234 | } |
| 172 | 235 | } |
| 173 | 236 | |
| 174 | -// Normatively, this while loop will break on the first element, because all messages are sent out in order and | |
| 175 | -// should be acked in order. | |
| 176 | 237 | void Session::clearQosMessage(uint16_t packet_id) |
| 177 | 238 | { |
| 178 | 239 | #ifndef NDEBUG |
| ... | ... | @@ -180,7 +241,13 @@ void Session::clearQosMessage(uint16_t packet_id) |
| 180 | 241 | #endif |
| 181 | 242 | |
| 182 | 243 | std::lock_guard<std::mutex> locker(qosQueueMutex); |
| 183 | - qosPacketQueue.erase(packet_id); | |
| 244 | + if (requiresPacketRetransmission()) | |
| 245 | + qosPacketQueue.erase(packet_id); | |
| 246 | + else | |
| 247 | + { | |
| 248 | + qosInFlightCounter--; | |
| 249 | + qosInFlightCounter = std::max<int>(0, qosInFlightCounter); // Should never happen, but in case we receive too many PUBACKs. | |
| 250 | + } | |
| 184 | 251 | } |
| 185 | 252 | |
| 186 | 253 | // [MQTT-4.4.0-1]: "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST re-send any |
| ... | ... | @@ -199,7 +266,7 @@ uint64_t Session::sendPendingQosMessages() |
| 199 | 266 | std::lock_guard<std::mutex> locker(qosQueueMutex); |
| 200 | 267 | for (const std::shared_ptr<MqttPacket> &qosMessage : qosPacketQueue) |
| 201 | 268 | { |
| 202 | - count += c->writeMqttPacketAndBlameThisClient(*qosMessage.get(), qosMessage->getQos()); | |
| 269 | + count += c->writeMqttPacketAndBlameThisClient(*qosMessage.get()); | |
| 203 | 270 | qosMessage->setDuplicate(); // Any dealings with this packet from here will be a duplicate. |
| 204 | 271 | } |
| 205 | 272 | |
| ... | ... | @@ -207,7 +274,7 @@ uint64_t Session::sendPendingQosMessages() |
| 207 | 274 | { |
| 208 | 275 | PubRel pubRel(packet_id); |
| 209 | 276 | MqttPacket packet(pubRel); |
| 210 | - count += c->writeMqttPacketAndBlameThisClient(packet, 2); | |
| 277 | + count += c->writeMqttPacketAndBlameThisClient(packet); | |
| 211 | 278 | } |
| 212 | 279 | } |
| 213 | 280 | ... | ... |
session.h
| ... | ... | @@ -48,11 +48,14 @@ class Session |
| 48 | 48 | std::set<uint16_t> outgoingQoS2MessageIds; |
| 49 | 49 | std::mutex qosQueueMutex; |
| 50 | 50 | uint16_t nextPacketId = 0; |
| 51 | + uint16_t qosInFlightCounter = 0; | |
| 51 | 52 | uint16_t QoSLogPrintedAtId = 0; |
| 52 | 53 | std::chrono::time_point<std::chrono::steady_clock> lastTouched = std::chrono::steady_clock::now(); |
| 53 | 54 | Logger *logger = Logger::getInstance(); |
| 54 | 55 | int64_t getSessionRelativeAgeInMs() const; |
| 55 | 56 | void setSessionTouch(int64_t ageInMs); |
| 57 | + bool requiresPacketRetransmission() const; | |
| 58 | + void increasePacketId(); | |
| 56 | 59 | |
| 57 | 60 | Session(const Session &other); |
| 58 | 61 | public: |
| ... | ... | @@ -69,7 +72,7 @@ public: |
| 69 | 72 | const std::string &getClientId() const { return client_id; } |
| 70 | 73 | std::shared_ptr<Client> makeSharedClient() const; |
| 71 | 74 | void assignActiveConnection(std::shared_ptr<Client> &client); |
| 72 | - void writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count); | |
| 75 | + void writePacket(MqttPacket &packet, char max_qos, std::shared_ptr<MqttPacket> &downgradedQos0PacketCopy, uint64_t &count); | |
| 73 | 76 | void clearQosMessage(uint16_t packet_id); |
| 74 | 77 | uint64_t sendPendingQosMessages(); |
| 75 | 78 | void touch(std::chrono::time_point<std::chrono::steady_clock> val); | ... | ... |
subscriptionstore.cpp
| ... | ... | @@ -331,7 +331,7 @@ void SubscriptionStore::publishRecursively(std::vector<std::string>::const_itera |
| 331 | 331 | } |
| 332 | 332 | } |
| 333 | 333 | |
| 334 | -void SubscriptionStore::queuePacketAtSubscribers(const std::vector<std::string> &subtopics, const MqttPacket &packet, bool dollar) | |
| 334 | +void SubscriptionStore::queuePacketAtSubscribers(const std::vector<std::string> &subtopics, MqttPacket &packet, bool dollar) | |
| 335 | 335 | { |
| 336 | 336 | assert(subtopics.size() > 0); |
| 337 | 337 | |
| ... | ... | @@ -346,9 +346,10 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::vector<std::string> |
| 346 | 346 | publishRecursively(subtopics.begin(), subtopics.end(), startNode, subscriberSessions); |
| 347 | 347 | } |
| 348 | 348 | |
| 349 | + std::shared_ptr<MqttPacket> possibleQos0Copy; | |
| 349 | 350 | for(const ReceivingSubscriber &x : subscriberSessions) |
| 350 | 351 | { |
| 351 | - x.session->writePacket(packet, x.qos, false, count); | |
| 352 | + x.session->writePacket(packet, x.qos, possibleQos0Copy, count); | |
| 352 | 353 | } |
| 353 | 354 | |
| 354 | 355 | std::shared_ptr<Client> sender = packet.getSender(); |
| ... | ... | @@ -422,9 +423,10 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr<Ses |
| 422 | 423 | giveClientRetainedMessagesRecursively(subscribeSubtopics.begin(), subscribeSubtopics.end(), startNode, false, packetList); |
| 423 | 424 | } |
| 424 | 425 | |
| 425 | - for(const MqttPacket &packet : packetList) | |
| 426 | + std::shared_ptr<MqttPacket> possibleQos0Copy; | |
| 427 | + for(MqttPacket &packet : packetList) | |
| 426 | 428 | { |
| 427 | - ses->writePacket(packet, max_qos, true, count); | |
| 429 | + ses->writePacket(packet, max_qos, possibleQos0Copy, count); | |
| 428 | 430 | } |
| 429 | 431 | |
| 430 | 432 | return count; | ... | ... |
subscriptionstore.h
| ... | ... | @@ -120,7 +120,7 @@ public: |
| 120 | 120 | void registerClientAndKickExistingOne(std::shared_ptr<Client> &client); |
| 121 | 121 | bool sessionPresent(const std::string &clientid); |
| 122 | 122 | |
| 123 | - void queuePacketAtSubscribers(const std::vector<std::string> &subtopics, const MqttPacket &packet, bool dollar = false); | |
| 123 | + void queuePacketAtSubscribers(const std::vector<std::string> &subtopics, MqttPacket &packet, bool dollar = false); | |
| 124 | 124 | void giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end, |
| 125 | 125 | RetainedMessageNode *this_node, bool poundMode, std::forward_list<MqttPacket> &packetList) const; |
| 126 | 126 | uint64_t giveClientRetainedMessages(const std::shared_ptr<Session> &ses, const std::vector<std::string> &subscribeSubtopics, char max_qos); | ... | ... |
threaddata.cpp
| ... | ... | @@ -127,7 +127,8 @@ void ThreadData::publishStat(const std::string &topic, uint64_t n) |
| 127 | 127 | splitTopic(topic, subtopics); |
| 128 | 128 | const std::string payload = std::to_string(n); |
| 129 | 129 | Publish p(topic, payload, 0); |
| 130 | - subscriptionStore->queuePacketAtSubscribers(subtopics, p, true); | |
| 130 | + MqttPacket pack(p); | |
| 131 | + subscriptionStore->queuePacketAtSubscribers(subtopics, pack, true); | |
| 131 | 132 | subscriptionStore->setRetainedMessage(topic, subtopics, payload, 0); |
| 132 | 133 | } |
| 133 | 134 | ... | ... |