Commit ee79271a6ccb3d8e9013f7a31d030165b2f4aee2

Authored by Wiebe Cazemier
1 parent 98366c6d

Working on puback with reason codes

I'm simplying/merging the rec, comp and rel packets, but I'm not sure it will
work. Committing as a safe point.

Later: I got it done as planned. Testing qos > 0 and mqtt5 still needs
to be done more.
mqttpacket.cpp
... ... @@ -153,47 +153,23 @@ MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const Publish &_pu
153 153 calculateRemainingLength();
154 154 }
155 155  
156   -/**
157   - * @brief MqttPacket::pubCommonConstruct is common code for constructors for all those empty pub control packets (ack, rec(eived), rel(ease), comp(lete)).
158   - * @param packet_id
159   - *
160   - * This functions cheats a bit and doesn't use calculateRemainingLength, because it's always 2. Be sure to allocate enough room in the vector when
161   - * you use this function (add 2 to the length).
162   - */
163   -void MqttPacket::pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits)
  156 +MqttPacket::MqttPacket(const PubResponse &pubAck) :
  157 + bites(pubAck.getLengthIncludingFixedHeader())
164 158 {
165   - assert(firstByteDefaultBits <= 0xF);
  159 + this->protocolVersion = pubAck.protocol_version;
166 160  
167 161 fixed_header_length = 2;
168   - first_byte = (static_cast<uint8_t>(packetType) << 4) | firstByteDefaultBits;
  162 + const uint8_t firstByteDefaultBits = pubAck.packet_type == PacketType::PUBREL ? 0b0010 : 0;
  163 + this->first_byte = (static_cast<uint8_t>(pubAck.packet_type) << 4) | firstByteDefaultBits;
169 164 writeByte(first_byte);
170   - writeByte(2); // length is always 2.
171   - packet_id_pos = pos;
172   - writeUint16(packet_id);
173   -}
174   -
175   -MqttPacket::MqttPacket(const PubAck &pubAck) :
176   - bites(pubAck.getLengthWithoutFixedHeader() + 2)
177   -{
178   - pubCommonConstruct(pubAck.packet_id, PacketType::PUBACK);
179   -}
180   -
181   -MqttPacket::MqttPacket(const PubRec &pubRec) :
182   - bites(pubRec.getLengthWithoutFixedHeader() + 2)
183   -{
184   - pubCommonConstruct(pubRec.packet_id, PacketType::PUBREC);
185   -}
186   -
187   -MqttPacket::MqttPacket(const PubComp &pubComp) :
188   - bites(pubComp.getLengthWithoutFixedHeader() + 2)
189   -{
190   - pubCommonConstruct(pubComp.packet_id, PacketType::PUBCOMP);
191   -}
  165 + writeByte(pubAck.getRemainingLength());
  166 + this->packet_id_pos = this->pos;
  167 + writeUint16(pubAck.packet_id);
192 168  
193   -MqttPacket::MqttPacket(const PubRel &pubRel) :
194   - bites(pubRel.getLengthWithoutFixedHeader() + 2)
195   -{
196   - pubCommonConstruct(pubRel.packet_id, PacketType::PUBREL, 0b0010);
  169 + if (pubAck.needsReasonCode())
  170 + {
  171 + writeByte(static_cast<uint8_t>(pubAck.reason_code));
  172 + }
197 173 }
198 174  
199 175 void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender)
... ... @@ -767,6 +743,8 @@ void MqttPacket::handlePublish()
767 743  
768 744 publishData.topic = std::string(readBytes(variable_header_length), variable_header_length);
769 745  
  746 + ReasonCodes ackCode = ReasonCodes::Success;
  747 +
770 748 if (qos)
771 749 {
772 750 packet_id_pos = pos;
... ... @@ -776,29 +754,6 @@ void MqttPacket::handlePublish()
776 754 {
777 755 throw ProtocolError("Packet ID 0 when publishing is invalid."); // [MQTT-2.3.1-1]
778 756 }
779   -
780   - if (qos == 1)
781   - {
782   - PubAck pubAck(packet_id);
783   - MqttPacket response(pubAck);
784   - sender->writeMqttPacket(response);
785   - }
786   - else
787   - {
788   - PubRec pubRec(packet_id);
789   - MqttPacket response(pubRec);
790   - sender->writeMqttPacket(response);
791   -
792   - if (sender->getSession()->incomingQoS2MessageIdInTransit(packet_id))
793   - {
794   - return;
795   - }
796   - else
797   - {
798   - // Doing this before the authentication on purpose, so when the publish is not allowed, the QoS control packets are allowed and can finish.
799   - sender->getSession()->addIncomingQoS2MessageId(packet_id);
800   - }
801   - }
802 757 }
803 758  
804 759 if (this->protocolVersion >= ProtocolVersion::Mqtt5 )
... ... @@ -881,17 +836,16 @@ void MqttPacket::handlePublish()
881 836 }
882 837 }
883 838  
884   - splitTopic(publishData.topic, publishData.subtopics);
  839 + if (publishData.topic.empty())
  840 + throw ProtocolError("Empty publish topic");
885 841  
886 842 if (!isValidUtf8(publishData.topic, true))
887 843 {
888   - logger->logf(LOG_WARNING, "Client '%s' published a message with invalid UTF8 or $/+/# in it. Dropping.", sender->repr().c_str());
889   - return;
  844 + const std::string err = formatString("Client '%s' published a message with invalid UTF8 or $/+/# in it. Dropping.", sender->repr().c_str());
  845 + logger->logf(LOG_WARNING, err.c_str());
  846 + throw ProtocolError(err);
890 847 }
891 848  
892   - if (publishData.topic.empty())
893   - throw ProtocolError("Empty publish topic");
894   -
895 849 #ifndef NDEBUG
896 850 logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", publishData.topic.c_str(), qos, retain, dup);
897 851 #endif
... ... @@ -902,21 +856,55 @@ void MqttPacket::handlePublish()
902 856 payloadStart = pos;
903 857  
904 858 Authentication &authentication = *ThreadGlobals::getAuth();
905   - if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.subtopics, AclAccess::write, qos, retain, getUserProperties()) == AuthResult::success)
  859 +
  860 + // Working with a local copy because the subscribing action will modify this->packet_id. See the PublishCopyFactory.
  861 + const uint16_t _packet_id = this->packet_id;
  862 +
  863 + if (qos == 2 && sender->getSession()->incomingQoS2MessageIdInTransit(_packet_id))
  864 + {
  865 + ackCode = ReasonCodes::PacketIdentifierInUse;
  866 + }
  867 + else
906 868 {
907   - if (retain)
  869 + // Doing this before the authentication on purpose, so when the publish is not allowed, the QoS control packets are allowed and can finish.
  870 + if (qos == 2)
  871 + sender->getSession()->addIncomingQoS2MessageId(_packet_id);
  872 +
  873 + splitTopic(publishData.topic, publishData.subtopics);
  874 +
  875 + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.subtopics, AclAccess::write, qos, retain, getUserProperties()) == AuthResult::success)
  876 + {
  877 + if (retain)
  878 + {
  879 + std::string payload(readBytes(payloadLen), payloadLen);
  880 + sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(publishData.topic, publishData.subtopics, payload, qos);
  881 + }
  882 +
  883 + // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3].
  884 + // Existing subscribers don't get retain=1. [MQTT-3.3.1-9]
  885 + bites[0] &= 0b11110110;
  886 + first_byte = bites[0];
  887 +
  888 + PublishCopyFactory factory(this);
  889 + sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(factory);
  890 + }
  891 + else
908 892 {
909   - std::string payload(readBytes(payloadLen), payloadLen);
910   - sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(publishData.topic, publishData.subtopics, payload, qos);
  893 + ackCode = ReasonCodes::NotAuthorized;
911 894 }
  895 + }
912 896  
913   - // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3].
914   - // Existing subscribers don't get retain=1. [MQTT-3.3.1-9]
915   - bites[0] &= 0b11110110;
916   - first_byte = bites[0];
  897 +#ifndef NDEBUG
  898 + // Protection against using the altered packet id.
  899 + this->packet_id = 0;
  900 +#endif
917 901  
918   - PublishCopyFactory factory(this);
919   - sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(factory);
  902 + if (qos > 0)
  903 + {
  904 + const PacketType responseType = qos == 1 ? PacketType::PUBACK : PacketType::PUBREC;
  905 + PubResponse pubAck(this->protocolVersion, responseType, ackCode, _packet_id);
  906 + MqttPacket response(pubAck);
  907 + sender->writeMqttPacket(response);
920 908 }
921 909 }
922 910  
... ... @@ -935,7 +923,7 @@ void MqttPacket::handlePubRec()
935 923 sender->getSession()->clearQosMessage(packet_id);
936 924 sender->getSession()->addOutgoingQoS2MessageId(packet_id);
937 925  
938   - PubRel pubRel(packet_id);
  926 + PubResponse pubRel(this->protocolVersion, PacketType::PUBREL, ReasonCodes::Success, packet_id);
939 927 MqttPacket response(pubRel);
940 928 sender->writeMqttPacket(response);
941 929 }
... ... @@ -952,7 +940,7 @@ void MqttPacket::handlePubRel()
952 940 const uint16_t packet_id = readTwoBytesToUInt16();
953 941 sender->getSession()->removeIncomingQoS2MessageId(packet_id);
954 942  
955   - PubComp pubcomp(packet_id);
  943 + PubResponse pubcomp(this->protocolVersion, PacketType::PUBCOMP, ReasonCodes::Success, packet_id);
956 944 MqttPacket response(pubcomp);
957 945 sender->writeMqttPacket(response);
958 946 }
... ...
mqttpacket.h
... ... @@ -71,7 +71,6 @@ class MqttPacket
71 71 void readUserProperty();
72 72  
73 73 void calculateRemainingLength();
74   - void pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits = 0);
75 74  
76 75 MqttPacket(const MqttPacket &other) = delete;
77 76 public:
... ... @@ -87,10 +86,7 @@ public:
87 86 MqttPacket(const SubAck &subAck);
88 87 MqttPacket(const UnsubAck &unsubAck);
89 88 MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish);
90   - MqttPacket(const PubAck &pubAck);
91   - MqttPacket(const PubRec &pubRec);
92   - MqttPacket(const PubComp &pubComp);
93   - MqttPacket(const PubRel &pubRel);
  89 + MqttPacket(const PubResponse &pubAck);
94 90  
95 91 static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender);
96 92  
... ...
session.cpp
... ... @@ -268,7 +268,7 @@ uint64_t Session::sendPendingQosMessages()
268 268  
269 269 for (const uint16_t packet_id : outgoingQoS2MessageIds)
270 270 {
271   - PubRel pubRel(packet_id);
  271 + PubResponse pubRel(c->getProtocolVersion(), PacketType::PUBREL, ReasonCodes::Success, packet_id);
272 272 MqttPacket packet(pubRel);
273 273 count += c->writeMqttPacketAndBlameThisClient(packet);
274 274 }
... ... @@ -313,12 +313,16 @@ std::shared_ptr&lt;Publish&gt; &amp;Session::getWill()
313 313  
314 314 void Session::addIncomingQoS2MessageId(uint16_t packet_id)
315 315 {
  316 + assert(packet_id > 0);
  317 +
316 318 std::unique_lock<std::mutex> locker(qosQueueMutex);
317 319 incomingQoS2MessageIds.insert(packet_id);
318 320 }
319 321  
320 322 bool Session::incomingQoS2MessageIdInTransit(uint16_t packet_id)
321 323 {
  324 + assert(packet_id > 0);
  325 +
322 326 std::unique_lock<std::mutex> locker(qosQueueMutex);
323 327 const auto it = incomingQoS2MessageIds.find(packet_id);
324 328 return it != incomingQoS2MessageIds.end();
... ... @@ -326,6 +330,8 @@ bool Session::incomingQoS2MessageIdInTransit(uint16_t packet_id)
326 330  
327 331 void Session::removeIncomingQoS2MessageId(u_int16_t packet_id)
328 332 {
  333 + assert(packet_id > 0);
  334 +
329 335 std::unique_lock<std::mutex> locker(qosQueueMutex);
330 336  
331 337 #ifndef NDEBUG
... ...
types.cpp
... ... @@ -186,58 +186,43 @@ bool WillDelayCompare(const std::shared_ptr&lt;Publish&gt; &amp;a, const std::weak_ptr&lt;Pub
186 186 return a->will_delay < _b->will_delay;
187 187 };
188 188  
189   -PubAck::PubAck(uint16_t packet_id) :
190   - packet_id(packet_id)
191   -{
192   -
193   -}
194   -
195   -// Packet has no payload and only a variable header, of length 2.
196   -size_t PubAck::getLengthWithoutFixedHeader() const
197   -{
198   - return 2;
199   -}
200   -
201   -UnsubAck::UnsubAck(uint16_t packet_id) :
202   - packet_id(packet_id)
203   -{
204   -
205   -}
206   -
207   -size_t UnsubAck::getLengthWithoutFixedHeader() const
208   -{
209   - return 2;
210   -}
211   -
212   -PubRec::PubRec(uint16_t packet_id) :
  189 +PubResponse::PubResponse(const ProtocolVersion protVersion, const PacketType packet_type, ReasonCodes reason_code, uint16_t packet_id) :
  190 + packet_type(packet_type),
  191 + protocol_version(protVersion),
  192 + reason_code(protVersion >= ProtocolVersion::Mqtt5 ? reason_code : ReasonCodes::Success),
213 193 packet_id(packet_id)
214 194 {
215   -
  195 + assert(packet_type == PacketType::PUBACK || packet_type == PacketType::PUBREC || packet_type == PacketType::PUBREL || packet_type == PacketType::PUBCOMP);
216 196 }
217 197  
218   -size_t PubRec::getLengthWithoutFixedHeader() const
  198 +uint8_t PubResponse::getLengthIncludingFixedHeader() const
219 199 {
220   - return 2;
  200 + return 2 + getRemainingLength();
221 201 }
222 202  
223   -PubComp::PubComp(uint16_t packet_id) :
224   - packet_id(packet_id)
  203 +uint8_t PubResponse::getRemainingLength() const
225 204 {
226   -
  205 + // I'm leaving out the property length of 0: "If the Remaining Length is less than 4 there is no Property Length and the value of 0 is used"
  206 + const uint8_t result = needsReasonCode() ? 3 : 2;
  207 + return result;
227 208 }
228 209  
229   -size_t PubComp::getLengthWithoutFixedHeader() const
  210 +/**
  211 + * @brief "The Reason Code and Property Length can be omitted if the Reason Code is 0x00 (Success) and there are no Properties"
  212 + * @return
  213 + */
  214 +bool PubResponse::needsReasonCode() const
230 215 {
231   - return 2;
  216 + return this->protocol_version >= ProtocolVersion::Mqtt5 && this->reason_code > ReasonCodes::Success;
232 217 }
233 218  
234   -PubRel::PubRel(uint16_t packet_id) :
  219 +UnsubAck::UnsubAck(uint16_t packet_id) :
235 220 packet_id(packet_id)
236 221 {
237 222  
238 223 }
239 224  
240   -size_t PubRel::getLengthWithoutFixedHeader() const
  225 +size_t UnsubAck::getLengthWithoutFixedHeader() const
241 226 {
242 227 return 2;
243 228 }
... ...
... ... @@ -230,36 +230,19 @@ public:
230 230  
231 231 bool WillDelayCompare(const std::shared_ptr<Publish> &a, const std::weak_ptr<Publish> &b);
232 232  
233   -class PubAck
  233 +class PubResponse
234 234 {
235 235 public:
236   - PubAck(uint16_t packet_id);
237   - uint16_t packet_id;
238   - size_t getLengthWithoutFixedHeader() const;
239   -};
240   -
241   -class PubRec
242   -{
243   -public:
244   - PubRec(uint16_t packet_id);
245   - uint16_t packet_id;
246   - size_t getLengthWithoutFixedHeader() const;
247   -};
  236 + PubResponse(const PubResponse &other) = delete;
  237 + PubResponse(const ProtocolVersion protVersion, const PacketType packet_type, ReasonCodes reason_code, uint16_t packet_id);
248 238  
249   -class PubComp
250   -{
251   -public:
252   - PubComp(uint16_t packet_id);
253   - uint16_t packet_id;
254   - size_t getLengthWithoutFixedHeader() const;
255   -};
256   -
257   -class PubRel
258   -{
259   -public:
260   - PubRel(uint16_t packet_id);
  239 + const PacketType packet_type;
  240 + const ProtocolVersion protocol_version;
  241 + const ReasonCodes reason_code;
261 242 uint16_t packet_id;
262   - size_t getLengthWithoutFixedHeader() const;
  243 + uint8_t getLengthIncludingFixedHeader() const;
  244 + uint8_t getRemainingLength() const;
  245 + bool needsReasonCode() const;
263 246 };
264 247  
265 248 #endif // TYPES_H
... ...