From 760ec58878490864a6f82dfcf1f6858cc19306a8 Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Wed, 13 Jan 2021 21:58:28 +0100 Subject: [PATCH] QoS 1, 80% --- client.cpp | 16 +++++++++++++--- client.h | 4 ++++ forward_declarations.h | 1 + mqttpacket.cpp | 108 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------- mqttpacket.h | 14 +++++++++++--- session.cpp | 74 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ session.h | 18 ++++++++++++++++-- subscriptionstore.cpp | 14 ++++++-------- subscriptionstore.h | 4 ++-- types.cpp | 22 ++++++++++++++++++++-- types.h | 13 +++++++++++-- 11 files changed, 249 insertions(+), 39 deletions(-) diff --git a/client.cpp b/client.cpp index bc60b56..70eb551 100644 --- a/client.cpp +++ b/client.cpp @@ -98,9 +98,9 @@ void Client::writeMqttPacket(const MqttPacket &packet) writebuf.doubleSize(); } - // And drop a publish when it doesn't fit, even after resizing. This means we do allow pings. - // TODO: when QoS is implemented, different filtering may be required. - if (packet.packetType == PacketType::PUBLISH && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace()) + // And drop a publish when it doesn't fit, even after resizing. This means we do allow pings. And + // QoS packet are queued and limited elsewhere. + if (packet.packetType == PacketType::PUBLISH && packet.getQos() == 0 && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace()) { return; } @@ -350,6 +350,16 @@ void Client::setWill(const std::string &topic, const std::string &payload, bool this->will_qos = qos; } +void Client::assignSession(std::shared_ptr &session) +{ + this->session = session; +} + +std::shared_ptr Client::getSession() +{ + return this->session; +} + diff --git a/client.h b/client.h index b7b49ca..fc8ed03 100644 --- a/client.h +++ b/client.h @@ -51,6 +51,8 @@ class Client ThreadData_p threadData; std::mutex writeBufMutex; + std::shared_ptr session; + void setReadyForWriting(bool val); void setReadyForReading(bool val); @@ -73,6 +75,8 @@ public: ThreadData_p getThreadData() { return threadData; } std::string &getClientId() { return this->clientid; } bool getCleanSession() { return cleanSession; } + void assignSession(std::shared_ptr &session); + std::shared_ptr getSession(); void writePingResp(); void writeMqttPacket(const MqttPacket &packet); diff --git a/forward_declarations.h b/forward_declarations.h index c56d5a8..d4f6902 100644 --- a/forward_declarations.h +++ b/forward_declarations.h @@ -9,6 +9,7 @@ class ThreadData; typedef std::shared_ptr ThreadData_p; class MqttPacket; class SubscriptionStore; +class Session; #endif // FORWARD_DECLARATIONS_H diff --git a/mqttpacket.cpp b/mqttpacket.cpp index afaaff3..e3eb51c 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -36,29 +36,36 @@ MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_lengt pos += fixed_header_length; } +// This is easier than using the copy constructor publically, because then I have to keep maintaining a functioning copy constructor. +// Returning shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource. +std::shared_ptr MqttPacket::getCopy() const +{ + std::shared_ptr copyPacket(new MqttPacket(*this)); + copyPacket->sender.reset(); + return copyPacket; +} + // This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector. MqttPacket::MqttPacket(const ConnAck &connAck) : - bites(connAck.getLength() + 2) + bites(connAck.getLengthWithoutFixedHeader() + 2) { fixed_header_length = 2; packetType = PacketType::CONNACK; char first_byte = static_cast(packetType) << 4; writeByte(first_byte); writeByte(2); // length is always 2. - writeByte(0); // all connect-ack flags are 0, except session-present, but we don't have that yet. + writeByte(0); // all connect-ack flags are 0, except session-present, but we don't have that yet. TODO: make that writeByte(static_cast(connAck.return_code)); } MqttPacket::MqttPacket(const SubAck &subAck) : - bites(3) + bites(subAck.getLengthWithoutFixedHeader()) { - fixed_header_length = 2; // TODO: this is wrong, pending implementation of the new method in SubAck packetType = PacketType::SUBACK; - char first_byte = static_cast(packetType) << 4; - writeByte(first_byte); - writeByte((subAck.packet_id & 0xF0) >> 8); - writeByte(subAck.packet_id & 0x0F); + first_byte = static_cast(packetType) << 4; + writeByte((subAck.packet_id & 0xFF00) >> 8); + writeByte(subAck.packet_id & 0x00FF); std::vector returnList; for (SubAckReturnCodes code : subAck.responses) @@ -66,12 +73,12 @@ MqttPacket::MqttPacket(const SubAck &subAck) : returnList.push_back(static_cast(code)); } - bites.insert(bites.end(), returnList.begin(), returnList.end()); - bites[1] = returnList.size() + 1; // TODO: make some generic way of calculating the header and use the multi-byte length + writeBytes(&returnList[0], returnList.size()); + calculateRemainingLength(); } MqttPacket::MqttPacket(const Publish &publish) : - bites(publish.getLength()) + bites(publish.getLengthWithoutFixedHeader()) { if (publish.topic.length() > 0xFFFF) { @@ -83,8 +90,8 @@ MqttPacket::MqttPacket(const Publish &publish) : first_byte |= (publish.qos << 1); first_byte |= (static_cast(publish.retain) & 0b00000001); - char topicLenMSB = (publish.topic.length() & 0xF0) >> 8; - char topicLenLSB = publish.topic.length() & 0x0F; + char topicLenMSB = (publish.topic.length() & 0xFF00) >> 8; + char topicLenLSB = publish.topic.length() & 0x00FF; writeByte(topicLenMSB); writeByte(topicLenLSB); writeBytes(publish.topic.c_str(), publish.topic.length()); @@ -98,6 +105,21 @@ MqttPacket::MqttPacket(const Publish &publish) : calculateRemainingLength(); } +// This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector. +MqttPacket::MqttPacket(const PubAck &pubAck) : + bites(pubAck.getLengthWithoutFixedHeader() + 2) +{ + fixed_header_length = 2; // This is the cheat part mentioned above. We're not calculating it dynamically. + packetType = PacketType::PUBACK; + first_byte = static_cast(packetType) << 4; + writeByte(first_byte); + writeByte(2); // length is always 2. + char topicLenMSB = (pubAck.packet_id & 0xFF00) >> 8; + char topicLenLSB = (pubAck.packet_id & 0x00FF); + writeByte(topicLenMSB); + writeByte(topicLenLSB); +} + void MqttPacket::handle() { if (packetType != PacketType::CONNECT) @@ -118,6 +140,8 @@ void MqttPacket::handle() handleSubscribe(); else if (packetType == PacketType::PUBLISH) handlePublish(); + else if (packetType == PacketType::PUBACK) + handlePubAck(); } void MqttPacket::handleConnect() @@ -268,10 +292,8 @@ void MqttPacket::handleSubscribe() uint16_t topicLength = readTwoBytesToUInt16(); std::string topic(readBytes(topicLength), topicLength); char qos = readByte(); - if (qos > 0) - throw NotImplementedException("QoS not implemented"); logger->logf(LOG_INFO, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str()); - sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic); + sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, qos); subs_reponse_codes.push_back(qos); } @@ -293,6 +315,7 @@ void MqttPacket::handlePublish() if (qos == 3) throw ProtocolError("QoS 3 is a protocol violation."); + this->qos = qos; std::string topic(readBytes(variable_header_length), variable_header_length); @@ -310,8 +333,19 @@ void MqttPacket::handlePublish() if (qos) { - throw ProtocolError("Qos not implemented."); + if (qos > 1) + throw ProtocolError("Qos > 1 not implemented."); + packet_id_pos = pos; uint16_t packet_id = readTwoBytesToUInt16(); + + // Clear the packet ID from this packet, because each new publish must get a new one. It's more of a debug precaution. + pos -= 2; + char zero[2]; zero[0] = 0; zero[1] = 0; + writeBytes(zero, 2); + + PubAck pubAck(packet_id); + MqttPacket response(pubAck); + sender->writeMqttPacket(response); } if (retain) @@ -330,6 +364,12 @@ void MqttPacket::handlePublish() sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this, sender); } +void MqttPacket::handlePubAck() +{ + uint16_t packet_id = readTwoBytesToUInt16(); + sender->getSession()->clearQosMessage(packet_id); +} + void MqttPacket::calculateRemainingLength() { assert(fixed_header_length == 0); // because you're not supposed to call this on packet that we already know the length of. @@ -356,6 +396,40 @@ RemainingLength MqttPacket::getRemainingLength() const return remainingLength; } +void MqttPacket::setPacketId(uint16_t packet_id) +{ + // In other words, we assume that this code can only be called on packets of which we have all the bytes, including fixed header. + assert(fixed_header_length > 0); + assert(packetType == PacketType::PUBLISH); + assert(qos > 0); + + pos = packet_id_pos; + + char topicLenMSB = (packet_id & 0xFF00) >> 8; + char topicLenLSB = (packet_id & 0x00FF); + writeByte(topicLenMSB); + writeByte(topicLenLSB); +} + +// If I read the specs correctly, the DUP flag is merely for show. It doesn't control anything? +void MqttPacket::setDuplicate() +{ + // In other words, we assume that this code can only be called on packets of which we have all the bytes, including fixed header. + assert(fixed_header_length > 0); + assert(packetType == PacketType::PUBLISH); + assert(qos > 0); + + char byte1 = bites[0]; + byte1 |= 0b00001000; + pos = 0; + writeByte(byte1); +} + +size_t MqttPacket::getTotalMemoryFootprint() +{ + return bites.size() + sizeof(MqttPacket); +} + size_t MqttPacket::getSizeIncludingNonPresentHeader() const { size_t total = bites.size(); diff --git a/mqttpacket.h b/mqttpacket.h index af00b55..92d85dd 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -28,9 +28,11 @@ class MqttPacket std::vector bites; size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header. RemainingLength remainingLength; + char qos = 0; Client_p sender; char first_byte = 0; size_t pos = 0; + size_t packet_id_pos = 0; ProtocolVersion protocolVersion = ProtocolVersion::None; Logger *logger = Logger::getInstance(); @@ -43,17 +45,20 @@ class MqttPacket void calculateRemainingLength(); + MqttPacket(const MqttPacket &other) = default; public: PacketType packetType = PacketType::Reserved; MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, Client_p &sender); // Constructor for parsing incoming packets. MqttPacket(MqttPacket &&other) = default; - MqttPacket(const MqttPacket &other) = delete; + + std::shared_ptr getCopy() const; // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance. MqttPacket(const ConnAck &connAck); MqttPacket(const SubAck &subAck); MqttPacket(const Publish &publish); + MqttPacket(const PubAck &pubAck); void handle(); void handleConnect(); @@ -61,16 +66,19 @@ public: void handleSubscribe(); void handlePing(); void handlePublish(); + void handlePubAck(); size_t getSizeIncludingNonPresentHeader() const; const std::vector &getBites() const { return bites; } - + char getQos() const { return qos; } Client_p getSender() const; void setSender(const Client_p &value); - bool containsFixedHeader() const; char getFirstByte() const; RemainingLength getRemainingLength() const; + void setPacketId(uint16_t packet_id); + void setDuplicate(); + size_t getTotalMemoryFootprint(); }; #endif // MQTTPACKET_H diff --git a/session.cpp b/session.cpp index 964806a..3295fc9 100644 --- a/session.cpp +++ b/session.cpp @@ -1,4 +1,7 @@ +#include "cassert" + #include "session.h" +#include "client.h" Session::Session() { @@ -19,3 +22,74 @@ void Session::assignActiveConnection(std::shared_ptr &client) { this->client = client; } + +void Session::writePacket(const MqttPacket &packet) +{ + const char qos = packet.getQos(); + + if (qos == 0) + { + if (!clientDisconnected()) + { + Client_p c = makeSharedClient(); + c->writeMqttPacketAndBlameThisClient(packet); + } + } + else if (qos == 1) + { + std::shared_ptr copyPacket = packet.getCopy(); + std::unique_lock locker(qosQueueMutex); + if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) + { + logger->logf(LOG_WARNING, "Dropping QoS message for client 'TODO', because its QoS buffers were full."); + return; + } + const uint16_t pid = nextPacketId++; + copyPacket->setPacketId(pid); + qosPacketQueue[pid] = copyPacket; + qosQueueBytes += copyPacket->getTotalMemoryFootprint(); + locker.unlock(); + + if (!clientDisconnected()) + { + Client_p c = makeSharedClient(); + c->writeMqttPacketAndBlameThisClient(*copyPacket.get()); + copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. + } + } +} + +void Session::clearQosMessage(uint16_t packet_id) +{ + std::lock_guard locker(qosQueueMutex); + auto it = qosPacketQueue.find(packet_id); + if (it != qosPacketQueue.end()) + { + std::shared_ptr packet = it->second; + qosPacketQueue.erase(it); + qosQueueBytes -= packet->getTotalMemoryFootprint(); + assert(qosQueueBytes >= 0); + if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose. + qosQueueBytes = 0; + } +} + +// [MQTT-4.4.0-1]: "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST re-send any +// unacknowledged PUBLISH Packets (where QoS > 0) and PUBREL Packets using their original Packet Identifiers. This +// is the only circumstance where a Client or Server is REQUIRED to redeliver messages." +// +// There is a bit of a hole there, I think. When we write out a packet to a receiver, it may decide to drop it, if its buffers +// are full, for instance. We are not required to (periodically) retry. TODO Perhaps I will implement that retry anyway. +void Session::sendPendingQosMessages() +{ + if (!clientDisconnected()) + { + Client_p c = makeSharedClient(); + std::lock_guard locker(qosQueueMutex); + for (auto &qosMessage : qosPacketQueue) // TODO: wrong: the order must be maintained. Combine the fix with that vector idea + { + c->writeMqttPacketAndBlameThisClient(*qosMessage.second.get()); + qosMessage.second->setDuplicate(); // Any dealings with this packet from here will be a duplicate. + } + } +} diff --git a/session.h b/session.h index 02b21d0..a866c32 100644 --- a/session.h +++ b/session.h @@ -2,13 +2,24 @@ #define SESSION_H #include +#include +#include -class Client; +#include "forward_declarations.h" +#include "logger.h" + +// TODO make settings +#define MAX_QOS_MSG_PENDING_PER_CLIENT 32 +#define MAX_QOS_BYTES_PENDING_PER_CLIENT 4096 class Session { std::weak_ptr client; - // TODO: qos message queue, as some kind of movable pointer. + std::unordered_map> qosPacketQueue; // TODO: because the max queue length should remain low-ish, perhaps a vector is better here. + std::mutex qosQueueMutex; + uint16_t nextPacketId = 0; + ssize_t qosQueueBytes = 0; + Logger *logger = Logger::getInstance(); public: Session(); Session(const Session &other) = delete; @@ -17,6 +28,9 @@ public: bool clientDisconnected() const; std::shared_ptr makeSharedClient() const; void assignActiveConnection(std::shared_ptr &client); + void writePacket(const MqttPacket &packet); + void clearQosMessage(uint16_t packet_id); + void sendPendingQosMessages(); }; #endif // SESSION_H diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index bc52ba3..6ef0f34 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -18,7 +18,7 @@ SubscriptionStore::SubscriptionStore() : } -void SubscriptionStore::addSubscription(Client_p &client, const std::string &topic) +void SubscriptionStore::addSubscription(Client_p &client, const std::string &topic, char qos) { const std::list subtopics = split(topic, '/'); @@ -89,10 +89,13 @@ void SubscriptionStore::registerClientAndKickExistingOne(Client_p &client) if (!session || client->getCleanSession()) { session.reset(new Session()); + sessionsById[client->getClientId()] = session; } session->assignActiveConnection(client); + client->assignSession(session); + session->sendPendingQosMessages(); } // TODO: should I implement cache, this needs to be changed to returning a list of clients. @@ -103,12 +106,7 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const st if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect. { const std::shared_ptr session = session_weak.lock(); - - if (!session->clientDisconnected()) - { - Client_p c = session->makeSharedClient(); - c->writeMqttPacketAndBlameThisClient(packet); - } + session->writePacket(packet); } } } @@ -170,7 +168,7 @@ void SubscriptionStore::giveClientRetainedMessages(Client_p &client, const std:: const MqttPacket packet(publish); if (topicsMatch(subscribe_topic, rm.topic)) - client->writeMqttPacket(packet); + client->writeMqttPacket(packet); // TODO: I think this needs to be session, not client, and then I can store it if it's QoS? I need to research how retain+qos works } } diff --git a/subscriptionstore.h b/subscriptionstore.h index dbb02d1..4ac3a77 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -30,7 +30,7 @@ public: SubscriptionNode(const SubscriptionNode &node) = delete; SubscriptionNode(SubscriptionNode &&node) = delete; - std::forward_list> subscribers; // The idea is to store subscriptions by client id, to support persistent sessions. + std::forward_list> subscribers; // TODO: a subscription class, with qos std::unordered_map> children; std::unique_ptr childrenPlus; std::unique_ptr childrenPound; @@ -54,7 +54,7 @@ class SubscriptionStore public: SubscriptionStore(); - void addSubscription(Client_p &client, const std::string &topic); + void addSubscription(Client_p &client, const std::string &topic, char qos); void registerClientAndKickExistingOne(Client_p &client); void queuePacketAtSubscribers(const std::string &topic, const MqttPacket &packet, const Client_p &sender); diff --git a/types.cpp b/types.cpp index 207782e..96fc5b5 100644 --- a/types.cpp +++ b/types.cpp @@ -15,6 +15,13 @@ SubAck::SubAck(uint16_t packet_id, const std::list &subs_qos_reponses) : } } +size_t SubAck::getLengthWithoutFixedHeader() const +{ + size_t result = responses.size(); + result += 2; // Packet ID + return result; +} + Publish::Publish(const std::string &topic, const std::string payload, char qos) : topic(topic), payload(payload), @@ -23,8 +30,7 @@ Publish::Publish(const std::string &topic, const std::string payload, char qos) } -// Length starting at the variable header, not the fixed header. -size_t Publish::getLength() const +size_t Publish::getLengthWithoutFixedHeader() const { int result = topic.length() + payload.length() + 2; @@ -33,3 +39,15 @@ size_t Publish::getLength() const return result; } + +PubAck::PubAck(uint16_t packet_id) : + packet_id(packet_id) +{ + +} + +// Packet has no payload and only a variable header, of length 2. +size_t PubAck::getLengthWithoutFixedHeader() const +{ + return 2; +} diff --git a/types.h b/types.h index befb53a..8523907 100644 --- a/types.h +++ b/types.h @@ -48,7 +48,7 @@ class ConnAck public: ConnAck(ConnAckReturnCodes return_code); ConnAckReturnCodes return_code; - size_t getLength() const { return 2;} // size of connack is always the same + size_t getLengthWithoutFixedHeader() const { return 2;} // size of connack is always the same }; enum class SubAckReturnCodes @@ -65,6 +65,7 @@ public: uint16_t packet_id; std::list responses; SubAck(uint16_t packet_id, const std::list &subs_qos_reponses); + size_t getLengthWithoutFixedHeader() const; }; class Publish @@ -75,7 +76,15 @@ public: char qos = 0; bool retain = false; // Note: existing subscribers don't get publishes of retained messages with retain=1. [MQTT-3.3.1-9] Publish(const std::string &topic, const std::string payload, char qos); - size_t getLength() const; + size_t getLengthWithoutFixedHeader() const; +}; + +class PubAck +{ +public: + PubAck(uint16_t packet_id); + uint16_t packet_id; + size_t getLengthWithoutFixedHeader() const; }; #endif // TYPES_H -- libgit2 0.21.4