diff --git a/CMakeLists.txt b/CMakeLists.txt
index 517499f..0a09f11 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -38,6 +38,7 @@ add_executable(FlashMQ
scopedsocket.h
bindaddr.h
oneinstancelock.h
+ evpencodectxmanager.h
mainapp.cpp
main.cpp
@@ -65,6 +66,7 @@ add_executable(FlashMQ
scopedsocket.cpp
bindaddr.cpp
oneinstancelock.cpp
+ evpencodectxmanager.cpp
)
target_link_libraries(FlashMQ pthread dl ssl crypto)
diff --git a/authplugin.cpp b/authplugin.cpp
index b0be764..583762d 100644
--- a/authplugin.cpp
+++ b/authplugin.cpp
@@ -21,12 +21,15 @@ License along with FlashMQ. If not, see .
#include
#include
#include
+#include
+#include "sys/stat.h"
#include "exceptions.h"
#include "unscopedlock.h"
+#include "utils.h"
-std::mutex AuthPlugin::initMutex;
-std::mutex AuthPlugin::authChecksMutex;
+std::mutex Authentication::initMutex;
+std::mutex Authentication::authChecksMutex;
void mosquitto_log_printf(int level, const char *fmt, ...)
{
@@ -37,19 +40,39 @@ void mosquitto_log_printf(int level, const char *fmt, ...)
va_end(valist);
}
+MosquittoPasswordFileEntry::MosquittoPasswordFileEntry(const std::vector &&salt, const std::vector &&cryptedPassword) :
+ salt(salt),
+ cryptedPassword(cryptedPassword)
+{
+
+}
-AuthPlugin::AuthPlugin(Settings &settings) :
- settings(settings)
+
+Authentication::Authentication(Settings &settings) :
+ settings(settings),
+ mosquittoPasswordFile(settings.mosquittoPasswordFile),
+ mosquittoDigestContext(EVP_MD_CTX_new())
{
logger = Logger::getInstance();
+
+ if(!sha512)
+ {
+ throw std::runtime_error("Failed to initialize SHA512 for decoding auth entry");
+ }
+
+ EVP_DigestInit_ex(mosquittoDigestContext, sha512, NULL);
+ memset(&mosquittoPasswordFileLastLoad, 0, sizeof(struct timespec));
}
-AuthPlugin::~AuthPlugin()
+Authentication::~Authentication()
{
cleanup();
+
+ if (mosquittoDigestContext)
+ EVP_MD_CTX_free(mosquittoDigestContext);
}
-void *AuthPlugin::loadSymbol(void *handle, const char *symbol) const
+void *Authentication::loadSymbol(void *handle, const char *symbol) const
{
void *r = dlsym(handle, symbol);
@@ -62,7 +85,7 @@ void *AuthPlugin::loadSymbol(void *handle, const char *symbol) const
return r;
}
-void AuthPlugin::loadPlugin(const std::string &pathToSoFile)
+void Authentication::loadPlugin(const std::string &pathToSoFile)
{
if (pathToSoFile.empty())
return;
@@ -70,7 +93,7 @@ void AuthPlugin::loadPlugin(const std::string &pathToSoFile)
logger->logf(LOG_NOTICE, "Loading auth plugin %s", pathToSoFile.c_str());
initialized = false;
- wanted = true;
+ useExternalPlugin = true;
if (access(pathToSoFile.c_str(), R_OK) != 0)
{
@@ -105,9 +128,13 @@ void AuthPlugin::loadPlugin(const std::string &pathToSoFile)
initialized = true;
}
-void AuthPlugin::init()
+/**
+ * @brief AuthPlugin::init is like Mosquitto's init(), and is to allow the plugin to init memory. Plugins should not load
+ * their authentication data here. That's what securityInit() is for.
+ */
+void Authentication::init()
{
- if (!wanted)
+ if (!useExternalPlugin)
return;
UnscopedLock lock(initMutex);
@@ -123,7 +150,7 @@ void AuthPlugin::init()
throw FatalError("Error initialising auth plugin.");
}
-void AuthPlugin::cleanup()
+void Authentication::cleanup()
{
if (!cleanup_v2)
return;
@@ -136,9 +163,13 @@ void AuthPlugin::cleanup()
logger->logf(LOG_ERR, "Error cleaning up auth plugin"); // Not doing exception, because we're shutting down anyway.
}
-void AuthPlugin::securityInit(bool reloading)
+/**
+ * @brief AuthPlugin::securityInit initializes the security data, like loading users, ACL tables, etc.
+ * @param reloading
+ */
+void Authentication::securityInit(bool reloading)
{
- if (!wanted)
+ if (!useExternalPlugin)
return;
UnscopedLock lock(initMutex);
@@ -157,9 +188,9 @@ void AuthPlugin::securityInit(bool reloading)
initialized = true;
}
-void AuthPlugin::securityCleanup(bool reloading)
+void Authentication::securityCleanup(bool reloading)
{
- if (!wanted)
+ if (!useExternalPlugin)
return;
initialized = false;
@@ -172,9 +203,9 @@ void AuthPlugin::securityCleanup(bool reloading)
}
}
-AuthResult AuthPlugin::aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, AclAccess access)
+AuthResult Authentication::aclCheck(const std::string &clientid, const std::string &username, const std::string &topic, AclAccess access)
{
- if (!wanted)
+ if (!useExternalPlugin)
return AuthResult::success;
if (!initialized)
@@ -198,14 +229,19 @@ AuthResult AuthPlugin::aclCheck(const std::string &clientid, const std::string &
return result_;
}
-AuthResult AuthPlugin::unPwdCheck(const std::string &username, const std::string &password)
+AuthResult Authentication::unPwdCheck(const std::string &username, const std::string &password)
{
- if (!wanted)
- return AuthResult::success;
+ AuthResult firstResult = unPwdCheckFromMosquittoPasswordFile(username, password);
+
+ if (firstResult != AuthResult::success)
+ return firstResult;
+
+ if (!useExternalPlugin)
+ return firstResult;
if (!initialized)
{
- logger->logf(LOG_ERR, "Username+password check wanted, but initialization failed. Can't perform check.");
+ logger->logf(LOG_ERR, "Username+password check with plugin wanted, but initialization failed. Can't perform check.");
return AuthResult::error;
}
@@ -224,11 +260,126 @@ AuthResult AuthPlugin::unPwdCheck(const std::string &username, const std::string
return r;
}
-void AuthPlugin::setQuitting()
+void Authentication::setQuitting()
{
this->quitting = true;
}
+/**
+ * @brief Authentication::loadMosquittoPasswordFile is called once on startup, and on a frequent interval, and reloads the file if changed.
+ */
+void Authentication::loadMosquittoPasswordFile()
+{
+ if (this->mosquittoPasswordFile.empty())
+ return;
+
+ if (access(this->mosquittoPasswordFile.c_str(), R_OK) != 0)
+ {
+ logger->logf(LOG_ERR, "Passwd file '%s' is not there or not readable.", this->mosquittoPasswordFile.c_str());
+ return;
+ }
+
+ struct stat statbuf;
+ memset(&statbuf, 0, sizeof(struct stat));
+ check(stat(mosquittoPasswordFile.c_str(), &statbuf));
+ struct timespec ctime = statbuf.st_ctim;
+
+ if (ctime.tv_sec == this->mosquittoPasswordFileLastLoad.tv_sec)
+ return;
+
+ logger->logf(LOG_NOTICE, "Change detected in '%s'. Reloading.", this->mosquittoPasswordFile.c_str());
+
+ try
+ {
+ std::ifstream infile(this->mosquittoPasswordFile, std::ios::in);
+ std::unique_ptr> passwordEntries_tmp(new std::unordered_map());
+
+ for(std::string line; getline(infile, line ); )
+ {
+ if (line.empty())
+ continue;
+
+ try
+ {
+ std::vector fields = splitToVector(line, ':');
+
+ if (fields.size() != 2)
+ throw std::runtime_error(formatString("Passwd file line '%s' contains more than one ':'", line.c_str()));
+
+ const std::string &username = fields[0];
+
+ for (const std::string &field : fields)
+ {
+ if (field.size() == 0)
+ {
+ throw std::runtime_error(formatString("An empty field was found in '%'", line.c_str()));
+ }
+ }
+
+ std::vector fields2 = splitToVector(fields[1], '$', 3, false);
+
+ if (fields2.size() != 3)
+ throw std::runtime_error(formatString("Invalid line format in '%s'. Expected three fields separated by '$'", line.c_str()));
+
+ if (fields2[0] != "6")
+ throw std::runtime_error("Password fields must start with $6$");
+
+ std::vector salt = base64Decode(fields2[1]);
+ std::vector cryptedPassword = base64Decode(fields2[2]);
+ passwordEntries_tmp->emplace(username, MosquittoPasswordFileEntry(std::move(salt), std::move(cryptedPassword)));
+ }
+ catch (std::exception &ex)
+ {
+ std::string lineCut = formatString("%s...", line.substr(0, 20).c_str());
+ logger->logf(LOG_ERR, "Dropping invalid username/password line: '%s'. Error: %s", lineCut.c_str(), ex.what());
+ }
+ }
+
+ this->mosquittoPasswordEntries = std::move(passwordEntries_tmp);
+ this->mosquittoPasswordFileLastLoad = ctime;
+ }
+ catch (std::exception &ex)
+ {
+ logger->logf(LOG_ERR, "Error loading Mosquitto password file: '%s'. Authentication won't work.", ex.what());
+ }
+}
+
+AuthResult Authentication::unPwdCheckFromMosquittoPasswordFile(const std::string &username, const std::string &password)
+{
+ if (this->mosquittoPasswordFile.empty())
+ return AuthResult::success;
+
+ if (!this->mosquittoPasswordEntries)
+ return AuthResult::login_denied;
+
+ AuthResult result = settings.allowAnonymous ? AuthResult::success : AuthResult::login_denied;
+
+ auto it = mosquittoPasswordEntries->find(username);
+ if (it != mosquittoPasswordEntries->end())
+ {
+ result = AuthResult::login_denied;
+
+ unsigned char md_value[EVP_MAX_MD_SIZE];
+ unsigned int output_len = 0;
+
+ const MosquittoPasswordFileEntry &entry = it->second;
+
+ EVP_MD_CTX_reset(mosquittoDigestContext);
+ EVP_DigestInit_ex(mosquittoDigestContext, sha512, NULL);
+ EVP_DigestUpdate(mosquittoDigestContext, password.c_str(), password.length());
+ EVP_DigestUpdate(mosquittoDigestContext, entry.salt.data(), entry.salt.size());
+ EVP_DigestFinal_ex(mosquittoDigestContext, md_value, &output_len);
+
+ std::vector hashedSalted(output_len);
+ std::memcpy(hashedSalted.data(), md_value, output_len);
+
+ if (hashedSalted == entry.cryptedPassword)
+ result = AuthResult::success;
+ }
+
+ return result;
+}
+
std::string AuthResultToString(AuthResult r)
{
{
diff --git a/authplugin.h b/authplugin.h
index acc5840..4f67468 100644
--- a/authplugin.h
+++ b/authplugin.h
@@ -41,6 +41,26 @@ enum class AuthResult
error = 13
};
+/**
+ * @brief The MosquittoPasswordFileEntry struct stores the decoded base64 password salt and hash.
+ *
+ * The Mosquitto encrypted format looks like that of crypt(2), but it's not. This is an example entry:
+ *
+ * one:$6$emTXKCHfxMnZLDWg$gDcJRPojvOX8l7W/DRhSPoxV3CgPfECJVGRzw2Sqjdc2KIQ/CVLS1mNEuZUsp/vLdj7RCuqXCkgG43+XIc8WBA==
+ *
+ * $ is the seperator. '6' is hard-coded by the 'mosquitto_passwd' utility.
+ */
+struct MosquittoPasswordFileEntry
+{
+ std::vector salt;
+ std::vector cryptedPassword;
+
+ MosquittoPasswordFileEntry(const std::vector &&salt, const std::vector &&cryptedPassword);
+
+ // The plan was that objects of this type wouldn't be copied, but I can't get emplacing to work without it...?
+ //MosquittoPasswordFileEntry(const MosquittoPasswordFileEntry &other) = delete;
+};
+
typedef int (*F_auth_plugin_version)(void);
typedef int (*F_auth_plugin_init_v2)(void **, struct mosquitto_auth_opt *, int);
@@ -59,8 +79,11 @@ extern "C"
std::string AuthResultToString(AuthResult r);
-
-class AuthPlugin
+/**
+ * @brief The Authentication class handles our integrated authentication, but also supports loading Mosquitto auth
+ * plugin compatible .so files.
+ */
+class Authentication
{
F_auth_plugin_version version = nullptr;
F_auth_plugin_init_v2 init_v2 = nullptr;
@@ -79,15 +102,30 @@ class AuthPlugin
void *pluginData = nullptr;
Logger *logger = nullptr;
bool initialized = false;
- bool wanted = false;
+ bool useExternalPlugin = false;
bool quitting = false;
+ /**
+ * @brief mosquittoPasswordFile is a once set value based on config. It's not reloaded on reload signal currently, because it
+ * forces some decisions when you change files or remove the config option. For instance, do you remove all accounts loaded
+ * from the previous one? Perhaps I'm overthinking it.
+ *
+ * Its content is, however, reloaded every two seconds.
+ */
+ const std::string mosquittoPasswordFile;
+
+ struct timespec mosquittoPasswordFileLastLoad;
+
+ std::unique_ptr> mosquittoPasswordEntries;
+ EVP_MD_CTX *mosquittoDigestContext = nullptr;
+ const EVP_MD *sha512 = EVP_sha512();
+
void *loadSymbol(void *handle, const char *symbol) const;
public:
- AuthPlugin(Settings &settings);
- AuthPlugin(const AuthPlugin &other) = delete;
- AuthPlugin(AuthPlugin &&other) = delete;
- ~AuthPlugin();
+ Authentication(Settings &settings);
+ Authentication(const Authentication &other) = delete;
+ Authentication(Authentication &&other) = delete;
+ ~Authentication();
void loadPlugin(const std::string &pathToSoFile);
void init();
@@ -98,6 +136,8 @@ public:
AuthResult unPwdCheck(const std::string &username, const std::string &password);
void setQuitting();
+ void loadMosquittoPasswordFile();
+ AuthResult unPwdCheckFromMosquittoPasswordFile(const std::string &username, const std::string &password);
};
diff --git a/configfileparser.cpp b/configfileparser.cpp
index 8d36fcc..a7d35f4 100644
--- a/configfileparser.cpp
+++ b/configfileparser.cpp
@@ -85,6 +85,8 @@ ConfigFileParser::ConfigFileParser(const std::string &path) :
validKeys.insert("max_packet_size");
validKeys.insert("log_debug");
validKeys.insert("log_subscriptions");
+ validKeys.insert("mosquitto_password_file");
+ validKeys.insert("allow_anonymous");
validListenKeys.insert("port");
validListenKeys.insert("protocol");
@@ -334,6 +336,17 @@ void ConfigFileParser::loadFile(bool test)
bool tmp = stringTruthiness(value);
tmpSettings->logSubscriptions = tmp;
}
+
+ if (key == "mosquitto_password_file")
+ {
+ tmpSettings->mosquittoPasswordFile = value;
+ }
+
+ if (key == "allow_anonymous")
+ {
+ bool tmp = stringTruthiness(value);
+ tmpSettings->allowAnonymous = tmp;
+ }
}
}
catch (std::invalid_argument &ex) // catch for the stoi()
diff --git a/evpencodectxmanager.cpp b/evpencodectxmanager.cpp
new file mode 100644
index 0000000..0517717
--- /dev/null
+++ b/evpencodectxmanager.cpp
@@ -0,0 +1,36 @@
+/*
+This file is part of FlashMQ (https://www.flashmq.org)
+Copyright (C) 2021 Wiebe Cazemier
+
+FlashMQ is free software: you can redistribute it and/or modify
+it under the terms of the GNU Affero General Public License as
+published by the Free Software Foundation, version 3.
+
+FlashMQ is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Affero General Public License for more details.
+
+You should have received a copy of the GNU Affero General Public
+License along with FlashMQ. If not, see .
+*/
+
+#include
+
+#include "evpencodectxmanager.h"
+
+EvpEncodeCtxManager::EvpEncodeCtxManager()
+{
+ ctx = EVP_ENCODE_CTX_new();
+
+ if (!ctx)
+ throw std::runtime_error("Error allocating with EVP_ENCODE_CTX_new()");
+
+ EVP_DecodeInit(ctx);
+}
+
+EvpEncodeCtxManager::~EvpEncodeCtxManager()
+{
+ if (ctx)
+ EVP_ENCODE_CTX_free(ctx);
+}
diff --git a/evpencodectxmanager.h b/evpencodectxmanager.h
new file mode 100644
index 0000000..0c20b81
--- /dev/null
+++ b/evpencodectxmanager.h
@@ -0,0 +1,30 @@
+/*
+This file is part of FlashMQ (https://www.flashmq.org)
+Copyright (C) 2021 Wiebe Cazemier
+
+FlashMQ is free software: you can redistribute it and/or modify
+it under the terms of the GNU Affero General Public License as
+published by the Free Software Foundation, version 3.
+
+FlashMQ is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Affero General Public License for more details.
+
+You should have received a copy of the GNU Affero General Public
+License along with FlashMQ. If not, see .
+*/
+
+#ifndef EVPENCODECTXMANAGER_H
+#define EVPENCODECTXMANAGER_H
+
+#include "openssl/evp.h"
+
+struct EvpEncodeCtxManager
+{
+ EVP_ENCODE_CTX *ctx = nullptr;
+ EvpEncodeCtxManager();
+ ~EvpEncodeCtxManager();
+};
+
+#endif // EVPENCODECTXMANAGER_H
diff --git a/mainapp.cpp b/mainapp.cpp
index e49975f..5a83228 100644
--- a/mainapp.cpp
+++ b/mainapp.cpp
@@ -186,6 +186,9 @@ MainApp::MainApp(const std::string &configFilePath) :
auto fKeepAlive = std::bind(&MainApp::queueKeepAliveCheckAtAllThreads, this);
timer.addCallback(fKeepAlive, 30000, "keep-alive check");
+
+ auto fPasswordFileReload = std::bind(&MainApp::queuePasswordFileReloadAllThreads, this);
+ timer.addCallback(fPasswordFileReload, 2000, "Password file reload.");
}
MainApp::~MainApp()
@@ -292,6 +295,14 @@ void MainApp::queueKeepAliveCheckAtAllThreads()
}
}
+void MainApp::queuePasswordFileReloadAllThreads()
+{
+ for (std::shared_ptr &thread : threads)
+ {
+ thread->queuePasswdFileReload();
+ }
+}
+
void MainApp::setFuzzFile(const std::string &fuzzFilePath)
{
this->fuzzFilePath = fuzzFilePath;
diff --git a/mainapp.h b/mainapp.h
index 4464197..6a34da8 100644
--- a/mainapp.h
+++ b/mainapp.h
@@ -73,6 +73,7 @@ class MainApp
std::list createListenSocket(const std::shared_ptr &listener);
void wakeUpThread();
void queueKeepAliveCheckAtAllThreads();
+ void queuePasswordFileReloadAllThreads();
void setFuzzFile(const std::string &fuzzFilePath);
MainApp(const std::string &configFilePath);
diff --git a/mqttpacket.cpp b/mqttpacket.cpp
index 6fde0dd..ffb6146 100644
--- a/mqttpacket.cpp
+++ b/mqttpacket.cpp
@@ -313,7 +313,7 @@ void MqttPacket::handleConnect()
sender->setDisconnectReason("Invalid username character");
accessGranted = false;
}
- else if (sender->getThreadData()->authPlugin.unPwdCheck(username, password) == AuthResult::success)
+ else if (sender->getThreadData()->authentication.unPwdCheck(username, password) == AuthResult::success)
{
accessGranted = true;
}
@@ -462,7 +462,7 @@ void MqttPacket::handlePublish()
sender->writeMqttPacket(response);
}
- if (sender->getThreadData()->authPlugin.aclCheck(sender->getClientId(), sender->getUsername(), topic, AclAccess::write) == AuthResult::success)
+ if (sender->getThreadData()->authentication.aclCheck(sender->getClientId(), sender->getUsername(), topic, AclAccess::write) == AuthResult::success)
{
if (retain)
{
diff --git a/session.cpp b/session.cpp
index c00da97..fa90681 100644
--- a/session.cpp
+++ b/session.cpp
@@ -52,7 +52,7 @@ void Session::writePacket(const MqttPacket &packet, char max_qos)
{
assert(max_qos <= 2);
- if (thread->authPlugin.aclCheck(client_id, username, packet.getTopic(), AclAccess::read) == AuthResult::success)
+ if (thread->authentication.aclCheck(client_id, username, packet.getTopic(), AclAccess::read) == AuthResult::success)
{
const char qos = std::min(packet.getQos(), max_qos);
diff --git a/settings.h b/settings.h
index 76e4255..667f49a 100644
--- a/settings.h
+++ b/settings.h
@@ -42,6 +42,8 @@ public:
int maxPacketSize = 268435461; // 256 MB + 5
bool logDebug = false;
bool logSubscriptions = false;
+ std::string mosquittoPasswordFile;
+ bool allowAnonymous = false;
std::list> listeners; // Default one is created later, when none are defined.
AuthOptCompatWrap &getAuthOptsCompat();
diff --git a/threaddata.cpp b/threaddata.cpp
index eb5d1f6..e4dc307 100644
--- a/threaddata.cpp
+++ b/threaddata.cpp
@@ -22,7 +22,7 @@ License along with FlashMQ. If not, see .
ThreadData::ThreadData(int threadnr, std::shared_ptr &subscriptionStore, std::shared_ptr settings) :
subscriptionStore(subscriptionStore),
settingsLocalCopy(*settings.get()),
- authPlugin(settingsLocalCopy),
+ authentication(settingsLocalCopy),
threadnr(threadnr)
{
logger = Logger::getInstance();
@@ -132,7 +132,7 @@ void ThreadData::queueQuit()
auto f = std::bind(&ThreadData::quit, this);
taskQueue.push_front(f);
- authPlugin.setQuitting();
+ authentication.setQuitting();
wakeUpThread();
}
@@ -142,6 +142,16 @@ void ThreadData::waitForQuit()
thread.join();
}
+void ThreadData::queuePasswdFileReload()
+{
+ std::lock_guard locker(taskQueueMutex);
+
+ auto f = std::bind(&Authentication::loadMosquittoPasswordFile, &authentication);
+ taskQueue.push_front(f);
+
+ wakeUpThread();
+}
+
// TODO: profile how fast hash iteration is. Perhaps having a second list/vector is beneficial?
void ThreadData::doKeepAliveCheck()
{
@@ -179,9 +189,10 @@ void ThreadData::doKeepAliveCheck()
void ThreadData::initAuthPlugin()
{
- authPlugin.loadPlugin(settingsLocalCopy.authPluginPath);
- authPlugin.init();
- authPlugin.securityInit(false);
+ authentication.loadMosquittoPasswordFile();
+ authentication.loadPlugin(settingsLocalCopy.authPluginPath);
+ authentication.init();
+ authentication.securityInit(false);
}
void ThreadData::reload(std::shared_ptr settings)
@@ -193,8 +204,8 @@ void ThreadData::reload(std::shared_ptr settings)
// Because the auth plugin has a reference to it, it will also be updated.
settingsLocalCopy = *settings.get();
- authPlugin.securityCleanup(true);
- authPlugin.securityInit(true);
+ authentication.securityCleanup(true);
+ authentication.securityInit(true);
}
catch (AuthPluginException &ex)
{
diff --git a/threaddata.h b/threaddata.h
index 02693ef..2433a0a 100644
--- a/threaddata.h
+++ b/threaddata.h
@@ -54,7 +54,7 @@ class ThreadData
public:
Settings settingsLocalCopy; // Is updated on reload, within the thread loop.
- AuthPlugin authPlugin;
+ Authentication authentication;
bool running = true;
std::thread thread;
int threadnr = 0;
@@ -80,6 +80,7 @@ public:
void queueDoKeepAliveCheck();
void queueQuit();
void waitForQuit();
+ void queuePasswdFileReload();
};
diff --git a/utils.cpp b/utils.cpp
index 8a1ed76..6625d31 100644
--- a/utils.cpp
+++ b/utils.cpp
@@ -21,6 +21,7 @@ License along with FlashMQ. If not, see .
#include "sys/random.h"
#include
#include
+#include
#include "openssl/ssl.h"
#include "openssl/err.h"
@@ -29,6 +30,7 @@ License along with FlashMQ. If not, see .
#include "cirbuf.h"
#include "sslctxmanager.h"
#include "logger.h"
+#include "evpencodectxmanager.h"
std::list split(const std::string &input, const char sep, size_t max, bool keep_empty_parts)
{
@@ -357,6 +359,30 @@ bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_ver
return doubleEmptyLine;
}
+std::vector base64Decode(const std::string &s)
+{
+ if (s.length() % 4 != 0)
+ throw std::runtime_error("Decoding invalid base64 string");
+
+ if (s.empty())
+ throw std::runtime_error("Trying to base64 decode an empty string.");
+
+ std::vector tmp(s.size());
+
+ int outl = 0;
+ int outl_total = 0;
+
+ EvpEncodeCtxManager b64_ctx;
+ if (EVP_DecodeUpdate(b64_ctx.ctx, reinterpret_cast(tmp.data()), &outl, reinterpret_cast(s.c_str()), s.size()) < 0)
+ throw std::runtime_error("Failure in EVP_DecodeUpdate()");
+ outl_total += outl;
+ if (EVP_DecodeFinal(b64_ctx.ctx, reinterpret_cast(tmp[outl_total]), &outl) < 0)
+ throw std::runtime_error("Failure in EVP_DecodeFinal()");
+ std::vector result(outl_total);
+ std::memcpy(result.data(), tmp.data(), outl_total);
+ return result;
+}
+
std::string base64Encode(const unsigned char *input, const int length)
{
const int pl = 4*((length+2)/3);
diff --git a/utils.h b/utils.h
index 8ab3af7..8c4f6f7 100644
--- a/utils.h
+++ b/utils.h
@@ -70,6 +70,7 @@ bool isPowerOfTwo(int val);
bool parseHttpHeader(CirBuf &buf, std::string &websocket_key, int &websocket_version);
+std::vector base64Decode(const std::string &s);
std::string base64Encode(const unsigned char *input, const int length);
std::string generateWebsocketAcceptString(const std::string &websocketKey);