Commit 193f509d98f3c0a5e4e43c76f67d2ab16e25c982

Authored by Wiebe Cazemier
1 parent 5f1d8259

Defined a reason code for many errors

I think it's very hard to distinguish between protocol error and
malformed packet. It's kind of arbitrary...
client.cpp
... ... @@ -358,7 +358,7 @@ void Client::resetBuffersIfEligible()
358 358 void Client::setTopicAlias(const uint16_t alias_id, const std::string &topic)
359 359 {
360 360 if (alias_id == 0)
361   - throw ProtocolError("Client tried to set topic alias 0, which is a protocol error.");
  361 + throw ProtocolError("Client tried to set topic alias 0, which is a protocol error.", ReasonCodes::ProtocolError);
362 362  
363 363 if (topic.empty())
364 364 return;
... ... @@ -367,7 +367,8 @@ void Client::setTopicAlias(const uint16_t alias_id, const std::string &topic)
367 367  
368 368 // The specs actually say "The Client MUST NOT send a Topic Alias [...] to the Server greater than this value [Topic Alias Maximum]". So, it's not about count.
369 369 if (alias_id > settings->maxIncomingTopicAliasValue)
370   - throw ProtocolError(formatString("Client tried to set more topic aliases than the server max of %d per client", settings->maxIncomingTopicAliasValue));
  370 + throw ProtocolError(formatString("Client tried to set more topic aliases than the server max of %d per client", settings->maxIncomingTopicAliasValue),
  371 + ReasonCodes::ProtocolError);
371 372  
372 373 this->incomingTopicAliases[alias_id] = topic;
373 374 }
... ...
mqttpacket.cpp
... ... @@ -133,7 +133,7 @@ MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish)
133 133 {
134 134 if (_publish.topic.length() > 0xFFFF)
135 135 {
136   - throw ProtocolError("Topic path too long.");
  136 + throw ProtocolError("Topic path too long.", ReasonCodes::ProtocolError);
137 137 }
138 138  
139 139 this->protocolVersion = protocolVersion;
... ... @@ -232,7 +232,7 @@ void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packe
232 232 fixed_header_length++;
233 233  
234 234 if (fixed_header_length > 5)
235   - throw ProtocolError("Packet signifies more than 5 bytes in variable length header. Invalid.");
  235 + throw ProtocolError("Packet signifies more than 5 bytes in variable length header. Invalid.", ReasonCodes::MalformedPacket);
236 236  
237 237 // This happens when you only don't have all the bytes that specify the remaining length.
238 238 if (fixed_header_length > buf.usedBytes())
... ... @@ -242,19 +242,19 @@ void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packe
242 242 packet_length += (encodedByte & 127) * multiplier;
243 243 multiplier *= 128;
244 244 if (multiplier > 128*128*128*128)
245   - throw ProtocolError("Malformed Remaining Length.");
  245 + throw ProtocolError("Malformed Remaining Length.", ReasonCodes::MalformedPacket);
246 246 }
247 247 while ((encodedByte & 128) != 0);
248 248 packet_length += fixed_header_length;
249 249  
250 250 if (sender && !sender->getAuthenticated() && packet_length >= 1024*1024)
251 251 {
252   - throw ProtocolError("An unauthenticated client sends a packet of 1 MB or bigger? Probably it's just random bytes.");
  252 + throw ProtocolError("An unauthenticated client sends a packet of 1 MB or bigger? Probably it's just random bytes.", ReasonCodes::ProtocolError);
253 253 }
254 254  
255 255 if (packet_length > ABSOLUTE_MAX_PACKET_SIZE)
256 256 {
257   - throw ProtocolError("A client sends a packet claiming to be bigger than the maximum MQTT allows.");
  257 + throw ProtocolError("A client sends a packet claiming to be bigger than the maximum MQTT allows.", ReasonCodes::ProtocolError);
258 258 }
259 259  
260 260 if (packet_length <= buf.usedBytes())
... ... @@ -269,7 +269,7 @@ void MqttPacket::bufferToMqttPackets(CirBuf &amp;buf, std::vector&lt;MqttPacket&gt; &amp;packe
269 269 void MqttPacket::handle()
270 270 {
271 271 if (packetType == PacketType::Reserved)
272   - throw ProtocolError("Packet type 0 specified, which is reserved and invalid.");
  272 + throw ProtocolError("Packet type 0 specified, which is reserved and invalid.", ReasonCodes::MalformedPacket);
273 273  
274 274 if (packetType != PacketType::CONNECT)
275 275 {
... ... @@ -305,7 +305,7 @@ void MqttPacket::handle()
305 305 void MqttPacket::handleConnect()
306 306 {
307 307 if (sender->hasConnectPacketSeen())
308   - throw ProtocolError("Client already sent a CONNECT.");
  308 + throw ProtocolError("Client already sent a CONNECT.", ReasonCodes::ProtocolError);
309 309  
310 310 std::shared_ptr<SubscriptionStore> subscriptionStore = sender->getThreadData()->getSubscriptionStore();
311 311  
... ... @@ -348,7 +348,7 @@ void MqttPacket::handleConnect()
348 348 bool reserved = !!(flagByte & 0b00000001);
349 349  
350 350 if (reserved)
351   - throw ProtocolError("Protocol demands reserved flag in CONNECT is 0");
  351 + throw ProtocolError("Protocol demands reserved flag in CONNECT is 0", ReasonCodes::MalformedPacket);
352 352  
353 353  
354 354 bool user_name_flag = !!(flagByte & 0b10000000);
... ... @@ -359,7 +359,7 @@ void MqttPacket::handleConnect()
359 359 bool clean_start = !!(flagByte & 0b00000010);
360 360  
361 361 if (will_qos > 2)
362   - throw ProtocolError("Invalid QoS for will.");
  362 + throw ProtocolError("Invalid QoS for will.", ReasonCodes::MalformedPacket);
363 363  
364 364 uint16_t keep_alive = readTwoBytesToUInt16();
365 365  
... ... @@ -415,7 +415,7 @@ void MqttPacket::handleConnect()
415 415 break;
416 416 }
417 417 default:
418   - throw ProtocolError("Invalid connect property.");
  418 + throw ProtocolError("Invalid connect property.", ReasonCodes::ProtocolError);
419 419 }
420 420 }
421 421 }
... ... @@ -489,7 +489,7 @@ void MqttPacket::handleConnect()
489 489 break;
490 490 }
491 491 default:
492   - throw ProtocolError("Invalid will property in connect.");
  492 + throw ProtocolError("Invalid will property in connect.", ReasonCodes::ProtocolError);
493 493 }
494 494 }
495 495 }
... ... @@ -506,7 +506,7 @@ void MqttPacket::handleConnect()
506 506 username = std::string(readBytes(user_name_length), user_name_length);
507 507  
508 508 if (username.empty())
509   - throw ProtocolError("Username flagged as present, but it's 0 bytes.");
  509 + throw ProtocolError("Username flagged as present, but it's 0 bytes.", ReasonCodes::MalformedPacket);
510 510 }
511 511 if (password_flag)
512 512 {
... ... @@ -637,7 +637,7 @@ void MqttPacket::handleConnect()
637 637 }
638 638 else
639 639 {
640   - throw ProtocolError("Invalid variable header length. Garbage?");
  640 + throw ProtocolError("Invalid variable header length. Garbage?", ReasonCodes::MalformedPacket);
641 641 }
642 642 }
643 643  
... ... @@ -658,13 +658,13 @@ void MqttPacket::handleSubscribe()
658 658 const char firstByteFirstNibble = (first_byte & 0x0F);
659 659  
660 660 if (firstByteFirstNibble != 2)
661   - throw ProtocolError("First LSB of first byte is wrong value for subscribe packet.");
  661 + throw ProtocolError("First LSB of first byte is wrong value for subscribe packet.", ReasonCodes::MalformedPacket);
662 662  
663 663 const uint16_t packet_id = readTwoBytesToUInt16();
664 664  
665 665 if (packet_id == 0)
666 666 {
667   - throw ProtocolError("Packet ID 0 when subscribing is invalid."); // [MQTT-2.3.1-1]
  667 + throw ProtocolError("Packet ID 0 when subscribing is invalid.", ReasonCodes::MalformedPacket); // [MQTT-2.3.1-1]
668 668 }
669 669  
670 670 if (protocolVersion == ProtocolVersion::Mqtt5)
... ... @@ -685,7 +685,7 @@ void MqttPacket::handleSubscribe()
685 685 readUserProperty();
686 686 break;
687 687 default:
688   - throw ProtocolError("Invalid subscribe property.");
  688 + throw ProtocolError("Invalid subscribe property.", ReasonCodes::ProtocolError);
689 689 }
690 690 }
691 691 }
... ... @@ -699,15 +699,15 @@ void MqttPacket::handleSubscribe()
699 699 std::string topic(readBytes(topicLength), topicLength);
700 700  
701 701 if (topic.empty() || !isValidUtf8(topic))
702   - throw ProtocolError("Subscribe topic not valid UTF-8.");
  702 + throw ProtocolError("Subscribe topic not valid UTF-8.", ReasonCodes::MalformedPacket);
703 703  
704 704 if (!isValidSubscribePath(topic))
705   - throw ProtocolError(formatString("Invalid subscribe path: %s", topic.c_str()));
  705 + throw ProtocolError(formatString("Invalid subscribe path: %s", topic.c_str()), ReasonCodes::MalformedPacket);
706 706  
707 707 uint8_t qos = readByte();
708 708  
709 709 if (qos > 2)
710   - throw ProtocolError("QoS is greater than 2, and/or reserved bytes in QoS field are not 0.");
  710 + throw ProtocolError("QoS is greater than 2, and/or reserved bytes in QoS field are not 0.", ReasonCodes::MalformedPacket);
711 711  
712 712 std::vector<std::string> subtopics;
713 713 splitTopic(topic, subtopics);
... ... @@ -730,7 +730,7 @@ void MqttPacket::handleSubscribe()
730 730 // MQTT-3.8.3-3
731 731 if (subs_reponse_codes.empty())
732 732 {
733   - throw ProtocolError("No topics specified to subscribe to.");
  733 + throw ProtocolError("No topics specified to subscribe to.", ReasonCodes::MalformedPacket);
734 734 }
735 735  
736 736 SubAck subAck(this->protocolVersion, packet_id, subs_reponse_codes);
... ... @@ -743,7 +743,7 @@ void MqttPacket::handleUnsubscribe()
743 743 const char firstByteFirstNibble = (first_byte & 0x0F);
744 744  
745 745 if (firstByteFirstNibble != 2)
746   - throw ProtocolError("First LSB of first byte is wrong value for subscribe packet.");
  746 + throw ProtocolError("First LSB of first byte is wrong value for subscribe packet.", ReasonCodes::MalformedPacket);
747 747  
748 748 const uint16_t packet_id = readTwoBytesToUInt16();
749 749  
... ... @@ -767,7 +767,7 @@ void MqttPacket::handleUnsubscribe()
767 767 readUserProperty();
768 768 break;
769 769 default:
770   - throw ProtocolError("Invalid unsubscribe property.");
  770 + throw ProtocolError("Invalid unsubscribe property.", ReasonCodes::ProtocolError);
771 771 }
772 772 }
773 773 }
... ... @@ -782,7 +782,7 @@ void MqttPacket::handleUnsubscribe()
782 782 std::string topic(readBytes(topicLength), topicLength);
783 783  
784 784 if (topic.empty() || !isValidUtf8(topic))
785   - throw ProtocolError("Subscribe topic not valid UTF-8.");
  785 + throw ProtocolError("Subscribe topic not valid UTF-8.", ReasonCodes::MalformedPacket);
786 786  
787 787 sender->getThreadData()->getSubscriptionStore()->removeSubscription(sender, topic);
788 788 logger->logf(LOG_UNSUBSCRIBE, "Client '%s' unsubscribed from '%s'", sender->repr().c_str(), topic.c_str());
... ... @@ -791,7 +791,7 @@ void MqttPacket::handleUnsubscribe()
791 791 // MQTT-3.10.3-2
792 792 if (numberOfUnsubs == 0)
793 793 {
794   - throw ProtocolError("No topics specified to unsubscribe to.");
  794 + throw ProtocolError("No topics specified to unsubscribe to.", ReasonCodes::MalformedPacket);
795 795 }
796 796  
797 797 UnsubAck unsubAck(this->sender->getProtocolVersion(), packet_id, numberOfUnsubs);
... ... @@ -808,10 +808,10 @@ void MqttPacket::parsePublishData()
808 808 publishData.qos = (first_byte & 0b00000110) >> 1;
809 809  
810 810 if (publishData.qos > 2)
811   - throw ProtocolError("QoS 3 is a protocol violation.");
  811 + throw ProtocolError("QoS 3 is a protocol violation.", ReasonCodes::MalformedPacket);
812 812  
813 813 if (publishData.qos == 0 && dup)
814   - throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal.");
  814 + throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal.", ReasonCodes::MalformedPacket);
815 815  
816 816 publishData.topic = std::string(readBytes(variable_header_length), variable_header_length);
817 817  
... ... @@ -822,7 +822,7 @@ void MqttPacket::parsePublishData()
822 822  
823 823 if (packet_id == 0)
824 824 {
825   - throw ProtocolError("Packet ID 0 when publishing is invalid."); // [MQTT-2.3.1-1]
  825 + throw ProtocolError("Packet ID 0 when publishing is invalid.", ReasonCodes::MalformedPacket); // [MQTT-2.3.1-1]
826 826 }
827 827 }
828 828  
... ... @@ -906,13 +906,13 @@ void MqttPacket::parsePublishData()
906 906 break;
907 907 }
908 908 default:
909   - throw ProtocolError("Invalid property in publish.");
  909 + throw ProtocolError("Invalid property in publish.", ReasonCodes::ProtocolError);
910 910 }
911 911 }
912 912 }
913 913  
914 914 if (publishData.topic.empty())
915   - throw ProtocolError("Empty publish topic");
  915 + throw ProtocolError("Empty publish topic", ReasonCodes::ProtocolError);
916 916  
917 917 payloadLen = remainingAfterPos();
918 918 payloadStart = pos;
... ... @@ -926,7 +926,7 @@ void MqttPacket::handlePublish()
926 926 {
927 927 const std::string err = formatString("Client '%s' published a message with invalid UTF8 or $/+/# in it. Dropping.", sender->repr().c_str());
928 928 logger->logf(LOG_WARNING, err.c_str());
929   - throw ProtocolError(err);
  929 + throw ProtocolError(err, ReasonCodes::ProtocolError);
930 930 }
931 931  
932 932 #ifndef NDEBUG
... ... @@ -1017,7 +1017,7 @@ void MqttPacket::handlePubRel()
1017 1017 {
1018 1018 // MQTT-3.6.1-1, but why do we care, and only care for certain control packets?
1019 1019 if (first_byte & 0b1101)
1020   - throw ProtocolError("PUBREL first byte LSB must be 0010.");
  1020 + throw ProtocolError("PUBREL first byte LSB must be 0010.", ReasonCodes::MalformedPacket);
1021 1021  
1022 1022 const uint16_t packet_id = readTwoBytesToUInt16();
1023 1023 sender->getSession()->removeIncomingQoS2MessageId(packet_id);
... ... @@ -1165,7 +1165,7 @@ bool MqttPacket::containsFixedHeader() const
1165 1165 char *MqttPacket::readBytes(size_t length)
1166 1166 {
1167 1167 if (pos + length > bites.size())
1168   - throw ProtocolError("Invalid packet: header specifies invalid length.");
  1168 + throw ProtocolError("Invalid packet: header specifies invalid length.", ReasonCodes::MalformedPacket);
1169 1169  
1170 1170 char *b = &bites[pos];
1171 1171 pos += length;
... ... @@ -1175,7 +1175,7 @@ char *MqttPacket::readBytes(size_t length)
1175 1175 char MqttPacket::readByte()
1176 1176 {
1177 1177 if (pos + 1 > bites.size())
1178   - throw ProtocolError("Invalid packet: header specifies invalid length.");
  1178 + throw ProtocolError("Invalid packet: header specifies invalid length.", ReasonCodes::MalformedPacket);
1179 1179  
1180 1180 char b = bites[pos++];
1181 1181 return b;
... ... @@ -1184,7 +1184,7 @@ char MqttPacket::readByte()
1184 1184 void MqttPacket::writeByte(char b)
1185 1185 {
1186 1186 if (pos + 1 > bites.size())
1187   - throw ProtocolError("Exceeding packet size");
  1187 + throw ProtocolError("Exceeding packet size", ReasonCodes::MalformedPacket);
1188 1188  
1189 1189 bites[pos++] = b;
1190 1190 }
... ... @@ -1192,7 +1192,7 @@ void MqttPacket::writeByte(char b)
1192 1192 void MqttPacket::writeUint16(uint16_t x)
1193 1193 {
1194 1194 if (pos + 2 > bites.size())
1195   - throw ProtocolError("Exceeding packet size");
  1195 + throw ProtocolError("Exceeding packet size", ReasonCodes::MalformedPacket);
1196 1196  
1197 1197 const uint8_t a = static_cast<uint8_t>(x >> 8);
1198 1198 const uint8_t b = static_cast<uint8_t>(x);
... ... @@ -1204,7 +1204,7 @@ void MqttPacket::writeUint16(uint16_t x)
1204 1204 void MqttPacket::writeBytes(const char *b, size_t len)
1205 1205 {
1206 1206 if (pos + len > bites.size())
1207   - throw ProtocolError("Exceeding packet size");
  1207 + throw ProtocolError("Exceeding packet size", ReasonCodes::MalformedPacket);
1208 1208  
1209 1209 memcpy(&bites[pos], b, len);
1210 1210 pos += len;
... ... @@ -1232,7 +1232,7 @@ void MqttPacket::writeVariableByteInt(const VariableByteInt &amp;v)
1232 1232 uint16_t MqttPacket::readTwoBytesToUInt16()
1233 1233 {
1234 1234 if (pos + 2 > bites.size())
1235   - throw ProtocolError("Invalid packet: header specifies invalid length.");
  1235 + throw ProtocolError("Invalid packet: header specifies invalid length.", ReasonCodes::MalformedPacket);
1236 1236  
1237 1237 uint8_t a = bites[pos];
1238 1238 uint8_t b = bites[pos+1];
... ... @@ -1244,7 +1244,7 @@ uint16_t MqttPacket::readTwoBytesToUInt16()
1244 1244 uint32_t MqttPacket::readFourBytesToUint32()
1245 1245 {
1246 1246 if (pos + 4 > bites.size())
1247   - throw ProtocolError("Invalid packet: header specifies invalid length.");
  1247 + throw ProtocolError("Invalid packet: header specifies invalid length.", ReasonCodes::MalformedPacket);
1248 1248  
1249 1249 const uint8_t a = bites[pos++];
1250 1250 const uint8_t b = bites[pos++];
... ... @@ -1267,13 +1267,13 @@ size_t MqttPacket::decodeVariableByteIntAtPos()
1267 1267 do
1268 1268 {
1269 1269 if (pos >= bites.size())
1270   - throw ProtocolError("Variable byte int length goes out of packet. Corrupt.");
  1270 + throw ProtocolError("Variable byte int length goes out of packet. Corrupt.", ReasonCodes::MalformedPacket);
1271 1271  
1272 1272 encodedByte = bites[pos++];
1273 1273 value += (encodedByte & 127) * multiplier;
1274 1274 multiplier *= 128;
1275 1275 if (multiplier > 128*128*128*128)
1276   - throw ProtocolError("Malformed Remaining Length.");
  1276 + throw ProtocolError("Malformed Remaining Length.", ReasonCodes::MalformedPacket);
1277 1277 }
1278 1278 while ((encodedByte & 128) != 0);
1279 1279  
... ...
subscriptionstore.cpp
... ... @@ -214,7 +214,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr&lt;Client&gt;
214 214 lock_guard.wrlock();
215 215  
216 216 if (client->getClientId().empty())
217   - throw ProtocolError("Trying to store client without an ID.");
  217 + throw ProtocolError("Trying to store client without an ID.", ReasonCodes::ProtocolError);
218 218  
219 219 std::shared_ptr<Session> session;
220 220 auto session_it = sessionsById.find(client->getClientId());
... ...