diff --git a/FlashMQTests/FlashMQTests.pro b/FlashMQTests/FlashMQTests.pro index 0775f19..bb2c98b 100644 --- a/FlashMQTests/FlashMQTests.pro +++ b/FlashMQTests/FlashMQTests.pro @@ -55,6 +55,7 @@ SOURCES += tst_maintests.cpp \ ../derivablecounter.cpp \ ../packetdatatypes.cpp \ ../flashmqtestclient.cpp \ + conffiletemp.cpp \ mainappthread.cpp @@ -102,6 +103,7 @@ HEADERS += \ ../derivablecounter.h \ ../packetdatatypes.h \ ../flashmqtestclient.h \ + conffiletemp.h \ mainappthread.h LIBS += -ldl -lssl -lcrypto diff --git a/FlashMQTests/conffiletemp.cpp b/FlashMQTests/conffiletemp.cpp new file mode 100644 index 0000000..eda9e59 --- /dev/null +++ b/FlashMQTests/conffiletemp.cpp @@ -0,0 +1,51 @@ +#include "conffiletemp.h" + +#include +#include "unistd.h" +#include + +ConfFileTemp::ConfFileTemp() +{ + const std::string templateName("/tmp/flashmqconf_XXXXXX"); + std::vector nameBuf(templateName.size() + 1, 0); + std::copy(templateName.begin(), templateName.end(), nameBuf.begin()); + this->fd = mkstemp(nameBuf.data()); + + if (this->fd < 0) + { + throw std::runtime_error("mkstemp error."); + } + + this->filePath = nameBuf.data(); +} + +ConfFileTemp::~ConfFileTemp() +{ + closeFile(); + + if (!this->filePath.empty()) + unlink(this->filePath.c_str()); +} + +const std::string &ConfFileTemp::getFilePath() const +{ + if (fd > 0) + throw std::runtime_error("You first need to close the file before using it."); + + return this->filePath; +} + +void ConfFileTemp::writeLine(const std::string &line) +{ + write(this->fd, line.c_str(), line.size()); + write(this->fd, "\n", 1); +} + +void ConfFileTemp::closeFile() +{ + if (this->fd < 0) + return; + + close(this->fd); + this->fd = -1; +} diff --git a/FlashMQTests/conffiletemp.h b/FlashMQTests/conffiletemp.h new file mode 100644 index 0000000..bfb03ab --- /dev/null +++ b/FlashMQTests/conffiletemp.h @@ -0,0 +1,20 @@ +#ifndef CONFFILETEMP_H +#define CONFFILETEMP_H + +#include + +class ConfFileTemp +{ + int fd = -1; + std::string filePath; + +public: + ConfFileTemp(); + ~ConfFileTemp(); + + const std::string &getFilePath() const; + void writeLine(const std::string &line); + void closeFile(); +}; + +#endif // CONFFILETEMP_H diff --git a/FlashMQTests/mainappthread.cpp b/FlashMQTests/mainappthread.cpp index c9fcae2..5d45b0d 100644 --- a/FlashMQTests/mainappthread.cpp +++ b/FlashMQTests/mainappthread.cpp @@ -24,6 +24,39 @@ MainAppThread::MainAppThread(QObject *parent) : QThread(parent) appInstance->settings->allowAnonymous = true; } +MainAppThread::MainAppThread(const std::vector &args, QObject *parent) : QThread(parent) +{ + std::list> argCopies; + + const std::string programName = "FlashMQTests"; + std::vector programNameCopy(programName.size() + 1, 0); + std::copy(programName.begin(), programName.end(), programNameCopy.begin()); + argCopies.push_back(std::move(programNameCopy)); + + for (const std::string &arg : args) + { + std::vector copyArg(arg.size() + 1, 0); + std::copy(arg.begin(), arg.end(), copyArg.begin()); + argCopies.push_back(std::move(copyArg)); + } + + char *argv[256]; + memset(argv, 0, 256*sizeof (char*)); + + int i = 0; + for (std::vector © : argCopies) + { + argv[i++] = copy.data(); + } + + MainApp::initMainApp(i, argv); + appInstance = MainApp::getMainApp(); + + // A hack: when I supply args I probably define a config for auth stuff. + if (args.empty()) + appInstance->settings->allowAnonymous = true; +} + MainAppThread::~MainAppThread() { if (appInstance) diff --git a/FlashMQTests/mainappthread.h b/FlashMQTests/mainappthread.h index d5ef145..7d72473 100644 --- a/FlashMQTests/mainappthread.h +++ b/FlashMQTests/mainappthread.h @@ -28,6 +28,7 @@ class MainAppThread : public QThread MainApp *appInstance = nullptr; public: explicit MainAppThread(QObject *parent = nullptr); + MainAppThread(const std::vector &args, QObject *parent = nullptr); ~MainAppThread(); public slots: diff --git a/FlashMQTests/plugins/test_plugin.cpp b/FlashMQTests/plugins/test_plugin.cpp index 46b6b5a..93e46a0 100644 --- a/FlashMQTests/plugins/test_plugin.cpp +++ b/FlashMQTests/plugins/test_plugin.cpp @@ -45,6 +45,9 @@ AuthResult flashmq_auth_plugin_login_check(void *thread_data, const std::string (void)password; (void)userProperties; + if (username == "failme") + return AuthResult::login_denied; + return AuthResult::success; } @@ -72,6 +75,34 @@ AuthResult flashmq_extended_auth(void *thread_data, const std::string &clientid, (void)userProperties; (void)returnData; - return AuthResult::success; + if (authMethod == "always_good_passing_back_the_auth_data") + { + if (authData == "actually not good.") + return AuthResult::login_denied; + + returnData = authData; + return AuthResult::success; + } + if (authMethod == "always_fail") + { + return AuthResult::login_denied; + } + if (authMethod == "two_step") + { + if (authData == "Hello") + returnData = "Hello back"; + + if (authData == "grant me already!") + { + returnData = "OK, if you insist."; + return AuthResult::success; + } + else if (authData == "whoops, wrong data.") + return AuthResult::login_denied; + else + return AuthResult::auth_continue; + } + + return AuthResult::auth_method_not_supported; } diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp index 2596ba5..98f82f3 100644 --- a/FlashMQTests/tst_maintests.cpp +++ b/FlashMQTests/tst_maintests.cpp @@ -32,6 +32,8 @@ License along with FlashMQ. If not, see . #include "session.h" #include "threaddata.h" #include "threadglobals.h" +#include "conffiletemp.h" +#include "packetdatatypes.h" #include "flashmqtestclient.h" @@ -65,6 +67,7 @@ public: ~MainTests(); private slots: + void init(const std::vector &args); void init(); // will be called before each test function is executed void cleanup(); // will be called after every test function. @@ -134,6 +137,18 @@ private slots: void testUserProperties(); + void testAuthFail(); + void testAuthSucceed(); + + void testExtendedAuthOneStepSucceed(); + void testExtendedAuthOneStepDeny(); + void testExtendedAuthOneStepBadAuthMethod(); + void testExtendedAuthTwoStep(); + void testExtendedAuthTwoStepSecondStepFail(); + void testExtendedReAuth(); + void testExtendedReAuthTwoStep(); + void testExtendedReAuthFail(); + }; MainTests::MainTests() @@ -146,10 +161,10 @@ MainTests::~MainTests() } -void MainTests::init() +void MainTests::init(const std::vector &args) { mainApp.reset(); - mainApp.reset(new MainAppThread()); + mainApp.reset(new MainAppThread(args)); mainApp->start(); mainApp->waitForStarted(); @@ -160,6 +175,12 @@ void MainTests::init() ThreadGlobals::assignThreadData(dummyThreadData.get()); } +void MainTests::init() +{ + std::vector args; + init(args); +} + void MainTests::cleanup() { mainApp->stopApp(); @@ -1925,6 +1946,414 @@ void MainTests::testUserProperties() QVERIFY(properties3 == nullptr); } +void MainTests::testAuthFail() +{ + std::vector versions { ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5 }; + + ConfFileTemp confFile; + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); + confFile.closeFile(); + + std::vector args {"--config-file", confFile.getFilePath()}; + + cleanup(); + init(args); + + for (ProtocolVersion &version : versions) + { + + FlashMQTestClient client; + client.start(); + client.connectClient(version, false, 120, [](Connect &connect) { + connect.username = "failme"; + connect.password = "boo"; + }); + + QVERIFY(client.receivedPackets.size() == 1); + + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); + + if (version >= ProtocolVersion::Mqtt5) + QVERIFY(connAckData.reasonCode == ReasonCodes::NotAuthorized); + else + QVERIFY(static_cast(connAckData.reasonCode) == 5); + } +} + +void MainTests::testAuthSucceed() +{ + std::vector versions { ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5 }; + + ConfFileTemp confFile; + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); + confFile.closeFile(); + + std::vector args {"--config-file", confFile.getFilePath()}; + + cleanup(); + init(args); + + for (ProtocolVersion &version : versions) + { + + FlashMQTestClient client; + client.start(); + client.connectClient(version, false, 120, [](Connect &connect) { + connect.username = "passme"; + connect.password = "boo"; + }); + + QVERIFY(client.receivedPackets.size() == 1); + + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); + + QVERIFY(connAckData.reasonCode == ReasonCodes::Success); + } +} + +void MainTests::testExtendedAuthOneStepSucceed() +{ + ConfFileTemp confFile; + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); + confFile.closeFile(); + + std::vector args {"--config-file", confFile.getFilePath()}; + + cleanup(); + init(args); + + FlashMQTestClient client; + client.start(); + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { + connect.username = "me"; + connect.password = "me me"; + + connect.constructPropertyBuilder(); + + connect.propertyBuilder->writeAuthenticationMethod("always_good_passing_back_the_auth_data"); + connect.propertyBuilder->writeAuthenticationData("I have a proposal to put to ye."); + }); + + QVERIFY(client.receivedPackets.size() == 1); + + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); + + QVERIFY(connAckData.reasonCode == ReasonCodes::Success); + QVERIFY(connAckData.authData == "I have a proposal to put to ye."); +} + +void MainTests::testExtendedAuthOneStepDeny() +{ + ConfFileTemp confFile; + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); + confFile.closeFile(); + + std::vector args {"--config-file", confFile.getFilePath()}; + + cleanup(); + init(args); + + FlashMQTestClient client; + client.start(); + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { + connect.username = "me"; + connect.password = "me me"; + + connect.constructPropertyBuilder(); + + connect.propertyBuilder->writeAuthenticationMethod("always_fail"); + }); + + QVERIFY(client.receivedPackets.size() == 1); + + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); + + QVERIFY(connAckData.reasonCode == ReasonCodes::NotAuthorized); +} + +void MainTests::testExtendedAuthOneStepBadAuthMethod() +{ + ConfFileTemp confFile; + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); + confFile.closeFile(); + + std::vector args {"--config-file", confFile.getFilePath()}; + + cleanup(); + init(args); + + FlashMQTestClient client; + client.start(); + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { + connect.username = "me"; + connect.password = "me me"; + + connect.constructPropertyBuilder(); + + connect.propertyBuilder->writeAuthenticationMethod("doesnt_exist"); + }); + + QVERIFY(client.receivedPackets.size() == 1); + + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); + + QVERIFY(connAckData.reasonCode == ReasonCodes::BadAuthenticationMethod); +} + +void MainTests::testExtendedAuthTwoStep() +{ + ConfFileTemp confFile; + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); + confFile.closeFile(); + + std::vector args {"--config-file", confFile.getFilePath()}; + + cleanup(); + init(args); + + FlashMQTestClient client; + client.start(); + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { + connect.username = "me"; + connect.password = "me me"; + + connect.constructPropertyBuilder(); + + connect.propertyBuilder->writeAuthenticationMethod("two_step"); + connect.propertyBuilder->writeAuthenticationData("Hello"); + }); + + QVERIFY(client.receivedPackets.size() == 1); + + AuthPacketData authData = client.receivedPackets.front().parseAuthData(); + + QVERIFY(authData.reasonCode == ReasonCodes::ContinueAuthentication); + QVERIFY(authData.data == "Hello back"); + + client.clearReceivedLists(); + + const Auth auth(ReasonCodes::ContinueAuthentication, "two_step", "grant me already!"); + client.writeAuth(auth); + + client.waitForConnack(); + + QVERIFY(client.receivedPackets.size() == 1); + + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); + + QVERIFY(connAckData.reasonCode == ReasonCodes::Success); + QVERIFY(connAckData.authData == "OK, if you insist."); +} + +void MainTests::testExtendedAuthTwoStepSecondStepFail() +{ + ConfFileTemp confFile; + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); + confFile.closeFile(); + + std::vector args {"--config-file", confFile.getFilePath()}; + + cleanup(); + init(args); + + FlashMQTestClient client; + client.start(); + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { + connect.username = "me"; + connect.password = "me me"; + + connect.constructPropertyBuilder(); + + connect.propertyBuilder->writeAuthenticationMethod("two_step"); + connect.propertyBuilder->writeAuthenticationData("Hello"); + }); + + QVERIFY(client.receivedPackets.size() == 1); + + AuthPacketData authData = client.receivedPackets.front().parseAuthData(); + + QVERIFY(authData.reasonCode == ReasonCodes::ContinueAuthentication); + QVERIFY(authData.data == "Hello back"); + + client.clearReceivedLists(); + + const Auth auth(ReasonCodes::ContinueAuthentication, "two_step", "whoops, wrong data."); + client.writeAuth(auth); + + client.waitForConnack(); + + QVERIFY(client.receivedPackets.size() == 1); + + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); + + QVERIFY(connAckData.reasonCode == ReasonCodes::NotAuthorized); +} + +void MainTests::testExtendedReAuth() +{ + ConfFileTemp confFile; + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); + confFile.closeFile(); + + std::vector args {"--config-file", confFile.getFilePath()}; + + cleanup(); + init(args); + + FlashMQTestClient client; + client.start(); + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { + connect.username = "me"; + connect.password = "me me"; + + connect.constructPropertyBuilder(); + connect.propertyBuilder->writeAuthenticationMethod("always_good_passing_back_the_auth_data"); + connect.propertyBuilder->writeAuthenticationData("Santa Claus"); + }); + + QVERIFY(client.receivedPackets.size() == 1); + + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); + + QVERIFY(connAckData.reasonCode == ReasonCodes::Success); + + client.clearReceivedLists(); + + // Then reauth. + + Auth auth(ReasonCodes::ContinueAuthentication, "always_good_passing_back_the_auth_data", "Again Santa Claus"); + client.writeAuth(auth); + + client.waitForConnack(); + + QVERIFY(client.receivedPackets.size() == 1); + + AuthPacketData authData = client.receivedPackets.front().parseAuthData(); + + QVERIFY(authData.reasonCode == ReasonCodes::Success); + QVERIFY(authData.data == "Again Santa Claus"); +} + +void MainTests::testExtendedReAuthTwoStep() +{ + ConfFileTemp confFile; + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); + confFile.closeFile(); + + std::vector args {"--config-file", confFile.getFilePath()}; + + cleanup(); + init(args); + + FlashMQTestClient client; + client.start(); + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { + connect.username = "me"; + connect.password = "me me"; + + connect.constructPropertyBuilder(); + + connect.propertyBuilder->writeAuthenticationMethod("two_step"); + connect.propertyBuilder->writeAuthenticationData("Hello"); + }); + + QVERIFY(client.receivedPackets.size() == 1); + + AuthPacketData authData = client.receivedPackets.front().parseAuthData(); + + QVERIFY(authData.reasonCode == ReasonCodes::ContinueAuthentication); + QVERIFY(authData.data == "Hello back"); + + client.clearReceivedLists(); + + const Auth auth(ReasonCodes::ContinueAuthentication, "two_step", "grant me already!"); + client.writeAuth(auth); + + client.waitForConnack(); + + QVERIFY(client.receivedPackets.size() == 1); + + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); + + QVERIFY(connAckData.reasonCode == ReasonCodes::Success); + QVERIFY(connAckData.authData == "OK, if you insist."); + + client.clearReceivedLists(); + + // Then reauth. + + const Auth reauth(ReasonCodes::ReAuthenticate, "two_step", "Hello"); + client.writeAuth(reauth); + client.waitForConnack(); + + QVERIFY(client.receivedPackets.size() == 1); + + AuthPacketData reauthData = client.receivedPackets.front().parseAuthData(); + + QVERIFY(reauthData.reasonCode == ReasonCodes::ContinueAuthentication); + QVERIFY(reauthData.data == "Hello back"); + + client.clearReceivedLists(); + + const Auth reauthFinish(ReasonCodes::ContinueAuthentication, "two_step", "grant me already!"); + client.writeAuth(reauthFinish); + + client.waitForConnack(); + + QVERIFY(client.receivedPackets.size() == 1); + + AuthPacketData reauthFinishData = client.receivedPackets.front().parseAuthData(); + + QVERIFY(connAckData.reasonCode == ReasonCodes::Success); + QVERIFY(connAckData.authData == "OK, if you insist."); +} + +void MainTests::testExtendedReAuthFail() +{ + ConfFileTemp confFile; + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); + confFile.closeFile(); + + std::vector args {"--config-file", confFile.getFilePath()}; + + cleanup(); + init(args); + + FlashMQTestClient client; + client.start(); + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { + connect.username = "me"; + connect.password = "me me"; + + connect.constructPropertyBuilder(); + + connect.propertyBuilder->writeAuthenticationMethod("always_good_passing_back_the_auth_data"); + connect.propertyBuilder->writeAuthenticationData("I have a proposal to put to ye."); + }); + + QVERIFY(client.receivedPackets.size() == 1); + + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); + + QVERIFY(connAckData.reasonCode == ReasonCodes::Success); + QVERIFY(connAckData.authData == "I have a proposal to put to ye."); + + client.clearReceivedLists(); + + // Then reauth. + + const Auth reauth(ReasonCodes::ReAuthenticate, "always_good_passing_back_the_auth_data", "actually not good."); + client.writeAuth(reauth); + client.waitForPacketCount(1); + + QVERIFY(client.receivedPackets.size() == 1); + QVERIFY(client.receivedPackets.front().packetType == PacketType::DISCONNECT); + + DisconnectData data = client.receivedPackets.front().parseDisconnectData(); + + QVERIFY(data.reasonCode == ReasonCodes::NotAuthorized); +} + int main(int argc, char *argv[]) { QCoreApplication app(argc, argv); diff --git a/authplugin.cpp b/authplugin.cpp index 63603d1..d2e4925 100644 --- a/authplugin.cpp +++ b/authplugin.cpp @@ -356,7 +356,7 @@ AuthResult Authentication::unPwdCheck(const std::string &username, const std::st { AuthResult firstResult = unPwdCheckFromMosquittoPasswordFile(username, password); - if (firstResult != AuthResult::success) + if (firstResult == AuthResult::success) return firstResult; if (pluginVersion == PluginVersion::None) diff --git a/flashmqtestclient.cpp b/flashmqtestclient.cpp index dafca3d..263080f 100644 --- a/flashmqtestclient.cpp +++ b/flashmqtestclient.cpp @@ -260,6 +260,12 @@ void FlashMQTestClient::publish(Publish &pub) } } +void FlashMQTestClient::writeAuth(const Auth &auth) +{ + MqttPacket pack(auth); + client->writeMqttPacketAndBlameThisClient(pack); +} + void FlashMQTestClient::publish(const std::string &topic, const std::string &payload, char qos) { Publish pub(topic, payload, qos); @@ -276,7 +282,7 @@ void FlashMQTestClient::waitForConnack() { waitForCondition([&]() { return std::any_of(this->receivedPackets.begin(), this->receivedPackets.end(), [](const MqttPacket &p) { - return p.packetType == PacketType::CONNACK; + return p.packetType == PacketType::CONNACK || p.packetType == PacketType::AUTH; }); }); } @@ -287,3 +293,10 @@ void FlashMQTestClient::waitForMessageCount(const size_t count, int timeout) return this->receivedPublishes.size() >= count; }, timeout); } + +void FlashMQTestClient::waitForPacketCount(const size_t count, int timeout) +{ + waitForCondition([&]() { + return this->receivedPackets.size() >= count; + }, timeout); +} diff --git a/flashmqtestclient.h b/flashmqtestclient.h index e8804fd..8dba152 100644 --- a/flashmqtestclient.h +++ b/flashmqtestclient.h @@ -40,6 +40,7 @@ public: void unsubscribe(const std::string &topic); void publish(const std::string &topic, const std::string &payload, char qos); void publish(Publish &pub); + void writeAuth(const Auth &auth); void clearReceivedLists(); void setWill(std::shared_ptr &will); void disconnect(ReasonCodes reason); @@ -47,6 +48,7 @@ public: void waitForQuit(); void waitForConnack(); void waitForMessageCount(const size_t count, int timeout = 1); + void waitForPacketCount(const size_t count, int timeout = 1); }; #endif // FLASHMQTESTCLIENT_H diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 3bfe720..05e244d 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -253,6 +253,8 @@ MqttPacket::MqttPacket(const Connect &connect) : writeByte(static_cast(protocolVersion)); uint8_t flags = connect.clean_start << 1; + flags |= !connect.username.empty() << 7; + flags |= !connect.password.empty() << 6; if (connect.will) { @@ -284,6 +286,11 @@ MqttPacket::MqttPacket(const Connect &connect) : writeString(connect.will->payload); } + if (!connect.username.empty()) + writeString(connect.username); + if (!connect.password.empty()) + writeString(connect.password); + calculateRemainingLength(); } @@ -634,6 +641,43 @@ ConnectData MqttPacket::parseConnectData() return result; } +ConnAckData MqttPacket::parseConnAckData() +{ + if (this->packetType != PacketType::CONNACK) + throw std::runtime_error("Packet must be connack packet."); + + setPosToDataStart(); + + ConnAckData result; + + const uint8_t flagByte = readByte(); + + result.sessionPresent = flagByte & 0x01; + result.reasonCode = static_cast(readUint8()); + + if (protocolVersion == ProtocolVersion::Mqtt5) + { + const size_t proplen = decodeVariableByteIntAtPos(); + const size_t prop_end_at = pos + proplen; + + while (pos < prop_end_at) + { + const Mqtt5Properties prop = static_cast(readByte()); + + switch (prop) + { + case Mqtt5Properties::AuthenticationData: + result.authData = readBytesToString(); + break; + default: + break; + } + } + } + + return result; +} + void MqttPacket::handleConnect() { if (sender->hasConnectPacketSeen()) @@ -798,18 +842,22 @@ void MqttPacket::handleConnect() } } -void MqttPacket::handleExtendedAuth() +AuthPacketData MqttPacket::parseAuthData() { + if (this->packetType != PacketType::AUTH) + throw std::runtime_error("Packet must be an AUTH packet."); + if (first_byte & 0b1111) throw ProtocolError("AUTH packet first 4 bits should be 0.", ReasonCodes::MalformedPacket); - const ReasonCodes reasonCode = static_cast(readByte()); - if (this->protocolVersion < ProtocolVersion::Mqtt5) throw ProtocolError("AUTH packet needs MQTT5 or higher"); - std::string authMethod; - std::string authData; + setPosToDataStart(); + + AuthPacketData result; + + result.reasonCode = static_cast(readUint8()); if (!atEnd()) { @@ -823,10 +871,10 @@ void MqttPacket::handleExtendedAuth() switch (prop) { case Mqtt5Properties::AuthenticationMethod: - authMethod = readBytesToString(); + result.method = readBytesToString(); break; case Mqtt5Properties::AuthenticationData: - authData = readBytesToString(false); + result.data = readBytesToString(false); break; case Mqtt5Properties::ReasonString: readBytesToString(); @@ -840,12 +888,19 @@ void MqttPacket::handleExtendedAuth() } } - if (authMethod != sender->getExtendedAuthenticationMethod()) + return result; +} + +void MqttPacket::handleExtendedAuth() +{ + AuthPacketData data = parseAuthData(); + + if (data.method != sender->getExtendedAuthenticationMethod()) throw ProtocolError("Client continued with another authentication method that it started with.", ReasonCodes::ProtocolError); ExtendedAuthStage authStage = ExtendedAuthStage::None; - switch(reasonCode) + switch(data.reasonCode) { case ReasonCodes::ContinueAuthentication: authStage = ExtendedAuthStage::Continue; @@ -854,7 +909,7 @@ void MqttPacket::handleExtendedAuth() authStage = ExtendedAuthStage::Reauth; break; default: - throw ProtocolError(formatString("Invalid reason code '%d' in auth packet", static_cast(reasonCode)), ReasonCodes::MalformedPacket); + throw ProtocolError(formatString("Invalid reason code '%d' in auth packet", static_cast(data.reasonCode)), ReasonCodes::MalformedPacket); } if (authStage == ExtendedAuthStage::Reauth && !sender->getAuthenticated()) @@ -865,28 +920,24 @@ void MqttPacket::handleExtendedAuth() Authentication &authentication = *ThreadGlobals::getAuth(); std::string returnData; - const AuthResult authResult = authentication.extendedAuth(sender->getClientId(), authStage, authMethod, authData, + const AuthResult authResult = authentication.extendedAuth(sender->getClientId(), authStage, data.method, data.data, getUserProperties(), returnData, sender->getMutableUsername()); if (authResult == AuthResult::auth_continue) { - Auth auth(ReasonCodes::ContinueAuthentication, authMethod, returnData); + Auth auth(ReasonCodes::ContinueAuthentication, data.method, returnData); MqttPacket pack(auth); sender->writeMqttPacket(pack); return; } - if (authResult == AuthResult::success) - { - sender->addAuthReturnDataToStagedConnAck(returnData); - } - const ReasonCodes finalResult = authResultToReasonCode(authResult); if (!sender->getAuthenticated()) // First auth sends connack packets. { if (finalResult == ReasonCodes::Success) { + sender->addAuthReturnDataToStagedConnAck(returnData); sender->sendConnackSuccess(); std::shared_ptr subscriptionStore = MainApp::getMainApp()->getSubscriptionStore(); subscriptionStore->registerClientAndKickExistingOne(sender); @@ -900,7 +951,7 @@ void MqttPacket::handleExtendedAuth() { if (finalResult == ReasonCodes::Success) { - Auth auth(ReasonCodes::Success, authMethod, returnData); + Auth auth(ReasonCodes::Success, data.method, returnData); MqttPacket authPack(auth); sender->writeMqttPacket(authPack); logger->logf(LOG_NOTICE, "Client '%s', user '%s' reauthentication successful.", sender->getClientId().c_str(), sender->getUsername().c_str()); @@ -933,7 +984,7 @@ DisconnectData MqttPacket::parseDisconnectData() { if (!atEnd()) { - result.reasonCode = static_cast(readByte()); + result.reasonCode = static_cast(readUint8()); const size_t proplen = decodeVariableByteIntAtPos(); const size_t prop_end_at = pos + proplen; @@ -1621,6 +1672,12 @@ char MqttPacket::readByte() return b; } +uint8_t MqttPacket::readUint8() +{ + char r = readByte(); + return static_cast(r); +} + void MqttPacket::writeByte(char b) { if (pos + 1 > bites.size()) diff --git a/mqttpacket.h b/mqttpacket.h index afe9c01..214ec75 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -73,6 +73,7 @@ class MqttPacket char *readBytes(size_t length); char readByte(); + uint8_t readUint8(); void writeByte(char b); void writeUint16(uint16_t x); void writeBytes(const char *b, size_t len); @@ -121,7 +122,9 @@ public: static void bufferToMqttPackets(CirBuf &buf, std::vector &packetQueueIn, std::shared_ptr &sender); void handle(); + AuthPacketData parseAuthData(); ConnectData parseConnectData(); + ConnAckData parseConnAckData(); void handleConnect(); void handleExtendedAuth(); DisconnectData parseDisconnectData(); diff --git a/packetdatatypes.h b/packetdatatypes.h index 1aaf7e8..c03bc6c 100644 --- a/packetdatatypes.h +++ b/packetdatatypes.h @@ -38,6 +38,23 @@ struct ConnectData ConnectData(); }; +struct ConnAckData +{ + // Flags + bool sessionPresent = false; + + ReasonCodes reasonCode = ReasonCodes::ImplementationSpecificError; // default something that is never a parse result; + + std::string authData; +}; + +struct AuthPacketData +{ + std::string method; + std::string data; + ReasonCodes reasonCode = ReasonCodes::ImplementationSpecificError; // default something that is never a parse result; +}; + struct DisconnectData { ReasonCodes reasonCode = ReasonCodes::Success; diff --git a/types.cpp b/types.cpp index 3e64cf4..5923dc4 100644 --- a/types.cpp +++ b/types.cpp @@ -397,6 +397,12 @@ size_t Connect::getLengthWithoutFixedHeader() const result += will->payload.length() + 2; } + if (!username.empty()) + result += username.size() + 2; + + if (!password.empty()) + result += password.size() + 2; + return result; }