diff --git a/mqttpacket.cpp b/mqttpacket.cpp index 393f9e7..d5c85be 100644 --- a/mqttpacket.cpp +++ b/mqttpacket.cpp @@ -128,8 +128,7 @@ MqttPacket::MqttPacket(const SubAck &subAck) : { packetType = PacketType::SUBACK; first_byte = static_cast(packetType) << 4; - writeByte((subAck.packet_id & 0xFF00) >> 8); - writeByte(subAck.packet_id & 0x00FF); + writeUint16(subAck.packet_id); std::vector returnList; for (SubAckReturnCodes code : subAck.responses) @@ -146,8 +145,7 @@ MqttPacket::MqttPacket(const UnsubAck &unsubAck) : { packetType = PacketType::SUBACK; first_byte = static_cast(packetType) << 4; - writeByte((unsubAck.packet_id & 0xFF00) >> 8); - writeByte(unsubAck.packet_id & 0x00FF); + writeUint16(unsubAck.packet_id); calculateRemainingLength(); } @@ -168,10 +166,7 @@ MqttPacket::MqttPacket(const Publish &publish) : first_byte |= (publish.qos << 1); first_byte |= (static_cast(publish.retain) & 0b00000001); - char topicLenMSB = (topic.length() & 0xFF00) >> 8; - char topicLenLSB = topic.length() & 0x00FF; - writeByte(topicLenMSB); - writeByte(topicLenLSB); + writeUint16(topic.length()); writeBytes(topic.c_str(), topic.length()); if (publish.qos) @@ -204,11 +199,8 @@ void MqttPacket::pubCommonConstruct(const uint16_t packet_id, PacketType packetT first_byte = (static_cast(packetType) << 4) | firstByteDefaultBits; writeByte(first_byte); writeByte(2); // length is always 2. - uint8_t packetIdMSB = (packet_id & 0xFF00) >> 8; - uint8_t packetIdLSB = (packet_id & 0x00FF); packet_id_pos = pos; - writeByte(packetIdMSB); - writeByte(packetIdLSB); + writeUint16(packet_id); } MqttPacket::MqttPacket(const PubAck &pubAck) : @@ -766,13 +758,8 @@ void MqttPacket::setPacketId(uint16_t packet_id) assert(qos > 0); this->packet_id = packet_id; - pos = packet_id_pos; - - char topicLenMSB = (packet_id & 0xFF00) >> 8; - char topicLenLSB = (packet_id & 0x00FF); - writeByte(topicLenMSB); - writeByte(topicLenLSB); + writeUint16(packet_id); } uint16_t MqttPacket::getPacketId() const @@ -901,6 +888,18 @@ void MqttPacket::writeByte(char b) bites[pos++] = b; } +void MqttPacket::writeUint16(uint16_t x) +{ + if (pos + 2 > bites.size()) + throw ProtocolError("Exceeding packet size"); + + const uint8_t a = static_cast(x >> 8); + const uint8_t b = static_cast(x); + + bites[pos++] = a; + bites[pos++] = b; +} + void MqttPacket::writeBytes(const char *b, size_t len) { if (pos + len > bites.size()) diff --git a/mqttpacket.h b/mqttpacket.h index 3c89e2c..396eee3 100644 --- a/mqttpacket.h +++ b/mqttpacket.h @@ -60,6 +60,7 @@ class MqttPacket char *readBytes(size_t length); char readByte(); void writeByte(char b); + void writeUint16(uint16_t x); void writeBytes(const char *b, size_t len); uint16_t readTwoBytesToUInt16(); size_t remainingAfterPos();