diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index a23dc5b..75c5869 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -44,6 +44,8 @@ private slots: void test_circbuf_wrapped_doubling(); void test_circbuf_full_wrapped_buffer_doubling(); + void test_validSubscribePath(); + void test_retained(); void test_retained_changed(); void test_retained_removed(); @@ -278,6 +280,26 @@ void MainTests::test_circbuf_full_wrapped_buffer_doubling() QVERIFY(true); } +void MainTests::test_validSubscribePath() +{ + QVERIFY(isValidSubscribePath("one/two/three")); + QVERIFY(isValidSubscribePath("one//three")); + QVERIFY(isValidSubscribePath("one/+/three")); + QVERIFY(isValidSubscribePath("one/+/#")); + QVERIFY(isValidSubscribePath("#")); + QVERIFY(isValidSubscribePath("///")); + QVERIFY(isValidSubscribePath("//#")); + QVERIFY(isValidSubscribePath("+")); + QVERIFY(isValidSubscribePath("")); + + QVERIFY(!isValidSubscribePath("one/tw+o/three")); + QVERIFY(!isValidSubscribePath("one/+o/three")); + QVERIFY(!isValidSubscribePath("one/a+/three")); + QVERIFY(!isValidSubscribePath("#//three")); + QVERIFY(!isValidSubscribePath("#//+")); + QVERIFY(!isValidSubscribePath("one/#/+")); +} + void MainTests::test_retained() { TwoClientTestContext testContext; diff --git a/mqttpacket.cpp b/mqttpacket.cpp index c09f8b4..a3df25c 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -358,6 +358,9 @@ void MqttPacket::handleSubscribe() if (topic.empty() || !isValidUtf8(topic)) throw ProtocolError("Subscribe topic not valid UTF-8."); + if (isValidSubscribePath(topic)) + throw ProtocolError(formatString("Invalid subscribe path: %s", topic.c_str())); + char qos = readByte(); if (qos > 2) diff --git a/utils.cpp b/utils.cpp index b6f819c..53cc431 100644 --- a/utils.cpp +++ b/utils.cpp @@ -137,6 +137,31 @@ bool isValidPublishPath(const std::string &s) return true; } +bool isValidSubscribePath(const std::string &s) +{ + bool plusAllowed = true; + bool nextMustBeSlash = false; + bool poundSeen = false; + + for (const char c : s) + { + if (!plusAllowed && c == '+') + return false; + + if (nextMustBeSlash && c != '/') + return false; + + if (poundSeen) + return false; + + plusAllowed = c == '/'; + nextMustBeSlash = c == '+'; + poundSeen = c == '#'; + } + + return true; +} + bool containsDangerousCharacters(const std::string &s) { if (s.empty()) diff --git a/utils.h b/utils.h index 27ca9cc..d2795a6 100644 --- a/utils.h +++ b/utils.h @@ -37,6 +37,7 @@ bool isValidUtf8(const std::string &s); bool strContains(const std::string &s, const std::string &needle); bool isValidPublishPath(const std::string &s); +bool isValidSubscribePath(const std::string &s); bool containsDangerousCharacters(const std::string &s); void ltrim(std::string &s);