Commit 01496bc8a58539a90131de1e12084b7756cf02a5
1 parent
3dbab5a7
IsValidUtf8 check with SSE
Showing
4 changed files
with
168 additions
and
1 deletions
FlashMQTests/tst_maintests.cpp
| ... | ... | @@ -79,6 +79,7 @@ private slots: |
| 79 | 79 | void test_sse_split(); |
| 80 | 80 | |
| 81 | 81 | void test_validUtf8(); |
| 82 | + void test_validUtf8Sse(); | |
| 82 | 83 | |
| 83 | 84 | }; |
| 84 | 85 | |
| ... | ... | @@ -624,6 +625,7 @@ void MainTests::test_validUtf8() |
| 624 | 625 | char m[16]; |
| 625 | 626 | |
| 626 | 627 | QVERIFY(isValidUtf8("")); |
| 628 | + QVERIFY(isValidUtf8("ƀ")); | |
| 627 | 629 | QVERIFY(isValidUtf8("Hello")); |
| 628 | 630 | |
| 629 | 631 | std::memset(m, 0, 16); |
| ... | ... | @@ -682,6 +684,81 @@ void MainTests::test_validUtf8() |
| 682 | 684 | QVERIFY(!isValidUtf8(e)); |
| 683 | 685 | } |
| 684 | 686 | |
| 687 | +void MainTests::test_validUtf8Sse() | |
| 688 | +{ | |
| 689 | + std::shared_ptr<SubscriptionStore> store(new SubscriptionStore); | |
| 690 | + std::shared_ptr<Settings> settings(new Settings); | |
| 691 | + ThreadData data(0, store, settings); | |
| 692 | + | |
| 693 | + char m[16]; | |
| 694 | + | |
| 695 | + QVERIFY(data.isValidUtf8("")); | |
| 696 | + QVERIFY(data.isValidUtf8("ƀ")); | |
| 697 | + QVERIFY(data.isValidUtf8("Hello")); | |
| 698 | + | |
| 699 | + std::memset(m, 0, 16); | |
| 700 | + QVERIFY(!data.isValidUtf8(std::string(m, 16))); | |
| 701 | + | |
| 702 | + QVERIFY(data.isValidUtf8("Straƀe")); // two byte chars | |
| 703 | + QVERIFY(data.isValidUtf8("StraƀeHelloHelloHelloHelloHelloHello")); // two byte chars | |
| 704 | + QVERIFY(data.isValidUtf8("HelloHelloHelloHelloHelloHelloHelloHelloStraƀeHelloHelloHelloHelloHelloHello")); // two byte chars | |
| 705 | + | |
| 706 | + QVERIFY(!data.isValidUtf8("Straƀe#", true)); | |
| 707 | + QVERIFY(!data.isValidUtf8("ƀ#", true)); | |
| 708 | + QVERIFY(!data.isValidUtf8("#ƀ", true)); | |
| 709 | + QVERIFY(!data.isValidUtf8("+", true)); | |
| 710 | + QVERIFY(!data.isValidUtf8("🩰+asdfasdfasdf", true)); | |
| 711 | + QVERIFY(!data.isValidUtf8("+asdfasdfasdf", true)); | |
| 712 | + | |
| 713 | + std::memset(m, 0, 16); | |
| 714 | + m[0] = 'a'; | |
| 715 | + m[1] = 13; // is \r | |
| 716 | + QVERIFY(!data.isValidUtf8(std::string(m, 16))); | |
| 717 | + | |
| 718 | + const std::string unicode_ballet_shoes("🩰"); | |
| 719 | + QVERIFY(unicode_ballet_shoes.length() == 4); | |
| 720 | + QVERIFY(data.isValidUtf8(unicode_ballet_shoes)); | |
| 721 | + | |
| 722 | + const std::string unicode_ballot_box("☐"); | |
| 723 | + QVERIFY(unicode_ballot_box.length() == 3); | |
| 724 | + QVERIFY(data.isValidUtf8(unicode_ballot_box)); | |
| 725 | + | |
| 726 | + std::memset(m, 0, 16); | |
| 727 | + m[0] = 0b11000001; // Start 2 byte char | |
| 728 | + m[1] = 0b00000001; // Next byte doesn't start with 1, which is wrong | |
| 729 | + std::string a(m, 2); | |
| 730 | + QVERIFY(!data.isValidUtf8(a)); | |
| 731 | + | |
| 732 | + std::memset(m, 0, 16); | |
| 733 | + m[0] = 0b11100001; // Start 3 byte char | |
| 734 | + m[1] = 0b10100001; | |
| 735 | + m[2] = 0b00000001; // Next byte doesn't start with 1, which is wrong | |
| 736 | + std::string b(m, 3); | |
| 737 | + QVERIFY(!data.isValidUtf8(b)); | |
| 738 | + | |
| 739 | + std::memset(m, 0, 16); | |
| 740 | + m[0] = 0b11110001; // Start 4 byte char | |
| 741 | + m[1] = 0b10100001; | |
| 742 | + m[2] = 0b10100001; | |
| 743 | + m[3] = 0b00000001; // Next byte doesn't start with 1, which is wrong | |
| 744 | + std::string c(m, 4); | |
| 745 | + QVERIFY(!data.isValidUtf8(c)); | |
| 746 | + | |
| 747 | + std::memset(m, 0, 16); | |
| 748 | + m[0] = 0b11110001; // Start 4 byte char | |
| 749 | + m[1] = 0b10100001; | |
| 750 | + m[2] = 0b00100001; // Doesn't start with 1: invalid. | |
| 751 | + m[3] = 0b10000001; | |
| 752 | + std::string d(m, 4); | |
| 753 | + QVERIFY(!data.isValidUtf8(d)); | |
| 754 | + | |
| 755 | + // Upper ASCII, invalid | |
| 756 | + std::memset(m, 0, 16); | |
| 757 | + m[0] = 127; | |
| 758 | + std::string e(m, 1); | |
| 759 | + QVERIFY(!data.isValidUtf8(e)); | |
| 760 | +} | |
| 761 | + | |
| 685 | 762 | QTEST_GUILESS_MAIN(MainTests) |
| 686 | 763 | |
| 687 | 764 | #include "tst_maintests.moc" | ... | ... |
mqttpacket.cpp
| ... | ... | @@ -441,7 +441,7 @@ void MqttPacket::handlePublish() |
| 441 | 441 | topic = std::string(readBytes(variable_header_length), variable_header_length); |
| 442 | 442 | subtopics = sender->getThreadData()->splitTopic(topic); |
| 443 | 443 | |
| 444 | - if (!isValidUtf8(topic, true)) | |
| 444 | + if (!sender->getThreadData()->isValidUtf8(topic, true)) | |
| 445 | 445 | { |
| 446 | 446 | logger->logf(LOG_WARNING, "Client '%s' published a message with invalid UTF8 or +/# in it. Dropping.", sender->repr().c_str()); |
| 447 | 447 | return; | ... | ... |
threaddata.cpp
| ... | ... | @@ -200,6 +200,90 @@ std::vector<std::string> *ThreadData::splitTopic(const std::string &topic) |
| 200 | 200 | return &subtopics; |
| 201 | 201 | } |
| 202 | 202 | |
| 203 | +bool ThreadData::isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars) | |
| 204 | +{ | |
| 205 | + const int len = s.size(); | |
| 206 | + | |
| 207 | + if (len + 16 > TOPIC_MEMORY_LENGTH) | |
| 208 | + return false; | |
| 209 | + | |
| 210 | + std::memcpy(topicCopy.data(), s.c_str(), len); | |
| 211 | + std::memset(&topicCopy.data()[len], 0x20, 16); // I fill out with spaces, as valid chars | |
| 212 | + | |
| 213 | + int n = 0; | |
| 214 | + const char *i = topicCopy.data(); | |
| 215 | + while (n < len) | |
| 216 | + { | |
| 217 | + const int len_left = len - n; | |
| 218 | + assert(len_left > 0); | |
| 219 | + __m128i loaded = _mm_loadu_si128((__m128i*)&i[n]); | |
| 220 | + __m128i loaded_AND_non_ascii = _mm_and_si128(loaded, non_ascii_mask); | |
| 221 | + | |
| 222 | + if (alsoCheckInvalidPublishChars && (_mm_movemask_epi8(_mm_cmpeq_epi8(loaded, pound) || _mm_movemask_epi8(_mm_cmpeq_epi8(loaded, plus))))) | |
| 223 | + return false; | |
| 224 | + | |
| 225 | + int index = _mm_cmpestri(non_ascii_mask, 1, loaded_AND_non_ascii, len_left, 0); | |
| 226 | + n += index; | |
| 227 | + | |
| 228 | + // Checking multi-byte chars one by one. With some effort, this may be done using SIMD too, but the majority of uses will | |
| 229 | + // have a minimum of multi byte chars. | |
| 230 | + if (index < 16) | |
| 231 | + { | |
| 232 | + char x = i[n++]; | |
| 233 | + char char_len = 0; | |
| 234 | + int cur_code_point = 0; | |
| 235 | + | |
| 236 | + if((x & 0b11100000) == 0b11000000) // 2 byte char | |
| 237 | + { | |
| 238 | + char_len = 1; | |
| 239 | + cur_code_point += ((x & 0b00011111) << 6); | |
| 240 | + } | |
| 241 | + else if((x & 0b11110000) == 0b11100000) // 3 byte char | |
| 242 | + { | |
| 243 | + char_len = 2; | |
| 244 | + cur_code_point += ((x & 0b00001111) << 12); | |
| 245 | + } | |
| 246 | + else if((x & 0b11111000) == 0b11110000) // 4 byte char | |
| 247 | + { | |
| 248 | + char_len = 3; | |
| 249 | + cur_code_point += ((x & 0b00000111) << 18); | |
| 250 | + } | |
| 251 | + else | |
| 252 | + return false; | |
| 253 | + | |
| 254 | + while (char_len > 0) | |
| 255 | + { | |
| 256 | + if (n >= len) | |
| 257 | + return false; | |
| 258 | + | |
| 259 | + x = i[n++]; | |
| 260 | + | |
| 261 | + if((x & 0b11000000) != 0b10000000) // All remainer bytes of this code point needs to start with 10 | |
| 262 | + return false; | |
| 263 | + char_len--; | |
| 264 | + cur_code_point += ((x & 0b00111111) << (6*char_len)); | |
| 265 | + } | |
| 266 | + | |
| 267 | + if (cur_code_point >= 0xD800 && cur_code_point <= 0xDFFF) // Dec 55296-57343 | |
| 268 | + return false; | |
| 269 | + | |
| 270 | + if (cur_code_point == 0xFFFF) | |
| 271 | + return false; | |
| 272 | + } | |
| 273 | + else | |
| 274 | + { | |
| 275 | + if (_mm_movemask_epi8(_mm_cmplt_epi8(loaded, lowerBound))) | |
| 276 | + return false; | |
| 277 | + | |
| 278 | + if (_mm_movemask_epi8(_mm_cmpgt_epi8(loaded, lastAsciiChar))) | |
| 279 | + return false; | |
| 280 | + } | |
| 281 | + | |
| 282 | + } | |
| 283 | + | |
| 284 | + return true; | |
| 285 | +} | |
| 286 | + | |
| 203 | 287 | // TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial? |
| 204 | 288 | void ThreadData::doKeepAliveCheck() |
| 205 | 289 | { | ... | ... |
threaddata.h
| ... | ... | @@ -54,6 +54,11 @@ class ThreadData |
| 54 | 54 | std::vector<char> subtopicParseMem; |
| 55 | 55 | std::vector<char> topicCopy; |
| 56 | 56 | __m128i slashes = _mm_set1_epi8('/'); |
| 57 | + __m128i lowerBound = _mm_set1_epi8(0x20); | |
| 58 | + __m128i lastAsciiChar = _mm_set1_epi8(0x7E); | |
| 59 | + __m128i non_ascii_mask = _mm_set1_epi8(0b10000000); | |
| 60 | + __m128i pound = _mm_set1_epi8('#'); | |
| 61 | + __m128i plus = _mm_set1_epi8('+'); | |
| 57 | 62 | |
| 58 | 63 | void reload(std::shared_ptr<Settings> settings); |
| 59 | 64 | void wakeUpThread(); |
| ... | ... | @@ -91,6 +96,7 @@ public: |
| 91 | 96 | void queuePasswdFileReload(); |
| 92 | 97 | |
| 93 | 98 | std::vector<std::string> *splitTopic(const std::string &topic); |
| 99 | + bool isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars = false); | |
| 94 | 100 | |
| 95 | 101 | }; |
| 96 | 102 | ... | ... |