Commit c652de5a8b35ecf1127ff377c9841882aa8b7da4

Authored by Wiebe Cazemier
1 parent 6fed4d58

Continue with idea to use parsed packets for loading from disk

Because clients can now also exist as dummy objects, I had to add some
extra checks.

Also split up handlePublish() and the new parsePublishData().
FlashMQTests/tst_maintests.cpp
@@ -1004,13 +1004,13 @@ void MainTests::testSavingSessions() @@ -1004,13 +1004,13 @@ void MainTests::testSavingSessions()
1004 Authentication auth(*settings.get()); 1004 Authentication auth(*settings.get());
1005 ThreadGlobals::assign(&auth); 1005 ThreadGlobals::assign(&auth);
1006 1006
1007 - std::shared_ptr<Client> c1(new Client(0, t, nullptr, false, nullptr, settings, false)); 1007 + std::shared_ptr<Client> c1(new Client(0, t, nullptr, false, nullptr, settings.get(), false));
1008 c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60); 1008 c1->setClientProperties(ProtocolVersion::Mqtt311, "c1", "user1", true, 60);
1009 store->registerClientAndKickExistingOne(c1, false, 512, 120); 1009 store->registerClientAndKickExistingOne(c1, false, 512, 120);
1010 c1->getSession()->addIncomingQoS2MessageId(2); 1010 c1->getSession()->addIncomingQoS2MessageId(2);
1011 c1->getSession()->addIncomingQoS2MessageId(3); 1011 c1->getSession()->addIncomingQoS2MessageId(3);
1012 1012
1013 - std::shared_ptr<Client> c2(new Client(0, t, nullptr, false, nullptr, settings, false)); 1013 + std::shared_ptr<Client> c2(new Client(0, t, nullptr, false, nullptr, settings.get(), false));
1014 c2->setClientProperties(ProtocolVersion::Mqtt311, "c2", "user2", true, 60); 1014 c2->setClientProperties(ProtocolVersion::Mqtt311, "c2", "user2", true, 60);
1015 store->registerClientAndKickExistingOne(c2, false, 512, 120); 1015 store->registerClientAndKickExistingOne(c2, false, 512, 120);
1016 c2->getSession()->addOutgoingQoS2MessageId(55); 1016 c2->getSession()->addOutgoingQoS2MessageId(55);
@@ -1119,7 +1119,7 @@ void MainTests::testParsePacketHelper(const std::string &amp;topic, char from_qos, b @@ -1119,7 +1119,7 @@ void MainTests::testParsePacketHelper(const std::string &amp;topic, char from_qos, b
1119 Authentication auth(*settings.get()); 1119 Authentication auth(*settings.get());
1120 ThreadGlobals::assign(&auth); 1120 ThreadGlobals::assign(&auth);
1121 1121
1122 - std::shared_ptr<Client> dummyClient(new Client(0, t, nullptr, false, nullptr, settings, false)); 1122 + std::shared_ptr<Client> dummyClient(new Client(0, t, nullptr, false, nullptr, settings.get(), false));
1123 dummyClient->setClientProperties(ProtocolVersion::Mqtt311, "qostestclient", "user1", true, 60); 1123 dummyClient->setClientProperties(ProtocolVersion::Mqtt311, "qostestclient", "user1", true, 60);
1124 store->registerClientAndKickExistingOne(dummyClient, false, 512, 120); 1124 store->registerClientAndKickExistingOne(dummyClient, false, 512, 120);
1125 1125
@@ -1142,7 +1142,7 @@ void MainTests::testParsePacketHelper(const std::string &amp;topic, char from_qos, b @@ -1142,7 +1142,7 @@ void MainTests::testParsePacketHelper(const std::string &amp;topic, char from_qos, b
1142 MqttPacket::bufferToMqttPackets(stagingBufOne, parsedPackets, dummyClient); 1142 MqttPacket::bufferToMqttPackets(stagingBufOne, parsedPackets, dummyClient);
1143 QVERIFY(parsedPackets.size() == 1); 1143 QVERIFY(parsedPackets.size() == 1);
1144 MqttPacket parsedPacketOne = std::move(parsedPackets.front()); 1144 MqttPacket parsedPacketOne = std::move(parsedPackets.front());
1145 - parsedPacketOne.handlePublish(); 1145 + parsedPacketOne.parsePublishData();
1146 if (retain) // A normal handled packet always has retain=0, so I force setting it here. 1146 if (retain) // A normal handled packet always has retain=0, so I force setting it here.
1147 parsedPacketOne.setRetain(); 1147 parsedPacketOne.setRetain();
1148 1148
client.cpp
@@ -27,7 +27,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -27,7 +27,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
27 #include "utils.h" 27 #include "utils.h"
28 #include "threadglobals.h" 28 #include "threadglobals.h"
29 29
30 -Client::Client(int fd, std::shared_ptr<ThreadData> threadData, SSL *ssl, bool websocket, struct sockaddr *addr, std::shared_ptr<Settings> settings, bool fuzzMode) : 30 +Client::Client(int fd, std::shared_ptr<ThreadData> threadData, SSL *ssl, bool websocket, struct sockaddr *addr, const Settings *settings, bool fuzzMode) :
31 fd(fd), 31 fd(fd),
32 fuzzMode(fuzzMode), 32 fuzzMode(fuzzMode),
33 initialBufferSize(settings->clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy 33 initialBufferSize(settings->clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy
@@ -51,7 +51,11 @@ Client::Client(int fd, std::shared_ptr&lt;ThreadData&gt; threadData, SSL *ssl, bool we @@ -51,7 +51,11 @@ Client::Client(int fd, std::shared_ptr&lt;ThreadData&gt; threadData, SSL *ssl, bool we
51 51
52 Client::~Client() 52 Client::~Client()
53 { 53 {
54 - std::shared_ptr<SubscriptionStore> &store = getThreadData()->getSubscriptionStore(); 54 + // Dummy clients, that I sometimes need just because the interface demands it but there's not actually a client, have no thread.
  55 + if (!this->threadData)
  56 + return;
  57 +
  58 + std::shared_ptr<SubscriptionStore> &store = this->threadData->getSubscriptionStore();
55 59
56 if (disconnectReason.empty()) 60 if (disconnectReason.empty())
57 disconnectReason = "not specified"; 61 disconnectReason = "not specified";
client.h
@@ -95,7 +95,7 @@ class Client @@ -95,7 +95,7 @@ class Client
95 void setReadyForReading(bool val); 95 void setReadyForReading(bool val);
96 96
97 public: 97 public:
98 - Client(int fd, std::shared_ptr<ThreadData> threadData, SSL *ssl, bool websocket, struct sockaddr *addr, std::shared_ptr<Settings> settings, bool fuzzMode=false); 98 + Client(int fd, std::shared_ptr<ThreadData> threadData, SSL *ssl, bool websocket, struct sockaddr *addr, const Settings *settings, bool fuzzMode=false);
99 Client(const Client &other) = delete; 99 Client(const Client &other) = delete;
100 Client(Client &&other) = delete; 100 Client(Client &&other) = delete;
101 ~Client(); 101 ~Client();
mainapp.cpp
@@ -471,12 +471,12 @@ void MainApp::start() @@ -471,12 +471,12 @@ void MainApp::start()
471 471
472 std::shared_ptr<ThreadData> threaddata = std::make_shared<ThreadData>(0, subscriptionStore, settings); 472 std::shared_ptr<ThreadData> threaddata = std::make_shared<ThreadData>(0, subscriptionStore, settings);
473 473
474 - std::shared_ptr<Client> client = std::make_shared<Client>(fd, threaddata, nullptr, fuzzWebsockets, nullptr, settings, true);  
475 - std::shared_ptr<Client> subscriber = std::make_shared<Client>(fdnull, threaddata, nullptr, fuzzWebsockets, nullptr, settings, true); 474 + std::shared_ptr<Client> client = std::make_shared<Client>(fd, threaddata, nullptr, fuzzWebsockets, nullptr, settings.get(), true);
  475 + std::shared_ptr<Client> subscriber = std::make_shared<Client>(fdnull, threaddata, nullptr, fuzzWebsockets, nullptr, settings.get(), true);
476 subscriber->setClientProperties(ProtocolVersion::Mqtt311, "subscriber", "subuser", true, 60); 476 subscriber->setClientProperties(ProtocolVersion::Mqtt311, "subscriber", "subuser", true, 60);
477 subscriber->setAuthenticated(true); 477 subscriber->setAuthenticated(true);
478 478
479 - std::shared_ptr<Client> websocketsubscriber = std::make_shared<Client>(fdnull2, threaddata, nullptr, true, nullptr, settings, true); 479 + std::shared_ptr<Client> websocketsubscriber = std::make_shared<Client>(fdnull2, threaddata, nullptr, true, nullptr, settings.get(), true);
480 websocketsubscriber->setClientProperties(ProtocolVersion::Mqtt311, "websocketsubscriber", "websocksubuser", true, 60); 480 websocketsubscriber->setClientProperties(ProtocolVersion::Mqtt311, "websocketsubscriber", "websocksubuser", true, 60);
481 websocketsubscriber->setAuthenticated(true); 481 websocketsubscriber->setAuthenticated(true);
482 websocketsubscriber->setFakeUpgraded(); 482 websocketsubscriber->setFakeUpgraded();
@@ -574,7 +574,7 @@ void MainApp::start() @@ -574,7 +574,7 @@ void MainApp::start()
574 SSL_set_fd(clientSSL, fd); 574 SSL_set_fd(clientSSL, fd);
575 } 575 }
576 576
577 - std::shared_ptr<Client> client = std::make_shared<Client>(fd, thread_data, clientSSL, listener->websocket, addr, settings); 577 + std::shared_ptr<Client> client = std::make_shared<Client>(fd, thread_data, clientSSL, listener->websocket, addr, settings.get());
578 thread_data->giveClient(client); 578 thread_data->giveClient(client);
579 } 579 }
580 else 580 else
mqttpacket.cpp
@@ -776,26 +776,23 @@ void MqttPacket::handleUnsubscribe() @@ -776,26 +776,23 @@ void MqttPacket::handleUnsubscribe()
776 sender->writeMqttPacket(response); 776 sender->writeMqttPacket(response);
777 } 777 }
778 778
779 -void MqttPacket::handlePublish(const bool stopAfterParsing) 779 +void MqttPacket::parsePublishData()
780 { 780 {
781 const uint16_t variable_header_length = readTwoBytesToUInt16(); 781 const uint16_t variable_header_length = readTwoBytesToUInt16();
782 782
783 - bool retain = (first_byte & 0b00000001); 783 + publishData.retain = (first_byte & 0b00000001);
784 bool dup = !!(first_byte & 0b00001000); 784 bool dup = !!(first_byte & 0b00001000);
785 - char qos = (first_byte & 0b00000110) >> 1; 785 + publishData.qos = (first_byte & 0b00000110) >> 1;
786 786
787 - if (qos > 2) 787 + if (publishData.qos > 2)
788 throw ProtocolError("QoS 3 is a protocol violation."); 788 throw ProtocolError("QoS 3 is a protocol violation.");
789 - this->publishData.qos = qos;  
790 789
791 - if (qos == 0 && dup) 790 + if (publishData.qos == 0 && dup)
792 throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal."); 791 throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal.");
793 792
794 publishData.topic = std::string(readBytes(variable_header_length), variable_header_length); 793 publishData.topic = std::string(readBytes(variable_header_length), variable_header_length);
795 794
796 - ReasonCodes ackCode = ReasonCodes::Success;  
797 -  
798 - if (qos) 795 + if (publishData.qos)
799 { 796 {
800 packet_id_pos = pos; 797 packet_id_pos = pos;
801 packet_id = readTwoBytesToUInt16(); 798 packet_id = readTwoBytesToUInt16();
@@ -827,6 +824,12 @@ void MqttPacket::handlePublish(const bool stopAfterParsing) @@ -827,6 +824,12 @@ void MqttPacket::handlePublish(const bool stopAfterParsing)
827 break; 824 break;
828 case Mqtt5Properties::TopicAlias: 825 case Mqtt5Properties::TopicAlias:
829 { 826 {
  827 + // For when we use packets has helpers without a senser (like loading packets from disk).
  828 + // Logically, this should never trip because there can't be aliases in such packets, but including
  829 + // a check to be sure.
  830 + if (!sender)
  831 + break;
  832 +
830 const uint16_t alias_id = readTwoBytesToUInt16(); 833 const uint16_t alias_id = readTwoBytesToUInt16();
831 this->hasTopicAlias = true; 834 this->hasTopicAlias = true;
832 835
@@ -889,6 +892,14 @@ void MqttPacket::handlePublish(const bool stopAfterParsing) @@ -889,6 +892,14 @@ void MqttPacket::handlePublish(const bool stopAfterParsing)
889 if (publishData.topic.empty()) 892 if (publishData.topic.empty())
890 throw ProtocolError("Empty publish topic"); 893 throw ProtocolError("Empty publish topic");
891 894
  895 + payloadLen = remainingAfterPos();
  896 + payloadStart = pos;
  897 +}
  898 +
  899 +void MqttPacket::handlePublish()
  900 +{
  901 + parsePublishData();
  902 +
892 if (!isValidUtf8(publishData.topic, true)) 903 if (!isValidUtf8(publishData.topic, true))
893 { 904 {
894 const std::string err = formatString("Client '%s' published a message with invalid UTF8 or $/+/# in it. Dropping.", sender->repr().c_str()); 905 const std::string err = formatString("Client '%s' published a message with invalid UTF8 or $/+/# in it. Dropping.", sender->repr().c_str());
@@ -897,41 +908,36 @@ void MqttPacket::handlePublish(const bool stopAfterParsing) @@ -897,41 +908,36 @@ void MqttPacket::handlePublish(const bool stopAfterParsing)
897 } 908 }
898 909
899 #ifndef NDEBUG 910 #ifndef NDEBUG
900 - logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", publishData.topic.c_str(), qos, retain, dup); 911 + logger->logf(LOG_DEBUG, "Publish received, topic '%s'. QoS=%d. Retain=%d, dup=%d", publishData.topic.c_str(), publishData.qos, publishData.retain, dup);
901 #endif 912 #endif
902 913
903 - payloadLen = remainingAfterPos();  
904 - payloadStart = pos; 914 + ReasonCodes ackCode = ReasonCodes::Success;
905 915
906 sender->getThreadData()->incrementReceivedMessageCount(); 916 sender->getThreadData()->incrementReceivedMessageCount();
907 917
908 - // TODO: or maybe create a function parsePublishData().  
909 - if (stopAfterParsing)  
910 - return;  
911 -  
912 Authentication &authentication = *ThreadGlobals::getAuth(); 918 Authentication &authentication = *ThreadGlobals::getAuth();
913 919
914 // Working with a local copy because the subscribing action will modify this->packet_id. See the PublishCopyFactory. 920 // Working with a local copy because the subscribing action will modify this->packet_id. See the PublishCopyFactory.
915 const uint16_t _packet_id = this->packet_id; 921 const uint16_t _packet_id = this->packet_id;
916 922
917 - if (qos == 2 && sender->getSession()->incomingQoS2MessageIdInTransit(_packet_id)) 923 + if (publishData.qos == 2 && sender->getSession()->incomingQoS2MessageIdInTransit(_packet_id))
918 { 924 {
919 ackCode = ReasonCodes::PacketIdentifierInUse; 925 ackCode = ReasonCodes::PacketIdentifierInUse;
920 } 926 }
921 else 927 else
922 { 928 {
923 // Doing this before the authentication on purpose, so when the publish is not allowed, the QoS control packets are allowed and can finish. 929 // Doing this before the authentication on purpose, so when the publish is not allowed, the QoS control packets are allowed and can finish.
924 - if (qos == 2) 930 + if (publishData.qos == 2)
925 sender->getSession()->addIncomingQoS2MessageId(_packet_id); 931 sender->getSession()->addIncomingQoS2MessageId(_packet_id);
926 932
927 splitTopic(publishData.topic, publishData.subtopics); 933 splitTopic(publishData.topic, publishData.subtopics);
928 934
929 - if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.subtopics, AclAccess::write, qos, retain, getUserProperties()) == AuthResult::success) 935 + if (authentication.aclCheck(sender->getClientId(), sender->getUsername(), publishData.topic, publishData.subtopics, AclAccess::write, publishData.qos, publishData.retain, getUserProperties()) == AuthResult::success)
930 { 936 {
931 - if (retain) 937 + if (publishData.retain)
932 { 938 {
933 std::string payload(readBytes(payloadLen), payloadLen); 939 std::string payload(readBytes(payloadLen), payloadLen);
934 - sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(publishData.topic, publishData.subtopics, payload, qos); 940 + sender->getThreadData()->getSubscriptionStore()->setRetainedMessage(publishData.topic, publishData.subtopics, payload, publishData.qos);
935 } 941 }
936 942
937 // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. 943 // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3].
@@ -953,9 +959,9 @@ void MqttPacket::handlePublish(const bool stopAfterParsing) @@ -953,9 +959,9 @@ void MqttPacket::handlePublish(const bool stopAfterParsing)
953 this->packet_id = 0; 959 this->packet_id = 0;
954 #endif 960 #endif
955 961
956 - if (qos > 0) 962 + if (publishData.qos > 0)
957 { 963 {
958 - const PacketType responseType = qos == 1 ? PacketType::PUBACK : PacketType::PUBREC; 964 + const PacketType responseType = publishData.qos == 1 ? PacketType::PUBACK : PacketType::PUBREC;
959 PubResponse pubAck(this->protocolVersion, responseType, ackCode, _packet_id); 965 PubResponse pubAck(this->protocolVersion, responseType, ackCode, _packet_id);
960 MqttPacket response(pubAck); 966 MqttPacket response(pubAck);
961 sender->writeMqttPacket(response); 967 sender->writeMqttPacket(response);
mqttpacket.h
@@ -96,7 +96,8 @@ public: @@ -96,7 +96,8 @@ public:
96 void handleSubscribe(); 96 void handleSubscribe();
97 void handleUnsubscribe(); 97 void handleUnsubscribe();
98 void handlePing(); 98 void handlePing();
99 - void handlePublish(const bool stopAfterParsing = false); 99 + void parsePublishData();
  100 + void handlePublish();
100 void handlePubAck(); 101 void handlePubAck();
101 void handlePubRec(); 102 void handlePubRec();
102 void handlePubRel(); 103 void handlePubRel();
sessionsandsubscriptionsdb.cpp
@@ -85,9 +85,9 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2() @@ -85,9 +85,9 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2()
85 CirBuf cirbuf(1024); 85 CirBuf cirbuf(1024);
86 86
87 // TODO: all that settings and thread data needs to be removed from Client. 87 // TODO: all that settings and thread data needs to be removed from Client.
88 - std::shared_ptr<ThreadData> dummyThreadData;  
89 - std::shared_ptr<Settings> dummySettings(new Settings()); // TODO: this is wrong: these are not from config file  
90 - std::shared_ptr<Client> dummyClient(new Client(0, dummyThreadData, nullptr, false, nullptr, dummySettings, false)); 88 + std::shared_ptr<ThreadData> dummyThreadData; // which thread am I going get/use here?
  89 + std::shared_ptr<Client> dummyClient(new Client(0, dummyThreadData, nullptr, false, nullptr, settings, false));
  90 + dummyClient->setClientProperties(ProtocolVersion::Mqtt5, "Dummyforloadingqueuedqos", "nobody", true, 60);
91 91
92 for (uint32_t i = 0; i < nrOfSessions; i++) 92 for (uint32_t i = 0; i < nrOfSessions; i++)
93 { 93 {
@@ -123,7 +123,7 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2() @@ -123,7 +123,7 @@ SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV2()
123 cirbuf.advanceHead(packlen); 123 cirbuf.advanceHead(packlen);
124 MqttPacket pack(cirbuf, packlen, 2, dummyClient); // TODO: store the 2 in the file 124 MqttPacket pack(cirbuf, packlen, 2, dummyClient); // TODO: store the 2 in the file
125 125
126 - pack.handlePublish(true); 126 + pack.parsePublishData();
127 Publish pub(pack.getPublishData()); 127 Publish pub(pack.getPublishData());
128 128
129 logger->logf(LOG_DEBUG, "Loaded QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str()); 129 logger->logf(LOG_DEBUG, "Loaded QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str());
@@ -240,12 +240,15 @@ void SessionsAndSubscriptionsDB::saveData(const std::vector&lt;std::unique_ptr&lt;Sess @@ -240,12 +240,15 @@ void SessionsAndSubscriptionsDB::saveData(const std::vector&lt;std::unique_ptr&lt;Sess
240 qosPacketsCounted++; 240 qosPacketsCounted++;
241 241
242 const Publish &pub = p.getPublish(); 242 const Publish &pub = p.getPublish();
  243 +
243 assert(!pub.splitTopic); 244 assert(!pub.splitTopic);
  245 + assert(!pub.skipTopic);
244 assert(pub.topicAlias == 0); 246 assert(pub.topicAlias == 0);
245 247
246 logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str()); 248 logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str());
247 249
248 - const MqttPacket pack(ProtocolVersion::Mqtt5, pub); 250 + MqttPacket pack(ProtocolVersion::Mqtt5, pub);
  251 + pack.setPacketId(p.getPacketId());
249 const uint32_t packSize = pack.getSizeIncludingNonPresentHeader(); 252 const uint32_t packSize = pack.getSizeIncludingNonPresentHeader();
250 cirbuf.reset(); 253 cirbuf.reset();
251 cirbuf.ensureFreeSpace(packSize + 32); 254 cirbuf.ensureFreeSpace(packSize + 32);