diff --git a/client.cpp b/client.cpp index 8b8b37f..152644b 100644 --- a/client.cpp +++ b/client.cpp @@ -25,6 +25,7 @@ License along with FlashMQ. If not, see . #include "logger.h" #include "utils.h" +#include "threadglobals.h" Client::Client(int fd, std::shared_ptr threadData, SSL *ssl, bool websocket, struct sockaddr *addr, std::shared_ptr settings, bool fuzzMode) : fd(fd), @@ -422,12 +423,24 @@ void Client::bufferToMqttPackets(std::vector &packetQueueIn, std::sh void Client::setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession) { + const Settings *settings = ThreadGlobals::getSettings(); + + setClientProperties(protocolVersion, clientId, username, connectPacketSeen, keepalive, cleanSession, + settings->maxPacketSize, 0); +} + + +void Client::setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, + bool cleanSession, uint32_t maxPacketSize, uint16_t maxTopicAliases) +{ this->protocolVersion = protocolVersion; this->clientid = clientId; this->username = username; this->connectPacketSeen = connectPacketSeen; this->keepalive = keepalive; this->cleanSession = cleanSession; + this->maxPacketSize = maxPacketSize; + this->maxTopicAliases = maxTopicAliases; } void Client::setWill(const std::string &topic, const std::string &payload, bool retain, char qos) diff --git a/client.h b/client.h index 0ca6eaa..c0da255 100644 --- a/client.h +++ b/client.h @@ -52,7 +52,8 @@ class Client ProtocolVersion protocolVersion = ProtocolVersion::None; const size_t initialBufferSize = 0; - const size_t maxPacketSize = 0; + uint32_t maxPacketSize = 0; + uint16_t maxTopicAliases = 0; IoWrapper ioWrapper; std::string transportStr; @@ -108,6 +109,8 @@ public: bool readFdIntoBuffer(); void bufferToMqttPackets(std::vector &packetQueueIn, std::shared_ptr &sender); void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession); + void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, + bool cleanSession, uint32_t maxPacketSize, uint16_t maxTopicAliases); void setWill(const std::string &topic, const std::string &payload, bool retain, char qos); void clearWill(); void setAuthenticated(bool value) { authenticated = value;} diff --git a/configfileparser.cpp b/configfileparser.cpp index f89e59c..86bb197 100644 --- a/configfileparser.cpp +++ b/configfileparser.cpp @@ -406,8 +406,8 @@ void ConfigFileParser::loadFile(bool test) if (key == "expire_sessions_after_seconds") { - int64_t newVal = std::stoi(value); - if (newVal < 0 || (newVal > 0 && newVal <= 300)) // 0 means disable + uint32_t newVal = std::stoi(value); + if (newVal > 0 && newVal <= 300) // 0 means disable { throw ConfigFileException(formatString("expire_sessions_after_seconds value '%d' is invalid. Valid values are 0, or 300 or higher.", newVal)); } diff --git a/mainapp.cpp b/mainapp.cpp index d6e1231..a13e7da 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -713,7 +713,7 @@ void MainApp::queueCleanup() { std::lock_guard locker(eventMutex); - auto f = std::bind(&SubscriptionStore::removeExpiredSessionsClients, subscriptionStore.get(), settings->expireSessionsAfterSeconds); + auto f = std::bind(&SubscriptionStore::removeExpiredSessionsClients, subscriptionStore.get()); taskQueue.push_front(f); wakeUpThread(); diff --git a/mqttpacket.cpp b/mqttpacket.cpp index d5c85be..7fd7890 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -321,7 +321,7 @@ void MqttPacket::handleConnect() uint16_t variable_header_length = readTwoBytesToUInt16(); - const Settings &settings = sender->getThreadData()->settingsLocalCopy; + const Settings &settings = *ThreadGlobals::getSettings(); if (variable_header_length == 4 || variable_header_length == 6) { @@ -330,9 +330,12 @@ void MqttPacket::handleConnect() char protocol_level = readByte(); - if (magic_marker == "MQTT" && protocol_level == 0x04) + if (magic_marker == "MQTT") { - protocolVersion = ProtocolVersion::Mqtt311; + if (protocol_level == 0x04) + protocolVersion = ProtocolVersion::Mqtt311; + if (protocol_level == 0x05) + protocolVersion = ProtocolVersion::Mqtt5; } else if (magic_marker == "MQIsdp" && protocol_level == 0x03) { @@ -367,6 +370,54 @@ void MqttPacket::handleConnect() uint16_t keep_alive = readTwoBytesToUInt16(); + uint16_t max_qos_packets = settings.maxQosMsgPendingPerClient; + uint32_t session_expire = settings.expireSessionsAfterSeconds > 0 ? settings.expireSessionsAfterSeconds : std::numeric_limits::max(); + uint32_t max_packet_size = settings.maxPacketSize; + uint16_t max_topic_aliases = 0; + bool request_response_information = false; + bool request_problem_information = false; + + if (protocolVersion == ProtocolVersion::Mqtt5) + { + const size_t proplen = decodeVariableByteIntAtPos(); + const size_t prop_end_at = pos + proplen; + + while (pos < prop_end_at) + { + const Mqtt5Properties prop = static_cast(readByte()); + + switch (prop) + { + case Mqtt5Properties::SessionExpiryInterval: + session_expire = std::min(readFourBytesToUint32(), session_expire); + break; + case Mqtt5Properties::ReceiveMaximum: + max_qos_packets = std::min(readTwoBytesToUInt16(), max_qos_packets); + break; + case Mqtt5Properties::MaximumPacketSize: + max_packet_size = std::min(readFourBytesToUint32(), max_packet_size); + break; + case Mqtt5Properties::TopicAliasMaximum: + max_topic_aliases = readTwoBytesToUInt16(); + break; + case Mqtt5Properties::RequestResponseInformation: + request_response_information = !!readByte(); + break; + case Mqtt5Properties::RequestProblemInformation: + request_problem_information = !!readByte(); + break; + case Mqtt5Properties::UserProperty: + break; + case Mqtt5Properties::AuthenticationMethod: + break; + case Mqtt5Properties::AuthenticationData: + break; + default: + throw ProtocolError("Invalid connect property."); + } + } + } + uint16_t client_id_length = readTwoBytesToUInt16(); std::string client_id(readBytes(client_id_length), client_id_length); @@ -442,7 +493,7 @@ void MqttPacket::handleConnect() client_id = getSecureRandomString(23); } - sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, clean_session); + sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, clean_session, max_packet_size, max_topic_aliases); sender->setWill(will_topic, will_payload, will_retain, will_qos); bool accessGranted = false; @@ -475,7 +526,7 @@ void MqttPacket::handleConnect() sender->writeMqttPacket(response); logger->logf(LOG_NOTICE, "Client '%s' logged in successfully", sender->repr().c_str()); - subscriptionStore->registerClientAndKickExistingOne(sender); + subscriptionStore->registerClientAndKickExistingOne(sender, max_qos_packets, session_expire); } else { @@ -921,11 +972,45 @@ uint16_t MqttPacket::readTwoBytesToUInt16() return i; } +uint32_t MqttPacket::readFourBytesToUint32() +{ + if (pos + 4 > bites.size()) + throw ProtocolError("Invalid packet: header specifies invalid length."); + + const uint8_t a = bites[pos++]; + const uint8_t b = bites[pos++]; + const uint8_t c = bites[pos++]; + const uint8_t d = bites[pos++]; + uint32_t i = (a << 24) | (b << 16) | (c << 8) | d; + return i; +} + size_t MqttPacket::remainingAfterPos() { return bites.size() - pos; } +size_t MqttPacket::decodeVariableByteIntAtPos() +{ + uint64_t multiplier = 1; + size_t value = 0; + uint8_t encodedByte = 0; + do + { + if (pos >= bites.size()) + throw ProtocolError("Variable byte int length goes out of packet. Corrupt."); + + encodedByte = bites[pos++]; + value += (encodedByte & 127) * multiplier; + multiplier *= 128; + if (multiplier > 128*128*128*128) + throw ProtocolError("Malformed Remaining Length."); + } + while ((encodedByte & 128) != 0); + + return value; +} + bool MqttPacket::getRetain() const { return (first_byte & 0b00000001); diff --git a/mqttpacket.h b/mqttpacket.h index 396eee3..283cc50 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -63,7 +63,9 @@ class MqttPacket void writeUint16(uint16_t x); void writeBytes(const char *b, size_t len); uint16_t readTwoBytesToUInt16(); + uint32_t readFourBytesToUint32(); size_t remainingAfterPos(); + size_t decodeVariableByteIntAtPos(); void calculateRemainingLength(); void pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits = 0); diff --git a/session.cpp b/session.cpp index c5540e9..705a2e6 100644 --- a/session.cpp +++ b/session.cpp @@ -20,12 +20,17 @@ License along with FlashMQ. If not, see . #include "session.h" #include "client.h" #include "threadglobals.h" +#include "threadglobals.h" std::chrono::time_point appStartTime = std::chrono::steady_clock::now(); Session::Session() { + const Settings &settings = *ThreadGlobals::getSettings(); + // Sessions also get defaults from the handleConnect() method, but when you create sessions elsewhere, we do need some sensible defaults. + this->maxQosMsgPending = settings.maxQosMsgPendingPerClient; + this->sessionExpiryInterval = settings.expireSessionsAfterSeconds; } int64_t Session::getProgramStartedAtUnixTimestamp() @@ -174,7 +179,7 @@ void Session::writePacket(PublishCopyFactory ©Factory, const char max_qos, u std::unique_lock locker(qosQueueMutex); const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size(); - if (totalQosPacketsInTransit >= settings->maxQosMsgPendingPerClient + if (totalQosPacketsInTransit >= maxQosMsgPending || (qosPacketQueue.getByteSize() >= settings->maxQosBytesPendingPerClient && qosPacketQueue.size() > 0)) { if (QoSLogPrintedAtId != nextPacketId) @@ -286,9 +291,9 @@ void Session::touch() lastTouched = std::chrono::steady_clock::now(); } -bool Session::hasExpired(int expireAfterSeconds) +bool Session::hasExpired() const { - std::chrono::seconds expireAfter(expireAfterSeconds); + std::chrono::seconds expireAfter(sessionExpiryInterval); std::chrono::time_point now = std::chrono::steady_clock::now(); return client.expired() && (lastTouched + expireAfter) < now; } @@ -347,3 +352,9 @@ bool Session::getCleanSession() const return c->getCleanSession(); } + +void Session::setSessionProperties(uint16_t maxQosPackets, uint32_t sessionExpiryInterval) +{ + this->maxQosMsgPending = maxQosPackets; + this->sessionExpiryInterval = sessionExpiryInterval; +} diff --git a/session.h b/session.h index 4f92abb..a525914 100644 --- a/session.h +++ b/session.h @@ -46,6 +46,8 @@ class Session std::mutex qosQueueMutex; uint16_t nextPacketId = 0; uint16_t qosInFlightCounter = 0; + uint32_t sessionExpiryInterval = 0; + uint16_t maxQosMsgPending; uint16_t QoSLogPrintedAtId = 0; std::chrono::time_point lastTouched = std::chrono::steady_clock::now(); Logger *logger = Logger::getInstance(); @@ -75,7 +77,7 @@ public: uint64_t sendPendingQosMessages(); void touch(std::chrono::time_point val); void touch(); - bool hasExpired(int expireAfterSeconds); + bool hasExpired() const; void addIncomingQoS2MessageId(uint16_t packet_id); bool incomingQoS2MessageIdInTransit(uint16_t packet_id); @@ -85,6 +87,8 @@ public: void removeOutgoingQoS2MessageId(u_int16_t packet_id); bool getCleanSession() const; + + void setSessionProperties(uint16_t maxQosPackets, uint32_t sessionExpiryInterval); }; #endif // SESSION_H diff --git a/settings.h b/settings.h index 55d0216..6910629 100644 --- a/settings.h +++ b/settings.h @@ -52,11 +52,11 @@ public: std::string mosquittoAclFile; bool allowAnonymous = false; int rlimitNoFile = 1000000; - uint64_t expireSessionsAfterSeconds = 1209600; + uint32_t expireSessionsAfterSeconds = 1209600; int authPluginTimerPeriod = 60; std::string storageDir; int threadCount = 0; - uint maxQosMsgPendingPerClient = 512; + uint16_t maxQosMsgPendingPerClient = 512; uint maxQosBytesPendingPerClient = 65536; std::list> listeners; // Default one is created later, when none are defined. diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp index 11c8d4c..a0d0cce 100644 --- a/subscriptionstore.cpp +++ b/subscriptionstore.cpp @@ -22,6 +22,7 @@ License along with FlashMQ. If not, see . #include "rwlockguard.h" #include "retainedmessagesdb.h" #include "publishcopyfactory.h" +#include "threadglobals.h" ReceivingSubscriber::ReceivingSubscriber(const std::shared_ptr &ses, char qos) : session(ses), @@ -200,9 +201,15 @@ void SubscriptionStore::removeSubscription(std::shared_ptr &client, cons } -// Removes an existing client when it already exists [MQTT-3.1.4-2]. void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr &client) { + const Settings *settings = ThreadGlobals::getSettings(); + registerClientAndKickExistingOne(client, settings->maxQosMsgPendingPerClient, settings->expireSessionsAfterSeconds); +} + +// Removes an existing client when it already exists [MQTT-3.1.4-2]. +void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr &client, uint16_t maxQosPackets, uint32_t sessionExpiryInterval) +{ RWLockGuard lock_guard(&subscriptionsRwlock); lock_guard.wrlock(); @@ -247,6 +254,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr session->assignActiveConnection(client); client->assignSession(session); + session->setSessionProperties(maxQosPackets, sessionExpiryInterval); uint64_t count = session->sendPendingQosMessages(); client->getThreadData()->incrementSentMessageCount(count); } @@ -533,8 +541,12 @@ void SubscriptionStore::removeSession(const std::string &clientid) } } -// This is not MQTT compliant, but the standard doesn't keep real world constraints into account. -void SubscriptionStore::removeExpiredSessionsClients(int expireSessionsAfterSeconds) +/** + * @brief SubscriptionStore::removeExpiredSessionsClients removes expired sessions. + * + * For Mqtt3 this is non-standard, but the standard doesn't keep real world constraints into account. + */ +void SubscriptionStore::removeExpiredSessionsClients() { RWLockGuard lock_guard(&subscriptionsRwlock); lock_guard.wrlock(); @@ -546,7 +558,7 @@ void SubscriptionStore::removeExpiredSessionsClients(int expireSessionsAfterSeco { std::shared_ptr &session = session_it->second; - if (session->hasExpired(expireSessionsAfterSeconds)) + if (session->hasExpired()) { logger->logf(LOG_DEBUG, "Removing expired session from store %s", session->getClientId().c_str()); session_it = sessionsById.erase(session_it); diff --git a/subscriptionstore.h b/subscriptionstore.h index 2c2ab2e..e7e549c 100644 --- a/subscriptionstore.h +++ b/subscriptionstore.h @@ -119,6 +119,7 @@ public: void addSubscription(std::shared_ptr &client, const std::string &topic, const std::vector &subtopics, char qos); void removeSubscription(std::shared_ptr &client, const std::string &topic); void registerClientAndKickExistingOne(std::shared_ptr &client); + void registerClientAndKickExistingOne(std::shared_ptr &client, uint16_t maxQosPackets, uint32_t sessionExpiryInterval); bool sessionPresent(const std::string &clientid); void queuePacketAtSubscribers(const std::vector &subtopics, MqttPacket &packet, bool dollar = false); @@ -129,7 +130,7 @@ public: void setRetainedMessage(const std::string &topic, const std::vector &subtopics, const std::string &payload, char qos); void removeSession(const std::string &clientid); - void removeExpiredSessionsClients(int expireSessionsAfterSeconds); + void removeExpiredSessionsClients(); int64_t getRetainedMessageCount() const; uint64_t getSessionCount() const; diff --git a/types.h b/types.h index f232e75..05aa950 100644 --- a/types.h +++ b/types.h @@ -47,7 +47,40 @@ enum class ProtocolVersion { None = 0, Mqtt31 = 0x03, - Mqtt311 = 0x04 + Mqtt311 = 0x04, + Mqtt5 = 0x05 +}; + +enum class Mqtt5Properties +{ + None = 0, + PayloadFormatIndicator = 1, + MessageExpiryInterval = 2, + ContentType = 3, + ResponseTopic = 8, + CorrelationData = 9, + SubscriptionIdentifier = 11, + SessionExpiryInterval = 17, + AssignedClientIdentifier = 18, + ServerKeepAlive = 13, + AuthenticationMethod = 21, + AuthenticationData = 22, + RequestProblemInformation = 23, + WillDelayInterval = 24, + RequestResponseInformation = 25, + ResponseInformation = 26, + ServerReference = 28, + ReasonString = 31, + ReceiveMaximum = 33, + TopicAliasMaximum = 34, + TopicAlias = 35, + MaximumQoS = 36, + RetainAvailable = 37, + UserProperty = 38, + MaximumPacketSize = 39, + WildcardSubscriptionAvailable = 40, + SubscriptionIdentifierAvailable = 41, + SharedSubscriptionAvailable = 42 }; enum class ConnAckReturnCodes diff --git a/utils.cpp b/utils.cpp index 8ce6334..0deadaf 100644 --- a/utils.cpp +++ b/utils.cpp @@ -662,6 +662,8 @@ const std::string protocolVersionString(ProtocolVersion p) return "3.1"; case ProtocolVersion::Mqtt311: return "3.1.1"; + case ProtocolVersion::Mqtt5: + return "5.0"; default: return "unknown"; }