Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: postgres_backend: replace abstract shutdown_watcher with CancellationToken #8295

Merged
merged 2 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion libs/postgres_backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ rustls.workspace = true
serde.workspace = true
thiserror.workspace = true
tokio.workspace = true
tokio-util.workspace = true
tokio-rustls.workspace = true
tracing.workspace = true

Expand All @@ -23,4 +24,4 @@ workspace_hack.workspace = true
once_cell.workspace = true
rustls-pemfile.workspace = true
tokio-postgres.workspace = true
tokio-postgres-rustls.workspace = true
tokio-postgres-rustls.workspace = true
33 changes: 12 additions & 21 deletions libs/postgres_backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::{fmt, io};
use std::{future::Future, str::FromStr};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, trace, warn};

use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
Expand Down Expand Up @@ -400,21 +401,15 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
}

/// Wrapper for run_message_loop() that shuts down socket when we are done
pub async fn run<F, S>(
pub async fn run(
mut self,
handler: &mut impl Handler<IO>,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S + Clone,
S: Future,
{
let ret = self
.run_message_loop(handler, shutdown_watcher.clone())
.await;
cancel: &CancellationToken,
) -> Result<(), QueryError> {
let ret = self.run_message_loop(handler, cancel).await;

tokio::select! {
_ = shutdown_watcher() => {
_ = cancel.cancelled() => {
// do nothing; we most likely got already stopped by shutdown and will log it next.
}
_ = self.framed.shutdown() => {
Expand Down Expand Up @@ -444,21 +439,17 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
}
}

async fn run_message_loop<F, S>(
async fn run_message_loop(
&mut self,
handler: &mut impl Handler<IO>,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
cancel: &CancellationToken,
) -> Result<(), QueryError> {
trace!("postgres backend to {:?} started", self.peer_addr);

tokio::select!(
biased;

_ = shutdown_watcher() => {
_ = cancel.cancelled() => {
// We were requested to shut down.
tracing::info!("shutdown request received during handshake");
return Err(QueryError::Shutdown)
Expand All @@ -473,7 +464,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
let mut query_string = Bytes::new();
while let Some(msg) = tokio::select!(
biased;
_ = shutdown_watcher() => {
_ = cancel.cancelled() => {
// We were requested to shut down.
tracing::info!("shutdown request received in run_message_loop");
return Err(QueryError::Shutdown)
Expand All @@ -485,7 +476,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
let result = self.process_message(handler, msg, &mut query_string).await;
tokio::select!(
biased;
_ = shutdown_watcher() => {
_ = cancel.cancelled() => {
// We were requested to shut down.
tracing::info!("shutdown request received during response flush");

Expand Down
7 changes: 4 additions & 3 deletions libs/postgres_backend/tests/simple_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ use once_cell::sync::Lazy;
use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
use pq_proto::{BeMessage, RowDescriptor};
use std::io::Cursor;
use std::{future, sync::Arc};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::{Config, NoTls, SimpleQueryMessage};
use tokio_postgres_rustls::MakeRustlsConnect;
use tokio_util::sync::CancellationToken;

// generate client, server test streams
async fn make_tcp_pair() -> (TcpStream, TcpStream) {
Expand Down Expand Up @@ -50,7 +51,7 @@ async fn simple_select() {

tokio::spawn(async move {
let mut handler = TestHandler {};
pgbackend.run(&mut handler, future::pending::<()>).await
pgbackend.run(&mut handler, &CancellationToken::new()).await
});

let conf = Config::new();
Expand Down Expand Up @@ -102,7 +103,7 @@ async fn simple_select_ssl() {

tokio::spawn(async move {
let mut handler = TestHandler {};
pgbackend.run(&mut handler, future::pending::<()>).await
pgbackend.run(&mut handler, &CancellationToken::new()).await
});

let client_cfg = rustls::ClientConfig::builder()
Expand Down
2 changes: 1 addition & 1 deletion pageserver/src/page_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ async fn page_service_conn_main(
let pgbackend = PostgresBackend::new_from_io(socket, peer_addr, auth_type, None)?;

match pgbackend
.run(&mut conn_handler, task_mgr::shutdown_watcher)
.run(&mut conn_handler, &task_mgr::shutdown_token())
.await
{
Ok(()) => {
Expand Down
7 changes: 5 additions & 2 deletions proxy/src/console/mgmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ use anyhow::Context;
use once_cell::sync::Lazy;
use postgres_backend::{AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
use std::{convert::Infallible, future};
use std::convert::Infallible;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, Instrument};

static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
Expand Down Expand Up @@ -67,7 +68,9 @@ pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {

async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?;
pgbackend.run(&mut MgmtHandler, future::pending::<()>).await
pgbackend
.run(&mut MgmtHandler, &CancellationToken::new())
.await
}

/// A message received by `mgmt` when a compute node is ready.
Expand Down
5 changes: 3 additions & 2 deletions safekeeper/src/wal_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
//!
use anyhow::{Context, Result};
use postgres_backend::QueryError;
use std::{future, time::Duration};
use std::time::Duration;
use tokio::net::TcpStream;
use tokio_io_timeout::TimeoutReader;
use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::{auth::Scope, measured_stream::MeasuredStream};

Expand Down Expand Up @@ -100,7 +101,7 @@ async fn handle_socket(
// libpq protocol between safekeeper and walproposer / pageserver
// We don't use shutdown.
pgbackend
.run(&mut conn_handler, future::pending::<()>)
.run(&mut conn_handler, &CancellationToken::new())
.await
}

Expand Down
2 changes: 0 additions & 2 deletions workspace_hack/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@ crossbeam-utils = { version = "0.8" }
either = { version = "1" }
fail = { version = "0.5", default-features = false, features = ["failpoints"] }
futures-channel = { version = "0.3", features = ["sink"] }
futures-core = { version = "0.3" }
futures-executor = { version = "0.3" }
futures-io = { version = "0.3" }
futures-sink = { version = "0.3" }
futures-util = { version = "0.3", features = ["channel", "io", "sink"] }
getrandom = { version = "0.2", default-features = false, features = ["std"] }
hashbrown = { version = "0.14", features = ["raw"] }
Expand Down
Loading