Commit 243c873f18ea5d69c5589c15b6a1c33cb66b998d

Authored by Wiebe Cazemier
1 parent d11fd364

Much stuff

CMakeLists.txt
@@ -6,6 +6,14 @@ set(CMAKE_CXX_STANDARD 11) @@ -6,6 +6,14 @@ set(CMAKE_CXX_STANDARD 11)
6 set(CMAKE_CXX_STANDARD_REQUIRED ON) 6 set(CMAKE_CXX_STANDARD_REQUIRED ON)
7 7
8 8
9 -add_executable(FlashMQ main.cpp utils.cpp threaddata.cpp client.cpp) 9 +add_executable(FlashMQ
  10 + main.cpp
  11 + utils.cpp
  12 + threaddata.cpp
  13 + client.cpp
  14 + bytestopacketparser.cpp
  15 + mqttpacket.cpp
  16 + exceptions.cpp
  17 + types.cpp)
10 18
11 target_link_libraries(FlashMQ pthread) 19 target_link_libraries(FlashMQ pthread)
MqttPacket.cpp 0 → 100644
bytestopacketparser.cpp 0 → 100644
  1 +#include "bytestopacketparser.h"
  2 +
  3 +BytesToPacketParser::BytesToPacketParser()
  4 +{
  5 +
  6 +}
  7 +
  8 +size_t BytesToPacketParser::bytesToPackets(char *buf, size_t len, std::vector<MqttPacket> &queue)
  9 +{
  10 +
  11 +}
bytestopacketparser.h 0 → 100644
  1 +#ifndef BYTESTOPACKETPARSER_H
  2 +#define BYTESTOPACKETPARSER_H
  3 +
  4 +#include <unistd.h>
  5 +#include <vector>
  6 +
  7 +#include "mqttpacket.h"
  8 +
  9 +#define MQTT_HEADER_LENGH 2
  10 +
  11 +class BytesToPacketParser
  12 +{
  13 +public:
  14 + BytesToPacketParser();
  15 +
  16 + size_t bytesToPackets(char *buf, size_t len, std::vector<MqttPacket> &queue);
  17 +};
  18 +
  19 +#endif // BYTESTOPACKETPARSER_H
client.cpp
@@ -35,17 +35,12 @@ bool Client::readFdIntoBuffer() @@ -35,17 +35,12 @@ bool Client::readFdIntoBuffer()
35 } 35 }
36 36
37 wi += n; 37 wi += n;
38 - size_t bytesUsed = getBufBytesUsed();  
39 38
40 - // TODO: we need a buffer to keep partial frames in, so/and can we reduce the size of this buffer again periodically?  
41 - if (bytesUsed >= bufsize) 39 + if (getBufBytesUsed() >= bufsize)
42 { 40 {
43 - const size_t newBufSize = bufsize * 2;  
44 - readbuf = (char*)realloc(readbuf, newBufSize);  
45 - bufsize = newBufSize; 41 + growBuffer();
46 } 42 }
47 43
48 - wi = wi % bufsize;  
49 read_size = getMaxWriteSize(); 44 read_size = getMaxWriteSize();
50 } 45 }
51 46
@@ -57,12 +52,51 @@ bool Client::readFdIntoBuffer() @@ -57,12 +52,51 @@ bool Client::readFdIntoBuffer()
57 return true; 52 return true;
58 } 53 }
59 54
60 -void Client::writeTest() 55 +bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn)
61 { 56 {
62 - char *p = &readbuf[ri];  
63 - size_t max_read = getMaxReadSize();  
64 - ri = (ri + max_read) % bufsize;  
65 - write(fd, p, max_read); 57 + while (getBufBytesUsed() >= MQTT_HEADER_LENGH)
  58 + {
  59 + // Determine the packet length by decoding the variable length
  60 + size_t remaining_length_i = 1;
  61 + int multiplier = 1;
  62 + size_t packet_length = 0;
  63 + unsigned char encodedByte = 0;
  64 + do
  65 + {
  66 + if (remaining_length_i >= getBufBytesUsed())
  67 + break;
  68 + encodedByte = readbuf[remaining_length_i++];
  69 + packet_length += (encodedByte & 127) * multiplier;
  70 + multiplier *= 128;
  71 + if (multiplier > 128*128*128)
  72 + return false;
  73 + }
  74 + while ((encodedByte & 128) != 0);
  75 + packet_length += remaining_length_i;
  76 +
  77 + // TODO: unauth client can't send many bytes
  78 +
  79 + if (packet_length <= getBufBytesUsed())
  80 + {
  81 + MqttPacket packet(&readbuf[ri], packet_length, remaining_length_i, this);
  82 + packetQueueIn.push_back(std::move(packet));
  83 +
  84 + ri += packet_length;
  85 + }
  86 + else
  87 + break;
  88 +
  89 + }
  90 +
  91 + if (ri == wi)
  92 + {
  93 + ri = 0;
  94 + wi = 0;
  95 + }
  96 +
  97 + return true;
  98 +
  99 + // TODO: reset buffer to normal size after a while of not needing it, or not needing the extra space.
66 } 100 }
67 101
68 102
client.h
@@ -3,14 +3,19 @@ @@ -3,14 +3,19 @@
3 3
4 #include <fcntl.h> 4 #include <fcntl.h>
5 #include <unistd.h> 5 #include <unistd.h>
  6 +#include <vector>
6 7
7 #include "threaddata.h" 8 #include "threaddata.h"
  9 +#include "mqttpacket.h"
8 10
9 -#define CLIENT_BUFFER_SIZE 16 11 +#define CLIENT_BUFFER_SIZE 1024
  12 +#define MQTT_HEADER_LENGH 2
10 13
11 class ThreadData; 14 class ThreadData;
12 typedef std::shared_ptr<ThreadData> ThreadData_p; 15 typedef std::shared_ptr<ThreadData> ThreadData_p;
13 16
  17 +class MqttPacket;
  18 +
14 class Client 19 class Client
15 { 20 {
16 int fd; 21 int fd;
@@ -20,31 +25,27 @@ class Client @@ -20,31 +25,27 @@ class Client
20 int wi = 0; 25 int wi = 0;
21 int ri = 0; 26 int ri = 0;
22 27
  28 + bool authenticated = false;
  29 + std::string clientid;
  30 +
23 ThreadData_p threadData; 31 ThreadData_p threadData;
24 32
25 size_t getBufBytesUsed() 33 size_t getBufBytesUsed()
26 { 34 {
27 - size_t result = 0;  
28 - if (wi >= ri)  
29 - result = wi - ri;  
30 - else  
31 - result = (bufsize + wi) - ri; 35 + return wi - ri;
32 }; 36 };
33 37
34 size_t getMaxWriteSize() 38 size_t getMaxWriteSize()
35 { 39 {
36 size_t available = bufsize - getBufBytesUsed(); 40 size_t available = bufsize - getBufBytesUsed();
37 - size_t space_at_end = bufsize - wi;  
38 - size_t answer = std::min<int>(available, space_at_end);  
39 - return answer; 41 + return available;
40 } 42 }
41 43
42 - size_t getMaxReadSize() 44 + void growBuffer()
43 { 45 {
44 - size_t available = getBufBytesUsed();  
45 - size_t space_to_end = bufsize - ri;  
46 - size_t answer = std::min<int>(available, space_to_end);  
47 - return answer; 46 + const size_t newBufSize = bufsize * 2;
  47 + readbuf = (char*)realloc(readbuf, newBufSize);
  48 + bufsize = newBufSize;
48 } 49 }
49 50
50 public: 51 public:
@@ -53,7 +54,7 @@ public: @@ -53,7 +54,7 @@ public:
53 54
54 int getFd() { return fd;} 55 int getFd() { return fd;}
55 bool readFdIntoBuffer(); 56 bool readFdIntoBuffer();
56 - void writeTest(); 57 + bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn);
57 58
58 }; 59 };
59 60
exceptions.cpp 0 → 100644
  1 +#include "exceptions.h"
  2 +
exceptions.h 0 → 100644
  1 +#ifndef EXCEPTIONS_H
  2 +#define EXCEPTIONS_H
  3 +
  4 +#include <exception>
  5 +#include <stdexcept>
  6 +
  7 +class ProtocolError : public std::runtime_error
  8 +{
  9 +public:
  10 + ProtocolError(const std::string &msg) : std::runtime_error(msg) {}
  11 +};
  12 +
  13 +
  14 +#endif // EXCEPTIONS_H
main.cpp
@@ -10,6 +10,7 @@ @@ -10,6 +10,7 @@
10 #include "utils.h" 10 #include "utils.h"
11 #include "threaddata.h" 11 #include "threaddata.h"
12 #include "client.h" 12 #include "client.h"
  13 +#include "mqttpacket.h"
13 14
14 #define MAX_EVENTS 1024 15 #define MAX_EVENTS 1024
15 #define NR_OF_THREADS 4 16 #define NR_OF_THREADS 4
@@ -23,6 +24,8 @@ void do_thread_work(ThreadData *threadData) @@ -23,6 +24,8 @@ void do_thread_work(ThreadData *threadData)
23 struct epoll_event events[MAX_EVENTS]; 24 struct epoll_event events[MAX_EVENTS];
24 memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); 25 memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS);
25 26
  27 + std::vector<MqttPacket> packetQueueIn;
  28 +
26 while (1) 29 while (1)
27 { 30 {
28 int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); 31 int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100);
@@ -38,13 +41,23 @@ void do_thread_work(ThreadData *threadData) @@ -38,13 +41,23 @@ void do_thread_work(ThreadData *threadData)
38 41
39 if (client) // TODO: is this check necessary? 42 if (client) // TODO: is this check necessary?
40 { 43 {
41 - if (!client->readFdIntoBuffer())  
42 - threadData->removeClient(client);  
43 - client->writeTest();  
44 - 44 + if (cur_ev.events | EPOLLIN)
  45 + {
  46 + if (!client->readFdIntoBuffer())
  47 + threadData->removeClient(client);
  48 + else
  49 + {
  50 + client->bufferToMqttPackets(packetQueueIn); // TODO: different, because now I need to give the packet a raw pointer.
  51 + }
  52 + }
45 } 53 }
46 } 54 }
47 } 55 }
  56 +
  57 + for (MqttPacket &packet : packetQueueIn)
  58 + {
  59 + packet.handle();
  60 + }
48 } 61 }
49 } 62 }
50 63
mqttpacket.cpp 0 → 100644
  1 +#include "mqttpacket.h"
  2 +#include <cstring>
  3 +
  4 +MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender) : // TODO: length of remaining length
  5 + bites(len),
  6 + fixed_header_length(fixed_header_length),
  7 + sender(sender)
  8 +{
  9 + unsigned char _packetType = buf[0] >> 4;
  10 + packetType = (PacketType)_packetType; // TODO: veryify some other things and set to invalid if doesn't match
  11 +
  12 + std::memcpy(&bites[0], buf, len);
  13 +}
  14 +
  15 +void MqttPacket::handle()
  16 +{
  17 + pos += fixed_header_length;
  18 +
  19 + if (packetType == PacketType::CONNECT)
  20 + handleConnect();
  21 +}
  22 +
  23 +void MqttPacket::handleConnect()
  24 +{
  25 + // TODO: Do all packets have a variable header?
  26 + variable_header_length = (bites[fixed_header_length] << 8) | (bites[fixed_header_length+1]);
  27 + pos += 2;
  28 +
  29 + if (variable_header_length == 4)
  30 + {
  31 + char *c = readBytes(variable_header_length);
  32 + std::string magic_marker(c, variable_header_length);
  33 +
  34 + char protocol_level = readByte();
  35 +
  36 + if (magic_marker == "MQTT" && protocol_level == 0x04)
  37 + {
  38 + protocolVersion = ProtocolVersion::Mqtt311;
  39 + }
  40 + }
  41 + else if (variable_header_length == 6)
  42 + {
  43 + throw new ProtocolError("Only MQTT 3.1.1 implemented.");
  44 + }
  45 +}
  46 +
  47 +char *MqttPacket::readBytes(size_t length)
  48 +{
  49 + if (pos + length > bites.size())
  50 + throw ProtocolError("Invalid packet: header specifies invalid length.");
  51 +
  52 + char *b = &bites[pos];
  53 + pos += length;
  54 + return b;
  55 +}
  56 +
  57 +char MqttPacket::readByte()
  58 +{
  59 + if (pos + 1 > bites.size())
  60 + throw ProtocolError("Invalid packet: header specifies invalid length.");
  61 +
  62 + char b = bites[pos];
  63 + pos++;
  64 + return b;
  65 +}
  66 +
  67 +
  68 +
  69 +std::string MqttPacket::getClientId()
  70 +{
  71 + if (packetType != PacketType::CONNECT)
  72 + throw ProtocolError("Can't get clientid from non-connect packet.");
  73 +
  74 + uint16_t clientid_length = (bites[fixed_header_length + 10] << 8) | (bites[fixed_header_length + 11]);
  75 + size_t client_id_start = fixed_header_length + 12;
  76 +
  77 + if (clientid_length + 12 < bites.size())
  78 + {
  79 + std::string result(&bites[client_id_start], clientid_length);
  80 + return result;
  81 + }
  82 +
  83 + throw ProtocolError("Can't get clientid");
  84 +}
  85 +
  86 +
  87 +
  88 +
  89 +
mqttpacket.h 0 → 100644
  1 +#ifndef MQTTPACKET_H
  2 +#define MQTTPACKET_H
  3 +
  4 +#include "unistd.h"
  5 +#include <memory>
  6 +#include <vector>
  7 +#include <exception>
  8 +
  9 +#include "client.h"
  10 +#include "exceptions.h"
  11 +#include "types.h"
  12 +
  13 +class Client;
  14 +
  15 +
  16 +class MqttPacket
  17 +{
  18 + bool valid = false;
  19 +
  20 + std::vector<char> bites;
  21 + const size_t fixed_header_length;
  22 + uint16_t variable_header_length;
  23 + Client *sender;
  24 + std::string clientid;
  25 + size_t pos = 0;
  26 + ProtocolVersion protocolVersion = ProtocolVersion::None;
  27 +public:
  28 + PacketType packetType = PacketType::Reserved;
  29 + MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender);
  30 +
  31 + bool isValid() { return valid; }
  32 + std::string getClientId();
  33 + void handle();
  34 + void handleConnect();
  35 + char *readBytes(size_t length);
  36 + char readByte();
  37 +};
  38 +
  39 +#endif // MQTTPACKET_H
types.cpp 0 → 100644
  1 +#include "types.h"
  2 +
types.h 0 → 100644
  1 +#ifndef TYPES_H
  2 +#define TYPES_H
  3 +
  4 +enum class PacketType
  5 +{
  6 + Reserved = 0,
  7 + CONNECT = 1,
  8 + CONNACK = 2,
  9 +
  10 + Reserved2 = 15
  11 +};
  12 +
  13 +enum class ProtocolVersion
  14 +{
  15 + None = 0,
  16 + Mqtt31 = 0x03,
  17 + Mqtt311 = 0x04
  18 +};
  19 +
  20 +#endif // TYPES_H