diff --git a/client.cpp b/client.cpp index 79137f5..45a9255 100644 --- a/client.cpp +++ b/client.cpp @@ -358,7 +358,7 @@ void Client::resetBuffersIfEligible() void Client::setTopicAlias(const uint16_t alias_id, const std::string &topic) { if (alias_id == 0) - throw ProtocolError("Client tried to set topic alias 0, which is a protocol error."); + throw ProtocolError("Client tried to set topic alias 0, which is a protocol error.", ReasonCodes::ProtocolError); if (topic.empty()) return; @@ -367,7 +367,8 @@ void Client::setTopicAlias(const uint16_t alias_id, const std::string &topic) // The specs actually say "The Client MUST NOT send a Topic Alias [...] to the Server greater than this value [Topic Alias Maximum]". So, it's not about count. if (alias_id > settings->maxIncomingTopicAliasValue) - throw ProtocolError(formatString("Client tried to set more topic aliases than the server max of %d per client", settings->maxIncomingTopicAliasValue)); + throw ProtocolError(formatString("Client tried to set more topic aliases than the server max of %d per client", settings->maxIncomingTopicAliasValue), + ReasonCodes::ProtocolError); this->incomingTopicAliases[alias_id] = topic; } diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 3cc6773..1ee3806 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -133,7 +133,7 @@ MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, Publish &_publish) { if (_publish.topic.length() > 0xFFFF) { - throw ProtocolError("Topic path too long."); + throw ProtocolError("Topic path too long.", ReasonCodes::ProtocolError); } this->protocolVersion = protocolVersion; @@ -232,7 +232,7 @@ void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector &packe fixed_header_length++; if (fixed_header_length > 5) - throw ProtocolError("Packet signifies more than 5 bytes in variable length header. Invalid."); + throw ProtocolError("Packet signifies more than 5 bytes in variable length header. Invalid.", ReasonCodes::MalformedPacket); // This happens when you only don't have all the bytes that specify the remaining length. if (fixed_header_length > buf.usedBytes()) @@ -242,19 +242,19 @@ void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector &packe packet_length += (encodedByte & 127) * multiplier; multiplier *= 128; if (multiplier > 128*128*128*128) - throw ProtocolError("Malformed Remaining Length."); + throw ProtocolError("Malformed Remaining Length.", ReasonCodes::MalformedPacket); } while ((encodedByte & 128) != 0); packet_length += fixed_header_length; if (sender && !sender->getAuthenticated() && packet_length >= 1024*1024) { - throw ProtocolError("An unauthenticated client sends a packet of 1 MB or bigger? Probably it's just random bytes."); + throw ProtocolError("An unauthenticated client sends a packet of 1 MB or bigger? Probably it's just random bytes.", ReasonCodes::ProtocolError); } if (packet_length > ABSOLUTE_MAX_PACKET_SIZE) { - throw ProtocolError("A client sends a packet claiming to be bigger than the maximum MQTT allows."); + throw ProtocolError("A client sends a packet claiming to be bigger than the maximum MQTT allows.", ReasonCodes::ProtocolError); } if (packet_length <= buf.usedBytes()) @@ -269,7 +269,7 @@ void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector &packe void MqttPacket::handle() { if (packetType == PacketType::Reserved) - throw ProtocolError("Packet type 0 specified, which is reserved and invalid."); + throw ProtocolError("Packet type 0 specified, which is reserved and invalid.", ReasonCodes::MalformedPacket); if (packetType != PacketType::CONNECT) { @@ -305,7 +305,7 @@ void MqttPacket::handle() void MqttPacket::handleConnect() { if (sender->hasConnectPacketSeen()) - throw ProtocolError("Client already sent a CONNECT."); + throw ProtocolError("Client already sent a CONNECT.", ReasonCodes::ProtocolError); std::shared_ptr subscriptionStore = sender->getThreadData()->getSubscriptionStore(); @@ -348,7 +348,7 @@ void MqttPacket::handleConnect() bool reserved = !!(flagByte & 0b00000001); if (reserved) - throw ProtocolError("Protocol demands reserved flag in CONNECT is 0"); + throw ProtocolError("Protocol demands reserved flag in CONNECT is 0", ReasonCodes::MalformedPacket); bool user_name_flag = !!(flagByte & 0b10000000); @@ -359,7 +359,7 @@ void MqttPacket::handleConnect() bool clean_start = !!(flagByte & 0b00000010); if (will_qos > 2) - throw ProtocolError("Invalid QoS for will."); + throw ProtocolError("Invalid QoS for will.", ReasonCodes::MalformedPacket); uint16_t keep_alive = readTwoBytesToUInt16(); @@ -415,7 +415,7 @@ void MqttPacket::handleConnect() break; } default: - throw ProtocolError("Invalid connect property."); + throw ProtocolError("Invalid connect property.", ReasonCodes::ProtocolError); } } } @@ -489,7 +489,7 @@ void MqttPacket::handleConnect() break; } default: - throw ProtocolError("Invalid will property in connect."); + throw ProtocolError("Invalid will property in connect.", ReasonCodes::ProtocolError); } } } @@ -506,7 +506,7 @@ void MqttPacket::handleConnect() username = std::string(readBytes(user_name_length), user_name_length); if (username.empty()) - throw ProtocolError("Username flagged as present, but it's 0 bytes."); + throw ProtocolError("Username flagged as present, but it's 0 bytes.", ReasonCodes::MalformedPacket); } if (password_flag) { @@ -637,7 +637,7 @@ void MqttPacket::handleConnect() } else { - throw ProtocolError("Invalid variable header length. Garbage?"); + throw ProtocolError("Invalid variable header length. Garbage?", ReasonCodes::MalformedPacket); } } @@ -658,13 +658,13 @@ void MqttPacket::handleSubscribe() const char firstByteFirstNibble = (first_byte & 0x0F); if (firstByteFirstNibble != 2) - throw ProtocolError("First LSB of first byte is wrong value for subscribe packet."); + throw ProtocolError("First LSB of first byte is wrong value for subscribe packet.", ReasonCodes::MalformedPacket); const uint16_t packet_id = readTwoBytesToUInt16(); if (packet_id == 0) { - throw ProtocolError("Packet ID 0 when subscribing is invalid."); // [MQTT-2.3.1-1] + throw ProtocolError("Packet ID 0 when subscribing is invalid.", ReasonCodes::MalformedPacket); // [MQTT-2.3.1-1] } if (protocolVersion == ProtocolVersion::Mqtt5) @@ -685,7 +685,7 @@ void MqttPacket::handleSubscribe() readUserProperty(); break; default: - throw ProtocolError("Invalid subscribe property."); + throw ProtocolError("Invalid subscribe property.", ReasonCodes::ProtocolError); } } } @@ -699,15 +699,15 @@ void MqttPacket::handleSubscribe() std::string topic(readBytes(topicLength), topicLength); if (topic.empty() || !isValidUtf8(topic)) - throw ProtocolError("Subscribe topic not valid UTF-8."); + throw ProtocolError("Subscribe topic not valid UTF-8.", ReasonCodes::MalformedPacket); if (!isValidSubscribePath(topic)) - throw ProtocolError(formatString("Invalid subscribe path: %s", topic.c_str())); + throw ProtocolError(formatString("Invalid subscribe path: %s", topic.c_str()), ReasonCodes::MalformedPacket); uint8_t qos = readByte(); if (qos > 2) - throw ProtocolError("QoS is greater than 2, and/or reserved bytes in QoS field are not 0."); + throw ProtocolError("QoS is greater than 2, and/or reserved bytes in QoS field are not 0.", ReasonCodes::MalformedPacket); std::vector subtopics; splitTopic(topic, subtopics); @@ -730,7 +730,7 @@ void MqttPacket::handleSubscribe() // MQTT-3.8.3-3 if (subs_reponse_codes.empty()) { - throw ProtocolError("No topics specified to subscribe to."); + throw ProtocolError("No topics specified to subscribe to.", ReasonCodes::MalformedPacket); } SubAck subAck(this->protocolVersion, packet_id, subs_reponse_codes); @@ -743,7 +743,7 @@ void MqttPacket::handleUnsubscribe() const char firstByteFirstNibble = (first_byte & 0x0F); if (firstByteFirstNibble != 2) - throw ProtocolError("First LSB of first byte is wrong value for subscribe packet."); + throw ProtocolError("First LSB of first byte is wrong value for subscribe packet.", ReasonCodes::MalformedPacket); const uint16_t packet_id = readTwoBytesToUInt16(); @@ -767,7 +767,7 @@ void MqttPacket::handleUnsubscribe() readUserProperty(); break; default: - throw ProtocolError("Invalid unsubscribe property."); + throw ProtocolError("Invalid unsubscribe property.", ReasonCodes::ProtocolError); } } } @@ -782,7 +782,7 @@ void MqttPacket::handleUnsubscribe() std::string topic(readBytes(topicLength), topicLength); if (topic.empty() || !isValidUtf8(topic)) - throw ProtocolError("Subscribe topic not valid UTF-8."); + throw ProtocolError("Subscribe topic not valid UTF-8.", ReasonCodes::MalformedPacket); sender->getThreadData()->getSubscriptionStore()->removeSubscription(sender, topic); logger->logf(LOG_UNSUBSCRIBE, "Client '%s' unsubscribed from '%s'", sender->repr().c_str(), topic.c_str()); @@ -791,7 +791,7 @@ void MqttPacket::handleUnsubscribe() // MQTT-3.10.3-2 if (numberOfUnsubs == 0) { - throw ProtocolError("No topics specified to unsubscribe to."); + throw ProtocolError("No topics specified to unsubscribe to.", ReasonCodes::MalformedPacket); } UnsubAck unsubAck(this->sender->getProtocolVersion(), packet_id, numberOfUnsubs); @@ -808,10 +808,10 @@ void MqttPacket::parsePublishData() publishData.qos = (first_byte & 0b00000110) >> 1; if (publishData.qos > 2) - throw ProtocolError("QoS 3 is a protocol violation."); + throw ProtocolError("QoS 3 is a protocol violation.", ReasonCodes::MalformedPacket); if (publishData.qos == 0 && dup) - throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal."); + throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal.", ReasonCodes::MalformedPacket); publishData.topic = std::string(readBytes(variable_header_length), variable_header_length); @@ -822,7 +822,7 @@ void MqttPacket::parsePublishData() if (packet_id == 0) { - throw ProtocolError("Packet ID 0 when publishing is invalid."); // [MQTT-2.3.1-1] + throw ProtocolError("Packet ID 0 when publishing is invalid.", ReasonCodes::MalformedPacket); // [MQTT-2.3.1-1] } } @@ -906,13 +906,13 @@ void MqttPacket::parsePublishData() break; } default: - throw ProtocolError("Invalid property in publish."); + throw ProtocolError("Invalid property in publish.", ReasonCodes::ProtocolError); } } } if (publishData.topic.empty()) - throw ProtocolError("Empty publish topic"); + throw ProtocolError("Empty publish topic", ReasonCodes::ProtocolError); payloadLen = remainingAfterPos(); payloadStart = pos; @@ -926,7 +926,7 @@ void MqttPacket::handlePublish() { const std::string err = formatString("Client '%s' published a message with invalid UTF8 or $/+/# in it. Dropping.", sender->repr().c_str()); logger->logf(LOG_WARNING, err.c_str()); - throw ProtocolError(err); + throw ProtocolError(err, ReasonCodes::ProtocolError); } #ifndef NDEBUG @@ -1017,7 +1017,7 @@ void MqttPacket::handlePubRel() { // MQTT-3.6.1-1, but why do we care, and only care for certain control packets? if (first_byte & 0b1101) - throw ProtocolError("PUBREL first byte LSB must be 0010."); + throw ProtocolError("PUBREL first byte LSB must be 0010.", ReasonCodes::MalformedPacket); const uint16_t packet_id = readTwoBytesToUInt16(); sender->getSession()->removeIncomingQoS2MessageId(packet_id); @@ -1165,7 +1165,7 @@ bool MqttPacket::containsFixedHeader() const char *MqttPacket::readBytes(size_t length) { if (pos + length > bites.size()) - throw ProtocolError("Invalid packet: header specifies invalid length."); + throw ProtocolError("Invalid packet: header specifies invalid length.", ReasonCodes::MalformedPacket); char *b = &bites[pos]; pos += length; @@ -1175,7 +1175,7 @@ char *MqttPacket::readBytes(size_t length) char MqttPacket::readByte() { if (pos + 1 > bites.size()) - throw ProtocolError("Invalid packet: header specifies invalid length."); + throw ProtocolError("Invalid packet: header specifies invalid length.", ReasonCodes::MalformedPacket); char b = bites[pos++]; return b; @@ -1184,7 +1184,7 @@ char MqttPacket::readByte() void MqttPacket::writeByte(char b) { if (pos + 1 > bites.size()) - throw ProtocolError("Exceeding packet size"); + throw ProtocolError("Exceeding packet size", ReasonCodes::MalformedPacket); bites[pos++] = b; } @@ -1192,7 +1192,7 @@ void MqttPacket::writeByte(char b) void MqttPacket::writeUint16(uint16_t x) { if (pos + 2 > bites.size()) - throw ProtocolError("Exceeding packet size"); + throw ProtocolError("Exceeding packet size", ReasonCodes::MalformedPacket); const uint8_t a = static_cast(x >> 8); const uint8_t b = static_cast(x); @@ -1204,7 +1204,7 @@ void MqttPacket::writeUint16(uint16_t x) void MqttPacket::writeBytes(const char *b, size_t len) { if (pos + len > bites.size()) - throw ProtocolError("Exceeding packet size"); + throw ProtocolError("Exceeding packet size", ReasonCodes::MalformedPacket); memcpy(&bites[pos], b, len); pos += len; @@ -1232,7 +1232,7 @@ void MqttPacket::writeVariableByteInt(const VariableByteInt &v) uint16_t MqttPacket::readTwoBytesToUInt16() { if (pos + 2 > bites.size()) - throw ProtocolError("Invalid packet: header specifies invalid length."); + throw ProtocolError("Invalid packet: header specifies invalid length.", ReasonCodes::MalformedPacket); uint8_t a = bites[pos]; uint8_t b = bites[pos+1]; @@ -1244,7 +1244,7 @@ uint16_t MqttPacket::readTwoBytesToUInt16() uint32_t MqttPacket::readFourBytesToUint32() { if (pos + 4 > bites.size()) - throw ProtocolError("Invalid packet: header specifies invalid length."); + throw ProtocolError("Invalid packet: header specifies invalid length.", ReasonCodes::MalformedPacket); const uint8_t a = bites[pos++]; const uint8_t b = bites[pos++]; @@ -1267,13 +1267,13 @@ size_t MqttPacket::decodeVariableByteIntAtPos() do { if (pos >= bites.size()) - throw ProtocolError("Variable byte int length goes out of packet. Corrupt."); + throw ProtocolError("Variable byte int length goes out of packet. Corrupt.", ReasonCodes::MalformedPacket); encodedByte = bites[pos++]; value += (encodedByte & 127) * multiplier; multiplier *= 128; if (multiplier > 128*128*128*128) - throw ProtocolError("Malformed Remaining Length."); + throw ProtocolError("Malformed Remaining Length.", ReasonCodes::MalformedPacket); } while ((encodedByte & 128) != 0); diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index c123b71..7526112 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -214,7 +214,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr lock_guard.wrlock(); if (client->getClientId().empty()) - throw ProtocolError("Trying to store client without an ID."); + throw ProtocolError("Trying to store client without an ID.", ReasonCodes::ProtocolError); std::shared_ptr session; auto session_it = sessionsById.find(client->getClientId());