Commit 2a3138f9c422ddeea1055d4773e84cf437b687b7

Authored by Wiebe Cazemier
1 parent fec1aa82

Reorder stuff for storing subscription store

CMakeLists.txt
@@ -14,6 +14,8 @@ add_executable(FlashMQ @@ -14,6 +14,8 @@ add_executable(FlashMQ
14 bytestopacketparser.cpp 14 bytestopacketparser.cpp
15 mqttpacket.cpp 15 mqttpacket.cpp
16 exceptions.cpp 16 exceptions.cpp
17 - types.cpp) 17 + types.cpp
  18 + subscriptionstore.cpp
  19 + )
18 20
19 target_link_libraries(FlashMQ pthread) 21 target_link_libraries(FlashMQ pthread)
client.cpp
@@ -115,7 +115,7 @@ std::string Client::repr() @@ -115,7 +115,7 @@ std::string Client::repr()
115 return a.str(); 115 return a.str();
116 } 116 }
117 117
118 -bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn) 118 +bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender)
119 { 119 {
120 while (getReadBufBytesUsed() >= MQTT_HEADER_LENGH) 120 while (getReadBufBytesUsed() >= MQTT_HEADER_LENGH)
121 { 121 {
@@ -144,7 +144,7 @@ bool Client::bufferToMqttPackets(std::vector&lt;MqttPacket&gt; &amp;packetQueueIn) @@ -144,7 +144,7 @@ bool Client::bufferToMqttPackets(std::vector&lt;MqttPacket&gt; &amp;packetQueueIn)
144 144
145 if (packet_length <= getReadBufBytesUsed()) 145 if (packet_length <= getReadBufBytesUsed())
146 { 146 {
147 - MqttPacket packet(&readbuf[ri], packet_length, remaining_length_i, this); 147 + MqttPacket packet(&readbuf[ri], packet_length, remaining_length_i, sender);
148 packetQueueIn.push_back(std::move(packet)); 148 packetQueueIn.push_back(std::move(packet));
149 149
150 ri += packet_length; 150 ri += packet_length;
client.h
@@ -5,6 +5,8 @@ @@ -5,6 +5,8 @@
5 #include <unistd.h> 5 #include <unistd.h>
6 #include <vector> 6 #include <vector>
7 7
  8 +#include "forward_declarations.h"
  9 +
8 #include "threaddata.h" 10 #include "threaddata.h"
9 #include "mqttpacket.h" 11 #include "mqttpacket.h"
10 #include "exceptions.h" 12 #include "exceptions.h"
@@ -12,11 +14,6 @@ @@ -12,11 +14,6 @@
12 #define CLIENT_BUFFER_SIZE 1024 14 #define CLIENT_BUFFER_SIZE 1024
13 #define MQTT_HEADER_LENGH 2 15 #define MQTT_HEADER_LENGH 2
14 16
15 -class ThreadData;  
16 -typedef std::shared_ptr<ThreadData> ThreadData_p;  
17 -  
18 -class MqttPacket;  
19 -  
20 class Client 17 class Client
21 { 18 {
22 int fd; 19 int fd;
@@ -90,7 +87,7 @@ public: @@ -90,7 +87,7 @@ public:
90 87
91 int getFd() { return fd;} 88 int getFd() { return fd;}
92 bool readFdIntoBuffer(); 89 bool readFdIntoBuffer();
93 - bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn); 90 + bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender);
94 void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive); 91 void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive);
95 void setAuthenticated(bool value) { authenticated = value;} 92 void setAuthenticated(bool value) { authenticated = value;}
96 bool getAuthenticated() { return authenticated; } 93 bool getAuthenticated() { return authenticated; }
@@ -104,6 +101,4 @@ public: @@ -104,6 +101,4 @@ public:
104 101
105 }; 102 };
106 103
107 -typedef std::shared_ptr<Client> Client_p;  
108 -  
109 #endif // CLIENT_H 104 #endif // CLIENT_H
forward_declarations.h 0 → 100644
  1 +#ifndef FORWARD_DECLARATIONS_H
  2 +#define FORWARD_DECLARATIONS_H
  3 +
  4 +#include <memory>
  5 +
  6 +class Client;
  7 +typedef std::shared_ptr<Client> Client_p;
  8 +class ThreadData;
  9 +typedef std::shared_ptr<ThreadData> ThreadData_p;
  10 +class MqttPacket;
  11 +class SubscriptionStore;
  12 +
  13 +
  14 +#endif // FORWARD_DECLARATIONS_H
main.cpp
@@ -7,10 +7,12 @@ @@ -7,10 +7,12 @@
7 #include <thread> 7 #include <thread>
8 #include <vector> 8 #include <vector>
9 9
  10 +
10 #include "utils.h" 11 #include "utils.h"
11 #include "threaddata.h" 12 #include "threaddata.h"
12 #include "client.h" 13 #include "client.h"
13 #include "mqttpacket.h" 14 #include "mqttpacket.h"
  15 +#include "subscriptionstore.h"
14 16
15 #define MAX_EVENTS 1024 17 #define MAX_EVENTS 1024
16 #define NR_OF_THREADS 4 18 #define NR_OF_THREADS 4
@@ -50,7 +52,7 @@ void do_thread_work(ThreadData *threadData) @@ -50,7 +52,7 @@ void do_thread_work(ThreadData *threadData)
50 } 52 }
51 else 53 else
52 { 54 {
53 - client->bufferToMqttPackets(packetQueueIn); // TODO: different, because now I need to give the packet a raw pointer. 55 + client->bufferToMqttPackets(packetQueueIn, client);
54 } 56 }
55 } 57 }
56 if (cur_ev.events & EPOLLOUT) 58 if (cur_ev.events & EPOLLOUT)
@@ -64,7 +66,7 @@ void do_thread_work(ThreadData *threadData) @@ -64,7 +66,7 @@ void do_thread_work(ThreadData *threadData)
64 66
65 for (MqttPacket &packet : packetQueueIn) 67 for (MqttPacket &packet : packetQueueIn)
66 { 68 {
67 - packet.handle(); 69 + packet.handle(threadData->getSubscriptionStore());
68 } 70 }
69 packetQueueIn.clear(); 71 packetQueueIn.clear();
70 } 72 }
@@ -100,11 +102,13 @@ int main() @@ -100,11 +102,13 @@ int main()
100 ev.events = EPOLLIN; 102 ev.events = EPOLLIN;
101 check<std::runtime_error>(epoll_ctl(epoll_fd_accept, EPOLL_CTL_ADD, listen_fd, &ev)); 103 check<std::runtime_error>(epoll_ctl(epoll_fd_accept, EPOLL_CTL_ADD, listen_fd, &ev));
102 104
  105 + std::shared_ptr<SubscriptionStore> subscriptionStore(new SubscriptionStore());
  106 +
103 std::vector<std::shared_ptr<ThreadData>> threads; 107 std::vector<std::shared_ptr<ThreadData>> threads;
104 108
105 for (int i = 0; i < NR_OF_THREADS; i++) 109 for (int i = 0; i < NR_OF_THREADS; i++)
106 { 110 {
107 - std::shared_ptr<ThreadData> t(new ThreadData(i)); 111 + std::shared_ptr<ThreadData> t(new ThreadData(i, subscriptionStore));
108 std::thread thread(do_thread_work, t.get()); 112 std::thread thread(do_thread_work, t.get());
109 t->thread = std::move(thread); 113 t->thread = std::move(thread);
110 threads.push_back(t); 114 threads.push_back(t);
mqttpacket.cpp
@@ -3,7 +3,7 @@ @@ -3,7 +3,7 @@
3 #include <iostream> 3 #include <iostream>
4 #include <list> 4 #include <list>
5 5
6 -MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender) : 6 +MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender) :
7 bites(len), 7 bites(len),
8 fixed_header_length(fixed_header_length), 8 fixed_header_length(fixed_header_length),
9 sender(sender) 9 sender(sender)
@@ -46,14 +46,14 @@ MqttPacket::MqttPacket(const SubAck &amp;subAck) : @@ -46,14 +46,14 @@ MqttPacket::MqttPacket(const SubAck &amp;subAck) :
46 bites[1] = returnList.size() + 1; // TODO: make some generic way of calculating the header and use the multi-byte length 46 bites[1] = returnList.size() + 1; // TODO: make some generic way of calculating the header and use the multi-byte length
47 } 47 }
48 48
49 -void MqttPacket::handle() 49 +void MqttPacket::handle(std::shared_ptr<SubscriptionStore> &subscriptionStore)
50 { 50 {
51 if (packetType == PacketType::CONNECT) 51 if (packetType == PacketType::CONNECT)
52 handleConnect(); 52 handleConnect();
53 else if (packetType == PacketType::PINGREQ) 53 else if (packetType == PacketType::PINGREQ)
54 sender->writePingResp(); 54 sender->writePingResp();
55 else if (packetType == PacketType::SUBSCRIBE) 55 else if (packetType == PacketType::SUBSCRIBE)
56 - handleSubscribe(); 56 + handleSubscribe(subscriptionStore);
57 } 57 }
58 58
59 void MqttPacket::handleConnect() 59 void MqttPacket::handleConnect()
@@ -137,7 +137,7 @@ void MqttPacket::handleConnect() @@ -137,7 +137,7 @@ void MqttPacket::handleConnect()
137 } 137 }
138 } 138 }
139 139
140 -void MqttPacket::handleSubscribe() 140 +void MqttPacket::handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptionStore)
141 { 141 {
142 uint16_t packet_id = readTwoBytesToUInt16(); 142 uint16_t packet_id = readTwoBytesToUInt16();
143 143
@@ -147,6 +147,7 @@ void MqttPacket::handleSubscribe() @@ -147,6 +147,7 @@ void MqttPacket::handleSubscribe()
147 uint16_t topicLength = readTwoBytesToUInt16(); 147 uint16_t topicLength = readTwoBytesToUInt16();
148 std::string topic(readBytes(topicLength), topicLength); 148 std::string topic(readBytes(topicLength), topicLength);
149 std::cout << sender->repr() << " Subscribed to " << topic << std::endl; 149 std::cout << sender->repr() << " Subscribed to " << topic << std::endl;
  150 + subscriptionStore->addSubscription(sender, topic);
150 subs.push_back(std::move(topic)); 151 subs.push_back(std::move(topic));
151 } 152 }
152 153
@@ -159,6 +160,16 @@ void MqttPacket::handleSubscribe() @@ -159,6 +160,16 @@ void MqttPacket::handleSubscribe()
159 } 160 }
160 161
161 162
  163 +Client_p MqttPacket::getSender() const
  164 +{
  165 + return sender;
  166 +}
  167 +
  168 +void MqttPacket::setSender(const Client_p &value)
  169 +{
  170 + sender = value;
  171 +}
  172 +
162 char *MqttPacket::readBytes(size_t length) 173 char *MqttPacket::readBytes(size_t length)
163 { 174 {
164 if (pos + length > bites.size()) 175 if (pos + length > bites.size())
mqttpacket.h
@@ -6,18 +6,19 @@ @@ -6,18 +6,19 @@
6 #include <vector> 6 #include <vector>
7 #include <exception> 7 #include <exception>
8 8
  9 +#include "forward_declarations.h"
  10 +
9 #include "client.h" 11 #include "client.h"
10 #include "exceptions.h" 12 #include "exceptions.h"
11 #include "types.h" 13 #include "types.h"
12 -  
13 -class Client; 14 +#include "subscriptionstore.h"
14 15
15 16
16 class MqttPacket 17 class MqttPacket
17 { 18 {
18 std::vector<char> bites; 19 std::vector<char> bites;
19 size_t fixed_header_length = 0; 20 size_t fixed_header_length = 0;
20 - Client *sender; 21 + Client_p sender;
21 size_t pos = 0; 22 size_t pos = 0;
22 ProtocolVersion protocolVersion = ProtocolVersion::None; 23 ProtocolVersion protocolVersion = ProtocolVersion::None;
23 24
@@ -29,18 +30,20 @@ class MqttPacket @@ -29,18 +30,20 @@ class MqttPacket
29 30
30 public: 31 public:
31 PacketType packetType = PacketType::Reserved; 32 PacketType packetType = PacketType::Reserved;
32 - MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender); 33 + MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender);
33 MqttPacket(const ConnAck &connAck); 34 MqttPacket(const ConnAck &connAck);
34 MqttPacket(const SubAck &subAck); 35 MqttPacket(const SubAck &subAck);
35 36
36 - void handle(); 37 + void handle(std::shared_ptr<SubscriptionStore> &subscriptionStore);
37 void handleConnect(); 38 void handleConnect();
38 - void handleSubscribe(); 39 + void handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptionStore);
39 void handlePing(); 40 void handlePing();
40 41
41 size_t getSize() { return bites.size(); } 42 size_t getSize() { return bites.size(); }
42 const std::vector<char> &getBites() { return bites; } 43 const std::vector<char> &getBites() { return bites; }
43 44
  45 + Client_p getSender() const;
  46 + void setSender(const Client_p &value);
44 }; 47 };
45 48
46 #endif // MQTTPACKET_H 49 #endif // MQTTPACKET_H
subscriptionstore.cpp 0 → 100644
  1 +#include "subscriptionstore.h"
  2 +
  3 +SubscriptionStore::SubscriptionStore()
  4 +{
  5 +
  6 +}
  7 +
  8 +void SubscriptionStore::addSubscription(Client_p &client, std::string &topic)
  9 +{
  10 + this->subscriptions[topic].push_back(client);
  11 +}
subscriptionstore.h 0 → 100644
  1 +#ifndef SUBSCRIPTIONSTORE_H
  2 +#define SUBSCRIPTIONSTORE_H
  3 +
  4 +#include <unordered_map>
  5 +#include <list>
  6 +
  7 +#include "forward_declarations.h"
  8 +
  9 +#include "client.h"
  10 +
  11 +class SubscriptionStore
  12 +{
  13 + std::unordered_map<std::string, std::list<Client_p>> subscriptions;
  14 +public:
  15 + SubscriptionStore();
  16 +
  17 + void addSubscription(Client_p &client, std::string &topic);
  18 +
  19 + // work with read copies intead of mutex/lock over the central store
  20 + void getReadCopy(); // TODO
  21 +};
  22 +
  23 +#endif // SUBSCRIPTIONSTORE_H
threaddata.cpp
1 #include "threaddata.h" 1 #include "threaddata.h"
2 2
3 3
4 -ThreadData::ThreadData(int threadnr) : 4 +ThreadData::ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore) :
  5 + subscriptionStore(subscriptionStore),
5 threadnr(threadnr) 6 threadnr(threadnr)
6 { 7 {
7 epollfd = check<std::runtime_error>(epoll_create(999)); 8 epollfd = check<std::runtime_error>(epoll_create(999));
@@ -29,3 +30,8 @@ void ThreadData::removeClient(Client_p client) @@ -29,3 +30,8 @@ void ThreadData::removeClient(Client_p client)
29 { 30 {
30 clients_by_fd.erase(client->getFd()); 31 clients_by_fd.erase(client->getFd());
31 } 32 }
  33 +
  34 +std::shared_ptr<SubscriptionStore> &ThreadData::getSubscriptionStore()
  35 +{
  36 + return subscriptionStore;
  37 +}
threaddata.h
@@ -2,18 +2,23 @@ @@ -2,18 +2,23 @@
2 #define THREADDATA_H 2 #define THREADDATA_H
3 3
4 #include <thread> 4 #include <thread>
5 -#include "utils.h" 5 +
6 #include <sys/epoll.h> 6 #include <sys/epoll.h>
7 #include <sys/eventfd.h> 7 #include <sys/eventfd.h>
8 -#include "client.h"  
9 #include <map> 8 #include <map>
10 9
11 -class Client;  
12 -typedef std::shared_ptr<Client> Client_p; 10 +#include "forward_declarations.h"
  11 +
  12 +#include "client.h"
  13 +#include "subscriptionstore.h"
  14 +#include "utils.h"
  15 +
  16 +
13 17
14 class ThreadData 18 class ThreadData
15 { 19 {
16 std::map<int, Client_p> clients_by_fd; 20 std::map<int, Client_p> clients_by_fd;
  21 + std::shared_ptr<SubscriptionStore> subscriptionStore;
17 22
18 public: 23 public:
19 std::thread thread; 24 std::thread thread;
@@ -21,13 +26,12 @@ public: @@ -21,13 +26,12 @@ public:
21 int epollfd = 0; 26 int epollfd = 0;
22 int event_fd = 0; 27 int event_fd = 0;
23 28
24 - ThreadData(int threadnr); 29 + ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore);
25 30
26 void giveClient(Client_p client); 31 void giveClient(Client_p client);
27 Client_p getClient(int fd); 32 Client_p getClient(int fd);
28 void removeClient(Client_p client); 33 void removeClient(Client_p client);
  34 + std::shared_ptr<SubscriptionStore> &getSubscriptionStore();
29 }; 35 };
30 36
31 -typedef std::shared_ptr<ThreadData> ThreadData_p;  
32 -  
33 #endif // THREADDATA_H 37 #endif // THREADDATA_H