diff --git a/Cargo.toml b/Cargo.toml index d3cec51..3540a4b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,9 +23,12 @@ tokio-rustls = "0.26.0" tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" http-body = "1.0.1" -tokio-stream = "0.1.16" +tokio-stream = { version = "0.1.16", features = ["net"] } bytes = "1.7.1" pin-project = "1.1.5" +async-stream = "0.3.5" +futures = "0.3.30" [dev-dependencies] -tokio = { version = "1.40.0", features = ["macros"] } +tokio = { version = "1.0", features = ["rt", "net", "test-util", "macros"] } +tokio-test = "0.4.4" diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..2c38bb4 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,70 @@ +use std::{error::Error as StdError, fmt}; + +type Source = Box; + +/// Errors that originate from the server; +pub struct Error { + inner: ErrorImpl, +} + +struct ErrorImpl { + kind: Kind, + source: Option, +} + +#[derive(Debug)] +pub(crate) enum Kind { + Transport, +} + +impl Error { + pub(crate) fn new(kind: Kind) -> Self { + Self { + inner: ErrorImpl { kind, source: None }, + } + } + + pub(crate) fn with(mut self, source: impl Into) -> Self { + self.inner.source = Some(source.into()); + self + } + + pub(crate) fn from_source(source: impl Into) -> Self { + Error::new(Kind::Transport).with(source) + } + + fn description(&self) -> &str { + match &self.inner.kind { + Kind::Transport => "transport error", + } + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut f = f.debug_tuple("tonic::transport::Error"); + + f.field(&self.inner.kind); + + if let Some(source) = &self.inner.source { + f.field(source); + } + + f.finish() + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.description()) + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.inner + .source + .as_ref() + .map(|source| &**source as &(dyn StdError + 'static)) + } +} diff --git a/src/fuse.rs b/src/fuse.rs new file mode 100644 index 0000000..f75bb02 --- /dev/null +++ b/src/fuse.rs @@ -0,0 +1,30 @@ +use pin_project::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +// From `futures-util` crate, borrowed since this is the only dependency hyper-server requires. +// LICENSE: MIT or Apache-2.0 +// A future which only yields `Poll::Ready` once, and thereafter yields `Poll::Pending`. +#[pin_project] +pub(crate) struct Fuse { + #[pin] + pub(crate) inner: Option, +} + +impl Future for Fuse +where + F: Future, +{ + type Output = F::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project().inner.as_pin_mut() { + Some(fut) => fut.poll(cx).map(|output| { + self.project().inner.set(None); + output + }), + None => Poll::Pending, + } + } +} diff --git a/src/http.rs b/src/http.rs new file mode 100644 index 0000000..959c505 --- /dev/null +++ b/src/http.rs @@ -0,0 +1,100 @@ +use http::{Request, Response}; +use http_body::Body; +use hyper::body::Incoming; +use hyper::service::Service; +use hyper_util::server::conn::auto::{Builder, HttpServerConnExec}; +use std::future::pending; +use std::pin::pin; +use std::time::Duration; +use tokio::time::sleep; +use tracing::{debug, trace}; + +async fn sleep_or_pending(wait_for: Option) { + match wait_for { + Some(wait) => sleep(wait).await, + None => pending().await, + }; +} + +/// Serves a single HTTP connection from a hyper service backend. +/// +/// This method handles an individual HTTP connection, processing requests through +/// the provided service and managing the connection lifecycle. +/// +/// # Type Parameters +/// +/// * `B`: The body type for the HTTP response. +/// * `IO`: The I/O type for the HTTP connection. +/// * `S`: The service type that processes HTTP requests. +/// * `E`: The executor type for the HTTP server connection. +/// +/// # Parameters +/// +/// * `hyper_io`: The I/O object representing the inbound hyper IO stream. +/// * `hyper_svc`: The hyper `Service` implementation used to process HTTP requests. +/// * `builder`: An `HttpConnBuilder` used to create and serve the HTTP connection. +/// * `watcher`: An optional `tokio::sync::watch::Receiver` for graceful shutdown signaling. +/// * `max_connection_age`: An optional `Duration` specifying the maximum age of the connection +/// before initiating a graceful shutdown. +async fn serve_http_connection( + hyper_io: IO, + hyper_service: S, + builder: Builder, + mut watcher: Option>, + max_connection_age: Option, +) where + B: Body + Send + 'static, + B::Data: Send, + B::Error: Into> + Send + Sync, + IO: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into> + Send, + E: HttpServerConnExec + Send + Sync + 'static, +{ + // Spawn a new asynchronous task to handle the incoming hyper IO stream + tokio::spawn(async move { + { + // Set up a fused future for the watcher + let mut sig = pin!(crate::fuse::Fuse { + inner: watcher.as_mut().map(|w| w.changed()), + }); + + // Create and pin the HTTP connection + let mut conn = pin!(builder.serve_connection(hyper_io, hyper_service)); + + // Set up the sleep future for max connection age + let sleep = sleep_or_pending(max_connection_age); + tokio::pin!(sleep); + + // Main loop for serving the HTTP connection + loop { + tokio::select! { + // Handle the connection result + rv = &mut conn => { + if let Err(err) = rv { + // Log any errors that occur while serving the HTTP connection + debug!("failed serving HTTP connection: {:#}", err); + } + break; + }, + // Handle max connection age timeout + _ = &mut sleep => { + // Initiate a graceful shutdown when max connection age is reached + conn.as_mut().graceful_shutdown(); + sleep.set(sleep_or_pending(None)); + }, + // Handle graceful shutdown signal + _ = &mut sig => { + // Initiate a graceful shutdown when signal is received + conn.as_mut().graceful_shutdown(); + } + } + } + } + + // Clean up and log connection closure + drop(watcher); + trace!("HTTP connection closed"); + }); +} diff --git a/src/lib.rs b/src/lib.rs index 6abb287..9a714c1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,375 +1,7 @@ -use bytes::Bytes; -use http::{Method, Request, Response, StatusCode}; -use http_body::Body; -use http_body_util::BodyExt; -use http_body_util::Full; -use hyper::{body::Incoming, service::Service as HyperService}; -use hyper_util::rt::{TokioExecutor, TokioIo}; -use hyper_util::server::conn::auto::Builder as HttpConnBuilder; -use hyper_util::server::conn::auto::HttpServerConnExec; -use hyper_util::service::TowerToHyperService; -use pin_project::pin_project; -use rustls::pki_types::{CertificateDer, PrivateKeyDer}; -use rustls::ServerConfig; -use std::error::Error as StdError; -use std::future::pending; -use std::net::SocketAddr; -use std::pin::{pin, Pin}; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::time::Duration; -use std::{fmt, fs, future::Future, io}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::macros::support::poll_fn; -use tokio::net::TcpListener; -use tokio::time::sleep; -use tokio_rustls::TlsAcceptor; -use tokio_stream::Stream; -use tokio_stream::StreamExt as _; -use tower::{Service, ServiceBuilder}; -use tracing::{debug, trace}; +mod error; +mod fuse; +mod http; +mod tcp; +mod tls; -// From `futures-util` crate, borrowed since this is the only dependency hyper-server requires. -// LICENSE: MIT or Apache-2.0 -// A future which only yields `Poll::Ready` once, and thereafter yields `Poll::Pending`. -#[pin_project] -struct Fuse { - #[pin] - inner: Option, -} - -impl Future for Fuse -where - F: Future, -{ - type Output = F::Output; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.as_mut().project().inner.as_pin_mut() { - Some(fut) => fut.poll(cx).map(|output| { - self.project().inner.set(None); - output - }), - None => Poll::Pending, - } - } -} - -type Source = Box; - -/// Errors that originate from the client hyper-server; -pub struct Error { - inner: ErrorImpl, -} - -struct ErrorImpl { - kind: Kind, - source: Option, -} - -#[derive(Debug)] -pub(crate) enum Kind { - Transport, -} - -impl Error { - pub(crate) fn new(kind: Kind) -> Self { - Self { - inner: ErrorImpl { kind, source: None }, - } - } - - pub(crate) fn with(mut self, source: impl Into) -> Self { - self.inner.source = Some(source.into()); - self - } - - pub(crate) fn from_source( - source: impl Into + std::error::Error + std::marker::Send + std::marker::Sync + 'static, - ) -> Self { - Error::new(Kind::Transport).with(source) - } - - fn description(&self) -> &str { - match &self.inner.kind { - Kind::Transport => "transport error", - } - } -} - -impl fmt::Debug for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut f = f.debug_tuple("hyper_server::Error"); - - f.field(&self.inner.kind); - - if let Some(source) = &self.inner.source { - f.field(source); - } - - f.finish() - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(self.description()) - } -} - -impl StdError for Error { - fn source(&self) -> Option<&(dyn StdError + 'static)> { - self.inner - .source - .as_ref() - .map(|source| &**source as &(dyn StdError + 'static)) - } -} - -async fn sleep_or_pending(wait_for: Option) { - match wait_for { - Some(wait) => sleep(wait).await, - None => pending().await, - }; -} - -#[derive(Debug, Clone)] -pub struct Logger { - inner: S, -} -impl Logger { - pub fn new(inner: S) -> Self { - Logger { inner } - } -} -type Req = Request; - -impl Service for Logger -where - S: Service + Clone, -{ - type Response = S::Response; - - type Error = S::Error; - - type Future = S::Future; - - fn poll_ready( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, req: Req) -> Self::Future { - println!("processing request: {} {}", req.method(), req.uri().path()); - self.inner.call(req) - } -} - -// Wrapped error type for the server. -fn error(err: String) -> io::Error { - io::Error::new(io::ErrorKind::Other, err) -} - -// Load the public certificate from a file. -fn load_certs(filename: &str) -> io::Result>> { - // Open certificate file. - let certfile = fs::File::open(filename) - .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; - let mut reader = io::BufReader::new(certfile); - - // Load and return certificate. - rustls_pemfile::certs(&mut reader).collect() -} - -// Load the private key from a file. -fn load_private_key(filename: &str) -> io::Result> { - // Open keyfile. - let keyfile = fs::File::open(filename) - .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; - let mut reader = io::BufReader::new(keyfile); - - // Load and return a single private key. - rustls_pemfile::private_key(&mut reader).map(|key| key.unwrap()) -} - -// Custom echo service, handling two different routes and a -// catch-all 404/not-found responder. -async fn echo(req: Request) -> Result>, hyper::Error> { - let mut response = Response::new(Full::default()); - match (req.method(), req.uri().path()) { - // Help route. - (&Method::GET, "/") => { - *response.body_mut() = Full::from("Try POST /echo\n"); - } - // Echo service route. - (&Method::POST, "/echo") => { - *response.body_mut() = Full::from(req.into_body().collect().await?.to_bytes()); - } - // Catch-all 404. - _ => { - *response.status_mut() = StatusCode::NOT_FOUND; - } - }; - Ok(response) -} - -pub struct Server {} - -impl Server { - pub async fn serve(&self) -> Result<(), Box> { - // Get a random port from the OS - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - // Create a TCP listener bound to the random address - let incoming = TcpListener::bind(&addr).await?; - - // Load public certificate. - let certs = load_certs("examples/sample.pem")?; - - // Load private key. - let key = load_private_key("examples/sample.rsa")?; - - // Build TLS configuration. - let mut server_config = ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(certs, key) - .unwrap(); - - // Enable ALPN with HTTP/2 and HTTP/1.1 support. - server_config.alpn_protocols = - vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; - - // Create a rustls TlsAcceptor - let tls_acceptor = TlsAcceptor::from(Arc::new(server_config)); - - // Create a tower service - let service = tower::service_fn(echo); - let service = ServiceBuilder::new().layer_fn(Logger::new).service(service); - - // Convert it to a hyper service - let service = TowerToHyperService::new(service); - - // Begin the server loop - loop { - // Wait for an incoming tcp stream - let (tcp_stream, _remote_addr) = incoming.accept().await.unwrap(); - - // Clone a new instance of the tls_acceptor - let tls_acceptor = tls_acceptor.clone(); - - // Clone a new instance of the service - let service = service.clone(); - - // Spawn a new async task to handle the incoming connection - tokio::spawn(async move { - // Perform the TLS handshake - let tls_stream = match tls_acceptor.accept(tcp_stream).await { - Ok(tls_stream) => tls_stream, - Err(err) => { - eprintln!("failed to perform tls handshake: {err:#}"); - return; - } - }; - - // Serve the http connection - if let Err(err) = HttpConnBuilder::new(TokioExecutor::new()) - .serve_connection(TokioIo::new(tls_stream), service) - .await - { - eprintln!("failed to serve connection: {err:#}"); - } - }); - } - } - - /// Serves a single HTTP connection from a hyper service backend. - /// - /// This method handles an individual HTTP connection, processing requests through - /// the provided service and managing the connection lifecycle. - /// - /// # Type Parameters - /// - /// * `B`: The body type for the HTTP response. - /// * `IO`: The I/O type for the HTTP connection. - /// * `S`: The service type that processes HTTP requests. - /// * `E`: The executor type for the HTTP server connection. - /// - /// # Parameters - /// - /// * `hyper_io`: The I/O object representing the inbound hyper IO stream. - /// * `hyper_svc`: The hyper `Service` implementation used to process HTTP requests. - /// * `builder`: An `HttpConnBuilder` used to create and serve the HTTP connection. - /// * `watcher`: An optional `tokio::sync::watch::Receiver` for graceful shutdown signaling. - /// * `max_connection_age`: An optional `Duration` specifying the maximum age of the connection - /// before initiating a graceful shutdown. - async fn serve_http_connection( - hyper_io: IO, - hyper_service: S, - builder: HttpConnBuilder, - mut watcher: Option>, - max_connection_age: Option, - ) where - B: Body + Send + 'static, - B::Data: Send, - B::Error: Into> + Send + Sync, - IO: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, - S: HyperService, Response=Response> + Clone + Send + 'static, - S::Future: Send + 'static, - S::Error: Into> + Send, - E: HttpServerConnExec + Send + Sync + 'static, - { - // Spawn a new asynchronous task to handle the incoming hyper IO stream - tokio::spawn(async move { - { - // Set up a fused future for the watcher - let mut sig = pin!(Fuse { - inner: watcher.as_mut().map(|w| w.changed()), - }); - - // Create and pin the HTTP connection - let mut conn = pin!(builder.serve_connection(hyper_io, hyper_service)); - - // Set up the sleep future for max connection age - let sleep = sleep_or_pending(max_connection_age); - tokio::pin!(sleep); - - // Main loop for serving the HTTP connection - loop { - tokio::select! { - // Handle the connection result - rv = &mut conn => { - if let Err(err) = rv { - // Log any errors that occur while serving the HTTP connection - debug!("failed serving HTTP connection: {:#}", err); - } - break; - }, - // Handle max connection age timeout - _ = &mut sleep => { - // Initiate a graceful shutdown when max connection age is reached - conn.as_mut().graceful_shutdown(); - sleep.set(sleep_or_pending(None)); - }, - // Handle graceful shutdown signal - _ = &mut sig => { - // Initiate a graceful shutdown when signal is received - conn.as_mut().graceful_shutdown(); - } - } - } - } - - // Clean up and log connection closure - drop(watcher); - trace!("HTTP connection closed"); - }); - } -} - -#[cfg(test)] -mod tests { - #[tokio::test] - async fn test_echo_service() {} -} +pub(crate) type Error = Box; diff --git a/src/tcp.rs b/src/tcp.rs new file mode 100644 index 0000000..eebb05f --- /dev/null +++ b/src/tcp.rs @@ -0,0 +1,199 @@ +use crate::Error; +use std::{io, ops::ControlFlow}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_stream::{Stream, StreamExt}; +use tracing::debug; + +/// Handles errors that occur during TCP connection acceptance. +/// +/// This function determines whether an error should be treated as fatal (breaking the accept loop) +/// or non-fatal (allowing the loop to continue). +/// +/// # Arguments +/// +/// * `e` - The error to handle, which can be converted into the crate's `Error` type. +/// +/// # Returns +/// +/// * `ControlFlow::Continue(())` if the error is non-fatal and the accept loop should continue. +/// * `ControlFlow::Break(Error)` if the error is fatal and the accept loop should terminate. +pub(crate) fn handle_accept_error(e: impl Into) -> ControlFlow { + let e = e.into(); + debug!(error = %e, "TCP accept loop error"); + if let Some(e) = e.downcast_ref::() { + if matches!( + e.kind(), + io::ErrorKind::ConnectionAborted + | io::ErrorKind::Interrupted + | io::ErrorKind::InvalidData + | io::ErrorKind::WouldBlock + ) { + return ControlFlow::Continue(()); + } + } + + ControlFlow::Break(e) +} + +/// Creates a stream that yields a TCP stream for each incoming connection. +/// +/// This function takes a stream of incoming connections and handles errors that may occur +/// during the acceptance process. It will continue to yield connections even if non-fatal +/// errors occur, but will terminate if a fatal error is encountered. +/// +/// # Type Parameters +/// +/// * `IO`: The type of the I/O object yielded by the incoming stream. +/// * `IE`: The type of the error that can be produced by the incoming stream. +/// +/// # Arguments +/// +/// * `incoming`: A stream that yields results of incoming connection attempts. +/// +/// # Returns +/// +/// A pinned stream that yields `Result` for each incoming connection. +pub(crate) fn serve_tcp_incoming( + incoming: impl Stream> + Send + 'static, +) -> impl Stream> +where + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + IE: Into + Send + 'static, +{ + async_stream::stream! { + let mut incoming = Box::pin(incoming); + + while let Some(item) = incoming.next().await { + match item { + Ok(io) => yield Ok(io), + Err(e) => match handle_accept_error(e.into()) { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(e) => yield Err(e), + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::SocketAddr; + use std::pin::Pin; + use tokio::net::{TcpListener, TcpStream}; + use tokio_stream::wrappers::TcpListenerStream; + use tokio_stream::StreamExt; + + #[tokio::test] + async fn test_handle_accept_error() { + // Test non-fatal errors + let non_fatal_errors = vec![ + io::ErrorKind::ConnectionAborted, + io::ErrorKind::Interrupted, + io::ErrorKind::InvalidData, + io::ErrorKind::WouldBlock, + ]; + + for kind in non_fatal_errors { + let error = io::Error::new(kind, "Test error"); + assert!(matches!( + handle_accept_error(error), + ControlFlow::Continue(()) + )); + } + + // Test fatal error + let fatal_error = io::Error::new(io::ErrorKind::PermissionDenied, "Permission denied"); + assert!(matches!( + handle_accept_error(fatal_error), + ControlFlow::Break(_) + )); + } + + #[tokio::test] + async fn test_serve_tcp_incoming_success() -> Result<(), Box> { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await?; + let bound_addr = listener.local_addr()?; + let stream = TcpListenerStream::new(listener); + let mut incoming = Box::pin(serve_tcp_incoming(stream)); + + // Spawn a task to accept one connection + let accept_task = tokio::spawn(async move { incoming.next().await }); + + // Connect to the server + let _client = TcpStream::connect(bound_addr).await?; + + // Check that the connection was accepted + let result = accept_task.await?.unwrap(); + assert!(result.is_ok()); + + Ok(()) + } + + #[tokio::test] + async fn test_serve_tcp_incoming_with_errors() { + // Create a mock stream that yields both successful connections and errors + let mock_stream = tokio_stream::iter(vec![ + Ok(MockIO), + Err(io::Error::new(io::ErrorKind::ConnectionAborted, "Aborted")), + Ok(MockIO), + Err(io::Error::new( + io::ErrorKind::PermissionDenied, + "Permission denied", + )), + ]); + + let mut incoming = Box::pin(serve_tcp_incoming(mock_stream)); + + // First connection should be successful + assert!(incoming.next().await.unwrap().is_ok()); + + // Second connection (aborted) should be skipped + // Third connection should be successful + assert!(incoming.next().await.unwrap().is_ok()); + + // Fourth connection (permission denied) should break the stream + assert!(incoming.next().await.unwrap().is_err()); + + // Stream should be exhausted + assert!(incoming.next().await.is_none()); + } + + // Mock IO type for testing + struct MockIO; + + impl AsyncRead for MockIO { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + } + + impl AsyncWrite for MockIO { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + _buf: &[u8], + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(0)) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + } +} diff --git a/src/tls.rs b/src/tls.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/tls.rs @@ -0,0 +1 @@ +