Skip to content

Commit

Permalink
proxy: report blame for passthrough disconnect io errors (#8170)
Browse files Browse the repository at this point in the history
## Problem

Hard to debug the disconnection reason currently.

## Summary of changes

Keep track of error-direction, and therefore error source (client vs
compute) during passthrough.
  • Loading branch information
conradludgate committed Jun 27, 2024
1 parent a126827 commit 87c65db
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 26 deletions.
9 changes: 6 additions & 3 deletions proxy/src/bin/pg_sni_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::context::RequestMonitoring;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled};
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
use rustls::pki_types::PrivateKeyDer;
use tokio::net::TcpListener;

Expand Down Expand Up @@ -286,7 +286,10 @@ async fn handle_client(

// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");
let _ = copy_bidirectional_client_compute(&mut tls_stream, &mut client).await?;

Ok(())
match copy_bidirectional_client_compute(&mut tls_stream, &mut client).await {
Ok(_) => Ok(()),
Err(ErrorSource::Client(err)) => Err(err).context("client"),
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
}
}
8 changes: 6 additions & 2 deletions proxy/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub mod passthrough;
pub mod retry;
pub mod wake_compute;
pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource;

use crate::{
auth,
Expand Down Expand Up @@ -148,8 +149,11 @@ pub async fn task_main(
ctx.log_connect();
match p.proxy_pass().instrument(span.clone()).await {
Ok(()) => {}
Err(e) => {
error!(parent: &span, "per-client task finished with an error: {e:#}");
Err(ErrorSource::Client(e)) => {
error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}");
}
Err(ErrorSource::Compute(e)) => {
error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}");
}
}
}
Expand Down
66 changes: 49 additions & 17 deletions proxy/src/proxy/copy_bidirectional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,39 @@ enum TransferState {
Done(u64),
}

#[derive(Debug)]
pub enum ErrorDirection {
Read(io::Error),
Write(io::Error),
}

impl ErrorSource {
fn from_client(err: ErrorDirection) -> ErrorSource {
match err {
ErrorDirection::Read(client) => Self::Client(client),
ErrorDirection::Write(compute) => Self::Compute(compute),
}
}
fn from_compute(err: ErrorDirection) -> ErrorSource {
match err {
ErrorDirection::Write(client) => Self::Client(client),
ErrorDirection::Read(compute) => Self::Compute(compute),
}
}
}

#[derive(Debug)]
pub enum ErrorSource {
Client(io::Error),
Compute(io::Error),
}

fn transfer_one_direction<A, B>(
cx: &mut Context<'_>,
state: &mut TransferState,
r: &mut A,
w: &mut B,
) -> Poll<io::Result<u64>>
) -> Poll<Result<u64, ErrorDirection>>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
Expand All @@ -32,7 +59,7 @@ where
*state = TransferState::ShuttingDown(count);
}
TransferState::ShuttingDown(count) => {
ready!(w.as_mut().poll_shutdown(cx))?;
ready!(w.as_mut().poll_shutdown(cx)).map_err(ErrorDirection::Write)?;
*state = TransferState::Done(*count);
}
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
Expand All @@ -44,7 +71,7 @@ where
pub async fn copy_bidirectional_client_compute<Client, Compute>(
client: &mut Client,
compute: &mut Compute,
) -> Result<(u64, u64), std::io::Error>
) -> Result<(u64, u64), ErrorSource>
where
Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
Compute: AsyncRead + AsyncWrite + Unpin + ?Sized,
Expand All @@ -54,9 +81,11 @@ where

poll_fn(|cx| {
let mut client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)?;
transfer_one_direction(cx, &mut client_to_compute, client, compute)
.map_err(ErrorSource::from_client)?;
let mut compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, compute, client)?;
transfer_one_direction(cx, &mut compute_to_client, compute, client)
.map_err(ErrorSource::from_compute)?;

// Early termination checks from compute to client.
if let TransferState::Done(_) = compute_to_client {
Expand All @@ -65,18 +94,20 @@ where
// Initiate shutdown
client_to_compute = TransferState::ShuttingDown(buf.amt);
client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)?;
transfer_one_direction(cx, &mut client_to_compute, client, compute)
.map_err(ErrorSource::from_client)?;
}
}

// Early termination checks from compute to client.
// Early termination checks from client to compute.
if let TransferState::Done(_) = client_to_compute {
if let TransferState::Running(buf) = &compute_to_client {
info!("Client is done, terminate compute");
// Initiate shutdown
compute_to_client = TransferState::ShuttingDown(buf.amt);
compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, client, compute)?;
transfer_one_direction(cx, &mut compute_to_client, compute, client)
.map_err(ErrorSource::from_compute)?;
}
}

Expand Down Expand Up @@ -138,7 +169,7 @@ impl CopyBuffer {
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<usize>>
) -> Poll<Result<usize, ErrorDirection>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
Expand All @@ -149,11 +180,11 @@ impl CopyBuffer {
// Top up the buffer towards full if we can read a bit more
// data - this should improve the chances of a large write
if !me.read_done && me.cap < me.buf.len() {
ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
ready!(me.poll_fill_buf(cx, reader.as_mut())).map_err(ErrorDirection::Read)?;
}
Poll::Pending
}
res => res,
res => res.map_err(ErrorDirection::Write),
}
}

Expand All @@ -162,7 +193,7 @@ impl CopyBuffer {
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<u64>>
) -> Poll<Result<u64, ErrorDirection>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
Expand All @@ -176,12 +207,13 @@ impl CopyBuffer {

match self.poll_fill_buf(cx, reader.as_mut()) {
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))),
Poll::Pending => {
// Try flushing when the reader has no progress to avoid deadlock
// when the reader depends on buffered writer.
if self.need_flush {
ready!(writer.as_mut().poll_flush(cx))?;
ready!(writer.as_mut().poll_flush(cx))
.map_err(ErrorDirection::Write)?;
self.need_flush = false;
}

Expand All @@ -194,10 +226,10 @@ impl CopyBuffer {
while self.pos < self.cap {
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
if i == 0 {
return Poll::Ready(Err(io::Error::new(
return Poll::Ready(Err(ErrorDirection::Write(io::Error::new(
io::ErrorKind::WriteZero,
"write zero byte into writer",
)));
))));
} else {
self.pos += i;
self.amt += i as u64;
Expand All @@ -216,7 +248,7 @@ impl CopyBuffer {
// If we've written all the data and we've seen EOF, flush out the
// data and finish the transfer.
if self.pos == self.cap && self.read_done {
ready!(writer.as_mut().poll_flush(cx))?;
ready!(writer.as_mut().poll_flush(cx)).map_err(ErrorDirection::Write)?;
return Poll::Ready(Ok(self.amt));
}
}
Expand Down
10 changes: 7 additions & 3 deletions proxy/src/proxy/passthrough.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use utils::measured_stream::MeasuredStream;

use super::copy_bidirectional::ErrorSource;

/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
pub async fn proxy_pass(
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
aux: MetricsAuxInfo,
) -> anyhow::Result<()> {
) -> Result<(), ErrorSource> {
let usage = USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id,
branch_id: aux.branch_id,
Expand Down Expand Up @@ -66,9 +68,11 @@ pub struct ProxyPassthrough<P, S> {
}

impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
pub async fn proxy_pass(self) -> anyhow::Result<()> {
pub async fn proxy_pass(self) -> Result<(), ErrorSource> {
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
self.compute.cancel_closure.try_cancel_query().await?;
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {
tracing::error!(?err, "could not cancel the query in the database");
}
res
}
}
8 changes: 7 additions & 1 deletion proxy/src/serverless/websocket.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::proxy::ErrorSource;
use crate::{
cancellation::CancellationHandlerMain,
config::ProxyConfig,
Expand All @@ -7,6 +8,7 @@ use crate::{
proxy::{handle_client, ClientMode},
rate_limiter::EndpointRateLimiter,
};
use anyhow::Context as _;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use framed_websockets::{Frame, OpCode, WebSocketServer};
use futures::{Sink, Stream};
Expand Down Expand Up @@ -165,7 +167,11 @@ pub async fn serve_websocket(
Ok(Some(p)) => {
ctx.set_success();
ctx.log_connect();
p.proxy_pass().await
match p.proxy_pass().await {
Ok(()) => Ok(()),
Err(ErrorSource::Client(err)) => Err(err).context("client"),
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
}
}
}
}
Expand Down

0 comments on commit 87c65db

Please sign in to comment.