Commit 760ec58878490864a6f82dfcf1f6858cc19306a8

Authored by Wiebe Cazemier
1 parent 86d368a3

QoS 1, 80%

Also includes fixes to packet parsing that I couldn't make a separate
commit for.

When it comes to QoS 1, these things are still left, off the top of my
head:

- vector for qos queue? It helps with ordering and is CPU cache friendly.
- Store subscription QoS.
- Do retained messages have QoS?
- Give session client's name, to access it later.
client.cpp
@@ -98,9 +98,9 @@ void Client::writeMqttPacket(const MqttPacket &packet) @@ -98,9 +98,9 @@ void Client::writeMqttPacket(const MqttPacket &packet)
98 writebuf.doubleSize(); 98 writebuf.doubleSize();
99 } 99 }
100 100
101 - // And drop a publish when it doesn't fit, even after resizing. This means we do allow pings.  
102 - // TODO: when QoS is implemented, different filtering may be required.  
103 - if (packet.packetType == PacketType::PUBLISH && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace()) 101 + // And drop a publish when it doesn't fit, even after resizing. This means we do allow pings. And
  102 + // QoS packet are queued and limited elsewhere.
  103 + if (packet.packetType == PacketType::PUBLISH && packet.getQos() == 0 && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace())
104 { 104 {
105 return; 105 return;
106 } 106 }
@@ -350,6 +350,16 @@ void Client::setWill(const std::string &topic, const std::string &payload, bool @@ -350,6 +350,16 @@ void Client::setWill(const std::string &topic, const std::string &payload, bool
350 this->will_qos = qos; 350 this->will_qos = qos;
351 } 351 }
352 352
  353 +void Client::assignSession(std::shared_ptr<Session> &session)
  354 +{
  355 + this->session = session;
  356 +}
  357 +
  358 +std::shared_ptr<Session> Client::getSession()
  359 +{
  360 + return this->session;
  361 +}
  362 +
353 363
354 364
355 365
client.h
@@ -51,6 +51,8 @@ class Client @@ -51,6 +51,8 @@ class Client
51 ThreadData_p threadData; 51 ThreadData_p threadData;
52 std::mutex writeBufMutex; 52 std::mutex writeBufMutex;
53 53
  54 + std::shared_ptr<Session> session;
  55 +
54 56
55 void setReadyForWriting(bool val); 57 void setReadyForWriting(bool val);
56 void setReadyForReading(bool val); 58 void setReadyForReading(bool val);
@@ -73,6 +75,8 @@ public: @@ -73,6 +75,8 @@ public:
73 ThreadData_p getThreadData() { return threadData; } 75 ThreadData_p getThreadData() { return threadData; }
74 std::string &getClientId() { return this->clientid; } 76 std::string &getClientId() { return this->clientid; }
75 bool getCleanSession() { return cleanSession; } 77 bool getCleanSession() { return cleanSession; }
  78 + void assignSession(std::shared_ptr<Session> &session);
  79 + std::shared_ptr<Session> getSession();
76 80
77 void writePingResp(); 81 void writePingResp();
78 void writeMqttPacket(const MqttPacket &packet); 82 void writeMqttPacket(const MqttPacket &packet);
forward_declarations.h
@@ -9,6 +9,7 @@ class ThreadData; @@ -9,6 +9,7 @@ class ThreadData;
9 typedef std::shared_ptr<ThreadData> ThreadData_p; 9 typedef std::shared_ptr<ThreadData> ThreadData_p;
10 class MqttPacket; 10 class MqttPacket;
11 class SubscriptionStore; 11 class SubscriptionStore;
  12 +class Session;
12 13
13 14
14 #endif // FORWARD_DECLARATIONS_H 15 #endif // FORWARD_DECLARATIONS_H
mqttpacket.cpp
@@ -36,29 +36,36 @@ MqttPacket::MqttPacket(CirBuf &amp;buf, size_t packet_len, size_t fixed_header_lengt @@ -36,29 +36,36 @@ MqttPacket::MqttPacket(CirBuf &amp;buf, size_t packet_len, size_t fixed_header_lengt
36 pos += fixed_header_length; 36 pos += fixed_header_length;
37 } 37 }
38 38
  39 +// This is easier than using the copy constructor publically, because then I have to keep maintaining a functioning copy constructor.
  40 +// Returning shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource.
  41 +std::shared_ptr<MqttPacket> MqttPacket::getCopy() const
  42 +{
  43 + std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(*this));
  44 + copyPacket->sender.reset();
  45 + return copyPacket;
  46 +}
  47 +
39 // This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector. 48 // This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector.
40 MqttPacket::MqttPacket(const ConnAck &connAck) : 49 MqttPacket::MqttPacket(const ConnAck &connAck) :
41 - bites(connAck.getLength() + 2) 50 + bites(connAck.getLengthWithoutFixedHeader() + 2)
42 { 51 {
43 fixed_header_length = 2; 52 fixed_header_length = 2;
44 packetType = PacketType::CONNACK; 53 packetType = PacketType::CONNACK;
45 char first_byte = static_cast<char>(packetType) << 4; 54 char first_byte = static_cast<char>(packetType) << 4;
46 writeByte(first_byte); 55 writeByte(first_byte);
47 writeByte(2); // length is always 2. 56 writeByte(2); // length is always 2.
48 - writeByte(0); // all connect-ack flags are 0, except session-present, but we don't have that yet. 57 + writeByte(0); // all connect-ack flags are 0, except session-present, but we don't have that yet. TODO: make that
49 writeByte(static_cast<char>(connAck.return_code)); 58 writeByte(static_cast<char>(connAck.return_code));
50 59
51 } 60 }
52 61
53 MqttPacket::MqttPacket(const SubAck &subAck) : 62 MqttPacket::MqttPacket(const SubAck &subAck) :
54 - bites(3) 63 + bites(subAck.getLengthWithoutFixedHeader())
55 { 64 {
56 - fixed_header_length = 2; // TODO: this is wrong, pending implementation of the new method in SubAck  
57 packetType = PacketType::SUBACK; 65 packetType = PacketType::SUBACK;
58 - char first_byte = static_cast<char>(packetType) << 4;  
59 - writeByte(first_byte);  
60 - writeByte((subAck.packet_id & 0xF0) >> 8);  
61 - writeByte(subAck.packet_id & 0x0F); 66 + first_byte = static_cast<char>(packetType) << 4;
  67 + writeByte((subAck.packet_id & 0xFF00) >> 8);
  68 + writeByte(subAck.packet_id & 0x00FF);
62 69
63 std::vector<char> returnList; 70 std::vector<char> returnList;
64 for (SubAckReturnCodes code : subAck.responses) 71 for (SubAckReturnCodes code : subAck.responses)
@@ -66,12 +73,12 @@ MqttPacket::MqttPacket(const SubAck &amp;subAck) : @@ -66,12 +73,12 @@ MqttPacket::MqttPacket(const SubAck &amp;subAck) :
66 returnList.push_back(static_cast<char>(code)); 73 returnList.push_back(static_cast<char>(code));
67 } 74 }
68 75
69 - bites.insert(bites.end(), returnList.begin(), returnList.end());  
70 - bites[1] = returnList.size() + 1; // TODO: make some generic way of calculating the header and use the multi-byte length 76 + writeBytes(&returnList[0], returnList.size());
  77 + calculateRemainingLength();
71 } 78 }
72 79
73 MqttPacket::MqttPacket(const Publish &publish) : 80 MqttPacket::MqttPacket(const Publish &publish) :
74 - bites(publish.getLength()) 81 + bites(publish.getLengthWithoutFixedHeader())
75 { 82 {
76 if (publish.topic.length() > 0xFFFF) 83 if (publish.topic.length() > 0xFFFF)
77 { 84 {
@@ -83,8 +90,8 @@ MqttPacket::MqttPacket(const Publish &amp;publish) : @@ -83,8 +90,8 @@ MqttPacket::MqttPacket(const Publish &amp;publish) :
83 first_byte |= (publish.qos << 1); 90 first_byte |= (publish.qos << 1);
84 first_byte |= (static_cast<char>(publish.retain) & 0b00000001); 91 first_byte |= (static_cast<char>(publish.retain) & 0b00000001);
85 92
86 - char topicLenMSB = (publish.topic.length() & 0xF0) >> 8;  
87 - char topicLenLSB = publish.topic.length() & 0x0F; 93 + char topicLenMSB = (publish.topic.length() & 0xFF00) >> 8;
  94 + char topicLenLSB = publish.topic.length() & 0x00FF;
88 writeByte(topicLenMSB); 95 writeByte(topicLenMSB);
89 writeByte(topicLenLSB); 96 writeByte(topicLenLSB);
90 writeBytes(publish.topic.c_str(), publish.topic.length()); 97 writeBytes(publish.topic.c_str(), publish.topic.length());
@@ -98,6 +105,21 @@ MqttPacket::MqttPacket(const Publish &amp;publish) : @@ -98,6 +105,21 @@ MqttPacket::MqttPacket(const Publish &amp;publish) :
98 calculateRemainingLength(); 105 calculateRemainingLength();
99 } 106 }
100 107
  108 +// This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector.
  109 +MqttPacket::MqttPacket(const PubAck &pubAck) :
  110 + bites(pubAck.getLengthWithoutFixedHeader() + 2)
  111 +{
  112 + fixed_header_length = 2; // This is the cheat part mentioned above. We're not calculating it dynamically.
  113 + packetType = PacketType::PUBACK;
  114 + first_byte = static_cast<char>(packetType) << 4;
  115 + writeByte(first_byte);
  116 + writeByte(2); // length is always 2.
  117 + char topicLenMSB = (pubAck.packet_id & 0xFF00) >> 8;
  118 + char topicLenLSB = (pubAck.packet_id & 0x00FF);
  119 + writeByte(topicLenMSB);
  120 + writeByte(topicLenLSB);
  121 +}
  122 +
101 void MqttPacket::handle() 123 void MqttPacket::handle()
102 { 124 {
103 if (packetType != PacketType::CONNECT) 125 if (packetType != PacketType::CONNECT)
@@ -118,6 +140,8 @@ void MqttPacket::handle() @@ -118,6 +140,8 @@ void MqttPacket::handle()
118 handleSubscribe(); 140 handleSubscribe();
119 else if (packetType == PacketType::PUBLISH) 141 else if (packetType == PacketType::PUBLISH)
120 handlePublish(); 142 handlePublish();
  143 + else if (packetType == PacketType::PUBACK)
  144 + handlePubAck();
121 } 145 }
122 146
123 void MqttPacket::handleConnect() 147 void MqttPacket::handleConnect()
@@ -268,10 +292,8 @@ void MqttPacket::handleSubscribe() @@ -268,10 +292,8 @@ void MqttPacket::handleSubscribe()
268 uint16_t topicLength = readTwoBytesToUInt16(); 292 uint16_t topicLength = readTwoBytesToUInt16();
269 std::string topic(readBytes(topicLength), topicLength); 293 std::string topic(readBytes(topicLength), topicLength);
270 char qos = readByte(); 294 char qos = readByte();
271 - if (qos > 0)  
272 - throw NotImplementedException("QoS not implemented");  
273 logger->logf(LOG_INFO, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str()); 295 logger->logf(LOG_INFO, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str());
274 - sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic); 296 + sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, qos);
275 subs_reponse_codes.push_back(qos); 297 subs_reponse_codes.push_back(qos);
276 } 298 }
277 299
@@ -293,6 +315,7 @@ void MqttPacket::handlePublish() @@ -293,6 +315,7 @@ void MqttPacket::handlePublish()
293 315
294 if (qos == 3) 316 if (qos == 3)
295 throw ProtocolError("QoS 3 is a protocol violation."); 317 throw ProtocolError("QoS 3 is a protocol violation.");
  318 + this->qos = qos;
296 319
297 std::string topic(readBytes(variable_header_length), variable_header_length); 320 std::string topic(readBytes(variable_header_length), variable_header_length);
298 321
@@ -310,8 +333,19 @@ void MqttPacket::handlePublish() @@ -310,8 +333,19 @@ void MqttPacket::handlePublish()
310 333
311 if (qos) 334 if (qos)
312 { 335 {
313 - throw ProtocolError("Qos not implemented."); 336 + if (qos > 1)
  337 + throw ProtocolError("Qos > 1 not implemented.");
  338 + packet_id_pos = pos;
314 uint16_t packet_id = readTwoBytesToUInt16(); 339 uint16_t packet_id = readTwoBytesToUInt16();
  340 +
  341 + // Clear the packet ID from this packet, because each new publish must get a new one. It's more of a debug precaution.
  342 + pos -= 2;
  343 + char zero[2]; zero[0] = 0; zero[1] = 0;
  344 + writeBytes(zero, 2);
  345 +
  346 + PubAck pubAck(packet_id);
  347 + MqttPacket response(pubAck);
  348 + sender->writeMqttPacket(response);
315 } 349 }
316 350
317 if (retain) 351 if (retain)
@@ -330,6 +364,12 @@ void MqttPacket::handlePublish() @@ -330,6 +364,12 @@ void MqttPacket::handlePublish()
330 sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this, sender); 364 sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this, sender);
331 } 365 }
332 366
  367 +void MqttPacket::handlePubAck()
  368 +{
  369 + uint16_t packet_id = readTwoBytesToUInt16();
  370 + sender->getSession()->clearQosMessage(packet_id);
  371 +}
  372 +
333 void MqttPacket::calculateRemainingLength() 373 void MqttPacket::calculateRemainingLength()
334 { 374 {
335 assert(fixed_header_length == 0); // because you're not supposed to call this on packet that we already know the length of. 375 assert(fixed_header_length == 0); // because you're not supposed to call this on packet that we already know the length of.
@@ -356,6 +396,40 @@ RemainingLength MqttPacket::getRemainingLength() const @@ -356,6 +396,40 @@ RemainingLength MqttPacket::getRemainingLength() const
356 return remainingLength; 396 return remainingLength;
357 } 397 }
358 398
  399 +void MqttPacket::setPacketId(uint16_t packet_id)
  400 +{
  401 + // In other words, we assume that this code can only be called on packets of which we have all the bytes, including fixed header.
  402 + assert(fixed_header_length > 0);
  403 + assert(packetType == PacketType::PUBLISH);
  404 + assert(qos > 0);
  405 +
  406 + pos = packet_id_pos;
  407 +
  408 + char topicLenMSB = (packet_id & 0xFF00) >> 8;
  409 + char topicLenLSB = (packet_id & 0x00FF);
  410 + writeByte(topicLenMSB);
  411 + writeByte(topicLenLSB);
  412 +}
  413 +
  414 +// If I read the specs correctly, the DUP flag is merely for show. It doesn't control anything?
  415 +void MqttPacket::setDuplicate()
  416 +{
  417 + // In other words, we assume that this code can only be called on packets of which we have all the bytes, including fixed header.
  418 + assert(fixed_header_length > 0);
  419 + assert(packetType == PacketType::PUBLISH);
  420 + assert(qos > 0);
  421 +
  422 + char byte1 = bites[0];
  423 + byte1 |= 0b00001000;
  424 + pos = 0;
  425 + writeByte(byte1);
  426 +}
  427 +
  428 +size_t MqttPacket::getTotalMemoryFootprint()
  429 +{
  430 + return bites.size() + sizeof(MqttPacket);
  431 +}
  432 +
359 size_t MqttPacket::getSizeIncludingNonPresentHeader() const 433 size_t MqttPacket::getSizeIncludingNonPresentHeader() const
360 { 434 {
361 size_t total = bites.size(); 435 size_t total = bites.size();
mqttpacket.h
@@ -28,9 +28,11 @@ class MqttPacket @@ -28,9 +28,11 @@ class MqttPacket
28 std::vector<char> bites; 28 std::vector<char> bites;
29 size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header. 29 size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header.
30 RemainingLength remainingLength; 30 RemainingLength remainingLength;
  31 + char qos = 0;
31 Client_p sender; 32 Client_p sender;
32 char first_byte = 0; 33 char first_byte = 0;
33 size_t pos = 0; 34 size_t pos = 0;
  35 + size_t packet_id_pos = 0;
34 ProtocolVersion protocolVersion = ProtocolVersion::None; 36 ProtocolVersion protocolVersion = ProtocolVersion::None;
35 Logger *logger = Logger::getInstance(); 37 Logger *logger = Logger::getInstance();
36 38
@@ -43,17 +45,20 @@ class MqttPacket @@ -43,17 +45,20 @@ class MqttPacket
43 45
44 void calculateRemainingLength(); 46 void calculateRemainingLength();
45 47
  48 + MqttPacket(const MqttPacket &other) = default;
46 public: 49 public:
47 PacketType packetType = PacketType::Reserved; 50 PacketType packetType = PacketType::Reserved;
48 MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, Client_p &sender); // Constructor for parsing incoming packets. 51 MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, Client_p &sender); // Constructor for parsing incoming packets.
49 52
50 MqttPacket(MqttPacket &&other) = default; 53 MqttPacket(MqttPacket &&other) = default;
51 - MqttPacket(const MqttPacket &other) = delete; 54 +
  55 + std::shared_ptr<MqttPacket> getCopy() const;
52 56
53 // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance. 57 // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance.
54 MqttPacket(const ConnAck &connAck); 58 MqttPacket(const ConnAck &connAck);
55 MqttPacket(const SubAck &subAck); 59 MqttPacket(const SubAck &subAck);
56 MqttPacket(const Publish &publish); 60 MqttPacket(const Publish &publish);
  61 + MqttPacket(const PubAck &pubAck);
57 62
58 void handle(); 63 void handle();
59 void handleConnect(); 64 void handleConnect();
@@ -61,16 +66,19 @@ public: @@ -61,16 +66,19 @@ public:
61 void handleSubscribe(); 66 void handleSubscribe();
62 void handlePing(); 67 void handlePing();
63 void handlePublish(); 68 void handlePublish();
  69 + void handlePubAck();
64 70
65 size_t getSizeIncludingNonPresentHeader() const; 71 size_t getSizeIncludingNonPresentHeader() const;
66 const std::vector<char> &getBites() const { return bites; } 72 const std::vector<char> &getBites() const { return bites; }
67 - 73 + char getQos() const { return qos; }
68 Client_p getSender() const; 74 Client_p getSender() const;
69 void setSender(const Client_p &value); 75 void setSender(const Client_p &value);
70 -  
71 bool containsFixedHeader() const; 76 bool containsFixedHeader() const;
72 char getFirstByte() const; 77 char getFirstByte() const;
73 RemainingLength getRemainingLength() const; 78 RemainingLength getRemainingLength() const;
  79 + void setPacketId(uint16_t packet_id);
  80 + void setDuplicate();
  81 + size_t getTotalMemoryFootprint();
74 }; 82 };
75 83
76 #endif // MQTTPACKET_H 84 #endif // MQTTPACKET_H
session.cpp
  1 +#include "cassert"
  2 +
1 #include "session.h" 3 #include "session.h"
  4 +#include "client.h"
2 5
3 Session::Session() 6 Session::Session()
4 { 7 {
@@ -19,3 +22,74 @@ void Session::assignActiveConnection(std::shared_ptr&lt;Client&gt; &amp;client) @@ -19,3 +22,74 @@ void Session::assignActiveConnection(std::shared_ptr&lt;Client&gt; &amp;client)
19 { 22 {
20 this->client = client; 23 this->client = client;
21 } 24 }
  25 +
  26 +void Session::writePacket(const MqttPacket &packet)
  27 +{
  28 + const char qos = packet.getQos();
  29 +
  30 + if (qos == 0)
  31 + {
  32 + if (!clientDisconnected())
  33 + {
  34 + Client_p c = makeSharedClient();
  35 + c->writeMqttPacketAndBlameThisClient(packet);
  36 + }
  37 + }
  38 + else if (qos == 1)
  39 + {
  40 + std::shared_ptr<MqttPacket> copyPacket = packet.getCopy();
  41 + std::unique_lock<std::mutex> locker(qosQueueMutex);
  42 + if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0))
  43 + {
  44 + logger->logf(LOG_WARNING, "Dropping QoS message for client 'TODO', because its QoS buffers were full.");
  45 + return;
  46 + }
  47 + const uint16_t pid = nextPacketId++;
  48 + copyPacket->setPacketId(pid);
  49 + qosPacketQueue[pid] = copyPacket;
  50 + qosQueueBytes += copyPacket->getTotalMemoryFootprint();
  51 + locker.unlock();
  52 +
  53 + if (!clientDisconnected())
  54 + {
  55 + Client_p c = makeSharedClient();
  56 + c->writeMqttPacketAndBlameThisClient(*copyPacket.get());
  57 + copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate.
  58 + }
  59 + }
  60 +}
  61 +
  62 +void Session::clearQosMessage(uint16_t packet_id)
  63 +{
  64 + std::lock_guard<std::mutex> locker(qosQueueMutex);
  65 + auto it = qosPacketQueue.find(packet_id);
  66 + if (it != qosPacketQueue.end())
  67 + {
  68 + std::shared_ptr<MqttPacket> packet = it->second;
  69 + qosPacketQueue.erase(it);
  70 + qosQueueBytes -= packet->getTotalMemoryFootprint();
  71 + assert(qosQueueBytes >= 0);
  72 + if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose.
  73 + qosQueueBytes = 0;
  74 + }
  75 +}
  76 +
  77 +// [MQTT-4.4.0-1]: "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST re-send any
  78 +// unacknowledged PUBLISH Packets (where QoS > 0) and PUBREL Packets using their original Packet Identifiers. This
  79 +// is the only circumstance where a Client or Server is REQUIRED to redeliver messages."
  80 +//
  81 +// There is a bit of a hole there, I think. When we write out a packet to a receiver, it may decide to drop it, if its buffers
  82 +// are full, for instance. We are not required to (periodically) retry. TODO Perhaps I will implement that retry anyway.
  83 +void Session::sendPendingQosMessages()
  84 +{
  85 + if (!clientDisconnected())
  86 + {
  87 + Client_p c = makeSharedClient();
  88 + std::lock_guard<std::mutex> locker(qosQueueMutex);
  89 + for (auto &qosMessage : qosPacketQueue) // TODO: wrong: the order must be maintained. Combine the fix with that vector idea
  90 + {
  91 + c->writeMqttPacketAndBlameThisClient(*qosMessage.second.get());
  92 + qosMessage.second->setDuplicate(); // Any dealings with this packet from here will be a duplicate.
  93 + }
  94 + }
  95 +}
session.h
@@ -2,13 +2,24 @@ @@ -2,13 +2,24 @@
2 #define SESSION_H 2 #define SESSION_H
3 3
4 #include <memory> 4 #include <memory>
  5 +#include <unordered_map>
  6 +#include <mutex>
5 7
6 -class Client; 8 +#include "forward_declarations.h"
  9 +#include "logger.h"
  10 +
  11 +// TODO make settings
  12 +#define MAX_QOS_MSG_PENDING_PER_CLIENT 32
  13 +#define MAX_QOS_BYTES_PENDING_PER_CLIENT 4096
7 14
8 class Session 15 class Session
9 { 16 {
10 std::weak_ptr<Client> client; 17 std::weak_ptr<Client> client;
11 - // TODO: qos message queue, as some kind of movable pointer. 18 + std::unordered_map<uint16_t, std::shared_ptr<MqttPacket>> qosPacketQueue; // TODO: because the max queue length should remain low-ish, perhaps a vector is better here.
  19 + std::mutex qosQueueMutex;
  20 + uint16_t nextPacketId = 0;
  21 + ssize_t qosQueueBytes = 0;
  22 + Logger *logger = Logger::getInstance();
12 public: 23 public:
13 Session(); 24 Session();
14 Session(const Session &other) = delete; 25 Session(const Session &other) = delete;
@@ -17,6 +28,9 @@ public: @@ -17,6 +28,9 @@ public:
17 bool clientDisconnected() const; 28 bool clientDisconnected() const;
18 std::shared_ptr<Client> makeSharedClient() const; 29 std::shared_ptr<Client> makeSharedClient() const;
19 void assignActiveConnection(std::shared_ptr<Client> &client); 30 void assignActiveConnection(std::shared_ptr<Client> &client);
  31 + void writePacket(const MqttPacket &packet);
  32 + void clearQosMessage(uint16_t packet_id);
  33 + void sendPendingQosMessages();
20 }; 34 };
21 35
22 #endif // SESSION_H 36 #endif // SESSION_H
subscriptionstore.cpp
@@ -18,7 +18,7 @@ SubscriptionStore::SubscriptionStore() : @@ -18,7 +18,7 @@ SubscriptionStore::SubscriptionStore() :
18 18
19 } 19 }
20 20
21 -void SubscriptionStore::addSubscription(Client_p &client, const std::string &topic) 21 +void SubscriptionStore::addSubscription(Client_p &client, const std::string &topic, char qos)
22 { 22 {
23 const std::list<std::string> subtopics = split(topic, '/'); 23 const std::list<std::string> subtopics = split(topic, '/');
24 24
@@ -89,10 +89,13 @@ void SubscriptionStore::registerClientAndKickExistingOne(Client_p &amp;client) @@ -89,10 +89,13 @@ void SubscriptionStore::registerClientAndKickExistingOne(Client_p &amp;client)
89 if (!session || client->getCleanSession()) 89 if (!session || client->getCleanSession())
90 { 90 {
91 session.reset(new Session()); 91 session.reset(new Session());
  92 +
92 sessionsById[client->getClientId()] = session; 93 sessionsById[client->getClientId()] = session;
93 } 94 }
94 95
95 session->assignActiveConnection(client); 96 session->assignActiveConnection(client);
  97 + client->assignSession(session);
  98 + session->sendPendingQosMessages();
96 } 99 }
97 100
98 // TODO: should I implement cache, this needs to be changed to returning a list of clients. 101 // TODO: should I implement cache, this needs to be changed to returning a list of clients.
@@ -103,12 +106,7 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &amp;packet, const st @@ -103,12 +106,7 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &amp;packet, const st
103 if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect. 106 if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect.
104 { 107 {
105 const std::shared_ptr<Session> session = session_weak.lock(); 108 const std::shared_ptr<Session> session = session_weak.lock();
106 -  
107 - if (!session->clientDisconnected())  
108 - {  
109 - Client_p c = session->makeSharedClient();  
110 - c->writeMqttPacketAndBlameThisClient(packet);  
111 - } 109 + session->writePacket(packet);
112 } 110 }
113 } 111 }
114 } 112 }
@@ -170,7 +168,7 @@ void SubscriptionStore::giveClientRetainedMessages(Client_p &amp;client, const std:: @@ -170,7 +168,7 @@ void SubscriptionStore::giveClientRetainedMessages(Client_p &amp;client, const std::
170 const MqttPacket packet(publish); 168 const MqttPacket packet(publish);
171 169
172 if (topicsMatch(subscribe_topic, rm.topic)) 170 if (topicsMatch(subscribe_topic, rm.topic))
173 - client->writeMqttPacket(packet); 171 + client->writeMqttPacket(packet); // TODO: I think this needs to be session, not client, and then I can store it if it's QoS? I need to research how retain+qos works
174 } 172 }
175 } 173 }
176 174
subscriptionstore.h
@@ -30,7 +30,7 @@ public: @@ -30,7 +30,7 @@ public:
30 SubscriptionNode(const SubscriptionNode &node) = delete; 30 SubscriptionNode(const SubscriptionNode &node) = delete;
31 SubscriptionNode(SubscriptionNode &&node) = delete; 31 SubscriptionNode(SubscriptionNode &&node) = delete;
32 32
33 - std::forward_list<std::weak_ptr<Session>> subscribers; // The idea is to store subscriptions by client id, to support persistent sessions. 33 + std::forward_list<std::weak_ptr<Session>> subscribers; // TODO: a subscription class, with qos
34 std::unordered_map<std::string, std::unique_ptr<SubscriptionNode>> children; 34 std::unordered_map<std::string, std::unique_ptr<SubscriptionNode>> children;
35 std::unique_ptr<SubscriptionNode> childrenPlus; 35 std::unique_ptr<SubscriptionNode> childrenPlus;
36 std::unique_ptr<SubscriptionNode> childrenPound; 36 std::unique_ptr<SubscriptionNode> childrenPound;
@@ -54,7 +54,7 @@ class SubscriptionStore @@ -54,7 +54,7 @@ class SubscriptionStore
54 public: 54 public:
55 SubscriptionStore(); 55 SubscriptionStore();
56 56
57 - void addSubscription(Client_p &client, const std::string &topic); 57 + void addSubscription(Client_p &client, const std::string &topic, char qos);
58 void registerClientAndKickExistingOne(Client_p &client); 58 void registerClientAndKickExistingOne(Client_p &client);
59 59
60 void queuePacketAtSubscribers(const std::string &topic, const MqttPacket &packet, const Client_p &sender); 60 void queuePacketAtSubscribers(const std::string &topic, const MqttPacket &packet, const Client_p &sender);
types.cpp
@@ -15,6 +15,13 @@ SubAck::SubAck(uint16_t packet_id, const std::list&lt;char&gt; &amp;subs_qos_reponses) : @@ -15,6 +15,13 @@ SubAck::SubAck(uint16_t packet_id, const std::list&lt;char&gt; &amp;subs_qos_reponses) :
15 } 15 }
16 } 16 }
17 17
  18 +size_t SubAck::getLengthWithoutFixedHeader() const
  19 +{
  20 + size_t result = responses.size();
  21 + result += 2; // Packet ID
  22 + return result;
  23 +}
  24 +
18 Publish::Publish(const std::string &topic, const std::string payload, char qos) : 25 Publish::Publish(const std::string &topic, const std::string payload, char qos) :
19 topic(topic), 26 topic(topic),
20 payload(payload), 27 payload(payload),
@@ -23,8 +30,7 @@ Publish::Publish(const std::string &amp;topic, const std::string payload, char qos) @@ -23,8 +30,7 @@ Publish::Publish(const std::string &amp;topic, const std::string payload, char qos)
23 30
24 } 31 }
25 32
26 -// Length starting at the variable header, not the fixed header.  
27 -size_t Publish::getLength() const 33 +size_t Publish::getLengthWithoutFixedHeader() const
28 { 34 {
29 int result = topic.length() + payload.length() + 2; 35 int result = topic.length() + payload.length() + 2;
30 36
@@ -33,3 +39,15 @@ size_t Publish::getLength() const @@ -33,3 +39,15 @@ size_t Publish::getLength() const
33 39
34 return result; 40 return result;
35 } 41 }
  42 +
  43 +PubAck::PubAck(uint16_t packet_id) :
  44 + packet_id(packet_id)
  45 +{
  46 +
  47 +}
  48 +
  49 +// Packet has no payload and only a variable header, of length 2.
  50 +size_t PubAck::getLengthWithoutFixedHeader() const
  51 +{
  52 + return 2;
  53 +}
@@ -48,7 +48,7 @@ class ConnAck @@ -48,7 +48,7 @@ class ConnAck
48 public: 48 public:
49 ConnAck(ConnAckReturnCodes return_code); 49 ConnAck(ConnAckReturnCodes return_code);
50 ConnAckReturnCodes return_code; 50 ConnAckReturnCodes return_code;
51 - size_t getLength() const { return 2;} // size of connack is always the same 51 + size_t getLengthWithoutFixedHeader() const { return 2;} // size of connack is always the same
52 }; 52 };
53 53
54 enum class SubAckReturnCodes 54 enum class SubAckReturnCodes
@@ -65,6 +65,7 @@ public: @@ -65,6 +65,7 @@ public:
65 uint16_t packet_id; 65 uint16_t packet_id;
66 std::list<SubAckReturnCodes> responses; 66 std::list<SubAckReturnCodes> responses;
67 SubAck(uint16_t packet_id, const std::list<char> &subs_qos_reponses); 67 SubAck(uint16_t packet_id, const std::list<char> &subs_qos_reponses);
  68 + size_t getLengthWithoutFixedHeader() const;
68 }; 69 };
69 70
70 class Publish 71 class Publish
@@ -75,7 +76,15 @@ public: @@ -75,7 +76,15 @@ public:
75 char qos = 0; 76 char qos = 0;
76 bool retain = false; // Note: existing subscribers don't get publishes of retained messages with retain=1. [MQTT-3.3.1-9] 77 bool retain = false; // Note: existing subscribers don't get publishes of retained messages with retain=1. [MQTT-3.3.1-9]
77 Publish(const std::string &topic, const std::string payload, char qos); 78 Publish(const std::string &topic, const std::string payload, char qos);
78 - size_t getLength() const; 79 + size_t getLengthWithoutFixedHeader() const;
  80 +};
  81 +
  82 +class PubAck
  83 +{
  84 +public:
  85 + PubAck(uint16_t packet_id);
  86 + uint16_t packet_id;
  87 + size_t getLengthWithoutFixedHeader() const;
79 }; 88 };
80 89
81 #endif // TYPES_H 90 #endif // TYPES_H