Commit 7180065162b30b42353d9f19e1fcbc6c955d6a7d

Authored by Wiebe Cazemier
1 parent e00c635e

Check for valid subscribe path

FlashMQTests/tst_maintests.cpp
... ... @@ -44,6 +44,8 @@ private slots:
44 44 void test_circbuf_wrapped_doubling();
45 45 void test_circbuf_full_wrapped_buffer_doubling();
46 46  
  47 + void test_validSubscribePath();
  48 +
47 49 void test_retained();
48 50 void test_retained_changed();
49 51 void test_retained_removed();
... ... @@ -278,6 +280,26 @@ void MainTests::test_circbuf_full_wrapped_buffer_doubling()
278 280 QVERIFY(true);
279 281 }
280 282  
  283 +void MainTests::test_validSubscribePath()
  284 +{
  285 + QVERIFY(isValidSubscribePath("one/two/three"));
  286 + QVERIFY(isValidSubscribePath("one//three"));
  287 + QVERIFY(isValidSubscribePath("one/+/three"));
  288 + QVERIFY(isValidSubscribePath("one/+/#"));
  289 + QVERIFY(isValidSubscribePath("#"));
  290 + QVERIFY(isValidSubscribePath("///"));
  291 + QVERIFY(isValidSubscribePath("//#"));
  292 + QVERIFY(isValidSubscribePath("+"));
  293 + QVERIFY(isValidSubscribePath(""));
  294 +
  295 + QVERIFY(!isValidSubscribePath("one/tw+o/three"));
  296 + QVERIFY(!isValidSubscribePath("one/+o/three"));
  297 + QVERIFY(!isValidSubscribePath("one/a+/three"));
  298 + QVERIFY(!isValidSubscribePath("#//three"));
  299 + QVERIFY(!isValidSubscribePath("#//+"));
  300 + QVERIFY(!isValidSubscribePath("one/#/+"));
  301 +}
  302 +
281 303 void MainTests::test_retained()
282 304 {
283 305 TwoClientTestContext testContext;
... ...
mqttpacket.cpp
... ... @@ -358,6 +358,9 @@ void MqttPacket::handleSubscribe()
358 358 if (topic.empty() || !isValidUtf8(topic))
359 359 throw ProtocolError("Subscribe topic not valid UTF-8.");
360 360  
  361 + if (isValidSubscribePath(topic))
  362 + throw ProtocolError(formatString("Invalid subscribe path: %s", topic.c_str()));
  363 +
361 364 char qos = readByte();
362 365  
363 366 if (qos > 2)
... ...
utils.cpp
... ... @@ -137,6 +137,31 @@ bool isValidPublishPath(const std::string &s)
137 137 return true;
138 138 }
139 139  
  140 +bool isValidSubscribePath(const std::string &s)
  141 +{
  142 + bool plusAllowed = true;
  143 + bool nextMustBeSlash = false;
  144 + bool poundSeen = false;
  145 +
  146 + for (const char c : s)
  147 + {
  148 + if (!plusAllowed && c == '+')
  149 + return false;
  150 +
  151 + if (nextMustBeSlash && c != '/')
  152 + return false;
  153 +
  154 + if (poundSeen)
  155 + return false;
  156 +
  157 + plusAllowed = c == '/';
  158 + nextMustBeSlash = c == '+';
  159 + poundSeen = c == '#';
  160 + }
  161 +
  162 + return true;
  163 +}
  164 +
140 165 bool containsDangerousCharacters(const std::string &s)
141 166 {
142 167 if (s.empty())
... ...
... ... @@ -37,6 +37,7 @@ bool isValidUtf8(const std::string &s);
37 37 bool strContains(const std::string &s, const std::string &needle);
38 38  
39 39 bool isValidPublishPath(const std::string &s);
  40 +bool isValidSubscribePath(const std::string &s);
40 41 bool containsDangerousCharacters(const std::string &s);
41 42  
42 43 void ltrim(std::string &s);
... ...