diff --git a/client.cpp b/client.cpp index d2f994b..0f539d6 100644 --- a/client.cpp +++ b/client.cpp @@ -7,11 +7,13 @@ #include "logger.h" -Client::Client(int fd, ThreadData_p threadData, SSL *ssl) : +Client::Client(int fd, ThreadData_p threadData, SSL *ssl, const GlobalSettings &settings) : fd(fd), ssl(ssl), - readbuf(CLIENT_BUFFER_SIZE), - writebuf(CLIENT_BUFFER_SIZE), + initialBufferSize(settings.clientInitialBufferSize), // The client is constructed in the main thread, so we need to use its settings copy + maxPacketSize(settings.maxPacketSize), // Same as initialBufferSize comment. + readbuf(initialBufferSize), + writebuf(initialBufferSize), threadData(threadData) { int flags = fcntl(fd, F_GETFL); @@ -206,7 +208,7 @@ bool Client::readFdIntoBuffer() // Make sure we either always have enough space for a next call of this method, or stop reading the fd. if (readbuf.freeSpace() == 0) { - if (readbuf.getSize() * 2 < MAX_PACKET_SIZE) + if (readbuf.getSize() * 2 < maxPacketSize) { readbuf.doubleSize(); } @@ -236,7 +238,7 @@ void Client::writeMqttPacket(const MqttPacket &packet) // We have to allow big packets, yet don't allow a slow loris subscriber to grow huge write buffers. This // could be enhanced a lot, but it's a start. - const uint32_t growBufMaxTo = std::min(packet.getSizeIncludingNonPresentHeader() * 1000, MAX_PACKET_SIZE); + const uint32_t growBufMaxTo = std::min(packet.getSizeIncludingNonPresentHeader() * 1000, maxPacketSize); // Grow as far as we can. We have to make room for one MQTT packet. while (packet.getSizeIncludingNonPresentHeader() > writebuf.freeSpace() && writebuf.getSize() < growBufMaxTo) @@ -439,13 +441,13 @@ bool Client::writeBufIntoFd() if (!bufferHasData) { writeBufIsZeroCount++; - bool doReset = (writeBufIsZeroCount >= 10 && writebuf.getSize() > (MAX_PACKET_SIZE / 10) && writebuf.bufferLastResizedSecondsAgo() > 30); + bool doReset = (writeBufIsZeroCount >= 10 && writebuf.getSize() > (maxPacketSize / 10) && writebuf.bufferLastResizedSecondsAgo() > 30); doReset |= (writeBufIsZeroCount >= 100 && writebuf.bufferLastResizedSecondsAgo() > 300); if (doReset) { writeBufIsZeroCount = 0; - writebuf.resetSize(CLIENT_BUFFER_SIZE); + writebuf.resetSize(initialBufferSize); } } @@ -564,13 +566,13 @@ bool Client::bufferToMqttPackets(std::vector &packetQueueIn, Client_ if (readbuf.usedBytes() == 0) { readBufIsZeroCount++; - bool doReset = (readBufIsZeroCount >= 10 && readbuf.getSize() > (MAX_PACKET_SIZE / 10) && readbuf.bufferLastResizedSecondsAgo() > 30); + bool doReset = (readBufIsZeroCount >= 10 && readbuf.getSize() > (maxPacketSize / 10) && readbuf.bufferLastResizedSecondsAgo() > 30); doReset |= (readBufIsZeroCount >= 100 && readbuf.bufferLastResizedSecondsAgo() > 300); if (doReset) { readBufIsZeroCount = 0; - readbuf.resetSize(CLIENT_BUFFER_SIZE); + readbuf.resetSize(initialBufferSize); } } diff --git a/client.h b/client.h index 3c801e9..5c06253 100644 --- a/client.h +++ b/client.h @@ -21,8 +21,6 @@ #include #include -#define CLIENT_BUFFER_SIZE 1024 // Must be power of 2 -#define MAX_PACKET_SIZE 268435461 // 256 MB + 5 #define MQTT_HEADER_LENGH 2 #define OPENSSL_ERROR_STRING_SIZE 256 // OpenSSL requires at least 256. @@ -64,6 +62,9 @@ class Client bool sslWriteWantsRead = false; ProtocolVersion protocolVersion = ProtocolVersion::None; + const size_t initialBufferSize = 0; + const size_t maxPacketSize = 0; + CirBuf readbuf; uint8_t readBufIsZeroCount = 0; @@ -101,7 +102,7 @@ class Client void setReadyForReading(bool val); public: - Client(int fd, ThreadData_p threadData, SSL *ssl); + Client(int fd, ThreadData_p threadData, SSL *ssl, const GlobalSettings &settings); Client(const Client &other) = delete; Client(Client &&other) = delete; ~Client(); diff --git a/configfileparser.cpp b/configfileparser.cpp index f6acf16..168bc80 100644 --- a/configfileparser.cpp +++ b/configfileparser.cpp @@ -99,6 +99,8 @@ ConfigFileParser::ConfigFileParser(const std::string &path) : validKeys.insert("fullchain"); validKeys.insert("privkey"); validKeys.insert("allow_unsafe_clientid_chars"); + validKeys.insert("client_initial_buffer_size"); + validKeys.insert("max_packet_size"); } void ConfigFileParser::loadFile(bool test) @@ -232,8 +234,30 @@ void ConfigFileParser::loadFile(bool test) sslListenPort = sslListenPortNew; } + if (key == "client_initial_buffer_size") + { + int newVal = std::stoi(value); + if (!isPowerOfTwo(newVal)) + throw ConfigFileException("client_initial_buffer_size value " + value + " is not a power of two."); + if (!test) + clientInitialBufferSize = newVal; + } + + if (key == "max_packet_size") + { + int newVal = std::stoi(value); + if (newVal > ABSOLUTE_MAX_PACKET_SIZE) + { + std::ostringstream oss; + oss << "Value for max_packet_size " << newVal << "is higher than absolute maximum " << ABSOLUTE_MAX_PACKET_SIZE; + throw ConfigFileException(oss.str()); + } + if (!test) + maxPacketSize = newVal; + } + } - catch (std::invalid_argument &ex) + catch (std::invalid_argument &ex) // catch for the stoi() { throw ConfigFileException(ex.what()); } diff --git a/configfileparser.h b/configfileparser.h index bb4529a..e05ca44 100644 --- a/configfileparser.h +++ b/configfileparser.h @@ -9,6 +9,8 @@ #include "sslctxmanager.h" +#define ABSOLUTE_MAX_PACKET_SIZE 268435461 // 256 MB + 5 + struct mosquitto_auth_opt { char *key = nullptr; @@ -55,6 +57,8 @@ public: uint listenPort = 1883; uint sslListenPort = 0; bool allowUnsafeClientidChars = false; + int clientInitialBufferSize = 1024; // Must be power of 2 + int maxPacketSize = 268435461; // 256 MB + 5 }; #endif // CONFIGFILEPARSER_H diff --git a/globalsettings.h b/globalsettings.h index 11a2437..bf9dbbb 100644 --- a/globalsettings.h +++ b/globalsettings.h @@ -5,5 +5,7 @@ struct GlobalSettings { bool allow_unsafe_clientid_chars = false; + int clientInitialBufferSize = 0; + int maxPacketSize = 0; }; #endif // GLOBALSETTINGS_H diff --git a/mainapp.cpp b/mainapp.cpp index 81f4415..1528a8b 100644 --- a/mainapp.cpp +++ b/mainapp.cpp @@ -417,7 +417,7 @@ void MainApp::start() SSL_set_fd(clientSSL, fd); } - Client_p client(new Client(fd, thread_data, clientSSL)); + Client_p client(new Client(fd, thread_data, clientSSL, settings)); thread_data->giveClient(client); } else if (cur_fd == taskEventFd) @@ -485,6 +485,8 @@ void MainApp::loadConfig() } settings.allow_unsafe_clientid_chars = confFileParser->allowUnsafeClientidChars; + settings.clientInitialBufferSize = confFileParser->clientInitialBufferSize; + settings.maxPacketSize = confFileParser->maxPacketSize; setCertAndKeyFromConfig(); diff --git a/utils.cpp b/utils.cpp index 17be781..d37f855 100644 --- a/utils.cpp +++ b/utils.cpp @@ -221,3 +221,8 @@ bool stringTruthiness(const std::string &val) return false; throw ConfigFileException("Value '" + val + "' can't be converted to boolean"); } + +bool isPowerOfTwo(int n) +{ + return (n != 0) && (n & (n - 1)) == 0; +} diff --git a/utils.h b/utils.h index b1161bc..804c804 100644 --- a/utils.h +++ b/utils.h @@ -41,5 +41,6 @@ int64_t currentMSecsSinceEpoch(); std::string getSecureRandomString(const size_t len); std::string str_tolower(std::string s); bool stringTruthiness(const std::string &val); +bool isPowerOfTwo(int val); #endif // UTILS_H