Commit 5dcad43303fd02ed7c68b16d90b29717b119fb7f

Authored by Wiebe Cazemier
1 parent d61c593d

Write response

client.cpp
1 #include "client.h" 1 #include "client.h"
2 2
  3 +#include <cstring>
  4 +#include <sstream>
  5 +
3 Client::Client(int fd, ThreadData_p threadData) : 6 Client::Client(int fd, ThreadData_p threadData) :
4 fd(fd), 7 fd(fd),
5 threadData(threadData) 8 threadData(threadData)
@@ -7,6 +10,7 @@ Client::Client(int fd, ThreadData_p threadData) : @@ -7,6 +10,7 @@ Client::Client(int fd, ThreadData_p threadData) :
7 int flags = fcntl(fd, F_GETFL); 10 int flags = fcntl(fd, F_GETFL);
8 fcntl(fd, F_SETFL, flags | O_NONBLOCK); 11 fcntl(fd, F_SETFL, flags | O_NONBLOCK);
9 readbuf = (char*)malloc(CLIENT_BUFFER_SIZE); 12 readbuf = (char*)malloc(CLIENT_BUFFER_SIZE);
  13 + writebuf = (char*)malloc(CLIENT_BUFFER_SIZE);
10 } 14 }
11 15
12 Client::~Client() 16 Client::~Client()
@@ -19,7 +23,7 @@ Client::~Client() @@ -19,7 +23,7 @@ Client::~Client()
19 // false means any kind of error we want to get rid of the client for. 23 // false means any kind of error we want to get rid of the client for.
20 bool Client::readFdIntoBuffer() 24 bool Client::readFdIntoBuffer()
21 { 25 {
22 - int read_size = getMaxWriteSize(); 26 + int read_size = getReadBufMaxWriteSize();
23 27
24 int n; 28 int n;
25 while ((n = read(fd, &readbuf[wi], read_size)) != 0) 29 while ((n = read(fd, &readbuf[wi], read_size)) != 0)
@@ -36,12 +40,12 @@ bool Client::readFdIntoBuffer() @@ -36,12 +40,12 @@ bool Client::readFdIntoBuffer()
36 40
37 wi += n; 41 wi += n;
38 42
39 - if (getBufBytesUsed() >= bufsize) 43 + if (getReadBufBytesUsed() >= readBufsize)
40 { 44 {
41 - growBuffer(); 45 + growReadBuffer();
42 } 46 }
43 47
44 - read_size = getMaxWriteSize(); 48 + read_size = getReadBufMaxWriteSize();
45 } 49 }
46 50
47 if (n == 0) // client disconnected. 51 if (n == 0) // client disconnected.
@@ -52,9 +56,52 @@ bool Client::readFdIntoBuffer() @@ -52,9 +56,52 @@ bool Client::readFdIntoBuffer()
52 return true; 56 return true;
53 } 57 }
54 58
  59 +void Client::writeMqttPacket(MqttPacket &packet)
  60 +{
  61 + if (packet.getSize() > getWriteBufMaxWriteSize())
  62 + growWriteBuffer(packet.getSize());
  63 +
  64 + std::memcpy(&writebuf[wwi], &packet.getBites()[0], packet.getSize());
  65 + wwi += packet.getSize();
  66 +}
  67 +
  68 +bool Client::writeBufIntoFd() // TODO: ignore the signal BROKEN PIPE we now also get when a client disappears.
  69 +{
  70 + int n;
  71 + while ((n = write(fd, &writebuf[wri], getWriteBufBytesUsed())) != 0)
  72 + {
  73 + if (n < 0)
  74 + {
  75 + if (errno == EINTR)
  76 + continue;
  77 + if (errno == EAGAIN || errno == EWOULDBLOCK)
  78 + break;
  79 + else
  80 + return false;
  81 + }
  82 +
  83 + wri += n;
  84 + }
  85 +
  86 + if (wri == wwi)
  87 + {
  88 + wri = 0;
  89 + wwi = 0;
  90 + }
  91 +
  92 + return true;
  93 +}
  94 +
  95 +std::string Client::repr()
  96 +{
  97 + std::ostringstream a;
  98 + a << "Client = " << clientid << ", user = " << username;
  99 + return a.str();
  100 +}
  101 +
55 bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn) 102 bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn)
56 { 103 {
57 - while (getBufBytesUsed() >= MQTT_HEADER_LENGH) 104 + while (getReadBufBytesUsed() >= MQTT_HEADER_LENGH)
58 { 105 {
59 // Determine the packet length by decoding the variable length 106 // Determine the packet length by decoding the variable length
60 size_t remaining_length_i = 1; 107 size_t remaining_length_i = 1;
@@ -63,7 +110,7 @@ bool Client::bufferToMqttPackets(std::vector&lt;MqttPacket&gt; &amp;packetQueueIn) @@ -63,7 +110,7 @@ bool Client::bufferToMqttPackets(std::vector&lt;MqttPacket&gt; &amp;packetQueueIn)
63 unsigned char encodedByte = 0; 110 unsigned char encodedByte = 0;
64 do 111 do
65 { 112 {
66 - if (remaining_length_i >= getBufBytesUsed()) 113 + if (remaining_length_i >= getReadBufBytesUsed())
67 break; 114 break;
68 encodedByte = readbuf[remaining_length_i++]; 115 encodedByte = readbuf[remaining_length_i++];
69 packet_length += (encodedByte & 127) * multiplier; 116 packet_length += (encodedByte & 127) * multiplier;
@@ -79,7 +126,7 @@ bool Client::bufferToMqttPackets(std::vector&lt;MqttPacket&gt; &amp;packetQueueIn) @@ -79,7 +126,7 @@ bool Client::bufferToMqttPackets(std::vector&lt;MqttPacket&gt; &amp;packetQueueIn)
79 throw ProtocolError("An unauthenticated client sends a packet of 1 MB or bigger? Probably it's just random bytes."); 126 throw ProtocolError("An unauthenticated client sends a packet of 1 MB or bigger? Probably it's just random bytes.");
80 } 127 }
81 128
82 - if (packet_length <= getBufBytesUsed()) 129 + if (packet_length <= getReadBufBytesUsed())
83 { 130 {
84 MqttPacket packet(&readbuf[ri], packet_length, remaining_length_i, this); 131 MqttPacket packet(&readbuf[ri], packet_length, remaining_length_i, this);
85 packetQueueIn.push_back(std::move(packet)); 132 packetQueueIn.push_back(std::move(packet));
@@ -120,3 +167,5 @@ void Client::setClientProperties(const std::string &amp;clientId, const std::string @@ -120,3 +167,5 @@ void Client::setClientProperties(const std::string &amp;clientId, const std::string
120 167
121 168
122 169
  170 +
  171 +
client.h
@@ -22,10 +22,15 @@ class Client @@ -22,10 +22,15 @@ class Client
22 int fd; 22 int fd;
23 23
24 char *readbuf = NULL; // With many clients, it may not be smart to keep a (big) buffer around. 24 char *readbuf = NULL; // With many clients, it may not be smart to keep a (big) buffer around.
25 - size_t bufsize = CLIENT_BUFFER_SIZE; 25 + size_t readBufsize = CLIENT_BUFFER_SIZE;
26 int wi = 0; 26 int wi = 0;
27 int ri = 0; 27 int ri = 0;
28 28
  29 + char *writebuf = NULL; // With many clients, it may not be smart to keep a (big) buffer around.
  30 + size_t writeBufsize = CLIENT_BUFFER_SIZE;
  31 + int wwi = 0;
  32 + int wri = 0;
  33 +
29 bool authenticated = false; 34 bool authenticated = false;
30 bool connectPacketSeen = false; 35 bool connectPacketSeen = false;
31 std::string clientid; 36 std::string clientid;
@@ -34,22 +39,43 @@ class Client @@ -34,22 +39,43 @@ class Client
34 39
35 ThreadData_p threadData; 40 ThreadData_p threadData;
36 41
37 - size_t getBufBytesUsed() 42 + size_t getReadBufBytesUsed()
38 { 43 {
39 return wi - ri; 44 return wi - ri;
40 }; 45 };
41 46
42 - size_t getMaxWriteSize() 47 + size_t getReadBufMaxWriteSize()
43 { 48 {
44 - size_t available = bufsize - getBufBytesUsed(); 49 + size_t available = readBufsize - getReadBufBytesUsed();
45 return available; 50 return available;
46 } 51 }
47 52
48 - void growBuffer() 53 + void growReadBuffer()
49 { 54 {
50 - const size_t newBufSize = bufsize * 2; 55 + const size_t newBufSize = readBufsize * 2;
51 readbuf = (char*)realloc(readbuf, newBufSize); 56 readbuf = (char*)realloc(readbuf, newBufSize);
52 - bufsize = newBufSize; 57 + readBufsize = newBufSize;
  58 + }
  59 +
  60 + size_t getWriteBufMaxWriteSize()
  61 + {
  62 + size_t available = writeBufsize - getWriteBufBytesUsed();
  63 + return available;
  64 + }
  65 +
  66 + size_t getWriteBufBytesUsed()
  67 + {
  68 + return wwi - wri;
  69 + };
  70 +
  71 + void growWriteBuffer(size_t add_size)
  72 + {
  73 + if (add_size == 0)
  74 + return;
  75 +
  76 + const size_t newBufSize = writeBufsize + add_size;
  77 + writebuf = (char*)realloc(readbuf, newBufSize);
  78 + writeBufsize = newBufSize;
53 } 79 }
54 80
55 public: 81 public:
@@ -64,6 +90,10 @@ public: @@ -64,6 +90,10 @@ public:
64 bool getAuthenticated() { return authenticated; } 90 bool getAuthenticated() { return authenticated; }
65 bool hasConnectPacketSeen() { return connectPacketSeen; } 91 bool hasConnectPacketSeen() { return connectPacketSeen; }
66 92
  93 + void writeMqttPacket(MqttPacket &packet);
  94 + bool writeBufIntoFd();
  95 +
  96 + std::string repr();
67 97
68 }; 98 };
69 99
main.cpp
@@ -39,17 +39,25 @@ void do_thread_work(ThreadData *threadData) @@ -39,17 +39,25 @@ void do_thread_work(ThreadData *threadData)
39 39
40 Client_p client = threadData->getClient(fd); 40 Client_p client = threadData->getClient(fd);
41 41
42 - if (client) // TODO: is this check necessary? 42 + if (client)
43 { 43 {
44 - if (cur_ev.events | EPOLLIN) 44 + if (cur_ev.events & EPOLLIN)
45 { 45 {
46 if (!client->readFdIntoBuffer()) 46 if (!client->readFdIntoBuffer())
  47 + {
  48 + std::cout << "Disconnect: " << client->repr() << std::endl;
47 threadData->removeClient(client); 49 threadData->removeClient(client);
  50 + }
48 else 51 else
49 { 52 {
50 client->bufferToMqttPackets(packetQueueIn); // TODO: different, because now I need to give the packet a raw pointer. 53 client->bufferToMqttPackets(packetQueueIn); // TODO: different, because now I need to give the packet a raw pointer.
51 } 54 }
52 } 55 }
  56 + if (cur_ev.events & EPOLLOUT)
  57 + {
  58 + if (!client->writeBufIntoFd())
  59 + threadData->removeClient(client);
  60 + }
53 } 61 }
54 } 62 }
55 } 63 }
@@ -58,6 +66,7 @@ void do_thread_work(ThreadData *threadData) @@ -58,6 +66,7 @@ void do_thread_work(ThreadData *threadData)
58 { 66 {
59 packet.handle(); 67 packet.handle();
60 } 68 }
  69 + packetQueueIn.clear();
61 } 70 }
62 } 71 }
63 72
mqttpacket.cpp
1 #include "mqttpacket.h" 1 #include "mqttpacket.h"
2 #include <cstring> 2 #include <cstring>
  3 +#include <iostream>
3 4
4 MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender) : 5 MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender) :
5 bites(len), 6 bites(len),
6 fixed_header_length(fixed_header_length), 7 fixed_header_length(fixed_header_length),
7 sender(sender) 8 sender(sender)
8 { 9 {
9 - unsigned char _packetType = buf[0] >> 4; 10 + unsigned char _packetType = (buf[0] & 0xF0) >> 4;
10 packetType = (PacketType)_packetType; 11 packetType = (PacketType)_packetType;
11 pos += fixed_header_length; 12 pos += fixed_header_length;
12 13
13 std::memcpy(&bites[0], buf, len); 14 std::memcpy(&bites[0], buf, len);
  15 +}
  16 +
  17 +MqttPacket::MqttPacket(const ConnAck &connAck) :
  18 + bites(4)
  19 +{
  20 + packetType = PacketType::CONNACK;
  21 + char first_byte = static_cast<char>(packetType) << 4;
  22 + writeByte(first_byte);
  23 + writeByte(2); // length is always 2.
  24 + writeByte(0); // all connect-ack flags are 0, except session-present, but we don't have that yet.
  25 + writeByte(static_cast<char>(connAck.return_code));
14 26
15 - variable_header_length = readTwoBytesToUInt16();  
16 } 27 }
17 28
18 void MqttPacket::handle() 29 void MqttPacket::handle()
19 { 30 {
20 if (packetType == PacketType::CONNECT) 31 if (packetType == PacketType::CONNECT)
21 handleConnect(); 32 handleConnect();
  33 + else if (packetType == PacketType::PINGREQ)
  34 + std::cout << "PING" << std::endl;
  35 + else if (packetType == PacketType::SUBSCRIBE)
  36 + std::cout << "Sub" << std::endl;
22 } 37 }
23 38
24 void MqttPacket::handleConnect() 39 void MqttPacket::handleConnect()
@@ -26,6 +41,7 @@ void MqttPacket::handleConnect() @@ -26,6 +41,7 @@ void MqttPacket::handleConnect()
26 if (sender->hasConnectPacketSeen()) 41 if (sender->hasConnectPacketSeen())
27 throw ProtocolError("Client already sent a CONNECT."); 42 throw ProtocolError("Client already sent a CONNECT.");
28 43
  44 + uint16_t variable_header_length = readTwoBytesToUInt16();
29 45
30 if (variable_header_length == 4 || variable_header_length == 6) 46 if (variable_header_length == 4 || variable_header_length == 6)
31 { 47 {
@@ -87,6 +103,13 @@ void MqttPacket::handleConnect() @@ -87,6 +103,13 @@ void MqttPacket::handleConnect()
87 // TODO: validate UTF8 encoded username/password. 103 // TODO: validate UTF8 encoded username/password.
88 104
89 sender->setClientProperties(client_id, username, true, keep_alive); 105 sender->setClientProperties(client_id, username, true, keep_alive);
  106 +
  107 + std::cout << "Connect: " << sender->repr() << std::endl;
  108 +
  109 + ConnAck connAck(ConnAckReturnCodes::Accepted);
  110 + MqttPacket response(connAck);
  111 + sender->writeMqttPacket(response);
  112 + sender->writeBufIntoFd();
90 } 113 }
91 else 114 else
92 { 115 {
@@ -113,6 +136,14 @@ char MqttPacket::readByte() @@ -113,6 +136,14 @@ char MqttPacket::readByte()
113 return b; 136 return b;
114 } 137 }
115 138
  139 +void MqttPacket::writeByte(char b)
  140 +{
  141 + if (pos + 1 > bites.size())
  142 + throw ProtocolError("Exceeding packet size");
  143 +
  144 + bites[pos++] = b;
  145 +}
  146 +
116 uint16_t MqttPacket::readTwoBytesToUInt16() 147 uint16_t MqttPacket::readTwoBytesToUInt16()
117 { 148 {
118 if (pos + 2 > bites.size()) 149 if (pos + 2 > bites.size())
mqttpacket.h
@@ -16,20 +16,26 @@ class Client; @@ -16,20 +16,26 @@ class Client;
16 class MqttPacket 16 class MqttPacket
17 { 17 {
18 std::vector<char> bites; 18 std::vector<char> bites;
19 - const size_t fixed_header_length;  
20 - uint16_t variable_header_length; 19 + size_t fixed_header_length = 0;
21 Client *sender; 20 Client *sender;
22 size_t pos = 0; 21 size_t pos = 0;
23 ProtocolVersion protocolVersion = ProtocolVersion::None; 22 ProtocolVersion protocolVersion = ProtocolVersion::None;
  23 +
  24 + char *readBytes(size_t length);
  25 + char readByte();
  26 + void writeByte(char b);
  27 + uint16_t readTwoBytesToUInt16();
  28 +
24 public: 29 public:
25 PacketType packetType = PacketType::Reserved; 30 PacketType packetType = PacketType::Reserved;
26 MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender); 31 MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender);
  32 + MqttPacket(const ConnAck &connAck);
27 33
28 void handle(); 34 void handle();
29 void handleConnect(); 35 void handleConnect();
30 - char *readBytes(size_t length);  
31 - char readByte();  
32 - uint16_t readTwoBytesToUInt16(); 36 + size_t getSize() { return bites.size(); }
  37 + const std::vector<char> &getBites() { return bites; }
  38 +
33 }; 39 };
34 40
35 #endif // MQTTPACKET_H 41 #endif // MQTTPACKET_H
types.cpp
1 #include "types.h" 1 #include "types.h"
2 2
  3 +ConnAck::ConnAck(ConnAckReturnCodes return_code) :
  4 + return_code(return_code)
  5 +{
  6 +
  7 +}
@@ -6,6 +6,18 @@ enum class PacketType @@ -6,6 +6,18 @@ enum class PacketType
6 Reserved = 0, 6 Reserved = 0,
7 CONNECT = 1, 7 CONNECT = 1,
8 CONNACK = 2, 8 CONNACK = 2,
  9 + PUBLISH = 3,
  10 + PUBACK = 4,
  11 + PUBREC = 5,
  12 + PUBREL = 6,
  13 + PUBCOMP = 7,
  14 + SUBSCRIBE = 8,
  15 + SUBACK = 9,
  16 + UNSUBSCRIBE = 10,
  17 + UNSUBACK = 11,
  18 + PINGREQ = 12,
  19 + PINGRESP = 13,
  20 + DISCONNECT = 14,
9 21
10 Reserved2 = 15 22 Reserved2 = 15
11 }; 23 };
@@ -17,4 +29,21 @@ enum class ProtocolVersion @@ -17,4 +29,21 @@ enum class ProtocolVersion
17 Mqtt311 = 0x04 29 Mqtt311 = 0x04
18 }; 30 };
19 31
  32 +enum class ConnAckReturnCodes
  33 +{
  34 + Accepted = 0,
  35 + UnacceptableProtocolVersion = 1,
  36 + ClientIdRejected = 2,
  37 + ServerUnavailable = 3,
  38 + MalformedUsernameOrPassword = 4,
  39 + NotAuthorized = 5
  40 +};
  41 +
  42 +class ConnAck
  43 +{
  44 +public:
  45 + ConnAck(ConnAckReturnCodes return_code);
  46 + ConnAckReturnCodes return_code;
  47 +};
  48 +
20 #endif // TYPES_H 49 #endif // TYPES_H