Commit 760ec58878490864a6f82dfcf1f6858cc19306a8
1 parent
86d368a3
QoS 1, 80%
Also includes fixes to packet parsing that I couldn't make a separate commit for. When it comes to QoS 1, these things are still left, off the top of my head: - vector for qos queue? It helps with ordering and is CPU cache friendly. - Store subscription QoS. - Do retained messages have QoS? - Give session client's name, to access it later.
Showing
11 changed files
with
249 additions
and
39 deletions
client.cpp
| @@ -98,9 +98,9 @@ void Client::writeMqttPacket(const MqttPacket &packet) | @@ -98,9 +98,9 @@ void Client::writeMqttPacket(const MqttPacket &packet) | ||
| 98 | writebuf.doubleSize(); | 98 | writebuf.doubleSize(); |
| 99 | } | 99 | } |
| 100 | 100 | ||
| 101 | - // And drop a publish when it doesn't fit, even after resizing. This means we do allow pings. | ||
| 102 | - // TODO: when QoS is implemented, different filtering may be required. | ||
| 103 | - if (packet.packetType == PacketType::PUBLISH && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace()) | 101 | + // And drop a publish when it doesn't fit, even after resizing. This means we do allow pings. And |
| 102 | + // QoS packet are queued and limited elsewhere. | ||
| 103 | + if (packet.packetType == PacketType::PUBLISH && packet.getQos() == 0 && packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace()) | ||
| 104 | { | 104 | { |
| 105 | return; | 105 | return; |
| 106 | } | 106 | } |
| @@ -350,6 +350,16 @@ void Client::setWill(const std::string &topic, const std::string &payload, bool | @@ -350,6 +350,16 @@ void Client::setWill(const std::string &topic, const std::string &payload, bool | ||
| 350 | this->will_qos = qos; | 350 | this->will_qos = qos; |
| 351 | } | 351 | } |
| 352 | 352 | ||
| 353 | +void Client::assignSession(std::shared_ptr<Session> &session) | ||
| 354 | +{ | ||
| 355 | + this->session = session; | ||
| 356 | +} | ||
| 357 | + | ||
| 358 | +std::shared_ptr<Session> Client::getSession() | ||
| 359 | +{ | ||
| 360 | + return this->session; | ||
| 361 | +} | ||
| 362 | + | ||
| 353 | 363 | ||
| 354 | 364 | ||
| 355 | 365 |
client.h
| @@ -51,6 +51,8 @@ class Client | @@ -51,6 +51,8 @@ class Client | ||
| 51 | ThreadData_p threadData; | 51 | ThreadData_p threadData; |
| 52 | std::mutex writeBufMutex; | 52 | std::mutex writeBufMutex; |
| 53 | 53 | ||
| 54 | + std::shared_ptr<Session> session; | ||
| 55 | + | ||
| 54 | 56 | ||
| 55 | void setReadyForWriting(bool val); | 57 | void setReadyForWriting(bool val); |
| 56 | void setReadyForReading(bool val); | 58 | void setReadyForReading(bool val); |
| @@ -73,6 +75,8 @@ public: | @@ -73,6 +75,8 @@ public: | ||
| 73 | ThreadData_p getThreadData() { return threadData; } | 75 | ThreadData_p getThreadData() { return threadData; } |
| 74 | std::string &getClientId() { return this->clientid; } | 76 | std::string &getClientId() { return this->clientid; } |
| 75 | bool getCleanSession() { return cleanSession; } | 77 | bool getCleanSession() { return cleanSession; } |
| 78 | + void assignSession(std::shared_ptr<Session> &session); | ||
| 79 | + std::shared_ptr<Session> getSession(); | ||
| 76 | 80 | ||
| 77 | void writePingResp(); | 81 | void writePingResp(); |
| 78 | void writeMqttPacket(const MqttPacket &packet); | 82 | void writeMqttPacket(const MqttPacket &packet); |
forward_declarations.h
| @@ -9,6 +9,7 @@ class ThreadData; | @@ -9,6 +9,7 @@ class ThreadData; | ||
| 9 | typedef std::shared_ptr<ThreadData> ThreadData_p; | 9 | typedef std::shared_ptr<ThreadData> ThreadData_p; |
| 10 | class MqttPacket; | 10 | class MqttPacket; |
| 11 | class SubscriptionStore; | 11 | class SubscriptionStore; |
| 12 | +class Session; | ||
| 12 | 13 | ||
| 13 | 14 | ||
| 14 | #endif // FORWARD_DECLARATIONS_H | 15 | #endif // FORWARD_DECLARATIONS_H |
mqttpacket.cpp
| @@ -36,29 +36,36 @@ MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_lengt | @@ -36,29 +36,36 @@ MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_lengt | ||
| 36 | pos += fixed_header_length; | 36 | pos += fixed_header_length; |
| 37 | } | 37 | } |
| 38 | 38 | ||
| 39 | +// This is easier than using the copy constructor publically, because then I have to keep maintaining a functioning copy constructor. | ||
| 40 | +// Returning shared pointer because that's typically how we need it; we only need to copy it if we pass it around as shared resource. | ||
| 41 | +std::shared_ptr<MqttPacket> MqttPacket::getCopy() const | ||
| 42 | +{ | ||
| 43 | + std::shared_ptr<MqttPacket> copyPacket(new MqttPacket(*this)); | ||
| 44 | + copyPacket->sender.reset(); | ||
| 45 | + return copyPacket; | ||
| 46 | +} | ||
| 47 | + | ||
| 39 | // This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector. | 48 | // This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector. |
| 40 | MqttPacket::MqttPacket(const ConnAck &connAck) : | 49 | MqttPacket::MqttPacket(const ConnAck &connAck) : |
| 41 | - bites(connAck.getLength() + 2) | 50 | + bites(connAck.getLengthWithoutFixedHeader() + 2) |
| 42 | { | 51 | { |
| 43 | fixed_header_length = 2; | 52 | fixed_header_length = 2; |
| 44 | packetType = PacketType::CONNACK; | 53 | packetType = PacketType::CONNACK; |
| 45 | char first_byte = static_cast<char>(packetType) << 4; | 54 | char first_byte = static_cast<char>(packetType) << 4; |
| 46 | writeByte(first_byte); | 55 | writeByte(first_byte); |
| 47 | writeByte(2); // length is always 2. | 56 | writeByte(2); // length is always 2. |
| 48 | - writeByte(0); // all connect-ack flags are 0, except session-present, but we don't have that yet. | 57 | + writeByte(0); // all connect-ack flags are 0, except session-present, but we don't have that yet. TODO: make that |
| 49 | writeByte(static_cast<char>(connAck.return_code)); | 58 | writeByte(static_cast<char>(connAck.return_code)); |
| 50 | 59 | ||
| 51 | } | 60 | } |
| 52 | 61 | ||
| 53 | MqttPacket::MqttPacket(const SubAck &subAck) : | 62 | MqttPacket::MqttPacket(const SubAck &subAck) : |
| 54 | - bites(3) | 63 | + bites(subAck.getLengthWithoutFixedHeader()) |
| 55 | { | 64 | { |
| 56 | - fixed_header_length = 2; // TODO: this is wrong, pending implementation of the new method in SubAck | ||
| 57 | packetType = PacketType::SUBACK; | 65 | packetType = PacketType::SUBACK; |
| 58 | - char first_byte = static_cast<char>(packetType) << 4; | ||
| 59 | - writeByte(first_byte); | ||
| 60 | - writeByte((subAck.packet_id & 0xF0) >> 8); | ||
| 61 | - writeByte(subAck.packet_id & 0x0F); | 66 | + first_byte = static_cast<char>(packetType) << 4; |
| 67 | + writeByte((subAck.packet_id & 0xFF00) >> 8); | ||
| 68 | + writeByte(subAck.packet_id & 0x00FF); | ||
| 62 | 69 | ||
| 63 | std::vector<char> returnList; | 70 | std::vector<char> returnList; |
| 64 | for (SubAckReturnCodes code : subAck.responses) | 71 | for (SubAckReturnCodes code : subAck.responses) |
| @@ -66,12 +73,12 @@ MqttPacket::MqttPacket(const SubAck &subAck) : | @@ -66,12 +73,12 @@ MqttPacket::MqttPacket(const SubAck &subAck) : | ||
| 66 | returnList.push_back(static_cast<char>(code)); | 73 | returnList.push_back(static_cast<char>(code)); |
| 67 | } | 74 | } |
| 68 | 75 | ||
| 69 | - bites.insert(bites.end(), returnList.begin(), returnList.end()); | ||
| 70 | - bites[1] = returnList.size() + 1; // TODO: make some generic way of calculating the header and use the multi-byte length | 76 | + writeBytes(&returnList[0], returnList.size()); |
| 77 | + calculateRemainingLength(); | ||
| 71 | } | 78 | } |
| 72 | 79 | ||
| 73 | MqttPacket::MqttPacket(const Publish &publish) : | 80 | MqttPacket::MqttPacket(const Publish &publish) : |
| 74 | - bites(publish.getLength()) | 81 | + bites(publish.getLengthWithoutFixedHeader()) |
| 75 | { | 82 | { |
| 76 | if (publish.topic.length() > 0xFFFF) | 83 | if (publish.topic.length() > 0xFFFF) |
| 77 | { | 84 | { |
| @@ -83,8 +90,8 @@ MqttPacket::MqttPacket(const Publish &publish) : | @@ -83,8 +90,8 @@ MqttPacket::MqttPacket(const Publish &publish) : | ||
| 83 | first_byte |= (publish.qos << 1); | 90 | first_byte |= (publish.qos << 1); |
| 84 | first_byte |= (static_cast<char>(publish.retain) & 0b00000001); | 91 | first_byte |= (static_cast<char>(publish.retain) & 0b00000001); |
| 85 | 92 | ||
| 86 | - char topicLenMSB = (publish.topic.length() & 0xF0) >> 8; | ||
| 87 | - char topicLenLSB = publish.topic.length() & 0x0F; | 93 | + char topicLenMSB = (publish.topic.length() & 0xFF00) >> 8; |
| 94 | + char topicLenLSB = publish.topic.length() & 0x00FF; | ||
| 88 | writeByte(topicLenMSB); | 95 | writeByte(topicLenMSB); |
| 89 | writeByte(topicLenLSB); | 96 | writeByte(topicLenLSB); |
| 90 | writeBytes(publish.topic.c_str(), publish.topic.length()); | 97 | writeBytes(publish.topic.c_str(), publish.topic.length()); |
| @@ -98,6 +105,21 @@ MqttPacket::MqttPacket(const Publish &publish) : | @@ -98,6 +105,21 @@ MqttPacket::MqttPacket(const Publish &publish) : | ||
| 98 | calculateRemainingLength(); | 105 | calculateRemainingLength(); |
| 99 | } | 106 | } |
| 100 | 107 | ||
| 108 | +// This constructor cheats and doesn't use calculateRemainingLength, because it's always the same. It allocates enough space in the vector. | ||
| 109 | +MqttPacket::MqttPacket(const PubAck &pubAck) : | ||
| 110 | + bites(pubAck.getLengthWithoutFixedHeader() + 2) | ||
| 111 | +{ | ||
| 112 | + fixed_header_length = 2; // This is the cheat part mentioned above. We're not calculating it dynamically. | ||
| 113 | + packetType = PacketType::PUBACK; | ||
| 114 | + first_byte = static_cast<char>(packetType) << 4; | ||
| 115 | + writeByte(first_byte); | ||
| 116 | + writeByte(2); // length is always 2. | ||
| 117 | + char topicLenMSB = (pubAck.packet_id & 0xFF00) >> 8; | ||
| 118 | + char topicLenLSB = (pubAck.packet_id & 0x00FF); | ||
| 119 | + writeByte(topicLenMSB); | ||
| 120 | + writeByte(topicLenLSB); | ||
| 121 | +} | ||
| 122 | + | ||
| 101 | void MqttPacket::handle() | 123 | void MqttPacket::handle() |
| 102 | { | 124 | { |
| 103 | if (packetType != PacketType::CONNECT) | 125 | if (packetType != PacketType::CONNECT) |
| @@ -118,6 +140,8 @@ void MqttPacket::handle() | @@ -118,6 +140,8 @@ void MqttPacket::handle() | ||
| 118 | handleSubscribe(); | 140 | handleSubscribe(); |
| 119 | else if (packetType == PacketType::PUBLISH) | 141 | else if (packetType == PacketType::PUBLISH) |
| 120 | handlePublish(); | 142 | handlePublish(); |
| 143 | + else if (packetType == PacketType::PUBACK) | ||
| 144 | + handlePubAck(); | ||
| 121 | } | 145 | } |
| 122 | 146 | ||
| 123 | void MqttPacket::handleConnect() | 147 | void MqttPacket::handleConnect() |
| @@ -268,10 +292,8 @@ void MqttPacket::handleSubscribe() | @@ -268,10 +292,8 @@ void MqttPacket::handleSubscribe() | ||
| 268 | uint16_t topicLength = readTwoBytesToUInt16(); | 292 | uint16_t topicLength = readTwoBytesToUInt16(); |
| 269 | std::string topic(readBytes(topicLength), topicLength); | 293 | std::string topic(readBytes(topicLength), topicLength); |
| 270 | char qos = readByte(); | 294 | char qos = readByte(); |
| 271 | - if (qos > 0) | ||
| 272 | - throw NotImplementedException("QoS not implemented"); | ||
| 273 | logger->logf(LOG_INFO, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str()); | 295 | logger->logf(LOG_INFO, "Client '%s' subscribed to '%s'", sender->repr().c_str(), topic.c_str()); |
| 274 | - sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic); | 296 | + sender->getThreadData()->getSubscriptionStore()->addSubscription(sender, topic, qos); |
| 275 | subs_reponse_codes.push_back(qos); | 297 | subs_reponse_codes.push_back(qos); |
| 276 | } | 298 | } |
| 277 | 299 | ||
| @@ -293,6 +315,7 @@ void MqttPacket::handlePublish() | @@ -293,6 +315,7 @@ void MqttPacket::handlePublish() | ||
| 293 | 315 | ||
| 294 | if (qos == 3) | 316 | if (qos == 3) |
| 295 | throw ProtocolError("QoS 3 is a protocol violation."); | 317 | throw ProtocolError("QoS 3 is a protocol violation."); |
| 318 | + this->qos = qos; | ||
| 296 | 319 | ||
| 297 | std::string topic(readBytes(variable_header_length), variable_header_length); | 320 | std::string topic(readBytes(variable_header_length), variable_header_length); |
| 298 | 321 | ||
| @@ -310,8 +333,19 @@ void MqttPacket::handlePublish() | @@ -310,8 +333,19 @@ void MqttPacket::handlePublish() | ||
| 310 | 333 | ||
| 311 | if (qos) | 334 | if (qos) |
| 312 | { | 335 | { |
| 313 | - throw ProtocolError("Qos not implemented."); | 336 | + if (qos > 1) |
| 337 | + throw ProtocolError("Qos > 1 not implemented."); | ||
| 338 | + packet_id_pos = pos; | ||
| 314 | uint16_t packet_id = readTwoBytesToUInt16(); | 339 | uint16_t packet_id = readTwoBytesToUInt16(); |
| 340 | + | ||
| 341 | + // Clear the packet ID from this packet, because each new publish must get a new one. It's more of a debug precaution. | ||
| 342 | + pos -= 2; | ||
| 343 | + char zero[2]; zero[0] = 0; zero[1] = 0; | ||
| 344 | + writeBytes(zero, 2); | ||
| 345 | + | ||
| 346 | + PubAck pubAck(packet_id); | ||
| 347 | + MqttPacket response(pubAck); | ||
| 348 | + sender->writeMqttPacket(response); | ||
| 315 | } | 349 | } |
| 316 | 350 | ||
| 317 | if (retain) | 351 | if (retain) |
| @@ -330,6 +364,12 @@ void MqttPacket::handlePublish() | @@ -330,6 +364,12 @@ void MqttPacket::handlePublish() | ||
| 330 | sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this, sender); | 364 | sender->getThreadData()->getSubscriptionStore()->queuePacketAtSubscribers(topic, *this, sender); |
| 331 | } | 365 | } |
| 332 | 366 | ||
| 367 | +void MqttPacket::handlePubAck() | ||
| 368 | +{ | ||
| 369 | + uint16_t packet_id = readTwoBytesToUInt16(); | ||
| 370 | + sender->getSession()->clearQosMessage(packet_id); | ||
| 371 | +} | ||
| 372 | + | ||
| 333 | void MqttPacket::calculateRemainingLength() | 373 | void MqttPacket::calculateRemainingLength() |
| 334 | { | 374 | { |
| 335 | assert(fixed_header_length == 0); // because you're not supposed to call this on packet that we already know the length of. | 375 | assert(fixed_header_length == 0); // because you're not supposed to call this on packet that we already know the length of. |
| @@ -356,6 +396,40 @@ RemainingLength MqttPacket::getRemainingLength() const | @@ -356,6 +396,40 @@ RemainingLength MqttPacket::getRemainingLength() const | ||
| 356 | return remainingLength; | 396 | return remainingLength; |
| 357 | } | 397 | } |
| 358 | 398 | ||
| 399 | +void MqttPacket::setPacketId(uint16_t packet_id) | ||
| 400 | +{ | ||
| 401 | + // In other words, we assume that this code can only be called on packets of which we have all the bytes, including fixed header. | ||
| 402 | + assert(fixed_header_length > 0); | ||
| 403 | + assert(packetType == PacketType::PUBLISH); | ||
| 404 | + assert(qos > 0); | ||
| 405 | + | ||
| 406 | + pos = packet_id_pos; | ||
| 407 | + | ||
| 408 | + char topicLenMSB = (packet_id & 0xFF00) >> 8; | ||
| 409 | + char topicLenLSB = (packet_id & 0x00FF); | ||
| 410 | + writeByte(topicLenMSB); | ||
| 411 | + writeByte(topicLenLSB); | ||
| 412 | +} | ||
| 413 | + | ||
| 414 | +// If I read the specs correctly, the DUP flag is merely for show. It doesn't control anything? | ||
| 415 | +void MqttPacket::setDuplicate() | ||
| 416 | +{ | ||
| 417 | + // In other words, we assume that this code can only be called on packets of which we have all the bytes, including fixed header. | ||
| 418 | + assert(fixed_header_length > 0); | ||
| 419 | + assert(packetType == PacketType::PUBLISH); | ||
| 420 | + assert(qos > 0); | ||
| 421 | + | ||
| 422 | + char byte1 = bites[0]; | ||
| 423 | + byte1 |= 0b00001000; | ||
| 424 | + pos = 0; | ||
| 425 | + writeByte(byte1); | ||
| 426 | +} | ||
| 427 | + | ||
| 428 | +size_t MqttPacket::getTotalMemoryFootprint() | ||
| 429 | +{ | ||
| 430 | + return bites.size() + sizeof(MqttPacket); | ||
| 431 | +} | ||
| 432 | + | ||
| 359 | size_t MqttPacket::getSizeIncludingNonPresentHeader() const | 433 | size_t MqttPacket::getSizeIncludingNonPresentHeader() const |
| 360 | { | 434 | { |
| 361 | size_t total = bites.size(); | 435 | size_t total = bites.size(); |
mqttpacket.h
| @@ -28,9 +28,11 @@ class MqttPacket | @@ -28,9 +28,11 @@ class MqttPacket | ||
| 28 | std::vector<char> bites; | 28 | std::vector<char> bites; |
| 29 | size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header. | 29 | size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header. |
| 30 | RemainingLength remainingLength; | 30 | RemainingLength remainingLength; |
| 31 | + char qos = 0; | ||
| 31 | Client_p sender; | 32 | Client_p sender; |
| 32 | char first_byte = 0; | 33 | char first_byte = 0; |
| 33 | size_t pos = 0; | 34 | size_t pos = 0; |
| 35 | + size_t packet_id_pos = 0; | ||
| 34 | ProtocolVersion protocolVersion = ProtocolVersion::None; | 36 | ProtocolVersion protocolVersion = ProtocolVersion::None; |
| 35 | Logger *logger = Logger::getInstance(); | 37 | Logger *logger = Logger::getInstance(); |
| 36 | 38 | ||
| @@ -43,17 +45,20 @@ class MqttPacket | @@ -43,17 +45,20 @@ class MqttPacket | ||
| 43 | 45 | ||
| 44 | void calculateRemainingLength(); | 46 | void calculateRemainingLength(); |
| 45 | 47 | ||
| 48 | + MqttPacket(const MqttPacket &other) = default; | ||
| 46 | public: | 49 | public: |
| 47 | PacketType packetType = PacketType::Reserved; | 50 | PacketType packetType = PacketType::Reserved; |
| 48 | MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, Client_p &sender); // Constructor for parsing incoming packets. | 51 | MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, Client_p &sender); // Constructor for parsing incoming packets. |
| 49 | 52 | ||
| 50 | MqttPacket(MqttPacket &&other) = default; | 53 | MqttPacket(MqttPacket &&other) = default; |
| 51 | - MqttPacket(const MqttPacket &other) = delete; | 54 | + |
| 55 | + std::shared_ptr<MqttPacket> getCopy() const; | ||
| 52 | 56 | ||
| 53 | // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance. | 57 | // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance. |
| 54 | MqttPacket(const ConnAck &connAck); | 58 | MqttPacket(const ConnAck &connAck); |
| 55 | MqttPacket(const SubAck &subAck); | 59 | MqttPacket(const SubAck &subAck); |
| 56 | MqttPacket(const Publish &publish); | 60 | MqttPacket(const Publish &publish); |
| 61 | + MqttPacket(const PubAck &pubAck); | ||
| 57 | 62 | ||
| 58 | void handle(); | 63 | void handle(); |
| 59 | void handleConnect(); | 64 | void handleConnect(); |
| @@ -61,16 +66,19 @@ public: | @@ -61,16 +66,19 @@ public: | ||
| 61 | void handleSubscribe(); | 66 | void handleSubscribe(); |
| 62 | void handlePing(); | 67 | void handlePing(); |
| 63 | void handlePublish(); | 68 | void handlePublish(); |
| 69 | + void handlePubAck(); | ||
| 64 | 70 | ||
| 65 | size_t getSizeIncludingNonPresentHeader() const; | 71 | size_t getSizeIncludingNonPresentHeader() const; |
| 66 | const std::vector<char> &getBites() const { return bites; } | 72 | const std::vector<char> &getBites() const { return bites; } |
| 67 | - | 73 | + char getQos() const { return qos; } |
| 68 | Client_p getSender() const; | 74 | Client_p getSender() const; |
| 69 | void setSender(const Client_p &value); | 75 | void setSender(const Client_p &value); |
| 70 | - | ||
| 71 | bool containsFixedHeader() const; | 76 | bool containsFixedHeader() const; |
| 72 | char getFirstByte() const; | 77 | char getFirstByte() const; |
| 73 | RemainingLength getRemainingLength() const; | 78 | RemainingLength getRemainingLength() const; |
| 79 | + void setPacketId(uint16_t packet_id); | ||
| 80 | + void setDuplicate(); | ||
| 81 | + size_t getTotalMemoryFootprint(); | ||
| 74 | }; | 82 | }; |
| 75 | 83 | ||
| 76 | #endif // MQTTPACKET_H | 84 | #endif // MQTTPACKET_H |
session.cpp
| 1 | +#include "cassert" | ||
| 2 | + | ||
| 1 | #include "session.h" | 3 | #include "session.h" |
| 4 | +#include "client.h" | ||
| 2 | 5 | ||
| 3 | Session::Session() | 6 | Session::Session() |
| 4 | { | 7 | { |
| @@ -19,3 +22,74 @@ void Session::assignActiveConnection(std::shared_ptr<Client> &client) | @@ -19,3 +22,74 @@ void Session::assignActiveConnection(std::shared_ptr<Client> &client) | ||
| 19 | { | 22 | { |
| 20 | this->client = client; | 23 | this->client = client; |
| 21 | } | 24 | } |
| 25 | + | ||
| 26 | +void Session::writePacket(const MqttPacket &packet) | ||
| 27 | +{ | ||
| 28 | + const char qos = packet.getQos(); | ||
| 29 | + | ||
| 30 | + if (qos == 0) | ||
| 31 | + { | ||
| 32 | + if (!clientDisconnected()) | ||
| 33 | + { | ||
| 34 | + Client_p c = makeSharedClient(); | ||
| 35 | + c->writeMqttPacketAndBlameThisClient(packet); | ||
| 36 | + } | ||
| 37 | + } | ||
| 38 | + else if (qos == 1) | ||
| 39 | + { | ||
| 40 | + std::shared_ptr<MqttPacket> copyPacket = packet.getCopy(); | ||
| 41 | + std::unique_lock<std::mutex> locker(qosQueueMutex); | ||
| 42 | + if (qosPacketQueue.size() >= MAX_QOS_MSG_PENDING_PER_CLIENT || (qosQueueBytes >= MAX_QOS_BYTES_PENDING_PER_CLIENT && qosPacketQueue.size() > 0)) | ||
| 43 | + { | ||
| 44 | + logger->logf(LOG_WARNING, "Dropping QoS message for client 'TODO', because its QoS buffers were full."); | ||
| 45 | + return; | ||
| 46 | + } | ||
| 47 | + const uint16_t pid = nextPacketId++; | ||
| 48 | + copyPacket->setPacketId(pid); | ||
| 49 | + qosPacketQueue[pid] = copyPacket; | ||
| 50 | + qosQueueBytes += copyPacket->getTotalMemoryFootprint(); | ||
| 51 | + locker.unlock(); | ||
| 52 | + | ||
| 53 | + if (!clientDisconnected()) | ||
| 54 | + { | ||
| 55 | + Client_p c = makeSharedClient(); | ||
| 56 | + c->writeMqttPacketAndBlameThisClient(*copyPacket.get()); | ||
| 57 | + copyPacket->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | ||
| 58 | + } | ||
| 59 | + } | ||
| 60 | +} | ||
| 61 | + | ||
| 62 | +void Session::clearQosMessage(uint16_t packet_id) | ||
| 63 | +{ | ||
| 64 | + std::lock_guard<std::mutex> locker(qosQueueMutex); | ||
| 65 | + auto it = qosPacketQueue.find(packet_id); | ||
| 66 | + if (it != qosPacketQueue.end()) | ||
| 67 | + { | ||
| 68 | + std::shared_ptr<MqttPacket> packet = it->second; | ||
| 69 | + qosPacketQueue.erase(it); | ||
| 70 | + qosQueueBytes -= packet->getTotalMemoryFootprint(); | ||
| 71 | + assert(qosQueueBytes >= 0); | ||
| 72 | + if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose. | ||
| 73 | + qosQueueBytes = 0; | ||
| 74 | + } | ||
| 75 | +} | ||
| 76 | + | ||
| 77 | +// [MQTT-4.4.0-1]: "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST re-send any | ||
| 78 | +// unacknowledged PUBLISH Packets (where QoS > 0) and PUBREL Packets using their original Packet Identifiers. This | ||
| 79 | +// is the only circumstance where a Client or Server is REQUIRED to redeliver messages." | ||
| 80 | +// | ||
| 81 | +// There is a bit of a hole there, I think. When we write out a packet to a receiver, it may decide to drop it, if its buffers | ||
| 82 | +// are full, for instance. We are not required to (periodically) retry. TODO Perhaps I will implement that retry anyway. | ||
| 83 | +void Session::sendPendingQosMessages() | ||
| 84 | +{ | ||
| 85 | + if (!clientDisconnected()) | ||
| 86 | + { | ||
| 87 | + Client_p c = makeSharedClient(); | ||
| 88 | + std::lock_guard<std::mutex> locker(qosQueueMutex); | ||
| 89 | + for (auto &qosMessage : qosPacketQueue) // TODO: wrong: the order must be maintained. Combine the fix with that vector idea | ||
| 90 | + { | ||
| 91 | + c->writeMqttPacketAndBlameThisClient(*qosMessage.second.get()); | ||
| 92 | + qosMessage.second->setDuplicate(); // Any dealings with this packet from here will be a duplicate. | ||
| 93 | + } | ||
| 94 | + } | ||
| 95 | +} |
session.h
| @@ -2,13 +2,24 @@ | @@ -2,13 +2,24 @@ | ||
| 2 | #define SESSION_H | 2 | #define SESSION_H |
| 3 | 3 | ||
| 4 | #include <memory> | 4 | #include <memory> |
| 5 | +#include <unordered_map> | ||
| 6 | +#include <mutex> | ||
| 5 | 7 | ||
| 6 | -class Client; | 8 | +#include "forward_declarations.h" |
| 9 | +#include "logger.h" | ||
| 10 | + | ||
| 11 | +// TODO make settings | ||
| 12 | +#define MAX_QOS_MSG_PENDING_PER_CLIENT 32 | ||
| 13 | +#define MAX_QOS_BYTES_PENDING_PER_CLIENT 4096 | ||
| 7 | 14 | ||
| 8 | class Session | 15 | class Session |
| 9 | { | 16 | { |
| 10 | std::weak_ptr<Client> client; | 17 | std::weak_ptr<Client> client; |
| 11 | - // TODO: qos message queue, as some kind of movable pointer. | 18 | + std::unordered_map<uint16_t, std::shared_ptr<MqttPacket>> qosPacketQueue; // TODO: because the max queue length should remain low-ish, perhaps a vector is better here. |
| 19 | + std::mutex qosQueueMutex; | ||
| 20 | + uint16_t nextPacketId = 0; | ||
| 21 | + ssize_t qosQueueBytes = 0; | ||
| 22 | + Logger *logger = Logger::getInstance(); | ||
| 12 | public: | 23 | public: |
| 13 | Session(); | 24 | Session(); |
| 14 | Session(const Session &other) = delete; | 25 | Session(const Session &other) = delete; |
| @@ -17,6 +28,9 @@ public: | @@ -17,6 +28,9 @@ public: | ||
| 17 | bool clientDisconnected() const; | 28 | bool clientDisconnected() const; |
| 18 | std::shared_ptr<Client> makeSharedClient() const; | 29 | std::shared_ptr<Client> makeSharedClient() const; |
| 19 | void assignActiveConnection(std::shared_ptr<Client> &client); | 30 | void assignActiveConnection(std::shared_ptr<Client> &client); |
| 31 | + void writePacket(const MqttPacket &packet); | ||
| 32 | + void clearQosMessage(uint16_t packet_id); | ||
| 33 | + void sendPendingQosMessages(); | ||
| 20 | }; | 34 | }; |
| 21 | 35 | ||
| 22 | #endif // SESSION_H | 36 | #endif // SESSION_H |
subscriptionstore.cpp
| @@ -18,7 +18,7 @@ SubscriptionStore::SubscriptionStore() : | @@ -18,7 +18,7 @@ SubscriptionStore::SubscriptionStore() : | ||
| 18 | 18 | ||
| 19 | } | 19 | } |
| 20 | 20 | ||
| 21 | -void SubscriptionStore::addSubscription(Client_p &client, const std::string &topic) | 21 | +void SubscriptionStore::addSubscription(Client_p &client, const std::string &topic, char qos) |
| 22 | { | 22 | { |
| 23 | const std::list<std::string> subtopics = split(topic, '/'); | 23 | const std::list<std::string> subtopics = split(topic, '/'); |
| 24 | 24 | ||
| @@ -89,10 +89,13 @@ void SubscriptionStore::registerClientAndKickExistingOne(Client_p &client) | @@ -89,10 +89,13 @@ void SubscriptionStore::registerClientAndKickExistingOne(Client_p &client) | ||
| 89 | if (!session || client->getCleanSession()) | 89 | if (!session || client->getCleanSession()) |
| 90 | { | 90 | { |
| 91 | session.reset(new Session()); | 91 | session.reset(new Session()); |
| 92 | + | ||
| 92 | sessionsById[client->getClientId()] = session; | 93 | sessionsById[client->getClientId()] = session; |
| 93 | } | 94 | } |
| 94 | 95 | ||
| 95 | session->assignActiveConnection(client); | 96 | session->assignActiveConnection(client); |
| 97 | + client->assignSession(session); | ||
| 98 | + session->sendPendingQosMessages(); | ||
| 96 | } | 99 | } |
| 97 | 100 | ||
| 98 | // TODO: should I implement cache, this needs to be changed to returning a list of clients. | 101 | // TODO: should I implement cache, this needs to be changed to returning a list of clients. |
| @@ -103,12 +106,7 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const st | @@ -103,12 +106,7 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const st | ||
| 103 | if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect. | 106 | if (!session_weak.expired()) // Shared pointer expires when session has been cleaned by 'clean session' connect. |
| 104 | { | 107 | { |
| 105 | const std::shared_ptr<Session> session = session_weak.lock(); | 108 | const std::shared_ptr<Session> session = session_weak.lock(); |
| 106 | - | ||
| 107 | - if (!session->clientDisconnected()) | ||
| 108 | - { | ||
| 109 | - Client_p c = session->makeSharedClient(); | ||
| 110 | - c->writeMqttPacketAndBlameThisClient(packet); | ||
| 111 | - } | 109 | + session->writePacket(packet); |
| 112 | } | 110 | } |
| 113 | } | 111 | } |
| 114 | } | 112 | } |
| @@ -170,7 +168,7 @@ void SubscriptionStore::giveClientRetainedMessages(Client_p &client, const std:: | @@ -170,7 +168,7 @@ void SubscriptionStore::giveClientRetainedMessages(Client_p &client, const std:: | ||
| 170 | const MqttPacket packet(publish); | 168 | const MqttPacket packet(publish); |
| 171 | 169 | ||
| 172 | if (topicsMatch(subscribe_topic, rm.topic)) | 170 | if (topicsMatch(subscribe_topic, rm.topic)) |
| 173 | - client->writeMqttPacket(packet); | 171 | + client->writeMqttPacket(packet); // TODO: I think this needs to be session, not client, and then I can store it if it's QoS? I need to research how retain+qos works |
| 174 | } | 172 | } |
| 175 | } | 173 | } |
| 176 | 174 |
subscriptionstore.h
| @@ -30,7 +30,7 @@ public: | @@ -30,7 +30,7 @@ public: | ||
| 30 | SubscriptionNode(const SubscriptionNode &node) = delete; | 30 | SubscriptionNode(const SubscriptionNode &node) = delete; |
| 31 | SubscriptionNode(SubscriptionNode &&node) = delete; | 31 | SubscriptionNode(SubscriptionNode &&node) = delete; |
| 32 | 32 | ||
| 33 | - std::forward_list<std::weak_ptr<Session>> subscribers; // The idea is to store subscriptions by client id, to support persistent sessions. | 33 | + std::forward_list<std::weak_ptr<Session>> subscribers; // TODO: a subscription class, with qos |
| 34 | std::unordered_map<std::string, std::unique_ptr<SubscriptionNode>> children; | 34 | std::unordered_map<std::string, std::unique_ptr<SubscriptionNode>> children; |
| 35 | std::unique_ptr<SubscriptionNode> childrenPlus; | 35 | std::unique_ptr<SubscriptionNode> childrenPlus; |
| 36 | std::unique_ptr<SubscriptionNode> childrenPound; | 36 | std::unique_ptr<SubscriptionNode> childrenPound; |
| @@ -54,7 +54,7 @@ class SubscriptionStore | @@ -54,7 +54,7 @@ class SubscriptionStore | ||
| 54 | public: | 54 | public: |
| 55 | SubscriptionStore(); | 55 | SubscriptionStore(); |
| 56 | 56 | ||
| 57 | - void addSubscription(Client_p &client, const std::string &topic); | 57 | + void addSubscription(Client_p &client, const std::string &topic, char qos); |
| 58 | void registerClientAndKickExistingOne(Client_p &client); | 58 | void registerClientAndKickExistingOne(Client_p &client); |
| 59 | 59 | ||
| 60 | void queuePacketAtSubscribers(const std::string &topic, const MqttPacket &packet, const Client_p &sender); | 60 | void queuePacketAtSubscribers(const std::string &topic, const MqttPacket &packet, const Client_p &sender); |
types.cpp
| @@ -15,6 +15,13 @@ SubAck::SubAck(uint16_t packet_id, const std::list<char> &subs_qos_reponses) : | @@ -15,6 +15,13 @@ SubAck::SubAck(uint16_t packet_id, const std::list<char> &subs_qos_reponses) : | ||
| 15 | } | 15 | } |
| 16 | } | 16 | } |
| 17 | 17 | ||
| 18 | +size_t SubAck::getLengthWithoutFixedHeader() const | ||
| 19 | +{ | ||
| 20 | + size_t result = responses.size(); | ||
| 21 | + result += 2; // Packet ID | ||
| 22 | + return result; | ||
| 23 | +} | ||
| 24 | + | ||
| 18 | Publish::Publish(const std::string &topic, const std::string payload, char qos) : | 25 | Publish::Publish(const std::string &topic, const std::string payload, char qos) : |
| 19 | topic(topic), | 26 | topic(topic), |
| 20 | payload(payload), | 27 | payload(payload), |
| @@ -23,8 +30,7 @@ Publish::Publish(const std::string &topic, const std::string payload, char qos) | @@ -23,8 +30,7 @@ Publish::Publish(const std::string &topic, const std::string payload, char qos) | ||
| 23 | 30 | ||
| 24 | } | 31 | } |
| 25 | 32 | ||
| 26 | -// Length starting at the variable header, not the fixed header. | ||
| 27 | -size_t Publish::getLength() const | 33 | +size_t Publish::getLengthWithoutFixedHeader() const |
| 28 | { | 34 | { |
| 29 | int result = topic.length() + payload.length() + 2; | 35 | int result = topic.length() + payload.length() + 2; |
| 30 | 36 | ||
| @@ -33,3 +39,15 @@ size_t Publish::getLength() const | @@ -33,3 +39,15 @@ size_t Publish::getLength() const | ||
| 33 | 39 | ||
| 34 | return result; | 40 | return result; |
| 35 | } | 41 | } |
| 42 | + | ||
| 43 | +PubAck::PubAck(uint16_t packet_id) : | ||
| 44 | + packet_id(packet_id) | ||
| 45 | +{ | ||
| 46 | + | ||
| 47 | +} | ||
| 48 | + | ||
| 49 | +// Packet has no payload and only a variable header, of length 2. | ||
| 50 | +size_t PubAck::getLengthWithoutFixedHeader() const | ||
| 51 | +{ | ||
| 52 | + return 2; | ||
| 53 | +} |
types.h
| @@ -48,7 +48,7 @@ class ConnAck | @@ -48,7 +48,7 @@ class ConnAck | ||
| 48 | public: | 48 | public: |
| 49 | ConnAck(ConnAckReturnCodes return_code); | 49 | ConnAck(ConnAckReturnCodes return_code); |
| 50 | ConnAckReturnCodes return_code; | 50 | ConnAckReturnCodes return_code; |
| 51 | - size_t getLength() const { return 2;} // size of connack is always the same | 51 | + size_t getLengthWithoutFixedHeader() const { return 2;} // size of connack is always the same |
| 52 | }; | 52 | }; |
| 53 | 53 | ||
| 54 | enum class SubAckReturnCodes | 54 | enum class SubAckReturnCodes |
| @@ -65,6 +65,7 @@ public: | @@ -65,6 +65,7 @@ public: | ||
| 65 | uint16_t packet_id; | 65 | uint16_t packet_id; |
| 66 | std::list<SubAckReturnCodes> responses; | 66 | std::list<SubAckReturnCodes> responses; |
| 67 | SubAck(uint16_t packet_id, const std::list<char> &subs_qos_reponses); | 67 | SubAck(uint16_t packet_id, const std::list<char> &subs_qos_reponses); |
| 68 | + size_t getLengthWithoutFixedHeader() const; | ||
| 68 | }; | 69 | }; |
| 69 | 70 | ||
| 70 | class Publish | 71 | class Publish |
| @@ -75,7 +76,15 @@ public: | @@ -75,7 +76,15 @@ public: | ||
| 75 | char qos = 0; | 76 | char qos = 0; |
| 76 | bool retain = false; // Note: existing subscribers don't get publishes of retained messages with retain=1. [MQTT-3.3.1-9] | 77 | bool retain = false; // Note: existing subscribers don't get publishes of retained messages with retain=1. [MQTT-3.3.1-9] |
| 77 | Publish(const std::string &topic, const std::string payload, char qos); | 78 | Publish(const std::string &topic, const std::string payload, char qos); |
| 78 | - size_t getLength() const; | 79 | + size_t getLengthWithoutFixedHeader() const; |
| 80 | +}; | ||
| 81 | + | ||
| 82 | +class PubAck | ||
| 83 | +{ | ||
| 84 | +public: | ||
| 85 | + PubAck(uint16_t packet_id); | ||
| 86 | + uint16_t packet_id; | ||
| 87 | + size_t getLengthWithoutFixedHeader() const; | ||
| 79 | }; | 88 | }; |
| 80 | 89 | ||
| 81 | #endif // TYPES_H | 90 | #endif // TYPES_H |