Commit 243c873f18ea5d69c5589c15b6a1c33cb66b998d

Authored by Wiebe Cazemier
1 parent d11fd364

Much stuff

CMakeLists.txt
... ... @@ -6,6 +6,14 @@ set(CMAKE_CXX_STANDARD 11)
6 6 set(CMAKE_CXX_STANDARD_REQUIRED ON)
7 7  
8 8  
9   -add_executable(FlashMQ main.cpp utils.cpp threaddata.cpp client.cpp)
  9 +add_executable(FlashMQ
  10 + main.cpp
  11 + utils.cpp
  12 + threaddata.cpp
  13 + client.cpp
  14 + bytestopacketparser.cpp
  15 + mqttpacket.cpp
  16 + exceptions.cpp
  17 + types.cpp)
10 18  
11 19 target_link_libraries(FlashMQ pthread)
... ...
MqttPacket.cpp 0 → 100644
bytestopacketparser.cpp 0 → 100644
  1 +#include "bytestopacketparser.h"
  2 +
  3 +BytesToPacketParser::BytesToPacketParser()
  4 +{
  5 +
  6 +}
  7 +
  8 +size_t BytesToPacketParser::bytesToPackets(char *buf, size_t len, std::vector<MqttPacket> &queue)
  9 +{
  10 +
  11 +}
... ...
bytestopacketparser.h 0 → 100644
  1 +#ifndef BYTESTOPACKETPARSER_H
  2 +#define BYTESTOPACKETPARSER_H
  3 +
  4 +#include <unistd.h>
  5 +#include <vector>
  6 +
  7 +#include "mqttpacket.h"
  8 +
  9 +#define MQTT_HEADER_LENGH 2
  10 +
  11 +class BytesToPacketParser
  12 +{
  13 +public:
  14 + BytesToPacketParser();
  15 +
  16 + size_t bytesToPackets(char *buf, size_t len, std::vector<MqttPacket> &queue);
  17 +};
  18 +
  19 +#endif // BYTESTOPACKETPARSER_H
... ...
client.cpp
... ... @@ -35,17 +35,12 @@ bool Client::readFdIntoBuffer()
35 35 }
36 36  
37 37 wi += n;
38   - size_t bytesUsed = getBufBytesUsed();
39 38  
40   - // TODO: we need a buffer to keep partial frames in, so/and can we reduce the size of this buffer again periodically?
41   - if (bytesUsed >= bufsize)
  39 + if (getBufBytesUsed() >= bufsize)
42 40 {
43   - const size_t newBufSize = bufsize * 2;
44   - readbuf = (char*)realloc(readbuf, newBufSize);
45   - bufsize = newBufSize;
  41 + growBuffer();
46 42 }
47 43  
48   - wi = wi % bufsize;
49 44 read_size = getMaxWriteSize();
50 45 }
51 46  
... ... @@ -57,12 +52,51 @@ bool Client::readFdIntoBuffer()
57 52 return true;
58 53 }
59 54  
60   -void Client::writeTest()
  55 +bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn)
61 56 {
62   - char *p = &readbuf[ri];
63   - size_t max_read = getMaxReadSize();
64   - ri = (ri + max_read) % bufsize;
65   - write(fd, p, max_read);
  57 + while (getBufBytesUsed() >= MQTT_HEADER_LENGH)
  58 + {
  59 + // Determine the packet length by decoding the variable length
  60 + size_t remaining_length_i = 1;
  61 + int multiplier = 1;
  62 + size_t packet_length = 0;
  63 + unsigned char encodedByte = 0;
  64 + do
  65 + {
  66 + if (remaining_length_i >= getBufBytesUsed())
  67 + break;
  68 + encodedByte = readbuf[remaining_length_i++];
  69 + packet_length += (encodedByte & 127) * multiplier;
  70 + multiplier *= 128;
  71 + if (multiplier > 128*128*128)
  72 + return false;
  73 + }
  74 + while ((encodedByte & 128) != 0);
  75 + packet_length += remaining_length_i;
  76 +
  77 + // TODO: unauth client can't send many bytes
  78 +
  79 + if (packet_length <= getBufBytesUsed())
  80 + {
  81 + MqttPacket packet(&readbuf[ri], packet_length, remaining_length_i, this);
  82 + packetQueueIn.push_back(std::move(packet));
  83 +
  84 + ri += packet_length;
  85 + }
  86 + else
  87 + break;
  88 +
  89 + }
  90 +
  91 + if (ri == wi)
  92 + {
  93 + ri = 0;
  94 + wi = 0;
  95 + }
  96 +
  97 + return true;
  98 +
  99 + // TODO: reset buffer to normal size after a while of not needing it, or not needing the extra space.
66 100 }
67 101  
68 102  
... ...
client.h
... ... @@ -3,14 +3,19 @@
3 3  
4 4 #include <fcntl.h>
5 5 #include <unistd.h>
  6 +#include <vector>
6 7  
7 8 #include "threaddata.h"
  9 +#include "mqttpacket.h"
8 10  
9   -#define CLIENT_BUFFER_SIZE 16
  11 +#define CLIENT_BUFFER_SIZE 1024
  12 +#define MQTT_HEADER_LENGH 2
10 13  
11 14 class ThreadData;
12 15 typedef std::shared_ptr<ThreadData> ThreadData_p;
13 16  
  17 +class MqttPacket;
  18 +
14 19 class Client
15 20 {
16 21 int fd;
... ... @@ -20,31 +25,27 @@ class Client
20 25 int wi = 0;
21 26 int ri = 0;
22 27  
  28 + bool authenticated = false;
  29 + std::string clientid;
  30 +
23 31 ThreadData_p threadData;
24 32  
25 33 size_t getBufBytesUsed()
26 34 {
27   - size_t result = 0;
28   - if (wi >= ri)
29   - result = wi - ri;
30   - else
31   - result = (bufsize + wi) - ri;
  35 + return wi - ri;
32 36 };
33 37  
34 38 size_t getMaxWriteSize()
35 39 {
36 40 size_t available = bufsize - getBufBytesUsed();
37   - size_t space_at_end = bufsize - wi;
38   - size_t answer = std::min<int>(available, space_at_end);
39   - return answer;
  41 + return available;
40 42 }
41 43  
42   - size_t getMaxReadSize()
  44 + void growBuffer()
43 45 {
44   - size_t available = getBufBytesUsed();
45   - size_t space_to_end = bufsize - ri;
46   - size_t answer = std::min<int>(available, space_to_end);
47   - return answer;
  46 + const size_t newBufSize = bufsize * 2;
  47 + readbuf = (char*)realloc(readbuf, newBufSize);
  48 + bufsize = newBufSize;
48 49 }
49 50  
50 51 public:
... ... @@ -53,7 +54,7 @@ public:
53 54  
54 55 int getFd() { return fd;}
55 56 bool readFdIntoBuffer();
56   - void writeTest();
  57 + bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn);
57 58  
58 59 };
59 60  
... ...
exceptions.cpp 0 → 100644
  1 +#include "exceptions.h"
  2 +
... ...
exceptions.h 0 → 100644
  1 +#ifndef EXCEPTIONS_H
  2 +#define EXCEPTIONS_H
  3 +
  4 +#include <exception>
  5 +#include <stdexcept>
  6 +
  7 +class ProtocolError : public std::runtime_error
  8 +{
  9 +public:
  10 + ProtocolError(const std::string &msg) : std::runtime_error(msg) {}
  11 +};
  12 +
  13 +
  14 +#endif // EXCEPTIONS_H
... ...
main.cpp
... ... @@ -10,6 +10,7 @@
10 10 #include "utils.h"
11 11 #include "threaddata.h"
12 12 #include "client.h"
  13 +#include "mqttpacket.h"
13 14  
14 15 #define MAX_EVENTS 1024
15 16 #define NR_OF_THREADS 4
... ... @@ -23,6 +24,8 @@ void do_thread_work(ThreadData *threadData)
23 24 struct epoll_event events[MAX_EVENTS];
24 25 memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS);
25 26  
  27 + std::vector<MqttPacket> packetQueueIn;
  28 +
26 29 while (1)
27 30 {
28 31 int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100);
... ... @@ -38,13 +41,23 @@ void do_thread_work(ThreadData *threadData)
38 41  
39 42 if (client) // TODO: is this check necessary?
40 43 {
41   - if (!client->readFdIntoBuffer())
42   - threadData->removeClient(client);
43   - client->writeTest();
44   -
  44 + if (cur_ev.events | EPOLLIN)
  45 + {
  46 + if (!client->readFdIntoBuffer())
  47 + threadData->removeClient(client);
  48 + else
  49 + {
  50 + client->bufferToMqttPackets(packetQueueIn); // TODO: different, because now I need to give the packet a raw pointer.
  51 + }
  52 + }
45 53 }
46 54 }
47 55 }
  56 +
  57 + for (MqttPacket &packet : packetQueueIn)
  58 + {
  59 + packet.handle();
  60 + }
48 61 }
49 62 }
50 63  
... ...
mqttpacket.cpp 0 → 100644
  1 +#include "mqttpacket.h"
  2 +#include <cstring>
  3 +
  4 +MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender) : // TODO: length of remaining length
  5 + bites(len),
  6 + fixed_header_length(fixed_header_length),
  7 + sender(sender)
  8 +{
  9 + unsigned char _packetType = buf[0] >> 4;
  10 + packetType = (PacketType)_packetType; // TODO: veryify some other things and set to invalid if doesn't match
  11 +
  12 + std::memcpy(&bites[0], buf, len);
  13 +}
  14 +
  15 +void MqttPacket::handle()
  16 +{
  17 + pos += fixed_header_length;
  18 +
  19 + if (packetType == PacketType::CONNECT)
  20 + handleConnect();
  21 +}
  22 +
  23 +void MqttPacket::handleConnect()
  24 +{
  25 + // TODO: Do all packets have a variable header?
  26 + variable_header_length = (bites[fixed_header_length] << 8) | (bites[fixed_header_length+1]);
  27 + pos += 2;
  28 +
  29 + if (variable_header_length == 4)
  30 + {
  31 + char *c = readBytes(variable_header_length);
  32 + std::string magic_marker(c, variable_header_length);
  33 +
  34 + char protocol_level = readByte();
  35 +
  36 + if (magic_marker == "MQTT" && protocol_level == 0x04)
  37 + {
  38 + protocolVersion = ProtocolVersion::Mqtt311;
  39 + }
  40 + }
  41 + else if (variable_header_length == 6)
  42 + {
  43 + throw new ProtocolError("Only MQTT 3.1.1 implemented.");
  44 + }
  45 +}
  46 +
  47 +char *MqttPacket::readBytes(size_t length)
  48 +{
  49 + if (pos + length > bites.size())
  50 + throw ProtocolError("Invalid packet: header specifies invalid length.");
  51 +
  52 + char *b = &bites[pos];
  53 + pos += length;
  54 + return b;
  55 +}
  56 +
  57 +char MqttPacket::readByte()
  58 +{
  59 + if (pos + 1 > bites.size())
  60 + throw ProtocolError("Invalid packet: header specifies invalid length.");
  61 +
  62 + char b = bites[pos];
  63 + pos++;
  64 + return b;
  65 +}
  66 +
  67 +
  68 +
  69 +std::string MqttPacket::getClientId()
  70 +{
  71 + if (packetType != PacketType::CONNECT)
  72 + throw ProtocolError("Can't get clientid from non-connect packet.");
  73 +
  74 + uint16_t clientid_length = (bites[fixed_header_length + 10] << 8) | (bites[fixed_header_length + 11]);
  75 + size_t client_id_start = fixed_header_length + 12;
  76 +
  77 + if (clientid_length + 12 < bites.size())
  78 + {
  79 + std::string result(&bites[client_id_start], clientid_length);
  80 + return result;
  81 + }
  82 +
  83 + throw ProtocolError("Can't get clientid");
  84 +}
  85 +
  86 +
  87 +
  88 +
  89 +
... ...
mqttpacket.h 0 → 100644
  1 +#ifndef MQTTPACKET_H
  2 +#define MQTTPACKET_H
  3 +
  4 +#include "unistd.h"
  5 +#include <memory>
  6 +#include <vector>
  7 +#include <exception>
  8 +
  9 +#include "client.h"
  10 +#include "exceptions.h"
  11 +#include "types.h"
  12 +
  13 +class Client;
  14 +
  15 +
  16 +class MqttPacket
  17 +{
  18 + bool valid = false;
  19 +
  20 + std::vector<char> bites;
  21 + const size_t fixed_header_length;
  22 + uint16_t variable_header_length;
  23 + Client *sender;
  24 + std::string clientid;
  25 + size_t pos = 0;
  26 + ProtocolVersion protocolVersion = ProtocolVersion::None;
  27 +public:
  28 + PacketType packetType = PacketType::Reserved;
  29 + MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender);
  30 +
  31 + bool isValid() { return valid; }
  32 + std::string getClientId();
  33 + void handle();
  34 + void handleConnect();
  35 + char *readBytes(size_t length);
  36 + char readByte();
  37 +};
  38 +
  39 +#endif // MQTTPACKET_H
... ...
types.cpp 0 → 100644
  1 +#include "types.h"
  2 +
... ...
types.h 0 → 100644
  1 +#ifndef TYPES_H
  2 +#define TYPES_H
  3 +
  4 +enum class PacketType
  5 +{
  6 + Reserved = 0,
  7 + CONNECT = 1,
  8 + CONNACK = 2,
  9 +
  10 + Reserved2 = 15
  11 +};
  12 +
  13 +enum class ProtocolVersion
  14 +{
  15 + None = 0,
  16 + Mqtt31 = 0x03,
  17 + Mqtt311 = 0x04
  18 +};
  19 +
  20 +#endif // TYPES_H
... ...