diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index f2c7cd1..208a807 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -1004,13 +1004,13 @@ void MainTests::testSavingSessions() Authentication auth(*settings.get()); ThreadGlobals::assign(&auth); - std::shared_ptr c1(new Client(0, t, nullptr, false, nullptr, settings, false)); + std::shared_ptr c1(new Client(0, t, nullptr, false, nullptr, settings.get(), false)); c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60); store->registerClientAndKickExistingOne(c1, false, 512, 120); c1->getSession()->addIncomingQoS2MessageId(2); c1->getSession()->addIncomingQoS2MessageId(3); - std::shared_ptr c2(new Client(0, t, nullptr, false, nullptr, settings, false)); + std::shared_ptr c2(new Client(0, t, nullptr, false, nullptr, settings.get(), false)); c2->setClientProperties(ProtocolVersion::Mqtt311, "c2", "user2", true, 60); store->registerClientAndKickExistingOne(c2, false, 512, 120); c2->getSession()->addOutgoingQoS2MessageId(55); @@ -1119,7 +1119,7 @@ void MainTests::testParsePacketHelper(const std::string &topic, char from_qos, b Authentication auth(*settings.get()); ThreadGlobals::assign(&auth); - std::shared_ptr dummyClient(new Client(0, t, nullptr, false, nullptr, settings, false)); + std::shared_ptr dummyClient(new Client(0, t, nullptr, false, nullptr, settings.get(), false)); dummyClient->setClientProperties(ProtocolVersion::Mqtt311, "qostestclient", "user1", true, 60); store->registerClientAndKickExistingOne(dummyClient, false, 512, 120); @@ -1142,7 +1142,7 @@ void MainTests::testParsePacketHelper(const std::string &topic, char from_qos, b MqttPacket::bufferToMqttPackets(stagingBufOne, parsedPackets, dummyClient); QVERIFY(parsedPackets.size() == 1); MqttPacket parsedPacketOne = std::move(parsedPackets.front()); - parsedPacketOne.handlePublish(); + parsedPacketOne.parsePublishData(); if (retain) // A normal handled packet always has retain=0, so I force setting it here. parsedPacketOne.setRetain(); diff --git a/client.cpp b/client.cpp index 81fbd20..79137f5 100644 --- a/client.cpp +++ b/client.cpp @@ -27,7 +27,7 @@ License along with FlashMQ. If not, see . #include "utils.h" #include "threadglobals.h" -Client::Client(int fd, std::shared_ptr threadData, SSL *ssl, bool websocket, struct sockaddr *addr, std::shared_ptr settings, bool fuzzMode) : +Client::Client(int fd, std::shared_ptr threadData, SSL *ssl, bool websocket, struct sockaddr *addr, const Settings *settings, bool fuzzMode) : fd(fd), fuzzMode(fuzzMode), initialBufferSize(settings->clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy @@ -51,7 +51,11 @@ Client::Client(int fd, std::shared_ptr threadData, SSL *ssl, bool we Client::~Client() { - std::shared_ptr &store = getThreadData()->getSubscriptionStore(); + // Dummy clients, that I sometimes need just because the interface demands it but there's not actually a client, have no thread. + if (!this->threadData) + return; + + std::shared_ptr &store = this->threadData->getSubscriptionStore(); if (disconnectReason.empty()) disconnectReason = "not specified"; diff --git a/client.h b/client.h index 30f1b2e..ae8ab32 100644 --- a/client.h +++ b/client.h @@ -95,7 +95,7 @@ class Client void setReadyForReading(bool val); public: - Client(int fd, std::shared_ptr threadData, SSL *ssl, bool websocket, struct sockaddr *addr, std::shared_ptr settings, bool fuzzMode=false); + Client(int fd, std::shared_ptr threadData, SSL *ssl, bool websocket, struct sockaddr *addr, const Settings *settings, bool fuzzMode=false); Client(const Client &other) = delete; Client(Client &&other) = delete; ~Client(); diff --git a/mainapp.cpp b/mainapp.cpp index 932a5f4..36487f7 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -471,12 +471,12 @@ void MainApp::start() std::shared_ptr threaddata = std::make_shared(0, subscriptionStore, settings); - std::shared_ptr client = std::make_shared(fd, threaddata, nullptr, fuzzWebsockets, nullptr, settings, true); - std::shared_ptr subscriber = std::make_shared(fdnull, threaddata, nullptr, fuzzWebsockets, nullptr, settings, true); + std::shared_ptr client = std::make_shared(fd, threaddata, nullptr, fuzzWebsockets, nullptr, settings.get(), true); + std::shared_ptr subscriber = std::make_shared(fdnull, threaddata, nullptr, fuzzWebsockets, nullptr, settings.get(), true); subscriber->setClientProperties(ProtocolVersion::Mqtt311, "subscriber", "subuser", true, 60); subscriber->setAuthenticated(true); - std::shared_ptr websocketsubscriber = std::make_shared(fdnull2, threaddata, nullptr, true, nullptr, settings, true); + std::shared_ptr websocketsubscriber = std::make_shared(fdnull2, threaddata, nullptr, true, nullptr, settings.get(), true); websocketsubscriber->setClientProperties(ProtocolVersion::Mqtt311, "websocketsubscriber", "websocksubuser", true, 60); websocketsubscriber->setAuthenticated(true); websocketsubscriber->setFakeUpgraded(); @@ -574,7 +574,7 @@ void MainApp::start() SSL_set_fd(clientSSL, fd); } - std::shared_ptr client = std::make_shared(fd, thread_data, clientSSL, listener->websocket, addr, settings); + std::shared_ptr client = std::make_shared(fd, thread_data, clientSSL, listener->websocket, addr, settings.get()); thread_data->giveClient(client); } else diff --git a/mqttpacket.cpp b/mqttpacket.cpp index d42916e..52a8481 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -776,26 +776,23 @@ void MqttPacket::handleUnsubscribe() sender->writeMqttPacket(response); } -void MqttPacket::handlePublish(const bool stopAfterParsing) +void MqttPacket::parsePublishData() { const uint16_t variable_header_length = readTwoBytesToUInt16(); - bool retain = (first_byte & 0b00000001); + publishData.retain = (first_byte & 0b00000001); bool dup = !!(first_byte & 0b00001000); - char qos = (first_byte & 0b00000110) >> 1; + publishData.qos = (first_byte & 0b00000110) >> 1; - if (qos > 2) + if (publishData.qos > 2) throw ProtocolError("QoS 3 is a protocol violation."); - this->publishData.qos = qos; - if (qos == 0 && dup) + if (publishData.qos == 0 && dup) throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal."); publishData.topic = std::string(readBytes(variable_header_length), variable_header_length); - ReasonCodes ackCode = ReasonCodes::Success; - - if (qos) + if (publishData.qos) { packet_id_pos = pos; packet_id = readTwoBytesToUInt16(); @@ -827,6 +824,12 @@ void MqttPacket::handlePublish(const bool stopAfterParsing) break; case Mqtt5Properties::TopicAlias: { + // For when we use packets has helpers without a senser (like loading packets from disk). + // Logically, this should never trip because there can't be aliases in such packets, but including + // a check to be sure. + if (!sender) + break; + const uint16_t alias_id = readTwoBytesToUInt16(); this->hasTopicAlias = true; @@ -889,6 +892,14 @@ void MqttPacket::handlePublish(const bool stopAfterParsing) if (publishData.topic.empty()) throw ProtocolError("Empty publish topic"); + payloadLen = remainingAfterPos(); + payloadStart = pos; +} + +void MqttPacket::handlePublish() +{ + parsePublishData(); + if (!isValidUtf8(publishData.topic, true)) { const std::string err = formatString("Client '%s' published a message with invalid UTF8 or $/+/# in it. Dropping.", sender->repr().c_str()); @@ -897,41 +908,36 @@ void MqttPacket::handlePublish(const bool stopAfterParsing) } #ifndef NDEBUG - logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", publishData.topic.c_str(), qos, retain, dup); + logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", publishData.topic.c_str(), publishData.qos, publishData.retain, dup); #endif - payloadLen = remainingAfterPos(); - payloadStart = pos; + ReasonCodes ackCode = ReasonCodes::Success; sender->getThreadData()->incrementReceivedMessageCount(); - // TODO: or maybe create a function parsePublishData(). - if (stopAfterParsing) - return; - Authentication &authentication = *ThreadGlobals::getAuth(); // Working with a local copy because the subscribing action will modify this->packet_id. See the PublishCopyFactory. const uint16_t _packet_id = this->packet_id; - if (qos == 2 && sender->getSession()->incomingQoS2MessageIdInTransit(_packet_id)) + if (publishData.qos == 2 && sender->getSession()->incomingQoS2MessageIdInTransit(_packet_id)) { ackCode = ReasonCodes::PacketIdentifierInUse; } else { // Doing this before the authentication on purpose, so when the publish is not allowed, the QoS control packets are allowed and can finish. - if (qos == 2) + if (publishData.qos == 2) sender->getSession()->addIncomingQoS2MessageId(_packet_id); splitTopic(publishData.topic, publishData.subtopics); - if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.subtopics, AclAccess::write, qos, retain, getUserProperties()) == AuthResult::success) + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.subtopics, AclAccess::write, publishData.qos, publishData.retain, getUserProperties()) == AuthResult::success) { - if (retain) + if (publishData.retain) { std::string payload(readBytes(payloadLen), payloadLen); - sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(publishData.topic, publishData.subtopics, payload, qos); + sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(publishData.topic, publishData.subtopics, payload, publishData.qos); } // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. @@ -953,9 +959,9 @@ void MqttPacket::handlePublish(const bool stopAfterParsing) this->packet_id = 0; #endif - if (qos > 0) + if (publishData.qos > 0) { - const PacketType responseType = qos == 1 ? PacketType::PUBACK : PacketType::PUBREC; + const PacketType responseType = publishData.qos == 1 ? PacketType::PUBACK : PacketType::PUBREC; PubResponse pubAck(this->protocolVersion, responseType, ackCode, _packet_id); MqttPacket response(pubAck); sender->writeMqttPacket(response); diff --git a/mqttpacket.h b/mqttpacket.h index f71a0be..856a718 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -96,7 +96,8 @@ public: void handleSubscribe(); void handleUnsubscribe(); void handlePing(); - void handlePublish(const bool stopAfterParsing = false); + void parsePublishData(); + void handlePublish(); void handlePubAck(); void handlePubRec(); void handlePubRel(); diff --git a/sessionsandsubscriptionsdb.cpp b/sessionsandsubscriptionsdb.cpp index d8aa974..1d253e3 100644 --- a/sessionsandsubscriptionsdb.cpp +++ b/sessionsandsubscriptionsdb.cpp @@ -85,9 +85,9 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2() CirBuf cirbuf(1024); // TODO: all that settings and thread data needs to be removed from Client. - std::shared_ptr dummyThreadData; - std::shared_ptr dummySettings(new Settings()); // TODO: this is wrong: these are not from config file - std::shared_ptr dummyClient(new Client(0, dummyThreadData, nullptr, false, nullptr, dummySettings, false)); + std::shared_ptr dummyThreadData; // which thread am I going get/use here? + std::shared_ptr dummyClient(new Client(0, dummyThreadData, nullptr, false, nullptr, settings, false)); + dummyClient->setClientProperties(ProtocolVersion::Mqtt5, "Dummyforloadingqueuedqos", "nobody", true, 60); for (uint32_t i = 0; i < nrOfSessions; i++) { @@ -123,7 +123,7 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2() cirbuf.advanceHead(packlen); MqttPacket pack(cirbuf, packlen, 2, dummyClient); // TODO: store the 2 in the file - pack.handlePublish(true); + pack.parsePublishData(); Publish pub(pack.getPublishData()); logger->logf(LOG_DEBUG, "Loaded QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str()); @@ -240,12 +240,15 @@ void SessionsAndSubscriptionsDB::saveData(const std::vectorlogf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str()); - const MqttPacket pack(ProtocolVersion::Mqtt5, pub); + MqttPacket pack(ProtocolVersion::Mqtt5, pub); + pack.setPacketId(p.getPacketId()); const uint32_t packSize = pack.getSizeIncludingNonPresentHeader(); cirbuf.reset(); cirbuf.ensureFreeSpace(packSize + 32);