diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 9329457..fed7cbb 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -404,14 +404,12 @@ void MqttPacket::handleConnect() break; case Mqtt5Properties::AuthenticationMethod: { - const uint16_t len = readTwoBytesToUInt16(); - readBytes(len); + readBytesToString(); break; } case Mqtt5Properties::AuthenticationData: { - const uint16_t len = readTwoBytesToUInt16(); - readBytes(len); + readBytesToString(false); break; } default: @@ -420,8 +418,7 @@ void MqttPacket::handleConnect() } } - uint16_t client_id_length = readTwoBytesToUInt16(); - std::string client_id(readBytes(client_id_length), client_id_length); + std::string client_id = readBytesToString(); std::string username; std::string password; @@ -453,15 +450,13 @@ void MqttPacket::handleConnect() break; case Mqtt5Properties::ContentType: { - const uint16_t len = readTwoBytesToUInt16(); - const std::string contentType(readBytes(len), len); + const std::string contentType = readBytesToString(); willpublish.propertyBuilder->writeContentType(contentType); break; } case Mqtt5Properties::ResponseTopic: { - const uint16_t len = readTwoBytesToUInt16(); - const std::string responseTopic(readBytes(len), len); + const std::string responseTopic = readBytesToString(); willpublish.propertyBuilder->writeResponseTopic(responseTopic); break; } @@ -473,18 +468,13 @@ void MqttPacket::handleConnect() } case Mqtt5Properties::CorrelationData: { - const uint16_t len = readTwoBytesToUInt16(); - const std::string correlationData(readBytes(len), len); + const std::string correlationData = readBytesToString(false); willpublish.propertyBuilder->writeCorrelationData(correlationData); break; } case Mqtt5Properties::UserProperty: { - const uint16_t lenKey = readTwoBytesToUInt16(); - std::string userPropKey(readBytes(lenKey), lenKey); - const uint16_t lenVal = readTwoBytesToUInt16(); - std::string userPropVal(readBytes(lenVal), lenVal); - willpublish.propertyBuilder->writeUserProperty(std::move(userPropKey), std::move(userPropVal)); + readUserProperty(); break; } default: @@ -493,16 +483,14 @@ void MqttPacket::handleConnect() } } - uint16_t will_topic_length = readTwoBytesToUInt16(); - willpublish.topic = std::string(readBytes(will_topic_length), will_topic_length); + willpublish.topic = readBytesToString(true, true); uint16_t will_payload_length = readTwoBytesToUInt16(); willpublish.payload = std::string(readBytes(will_payload_length), will_payload_length); } if (user_name_flag) { - uint16_t user_name_length = readTwoBytesToUInt16(); - username = std::string(readBytes(user_name_length), user_name_length); + username = readBytesToString(false); if (username.empty()) throw ProtocolError("Username flagged as present, but it's 0 bytes.", ReasonCodes::MalformedPacket); @@ -514,21 +502,21 @@ void MqttPacket::handleConnect() throw ProtocolError("MQTT 3.1.1: If the User Name Flag is set to 0, the Password Flag MUST be set to 0."); } - uint16_t password_length = readTwoBytesToUInt16(); - password = std::string(readBytes(password_length), password_length); + password = readBytesToString(false); if (password.empty()) throw ProtocolError("Password flagged as present, but it's 0 bytes.", ReasonCodes::MalformedPacket); } - // The specs don't really say what to do when client id not UTF8, so including here. - if (!isValidUtf8(client_id) || !isValidUtf8(username) || !isValidUtf8(password) || !isValidUtf8(willpublish.topic)) + // I deferred the initial UTF8 check on username to be able to give an appropriate connack here, but to me, the specs + // are actually vague whether 'BadUserNameOrPassword' should be given on invalid UTF8. + if (!isValidUtf8(username)) { ConnAck connAck(protocolVersion, ReasonCodes::BadUserNameOrPassword); MqttPacket response(connAck); sender->setReadyForDisconnect(); sender->writeMqttPacket(response); - logger->logf(LOG_ERR, "Client ID, username, password or will topic has invalid UTF8: ", client_id.c_str()); + logger->logf(LOG_ERR, "Username has invalid UTF8: %s", username.c_str()); return; } @@ -679,14 +667,12 @@ void MqttPacket::handleDisconnect() } case Mqtt5Properties::ReasonString: { - const uint16_t len = readTwoBytesToUInt16(); - reasonString = std::string(readBytes(len), len); + reasonString = readBytesToString(); break; } case Mqtt5Properties::ServerReference: { - const uint16_t len = readTwoBytesToUInt16(); - readBytes(len); + readBytesToString(); break; } case Mqtt5Properties::UserProperty: @@ -754,11 +740,10 @@ void MqttPacket::handleSubscribe() std::list subs_reponse_codes; while (remainingAfterPos() > 0) { - uint16_t topicLength = readTwoBytesToUInt16(); - std::string topic(readBytes(topicLength), topicLength); + std::string topic = readBytesToString(true); - if (topic.empty() || !isValidUtf8(topic)) - throw ProtocolError("Subscribe topic not valid UTF-8.", ReasonCodes::MalformedPacket); + if (topic.empty()) + throw ProtocolError("Subscribe topic is empty.", ReasonCodes::MalformedPacket); if (!isValidSubscribePath(topic)) throw ProtocolError(formatString("Invalid subscribe path: %s", topic.c_str()), ReasonCodes::MalformedPacket); @@ -837,11 +822,10 @@ void MqttPacket::handleUnsubscribe() { numberOfUnsubs++; - uint16_t topicLength = readTwoBytesToUInt16(); - std::string topic(readBytes(topicLength), topicLength); + std::string topic = readBytesToString(); - if (topic.empty() || !isValidUtf8(topic)) - throw ProtocolError("Subscribe topic not valid UTF-8.", ReasonCodes::MalformedPacket); + if (topic.empty()) + throw ProtocolError("Subscribe topic is empty.", ReasonCodes::MalformedPacket); sender->getThreadData()->getSubscriptionStore()->removeSubscription(sender, topic); logger->logf(LOG_UNSUBSCRIBE, "Client '%s' unsubscribed from '%s'", sender->repr().c_str(), topic.c_str()); @@ -860,8 +844,6 @@ void MqttPacket::handleUnsubscribe() void MqttPacket::parsePublishData() { - const uint16_t variable_header_length = readTwoBytesToUInt16(); - publishData.retain = (first_byte & 0b00000001); bool dup = !!(first_byte & 0b00001000); publishData.qos = (first_byte & 0b00000110) >> 1; @@ -872,7 +854,7 @@ void MqttPacket::parsePublishData() if (publishData.qos == 0 && dup) throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal.", ReasonCodes::MalformedPacket); - publishData.topic = std::string(readBytes(variable_header_length), variable_header_length); + publishData.topic = readBytesToString(true, true); if (publishData.qos) { @@ -928,27 +910,20 @@ void MqttPacket::parsePublishData() case Mqtt5Properties::ResponseTopic: { publishData.constructPropertyBuilder(); - const uint16_t len = readTwoBytesToUInt16(); - const std::string responseTopic(readBytes(len), len); + const std::string responseTopic = readBytesToString(); publishData.propertyBuilder->writeResponseTopic(responseTopic); break; } case Mqtt5Properties::CorrelationData: { publishData.constructPropertyBuilder(); - const uint16_t len = readTwoBytesToUInt16(); - const std::string correlationData(readBytes(len), len); + const std::string correlationData = readBytesToString(false); publishData.propertyBuilder->writeCorrelationData(correlationData); break; } case Mqtt5Properties::UserProperty: { - publishData.constructPropertyBuilder(); - const uint16_t lenKey = readTwoBytesToUInt16(); - std::string userPropKey(readBytes(lenKey), lenKey); - const uint16_t lenVal = readTwoBytesToUInt16(); - std::string userPropVal(readBytes(lenVal), lenVal); - publishData.propertyBuilder->writeUserProperty(std::move(userPropKey), std::move(userPropVal)); + readUserProperty(); break; } case Mqtt5Properties::SubscriptionIdentifier: @@ -981,13 +956,6 @@ void MqttPacket::handlePublish() { parsePublishData(); - if (!isValidUtf8(publishData.topic, true)) - { - const std::string err = formatString("Client '%s' published a message with invalid UTF8 or $/+/# in it. Dropping.", sender->repr().c_str()); - logger->logf(LOG_WARNING, err.c_str()); - throw ProtocolError(err, ReasonCodes::ProtocolError); - } - #ifndef NDEBUG logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", publishData.topic.c_str(), publishData.qos, publishData.retain, dup); #endif @@ -1350,14 +1318,29 @@ void MqttPacket::readUserProperty() { this->publishData.constructPropertyBuilder(); - const uint16_t len = readTwoBytesToUInt16(); - std::string key(readBytes(len), len); - const uint16_t len2 = readTwoBytesToUInt16(); - std::string value(readBytes(len2), len2); + std::string key = readBytesToString(); + std::string value = readBytesToString(); this->publishData.propertyBuilder->writeUserProperty(std::move(key), std::move(value)); } +std::string MqttPacket::readBytesToString(bool validateUtf8, bool alsoCheckInvalidPublishChars) +{ + const uint16_t len = readTwoBytesToUInt16(); + std::string result(readBytes(len), len); + + if (validateUtf8) + { + if (!isValidUtf8(result, alsoCheckInvalidPublishChars)) + { + logger->logf(LOG_DEBUG, "Data of invalid UTF-8 string or publish topic: %s", result.c_str()); + throw ProtocolError("Invalid UTF8 string detected, or invalid publish characters.", ReasonCodes::MalformedPacket); + } + } + + return result; +} + const std::vector> *MqttPacket::getUserProperties() const { if (this->publishData.propertyBuilder) diff --git a/mqttpacket.h b/mqttpacket.h index 4c442ac..a2f6152 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -82,6 +82,7 @@ class MqttPacket size_t remainingAfterPos(); size_t decodeVariableByteIntAtPos(); void readUserProperty(); + std::string readBytesToString(bool validateUtf8 = true, bool alsoCheckInvalidPublishChars = false); void calculateRemainingLength(); diff --git a/threadlocalutils.cpp b/threadlocalutils.cpp index 3307f7a..a6809e5 100644 --- a/threadlocalutils.cpp +++ b/threadlocalutils.cpp @@ -51,6 +51,12 @@ std::vector *SimdUtils::splitTopic(const std::string &topic, std::v return &output; } +/** + * @brief SimdUtils::isValidUtf8 checks UTF-8 validity 16 bytes at a time, using SSE 4.2. + * @param s + * @param alsoCheckInvalidPublishChars is for checking the presence of '#' and '+' which is not allowed in publishes. + * @return + */ bool SimdUtils::isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars) { const int len = s.size();