Commit 6702193e3c13db784537885d3435a60f3650b2de

Authored by Wiebe Cazemier
1 parent 5d72070e

Add config option 'allow_unsafe_username_chars'

configfileparser.cpp
... ... @@ -41,6 +41,7 @@ ConfigFileParser::ConfigFileParser(const std::string &path) :
41 41 validKeys.insert("auth_plugin");
42 42 validKeys.insert("log_file");
43 43 validKeys.insert("allow_unsafe_clientid_chars");
  44 + validKeys.insert("allow_unsafe_username_chars");
44 45 validKeys.insert("client_initial_buffer_size");
45 46 validKeys.insert("max_packet_size");
46 47  
... ... @@ -221,6 +222,12 @@ void ConfigFileParser::loadFile(bool test)
221 222 tmpSettings->allowUnsafeClientidChars = tmp;
222 223 }
223 224  
  225 + if (key == "allow_unsafe_username_chars")
  226 + {
  227 + bool tmp = stringTruthiness(value);
  228 + tmpSettings->allowUnsafeUsernameChars = tmp;
  229 + }
  230 +
224 231 if (key == "client_initial_buffer_size")
225 232 {
226 233 int newVal = std::stoi(value);
... ...
mqttpacket.cpp
... ... @@ -232,9 +232,9 @@ void MqttPacket::handleConnect()
232 232 bool validClientId = true;
233 233  
234 234 // Check for wildcard chars in case the client_id ever appears in topics.
235   - if (!settings.allowUnsafeClientidChars && (strContains(client_id, "+") || strContains(client_id, "#")))
  235 + if (!settings.allowUnsafeClientidChars && containsDangerousCharacters(client_id))
236 236 {
237   - logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false:", client_id.c_str());
  237 + logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false.", client_id.c_str());
238 238 validClientId = false;
239 239 }
240 240 else if (!clean_session && client_id.empty())
... ... @@ -266,7 +266,21 @@ void MqttPacket::handleConnect()
266 266 sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, clean_session);
267 267 sender->setWill(will_topic, will_payload, will_retain, will_qos);
268 268  
269   - if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success)
  269 + bool accessGranted = false;
  270 + std::string denyLogMsg;
  271 +
  272 + if (!settings.allowUnsafeUsernameChars && containsDangerousCharacters(username))
  273 + {
  274 + denyLogMsg = formatString("Username '%s' has + or # in the id and 'allow_unsafe_username_chars' is false.", username.c_str());
  275 + sender->setDisconnectReason("Invalid username character");
  276 + accessGranted = false;
  277 + }
  278 + else if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success)
  279 + {
  280 + accessGranted = true;
  281 + }
  282 +
  283 + if (accessGranted)
270 284 {
271 285 bool sessionPresent = protocolVersion >= ProtocolVersion::Mqtt311 && !clean_session && subscriptionStore->sessionPresent(client_id);
272 286 subscriptionStore->registerClientAndKickExistingOne(sender);
... ... @@ -284,7 +298,10 @@ void MqttPacket::handleConnect()
284 298 sender->setDisconnectReason("Access denied");
285 299 sender->setReadyForDisconnect();
286 300 sender->writeMqttPacket(response);
287   - logger->logf(LOG_NOTICE, "User '%s' access denied", username.c_str());
  301 + if (!denyLogMsg.empty())
  302 + logger->logf(LOG_NOTICE, denyLogMsg.c_str());
  303 + else
  304 + logger->logf(LOG_NOTICE, "User '%s' access denied", username.c_str());
288 305 }
289 306 }
290 307 else
... ...
settings.h
... ... @@ -18,6 +18,7 @@ public:
18 18 std::string authPluginPath;
19 19 std::string logPath;
20 20 bool allowUnsafeClientidChars = false;
  21 + bool allowUnsafeUsernameChars = false;
21 22 int clientInitialBufferSize = 1024; // Must be power of 2
22 23 int maxPacketSize = 268435461; // 256 MB + 5
23 24 std::list<std::shared_ptr<Listener>> listeners; // Default one is created later, when none are defined.
... ...
utils.cpp
... ... @@ -137,6 +137,25 @@ bool isValidPublishPath(const std::string &amp;s)
137 137 return true;
138 138 }
139 139  
  140 +bool containsDangerousCharacters(const std::string &s)
  141 +{
  142 + if (s.empty())
  143 + return false;
  144 +
  145 + for (const char c : s)
  146 + {
  147 + switch(c)
  148 + {
  149 + case '#':
  150 + return true;
  151 + case '+':
  152 + return true;
  153 + }
  154 + }
  155 +
  156 + return false;
  157 +}
  158 +
140 159 std::vector<std::string> splitToVector(const std::string &input, const char sep, size_t max, bool keep_empty_parts)
141 160 {
142 161 std::vector<std::string> list;
... ...
... ... @@ -34,6 +34,7 @@ bool isValidUtf8(const std::string &amp;s);
34 34 bool strContains(const std::string &s, const std::string &needle);
35 35  
36 36 bool isValidPublishPath(const std::string &s);
  37 +bool containsDangerousCharacters(const std::string &s);
37 38  
38 39 void ltrim(std::string &s);
39 40 void rtrim(std::string &s);
... ...