Commit 5614ee5cee70252c329d870cf33ece5820b49f88

Authored by Wiebe Cazemier
1 parent 3893af5d

Move some parsing logic to the MqttPacket class

This makes more sense, logically, and also helps in tests I'm about to
write.
client.cpp
@@ -195,17 +195,7 @@ int Client::writeMqttPacket(const MqttPacket &packet, const char qos) @@ -195,17 +195,7 @@ int Client::writeMqttPacket(const MqttPacket &packet, const char qos)
195 return 0; 195 return 0;
196 } 196 }
197 197
198 - writebuf.ensureFreeSpace(packet.getSizeIncludingNonPresentHeader());  
199 -  
200 - if (!packet.containsFixedHeader())  
201 - {  
202 - writebuf.headPtr()[0] = packet.getFirstByte();  
203 - writebuf.advanceHead(1);  
204 - RemainingLength r = packet.getRemainingLength();  
205 - writebuf.write(r.bytes, r.len);  
206 - }  
207 -  
208 - writebuf.write(packet.getBites().data(), packet.getBites().size()); 198 + packet.readIntoBuf(writebuf);
209 199
210 if (packet.packetType == PacketType::DISCONNECT) 200 if (packet.packetType == PacketType::DISCONNECT)
211 setReadyForDisconnect(); 201 setReadyForDisconnect();
@@ -396,57 +386,10 @@ void Client::setReadyForReading(bool val) @@ -396,57 +386,10 @@ void Client::setReadyForReading(bool val)
396 } 386 }
397 } 387 }
398 388
399 -bool Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender) 389 +void Client::bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender)
400 { 390 {
401 - while (readbuf.usedBytes() >= MQTT_HEADER_LENGH)  
402 - {  
403 - // Determine the packet length by decoding the variable length  
404 - int remaining_length_i = 1; // index of 'remaining length' field is one after start.  
405 - uint fixed_header_length = 1;  
406 - size_t multiplier = 1;  
407 - size_t packet_length = 0;  
408 - unsigned char encodedByte = 0;  
409 - do  
410 - {  
411 - fixed_header_length++;  
412 -  
413 - if (fixed_header_length > 5)  
414 - throw ProtocolError("Packet signifies more than 5 bytes in variable length header. Invalid.");  
415 -  
416 - // This happens when you only don't have all the bytes that specify the remaining length.  
417 - if (fixed_header_length > readbuf.usedBytes())  
418 - return false;  
419 -  
420 - encodedByte = readbuf.peakAhead(remaining_length_i++);  
421 - packet_length += (encodedByte & 127) * multiplier;  
422 - multiplier *= 128;  
423 - if (multiplier > 128*128*128*128)  
424 - throw ProtocolError("Malformed Remaining Length.");  
425 - }  
426 - while ((encodedByte & 128) != 0);  
427 - packet_length += fixed_header_length;  
428 -  
429 - if (!authenticated && packet_length >= 1024*1024)  
430 - {  
431 - throw ProtocolError("An unauthenticated client sends a packet of 1 MB or bigger? Probably it's just random bytes.");  
432 - }  
433 -  
434 - if (packet_length > ABSOLUTE_MAX_PACKET_SIZE)  
435 - {  
436 - throw ProtocolError("A client sends a packet claiming to be bigger than the maximum MQTT allows.");  
437 - }  
438 -  
439 - if (packet_length <= readbuf.usedBytes())  
440 - {  
441 - packetQueueIn.emplace_back(readbuf, packet_length, fixed_header_length, sender);  
442 - }  
443 - else  
444 - break;  
445 - }  
446 - 391 + MqttPacket::bufferToMqttPackets(readbuf, packetQueueIn, sender);
447 setReadyForReading(readbuf.freeSpace() > 0); 392 setReadyForReading(readbuf.freeSpace() > 0);
448 -  
449 - return true;  
450 } 393 }
451 394
452 void Client::setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession) 395 void Client::setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession)
client.h
@@ -104,7 +104,7 @@ public: @@ -104,7 +104,7 @@ public:
104 void startOrContinueSslAccept(); 104 void startOrContinueSslAccept();
105 void markAsDisconnecting(); 105 void markAsDisconnecting();
106 bool readFdIntoBuffer(); 106 bool readFdIntoBuffer();
107 - bool bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender); 107 + void bufferToMqttPackets(std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender);
108 void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession); 108 void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::string username, bool connectPacketSeen, uint16_t keepalive, bool cleanSession);
109 void setWill(const std::string &topic, const std::string &payload, bool retain, char qos); 109 void setWill(const std::string &topic, const std::string &payload, bool retain, char qos);
110 void clearWill(); 110 void clearWill();
mqttpacket.cpp
@@ -181,6 +181,55 @@ MqttPacket::MqttPacket(const PubRel &amp;pubRel) : @@ -181,6 +181,55 @@ MqttPacket::MqttPacket(const PubRel &amp;pubRel) :
181 pubCommonConstruct(pubRel.packet_id, PacketType::PUBREL, 0b0010); 181 pubCommonConstruct(pubRel.packet_id, PacketType::PUBREL, 0b0010);
182 } 182 }
183 183
  184 +void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender)
  185 +{
  186 + while (buf.usedBytes() >= MQTT_HEADER_LENGH)
  187 + {
  188 + // Determine the packet length by decoding the variable length
  189 + int remaining_length_i = 1; // index of 'remaining length' field is one after start.
  190 + uint fixed_header_length = 1;
  191 + size_t multiplier = 1;
  192 + size_t packet_length = 0;
  193 + unsigned char encodedByte = 0;
  194 + do
  195 + {
  196 + fixed_header_length++;
  197 +
  198 + if (fixed_header_length > 5)
  199 + throw ProtocolError("Packet signifies more than 5 bytes in variable length header. Invalid.");
  200 +
  201 + // This happens when you only don't have all the bytes that specify the remaining length.
  202 + if (fixed_header_length > buf.usedBytes())
  203 + return;
  204 +
  205 + encodedByte = buf.peakAhead(remaining_length_i++);
  206 + packet_length += (encodedByte & 127) * multiplier;
  207 + multiplier *= 128;
  208 + if (multiplier > 128*128*128*128)
  209 + throw ProtocolError("Malformed Remaining Length.");
  210 + }
  211 + while ((encodedByte & 128) != 0);
  212 + packet_length += fixed_header_length;
  213 +
  214 + if (sender && !sender->getAuthenticated() && packet_length >= 1024*1024)
  215 + {
  216 + throw ProtocolError("An unauthenticated client sends a packet of 1 MB or bigger? Probably it's just random bytes.");
  217 + }
  218 +
  219 + if (packet_length > ABSOLUTE_MAX_PACKET_SIZE)
  220 + {
  221 + throw ProtocolError("A client sends a packet claiming to be bigger than the maximum MQTT allows.");
  222 + }
  223 +
  224 + if (packet_length <= buf.usedBytes())
  225 + {
  226 + packetQueueIn.emplace_back(buf, packet_length, fixed_header_length, sender);
  227 + }
  228 + else
  229 + break;
  230 + }
  231 +}
  232 +
184 void MqttPacket::handle() 233 void MqttPacket::handle()
185 { 234 {
186 if (packetType == PacketType::Reserved) 235 if (packetType == PacketType::Reserved)
@@ -838,6 +887,23 @@ size_t MqttPacket::remainingAfterPos() @@ -838,6 +887,23 @@ size_t MqttPacket::remainingAfterPos()
838 } 887 }
839 888
840 889
  890 +void MqttPacket::readIntoBuf(CirBuf &buf) const
  891 +{
  892 + buf.ensureFreeSpace(getSizeIncludingNonPresentHeader());
  893 +
  894 + if (!containsFixedHeader())
  895 + {
  896 + assert(remainingLength.len > 0);
  897 +
  898 + buf.headPtr()[0] = getFirstByte();
  899 + buf.advanceHead(1);
  900 + buf.write(remainingLength.bytes, remainingLength.len);
  901 + }
  902 +
  903 + buf.write(bites.data(), bites.size());
  904 +}
  905 +
  906 +
841 907
842 908
843 909
mqttpacket.h
@@ -92,6 +92,8 @@ public: @@ -92,6 +92,8 @@ public:
92 MqttPacket(const PubComp &pubComp); 92 MqttPacket(const PubComp &pubComp);
93 MqttPacket(const PubRel &pubRel); 93 MqttPacket(const PubRel &pubRel);
94 94
  95 + static void bufferToMqttPackets(CirBuf &buf, std::vector<MqttPacket> &packetQueueIn, std::shared_ptr<Client> &sender);
  96 +
95 void handle(); 97 void handle();
96 void handleConnect(); 98 void handleConnect();
97 void handleDisconnect(); 99 void handleDisconnect();
@@ -119,6 +121,7 @@ public: @@ -119,6 +121,7 @@ public:
119 void setDuplicate(); 121 void setDuplicate();
120 size_t getTotalMemoryFootprint(); 122 size_t getTotalMemoryFootprint();
121 std::string getPayloadCopy(); 123 std::string getPayloadCopy();
  124 + void readIntoBuf(CirBuf &buf) const;
122 }; 125 };
123 126
124 #endif // MQTTPACKET_H 127 #endif // MQTTPACKET_H