Commit 6702193e3c13db784537885d3435a60f3650b2de

Authored by Wiebe Cazemier
1 parent 5d72070e

Add config option 'allow_unsafe_username_chars'

configfileparser.cpp
@@ -41,6 +41,7 @@ ConfigFileParser::ConfigFileParser(const std::string &path) : @@ -41,6 +41,7 @@ ConfigFileParser::ConfigFileParser(const std::string &path) :
41 validKeys.insert("auth_plugin"); 41 validKeys.insert("auth_plugin");
42 validKeys.insert("log_file"); 42 validKeys.insert("log_file");
43 validKeys.insert("allow_unsafe_clientid_chars"); 43 validKeys.insert("allow_unsafe_clientid_chars");
  44 + validKeys.insert("allow_unsafe_username_chars");
44 validKeys.insert("client_initial_buffer_size"); 45 validKeys.insert("client_initial_buffer_size");
45 validKeys.insert("max_packet_size"); 46 validKeys.insert("max_packet_size");
46 47
@@ -221,6 +222,12 @@ void ConfigFileParser::loadFile(bool test) @@ -221,6 +222,12 @@ void ConfigFileParser::loadFile(bool test)
221 tmpSettings->allowUnsafeClientidChars = tmp; 222 tmpSettings->allowUnsafeClientidChars = tmp;
222 } 223 }
223 224
  225 + if (key == "allow_unsafe_username_chars")
  226 + {
  227 + bool tmp = stringTruthiness(value);
  228 + tmpSettings->allowUnsafeUsernameChars = tmp;
  229 + }
  230 +
224 if (key == "client_initial_buffer_size") 231 if (key == "client_initial_buffer_size")
225 { 232 {
226 int newVal = std::stoi(value); 233 int newVal = std::stoi(value);
mqttpacket.cpp
@@ -232,9 +232,9 @@ void MqttPacket::handleConnect() @@ -232,9 +232,9 @@ void MqttPacket::handleConnect()
232 bool validClientId = true; 232 bool validClientId = true;
233 233
234 // Check for wildcard chars in case the client_id ever appears in topics. 234 // Check for wildcard chars in case the client_id ever appears in topics.
235 - if (!settings.allowUnsafeClientidChars && (strContains(client_id, "+") || strContains(client_id, "#"))) 235 + if (!settings.allowUnsafeClientidChars && containsDangerousCharacters(client_id))
236 { 236 {
237 - logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false:", client_id.c_str()); 237 + logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false.", client_id.c_str());
238 validClientId = false; 238 validClientId = false;
239 } 239 }
240 else if (!clean_session && client_id.empty()) 240 else if (!clean_session && client_id.empty())
@@ -266,7 +266,21 @@ void MqttPacket::handleConnect() @@ -266,7 +266,21 @@ void MqttPacket::handleConnect()
266 sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, clean_session); 266 sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, clean_session);
267 sender->setWill(will_topic, will_payload, will_retain, will_qos); 267 sender->setWill(will_topic, will_payload, will_retain, will_qos);
268 268
269 - if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success) 269 + bool accessGranted = false;
  270 + std::string denyLogMsg;
  271 +
  272 + if (!settings.allowUnsafeUsernameChars && containsDangerousCharacters(username))
  273 + {
  274 + denyLogMsg = formatString("Username '%s' has + or # in the id and 'allow_unsafe_username_chars' is false.", username.c_str());
  275 + sender->setDisconnectReason("Invalid username character");
  276 + accessGranted = false;
  277 + }
  278 + else if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success)
  279 + {
  280 + accessGranted = true;
  281 + }
  282 +
  283 + if (accessGranted)
270 { 284 {
271 bool sessionPresent = protocolVersion >= ProtocolVersion::Mqtt311 && !clean_session && subscriptionStore->sessionPresent(client_id); 285 bool sessionPresent = protocolVersion >= ProtocolVersion::Mqtt311 && !clean_session && subscriptionStore->sessionPresent(client_id);
272 subscriptionStore->registerClientAndKickExistingOne(sender); 286 subscriptionStore->registerClientAndKickExistingOne(sender);
@@ -284,7 +298,10 @@ void MqttPacket::handleConnect() @@ -284,7 +298,10 @@ void MqttPacket::handleConnect()
284 sender->setDisconnectReason("Access denied"); 298 sender->setDisconnectReason("Access denied");
285 sender->setReadyForDisconnect(); 299 sender->setReadyForDisconnect();
286 sender->writeMqttPacket(response); 300 sender->writeMqttPacket(response);
287 - logger->logf(LOG_NOTICE, "User '%s' access denied", username.c_str()); 301 + if (!denyLogMsg.empty())
  302 + logger->logf(LOG_NOTICE, denyLogMsg.c_str());
  303 + else
  304 + logger->logf(LOG_NOTICE, "User '%s' access denied", username.c_str());
288 } 305 }
289 } 306 }
290 else 307 else
settings.h
@@ -18,6 +18,7 @@ public: @@ -18,6 +18,7 @@ public:
18 std::string authPluginPath; 18 std::string authPluginPath;
19 std::string logPath; 19 std::string logPath;
20 bool allowUnsafeClientidChars = false; 20 bool allowUnsafeClientidChars = false;
  21 + bool allowUnsafeUsernameChars = false;
21 int clientInitialBufferSize = 1024; // Must be power of 2 22 int clientInitialBufferSize = 1024; // Must be power of 2
22 int maxPacketSize = 268435461; // 256 MB + 5 23 int maxPacketSize = 268435461; // 256 MB + 5
23 std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined. 24 std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined.
utils.cpp
@@ -137,6 +137,25 @@ bool isValidPublishPath(const std::string &amp;s) @@ -137,6 +137,25 @@ bool isValidPublishPath(const std::string &amp;s)
137 return true; 137 return true;
138 } 138 }
139 139
  140 +bool containsDangerousCharacters(const std::string &s)
  141 +{
  142 + if (s.empty())
  143 + return false;
  144 +
  145 + for (const char c : s)
  146 + {
  147 + switch(c)
  148 + {
  149 + case '#':
  150 + return true;
  151 + case '+':
  152 + return true;
  153 + }
  154 + }
  155 +
  156 + return false;
  157 +}
  158 +
140 std::vector<std::string> splitToVector(const std::string &input, const char sep, size_t max, bool keep_empty_parts) 159 std::vector<std::string> splitToVector(const std::string &input, const char sep, size_t max, bool keep_empty_parts)
141 { 160 {
142 std::vector<std::string> list; 161 std::vector<std::string> list;
@@ -34,6 +34,7 @@ bool isValidUtf8(const std::string &amp;s); @@ -34,6 +34,7 @@ bool isValidUtf8(const std::string &amp;s);
34 bool strContains(const std::string &s, const std::string &needle); 34 bool strContains(const std::string &s, const std::string &needle);
35 35
36 bool isValidPublishPath(const std::string &s); 36 bool isValidPublishPath(const std::string &s);
  37 +bool containsDangerousCharacters(const std::string &s);
37 38
38 void ltrim(std::string &s); 39 void ltrim(std::string &s);
39 void rtrim(std::string &s); 40 void rtrim(std::string &s);