diff --git a/client.cpp b/client.cpp index 088a9c1..c114f1a 100644 --- a/client.cpp +++ b/client.cpp @@ -334,6 +334,25 @@ void Client::resetBuffersIfEligible() writebuf.resetSizeIfEligable(initialBufferSize); } +void Client::setTopicAlias(const uint16_t alias_id, const std::string &topic) +{ + if (alias_id == 0) + throw ProtocolError("Client tried to set topic alias 0, which is a protocol error."); + + if (topic.empty()) + return; + + if (alias_id > this->maxTopicAliases) + throw ProtocolError("Client exceeded max topic aliases."); + + this->topicAliases[alias_id] = topic; +} + +const std::string &Client::getTopicAlias(const uint16_t id) +{ + return this->topicAliases[id]; +} + #ifndef NDEBUG /** * @brief IoWrapper::setFakeUpgraded(). diff --git a/client.h b/client.h index fc397f7..737b64b 100644 --- a/client.h +++ b/client.h @@ -82,6 +82,8 @@ class Client std::shared_ptr session; + std::unordered_map topicAliases; + Logger *logger = Logger::getInstance(); void setReadyForWriting(bool val); @@ -137,6 +139,9 @@ public: std::string getKeepAliveInfoString() const; void resetBuffersIfEligible(); + void setTopicAlias(const uint16_t alias_id, const std::string &topic); + const std::string &getTopicAlias(const uint16_t id); + #ifndef NDEBUG void setFakeUpgraded(); #endif diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 6864d9f..689bc53 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -343,7 +343,7 @@ void MqttPacket::handleConnect() uint16_t max_qos_packets = settings.maxQosMsgPendingPerClient; uint32_t session_expire = settings.expireSessionsAfterSeconds > 0 ? settings.expireSessionsAfterSeconds : std::numeric_limits::max(); uint32_t max_packet_size = settings.maxPacketSize; - uint16_t max_topic_aliases = 0; + uint16_t max_topic_aliases = settings.maxTopicAliases; bool request_response_information = false; bool request_problem_information = false; @@ -750,10 +750,7 @@ void MqttPacket::handleUnsubscribe() void MqttPacket::handlePublish() { - uint16_t variable_header_length = readTwoBytesToUInt16(); - - if (variable_header_length == 0) - throw ProtocolError("Empty publish topic"); + const uint16_t variable_header_length = readTwoBytesToUInt16(); bool retain = (first_byte & 0b00000001); bool dup = !!(first_byte & 0b00001000); @@ -767,19 +764,6 @@ void MqttPacket::handlePublish() throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal."); publishData.topic = std::string(readBytes(variable_header_length), variable_header_length); - splitTopic(publishData.topic, publishData.subtopics); - - if (!isValidUtf8(publishData.topic, true)) - { - logger->logf(LOG_WARNING, "Client '%s' published a message with invalid UTF8 or $/+/# in it. Dropping.", sender->repr().c_str()); - return; - } - -#ifndef NDEBUG - logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", publishData.topic.c_str(), qos, retain, dup); -#endif - - sender->getThreadData()->incrementReceivedMessageCount(); if (qos) { @@ -820,6 +804,7 @@ void MqttPacket::handlePublish() const size_t proplen = decodeVariableByteIntAtPos(); const size_t prop_end_at = pos + proplen; + // TODO: don't do this when the only properties are expiry and topic alias; they don't need the builder. if (proplen > 0) publishData.constructPropertyBuilder(); @@ -836,7 +821,21 @@ void MqttPacket::handlePublish() publishData.createdAt = std::chrono::steady_clock::now(); publishData.expiresAfter = std::chrono::seconds(readFourBytesToUint32()); case Mqtt5Properties::TopicAlias: + { + const uint16_t alias_id = readTwoBytesToUInt16(); + this->hasTopicAlias = true; + + if (publishData.topic.empty()) + { + publishData.topic = sender->getTopicAlias(alias_id); + } + else + { + sender->setTopicAlias(alias_id, publishData.topic); + } + break; + } case Mqtt5Properties::ResponseTopic: { const uint16_t len = readTwoBytesToUInt16(); @@ -878,6 +877,23 @@ void MqttPacket::handlePublish() } } + splitTopic(publishData.topic, publishData.subtopics); + + if (!isValidUtf8(publishData.topic, true)) + { + logger->logf(LOG_WARNING, "Client '%s' published a message with invalid UTF8 or $/+/# in it. Dropping.", sender->repr().c_str()); + return; + } + + if (publishData.topic.empty()) + throw ProtocolError("Empty publish topic"); + +#ifndef NDEBUG + logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", publishData.topic.c_str(), qos, retain, dup); +#endif + + sender->getThreadData()->incrementReceivedMessageCount(); + payloadLen = remainingAfterPos(); payloadStart = pos; @@ -1237,7 +1253,7 @@ bool MqttPacket::containsClientSpecificProperties() const if (protocolVersion <= ProtocolVersion::Mqtt311 || !publishData.propertyBuilder) return false; - if (publishData.createdAt.time_since_epoch().count() == 0) // TODO: better + if (publishData.createdAt.time_since_epoch().count() == 0 || this->hasTopicAlias) // TODO: better { return true; } diff --git a/mqttpacket.h b/mqttpacket.h index 929c8de..5283060 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -54,6 +54,7 @@ class MqttPacket ProtocolVersion protocolVersion = ProtocolVersion::None; size_t payloadStart = 0; size_t payloadLen = 0; + bool hasTopicAlias = false; Logger *logger = Logger::getInstance(); char *readBytes(size_t length); diff --git a/settings.h b/settings.h index 6910629..e60c506 100644 --- a/settings.h +++ b/settings.h @@ -42,6 +42,7 @@ public: bool authPluginSerializeAuthChecks = false; int clientInitialBufferSize = 1024; // Must be power of 2 int maxPacketSize = 268435461; // 256 MB + 5 + uint16_t maxTopicAliases = 65535; #ifdef TESTING bool logDebug = true; #else