Commit a682f1e0348fae4fff5c6a70a01d16775b547737

Authored by Wiebe Cazemier
1 parent e6a16009

Start mqtt5 by the connect properties

Most of it is limits we already implemented non-standard compliant.
client.cpp
... ... @@ -25,6 +25,7 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
25 25  
26 26 #include "logger.h"
27 27 #include "utils.h"
  28 +#include "threadglobals.h"
28 29  
29 30 Client::Client(int fd, std::shared_ptr<ThreadData> threadData, SSL *ssl, bool websocket, struct sockaddr *addr, std::shared_ptr<Settings> settings, bool fuzzMode) :
30 31 fd(fd),
... ... @@ -422,12 +423,24 @@ void Client::bufferToMqttPackets(std::vector&lt;MqttPacket&gt; &amp;packetQueueIn, std::sh
422 423  
423 424 void Client::setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession)
424 425 {
  426 + const Settings *settings = ThreadGlobals::getSettings();
  427 +
  428 + setClientProperties(protocolVersion, clientId, username, connectPacketSeen, keepalive, cleanSession,
  429 + settings->maxPacketSize, 0);
  430 +}
  431 +
  432 +
  433 +void Client::setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive,
  434 + bool cleanSession, uint32_t maxPacketSize, uint16_t maxTopicAliases)
  435 +{
425 436 this->protocolVersion = protocolVersion;
426 437 this->clientid = clientId;
427 438 this->username = username;
428 439 this->connectPacketSeen = connectPacketSeen;
429 440 this->keepalive = keepalive;
430 441 this->cleanSession = cleanSession;
  442 + this->maxPacketSize = maxPacketSize;
  443 + this->maxTopicAliases = maxTopicAliases;
431 444 }
432 445  
433 446 void Client::setWill(const std::string &topic, const std::string &payload, bool retain, char qos)
... ...
client.h
... ... @@ -52,7 +52,8 @@ class Client
52 52 ProtocolVersion protocolVersion = ProtocolVersion::None;
53 53  
54 54 const size_t initialBufferSize = 0;
55   - const size_t maxPacketSize = 0;
  55 + uint32_t maxPacketSize = 0;
  56 + uint16_t maxTopicAliases = 0;
56 57  
57 58 IoWrapper ioWrapper;
58 59 std::string transportStr;
... ... @@ -108,6 +109,8 @@ public:
108 109 bool readFdIntoBuffer();
109 110 void bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender);
110 111 void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession);
  112 + void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive,
  113 + bool cleanSession, uint32_t maxPacketSize, uint16_t maxTopicAliases);
111 114 void setWill(const std::string &topic, const std::string &payload, bool retain, char qos);
112 115 void clearWill();
113 116 void setAuthenticated(bool value) { authenticated = value;}
... ...
configfileparser.cpp
... ... @@ -406,8 +406,8 @@ void ConfigFileParser::loadFile(bool test)
406 406  
407 407 if (key == "expire_sessions_after_seconds")
408 408 {
409   - int64_t newVal = std::stoi(value);
410   - if (newVal < 0 || (newVal > 0 && newVal <= 300)) // 0 means disable
  409 + uint32_t newVal = std::stoi(value);
  410 + if (newVal > 0 && newVal <= 300) // 0 means disable
411 411 {
412 412 throw ConfigFileException(formatString("expire_sessions_after_seconds value '%d' is invalid. Valid values are 0, or 300 or higher.", newVal));
413 413 }
... ...
mainapp.cpp
... ... @@ -713,7 +713,7 @@ void MainApp::queueCleanup()
713 713 {
714 714 std::lock_guard<std::mutex> locker(eventMutex);
715 715  
716   - auto f = std::bind(&SubscriptionStore::removeExpiredSessionsClients, subscriptionStore.get(), settings->expireSessionsAfterSeconds);
  716 + auto f = std::bind(&SubscriptionStore::removeExpiredSessionsClients, subscriptionStore.get());
717 717 taskQueue.push_front(f);
718 718  
719 719 wakeUpThread();
... ...
mqttpacket.cpp
... ... @@ -321,7 +321,7 @@ void MqttPacket::handleConnect()
321 321  
322 322 uint16_t variable_header_length = readTwoBytesToUInt16();
323 323  
324   - const Settings &settings = sender->getThreadData()->settingsLocalCopy;
  324 + const Settings &settings = *ThreadGlobals::getSettings();
325 325  
326 326 if (variable_header_length == 4 || variable_header_length == 6)
327 327 {
... ... @@ -330,9 +330,12 @@ void MqttPacket::handleConnect()
330 330  
331 331 char protocol_level = readByte();
332 332  
333   - if (magic_marker == "MQTT" && protocol_level == 0x04)
  333 + if (magic_marker == "MQTT")
334 334 {
335   - protocolVersion = ProtocolVersion::Mqtt311;
  335 + if (protocol_level == 0x04)
  336 + protocolVersion = ProtocolVersion::Mqtt311;
  337 + if (protocol_level == 0x05)
  338 + protocolVersion = ProtocolVersion::Mqtt5;
336 339 }
337 340 else if (magic_marker == "MQIsdp" && protocol_level == 0x03)
338 341 {
... ... @@ -367,6 +370,54 @@ void MqttPacket::handleConnect()
367 370  
368 371 uint16_t keep_alive = readTwoBytesToUInt16();
369 372  
  373 + uint16_t max_qos_packets = settings.maxQosMsgPendingPerClient;
  374 + uint32_t session_expire = settings.expireSessionsAfterSeconds > 0 ? settings.expireSessionsAfterSeconds : std::numeric_limits<uint32_t>::max();
  375 + uint32_t max_packet_size = settings.maxPacketSize;
  376 + uint16_t max_topic_aliases = 0;
  377 + bool request_response_information = false;
  378 + bool request_problem_information = false;
  379 +
  380 + if (protocolVersion == ProtocolVersion::Mqtt5)
  381 + {
  382 + const size_t proplen = decodeVariableByteIntAtPos();
  383 + const size_t prop_end_at = pos + proplen;
  384 +
  385 + while (pos < prop_end_at)
  386 + {
  387 + const Mqtt5Properties prop = static_cast<Mqtt5Properties>(readByte());
  388 +
  389 + switch (prop)
  390 + {
  391 + case Mqtt5Properties::SessionExpiryInterval:
  392 + session_expire = std::min<uint32_t>(readFourBytesToUint32(), session_expire);
  393 + break;
  394 + case Mqtt5Properties::ReceiveMaximum:
  395 + max_qos_packets = std::min<int16_t>(readTwoBytesToUInt16(), max_qos_packets);
  396 + break;
  397 + case Mqtt5Properties::MaximumPacketSize:
  398 + max_packet_size = std::min<uint32_t>(readFourBytesToUint32(), max_packet_size);
  399 + break;
  400 + case Mqtt5Properties::TopicAliasMaximum:
  401 + max_topic_aliases = readTwoBytesToUInt16();
  402 + break;
  403 + case Mqtt5Properties::RequestResponseInformation:
  404 + request_response_information = !!readByte();
  405 + break;
  406 + case Mqtt5Properties::RequestProblemInformation:
  407 + request_problem_information = !!readByte();
  408 + break;
  409 + case Mqtt5Properties::UserProperty:
  410 + break;
  411 + case Mqtt5Properties::AuthenticationMethod:
  412 + break;
  413 + case Mqtt5Properties::AuthenticationData:
  414 + break;
  415 + default:
  416 + throw ProtocolError("Invalid connect property.");
  417 + }
  418 + }
  419 + }
  420 +
370 421 uint16_t client_id_length = readTwoBytesToUInt16();
371 422 std::string client_id(readBytes(client_id_length), client_id_length);
372 423  
... ... @@ -442,7 +493,7 @@ void MqttPacket::handleConnect()
442 493 client_id = getSecureRandomString(23);
443 494 }
444 495  
445   - sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, clean_session);
  496 + sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, clean_session, max_packet_size, max_topic_aliases);
446 497 sender->setWill(will_topic, will_payload, will_retain, will_qos);
447 498  
448 499 bool accessGranted = false;
... ... @@ -475,7 +526,7 @@ void MqttPacket::handleConnect()
475 526 sender->writeMqttPacket(response);
476 527 logger->logf(LOG_NOTICE, "Client '%s' logged in successfully", sender->repr().c_str());
477 528  
478   - subscriptionStore->registerClientAndKickExistingOne(sender);
  529 + subscriptionStore->registerClientAndKickExistingOne(sender, max_qos_packets, session_expire);
479 530 }
480 531 else
481 532 {
... ... @@ -921,11 +972,45 @@ uint16_t MqttPacket::readTwoBytesToUInt16()
921 972 return i;
922 973 }
923 974  
  975 +uint32_t MqttPacket::readFourBytesToUint32()
  976 +{
  977 + if (pos + 4 > bites.size())
  978 + throw ProtocolError("Invalid packet: header specifies invalid length.");
  979 +
  980 + const uint8_t a = bites[pos++];
  981 + const uint8_t b = bites[pos++];
  982 + const uint8_t c = bites[pos++];
  983 + const uint8_t d = bites[pos++];
  984 + uint32_t i = (a << 24) | (b << 16) | (c << 8) | d;
  985 + return i;
  986 +}
  987 +
924 988 size_t MqttPacket::remainingAfterPos()
925 989 {
926 990 return bites.size() - pos;
927 991 }
928 992  
  993 +size_t MqttPacket::decodeVariableByteIntAtPos()
  994 +{
  995 + uint64_t multiplier = 1;
  996 + size_t value = 0;
  997 + uint8_t encodedByte = 0;
  998 + do
  999 + {
  1000 + if (pos >= bites.size())
  1001 + throw ProtocolError("Variable byte int length goes out of packet. Corrupt.");
  1002 +
  1003 + encodedByte = bites[pos++];
  1004 + value += (encodedByte & 127) * multiplier;
  1005 + multiplier *= 128;
  1006 + if (multiplier > 128*128*128*128)
  1007 + throw ProtocolError("Malformed Remaining Length.");
  1008 + }
  1009 + while ((encodedByte & 128) != 0);
  1010 +
  1011 + return value;
  1012 +}
  1013 +
929 1014 bool MqttPacket::getRetain() const
930 1015 {
931 1016 return (first_byte & 0b00000001);
... ...
mqttpacket.h
... ... @@ -63,7 +63,9 @@ class MqttPacket
63 63 void writeUint16(uint16_t x);
64 64 void writeBytes(const char *b, size_t len);
65 65 uint16_t readTwoBytesToUInt16();
  66 + uint32_t readFourBytesToUint32();
66 67 size_t remainingAfterPos();
  68 + size_t decodeVariableByteIntAtPos();
67 69  
68 70 void calculateRemainingLength();
69 71 void pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits = 0);
... ...
session.cpp
... ... @@ -20,12 +20,17 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
20 20 #include "session.h"
21 21 #include "client.h"
22 22 #include "threadglobals.h"
  23 +#include "threadglobals.h"
23 24  
24 25 std::chrono::time_point<std::chrono::steady_clock> appStartTime = std::chrono::steady_clock::now();
25 26  
26 27 Session::Session()
27 28 {
  29 + const Settings &settings = *ThreadGlobals::getSettings();
28 30  
  31 + // Sessions also get defaults from the handleConnect() method, but when you create sessions elsewhere, we do need some sensible defaults.
  32 + this->maxQosMsgPending = settings.maxQosMsgPendingPerClient;
  33 + this->sessionExpiryInterval = settings.expireSessionsAfterSeconds;
29 34 }
30 35  
31 36 int64_t Session::getProgramStartedAtUnixTimestamp()
... ... @@ -174,7 +179,7 @@ void Session::writePacket(PublishCopyFactory &amp;copyFactory, const char max_qos, u
174 179 std::unique_lock<std::mutex> locker(qosQueueMutex);
175 180  
176 181 const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size();
177   - if (totalQosPacketsInTransit >= settings->maxQosMsgPendingPerClient
  182 + if (totalQosPacketsInTransit >= maxQosMsgPending
178 183 || (qosPacketQueue.getByteSize() >= settings->maxQosBytesPendingPerClient && qosPacketQueue.size() > 0))
179 184 {
180 185 if (QoSLogPrintedAtId != nextPacketId)
... ... @@ -286,9 +291,9 @@ void Session::touch()
286 291 lastTouched = std::chrono::steady_clock::now();
287 292 }
288 293  
289   -bool Session::hasExpired(int expireAfterSeconds)
  294 +bool Session::hasExpired() const
290 295 {
291   - std::chrono::seconds expireAfter(expireAfterSeconds);
  296 + std::chrono::seconds expireAfter(sessionExpiryInterval);
292 297 std::chrono::time_point<std::chrono::steady_clock> now = std::chrono::steady_clock::now();
293 298 return client.expired() && (lastTouched + expireAfter) < now;
294 299 }
... ... @@ -347,3 +352,9 @@ bool Session::getCleanSession() const
347 352  
348 353 return c->getCleanSession();
349 354 }
  355 +
  356 +void Session::setSessionProperties(uint16_t maxQosPackets, uint32_t sessionExpiryInterval)
  357 +{
  358 + this->maxQosMsgPending = maxQosPackets;
  359 + this->sessionExpiryInterval = sessionExpiryInterval;
  360 +}
... ...
session.h
... ... @@ -46,6 +46,8 @@ class Session
46 46 std::mutex qosQueueMutex;
47 47 uint16_t nextPacketId = 0;
48 48 uint16_t qosInFlightCounter = 0;
  49 + uint32_t sessionExpiryInterval = 0;
  50 + uint16_t maxQosMsgPending;
49 51 uint16_t QoSLogPrintedAtId = 0;
50 52 std::chrono::time_point<std::chrono::steady_clock> lastTouched = std::chrono::steady_clock::now();
51 53 Logger *logger = Logger::getInstance();
... ... @@ -75,7 +77,7 @@ public:
75 77 uint64_t sendPendingQosMessages();
76 78 void touch(std::chrono::time_point<std::chrono::steady_clock> val);
77 79 void touch();
78   - bool hasExpired(int expireAfterSeconds);
  80 + bool hasExpired() const;
79 81  
80 82 void addIncomingQoS2MessageId(uint16_t packet_id);
81 83 bool incomingQoS2MessageIdInTransit(uint16_t packet_id);
... ... @@ -85,6 +87,8 @@ public:
85 87 void removeOutgoingQoS2MessageId(u_int16_t packet_id);
86 88  
87 89 bool getCleanSession() const;
  90 +
  91 + void setSessionProperties(uint16_t maxQosPackets, uint32_t sessionExpiryInterval);
88 92 };
89 93  
90 94 #endif // SESSION_H
... ...
settings.h
... ... @@ -52,11 +52,11 @@ public:
52 52 std::string mosquittoAclFile;
53 53 bool allowAnonymous = false;
54 54 int rlimitNoFile = 1000000;
55   - uint64_t expireSessionsAfterSeconds = 1209600;
  55 + uint32_t expireSessionsAfterSeconds = 1209600;
56 56 int authPluginTimerPeriod = 60;
57 57 std::string storageDir;
58 58 int threadCount = 0;
59   - uint maxQosMsgPendingPerClient = 512;
  59 + uint16_t maxQosMsgPendingPerClient = 512;
60 60 uint maxQosBytesPendingPerClient = 65536;
61 61 std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined.
62 62  
... ...
subscriptionstore.cpp
... ... @@ -22,6 +22,7 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
22 22 #include "rwlockguard.h"
23 23 #include "retainedmessagesdb.h"
24 24 #include "publishcopyfactory.h"
  25 +#include "threadglobals.h"
25 26  
26 27 ReceivingSubscriber::ReceivingSubscriber(const std::shared_ptr<Session> &ses, char qos) :
27 28 session(ses),
... ... @@ -200,9 +201,15 @@ void SubscriptionStore::removeSubscription(std::shared_ptr&lt;Client&gt; &amp;client, cons
200 201  
201 202 }
202 203  
203   -// Removes an existing client when it already exists [MQTT-3.1.4-2].
204 204 void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr<Client> &client)
205 205 {
  206 + const Settings *settings = ThreadGlobals::getSettings();
  207 + registerClientAndKickExistingOne(client, settings->maxQosMsgPendingPerClient, settings->expireSessionsAfterSeconds);
  208 +}
  209 +
  210 +// Removes an existing client when it already exists [MQTT-3.1.4-2].
  211 +void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr<Client> &client, uint16_t maxQosPackets, uint32_t sessionExpiryInterval)
  212 +{
206 213 RWLockGuard lock_guard(&subscriptionsRwlock);
207 214 lock_guard.wrlock();
208 215  
... ... @@ -247,6 +254,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr&lt;Client&gt;
247 254  
248 255 session->assignActiveConnection(client);
249 256 client->assignSession(session);
  257 + session->setSessionProperties(maxQosPackets, sessionExpiryInterval);
250 258 uint64_t count = session->sendPendingQosMessages();
251 259 client->getThreadData()->incrementSentMessageCount(count);
252 260 }
... ... @@ -533,8 +541,12 @@ void SubscriptionStore::removeSession(const std::string &amp;clientid)
533 541 }
534 542 }
535 543  
536   -// This is not MQTT compliant, but the standard doesn't keep real world constraints into account.
537   -void SubscriptionStore::removeExpiredSessionsClients(int expireSessionsAfterSeconds)
  544 +/**
  545 + * @brief SubscriptionStore::removeExpiredSessionsClients removes expired sessions.
  546 + *
  547 + * For Mqtt3 this is non-standard, but the standard doesn't keep real world constraints into account.
  548 + */
  549 +void SubscriptionStore::removeExpiredSessionsClients()
538 550 {
539 551 RWLockGuard lock_guard(&subscriptionsRwlock);
540 552 lock_guard.wrlock();
... ... @@ -546,7 +558,7 @@ void SubscriptionStore::removeExpiredSessionsClients(int expireSessionsAfterSeco
546 558 {
547 559 std::shared_ptr<Session> &session = session_it->second;
548 560  
549   - if (session->hasExpired(expireSessionsAfterSeconds))
  561 + if (session->hasExpired())
550 562 {
551 563 logger->logf(LOG_DEBUG, "Removing expired session from store %s", session->getClientId().c_str());
552 564 session_it = sessionsById.erase(session_it);
... ...
subscriptionstore.h
... ... @@ -119,6 +119,7 @@ public:
119 119 void addSubscription(std::shared_ptr<Client> &client, const std::string &topic, const std::vector<std::string> &subtopics, char qos);
120 120 void removeSubscription(std::shared_ptr<Client> &client, const std::string &topic);
121 121 void registerClientAndKickExistingOne(std::shared_ptr<Client> &client);
  122 + void registerClientAndKickExistingOne(std::shared_ptr<Client> &client, uint16_t maxQosPackets, uint32_t sessionExpiryInterval);
122 123 bool sessionPresent(const std::string &clientid);
123 124  
124 125 void queuePacketAtSubscribers(const std::vector<std::string> &subtopics, MqttPacket &packet, bool dollar = false);
... ... @@ -129,7 +130,7 @@ public:
129 130 void setRetainedMessage(const std::string &topic, const std::vector<std::string> &subtopics, const std::string &payload, char qos);
130 131  
131 132 void removeSession(const std::string &clientid);
132   - void removeExpiredSessionsClients(int expireSessionsAfterSeconds);
  133 + void removeExpiredSessionsClients();
133 134  
134 135 int64_t getRetainedMessageCount() const;
135 136 uint64_t getSessionCount() const;
... ...
... ... @@ -47,7 +47,40 @@ enum class ProtocolVersion
47 47 {
48 48 None = 0,
49 49 Mqtt31 = 0x03,
50   - Mqtt311 = 0x04
  50 + Mqtt311 = 0x04,
  51 + Mqtt5 = 0x05
  52 +};
  53 +
  54 +enum class Mqtt5Properties
  55 +{
  56 + None = 0,
  57 + PayloadFormatIndicator = 1,
  58 + MessageExpiryInterval = 2,
  59 + ContentType = 3,
  60 + ResponseTopic = 8,
  61 + CorrelationData = 9,
  62 + SubscriptionIdentifier = 11,
  63 + SessionExpiryInterval = 17,
  64 + AssignedClientIdentifier = 18,
  65 + ServerKeepAlive = 13,
  66 + AuthenticationMethod = 21,
  67 + AuthenticationData = 22,
  68 + RequestProblemInformation = 23,
  69 + WillDelayInterval = 24,
  70 + RequestResponseInformation = 25,
  71 + ResponseInformation = 26,
  72 + ServerReference = 28,
  73 + ReasonString = 31,
  74 + ReceiveMaximum = 33,
  75 + TopicAliasMaximum = 34,
  76 + TopicAlias = 35,
  77 + MaximumQoS = 36,
  78 + RetainAvailable = 37,
  79 + UserProperty = 38,
  80 + MaximumPacketSize = 39,
  81 + WildcardSubscriptionAvailable = 40,
  82 + SubscriptionIdentifierAvailable = 41,
  83 + SharedSubscriptionAvailable = 42
51 84 };
52 85  
53 86 enum class ConnAckReturnCodes
... ...
utils.cpp
... ... @@ -662,6 +662,8 @@ const std::string protocolVersionString(ProtocolVersion p)
662 662 return "3.1";
663 663 case ProtocolVersion::Mqtt311:
664 664 return "3.1.1";
  665 + case ProtocolVersion::Mqtt5:
  666 + return "5.0";
665 667 default:
666 668 return "unknown";
667 669 }
... ...