Commit 5f683b9c702ce2d2d85dfc7fcca722a8a560849d

Authored by Wiebe Cazemier
1 parent 5c47eca3

Test (plugin) auth system + fixes

When the plugin was enabled, it would never got past the fail of not
having a password file entry. This is fixed.

The reauth didn't work because setting the return data took the path of
first auth, hitting a protection exception. This is fixed.

Writing extended auth tests involved some refactoring, to create a
separate method to parse the auth packet data, because I needed it in
the test too.
FlashMQTests/FlashMQTests.pro
@@ -55,6 +55,7 @@ SOURCES += tst_maintests.cpp \ @@ -55,6 +55,7 @@ SOURCES += tst_maintests.cpp \
55 ../derivablecounter.cpp \ 55 ../derivablecounter.cpp \
56 ../packetdatatypes.cpp \ 56 ../packetdatatypes.cpp \
57 ../flashmqtestclient.cpp \ 57 ../flashmqtestclient.cpp \
  58 + conffiletemp.cpp \
58 mainappthread.cpp 59 mainappthread.cpp
59 60
60 61
@@ -102,6 +103,7 @@ HEADERS += \ @@ -102,6 +103,7 @@ HEADERS += \
102 ../derivablecounter.h \ 103 ../derivablecounter.h \
103 ../packetdatatypes.h \ 104 ../packetdatatypes.h \
104 ../flashmqtestclient.h \ 105 ../flashmqtestclient.h \
  106 + conffiletemp.h \
105 mainappthread.h 107 mainappthread.h
106 108
107 LIBS += -ldl -lssl -lcrypto 109 LIBS += -ldl -lssl -lcrypto
FlashMQTests/conffiletemp.cpp 0 → 100644
  1 +#include "conffiletemp.h"
  2 +
  3 +#include <vector>
  4 +#include "unistd.h"
  5 +#include <stdexcept>
  6 +
  7 +ConfFileTemp::ConfFileTemp()
  8 +{
  9 + const std::string templateName("/tmp/flashmqconf_XXXXXX");
  10 + std::vector<char> nameBuf(templateName.size() + 1, 0);
  11 + std::copy(templateName.begin(), templateName.end(), nameBuf.begin());
  12 + this->fd = mkstemp(nameBuf.data());
  13 +
  14 + if (this->fd < 0)
  15 + {
  16 + throw std::runtime_error("mkstemp error.");
  17 + }
  18 +
  19 + this->filePath = nameBuf.data();
  20 +}
  21 +
  22 +ConfFileTemp::~ConfFileTemp()
  23 +{
  24 + closeFile();
  25 +
  26 + if (!this->filePath.empty())
  27 + unlink(this->filePath.c_str());
  28 +}
  29 +
  30 +const std::string &ConfFileTemp::getFilePath() const
  31 +{
  32 + if (fd > 0)
  33 + throw std::runtime_error("You first need to close the file before using it.");
  34 +
  35 + return this->filePath;
  36 +}
  37 +
  38 +void ConfFileTemp::writeLine(const std::string &line)
  39 +{
  40 + write(this->fd, line.c_str(), line.size());
  41 + write(this->fd, "\n", 1);
  42 +}
  43 +
  44 +void ConfFileTemp::closeFile()
  45 +{
  46 + if (this->fd < 0)
  47 + return;
  48 +
  49 + close(this->fd);
  50 + this->fd = -1;
  51 +}
FlashMQTests/conffiletemp.h 0 → 100644
  1 +#ifndef CONFFILETEMP_H
  2 +#define CONFFILETEMP_H
  3 +
  4 +#include <string>
  5 +
  6 +class ConfFileTemp
  7 +{
  8 + int fd = -1;
  9 + std::string filePath;
  10 +
  11 +public:
  12 + ConfFileTemp();
  13 + ~ConfFileTemp();
  14 +
  15 + const std::string &getFilePath() const;
  16 + void writeLine(const std::string &line);
  17 + void closeFile();
  18 +};
  19 +
  20 +#endif // CONFFILETEMP_H
FlashMQTests/mainappthread.cpp
@@ -24,6 +24,39 @@ MainAppThread::MainAppThread(QObject *parent) : QThread(parent) @@ -24,6 +24,39 @@ MainAppThread::MainAppThread(QObject *parent) : QThread(parent)
24 appInstance->settings->allowAnonymous = true; 24 appInstance->settings->allowAnonymous = true;
25 } 25 }
26 26
  27 +MainAppThread::MainAppThread(const std::vector<std::string> &args, QObject *parent) : QThread(parent)
  28 +{
  29 + std::list<std::vector<char>> argCopies;
  30 +
  31 + const std::string programName = "FlashMQTests";
  32 + std::vector<char> programNameCopy(programName.size() + 1, 0);
  33 + std::copy(programName.begin(), programName.end(), programNameCopy.begin());
  34 + argCopies.push_back(std::move(programNameCopy));
  35 +
  36 + for (const std::string &arg : args)
  37 + {
  38 + std::vector<char> copyArg(arg.size() + 1, 0);
  39 + std::copy(arg.begin(), arg.end(), copyArg.begin());
  40 + argCopies.push_back(std::move(copyArg));
  41 + }
  42 +
  43 + char *argv[256];
  44 + memset(argv, 0, 256*sizeof (char*));
  45 +
  46 + int i = 0;
  47 + for (std::vector<char> &copy : argCopies)
  48 + {
  49 + argv[i++] = copy.data();
  50 + }
  51 +
  52 + MainApp::initMainApp(i, argv);
  53 + appInstance = MainApp::getMainApp();
  54 +
  55 + // A hack: when I supply args I probably define a config for auth stuff.
  56 + if (args.empty())
  57 + appInstance->settings->allowAnonymous = true;
  58 +}
  59 +
27 MainAppThread::~MainAppThread() 60 MainAppThread::~MainAppThread()
28 { 61 {
29 if (appInstance) 62 if (appInstance)
FlashMQTests/mainappthread.h
@@ -28,6 +28,7 @@ class MainAppThread : public QThread @@ -28,6 +28,7 @@ class MainAppThread : public QThread
28 MainApp *appInstance = nullptr; 28 MainApp *appInstance = nullptr;
29 public: 29 public:
30 explicit MainAppThread(QObject *parent = nullptr); 30 explicit MainAppThread(QObject *parent = nullptr);
  31 + MainAppThread(const std::vector<std::string> &args, QObject *parent = nullptr);
31 ~MainAppThread(); 32 ~MainAppThread();
32 33
33 public slots: 34 public slots:
FlashMQTests/plugins/test_plugin.cpp
@@ -45,6 +45,9 @@ AuthResult flashmq_auth_plugin_login_check(void *thread_data, const std::string @@ -45,6 +45,9 @@ AuthResult flashmq_auth_plugin_login_check(void *thread_data, const std::string
45 (void)password; 45 (void)password;
46 (void)userProperties; 46 (void)userProperties;
47 47
  48 + if (username == "failme")
  49 + return AuthResult::login_denied;
  50 +
48 return AuthResult::success; 51 return AuthResult::success;
49 } 52 }
50 53
@@ -72,6 +75,34 @@ AuthResult flashmq_extended_auth(void *thread_data, const std::string &amp;clientid, @@ -72,6 +75,34 @@ AuthResult flashmq_extended_auth(void *thread_data, const std::string &amp;clientid,
72 (void)userProperties; 75 (void)userProperties;
73 (void)returnData; 76 (void)returnData;
74 77
75 - return AuthResult::success; 78 + if (authMethod == "always_good_passing_back_the_auth_data")
  79 + {
  80 + if (authData == "actually not good.")
  81 + return AuthResult::login_denied;
  82 +
  83 + returnData = authData;
  84 + return AuthResult::success;
  85 + }
  86 + if (authMethod == "always_fail")
  87 + {
  88 + return AuthResult::login_denied;
  89 + }
  90 + if (authMethod == "two_step")
  91 + {
  92 + if (authData == "Hello")
  93 + returnData = "Hello back";
  94 +
  95 + if (authData == "grant me already!")
  96 + {
  97 + returnData = "OK, if you insist.";
  98 + return AuthResult::success;
  99 + }
  100 + else if (authData == "whoops, wrong data.")
  101 + return AuthResult::login_denied;
  102 + else
  103 + return AuthResult::auth_continue;
  104 + }
  105 +
  106 + return AuthResult::auth_method_not_supported;
76 } 107 }
77 108
FlashMQTests/tst_maintests.cpp
@@ -32,6 +32,8 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;. @@ -32,6 +32,8 @@ License along with FlashMQ. If not, see &lt;https://www.gnu.org/licenses/&gt;.
32 #include "session.h" 32 #include "session.h"
33 #include "threaddata.h" 33 #include "threaddata.h"
34 #include "threadglobals.h" 34 #include "threadglobals.h"
  35 +#include "conffiletemp.h"
  36 +#include "packetdatatypes.h"
35 37
36 #include "flashmqtestclient.h" 38 #include "flashmqtestclient.h"
37 39
@@ -65,6 +67,7 @@ public: @@ -65,6 +67,7 @@ public:
65 ~MainTests(); 67 ~MainTests();
66 68
67 private slots: 69 private slots:
  70 + void init(const std::vector<std::string> &args);
68 void init(); // will be called before each test function is executed 71 void init(); // will be called before each test function is executed
69 void cleanup(); // will be called after every test function. 72 void cleanup(); // will be called after every test function.
70 73
@@ -134,6 +137,18 @@ private slots: @@ -134,6 +137,18 @@ private slots:
134 137
135 void testUserProperties(); 138 void testUserProperties();
136 139
  140 + void testAuthFail();
  141 + void testAuthSucceed();
  142 +
  143 + void testExtendedAuthOneStepSucceed();
  144 + void testExtendedAuthOneStepDeny();
  145 + void testExtendedAuthOneStepBadAuthMethod();
  146 + void testExtendedAuthTwoStep();
  147 + void testExtendedAuthTwoStepSecondStepFail();
  148 + void testExtendedReAuth();
  149 + void testExtendedReAuthTwoStep();
  150 + void testExtendedReAuthFail();
  151 +
137 }; 152 };
138 153
139 MainTests::MainTests() 154 MainTests::MainTests()
@@ -146,10 +161,10 @@ MainTests::~MainTests() @@ -146,10 +161,10 @@ MainTests::~MainTests()
146 161
147 } 162 }
148 163
149 -void MainTests::init() 164 +void MainTests::init(const std::vector<std::string> &args)
150 { 165 {
151 mainApp.reset(); 166 mainApp.reset();
152 - mainApp.reset(new MainAppThread()); 167 + mainApp.reset(new MainAppThread(args));
153 mainApp->start(); 168 mainApp->start();
154 mainApp->waitForStarted(); 169 mainApp->waitForStarted();
155 170
@@ -160,6 +175,12 @@ void MainTests::init() @@ -160,6 +175,12 @@ void MainTests::init()
160 ThreadGlobals::assignThreadData(dummyThreadData.get()); 175 ThreadGlobals::assignThreadData(dummyThreadData.get());
161 } 176 }
162 177
  178 +void MainTests::init()
  179 +{
  180 + std::vector<std::string> args;
  181 + init(args);
  182 +}
  183 +
163 void MainTests::cleanup() 184 void MainTests::cleanup()
164 { 185 {
165 mainApp->stopApp(); 186 mainApp->stopApp();
@@ -1925,6 +1946,414 @@ void MainTests::testUserProperties() @@ -1925,6 +1946,414 @@ void MainTests::testUserProperties()
1925 QVERIFY(properties3 == nullptr); 1946 QVERIFY(properties3 == nullptr);
1926 } 1947 }
1927 1948
  1949 +void MainTests::testAuthFail()
  1950 +{
  1951 + std::vector<ProtocolVersion> versions { ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5 };
  1952 +
  1953 + ConfFileTemp confFile;
  1954 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  1955 + confFile.closeFile();
  1956 +
  1957 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  1958 +
  1959 + cleanup();
  1960 + init(args);
  1961 +
  1962 + for (ProtocolVersion &version : versions)
  1963 + {
  1964 +
  1965 + FlashMQTestClient client;
  1966 + client.start();
  1967 + client.connectClient(version, false, 120, [](Connect &connect) {
  1968 + connect.username = "failme";
  1969 + connect.password = "boo";
  1970 + });
  1971 +
  1972 + QVERIFY(client.receivedPackets.size() == 1);
  1973 +
  1974 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  1975 +
  1976 + if (version >= ProtocolVersion::Mqtt5)
  1977 + QVERIFY(connAckData.reasonCode == ReasonCodes::NotAuthorized);
  1978 + else
  1979 + QVERIFY(static_cast<uint8_t>(connAckData.reasonCode) == 5);
  1980 + }
  1981 +}
  1982 +
  1983 +void MainTests::testAuthSucceed()
  1984 +{
  1985 + std::vector<ProtocolVersion> versions { ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5 };
  1986 +
  1987 + ConfFileTemp confFile;
  1988 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  1989 + confFile.closeFile();
  1990 +
  1991 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  1992 +
  1993 + cleanup();
  1994 + init(args);
  1995 +
  1996 + for (ProtocolVersion &version : versions)
  1997 + {
  1998 +
  1999 + FlashMQTestClient client;
  2000 + client.start();
  2001 + client.connectClient(version, false, 120, [](Connect &connect) {
  2002 + connect.username = "passme";
  2003 + connect.password = "boo";
  2004 + });
  2005 +
  2006 + QVERIFY(client.receivedPackets.size() == 1);
  2007 +
  2008 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2009 +
  2010 + QVERIFY(connAckData.reasonCode == ReasonCodes::Success);
  2011 + }
  2012 +}
  2013 +
  2014 +void MainTests::testExtendedAuthOneStepSucceed()
  2015 +{
  2016 + ConfFileTemp confFile;
  2017 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  2018 + confFile.closeFile();
  2019 +
  2020 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  2021 +
  2022 + cleanup();
  2023 + init(args);
  2024 +
  2025 + FlashMQTestClient client;
  2026 + client.start();
  2027 + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) {
  2028 + connect.username = "me";
  2029 + connect.password = "me me";
  2030 +
  2031 + connect.constructPropertyBuilder();
  2032 +
  2033 + connect.propertyBuilder->writeAuthenticationMethod("always_good_passing_back_the_auth_data");
  2034 + connect.propertyBuilder->writeAuthenticationData("I have a proposal to put to ye.");
  2035 + });
  2036 +
  2037 + QVERIFY(client.receivedPackets.size() == 1);
  2038 +
  2039 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2040 +
  2041 + QVERIFY(connAckData.reasonCode == ReasonCodes::Success);
  2042 + QVERIFY(connAckData.authData == "I have a proposal to put to ye.");
  2043 +}
  2044 +
  2045 +void MainTests::testExtendedAuthOneStepDeny()
  2046 +{
  2047 + ConfFileTemp confFile;
  2048 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  2049 + confFile.closeFile();
  2050 +
  2051 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  2052 +
  2053 + cleanup();
  2054 + init(args);
  2055 +
  2056 + FlashMQTestClient client;
  2057 + client.start();
  2058 + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) {
  2059 + connect.username = "me";
  2060 + connect.password = "me me";
  2061 +
  2062 + connect.constructPropertyBuilder();
  2063 +
  2064 + connect.propertyBuilder->writeAuthenticationMethod("always_fail");
  2065 + });
  2066 +
  2067 + QVERIFY(client.receivedPackets.size() == 1);
  2068 +
  2069 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2070 +
  2071 + QVERIFY(connAckData.reasonCode == ReasonCodes::NotAuthorized);
  2072 +}
  2073 +
  2074 +void MainTests::testExtendedAuthOneStepBadAuthMethod()
  2075 +{
  2076 + ConfFileTemp confFile;
  2077 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  2078 + confFile.closeFile();
  2079 +
  2080 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  2081 +
  2082 + cleanup();
  2083 + init(args);
  2084 +
  2085 + FlashMQTestClient client;
  2086 + client.start();
  2087 + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) {
  2088 + connect.username = "me";
  2089 + connect.password = "me me";
  2090 +
  2091 + connect.constructPropertyBuilder();
  2092 +
  2093 + connect.propertyBuilder->writeAuthenticationMethod("doesnt_exist");
  2094 + });
  2095 +
  2096 + QVERIFY(client.receivedPackets.size() == 1);
  2097 +
  2098 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2099 +
  2100 + QVERIFY(connAckData.reasonCode == ReasonCodes::BadAuthenticationMethod);
  2101 +}
  2102 +
  2103 +void MainTests::testExtendedAuthTwoStep()
  2104 +{
  2105 + ConfFileTemp confFile;
  2106 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  2107 + confFile.closeFile();
  2108 +
  2109 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  2110 +
  2111 + cleanup();
  2112 + init(args);
  2113 +
  2114 + FlashMQTestClient client;
  2115 + client.start();
  2116 + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) {
  2117 + connect.username = "me";
  2118 + connect.password = "me me";
  2119 +
  2120 + connect.constructPropertyBuilder();
  2121 +
  2122 + connect.propertyBuilder->writeAuthenticationMethod("two_step");
  2123 + connect.propertyBuilder->writeAuthenticationData("Hello");
  2124 + });
  2125 +
  2126 + QVERIFY(client.receivedPackets.size() == 1);
  2127 +
  2128 + AuthPacketData authData = client.receivedPackets.front().parseAuthData();
  2129 +
  2130 + QVERIFY(authData.reasonCode == ReasonCodes::ContinueAuthentication);
  2131 + QVERIFY(authData.data == "Hello back");
  2132 +
  2133 + client.clearReceivedLists();
  2134 +
  2135 + const Auth auth(ReasonCodes::ContinueAuthentication, "two_step", "grant me already!");
  2136 + client.writeAuth(auth);
  2137 +
  2138 + client.waitForConnack();
  2139 +
  2140 + QVERIFY(client.receivedPackets.size() == 1);
  2141 +
  2142 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2143 +
  2144 + QVERIFY(connAckData.reasonCode == ReasonCodes::Success);
  2145 + QVERIFY(connAckData.authData == "OK, if you insist.");
  2146 +}
  2147 +
  2148 +void MainTests::testExtendedAuthTwoStepSecondStepFail()
  2149 +{
  2150 + ConfFileTemp confFile;
  2151 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  2152 + confFile.closeFile();
  2153 +
  2154 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  2155 +
  2156 + cleanup();
  2157 + init(args);
  2158 +
  2159 + FlashMQTestClient client;
  2160 + client.start();
  2161 + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) {
  2162 + connect.username = "me";
  2163 + connect.password = "me me";
  2164 +
  2165 + connect.constructPropertyBuilder();
  2166 +
  2167 + connect.propertyBuilder->writeAuthenticationMethod("two_step");
  2168 + connect.propertyBuilder->writeAuthenticationData("Hello");
  2169 + });
  2170 +
  2171 + QVERIFY(client.receivedPackets.size() == 1);
  2172 +
  2173 + AuthPacketData authData = client.receivedPackets.front().parseAuthData();
  2174 +
  2175 + QVERIFY(authData.reasonCode == ReasonCodes::ContinueAuthentication);
  2176 + QVERIFY(authData.data == "Hello back");
  2177 +
  2178 + client.clearReceivedLists();
  2179 +
  2180 + const Auth auth(ReasonCodes::ContinueAuthentication, "two_step", "whoops, wrong data.");
  2181 + client.writeAuth(auth);
  2182 +
  2183 + client.waitForConnack();
  2184 +
  2185 + QVERIFY(client.receivedPackets.size() == 1);
  2186 +
  2187 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2188 +
  2189 + QVERIFY(connAckData.reasonCode == ReasonCodes::NotAuthorized);
  2190 +}
  2191 +
  2192 +void MainTests::testExtendedReAuth()
  2193 +{
  2194 + ConfFileTemp confFile;
  2195 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  2196 + confFile.closeFile();
  2197 +
  2198 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  2199 +
  2200 + cleanup();
  2201 + init(args);
  2202 +
  2203 + FlashMQTestClient client;
  2204 + client.start();
  2205 + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) {
  2206 + connect.username = "me";
  2207 + connect.password = "me me";
  2208 +
  2209 + connect.constructPropertyBuilder();
  2210 + connect.propertyBuilder->writeAuthenticationMethod("always_good_passing_back_the_auth_data");
  2211 + connect.propertyBuilder->writeAuthenticationData("Santa Claus");
  2212 + });
  2213 +
  2214 + QVERIFY(client.receivedPackets.size() == 1);
  2215 +
  2216 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2217 +
  2218 + QVERIFY(connAckData.reasonCode == ReasonCodes::Success);
  2219 +
  2220 + client.clearReceivedLists();
  2221 +
  2222 + // Then reauth.
  2223 +
  2224 + Auth auth(ReasonCodes::ContinueAuthentication, "always_good_passing_back_the_auth_data", "Again Santa Claus");
  2225 + client.writeAuth(auth);
  2226 +
  2227 + client.waitForConnack();
  2228 +
  2229 + QVERIFY(client.receivedPackets.size() == 1);
  2230 +
  2231 + AuthPacketData authData = client.receivedPackets.front().parseAuthData();
  2232 +
  2233 + QVERIFY(authData.reasonCode == ReasonCodes::Success);
  2234 + QVERIFY(authData.data == "Again Santa Claus");
  2235 +}
  2236 +
  2237 +void MainTests::testExtendedReAuthTwoStep()
  2238 +{
  2239 + ConfFileTemp confFile;
  2240 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  2241 + confFile.closeFile();
  2242 +
  2243 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  2244 +
  2245 + cleanup();
  2246 + init(args);
  2247 +
  2248 + FlashMQTestClient client;
  2249 + client.start();
  2250 + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) {
  2251 + connect.username = "me";
  2252 + connect.password = "me me";
  2253 +
  2254 + connect.constructPropertyBuilder();
  2255 +
  2256 + connect.propertyBuilder->writeAuthenticationMethod("two_step");
  2257 + connect.propertyBuilder->writeAuthenticationData("Hello");
  2258 + });
  2259 +
  2260 + QVERIFY(client.receivedPackets.size() == 1);
  2261 +
  2262 + AuthPacketData authData = client.receivedPackets.front().parseAuthData();
  2263 +
  2264 + QVERIFY(authData.reasonCode == ReasonCodes::ContinueAuthentication);
  2265 + QVERIFY(authData.data == "Hello back");
  2266 +
  2267 + client.clearReceivedLists();
  2268 +
  2269 + const Auth auth(ReasonCodes::ContinueAuthentication, "two_step", "grant me already!");
  2270 + client.writeAuth(auth);
  2271 +
  2272 + client.waitForConnack();
  2273 +
  2274 + QVERIFY(client.receivedPackets.size() == 1);
  2275 +
  2276 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2277 +
  2278 + QVERIFY(connAckData.reasonCode == ReasonCodes::Success);
  2279 + QVERIFY(connAckData.authData == "OK, if you insist.");
  2280 +
  2281 + client.clearReceivedLists();
  2282 +
  2283 + // Then reauth.
  2284 +
  2285 + const Auth reauth(ReasonCodes::ReAuthenticate, "two_step", "Hello");
  2286 + client.writeAuth(reauth);
  2287 + client.waitForConnack();
  2288 +
  2289 + QVERIFY(client.receivedPackets.size() == 1);
  2290 +
  2291 + AuthPacketData reauthData = client.receivedPackets.front().parseAuthData();
  2292 +
  2293 + QVERIFY(reauthData.reasonCode == ReasonCodes::ContinueAuthentication);
  2294 + QVERIFY(reauthData.data == "Hello back");
  2295 +
  2296 + client.clearReceivedLists();
  2297 +
  2298 + const Auth reauthFinish(ReasonCodes::ContinueAuthentication, "two_step", "grant me already!");
  2299 + client.writeAuth(reauthFinish);
  2300 +
  2301 + client.waitForConnack();
  2302 +
  2303 + QVERIFY(client.receivedPackets.size() == 1);
  2304 +
  2305 + AuthPacketData reauthFinishData = client.receivedPackets.front().parseAuthData();
  2306 +
  2307 + QVERIFY(connAckData.reasonCode == ReasonCodes::Success);
  2308 + QVERIFY(connAckData.authData == "OK, if you insist.");
  2309 +}
  2310 +
  2311 +void MainTests::testExtendedReAuthFail()
  2312 +{
  2313 + ConfFileTemp confFile;
  2314 + confFile.writeLine("auth_plugin plugins/libtest_plugin.so.0.0.1");
  2315 + confFile.closeFile();
  2316 +
  2317 + std::vector<std::string> args {"--config-file", confFile.getFilePath()};
  2318 +
  2319 + cleanup();
  2320 + init(args);
  2321 +
  2322 + FlashMQTestClient client;
  2323 + client.start();
  2324 + client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) {
  2325 + connect.username = "me";
  2326 + connect.password = "me me";
  2327 +
  2328 + connect.constructPropertyBuilder();
  2329 +
  2330 + connect.propertyBuilder->writeAuthenticationMethod("always_good_passing_back_the_auth_data");
  2331 + connect.propertyBuilder->writeAuthenticationData("I have a proposal to put to ye.");
  2332 + });
  2333 +
  2334 + QVERIFY(client.receivedPackets.size() == 1);
  2335 +
  2336 + ConnAckData connAckData = client.receivedPackets.front().parseConnAckData();
  2337 +
  2338 + QVERIFY(connAckData.reasonCode == ReasonCodes::Success);
  2339 + QVERIFY(connAckData.authData == "I have a proposal to put to ye.");
  2340 +
  2341 + client.clearReceivedLists();
  2342 +
  2343 + // Then reauth.
  2344 +
  2345 + const Auth reauth(ReasonCodes::ReAuthenticate, "always_good_passing_back_the_auth_data", "actually not good.");
  2346 + client.writeAuth(reauth);
  2347 + client.waitForPacketCount(1);
  2348 +
  2349 + QVERIFY(client.receivedPackets.size() == 1);
  2350 + QVERIFY(client.receivedPackets.front().packetType == PacketType::DISCONNECT);
  2351 +
  2352 + DisconnectData data = client.receivedPackets.front().parseDisconnectData();
  2353 +
  2354 + QVERIFY(data.reasonCode == ReasonCodes::NotAuthorized);
  2355 +}
  2356 +
1928 int main(int argc, char *argv[]) 2357 int main(int argc, char *argv[])
1929 { 2358 {
1930 QCoreApplication app(argc, argv); 2359 QCoreApplication app(argc, argv);
authplugin.cpp
@@ -356,7 +356,7 @@ AuthResult Authentication::unPwdCheck(const std::string &amp;username, const std::st @@ -356,7 +356,7 @@ AuthResult Authentication::unPwdCheck(const std::string &amp;username, const std::st
356 { 356 {
357 AuthResult firstResult = unPwdCheckFromMosquittoPasswordFile(username, password); 357 AuthResult firstResult = unPwdCheckFromMosquittoPasswordFile(username, password);
358 358
359 - if (firstResult != AuthResult::success) 359 + if (firstResult == AuthResult::success)
360 return firstResult; 360 return firstResult;
361 361
362 if (pluginVersion == PluginVersion::None) 362 if (pluginVersion == PluginVersion::None)
flashmqtestclient.cpp
@@ -260,6 +260,12 @@ void FlashMQTestClient::publish(Publish &amp;pub) @@ -260,6 +260,12 @@ void FlashMQTestClient::publish(Publish &amp;pub)
260 } 260 }
261 } 261 }
262 262
  263 +void FlashMQTestClient::writeAuth(const Auth &auth)
  264 +{
  265 + MqttPacket pack(auth);
  266 + client->writeMqttPacketAndBlameThisClient(pack);
  267 +}
  268 +
263 void FlashMQTestClient::publish(const std::string &topic, const std::string &payload, char qos) 269 void FlashMQTestClient::publish(const std::string &topic, const std::string &payload, char qos)
264 { 270 {
265 Publish pub(topic, payload, qos); 271 Publish pub(topic, payload, qos);
@@ -276,7 +282,7 @@ void FlashMQTestClient::waitForConnack() @@ -276,7 +282,7 @@ void FlashMQTestClient::waitForConnack()
276 { 282 {
277 waitForCondition([&]() { 283 waitForCondition([&]() {
278 return std::any_of(this->receivedPackets.begin(), this->receivedPackets.end(), [](const MqttPacket &p) { 284 return std::any_of(this->receivedPackets.begin(), this->receivedPackets.end(), [](const MqttPacket &p) {
279 - return p.packetType == PacketType::CONNACK; 285 + return p.packetType == PacketType::CONNACK || p.packetType == PacketType::AUTH;
280 }); 286 });
281 }); 287 });
282 } 288 }
@@ -287,3 +293,10 @@ void FlashMQTestClient::waitForMessageCount(const size_t count, int timeout) @@ -287,3 +293,10 @@ void FlashMQTestClient::waitForMessageCount(const size_t count, int timeout)
287 return this->receivedPublishes.size() >= count; 293 return this->receivedPublishes.size() >= count;
288 }, timeout); 294 }, timeout);
289 } 295 }
  296 +
  297 +void FlashMQTestClient::waitForPacketCount(const size_t count, int timeout)
  298 +{
  299 + waitForCondition([&]() {
  300 + return this->receivedPackets.size() >= count;
  301 + }, timeout);
  302 +}
flashmqtestclient.h
@@ -40,6 +40,7 @@ public: @@ -40,6 +40,7 @@ public:
40 void unsubscribe(const std::string &topic); 40 void unsubscribe(const std::string &topic);
41 void publish(const std::string &topic, const std::string &payload, char qos); 41 void publish(const std::string &topic, const std::string &payload, char qos);
42 void publish(Publish &pub); 42 void publish(Publish &pub);
  43 + void writeAuth(const Auth &auth);
43 void clearReceivedLists(); 44 void clearReceivedLists();
44 void setWill(std::shared_ptr<WillPublish> &will); 45 void setWill(std::shared_ptr<WillPublish> &will);
45 void disconnect(ReasonCodes reason); 46 void disconnect(ReasonCodes reason);
@@ -47,6 +48,7 @@ public: @@ -47,6 +48,7 @@ public:
47 void waitForQuit(); 48 void waitForQuit();
48 void waitForConnack(); 49 void waitForConnack();
49 void waitForMessageCount(const size_t count, int timeout = 1); 50 void waitForMessageCount(const size_t count, int timeout = 1);
  51 + void waitForPacketCount(const size_t count, int timeout = 1);
50 }; 52 };
51 53
52 #endif // FLASHMQTESTCLIENT_H 54 #endif // FLASHMQTESTCLIENT_H
mqttpacket.cpp
@@ -253,6 +253,8 @@ MqttPacket::MqttPacket(const Connect &amp;connect) : @@ -253,6 +253,8 @@ MqttPacket::MqttPacket(const Connect &amp;connect) :
253 writeByte(static_cast<char>(protocolVersion)); 253 writeByte(static_cast<char>(protocolVersion));
254 254
255 uint8_t flags = connect.clean_start << 1; 255 uint8_t flags = connect.clean_start << 1;
  256 + flags |= !connect.username.empty() << 7;
  257 + flags |= !connect.password.empty() << 6;
256 258
257 if (connect.will) 259 if (connect.will)
258 { 260 {
@@ -284,6 +286,11 @@ MqttPacket::MqttPacket(const Connect &amp;connect) : @@ -284,6 +286,11 @@ MqttPacket::MqttPacket(const Connect &amp;connect) :
284 writeString(connect.will->payload); 286 writeString(connect.will->payload);
285 } 287 }
286 288
  289 + if (!connect.username.empty())
  290 + writeString(connect.username);
  291 + if (!connect.password.empty())
  292 + writeString(connect.password);
  293 +
287 calculateRemainingLength(); 294 calculateRemainingLength();
288 } 295 }
289 296
@@ -634,6 +641,43 @@ ConnectData MqttPacket::parseConnectData() @@ -634,6 +641,43 @@ ConnectData MqttPacket::parseConnectData()
634 return result; 641 return result;
635 } 642 }
636 643
  644 +ConnAckData MqttPacket::parseConnAckData()
  645 +{
  646 + if (this->packetType != PacketType::CONNACK)
  647 + throw std::runtime_error("Packet must be connack packet.");
  648 +
  649 + setPosToDataStart();
  650 +
  651 + ConnAckData result;
  652 +
  653 + const uint8_t flagByte = readByte();
  654 +
  655 + result.sessionPresent = flagByte & 0x01;
  656 + result.reasonCode = static_cast<ReasonCodes>(readUint8());
  657 +
  658 + if (protocolVersion == ProtocolVersion::Mqtt5)
  659 + {
  660 + const size_t proplen = decodeVariableByteIntAtPos();
  661 + const size_t prop_end_at = pos + proplen;
  662 +
  663 + while (pos < prop_end_at)
  664 + {
  665 + const Mqtt5Properties prop = static_cast<Mqtt5Properties>(readByte());
  666 +
  667 + switch (prop)
  668 + {
  669 + case Mqtt5Properties::AuthenticationData:
  670 + result.authData = readBytesToString();
  671 + break;
  672 + default:
  673 + break;
  674 + }
  675 + }
  676 + }
  677 +
  678 + return result;
  679 +}
  680 +
637 void MqttPacket::handleConnect() 681 void MqttPacket::handleConnect()
638 { 682 {
639 if (sender->hasConnectPacketSeen()) 683 if (sender->hasConnectPacketSeen())
@@ -798,18 +842,22 @@ void MqttPacket::handleConnect() @@ -798,18 +842,22 @@ void MqttPacket::handleConnect()
798 } 842 }
799 } 843 }
800 844
801 -void MqttPacket::handleExtendedAuth() 845 +AuthPacketData MqttPacket::parseAuthData()
802 { 846 {
  847 + if (this->packetType != PacketType::AUTH)
  848 + throw std::runtime_error("Packet must be an AUTH packet.");
  849 +
803 if (first_byte & 0b1111) 850 if (first_byte & 0b1111)
804 throw ProtocolError("AUTH packet first 4 bits should be 0.", ReasonCodes::MalformedPacket); 851 throw ProtocolError("AUTH packet first 4 bits should be 0.", ReasonCodes::MalformedPacket);
805 852
806 - const ReasonCodes reasonCode = static_cast<ReasonCodes>(readByte());  
807 -  
808 if (this->protocolVersion < ProtocolVersion::Mqtt5) 853 if (this->protocolVersion < ProtocolVersion::Mqtt5)
809 throw ProtocolError("AUTH packet needs MQTT5 or higher"); 854 throw ProtocolError("AUTH packet needs MQTT5 or higher");
810 855
811 - std::string authMethod;  
812 - std::string authData; 856 + setPosToDataStart();
  857 +
  858 + AuthPacketData result;
  859 +
  860 + result.reasonCode = static_cast<ReasonCodes>(readUint8());
813 861
814 if (!atEnd()) 862 if (!atEnd())
815 { 863 {
@@ -823,10 +871,10 @@ void MqttPacket::handleExtendedAuth() @@ -823,10 +871,10 @@ void MqttPacket::handleExtendedAuth()
823 switch (prop) 871 switch (prop)
824 { 872 {
825 case Mqtt5Properties::AuthenticationMethod: 873 case Mqtt5Properties::AuthenticationMethod:
826 - authMethod = readBytesToString(); 874 + result.method = readBytesToString();
827 break; 875 break;
828 case Mqtt5Properties::AuthenticationData: 876 case Mqtt5Properties::AuthenticationData:
829 - authData = readBytesToString(false); 877 + result.data = readBytesToString(false);
830 break; 878 break;
831 case Mqtt5Properties::ReasonString: 879 case Mqtt5Properties::ReasonString:
832 readBytesToString(); 880 readBytesToString();
@@ -840,12 +888,19 @@ void MqttPacket::handleExtendedAuth() @@ -840,12 +888,19 @@ void MqttPacket::handleExtendedAuth()
840 } 888 }
841 } 889 }
842 890
843 - if (authMethod != sender->getExtendedAuthenticationMethod()) 891 + return result;
  892 +}
  893 +
  894 +void MqttPacket::handleExtendedAuth()
  895 +{
  896 + AuthPacketData data = parseAuthData();
  897 +
  898 + if (data.method != sender->getExtendedAuthenticationMethod())
844 throw ProtocolError("Client continued with another authentication method that it started with.", ReasonCodes::ProtocolError); 899 throw ProtocolError("Client continued with another authentication method that it started with.", ReasonCodes::ProtocolError);
845 900
846 ExtendedAuthStage authStage = ExtendedAuthStage::None; 901 ExtendedAuthStage authStage = ExtendedAuthStage::None;
847 902
848 - switch(reasonCode) 903 + switch(data.reasonCode)
849 { 904 {
850 case ReasonCodes::ContinueAuthentication: 905 case ReasonCodes::ContinueAuthentication:
851 authStage = ExtendedAuthStage::Continue; 906 authStage = ExtendedAuthStage::Continue;
@@ -854,7 +909,7 @@ void MqttPacket::handleExtendedAuth() @@ -854,7 +909,7 @@ void MqttPacket::handleExtendedAuth()
854 authStage = ExtendedAuthStage::Reauth; 909 authStage = ExtendedAuthStage::Reauth;
855 break; 910 break;
856 default: 911 default:
857 - throw ProtocolError(formatString("Invalid reason code '%d' in auth packet", static_cast<uint8_t>(reasonCode)), ReasonCodes::MalformedPacket); 912 + throw ProtocolError(formatString("Invalid reason code '%d' in auth packet", static_cast<uint8_t>(data.reasonCode)), ReasonCodes::MalformedPacket);
858 } 913 }
859 914
860 if (authStage == ExtendedAuthStage::Reauth && !sender->getAuthenticated()) 915 if (authStage == ExtendedAuthStage::Reauth && !sender->getAuthenticated())
@@ -865,28 +920,24 @@ void MqttPacket::handleExtendedAuth() @@ -865,28 +920,24 @@ void MqttPacket::handleExtendedAuth()
865 Authentication &authentication = *ThreadGlobals::getAuth(); 920 Authentication &authentication = *ThreadGlobals::getAuth();
866 921
867 std::string returnData; 922 std::string returnData;
868 - const AuthResult authResult = authentication.extendedAuth(sender->getClientId(), authStage, authMethod, authData, 923 + const AuthResult authResult = authentication.extendedAuth(sender->getClientId(), authStage, data.method, data.data,
869 getUserProperties(), returnData, sender->getMutableUsername()); 924 getUserProperties(), returnData, sender->getMutableUsername());
870 925
871 if (authResult == AuthResult::auth_continue) 926 if (authResult == AuthResult::auth_continue)
872 { 927 {
873 - Auth auth(ReasonCodes::ContinueAuthentication, authMethod, returnData); 928 + Auth auth(ReasonCodes::ContinueAuthentication, data.method, returnData);
874 MqttPacket pack(auth); 929 MqttPacket pack(auth);
875 sender->writeMqttPacket(pack); 930 sender->writeMqttPacket(pack);
876 return; 931 return;
877 } 932 }
878 933
879 - if (authResult == AuthResult::success)  
880 - {  
881 - sender->addAuthReturnDataToStagedConnAck(returnData);  
882 - }  
883 -  
884 const ReasonCodes finalResult = authResultToReasonCode(authResult); 934 const ReasonCodes finalResult = authResultToReasonCode(authResult);
885 935
886 if (!sender->getAuthenticated()) // First auth sends connack packets. 936 if (!sender->getAuthenticated()) // First auth sends connack packets.
887 { 937 {
888 if (finalResult == ReasonCodes::Success) 938 if (finalResult == ReasonCodes::Success)
889 { 939 {
  940 + sender->addAuthReturnDataToStagedConnAck(returnData);
890 sender->sendConnackSuccess(); 941 sender->sendConnackSuccess();
891 std::shared_ptr<SubscriptionStore> subscriptionStore = MainApp::getMainApp()->getSubscriptionStore(); 942 std::shared_ptr<SubscriptionStore> subscriptionStore = MainApp::getMainApp()->getSubscriptionStore();
892 subscriptionStore->registerClientAndKickExistingOne(sender); 943 subscriptionStore->registerClientAndKickExistingOne(sender);
@@ -900,7 +951,7 @@ void MqttPacket::handleExtendedAuth() @@ -900,7 +951,7 @@ void MqttPacket::handleExtendedAuth()
900 { 951 {
901 if (finalResult == ReasonCodes::Success) 952 if (finalResult == ReasonCodes::Success)
902 { 953 {
903 - Auth auth(ReasonCodes::Success, authMethod, returnData); 954 + Auth auth(ReasonCodes::Success, data.method, returnData);
904 MqttPacket authPack(auth); 955 MqttPacket authPack(auth);
905 sender->writeMqttPacket(authPack); 956 sender->writeMqttPacket(authPack);
906 logger->logf(LOG_NOTICE, "Client '%s', user '%s' reauthentication successful.", sender->getClientId().c_str(), sender->getUsername().c_str()); 957 logger->logf(LOG_NOTICE, "Client '%s', user '%s' reauthentication successful.", sender->getClientId().c_str(), sender->getUsername().c_str());
@@ -933,7 +984,7 @@ DisconnectData MqttPacket::parseDisconnectData() @@ -933,7 +984,7 @@ DisconnectData MqttPacket::parseDisconnectData()
933 { 984 {
934 if (!atEnd()) 985 if (!atEnd())
935 { 986 {
936 - result.reasonCode = static_cast<ReasonCodes>(readByte()); 987 + result.reasonCode = static_cast<ReasonCodes>(readUint8());
937 988
938 const size_t proplen = decodeVariableByteIntAtPos(); 989 const size_t proplen = decodeVariableByteIntAtPos();
939 const size_t prop_end_at = pos + proplen; 990 const size_t prop_end_at = pos + proplen;
@@ -1621,6 +1672,12 @@ char MqttPacket::readByte() @@ -1621,6 +1672,12 @@ char MqttPacket::readByte()
1621 return b; 1672 return b;
1622 } 1673 }
1623 1674
  1675 +uint8_t MqttPacket::readUint8()
  1676 +{
  1677 + char r = readByte();
  1678 + return static_cast<uint8_t>(r);
  1679 +}
  1680 +
1624 void MqttPacket::writeByte(char b) 1681 void MqttPacket::writeByte(char b)
1625 { 1682 {
1626 if (pos + 1 > bites.size()) 1683 if (pos + 1 > bites.size())
mqttpacket.h
@@ -73,6 +73,7 @@ class MqttPacket @@ -73,6 +73,7 @@ class MqttPacket
73 73
74 char *readBytes(size_t length); 74 char *readBytes(size_t length);
75 char readByte(); 75 char readByte();
  76 + uint8_t readUint8();
76 void writeByte(char b); 77 void writeByte(char b);
77 void writeUint16(uint16_t x); 78 void writeUint16(uint16_t x);
78 void writeBytes(const char *b, size_t len); 79 void writeBytes(const char *b, size_t len);
@@ -121,7 +122,9 @@ public: @@ -121,7 +122,9 @@ public:
121 static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender); 122 static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender);
122 123
123 void handle(); 124 void handle();
  125 + AuthPacketData parseAuthData();
124 ConnectData parseConnectData(); 126 ConnectData parseConnectData();
  127 + ConnAckData parseConnAckData();
125 void handleConnect(); 128 void handleConnect();
126 void handleExtendedAuth(); 129 void handleExtendedAuth();
127 DisconnectData parseDisconnectData(); 130 DisconnectData parseDisconnectData();
packetdatatypes.h
@@ -38,6 +38,23 @@ struct ConnectData @@ -38,6 +38,23 @@ struct ConnectData
38 ConnectData(); 38 ConnectData();
39 }; 39 };
40 40
  41 +struct ConnAckData
  42 +{
  43 + // Flags
  44 + bool sessionPresent = false;
  45 +
  46 + ReasonCodes reasonCode = ReasonCodes::ImplementationSpecificError; // default something that is never a parse result;
  47 +
  48 + std::string authData;
  49 +};
  50 +
  51 +struct AuthPacketData
  52 +{
  53 + std::string method;
  54 + std::string data;
  55 + ReasonCodes reasonCode = ReasonCodes::ImplementationSpecificError; // default something that is never a parse result;
  56 +};
  57 +
41 struct DisconnectData 58 struct DisconnectData
42 { 59 {
43 ReasonCodes reasonCode = ReasonCodes::Success; 60 ReasonCodes reasonCode = ReasonCodes::Success;
types.cpp
@@ -397,6 +397,12 @@ size_t Connect::getLengthWithoutFixedHeader() const @@ -397,6 +397,12 @@ size_t Connect::getLengthWithoutFixedHeader() const
397 result += will->payload.length() + 2; 397 result += will->payload.length() + 2;
398 } 398 }
399 399
  400 + if (!username.empty())
  401 + result += username.size() + 2;
  402 +
  403 + if (!password.empty())
  404 + result += password.size() + 2;
  405 +
400 return result; 406 return result;
401 } 407 }
402 408