Commit 40f8e5e5f8eac75a2647d423826b52acf38bcb71

Authored by Wiebe Cazemier
1 parent a5b81d74

Publishing within the same thread works

Roughly...
client.cpp
... ... @@ -57,7 +57,7 @@ bool Client::readFdIntoBuffer()
57 57 return true;
58 58 }
59 59  
60   -void Client::writeMqttPacket(MqttPacket &packet)
  60 +void Client::writeMqttPacket(const MqttPacket &packet)
61 61 {
62 62 if (packet.getSize() > getWriteBufMaxWriteSize())
63 63 growWriteBuffer(packet.getSize());
... ... @@ -115,6 +115,18 @@ std::string Client::repr()
115 115 return a.str();
116 116 }
117 117  
  118 +void Client::queueMessage(const MqttPacket &packet)
  119 +{
  120 +
  121 +
  122 + // TODO: semaphores on stl containers?
  123 +}
  124 +
  125 +void Client::queuedMessagesToBuffer()
  126 +{
  127 +
  128 +}
  129 +
118 130 bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender)
119 131 {
120 132 while (getReadBufBytesUsed() >= MQTT_HEADER_LENGH)
... ...
client.h
... ... @@ -92,13 +92,16 @@ public:
92 92 void setAuthenticated(bool value) { authenticated = value;}
93 93 bool getAuthenticated() { return authenticated; }
94 94 bool hasConnectPacketSeen() { return connectPacketSeen; }
  95 + ThreadData_p getThreadData() { return threadData; }
95 96  
96 97 void writePingResp();
97   - void writeMqttPacket(MqttPacket &packet);
  98 + void writeMqttPacket(const MqttPacket &packet);
98 99 bool writeBufIntoFd();
99 100  
100 101 std::string repr();
101 102  
  103 + void queueMessage(const MqttPacket &packet);
  104 + void queuedMessagesToBuffer();
102 105 };
103 106  
104 107 #endif // CLIENT_H
... ...
main.cpp
... ... @@ -28,8 +28,22 @@ void do_thread_work(ThreadData *threadData)
28 28  
29 29 std::vector<MqttPacket> packetQueueIn;
30 30  
  31 + uint64_t eventfd_value = 0;
  32 +
31 33 while (1)
32 34 {
  35 + if (eventfd_value > 0)
  36 + {
  37 + for (Client_p client : threadData->getReadyForDequeueing())
  38 + {
  39 + client->queuedMessagesToBuffer();
  40 + }
  41 + threadData->clearReadyForDequeueing();
  42 + eventfd_value = 0;
  43 + }
  44 +
  45 + // TODO: do all the buftofd here, not spread out over
  46 +
33 47 int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100);
34 48  
35 49 if (fdcount > 0)
... ... @@ -39,6 +53,13 @@ void do_thread_work(ThreadData *threadData)
39 53 struct epoll_event cur_ev = events[i];
40 54 int fd = cur_ev.data.fd;
41 55  
  56 + // If this thread was actively woken up.
  57 + if (fd == threadData->event_fd)
  58 + {
  59 + read(fd, &eventfd_value, sizeof(uint64_t));
  60 + continue;
  61 + }
  62 +
42 63 Client_p client = threadData->getClient(fd);
43 64  
44 65 if (client)
... ...
mqttpacket.cpp
... ... @@ -3,16 +3,17 @@
3 3 #include <iostream>
4 4 #include <list>
5 5  
  6 +// constructor for parsing incoming packets
6 7 MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender) :
7 8 bites(len),
8 9 fixed_header_length(fixed_header_length),
9 10 sender(sender)
10 11 {
11   - unsigned char _packetType = (buf[0] & 0xF0) >> 4;
  12 + std::memcpy(&bites[0], buf, len);
  13 + first_byte = bites[0];
  14 + unsigned char _packetType = (first_byte & 0xF0) >> 4;
12 15 packetType = (PacketType)_packetType;
13 16 pos += fixed_header_length;
14   -
15   - std::memcpy(&bites[0], buf, len);
16 17 }
17 18  
18 19 MqttPacket::MqttPacket(const ConnAck &connAck) :
... ... @@ -46,14 +47,42 @@ MqttPacket::MqttPacket(const SubAck &amp;subAck) :
46 47 bites[1] = returnList.size() + 1; // TODO: make some generic way of calculating the header and use the multi-byte length
47 48 }
48 49  
  50 +MqttPacket::MqttPacket(const Publish &publish) :
  51 + bites(publish.topic.length() + publish.payload.length() + 2 + 2) // TODO: same as above
  52 +{
  53 + packetType = PacketType::PUBLISH;
  54 + char first_byte = static_cast<char>(packetType) << 4;
  55 + writeByte(first_byte);
  56 +
  57 + char topicLenMSB = (publish.topic.length() & 0xF0) >> 8;
  58 + char topicLenLSB = publish.topic.length() & 0x0F;
  59 + writeByte(topicLenMSB);
  60 + writeByte(topicLenLSB);
  61 + writeBytes(publish.topic.c_str(), publish.topic.length());
  62 +
  63 + writeBytes(publish.payload.c_str(), publish.payload.length());
  64 +
  65 + // TODO: untested. May be unnecessary, because a received packet can be resent just like that.
  66 +}
  67 +
49 68 void MqttPacket::handle(std::shared_ptr<SubscriptionStore> &subscriptionStore)
50 69 {
  70 + if (packetType != PacketType::CONNECT)
  71 + {
  72 + if (!sender->getAuthenticated())
  73 + {
  74 + throw ProtocolError("Non-connect packet from non-authenticated client.");
  75 + }
  76 + }
  77 +
51 78 if (packetType == PacketType::CONNECT)
52 79 handleConnect();
53 80 else if (packetType == PacketType::PINGREQ)
54 81 sender->writePingResp();
55 82 else if (packetType == PacketType::SUBSCRIBE)
56 83 handleSubscribe(subscriptionStore);
  84 + else if (packetType == PacketType::PUBLISH)
  85 + handlePublish(subscriptionStore);
57 86 }
58 87  
59 88 void MqttPacket::handleConnect()
... ... @@ -123,6 +152,7 @@ void MqttPacket::handleConnect()
123 152 // TODO: validate UTF8 encoded username/password.
124 153  
125 154 sender->setClientProperties(client_id, username, true, keep_alive);
  155 + sender->setAuthenticated(true);
126 156  
127 157 std::cout << "Connect: " << sender->repr() << std::endl;
128 158  
... ... @@ -159,6 +189,30 @@ void MqttPacket::handleSubscribe(std::shared_ptr&lt;SubscriptionStore&gt; &amp;subscriptio
159 189 sender->writeBufIntoFd();
160 190 }
161 191  
  192 +void MqttPacket::handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore)
  193 +{
  194 + uint16_t variable_header_length = readTwoBytesToUInt16();
  195 +
  196 + if (variable_header_length == 0)
  197 + throw ProtocolError("Empty publish topic");
  198 +
  199 + char qos = (first_byte & 0b00000110) >> 1;
  200 +
  201 + std::string topic(readBytes(variable_header_length), variable_header_length);
  202 +
  203 + if (qos)
  204 + {
  205 + throw ProtocolError("Qos not implemented.");
  206 + //uint16_t packet_id = readTwoBytesToUInt16();
  207 + }
  208 +
  209 + // TODO: validate UTF8.
  210 + size_t payload_length = remainingAfterPos();
  211 + std::string payload(readBytes(payload_length), payload_length);
  212 +
  213 + subscriptionStore->queueAtClientsTemp(topic, *this, sender);
  214 +}
  215 +
162 216  
163 217 Client_p MqttPacket::getSender() const
164 218 {
... ... @@ -197,6 +251,14 @@ void MqttPacket::writeByte(char b)
197 251 bites[pos++] = b;
198 252 }
199 253  
  254 +void MqttPacket::writeBytes(const char *b, size_t len)
  255 +{
  256 + if (pos + len > bites.size())
  257 + throw ProtocolError("Exceeding packet size");
  258 +
  259 + memcpy(&bites[pos], b, len);
  260 +}
  261 +
200 262 uint16_t MqttPacket::readTwoBytesToUInt16()
201 263 {
202 264 if (pos + 2 > bites.size())
... ...
mqttpacket.h
... ... @@ -19,28 +19,34 @@ class MqttPacket
19 19 std::vector<char> bites;
20 20 size_t fixed_header_length = 0;
21 21 Client_p sender;
  22 + char first_byte;
22 23 size_t pos = 0;
23 24 ProtocolVersion protocolVersion = ProtocolVersion::None;
24 25  
25 26 char *readBytes(size_t length);
26 27 char readByte();
27 28 void writeByte(char b);
  29 + void writeBytes(const char *b, size_t len);
28 30 uint16_t readTwoBytesToUInt16();
29 31 size_t remainingAfterPos();
30 32  
31 33 public:
32 34 PacketType packetType = PacketType::Reserved;
33 35 MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender);
  36 +
  37 + // TODO: not constructors, but static functions that return all the stuff after the fixed header, then a constructor with vector.
34 38 MqttPacket(const ConnAck &connAck);
35 39 MqttPacket(const SubAck &subAck);
  40 + MqttPacket(const Publish &publish);
36 41  
37 42 void handle(std::shared_ptr<SubscriptionStore> &subscriptionStore);
38 43 void handleConnect();
39 44 void handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptionStore);
40 45 void handlePing();
  46 + void handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore);
41 47  
42   - size_t getSize() { return bites.size(); }
43   - const std::vector<char> &getBites() { return bites; }
  48 + size_t getSize() const { return bites.size(); }
  49 + const std::vector<char> &getBites() const { return bites; }
44 50  
45 51 Client_p getSender() const;
46 52 void setSender(const Client_p &value);
... ...
subscriptionstore.cpp
... ... @@ -9,3 +9,20 @@ void SubscriptionStore::addSubscription(Client_p &amp;client, std::string &amp;topic)
9 9 {
10 10 this->subscriptions[topic].push_back(client);
11 11 }
  12 +
  13 +void SubscriptionStore::queueAtClientsTemp(std::string &topic, const MqttPacket &packet, const Client_p &sender)
  14 +{
  15 + for(Client_p &client : subscriptions[topic])
  16 + {
  17 + if (client->getThreadData()->threadnr == sender->getThreadData()->threadnr)
  18 + {
  19 + client->writeMqttPacket(packet);
  20 + client->writeBufIntoFd();
  21 + }
  22 + else
  23 + {
  24 + client->queueMessage(packet);
  25 + client->getThreadData()->addToReadyForDequeuing(client);
  26 + }
  27 + }
  28 +}
... ...
subscriptionstore.h
... ... @@ -18,6 +18,8 @@ public:
18 18  
19 19 // work with read copies intead of mutex/lock over the central store
20 20 void getReadCopy(); // TODO
  21 +
  22 + void queueAtClientsTemp(std::string &topic, const MqttPacket &packet, const Client_p &sender);
21 23 };
22 24  
23 25 #endif // SUBSCRIPTIONSTORE_H
... ...
threaddata.cpp
... ... @@ -35,3 +35,21 @@ std::shared_ptr&lt;SubscriptionStore&gt; &amp;ThreadData::getSubscriptionStore()
35 35 {
36 36 return subscriptionStore;
37 37 }
  38 +
  39 +void ThreadData::wakeUpThread()
  40 +{
  41 + uint64_t one = 1;
  42 + write(event_fd, &one, sizeof(uint64_t));
  43 +}
  44 +
  45 +void ThreadData::addToReadyForDequeuing(Client_p &client)
  46 +{
  47 + this->readyForDequeueing.insert(client);
  48 +}
  49 +
  50 +void ThreadData::clearReadyForDequeueing()
  51 +{
  52 + this->readyForDequeueing.clear();
  53 +}
  54 +
  55 +
... ...
threaddata.h
... ... @@ -6,6 +6,8 @@
6 6 #include <sys/epoll.h>
7 7 #include <sys/eventfd.h>
8 8 #include <map>
  9 +#include <unordered_set>
  10 +#include <unordered_map>
9 11  
10 12 #include "forward_declarations.h"
11 13  
... ... @@ -17,8 +19,9 @@
17 19  
18 20 class ThreadData
19 21 {
20   - std::map<int, Client_p> clients_by_fd;
  22 + std::unordered_map<int, Client_p> clients_by_fd;
21 23 std::shared_ptr<SubscriptionStore> subscriptionStore;
  24 + std::unordered_set<Client_p> readyForDequeueing;
22 25  
23 26 public:
24 27 std::thread thread;
... ... @@ -32,6 +35,10 @@ public:
32 35 Client_p getClient(int fd);
33 36 void removeClient(Client_p client);
34 37 std::shared_ptr<SubscriptionStore> &getSubscriptionStore();
  38 + void wakeUpThread();
  39 + void addToReadyForDequeuing(Client_p &client);
  40 + std::unordered_set<Client_p> &getReadyForDequeueing() { return readyForDequeueing; }
  41 + void clearReadyForDequeueing();
35 42 };
36 43  
37 44 #endif // THREADDATA_H
... ...
types.cpp
... ... @@ -15,3 +15,10 @@ SubAck::SubAck(uint16_t packet_id, const std::list&lt;std::string&gt; &amp;subs) :
15 15 responses.push_back(SubAckReturnCodes::MaxQoS0);
16 16 }
17 17 }
  18 +
  19 +Publish::Publish(std::string &topic, std::string payload) :
  20 + topic(topic),
  21 + payload(payload)
  22 +{
  23 +
  24 +}
... ...
... ... @@ -66,4 +66,12 @@ public:
66 66 SubAck(uint16_t packet_id, const std::list<std::string> &subs);
67 67 };
68 68  
  69 +class Publish
  70 +{
  71 +public:
  72 + std::string topic;
  73 + std::string payload;
  74 + Publish(std::string &topic, std::string payload);
  75 +};
  76 +
69 77 #endif // TYPES_H
... ...