Commit f67238c5b794bb4ff5d3666b052bb6c80dad3ba2

Authored by Wiebe Cazemier
1 parent 4947d33a

Fix not downgrading QoS to subscription QoS

This entails making copies of the original packet when necessary,
because QoS 0 doesn't have a packet id. I tried to keep it to an
absolute minimum and do some precarious optmizations for it. There are
tests though.
FlashMQTests/tst_maintests.cpp
... ... @@ -102,6 +102,15 @@ private slots:
102 102  
103 103 void testSavingSessions();
104 104  
  105 + void testCopyPacket();
  106 +
  107 + void testDowngradeQoSOnSubscribeQos2to2();
  108 + void testDowngradeQoSOnSubscribeQos2to1();
  109 + void testDowngradeQoSOnSubscribeQos2to0();
  110 + void testDowngradeQoSOnSubscribeQos1to1();
  111 + void testDowngradeQoSOnSubscribeQos1to0();
  112 + void testDowngradeQoSOnSubscribeQos0to0();
  113 +
105 114 };
106 115  
107 116 MainTests::MainTests()
... ... @@ -1027,7 +1036,9 @@ void MainTests::testSavingSessions()
1027 1036  
1028 1037 std::shared_ptr<Session> c1ses = c1->getSession();
1029 1038 c1.reset();
1030   - c1ses->writePacket(publish, 1, false, count);
  1039 + MqttPacket publishPacket(publish);
  1040 + std::shared_ptr<MqttPacket> possibleQos0Copy;
  1041 + c1ses->writePacket(publishPacket, 1, possibleQos0Copy, count);
1031 1042  
1032 1043 store->saveSessionsAndSubscriptions("/tmp/flashmqtests_sessions.db");
1033 1044  
... ... @@ -1091,6 +1102,145 @@ void MainTests::testSavingSessions()
1091 1102 }
1092 1103 }
1093 1104  
  1105 +void testCopyPacketHelper(const std::string &topic, char from_qos, char to_qos, bool retain)
  1106 +{
  1107 + assert(to_qos <= from_qos);
  1108 +
  1109 + Logger::getInstance()->setFlags(false, false, true);
  1110 +
  1111 + std::shared_ptr<Settings> settings(new Settings());
  1112 + settings->logDebug = false;
  1113 + std::shared_ptr<SubscriptionStore> store(new SubscriptionStore());
  1114 + std::shared_ptr<ThreadData> t(new ThreadData(0, store, settings));
  1115 +
  1116 + // Kind of a hack...
  1117 + Authentication auth(*settings.get());
  1118 + ThreadAuth::assign(&auth);
  1119 +
  1120 + std::shared_ptr<Client> dummyClient(new Client(0, t, nullptr, false, nullptr, settings, false));
  1121 + dummyClient->setClientProperties(ProtocolVersion::Mqtt311, "qostestclient", "user1", true, 60, false);
  1122 + store->registerClientAndKickExistingOne(dummyClient);
  1123 +
  1124 + uint16_t packetid = 66;
  1125 + for (int len = 0; len < 150; len++ )
  1126 + {
  1127 + const uint16_t pack_id = packetid++;
  1128 +
  1129 + std::vector<MqttPacket> parsedPackets;
  1130 +
  1131 + const std::string payloadOne = getSecureRandomString(len);
  1132 + Publish pubOne(topic, payloadOne, from_qos);
  1133 + pubOne.retain = retain;
  1134 + MqttPacket stagingPacketOne(pubOne);
  1135 + if (from_qos > 0)
  1136 + stagingPacketOne.setPacketId(pack_id);
  1137 + CirBuf stagingBufOne(1024);
  1138 + stagingPacketOne.readIntoBuf(stagingBufOne);
  1139 +
  1140 + MqttPacket::bufferToMqttPackets(stagingBufOne, parsedPackets, dummyClient);
  1141 + QVERIFY(parsedPackets.size() == 1);
  1142 + MqttPacket parsedPacketOne = std::move(parsedPackets.front());
  1143 + parsedPacketOne.handlePublish();
  1144 + if (retain) // A normal handled packet always has retain=0, so I force setting it here.
  1145 + parsedPacketOne.setRetain();
  1146 + QCOMPARE(stagingPacketOne.getTopic(), parsedPacketOne.getTopic());
  1147 + QCOMPARE(stagingPacketOne.getPayloadCopy(), parsedPacketOne.getPayloadCopy());
  1148 +
  1149 + std::shared_ptr<MqttPacket> copiedPacketOne = parsedPacketOne.getCopy(to_qos);
  1150 +
  1151 + QCOMPARE(payloadOne, copiedPacketOne->getPayloadCopy());
  1152 +
  1153 + // Now compare the written buffer of our copied packet to one that was written with our known good reference packet.
  1154 +
  1155 + Publish pubReference(topic, payloadOne, to_qos);
  1156 + pubReference.retain = retain;
  1157 + MqttPacket packetReference(pubReference);
  1158 + if (to_qos > 0)
  1159 + packetReference.setPacketId(pack_id);
  1160 + CirBuf bufOfReference(1024);
  1161 + CirBuf bufOfCopied(1024);
  1162 + packetReference.readIntoBuf(bufOfReference);
  1163 + copiedPacketOne->readIntoBuf(bufOfCopied);
  1164 + QVERIFY2(bufOfCopied == bufOfReference, formatString("Failure on length %d for topic %s, from qos %d to qos %d, retain: %d.",
  1165 + len, topic.c_str(), from_qos, to_qos, retain).c_str());
  1166 + }
  1167 +}
  1168 +
  1169 +/**
  1170 + * @brief MainTests::testCopyPacket tests the actual bytes of a copied that would be written to a client.
  1171 + *
  1172 + * This is specifically to test the optimiziations in getCopy(). It indirectly also tests packet parsing.
  1173 + */
  1174 +void MainTests::testCopyPacket()
  1175 +{
  1176 + for (int retain = 0; retain < 2; retain++)
  1177 + {
  1178 + testCopyPacketHelper("John/McLane", 0, 0, retain);
  1179 + testCopyPacketHelper("Ben/Sisko", 1, 1, retain);
  1180 + testCopyPacketHelper("Rebecca/Bunch", 2, 2, retain);
  1181 +
  1182 + testCopyPacketHelper("Buffy/Slayer", 1, 0, retain);
  1183 + testCopyPacketHelper("Sarah/Connor", 2, 0, retain);
  1184 + testCopyPacketHelper("Susan/Mayer", 2, 1, retain);
  1185 + }
  1186 +}
  1187 +
  1188 +void testDowngradeQoSOnSubscribeHelper(const char pub_qos, const char sub_qos)
  1189 +{
  1190 + TwoClientTestContext testContext;
  1191 +
  1192 + const QString topic("Star/Trek");
  1193 + const QByteArray payload("Captain Kirk");
  1194 +
  1195 + testContext.connectSender();
  1196 + testContext.connectReceiver();
  1197 +
  1198 + testContext.subscribeReceiver(topic, sub_qos);
  1199 + testContext.publish(topic, payload, pub_qos, false);
  1200 +
  1201 + testContext.waitReceiverReceived(1);
  1202 +
  1203 + QCOMPARE(testContext.receivedMessages.length(), 1);
  1204 + QMQTT::Message &recv = testContext.receivedMessages.first();
  1205 +
  1206 + const char expected_qos = std::min<const char>(pub_qos, sub_qos);
  1207 + QVERIFY2(recv.qos() == expected_qos, formatString("Failure: received QoS is %d. Published is %d. Subscribed as %d. Expected QoS is %d",
  1208 + recv.qos(), pub_qos, sub_qos, expected_qos).c_str());
  1209 + QVERIFY(recv.topic() == topic);
  1210 + QVERIFY(recv.payload() == payload);
  1211 +}
  1212 +
  1213 +void MainTests::testDowngradeQoSOnSubscribeQos2to2()
  1214 +{
  1215 + testDowngradeQoSOnSubscribeHelper(2, 2);
  1216 +}
  1217 +
  1218 +void MainTests::testDowngradeQoSOnSubscribeQos2to1()
  1219 +{
  1220 + testDowngradeQoSOnSubscribeHelper(2, 1);
  1221 +}
  1222 +
  1223 +void MainTests::testDowngradeQoSOnSubscribeQos2to0()
  1224 +{
  1225 + testDowngradeQoSOnSubscribeHelper(2, 0);
  1226 +}
  1227 +
  1228 +void MainTests::testDowngradeQoSOnSubscribeQos1to1()
  1229 +{
  1230 + testDowngradeQoSOnSubscribeHelper(1, 1);
  1231 +}
  1232 +
  1233 +void MainTests::testDowngradeQoSOnSubscribeQos1to0()
  1234 +{
  1235 + testDowngradeQoSOnSubscribeHelper(1, 0);
  1236 +}
  1237 +
  1238 +void MainTests::testDowngradeQoSOnSubscribeQos0to0()
  1239 +{
  1240 + testDowngradeQoSOnSubscribeHelper(0, 0);
  1241 +}
  1242 +
  1243 +
1094 1244 int main(int argc, char *argv[])
1095 1245 {
1096 1246 QCoreApplication app(argc, argv);
... ...
FlashMQTests/twoclienttestcontext.cpp
... ... @@ -35,12 +35,22 @@ TwoClientTestContext::TwoClientTestContext(QObject *parent) : QObject(parent)
35 35 connect(receiver.data(), &QMQTT::Client::error, this, &TwoClientTestContext::onClientError);
36 36 }
37 37  
  38 +void TwoClientTestContext::publish(const QString &topic, const QByteArray &payload)
  39 +{
  40 + publish(topic, payload, 0, false);
  41 +}
  42 +
38 43 void TwoClientTestContext::publish(const QString &topic, const QByteArray &payload, bool retain)
39 44 {
  45 + publish(topic, payload, 0, retain);
  46 +}
  47 +
  48 +void TwoClientTestContext::publish(const QString &topic, const QByteArray &payload, const quint8 qos, bool retain)
  49 +{
40 50 QMQTT::Message msg;
41 51 msg.setTopic(topic);
42 52 msg.setRetain(retain);
43   - msg.setQos(0);
  53 + msg.setQos(qos);
44 54 msg.setPayload(payload);
45 55 sender->publish(msg);
46 56 }
... ... @@ -71,9 +81,9 @@ void TwoClientTestContext::disconnectReceiver()
71 81 waiter.exec();
72 82 }
73 83  
74   -void TwoClientTestContext::subscribeReceiver(const QString &topic)
  84 +void TwoClientTestContext::subscribeReceiver(const QString &topic, const quint8 qos)
75 85 {
76   - receiver->subscribe(topic);
  86 + receiver->subscribe(topic, qos);
77 87  
78 88 QEventLoop waiter;
79 89 QTimer timeout;
... ...
FlashMQTests/twoclienttestcontext.h
... ... @@ -34,11 +34,13 @@ private slots:
34 34  
35 35 public:
36 36 explicit TwoClientTestContext(QObject *parent = nullptr);
37   - void publish(const QString &topic, const QByteArray &payload, bool retain = false);
  37 + void publish(const QString &topic, const QByteArray &payload);
  38 + void publish(const QString &topic, const QByteArray &payload, bool retain);
  39 + void publish(const QString &topic, const QByteArray &payload, const quint8 qos, bool retain);
38 40 void connectSender();
39 41 void connectReceiver();
40 42 void disconnectReceiver();
41   - void subscribeReceiver(const QString &topic);
  43 + void subscribeReceiver(const QString &topic, const quint8 qos = 0);
42 44 void waitReceiverReceived(int count);
43 45 void onClientError(const QMQTT::ClientError error);
44 46  
... ...
cirbuf.cpp
... ... @@ -252,3 +252,22 @@ void CirBuf::read(void *buf, const size_t count)
252 252 assert(_packet_len == 0);
253 253 assert(i == static_cast<int>(count));
254 254 }
  255 +
  256 +/**
  257 + * @brief CirBuf::operator == simplistic comparision. It doesn't take the fact that it's circular into account.
  258 + * @param other
  259 + * @return
  260 + *
  261 + * It was created for unit testing. read() and write() are non-const, so taking the circular properties into account
  262 + * would need more/duplicate code that I don't need at this point.
  263 + */
  264 +bool CirBuf::operator==(const CirBuf &other) const
  265 +{
  266 +#ifdef NDEBUG
  267 + throw std::exception(); // you can't use it in release builds, because new buffers aren't zeroed.
  268 +#endif
  269 +
  270 + return tail == 0 && other.tail == 0
  271 + && usedBytes() == other.usedBytes()
  272 + && std::memcmp(buf, other.buf, size) == 0;
  273 +}
... ...
cirbuf.h
... ... @@ -60,6 +60,8 @@ public:
60 60  
61 61 void write(const void *buf, size_t count);
62 62 void read(void *buf, const size_t count);
  63 +
  64 + bool operator==(const CirBuf &other) const;
63 65 };
64 66  
65 67 #endif // CIRBUF_H
... ...
client.cpp
... ... @@ -56,7 +56,7 @@ Client::~Client()
56 56 {
57 57 Publish will(will_topic, will_payload, will_qos);
58 58 will.retain = will_retain;
59   - const MqttPacket willPacket(will);
  59 + MqttPacket willPacket(will);
60 60  
61 61 const std::vector<std::string> subtopics = splitToVector(will_topic, '/');
62 62 store->queuePacketAtSubscribers(subtopics, willPacket);
... ... @@ -180,7 +180,7 @@ void Client::writeText(const std::string &amp;text)
180 180 setReadyForWriting(true);
181 181 }
182 182  
183   -int Client::writeMqttPacket(const MqttPacket &packet, const char qos)
  183 +int Client::writeMqttPacket(const MqttPacket &packet)
184 184 {
185 185 std::lock_guard<std::mutex> locker(writeBufMutex);
186 186  
... ... @@ -193,7 +193,7 @@ int Client::writeMqttPacket(const MqttPacket &amp;packet, const char qos)
193 193  
194 194 // And drop a publish when it doesn't fit, even after resizing. This means we do allow pings. And
195 195 // QoS packet are queued and limited elsewhere.
196   - if (packet.packetType == PacketType::PUBLISH && qos == 0 && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace())
  196 + if (packet.packetType == PacketType::PUBLISH && packet.getQos() == 0 && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace())
197 197 {
198 198 return 0;
199 199 }
... ... @@ -208,11 +208,11 @@ int Client::writeMqttPacket(const MqttPacket &amp;packet, const char qos)
208 208 }
209 209  
210 210 // Helper method to avoid the exception ending up at the sender of messages, which would then get disconnected.
211   -int Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet, const char qos)
  211 +int Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet)
212 212 {
213 213 try
214 214 {
215   - return this->writeMqttPacket(packet, qos);
  215 + return this->writeMqttPacket(packet);
216 216 }
217 217 catch (std::exception &ex)
218 218 {
... ...
client.h
... ... @@ -121,8 +121,8 @@ public:
121 121  
122 122 void writeText(const std::string &text);
123 123 void writePingResp();
124   - int writeMqttPacket(const MqttPacket &packet, const char qos = 0);
125   - int writeMqttPacketAndBlameThisClient(const MqttPacket &packet, const char qos);
  124 + int writeMqttPacket(const MqttPacket &packet);
  125 + int writeMqttPacketAndBlameThisClient(const MqttPacket &packet);
126 126 bool writeBufIntoFd();
127 127 bool readyForDisconnecting() const { return disconnectWhenBytesWritten && writebuf.usedBytes() == 0; }
128 128  
... ...
mqttpacket.cpp
... ... @@ -48,11 +48,69 @@ MqttPacket::MqttPacket(CirBuf &amp;buf, size_t packet_len, size_t fixed_header_lengt
48 48 * @brief MqttPacket::getCopy (using default copy constructor and resetting some selected fields) is easier than using the copy constructor
49 49 * publically, because then I have to keep maintaining a functioning copy constructor for each new field I add.
50 50 * @return a shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource.
  51 + *
  52 + * The idea is that because a packet with QoS is longer than one without, we just copy as much as possible if both packets have the same QoS.
  53 + *
  54 + * Note that there can be two types of packets: one with the fixed header (including remaining length), and one without. The latter we could be
  55 + * more clever about, but I'm forgoing that right now. Their use is mostly for retained messages.
  56 + *
  57 + * Also note that some fields are undeterminstic in the copy: dup, retain and packetid for instance. Sometimes they come from the original,
  58 + * sometimes not. The current planned usage is that those fields will either ONLY or NEVER be used in the copy, so it doesn't matter what I do
  59 + * with them here. I may reconsider.
51 60 */
52   -std::shared_ptr<MqttPacket> MqttPacket::getCopy() const
  61 +std::shared_ptr<MqttPacket> MqttPacket::getCopy(char new_max_qos) const
53 62 {
  63 + assert(packetType == PacketType::PUBLISH);
  64 +
  65 + // You're not supposed to copy a duplicate packet. The only packets that get the dup flag, should not be copied AGAIN. This
  66 + // has to do with the Session::writePacket() and Session::sendPendingQosMessages() logic.
  67 + assert((first_byte & 0b00001000) == 0);
  68 +
  69 + if (qos > 0 && new_max_qos == 0)
  70 + {
  71 + // if shrinking the packet doesn't alter the amount of bytes in the 'remaining length' part of the header, we can
  72 + // just memmove+shrink the packet. This is because the packet id always is two bytes before the payload, so we just move the payload
  73 + // over it. When testing 100M copies, it went from 21000 ms to 10000 ms. In other words, about 100 ยตs to 200 ยตs per copy.
  74 + // There is an elaborate unit test to test this optimization.
  75 + if ((fixed_header_length == 2 && bites.size() < 125))
  76 + {
  77 + // I don't know yet if this is true, but I don't want to forget when I implemenet MQTT5.
  78 + assert(sender && sender->getProtocolVersion() <= ProtocolVersion::Mqtt311);
  79 +
  80 + std::shared_ptr<MqttPacket> p(new MqttPacket(*this));
  81 + p->sender.reset();
  82 +
  83 + if (payloadLen > 0)
  84 + std::memmove(&p->bites[packet_id_pos], &p->bites[packet_id_pos+2], payloadLen);
  85 + p->bites.erase(p->bites.end() - 2, p->bites.end());
  86 + p->packet_id_pos = 0;
  87 + p->payloadStart -= 2;
  88 + if (pos > p->bites.size()) // pos can possible be set elsewhere, so we only set it back if it was after the payload.
  89 + p->pos -= 2;
  90 + p->packet_id = 0;
  91 +
  92 + // Clear QoS bits from the header.
  93 + p->first_byte &= 0b11111001;
  94 + p->bites[0] = p->first_byte;
  95 +
  96 + assert((p->bites[1] & 0b10000000) == 0); // when there is an MSB, I musn't get rid of it.
  97 + assert(p->bites[1] > 3); // There has to be a remaining value after subtracting 2.
  98 +
  99 + p->bites[1] -= 2; // Reduce the value in the 'remaining length' part of the header.
  100 +
  101 + return p;
  102 + }
  103 +
  104 + Publish pub(topic, getPayloadCopy(), new_max_qos);
  105 + pub.retain = getRetain();
  106 + std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(pub));
  107 + return copyPacket;
  108 + }
  109 +
54 110 std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(*this));
55 111 copyPacket->sender.reset();
  112 + if (qos != new_max_qos)
  113 + copyPacket->setQos(new_max_qos);
56 114 return copyPacket;
57 115 }
58 116  
... ... @@ -773,11 +831,11 @@ size_t MqttPacket::getTotalMemoryFootprint()
773 831 * @return
774 832 *
775 833 * It's necessary sometimes, but it's against FlashMQ's concept of not parsing the payload. Normally, you can just write out
776   - * the whole byte array that is a packet to subscribers. No need to copy and such.
  834 + * the whole byte array of an original packet to subscribers. No need to copy and such.
777 835 *
778   - * I created it for saving QoS packages in the db file.
  836 + * But, as stated, sometimes it's necessary.
779 837 */
780   -std::string MqttPacket::getPayloadCopy()
  838 +std::string MqttPacket::getPayloadCopy() const
781 839 {
782 840 assert(payloadStart > 0);
783 841 assert(pos <= bites.size());
... ... @@ -798,6 +856,23 @@ size_t MqttPacket::getSizeIncludingNonPresentHeader() const
798 856 return total;
799 857 }
800 858  
  859 +void MqttPacket::setQos(const char new_qos)
  860 +{
  861 + // You can't change to a QoS level that would remove the packet identifier.
  862 + assert((qos == 0 && new_qos == 0) || (qos > 0 && new_qos > 0));
  863 + assert(new_qos > 0 && packet_id_pos > 0);
  864 +
  865 + qos = new_qos;
  866 + first_byte &= 0b11111001;
  867 + first_byte |= (qos << 1);
  868 +
  869 + if (fixed_header_length > 0)
  870 + {
  871 + pos = 0;
  872 + writeByte(first_byte);
  873 + }
  874 +}
  875 +
801 876 const std::string &MqttPacket::getTopic() const
802 877 {
803 878 return this->topic;
... ... @@ -886,6 +961,31 @@ size_t MqttPacket::remainingAfterPos()
886 961 return bites.size() - pos;
887 962 }
888 963  
  964 +bool MqttPacket::getRetain() const
  965 +{
  966 + return (first_byte & 0b00000001);
  967 +}
  968 +
  969 +/**
  970 + * @brief MqttPacket::setRetain set the retain bit in the first byte. I think I only need this in tests, because existing subscribers don't get retain=1,
  971 + * so handlePublish() clears it. But I needed it to be set in testing.
  972 + *
  973 + * Publishing of the retained messages goes through the MqttPacket(Publish) constructor, hence this setRetain() isn't necessary for that.
  974 + */
  975 +void MqttPacket::setRetain()
  976 +{
  977 +#ifndef TESTING
  978 + assert(false);
  979 +#endif
  980 +
  981 + first_byte |= 0b00000001;
  982 +
  983 + if (fixed_header_length > 0)
  984 + {
  985 + pos = 0;
  986 + writeByte(first_byte);
  987 + }
  988 +}
889 989  
890 990 void MqttPacket::readIntoBuf(CirBuf &buf) const
891 991 {
... ...
mqttpacket.h
... ... @@ -80,7 +80,7 @@ public:
80 80  
81 81 MqttPacket(MqttPacket &&other) = default;
82 82  
83   - std::shared_ptr<MqttPacket> getCopy() const;
  83 + std::shared_ptr<MqttPacket> getCopy(char new_max_qos) const;
84 84  
85 85 // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance.
86 86 MqttPacket(const ConnAck &connAck);
... ... @@ -109,6 +109,7 @@ public:
109 109 size_t getSizeIncludingNonPresentHeader() const;
110 110 const std::vector<char> &getBites() const { return bites; }
111 111 char getQos() const { return qos; }
  112 + void setQos(const char new_qos);
112 113 const std::string &getTopic() const;
113 114 const std::vector<std::string> &getSubtopics() const;
114 115 std::shared_ptr<Client> getSender() const;
... ... @@ -120,8 +121,10 @@ public:
120 121 uint16_t getPacketId() const;
121 122 void setDuplicate();
122 123 size_t getTotalMemoryFootprint();
123   - std::string getPayloadCopy();
124 124 void readIntoBuf(CirBuf &buf) const;
  125 + std::string getPayloadCopy() const;
  126 + bool getRetain() const;
  127 + void setRetain();
125 128 };
126 129  
127 130 #endif // MQTTPACKET_H
... ...
qospacketqueue.cpp
... ... @@ -44,11 +44,11 @@ size_t QoSPacketQueue::getByteSize() const
44 44 * @param id
45 45 * @return the packet copy.
46 46 */
47   -std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id)
  47 +std::shared_ptr<MqttPacket> QoSPacketQueue::queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos)
48 48 {
49 49 assert(p.getQos() > 0);
50 50  
51   - std::shared_ptr<MqttPacket> copyPacket = p.getCopy();
  51 + std::shared_ptr<MqttPacket> copyPacket = p.getCopy(new_max_qos);
52 52 copyPacket->setPacketId(id);
53 53 queue.push_back(copyPacket);
54 54 qosQueueBytes += copyPacket->getTotalMemoryFootprint();
... ...
qospacketqueue.h
... ... @@ -15,7 +15,7 @@ public:
15 15 void erase(const uint16_t packet_id);
16 16 size_t size() const;
17 17 size_t getByteSize() const;
18   - std::shared_ptr<MqttPacket> queuePacket(const MqttPacket &p, uint16_t id);
  18 + std::shared_ptr<MqttPacket> queuePacket(const MqttPacket &p, uint16_t id, char new_max_qos);
19 19 std::shared_ptr<MqttPacket> queuePacket(const Publish &pub, uint16_t id);
20 20  
21 21 std::list<std::shared_ptr<MqttPacket>>::const_iterator begin() const;
... ...
session.cpp
... ... @@ -59,6 +59,29 @@ void Session::setSessionTouch(int64_t ageInMs)
59 59 lastTouched = point;
60 60 }
61 61  
  62 +bool Session::requiresPacketRetransmission() const
  63 +{
  64 + const std::shared_ptr<Client> client = makeSharedClient();
  65 +
  66 + if (!client)
  67 + return true;
  68 +
  69 + // MQTT 3.1: "Brokers, however, should retry any unacknowledged message."
  70 + // MQTT 3.1.1: "This [reconnecting] is the only circumstance where a Client or Server is REQUIRED to redeliver messages."
  71 + if (client->getProtocolVersion() < ProtocolVersion::Mqtt311)
  72 + return true;
  73 +
  74 + // TODO: for MQTT5, the rules are different.
  75 + return !client->getCleanSession();
  76 +}
  77 +
  78 +void Session::increasePacketId()
  79 +{
  80 + nextPacketId++;
  81 + if (nextPacketId == 0)
  82 + nextPacketId++;
  83 +}
  84 +
62 85 /**
63 86 * @brief Session::Session copy constructor. Was created for session storing, and is explicitely kept private, to avoid making accidental copies.
64 87 * @param other
... ... @@ -116,63 +139,101 @@ void Session::assignActiveConnection(std::shared_ptr&lt;Client&gt; &amp;client)
116 139  
117 140 /**
118 141 * @brief Session::writePacket is the main way to give a client a packet -> it goes through the session.
119   - * @param packet
  142 + * @param packet is not const. We set the qos and packet id for each publish. This should be safe, because the packet
  143 + * with original packet id and qos is not saved. This saves unnecessary copying.
120 144 * @param max_qos
121 145 * @param retain. Keep MQTT-3.3.1-9 in mind: existing subscribers don't get retain=1 on packets.
122 146 * @param count. Reference value is updated. It's for statistics.
123 147 */
124   -void Session::writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count)
  148 +void Session::writePacket(MqttPacket &packet, char max_qos, std::shared_ptr<MqttPacket> &downgradedQos0PacketCopy, uint64_t &count)
125 149 {
126 150 assert(max_qos <= 2);
127   - const char qos = std::min<char>(packet.getQos(), max_qos);
  151 + const char effectiveQos = std::min<char>(packet.getQos(), max_qos);
128 152  
129 153 Authentication *_auth = ThreadAuth::getAuth();
130 154 assert(_auth);
131 155 Authentication &auth = *_auth;
132   - if (auth.aclCheck(client_id, username, packet.getTopic(), packet.getSubtopics(), AclAccess::read, qos, retain) == AuthResult::success)
  156 + if (auth.aclCheck(client_id, username, packet.getTopic(), packet.getSubtopics(), AclAccess::read, effectiveQos, packet.getRetain()) == AuthResult::success)
133 157 {
134   - if (qos == 0)
  158 + std::shared_ptr<Client> c = makeSharedClient();
  159 + if (effectiveQos == 0)
135 160 {
136   - std::shared_ptr<Client> c = makeSharedClient();
137   -
138 161 if (c)
139 162 {
140   - count += c->writeMqttPacketAndBlameThisClient(packet, qos);
  163 + const MqttPacket *packetToSend = &packet;
  164 +
  165 + if (max_qos < packet.getQos())
  166 + {
  167 + if (!downgradedQos0PacketCopy)
  168 + downgradedQos0PacketCopy = packet.getCopy(max_qos);
  169 + packetToSend = downgradedQos0PacketCopy.get();
  170 + }
  171 +
  172 + count += c->writeMqttPacketAndBlameThisClient(*packetToSend);
141 173 }
142 174 }
143   - else if (qos > 0)
  175 + else if (effectiveQos > 0)
144 176 {
145   - std::unique_lock<std::mutex> locker(qosQueueMutex);
  177 + const bool requiresRetransmission = requiresPacketRetransmission();
146 178  
147   - const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size();
148   - if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosPacketQueue.getByteSize() >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0))
  179 + if (requiresRetransmission)
149 180 {
150   - if (QoSLogPrintedAtId != nextPacketId)
  181 + std::unique_lock<std::mutex> locker(qosQueueMutex);
  182 +
  183 + const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size();
  184 + if (totalQosPacketsInTransit >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosPacketQueue.getByteSize() >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0))
151 185 {
152   - logger->logf(LOG_WARNING, "Dropping QoS message(s) for client '%s', because its QoS buffers were full.", client_id.c_str());
153   - QoSLogPrintedAtId = nextPacketId;
  186 + if (QoSLogPrintedAtId != nextPacketId)
  187 + {
  188 + logger->logf(LOG_WARNING, "Dropping QoS message(s) for client '%s', because max in-transit packet count reached.", client_id.c_str());
  189 + QoSLogPrintedAtId = nextPacketId;
  190 + }
  191 + return;
154 192 }
155   - return;
156   - }
157   - nextPacketId++;
158   - if (nextPacketId == 0)
159   - nextPacketId++;
160 193  
161   - std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId);
162   - locker.unlock();
  194 + increasePacketId();
  195 + std::shared_ptr<MqttPacket> copyPacket = qosPacketQueue.queuePacket(packet, nextPacketId, effectiveQos);
163 196  
164   - std::shared_ptr<Client> c = makeSharedClient();
165   - if (c)
  197 + if (c)
  198 + {
  199 + count += c->writeMqttPacketAndBlameThisClient(*copyPacket.get());
  200 + copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate.
  201 + }
  202 + }
  203 + else
166 204 {
167   - count += c->writeMqttPacketAndBlameThisClient(*copyPacket.get(), qos);
168   - copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate.
  205 + // We don't need to make a copy of the packet in this branch, because:
  206 + // - The packet to give the client won't shrink in size because source and client have a packet_id.
  207 + // - We don't have to store the copy in the session for retransmission, see Session::requiresPacketRetransmission()
  208 + // So, we just keep altering the original published packet.
  209 +
  210 + std::unique_lock<std::mutex> locker(qosQueueMutex);
  211 +
  212 + if (qosInFlightCounter >= 65530) // Includes a small safety margin.
  213 + {
  214 + if (QoSLogPrintedAtId != nextPacketId)
  215 + {
  216 + logger->logf(LOG_WARNING, "Dropping QoS message(s) for client '%s', because it hasn't seen enough PUBACKs to release places.", client_id.c_str());
  217 + QoSLogPrintedAtId = nextPacketId;
  218 + }
  219 + return;
  220 + }
  221 +
  222 + increasePacketId();
  223 +
  224 + // This changes the packet ID and QoS of the incoming packet for each subscriber, but because we don't store that packet anywhere,
  225 + // that should be fine.
  226 + packet.setPacketId(nextPacketId);
  227 + packet.setQos(effectiveQos);
  228 +
  229 + qosInFlightCounter++;
  230 + assert(c); // with requiresRetransmission==false, there must be a client.
  231 + c->writeMqttPacketAndBlameThisClient(packet);
169 232 }
170 233 }
171 234 }
172 235 }
173 236  
174   -// Normatively, this while loop will break on the first element, because all messages are sent out in order and
175   -// should be acked in order.
176 237 void Session::clearQosMessage(uint16_t packet_id)
177 238 {
178 239 #ifndef NDEBUG
... ... @@ -180,7 +241,13 @@ void Session::clearQosMessage(uint16_t packet_id)
180 241 #endif
181 242  
182 243 std::lock_guard<std::mutex> locker(qosQueueMutex);
183   - qosPacketQueue.erase(packet_id);
  244 + if (requiresPacketRetransmission())
  245 + qosPacketQueue.erase(packet_id);
  246 + else
  247 + {
  248 + qosInFlightCounter--;
  249 + qosInFlightCounter = std::max<int>(0, qosInFlightCounter); // Should never happen, but in case we receive too many PUBACKs.
  250 + }
184 251 }
185 252  
186 253 // [MQTT-4.4.0-1]: "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST re-send any
... ... @@ -199,7 +266,7 @@ uint64_t Session::sendPendingQosMessages()
199 266 std::lock_guard<std::mutex> locker(qosQueueMutex);
200 267 for (const std::shared_ptr<MqttPacket> &qosMessage : qosPacketQueue)
201 268 {
202   - count += c->writeMqttPacketAndBlameThisClient(*qosMessage.get(), qosMessage->getQos());
  269 + count += c->writeMqttPacketAndBlameThisClient(*qosMessage.get());
203 270 qosMessage->setDuplicate(); // Any dealings with this packet from here will be a duplicate.
204 271 }
205 272  
... ... @@ -207,7 +274,7 @@ uint64_t Session::sendPendingQosMessages()
207 274 {
208 275 PubRel pubRel(packet_id);
209 276 MqttPacket packet(pubRel);
210   - count += c->writeMqttPacketAndBlameThisClient(packet, 2);
  277 + count += c->writeMqttPacketAndBlameThisClient(packet);
211 278 }
212 279 }
213 280  
... ...
session.h
... ... @@ -48,11 +48,14 @@ class Session
48 48 std::set<uint16_t> outgoingQoS2MessageIds;
49 49 std::mutex qosQueueMutex;
50 50 uint16_t nextPacketId = 0;
  51 + uint16_t qosInFlightCounter = 0;
51 52 uint16_t QoSLogPrintedAtId = 0;
52 53 std::chrono::time_point<std::chrono::steady_clock> lastTouched = std::chrono::steady_clock::now();
53 54 Logger *logger = Logger::getInstance();
54 55 int64_t getSessionRelativeAgeInMs() const;
55 56 void setSessionTouch(int64_t ageInMs);
  57 + bool requiresPacketRetransmission() const;
  58 + void increasePacketId();
56 59  
57 60 Session(const Session &other);
58 61 public:
... ... @@ -69,7 +72,7 @@ public:
69 72 const std::string &getClientId() const { return client_id; }
70 73 std::shared_ptr<Client> makeSharedClient() const;
71 74 void assignActiveConnection(std::shared_ptr<Client> &client);
72   - void writePacket(const MqttPacket &packet, char max_qos, bool retain, uint64_t &count);
  75 + void writePacket(MqttPacket &packet, char max_qos, std::shared_ptr<MqttPacket> &downgradedQos0PacketCopy, uint64_t &count);
73 76 void clearQosMessage(uint16_t packet_id);
74 77 uint64_t sendPendingQosMessages();
75 78 void touch(std::chrono::time_point<std::chrono::steady_clock> val);
... ...
subscriptionstore.cpp
... ... @@ -331,7 +331,7 @@ void SubscriptionStore::publishRecursively(std::vector&lt;std::string&gt;::const_itera
331 331 }
332 332 }
333 333  
334   -void SubscriptionStore::queuePacketAtSubscribers(const std::vector<std::string> &subtopics, const MqttPacket &packet, bool dollar)
  334 +void SubscriptionStore::queuePacketAtSubscribers(const std::vector<std::string> &subtopics, MqttPacket &packet, bool dollar)
335 335 {
336 336 assert(subtopics.size() > 0);
337 337  
... ... @@ -346,9 +346,10 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::vector&lt;std::string&gt;
346 346 publishRecursively(subtopics.begin(), subtopics.end(), startNode, subscriberSessions);
347 347 }
348 348  
  349 + std::shared_ptr<MqttPacket> possibleQos0Copy;
349 350 for(const ReceivingSubscriber &x : subscriberSessions)
350 351 {
351   - x.session->writePacket(packet, x.qos, false, count);
  352 + x.session->writePacket(packet, x.qos, possibleQos0Copy, count);
352 353 }
353 354  
354 355 std::shared_ptr<Client> sender = packet.getSender();
... ... @@ -422,9 +423,10 @@ uint64_t SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr&lt;Ses
422 423 giveClientRetainedMessagesRecursively(subscribeSubtopics.begin(), subscribeSubtopics.end(), startNode, false, packetList);
423 424 }
424 425  
425   - for(const MqttPacket &packet : packetList)
  426 + std::shared_ptr<MqttPacket> possibleQos0Copy;
  427 + for(MqttPacket &packet : packetList)
426 428 {
427   - ses->writePacket(packet, max_qos, true, count);
  429 + ses->writePacket(packet, max_qos, possibleQos0Copy, count);
428 430 }
429 431  
430 432 return count;
... ...
subscriptionstore.h
... ... @@ -120,7 +120,7 @@ public:
120 120 void registerClientAndKickExistingOne(std::shared_ptr<Client> &client);
121 121 bool sessionPresent(const std::string &clientid);
122 122  
123   - void queuePacketAtSubscribers(const std::vector<std::string> &subtopics, const MqttPacket &packet, bool dollar = false);
  123 + void queuePacketAtSubscribers(const std::vector<std::string> &subtopics, MqttPacket &packet, bool dollar = false);
124 124 void giveClientRetainedMessagesRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end,
125 125 RetainedMessageNode *this_node, bool poundMode, std::forward_list<MqttPacket> &packetList) const;
126 126 uint64_t giveClientRetainedMessages(const std::shared_ptr<Session> &ses, const std::vector<std::string> &subscribeSubtopics, char max_qos);
... ...
threaddata.cpp
... ... @@ -127,7 +127,8 @@ void ThreadData::publishStat(const std::string &amp;topic, uint64_t n)
127 127 splitTopic(topic, subtopics);
128 128 const std::string payload = std::to_string(n);
129 129 Publish p(topic, payload, 0);
130   - subscriptionStore->queuePacketAtSubscribers(subtopics, p, true);
  130 + MqttPacket pack(p);
  131 + subscriptionStore->queuePacketAtSubscribers(subtopics, pack, true);
131 132 subscriptionStore->setRetainedMessage(topic, subtopics, payload, 0);
132 133 }
133 134  
... ...