Commit 760ec58878490864a6f82dfcf1f6858cc19306a8
1 parent
86d368a3
QoS 1, 80%
Also includes fixes to packet parsing that I couldn't make a separate commit for. When it comes to QoS 1, these things are still left, off the top of my head: - vector for qos queue? It helps with ordering and is CPU cache friendly. - Store subscription QoS. - Do retained messages have QoS? - Give session client's name, to access it later.
Showing
11 changed files
with
249 additions
and
39 deletions
client.cpp
| ... | ... | @@ -98,9 +98,9 @@ void Client::writeMqttPacket(const MqttPacket &packet) |
| 98 | 98 | writebuf.doubleSize(); |
| 99 | 99 | } |
| 100 | 100 | |
| 101 | - // And drop a publish when it doesn't fit, even after resizing. This means we do allow pings. | |
| 102 | - // TODO: when QoS is implemented, different filtering may be required. | |
| 103 | - if (packet.packetType == PacketType::PUBLISH && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace()) | |
| 101 | + // And drop a publish when it doesn't fit, even after resizing. This means we do allow pings. And | |
| 102 | + // QoS packet are queued and limited elsewhere. | |
| 103 | + if (packet.packetType == PacketType::PUBLISH && packet.getQos() == 0 && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace()) | |
| 104 | 104 | { |
| 105 | 105 | return; |
| 106 | 106 | } |
| ... | ... | @@ -350,6 +350,16 @@ void Client::setWill(const std::string &topic, const std::string &payload, bool |
| 350 | 350 | this->will_qos = qos; |
| 351 | 351 | } |
| 352 | 352 | |
| 353 | +void Client::assignSession(std::shared_ptr<Session> &session) | |
| 354 | +{ | |
| 355 | + this->session = session; | |
| 356 | +} | |
| 357 | + | |
| 358 | +std::shared_ptr<Session> Client::getSession() | |
| 359 | +{ | |
| 360 | + return this->session; | |
| 361 | +} | |
| 362 | + | |
| 353 | 363 | |
| 354 | 364 | |
| 355 | 365 | ... | ... |
client.h
| ... | ... | @@ -51,6 +51,8 @@ class Client |
| 51 | 51 | ThreadData_p threadData; |
| 52 | 52 | std::mutex writeBufMutex; |
| 53 | 53 | |
| 54 | + std::shared_ptr<Session> session; | |
| 55 | + | |
| 54 | 56 | |
| 55 | 57 | void setReadyForWriting(bool val); |
| 56 | 58 | void setReadyForReading(bool val); |
| ... | ... | @@ -73,6 +75,8 @@ public: |
| 73 | 75 | ThreadData_p getThreadData() { return threadData; } |
| 74 | 76 | std::string &getClientId() { return this->clientid; } |
| 75 | 77 | bool getCleanSession() { return cleanSession; } |
| 78 | + void assignSession(std::shared_ptr<Session> &session); | |
| 79 | + std::shared_ptr<Session> getSession(); | |
| 76 | 80 | |
| 77 | 81 | void writePingResp(); |
| 78 | 82 | void writeMqttPacket(const MqttPacket &packet); | ... | ... |
forward_declarations.h
mqttpacket.cpp
| ... | ... | @@ -36,29 +36,36 @@ MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_lengt |
| 36 | 36 | pos += fixed_header_length; |
| 37 | 37 | } |
| 38 | 38 | |
| 39 | +// This is easier than using the copy constructor publically, because then I have to keep maintaining a functioning copy constructor. | |
| 40 | +// Returning shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource. | |
| 41 | +std::shared_ptr<MqttPacket> MqttPacket::getCopy() const | |
| 42 | +{ | |
| 43 | + std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(*this)); | |
| 44 | + copyPacket->sender.reset(); | |
| 45 | + return copyPacket; | |
| 46 | +} | |
| 47 | + | |
| 39 | 48 | // This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector. |
| 40 | 49 | MqttPacket::MqttPacket(const ConnAck &connAck) : |
| 41 | - bites(connAck.getLength() + 2) | |
| 50 | + bites(connAck.getLengthWithoutFixedHeader() + 2) | |
| 42 | 51 | { |
| 43 | 52 | fixed_header_length = 2; |
| 44 | 53 | packetType = PacketType::CONNACK; |
| 45 | 54 | char first_byte = static_cast<char>(packetType) << 4; |
| 46 | 55 | writeByte(first_byte); |
| 47 | 56 | writeByte(2); // length is always 2. |
| 48 | - writeByte(0); // all connect-ack flags are 0, except session-present, but we don't have that yet. | |
| 57 | + writeByte(0); // all connect-ack flags are 0, except session-present, but we don't have that yet. TODO: make that | |
| 49 | 58 | writeByte(static_cast<char>(connAck.return_code)); |
| 50 | 59 | |
| 51 | 60 | } |
| 52 | 61 | |
| 53 | 62 | MqttPacket::MqttPacket(const SubAck &subAck) : |
| 54 | - bites(3) | |
| 63 | + bites(subAck.getLengthWithoutFixedHeader()) | |
| 55 | 64 | { |
| 56 | - fixed_header_length = 2; // TODO: this is wrong, pending implementation of the new method in SubAck | |
| 57 | 65 | packetType = PacketType::SUBACK; |
| 58 | - char first_byte = static_cast<char>(packetType) << 4; | |
| 59 | - writeByte(first_byte); | |
| 60 | - writeByte((subAck.packet_id & 0xF0) >> 8); | |
| 61 | - writeByte(subAck.packet_id & 0x0F); | |
| 66 | + first_byte = static_cast<char>(packetType) << 4; | |
| 67 | + writeByte((subAck.packet_id & 0xFF00) >> 8); | |
| 68 | + writeByte(subAck.packet_id & 0x00FF); | |
| 62 | 69 | |
| 63 | 70 | std::vector<char> returnList; |
| 64 | 71 | for (SubAckReturnCodes code : subAck.responses) |
| ... | ... | @@ -66,12 +73,12 @@ MqttPacket::MqttPacket(const SubAck &subAck) : |
| 66 | 73 | returnList.push_back(static_cast<char>(code)); |
| 67 | 74 | } |
| 68 | 75 | |
| 69 | - bites.insert(bites.end(), returnList.begin(), returnList.end()); | |
| 70 | - bites[1] = returnList.size() + 1; // TODO: make some generic way of calculating the header and use the multi-byte length | |
| 76 | + writeBytes(&returnList[0], returnList.size()); | |
| 77 | + calculateRemainingLength(); | |
| 71 | 78 | } |
| 72 | 79 | |
| 73 | 80 | MqttPacket::MqttPacket(const Publish &publish) : |
| 74 | - bites(publish.getLength()) | |
| 81 | + bites(publish.getLengthWithoutFixedHeader()) | |
| 75 | 82 | { |
| 76 | 83 | if (publish.topic.length() > 0xFFFF) |
| 77 | 84 | { |
| ... | ... | @@ -83,8 +90,8 @@ MqttPacket::MqttPacket(const Publish &publish) : |
| 83 | 90 | first_byte |= (publish.qos << 1); |
| 84 | 91 | first_byte |= (static_cast<char>(publish.retain) & 0b00000001); |
| 85 | 92 | |
| 86 | - char topicLenMSB = (publish.topic.length() & 0xF0) >> 8; | |
| 87 | - char topicLenLSB = publish.topic.length() & 0x0F; | |
| 93 | + char topicLenMSB = (publish.topic.length() & 0xFF00) >> 8; | |
| 94 | + char topicLenLSB = publish.topic.length() & 0x00FF; | |
| 88 | 95 | writeByte(topicLenMSB); |
| 89 | 96 | writeByte(topicLenLSB); |
| 90 | 97 | writeBytes(publish.topic.c_str(), publish.topic.length()); |
| ... | ... | @@ -98,6 +105,21 @@ MqttPacket::MqttPacket(const Publish &publish) : |
| 98 | 105 | calculateRemainingLength(); |
| 99 | 106 | } |
| 100 | 107 | |
| 108 | +// This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector. | |
| 109 | +MqttPacket::MqttPacket(const PubAck &pubAck) : | |
| 110 | + bites(pubAck.getLengthWithoutFixedHeader() + 2) | |
| 111 | +{ | |
| 112 | + fixed_header_length = 2; // This is the cheat part mentioned above. We're not calculating it dynamically. | |
| 113 | + packetType = PacketType::PUBACK; | |
| 114 | + first_byte = static_cast<char>(packetType) << 4; | |
| 115 | + writeByte(first_byte); | |
| 116 | + writeByte(2); // length is always 2. | |
| 117 | + char topicLenMSB = (pubAck.packet_id & 0xFF00) >> 8; | |
| 118 | + char topicLenLSB = (pubAck.packet_id & 0x00FF); | |
| 119 | + writeByte(topicLenMSB); | |
| 120 | + writeByte(topicLenLSB); | |
| 121 | +} | |
| 122 | + | |
| 101 | 123 | void MqttPacket::handle() |
| 102 | 124 | { |
| 103 | 125 | if (packetType != PacketType::CONNECT) |
| ... | ... | @@ -118,6 +140,8 @@ void MqttPacket::handle() |
| 118 | 140 | handleSubscribe(); |
| 119 | 141 | else if (packetType == PacketType::PUBLISH) |
| 120 | 142 | handlePublish(); |
| 143 | + else if (packetType == PacketType::PUBACK) | |
| 144 | + handlePubAck(); | |
| 121 | 145 | } |
| 122 | 146 | |
| 123 | 147 | void MqttPacket::handleConnect() |
| ... | ... | @@ -268,10 +292,8 @@ void MqttPacket::handleSubscribe() |
| 268 | 292 | uint16_t topicLength = readTwoBytesToUInt16(); |
| 269 | 293 | std::string topic(readBytes(topicLength), topicLength); |
| 270 | 294 | char qos = readByte(); |
| 271 | - if (qos > 0) | |
| 272 | - throw NotImplementedException("QoS not implemented"); | |
| 273 | 295 | logger->logf(LOG_INFO, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str()); |
| 274 | - sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic); | |
| 296 | + sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, qos); | |
| 275 | 297 | subs_reponse_codes.push_back(qos); |
| 276 | 298 | } |
| 277 | 299 | |
| ... | ... | @@ -293,6 +315,7 @@ void MqttPacket::handlePublish() |
| 293 | 315 | |
| 294 | 316 | if (qos == 3) |
| 295 | 317 | throw ProtocolError("QoS 3 is a protocol violation."); |
| 318 | + this->qos = qos; | |
| 296 | 319 | |
| 297 | 320 | std::string topic(readBytes(variable_header_length), variable_header_length); |
| 298 | 321 | |
| ... | ... | @@ -310,8 +333,19 @@ void MqttPacket::handlePublish() |
| 310 | 333 | |
| 311 | 334 | if (qos) |
| 312 | 335 | { |
| 313 | - throw ProtocolError("Qos not implemented."); | |
| 336 | + if (qos > 1) | |
| 337 | + throw ProtocolError("Qos > 1 not implemented."); | |
| 338 | + packet_id_pos = pos; | |
| 314 | 339 | uint16_t packet_id = readTwoBytesToUInt16(); |
| 340 | + | |
| 341 | + // Clear the packet ID from this packet, because each new publish must get a new one. It's more of a debug precaution. | |
| 342 | + pos -= 2; | |
| 343 | + char zero[2]; zero[0] = 0; zero[1] = 0; | |
| 344 | + writeBytes(zero, 2); | |
| 345 | + | |
| 346 | + PubAck pubAck(packet_id); | |
| 347 | + MqttPacket response(pubAck); | |
| 348 | + sender->writeMqttPacket(response); | |
| 315 | 349 | } |
| 316 | 350 | |
| 317 | 351 | if (retain) |
| ... | ... | @@ -330,6 +364,12 @@ void MqttPacket::handlePublish() |
| 330 | 364 | sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this, sender); |
| 331 | 365 | } |
| 332 | 366 | |
| 367 | +void MqttPacket::handlePubAck() | |
| 368 | +{ | |
| 369 | + uint16_t packet_id = readTwoBytesToUInt16(); | |
| 370 | + sender->getSession()->clearQosMessage(packet_id); | |
| 371 | +} | |
| 372 | + | |
| 333 | 373 | void MqttPacket::calculateRemainingLength() |
| 334 | 374 | { |
| 335 | 375 | assert(fixed_header_length == 0); // because you're not supposed to call this on packet that we already know the length of. |
| ... | ... | @@ -356,6 +396,40 @@ RemainingLength MqttPacket::getRemainingLength() const |
| 356 | 396 | return remainingLength; |
| 357 | 397 | } |
| 358 | 398 | |
| 399 | +void MqttPacket::setPacketId(uint16_t packet_id) | |
| 400 | +{ | |
| 401 | + // In other words, we assume that this code can only be called on packets of which we have all the bytes, including fixed header. | |
| 402 | + assert(fixed_header_length > 0); | |
| 403 | + assert(packetType == PacketType::PUBLISH); | |
| 404 | + assert(qos > 0); | |
| 405 | + | |
| 406 | + pos = packet_id_pos; | |
| 407 | + | |
| 408 | + char topicLenMSB = (packet_id & 0xFF00) >> 8; | |
| 409 | + char topicLenLSB = (packet_id & 0x00FF); | |
| 410 | + writeByte(topicLenMSB); | |
| 411 | + writeByte(topicLenLSB); | |
| 412 | +} | |
| 413 | + | |
| 414 | +// If I read the specs correctly, the DUP flag is merely for show. It doesn't control anything? | |
| 415 | +void MqttPacket::setDuplicate() | |
| 416 | +{ | |
| 417 | + // In other words, we assume that this code can only be called on packets of which we have all the bytes, including fixed header. | |
| 418 | + assert(fixed_header_length > 0); | |
| 419 | + assert(packetType == PacketType::PUBLISH); | |
| 420 | + assert(qos > 0); | |
| 421 | + | |
| 422 | + char byte1 = bites[0]; | |
| 423 | + byte1 |= 0b00001000; | |
| 424 | + pos = 0; | |
| 425 | + writeByte(byte1); | |
| 426 | +} | |
| 427 | + | |
| 428 | +size_t MqttPacket::getTotalMemoryFootprint() | |
| 429 | +{ | |
| 430 | + return bites.size() + sizeof(MqttPacket); | |
| 431 | +} | |
| 432 | + | |
| 359 | 433 | size_t MqttPacket::getSizeIncludingNonPresentHeader() const |
| 360 | 434 | { |
| 361 | 435 | size_t total = bites.size(); | ... | ... |
mqttpacket.h
| ... | ... | @@ -28,9 +28,11 @@ class MqttPacket |
| 28 | 28 | std::vector<char> bites; |
| 29 | 29 | size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header. |
| 30 | 30 | RemainingLength remainingLength; |
| 31 | + char qos = 0; | |
| 31 | 32 | Client_p sender; |
| 32 | 33 | char first_byte = 0; |
| 33 | 34 | size_t pos = 0; |
| 35 | + size_t packet_id_pos = 0; | |
| 34 | 36 | ProtocolVersion protocolVersion = ProtocolVersion::None; |
| 35 | 37 | Logger *logger = Logger::getInstance(); |
| 36 | 38 | |
| ... | ... | @@ -43,17 +45,20 @@ class MqttPacket |
| 43 | 45 | |
| 44 | 46 | void calculateRemainingLength(); |
| 45 | 47 | |
| 48 | + MqttPacket(const MqttPacket &other) = default; | |
| 46 | 49 | public: |
| 47 | 50 | PacketType packetType = PacketType::Reserved; |
| 48 | 51 | MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, Client_p &sender); // Constructor for parsing incoming packets. |
| 49 | 52 | |
| 50 | 53 | MqttPacket(MqttPacket &&other) = default; |
| 51 | - MqttPacket(const MqttPacket &other) = delete; | |
| 54 | + | |
| 55 | + std::shared_ptr<MqttPacket> getCopy() const; | |
| 52 | 56 | |
| 53 | 57 | // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance. |
| 54 | 58 | MqttPacket(const ConnAck &connAck); |
| 55 | 59 | MqttPacket(const SubAck &subAck); |
| 56 | 60 | MqttPacket(const Publish &publish); |
| 61 | + MqttPacket(const PubAck &pubAck); | |
| 57 | 62 | |
| 58 | 63 | void handle(); |
| 59 | 64 | void handleConnect(); |
| ... | ... | @@ -61,16 +66,19 @@ public: |
| 61 | 66 | void handleSubscribe(); |
| 62 | 67 | void handlePing(); |
| 63 | 68 | void handlePublish(); |
| 69 | + void handlePubAck(); | |
| 64 | 70 | |
| 65 | 71 | size_t getSizeIncludingNonPresentHeader() const; |
| 66 | 72 | const std::vector<char> &getBites() const { return bites; } |
| 67 | - | |
| 73 | + char getQos() const { return qos; } | |
| 68 | 74 | Client_p getSender() const; |
| 69 | 75 | void setSender(const Client_p &value); |
| 70 | - | |
| 71 | 76 | bool containsFixedHeader() const; |
| 72 | 77 | char getFirstByte() const; |
| 73 | 78 | RemainingLength getRemainingLength() const; |
| 79 | + void setPacketId(uint16_t packet_id); | |
| 80 | + void setDuplicate(); | |
| 81 | + size_t getTotalMemoryFootprint(); | |
| 74 | 82 | }; |
| 75 | 83 | |
| 76 | 84 | #endif // MQTTPACKET_H | ... | ... |
session.cpp
| 1 | +#include "cassert" | |
| 2 | + | |
| 1 | 3 | #include "session.h" |
| 4 | +#include "client.h" | |
| 2 | 5 | |
| 3 | 6 | Session::Session() |
| 4 | 7 | { |
| ... | ... | @@ -19,3 +22,74 @@ void Session::assignActiveConnection(std::shared_ptr<Client> &client) |
| 19 | 22 | { |
| 20 | 23 | this->client = client; |
| 21 | 24 | } |
| 25 | + | |
| 26 | +void Session::writePacket(const MqttPacket &packet) | |
| 27 | +{ | |
| 28 | + const char qos = packet.getQos(); | |
| 29 | + | |
| 30 | + if (qos == 0) | |
| 31 | + { | |
| 32 | + if (!clientDisconnected()) | |
| 33 | + { | |
| 34 | + Client_p c = makeSharedClient(); | |
| 35 | + c->writeMqttPacketAndBlameThisClient(packet); | |
| 36 | + } | |
| 37 | + } | |
| 38 | + else if (qos == 1) | |
| 39 | + { | |
| 40 | + std::shared_ptr<MqttPacket> copyPacket = packet.getCopy(); | |
| 41 | + std::unique_lock<std::mutex> locker(qosQueueMutex); | |
| 42 | + if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) | |
| 43 | + { | |
| 44 | + logger->logf(LOG_WARNING, "Dropping QoS message for client 'TODO', because its QoS buffers were full."); | |
| 45 | + return; | |
| 46 | + } | |
| 47 | + const uint16_t pid = nextPacketId++; | |
| 48 | + copyPacket->setPacketId(pid); | |
| 49 | + qosPacketQueue[pid] = copyPacket; | |
| 50 | + qosQueueBytes += copyPacket->getTotalMemoryFootprint(); | |
| 51 | + locker.unlock(); | |
| 52 | + | |
| 53 | + if (!clientDisconnected()) | |
| 54 | + { | |
| 55 | + Client_p c = makeSharedClient(); | |
| 56 | + c->writeMqttPacketAndBlameThisClient(*copyPacket.get()); | |
| 57 | + copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | |
| 58 | + } | |
| 59 | + } | |
| 60 | +} | |
| 61 | + | |
| 62 | +void Session::clearQosMessage(uint16_t packet_id) | |
| 63 | +{ | |
| 64 | + std::lock_guard<std::mutex> locker(qosQueueMutex); | |
| 65 | + auto it = qosPacketQueue.find(packet_id); | |
| 66 | + if (it != qosPacketQueue.end()) | |
| 67 | + { | |
| 68 | + std::shared_ptr<MqttPacket> packet = it->second; | |
| 69 | + qosPacketQueue.erase(it); | |
| 70 | + qosQueueBytes -= packet->getTotalMemoryFootprint(); | |
| 71 | + assert(qosQueueBytes >= 0); | |
| 72 | + if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose. | |
| 73 | + qosQueueBytes = 0; | |
| 74 | + } | |
| 75 | +} | |
| 76 | + | |
| 77 | +// [MQTT-4.4.0-1]: "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST re-send any | |
| 78 | +// unacknowledged PUBLISH Packets (where QoS > 0) and PUBREL Packets using their original Packet Identifiers. This | |
| 79 | +// is the only circumstance where a Client or Server is REQUIRED to redeliver messages." | |
| 80 | +// | |
| 81 | +// There is a bit of a hole there, I think. When we write out a packet to a receiver, it may decide to drop it, if its buffers | |
| 82 | +// are full, for instance. We are not required to (periodically) retry. TODO Perhaps I will implement that retry anyway. | |
| 83 | +void Session::sendPendingQosMessages() | |
| 84 | +{ | |
| 85 | + if (!clientDisconnected()) | |
| 86 | + { | |
| 87 | + Client_p c = makeSharedClient(); | |
| 88 | + std::lock_guard<std::mutex> locker(qosQueueMutex); | |
| 89 | + for (auto &qosMessage : qosPacketQueue) // TODO: wrong: the order must be maintained. Combine the fix with that vector idea | |
| 90 | + { | |
| 91 | + c->writeMqttPacketAndBlameThisClient(*qosMessage.second.get()); | |
| 92 | + qosMessage.second->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | |
| 93 | + } | |
| 94 | + } | |
| 95 | +} | ... | ... |
session.h
| ... | ... | @@ -2,13 +2,24 @@ |
| 2 | 2 | #define SESSION_H |
| 3 | 3 | |
| 4 | 4 | #include <memory> |
| 5 | +#include <unordered_map> | |
| 6 | +#include <mutex> | |
| 5 | 7 | |
| 6 | -class Client; | |
| 8 | +#include "forward_declarations.h" | |
| 9 | +#include "logger.h" | |
| 10 | + | |
| 11 | +// TODO make settings | |
| 12 | +#define MAX_QOS_MSG_PENDING_PER_CLIENT 32 | |
| 13 | +#define MAX_QOS_BYTES_PENDING_PER_CLIENT 4096 | |
| 7 | 14 | |
| 8 | 15 | class Session |
| 9 | 16 | { |
| 10 | 17 | std::weak_ptr<Client> client; |
| 11 | - // TODO: qos message queue, as some kind of movable pointer. | |
| 18 | + std::unordered_map<uint16_t, std::shared_ptr<MqttPacket>> qosPacketQueue; // TODO: because the max queue length should remain low-ish, perhaps a vector is better here. | |
| 19 | + std::mutex qosQueueMutex; | |
| 20 | + uint16_t nextPacketId = 0; | |
| 21 | + ssize_t qosQueueBytes = 0; | |
| 22 | + Logger *logger = Logger::getInstance(); | |
| 12 | 23 | public: |
| 13 | 24 | Session(); |
| 14 | 25 | Session(const Session &other) = delete; |
| ... | ... | @@ -17,6 +28,9 @@ public: |
| 17 | 28 | bool clientDisconnected() const; |
| 18 | 29 | std::shared_ptr<Client> makeSharedClient() const; |
| 19 | 30 | void assignActiveConnection(std::shared_ptr<Client> &client); |
| 31 | + void writePacket(const MqttPacket &packet); | |
| 32 | + void clearQosMessage(uint16_t packet_id); | |
| 33 | + void sendPendingQosMessages(); | |
| 20 | 34 | }; |
| 21 | 35 | |
| 22 | 36 | #endif // SESSION_H | ... | ... |
subscriptionstore.cpp
| ... | ... | @@ -18,7 +18,7 @@ SubscriptionStore::SubscriptionStore() : |
| 18 | 18 | |
| 19 | 19 | } |
| 20 | 20 | |
| 21 | -void SubscriptionStore::addSubscription(Client_p &client, const std::string &topic) | |
| 21 | +void SubscriptionStore::addSubscription(Client_p &client, const std::string &topic, char qos) | |
| 22 | 22 | { |
| 23 | 23 | const std::list<std::string> subtopics = split(topic, '/'); |
| 24 | 24 | |
| ... | ... | @@ -89,10 +89,13 @@ void SubscriptionStore::registerClientAndKickExistingOne(Client_p &client) |
| 89 | 89 | if (!session || client->getCleanSession()) |
| 90 | 90 | { |
| 91 | 91 | session.reset(new Session()); |
| 92 | + | |
| 92 | 93 | sessionsById[client->getClientId()] = session; |
| 93 | 94 | } |
| 94 | 95 | |
| 95 | 96 | session->assignActiveConnection(client); |
| 97 | + client->assignSession(session); | |
| 98 | + session->sendPendingQosMessages(); | |
| 96 | 99 | } |
| 97 | 100 | |
| 98 | 101 | // TODO: should I implement cache, this needs to be changed to returning a list of clients. |
| ... | ... | @@ -103,12 +106,7 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const st |
| 103 | 106 | if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect. |
| 104 | 107 | { |
| 105 | 108 | const std::shared_ptr<Session> session = session_weak.lock(); |
| 106 | - | |
| 107 | - if (!session->clientDisconnected()) | |
| 108 | - { | |
| 109 | - Client_p c = session->makeSharedClient(); | |
| 110 | - c->writeMqttPacketAndBlameThisClient(packet); | |
| 111 | - } | |
| 109 | + session->writePacket(packet); | |
| 112 | 110 | } |
| 113 | 111 | } |
| 114 | 112 | } |
| ... | ... | @@ -170,7 +168,7 @@ void SubscriptionStore::giveClientRetainedMessages(Client_p &client, const std:: |
| 170 | 168 | const MqttPacket packet(publish); |
| 171 | 169 | |
| 172 | 170 | if (topicsMatch(subscribe_topic, rm.topic)) |
| 173 | - client->writeMqttPacket(packet); | |
| 171 | + client->writeMqttPacket(packet); // TODO: I think this needs to be session, not client, and then I can store it if it's QoS? I need to research how retain+qos works | |
| 174 | 172 | } |
| 175 | 173 | } |
| 176 | 174 | ... | ... |
subscriptionstore.h
| ... | ... | @@ -30,7 +30,7 @@ public: |
| 30 | 30 | SubscriptionNode(const SubscriptionNode &node) = delete; |
| 31 | 31 | SubscriptionNode(SubscriptionNode &&node) = delete; |
| 32 | 32 | |
| 33 | - std::forward_list<std::weak_ptr<Session>> subscribers; // The idea is to store subscriptions by client id, to support persistent sessions. | |
| 33 | + std::forward_list<std::weak_ptr<Session>> subscribers; // TODO: a subscription class, with qos | |
| 34 | 34 | std::unordered_map<std::string, std::unique_ptr<SubscriptionNode>> children; |
| 35 | 35 | std::unique_ptr<SubscriptionNode> childrenPlus; |
| 36 | 36 | std::unique_ptr<SubscriptionNode> childrenPound; |
| ... | ... | @@ -54,7 +54,7 @@ class SubscriptionStore |
| 54 | 54 | public: |
| 55 | 55 | SubscriptionStore(); |
| 56 | 56 | |
| 57 | - void addSubscription(Client_p &client, const std::string &topic); | |
| 57 | + void addSubscription(Client_p &client, const std::string &topic, char qos); | |
| 58 | 58 | void registerClientAndKickExistingOne(Client_p &client); |
| 59 | 59 | |
| 60 | 60 | void queuePacketAtSubscribers(const std::string &topic, const MqttPacket &packet, const Client_p &sender); | ... | ... |
types.cpp
| ... | ... | @@ -15,6 +15,13 @@ SubAck::SubAck(uint16_t packet_id, const std::list<char> &subs_qos_reponses) : |
| 15 | 15 | } |
| 16 | 16 | } |
| 17 | 17 | |
| 18 | +size_t SubAck::getLengthWithoutFixedHeader() const | |
| 19 | +{ | |
| 20 | + size_t result = responses.size(); | |
| 21 | + result += 2; // Packet ID | |
| 22 | + return result; | |
| 23 | +} | |
| 24 | + | |
| 18 | 25 | Publish::Publish(const std::string &topic, const std::string payload, char qos) : |
| 19 | 26 | topic(topic), |
| 20 | 27 | payload(payload), |
| ... | ... | @@ -23,8 +30,7 @@ Publish::Publish(const std::string &topic, const std::string payload, char qos) |
| 23 | 30 | |
| 24 | 31 | } |
| 25 | 32 | |
| 26 | -// Length starting at the variable header, not the fixed header. | |
| 27 | -size_t Publish::getLength() const | |
| 33 | +size_t Publish::getLengthWithoutFixedHeader() const | |
| 28 | 34 | { |
| 29 | 35 | int result = topic.length() + payload.length() + 2; |
| 30 | 36 | |
| ... | ... | @@ -33,3 +39,15 @@ size_t Publish::getLength() const |
| 33 | 39 | |
| 34 | 40 | return result; |
| 35 | 41 | } |
| 42 | + | |
| 43 | +PubAck::PubAck(uint16_t packet_id) : | |
| 44 | + packet_id(packet_id) | |
| 45 | +{ | |
| 46 | + | |
| 47 | +} | |
| 48 | + | |
| 49 | +// Packet has no payload and only a variable header, of length 2. | |
| 50 | +size_t PubAck::getLengthWithoutFixedHeader() const | |
| 51 | +{ | |
| 52 | + return 2; | |
| 53 | +} | ... | ... |
types.h
| ... | ... | @@ -48,7 +48,7 @@ class ConnAck |
| 48 | 48 | public: |
| 49 | 49 | ConnAck(ConnAckReturnCodes return_code); |
| 50 | 50 | ConnAckReturnCodes return_code; |
| 51 | - size_t getLength() const { return 2;} // size of connack is always the same | |
| 51 | + size_t getLengthWithoutFixedHeader() const { return 2;} // size of connack is always the same | |
| 52 | 52 | }; |
| 53 | 53 | |
| 54 | 54 | enum class SubAckReturnCodes |
| ... | ... | @@ -65,6 +65,7 @@ public: |
| 65 | 65 | uint16_t packet_id; |
| 66 | 66 | std::list<SubAckReturnCodes> responses; |
| 67 | 67 | SubAck(uint16_t packet_id, const std::list<char> &subs_qos_reponses); |
| 68 | + size_t getLengthWithoutFixedHeader() const; | |
| 68 | 69 | }; |
| 69 | 70 | |
| 70 | 71 | class Publish |
| ... | ... | @@ -75,7 +76,15 @@ public: |
| 75 | 76 | char qos = 0; |
| 76 | 77 | bool retain = false; // Note: existing subscribers don't get publishes of retained messages with retain=1. [MQTT-3.3.1-9] |
| 77 | 78 | Publish(const std::string &topic, const std::string payload, char qos); |
| 78 | - size_t getLength() const; | |
| 79 | + size_t getLengthWithoutFixedHeader() const; | |
| 80 | +}; | |
| 81 | + | |
| 82 | +class PubAck | |
| 83 | +{ | |
| 84 | +public: | |
| 85 | + PubAck(uint16_t packet_id); | |
| 86 | + uint16_t packet_id; | |
| 87 | + size_t getLengthWithoutFixedHeader() const; | |
| 79 | 88 | }; |
| 80 | 89 | |
| 81 | 90 | #endif // TYPES_H | ... | ... |