From d3072e9cbc756dc86f35c067a3e8075c0538ef5f Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Tue, 1 Jun 2021 23:47:25 +0200 Subject: [PATCH] Refactor SIMD/SSE --- FlashMQTests/tst_maintests.cpp | 71 +++++++++++++++++++++++++++++++++++++++++++++++++++-------------------- main.cpp | 5 +---- mainapp.cpp | 5 +++-- mainapp.h | 2 -- mqttpacket.cpp | 26 +++++++++++++++++--------- mqttpacket.h | 2 +- threadlocalutils.cpp | 25 ++++++++++++++----------- threadlocalutils.h | 16 +++++++++------- utils.cpp | 55 ++++++++++++++++++++++++++++++++++++++++++++++--------- utils.h | 3 +++ 10 files changed, 145 insertions(+), 65 deletions(-) diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 390fb78..358218d 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -78,11 +78,13 @@ private slots: void test_sse_split(); - void test_validUtf8(); + void test_validUtf8Generic(); void test_validUtf8Sse(); void testPacketInt16Parse(); + void testTopicsMatch(); + }; MainTests::MainTests() @@ -600,7 +602,8 @@ void MainTests::test_acl_patterns_clientid() void MainTests::test_sse_split() { - Utils data; + SimdUtils data; + std::vector output; std::list topics; topics.push_back("one/two/threeabcasdfasdf/koe"); @@ -619,50 +622,51 @@ void MainTests::test_sse_split() for (const std::string &t : topics) { - QCOMPARE(*data.splitTopic(t), splitToVector(t, '/')); + data.splitTopic(t, output); + QCOMPARE(output, splitToVector(t, '/')); } } -void MainTests::test_validUtf8() +void MainTests::test_validUtf8Generic() { char m[16]; - QVERIFY(isValidUtf8("")); - QVERIFY(isValidUtf8("ƀ")); - QVERIFY(isValidUtf8("Hello")); + QVERIFY(isValidUtf8Generic("")); + QVERIFY(isValidUtf8Generic("ƀ")); + QVERIFY(isValidUtf8Generic("Hello")); std::memset(m, 0, 16); - QVERIFY(!isValidUtf8(std::string(m, 16))); + QVERIFY(!isValidUtf8Generic(std::string(m, 16))); - QVERIFY(isValidUtf8("Straƀe")); // two byte chars - QVERIFY(isValidUtf8("StraƀeHelloHelloHelloHelloHelloHello")); // two byte chars - QVERIFY(isValidUtf8("HelloHelloHelloHelloHelloHelloHelloHelloStraƀeHelloHelloHelloHelloHelloHello")); // two byte chars + QVERIFY(isValidUtf8Generic("Straƀe")); // two byte chars + QVERIFY(isValidUtf8Generic("StraƀeHelloHelloHelloHelloHelloHello")); // two byte chars + QVERIFY(isValidUtf8Generic("HelloHelloHelloHelloHelloHelloHelloHelloStraƀeHelloHelloHelloHelloHelloHello")); // two byte chars std::memset(m, 0, 16); m[0] = 'a'; m[1] = 13; // is \r - QVERIFY(!isValidUtf8(std::string(m, 16))); + QVERIFY(!isValidUtf8Generic(std::string(m, 16))); const std::string unicode_ballet_shoes("🩰"); QVERIFY(unicode_ballet_shoes.length() == 4); - QVERIFY(isValidUtf8(unicode_ballet_shoes)); + QVERIFY(isValidUtf8Generic(unicode_ballet_shoes)); const std::string unicode_ballot_box("☐"); QVERIFY(unicode_ballot_box.length() == 3); - QVERIFY(isValidUtf8(unicode_ballot_box)); + QVERIFY(isValidUtf8Generic(unicode_ballot_box)); std::memset(m, 0, 16); m[0] = 0b11000001; // Start 2 byte char m[1] = 0b00000001; // Next byte doesn't start with 1, which is wrong std::string a(m, 2); - QVERIFY(!isValidUtf8(a)); + QVERIFY(!isValidUtf8Generic(a)); std::memset(m, 0, 16); m[0] = 0b11100001; // Start 3 byte char m[1] = 0b10100001; m[2] = 0b00000001; // Next byte doesn't start with 1, which is wrong std::string b(m, 3); - QVERIFY(!isValidUtf8(b)); + QVERIFY(!isValidUtf8Generic(b)); std::memset(m, 0, 16); m[0] = 0b11110001; // Start 4 byte char @@ -670,7 +674,7 @@ void MainTests::test_validUtf8() m[2] = 0b10100001; m[3] = 0b00000001; // Next byte doesn't start with 1, which is wrong std::string c(m, 4); - QVERIFY(!isValidUtf8(c)); + QVERIFY(!isValidUtf8Generic(c)); std::memset(m, 0, 16); m[0] = 0b11110001; // Start 4 byte char @@ -678,18 +682,18 @@ void MainTests::test_validUtf8() m[2] = 0b00100001; // Doesn't start with 1: invalid. m[3] = 0b10000001; std::string d(m, 4); - QVERIFY(!isValidUtf8(d)); + QVERIFY(!isValidUtf8Generic(d)); // Upper ASCII, invalid std::memset(m, 0, 16); m[0] = 127; std::string e(m, 1); - QVERIFY(!isValidUtf8(e)); + QVERIFY(!isValidUtf8Generic(e)); } void MainTests::test_validUtf8Sse() { - Utils data; + SimdUtils data; char m[16]; @@ -776,6 +780,33 @@ void MainTests::testPacketInt16Parse() } } +void MainTests::testTopicsMatch() +{ + QVERIFY(topicsMatch("#", "")); + QVERIFY(topicsMatch("#", "asdf/b/sdf")); + QVERIFY(topicsMatch("#", "+/b/sdf")); + QVERIFY(topicsMatch("#", "/one/two/asdf")); + QVERIFY(topicsMatch("#", "/one/two/asdf/")); + QVERIFY(topicsMatch("+/+/+/+/+", "/one/two/asdf/")); + QVERIFY(topicsMatch("+/+/#", "/one/two/asdf/")); + QVERIFY(topicsMatch("+/+/#", "/1234567890abcdef/two/asdf/")); + QVERIFY(topicsMatch("+/+/#", "/1234567890abcdefg/two/asdf/")); + QVERIFY(topicsMatch("+/+/#", "/1234567890abcde/two/asdf/")); + QVERIFY(topicsMatch("+/+/#", "1234567890abcde//two/asdf/")); + + QVERIFY(!topicsMatch("+/santa", "/one/two/asdf/")); + QVERIFY(!topicsMatch("+/+/+/+/", "/one/two/asdf/a")); + QVERIFY(!topicsMatch("+/one/+/+/", "/one/two/asdf/a")); + + QVERIFY(topicsMatch("$SYS/cow", "$SYS/cow")); + QVERIFY(topicsMatch("$SYS/cow/+", "$SYS/cow/bla")); + QVERIFY(topicsMatch("$SYS/#", "$SYS/broker/clients/connected")); + + QVERIFY(!topicsMatch("$SYS/cow/+", "$SYS/cow/bla/foobar")); + QVERIFY(!topicsMatch("#", "$SYS/cow")); + +} + QTEST_GUILESS_MAIN(MainTests) #include "tst_maintests.moc" diff --git a/main.cpp b/main.cpp index 844e659..8a5c4cf 100644 --- a/main.cpp +++ b/main.cpp @@ -84,14 +84,11 @@ int main(int argc, char *argv[]) check(register_signal_handers()); std::string sse = "without SSE support"; -#ifdef __SSE2__ - sse = "with SSE2 support"; -#endif #ifdef __SSE4_2__ sse = "with SSE4.2 support"; #endif #ifdef NDEBUG - logger->logf(LOG_NOTICE, "Starting FlashMQ version %s, release build.", VERSION, sse.c_str()); + logger->logf(LOG_NOTICE, "Starting FlashMQ version %s, release build %s.", VERSION, sse.c_str()); #else logger->logf(LOG_NOTICE, "Starting FlashMQ version %s, debug build %s.", VERSION, sse.c_str()); #endif diff --git a/mainapp.cpp b/mainapp.cpp index e7a1261..4b9be20 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -340,10 +340,11 @@ void MainApp::publishStatsOnDollarTopic() void MainApp::publishStat(const std::string &topic, uint64_t n) { - std::vector *subtopics = utils.splitTopic(topic); + std::vector subtopics; + splitTopic(topic, subtopics); const std::string payload = std::to_string(n); Publish p(topic, payload, 0); - subscriptionStore->queuePacketAtSubscribers(*subtopics, p, true); + subscriptionStore->queuePacketAtSubscribers(subtopics, p, true); subscriptionStore->setRetainedMessage(topic, payload, 0); } diff --git a/mainapp.h b/mainapp.h index 5cc1afb..24fefce 100644 --- a/mainapp.h +++ b/mainapp.h @@ -41,7 +41,6 @@ License along with FlashMQ. If not, see . #include "timer.h" #include "scopedsocket.h" #include "oneinstancelock.h" -#include "threadlocalutils.h" #define VERSION "0.7.0" @@ -65,7 +64,6 @@ class MainApp std::mutex quitMutex; std::string fuzzFilePath; OneInstanceLock oneInstanceLock; - Utils utils; Logger *logger = Logger::getInstance(); diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 392685d..91f5897 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -23,9 +23,9 @@ License along with FlashMQ. If not, see . #include "utils.h" -#include "threadlocalutils.h" - -thread_local Utils utils; +// We can void constant reallocation of space for parsed subtopics by using this. But, beware to only use it during handling of the current +// packet. Don't access it for a stored packet, because then it will have changed. +thread_local std::vector gSubtopics; RemainingLength::RemainingLength() { @@ -106,7 +106,8 @@ MqttPacket::MqttPacket(const Publish &publish) : } this->topic = publish.topic; - this->subtopics = utils.splitTopic(this->topic); + this->subtopics = &gSubtopics; + splitTopic(this->topic, gSubtopics); packetType = PacketType::PUBLISH; this->qos = publish.qos; @@ -299,7 +300,7 @@ void MqttPacket::handleConnect() } // The specs don't really say what to do when client id not UTF8, so including here. - if (!utils.isValidUtf8(client_id) || !utils.isValidUtf8(username) || !utils.isValidUtf8(password) || !utils.isValidUtf8(will_topic)) + if (!isValidUtf8(client_id) || !isValidUtf8(username) || !isValidUtf8(password) || !isValidUtf8(will_topic)) { ConnAck connAck(ConnAckReturnCodes::MalformedUsernameOrPassword); MqttPacket response(connAck); @@ -419,7 +420,7 @@ void MqttPacket::handleSubscribe() uint16_t topicLength = readTwoBytesToUInt16(); std::string topic(readBytes(topicLength), topicLength); - if (topic.empty() || !utils.isValidUtf8(topic)) + if (topic.empty() || !isValidUtf8(topic)) throw ProtocolError("Subscribe topic not valid UTF-8."); if (!isValidSubscribePath(topic)) @@ -438,6 +439,7 @@ void MqttPacket::handleSubscribe() SubAck subAck(packet_id, subs_reponse_codes); MqttPacket response(subAck); sender->writeMqttPacket(response); + this->subtopics = nullptr; } void MqttPacket::handleUnsubscribe() @@ -454,7 +456,7 @@ void MqttPacket::handleUnsubscribe() uint16_t topicLength = readTwoBytesToUInt16(); std::string topic(readBytes(topicLength), topicLength); - if (topic.empty() || !utils.isValidUtf8(topic)) + if (topic.empty() || !isValidUtf8(topic)) throw ProtocolError("Subscribe topic not valid UTF-8."); sender->getThreadData()->getSubscriptionStore()->removeSubscription(sender, topic); @@ -485,9 +487,10 @@ void MqttPacket::handlePublish() throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal."); topic = std::string(readBytes(variable_header_length), variable_header_length); - subtopics = utils.splitTopic(topic); + subtopics = &gSubtopics; + splitTopic(topic, gSubtopics); - if (!utils.isValidUtf8(topic, true)) + if (!isValidUtf8(topic, true)) { logger->logf(LOG_WARNING, "Client '%s' published a message with invalid UTF8 or +/# in it. Dropping.", sender->repr().c_str()); return; @@ -546,6 +549,7 @@ void MqttPacket::handlePublish() // For the existing clients, we can just write the same packet back out, with our small alterations. sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(*subtopics, *this); } + this->subtopics = nullptr; } void MqttPacket::handlePubAck() @@ -678,6 +682,10 @@ const std::string &MqttPacket::getTopic() const return this->topic; } +/** + * @brief MqttPacket::getSubtopics returns a pointer to the parsed subtopics. Use with care! + * @return a pointer to a vector of subtopics that will be overwritten the next packet! + */ const std::vector *MqttPacket::getSubtopics() const { return this->subtopics; diff --git a/mqttpacket.h b/mqttpacket.h index 98f85a6..6c8ddb8 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -48,7 +48,7 @@ class MqttPacket #endif std::string topic; - std::vector *subtopics; // comes from local thread storage. See std::vector *ThreadData::splitTopic(std::string &topic) + std::vector *subtopics = nullptr; std::vector bites; size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header. RemainingLength remainingLength; diff --git a/threadlocalutils.cpp b/threadlocalutils.cpp index 7de86d5..3307f7a 100644 --- a/threadlocalutils.cpp +++ b/threadlocalutils.cpp @@ -1,9 +1,11 @@ +#ifdef __SSE4_2__ + #include "threadlocalutils.h" #include #include -Utils::Utils() : +SimdUtils::SimdUtils() : subtopicParseMem(TOPIC_MEMORY_LENGTH), topicCopy(TOPIC_MEMORY_LENGTH) { @@ -11,15 +13,14 @@ Utils::Utils() : } /** - * @brief ThreadData::splitTopic uses SSE4.2 to detect the '/' chars, 16 chars at a time, and returns a pointer to thread-local memory. - * @param topic string is altered: some extra space is reserved. - * @return Pointer to thread-owned vector of subtopics. - * - * Because it returns a pointer to the thread-local vector, only the current thread should touch it. + * @brief SimdUtils::splitTopic uses SSE4.2 to detect the '/' chars, 16 chars at a time, and returns a pointer to thread-local memory. + * @param topic + * @param output is cleared and emplaced in. You can give it members from the Utils class, to avoid re-allocation. + * @return */ -std::vector *Utils::splitTopic(const std::string &topic) +std::vector *SimdUtils::splitTopic(const std::string &topic, std::vector &output) { - subtopics.clear(); + output.clear(); const int s = topic.size(); std::memcpy(topicCopy.data(), topic.c_str(), s+1); @@ -41,16 +42,16 @@ std::vector *Utils::splitTopic(const std::string &topic) if (index < 16 || n >= s) { - subtopics.emplace_back(subtopicParseMem.data(), carryi); + output.emplace_back(subtopicParseMem.data(), carryi); carryi = 0; n++; } } - return &subtopics; + return &output; } -bool Utils::isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars) +bool SimdUtils::isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars) { const int len = s.size(); @@ -136,3 +137,5 @@ bool Utils::isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars) return true; } + +#endif diff --git a/threadlocalutils.h b/threadlocalutils.h index f87339f..c7a9f0a 100644 --- a/threadlocalutils.h +++ b/threadlocalutils.h @@ -1,18 +1,18 @@ #ifndef THREADLOCALUTILS_H #define THREADLOCALUTILS_H +#ifdef __SSE4_2__ + #include #include #include #define TOPIC_MEMORY_LENGTH 65560 -/** - * @brief The Utils class have utility functions that make use of pre-allocated memory. Use with thread_local or create per thread manually. - */ -class Utils + + +class SimdUtils { - std::vector subtopics; std::vector subtopicParseMem; std::vector topicCopy; __m128i slashes = _mm_set1_epi8('/'); @@ -23,11 +23,13 @@ class Utils __m128i plus = _mm_set1_epi8('+'); public: - Utils(); + SimdUtils(); - std::vector *splitTopic(const std::string &topic); + std::vector *splitTopic(const std::string &topic, std::vector &output); bool isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars = false); }; +#endif + #endif // THREADLOCALUTILS_H diff --git a/utils.cpp b/utils.cpp index 236a77c..358f9dc 100644 --- a/utils.cpp +++ b/utils.cpp @@ -34,6 +34,12 @@ License along with FlashMQ. If not, see . #include "logger.h" #include "evpencodectxmanager.h" + +#ifdef __SSE4_2__ +#include "threadlocalutils.h" +thread_local SimdUtils simdUtils; +#endif + std::list split(const std::string &input, const char sep, size_t max, bool keep_empty_parts) { std::list list; @@ -60,8 +66,16 @@ bool topicsMatch(const std::string &subscribeTopic, const std::string &publishTo if (!subscribeTopic.empty() && !publishTopic.empty() && publishTopic[0] == '$' && subscribeTopic[0] != '$') return false; - const std::vector subscribeParts = splitToVector(subscribeTopic, '/'); - const std::vector publishParts = splitToVector(publishTopic, '/'); + std::vector subscribeParts; + std::vector publishParts; + +#ifdef __SSE4_2__ + simdUtils.splitTopic(subscribeTopic, subscribeParts); + simdUtils.splitTopic(publishTopic, publishParts); +#else + splitToVector(subscribeTopic, subscribeParts, '/'); + splitToVector(publishTopic, publishParts, '/'); +#endif auto subscribe_itr = subscribeParts.begin(); auto publish_itr = publishParts.begin(); @@ -84,7 +98,7 @@ bool topicsMatch(const std::string &subscribeTopic, const std::string &publishTo return result; } -bool isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars) +bool isValidUtf8Generic(const std::string &s, bool alsoCheckInvalidPublishChars) { int multibyte_remain = 0; int cur_code_point = 0; @@ -146,6 +160,15 @@ bool isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars) return multibyte_remain == 0; } +bool isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars) +{ +#ifdef __SSE4_2__ + return simdUtils.isValidUtf8(s, alsoCheckInvalidPublishChars); +#else + return isValidUtf8Generic(s, alsoCheckInvalidPublishChars); +#endif +} + bool strContains(const std::string &s, const std::string &needle) { return s.find(needle) != std::string::npos; @@ -209,25 +232,39 @@ bool containsDangerousCharacters(const std::string &s) return false; } -const std::vector splitToVector(const std::string &input, const char sep, size_t max, bool keep_empty_parts) +void splitTopic(const std::string &topic, std::vector &output) +{ +#ifdef __SSE4_2__ + simdUtils.splitTopic(topic, output); +#else + splitToVector(topic, output, '/'); +#endif +} + +void splitToVector(const std::string &input, std::vector &output, const char sep, size_t max, bool keep_empty_parts) { const auto subtopic_count = std::count(input.begin(), input.end(), '/') + 1; - std::vector result; - result.reserve(subtopic_count); + output.reserve(subtopic_count); size_t start = 0; size_t end; const auto npos = std::string::npos; - while (result.size() < max && (end = input.find(sep, start)) != npos) + while (output.size() < max && (end = input.find(sep, start)) != npos) { if (start != end || keep_empty_parts) - result.push_back(input.substr(start, end - start)); + output.push_back(input.substr(start, end - start)); start = end + 1; // increase by length of seperator. } if (start != input.size() || keep_empty_parts) - result.push_back(input.substr(start, npos)); + output.push_back(input.substr(start, npos)); +} + +const std::vector splitToVector(const std::string &input, const char sep, size_t max, bool keep_empty_parts) +{ + std::vector result; + splitToVector(input, result, sep, max, keep_empty_parts); return result; } diff --git a/utils.h b/utils.h index 8fc5a0f..796928c 100644 --- a/utils.h +++ b/utils.h @@ -46,9 +46,12 @@ template int check(int rc) std::list split(const std::string &input, const char sep, size_t max = std::numeric_limits::max(), bool keep_empty_parts = true); const std::vector splitToVector(const std::string &input, const char sep, size_t max = std::numeric_limits::max(), bool keep_empty_parts = true); +void splitToVector(const std::string &input, std::vector &output, const char sep, size_t max = std::numeric_limits::max(), bool keep_empty_parts = true); +void splitTopic(const std::string &topic, std::vector &output); bool topicsMatch(const std::string &subscribeTopic, const std::string &publishTopic); +bool isValidUtf8Generic(const std::string &s, bool alsoCheckInvalidPublishChars = false); bool isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars = false); bool strContains(const std::string &s, const std::string &needle); -- libgit2 0.21.4