diff --git a/CMakeLists.txt b/CMakeLists.txt
index 809a6ab..46e510d 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -59,6 +59,7 @@ add_executable(FlashMQ
threadglobals.h
threadloop.h
publishcopyfactory.h
+ variablebyteint.h
mainapp.cpp
main.cpp
@@ -97,6 +98,7 @@ add_executable(FlashMQ
threadglobals.cpp
threadloop.cpp
publishcopyfactory.cpp
+ variablebyteint.cpp
)
diff --git a/FlashMQTests/FlashMQTests.pro b/FlashMQTests/FlashMQTests.pro
index 74d7eff..43c3942 100644
--- a/FlashMQTests/FlashMQTests.pro
+++ b/FlashMQTests/FlashMQTests.pro
@@ -50,6 +50,7 @@ SOURCES += tst_maintests.cpp \
../threadglobals.cpp \
../threadloop.cpp \
../publishcopyfactory.cpp \
+ ../variablebyteint.cpp \
mainappthread.cpp \
twoclienttestcontext.cpp
@@ -92,6 +93,7 @@ HEADERS += \
../threadglobals.h \
../threadloop.h \
../publishcopyfactory.h \
+ ../variablebyteint.h \
mainappthread.h \
twoclienttestcontext.h
diff --git a/mqttpacket.cpp b/mqttpacket.cpp
index 8440c7a..393f9e7 100644
--- a/mqttpacket.cpp
+++ b/mqttpacket.cpp
@@ -24,11 +24,6 @@ License along with FlashMQ. If not, see .
#include "utils.h"
#include "threadglobals.h"
-RemainingLength::RemainingLength()
-{
- memset(bytes, 0, 4);
-}
-
// constructor for parsing incoming packets
MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, std::shared_ptr &sender) :
bites(packet_len),
@@ -760,27 +755,7 @@ void MqttPacket::handlePubComp()
void MqttPacket::calculateRemainingLength()
{
assert(fixed_header_length == 0); // because you're not supposed to call this on packet that we already know the length of.
-
- size_t x = bites.size();
-
- do
- {
- if (remainingLength.len > 4)
- throw std::runtime_error("Calculated remaining length is longer than 4 bytes.");
-
- char encodedByte = x % 128;
- x = x / 128;
- if (x > 0)
- encodedByte = encodedByte | 128;
- remainingLength.bytes[remainingLength.len++] = encodedByte;
- }
- while(x > 0);
-}
-
-RemainingLength MqttPacket::getRemainingLength() const
-{
- assert(remainingLength.len > 0);
- return remainingLength;
+ this->remainingLength = bites.size();
}
void MqttPacket::setPacketId(uint16_t packet_id)
@@ -846,7 +821,7 @@ size_t MqttPacket::getSizeIncludingNonPresentHeader() const
if (fixed_header_length == 0)
{
total++;
- total += remainingLength.len;
+ total += remainingLength.getLen();
}
return total;
@@ -899,11 +874,6 @@ bool MqttPacket::containsFixedHeader() const
return fixed_header_length > 0;
}
-char MqttPacket::getFirstByte() const
-{
- return first_byte;
-}
-
char *MqttPacket::readBytes(size_t length)
{
if (pos + length > bites.size())
@@ -991,11 +961,9 @@ void MqttPacket::readIntoBuf(CirBuf &buf) const
if (!containsFixedHeader())
{
- assert(remainingLength.len > 0);
-
- buf.headPtr()[0] = getFirstByte();
+ buf.headPtr()[0] = first_byte;
buf.advanceHead(1);
- buf.write(remainingLength.bytes, remainingLength.len);
+ remainingLength.readIntoBuf(buf);
}
else
{
diff --git a/mqttpacket.h b/mqttpacket.h
index c7d0840..3c89e2c 100644
--- a/mqttpacket.h
+++ b/mqttpacket.h
@@ -33,13 +33,7 @@ License along with FlashMQ. If not, see .
#include "logger.h"
#include "mainapp.h"
-struct RemainingLength
-{
- char bytes[4];
- int len = 0;
-public:
- RemainingLength();
-};
+#include "variablebyteint.h"
class MqttPacket
{
@@ -51,7 +45,7 @@ class MqttPacket
std::vector subtopics;
std::vector bites;
size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header.
- RemainingLength remainingLength;
+ VariableByteInt remainingLength;
char qos = 0;
std::shared_ptr sender;
char first_byte = 0;
@@ -115,8 +109,6 @@ public:
std::shared_ptr getSender() const;
void setSender(const std::shared_ptr &value);
bool containsFixedHeader() const;
- char getFirstByte() const;
- RemainingLength getRemainingLength() const;
void setPacketId(uint16_t packet_id);
uint16_t getPacketId() const;
void setDuplicate();
diff --git a/variablebyteint.cpp b/variablebyteint.cpp
new file mode 100644
index 0000000..b126a07
--- /dev/null
+++ b/variablebyteint.cpp
@@ -0,0 +1,36 @@
+#include "variablebyteint.h"
+
+#include
+#include
+#include
+
+void VariableByteInt::readIntoBuf(CirBuf &buf) const
+{
+ assert(len > 0);
+ buf.write(bytes, len);
+}
+
+VariableByteInt &VariableByteInt::operator=(uint32_t x)
+{
+ if (x > 268435455)
+ throw std::runtime_error("Value of variable byte int to encode too big. Bug or corrupt packet?");
+
+ len = 0;
+
+ do
+ {
+ uint8_t encodedByte = x % 128;
+ x = x / 128;
+ if (x > 0)
+ encodedByte = encodedByte | 128;
+ bytes[len++] = encodedByte;
+ }
+ while(x > 0);
+
+ return *this;
+}
+
+uint8_t VariableByteInt::getLen() const
+{
+ return len;
+}
diff --git a/variablebyteint.h b/variablebyteint.h
new file mode 100644
index 0000000..3322262
--- /dev/null
+++ b/variablebyteint.h
@@ -0,0 +1,17 @@
+#ifndef VARIABLEBYTEINT_H
+#define VARIABLEBYTEINT_H
+
+#include "cirbuf.h"
+
+class VariableByteInt
+{
+ char bytes[4];
+ uint8_t len = 0;
+
+public:
+ void readIntoBuf(CirBuf &buf) const;
+ VariableByteInt &operator=(uint32_t x);
+ uint8_t getLen() const;
+};
+
+#endif // VARIABLEBYTEINT_H