Commit 40f8e5e5f8eac75a2647d423826b52acf38bcb71
1 parent
a5b81d74
Publishing within the same thread works
Roughly...
Showing
11 changed files
with
171 additions
and
8 deletions
client.cpp
| @@ -57,7 +57,7 @@ bool Client::readFdIntoBuffer() | @@ -57,7 +57,7 @@ bool Client::readFdIntoBuffer() | ||
| 57 | return true; | 57 | return true; |
| 58 | } | 58 | } |
| 59 | 59 | ||
| 60 | -void Client::writeMqttPacket(MqttPacket &packet) | 60 | +void Client::writeMqttPacket(const MqttPacket &packet) |
| 61 | { | 61 | { |
| 62 | if (packet.getSize() > getWriteBufMaxWriteSize()) | 62 | if (packet.getSize() > getWriteBufMaxWriteSize()) |
| 63 | growWriteBuffer(packet.getSize()); | 63 | growWriteBuffer(packet.getSize()); |
| @@ -115,6 +115,18 @@ std::string Client::repr() | @@ -115,6 +115,18 @@ std::string Client::repr() | ||
| 115 | return a.str(); | 115 | return a.str(); |
| 116 | } | 116 | } |
| 117 | 117 | ||
| 118 | +void Client::queueMessage(const MqttPacket &packet) | ||
| 119 | +{ | ||
| 120 | + | ||
| 121 | + | ||
| 122 | + // TODO: semaphores on stl containers? | ||
| 123 | +} | ||
| 124 | + | ||
| 125 | +void Client::queuedMessagesToBuffer() | ||
| 126 | +{ | ||
| 127 | + | ||
| 128 | +} | ||
| 129 | + | ||
| 118 | bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender) | 130 | bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender) |
| 119 | { | 131 | { |
| 120 | while (getReadBufBytesUsed() >= MQTT_HEADER_LENGH) | 132 | while (getReadBufBytesUsed() >= MQTT_HEADER_LENGH) |
client.h
| @@ -92,13 +92,16 @@ public: | @@ -92,13 +92,16 @@ public: | ||
| 92 | void setAuthenticated(bool value) { authenticated = value;} | 92 | void setAuthenticated(bool value) { authenticated = value;} |
| 93 | bool getAuthenticated() { return authenticated; } | 93 | bool getAuthenticated() { return authenticated; } |
| 94 | bool hasConnectPacketSeen() { return connectPacketSeen; } | 94 | bool hasConnectPacketSeen() { return connectPacketSeen; } |
| 95 | + ThreadData_p getThreadData() { return threadData; } | ||
| 95 | 96 | ||
| 96 | void writePingResp(); | 97 | void writePingResp(); |
| 97 | - void writeMqttPacket(MqttPacket &packet); | 98 | + void writeMqttPacket(const MqttPacket &packet); |
| 98 | bool writeBufIntoFd(); | 99 | bool writeBufIntoFd(); |
| 99 | 100 | ||
| 100 | std::string repr(); | 101 | std::string repr(); |
| 101 | 102 | ||
| 103 | + void queueMessage(const MqttPacket &packet); | ||
| 104 | + void queuedMessagesToBuffer(); | ||
| 102 | }; | 105 | }; |
| 103 | 106 | ||
| 104 | #endif // CLIENT_H | 107 | #endif // CLIENT_H |
main.cpp
| @@ -28,8 +28,22 @@ void do_thread_work(ThreadData *threadData) | @@ -28,8 +28,22 @@ void do_thread_work(ThreadData *threadData) | ||
| 28 | 28 | ||
| 29 | std::vector<MqttPacket> packetQueueIn; | 29 | std::vector<MqttPacket> packetQueueIn; |
| 30 | 30 | ||
| 31 | + uint64_t eventfd_value = 0; | ||
| 32 | + | ||
| 31 | while (1) | 33 | while (1) |
| 32 | { | 34 | { |
| 35 | + if (eventfd_value > 0) | ||
| 36 | + { | ||
| 37 | + for (Client_p client : threadData->getReadyForDequeueing()) | ||
| 38 | + { | ||
| 39 | + client->queuedMessagesToBuffer(); | ||
| 40 | + } | ||
| 41 | + threadData->clearReadyForDequeueing(); | ||
| 42 | + eventfd_value = 0; | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + // TODO: do all the buftofd here, not spread out over | ||
| 46 | + | ||
| 33 | int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); | 47 | int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); |
| 34 | 48 | ||
| 35 | if (fdcount > 0) | 49 | if (fdcount > 0) |
| @@ -39,6 +53,13 @@ void do_thread_work(ThreadData *threadData) | @@ -39,6 +53,13 @@ void do_thread_work(ThreadData *threadData) | ||
| 39 | struct epoll_event cur_ev = events[i]; | 53 | struct epoll_event cur_ev = events[i]; |
| 40 | int fd = cur_ev.data.fd; | 54 | int fd = cur_ev.data.fd; |
| 41 | 55 | ||
| 56 | + // If this thread was actively woken up. | ||
| 57 | + if (fd == threadData->event_fd) | ||
| 58 | + { | ||
| 59 | + read(fd, &eventfd_value, sizeof(uint64_t)); | ||
| 60 | + continue; | ||
| 61 | + } | ||
| 62 | + | ||
| 42 | Client_p client = threadData->getClient(fd); | 63 | Client_p client = threadData->getClient(fd); |
| 43 | 64 | ||
| 44 | if (client) | 65 | if (client) |
mqttpacket.cpp
| @@ -3,16 +3,17 @@ | @@ -3,16 +3,17 @@ | ||
| 3 | #include <iostream> | 3 | #include <iostream> |
| 4 | #include <list> | 4 | #include <list> |
| 5 | 5 | ||
| 6 | +// constructor for parsing incoming packets | ||
| 6 | MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender) : | 7 | MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender) : |
| 7 | bites(len), | 8 | bites(len), |
| 8 | fixed_header_length(fixed_header_length), | 9 | fixed_header_length(fixed_header_length), |
| 9 | sender(sender) | 10 | sender(sender) |
| 10 | { | 11 | { |
| 11 | - unsigned char _packetType = (buf[0] & 0xF0) >> 4; | 12 | + std::memcpy(&bites[0], buf, len); |
| 13 | + first_byte = bites[0]; | ||
| 14 | + unsigned char _packetType = (first_byte & 0xF0) >> 4; | ||
| 12 | packetType = (PacketType)_packetType; | 15 | packetType = (PacketType)_packetType; |
| 13 | pos += fixed_header_length; | 16 | pos += fixed_header_length; |
| 14 | - | ||
| 15 | - std::memcpy(&bites[0], buf, len); | ||
| 16 | } | 17 | } |
| 17 | 18 | ||
| 18 | MqttPacket::MqttPacket(const ConnAck &connAck) : | 19 | MqttPacket::MqttPacket(const ConnAck &connAck) : |
| @@ -46,14 +47,42 @@ MqttPacket::MqttPacket(const SubAck &subAck) : | @@ -46,14 +47,42 @@ MqttPacket::MqttPacket(const SubAck &subAck) : | ||
| 46 | bites[1] = returnList.size() + 1; // TODO: make some generic way of calculating the header and use the multi-byte length | 47 | bites[1] = returnList.size() + 1; // TODO: make some generic way of calculating the header and use the multi-byte length |
| 47 | } | 48 | } |
| 48 | 49 | ||
| 50 | +MqttPacket::MqttPacket(const Publish &publish) : | ||
| 51 | + bites(publish.topic.length() + publish.payload.length() + 2 + 2) // TODO: same as above | ||
| 52 | +{ | ||
| 53 | + packetType = PacketType::PUBLISH; | ||
| 54 | + char first_byte = static_cast<char>(packetType) << 4; | ||
| 55 | + writeByte(first_byte); | ||
| 56 | + | ||
| 57 | + char topicLenMSB = (publish.topic.length() & 0xF0) >> 8; | ||
| 58 | + char topicLenLSB = publish.topic.length() & 0x0F; | ||
| 59 | + writeByte(topicLenMSB); | ||
| 60 | + writeByte(topicLenLSB); | ||
| 61 | + writeBytes(publish.topic.c_str(), publish.topic.length()); | ||
| 62 | + | ||
| 63 | + writeBytes(publish.payload.c_str(), publish.payload.length()); | ||
| 64 | + | ||
| 65 | + // TODO: untested. May be unnecessary, because a received packet can be resent just like that. | ||
| 66 | +} | ||
| 67 | + | ||
| 49 | void MqttPacket::handle(std::shared_ptr<SubscriptionStore> &subscriptionStore) | 68 | void MqttPacket::handle(std::shared_ptr<SubscriptionStore> &subscriptionStore) |
| 50 | { | 69 | { |
| 70 | + if (packetType != PacketType::CONNECT) | ||
| 71 | + { | ||
| 72 | + if (!sender->getAuthenticated()) | ||
| 73 | + { | ||
| 74 | + throw ProtocolError("Non-connect packet from non-authenticated client."); | ||
| 75 | + } | ||
| 76 | + } | ||
| 77 | + | ||
| 51 | if (packetType == PacketType::CONNECT) | 78 | if (packetType == PacketType::CONNECT) |
| 52 | handleConnect(); | 79 | handleConnect(); |
| 53 | else if (packetType == PacketType::PINGREQ) | 80 | else if (packetType == PacketType::PINGREQ) |
| 54 | sender->writePingResp(); | 81 | sender->writePingResp(); |
| 55 | else if (packetType == PacketType::SUBSCRIBE) | 82 | else if (packetType == PacketType::SUBSCRIBE) |
| 56 | handleSubscribe(subscriptionStore); | 83 | handleSubscribe(subscriptionStore); |
| 84 | + else if (packetType == PacketType::PUBLISH) | ||
| 85 | + handlePublish(subscriptionStore); | ||
| 57 | } | 86 | } |
| 58 | 87 | ||
| 59 | void MqttPacket::handleConnect() | 88 | void MqttPacket::handleConnect() |
| @@ -123,6 +152,7 @@ void MqttPacket::handleConnect() | @@ -123,6 +152,7 @@ void MqttPacket::handleConnect() | ||
| 123 | // TODO: validate UTF8 encoded username/password. | 152 | // TODO: validate UTF8 encoded username/password. |
| 124 | 153 | ||
| 125 | sender->setClientProperties(client_id, username, true, keep_alive); | 154 | sender->setClientProperties(client_id, username, true, keep_alive); |
| 155 | + sender->setAuthenticated(true); | ||
| 126 | 156 | ||
| 127 | std::cout << "Connect: " << sender->repr() << std::endl; | 157 | std::cout << "Connect: " << sender->repr() << std::endl; |
| 128 | 158 | ||
| @@ -159,6 +189,30 @@ void MqttPacket::handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptio | @@ -159,6 +189,30 @@ void MqttPacket::handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptio | ||
| 159 | sender->writeBufIntoFd(); | 189 | sender->writeBufIntoFd(); |
| 160 | } | 190 | } |
| 161 | 191 | ||
| 192 | +void MqttPacket::handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore) | ||
| 193 | +{ | ||
| 194 | + uint16_t variable_header_length = readTwoBytesToUInt16(); | ||
| 195 | + | ||
| 196 | + if (variable_header_length == 0) | ||
| 197 | + throw ProtocolError("Empty publish topic"); | ||
| 198 | + | ||
| 199 | + char qos = (first_byte & 0b00000110) >> 1; | ||
| 200 | + | ||
| 201 | + std::string topic(readBytes(variable_header_length), variable_header_length); | ||
| 202 | + | ||
| 203 | + if (qos) | ||
| 204 | + { | ||
| 205 | + throw ProtocolError("Qos not implemented."); | ||
| 206 | + //uint16_t packet_id = readTwoBytesToUInt16(); | ||
| 207 | + } | ||
| 208 | + | ||
| 209 | + // TODO: validate UTF8. | ||
| 210 | + size_t payload_length = remainingAfterPos(); | ||
| 211 | + std::string payload(readBytes(payload_length), payload_length); | ||
| 212 | + | ||
| 213 | + subscriptionStore->queueAtClientsTemp(topic, *this, sender); | ||
| 214 | +} | ||
| 215 | + | ||
| 162 | 216 | ||
| 163 | Client_p MqttPacket::getSender() const | 217 | Client_p MqttPacket::getSender() const |
| 164 | { | 218 | { |
| @@ -197,6 +251,14 @@ void MqttPacket::writeByte(char b) | @@ -197,6 +251,14 @@ void MqttPacket::writeByte(char b) | ||
| 197 | bites[pos++] = b; | 251 | bites[pos++] = b; |
| 198 | } | 252 | } |
| 199 | 253 | ||
| 254 | +void MqttPacket::writeBytes(const char *b, size_t len) | ||
| 255 | +{ | ||
| 256 | + if (pos + len > bites.size()) | ||
| 257 | + throw ProtocolError("Exceeding packet size"); | ||
| 258 | + | ||
| 259 | + memcpy(&bites[pos], b, len); | ||
| 260 | +} | ||
| 261 | + | ||
| 200 | uint16_t MqttPacket::readTwoBytesToUInt16() | 262 | uint16_t MqttPacket::readTwoBytesToUInt16() |
| 201 | { | 263 | { |
| 202 | if (pos + 2 > bites.size()) | 264 | if (pos + 2 > bites.size()) |
mqttpacket.h
| @@ -19,28 +19,34 @@ class MqttPacket | @@ -19,28 +19,34 @@ class MqttPacket | ||
| 19 | std::vector<char> bites; | 19 | std::vector<char> bites; |
| 20 | size_t fixed_header_length = 0; | 20 | size_t fixed_header_length = 0; |
| 21 | Client_p sender; | 21 | Client_p sender; |
| 22 | + char first_byte; | ||
| 22 | size_t pos = 0; | 23 | size_t pos = 0; |
| 23 | ProtocolVersion protocolVersion = ProtocolVersion::None; | 24 | ProtocolVersion protocolVersion = ProtocolVersion::None; |
| 24 | 25 | ||
| 25 | char *readBytes(size_t length); | 26 | char *readBytes(size_t length); |
| 26 | char readByte(); | 27 | char readByte(); |
| 27 | void writeByte(char b); | 28 | void writeByte(char b); |
| 29 | + void writeBytes(const char *b, size_t len); | ||
| 28 | uint16_t readTwoBytesToUInt16(); | 30 | uint16_t readTwoBytesToUInt16(); |
| 29 | size_t remainingAfterPos(); | 31 | size_t remainingAfterPos(); |
| 30 | 32 | ||
| 31 | public: | 33 | public: |
| 32 | PacketType packetType = PacketType::Reserved; | 34 | PacketType packetType = PacketType::Reserved; |
| 33 | MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender); | 35 | MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender); |
| 36 | + | ||
| 37 | + // TODO: not constructors, but static functions that return all the stuff after the fixed header, then a constructor with vector. | ||
| 34 | MqttPacket(const ConnAck &connAck); | 38 | MqttPacket(const ConnAck &connAck); |
| 35 | MqttPacket(const SubAck &subAck); | 39 | MqttPacket(const SubAck &subAck); |
| 40 | + MqttPacket(const Publish &publish); | ||
| 36 | 41 | ||
| 37 | void handle(std::shared_ptr<SubscriptionStore> &subscriptionStore); | 42 | void handle(std::shared_ptr<SubscriptionStore> &subscriptionStore); |
| 38 | void handleConnect(); | 43 | void handleConnect(); |
| 39 | void handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptionStore); | 44 | void handleSubscribe(std::shared_ptr<SubscriptionStore> &subscriptionStore); |
| 40 | void handlePing(); | 45 | void handlePing(); |
| 46 | + void handlePublish(std::shared_ptr<SubscriptionStore> &subscriptionStore); | ||
| 41 | 47 | ||
| 42 | - size_t getSize() { return bites.size(); } | ||
| 43 | - const std::vector<char> &getBites() { return bites; } | 48 | + size_t getSize() const { return bites.size(); } |
| 49 | + const std::vector<char> &getBites() const { return bites; } | ||
| 44 | 50 | ||
| 45 | Client_p getSender() const; | 51 | Client_p getSender() const; |
| 46 | void setSender(const Client_p &value); | 52 | void setSender(const Client_p &value); |
subscriptionstore.cpp
| @@ -9,3 +9,20 @@ void SubscriptionStore::addSubscription(Client_p &client, std::string &topic) | @@ -9,3 +9,20 @@ void SubscriptionStore::addSubscription(Client_p &client, std::string &topic) | ||
| 9 | { | 9 | { |
| 10 | this->subscriptions[topic].push_back(client); | 10 | this->subscriptions[topic].push_back(client); |
| 11 | } | 11 | } |
| 12 | + | ||
| 13 | +void SubscriptionStore::queueAtClientsTemp(std::string &topic, const MqttPacket &packet, const Client_p &sender) | ||
| 14 | +{ | ||
| 15 | + for(Client_p &client : subscriptions[topic]) | ||
| 16 | + { | ||
| 17 | + if (client->getThreadData()->threadnr == sender->getThreadData()->threadnr) | ||
| 18 | + { | ||
| 19 | + client->writeMqttPacket(packet); | ||
| 20 | + client->writeBufIntoFd(); | ||
| 21 | + } | ||
| 22 | + else | ||
| 23 | + { | ||
| 24 | + client->queueMessage(packet); | ||
| 25 | + client->getThreadData()->addToReadyForDequeuing(client); | ||
| 26 | + } | ||
| 27 | + } | ||
| 28 | +} |
subscriptionstore.h
| @@ -18,6 +18,8 @@ public: | @@ -18,6 +18,8 @@ public: | ||
| 18 | 18 | ||
| 19 | // work with read copies intead of mutex/lock over the central store | 19 | // work with read copies intead of mutex/lock over the central store |
| 20 | void getReadCopy(); // TODO | 20 | void getReadCopy(); // TODO |
| 21 | + | ||
| 22 | + void queueAtClientsTemp(std::string &topic, const MqttPacket &packet, const Client_p &sender); | ||
| 21 | }; | 23 | }; |
| 22 | 24 | ||
| 23 | #endif // SUBSCRIPTIONSTORE_H | 25 | #endif // SUBSCRIPTIONSTORE_H |
threaddata.cpp
| @@ -35,3 +35,21 @@ std::shared_ptr<SubscriptionStore> &ThreadData::getSubscriptionStore() | @@ -35,3 +35,21 @@ std::shared_ptr<SubscriptionStore> &ThreadData::getSubscriptionStore() | ||
| 35 | { | 35 | { |
| 36 | return subscriptionStore; | 36 | return subscriptionStore; |
| 37 | } | 37 | } |
| 38 | + | ||
| 39 | +void ThreadData::wakeUpThread() | ||
| 40 | +{ | ||
| 41 | + uint64_t one = 1; | ||
| 42 | + write(event_fd, &one, sizeof(uint64_t)); | ||
| 43 | +} | ||
| 44 | + | ||
| 45 | +void ThreadData::addToReadyForDequeuing(Client_p &client) | ||
| 46 | +{ | ||
| 47 | + this->readyForDequeueing.insert(client); | ||
| 48 | +} | ||
| 49 | + | ||
| 50 | +void ThreadData::clearReadyForDequeueing() | ||
| 51 | +{ | ||
| 52 | + this->readyForDequeueing.clear(); | ||
| 53 | +} | ||
| 54 | + | ||
| 55 | + |
threaddata.h
| @@ -6,6 +6,8 @@ | @@ -6,6 +6,8 @@ | ||
| 6 | #include <sys/epoll.h> | 6 | #include <sys/epoll.h> |
| 7 | #include <sys/eventfd.h> | 7 | #include <sys/eventfd.h> |
| 8 | #include <map> | 8 | #include <map> |
| 9 | +#include <unordered_set> | ||
| 10 | +#include <unordered_map> | ||
| 9 | 11 | ||
| 10 | #include "forward_declarations.h" | 12 | #include "forward_declarations.h" |
| 11 | 13 | ||
| @@ -17,8 +19,9 @@ | @@ -17,8 +19,9 @@ | ||
| 17 | 19 | ||
| 18 | class ThreadData | 20 | class ThreadData |
| 19 | { | 21 | { |
| 20 | - std::map<int, Client_p> clients_by_fd; | 22 | + std::unordered_map<int, Client_p> clients_by_fd; |
| 21 | std::shared_ptr<SubscriptionStore> subscriptionStore; | 23 | std::shared_ptr<SubscriptionStore> subscriptionStore; |
| 24 | + std::unordered_set<Client_p> readyForDequeueing; | ||
| 22 | 25 | ||
| 23 | public: | 26 | public: |
| 24 | std::thread thread; | 27 | std::thread thread; |
| @@ -32,6 +35,10 @@ public: | @@ -32,6 +35,10 @@ public: | ||
| 32 | Client_p getClient(int fd); | 35 | Client_p getClient(int fd); |
| 33 | void removeClient(Client_p client); | 36 | void removeClient(Client_p client); |
| 34 | std::shared_ptr<SubscriptionStore> &getSubscriptionStore(); | 37 | std::shared_ptr<SubscriptionStore> &getSubscriptionStore(); |
| 38 | + void wakeUpThread(); | ||
| 39 | + void addToReadyForDequeuing(Client_p &client); | ||
| 40 | + std::unordered_set<Client_p> &getReadyForDequeueing() { return readyForDequeueing; } | ||
| 41 | + void clearReadyForDequeueing(); | ||
| 35 | }; | 42 | }; |
| 36 | 43 | ||
| 37 | #endif // THREADDATA_H | 44 | #endif // THREADDATA_H |
types.cpp
| @@ -15,3 +15,10 @@ SubAck::SubAck(uint16_t packet_id, const std::list<std::string> &subs) : | @@ -15,3 +15,10 @@ SubAck::SubAck(uint16_t packet_id, const std::list<std::string> &subs) : | ||
| 15 | responses.push_back(SubAckReturnCodes::MaxQoS0); | 15 | responses.push_back(SubAckReturnCodes::MaxQoS0); |
| 16 | } | 16 | } |
| 17 | } | 17 | } |
| 18 | + | ||
| 19 | +Publish::Publish(std::string &topic, std::string payload) : | ||
| 20 | + topic(topic), | ||
| 21 | + payload(payload) | ||
| 22 | +{ | ||
| 23 | + | ||
| 24 | +} |
types.h
| @@ -66,4 +66,12 @@ public: | @@ -66,4 +66,12 @@ public: | ||
| 66 | SubAck(uint16_t packet_id, const std::list<std::string> &subs); | 66 | SubAck(uint16_t packet_id, const std::list<std::string> &subs); |
| 67 | }; | 67 | }; |
| 68 | 68 | ||
| 69 | +class Publish | ||
| 70 | +{ | ||
| 71 | +public: | ||
| 72 | + std::string topic; | ||
| 73 | + std::string payload; | ||
| 74 | + Publish(std::string &topic, std::string payload); | ||
| 75 | +}; | ||
| 76 | + | ||
| 69 | #endif // TYPES_H | 77 | #endif // TYPES_H |