Skip to content

Commit

Permalink
proxy http cancellation safety (#7117)
Browse files Browse the repository at this point in the history
## Problem

hyper auto-cancels the request futures on connection close.
`sql_over_http::handle` is not 'drop cancel safe', so we need to do some
other work to make sure connections are queries in the right way.

## Summary of changes

1. tokio::spawn the request handler to resolve the initial cancel-safety
issue
2. share a cancellation token, and cancel it when the request `Service`
is dropped.
3. Add a new log span to be able to track the HTTP connection lifecycle.
  • Loading branch information
conradludgate authored Mar 14, 2024
1 parent 69338e5 commit 3bd6551
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 42 deletions.
18 changes: 17 additions & 1 deletion proxy/src/protocol2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,14 @@ impl Accept for ProxyProtocolAccept {
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
tracing::info!(protocol = self.protocol, "accepted new TCP connection");

let conn_id = uuid::Uuid::new_v4();
let span = tracing::info_span!("http_conn", ?conn_id);
{
let _enter = span.enter();
tracing::info!("accepted new TCP connection");
}

let Some(conn) = conn else {
return Poll::Ready(None);
};
Expand All @@ -354,6 +361,7 @@ impl Accept for ProxyProtocolAccept {
.with_label_values(&[self.protocol])
.guard(),
)),
span,
})))
}
}
Expand All @@ -364,6 +372,14 @@ pin_project! {
pub inner: T,
pub connection_id: Uuid,
pub gauge: Mutex<Option<IntCounterPairGuard>>,
pub span: tracing::Span,
}

impl<S> PinnedDrop for WithConnectionGuard<S> {
fn drop(this: Pin<&mut Self>) {
let _enter = this.span.enter();
tracing::info!("HTTP connection closed")
}
}
}

Expand Down
76 changes: 54 additions & 22 deletions proxy/src/serverless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use rand::SeedableRng;
pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio_util::task::TaskTracker;
use tracing::instrument::Instrumented;

use crate::context::RequestMonitoring;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard};
Expand All @@ -30,13 +31,12 @@ use hyper::{
Body, Method, Request, Response,
};

use std::convert::Infallible;
use std::net::IpAddr;
use std::sync::Arc;
use std::task::Poll;
use tls_listener::TlsListener;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tokio_util::sync::{CancellationToken, DropGuard};
use tracing::{error, info, warn, Instrument};
use utils::http::{error::ApiError, json::json_response};

Expand Down Expand Up @@ -100,12 +100,7 @@ pub async fn task_main(
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
ws_connections.close(); // allows `ws_connections.wait to complete`

let tls_listener = TlsListener::new(
tls_acceptor,
addr_incoming,
"http",
config.handshake_timeout,
);
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming, config.handshake_timeout);

let make_svc = hyper::service::make_service_fn(
|stream: &tokio_rustls::server::TlsStream<
Expand All @@ -121,6 +116,11 @@ pub async fn task_main(
.take()
.expect("gauge should be set on connection start");

// Cancel all current inflight HTTP requests if the HTTP connection is closed.
let http_cancellation_token = CancellationToken::new();
let cancel_connection = http_cancellation_token.clone().drop_guard();

let span = conn.span.clone();
let client_addr = conn.inner.client_addr();
let remote_addr = conn.inner.inner.remote_addr();
let backend = backend.clone();
Expand All @@ -136,27 +136,43 @@ pub async fn task_main(
Ok(MetricService::new(
hyper::service::service_fn(move |req: Request<Body>| {
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let ws_connections2 = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellation_handler = cancellation_handler.clone();

async move {
Ok::<_, Infallible>(
request_handler(
let http_cancellation_token = http_cancellation_token.child_token();

// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
ws_connections.spawn(
async move {
// Cancel the current inflight HTTP request if the requets stream is closed.
// This is slightly different to `_cancel_connection` in that
// h2 can cancel individual requests with a `RST_STREAM`.
let _cancel_session = http_cancellation_token.clone().drop_guard();

let res = request_handler(
req,
config,
backend,
ws_connections,
ws_connections2,
cancellation_handler,
peer_addr.ip(),
endpoint_rate_limiter,
http_cancellation_token,
)
.await
.map_or_else(|e| e.into_response(), |r| r),
)
}
.map_or_else(|e| e.into_response(), |r| r);

_cancel_session.disarm();

res
}
.in_current_span(),
)
}),
gauge,
cancel_connection,
span,
))
}
},
Expand All @@ -176,11 +192,23 @@ pub async fn task_main(
struct MetricService<S> {
inner: S,
_gauge: IntCounterPairGuard,
_cancel: DropGuard,
span: tracing::Span,
}

impl<S> MetricService<S> {
fn new(inner: S, _gauge: IntCounterPairGuard) -> MetricService<S> {
MetricService { inner, _gauge }
fn new(
inner: S,
_gauge: IntCounterPairGuard,
_cancel: DropGuard,
span: tracing::Span,
) -> MetricService<S> {
MetricService {
inner,
_gauge,
_cancel,
span,
}
}
}

Expand All @@ -190,14 +218,16 @@ where
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
type Future = Instrumented<S::Future>;

fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
self.inner.call(req)
self.span
.in_scope(|| self.inner.call(req))
.instrument(self.span.clone())
}
}

Expand All @@ -210,6 +240,8 @@ async fn request_handler(
cancellation_handler: Arc<CancellationHandler>,
peer_addr: IpAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
) -> Result<Response<Body>, ApiError> {
let session_id = uuid::Uuid::new_v4();

Expand Down Expand Up @@ -253,7 +285,7 @@ async fn request_handler(
let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
let span = ctx.span.clone();

sql_over_http::handle(config, ctx, request, backend)
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
.instrument(span)
.await
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
Expand Down
2 changes: 1 addition & 1 deletion proxy/src/serverless/sql_over_http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ pub async fn handle(
mut ctx: RequestMonitoring,
request: Request<Body>,
backend: Arc<PoolingBackend>,
cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
let cancel = CancellationToken::new();
let cancel2 = cancel.clone();
let handle = tokio::spawn(async move {
time::sleep(config.http_config.request_timeout).await;
Expand Down
29 changes: 11 additions & 18 deletions proxy/src/serverless/tls_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use tokio::{
time::timeout,
};
use tokio_rustls::{server::TlsStream, TlsAcceptor};
use tracing::{info, warn};
use tracing::{info, warn, Instrument};

use crate::{
metrics::TLS_HANDSHAKE_FAILURES,
Expand All @@ -29,24 +29,17 @@ pin_project! {
tls: TlsAcceptor,
waiting: JoinSet<Option<TlsStream<A::Conn>>>,
timeout: Duration,
protocol: &'static str,
}
}

impl<A: Accept> TlsListener<A> {
/// Create a `TlsListener` with default options.
pub(crate) fn new(
tls: TlsAcceptor,
listener: A,
protocol: &'static str,
timeout: Duration,
) -> Self {
pub(crate) fn new(tls: TlsAcceptor, listener: A, timeout: Duration) -> Self {
TlsListener {
listener,
tls,
waiting: JoinSet::new(),
timeout,
protocol,
}
}
}
Expand All @@ -73,7 +66,7 @@ where
Poll::Ready(Some(Ok(mut conn))) => {
let t = *this.timeout;
let tls = this.tls.clone();
let protocol = *this.protocol;
let span = conn.span.clone();
this.waiting.spawn(async move {
let peer_addr = match conn.inner.wait_for_addr().await {
Ok(Some(addr)) => addr,
Expand All @@ -86,21 +79,24 @@ where

let accept = tls.accept(conn);
match timeout(t, accept).await {
Ok(Ok(conn)) => Some(conn),
Ok(Ok(conn)) => {
info!(%peer_addr, "accepted new TLS connection");
Some(conn)
},
// The handshake failed, try getting another connection from the queue
Ok(Err(e)) => {
TLS_HANDSHAKE_FAILURES.inc();
warn!(%peer_addr, protocol, "failed to accept TLS connection: {e:?}");
warn!(%peer_addr, "failed to accept TLS connection: {e:?}");
None
}
// The handshake timed out, try getting another connection from the queue
Err(_) => {
TLS_HANDSHAKE_FAILURES.inc();
warn!(%peer_addr, protocol, "failed to accept TLS connection: timeout");
warn!(%peer_addr, "failed to accept TLS connection: timeout");
None
}
}
});
}.instrument(span));
}
Poll::Ready(Some(Err(e))) => {
tracing::error!("error accepting TCP connection: {e}");
Expand All @@ -112,10 +108,7 @@ where

loop {
return match this.waiting.poll_join_next(cx) {
Poll::Ready(Some(Ok(Some(conn)))) => {
info!(protocol = this.protocol, "accepted new TLS connection");
Poll::Ready(Some(Ok(conn)))
}
Poll::Ready(Some(Ok(Some(conn)))) => Poll::Ready(Some(Ok(conn))),
// The handshake failed to complete, try getting another connection from the queue
Poll::Ready(Some(Ok(None))) => continue,
// The handshake panicked or was cancelled. ignore and get another connection
Expand Down
2 changes: 2 additions & 0 deletions test_runner/fixtures/neon_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2944,6 +2944,7 @@ def http_query(self, query, args, **kwargs):
user = quote(kwargs["user"])
password = quote(kwargs["password"])
expected_code = kwargs.get("expected_code")
timeout = kwargs.get("timeout")

log.info(f"Executing http query: {query}")

Expand All @@ -2957,6 +2958,7 @@ def http_query(self, query, args, **kwargs):
"Neon-Pool-Opt-In": "true",
},
verify=str(self.test_output_dir / "proxy.crt"),
timeout=timeout,
)

if expected_code is not None:
Expand Down
36 changes: 36 additions & 0 deletions test_runner/regress/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,39 @@ def test_sql_over_http_timeout_cancel(static_proxy: NeonProxy):
assert (
"duplicate key value violates unique constraint" in res["message"]
), "HTTP query should conflict"


def test_sql_over_http_connection_cancel(static_proxy: NeonProxy):
static_proxy.safe_psql("create role http with login password 'http' superuser")

static_proxy.safe_psql("create table test_table ( id int primary key )")

# insert into a table, with a unique constraint, after sleeping for n seconds
query = "WITH temp AS ( \
SELECT pg_sleep($1) as sleep, $2::int as id \
) INSERT INTO test_table (id) SELECT id FROM temp"

try:
# The request should complete before the proxy HTTP timeout triggers.
# Timeout and cancel the request on the client side before the query completes.
static_proxy.http_query(
query,
[static_proxy.http_timeout_seconds - 1, 1],
user="http",
password="http",
timeout=2,
)
except requests.exceptions.ReadTimeout:
pass

# wait until the query _would_ have been complete
time.sleep(static_proxy.http_timeout_seconds)

res = static_proxy.http_query(query, [1, 1], user="http", password="http", expected_code=200)
assert res["command"] == "INSERT", "HTTP query should insert"
assert res["rowCount"] == 1, "HTTP query should insert"

res = static_proxy.http_query(query, [0, 1], user="http", password="http", expected_code=400)
assert (
"duplicate key value violates unique constraint" in res["message"]
), "HTTP query should conflict"

1 comment on commit 3bd6551

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2606 tests run: 2467 passed, 1 failed, 138 skipped (full report)


Failures on Postgres 14

  • test_bulk_insert[neon-github-actions-selfhosted]: release
# Run all failed tests locally:
scripts/pytest -vv -n $(nproc) -k "test_bulk_insert[neon-release-pg14-github-actions-selfhosted]"

Code coverage* (full report)

  • functions: 28.7% (7031 of 24503 functions)
  • lines: 47.5% (43457 of 91583 lines)

* collected from Rust tests only


The comment gets automatically updated with the latest test results
3bd6551 at 2024-03-14T09:16:32.779Z :recycle:

Please sign in to comment.