diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 2360ec9f46..e0fe0029a5 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -7,7 +7,7 @@ use crate::PingConfig; use futures_util::future::{self, Either, Fuse}; use futures_util::io::{BufReader, BufWriter}; -use futures_util::stream::{FuturesOrdered, FuturesUnordered}; +use futures_util::stream::FuturesOrdered; use futures_util::{Future, FutureExt, StreamExt}; use hyper::upgrade::Upgraded; use jsonrpsee_core::server::helpers::{ @@ -248,7 +248,11 @@ pub(crate) async fn background_task(sender: Sender, mut receiver: Rec let (conn_tx, conn_rx) = oneshot::channel(); let sink = MethodSink::new_with_limit(tx, max_response_body_size, max_log_length); let bounded_subscriptions = BoundedSubscriptions::new(max_subscriptions_per_connection); - let pending_calls = FuturesUnordered::new(); + + // On each method call the `pending_calls` is cloned + // then when all pending_calls are dropped + // a graceful shutdown can has occur. + let (pending_calls, pending_calls_completed) = mpsc::channel::<()>(1); // Spawn another task that sends out the responses on the Websocket. let send_task_handle = tokio::spawn(send_task(rx, sender, ping_config.ping_interval(), conn_rx)); @@ -320,13 +324,14 @@ pub(crate) async fn background_task(sender: Sender, mut receiver: Rec } }; - pending_calls.push(tokio::spawn(execute_unchecked_call(params.clone(), std::mem::take(&mut data)))); + tokio::spawn(execute_unchecked_call(params.clone(), std::mem::take(&mut data), pending_calls.clone())); }; // Drive all running methods to completion. // **NOTE** Do not return early in this function. This `await` needs to run to guarantee // proper drop behaviour. - graceful_shutdown(result, pending_calls, receiver, data, conn_tx, send_task_handle).await; + drop(pending_calls); + graceful_shutdown(result, pending_calls_completed, receiver, data, conn_tx, send_task_handle).await; logger.on_disconnect(remote_addr, TransportProtocol::WebSocket); drop(conn); @@ -492,7 +497,11 @@ struct ExecuteCallParams { bounded_subscriptions: BoundedSubscriptions, } -async fn execute_unchecked_call(params: Arc>, data: Vec) { +async fn execute_unchecked_call( + params: Arc>, + data: Vec, + drop_on_completion: mpsc::Sender<()>, +) { let request_start = params.logger.on_request(TransportProtocol::WebSocket); let first_non_whitespace = data.iter().enumerate().take(128).find(|(_, byte)| !byte.is_ascii_whitespace()); @@ -550,6 +559,10 @@ async fn execute_unchecked_call(params: Arc>, da _ = params.sink.send_error(Id::Null, ErrorCode::ParseError.into()).await; } }; + + // NOTE: This channel is only used to indicate that a method call was completed + // thus the drop here tells the main task that method call was completed. + drop(drop_on_completion); } #[derive(Debug, Copy, Clone)] @@ -561,14 +574,16 @@ pub(crate) enum Shutdown { /// Enforce a graceful shutdown. /// /// This will return once the connection has been terminated or all pending calls have been executed. -async fn graceful_shutdown( +async fn graceful_shutdown( result: Result, - pending_calls: FuturesUnordered, + pending_calls: mpsc::Receiver<()>, receiver: Receiver, data: Vec, mut conn_tx: oneshot::Sender<()>, send_task_handle: tokio::task::JoinHandle<()>, ) { + let pending_calls = ReceiverStream::new(pending_calls); + match result { Ok(Shutdown::ConnectionClosed) | Err(SokettoError::Closed) => (), Ok(Shutdown::Stopped) | Err(_) => {