Commit 73ef26aefbc86fa9d4b73f60739d72a8fd6ff5f7

Authored by Wiebe Cazemier
1 parent 66cd78b3

IO stuff

client.cpp
@@ -16,12 +16,20 @@ Client::Client(int fd, ThreadData_p threadData) : @@ -16,12 +16,20 @@ Client::Client(int fd, ThreadData_p threadData) :
16 16
17 Client::~Client() 17 Client::~Client()
18 { 18 {
19 - epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL); // NOTE: the last NULL can cause crash on old kernels  
20 - close(fd); 19 + closeConnection();
21 free(readbuf); 20 free(readbuf);
22 free(writebuf); 21 free(writebuf);
23 } 22 }
24 23
  24 +void Client::closeConnection()
  25 +{
  26 + if (fd < 0)
  27 + return;
  28 + epoll_ctl(threadData->epollfd, EPOLL_CTL_DEL, fd, NULL);
  29 + close(fd);
  30 + fd = -1;
  31 +}
  32 +
25 // false means any kind of error we want to get rid of the client for. 33 // false means any kind of error we want to get rid of the client for.
26 bool Client::readFdIntoBuffer() 34 bool Client::readFdIntoBuffer()
27 { 35 {
@@ -59,11 +67,16 @@ bool Client::readFdIntoBuffer() @@ -59,11 +67,16 @@ bool Client::readFdIntoBuffer()
59 67
60 void Client::writeMqttPacket(const MqttPacket &packet) 68 void Client::writeMqttPacket(const MqttPacket &packet)
61 { 69 {
  70 + if (packet.packetType == PacketType::PUBLISH && getWriteBufBytesUsed() > CLIENT_MAX_BUFFER_SIZE)
  71 + return;
  72 +
62 if (packet.getSize() > getWriteBufMaxWriteSize()) 73 if (packet.getSize() > getWriteBufMaxWriteSize())
63 growWriteBuffer(packet.getSize()); 74 growWriteBuffer(packet.getSize());
64 75
65 std::memcpy(&writebuf[wwi], &packet.getBites()[0], packet.getSize()); 76 std::memcpy(&writebuf[wwi], &packet.getBites()[0], packet.getSize());
66 wwi += packet.getSize(); 77 wwi += packet.getSize();
  78 +
  79 + setReadyForWriting(true);
67 } 80 }
68 81
69 // Not sure if this is the method I want to use 82 // Not sure if this is the method I want to use
@@ -83,10 +96,11 @@ void Client::writePingResp() @@ -83,10 +96,11 @@ void Client::writePingResp()
83 96
84 writebuf[wwi++] = 0b11010000; 97 writebuf[wwi++] = 0b11010000;
85 writebuf[wwi++] = 0; 98 writebuf[wwi++] = 0;
86 - writeBufIntoFd(); 99 +
  100 + setReadyForWriting(true);
87 } 101 }
88 102
89 -bool Client::writeBufIntoFd() // TODO: ignore the signal BROKEN PIPE we now also get when a client disappears. 103 +bool Client::writeBufIntoFd()
90 { 104 {
91 int n; 105 int n;
92 while ((n = write(fd, &writebuf[wri], getWriteBufBytesUsed())) != 0) 106 while ((n = write(fd, &writebuf[wri], getWriteBufBytesUsed())) != 0)
@@ -108,6 +122,8 @@ bool Client::writeBufIntoFd() // TODO: ignore the signal BROKEN PIPE we now also @@ -108,6 +122,8 @@ bool Client::writeBufIntoFd() // TODO: ignore the signal BROKEN PIPE we now also
108 { 122 {
109 wri = 0; 123 wri = 0;
110 wwi = 0; 124 wwi = 0;
  125 +
  126 + setReadyForWriting(false);
111 } 127 }
112 128
113 return true; 129 return true;
@@ -134,6 +150,21 @@ void Client::queuedMessagesToBuffer() @@ -134,6 +150,21 @@ void Client::queuedMessagesToBuffer()
134 150
135 } 151 }
136 152
  153 +void Client::setReadyForWriting(bool val)
  154 +{
  155 + if (val == this->readyForWriting)
  156 + return;
  157 +
  158 + readyForWriting = val;
  159 + struct epoll_event ev;
  160 + memset(&ev, 0, sizeof (struct epoll_event));
  161 + ev.data.fd = fd;
  162 + ev.events = EPOLLIN;
  163 + if (val)
  164 + ev.events |= EPOLLOUT;
  165 + check<std::runtime_error>(epoll_ctl(threadData->epollfd, EPOLL_CTL_MOD, fd, &ev));
  166 +}
  167 +
137 bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender) 168 bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender)
138 { 169 {
139 while (getReadBufBytesUsed() >= MQTT_HEADER_LENGH) 170 while (getReadBufBytesUsed() >= MQTT_HEADER_LENGH)
client.h
@@ -14,6 +14,7 @@ @@ -14,6 +14,7 @@
14 14
15 15
16 #define CLIENT_BUFFER_SIZE 1024 16 #define CLIENT_BUFFER_SIZE 1024
  17 +#define CLIENT_MAX_BUFFER_SIZE 1048576
17 #define MQTT_HEADER_LENGH 2 18 #define MQTT_HEADER_LENGH 2
18 19
19 class Client 20 class Client
@@ -32,6 +33,8 @@ class Client @@ -32,6 +33,8 @@ class Client
32 33
33 bool authenticated = false; 34 bool authenticated = false;
34 bool connectPacketSeen = false; 35 bool connectPacketSeen = false;
  36 + bool readyForWriting = false;
  37 +
35 std::string clientid; 38 std::string clientid;
36 std::string username; 39 std::string username;
37 uint16_t keepalive = 0; 40 uint16_t keepalive = 0;
@@ -84,11 +87,14 @@ class Client @@ -84,11 +87,14 @@ class Client
84 writeBufsize = newBufSize; 87 writeBufsize = newBufSize;
85 } 88 }
86 89
  90 +
  91 +
87 public: 92 public:
88 Client(int fd, ThreadData_p threadData); 93 Client(int fd, ThreadData_p threadData);
89 ~Client(); 94 ~Client();
90 95
91 int getFd() { return fd;} 96 int getFd() { return fd;}
  97 + void closeConnection();
92 bool readFdIntoBuffer(); 98 bool readFdIntoBuffer();
93 bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender); 99 bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender);
94 void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive); 100 void setClientProperties(const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive);
@@ -96,6 +102,7 @@ public: @@ -96,6 +102,7 @@ public:
96 bool getAuthenticated() { return authenticated; } 102 bool getAuthenticated() { return authenticated; }
97 bool hasConnectPacketSeen() { return connectPacketSeen; } 103 bool hasConnectPacketSeen() { return connectPacketSeen; }
98 ThreadData_p getThreadData() { return threadData; } 104 ThreadData_p getThreadData() { return threadData; }
  105 + std::string &getClientId() { return this->clientid; }
99 106
100 void writePingResp(); 107 void writePingResp();
101 void writeMqttPacket(const MqttPacket &packet); 108 void writeMqttPacket(const MqttPacket &packet);
@@ -106,6 +113,8 @@ public: @@ -106,6 +113,8 @@ public:
106 113
107 void queueMessage(const MqttPacket &packet); 114 void queueMessage(const MqttPacket &packet);
108 void queuedMessagesToBuffer(); 115 void queuedMessagesToBuffer();
  116 +
  117 + void setReadyForWriting(bool val);
109 }; 118 };
110 119
111 #endif // CLIENT_H 120 #endif // CLIENT_H
mainapp.cpp
1 #include "mainapp.h" 1 #include "mainapp.h"
  2 +#include "cassert"
2 3
3 #define MAX_EVENTS 1024 4 #define MAX_EVENTS 1024
4 #define NR_OF_THREADS 4 5 #define NR_OF_THREADS 4
@@ -27,8 +28,6 @@ void do_thread_work(ThreadData *threadData) @@ -27,8 +28,6 @@ void do_thread_work(ThreadData *threadData)
27 eventfd_value = 0; 28 eventfd_value = 0;
28 } 29 }
29 30
30 - // TODO: do all the buftofd here, not spread out over  
31 -  
32 int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); 31 int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100);
33 32
34 if (fdcount > 0) 33 if (fdcount > 0)
@@ -66,6 +65,10 @@ void do_thread_work(ThreadData *threadData) @@ -66,6 +65,10 @@ void do_thread_work(ThreadData *threadData)
66 threadData->removeClient(client); 65 threadData->removeClient(client);
67 } 66 }
68 } 67 }
  68 + else
  69 + {
  70 + assert(false);
  71 + }
69 } 72 }
70 } 73 }
71 74
@@ -85,11 +88,11 @@ MainApp::MainApp() : @@ -85,11 +88,11 @@ MainApp::MainApp() :
85 88
86 void MainApp::start() 89 void MainApp::start()
87 { 90 {
88 - int listen_fd = socket(AF_INET, SOCK_STREAM, 0); 91 + int listen_fd = check<std::runtime_error>(socket(AF_INET, SOCK_STREAM, 0));
89 92
90 // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT. 93 // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT.
91 - //int optval = 1;  
92 - //check<std::runtime_error>(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); 94 + int optval = 1;
  95 + check<std::runtime_error>(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval)));
93 96
94 int flags = fcntl(listen_fd, F_GETFL); 97 int flags = fcntl(listen_fd, F_GETFL);
95 check<std::runtime_error>(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK )); 98 check<std::runtime_error>(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK ));
@@ -153,11 +156,13 @@ void MainApp::start() @@ -153,11 +156,13 @@ void MainApp::start()
153 156
154 } 157 }
155 } 158 }
  159 +
  160 + close(listen_fd);
156 } 161 }
157 162
158 void MainApp::quit() 163 void MainApp::quit()
159 { 164 {
160 - std::cout << "Quitting application" << std::endl; 165 + std::cout << "Quitting FlashMQ" << std::endl;
161 166
162 running = false; 167 running = false;
163 168
mqttpacket.cpp
@@ -159,7 +159,6 @@ void MqttPacket::handleConnect() @@ -159,7 +159,6 @@ void MqttPacket::handleConnect()
159 ConnAck connAck(ConnAckReturnCodes::Accepted); 159 ConnAck connAck(ConnAckReturnCodes::Accepted);
160 MqttPacket response(connAck); 160 MqttPacket response(connAck);
161 sender->writeMqttPacket(response); 161 sender->writeMqttPacket(response);
162 - sender->writeBufIntoFd();  
163 } 162 }
164 else 163 else
165 { 164 {
@@ -186,7 +185,6 @@ void MqttPacket::handleSubscribe(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptio @@ -186,7 +185,6 @@ void MqttPacket::handleSubscribe(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptio
186 SubAck subAck(packet_id, subs); 185 SubAck subAck(packet_id, subs);
187 MqttPacket response(subAck); 186 MqttPacket response(subAck);
188 sender->writeMqttPacket(response); 187 sender->writeMqttPacket(response);
189 - sender->writeBufIntoFd();  
190 } 188 }
191 189
192 void MqttPacket::handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore) 190 void MqttPacket::handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore)
subscriptionstore.cpp
@@ -8,7 +8,17 @@ SubscriptionStore::SubscriptionStore() @@ -8,7 +8,17 @@ SubscriptionStore::SubscriptionStore()
8 void SubscriptionStore::addSubscription(Client_p &client, std::string &topic) 8 void SubscriptionStore::addSubscription(Client_p &client, std::string &topic)
9 { 9 {
10 std::lock_guard<std::mutex> lock(subscriptionsMutex); 10 std::lock_guard<std::mutex> lock(subscriptionsMutex);
11 - this->subscriptions[topic].push_back(client); 11 + this->subscriptions[topic].insert(client);
  12 +}
  13 +
  14 +void SubscriptionStore::removeClient(const Client_p &client)
  15 +{
  16 + std::lock_guard<std::mutex> lock(subscriptionsMutex);
  17 + for(std::pair<const std::string, std::unordered_set<Client_p>> &pair : subscriptions)
  18 + {
  19 + std::unordered_set<Client_p> &bla = pair.second;
  20 + bla.erase(client);
  21 + }
12 } 22 }
13 23
14 void SubscriptionStore::queueAtClientsTemp(std::string &topic, const MqttPacket &packet, const Client_p &sender) 24 void SubscriptionStore::queueAtClientsTemp(std::string &topic, const MqttPacket &packet, const Client_p &sender)
@@ -16,18 +26,19 @@ void SubscriptionStore::queueAtClientsTemp(std::string &amp;topic, const MqttPacket @@ -16,18 +26,19 @@ void SubscriptionStore::queueAtClientsTemp(std::string &amp;topic, const MqttPacket
16 // TODO: temp. I want to work with read copies of the subscription store, to avoid frequent lock contention. 26 // TODO: temp. I want to work with read copies of the subscription store, to avoid frequent lock contention.
17 std::lock_guard<std::mutex> lock(subscriptionsMutex); 27 std::lock_guard<std::mutex> lock(subscriptionsMutex);
18 28
19 - for(Client_p &client : subscriptions[topic]) 29 + for(const Client_p &client : subscriptions[topic])
20 { 30 {
21 if (client->getThreadData()->threadnr == sender->getThreadData()->threadnr) 31 if (client->getThreadData()->threadnr == sender->getThreadData()->threadnr)
22 { 32 {
23 client->writeMqttPacket(packet); // TODO: with my current hack way, this is wrong. Not using a lock only works with my previous idea of queueing. 33 client->writeMqttPacket(packet); // TODO: with my current hack way, this is wrong. Not using a lock only works with my previous idea of queueing.
24 - client->writeBufIntoFd();  
25 } 34 }
26 else 35 else
27 { 36 {
28 - client->writeMqttPacketLocked(packet);  
29 - client->getThreadData()->addToReadyForDequeuing(client);  
30 - client->getThreadData()->wakeUpThread(); 37 + // Or keep a list of queued messages in the store, per client?
  38 +
  39 + //client->writeMqttPacketLocked(packet);
  40 + //client->getThreadData()->addToReadyForDequeuing(client);
  41 + //client->getThreadData()->wakeUpThread();
31 } 42 }
32 } 43 }
33 } 44 }
subscriptionstore.h
@@ -11,12 +11,13 @@ @@ -11,12 +11,13 @@
11 11
12 class SubscriptionStore 12 class SubscriptionStore
13 { 13 {
14 - std::unordered_map<std::string, std::list<Client_p>> subscriptions; 14 + std::unordered_map<std::string, std::unordered_set<Client_p>> subscriptions;
15 std::mutex subscriptionsMutex; 15 std::mutex subscriptionsMutex;
16 public: 16 public:
17 SubscriptionStore(); 17 SubscriptionStore();
18 18
19 void addSubscription(Client_p &client, std::string &topic); 19 void addSubscription(Client_p &client, std::string &topic);
  20 + void removeClient(const Client_p &client);
20 21
21 // work with read copies intead of mutex/lock over the central store 22 // work with read copies intead of mutex/lock over the central store
22 void getReadCopy(); // TODO 23 void getReadCopy(); // TODO
threaddata.cpp
1 #include "threaddata.h" 1 #include "threaddata.h"
2 2
3 -  
4 ThreadData::ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore) : 3 ThreadData::ThreadData(int threadnr, std::shared_ptr<SubscriptionStore> &subscriptionStore) :
5 subscriptionStore(subscriptionStore), 4 subscriptionStore(subscriptionStore),
6 threadnr(threadnr) 5 threadnr(threadnr)
@@ -23,24 +22,31 @@ void ThreadData::quit() @@ -23,24 +22,31 @@ void ThreadData::quit()
23 22
24 void ThreadData::giveClient(Client_p client) 23 void ThreadData::giveClient(Client_p client)
25 { 24 {
  25 + clients_by_fd_mutex.lock();
26 int fd = client->getFd(); 26 int fd = client->getFd();
  27 + clients_by_fd[fd] = client;
  28 + clients_by_fd_mutex.unlock();
  29 +
27 struct epoll_event ev; 30 struct epoll_event ev;
28 memset(&ev, 0, sizeof (struct epoll_event)); 31 memset(&ev, 0, sizeof (struct epoll_event));
29 ev.data.fd = fd; 32 ev.data.fd = fd;
30 ev.events = EPOLLIN; 33 ev.events = EPOLLIN;
31 check<std::runtime_error>(epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev)); 34 check<std::runtime_error>(epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev));
32 -  
33 - clients_by_fd[fd] = client;  
34 } 35 }
35 36
36 Client_p ThreadData::getClient(int fd) 37 Client_p ThreadData::getClient(int fd)
37 { 38 {
  39 + std::lock_guard<std::mutex> lck(clients_by_fd_mutex);
38 return this->clients_by_fd[fd]; 40 return this->clients_by_fd[fd];
39 } 41 }
40 42
41 void ThreadData::removeClient(Client_p client) 43 void ThreadData::removeClient(Client_p client)
42 { 44 {
  45 + client->closeConnection();
  46 + std::lock_guard<std::mutex> lck(clients_by_fd_mutex);
  47 + subscriptionStore->removeClient(client);
43 clients_by_fd.erase(client->getFd()); 48 clients_by_fd.erase(client->getFd());
  49 +
44 } 50 }
45 51
46 std::shared_ptr<SubscriptionStore> &ThreadData::getSubscriptionStore() 52 std::shared_ptr<SubscriptionStore> &ThreadData::getSubscriptionStore()
threaddata.h
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 #include <unordered_set> 9 #include <unordered_set>
10 #include <unordered_map> 10 #include <unordered_map>
11 #include <mutex> 11 #include <mutex>
  12 +#include <shared_mutex>
12 13
13 #include "forward_declarations.h" 14 #include "forward_declarations.h"
14 15
@@ -21,6 +22,7 @@ @@ -21,6 +22,7 @@
21 class ThreadData 22 class ThreadData
22 { 23 {
23 std::unordered_map<int, Client_p> clients_by_fd; 24 std::unordered_map<int, Client_p> clients_by_fd;
  25 + std::mutex clients_by_fd_mutex;
24 std::shared_ptr<SubscriptionStore> subscriptionStore; 26 std::shared_ptr<SubscriptionStore> subscriptionStore;
25 std::unordered_set<Client_p> readyForDequeueing; 27 std::unordered_set<Client_p> readyForDequeueing;
26 std::mutex readForDequeuingMutex; 28 std::mutex readForDequeuingMutex;