diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index e1674049a60ad..44e880838e077 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -10,7 +10,7 @@ use itertools::Itertools; use proxy::config::TlsServerEndPoint; use proxy::context::RequestMonitoring; use proxy::metrics::{Metrics, ThreadPoolMetrics}; -use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled}; +use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource}; use rustls::pki_types::PrivateKeyDer; use tokio::net::TcpListener; @@ -286,7 +286,10 @@ async fn handle_client( // Starting from here we only proxy the client's traffic. info!("performing the proxy pass..."); - let _ = copy_bidirectional_client_compute(&mut tls_stream, &mut client).await?; - Ok(()) + match copy_bidirectional_client_compute(&mut tls_stream, &mut client).await { + Ok(_) => Ok(()), + Err(ErrorSource::Client(err)) => Err(err).context("client"), + Err(ErrorSource::Compute(err)) => Err(err).context("compute"), + } } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 072f51958f48e..3edefcf21a49b 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -8,6 +8,7 @@ pub mod passthrough; pub mod retry; pub mod wake_compute; pub use copy_bidirectional::copy_bidirectional_client_compute; +pub use copy_bidirectional::ErrorSource; use crate::{ auth, @@ -148,8 +149,11 @@ pub async fn task_main( ctx.log_connect(); match p.proxy_pass().instrument(span.clone()).await { Ok(()) => {} - Err(e) => { - error!(parent: &span, "per-client task finished with an error: {e:#}"); + Err(ErrorSource::Client(e)) => { + error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}"); + } + Err(ErrorSource::Compute(e)) => { + error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}"); } } } diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/proxy/copy_bidirectional.rs index aaf3688f21a05..3c45fff969f87 100644 --- a/proxy/src/proxy/copy_bidirectional.rs +++ b/proxy/src/proxy/copy_bidirectional.rs @@ -13,12 +13,39 @@ enum TransferState { Done(u64), } +#[derive(Debug)] +pub enum ErrorDirection { + Read(io::Error), + Write(io::Error), +} + +impl ErrorSource { + fn from_client(err: ErrorDirection) -> ErrorSource { + match err { + ErrorDirection::Read(client) => Self::Client(client), + ErrorDirection::Write(compute) => Self::Compute(compute), + } + } + fn from_compute(err: ErrorDirection) -> ErrorSource { + match err { + ErrorDirection::Write(client) => Self::Client(client), + ErrorDirection::Read(compute) => Self::Compute(compute), + } + } +} + +#[derive(Debug)] +pub enum ErrorSource { + Client(io::Error), + Compute(io::Error), +} + fn transfer_one_direction( cx: &mut Context<'_>, state: &mut TransferState, r: &mut A, w: &mut B, -) -> Poll> +) -> Poll> where A: AsyncRead + AsyncWrite + Unpin + ?Sized, B: AsyncRead + AsyncWrite + Unpin + ?Sized, @@ -32,7 +59,7 @@ where *state = TransferState::ShuttingDown(count); } TransferState::ShuttingDown(count) => { - ready!(w.as_mut().poll_shutdown(cx))?; + ready!(w.as_mut().poll_shutdown(cx)).map_err(ErrorDirection::Write)?; *state = TransferState::Done(*count); } TransferState::Done(count) => return Poll::Ready(Ok(*count)), @@ -44,7 +71,7 @@ where pub async fn copy_bidirectional_client_compute( client: &mut Client, compute: &mut Compute, -) -> Result<(u64, u64), std::io::Error> +) -> Result<(u64, u64), ErrorSource> where Client: AsyncRead + AsyncWrite + Unpin + ?Sized, Compute: AsyncRead + AsyncWrite + Unpin + ?Sized, @@ -54,9 +81,11 @@ where poll_fn(|cx| { let mut client_to_compute_result = - transfer_one_direction(cx, &mut client_to_compute, client, compute)?; + transfer_one_direction(cx, &mut client_to_compute, client, compute) + .map_err(ErrorSource::from_client)?; let mut compute_to_client_result = - transfer_one_direction(cx, &mut compute_to_client, compute, client)?; + transfer_one_direction(cx, &mut compute_to_client, compute, client) + .map_err(ErrorSource::from_compute)?; // Early termination checks from compute to client. if let TransferState::Done(_) = compute_to_client { @@ -65,18 +94,20 @@ where // Initiate shutdown client_to_compute = TransferState::ShuttingDown(buf.amt); client_to_compute_result = - transfer_one_direction(cx, &mut client_to_compute, client, compute)?; + transfer_one_direction(cx, &mut client_to_compute, client, compute) + .map_err(ErrorSource::from_client)?; } } - // Early termination checks from compute to client. + // Early termination checks from client to compute. if let TransferState::Done(_) = client_to_compute { if let TransferState::Running(buf) = &compute_to_client { info!("Client is done, terminate compute"); // Initiate shutdown compute_to_client = TransferState::ShuttingDown(buf.amt); compute_to_client_result = - transfer_one_direction(cx, &mut compute_to_client, client, compute)?; + transfer_one_direction(cx, &mut compute_to_client, compute, client) + .map_err(ErrorSource::from_compute)?; } } @@ -138,7 +169,7 @@ impl CopyBuffer { cx: &mut Context<'_>, mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, - ) -> Poll> + ) -> Poll> where R: AsyncRead + ?Sized, W: AsyncWrite + ?Sized, @@ -149,11 +180,11 @@ impl CopyBuffer { // Top up the buffer towards full if we can read a bit more // data - this should improve the chances of a large write if !me.read_done && me.cap < me.buf.len() { - ready!(me.poll_fill_buf(cx, reader.as_mut()))?; + ready!(me.poll_fill_buf(cx, reader.as_mut())).map_err(ErrorDirection::Read)?; } Poll::Pending } - res => res, + res => res.map_err(ErrorDirection::Write), } } @@ -162,7 +193,7 @@ impl CopyBuffer { cx: &mut Context<'_>, mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, - ) -> Poll> + ) -> Poll> where R: AsyncRead + ?Sized, W: AsyncWrite + ?Sized, @@ -176,12 +207,13 @@ impl CopyBuffer { match self.poll_fill_buf(cx, reader.as_mut()) { Poll::Ready(Ok(())) => (), - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))), Poll::Pending => { // Try flushing when the reader has no progress to avoid deadlock // when the reader depends on buffered writer. if self.need_flush { - ready!(writer.as_mut().poll_flush(cx))?; + ready!(writer.as_mut().poll_flush(cx)) + .map_err(ErrorDirection::Write)?; self.need_flush = false; } @@ -194,10 +226,10 @@ impl CopyBuffer { while self.pos < self.cap { let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?; if i == 0 { - return Poll::Ready(Err(io::Error::new( + return Poll::Ready(Err(ErrorDirection::Write(io::Error::new( io::ErrorKind::WriteZero, "write zero byte into writer", - ))); + )))); } else { self.pos += i; self.amt += i as u64; @@ -216,7 +248,7 @@ impl CopyBuffer { // If we've written all the data and we've seen EOF, flush out the // data and finish the transfer. if self.pos == self.cap && self.read_done { - ready!(writer.as_mut().poll_flush(cx))?; + ready!(writer.as_mut().poll_flush(cx)).map_err(ErrorDirection::Write)?; return Poll::Ready(Ok(self.amt)); } } diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 62de79946f3c4..9942fac383b77 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -10,13 +10,15 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; use utils::measured_stream::MeasuredStream; +use super::copy_bidirectional::ErrorSource; + /// Forward bytes in both directions (client <-> compute). #[tracing::instrument(skip_all)] pub async fn proxy_pass( client: impl AsyncRead + AsyncWrite + Unpin, compute: impl AsyncRead + AsyncWrite + Unpin, aux: MetricsAuxInfo, -) -> anyhow::Result<()> { +) -> Result<(), ErrorSource> { let usage = USAGE_METRICS.register(Ids { endpoint_id: aux.endpoint_id, branch_id: aux.branch_id, @@ -66,9 +68,11 @@ pub struct ProxyPassthrough { } impl ProxyPassthrough { - pub async fn proxy_pass(self) -> anyhow::Result<()> { + pub async fn proxy_pass(self) -> Result<(), ErrorSource> { let res = proxy_pass(self.client, self.compute.stream, self.aux).await; - self.compute.cancel_closure.try_cancel_query().await?; + if let Err(err) = self.compute.cancel_closure.try_cancel_query().await { + tracing::error!(?err, "could not cancel the query in the database"); + } res } } diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 0e9772733da48..0d5b88f07b0c8 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -1,3 +1,4 @@ +use crate::proxy::ErrorSource; use crate::{ cancellation::CancellationHandlerMain, config::ProxyConfig, @@ -7,6 +8,7 @@ use crate::{ proxy::{handle_client, ClientMode}, rate_limiter::EndpointRateLimiter, }; +use anyhow::Context as _; use bytes::{Buf, BufMut, Bytes, BytesMut}; use framed_websockets::{Frame, OpCode, WebSocketServer}; use futures::{Sink, Stream}; @@ -165,7 +167,11 @@ pub async fn serve_websocket( Ok(Some(p)) => { ctx.set_success(); ctx.log_connect(); - p.proxy_pass().await + match p.proxy_pass().await { + Ok(()) => Ok(()), + Err(ErrorSource::Client(err)) => Err(err).context("client"), + Err(ErrorSource::Compute(err)) => Err(err).context("compute"), + } } } }