diff --git a/CMakeLists.txt b/CMakeLists.txt index 809a6ab..46e510d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -59,6 +59,7 @@ add_executable(FlashMQ threadglobals.h threadloop.h publishcopyfactory.h + variablebyteint.h mainapp.cpp main.cpp @@ -97,6 +98,7 @@ add_executable(FlashMQ threadglobals.cpp threadloop.cpp publishcopyfactory.cpp + variablebyteint.cpp ) diff --git a/FlashMQTests/FlashMQTests.pro b/FlashMQTests/FlashMQTests.pro index 74d7eff..43c3942 100644 --- a/FlashMQTests/FlashMQTests.pro +++ b/FlashMQTests/FlashMQTests.pro @@ -50,6 +50,7 @@ SOURCES += tst_maintests.cpp \ ../threadglobals.cpp \ ../threadloop.cpp \ ../publishcopyfactory.cpp \ + ../variablebyteint.cpp \ mainappthread.cpp \ twoclienttestcontext.cpp @@ -92,6 +93,7 @@ HEADERS += \ ../threadglobals.h \ ../threadloop.h \ ../publishcopyfactory.h \ + ../variablebyteint.h \ mainappthread.h \ twoclienttestcontext.h diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 8440c7a..393f9e7 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -24,11 +24,6 @@ License along with FlashMQ. If not, see . #include "utils.h" #include "threadglobals.h" -RemainingLength::RemainingLength() -{ - memset(bytes, 0, 4); -} - // constructor for parsing incoming packets MqttPacket::MqttPacket(CirBuf &buf, size_t packet_len, size_t fixed_header_length, std::shared_ptr &sender) : bites(packet_len), @@ -760,27 +755,7 @@ void MqttPacket::handlePubComp() void MqttPacket::calculateRemainingLength() { assert(fixed_header_length == 0); // because you're not supposed to call this on packet that we already know the length of. - - size_t x = bites.size(); - - do - { - if (remainingLength.len > 4) - throw std::runtime_error("Calculated remaining length is longer than 4 bytes."); - - char encodedByte = x % 128; - x = x / 128; - if (x > 0) - encodedByte = encodedByte | 128; - remainingLength.bytes[remainingLength.len++] = encodedByte; - } - while(x > 0); -} - -RemainingLength MqttPacket::getRemainingLength() const -{ - assert(remainingLength.len > 0); - return remainingLength; + this->remainingLength = bites.size(); } void MqttPacket::setPacketId(uint16_t packet_id) @@ -846,7 +821,7 @@ size_t MqttPacket::getSizeIncludingNonPresentHeader() const if (fixed_header_length == 0) { total++; - total += remainingLength.len; + total += remainingLength.getLen(); } return total; @@ -899,11 +874,6 @@ bool MqttPacket::containsFixedHeader() const return fixed_header_length > 0; } -char MqttPacket::getFirstByte() const -{ - return first_byte; -} - char *MqttPacket::readBytes(size_t length) { if (pos + length > bites.size()) @@ -991,11 +961,9 @@ void MqttPacket::readIntoBuf(CirBuf &buf) const if (!containsFixedHeader()) { - assert(remainingLength.len > 0); - - buf.headPtr()[0] = getFirstByte(); + buf.headPtr()[0] = first_byte; buf.advanceHead(1); - buf.write(remainingLength.bytes, remainingLength.len); + remainingLength.readIntoBuf(buf); } else { diff --git a/mqttpacket.h b/mqttpacket.h index c7d0840..3c89e2c 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -33,13 +33,7 @@ License along with FlashMQ. If not, see . #include "logger.h" #include "mainapp.h" -struct RemainingLength -{ - char bytes[4]; - int len = 0; -public: - RemainingLength(); -}; +#include "variablebyteint.h" class MqttPacket { @@ -51,7 +45,7 @@ class MqttPacket std::vector subtopics; std::vector bites; size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header. - RemainingLength remainingLength; + VariableByteInt remainingLength; char qos = 0; std::shared_ptr sender; char first_byte = 0; @@ -115,8 +109,6 @@ public: std::shared_ptr getSender() const; void setSender(const std::shared_ptr &value); bool containsFixedHeader() const; - char getFirstByte() const; - RemainingLength getRemainingLength() const; void setPacketId(uint16_t packet_id); uint16_t getPacketId() const; void setDuplicate(); diff --git a/variablebyteint.cpp b/variablebyteint.cpp new file mode 100644 index 0000000..b126a07 --- /dev/null +++ b/variablebyteint.cpp @@ -0,0 +1,36 @@ +#include "variablebyteint.h" + +#include +#include +#include + +void VariableByteInt::readIntoBuf(CirBuf &buf) const +{ + assert(len > 0); + buf.write(bytes, len); +} + +VariableByteInt &VariableByteInt::operator=(uint32_t x) +{ + if (x > 268435455) + throw std::runtime_error("Value of variable byte int to encode too big. Bug or corrupt packet?"); + + len = 0; + + do + { + uint8_t encodedByte = x % 128; + x = x / 128; + if (x > 0) + encodedByte = encodedByte | 128; + bytes[len++] = encodedByte; + } + while(x > 0); + + return *this; +} + +uint8_t VariableByteInt::getLen() const +{ + return len; +} diff --git a/variablebyteint.h b/variablebyteint.h new file mode 100644 index 0000000..3322262 --- /dev/null +++ b/variablebyteint.h @@ -0,0 +1,17 @@ +#ifndef VARIABLEBYTEINT_H +#define VARIABLEBYTEINT_H + +#include "cirbuf.h" + +class VariableByteInt +{ + char bytes[4]; + uint8_t len = 0; + +public: + void readIntoBuf(CirBuf &buf) const; + VariableByteInt &operator=(uint32_t x); + uint8_t getLen() const; +}; + +#endif // VARIABLEBYTEINT_H