diff --git a/CMakeLists.txt b/CMakeLists.txt index 3d8484c..477563e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,6 +6,14 @@ set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD_REQUIRED ON) -add_executable(FlashMQ main.cpp utils.cpp threaddata.cpp client.cpp) +add_executable(FlashMQ + main.cpp + utils.cpp + threaddata.cpp + client.cpp + bytestopacketparser.cpp + mqttpacket.cpp + exceptions.cpp + types.cpp) target_link_libraries(FlashMQ pthread) diff --git a/MqttPacket.cpp b/MqttPacket.cpp new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/MqttPacket.cpp diff --git a/bytestopacketparser.cpp b/bytestopacketparser.cpp new file mode 100644 index 0000000..01a600e --- /dev/null +++ b/bytestopacketparser.cpp @@ -0,0 +1,11 @@ +#include "bytestopacketparser.h" + +BytesToPacketParser::BytesToPacketParser() +{ + +} + +size_t BytesToPacketParser::bytesToPackets(char *buf, size_t len, std::vector &queue) +{ + +} diff --git a/bytestopacketparser.h b/bytestopacketparser.h new file mode 100644 index 0000000..9dfe370 --- /dev/null +++ b/bytestopacketparser.h @@ -0,0 +1,19 @@ +#ifndef BYTESTOPACKETPARSER_H +#define BYTESTOPACKETPARSER_H + +#include +#include + +#include "mqttpacket.h" + +#define MQTT_HEADER_LENGH 2 + +class BytesToPacketParser +{ +public: + BytesToPacketParser(); + + size_t bytesToPackets(char *buf, size_t len, std::vector &queue); +}; + +#endif // BYTESTOPACKETPARSER_H diff --git a/client.cpp b/client.cpp index 7452dd0..3d8318f 100644 --- a/client.cpp +++ b/client.cpp @@ -35,17 +35,12 @@ bool Client::readFdIntoBuffer() } wi += n; - size_t bytesUsed = getBufBytesUsed(); - // TODO: we need a buffer to keep partial frames in, so/and can we reduce the size of this buffer again periodically? - if (bytesUsed >= bufsize) + if (getBufBytesUsed() >= bufsize) { - const size_t newBufSize = bufsize * 2; - readbuf = (char*)realloc(readbuf, newBufSize); - bufsize = newBufSize; + growBuffer(); } - wi = wi % bufsize; read_size = getMaxWriteSize(); } @@ -57,12 +52,51 @@ bool Client::readFdIntoBuffer() return true; } -void Client::writeTest() +bool Client::bufferToMqttPackets(std::vector &packetQueueIn) { - char *p = &readbuf[ri]; - size_t max_read = getMaxReadSize(); - ri = (ri + max_read) % bufsize; - write(fd, p, max_read); + while (getBufBytesUsed() >= MQTT_HEADER_LENGH) + { + // Determine the packet length by decoding the variable length + size_t remaining_length_i = 1; + int multiplier = 1; + size_t packet_length = 0; + unsigned char encodedByte = 0; + do + { + if (remaining_length_i >= getBufBytesUsed()) + break; + encodedByte = readbuf[remaining_length_i++]; + packet_length += (encodedByte & 127) * multiplier; + multiplier *= 128; + if (multiplier > 128*128*128) + return false; + } + while ((encodedByte & 128) != 0); + packet_length += remaining_length_i; + + // TODO: unauth client can't send many bytes + + if (packet_length <= getBufBytesUsed()) + { + MqttPacket packet(&readbuf[ri], packet_length, remaining_length_i, this); + packetQueueIn.push_back(std::move(packet)); + + ri += packet_length; + } + else + break; + + } + + if (ri == wi) + { + ri = 0; + wi = 0; + } + + return true; + + // TODO: reset buffer to normal size after a while of not needing it, or not needing the extra space. } diff --git a/client.h b/client.h index 925c900..dbac6da 100644 --- a/client.h +++ b/client.h @@ -3,14 +3,19 @@ #include #include +#include #include "threaddata.h" +#include "mqttpacket.h" -#define CLIENT_BUFFER_SIZE 16 +#define CLIENT_BUFFER_SIZE 1024 +#define MQTT_HEADER_LENGH 2 class ThreadData; typedef std::shared_ptr ThreadData_p; +class MqttPacket; + class Client { int fd; @@ -20,31 +25,27 @@ class Client int wi = 0; int ri = 0; + bool authenticated = false; + std::string clientid; + ThreadData_p threadData; size_t getBufBytesUsed() { - size_t result = 0; - if (wi >= ri) - result = wi - ri; - else - result = (bufsize + wi) - ri; + return wi - ri; }; size_t getMaxWriteSize() { size_t available = bufsize - getBufBytesUsed(); - size_t space_at_end = bufsize - wi; - size_t answer = std::min(available, space_at_end); - return answer; + return available; } - size_t getMaxReadSize() + void growBuffer() { - size_t available = getBufBytesUsed(); - size_t space_to_end = bufsize - ri; - size_t answer = std::min(available, space_to_end); - return answer; + const size_t newBufSize = bufsize * 2; + readbuf = (char*)realloc(readbuf, newBufSize); + bufsize = newBufSize; } public: @@ -53,7 +54,7 @@ public: int getFd() { return fd;} bool readFdIntoBuffer(); - void writeTest(); + bool bufferToMqttPackets(std::vector &packetQueueIn); }; diff --git a/exceptions.cpp b/exceptions.cpp new file mode 100644 index 0000000..3a4a7aa --- /dev/null +++ b/exceptions.cpp @@ -0,0 +1,2 @@ +#include "exceptions.h" + diff --git a/exceptions.h b/exceptions.h new file mode 100644 index 0000000..c8c1856 --- /dev/null +++ b/exceptions.h @@ -0,0 +1,14 @@ +#ifndef EXCEPTIONS_H +#define EXCEPTIONS_H + +#include +#include + +class ProtocolError : public std::runtime_error +{ +public: + ProtocolError(const std::string &msg) : std::runtime_error(msg) {} +}; + + +#endif // EXCEPTIONS_H diff --git a/main.cpp b/main.cpp index 4851ea9..14632bd 100644 --- a/main.cpp +++ b/main.cpp @@ -10,6 +10,7 @@ #include "utils.h" #include "threaddata.h" #include "client.h" +#include "mqttpacket.h" #define MAX_EVENTS 1024 #define NR_OF_THREADS 4 @@ -23,6 +24,8 @@ void do_thread_work(ThreadData *threadData) struct epoll_event events[MAX_EVENTS]; memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); + std::vector packetQueueIn; + while (1) { int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); @@ -38,13 +41,23 @@ void do_thread_work(ThreadData *threadData) if (client) // TODO: is this check necessary? { - if (!client->readFdIntoBuffer()) - threadData->removeClient(client); - client->writeTest(); - + if (cur_ev.events | EPOLLIN) + { + if (!client->readFdIntoBuffer()) + threadData->removeClient(client); + else + { + client->bufferToMqttPackets(packetQueueIn); // TODO: different, because now I need to give the packet a raw pointer. + } + } } } } + + for (MqttPacket &packet : packetQueueIn) + { + packet.handle(); + } } } diff --git a/mqttpacket.cpp b/mqttpacket.cpp new file mode 100644 index 0000000..4872d34 --- /dev/null +++ b/mqttpacket.cpp @@ -0,0 +1,89 @@ +#include "mqttpacket.h" +#include + +MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender) : // TODO: length of remaining length + bites(len), + fixed_header_length(fixed_header_length), + sender(sender) +{ + unsigned char _packetType = buf[0] >> 4; + packetType = (PacketType)_packetType; // TODO: veryify some other things and set to invalid if doesn't match + + std::memcpy(&bites[0], buf, len); +} + +void MqttPacket::handle() +{ + pos += fixed_header_length; + + if (packetType == PacketType::CONNECT) + handleConnect(); +} + +void MqttPacket::handleConnect() +{ + // TODO: Do all packets have a variable header? + variable_header_length = (bites[fixed_header_length] << 8) | (bites[fixed_header_length+1]); + pos += 2; + + if (variable_header_length == 4) + { + char *c = readBytes(variable_header_length); + std::string magic_marker(c, variable_header_length); + + char protocol_level = readByte(); + + if (magic_marker == "MQTT" && protocol_level == 0x04) + { + protocolVersion = ProtocolVersion::Mqtt311; + } + } + else if (variable_header_length == 6) + { + throw new ProtocolError("Only MQTT 3.1.1 implemented."); + } +} + +char *MqttPacket::readBytes(size_t length) +{ + if (pos + length > bites.size()) + throw ProtocolError("Invalid packet: header specifies invalid length."); + + char *b = &bites[pos]; + pos += length; + return b; +} + +char MqttPacket::readByte() +{ + if (pos + 1 > bites.size()) + throw ProtocolError("Invalid packet: header specifies invalid length."); + + char b = bites[pos]; + pos++; + return b; +} + + + +std::string MqttPacket::getClientId() +{ + if (packetType != PacketType::CONNECT) + throw ProtocolError("Can't get clientid from non-connect packet."); + + uint16_t clientid_length = (bites[fixed_header_length + 10] << 8) | (bites[fixed_header_length + 11]); + size_t client_id_start = fixed_header_length + 12; + + if (clientid_length + 12 < bites.size()) + { + std::string result(&bites[client_id_start], clientid_length); + return result; + } + + throw ProtocolError("Can't get clientid"); +} + + + + + diff --git a/mqttpacket.h b/mqttpacket.h new file mode 100644 index 0000000..fbba082 --- /dev/null +++ b/mqttpacket.h @@ -0,0 +1,39 @@ +#ifndef MQTTPACKET_H +#define MQTTPACKET_H + +#include "unistd.h" +#include +#include +#include + +#include "client.h" +#include "exceptions.h" +#include "types.h" + +class Client; + + +class MqttPacket +{ + bool valid = false; + + std::vector bites; + const size_t fixed_header_length; + uint16_t variable_header_length; + Client *sender; + std::string clientid; + size_t pos = 0; + ProtocolVersion protocolVersion = ProtocolVersion::None; +public: + PacketType packetType = PacketType::Reserved; + MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender); + + bool isValid() { return valid; } + std::string getClientId(); + void handle(); + void handleConnect(); + char *readBytes(size_t length); + char readByte(); +}; + +#endif // MQTTPACKET_H diff --git a/types.cpp b/types.cpp new file mode 100644 index 0000000..868cc2f --- /dev/null +++ b/types.cpp @@ -0,0 +1,2 @@ +#include "types.h" + diff --git a/types.h b/types.h new file mode 100644 index 0000000..57598cb --- /dev/null +++ b/types.h @@ -0,0 +1,20 @@ +#ifndef TYPES_H +#define TYPES_H + +enum class PacketType +{ + Reserved = 0, + CONNECT = 1, + CONNACK = 2, + + Reserved2 = 15 +}; + +enum class ProtocolVersion +{ + None = 0, + Mqtt31 = 0x03, + Mqtt311 = 0x04 +}; + +#endif // TYPES_H