Commit 40f8e5e5f8eac75a2647d423826b52acf38bcb71

Authored by Wiebe Cazemier
1 parent a5b81d74

Publishing within the same thread works

Roughly...
client.cpp
@@ -57,7 +57,7 @@ bool Client::readFdIntoBuffer() @@ -57,7 +57,7 @@ bool Client::readFdIntoBuffer()
57 return true; 57 return true;
58 } 58 }
59 59
60 -void Client::writeMqttPacket(MqttPacket &packet) 60 +void Client::writeMqttPacket(const MqttPacket &packet)
61 { 61 {
62 if (packet.getSize() > getWriteBufMaxWriteSize()) 62 if (packet.getSize() > getWriteBufMaxWriteSize())
63 growWriteBuffer(packet.getSize()); 63 growWriteBuffer(packet.getSize());
@@ -115,6 +115,18 @@ std::string Client::repr() @@ -115,6 +115,18 @@ std::string Client::repr()
115 return a.str(); 115 return a.str();
116 } 116 }
117 117
  118 +void Client::queueMessage(const MqttPacket &packet)
  119 +{
  120 +
  121 +
  122 + // TODO: semaphores on stl containers?
  123 +}
  124 +
  125 +void Client::queuedMessagesToBuffer()
  126 +{
  127 +
  128 +}
  129 +
118 bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender) 130 bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender)
119 { 131 {
120 while (getReadBufBytesUsed() >= MQTT_HEADER_LENGH) 132 while (getReadBufBytesUsed() >= MQTT_HEADER_LENGH)
client.h
@@ -92,13 +92,16 @@ public: @@ -92,13 +92,16 @@ public:
92 void setAuthenticated(bool value) { authenticated = value;} 92 void setAuthenticated(bool value) { authenticated = value;}
93 bool getAuthenticated() { return authenticated; } 93 bool getAuthenticated() { return authenticated; }
94 bool hasConnectPacketSeen() { return connectPacketSeen; } 94 bool hasConnectPacketSeen() { return connectPacketSeen; }
  95 + ThreadData_p getThreadData() { return threadData; }
95 96
96 void writePingResp(); 97 void writePingResp();
97 - void writeMqttPacket(MqttPacket &packet); 98 + void writeMqttPacket(const MqttPacket &packet);
98 bool writeBufIntoFd(); 99 bool writeBufIntoFd();
99 100
100 std::string repr(); 101 std::string repr();
101 102
  103 + void queueMessage(const MqttPacket &packet);
  104 + void queuedMessagesToBuffer();
102 }; 105 };
103 106
104 #endif // CLIENT_H 107 #endif // CLIENT_H
main.cpp
@@ -28,8 +28,22 @@ void do_thread_work(ThreadData *threadData) @@ -28,8 +28,22 @@ void do_thread_work(ThreadData *threadData)
28 28
29 std::vector<MqttPacket> packetQueueIn; 29 std::vector<MqttPacket> packetQueueIn;
30 30
  31 + uint64_t eventfd_value = 0;
  32 +
31 while (1) 33 while (1)
32 { 34 {
  35 + if (eventfd_value > 0)
  36 + {
  37 + for (Client_p client : threadData->getReadyForDequeueing())
  38 + {
  39 + client->queuedMessagesToBuffer();
  40 + }
  41 + threadData->clearReadyForDequeueing();
  42 + eventfd_value = 0;
  43 + }
  44 +
  45 + // TODO: do all the buftofd here, not spread out over
  46 +
33 int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); 47 int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100);
34 48
35 if (fdcount > 0) 49 if (fdcount > 0)
@@ -39,6 +53,13 @@ void do_thread_work(ThreadData *threadData) @@ -39,6 +53,13 @@ void do_thread_work(ThreadData *threadData)
39 struct epoll_event cur_ev = events[i]; 53 struct epoll_event cur_ev = events[i];
40 int fd = cur_ev.data.fd; 54 int fd = cur_ev.data.fd;
41 55
  56 + // If this thread was actively woken up.
  57 + if (fd == threadData->event_fd)
  58 + {
  59 + read(fd, &eventfd_value, sizeof(uint64_t));
  60 + continue;
  61 + }
  62 +
42 Client_p client = threadData->getClient(fd); 63 Client_p client = threadData->getClient(fd);
43 64
44 if (client) 65 if (client)
mqttpacket.cpp
@@ -3,16 +3,17 @@ @@ -3,16 +3,17 @@
3 #include <iostream> 3 #include <iostream>
4 #include <list> 4 #include <list>
5 5
  6 +// constructor for parsing incoming packets
6 MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender) : 7 MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender) :
7 bites(len), 8 bites(len),
8 fixed_header_length(fixed_header_length), 9 fixed_header_length(fixed_header_length),
9 sender(sender) 10 sender(sender)
10 { 11 {
11 - unsigned char _packetType = (buf[0] & 0xF0) >> 4; 12 + std::memcpy(&bites[0], buf, len);
  13 + first_byte = bites[0];
  14 + unsigned char _packetType = (first_byte & 0xF0) >> 4;
12 packetType = (PacketType)_packetType; 15 packetType = (PacketType)_packetType;
13 pos += fixed_header_length; 16 pos += fixed_header_length;
14 -  
15 - std::memcpy(&bites[0], buf, len);  
16 } 17 }
17 18
18 MqttPacket::MqttPacket(const ConnAck &connAck) : 19 MqttPacket::MqttPacket(const ConnAck &connAck) :
@@ -46,14 +47,42 @@ MqttPacket::MqttPacket(const SubAck &amp;subAck) : @@ -46,14 +47,42 @@ MqttPacket::MqttPacket(const SubAck &amp;subAck) :
46 bites[1] = returnList.size() + 1; // TODO: make some generic way of calculating the header and use the multi-byte length 47 bites[1] = returnList.size() + 1; // TODO: make some generic way of calculating the header and use the multi-byte length
47 } 48 }
48 49
  50 +MqttPacket::MqttPacket(const Publish &publish) :
  51 + bites(publish.topic.length() + publish.payload.length() + 2 + 2) // TODO: same as above
  52 +{
  53 + packetType = PacketType::PUBLISH;
  54 + char first_byte = static_cast<char>(packetType) << 4;
  55 + writeByte(first_byte);
  56 +
  57 + char topicLenMSB = (publish.topic.length() & 0xF0) >> 8;
  58 + char topicLenLSB = publish.topic.length() & 0x0F;
  59 + writeByte(topicLenMSB);
  60 + writeByte(topicLenLSB);
  61 + writeBytes(publish.topic.c_str(), publish.topic.length());
  62 +
  63 + writeBytes(publish.payload.c_str(), publish.payload.length());
  64 +
  65 + // TODO: untested. May be unnecessary, because a received packet can be resent just like that.
  66 +}
  67 +
49 void MqttPacket::handle(std::shared_ptr<SubscriptionStore> &subscriptionStore) 68 void MqttPacket::handle(std::shared_ptr<SubscriptionStore> &subscriptionStore)
50 { 69 {
  70 + if (packetType != PacketType::CONNECT)
  71 + {
  72 + if (!sender->getAuthenticated())
  73 + {
  74 + throw ProtocolError("Non-connect packet from non-authenticated client.");
  75 + }
  76 + }
  77 +
51 if (packetType == PacketType::CONNECT) 78 if (packetType == PacketType::CONNECT)
52 handleConnect(); 79 handleConnect();
53 else if (packetType == PacketType::PINGREQ) 80 else if (packetType == PacketType::PINGREQ)
54 sender->writePingResp(); 81 sender->writePingResp();
55 else if (packetType == PacketType::SUBSCRIBE) 82 else if (packetType == PacketType::SUBSCRIBE)
56 handleSubscribe(subscriptionStore); 83 handleSubscribe(subscriptionStore);
  84 + else if (packetType == PacketType::PUBLISH)
  85 + handlePublish(subscriptionStore);
57 } 86 }
58 87
59 void MqttPacket::handleConnect() 88 void MqttPacket::handleConnect()
@@ -123,6 +152,7 @@ void MqttPacket::handleConnect() @@ -123,6 +152,7 @@ void MqttPacket::handleConnect()
123 // TODO: validate UTF8 encoded username/password. 152 // TODO: validate UTF8 encoded username/password.
124 153
125 sender->setClientProperties(client_id, username, true, keep_alive); 154 sender->setClientProperties(client_id, username, true, keep_alive);
  155 + sender->setAuthenticated(true);
126 156
127 std::cout << "Connect: " << sender->repr() << std::endl; 157 std::cout << "Connect: " << sender->repr() << std::endl;
128 158
@@ -159,6 +189,30 @@ void MqttPacket::handleSubscribe(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptio @@ -159,6 +189,30 @@ void MqttPacket::handleSubscribe(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptio
159 sender->writeBufIntoFd(); 189 sender->writeBufIntoFd();
160 } 190 }
161 191
  192 +void MqttPacket::handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore)
  193 +{
  194 + uint16_t variable_header_length = readTwoBytesToUInt16();
  195 +
  196 + if (variable_header_length == 0)
  197 + throw ProtocolError("Empty publish topic");
  198 +
  199 + char qos = (first_byte & 0b00000110) >> 1;
  200 +
  201 + std::string topic(readBytes(variable_header_length), variable_header_length);
  202 +
  203 + if (qos)
  204 + {
  205 + throw ProtocolError("Qos not implemented.");
  206 + //uint16_t packet_id = readTwoBytesToUInt16();
  207 + }
  208 +
  209 + // TODO: validate UTF8.
  210 + size_t payload_length = remainingAfterPos();
  211 + std::string payload(readBytes(payload_length), payload_length);
  212 +
  213 + subscriptionStore->queueAtClientsTemp(topic, *this, sender);
  214 +}
  215 +
162 216
163 Client_p MqttPacket::getSender() const 217 Client_p MqttPacket::getSender() const
164 { 218 {
@@ -197,6 +251,14 @@ void MqttPacket::writeByte(char b) @@ -197,6 +251,14 @@ void MqttPacket::writeByte(char b)
197 bites[pos++] = b; 251 bites[pos++] = b;
198 } 252 }
199 253
  254 +void MqttPacket::writeBytes(const char *b, size_t len)
  255 +{
  256 + if (pos + len > bites.size())
  257 + throw ProtocolError("Exceeding packet size");
  258 +
  259 + memcpy(&bites[pos], b, len);
  260 +}
  261 +
200 uint16_t MqttPacket::readTwoBytesToUInt16() 262 uint16_t MqttPacket::readTwoBytesToUInt16()
201 { 263 {
202 if (pos + 2 > bites.size()) 264 if (pos + 2 > bites.size())
mqttpacket.h
@@ -19,28 +19,34 @@ class MqttPacket @@ -19,28 +19,34 @@ class MqttPacket
19 std::vector<char> bites; 19 std::vector<char> bites;
20 size_t fixed_header_length = 0; 20 size_t fixed_header_length = 0;
21 Client_p sender; 21 Client_p sender;
  22 + char first_byte;
22 size_t pos = 0; 23 size_t pos = 0;
23 ProtocolVersion protocolVersion = ProtocolVersion::None; 24 ProtocolVersion protocolVersion = ProtocolVersion::None;
24 25
25 char *readBytes(size_t length); 26 char *readBytes(size_t length);
26 char readByte(); 27 char readByte();
27 void writeByte(char b); 28 void writeByte(char b);
  29 + void writeBytes(const char *b, size_t len);
28 uint16_t readTwoBytesToUInt16(); 30 uint16_t readTwoBytesToUInt16();
29 size_t remainingAfterPos(); 31 size_t remainingAfterPos();
30 32
31 public: 33 public:
32 PacketType packetType = PacketType::Reserved; 34 PacketType packetType = PacketType::Reserved;
33 MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender); 35 MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender);
  36 +
  37 + // TODO: not constructors, but static functions that return all the stuff after the fixed header, then a constructor with vector.
34 MqttPacket(const ConnAck &connAck); 38 MqttPacket(const ConnAck &connAck);
35 MqttPacket(const SubAck &subAck); 39 MqttPacket(const SubAck &subAck);
  40 + MqttPacket(const Publish &publish);
36 41
37 void handle(std::shared_ptr<SubscriptionStore> &subscriptionStore); 42 void handle(std::shared_ptr<SubscriptionStore> &subscriptionStore);
38 void handleConnect(); 43 void handleConnect();
39 void handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptionStore); 44 void handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptionStore);
40 void handlePing(); 45 void handlePing();
  46 + void handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore);
41 47
42 - size_t getSize() { return bites.size(); }  
43 - const std::vector<char> &getBites() { return bites; } 48 + size_t getSize() const { return bites.size(); }
  49 + const std::vector<char> &getBites() const { return bites; }
44 50
45 Client_p getSender() const; 51 Client_p getSender() const;
46 void setSender(const Client_p &value); 52 void setSender(const Client_p &value);
subscriptionstore.cpp
@@ -9,3 +9,20 @@ void SubscriptionStore::addSubscription(Client_p &amp;client, std::string &amp;topic) @@ -9,3 +9,20 @@ void SubscriptionStore::addSubscription(Client_p &amp;client, std::string &amp;topic)
9 { 9 {
10 this->subscriptions[topic].push_back(client); 10 this->subscriptions[topic].push_back(client);
11 } 11 }
  12 +
  13 +void SubscriptionStore::queueAtClientsTemp(std::string &topic, const MqttPacket &packet, const Client_p &sender)
  14 +{
  15 + for(Client_p &client : subscriptions[topic])
  16 + {
  17 + if (client->getThreadData()->threadnr == sender->getThreadData()->threadnr)
  18 + {
  19 + client->writeMqttPacket(packet);
  20 + client->writeBufIntoFd();
  21 + }
  22 + else
  23 + {
  24 + client->queueMessage(packet);
  25 + client->getThreadData()->addToReadyForDequeuing(client);
  26 + }
  27 + }
  28 +}
subscriptionstore.h
@@ -18,6 +18,8 @@ public: @@ -18,6 +18,8 @@ public:
18 18
19 // work with read copies intead of mutex/lock over the central store 19 // work with read copies intead of mutex/lock over the central store
20 void getReadCopy(); // TODO 20 void getReadCopy(); // TODO
  21 +
  22 + void queueAtClientsTemp(std::string &topic, const MqttPacket &packet, const Client_p &sender);
21 }; 23 };
22 24
23 #endif // SUBSCRIPTIONSTORE_H 25 #endif // SUBSCRIPTIONSTORE_H
threaddata.cpp
@@ -35,3 +35,21 @@ std::shared_ptr&lt;SubscriptionStore&gt; &amp;ThreadData::getSubscriptionStore() @@ -35,3 +35,21 @@ std::shared_ptr&lt;SubscriptionStore&gt; &amp;ThreadData::getSubscriptionStore()
35 { 35 {
36 return subscriptionStore; 36 return subscriptionStore;
37 } 37 }
  38 +
  39 +void ThreadData::wakeUpThread()
  40 +{
  41 + uint64_t one = 1;
  42 + write(event_fd, &one, sizeof(uint64_t));
  43 +}
  44 +
  45 +void ThreadData::addToReadyForDequeuing(Client_p &client)
  46 +{
  47 + this->readyForDequeueing.insert(client);
  48 +}
  49 +
  50 +void ThreadData::clearReadyForDequeueing()
  51 +{
  52 + this->readyForDequeueing.clear();
  53 +}
  54 +
  55 +
threaddata.h
@@ -6,6 +6,8 @@ @@ -6,6 +6,8 @@
6 #include <sys/epoll.h> 6 #include <sys/epoll.h>
7 #include <sys/eventfd.h> 7 #include <sys/eventfd.h>
8 #include <map> 8 #include <map>
  9 +#include <unordered_set>
  10 +#include <unordered_map>
9 11
10 #include "forward_declarations.h" 12 #include "forward_declarations.h"
11 13
@@ -17,8 +19,9 @@ @@ -17,8 +19,9 @@
17 19
18 class ThreadData 20 class ThreadData
19 { 21 {
20 - std::map<int, Client_p> clients_by_fd; 22 + std::unordered_map<int, Client_p> clients_by_fd;
21 std::shared_ptr<SubscriptionStore> subscriptionStore; 23 std::shared_ptr<SubscriptionStore> subscriptionStore;
  24 + std::unordered_set<Client_p> readyForDequeueing;
22 25
23 public: 26 public:
24 std::thread thread; 27 std::thread thread;
@@ -32,6 +35,10 @@ public: @@ -32,6 +35,10 @@ public:
32 Client_p getClient(int fd); 35 Client_p getClient(int fd);
33 void removeClient(Client_p client); 36 void removeClient(Client_p client);
34 std::shared_ptr<SubscriptionStore> &getSubscriptionStore(); 37 std::shared_ptr<SubscriptionStore> &getSubscriptionStore();
  38 + void wakeUpThread();
  39 + void addToReadyForDequeuing(Client_p &client);
  40 + std::unordered_set<Client_p> &getReadyForDequeueing() { return readyForDequeueing; }
  41 + void clearReadyForDequeueing();
35 }; 42 };
36 43
37 #endif // THREADDATA_H 44 #endif // THREADDATA_H
types.cpp
@@ -15,3 +15,10 @@ SubAck::SubAck(uint16_t packet_id, const std::list&lt;std::string&gt; &amp;subs) : @@ -15,3 +15,10 @@ SubAck::SubAck(uint16_t packet_id, const std::list&lt;std::string&gt; &amp;subs) :
15 responses.push_back(SubAckReturnCodes::MaxQoS0); 15 responses.push_back(SubAckReturnCodes::MaxQoS0);
16 } 16 }
17 } 17 }
  18 +
  19 +Publish::Publish(std::string &topic, std::string payload) :
  20 + topic(topic),
  21 + payload(payload)
  22 +{
  23 +
  24 +}
@@ -66,4 +66,12 @@ public: @@ -66,4 +66,12 @@ public:
66 SubAck(uint16_t packet_id, const std::list<std::string> &subs); 66 SubAck(uint16_t packet_id, const std::list<std::string> &subs);
67 }; 67 };
68 68
  69 +class Publish
  70 +{
  71 +public:
  72 + std::string topic;
  73 + std::string payload;
  74 + Publish(std::string &topic, std::string payload);
  75 +};
  76 +
69 #endif // TYPES_H 77 #endif // TYPES_H