Commit b615099f48a60fe3880da2e3a83a8aae1e6a5607

Authored by Wiebe Cazemier
1 parent 60e64cf9

Write to session/client outside of subscription lock

subscriptionstore.cpp
@@ -22,6 +22,13 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>. @@ -22,6 +22,13 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
22 #include "rwlockguard.h" 22 #include "rwlockguard.h"
23 #include "retainedmessagesdb.h" 23 #include "retainedmessagesdb.h"
24 24
  25 +ReceivingSubscriber::ReceivingSubscriber(const std::shared_ptr<Session> &ses, char qos) :
  26 + session(ses),
  27 + qos(qos)
  28 +{
  29 +
  30 +}
  31 +
25 SubscriptionNode::SubscriptionNode(const std::string &subtopic) : 32 SubscriptionNode::SubscriptionNode(const std::string &subtopic) :
26 subtopic(subtopic) 33 subtopic(subtopic)
27 { 34 {
@@ -260,7 +267,8 @@ bool SubscriptionStore::sessionPresent(const std::string &amp;clientid) @@ -260,7 +267,8 @@ bool SubscriptionStore::sessionPresent(const std::string &amp;clientid)
260 return result; 267 return result;
261 } 268 }
262 269
263 -void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::unordered_map<std::string, Subscription> &subscribers, uint64_t &count) const 270 +void SubscriptionStore::publishNonRecursively(const std::unordered_map<std::string, Subscription> &subscribers,
  271 + std::forward_list<ReceivingSubscriber> &targetSessions) const
264 { 272 {
265 for (auto &pair : subscribers) 273 for (auto &pair : subscribers)
266 { 274 {
@@ -269,7 +277,8 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &amp;packet, const st @@ -269,7 +277,8 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &amp;packet, const st
269 const std::shared_ptr<Session> session = sub.session.lock(); 277 const std::shared_ptr<Session> session = sub.session.lock();
270 if (session) // Shared pointer expires when session has been cleaned by 'clean session' connect. 278 if (session) // Shared pointer expires when session has been cleaned by 'clean session' connect.
271 { 279 {
272 - session->writePacket(packet, sub.qos, false, count); 280 + ReceivingSubscriber x(session, sub.qos);
  281 + targetSessions.emplace_front(session, sub.qos);
273 } 282 }
274 } 283 }
275 } 284 }
@@ -282,16 +291,16 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &amp;packet, const st @@ -282,16 +291,16 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &amp;packet, const st
282 * @param packet 291 * @param packet
283 * @param count as a reference (vs return value) because a return value introduces an extra call i.e. limits tail recursion optimization. 292 * @param count as a reference (vs return value) because a return value introduces an extra call i.e. limits tail recursion optimization.
284 * 293 *
285 - * As noted in the params section, this method was written so that it could be (somewhat) optimized for tail recursion by the kernel. If you refactor this, 294 + * As noted in the params section, this method was written so that it could be (somewhat) optimized for tail recursion by the compiler. If you refactor this,
286 * look at objdump --disassemble --demangle to see how many calls (not jumps) to itself are made and compare. 295 * look at objdump --disassemble --demangle to see how many calls (not jumps) to itself are made and compare.
287 */ 296 */
288 void SubscriptionStore::publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end, 297 void SubscriptionStore::publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end,
289 - SubscriptionNode *this_node, const MqttPacket &packet, uint64_t &count) const 298 + SubscriptionNode *this_node, std::forward_list<ReceivingSubscriber> &targetSessions) const
290 { 299 {
291 if (cur_subtopic_it == end) // This is the end of the topic path, so look for subscribers here. 300 if (cur_subtopic_it == end) // This is the end of the topic path, so look for subscribers here.
292 { 301 {
293 if (this_node) 302 if (this_node)
294 - publishNonRecursively(packet, this_node->getSubscribers(), count); 303 + publishNonRecursively(this_node->getSubscribers(), targetSessions);
295 return; 304 return;
296 } 305 }
297 306
@@ -308,18 +317,18 @@ void SubscriptionStore::publishRecursively(std::vector&lt;std::string&gt;::const_itera @@ -308,18 +317,18 @@ void SubscriptionStore::publishRecursively(std::vector&lt;std::string&gt;::const_itera
308 317
309 if (this_node->childrenPound) 318 if (this_node->childrenPound)
310 { 319 {
311 - publishNonRecursively(packet, this_node->childrenPound->getSubscribers(), count); 320 + publishNonRecursively(this_node->childrenPound->getSubscribers(), targetSessions);
312 } 321 }
313 322
314 const auto &sub_node = this_node->children.find(cur_subtop); 323 const auto &sub_node = this_node->children.find(cur_subtop);
315 if (sub_node != this_node->children.end()) 324 if (sub_node != this_node->children.end())
316 { 325 {
317 - publishRecursively(next_subtopic, end, sub_node->second.get(), packet, count); 326 + publishRecursively(next_subtopic, end, sub_node->second.get(), targetSessions);
318 } 327 }
319 328
320 if (this_node->childrenPlus) 329 if (this_node->childrenPlus)
321 { 330 {
322 - publishRecursively(next_subtopic, end, this_node->childrenPlus.get(), packet, count); 331 + publishRecursively(next_subtopic, end, this_node->childrenPlus.get(), targetSessions);
323 } 332 }
324 } 333 }
325 334
@@ -329,11 +338,19 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::vector&lt;std::string&gt; @@ -329,11 +338,19 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::vector&lt;std::string&gt;
329 338
330 SubscriptionNode *startNode = dollar ? &rootDollar : &root; 339 SubscriptionNode *startNode = dollar ? &rootDollar : &root;
331 340
332 - RWLockGuard lock_guard(&subscriptionsRwlock);  
333 - lock_guard.rdlock();  
334 -  
335 uint64_t count = 0; 341 uint64_t count = 0;
336 - publishRecursively(subtopics.begin(), subtopics.end(), startNode, packet, count); 342 + std::forward_list<ReceivingSubscriber> subscriberSessions;
  343 +
  344 + {
  345 + RWLockGuard lock_guard(&subscriptionsRwlock);
  346 + lock_guard.rdlock();
  347 + publishRecursively(subtopics.begin(), subtopics.end(), startNode, subscriberSessions);
  348 + }
  349 +
  350 + for(const ReceivingSubscriber &x : subscriberSessions)
  351 + {
  352 + x.session->writePacket(packet, x.qos, false, count);
  353 + }
337 354
338 std::shared_ptr<Client> sender = packet.getSender(); 355 std::shared_ptr<Client> sender = packet.getSender();
339 if (sender) 356 if (sender)
subscriptionstore.h
@@ -41,6 +41,15 @@ struct Subscription @@ -41,6 +41,15 @@ struct Subscription
41 void reset(); 41 void reset();
42 }; 42 };
43 43
  44 +struct ReceivingSubscriber
  45 +{
  46 + const std::shared_ptr<Session> session;
  47 + const char qos;
  48 +
  49 +public:
  50 + ReceivingSubscriber(const std::shared_ptr<Session> &ses, char qos);
  51 +};
  52 +
44 class SubscriptionNode 53 class SubscriptionNode
45 { 54 {
46 std::string subtopic; 55 std::string subtopic;
@@ -94,9 +103,10 @@ class SubscriptionStore @@ -94,9 +103,10 @@ class SubscriptionStore
94 103
95 Logger *logger = Logger::getInstance(); 104 Logger *logger = Logger::getInstance();
96 105
97 - void publishNonRecursively(const MqttPacket &packet, const std::unordered_map<std::string, Subscription> &subscribers, uint64_t &count) const; 106 + void publishNonRecursively(const std::unordered_map<std::string, Subscription> &subscribers,
  107 + std::forward_list<ReceivingSubscriber> &targetSessions) const;
98 void publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end, 108 void publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end,
99 - SubscriptionNode *this_node, const MqttPacket &packet, uint64_t &count) const; 109 + SubscriptionNode *this_node, std::forward_list<ReceivingSubscriber> &targetSessions) const;
100 void getRetainedMessages(RetainedMessageNode *this_node, std::vector<RetainedMessage> &outputList) const; 110 void getRetainedMessages(RetainedMessageNode *this_node, std::vector<RetainedMessage> &outputList) const;
101 void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root, 111 void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root,
102 std::unordered_map<std::string, std::list<SubscriptionForSerializing>> &outputList) const; 112 std::unordered_map<std::string, std::list<SubscriptionForSerializing>> &outputList) const;