Commit 5f683b9c702ce2d2d85dfc7fcca722a8a560849d
1 parent
5c47eca3
Test (plugin) auth system + fixes
When the plugin was enabled, it would never got past the fail of not having a password file entry. This is fixed. The reauth didn't work because setting the return data took the path of first auth, hitting a protection exception. This is fixed. Writing extended auth tests involved some refactoring, to create a separate method to parse the auth packet data, because I needed it in the test too.
Showing
14 changed files
with
689 additions
and
24 deletions
FlashMQTests/FlashMQTests.pro
| @@ -55,6 +55,7 @@ SOURCES += tst_maintests.cpp \ | @@ -55,6 +55,7 @@ SOURCES += tst_maintests.cpp \ | ||
| 55 | ../derivablecounter.cpp \ | 55 | ../derivablecounter.cpp \ |
| 56 | ../packetdatatypes.cpp \ | 56 | ../packetdatatypes.cpp \ |
| 57 | ../flashmqtestclient.cpp \ | 57 | ../flashmqtestclient.cpp \ |
| 58 | + conffiletemp.cpp \ | ||
| 58 | mainappthread.cpp | 59 | mainappthread.cpp |
| 59 | 60 | ||
| 60 | 61 | ||
| @@ -102,6 +103,7 @@ HEADERS += \ | @@ -102,6 +103,7 @@ HEADERS += \ | ||
| 102 | ../derivablecounter.h \ | 103 | ../derivablecounter.h \ |
| 103 | ../packetdatatypes.h \ | 104 | ../packetdatatypes.h \ |
| 104 | ../flashmqtestclient.h \ | 105 | ../flashmqtestclient.h \ |
| 106 | + conffiletemp.h \ | ||
| 105 | mainappthread.h | 107 | mainappthread.h |
| 106 | 108 | ||
| 107 | LIBS += -ldl -lssl -lcrypto | 109 | LIBS += -ldl -lssl -lcrypto |
FlashMQTests/conffiletemp.cpp
0 → 100644
| 1 | +#include "conffiletemp.h" | ||
| 2 | + | ||
| 3 | +#include <vector> | ||
| 4 | +#include "unistd.h" | ||
| 5 | +#include <stdexcept> | ||
| 6 | + | ||
| 7 | +ConfFileTemp::ConfFileTemp() | ||
| 8 | +{ | ||
| 9 | + const std::string templateName("/tmp/flashmqconf_XXXXXX"); | ||
| 10 | + std::vector<char> nameBuf(templateName.size() + 1, 0); | ||
| 11 | + std::copy(templateName.begin(), templateName.end(), nameBuf.begin()); | ||
| 12 | + this->fd = mkstemp(nameBuf.data()); | ||
| 13 | + | ||
| 14 | + if (this->fd < 0) | ||
| 15 | + { | ||
| 16 | + throw std::runtime_error("mkstemp error."); | ||
| 17 | + } | ||
| 18 | + | ||
| 19 | + this->filePath = nameBuf.data(); | ||
| 20 | +} | ||
| 21 | + | ||
| 22 | +ConfFileTemp::~ConfFileTemp() | ||
| 23 | +{ | ||
| 24 | + closeFile(); | ||
| 25 | + | ||
| 26 | + if (!this->filePath.empty()) | ||
| 27 | + unlink(this->filePath.c_str()); | ||
| 28 | +} | ||
| 29 | + | ||
| 30 | +const std::string &ConfFileTemp::getFilePath() const | ||
| 31 | +{ | ||
| 32 | + if (fd > 0) | ||
| 33 | + throw std::runtime_error("You first need to close the file before using it."); | ||
| 34 | + | ||
| 35 | + return this->filePath; | ||
| 36 | +} | ||
| 37 | + | ||
| 38 | +void ConfFileTemp::writeLine(const std::string &line) | ||
| 39 | +{ | ||
| 40 | + write(this->fd, line.c_str(), line.size()); | ||
| 41 | + write(this->fd, "\n", 1); | ||
| 42 | +} | ||
| 43 | + | ||
| 44 | +void ConfFileTemp::closeFile() | ||
| 45 | +{ | ||
| 46 | + if (this->fd < 0) | ||
| 47 | + return; | ||
| 48 | + | ||
| 49 | + close(this->fd); | ||
| 50 | + this->fd = -1; | ||
| 51 | +} |
FlashMQTests/conffiletemp.h
0 → 100644
| 1 | +#ifndef CONFFILETEMP_H | ||
| 2 | +#define CONFFILETEMP_H | ||
| 3 | + | ||
| 4 | +#include <string> | ||
| 5 | + | ||
| 6 | +class ConfFileTemp | ||
| 7 | +{ | ||
| 8 | + int fd = -1; | ||
| 9 | + std::string filePath; | ||
| 10 | + | ||
| 11 | +public: | ||
| 12 | + ConfFileTemp(); | ||
| 13 | + ~ConfFileTemp(); | ||
| 14 | + | ||
| 15 | + const std::string &getFilePath() const; | ||
| 16 | + void writeLine(const std::string &line); | ||
| 17 | + void closeFile(); | ||
| 18 | +}; | ||
| 19 | + | ||
| 20 | +#endif // CONFFILETEMP_H |
FlashMQTests/mainappthread.cpp
| @@ -24,6 +24,39 @@ MainAppThread::MainAppThread(QObject *parent) : QThread(parent) | @@ -24,6 +24,39 @@ MainAppThread::MainAppThread(QObject *parent) : QThread(parent) | ||
| 24 | appInstance->settings->allowAnonymous = true; | 24 | appInstance->settings->allowAnonymous = true; |
| 25 | } | 25 | } |
| 26 | 26 | ||
| 27 | +MainAppThread::MainAppThread(const std::vector<std::string> &args, QObject *parent) : QThread(parent) | ||
| 28 | +{ | ||
| 29 | + std::list<std::vector<char>> argCopies; | ||
| 30 | + | ||
| 31 | + const std::string programName = "FlashMQTests"; | ||
| 32 | + std::vector<char> programNameCopy(programName.size() + 1, 0); | ||
| 33 | + std::copy(programName.begin(), programName.end(), programNameCopy.begin()); | ||
| 34 | + argCopies.push_back(std::move(programNameCopy)); | ||
| 35 | + | ||
| 36 | + for (const std::string &arg : args) | ||
| 37 | + { | ||
| 38 | + std::vector<char> copyArg(arg.size() + 1, 0); | ||
| 39 | + std::copy(arg.begin(), arg.end(), copyArg.begin()); | ||
| 40 | + argCopies.push_back(std::move(copyArg)); | ||
| 41 | + } | ||
| 42 | + | ||
| 43 | + char *argv[256]; | ||
| 44 | + memset(argv, 0, 256*sizeof (char*)); | ||
| 45 | + | ||
| 46 | + int i = 0; | ||
| 47 | + for (std::vector<char> © : argCopies) | ||
| 48 | + { | ||
| 49 | + argv[i++] = copy.data(); | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | + MainApp::initMainApp(i, argv); | ||
| 53 | + appInstance = MainApp::getMainApp(); | ||
| 54 | + | ||
| 55 | + // A hack: when I supply args I probably define a config for auth stuff. | ||
| 56 | + if (args.empty()) | ||
| 57 | + appInstance->settings->allowAnonymous = true; | ||
| 58 | +} | ||
| 59 | + | ||
| 27 | MainAppThread::~MainAppThread() | 60 | MainAppThread::~MainAppThread() |
| 28 | { | 61 | { |
| 29 | if (appInstance) | 62 | if (appInstance) |
FlashMQTests/mainappthread.h
| @@ -28,6 +28,7 @@ class MainAppThread : public QThread | @@ -28,6 +28,7 @@ class MainAppThread : public QThread | ||
| 28 | MainApp *appInstance = nullptr; | 28 | MainApp *appInstance = nullptr; |
| 29 | public: | 29 | public: |
| 30 | explicit MainAppThread(QObject *parent = nullptr); | 30 | explicit MainAppThread(QObject *parent = nullptr); |
| 31 | + MainAppThread(const std::vector<std::string> &args, QObject *parent = nullptr); | ||
| 31 | ~MainAppThread(); | 32 | ~MainAppThread(); |
| 32 | 33 | ||
| 33 | public slots: | 34 | public slots: |
FlashMQTests/plugins/test_plugin.cpp
| @@ -45,6 +45,9 @@ AuthResult flashmq_auth_plugin_login_check(void *thread_data, const std::string | @@ -45,6 +45,9 @@ AuthResult flashmq_auth_plugin_login_check(void *thread_data, const std::string | ||
| 45 | (void)password; | 45 | (void)password; |
| 46 | (void)userProperties; | 46 | (void)userProperties; |
| 47 | 47 | ||
| 48 | + if (username == "failme") | ||
| 49 | + return AuthResult::login_denied; | ||
| 50 | + | ||
| 48 | return AuthResult::success; | 51 | return AuthResult::success; |
| 49 | } | 52 | } |
| 50 | 53 | ||
| @@ -72,6 +75,34 @@ AuthResult flashmq_extended_auth(void *thread_data, const std::string &clientid, | @@ -72,6 +75,34 @@ AuthResult flashmq_extended_auth(void *thread_data, const std::string &clientid, | ||
| 72 | (void)userProperties; | 75 | (void)userProperties; |
| 73 | (void)returnData; | 76 | (void)returnData; |
| 74 | 77 | ||
| 75 | - return AuthResult::success; | 78 | + if (authMethod == "always_good_passing_back_the_auth_data") |
| 79 | + { | ||
| 80 | + if (authData == "actually not good.") | ||
| 81 | + return AuthResult::login_denied; | ||
| 82 | + | ||
| 83 | + returnData = authData; | ||
| 84 | + return AuthResult::success; | ||
| 85 | + } | ||
| 86 | + if (authMethod == "always_fail") | ||
| 87 | + { | ||
| 88 | + return AuthResult::login_denied; | ||
| 89 | + } | ||
| 90 | + if (authMethod == "two_step") | ||
| 91 | + { | ||
| 92 | + if (authData == "Hello") | ||
| 93 | + returnData = "Hello back"; | ||
| 94 | + | ||
| 95 | + if (authData == "grant me already!") | ||
| 96 | + { | ||
| 97 | + returnData = "OK, if you insist."; | ||
| 98 | + return AuthResult::success; | ||
| 99 | + } | ||
| 100 | + else if (authData == "whoops, wrong data.") | ||
| 101 | + return AuthResult::login_denied; | ||
| 102 | + else | ||
| 103 | + return AuthResult::auth_continue; | ||
| 104 | + } | ||
| 105 | + | ||
| 106 | + return AuthResult::auth_method_not_supported; | ||
| 76 | } | 107 | } |
| 77 | 108 |
FlashMQTests/tst_maintests.cpp
| @@ -32,6 +32,8 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | @@ -32,6 +32,8 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. | ||
| 32 | #include "session.h" | 32 | #include "session.h" |
| 33 | #include "threaddata.h" | 33 | #include "threaddata.h" |
| 34 | #include "threadglobals.h" | 34 | #include "threadglobals.h" |
| 35 | +#include "conffiletemp.h" | ||
| 36 | +#include "packetdatatypes.h" | ||
| 35 | 37 | ||
| 36 | #include "flashmqtestclient.h" | 38 | #include "flashmqtestclient.h" |
| 37 | 39 | ||
| @@ -65,6 +67,7 @@ public: | @@ -65,6 +67,7 @@ public: | ||
| 65 | ~MainTests(); | 67 | ~MainTests(); |
| 66 | 68 | ||
| 67 | private slots: | 69 | private slots: |
| 70 | + void init(const std::vector<std::string> &args); | ||
| 68 | void init(); // will be called before each test function is executed | 71 | void init(); // will be called before each test function is executed |
| 69 | void cleanup(); // will be called after every test function. | 72 | void cleanup(); // will be called after every test function. |
| 70 | 73 | ||
| @@ -134,6 +137,18 @@ private slots: | @@ -134,6 +137,18 @@ private slots: | ||
| 134 | 137 | ||
| 135 | void testUserProperties(); | 138 | void testUserProperties(); |
| 136 | 139 | ||
| 140 | + void testAuthFail(); | ||
| 141 | + void testAuthSucceed(); | ||
| 142 | + | ||
| 143 | + void testExtendedAuthOneStepSucceed(); | ||
| 144 | + void testExtendedAuthOneStepDeny(); | ||
| 145 | + void testExtendedAuthOneStepBadAuthMethod(); | ||
| 146 | + void testExtendedAuthTwoStep(); | ||
| 147 | + void testExtendedAuthTwoStepSecondStepFail(); | ||
| 148 | + void testExtendedReAuth(); | ||
| 149 | + void testExtendedReAuthTwoStep(); | ||
| 150 | + void testExtendedReAuthFail(); | ||
| 151 | + | ||
| 137 | }; | 152 | }; |
| 138 | 153 | ||
| 139 | MainTests::MainTests() | 154 | MainTests::MainTests() |
| @@ -146,10 +161,10 @@ MainTests::~MainTests() | @@ -146,10 +161,10 @@ MainTests::~MainTests() | ||
| 146 | 161 | ||
| 147 | } | 162 | } |
| 148 | 163 | ||
| 149 | -void MainTests::init() | 164 | +void MainTests::init(const std::vector<std::string> &args) |
| 150 | { | 165 | { |
| 151 | mainApp.reset(); | 166 | mainApp.reset(); |
| 152 | - mainApp.reset(new MainAppThread()); | 167 | + mainApp.reset(new MainAppThread(args)); |
| 153 | mainApp->start(); | 168 | mainApp->start(); |
| 154 | mainApp->waitForStarted(); | 169 | mainApp->waitForStarted(); |
| 155 | 170 | ||
| @@ -160,6 +175,12 @@ void MainTests::init() | @@ -160,6 +175,12 @@ void MainTests::init() | ||
| 160 | ThreadGlobals::assignThreadData(dummyThreadData.get()); | 175 | ThreadGlobals::assignThreadData(dummyThreadData.get()); |
| 161 | } | 176 | } |
| 162 | 177 | ||
| 178 | +void MainTests::init() | ||
| 179 | +{ | ||
| 180 | + std::vector<std::string> args; | ||
| 181 | + init(args); | ||
| 182 | +} | ||
| 183 | + | ||
| 163 | void MainTests::cleanup() | 184 | void MainTests::cleanup() |
| 164 | { | 185 | { |
| 165 | mainApp->stopApp(); | 186 | mainApp->stopApp(); |
| @@ -1925,6 +1946,414 @@ void MainTests::testUserProperties() | @@ -1925,6 +1946,414 @@ void MainTests::testUserProperties() | ||
| 1925 | QVERIFY(properties3 == nullptr); | 1946 | QVERIFY(properties3 == nullptr); |
| 1926 | } | 1947 | } |
| 1927 | 1948 | ||
| 1949 | +void MainTests::testAuthFail() | ||
| 1950 | +{ | ||
| 1951 | + std::vector<ProtocolVersion> versions { ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5 }; | ||
| 1952 | + | ||
| 1953 | + ConfFileTemp confFile; | ||
| 1954 | + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); | ||
| 1955 | + confFile.closeFile(); | ||
| 1956 | + | ||
| 1957 | + std::vector<std::string> args {"--config-file", confFile.getFilePath()}; | ||
| 1958 | + | ||
| 1959 | + cleanup(); | ||
| 1960 | + init(args); | ||
| 1961 | + | ||
| 1962 | + for (ProtocolVersion &version : versions) | ||
| 1963 | + { | ||
| 1964 | + | ||
| 1965 | + FlashMQTestClient client; | ||
| 1966 | + client.start(); | ||
| 1967 | + client.connectClient(version, false, 120, [](Connect &connect) { | ||
| 1968 | + connect.username = "failme"; | ||
| 1969 | + connect.password = "boo"; | ||
| 1970 | + }); | ||
| 1971 | + | ||
| 1972 | + QVERIFY(client.receivedPackets.size() == 1); | ||
| 1973 | + | ||
| 1974 | + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); | ||
| 1975 | + | ||
| 1976 | + if (version >= ProtocolVersion::Mqtt5) | ||
| 1977 | + QVERIFY(connAckData.reasonCode == ReasonCodes::NotAuthorized); | ||
| 1978 | + else | ||
| 1979 | + QVERIFY(static_cast<uint8_t>(connAckData.reasonCode) == 5); | ||
| 1980 | + } | ||
| 1981 | +} | ||
| 1982 | + | ||
| 1983 | +void MainTests::testAuthSucceed() | ||
| 1984 | +{ | ||
| 1985 | + std::vector<ProtocolVersion> versions { ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5 }; | ||
| 1986 | + | ||
| 1987 | + ConfFileTemp confFile; | ||
| 1988 | + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); | ||
| 1989 | + confFile.closeFile(); | ||
| 1990 | + | ||
| 1991 | + std::vector<std::string> args {"--config-file", confFile.getFilePath()}; | ||
| 1992 | + | ||
| 1993 | + cleanup(); | ||
| 1994 | + init(args); | ||
| 1995 | + | ||
| 1996 | + for (ProtocolVersion &version : versions) | ||
| 1997 | + { | ||
| 1998 | + | ||
| 1999 | + FlashMQTestClient client; | ||
| 2000 | + client.start(); | ||
| 2001 | + client.connectClient(version, false, 120, [](Connect &connect) { | ||
| 2002 | + connect.username = "passme"; | ||
| 2003 | + connect.password = "boo"; | ||
| 2004 | + }); | ||
| 2005 | + | ||
| 2006 | + QVERIFY(client.receivedPackets.size() == 1); | ||
| 2007 | + | ||
| 2008 | + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); | ||
| 2009 | + | ||
| 2010 | + QVERIFY(connAckData.reasonCode == ReasonCodes::Success); | ||
| 2011 | + } | ||
| 2012 | +} | ||
| 2013 | + | ||
| 2014 | +void MainTests::testExtendedAuthOneStepSucceed() | ||
| 2015 | +{ | ||
| 2016 | + ConfFileTemp confFile; | ||
| 2017 | + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); | ||
| 2018 | + confFile.closeFile(); | ||
| 2019 | + | ||
| 2020 | + std::vector<std::string> args {"--config-file", confFile.getFilePath()}; | ||
| 2021 | + | ||
| 2022 | + cleanup(); | ||
| 2023 | + init(args); | ||
| 2024 | + | ||
| 2025 | + FlashMQTestClient client; | ||
| 2026 | + client.start(); | ||
| 2027 | + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { | ||
| 2028 | + connect.username = "me"; | ||
| 2029 | + connect.password = "me me"; | ||
| 2030 | + | ||
| 2031 | + connect.constructPropertyBuilder(); | ||
| 2032 | + | ||
| 2033 | + connect.propertyBuilder->writeAuthenticationMethod("always_good_passing_back_the_auth_data"); | ||
| 2034 | + connect.propertyBuilder->writeAuthenticationData("I have a proposal to put to ye."); | ||
| 2035 | + }); | ||
| 2036 | + | ||
| 2037 | + QVERIFY(client.receivedPackets.size() == 1); | ||
| 2038 | + | ||
| 2039 | + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); | ||
| 2040 | + | ||
| 2041 | + QVERIFY(connAckData.reasonCode == ReasonCodes::Success); | ||
| 2042 | + QVERIFY(connAckData.authData == "I have a proposal to put to ye."); | ||
| 2043 | +} | ||
| 2044 | + | ||
| 2045 | +void MainTests::testExtendedAuthOneStepDeny() | ||
| 2046 | +{ | ||
| 2047 | + ConfFileTemp confFile; | ||
| 2048 | + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); | ||
| 2049 | + confFile.closeFile(); | ||
| 2050 | + | ||
| 2051 | + std::vector<std::string> args {"--config-file", confFile.getFilePath()}; | ||
| 2052 | + | ||
| 2053 | + cleanup(); | ||
| 2054 | + init(args); | ||
| 2055 | + | ||
| 2056 | + FlashMQTestClient client; | ||
| 2057 | + client.start(); | ||
| 2058 | + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { | ||
| 2059 | + connect.username = "me"; | ||
| 2060 | + connect.password = "me me"; | ||
| 2061 | + | ||
| 2062 | + connect.constructPropertyBuilder(); | ||
| 2063 | + | ||
| 2064 | + connect.propertyBuilder->writeAuthenticationMethod("always_fail"); | ||
| 2065 | + }); | ||
| 2066 | + | ||
| 2067 | + QVERIFY(client.receivedPackets.size() == 1); | ||
| 2068 | + | ||
| 2069 | + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); | ||
| 2070 | + | ||
| 2071 | + QVERIFY(connAckData.reasonCode == ReasonCodes::NotAuthorized); | ||
| 2072 | +} | ||
| 2073 | + | ||
| 2074 | +void MainTests::testExtendedAuthOneStepBadAuthMethod() | ||
| 2075 | +{ | ||
| 2076 | + ConfFileTemp confFile; | ||
| 2077 | + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); | ||
| 2078 | + confFile.closeFile(); | ||
| 2079 | + | ||
| 2080 | + std::vector<std::string> args {"--config-file", confFile.getFilePath()}; | ||
| 2081 | + | ||
| 2082 | + cleanup(); | ||
| 2083 | + init(args); | ||
| 2084 | + | ||
| 2085 | + FlashMQTestClient client; | ||
| 2086 | + client.start(); | ||
| 2087 | + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { | ||
| 2088 | + connect.username = "me"; | ||
| 2089 | + connect.password = "me me"; | ||
| 2090 | + | ||
| 2091 | + connect.constructPropertyBuilder(); | ||
| 2092 | + | ||
| 2093 | + connect.propertyBuilder->writeAuthenticationMethod("doesnt_exist"); | ||
| 2094 | + }); | ||
| 2095 | + | ||
| 2096 | + QVERIFY(client.receivedPackets.size() == 1); | ||
| 2097 | + | ||
| 2098 | + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); | ||
| 2099 | + | ||
| 2100 | + QVERIFY(connAckData.reasonCode == ReasonCodes::BadAuthenticationMethod); | ||
| 2101 | +} | ||
| 2102 | + | ||
| 2103 | +void MainTests::testExtendedAuthTwoStep() | ||
| 2104 | +{ | ||
| 2105 | + ConfFileTemp confFile; | ||
| 2106 | + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); | ||
| 2107 | + confFile.closeFile(); | ||
| 2108 | + | ||
| 2109 | + std::vector<std::string> args {"--config-file", confFile.getFilePath()}; | ||
| 2110 | + | ||
| 2111 | + cleanup(); | ||
| 2112 | + init(args); | ||
| 2113 | + | ||
| 2114 | + FlashMQTestClient client; | ||
| 2115 | + client.start(); | ||
| 2116 | + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { | ||
| 2117 | + connect.username = "me"; | ||
| 2118 | + connect.password = "me me"; | ||
| 2119 | + | ||
| 2120 | + connect.constructPropertyBuilder(); | ||
| 2121 | + | ||
| 2122 | + connect.propertyBuilder->writeAuthenticationMethod("two_step"); | ||
| 2123 | + connect.propertyBuilder->writeAuthenticationData("Hello"); | ||
| 2124 | + }); | ||
| 2125 | + | ||
| 2126 | + QVERIFY(client.receivedPackets.size() == 1); | ||
| 2127 | + | ||
| 2128 | + AuthPacketData authData = client.receivedPackets.front().parseAuthData(); | ||
| 2129 | + | ||
| 2130 | + QVERIFY(authData.reasonCode == ReasonCodes::ContinueAuthentication); | ||
| 2131 | + QVERIFY(authData.data == "Hello back"); | ||
| 2132 | + | ||
| 2133 | + client.clearReceivedLists(); | ||
| 2134 | + | ||
| 2135 | + const Auth auth(ReasonCodes::ContinueAuthentication, "two_step", "grant me already!"); | ||
| 2136 | + client.writeAuth(auth); | ||
| 2137 | + | ||
| 2138 | + client.waitForConnack(); | ||
| 2139 | + | ||
| 2140 | + QVERIFY(client.receivedPackets.size() == 1); | ||
| 2141 | + | ||
| 2142 | + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); | ||
| 2143 | + | ||
| 2144 | + QVERIFY(connAckData.reasonCode == ReasonCodes::Success); | ||
| 2145 | + QVERIFY(connAckData.authData == "OK, if you insist."); | ||
| 2146 | +} | ||
| 2147 | + | ||
| 2148 | +void MainTests::testExtendedAuthTwoStepSecondStepFail() | ||
| 2149 | +{ | ||
| 2150 | + ConfFileTemp confFile; | ||
| 2151 | + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); | ||
| 2152 | + confFile.closeFile(); | ||
| 2153 | + | ||
| 2154 | + std::vector<std::string> args {"--config-file", confFile.getFilePath()}; | ||
| 2155 | + | ||
| 2156 | + cleanup(); | ||
| 2157 | + init(args); | ||
| 2158 | + | ||
| 2159 | + FlashMQTestClient client; | ||
| 2160 | + client.start(); | ||
| 2161 | + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { | ||
| 2162 | + connect.username = "me"; | ||
| 2163 | + connect.password = "me me"; | ||
| 2164 | + | ||
| 2165 | + connect.constructPropertyBuilder(); | ||
| 2166 | + | ||
| 2167 | + connect.propertyBuilder->writeAuthenticationMethod("two_step"); | ||
| 2168 | + connect.propertyBuilder->writeAuthenticationData("Hello"); | ||
| 2169 | + }); | ||
| 2170 | + | ||
| 2171 | + QVERIFY(client.receivedPackets.size() == 1); | ||
| 2172 | + | ||
| 2173 | + AuthPacketData authData = client.receivedPackets.front().parseAuthData(); | ||
| 2174 | + | ||
| 2175 | + QVERIFY(authData.reasonCode == ReasonCodes::ContinueAuthentication); | ||
| 2176 | + QVERIFY(authData.data == "Hello back"); | ||
| 2177 | + | ||
| 2178 | + client.clearReceivedLists(); | ||
| 2179 | + | ||
| 2180 | + const Auth auth(ReasonCodes::ContinueAuthentication, "two_step", "whoops, wrong data."); | ||
| 2181 | + client.writeAuth(auth); | ||
| 2182 | + | ||
| 2183 | + client.waitForConnack(); | ||
| 2184 | + | ||
| 2185 | + QVERIFY(client.receivedPackets.size() == 1); | ||
| 2186 | + | ||
| 2187 | + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); | ||
| 2188 | + | ||
| 2189 | + QVERIFY(connAckData.reasonCode == ReasonCodes::NotAuthorized); | ||
| 2190 | +} | ||
| 2191 | + | ||
| 2192 | +void MainTests::testExtendedReAuth() | ||
| 2193 | +{ | ||
| 2194 | + ConfFileTemp confFile; | ||
| 2195 | + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); | ||
| 2196 | + confFile.closeFile(); | ||
| 2197 | + | ||
| 2198 | + std::vector<std::string> args {"--config-file", confFile.getFilePath()}; | ||
| 2199 | + | ||
| 2200 | + cleanup(); | ||
| 2201 | + init(args); | ||
| 2202 | + | ||
| 2203 | + FlashMQTestClient client; | ||
| 2204 | + client.start(); | ||
| 2205 | + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { | ||
| 2206 | + connect.username = "me"; | ||
| 2207 | + connect.password = "me me"; | ||
| 2208 | + | ||
| 2209 | + connect.constructPropertyBuilder(); | ||
| 2210 | + connect.propertyBuilder->writeAuthenticationMethod("always_good_passing_back_the_auth_data"); | ||
| 2211 | + connect.propertyBuilder->writeAuthenticationData("Santa Claus"); | ||
| 2212 | + }); | ||
| 2213 | + | ||
| 2214 | + QVERIFY(client.receivedPackets.size() == 1); | ||
| 2215 | + | ||
| 2216 | + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); | ||
| 2217 | + | ||
| 2218 | + QVERIFY(connAckData.reasonCode == ReasonCodes::Success); | ||
| 2219 | + | ||
| 2220 | + client.clearReceivedLists(); | ||
| 2221 | + | ||
| 2222 | + // Then reauth. | ||
| 2223 | + | ||
| 2224 | + Auth auth(ReasonCodes::ContinueAuthentication, "always_good_passing_back_the_auth_data", "Again Santa Claus"); | ||
| 2225 | + client.writeAuth(auth); | ||
| 2226 | + | ||
| 2227 | + client.waitForConnack(); | ||
| 2228 | + | ||
| 2229 | + QVERIFY(client.receivedPackets.size() == 1); | ||
| 2230 | + | ||
| 2231 | + AuthPacketData authData = client.receivedPackets.front().parseAuthData(); | ||
| 2232 | + | ||
| 2233 | + QVERIFY(authData.reasonCode == ReasonCodes::Success); | ||
| 2234 | + QVERIFY(authData.data == "Again Santa Claus"); | ||
| 2235 | +} | ||
| 2236 | + | ||
| 2237 | +void MainTests::testExtendedReAuthTwoStep() | ||
| 2238 | +{ | ||
| 2239 | + ConfFileTemp confFile; | ||
| 2240 | + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); | ||
| 2241 | + confFile.closeFile(); | ||
| 2242 | + | ||
| 2243 | + std::vector<std::string> args {"--config-file", confFile.getFilePath()}; | ||
| 2244 | + | ||
| 2245 | + cleanup(); | ||
| 2246 | + init(args); | ||
| 2247 | + | ||
| 2248 | + FlashMQTestClient client; | ||
| 2249 | + client.start(); | ||
| 2250 | + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { | ||
| 2251 | + connect.username = "me"; | ||
| 2252 | + connect.password = "me me"; | ||
| 2253 | + | ||
| 2254 | + connect.constructPropertyBuilder(); | ||
| 2255 | + | ||
| 2256 | + connect.propertyBuilder->writeAuthenticationMethod("two_step"); | ||
| 2257 | + connect.propertyBuilder->writeAuthenticationData("Hello"); | ||
| 2258 | + }); | ||
| 2259 | + | ||
| 2260 | + QVERIFY(client.receivedPackets.size() == 1); | ||
| 2261 | + | ||
| 2262 | + AuthPacketData authData = client.receivedPackets.front().parseAuthData(); | ||
| 2263 | + | ||
| 2264 | + QVERIFY(authData.reasonCode == ReasonCodes::ContinueAuthentication); | ||
| 2265 | + QVERIFY(authData.data == "Hello back"); | ||
| 2266 | + | ||
| 2267 | + client.clearReceivedLists(); | ||
| 2268 | + | ||
| 2269 | + const Auth auth(ReasonCodes::ContinueAuthentication, "two_step", "grant me already!"); | ||
| 2270 | + client.writeAuth(auth); | ||
| 2271 | + | ||
| 2272 | + client.waitForConnack(); | ||
| 2273 | + | ||
| 2274 | + QVERIFY(client.receivedPackets.size() == 1); | ||
| 2275 | + | ||
| 2276 | + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); | ||
| 2277 | + | ||
| 2278 | + QVERIFY(connAckData.reasonCode == ReasonCodes::Success); | ||
| 2279 | + QVERIFY(connAckData.authData == "OK, if you insist."); | ||
| 2280 | + | ||
| 2281 | + client.clearReceivedLists(); | ||
| 2282 | + | ||
| 2283 | + // Then reauth. | ||
| 2284 | + | ||
| 2285 | + const Auth reauth(ReasonCodes::ReAuthenticate, "two_step", "Hello"); | ||
| 2286 | + client.writeAuth(reauth); | ||
| 2287 | + client.waitForConnack(); | ||
| 2288 | + | ||
| 2289 | + QVERIFY(client.receivedPackets.size() == 1); | ||
| 2290 | + | ||
| 2291 | + AuthPacketData reauthData = client.receivedPackets.front().parseAuthData(); | ||
| 2292 | + | ||
| 2293 | + QVERIFY(reauthData.reasonCode == ReasonCodes::ContinueAuthentication); | ||
| 2294 | + QVERIFY(reauthData.data == "Hello back"); | ||
| 2295 | + | ||
| 2296 | + client.clearReceivedLists(); | ||
| 2297 | + | ||
| 2298 | + const Auth reauthFinish(ReasonCodes::ContinueAuthentication, "two_step", "grant me already!"); | ||
| 2299 | + client.writeAuth(reauthFinish); | ||
| 2300 | + | ||
| 2301 | + client.waitForConnack(); | ||
| 2302 | + | ||
| 2303 | + QVERIFY(client.receivedPackets.size() == 1); | ||
| 2304 | + | ||
| 2305 | + AuthPacketData reauthFinishData = client.receivedPackets.front().parseAuthData(); | ||
| 2306 | + | ||
| 2307 | + QVERIFY(connAckData.reasonCode == ReasonCodes::Success); | ||
| 2308 | + QVERIFY(connAckData.authData == "OK, if you insist."); | ||
| 2309 | +} | ||
| 2310 | + | ||
| 2311 | +void MainTests::testExtendedReAuthFail() | ||
| 2312 | +{ | ||
| 2313 | + ConfFileTemp confFile; | ||
| 2314 | + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1"); | ||
| 2315 | + confFile.closeFile(); | ||
| 2316 | + | ||
| 2317 | + std::vector<std::string> args {"--config-file", confFile.getFilePath()}; | ||
| 2318 | + | ||
| 2319 | + cleanup(); | ||
| 2320 | + init(args); | ||
| 2321 | + | ||
| 2322 | + FlashMQTestClient client; | ||
| 2323 | + client.start(); | ||
| 2324 | + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { | ||
| 2325 | + connect.username = "me"; | ||
| 2326 | + connect.password = "me me"; | ||
| 2327 | + | ||
| 2328 | + connect.constructPropertyBuilder(); | ||
| 2329 | + | ||
| 2330 | + connect.propertyBuilder->writeAuthenticationMethod("always_good_passing_back_the_auth_data"); | ||
| 2331 | + connect.propertyBuilder->writeAuthenticationData("I have a proposal to put to ye."); | ||
| 2332 | + }); | ||
| 2333 | + | ||
| 2334 | + QVERIFY(client.receivedPackets.size() == 1); | ||
| 2335 | + | ||
| 2336 | + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData(); | ||
| 2337 | + | ||
| 2338 | + QVERIFY(connAckData.reasonCode == ReasonCodes::Success); | ||
| 2339 | + QVERIFY(connAckData.authData == "I have a proposal to put to ye."); | ||
| 2340 | + | ||
| 2341 | + client.clearReceivedLists(); | ||
| 2342 | + | ||
| 2343 | + // Then reauth. | ||
| 2344 | + | ||
| 2345 | + const Auth reauth(ReasonCodes::ReAuthenticate, "always_good_passing_back_the_auth_data", "actually not good."); | ||
| 2346 | + client.writeAuth(reauth); | ||
| 2347 | + client.waitForPacketCount(1); | ||
| 2348 | + | ||
| 2349 | + QVERIFY(client.receivedPackets.size() == 1); | ||
| 2350 | + QVERIFY(client.receivedPackets.front().packetType == PacketType::DISCONNECT); | ||
| 2351 | + | ||
| 2352 | + DisconnectData data = client.receivedPackets.front().parseDisconnectData(); | ||
| 2353 | + | ||
| 2354 | + QVERIFY(data.reasonCode == ReasonCodes::NotAuthorized); | ||
| 2355 | +} | ||
| 2356 | + | ||
| 1928 | int main(int argc, char *argv[]) | 2357 | int main(int argc, char *argv[]) |
| 1929 | { | 2358 | { |
| 1930 | QCoreApplication app(argc, argv); | 2359 | QCoreApplication app(argc, argv); |
authplugin.cpp
| @@ -356,7 +356,7 @@ AuthResult Authentication::unPwdCheck(const std::string &username, const std::st | @@ -356,7 +356,7 @@ AuthResult Authentication::unPwdCheck(const std::string &username, const std::st | ||
| 356 | { | 356 | { |
| 357 | AuthResult firstResult = unPwdCheckFromMosquittoPasswordFile(username, password); | 357 | AuthResult firstResult = unPwdCheckFromMosquittoPasswordFile(username, password); |
| 358 | 358 | ||
| 359 | - if (firstResult != AuthResult::success) | 359 | + if (firstResult == AuthResult::success) |
| 360 | return firstResult; | 360 | return firstResult; |
| 361 | 361 | ||
| 362 | if (pluginVersion == PluginVersion::None) | 362 | if (pluginVersion == PluginVersion::None) |
flashmqtestclient.cpp
| @@ -260,6 +260,12 @@ void FlashMQTestClient::publish(Publish &pub) | @@ -260,6 +260,12 @@ void FlashMQTestClient::publish(Publish &pub) | ||
| 260 | } | 260 | } |
| 261 | } | 261 | } |
| 262 | 262 | ||
| 263 | +void FlashMQTestClient::writeAuth(const Auth &auth) | ||
| 264 | +{ | ||
| 265 | + MqttPacket pack(auth); | ||
| 266 | + client->writeMqttPacketAndBlameThisClient(pack); | ||
| 267 | +} | ||
| 268 | + | ||
| 263 | void FlashMQTestClient::publish(const std::string &topic, const std::string &payload, char qos) | 269 | void FlashMQTestClient::publish(const std::string &topic, const std::string &payload, char qos) |
| 264 | { | 270 | { |
| 265 | Publish pub(topic, payload, qos); | 271 | Publish pub(topic, payload, qos); |
| @@ -276,7 +282,7 @@ void FlashMQTestClient::waitForConnack() | @@ -276,7 +282,7 @@ void FlashMQTestClient::waitForConnack() | ||
| 276 | { | 282 | { |
| 277 | waitForCondition([&]() { | 283 | waitForCondition([&]() { |
| 278 | return std::any_of(this->receivedPackets.begin(), this->receivedPackets.end(), [](const MqttPacket &p) { | 284 | return std::any_of(this->receivedPackets.begin(), this->receivedPackets.end(), [](const MqttPacket &p) { |
| 279 | - return p.packetType == PacketType::CONNACK; | 285 | + return p.packetType == PacketType::CONNACK || p.packetType == PacketType::AUTH; |
| 280 | }); | 286 | }); |
| 281 | }); | 287 | }); |
| 282 | } | 288 | } |
| @@ -287,3 +293,10 @@ void FlashMQTestClient::waitForMessageCount(const size_t count, int timeout) | @@ -287,3 +293,10 @@ void FlashMQTestClient::waitForMessageCount(const size_t count, int timeout) | ||
| 287 | return this->receivedPublishes.size() >= count; | 293 | return this->receivedPublishes.size() >= count; |
| 288 | }, timeout); | 294 | }, timeout); |
| 289 | } | 295 | } |
| 296 | + | ||
| 297 | +void FlashMQTestClient::waitForPacketCount(const size_t count, int timeout) | ||
| 298 | +{ | ||
| 299 | + waitForCondition([&]() { | ||
| 300 | + return this->receivedPackets.size() >= count; | ||
| 301 | + }, timeout); | ||
| 302 | +} |
flashmqtestclient.h
| @@ -40,6 +40,7 @@ public: | @@ -40,6 +40,7 @@ public: | ||
| 40 | void unsubscribe(const std::string &topic); | 40 | void unsubscribe(const std::string &topic); |
| 41 | void publish(const std::string &topic, const std::string &payload, char qos); | 41 | void publish(const std::string &topic, const std::string &payload, char qos); |
| 42 | void publish(Publish &pub); | 42 | void publish(Publish &pub); |
| 43 | + void writeAuth(const Auth &auth); | ||
| 43 | void clearReceivedLists(); | 44 | void clearReceivedLists(); |
| 44 | void setWill(std::shared_ptr<WillPublish> &will); | 45 | void setWill(std::shared_ptr<WillPublish> &will); |
| 45 | void disconnect(ReasonCodes reason); | 46 | void disconnect(ReasonCodes reason); |
| @@ -47,6 +48,7 @@ public: | @@ -47,6 +48,7 @@ public: | ||
| 47 | void waitForQuit(); | 48 | void waitForQuit(); |
| 48 | void waitForConnack(); | 49 | void waitForConnack(); |
| 49 | void waitForMessageCount(const size_t count, int timeout = 1); | 50 | void waitForMessageCount(const size_t count, int timeout = 1); |
| 51 | + void waitForPacketCount(const size_t count, int timeout = 1); | ||
| 50 | }; | 52 | }; |
| 51 | 53 | ||
| 52 | #endif // FLASHMQTESTCLIENT_H | 54 | #endif // FLASHMQTESTCLIENT_H |
mqttpacket.cpp
| @@ -253,6 +253,8 @@ MqttPacket::MqttPacket(const Connect &connect) : | @@ -253,6 +253,8 @@ MqttPacket::MqttPacket(const Connect &connect) : | ||
| 253 | writeByte(static_cast<char>(protocolVersion)); | 253 | writeByte(static_cast<char>(protocolVersion)); |
| 254 | 254 | ||
| 255 | uint8_t flags = connect.clean_start << 1; | 255 | uint8_t flags = connect.clean_start << 1; |
| 256 | + flags |= !connect.username.empty() << 7; | ||
| 257 | + flags |= !connect.password.empty() << 6; | ||
| 256 | 258 | ||
| 257 | if (connect.will) | 259 | if (connect.will) |
| 258 | { | 260 | { |
| @@ -284,6 +286,11 @@ MqttPacket::MqttPacket(const Connect &connect) : | @@ -284,6 +286,11 @@ MqttPacket::MqttPacket(const Connect &connect) : | ||
| 284 | writeString(connect.will->payload); | 286 | writeString(connect.will->payload); |
| 285 | } | 287 | } |
| 286 | 288 | ||
| 289 | + if (!connect.username.empty()) | ||
| 290 | + writeString(connect.username); | ||
| 291 | + if (!connect.password.empty()) | ||
| 292 | + writeString(connect.password); | ||
| 293 | + | ||
| 287 | calculateRemainingLength(); | 294 | calculateRemainingLength(); |
| 288 | } | 295 | } |
| 289 | 296 | ||
| @@ -634,6 +641,43 @@ ConnectData MqttPacket::parseConnectData() | @@ -634,6 +641,43 @@ ConnectData MqttPacket::parseConnectData() | ||
| 634 | return result; | 641 | return result; |
| 635 | } | 642 | } |
| 636 | 643 | ||
| 644 | +ConnAckData MqttPacket::parseConnAckData() | ||
| 645 | +{ | ||
| 646 | + if (this->packetType != PacketType::CONNACK) | ||
| 647 | + throw std::runtime_error("Packet must be connack packet."); | ||
| 648 | + | ||
| 649 | + setPosToDataStart(); | ||
| 650 | + | ||
| 651 | + ConnAckData result; | ||
| 652 | + | ||
| 653 | + const uint8_t flagByte = readByte(); | ||
| 654 | + | ||
| 655 | + result.sessionPresent = flagByte & 0x01; | ||
| 656 | + result.reasonCode = static_cast<ReasonCodes>(readUint8()); | ||
| 657 | + | ||
| 658 | + if (protocolVersion == ProtocolVersion::Mqtt5) | ||
| 659 | + { | ||
| 660 | + const size_t proplen = decodeVariableByteIntAtPos(); | ||
| 661 | + const size_t prop_end_at = pos + proplen; | ||
| 662 | + | ||
| 663 | + while (pos < prop_end_at) | ||
| 664 | + { | ||
| 665 | + const Mqtt5Properties prop = static_cast<Mqtt5Properties>(readByte()); | ||
| 666 | + | ||
| 667 | + switch (prop) | ||
| 668 | + { | ||
| 669 | + case Mqtt5Properties::AuthenticationData: | ||
| 670 | + result.authData = readBytesToString(); | ||
| 671 | + break; | ||
| 672 | + default: | ||
| 673 | + break; | ||
| 674 | + } | ||
| 675 | + } | ||
| 676 | + } | ||
| 677 | + | ||
| 678 | + return result; | ||
| 679 | +} | ||
| 680 | + | ||
| 637 | void MqttPacket::handleConnect() | 681 | void MqttPacket::handleConnect() |
| 638 | { | 682 | { |
| 639 | if (sender->hasConnectPacketSeen()) | 683 | if (sender->hasConnectPacketSeen()) |
| @@ -798,18 +842,22 @@ void MqttPacket::handleConnect() | @@ -798,18 +842,22 @@ void MqttPacket::handleConnect() | ||
| 798 | } | 842 | } |
| 799 | } | 843 | } |
| 800 | 844 | ||
| 801 | -void MqttPacket::handleExtendedAuth() | 845 | +AuthPacketData MqttPacket::parseAuthData() |
| 802 | { | 846 | { |
| 847 | + if (this->packetType != PacketType::AUTH) | ||
| 848 | + throw std::runtime_error("Packet must be an AUTH packet."); | ||
| 849 | + | ||
| 803 | if (first_byte & 0b1111) | 850 | if (first_byte & 0b1111) |
| 804 | throw ProtocolError("AUTH packet first 4 bits should be 0.", ReasonCodes::MalformedPacket); | 851 | throw ProtocolError("AUTH packet first 4 bits should be 0.", ReasonCodes::MalformedPacket); |
| 805 | 852 | ||
| 806 | - const ReasonCodes reasonCode = static_cast<ReasonCodes>(readByte()); | ||
| 807 | - | ||
| 808 | if (this->protocolVersion < ProtocolVersion::Mqtt5) | 853 | if (this->protocolVersion < ProtocolVersion::Mqtt5) |
| 809 | throw ProtocolError("AUTH packet needs MQTT5 or higher"); | 854 | throw ProtocolError("AUTH packet needs MQTT5 or higher"); |
| 810 | 855 | ||
| 811 | - std::string authMethod; | ||
| 812 | - std::string authData; | 856 | + setPosToDataStart(); |
| 857 | + | ||
| 858 | + AuthPacketData result; | ||
| 859 | + | ||
| 860 | + result.reasonCode = static_cast<ReasonCodes>(readUint8()); | ||
| 813 | 861 | ||
| 814 | if (!atEnd()) | 862 | if (!atEnd()) |
| 815 | { | 863 | { |
| @@ -823,10 +871,10 @@ void MqttPacket::handleExtendedAuth() | @@ -823,10 +871,10 @@ void MqttPacket::handleExtendedAuth() | ||
| 823 | switch (prop) | 871 | switch (prop) |
| 824 | { | 872 | { |
| 825 | case Mqtt5Properties::AuthenticationMethod: | 873 | case Mqtt5Properties::AuthenticationMethod: |
| 826 | - authMethod = readBytesToString(); | 874 | + result.method = readBytesToString(); |
| 827 | break; | 875 | break; |
| 828 | case Mqtt5Properties::AuthenticationData: | 876 | case Mqtt5Properties::AuthenticationData: |
| 829 | - authData = readBytesToString(false); | 877 | + result.data = readBytesToString(false); |
| 830 | break; | 878 | break; |
| 831 | case Mqtt5Properties::ReasonString: | 879 | case Mqtt5Properties::ReasonString: |
| 832 | readBytesToString(); | 880 | readBytesToString(); |
| @@ -840,12 +888,19 @@ void MqttPacket::handleExtendedAuth() | @@ -840,12 +888,19 @@ void MqttPacket::handleExtendedAuth() | ||
| 840 | } | 888 | } |
| 841 | } | 889 | } |
| 842 | 890 | ||
| 843 | - if (authMethod != sender->getExtendedAuthenticationMethod()) | 891 | + return result; |
| 892 | +} | ||
| 893 | + | ||
| 894 | +void MqttPacket::handleExtendedAuth() | ||
| 895 | +{ | ||
| 896 | + AuthPacketData data = parseAuthData(); | ||
| 897 | + | ||
| 898 | + if (data.method != sender->getExtendedAuthenticationMethod()) | ||
| 844 | throw ProtocolError("Client continued with another authentication method that it started with.", ReasonCodes::ProtocolError); | 899 | throw ProtocolError("Client continued with another authentication method that it started with.", ReasonCodes::ProtocolError); |
| 845 | 900 | ||
| 846 | ExtendedAuthStage authStage = ExtendedAuthStage::None; | 901 | ExtendedAuthStage authStage = ExtendedAuthStage::None; |
| 847 | 902 | ||
| 848 | - switch(reasonCode) | 903 | + switch(data.reasonCode) |
| 849 | { | 904 | { |
| 850 | case ReasonCodes::ContinueAuthentication: | 905 | case ReasonCodes::ContinueAuthentication: |
| 851 | authStage = ExtendedAuthStage::Continue; | 906 | authStage = ExtendedAuthStage::Continue; |
| @@ -854,7 +909,7 @@ void MqttPacket::handleExtendedAuth() | @@ -854,7 +909,7 @@ void MqttPacket::handleExtendedAuth() | ||
| 854 | authStage = ExtendedAuthStage::Reauth; | 909 | authStage = ExtendedAuthStage::Reauth; |
| 855 | break; | 910 | break; |
| 856 | default: | 911 | default: |
| 857 | - throw ProtocolError(formatString("Invalid reason code '%d' in auth packet", static_cast<uint8_t>(reasonCode)), ReasonCodes::MalformedPacket); | 912 | + throw ProtocolError(formatString("Invalid reason code '%d' in auth packet", static_cast<uint8_t>(data.reasonCode)), ReasonCodes::MalformedPacket); |
| 858 | } | 913 | } |
| 859 | 914 | ||
| 860 | if (authStage == ExtendedAuthStage::Reauth && !sender->getAuthenticated()) | 915 | if (authStage == ExtendedAuthStage::Reauth && !sender->getAuthenticated()) |
| @@ -865,28 +920,24 @@ void MqttPacket::handleExtendedAuth() | @@ -865,28 +920,24 @@ void MqttPacket::handleExtendedAuth() | ||
| 865 | Authentication &authentication = *ThreadGlobals::getAuth(); | 920 | Authentication &authentication = *ThreadGlobals::getAuth(); |
| 866 | 921 | ||
| 867 | std::string returnData; | 922 | std::string returnData; |
| 868 | - const AuthResult authResult = authentication.extendedAuth(sender->getClientId(), authStage, authMethod, authData, | 923 | + const AuthResult authResult = authentication.extendedAuth(sender->getClientId(), authStage, data.method, data.data, |
| 869 | getUserProperties(), returnData, sender->getMutableUsername()); | 924 | getUserProperties(), returnData, sender->getMutableUsername()); |
| 870 | 925 | ||
| 871 | if (authResult == AuthResult::auth_continue) | 926 | if (authResult == AuthResult::auth_continue) |
| 872 | { | 927 | { |
| 873 | - Auth auth(ReasonCodes::ContinueAuthentication, authMethod, returnData); | 928 | + Auth auth(ReasonCodes::ContinueAuthentication, data.method, returnData); |
| 874 | MqttPacket pack(auth); | 929 | MqttPacket pack(auth); |
| 875 | sender->writeMqttPacket(pack); | 930 | sender->writeMqttPacket(pack); |
| 876 | return; | 931 | return; |
| 877 | } | 932 | } |
| 878 | 933 | ||
| 879 | - if (authResult == AuthResult::success) | ||
| 880 | - { | ||
| 881 | - sender->addAuthReturnDataToStagedConnAck(returnData); | ||
| 882 | - } | ||
| 883 | - | ||
| 884 | const ReasonCodes finalResult = authResultToReasonCode(authResult); | 934 | const ReasonCodes finalResult = authResultToReasonCode(authResult); |
| 885 | 935 | ||
| 886 | if (!sender->getAuthenticated()) // First auth sends connack packets. | 936 | if (!sender->getAuthenticated()) // First auth sends connack packets. |
| 887 | { | 937 | { |
| 888 | if (finalResult == ReasonCodes::Success) | 938 | if (finalResult == ReasonCodes::Success) |
| 889 | { | 939 | { |
| 940 | + sender->addAuthReturnDataToStagedConnAck(returnData); | ||
| 890 | sender->sendConnackSuccess(); | 941 | sender->sendConnackSuccess(); |
| 891 | std::shared_ptr<SubscriptionStore> subscriptionStore = MainApp::getMainApp()->getSubscriptionStore(); | 942 | std::shared_ptr<SubscriptionStore> subscriptionStore = MainApp::getMainApp()->getSubscriptionStore(); |
| 892 | subscriptionStore->registerClientAndKickExistingOne(sender); | 943 | subscriptionStore->registerClientAndKickExistingOne(sender); |
| @@ -900,7 +951,7 @@ void MqttPacket::handleExtendedAuth() | @@ -900,7 +951,7 @@ void MqttPacket::handleExtendedAuth() | ||
| 900 | { | 951 | { |
| 901 | if (finalResult == ReasonCodes::Success) | 952 | if (finalResult == ReasonCodes::Success) |
| 902 | { | 953 | { |
| 903 | - Auth auth(ReasonCodes::Success, authMethod, returnData); | 954 | + Auth auth(ReasonCodes::Success, data.method, returnData); |
| 904 | MqttPacket authPack(auth); | 955 | MqttPacket authPack(auth); |
| 905 | sender->writeMqttPacket(authPack); | 956 | sender->writeMqttPacket(authPack); |
| 906 | logger->logf(LOG_NOTICE, "Client '%s', user '%s' reauthentication successful.", sender->getClientId().c_str(), sender->getUsername().c_str()); | 957 | logger->logf(LOG_NOTICE, "Client '%s', user '%s' reauthentication successful.", sender->getClientId().c_str(), sender->getUsername().c_str()); |
| @@ -933,7 +984,7 @@ DisconnectData MqttPacket::parseDisconnectData() | @@ -933,7 +984,7 @@ DisconnectData MqttPacket::parseDisconnectData() | ||
| 933 | { | 984 | { |
| 934 | if (!atEnd()) | 985 | if (!atEnd()) |
| 935 | { | 986 | { |
| 936 | - result.reasonCode = static_cast<ReasonCodes>(readByte()); | 987 | + result.reasonCode = static_cast<ReasonCodes>(readUint8()); |
| 937 | 988 | ||
| 938 | const size_t proplen = decodeVariableByteIntAtPos(); | 989 | const size_t proplen = decodeVariableByteIntAtPos(); |
| 939 | const size_t prop_end_at = pos + proplen; | 990 | const size_t prop_end_at = pos + proplen; |
| @@ -1621,6 +1672,12 @@ char MqttPacket::readByte() | @@ -1621,6 +1672,12 @@ char MqttPacket::readByte() | ||
| 1621 | return b; | 1672 | return b; |
| 1622 | } | 1673 | } |
| 1623 | 1674 | ||
| 1675 | +uint8_t MqttPacket::readUint8() | ||
| 1676 | +{ | ||
| 1677 | + char r = readByte(); | ||
| 1678 | + return static_cast<uint8_t>(r); | ||
| 1679 | +} | ||
| 1680 | + | ||
| 1624 | void MqttPacket::writeByte(char b) | 1681 | void MqttPacket::writeByte(char b) |
| 1625 | { | 1682 | { |
| 1626 | if (pos + 1 > bites.size()) | 1683 | if (pos + 1 > bites.size()) |
mqttpacket.h
| @@ -73,6 +73,7 @@ class MqttPacket | @@ -73,6 +73,7 @@ class MqttPacket | ||
| 73 | 73 | ||
| 74 | char *readBytes(size_t length); | 74 | char *readBytes(size_t length); |
| 75 | char readByte(); | 75 | char readByte(); |
| 76 | + uint8_t readUint8(); | ||
| 76 | void writeByte(char b); | 77 | void writeByte(char b); |
| 77 | void writeUint16(uint16_t x); | 78 | void writeUint16(uint16_t x); |
| 78 | void writeBytes(const char *b, size_t len); | 79 | void writeBytes(const char *b, size_t len); |
| @@ -121,7 +122,9 @@ public: | @@ -121,7 +122,9 @@ public: | ||
| 121 | static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender); | 122 | static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender); |
| 122 | 123 | ||
| 123 | void handle(); | 124 | void handle(); |
| 125 | + AuthPacketData parseAuthData(); | ||
| 124 | ConnectData parseConnectData(); | 126 | ConnectData parseConnectData(); |
| 127 | + ConnAckData parseConnAckData(); | ||
| 125 | void handleConnect(); | 128 | void handleConnect(); |
| 126 | void handleExtendedAuth(); | 129 | void handleExtendedAuth(); |
| 127 | DisconnectData parseDisconnectData(); | 130 | DisconnectData parseDisconnectData(); |
packetdatatypes.h
| @@ -38,6 +38,23 @@ struct ConnectData | @@ -38,6 +38,23 @@ struct ConnectData | ||
| 38 | ConnectData(); | 38 | ConnectData(); |
| 39 | }; | 39 | }; |
| 40 | 40 | ||
| 41 | +struct ConnAckData | ||
| 42 | +{ | ||
| 43 | + // Flags | ||
| 44 | + bool sessionPresent = false; | ||
| 45 | + | ||
| 46 | + ReasonCodes reasonCode = ReasonCodes::ImplementationSpecificError; // default something that is never a parse result; | ||
| 47 | + | ||
| 48 | + std::string authData; | ||
| 49 | +}; | ||
| 50 | + | ||
| 51 | +struct AuthPacketData | ||
| 52 | +{ | ||
| 53 | + std::string method; | ||
| 54 | + std::string data; | ||
| 55 | + ReasonCodes reasonCode = ReasonCodes::ImplementationSpecificError; // default something that is never a parse result; | ||
| 56 | +}; | ||
| 57 | + | ||
| 41 | struct DisconnectData | 58 | struct DisconnectData |
| 42 | { | 59 | { |
| 43 | ReasonCodes reasonCode = ReasonCodes::Success; | 60 | ReasonCodes reasonCode = ReasonCodes::Success; |
types.cpp
| @@ -397,6 +397,12 @@ size_t Connect::getLengthWithoutFixedHeader() const | @@ -397,6 +397,12 @@ size_t Connect::getLengthWithoutFixedHeader() const | ||
| 397 | result += will->payload.length() + 2; | 397 | result += will->payload.length() + 2; |
| 398 | } | 398 | } |
| 399 | 399 | ||
| 400 | + if (!username.empty()) | ||
| 401 | + result += username.size() + 2; | ||
| 402 | + | ||
| 403 | + if (!password.empty()) | ||
| 404 | + result += password.size() + 2; | ||
| 405 | + | ||
| 400 | return result; | 406 | return result; |
| 401 | } | 407 | } |
| 402 | 408 |