You need to sign in before continuing.

Commit d06dcb0bba521c4992ca07e538bf7048dbc41d41

Authored by Wiebe Cazemier
1 parent cb40c2d2

Add IPv6 support, with related listener options

This also adds configuration options for choosing what address to bind
to.
CMakeLists.txt
... ... @@ -33,6 +33,8 @@ add_executable(FlashMQ
33 33 settings.cpp
34 34 listener.cpp
35 35 unscopedlock.cpp
  36 + scopedsocket.cpp
  37 + bindaddr.cpp
36 38 )
37 39  
38 40 target_link_libraries(FlashMQ pthread dl ssl crypto)
... ...
bindaddr.cpp 0 → 100644
  1 +#include "bindaddr.h"
  2 +
  3 +
... ...
bindaddr.h 0 → 100644
  1 +#ifndef BINDADDR_H
  2 +#define BINDADDR_H
  3 +
  4 +#include <arpa/inet.h>
  5 +#include <memory>
  6 +
  7 +/**
  8 + * @brief The BindAddr struct helps creating the resource for bind(). It uses an intermediate struct sockaddr to avoid compiler warnings, and
  9 + * this class helps a bit with resource management of it.
  10 + */
  11 +struct BindAddr
  12 +{
  13 + std::unique_ptr<sockaddr> p;
  14 + socklen_t len = 0;
  15 +};
  16 +
  17 +#endif // BINDADDR_H
... ...
configfileparser.cpp
... ... @@ -71,6 +71,9 @@ ConfigFileParser::ConfigFileParser(const std::string &amp;path) :
71 71 validListenKeys.insert("protocol");
72 72 validListenKeys.insert("fullchain");
73 73 validListenKeys.insert("privkey");
  74 + validListenKeys.insert("inet_protocol");
  75 + validListenKeys.insert("inet4_bind_address");
  76 + validListenKeys.insert("inet6_bind_address");
74 77  
75 78 settings.reset(new Settings());
76 79 }
... ... @@ -93,7 +96,7 @@ void ConfigFileParser::loadFile(bool test)
93 96  
94 97 std::list<std::string> lines;
95 98  
96   - const std::regex key_value_regex("^([a-zA-Z0-9_\\-]+) +([a-zA-Z0-9_\\-/\\.]+)$");
  99 + const std::regex key_value_regex("^([a-zA-Z0-9_\\-]+) +([a-zA-Z0-9_\\-/\\.:]+)$");
97 100 const std::regex block_regex_start("^([a-zA-Z0-9_\\-]+) *\\{$");
98 101 const std::regex block_regex_end("^\\}$");
99 102  
... ... @@ -211,6 +214,25 @@ void ConfigFileParser::loadFile(bool test)
211 214 {
212 215 curListener->sslPrivkey = value;
213 216 }
  217 + if (key == "inet_protocol")
  218 + {
  219 + if (value == "ip4")
  220 + curListener->protocol = ListenerProtocol::IPv4;
  221 + else if (value == "ip6")
  222 + curListener->protocol = ListenerProtocol::IPv6;
  223 + else if (value == "ip4_ip6")
  224 + curListener->protocol = ListenerProtocol::IPv46;
  225 + else
  226 + throw ConfigFileException(formatString("Invalid inet protocol: %s", value.c_str()));
  227 + }
  228 + if (key == "inet4_bind_address")
  229 + {
  230 + curListener->inet4BindAddress = value;
  231 + }
  232 + if (key == "inet6_bind_address")
  233 + {
  234 + curListener->inet6BindAddress = value;
  235 + }
214 236  
215 237 continue;
216 238 }
... ...
listener.cpp
... ... @@ -76,3 +76,20 @@ void Listener::loadCertAndKeyFromConfig()
76 76 if (SSL_CTX_use_PrivateKey_file(sslctx->get(), sslPrivkey.c_str(), SSL_FILETYPE_PEM) != 1)
77 77 throw std::runtime_error("Loading key failed. This was after test loading the certificate, so is very unexpected.");
78 78 }
  79 +
  80 +std::string Listener::getBindAddress(ListenerProtocol p)
  81 +{
  82 + if (p == ListenerProtocol::IPv4)
  83 + {
  84 + if (inet4BindAddress.empty())
  85 + return "0.0.0.0";
  86 + return inet4BindAddress;
  87 + }
  88 + if (p == ListenerProtocol::IPv6)
  89 + {
  90 + if (inet6BindAddress.empty())
  91 + return "::";
  92 + return inet6BindAddress;
  93 + }
  94 + return "";
  95 +}
... ...
listener.h
... ... @@ -6,8 +6,18 @@
6 6  
7 7 #include "sslctxmanager.h"
8 8  
  9 +enum class ListenerProtocol
  10 +{
  11 + IPv46,
  12 + IPv4,
  13 + IPv6
  14 +};
  15 +
9 16 struct Listener
10 17 {
  18 + ListenerProtocol protocol = ListenerProtocol::IPv46;
  19 + std::string inet4BindAddress;
  20 + std::string inet6BindAddress;
11 21 int port = 0;
12 22 bool websocket = false;
13 23 std::string sslFullchain;
... ... @@ -18,5 +28,7 @@ struct Listener
18 28 bool isSsl() const;
19 29 std::string getProtocolName() const;
20 30 void loadCertAndKeyFromConfig();
  31 +
  32 + std::string getBindAddress(ListenerProtocol p);
21 33 };
22 34 #endif // LISTENER_H
... ...
mainapp.cpp
... ... @@ -5,6 +5,7 @@
5 5 #include <unistd.h>
6 6 #include <stdio.h>
7 7 #include <sys/sysinfo.h>
  8 +#include <arpa/inet.h>
8 9  
9 10 #include <openssl/ssl.h>
10 11 #include <openssl/err.h>
... ... @@ -205,47 +206,59 @@ void MainApp::showLicense()
205 206 puts("Author: Wiebe Cazemier <wiebe@halfgaar.net>");
206 207 }
207 208  
208   -int MainApp::createListenSocket(const std::shared_ptr<Listener> &listener)
  209 +std::list<ScopedSocket> MainApp::createListenSocket(const std::shared_ptr<Listener> &listener)
209 210 {
210   - if (listener->port <= 0)
211   - return -2;
  211 + std::list<ScopedSocket> result;
212 212  
213   - logger->logf(LOG_NOTICE, "Creating %s listener on port %d", listener->getProtocolName().c_str(), listener->port);
  213 + if (listener->port <= 0)
  214 + return result;
214 215  
215   - try
  216 + for (ListenerProtocol p : std::list<ListenerProtocol>({ ListenerProtocol::IPv4, ListenerProtocol::IPv6}))
216 217 {
217   - int listen_fd = check<std::runtime_error>(socket(AF_INET, SOCK_STREAM, 0));
  218 + std::string pname = p == ListenerProtocol::IPv4 ? "IPv4" : "IPv6";
  219 + int family = p == ListenerProtocol::IPv4 ? AF_INET : AF_INET6;
218 220  
219   - // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT.
220   - int optval = 1;
221   - check<std::runtime_error>(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval)));
  221 + if (!(listener->protocol == ListenerProtocol::IPv46 || listener->protocol == p))
  222 + continue;
222 223  
223   - int flags = fcntl(listen_fd, F_GETFL);
224   - check<std::runtime_error>(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK ));
  224 + try
  225 + {
  226 + logger->logf(LOG_NOTICE, "Creating %s %s listener on [%s]:%d", pname.c_str(), listener->getProtocolName().c_str(),
  227 + listener->getBindAddress(p).c_str(), listener->port);
225 228  
226   - struct sockaddr_in in_addr_plain;
227   - in_addr_plain.sin_family = AF_INET;
228   - in_addr_plain.sin_addr.s_addr = INADDR_ANY;
229   - in_addr_plain.sin_port = htons(listener->port);
  229 + BindAddr bindAddr = getBindAddr(family, listener->getBindAddress(p), listener->port);
230 230  
231   - check<std::runtime_error>(bind(listen_fd, (struct sockaddr *)(&in_addr_plain), sizeof(struct sockaddr_in)));
232   - check<std::runtime_error>(listen(listen_fd, 1024));
  231 + int listen_fd = check<std::runtime_error>(socket(family, SOCK_STREAM, 0));
233 232  
234   - struct epoll_event ev;
235   - memset(&ev, 0, sizeof (struct epoll_event));
  233 + // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT.
  234 + int optval = 1;
  235 + check<std::runtime_error>(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval)));
236 236  
237   - ev.data.fd = listen_fd;
238   - ev.events = EPOLLIN;
239   - check<std::runtime_error>(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, listen_fd, &ev));
  237 + int flags = fcntl(listen_fd, F_GETFL);
  238 + check<std::runtime_error>(fcntl(listen_fd, F_SETFL, flags | O_NONBLOCK ));
240 239  
241   - return listen_fd;
242   - }
243   - catch (std::exception &ex)
244   - {
245   - logger->logf(LOG_NOTICE, "Creating %s listener on port %d failed: %s", listener->getProtocolName().c_str(), listener->port, ex.what());
246   - return -1;
  240 + check<std::runtime_error>(bind(listen_fd, bindAddr.p.get(), bindAddr.len));
  241 + check<std::runtime_error>(listen(listen_fd, 1024));
  242 +
  243 + struct epoll_event ev;
  244 + memset(&ev, 0, sizeof (struct epoll_event));
  245 +
  246 + ev.data.fd = listen_fd;
  247 + ev.events = EPOLLIN;
  248 + check<std::runtime_error>(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, listen_fd, &ev));
  249 +
  250 + result.push_back(ScopedSocket(listen_fd));
  251 +
  252 + }
  253 + catch (std::exception &ex)
  254 + {
  255 + logger->logf(LOG_ERR, "Creating %s %s listener on [%s]:%d failed: %s", pname.c_str(), listener->getProtocolName().c_str(),
  256 + listener->getBindAddress(p).c_str(), listener->port, ex.what());
  257 + return std::list<ScopedSocket>();
  258 + }
247 259 }
248   - return -1;
  260 +
  261 + return result;
249 262 }
250 263  
251 264 void MainApp::wakeUpThread()
... ... @@ -367,13 +380,19 @@ void MainApp::start()
367 380 {
368 381 timer.start();
369 382  
370   - std::map<int, std::shared_ptr<Listener>> listenerMap;
  383 + std::map<int, std::shared_ptr<Listener>> listenerMap; // For finding listeners by fd.
  384 + std::list<ScopedSocket> activeListenSockets; // For RAII/ownership
371 385  
372 386 for(std::shared_ptr<Listener> &listener : this->listeners)
373 387 {
374   - int fd = createListenSocket(listener);
375   - if (fd > 0)
376   - listenerMap[fd] = listener;
  388 + std::list<ScopedSocket> scopedSockets = createListenSocket(listener);
  389 +
  390 + for (ScopedSocket &scopedSocket : scopedSockets)
  391 + {
  392 + if (scopedSocket.socket > 0)
  393 + listenerMap[scopedSocket.socket] = listener;
  394 + activeListenSockets.push_back(std::move(scopedSocket));
  395 + }
377 396 }
378 397  
379 398 #ifdef NDEBUG
... ... @@ -506,11 +525,6 @@ void MainApp::start()
506 525 {
507 526 thread->waitForQuit();
508 527 }
509   -
510   - for(auto pair : listenerMap)
511   - {
512   - close(pair.first);
513   - }
514 528 }
515 529  
516 530 void MainApp::quit()
... ... @@ -596,3 +610,4 @@ void MainApp::queueCleanup()
596 610  
597 611 wakeUpThread();
598 612 }
  613 +
... ...
mainapp.h
... ... @@ -10,6 +10,7 @@
10 10 #include <vector>
11 11 #include <functional>
12 12 #include <forward_list>
  13 +#include <list>
13 14  
14 15 #include "forward_declarations.h"
15 16  
... ... @@ -20,6 +21,7 @@
20 21 #include "subscriptionstore.h"
21 22 #include "configfileparser.h"
22 23 #include "timer.h"
  24 +#include "scopedsocket.h"
23 25  
24 26 class MainApp
25 27 {
... ... @@ -48,7 +50,7 @@ class MainApp
48 50 void reloadConfig();
49 51 static void doHelp(const char *arg);
50 52 static void showLicense();
51   - int createListenSocket(const std::shared_ptr<Listener> &listener);
  53 + std::list<ScopedSocket> createListenSocket(const std::shared_ptr<Listener> &listener);
52 54 void wakeUpThread();
53 55 void queueKeepAliveCheckAtAllThreads();
54 56 void setFuzzFile(const std::string &fuzzFilePath);
... ...
scopedsocket.cpp 0 → 100644
  1 +#include "scopedsocket.h"
  2 +
  3 +ScopedSocket::ScopedSocket(int socket) : socket(socket)
  4 +{
  5 +
  6 +}
  7 +
  8 +ScopedSocket::ScopedSocket(ScopedSocket &&other)
  9 +{
  10 + this->socket = other.socket;
  11 + other.socket = 0;
  12 +}
  13 +
  14 +ScopedSocket::~ScopedSocket()
  15 +{
  16 + if (socket > 0)
  17 + close(socket);
  18 +}
... ...
scopedsocket.h 0 → 100644
  1 +#ifndef SCOPEDSOCKET_H
  2 +#define SCOPEDSOCKET_H
  3 +
  4 +#include <fcntl.h>
  5 +#include <unistd.h>
  6 +
  7 +/**
  8 + * @brief The ScopedSocket struct allows for a bit of RAII and move semantics on a socket fd.
  9 + */
  10 +struct ScopedSocket
  11 +{
  12 + int socket = 0;
  13 + ScopedSocket(int socket);
  14 + ScopedSocket(ScopedSocket &&other);
  15 + ~ScopedSocket();
  16 +};
  17 +
  18 +#endif // SCOPEDSOCKET_H
... ...
utils.cpp
... ... @@ -427,3 +427,41 @@ std::string dirnameOf(const std::string&amp; path)
427 427 return (std::string::npos == pos) ? "" : path.substr(0, pos);
428 428 }
429 429  
  430 +
  431 +BindAddr getBindAddr(int family, const std::string &bindAddress, int port)
  432 +{
  433 + BindAddr result;
  434 +
  435 + if (family == AF_INET)
  436 + {
  437 + struct sockaddr_in *in_addr_v4 = new sockaddr_in();
  438 + result.len = sizeof(struct sockaddr_in);
  439 + memset(in_addr_v4, 0, result.len);
  440 +
  441 + if (bindAddress.empty())
  442 + in_addr_v4->sin_addr.s_addr = INADDR_ANY;
  443 + else
  444 + inet_pton(AF_INET, bindAddress.c_str(), &in_addr_v4->sin_addr);
  445 +
  446 + in_addr_v4->sin_family = AF_INET;
  447 + in_addr_v4->sin_port = htons(port);
  448 + result.p.reset(reinterpret_cast<sockaddr*>(in_addr_v4));
  449 + }
  450 + if (family == AF_INET6)
  451 + {
  452 + struct sockaddr_in6 *in_addr_v6 = new sockaddr_in6();
  453 + result.len = sizeof(struct sockaddr_in6);
  454 + memset(in_addr_v6, 0, result.len);
  455 +
  456 + if (bindAddress.empty())
  457 + in_addr_v6->sin6_addr = IN6ADDR_ANY_INIT;
  458 + else
  459 + inet_pton(AF_INET6, bindAddress.c_str(), &in_addr_v6->sin6_addr);
  460 +
  461 + in_addr_v6->sin6_family = AF_INET6;
  462 + in_addr_v6->sin6_port = htons(port);
  463 + result.p.reset(reinterpret_cast<sockaddr*>(in_addr_v6));
  464 + }
  465 +
  466 + return result;
  467 +}
... ...
... ... @@ -9,8 +9,11 @@
9 9 #include <vector>
10 10 #include <algorithm>
11 11 #include <openssl/evp.h>
  12 +#include <memory>
  13 +#include <arpa/inet.h>
12 14  
13 15 #include "cirbuf.h"
  16 +#include "bindaddr.h"
14 17  
15 18 template<typename T> int check(int rc)
16 19 {
... ... @@ -62,5 +65,7 @@ std::string formatString(const std::string str, ...);
62 65  
63 66 std::string dirnameOf(const std::string& fname);
64 67  
  68 +BindAddr getBindAddr(int family, const std::string &bindAddress, int port);
  69 +
65 70  
66 71 #endif // UTILS_H
... ...