Commit 01496bc8a58539a90131de1e12084b7756cf02a5
1 parent
3dbab5a7
IsValidUtf8 check with SSE
Showing
4 changed files
with
168 additions
and
1 deletions
FlashMQTests/tst_maintests.cpp
| @@ -79,6 +79,7 @@ private slots: | @@ -79,6 +79,7 @@ private slots: | ||
| 79 | void test_sse_split(); | 79 | void test_sse_split(); |
| 80 | 80 | ||
| 81 | void test_validUtf8(); | 81 | void test_validUtf8(); |
| 82 | + void test_validUtf8Sse(); | ||
| 82 | 83 | ||
| 83 | }; | 84 | }; |
| 84 | 85 | ||
| @@ -624,6 +625,7 @@ void MainTests::test_validUtf8() | @@ -624,6 +625,7 @@ void MainTests::test_validUtf8() | ||
| 624 | char m[16]; | 625 | char m[16]; |
| 625 | 626 | ||
| 626 | QVERIFY(isValidUtf8("")); | 627 | QVERIFY(isValidUtf8("")); |
| 628 | + QVERIFY(isValidUtf8("ƀ")); | ||
| 627 | QVERIFY(isValidUtf8("Hello")); | 629 | QVERIFY(isValidUtf8("Hello")); |
| 628 | 630 | ||
| 629 | std::memset(m, 0, 16); | 631 | std::memset(m, 0, 16); |
| @@ -682,6 +684,81 @@ void MainTests::test_validUtf8() | @@ -682,6 +684,81 @@ void MainTests::test_validUtf8() | ||
| 682 | QVERIFY(!isValidUtf8(e)); | 684 | QVERIFY(!isValidUtf8(e)); |
| 683 | } | 685 | } |
| 684 | 686 | ||
| 687 | +void MainTests::test_validUtf8Sse() | ||
| 688 | +{ | ||
| 689 | + std::shared_ptr<SubscriptionStore> store(new SubscriptionStore); | ||
| 690 | + std::shared_ptr<Settings> settings(new Settings); | ||
| 691 | + ThreadData data(0, store, settings); | ||
| 692 | + | ||
| 693 | + char m[16]; | ||
| 694 | + | ||
| 695 | + QVERIFY(data.isValidUtf8("")); | ||
| 696 | + QVERIFY(data.isValidUtf8("ƀ")); | ||
| 697 | + QVERIFY(data.isValidUtf8("Hello")); | ||
| 698 | + | ||
| 699 | + std::memset(m, 0, 16); | ||
| 700 | + QVERIFY(!data.isValidUtf8(std::string(m, 16))); | ||
| 701 | + | ||
| 702 | + QVERIFY(data.isValidUtf8("Straƀe")); // two byte chars | ||
| 703 | + QVERIFY(data.isValidUtf8("StraƀeHelloHelloHelloHelloHelloHello")); // two byte chars | ||
| 704 | + QVERIFY(data.isValidUtf8("HelloHelloHelloHelloHelloHelloHelloHelloStraƀeHelloHelloHelloHelloHelloHello")); // two byte chars | ||
| 705 | + | ||
| 706 | + QVERIFY(!data.isValidUtf8("Straƀe#", true)); | ||
| 707 | + QVERIFY(!data.isValidUtf8("ƀ#", true)); | ||
| 708 | + QVERIFY(!data.isValidUtf8("#ƀ", true)); | ||
| 709 | + QVERIFY(!data.isValidUtf8("+", true)); | ||
| 710 | + QVERIFY(!data.isValidUtf8("🩰+asdfasdfasdf", true)); | ||
| 711 | + QVERIFY(!data.isValidUtf8("+asdfasdfasdf", true)); | ||
| 712 | + | ||
| 713 | + std::memset(m, 0, 16); | ||
| 714 | + m[0] = 'a'; | ||
| 715 | + m[1] = 13; // is \r | ||
| 716 | + QVERIFY(!data.isValidUtf8(std::string(m, 16))); | ||
| 717 | + | ||
| 718 | + const std::string unicode_ballet_shoes("🩰"); | ||
| 719 | + QVERIFY(unicode_ballet_shoes.length() == 4); | ||
| 720 | + QVERIFY(data.isValidUtf8(unicode_ballet_shoes)); | ||
| 721 | + | ||
| 722 | + const std::string unicode_ballot_box("☐"); | ||
| 723 | + QVERIFY(unicode_ballot_box.length() == 3); | ||
| 724 | + QVERIFY(data.isValidUtf8(unicode_ballot_box)); | ||
| 725 | + | ||
| 726 | + std::memset(m, 0, 16); | ||
| 727 | + m[0] = 0b11000001; // Start 2 byte char | ||
| 728 | + m[1] = 0b00000001; // Next byte doesn't start with 1, which is wrong | ||
| 729 | + std::string a(m, 2); | ||
| 730 | + QVERIFY(!data.isValidUtf8(a)); | ||
| 731 | + | ||
| 732 | + std::memset(m, 0, 16); | ||
| 733 | + m[0] = 0b11100001; // Start 3 byte char | ||
| 734 | + m[1] = 0b10100001; | ||
| 735 | + m[2] = 0b00000001; // Next byte doesn't start with 1, which is wrong | ||
| 736 | + std::string b(m, 3); | ||
| 737 | + QVERIFY(!data.isValidUtf8(b)); | ||
| 738 | + | ||
| 739 | + std::memset(m, 0, 16); | ||
| 740 | + m[0] = 0b11110001; // Start 4 byte char | ||
| 741 | + m[1] = 0b10100001; | ||
| 742 | + m[2] = 0b10100001; | ||
| 743 | + m[3] = 0b00000001; // Next byte doesn't start with 1, which is wrong | ||
| 744 | + std::string c(m, 4); | ||
| 745 | + QVERIFY(!data.isValidUtf8(c)); | ||
| 746 | + | ||
| 747 | + std::memset(m, 0, 16); | ||
| 748 | + m[0] = 0b11110001; // Start 4 byte char | ||
| 749 | + m[1] = 0b10100001; | ||
| 750 | + m[2] = 0b00100001; // Doesn't start with 1: invalid. | ||
| 751 | + m[3] = 0b10000001; | ||
| 752 | + std::string d(m, 4); | ||
| 753 | + QVERIFY(!data.isValidUtf8(d)); | ||
| 754 | + | ||
| 755 | + // Upper ASCII, invalid | ||
| 756 | + std::memset(m, 0, 16); | ||
| 757 | + m[0] = 127; | ||
| 758 | + std::string e(m, 1); | ||
| 759 | + QVERIFY(!data.isValidUtf8(e)); | ||
| 760 | +} | ||
| 761 | + | ||
| 685 | QTEST_GUILESS_MAIN(MainTests) | 762 | QTEST_GUILESS_MAIN(MainTests) |
| 686 | 763 | ||
| 687 | #include "tst_maintests.moc" | 764 | #include "tst_maintests.moc" |
mqttpacket.cpp
| @@ -441,7 +441,7 @@ void MqttPacket::handlePublish() | @@ -441,7 +441,7 @@ void MqttPacket::handlePublish() | ||
| 441 | topic = std::string(readBytes(variable_header_length), variable_header_length); | 441 | topic = std::string(readBytes(variable_header_length), variable_header_length); |
| 442 | subtopics = sender->getThreadData()->splitTopic(topic); | 442 | subtopics = sender->getThreadData()->splitTopic(topic); |
| 443 | 443 | ||
| 444 | - if (!isValidUtf8(topic, true)) | 444 | + if (!sender->getThreadData()->isValidUtf8(topic, true)) |
| 445 | { | 445 | { |
| 446 | logger->logf(LOG_WARNING, "Client '%s' published a message with invalid UTF8 or +/# in it. Dropping.", sender->repr().c_str()); | 446 | logger->logf(LOG_WARNING, "Client '%s' published a message with invalid UTF8 or +/# in it. Dropping.", sender->repr().c_str()); |
| 447 | return; | 447 | return; |
threaddata.cpp
| @@ -200,6 +200,90 @@ std::vector<std::string> *ThreadData::splitTopic(const std::string &topic) | @@ -200,6 +200,90 @@ std::vector<std::string> *ThreadData::splitTopic(const std::string &topic) | ||
| 200 | return &subtopics; | 200 | return &subtopics; |
| 201 | } | 201 | } |
| 202 | 202 | ||
| 203 | +bool ThreadData::isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars) | ||
| 204 | +{ | ||
| 205 | + const int len = s.size(); | ||
| 206 | + | ||
| 207 | + if (len + 16 > TOPIC_MEMORY_LENGTH) | ||
| 208 | + return false; | ||
| 209 | + | ||
| 210 | + std::memcpy(topicCopy.data(), s.c_str(), len); | ||
| 211 | + std::memset(&topicCopy.data()[len], 0x20, 16); // I fill out with spaces, as valid chars | ||
| 212 | + | ||
| 213 | + int n = 0; | ||
| 214 | + const char *i = topicCopy.data(); | ||
| 215 | + while (n < len) | ||
| 216 | + { | ||
| 217 | + const int len_left = len - n; | ||
| 218 | + assert(len_left > 0); | ||
| 219 | + __m128i loaded = _mm_loadu_si128((__m128i*)&i[n]); | ||
| 220 | + __m128i loaded_AND_non_ascii = _mm_and_si128(loaded, non_ascii_mask); | ||
| 221 | + | ||
| 222 | + if (alsoCheckInvalidPublishChars && (_mm_movemask_epi8(_mm_cmpeq_epi8(loaded, pound) || _mm_movemask_epi8(_mm_cmpeq_epi8(loaded, plus))))) | ||
| 223 | + return false; | ||
| 224 | + | ||
| 225 | + int index = _mm_cmpestri(non_ascii_mask, 1, loaded_AND_non_ascii, len_left, 0); | ||
| 226 | + n += index; | ||
| 227 | + | ||
| 228 | + // Checking multi-byte chars one by one. With some effort, this may be done using SIMD too, but the majority of uses will | ||
| 229 | + // have a minimum of multi byte chars. | ||
| 230 | + if (index < 16) | ||
| 231 | + { | ||
| 232 | + char x = i[n++]; | ||
| 233 | + char char_len = 0; | ||
| 234 | + int cur_code_point = 0; | ||
| 235 | + | ||
| 236 | + if((x & 0b11100000) == 0b11000000) // 2 byte char | ||
| 237 | + { | ||
| 238 | + char_len = 1; | ||
| 239 | + cur_code_point += ((x & 0b00011111) << 6); | ||
| 240 | + } | ||
| 241 | + else if((x & 0b11110000) == 0b11100000) // 3 byte char | ||
| 242 | + { | ||
| 243 | + char_len = 2; | ||
| 244 | + cur_code_point += ((x & 0b00001111) << 12); | ||
| 245 | + } | ||
| 246 | + else if((x & 0b11111000) == 0b11110000) // 4 byte char | ||
| 247 | + { | ||
| 248 | + char_len = 3; | ||
| 249 | + cur_code_point += ((x & 0b00000111) << 18); | ||
| 250 | + } | ||
| 251 | + else | ||
| 252 | + return false; | ||
| 253 | + | ||
| 254 | + while (char_len > 0) | ||
| 255 | + { | ||
| 256 | + if (n >= len) | ||
| 257 | + return false; | ||
| 258 | + | ||
| 259 | + x = i[n++]; | ||
| 260 | + | ||
| 261 | + if((x & 0b11000000) != 0b10000000) // All remainer bytes of this code point needs to start with 10 | ||
| 262 | + return false; | ||
| 263 | + char_len--; | ||
| 264 | + cur_code_point += ((x & 0b00111111) << (6*char_len)); | ||
| 265 | + } | ||
| 266 | + | ||
| 267 | + if (cur_code_point >= 0xD800 && cur_code_point <= 0xDFFF) // Dec 55296-57343 | ||
| 268 | + return false; | ||
| 269 | + | ||
| 270 | + if (cur_code_point == 0xFFFF) | ||
| 271 | + return false; | ||
| 272 | + } | ||
| 273 | + else | ||
| 274 | + { | ||
| 275 | + if (_mm_movemask_epi8(_mm_cmplt_epi8(loaded, lowerBound))) | ||
| 276 | + return false; | ||
| 277 | + | ||
| 278 | + if (_mm_movemask_epi8(_mm_cmpgt_epi8(loaded, lastAsciiChar))) | ||
| 279 | + return false; | ||
| 280 | + } | ||
| 281 | + | ||
| 282 | + } | ||
| 283 | + | ||
| 284 | + return true; | ||
| 285 | +} | ||
| 286 | + | ||
| 203 | // TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial? | 287 | // TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial? |
| 204 | void ThreadData::doKeepAliveCheck() | 288 | void ThreadData::doKeepAliveCheck() |
| 205 | { | 289 | { |
threaddata.h
| @@ -54,6 +54,11 @@ class ThreadData | @@ -54,6 +54,11 @@ class ThreadData | ||
| 54 | std::vector<char> subtopicParseMem; | 54 | std::vector<char> subtopicParseMem; |
| 55 | std::vector<char> topicCopy; | 55 | std::vector<char> topicCopy; |
| 56 | __m128i slashes = _mm_set1_epi8('/'); | 56 | __m128i slashes = _mm_set1_epi8('/'); |
| 57 | + __m128i lowerBound = _mm_set1_epi8(0x20); | ||
| 58 | + __m128i lastAsciiChar = _mm_set1_epi8(0x7E); | ||
| 59 | + __m128i non_ascii_mask = _mm_set1_epi8(0b10000000); | ||
| 60 | + __m128i pound = _mm_set1_epi8('#'); | ||
| 61 | + __m128i plus = _mm_set1_epi8('+'); | ||
| 57 | 62 | ||
| 58 | void reload(std::shared_ptr<Settings> settings); | 63 | void reload(std::shared_ptr<Settings> settings); |
| 59 | void wakeUpThread(); | 64 | void wakeUpThread(); |
| @@ -91,6 +96,7 @@ public: | @@ -91,6 +96,7 @@ public: | ||
| 91 | void queuePasswdFileReload(); | 96 | void queuePasswdFileReload(); |
| 92 | 97 | ||
| 93 | std::vector<std::string> *splitTopic(const std::string &topic); | 98 | std::vector<std::string> *splitTopic(const std::string &topic); |
| 99 | + bool isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars = false); | ||
| 94 | 100 | ||
| 95 | }; | 101 | }; |
| 96 | 102 |