From 243c873f18ea5d69c5589c15b6a1c33cb66b998d Mon Sep 17 00:00:00 2001 From: Wiebe Cazemier Date: Tue, 8 Dec 2020 22:16:48 +0100 Subject: [PATCH] Much stuff --- CMakeLists.txt | 10 +++++++++- MqttPacket.cpp | 0 bytestopacketparser.cpp | 11 +++++++++++ bytestopacketparser.h | 19 +++++++++++++++++++ client.cpp | 58 ++++++++++++++++++++++++++++++++++++++++++++++------------ client.h | 31 ++++++++++++++++--------------- exceptions.cpp | 2 ++ exceptions.h | 14 ++++++++++++++ main.cpp | 21 +++++++++++++++++---- mqttpacket.cpp | 89 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ mqttpacket.h | 39 +++++++++++++++++++++++++++++++++++++++ types.cpp | 2 ++ types.h | 20 ++++++++++++++++++++ 13 files changed, 284 insertions(+), 32 deletions(-) create mode 100644 MqttPacket.cpp create mode 100644 bytestopacketparser.cpp create mode 100644 bytestopacketparser.h create mode 100644 exceptions.cpp create mode 100644 exceptions.h create mode 100644 mqttpacket.cpp create mode 100644 mqttpacket.h create mode 100644 types.cpp create mode 100644 types.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 3d8484c..477563e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,6 +6,14 @@ set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD_REQUIRED ON) -add_executable(FlashMQ main.cpp utils.cpp threaddata.cpp client.cpp) +add_executable(FlashMQ + main.cpp + utils.cpp + threaddata.cpp + client.cpp + bytestopacketparser.cpp + mqttpacket.cpp + exceptions.cpp + types.cpp) target_link_libraries(FlashMQ pthread) diff --git a/MqttPacket.cpp b/MqttPacket.cpp new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/MqttPacket.cpp diff --git a/bytestopacketparser.cpp b/bytestopacketparser.cpp new file mode 100644 index 0000000..01a600e --- /dev/null +++ b/bytestopacketparser.cpp @@ -0,0 +1,11 @@ +#include "bytestopacketparser.h" + +BytesToPacketParser::BytesToPacketParser() +{ + +} + +size_t BytesToPacketParser::bytesToPackets(char *buf, size_t len, std::vector &queue) +{ + +} diff --git a/bytestopacketparser.h b/bytestopacketparser.h new file mode 100644 index 0000000..9dfe370 --- /dev/null +++ b/bytestopacketparser.h @@ -0,0 +1,19 @@ +#ifndef BYTESTOPACKETPARSER_H +#define BYTESTOPACKETPARSER_H + +#include +#include + +#include "mqttpacket.h" + +#define MQTT_HEADER_LENGH 2 + +class BytesToPacketParser +{ +public: + BytesToPacketParser(); + + size_t bytesToPackets(char *buf, size_t len, std::vector &queue); +}; + +#endif // BYTESTOPACKETPARSER_H diff --git a/client.cpp b/client.cpp index 7452dd0..3d8318f 100644 --- a/client.cpp +++ b/client.cpp @@ -35,17 +35,12 @@ bool Client::readFdIntoBuffer() } wi += n; - size_t bytesUsed = getBufBytesUsed(); - // TODO: we need a buffer to keep partial frames in, so/and can we reduce the size of this buffer again periodically? - if (bytesUsed >= bufsize) + if (getBufBytesUsed() >= bufsize) { - const size_t newBufSize = bufsize * 2; - readbuf = (char*)realloc(readbuf, newBufSize); - bufsize = newBufSize; + growBuffer(); } - wi = wi % bufsize; read_size = getMaxWriteSize(); } @@ -57,12 +52,51 @@ bool Client::readFdIntoBuffer() return true; } -void Client::writeTest() +bool Client::bufferToMqttPackets(std::vector &packetQueueIn) { - char *p = &readbuf[ri]; - size_t max_read = getMaxReadSize(); - ri = (ri + max_read) % bufsize; - write(fd, p, max_read); + while (getBufBytesUsed() >= MQTT_HEADER_LENGH) + { + // Determine the packet length by decoding the variable length + size_t remaining_length_i = 1; + int multiplier = 1; + size_t packet_length = 0; + unsigned char encodedByte = 0; + do + { + if (remaining_length_i >= getBufBytesUsed()) + break; + encodedByte = readbuf[remaining_length_i++]; + packet_length += (encodedByte & 127) * multiplier; + multiplier *= 128; + if (multiplier > 128*128*128) + return false; + } + while ((encodedByte & 128) != 0); + packet_length += remaining_length_i; + + // TODO: unauth client can't send many bytes + + if (packet_length <= getBufBytesUsed()) + { + MqttPacket packet(&readbuf[ri], packet_length, remaining_length_i, this); + packetQueueIn.push_back(std::move(packet)); + + ri += packet_length; + } + else + break; + + } + + if (ri == wi) + { + ri = 0; + wi = 0; + } + + return true; + + // TODO: reset buffer to normal size after a while of not needing it, or not needing the extra space. } diff --git a/client.h b/client.h index 925c900..dbac6da 100644 --- a/client.h +++ b/client.h @@ -3,14 +3,19 @@ #include #include +#include #include "threaddata.h" +#include "mqttpacket.h" -#define CLIENT_BUFFER_SIZE 16 +#define CLIENT_BUFFER_SIZE 1024 +#define MQTT_HEADER_LENGH 2 class ThreadData; typedef std::shared_ptr ThreadData_p; +class MqttPacket; + class Client { int fd; @@ -20,31 +25,27 @@ class Client int wi = 0; int ri = 0; + bool authenticated = false; + std::string clientid; + ThreadData_p threadData; size_t getBufBytesUsed() { - size_t result = 0; - if (wi >= ri) - result = wi - ri; - else - result = (bufsize + wi) - ri; + return wi - ri; }; size_t getMaxWriteSize() { size_t available = bufsize - getBufBytesUsed(); - size_t space_at_end = bufsize - wi; - size_t answer = std::min(available, space_at_end); - return answer; + return available; } - size_t getMaxReadSize() + void growBuffer() { - size_t available = getBufBytesUsed(); - size_t space_to_end = bufsize - ri; - size_t answer = std::min(available, space_to_end); - return answer; + const size_t newBufSize = bufsize * 2; + readbuf = (char*)realloc(readbuf, newBufSize); + bufsize = newBufSize; } public: @@ -53,7 +54,7 @@ public: int getFd() { return fd;} bool readFdIntoBuffer(); - void writeTest(); + bool bufferToMqttPackets(std::vector &packetQueueIn); }; diff --git a/exceptions.cpp b/exceptions.cpp new file mode 100644 index 0000000..3a4a7aa --- /dev/null +++ b/exceptions.cpp @@ -0,0 +1,2 @@ +#include "exceptions.h" + diff --git a/exceptions.h b/exceptions.h new file mode 100644 index 0000000..c8c1856 --- /dev/null +++ b/exceptions.h @@ -0,0 +1,14 @@ +#ifndef EXCEPTIONS_H +#define EXCEPTIONS_H + +#include +#include + +class ProtocolError : public std::runtime_error +{ +public: + ProtocolError(const std::string &msg) : std::runtime_error(msg) {} +}; + + +#endif // EXCEPTIONS_H diff --git a/main.cpp b/main.cpp index 4851ea9..14632bd 100644 --- a/main.cpp +++ b/main.cpp @@ -10,6 +10,7 @@ #include "utils.h" #include "threaddata.h" #include "client.h" +#include "mqttpacket.h" #define MAX_EVENTS 1024 #define NR_OF_THREADS 4 @@ -23,6 +24,8 @@ void do_thread_work(ThreadData *threadData) struct epoll_event events[MAX_EVENTS]; memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); + std::vector packetQueueIn; + while (1) { int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); @@ -38,13 +41,23 @@ void do_thread_work(ThreadData *threadData) if (client) // TODO: is this check necessary? { - if (!client->readFdIntoBuffer()) - threadData->removeClient(client); - client->writeTest(); - + if (cur_ev.events | EPOLLIN) + { + if (!client->readFdIntoBuffer()) + threadData->removeClient(client); + else + { + client->bufferToMqttPackets(packetQueueIn); // TODO: different, because now I need to give the packet a raw pointer. + } + } } } } + + for (MqttPacket &packet : packetQueueIn) + { + packet.handle(); + } } } diff --git a/mqttpacket.cpp b/mqttpacket.cpp new file mode 100644 index 0000000..4872d34 --- /dev/null +++ b/mqttpacket.cpp @@ -0,0 +1,89 @@ +#include "mqttpacket.h" +#include + +MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender) : // TODO: length of remaining length + bites(len), + fixed_header_length(fixed_header_length), + sender(sender) +{ + unsigned char _packetType = buf[0] >> 4; + packetType = (PacketType)_packetType; // TODO: veryify some other things and set to invalid if doesn't match + + std::memcpy(&bites[0], buf, len); +} + +void MqttPacket::handle() +{ + pos += fixed_header_length; + + if (packetType == PacketType::CONNECT) + handleConnect(); +} + +void MqttPacket::handleConnect() +{ + // TODO: Do all packets have a variable header? + variable_header_length = (bites[fixed_header_length] << 8) | (bites[fixed_header_length+1]); + pos += 2; + + if (variable_header_length == 4) + { + char *c = readBytes(variable_header_length); + std::string magic_marker(c, variable_header_length); + + char protocol_level = readByte(); + + if (magic_marker == "MQTT" && protocol_level == 0x04) + { + protocolVersion = ProtocolVersion::Mqtt311; + } + } + else if (variable_header_length == 6) + { + throw new ProtocolError("Only MQTT 3.1.1 implemented."); + } +} + +char *MqttPacket::readBytes(size_t length) +{ + if (pos + length > bites.size()) + throw ProtocolError("Invalid packet: header specifies invalid length."); + + char *b = &bites[pos]; + pos += length; + return b; +} + +char MqttPacket::readByte() +{ + if (pos + 1 > bites.size()) + throw ProtocolError("Invalid packet: header specifies invalid length."); + + char b = bites[pos]; + pos++; + return b; +} + + + +std::string MqttPacket::getClientId() +{ + if (packetType != PacketType::CONNECT) + throw ProtocolError("Can't get clientid from non-connect packet."); + + uint16_t clientid_length = (bites[fixed_header_length + 10] << 8) | (bites[fixed_header_length + 11]); + size_t client_id_start = fixed_header_length + 12; + + if (clientid_length + 12 < bites.size()) + { + std::string result(&bites[client_id_start], clientid_length); + return result; + } + + throw ProtocolError("Can't get clientid"); +} + + + + + diff --git a/mqttpacket.h b/mqttpacket.h new file mode 100644 index 0000000..fbba082 --- /dev/null +++ b/mqttpacket.h @@ -0,0 +1,39 @@ +#ifndef MQTTPACKET_H +#define MQTTPACKET_H + +#include "unistd.h" +#include +#include +#include + +#include "client.h" +#include "exceptions.h" +#include "types.h" + +class Client; + + +class MqttPacket +{ + bool valid = false; + + std::vector bites; + const size_t fixed_header_length; + uint16_t variable_header_length; + Client *sender; + std::string clientid; + size_t pos = 0; + ProtocolVersion protocolVersion = ProtocolVersion::None; +public: + PacketType packetType = PacketType::Reserved; + MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender); + + bool isValid() { return valid; } + std::string getClientId(); + void handle(); + void handleConnect(); + char *readBytes(size_t length); + char readByte(); +}; + +#endif // MQTTPACKET_H diff --git a/types.cpp b/types.cpp new file mode 100644 index 0000000..868cc2f --- /dev/null +++ b/types.cpp @@ -0,0 +1,2 @@ +#include "types.h" + diff --git a/types.h b/types.h new file mode 100644 index 0000000..57598cb --- /dev/null +++ b/types.h @@ -0,0 +1,20 @@ +#ifndef TYPES_H +#define TYPES_H + +enum class PacketType +{ + Reserved = 0, + CONNECT = 1, + CONNACK = 2, + + Reserved2 = 15 +}; + +enum class ProtocolVersion +{ + None = 0, + Mqtt31 = 0x03, + Mqtt311 = 0x04 +}; + +#endif // TYPES_H -- libgit2 0.21.4