From 5c74f9737a8d35b437ea1b7df888341c59c9f505 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Tue, 27 Apr 2021 15:37:29 +0200 Subject: [PATCH] Topic splitting with SSE instructions --- CMakeLists.txt | 2 +- FlashMQTests/FlashMQTests.pro | 1 + FlashMQTests/tst_maintests.cpp | 28 ++++++++++++++++++++++++++++ mqttpacket.cpp | 10 +++++----- mqttpacket.h | 4 ++-- session.cpp | 2 +- threaddata.cpp | 43 +++++++++++++++++++++++++++++++++++++++++++ threaddata.h | 10 ++++++++++ 8 files changed, 91 insertions(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1e9af0f..617f662 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,7 +7,7 @@ add_definitions(-DOPENSSL_API_COMPAT=0x10100000L) set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD_REQUIRED ON) -SET(CMAKE_CXX_FLAGS "-rdynamic") +SET(CMAKE_CXX_FLAGS "-rdynamic -msse4.2") add_compile_options(-Wall) diff --git a/FlashMQTests/FlashMQTests.pro b/FlashMQTests/FlashMQTests.pro index cdd835b..973e9b2 100644 --- a/FlashMQTests/FlashMQTests.pro +++ b/FlashMQTests/FlashMQTests.pro @@ -79,3 +79,4 @@ HEADERS += \ LIBS += -ldl -lssl -lcrypto QMAKE_LFLAGS += -rdynamic +QMAKE_CXXFLAGS += -msse4.2 diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 4ab8aaf..bf96ce4 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -26,6 +26,7 @@ License along with FlashMQ. If not, see . #include "mainapp.h" #include "mainappthread.h" #include "twoclienttestcontext.h" +#include "threaddata.h" // Dumb Qt version gives warnings when comparing uint with number literal. template @@ -75,6 +76,8 @@ private slots: void test_acl_patterns_username(); void test_acl_patterns_clientid(); + void test_sse_split(); + }; MainTests::MainTests() @@ -587,6 +590,31 @@ void MainTests::test_acl_patterns_clientid() QCOMPARE(aclTree.findPermission(splitToVector("d/clientid_one/f/A/B", '/'), AclGrant::Read, "foo", "clientid_one"), AuthResult::success); } +void MainTests::test_sse_split() +{ + std::shared_ptr store(new SubscriptionStore); + std::shared_ptr settings(new Settings); + ThreadData data(0, store, settings); + + std::list topics; + topics.push_back("one/two/threeabcasdfasdf/koe"); + topics.push_back("/two/threeabcasdfasdf/koe"); // Test empty component. + topics.push_back("//two/threeabcasdfasdf/koe"); // Test two empty components. + topics.push_back("//1234567890abcde/bla/koe"); // Test two empty components, 15 char topic (one byte short of 16 alignment). + topics.push_back("//1234567890abcdef/bla/koe"); // Test two empty components, 16 char topic + topics.push_back("//1234567890abcdefg/bla/koe"); // Test two empty components, 17 char topic + topics.push_back("//1234567890abcdefg/1234567890abcdefg/koe"); // Test two empty components, two 17 char topics + topics.push_back("//1234567890abcdef/1234567890abcdefg/koe"); // Test two empty components, 16 and 17 char + topics.push_back("//1234567890abcdef/1234567890abcdefg/koe/"); + topics.push_back("//1234567890abcdef/1234567890abcdefg/koe//"); + topics.push_back("//1234567890abcdef/1234567890abcdef/"); + + for (const std::string &t : topics) + { + QCOMPARE(*data.splitTopic(t), splitToVector(t, '/')); + } +} + QTEST_GUILESS_MAIN(MainTests) #include "tst_maintests.moc" diff --git a/mqttpacket.cpp b/mqttpacket.cpp index fe8e0bf..012c2e1 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -102,7 +102,7 @@ MqttPacket::MqttPacket(const Publish &publish) : } this->topic = publish.topic; - this->subtopics = splitToVector(publish.topic, '/'); + this->subtopics = sender->getThreadData()->splitTopic(this->topic); packetType = PacketType::PUBLISH; this->qos = publish.qos; @@ -439,7 +439,7 @@ void MqttPacket::handlePublish() throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal."); topic = std::string(readBytes(variable_header_length), variable_header_length); - subtopics = splitToVector(topic, '/'); + subtopics = sender->getThreadData()->splitTopic(topic); if (!isValidUtf8(topic, true)) { @@ -464,7 +464,7 @@ void MqttPacket::handlePublish() sender->writeMqttPacket(response); } - if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, subtopics, AclAccess::write) == AuthResult::success) + if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, *subtopics, AclAccess::write) == AuthResult::success) { if (retain) { @@ -479,7 +479,7 @@ void MqttPacket::handlePublish() bites[0] &= 0b11110110; // For the existing clients, we can just write the same packet back out, with our small alterations. - sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(subtopics, *this); + sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(*subtopics, *this); } } @@ -569,7 +569,7 @@ const std::string &MqttPacket::getTopic() const return this->topic; } -const std::vector &MqttPacket::getSubtopics() const +const std::vector *MqttPacket::getSubtopics() const { return this->subtopics; } diff --git a/mqttpacket.h b/mqttpacket.h index 9244ace..0a55042 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -44,7 +44,7 @@ public: class MqttPacket { std::string topic; - std::vector subtopics; + std::vector *subtopics; // comes from local thread storage. See std::vector *ThreadData::splitTopic(std::string &topic) std::vector bites; size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header. RemainingLength remainingLength; @@ -94,7 +94,7 @@ public: const std::vector &getBites() const { return bites; } char getQos() const { return qos; } const std::string &getTopic() const; - const std::vector &getSubtopics() const; + const std::vector *getSubtopics() const; std::shared_ptr getSender() const; void setSender(const std::shared_ptr &value); bool containsFixedHeader() const; diff --git a/session.cpp b/session.cpp index 3706b70..77b73ca 100644 --- a/session.cpp +++ b/session.cpp @@ -52,7 +52,7 @@ void Session::writePacket(const MqttPacket &packet, char max_qos) { assert(max_qos <= 2); - if (thread->authentication.aclCheck(client_id, username, packet.getTopic(), packet.getSubtopics(), AclAccess::read) == AuthResult::success) + if (thread->authentication.aclCheck(client_id, username, packet.getTopic(), *packet.getSubtopics(), AclAccess::read) == AuthResult::success) { const char qos = std::min(packet.getQos(), max_qos); diff --git a/threaddata.cpp b/threaddata.cpp index 90e8213..d231724 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -19,8 +19,12 @@ License along with FlashMQ. If not, see . #include #include +#define TOPIC_MEMORY_LENGTH 65560 + ThreadData::ThreadData(int threadnr, std::shared_ptr &subscriptionStore, std::shared_ptr settings) : subscriptionStore(subscriptionStore), + subtopicParseMem(TOPIC_MEMORY_LENGTH), + topicCopy(TOPIC_MEMORY_LENGTH), settingsLocalCopy(*settings.get()), authentication(settingsLocalCopy), threadnr(threadnr) @@ -155,6 +159,45 @@ void ThreadData::queuePasswdFileReload() wakeUpThread(); } +/** + * @brief ThreadData::splitTopic uses SSE4.2 to detect the '/' chars, 16 chars at a time, and returns a pointer to thread-local memory. + * @param topic string is altered: some extra space is reserved. + * @return Pointer to thread-owned vector of subtopics. + * + * Because it returns a pointer to the thread-local vector, only the current thread should touch it. + */ +std::vector *ThreadData::splitTopic(const std::string &topic) +{ + subtopics.clear(); + + const int s = topic.size(); + std::memcpy(topicCopy.data(), topic.c_str(), s+1); + std::memset(&topicCopy.data()[s], 0, 16); + int n = 0; + int carryi = 0; + while (n <= s) + { + const char *i = &topicCopy.data()[n]; + __m128i loaded = _mm_loadu_si128((__m128i*)i); + + int len_left = s - n; + int index = _mm_cmpestri(slashes, 1, loaded, len_left, 0); + std::memcpy(&subtopicParseMem[carryi], i, index); + carryi += std::min(index, len_left); + + n += index; + + if (index < 16 || n >= s) + { + subtopics.emplace_back(subtopicParseMem.data(), carryi); + carryi = 0; + n++; + } + } + + return &subtopics; +} + // TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial? void ThreadData::doKeepAliveCheck() { diff --git a/threaddata.h b/threaddata.h index 2433a0a..2a79371 100644 --- a/threaddata.h +++ b/threaddata.h @@ -18,6 +18,8 @@ License along with FlashMQ. If not, see . #ifndef THREADDATA_H #define THREADDATA_H +#include + #include #include @@ -47,6 +49,12 @@ class ThreadData std::shared_ptr subscriptionStore; Logger *logger; + // Topic parsing working memory + std::vector subtopics; + std::vector subtopicParseMem; + std::vector topicCopy; + __m128i slashes = _mm_set1_epi8('/'); + void reload(std::shared_ptr settings); void wakeUpThread(); void doKeepAliveCheck(); @@ -82,6 +90,8 @@ public: void waitForQuit(); void queuePasswdFileReload(); + std::vector *splitTopic(const std::string &topic); + }; #endif // THREADDATA_H -- libgit2 0.21.4