From 2a3138f9c422ddeea1055d4773e84cf437b687b7 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Thu, 10 Dec 2020 15:35:14 +0100 Subject: [PATCH] Reorder stuff for storing subscription store --- CMakeLists.txt | 4 +++- client.cpp | 4 ++-- client.h | 11 +++-------- forward_declarations.h | 14 ++++++++++++++ main.cpp | 10 +++++++--- mqttpacket.cpp | 19 +++++++++++++++---- mqttpacket.h | 15 +++++++++------ subscriptionstore.cpp | 11 +++++++++++ subscriptionstore.h | 23 +++++++++++++++++++++++ threaddata.cpp | 8 +++++++- threaddata.h | 18 +++++++++++------- 11 files changed, 105 insertions(+), 32 deletions(-) create mode 100644 forward_declarations.h create mode 100644 subscriptionstore.cpp create mode 100644 subscriptionstore.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 477563e..444838a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,8 @@ add_executable(FlashMQ bytestopacketparser.cpp mqttpacket.cpp exceptions.cpp - types.cpp) + types.cpp + subscriptionstore.cpp + ) target_link_libraries(FlashMQ pthread) diff --git a/client.cpp b/client.cpp index f0d5873..ee7bd2f 100644 --- a/client.cpp +++ b/client.cpp @@ -115,7 +115,7 @@ std::string Client::repr() return a.str(); } -bool Client::bufferToMqttPackets(std::vector &packetQueueIn) +bool Client::bufferToMqttPackets(std::vector &packetQueueIn, Client_p &sender) { while (getReadBufBytesUsed() >= MQTT_HEADER_LENGH) { @@ -144,7 +144,7 @@ bool Client::bufferToMqttPackets(std::vector &packetQueueIn) if (packet_length <= getReadBufBytesUsed()) { - MqttPacket packet(&readbuf[ri], packet_length, remaining_length_i, this); + MqttPacket packet(&readbuf[ri], packet_length, remaining_length_i, sender); packetQueueIn.push_back(std::move(packet)); ri += packet_length; diff --git a/client.h b/client.h index 51e221b..de734d9 100644 --- a/client.h +++ b/client.h @@ -5,6 +5,8 @@ #include #include +#include "forward_declarations.h" + #include "threaddata.h" #include "mqttpacket.h" #include "exceptions.h" @@ -12,11 +14,6 @@ #define CLIENT_BUFFER_SIZE 1024 #define MQTT_HEADER_LENGH 2 -class ThreadData; -typedef std::shared_ptr ThreadData_p; - -class MqttPacket; - class Client { int fd; @@ -90,7 +87,7 @@ public: int getFd() { return fd;} bool readFdIntoBuffer(); - bool bufferToMqttPackets(std::vector &packetQueueIn); + bool bufferToMqttPackets(std::vector &packetQueueIn, Client_p &sender); void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive); void setAuthenticated(bool value) { authenticated = value;} bool getAuthenticated() { return authenticated; } @@ -104,6 +101,4 @@ public: }; -typedef std::shared_ptr Client_p; - #endif // CLIENT_H diff --git a/forward_declarations.h b/forward_declarations.h new file mode 100644 index 0000000..c56d5a8 --- /dev/null +++ b/forward_declarations.h @@ -0,0 +1,14 @@ +#ifndef FORWARD_DECLARATIONS_H +#define FORWARD_DECLARATIONS_H + +#include + +class Client; +typedef std::shared_ptr Client_p; +class ThreadData; +typedef std::shared_ptr ThreadData_p; +class MqttPacket; +class SubscriptionStore; + + +#endif // FORWARD_DECLARATIONS_H diff --git a/main.cpp b/main.cpp index 7edeb56..e0c4ac7 100644 --- a/main.cpp +++ b/main.cpp @@ -7,10 +7,12 @@ #include #include + #include "utils.h" #include "threaddata.h" #include "client.h" #include "mqttpacket.h" +#include "subscriptionstore.h" #define MAX_EVENTS 1024 #define NR_OF_THREADS 4 @@ -50,7 +52,7 @@ void do_thread_work(ThreadData *threadData) } else { - client->bufferToMqttPackets(packetQueueIn); // TODO: different, because now I need to give the packet a raw pointer. + client->bufferToMqttPackets(packetQueueIn, client); } } if (cur_ev.events & EPOLLOUT) @@ -64,7 +66,7 @@ void do_thread_work(ThreadData *threadData) for (MqttPacket &packet : packetQueueIn) { - packet.handle(); + packet.handle(threadData->getSubscriptionStore()); } packetQueueIn.clear(); } @@ -100,11 +102,13 @@ int main() ev.events = EPOLLIN; check(epoll_ctl(epoll_fd_accept, EPOLL_CTL_ADD, listen_fd, &ev)); + std::shared_ptr subscriptionStore(new SubscriptionStore()); + std::vector> threads; for (int i = 0; i < NR_OF_THREADS; i++) { - std::shared_ptr t(new ThreadData(i)); + std::shared_ptr t(new ThreadData(i, subscriptionStore)); std::thread thread(do_thread_work, t.get()); t->thread = std::move(thread); threads.push_back(t); diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 5dbce12..58eecb9 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -3,7 +3,7 @@ #include #include -MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender) : +MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender) : bites(len), fixed_header_length(fixed_header_length), sender(sender) @@ -46,14 +46,14 @@ MqttPacket::MqttPacket(const SubAck &subAck) : bites[1] = returnList.size() + 1; // TODO: make some generic way of calculating the header and use the multi-byte length } -void MqttPacket::handle() +void MqttPacket::handle(std::shared_ptr &subscriptionStore) { if (packetType == PacketType::CONNECT) handleConnect(); else if (packetType == PacketType::PINGREQ) sender->writePingResp(); else if (packetType == PacketType::SUBSCRIBE) - handleSubscribe(); + handleSubscribe(subscriptionStore); } void MqttPacket::handleConnect() @@ -137,7 +137,7 @@ void MqttPacket::handleConnect() } } -void MqttPacket::handleSubscribe() +void MqttPacket::handleSubscribe(std::shared_ptr &subscriptionStore) { uint16_t packet_id = readTwoBytesToUInt16(); @@ -147,6 +147,7 @@ void MqttPacket::handleSubscribe() uint16_t topicLength = readTwoBytesToUInt16(); std::string topic(readBytes(topicLength), topicLength); std::cout << sender->repr() << " Subscribed to " << topic << std::endl; + subscriptionStore->addSubscription(sender, topic); subs.push_back(std::move(topic)); } @@ -159,6 +160,16 @@ void MqttPacket::handleSubscribe() } +Client_p MqttPacket::getSender() const +{ + return sender; +} + +void MqttPacket::setSender(const Client_p &value) +{ + sender = value; +} + char *MqttPacket::readBytes(size_t length) { if (pos + length > bites.size()) diff --git a/mqttpacket.h b/mqttpacket.h index ac2fed8..3dfc20b 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -6,18 +6,19 @@ #include #include +#include "forward_declarations.h" + #include "client.h" #include "exceptions.h" #include "types.h" - -class Client; +#include "subscriptionstore.h" class MqttPacket { std::vector bites; size_t fixed_header_length = 0; - Client *sender; + Client_p sender; size_t pos = 0; ProtocolVersion protocolVersion = ProtocolVersion::None; @@ -29,18 +30,20 @@ class MqttPacket public: PacketType packetType = PacketType::Reserved; - MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender); + MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender); MqttPacket(const ConnAck &connAck); MqttPacket(const SubAck &subAck); - void handle(); + void handle(std::shared_ptr &subscriptionStore); void handleConnect(); - void handleSubscribe(); + void handleSubscribe(std::shared_ptr &subscriptionStore); void handlePing(); size_t getSize() { return bites.size(); } const std::vector &getBites() { return bites; } + Client_p getSender() const; + void setSender(const Client_p &value); }; #endif // MQTTPACKET_H diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp new file mode 100644 index 0000000..8c2e6b9 --- /dev/null +++ b/subscriptionstore.cpp @@ -0,0 +1,11 @@ +#include "subscriptionstore.h" + +SubscriptionStore::SubscriptionStore() +{ + +} + +void SubscriptionStore::addSubscription(Client_p &client, std::string &topic) +{ + this->subscriptions[topic].push_back(client); +} diff --git a/subscriptionstore.h b/subscriptionstore.h new file mode 100644 index 0000000..94ae25c --- /dev/null +++ b/subscriptionstore.h @@ -0,0 +1,23 @@ +#ifndef SUBSCRIPTIONSTORE_H +#define SUBSCRIPTIONSTORE_H + +#include +#include + +#include "forward_declarations.h" + +#include "client.h" + +class SubscriptionStore +{ + std::unordered_map> subscriptions; +public: + SubscriptionStore(); + + void addSubscription(Client_p &client, std::string &topic); + + // work with read copies intead of mutex/lock over the central store + void getReadCopy(); // TODO +}; + +#endif // SUBSCRIPTIONSTORE_H diff --git a/threaddata.cpp b/threaddata.cpp index db9738b..93aa66a 100644 --- a/threaddata.cpp +++ b/threaddata.cpp @@ -1,7 +1,8 @@ #include "threaddata.h" -ThreadData::ThreadData(int threadnr) : +ThreadData::ThreadData(int threadnr, std::shared_ptr &subscriptionStore) : + subscriptionStore(subscriptionStore), threadnr(threadnr) { epollfd = check(epoll_create(999)); @@ -29,3 +30,8 @@ void ThreadData::removeClient(Client_p client) { clients_by_fd.erase(client->getFd()); } + +std::shared_ptr &ThreadData::getSubscriptionStore() +{ + return subscriptionStore; +} diff --git a/threaddata.h b/threaddata.h index 8e3f771..35af923 100644 --- a/threaddata.h +++ b/threaddata.h @@ -2,18 +2,23 @@ #define THREADDATA_H #include -#include "utils.h" + #include #include -#include "client.h" #include -class Client; -typedef std::shared_ptr Client_p; +#include "forward_declarations.h" + +#include "client.h" +#include "subscriptionstore.h" +#include "utils.h" + + class ThreadData { std::map clients_by_fd; + std::shared_ptr subscriptionStore; public: std::thread thread; @@ -21,13 +26,12 @@ public: int epollfd = 0; int event_fd = 0; - ThreadData(int threadnr); + ThreadData(int threadnr, std::shared_ptr &subscriptionStore); void giveClient(Client_p client); Client_p getClient(int fd); void removeClient(Client_p client); + std::shared_ptr &getSubscriptionStore(); }; -typedef std::shared_ptr ThreadData_p; - #endif // THREADDATA_H -- libgit2 0.21.4