Commit 243c873f18ea5d69c5589c15b6a1c33cb66b998d
1 parent
d11fd364
Much stuff
Showing
13 changed files
with
284 additions
and
32 deletions
CMakeLists.txt
| ... | ... | @@ -6,6 +6,14 @@ set(CMAKE_CXX_STANDARD 11) |
| 6 | 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) |
| 7 | 7 | |
| 8 | 8 | |
| 9 | -add_executable(FlashMQ main.cpp utils.cpp threaddata.cpp client.cpp) | |
| 9 | +add_executable(FlashMQ | |
| 10 | + main.cpp | |
| 11 | + utils.cpp | |
| 12 | + threaddata.cpp | |
| 13 | + client.cpp | |
| 14 | + bytestopacketparser.cpp | |
| 15 | + mqttpacket.cpp | |
| 16 | + exceptions.cpp | |
| 17 | + types.cpp) | |
| 10 | 18 | |
| 11 | 19 | target_link_libraries(FlashMQ pthread) | ... | ... |
MqttPacket.cpp
0 → 100644
bytestopacketparser.cpp
0 → 100644
bytestopacketparser.h
0 → 100644
| 1 | +#ifndef BYTESTOPACKETPARSER_H | |
| 2 | +#define BYTESTOPACKETPARSER_H | |
| 3 | + | |
| 4 | +#include <unistd.h> | |
| 5 | +#include <vector> | |
| 6 | + | |
| 7 | +#include "mqttpacket.h" | |
| 8 | + | |
| 9 | +#define MQTT_HEADER_LENGH 2 | |
| 10 | + | |
| 11 | +class BytesToPacketParser | |
| 12 | +{ | |
| 13 | +public: | |
| 14 | + BytesToPacketParser(); | |
| 15 | + | |
| 16 | + size_t bytesToPackets(char *buf, size_t len, std::vector<MqttPacket> &queue); | |
| 17 | +}; | |
| 18 | + | |
| 19 | +#endif // BYTESTOPACKETPARSER_H | ... | ... |
client.cpp
| ... | ... | @@ -35,17 +35,12 @@ bool Client::readFdIntoBuffer() |
| 35 | 35 | } |
| 36 | 36 | |
| 37 | 37 | wi += n; |
| 38 | - size_t bytesUsed = getBufBytesUsed(); | |
| 39 | 38 | |
| 40 | - // TODO: we need a buffer to keep partial frames in, so/and can we reduce the size of this buffer again periodically? | |
| 41 | - if (bytesUsed >= bufsize) | |
| 39 | + if (getBufBytesUsed() >= bufsize) | |
| 42 | 40 | { |
| 43 | - const size_t newBufSize = bufsize * 2; | |
| 44 | - readbuf = (char*)realloc(readbuf, newBufSize); | |
| 45 | - bufsize = newBufSize; | |
| 41 | + growBuffer(); | |
| 46 | 42 | } |
| 47 | 43 | |
| 48 | - wi = wi % bufsize; | |
| 49 | 44 | read_size = getMaxWriteSize(); |
| 50 | 45 | } |
| 51 | 46 | |
| ... | ... | @@ -57,12 +52,51 @@ bool Client::readFdIntoBuffer() |
| 57 | 52 | return true; |
| 58 | 53 | } |
| 59 | 54 | |
| 60 | -void Client::writeTest() | |
| 55 | +bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn) | |
| 61 | 56 | { |
| 62 | - char *p = &readbuf[ri]; | |
| 63 | - size_t max_read = getMaxReadSize(); | |
| 64 | - ri = (ri + max_read) % bufsize; | |
| 65 | - write(fd, p, max_read); | |
| 57 | + while (getBufBytesUsed() >= MQTT_HEADER_LENGH) | |
| 58 | + { | |
| 59 | + // Determine the packet length by decoding the variable length | |
| 60 | + size_t remaining_length_i = 1; | |
| 61 | + int multiplier = 1; | |
| 62 | + size_t packet_length = 0; | |
| 63 | + unsigned char encodedByte = 0; | |
| 64 | + do | |
| 65 | + { | |
| 66 | + if (remaining_length_i >= getBufBytesUsed()) | |
| 67 | + break; | |
| 68 | + encodedByte = readbuf[remaining_length_i++]; | |
| 69 | + packet_length += (encodedByte & 127) * multiplier; | |
| 70 | + multiplier *= 128; | |
| 71 | + if (multiplier > 128*128*128) | |
| 72 | + return false; | |
| 73 | + } | |
| 74 | + while ((encodedByte & 128) != 0); | |
| 75 | + packet_length += remaining_length_i; | |
| 76 | + | |
| 77 | + // TODO: unauth client can't send many bytes | |
| 78 | + | |
| 79 | + if (packet_length <= getBufBytesUsed()) | |
| 80 | + { | |
| 81 | + MqttPacket packet(&readbuf[ri], packet_length, remaining_length_i, this); | |
| 82 | + packetQueueIn.push_back(std::move(packet)); | |
| 83 | + | |
| 84 | + ri += packet_length; | |
| 85 | + } | |
| 86 | + else | |
| 87 | + break; | |
| 88 | + | |
| 89 | + } | |
| 90 | + | |
| 91 | + if (ri == wi) | |
| 92 | + { | |
| 93 | + ri = 0; | |
| 94 | + wi = 0; | |
| 95 | + } | |
| 96 | + | |
| 97 | + return true; | |
| 98 | + | |
| 99 | + // TODO: reset buffer to normal size after a while of not needing it, or not needing the extra space. | |
| 66 | 100 | } |
| 67 | 101 | |
| 68 | 102 | ... | ... |
client.h
| ... | ... | @@ -3,14 +3,19 @@ |
| 3 | 3 | |
| 4 | 4 | #include <fcntl.h> |
| 5 | 5 | #include <unistd.h> |
| 6 | +#include <vector> | |
| 6 | 7 | |
| 7 | 8 | #include "threaddata.h" |
| 9 | +#include "mqttpacket.h" | |
| 8 | 10 | |
| 9 | -#define CLIENT_BUFFER_SIZE 16 | |
| 11 | +#define CLIENT_BUFFER_SIZE 1024 | |
| 12 | +#define MQTT_HEADER_LENGH 2 | |
| 10 | 13 | |
| 11 | 14 | class ThreadData; |
| 12 | 15 | typedef std::shared_ptr<ThreadData> ThreadData_p; |
| 13 | 16 | |
| 17 | +class MqttPacket; | |
| 18 | + | |
| 14 | 19 | class Client |
| 15 | 20 | { |
| 16 | 21 | int fd; |
| ... | ... | @@ -20,31 +25,27 @@ class Client |
| 20 | 25 | int wi = 0; |
| 21 | 26 | int ri = 0; |
| 22 | 27 | |
| 28 | + bool authenticated = false; | |
| 29 | + std::string clientid; | |
| 30 | + | |
| 23 | 31 | ThreadData_p threadData; |
| 24 | 32 | |
| 25 | 33 | size_t getBufBytesUsed() |
| 26 | 34 | { |
| 27 | - size_t result = 0; | |
| 28 | - if (wi >= ri) | |
| 29 | - result = wi - ri; | |
| 30 | - else | |
| 31 | - result = (bufsize + wi) - ri; | |
| 35 | + return wi - ri; | |
| 32 | 36 | }; |
| 33 | 37 | |
| 34 | 38 | size_t getMaxWriteSize() |
| 35 | 39 | { |
| 36 | 40 | size_t available = bufsize - getBufBytesUsed(); |
| 37 | - size_t space_at_end = bufsize - wi; | |
| 38 | - size_t answer = std::min<int>(available, space_at_end); | |
| 39 | - return answer; | |
| 41 | + return available; | |
| 40 | 42 | } |
| 41 | 43 | |
| 42 | - size_t getMaxReadSize() | |
| 44 | + void growBuffer() | |
| 43 | 45 | { |
| 44 | - size_t available = getBufBytesUsed(); | |
| 45 | - size_t space_to_end = bufsize - ri; | |
| 46 | - size_t answer = std::min<int>(available, space_to_end); | |
| 47 | - return answer; | |
| 46 | + const size_t newBufSize = bufsize * 2; | |
| 47 | + readbuf = (char*)realloc(readbuf, newBufSize); | |
| 48 | + bufsize = newBufSize; | |
| 48 | 49 | } |
| 49 | 50 | |
| 50 | 51 | public: |
| ... | ... | @@ -53,7 +54,7 @@ public: |
| 53 | 54 | |
| 54 | 55 | int getFd() { return fd;} |
| 55 | 56 | bool readFdIntoBuffer(); |
| 56 | - void writeTest(); | |
| 57 | + bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn); | |
| 57 | 58 | |
| 58 | 59 | }; |
| 59 | 60 | ... | ... |
exceptions.cpp
0 → 100644
exceptions.h
0 → 100644
main.cpp
| ... | ... | @@ -10,6 +10,7 @@ |
| 10 | 10 | #include "utils.h" |
| 11 | 11 | #include "threaddata.h" |
| 12 | 12 | #include "client.h" |
| 13 | +#include "mqttpacket.h" | |
| 13 | 14 | |
| 14 | 15 | #define MAX_EVENTS 1024 |
| 15 | 16 | #define NR_OF_THREADS 4 |
| ... | ... | @@ -23,6 +24,8 @@ void do_thread_work(ThreadData *threadData) |
| 23 | 24 | struct epoll_event events[MAX_EVENTS]; |
| 24 | 25 | memset(&events, 0, sizeof (struct epoll_event)*MAX_EVENTS); |
| 25 | 26 | |
| 27 | + std::vector<MqttPacket> packetQueueIn; | |
| 28 | + | |
| 26 | 29 | while (1) |
| 27 | 30 | { |
| 28 | 31 | int fdcount = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); |
| ... | ... | @@ -38,13 +41,23 @@ void do_thread_work(ThreadData *threadData) |
| 38 | 41 | |
| 39 | 42 | if (client) // TODO: is this check necessary? |
| 40 | 43 | { |
| 41 | - if (!client->readFdIntoBuffer()) | |
| 42 | - threadData->removeClient(client); | |
| 43 | - client->writeTest(); | |
| 44 | - | |
| 44 | + if (cur_ev.events | EPOLLIN) | |
| 45 | + { | |
| 46 | + if (!client->readFdIntoBuffer()) | |
| 47 | + threadData->removeClient(client); | |
| 48 | + else | |
| 49 | + { | |
| 50 | + client->bufferToMqttPackets(packetQueueIn); // TODO: different, because now I need to give the packet a raw pointer. | |
| 51 | + } | |
| 52 | + } | |
| 45 | 53 | } |
| 46 | 54 | } |
| 47 | 55 | } |
| 56 | + | |
| 57 | + for (MqttPacket &packet : packetQueueIn) | |
| 58 | + { | |
| 59 | + packet.handle(); | |
| 60 | + } | |
| 48 | 61 | } |
| 49 | 62 | } |
| 50 | 63 | ... | ... |
mqttpacket.cpp
0 → 100644
| 1 | +#include "mqttpacket.h" | |
| 2 | +#include <cstring> | |
| 3 | + | |
| 4 | +MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender) : // TODO: length of remaining length | |
| 5 | + bites(len), | |
| 6 | + fixed_header_length(fixed_header_length), | |
| 7 | + sender(sender) | |
| 8 | +{ | |
| 9 | + unsigned char _packetType = buf[0] >> 4; | |
| 10 | + packetType = (PacketType)_packetType; // TODO: veryify some other things and set to invalid if doesn't match | |
| 11 | + | |
| 12 | + std::memcpy(&bites[0], buf, len); | |
| 13 | +} | |
| 14 | + | |
| 15 | +void MqttPacket::handle() | |
| 16 | +{ | |
| 17 | + pos += fixed_header_length; | |
| 18 | + | |
| 19 | + if (packetType == PacketType::CONNECT) | |
| 20 | + handleConnect(); | |
| 21 | +} | |
| 22 | + | |
| 23 | +void MqttPacket::handleConnect() | |
| 24 | +{ | |
| 25 | + // TODO: Do all packets have a variable header? | |
| 26 | + variable_header_length = (bites[fixed_header_length] << 8) | (bites[fixed_header_length+1]); | |
| 27 | + pos += 2; | |
| 28 | + | |
| 29 | + if (variable_header_length == 4) | |
| 30 | + { | |
| 31 | + char *c = readBytes(variable_header_length); | |
| 32 | + std::string magic_marker(c, variable_header_length); | |
| 33 | + | |
| 34 | + char protocol_level = readByte(); | |
| 35 | + | |
| 36 | + if (magic_marker == "MQTT" && protocol_level == 0x04) | |
| 37 | + { | |
| 38 | + protocolVersion = ProtocolVersion::Mqtt311; | |
| 39 | + } | |
| 40 | + } | |
| 41 | + else if (variable_header_length == 6) | |
| 42 | + { | |
| 43 | + throw new ProtocolError("Only MQTT 3.1.1 implemented."); | |
| 44 | + } | |
| 45 | +} | |
| 46 | + | |
| 47 | +char *MqttPacket::readBytes(size_t length) | |
| 48 | +{ | |
| 49 | + if (pos + length > bites.size()) | |
| 50 | + throw ProtocolError("Invalid packet: header specifies invalid length."); | |
| 51 | + | |
| 52 | + char *b = &bites[pos]; | |
| 53 | + pos += length; | |
| 54 | + return b; | |
| 55 | +} | |
| 56 | + | |
| 57 | +char MqttPacket::readByte() | |
| 58 | +{ | |
| 59 | + if (pos + 1 > bites.size()) | |
| 60 | + throw ProtocolError("Invalid packet: header specifies invalid length."); | |
| 61 | + | |
| 62 | + char b = bites[pos]; | |
| 63 | + pos++; | |
| 64 | + return b; | |
| 65 | +} | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | +std::string MqttPacket::getClientId() | |
| 70 | +{ | |
| 71 | + if (packetType != PacketType::CONNECT) | |
| 72 | + throw ProtocolError("Can't get clientid from non-connect packet."); | |
| 73 | + | |
| 74 | + uint16_t clientid_length = (bites[fixed_header_length + 10] << 8) | (bites[fixed_header_length + 11]); | |
| 75 | + size_t client_id_start = fixed_header_length + 12; | |
| 76 | + | |
| 77 | + if (clientid_length + 12 < bites.size()) | |
| 78 | + { | |
| 79 | + std::string result(&bites[client_id_start], clientid_length); | |
| 80 | + return result; | |
| 81 | + } | |
| 82 | + | |
| 83 | + throw ProtocolError("Can't get clientid"); | |
| 84 | +} | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | ... | ... |
mqttpacket.h
0 → 100644
| 1 | +#ifndef MQTTPACKET_H | |
| 2 | +#define MQTTPACKET_H | |
| 3 | + | |
| 4 | +#include "unistd.h" | |
| 5 | +#include <memory> | |
| 6 | +#include <vector> | |
| 7 | +#include <exception> | |
| 8 | + | |
| 9 | +#include "client.h" | |
| 10 | +#include "exceptions.h" | |
| 11 | +#include "types.h" | |
| 12 | + | |
| 13 | +class Client; | |
| 14 | + | |
| 15 | + | |
| 16 | +class MqttPacket | |
| 17 | +{ | |
| 18 | + bool valid = false; | |
| 19 | + | |
| 20 | + std::vector<char> bites; | |
| 21 | + const size_t fixed_header_length; | |
| 22 | + uint16_t variable_header_length; | |
| 23 | + Client *sender; | |
| 24 | + std::string clientid; | |
| 25 | + size_t pos = 0; | |
| 26 | + ProtocolVersion protocolVersion = ProtocolVersion::None; | |
| 27 | +public: | |
| 28 | + PacketType packetType = PacketType::Reserved; | |
| 29 | + MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender); | |
| 30 | + | |
| 31 | + bool isValid() { return valid; } | |
| 32 | + std::string getClientId(); | |
| 33 | + void handle(); | |
| 34 | + void handleConnect(); | |
| 35 | + char *readBytes(size_t length); | |
| 36 | + char readByte(); | |
| 37 | +}; | |
| 38 | + | |
| 39 | +#endif // MQTTPACKET_H | ... | ... |
types.cpp
0 → 100644
types.h
0 → 100644