Commit 5f683b9c702ce2d2d85dfc7fcca722a8a560849d

Authored by Wiebe Cazemier
1 parent 5c47eca3

Test (plugin) auth system + fixes

When the plugin was enabled, it would never got past the fail of not
having a password file entry. This is fixed.

The reauth didn't work because setting the return data took the path of
first auth, hitting a protection exception. This is fixed.

Writing extended auth tests involved some refactoring, to create a
separate method to parse the auth packet data, because I needed it in
the test too.
FlashMQTests/FlashMQTests.pro
... ... @@ -55,6 +55,7 @@ SOURCES += tst_maintests.cpp \
55 55 ../derivablecounter.cpp \
56 56 ../packetdatatypes.cpp \
57 57 ../flashmqtestclient.cpp \
  58 + conffiletemp.cpp \
58 59 mainappthread.cpp
59 60  
60 61  
... ... @@ -102,6 +103,7 @@ HEADERS += \
102 103 ../derivablecounter.h \
103 104 ../packetdatatypes.h \
104 105 ../flashmqtestclient.h \
  106 + conffiletemp.h \
105 107 mainappthread.h
106 108  
107 109 LIBS += -ldl -lssl -lcrypto
... ...
FlashMQTests/conffiletemp.cpp 0 → 100644
  1 +#include "conffiletemp.h"
  2 +
  3 +#include <vector>
  4 +#include "unistd.h"
  5 +#include <stdexcept>
  6 +
  7 +ConfFileTemp::ConfFileTemp()
  8 +{
  9 + const std::string templateName("/tmp/flashmqconf_XXXXXX");
  10 + std::vector<char> nameBuf(templateName.size() + 1, 0);
  11 + std::copy(templateName.begin(), templateName.end(), nameBuf.begin());
  12 + this->fd = mkstemp(nameBuf.data());
  13 +
  14 + if (this->fd < 0)
  15 + {
  16 + throw std::runtime_error("mkstemp error.");
  17 + }
  18 +
  19 + this->filePath = nameBuf.data();
  20 +}
  21 +
  22 +ConfFileTemp::~ConfFileTemp()
  23 +{
  24 + closeFile();
  25 +
  26 + if (!this->filePath.empty())
  27 + unlink(this->filePath.c_str());
  28 +}
  29 +
  30 +const std::string &ConfFileTemp::getFilePath() const
  31 +{
  32 + if (fd > 0)
  33 + throw std::runtime_error("You first need to close the file before using it.");
  34 +
  35 + return this->filePath;
  36 +}
  37 +
  38 +void ConfFileTemp::writeLine(const std::string &line)
  39 +{
  40 + write(this->fd, line.c_str(), line.size());
  41 + write(this->fd, "\n", 1);
  42 +}
  43 +
  44 +void ConfFileTemp::closeFile()
  45 +{
  46 + if (this->fd < 0)
  47 + return;
  48 +
  49 + close(this->fd);
  50 + this->fd = -1;
  51 +}
... ...
FlashMQTests/conffiletemp.h 0 → 100644
  1 +#ifndef CONFFILETEMP_H
  2 +#define CONFFILETEMP_H
  3 +
  4 +#include <string>
  5 +
  6 +class ConfFileTemp
  7 +{
  8 + int fd = -1;
  9 + std::string filePath;
  10 +
  11 +public:
  12 + ConfFileTemp();
  13 + ~ConfFileTemp();
  14 +
  15 + const std::string &getFilePath() const;
  16 + void writeLine(const std::string &line);
  17 + void closeFile();
  18 +};
  19 +
  20 +#endif // CONFFILETEMP_H
... ...
FlashMQTests/mainappthread.cpp
... ... @@ -24,6 +24,39 @@ MainAppThread::MainAppThread(QObject *parent) : QThread(parent)
24 24 appInstance->settings->allowAnonymous = true;
25 25 }
26 26  
  27 +MainAppThread::MainAppThread(const std::vector<std::string> &args, QObject *parent) : QThread(parent)
  28 +{
  29 + std::list<std::vector<char>> argCopies;
  30 +
  31 + const std::string programName = "FlashMQTests";
  32 + std::vector<char> programNameCopy(programName.size() + 1, 0);
  33 + std::copy(programName.begin(), programName.end(), programNameCopy.begin());
  34 + argCopies.push_back(std::move(programNameCopy));
  35 +
  36 + for (const std::string &arg : args)
  37 + {
  38 + std::vector<char> copyArg(arg.size() + 1, 0);
  39 + std::copy(arg.begin(), arg.end(), copyArg.begin());
  40 + argCopies.push_back(std::move(copyArg));
  41 + }
  42 +
  43 + char *argv[256];
  44 + memset(argv, 0, 256*sizeof (char*));
  45 +
  46 + int i = 0;
  47 + for (std::vector<char> &copy : argCopies)
  48 + {
  49 + argv[i++] = copy.data();
  50 + }
  51 +
  52 + MainApp::initMainApp(i, argv);
  53 + appInstance = MainApp::getMainApp();
  54 +
  55 + // A hack: when I supply args I probably define a config for auth stuff.
  56 + if (args.empty())
  57 + appInstance->settings->allowAnonymous = true;
  58 +}
  59 +
27 60 MainAppThread::~MainAppThread()
28 61 {
29 62 if (appInstance)
... ...
FlashMQTests/mainappthread.h
... ... @@ -28,6 +28,7 @@ class MainAppThread : public QThread
28 28 MainApp *appInstance = nullptr;
29 29 public:
30 30 explicit MainAppThread(QObject *parent = nullptr);
  31 + MainAppThread(const std::vector<std::string> &args, QObject *parent = nullptr);
31 32 ~MainAppThread();
32 33  
33 34 public slots:
... ...
FlashMQTests/plugins/test_plugin.cpp
... ... @@ -45,6 +45,9 @@ AuthResult flashmq_auth_plugin_login_check(void *thread_data, const std::string
45 45 (void)password;
46 46 (void)userProperties;
47 47  
  48 + if (username == "failme")
  49 + return AuthResult::login_denied;
  50 +
48 51 return AuthResult::success;
49 52 }
50 53  
... ... @@ -72,6 +75,34 @@ AuthResult flashmq_extended_auth(void *thread_data, const std::string &amp;clientid,
72 75 (void)userProperties;
73 76 (void)returnData;
74 77  
75   - return AuthResult::success;
  78 + if (authMethod == "always_good_passing_back_the_auth_data")
  79 + {
  80 + if (authData == "actually not good.")
  81 + return AuthResult::login_denied;
  82 +
  83 + returnData = authData;
  84 + return AuthResult::success;
  85 + }
  86 + if (authMethod == "always_fail")
  87 + {
  88 + return AuthResult::login_denied;
  89 + }
  90 + if (authMethod == "two_step")
  91 + {
  92 + if (authData == "Hello")
  93 + returnData = "Hello back";
  94 +
  95 + if (authData == "grant me already!")
  96 + {
  97 + returnData = "OK, if you insist.";
  98 + return AuthResult::success;
  99 + }
  100 + else if (authData == "whoops, wrong data.")
  101 + return AuthResult::login_denied;
  102 + else
  103 + return AuthResult::auth_continue;
  104 + }
  105 +
  106 + return AuthResult::auth_method_not_supported;
76 107 }
77 108  
... ...
FlashMQTests/tst_maintests.cpp
... ... @@ -32,6 +32,8 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
32 32 #include "session.h"
33 33 #include "threaddata.h"
34 34 #include "threadglobals.h"
  35 +#include "conffiletemp.h"
  36 +#include "packetdatatypes.h"
35 37  
36 38 #include "flashmqtestclient.h"
37 39  
... ... @@ -65,6 +67,7 @@ public:
65 67 ~MainTests();
66 68  
67 69 private slots:
  70 + void init(const std::vector<std::string> &args);
68 71 void init(); // will be called before each test function is executed
69 72 void cleanup(); // will be called after every test function.
70 73  
... ... @@ -134,6 +137,18 @@ private slots:
134 137  
135 138 void testUserProperties();
136 139  
  140 + void testAuthFail();
  141 + void testAuthSucceed();
  142 +
  143 + void testExtendedAuthOneStepSucceed();
  144 + void testExtendedAuthOneStepDeny();
  145 + void testExtendedAuthOneStepBadAuthMethod();
  146 + void testExtendedAuthTwoStep();
  147 + void testExtendedAuthTwoStepSecondStepFail();
  148 + void testExtendedReAuth();
  149 + void testExtendedReAuthTwoStep();
  150 + void testExtendedReAuthFail();
  151 +
137 152 };
138 153  
139 154 MainTests::MainTests()
... ... @@ -146,10 +161,10 @@ MainTests::~MainTests()
146 161  
147 162 }
148 163  
149   -void MainTests::init()
  164 +void MainTests::init(const std::vector<std::string> &args)
150 165 {
151 166 mainApp.reset();
152   - mainApp.reset(new MainAppThread());
  167 + mainApp.reset(new MainAppThread(args));
153 168 mainApp->start();
154 169 mainApp->waitForStarted();
155 170  
... ... @@ -160,6 +175,12 @@ void MainTests::init()
160 175 ThreadGlobals::assignThreadData(dummyThreadData.get());
161 176 }
162 177  
  178 +void MainTests::init()
  179 +{
  180 + std::vector<std::string> args;
  181 + init(args);
  182 +}
  183 +
163 184 void MainTests::cleanup()
164 185 {
165 186 mainApp->stopApp();
... ... @@ -1925,6 +1946,414 @@ void MainTests::testUserProperties()
1925 1946 QVERIFY(properties3 == nullptr);
1926 1947 }
1927 1948  
  1949 +void MainTests::testAuthFail()
  1950 +{
  1951 + std::vector<ProtocolVersion> versions { ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5 };
  1952 +
  1953 + ConfFileTemp confFile;
  1954 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  1955 + confFile.closeFile();
  1956 +
  1957 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  1958 +
  1959 + cleanup();
  1960 + init(args);
  1961 +
  1962 + for (ProtocolVersion &version : versions)
  1963 + {
  1964 +
  1965 + FlashMQTestClient client;
  1966 + client.start();
  1967 + client.connectClient(version, false, 120, [](Connect &connect) {
  1968 + connect.username = "failme";
  1969 + connect.password = "boo";
  1970 + });
  1971 +
  1972 + QVERIFY(client.receivedPackets.size() == 1);
  1973 +
  1974 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  1975 +
  1976 + if (version >= ProtocolVersion::Mqtt5)
  1977 + QVERIFY(connAckData.reasonCode == ReasonCodes::NotAuthorized);
  1978 + else
  1979 + QVERIFY(static_cast<uint8_t>(connAckData.reasonCode) == 5);
  1980 + }
  1981 +}
  1982 +
  1983 +void MainTests::testAuthSucceed()
  1984 +{
  1985 + std::vector<ProtocolVersion> versions { ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5 };
  1986 +
  1987 + ConfFileTemp confFile;
  1988 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  1989 + confFile.closeFile();
  1990 +
  1991 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  1992 +
  1993 + cleanup();
  1994 + init(args);
  1995 +
  1996 + for (ProtocolVersion &version : versions)
  1997 + {
  1998 +
  1999 + FlashMQTestClient client;
  2000 + client.start();
  2001 + client.connectClient(version, false, 120, [](Connect &connect) {
  2002 + connect.username = "passme";
  2003 + connect.password = "boo";
  2004 + });
  2005 +
  2006 + QVERIFY(client.receivedPackets.size() == 1);
  2007 +
  2008 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2009 +
  2010 + QVERIFY(connAckData.reasonCode == ReasonCodes::Success);
  2011 + }
  2012 +}
  2013 +
  2014 +void MainTests::testExtendedAuthOneStepSucceed()
  2015 +{
  2016 + ConfFileTemp confFile;
  2017 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  2018 + confFile.closeFile();
  2019 +
  2020 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  2021 +
  2022 + cleanup();
  2023 + init(args);
  2024 +
  2025 + FlashMQTestClient client;
  2026 + client.start();
  2027 + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) {
  2028 + connect.username = "me";
  2029 + connect.password = "me me";
  2030 +
  2031 + connect.constructPropertyBuilder();
  2032 +
  2033 + connect.propertyBuilder->writeAuthenticationMethod("always_good_passing_back_the_auth_data");
  2034 + connect.propertyBuilder->writeAuthenticationData("I have a proposal to put to ye.");
  2035 + });
  2036 +
  2037 + QVERIFY(client.receivedPackets.size() == 1);
  2038 +
  2039 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2040 +
  2041 + QVERIFY(connAckData.reasonCode == ReasonCodes::Success);
  2042 + QVERIFY(connAckData.authData == "I have a proposal to put to ye.");
  2043 +}
  2044 +
  2045 +void MainTests::testExtendedAuthOneStepDeny()
  2046 +{
  2047 + ConfFileTemp confFile;
  2048 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  2049 + confFile.closeFile();
  2050 +
  2051 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  2052 +
  2053 + cleanup();
  2054 + init(args);
  2055 +
  2056 + FlashMQTestClient client;
  2057 + client.start();
  2058 + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) {
  2059 + connect.username = "me";
  2060 + connect.password = "me me";
  2061 +
  2062 + connect.constructPropertyBuilder();
  2063 +
  2064 + connect.propertyBuilder->writeAuthenticationMethod("always_fail");
  2065 + });
  2066 +
  2067 + QVERIFY(client.receivedPackets.size() == 1);
  2068 +
  2069 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2070 +
  2071 + QVERIFY(connAckData.reasonCode == ReasonCodes::NotAuthorized);
  2072 +}
  2073 +
  2074 +void MainTests::testExtendedAuthOneStepBadAuthMethod()
  2075 +{
  2076 + ConfFileTemp confFile;
  2077 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  2078 + confFile.closeFile();
  2079 +
  2080 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  2081 +
  2082 + cleanup();
  2083 + init(args);
  2084 +
  2085 + FlashMQTestClient client;
  2086 + client.start();
  2087 + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) {
  2088 + connect.username = "me";
  2089 + connect.password = "me me";
  2090 +
  2091 + connect.constructPropertyBuilder();
  2092 +
  2093 + connect.propertyBuilder->writeAuthenticationMethod("doesnt_exist");
  2094 + });
  2095 +
  2096 + QVERIFY(client.receivedPackets.size() == 1);
  2097 +
  2098 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2099 +
  2100 + QVERIFY(connAckData.reasonCode == ReasonCodes::BadAuthenticationMethod);
  2101 +}
  2102 +
  2103 +void MainTests::testExtendedAuthTwoStep()
  2104 +{
  2105 + ConfFileTemp confFile;
  2106 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  2107 + confFile.closeFile();
  2108 +
  2109 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  2110 +
  2111 + cleanup();
  2112 + init(args);
  2113 +
  2114 + FlashMQTestClient client;
  2115 + client.start();
  2116 + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) {
  2117 + connect.username = "me";
  2118 + connect.password = "me me";
  2119 +
  2120 + connect.constructPropertyBuilder();
  2121 +
  2122 + connect.propertyBuilder->writeAuthenticationMethod("two_step");
  2123 + connect.propertyBuilder->writeAuthenticationData("Hello");
  2124 + });
  2125 +
  2126 + QVERIFY(client.receivedPackets.size() == 1);
  2127 +
  2128 + AuthPacketData authData = client.receivedPackets.front().parseAuthData();
  2129 +
  2130 + QVERIFY(authData.reasonCode == ReasonCodes::ContinueAuthentication);
  2131 + QVERIFY(authData.data == "Hello back");
  2132 +
  2133 + client.clearReceivedLists();
  2134 +
  2135 + const Auth auth(ReasonCodes::ContinueAuthentication, "two_step", "grant me already!");
  2136 + client.writeAuth(auth);
  2137 +
  2138 + client.waitForConnack();
  2139 +
  2140 + QVERIFY(client.receivedPackets.size() == 1);
  2141 +
  2142 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2143 +
  2144 + QVERIFY(connAckData.reasonCode == ReasonCodes::Success);
  2145 + QVERIFY(connAckData.authData == "OK, if you insist.");
  2146 +}
  2147 +
  2148 +void MainTests::testExtendedAuthTwoStepSecondStepFail()
  2149 +{
  2150 + ConfFileTemp confFile;
  2151 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  2152 + confFile.closeFile();
  2153 +
  2154 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  2155 +
  2156 + cleanup();
  2157 + init(args);
  2158 +
  2159 + FlashMQTestClient client;
  2160 + client.start();
  2161 + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) {
  2162 + connect.username = "me";
  2163 + connect.password = "me me";
  2164 +
  2165 + connect.constructPropertyBuilder();
  2166 +
  2167 + connect.propertyBuilder->writeAuthenticationMethod("two_step");
  2168 + connect.propertyBuilder->writeAuthenticationData("Hello");
  2169 + });
  2170 +
  2171 + QVERIFY(client.receivedPackets.size() == 1);
  2172 +
  2173 + AuthPacketData authData = client.receivedPackets.front().parseAuthData();
  2174 +
  2175 + QVERIFY(authData.reasonCode == ReasonCodes::ContinueAuthentication);
  2176 + QVERIFY(authData.data == "Hello back");
  2177 +
  2178 + client.clearReceivedLists();
  2179 +
  2180 + const Auth auth(ReasonCodes::ContinueAuthentication, "two_step", "whoops, wrong data.");
  2181 + client.writeAuth(auth);
  2182 +
  2183 + client.waitForConnack();
  2184 +
  2185 + QVERIFY(client.receivedPackets.size() == 1);
  2186 +
  2187 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2188 +
  2189 + QVERIFY(connAckData.reasonCode == ReasonCodes::NotAuthorized);
  2190 +}
  2191 +
  2192 +void MainTests::testExtendedReAuth()
  2193 +{
  2194 + ConfFileTemp confFile;
  2195 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  2196 + confFile.closeFile();
  2197 +
  2198 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  2199 +
  2200 + cleanup();
  2201 + init(args);
  2202 +
  2203 + FlashMQTestClient client;
  2204 + client.start();
  2205 + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) {
  2206 + connect.username = "me";
  2207 + connect.password = "me me";
  2208 +
  2209 + connect.constructPropertyBuilder();
  2210 + connect.propertyBuilder->writeAuthenticationMethod("always_good_passing_back_the_auth_data");
  2211 + connect.propertyBuilder->writeAuthenticationData("Santa Claus");
  2212 + });
  2213 +
  2214 + QVERIFY(client.receivedPackets.size() == 1);
  2215 +
  2216 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2217 +
  2218 + QVERIFY(connAckData.reasonCode == ReasonCodes::Success);
  2219 +
  2220 + client.clearReceivedLists();
  2221 +
  2222 + // Then reauth.
  2223 +
  2224 + Auth auth(ReasonCodes::ContinueAuthentication, "always_good_passing_back_the_auth_data", "Again Santa Claus");
  2225 + client.writeAuth(auth);
  2226 +
  2227 + client.waitForConnack();
  2228 +
  2229 + QVERIFY(client.receivedPackets.size() == 1);
  2230 +
  2231 + AuthPacketData authData = client.receivedPackets.front().parseAuthData();
  2232 +
  2233 + QVERIFY(authData.reasonCode == ReasonCodes::Success);
  2234 + QVERIFY(authData.data == "Again Santa Claus");
  2235 +}
  2236 +
  2237 +void MainTests::testExtendedReAuthTwoStep()
  2238 +{
  2239 + ConfFileTemp confFile;
  2240 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  2241 + confFile.closeFile();
  2242 +
  2243 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  2244 +
  2245 + cleanup();
  2246 + init(args);
  2247 +
  2248 + FlashMQTestClient client;
  2249 + client.start();
  2250 + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) {
  2251 + connect.username = "me";
  2252 + connect.password = "me me";
  2253 +
  2254 + connect.constructPropertyBuilder();
  2255 +
  2256 + connect.propertyBuilder->writeAuthenticationMethod("two_step");
  2257 + connect.propertyBuilder->writeAuthenticationData("Hello");
  2258 + });
  2259 +
  2260 + QVERIFY(client.receivedPackets.size() == 1);
  2261 +
  2262 + AuthPacketData authData = client.receivedPackets.front().parseAuthData();
  2263 +
  2264 + QVERIFY(authData.reasonCode == ReasonCodes::ContinueAuthentication);
  2265 + QVERIFY(authData.data == "Hello back");
  2266 +
  2267 + client.clearReceivedLists();
  2268 +
  2269 + const Auth auth(ReasonCodes::ContinueAuthentication, "two_step", "grant me already!");
  2270 + client.writeAuth(auth);
  2271 +
  2272 + client.waitForConnack();
  2273 +
  2274 + QVERIFY(client.receivedPackets.size() == 1);
  2275 +
  2276 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2277 +
  2278 + QVERIFY(connAckData.reasonCode == ReasonCodes::Success);
  2279 + QVERIFY(connAckData.authData == "OK, if you insist.");
  2280 +
  2281 + client.clearReceivedLists();
  2282 +
  2283 + // Then reauth.
  2284 +
  2285 + const Auth reauth(ReasonCodes::ReAuthenticate, "two_step", "Hello");
  2286 + client.writeAuth(reauth);
  2287 + client.waitForConnack();
  2288 +
  2289 + QVERIFY(client.receivedPackets.size() == 1);
  2290 +
  2291 + AuthPacketData reauthData = client.receivedPackets.front().parseAuthData();
  2292 +
  2293 + QVERIFY(reauthData.reasonCode == ReasonCodes::ContinueAuthentication);
  2294 + QVERIFY(reauthData.data == "Hello back");
  2295 +
  2296 + client.clearReceivedLists();
  2297 +
  2298 + const Auth reauthFinish(ReasonCodes::ContinueAuthentication, "two_step", "grant me already!");
  2299 + client.writeAuth(reauthFinish);
  2300 +
  2301 + client.waitForConnack();
  2302 +
  2303 + QVERIFY(client.receivedPackets.size() == 1);
  2304 +
  2305 + AuthPacketData reauthFinishData = client.receivedPackets.front().parseAuthData();
  2306 +
  2307 + QVERIFY(connAckData.reasonCode == ReasonCodes::Success);
  2308 + QVERIFY(connAckData.authData == "OK, if you insist.");
  2309 +}
  2310 +
  2311 +void MainTests::testExtendedReAuthFail()
  2312 +{
  2313 + ConfFileTemp confFile;
  2314 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  2315 + confFile.closeFile();
  2316 +
  2317 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  2318 +
  2319 + cleanup();
  2320 + init(args);
  2321 +
  2322 + FlashMQTestClient client;
  2323 + client.start();
  2324 + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) {
  2325 + connect.username = "me";
  2326 + connect.password = "me me";
  2327 +
  2328 + connect.constructPropertyBuilder();
  2329 +
  2330 + connect.propertyBuilder->writeAuthenticationMethod("always_good_passing_back_the_auth_data");
  2331 + connect.propertyBuilder->writeAuthenticationData("I have a proposal to put to ye.");
  2332 + });
  2333 +
  2334 + QVERIFY(client.receivedPackets.size() == 1);
  2335 +
  2336 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2337 +
  2338 + QVERIFY(connAckData.reasonCode == ReasonCodes::Success);
  2339 + QVERIFY(connAckData.authData == "I have a proposal to put to ye.");
  2340 +
  2341 + client.clearReceivedLists();
  2342 +
  2343 + // Then reauth.
  2344 +
  2345 + const Auth reauth(ReasonCodes::ReAuthenticate, "always_good_passing_back_the_auth_data", "actually not good.");
  2346 + client.writeAuth(reauth);
  2347 + client.waitForPacketCount(1);
  2348 +
  2349 + QVERIFY(client.receivedPackets.size() == 1);
  2350 + QVERIFY(client.receivedPackets.front().packetType == PacketType::DISCONNECT);
  2351 +
  2352 + DisconnectData data = client.receivedPackets.front().parseDisconnectData();
  2353 +
  2354 + QVERIFY(data.reasonCode == ReasonCodes::NotAuthorized);
  2355 +}
  2356 +
1928 2357 int main(int argc, char *argv[])
1929 2358 {
1930 2359 QCoreApplication app(argc, argv);
... ...
authplugin.cpp
... ... @@ -356,7 +356,7 @@ AuthResult Authentication::unPwdCheck(const std::string &amp;username, const std::st
356 356 {
357 357 AuthResult firstResult = unPwdCheckFromMosquittoPasswordFile(username, password);
358 358  
359   - if (firstResult != AuthResult::success)
  359 + if (firstResult == AuthResult::success)
360 360 return firstResult;
361 361  
362 362 if (pluginVersion == PluginVersion::None)
... ...
flashmqtestclient.cpp
... ... @@ -260,6 +260,12 @@ void FlashMQTestClient::publish(Publish &amp;pub)
260 260 }
261 261 }
262 262  
  263 +void FlashMQTestClient::writeAuth(const Auth &auth)
  264 +{
  265 + MqttPacket pack(auth);
  266 + client->writeMqttPacketAndBlameThisClient(pack);
  267 +}
  268 +
263 269 void FlashMQTestClient::publish(const std::string &topic, const std::string &payload, char qos)
264 270 {
265 271 Publish pub(topic, payload, qos);
... ... @@ -276,7 +282,7 @@ void FlashMQTestClient::waitForConnack()
276 282 {
277 283 waitForCondition([&]() {
278 284 return std::any_of(this->receivedPackets.begin(), this->receivedPackets.end(), [](const MqttPacket &p) {
279   - return p.packetType == PacketType::CONNACK;
  285 + return p.packetType == PacketType::CONNACK || p.packetType == PacketType::AUTH;
280 286 });
281 287 });
282 288 }
... ... @@ -287,3 +293,10 @@ void FlashMQTestClient::waitForMessageCount(const size_t count, int timeout)
287 293 return this->receivedPublishes.size() >= count;
288 294 }, timeout);
289 295 }
  296 +
  297 +void FlashMQTestClient::waitForPacketCount(const size_t count, int timeout)
  298 +{
  299 + waitForCondition([&]() {
  300 + return this->receivedPackets.size() >= count;
  301 + }, timeout);
  302 +}
... ...
flashmqtestclient.h
... ... @@ -40,6 +40,7 @@ public:
40 40 void unsubscribe(const std::string &topic);
41 41 void publish(const std::string &topic, const std::string &payload, char qos);
42 42 void publish(Publish &pub);
  43 + void writeAuth(const Auth &auth);
43 44 void clearReceivedLists();
44 45 void setWill(std::shared_ptr<WillPublish> &will);
45 46 void disconnect(ReasonCodes reason);
... ... @@ -47,6 +48,7 @@ public:
47 48 void waitForQuit();
48 49 void waitForConnack();
49 50 void waitForMessageCount(const size_t count, int timeout = 1);
  51 + void waitForPacketCount(const size_t count, int timeout = 1);
50 52 };
51 53  
52 54 #endif // FLASHMQTESTCLIENT_H
... ...
mqttpacket.cpp
... ... @@ -253,6 +253,8 @@ MqttPacket::MqttPacket(const Connect &amp;connect) :
253 253 writeByte(static_cast<char>(protocolVersion));
254 254  
255 255 uint8_t flags = connect.clean_start << 1;
  256 + flags |= !connect.username.empty() << 7;
  257 + flags |= !connect.password.empty() << 6;
256 258  
257 259 if (connect.will)
258 260 {
... ... @@ -284,6 +286,11 @@ MqttPacket::MqttPacket(const Connect &amp;connect) :
284 286 writeString(connect.will->payload);
285 287 }
286 288  
  289 + if (!connect.username.empty())
  290 + writeString(connect.username);
  291 + if (!connect.password.empty())
  292 + writeString(connect.password);
  293 +
287 294 calculateRemainingLength();
288 295 }
289 296  
... ... @@ -634,6 +641,43 @@ ConnectData MqttPacket::parseConnectData()
634 641 return result;
635 642 }
636 643  
  644 +ConnAckData MqttPacket::parseConnAckData()
  645 +{
  646 + if (this->packetType != PacketType::CONNACK)
  647 + throw std::runtime_error("Packet must be connack packet.");
  648 +
  649 + setPosToDataStart();
  650 +
  651 + ConnAckData result;
  652 +
  653 + const uint8_t flagByte = readByte();
  654 +
  655 + result.sessionPresent = flagByte & 0x01;
  656 + result.reasonCode = static_cast<ReasonCodes>(readUint8());
  657 +
  658 + if (protocolVersion == ProtocolVersion::Mqtt5)
  659 + {
  660 + const size_t proplen = decodeVariableByteIntAtPos();
  661 + const size_t prop_end_at = pos + proplen;
  662 +
  663 + while (pos < prop_end_at)
  664 + {
  665 + const Mqtt5Properties prop = static_cast<Mqtt5Properties>(readByte());
  666 +
  667 + switch (prop)
  668 + {
  669 + case Mqtt5Properties::AuthenticationData:
  670 + result.authData = readBytesToString();
  671 + break;
  672 + default:
  673 + break;
  674 + }
  675 + }
  676 + }
  677 +
  678 + return result;
  679 +}
  680 +
637 681 void MqttPacket::handleConnect()
638 682 {
639 683 if (sender->hasConnectPacketSeen())
... ... @@ -798,18 +842,22 @@ void MqttPacket::handleConnect()
798 842 }
799 843 }
800 844  
801   -void MqttPacket::handleExtendedAuth()
  845 +AuthPacketData MqttPacket::parseAuthData()
802 846 {
  847 + if (this->packetType != PacketType::AUTH)
  848 + throw std::runtime_error("Packet must be an AUTH packet.");
  849 +
803 850 if (first_byte & 0b1111)
804 851 throw ProtocolError("AUTH packet first 4 bits should be 0.", ReasonCodes::MalformedPacket);
805 852  
806   - const ReasonCodes reasonCode = static_cast<ReasonCodes>(readByte());
807   -
808 853 if (this->protocolVersion < ProtocolVersion::Mqtt5)
809 854 throw ProtocolError("AUTH packet needs MQTT5 or higher");
810 855  
811   - std::string authMethod;
812   - std::string authData;
  856 + setPosToDataStart();
  857 +
  858 + AuthPacketData result;
  859 +
  860 + result.reasonCode = static_cast<ReasonCodes>(readUint8());
813 861  
814 862 if (!atEnd())
815 863 {
... ... @@ -823,10 +871,10 @@ void MqttPacket::handleExtendedAuth()
823 871 switch (prop)
824 872 {
825 873 case Mqtt5Properties::AuthenticationMethod:
826   - authMethod = readBytesToString();
  874 + result.method = readBytesToString();
827 875 break;
828 876 case Mqtt5Properties::AuthenticationData:
829   - authData = readBytesToString(false);
  877 + result.data = readBytesToString(false);
830 878 break;
831 879 case Mqtt5Properties::ReasonString:
832 880 readBytesToString();
... ... @@ -840,12 +888,19 @@ void MqttPacket::handleExtendedAuth()
840 888 }
841 889 }
842 890  
843   - if (authMethod != sender->getExtendedAuthenticationMethod())
  891 + return result;
  892 +}
  893 +
  894 +void MqttPacket::handleExtendedAuth()
  895 +{
  896 + AuthPacketData data = parseAuthData();
  897 +
  898 + if (data.method != sender->getExtendedAuthenticationMethod())
844 899 throw ProtocolError("Client continued with another authentication method that it started with.", ReasonCodes::ProtocolError);
845 900  
846 901 ExtendedAuthStage authStage = ExtendedAuthStage::None;
847 902  
848   - switch(reasonCode)
  903 + switch(data.reasonCode)
849 904 {
850 905 case ReasonCodes::ContinueAuthentication:
851 906 authStage = ExtendedAuthStage::Continue;
... ... @@ -854,7 +909,7 @@ void MqttPacket::handleExtendedAuth()
854 909 authStage = ExtendedAuthStage::Reauth;
855 910 break;
856 911 default:
857   - throw ProtocolError(formatString("Invalid reason code '%d' in auth packet", static_cast<uint8_t>(reasonCode)), ReasonCodes::MalformedPacket);
  912 + throw ProtocolError(formatString("Invalid reason code '%d' in auth packet", static_cast<uint8_t>(data.reasonCode)), ReasonCodes::MalformedPacket);
858 913 }
859 914  
860 915 if (authStage == ExtendedAuthStage::Reauth && !sender->getAuthenticated())
... ... @@ -865,28 +920,24 @@ void MqttPacket::handleExtendedAuth()
865 920 Authentication &authentication = *ThreadGlobals::getAuth();
866 921  
867 922 std::string returnData;
868   - const AuthResult authResult = authentication.extendedAuth(sender->getClientId(), authStage, authMethod, authData,
  923 + const AuthResult authResult = authentication.extendedAuth(sender->getClientId(), authStage, data.method, data.data,
869 924 getUserProperties(), returnData, sender->getMutableUsername());
870 925  
871 926 if (authResult == AuthResult::auth_continue)
872 927 {
873   - Auth auth(ReasonCodes::ContinueAuthentication, authMethod, returnData);
  928 + Auth auth(ReasonCodes::ContinueAuthentication, data.method, returnData);
874 929 MqttPacket pack(auth);
875 930 sender->writeMqttPacket(pack);
876 931 return;
877 932 }
878 933  
879   - if (authResult == AuthResult::success)
880   - {
881   - sender->addAuthReturnDataToStagedConnAck(returnData);
882   - }
883   -
884 934 const ReasonCodes finalResult = authResultToReasonCode(authResult);
885 935  
886 936 if (!sender->getAuthenticated()) // First auth sends connack packets.
887 937 {
888 938 if (finalResult == ReasonCodes::Success)
889 939 {
  940 + sender->addAuthReturnDataToStagedConnAck(returnData);
890 941 sender->sendConnackSuccess();
891 942 std::shared_ptr<SubscriptionStore> subscriptionStore = MainApp::getMainApp()->getSubscriptionStore();
892 943 subscriptionStore->registerClientAndKickExistingOne(sender);
... ... @@ -900,7 +951,7 @@ void MqttPacket::handleExtendedAuth()
900 951 {
901 952 if (finalResult == ReasonCodes::Success)
902 953 {
903   - Auth auth(ReasonCodes::Success, authMethod, returnData);
  954 + Auth auth(ReasonCodes::Success, data.method, returnData);
904 955 MqttPacket authPack(auth);
905 956 sender->writeMqttPacket(authPack);
906 957 logger->logf(LOG_NOTICE, "Client '%s', user '%s' reauthentication successful.", sender->getClientId().c_str(), sender->getUsername().c_str());
... ... @@ -933,7 +984,7 @@ DisconnectData MqttPacket::parseDisconnectData()
933 984 {
934 985 if (!atEnd())
935 986 {
936   - result.reasonCode = static_cast<ReasonCodes>(readByte());
  987 + result.reasonCode = static_cast<ReasonCodes>(readUint8());
937 988  
938 989 const size_t proplen = decodeVariableByteIntAtPos();
939 990 const size_t prop_end_at = pos + proplen;
... ... @@ -1621,6 +1672,12 @@ char MqttPacket::readByte()
1621 1672 return b;
1622 1673 }
1623 1674  
  1675 +uint8_t MqttPacket::readUint8()
  1676 +{
  1677 + char r = readByte();
  1678 + return static_cast<uint8_t>(r);
  1679 +}
  1680 +
1624 1681 void MqttPacket::writeByte(char b)
1625 1682 {
1626 1683 if (pos + 1 > bites.size())
... ...
mqttpacket.h
... ... @@ -73,6 +73,7 @@ class MqttPacket
73 73  
74 74 char *readBytes(size_t length);
75 75 char readByte();
  76 + uint8_t readUint8();
76 77 void writeByte(char b);
77 78 void writeUint16(uint16_t x);
78 79 void writeBytes(const char *b, size_t len);
... ... @@ -121,7 +122,9 @@ public:
121 122 static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender);
122 123  
123 124 void handle();
  125 + AuthPacketData parseAuthData();
124 126 ConnectData parseConnectData();
  127 + ConnAckData parseConnAckData();
125 128 void handleConnect();
126 129 void handleExtendedAuth();
127 130 DisconnectData parseDisconnectData();
... ...
packetdatatypes.h
... ... @@ -38,6 +38,23 @@ struct ConnectData
38 38 ConnectData();
39 39 };
40 40  
  41 +struct ConnAckData
  42 +{
  43 + // Flags
  44 + bool sessionPresent = false;
  45 +
  46 + ReasonCodes reasonCode = ReasonCodes::ImplementationSpecificError; // default something that is never a parse result;
  47 +
  48 + std::string authData;
  49 +};
  50 +
  51 +struct AuthPacketData
  52 +{
  53 + std::string method;
  54 + std::string data;
  55 + ReasonCodes reasonCode = ReasonCodes::ImplementationSpecificError; // default something that is never a parse result;
  56 +};
  57 +
41 58 struct DisconnectData
42 59 {
43 60 ReasonCodes reasonCode = ReasonCodes::Success;
... ...
types.cpp
... ... @@ -397,6 +397,12 @@ size_t Connect::getLengthWithoutFixedHeader() const
397 397 result += will->payload.length() + 2;
398 398 }
399 399  
  400 + if (!username.empty())
  401 + result += username.size() + 2;
  402 +
  403 + if (!password.empty())
  404 + result += password.size() + 2;
  405 +
400 406 return result;
401 407 }
402 408  
... ...