Commit b615099f48a60fe3880da2e3a83a8aae1e6a5607

Authored by Wiebe Cazemier
1 parent 60e64cf9

Write to session/client outside of subscription lock

subscriptionstore.cpp
... ... @@ -22,6 +22,13 @@ License along with FlashMQ. If not, see <https://www.gnu.org/licenses/>.
22 22 #include "rwlockguard.h"
23 23 #include "retainedmessagesdb.h"
24 24  
  25 +ReceivingSubscriber::ReceivingSubscriber(const std::shared_ptr<Session> &ses, char qos) :
  26 + session(ses),
  27 + qos(qos)
  28 +{
  29 +
  30 +}
  31 +
25 32 SubscriptionNode::SubscriptionNode(const std::string &subtopic) :
26 33 subtopic(subtopic)
27 34 {
... ... @@ -260,7 +267,8 @@ bool SubscriptionStore::sessionPresent(const std::string &amp;clientid)
260 267 return result;
261 268 }
262 269  
263   -void SubscriptionStore::publishNonRecursively(const MqttPacket &packet, const std::unordered_map<std::string, Subscription> &subscribers, uint64_t &count) const
  270 +void SubscriptionStore::publishNonRecursively(const std::unordered_map<std::string, Subscription> &subscribers,
  271 + std::forward_list<ReceivingSubscriber> &targetSessions) const
264 272 {
265 273 for (auto &pair : subscribers)
266 274 {
... ... @@ -269,7 +277,8 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &amp;packet, const st
269 277 const std::shared_ptr<Session> session = sub.session.lock();
270 278 if (session) // Shared pointer expires when session has been cleaned by 'clean session' connect.
271 279 {
272   - session->writePacket(packet, sub.qos, false, count);
  280 + ReceivingSubscriber x(session, sub.qos);
  281 + targetSessions.emplace_front(session, sub.qos);
273 282 }
274 283 }
275 284 }
... ... @@ -282,16 +291,16 @@ void SubscriptionStore::publishNonRecursively(const MqttPacket &amp;packet, const st
282 291 * @param packet
283 292 * @param count as a reference (vs return value) because a return value introduces an extra call i.e. limits tail recursion optimization.
284 293 *
285   - * As noted in the params section, this method was written so that it could be (somewhat) optimized for tail recursion by the kernel. If you refactor this,
  294 + * As noted in the params section, this method was written so that it could be (somewhat) optimized for tail recursion by the compiler. If you refactor this,
286 295 * look at objdump --disassemble --demangle to see how many calls (not jumps) to itself are made and compare.
287 296 */
288 297 void SubscriptionStore::publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end,
289   - SubscriptionNode *this_node, const MqttPacket &packet, uint64_t &count) const
  298 + SubscriptionNode *this_node, std::forward_list<ReceivingSubscriber> &targetSessions) const
290 299 {
291 300 if (cur_subtopic_it == end) // This is the end of the topic path, so look for subscribers here.
292 301 {
293 302 if (this_node)
294   - publishNonRecursively(packet, this_node->getSubscribers(), count);
  303 + publishNonRecursively(this_node->getSubscribers(), targetSessions);
295 304 return;
296 305 }
297 306  
... ... @@ -308,18 +317,18 @@ void SubscriptionStore::publishRecursively(std::vector&lt;std::string&gt;::const_itera
308 317  
309 318 if (this_node->childrenPound)
310 319 {
311   - publishNonRecursively(packet, this_node->childrenPound->getSubscribers(), count);
  320 + publishNonRecursively(this_node->childrenPound->getSubscribers(), targetSessions);
312 321 }
313 322  
314 323 const auto &sub_node = this_node->children.find(cur_subtop);
315 324 if (sub_node != this_node->children.end())
316 325 {
317   - publishRecursively(next_subtopic, end, sub_node->second.get(), packet, count);
  326 + publishRecursively(next_subtopic, end, sub_node->second.get(), targetSessions);
318 327 }
319 328  
320 329 if (this_node->childrenPlus)
321 330 {
322   - publishRecursively(next_subtopic, end, this_node->childrenPlus.get(), packet, count);
  331 + publishRecursively(next_subtopic, end, this_node->childrenPlus.get(), targetSessions);
323 332 }
324 333 }
325 334  
... ... @@ -329,11 +338,19 @@ void SubscriptionStore::queuePacketAtSubscribers(const std::vector&lt;std::string&gt;
329 338  
330 339 SubscriptionNode *startNode = dollar ? &rootDollar : &root;
331 340  
332   - RWLockGuard lock_guard(&subscriptionsRwlock);
333   - lock_guard.rdlock();
334   -
335 341 uint64_t count = 0;
336   - publishRecursively(subtopics.begin(), subtopics.end(), startNode, packet, count);
  342 + std::forward_list<ReceivingSubscriber> subscriberSessions;
  343 +
  344 + {
  345 + RWLockGuard lock_guard(&subscriptionsRwlock);
  346 + lock_guard.rdlock();
  347 + publishRecursively(subtopics.begin(), subtopics.end(), startNode, subscriberSessions);
  348 + }
  349 +
  350 + for(const ReceivingSubscriber &x : subscriberSessions)
  351 + {
  352 + x.session->writePacket(packet, x.qos, false, count);
  353 + }
337 354  
338 355 std::shared_ptr<Client> sender = packet.getSender();
339 356 if (sender)
... ...
subscriptionstore.h
... ... @@ -41,6 +41,15 @@ struct Subscription
41 41 void reset();
42 42 };
43 43  
  44 +struct ReceivingSubscriber
  45 +{
  46 + const std::shared_ptr<Session> session;
  47 + const char qos;
  48 +
  49 +public:
  50 + ReceivingSubscriber(const std::shared_ptr<Session> &ses, char qos);
  51 +};
  52 +
44 53 class SubscriptionNode
45 54 {
46 55 std::string subtopic;
... ... @@ -94,9 +103,10 @@ class SubscriptionStore
94 103  
95 104 Logger *logger = Logger::getInstance();
96 105  
97   - void publishNonRecursively(const MqttPacket &packet, const std::unordered_map<std::string, Subscription> &subscribers, uint64_t &count) const;
  106 + void publishNonRecursively(const std::unordered_map<std::string, Subscription> &subscribers,
  107 + std::forward_list<ReceivingSubscriber> &targetSessions) const;
98 108 void publishRecursively(std::vector<std::string>::const_iterator cur_subtopic_it, std::vector<std::string>::const_iterator end,
99   - SubscriptionNode *this_node, const MqttPacket &packet, uint64_t &count) const;
  109 + SubscriptionNode *this_node, std::forward_list<ReceivingSubscriber> &targetSessions) const;
100 110 void getRetainedMessages(RetainedMessageNode *this_node, std::vector<RetainedMessage> &outputList) const;
101 111 void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root,
102 112 std::unordered_map<std::string, std::list<SubscriptionForSerializing>> &outputList) const;
... ...