From 40f8e5e5f8eac75a2647d423826b52acf38bcb71 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Thu, 10 Dec 2020 21:57:42 +0100 Subject: [PATCH] Publishing within the same thread works --- client.cpp | 14 +++++++++++++- client.h | 5 ++++- main.cpp | 21 +++++++++++++++++++++ mqttpacket.cpp | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--- mqttpacket.h | 10 ++++++++-- subscriptionstore.cpp | 17 +++++++++++++++++ subscriptionstore.h | 2 ++ threaddata.cpp | 18 ++++++++++++++++++ threaddata.h | 9 ++++++++- types.cpp | 7 +++++++ types.h | 8 ++++++++ 11 files changed, 171 insertions(+), 8 deletions(-) diff --git a/client.cpp b/client.cpp index ee7bd2f..a50bd86 100644 --- a/client.cpp +++ b/client.cpp @@ -57,7 +57,7 @@ bool Client::readFdIntoBuffer() return true; } -void Client::writeMqttPacket(MqttPacket &packet) +void Client::writeMqttPacket(const MqttPacket &packet) { if (packet.getSize() > getWriteBufMaxWriteSize()) growWriteBuffer(packet.getSize()); @@ -115,6 +115,18 @@ std::string Client::repr() return a.str(); } +void Client::queueMessage(const MqttPacket &packet) +{ + + + // TODO: semaphores on stl containers? +} + +void Client::queuedMessagesToBuffer() +{ + +} + bool Client::bufferToMqttPackets(std::vector &packetQueueIn, Client_p &sender) { while (getReadBufBytesUsed() >= MQTT_HEADER_LENGH) diff --git a/client.h b/client.h index de734d9..fc6ae46 100644 --- a/client.h +++ b/client.h @@ -92,13 +92,16 @@ public: void setAuthenticated(bool value) { authenticated = value;} bool getAuthenticated() { return authenticated; } bool hasConnectPacketSeen() { return connectPacketSeen; } + ThreadData_p getThreadData() { return threadData; } void writePingResp(); - void writeMqttPacket(MqttPacket &packet); + void writeMqttPacket(const MqttPacket &packet); bool writeBufIntoFd(); std::string repr(); + void queueMessage(const MqttPacket &packet); + void queuedMessagesToBuffer(); }; #endif // CLIENT_H diff --git a/main.cpp b/main.cpp index 148db8d..3f8cbb5 100644 --- a/main.cpp +++ b/main.cpp @@ -28,8 +28,22 @@ void do_thread_work(ThreadData *threadData) std::vector packetQueueIn; + uint64_t eventfd_value = 0; + while (1) { + if (eventfd_value > 0) + { + for (Client_p client : threadData->getReadyForDequeueing()) + { + client->queuedMessagesToBuffer(); + } + threadData->clearReadyForDequeueing(); + eventfd_value = 0; + } + + // TODO: do all the buftofd here, not spread out over + int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); if (fdcount > 0) @@ -39,6 +53,13 @@ void do_thread_work(ThreadData *threadData) struct epoll_event cur_ev = events[i]; int fd = cur_ev.data.fd; + // If this thread was actively woken up. + if (fd == threadData->event_fd) + { + read(fd, &eventfd_value, sizeof(uint64_t)); + continue; + } + Client_p client = threadData->getClient(fd); if (client) diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 58eecb9..6d4d138 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -3,16 +3,17 @@ #include #include +// constructor for parsing incoming packets MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender) : bites(len), fixed_header_length(fixed_header_length), sender(sender) { - unsigned char _packetType = (buf[0] & 0xF0) >> 4; + std::memcpy(&bites[0], buf, len); + first_byte = bites[0]; + unsigned char _packetType = (first_byte & 0xF0) >> 4; packetType = (PacketType)_packetType; pos += fixed_header_length; - - std::memcpy(&bites[0], buf, len); } MqttPacket::MqttPacket(const ConnAck &connAck) : @@ -46,14 +47,42 @@ MqttPacket::MqttPacket(const SubAck &subAck) : bites[1] = returnList.size() + 1; // TODO: make some generic way of calculating the header and use the multi-byte length } +MqttPacket::MqttPacket(const Publish &publish) : + bites(publish.topic.length() + publish.payload.length() + 2 + 2) // TODO: same as above +{ + packetType = PacketType::PUBLISH; + char first_byte = static_cast(packetType) << 4; + writeByte(first_byte); + + char topicLenMSB = (publish.topic.length() & 0xF0) >> 8; + char topicLenLSB = publish.topic.length() & 0x0F; + writeByte(topicLenMSB); + writeByte(topicLenLSB); + writeBytes(publish.topic.c_str(), publish.topic.length()); + + writeBytes(publish.payload.c_str(), publish.payload.length()); + + // TODO: untested. May be unnecessary, because a received packet can be resent just like that. +} + void MqttPacket::handle(std::shared_ptr &subscriptionStore) { + if (packetType != PacketType::CONNECT) + { + if (!sender->getAuthenticated()) + { + throw ProtocolError("Non-connect packet from non-authenticated client."); + } + } + if (packetType == PacketType::CONNECT) handleConnect(); else if (packetType == PacketType::PINGREQ) sender->writePingResp(); else if (packetType == PacketType::SUBSCRIBE) handleSubscribe(subscriptionStore); + else if (packetType == PacketType::PUBLISH) + handlePublish(subscriptionStore); } void MqttPacket::handleConnect() @@ -123,6 +152,7 @@ void MqttPacket::handleConnect() // TODO: validate UTF8 encoded username/password. sender->setClientProperties(client_id, username, true, keep_alive); + sender->setAuthenticated(true); std::cout << "Connect: " << sender->repr() << std::endl; @@ -159,6 +189,30 @@ void MqttPacket::handleSubscribe(std::shared_ptr &subscriptio sender->writeBufIntoFd(); } +void MqttPacket::handlePublish(std::shared_ptr &subscriptionStore) +{ + uint16_t variable_header_length = readTwoBytesToUInt16(); + + if (variable_header_length == 0) + throw ProtocolError("Empty publish topic"); + + char qos = (first_byte & 0b00000110) >> 1; + + std::string topic(readBytes(variable_header_length), variable_header_length); + + if (qos) + { + throw ProtocolError("Qos not implemented."); + //uint16_t packet_id = readTwoBytesToUInt16(); + } + + // TODO: validate UTF8. + size_t payload_length = remainingAfterPos(); + std::string payload(readBytes(payload_length), payload_length); + + subscriptionStore->queueAtClientsTemp(topic, *this, sender); +} + Client_p MqttPacket::getSender() const { @@ -197,6 +251,14 @@ void MqttPacket::writeByte(char b) bites[pos++] = b; } +void MqttPacket::writeBytes(const char *b, size_t len) +{ + if (pos + len > bites.size()) + throw ProtocolError("Exceeding packet size"); + + memcpy(&bites[pos], b, len); +} + uint16_t MqttPacket::readTwoBytesToUInt16() { if (pos + 2 > bites.size()) diff --git a/mqttpacket.h b/mqttpacket.h index 3dfc20b..f4b48e7 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -19,28 +19,34 @@ class MqttPacket std::vector bites; size_t fixed_header_length = 0; Client_p sender; + char first_byte; size_t pos = 0; ProtocolVersion protocolVersion = ProtocolVersion::None; char *readBytes(size_t length); char readByte(); void writeByte(char b); + void writeBytes(const char *b, size_t len); uint16_t readTwoBytesToUInt16(); size_t remainingAfterPos(); public: PacketType packetType = PacketType::Reserved; MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender); + + // TODO: not constructors, but static functions that return all the stuff after the fixed header, then a constructor with vector. MqttPacket(const ConnAck &connAck); MqttPacket(const SubAck &subAck); + MqttPacket(const Publish &publish); void handle(std::shared_ptr &subscriptionStore); void handleConnect(); void handleSubscribe(std::shared_ptr &subscriptionStore); void handlePing(); + void handlePublish(std::shared_ptr &subscriptionStore); - size_t getSize() { return bites.size(); } - const std::vector &getBites() { return bites; } + size_t getSize() const { return bites.size(); } + const std::vector &getBites() const { return bites; } Client_p getSender() const; void setSender(const Client_p &value); diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index 8c2e6b9..fd25cd2 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -9,3 +9,20 @@ void SubscriptionStore::addSubscription(Client_p &client, std::string &topic) { this->subscriptions[topic].push_back(client); } + +void SubscriptionStore::queueAtClientsTemp(std::string &topic, const MqttPacket &packet, const Client_p &sender) +{ + for(Client_p &client : subscriptions[topic]) + { + if (client->getThreadData()->threadnr == sender->getThreadData()->threadnr) + { + client->writeMqttPacket(packet); + client->writeBufIntoFd(); + } + else + { + client->queueMessage(packet); + client->getThreadData()->addToReadyForDequeuing(client); + } + } +} diff --git a/subscriptionstore.h b/subscriptionstore.h index 94ae25c..a661e36 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -18,6 +18,8 @@ public: // work with read copies intead of mutex/lock over the central store void getReadCopy(); // TODO + + void queueAtClientsTemp(std::string &topic, const MqttPacket &packet, const Client_p &sender); }; #endif // SUBSCRIPTIONSTORE_H diff --git a/threaddata.cpp b/threaddata.cpp index 93aa66a..8e77fea 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -35,3 +35,21 @@ std::shared_ptr &ThreadData::getSubscriptionStore() { return subscriptionStore; } + +void ThreadData::wakeUpThread() +{ + uint64_t one = 1; + write(event_fd, &one, sizeof(uint64_t)); +} + +void ThreadData::addToReadyForDequeuing(Client_p &client) +{ + this->readyForDequeueing.insert(client); +} + +void ThreadData::clearReadyForDequeueing() +{ + this->readyForDequeueing.clear(); +} + + diff --git a/threaddata.h b/threaddata.h index 35af923..15656f2 100644 --- a/threaddata.h +++ b/threaddata.h @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include "forward_declarations.h" @@ -17,8 +19,9 @@ class ThreadData { - std::map clients_by_fd; + std::unordered_map clients_by_fd; std::shared_ptr subscriptionStore; + std::unordered_set readyForDequeueing; public: std::thread thread; @@ -32,6 +35,10 @@ public: Client_p getClient(int fd); void removeClient(Client_p client); std::shared_ptr &getSubscriptionStore(); + void wakeUpThread(); + void addToReadyForDequeuing(Client_p &client); + std::unordered_set &getReadyForDequeueing() { return readyForDequeueing; } + void clearReadyForDequeueing(); }; #endif // THREADDATA_H diff --git a/types.cpp b/types.cpp index 8911257..0411337 100644 --- a/types.cpp +++ b/types.cpp @@ -15,3 +15,10 @@ SubAck::SubAck(uint16_t packet_id, const std::list &subs) : responses.push_back(SubAckReturnCodes::MaxQoS0); } } + +Publish::Publish(std::string &topic, std::string payload) : + topic(topic), + payload(payload) +{ + +} diff --git a/types.h b/types.h index 9432a90..f427173 100644 --- a/types.h +++ b/types.h @@ -66,4 +66,12 @@ public: SubAck(uint16_t packet_id, const std::list &subs); }; +class Publish +{ +public: + std::string topic; + std::string payload; + Publish(std::string &topic, std::string payload); +}; + #endif // TYPES_H -- libgit2 0.21.4