threaddata.cpp 10.6 KB
/*
This file is part of FlashMQ (https://www.flashmq.org)
Copyright (C) 2021 Wiebe Cazemier

FlashMQ is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, version 3.

FlashMQ is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public
License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
*/

#include "threaddata.h"
#include <string>
#include <sstream>
#include <cassert>

#define TOPIC_MEMORY_LENGTH 65560

ThreadData::ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore, std::shared_ptr<Settings> settings) :
    subscriptionStore(subscriptionStore),
    subtopicParseMem(TOPIC_MEMORY_LENGTH),
    topicCopy(TOPIC_MEMORY_LENGTH),
    settingsLocalCopy(*settings.get()),
    authentication(settingsLocalCopy),
    threadnr(threadnr)
{
    logger = Logger::getInstance();

    epollfd = check<std::runtime_error>(epoll_create(999));

    taskEventFd = eventfd(0, EFD_NONBLOCK);
    if (taskEventFd < 0)
        throw std::runtime_error("Can't create eventfd.");

    struct epoll_event ev;
    memset(&ev, 0, sizeof (struct epoll_event));
    ev.data.fd = taskEventFd;
    ev.events = EPOLLIN;
    check<std::runtime_error>(epoll_ctl(this->epollfd, EPOLL_CTL_ADD, taskEventFd, &ev));
}

void ThreadData::start(thread_f f)
{
    this->thread = std::thread(f, this);

    pthread_t native = this->thread.native_handle();
    std::ostringstream threadName;
    threadName << "FlashMQ T " << threadnr;
    threadName.flush();
    const char *c_str = threadName.str().c_str();
    pthread_setname_np(native, c_str);

    cpu_set_t cpuset;
    CPU_ZERO(&cpuset);
    CPU_SET(threadnr, &cpuset);
    check<std::runtime_error>(pthread_setaffinity_np(native, sizeof(cpuset), &cpuset));

    // It's not really necessary to get affinity again, but now I'm logging truth instead assumption.
    check<std::runtime_error>(pthread_getaffinity_np(native, sizeof(cpuset), &cpuset));
    int pinned_cpu = -1;
    for (int j = 0; j < CPU_SETSIZE; j++)
        if (CPU_ISSET(j, &cpuset))
            pinned_cpu = j;

    logger->logf(LOG_NOTICE, "Thread '%s' pinned to CPU %d", c_str, pinned_cpu);
}

void ThreadData::quit()
{
    running = false;
}

void ThreadData::giveClient(std::shared_ptr<Client> client)
{
    clients_by_fd_mutex.lock();
    int fd = client->getFd();
    clients_by_fd[fd] = client;
    clients_by_fd_mutex.unlock();

    struct epoll_event ev;
    memset(&ev, 0, sizeof (struct epoll_event));
    ev.data.fd = fd;
    ev.events = EPOLLIN;
    check<std::runtime_error>(epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev));
}

std::shared_ptr<Client> ThreadData::getClient(int fd)
{
    std::lock_guard<std::mutex> lck(clients_by_fd_mutex);
    return this->clients_by_fd[fd];
}

void ThreadData::removeClient(std::shared_ptr<Client> client)
{
    client->markAsDisconnecting();

    std::lock_guard<std::mutex> lck(clients_by_fd_mutex);
    clients_by_fd.erase(client->getFd());
}

void ThreadData::removeClient(int fd)
{
    std::lock_guard<std::mutex> lck(clients_by_fd_mutex);
    auto client_it = this->clients_by_fd.find(fd);
    if (client_it != this->clients_by_fd.end())
    {
        client_it->second->markAsDisconnecting();
        this->clients_by_fd.erase(fd);
    }
}

std::shared_ptr<SubscriptionStore> &ThreadData::getSubscriptionStore()
{
    return subscriptionStore;
}

void ThreadData::queueDoKeepAliveCheck()
{
    std::lock_guard<std::mutex> locker(taskQueueMutex);

    auto f = std::bind(&ThreadData::doKeepAliveCheck, this);
    taskQueue.push_front(f);

    wakeUpThread();
}

void ThreadData::queueQuit()
{
    std::lock_guard<std::mutex> locker(taskQueueMutex);

    auto f = std::bind(&ThreadData::quit, this);
    taskQueue.push_front(f);

    authentication.setQuitting();

    wakeUpThread();
}

void ThreadData::waitForQuit()
{
    thread.join();
}

void ThreadData::queuePasswdFileReload()
{
    std::lock_guard<std::mutex> locker(taskQueueMutex);

    auto f = std::bind(&Authentication::loadMosquittoPasswordFile, &authentication);
    taskQueue.push_front(f);

    auto f2 = std::bind(&Authentication::loadMosquittoAclFile, &authentication);
    taskQueue.push_front(f2);

    wakeUpThread();
}

/**
 * @brief ThreadData::splitTopic uses SSE4.2 to detect the '/' chars, 16 chars at a time, and returns a pointer to thread-local memory.
 * @param topic string is altered: some extra space is reserved.
 * @return Pointer to thread-owned vector of subtopics.
 *
 * Because it returns a pointer to the thread-local vector, only the current thread should touch it.
 */
std::vector<std::string> *ThreadData::splitTopic(const std::string &topic)
{
    subtopics.clear();

    const int s = topic.size();
    std::memcpy(topicCopy.data(), topic.c_str(), s+1);
    std::memset(&topicCopy.data()[s], 0, 16);
    int n = 0;
    int carryi = 0;
    while (n <= s)
    {
        const char *i = &topicCopy.data()[n];
        __m128i loaded = _mm_loadu_si128((__m128i*)i);

        int len_left = s - n;
        assert(len_left >= 0);
        int index = _mm_cmpestri(slashes, 1, loaded, len_left, 0);
        std::memcpy(&subtopicParseMem[carryi], i, index);
        carryi += std::min<int>(index, len_left);

        n += index;

        if (index < 16 || n >= s)
        {
            subtopics.emplace_back(subtopicParseMem.data(), carryi);
            carryi = 0;
            n++;
        }
    }

    return &subtopics;
}

bool ThreadData::isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars)
{
    const int len = s.size();

    if (len + 16 > TOPIC_MEMORY_LENGTH)
        return false;

    std::memcpy(topicCopy.data(), s.c_str(), len);
    std::memset(&topicCopy.data()[len], 0x20, 16); // I fill out with spaces, as valid chars

    int n = 0;
    const char *i = topicCopy.data();
    while (n < len)
    {
        const int len_left = len - n;
        assert(len_left > 0);
        __m128i loaded = _mm_loadu_si128((__m128i*)&i[n]);
        __m128i loaded_AND_non_ascii = _mm_and_si128(loaded, non_ascii_mask);

        if (alsoCheckInvalidPublishChars && (_mm_movemask_epi8(_mm_cmpeq_epi8(loaded, pound) || _mm_movemask_epi8(_mm_cmpeq_epi8(loaded, plus)))))
            return false;

        int index = _mm_cmpestri(non_ascii_mask, 1, loaded_AND_non_ascii, len_left, 0);
        n += index;

        // Checking multi-byte chars one by one. With some effort, this may be done using SIMD too, but the majority of uses will
        // have a minimum of multi byte chars.
        if (index < 16)
        {
            char x = i[n++];
            char char_len = 0;
            int cur_code_point = 0;

            if((x & 0b11100000) == 0b11000000) // 2 byte char
            {
                char_len = 1;
                cur_code_point += ((x & 0b00011111) << 6);
            }
            else if((x & 0b11110000) == 0b11100000) // 3 byte char
            {
                char_len = 2;
                cur_code_point += ((x & 0b00001111) << 12);
            }
            else if((x & 0b11111000) == 0b11110000) // 4 byte char
            {
                char_len = 3;
                cur_code_point += ((x & 0b00000111) << 18);
            }
            else
                return false;

            while (char_len > 0)
            {
                if (n >= len)
                    return false;

                x = i[n++];

                if((x & 0b11000000) != 0b10000000) // All remainer bytes of this code point needs to start with 10
                    return false;
                char_len--;
                cur_code_point += ((x & 0b00111111) << (6*char_len));
            }

            if (cur_code_point >= 0xD800 && cur_code_point <= 0xDFFF) // Dec 55296-57343
                return false;

            if (cur_code_point == 0xFFFF)
                return false;
        }
        else
        {
            if (_mm_movemask_epi8(_mm_cmplt_epi8(loaded, lowerBound)))
                return false;

            if (_mm_movemask_epi8(_mm_cmpgt_epi8(loaded, lastAsciiChar)))
                return false;
        }

    }

    return true;
}

// TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial?
void ThreadData::doKeepAliveCheck()
{
    // We don't need to stall normal connects and disconnects for keep-alive checking. We can do it later.
    std::unique_lock<std::mutex> lock(clients_by_fd_mutex, std::try_to_lock);
    if (!lock.owns_lock())
        return;

    logger->logf(LOG_DEBUG, "Doing keep-alive check in thread %d", threadnr);

    try
    {
        auto it = clients_by_fd.begin();
        while (it != clients_by_fd.end())
        {
            std::shared_ptr<Client> &client = it->second;
            if (client && client->keepAliveExpired())
            {
                client->setDisconnectReason("Keep-alive expired: " + client->getKeepAliveInfoString());
                it = clients_by_fd.erase(it);
            }
            else
            {
                if (client)
                    client->resetBuffersIfEligible();
                it++;
            }
        }
    }
    catch (std::exception &ex)
    {
        logger->logf(LOG_ERR, "Error handling keep-alives: %s.", ex.what());
    }
}

void ThreadData::initAuthPlugin()
{
    authentication.loadMosquittoPasswordFile();
    authentication.loadMosquittoAclFile();
    authentication.loadPlugin(settingsLocalCopy.authPluginPath);
    authentication.init();
    authentication.securityInit(false);
}

void ThreadData::reload(std::shared_ptr<Settings> settings)
{
    logger->logf(LOG_DEBUG, "Doing reload in thread %d", threadnr);

    try
    {
        // Because the auth plugin has a reference to it, it will also be updated.
        settingsLocalCopy = *settings.get();

        authentication.securityCleanup(true);
        authentication.securityInit(true);
    }
    catch (AuthPluginException &ex)
    {
        logger->logf(LOG_ERR, "Error reloading auth plugin: %s. Security checks will now fail, because we don't know the status of the plugin anymore.", ex.what());
    }
    catch (std::exception &ex)
    {
        logger->logf(LOG_ERR, "Error reloading: %s.", ex.what());
    }
}

void ThreadData::queueReload(std::shared_ptr<Settings> settings)
{
    std::lock_guard<std::mutex> locker(taskQueueMutex);

    auto f = std::bind(&ThreadData::reload, this, settings);
    taskQueue.push_front(f);

    wakeUpThread();
}

void ThreadData::wakeUpThread()
{
    uint64_t one = 1;
    check<std::runtime_error>(write(taskEventFd, &one, sizeof(uint64_t)));
}