diff --git a/mqttpacket.cpp b/mqttpacket.cpp index b8080ef..00f0849 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -54,9 +54,8 @@ MqttPacket::MqttPacket(const ConnAck &connAck) : char first_byte = static_cast(packetType) << 4; writeByte(first_byte); writeByte(2); // length is always 2. - writeByte(0); // all connect-ack flags are 0, except session-present, but we don't have that yet. TODO: make that + writeByte(connAck.session_present & 0b00000001); // all connect-ack flags are 0, except session-present. [MQTT-3.2.2.1] writeByte(static_cast(connAck.return_code)); - } MqttPacket::MqttPacket(const SubAck &subAck) : @@ -155,6 +154,8 @@ void MqttPacket::handleConnect() GlobalSettings *settings = GlobalSettings::getInstance(); + std::shared_ptr subscriptionStore = sender->getThreadData()->getSubscriptionStore(); + uint16_t variable_header_length = readTwoBytesToUInt16(); if (variable_header_length == 4 || variable_header_length == 6) @@ -278,17 +279,18 @@ void MqttPacket::handleConnect() if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success) { - sender->getThreadData()->getSubscriptionStore()->registerClientAndKickExistingOne(sender); + bool sessionPresent = protocolVersion >= ProtocolVersion::Mqtt311 && !clean_session && subscriptionStore->sessionPresent(client_id); + subscriptionStore->registerClientAndKickExistingOne(sender); sender->setAuthenticated(true); - ConnAck connAck(ConnAckReturnCodes::Accepted); + ConnAck connAck(ConnAckReturnCodes::Accepted, sessionPresent); MqttPacket response(connAck); sender->writeMqttPacket(response); logger->logf(LOG_NOTICE, "User '%s' logged in successfully", username.c_str()); } else { - ConnAck connDeny(ConnAckReturnCodes::NotAuthorized); + ConnAck connDeny(ConnAckReturnCodes::NotAuthorized, false); MqttPacket response(connDeny); sender->setDisconnectReason("Access denied"); sender->setReadyForDisconnect(); diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index f067df1..bec9118 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -121,6 +121,22 @@ void SubscriptionStore::registerClientAndKickExistingOne(Client_p &client) session->sendPendingQosMessages(); } +bool SubscriptionStore::sessionPresent(const std::string &clientid) +{ + RWLockGuard lock_guard(&subscriptionsRwlock); + lock_guard.rdlock(); + + bool result = false; + + auto it = sessionsByIdConst.find(clientid); + if (it != sessionsByIdConst.end()) + { + it->second->touch(); // Touching to avoid a race condition between using the session after this, and it expiring. + result = true; + } + return result; +} + // TODO: should I implement cache, this needs to be changed to returning a list of clients. void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::vector &subscribers) const { diff --git a/subscriptionstore.h b/subscriptionstore.h index 1732574..958a809 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -70,6 +70,7 @@ public: void addSubscription(Client_p &client, const std::string &topic, char qos); void registerClientAndKickExistingOne(Client_p &client); + bool sessionPresent(const std::string &clientid); void queuePacketAtSubscribers(const std::string &topic, const MqttPacket &packet, const Client_p &sender); void giveClientRetainedMessages(const std::shared_ptr &ses, const std::string &subscribe_topic, char max_qos); diff --git a/types.cpp b/types.cpp index 96fc5b5..973820e 100644 --- a/types.cpp +++ b/types.cpp @@ -1,9 +1,12 @@ #include "types.h" -ConnAck::ConnAck(ConnAckReturnCodes return_code) : - return_code(return_code) +ConnAck::ConnAck(ConnAckReturnCodes return_code, bool session_present) : + return_code(return_code), + session_present(session_present) { - + // [MQTT-3.2.2-4] + if (return_code > ConnAckReturnCodes::Accepted) + session_present = false; } SubAck::SubAck(uint16_t packet_id, const std::list &subs_qos_reponses) : diff --git a/types.h b/types.h index 8523907..c6d3352 100644 --- a/types.h +++ b/types.h @@ -46,8 +46,9 @@ enum class ConnAckReturnCodes class ConnAck { public: - ConnAck(ConnAckReturnCodes return_code); + ConnAck(ConnAckReturnCodes return_code, bool session_present=false); ConnAckReturnCodes return_code; + bool session_present = false; size_t getLengthWithoutFixedHeader() const { return 2;} // size of connack is always the same };