Commit 3dbab5a7f7ff34e8cbe5e30acf177088a6af69eb

Authored by Wiebe Cazemier
1 parent 5c74f973

Optimize UTF-8 check and add tests

FlashMQTests/tst_maintests.cpp
@@ -78,6 +78,8 @@ private slots: @@ -78,6 +78,8 @@ private slots:
78 78
79 void test_sse_split(); 79 void test_sse_split();
80 80
  81 + void test_validUtf8();
  82 +
81 }; 83 };
82 84
83 MainTests::MainTests() 85 MainTests::MainTests()
@@ -608,6 +610,8 @@ void MainTests::test_sse_split() @@ -608,6 +610,8 @@ void MainTests::test_sse_split()
608 topics.push_back("//1234567890abcdef/1234567890abcdefg/koe/"); 610 topics.push_back("//1234567890abcdef/1234567890abcdefg/koe/");
609 topics.push_back("//1234567890abcdef/1234567890abcdefg/koe//"); 611 topics.push_back("//1234567890abcdef/1234567890abcdefg/koe//");
610 topics.push_back("//1234567890abcdef/1234567890abcdef/"); 612 topics.push_back("//1234567890abcdef/1234567890abcdef/");
  613 + topics.push_back("/");
  614 + topics.push_back("");
611 615
612 for (const std::string &t : topics) 616 for (const std::string &t : topics)
613 { 617 {
@@ -615,6 +619,69 @@ void MainTests::test_sse_split() @@ -615,6 +619,69 @@ void MainTests::test_sse_split()
615 } 619 }
616 } 620 }
617 621
  622 +void MainTests::test_validUtf8()
  623 +{
  624 + char m[16];
  625 +
  626 + QVERIFY(isValidUtf8(""));
  627 + QVERIFY(isValidUtf8("Hello"));
  628 +
  629 + std::memset(m, 0, 16);
  630 + QVERIFY(!isValidUtf8(std::string(m, 16)));
  631 +
  632 + QVERIFY(isValidUtf8("Straƀe")); // two byte chars
  633 + QVERIFY(isValidUtf8("StraƀeHelloHelloHelloHelloHelloHello")); // two byte chars
  634 + QVERIFY(isValidUtf8("HelloHelloHelloHelloHelloHelloHelloHelloStraƀeHelloHelloHelloHelloHelloHello")); // two byte chars
  635 +
  636 + std::memset(m, 0, 16);
  637 + m[0] = 'a';
  638 + m[1] = 13; // is \r
  639 + QVERIFY(!isValidUtf8(std::string(m, 16)));
  640 +
  641 + const std::string unicode_ballet_shoes("🩰");
  642 + QVERIFY(unicode_ballet_shoes.length() == 4);
  643 + QVERIFY(isValidUtf8(unicode_ballet_shoes));
  644 +
  645 + const std::string unicode_ballot_box("☐");
  646 + QVERIFY(unicode_ballot_box.length() == 3);
  647 + QVERIFY(isValidUtf8(unicode_ballot_box));
  648 +
  649 + std::memset(m, 0, 16);
  650 + m[0] = 0b11000001; // Start 2 byte char
  651 + m[1] = 0b00000001; // Next byte doesn't start with 1, which is wrong
  652 + std::string a(m, 2);
  653 + QVERIFY(!isValidUtf8(a));
  654 +
  655 + std::memset(m, 0, 16);
  656 + m[0] = 0b11100001; // Start 3 byte char
  657 + m[1] = 0b10100001;
  658 + m[2] = 0b00000001; // Next byte doesn't start with 1, which is wrong
  659 + std::string b(m, 3);
  660 + QVERIFY(!isValidUtf8(b));
  661 +
  662 + std::memset(m, 0, 16);
  663 + m[0] = 0b11110001; // Start 4 byte char
  664 + m[1] = 0b10100001;
  665 + m[2] = 0b10100001;
  666 + m[3] = 0b00000001; // Next byte doesn't start with 1, which is wrong
  667 + std::string c(m, 4);
  668 + QVERIFY(!isValidUtf8(c));
  669 +
  670 + std::memset(m, 0, 16);
  671 + m[0] = 0b11110001; // Start 4 byte char
  672 + m[1] = 0b10100001;
  673 + m[2] = 0b00100001; // Doesn't start with 1: invalid.
  674 + m[3] = 0b10000001;
  675 + std::string d(m, 4);
  676 + QVERIFY(!isValidUtf8(d));
  677 +
  678 + // Upper ASCII, invalid
  679 + std::memset(m, 0, 16);
  680 + m[0] = 127;
  681 + std::string e(m, 1);
  682 + QVERIFY(!isValidUtf8(e));
  683 +}
  684 +
618 QTEST_GUILESS_MAIN(MainTests) 685 QTEST_GUILESS_MAIN(MainTests)
619 686
620 #include "tst_maintests.moc" 687 #include "tst_maintests.moc"
threaddata.cpp
@@ -18,6 +18,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. @@ -18,6 +18,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
18 #include "threaddata.h" 18 #include "threaddata.h"
19 #include <string> 19 #include <string>
20 #include <sstream> 20 #include <sstream>
  21 +#include <cassert>
21 22
22 #define TOPIC_MEMORY_LENGTH 65560 23 #define TOPIC_MEMORY_LENGTH 65560
23 24
@@ -181,6 +182,7 @@ std::vector&lt;std::string&gt; *ThreadData::splitTopic(const std::string &amp;topic) @@ -181,6 +182,7 @@ std::vector&lt;std::string&gt; *ThreadData::splitTopic(const std::string &amp;topic)
181 __m128i loaded = _mm_loadu_si128((__m128i*)i); 182 __m128i loaded = _mm_loadu_si128((__m128i*)i);
182 183
183 int len_left = s - n; 184 int len_left = s - n;
  185 + assert(len_left >= 0);
184 int index = _mm_cmpestri(slashes, 1, loaded, len_left, 0); 186 int index = _mm_cmpestri(slashes, 1, loaded, len_left, 0);
185 std::memcpy(&subtopicParseMem[carryi], i, index); 187 std::memcpy(&subtopicParseMem[carryi], i, index);
186 carryi += std::min<int>(index, len_left); 188 carryi += std::min<int>(index, len_left);
utils.cpp
@@ -83,11 +83,8 @@ bool isValidUtf8(const std::string &amp;s, bool alsoCheckInvalidPublishChars) @@ -83,11 +83,8 @@ bool isValidUtf8(const std::string &amp;s, bool alsoCheckInvalidPublishChars)
83 { 83 {
84 int multibyte_remain = 0; 84 int multibyte_remain = 0;
85 int cur_code_point = 0; 85 int cur_code_point = 0;
86 - for(const char &x : s) 86 + for(const char x : s)
87 { 87 {
88 - if (x == 0)  
89 - return false;  
90 -  
91 if (alsoCheckInvalidPublishChars && (x == '#' || x == '+')) 88 if (alsoCheckInvalidPublishChars && (x == '#' || x == '+'))
92 return false; 89 return false;
93 90
@@ -95,7 +92,11 @@ bool isValidUtf8(const std::string &amp;s, bool alsoCheckInvalidPublishChars) @@ -95,7 +92,11 @@ bool isValidUtf8(const std::string &amp;s, bool alsoCheckInvalidPublishChars)
95 { 92 {
96 cur_code_point = 0; 93 cur_code_point = 0;
97 94
98 - if((x & 0b11100000) == 0b11000000) // 2 byte char 95 + if ((x & 0b10000000) == 0) // when the MSB is 0, it's ASCII, most common case
  96 + {
  97 + cur_code_point += (x & 0b01111111);
  98 + }
  99 + else if((x & 0b11100000) == 0b11000000) // 2 byte char
99 { 100 {
100 multibyte_remain = 1; 101 multibyte_remain = 1;
101 cur_code_point += ((x & 0b00011111) << 6); 102 cur_code_point += ((x & 0b00011111) << 6);
@@ -128,7 +129,7 @@ bool isValidUtf8(const std::string &amp;s, bool alsoCheckInvalidPublishChars) @@ -128,7 +129,7 @@ bool isValidUtf8(const std::string &amp;s, bool alsoCheckInvalidPublishChars)
128 // Invalid range for MQTT. [MQTT-1.5.3-1] 129 // Invalid range for MQTT. [MQTT-1.5.3-1]
129 if (cur_code_point >= 0xD800 && cur_code_point <= 0xDFFF) // Dec 55296-57343 130 if (cur_code_point >= 0xD800 && cur_code_point <= 0xDFFF) // Dec 55296-57343
130 return false; 131 return false;
131 - if (cur_code_point >= 0x0001 && cur_code_point <= 0x001F) 132 + if (cur_code_point <= 0x001F)
132 return false; 133 return false;
133 if (cur_code_point >= 0x007F && cur_code_point <= 0x009F) 134 if (cur_code_point >= 0x007F && cur_code_point <= 0x009F)
134 return false; 135 return false;