From 6702193e3c13db784537885d3435a60f3650b2de Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Sat, 6 Mar 2021 15:45:26 +0100 Subject: [PATCH] Add config option 'allow_unsafe_username_chars' --- configfileparser.cpp | 7 +++++++ mqttpacket.cpp | 25 +++++++++++++++++++++---- settings.h | 1 + utils.cpp | 19 +++++++++++++++++++ utils.h | 1 + 5 files changed, 49 insertions(+), 4 deletions(-) diff --git a/configfileparser.cpp b/configfileparser.cpp index 75e7825..0d8cd07 100644 --- a/configfileparser.cpp +++ b/configfileparser.cpp @@ -41,6 +41,7 @@ ConfigFileParser::ConfigFileParser(const std::string &path) : validKeys.insert("auth_plugin"); validKeys.insert("log_file"); validKeys.insert("allow_unsafe_clientid_chars"); + validKeys.insert("allow_unsafe_username_chars"); validKeys.insert("client_initial_buffer_size"); validKeys.insert("max_packet_size"); @@ -221,6 +222,12 @@ void ConfigFileParser::loadFile(bool test) tmpSettings->allowUnsafeClientidChars = tmp; } + if (key == "allow_unsafe_username_chars") + { + bool tmp = stringTruthiness(value); + tmpSettings->allowUnsafeUsernameChars = tmp; + } + if (key == "client_initial_buffer_size") { int newVal = std::stoi(value); diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 8a14307..ff7d681 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -232,9 +232,9 @@ void MqttPacket::handleConnect() bool validClientId = true; // Check for wildcard chars in case the client_id ever appears in topics. - if (!settings.allowUnsafeClientidChars && (strContains(client_id, "+") || strContains(client_id, "#"))) + if (!settings.allowUnsafeClientidChars && containsDangerousCharacters(client_id)) { - logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false:", client_id.c_str()); + logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false.", client_id.c_str()); validClientId = false; } else if (!clean_session && client_id.empty()) @@ -266,7 +266,21 @@ void MqttPacket::handleConnect() sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, clean_session); sender->setWill(will_topic, will_payload, will_retain, will_qos); - if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success) + bool accessGranted = false; + std::string denyLogMsg; + + if (!settings.allowUnsafeUsernameChars && containsDangerousCharacters(username)) + { + denyLogMsg = formatString("Username '%s' has + or # in the id and 'allow_unsafe_username_chars' is false.", username.c_str()); + sender->setDisconnectReason("Invalid username character"); + accessGranted = false; + } + else if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success) + { + accessGranted = true; + } + + if (accessGranted) { bool sessionPresent = protocolVersion >= ProtocolVersion::Mqtt311 && !clean_session && subscriptionStore->sessionPresent(client_id); subscriptionStore->registerClientAndKickExistingOne(sender); @@ -284,7 +298,10 @@ void MqttPacket::handleConnect() sender->setDisconnectReason("Access denied"); sender->setReadyForDisconnect(); sender->writeMqttPacket(response); - logger->logf(LOG_NOTICE, "User '%s' access denied", username.c_str()); + if (!denyLogMsg.empty()) + logger->logf(LOG_NOTICE, denyLogMsg.c_str()); + else + logger->logf(LOG_NOTICE, "User '%s' access denied", username.c_str()); } } else diff --git a/settings.h b/settings.h index d440612..687fc70 100644 --- a/settings.h +++ b/settings.h @@ -18,6 +18,7 @@ public: std::string authPluginPath; std::string logPath; bool allowUnsafeClientidChars = false; + bool allowUnsafeUsernameChars = false; int clientInitialBufferSize = 1024; // Must be power of 2 int maxPacketSize = 268435461; // 256 MB + 5 std::list> listeners; // Default one is created later, when none are defined. diff --git a/utils.cpp b/utils.cpp index 9f4731a..ad263a4 100644 --- a/utils.cpp +++ b/utils.cpp @@ -137,6 +137,25 @@ bool isValidPublishPath(const std::string &s) return true; } +bool containsDangerousCharacters(const std::string &s) +{ + if (s.empty()) + return false; + + for (const char c : s) + { + switch(c) + { + case '#': + return true; + case '+': + return true; + } + } + + return false; +} + std::vector splitToVector(const std::string &input, const char sep, size_t max, bool keep_empty_parts) { std::vector list; diff --git a/utils.h b/utils.h index d7776dc..6478047 100644 --- a/utils.h +++ b/utils.h @@ -34,6 +34,7 @@ bool isValidUtf8(const std::string &s); bool strContains(const std::string &s, const std::string &needle); bool isValidPublishPath(const std::string &s); +bool containsDangerousCharacters(const std::string &s); void ltrim(std::string &s); void rtrim(std::string &s); -- libgit2 0.21.4