diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 7387b5d..e4e0283 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -221,6 +221,8 @@ void MqttPacket::handleConnect() if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success) { + sender->getThreadData()->getSubscriptionStore()->registerClientAndKickExistingOne(sender); + sender->setAuthenticated(true); ConnAck connAck(ConnAckReturnCodes::Accepted); MqttPacket response(connAck); diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index b8974bc..b0db79f 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -51,12 +51,35 @@ void SubscriptionStore::addSubscription(Client_p &client, const std::string &top deepestNode->subscribers.push_front(client->getClientId()); } - clients_by_id[client->getClientId()] = client; lock_guard.unlock(); giveClientRetainedMessages(client, topic); } +// Removes an existing client when it already exists [MQTT-3.1.4-2]. +void SubscriptionStore::registerClientAndKickExistingOne(Client_p &client) +{ + RWLockGuard lock_guard(&subscriptionsRwlock); + lock_guard.wrlock(); + + if (client->getClientId().empty()) + throw ProtocolError("Trying to store client without an ID."); + + std::weak_ptr existingClient = clients_by_id[client->getClientId()]; + auto it = clients_by_id.find(client->getClientId()); + + if (it != clients_by_id.end() && !it->second.expired()) + { + std::shared_ptr cl = it->second.lock(); + logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", cl->getClientId().c_str()); + cl->setReadyForDisconnect(); + cl->getThreadData()->removeClient(cl); + cl->markAsDisconnecting(); + } + + clients_by_id[client->getClientId()] = client; +} + // TODO: should I implement cache, this needs to be changed to returning a list of clients. void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::forward_list &subscribers) const { diff --git a/subscriptionstore.h b/subscriptionstore.h index 459e4f6..e2c7a43 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -12,6 +12,7 @@ #include "client.h" #include "utils.h" #include "retainedmessage.h" +#include "logger.h" struct RetainedPayload { @@ -44,6 +45,8 @@ class SubscriptionStore pthread_rwlock_t retainedMessagesRwlock = PTHREAD_RWLOCK_INITIALIZER; std::unordered_set retainedMessages; + Logger *logger = Logger::getInstance(); + void publishNonRecursively(const MqttPacket &packet, const std::forward_list &subscribers) const; void publishRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, std::unique_ptr &next, const MqttPacket &packet) const; @@ -51,6 +54,7 @@ public: SubscriptionStore(); void addSubscription(Client_p &client, const std::string &topic); + void registerClientAndKickExistingOne(Client_p &client); void queuePacketAtSubscribers(const std::string &topic, const MqttPacket &packet, const Client_p &sender); void giveClientRetainedMessages(Client_p &client, const std::string &subscribe_topic);