diff --git a/include/mqtt/endpoint.hpp b/include/mqtt/endpoint.hpp index 991fe760f..70d5a3dd9 100644 --- a/include/mqtt/endpoint.hpp +++ b/include/mqtt/endpoint.hpp @@ -70,6 +70,7 @@ #include #include #include +#include #if defined(MQTT_USE_WS) #include @@ -4313,9 +4314,7 @@ class endpoint : public std::enable_shared_from_this lck (store_mtx_); - auto& idx = store_.template get(); - auto r = idx.equal_range(packet_id); - idx.erase(std::get<0>(r), std::get<1>(r)); + store_.erase(packet_id); pid_man_.release_id(packet_id); } @@ -4328,12 +4327,16 @@ class endpoint : public std::enable_shared_from_this lck (store_mtx_); - auto const& idx = store_.template get(); - for (auto const & e : idx) { - auto const& m = e.message(); - auto cb = continuous_buffer(m); - f(cb.data(), cb.size()); - } + store_.for_each( + [f]( + basic_store_message_variant const& message, + any const& /*life_keeper*/ + ) { + auto cb = continuous_buffer(message); + f(cb.data(), cb.size()); + return false; // no erase + } + ); } /** @@ -4345,10 +4348,15 @@ class endpoint : public std::enable_shared_from_this lck (store_mtx_); - auto const& idx = store_.template get(); - for (auto const & e : idx) { - f(e.message()); - } + store_.for_each( + [f]( + basic_store_message_variant const& message, + any const& /*life_keeper*/ + ) { + f(message); + return false; // no erase + } + ); } /** @@ -4361,10 +4369,15 @@ class endpoint : public std::enable_shared_from_this lck (store_mtx_); - auto const& idx = store_.template get(); - for (auto const & e : idx) { - f(e.message(), e.life_keeper()); - } + store_.for_each( + [f]( + basic_store_message_variant const& message, + any const& life_keeper + ) { + f(message, life_keeper); + return false; // no erase + } + ); } /** @@ -4513,30 +4526,13 @@ class endpoint : public std::enable_shared_from_this lck (store_mtx_); if (pid_man_.register_id(packet_id)) { - auto ret = store_.emplace( + store_.insert_or_update( packet_id, ((qos_value == qos::at_least_once) ? control_packet_type::puback : control_packet_type::pubrec), force_move(msg), force_move(life_keeper) ); - // When client want to restore serialized messages, - // endpoint might keep the message that has the same packet_id. - // In this case, overwrite store_. - if (!ret.second) { - store_.modify( - ret.first, - [&] (auto& e) { - e = store( - packet_id, - ((qos_value == qos::at_least_once) ? control_packet_type::puback - : control_packet_type::pubrec), - force_move(msg), - force_move(life_keeper) - ); - } - ); - } } } @@ -4550,28 +4546,12 @@ class endpoint : public std::enable_shared_from_this lck (store_mtx_); if (pid_man_.register_id(packet_id)) { - auto ret = store_.emplace( + store_.insert_or_update( packet_id, control_packet_type::pubcomp, force_move(msg), force_move(life_keeper) ); - // When client want to restore serialized messages, - // endpoint might keep the message that has the same packet_id. - // In this case, overwrite store_. - if (!ret.second) { - store_.modify( - ret.first, - [&] (auto& e) { - e = store( - packet_id, - control_packet_type::pubcomp, - force_move(msg), - force_move(life_keeper) - ); - } - ); - } } } @@ -4636,30 +4616,13 @@ class endpoint : public std::enable_shared_from_this lck (store_mtx_); if (pid_man_.register_id(packet_id)) { - auto ret = store_.emplace( + store_.insert_or_update( packet_id, qos == qos::at_least_once ? control_packet_type::puback : control_packet_type::pubrec, force_move(msg), force_move(life_keeper) ); - // When client want to restore serialized messages, - // endpoint might keep the message that has the same packet_id. - // In this case, overwrite store_. - if (!ret.second) { - store_.modify( - ret.first, - [&] (auto& e) { - e = store( - packet_id, - qos == qos::at_least_once ? control_packet_type::puback - : control_packet_type::pubrec, - force_move(msg), - force_move(life_keeper) - ); - } - ); - } } } @@ -4675,27 +4638,12 @@ class endpoint : public std::enable_shared_from_this lck (store_mtx_); if (pid_man_.register_id(packet_id)) { - auto ret = store_.emplace( + store_.insert_or_update( packet_id, control_packet_type::pubcomp, force_move(msg), force_move(life_keeper) ); - // When client want to restore serialized messages, - // endpoint might keep the message that has the same packet_id. - // In this case, overwrite store_. - if (!ret.second) { - store_.modify( - ret.first, - [&] (auto& e) { - e = store( - packet_id, - control_packet_type::pubcomp, - force_move(msg) - ); - } - ); - } } } @@ -4759,14 +4707,15 @@ class endpoint : public std::enable_shared_from_this lck (store_mtx_); pid_man_.register_id(packet_id); - auto ret = store_.emplace( + auto ret = store_.insert( packet_id, control_packet_type::pubcomp, msg, force_move(life_keeper) ); (void)ret; - BOOST_ASSERT(ret.second); + BOOST_ASSERT(ret); + (this->*serialize)(msg); do_sync_write(force_move(msg)); }; @@ -4872,14 +4821,15 @@ class endpoint : public std::enable_shared_from_this lck (store_mtx_); pid_man_.register_id(packet_id); - auto ret = store_.emplace( + auto ret = store_.insert( packet_id, control_packet_type::pubcomp, msg, force_move(life_keeper) ); (void)ret; - BOOST_ASSERT(ret.second); + BOOST_ASSERT(ret); + (this->*serialize)(msg); do_async_write(force_move(msg), force_move(func)); }; @@ -5507,73 +5457,6 @@ class endpoint : public std::enable_shared_from_this buf_; }; - struct store { - store( - packet_id_t id, - control_packet_type type, - basic_store_message_variant smv, - any life_keeper = any()) - : packet_id_(id) - , expected_control_packet_type_(type) - , smv_(force_move(smv)) - , life_keeper_(force_move(life_keeper)) {} - packet_id_t packet_id() const { return packet_id_; } - control_packet_type expected_control_packet_type() const { return expected_control_packet_type_; } - basic_store_message_variant const& message() const { - return smv_; - } - basic_store_message_variant& message() { - return smv_; - } - any const& life_keeper() const { - return life_keeper_; - } - bool is_publish() const { - return - expected_control_packet_type_ == control_packet_type::puback || - expected_control_packet_type_ == control_packet_type::pubrec; - } - - private: - packet_id_t packet_id_; - control_packet_type expected_control_packet_type_; - basic_store_message_variant smv_; - any life_keeper_; - }; - - struct tag_packet_id {}; - struct tag_packet_id_type {}; - struct tag_seq {}; - using mi_store = mi::multi_index_container< - store, - mi::indexed_by< - mi::ordered_unique< - mi::tag, - mi::composite_key< - store, - mi::const_mem_fun< - store, packet_id_t, - &store::packet_id - >, - mi::const_mem_fun< - store, control_packet_type, - &store::expected_control_packet_type - > - > - >, - mi::ordered_non_unique< - mi::tag, - mi::const_mem_fun< - store, packet_id_t, - &store::packet_id - > - >, - mi::sequenced< - mi::tag - > - > - >; - void handle_control_packet_type(any session_life_keeper, this_type_sp self) { fixed_header_ = static_cast(buf_.front()); remaining_length_ = 0; @@ -8189,13 +8072,11 @@ class endpoint : public std::enable_shared_from_this lck (ep_.store_mtx_); - auto& idx = ep_.store_.template get(); - auto r = idx.equal_range(std::make_tuple(packet_id_, control_packet_type::puback)); - - // puback packet_id is not matched to publish - if (std::get<0>(r) == std::get<1>(r)) return false; + if (!ep_.store_.erase(packet_id_, control_packet_type::puback)) { + // puback packet_id is not matched to publish + return false; + } - idx.erase(std::get<0>(r), std::get<1>(r)); ep_.pid_man_.release_id(packet_id_); return true; } (); @@ -8337,13 +8218,11 @@ class endpoint : public std::enable_shared_from_this lck (ep_.store_mtx_); - auto& idx = ep_.store_.template get(); - auto r = idx.equal_range(std::make_tuple(packet_id_, control_packet_type::pubrec)); - - // pubrec packet_id is not matched to publish - if (std::get<0>(r) == std::get<1>(r)) return false; + if (!ep_.store_.erase(packet_id_, control_packet_type::pubrec)) { + // pubrec packet_id is not matched to publish + return false; + } - idx.erase(std::get<0>(r), std::get<1>(r)); // packet_id should be erased here only if reason_code is error. // Otherwise the packet_id is continue to be used for pubrel/pubcomp. if (is_error(reason_code_)) ep_.pid_man_.release_id(packet_id_); @@ -8697,13 +8576,11 @@ class endpoint : public std::enable_shared_from_this lck (ep_.store_mtx_); - auto& idx = ep_.store_.template get(); - auto r = idx.equal_range(std::make_tuple(packet_id_, control_packet_type::pubcomp)); - - // pubcomp packet_id is not matched to pubrel - if (std::get<0>(r) == std::get<1>(r)) return false; + if (!ep_.store_.erase(packet_id_, control_packet_type::pubcomp)) { + // pubcomp packet_id is not matched to pubrel + return false; + } - idx.erase(std::get<0>(r), std::get<1>(r)); ep_.pid_man_.release_id(packet_id_); return true; } (); @@ -9875,11 +9752,11 @@ class endpoint : public std::enable_shared_from_this*serialize)(msg); @@ -10435,79 +10301,80 @@ class endpoint : public std::enable_shared_from_this lck (store_mtx_); - auto& idx = store_.template get(); - for (auto it = idx.begin(), end = idx.end(); it != end;) { - auto msg = it->message(); - MQTT_NS::visit( - make_lambda_visitor( - [&](v3_1_1::basic_publish_message& m) { - MQTT_LOG("mqtt_api", info) - << MQTT_ADD_VALUE(address, this) - << "async_send_store publish v3.1.1"; - if (maximum_packet_size_send_ < size(m)) { - pid_man_.release_id(m.packet_id()); - MQTT_LOG("mqtt_impl", warning) + store_.for_each( + [&] ( + basic_store_message_variant const& message, + any const& /*life_keeper*/ + ) { + auto erase = false; + MQTT_NS::visit( + make_lambda_visitor( + [&](v3_1_1::basic_publish_message const& m) { + MQTT_LOG("mqtt_api", info) << MQTT_ADD_VALUE(address, this) - << "over maximum packet size message removed. packet_id:" << m.packet_id(); - it = idx.erase(it); - return; - } - do_sync_write(m); - ++it; - }, - [&](v3_1_1::basic_pubrel_message& m) { - MQTT_LOG("mqtt_api", info) - << MQTT_ADD_VALUE(address, this) - << "async_send_store pubrel v3.1.1"; - do_sync_write(m); - ++it; - }, - [&](v5::basic_publish_message& m) { - MQTT_LOG("mqtt_api", info) - << MQTT_ADD_VALUE(address, this) - << "async_send_store publish v5"; - any life_keeper; - auto msg_lk = apply_topic_alias(m, force_move(life_keeper)); - if (maximum_packet_size_send_ < size(std::get<0>(msg_lk))) { - pid_man_.release_id(m.packet_id()); - MQTT_LOG("mqtt_impl", warning) + << "async_send_store publish v3.1.1"; + if (maximum_packet_size_send_ < size(m)) { + pid_man_.release_id(m.packet_id()); + MQTT_LOG("mqtt_impl", warning) + << MQTT_ADD_VALUE(address, this) + << "over maximum packet size message removed. packet_id:" << m.packet_id(); + erase = true; + return; + } + do_sync_write(m); + }, + [&](v3_1_1::basic_pubrel_message const& m) { + MQTT_LOG("mqtt_api", info) << MQTT_ADD_VALUE(address, this) - << "over maximum packet size message removed. packet_id:" << m.packet_id(); - it = idx.erase(it); - return; - } - if (publish_send_count_.load() == publish_send_max_) { - LockGuard lck (publish_send_queue_mtx_); - publish_send_queue_.emplace_back( - force_move(std::get<0>(msg_lk)), - false, - force_move(std::get<1>(msg_lk)) - ); - } - else { - MQTT_LOG("mqtt_impl", trace) + << "async_send_store pubrel v3.1.1"; + do_sync_write(m); + }, + [&](v5::basic_publish_message const& m) { + MQTT_LOG("mqtt_api", info) << MQTT_ADD_VALUE(address, this) - << "increment publish_send_count_:" << publish_send_count_.load(); - ++publish_send_count_; - do_sync_write(force_move(std::get<0>(msg_lk))); - } - ++it; - }, - [&](v5::basic_pubrel_message& m) { - MQTT_LOG("mqtt_api", info) - << MQTT_ADD_VALUE(address, this) - << "async_send_store pubrel v5"; - { - LockGuard lck_resend_pubrel (resend_pubrel_mtx_); - resend_pubrel_.insert(m.packet_id()); + << "async_send_store publish v5"; + any life_keeper; + auto msg_lk = apply_topic_alias(m, force_move(life_keeper)); + if (maximum_packet_size_send_ < size(std::get<0>(msg_lk))) { + pid_man_.release_id(m.packet_id()); + MQTT_LOG("mqtt_impl", warning) + << MQTT_ADD_VALUE(address, this) + << "over maximum packet size message removed. packet_id:" << m.packet_id(); + erase = true; + return; + } + if (publish_send_count_.load() == publish_send_max_) { + LockGuard lck (publish_send_queue_mtx_); + publish_send_queue_.emplace_back( + force_move(std::get<0>(msg_lk)), + false, + force_move(std::get<1>(msg_lk)) + ); + } + else { + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "increment publish_send_count_:" << publish_send_count_.load(); + ++publish_send_count_; + do_sync_write(force_move(std::get<0>(msg_lk))); + } + }, + [&](v5::basic_pubrel_message const& m) { + MQTT_LOG("mqtt_api", info) + << MQTT_ADD_VALUE(address, this) + << "async_send_store pubrel v5"; + { + LockGuard lck_resend_pubrel (resend_pubrel_mtx_); + resend_pubrel_.insert(m.packet_id()); + } + do_sync_write(m); } - do_sync_write(m); - ++it; - } - ), - msg - ); - } + ), + message + ); + return erase; + } + ); } // Blocking write @@ -10843,34 +10710,23 @@ class endpoint : public std::enable_shared_from_this(); - for (auto it = idx.begin(), end = idx.end(); it != end;) { - auto msg = it->message(); - MQTT_NS::visit( - make_lambda_visitor( - [&](v3_1_1::basic_publish_message& m) { - MQTT_LOG("mqtt_api", info) - << MQTT_ADD_VALUE(address, this) - << "async_send_store publish v3.1.1"; - if (maximum_packet_size_send_ < size(m)) { - pid_man_.release_id(m.packet_id()); - MQTT_LOG("mqtt_impl", warning) + store_.for_each( + [&] ( + basic_store_message_variant const& message, + any const& /*life_keeper*/ + ) { + auto erase = false; + MQTT_NS::visit( + make_lambda_visitor( + [&](v3_1_1::basic_publish_message const& m) { + MQTT_LOG("mqtt_api", info) << MQTT_ADD_VALUE(address, this) - << "over maximum packet size message removed. packet_id:" << m.packet_id(); - it = idx.erase(it); - return; - } - do_async_write( - m, - [g] - (error_code /*ec*/) { - } - ); - ++it; - }, - [&](v3_1_1::basic_pubrel_message& m) { - MQTT_LOG("mqtt_api", info) - << MQTT_ADD_VALUE(address, this) - << "async_send_store pubrel v3.1.1"; - do_async_write( - m, - [g] - (error_code /*ec*/) { + << "async_send_store publish v3.1.1"; + if (maximum_packet_size_send_ < size(m)) { + pid_man_.release_id(m.packet_id()); + MQTT_LOG("mqtt_impl", warning) + << MQTT_ADD_VALUE(address, this) + << "over maximum packet size message removed. packet_id:" << m.packet_id(); + erase = true; + return; } - ); - ++it; - }, - [&](v5::basic_publish_message& m) { - MQTT_LOG("mqtt_api", info) - << MQTT_ADD_VALUE(address, this) - << "async_send_store publish v5"; - any life_keeper; - auto msg_lk = apply_topic_alias(m, force_move(life_keeper)); - if (maximum_packet_size_send_ < size(std::get<0>(msg_lk))) { - pid_man_.release_id(m.packet_id()); - MQTT_LOG("mqtt_impl", warning) + do_async_write( + m, + [g] + (error_code /*ec*/) { + } + ); + }, + [&](v3_1_1::basic_pubrel_message const& m) { + MQTT_LOG("mqtt_api", info) << MQTT_ADD_VALUE(address, this) - << "over maximum packet size message removed. packet_id:" << m.packet_id(); - it = idx.erase(it); - return; - } - if (publish_send_count_.load() == publish_send_max_) { - LockGuard lck (publish_send_queue_mtx_); - publish_send_queue_.emplace_back( - force_move(std::get<0>(msg_lk)), - true, - force_move(std::get<1>(msg_lk)) + << "async_send_store pubrel v3.1.1"; + do_async_write( + m, + [g] + (error_code /*ec*/) { + } ); - } - else { - MQTT_LOG("mqtt_impl", trace) + }, + [&](v5::basic_publish_message const& m) { + MQTT_LOG("mqtt_api", info) + << MQTT_ADD_VALUE(address, this) + << "async_send_store publish v5"; + any life_keeper; + auto msg_lk = apply_topic_alias(m, force_move(life_keeper)); + if (maximum_packet_size_send_ < size(std::get<0>(msg_lk))) { + pid_man_.release_id(m.packet_id()); + MQTT_LOG("mqtt_impl", warning) + << MQTT_ADD_VALUE(address, this) + << "over maximum packet size message removed. packet_id:" << m.packet_id(); + erase = true; + return; + } + if (publish_send_count_.load() == publish_send_max_) { + LockGuard lck (publish_send_queue_mtx_); + publish_send_queue_.emplace_back( + force_move(std::get<0>(msg_lk)), + true, + force_move(std::get<1>(msg_lk)) + ); + } + else { + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "increment publish_send_count_:" << publish_send_count_.load(); + ++publish_send_count_; + do_async_write( + get_basic_message_variant(force_move(std::get<0>(msg_lk))), + [g, life_keeper = force_move(std::get<1>(msg_lk))] + (error_code /*ec*/) { + } + ); + } + }, + [&](v5::basic_pubrel_message const& m) { + MQTT_LOG("mqtt_api", info) << MQTT_ADD_VALUE(address, this) - << "increment publish_send_count_:" << publish_send_count_.load(); - ++publish_send_count_; + << "async_send_store pubrel v5"; + { + LockGuard lck_resend_pubrel (resend_pubrel_mtx_); + resend_pubrel_.insert(m.packet_id()); + } do_async_write( - get_basic_message_variant(force_move(std::get<0>(msg_lk))), - [g, life_keeper = force_move(std::get<1>(msg_lk))] + m, + [g] (error_code /*ec*/) { } ); } - ++it; - }, - [&](v5::basic_pubrel_message& m) { - MQTT_LOG("mqtt_api", info) - << MQTT_ADD_VALUE(address, this) - << "async_send_store pubrel v5"; - { - LockGuard lck_resend_pubrel (resend_pubrel_mtx_); - resend_pubrel_.insert(m.packet_id()); - } - do_async_write( - m, - [g] - (error_code /*ec*/) { - } - ); - ++it; - } - ), - msg - ); - } + ), + message + ); + return erase; + } + ); } // Non blocking (async) write @@ -11744,7 +11601,7 @@ class endpoint : public std::enable_shared_from_this payload_; Mutex store_mtx_; - mi_store store_; + store store_; std::set qos2_publish_handled_; std::deque queue_; diff --git a/include/mqtt/store.hpp b/include/mqtt/store.hpp new file mode 100644 index 000000000..717ad4f44 --- /dev/null +++ b/include/mqtt/store.hpp @@ -0,0 +1,200 @@ +// Copyright Takatoshi Kondo 2022 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(MQTT_STORE_HPP) +#define MQTT_STORE_HPP + +#include // should be top to configure variant limit + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace MQTT_NS { + +namespace mi = boost::multi_index; + +enum class store_insert_update_result { + inserted, + updated +}; + +template +class store { +private: + struct tag_packet_id {}; + struct tag_packet_id_type {}; + struct tag_seq {}; + +public: + using packet_id_t = typename packet_id_type::type; + + bool insert( + packet_id_t packet_id, + control_packet_type expected_type, + basic_store_message_variant smv, + any life_keeper + ) { + auto ret = elems_.emplace( + packet_id, + expected_type, + force_move(smv), + force_move(life_keeper) + ); + return ret.second; + } + + store_insert_update_result insert_or_update( + packet_id_t packet_id, + control_packet_type expected_type, + basic_store_message_variant smv, + any life_keeper + ) { + auto ret = elems_.emplace( + packet_id, + expected_type, + force_move(smv), + force_move(life_keeper) + ); + if (ret.second) return store_insert_update_result::inserted; + + // When client want to restore serialized messages, + // endpoint might keep the message that has the same packet_id. + // In this case, overwrite the element. + // entry exists + elems_.modify( + ret.first, + [&] (auto& e) { + e.packet_id_ = packet_id; + e.expected_control_packet_type_ = expected_type; + e.smv_ = force_move(smv); + e.life_keeper_ = force_move(life_keeper); + } + ); + return store_insert_update_result::updated; + } + + void for_each( + std::function< + // if return true, then erase element + bool(basic_store_message_variant const&, any const&) + > const& f + ) { + auto& idx = elems_.template get(); + auto it = idx.begin(); + auto end = idx.end(); + while (it != end) { + if (f(it->message(), it->life_keeper())) { + it = idx.erase(it); + } + else { + ++it; + } + } + } + + std::size_t erase(packet_id_t packet_id) { + auto& idx = elems_.template get(); + return idx.erase(packet_id); + } + + bool erase(packet_id_t packet_id, control_packet_type type) { + auto& idx = elems_.template get(); + auto ret = idx.equal_range(std::make_tuple(packet_id, type)); + if (ret.first == ret.second) return false; + idx.erase(ret.first, ret.second); + return true; + } + + void clear() { + elems_.clear(); + } + + bool empty() const { + return elems_.empty(); + } + +private: + + struct elem_t { + friend class store; + + elem_t( + packet_id_t id, + control_packet_type type, + basic_store_message_variant smv, + any life_keeper = any()) + : packet_id_(id) + , expected_control_packet_type_(type) + , smv_(force_move(smv)) + , life_keeper_(force_move(life_keeper)) {} + packet_id_t packet_id() const { return packet_id_; } + control_packet_type expected_control_packet_type() const { return expected_control_packet_type_; } + basic_store_message_variant const& message() const { + return smv_; + } + basic_store_message_variant& message() { + return smv_; + } + any const& life_keeper() const { + return life_keeper_; + } + bool is_publish() const { + return + expected_control_packet_type_ == control_packet_type::puback || + expected_control_packet_type_ == control_packet_type::pubrec; + } + + private: + packet_id_t packet_id_; + control_packet_type expected_control_packet_type_; + basic_store_message_variant smv_; + any life_keeper_; + }; + + + using mi_elem = mi::multi_index_container< + elem_t, + mi::indexed_by< + mi::ordered_unique< + mi::tag, + mi::composite_key< + elem_t, + mi::const_mem_fun< + elem_t, packet_id_t, + &elem_t::packet_id + >, + mi::const_mem_fun< + elem_t, control_packet_type, + &elem_t::expected_control_packet_type + > + > + >, + mi::ordered_non_unique< + mi::tag, + mi::const_mem_fun< + elem_t, packet_id_t, + &elem_t::packet_id + > + >, + mi::sequenced< + mi::tag + > + > + >; + + mi_elem elems_; +}; + +} // namespace MQTT_NS + +#endif // MQTT_STORE_HPP