Commit ee79271a6ccb3d8e9013f7a31d030165b2f4aee2

Authored by Wiebe Cazemier
1 parent 98366c6d

Working on puback with reason codes

I'm simplying/merging the rec, comp and rel packets, but I'm not sure it will
work. Committing as a safe point.

Later: I got it done as planned. Testing qos > 0 and mqtt5 still needs
to be done more.
mqttpacket.cpp
@@ -153,47 +153,23 @@ MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const Publish &_pu @@ -153,47 +153,23 @@ MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const Publish &_pu
153 calculateRemainingLength(); 153 calculateRemainingLength();
154 } 154 }
155 155
156 -/**  
157 - * @brief MqttPacket::pubCommonConstruct is common code for constructors for all those empty pub control packets (ack, rec(eived), rel(ease), comp(lete)).  
158 - * @param packet_id  
159 - *  
160 - * This functions cheats a bit and doesn't use calculateRemainingLength, because it's always 2. Be sure to allocate enough room in the vector when  
161 - * you use this function (add 2 to the length).  
162 - */  
163 -void MqttPacket::pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits) 156 +MqttPacket::MqttPacket(const PubResponse &pubAck) :
  157 + bites(pubAck.getLengthIncludingFixedHeader())
164 { 158 {
165 - assert(firstByteDefaultBits <= 0xF); 159 + this->protocolVersion = pubAck.protocol_version;
166 160
167 fixed_header_length = 2; 161 fixed_header_length = 2;
168 - first_byte = (static_cast<uint8_t>(packetType) << 4) | firstByteDefaultBits; 162 + const uint8_t firstByteDefaultBits = pubAck.packet_type == PacketType::PUBREL ? 0b0010 : 0;
  163 + this->first_byte = (static_cast<uint8_t>(pubAck.packet_type) << 4) | firstByteDefaultBits;
169 writeByte(first_byte); 164 writeByte(first_byte);
170 - writeByte(2); // length is always 2.  
171 - packet_id_pos = pos;  
172 - writeUint16(packet_id);  
173 -}  
174 -  
175 -MqttPacket::MqttPacket(const PubAck &pubAck) :  
176 - bites(pubAck.getLengthWithoutFixedHeader() + 2)  
177 -{  
178 - pubCommonConstruct(pubAck.packet_id, PacketType::PUBACK);  
179 -}  
180 -  
181 -MqttPacket::MqttPacket(const PubRec &pubRec) :  
182 - bites(pubRec.getLengthWithoutFixedHeader() + 2)  
183 -{  
184 - pubCommonConstruct(pubRec.packet_id, PacketType::PUBREC);  
185 -}  
186 -  
187 -MqttPacket::MqttPacket(const PubComp &pubComp) :  
188 - bites(pubComp.getLengthWithoutFixedHeader() + 2)  
189 -{  
190 - pubCommonConstruct(pubComp.packet_id, PacketType::PUBCOMP);  
191 -} 165 + writeByte(pubAck.getRemainingLength());
  166 + this->packet_id_pos = this->pos;
  167 + writeUint16(pubAck.packet_id);
192 168
193 -MqttPacket::MqttPacket(const PubRel &pubRel) :  
194 - bites(pubRel.getLengthWithoutFixedHeader() + 2)  
195 -{  
196 - pubCommonConstruct(pubRel.packet_id, PacketType::PUBREL, 0b0010); 169 + if (pubAck.needsReasonCode())
  170 + {
  171 + writeByte(static_cast<uint8_t>(pubAck.reason_code));
  172 + }
197 } 173 }
198 174
199 void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender) 175 void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender)
@@ -767,6 +743,8 @@ void MqttPacket::handlePublish() @@ -767,6 +743,8 @@ void MqttPacket::handlePublish()
767 743
768 publishData.topic = std::string(readBytes(variable_header_length), variable_header_length); 744 publishData.topic = std::string(readBytes(variable_header_length), variable_header_length);
769 745
  746 + ReasonCodes ackCode = ReasonCodes::Success;
  747 +
770 if (qos) 748 if (qos)
771 { 749 {
772 packet_id_pos = pos; 750 packet_id_pos = pos;
@@ -776,29 +754,6 @@ void MqttPacket::handlePublish() @@ -776,29 +754,6 @@ void MqttPacket::handlePublish()
776 { 754 {
777 throw ProtocolError("Packet ID 0 when publishing is invalid."); // [MQTT-2.3.1-1] 755 throw ProtocolError("Packet ID 0 when publishing is invalid."); // [MQTT-2.3.1-1]
778 } 756 }
779 -  
780 - if (qos == 1)  
781 - {  
782 - PubAck pubAck(packet_id);  
783 - MqttPacket response(pubAck);  
784 - sender->writeMqttPacket(response);  
785 - }  
786 - else  
787 - {  
788 - PubRec pubRec(packet_id);  
789 - MqttPacket response(pubRec);  
790 - sender->writeMqttPacket(response);  
791 -  
792 - if (sender->getSession()->incomingQoS2MessageIdInTransit(packet_id))  
793 - {  
794 - return;  
795 - }  
796 - else  
797 - {  
798 - // Doing this before the authentication on purpose, so when the publish is not allowed, the QoS control packets are allowed and can finish.  
799 - sender->getSession()->addIncomingQoS2MessageId(packet_id);  
800 - }  
801 - }  
802 } 757 }
803 758
804 if (this->protocolVersion >= ProtocolVersion::Mqtt5 ) 759 if (this->protocolVersion >= ProtocolVersion::Mqtt5 )
@@ -881,17 +836,16 @@ void MqttPacket::handlePublish() @@ -881,17 +836,16 @@ void MqttPacket::handlePublish()
881 } 836 }
882 } 837 }
883 838
884 - splitTopic(publishData.topic, publishData.subtopics); 839 + if (publishData.topic.empty())
  840 + throw ProtocolError("Empty publish topic");
885 841
886 if (!isValidUtf8(publishData.topic, true)) 842 if (!isValidUtf8(publishData.topic, true))
887 { 843 {
888 - logger->logf(LOG_WARNING, "Client '%s' published a message with invalid UTF8 or $/+/# in it. Dropping.", sender->repr().c_str());  
889 - return; 844 + const std::string err = formatString("Client '%s' published a message with invalid UTF8 or $/+/# in it. Dropping.", sender->repr().c_str());
  845 + logger->logf(LOG_WARNING, err.c_str());
  846 + throw ProtocolError(err);
890 } 847 }
891 848
892 - if (publishData.topic.empty())  
893 - throw ProtocolError("Empty publish topic");  
894 -  
895 #ifndef NDEBUG 849 #ifndef NDEBUG
896 logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", publishData.topic.c_str(), qos, retain, dup); 850 logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", publishData.topic.c_str(), qos, retain, dup);
897 #endif 851 #endif
@@ -902,21 +856,55 @@ void MqttPacket::handlePublish() @@ -902,21 +856,55 @@ void MqttPacket::handlePublish()
902 payloadStart = pos; 856 payloadStart = pos;
903 857
904 Authentication &authentication = *ThreadGlobals::getAuth(); 858 Authentication &authentication = *ThreadGlobals::getAuth();
905 - if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.subtopics, AclAccess::write, qos, retain, getUserProperties()) == AuthResult::success) 859 +
  860 + // Working with a local copy because the subscribing action will modify this->packet_id. See the PublishCopyFactory.
  861 + const uint16_t _packet_id = this->packet_id;
  862 +
  863 + if (qos == 2 && sender->getSession()->incomingQoS2MessageIdInTransit(_packet_id))
  864 + {
  865 + ackCode = ReasonCodes::PacketIdentifierInUse;
  866 + }
  867 + else
906 { 868 {
907 - if (retain) 869 + // Doing this before the authentication on purpose, so when the publish is not allowed, the QoS control packets are allowed and can finish.
  870 + if (qos == 2)
  871 + sender->getSession()->addIncomingQoS2MessageId(_packet_id);
  872 +
  873 + splitTopic(publishData.topic, publishData.subtopics);
  874 +
  875 + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.subtopics, AclAccess::write, qos, retain, getUserProperties()) == AuthResult::success)
  876 + {
  877 + if (retain)
  878 + {
  879 + std::string payload(readBytes(payloadLen), payloadLen);
  880 + sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(publishData.topic, publishData.subtopics, payload, qos);
  881 + }
  882 +
  883 + // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3].
  884 + // Existing subscribers don't get retain=1. [MQTT-3.3.1-9]
  885 + bites[0] &= 0b11110110;
  886 + first_byte = bites[0];
  887 +
  888 + PublishCopyFactory factory(this);
  889 + sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(factory);
  890 + }
  891 + else
908 { 892 {
909 - std::string payload(readBytes(payloadLen), payloadLen);  
910 - sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(publishData.topic, publishData.subtopics, payload, qos); 893 + ackCode = ReasonCodes::NotAuthorized;
911 } 894 }
  895 + }
912 896
913 - // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3].  
914 - // Existing subscribers don't get retain=1. [MQTT-3.3.1-9]  
915 - bites[0] &= 0b11110110;  
916 - first_byte = bites[0]; 897 +#ifndef NDEBUG
  898 + // Protection against using the altered packet id.
  899 + this->packet_id = 0;
  900 +#endif
917 901
918 - PublishCopyFactory factory(this);  
919 - sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(factory); 902 + if (qos > 0)
  903 + {
  904 + const PacketType responseType = qos == 1 ? PacketType::PUBACK : PacketType::PUBREC;
  905 + PubResponse pubAck(this->protocolVersion, responseType, ackCode, _packet_id);
  906 + MqttPacket response(pubAck);
  907 + sender->writeMqttPacket(response);
920 } 908 }
921 } 909 }
922 910
@@ -935,7 +923,7 @@ void MqttPacket::handlePubRec() @@ -935,7 +923,7 @@ void MqttPacket::handlePubRec()
935 sender->getSession()->clearQosMessage(packet_id); 923 sender->getSession()->clearQosMessage(packet_id);
936 sender->getSession()->addOutgoingQoS2MessageId(packet_id); 924 sender->getSession()->addOutgoingQoS2MessageId(packet_id);
937 925
938 - PubRel pubRel(packet_id); 926 + PubResponse pubRel(this->protocolVersion, PacketType::PUBREL, ReasonCodes::Success, packet_id);
939 MqttPacket response(pubRel); 927 MqttPacket response(pubRel);
940 sender->writeMqttPacket(response); 928 sender->writeMqttPacket(response);
941 } 929 }
@@ -952,7 +940,7 @@ void MqttPacket::handlePubRel() @@ -952,7 +940,7 @@ void MqttPacket::handlePubRel()
952 const uint16_t packet_id = readTwoBytesToUInt16(); 940 const uint16_t packet_id = readTwoBytesToUInt16();
953 sender->getSession()->removeIncomingQoS2MessageId(packet_id); 941 sender->getSession()->removeIncomingQoS2MessageId(packet_id);
954 942
955 - PubComp pubcomp(packet_id); 943 + PubResponse pubcomp(this->protocolVersion, PacketType::PUBCOMP, ReasonCodes::Success, packet_id);
956 MqttPacket response(pubcomp); 944 MqttPacket response(pubcomp);
957 sender->writeMqttPacket(response); 945 sender->writeMqttPacket(response);
958 } 946 }
mqttpacket.h
@@ -71,7 +71,6 @@ class MqttPacket @@ -71,7 +71,6 @@ class MqttPacket
71 void readUserProperty(); 71 void readUserProperty();
72 72
73 void calculateRemainingLength(); 73 void calculateRemainingLength();
74 - void pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits = 0);  
75 74
76 MqttPacket(const MqttPacket &other) = delete; 75 MqttPacket(const MqttPacket &other) = delete;
77 public: 76 public:
@@ -87,10 +86,7 @@ public: @@ -87,10 +86,7 @@ public:
87 MqttPacket(const SubAck &subAck); 86 MqttPacket(const SubAck &subAck);
88 MqttPacket(const UnsubAck &unsubAck); 87 MqttPacket(const UnsubAck &unsubAck);
89 MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish); 88 MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish);
90 - MqttPacket(const PubAck &pubAck);  
91 - MqttPacket(const PubRec &pubRec);  
92 - MqttPacket(const PubComp &pubComp);  
93 - MqttPacket(const PubRel &pubRel); 89 + MqttPacket(const PubResponse &pubAck);
94 90
95 static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender); 91 static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender);
96 92
session.cpp
@@ -268,7 +268,7 @@ uint64_t Session::sendPendingQosMessages() @@ -268,7 +268,7 @@ uint64_t Session::sendPendingQosMessages()
268 268
269 for (const uint16_t packet_id : outgoingQoS2MessageIds) 269 for (const uint16_t packet_id : outgoingQoS2MessageIds)
270 { 270 {
271 - PubRel pubRel(packet_id); 271 + PubResponse pubRel(c->getProtocolVersion(), PacketType::PUBREL, ReasonCodes::Success, packet_id);
272 MqttPacket packet(pubRel); 272 MqttPacket packet(pubRel);
273 count += c->writeMqttPacketAndBlameThisClient(packet); 273 count += c->writeMqttPacketAndBlameThisClient(packet);
274 } 274 }
@@ -313,12 +313,16 @@ std::shared_ptr&lt;Publish&gt; &amp;Session::getWill() @@ -313,12 +313,16 @@ std::shared_ptr&lt;Publish&gt; &amp;Session::getWill()
313 313
314 void Session::addIncomingQoS2MessageId(uint16_t packet_id) 314 void Session::addIncomingQoS2MessageId(uint16_t packet_id)
315 { 315 {
  316 + assert(packet_id > 0);
  317 +
316 std::unique_lock<std::mutex> locker(qosQueueMutex); 318 std::unique_lock<std::mutex> locker(qosQueueMutex);
317 incomingQoS2MessageIds.insert(packet_id); 319 incomingQoS2MessageIds.insert(packet_id);
318 } 320 }
319 321
320 bool Session::incomingQoS2MessageIdInTransit(uint16_t packet_id) 322 bool Session::incomingQoS2MessageIdInTransit(uint16_t packet_id)
321 { 323 {
  324 + assert(packet_id > 0);
  325 +
322 std::unique_lock<std::mutex> locker(qosQueueMutex); 326 std::unique_lock<std::mutex> locker(qosQueueMutex);
323 const auto it = incomingQoS2MessageIds.find(packet_id); 327 const auto it = incomingQoS2MessageIds.find(packet_id);
324 return it != incomingQoS2MessageIds.end(); 328 return it != incomingQoS2MessageIds.end();
@@ -326,6 +330,8 @@ bool Session::incomingQoS2MessageIdInTransit(uint16_t packet_id) @@ -326,6 +330,8 @@ bool Session::incomingQoS2MessageIdInTransit(uint16_t packet_id)
326 330
327 void Session::removeIncomingQoS2MessageId(u_int16_t packet_id) 331 void Session::removeIncomingQoS2MessageId(u_int16_t packet_id)
328 { 332 {
  333 + assert(packet_id > 0);
  334 +
329 std::unique_lock<std::mutex> locker(qosQueueMutex); 335 std::unique_lock<std::mutex> locker(qosQueueMutex);
330 336
331 #ifndef NDEBUG 337 #ifndef NDEBUG
types.cpp
@@ -186,58 +186,43 @@ bool WillDelayCompare(const std::shared_ptr&lt;Publish&gt; &amp;a, const std::weak_ptr&lt;Pub @@ -186,58 +186,43 @@ bool WillDelayCompare(const std::shared_ptr&lt;Publish&gt; &amp;a, const std::weak_ptr&lt;Pub
186 return a->will_delay < _b->will_delay; 186 return a->will_delay < _b->will_delay;
187 }; 187 };
188 188
189 -PubAck::PubAck(uint16_t packet_id) :  
190 - packet_id(packet_id)  
191 -{  
192 -  
193 -}  
194 -  
195 -// Packet has no payload and only a variable header, of length 2.  
196 -size_t PubAck::getLengthWithoutFixedHeader() const  
197 -{  
198 - return 2;  
199 -}  
200 -  
201 -UnsubAck::UnsubAck(uint16_t packet_id) :  
202 - packet_id(packet_id)  
203 -{  
204 -  
205 -}  
206 -  
207 -size_t UnsubAck::getLengthWithoutFixedHeader() const  
208 -{  
209 - return 2;  
210 -}  
211 -  
212 -PubRec::PubRec(uint16_t packet_id) : 189 +PubResponse::PubResponse(const ProtocolVersion protVersion, const PacketType packet_type, ReasonCodes reason_code, uint16_t packet_id) :
  190 + packet_type(packet_type),
  191 + protocol_version(protVersion),
  192 + reason_code(protVersion >= ProtocolVersion::Mqtt5 ? reason_code : ReasonCodes::Success),
213 packet_id(packet_id) 193 packet_id(packet_id)
214 { 194 {
215 - 195 + assert(packet_type == PacketType::PUBACK || packet_type == PacketType::PUBREC || packet_type == PacketType::PUBREL || packet_type == PacketType::PUBCOMP);
216 } 196 }
217 197
218 -size_t PubRec::getLengthWithoutFixedHeader() const 198 +uint8_t PubResponse::getLengthIncludingFixedHeader() const
219 { 199 {
220 - return 2; 200 + return 2 + getRemainingLength();
221 } 201 }
222 202
223 -PubComp::PubComp(uint16_t packet_id) :  
224 - packet_id(packet_id) 203 +uint8_t PubResponse::getRemainingLength() const
225 { 204 {
226 - 205 + // I'm leaving out the property length of 0: "If the Remaining Length is less than 4 there is no Property Length and the value of 0 is used"
  206 + const uint8_t result = needsReasonCode() ? 3 : 2;
  207 + return result;
227 } 208 }
228 209
229 -size_t PubComp::getLengthWithoutFixedHeader() const 210 +/**
  211 + * @brief "The Reason Code and Property Length can be omitted if the Reason Code is 0x00 (Success) and there are no Properties"
  212 + * @return
  213 + */
  214 +bool PubResponse::needsReasonCode() const
230 { 215 {
231 - return 2; 216 + return this->protocol_version >= ProtocolVersion::Mqtt5 && this->reason_code > ReasonCodes::Success;
232 } 217 }
233 218
234 -PubRel::PubRel(uint16_t packet_id) : 219 +UnsubAck::UnsubAck(uint16_t packet_id) :
235 packet_id(packet_id) 220 packet_id(packet_id)
236 { 221 {
237 222
238 } 223 }
239 224
240 -size_t PubRel::getLengthWithoutFixedHeader() const 225 +size_t UnsubAck::getLengthWithoutFixedHeader() const
241 { 226 {
242 return 2; 227 return 2;
243 } 228 }
@@ -230,36 +230,19 @@ public: @@ -230,36 +230,19 @@ public:
230 230
231 bool WillDelayCompare(const std::shared_ptr<Publish> &a, const std::weak_ptr<Publish> &b); 231 bool WillDelayCompare(const std::shared_ptr<Publish> &a, const std::weak_ptr<Publish> &b);
232 232
233 -class PubAck 233 +class PubResponse
234 { 234 {
235 public: 235 public:
236 - PubAck(uint16_t packet_id);  
237 - uint16_t packet_id;  
238 - size_t getLengthWithoutFixedHeader() const;  
239 -};  
240 -  
241 -class PubRec  
242 -{  
243 -public:  
244 - PubRec(uint16_t packet_id);  
245 - uint16_t packet_id;  
246 - size_t getLengthWithoutFixedHeader() const;  
247 -}; 236 + PubResponse(const PubResponse &other) = delete;
  237 + PubResponse(const ProtocolVersion protVersion, const PacketType packet_type, ReasonCodes reason_code, uint16_t packet_id);
248 238
249 -class PubComp  
250 -{  
251 -public:  
252 - PubComp(uint16_t packet_id);  
253 - uint16_t packet_id;  
254 - size_t getLengthWithoutFixedHeader() const;  
255 -};  
256 -  
257 -class PubRel  
258 -{  
259 -public:  
260 - PubRel(uint16_t packet_id); 239 + const PacketType packet_type;
  240 + const ProtocolVersion protocol_version;
  241 + const ReasonCodes reason_code;
261 uint16_t packet_id; 242 uint16_t packet_id;
262 - size_t getLengthWithoutFixedHeader() const; 243 + uint8_t getLengthIncludingFixedHeader() const;
  244 + uint8_t getRemainingLength() const;
  245 + bool needsReasonCode() const;
263 }; 246 };
264 247
265 #endif // TYPES_H 248 #endif // TYPES_H