Commit d06dcb0bba521c4992ca07e538bf7048dbc41d41

Authored by Wiebe Cazemier
1 parent cb40c2d2

Add IPv6 support, with related listener options

This also adds configuration options for choosing what address to bind
to.
CMakeLists.txt
@@ -33,6 +33,8 @@ add_executable(FlashMQ @@ -33,6 +33,8 @@ add_executable(FlashMQ
33 settings.cpp 33 settings.cpp
34 listener.cpp 34 listener.cpp
35 unscopedlock.cpp 35 unscopedlock.cpp
  36 + scopedsocket.cpp
  37 + bindaddr.cpp
36 ) 38 )
37 39
38 target_link_libraries(FlashMQ pthread dl ssl crypto) 40 target_link_libraries(FlashMQ pthread dl ssl crypto)
bindaddr.cpp 0 → 100644
  1 +#include "bindaddr.h"
  2 +
  3 +
bindaddr.h 0 → 100644
  1 +#ifndef BINDADDR_H
  2 +#define BINDADDR_H
  3 +
  4 +#include <arpa/inet.h>
  5 +#include <memory>
  6 +
  7 +/**
  8 + * @brief The BindAddr struct helps creating the resource for bind(). It uses an intermediate struct sockaddr to avoid compiler warnings, and
  9 + * this class helps a bit with resource management of it.
  10 + */
  11 +struct BindAddr
  12 +{
  13 + std::unique_ptr<sockaddr> p;
  14 + socklen_t len = 0;
  15 +};
  16 +
  17 +#endif // BINDADDR_H
configfileparser.cpp
@@ -71,6 +71,9 @@ ConfigFileParser::ConfigFileParser(const std::string &amp;path) : @@ -71,6 +71,9 @@ ConfigFileParser::ConfigFileParser(const std::string &amp;path) :
71 validListenKeys.insert("protocol"); 71 validListenKeys.insert("protocol");
72 validListenKeys.insert("fullchain"); 72 validListenKeys.insert("fullchain");
73 validListenKeys.insert("privkey"); 73 validListenKeys.insert("privkey");
  74 + validListenKeys.insert("inet_protocol");
  75 + validListenKeys.insert("inet4_bind_address");
  76 + validListenKeys.insert("inet6_bind_address");
74 77
75 settings.reset(new Settings()); 78 settings.reset(new Settings());
76 } 79 }
@@ -93,7 +96,7 @@ void ConfigFileParser::loadFile(bool test) @@ -93,7 +96,7 @@ void ConfigFileParser::loadFile(bool test)
93 96
94 std::list<std::string> lines; 97 std::list<std::string> lines;
95 98
96 - const std::regex key_value_regex("^([a-zA-Z0-9_\\-]+) +([a-zA-Z0-9_\\-/\\.]+)$"); 99 + const std::regex key_value_regex("^([a-zA-Z0-9_\\-]+) +([a-zA-Z0-9_\\-/\\.:]+)$");
97 const std::regex block_regex_start("^([a-zA-Z0-9_\\-]+) *\\{$"); 100 const std::regex block_regex_start("^([a-zA-Z0-9_\\-]+) *\\{$");
98 const std::regex block_regex_end("^\\}$"); 101 const std::regex block_regex_end("^\\}$");
99 102
@@ -211,6 +214,25 @@ void ConfigFileParser::loadFile(bool test) @@ -211,6 +214,25 @@ void ConfigFileParser::loadFile(bool test)
211 { 214 {
212 curListener->sslPrivkey = value; 215 curListener->sslPrivkey = value;
213 } 216 }
  217 + if (key == "inet_protocol")
  218 + {
  219 + if (value == "ip4")
  220 + curListener->protocol = ListenerProtocol::IPv4;
  221 + else if (value == "ip6")
  222 + curListener->protocol = ListenerProtocol::IPv6;
  223 + else if (value == "ip4_ip6")
  224 + curListener->protocol = ListenerProtocol::IPv46;
  225 + else
  226 + throw ConfigFileException(formatString("Invalid inet protocol: %s", value.c_str()));
  227 + }
  228 + if (key == "inet4_bind_address")
  229 + {
  230 + curListener->inet4BindAddress = value;
  231 + }
  232 + if (key == "inet6_bind_address")
  233 + {
  234 + curListener->inet6BindAddress = value;
  235 + }
214 236
215 continue; 237 continue;
216 } 238 }
listener.cpp
@@ -76,3 +76,20 @@ void Listener::loadCertAndKeyFromConfig() @@ -76,3 +76,20 @@ void Listener::loadCertAndKeyFromConfig()
76 if (SSL_CTX_use_PrivateKey_file(sslctx->get(), sslPrivkey.c_str(), SSL_FILETYPE_PEM) != 1) 76 if (SSL_CTX_use_PrivateKey_file(sslctx->get(), sslPrivkey.c_str(), SSL_FILETYPE_PEM) != 1)
77 throw std::runtime_error("Loading key failed. This was after test loading the certificate, so is very unexpected."); 77 throw std::runtime_error("Loading key failed. This was after test loading the certificate, so is very unexpected.");
78 } 78 }
  79 +
  80 +std::string Listener::getBindAddress(ListenerProtocol p)
  81 +{
  82 + if (p == ListenerProtocol::IPv4)
  83 + {
  84 + if (inet4BindAddress.empty())
  85 + return "0.0.0.0";
  86 + return inet4BindAddress;
  87 + }
  88 + if (p == ListenerProtocol::IPv6)
  89 + {
  90 + if (inet6BindAddress.empty())
  91 + return "::";
  92 + return inet6BindAddress;
  93 + }
  94 + return "";
  95 +}
listener.h
@@ -6,8 +6,18 @@ @@ -6,8 +6,18 @@
6 6
7 #include "sslctxmanager.h" 7 #include "sslctxmanager.h"
8 8
  9 +enum class ListenerProtocol
  10 +{
  11 + IPv46,
  12 + IPv4,
  13 + IPv6
  14 +};
  15 +
9 struct Listener 16 struct Listener
10 { 17 {
  18 + ListenerProtocol protocol = ListenerProtocol::IPv46;
  19 + std::string inet4BindAddress;
  20 + std::string inet6BindAddress;
11 int port = 0; 21 int port = 0;
12 bool websocket = false; 22 bool websocket = false;
13 std::string sslFullchain; 23 std::string sslFullchain;
@@ -18,5 +28,7 @@ struct Listener @@ -18,5 +28,7 @@ struct Listener
18 bool isSsl() const; 28 bool isSsl() const;
19 std::string getProtocolName() const; 29 std::string getProtocolName() const;
20 void loadCertAndKeyFromConfig(); 30 void loadCertAndKeyFromConfig();
  31 +
  32 + std::string getBindAddress(ListenerProtocol p);
21 }; 33 };
22 #endif // LISTENER_H 34 #endif // LISTENER_H
mainapp.cpp
@@ -5,6 +5,7 @@ @@ -5,6 +5,7 @@
5 #include <unistd.h> 5 #include <unistd.h>
6 #include <stdio.h> 6 #include <stdio.h>
7 #include <sys/sysinfo.h> 7 #include <sys/sysinfo.h>
  8 +#include <arpa/inet.h>
8 9
9 #include <openssl/ssl.h> 10 #include <openssl/ssl.h>
10 #include <openssl/err.h> 11 #include <openssl/err.h>
@@ -205,47 +206,59 @@ void MainApp::showLicense() @@ -205,47 +206,59 @@ void MainApp::showLicense()
205 puts("Author: Wiebe Cazemier <wiebe@halfgaar.net>"); 206 puts("Author: Wiebe Cazemier <wiebe@halfgaar.net>");
206 } 207 }
207 208
208 -int MainApp::createListenSocket(const std::shared_ptr<Listener> &listener) 209 +std::list<ScopedSocket> MainApp::createListenSocket(const std::shared_ptr<Listener> &listener)
209 { 210 {
210 - if (listener->port <= 0)  
211 - return -2; 211 + std::list<ScopedSocket> result;
212 212
213 - logger->logf(LOG_NOTICE, "Creating %s listener on port %d", listener->getProtocolName().c_str(), listener->port); 213 + if (listener->port <= 0)
  214 + return result;
214 215
215 - try 216 + for (ListenerProtocol p : std::list<ListenerProtocol>({ ListenerProtocol::IPv4, ListenerProtocol::IPv6}))
216 { 217 {
217 - int listen_fd = check<std::runtime_error>(socket(AF_INET, SOCK_STREAM, 0)); 218 + std::string pname = p == ListenerProtocol::IPv4 ? "IPv4" : "IPv6";
  219 + int family = p == ListenerProtocol::IPv4 ? AF_INET : AF_INET6;
218 220
219 - // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT.  
220 - int optval = 1;  
221 - check<std::runtime_error>(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); 221 + if (!(listener->protocol == ListenerProtocol::IPv46 || listener->protocol == p))
  222 + continue;
222 223
223 - int flags = fcntl(listen_fd, F_GETFL);  
224 - check<std::runtime_error>(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK )); 224 + try
  225 + {
  226 + logger->logf(LOG_NOTICE, "Creating %s %s listener on [%s]:%d", pname.c_str(), listener->getProtocolName().c_str(),
  227 + listener->getBindAddress(p).c_str(), listener->port);
225 228
226 - struct sockaddr_in in_addr_plain;  
227 - in_addr_plain.sin_family = AF_INET;  
228 - in_addr_plain.sin_addr.s_addr = INADDR_ANY;  
229 - in_addr_plain.sin_port = htons(listener->port); 229 + BindAddr bindAddr = getBindAddr(family, listener->getBindAddress(p), listener->port);
230 230
231 - check<std::runtime_error>(bind(listen_fd, (struct sockaddr *)(&in_addr_plain), sizeof(struct sockaddr_in)));  
232 - check<std::runtime_error>(listen(listen_fd, 1024)); 231 + int listen_fd = check<std::runtime_error>(socket(family, SOCK_STREAM, 0));
233 232
234 - struct epoll_event ev;  
235 - memset(&ev, 0, sizeof (struct epoll_event)); 233 + // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT.
  234 + int optval = 1;
  235 + check<std::runtime_error>(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval)));
236 236
237 - ev.data.fd = listen_fd;  
238 - ev.events = EPOLLIN;  
239 - check<std::runtime_error>(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, listen_fd, &ev)); 237 + int flags = fcntl(listen_fd, F_GETFL);
  238 + check<std::runtime_error>(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK ));
240 239
241 - return listen_fd;  
242 - }  
243 - catch (std::exception &ex)  
244 - {  
245 - logger->logf(LOG_NOTICE, "Creating %s listener on port %d failed: %s", listener->getProtocolName().c_str(), listener->port, ex.what());  
246 - return -1; 240 + check<std::runtime_error>(bind(listen_fd, bindAddr.p.get(), bindAddr.len));
  241 + check<std::runtime_error>(listen(listen_fd, 1024));
  242 +
  243 + struct epoll_event ev;
  244 + memset(&ev, 0, sizeof (struct epoll_event));
  245 +
  246 + ev.data.fd = listen_fd;
  247 + ev.events = EPOLLIN;
  248 + check<std::runtime_error>(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, listen_fd, &ev));
  249 +
  250 + result.push_back(ScopedSocket(listen_fd));
  251 +
  252 + }
  253 + catch (std::exception &ex)
  254 + {
  255 + logger->logf(LOG_ERR, "Creating %s %s listener on [%s]:%d failed: %s", pname.c_str(), listener->getProtocolName().c_str(),
  256 + listener->getBindAddress(p).c_str(), listener->port, ex.what());
  257 + return std::list<ScopedSocket>();
  258 + }
247 } 259 }
248 - return -1; 260 +
  261 + return result;
249 } 262 }
250 263
251 void MainApp::wakeUpThread() 264 void MainApp::wakeUpThread()
@@ -367,13 +380,19 @@ void MainApp::start() @@ -367,13 +380,19 @@ void MainApp::start()
367 { 380 {
368 timer.start(); 381 timer.start();
369 382
370 - std::map<int, std::shared_ptr<Listener>> listenerMap; 383 + std::map<int, std::shared_ptr<Listener>> listenerMap; // For finding listeners by fd.
  384 + std::list<ScopedSocket> activeListenSockets; // For RAII/ownership
371 385
372 for(std::shared_ptr<Listener> &listener : this->listeners) 386 for(std::shared_ptr<Listener> &listener : this->listeners)
373 { 387 {
374 - int fd = createListenSocket(listener);  
375 - if (fd > 0)  
376 - listenerMap[fd] = listener; 388 + std::list<ScopedSocket> scopedSockets = createListenSocket(listener);
  389 +
  390 + for (ScopedSocket &scopedSocket : scopedSockets)
  391 + {
  392 + if (scopedSocket.socket > 0)
  393 + listenerMap[scopedSocket.socket] = listener;
  394 + activeListenSockets.push_back(std::move(scopedSocket));
  395 + }
377 } 396 }
378 397
379 #ifdef NDEBUG 398 #ifdef NDEBUG
@@ -506,11 +525,6 @@ void MainApp::start() @@ -506,11 +525,6 @@ void MainApp::start()
506 { 525 {
507 thread->waitForQuit(); 526 thread->waitForQuit();
508 } 527 }
509 -  
510 - for(auto pair : listenerMap)  
511 - {  
512 - close(pair.first);  
513 - }  
514 } 528 }
515 529
516 void MainApp::quit() 530 void MainApp::quit()
@@ -596,3 +610,4 @@ void MainApp::queueCleanup() @@ -596,3 +610,4 @@ void MainApp::queueCleanup()
596 610
597 wakeUpThread(); 611 wakeUpThread();
598 } 612 }
  613 +
mainapp.h
@@ -10,6 +10,7 @@ @@ -10,6 +10,7 @@
10 #include <vector> 10 #include <vector>
11 #include <functional> 11 #include <functional>
12 #include <forward_list> 12 #include <forward_list>
  13 +#include <list>
13 14
14 #include "forward_declarations.h" 15 #include "forward_declarations.h"
15 16
@@ -20,6 +21,7 @@ @@ -20,6 +21,7 @@
20 #include "subscriptionstore.h" 21 #include "subscriptionstore.h"
21 #include "configfileparser.h" 22 #include "configfileparser.h"
22 #include "timer.h" 23 #include "timer.h"
  24 +#include "scopedsocket.h"
23 25
24 class MainApp 26 class MainApp
25 { 27 {
@@ -48,7 +50,7 @@ class MainApp @@ -48,7 +50,7 @@ class MainApp
48 void reloadConfig(); 50 void reloadConfig();
49 static void doHelp(const char *arg); 51 static void doHelp(const char *arg);
50 static void showLicense(); 52 static void showLicense();
51 - int createListenSocket(const std::shared_ptr<Listener> &listener); 53 + std::list<ScopedSocket> createListenSocket(const std::shared_ptr<Listener> &listener);
52 void wakeUpThread(); 54 void wakeUpThread();
53 void queueKeepAliveCheckAtAllThreads(); 55 void queueKeepAliveCheckAtAllThreads();
54 void setFuzzFile(const std::string &fuzzFilePath); 56 void setFuzzFile(const std::string &fuzzFilePath);
scopedsocket.cpp 0 → 100644
  1 +#include "scopedsocket.h"
  2 +
  3 +ScopedSocket::ScopedSocket(int socket) : socket(socket)
  4 +{
  5 +
  6 +}
  7 +
  8 +ScopedSocket::ScopedSocket(ScopedSocket &&other)
  9 +{
  10 + this->socket = other.socket;
  11 + other.socket = 0;
  12 +}
  13 +
  14 +ScopedSocket::~ScopedSocket()
  15 +{
  16 + if (socket > 0)
  17 + close(socket);
  18 +}
scopedsocket.h 0 → 100644
  1 +#ifndef SCOPEDSOCKET_H
  2 +#define SCOPEDSOCKET_H
  3 +
  4 +#include <fcntl.h>
  5 +#include <unistd.h>
  6 +
  7 +/**
  8 + * @brief The ScopedSocket struct allows for a bit of RAII and move semantics on a socket fd.
  9 + */
  10 +struct ScopedSocket
  11 +{
  12 + int socket = 0;
  13 + ScopedSocket(int socket);
  14 + ScopedSocket(ScopedSocket &&other);
  15 + ~ScopedSocket();
  16 +};
  17 +
  18 +#endif // SCOPEDSOCKET_H
utils.cpp
@@ -427,3 +427,41 @@ std::string dirnameOf(const std::string&amp; path) @@ -427,3 +427,41 @@ std::string dirnameOf(const std::string&amp; path)
427 return (std::string::npos == pos) ? "" : path.substr(0, pos); 427 return (std::string::npos == pos) ? "" : path.substr(0, pos);
428 } 428 }
429 429
  430 +
  431 +BindAddr getBindAddr(int family, const std::string &bindAddress, int port)
  432 +{
  433 + BindAddr result;
  434 +
  435 + if (family == AF_INET)
  436 + {
  437 + struct sockaddr_in *in_addr_v4 = new sockaddr_in();
  438 + result.len = sizeof(struct sockaddr_in);
  439 + memset(in_addr_v4, 0, result.len);
  440 +
  441 + if (bindAddress.empty())
  442 + in_addr_v4->sin_addr.s_addr = INADDR_ANY;
  443 + else
  444 + inet_pton(AF_INET, bindAddress.c_str(), &in_addr_v4->sin_addr);
  445 +
  446 + in_addr_v4->sin_family = AF_INET;
  447 + in_addr_v4->sin_port = htons(port);
  448 + result.p.reset(reinterpret_cast<sockaddr*>(in_addr_v4));
  449 + }
  450 + if (family == AF_INET6)
  451 + {
  452 + struct sockaddr_in6 *in_addr_v6 = new sockaddr_in6();
  453 + result.len = sizeof(struct sockaddr_in6);
  454 + memset(in_addr_v6, 0, result.len);
  455 +
  456 + if (bindAddress.empty())
  457 + in_addr_v6->sin6_addr = IN6ADDR_ANY_INIT;
  458 + else
  459 + inet_pton(AF_INET6, bindAddress.c_str(), &in_addr_v6->sin6_addr);
  460 +
  461 + in_addr_v6->sin6_family = AF_INET6;
  462 + in_addr_v6->sin6_port = htons(port);
  463 + result.p.reset(reinterpret_cast<sockaddr*>(in_addr_v6));
  464 + }
  465 +
  466 + return result;
  467 +}
@@ -9,8 +9,11 @@ @@ -9,8 +9,11 @@
9 #include <vector> 9 #include <vector>
10 #include <algorithm> 10 #include <algorithm>
11 #include <openssl/evp.h> 11 #include <openssl/evp.h>
  12 +#include <memory>
  13 +#include <arpa/inet.h>
12 14
13 #include "cirbuf.h" 15 #include "cirbuf.h"
  16 +#include "bindaddr.h"
14 17
15 template<typename T> int check(int rc) 18 template<typename T> int check(int rc)
16 { 19 {
@@ -62,5 +65,7 @@ std::string formatString(const std::string str, ...); @@ -62,5 +65,7 @@ std::string formatString(const std::string str, ...);
62 65
63 std::string dirnameOf(const std::string& fname); 66 std::string dirnameOf(const std::string& fname);
64 67
  68 +BindAddr getBindAddr(int family, const std::string &bindAddress, int port);
  69 +
65 70
66 #endif // UTILS_H 71 #endif // UTILS_H