Commit 9dab7a48e665f59212a6bcf5f2b726ff99098d22

Authored by Wiebe Cazemier
1 parent 8ce485ee

Beginning of topic aliases

client.cpp
@@ -334,6 +334,25 @@ void Client::resetBuffersIfEligible() @@ -334,6 +334,25 @@ void Client::resetBuffersIfEligible()
334 writebuf.resetSizeIfEligable(initialBufferSize); 334 writebuf.resetSizeIfEligable(initialBufferSize);
335 } 335 }
336 336
  337 +void Client::setTopicAlias(const uint16_t alias_id, const std::string &topic)
  338 +{
  339 + if (alias_id == 0)
  340 + throw ProtocolError("Client tried to set topic alias 0, which is a protocol error.");
  341 +
  342 + if (topic.empty())
  343 + return;
  344 +
  345 + if (alias_id > this->maxTopicAliases)
  346 + throw ProtocolError("Client exceeded max topic aliases.");
  347 +
  348 + this->topicAliases[alias_id] = topic;
  349 +}
  350 +
  351 +const std::string &Client::getTopicAlias(const uint16_t id)
  352 +{
  353 + return this->topicAliases[id];
  354 +}
  355 +
337 #ifndef NDEBUG 356 #ifndef NDEBUG
338 /** 357 /**
339 * @brief IoWrapper::setFakeUpgraded(). 358 * @brief IoWrapper::setFakeUpgraded().
client.h
@@ -82,6 +82,8 @@ class Client @@ -82,6 +82,8 @@ class Client
82 82
83 std::shared_ptr<Session> session; 83 std::shared_ptr<Session> session;
84 84
  85 + std::unordered_map<uint16_t, std::string> topicAliases;
  86 +
85 Logger *logger = Logger::getInstance(); 87 Logger *logger = Logger::getInstance();
86 88
87 void setReadyForWriting(bool val); 89 void setReadyForWriting(bool val);
@@ -137,6 +139,9 @@ public: @@ -137,6 +139,9 @@ public:
137 std::string getKeepAliveInfoString() const; 139 std::string getKeepAliveInfoString() const;
138 void resetBuffersIfEligible(); 140 void resetBuffersIfEligible();
139 141
  142 + void setTopicAlias(const uint16_t alias_id, const std::string &topic);
  143 + const std::string &getTopicAlias(const uint16_t id);
  144 +
140 #ifndef NDEBUG 145 #ifndef NDEBUG
141 void setFakeUpgraded(); 146 void setFakeUpgraded();
142 #endif 147 #endif
mqttpacket.cpp
@@ -343,7 +343,7 @@ void MqttPacket::handleConnect() @@ -343,7 +343,7 @@ void MqttPacket::handleConnect()
343 uint16_t max_qos_packets = settings.maxQosMsgPendingPerClient; 343 uint16_t max_qos_packets = settings.maxQosMsgPendingPerClient;
344 uint32_t session_expire = settings.expireSessionsAfterSeconds > 0 ? settings.expireSessionsAfterSeconds : std::numeric_limits<uint32_t>::max(); 344 uint32_t session_expire = settings.expireSessionsAfterSeconds > 0 ? settings.expireSessionsAfterSeconds : std::numeric_limits<uint32_t>::max();
345 uint32_t max_packet_size = settings.maxPacketSize; 345 uint32_t max_packet_size = settings.maxPacketSize;
346 - uint16_t max_topic_aliases = 0; 346 + uint16_t max_topic_aliases = settings.maxTopicAliases;
347 bool request_response_information = false; 347 bool request_response_information = false;
348 bool request_problem_information = false; 348 bool request_problem_information = false;
349 349
@@ -750,10 +750,7 @@ void MqttPacket::handleUnsubscribe() @@ -750,10 +750,7 @@ void MqttPacket::handleUnsubscribe()
750 750
751 void MqttPacket::handlePublish() 751 void MqttPacket::handlePublish()
752 { 752 {
753 - uint16_t variable_header_length = readTwoBytesToUInt16();  
754 -  
755 - if (variable_header_length == 0)  
756 - throw ProtocolError("Empty publish topic"); 753 + const uint16_t variable_header_length = readTwoBytesToUInt16();
757 754
758 bool retain = (first_byte & 0b00000001); 755 bool retain = (first_byte & 0b00000001);
759 bool dup = !!(first_byte & 0b00001000); 756 bool dup = !!(first_byte & 0b00001000);
@@ -767,19 +764,6 @@ void MqttPacket::handlePublish() @@ -767,19 +764,6 @@ void MqttPacket::handlePublish()
767 throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal."); 764 throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal.");
768 765
769 publishData.topic = std::string(readBytes(variable_header_length), variable_header_length); 766 publishData.topic = std::string(readBytes(variable_header_length), variable_header_length);
770 - splitTopic(publishData.topic, publishData.subtopics);  
771 -  
772 - if (!isValidUtf8(publishData.topic, true))  
773 - {  
774 - logger->logf(LOG_WARNING, "Client '%s' published a message with invalid UTF8 or $/+/# in it. Dropping.", sender->repr().c_str());  
775 - return;  
776 - }  
777 -  
778 -#ifndef NDEBUG  
779 - logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", publishData.topic.c_str(), qos, retain, dup);  
780 -#endif  
781 -  
782 - sender->getThreadData()->incrementReceivedMessageCount();  
783 767
784 if (qos) 768 if (qos)
785 { 769 {
@@ -820,6 +804,7 @@ void MqttPacket::handlePublish() @@ -820,6 +804,7 @@ void MqttPacket::handlePublish()
820 const size_t proplen = decodeVariableByteIntAtPos(); 804 const size_t proplen = decodeVariableByteIntAtPos();
821 const size_t prop_end_at = pos + proplen; 805 const size_t prop_end_at = pos + proplen;
822 806
  807 + // TODO: don't do this when the only properties are expiry and topic alias; they don't need the builder.
823 if (proplen > 0) 808 if (proplen > 0)
824 publishData.constructPropertyBuilder(); 809 publishData.constructPropertyBuilder();
825 810
@@ -836,7 +821,21 @@ void MqttPacket::handlePublish() @@ -836,7 +821,21 @@ void MqttPacket::handlePublish()
836 publishData.createdAt = std::chrono::steady_clock::now(); 821 publishData.createdAt = std::chrono::steady_clock::now();
837 publishData.expiresAfter = std::chrono::seconds(readFourBytesToUint32()); 822 publishData.expiresAfter = std::chrono::seconds(readFourBytesToUint32());
838 case Mqtt5Properties::TopicAlias: 823 case Mqtt5Properties::TopicAlias:
  824 + {
  825 + const uint16_t alias_id = readTwoBytesToUInt16();
  826 + this->hasTopicAlias = true;
  827 +
  828 + if (publishData.topic.empty())
  829 + {
  830 + publishData.topic = sender->getTopicAlias(alias_id);
  831 + }
  832 + else
  833 + {
  834 + sender->setTopicAlias(alias_id, publishData.topic);
  835 + }
  836 +
839 break; 837 break;
  838 + }
840 case Mqtt5Properties::ResponseTopic: 839 case Mqtt5Properties::ResponseTopic:
841 { 840 {
842 const uint16_t len = readTwoBytesToUInt16(); 841 const uint16_t len = readTwoBytesToUInt16();
@@ -878,6 +877,23 @@ void MqttPacket::handlePublish() @@ -878,6 +877,23 @@ void MqttPacket::handlePublish()
878 } 877 }
879 } 878 }
880 879
  880 + splitTopic(publishData.topic, publishData.subtopics);
  881 +
  882 + if (!isValidUtf8(publishData.topic, true))
  883 + {
  884 + logger->logf(LOG_WARNING, "Client '%s' published a message with invalid UTF8 or $/+/# in it. Dropping.", sender->repr().c_str());
  885 + return;
  886 + }
  887 +
  888 + if (publishData.topic.empty())
  889 + throw ProtocolError("Empty publish topic");
  890 +
  891 +#ifndef NDEBUG
  892 + logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", publishData.topic.c_str(), qos, retain, dup);
  893 +#endif
  894 +
  895 + sender->getThreadData()->incrementReceivedMessageCount();
  896 +
881 payloadLen = remainingAfterPos(); 897 payloadLen = remainingAfterPos();
882 payloadStart = pos; 898 payloadStart = pos;
883 899
@@ -1237,7 +1253,7 @@ bool MqttPacket::containsClientSpecificProperties() const @@ -1237,7 +1253,7 @@ bool MqttPacket::containsClientSpecificProperties() const
1237 if (protocolVersion <= ProtocolVersion::Mqtt311 || !publishData.propertyBuilder) 1253 if (protocolVersion <= ProtocolVersion::Mqtt311 || !publishData.propertyBuilder)
1238 return false; 1254 return false;
1239 1255
1240 - if (publishData.createdAt.time_since_epoch().count() == 0) // TODO: better 1256 + if (publishData.createdAt.time_since_epoch().count() == 0 || this->hasTopicAlias) // TODO: better
1241 { 1257 {
1242 return true; 1258 return true;
1243 } 1259 }
mqttpacket.h
@@ -54,6 +54,7 @@ class MqttPacket @@ -54,6 +54,7 @@ class MqttPacket
54 ProtocolVersion protocolVersion = ProtocolVersion::None; 54 ProtocolVersion protocolVersion = ProtocolVersion::None;
55 size_t payloadStart = 0; 55 size_t payloadStart = 0;
56 size_t payloadLen = 0; 56 size_t payloadLen = 0;
  57 + bool hasTopicAlias = false;
57 Logger *logger = Logger::getInstance(); 58 Logger *logger = Logger::getInstance();
58 59
59 char *readBytes(size_t length); 60 char *readBytes(size_t length);
settings.h
@@ -42,6 +42,7 @@ public: @@ -42,6 +42,7 @@ public:
42 bool authPluginSerializeAuthChecks = false; 42 bool authPluginSerializeAuthChecks = false;
43 int clientInitialBufferSize = 1024; // Must be power of 2 43 int clientInitialBufferSize = 1024; // Must be power of 2
44 int maxPacketSize = 268435461; // 256 MB + 5 44 int maxPacketSize = 268435461; // 256 MB + 5
  45 + uint16_t maxTopicAliases = 65535;
45 #ifdef TESTING 46 #ifdef TESTING
46 bool logDebug = true; 47 bool logDebug = true;
47 #else 48 #else