diff --git a/client.cpp b/client.cpp index 5ba4a10..e644ec3 100644 --- a/client.cpp +++ b/client.cpp @@ -1,5 +1,8 @@ #include "client.h" +#include +#include + Client::Client(int fd, ThreadData_p threadData) : fd(fd), threadData(threadData) @@ -7,6 +10,7 @@ Client::Client(int fd, ThreadData_p threadData) : int flags = fcntl(fd, F_GETFL); fcntl(fd, F_SETFL, flags | O_NONBLOCK); readbuf = (char*)malloc(CLIENT_BUFFER_SIZE); + writebuf = (char*)malloc(CLIENT_BUFFER_SIZE); } Client::~Client() @@ -19,7 +23,7 @@ Client::~Client() // false means any kind of error we want to get rid of the client for. bool Client::readFdIntoBuffer() { - int read_size = getMaxWriteSize(); + int read_size = getReadBufMaxWriteSize(); int n; while ((n = read(fd, &readbuf[wi], read_size)) != 0) @@ -36,12 +40,12 @@ bool Client::readFdIntoBuffer() wi += n; - if (getBufBytesUsed() >= bufsize) + if (getReadBufBytesUsed() >= readBufsize) { - growBuffer(); + growReadBuffer(); } - read_size = getMaxWriteSize(); + read_size = getReadBufMaxWriteSize(); } if (n == 0) // client disconnected. @@ -52,9 +56,52 @@ bool Client::readFdIntoBuffer() return true; } +void Client::writeMqttPacket(MqttPacket &packet) +{ + if (packet.getSize() > getWriteBufMaxWriteSize()) + growWriteBuffer(packet.getSize()); + + std::memcpy(&writebuf[wwi], &packet.getBites()[0], packet.getSize()); + wwi += packet.getSize(); +} + +bool Client::writeBufIntoFd() // TODO: ignore the signal BROKEN PIPE we now also get when a client disappears. +{ + int n; + while ((n = write(fd, &writebuf[wri], getWriteBufBytesUsed())) != 0) + { + if (n < 0) + { + if (errno == EINTR) + continue; + if (errno == EAGAIN || errno == EWOULDBLOCK) + break; + else + return false; + } + + wri += n; + } + + if (wri == wwi) + { + wri = 0; + wwi = 0; + } + + return true; +} + +std::string Client::repr() +{ + std::ostringstream a; + a << "Client = " << clientid << ", user = " << username; + return a.str(); +} + bool Client::bufferToMqttPackets(std::vector &packetQueueIn) { - while (getBufBytesUsed() >= MQTT_HEADER_LENGH) + while (getReadBufBytesUsed() >= MQTT_HEADER_LENGH) { // Determine the packet length by decoding the variable length size_t remaining_length_i = 1; @@ -63,7 +110,7 @@ bool Client::bufferToMqttPackets(std::vector &packetQueueIn) unsigned char encodedByte = 0; do { - if (remaining_length_i >= getBufBytesUsed()) + if (remaining_length_i >= getReadBufBytesUsed()) break; encodedByte = readbuf[remaining_length_i++]; packet_length += (encodedByte & 127) * multiplier; @@ -79,7 +126,7 @@ bool Client::bufferToMqttPackets(std::vector &packetQueueIn) throw ProtocolError("An unauthenticated client sends a packet of 1 MB or bigger? Probably it's just random bytes."); } - if (packet_length <= getBufBytesUsed()) + if (packet_length <= getReadBufBytesUsed()) { MqttPacket packet(&readbuf[ri], packet_length, remaining_length_i, this); packetQueueIn.push_back(std::move(packet)); @@ -120,3 +167,5 @@ void Client::setClientProperties(const std::string &clientId, const std::string + + diff --git a/client.h b/client.h index 59f9d0e..c6381ea 100644 --- a/client.h +++ b/client.h @@ -22,10 +22,15 @@ class Client int fd; char *readbuf = NULL; // With many clients, it may not be smart to keep a (big) buffer around. - size_t bufsize = CLIENT_BUFFER_SIZE; + size_t readBufsize = CLIENT_BUFFER_SIZE; int wi = 0; int ri = 0; + char *writebuf = NULL; // With many clients, it may not be smart to keep a (big) buffer around. + size_t writeBufsize = CLIENT_BUFFER_SIZE; + int wwi = 0; + int wri = 0; + bool authenticated = false; bool connectPacketSeen = false; std::string clientid; @@ -34,22 +39,43 @@ class Client ThreadData_p threadData; - size_t getBufBytesUsed() + size_t getReadBufBytesUsed() { return wi - ri; }; - size_t getMaxWriteSize() + size_t getReadBufMaxWriteSize() { - size_t available = bufsize - getBufBytesUsed(); + size_t available = readBufsize - getReadBufBytesUsed(); return available; } - void growBuffer() + void growReadBuffer() { - const size_t newBufSize = bufsize * 2; + const size_t newBufSize = readBufsize * 2; readbuf = (char*)realloc(readbuf, newBufSize); - bufsize = newBufSize; + readBufsize = newBufSize; + } + + size_t getWriteBufMaxWriteSize() + { + size_t available = writeBufsize - getWriteBufBytesUsed(); + return available; + } + + size_t getWriteBufBytesUsed() + { + return wwi - wri; + }; + + void growWriteBuffer(size_t add_size) + { + if (add_size == 0) + return; + + const size_t newBufSize = writeBufsize + add_size; + writebuf = (char*)realloc(readbuf, newBufSize); + writeBufsize = newBufSize; } public: @@ -64,6 +90,10 @@ public: bool getAuthenticated() { return authenticated; } bool hasConnectPacketSeen() { return connectPacketSeen; } + void writeMqttPacket(MqttPacket &packet); + bool writeBufIntoFd(); + + std::string repr(); }; diff --git a/main.cpp b/main.cpp index 14632bd..323fc0e 100644 --- a/main.cpp +++ b/main.cpp @@ -39,17 +39,25 @@ void do_thread_work(ThreadData *threadData) Client_p client = threadData->getClient(fd); - if (client) // TODO: is this check necessary? + if (client) { - if (cur_ev.events | EPOLLIN) + if (cur_ev.events & EPOLLIN) { if (!client->readFdIntoBuffer()) + { + std::cout << "Disconnect: " << client->repr() << std::endl; threadData->removeClient(client); + } else { client->bufferToMqttPackets(packetQueueIn); // TODO: different, because now I need to give the packet a raw pointer. } } + if (cur_ev.events & EPOLLOUT) + { + if (!client->writeBufIntoFd()) + threadData->removeClient(client); + } } } } @@ -58,6 +66,7 @@ void do_thread_work(ThreadData *threadData) { packet.handle(); } + packetQueueIn.clear(); } } diff --git a/mqttpacket.cpp b/mqttpacket.cpp index c717e15..50c9eb5 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -1,24 +1,39 @@ #include "mqttpacket.h" #include +#include MqttPacket::MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender) : bites(len), fixed_header_length(fixed_header_length), sender(sender) { - unsigned char _packetType = buf[0] >> 4; + unsigned char _packetType = (buf[0] & 0xF0) >> 4; packetType = (PacketType)_packetType; pos += fixed_header_length; std::memcpy(&bites[0], buf, len); +} + +MqttPacket::MqttPacket(const ConnAck &connAck) : + bites(4) +{ + packetType = PacketType::CONNACK; + char first_byte = static_cast(packetType) << 4; + writeByte(first_byte); + writeByte(2); // length is always 2. + writeByte(0); // all connect-ack flags are 0, except session-present, but we don't have that yet. + writeByte(static_cast(connAck.return_code)); - variable_header_length = readTwoBytesToUInt16(); } void MqttPacket::handle() { if (packetType == PacketType::CONNECT) handleConnect(); + else if (packetType == PacketType::PINGREQ) + std::cout << "PING" << std::endl; + else if (packetType == PacketType::SUBSCRIBE) + std::cout << "Sub" << std::endl; } void MqttPacket::handleConnect() @@ -26,6 +41,7 @@ void MqttPacket::handleConnect() if (sender->hasConnectPacketSeen()) throw ProtocolError("Client already sent a CONNECT."); + uint16_t variable_header_length = readTwoBytesToUInt16(); if (variable_header_length == 4 || variable_header_length == 6) { @@ -87,6 +103,13 @@ void MqttPacket::handleConnect() // TODO: validate UTF8 encoded username/password. sender->setClientProperties(client_id, username, true, keep_alive); + + std::cout << "Connect: " << sender->repr() << std::endl; + + ConnAck connAck(ConnAckReturnCodes::Accepted); + MqttPacket response(connAck); + sender->writeMqttPacket(response); + sender->writeBufIntoFd(); } else { @@ -113,6 +136,14 @@ char MqttPacket::readByte() return b; } +void MqttPacket::writeByte(char b) +{ + if (pos + 1 > bites.size()) + throw ProtocolError("Exceeding packet size"); + + bites[pos++] = b; +} + uint16_t MqttPacket::readTwoBytesToUInt16() { if (pos + 2 > bites.size()) diff --git a/mqttpacket.h b/mqttpacket.h index 1b1ab41..39fea97 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -16,20 +16,26 @@ class Client; class MqttPacket { std::vector bites; - const size_t fixed_header_length; - uint16_t variable_header_length; + size_t fixed_header_length = 0; Client *sender; size_t pos = 0; ProtocolVersion protocolVersion = ProtocolVersion::None; + + char *readBytes(size_t length); + char readByte(); + void writeByte(char b); + uint16_t readTwoBytesToUInt16(); + public: PacketType packetType = PacketType::Reserved; MqttPacket(char *buf, size_t len, size_t fixed_header_length, Client *sender); + MqttPacket(const ConnAck &connAck); void handle(); void handleConnect(); - char *readBytes(size_t length); - char readByte(); - uint16_t readTwoBytesToUInt16(); + size_t getSize() { return bites.size(); } + const std::vector &getBites() { return bites; } + }; #endif // MQTTPACKET_H diff --git a/types.cpp b/types.cpp index 868cc2f..e71cd93 100644 --- a/types.cpp +++ b/types.cpp @@ -1,2 +1,7 @@ #include "types.h" +ConnAck::ConnAck(ConnAckReturnCodes return_code) : + return_code(return_code) +{ + +} diff --git a/types.h b/types.h index 57598cb..cfc854e 100644 --- a/types.h +++ b/types.h @@ -6,6 +6,18 @@ enum class PacketType Reserved = 0, CONNECT = 1, CONNACK = 2, + PUBLISH = 3, + PUBACK = 4, + PUBREC = 5, + PUBREL = 6, + PUBCOMP = 7, + SUBSCRIBE = 8, + SUBACK = 9, + UNSUBSCRIBE = 10, + UNSUBACK = 11, + PINGREQ = 12, + PINGRESP = 13, + DISCONNECT = 14, Reserved2 = 15 }; @@ -17,4 +29,21 @@ enum class ProtocolVersion Mqtt311 = 0x04 }; +enum class ConnAckReturnCodes +{ + Accepted = 0, + UnacceptableProtocolVersion = 1, + ClientIdRejected = 2, + ServerUnavailable = 3, + MalformedUsernameOrPassword = 4, + NotAuthorized = 5 +}; + +class ConnAck +{ +public: + ConnAck(ConnAckReturnCodes return_code); + ConnAckReturnCodes return_code; +}; + #endif // TYPES_H