Commit f67238c5b794bb4ff5d3666b052bb6c80dad3ba2

Authored by Wiebe Cazemier
1 parent 4947d33a

Fix not downgrading QoS to subscription QoS

This entails making copies of the original packet when necessary,
because QoS 0 doesn't have a packet id. I tried to keep it to an
absolute minimum and do some precarious optmizations for it. There are
tests though.
FlashMQTests/tst_maintests.cpp
@@ -102,6 +102,15 @@ private slots: @@ -102,6 +102,15 @@ private slots:
102 102
103 void testSavingSessions(); 103 void testSavingSessions();
104 104
  105 + void testCopyPacket();
  106 +
  107 + void testDowngradeQoSOnSubscribeQos2to2();
  108 + void testDowngradeQoSOnSubscribeQos2to1();
  109 + void testDowngradeQoSOnSubscribeQos2to0();
  110 + void testDowngradeQoSOnSubscribeQos1to1();
  111 + void testDowngradeQoSOnSubscribeQos1to0();
  112 + void testDowngradeQoSOnSubscribeQos0to0();
  113 +
105 }; 114 };
106 115
107 MainTests::MainTests() 116 MainTests::MainTests()
@@ -1027,7 +1036,9 @@ void MainTests::testSavingSessions() @@ -1027,7 +1036,9 @@ void MainTests::testSavingSessions()
1027 1036
1028 std::shared_ptr<Session> c1ses = c1->getSession(); 1037 std::shared_ptr<Session> c1ses = c1->getSession();
1029 c1.reset(); 1038 c1.reset();
1030 - c1ses->writePacket(publish, 1, false, count); 1039 + MqttPacket publishPacket(publish);
  1040 + std::shared_ptr<MqttPacket> possibleQos0Copy;
  1041 + c1ses->writePacket(publishPacket, 1, possibleQos0Copy, count);
1031 1042
1032 store->saveSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db"); 1043 store->saveSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db");
1033 1044
@@ -1091,6 +1102,145 @@ void MainTests::testSavingSessions() @@ -1091,6 +1102,145 @@ void MainTests::testSavingSessions()
1091 } 1102 }
1092 } 1103 }
1093 1104
  1105 +void testCopyPacketHelper(const std::string &topic, char from_qos, char to_qos, bool retain)
  1106 +{
  1107 + assert(to_qos <= from_qos);
  1108 +
  1109 + Logger::getInstance()->setFlags(false, false, true);
  1110 +
  1111 + std::shared_ptr<Settings> settings(new Settings());
  1112 + settings->logDebug = false;
  1113 + std::shared_ptr<SubscriptionStore> store(new SubscriptionStore());
  1114 + std::shared_ptr<ThreadData> t(new ThreadData(0, store, settings));
  1115 +
  1116 + // Kind of a hack...
  1117 + Authentication auth(*settings.get());
  1118 + ThreadAuth::assign(&auth);
  1119 +
  1120 + std::shared_ptr<Client> dummyClient(new Client(0, t, nullptr, false, nullptr, settings, false));
  1121 + dummyClient->setClientProperties(ProtocolVersion::Mqtt311, "qostestclient", "user1", true, 60, false);
  1122 + store->registerClientAndKickExistingOne(dummyClient);
  1123 +
  1124 + uint16_t packetid = 66;
  1125 + for (int len = 0; len < 150; len++ )
  1126 + {
  1127 + const uint16_t pack_id = packetid++;
  1128 +
  1129 + std::vector<MqttPacket> parsedPackets;
  1130 +
  1131 + const std::string payloadOne = getSecureRandomString(len);
  1132 + Publish pubOne(topic, payloadOne, from_qos);
  1133 + pubOne.retain = retain;
  1134 + MqttPacket stagingPacketOne(pubOne);
  1135 + if (from_qos > 0)
  1136 + stagingPacketOne.setPacketId(pack_id);
  1137 + CirBuf stagingBufOne(1024);
  1138 + stagingPacketOne.readIntoBuf(stagingBufOne);
  1139 +
  1140 + MqttPacket::bufferToMqttPackets(stagingBufOne, parsedPackets, dummyClient);
  1141 + QVERIFY(parsedPackets.size() == 1);
  1142 + MqttPacket parsedPacketOne = std::move(parsedPackets.front());
  1143 + parsedPacketOne.handlePublish();
  1144 + if (retain) // A normal handled packet always has retain=0, so I force setting it here.
  1145 + parsedPacketOne.setRetain();
  1146 + QCOMPARE(stagingPacketOne.getTopic(), parsedPacketOne.getTopic());
  1147 + QCOMPARE(stagingPacketOne.getPayloadCopy(), parsedPacketOne.getPayloadCopy());
  1148 +
  1149 + std::shared_ptr<MqttPacket> copiedPacketOne = parsedPacketOne.getCopy(to_qos);
  1150 +
  1151 + QCOMPARE(payloadOne, copiedPacketOne->getPayloadCopy());
  1152 +
  1153 + // Now compare the written buffer of our copied packet to one that was written with our known good reference packet.
  1154 +
  1155 + Publish pubReference(topic, payloadOne, to_qos);
  1156 + pubReference.retain = retain;
  1157 + MqttPacket packetReference(pubReference);
  1158 + if (to_qos > 0)
  1159 + packetReference.setPacketId(pack_id);
  1160 + CirBuf bufOfReference(1024);
  1161 + CirBuf bufOfCopied(1024);
  1162 + packetReference.readIntoBuf(bufOfReference);
  1163 + copiedPacketOne->readIntoBuf(bufOfCopied);
  1164 + QVERIFY2(bufOfCopied == bufOfReference, formatString("Failure on length %d for topic %s, from qos %d to qos %d, retain: %d.",
  1165 + len, topic.c_str(), from_qos, to_qos, retain).c_str());
  1166 + }
  1167 +}
  1168 +
  1169 +/**
  1170 + * @brief MainTests::testCopyPacket tests the actual bytes of a copied that would be written to a client.
  1171 + *
  1172 + * This is specifically to test the optimiziations in getCopy(). It indirectly also tests packet parsing.
  1173 + */
  1174 +void MainTests::testCopyPacket()
  1175 +{
  1176 + for (int retain = 0; retain < 2; retain++)
  1177 + {
  1178 + testCopyPacketHelper("John/McLane", 0, 0, retain);
  1179 + testCopyPacketHelper("Ben/Sisko", 1, 1, retain);
  1180 + testCopyPacketHelper("Rebecca/Bunch", 2, 2, retain);
  1181 +
  1182 + testCopyPacketHelper("Buffy/Slayer", 1, 0, retain);
  1183 + testCopyPacketHelper("Sarah/Connor", 2, 0, retain);
  1184 + testCopyPacketHelper("Susan/Mayer", 2, 1, retain);
  1185 + }
  1186 +}
  1187 +
  1188 +void testDowngradeQoSOnSubscribeHelper(const char pub_qos, const char sub_qos)
  1189 +{
  1190 + TwoClientTestContext testContext;
  1191 +
  1192 + const QString topic("Star/Trek");
  1193 + const QByteArray payload("Captain Kirk");
  1194 +
  1195 + testContext.connectSender();
  1196 + testContext.connectReceiver();
  1197 +
  1198 + testContext.subscribeReceiver(topic, sub_qos);
  1199 + testContext.publish(topic, payload, pub_qos, false);
  1200 +
  1201 + testContext.waitReceiverReceived(1);
  1202 +
  1203 + QCOMPARE(testContext.receivedMessages.length(), 1);
  1204 + QMQTT::Message &recv = testContext.receivedMessages.first();
  1205 +
  1206 + const char expected_qos = std::min<const char>(pub_qos, sub_qos);
  1207 + QVERIFY2(recv.qos() == expected_qos, formatString("Failure: received QoS is %d. Published is %d. Subscribed as %d. Expected QoS is %d",
  1208 + recv.qos(), pub_qos, sub_qos, expected_qos).c_str());
  1209 + QVERIFY(recv.topic() == topic);
  1210 + QVERIFY(recv.payload() == payload);
  1211 +}
  1212 +
  1213 +void MainTests::testDowngradeQoSOnSubscribeQos2to2()
  1214 +{
  1215 + testDowngradeQoSOnSubscribeHelper(2, 2);
  1216 +}
  1217 +
  1218 +void MainTests::testDowngradeQoSOnSubscribeQos2to1()
  1219 +{
  1220 + testDowngradeQoSOnSubscribeHelper(2, 1);
  1221 +}
  1222 +
  1223 +void MainTests::testDowngradeQoSOnSubscribeQos2to0()
  1224 +{
  1225 + testDowngradeQoSOnSubscribeHelper(2, 0);
  1226 +}
  1227 +
  1228 +void MainTests::testDowngradeQoSOnSubscribeQos1to1()
  1229 +{
  1230 + testDowngradeQoSOnSubscribeHelper(1, 1);
  1231 +}
  1232 +
  1233 +void MainTests::testDowngradeQoSOnSubscribeQos1to0()
  1234 +{
  1235 + testDowngradeQoSOnSubscribeHelper(1, 0);
  1236 +}
  1237 +
  1238 +void MainTests::testDowngradeQoSOnSubscribeQos0to0()
  1239 +{
  1240 + testDowngradeQoSOnSubscribeHelper(0, 0);
  1241 +}
  1242 +
  1243 +
1094 int main(int argc, char *argv[]) 1244 int main(int argc, char *argv[])
1095 { 1245 {
1096 QCoreApplication app(argc, argv); 1246 QCoreApplication app(argc, argv);
FlashMQTests/twoclienttestcontext.cpp
@@ -35,12 +35,22 @@ TwoClientTestContext::TwoClientTestContext(QObject *parent) : QObject(parent) @@ -35,12 +35,22 @@ TwoClientTestContext::TwoClientTestContext(QObject *parent) : QObject(parent)
35 connect(receiver.data(), &QMQTT::Client::error, this, &TwoClientTestContext::onClientError); 35 connect(receiver.data(), &QMQTT::Client::error, this, &TwoClientTestContext::onClientError);
36 } 36 }
37 37
  38 +void TwoClientTestContext::publish(const QString &topic, const QByteArray &payload)
  39 +{
  40 + publish(topic, payload, 0, false);
  41 +}
  42 +
38 void TwoClientTestContext::publish(const QString &topic, const QByteArray &payload, bool retain) 43 void TwoClientTestContext::publish(const QString &topic, const QByteArray &payload, bool retain)
39 { 44 {
  45 + publish(topic, payload, 0, retain);
  46 +}
  47 +
  48 +void TwoClientTestContext::publish(const QString &topic, const QByteArray &payload, const quint8 qos, bool retain)
  49 +{
40 QMQTT::Message msg; 50 QMQTT::Message msg;
41 msg.setTopic(topic); 51 msg.setTopic(topic);
42 msg.setRetain(retain); 52 msg.setRetain(retain);
43 - msg.setQos(0); 53 + msg.setQos(qos);
44 msg.setPayload(payload); 54 msg.setPayload(payload);
45 sender->publish(msg); 55 sender->publish(msg);
46 } 56 }
@@ -71,9 +81,9 @@ void TwoClientTestContext::disconnectReceiver() @@ -71,9 +81,9 @@ void TwoClientTestContext::disconnectReceiver()
71 waiter.exec(); 81 waiter.exec();
72 } 82 }
73 83
74 -void TwoClientTestContext::subscribeReceiver(const QString &topic) 84 +void TwoClientTestContext::subscribeReceiver(const QString &topic, const quint8 qos)
75 { 85 {
76 - receiver->subscribe(topic); 86 + receiver->subscribe(topic, qos);
77 87
78 QEventLoop waiter; 88 QEventLoop waiter;
79 QTimer timeout; 89 QTimer timeout;
FlashMQTests/twoclienttestcontext.h
@@ -34,11 +34,13 @@ private slots: @@ -34,11 +34,13 @@ private slots:
34 34
35 public: 35 public:
36 explicit TwoClientTestContext(QObject *parent = nullptr); 36 explicit TwoClientTestContext(QObject *parent = nullptr);
37 - void publish(const QString &topic, const QByteArray &payload, bool retain = false); 37 + void publish(const QString &topic, const QByteArray &payload);
  38 + void publish(const QString &topic, const QByteArray &payload, bool retain);
  39 + void publish(const QString &topic, const QByteArray &payload, const quint8 qos, bool retain);
38 void connectSender(); 40 void connectSender();
39 void connectReceiver(); 41 void connectReceiver();
40 void disconnectReceiver(); 42 void disconnectReceiver();
41 - void subscribeReceiver(const QString &topic); 43 + void subscribeReceiver(const QString &topic, const quint8 qos = 0);
42 void waitReceiverReceived(int count); 44 void waitReceiverReceived(int count);
43 void onClientError(const QMQTT::ClientError error); 45 void onClientError(const QMQTT::ClientError error);
44 46
cirbuf.cpp
@@ -252,3 +252,22 @@ void CirBuf::read(void *buf, const size_t count) @@ -252,3 +252,22 @@ void CirBuf::read(void *buf, const size_t count)
252 assert(_packet_len == 0); 252 assert(_packet_len == 0);
253 assert(i == static_cast<int>(count)); 253 assert(i == static_cast<int>(count));
254 } 254 }
  255 +
  256 +/**
  257 + * @brief CirBuf::operator == simplistic comparision. It doesn't take the fact that it's circular into account.
  258 + * @param other
  259 + * @return
  260 + *
  261 + * It was created for unit testing. read() and write() are non-const, so taking the circular properties into account
  262 + * would need more/duplicate code that I don't need at this point.
  263 + */
  264 +bool CirBuf::operator==(const CirBuf &other) const
  265 +{
  266 +#ifdef NDEBUG
  267 + throw std::exception(); // you can't use it in release builds, because new buffers aren't zeroed.
  268 +#endif
  269 +
  270 + return tail == 0 && other.tail == 0
  271 + && usedBytes() == other.usedBytes()
  272 + && std::memcmp(buf, other.buf, size) == 0;
  273 +}
cirbuf.h
@@ -60,6 +60,8 @@ public: @@ -60,6 +60,8 @@ public:
60 60
61 void write(const void *buf, size_t count); 61 void write(const void *buf, size_t count);
62 void read(void *buf, const size_t count); 62 void read(void *buf, const size_t count);
  63 +
  64 + bool operator==(const CirBuf &other) const;
63 }; 65 };
64 66
65 #endif // CIRBUF_H 67 #endif // CIRBUF_H
client.cpp
@@ -56,7 +56,7 @@ Client::~Client() @@ -56,7 +56,7 @@ Client::~Client()
56 { 56 {
57 Publish will(will_topic, will_payload, will_qos); 57 Publish will(will_topic, will_payload, will_qos);
58 will.retain = will_retain; 58 will.retain = will_retain;
59 - const MqttPacket willPacket(will); 59 + MqttPacket willPacket(will);
60 60
61 const std::vector<std::string> subtopics = splitToVector(will_topic, '/'); 61 const std::vector<std::string> subtopics = splitToVector(will_topic, '/');
62 store->queuePacketAtSubscribers(subtopics, willPacket); 62 store->queuePacketAtSubscribers(subtopics, willPacket);
@@ -180,7 +180,7 @@ void Client::writeText(const std::string &amp;text) @@ -180,7 +180,7 @@ void Client::writeText(const std::string &amp;text)
180 setReadyForWriting(true); 180 setReadyForWriting(true);
181 } 181 }
182 182
183 -int Client::writeMqttPacket(const MqttPacket &packet, const char qos) 183 +int Client::writeMqttPacket(const MqttPacket &packet)
184 { 184 {
185 std::lock_guard<std::mutex> locker(writeBufMutex); 185 std::lock_guard<std::mutex> locker(writeBufMutex);
186 186
@@ -193,7 +193,7 @@ int Client::writeMqttPacket(const MqttPacket &amp;packet, const char qos) @@ -193,7 +193,7 @@ int Client::writeMqttPacket(const MqttPacket &amp;packet, const char qos)
193 193
194 // And drop a publish when it doesn't fit, even after resizing. This means we do allow pings. And 194 // And drop a publish when it doesn't fit, even after resizing. This means we do allow pings. And
195 // QoS packet are queued and limited elsewhere. 195 // QoS packet are queued and limited elsewhere.
196 - if (packet.packetType == PacketType::PUBLISH && qos == 0 && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace()) 196 + if (packet.packetType == PacketType::PUBLISH && packet.getQos() == 0 && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace())
197 { 197 {
198 return 0; 198 return 0;
199 } 199 }
@@ -208,11 +208,11 @@ int Client::writeMqttPacket(const MqttPacket &amp;packet, const char qos) @@ -208,11 +208,11 @@ int Client::writeMqttPacket(const MqttPacket &amp;packet, const char qos)
208 } 208 }
209 209
210 // Helper method to avoid the exception ending up at the sender of messages, which would then get disconnected. 210 // Helper method to avoid the exception ending up at the sender of messages, which would then get disconnected.
211 -int Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet, const char qos) 211 +int Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet)
212 { 212 {
213 try 213 try
214 { 214 {
215 - return this->writeMqttPacket(packet, qos); 215 + return this->writeMqttPacket(packet);
216 } 216 }
217 catch (std::exception &ex) 217 catch (std::exception &ex)
218 { 218 {
client.h
@@ -121,8 +121,8 @@ public: @@ -121,8 +121,8 @@ public:
121 121
122 void writeText(const std::string &text); 122 void writeText(const std::string &text);
123 void writePingResp(); 123 void writePingResp();
124 - int writeMqttPacket(const MqttPacket &packet, const char qos = 0);  
125 - int writeMqttPacketAndBlameThisClient(const MqttPacket &packet, const char qos); 124 + int writeMqttPacket(const MqttPacket &packet);
  125 + int writeMqttPacketAndBlameThisClient(const MqttPacket &packet);
126 bool writeBufIntoFd(); 126 bool writeBufIntoFd();
127 bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; } 127 bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; }
128 128
mqttpacket.cpp
@@ -48,11 +48,69 @@ MqttPacket::MqttPacket(CirBuf &amp;buf, size_t packet_len, size_t fixed_header_lengt @@ -48,11 +48,69 @@ MqttPacket::MqttPacket(CirBuf &amp;buf, size_t packet_len, size_t fixed_header_lengt
48 * @brief MqttPacket::getCopy (using default copy constructor and resetting some selected fields) is easier than using the copy constructor 48 * @brief MqttPacket::getCopy (using default copy constructor and resetting some selected fields) is easier than using the copy constructor
49 * publically, because then I have to keep maintaining a functioning copy constructor for each new field I add. 49 * publically, because then I have to keep maintaining a functioning copy constructor for each new field I add.
50 * @return a shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource. 50 * @return a shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource.
  51 + *
  52 + * The idea is that because a packet with QoS is longer than one without, we just copy as much as possible if both packets have the same QoS.
  53 + *
  54 + * Note that there can be two types of packets: one with the fixed header (including remaining length), and one without. The latter we could be
  55 + * more clever about, but I'm forgoing that right now. Their use is mostly for retained messages.
  56 + *
  57 + * Also note that some fields are undeterminstic in the copy: dup, retain and packetid for instance. Sometimes they come from the original,
  58 + * sometimes not. The current planned usage is that those fields will either ONLY or NEVER be used in the copy, so it doesn't matter what I do
  59 + * with them here. I may reconsider.
51 */ 60 */
52 -std::shared_ptr<MqttPacket> MqttPacket::getCopy() const 61 +std::shared_ptr<MqttPacket> MqttPacket::getCopy(char new_max_qos) const
53 { 62 {
  63 + assert(packetType == PacketType::PUBLISH);
  64 +
  65 + // You're not supposed to copy a duplicate packet. The only packets that get the dup flag, should not be copied AGAIN. This
  66 + // has to do with the Session::writePacket() and Session::sendPendingQosMessages() logic.
  67 + assert((first_byte & 0b00001000) == 0);
  68 +
  69 + if (qos > 0 && new_max_qos == 0)
  70 + {
  71 + // if shrinking the packet doesn't alter the amount of bytes in the 'remaining length' part of the header, we can
  72 + // just memmove+shrink the packet. This is because the packet id always is two bytes before the payload, so we just move the payload
  73 + // over it. When testing 100M copies, it went from 21000 ms to 10000 ms. In other words, about 100 ยตs to 200 ยตs per copy.
  74 + // There is an elaborate unit test to test this optimization.
  75 + if ((fixed_header_length == 2 && bites.size() < 125))
  76 + {
  77 + // I don't know yet if this is true, but I don't want to forget when I implemenet MQTT5.
  78 + assert(sender && sender->getProtocolVersion() <= ProtocolVersion::Mqtt311);
  79 +
  80 + std::shared_ptr<MqttPacket> p(new MqttPacket(*this));
  81 + p->sender.reset();
  82 +
  83 + if (payloadLen > 0)
  84 + std::memmove(&p->bites[packet_id_pos], &p->bites[packet_id_pos+2], payloadLen);
  85 + p->bites.erase(p->bites.end() - 2, p->bites.end());
  86 + p->packet_id_pos = 0;
  87 + p->payloadStart -= 2;
  88 + if (pos > p->bites.size()) // pos can possible be set elsewhere, so we only set it back if it was after the payload.
  89 + p->pos -= 2;
  90 + p->packet_id = 0;
  91 +
  92 + // Clear QoS bits from the header.
  93 + p->first_byte &= 0b11111001;
  94 + p->bites[0] = p->first_byte;
  95 +
  96 + assert((p->bites[1] & 0b10000000) == 0); // when there is an MSB, I musn't get rid of it.
  97 + assert(p->bites[1] > 3); // There has to be a remaining value after subtracting 2.
  98 +
  99 + p->bites[1] -= 2; // Reduce the value in the 'remaining length' part of the header.
  100 +
  101 + return p;
  102 + }
  103 +
  104 + Publish pub(topic, getPayloadCopy(), new_max_qos);
  105 + pub.retain = getRetain();
  106 + std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(pub));
  107 + return copyPacket;
  108 + }
  109 +
54 std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(*this)); 110 std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(*this));
55 copyPacket->sender.reset(); 111 copyPacket->sender.reset();
  112 + if (qos != new_max_qos)
  113 + copyPacket->setQos(new_max_qos);
56 return copyPacket; 114 return copyPacket;
57 } 115 }
58 116
@@ -773,11 +831,11 @@ size_t MqttPacket::getTotalMemoryFootprint() @@ -773,11 +831,11 @@ size_t MqttPacket::getTotalMemoryFootprint()
773 * @return 831 * @return
774 * 832 *
775 * It's necessary sometimes, but it's against FlashMQ's concept of not parsing the payload. Normally, you can just write out 833 * It's necessary sometimes, but it's against FlashMQ's concept of not parsing the payload. Normally, you can just write out
776 - * the whole byte array that is a packet to subscribers. No need to copy and such. 834 + * the whole byte array of an original packet to subscribers. No need to copy and such.
777 * 835 *
778 - * I created it for saving QoS packages in the db file. 836 + * But, as stated, sometimes it's necessary.
779 */ 837 */
780 -std::string MqttPacket::getPayloadCopy() 838 +std::string MqttPacket::getPayloadCopy() const
781 { 839 {
782 assert(payloadStart > 0); 840 assert(payloadStart > 0);
783 assert(pos <= bites.size()); 841 assert(pos <= bites.size());
@@ -798,6 +856,23 @@ size_t MqttPacket::getSizeIncludingNonPresentHeader() const @@ -798,6 +856,23 @@ size_t MqttPacket::getSizeIncludingNonPresentHeader() const
798 return total; 856 return total;
799 } 857 }
800 858
  859 +void MqttPacket::setQos(const char new_qos)
  860 +{
  861 + // You can't change to a QoS level that would remove the packet identifier.
  862 + assert((qos == 0 && new_qos == 0) || (qos > 0 && new_qos > 0));
  863 + assert(new_qos > 0 && packet_id_pos > 0);
  864 +
  865 + qos = new_qos;
  866 + first_byte &= 0b11111001;
  867 + first_byte |= (qos << 1);
  868 +
  869 + if (fixed_header_length > 0)
  870 + {
  871 + pos = 0;
  872 + writeByte(first_byte);
  873 + }
  874 +}
  875 +
801 const std::string &MqttPacket::getTopic() const 876 const std::string &MqttPacket::getTopic() const
802 { 877 {
803 return this->topic; 878 return this->topic;
@@ -886,6 +961,31 @@ size_t MqttPacket::remainingAfterPos() @@ -886,6 +961,31 @@ size_t MqttPacket::remainingAfterPos()
886 return bites.size() - pos; 961 return bites.size() - pos;
887 } 962 }
888 963
  964 +bool MqttPacket::getRetain() const
  965 +{
  966 + return (first_byte & 0b00000001);
  967 +}
  968 +
  969 +/**
  970 + * @brief MqttPacket::setRetain set the retain bit in the first byte. I think I only need this in tests, because existing subscribers don't get retain=1,
  971 + * so handlePublish() clears it. But I needed it to be set in testing.
  972 + *
  973 + * Publishing of the retained messages goes through the MqttPacket(Publish) constructor, hence this setRetain() isn't necessary for that.
  974 + */
  975 +void MqttPacket::setRetain()
  976 +{
  977 +#ifndef TESTING
  978 + assert(false);
  979 +#endif
  980 +
  981 + first_byte |= 0b00000001;
  982 +
  983 + if (fixed_header_length > 0)
  984 + {
  985 + pos = 0;
  986 + writeByte(first_byte);
  987 + }
  988 +}
889 989
890 void MqttPacket::readIntoBuf(CirBuf &buf) const 990 void MqttPacket::readIntoBuf(CirBuf &buf) const
891 { 991 {
mqttpacket.h
@@ -80,7 +80,7 @@ public: @@ -80,7 +80,7 @@ public:
80 80
81 MqttPacket(MqttPacket &&other) = default; 81 MqttPacket(MqttPacket &&other) = default;
82 82
83 - std::shared_ptr<MqttPacket> getCopy() const; 83 + std::shared_ptr<MqttPacket> getCopy(char new_max_qos) const;
84 84
85 // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance. 85 // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance.
86 MqttPacket(const ConnAck &connAck); 86 MqttPacket(const ConnAck &connAck);
@@ -109,6 +109,7 @@ public: @@ -109,6 +109,7 @@ public:
109 size_t getSizeIncludingNonPresentHeader() const; 109 size_t getSizeIncludingNonPresentHeader() const;
110 const std::vector<char> &getBites() const { return bites; } 110 const std::vector<char> &getBites() const { return bites; }
111 char getQos() const { return qos; } 111 char getQos() const { return qos; }
  112 + void setQos(const char new_qos);
112 const std::string &getTopic() const; 113 const std::string &getTopic() const;
113 const std::vector<std::string> &getSubtopics() const; 114 const std::vector<std::string> &getSubtopics() const;
114 std::shared_ptr<Client> getSender() const; 115 std::shared_ptr<Client> getSender() const;
@@ -120,8 +121,10 @@ public: @@ -120,8 +121,10 @@ public:
120 uint16_t getPacketId() const; 121 uint16_t getPacketId() const;
121 void setDuplicate(); 122 void setDuplicate();
122 size_t getTotalMemoryFootprint(); 123 size_t getTotalMemoryFootprint();
123 - std::string getPayloadCopy();  
124 void readIntoBuf(CirBuf &buf) const; 124 void readIntoBuf(CirBuf &buf) const;
  125 + std::string getPayloadCopy() const;
  126 + bool getRetain() const;
  127 + void setRetain();
125 }; 128 };
126 129
127 #endif // MQTTPACKET_H 130 #endif // MQTTPACKET_H
qospacketqueue.cpp
@@ -44,11 +44,11 @@ size_t QoSPacketQueue::getByteSize() const @@ -44,11 +44,11 @@ size_t QoSPacketQueue::getByteSize() const
44 * @param id 44 * @param id
45 * @return the packet copy. 45 * @return the packet copy.
46 */ 46 */
47 -std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id) 47 +std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos)
48 { 48 {
49 assert(p.getQos() > 0); 49 assert(p.getQos() > 0);
50 50
51 - std::shared_ptr<MqttPacket> copyPacket = p.getCopy(); 51 + std::shared_ptr<MqttPacket> copyPacket = p.getCopy(new_max_qos);
52 copyPacket->setPacketId(id); 52 copyPacket->setPacketId(id);
53 queue.push_back(copyPacket); 53 queue.push_back(copyPacket);
54 qosQueueBytes += copyPacket->getTotalMemoryFootprint(); 54 qosQueueBytes += copyPacket->getTotalMemoryFootprint();
qospacketqueue.h
@@ -15,7 +15,7 @@ public: @@ -15,7 +15,7 @@ public:
15 void erase(const uint16_t packet_id); 15 void erase(const uint16_t packet_id);
16 size_t size() const; 16 size_t size() const;
17 size_t getByteSize() const; 17 size_t getByteSize() const;
18 - std::shared_ptr<MqttPacket> queuePacket(const MqttPacket &p, uint16_t id); 18 + std::shared_ptr<MqttPacket> queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos);
19 std::shared_ptr<MqttPacket> queuePacket(const Publish &pub, uint16_t id); 19 std::shared_ptr<MqttPacket> queuePacket(const Publish &pub, uint16_t id);
20 20
21 std::list<std::shared_ptr<MqttPacket>>::const_iterator begin() const; 21 std::list<std::shared_ptr<MqttPacket>>::const_iterator begin() const;
session.cpp
@@ -59,6 +59,29 @@ void Session::setSessionTouch(int64_t ageInMs) @@ -59,6 +59,29 @@ void Session::setSessionTouch(int64_t ageInMs)
59 lastTouched = point; 59 lastTouched = point;
60 } 60 }
61 61
  62 +bool Session::requiresPacketRetransmission() const
  63 +{
  64 + const std::shared_ptr<Client> client = makeSharedClient();
  65 +
  66 + if (!client)
  67 + return true;
  68 +
  69 + // MQTT 3.1: "Brokers, however, should retry any unacknowledged message."
  70 + // MQTT 3.1.1: "This [reconnecting] is the only circumstance where a Client or Server is REQUIRED to redeliver messages."
  71 + if (client->getProtocolVersion() < ProtocolVersion::Mqtt311)
  72 + return true;
  73 +
  74 + // TODO: for MQTT5, the rules are different.
  75 + return !client->getCleanSession();
  76 +}
  77 +
  78 +void Session::increasePacketId()
  79 +{
  80 + nextPacketId++;
  81 + if (nextPacketId == 0)
  82 + nextPacketId++;
  83 +}
  84 +
62 /** 85 /**
63 * @brief Session::Session copy constructor. Was created for session storing, and is explicitely kept private, to avoid making accidental copies. 86 * @brief Session::Session copy constructor. Was created for session storing, and is explicitely kept private, to avoid making accidental copies.
64 * @param other 87 * @param other
@@ -116,63 +139,101 @@ void Session::assignActiveConnection(std::shared_ptr&lt;Client&gt; &amp;client) @@ -116,63 +139,101 @@ void Session::assignActiveConnection(std::shared_ptr&lt;Client&gt; &amp;client)
116 139
117 /** 140 /**
118 * @brief Session::writePacket is the main way to give a client a packet -> it goes through the session. 141 * @brief Session::writePacket is the main way to give a client a packet -> it goes through the session.
119 - * @param packet 142 + * @param packet is not const. We set the qos and packet id for each publish. This should be safe, because the packet
  143 + * with original packet id and qos is not saved. This saves unnecessary copying.
120 * @param max_qos 144 * @param max_qos
121 * @param retain. Keep MQTT-3.3.1-9 in mind: existing subscribers don't get retain=1 on packets. 145 * @param retain. Keep MQTT-3.3.1-9 in mind: existing subscribers don't get retain=1 on packets.
122 * @param count. Reference value is updated. It's for statistics. 146 * @param count. Reference value is updated. It's for statistics.
123 */ 147 */
124 -void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count) 148 +void Session::writePacket(MqttPacket &packet, char max_qos, std::shared_ptr<MqttPacket> &downgradedQos0PacketCopy, uint64_t &count)
125 { 149 {
126 assert(max_qos <= 2); 150 assert(max_qos <= 2);
127 - const char qos = std::min<char>(packet.getQos(), max_qos); 151 + const char effectiveQos = std::min<char>(packet.getQos(), max_qos);
128 152
129 Authentication *_auth = ThreadAuth::getAuth(); 153 Authentication *_auth = ThreadAuth::getAuth();
130 assert(_auth); 154 assert(_auth);
131 Authentication &auth = *_auth; 155 Authentication &auth = *_auth;
132 - if (auth.aclCheck(client_id, username, packet.getTopic(), packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success) 156 + if (auth.aclCheck(client_id, username, packet.getTopic(), packet.getSubtopics(), AclAccess::read, effectiveQos, packet.getRetain()) == AuthResult::success)
133 { 157 {
134 - if (qos == 0) 158 + std::shared_ptr<Client> c = makeSharedClient();
  159 + if (effectiveQos == 0)
135 { 160 {
136 - std::shared_ptr<Client> c = makeSharedClient();  
137 -  
138 if (c) 161 if (c)
139 { 162 {
140 - count += c->writeMqttPacketAndBlameThisClient(packet, qos); 163 + const MqttPacket *packetToSend = &packet;
  164 +
  165 + if (max_qos < packet.getQos())
  166 + {
  167 + if (!downgradedQos0PacketCopy)
  168 + downgradedQos0PacketCopy = packet.getCopy(max_qos);
  169 + packetToSend = downgradedQos0PacketCopy.get();
  170 + }
  171 +
  172 + count += c->writeMqttPacketAndBlameThisClient(*packetToSend);
141 } 173 }
142 } 174 }
143 - else if (qos > 0) 175 + else if (effectiveQos > 0)
144 { 176 {
145 - std::unique_lock<std::mutex> locker(qosQueueMutex); 177 + const bool requiresRetransmission = requiresPacketRetransmission();
146 178
147 - const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size();  
148 - if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosPacketQueue.getByteSize() >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) 179 + if (requiresRetransmission)
149 { 180 {
150 - if (QoSLogPrintedAtId != nextPacketId) 181 + std::unique_lock<std::mutex> locker(qosQueueMutex);
  182 +
  183 + const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size();
  184 + if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosPacketQueue.getByteSize() >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0))
151 { 185 {
152 - logger->logf(LOG_WARNING, "Dropping QoS message(s) for client '%s', because its QoS buffers were full.", client_id.c_str());  
153 - QoSLogPrintedAtId = nextPacketId; 186 + if (QoSLogPrintedAtId != nextPacketId)
  187 + {
  188 + logger->logf(LOG_WARNING, "Dropping QoS message(s) for client '%s', because max in-transit packet count reached.", client_id.c_str());
  189 + QoSLogPrintedAtId = nextPacketId;
  190 + }
  191 + return;
154 } 192 }
155 - return;  
156 - }  
157 - nextPacketId++;  
158 - if (nextPacketId == 0)  
159 - nextPacketId++;  
160 193
161 - std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId);  
162 - locker.unlock(); 194 + increasePacketId();
  195 + std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId, effectiveQos);
163 196
164 - std::shared_ptr<Client> c = makeSharedClient();  
165 - if (c) 197 + if (c)
  198 + {
  199 + count += c->writeMqttPacketAndBlameThisClient(*copyPacket.get());
  200 + copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate.
  201 + }
  202 + }
  203 + else
166 { 204 {
167 - count += c->writeMqttPacketAndBlameThisClient(*copyPacket.get(), qos);  
168 - copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. 205 + // We don't need to make a copy of the packet in this branch, because:
  206 + // - The packet to give the client won't shrink in size because source and client have a packet_id.
  207 + // - We don't have to store the copy in the session for retransmission, see Session::requiresPacketRetransmission()
  208 + // So, we just keep altering the original published packet.
  209 +
  210 + std::unique_lock<std::mutex> locker(qosQueueMutex);
  211 +
  212 + if (qosInFlightCounter >= 65530) // Includes a small safety margin.
  213 + {
  214 + if (QoSLogPrintedAtId != nextPacketId)
  215 + {
  216 + logger->logf(LOG_WARNING, "Dropping QoS message(s) for client '%s', because it hasn't seen enough PUBACKs to release places.", client_id.c_str());
  217 + QoSLogPrintedAtId = nextPacketId;
  218 + }
  219 + return;
  220 + }
  221 +
  222 + increasePacketId();
  223 +
  224 + // This changes the packet ID and QoS of the incoming packet for each subscriber, but because we don't store that packet anywhere,
  225 + // that should be fine.
  226 + packet.setPacketId(nextPacketId);
  227 + packet.setQos(effectiveQos);
  228 +
  229 + qosInFlightCounter++;
  230 + assert(c); // with requiresRetransmission==false, there must be a client.
  231 + c->writeMqttPacketAndBlameThisClient(packet);
169 } 232 }
170 } 233 }
171 } 234 }
172 } 235 }
173 236
174 -// Normatively, this while loop will break on the first element, because all messages are sent out in order and  
175 -// should be acked in order.  
176 void Session::clearQosMessage(uint16_t packet_id) 237 void Session::clearQosMessage(uint16_t packet_id)
177 { 238 {
178 #ifndef NDEBUG 239 #ifndef NDEBUG
@@ -180,7 +241,13 @@ void Session::clearQosMessage(uint16_t packet_id) @@ -180,7 +241,13 @@ void Session::clearQosMessage(uint16_t packet_id)
180 #endif 241 #endif
181 242
182 std::lock_guard<std::mutex> locker(qosQueueMutex); 243 std::lock_guard<std::mutex> locker(qosQueueMutex);
183 - qosPacketQueue.erase(packet_id); 244 + if (requiresPacketRetransmission())
  245 + qosPacketQueue.erase(packet_id);
  246 + else
  247 + {
  248 + qosInFlightCounter--;
  249 + qosInFlightCounter = std::max<int>(0, qosInFlightCounter); // Should never happen, but in case we receive too many PUBACKs.
  250 + }
184 } 251 }
185 252
186 // [MQTT-4.4.0-1]: "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST re-send any 253 // [MQTT-4.4.0-1]: "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST re-send any
@@ -199,7 +266,7 @@ uint64_t Session::sendPendingQosMessages() @@ -199,7 +266,7 @@ uint64_t Session::sendPendingQosMessages()
199 std::lock_guard<std::mutex> locker(qosQueueMutex); 266 std::lock_guard<std::mutex> locker(qosQueueMutex);
200 for (const std::shared_ptr<MqttPacket> &qosMessage : qosPacketQueue) 267 for (const std::shared_ptr<MqttPacket> &qosMessage : qosPacketQueue)
201 { 268 {
202 - count += c->writeMqttPacketAndBlameThisClient(*qosMessage.get(), qosMessage->getQos()); 269 + count += c->writeMqttPacketAndBlameThisClient(*qosMessage.get());
203 qosMessage->setDuplicate(); // Any dealings with this packet from here will be a duplicate. 270 qosMessage->setDuplicate(); // Any dealings with this packet from here will be a duplicate.
204 } 271 }
205 272
@@ -207,7 +274,7 @@ uint64_t Session::sendPendingQosMessages() @@ -207,7 +274,7 @@ uint64_t Session::sendPendingQosMessages()
207 { 274 {
208 PubRel pubRel(packet_id); 275 PubRel pubRel(packet_id);
209 MqttPacket packet(pubRel); 276 MqttPacket packet(pubRel);
210 - count += c->writeMqttPacketAndBlameThisClient(packet, 2); 277 + count += c->writeMqttPacketAndBlameThisClient(packet);
211 } 278 }
212 } 279 }
213 280
session.h
@@ -48,11 +48,14 @@ class Session @@ -48,11 +48,14 @@ class Session
48 std::set<uint16_t> outgoingQoS2MessageIds; 48 std::set<uint16_t> outgoingQoS2MessageIds;
49 std::mutex qosQueueMutex; 49 std::mutex qosQueueMutex;
50 uint16_t nextPacketId = 0; 50 uint16_t nextPacketId = 0;
  51 + uint16_t qosInFlightCounter = 0;
51 uint16_t QoSLogPrintedAtId = 0; 52 uint16_t QoSLogPrintedAtId = 0;
52 std::chrono::time_point<std::chrono::steady_clock> lastTouched = std::chrono::steady_clock::now(); 53 std::chrono::time_point<std::chrono::steady_clock> lastTouched = std::chrono::steady_clock::now();
53 Logger *logger = Logger::getInstance(); 54 Logger *logger = Logger::getInstance();
54 int64_t getSessionRelativeAgeInMs() const; 55 int64_t getSessionRelativeAgeInMs() const;
55 void setSessionTouch(int64_t ageInMs); 56 void setSessionTouch(int64_t ageInMs);
  57 + bool requiresPacketRetransmission() const;
  58 + void increasePacketId();
56 59
57 Session(const Session &other); 60 Session(const Session &other);
58 public: 61 public:
@@ -69,7 +72,7 @@ public: @@ -69,7 +72,7 @@ public:
69 const std::string &getClientId() const { return client_id; } 72 const std::string &getClientId() const { return client_id; }
70 std::shared_ptr<Client> makeSharedClient() const; 73 std::shared_ptr<Client> makeSharedClient() const;
71 void assignActiveConnection(std::shared_ptr<Client> &client); 74 void assignActiveConnection(std::shared_ptr<Client> &client);
72 - void writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count); 75 + void writePacket(MqttPacket &packet, char max_qos, std::shared_ptr<MqttPacket> &downgradedQos0PacketCopy, uint64_t &count);
73 void clearQosMessage(uint16_t packet_id); 76 void clearQosMessage(uint16_t packet_id);
74 uint64_t sendPendingQosMessages(); 77 uint64_t sendPendingQosMessages();
75 void touch(std::chrono::time_point<std::chrono::steady_clock> val); 78 void touch(std::chrono::time_point<std::chrono::steady_clock> val);
subscriptionstore.cpp
@@ -331,7 +331,7 @@ void SubscriptionStore::publishRecursively(std::vector&lt;std::string&gt;::const_itera @@ -331,7 +331,7 @@ void SubscriptionStore::publishRecursively(std::vector&lt;std::string&gt;::const_itera
331 } 331 }
332 } 332 }
333 333
334 -void SubscriptionStore::queuePacketAtSubscribers(const std::vector<std::string> &subtopics, const MqttPacket &packet, bool dollar) 334 +void SubscriptionStore::queuePacketAtSubscribers(const std::vector<std::string> &subtopics, MqttPacket &packet, bool dollar)
335 { 335 {
336 assert(subtopics.size() > 0); 336 assert(subtopics.size() > 0);
337 337
@@ -346,9 +346,10 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::vector&lt;std::string&gt; @@ -346,9 +346,10 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::vector&lt;std::string&gt;
346 publishRecursively(subtopics.begin(), subtopics.end(), startNode, subscriberSessions); 346 publishRecursively(subtopics.begin(), subtopics.end(), startNode, subscriberSessions);
347 } 347 }
348 348
  349 + std::shared_ptr<MqttPacket> possibleQos0Copy;
349 for(const ReceivingSubscriber &x : subscriberSessions) 350 for(const ReceivingSubscriber &x : subscriberSessions)
350 { 351 {
351 - x.session->writePacket(packet, x.qos, false, count); 352 + x.session->writePacket(packet, x.qos, possibleQos0Copy, count);
352 } 353 }
353 354
354 std::shared_ptr<Client> sender = packet.getSender(); 355 std::shared_ptr<Client> sender = packet.getSender();
@@ -422,9 +423,10 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr&lt;Ses @@ -422,9 +423,10 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr&lt;Ses
422 giveClientRetainedMessagesRecursively(subscribeSubtopics.begin(), subscribeSubtopics.end(), startNode, false, packetList); 423 giveClientRetainedMessagesRecursively(subscribeSubtopics.begin(), subscribeSubtopics.end(), startNode, false, packetList);
423 } 424 }
424 425
425 - for(const MqttPacket &packet : packetList) 426 + std::shared_ptr<MqttPacket> possibleQos0Copy;
  427 + for(MqttPacket &packet : packetList)
426 { 428 {
427 - ses->writePacket(packet, max_qos, true, count); 429 + ses->writePacket(packet, max_qos, possibleQos0Copy, count);
428 } 430 }
429 431
430 return count; 432 return count;
subscriptionstore.h
@@ -120,7 +120,7 @@ public: @@ -120,7 +120,7 @@ public:
120 void registerClientAndKickExistingOne(std::shared_ptr<Client> &client); 120 void registerClientAndKickExistingOne(std::shared_ptr<Client> &client);
121 bool sessionPresent(const std::string &clientid); 121 bool sessionPresent(const std::string &clientid);
122 122
123 - void queuePacketAtSubscribers(const std::vector<std::string> &subtopics, const MqttPacket &packet, bool dollar = false); 123 + void queuePacketAtSubscribers(const std::vector<std::string> &subtopics, MqttPacket &packet, bool dollar = false);
124 void giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end, 124 void giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end,
125 RetainedMessageNode *this_node, bool poundMode, std::forward_list<MqttPacket> &packetList) const; 125 RetainedMessageNode *this_node, bool poundMode, std::forward_list<MqttPacket> &packetList) const;
126 uint64_t giveClientRetainedMessages(const std::shared_ptr<Session> &ses, const std::vector<std::string> &subscribeSubtopics, char max_qos); 126 uint64_t giveClientRetainedMessages(const std::shared_ptr<Session> &ses, const std::vector<std::string> &subscribeSubtopics, char max_qos);
threaddata.cpp
@@ -127,7 +127,8 @@ void ThreadData::publishStat(const std::string &amp;topic, uint64_t n) @@ -127,7 +127,8 @@ void ThreadData::publishStat(const std::string &amp;topic, uint64_t n)
127 splitTopic(topic, subtopics); 127 splitTopic(topic, subtopics);
128 const std::string payload = std::to_string(n); 128 const std::string payload = std::to_string(n);
129 Publish p(topic, payload, 0); 129 Publish p(topic, payload, 0);
130 - subscriptionStore->queuePacketAtSubscribers(subtopics, p, true); 130 + MqttPacket pack(p);
  131 + subscriptionStore->queuePacketAtSubscribers(subtopics, pack, true);
131 subscriptionStore->setRetainedMessage(topic, subtopics, payload, 0); 132 subscriptionStore->setRetainedMessage(topic, subtopics, payload, 0);
132 } 133 }
133 134