diff --git a/.gitignore b/.gitignore index 8a9d35c..c9fbc5d 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ *.user +build-* diff --git a/CMakeLists.txt b/CMakeLists.txt index 8062466..060dd37 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,6 +19,7 @@ add_executable(FlashMQ subscriptionstore.cpp rwlockguard.cpp retainedmessage.cpp + cirbuf.cpp ) target_link_libraries(FlashMQ pthread) diff --git a/FlashMQTests/FlashMQTests.pro b/FlashMQTests/FlashMQTests.pro new file mode 100644 index 0000000..75e26e2 --- /dev/null +++ b/FlashMQTests/FlashMQTests.pro @@ -0,0 +1,18 @@ +QT += testlib +QT -= gui +Qt += network + +DEFINES += TESTING + +INCLUDEPATH += .. + +CONFIG += qt console warn_on depend_includepath testcase +CONFIG -= app_bundle + +TEMPLATE = app + +SOURCES += tst_maintests.cpp \ + ../cirbuf.cpp + +HEADERS += \ + ../cirbuf.h diff --git a/FlashMQTests/tst_maintests.cpp b/FlashMQTests/tst_maintests.cpp new file mode 100644 index 0000000..946167f --- /dev/null +++ b/FlashMQTests/tst_maintests.cpp @@ -0,0 +1,248 @@ +#include + +#include "cirbuf.h" + +class MainTests : public QObject +{ + Q_OBJECT + +public: + MainTests(); + ~MainTests(); + +private slots: + void test_case1(); + void test_circbuf(); + void test_circbuf_unwrapped_doubling(); + void test_circbuf_wrapped_doubling(); + void test_circbuf_full_wrapped_buffer_doubling(); + +}; + +MainTests::MainTests() +{ + +} + +MainTests::~MainTests() +{ + +} + +void MainTests::test_case1() +{ + +} + +void MainTests::test_circbuf() +{ + CirBuf buf(64); + + QCOMPARE(buf.freeSpace(), 63); + + int write_n = 40; + + char *head = buf.headPtr(); + for (int i = 0; i < write_n; i++) + { + head[i] = i+1; + } + + buf.advanceHead(write_n); + + QCOMPARE(buf.head, write_n); + QCOMPARE(buf.tail, 0); + QCOMPARE(buf.maxReadSize(), write_n); + QCOMPARE(buf.maxWriteSize(), (64 - write_n - 1)); + QCOMPARE(buf.freeSpace(), 64 - write_n - 1); + + for (int i = 0; i < write_n; i++) + { + QCOMPARE(buf.tailPtr()[i], i+1); + } + + buf.advanceTail(write_n); + QVERIFY(buf.tail == buf.head); + QCOMPARE(buf.tail, write_n); + QCOMPARE(buf.maxReadSize(), 0); + QCOMPARE(buf.maxWriteSize(), (64 - write_n)); // no longer -1, because the head can point to 0 afterwards + QCOMPARE(buf.freeSpace(), 63); + + write_n = buf.maxWriteSize(); + + head = buf.headPtr(); + for (int i = 0; i < write_n; i++) + { + head[i] = i+1; + } + buf.advanceHead(write_n); + + QCOMPARE(buf.head, 0); + + // Now write more, starting at the beginning. + + write_n = buf.maxWriteSize(); + + head = buf.headPtr(); + for (int i = 0; i < write_n; i++) + { + head[i] = i+100; // Offset by 100 so we can see if we overwrite the tail + } + buf.advanceHead(write_n); + + QCOMPARE(buf.tailPtr()[0], 1); // Did we not overwrite the tail? + QCOMPARE(buf.head, buf.tail - 1); + +} + + + +void MainTests::test_circbuf_unwrapped_doubling() +{ + CirBuf buf(64); + + int w = 63; + + char *head = buf.headPtr(); + for (int i = 0; i < w; i++) + { + head[i] = i+1; + } + buf.advanceHead(63); + + char *tail = buf.tailPtr(); + for (int i = 0; i < w; i++) + { + QCOMPARE(tail[i], i+1); + } + QCOMPARE(buf.buf[64], 0); // Vacant place, because of the circulerness. + + QCOMPARE(buf.head, 63); + + buf.doubleSize(); + tail = buf.tailPtr(); + + for (int i = 0; i < w; i++) + { + QCOMPARE(tail[i], i+1); + } + + for (int i = 63; i < 128; i++) + { + QCOMPARE(tail[i], 5); + } + + QCOMPARE(buf.tail, 0); + QCOMPARE(buf.head, 63); + QCOMPARE(buf.maxWriteSize(), 64); + QCOMPARE(buf.maxReadSize(), 63); +} + +void MainTests::test_circbuf_wrapped_doubling() +{ + CirBuf buf(64); + + int w = 40; + + char *head = buf.headPtr(); + for (int i = 0; i < w; i++) + { + head[i] = i+1; + } + buf.advanceHead(w); + + QCOMPARE(buf.tail, 0); + QCOMPARE(buf.head, w); + QCOMPARE(buf.maxReadSize(), 40); + QCOMPARE(buf.maxWriteSize(), 23); + + buf.advanceTail(40); + + QCOMPARE(buf.maxWriteSize(), 24); + + head = buf.headPtr(); + for (int i = 0; i < 24; i++) + { + head[i] = 99; + } + buf.advanceHead(24); + + QCOMPARE(buf.tail, 40); + QCOMPARE(buf.head, 0); + QCOMPARE(buf.maxReadSize(), 24); + QCOMPARE(buf.maxWriteSize(), 39); + + // Now write a little more, which starts at the start + + head = buf.headPtr(); + for (int i = 0; i < 10; i++) + { + head[i] = 88; + } + buf.advanceHead(10); + QCOMPARE(buf.head, 10); + + buf.doubleSize(); + + // The 88's that were appended at the start, should now appear at the end; + for (int i = 64; i < 74; i++) + { + QCOMPARE(buf.buf[i], 88); + } + + QCOMPARE(buf.tail, 40); + QCOMPARE(buf.head, 74); +} + +void MainTests::test_circbuf_full_wrapped_buffer_doubling() +{ + CirBuf buf(64); + + buf.head = 10; + buf.tail = 10; + + memset(buf.headPtr(), 1, buf.maxWriteSize()); + buf.advanceHead(buf.maxWriteSize()); + memset(buf.headPtr(), 2, buf.maxWriteSize()); + buf.advanceHead(buf.maxWriteSize()); + + for (int i = 0; i < 9; i++) + { + QCOMPARE(buf.buf[i], 2); + } + + QCOMPARE(buf.buf[9], 0); + + for (int i = 10; i < 64; i++) + { + QCOMPARE(buf.buf[i], 1); + } + + QVERIFY(true); + + buf.doubleSize(); + + // The places where value was 1 are the same + for (int i = 10; i < 64; i++) + { + QCOMPARE(buf.buf[i], 1); + } + + // The nine 2's have been moved to the end + for (int i = 64; i < 73; i++) + { + QCOMPARE(buf.buf[i], 2); + } + + // The rest are our debug 5. + for (int i = 73; i < 128; i++) + { + QCOMPARE(buf.buf[i], 5); + } + + QVERIFY(true); +} + +QTEST_APPLESS_MAIN(MainTests) + +#include "tst_maintests.moc" diff --git a/cirbuf.cpp b/cirbuf.cpp new file mode 100644 index 0000000..aaf9931 --- /dev/null +++ b/cirbuf.cpp @@ -0,0 +1,112 @@ +#include "cirbuf.h" + +#include +#include +#include +#include +#include + +CirBuf::CirBuf(size_t size) : + size(size) +{ + buf = (char*)malloc(size); + + if (buf == NULL) + throw std::runtime_error("Malloc error constructing client."); + +#ifndef NDEBUG + memset(buf, 0, size); +#endif +} + +CirBuf::~CirBuf() +{ + if (buf) + free(buf); +} + +int CirBuf::usedBytes() const +{ + int result = (head - tail) & (size-1); + return result; +} + +int CirBuf::freeSpace() const +{ + int result = (tail - (head + 1)) & (size-1); + return result; +} + +int CirBuf::maxWriteSize() const +{ + int end = size - 1 - head; + int n = (end + tail) & (size-1); + int result = n <= end ? n : end+1; + return result; +} + +int CirBuf::maxReadSize() const +{ + int end = size - tail; + int n = (head + end) & (size-1); + int result = n < end ? n : end; + return result; +} + +char *CirBuf::headPtr() +{ + return &buf[head]; +} + +char *CirBuf::tailPtr() +{ + return &buf[tail]; +} + +void CirBuf::advanceHead(int n) +{ + head = (head + n) & (size -1); + assert(tail != head); // Putting things in the buffer must never end on tail, because tail == head == empty. +} + +void CirBuf::advanceTail(int n) +{ + tail = (tail + n) & (size -1); +} + +int CirBuf::peakAhead(int offset) const +{ + int b = buf[(tail + offset) & (size - 1)]; + return b; +} + +void CirBuf::doubleSize() +{ + uint newSize = size * 2; + char *newBuf = (char*)realloc(buf, newSize); + + if (newBuf == NULL) + throw std::runtime_error("Malloc error doubling buffer size."); + + uint maxRead = maxReadSize(); + buf = newBuf; + + if (head < tail) + { + std::memcpy(&buf[tail + maxRead], buf, head); + } + + head = tail + usedBytes(); + size = newSize; + + std::cout << "New read buf size: " << size << std::endl; + +#ifdef TESTING + memset(&buf[head], 5, maxWriteSize() + 2); +#endif +} + +uint CirBuf::getSize() const +{ + return size; +} diff --git a/cirbuf.h b/cirbuf.h new file mode 100644 index 0000000..761d18c --- /dev/null +++ b/cirbuf.h @@ -0,0 +1,36 @@ +#ifndef CIRBUF_H +#define CIRBUF_H + +#include +#include + +// Optimized circular buffer, works only with sizes power of two. +class CirBuf +{ +#ifdef TESTING + friend class MainTests; +#endif + + char *buf = NULL; + uint head = 0; + uint tail = 0; + uint size = 0; +public: + + CirBuf(size_t size); + ~CirBuf(); + + int usedBytes() const; + int freeSpace() const; + int maxWriteSize() const; + int maxReadSize() const; + char *headPtr(); + char *tailPtr(); + void advanceHead(int n); + void advanceTail(int n); + int peakAhead(int offset) const; + void doubleSize(); + uint getSize() const; +}; + +#endif // CIRBUF_H diff --git a/client.cpp b/client.cpp index 8550518..37c9c94 100644 --- a/client.cpp +++ b/client.cpp @@ -7,34 +7,24 @@ Client::Client(int fd, ThreadData_p threadData) : fd(fd), + readbuf(CLIENT_BUFFER_SIZE), threadData(threadData) { int flags = fcntl(fd, F_GETFL); fcntl(fd, F_SETFL, flags | O_NONBLOCK); - char *readbuf = (char*)malloc(CLIENT_BUFFER_SIZE); char *writebuf = (char*)malloc(CLIENT_BUFFER_SIZE); - if (readbuf == NULL || writebuf == NULL) + if (writebuf == NULL) { - if (readbuf != NULL) - free(readbuf); - if (writebuf != NULL) - free(writebuf); - - readbuf = NULL; - writebuf = NULL; - throw std::runtime_error("Malloc error constructing client."); } - this->readbuf = readbuf; this->writebuf = writebuf; } Client::~Client() { close(fd); - free(readbuf); free(writebuf); } @@ -55,11 +45,11 @@ bool Client::readFdIntoBuffer() return false; int n = 0; - while (getReadBufFreeSpace() > 0 && (n = read(fd, &readbuf[wi], getReadBufMaxWriteSize())) != 0) + while (readbuf.freeSpace() > 0 && (n = read(fd, readbuf.headPtr(), readbuf.maxWriteSize())) != 0) { if (n > 0) { - wi = (wi + n) % readBufsize; + readbuf.advanceHead(n); } if (n < 0) @@ -73,11 +63,11 @@ bool Client::readFdIntoBuffer() } // Make sure we either always have enough space for a next call of this method, or stop reading the fd. - if (getReadBufFreeSpace() == 0) + if (readbuf.freeSpace() == 0) { - if (readBufsize * 2 < CLIENT_MAX_BUFFER_SIZE) + if (readbuf.getSize() * 2 < CLIENT_MAX_BUFFER_SIZE) { - growReadBuffer(); + readbuf.doubleSize(); } else { @@ -199,18 +189,6 @@ std::string Client::repr() return a.str(); } -void Client::growReadBuffer() // TODO: refactor -{ - const size_t newBufSize = readBufsize * 2; - char *readbuf = (char*)realloc(this->readbuf, newBufSize); - if (readbuf == NULL) - throw std::runtime_error("Memory allocation failure in growReadBuffer()"); - this->readbuf = readbuf; - readBufsize = newBufSize; - - std::cout << "New read buf size: " << readBufsize << std::endl; -} - void Client::growWriteBuffer(size_t add_size) { if (add_size == 0) @@ -269,20 +247,23 @@ void Client::setReadyForReading(bool val) bool Client::bufferToMqttPackets(std::vector &packetQueueIn, Client_p &sender) { - while (getReadBufBytesUsed() >= MQTT_HEADER_LENGH) + while (readbuf.usedBytes() >= MQTT_HEADER_LENGH) { // Determine the packet length by decoding the variable length - int remaining_length_i = ri + 1; // index of 'remaining length' field is one after start. - size_t fixed_header_length = 1; + int remaining_length_i = 1; // index of 'remaining length' field is one after start. + int fixed_header_length = 1; int multiplier = 1; - size_t packet_length = 0; + int packet_length = 0; unsigned char encodedByte = 0; do { fixed_header_length++; - if (remaining_length_i >= wi) + + // This happens when you only don't have all the bytes that specify the remaining length. + if (fixed_header_length > readbuf.usedBytes()) return false; - encodedByte = readbuf[remaining_length_i++ % readBufsize]; + + encodedByte = readbuf.peakAhead(remaining_length_i++); packet_length += (encodedByte & 127) * multiplier; multiplier *= 128; if (multiplier > 128*128*128) @@ -296,25 +277,18 @@ bool Client::bufferToMqttPackets(std::vector &packetQueueIn, Client_ throw ProtocolError("An unauthenticated client sends a packet of 1 MB or bigger? Probably it's just random bytes."); } - if (packet_length <= getReadBufBytesUsed()) + if (packet_length <= readbuf.usedBytes()) { - // TODO: deal with circularness here, or in the packet? - MqttPacket packet(&readbuf[ri], packet_length, fixed_header_length, sender); + MqttPacket packet(readbuf, packet_length, fixed_header_length, sender); packetQueueIn.push_back(std::move(packet)); - - ri = (ri + packet_length) % readBufsize; } else break; - } - if (getReadBufMaxWriteSize() > 0) - { - setReadyForReading(true); - } + setReadyForReading(readbuf.freeSpace() > 0); - if (getReadBufBytesUsed() == 0) + if (readbuf.usedBytes() == 0) { // TODO: reset buffer to normal size after a while of not needing it, or not needing the extra space. } diff --git a/client.h b/client.h index 813c8a1..1d96c8f 100644 --- a/client.h +++ b/client.h @@ -6,27 +6,24 @@ #include #include #include -#include #include "forward_declarations.h" #include "threaddata.h" #include "mqttpacket.h" #include "exceptions.h" +#include "cirbuf.h" -#define CLIENT_BUFFER_SIZE 1024 -#define CLIENT_MAX_BUFFER_SIZE 1048576 +#define CLIENT_BUFFER_SIZE 1024 // Must be power of 2 +#define CLIENT_MAX_BUFFER_SIZE 65536 #define MQTT_HEADER_LENGH 2 class Client { int fd; - char *readbuf = NULL; // With many clients, it may not be smart to keep a (big) buffer around. - size_t readBufsize = CLIENT_BUFFER_SIZE; - uint wi = 0; - uint ri = 0; + CirBuf readbuf; char *writebuf = NULL; // With many clients, it may not be smart to keep a (big) buffer around. size_t writeBufsize = CLIENT_BUFFER_SIZE; @@ -52,37 +49,6 @@ class Client ThreadData_p threadData; std::mutex writeBufMutex; - inline size_t getReadBufBytesUsed() const - { - size_t result; - if (wi >= ri) - result = wi - ri; - else - result = (readBufsize + wi) - ri; - return result; - }; - - inline size_t getReadBufFreeSpace() const - { - size_t result = readBufsize - getReadBufBytesUsed() - 1; - return result; - } - - inline size_t getReadBufMaxWriteSize() const - { - const size_t available_space = getReadBufFreeSpace(); - - size_t result = 0; - if (wi >= ri) - result = available_space - wi; - else - result = ri - wi - 1; - - return result; - } - - void growReadBuffer(); - size_t getWriteBufMaxWriteSize() { diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 4d574ef..e457f16 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -12,12 +12,23 @@ RemainingLength::RemainingLength() } // constructor for parsing incoming packets -MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender) : - bites(len), +MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, Client_p &sender) : + bites(packet_len), fixed_header_length(fixed_header_length), sender(sender) { - std::memcpy(&bites[0], buf, len); + int i = 0; + ssize_t _packet_len = packet_len; + while (_packet_len > 0) + { + int readlen = std::min(buf.maxReadSize(), _packet_len); + std::memcpy(&bites[i], buf.tailPtr(), readlen); + buf.advanceTail(readlen); + i += readlen; + _packet_len -= readlen; + } + assert(_packet_len == 0); + first_byte = bites[0]; unsigned char _packetType = (first_byte & 0xF0) >> 4; packetType = (PacketType)_packetType; diff --git a/mqttpacket.h b/mqttpacket.h index 1b6a528..f470c0d 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -12,6 +12,7 @@ #include "exceptions.h" #include "types.h" #include "subscriptionstore.h" +#include "cirbuf.h" struct RemainingLength { @@ -42,7 +43,7 @@ class MqttPacket public: PacketType packetType = PacketType::Reserved; - MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender); // Constructor for parsing incoming packets. + MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, Client_p &sender); // Constructor for parsing incoming packets. // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance. MqttPacket(const ConnAck &connAck);