diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 996b739..d2b058f 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -79,6 +79,7 @@ private slots: void test_sse_split(); void test_validUtf8(); + void test_validUtf8Sse(); }; @@ -624,6 +625,7 @@ void MainTests::test_validUtf8() char m[16]; QVERIFY(isValidUtf8("")); + QVERIFY(isValidUtf8("ƀ")); QVERIFY(isValidUtf8("Hello")); std::memset(m, 0, 16); @@ -682,6 +684,81 @@ void MainTests::test_validUtf8() QVERIFY(!isValidUtf8(e)); } +void MainTests::test_validUtf8Sse() +{ + std::shared_ptr store(new SubscriptionStore); + std::shared_ptr settings(new Settings); + ThreadData data(0, store, settings); + + char m[16]; + + QVERIFY(data.isValidUtf8("")); + QVERIFY(data.isValidUtf8("ƀ")); + QVERIFY(data.isValidUtf8("Hello")); + + std::memset(m, 0, 16); + QVERIFY(!data.isValidUtf8(std::string(m, 16))); + + QVERIFY(data.isValidUtf8("Straƀe")); // two byte chars + QVERIFY(data.isValidUtf8("StraƀeHelloHelloHelloHelloHelloHello")); // two byte chars + QVERIFY(data.isValidUtf8("HelloHelloHelloHelloHelloHelloHelloHelloStraƀeHelloHelloHelloHelloHelloHello")); // two byte chars + + QVERIFY(!data.isValidUtf8("Straƀe#", true)); + QVERIFY(!data.isValidUtf8("ƀ#", true)); + QVERIFY(!data.isValidUtf8("#ƀ", true)); + QVERIFY(!data.isValidUtf8("+", true)); + QVERIFY(!data.isValidUtf8("🩰+asdfasdfasdf", true)); + QVERIFY(!data.isValidUtf8("+asdfasdfasdf", true)); + + std::memset(m, 0, 16); + m[0] = 'a'; + m[1] = 13; // is \r + QVERIFY(!data.isValidUtf8(std::string(m, 16))); + + const std::string unicode_ballet_shoes("🩰"); + QVERIFY(unicode_ballet_shoes.length() == 4); + QVERIFY(data.isValidUtf8(unicode_ballet_shoes)); + + const std::string unicode_ballot_box("☐"); + QVERIFY(unicode_ballot_box.length() == 3); + QVERIFY(data.isValidUtf8(unicode_ballot_box)); + + std::memset(m, 0, 16); + m[0] = 0b11000001; // Start 2 byte char + m[1] = 0b00000001; // Next byte doesn't start with 1, which is wrong + std::string a(m, 2); + QVERIFY(!data.isValidUtf8(a)); + + std::memset(m, 0, 16); + m[0] = 0b11100001; // Start 3 byte char + m[1] = 0b10100001; + m[2] = 0b00000001; // Next byte doesn't start with 1, which is wrong + std::string b(m, 3); + QVERIFY(!data.isValidUtf8(b)); + + std::memset(m, 0, 16); + m[0] = 0b11110001; // Start 4 byte char + m[1] = 0b10100001; + m[2] = 0b10100001; + m[3] = 0b00000001; // Next byte doesn't start with 1, which is wrong + std::string c(m, 4); + QVERIFY(!data.isValidUtf8(c)); + + std::memset(m, 0, 16); + m[0] = 0b11110001; // Start 4 byte char + m[1] = 0b10100001; + m[2] = 0b00100001; // Doesn't start with 1: invalid. + m[3] = 0b10000001; + std::string d(m, 4); + QVERIFY(!data.isValidUtf8(d)); + + // Upper ASCII, invalid + std::memset(m, 0, 16); + m[0] = 127; + std::string e(m, 1); + QVERIFY(!data.isValidUtf8(e)); +} + QTEST_GUILESS_MAIN(MainTests) #include "tst_maintests.moc" diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 012c2e1..f6acc49 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -441,7 +441,7 @@ void MqttPacket::handlePublish() topic = std::string(readBytes(variable_header_length), variable_header_length); subtopics = sender->getThreadData()->splitTopic(topic); - if (!isValidUtf8(topic, true)) + if (!sender->getThreadData()->isValidUtf8(topic, true)) { logger->logf(LOG_WARNING, "Client '%s' published a message with invalid UTF8 or +/# in it. Dropping.", sender->repr().c_str()); return; diff --git a/threaddata.cpp b/threaddata.cpp index 9b56c7b..3bbd86a 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -200,6 +200,90 @@ std::vector *ThreadData::splitTopic(const std::string &topic) return &subtopics; } +bool ThreadData::isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars) +{ + const int len = s.size(); + + if (len + 16 > TOPIC_MEMORY_LENGTH) + return false; + + std::memcpy(topicCopy.data(), s.c_str(), len); + std::memset(&topicCopy.data()[len], 0x20, 16); // I fill out with spaces, as valid chars + + int n = 0; + const char *i = topicCopy.data(); + while (n < len) + { + const int len_left = len - n; + assert(len_left > 0); + __m128i loaded = _mm_loadu_si128((__m128i*)&i[n]); + __m128i loaded_AND_non_ascii = _mm_and_si128(loaded, non_ascii_mask); + + if (alsoCheckInvalidPublishChars && (_mm_movemask_epi8(_mm_cmpeq_epi8(loaded, pound) || _mm_movemask_epi8(_mm_cmpeq_epi8(loaded, plus))))) + return false; + + int index = _mm_cmpestri(non_ascii_mask, 1, loaded_AND_non_ascii, len_left, 0); + n += index; + + // Checking multi-byte chars one by one. With some effort, this may be done using SIMD too, but the majority of uses will + // have a minimum of multi byte chars. + if (index < 16) + { + char x = i[n++]; + char char_len = 0; + int cur_code_point = 0; + + if((x & 0b11100000) == 0b11000000) // 2 byte char + { + char_len = 1; + cur_code_point += ((x & 0b00011111) << 6); + } + else if((x & 0b11110000) == 0b11100000) // 3 byte char + { + char_len = 2; + cur_code_point += ((x & 0b00001111) << 12); + } + else if((x & 0b11111000) == 0b11110000) // 4 byte char + { + char_len = 3; + cur_code_point += ((x & 0b00000111) << 18); + } + else + return false; + + while (char_len > 0) + { + if (n >= len) + return false; + + x = i[n++]; + + if((x & 0b11000000) != 0b10000000) // All remainer bytes of this code point needs to start with 10 + return false; + char_len--; + cur_code_point += ((x & 0b00111111) << (6*char_len)); + } + + if (cur_code_point >= 0xD800 && cur_code_point <= 0xDFFF) // Dec 55296-57343 + return false; + + if (cur_code_point == 0xFFFF) + return false; + } + else + { + if (_mm_movemask_epi8(_mm_cmplt_epi8(loaded, lowerBound))) + return false; + + if (_mm_movemask_epi8(_mm_cmpgt_epi8(loaded, lastAsciiChar))) + return false; + } + + } + + return true; +} + // TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial? void ThreadData::doKeepAliveCheck() { diff --git a/threaddata.h b/threaddata.h index 2a79371..91ecb48 100644 --- a/threaddata.h +++ b/threaddata.h @@ -54,6 +54,11 @@ class ThreadData std::vector subtopicParseMem; std::vector topicCopy; __m128i slashes = _mm_set1_epi8('/'); + __m128i lowerBound = _mm_set1_epi8(0x20); + __m128i lastAsciiChar = _mm_set1_epi8(0x7E); + __m128i non_ascii_mask = _mm_set1_epi8(0b10000000); + __m128i pound = _mm_set1_epi8('#'); + __m128i plus = _mm_set1_epi8('+'); void reload(std::shared_ptr settings); void wakeUpThread(); @@ -91,6 +96,7 @@ public: void queuePasswdFileReload(); std::vector *splitTopic(const std::string &topic); + bool isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars = false); };