diff --git a/CMakeLists.txt b/CMakeLists.txt index bbb4b01..3fda0eb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,6 +63,7 @@ add_executable(FlashMQ mqtt5properties.h globalstats.h derivablecounter.h + packetdatatypes.h mainapp.cpp main.cpp @@ -105,6 +106,7 @@ add_executable(FlashMQ mqtt5properties.cpp globalstats.cpp derivablecounter.cpp + packetdatatypes.cpp ) diff --git a/FlashMQTests/FlashMQTests.pro b/FlashMQTests/FlashMQTests.pro index 3ef852d..f8d746b 100644 --- a/FlashMQTests/FlashMQTests.pro +++ b/FlashMQTests/FlashMQTests.pro @@ -54,6 +54,7 @@ SOURCES += tst_maintests.cpp \ ../mqtt5properties.cpp \ ../globalstats.cpp \ ../derivablecounter.cpp \ + ../packetdatatypes.cpp \ mainappthread.cpp \ twoclienttestcontext.cpp @@ -100,6 +101,7 @@ HEADERS += \ ../mqtt5properties.h \ ../globalstats.h \ ../derivablecounter.h \ + ../packetdatatypes.h \ mainappthread.h \ twoclienttestcontext.h diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 0423adf..301697f 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -321,52 +321,40 @@ void MqttPacket::handle() handleExtendedAuth(); } -void MqttPacket::handleConnect() +ConnectData MqttPacket::parseConnectData() { - if (sender->hasConnectPacketSeen()) - throw ProtocolError("Client already sent a CONNECT.", ReasonCodes::ProtocolError); + if (this->packetType != PacketType::CONNECT) + throw std::runtime_error("Packet must be connect packet."); - std::shared_ptr subscriptionStore = MainApp::getMainApp()->getSubscriptionStore(); + setPosToDataStart(); - ThreadGlobals::getThreadData()->mqttConnectCounter.inc(); + ConnectData result; uint16_t variable_header_length = readTwoBytesToUInt16(); - const Settings &settings = *ThreadGlobals::getSettings(); - if (!(variable_header_length == 4 || variable_header_length == 6)) { throw ProtocolError("Invalid variable header length. Garbage?", ReasonCodes::MalformedPacket); } + const Settings &settings = *ThreadGlobals::getSettings(); + char *c = readBytes(variable_header_length); std::string magic_marker(c, variable_header_length); - char protocol_level = readByte(); + result.protocol_level_byte = readByte(); if (magic_marker == "MQTT") { - if (protocol_level == 0x04) + if (result.protocol_level_byte == 0x04) protocolVersion = ProtocolVersion::Mqtt311; - if (protocol_level == 0x05) + if (result.protocol_level_byte == 0x05) protocolVersion = ProtocolVersion::Mqtt5; } - else if (magic_marker == "MQIsdp" && protocol_level == 0x03) + else if (magic_marker == "MQIsdp" && result.protocol_level_byte == 0x03) { protocolVersion = ProtocolVersion::Mqtt31; } - else - { - // The specs are unclear when to use the version 3 codes or version 5 codes. - ProtocolVersion fuzzyProtocolVersion = protocol_level < 0x05 ? ProtocolVersion::Mqtt31 : ProtocolVersion::Mqtt5; - - ConnAck connAck(fuzzyProtocolVersion, ReasonCodes::UnsupportedProtocolVersion); - MqttPacket response(connAck); - sender->setReadyForDisconnect(); - sender->writeMqttPacket(response); - logger->logf(LOG_ERR, "Rejecting because of invalid protocol version: %s", sender->repr().c_str()); - return; - } char flagByte = readByte(); bool reserved = !!(flagByte & 0b00000001); @@ -374,32 +362,21 @@ void MqttPacket::handleConnect() if (reserved) throw ProtocolError("Protocol demands reserved flag in CONNECT is 0", ReasonCodes::MalformedPacket); + result.user_name_flag = !!(flagByte & 0b10000000); + result.password_flag = !!(flagByte & 0b01000000); + result.will_retain = !!(flagByte & 0b00100000); + result.will_qos = (flagByte & 0b00011000) >> 3; + result.will_flag = !!(flagByte & 0b00000100); + result.clean_start = !!(flagByte & 0b00000010); - bool user_name_flag = !!(flagByte & 0b10000000); - bool password_flag = !!(flagByte & 0b01000000); - bool will_retain = !!(flagByte & 0b00100000); - char will_qos = (flagByte & 0b00011000) >> 3; - bool will_flag = !!(flagByte & 0b00000100); - bool clean_start = !!(flagByte & 0b00000010); - - if (will_qos > 2) + if (result.will_qos > 2) throw ProtocolError("Invalid QoS for will.", ReasonCodes::MalformedPacket); - uint16_t keep_alive = readTwoBytesToUInt16(); - - uint16_t client_receive_max = settings.maxQosMsgPendingPerClient; - uint32_t session_expire = settings.getExpireSessionAfterSeconds(); - uint32_t max_outgoing_packet_size = settings.maxPacketSize; - uint16_t max_outgoing_topic_aliases = 0; // Default MUST BE 0, meaning server won't initiate aliases - bool request_response_information = false; - bool request_problem_information = false; - - std::string authenticationMethod; - std::string authenticationData; + result.keep_alive = readTwoBytesToUInt16(); if (protocolVersion == ProtocolVersion::Mqtt5) { - keep_alive = std::max(keep_alive, 5); + result.keep_alive = std::max(result.keep_alive, 5); const size_t proplen = decodeVariableByteIntAtPos(); const size_t prop_end_at = pos + proplen; @@ -411,36 +388,34 @@ void MqttPacket::handleConnect() switch (prop) { case Mqtt5Properties::SessionExpiryInterval: - session_expire = std::min(readFourBytesToUint32(), session_expire); + result.session_expire = std::min(readFourBytesToUint32(), result.session_expire); break; case Mqtt5Properties::ReceiveMaximum: - client_receive_max = std::min(readTwoBytesToUInt16(), client_receive_max); + result.client_receive_max = std::min(readTwoBytesToUInt16(), result.client_receive_max); break; case Mqtt5Properties::MaximumPacketSize: - max_outgoing_packet_size = std::min(readFourBytesToUint32(), max_outgoing_packet_size); + result.max_outgoing_packet_size = std::min(readFourBytesToUint32(), result.max_outgoing_packet_size); break; case Mqtt5Properties::TopicAliasMaximum: - max_outgoing_topic_aliases = std::min(readTwoBytesToUInt16(), settings.maxOutgoingTopicAliasValue); + result.max_outgoing_topic_aliases = std::min(readTwoBytesToUInt16(), settings.maxOutgoingTopicAliasValue); break; case Mqtt5Properties::RequestResponseInformation: - request_response_information = !!readByte(); - UNUSED(request_response_information); + result.request_response_information = !!readByte(); break; case Mqtt5Properties::RequestProblemInformation: - request_problem_information = !!readByte(); - UNUSED(request_problem_information); + result.request_problem_information = !!readByte(); break; case Mqtt5Properties::UserProperty: readUserProperty(); break; case Mqtt5Properties::AuthenticationMethod: { - authenticationMethod = readBytesToString(); + result.authenticationMethod = readBytesToString(); break; } case Mqtt5Properties::AuthenticationData: { - authenticationData = readBytesToString(false); + result.authenticationData = readBytesToString(false); break; } default: @@ -449,25 +424,21 @@ void MqttPacket::handleConnect() } } - if (client_receive_max == 0 || max_outgoing_packet_size == 0) + if (result.client_receive_max == 0 || result.max_outgoing_packet_size == 0) { throw ProtocolError("Receive max or max outgoing packet size can't be 0.", ReasonCodes::ProtocolError); } - std::string client_id = readBytesToString(); - - std::string username; - std::string password; + result.client_id = readBytesToString(); - WillPublish willpublish; - willpublish.qos = will_qos; - willpublish.retain = will_retain; + result.willpublish.qos = result.will_qos; + result.willpublish.retain = result.will_retain; - if (will_flag) + if (result.will_flag) { if (protocolVersion == ProtocolVersion::Mqtt5) { - willpublish.constructPropertyBuilder(); + result.willpublish.constructPropertyBuilder(); const size_t proplen = decodeVariableByteIntAtPos(); const size_t prop_end_at = pos + proplen; @@ -479,38 +450,43 @@ void MqttPacket::handleConnect() switch (prop) { case Mqtt5Properties::WillDelayInterval: - willpublish.will_delay = readFourBytesToUint32(); + result.willpublish.will_delay = readFourBytesToUint32(); break; case Mqtt5Properties::PayloadFormatIndicator: - willpublish.propertyBuilder->writePayloadFormatIndicator(readByte()); + result.willpublish.propertyBuilder->writePayloadFormatIndicator(readByte()); break; case Mqtt5Properties::ContentType: { const std::string contentType = readBytesToString(); - willpublish.propertyBuilder->writeContentType(contentType); + result.willpublish.propertyBuilder->writeContentType(contentType); break; } case Mqtt5Properties::ResponseTopic: { const std::string responseTopic = readBytesToString(true, true); - willpublish.propertyBuilder->writeResponseTopic(responseTopic); + result.willpublish.propertyBuilder->writeResponseTopic(responseTopic); break; } case Mqtt5Properties::MessageExpiryInterval: { const uint32_t expiresAfter = readFourBytesToUint32(); - willpublish.setExpireAfter(expiresAfter); + result.willpublish.setExpireAfter(expiresAfter); break; } case Mqtt5Properties::CorrelationData: { const std::string correlationData = readBytesToString(false); - willpublish.propertyBuilder->writeCorrelationData(correlationData); + result.willpublish.propertyBuilder->writeCorrelationData(correlationData); break; } case Mqtt5Properties::UserProperty: { - readUserProperty(); + result.willpublish.constructPropertyBuilder(); + + std::string key = readBytesToString(); + std::string value = readBytesToString(); + + result.willpublish.propertyBuilder->writeUserProperty(std::move(key), std::move(value)); break; } default: @@ -519,61 +495,92 @@ void MqttPacket::handleConnect() } } - willpublish.topic = readBytesToString(true, true); + result.willpublish.topic = readBytesToString(true, true); uint16_t will_payload_length = readTwoBytesToUInt16(); - willpublish.payload = std::string(readBytes(will_payload_length), will_payload_length); + result.willpublish.payload = std::string(readBytes(will_payload_length), will_payload_length); } - if (user_name_flag) + + if (result.user_name_flag) { - username = readBytesToString(false); + result.username = readBytesToString(false); - if (username.empty()) + if (result.username.empty()) throw ProtocolError("Username flagged as present, but it's 0 bytes.", ReasonCodes::MalformedPacket); - if (!settings.allowUnsafeUsernameChars && containsDangerousCharacters(username)) - throw ProtocolError(formatString("Username '%s' contains unsafe characters and 'allow_unsafe_username_chars' is false.", username.c_str()), + if (!settings.allowUnsafeUsernameChars && containsDangerousCharacters(result.username)) + throw ProtocolError(formatString("Username '%s' contains unsafe characters and 'allow_unsafe_username_chars' is false.", result.username.c_str()), ReasonCodes::BadUserNameOrPassword); } - if (password_flag) + + if (result.password_flag) { - if (this->protocolVersion <= ProtocolVersion::Mqtt311 && !user_name_flag) + if (this->protocolVersion <= ProtocolVersion::Mqtt311 && !result.user_name_flag) { throw ProtocolError("MQTT 3.1.1: If the User Name Flag is set to 0, the Password Flag MUST be set to 0."); } - password = readBytesToString(false); + result.password = readBytesToString(false); - if (password.empty()) + if (result.password.empty()) throw ProtocolError("Password flagged as present, but it's 0 bytes.", ReasonCodes::MalformedPacket); } + return result; +} + +void MqttPacket::handleConnect() +{ + if (sender->hasConnectPacketSeen()) + throw ProtocolError("Client already sent a CONNECT.", ReasonCodes::ProtocolError); + + std::shared_ptr subscriptionStore = MainApp::getMainApp()->getSubscriptionStore(); + + ThreadGlobals::getThreadData()->mqttConnectCounter.inc(); + + ConnectData connectData = parseConnectData(); + + if (this->protocolVersion == ProtocolVersion::None) + { + // The specs are unclear when to use the version 3 codes or version 5 codes. + ProtocolVersion fuzzyProtocolVersion = connectData.protocol_level_byte < 0x05 ? ProtocolVersion::Mqtt31 : ProtocolVersion::Mqtt5; + + ConnAck connAck(fuzzyProtocolVersion, ReasonCodes::UnsupportedProtocolVersion); + MqttPacket response(connAck); + sender->setReadyForDisconnect(); + sender->writeMqttPacket(response); + logger->logf(LOG_ERR, "Rejecting because of invalid protocol version: %s", sender->repr().c_str()); + return; + } + + const Settings &settings = *ThreadGlobals::getSettings(); + // I deferred the initial UTF8 check on username to be able to give an appropriate connack here, but to me, the specs // are actually vague whether 'BadUserNameOrPassword' should be given on invalid UTF8. - if (!isValidUtf8(username)) + if (!isValidUtf8(connectData.username)) { ConnAck connAck(protocolVersion, ReasonCodes::BadUserNameOrPassword); MqttPacket response(connAck); sender->setReadyForDisconnect(); sender->writeMqttPacket(response); - logger->logf(LOG_ERR, "Username has invalid UTF8: %s", username.c_str()); + logger->logf(LOG_ERR, "Username has invalid UTF8: %s", connectData.username.c_str()); return; } bool validClientId = true; // Check for wildcard chars in case the client_id ever appears in topics. - if (!settings.allowUnsafeClientidChars && containsDangerousCharacters(client_id)) + if (!settings.allowUnsafeClientidChars && containsDangerousCharacters(connectData.client_id)) { - logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false.", client_id.c_str()); + logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false.", connectData.client_id.c_str()); validClientId = false; } - else if (!clean_start && client_id.empty()) + else if (!connectData.clean_start && connectData.client_id.empty()) { logger->logf(LOG_ERR, "ClientID empty and clean start 0, which is incompatible"); validClientId = false; } - else if (protocolVersion < ProtocolVersion::Mqtt311 && client_id.empty()) + else if (protocolVersion < ProtocolVersion::Mqtt311 && connectData.client_id.empty()) { logger->logf(LOG_ERR, "Empty clientID. Connect with protocol 3.1.1 or higher to have one generated securely."); validClientId = false; @@ -590,25 +597,26 @@ void MqttPacket::handleConnect() } bool clientIdGenerated = false; - if (client_id.empty()) + if (connectData.client_id.empty()) { - client_id = getSecureRandomString(23); + connectData.client_id = getSecureRandomString(23); clientIdGenerated = true; } - sender->setClientProperties(protocolVersion, client_id, username, true, keep_alive, max_outgoing_packet_size, max_outgoing_topic_aliases); + sender->setClientProperties(protocolVersion, connectData.client_id, connectData.username, true, connectData.keep_alive, + connectData.max_outgoing_packet_size, connectData.max_outgoing_topic_aliases); - if (will_flag) - sender->setWill(std::move(willpublish)); + if (connectData.will_flag) + sender->setWill(std::move(connectData.willpublish)); // Stage connack, for immediate or delayed use when auth succeeds. { bool sessionPresent = false; std::shared_ptr existingSession; - if (protocolVersion >= ProtocolVersion::Mqtt311 && !clean_start) + if (protocolVersion >= ProtocolVersion::Mqtt311 && !connectData.clean_start) { - existingSession = subscriptionStore->lockSession(client_id); + existingSession = subscriptionStore->lockSession(connectData.client_id); if (existingSession) sessionPresent = true; } @@ -618,47 +626,47 @@ void MqttPacket::handleConnect() if (protocolVersion >= ProtocolVersion::Mqtt5) { connAck->propertyBuilder = std::make_shared(); - connAck->propertyBuilder->writeSessionExpiry(session_expire); + connAck->propertyBuilder->writeSessionExpiry(connectData.session_expire); connAck->propertyBuilder->writeReceiveMax(settings.maxQosMsgPendingPerClient); connAck->propertyBuilder->writeRetainAvailable(1); connAck->propertyBuilder->writeMaxPacketSize(sender->getMaxIncomingPacketSize()); if (clientIdGenerated) - connAck->propertyBuilder->writeAssignedClientId(client_id); + connAck->propertyBuilder->writeAssignedClientId(connectData.client_id); connAck->propertyBuilder->writeMaxTopicAliases(sender->getMaxIncomingTopicAliasValue()); connAck->propertyBuilder->writeWildcardSubscriptionAvailable(1); connAck->propertyBuilder->writeSubscriptionIdentifiersAvailable(0); connAck->propertyBuilder->writeSharedSubscriptionAvailable(0); - connAck->propertyBuilder->writeServerKeepAlive(keep_alive); + connAck->propertyBuilder->writeServerKeepAlive(connectData.keep_alive); - if (!authenticationMethod.empty()) + if (!connectData.authenticationMethod.empty()) { - connAck->propertyBuilder->writeAuthenticationMethod(authenticationMethod); + connAck->propertyBuilder->writeAuthenticationMethod(connectData.authenticationMethod); } } sender->stageConnack(std::move(connAck)); } - sender->setRegistrationData(clean_start, client_receive_max, session_expire); + sender->setRegistrationData(connectData.clean_start, connectData.client_receive_max, connectData.session_expire); Authentication &authentication = *ThreadGlobals::getAuth(); AuthResult authResult = AuthResult::login_denied; - if (!user_name_flag && authenticationMethod.empty() && settings.allowAnonymous) + if (!connectData.user_name_flag && connectData.authenticationMethod.empty() && settings.allowAnonymous) { authResult = AuthResult::success; } - else if (!authenticationMethod.empty()) + else if (!connectData.authenticationMethod.empty()) { - sender->setExtendedAuthenticationMethod(authenticationMethod); + sender->setExtendedAuthenticationMethod(connectData.authenticationMethod); std::string returnData; - authResult = authentication.extendedAuth(client_id, ExtendedAuthStage::Auth, authenticationMethod, authenticationData, + authResult = authentication.extendedAuth(connectData.client_id, ExtendedAuthStage::Auth, connectData.authenticationMethod, connectData.authenticationData, getUserProperties(), returnData, sender->getMutableUsername()); if (authResult == AuthResult::auth_continue) { - Auth auth(ReasonCodes::ContinueAuthentication, authenticationMethod, returnData); + Auth auth(ReasonCodes::ContinueAuthentication, connectData.authenticationMethod, returnData); MqttPacket pack(auth); sender->writeMqttPacket(pack); return; @@ -670,7 +678,7 @@ void MqttPacket::handleConnect() } else { - authResult = authentication.unPwdCheck(username, password, getUserProperties()); + authResult = authentication.unPwdCheck(connectData.username, connectData.password, getUserProperties()); } if (authResult == AuthResult::success) diff --git a/mqttpacket.h b/mqttpacket.h index 41bfad1..bae2273 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -35,6 +35,7 @@ License along with FlashMQ. If not, see . #include "variablebyteint.h" #include "mqtt5properties.h" +#include "packetdatatypes.h" /** * @brief The MqttPacket class represents incoming and outgonig packets. @@ -109,6 +110,7 @@ public: static void bufferToMqttPackets(CirBuf &buf, std::vector &packetQueueIn, std::shared_ptr &sender); void handle(); + ConnectData parseConnectData(); void handleConnect(); void handleExtendedAuth(); void handleDisconnect(); diff --git a/packetdatatypes.cpp b/packetdatatypes.cpp new file mode 100644 index 0000000..2e07fbf --- /dev/null +++ b/packetdatatypes.cpp @@ -0,0 +1,13 @@ +#include "packetdatatypes.h" + +#include "threadglobals.h" +#include "settings.h" + +ConnectData::ConnectData() +{ + const Settings *settings = ThreadGlobals::getSettings(); + + client_receive_max = settings->maxQosMsgPendingPerClient; + session_expire = settings->getExpireSessionAfterSeconds(); + max_outgoing_packet_size = settings->maxPacketSize; +} diff --git a/packetdatatypes.h b/packetdatatypes.h new file mode 100644 index 0000000..279f8e0 --- /dev/null +++ b/packetdatatypes.h @@ -0,0 +1,42 @@ +#ifndef PACKETDATATYPES_H +#define PACKETDATATYPES_H + +#include "mqtt5properties.h" + +struct ConnectData +{ + char protocol_level_byte = 0; + + // Flags + bool user_name_flag = false; + bool password_flag = false; + bool will_retain = false; + char will_qos = false; + bool will_flag = false; + bool clean_start = false; + + uint16_t keep_alive = 0; + + // Content from properties + uint16_t client_receive_max; + uint32_t session_expire; + uint32_t max_outgoing_packet_size; + uint16_t max_outgoing_topic_aliases = 0; // Default MUST BE 0, meaning server won't initiate aliases; + bool request_response_information = false; + bool request_problem_information = false; + std::string authenticationMethod; + std::string authenticationData; + + // Content from Payload + std::string client_id; + WillPublish willpublish; + std::string username; + std::string password; + + Mqtt5PropertyBuilder builder; + + ConnectData(); +}; + + +#endif // PACKETDATATYPES_H