Commit ae95e6dc2bf468b7b20d1a8fe5c2c37c3ddcaa9d

Authored by Wiebe Cazemier
1 parent ae9b2060

Use proper circular buffer for reading

.gitignore
1 1 *.user
  2 +build-*
... ...
CMakeLists.txt
... ... @@ -19,6 +19,7 @@ add_executable(FlashMQ
19 19 subscriptionstore.cpp
20 20 rwlockguard.cpp
21 21 retainedmessage.cpp
  22 + cirbuf.cpp
22 23 )
23 24  
24 25 target_link_libraries(FlashMQ pthread)
... ...
FlashMQTests/FlashMQTests.pro 0 → 100644
  1 +QT += testlib
  2 +QT -= gui
  3 +Qt += network
  4 +
  5 +DEFINES += TESTING
  6 +
  7 +INCLUDEPATH += ..
  8 +
  9 +CONFIG += qt console warn_on depend_includepath testcase
  10 +CONFIG -= app_bundle
  11 +
  12 +TEMPLATE = app
  13 +
  14 +SOURCES += tst_maintests.cpp \
  15 + ../cirbuf.cpp
  16 +
  17 +HEADERS += \
  18 + ../cirbuf.h
... ...
FlashMQTests/tst_maintests.cpp 0 → 100644
  1 +#include <QtTest>
  2 +
  3 +#include "cirbuf.h"
  4 +
  5 +class MainTests : public QObject
  6 +{
  7 + Q_OBJECT
  8 +
  9 +public:
  10 + MainTests();
  11 + ~MainTests();
  12 +
  13 +private slots:
  14 + void test_case1();
  15 + void test_circbuf();
  16 + void test_circbuf_unwrapped_doubling();
  17 + void test_circbuf_wrapped_doubling();
  18 + void test_circbuf_full_wrapped_buffer_doubling();
  19 +
  20 +};
  21 +
  22 +MainTests::MainTests()
  23 +{
  24 +
  25 +}
  26 +
  27 +MainTests::~MainTests()
  28 +{
  29 +
  30 +}
  31 +
  32 +void MainTests::test_case1()
  33 +{
  34 +
  35 +}
  36 +
  37 +void MainTests::test_circbuf()
  38 +{
  39 + CirBuf buf(64);
  40 +
  41 + QCOMPARE(buf.freeSpace(), 63);
  42 +
  43 + int write_n = 40;
  44 +
  45 + char *head = buf.headPtr();
  46 + for (int i = 0; i < write_n; i++)
  47 + {
  48 + head[i] = i+1;
  49 + }
  50 +
  51 + buf.advanceHead(write_n);
  52 +
  53 + QCOMPARE(buf.head, write_n);
  54 + QCOMPARE(buf.tail, 0);
  55 + QCOMPARE(buf.maxReadSize(), write_n);
  56 + QCOMPARE(buf.maxWriteSize(), (64 - write_n - 1));
  57 + QCOMPARE(buf.freeSpace(), 64 - write_n - 1);
  58 +
  59 + for (int i = 0; i < write_n; i++)
  60 + {
  61 + QCOMPARE(buf.tailPtr()[i], i+1);
  62 + }
  63 +
  64 + buf.advanceTail(write_n);
  65 + QVERIFY(buf.tail == buf.head);
  66 + QCOMPARE(buf.tail, write_n);
  67 + QCOMPARE(buf.maxReadSize(), 0);
  68 + QCOMPARE(buf.maxWriteSize(), (64 - write_n)); // no longer -1, because the head can point to 0 afterwards
  69 + QCOMPARE(buf.freeSpace(), 63);
  70 +
  71 + write_n = buf.maxWriteSize();
  72 +
  73 + head = buf.headPtr();
  74 + for (int i = 0; i < write_n; i++)
  75 + {
  76 + head[i] = i+1;
  77 + }
  78 + buf.advanceHead(write_n);
  79 +
  80 + QCOMPARE(buf.head, 0);
  81 +
  82 + // Now write more, starting at the beginning.
  83 +
  84 + write_n = buf.maxWriteSize();
  85 +
  86 + head = buf.headPtr();
  87 + for (int i = 0; i < write_n; i++)
  88 + {
  89 + head[i] = i+100; // Offset by 100 so we can see if we overwrite the tail
  90 + }
  91 + buf.advanceHead(write_n);
  92 +
  93 + QCOMPARE(buf.tailPtr()[0], 1); // Did we not overwrite the tail?
  94 + QCOMPARE(buf.head, buf.tail - 1);
  95 +
  96 +}
  97 +
  98 +
  99 +
  100 +void MainTests::test_circbuf_unwrapped_doubling()
  101 +{
  102 + CirBuf buf(64);
  103 +
  104 + int w = 63;
  105 +
  106 + char *head = buf.headPtr();
  107 + for (int i = 0; i < w; i++)
  108 + {
  109 + head[i] = i+1;
  110 + }
  111 + buf.advanceHead(63);
  112 +
  113 + char *tail = buf.tailPtr();
  114 + for (int i = 0; i < w; i++)
  115 + {
  116 + QCOMPARE(tail[i], i+1);
  117 + }
  118 + QCOMPARE(buf.buf[64], 0); // Vacant place, because of the circulerness.
  119 +
  120 + QCOMPARE(buf.head, 63);
  121 +
  122 + buf.doubleSize();
  123 + tail = buf.tailPtr();
  124 +
  125 + for (int i = 0; i < w; i++)
  126 + {
  127 + QCOMPARE(tail[i], i+1);
  128 + }
  129 +
  130 + for (int i = 63; i < 128; i++)
  131 + {
  132 + QCOMPARE(tail[i], 5);
  133 + }
  134 +
  135 + QCOMPARE(buf.tail, 0);
  136 + QCOMPARE(buf.head, 63);
  137 + QCOMPARE(buf.maxWriteSize(), 64);
  138 + QCOMPARE(buf.maxReadSize(), 63);
  139 +}
  140 +
  141 +void MainTests::test_circbuf_wrapped_doubling()
  142 +{
  143 + CirBuf buf(64);
  144 +
  145 + int w = 40;
  146 +
  147 + char *head = buf.headPtr();
  148 + for (int i = 0; i < w; i++)
  149 + {
  150 + head[i] = i+1;
  151 + }
  152 + buf.advanceHead(w);
  153 +
  154 + QCOMPARE(buf.tail, 0);
  155 + QCOMPARE(buf.head, w);
  156 + QCOMPARE(buf.maxReadSize(), 40);
  157 + QCOMPARE(buf.maxWriteSize(), 23);
  158 +
  159 + buf.advanceTail(40);
  160 +
  161 + QCOMPARE(buf.maxWriteSize(), 24);
  162 +
  163 + head = buf.headPtr();
  164 + for (int i = 0; i < 24; i++)
  165 + {
  166 + head[i] = 99;
  167 + }
  168 + buf.advanceHead(24);
  169 +
  170 + QCOMPARE(buf.tail, 40);
  171 + QCOMPARE(buf.head, 0);
  172 + QCOMPARE(buf.maxReadSize(), 24);
  173 + QCOMPARE(buf.maxWriteSize(), 39);
  174 +
  175 + // Now write a little more, which starts at the start
  176 +
  177 + head = buf.headPtr();
  178 + for (int i = 0; i < 10; i++)
  179 + {
  180 + head[i] = 88;
  181 + }
  182 + buf.advanceHead(10);
  183 + QCOMPARE(buf.head, 10);
  184 +
  185 + buf.doubleSize();
  186 +
  187 + // The 88's that were appended at the start, should now appear at the end;
  188 + for (int i = 64; i < 74; i++)
  189 + {
  190 + QCOMPARE(buf.buf[i], 88);
  191 + }
  192 +
  193 + QCOMPARE(buf.tail, 40);
  194 + QCOMPARE(buf.head, 74);
  195 +}
  196 +
  197 +void MainTests::test_circbuf_full_wrapped_buffer_doubling()
  198 +{
  199 + CirBuf buf(64);
  200 +
  201 + buf.head = 10;
  202 + buf.tail = 10;
  203 +
  204 + memset(buf.headPtr(), 1, buf.maxWriteSize());
  205 + buf.advanceHead(buf.maxWriteSize());
  206 + memset(buf.headPtr(), 2, buf.maxWriteSize());
  207 + buf.advanceHead(buf.maxWriteSize());
  208 +
  209 + for (int i = 0; i < 9; i++)
  210 + {
  211 + QCOMPARE(buf.buf[i], 2);
  212 + }
  213 +
  214 + QCOMPARE(buf.buf[9], 0);
  215 +
  216 + for (int i = 10; i < 64; i++)
  217 + {
  218 + QCOMPARE(buf.buf[i], 1);
  219 + }
  220 +
  221 + QVERIFY(true);
  222 +
  223 + buf.doubleSize();
  224 +
  225 + // The places where value was 1 are the same
  226 + for (int i = 10; i < 64; i++)
  227 + {
  228 + QCOMPARE(buf.buf[i], 1);
  229 + }
  230 +
  231 + // The nine 2's have been moved to the end
  232 + for (int i = 64; i < 73; i++)
  233 + {
  234 + QCOMPARE(buf.buf[i], 2);
  235 + }
  236 +
  237 + // The rest are our debug 5.
  238 + for (int i = 73; i < 128; i++)
  239 + {
  240 + QCOMPARE(buf.buf[i], 5);
  241 + }
  242 +
  243 + QVERIFY(true);
  244 +}
  245 +
  246 +QTEST_APPLESS_MAIN(MainTests)
  247 +
  248 +#include "tst_maintests.moc"
... ...
cirbuf.cpp 0 → 100644
  1 +#include "cirbuf.h"
  2 +
  3 +#include <iostream>
  4 +#include <exception>
  5 +#include <stdexcept>
  6 +#include <cassert>
  7 +#include <cstring>
  8 +
  9 +CirBuf::CirBuf(size_t size) :
  10 + size(size)
  11 +{
  12 + buf = (char*)malloc(size);
  13 +
  14 + if (buf == NULL)
  15 + throw std::runtime_error("Malloc error constructing client.");
  16 +
  17 +#ifndef NDEBUG
  18 + memset(buf, 0, size);
  19 +#endif
  20 +}
  21 +
  22 +CirBuf::~CirBuf()
  23 +{
  24 + if (buf)
  25 + free(buf);
  26 +}
  27 +
  28 +int CirBuf::usedBytes() const
  29 +{
  30 + int result = (head - tail) & (size-1);
  31 + return result;
  32 +}
  33 +
  34 +int CirBuf::freeSpace() const
  35 +{
  36 + int result = (tail - (head + 1)) & (size-1);
  37 + return result;
  38 +}
  39 +
  40 +int CirBuf::maxWriteSize() const
  41 +{
  42 + int end = size - 1 - head;
  43 + int n = (end + tail) & (size-1);
  44 + int result = n <= end ? n : end+1;
  45 + return result;
  46 +}
  47 +
  48 +int CirBuf::maxReadSize() const
  49 +{
  50 + int end = size - tail;
  51 + int n = (head + end) & (size-1);
  52 + int result = n < end ? n : end;
  53 + return result;
  54 +}
  55 +
  56 +char *CirBuf::headPtr()
  57 +{
  58 + return &buf[head];
  59 +}
  60 +
  61 +char *CirBuf::tailPtr()
  62 +{
  63 + return &buf[tail];
  64 +}
  65 +
  66 +void CirBuf::advanceHead(int n)
  67 +{
  68 + head = (head + n) & (size -1);
  69 + assert(tail != head); // Putting things in the buffer must never end on tail, because tail == head == empty.
  70 +}
  71 +
  72 +void CirBuf::advanceTail(int n)
  73 +{
  74 + tail = (tail + n) & (size -1);
  75 +}
  76 +
  77 +int CirBuf::peakAhead(int offset) const
  78 +{
  79 + int b = buf[(tail + offset) & (size - 1)];
  80 + return b;
  81 +}
  82 +
  83 +void CirBuf::doubleSize()
  84 +{
  85 + uint newSize = size * 2;
  86 + char *newBuf = (char*)realloc(buf, newSize);
  87 +
  88 + if (newBuf == NULL)
  89 + throw std::runtime_error("Malloc error doubling buffer size.");
  90 +
  91 + uint maxRead = maxReadSize();
  92 + buf = newBuf;
  93 +
  94 + if (head < tail)
  95 + {
  96 + std::memcpy(&buf[tail + maxRead], buf, head);
  97 + }
  98 +
  99 + head = tail + usedBytes();
  100 + size = newSize;
  101 +
  102 + std::cout << "New read buf size: " << size << std::endl;
  103 +
  104 +#ifdef TESTING
  105 + memset(&buf[head], 5, maxWriteSize() + 2);
  106 +#endif
  107 +}
  108 +
  109 +uint CirBuf::getSize() const
  110 +{
  111 + return size;
  112 +}
... ...
cirbuf.h 0 → 100644
  1 +#ifndef CIRBUF_H
  2 +#define CIRBUF_H
  3 +
  4 +#include <stddef.h>
  5 +#include <stdlib.h>
  6 +
  7 +// Optimized circular buffer, works only with sizes power of two.
  8 +class CirBuf
  9 +{
  10 +#ifdef TESTING
  11 + friend class MainTests;
  12 +#endif
  13 +
  14 + char *buf = NULL;
  15 + uint head = 0;
  16 + uint tail = 0;
  17 + uint size = 0;
  18 +public:
  19 +
  20 + CirBuf(size_t size);
  21 + ~CirBuf();
  22 +
  23 + int usedBytes() const;
  24 + int freeSpace() const;
  25 + int maxWriteSize() const;
  26 + int maxReadSize() const;
  27 + char *headPtr();
  28 + char *tailPtr();
  29 + void advanceHead(int n);
  30 + void advanceTail(int n);
  31 + int peakAhead(int offset) const;
  32 + void doubleSize();
  33 + uint getSize() const;
  34 +};
  35 +
  36 +#endif // CIRBUF_H
... ...
client.cpp
... ... @@ -7,34 +7,24 @@
7 7  
8 8 Client::Client(int fd, ThreadData_p threadData) :
9 9 fd(fd),
  10 + readbuf(CLIENT_BUFFER_SIZE),
10 11 threadData(threadData)
11 12 {
12 13 int flags = fcntl(fd, F_GETFL);
13 14 fcntl(fd, F_SETFL, flags | O_NONBLOCK);
14   - char *readbuf = (char*)malloc(CLIENT_BUFFER_SIZE);
15 15 char *writebuf = (char*)malloc(CLIENT_BUFFER_SIZE);
16 16  
17   - if (readbuf == NULL || writebuf == NULL)
  17 + if (writebuf == NULL)
18 18 {
19   - if (readbuf != NULL)
20   - free(readbuf);
21   - if (writebuf != NULL)
22   - free(writebuf);
23   -
24   - readbuf = NULL;
25   - writebuf = NULL;
26   -
27 19 throw std::runtime_error("Malloc error constructing client.");
28 20 }
29 21  
30   - this->readbuf = readbuf;
31 22 this->writebuf = writebuf;
32 23 }
33 24  
34 25 Client::~Client()
35 26 {
36 27 close(fd);
37   - free(readbuf);
38 28 free(writebuf);
39 29 }
40 30  
... ... @@ -55,11 +45,11 @@ bool Client::readFdIntoBuffer()
55 45 return false;
56 46  
57 47 int n = 0;
58   - while (getReadBufFreeSpace() > 0 && (n = read(fd, &readbuf[wi], getReadBufMaxWriteSize())) != 0)
  48 + while (readbuf.freeSpace() > 0 && (n = read(fd, readbuf.headPtr(), readbuf.maxWriteSize())) != 0)
59 49 {
60 50 if (n > 0)
61 51 {
62   - wi = (wi + n) % readBufsize;
  52 + readbuf.advanceHead(n);
63 53 }
64 54  
65 55 if (n < 0)
... ... @@ -73,11 +63,11 @@ bool Client::readFdIntoBuffer()
73 63 }
74 64  
75 65 // Make sure we either always have enough space for a next call of this method, or stop reading the fd.
76   - if (getReadBufFreeSpace() == 0)
  66 + if (readbuf.freeSpace() == 0)
77 67 {
78   - if (readBufsize * 2 < CLIENT_MAX_BUFFER_SIZE)
  68 + if (readbuf.getSize() * 2 < CLIENT_MAX_BUFFER_SIZE)
79 69 {
80   - growReadBuffer();
  70 + readbuf.doubleSize();
81 71 }
82 72 else
83 73 {
... ... @@ -199,18 +189,6 @@ std::string Client::repr()
199 189 return a.str();
200 190 }
201 191  
202   -void Client::growReadBuffer() // TODO: refactor
203   -{
204   - const size_t newBufSize = readBufsize * 2;
205   - char *readbuf = (char*)realloc(this->readbuf, newBufSize);
206   - if (readbuf == NULL)
207   - throw std::runtime_error("Memory allocation failure in growReadBuffer()");
208   - this->readbuf = readbuf;
209   - readBufsize = newBufSize;
210   -
211   - std::cout << "New read buf size: " << readBufsize << std::endl;
212   -}
213   -
214 192 void Client::growWriteBuffer(size_t add_size)
215 193 {
216 194 if (add_size == 0)
... ... @@ -269,20 +247,23 @@ void Client::setReadyForReading(bool val)
269 247  
270 248 bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, Client_p &sender)
271 249 {
272   - while (getReadBufBytesUsed() >= MQTT_HEADER_LENGH)
  250 + while (readbuf.usedBytes() >= MQTT_HEADER_LENGH)
273 251 {
274 252 // Determine the packet length by decoding the variable length
275   - int remaining_length_i = ri + 1; // index of 'remaining length' field is one after start.
276   - size_t fixed_header_length = 1;
  253 + int remaining_length_i = 1; // index of 'remaining length' field is one after start.
  254 + int fixed_header_length = 1;
277 255 int multiplier = 1;
278   - size_t packet_length = 0;
  256 + int packet_length = 0;
279 257 unsigned char encodedByte = 0;
280 258 do
281 259 {
282 260 fixed_header_length++;
283   - if (remaining_length_i >= wi)
  261 +
  262 + // This happens when you only don't have all the bytes that specify the remaining length.
  263 + if (fixed_header_length > readbuf.usedBytes())
284 264 return false;
285   - encodedByte = readbuf[remaining_length_i++ % readBufsize];
  265 +
  266 + encodedByte = readbuf.peakAhead(remaining_length_i++);
286 267 packet_length += (encodedByte & 127) * multiplier;
287 268 multiplier *= 128;
288 269 if (multiplier > 128*128*128)
... ... @@ -296,25 +277,18 @@ bool Client::bufferToMqttPackets(std::vector&lt;MqttPacket&gt; &amp;packetQueueIn, Client_
296 277 throw ProtocolError("An unauthenticated client sends a packet of 1 MB or bigger? Probably it's just random bytes.");
297 278 }
298 279  
299   - if (packet_length <= getReadBufBytesUsed())
  280 + if (packet_length <= readbuf.usedBytes())
300 281 {
301   - // TODO: deal with circularness here, or in the packet?
302   - MqttPacket packet(&readbuf[ri], packet_length, fixed_header_length, sender);
  282 + MqttPacket packet(readbuf, packet_length, fixed_header_length, sender);
303 283 packetQueueIn.push_back(std::move(packet));
304   -
305   - ri = (ri + packet_length) % readBufsize;
306 284 }
307 285 else
308 286 break;
309   -
310 287 }
311 288  
312   - if (getReadBufMaxWriteSize() > 0)
313   - {
314   - setReadyForReading(true);
315   - }
  289 + setReadyForReading(readbuf.freeSpace() > 0);
316 290  
317   - if (getReadBufBytesUsed() == 0)
  291 + if (readbuf.usedBytes() == 0)
318 292 {
319 293 // TODO: reset buffer to normal size after a while of not needing it, or not needing the extra space.
320 294 }
... ...
client.h
... ... @@ -6,27 +6,24 @@
6 6 #include <vector>
7 7 #include <mutex>
8 8 #include <iostream>
9   -#include
10 9  
11 10 #include "forward_declarations.h"
12 11  
13 12 #include "threaddata.h"
14 13 #include "mqttpacket.h"
15 14 #include "exceptions.h"
  15 +#include "cirbuf.h"
16 16  
17 17  
18   -#define CLIENT_BUFFER_SIZE 1024
19   -#define CLIENT_MAX_BUFFER_SIZE 1048576
  18 +#define CLIENT_BUFFER_SIZE 1024 // Must be power of 2
  19 +#define CLIENT_MAX_BUFFER_SIZE 65536
20 20 #define MQTT_HEADER_LENGH 2
21 21  
22 22 class Client
23 23 {
24 24 int fd;
25 25  
26   - char *readbuf = NULL; // With many clients, it may not be smart to keep a (big) buffer around.
27   - size_t readBufsize = CLIENT_BUFFER_SIZE;
28   - uint wi = 0;
29   - uint ri = 0;
  26 + CirBuf readbuf;
30 27  
31 28 char *writebuf = NULL; // With many clients, it may not be smart to keep a (big) buffer around.
32 29 size_t writeBufsize = CLIENT_BUFFER_SIZE;
... ... @@ -52,37 +49,6 @@ class Client
52 49 ThreadData_p threadData;
53 50 std::mutex writeBufMutex;
54 51  
55   - inline size_t getReadBufBytesUsed() const
56   - {
57   - size_t result;
58   - if (wi >= ri)
59   - result = wi - ri;
60   - else
61   - result = (readBufsize + wi) - ri;
62   - return result;
63   - };
64   -
65   - inline size_t getReadBufFreeSpace() const
66   - {
67   - size_t result = readBufsize - getReadBufBytesUsed() - 1;
68   - return result;
69   - }
70   -
71   - inline size_t getReadBufMaxWriteSize() const
72   - {
73   - const size_t available_space = getReadBufFreeSpace();
74   -
75   - size_t result = 0;
76   - if (wi >= ri)
77   - result = available_space - wi;
78   - else
79   - result = ri - wi - 1;
80   -
81   - return result;
82   - }
83   -
84   - void growReadBuffer();
85   -
86 52  
87 53 size_t getWriteBufMaxWriteSize()
88 54 {
... ...
mqttpacket.cpp
... ... @@ -12,12 +12,23 @@ RemainingLength::RemainingLength()
12 12 }
13 13  
14 14 // constructor for parsing incoming packets
15   -MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender) :
16   - bites(len),
  15 +MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, Client_p &sender) :
  16 + bites(packet_len),
17 17 fixed_header_length(fixed_header_length),
18 18 sender(sender)
19 19 {
20   - std::memcpy(&bites[0], buf, len);
  20 + int i = 0;
  21 + ssize_t _packet_len = packet_len;
  22 + while (_packet_len > 0)
  23 + {
  24 + int readlen = std::min<int>(buf.maxReadSize(), _packet_len);
  25 + std::memcpy(&bites[i], buf.tailPtr(), readlen);
  26 + buf.advanceTail(readlen);
  27 + i += readlen;
  28 + _packet_len -= readlen;
  29 + }
  30 + assert(_packet_len == 0);
  31 +
21 32 first_byte = bites[0];
22 33 unsigned char _packetType = (first_byte & 0xF0) >> 4;
23 34 packetType = (PacketType)_packetType;
... ...
mqttpacket.h
... ... @@ -12,6 +12,7 @@
12 12 #include "exceptions.h"
13 13 #include "types.h"
14 14 #include "subscriptionstore.h"
  15 +#include "cirbuf.h"
15 16  
16 17 struct RemainingLength
17 18 {
... ... @@ -42,7 +43,7 @@ class MqttPacket
42 43  
43 44 public:
44 45 PacketType packetType = PacketType::Reserved;
45   - MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client_p &sender); // Constructor for parsing incoming packets.
  46 + MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, Client_p &sender); // Constructor for parsing incoming packets.
46 47  
47 48 // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance.
48 49 MqttPacket(const ConnAck &connAck);
... ...