Commit 5dcad43303fd02ed7c68b16d90b29717b119fb7f

Authored by Wiebe Cazemier
1 parent d61c593d

Write response

client.cpp
1 1 #include "client.h"
2 2  
  3 +#include <cstring>
  4 +#include <sstream>
  5 +
3 6 Client::Client(int fd, ThreadData_p threadData) :
4 7 fd(fd),
5 8 threadData(threadData)
... ... @@ -7,6 +10,7 @@ Client::Client(int fd, ThreadData_p threadData) :
7 10 int flags = fcntl(fd, F_GETFL);
8 11 fcntl(fd, F_SETFL, flags | O_NONBLOCK);
9 12 readbuf = (char*)malloc(CLIENT_BUFFER_SIZE);
  13 + writebuf = (char*)malloc(CLIENT_BUFFER_SIZE);
10 14 }
11 15  
12 16 Client::~Client()
... ... @@ -19,7 +23,7 @@ Client::~Client()
19 23 // false means any kind of error we want to get rid of the client for.
20 24 bool Client::readFdIntoBuffer()
21 25 {
22   - int read_size = getMaxWriteSize();
  26 + int read_size = getReadBufMaxWriteSize();
23 27  
24 28 int n;
25 29 while ((n = read(fd, &readbuf[wi], read_size)) != 0)
... ... @@ -36,12 +40,12 @@ bool Client::readFdIntoBuffer()
36 40  
37 41 wi += n;
38 42  
39   - if (getBufBytesUsed() >= bufsize)
  43 + if (getReadBufBytesUsed() >= readBufsize)
40 44 {
41   - growBuffer();
  45 + growReadBuffer();
42 46 }
43 47  
44   - read_size = getMaxWriteSize();
  48 + read_size = getReadBufMaxWriteSize();
45 49 }
46 50  
47 51 if (n == 0) // client disconnected.
... ... @@ -52,9 +56,52 @@ bool Client::readFdIntoBuffer()
52 56 return true;
53 57 }
54 58  
  59 +void Client::writeMqttPacket(MqttPacket &packet)
  60 +{
  61 + if (packet.getSize() > getWriteBufMaxWriteSize())
  62 + growWriteBuffer(packet.getSize());
  63 +
  64 + std::memcpy(&writebuf[wwi], &packet.getBites()[0], packet.getSize());
  65 + wwi += packet.getSize();
  66 +}
  67 +
  68 +bool Client::writeBufIntoFd() // TODO: ignore the signal BROKEN PIPE we now also get when a client disappears.
  69 +{
  70 + int n;
  71 + while ((n = write(fd, &writebuf[wri], getWriteBufBytesUsed())) != 0)
  72 + {
  73 + if (n < 0)
  74 + {
  75 + if (errno == EINTR)
  76 + continue;
  77 + if (errno == EAGAIN || errno == EWOULDBLOCK)
  78 + break;
  79 + else
  80 + return false;
  81 + }
  82 +
  83 + wri += n;
  84 + }
  85 +
  86 + if (wri == wwi)
  87 + {
  88 + wri = 0;
  89 + wwi = 0;
  90 + }
  91 +
  92 + return true;
  93 +}
  94 +
  95 +std::string Client::repr()
  96 +{
  97 + std::ostringstream a;
  98 + a << "Client = " << clientid << ", user = " << username;
  99 + return a.str();
  100 +}
  101 +
55 102 bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn)
56 103 {
57   - while (getBufBytesUsed() >= MQTT_HEADER_LENGH)
  104 + while (getReadBufBytesUsed() >= MQTT_HEADER_LENGH)
58 105 {
59 106 // Determine the packet length by decoding the variable length
60 107 size_t remaining_length_i = 1;
... ... @@ -63,7 +110,7 @@ bool Client::bufferToMqttPackets(std::vector&lt;MqttPacket&gt; &amp;packetQueueIn)
63 110 unsigned char encodedByte = 0;
64 111 do
65 112 {
66   - if (remaining_length_i >= getBufBytesUsed())
  113 + if (remaining_length_i >= getReadBufBytesUsed())
67 114 break;
68 115 encodedByte = readbuf[remaining_length_i++];
69 116 packet_length += (encodedByte & 127) * multiplier;
... ... @@ -79,7 +126,7 @@ bool Client::bufferToMqttPackets(std::vector&lt;MqttPacket&gt; &amp;packetQueueIn)
79 126 throw ProtocolError("An unauthenticated client sends a packet of 1 MB or bigger? Probably it's just random bytes.");
80 127 }
81 128  
82   - if (packet_length <= getBufBytesUsed())
  129 + if (packet_length <= getReadBufBytesUsed())
83 130 {
84 131 MqttPacket packet(&readbuf[ri], packet_length, remaining_length_i, this);
85 132 packetQueueIn.push_back(std::move(packet));
... ... @@ -120,3 +167,5 @@ void Client::setClientProperties(const std::string &amp;clientId, const std::string
120 167  
121 168  
122 169  
  170 +
  171 +
... ...
client.h
... ... @@ -22,10 +22,15 @@ class Client
22 22 int fd;
23 23  
24 24 char *readbuf = NULL; // With many clients, it may not be smart to keep a (big) buffer around.
25   - size_t bufsize = CLIENT_BUFFER_SIZE;
  25 + size_t readBufsize = CLIENT_BUFFER_SIZE;
26 26 int wi = 0;
27 27 int ri = 0;
28 28  
  29 + char *writebuf = NULL; // With many clients, it may not be smart to keep a (big) buffer around.
  30 + size_t writeBufsize = CLIENT_BUFFER_SIZE;
  31 + int wwi = 0;
  32 + int wri = 0;
  33 +
29 34 bool authenticated = false;
30 35 bool connectPacketSeen = false;
31 36 std::string clientid;
... ... @@ -34,22 +39,43 @@ class Client
34 39  
35 40 ThreadData_p threadData;
36 41  
37   - size_t getBufBytesUsed()
  42 + size_t getReadBufBytesUsed()
38 43 {
39 44 return wi - ri;
40 45 };
41 46  
42   - size_t getMaxWriteSize()
  47 + size_t getReadBufMaxWriteSize()
43 48 {
44   - size_t available = bufsize - getBufBytesUsed();
  49 + size_t available = readBufsize - getReadBufBytesUsed();
45 50 return available;
46 51 }
47 52  
48   - void growBuffer()
  53 + void growReadBuffer()
49 54 {
50   - const size_t newBufSize = bufsize * 2;
  55 + const size_t newBufSize = readBufsize * 2;
51 56 readbuf = (char*)realloc(readbuf, newBufSize);
52   - bufsize = newBufSize;
  57 + readBufsize = newBufSize;
  58 + }
  59 +
  60 + size_t getWriteBufMaxWriteSize()
  61 + {
  62 + size_t available = writeBufsize - getWriteBufBytesUsed();
  63 + return available;
  64 + }
  65 +
  66 + size_t getWriteBufBytesUsed()
  67 + {
  68 + return wwi - wri;
  69 + };
  70 +
  71 + void growWriteBuffer(size_t add_size)
  72 + {
  73 + if (add_size == 0)
  74 + return;
  75 +
  76 + const size_t newBufSize = writeBufsize + add_size;
  77 + writebuf = (char*)realloc(readbuf, newBufSize);
  78 + writeBufsize = newBufSize;
53 79 }
54 80  
55 81 public:
... ... @@ -64,6 +90,10 @@ public:
64 90 bool getAuthenticated() { return authenticated; }
65 91 bool hasConnectPacketSeen() { return connectPacketSeen; }
66 92  
  93 + void writeMqttPacket(MqttPacket &packet);
  94 + bool writeBufIntoFd();
  95 +
  96 + std::string repr();
67 97  
68 98 };
69 99  
... ...
main.cpp
... ... @@ -39,17 +39,25 @@ void do_thread_work(ThreadData *threadData)
39 39  
40 40 Client_p client = threadData->getClient(fd);
41 41  
42   - if (client) // TODO: is this check necessary?
  42 + if (client)
43 43 {
44   - if (cur_ev.events | EPOLLIN)
  44 + if (cur_ev.events & EPOLLIN)
45 45 {
46 46 if (!client->readFdIntoBuffer())
  47 + {
  48 + std::cout << "Disconnect: " << client->repr() << std::endl;
47 49 threadData->removeClient(client);
  50 + }
48 51 else
49 52 {
50 53 client->bufferToMqttPackets(packetQueueIn); // TODO: different, because now I need to give the packet a raw pointer.
51 54 }
52 55 }
  56 + if (cur_ev.events & EPOLLOUT)
  57 + {
  58 + if (!client->writeBufIntoFd())
  59 + threadData->removeClient(client);
  60 + }
53 61 }
54 62 }
55 63 }
... ... @@ -58,6 +66,7 @@ void do_thread_work(ThreadData *threadData)
58 66 {
59 67 packet.handle();
60 68 }
  69 + packetQueueIn.clear();
61 70 }
62 71 }
63 72  
... ...
mqttpacket.cpp
1 1 #include "mqttpacket.h"
2 2 #include <cstring>
  3 +#include <iostream>
3 4  
4 5 MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender) :
5 6 bites(len),
6 7 fixed_header_length(fixed_header_length),
7 8 sender(sender)
8 9 {
9   - unsigned char _packetType = buf[0] >> 4;
  10 + unsigned char _packetType = (buf[0] & 0xF0) >> 4;
10 11 packetType = (PacketType)_packetType;
11 12 pos += fixed_header_length;
12 13  
13 14 std::memcpy(&bites[0], buf, len);
  15 +}
  16 +
  17 +MqttPacket::MqttPacket(const ConnAck &connAck) :
  18 + bites(4)
  19 +{
  20 + packetType = PacketType::CONNACK;
  21 + char first_byte = static_cast<char>(packetType) << 4;
  22 + writeByte(first_byte);
  23 + writeByte(2); // length is always 2.
  24 + writeByte(0); // all connect-ack flags are 0, except session-present, but we don't have that yet.
  25 + writeByte(static_cast<char>(connAck.return_code));
14 26  
15   - variable_header_length = readTwoBytesToUInt16();
16 27 }
17 28  
18 29 void MqttPacket::handle()
19 30 {
20 31 if (packetType == PacketType::CONNECT)
21 32 handleConnect();
  33 + else if (packetType == PacketType::PINGREQ)
  34 + std::cout << "PING" << std::endl;
  35 + else if (packetType == PacketType::SUBSCRIBE)
  36 + std::cout << "Sub" << std::endl;
22 37 }
23 38  
24 39 void MqttPacket::handleConnect()
... ... @@ -26,6 +41,7 @@ void MqttPacket::handleConnect()
26 41 if (sender->hasConnectPacketSeen())
27 42 throw ProtocolError("Client already sent a CONNECT.");
28 43  
  44 + uint16_t variable_header_length = readTwoBytesToUInt16();
29 45  
30 46 if (variable_header_length == 4 || variable_header_length == 6)
31 47 {
... ... @@ -87,6 +103,13 @@ void MqttPacket::handleConnect()
87 103 // TODO: validate UTF8 encoded username/password.
88 104  
89 105 sender->setClientProperties(client_id, username, true, keep_alive);
  106 +
  107 + std::cout << "Connect: " << sender->repr() << std::endl;
  108 +
  109 + ConnAck connAck(ConnAckReturnCodes::Accepted);
  110 + MqttPacket response(connAck);
  111 + sender->writeMqttPacket(response);
  112 + sender->writeBufIntoFd();
90 113 }
91 114 else
92 115 {
... ... @@ -113,6 +136,14 @@ char MqttPacket::readByte()
113 136 return b;
114 137 }
115 138  
  139 +void MqttPacket::writeByte(char b)
  140 +{
  141 + if (pos + 1 > bites.size())
  142 + throw ProtocolError("Exceeding packet size");
  143 +
  144 + bites[pos++] = b;
  145 +}
  146 +
116 147 uint16_t MqttPacket::readTwoBytesToUInt16()
117 148 {
118 149 if (pos + 2 > bites.size())
... ...
mqttpacket.h
... ... @@ -16,20 +16,26 @@ class Client;
16 16 class MqttPacket
17 17 {
18 18 std::vector<char> bites;
19   - const size_t fixed_header_length;
20   - uint16_t variable_header_length;
  19 + size_t fixed_header_length = 0;
21 20 Client *sender;
22 21 size_t pos = 0;
23 22 ProtocolVersion protocolVersion = ProtocolVersion::None;
  23 +
  24 + char *readBytes(size_t length);
  25 + char readByte();
  26 + void writeByte(char b);
  27 + uint16_t readTwoBytesToUInt16();
  28 +
24 29 public:
25 30 PacketType packetType = PacketType::Reserved;
26 31 MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender);
  32 + MqttPacket(const ConnAck &connAck);
27 33  
28 34 void handle();
29 35 void handleConnect();
30   - char *readBytes(size_t length);
31   - char readByte();
32   - uint16_t readTwoBytesToUInt16();
  36 + size_t getSize() { return bites.size(); }
  37 + const std::vector<char> &getBites() { return bites; }
  38 +
33 39 };
34 40  
35 41 #endif // MQTTPACKET_H
... ...
types.cpp
1 1 #include "types.h"
2 2  
  3 +ConnAck::ConnAck(ConnAckReturnCodes return_code) :
  4 + return_code(return_code)
  5 +{
  6 +
  7 +}
... ...
... ... @@ -6,6 +6,18 @@ enum class PacketType
6 6 Reserved = 0,
7 7 CONNECT = 1,
8 8 CONNACK = 2,
  9 + PUBLISH = 3,
  10 + PUBACK = 4,
  11 + PUBREC = 5,
  12 + PUBREL = 6,
  13 + PUBCOMP = 7,
  14 + SUBSCRIBE = 8,
  15 + SUBACK = 9,
  16 + UNSUBSCRIBE = 10,
  17 + UNSUBACK = 11,
  18 + PINGREQ = 12,
  19 + PINGRESP = 13,
  20 + DISCONNECT = 14,
9 21  
10 22 Reserved2 = 15
11 23 };
... ... @@ -17,4 +29,21 @@ enum class ProtocolVersion
17 29 Mqtt311 = 0x04
18 30 };
19 31  
  32 +enum class ConnAckReturnCodes
  33 +{
  34 + Accepted = 0,
  35 + UnacceptableProtocolVersion = 1,
  36 + ClientIdRejected = 2,
  37 + ServerUnavailable = 3,
  38 + MalformedUsernameOrPassword = 4,
  39 + NotAuthorized = 5
  40 +};
  41 +
  42 +class ConnAck
  43 +{
  44 +public:
  45 + ConnAck(ConnAckReturnCodes return_code);
  46 + ConnAckReturnCodes return_code;
  47 +};
  48 +
20 49 #endif // TYPES_H
... ...