diff --git a/Cargo.lock b/Cargo.lock index 6dae8e340348..d8d412eec997 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4077,6 +4077,7 @@ dependencies = [ "tokio-postgres", "tokio-postgres-rustls", "tokio-rustls 0.25.0", + "tokio-util", "tracing", "workspace_hack", ] @@ -7430,10 +7431,8 @@ dependencies = [ "either", "fail", "futures-channel", - "futures-core", "futures-executor", "futures-io", - "futures-sink", "futures-util", "getrandom 0.2.11", "hashbrown 0.14.5", diff --git a/libs/postgres_backend/Cargo.toml b/libs/postgres_backend/Cargo.toml index 8e249c09f7e0..c7611b9f213d 100644 --- a/libs/postgres_backend/Cargo.toml +++ b/libs/postgres_backend/Cargo.toml @@ -13,6 +13,7 @@ rustls.workspace = true serde.workspace = true thiserror.workspace = true tokio.workspace = true +tokio-util.workspace = true tokio-rustls.workspace = true tracing.workspace = true @@ -23,4 +24,4 @@ workspace_hack.workspace = true once_cell.workspace = true rustls-pemfile.workspace = true tokio-postgres.workspace = true -tokio-postgres-rustls.workspace = true \ No newline at end of file +tokio-postgres-rustls.workspace = true diff --git a/libs/postgres_backend/src/lib.rs b/libs/postgres_backend/src/lib.rs index 6c41b7f347a9..c79ee4e0533a 100644 --- a/libs/postgres_backend/src/lib.rs +++ b/libs/postgres_backend/src/lib.rs @@ -16,6 +16,7 @@ use std::{fmt, io}; use std::{future::Future, str::FromStr}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::TlsAcceptor; +use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, trace, warn}; use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter}; @@ -400,21 +401,15 @@ impl PostgresBackend { } /// Wrapper for run_message_loop() that shuts down socket when we are done - pub async fn run( + pub async fn run( mut self, handler: &mut impl Handler, - shutdown_watcher: F, - ) -> Result<(), QueryError> - where - F: Fn() -> S + Clone, - S: Future, - { - let ret = self - .run_message_loop(handler, shutdown_watcher.clone()) - .await; + cancel: &CancellationToken, + ) -> Result<(), QueryError> { + let ret = self.run_message_loop(handler, cancel).await; tokio::select! { - _ = shutdown_watcher() => { + _ = cancel.cancelled() => { // do nothing; we most likely got already stopped by shutdown and will log it next. } _ = self.framed.shutdown() => { @@ -444,21 +439,17 @@ impl PostgresBackend { } } - async fn run_message_loop( + async fn run_message_loop( &mut self, handler: &mut impl Handler, - shutdown_watcher: F, - ) -> Result<(), QueryError> - where - F: Fn() -> S, - S: Future, - { + cancel: &CancellationToken, + ) -> Result<(), QueryError> { trace!("postgres backend to {:?} started", self.peer_addr); tokio::select!( biased; - _ = shutdown_watcher() => { + _ = cancel.cancelled() => { // We were requested to shut down. tracing::info!("shutdown request received during handshake"); return Err(QueryError::Shutdown) @@ -473,7 +464,7 @@ impl PostgresBackend { let mut query_string = Bytes::new(); while let Some(msg) = tokio::select!( biased; - _ = shutdown_watcher() => { + _ = cancel.cancelled() => { // We were requested to shut down. tracing::info!("shutdown request received in run_message_loop"); return Err(QueryError::Shutdown) @@ -485,7 +476,7 @@ impl PostgresBackend { let result = self.process_message(handler, msg, &mut query_string).await; tokio::select!( biased; - _ = shutdown_watcher() => { + _ = cancel.cancelled() => { // We were requested to shut down. tracing::info!("shutdown request received during response flush"); diff --git a/libs/postgres_backend/tests/simple_select.rs b/libs/postgres_backend/tests/simple_select.rs index 80df9db858a9..7ec85f0dbe90 100644 --- a/libs/postgres_backend/tests/simple_select.rs +++ b/libs/postgres_backend/tests/simple_select.rs @@ -3,13 +3,14 @@ use once_cell::sync::Lazy; use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError}; use pq_proto::{BeMessage, RowDescriptor}; use std::io::Cursor; -use std::{future, sync::Arc}; +use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, TcpStream}; use tokio_postgres::config::SslMode; use tokio_postgres::tls::MakeTlsConnect; use tokio_postgres::{Config, NoTls, SimpleQueryMessage}; use tokio_postgres_rustls::MakeRustlsConnect; +use tokio_util::sync::CancellationToken; // generate client, server test streams async fn make_tcp_pair() -> (TcpStream, TcpStream) { @@ -50,7 +51,7 @@ async fn simple_select() { tokio::spawn(async move { let mut handler = TestHandler {}; - pgbackend.run(&mut handler, future::pending::<()>).await + pgbackend.run(&mut handler, &CancellationToken::new()).await }); let conf = Config::new(); @@ -102,7 +103,7 @@ async fn simple_select_ssl() { tokio::spawn(async move { let mut handler = TestHandler {}; - pgbackend.run(&mut handler, future::pending::<()>).await + pgbackend.run(&mut handler, &CancellationToken::new()).await }); let client_cfg = rustls::ClientConfig::builder() diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index a440ad63785b..d6c5cf24f1a8 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -272,7 +272,7 @@ async fn page_service_conn_main( let pgbackend = PostgresBackend::new_from_io(socket, peer_addr, auth_type, None)?; match pgbackend - .run(&mut conn_handler, task_mgr::shutdown_watcher) + .run(&mut conn_handler, &task_mgr::shutdown_token()) .await { Ok(()) => { diff --git a/proxy/src/console/mgmt.rs b/proxy/src/console/mgmt.rs index c7a2d467c016..befe7d75104b 100644 --- a/proxy/src/console/mgmt.rs +++ b/proxy/src/console/mgmt.rs @@ -6,8 +6,9 @@ use anyhow::Context; use once_cell::sync::Lazy; use postgres_backend::{AuthType, PostgresBackend, PostgresBackendTCP, QueryError}; use pq_proto::{BeMessage, SINGLE_COL_ROWDESC}; -use std::{convert::Infallible, future}; +use std::convert::Infallible; use tokio::net::{TcpListener, TcpStream}; +use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, Instrument}; static CPLANE_WAITERS: Lazy> = Lazy::new(Default::default); @@ -67,7 +68,9 @@ pub async fn task_main(listener: TcpListener) -> anyhow::Result { async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> { let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?; - pgbackend.run(&mut MgmtHandler, future::pending::<()>).await + pgbackend + .run(&mut MgmtHandler, &CancellationToken::new()) + .await } /// A message received by `mgmt` when a compute node is ready. diff --git a/safekeeper/src/wal_service.rs b/safekeeper/src/wal_service.rs index 4a97eb3993f3..091571111e5c 100644 --- a/safekeeper/src/wal_service.rs +++ b/safekeeper/src/wal_service.rs @@ -4,9 +4,10 @@ //! use anyhow::{Context, Result}; use postgres_backend::QueryError; -use std::{future, time::Duration}; +use std::time::Duration; use tokio::net::TcpStream; use tokio_io_timeout::TimeoutReader; +use tokio_util::sync::CancellationToken; use tracing::*; use utils::{auth::Scope, measured_stream::MeasuredStream}; @@ -100,7 +101,7 @@ async fn handle_socket( // libpq protocol between safekeeper and walproposer / pageserver // We don't use shutdown. pgbackend - .run(&mut conn_handler, future::pending::<()>) + .run(&mut conn_handler, &CancellationToken::new()) .await } diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index f43076171f21..93223f080ade 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -33,10 +33,8 @@ crossbeam-utils = { version = "0.8" } either = { version = "1" } fail = { version = "0.5", default-features = false, features = ["failpoints"] } futures-channel = { version = "0.3", features = ["sink"] } -futures-core = { version = "0.3" } futures-executor = { version = "0.3" } futures-io = { version = "0.3" } -futures-sink = { version = "0.3" } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } getrandom = { version = "0.2", default-features = false, features = ["std"] } hashbrown = { version = "0.14", features = ["raw"] }