Commit 6702193e3c13db784537885d3435a60f3650b2de
1 parent
5d72070e
Add config option 'allow_unsafe_username_chars'
Showing
5 changed files
with
49 additions
and
4 deletions
configfileparser.cpp
| @@ -41,6 +41,7 @@ ConfigFileParser::ConfigFileParser(const std::string &path) : | @@ -41,6 +41,7 @@ ConfigFileParser::ConfigFileParser(const std::string &path) : | ||
| 41 | validKeys.insert("auth_plugin"); | 41 | validKeys.insert("auth_plugin"); |
| 42 | validKeys.insert("log_file"); | 42 | validKeys.insert("log_file"); |
| 43 | validKeys.insert("allow_unsafe_clientid_chars"); | 43 | validKeys.insert("allow_unsafe_clientid_chars"); |
| 44 | + validKeys.insert("allow_unsafe_username_chars"); | ||
| 44 | validKeys.insert("client_initial_buffer_size"); | 45 | validKeys.insert("client_initial_buffer_size"); |
| 45 | validKeys.insert("max_packet_size"); | 46 | validKeys.insert("max_packet_size"); |
| 46 | 47 | ||
| @@ -221,6 +222,12 @@ void ConfigFileParser::loadFile(bool test) | @@ -221,6 +222,12 @@ void ConfigFileParser::loadFile(bool test) | ||
| 221 | tmpSettings->allowUnsafeClientidChars = tmp; | 222 | tmpSettings->allowUnsafeClientidChars = tmp; |
| 222 | } | 223 | } |
| 223 | 224 | ||
| 225 | + if (key == "allow_unsafe_username_chars") | ||
| 226 | + { | ||
| 227 | + bool tmp = stringTruthiness(value); | ||
| 228 | + tmpSettings->allowUnsafeUsernameChars = tmp; | ||
| 229 | + } | ||
| 230 | + | ||
| 224 | if (key == "client_initial_buffer_size") | 231 | if (key == "client_initial_buffer_size") |
| 225 | { | 232 | { |
| 226 | int newVal = std::stoi(value); | 233 | int newVal = std::stoi(value); |
mqttpacket.cpp
| @@ -232,9 +232,9 @@ void MqttPacket::handleConnect() | @@ -232,9 +232,9 @@ void MqttPacket::handleConnect() | ||
| 232 | bool validClientId = true; | 232 | bool validClientId = true; |
| 233 | 233 | ||
| 234 | // Check for wildcard chars in case the client_id ever appears in topics. | 234 | // Check for wildcard chars in case the client_id ever appears in topics. |
| 235 | - if (!settings.allowUnsafeClientidChars && (strContains(client_id, "+") || strContains(client_id, "#"))) | 235 | + if (!settings.allowUnsafeClientidChars && containsDangerousCharacters(client_id)) |
| 236 | { | 236 | { |
| 237 | - logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false:", client_id.c_str()); | 237 | + logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false.", client_id.c_str()); |
| 238 | validClientId = false; | 238 | validClientId = false; |
| 239 | } | 239 | } |
| 240 | else if (!clean_session && client_id.empty()) | 240 | else if (!clean_session && client_id.empty()) |
| @@ -266,7 +266,21 @@ void MqttPacket::handleConnect() | @@ -266,7 +266,21 @@ void MqttPacket::handleConnect() | ||
| 266 | sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, clean_session); | 266 | sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, clean_session); |
| 267 | sender->setWill(will_topic, will_payload, will_retain, will_qos); | 267 | sender->setWill(will_topic, will_payload, will_retain, will_qos); |
| 268 | 268 | ||
| 269 | - if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success) | 269 | + bool accessGranted = false; |
| 270 | + std::string denyLogMsg; | ||
| 271 | + | ||
| 272 | + if (!settings.allowUnsafeUsernameChars && containsDangerousCharacters(username)) | ||
| 273 | + { | ||
| 274 | + denyLogMsg = formatString("Username '%s' has + or # in the id and 'allow_unsafe_username_chars' is false.", username.c_str()); | ||
| 275 | + sender->setDisconnectReason("Invalid username character"); | ||
| 276 | + accessGranted = false; | ||
| 277 | + } | ||
| 278 | + else if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success) | ||
| 279 | + { | ||
| 280 | + accessGranted = true; | ||
| 281 | + } | ||
| 282 | + | ||
| 283 | + if (accessGranted) | ||
| 270 | { | 284 | { |
| 271 | bool sessionPresent = protocolVersion >= ProtocolVersion::Mqtt311 && !clean_session && subscriptionStore->sessionPresent(client_id); | 285 | bool sessionPresent = protocolVersion >= ProtocolVersion::Mqtt311 && !clean_session && subscriptionStore->sessionPresent(client_id); |
| 272 | subscriptionStore->registerClientAndKickExistingOne(sender); | 286 | subscriptionStore->registerClientAndKickExistingOne(sender); |
| @@ -284,7 +298,10 @@ void MqttPacket::handleConnect() | @@ -284,7 +298,10 @@ void MqttPacket::handleConnect() | ||
| 284 | sender->setDisconnectReason("Access denied"); | 298 | sender->setDisconnectReason("Access denied"); |
| 285 | sender->setReadyForDisconnect(); | 299 | sender->setReadyForDisconnect(); |
| 286 | sender->writeMqttPacket(response); | 300 | sender->writeMqttPacket(response); |
| 287 | - logger->logf(LOG_NOTICE, "User '%s' access denied", username.c_str()); | 301 | + if (!denyLogMsg.empty()) |
| 302 | + logger->logf(LOG_NOTICE, denyLogMsg.c_str()); | ||
| 303 | + else | ||
| 304 | + logger->logf(LOG_NOTICE, "User '%s' access denied", username.c_str()); | ||
| 288 | } | 305 | } |
| 289 | } | 306 | } |
| 290 | else | 307 | else |
settings.h
| @@ -18,6 +18,7 @@ public: | @@ -18,6 +18,7 @@ public: | ||
| 18 | std::string authPluginPath; | 18 | std::string authPluginPath; |
| 19 | std::string logPath; | 19 | std::string logPath; |
| 20 | bool allowUnsafeClientidChars = false; | 20 | bool allowUnsafeClientidChars = false; |
| 21 | + bool allowUnsafeUsernameChars = false; | ||
| 21 | int clientInitialBufferSize = 1024; // Must be power of 2 | 22 | int clientInitialBufferSize = 1024; // Must be power of 2 |
| 22 | int maxPacketSize = 268435461; // 256 MB + 5 | 23 | int maxPacketSize = 268435461; // 256 MB + 5 |
| 23 | std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined. | 24 | std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined. |
utils.cpp
| @@ -137,6 +137,25 @@ bool isValidPublishPath(const std::string &s) | @@ -137,6 +137,25 @@ bool isValidPublishPath(const std::string &s) | ||
| 137 | return true; | 137 | return true; |
| 138 | } | 138 | } |
| 139 | 139 | ||
| 140 | +bool containsDangerousCharacters(const std::string &s) | ||
| 141 | +{ | ||
| 142 | + if (s.empty()) | ||
| 143 | + return false; | ||
| 144 | + | ||
| 145 | + for (const char c : s) | ||
| 146 | + { | ||
| 147 | + switch(c) | ||
| 148 | + { | ||
| 149 | + case '#': | ||
| 150 | + return true; | ||
| 151 | + case '+': | ||
| 152 | + return true; | ||
| 153 | + } | ||
| 154 | + } | ||
| 155 | + | ||
| 156 | + return false; | ||
| 157 | +} | ||
| 158 | + | ||
| 140 | std::vector<std::string> splitToVector(const std::string &input, const char sep, size_t max, bool keep_empty_parts) | 159 | std::vector<std::string> splitToVector(const std::string &input, const char sep, size_t max, bool keep_empty_parts) |
| 141 | { | 160 | { |
| 142 | std::vector<std::string> list; | 161 | std::vector<std::string> list; |
utils.h
| @@ -34,6 +34,7 @@ bool isValidUtf8(const std::string &s); | @@ -34,6 +34,7 @@ bool isValidUtf8(const std::string &s); | ||
| 34 | bool strContains(const std::string &s, const std::string &needle); | 34 | bool strContains(const std::string &s, const std::string &needle); |
| 35 | 35 | ||
| 36 | bool isValidPublishPath(const std::string &s); | 36 | bool isValidPublishPath(const std::string &s); |
| 37 | +bool containsDangerousCharacters(const std::string &s); | ||
| 37 | 38 | ||
| 38 | void ltrim(std::string &s); | 39 | void ltrim(std::string &s); |
| 39 | void rtrim(std::string &s); | 40 | void rtrim(std::string &s); |