Commit 01496bc8a58539a90131de1e12084b7756cf02a5

Authored by Wiebe Cazemier
1 parent 3dbab5a7

IsValidUtf8 check with SSE

FlashMQTests/tst_maintests.cpp
@@ -79,6 +79,7 @@ private slots: @@ -79,6 +79,7 @@ private slots:
79 void test_sse_split(); 79 void test_sse_split();
80 80
81 void test_validUtf8(); 81 void test_validUtf8();
  82 + void test_validUtf8Sse();
82 83
83 }; 84 };
84 85
@@ -624,6 +625,7 @@ void MainTests::test_validUtf8() @@ -624,6 +625,7 @@ void MainTests::test_validUtf8()
624 char m[16]; 625 char m[16];
625 626
626 QVERIFY(isValidUtf8("")); 627 QVERIFY(isValidUtf8(""));
  628 + QVERIFY(isValidUtf8("ƀ"));
627 QVERIFY(isValidUtf8("Hello")); 629 QVERIFY(isValidUtf8("Hello"));
628 630
629 std::memset(m, 0, 16); 631 std::memset(m, 0, 16);
@@ -682,6 +684,81 @@ void MainTests::test_validUtf8() @@ -682,6 +684,81 @@ void MainTests::test_validUtf8()
682 QVERIFY(!isValidUtf8(e)); 684 QVERIFY(!isValidUtf8(e));
683 } 685 }
684 686
  687 +void MainTests::test_validUtf8Sse()
  688 +{
  689 + std::shared_ptr<SubscriptionStore> store(new SubscriptionStore);
  690 + std::shared_ptr<Settings> settings(new Settings);
  691 + ThreadData data(0, store, settings);
  692 +
  693 + char m[16];
  694 +
  695 + QVERIFY(data.isValidUtf8(""));
  696 + QVERIFY(data.isValidUtf8("ƀ"));
  697 + QVERIFY(data.isValidUtf8("Hello"));
  698 +
  699 + std::memset(m, 0, 16);
  700 + QVERIFY(!data.isValidUtf8(std::string(m, 16)));
  701 +
  702 + QVERIFY(data.isValidUtf8("Straƀe")); // two byte chars
  703 + QVERIFY(data.isValidUtf8("StraƀeHelloHelloHelloHelloHelloHello")); // two byte chars
  704 + QVERIFY(data.isValidUtf8("HelloHelloHelloHelloHelloHelloHelloHelloStraƀeHelloHelloHelloHelloHelloHello")); // two byte chars
  705 +
  706 + QVERIFY(!data.isValidUtf8("Straƀe#", true));
  707 + QVERIFY(!data.isValidUtf8("ƀ#", true));
  708 + QVERIFY(!data.isValidUtf8("#ƀ", true));
  709 + QVERIFY(!data.isValidUtf8("+", true));
  710 + QVERIFY(!data.isValidUtf8("🩰+asdfasdfasdf", true));
  711 + QVERIFY(!data.isValidUtf8("+asdfasdfasdf", true));
  712 +
  713 + std::memset(m, 0, 16);
  714 + m[0] = 'a';
  715 + m[1] = 13; // is \r
  716 + QVERIFY(!data.isValidUtf8(std::string(m, 16)));
  717 +
  718 + const std::string unicode_ballet_shoes("🩰");
  719 + QVERIFY(unicode_ballet_shoes.length() == 4);
  720 + QVERIFY(data.isValidUtf8(unicode_ballet_shoes));
  721 +
  722 + const std::string unicode_ballot_box("☐");
  723 + QVERIFY(unicode_ballot_box.length() == 3);
  724 + QVERIFY(data.isValidUtf8(unicode_ballot_box));
  725 +
  726 + std::memset(m, 0, 16);
  727 + m[0] = 0b11000001; // Start 2 byte char
  728 + m[1] = 0b00000001; // Next byte doesn't start with 1, which is wrong
  729 + std::string a(m, 2);
  730 + QVERIFY(!data.isValidUtf8(a));
  731 +
  732 + std::memset(m, 0, 16);
  733 + m[0] = 0b11100001; // Start 3 byte char
  734 + m[1] = 0b10100001;
  735 + m[2] = 0b00000001; // Next byte doesn't start with 1, which is wrong
  736 + std::string b(m, 3);
  737 + QVERIFY(!data.isValidUtf8(b));
  738 +
  739 + std::memset(m, 0, 16);
  740 + m[0] = 0b11110001; // Start 4 byte char
  741 + m[1] = 0b10100001;
  742 + m[2] = 0b10100001;
  743 + m[3] = 0b00000001; // Next byte doesn't start with 1, which is wrong
  744 + std::string c(m, 4);
  745 + QVERIFY(!data.isValidUtf8(c));
  746 +
  747 + std::memset(m, 0, 16);
  748 + m[0] = 0b11110001; // Start 4 byte char
  749 + m[1] = 0b10100001;
  750 + m[2] = 0b00100001; // Doesn't start with 1: invalid.
  751 + m[3] = 0b10000001;
  752 + std::string d(m, 4);
  753 + QVERIFY(!data.isValidUtf8(d));
  754 +
  755 + // Upper ASCII, invalid
  756 + std::memset(m, 0, 16);
  757 + m[0] = 127;
  758 + std::string e(m, 1);
  759 + QVERIFY(!data.isValidUtf8(e));
  760 +}
  761 +
685 QTEST_GUILESS_MAIN(MainTests) 762 QTEST_GUILESS_MAIN(MainTests)
686 763
687 #include "tst_maintests.moc" 764 #include "tst_maintests.moc"
mqttpacket.cpp
@@ -441,7 +441,7 @@ void MqttPacket::handlePublish() @@ -441,7 +441,7 @@ void MqttPacket::handlePublish()
441 topic = std::string(readBytes(variable_header_length), variable_header_length); 441 topic = std::string(readBytes(variable_header_length), variable_header_length);
442 subtopics = sender->getThreadData()->splitTopic(topic); 442 subtopics = sender->getThreadData()->splitTopic(topic);
443 443
444 - if (!isValidUtf8(topic, true)) 444 + if (!sender->getThreadData()->isValidUtf8(topic, true))
445 { 445 {
446 logger->logf(LOG_WARNING, "Client '%s' published a message with invalid UTF8 or +/# in it. Dropping.", sender->repr().c_str()); 446 logger->logf(LOG_WARNING, "Client '%s' published a message with invalid UTF8 or +/# in it. Dropping.", sender->repr().c_str());
447 return; 447 return;
threaddata.cpp
@@ -200,6 +200,90 @@ std::vector&lt;std::string&gt; *ThreadData::splitTopic(const std::string &amp;topic) @@ -200,6 +200,90 @@ std::vector&lt;std::string&gt; *ThreadData::splitTopic(const std::string &amp;topic)
200 return &subtopics; 200 return &subtopics;
201 } 201 }
202 202
  203 +bool ThreadData::isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars)
  204 +{
  205 + const int len = s.size();
  206 +
  207 + if (len + 16 > TOPIC_MEMORY_LENGTH)
  208 + return false;
  209 +
  210 + std::memcpy(topicCopy.data(), s.c_str(), len);
  211 + std::memset(&topicCopy.data()[len], 0x20, 16); // I fill out with spaces, as valid chars
  212 +
  213 + int n = 0;
  214 + const char *i = topicCopy.data();
  215 + while (n < len)
  216 + {
  217 + const int len_left = len - n;
  218 + assert(len_left > 0);
  219 + __m128i loaded = _mm_loadu_si128((__m128i*)&i[n]);
  220 + __m128i loaded_AND_non_ascii = _mm_and_si128(loaded, non_ascii_mask);
  221 +
  222 + if (alsoCheckInvalidPublishChars && (_mm_movemask_epi8(_mm_cmpeq_epi8(loaded, pound) || _mm_movemask_epi8(_mm_cmpeq_epi8(loaded, plus)))))
  223 + return false;
  224 +
  225 + int index = _mm_cmpestri(non_ascii_mask, 1, loaded_AND_non_ascii, len_left, 0);
  226 + n += index;
  227 +
  228 + // Checking multi-byte chars one by one. With some effort, this may be done using SIMD too, but the majority of uses will
  229 + // have a minimum of multi byte chars.
  230 + if (index < 16)
  231 + {
  232 + char x = i[n++];
  233 + char char_len = 0;
  234 + int cur_code_point = 0;
  235 +
  236 + if((x & 0b11100000) == 0b11000000) // 2 byte char
  237 + {
  238 + char_len = 1;
  239 + cur_code_point += ((x & 0b00011111) << 6);
  240 + }
  241 + else if((x & 0b11110000) == 0b11100000) // 3 byte char
  242 + {
  243 + char_len = 2;
  244 + cur_code_point += ((x & 0b00001111) << 12);
  245 + }
  246 + else if((x & 0b11111000) == 0b11110000) // 4 byte char
  247 + {
  248 + char_len = 3;
  249 + cur_code_point += ((x & 0b00000111) << 18);
  250 + }
  251 + else
  252 + return false;
  253 +
  254 + while (char_len > 0)
  255 + {
  256 + if (n >= len)
  257 + return false;
  258 +
  259 + x = i[n++];
  260 +
  261 + if((x & 0b11000000) != 0b10000000) // All remainer bytes of this code point needs to start with 10
  262 + return false;
  263 + char_len--;
  264 + cur_code_point += ((x & 0b00111111) << (6*char_len));
  265 + }
  266 +
  267 + if (cur_code_point >= 0xD800 && cur_code_point <= 0xDFFF) // Dec 55296-57343
  268 + return false;
  269 +
  270 + if (cur_code_point == 0xFFFF)
  271 + return false;
  272 + }
  273 + else
  274 + {
  275 + if (_mm_movemask_epi8(_mm_cmplt_epi8(loaded, lowerBound)))
  276 + return false;
  277 +
  278 + if (_mm_movemask_epi8(_mm_cmpgt_epi8(loaded, lastAsciiChar)))
  279 + return false;
  280 + }
  281 +
  282 + }
  283 +
  284 + return true;
  285 +}
  286 +
203 // TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial? 287 // TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial?
204 void ThreadData::doKeepAliveCheck() 288 void ThreadData::doKeepAliveCheck()
205 { 289 {
threaddata.h
@@ -54,6 +54,11 @@ class ThreadData @@ -54,6 +54,11 @@ class ThreadData
54 std::vector<char> subtopicParseMem; 54 std::vector<char> subtopicParseMem;
55 std::vector<char> topicCopy; 55 std::vector<char> topicCopy;
56 __m128i slashes = _mm_set1_epi8('/'); 56 __m128i slashes = _mm_set1_epi8('/');
  57 + __m128i lowerBound = _mm_set1_epi8(0x20);
  58 + __m128i lastAsciiChar = _mm_set1_epi8(0x7E);
  59 + __m128i non_ascii_mask = _mm_set1_epi8(0b10000000);
  60 + __m128i pound = _mm_set1_epi8('#');
  61 + __m128i plus = _mm_set1_epi8('+');
57 62
58 void reload(std::shared_ptr<Settings> settings); 63 void reload(std::shared_ptr<Settings> settings);
59 void wakeUpThread(); 64 void wakeUpThread();
@@ -91,6 +96,7 @@ public: @@ -91,6 +96,7 @@ public:
91 void queuePasswdFileReload(); 96 void queuePasswdFileReload();
92 97
93 std::vector<std::string> *splitTopic(const std::string &topic); 98 std::vector<std::string> *splitTopic(const std::string &topic);
  99 + bool isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars = false);
94 100
95 }; 101 };
96 102