Commit 5614ee5cee70252c329d870cf33ece5820b49f88

Authored by Wiebe Cazemier
1 parent 3893af5d

Move some parsing logic to the MqttPacket class

This makes more sense, logically, and also helps in tests I'm about to
write.
client.cpp
... ... @@ -195,17 +195,7 @@ int Client::writeMqttPacket(const MqttPacket &packet, const char qos)
195 195 return 0;
196 196 }
197 197  
198   - writebuf.ensureFreeSpace(packet.getSizeIncludingNonPresentHeader());
199   -
200   - if (!packet.containsFixedHeader())
201   - {
202   - writebuf.headPtr()[0] = packet.getFirstByte();
203   - writebuf.advanceHead(1);
204   - RemainingLength r = packet.getRemainingLength();
205   - writebuf.write(r.bytes, r.len);
206   - }
207   -
208   - writebuf.write(packet.getBites().data(), packet.getBites().size());
  198 + packet.readIntoBuf(writebuf);
209 199  
210 200 if (packet.packetType == PacketType::DISCONNECT)
211 201 setReadyForDisconnect();
... ... @@ -396,57 +386,10 @@ void Client::setReadyForReading(bool val)
396 386 }
397 387 }
398 388  
399   -bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender)
  389 +void Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender)
400 390 {
401   - while (readbuf.usedBytes() >= MQTT_HEADER_LENGH)
402   - {
403   - // Determine the packet length by decoding the variable length
404   - int remaining_length_i = 1; // index of 'remaining length' field is one after start.
405   - uint fixed_header_length = 1;
406   - size_t multiplier = 1;
407   - size_t packet_length = 0;
408   - unsigned char encodedByte = 0;
409   - do
410   - {
411   - fixed_header_length++;
412   -
413   - if (fixed_header_length > 5)
414   - throw ProtocolError("Packet signifies more than 5 bytes in variable length header. Invalid.");
415   -
416   - // This happens when you only don't have all the bytes that specify the remaining length.
417   - if (fixed_header_length > readbuf.usedBytes())
418   - return false;
419   -
420   - encodedByte = readbuf.peakAhead(remaining_length_i++);
421   - packet_length += (encodedByte & 127) * multiplier;
422   - multiplier *= 128;
423   - if (multiplier > 128*128*128*128)
424   - throw ProtocolError("Malformed Remaining Length.");
425   - }
426   - while ((encodedByte & 128) != 0);
427   - packet_length += fixed_header_length;
428   -
429   - if (!authenticated && packet_length >= 1024*1024)
430   - {
431   - throw ProtocolError("An unauthenticated client sends a packet of 1 MB or bigger? Probably it's just random bytes.");
432   - }
433   -
434   - if (packet_length > ABSOLUTE_MAX_PACKET_SIZE)
435   - {
436   - throw ProtocolError("A client sends a packet claiming to be bigger than the maximum MQTT allows.");
437   - }
438   -
439   - if (packet_length <= readbuf.usedBytes())
440   - {
441   - packetQueueIn.emplace_back(readbuf, packet_length, fixed_header_length, sender);
442   - }
443   - else
444   - break;
445   - }
446   -
  391 + MqttPacket::bufferToMqttPackets(readbuf, packetQueueIn, sender);
447 392 setReadyForReading(readbuf.freeSpace() > 0);
448   -
449   - return true;
450 393 }
451 394  
452 395 void Client::setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession)
... ...
client.h
... ... @@ -104,7 +104,7 @@ public:
104 104 void startOrContinueSslAccept();
105 105 void markAsDisconnecting();
106 106 bool readFdIntoBuffer();
107   - bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender);
  107 + void bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender);
108 108 void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession);
109 109 void setWill(const std::string &topic, const std::string &payload, bool retain, char qos);
110 110 void clearWill();
... ...
mqttpacket.cpp
... ... @@ -181,6 +181,55 @@ MqttPacket::MqttPacket(const PubRel &amp;pubRel) :
181 181 pubCommonConstruct(pubRel.packet_id, PacketType::PUBREL, 0b0010);
182 182 }
183 183  
  184 +void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender)
  185 +{
  186 + while (buf.usedBytes() >= MQTT_HEADER_LENGH)
  187 + {
  188 + // Determine the packet length by decoding the variable length
  189 + int remaining_length_i = 1; // index of 'remaining length' field is one after start.
  190 + uint fixed_header_length = 1;
  191 + size_t multiplier = 1;
  192 + size_t packet_length = 0;
  193 + unsigned char encodedByte = 0;
  194 + do
  195 + {
  196 + fixed_header_length++;
  197 +
  198 + if (fixed_header_length > 5)
  199 + throw ProtocolError("Packet signifies more than 5 bytes in variable length header. Invalid.");
  200 +
  201 + // This happens when you only don't have all the bytes that specify the remaining length.
  202 + if (fixed_header_length > buf.usedBytes())
  203 + return;
  204 +
  205 + encodedByte = buf.peakAhead(remaining_length_i++);
  206 + packet_length += (encodedByte & 127) * multiplier;
  207 + multiplier *= 128;
  208 + if (multiplier > 128*128*128*128)
  209 + throw ProtocolError("Malformed Remaining Length.");
  210 + }
  211 + while ((encodedByte & 128) != 0);
  212 + packet_length += fixed_header_length;
  213 +
  214 + if (sender && !sender->getAuthenticated() && packet_length >= 1024*1024)
  215 + {
  216 + throw ProtocolError("An unauthenticated client sends a packet of 1 MB or bigger? Probably it's just random bytes.");
  217 + }
  218 +
  219 + if (packet_length > ABSOLUTE_MAX_PACKET_SIZE)
  220 + {
  221 + throw ProtocolError("A client sends a packet claiming to be bigger than the maximum MQTT allows.");
  222 + }
  223 +
  224 + if (packet_length <= buf.usedBytes())
  225 + {
  226 + packetQueueIn.emplace_back(buf, packet_length, fixed_header_length, sender);
  227 + }
  228 + else
  229 + break;
  230 + }
  231 +}
  232 +
184 233 void MqttPacket::handle()
185 234 {
186 235 if (packetType == PacketType::Reserved)
... ... @@ -838,6 +887,23 @@ size_t MqttPacket::remainingAfterPos()
838 887 }
839 888  
840 889  
  890 +void MqttPacket::readIntoBuf(CirBuf &buf) const
  891 +{
  892 + buf.ensureFreeSpace(getSizeIncludingNonPresentHeader());
  893 +
  894 + if (!containsFixedHeader())
  895 + {
  896 + assert(remainingLength.len > 0);
  897 +
  898 + buf.headPtr()[0] = getFirstByte();
  899 + buf.advanceHead(1);
  900 + buf.write(remainingLength.bytes, remainingLength.len);
  901 + }
  902 +
  903 + buf.write(bites.data(), bites.size());
  904 +}
  905 +
  906 +
841 907  
842 908  
843 909  
... ...
mqttpacket.h
... ... @@ -92,6 +92,8 @@ public:
92 92 MqttPacket(const PubComp &pubComp);
93 93 MqttPacket(const PubRel &pubRel);
94 94  
  95 + static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender);
  96 +
95 97 void handle();
96 98 void handleConnect();
97 99 void handleDisconnect();
... ... @@ -119,6 +121,7 @@ public:
119 121 void setDuplicate();
120 122 size_t getTotalMemoryFootprint();
121 123 std::string getPayloadCopy();
  124 + void readIntoBuf(CirBuf &buf) const;
122 125 };
123 126  
124 127 #endif // MQTTPACKET_H
... ...