Commit c652de5a8b35ecf1127ff377c9841882aa8b7da4

Authored by Wiebe Cazemier
1 parent 6fed4d58

Continue with idea to use parsed packets for loading from disk

Because clients can now also exist as dummy objects, I had to add some
extra checks.

Also split up handlePublish() and the new parsePublishData().
FlashMQTests/tst_maintests.cpp
... ... @@ -1004,13 +1004,13 @@ void MainTests::testSavingSessions()
1004 1004 Authentication auth(*settings.get());
1005 1005 ThreadGlobals::assign(&auth);
1006 1006  
1007   - std::shared_ptr<Client> c1(new Client(0, t, nullptr, false, nullptr, settings, false));
  1007 + std::shared_ptr<Client> c1(new Client(0, t, nullptr, false, nullptr, settings.get(), false));
1008 1008 c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60);
1009 1009 store->registerClientAndKickExistingOne(c1, false, 512, 120);
1010 1010 c1->getSession()->addIncomingQoS2MessageId(2);
1011 1011 c1->getSession()->addIncomingQoS2MessageId(3);
1012 1012  
1013   - std::shared_ptr<Client> c2(new Client(0, t, nullptr, false, nullptr, settings, false));
  1013 + std::shared_ptr<Client> c2(new Client(0, t, nullptr, false, nullptr, settings.get(), false));
1014 1014 c2->setClientProperties(ProtocolVersion::Mqtt311, "c2", "user2", true, 60);
1015 1015 store->registerClientAndKickExistingOne(c2, false, 512, 120);
1016 1016 c2->getSession()->addOutgoingQoS2MessageId(55);
... ... @@ -1119,7 +1119,7 @@ void MainTests::testParsePacketHelper(const std::string &amp;topic, char from_qos, b
1119 1119 Authentication auth(*settings.get());
1120 1120 ThreadGlobals::assign(&auth);
1121 1121  
1122   - std::shared_ptr<Client> dummyClient(new Client(0, t, nullptr, false, nullptr, settings, false));
  1122 + std::shared_ptr<Client> dummyClient(new Client(0, t, nullptr, false, nullptr, settings.get(), false));
1123 1123 dummyClient->setClientProperties(ProtocolVersion::Mqtt311, "qostestclient", "user1", true, 60);
1124 1124 store->registerClientAndKickExistingOne(dummyClient, false, 512, 120);
1125 1125  
... ... @@ -1142,7 +1142,7 @@ void MainTests::testParsePacketHelper(const std::string &amp;topic, char from_qos, b
1142 1142 MqttPacket::bufferToMqttPackets(stagingBufOne, parsedPackets, dummyClient);
1143 1143 QVERIFY(parsedPackets.size() == 1);
1144 1144 MqttPacket parsedPacketOne = std::move(parsedPackets.front());
1145   - parsedPacketOne.handlePublish();
  1145 + parsedPacketOne.parsePublishData();
1146 1146 if (retain) // A normal handled packet always has retain=0, so I force setting it here.
1147 1147 parsedPacketOne.setRetain();
1148 1148  
... ...
client.cpp
... ... @@ -27,7 +27,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
27 27 #include "utils.h"
28 28 #include "threadglobals.h"
29 29  
30   -Client::Client(int fd, std::shared_ptr<ThreadData> threadData, SSL *ssl, bool websocket, struct sockaddr *addr, std::shared_ptr<Settings> settings, bool fuzzMode) :
  30 +Client::Client(int fd, std::shared_ptr<ThreadData> threadData, SSL *ssl, bool websocket, struct sockaddr *addr, const Settings *settings, bool fuzzMode) :
31 31 fd(fd),
32 32 fuzzMode(fuzzMode),
33 33 initialBufferSize(settings->clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy
... ... @@ -51,7 +51,11 @@ Client::Client(int fd, std::shared_ptr&lt;ThreadData&gt; threadData, SSL *ssl, bool we
51 51  
52 52 Client::~Client()
53 53 {
54   - std::shared_ptr<SubscriptionStore> &store = getThreadData()->getSubscriptionStore();
  54 + // Dummy clients, that I sometimes need just because the interface demands it but there's not actually a client, have no thread.
  55 + if (!this->threadData)
  56 + return;
  57 +
  58 + std::shared_ptr<SubscriptionStore> &store = this->threadData->getSubscriptionStore();
55 59  
56 60 if (disconnectReason.empty())
57 61 disconnectReason = "not specified";
... ...
client.h
... ... @@ -95,7 +95,7 @@ class Client
95 95 void setReadyForReading(bool val);
96 96  
97 97 public:
98   - Client(int fd, std::shared_ptr<ThreadData> threadData, SSL *ssl, bool websocket, struct sockaddr *addr, std::shared_ptr<Settings> settings, bool fuzzMode=false);
  98 + Client(int fd, std::shared_ptr<ThreadData> threadData, SSL *ssl, bool websocket, struct sockaddr *addr, const Settings *settings, bool fuzzMode=false);
99 99 Client(const Client &other) = delete;
100 100 Client(Client &&other) = delete;
101 101 ~Client();
... ...
mainapp.cpp
... ... @@ -471,12 +471,12 @@ void MainApp::start()
471 471  
472 472 std::shared_ptr<ThreadData> threaddata = std::make_shared<ThreadData>(0, subscriptionStore, settings);
473 473  
474   - std::shared_ptr<Client> client = std::make_shared<Client>(fd, threaddata, nullptr, fuzzWebsockets, nullptr, settings, true);
475   - std::shared_ptr<Client> subscriber = std::make_shared<Client>(fdnull, threaddata, nullptr, fuzzWebsockets, nullptr, settings, true);
  474 + std::shared_ptr<Client> client = std::make_shared<Client>(fd, threaddata, nullptr, fuzzWebsockets, nullptr, settings.get(), true);
  475 + std::shared_ptr<Client> subscriber = std::make_shared<Client>(fdnull, threaddata, nullptr, fuzzWebsockets, nullptr, settings.get(), true);
476 476 subscriber->setClientProperties(ProtocolVersion::Mqtt311, "subscriber", "subuser", true, 60);
477 477 subscriber->setAuthenticated(true);
478 478  
479   - std::shared_ptr<Client> websocketsubscriber = std::make_shared<Client>(fdnull2, threaddata, nullptr, true, nullptr, settings, true);
  479 + std::shared_ptr<Client> websocketsubscriber = std::make_shared<Client>(fdnull2, threaddata, nullptr, true, nullptr, settings.get(), true);
480 480 websocketsubscriber->setClientProperties(ProtocolVersion::Mqtt311, "websocketsubscriber", "websocksubuser", true, 60);
481 481 websocketsubscriber->setAuthenticated(true);
482 482 websocketsubscriber->setFakeUpgraded();
... ... @@ -574,7 +574,7 @@ void MainApp::start()
574 574 SSL_set_fd(clientSSL, fd);
575 575 }
576 576  
577   - std::shared_ptr<Client> client = std::make_shared<Client>(fd, thread_data, clientSSL, listener->websocket, addr, settings);
  577 + std::shared_ptr<Client> client = std::make_shared<Client>(fd, thread_data, clientSSL, listener->websocket, addr, settings.get());
578 578 thread_data->giveClient(client);
579 579 }
580 580 else
... ...
mqttpacket.cpp
... ... @@ -776,26 +776,23 @@ void MqttPacket::handleUnsubscribe()
776 776 sender->writeMqttPacket(response);
777 777 }
778 778  
779   -void MqttPacket::handlePublish(const bool stopAfterParsing)
  779 +void MqttPacket::parsePublishData()
780 780 {
781 781 const uint16_t variable_header_length = readTwoBytesToUInt16();
782 782  
783   - bool retain = (first_byte & 0b00000001);
  783 + publishData.retain = (first_byte & 0b00000001);
784 784 bool dup = !!(first_byte & 0b00001000);
785   - char qos = (first_byte & 0b00000110) >> 1;
  785 + publishData.qos = (first_byte & 0b00000110) >> 1;
786 786  
787   - if (qos > 2)
  787 + if (publishData.qos > 2)
788 788 throw ProtocolError("QoS 3 is a protocol violation.");
789   - this->publishData.qos = qos;
790 789  
791   - if (qos == 0 && dup)
  790 + if (publishData.qos == 0 && dup)
792 791 throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal.");
793 792  
794 793 publishData.topic = std::string(readBytes(variable_header_length), variable_header_length);
795 794  
796   - ReasonCodes ackCode = ReasonCodes::Success;
797   -
798   - if (qos)
  795 + if (publishData.qos)
799 796 {
800 797 packet_id_pos = pos;
801 798 packet_id = readTwoBytesToUInt16();
... ... @@ -827,6 +824,12 @@ void MqttPacket::handlePublish(const bool stopAfterParsing)
827 824 break;
828 825 case Mqtt5Properties::TopicAlias:
829 826 {
  827 + // For when we use packets has helpers without a senser (like loading packets from disk).
  828 + // Logically, this should never trip because there can't be aliases in such packets, but including
  829 + // a check to be sure.
  830 + if (!sender)
  831 + break;
  832 +
830 833 const uint16_t alias_id = readTwoBytesToUInt16();
831 834 this->hasTopicAlias = true;
832 835  
... ... @@ -889,6 +892,14 @@ void MqttPacket::handlePublish(const bool stopAfterParsing)
889 892 if (publishData.topic.empty())
890 893 throw ProtocolError("Empty publish topic");
891 894  
  895 + payloadLen = remainingAfterPos();
  896 + payloadStart = pos;
  897 +}
  898 +
  899 +void MqttPacket::handlePublish()
  900 +{
  901 + parsePublishData();
  902 +
892 903 if (!isValidUtf8(publishData.topic, true))
893 904 {
894 905 const std::string err = formatString("Client '%s' published a message with invalid UTF8 or $/+/# in it. Dropping.", sender->repr().c_str());
... ... @@ -897,41 +908,36 @@ void MqttPacket::handlePublish(const bool stopAfterParsing)
897 908 }
898 909  
899 910 #ifndef NDEBUG
900   - logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", publishData.topic.c_str(), qos, retain, dup);
  911 + logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", publishData.topic.c_str(), publishData.qos, publishData.retain, dup);
901 912 #endif
902 913  
903   - payloadLen = remainingAfterPos();
904   - payloadStart = pos;
  914 + ReasonCodes ackCode = ReasonCodes::Success;
905 915  
906 916 sender->getThreadData()->incrementReceivedMessageCount();
907 917  
908   - // TODO: or maybe create a function parsePublishData().
909   - if (stopAfterParsing)
910   - return;
911   -
912 918 Authentication &authentication = *ThreadGlobals::getAuth();
913 919  
914 920 // Working with a local copy because the subscribing action will modify this->packet_id. See the PublishCopyFactory.
915 921 const uint16_t _packet_id = this->packet_id;
916 922  
917   - if (qos == 2 && sender->getSession()->incomingQoS2MessageIdInTransit(_packet_id))
  923 + if (publishData.qos == 2 && sender->getSession()->incomingQoS2MessageIdInTransit(_packet_id))
918 924 {
919 925 ackCode = ReasonCodes::PacketIdentifierInUse;
920 926 }
921 927 else
922 928 {
923 929 // Doing this before the authentication on purpose, so when the publish is not allowed, the QoS control packets are allowed and can finish.
924   - if (qos == 2)
  930 + if (publishData.qos == 2)
925 931 sender->getSession()->addIncomingQoS2MessageId(_packet_id);
926 932  
927 933 splitTopic(publishData.topic, publishData.subtopics);
928 934  
929   - if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.subtopics, AclAccess::write, qos, retain, getUserProperties()) == AuthResult::success)
  935 + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.subtopics, AclAccess::write, publishData.qos, publishData.retain, getUserProperties()) == AuthResult::success)
930 936 {
931   - if (retain)
  937 + if (publishData.retain)
932 938 {
933 939 std::string payload(readBytes(payloadLen), payloadLen);
934   - sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(publishData.topic, publishData.subtopics, payload, qos);
  940 + sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(publishData.topic, publishData.subtopics, payload, publishData.qos);
935 941 }
936 942  
937 943 // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3].
... ... @@ -953,9 +959,9 @@ void MqttPacket::handlePublish(const bool stopAfterParsing)
953 959 this->packet_id = 0;
954 960 #endif
955 961  
956   - if (qos > 0)
  962 + if (publishData.qos > 0)
957 963 {
958   - const PacketType responseType = qos == 1 ? PacketType::PUBACK : PacketType::PUBREC;
  964 + const PacketType responseType = publishData.qos == 1 ? PacketType::PUBACK : PacketType::PUBREC;
959 965 PubResponse pubAck(this->protocolVersion, responseType, ackCode, _packet_id);
960 966 MqttPacket response(pubAck);
961 967 sender->writeMqttPacket(response);
... ...
mqttpacket.h
... ... @@ -96,7 +96,8 @@ public:
96 96 void handleSubscribe();
97 97 void handleUnsubscribe();
98 98 void handlePing();
99   - void handlePublish(const bool stopAfterParsing = false);
  99 + void parsePublishData();
  100 + void handlePublish();
100 101 void handlePubAck();
101 102 void handlePubRec();
102 103 void handlePubRel();
... ...
sessionsandsubscriptionsdb.cpp
... ... @@ -85,9 +85,9 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2()
85 85 CirBuf cirbuf(1024);
86 86  
87 87 // TODO: all that settings and thread data needs to be removed from Client.
88   - std::shared_ptr<ThreadData> dummyThreadData;
89   - std::shared_ptr<Settings> dummySettings(new Settings()); // TODO: this is wrong: these are not from config file
90   - std::shared_ptr<Client> dummyClient(new Client(0, dummyThreadData, nullptr, false, nullptr, dummySettings, false));
  88 + std::shared_ptr<ThreadData> dummyThreadData; // which thread am I going get/use here?
  89 + std::shared_ptr<Client> dummyClient(new Client(0, dummyThreadData, nullptr, false, nullptr, settings, false));
  90 + dummyClient->setClientProperties(ProtocolVersion::Mqtt5, "Dummyforloadingqueuedqos", "nobody", true, 60);
91 91  
92 92 for (uint32_t i = 0; i < nrOfSessions; i++)
93 93 {
... ... @@ -123,7 +123,7 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2()
123 123 cirbuf.advanceHead(packlen);
124 124 MqttPacket pack(cirbuf, packlen, 2, dummyClient); // TODO: store the 2 in the file
125 125  
126   - pack.handlePublish(true);
  126 + pack.parsePublishData();
127 127 Publish pub(pack.getPublishData());
128 128  
129 129 logger->logf(LOG_DEBUG, "Loaded QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str());
... ... @@ -240,12 +240,15 @@ void SessionsAndSubscriptionsDB::saveData(const std::vector&lt;std::unique_ptr&lt;Sess
240 240 qosPacketsCounted++;
241 241  
242 242 const Publish &pub = p.getPublish();
  243 +
243 244 assert(!pub.splitTopic);
  245 + assert(!pub.skipTopic);
244 246 assert(pub.topicAlias == 0);
245 247  
246 248 logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str());
247 249  
248   - const MqttPacket pack(ProtocolVersion::Mqtt5, pub);
  250 + MqttPacket pack(ProtocolVersion::Mqtt5, pub);
  251 + pack.setPacketId(p.getPacketId());
249 252 const uint32_t packSize = pack.getSizeIncludingNonPresentHeader();
250 253 cirbuf.reset();
251 254 cirbuf.ensureFreeSpace(packSize + 32);
... ...