From 2d27fb7a5aac6a205a9dedd56b7102a0f2221f5a Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 09:58:20 -0400 Subject: [PATCH] feat: well factored http connection handling --- Cargo.toml | 4 ++ src/lib.rs | 207 ++++++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 184 insertions(+), 27 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3ce4cb2..d3cec51 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,10 @@ tokio = { version = "1.40.0", features = ["net"] } tokio-rustls = "0.26.0" tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" +http-body = "1.0.1" +tokio-stream = "0.1.16" +bytes = "1.7.1" +pin-project = "1.1.5" [dev-dependencies] tokio = { version = "1.40.0", features = ["macros"] } diff --git a/src/lib.rs b/src/lib.rs index 69f9e7d..b49e7e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,37 +1,137 @@ -// Crate links: -// - tokio: https://docs.rs/tokio/latest/tokio/ -// - rustls: https://docs.rs/rustls/latest/rustls/ -// - tokio_rustls:https://docs.rs/tokio-rustls/latest/tokio_rustls/ -// - hyper: https://docs.rs/hyper/latest/hyper/ -// - tower: https://docs.rs/tower/latest/tower/ -// -// We take a `SocketAddr` to bind the server to a specific address. -// We use `TcpListener` to bind the server to the specified address at the TCP layer. -// We use `rustls` to create a new `ServerConfig` instance. -// We use `tokio_rustls` to create a new `TlsAcceptor` instance. -// We use `hyper_util::server::conn::auto` to create a new `Connection` instance. -// Which then passes requests to a `hyper::service::Service` instance. -// Which then can optionally pass requests to a `tower::Service` instance. -// Behind that can be axum, tower, tonic, -// or any other service that implements the `tower::Service` trait. - -use std::fs; -use std::io; -use std::net::SocketAddr; -use std::sync::Arc; - +use bytes::Bytes; use http::{Method, Request, Response, StatusCode}; +use http_body::Body; use http_body_util::BodyExt; use http_body_util::Full; -use hyper::body::{Bytes, Incoming}; +use hyper::{body::Incoming, service::Service as HyperService}; use hyper_util::rt::{TokioExecutor, TokioIo}; -use hyper_util::server::conn::auto::Builder; +use hyper_util::server::conn::auto::Builder as HttpConnBuilder; +use hyper_util::server::conn::auto::HttpServerConnExec; use hyper_util::service::TowerToHyperService; +use pin_project::pin_project; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use rustls::ServerConfig; +use std::error::Error as StdError; +use std::future::pending; +use std::net::SocketAddr; +use std::pin::{pin, Pin}; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; +use std::{fmt, fs, future::Future, io}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::macros::support::poll_fn; use tokio::net::TcpListener; +use tokio::time::sleep; use tokio_rustls::TlsAcceptor; +use tokio_stream::Stream; +use tokio_stream::StreamExt as _; use tower::{Service, ServiceBuilder}; +use tracing::{debug, trace}; + +// From `futures-util` crate, borrowed since this is the only dependency hyper-server requires. +// LICENSE: MIT or Apache-2.0 +// A future which only yields `Poll::Ready` once, and thereafter yields `Poll::Pending`. +#[pin_project] +struct Fuse { + #[pin] + inner: Option, +} + +impl Future for Fuse +where + F: Future, +{ + type Output = F::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project().inner.as_pin_mut() { + Some(fut) => fut.poll(cx).map(|output| { + self.project().inner.set(None); + output + }), + None => Poll::Pending, + } + } +} + +type Source = Box; + +/// Errors that originate from the client hyper-server; +pub struct Error { + inner: ErrorImpl, +} + +struct ErrorImpl { + kind: Kind, + source: Option, +} + +#[derive(Debug)] +pub(crate) enum Kind { + Transport, +} + +impl Error { + pub(crate) fn new(kind: Kind) -> Self { + Self { + inner: ErrorImpl { kind, source: None }, + } + } + + pub(crate) fn with(mut self, source: impl Into) -> Self { + self.inner.source = Some(source.into()); + self + } + + pub(crate) fn from_source( + source: impl Into + std::error::Error + std::marker::Send + std::marker::Sync + 'static, + ) -> Self { + Error::new(Kind::Transport).with(source) + } + + fn description(&self) -> &str { + match &self.inner.kind { + Kind::Transport => "transport error", + } + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut f = f.debug_tuple("hyper_server::Error"); + + f.field(&self.inner.kind); + + if let Some(source) = &self.inner.source { + f.field(source); + } + + f.finish() + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.description()) + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.inner + .source + .as_ref() + .map(|source| &**source as &(dyn StdError + 'static)) + } +} + +async fn sleep_or_pending(wait_for: Option) { + match wait_for { + Some(wait) => sleep(wait).await, + None => pending().await, + }; +} #[derive(Debug, Clone)] pub struct Logger { @@ -43,6 +143,7 @@ impl Logger { } } type Req = Request; + impl Service for Logger where S: Service + Clone, @@ -114,9 +215,9 @@ async fn echo(req: Request) -> Result>, hyper::Er Ok(response) } -pub struct HyperServer {} +pub struct Server {} -impl HyperServer { +impl Server { pub async fn serve(&self) -> Result<(), Box> { // Get a random port from the OS let addr = SocketAddr::from(([127, 0, 0, 1], 0)); @@ -173,7 +274,7 @@ impl HyperServer { }; // Serve the http connection - if let Err(err) = Builder::new(TokioExecutor::new()) + if let Err(err) = HttpConnBuilder::new(TokioExecutor::new()) .serve_connection(TokioIo::new(tls_stream), service) .await { @@ -182,6 +283,58 @@ impl HyperServer { }); } } + + /// Serves a single HTTP connection from a hyper service backend + async fn serve_connection( + hyper_io: IO, + hyper_svc: S, + builder: HttpConnBuilder, + mut watcher: Option>, + max_connection_age: Option, + ) where + B: Body + Send + 'static, + B::Data: Send, + B::Error: Into> + Send + Sync, + IO: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, + S: HyperService, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into> + Send, + E: HttpServerConnExec + Send + Sync + 'static, + { + tokio::spawn(async move { + { + let mut sig = pin!(Fuse { + inner: watcher.as_mut().map(|w| w.changed()), + }); + + let mut conn = pin!(builder.serve_connection(hyper_io, hyper_svc)); + + let sleep = sleep_or_pending(max_connection_age); + tokio::pin!(sleep); + + loop { + tokio::select! { + rv = &mut conn => { + if let Err(err) = rv { + debug!("failed serving connection: {:#}", err); + } + break; + }, + _ = &mut sleep => { + conn.as_mut().graceful_shutdown(); + sleep.set(sleep_or_pending(None)); + }, + _ = &mut sig => { + conn.as_mut().graceful_shutdown(); + } + } + } + } + + drop(watcher); + trace!("connection closed"); + }); + } } #[cfg(test)]