diff --git a/client.cpp b/client.cpp
index 8b8b37f..152644b 100644
--- a/client.cpp
+++ b/client.cpp
@@ -25,6 +25,7 @@ License along with FlashMQ. If not, see .
#include "logger.h"
#include "utils.h"
+#include "threadglobals.h"
Client::Client(int fd, std::shared_ptr threadData, SSL *ssl, bool websocket, struct sockaddr *addr, std::shared_ptr settings, bool fuzzMode) :
fd(fd),
@@ -422,12 +423,24 @@ void Client::bufferToMqttPackets(std::vector &packetQueueIn, std::sh
void Client::setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession)
{
+ const Settings *settings = ThreadGlobals::getSettings();
+
+ setClientProperties(protocolVersion, clientId, username, connectPacketSeen, keepalive, cleanSession,
+ settings->maxPacketSize, 0);
+}
+
+
+void Client::setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive,
+ bool cleanSession, uint32_t maxPacketSize, uint16_t maxTopicAliases)
+{
this->protocolVersion = protocolVersion;
this->clientid = clientId;
this->username = username;
this->connectPacketSeen = connectPacketSeen;
this->keepalive = keepalive;
this->cleanSession = cleanSession;
+ this->maxPacketSize = maxPacketSize;
+ this->maxTopicAliases = maxTopicAliases;
}
void Client::setWill(const std::string &topic, const std::string &payload, bool retain, char qos)
diff --git a/client.h b/client.h
index 0ca6eaa..c0da255 100644
--- a/client.h
+++ b/client.h
@@ -52,7 +52,8 @@ class Client
ProtocolVersion protocolVersion = ProtocolVersion::None;
const size_t initialBufferSize = 0;
- const size_t maxPacketSize = 0;
+ uint32_t maxPacketSize = 0;
+ uint16_t maxTopicAliases = 0;
IoWrapper ioWrapper;
std::string transportStr;
@@ -108,6 +109,8 @@ public:
bool readFdIntoBuffer();
void bufferToMqttPackets(std::vector &packetQueueIn, std::shared_ptr &sender);
void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession);
+ void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive,
+ bool cleanSession, uint32_t maxPacketSize, uint16_t maxTopicAliases);
void setWill(const std::string &topic, const std::string &payload, bool retain, char qos);
void clearWill();
void setAuthenticated(bool value) { authenticated = value;}
diff --git a/configfileparser.cpp b/configfileparser.cpp
index f89e59c..86bb197 100644
--- a/configfileparser.cpp
+++ b/configfileparser.cpp
@@ -406,8 +406,8 @@ void ConfigFileParser::loadFile(bool test)
if (key == "expire_sessions_after_seconds")
{
- int64_t newVal = std::stoi(value);
- if (newVal < 0 || (newVal > 0 && newVal <= 300)) // 0 means disable
+ uint32_t newVal = std::stoi(value);
+ if (newVal > 0 && newVal <= 300) // 0 means disable
{
throw ConfigFileException(formatString("expire_sessions_after_seconds value '%d' is invalid. Valid values are 0, or 300 or higher.", newVal));
}
diff --git a/mainapp.cpp b/mainapp.cpp
index d6e1231..a13e7da 100644
--- a/mainapp.cpp
+++ b/mainapp.cpp
@@ -713,7 +713,7 @@ void MainApp::queueCleanup()
{
std::lock_guard locker(eventMutex);
- auto f = std::bind(&SubscriptionStore::removeExpiredSessionsClients, subscriptionStore.get(), settings->expireSessionsAfterSeconds);
+ auto f = std::bind(&SubscriptionStore::removeExpiredSessionsClients, subscriptionStore.get());
taskQueue.push_front(f);
wakeUpThread();
diff --git a/mqttpacket.cpp b/mqttpacket.cpp
index d5c85be..7fd7890 100644
--- a/mqttpacket.cpp
+++ b/mqttpacket.cpp
@@ -321,7 +321,7 @@ void MqttPacket::handleConnect()
uint16_t variable_header_length = readTwoBytesToUInt16();
- const Settings &settings = sender->getThreadData()->settingsLocalCopy;
+ const Settings &settings = *ThreadGlobals::getSettings();
if (variable_header_length == 4 || variable_header_length == 6)
{
@@ -330,9 +330,12 @@ void MqttPacket::handleConnect()
char protocol_level = readByte();
- if (magic_marker == "MQTT" && protocol_level == 0x04)
+ if (magic_marker == "MQTT")
{
- protocolVersion = ProtocolVersion::Mqtt311;
+ if (protocol_level == 0x04)
+ protocolVersion = ProtocolVersion::Mqtt311;
+ if (protocol_level == 0x05)
+ protocolVersion = ProtocolVersion::Mqtt5;
}
else if (magic_marker == "MQIsdp" && protocol_level == 0x03)
{
@@ -367,6 +370,54 @@ void MqttPacket::handleConnect()
uint16_t keep_alive = readTwoBytesToUInt16();
+ uint16_t max_qos_packets = settings.maxQosMsgPendingPerClient;
+ uint32_t session_expire = settings.expireSessionsAfterSeconds > 0 ? settings.expireSessionsAfterSeconds : std::numeric_limits::max();
+ uint32_t max_packet_size = settings.maxPacketSize;
+ uint16_t max_topic_aliases = 0;
+ bool request_response_information = false;
+ bool request_problem_information = false;
+
+ if (protocolVersion == ProtocolVersion::Mqtt5)
+ {
+ const size_t proplen = decodeVariableByteIntAtPos();
+ const size_t prop_end_at = pos + proplen;
+
+ while (pos < prop_end_at)
+ {
+ const Mqtt5Properties prop = static_cast(readByte());
+
+ switch (prop)
+ {
+ case Mqtt5Properties::SessionExpiryInterval:
+ session_expire = std::min(readFourBytesToUint32(), session_expire);
+ break;
+ case Mqtt5Properties::ReceiveMaximum:
+ max_qos_packets = std::min(readTwoBytesToUInt16(), max_qos_packets);
+ break;
+ case Mqtt5Properties::MaximumPacketSize:
+ max_packet_size = std::min(readFourBytesToUint32(), max_packet_size);
+ break;
+ case Mqtt5Properties::TopicAliasMaximum:
+ max_topic_aliases = readTwoBytesToUInt16();
+ break;
+ case Mqtt5Properties::RequestResponseInformation:
+ request_response_information = !!readByte();
+ break;
+ case Mqtt5Properties::RequestProblemInformation:
+ request_problem_information = !!readByte();
+ break;
+ case Mqtt5Properties::UserProperty:
+ break;
+ case Mqtt5Properties::AuthenticationMethod:
+ break;
+ case Mqtt5Properties::AuthenticationData:
+ break;
+ default:
+ throw ProtocolError("Invalid connect property.");
+ }
+ }
+ }
+
uint16_t client_id_length = readTwoBytesToUInt16();
std::string client_id(readBytes(client_id_length), client_id_length);
@@ -442,7 +493,7 @@ void MqttPacket::handleConnect()
client_id = getSecureRandomString(23);
}
- sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, clean_session);
+ sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, clean_session, max_packet_size, max_topic_aliases);
sender->setWill(will_topic, will_payload, will_retain, will_qos);
bool accessGranted = false;
@@ -475,7 +526,7 @@ void MqttPacket::handleConnect()
sender->writeMqttPacket(response);
logger->logf(LOG_NOTICE, "Client '%s' logged in successfully", sender->repr().c_str());
- subscriptionStore->registerClientAndKickExistingOne(sender);
+ subscriptionStore->registerClientAndKickExistingOne(sender, max_qos_packets, session_expire);
}
else
{
@@ -921,11 +972,45 @@ uint16_t MqttPacket::readTwoBytesToUInt16()
return i;
}
+uint32_t MqttPacket::readFourBytesToUint32()
+{
+ if (pos + 4 > bites.size())
+ throw ProtocolError("Invalid packet: header specifies invalid length.");
+
+ const uint8_t a = bites[pos++];
+ const uint8_t b = bites[pos++];
+ const uint8_t c = bites[pos++];
+ const uint8_t d = bites[pos++];
+ uint32_t i = (a << 24) | (b << 16) | (c << 8) | d;
+ return i;
+}
+
size_t MqttPacket::remainingAfterPos()
{
return bites.size() - pos;
}
+size_t MqttPacket::decodeVariableByteIntAtPos()
+{
+ uint64_t multiplier = 1;
+ size_t value = 0;
+ uint8_t encodedByte = 0;
+ do
+ {
+ if (pos >= bites.size())
+ throw ProtocolError("Variable byte int length goes out of packet. Corrupt.");
+
+ encodedByte = bites[pos++];
+ value += (encodedByte & 127) * multiplier;
+ multiplier *= 128;
+ if (multiplier > 128*128*128*128)
+ throw ProtocolError("Malformed Remaining Length.");
+ }
+ while ((encodedByte & 128) != 0);
+
+ return value;
+}
+
bool MqttPacket::getRetain() const
{
return (first_byte & 0b00000001);
diff --git a/mqttpacket.h b/mqttpacket.h
index 396eee3..283cc50 100644
--- a/mqttpacket.h
+++ b/mqttpacket.h
@@ -63,7 +63,9 @@ class MqttPacket
void writeUint16(uint16_t x);
void writeBytes(const char *b, size_t len);
uint16_t readTwoBytesToUInt16();
+ uint32_t readFourBytesToUint32();
size_t remainingAfterPos();
+ size_t decodeVariableByteIntAtPos();
void calculateRemainingLength();
void pubCommonConstruct(const uint16_t packet_id, PacketType packetType, uint8_t firstByteDefaultBits = 0);
diff --git a/session.cpp b/session.cpp
index c5540e9..705a2e6 100644
--- a/session.cpp
+++ b/session.cpp
@@ -20,12 +20,17 @@ License along with FlashMQ. If not, see .
#include "session.h"
#include "client.h"
#include "threadglobals.h"
+#include "threadglobals.h"
std::chrono::time_point appStartTime = std::chrono::steady_clock::now();
Session::Session()
{
+ const Settings &settings = *ThreadGlobals::getSettings();
+ // Sessions also get defaults from the handleConnect() method, but when you create sessions elsewhere, we do need some sensible defaults.
+ this->maxQosMsgPending = settings.maxQosMsgPendingPerClient;
+ this->sessionExpiryInterval = settings.expireSessionsAfterSeconds;
}
int64_t Session::getProgramStartedAtUnixTimestamp()
@@ -174,7 +179,7 @@ void Session::writePacket(PublishCopyFactory ©Factory, const char max_qos, u
std::unique_lock locker(qosQueueMutex);
const size_t totalQosPacketsInTransit = qosPacketQueue.size() + incomingQoS2MessageIds.size() + outgoingQoS2MessageIds.size();
- if (totalQosPacketsInTransit >= settings->maxQosMsgPendingPerClient
+ if (totalQosPacketsInTransit >= maxQosMsgPending
|| (qosPacketQueue.getByteSize() >= settings->maxQosBytesPendingPerClient && qosPacketQueue.size() > 0))
{
if (QoSLogPrintedAtId != nextPacketId)
@@ -286,9 +291,9 @@ void Session::touch()
lastTouched = std::chrono::steady_clock::now();
}
-bool Session::hasExpired(int expireAfterSeconds)
+bool Session::hasExpired() const
{
- std::chrono::seconds expireAfter(expireAfterSeconds);
+ std::chrono::seconds expireAfter(sessionExpiryInterval);
std::chrono::time_point now = std::chrono::steady_clock::now();
return client.expired() && (lastTouched + expireAfter) < now;
}
@@ -347,3 +352,9 @@ bool Session::getCleanSession() const
return c->getCleanSession();
}
+
+void Session::setSessionProperties(uint16_t maxQosPackets, uint32_t sessionExpiryInterval)
+{
+ this->maxQosMsgPending = maxQosPackets;
+ this->sessionExpiryInterval = sessionExpiryInterval;
+}
diff --git a/session.h b/session.h
index 4f92abb..a525914 100644
--- a/session.h
+++ b/session.h
@@ -46,6 +46,8 @@ class Session
std::mutex qosQueueMutex;
uint16_t nextPacketId = 0;
uint16_t qosInFlightCounter = 0;
+ uint32_t sessionExpiryInterval = 0;
+ uint16_t maxQosMsgPending;
uint16_t QoSLogPrintedAtId = 0;
std::chrono::time_point lastTouched = std::chrono::steady_clock::now();
Logger *logger = Logger::getInstance();
@@ -75,7 +77,7 @@ public:
uint64_t sendPendingQosMessages();
void touch(std::chrono::time_point val);
void touch();
- bool hasExpired(int expireAfterSeconds);
+ bool hasExpired() const;
void addIncomingQoS2MessageId(uint16_t packet_id);
bool incomingQoS2MessageIdInTransit(uint16_t packet_id);
@@ -85,6 +87,8 @@ public:
void removeOutgoingQoS2MessageId(u_int16_t packet_id);
bool getCleanSession() const;
+
+ void setSessionProperties(uint16_t maxQosPackets, uint32_t sessionExpiryInterval);
};
#endif // SESSION_H
diff --git a/settings.h b/settings.h
index 55d0216..6910629 100644
--- a/settings.h
+++ b/settings.h
@@ -52,11 +52,11 @@ public:
std::string mosquittoAclFile;
bool allowAnonymous = false;
int rlimitNoFile = 1000000;
- uint64_t expireSessionsAfterSeconds = 1209600;
+ uint32_t expireSessionsAfterSeconds = 1209600;
int authPluginTimerPeriod = 60;
std::string storageDir;
int threadCount = 0;
- uint maxQosMsgPendingPerClient = 512;
+ uint16_t maxQosMsgPendingPerClient = 512;
uint maxQosBytesPendingPerClient = 65536;
std::list> listeners; // Default one is created later, when none are defined.
diff --git a/subscriptionstore.cpp b/subscriptionstore.cpp
index 11c8d4c..a0d0cce 100644
--- a/subscriptionstore.cpp
+++ b/subscriptionstore.cpp
@@ -22,6 +22,7 @@ License along with FlashMQ. If not, see .
#include "rwlockguard.h"
#include "retainedmessagesdb.h"
#include "publishcopyfactory.h"
+#include "threadglobals.h"
ReceivingSubscriber::ReceivingSubscriber(const std::shared_ptr &ses, char qos) :
session(ses),
@@ -200,9 +201,15 @@ void SubscriptionStore::removeSubscription(std::shared_ptr &client, cons
}
-// Removes an existing client when it already exists [MQTT-3.1.4-2].
void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr &client)
{
+ const Settings *settings = ThreadGlobals::getSettings();
+ registerClientAndKickExistingOne(client, settings->maxQosMsgPendingPerClient, settings->expireSessionsAfterSeconds);
+}
+
+// Removes an existing client when it already exists [MQTT-3.1.4-2].
+void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr &client, uint16_t maxQosPackets, uint32_t sessionExpiryInterval)
+{
RWLockGuard lock_guard(&subscriptionsRwlock);
lock_guard.wrlock();
@@ -247,6 +254,7 @@ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr
session->assignActiveConnection(client);
client->assignSession(session);
+ session->setSessionProperties(maxQosPackets, sessionExpiryInterval);
uint64_t count = session->sendPendingQosMessages();
client->getThreadData()->incrementSentMessageCount(count);
}
@@ -533,8 +541,12 @@ void SubscriptionStore::removeSession(const std::string &clientid)
}
}
-// This is not MQTT compliant, but the standard doesn't keep real world constraints into account.
-void SubscriptionStore::removeExpiredSessionsClients(int expireSessionsAfterSeconds)
+/**
+ * @brief SubscriptionStore::removeExpiredSessionsClients removes expired sessions.
+ *
+ * For Mqtt3 this is non-standard, but the standard doesn't keep real world constraints into account.
+ */
+void SubscriptionStore::removeExpiredSessionsClients()
{
RWLockGuard lock_guard(&subscriptionsRwlock);
lock_guard.wrlock();
@@ -546,7 +558,7 @@ void SubscriptionStore::removeExpiredSessionsClients(int expireSessionsAfterSeco
{
std::shared_ptr &session = session_it->second;
- if (session->hasExpired(expireSessionsAfterSeconds))
+ if (session->hasExpired())
{
logger->logf(LOG_DEBUG, "Removing expired session from store %s", session->getClientId().c_str());
session_it = sessionsById.erase(session_it);
diff --git a/subscriptionstore.h b/subscriptionstore.h
index 2c2ab2e..e7e549c 100644
--- a/subscriptionstore.h
+++ b/subscriptionstore.h
@@ -119,6 +119,7 @@ public:
void addSubscription(std::shared_ptr &client, const std::string &topic, const std::vector &subtopics, char qos);
void removeSubscription(std::shared_ptr &client, const std::string &topic);
void registerClientAndKickExistingOne(std::shared_ptr &client);
+ void registerClientAndKickExistingOne(std::shared_ptr &client, uint16_t maxQosPackets, uint32_t sessionExpiryInterval);
bool sessionPresent(const std::string &clientid);
void queuePacketAtSubscribers(const std::vector &subtopics, MqttPacket &packet, bool dollar = false);
@@ -129,7 +130,7 @@ public:
void setRetainedMessage(const std::string &topic, const std::vector &subtopics, const std::string &payload, char qos);
void removeSession(const std::string &clientid);
- void removeExpiredSessionsClients(int expireSessionsAfterSeconds);
+ void removeExpiredSessionsClients();
int64_t getRetainedMessageCount() const;
uint64_t getSessionCount() const;
diff --git a/types.h b/types.h
index f232e75..05aa950 100644
--- a/types.h
+++ b/types.h
@@ -47,7 +47,40 @@ enum class ProtocolVersion
{
None = 0,
Mqtt31 = 0x03,
- Mqtt311 = 0x04
+ Mqtt311 = 0x04,
+ Mqtt5 = 0x05
+};
+
+enum class Mqtt5Properties
+{
+ None = 0,
+ PayloadFormatIndicator = 1,
+ MessageExpiryInterval = 2,
+ ContentType = 3,
+ ResponseTopic = 8,
+ CorrelationData = 9,
+ SubscriptionIdentifier = 11,
+ SessionExpiryInterval = 17,
+ AssignedClientIdentifier = 18,
+ ServerKeepAlive = 13,
+ AuthenticationMethod = 21,
+ AuthenticationData = 22,
+ RequestProblemInformation = 23,
+ WillDelayInterval = 24,
+ RequestResponseInformation = 25,
+ ResponseInformation = 26,
+ ServerReference = 28,
+ ReasonString = 31,
+ ReceiveMaximum = 33,
+ TopicAliasMaximum = 34,
+ TopicAlias = 35,
+ MaximumQoS = 36,
+ RetainAvailable = 37,
+ UserProperty = 38,
+ MaximumPacketSize = 39,
+ WildcardSubscriptionAvailable = 40,
+ SubscriptionIdentifierAvailable = 41,
+ SharedSubscriptionAvailable = 42
};
enum class ConnAckReturnCodes
diff --git a/utils.cpp b/utils.cpp
index 8ce6334..0deadaf 100644
--- a/utils.cpp
+++ b/utils.cpp
@@ -662,6 +662,8 @@ const std::string protocolVersionString(ProtocolVersion p)
return "3.1";
case ProtocolVersion::Mqtt311:
return "3.1.1";
+ case ProtocolVersion::Mqtt5:
+ return "5.0";
default:
return "unknown";
}