diff --git a/include/mqtt/broker/broker.hpp b/include/mqtt/broker/broker.hpp index 8fa171641..ee09e8bcc 100644 --- a/include/mqtt/broker/broker.hpp +++ b/include/mqtt/broker/broker.hpp @@ -1319,9 +1319,15 @@ class broker_t { auto& ss = const_cast(*it); do_send_will(ss); if (rc) { + MQTT_LOG("mqtt_broker", trace) + << MQTT_ADD_VALUE(address, spep.get()) + << "disconnect_and_force_disconnect(async) cid:" << ss.client_id(); disconnect_and_force_disconnect(spep, rc.value()); } else { + MQTT_LOG("mqtt_broker", trace) + << MQTT_ADD_VALUE(address, spep.get()) + << "force_disconnect(async) cid:" << ss.client_id(); force_disconnect(spep); } idx.erase(it); @@ -1334,9 +1340,15 @@ class broker_t { [&](session_state& ss) { do_send_will(ss); if (rc) { + MQTT_LOG("mqtt_broker", trace) + << MQTT_ADD_VALUE(address, spep.get()) + << "disconnect_and_force_disconnect(async) cid:" << ss.client_id(); disconnect_and_force_disconnect(spep, rc.value()); } else { + MQTT_LOG("mqtt_broker", trace) + << MQTT_ADD_VALUE(address, spep.get()) + << "force_disconnect(async) cid:" << ss.client_id(); force_disconnect(spep); } // become_offline updates index diff --git a/include/mqtt/broker/session_state.hpp b/include/mqtt/broker/session_state.hpp index c70ce0bd6..efe810a3a 100644 --- a/include/mqtt/broker/session_state.hpp +++ b/include/mqtt/broker/session_state.hpp @@ -282,6 +282,9 @@ struct session_state { } void clean() { + MQTT_LOG("mqtt_broker", trace) + << MQTT_ADD_VALUE(address, this) + << "clean"; { std::lock_guard g(mtx_inflight_messages_); inflight_messages_.clear(); @@ -296,6 +299,7 @@ struct session_state { } shared_targets_.erase(*this); unsubscribe_all(); + if (con_) con_->async_force_disconnect(); } void exactly_once_start(packet_id_t packet_id) { diff --git a/include/mqtt/constant.hpp b/include/mqtt/constant.hpp index 29db53200..3b9b748bc 100644 --- a/include/mqtt/constant.hpp +++ b/include/mqtt/constant.hpp @@ -8,6 +8,7 @@ #define MQTT_CONSTANT_HPP #include +#include #include #include @@ -21,6 +22,7 @@ static constexpr std::size_t packet_size_no_limit = 4 + // remaining length 128 * 128 * 128 * 128; // maximum value of remainin length static constexpr receive_maximum_t receive_maximum_max = 0xffff; +static constexpr auto shutdown_timeout = std::chrono::seconds(3); } // namespace MQTT_NS diff --git a/include/mqtt/endpoint.hpp b/include/mqtt/endpoint.hpp index eefb72ed5..41ec48e2a 100644 --- a/include/mqtt/endpoint.hpp +++ b/include/mqtt/endpoint.hpp @@ -191,7 +191,8 @@ class endpoint : public std::enable_shared_from_thisclose(ignored_ec); - } + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "handle_close_or_error call chutdown"; + shutdown(socket()); } + + connect_requested_ = false; + clean_sub_unsub_inflight(); if (disconnect_requested_) { - disconnect_requested_ = false; - connect_requested_ = false; - clean_sub_unsub_inflight(); on_close(); - return true; + disconnect_requested_ = false; + } + else { + if (!ec) ec = boost::system::errc::make_error_code(boost::system::errc::not_connected); + on_error(ec); } - disconnect_requested_ = false; - connect_requested_ = false; - if (!ec) ec = boost::system::errc::make_error_code(boost::system::errc::not_connected); - clean_sub_unsub_inflight_on_error(ec); return true; } @@ -5260,36 +5260,82 @@ class endpoint : public std::enable_shared_from_this - void shutdown(T& socket) { + void shutdown(MQTT_NS::socket& s) { MQTT_LOG("mqtt_impl", trace) << MQTT_ADD_VALUE(address, this) << "shutdown"; - if (shutdowned_) { + if (shutdown_requested_) { MQTT_LOG("mqtt_impl", trace) << MQTT_ADD_VALUE(address, this) << "already shutdowned"; return; } - shutdowned_ = true; - connected_ = false; + shutdown_requested_ = true; mqtt_connected_ = false; - - { - boost::system::error_code ec; - socket.lowest_layer().shutdown(as::ip::tcp::socket::shutdown_both, ec); + if (async_operation_) { MQTT_LOG("mqtt_impl", trace) << MQTT_ADD_VALUE(address, this) - << "socket shutdown ec:" - << ec.message(); + << "async_clean_shutdown_and_close"; + s.async_clean_shutdown_and_close( + [this, sp = this->shared_from_this(), ssp = socket_sp_ref()](error_code ec) { // *1 + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "async_clean_shutdown_and_close ec:" + << ec.message(); + tim_shutdown_.cancel(); + connected_ = false; + } + ); + // timeout timer set + tim_shutdown_.expires_after(shutdown_timeout); + std::weak_ptr wp(std::static_pointer_cast(this->shared_from_this())); + tim_shutdown_.async_wait( + [this, wp = force_move(wp), ssp = socket_sp_ref()](error_code ec) mutable { + if (auto sp = wp.lock()) { + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "async_shutdown timer ec:" + << ec.message(); + if (!ec) { + // timeout + // tcp_shutdown indirectly cancel stream.async_shutdown() + // and handler is called with error. + // So captured sp at *1 is released. + + // post is for applying strand + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "post force_shutdown_and_close"; + sp->socket().post( + [this, sp] { + if (connected_) { + error_code ec; + socket().force_shutdown_and_close(ec); + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "force_shutdown_and_close ec:" + << ec.message(); + connected_ = false; + } + } + ); + } + } + } + ); + return; } - { - boost::system::error_code ec; - socket.lowest_layer().close(ec); + else { + error_code ec; + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "clean_shutdown_and_close"; + s.clean_shutdown_and_close(ec); MQTT_LOG("mqtt_impl", trace) << MQTT_ADD_VALUE(address, this) - << "socket close ec:" + << "clean_shutdown_and_close ec:" << ec.message(); + connected_ = false; } } @@ -9483,7 +9529,7 @@ class endpoint : public std::enable_shared_from_this void do_sync_write(MessageVariant&& mv) { boost::system::error_code ec; - if (!connected_) return; - on_pre_send(); - total_bytes_sent_ += socket_->write(const_buffer_sequence(mv), ec); - // If ec is set as error, the error will be handled by async_read. - // If `handle_error(ec);` is called here, error_handler would be called twice. + if (can_send()) { + on_pre_send(); + total_bytes_sent_ += socket_->write(const_buffer_sequence(mv), ec); + // If ec is set as error, the error will be handled by async_read. + // If `handle_error(ec);` is called here, error_handler would be called twice. + } } // Non blocking (async) senders @@ -10366,7 +10413,7 @@ class endpoint : public std::enable_shared_from_thisconnected_) { - self_->connected_ = false; while (!self_->queue_.empty()) { // Handlers for outgoing packets need not be valid. if (auto&& h = self_->queue_.front().handler()) h(ec); @@ -11316,7 +11362,6 @@ class endpoint : public std::enable_shared_from_thisconnected_) { - self_->connected_ = false; while (!self_->queue_.empty()) { // Handlers for outgoing packets need not be valid. if(auto&& h = self_->queue_.front().handler()) h(ec); @@ -11325,7 +11370,6 @@ class endpoint : public std::enable_shared_from_thisconnected_ = false; while (!self_->queue_.empty()) { // Handlers for outgoing packets need not be valid. if(auto&& h = self_->queue_.front().handler()) h(ec); @@ -11407,15 +11451,17 @@ class endpoint : public std::enable_shared_from_thispost( [this, self = this->shared_from_this(), mv = force_move(mv), func = force_move(func)] () mutable { - if (!connected_) { + if (can_send()) { + queue_.emplace_back(force_move(mv), force_move(func)); + // Only need to start async writes if there was nothing in the queue before the above item. + if (queue_.size() > 1) return; + do_async_write(); + } + else { // offline async publish is successfully finished, because there's nothing to do. if (func) func(boost::system::errc::make_error_code(boost::system::errc::success)); return; } - queue_.emplace_back(force_move(mv), force_move(func)); - // Only need to start async writes if there was nothing in the queue before the above item. - if (queue_.size() > 1) return; - do_async_write(); } ); } @@ -11503,6 +11549,10 @@ class endpoint : public std::enable_shared_from_this get_topic_alias_from_prop(v5::property_variant const& prop) { optional val; v5::visit_prop( @@ -11548,7 +11598,7 @@ class endpoint : public std::enable_shared_from_this socket_; std::atomic connected_{false}; std::atomic mqtt_connected_{false}; - std::atomic shutdowned_{false}; + std::atomic shutdown_requested_{false}; std::array buf_; std::uint8_t fixed_header_; @@ -11583,6 +11633,8 @@ class endpoint : public std::enable_shared_from_this #include #include +#include +#include namespace MQTT_NS { @@ -82,7 +84,16 @@ class tcp_endpoint : public socket { return tcp_.native_handle(); } - MQTT_ALWAYS_INLINE void close(boost::system::error_code& ec) override final { + MQTT_ALWAYS_INLINE void clean_shutdown_and_close(boost::system::error_code& ec) override final { + shutdown_and_close_impl(tcp_, ec); + } + + MQTT_ALWAYS_INLINE void async_clean_shutdown_and_close(std::function handler) override final { + async_shutdown_and_close_impl(tcp_, force_move(handler)); + } + + MQTT_ALWAYS_INLINE void force_shutdown_and_close(boost::system::error_code& ec) override final { + tcp_.lowest_layer().shutdown(as::ip::tcp::socket::shutdown_both, ec); tcp_.lowest_layer().close(ec); } @@ -123,6 +134,56 @@ class tcp_endpoint : public socket { #endif // defined(MQTT_USE_TLS) +private: + void shutdown_and_close_impl(as::basic_socket& s, boost::system::error_code& ec) { + s.shutdown(as::ip::tcp::socket::shutdown_both, ec); + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "shutdown ec:" + << ec.message(); + s.close(ec); + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "close ec:" + << ec.message(); + } + + void async_shutdown_and_close_impl(as::basic_socket& s, std::function handler) { + post( + [this, &s, handler = force_move(handler)] () mutable { + error_code ec; + shutdown_and_close_impl(s, ec); + force_move(handler)(ec); + } + ); + } + +#if defined(MQTT_USE_TLS) + void shutdown_and_close_impl(tls::stream& s, boost::system::error_code& ec) { + s.shutdown(ec); + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "shutdown ec:" + << ec.message(); + shutdown_and_close_impl(lowest_layer(), ec); + } + void async_shutdown_and_close_impl(tls::stream& s, std::function handler) { + s.async_shutdown( + as::bind_executor( + strand_, + [this, &s, handler = force_move(handler)] (error_code ec) mutable { + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "shutdown ec:" + << ec.message(); + shutdown_and_close_impl(s.lowest_layer(), ec); + force_move(handler)(ec); + } + ) + ); + } +#endif // defined(MQTT_USE_TLS) + private: Socket tcp_; Strand strand_; diff --git a/include/mqtt/type_erased_socket.hpp b/include/mqtt/type_erased_socket.hpp index 008ce4c1f..3939ab156 100644 --- a/include/mqtt/type_erased_socket.hpp +++ b/include/mqtt/type_erased_socket.hpp @@ -28,7 +28,9 @@ class socket { virtual void post(std::function) = 0; virtual as::ip::tcp::socket::lowest_layer_type& lowest_layer() = 0; virtual any native_handle() = 0; - virtual void close(boost::system::error_code&) = 0; + virtual void clean_shutdown_and_close(boost::system::error_code&) = 0; + virtual void async_clean_shutdown_and_close(std::function) = 0; + virtual void force_shutdown_and_close(boost::system::error_code&) = 0; #if BOOST_VERSION < 107400 || defined(BOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT) virtual as::executor get_executor() = 0; #else // BOOST_VERSION < 107400 || defined(BOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT) diff --git a/include/mqtt/ws_endpoint.hpp b/include/mqtt/ws_endpoint.hpp index ff859f59b..dcc38f92d 100644 --- a/include/mqtt/ws_endpoint.hpp +++ b/include/mqtt/ws_endpoint.hpp @@ -17,6 +17,8 @@ #include #include #include +#include +#include namespace MQTT_NS { @@ -153,15 +155,68 @@ class ws_endpoint : public socket { return next_layer().native_handle(); } - MQTT_ALWAYS_INLINE void close(boost::system::error_code& ec) override final { - ws_.close(boost::beast::websocket::close_code::normal, ec); - if (ec) return; - do { - boost::beast::flat_buffer buffer; - ws_.read(buffer, ec); - } while (!ec); - if (ec != boost::beast::websocket::error::closed) return; - ec = boost::system::errc::make_error_code(boost::system::errc::success); + MQTT_ALWAYS_INLINE void clean_shutdown_and_close(boost::system::error_code& ec) override final { + if (ws_.is_open()) { + // WebSocket closing process + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "call beast close"; + ws_.close(boost::beast::websocket::close_code::normal, ec); + if (ec) return; + + do { + boost::beast::flat_buffer buffer; + ws_.read(buffer, ec); + } while (!ec); + + if (ec == boost::beast::websocket::error::closed) { + ec = boost::system::errc::make_error_code(boost::system::errc::success); + } + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "ws close ec:" + << ec.message(); + } + shutdown_and_close_impl(next_layer(), ec); + } + + MQTT_ALWAYS_INLINE void async_clean_shutdown_and_close(std::function handler) override final { + if (ws_.is_open()) { + // WebSocket closing process + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "call beast async_close"; + ws_.async_close( + boost::beast::websocket::close_code::normal, + as::bind_executor( + strand_, + [this, handler = force_move(handler)] + (error_code ec) mutable { + if (ec) { + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "ws async_close ec:" + << ec.message(); + async_shutdown_and_close_impl(next_layer(), force_move(handler)); + } + else { + async_read_until_closed(force_move(handler)); + } + } + ) + ); + } + else { + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "ws async_close already closed"; + async_shutdown_and_close_impl(next_layer(), force_move(handler)); + } + } + + MQTT_ALWAYS_INLINE void force_shutdown_and_close(boost::system::error_code& ec) override final { + lowest_layer().shutdown(as::ip::tcp::socket::shutdown_both, ec); + lowest_layer().close(ec); } #if BOOST_VERSION < 107400 || defined(BOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT) @@ -215,6 +270,82 @@ class ws_endpoint : public socket { return as::buffer_size(buffers); } +private: + void async_read_until_closed(std::function handler) { + auto buffer = std::make_shared(); + ws_.async_read( + *buffer, + as::bind_executor( + strand_, + [this, handler = force_move(handler)] + (error_code ec, std::size_t) mutable { + if (ec) { + if (ec == boost::beast::websocket::error::closed) { + ec = boost::system::errc::make_error_code(boost::system::errc::success); + } + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "ws async_read ec:" + << ec.message(); + async_shutdown_and_close_impl(next_layer(), force_move(handler)); + } + else { + async_read_until_closed(force_move(handler)); + } + } + ) + ); + } + + void shutdown_and_close_impl(as::basic_socket& s, boost::system::error_code& ec) { + s.shutdown(as::ip::tcp::socket::shutdown_both, ec); + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "shutdown ec:" + << ec.message(); + s.close(ec); + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "close ec:" + << ec.message(); + } + + void async_shutdown_and_close_impl(as::basic_socket& s, std::function handler) { + post( + [this, &s, handler = force_move(handler)] () mutable { + error_code ec; + shutdown_and_close_impl(s, ec); + force_move(handler)(ec); + } + ); + } + +#if defined(MQTT_USE_TLS) + void shutdown_and_close_impl(tls::stream& s, boost::system::error_code& ec) { + s.shutdown(ec); + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "shutdown ec:" + << ec.message(); + shutdown_and_close_impl(lowest_layer(), ec); + } + void async_shutdown_and_close_impl(tls::stream& s, std::function handler) { + s.async_shutdown( + as::bind_executor( + strand_, + [this, handler = force_move(handler)] (error_code ec) mutable { + MQTT_LOG("mqtt_impl", trace) + << MQTT_ADD_VALUE(address, this) + << "shutdown ec:" + << ec.message(); + shutdown_and_close_impl(lowest_layer(), ec); + force_move(handler)(ec); + } + ) + ); + } +#endif // defined(MQTT_USE_TLS) + private: boost::beast::websocket::stream ws_; boost::beast::flat_buffer buffer_;