From b3e42557568b0d7ff43972e4ef84137fdfa0766e Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 07:16:15 -0400 Subject: [PATCH 01/45] chore: tabula rasa --- Cargo.toml | 73 +- README.md | 68 +- examples/configure_addr_incoming.rs | 29 - examples/configure_http.rs | 26 - examples/from_std_listener.rs | 19 - examples/from_std_listener_rustls.rs | 28 - examples/graceful_shutdown.rs | 50 -- examples/hello_world.rs | 18 - examples/http_and_https.rs | 52 -- examples/multiple_addresses.rs | 25 - examples/remote_address.rs | 21 - examples/remote_address_using_tower.rs | 27 - examples/rustls_reload.rs | 52 -- examples/rustls_server.rs | 26 - examples/rustls_session.rs | 80 -- examples/self-signed-certs/cert.pem | 32 - examples/self-signed-certs/key.pem | 52 -- examples/self-signed-certs/reload/cert.pem | 32 - examples/self-signed-certs/reload/key.pem | 52 -- examples/shutdown.rs | 40 - src/accept.rs | 62 -- src/addr_incoming_config.rs | 140 ---- src/handle.rs | 167 ---- src/http_config.rs | 263 ------- src/lib.rs | 130 ---- src/notify_once.rs | 45 -- src/proxy_protocol/future.rs | 143 ---- src/proxy_protocol/mod.rs | 841 --------------------- src/server.rs | 523 ------------- src/service.rs | 164 ---- src/tls_openssl/future.rs | 183 ----- src/tls_openssl/mod.rs | 412 ---------- src/tls_rustls/future.rs | 130 ---- src/tls_rustls/mod.rs | 579 -------------- 34 files changed, 5 insertions(+), 4579 deletions(-) delete mode 100644 examples/configure_addr_incoming.rs delete mode 100644 examples/configure_http.rs delete mode 100644 examples/from_std_listener.rs delete mode 100644 examples/from_std_listener_rustls.rs delete mode 100644 examples/graceful_shutdown.rs delete mode 100644 examples/hello_world.rs delete mode 100644 examples/http_and_https.rs delete mode 100644 examples/multiple_addresses.rs delete mode 100644 examples/remote_address.rs delete mode 100644 examples/remote_address_using_tower.rs delete mode 100644 examples/rustls_reload.rs delete mode 100644 examples/rustls_server.rs delete mode 100644 examples/rustls_session.rs delete mode 100644 examples/self-signed-certs/cert.pem delete mode 100644 examples/self-signed-certs/key.pem delete mode 100644 examples/self-signed-certs/reload/cert.pem delete mode 100644 examples/self-signed-certs/reload/key.pem delete mode 100644 examples/shutdown.rs delete mode 100644 src/accept.rs delete mode 100644 src/addr_incoming_config.rs delete mode 100644 src/handle.rs delete mode 100644 src/http_config.rs delete mode 100644 src/notify_once.rs delete mode 100644 src/proxy_protocol/future.rs delete mode 100644 src/proxy_protocol/mod.rs delete mode 100644 src/server.rs delete mode 100644 src/service.rs delete mode 100644 src/tls_openssl/future.rs delete mode 100644 src/tls_openssl/mod.rs delete mode 100644 src/tls_rustls/future.rs delete mode 100644 src/tls_rustls/mod.rs diff --git a/Cargo.toml b/Cargo.toml index db6b24d..fa1eb50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,79 +1,14 @@ [package] -authors = ["Programatik ", "Megsdevs "] +authors = ["0xAlcibiades "] categories = ["asynchronous", "network-programming", "web-programming"] description = "High level server for hyper and tower." edition = "2021" -homepage = "https://github.com/valorem-labs-inc/hyper-server" +homepage = "https://github.com/warlock-labls/hyper-server" keywords = ["axum", "tonic", "hyper", "tower", "server"] license = "MIT" name = "hyper-server" readme = "README.md" repository = "https://github.com/valorem-labs-inc/hyper-server" -version = "0.5.3" +version = "1.0.0" -[features] -default = [] -tls-rustls = ["arc-swap", "pin-project-lite", "rustls", "rustls-pemfile", "tokio/fs", "tokio/time", "tokio-rustls"] -tls-openssl = ["openssl", "tokio-openssl", "pin-project-lite"] -proxy-protocol = ["ppp", "pin-project-lite"] - -[dependencies] - -# optional dependencies -## rustls -arc-swap = { version = "1", optional = true } -bytes = "1" -futures-util = { version = "0.3", default-features = false, features = ["alloc"] } -http = "0.2" -http-body = "0.4" -hyper = { version = "0.14.27", features = ["http1", "http2", "server", "runtime"] } - -## openssl -openssl = { version = "0.10", optional = true } -pin-project-lite = { version = "0.2", optional = true } -rustls = { version = "0.21", features = ["dangerous_configuration"], optional = true } -rustls-pemfile = { version = "1", optional = true } -tokio = { version = "1", features = ["macros", "net", "sync"] } -tokio-openssl = { version = "0.6", optional = true } -tokio-rustls = { version = "0.24", optional = true } -tower-service = "0.3" - -## proxy-protocol -ppp = { version = "2.2.0", optional = true } - -[dev-dependencies] -axum = "0.6" -hyper = { version = "0.14", features = ["full"] } -tokio = { version = "1", features = ["full"] } -tower = { version = "0.4", features = ["util"] } -tower-http = { version = "0.4.4", features = ["add-extension"] } - -[package.metadata.docs.rs] -all-features = true -cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"] -rustdoc-args = ["--cfg", "docsrs"] - -[[example]] -name = "from_std_listener_rustls" -required-features = ["tls-rustls"] -doc-scrape-examples = true - -[[example]] -name = "http_and_https" -required-features = ["tls-rustls"] -doc-scrape-examples = true - -[[example]] -name = "rustls_reload" -required-features = ["tls-rustls"] -doc-scrape-examples = true - -[[example]] -name = "rustls_server" -required-features = ["tls-rustls"] -doc-scrape-examples = true - -[[example]] -name = "rustls_session" -required-features = ["tls-rustls"] -doc-scrape-examples = true +[dependencies] \ No newline at end of file diff --git a/README.md b/README.md index 5d0b234..4372204 100644 --- a/README.md +++ b/README.md @@ -1,67 +1 @@ -[![License](https://img.shields.io/crates/l/hyper-server)](https://choosealicense.com/licenses/mit/) -[![Crates.io](https://img.shields.io/crates/v/hyper-server)](https://crates.io/crates/hyper-server) -[![Docs](https://img.shields.io/crates/v/hyper-server?color=blue&label=docs)](https://docs.rs/hyper-server/) -![CI](https://github.com/valorem-labs-inc/hyper-server/actions/workflows/CI.yml/badge.svg) -[![codecov](https://codecov.io/gh/valorem-labs-inc/hyper-server/branch/master/graph/badge.svg?token=8W5MEJQSW6)](https://codecov.io/gh/valorem-labs-inc/hyper-server) - -# hyper-server - -hyper-server is a high performance [hyper] server implementation designed to -work with [axum], [tonic] and [tower]. - -## Features - -- HTTP/1 and HTTP/2 -- HTTPS through [rustls] and openssl. -- High performance through [hyper]. -- Using [tower] make service API. -- Exceptional [axum] compatibility. Likely to work with future [axum] releases. -- Superb [tonic] compatibility. Likely to work with future [tonic] releases. - -## Usage Example - -A simple hello world application can be served like: - -```rust -use axum::{routing::get, Router}; -use std::net::SocketAddr; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("listening on {}", addr); - hyper_server::bind(addr) - .serve(app.into_make_service()) - .await - .unwrap(); -} -``` - -You can find more examples [here](/examples). - -## Minimum Supported Rust Version - -hyper-server's MSRV is `1.65`. - -## Safety - -This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in 100% safe Rust. - -## License - -This project is licensed under the [MIT license](LICENSE). - -## Why fork - -This project is based on the great work in [axum-server], which is no longer actively maintained. -The rationale for forking is that we use this for critical infrastructure and want to be able to -extend the crate and fix bugs as needed. - -[axum-server]: https://github.com/programatik29/axum-server -[axum]: https://crates.io/crates/axum -[hyper]: https://crates.io/crates/hyper -[rustls]: https://crates.io/crates/rustls -[tower]: https://crates.io/crates/tower -[tonic]: https://crates.io/crates/tonic +# hyper-server \ No newline at end of file diff --git a/examples/configure_addr_incoming.rs b/examples/configure_addr_incoming.rs deleted file mode 100644 index bf8ad58..0000000 --- a/examples/configure_addr_incoming.rs +++ /dev/null @@ -1,29 +0,0 @@ -//! Run with `cargo run --example configure_http` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url. - -use axum::{routing::get, Router}; -use hyper_server::AddrIncomingConfig; -use std::net::SocketAddr; -use std::time::Duration; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let config = AddrIncomingConfig::new() - .tcp_nodelay(true) - .tcp_sleep_on_accept_errors(true) - .tcp_keepalive(Some(Duration::from_secs(32))) - .tcp_keepalive_interval(Some(Duration::from_secs(1))) - .tcp_keepalive_retries(Some(1)) - .build(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("listening on {}", addr); - hyper_server::bind(addr) - .addr_incoming_config(config) - .serve(app.into_make_service()) - .await - .unwrap(); -} diff --git a/examples/configure_http.rs b/examples/configure_http.rs deleted file mode 100644 index cfd79e2..0000000 --- a/examples/configure_http.rs +++ /dev/null @@ -1,26 +0,0 @@ -//! Run with `cargo run --example configure_http` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url. - -use axum::{routing::get, Router}; -use hyper_server::HttpConfig; -use std::net::SocketAddr; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let config = HttpConfig::new() - .http1_only(true) - .http2_only(false) - .max_buf_size(8192) - .build(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("listening on {}", addr); - hyper_server::bind(addr) - .http_config(config) - .serve(app.into_make_service()) - .await - .unwrap(); -} diff --git a/examples/from_std_listener.rs b/examples/from_std_listener.rs deleted file mode 100644 index 3360187..0000000 --- a/examples/from_std_listener.rs +++ /dev/null @@ -1,19 +0,0 @@ -//! Run with `cargo run --example from_std_listener` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url. - -use axum::{routing::get, Router}; -use std::net::{SocketAddr, TcpListener}; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - let listener = TcpListener::bind(addr).unwrap(); - println!("listening on {}", addr); - hyper_server::from_tcp(listener) - .serve(app.into_make_service()) - .await - .unwrap(); -} diff --git a/examples/from_std_listener_rustls.rs b/examples/from_std_listener_rustls.rs deleted file mode 100644 index 9011212..0000000 --- a/examples/from_std_listener_rustls.rs +++ /dev/null @@ -1,28 +0,0 @@ -//! Run with `cargo run --all-features --example from_std_listener_rustls` -//! command. -//! -//! To connect through browser, navigate to "https://localhost:3000" url. - -use axum::{routing::get, Router}; -use hyper_server::tls_rustls::RustlsConfig; -use std::net::{SocketAddr, TcpListener}; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let config = RustlsConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await - .unwrap(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - let listener = TcpListener::bind(addr).unwrap(); - println!("listening on {}", addr); - hyper_server::from_tcp_rustls(listener, config) - .serve(app.into_make_service()) - .await - .unwrap(); -} diff --git a/examples/graceful_shutdown.rs b/examples/graceful_shutdown.rs deleted file mode 100644 index 82ec31b..0000000 --- a/examples/graceful_shutdown.rs +++ /dev/null @@ -1,50 +0,0 @@ -//! Run with `cargo run --example graceful_shutdown` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url. -//! -//! After 10 seconds: -//! - If there aren't any connections alive, server will shutdown. -//! - If there are connections alive, server will wait until deadline is elapsed. -//! - Deadline is 30 seconds. Server will shutdown anyways when deadline is elapsed. - -use axum::{routing::get, Router}; -use hyper_server::Handle; -use std::{net::SocketAddr, time::Duration}; -use tokio::time::sleep; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let handle = Handle::new(); - - // Spawn a task to gracefully shutdown server. - tokio::spawn(graceful_shutdown(handle.clone())); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("listening on {}", addr); - hyper_server::bind(addr) - .handle(handle) - .serve(app.into_make_service()) - .await - .unwrap(); - - println!("server is shut down"); -} - -async fn graceful_shutdown(handle: Handle) { - // Wait 10 seconds. - sleep(Duration::from_secs(10)).await; - - println!("sending graceful shutdown signal"); - - // Signal the server to shutdown using Handle. - handle.graceful_shutdown(Some(Duration::from_secs(30))); - - // Print alive connection count every second. - loop { - sleep(Duration::from_secs(1)).await; - - println!("alive connections: {}", handle.connection_count()); - } -} diff --git a/examples/hello_world.rs b/examples/hello_world.rs deleted file mode 100644 index d91e25c..0000000 --- a/examples/hello_world.rs +++ /dev/null @@ -1,18 +0,0 @@ -//! Run with `cargo run --example hello_world` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url. - -use axum::{routing::get, Router}; -use std::net::SocketAddr; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("listening on {}", addr); - hyper_server::bind(addr) - .serve(app.into_make_service()) - .await - .unwrap(); -} diff --git a/examples/http_and_https.rs b/examples/http_and_https.rs deleted file mode 100644 index 63b05db..0000000 --- a/examples/http_and_https.rs +++ /dev/null @@ -1,52 +0,0 @@ -//! Run with `cargo run --all-features --example http_and_https` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url which should redirect to -//! "https://localhost:3443". - -use axum::{http::uri::Uri, response::Redirect, routing::get, Router}; -use hyper_server::tls_rustls::RustlsConfig; -use std::net::SocketAddr; - -#[tokio::main] -async fn main() { - let http = tokio::spawn(http_server()); - let https = tokio::spawn(https_server()); - - // Ignore errors. - let _ = tokio::join!(http, https); -} - -async fn http_server() { - let app = Router::new().route("/", get(http_handler)); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("http listening on {}", addr); - hyper_server::bind(addr) - .serve(app.into_make_service()) - .await - .unwrap(); -} - -async fn http_handler(uri: Uri) -> Redirect { - let uri = format!("https://127.0.0.1:3443{}", uri.path()); - - Redirect::temporary(&uri) -} - -async fn https_server() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let config = RustlsConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await - .unwrap(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3443)); - println!("https listening on {}", addr); - hyper_server::bind_rustls(addr, config) - .serve(app.into_make_service()) - .await - .unwrap(); -} diff --git a/examples/multiple_addresses.rs b/examples/multiple_addresses.rs deleted file mode 100644 index 8030586..0000000 --- a/examples/multiple_addresses.rs +++ /dev/null @@ -1,25 +0,0 @@ -use axum::{routing::get, Router}; -use futures_util::future::try_join_all; -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; - -#[tokio::main] -async fn main() { - let servers = vec![ - SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 3000), - SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 3000), - ] - .into_iter() - .map(|addr| tokio::spawn(start_server(addr))); - - // Returns the first error if any of the servers return an error. - try_join_all(servers).await.unwrap(); -} - -async fn start_server(addr: SocketAddr) { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - hyper_server::bind(addr) - .serve(app.into_make_service()) - .await - .unwrap(); -} diff --git a/examples/remote_address.rs b/examples/remote_address.rs deleted file mode 100644 index fb29c7f..0000000 --- a/examples/remote_address.rs +++ /dev/null @@ -1,21 +0,0 @@ -//! Run with `cargo run --example remote_address` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url. - -use axum::{extract::ConnectInfo, routing::get, Router}; -use std::net::SocketAddr; - -#[tokio::main] -async fn main() { - let app = Router::new() - .route("/", get(handler)) - .into_make_service_with_connect_info::(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - - hyper_server::bind(addr).serve(app).await.unwrap(); -} - -async fn handler(ConnectInfo(addr): ConnectInfo) -> String { - format!("your ip address is: {}", addr) -} diff --git a/examples/remote_address_using_tower.rs b/examples/remote_address_using_tower.rs deleted file mode 100644 index 6342963..0000000 --- a/examples/remote_address_using_tower.rs +++ /dev/null @@ -1,27 +0,0 @@ -//! Run with `cargo run --example remote_address_using_tower` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url. - -use hyper::{server::conn::AddrStream, Body, Request, Response}; -use std::{convert::Infallible, net::SocketAddr}; -use tower::service_fn; -use tower_http::add_extension::AddExtension; - -#[tokio::main] -async fn main() { - let service = service_fn(|mut req: Request| async move { - let addr: SocketAddr = req.extensions_mut().remove().unwrap(); - let body = Body::from(format!("IP Address: {}", addr)); - - Ok::<_, Infallible>(Response::new(body)) - }); - - hyper_server::bind(SocketAddr::from(([127, 0, 0, 1], 3000))) - .serve(service_fn(|addr: &AddrStream| { - let addr = addr.remote_addr(); - - async move { Ok::<_, Infallible>(AddExtension::new(service, addr)) } - })) - .await - .unwrap(); -} diff --git a/examples/rustls_reload.rs b/examples/rustls_reload.rs deleted file mode 100644 index ba1add6..0000000 --- a/examples/rustls_reload.rs +++ /dev/null @@ -1,52 +0,0 @@ -//! Run with `cargo run --all-features --example rustls_reload` command. -//! -//! To connect through browser, navigate to "https://localhost:3000" url. -//! -//! Certificate common name will be "localhost". -//! -//! After 20 seconds, certificate common name will be "reloaded". - -use axum::{routing::get, Router}; -use hyper_server::tls_rustls::RustlsConfig; -use std::{net::SocketAddr, time::Duration}; -use tokio::time::sleep; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let config = RustlsConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await - .unwrap(); - - // Spawn a task to reload tls. - tokio::spawn(reload(config.clone())); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("listening on {}", addr); - hyper_server::bind_rustls(addr, config) - .serve(app.into_make_service()) - .await - .unwrap(); -} - -async fn reload(config: RustlsConfig) { - // Wait for 20 seconds. - sleep(Duration::from_secs(20)).await; - - println!("reloading rustls configuration"); - - // Reload rustls configuration from new files. - config - .reload_from_pem_file( - "examples/self-signed-certs/reload/cert.pem", - "examples/self-signed-certs/reload/key.pem", - ) - .await - .unwrap(); - - println!("rustls configuration reloaded"); -} diff --git a/examples/rustls_server.rs b/examples/rustls_server.rs deleted file mode 100644 index c7cf0f2..0000000 --- a/examples/rustls_server.rs +++ /dev/null @@ -1,26 +0,0 @@ -//! Run with `cargo run --all-features --example rustls_server` command. -//! -//! To connect through browser, navigate to "https://localhost:3000" url. - -use axum::{routing::get, Router}; -use hyper_server::tls_rustls::RustlsConfig; -use std::net::SocketAddr; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let config = RustlsConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await - .unwrap(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("listening on {}", addr); - hyper_server::bind_rustls(addr, config) - .serve(app.into_make_service()) - .await - .unwrap(); -} diff --git a/examples/rustls_session.rs b/examples/rustls_session.rs deleted file mode 100644 index fab499f..0000000 --- a/examples/rustls_session.rs +++ /dev/null @@ -1,80 +0,0 @@ -//! Run with `cargo run --all-features --example rustls_session` command. -//! -//! To connect through browser, navigate to "https://localhost:3000" url. - -use axum::{middleware::AddExtension, routing::get, Extension, Router}; -use futures_util::future::BoxFuture; -use hyper_server::{ - accept::Accept, - tls_rustls::{RustlsAcceptor, RustlsConfig}, -}; -use std::{io, net::SocketAddr, sync::Arc}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_rustls::server::TlsStream; -use tower::Layer; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(handler)); - - let config = RustlsConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await - .unwrap(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - - println!("listening on {}", addr); - - let acceptor = CustomAcceptor::new(RustlsAcceptor::new(config)); - let server = hyper_server::bind(addr).acceptor(acceptor); - - server.serve(app.into_make_service()).await.unwrap(); -} - -async fn handler(tls_data: Extension) -> String { - format!("{:?}", tls_data) -} - -#[derive(Debug, Clone)] -struct TlsData { - _hostname: Option>, -} - -#[derive(Debug, Clone)] -struct CustomAcceptor { - inner: RustlsAcceptor, -} - -impl CustomAcceptor { - fn new(inner: RustlsAcceptor) -> Self { - Self { inner } - } -} - -impl Accept for CustomAcceptor -where - I: AsyncRead + AsyncWrite + Unpin + Send + 'static, - S: Send + 'static, -{ - type Stream = TlsStream; - type Service = AddExtension; - type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>; - - fn accept(&self, stream: I, service: S) -> Self::Future { - let acceptor = self.inner.clone(); - - Box::pin(async move { - let (stream, service) = acceptor.accept(stream, service).await?; - let server_conn = stream.get_ref().1; - let sni_hostname = TlsData { - _hostname: server_conn.server_name().map(From::from), - }; - let service = Extension(sni_hostname).layer(service); - - Ok((stream, service)) - }) - } -} diff --git a/examples/self-signed-certs/cert.pem b/examples/self-signed-certs/cert.pem deleted file mode 100644 index 8227f32..0000000 --- a/examples/self-signed-certs/cert.pem +++ /dev/null @@ -1,32 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIFkzCCA3ugAwIBAgIUQZiKeBISKUZoglT8J8CCPpGbgTkwDQYJKoZIhvcNAQEL -BQAwWTELMAkGA1UEBhMCVVMxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM -GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X -DTIxMDgyOTEyMDE0NVoXDTIyMDgyOTEyMDE0NVowWTELMAkGA1UEBhMCVVMxEzAR -BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 -IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEFAAOCAg8A -MIICCgKCAgEAoeDJnuh1lhcpKCt5VEBqO9JcSoz2wqD3SLj4i2qrEOvqb4X0ZZeN -5GQXQlOG2N6+9FOxTzaTTigTecYzI3hqKn1fiuvaS4EeTC7E1sVOj7tY0yVySjXM -pC/3t1n1s3B25m7eQ0G2JypZFCobGqY0kaRoO+mCTjI4bdCd769shIerCO4Z8FD5 -uj1+hBC7ZY/sqmRkGTLX1ZzkXzaeNeWGlkXKU8/V3qdveFQ/sGe+KoZpOPXb0yR7 -H8zf6NE2CFCNJDhytOkYLOsnvCJOvibJ3kbM2GfI9iCd0/QhQAOcrVhcOgI4aIxr -wP3zvF4PFUhFKEWHqK5IFq41xKyMYu2fw3bmKXg4zsQGcB0avBD7z+7ENEBvLkNI -7O20wKJp8u0RfjStNHWPmWLXPjkadVB5JHJjsktvgNZkbs9ugxhZWW2AzrrIuqwR -NOWnjHE7J3jvcHP6jE5O9LHpnlh6BMoKPsQuRu/bkrD34rNzwH7IX1To1CyDazMR -yhUiARYh43gg6hrrQdVjDFMHd51mgWHtOPzSLb0uzToglAa3FClGlCeaiacu4H2V -EfJrlCbVlftmIub9/EILZ6XpyYWMxt2mm4mCcMtXmBsHolP4lU3keK8AGNFOr3PC -B7NHLNp1RHgx8+Q3kzobJ1Lk+zEjraWPb5gyByUvZySbd/JTGgNCmZsCAwEAAaNT -MFEwHQYDVR0OBBYEFGsIv6GsbDS+dEWwWlA/3TG5Oi88MB8GA1UdIwQYMBaAFGsI -v6GsbDS+dEWwWlA/3TG5Oi88MA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEL -BQADggIBAHhjzP8WtkLJVfZXXUPAAekR7kaqk2hb3hIgDABBJ7xNxcktLOH7V/ng -nhbnwSH5mCkHHXx78TOhWqokHp5wru8K3de5wvAD8uz0UwNDHK5EzqtjYLzxbxAr -ht89WoXGPEZIz6MuOxVYx/HHXdgNEXUcujzfpAfvznVxvzBVqpHNgc7qO8wJd0cG -nit1XubxKoIVTEUjDfxGa2TsmBI7CZ8MLjIyztp/b3txpVl36hPC/uFLwKC780Jc -eO9saA5ISbJh7EaISRr8MKpBpJcraL+055bMjM+kzRFA18NWuuo9Y8fXnXE8e/af -k8FvclVdH/YyezaLkjW7lXjo7QoSXHhAuSzvsGmIsh+HuH+3Fs22AN3aGdmimOmp -7JiNe42mwEpJydwgGlKOysw4ht6MA6yOcQJw73QAYYwusOmNjFZtfCUqJx/JO7mn -Sb1/PW58xYSJhDxdGhoh6Rd3xPMW1T4YwpapkAC/htciK3XkwCcG1VKSmCIErkXf -vllmdahH/QkNooNAHMZl/ipYMik8pp5eRjVjCvpQTDBOI97U0+bgXydHVowP9ExE -dGcm6pP8FU1LyBZdYTdlMRC5Z0L0ltcZn7bqKcyzZB3UcWJv7Uhn3MYbmqGsUVly -a/e3kH2t5pEWRTsrNrRD94LzEYKvcNHy6PYkrgpGjh2G2VBZgNzh ------END CERTIFICATE----- diff --git a/examples/self-signed-certs/key.pem b/examples/self-signed-certs/key.pem deleted file mode 100644 index c329a2d..0000000 --- a/examples/self-signed-certs/key.pem +++ /dev/null @@ -1,52 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIJQQIBADANBgkqhkiG9w0BAQEFAASCCSswggknAgEAAoICAQCh4Mme6HWWFyko -K3lUQGo70lxKjPbCoPdIuPiLaqsQ6+pvhfRll43kZBdCU4bY3r70U7FPNpNOKBN5 -xjMjeGoqfV+K69pLgR5MLsTWxU6Pu1jTJXJKNcykL/e3WfWzcHbmbt5DQbYnKlkU -KhsapjSRpGg76YJOMjht0J3vr2yEh6sI7hnwUPm6PX6EELtlj+yqZGQZMtfVnORf -Np415YaWRcpTz9Xep294VD+wZ74qhmk49dvTJHsfzN/o0TYIUI0kOHK06Rgs6ye8 -Ik6+JsneRszYZ8j2IJ3T9CFAA5ytWFw6AjhojGvA/fO8Xg8VSEUoRYeorkgWrjXE -rIxi7Z/DduYpeDjOxAZwHRq8EPvP7sQ0QG8uQ0js7bTAomny7RF+NK00dY+ZYtc+ -ORp1UHkkcmOyS2+A1mRuz26DGFlZbYDOusi6rBE05aeMcTsneO9wc/qMTk70seme -WHoEygo+xC5G79uSsPfis3PAfshfVOjULINrMxHKFSIBFiHjeCDqGutB1WMMUwd3 -nWaBYe04/NItvS7NOiCUBrcUKUaUJ5qJpy7gfZUR8muUJtWV+2Yi5v38QgtnpenJ -hYzG3aabiYJwy1eYGweiU/iVTeR4rwAY0U6vc8IHs0cs2nVEeDHz5DeTOhsnUuT7 -MSOtpY9vmDIHJS9nJJt38lMaA0KZmwIDAQABAoICAHzGnCLU4+4xJBRGjlsW28wI -tgLw7TPQh0uS6GHucrW0YxxbkKrOSx0E2bjSUVrRNzd1W3LHinvwADMZR0nMA2mF -AiQ+8CDLAeOPGULDC29W5Xy7nID/PyI/px25Rd5ujffI9aG6AQHnbopQelvsSREK -PR4RO9OyejSLXXHnMipluLxFa9EFWbjotaBulUQP0Ej24QFbY2rQaGfL3d+FcFxc -pzw7M4tQXGfP6Ne836Q/vtOdDziNIiq87Mq0mIWIMYL9z80K7wuQpywo9bE0jN28 -jSExvoGZWo6J2ydQoXAsb8p286wCsPwtw7Yqek3ZSxVjotGupPp2hhN3PS70IvR5 -wcR+1pGTSzUFkrLurZftR+HNU4GHVGEzmFKtQ1dyBjDdLSkBHx+N3rzvvArMLDKI -hYXc7AgCTR1SkZBBVPFlNZJyicE+x52UGLvnyS5chgqvSsOrkhDu/bK+ISTh+3jZ -8QSnjYuZLQ1q5i3914wKzjSrHbFWuoGullqCk6nvhn2EEDcAVla0ebSYBcrnzKhO -qJogZzUSTpINIKNQlZuohzbS0lrvXuYDRDkZLRaQWKgHGiat7peBazEfd0NTHpIs -2lKovGTWNU8MIvJPONFixIZ0k7Z+s7Oje+dSOoCyCUzA3BT+mmS2Yi180zxrtRBS -LPGooWR3Rfyptx+OJkehAoIBAQDQkoPWIQWdFG1G9x08H49/AjcfGtHbdjeCjNqS -6mbXLzHgQjnUnmKmuqgkSw9IA+l2OqX4dNrKqH9P6Ex9s3HRxTmYt9/0DLT8Thus -04DiusjhUDQYV8pXUBujmVkMEEI8N5RXv0IAd59kaA6kWJLtrnp6mREY2WJicIAJ -BKut0QTC+upnvV2NKYc+Ki5ElB5hqzICr+wBq35ZlxTId7F5iaZeWeljpOodZw06 -KCVIUhmGHNVR0DUqUJ8+j7gstXhXr0MVhAlRg+WhlUvyCm1UhElyyrVgiXjqeqO9 -RO2+/poPNFxylVzYgTi54ydeB378/LcrxFQ7Q3DAW6DSAefHAoIBAQDGsBc6SnXu -WGW2qPWQM1Jm9hGy7ZgB8953kvpSxE1cVkXoOOtaa2HtRurxT55s4nTAzqDV//7R -9OX+JDCMeQLm9oLzGOxaCaq5lGNTNQs+MBPP78wwQrZRhneuG5U0lEYBb+dlkHih -IejR9OK0r0btpwuLWTC/cs2dNMW0J6JwaK6J4JiJC+nJiKyt1W98Vtpz0oLJq/Re -Z/e3sVZF3RLks5WoQsiXYoQ3KFf9koBsImggGm2prrFl9KeZJOVJP0ZeDaRcLGWQ -PRt0nNKuuSRJ5HZF/0TCwUXAtpaftAsr4fhB+/KYVdVrni5FYdfqUX4KH6n9LFSG -VC1OST1JJIeNAoIBAB0H57XMTt24VCWGi9ksg2qoQkfgEcm8QKm5NUsxuTLGbOjM -DwSbLxwJ6xFyKSRa9wnvy94zVajTnzTeHpd4fKU4EHZDUbbEdgSQUqXRoqTsXr2N -zlJ9FbrleZNh6tUVBkMfcVRtWKB8BgGRwkf51CmlGYMq/wg4actN4WRf9A1zhHgn -OK1L3FOjriFm+Z2uCDSMAaACIJVy61lJACmPD3LdR/zmAuhNshB5oYuwvs+8LbVP -GhoTIvNK2X95vabrc16xFGNQR4PDGhlNkI6WCPW0nAyQToKrX9szSsszZuwowATR -wvRn+c5g3iZxia861+AaxNwgraC6GF2N42qXvU0CggEAXD+NyUahEpSARRqVSOpL -K/q7pPOjS+TKOYJILv1tXZ3Av10OCOEqilwO4RMyXyOVSZ+mFTXSPfESh7iNweq9 -ajax/eRoeDVcyuUWaJ+MJMd1q2mOyClxNNDV6ERuNgdRqYEnUoSNPWLdEf48898d -c2HHfl9evsSyqnbCBC8SwFYaE3Hv4FFjrmqCogMiy/wXWQc4KiJoRxzGascvYyiN -iRnINmMrdv4KnQFiOR03+vzOk3kxyUKOouPAnN4Ahs2WAj0bPqBuV1XH1ZCqUO0s -6BHmyAEJD9Nka2Fa9bNGLI2yEhDERe40NM8wdI5FDUng1xp0dlOKuwOCNYLTrY4E -UQKCAQByK/e9bFaNv+BS81flfTt9tinKRIFc8IAKUl39M5wmehUqey8BGfKkMTGX -1w7R7lfCxoDi5Cl64fkPLWHrvZTuWh5ApC8r6uVjEX3TNWhBCQAB2tJmF7s9N73K -ymoh3VvQUHFZ2+IrCTgkJTWqjEdhPiiU3/oBnIv9ZYWf1ORkVhoAdxoLBn2XuTRC -xIKhiQeqCcKE9yTN26rt+7DjhB5TJ0W2meC8Rxb4lZRDD50MZayZQ6Vo4O87INpD -WjR7NdZndxUeinCPNQos9hEEke1ncCIzkwzJ9kn1R3iJzZRdjDKW3oT4G6QaStf5 -HUGWsrhnzvWoCOV+9+MdApoim8FI ------END PRIVATE KEY----- diff --git a/examples/self-signed-certs/reload/cert.pem b/examples/self-signed-certs/reload/cert.pem deleted file mode 100644 index 545c0ae..0000000 --- a/examples/self-signed-certs/reload/cert.pem +++ /dev/null @@ -1,32 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIFkTCCA3mgAwIBAgIUXN/Uw2uyZ6/Uj4LRuuK0/RdRHRswDQYJKoZIhvcNAQEL -BQAwWDELMAkGA1UEBhMCVVMxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM -GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDERMA8GA1UEAwwIcmVsb2FkZWQwHhcN -MjEwODI5MTI1NjU2WhcNMjIwODI5MTI1NjU2WjBYMQswCQYDVQQGEwJVUzETMBEG -A1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkg -THRkMREwDwYDVQQDDAhyZWxvYWRlZDCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCC -AgoCggIBAL/OR5KG8PqJgZSDza1lBVpZjW3jw1MA9eegePoK/4dYjd0Mdw+DeYOu -J/UmXoLHUDi/YWwZSmeY3YW0Wimwo1C5VqQL3GapSyFibvyTFE2fpoK0QtlgTKJ4 -G0mzdZ9NjibhvK23UOW5VbzlBujrYAaF2ynUha/cgVZ9uzvdwd6ooi+1i6XfHnkG -AQqGi6u/SIB+eHXn0w+tTYXmMp44jqIkjsK2vPNeifWj3MQxvgg7JTR/AKTmFCMm -BJIEP62BTFEnHJF+pRd2Hj0GIAiNBq1uA1F+HoUhxyX3OWHYCkRwPMnrSbPQOyxO -g4oFaUzAvMd2lHN/GjJS0kLwDy7WF/iXZuFxdEsmEmH62fE7N4P2uEnNw5OcHS82 -8Mc2EoMrV8zUBl4ZJ2eFo6w9lAx2bzMZyGXdOHsZWnJ5+1co6gfRfv51TeJGQx8f -JaHWFrn55qKBQmgQpKmCt/sG3HrqTviw1PtecsrzTliEXPoWdx6AhYaV+I4u8c8S -Q0NfdfjXx+5EMFDe5CvfWp/D5C1AQIV5E0Ao3Q+VfjoU/2tz9WcE5voHfyl3mBMI -FHvAPCZC18E+ZpiYyhRJLxP4z0MzTiuxp25lRi0Yt/5QTzEzFfH1UNQYe2xljPtf -syg5RtHoijcL+MncE1NUXz+B/qC4uJm8llPjFoL94Yg3/dwWOPtPAgMBAAGjUzBR -MB0GA1UdDgQWBBThri9Jq8CLFMEHJ+wE7WsCODVRwjAfBgNVHSMEGDAWgBThri9J -q8CLFMEHJ+wE7WsCODVRwjAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUA -A4ICAQCiR8yJ2YQyJfYDd9BT9eb9H8/S+Yz/9ayNS3zSJk4StQZaS1V6XjexzDBr -MRSr/hHGtO9G2qeocuJ/ArUJS5yYsf69g9AjuB+b41k0E4BVpiB/lENAhMbMbl+D -+ysRifUR2svHnZzKnL7DRrpS3vEUQhO37GXwbEi192rXAr2N6VE0LhxGyE7EwCzw -7gNkzoB3/Y4Fb+6zCYZorg3PmPZHrfu9vGFiP9nh+JVos9aq2JHZgZJ2N5Hcdh1H -Bci372+i1SHKfYutXrcSnUcPd4UgGQt6F63fOFHJEGsSVbHpJujqjpIscuPqgfn8 -DSkm9SEyVEV8MrY2vtwtVFOre4yjsaZ2fHDU7rCXOO88kIBBdvIpdIO4mBKV14ug -k9M1xzqK/KvgMUztuw/oLxOp7Vnii9sQ9bjzjbFEMiJ07V5Egr88Zh+VnN3ED1MH -Ri6Ho/CI/ttAwzZVhrKumOb6AprPVUteZFedpV80UaYmIthkeW0i9QcUOMkr4bL3 -gCghJeBSETTGEYCKOpcIFbvXwlc8d3KlL0Fa4EbQiw5vlPY28UChnxuZ3I0Vtetf -2F+3bLoVxfZD2Gc7p5bjGHgzUbGLFM4GgqQ6EbRh261Om9/bUxBao7mhKa23XWna -3Y4qISAqus6OolerflYJCCuWUF4N6e6fES5bqnZD49qAaIEg0A== ------END CERTIFICATE----- diff --git a/examples/self-signed-certs/reload/key.pem b/examples/self-signed-certs/reload/key.pem deleted file mode 100644 index bc7dccf..0000000 --- a/examples/self-signed-certs/reload/key.pem +++ /dev/null @@ -1,52 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQC/zkeShvD6iYGU -g82tZQVaWY1t48NTAPXnoHj6Cv+HWI3dDHcPg3mDrif1Jl6Cx1A4v2FsGUpnmN2F -tFopsKNQuVakC9xmqUshYm78kxRNn6aCtELZYEyieBtJs3WfTY4m4bytt1DluVW8 -5Qbo62AGhdsp1IWv3IFWfbs73cHeqKIvtYul3x55BgEKhourv0iAfnh159MPrU2F -5jKeOI6iJI7CtrzzXon1o9zEMb4IOyU0fwCk5hQjJgSSBD+tgUxRJxyRfqUXdh49 -BiAIjQatbgNRfh6FIccl9zlh2ApEcDzJ60mz0DssToOKBWlMwLzHdpRzfxoyUtJC -8A8u1hf4l2bhcXRLJhJh+tnxOzeD9rhJzcOTnB0vNvDHNhKDK1fM1AZeGSdnhaOs -PZQMdm8zGchl3Th7GVpyeftXKOoH0X7+dU3iRkMfHyWh1ha5+eaigUJoEKSpgrf7 -Btx66k74sNT7XnLK805YhFz6FncegIWGlfiOLvHPEkNDX3X418fuRDBQ3uQr31qf -w+QtQECFeRNAKN0PlX46FP9rc/VnBOb6B38pd5gTCBR7wDwmQtfBPmaYmMoUSS8T -+M9DM04rsaduZUYtGLf+UE8xMxXx9VDUGHtsZYz7X7MoOUbR6Io3C/jJ3BNTVF8/ -gf6guLiZvJZT4xaC/eGIN/3cFjj7TwIDAQABAoICAGNoV7PbeB2BEsWUIg8R4lpX -O3OOrfbg8pGfm9OLy6+r96pvAW3q6BmVM2RdBHKnNi6TEbzixqs2kOjw9iHRSHNX -+01+UDZs22FsELWazNUGP1hScKsUu+MgeJQUDIwJt/jy2cT201icW5FQ6enhw5zd -1x6w5LCmien3tAhtAEOUBqrPXpcTMknrELMR1GWo97yQz4HcKolfemRBUE6sZVAn -vk2wQ/GmN741tP+CAElnzfqNMBpGnH0zAP9kcFRORO1yZd4KUyn7r+RUvllwLdvI -vrOHt+2r+fj1TqolO/0IZpkH9uTYsTJfZtEryM1cvvppvLq3Ty5xukOzA0t07mqk -6G6217EhPSKE+DdBbsrExJjdrzBMyTQEL2qGLihhIFpDAd8WdNr8DRJrI4ZEo1Rg -Du1PuvcCscp97eTaiXSQTknUwBzHbeIkYepQYOksd+11cBXY40TR9X78LwUnfmBZ -yeAqFIBND5Z56NgPkXZ9DTeLyt6fkA9+V7WLfpxeGAdhn/JsyflIy2SQyFmRElxV -AC5/8GHgwTXjHmBJNg/PJZBHduje7BWPoCdX8X+SzE/ph/s6vzNdYsGxUFgoMshj -YlhTS9NL0Asp+KQD+bsMYxYmhvb++YIIqwdkMAP4sGD3iKFQXRRRUzldXC5A88US -1Zk0xEvYjw7F5GEKi35RAoIBAQDgH/C07vP1+qPHj3W6vOQ90T2WbS4kfpWUv5wc -KKyvZVDqBrx6R22/fn1GrdXKxrMzVIFN0AXx38NYUmUVe9tQ/nq2Lx6PFKWX5khw -84IJw0LLuXBN6NiorxV4Ep9Bf0uST81sPMmE1vDyAveUVC+FX8NAgD8Hr4tDsleF -NIijqDjVbAN6+T5qlUyuUSjSUo+KnWJ72M2PCSiUDONW93kACk77wo1Hon2YcO3H -IyAQnPJKPYNlgivm5EmEvvThJ2nmlaXwadSH9bNes8RkzcfPJybkVEFMD9nxD127 -DnuHpRBFkjGfPsb9ulLODPvfQirSSXQsR1N8hQTZACd9g6L9AoIBAQDbFabN7Ztg -CnMZ9hT8qEvau67Q8KmpaZBptuYM/W3/T4oxoPOTLZCzvVX5Xy+hOuec/N/DAP/4 -6PDTXPt6kEr31ewcQyBVQarB9bkY1t9iMa32ZsVBe00/UFrdgR3MwD3jP3pmuifT -+ZI4MyJqq4SGek7Zqjc6Unn24TSqXVsvbtILqTbRsqf5iV2LUx3NmqbX84K4EwBm -ZPrMyD0jiAd0YibewyorbhDVTKVxPtVLVCCQpLaTcvkYs1H3mSaY2yB4nBaWUto7 -3iRW497KOpsBpx4UeW4iNni9JtfPKALdIaz+X4ig7tyxwRuMVUkKd7q7faM8IGoH -45xH8w5mW4c7AoIBAQCWRhQ43LcKyOEjnxcK/Df1EuS+hboYkh9tOwRLBSKz/7S/ -FYEuY9I8QW1yBICCk7P3yMNiDwbNZIEwKR7JxuAIcHiKyxEsUmWtcaREx6D7NscE -nfOk6WjLwYkdly7c1aMwGP3dguyDezLWshKai8/JF6ptBxA78QHphByWneC4CsUA -pIm43IFzKWPexWAflWfVQy2TaIx7SWLB0dpkp02kL0VCHPJpg5O+sIldqjmHqhPy -n0gIub0B9TMuJHNAvBKPnutCRVNRTfbUmqgmBqvgQ5oaIjwd6crxjKIGF/HPw2cj -nqBS6960pUd8DMycp1ra4JFaVwCtTusvLKFN0QNpAoIBAAJ128m0QWpys5g3C0VL -Ho72TKBME5uzc8u8IhlDP1j+q66jABlHCbj7B1wllYNaBf/dVyX5fOZut0WoZaqa -tDzUSjKHDnXmpuRGvi1pPFj99dYukUiK+fMcE+ko6gzCm+9RZy6AKLJYuyumZ1yL -UJGyDfCj2Lru8i+zl8PSCJQfynwXCmaQexJyWHqYFF2avwTt1yn6DKcZuzdRiF49 -yNelwon95xtVwRqkIbeD3SFbcIIvV12QjPuaB/Gf5q8QxuyT1C0cARdrBz1yka3z -uonqNoxEUNhRhEmbhhDtghq5phe1OvOTuybD5GtPCeL0NUSlxI+ITaiJBdhJAoBj -xsECggEAKN8pJSYAScGx94fCwNbMBxMHqH3Kk83W+DF1V0ejvfhCmWZ6Vf/81xqz -a22AtpKA0EQIV/+d+4BvddMvLtKgYpYf9YR0MTyaps7DIzebr352/0WLlPZTWr5B -mzwWCtiBL0R7i6bXIiuXxqZv7zjFlXHRcj4GQI0zHT61CLGkTlF5f/25js0NkL+K -dizoG4pOA0mvZKJIKdE1GI3/20qP01BoCIHVdRHUdB0yKhoHi1EuO7hZdAPN9gsB -LMYbHG3f/dtvj0KCscKYB/Py/SmPdTW+xPZAf7tCrqZjQhPvqP2cD1UQ4glr+N2a -85DaC33fAFuevGxpS147+sAiW6doqQ== ------END PRIVATE KEY----- diff --git a/examples/shutdown.rs b/examples/shutdown.rs deleted file mode 100644 index b00943a..0000000 --- a/examples/shutdown.rs +++ /dev/null @@ -1,40 +0,0 @@ -//! Run with `cargo run --example shutdown` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url. -//! -//! Server will shutdown in 20 seconds. - -use axum::{routing::get, Router}; -use hyper_server::Handle; -use std::{net::SocketAddr, time::Duration}; -use tokio::time::sleep; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let handle = Handle::new(); - - // Spawn a task to shutdown server. - tokio::spawn(shutdown(handle.clone())); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("listening on {}", addr); - hyper_server::bind(addr) - .handle(handle) - .serve(app.into_make_service()) - .await - .unwrap(); - - println!("server is shut down"); -} - -async fn shutdown(handle: Handle) { - // Wait 20 seconds. - sleep(Duration::from_secs(20)).await; - - println!("sending shutdown signal"); - - // Signal the server to shutdown using Handle. - handle.shutdown(); -} diff --git a/src/accept.rs b/src/accept.rs deleted file mode 100644 index 21fc431..0000000 --- a/src/accept.rs +++ /dev/null @@ -1,62 +0,0 @@ -//! Module `accept` provides utilities for asynchronously processing and modifying IO streams and services. -//! -//! The primary trait exposed by this module is [`Accept`], which allows for asynchronous transformations -//! of input streams and services. The module also provides a default implementation, [`DefaultAcceptor`], -//! that performs no modifications and directly passes through the input stream and service. - -use std::{ - future::{Future, Ready}, - io, -}; - -/// An asynchronous trait for processing and modifying IO streams and services. -/// -/// Implementations of this trait can be used to modify or transform the input stream and service before -/// further processing. For instance, this trait could be used to perform initial authentication, logging, -/// or other setup operations on new connections. -pub trait Accept { - /// The modified or transformed IO stream produced by `accept`. - type Stream; - - /// The modified or transformed service produced by `accept`. - type Service; - - /// The Future type that is returned by `accept`. - type Future: Future>; - - /// Asynchronously process and possibly modify the given IO stream and service. - /// - /// # Parameters: - /// * `stream`: The incoming IO stream, typically a connection. - /// * `service`: The associated service with the stream. - /// - /// # Returns: - /// A future resolving to the modified stream and service, or an error. - fn accept(&self, stream: I, service: S) -> Self::Future; -} - -/// A default implementation of the [`Accept`] trait that performs no modifications. -/// -/// This is a no-op acceptor that simply passes the provided stream and service through without any transformations. -#[derive(Clone, Copy, Debug, Default)] -pub struct DefaultAcceptor; - -impl DefaultAcceptor { - /// Create a new default acceptor instance. - /// - /// # Returns: - /// An instance of [`DefaultAcceptor`]. - pub fn new() -> Self { - Self - } -} - -impl Accept for DefaultAcceptor { - type Stream = I; - type Service = S; - type Future = Ready>; - - fn accept(&self, stream: I, service: S) -> Self::Future { - std::future::ready(Ok((stream, service))) - } -} diff --git a/src/addr_incoming_config.rs b/src/addr_incoming_config.rs deleted file mode 100644 index 250dab6..0000000 --- a/src/addr_incoming_config.rs +++ /dev/null @@ -1,140 +0,0 @@ -use std::time::Duration; - -/// Configuration settings for the `AddrIncoming`. -/// -/// This configuration structure is designed to be used in conjunction with the -/// [`AddrIncoming`](hyper::server::conn::AddrIncoming) type from the Hyper crate. -/// It provides a mechanism to customize server settings like TCP keepalive probes, -/// error handling, and other TCP socket-level configurations. -#[derive(Debug, Clone)] -pub struct AddrIncomingConfig { - pub(crate) tcp_sleep_on_accept_errors: bool, - pub(crate) tcp_keepalive: Option, - pub(crate) tcp_keepalive_interval: Option, - pub(crate) tcp_keepalive_retries: Option, - pub(crate) tcp_nodelay: bool, -} - -impl Default for AddrIncomingConfig { - fn default() -> Self { - Self::new() - } -} - -impl AddrIncomingConfig { - /// Creates a new `AddrIncomingConfig` with default settings. - /// - /// # Default Settings - /// - Sleep on accept errors: `true` - /// - TCP keepalive probes: Disabled (`None`) - /// - Duration between keepalive retransmissions: None - /// - Number of keepalive retransmissions: None - /// - `TCP_NODELAY` option: `false` - /// - /// # Returns - /// - /// A new `AddrIncomingConfig` instance with default settings. - pub fn new() -> AddrIncomingConfig { - Self { - tcp_sleep_on_accept_errors: true, - tcp_keepalive: None, - tcp_keepalive_interval: None, - tcp_keepalive_retries: None, - tcp_nodelay: false, - } - } - - /// Creates a cloned copy of the current configuration. - /// - /// This method can be useful when you want to preserve the original settings and - /// create a modified configuration based on the current one. - /// - /// # Returns - /// - /// A cloned `AddrIncomingConfig`. - pub fn build(&mut self) -> Self { - self.clone() - } - - /// Specifies whether to pause (sleep) when an error occurs while accepting a connection. - /// - /// This can be useful to prevent rapidly exhausting file descriptors in scenarios - /// where errors might be transient or frequent. - /// - /// # Parameters - /// - /// - `val`: Whether to sleep on accept errors. Default is `true`. - /// - /// # Returns - /// - /// A mutable reference to the current `AddrIncomingConfig`. - pub fn tcp_sleep_on_accept_errors(&mut self, val: bool) -> &mut Self { - self.tcp_sleep_on_accept_errors = val; - self - } - - /// Configures the frequency of TCP keepalive probes. - /// - /// TCP keepalive probes are used to detect whether a peer is still connected. - /// - /// # Parameters - /// - /// - `val`: Duration between keepalive probes. Setting to `None` disables keepalive probes. Default is `None`. - /// - /// # Returns - /// - /// A mutable reference to the current `AddrIncomingConfig`. - pub fn tcp_keepalive(&mut self, val: Option) -> &mut Self { - self.tcp_keepalive = val; - self - } - - /// Configures the duration between two successive TCP keepalive retransmissions. - /// - /// If an acknowledgment to a previous keepalive probe isn't received within this duration, - /// a new probe will be sent. - /// - /// # Parameters - /// - /// - `val`: Duration between keepalive retransmissions. Default is no interval (`None`). - /// - /// # Returns - /// - /// A mutable reference to the current `AddrIncomingConfig`. - pub fn tcp_keepalive_interval(&mut self, val: Option) -> &mut Self { - self.tcp_keepalive_interval = val; - self - } - - /// Configures the number of times to retransmit a TCP keepalive probe if no acknowledgment is received. - /// - /// After the specified number of retransmissions, the remote end is considered unavailable. - /// - /// # Parameters - /// - /// - `val`: Number of retransmissions before considering the remote end unavailable. Default is no retry (`None`). - /// - /// # Returns - /// - /// A mutable reference to the current `AddrIncomingConfig`. - pub fn tcp_keepalive_retries(&mut self, val: Option) -> &mut Self { - self.tcp_keepalive_retries = val; - self - } - - /// Configures the `TCP_NODELAY` option for accepted connections. - /// - /// When enabled, this option disables Nagle's algorithm, which can reduce latencies for small packets. - /// - /// # Parameters - /// - /// - `val`: Whether to enable `TCP_NODELAY`. Default is `false`. - /// - /// # Returns - /// - /// A mutable reference to the current `AddrIncomingConfig`. - pub fn tcp_nodelay(&mut self, val: bool) -> &mut Self { - self.tcp_nodelay = val; - self - } -} diff --git a/src/handle.rs b/src/handle.rs deleted file mode 100644 index bd5269c..0000000 --- a/src/handle.rs +++ /dev/null @@ -1,167 +0,0 @@ -use crate::notify_once::NotifyOnce; -use std::{ - net::SocketAddr, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, Mutex, - }, - time::Duration, -}; -use tokio::{sync::Notify, time::sleep}; - -/// A handle to manage and interact with the server. -/// -/// `Handle` provides methods to access server information, such as the number of active connections, -/// and to perform actions like initiating a shutdown. -#[derive(Clone, Debug, Default)] -pub struct Handle { - inner: Arc, -} - -#[derive(Debug, Default)] -struct HandleInner { - addr: Mutex>, - addr_notify: Notify, - conn_count: AtomicUsize, - shutdown: NotifyOnce, - graceful: NotifyOnce, - graceful_dur: Mutex>, - conn_end: NotifyOnce, -} - -impl Handle { - /// Create a new handle for the server. - /// - /// # Returns - /// - /// A new `Handle` instance. - pub fn new() -> Self { - Self::default() - } - - /// Get the number of active connections to the server. - /// - /// # Returns - /// - /// The number of active connections. - pub fn connection_count(&self) -> usize { - self.inner.conn_count.load(Ordering::SeqCst) - } - - /// Initiate an immediate shutdown of the server. - /// - /// This method will terminate the server without waiting for active connections to close. - pub fn shutdown(&self) { - self.inner.shutdown.notify_waiters(); - } - - /// Initiate a graceful shutdown of the server. - /// - /// The server will wait for active connections to close before shutting down. If a duration - /// is provided, the server will wait up to that duration for active connections to close - /// before forcing a shutdown. - /// - /// # Parameters - /// - /// - `duration`: Maximum time to wait for active connections to close. `None` means the server - /// will wait indefinitely. - pub fn graceful_shutdown(&self, duration: Option) { - *self.inner.graceful_dur.lock().unwrap() = duration; - self.inner.graceful.notify_waiters(); - } - - /// Wait until the server starts listening and then returns its local address and port. - /// - /// # Returns - /// - /// The local `SocketAddr` if the server successfully binds, otherwise `None`. - pub async fn listening(&self) -> Option { - let notified = self.inner.addr_notify.notified(); - - if let Some(addr) = *self.inner.addr.lock().unwrap() { - return Some(addr); - } - - notified.await; - - *self.inner.addr.lock().unwrap() - } - - /// Internal method to notify the handle when the server starts listening on a particular address. - pub(crate) fn notify_listening(&self, addr: Option) { - *self.inner.addr.lock().unwrap() = addr; - self.inner.addr_notify.notify_waiters(); - } - - /// Creates a watcher that monitors server status and connection activity. - pub(crate) fn watcher(&self) -> Watcher { - Watcher::new(self.clone()) - } - - /// Internal method to wait until the server is shut down. - pub(crate) async fn wait_shutdown(&self) { - self.inner.shutdown.notified().await; - } - - /// Internal method to wait until the server is gracefully shut down. - pub(crate) async fn wait_graceful_shutdown(&self) { - self.inner.graceful.notified().await; - } - - /// Internal method to wait until all connections have ended, or the optional graceful duration has expired. - pub(crate) async fn wait_connections_end(&self) { - if self.inner.conn_count.load(Ordering::SeqCst) == 0 { - return; - } - - let deadline = *self.inner.graceful_dur.lock().unwrap(); - - match deadline { - Some(duration) => tokio::select! { - biased; - _ = sleep(duration) => self.shutdown(), - _ = self.inner.conn_end.notified() => (), - }, - None => self.inner.conn_end.notified().await, - } - } -} - -/// A watcher that monitors server status and connection activity. -/// -/// The watcher keeps track of active connections and listens for shutdown or graceful shutdown signals. -pub(crate) struct Watcher { - handle: Handle, -} - -impl Watcher { - /// Creates a new watcher linked to the given server handle. - fn new(handle: Handle) -> Self { - handle.inner.conn_count.fetch_add(1, Ordering::SeqCst); - Self { handle } - } - - /// Internal method to wait until the server is gracefully shut down. - pub(crate) async fn wait_graceful_shutdown(&self) { - self.handle.wait_graceful_shutdown().await - } - - /// Internal method to wait until the server is shut down. - pub(crate) async fn wait_shutdown(&self) { - self.handle.wait_shutdown().await - } -} - -impl Drop for Watcher { - /// Reduces the active connection count when a watcher is dropped. - /// - /// If the connection count reaches zero and a graceful shutdown has been initiated, the server is notified that - /// all connections have ended. - fn drop(&mut self) { - let count = self.handle.inner.conn_count.fetch_sub(1, Ordering::SeqCst) - 1; - - if count == 0 && self.handle.inner.graceful.is_notified() { - self.handle.inner.conn_end.notify_waiters(); - } - } -} diff --git a/src/http_config.rs b/src/http_config.rs deleted file mode 100644 index 481a81e..0000000 --- a/src/http_config.rs +++ /dev/null @@ -1,263 +0,0 @@ -use hyper::server::conn::Http; -use std::time::Duration; - -/// Represents a configuration for the [`Http`] protocol. -/// This allows for detailed customization of various HTTP/1 and HTTP/2 settings. -#[derive(Debug, Clone)] -pub struct HttpConfig { - /// The inner HTTP configuration from the `hyper` crate. - pub(crate) inner: Http, -} - -impl Default for HttpConfig { - /// Provides a default HTTP configuration. - fn default() -> Self { - Self::new() - } -} - -impl HttpConfig { - /// Creates a new `HttpConfig` with default settings. - pub fn new() -> HttpConfig { - Self { inner: Http::new() } - } - - /// Clones the current configuration state and returns it. - /// Useful for building configurations dynamically. - pub fn build(&mut self) -> Self { - self.clone() - } - - /// Configures whether to exclusively support HTTP/1. - /// - /// When enabled, only HTTP/1 requests are processed, and HTTP/2 requests are rejected. - /// - /// Default is `false`. - pub fn http1_only(&mut self, val: bool) -> &mut Self { - self.inner.http1_only(val); - self - } - - /// Specifies if HTTP/1 connections should be allowed to use half-closures. - /// - /// A half-closure in TCP occurs when one side of the data stream is terminated, - /// but the other side remains open. This setting, when `true`, ensures the server - /// doesn't immediately close a connection if a client shuts down their sending side - /// while waiting for a response. - /// - /// Default is `false`. - pub fn http1_half_close(&mut self, val: bool) -> &mut Self { - self.inner.http1_half_close(val); - self - } - - /// Enables or disables the keep-alive feature for HTTP/1 connections. - /// - /// Keep-alive allows the connection to be reused for multiple requests and responses. - /// - /// Default is true. - pub fn http1_keep_alive(&mut self, val: bool) -> &mut Self { - self.inner.http1_keep_alive(val); - self - } - - /// Determines if HTTP/1 connections should write headers with title-case naming. - /// - /// For example, turning this setting `true` would send headers as "Content-Type" instead of "content-type". - /// Note that this has no effect on HTTP/2 connections. - /// - /// Default is false. - pub fn http1_title_case_headers(&mut self, enabled: bool) -> &mut Self { - self.inner.http1_title_case_headers(enabled); - self - } - - /// Determines if HTTP/1 connections should preserve the original case of headers. - /// - /// By default, headers might be normalized. Enabling this will ensure headers retain their original casing. - /// This setting doesn't influence HTTP/2. - /// - /// Default is false. - pub fn http1_preserve_header_case(&mut self, enabled: bool) -> &mut Self { - self.inner.http1_preserve_header_case(enabled); - self - } - - /// Configures a timeout for how long the server will wait for client headers. - /// - /// If the client doesn't send all headers within this duration, the connection is terminated. - /// - /// Default is None, meaning no timeout. - pub fn http1_header_read_timeout(&mut self, val: Duration) -> &mut Self { - self.inner.http1_header_read_timeout(val); - self - } - - /// Specifies whether to use vectored writes for HTTP/1 connections. - /// - /// Vectored writes can be efficient for multiple non-contiguous data segments. - /// However, certain transports (like many TLS implementations) may not handle vectored writes well. - /// When disabled, data is flattened into a single buffer before writing. - /// - /// Default is `auto`, where the best method is determined dynamically. - pub fn http1_writev(&mut self, val: bool) -> &mut Self { - self.inner.http1_writev(val); - self - } - - /// Configures the server to exclusively support HTTP/2. - /// - /// When enabled, only HTTP/2 requests are processed, and HTTP/1 requests are rejected. - /// - /// Default is false. - pub fn http2_only(&mut self, val: bool) -> &mut Self { - self.inner.http2_only(val); - self - } - - /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2 - /// stream-level flow control. - /// - /// Passing `None` will do nothing. - /// - /// If not set, hyper will use a default. - /// - /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_INITIAL_WINDOW_SIZE - pub fn http2_initial_stream_window_size(&mut self, sz: impl Into>) -> &mut Self { - self.inner.http2_initial_stream_window_size(sz); - self - } - - /// Sets the max connection-level flow control for HTTP2. - /// - /// Passing `None` will do nothing. - /// - /// If not set, hyper will use a default. - pub fn http2_initial_connection_window_size( - &mut self, - sz: impl Into>, - ) -> &mut Self { - self.inner.http2_initial_connection_window_size(sz); - self - } - - /// Sets whether to use an adaptive flow control. - /// - /// Enabling this will override the limits set in - /// `http2_initial_stream_window_size` and - /// `http2_initial_connection_window_size`. - pub fn http2_adaptive_window(&mut self, enabled: bool) -> &mut Self { - self.inner.http2_adaptive_window(enabled); - self - } - - /// Enables the [extended CONNECT protocol]. - /// - /// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 - pub fn http2_enable_connect_protocol(&mut self) -> &mut Self { - self.inner.http2_enable_connect_protocol(); - self - } - - /// Sets the maximum frame size to use for HTTP2. - /// - /// Passing `None` will do nothing. - /// - /// If not set, hyper will use a default. - pub fn http2_max_frame_size(&mut self, sz: impl Into>) -> &mut Self { - self.inner.http2_max_frame_size(sz); - self - } - - /// Sets the [`SETTINGS_MAX_CONCURRENT_STREAMS`][spec] option for HTTP2 - /// connections. - /// - /// Default is no limit (`std::u32::MAX`). Passing `None` will do nothing. - /// - /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_MAX_CONCURRENT_STREAMS - pub fn http2_max_concurrent_streams(&mut self, max: impl Into>) -> &mut Self { - self.inner.http2_max_concurrent_streams(max); - self - } - - /// Sets the max size of received header frames. - /// - /// Default is currently ~16MB, but may change. - pub fn http2_max_header_list_size(&mut self, max: u32) -> &mut Self { - self.inner.http2_max_header_list_size(max); - self - } - - /// Configures the maximum number of pending reset streams allowed before a GOAWAY will be sent. - /// - /// This will default to the default value set by the [`h2` crate](https://crates.io/crates/h2). - /// As of v0.3.17, it is 20. - /// - /// See for more information. - pub fn http2_max_pending_accept_reset_streams( - &mut self, - max: impl Into>, - ) -> &mut Self { - self.inner.http2_max_pending_accept_reset_streams(max); - self - } - - /// Set the maximum write buffer size for each HTTP/2 stream. - /// - /// Default is currently ~400KB, but may change. - /// - /// # Panics - /// - /// The value must be no larger than `u32::MAX`. - pub fn http2_max_send_buf_size(&mut self, max: usize) -> &mut Self { - self.inner.http2_max_send_buf_size(max); - self - } - - /// Sets an interval for HTTP2 Ping frames should be sent to keep a - /// connection alive. - /// - /// Pass `None` to disable HTTP2 keep-alive. - /// - /// Default is currently disabled. - pub fn http2_keep_alive_interval( - &mut self, - interval: impl Into>, - ) -> &mut Self { - self.inner.http2_keep_alive_interval(interval); - self - } - - /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. - /// - /// If the ping is not acknowledged within the timeout, the connection will - /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. - /// - /// Default is 20 seconds. - pub fn http2_keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self { - self.inner.http2_keep_alive_timeout(timeout); - self - } - - /// Set the maximum buffer size for the HTTP/1 connection. - /// - /// Default is ~400kb. - /// - /// # Panics - /// - /// The minimum value allowed is 8192. This method panics if the passed `max` is less than the minimum. - pub fn max_buf_size(&mut self, max: usize) -> &mut Self { - self.inner.max_buf_size(max); - self - } - - /// Determines if multiple responses should be buffered and sent together to support pipelined responses. - /// - /// This can improve throughput in certain situations, but is experimental and might contain issues. - /// - /// Default is false. - pub fn pipeline_flush(&mut self, enabled: bool) -> &mut Self { - self.inner.pipeline_flush(enabled); - self - } -} diff --git a/src/lib.rs b/src/lib.rs index 932c092..e69de29 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,130 +0,0 @@ -//! hyper-server is a [hyper] server implementation designed to be used with [axum] framework. -//! -//! # Features -//! -//! - HTTP/1 and HTTP/2 -//! - HTTPS through [rustls] or [openssl]. -//! - High performance through [hyper]. -//! - Using [tower] make service API. -//! - Very good [axum] compatibility. Likely to work with future [axum] releases. -//! -//! # Guide -//! -//! hyper-server can [`serve`] items that implement [`MakeService`] with some additional [trait -//! bounds](crate::service::MakeServiceRef). Make services that are [created] using [`axum`] -//! complies with those trait bounds out of the box. Therefore it is more convenient to use this -//! crate with [`axum`]. -//! -//! All examples in this crate uses [`axum`]. If you want to use this crate without [`axum`] it is -//! highly recommended to learn how [tower] works. -//! -//! [`Server::bind`] or [`bind`] function can be called to create a server that will bind to -//! provided [`SocketAddr`] when [`serve`] is called. -//! -//! A [`Handle`] can be passed to [`Server`](Server::handle) for additional utilities like shutdown -//! and graceful shutdown. -//! -//! [`bind_rustls`] can be called by providing [`RustlsConfig`] to create a HTTPS [`Server`] that -//! will bind on provided [`SocketAddr`]. [`RustlsConfig`] can be cloned, reload methods can be -//! used on clone to reload tls configuration. -//! -//! # Features -//! -//! * `tls-rustls` - activate [rustls] support. -//! * `tls-openssl` - activate [openssl] support. -//! -//! # Example -//! -//! A simple hello world application can be served like: -//! -//! ```rust,no_run -//! use axum::{routing::get, Router}; -//! use std::net::SocketAddr; -//! -//! #[tokio::main] -//! async fn main() { -//! let app = Router::new().route("/", get(|| async { "Hello, world!" })); -//! -//! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); -//! println!("listening on {}", addr); -//! hyper_server::bind(addr) -//! .serve(app.into_make_service()) -//! .await -//! .unwrap(); -//! } -//! ``` -//! -//! You can find more examples in [repository]. -//! -//! [axum]: https://crates.io/crates/axum -//! [bind]: crate::bind -//! [bind_rustls]: crate::bind_rustls -//! [created]: https://docs.rs/axum/0.3/axum/struct.Router.html#method.into_make_service -//! [hyper]: https://crates.io/crates/hyper -//! [openssl]: https://crates.io/crates/openssl -//! [repository]: https://github.com/valorem-labs-inc/hyper-server/examples -//! [rustls]: https://crates.io/crates/rustls -//! [tower]: https://crates.io/crates/tower -//! [`axum`]: https://docs.rs/axum/0.3 -//! [`serve`]: crate::server::Server::serve -//! [`MakeService`]: https://docs.rs/tower/0.4/tower/make/trait.MakeService.html -//! [`RustlsConfig`]: crate::tls_rustls::RustlsConfig -//! [`SocketAddr`]: std::net::SocketAddr - -#![forbid(unsafe_code)] -#![warn( - clippy::await_holding_lock, - clippy::cargo_common_metadata, - clippy::dbg_macro, - clippy::doc_markdown, - clippy::empty_enum, - clippy::enum_glob_use, - clippy::inefficient_to_string, - clippy::mem_forget, - clippy::mutex_integer, - clippy::needless_continue, - clippy::todo, - clippy::unimplemented, - clippy::wildcard_imports, - future_incompatible, - missing_docs, - missing_debug_implementations, - unreachable_pub -)] -#![cfg_attr(docsrs, feature(doc_cfg))] - -mod addr_incoming_config; -mod handle; -mod http_config; -mod notify_once; -mod server; - -pub mod accept; -pub mod service; - -pub use self::{ - addr_incoming_config::AddrIncomingConfig, - handle::Handle, - http_config::HttpConfig, - server::{bind, from_tcp, Server}, -}; - -#[cfg(feature = "tls-rustls")] -#[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))] -pub mod tls_rustls; - -#[doc(inline)] -#[cfg(feature = "tls-rustls")] -pub use self::tls_rustls::export::{bind_rustls, from_tcp_rustls}; - -#[cfg(feature = "tls-openssl")] -#[cfg_attr(docsrs, doc(cfg(feature = "tls-openssl")))] -pub mod tls_openssl; - -#[doc(inline)] -#[cfg(feature = "tls-openssl")] -pub use self::tls_openssl::bind_openssl; - -#[cfg(feature = "proxy-protocol")] -#[cfg_attr(docsrs, doc(cfg(feature = "proxy_protocol")))] -pub mod proxy_protocol; diff --git a/src/notify_once.rs b/src/notify_once.rs deleted file mode 100644 index e4564cd..0000000 --- a/src/notify_once.rs +++ /dev/null @@ -1,45 +0,0 @@ -use std::sync::atomic::{AtomicBool, Ordering}; -use tokio::sync::Notify; - -/// A thread-safe utility that provides a one-time notification to waiters. -/// It utilizes an atomic boolean to ensure that the notification is sent only once. -#[derive(Debug, Default)] -pub(crate) struct NotifyOnce { - /// A flag indicating whether a notification has been sent. - notified: AtomicBool, - /// An asynchronous primitive from the Tokio library used for notifying tasks. - notify: Notify, -} - -impl NotifyOnce { - /// Notifies all waiting tasks, ensuring that the notification happens only once. - pub(crate) fn notify_waiters(&self) { - // Atomically set the `notified` flag to true. - self.notified.store(true, Ordering::SeqCst); - - // Notify all waiting tasks. - self.notify.notify_waiters(); - } - - /// Checks whether a notification has been sent. - /// - /// Returns: - /// - `true` if the notification has already been sent. - /// - `false` otherwise. - pub(crate) fn is_notified(&self) -> bool { - self.notified.load(Ordering::SeqCst) - } - - /// Awaits until a notification has been sent. - /// - /// This asynchronous function will immediately complete if a notification - /// has already been sent, otherwise it will await until it's notified. - pub(crate) async fn notified(&self) { - let future = self.notify.notified(); - - // If not notified, await on the future. - if !self.notified.load(Ordering::SeqCst) { - future.await; - } - } -} diff --git a/src/proxy_protocol/future.rs b/src/proxy_protocol/future.rs deleted file mode 100644 index bd9ab26..0000000 --- a/src/proxy_protocol/future.rs +++ /dev/null @@ -1,143 +0,0 @@ -//! Future types for PROXY protocol support. -use crate::accept::Accept; -use crate::proxy_protocol::ForwardClientIp; -use pin_project_lite::pin_project; -use std::{ - fmt, - future::Future, - io, - net::SocketAddr, - pin::Pin, - task::{Context, Poll}, -}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::time::Timeout; - -// A `pin_project` is a procedural macro used for safe field projection in conjunction -// with the Rust Pin API, which guarantees that certain types will not move in memory. -pin_project! { - /// This struct represents the future for the ProxyProtocolAcceptor. - /// The generic types are: - /// F: The future type. - /// A: The type that implements the Accept trait. - /// I: The IO type that supports both AsyncRead and AsyncWrite. - /// S: The service type. - pub struct ProxyProtocolAcceptorFuture - where - A: Accept, - { - #[pin] - inner: AcceptFuture, - } -} - -impl ProxyProtocolAcceptorFuture -where - A: Accept, - I: AsyncRead + AsyncWrite + Unpin, -{ - // Constructor for creating a new ProxyProtocolAcceptorFuture. - pub(crate) fn new(future: Timeout, acceptor: A, service: S) -> Self { - let inner = AcceptFuture::ReadHeader { - future, - acceptor, - service: Some(service), - }; - Self { inner } - } -} - -// Implement Debug trait for ProxyProtocolAcceptorFuture to allow -// debugging and logging. -impl fmt::Debug for ProxyProtocolAcceptorFuture -where - A: Accept, - I: AsyncRead + AsyncWrite + Unpin, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("ProxyProtocolAcceptorFuture").finish() - } -} - -pin_project! { - // AcceptFuture represents the internal states of ProxyProtocolAcceptorFuture. - // It can either be waiting to read the header or forward the client IP. - #[project = AcceptFutureProj] - enum AcceptFuture - where - A: Accept, - { - ReadHeader { - #[pin] - future: Timeout, - acceptor: A, - service: Option, - }, - ForwardIp { - #[pin] - future: A::Future, - client_address: Option, - }, - } -} - -impl Future for ProxyProtocolAcceptorFuture -where - A: Accept, - I: AsyncRead + AsyncWrite + Unpin, - // Future whose output is a result with either a tuple of stream and optional address, - // or an io::Error. - F: Future), io::Error>>, -{ - type Output = io::Result<(A::Stream, ForwardClientIp)>; - - // The main poll function that drives the future towards completion. - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - - loop { - // Check the current state of the inner future. - match this.inner.as_mut().project() { - AcceptFutureProj::ReadHeader { - future, - acceptor, - service, - } => match future.poll(cx) { - Poll::Ready(Ok(Ok((stream, client_address)))) => { - let service = service.take().expect("future polled after ready"); - let future = acceptor.accept(stream, service); - - // Transition to the ForwardIp state after successfully reading the header. - this.inner.set(AcceptFuture::ForwardIp { - future, - client_address, - }); - } - Poll::Ready(Ok(Err(e))) => return Poll::Ready(Err(e)), - Poll::Ready(Err(timeout)) => { - return Poll::Ready(Err(io::Error::new(io::ErrorKind::TimedOut, timeout))) - } - Poll::Pending => return Poll::Pending, - }, - AcceptFutureProj::ForwardIp { - future, - client_address, - } => { - return match future.poll(cx) { - Poll::Ready(Ok((stream, service))) => { - let service = ForwardClientIp { - inner: service, - client_address: *client_address, - }; - - // Return the successfully processed stream and service. - Poll::Ready(Ok((stream, service))) - } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, - }; - } - } - } - } -} diff --git a/src/proxy_protocol/mod.rs b/src/proxy_protocol/mod.rs deleted file mode 100644 index 0906eed..0000000 --- a/src/proxy_protocol/mod.rs +++ /dev/null @@ -1,841 +0,0 @@ -//! This feature allows the `hyper_server` to be used behind a layer 4 load balancer whilst the proxy -//! protocol is enabled to preserve the client IP address and port. -//! See The PROXY protocol spec for more details: . -//! -//! Any client address found in the proxy protocol header is forwarded on in the HTTP `forwarded` -//! header to be accessible by the rest server. -//! -//! Note: if you are setting a custom acceptor, `enable_proxy_protocol` must be called after this is set. -//! It is best to use directly before calling `serve` when the inner acceptor is already configured. -//! `ProxyProtocolAcceptor` wraps the initial acceptor, so the proxy header is removed from the -//! beginning of the stream before the messages are forwarded on. -//! -//! # Example -//! -//! ```rust,no_run -//! use axum::{routing::get, Router}; -//! use std::net::SocketAddr; -//! use std::time::Duration; -//! -//! #[tokio::main] -//! async fn main() { -//! let app = Router::new().route("/", get(|| async { "Hello, world!" })); -//! -//! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); -//! println!("listening on {}", addr); -//! -//! // Can configure if you want different from the default of 5 seconds, -//! // otherwise passing `None` will use the default. -//! let proxy_header_timeout = Some(Duration::from_secs(2)); -//! -//! hyper_server::bind(addr) -//! .enable_proxy_protocol(proxy_header_timeout) -//! .serve(app.into_make_service()) -//! .await -//! .unwrap(); -//! } -//! ``` -use crate::accept::Accept; -use std::{ - fmt, - future::Future, - io, - net::{IpAddr, SocketAddr}, - pin::Pin, - task::{Context, Poll}, - time::Duration, -}; - -use http::HeaderValue; -use http::Request; -use ppp::{v1, v2, HeaderResult}; -use tokio::{ - io::{AsyncRead, AsyncReadExt, AsyncWrite}, - time::timeout, -}; -use tower_service::Service; - -pub(crate) mod future; -use self::future::ProxyProtocolAcceptorFuture; - -/// The length of a v1 header in bytes. -const V1_PREFIX_LEN: usize = 5; -/// The maximum length of a v1 header in bytes. -const V1_MAX_LENGTH: usize = 107; -/// The terminator bytes of a v1 header. -const V1_TERMINATOR: &[u8] = b"\r\n"; -/// The prefix length of a v2 header in bytes. -const V2_PREFIX_LEN: usize = 12; -/// The minimum length of a v2 header in bytes. -const V2_MINIMUM_LEN: usize = 16; -/// The index of the start of the big-endian u16 length in the v2 header. -const V2_LENGTH_INDEX: usize = 14; -/// The length of the read buffer used to read the PROXY protocol header. -const READ_BUFFER_LEN: usize = 512; - -pub(crate) async fn read_proxy_header( - mut stream: I, -) -> Result<(I, Option), io::Error> -where - I: AsyncRead + Unpin, -{ - // Mutable buffer for storing stream data - let mut buffer = [0; READ_BUFFER_LEN]; - // Dynamic in case v2 header is too long - let mut dynamic_buffer = None; - - // Read prefix to check for v1, v2, or kill - stream.read_exact(&mut buffer[..V1_PREFIX_LEN]).await?; - - if &buffer[..V1_PREFIX_LEN] == v1::PROTOCOL_PREFIX.as_bytes() { - read_v1_header(&mut stream, &mut buffer).await?; - } else { - stream - .read_exact(&mut buffer[V1_PREFIX_LEN..V2_MINIMUM_LEN]) - .await?; - if &buffer[..V2_PREFIX_LEN] == v2::PROTOCOL_PREFIX { - dynamic_buffer = read_v2_header(&mut stream, &mut buffer).await?; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "No valid Proxy Protocol header detected", - )); - } - } - - // Choose which buffer to parse - let buffer = dynamic_buffer.as_deref().unwrap_or(&buffer[..]); - - // Parse the header - let header = HeaderResult::parse(buffer); - match header { - HeaderResult::V1(Ok(header)) => { - let client_address = match header.addresses { - v1::Addresses::Tcp4(ip) => { - SocketAddr::new(IpAddr::V4(ip.source_address), ip.source_port) - } - v1::Addresses::Tcp6(ip) => { - SocketAddr::new(IpAddr::V6(ip.source_address), ip.source_port) - } - v1::Addresses::Unknown => { - // Return client address as `None` so that "unknown" is used in the http header - return Ok((stream, None)); - } - }; - - Ok((stream, Some(client_address))) - } - HeaderResult::V2(Ok(header)) => { - let client_address = match header.addresses { - v2::Addresses::IPv4(ip) => { - SocketAddr::new(IpAddr::V4(ip.source_address), ip.source_port) - } - v2::Addresses::IPv6(ip) => { - SocketAddr::new(IpAddr::V6(ip.source_address), ip.source_port) - } - v2::Addresses::Unix(unix) => { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!( - "Unix socket addresses are not supported. Addresses: {:?}", - unix - ), - )); - } - v2::Addresses::Unspecified => { - // Return client address as `None` so that "unknown" is used in the http header - return Ok((stream, None)); - } - }; - - Ok((stream, Some(client_address))) - } - HeaderResult::V1(Err(_error)) => Err(io::Error::new( - io::ErrorKind::InvalidData, - "No valid V1 Proxy Protocol header received", - )), - HeaderResult::V2(Err(_error)) => Err(io::Error::new( - io::ErrorKind::InvalidData, - "No valid V2 Proxy Protocol header received", - )), - } -} - -async fn read_v2_header( - mut stream: I, - buffer: &mut [u8; READ_BUFFER_LEN], -) -> Result>, io::Error> -where - I: AsyncRead + Unpin, -{ - let length = - u16::from_be_bytes([buffer[V2_LENGTH_INDEX], buffer[V2_LENGTH_INDEX + 1]]) as usize; - let full_length = V2_MINIMUM_LEN + length; - - // Switch to dynamic buffer if header is too long; v2 has no maximum length - if full_length > READ_BUFFER_LEN { - let mut dynamic_buffer = Vec::with_capacity(full_length); - dynamic_buffer.extend_from_slice(&buffer[..V2_MINIMUM_LEN]); - - // Read the remaining header length - stream - .read_exact(&mut dynamic_buffer[V2_MINIMUM_LEN..full_length]) - .await?; - - Ok(Some(dynamic_buffer)) - } else { - // Read the remaining header length - stream - .read_exact(&mut buffer[V2_MINIMUM_LEN..full_length]) - .await?; - - Ok(None) - } -} - -async fn read_v1_header( - mut stream: I, - buffer: &mut [u8; READ_BUFFER_LEN], -) -> Result<(), io::Error> -where - I: AsyncRead + Unpin, -{ - // read one byte at a time until terminator found - let mut end_found = false; - for i in V1_PREFIX_LEN..V1_MAX_LENGTH { - buffer[i] = stream.read_u8().await?; - - if [buffer[i - 1], buffer[i]] == V1_TERMINATOR { - end_found = true; - break; - } - } - if !end_found { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "No valid Proxy Protocol header detected", - )); - } - - Ok(()) -} - -/// Middleware for adding client IP address to the request `forwarded` header. -/// see spec: -#[derive(Debug, Clone)] -pub struct ForwardClientIp { - inner: S, - client_address: Option, -} - -impl Service> for ForwardClientIp -where - S: Service>, -{ - type Response = S::Response; - type Error = S::Error; - type Future = S::Future; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, mut req: Request) -> Self::Future { - // The full socket address is available in the proxy header, hence why we include port - let mut forwarded_string = match self.client_address { - Some(socket_addr) => match socket_addr { - SocketAddr::V4(addr) => { - format!("for={}:{}", addr.ip(), addr.port()) - } - SocketAddr::V6(addr) => { - format!("for=\"[{}]:{}\"", addr.ip(), addr.port()) - } - }, - None => "for=unknown".to_string(), - }; - - if let Some(existing_value) = req.headers_mut().get("Forwarded") { - forwarded_string = format!( - "{}, {}", - existing_value.to_str().unwrap_or(""), - forwarded_string - ); - } - - if let Ok(header_value) = HeaderValue::from_str(&forwarded_string) { - req.headers_mut().insert("Forwarded", header_value); - } - - self.inner.call(req) - } -} - -/// Acceptor wrapper for receiving Proxy Protocol headers. -#[derive(Clone)] -pub struct ProxyProtocolAcceptor { - inner: A, - parsing_timeout: Duration, -} - -impl ProxyProtocolAcceptor { - /// Create a new proxy protocol acceptor from an initial acceptor. - /// This is compatible with tls acceptors. - pub fn new(inner: A) -> Self { - #[cfg(not(test))] - let parsing_timeout = Duration::from_secs(5); - - // Don't force tests to wait too long. - #[cfg(test)] - let parsing_timeout = Duration::from_secs(1); - - Self { - inner, - parsing_timeout, - } - } - - /// Override the default Proxy Header parsing timeout. - pub fn parsing_timeout(mut self, val: Duration) -> Self { - self.parsing_timeout = val; - self - } -} - -impl ProxyProtocolAcceptor { - /// Overwrite inner acceptor. - pub fn acceptor(self, acceptor: Acceptor) -> ProxyProtocolAcceptor { - ProxyProtocolAcceptor { - inner: acceptor, - parsing_timeout: self.parsing_timeout, - } - } -} - -impl Accept for ProxyProtocolAcceptor -where - A: Accept + Clone, - A::Stream: AsyncRead + AsyncWrite + Unpin, - I: AsyncRead + AsyncWrite + Unpin + Send + 'static, -{ - type Stream = A::Stream; - type Service = ForwardClientIp; - type Future = ProxyProtocolAcceptorFuture< - Pin), io::Error>> + Send>>, - A, - I, - S, - >; - - fn accept(&self, stream: I, service: S) -> Self::Future { - let future = Box::pin(read_proxy_header(stream)); - - ProxyProtocolAcceptorFuture::new( - timeout(self.parsing_timeout, future), - self.inner.clone(), - service, - ) - } -} - -impl fmt::Debug for ProxyProtocolAcceptor { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("ProxyProtocolAcceptor").finish() - } -} - -#[cfg(test)] -mod tests { - #[cfg(feature = "tls-openssl")] - use crate::tls_openssl::{ - self, - tests::{dns_name as openssl_dns_name, tls_connector as openssl_connector}, - OpenSSLConfig, - }; - #[cfg(feature = "tls-rustls")] - use crate::tls_rustls::{ - self, - tests::{dns_name as rustls_dns_name, tls_connector as rustls_connector}, - RustlsConfig, - }; - use crate::{handle::Handle, server::Server}; - use axum::http::Response; - use axum::{routing::get, Router}; - use bytes::Bytes; - use http::{response, Request}; - use hyper::{ - client::conn::{handshake, SendRequest}, - Body, - }; - use ppp::v2::{Builder, Command, Protocol, Type, Version}; - use std::{io, net::SocketAddr, time::Duration}; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - use tokio::{ - net::{TcpListener, TcpStream}, - task::JoinHandle, - time::timeout, - }; - use tower::{Service, ServiceExt}; - - #[tokio::test] - async fn start_and_request() { - let (_handle, _server_task, server_addr) = start_server(true).await; - - let addr = start_proxy(server_addr, ProxyVersion::V2) - .await - .expect("Failed to start proxy"); - - let (mut client, _conn, _client_addr) = connect(addr).await; - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - } - - #[tokio::test] - async fn server_receives_client_address() { - let (_handle, _server_task, server_addr) = start_server(true).await; - - let addr = start_proxy(server_addr, ProxyVersion::V2) - .await - .expect("Failed to start proxy"); - - let (mut client, _conn, client_addr) = connect(addr).await; - - let (parts, body) = send_empty_request(&mut client).await; - - // Check for the Forwarded header - let forwarded_header = parts - .headers - .get("Forwarded") - .expect("No Forwarded header present") - .to_str() - .expect("Failed to convert Forwarded header to str"); - - assert!(forwarded_header.contains(&format!("for={}", client_addr))); - assert_eq!(body.as_ref(), b"Hello, world!"); - } - - #[tokio::test] - async fn server_receives_client_address_v1() { - let (_handle, _server_task, server_addr) = start_server(true).await; - - let addr = start_proxy(server_addr, ProxyVersion::V1) - .await - .expect("Failed to start proxy"); - - let (mut client, _conn, client_addr) = connect(addr).await; - - let (parts, body) = send_empty_request(&mut client).await; - - // Check for the Forwarded header - let forwarded_header = parts - .headers - .get("Forwarded") - .expect("No Forwarded header present") - .to_str() - .expect("Failed to convert Forwarded header to str"); - - assert!(forwarded_header.contains(&format!("for={}", client_addr))); - assert_eq!(body.as_ref(), b"Hello, world!"); - } - - #[cfg(feature = "tls-rustls")] - #[tokio::test] - async fn rustls_server_receives_client_address() { - let (_handle, _server_task, server_addr) = start_rustls_server().await; - - let addr = start_proxy(server_addr, ProxyVersion::V2) - .await - .expect("Failed to start proxy"); - - let (mut client, _conn, client_addr) = rustls_connect(addr).await; - - let (parts, body) = send_empty_request(&mut client).await; - - // Check for the Forwarded header - let forwarded_header = parts - .headers - .get("Forwarded") - .expect("No Forwarded header present") - .to_str() - .expect("Failed to convert Forwarded header to str"); - - assert!(forwarded_header.contains(&format!("for={}", client_addr))); - assert_eq!(body.as_ref(), b"Hello, world!"); - } - - #[cfg(feature = "tls-openssl")] - #[tokio::test] - async fn openssl_server_receives_client_address() { - let (_handle, _server_task, server_addr) = start_openssl_server().await; - - let addr = start_proxy(server_addr, ProxyVersion::V2) - .await - .expect("Failed to start proxy"); - - let (mut client, _conn, client_addr) = openssl_connect(addr).await; - - let (parts, body) = send_empty_request(&mut client).await; - - // Check for the Forwarded header - let forwarded_header = parts - .headers - .get("Forwarded") - .expect("No Forwarded header present") - .to_str() - .expect("Failed to convert Forwarded header to str"); - - assert!(forwarded_header.contains(&format!("for={}", client_addr))); - assert_eq!(body.as_ref(), b"Hello, world!"); - } - - #[tokio::test] - async fn not_parsing_when_header_present_fails() { - // Start the server with proxy protocol disabled - let (_handle, _server_task, server_addr) = start_server(false).await; - - // Start the proxy - let addr = start_proxy(server_addr, ProxyVersion::V2) - .await - .expect("Failed to start proxy"); - - // Connect to the proxy - let (mut client, _conn, _client_addr) = connect(addr).await; - - // Send a request to the proxy - match client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await - { - // TODO(This should fail when there is no proxy protocol support, perhaps) - Ok(_o) => { - //dbg!(_o); - //() - } - Err(e) => { - if e.is_incomplete_message() { - } else { - panic!("Received unexpected error"); - } - } - } - } - - #[tokio::test] - async fn parsing_when_header_not_present_fails() { - let (_handle, _server_task, server_addr) = start_server(true).await; - - let addr = start_proxy(server_addr, ProxyVersion::None) - .await - .expect("Failed to start proxy"); - - let (mut client, _conn, _client_addr) = connect(addr).await; - - match client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await - { - Ok(_) => panic!("Should have failed"), - Err(e) => { - if e.is_incomplete_message() { - } else { - panic!("Received unexpected error"); - } - } - } - } - - async fn forward_ip_handler(req: Request) -> Response { - let mut response = Response::new(Body::from("Hello, world!")); - - if let Some(header_value) = req.headers().get("Forwarded") { - response - .headers_mut() - .insert("Forwarded", header_value.clone()); - } - - response - } - - async fn start_server( - parse_proxy_header: bool, - ) -> (Handle, JoinHandle>, SocketAddr) { - let handle = Handle::new(); - - let server_handle = handle.clone(); - let server_task = tokio::spawn(async move { - let app = Router::new().route("/", get(forward_ip_handler)); - - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - if parse_proxy_header { - Server::bind(addr) - .handle(server_handle) - .enable_proxy_protocol(None) - .serve(app.into_make_service()) - .await - } else { - Server::bind(addr) - .handle(server_handle) - .serve(app.into_make_service()) - .await - } - }); - - let addr = handle.listening().await.unwrap(); - - (handle, server_task, addr) - } - - #[cfg(feature = "tls-rustls")] - async fn start_rustls_server() -> (Handle, JoinHandle>, SocketAddr) { - let handle = Handle::new(); - - let server_handle = handle.clone(); - let server_task = tokio::spawn(async move { - let app = Router::new().route("/", get(forward_ip_handler)); - - let config = RustlsConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await?; - - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - tls_rustls::bind_rustls(addr, config) - .handle(server_handle) - .enable_proxy_protocol(None) - .serve(app.into_make_service()) - .await - }); - - let addr = handle.listening().await.unwrap(); - - (handle, server_task, addr) - } - - #[cfg(feature = "tls-openssl")] - async fn start_openssl_server() -> (Handle, JoinHandle>, SocketAddr) { - let handle = Handle::new(); - - let server_handle = handle.clone(); - let server_task = tokio::spawn(async move { - let app = Router::new().route("/", get(forward_ip_handler)); - - let config = OpenSSLConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .unwrap(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - tls_openssl::bind_openssl(addr, config) - .handle(server_handle) - .enable_proxy_protocol(None) - .serve(app.into_make_service()) - .await - }); - - let addr = handle.listening().await.unwrap(); - - (handle, server_task, addr) - } - - #[derive(Debug, Clone, Copy)] - enum ProxyVersion { - V1, - V2, - None, - } - - async fn start_proxy( - server_address: SocketAddr, - proxy_version: ProxyVersion, - ) -> Result> { - let proxy_address = SocketAddr::from(([127, 0, 0, 1], 0)); - let listener = TcpListener::bind(proxy_address).await?; - let proxy_address = listener.local_addr()?; - - let _proxy_task = tokio::spawn(async move { - loop { - match listener.accept().await { - Ok((client_stream, _)) => { - tokio::spawn(async move { - if let Err(e) = - handle_conn(client_stream, server_address, proxy_version).await - { - println!("Error handling connection: {:?}", e); - } - }); - } - Err(e) => println!("Failed to accept a connection: {:?}", e), - } - } - }); - - Ok(proxy_address) - } - - async fn handle_conn( - mut client_stream: TcpStream, - server_address: SocketAddr, - proxy_version: ProxyVersion, - ) -> io::Result<()> { - let client_address = client_stream.peer_addr()?; // Get the address before splitting - let mut server_stream = TcpStream::connect(server_address).await?; - let server_address = server_stream.peer_addr()?; // Get the address before splitting - - let (mut client_read, mut client_write) = client_stream.split(); - let (mut server_read, mut server_write) = server_stream.split(); - - send_proxy_header( - &mut server_write, - client_address, - server_address, - proxy_version, - ) - .await?; - - let duration = Duration::from_secs(1); - let client_to_server = async { - match timeout(duration, transfer(&mut client_read, &mut server_write)).await { - Ok(result) => result, - Err(_) => Err(io::Error::new( - io::ErrorKind::TimedOut, - "Client to Server transfer timed out", - )), - } - }; - - let server_to_client = async { - match timeout(duration, transfer(&mut server_read, &mut client_write)).await { - Ok(result) => result, - Err(_) => Err(io::Error::new( - io::ErrorKind::TimedOut, - "Server to Client transfer timed out", - )), - } - }; - - let _ = tokio::try_join!(client_to_server, server_to_client); - - Ok(()) - } - - async fn transfer( - read_stream: &mut (impl AsyncReadExt + Unpin), - write_stream: &mut (impl AsyncWriteExt + Unpin), - ) -> io::Result<()> { - let mut buf = [0; 4096]; - loop { - let n = read_stream.read(&mut buf).await?; - if n == 0 { - break; // EOF - } - write_stream.write_all(&buf[..n]).await?; - } - Ok(()) - } - - async fn send_proxy_header( - write_stream: &mut (impl AsyncWriteExt + Unpin), - client_address: SocketAddr, - server_address: SocketAddr, - proxy_version: ProxyVersion, - ) -> io::Result<()> { - match proxy_version { - ProxyVersion::V1 => { - let header = ppp::v1::Addresses::from((client_address, server_address)).to_string(); - - for byte in header.as_bytes() { - write_stream.write_all(&[*byte]).await?; - } - } - ProxyVersion::V2 => { - let mut header = Builder::with_addresses( - // Declare header as mutable - Version::Two | Command::Proxy, - Protocol::Stream, - (client_address, server_address), - ) - .write_tlv(Type::NoOp, b"Hello, World!")? - .build()?; - - for byte in header.drain(..) { - write_stream.write_all(&[byte]).await?; - } - } - ProxyVersion::None => {} - } - - Ok(()) - } - - async fn connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>, SocketAddr) { - let stream = TcpStream::connect(addr).await.unwrap(); - let client_addr = stream.local_addr().unwrap(); - - let (send_request, connection) = handshake(stream).await.unwrap(); - - let task = tokio::spawn(async move { - let _ = connection.await; - }); - - (send_request, task, client_addr) - } - - #[cfg(feature = "tls-rustls")] - async fn rustls_connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>, SocketAddr) { - let stream = TcpStream::connect(addr).await.unwrap(); - let client_addr = stream.local_addr().unwrap(); - let tls_stream = rustls_connector() - .connect(rustls_dns_name(), stream) - .await - .unwrap(); - - let (send_request, connection) = handshake(tls_stream).await.unwrap(); - - let task = tokio::spawn(async move { - let _ = connection.await; - }); - - (send_request, task, client_addr) - } - - #[cfg(feature = "tls-openssl")] - async fn openssl_connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>, SocketAddr) { - let stream = TcpStream::connect(addr).await.unwrap(); - let client_addr = stream.local_addr().unwrap(); - let tls_stream = openssl_connector(openssl_dns_name(), stream).await; - - let (send_request, connection) = handshake(tls_stream).await.unwrap(); - - let task = tokio::spawn(async move { - let _ = connection.await; - }); - - (send_request, task, client_addr) - } - - async fn send_empty_request(client: &mut SendRequest) -> (response::Parts, Bytes) { - let (parts, body) = client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await - .unwrap() - .into_parts(); - let body = hyper::body::to_bytes(body).await.unwrap(); - - (parts, body) - } -} diff --git a/src/server.rs b/src/server.rs deleted file mode 100644 index 6db138a..0000000 --- a/src/server.rs +++ /dev/null @@ -1,523 +0,0 @@ -#[cfg(feature = "proxy-protocol")] -use crate::proxy_protocol::ProxyProtocolAcceptor; -use crate::{ - accept::{Accept, DefaultAcceptor}, - addr_incoming_config::AddrIncomingConfig, - handle::Handle, - http_config::HttpConfig, - service::{MakeServiceRef, SendService}, -}; -use futures_util::future::poll_fn; -use http::Request; -use hyper::server::{ - accept::Accept as HyperAccept, - conn::{AddrIncoming, AddrStream}, -}; -#[cfg(feature = "proxy-protocol")] -use std::time::Duration; -use std::{ - io::{self, ErrorKind}, - net::SocketAddr, - pin::Pin, -}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - net::TcpListener, -}; - -/// Represents an HTTP server with customization capabilities for handling incoming requests. -#[derive(Debug)] -pub struct Server { - acceptor: A, - listener: Listener, - addr_incoming_conf: AddrIncomingConfig, - handle: Handle, - http_conf: HttpConfig, - #[cfg(feature = "proxy-protocol")] - proxy_acceptor_set: bool, -} - -/// Enum representing the ways the server can be initialized - either by binding to an address or from a standard TCP listener. -#[derive(Debug)] -enum Listener { - Bind(SocketAddr), - Std(std::net::TcpListener), -} - -/// Creates a new [`Server`] instance that binds to the provided address. -pub fn bind(addr: SocketAddr) -> Server { - Server::bind(addr) -} - -/// Creates a new [`Server`] instance using an existing `std::net::TcpListener`. -pub fn from_tcp(listener: std::net::TcpListener) -> Server { - Server::from_tcp(listener) -} - -impl Server { - /// Constructs a server bound to the provided address. - pub fn bind(addr: SocketAddr) -> Self { - let acceptor = DefaultAcceptor::new(); - let handle = Handle::new(); - - Self { - acceptor, - listener: Listener::Bind(addr), - addr_incoming_conf: AddrIncomingConfig::default(), - handle, - http_conf: HttpConfig::default(), - #[cfg(feature = "proxy-protocol")] - proxy_acceptor_set: false, - } - } - - /// Constructs a server from an existing `std::net::TcpListener`. - pub fn from_tcp(listener: std::net::TcpListener) -> Self { - let acceptor = DefaultAcceptor::new(); - let handle = Handle::new(); - - Self { - acceptor, - listener: Listener::Std(listener), - addr_incoming_conf: AddrIncomingConfig::default(), - handle, - http_conf: HttpConfig::default(), - #[cfg(feature = "proxy-protocol")] - proxy_acceptor_set: false, - } - } -} - -impl Server { - /// Replace the current acceptor with a new one. - pub fn acceptor(self, acceptor: Acceptor) -> Server { - #[cfg(feature = "proxy-protocol")] - if self.proxy_acceptor_set { - panic!("Overwriting the acceptor after proxy protocol is enabled is not supported. Configure the acceptor first in the builder, then enable proxy protocol."); - } - - Server { - acceptor, - listener: self.listener, - addr_incoming_conf: self.addr_incoming_conf, - handle: self.handle, - http_conf: self.http_conf, - #[cfg(feature = "proxy-protocol")] - proxy_acceptor_set: self.proxy_acceptor_set, - } - } - - #[cfg(feature = "proxy-protocol")] - /// Enable proxy protocol header parsing. - /// Note has to be called after initial acceptor is set. - pub fn enable_proxy_protocol( - self, - parsing_timeout: Option, - ) -> Server> { - let initial_acceptor = self.acceptor; - let mut acceptor = ProxyProtocolAcceptor::new(initial_acceptor); - - if let Some(val) = parsing_timeout { - acceptor = acceptor.parsing_timeout(val); - } - - Server { - acceptor, - listener: self.listener, - addr_incoming_conf: self.addr_incoming_conf, - handle: self.handle, - http_conf: self.http_conf, - proxy_acceptor_set: true, - } - } - - /// Maps the current acceptor to a new type. - pub fn map(self, acceptor: F) -> Server - where - F: FnOnce(A) -> Acceptor, - { - Server { - acceptor: acceptor(self.acceptor), - listener: self.listener, - addr_incoming_conf: self.addr_incoming_conf, - handle: self.handle, - http_conf: self.http_conf, - #[cfg(feature = "proxy-protocol")] - proxy_acceptor_set: self.proxy_acceptor_set, - } - } - - /// Retrieves a reference to the server's acceptor. - pub fn get_ref(&self) -> &A { - &self.acceptor - } - - /// Retrieves a mutable reference to the server's acceptor. - pub fn get_mut(&mut self) -> &mut A { - &mut self.acceptor - } - - /// Provides the server with a handle for extra utilities. - pub fn handle(mut self, handle: Handle) -> Self { - self.handle = handle; - self - } - - /// Replaces the current HTTP configuration. - pub fn http_config(mut self, config: HttpConfig) -> Self { - self.http_conf = config; - self - } - - /// Replaces the current incoming address configuration. - pub fn addr_incoming_config(mut self, config: AddrIncomingConfig) -> Self { - self.addr_incoming_conf = config; - self - } - - /// Serves the provided `MakeService`. - /// - /// The `MakeService` is responsible for constructing services for each incoming connection. - /// Each service is then used to handle requests from that specific connection. - /// - /// # Arguments - /// - `make_service`: A mutable reference to a type implementing the `MakeServiceRef` trait. - /// This will be used to produce a service for each incoming connection. - /// - /// # Errors - /// - /// This method can return errors in the following scenarios: - /// - When binding to an address fails. - /// - If the `make_service` function encounters an error during its `poll_ready` call. - /// It's worth noting that this error scenario doesn't typically occur with `axum` make services. - /// - pub async fn serve(self, mut make_service: M) -> io::Result<()> - where - M: MakeServiceRef>, - A: Accept + Clone + Send + Sync + 'static, - A::Stream: AsyncRead + AsyncWrite + Unpin + Send, - A::Service: SendService> + Send, - A::Future: Send, - { - // Extract relevant fields from `self` for easier access. - let acceptor = self.acceptor; - let addr_incoming_conf = self.addr_incoming_conf; - let handle = self.handle; - let http_conf = self.http_conf; - - // Bind the incoming connections. Notify the handle if an error occurs during binding. - let mut incoming = match bind_incoming(self.listener, addr_incoming_conf).await { - Ok(v) => v, - Err(e) => { - handle.notify_listening(None); - return Err(e); - } - }; - - // Notify the handle about the server's listening state. - handle.notify_listening(Some(incoming.local_addr())); - - // This is the main loop that accepts incoming connections and spawns tasks to handle them. - let accept_loop_future = async { - loop { - // Wait for a new connection or for the server to be signaled to shut down. - let addr_stream = tokio::select! { - biased; - result = accept(&mut incoming) => result?, - _ = handle.wait_graceful_shutdown() => return Ok(()), - }; - - // Ensure the `make_service` is ready to produce another service. - poll_fn(|cx| make_service.poll_ready(cx)) - .await - .map_err(io_other)?; - - // Create a service for this connection. - let service = match make_service.make_service(&addr_stream).await { - Ok(service) => service, - Err(_) => continue, // TODO: Consider logging or handling this error in a more detailed manner. - }; - - // Clone necessary objects for the spawned task. - let acceptor = acceptor.clone(); - let watcher = handle.watcher(); - let http_conf = http_conf.clone(); - - // Spawn a new task to handle the connection. - tokio::spawn(async move { - if let Ok((stream, send_service)) = acceptor.accept(addr_stream, service).await - { - let service = send_service.into_service(); - - let mut serve_future = http_conf - .inner - .serve_connection(stream, service) - .with_upgrades(); - - // Wait for either the server to be shut down or the connection to finish. - tokio::select! { - biased; - _ = watcher.wait_graceful_shutdown() => { - // Initiate a graceful shutdown. - Pin::new(&mut serve_future).graceful_shutdown(); - tokio::select! { - biased; - _ = watcher.wait_shutdown() => (), - _ = &mut serve_future => (), - } - } - _ = watcher.wait_shutdown() => (), - _ = &mut serve_future => (), - } - } - // TODO: Consider logging or handling any errors that occur during acceptance. - }); - } - }; - - // Wait for either the server to be fully shut down or an error to occur. - let result = tokio::select! { - biased; - _ = handle.wait_shutdown() => return Ok(()), - result = accept_loop_future => result, - }; - - // Handle potential errors. - // TODO: Consider removing the Clippy annotation by restructuring this error handling. - #[allow(clippy::question_mark)] - if let Err(e) = result { - return Err(e); - } - - // Wait for all connections to end. - handle.wait_connections_end().await; - - Ok(()) - } -} - -/// Binds the listener based on the provided configuration and returns an [`AddrIncoming`] -/// which will produce [`AddrStream`]s for incoming connections. -/// -/// The function takes into account different ways the listener might be set up, -/// either by binding to a provided address or by using an existing standard listener. -/// -/// # Arguments -/// -/// - `listener`: The listener configuration. Can be either a direct bind address or an existing standard listener. -/// - `addr_incoming_conf`: Configuration for the incoming connections, such as TCP keepalive settings. -/// -/// # Errors -/// -/// Returns an `io::Error` if: -/// - Binding the listener fails. -/// - Setting the listener to non-blocking mode fails. -/// - The listener cannot be converted to a [`TcpListener`]. -/// - An error occurs when creating the [`AddrIncoming`]. -/// -async fn bind_incoming( - listener: Listener, - addr_incoming_conf: AddrIncomingConfig, -) -> io::Result { - let listener = match listener { - Listener::Bind(addr) => TcpListener::bind(addr).await?, - Listener::Std(std_listener) => { - std_listener.set_nonblocking(true)?; - TcpListener::from_std(std_listener)? - } - }; - let mut incoming = AddrIncoming::from_listener(listener).map_err(io_other)?; - - // Apply configuration settings to the incoming connection handler. - incoming.set_sleep_on_errors(addr_incoming_conf.tcp_sleep_on_accept_errors); - incoming.set_keepalive(addr_incoming_conf.tcp_keepalive); - incoming.set_keepalive_interval(addr_incoming_conf.tcp_keepalive_interval); - incoming.set_keepalive_retries(addr_incoming_conf.tcp_keepalive_retries); - incoming.set_nodelay(addr_incoming_conf.tcp_nodelay); - - Ok(incoming) -} - -/// Awaits and accepts a new incoming connection. -/// -/// This function will poll the given `incoming` object until a new connection is ready to be accepted. -/// -/// # Arguments -/// -/// - `incoming`: The incoming connection handler from which new connections will be accepted. -/// -/// # Returns -/// -/// Returns the accepted [`AddrStream`] which represents a specific incoming connection. -/// -/// # Panics -/// -/// This function will panic if the `poll_accept` method returns `None`, which should never happen as per the Hyper documentation. -/// -pub(crate) async fn accept(incoming: &mut AddrIncoming) -> io::Result { - let mut incoming = Pin::new(incoming); - - // Always [`Option::Some`]. - // According to: https://docs.rs/hyper/0.14.14/src/hyper/server/tcp.rs.html#165 - poll_fn(|cx| incoming.as_mut().poll_accept(cx)) - .await - .unwrap() -} - -/// Type definition for a boxed error which can be sent between threads and is Sync. -type BoxError = Box; - -/// Converts any error into an `io::Error` of kind `Other`. -/// -/// This function can be used to create a uniform `io::Error` response for various error types. -/// -/// # Arguments -/// -/// - `error`: The error to be converted. -/// -/// # Returns -/// -/// Returns an `io::Error` with the kind set to `Other` and the provided error as its cause. -/// -pub(crate) fn io_other>(error: E) -> io::Error { - io::Error::new(ErrorKind::Other, error) -} - -#[cfg(test)] -mod tests { - use crate::{handle::Handle, server::Server}; - use axum::{routing::get, Router}; - use bytes::Bytes; - use http::{response, Request}; - use hyper::{ - client::conn::{handshake, SendRequest}, - Body, - }; - use std::{io, net::SocketAddr, time::Duration}; - use tokio::{net::TcpStream, task::JoinHandle, time::timeout}; - use tower::{Service, ServiceExt}; - - #[tokio::test] - async fn start_and_request() { - let (_handle, _server_task, addr) = start_server().await; - - let (mut client, _conn) = connect(addr).await; - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - } - - #[tokio::test] - async fn test_shutdown() { - let (handle, _server_task, addr) = start_server().await; - - let (mut client, conn) = connect(addr).await; - - handle.shutdown(); - - let response_future_result = client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await; - - assert!(response_future_result.is_err()); - - // Connection task should finish soon. - let _ = timeout(Duration::from_secs(1), conn).await.unwrap(); - } - - #[tokio::test] - async fn test_graceful_shutdown() { - let (handle, server_task, addr) = start_server().await; - - let (mut client, conn) = connect(addr).await; - - handle.graceful_shutdown(None); - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - - // Disconnect client. - conn.abort(); - - // TODO(This does not shut down gracefully) - // Server task should finish soon. - let server_result = timeout(Duration::from_secs(1), server_task) - .await - .unwrap() - .unwrap(); - - assert!(server_result.is_ok()); - } - - #[tokio::test] - async fn test_graceful_shutdown_timed() { - let (handle, server_task, addr) = start_server().await; - - let (mut client, _conn) = connect(addr).await; - - handle.graceful_shutdown(Some(Duration::from_millis(250))); - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - - // Server task should finish soon. - let server_result = timeout(Duration::from_secs(1), server_task) - .await - .unwrap() - .unwrap(); - - assert!(server_result.is_ok()); - } - - async fn start_server() -> (Handle, JoinHandle>, SocketAddr) { - let handle = Handle::new(); - - let server_handle = handle.clone(); - let server_task = tokio::spawn(async move { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - Server::bind(addr) - .handle(server_handle) - .serve(app.into_make_service()) - .await - }); - - let addr = handle.listening().await.unwrap(); - - (handle, server_task, addr) - } - - async fn connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>) { - let stream = TcpStream::connect(addr).await.unwrap(); - - let (send_request, connection) = handshake(stream).await.unwrap(); - - let task = tokio::spawn(async move { - let _ = connection.await; - }); - - (send_request, task) - } - - async fn send_empty_request(client: &mut SendRequest) -> (response::Parts, Bytes) { - let (parts, body) = client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await - .unwrap() - .into_parts(); - let body = hyper::body::to_bytes(body).await.unwrap(); - - (parts, body) - } -} diff --git a/src/service.rs b/src/service.rs deleted file mode 100644 index ee8848d..0000000 --- a/src/service.rs +++ /dev/null @@ -1,164 +0,0 @@ -//! Module containing service traits. -//! These traits are vital for handling requests and creating services within the server. - -use http::Response; -use http_body::Body; -use std::{ - future::Future, - task::{Context, Poll}, -}; -use tower_service::Service; - -// TODO(Document the types here to disable the clippy annotation) - -/// An alias trait for the [`Service`] trait, specialized with required bounds for the server's service function. -/// This trait has been sealed, ensuring it cannot be implemented by types outside of this crate. -/// -/// It provides constraints for the body data, errors, and asynchronous behavior that fits the server's needs. -#[allow(missing_docs)] -pub trait SendService: send_service::Sealed { - type Service: Service< - Request, - Response = Response, - Error = Self::Error, - Future = Self::Future, - > + Send - + 'static; - - type Body: Body + Send + 'static; - type BodyData: Send + 'static; - type BodyError: Into>; - - type Error: Into>; - type Future: Future, Self::Error>> + Send + 'static; - - /// Convert this type into a service. - fn into_service(self) -> Self::Service; -} - -impl send_service::Sealed for T -where - T: Service>, - T::Error: Into>, - T::Future: Send + 'static, - B: Body + Send + 'static, - B::Data: Send + 'static, - B::Error: Into>, -{ -} - -impl SendService for T -where - T: Service> + Send + 'static, - T::Error: Into>, - T::Future: Send + 'static, - B: Body + Send + 'static, - B::Data: Send + 'static, - B::Error: Into>, -{ - type Service = T; - - type Body = B; - type BodyData = B::Data; - type BodyError = B::Error; - - type Error = T::Error; - - type Future = T::Future; - - fn into_service(self) -> Self::Service { - self - } -} - -/// A variant of the [`MakeService`] trait that accepts a `&Target` reference. -/// This trait has been sealed, ensuring it cannot be implemented by types outside of this crate. -/// It is specifically designed for the server's `serve` function. -/// -/// This trait provides a mechanism to create services upon request, with the required trait bounds. -/// -/// [`MakeService`]: https://docs.rs/tower/0.4/tower/make/trait.MakeService.html -#[allow(missing_docs)] -pub trait MakeServiceRef: make_service_ref::Sealed<(Target, Request)> { - type Service: Service< - Request, - Response = Response, - Error = Self::Error, - Future = Self::Future, - > + Send - + 'static; - - type Body: Body + Send + 'static; - type BodyData: Send + 'static; - type BodyError: Into>; - - type Error: Into>; - type Future: Future, Self::Error>> + Send + 'static; - - type MakeError: Into>; - type MakeFuture: Future>; - - /// Polls to check if the service factory is ready to create a service. - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll>; - - /// Creates and returns a service for the provided target. - fn make_service(&mut self, target: &Target) -> Self::MakeFuture; -} - -impl make_service_ref::Sealed<(Target, Request)> for T -where - T: for<'a> Service<&'a Target, Response = S, Error = E, Future = F>, - S: Service> + Send + 'static, - S::Error: Into>, - S::Future: Send + 'static, - B: Body + Send + 'static, - B::Data: Send + 'static, - B::Error: Into>, - E: Into>, - F: Future>, -{ -} - -impl MakeServiceRef for T -where - T: for<'a> Service<&'a Target, Response = S, Error = E, Future = F>, - S: Service> + Send + 'static, - S::Error: Into>, - S::Future: Send + 'static, - B: Body + Send + 'static, - B::Data: Send + 'static, - B::Error: Into>, - E: Into>, - F: Future>, -{ - type Service = S; - - type Body = B; - type BodyData = B::Data; - type BodyError = B::Error; - - type Error = S::Error; - - type Future = S::Future; - - type MakeError = E; - type MakeFuture = F; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.poll_ready(cx) - } - - fn make_service(&mut self, target: &Target) -> Self::MakeFuture { - self.call(target) - } -} - -// Sealed traits prevent external implementations of our core traits. -// This provides future compatibility guarantees. -mod send_service { - pub trait Sealed {} -} - -mod make_service_ref { - pub trait Sealed {} -} diff --git a/src/tls_openssl/future.rs b/src/tls_openssl/future.rs deleted file mode 100644 index 54f64b6..0000000 --- a/src/tls_openssl/future.rs +++ /dev/null @@ -1,183 +0,0 @@ -//! Future types. -//! -//! This module provides the futures and supporting types for integrating OpenSSL with a hyper/tokio HTTP/TLS server. -//! `OpenSSLAcceptorFuture` is the main public-facing type which wraps around the core logic of establishing an SSL/TLS -//! connection. - -use super::OpenSSLConfig; -use pin_project_lite::pin_project; -use std::io::{Error, ErrorKind}; -use std::time::Duration; -use std::{ - fmt, - future::Future, - io, - pin::Pin, - task::{Context, Poll}, -}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::time::{timeout, Timeout}; - -use openssl::ssl::Ssl; -use tokio_openssl::SslStream; - -// The OpenSSLAcceptorFuture encapsulates the asynchronous logic of accepting an SSL/TLS connection. -pin_project! { - /// A Future for establishing an SSL/TLS connection using `OpenSSLAcceptor`. - /// - /// This wraps around the process of asynchronously establishing an SSL/TLS connection via OpenSSL. - /// It waits for the inner non-TLS connection to be established, and then handles the TLS handshake. - pub struct OpenSSLAcceptorFuture { - #[pin] - inner: AcceptFuture, // Inner future which manages the state machine of accepting connections. - config: Option, // The SSL/TLS configuration to use for the handshake. - } -} - -impl OpenSSLAcceptorFuture { - /// Constructs a new `OpenSSLAcceptorFuture`. - /// - /// # Arguments - /// - `future`: The initial future that handles the non-TLS accept phase. - /// - `config`: SSL/TLS configuration. - /// - `handshake_timeout`: Maximum duration allowed for the TLS handshake. - pub(crate) fn new(future: F, config: OpenSSLConfig, handshake_timeout: Duration) -> Self { - let inner = AcceptFuture::InnerAccepting { - future, - handshake_timeout, - }; - let config = Some(config); - - Self { inner, config } - } -} - -impl fmt::Debug for OpenSSLAcceptorFuture { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("OpenSSLAcceptorFuture").finish() - } -} - -// A future for performing the SSL/TLS handshake using an `SslStream`. -pin_project! { - struct TlsAccept { - #[pin] - tls_stream: Option>, // The SSL/TLS stream on which the handshake will be performed. - } -} - -impl Future for TlsAccept -where - I: AsyncRead + AsyncWrite + Unpin, // The inner type must support asynchronous reading and writing. -{ - type Output = io::Result>; // The result will be an `SslStream` if the handshake is successful. - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - - match this - .tls_stream - .as_mut() - .as_pin_mut() - .map(|inner| inner.poll_accept(cx)) - .expect("tlsaccept polled after ready") - { - Poll::Ready(Ok(())) => { - let tls_stream = this.tls_stream.take().expect("tls stream vanished?"); - Poll::Ready(Ok(tls_stream)) - } - Poll::Ready(Err(e)) => Poll::Ready(Err(Error::new(ErrorKind::Other, e))), - Poll::Pending => Poll::Pending, - } - } -} - -// Enumerates the possible states of the accept process, either waiting for the inner non-TLS -// connection to be accepted, or performing the TLS handshake. -pin_project! { - #[project = AcceptFutureProj] - enum AcceptFuture { - // Waiting for the non-TLS connection to be accepted. - InnerAccepting { - #[pin] - future: F, // The future representing the non-TLS accept phase. - handshake_timeout: Duration, // Maximum duration for the TLS handshake. - }, - // Performing the TLS handshake. - TlsAccepting { - #[pin] - future: Timeout>, // Future that represents the TLS handshake, with a timeout. - service: Option, // The underlying service that will handle the request after the TLS handshake. - } - } -} - -// Main implementation of the future for `OpenSSLAcceptor`. -impl Future for OpenSSLAcceptorFuture -where - F: Future>, // The initial non-TLS accept future. - I: AsyncRead + AsyncWrite + Unpin, // The inner type must support asynchronous reading and writing. -{ - type Output = io::Result<(SslStream, S)>; // The output will be an `SslStream` and the service to handle the request. - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - - // The inner future here is what is doing the lower level accept, such as - // our tcp socket. - // - // So we poll on that first, when it's ready we then swap our the inner future to - // one waiting for our ssl layer to accept/install. - // - // Then once that's ready we can then wrap and provide the SslStream back out. - - // This loop exists to allow the Poll::Ready from InnerAccept on complete - // to re-poll immediately. Otherwise all other paths are immediate returns. - loop { - match this.inner.as_mut().project() { - AcceptFutureProj::InnerAccepting { - future, - handshake_timeout, - } => match future.poll(cx) { - Poll::Ready(Ok((stream, service))) => { - let server_config = this.config.take().expect( - "config is not set. this is a bug in hyper-server, please report", - ); - - // Change to poll::ready(err) - let ssl = Ssl::new(server_config.acceptor.context()).unwrap(); - - let tls_stream = SslStream::new(ssl, stream).unwrap(); - let future = TlsAccept { - tls_stream: Some(tls_stream), - }; - - let service = Some(service); - let handshake_timeout = *handshake_timeout; - - this.inner.set(AcceptFuture::TlsAccepting { - future: timeout(handshake_timeout, future), - service, - }); - // the loop is now triggered to immediately poll on - // ssl stream accept. - } - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Pending => return Poll::Pending, - }, - - AcceptFutureProj::TlsAccepting { future, service } => match future.poll(cx) { - Poll::Ready(Ok(Ok(stream))) => { - let service = service.take().expect("future polled after ready"); - return Poll::Ready(Ok((stream, service))); - } - Poll::Ready(Ok(Err(e))) => return Poll::Ready(Err(e)), - Poll::Ready(Err(timeout)) => { - return Poll::Ready(Err(Error::new(ErrorKind::TimedOut, timeout))) - } - Poll::Pending => return Poll::Pending, - }, - } - } - } -} diff --git a/src/tls_openssl/mod.rs b/src/tls_openssl/mod.rs deleted file mode 100644 index f108517..0000000 --- a/src/tls_openssl/mod.rs +++ /dev/null @@ -1,412 +0,0 @@ -//! Tls implementation using [`openssl`] -//! -//! # Example -//! -//! ```rust,no_run -//! use axum::{routing::get, Router}; -//! use hyper_server::tls_openssl::OpenSSLConfig; -//! use std::net::SocketAddr; -//! -//! #[tokio::main] -//! async fn main() { -//! let app = Router::new().route("/", get(|| async { "Hello, world!" })); -//! -//! let config = OpenSSLConfig::from_pem_file( -//! "examples/self-signed-certs/cert.pem", -//! "examples/self-signed-certs/key.pem", -//! ) -//! .unwrap(); -//! -//! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); -//! println!("listening on {}", addr); -//! hyper_server::bind_openssl(addr, config) -//! .serve(app.into_make_service()) -//! .await -//! .unwrap(); -//! } -//! ``` - -use self::future::OpenSSLAcceptorFuture; -use crate::{ - accept::{Accept, DefaultAcceptor}, - server::Server, -}; -use openssl::ssl::{ - Error as OpenSSLError, SslAcceptor, SslAcceptorBuilder, SslFiletype, SslMethod, -}; -use std::{convert::TryFrom, fmt, net::SocketAddr, path::Path, sync::Arc, time::Duration}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_openssl::SslStream; - -pub mod future; - -/// Binds a TLS server using OpenSSL to the specified address with the given configuration. -/// -/// The server is configured to accept TLS encrypted connections. -/// -/// # Arguments -/// -/// * `addr`: The address to which the server will bind. -/// * `config`: The TLS configuration for the server. -/// -/// # Returns -/// -/// A configured `Server` instance ready to be run. -pub fn bind_openssl(addr: SocketAddr, config: OpenSSLConfig) -> Server { - let acceptor = OpenSSLAcceptor::new(config); - Server::bind(addr).acceptor(acceptor) -} - -/// Represents a TLS acceptor that uses OpenSSL for cryptographic operations. -/// -/// This structure is used for handling TLS encrypted connections. -/// -/// The acceptor is backed by OpenSSL, and is used to upgrade incoming non-secure connections -/// to secure TLS connections. -/// -/// The default TLS handshake timeout is set to 10 seconds. -#[derive(Clone)] -pub struct OpenSSLAcceptor { - inner: A, - config: OpenSSLConfig, - handshake_timeout: Duration, -} - -impl OpenSSLAcceptor { - /// Constructs a new instance of the OpenSSL acceptor. - /// - /// # Arguments - /// - /// * `config`: Configuration options for the OpenSSL server. - pub fn new(config: OpenSSLConfig) -> Self { - let inner = DefaultAcceptor::new(); - - // Default handshake timeout is 10 seconds. - #[cfg(not(test))] - let handshake_timeout = Duration::from_secs(10); - - // For tests, use a shorter timeout to avoid unnecessary delays. - #[cfg(test)] - let handshake_timeout = Duration::from_secs(1); - - Self { - inner, - config, - handshake_timeout, - } - } - - /// Overrides the default TLS handshake timeout. - /// - /// # Arguments - /// - /// * `val`: The duration to set as the new handshake timeout. - /// - /// # Returns - /// - /// A modified version of the current acceptor with the new timeout value. - pub fn handshake_timeout(mut self, val: Duration) -> Self { - self.handshake_timeout = val; - self - } -} - -impl Accept for OpenSSLAcceptor -where - A: Accept, - A::Stream: AsyncRead + AsyncWrite + Unpin, -{ - type Stream = SslStream; - type Service = A::Service; - type Future = OpenSSLAcceptorFuture; - - /// Handles the incoming stream, initiates a TLS handshake, and upgrades it to a secure connection. - fn accept(&self, stream: I, service: S) -> Self::Future { - let inner_future = self.inner.accept(stream, service); - let config = self.config.clone(); - - OpenSSLAcceptorFuture::new(inner_future, config, self.handshake_timeout) - } -} - -impl fmt::Debug for OpenSSLAcceptor { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("OpenSSLAcceptor").finish() - } -} - -/// Represents configuration options for an OpenSSL-based server. -/// -/// This configuration is used when constructing a new `OpenSSLAcceptor`. -#[derive(Clone)] -pub struct OpenSSLConfig { - acceptor: Arc, -} - -impl OpenSSLConfig { - /// Creates a new configuration using a PEM formatted certificate and key. - /// - /// # Arguments - /// - /// * `cert`: Path to the PEM-formatted certificate file. - /// * `key`: Path to the PEM-formatted private key file. - /// - /// # Returns - /// - /// A `Result` that contains an `OpenSSLConfig` or an `OpenSSLError`. - pub fn from_pem_file, B: AsRef>( - cert: A, - key: B, - ) -> Result { - let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?; - - tls_builder.set_certificate_file(cert, SslFiletype::PEM)?; - tls_builder.set_private_key_file(key, SslFiletype::PEM)?; - tls_builder.check_private_key()?; - - let acceptor = Arc::new(tls_builder.build()); - - Ok(OpenSSLConfig { acceptor }) - } - - /// Creates a new configuration using a PEM formatted certificate chain and key. - /// - /// # Arguments - /// - /// * `chain`: Path to the PEM-formatted certificate chain file. - /// * `key`: Path to the PEM-formatted private key file. - /// - /// # Returns - /// - /// A `Result` that contains an `OpenSSLConfig` or an `OpenSSLError`. - pub fn from_pem_chain_file, B: AsRef>( - chain: A, - key: B, - ) -> Result { - let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?; - - tls_builder.set_certificate_chain_file(chain)?; - tls_builder.set_private_key_file(key, SslFiletype::PEM)?; - tls_builder.check_private_key()?; - - let acceptor = Arc::new(tls_builder.build()); - - Ok(OpenSSLConfig { acceptor }) - } -} - -impl TryFrom for OpenSSLConfig { - type Error = OpenSSLError; - - /// Constructs [`OpenSSLConfig`] from an [`SslAcceptorBuilder`]. This allows precise - /// control over the settings that will be used by OpenSSL in this server. - /// - /// # Example - /// ``` - /// use hyper_server::tls_openssl::OpenSSLConfig; - /// use openssl::ssl::{SslAcceptor, SslMethod}; - /// use std::convert::TryFrom; - /// - /// #[tokio::main] - /// async fn main() { - /// let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls()) - /// .unwrap(); - /// // Set configurations like set_certificate_chain_file or - /// // set_private_key_file. - /// // let tls_builder.set_ ... ; - - /// let _config = OpenSSLConfig::try_from(tls_builder); - /// } - /// ``` - fn try_from(tls_builder: SslAcceptorBuilder) -> Result { - tls_builder.check_private_key()?; - let acceptor = Arc::new(tls_builder.build()); - Ok(OpenSSLConfig { acceptor }) - } -} - -impl fmt::Debug for OpenSSLConfig { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("OpenSSLConfig").finish() - } -} - -#[cfg(test)] -pub(crate) mod tests { - use crate::{ - handle::Handle, - tls_openssl::{self, OpenSSLConfig}, - }; - use axum::{routing::get, Router}; - use bytes::Bytes; - use http::{response, Request}; - use hyper::{ - client::conn::{handshake, SendRequest}, - Body, - }; - use std::{io, net::SocketAddr, time::Duration}; - use tokio::{net::TcpStream, task::JoinHandle, time::timeout}; - use tower::{Service, ServiceExt}; - - use openssl::ssl::{Ssl, SslConnector, SslMethod, SslVerifyMode}; - use std::pin::Pin; - use tokio_openssl::SslStream; - - #[tokio::test] - async fn start_and_request() { - let (_handle, _server_task, addr) = start_server().await; - - let (mut client, _conn) = connect(addr).await; - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - } - - #[tokio::test] - async fn test_shutdown() { - let (handle, _server_task, addr) = start_server().await; - - let (mut client, conn) = connect(addr).await; - - handle.shutdown(); - - let response_future_result = client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await; - - assert!(response_future_result.is_err()); - - // Connection task should finish soon. - let _ = timeout(Duration::from_secs(1), conn).await.unwrap(); - } - - #[tokio::test] - async fn test_graceful_shutdown() { - let (handle, server_task, addr) = start_server().await; - - let (mut client, conn) = connect(addr).await; - - handle.graceful_shutdown(None); - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - - // Disconnect client. - conn.abort(); - - // Server task should finish soon. - let server_result = timeout(Duration::from_secs(1), server_task) - .await - .unwrap() - .unwrap(); - - assert!(server_result.is_ok()); - } - - #[tokio::test] - async fn test_graceful_shutdown_timed() { - let (handle, server_task, addr) = start_server().await; - - let (mut client, _conn) = connect(addr).await; - - handle.graceful_shutdown(Some(Duration::from_millis(250))); - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - - // Don't disconnect client. - // conn.abort(); - - // Server task should finish soon. - let server_result = timeout(Duration::from_secs(1), server_task) - .await - .unwrap() - .unwrap(); - - assert!(server_result.is_ok()); - } - - async fn start_server() -> (Handle, JoinHandle>, SocketAddr) { - let handle = Handle::new(); - - let server_handle = handle.clone(); - let server_task = tokio::spawn(async move { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let config = OpenSSLConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .unwrap(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - tls_openssl::bind_openssl(addr, config) - .handle(server_handle) - .serve(app.into_make_service()) - .await - }); - - let addr = handle.listening().await.unwrap(); - - (handle, server_task, addr) - } - - async fn connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>) { - let stream = TcpStream::connect(addr).await.unwrap(); - let tls_stream = tls_connector(dns_name(), stream).await; - - let (send_request, connection) = handshake(tls_stream).await.unwrap(); - - let task = tokio::spawn(async move { - let _ = connection.await; - }); - - (send_request, task) - } - - async fn send_empty_request(client: &mut SendRequest) -> (response::Parts, Bytes) { - let (parts, body) = client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await - .unwrap() - .into_parts(); - let body = hyper::body::to_bytes(body).await.unwrap(); - - (parts, body) - } - - /// Used in `proxy-protocol` feature tests. - pub(crate) async fn tls_connector(hostname: &str, stream: TcpStream) -> SslStream { - let mut tls_parms = SslConnector::builder(SslMethod::tls_client()).unwrap(); - tls_parms.set_verify(SslVerifyMode::NONE); - let hostname_owned = hostname.to_string(); - tls_parms.set_client_hello_callback(move |ssl_ref, _ssl_alert| { - ssl_ref - .set_hostname(hostname_owned.as_str()) - .map(|()| openssl::ssl::ClientHelloResponse::SUCCESS) - }); - let tls_parms = tls_parms.build(); - - let ssl = Ssl::new(tls_parms.context()).unwrap(); - let mut tls_stream = SslStream::new(ssl, stream).unwrap(); - - SslStream::connect(Pin::new(&mut tls_stream)).await.unwrap(); - - tls_stream - } - - /// Used in `proxy-protocol` feature tests. - pub(crate) fn dns_name() -> &'static str { - "localhost" - } -} diff --git a/src/tls_rustls/future.rs b/src/tls_rustls/future.rs deleted file mode 100644 index 843046f..0000000 --- a/src/tls_rustls/future.rs +++ /dev/null @@ -1,130 +0,0 @@ -//! Module containing futures specific to the `rustls` TLS acceptor for the server. -//! -//! This module primarily provides the `RustlsAcceptorFuture` which is responsible for performing the TLS handshake -//! using the `rustls` library. - -use crate::tls_rustls::RustlsConfig; -use pin_project_lite::pin_project; -use std::io::{Error, ErrorKind}; -use std::time::Duration; -use std::{ - fmt, - future::Future, - io, - pin::Pin, - task::{Context, Poll}, -}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::time::{timeout, Timeout}; -use tokio_rustls::{server::TlsStream, Accept, TlsAcceptor}; - -pin_project! { - /// A future representing the asynchronous TLS handshake using the `rustls` library. - /// - /// Once completed, it yields a `TlsStream` which is a wrapper around the actual underlying stream, with - /// encryption and decryption operations applied to it. - pub struct RustlsAcceptorFuture { - #[pin] - inner: AcceptFuture, - config: Option, - } -} - -impl RustlsAcceptorFuture { - /// Constructs a new `RustlsAcceptorFuture`. - /// - /// * `future`: The future that resolves to the original non-encrypted stream. - /// * `config`: The rustls configuration to use for the handshake. - /// * `handshake_timeout`: The maximum duration to wait for the handshake to complete. - pub(crate) fn new(future: F, config: RustlsConfig, handshake_timeout: Duration) -> Self { - let inner = AcceptFuture::Inner { - future, - handshake_timeout, - }; - let config = Some(config); - - Self { inner, config } - } -} - -impl fmt::Debug for RustlsAcceptorFuture { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RustlsAcceptorFuture").finish() - } -} - -pin_project! { - /// Internal states of the handshake process. - #[project = AcceptFutureProj] - enum AcceptFuture { - /// Initial state where we have a future that resolves to the original non-encrypted stream. - Inner { - #[pin] - future: F, - handshake_timeout: Duration, - }, - /// State after receiving the stream where the handshake is performed asynchronously. - Accept { - #[pin] - future: Timeout>, - service: Option, - }, - } -} - -impl Future for RustlsAcceptorFuture -where - F: Future>, - I: AsyncRead + AsyncWrite + Unpin, -{ - type Output = io::Result<(TlsStream, S)>; - - /// Advances the handshake state machine. - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - - loop { - match this.inner.as_mut().project() { - AcceptFutureProj::Inner { - future, - handshake_timeout, - } => { - // Poll the future to get the original stream. - match future.poll(cx) { - Poll::Ready(Ok((stream, service))) => { - let server_config = this.config - .take() - .expect("config is not set. this is a bug in hyper-server, please report") - .get_inner(); - - let acceptor = TlsAcceptor::from(server_config); - let future = acceptor.accept(stream); - - let service = Some(service); - let handshake_timeout = *handshake_timeout; - - this.inner.set(AcceptFuture::Accept { - future: timeout(handshake_timeout, future), - service, - }); - } - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Pending => return Poll::Pending, - } - } - AcceptFutureProj::Accept { future, service } => match future.poll(cx) { - Poll::Ready(Ok(Ok(stream))) => { - let service = service.take().expect("future polled after ready"); - - return Poll::Ready(Ok((stream, service))); - } - Poll::Ready(Ok(Err(e))) => return Poll::Ready(Err(e)), - Poll::Ready(Err(timeout)) => { - return Poll::Ready(Err(Error::new(ErrorKind::TimedOut, timeout))) - } - Poll::Pending => return Poll::Pending, - }, - } - } - } -} diff --git a/src/tls_rustls/mod.rs b/src/tls_rustls/mod.rs deleted file mode 100644 index af7e391..0000000 --- a/src/tls_rustls/mod.rs +++ /dev/null @@ -1,579 +0,0 @@ -//! Tls implementation using [`rustls`]. -//! -//! # Example -//! -//! ```rust,no_run -//! use axum::{routing::get, Router}; -//! use hyper_server::tls_rustls::RustlsConfig; -//! use std::net::SocketAddr; -//! -//! #[tokio::main] -//! async fn main() { -//! let app = Router::new().route("/", get(|| async { "Hello, world!" })); -//! -//! let config = RustlsConfig::from_pem_file( -//! "examples/self-signed-certs/cert.pem", -//! "examples/self-signed-certs/key.pem", -//! ) -//! .await -//! .unwrap(); -//! -//! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); -//! println!("listening on {}", addr); -//! hyper_server::bind_rustls(addr, config) -//! .serve(app.into_make_service()) -//! .await -//! .unwrap(); -//! } -//! ``` - -use self::future::RustlsAcceptorFuture; -use crate::{ - accept::{Accept, DefaultAcceptor}, - server::{io_other, Server}, -}; -use arc_swap::ArcSwap; -use rustls::{Certificate, PrivateKey, ServerConfig}; -use std::time::Duration; -use std::{fmt, io, net::SocketAddr, path::Path, sync::Arc}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - task::spawn_blocking, -}; -use tokio_rustls::server::TlsStream; - -/// Sub-module that contains re-exported public interfaces. -pub(crate) mod export { - use super::{RustlsAcceptor, RustlsConfig, Server, SocketAddr}; - - /// Creates a TLS server that binds to the provided address using the rustls library. - #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))] - pub fn bind_rustls(addr: SocketAddr, config: RustlsConfig) -> Server { - super::bind_rustls(addr, config) - } - - /// Creates a TLS server from an existing `std::net::TcpListener` using the rustls library. - #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))] - pub fn from_tcp_rustls( - listener: std::net::TcpListener, - config: RustlsConfig, - ) -> Server { - let acceptor = RustlsAcceptor::new(config); - - Server::from_tcp(listener).acceptor(acceptor) - } -} - -pub mod future; - -/// Helper function to create a TLS server bound to a provided address. -pub fn bind_rustls(addr: SocketAddr, config: RustlsConfig) -> Server { - let acceptor = RustlsAcceptor::new(config); - - Server::bind(addr).acceptor(acceptor) -} - -/// Helper function to create a TLS server from an existing `std::net::TcpListener`. -pub fn from_tcp_rustls( - listener: std::net::TcpListener, - config: RustlsConfig, -) -> Server { - let acceptor = RustlsAcceptor::new(config); - - Server::from_tcp(listener).acceptor(acceptor) -} - -/// A TLS acceptor implementation using the rustls library. -#[derive(Clone)] -pub struct RustlsAcceptor { - inner: A, - config: RustlsConfig, - handshake_timeout: Duration, -} - -impl RustlsAcceptor { - /// Constructs a new rustls acceptor with the given configuration. - pub fn new(config: RustlsConfig) -> Self { - let inner = DefaultAcceptor::new(); - - // Default handshake timeout is set to 10 seconds. - // In test mode, this is reduced to 1 second to avoid waiting too long. - #[cfg(not(test))] - let handshake_timeout = Duration::from_secs(10); - #[cfg(test)] - let handshake_timeout = Duration::from_secs(1); - - Self { - inner, - config, - handshake_timeout, - } - } - - /// Allows overriding the default TLS handshake timeout. - pub fn handshake_timeout(mut self, val: Duration) -> Self { - self.handshake_timeout = val; - self - } -} - -impl RustlsAcceptor { - /// Replaces the inner acceptor with a custom acceptor. - pub fn acceptor(self, acceptor: Acceptor) -> RustlsAcceptor { - RustlsAcceptor { - inner: acceptor, - config: self.config, - handshake_timeout: self.handshake_timeout, - } - } -} - -// Implementation to accept incoming TLS connections using rustls. -impl Accept for RustlsAcceptor -where - A: Accept, - A::Stream: AsyncRead + AsyncWrite + Unpin, -{ - type Stream = TlsStream; - type Service = A::Service; - type Future = RustlsAcceptorFuture; - - fn accept(&self, stream: I, service: S) -> Self::Future { - let inner_future = self.inner.accept(stream, service); - let config = self.config.clone(); - - RustlsAcceptorFuture::new(inner_future, config, self.handshake_timeout) - } -} - -impl fmt::Debug for RustlsAcceptor { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RustlsAcceptor").finish() - } -} - -/// Represents the rustls configuration for the server. -#[derive(Clone)] -pub struct RustlsConfig { - inner: Arc>, -} - -// The `RustlsConfig` structure represents configuration data for rustls. -impl RustlsConfig { - /// Create a new `RustlsConfig` from an `Arc`. - /// - /// Important: This method does not set ALPN protocols (like `http/1.1` or `h2`) automatically. - /// ALPN protocols need to be set manually when using this method. - pub fn from_config(config: Arc) -> Self { - let inner = Arc::new(ArcSwap::new(config)); - Self { inner } - } - - /// Create a `RustlsConfig` from DER-encoded data. - /// DER is a binary format for encoding data, commonly used for certificates and keys. - /// - /// `cert` is expected to be a DER-encoded X.509 certificate. - /// `key` is expected to be a DER-encoded ASN.1 format private key, either in PKCS#8 or PKCS#1 format. - pub async fn from_der(cert: Vec>, key: Vec) -> io::Result { - let server_config = spawn_blocking(|| config_from_der(cert, key)) - .await - .unwrap()?; - let inner = Arc::new(ArcSwap::from_pointee(server_config)); - Ok(Self { inner }) - } - - /// Create a `RustlsConfig` from PEM-formatted data. - /// PEM is a text-based format used to encode binary data like certificates and keys. - /// - /// Both `cert` and `key` must be provided in PEM format. - pub async fn from_pem(cert: Vec, key: Vec) -> io::Result { - let server_config = spawn_blocking(|| config_from_pem(cert, key)) - .await - .unwrap()?; - let inner = Arc::new(ArcSwap::from_pointee(server_config)); - Ok(Self { inner }) - } - - /// Create a `RustlsConfig` by reading PEM-formatted files. - /// - /// The contents of the provided certificate and private key files must be in PEM format. - pub async fn from_pem_file(cert: impl AsRef, key: impl AsRef) -> io::Result { - let server_config = config_from_pem_file(cert, key).await?; - let inner = Arc::new(ArcSwap::from_pointee(server_config)); - Ok(Self { inner }) - } - - /// Retrieve the inner `Arc` from the `RustlsConfig`. - pub fn get_inner(&self) -> Arc { - self.inner.load_full() - } - - /// Update (or reload) the `RustlsConfig` with a new `Arc`. - pub fn reload_from_config(&self, config: Arc) { - self.inner.store(config); - } - - /// Reload the `RustlsConfig` from provided DER-encoded data. - /// - /// As with the `from_der` method, `cert` must be DER-encoded X.509 and `key` - /// should be in either PKCS#8 or PKCS#1 DER-encoded ASN.1 format. - pub async fn reload_from_der(&self, cert: Vec>, key: Vec) -> io::Result<()> { - let server_config = spawn_blocking(|| config_from_der(cert, key)) - .await - .unwrap()?; - let inner = Arc::new(server_config); - self.inner.store(inner); - Ok(()) - } - - /// Reload the `RustlsConfig` using provided PEM-formatted data. - pub async fn reload_from_pem(&self, cert: Vec, key: Vec) -> io::Result<()> { - let server_config = spawn_blocking(|| config_from_pem(cert, key)) - .await - .unwrap()?; - let inner = Arc::new(server_config); - self.inner.store(inner); - Ok(()) - } - - /// Reload the `RustlsConfig` from provided PEM-formatted files. - pub async fn reload_from_pem_file( - &self, - cert: impl AsRef, - key: impl AsRef, - ) -> io::Result<()> { - let server_config = config_from_pem_file(cert, key).await?; - let inner = Arc::new(server_config); - self.inner.store(inner); - Ok(()) - } -} - -// This provides a debug representation for the `RustlsConfig`. -impl fmt::Debug for RustlsConfig { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RustlsConfig").finish() - } -} - -// Helper function to convert DER-encoded certificate and key into rustls's `ServerConfig`. -fn config_from_der(cert: Vec>, key: Vec) -> io::Result { - // Convert the raw bytes into rustls's Certificate and PrivateKey structures. - let cert = cert.into_iter().map(Certificate).collect(); - let key = PrivateKey(key); - - // Construct the ServerConfig. - let mut config = ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(cert, key) - .map_err(io_other)?; - - // Set ALPN protocols for the configuration. - config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - - Ok(config) -} - -// Helper function to convert PEM-formatted certificate and key into rustls' `ServerConfig`. -fn config_from_pem(cert: Vec, key: Vec) -> io::Result { - use rustls_pemfile::Item; - - // Parse PEM formatted data into rustls structures. - let cert = rustls_pemfile::certs(&mut cert.as_ref())?; - let key = match rustls_pemfile::read_one(&mut key.as_ref())? { - Some(Item::RSAKey(key)) | Some(Item::PKCS8Key(key)) | Some(Item::ECKey(key)) => key, - _ => return Err(io_other("private key format not supported")), - }; - - config_from_der(cert, key) -} - -// Helper function to read PEM-formatted files and convert them into rustls' ServerConfig. -async fn config_from_pem_file( - cert: impl AsRef, - key: impl AsRef, -) -> io::Result { - // Read the PEM files asynchronously. - let cert = tokio::fs::read(cert.as_ref()).await?; - let key = tokio::fs::read(key.as_ref()).await?; - - config_from_pem(cert, key) -} - -#[cfg(test)] -pub(crate) mod tests { - use crate::{ - handle::Handle, - tls_rustls::{self, RustlsConfig}, - }; - use axum::{routing::get, Router}; - use bytes::Bytes; - use http::{response, Request}; - use hyper::{ - client::conn::{handshake, SendRequest}, - Body, - }; - use rustls::{ - client::{ServerCertVerified, ServerCertVerifier}, - Certificate, ClientConfig, ServerName, - }; - use std::{ - convert::TryFrom, - io, - net::SocketAddr, - sync::Arc, - time::{Duration, SystemTime}, - }; - use tokio::time::sleep; - use tokio::{net::TcpStream, task::JoinHandle, time::timeout}; - use tokio_rustls::TlsConnector; - use tower::{Service, ServiceExt}; - - #[tokio::test] - async fn start_and_request() { - let (_handle, _server_task, addr) = start_server().await; - - let (mut client, _conn) = connect(addr).await; - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - } - - #[tokio::test] - async fn tls_timeout() { - let (handle, _server_task, addr) = start_server().await; - assert_eq!(handle.connection_count(), 0); - - // We intentionally avoid driving a TLS handshake to completion. - let _stream = TcpStream::connect(addr).await.unwrap(); - - sleep(Duration::from_millis(500)).await; - assert_eq!(handle.connection_count(), 1); - - tokio::time::sleep(Duration::from_millis(1000)).await; - // Timeout defaults to 1s during testing, and we have waited 1.5 seconds. - assert_eq!(handle.connection_count(), 0); - } - - #[tokio::test] - async fn test_reload() { - let handle = Handle::new(); - - let config = RustlsConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await - .unwrap(); - - let server_handle = handle.clone(); - let rustls_config = config.clone(); - tokio::spawn(async move { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - tls_rustls::bind_rustls(addr, rustls_config) - .handle(server_handle) - .serve(app.into_make_service()) - .await - }); - - let addr = handle.listening().await.unwrap(); - - let cert_a = get_first_cert(addr).await; - let mut cert_b = get_first_cert(addr).await; - - assert_eq!(cert_a, cert_b); - - config - .reload_from_pem_file( - "examples/self-signed-certs/reload/cert.pem", - "examples/self-signed-certs/reload/key.pem", - ) - .await - .unwrap(); - - cert_b = get_first_cert(addr).await; - - assert_ne!(cert_a, cert_b); - - config - .reload_from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await - .unwrap(); - - cert_b = get_first_cert(addr).await; - - assert_eq!(cert_a, cert_b); - } - - #[tokio::test] - async fn test_shutdown() { - let (handle, _server_task, addr) = start_server().await; - - let (mut client, conn) = connect(addr).await; - - handle.shutdown(); - - let response_future_result = client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await; - - assert!(response_future_result.is_err()); - - // Connection task should finish soon. - let _ = timeout(Duration::from_secs(1), conn).await.unwrap(); - } - - #[tokio::test] - async fn test_graceful_shutdown() { - let (handle, server_task, addr) = start_server().await; - - let (mut client, conn) = connect(addr).await; - - handle.graceful_shutdown(None); - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - - // Disconnect client. - conn.abort(); - - // Server task should finish soon. - let server_result = timeout(Duration::from_secs(1), server_task) - .await - .unwrap() - .unwrap(); - - assert!(server_result.is_ok()); - } - - #[tokio::test] - async fn test_graceful_shutdown_timed() { - let (handle, server_task, addr) = start_server().await; - - let (mut client, _conn) = connect(addr).await; - - handle.graceful_shutdown(Some(Duration::from_millis(250))); - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - - // Don't disconnect client. - // conn.abort(); - - // Server task should finish soon. - let server_result = timeout(Duration::from_secs(1), server_task) - .await - .unwrap() - .unwrap(); - - assert!(server_result.is_ok()); - } - - async fn start_server() -> (Handle, JoinHandle>, SocketAddr) { - let handle = Handle::new(); - - let server_handle = handle.clone(); - let server_task = tokio::spawn(async move { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let config = RustlsConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await?; - - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - tls_rustls::bind_rustls(addr, config) - .handle(server_handle) - .serve(app.into_make_service()) - .await - }); - - let addr = handle.listening().await.unwrap(); - - (handle, server_task, addr) - } - - async fn get_first_cert(addr: SocketAddr) -> Certificate { - let stream = TcpStream::connect(addr).await.unwrap(); - let tls_stream = tls_connector().connect(dns_name(), stream).await.unwrap(); - - let (_io, client_connection) = tls_stream.into_inner(); - - client_connection.peer_certificates().unwrap()[0].clone() - } - - async fn connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>) { - let stream = TcpStream::connect(addr).await.unwrap(); - let tls_stream = tls_connector().connect(dns_name(), stream).await.unwrap(); - - let (send_request, connection) = handshake(tls_stream).await.unwrap(); - - let task = tokio::spawn(async move { - let _ = connection.await; - }); - - (send_request, task) - } - - async fn send_empty_request(client: &mut SendRequest) -> (response::Parts, Bytes) { - let (parts, body) = client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await - .unwrap() - .into_parts(); - let body = hyper::body::to_bytes(body).await.unwrap(); - - (parts, body) - } - - /// Used in `proxy-protocol` feature tests. - pub(crate) fn tls_connector() -> TlsConnector { - struct NoVerify; - - impl ServerCertVerifier for NoVerify { - fn verify_server_cert( - &self, - _end_entity: &Certificate, - _intermediates: &[Certificate], - _server_name: &ServerName, - _scts: &mut dyn Iterator, - _ocsp_response: &[u8], - _now: SystemTime, - ) -> Result { - Ok(ServerCertVerified::assertion()) - } - } - - let mut client_config = ClientConfig::builder() - .with_safe_defaults() - .with_custom_certificate_verifier(Arc::new(NoVerify)) - .with_no_client_auth(); - - client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - - TlsConnector::from(Arc::new(client_config)) - } - - /// Used in `proxy-protocol` feature tests. - pub(crate) fn dns_name() -> ServerName { - ServerName::try_from("localhost").unwrap() - } -} From cc0d0dc9df4662d19217558c0956938200f2aac7 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 07:49:40 -0400 Subject: [PATCH 02/45] chore: scaffold dependencies --- Cargo.toml | 13 ++++++++++++- src/lib.rs | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index fa1eb50..0237a86 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,4 +11,15 @@ readme = "README.md" repository = "https://github.com/valorem-labs-inc/hyper-server" version = "1.0.0" -[dependencies] \ No newline at end of file +[dependencies] +http = "1.1.0" +http-body-util = "0.1.2" +hyper = "1.4.1" +hyper-util = { version = "0.1.8", features = ["server", "tokio", "server-auto", "server-graceful"] } +rustls = "0.23.13" +tokio = { version = "1.40.0", features = ["net"] } +tokio-rustls = "0.26.0" +tower = "0.5.1" + +[dev-dependencies] +tokio = { version = "1.40.0", features = ["macros"] } diff --git a/src/lib.rs b/src/lib.rs index e69de29..62e11da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -0,0 +1,32 @@ +// Crate links: +// - tokio: https://docs.rs/tokio/latest/tokio/ +// - rustls: https://docs.rs/rustls/latest/rustls/ +// - tokio_rustls:https://docs.rs/tokio-rustls/latest/tokio_rustls/ +// - hyper: https://docs.rs/hyper/latest/hyper/ +// - tower: https://docs.rs/tower/latest/tower/ + +// We take a `SocketAddr` to bind the server to a specific address. +// We use `TcpListener` to bind the server to the specified address at the TCP layer. +// We use `rustls` to create a new `ServerConfig` instance. +// We use `tokio_rustls` to create a new `TlsAcceptor` instance. +// We use `hyper_util::server::conn::auto` to create a new `Connection` instance. +// Which then passes requests to a `hyper::service::Service` instance. +// Which then can optionally pass requests to a `tower::Service` instance. +// Behind that can be axum, tower, tonic, +// or any other service that implements the `tower::Service` trait. + +pub struct HyperServer {} + +#[cfg(test)] +mod tests { + use std::net::SocketAddr; + use tokio::net::TcpListener; + + #[tokio::test] + async fn test_server() { + // Get a random port from the OS + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + // Create a TCP listener bound to the random address + let listener = TcpListener::bind(&addr).await.unwrap(); + } +} From f5c8c6088344bbf0e36f577190dc95c5685cd59a Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 08:21:03 -0400 Subject: [PATCH 03/45] feat: working impl of a tower lambda service/server --- Cargo.toml | 6 +- examples/sample.pem | 79 ++++++++++++++++++++ examples/sample.rsa | 27 +++++++ src/lib.rs | 177 +++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 278 insertions(+), 11 deletions(-) create mode 100644 examples/sample.pem create mode 100644 examples/sample.rsa diff --git a/Cargo.toml b/Cargo.toml index 0237a86..3ce4cb2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,11 +15,13 @@ version = "1.0.0" http = "1.1.0" http-body-util = "0.1.2" hyper = "1.4.1" -hyper-util = { version = "0.1.8", features = ["server", "tokio", "server-auto", "server-graceful"] } +hyper-util = { version = "0.1.8", features = ["server", "tokio", "server-auto", "server-graceful", "service"] } rustls = "0.23.13" +rustls-pemfile = "2.1.3" tokio = { version = "1.40.0", features = ["net"] } tokio-rustls = "0.26.0" -tower = "0.5.1" +tower = { version = "0.5.1", features = ["util"] } +tracing = "0.1.40" [dev-dependencies] tokio = { version = "1.40.0", features = ["macros"] } diff --git a/examples/sample.pem b/examples/sample.pem new file mode 100644 index 0000000..9e8bc64 --- /dev/null +++ b/examples/sample.pem @@ -0,0 +1,79 @@ +-----BEGIN CERTIFICATE----- +MIIEADCCAmigAwIBAgICAcgwDQYJKoZIhvcNAQELBQAwLDEqMCgGA1UEAwwhcG9u +eXRvd24gUlNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIyMDcwNDE0MzA1OFoX +DTI3MTIyNTE0MzA1OFowGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wggEiMA0G +CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDL35qLQLIqswCmHJxyczYF2p0YxXCq +gMvtRcKVElnifPMFrbGCY1aYBmhIiXPGRwhfythAtYfDQsrXFADZd52JPgZCR/u6 +DQMqKD2lcvFQkf7Kee/fNTOuQTQPh1XQx4ntxvicSATwEnuU28NwVnOU//Zzq2xn +Q34gUQNHWp1pN+B1La7emm/Ucgs1/2hMxwCZYUnRoiUoRGXUSzZuWokDOstPNkjc ++AjHmxONgowogmL2jKN9BjBw/8psGoqEOjMO+Lb9iekOCzX4kqHaRUbTlbSAviQu +2Q115xiZCBCZVtNE6DUG25buvpMSEXwpLd96nLywbrSCyueC7cd01/hpAgMBAAGj +gb4wgbswDAYDVR0TAQH/BAIwADALBgNVHQ8EBAMCBsAwHQYDVR0OBBYEFHGnzC5Q +A62Wmv4zfMk/kf/BxHevMEIGA1UdIwQ7MDmAFDMRUvwxXbYDBCxOdQ9xfBnNWUz0 +oR6kHDAaMRgwFgYDVQQDDA9wb255dG93biBSU0EgQ0GCAXswOwYDVR0RBDQwMoIO +dGVzdHNlcnZlci5jb22CFXNlY29uZC50ZXN0c2VydmVyLmNvbYIJbG9jYWxob3N0 +MA0GCSqGSIb3DQEBCwUAA4IBgQBqKNIM/JBGRmGEopm5/WNKV8UoxKPA+2jR020t +RumXMAnJEfhsivF+Zw/rDmSDpmts/3cIlesKi47f13q4Mfj1QytQUDrsuQEyRTrV +Go6BOQQ4dkS+IqnIfSuue70wpvrZHhRHNFdFt9qM5wCLQokXlP988sEWUmyPPCbO +1BEpwWcP1kx+PdY8NKOhMnfq2RfluI/m4MA4NxJqAWajAhIbDNbvP8Ov4a71HPa6 +b1q9qIQE1ut8KycTrm9K32bVbvMHvR/TPUue8W0VvV2rWTGol5TSNgEQb9w6Kyf7 +N5HlRl9kZB4K8ckWH/JVn0pYNBQPgwbcUbJ/jp6w+LHrh+UW75maOY+IGjVICud8 +6Rc5DZZ2+AAbXJQZ1HMPrw9SW/16Eh/A4CIEsvbu9J+7IoSzhgcKFzOCUojzzRSj +iU7w/HsvpltmVCAZcZ/VARFbe1By2wXX2GSw2p2FVC8orXs76QyruPAVgSHCTVes +zzBo6GLScO/3b6uAcPM3MHRGGvE= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIEnzCCAoegAwIBAgIBezANBgkqhkiG9w0BAQsFADAaMRgwFgYDVQQDDA9wb255 +dG93biBSU0EgQ0EwHhcNMjIwNzA0MTQzMDU4WhcNMzIwNzAxMTQzMDU4WjAsMSow +KAYDVQQDDCFwb255dG93biBSU0EgbGV2ZWwgMiBpbnRlcm1lZGlhdGUwggGiMA0G +CSqGSIb3DQEBAQUAA4IBjwAwggGKAoIBgQCsTkd2SKiy3yy20lygOhKfOySo3qpq +TZVrpW11vQ58+6EcetXRnzIIK0HyhPmZrv9XKPpQclJvfY9jADNtu2CSj/v15OSB +Love3GzmXSZz2A8QUZBPWx6HczDG1hFGzrCZPKzpeLnFD1LPsKCUkUOHl1acyy24 +DaCacQJPzPQWbMhbGmYRlDNb+2R2K6UKMAEVe4IOTv2aSIKDGLI+xlaBXYAJj48L +//9eNmR3bMP3kkNKOKaaBk8vnYxKpZ+8ZHeHTmYWR9x1ZoMcbA9lKUwRpKAjY5JJ +NVZMDmjlVQVvvBrvhgz/zgXtfuaQCryZ0f1sEY/zXhdealo3fGVomeoniD4XwA1c +oaUFkbo5IM5HU/pXyAGRerDyhYLgRqQZMIRauvKRPN3jLsPOEQ0+gnXUUTr/YGIE +KY3/Axg4P3hzZCFqJ5IgkgWZr/dKr9p/0cxSUGHTVcpEFOlkKIIIdRuR7Ng5sJml +u7PAMWt6T+x02ORs1/WkyP7LyPQmuugYTicCAwEAAaNeMFwwHQYDVR0OBBYEFDMR +UvwxXbYDBCxOdQ9xfBnNWUz0MCAGA1UdJQEB/wQWMBQGCCsGAQUFBwMBBggrBgEF +BQcDAjAMBgNVHRMEBTADAQH/MAsGA1UdDwQEAwIB/jANBgkqhkiG9w0BAQsFAAOC +AgEAYzqmX+cNPgVD2HWgbeimUraTpI9JP5P4TbOHWmaJKecoy3Hwr71xyAOGiVXL +urk1ZZe8n++GwuDEgRajN3HO9LR1Pu9qVIzTYIsz0ORRQHxujnF7CxK/I/vrIgde +pddUdHNS0Y0g8J1emH9BgoD8a2YsGX4iDY4S4hIGBbGvQp9z8U/uG1mViAmlXybM +b8bf0dx0tEFUyu8jsQP6nFLY/HhkEcvU6SnOzZHRsFko6NE44VIsHLd2+LS2LCM/ +NfAoTzgvj41M3zQCZapaHZc9KXfdcCvEFaySKGfEZeQTUR5W0FHsF5I4NLGryf5L +h3ENQ1tgBTO5WnqL/5rbgv6va9VionPM5sbEwAcancejnkVs3NoYPIPPgBFjaFmL +hNTpT9H2owdZvEwNDChVS0b8ukNNd4cERtvy0Ohc3mk0LGN0ABzrud0fIqa51LMh +0N3dkPkiZ4XYk4yLJ5EwCrCNNH50QkGCOWpInKIPeSYcALGgBDbCDLv6rV3oSKrV +tHCZQwXVKKgU4AQu7hlHBwJ61cH44ksydOidW3MNq1kDIp7ST8s7gVrItNgFnG+L +Jpo270riwSUlWDY4hXw5Ff5lE+bWCmFyyOkLevDkD9v8M4HdwEVvafYYwn75fCIS +5OnpSeIB08kKqCtW1WBwki0rYJjWqdzI7Z1MQ/AyScAKiGM= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIEsDCCApgCCQCfkxy3a+AgNjANBgkqhkiG9w0BAQsFADAaMRgwFgYDVQQDDA9w +b255dG93biBSU0EgQ0EwHhcNMjIwNzA0MTQzMDU3WhcNMzIwNzAxMTQzMDU3WjAa +MRgwFgYDVQQDDA9wb255dG93biBSU0EgQ0EwggIiMA0GCSqGSIb3DQEBAQUAA4IC +DwAwggIKAoICAQCj6nW8pnN50UsH2NjL97xZKxlXPe5ptXfvqXczMsw0vB3gI4xJ +Tdmrnqo0K+VOH7vh+UXcOj2ZMY2ou6oDDK5Qpu9bvGPBIJH/rC1Ti2+u5Y4KTIUc +jWAtzQJeFn8+oCMfskpLdtlWLRdAuwqNHjvxXdd2JnsX1Wid85U/rG2SNPLGjJAF +xG7xzZC4VSO2WIXTGRMUkZfFc8fhWMjo3GaeF/qYjzfHDPWN/ll/7vfxyXJO/ohw +FzpJSZtKmI+6PLxqB/oFrKfTDQUGzxjfHp187bI3eyUFMJsp18/tLYkLyxSWIg3o +bq7ZVimHd1UG2Vb5Y+5pZkh22jmJ6bAa/kmNNwbsD+5vJhW1myGhmZSxkreYPWnS +6ELrSMvbXccFfTYmdBlWsZx/zUVUzVCPe9jdJki2VXlicohqtvBQqe6LGGO37vvv +Gwu1yzQ/rJy47rnaao7fSxqM8nsDjNR2Ev1v031QpEMWjfgUW0roW3H58RZSx+kU +gzIS2CjJIqKxCp894FUQbC6r0wwAuKltl3ywz5qWkxY0O9bXS0YdEXiri5pdsWjr +84shVVQwnoVD9539CLSdHZjlOCAzvSWHZH6ta2JZjUfYYz8cLyv2c2+y9BYrlvHw +T7U7BqzngUk72gcRXd5+Onp+16gGxpGJqaxqj94Nh/yTUnr2Jd9YaXeFmQIDAQAB +MA0GCSqGSIb3DQEBCwUAA4ICAQBzIRVRt3Yaw60tpkyz/i1xbKCbtC+HqYTEsXvZ +RvZ5X1qyLAcmu4EW9RHXnlLiawDbES6lCMFfdBUK03Wis7socvoFUCBRW337F4z2 +IivHfIge4u+w5ouUKPzcpj6oeuR06tmNytYbno6l8tXJpm1eeO4KNZ0ZtodmyB5D +yLrplFgxTdGGgyvxt8LoeLwGmPCyVt35x/Mz6x2lcq1+r7QJZ9sENhQYuA8UqHrw +fmNoVIMXMEcPLcWtFl6nKTK9LrqAu1jgTBqGGZKRn5CYBBK3pNEGKiOIsZXDbyFS +F59teFpJjyeJTbUbLxXDa15J6ExkHV9wFLEvfu/nzQzg8D9yzczSdbDkE2rrrL+s +Q/H/pIXO/DesCWQ37VALn3B5gm9UBd5uogbSw8eamiwRFLQ0snP80pJQGJoTNn0P +wrLLUf2gsKC2262igiA+imepm5wxbV9XGVZfHJgxCi5Zqrf6aWnjIqD2YtDvAHhs +V8ZWN3QTjdnEcQbG0544rocoLNX/FzmyDgjfZKY5r6wt+FWNc/R4clkF+KxasxqB +HdBs8j0lGV3ujvNXASLq9HI6VxZayrSfkR73hADCXIM/wzynKwMarvA4SXwYX9Pd +cJ4+FMqrevPpamMHUsNndS0KfDTdjDp+TSBf87yiyRkD1Ri4ePslyfNvRyv3Xs7k +47YFzA== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/examples/sample.rsa b/examples/sample.rsa new file mode 100644 index 0000000..d1f178a --- /dev/null +++ b/examples/sample.rsa @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAy9+ai0CyKrMAphyccnM2BdqdGMVwqoDL7UXClRJZ4nzzBa2x +gmNWmAZoSIlzxkcIX8rYQLWHw0LK1xQA2XediT4GQkf7ug0DKig9pXLxUJH+ynnv +3zUzrkE0D4dV0MeJ7cb4nEgE8BJ7lNvDcFZzlP/2c6tsZ0N+IFEDR1qdaTfgdS2u +3ppv1HILNf9oTMcAmWFJ0aIlKERl1Es2blqJAzrLTzZI3PgIx5sTjYKMKIJi9oyj +fQYwcP/KbBqKhDozDvi2/YnpDgs1+JKh2kVG05W0gL4kLtkNdecYmQgQmVbTROg1 +BtuW7r6TEhF8KS3fepy8sG60gsrngu3HdNf4aQIDAQABAoIBAFTehqVFj2W7EqAT +9QSn9WtGcHNpbddsunfRvIj2FLj2LuzEO8r9s4Sh1jOsFKgL1e6asJ9vck7UtUAH +sbrV0pzZVx2sfZwb4p9gFRmU2eQigqCjVjnjGdqGhjeYrR62kjKLy96zFGskJpH3 +UkqnkoIKc/v+9qeeLxkg4G6JyFGOFHJAZEraxoGydJk9n/yBEZ/+3W7JUJaGOUNU +M7BYsCS2VOJr+cCqmCk1j8NvYvWWxTPsIXgGJl4EOoskzlzJnYLdh9fPFZu3uOIx +hpm3DBNp6X+qXf1lmx9EdpyeXKpLFIgJM7+nw2uWzxW7XMlRERi+5Tprc/pjrqUq +gpfyvMkCgYEA909QcJpS3qHoWyxGbI1zosVIZXdnj8L+GF/2kEQEU5iEYT+2M1U+ +gCPLr49gNwkD1FdBSCy+Fw20zi35jGmxNwhgp4V94CGYzqwQzpnvgIRBMiAIoEwI +CD5/t34DZ/82u8Gb7UYVrzOD54rJ628Q+tJEJak3TqoShbvcxJC/rXMCgYEA0wmO +SRoxrBE3rFzNQkqHbMHLe9LksW9YSIXdMBjq4DhzQEwI0YgPLajXnsLurqHaJrQA +JPtYkqiJkV7rvJLBo5wxwU+O2JKKa2jcMwuCZ4hOg5oBfK6ES9QJZUL7kDe2vsWy +rL+rnxJheUjDPBTopGHuuc9Nogid35CE0wy7S7MCgYArxB+KLeVofOKv79/uqgHC +1oL/Yegz6uAo1CLAWSki2iTjSPEnmHhdGPic8xSl6LSCyYZGDZT+Y3CR5FT7YmD4 +SkVAoEEsfwWZ3Z2D0n4uEjmvczfTlmD9hIH5qRVVPDcldxfvH64KuWUofslJHvi0 +Sq3AtHeTNknc3Ogu6SbivQKBgQC4ZAsMWHS6MTkBwvwdRd1Z62INyNDFL9JlW4FN +uxfN3cTlkwnJeiY48OOk9hFySDzBwFi3910Gl3fLqrIyy8+hUqIuk4LuO+vxuWdc +uluwdmqTlgZimGFDl/q1nXcMJYHo4fgh9D7R+E9ul2Luph43MtJRS447W2gFpNJJ +TUCA/QKBgQC07GFP2BN74UvL12f+FpZvE/UFtWnSZ8yJSq8oYpIbhmoF5EUF+XdA +E2y3l1cvmDJFo4RNZl+IQIbHACR3y1XOnh4/B9fMEsVQHK3x8exPk1vAk687bBG8 +TVDmdP52XEKHplcVoYKvGzw/wsObLAGyIbJ00t1VPU+7guTPsc+H/w== +-----END RSA PRIVATE KEY----- \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 62e11da..69f9e7d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,7 @@ // - tokio_rustls:https://docs.rs/tokio-rustls/latest/tokio_rustls/ // - hyper: https://docs.rs/hyper/latest/hyper/ // - tower: https://docs.rs/tower/latest/tower/ - +// // We take a `SocketAddr` to bind the server to a specific address. // We use `TcpListener` to bind the server to the specified address at the TCP layer. // We use `rustls` to create a new `ServerConfig` instance. @@ -15,18 +15,177 @@ // Behind that can be axum, tower, tonic, // or any other service that implements the `tower::Service` trait. -pub struct HyperServer {} +use std::fs; +use std::io; +use std::net::SocketAddr; +use std::sync::Arc; -#[cfg(test)] -mod tests { - use std::net::SocketAddr; - use tokio::net::TcpListener; +use http::{Method, Request, Response, StatusCode}; +use http_body_util::BodyExt; +use http_body_util::Full; +use hyper::body::{Bytes, Incoming}; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::server::conn::auto::Builder; +use hyper_util::service::TowerToHyperService; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use rustls::ServerConfig; +use tokio::net::TcpListener; +use tokio_rustls::TlsAcceptor; +use tower::{Service, ServiceBuilder}; - #[tokio::test] - async fn test_server() { +#[derive(Debug, Clone)] +pub struct Logger { + inner: S, +} +impl Logger { + pub fn new(inner: S) -> Self { + Logger { inner } + } +} +type Req = Request; +impl Service for Logger +where + S: Service + Clone, +{ + type Response = S::Response; + + type Error = S::Error; + + type Future = S::Future; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Req) -> Self::Future { + println!("processing request: {} {}", req.method(), req.uri().path()); + self.inner.call(req) + } +} + +// Wrapped error type for the server. +fn error(err: String) -> io::Error { + io::Error::new(io::ErrorKind::Other, err) +} + +// Load the public certificate from a file. +fn load_certs(filename: &str) -> io::Result>> { + // Open certificate file. + let certfile = fs::File::open(filename) + .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; + let mut reader = io::BufReader::new(certfile); + + // Load and return certificate. + rustls_pemfile::certs(&mut reader).collect() +} + +// Load the private key from a file. +fn load_private_key(filename: &str) -> io::Result> { + // Open keyfile. + let keyfile = fs::File::open(filename) + .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; + let mut reader = io::BufReader::new(keyfile); + + // Load and return a single private key. + rustls_pemfile::private_key(&mut reader).map(|key| key.unwrap()) +} + +// Custom echo service, handling two different routes and a +// catch-all 404/not-found responder. +async fn echo(req: Request) -> Result>, hyper::Error> { + let mut response = Response::new(Full::default()); + match (req.method(), req.uri().path()) { + // Help route. + (&Method::GET, "/") => { + *response.body_mut() = Full::from("Try POST /echo\n"); + } + // Echo service route. + (&Method::POST, "/echo") => { + *response.body_mut() = Full::from(req.into_body().collect().await?.to_bytes()); + } + // Catch-all 404. + _ => { + *response.status_mut() = StatusCode::NOT_FOUND; + } + }; + Ok(response) +} + +pub struct HyperServer {} + +impl HyperServer { + pub async fn serve(&self) -> Result<(), Box> { // Get a random port from the OS let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + // Create a TCP listener bound to the random address - let listener = TcpListener::bind(&addr).await.unwrap(); + let incoming = TcpListener::bind(&addr).await?; + + // Load public certificate. + let certs = load_certs("examples/sample.pem")?; + + // Load private key. + let key = load_private_key("examples/sample.rsa")?; + + // Build TLS configuration. + let mut server_config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .unwrap(); + + // Enable ALPN with HTTP/2 and HTTP/1.1 support. + server_config.alpn_protocols = + vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; + + // Create a rustls TlsAcceptor + let tls_acceptor = TlsAcceptor::from(Arc::new(server_config)); + + // Create a tower service + let service = tower::service_fn(echo); + let service = ServiceBuilder::new().layer_fn(Logger::new).service(service); + + // Convert it to a hyper service + let service = TowerToHyperService::new(service); + + // Begin the server loop + loop { + // Wait for an incoming tcp stream + let (tcp_stream, _remote_addr) = incoming.accept().await.unwrap(); + + // Clone a new instance of the tls_acceptor + let tls_acceptor = tls_acceptor.clone(); + + // Clone a new instance of the service + let service = service.clone(); + + // Spawn a new async task to handle the incoming connection + tokio::spawn(async move { + // Perform the TLS handshake + let tls_stream = match tls_acceptor.accept(tcp_stream).await { + Ok(tls_stream) => tls_stream, + Err(err) => { + eprintln!("failed to perform tls handshake: {err:#}"); + return; + } + }; + + // Serve the http connection + if let Err(err) = Builder::new(TokioExecutor::new()) + .serve_connection(TokioIo::new(tls_stream), service) + .await + { + eprintln!("failed to serve connection: {err:#}"); + } + }); + } } } + +#[cfg(test)] +mod tests { + #[tokio::test] + async fn test_echo_service() {} +} From 2d27fb7a5aac6a205a9dedd56b7102a0f2221f5a Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 09:58:20 -0400 Subject: [PATCH 04/45] feat: well factored http connection handling --- Cargo.toml | 4 ++ src/lib.rs | 207 ++++++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 184 insertions(+), 27 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3ce4cb2..d3cec51 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,10 @@ tokio = { version = "1.40.0", features = ["net"] } tokio-rustls = "0.26.0" tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" +http-body = "1.0.1" +tokio-stream = "0.1.16" +bytes = "1.7.1" +pin-project = "1.1.5" [dev-dependencies] tokio = { version = "1.40.0", features = ["macros"] } diff --git a/src/lib.rs b/src/lib.rs index 69f9e7d..b49e7e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,37 +1,137 @@ -// Crate links: -// - tokio: https://docs.rs/tokio/latest/tokio/ -// - rustls: https://docs.rs/rustls/latest/rustls/ -// - tokio_rustls:https://docs.rs/tokio-rustls/latest/tokio_rustls/ -// - hyper: https://docs.rs/hyper/latest/hyper/ -// - tower: https://docs.rs/tower/latest/tower/ -// -// We take a `SocketAddr` to bind the server to a specific address. -// We use `TcpListener` to bind the server to the specified address at the TCP layer. -// We use `rustls` to create a new `ServerConfig` instance. -// We use `tokio_rustls` to create a new `TlsAcceptor` instance. -// We use `hyper_util::server::conn::auto` to create a new `Connection` instance. -// Which then passes requests to a `hyper::service::Service` instance. -// Which then can optionally pass requests to a `tower::Service` instance. -// Behind that can be axum, tower, tonic, -// or any other service that implements the `tower::Service` trait. - -use std::fs; -use std::io; -use std::net::SocketAddr; -use std::sync::Arc; - +use bytes::Bytes; use http::{Method, Request, Response, StatusCode}; +use http_body::Body; use http_body_util::BodyExt; use http_body_util::Full; -use hyper::body::{Bytes, Incoming}; +use hyper::{body::Incoming, service::Service as HyperService}; use hyper_util::rt::{TokioExecutor, TokioIo}; -use hyper_util::server::conn::auto::Builder; +use hyper_util::server::conn::auto::Builder as HttpConnBuilder; +use hyper_util::server::conn::auto::HttpServerConnExec; use hyper_util::service::TowerToHyperService; +use pin_project::pin_project; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use rustls::ServerConfig; +use std::error::Error as StdError; +use std::future::pending; +use std::net::SocketAddr; +use std::pin::{pin, Pin}; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; +use std::{fmt, fs, future::Future, io}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::macros::support::poll_fn; use tokio::net::TcpListener; +use tokio::time::sleep; use tokio_rustls::TlsAcceptor; +use tokio_stream::Stream; +use tokio_stream::StreamExt as _; use tower::{Service, ServiceBuilder}; +use tracing::{debug, trace}; + +// From `futures-util` crate, borrowed since this is the only dependency hyper-server requires. +// LICENSE: MIT or Apache-2.0 +// A future which only yields `Poll::Ready` once, and thereafter yields `Poll::Pending`. +#[pin_project] +struct Fuse { + #[pin] + inner: Option, +} + +impl Future for Fuse +where + F: Future, +{ + type Output = F::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project().inner.as_pin_mut() { + Some(fut) => fut.poll(cx).map(|output| { + self.project().inner.set(None); + output + }), + None => Poll::Pending, + } + } +} + +type Source = Box; + +/// Errors that originate from the client hyper-server; +pub struct Error { + inner: ErrorImpl, +} + +struct ErrorImpl { + kind: Kind, + source: Option, +} + +#[derive(Debug)] +pub(crate) enum Kind { + Transport, +} + +impl Error { + pub(crate) fn new(kind: Kind) -> Self { + Self { + inner: ErrorImpl { kind, source: None }, + } + } + + pub(crate) fn with(mut self, source: impl Into) -> Self { + self.inner.source = Some(source.into()); + self + } + + pub(crate) fn from_source( + source: impl Into + std::error::Error + std::marker::Send + std::marker::Sync + 'static, + ) -> Self { + Error::new(Kind::Transport).with(source) + } + + fn description(&self) -> &str { + match &self.inner.kind { + Kind::Transport => "transport error", + } + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut f = f.debug_tuple("hyper_server::Error"); + + f.field(&self.inner.kind); + + if let Some(source) = &self.inner.source { + f.field(source); + } + + f.finish() + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.description()) + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.inner + .source + .as_ref() + .map(|source| &**source as &(dyn StdError + 'static)) + } +} + +async fn sleep_or_pending(wait_for: Option) { + match wait_for { + Some(wait) => sleep(wait).await, + None => pending().await, + }; +} #[derive(Debug, Clone)] pub struct Logger { @@ -43,6 +143,7 @@ impl Logger { } } type Req = Request; + impl Service for Logger where S: Service + Clone, @@ -114,9 +215,9 @@ async fn echo(req: Request) -> Result>, hyper::Er Ok(response) } -pub struct HyperServer {} +pub struct Server {} -impl HyperServer { +impl Server { pub async fn serve(&self) -> Result<(), Box> { // Get a random port from the OS let addr = SocketAddr::from(([127, 0, 0, 1], 0)); @@ -173,7 +274,7 @@ impl HyperServer { }; // Serve the http connection - if let Err(err) = Builder::new(TokioExecutor::new()) + if let Err(err) = HttpConnBuilder::new(TokioExecutor::new()) .serve_connection(TokioIo::new(tls_stream), service) .await { @@ -182,6 +283,58 @@ impl HyperServer { }); } } + + /// Serves a single HTTP connection from a hyper service backend + async fn serve_connection( + hyper_io: IO, + hyper_svc: S, + builder: HttpConnBuilder, + mut watcher: Option>, + max_connection_age: Option, + ) where + B: Body + Send + 'static, + B::Data: Send, + B::Error: Into> + Send + Sync, + IO: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, + S: HyperService, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into> + Send, + E: HttpServerConnExec + Send + Sync + 'static, + { + tokio::spawn(async move { + { + let mut sig = pin!(Fuse { + inner: watcher.as_mut().map(|w| w.changed()), + }); + + let mut conn = pin!(builder.serve_connection(hyper_io, hyper_svc)); + + let sleep = sleep_or_pending(max_connection_age); + tokio::pin!(sleep); + + loop { + tokio::select! { + rv = &mut conn => { + if let Err(err) = rv { + debug!("failed serving connection: {:#}", err); + } + break; + }, + _ = &mut sleep => { + conn.as_mut().graceful_shutdown(); + sleep.set(sleep_or_pending(None)); + }, + _ = &mut sig => { + conn.as_mut().graceful_shutdown(); + } + } + } + } + + drop(watcher); + trace!("connection closed"); + }); + } } #[cfg(test)] From 4d9bc67f52a7601a479d8c61390151ca3666fcb7 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 10:25:25 -0400 Subject: [PATCH 05/45] chore: document http connection server --- src/lib.rs | 71 +++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 51 insertions(+), 20 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b49e7e3..6abb287 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -284,10 +284,29 @@ impl Server { } } - /// Serves a single HTTP connection from a hyper service backend - async fn serve_connection( + /// Serves a single HTTP connection from a hyper service backend. + /// + /// This method handles an individual HTTP connection, processing requests through + /// the provided service and managing the connection lifecycle. + /// + /// # Type Parameters + /// + /// * `B`: The body type for the HTTP response. + /// * `IO`: The I/O type for the HTTP connection. + /// * `S`: The service type that processes HTTP requests. + /// * `E`: The executor type for the HTTP server connection. + /// + /// # Parameters + /// + /// * `hyper_io`: The I/O object representing the inbound hyper IO stream. + /// * `hyper_svc`: The hyper `Service` implementation used to process HTTP requests. + /// * `builder`: An `HttpConnBuilder` used to create and serve the HTTP connection. + /// * `watcher`: An optional `tokio::sync::watch::Receiver` for graceful shutdown signaling. + /// * `max_connection_age`: An optional `Duration` specifying the maximum age of the connection + /// before initiating a graceful shutdown. + async fn serve_http_connection( hyper_io: IO, - hyper_svc: S, + hyper_service: S, builder: HttpConnBuilder, mut watcher: Option>, max_connection_age: Option, @@ -296,43 +315,55 @@ impl Server { B::Data: Send, B::Error: Into> + Send + Sync, IO: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, - S: HyperService, Response = Response> + Clone + Send + 'static, + S: HyperService, Response=Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into> + Send, E: HttpServerConnExec + Send + Sync + 'static, { + // Spawn a new asynchronous task to handle the incoming hyper IO stream tokio::spawn(async move { { + // Set up a fused future for the watcher let mut sig = pin!(Fuse { - inner: watcher.as_mut().map(|w| w.changed()), - }); + inner: watcher.as_mut().map(|w| w.changed()), + }); - let mut conn = pin!(builder.serve_connection(hyper_io, hyper_svc)); + // Create and pin the HTTP connection + let mut conn = pin!(builder.serve_connection(hyper_io, hyper_service)); + // Set up the sleep future for max connection age let sleep = sleep_or_pending(max_connection_age); tokio::pin!(sleep); + // Main loop for serving the HTTP connection loop { tokio::select! { - rv = &mut conn => { - if let Err(err) = rv { - debug!("failed serving connection: {:#}", err); - } - break; - }, - _ = &mut sleep => { - conn.as_mut().graceful_shutdown(); - sleep.set(sleep_or_pending(None)); - }, - _ = &mut sig => { - conn.as_mut().graceful_shutdown(); + // Handle the connection result + rv = &mut conn => { + if let Err(err) = rv { + // Log any errors that occur while serving the HTTP connection + debug!("failed serving HTTP connection: {:#}", err); } + break; + }, + // Handle max connection age timeout + _ = &mut sleep => { + // Initiate a graceful shutdown when max connection age is reached + conn.as_mut().graceful_shutdown(); + sleep.set(sleep_or_pending(None)); + }, + // Handle graceful shutdown signal + _ = &mut sig => { + // Initiate a graceful shutdown when signal is received + conn.as_mut().graceful_shutdown(); } } + } } + // Clean up and log connection closure drop(watcher); - trace!("connection closed"); + trace!("HTTP connection closed"); }); } } From 8e66463f51e9f2c333479734bede804d177c4a70 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 14:36:38 -0400 Subject: [PATCH 06/45] feat: tcp incoming --- Cargo.toml | 7 +- src/error.rs | 70 ++++++++++ src/fuse.rs | 30 ++++ src/http.rs | 100 ++++++++++++++ src/lib.rs | 380 +-------------------------------------------------- src/tcp.rs | 199 +++++++++++++++++++++++++++ src/tls.rs | 1 + 7 files changed, 411 insertions(+), 376 deletions(-) create mode 100644 src/error.rs create mode 100644 src/fuse.rs create mode 100644 src/http.rs create mode 100644 src/tcp.rs create mode 100644 src/tls.rs diff --git a/Cargo.toml b/Cargo.toml index d3cec51..3540a4b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,9 +23,12 @@ tokio-rustls = "0.26.0" tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" http-body = "1.0.1" -tokio-stream = "0.1.16" +tokio-stream = { version = "0.1.16", features = ["net"] } bytes = "1.7.1" pin-project = "1.1.5" +async-stream = "0.3.5" +futures = "0.3.30" [dev-dependencies] -tokio = { version = "1.40.0", features = ["macros"] } +tokio = { version = "1.0", features = ["rt", "net", "test-util", "macros"] } +tokio-test = "0.4.4" diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..2c38bb4 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,70 @@ +use std::{error::Error as StdError, fmt}; + +type Source = Box; + +/// Errors that originate from the server; +pub struct Error { + inner: ErrorImpl, +} + +struct ErrorImpl { + kind: Kind, + source: Option, +} + +#[derive(Debug)] +pub(crate) enum Kind { + Transport, +} + +impl Error { + pub(crate) fn new(kind: Kind) -> Self { + Self { + inner: ErrorImpl { kind, source: None }, + } + } + + pub(crate) fn with(mut self, source: impl Into) -> Self { + self.inner.source = Some(source.into()); + self + } + + pub(crate) fn from_source(source: impl Into) -> Self { + Error::new(Kind::Transport).with(source) + } + + fn description(&self) -> &str { + match &self.inner.kind { + Kind::Transport => "transport error", + } + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut f = f.debug_tuple("tonic::transport::Error"); + + f.field(&self.inner.kind); + + if let Some(source) = &self.inner.source { + f.field(source); + } + + f.finish() + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.description()) + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.inner + .source + .as_ref() + .map(|source| &**source as &(dyn StdError + 'static)) + } +} diff --git a/src/fuse.rs b/src/fuse.rs new file mode 100644 index 0000000..f75bb02 --- /dev/null +++ b/src/fuse.rs @@ -0,0 +1,30 @@ +use pin_project::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +// From `futures-util` crate, borrowed since this is the only dependency hyper-server requires. +// LICENSE: MIT or Apache-2.0 +// A future which only yields `Poll::Ready` once, and thereafter yields `Poll::Pending`. +#[pin_project] +pub(crate) struct Fuse { + #[pin] + pub(crate) inner: Option, +} + +impl Future for Fuse +where + F: Future, +{ + type Output = F::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project().inner.as_pin_mut() { + Some(fut) => fut.poll(cx).map(|output| { + self.project().inner.set(None); + output + }), + None => Poll::Pending, + } + } +} diff --git a/src/http.rs b/src/http.rs new file mode 100644 index 0000000..959c505 --- /dev/null +++ b/src/http.rs @@ -0,0 +1,100 @@ +use http::{Request, Response}; +use http_body::Body; +use hyper::body::Incoming; +use hyper::service::Service; +use hyper_util::server::conn::auto::{Builder, HttpServerConnExec}; +use std::future::pending; +use std::pin::pin; +use std::time::Duration; +use tokio::time::sleep; +use tracing::{debug, trace}; + +async fn sleep_or_pending(wait_for: Option) { + match wait_for { + Some(wait) => sleep(wait).await, + None => pending().await, + }; +} + +/// Serves a single HTTP connection from a hyper service backend. +/// +/// This method handles an individual HTTP connection, processing requests through +/// the provided service and managing the connection lifecycle. +/// +/// # Type Parameters +/// +/// * `B`: The body type for the HTTP response. +/// * `IO`: The I/O type for the HTTP connection. +/// * `S`: The service type that processes HTTP requests. +/// * `E`: The executor type for the HTTP server connection. +/// +/// # Parameters +/// +/// * `hyper_io`: The I/O object representing the inbound hyper IO stream. +/// * `hyper_svc`: The hyper `Service` implementation used to process HTTP requests. +/// * `builder`: An `HttpConnBuilder` used to create and serve the HTTP connection. +/// * `watcher`: An optional `tokio::sync::watch::Receiver` for graceful shutdown signaling. +/// * `max_connection_age`: An optional `Duration` specifying the maximum age of the connection +/// before initiating a graceful shutdown. +async fn serve_http_connection( + hyper_io: IO, + hyper_service: S, + builder: Builder, + mut watcher: Option>, + max_connection_age: Option, +) where + B: Body + Send + 'static, + B::Data: Send, + B::Error: Into> + Send + Sync, + IO: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into> + Send, + E: HttpServerConnExec + Send + Sync + 'static, +{ + // Spawn a new asynchronous task to handle the incoming hyper IO stream + tokio::spawn(async move { + { + // Set up a fused future for the watcher + let mut sig = pin!(crate::fuse::Fuse { + inner: watcher.as_mut().map(|w| w.changed()), + }); + + // Create and pin the HTTP connection + let mut conn = pin!(builder.serve_connection(hyper_io, hyper_service)); + + // Set up the sleep future for max connection age + let sleep = sleep_or_pending(max_connection_age); + tokio::pin!(sleep); + + // Main loop for serving the HTTP connection + loop { + tokio::select! { + // Handle the connection result + rv = &mut conn => { + if let Err(err) = rv { + // Log any errors that occur while serving the HTTP connection + debug!("failed serving HTTP connection: {:#}", err); + } + break; + }, + // Handle max connection age timeout + _ = &mut sleep => { + // Initiate a graceful shutdown when max connection age is reached + conn.as_mut().graceful_shutdown(); + sleep.set(sleep_or_pending(None)); + }, + // Handle graceful shutdown signal + _ = &mut sig => { + // Initiate a graceful shutdown when signal is received + conn.as_mut().graceful_shutdown(); + } + } + } + } + + // Clean up and log connection closure + drop(watcher); + trace!("HTTP connection closed"); + }); +} diff --git a/src/lib.rs b/src/lib.rs index 6abb287..9a714c1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,375 +1,7 @@ -use bytes::Bytes; -use http::{Method, Request, Response, StatusCode}; -use http_body::Body; -use http_body_util::BodyExt; -use http_body_util::Full; -use hyper::{body::Incoming, service::Service as HyperService}; -use hyper_util::rt::{TokioExecutor, TokioIo}; -use hyper_util::server::conn::auto::Builder as HttpConnBuilder; -use hyper_util::server::conn::auto::HttpServerConnExec; -use hyper_util::service::TowerToHyperService; -use pin_project::pin_project; -use rustls::pki_types::{CertificateDer, PrivateKeyDer}; -use rustls::ServerConfig; -use std::error::Error as StdError; -use std::future::pending; -use std::net::SocketAddr; -use std::pin::{pin, Pin}; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::time::Duration; -use std::{fmt, fs, future::Future, io}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::macros::support::poll_fn; -use tokio::net::TcpListener; -use tokio::time::sleep; -use tokio_rustls::TlsAcceptor; -use tokio_stream::Stream; -use tokio_stream::StreamExt as _; -use tower::{Service, ServiceBuilder}; -use tracing::{debug, trace}; +mod error; +mod fuse; +mod http; +mod tcp; +mod tls; -// From `futures-util` crate, borrowed since this is the only dependency hyper-server requires. -// LICENSE: MIT or Apache-2.0 -// A future which only yields `Poll::Ready` once, and thereafter yields `Poll::Pending`. -#[pin_project] -struct Fuse { - #[pin] - inner: Option, -} - -impl Future for Fuse -where - F: Future, -{ - type Output = F::Output; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.as_mut().project().inner.as_pin_mut() { - Some(fut) => fut.poll(cx).map(|output| { - self.project().inner.set(None); - output - }), - None => Poll::Pending, - } - } -} - -type Source = Box; - -/// Errors that originate from the client hyper-server; -pub struct Error { - inner: ErrorImpl, -} - -struct ErrorImpl { - kind: Kind, - source: Option, -} - -#[derive(Debug)] -pub(crate) enum Kind { - Transport, -} - -impl Error { - pub(crate) fn new(kind: Kind) -> Self { - Self { - inner: ErrorImpl { kind, source: None }, - } - } - - pub(crate) fn with(mut self, source: impl Into) -> Self { - self.inner.source = Some(source.into()); - self - } - - pub(crate) fn from_source( - source: impl Into + std::error::Error + std::marker::Send + std::marker::Sync + 'static, - ) -> Self { - Error::new(Kind::Transport).with(source) - } - - fn description(&self) -> &str { - match &self.inner.kind { - Kind::Transport => "transport error", - } - } -} - -impl fmt::Debug for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut f = f.debug_tuple("hyper_server::Error"); - - f.field(&self.inner.kind); - - if let Some(source) = &self.inner.source { - f.field(source); - } - - f.finish() - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(self.description()) - } -} - -impl StdError for Error { - fn source(&self) -> Option<&(dyn StdError + 'static)> { - self.inner - .source - .as_ref() - .map(|source| &**source as &(dyn StdError + 'static)) - } -} - -async fn sleep_or_pending(wait_for: Option) { - match wait_for { - Some(wait) => sleep(wait).await, - None => pending().await, - }; -} - -#[derive(Debug, Clone)] -pub struct Logger { - inner: S, -} -impl Logger { - pub fn new(inner: S) -> Self { - Logger { inner } - } -} -type Req = Request; - -impl Service for Logger -where - S: Service + Clone, -{ - type Response = S::Response; - - type Error = S::Error; - - type Future = S::Future; - - fn poll_ready( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, req: Req) -> Self::Future { - println!("processing request: {} {}", req.method(), req.uri().path()); - self.inner.call(req) - } -} - -// Wrapped error type for the server. -fn error(err: String) -> io::Error { - io::Error::new(io::ErrorKind::Other, err) -} - -// Load the public certificate from a file. -fn load_certs(filename: &str) -> io::Result>> { - // Open certificate file. - let certfile = fs::File::open(filename) - .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; - let mut reader = io::BufReader::new(certfile); - - // Load and return certificate. - rustls_pemfile::certs(&mut reader).collect() -} - -// Load the private key from a file. -fn load_private_key(filename: &str) -> io::Result> { - // Open keyfile. - let keyfile = fs::File::open(filename) - .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; - let mut reader = io::BufReader::new(keyfile); - - // Load and return a single private key. - rustls_pemfile::private_key(&mut reader).map(|key| key.unwrap()) -} - -// Custom echo service, handling two different routes and a -// catch-all 404/not-found responder. -async fn echo(req: Request) -> Result>, hyper::Error> { - let mut response = Response::new(Full::default()); - match (req.method(), req.uri().path()) { - // Help route. - (&Method::GET, "/") => { - *response.body_mut() = Full::from("Try POST /echo\n"); - } - // Echo service route. - (&Method::POST, "/echo") => { - *response.body_mut() = Full::from(req.into_body().collect().await?.to_bytes()); - } - // Catch-all 404. - _ => { - *response.status_mut() = StatusCode::NOT_FOUND; - } - }; - Ok(response) -} - -pub struct Server {} - -impl Server { - pub async fn serve(&self) -> Result<(), Box> { - // Get a random port from the OS - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - // Create a TCP listener bound to the random address - let incoming = TcpListener::bind(&addr).await?; - - // Load public certificate. - let certs = load_certs("examples/sample.pem")?; - - // Load private key. - let key = load_private_key("examples/sample.rsa")?; - - // Build TLS configuration. - let mut server_config = ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(certs, key) - .unwrap(); - - // Enable ALPN with HTTP/2 and HTTP/1.1 support. - server_config.alpn_protocols = - vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; - - // Create a rustls TlsAcceptor - let tls_acceptor = TlsAcceptor::from(Arc::new(server_config)); - - // Create a tower service - let service = tower::service_fn(echo); - let service = ServiceBuilder::new().layer_fn(Logger::new).service(service); - - // Convert it to a hyper service - let service = TowerToHyperService::new(service); - - // Begin the server loop - loop { - // Wait for an incoming tcp stream - let (tcp_stream, _remote_addr) = incoming.accept().await.unwrap(); - - // Clone a new instance of the tls_acceptor - let tls_acceptor = tls_acceptor.clone(); - - // Clone a new instance of the service - let service = service.clone(); - - // Spawn a new async task to handle the incoming connection - tokio::spawn(async move { - // Perform the TLS handshake - let tls_stream = match tls_acceptor.accept(tcp_stream).await { - Ok(tls_stream) => tls_stream, - Err(err) => { - eprintln!("failed to perform tls handshake: {err:#}"); - return; - } - }; - - // Serve the http connection - if let Err(err) = HttpConnBuilder::new(TokioExecutor::new()) - .serve_connection(TokioIo::new(tls_stream), service) - .await - { - eprintln!("failed to serve connection: {err:#}"); - } - }); - } - } - - /// Serves a single HTTP connection from a hyper service backend. - /// - /// This method handles an individual HTTP connection, processing requests through - /// the provided service and managing the connection lifecycle. - /// - /// # Type Parameters - /// - /// * `B`: The body type for the HTTP response. - /// * `IO`: The I/O type for the HTTP connection. - /// * `S`: The service type that processes HTTP requests. - /// * `E`: The executor type for the HTTP server connection. - /// - /// # Parameters - /// - /// * `hyper_io`: The I/O object representing the inbound hyper IO stream. - /// * `hyper_svc`: The hyper `Service` implementation used to process HTTP requests. - /// * `builder`: An `HttpConnBuilder` used to create and serve the HTTP connection. - /// * `watcher`: An optional `tokio::sync::watch::Receiver` for graceful shutdown signaling. - /// * `max_connection_age`: An optional `Duration` specifying the maximum age of the connection - /// before initiating a graceful shutdown. - async fn serve_http_connection( - hyper_io: IO, - hyper_service: S, - builder: HttpConnBuilder, - mut watcher: Option>, - max_connection_age: Option, - ) where - B: Body + Send + 'static, - B::Data: Send, - B::Error: Into> + Send + Sync, - IO: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, - S: HyperService, Response=Response> + Clone + Send + 'static, - S::Future: Send + 'static, - S::Error: Into> + Send, - E: HttpServerConnExec + Send + Sync + 'static, - { - // Spawn a new asynchronous task to handle the incoming hyper IO stream - tokio::spawn(async move { - { - // Set up a fused future for the watcher - let mut sig = pin!(Fuse { - inner: watcher.as_mut().map(|w| w.changed()), - }); - - // Create and pin the HTTP connection - let mut conn = pin!(builder.serve_connection(hyper_io, hyper_service)); - - // Set up the sleep future for max connection age - let sleep = sleep_or_pending(max_connection_age); - tokio::pin!(sleep); - - // Main loop for serving the HTTP connection - loop { - tokio::select! { - // Handle the connection result - rv = &mut conn => { - if let Err(err) = rv { - // Log any errors that occur while serving the HTTP connection - debug!("failed serving HTTP connection: {:#}", err); - } - break; - }, - // Handle max connection age timeout - _ = &mut sleep => { - // Initiate a graceful shutdown when max connection age is reached - conn.as_mut().graceful_shutdown(); - sleep.set(sleep_or_pending(None)); - }, - // Handle graceful shutdown signal - _ = &mut sig => { - // Initiate a graceful shutdown when signal is received - conn.as_mut().graceful_shutdown(); - } - } - } - } - - // Clean up and log connection closure - drop(watcher); - trace!("HTTP connection closed"); - }); - } -} - -#[cfg(test)] -mod tests { - #[tokio::test] - async fn test_echo_service() {} -} +pub(crate) type Error = Box; diff --git a/src/tcp.rs b/src/tcp.rs new file mode 100644 index 0000000..eebb05f --- /dev/null +++ b/src/tcp.rs @@ -0,0 +1,199 @@ +use crate::Error; +use std::{io, ops::ControlFlow}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_stream::{Stream, StreamExt}; +use tracing::debug; + +/// Handles errors that occur during TCP connection acceptance. +/// +/// This function determines whether an error should be treated as fatal (breaking the accept loop) +/// or non-fatal (allowing the loop to continue). +/// +/// # Arguments +/// +/// * `e` - The error to handle, which can be converted into the crate's `Error` type. +/// +/// # Returns +/// +/// * `ControlFlow::Continue(())` if the error is non-fatal and the accept loop should continue. +/// * `ControlFlow::Break(Error)` if the error is fatal and the accept loop should terminate. +pub(crate) fn handle_accept_error(e: impl Into) -> ControlFlow { + let e = e.into(); + debug!(error = %e, "TCP accept loop error"); + if let Some(e) = e.downcast_ref::() { + if matches!( + e.kind(), + io::ErrorKind::ConnectionAborted + | io::ErrorKind::Interrupted + | io::ErrorKind::InvalidData + | io::ErrorKind::WouldBlock + ) { + return ControlFlow::Continue(()); + } + } + + ControlFlow::Break(e) +} + +/// Creates a stream that yields a TCP stream for each incoming connection. +/// +/// This function takes a stream of incoming connections and handles errors that may occur +/// during the acceptance process. It will continue to yield connections even if non-fatal +/// errors occur, but will terminate if a fatal error is encountered. +/// +/// # Type Parameters +/// +/// * `IO`: The type of the I/O object yielded by the incoming stream. +/// * `IE`: The type of the error that can be produced by the incoming stream. +/// +/// # Arguments +/// +/// * `incoming`: A stream that yields results of incoming connection attempts. +/// +/// # Returns +/// +/// A pinned stream that yields `Result` for each incoming connection. +pub(crate) fn serve_tcp_incoming( + incoming: impl Stream> + Send + 'static, +) -> impl Stream> +where + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + IE: Into + Send + 'static, +{ + async_stream::stream! { + let mut incoming = Box::pin(incoming); + + while let Some(item) = incoming.next().await { + match item { + Ok(io) => yield Ok(io), + Err(e) => match handle_accept_error(e.into()) { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(e) => yield Err(e), + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::SocketAddr; + use std::pin::Pin; + use tokio::net::{TcpListener, TcpStream}; + use tokio_stream::wrappers::TcpListenerStream; + use tokio_stream::StreamExt; + + #[tokio::test] + async fn test_handle_accept_error() { + // Test non-fatal errors + let non_fatal_errors = vec![ + io::ErrorKind::ConnectionAborted, + io::ErrorKind::Interrupted, + io::ErrorKind::InvalidData, + io::ErrorKind::WouldBlock, + ]; + + for kind in non_fatal_errors { + let error = io::Error::new(kind, "Test error"); + assert!(matches!( + handle_accept_error(error), + ControlFlow::Continue(()) + )); + } + + // Test fatal error + let fatal_error = io::Error::new(io::ErrorKind::PermissionDenied, "Permission denied"); + assert!(matches!( + handle_accept_error(fatal_error), + ControlFlow::Break(_) + )); + } + + #[tokio::test] + async fn test_serve_tcp_incoming_success() -> Result<(), Box> { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await?; + let bound_addr = listener.local_addr()?; + let stream = TcpListenerStream::new(listener); + let mut incoming = Box::pin(serve_tcp_incoming(stream)); + + // Spawn a task to accept one connection + let accept_task = tokio::spawn(async move { incoming.next().await }); + + // Connect to the server + let _client = TcpStream::connect(bound_addr).await?; + + // Check that the connection was accepted + let result = accept_task.await?.unwrap(); + assert!(result.is_ok()); + + Ok(()) + } + + #[tokio::test] + async fn test_serve_tcp_incoming_with_errors() { + // Create a mock stream that yields both successful connections and errors + let mock_stream = tokio_stream::iter(vec![ + Ok(MockIO), + Err(io::Error::new(io::ErrorKind::ConnectionAborted, "Aborted")), + Ok(MockIO), + Err(io::Error::new( + io::ErrorKind::PermissionDenied, + "Permission denied", + )), + ]); + + let mut incoming = Box::pin(serve_tcp_incoming(mock_stream)); + + // First connection should be successful + assert!(incoming.next().await.unwrap().is_ok()); + + // Second connection (aborted) should be skipped + // Third connection should be successful + assert!(incoming.next().await.unwrap().is_ok()); + + // Fourth connection (permission denied) should break the stream + assert!(incoming.next().await.unwrap().is_err()); + + // Stream should be exhausted + assert!(incoming.next().await.is_none()); + } + + // Mock IO type for testing + struct MockIO; + + impl AsyncRead for MockIO { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + } + + impl AsyncWrite for MockIO { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + _buf: &[u8], + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(0)) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + } +} diff --git a/src/tls.rs b/src/tls.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/tls.rs @@ -0,0 +1 @@ + From 42beb8896a862b08bc7ee6b3b496e7300db0e898 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 14:58:18 -0400 Subject: [PATCH 07/45] chore: docs --- src/http.rs | 16 ++++++++++++---- src/tcp.rs | 21 +++++++++++++++++++-- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/src/http.rs b/src/http.rs index 959c505..62426f6 100644 --- a/src/http.rs +++ b/src/http.rs @@ -9,6 +9,14 @@ use std::time::Duration; use tokio::time::sleep; use tracing::{debug, trace}; +/// Sleeps for a specified duration or waits indefinitely. +/// +/// This function is used to implement timeouts or indefinite waiting periods. +/// +/// # Arguments +/// +/// * `wait_for` - An `Option` specifying how long to sleep. +/// If `None`, the function will wait indefinitely. async fn sleep_or_pending(wait_for: Option) { match wait_for { Some(wait) => sleep(wait).await, @@ -28,15 +36,15 @@ async fn sleep_or_pending(wait_for: Option) { /// * `S`: The service type that processes HTTP requests. /// * `E`: The executor type for the HTTP server connection. /// -/// # Parameters +/// # Arguments /// /// * `hyper_io`: The I/O object representing the inbound hyper IO stream. -/// * `hyper_svc`: The hyper `Service` implementation used to process HTTP requests. -/// * `builder`: An `HttpConnBuilder` used to create and serve the HTTP connection. +/// * `hyper_service`: The hyper `Service` implementation used to process HTTP requests. +/// * `builder`: A `Builder` used to create and serve the HTTP connection. /// * `watcher`: An optional `tokio::sync::watch::Receiver` for graceful shutdown signaling. /// * `max_connection_age`: An optional `Duration` specifying the maximum age of the connection /// before initiating a graceful shutdown. -async fn serve_http_connection( +pub(crate) async fn serve_http_connection( hyper_io: IO, hyper_service: S, builder: Builder, diff --git a/src/tcp.rs b/src/tcp.rs index eebb05f..7b72c96 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -17,10 +17,19 @@ use tracing::debug; /// /// * `ControlFlow::Continue(())` if the error is non-fatal and the accept loop should continue. /// * `ControlFlow::Break(Error)` if the error is fatal and the accept loop should terminate. -pub(crate) fn handle_accept_error(e: impl Into) -> ControlFlow { +/// +/// # Error Handling +/// +/// The function categorizes errors as follows: +/// - Non-fatal errors: ConnectionAborted, Interrupted, InvalidData, WouldBlock +/// - Fatal errors: All other error types +fn handle_accept_error(e: impl Into) -> ControlFlow { let e = e.into(); debug!(error = %e, "TCP accept loop error"); + + // Check if the error is an I/O error if let Some(e) = e.downcast_ref::() { + // Determine if the error is non-fatal if matches!( e.kind(), io::ErrorKind::ConnectionAborted @@ -32,6 +41,7 @@ pub(crate) fn handle_accept_error(e: impl Into) -> ControlFlow { } } + // If not a non-fatal I/O error, treat as fatal ControlFlow::Break(e) } @@ -52,7 +62,13 @@ pub(crate) fn handle_accept_error(e: impl Into) -> ControlFlow { /// /// # Returns /// -/// A pinned stream that yields `Result` for each incoming connection. +/// A stream that yields `Result` for each incoming connection. +/// +/// # Error Handling +/// +/// This function uses `handle_accept_error` to determine whether to continue accepting +/// connections after an error occurs. Non-fatal errors are logged and skipped, while +/// fatal errors cause the stream to yield an error and terminate. pub(crate) fn serve_tcp_incoming( incoming: impl Stream> + Send + 'static, ) -> impl Stream> @@ -112,6 +128,7 @@ mod tests { #[tokio::test] async fn test_serve_tcp_incoming_success() -> Result<(), Box> { + // Set up a TCP listener let addr = SocketAddr::from(([127, 0, 0, 1], 0)); let listener = TcpListener::bind(addr).await?; let bound_addr = listener.local_addr()?; From 82e39be7ee5169a39458d2e38fae64ba5de9e423 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 15:36:52 -0400 Subject: [PATCH 08/45] feat: higher level server and tests for http --- Cargo.toml | 5 +- src/http.rs | 259 ++++++++++++++++++++++++++++++++++++++++++++++++++-- src/lib.rs | 2 + 3 files changed, 257 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3540a4b..a74cd5f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ hyper = "1.4.1" hyper-util = { version = "0.1.8", features = ["server", "tokio", "server-auto", "server-graceful", "service"] } rustls = "0.23.13" rustls-pemfile = "2.1.3" -tokio = { version = "1.40.0", features = ["net"] } +tokio = { version = "1.40.0", features = ["net", "macros"] } tokio-rustls = "0.26.0" tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" @@ -30,5 +30,6 @@ async-stream = "0.3.5" futures = "0.3.30" [dev-dependencies] -tokio = { version = "1.0", features = ["rt", "net", "test-util", "macros"] } +tokio = { version = "1.0", features = ["rt", "net", "test-util"] } tokio-test = "0.4.4" +hyper = {version = "1.4.1", features = ["client"] } diff --git a/src/http.rs b/src/http.rs index 62426f6..32614b1 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,13 +1,36 @@ +use tokio_stream::StreamExt as _; +use tracing::{debug, trace}; +use hyper_util::{ + rt::{TokioExecutor, TokioIo, TokioTimer}, + server::conn::auto::{Builder as HttpConnectionBuilder, HttpServerConnExec}, + service::TowerToHyperService, +}; +use bytes::Bytes; use http::{Request, Response}; +use http_body_util::BodyExt; +use hyper::{body::Incoming, service::Service as HyperService}; +use pin_project::pin_project; +use std::future::pending; +use std::{ + convert::Infallible, + fmt, + future::{self, poll_fn, Future}, + marker::PhantomData, + net::SocketAddr, + pin::{pin, Pin}, + sync::Arc, + task::{ready, Context, Poll}, + time::Duration, +}; +use futures::Sink; use http_body::Body; -use hyper::body::Incoming; use hyper::service::Service; -use hyper_util::server::conn::auto::{Builder, HttpServerConnExec}; -use std::future::pending; -use std::pin::pin; -use std::time::Duration; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpListener; use tokio::time::sleep; -use tracing::{debug, trace}; +use tokio_stream::Stream; +use tokio_stream::wrappers::TcpListenerStream; +use crate::fuse::Fuse; /// Sleeps for a specified duration or waits indefinitely. /// @@ -47,7 +70,7 @@ async fn sleep_or_pending(wait_for: Option) { pub(crate) async fn serve_http_connection( hyper_io: IO, hyper_service: S, - builder: Builder, + builder: HttpConnectionBuilder, mut watcher: Option>, max_connection_age: Option, ) where @@ -106,3 +129,225 @@ pub(crate) async fn serve_http_connection( trace!("HTTP connection closed"); }); } + +pub(crate) async fn serve_http_with_shutdown( + service: S, + incoming: I, + signal: Option, +) -> Result<(), super::Error> +where + F: Future, + I: Stream> + Send + 'static, + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + IE: Into + Send + 'static, + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into> + Send, + ResBody: Body + Send + Sync + 'static, + ResBody::Error: Into + Send + Sync, +{ + let incoming = crate::tcp::serve_tcp_incoming( + incoming + ); + + let server = { + let mut builder = HttpConnectionBuilder::new(TokioExecutor::new()); + builder + }; + + let (signal_tx, signal_rx) = tokio::sync::watch::channel(()); + let signal_tx = Arc::new(signal_tx); + + let graceful = signal.is_some(); + let mut sig = pin!(Fuse { inner: signal }); + let mut incoming = pin!(incoming); + + loop { + tokio::select! { + _ = &mut sig => { + trace!("signal received, shutting down"); + break; + }, + io = incoming.next() => { + let io = match io { + Some(Ok(io)) => io, + Some(Err(e)) => { + trace!("error accepting connection: {:#}", e); + continue; + }, + None => { + break + }, + }; + + trace!("connection accepted"); + + let hyper_io = TokioIo::new(io); + let hyper_svc = service.clone(); + + serve_http_connection(hyper_io, hyper_svc, server.clone(), graceful.then(|| signal_rx.clone()), None).await; + } + } + } + + if graceful { + let _ = signal_tx.send(()); + drop(signal_rx); + trace!( + "waiting for {} connections to close", + signal_tx.receiver_count() + ); + + // Wait for all connections to close + signal_tx.closed().await; + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::SocketAddr; + use std::time::Duration; + use tokio::net::{TcpStream, TcpListener}; + use tokio::sync::oneshot; + use bytes::Bytes; + use http_body_util::{BodyExt, Empty, Full}; + use hyper::{body::Incoming, Request, Response, StatusCode}; + use tokio_stream::wrappers::TcpListenerStream; + use tower::ServiceExt; + + // Echo service + async fn echo(req: Request) -> Result>, hyper::Error> { + match (req.method(), req.uri().path()) { + (&hyper::Method::GET, "/") => Ok(Response::new(Full::new(Bytes::from("Hello, World!")))), + (&hyper::Method::POST, "/echo") => { + let body = req.collect().await?.to_bytes(); + Ok(Response::new(Full::new(body))) + } + _ => { + let mut res = Response::new(Full::new(Bytes::from("Not Found"))); + *res.status_mut() = StatusCode::NOT_FOUND; + Ok(res) + } + } + } + + async fn setup_test_server(addr: SocketAddr) -> (TcpListenerStream, SocketAddr) { + let listener = TcpListener::bind(addr).await.unwrap(); + let server_addr = listener.local_addr().unwrap(); + let incoming = TcpListenerStream::new(listener); + (incoming, server_addr) + } + + async fn send_request(addr: SocketAddr, req: Request>) -> hyper::Result> { + let stream = TcpStream::connect(addr).await.unwrap(); + let io = TokioIo::new(stream); + + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; + tokio::spawn(async move { + if let Err(err) = conn.await { + eprintln!("Connection failed: {:?}", err); + } + }); + + sender.send_request(req).await + } + + #[tokio::test] + async fn test_serve_http_with_shutdown_basic() { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let (incoming, server_addr) = setup_test_server(addr).await; + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let tower_service_fn = tower::service_fn(echo); + let hyper_service = TowerToHyperService::new(tower_service_fn); + + let server = tokio::spawn(serve_http_with_shutdown( + hyper_service, + incoming, + Some(async { + shutdown_rx.await.ok(); + }), + )); + + // Test GET request + let req = Request::builder() + .uri("/") + .body(Empty::::new()) + .unwrap(); + let res = send_request(server_addr, req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let body = res.collect().await.unwrap().to_bytes(); + assert_eq!(&body[..], b"Hello, World!"); + + // Test POST request + let req = Request::builder() + .method(hyper::Method::POST) + .uri("/echo") + .body(Empty::::new()) + .unwrap(); + let res = send_request(server_addr, req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + // Test 404 response + let req = Request::builder() + .uri("/not_found") + .body(Empty::::new()) + .unwrap(); + let res = send_request(server_addr, req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + // Shutdown the server + shutdown_tx.send(()).unwrap(); + tokio::time::timeout(Duration::from_secs(5), server).await + .expect("Server didn't shut down within the timeout period") + .unwrap() + .unwrap(); + } + + #[tokio::test] + async fn test_serve_http_with_concurrent_requests() { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let (incoming, server_addr) = setup_test_server(addr).await; + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let tower_service_fn = tower::service_fn(echo); + let hyper_service = TowerToHyperService::new(tower_service_fn); + + let server = tokio::spawn(serve_http_with_shutdown( + hyper_service, + incoming, + Some(async { + shutdown_rx.await.ok(); + }), + )); + + let mut handles = vec![]; + for _ in 0..10 { + let handle = tokio::spawn(async move { + let req = Request::builder() + .uri("/") + .body(Empty::::new()) + .unwrap(); + let res = send_request(server_addr, req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + }); + handles.push(handle); + } + + for handle in handles { + handle.await.unwrap(); + } + + // Shutdown the server + shutdown_tx.send(()).unwrap(); + tokio::time::timeout(Duration::from_secs(5), server).await + .expect("Server didn't shut down within the timeout period") + .unwrap() + .unwrap(); + } +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 9a714c1..7fc292f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,3 +5,5 @@ mod tcp; mod tls; pub(crate) type Error = Box; + + From 215d7be0a810eef87017657e46fef075a9bc462b Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 15:41:06 -0400 Subject: [PATCH 09/45] chore: add docs --- src/http.rs | 94 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 65 insertions(+), 29 deletions(-) diff --git a/src/http.rs b/src/http.rs index 32614b1..9c4dcee 100644 --- a/src/http.rs +++ b/src/http.rs @@ -130,6 +130,29 @@ pub(crate) async fn serve_http_connection( }); } +/// Serves HTTP requests with graceful shutdown capability. +/// +/// This function sets up an HTTP server that can handle incoming connections and +/// process requests using the provided service. It also supports graceful shutdown. +/// +/// # Type Parameters +/// +/// * `S`: The service type that processes HTTP requests. +/// * `I`: The incoming stream of IO objects. +/// * `F`: The future type for the shutdown signal. +/// * `IO`: The I/O type for the HTTP connection. +/// * `IE`: The error type for the incoming stream. +/// * `ResBody`: The response body type. +/// +/// # Arguments +/// +/// * `service`: The service used to process HTTP requests. +/// * `incoming`: The stream of incoming connections. +/// * `signal`: An optional future that, when resolved, signals the server to shut down gracefully. +/// +/// # Returns +/// +/// A `Result` indicating success or failure of the server operation. pub(crate) async fn serve_http_with_shutdown( service: S, incoming: I, @@ -146,15 +169,16 @@ where ResBody: Body + Send + Sync + 'static, ResBody::Error: Into + Send + Sync, { - let incoming = crate::tcp::serve_tcp_incoming( - incoming - ); + // Prepare the incoming stream of TCP connections + let incoming = crate::tcp::serve_tcp_incoming(incoming); + // Set up the HTTP connection builder let server = { let mut builder = HttpConnectionBuilder::new(TokioExecutor::new()); builder }; + // Create a channel for signaling graceful shutdown let (signal_tx, signal_rx) = tokio::sync::watch::channel(()); let signal_tx = Arc::new(signal_tx); @@ -162,41 +186,53 @@ where let mut sig = pin!(Fuse { inner: signal }); let mut incoming = pin!(incoming); + // Main server loop loop { tokio::select! { - _ = &mut sig => { - trace!("signal received, shutting down"); - break; - }, - io = incoming.next() => { - let io = match io { - Some(Ok(io)) => io, - Some(Err(e)) => { - trace!("error accepting connection: {:#}", e); - continue; - }, - None => { - break - }, - }; - - trace!("connection accepted"); - - let hyper_io = TokioIo::new(io); - let hyper_svc = service.clone(); - - serve_http_connection(hyper_io, hyper_svc, server.clone(), graceful.then(|| signal_rx.clone()), None).await; - } + // Handle shutdown signal + _ = &mut sig => { + trace!("signal received, shutting down"); + break; + }, + // Handle incoming connections + io = incoming.next() => { + let io = match io { + Some(Ok(io)) => io, + Some(Err(e)) => { + trace!("error accepting connection: {:#}", e); + continue; + }, + None => { + break + }, + }; + + trace!("connection accepted"); + + // Prepare the connection for hyper + let hyper_io = TokioIo::new(io); + let hyper_svc = service.clone(); + + // Serve the HTTP connection + serve_http_connection( + hyper_io, + hyper_svc, + server.clone(), + graceful.then(|| signal_rx.clone()), + None + ).await; } + } } + // Handle graceful shutdown if graceful { let _ = signal_tx.send(()); drop(signal_rx); trace!( - "waiting for {} connections to close", - signal_tx.receiver_count() - ); + "waiting for {} connections to close", + signal_tx.receiver_count() + ); // Wait for all connections to close signal_tx.closed().await; From 6a20e2e38c6d865210a6e670a48e037341d4f53c Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 15:43:41 -0400 Subject: [PATCH 10/45] fix: cleanup --- Cargo.toml | 17 +++++--------- src/http.rs | 68 +++++++++++++++++++++++------------------------------ src/lib.rs | 2 -- 3 files changed, 36 insertions(+), 51 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a74cd5f..14050bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,24 +12,19 @@ repository = "https://github.com/valorem-labs-inc/hyper-server" version = "1.0.0" [dependencies] +async-stream = "0.3.5" +bytes = "1.7.1" http = "1.1.0" +http-body = "1.0.1" http-body-util = "0.1.2" hyper = "1.4.1" hyper-util = { version = "0.1.8", features = ["server", "tokio", "server-auto", "server-graceful", "service"] } -rustls = "0.23.13" -rustls-pemfile = "2.1.3" +pin-project = "1.1.5" tokio = { version = "1.40.0", features = ["net", "macros"] } -tokio-rustls = "0.26.0" +tokio-stream = { version = "0.1.16", features = ["net"] } tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" -http-body = "1.0.1" -tokio-stream = { version = "0.1.16", features = ["net"] } -bytes = "1.7.1" -pin-project = "1.1.5" -async-stream = "0.3.5" -futures = "0.3.30" [dev-dependencies] +hyper = { version = "1.4.1", features = ["client"] } tokio = { version = "1.0", features = ["rt", "net", "test-util"] } -tokio-test = "0.4.4" -hyper = {version = "1.4.1", features = ["client"] } diff --git a/src/http.rs b/src/http.rs index 9c4dcee..cebf5c2 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,35 +1,21 @@ -use tokio_stream::StreamExt as _; -use tracing::{debug, trace}; -use hyper_util::{ - rt::{TokioExecutor, TokioIo, TokioTimer}, - server::conn::auto::{Builder as HttpConnectionBuilder, HttpServerConnExec}, - service::TowerToHyperService, -}; +use std::future::pending; +use std::{future::Future, pin::pin, sync::Arc, time::Duration}; + use bytes::Bytes; use http::{Request, Response}; -use http_body_util::BodyExt; -use hyper::{body::Incoming, service::Service as HyperService}; -use pin_project::pin_project; -use std::future::pending; -use std::{ - convert::Infallible, - fmt, - future::{self, poll_fn, Future}, - marker::PhantomData, - net::SocketAddr, - pin::{pin, Pin}, - sync::Arc, - task::{ready, Context, Poll}, - time::Duration, -}; -use futures::Sink; use http_body::Body; +use hyper::body::Incoming; use hyper::service::Service; +use hyper_util::{ + rt::{TokioExecutor, TokioIo}, + server::conn::auto::{Builder as HttpConnectionBuilder, HttpServerConnExec}, +}; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::TcpListener; use tokio::time::sleep; use tokio_stream::Stream; -use tokio_stream::wrappers::TcpListenerStream; +use tokio_stream::StreamExt as _; +use tracing::{debug, trace}; + use crate::fuse::Fuse; /// Sleeps for a specified duration or waits indefinitely. @@ -173,10 +159,7 @@ where let incoming = crate::tcp::serve_tcp_incoming(incoming); // Set up the HTTP connection builder - let server = { - let mut builder = HttpConnectionBuilder::new(TokioExecutor::new()); - builder - }; + let server = { HttpConnectionBuilder::new(TokioExecutor::new()) }; // Create a channel for signaling graceful shutdown let (signal_tx, signal_rx) = tokio::sync::watch::channel(()); @@ -243,21 +226,25 @@ where #[cfg(test)] mod tests { - use super::*; use std::net::SocketAddr; use std::time::Duration; - use tokio::net::{TcpStream, TcpListener}; - use tokio::sync::oneshot; + use bytes::Bytes; use http_body_util::{BodyExt, Empty, Full}; use hyper::{body::Incoming, Request, Response, StatusCode}; + use hyper_util::service::TowerToHyperService; + use tokio::net::{TcpListener, TcpStream}; + use tokio::sync::oneshot; use tokio_stream::wrappers::TcpListenerStream; - use tower::ServiceExt; + + use super::*; // Echo service async fn echo(req: Request) -> Result>, hyper::Error> { match (req.method(), req.uri().path()) { - (&hyper::Method::GET, "/") => Ok(Response::new(Full::new(Bytes::from("Hello, World!")))), + (&hyper::Method::GET, "/") => { + Ok(Response::new(Full::new(Bytes::from("Hello, World!")))) + } (&hyper::Method::POST, "/echo") => { let body = req.collect().await?.to_bytes(); Ok(Response::new(Full::new(body))) @@ -277,7 +264,10 @@ mod tests { (incoming, server_addr) } - async fn send_request(addr: SocketAddr, req: Request>) -> hyper::Result> { + async fn send_request( + addr: SocketAddr, + req: Request>, + ) -> hyper::Result> { let stream = TcpStream::connect(addr).await.unwrap(); let io = TokioIo::new(stream); @@ -338,7 +328,8 @@ mod tests { // Shutdown the server shutdown_tx.send(()).unwrap(); - tokio::time::timeout(Duration::from_secs(5), server).await + tokio::time::timeout(Duration::from_secs(5), server) + .await .expect("Server didn't shut down within the timeout period") .unwrap() .unwrap(); @@ -381,9 +372,10 @@ mod tests { // Shutdown the server shutdown_tx.send(()).unwrap(); - tokio::time::timeout(Duration::from_secs(5), server).await + tokio::time::timeout(Duration::from_secs(5), server) + .await .expect("Server didn't shut down within the timeout period") .unwrap() .unwrap(); } -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index 7fc292f..9a714c1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,5 +5,3 @@ mod tcp; mod tls; pub(crate) type Error = Box; - - From 8192854c9cf78001ddbc07232e4bfee1feef2f0a Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 15:47:54 -0400 Subject: [PATCH 11/45] fix: expose builder pattern --- src/http.rs | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/http.rs b/src/http.rs index cebf5c2..6ea52f6 100644 --- a/src/http.rs +++ b/src/http.rs @@ -139,9 +139,10 @@ pub(crate) async fn serve_http_connection( /// # Returns /// /// A `Result` indicating success or failure of the server operation. -pub(crate) async fn serve_http_with_shutdown( +pub(crate) async fn serve_http_with_shutdown( service: S, incoming: I, + builder: HttpConnectionBuilder, signal: Option, ) -> Result<(), super::Error> where @@ -154,13 +155,11 @@ where S::Error: Into> + Send, ResBody: Body + Send + Sync + 'static, ResBody::Error: Into + Send + Sync, + E: HttpServerConnExec + Send + Sync + 'static, { // Prepare the incoming stream of TCP connections let incoming = crate::tcp::serve_tcp_incoming(incoming); - // Set up the HTTP connection builder - let server = { HttpConnectionBuilder::new(TokioExecutor::new()) }; - // Create a channel for signaling graceful shutdown let (signal_tx, signal_rx) = tokio::sync::watch::channel(()); let signal_tx = Arc::new(signal_tx); @@ -200,7 +199,7 @@ where serve_http_connection( hyper_io, hyper_svc, - server.clone(), + builder.clone(), graceful.then(|| signal_rx.clone()), None ).await; @@ -288,12 +287,15 @@ mod tests { let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let tower_service_fn = tower::service_fn(echo); let hyper_service = TowerToHyperService::new(tower_service_fn); let server = tokio::spawn(serve_http_with_shutdown( hyper_service, incoming, + http_server_builder, Some(async { shutdown_rx.await.ok(); }), @@ -342,12 +344,15 @@ mod tests { let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let tower_service_fn = tower::service_fn(echo); let hyper_service = TowerToHyperService::new(tower_service_fn); let server = tokio::spawn(serve_http_with_shutdown( hyper_service, incoming, + http_server_builder, Some(async { shutdown_rx.await.ok(); }), From 25d98a35ea4864364dfcbba995689b3c005c16c5 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 15:55:38 -0400 Subject: [PATCH 12/45] feat: basic cut of TLS --- Cargo.toml | 3 +++ src/tls.rs | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 14050bf..c9231b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,10 @@ http-body-util = "0.1.2" hyper = "1.4.1" hyper-util = { version = "0.1.8", features = ["server", "tokio", "server-auto", "server-graceful", "service"] } pin-project = "1.1.5" +rustls = "0.23.13" +rustls-pemfile = "2.1.3" tokio = { version = "1.40.0", features = ["net", "macros"] } +tokio-rustls = "0.26.0" tokio-stream = { version = "0.1.16", features = ["net"] } tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" diff --git a/src/tls.rs b/src/tls.rs index 8b13789..c9d03cc 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1 +1,52 @@ +use std::sync::Arc; +use std::{fs, io}; +use tokio::net::TcpStream; +use tokio_rustls::TlsAcceptor; +use rustls::ServerConfig; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +fn error(err: String) -> io::Error { + io::Error::new(io::ErrorKind::Other, err) +} + +pub fn create_tls_acceptor(cert_path: &str, key_path: &str) -> io::Result { + // Load public certificate. + let certs = load_certs(cert_path)?; + // Load private key. + let key = load_private_key(key_path)?; + + // Build TLS configuration. + let mut server_config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| error(e.to_string()))?; + server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; + + Ok(TlsAcceptor::from(Arc::new(server_config))) +} + +// Load public certificate from file. +fn load_certs(filename: &str) -> io::Result>> { + // Open certificate file. + let certfile = fs::File::open(filename) + .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; + let mut reader = io::BufReader::new(certfile); + + // Load and return certificate. + rustls_pemfile::certs(&mut reader).collect() +} + +// Load private key from file. +fn load_private_key(filename: &str) -> io::Result> { + // Open keyfile. + let keyfile = fs::File::open(filename) + .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; + let mut reader = io::BufReader::new(keyfile); + + // Load and return a single private key. + rustls_pemfile::private_key(&mut reader).map(|key| key.unwrap()) +} + +pub async fn tls_accept(acceptor: TlsAcceptor, tcp_stream: TcpStream) -> Result, std::io::Error> { + acceptor.accept(tcp_stream).await.map_err(|e| error(format!("failed to perform tls handshake: {}", e))) +} \ No newline at end of file From 51135057b75f469fd821cc88b4dc3485e2bf67f8 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 17:15:48 -0400 Subject: [PATCH 13/45] feat: basic TLS --- Cargo.toml | 1 + src/tls.rs | 70 ++++++++++++++++-------------------------------------- 2 files changed, 21 insertions(+), 50 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c9231b4..8b0570a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ tokio-rustls = "0.26.0" tokio-stream = { version = "0.1.16", features = ["net"] } tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" +futures = "0.3.30" [dev-dependencies] hyper = { version = "1.4.1", features = ["client"] } diff --git a/src/tls.rs b/src/tls.rs index c9d03cc..ead120e 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,52 +1,22 @@ -use std::sync::Arc; -use std::{fs, io}; -use tokio::net::TcpStream; +use crate::Error; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::TlsAcceptor; -use rustls::ServerConfig; -use rustls::pki_types::{CertificateDer, PrivateKeyDer}; - -fn error(err: String) -> io::Error { - io::Error::new(io::ErrorKind::Other, err) -} - -pub fn create_tls_acceptor(cert_path: &str, key_path: &str) -> io::Result { - // Load public certificate. - let certs = load_certs(cert_path)?; - // Load private key. - let key = load_private_key(key_path)?; - - // Build TLS configuration. - let mut server_config = ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(certs, key) - .map_err(|e| error(e.to_string()))?; - server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; - - Ok(TlsAcceptor::from(Arc::new(server_config))) -} - -// Load public certificate from file. -fn load_certs(filename: &str) -> io::Result>> { - // Open certificate file. - let certfile = fs::File::open(filename) - .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; - let mut reader = io::BufReader::new(certfile); - - // Load and return certificate. - rustls_pemfile::certs(&mut reader).collect() -} - -// Load private key from file. -fn load_private_key(filename: &str) -> io::Result> { - // Open keyfile. - let keyfile = fs::File::open(filename) - .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; - let mut reader = io::BufReader::new(keyfile); - - // Load and return a single private key. - rustls_pemfile::private_key(&mut reader).map(|key| key.unwrap()) +use tokio_stream::{Stream, StreamExt}; + +pub(crate) fn tls_incoming( + tcp_stream: impl Stream>, + tls: TlsAcceptor, +) -> impl Stream, Error>> +where + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + tcp_stream.then(move |result| { + let tls = tls.clone(); + async move { + match result { + Ok(io) => tls.accept(io).await.map_err(Error::from), + Err(e) => Err(e), + } + } + }) } - -pub async fn tls_accept(acceptor: TlsAcceptor, tcp_stream: TcpStream) -> Result, std::io::Error> { - acceptor.accept(tcp_stream).await.map_err(|e| error(format!("failed to perform tls handshake: {}", e))) -} \ No newline at end of file From fb2cc7f87e25d37fb076a6d3dd8b370995c9d47d Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 18:56:21 -0400 Subject: [PATCH 14/45] chore: docs --- src/error.rs | 21 ++++++++++++++++++++- src/fuse.rs | 16 +++++++++++++--- src/http.rs | 7 ++++--- src/tls.rs | 31 +++++++++++++++++++++++++++++++ 4 files changed, 68 insertions(+), 7 deletions(-) diff --git a/src/error.rs b/src/error.rs index 2c38bb4..4624ac3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,38 +1,51 @@ use std::{error::Error as StdError, fmt}; +/// A type alias for the source of an error, which is a boxed trait object. +/// This allows for dynamic dispatch and type erasure of the original error type. type Source = Box; -/// Errors that originate from the server; +/// Represents errors that originate from the server. +/// This struct provides a public API for error handling. pub struct Error { inner: ErrorImpl, } +/// The internal implementation of the Error struct. +/// This separation allows for better control over the public API. struct ErrorImpl { kind: Kind, source: Option, } +/// Enum representing different kinds of errors that can occur. +/// Currently, only includes a Transport variant, but can be extended for more error types. #[derive(Debug)] pub(crate) enum Kind { Transport, } impl Error { + /// Creates a new Error with a specific kind. pub(crate) fn new(kind: Kind) -> Self { Self { inner: ErrorImpl { kind, source: None }, } } + /// Attaches a source error to this Error. + /// This method consumes self and returns a new Error, allowing for method chaining. pub(crate) fn with(mut self, source: impl Into) -> Self { self.inner.source = Some(source.into()); self } + /// Creates a new Transport Error with the given source. + /// This is a convenience method combining new() and with(). pub(crate) fn from_source(source: impl Into) -> Self { Error::new(Kind::Transport).with(source) } + /// Returns a string slice describing the error. fn description(&self) -> &str { match &self.inner.kind { Kind::Transport => "transport error", @@ -40,6 +53,8 @@ impl Error { } } +/// Implements the Debug trait for Error. +/// This provides a custom debug representation of the error. impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut f = f.debug_tuple("tonic::transport::Error"); @@ -54,12 +69,16 @@ impl fmt::Debug for Error { } } +/// Implements the Display trait for Error. +/// This provides a user-friendly string representation of the error. impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(self.description()) } } +/// Implements the std::error::Error trait for Error. +/// This allows our custom Error to be used with the standard error handling mechanisms in Rust. impl StdError for Error { fn source(&self) -> Option<&(dyn StdError + 'static)> { self.inner diff --git a/src/fuse.rs b/src/fuse.rs index f75bb02..cff01c3 100644 --- a/src/fuse.rs +++ b/src/fuse.rs @@ -3,11 +3,14 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -// From `futures-util` crate, borrowed since this is the only dependency hyper-server requires. -// LICENSE: MIT or Apache-2.0 -// A future which only yields `Poll::Ready` once, and thereafter yields `Poll::Pending`. +/// `Fuse` is a wrapper around a future that ensures it can only complete once. +/// After the wrapped future completes, all subsequent polls will return `Poll::Pending`. +/// +/// This struct is borrowed from the `futures-util` crate and is used in the hyper server. +/// LICENSE: MIT or Apache-2.0 #[pin_project] pub(crate) struct Fuse { + /// The wrapped future. Once it completes, this will be set to `None`. #[pin] pub(crate) inner: Option, } @@ -16,14 +19,21 @@ impl Future for Fuse where F: Future, { + /// The output type is the same as the wrapped future's output type. type Output = F::Output; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // Match on the pinned inner future match self.as_mut().project().inner.as_pin_mut() { + // If we have a future, poll it Some(fut) => fut.poll(cx).map(|output| { + // If the future completes, set inner to None + // This ensures that future calls to poll will return Poll::Pending self.project().inner.set(None); output }), + // If inner is None, it means the future has already completed + // So we return Poll::Pending None => Poll::Pending, } } diff --git a/src/http.rs b/src/http.rs index 6ea52f6..66c76ee 100644 --- a/src/http.rs +++ b/src/http.rs @@ -123,12 +123,13 @@ pub(crate) async fn serve_http_connection( /// /// # Type Parameters /// -/// * `S`: The service type that processes HTTP requests. -/// * `I`: The incoming stream of IO objects. +/// * `E`: The executor type for the HTTP server connection. /// * `F`: The future type for the shutdown signal. +/// * `I`: The incoming stream of IO objects. /// * `IO`: The I/O type for the HTTP connection. /// * `IE`: The error type for the incoming stream. /// * `ResBody`: The response body type. +/// * `S`: The service type that processes HTTP requests. /// /// # Arguments /// @@ -139,7 +140,7 @@ pub(crate) async fn serve_http_connection( /// # Returns /// /// A `Result` indicating success or failure of the server operation. -pub(crate) async fn serve_http_with_shutdown( +pub(crate) async fn serve_http_with_shutdown( service: S, incoming: I, builder: HttpConnectionBuilder, diff --git a/src/tls.rs b/src/tls.rs index ead120e..566c713 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -3,6 +3,32 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::TlsAcceptor; use tokio_stream::{Stream, StreamExt}; +/// Creates a stream of TLS-encrypted connections from a stream of TCP connections. +/// +/// This function takes a stream of TCP connections and a TLS acceptor, and produces +/// a new stream that yields TLS-encrypted connections. It handles both the successful +/// case of establishing a TLS connection and the error cases. +/// +/// # Type Parameters +/// +/// * `IO`: The I/O type representing the underlying TCP connection. It must implement +/// `AsyncRead`, `AsyncWrite`, `Unpin`, `Send`, and have a static lifetime. +/// +/// # Arguments +/// +/// * `tcp_stream`: A stream that yields `Result` items, representing incoming +/// TCP connections or errors. +/// * `tls`: A `TlsAcceptor` used to perform the TLS handshake on each TCP connection. +/// +/// # Returns +/// +/// A new `Stream` that yields `Result, Error>` items. +/// Each item is either a successfully established TLS connection or an error. +/// +/// # Error Handling +/// +/// - If the input `tcp_stream` yields an error, that error is propagated. +/// - If the TLS handshake fails, the error is wrapped in the crate's `Error` type. pub(crate) fn tls_incoming( tcp_stream: impl Stream>, tls: TlsAcceptor, @@ -10,11 +36,16 @@ pub(crate) fn tls_incoming( where IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, { + // Transform each item in the TCP stream into a TLS stream tcp_stream.then(move |result| { + // Clone the TLS acceptor for each connection let tls = tls.clone(); + async move { match result { + // TODO(Can we get at the raw IO here so that it looks the same after the handshake?) Ok(io) => tls.accept(io).await.map_err(Error::from), + // TODO(Unwrap into crate error and handle) Err(e) => Err(e), } } From 116c27dfcd7d19d2743b78b24ab4504e7613b6ca Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 20:40:14 -0400 Subject: [PATCH 15/45] feat: add file loader helpers --- src/tls.rs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/tls.rs b/src/tls.rs index 566c713..99365d8 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,3 +1,5 @@ +use std::{fs, io}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use crate::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::TlsAcceptor; @@ -51,3 +53,23 @@ where } }) } + +// Load the public certificate from a file. +fn load_certs(filename: &str) -> io::Result>> { + // Open certificate file. + let certfile = fs::File::open(filename).unwrap(); + let mut reader = io::BufReader::new(certfile); + + // Load and return certificate. + rustls_pemfile::certs(&mut reader).collect() +} + +// Load the private key from a file. +fn load_private_key(filename: &str) -> io::Result> { + // Open keyfile. + let keyfile = fs::File::open(filename).unwrap(); + let mut reader = io::BufReader::new(keyfile); + + // Load and return a single private key. + rustls_pemfile::private_key(&mut reader).map(|key| key.unwrap()) +} \ No newline at end of file From 386e93a7477ab3885e77ce3d8906786280c13d8d Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 20:41:13 -0400 Subject: [PATCH 16/45] fix: add basic error handling and docs --- Cargo.toml | 2 +- src/tls.rs | 59 ++++++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8b0570a..603878a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ version = "1.0.0" [dependencies] async-stream = "0.3.5" bytes = "1.7.1" +futures = "0.3.30" http = "1.1.0" http-body = "1.0.1" http-body-util = "0.1.2" @@ -27,7 +28,6 @@ tokio-rustls = "0.26.0" tokio-stream = { version = "0.1.16", features = ["net"] } tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" -futures = "0.3.30" [dev-dependencies] hyper = { version = "1.4.1", features = ["client"] } diff --git a/src/tls.rs b/src/tls.rs index 99365d8..947e91c 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,6 +1,6 @@ -use std::{fs, io}; -use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use crate::Error; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use std::{fs, io}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::TlsAcceptor; use tokio_stream::{Stream, StreamExt}; @@ -41,35 +41,64 @@ where // Transform each item in the TCP stream into a TLS stream tcp_stream.then(move |result| { // Clone the TLS acceptor for each connection + // This is necessary because the acceptor is moved into the async block let tls = tls.clone(); async move { match result { - // TODO(Can we get at the raw IO here so that it looks the same after the handshake?) - Ok(io) => tls.accept(io).await.map_err(Error::from), - // TODO(Unwrap into crate error and handle) + // If the TCP connection was successfully established + Ok(io) => { + // Attempt to perform the TLS handshake + // If successful, return the TLS stream; otherwise, wrap the error + tls.accept(io).await.map_err(Error::from) + } + // If there was an error establishing the TCP connection, propagate it Err(e) => Err(e), } } }) } -// Load the public certificate from a file. +/// Load the public certificate from a file. +/// +/// This function reads a PEM-encoded certificate file and returns a vector of +/// parsed certificates. +/// +/// # Arguments +/// +/// * `filename`: The path to the certificate file. +/// +/// # Returns +/// +/// A `Result` containing a vector of `CertificateDer` on success, or an `io::Error` on failure. fn load_certs(filename: &str) -> io::Result>> { - // Open certificate file. - let certfile = fs::File::open(filename).unwrap(); + // Open certificate file + let certfile = fs::File::open(filename)?; let mut reader = io::BufReader::new(certfile); - // Load and return certificate. + // Load and return certificates + // The `collect()` method is used to gather all certificates into a vector rustls_pemfile::certs(&mut reader).collect() } -// Load the private key from a file. +/// Load the private key from a file. +/// +/// This function reads a PEM-encoded private key file and returns the parsed private key. +/// +/// # Arguments +/// +/// * `filename`: The path to the private key file. +/// +/// # Returns +/// +/// A `Result` containing a `PrivateKeyDer` on success, or an `io::Error` on failure. fn load_private_key(filename: &str) -> io::Result> { - // Open keyfile. - let keyfile = fs::File::open(filename).unwrap(); + // Open keyfile + let keyfile = fs::File::open(filename)?; let mut reader = io::BufReader::new(keyfile); - // Load and return a single private key. - rustls_pemfile::private_key(&mut reader).map(|key| key.unwrap()) -} \ No newline at end of file + // Load and return a single private key + // The `?` operator is used for error propagation + rustls_pemfile::private_key(&mut reader)? + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "No private key found in file")) +} From 7333d38e640cc0694aa83b3131f5f1fdf7d05ade Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 21:23:26 -0400 Subject: [PATCH 17/45] chore: add TLS tests --- Cargo.toml | 1 + src/tls.rs | 321 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 322 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 603878a..b57470e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ tokio-rustls = "0.26.0" tokio-stream = { version = "0.1.16", features = ["net"] } tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" +tracing-subscriber = "0.3.18" [dev-dependencies] hyper = { version = "1.4.1", features = ["client"] } diff --git a/src/tls.rs b/src/tls.rs index 947e91c..6e81849 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -102,3 +102,324 @@ fn load_private_key(filename: &str) -> io::Result> { rustls_pemfile::private_key(&mut reader)? .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "No private key found in file")) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::tcp::serve_tcp_incoming; + use futures::StreamExt; + use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName}; + use rustls::{ClientConfig, ServerConfig}; + use std::net::SocketAddr; + use std::pin::Pin; + use std::sync::Arc; + use tokio::io::{AsyncRead, AsyncWrite}; + use tokio::net::{TcpListener, TcpStream}; + use tokio_rustls::TlsAcceptor; + use tokio_stream::wrappers::TcpListenerStream; + use tracing::{debug, error, info, warn}; + + // Helper function to create a TLS acceptor for testing + async fn create_test_tls_acceptor() -> io::Result { + debug!("Creating test TLS acceptor"); + let certs = load_certs("examples/sample.pem")?; + let key = load_private_key("examples/sample.rsa")?; + + let config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| { + error!("Failed to create ServerConfig: {}", e); + io::Error::new(io::ErrorKind::Other, e) + })?; + + Ok(TlsAcceptor::from(Arc::new(config))) + } + + #[tokio::test] + async fn test_tls_incoming_success() -> Result<(), Box> { + let _guard = tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .try_init(); + + info!("Starting test_tls_incoming_success"); + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await?; + let server_addr = listener.local_addr()?; + debug!("Server listening on {}", server_addr); + let incoming = TcpListenerStream::new(listener); + + let tls_acceptor = create_test_tls_acceptor().await?; + + // Use serve_tcp_incoming to handle TCP connections + let tcp_incoming = serve_tcp_incoming(incoming); + + // Spawn the server task + let server_task = tokio::spawn(async move { + debug!("Server task started"); + let mut tls_stream = Box::pin(tls_incoming(tcp_incoming, tls_acceptor)); + let result = tls_stream.next().await; + debug!("Server received connection: {:?}", result.is_some()); + result + }); + + // Connect to the server with a TLS client + let mut root_store = rustls::RootCertStore::empty(); + let certs = load_certs("examples/sample.pem")?; + root_store.add_parsable_certificates(certs); + + let client_config = ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + + let connector = tokio_rustls::TlsConnector::from(Arc::new(client_config)); + + debug!("Client connecting to {}", server_addr); + let tcp_stream = TcpStream::connect(server_addr).await?; + let domain = ServerName::try_from("localhost")?; + let _client_stream = connector.connect(domain, tcp_stream).await?; + debug!("Client connected successfully"); + + // Wait for the server to accept the connection + let result = server_task + .await? + .ok_or("Server task completed without result")?; + match result { + Ok(_) => info!("TLS connection established successfully"), + Err(ref e) => error!("TLS connection failed: {}", e), + } + assert!(result.is_ok()); + + Ok(()) + } + + #[tokio::test] + async fn test_tls_incoming_invalid_cert() -> Result<(), Box> { + let _guard = tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .try_init(); + + info!("Starting test_tls_incoming_invalid_cert"); + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await?; + let server_addr = listener.local_addr()?; + debug!("Server listening on {}", server_addr); + let incoming = TcpListenerStream::new(listener); + + // Create a TLS acceptor with an invalid certificate + let invalid_cert = vec![CertificateDer::from(vec![0; 32])]; // Invalid certificate + let key = load_private_key("examples/sample.rsa")?; + + // Expect this to fail and log the error + let config_result = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(invalid_cert, key); + + match config_result { + Ok(_) => warn!("ServerConfig creation unexpectedly succeeded with invalid cert"), + Err(e) => info!("ServerConfig creation failed as expected: {}", e), + } + + // Use a valid certificate for the server to allow the test to continue + let valid_certs = load_certs("examples/sample.pem")?; + let valid_key = load_private_key("examples/sample.rsa")?; + let config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(valid_certs, valid_key) + .expect("ServerConfig creation should succeed with valid cert"); + + let tls_acceptor = TlsAcceptor::from(Arc::new(config)); + + // Use serve_tcp_incoming to handle TCP connections + let tcp_incoming = serve_tcp_incoming(incoming); + + // Spawn the server task + let server_task = tokio::spawn(async move { + debug!("Server task started"); + let mut tls_stream = Box::pin(tls_incoming(tcp_incoming, tls_acceptor)); + let result = tls_stream.next().await; + debug!("Server received connection: {:?}", result.is_some()); + result + }); + + // Connect to the server with a TLS client that doesn't trust the server's certificate + let connector = tokio_rustls::TlsConnector::from(Arc::new( + ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(), + )); + + debug!("Client connecting to {}", server_addr); + let tcp_stream = TcpStream::connect(server_addr).await?; + let domain = ServerName::try_from("localhost")?; + + // This connection should fail due to certificate verification + let client_result = connector.connect(domain, tcp_stream).await; + match &client_result { + Ok(_) => warn!("Client connection succeeded unexpectedly"), + Err(e) => info!("Client connection failed as expected: {}", e), + } + assert!(client_result.is_err()); + + // The server should not encounter an error, but the connection should not be established + let server_result = server_task + .await? + .ok_or("Server task completed without result")?; + match &server_result { + Ok(_) => warn!("Server accepted connection unexpectedly"), + Err(e) => info!("Server did not establish connection as expected: {}", e), + } + assert!(server_result.is_err()); + + Ok(()) + } + + #[tokio::test] + async fn test_tls_incoming_client_hello_timeout() -> Result<(), Box> { + let _guard = tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .try_init(); + + info!("Starting test_tls_incoming_client_hello_timeout"); + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await?; + let server_addr = listener.local_addr()?; + debug!("Server listening on {}", server_addr); + let incoming = TcpListenerStream::new(listener); + + let tls_acceptor = create_test_tls_acceptor().await?; + + // Use serve_tcp_incoming to handle TCP connections + let tcp_incoming = serve_tcp_incoming(incoming); + + // Spawn the server task + let server_task = tokio::spawn(async move { + debug!("Server task started"); + let mut tls_stream = Box::pin(tls_incoming(tcp_incoming, tls_acceptor)); + let result = + tokio::time::timeout(std::time::Duration::from_secs(1), tls_stream.next()).await; + debug!("Server task completed with result: {:?}", result.is_err()); + result + }); + + // Connect with a regular TCP client (no TLS handshake) + debug!("Client connecting with plain TCP to {}", server_addr); + let _tcp_stream = TcpStream::connect(server_addr).await?; + debug!("Client connected with plain TCP"); + + // The server task should timeout + let result = server_task.await?; + match result { + Ok(_) => warn!("Server did not timeout as expected"), + Err(ref e) => info!("Server timed out as expected: {}", e), + } + assert!(result.is_err()); // Timeout error + + Ok(()) + } + + #[tokio::test] + async fn test_load_certs() -> io::Result<()> { + let _guard = tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .try_init(); + + info!("Starting test_load_certs"); + let certs = load_certs("examples/sample.pem")?; + debug!("Loaded {} certificates", certs.len()); + assert!(!certs.is_empty(), "Certificate file should not be empty"); + Ok(()) + } + + #[tokio::test] + async fn test_load_private_key() -> io::Result<()> { + let _guard = tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .try_init(); + + info!("Starting test_load_private_key"); + let key = load_private_key("examples/sample.rsa")?; + debug!("Loaded private key, length: {}", key.secret_der().len()); + assert!( + !key.secret_der().is_empty(), + "Private key should not be empty" + ); + Ok(()) + } + + // Simulating the tls_incoming function for testing purposes + // Replace this with your actual implementation + fn tls_incoming( + incoming: impl Stream> + Send + 'static, + tls_acceptor: TlsAcceptor, + ) -> impl Stream, Error>> + Send + 'static + where + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + Box::pin(incoming.then(move |result| { + let tls_acceptor = tls_acceptor.clone(); + async move { + match result { + Ok(io) => { + debug!("Accepting TLS connection"); + let accept_result = tls_acceptor.accept(io).await.map_err(Error::from); + match &accept_result { + Ok(_) => debug!("TLS connection accepted successfully"), + Err(e) => warn!("Failed to accept TLS connection: {}", e), + } + accept_result + } + Err(e) => { + warn!("Error in incoming connection: {}", e); + Err(e) + } + } + } + })) + } + + // Helper function to load certificates + fn load_certs(filename: &str) -> io::Result>> { + debug!("Loading certificates from {}", filename); + let certfile = std::fs::File::open(filename).map_err(|e| { + error!("Failed to open certificate file: {}", e); + io::Error::new( + io::ErrorKind::Other, + format!("failed to open {}: {}", filename, e), + ) + })?; + let mut reader = io::BufReader::new(certfile); + let certs = rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .map_err(|e| { + error!("Failed to parse certificates: {}", e); + io::Error::new(io::ErrorKind::Other, e) + })?; + debug!("Loaded {} certificates", certs.len()); + Ok(certs) + } + + // Helper function to load private key + fn load_private_key(filename: &str) -> io::Result> { + debug!("Loading private key from {}", filename); + let keyfile = std::fs::File::open(filename).map_err(|e| { + error!("Failed to open private key file: {}", e); + io::Error::new( + io::ErrorKind::Other, + format!("failed to open {}: {}", filename, e), + ) + })?; + let mut reader = io::BufReader::new(keyfile); + let key = rustls_pemfile::private_key(&mut reader) + .map_err(|e| { + error!("Failed to parse private key: {}", e); + io::Error::new(io::ErrorKind::Other, e) + })? + .ok_or_else(|| { + error!("No private key found in file"); + io::Error::new(io::ErrorKind::Other, "no private key found") + })?; + debug!("Loaded private key"); + Ok(key) + } +} From 9ad12d8d329bfe923fb5c5420929594fad3a07e6 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 21:27:36 -0400 Subject: [PATCH 18/45] fix: dry code --- Cargo.toml | 2 +- src/tls.rs | 47 +---------------------------------------------- 2 files changed, 2 insertions(+), 47 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b57470e..d6e054b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,8 +28,8 @@ tokio-rustls = "0.26.0" tokio-stream = { version = "0.1.16", features = ["net"] } tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" -tracing-subscriber = "0.3.18" [dev-dependencies] hyper = { version = "1.4.1", features = ["client"] } tokio = { version = "1.0", features = ["rt", "net", "test-util"] } +tracing-subscriber = "0.3.18" diff --git a/src/tls.rs b/src/tls.rs index 6e81849..73ddb1a 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -297,7 +297,7 @@ mod tests { debug!("Server task started"); let mut tls_stream = Box::pin(tls_incoming(tcp_incoming, tls_acceptor)); let result = - tokio::time::timeout(std::time::Duration::from_secs(1), tls_stream.next()).await; + tokio::time::timeout(std::time::Duration::from_millis(10), tls_stream.next()).await; debug!("Server task completed with result: {:?}", result.is_err()); result }); @@ -377,49 +377,4 @@ mod tests { } })) } - - // Helper function to load certificates - fn load_certs(filename: &str) -> io::Result>> { - debug!("Loading certificates from {}", filename); - let certfile = std::fs::File::open(filename).map_err(|e| { - error!("Failed to open certificate file: {}", e); - io::Error::new( - io::ErrorKind::Other, - format!("failed to open {}: {}", filename, e), - ) - })?; - let mut reader = io::BufReader::new(certfile); - let certs = rustls_pemfile::certs(&mut reader) - .collect::, _>>() - .map_err(|e| { - error!("Failed to parse certificates: {}", e); - io::Error::new(io::ErrorKind::Other, e) - })?; - debug!("Loaded {} certificates", certs.len()); - Ok(certs) - } - - // Helper function to load private key - fn load_private_key(filename: &str) -> io::Result> { - debug!("Loading private key from {}", filename); - let keyfile = std::fs::File::open(filename).map_err(|e| { - error!("Failed to open private key file: {}", e); - io::Error::new( - io::ErrorKind::Other, - format!("failed to open {}: {}", filename, e), - ) - })?; - let mut reader = io::BufReader::new(keyfile); - let key = rustls_pemfile::private_key(&mut reader) - .map_err(|e| { - error!("Failed to parse private key: {}", e); - io::Error::new(io::ErrorKind::Other, e) - })? - .ok_or_else(|| { - error!("No private key found in file"); - io::Error::new(io::ErrorKind::Other, "no private key found") - })?; - debug!("Loaded private key"); - Ok(key) - } } From 3f4374f8970d7aa5090c69c9c163b7f951a78a57 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 21:33:56 -0400 Subject: [PATCH 19/45] feat: expose public functions --- Cargo.toml | 2 +- src/http.rs | 4 ++-- src/lib.rs | 5 +++++ src/tcp.rs | 2 +- src/tls.rs | 6 +++--- 5 files changed, 12 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d6e054b..f08d808 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ categories = ["asynchronous", "network-programming", "web-programming"] description = "High level server for hyper and tower." edition = "2021" homepage = "https://github.com/warlock-labls/hyper-server" -keywords = ["axum", "tonic", "hyper", "tower", "server"] +keywords = ["tcp", "tls", "http", "hyper", "tokio"] license = "MIT" name = "hyper-server" readme = "README.md" diff --git a/src/http.rs b/src/http.rs index 66c76ee..317dbe4 100644 --- a/src/http.rs +++ b/src/http.rs @@ -53,7 +53,7 @@ async fn sleep_or_pending(wait_for: Option) { /// * `watcher`: An optional `tokio::sync::watch::Receiver` for graceful shutdown signaling. /// * `max_connection_age`: An optional `Duration` specifying the maximum age of the connection /// before initiating a graceful shutdown. -pub(crate) async fn serve_http_connection( +pub async fn serve_http_connection( hyper_io: IO, hyper_service: S, builder: HttpConnectionBuilder, @@ -140,7 +140,7 @@ pub(crate) async fn serve_http_connection( /// # Returns /// /// A `Result` indicating success or failure of the server operation. -pub(crate) async fn serve_http_with_shutdown( +pub async fn serve_http_with_shutdown( service: S, incoming: I, builder: HttpConnectionBuilder, diff --git a/src/lib.rs b/src/lib.rs index 9a714c1..39d4bba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,4 +4,9 @@ mod http; mod tcp; mod tls; +pub use tcp::serve_tcp_incoming; +pub use tls::serve_tls_incoming; +pub use http::serve_http_with_shutdown; +pub use http::serve_http_connection; + pub(crate) type Error = Box; diff --git a/src/tcp.rs b/src/tcp.rs index 7b72c96..2d5be6d 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -69,7 +69,7 @@ fn handle_accept_error(e: impl Into) -> ControlFlow { /// This function uses `handle_accept_error` to determine whether to continue accepting /// connections after an error occurs. Non-fatal errors are logged and skipped, while /// fatal errors cause the stream to yield an error and terminate. -pub(crate) fn serve_tcp_incoming( +pub fn serve_tcp_incoming( incoming: impl Stream> + Send + 'static, ) -> impl Stream> where diff --git a/src/tls.rs b/src/tls.rs index 73ddb1a..ce598cc 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -31,7 +31,7 @@ use tokio_stream::{Stream, StreamExt}; /// /// - If the input `tcp_stream` yields an error, that error is propagated. /// - If the TLS handshake fails, the error is wrapped in the crate's `Error` type. -pub(crate) fn tls_incoming( +pub fn serve_tls_incoming( tcp_stream: impl Stream>, tls: TlsAcceptor, ) -> impl Stream, Error>> @@ -71,7 +71,7 @@ where /// # Returns /// /// A `Result` containing a vector of `CertificateDer` on success, or an `io::Error` on failure. -fn load_certs(filename: &str) -> io::Result>> { +pub fn load_certs(filename: &str) -> io::Result>> { // Open certificate file let certfile = fs::File::open(filename)?; let mut reader = io::BufReader::new(certfile); @@ -92,7 +92,7 @@ fn load_certs(filename: &str) -> io::Result>> { /// # Returns /// /// A `Result` containing a `PrivateKeyDer` on success, or an `io::Error` on failure. -fn load_private_key(filename: &str) -> io::Result> { +pub fn load_private_key(filename: &str) -> io::Result> { // Open keyfile let keyfile = fs::File::open(filename)?; let mut reader = io::BufReader::new(keyfile); From be3ad7ffd061ba0741352cf731b91de037ba5fd1 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 21:58:47 -0400 Subject: [PATCH 20/45] chore: cleanup and expose API --- src/error.rs | 8 ++++---- src/http.rs | 3 ++- src/lib.rs | 13 ++++++++----- src/tls.rs | 3 +-- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/error.rs b/src/error.rs index 4624ac3..78d89a5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -20,13 +20,13 @@ struct ErrorImpl { /// Enum representing different kinds of errors that can occur. /// Currently, only includes a Transport variant, but can be extended for more error types. #[derive(Debug)] -pub(crate) enum Kind { +pub enum Kind { Transport, } impl Error { /// Creates a new Error with a specific kind. - pub(crate) fn new(kind: Kind) -> Self { + pub fn new(kind: Kind) -> Self { Self { inner: ErrorImpl { kind, source: None }, } @@ -34,14 +34,14 @@ impl Error { /// Attaches a source error to this Error. /// This method consumes self and returns a new Error, allowing for method chaining. - pub(crate) fn with(mut self, source: impl Into) -> Self { + pub fn with(mut self, source: impl Into) -> Self { self.inner.source = Some(source.into()); self } /// Creates a new Transport Error with the given source. /// This is a convenience method combining new() and with(). - pub(crate) fn from_source(source: impl Into) -> Self { + pub fn from_source(source: impl Into) -> Self { Error::new(Kind::Transport).with(source) } diff --git a/src/http.rs b/src/http.rs index 317dbe4..512b3b4 100644 --- a/src/http.rs +++ b/src/http.rs @@ -7,7 +7,7 @@ use http_body::Body; use hyper::body::Incoming; use hyper::service::Service; use hyper_util::{ - rt::{TokioExecutor, TokioIo}, + rt::TokioIo, server::conn::auto::{Builder as HttpConnectionBuilder, HttpServerConnExec}, }; use tokio::io::{AsyncRead, AsyncWrite}; @@ -232,6 +232,7 @@ mod tests { use bytes::Bytes; use http_body_util::{BodyExt, Empty, Full}; use hyper::{body::Incoming, Request, Response, StatusCode}; + use hyper_util::rt::TokioExecutor; use hyper_util::service::TowerToHyperService; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::oneshot; diff --git a/src/lib.rs b/src/lib.rs index 39d4bba..0302b0d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,15 @@ +pub use error::{Error as TransportError, Kind as TransportErrorKind}; +pub use http::serve_http_connection; +pub use http::serve_http_with_shutdown; +pub use tcp::serve_tcp_incoming; +pub use tls::load_certs; +pub use tls::load_private_key; +pub use tls::serve_tls_incoming; + mod error; mod fuse; mod http; mod tcp; mod tls; -pub use tcp::serve_tcp_incoming; -pub use tls::serve_tls_incoming; -pub use http::serve_http_with_shutdown; -pub use http::serve_http_connection; - pub(crate) type Error = Box; diff --git a/src/tls.rs b/src/tls.rs index ce598cc..61ae64a 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -108,10 +108,9 @@ mod tests { use super::*; use crate::tcp::serve_tcp_incoming; use futures::StreamExt; - use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName}; + use rustls::pki_types::{CertificateDer, ServerName}; use rustls::{ClientConfig, ServerConfig}; use std::net::SocketAddr; - use std::pin::Pin; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, TcpStream}; From f11605948308c5bdc36c96794c7346511d68f14f Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 22:01:08 -0400 Subject: [PATCH 21/45] fix: bump MSRV --- .github/workflows/CI.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d7228dd..67224e8 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,7 +21,7 @@ jobs: strategy: fail-fast: false matrix: - rust: [ "stable", "beta", "nightly", "1.65" ] # MSRV + rust: [ "stable", "beta", "nightly", "1.80" ] # MSRV flags: [ "--no-default-features", "", "--all-features" ] exclude: # Skip because some features have highest MSRV. @@ -37,10 +37,10 @@ jobs: cache-on-failure: true # Only run tests on the latest stable and above - name: check - if: ${{ matrix.rust == '1.65' }} # MSRV + if: ${{ matrix.rust == '1.80' }} # MSRV run: cargo check --workspace ${{ matrix.flags }} - name: test - if: ${{ matrix.rust != '1.65' }} # MSRV + if: ${{ matrix.rust != '1.80' }} # MSRV run: cargo test --workspace ${{ matrix.flags }} coverage: From 7c2867da881bbf730d4a4ce695aac7c75842bf9e Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 22:02:12 -0400 Subject: [PATCH 22/45] fix: bump MSRV 2 --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 67224e8..4410abd 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -25,7 +25,7 @@ jobs: flags: [ "--no-default-features", "", "--all-features" ] exclude: # Skip because some features have highest MSRV. - - rust: "1.65" # MSRV + - rust: "1.80" # MSRV flags: "--all-features" steps: - uses: actions/checkout@v3 From 840ca81757044c23180a402d20b8b8ddf5f53142 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 22:45:32 -0400 Subject: [PATCH 23/45] chore: add new readme --- README.md | 156 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 155 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4372204..6f668c6 100644 --- a/README.md +++ b/README.md @@ -1 +1,155 @@ -# hyper-server \ No newline at end of file +# hyper-server + +[![License](https://img.shields.io/crates/l/hyper-server)](https://choosealicense.com/licenses/mit/) +[![Crates.io](https://img.shields.io/crates/v/hyper-server)](https://crates.io/crates/hyper-server) +[![Docs](https://img.shields.io/crates/v/hyper-server?color=blue&label=docs)](https://docs.rs/hyper-server/) +![CI](https://github.com/warlock-labs/hyper-server/actions/workflows/CI.yml/badge.svg) +[![codecov](https://codecov.io/gh/warlock-labs/hyper-server/branch/master/graph/badge.svg?token=8W5MEJQSW6)](https://codecov.io/gh/warlock-labs/hyper-server) + +A high-performance, modular server implementation built on [hyper], designed to +work seamlessly with [axum], [tonic], [tower], and other tower-compatible +frameworks. + +## Features + +- HTTP/1 and HTTP/2 support +- TLS/HTTPS through [rustls] +- High performance leveraging [hyper] 1.0 +- Modular architecture based on `tokio::net::TcpListener`, `tokio-rustls::Acceptor`, and `hyper::server::conn::auto` +- Flexible integration with [tower] the ecosystem, supporting various backends: + - [axum] + - [tonic] + - [tungstenite] + - Any `hyper::service::Service`, `tower::Service`, or `tower::Layer` composition + +## Installation + +Add this to your `Cargo.toml`: + +```toml +[dependencies] +hyper-server = "0.7.0" +``` + +## Usage + +Here's an example of how to use hyper-server with a simple tower service: + +```rust +use std::convert::Infallible; +use std::net::SocketAddr; +use bytes::Bytes; +use http::{Request, Response, StatusCode}; +use http_body_util::Full; +use hyper_util::rt::TokioExecutor; +use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; +use tower::Service; +use tokio::net::TcpListener; + +// A simple tower service +#[derive(Clone)] +struct HelloService; + +impl Service> for HelloService { + type Response = Response>; + type Error = Infallible; + type Future = std::pin::Pin> + Send>>; + + fn call(&mut self, req: Request) -> Self::Future { + Box::pin(async move { + let response = match (req.method(), req.uri().path()) { + (&hyper::Method::GET, "/") => { + Response::new(Full::new(Bytes::from("Hello, World!"))) + } + _ => { + let mut res = Response::new(Full::new(Bytes::from("Not Found"))); + *res.status_mut() = StatusCode::NOT_FOUND; + res + } + }; + Ok(response) + }) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + let listener = TcpListener::bind(addr).await?; + println!("Listening on http://{}", addr); + + let http_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let service = HelloService; + + hyper_server::serve_http_with_shutdown( + service, + tokio_stream::wrappers::TcpListenerStream::new(listener), + http_builder, + None, + ) + .await?; + + Ok(()) +} +``` + +For more advanced usage and examples, including TLS configuration and custom service implementations, please refer to the [examples directory](/examples). + +## Architecture + +hyper-server provides a layered, composable architecture: + +1. TCP Listening: `tokio::net::TcpListener` +2. TLS (optional): `rustls::TlsAcceptor` +3. HTTP: `hyper_util::server::conn::auto` +4. Service: `hyper::service::Service` +5. Middleware: `tower::Service` +6. Application: Your choice of tower-compatible framework (axum, tonic, etc.) or custom service implementation + +This structure allows for easy customization and extension at each layer. You +can integrate your own implementations at any level of the stack, providing +maximum flexibility for your specific use case. + +## Security + +hyper-server takes security seriously. We use `rustls` for TLS support, which +provides modern, secure defaults. However, please ensure that you configure +your server appropriately for your use case, especially when deploying in +production environments. + +If you discover any security-related issues, please email team@warlock.xyz +instead of using the issue tracker. + +## API + +For detailed API documentation, please refer to the [API docs on docs.rs](https://docs.rs/hyper-server/). + +## Minimum Supported Rust Version + +hyper-server's MSRV is `1.80`. + +## Contributing + +We welcome contributions to hyper-server! Our contributing guidelines are +inspired by the Rule of St. Benedict, emphasizing humility, listening, +and community. Before contributing, please familiarize yourself with these +principles at [The Rule of St. Benedict](http://www.benedictfriend.org/the-rule.html). + +Key points for contributors: + +- Listen first, speak second (Chapter 6) +- Be humble in your contributions (Chapter 7) +- Work diligently and carefully (Chapter 48) +- Treat all code and ideas with respect (Chapter 72) + +## License + +This project is licensed under the MIT License — see +the [LICENSE](/LICENSE) file for details. + +[axum]: https://crates.io/crates/axum +[hyper]: https://crates.io/crates/hyper +[rustls]: https://crates.io/crates/rustls +[tower]: https://crates.io/crates/tower +[tonic]: https://crates.io/crates/tonic +[tungstenite]: https://crates.io/crates/tungstenite \ No newline at end of file From 4e40cafa31afd6db605c796f605f3fcbddb185e1 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 22:50:51 -0400 Subject: [PATCH 24/45] fix: version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index f08d808..84406f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ license = "MIT" name = "hyper-server" readme = "README.md" repository = "https://github.com/valorem-labs-inc/hyper-server" -version = "1.0.0" +version = "0.7.0" [dependencies] async-stream = "0.3.5" From 479cc869a46af63e656fb8d87ea3d5a6e1e9cd34 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Tue, 10 Sep 2024 22:51:50 -0400 Subject: [PATCH 25/45] fix: remove usage example for now --- README.md | 56 +------------------------------------------------------ 1 file changed, 1 insertion(+), 55 deletions(-) diff --git a/README.md b/README.md index 6f668c6..9733aa0 100644 --- a/README.md +++ b/README.md @@ -36,61 +36,7 @@ hyper-server = "0.7.0" Here's an example of how to use hyper-server with a simple tower service: ```rust -use std::convert::Infallible; -use std::net::SocketAddr; -use bytes::Bytes; -use http::{Request, Response, StatusCode}; -use http_body_util::Full; -use hyper_util::rt::TokioExecutor; -use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; -use tower::Service; -use tokio::net::TcpListener; - -// A simple tower service -#[derive(Clone)] -struct HelloService; - -impl Service> for HelloService { - type Response = Response>; - type Error = Infallible; - type Future = std::pin::Pin> + Send>>; - - fn call(&mut self, req: Request) -> Self::Future { - Box::pin(async move { - let response = match (req.method(), req.uri().path()) { - (&hyper::Method::GET, "/") => { - Response::new(Full::new(Bytes::from("Hello, World!"))) - } - _ => { - let mut res = Response::new(Full::new(Bytes::from("Not Found"))); - *res.status_mut() = StatusCode::NOT_FOUND; - res - } - }; - Ok(response) - }) - } -} - -#[tokio::main] -async fn main() -> Result<(), Box> { - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - let listener = TcpListener::bind(addr).await?; - println!("Listening on http://{}", addr); - - let http_builder = HttpConnectionBuilder::new(TokioExecutor::new()); - let service = HelloService; - - hyper_server::serve_http_with_shutdown( - service, - tokio_stream::wrappers::TcpListenerStream::new(listener), - http_builder, - None, - ) - .await?; - - Ok(()) -} + ``` For more advanced usage and examples, including TLS configuration and custom service implementations, please refer to the [examples directory](/examples). From 9ea325ea63db61fb8f9ae741e0b18d75ad4a7bfd Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Wed, 11 Sep 2024 00:46:44 -0400 Subject: [PATCH 26/45] feat: worked example in full --- README.md | 167 ++++++++++-- examples/full.rs | 131 +++++++++ src/http.rs | 670 ++++++++++++++++++++++++++++++++++++++++------- src/io.rs | 84 ++++++ src/lib.rs | 1 + 5 files changed, 936 insertions(+), 117 deletions(-) create mode 100644 examples/full.rs create mode 100644 src/io.rs diff --git a/README.md b/README.md index 9733aa0..fd46207 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,8 @@ ![CI](https://github.com/warlock-labs/hyper-server/actions/workflows/CI.yml/badge.svg) [![codecov](https://codecov.io/gh/warlock-labs/hyper-server/branch/master/graph/badge.svg?token=8W5MEJQSW6)](https://codecov.io/gh/warlock-labs/hyper-server) -A high-performance, modular server implementation built on [hyper], designed to -work seamlessly with [axum], [tonic], [tower], and other tower-compatible +A high-performance, modular server implementation built on [hyper], designed to +work seamlessly with [axum], [tonic], [tower], and other tower-compatible frameworks. ## Features @@ -19,7 +19,6 @@ frameworks. - Flexible integration with [tower] the ecosystem, supporting various backends: - [axum] - [tonic] - - [tungstenite] - Any `hyper::service::Service`, `tower::Service`, or `tower::Layer` composition ## Installation @@ -33,13 +32,144 @@ hyper-server = "0.7.0" ## Usage -Here's an example of how to use hyper-server with a simple tower service: +Here's an example of how to use hyper-server with a simple tower lambda service via TCP/TLS/HTTP2 transport: ```rust - +use std::convert::Infallible; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use bytes::Bytes; +use http_body_util::Full; +use hyper::{Request, Response}; +use hyper::body::Incoming; +use hyper_util::rt::TokioExecutor; +use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; +use hyper_util::service::TowerToHyperService; +use rustls::ServerConfig; +use tokio::net::TcpListener; +use tokio_stream::wrappers::TcpListenerStream; +use tower::{Layer, ServiceBuilder}; +use tracing::{trace, debug, info}; + +use hyper_server::{load_certs, load_private_key, serve_http_with_shutdown}; + +// Define a simple service that responds with "Hello, World!" +async fn hello(_: Request) -> Result>, Infallible> { + Ok(Response::new(Full::new(Bytes::from("Hello, World!")))) +} + +// Define a Custom middleware to add a header to all responses, for example +struct AddHeaderLayer; + +impl Layer for AddHeaderLayer { + type Service = AddHeaderService; + + fn layer(&self, service: S) -> Self::Service { + AddHeaderService { inner: service } + } +} + +#[derive(Clone)] +struct AddHeaderService { + inner: S, +} + +impl tower::Service> for AddHeaderService +where + S: tower::Service, Response=Response>>, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = std::pin::Pin< + Box> + Send>, + >; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + trace!("Adding custom header to response"); + let future = self.inner.call(req); + Box::pin(async move { + let mut resp = future.await?; + resp.headers_mut() + .insert("X-Custom-Header", "Hello from middleware!".parse().unwrap()); + Ok(resp) + }) + } +} + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + let addr = SocketAddr::from(([127, 0, 0, 1], 8443)); + // 1. Set up the TCP listener + let listener = TcpListener::bind(addr).await?; + info!("Listening on https://{}", addr); + let incoming = TcpListenerStream::new(listener); + + // 2. Create the HTTP connection builder + let builder = HttpConnectionBuilder::new(TokioExecutor::new()); + + // 3. Set up the Tower service with middleware + let svc = tower::service_fn(hello); + let svc = ServiceBuilder::new() + .layer(AddHeaderLayer) // Custom middleware + .service(svc); + + // 4. Convert the Tower service to a Hyper service + let svc = TowerToHyperService::new(svc); + + // 5. Set up TLS config + let certs = load_certs("examples/sample.pem")?; + let key = load_private_key("examples/sample.rsa")?; + + let mut config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; + let tls_config = Arc::new(config); + + // 6. Set up graceful shutdown + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + + // Spawn a task to send the shutdown signal after 1 second + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(1)).await; + let _ = shutdown_tx.send(()); + debug!("Shutdown signal sent"); + }); + + // 7. Start the server + info!("Starting HTTPS server..."); + serve_http_with_shutdown( + svc, + incoming, + builder, + Some(tls_config), + Some(async { + shutdown_rx.await.ok(); + info!("Shutdown signal received, starting graceful shutdown"); + }), + ) + .await?; + + info!("Server has shut down"); + // Et voilà! + // A flexible, high-performance server with custom services, middleware, http, tls, tcp, and graceful shutdown + Ok(()) +} ``` -For more advanced usage and examples, including TLS configuration and custom service implementations, please refer to the [examples directory](/examples). +For more advanced usage and examples, please refer to, or contribute to, +the [examples directory](/examples). ## Architecture @@ -52,18 +182,18 @@ hyper-server provides a layered, composable architecture: 5. Middleware: `tower::Service` 6. Application: Your choice of tower-compatible framework (axum, tonic, etc.) or custom service implementation -This structure allows for easy customization and extension at each layer. You -can integrate your own implementations at any level of the stack, providing +This structure allows for easy customization and extension at each layer. You +can integrate your own implementations at any level of the stack, providing maximum flexibility for your specific use case. ## Security -hyper-server takes security seriously. We use `rustls` for TLS support, which -provides modern, secure defaults. However, please ensure that you configure -your server appropriately for your use case, especially when deploying in +hyper-server takes security seriously. We use `rustls` for TLS support, which +provides modern, secure defaults. However, please ensure that you configure +your server appropriately for your use case, especially when deploying in production environments. -If you discover any security-related issues, please email team@warlock.xyz +If you discover any security-related issues, please email team@warlock.xyz instead of using the issue tracker. ## API @@ -76,9 +206,9 @@ hyper-server's MSRV is `1.80`. ## Contributing -We welcome contributions to hyper-server! Our contributing guidelines are -inspired by the Rule of St. Benedict, emphasizing humility, listening, -and community. Before contributing, please familiarize yourself with these +We welcome contributions to hyper-server! Our contributing guidelines are +inspired by the Rule of St. Benedict, emphasizing humility, listening, +and community. Before contributing, please familiarize yourself with these principles at [The Rule of St. Benedict](http://www.benedictfriend.org/the-rule.html). Key points for contributors: @@ -90,12 +220,17 @@ Key points for contributors: ## License -This project is licensed under the MIT License — see +This project is licensed under the MIT License — see the [LICENSE](/LICENSE) file for details. [axum]: https://crates.io/crates/axum + [hyper]: https://crates.io/crates/hyper + [rustls]: https://crates.io/crates/rustls + [tower]: https://crates.io/crates/tower + [tonic]: https://crates.io/crates/tonic + [tungstenite]: https://crates.io/crates/tungstenite \ No newline at end of file diff --git a/examples/full.rs b/examples/full.rs new file mode 100644 index 0000000..d59c614 --- /dev/null +++ b/examples/full.rs @@ -0,0 +1,131 @@ +use std::convert::Infallible; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use bytes::Bytes; +use http_body_util::Full; +use hyper::body::Incoming; +use hyper::{Request, Response}; +use hyper_util::rt::TokioExecutor; +use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; +use hyper_util::service::TowerToHyperService; +use rustls::ServerConfig; +use tokio::net::TcpListener; +use tokio_stream::wrappers::TcpListenerStream; +use tower::{Layer, ServiceBuilder}; +use tracing::{debug, info, trace}; + +use hyper_server::{load_certs, load_private_key, serve_http_with_shutdown}; + +// Define a simple service that responds with "Hello, World!" +async fn hello(_: Request) -> Result>, Infallible> { + Ok(Response::new(Full::new(Bytes::from("Hello, World!")))) +} + +// Define a Custom middleware to add a header to all responses, for example +struct AddHeaderLayer; + +impl Layer for AddHeaderLayer { + type Service = AddHeaderService; + + fn layer(&self, service: S) -> Self::Service { + AddHeaderService { inner: service } + } +} + +#[derive(Clone)] +struct AddHeaderService { + inner: S, +} + +impl tower::Service> for AddHeaderService +where + S: tower::Service, Response = Response>>, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = std::pin::Pin< + Box> + Send>, + >; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + trace!("Adding custom header to response"); + let future = self.inner.call(req); + Box::pin(async move { + let mut resp = future.await?; + resp.headers_mut() + .insert("X-Custom-Header", "Hello from middleware!".parse().unwrap()); + Ok(resp) + }) + } +} + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + let addr = SocketAddr::from(([127, 0, 0, 1], 8443)); + // 1. Set up the TCP listener + let listener = TcpListener::bind(addr).await?; + info!("Listening on https://{}", addr); + let incoming = TcpListenerStream::new(listener); + + // 2. Create the HTTP connection builder + let builder = HttpConnectionBuilder::new(TokioExecutor::new()); + + // 3. Set up the Tower service with middleware + let svc = tower::service_fn(hello); + let svc = ServiceBuilder::new() + .layer(AddHeaderLayer) // Custom middleware + .service(svc); + + // 4. Convert the Tower service to a Hyper service + let svc = TowerToHyperService::new(svc); + + // 5. Set up TLS config + let certs = load_certs("examples/sample.pem")?; + let key = load_private_key("examples/sample.rsa")?; + + let mut config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; + let tls_config = Arc::new(config); + + // 6. Set up graceful shutdown + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + + // Spawn a task to send the shutdown signal after 1 second + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(1)).await; + let _ = shutdown_tx.send(()); + debug!("Shutdown signal sent"); + }); + + // 7. Start the server + info!("Starting HTTPS server..."); + serve_http_with_shutdown( + svc, + incoming, + builder, + Some(tls_config), + Some(async { + shutdown_rx.await.ok(); + info!("Shutdown signal received, starting graceful shutdown"); + }), + ) + .await?; + + info!("Server has shut down"); + // Et voilà! + // A flexible, high-performance server with custom services, middleware, http, tls, tcp, and graceful shutdown + Ok(()) +} diff --git a/src/http.rs b/src/http.rs index 512b3b4..b1c950c 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,5 +1,7 @@ +use crate::io::Transport; use std::future::pending; use std::{future::Future, pin::pin, sync::Arc, time::Duration}; +use tokio_rustls::TlsAcceptor; use bytes::Bytes; use http::{Request, Response}; @@ -116,10 +118,11 @@ pub async fn serve_http_connection( }); } -/// Serves HTTP requests with graceful shutdown capability. +/// Serves HTTP/HTTPS requests with graceful shutdown capability. /// -/// This function sets up an HTTP server that can handle incoming connections and -/// process requests using the provided service. It also supports graceful shutdown. +/// This function sets up an HTTP/HTTPS server that can handle incoming connections and +/// process requests using the provided service. It supports both plain HTTP and HTTPS +/// connections, as well as graceful shutdown. /// /// # Type Parameters /// @@ -135,19 +138,209 @@ pub async fn serve_http_connection( /// /// * `service`: The service used to process HTTP requests. /// * `incoming`: The stream of incoming connections. +/// * `builder`: The `HttpConnectionBuilder` used to configure the server. +/// * `tls_config`: An optional TLS configuration for HTTPS support. /// * `signal`: An optional future that, when resolved, signals the server to shut down gracefully. /// /// # Returns /// /// A `Result` indicating success or failure of the server operation. +/// +/// # Examples +/// +/// These examples provide some very basic ways to use the server. With that said, +/// the server is very flexible and can be used in a variety of ways. This is +/// because you as the integrator have control over every level of the stack at +/// construction, with all the native builders exposed via generics. +/// +/// Setting up an HTTP server with graceful shutdown: +/// +/// ```rust,no_run +/// use std::convert::Infallible; +/// use bytes::Bytes; +/// use http_body_util::Full; +/// use hyper::body::Incoming; +/// use hyper::{Request, Response}; +/// use hyper_util::rt::TokioExecutor; +/// use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; +/// use tokio::net::TcpListener; +/// use tokio_stream::wrappers::TcpListenerStream; +/// use tower::ServiceBuilder; +/// use std::net::SocketAddr; +/// +/// use hyper_server::serve_http_with_shutdown; +/// +/// async fn hello(_: Request) -> Result>, Infallible> { +/// Ok(Response::new(Full::new(Bytes::from("Hello, World!")))) +/// } +/// +/// #[tokio::main(flavor = "current_thread")] +/// async fn main() -> Result<(), Box> { +/// let addr = SocketAddr::from(([127, 0, 0, 1], 8080)); +/// let listener = TcpListener::bind(addr).await?; +/// let incoming = TcpListenerStream::new(listener); +/// +/// let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); +/// +/// let builder = HttpConnectionBuilder::new(TokioExecutor::new()); +/// let svc = hyper::service::service_fn(hello); +/// let svc = ServiceBuilder::new().service(svc); +/// +/// tokio::spawn(async move { +/// // Simulate a shutdown signal after 60 seconds +/// tokio::time::sleep(std::time::Duration::from_secs(60)).await; +/// let _ = shutdown_tx.send(()); +/// }); +/// +/// serve_http_with_shutdown( +/// svc, +/// incoming, +/// builder, +/// None, // No TLS config for plain HTTP +/// Some(async { +/// shutdown_rx.await.ok(); +/// }), +/// ).await?; +/// +/// Ok(()) +/// } +/// ``` +/// +/// Setting up an HTTPS server: +/// +/// ```rust,no_run +/// use std::convert::Infallible; +/// use std::sync::Arc; +/// use bytes::Bytes; +/// use http_body_util::Full; +/// use hyper::body::Incoming; +/// use hyper::{Request, Response}; +/// use hyper_util::rt::TokioExecutor; +/// use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; +/// use tokio::net::TcpListener; +/// use tokio_stream::wrappers::TcpListenerStream; +/// use tower::ServiceBuilder; +/// use rustls::ServerConfig; +/// use std::io; +/// use std::net::SocketAddr; +/// use std::future::Future; +/// +/// use hyper_server::{serve_http_with_shutdown, load_certs, load_private_key}; +/// +/// async fn hello(_: Request) -> Result>, Infallible> { +/// Ok(Response::new(Full::new(Bytes::from("Hello, World!")))) +/// } +/// +/// #[tokio::main(flavor = "current_thread")] +/// async fn main() -> Result<(), Box> { +/// let addr = SocketAddr::from(([127, 0, 0, 1], 8080)); +/// let listener = TcpListener::bind(addr).await?; +/// let incoming = TcpListenerStream::new(listener); +/// +/// let builder = HttpConnectionBuilder::new(TokioExecutor::new()); +/// let svc = hyper::service::service_fn(hello); +/// let svc = ServiceBuilder::new().service(svc); +/// +/// // Set up TLS config +/// let certs = load_certs("examples/sample.pem")?; +/// let key = load_private_key("examples/sample.rsa")?; +/// +/// let config = ServerConfig::builder() +/// .with_no_client_auth() +/// .with_single_cert(certs, key) +/// .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; +/// let tls_config = Arc::new(config); +/// +/// serve_http_with_shutdown( +/// svc, +/// incoming, +/// builder, +/// Some(tls_config), +/// Some(std::future::pending::<()>()), // A never-resolving future as a placeholder +/// ).await?; +/// +/// Ok(()) +/// } +/// ``` +/// +/// Setting up an HTTPS server with a Tower service: +/// +/// ```rust,no_run +/// use std::convert::Infallible; +/// use std::sync::Arc; +/// use bytes::Bytes; +/// use http_body_util::Full; +/// use hyper::body::Incoming; +/// use hyper::{Request, Response}; +/// use hyper_util::rt::TokioExecutor; +/// use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; +/// use hyper_util::service::TowerToHyperService; +/// use tokio::net::TcpListener; +/// use tokio_stream::wrappers::TcpListenerStream; +/// use tower::{ServiceBuilder, ServiceExt}; +/// use rustls::ServerConfig; +/// use std::io; +/// use std::net::SocketAddr; +/// use std::future::Future; +/// +/// use hyper_server::{serve_http_with_shutdown, load_certs, load_private_key}; +/// +/// async fn hello(_: Request) -> Result>, Infallible> { +/// Ok(Response::new(Full::new(Bytes::from("Hello, World!")))) +/// } +/// +/// #[tokio::main(flavor = "current_thread")] +/// async fn main() -> Result<(), Box> { +/// let addr = SocketAddr::from(([127, 0, 0, 1], 8080)); +/// let listener = TcpListener::bind(addr).await?; +/// let incoming = TcpListenerStream::new(listener); +/// +/// let builder = HttpConnectionBuilder::new(TokioExecutor::new()); +/// +/// // Set up the Tower service +/// let svc = tower::service_fn(hello); +/// let svc = ServiceBuilder::new() +/// .service(svc); +/// +/// // Convert the Tower service to a Hyper service +/// let svc = TowerToHyperService::new(svc); +/// +/// // Set up TLS config +/// let certs = load_certs("examples/sample.pem")?; +/// let key = load_private_key("examples/sample.rsa")?; +/// +/// let config = ServerConfig::builder() +/// .with_no_client_auth() +/// .with_single_cert(certs, key) +/// .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; +/// let tls_config = Arc::new(config); +/// +/// serve_http_with_shutdown( +/// svc, +/// incoming, +/// builder, +/// Some(tls_config), +/// Some(std::future::pending::<()>()), // A never-resolving future as a placeholder +/// ).await?; +/// +/// Ok(()) +/// } +/// ``` +/// +/// # Notes +/// +/// - The server will continue to accept new connections until the `signal` future resolves. +/// - When using TLS, make sure to provide a properly configured `ServerConfig`. +/// - The function will return when all connections have been closed after the shutdown signal. pub async fn serve_http_with_shutdown( service: S, incoming: I, builder: HttpConnectionBuilder, + tls_config: Option>, signal: Option, ) -> Result<(), super::Error> where - F: Future, + F: Future + Send + 'static, I: Stream> + Send + 'static, IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, IE: Into + Send + 'static, @@ -169,6 +362,9 @@ where let mut sig = pin!(Fuse { inner: signal }); let mut incoming = pin!(incoming); + // Create TLS acceptor if TLS config is provided + let tls_acceptor = tls_config.map(TlsAcceptor::from); + // Main server loop loop { tokio::select! { @@ -193,7 +389,19 @@ where trace!("connection accepted"); // Prepare the connection for hyper - let hyper_io = TokioIo::new(io); + let transport = if let Some(tls_acceptor) = &tls_acceptor { + match tls_acceptor.accept(io).await { + Ok(tls_stream) => Transport::new_tls(tls_stream), + Err(e) => { + debug!("TLS handshake failed: {:#}", e); + continue; + } + } + } else { + Transport::new_plain(io) + }; + + let hyper_io = TokioIo::new(transport); let hyper_svc = service.clone(); // Serve the HTTP connection @@ -226,21 +434,23 @@ where #[cfg(test)] mod tests { - use std::net::SocketAddr; - use std::time::Duration; - + use super::*; + use crate::{load_certs, load_private_key}; use bytes::Bytes; use http_body_util::{BodyExt, Empty, Full}; use hyper::{body::Incoming, Request, Response, StatusCode}; use hyper_util::rt::TokioExecutor; use hyper_util::service::TowerToHyperService; + use rustls::ServerConfig; + use std::net::SocketAddr; + use std::sync::Arc; + use std::time::Duration; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::oneshot; use tokio_stream::wrappers::TcpListenerStream; - use super::*; + // Utility functions - // Echo service async fn echo(req: Request) -> Result>, hyper::Error> { match (req.method(), req.uri().path()) { (&hyper::Method::GET, "/") => { @@ -265,11 +475,21 @@ mod tests { (incoming, server_addr) } + async fn create_test_tls_config() -> Arc { + let certs = load_certs("examples/sample.pem").unwrap(); + let key = load_private_key("examples/sample.rsa").unwrap(); + let config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .unwrap(); + Arc::new(config) + } + async fn send_request( addr: SocketAddr, req: Request>, - ) -> hyper::Result> { - let stream = TcpStream::connect(addr).await.unwrap(); + ) -> Result, Box> { + let stream = TcpStream::connect(addr).await?; let io = TokioIo::new(stream); let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; @@ -279,110 +499,358 @@ mod tests { } }); - sender.send_request(req).await + Ok(sender.send_request(req).await?) } - #[tokio::test] - async fn test_serve_http_with_shutdown_basic() { - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - let (incoming, server_addr) = setup_test_server(addr).await; + // HTTP Tests + + mod http_tests { + use super::*; + + #[tokio::test] + async fn test_http_basic_requests() { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let (incoming, server_addr) = setup_test_server(addr).await; + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let tower_service_fn = tower::service_fn(echo); + let hyper_service = TowerToHyperService::new(tower_service_fn); + + let server = tokio::spawn(serve_http_with_shutdown( + hyper_service, + incoming, + http_server_builder, + None, + Some(async { + shutdown_rx.await.ok(); + }), + )); + + // Test GET request + let req = Request::builder() + .uri("/") + .body(Empty::::new()) + .unwrap(); + let res = send_request(server_addr, req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let body = res.collect().await.unwrap().to_bytes(); + assert_eq!(&body[..], b"Hello, World!"); + + // Test POST request + let req = Request::builder() + .method(hyper::Method::POST) + .uri("/echo") + .body(Empty::::new()) + .unwrap(); + let res = send_request(server_addr, req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + // Test 404 response + let req = Request::builder() + .uri("/not_found") + .body(Empty::::new()) + .unwrap(); + let res = send_request(server_addr, req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + shutdown_tx.send(()).unwrap(); + server.await.unwrap().unwrap(); + } - let (shutdown_tx, shutdown_rx) = oneshot::channel(); + #[tokio::test] + async fn test_http_concurrent_requests() { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let (incoming, server_addr) = setup_test_server(addr).await; + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let tower_service_fn = tower::service_fn(echo); + let hyper_service = TowerToHyperService::new(tower_service_fn); + + let server = tokio::spawn(serve_http_with_shutdown( + hyper_service, + incoming, + http_server_builder, + None, + Some(async { + shutdown_rx.await.ok(); + }), + )); + + let mut handles = vec![]; + for _ in 0..10 { + let addr = server_addr; + let handle = tokio::spawn(async move { + let req = Request::builder() + .uri("/") + .body(Empty::::new()) + .unwrap(); + let res = send_request(addr, req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + }); + handles.push(handle); + } - let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + for handle in handles { + handle.await.unwrap(); + } - let tower_service_fn = tower::service_fn(echo); - let hyper_service = TowerToHyperService::new(tower_service_fn); + shutdown_tx.send(()).unwrap(); + server.await.unwrap().unwrap(); + } - let server = tokio::spawn(serve_http_with_shutdown( - hyper_service, - incoming, - http_server_builder, - Some(async { - shutdown_rx.await.ok(); - }), - )); + #[tokio::test] + async fn test_http_graceful_shutdown() { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let (incoming, server_addr) = setup_test_server(addr).await; + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let tower_service_fn = tower::service_fn(echo); + let hyper_service = TowerToHyperService::new(tower_service_fn); + + let server = tokio::spawn(serve_http_with_shutdown( + hyper_service, + incoming, + http_server_builder, + None, + Some(async { + shutdown_rx.await.ok(); + }), + )); + + // Send a request before shutdown + let req = Request::builder() + .uri("/") + .body(Empty::::new()) + .unwrap(); + let res = send_request(server_addr, req) + .await + .expect("Failed to send initial request"); + assert_eq!(res.status(), StatusCode::OK); + + // Initiate graceful shutdown + shutdown_tx.send(()).unwrap(); + + // Wait for the server to shut down + let shutdown_timeout = Duration::from_millis(150); + let shutdown_result = tokio::time::timeout(shutdown_timeout, async { + loop { + tokio::time::sleep(Duration::from_millis(10)).await; + let req = Request::builder() + .uri("/") + .body(Empty::::new()) + .unwrap(); + match send_request(server_addr, req).await { + Ok(_) => continue, // Server still accepting connections + Err(e) if e.to_string().contains("Connection refused") => { + // Server has shut down as expected + return Ok(()); + } + Err(e) => return Err(e), // Unexpected error + } + } + }) + .await; - // Test GET request - let req = Request::builder() - .uri("/") - .body(Empty::::new()) - .unwrap(); - let res = send_request(server_addr, req).await.unwrap(); - assert_eq!(res.status(), StatusCode::OK); - let body = res.collect().await.unwrap().to_bytes(); - assert_eq!(&body[..], b"Hello, World!"); - - // Test POST request - let req = Request::builder() - .method(hyper::Method::POST) - .uri("/echo") - .body(Empty::::new()) - .unwrap(); - let res = send_request(server_addr, req).await.unwrap(); - assert_eq!(res.status(), StatusCode::OK); + match shutdown_result { + Ok(Ok(())) => println!("Server shut down successfully"), + Ok(Err(e)) => panic!("Unexpected error during shutdown: {}", e), + Err(_) => panic!("Timeout waiting for server to shut down"), + } - // Test 404 response - let req = Request::builder() - .uri("/not_found") - .body(Empty::::new()) - .unwrap(); - let res = send_request(server_addr, req).await.unwrap(); - assert_eq!(res.status(), StatusCode::NOT_FOUND); - - // Shutdown the server - shutdown_tx.send(()).unwrap(); - tokio::time::timeout(Duration::from_secs(5), server) - .await - .expect("Server didn't shut down within the timeout period") - .unwrap() - .unwrap(); + // Ensure the server task completes + server.await.unwrap().unwrap(); + } } - #[tokio::test] - async fn test_serve_http_with_concurrent_requests() { - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - let (incoming, server_addr) = setup_test_server(addr).await; - - let (shutdown_tx, shutdown_rx) = oneshot::channel(); - - let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); - - let tower_service_fn = tower::service_fn(echo); - let hyper_service = TowerToHyperService::new(tower_service_fn); - - let server = tokio::spawn(serve_http_with_shutdown( - hyper_service, - incoming, - http_server_builder, - Some(async { - shutdown_rx.await.ok(); - }), - )); - - let mut handles = vec![]; - for _ in 0..10 { - let handle = tokio::spawn(async move { - let req = Request::builder() - .uri("/") - .body(Empty::::new()) + // HTTPS Tests + + mod https_tests { + use super::*; + + async fn create_https_client() -> ( + tokio_rustls::TlsConnector, + rustls::pki_types::ServerName<'static>, + ) { + let mut root_cert_store = rustls::RootCertStore::empty(); + root_cert_store.add_parsable_certificates(load_certs("examples/sample.pem").unwrap()); + + let client_config = rustls::ClientConfig::builder() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + + let tls_connector = tokio_rustls::TlsConnector::from(Arc::new(client_config)); + let domain = rustls::pki_types::ServerName::try_from("localhost") + .expect("Failed to create ServerName"); + + (tls_connector, domain) + } + + #[tokio::test] + async fn test_https_connection() { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let (incoming, server_addr) = setup_test_server(addr).await; + + let tls_config = create_test_tls_config().await; + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let tower_service_fn = tower::service_fn(echo); + let hyper_service = TowerToHyperService::new(tower_service_fn); + + let server = tokio::spawn(serve_http_with_shutdown( + hyper_service, + incoming, + http_server_builder, + Some(tls_config), + Some(async { + shutdown_rx.await.ok(); + }), + )); + + let (tls_connector, domain) = create_https_client().await; + + let tcp_stream = TcpStream::connect(server_addr).await.unwrap(); + let tls_stream = tls_connector.connect(domain, tcp_stream).await.unwrap(); + + let (mut sender, conn) = + hyper::client::conn::http1::handshake(TokioIo::new(tls_stream)) + .await .unwrap(); - let res = send_request(server_addr, req).await.unwrap(); - assert_eq!(res.status(), StatusCode::OK); + + tokio::spawn(async move { + if let Err(err) = conn.await { + eprintln!("Connection failed: {:?}", err); + } }); - handles.push(handle); + + let req = Request::builder() + .uri("/") + .body(Empty::::new()) + .unwrap(); + + let res = sender.send_request(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let body = res.collect().await.unwrap().to_bytes(); + assert_eq!(&body[..], b"Hello, World!"); + + shutdown_tx.send(()).unwrap(); + server.await.unwrap().unwrap(); } - for handle in handles { - handle.await.unwrap(); + #[tokio::test] + async fn test_https_invalid_client_cert() { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let (incoming, server_addr) = setup_test_server(addr).await; + + let tls_config = create_test_tls_config().await; + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let tower_service_fn = tower::service_fn(echo); + let hyper_service = TowerToHyperService::new(tower_service_fn); + + let server = tokio::spawn(serve_http_with_shutdown( + hyper_service, + incoming, + http_server_builder, + Some(tls_config), + Some(async { + shutdown_rx.await.ok(); + }), + )); + + let client_config = rustls::ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(); + + let tls_connector = tokio_rustls::TlsConnector::from(Arc::new(client_config)); + + let tcp_stream = TcpStream::connect(server_addr).await.unwrap(); + let domain = rustls::pki_types::ServerName::try_from("localhost").unwrap(); + + let result = tls_connector.connect(domain, tcp_stream).await; + assert!( + result.is_err(), + "Expected TLS connection to fail due to invalid client certificate" + ); + + shutdown_tx.send(()).unwrap(); + server.await.unwrap().unwrap(); } + #[tokio::test] + async fn test_https_graceful_shutdown() { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let (incoming, server_addr) = setup_test_server(addr).await; + + let tls_config = create_test_tls_config().await; + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let tower_service_fn = tower::service_fn(echo); + let hyper_service = TowerToHyperService::new(tower_service_fn); + + let server = tokio::spawn(serve_http_with_shutdown( + hyper_service, + incoming, + http_server_builder, + Some(tls_config), + Some(async { + shutdown_rx.await.ok(); + }), + )); + + let (tls_connector, domain) = create_https_client().await; + + // Establish a connection + let tcp_stream = TcpStream::connect(server_addr).await.unwrap(); + let tls_stream = tls_connector.connect(domain, tcp_stream).await.unwrap(); + + let (mut sender, conn) = + hyper::client::conn::http1::handshake(TokioIo::new(tls_stream)) + .await + .unwrap(); - // Shutdown the server - shutdown_tx.send(()).unwrap(); - tokio::time::timeout(Duration::from_secs(5), server) - .await - .expect("Server didn't shut down within the timeout period") - .unwrap() - .unwrap(); + tokio::spawn(async move { + if let Err(err) = conn.await { + eprintln!("Connection failed: {:?}", err); + } + }); + + // Send a request + let req = Request::builder() + .uri("/") + .body(Empty::::new()) + .unwrap(); + + let res = sender.send_request(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + // Initiate graceful shutdown + shutdown_tx.send(()).unwrap(); + + // Wait a bit to allow the server to start shutting down + tokio::time::sleep(Duration::from_millis(10)).await; + + // Try to send another request, it should fail + let req = Request::builder() + .uri("/") + .body(Empty::::new()) + .unwrap(); + + let result = sender.send_request(req).await; + assert!( + result.is_err(), + "Expected request to fail after graceful shutdown" + ); + + server.await.unwrap().unwrap(); + } } } diff --git a/src/io.rs b/src/io.rs new file mode 100644 index 0000000..48ece65 --- /dev/null +++ b/src/io.rs @@ -0,0 +1,84 @@ +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_rustls::server::TlsStream; + +pub(crate) enum Transport { + Plain(IO), + Tls(Box>), +} + +impl Transport { + pub(crate) fn new_plain(io: IO) -> Self { + Self::Plain(io) + } + + pub(crate) fn new_tls(io: TlsStream) -> Self { + Self::Tls(Box::from(io)) + } +} + +impl AsyncRead for Transport +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + Transport::Plain(io) => Pin::new(io).poll_read(cx, buf), + Transport::Tls(io) => Pin::new(io).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for Transport +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + Transport::Plain(io) => Pin::new(io).poll_write(cx, buf), + Transport::Tls(io) => Pin::new(io).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Transport::Plain(io) => Pin::new(io).poll_flush(cx), + Transport::Tls(io) => Pin::new(io).poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Transport::Plain(io) => Pin::new(io).poll_shutdown(cx), + Transport::Tls(io) => Pin::new(io).poll_shutdown(cx), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + match self.get_mut() { + Transport::Plain(io) => Pin::new(io).poll_write_vectored(cx, bufs), + Transport::Tls(io) => Pin::new(io).poll_write_vectored(cx, bufs), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + Transport::Plain(io) => io.is_write_vectored(), + Transport::Tls(io) => io.is_write_vectored(), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 0302b0d..8fe53d3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ pub use tls::serve_tls_incoming; mod error; mod fuse; mod http; +mod io; mod tcp; mod tls; From 897e8ce82cdea1090067ba41693215c5650c486f Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Wed, 11 Sep 2024 00:49:56 -0400 Subject: [PATCH 27/45] fix: address feedback --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 84406f0..9e80a97 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,9 +1,9 @@ [package] -authors = ["0xAlcibiades "] +authors = ["0xAlcibiades "] categories = ["asynchronous", "network-programming", "web-programming"] description = "High level server for hyper and tower." edition = "2021" -homepage = "https://github.com/warlock-labls/hyper-server" +homepage = "https://github.com/warlock-labs/hyper-server" keywords = ["tcp", "tls", "http", "hyper", "tokio"] license = "MIT" name = "hyper-server" From 9b2847eb30fa6cfbc0620f7459394eabc4869a8f Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Wed, 11 Sep 2024 01:17:56 -0400 Subject: [PATCH 28/45] fix: line break --- README.md | 3 ++- examples/full.rs | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index fd46207..44b7e14 100644 --- a/README.md +++ b/README.md @@ -163,7 +163,8 @@ async fn main() -> Result<(), Box> { info!("Server has shut down"); // Et voilà! - // A flexible, high-performance server with custom services, middleware, http, tls, tcp, and graceful shutdown + // A flexible, high-performance server with custom services, middleware, + // http, tls, tcp, and graceful shutdown Ok(()) } ``` diff --git a/examples/full.rs b/examples/full.rs index d59c614..e2a9158 100644 --- a/examples/full.rs +++ b/examples/full.rs @@ -126,6 +126,7 @@ async fn main() -> Result<(), Box> { info!("Server has shut down"); // Et voilà! - // A flexible, high-performance server with custom services, middleware, http, tls, tcp, and graceful shutdown + // A flexible, high-performance server with custom services, + // middleware, http, tls, tcp, and graceful shutdown Ok(()) } From 2b04c432f16c401429eb90be061f891467dd2089 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Wed, 11 Sep 2024 02:30:08 -0400 Subject: [PATCH 29/45] chore: basic benchmark --- Cargo.toml | 10 +- benches/hello_world_tower_hyper_tls_tcp.rs | 179 +++++++++++++++++++++ 2 files changed, 188 insertions(+), 1 deletion(-) create mode 100644 benches/hello_world_tower_hyper_tls_tcp.rs diff --git a/Cargo.toml b/Cargo.toml index 9e80a97..71d84ce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,13 +23,21 @@ hyper-util = { version = "0.1.8", features = ["server", "tokio", "server-auto", pin-project = "1.1.5" rustls = "0.23.13" rustls-pemfile = "2.1.3" -tokio = { version = "1.40.0", features = ["net", "macros"] } +tokio = { version = "1.40.0", features = ["net", "macros", "rt-multi-thread"] } tokio-rustls = "0.26.0" tokio-stream = { version = "0.1.16", features = ["net"] } tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" +tokio-util = "0.7.12" +hyper-rustls = "0.27.3" [dev-dependencies] +criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] } hyper = { version = "1.4.1", features = ["client"] } tokio = { version = "1.0", features = ["rt", "net", "test-util"] } +tokio-util = { version = "0.7", features = ["compat"] } tracing-subscriber = "0.3.18" + +[[bench]] +name = "hello_world_tower_hyper_tls_tcp" +harness = false diff --git a/benches/hello_world_tower_hyper_tls_tcp.rs b/benches/hello_world_tower_hyper_tls_tcp.rs new file mode 100644 index 0000000..d9ad627 --- /dev/null +++ b/benches/hello_world_tower_hyper_tls_tcp.rs @@ -0,0 +1,179 @@ +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use bytes::Bytes; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use http::{Request, Response, StatusCode}; +use http_body_util::{BodyExt, Empty, Full}; +use hyper::body::Incoming; +use hyper_util::rt::TokioExecutor; +use hyper_util::rt::TokioIo; +use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; +use hyper_util::service::TowerToHyperService; +use rustls::pki_types::ServerName; +use rustls::{RootCertStore, ServerConfig}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::runtime::Runtime; +use tokio::sync::oneshot; +use tokio_rustls::TlsConnector; +use tokio_stream::wrappers::TcpListenerStream; + +use hyper_server::{load_certs, load_private_key, serve_http_with_shutdown}; + +async fn echo(req: Request) -> Result>, hyper::Error> { + match (req.method(), req.uri().path()) { + (&hyper::Method::GET, "/") => Ok(Response::new(Full::new(Bytes::from("Hello, World!")))), + (&hyper::Method::POST, "/echo") => { + let body = req.collect().await?.to_bytes(); + Ok(Response::new(Full::new(body))) + } + _ => { + let mut res = Response::new(Full::new(Bytes::from("Not Found"))); + *res.status_mut() = StatusCode::NOT_FOUND; + Ok(res) + } + } +} + +async fn setup_server( +) -> Result<(TcpListenerStream, SocketAddr, Arc), Box> { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await?; + let server_addr = listener.local_addr()?; + let incoming = TcpListenerStream::new(listener); + + let certs = load_certs("examples/sample.pem")?; + let key = load_private_key("examples/sample.rsa")?; + let config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key)?; + let tls_config = Arc::new(config); + + Ok((incoming, server_addr, tls_config)) +} + +async fn start_server() -> Result<(SocketAddr, oneshot::Sender<()>), Box> { + let (incoming, server_addr, tls_config) = setup_server().await?; + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let tower_service_fn = tower::service_fn(echo); + let hyper_service = TowerToHyperService::new(tower_service_fn); + + tokio::spawn(async move { + serve_http_with_shutdown( + hyper_service, + incoming, + http_server_builder, + Some(tls_config), + Some(async { + shutdown_rx.await.ok(); + }), + ) + .await + .unwrap(); + }); + + Ok((server_addr, shutdown_tx)) +} + +fn create_https_client() -> (TlsConnector, ServerName<'static>) { + let mut root_cert_store = RootCertStore::empty(); + let certs = load_certs("examples/sample.pem").unwrap(); + root_cert_store.add_parsable_certificates(certs); + + let client_config = rustls::ClientConfig::builder() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + + let tls_connector = TlsConnector::from(Arc::new(client_config)); + let domain = ServerName::try_from("localhost").expect("Failed to create ServerName"); + + (tls_connector, domain) +} + +async fn send_request( + tls_connector: &TlsConnector, + domain: &ServerName<'static>, + addr: SocketAddr, +) -> Result<(), Box> { + let tcp_stream = TcpStream::connect(addr).await?; + let tls_stream = tls_connector.connect(domain.clone(), tcp_stream).await?; + let (mut sender, conn) = + hyper::client::conn::http1::handshake(TokioIo::new(tls_stream)).await?; + + tokio::spawn(async move { + if let Err(err) = conn.await { + eprintln!("Connection failed: {:?}", err); + } + }); + + let req = Request::builder().uri("/").body(Empty::::new())?; + + let res = sender.send_request(req).await?; + assert_eq!(res.status(), StatusCode::OK); + let body = res.collect().await?.to_bytes(); + assert_eq!(&body[..], b"Hello, World!"); + Ok(()) +} + +fn bench_server(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + let (server_addr, shutdown_tx) = rt.block_on(start_server()).unwrap(); + + let (tls_connector, domain) = create_https_client(); + + let mut group = c.benchmark_group("hyper_server"); + + group.bench_function("new_connection_latency", |b| { + b.to_async(&rt).iter(|| async { + send_request(&tls_connector, &domain, server_addr) + .await + .unwrap() + }); + }); + + let concurrent_requests = vec![10, 50, 100]; + for &num_requests in &concurrent_requests { + group.bench_with_input( + BenchmarkId::new("concurrent_connections", num_requests), + &num_requests, + |b, &num_requests| { + b.to_async(&rt).iter(|| async { + let requests = (0..num_requests) + .map(|_| send_request(&tls_connector, &domain, server_addr)); + futures::future::join_all(requests).await + }); + }, + ); + } + + group.bench_function("throughput", |b| { + b.to_async(&rt).iter(|| async { + let start = std::time::Instant::now(); + let mut count = 0; + + while start.elapsed() < Duration::from_secs(1) { + send_request(&tls_connector, &domain, server_addr) + .await + .unwrap(); + count += 1; + } + + count + }); + }); + + group.finish(); + + // Gracefully shutdown the server + rt.block_on(async { + shutdown_tx.send(()).unwrap(); + tokio::time::sleep(Duration::from_secs(1)).await; // Give the server time to shut down + }); +} + +criterion_group!(benches, bench_server); +criterion_main!(benches); From 20c67092eea1ec5d61447ba527474656b89e422c Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Wed, 11 Sep 2024 03:08:25 -0400 Subject: [PATCH 30/45] fix: benchmarks that run --- Cargo.toml | 1 + benches/hello_world_tower_hyper_tls_tcp.rs | 167 ++++++++------------- 2 files changed, 66 insertions(+), 102 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 71d84ce..51fbe33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" tokio-util = "0.7.12" hyper-rustls = "0.27.3" +rand = "0.9.0-alpha.2" [dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] } diff --git a/benches/hello_world_tower_hyper_tls_tcp.rs b/benches/hello_world_tower_hyper_tls_tcp.rs index d9ad627..87bbb29 100644 --- a/benches/hello_world_tower_hyper_tls_tcp.rs +++ b/benches/hello_world_tower_hyper_tls_tcp.rs @@ -1,25 +1,26 @@ +use rustls::ClientConfig; +use rustls::RootCertStore; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; - use bytes::Bytes; -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use http::{Request, Response, StatusCode}; -use http_body_util::{BodyExt, Empty, Full}; +use criterion::{criterion_group, criterion_main, Criterion, BenchmarkId}; +use futures::future::join_all; +use http::{Request, Response, StatusCode, Uri}; +use http_body_util::{Empty, Full, BodyExt}; use hyper::body::Incoming; use hyper_util::rt::TokioExecutor; -use hyper_util::rt::TokioIo; use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; -use hyper_util::service::TowerToHyperService; -use rustls::pki_types::ServerName; -use rustls::{RootCertStore, ServerConfig}; -use tokio::net::{TcpListener, TcpStream}; +use rustls::ServerConfig; +use tokio::net::TcpListener; use tokio::runtime::Runtime; -use tokio::sync::oneshot; -use tokio_rustls::TlsConnector; +use tokio::sync::{oneshot, Semaphore}; use tokio_stream::wrappers::TcpListenerStream; - +use hyper_util::service::TowerToHyperService; +use tracing::info; use hyper_server::{load_certs, load_private_key, serve_http_with_shutdown}; +use hyper_rustls::HttpsConnectorBuilder; +use hyper_util::client::legacy::Client; async fn echo(req: Request) -> Result>, hyper::Error> { match (req.method(), req.uri().path()) { @@ -27,7 +28,7 @@ async fn echo(req: Request) -> Result>, hyper::Er (&hyper::Method::POST, "/echo") => { let body = req.collect().await?.to_bytes(); Ok(Response::new(Full::new(body))) - } + }, _ => { let mut res = Response::new(Full::new(Bytes::from("Not Found"))); *res.status_mut() = StatusCode::NOT_FOUND; @@ -36,8 +37,7 @@ async fn echo(req: Request) -> Result>, hyper::Er } } -async fn setup_server( -) -> Result<(TcpListenerStream, SocketAddr, Arc), Box> { +async fn setup_server() -> Result<(TcpListenerStream, SocketAddr, Arc), Box> { let addr = SocketAddr::from(([127, 0, 0, 1], 0)); let listener = TcpListener::bind(addr).await?; let server_addr = listener.local_addr()?; @@ -45,135 +45,98 @@ async fn setup_server( let certs = load_certs("examples/sample.pem")?; let key = load_private_key("examples/sample.rsa")?; - let config = ServerConfig::builder() + + let mut config = ServerConfig::builder() .with_no_client_auth() - .with_single_cert(certs, key)?; + .with_single_cert(certs, key) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; let tls_config = Arc::new(config); Ok((incoming, server_addr, tls_config)) } -async fn start_server() -> Result<(SocketAddr, oneshot::Sender<()>), Box> { + +async fn start_server() -> Result<(SocketAddr, oneshot::Sender<()>), Box> { let (incoming, server_addr, tls_config) = setup_server().await?; let (shutdown_tx, shutdown_rx) = oneshot::channel(); - let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); let tower_service_fn = tower::service_fn(echo); let hyper_service = TowerToHyperService::new(tower_service_fn); - tokio::spawn(async move { serve_http_with_shutdown( hyper_service, incoming, http_server_builder, Some(tls_config), - Some(async { - shutdown_rx.await.ok(); - }), + Some(async { shutdown_rx.await.ok(); }), ) - .await - .unwrap(); + .await + .unwrap(); }); - Ok((server_addr, shutdown_tx)) } -fn create_https_client() -> (TlsConnector, ServerName<'static>) { - let mut root_cert_store = RootCertStore::empty(); - let certs = load_certs("examples/sample.pem").unwrap(); - root_cert_store.add_parsable_certificates(certs); - - let client_config = rustls::ClientConfig::builder() - .with_root_certificates(root_cert_store) - .with_no_client_auth(); - - let tls_connector = TlsConnector::from(Arc::new(client_config)); - let domain = ServerName::try_from("localhost").expect("Failed to create ServerName"); - - (tls_connector, domain) -} - -async fn send_request( - tls_connector: &TlsConnector, - domain: &ServerName<'static>, - addr: SocketAddr, -) -> Result<(), Box> { - let tcp_stream = TcpStream::connect(addr).await?; - let tls_stream = tls_connector.connect(domain.clone(), tcp_stream).await?; - let (mut sender, conn) = - hyper::client::conn::http1::handshake(TokioIo::new(tls_stream)).await?; - - tokio::spawn(async move { - if let Err(err) = conn.await { - eprintln!("Connection failed: {:?}", err); - } - }); - - let req = Request::builder().uri("/").body(Empty::::new())?; - - let res = sender.send_request(req).await?; +async fn send_request(client: &Client, Empty>, url: Uri) -> Result<(), Box> { + let res = client.get(url).await?; assert_eq!(res.status(), StatusCode::OK); - let body = res.collect().await?.to_bytes(); + let body = res.into_body().collect().await?.to_bytes(); assert_eq!(&body[..], b"Hello, World!"); Ok(()) } fn bench_server(c: &mut Criterion) { let rt = Runtime::new().unwrap(); + let (server_addr, shutdown_tx, client) = rt.block_on(async { + let (server_addr, shutdown_tx) = start_server().await.expect("Failed to start server"); + info!("Server started on {}", server_addr); - let (server_addr, shutdown_tx) = rt.block_on(start_server()).unwrap(); + let mut root_cert_store = RootCertStore::empty(); + root_cert_store.add_parsable_certificates(load_certs("examples/sample.pem").unwrap()); - let (tls_connector, domain) = create_https_client(); + let client_config = ClientConfig::builder() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); - let mut group = c.benchmark_group("hyper_server"); + let https = HttpsConnectorBuilder::new() + .with_tls_config(client_config) + .https_or_http() + .enable_http1() + .build(); - group.bench_function("new_connection_latency", |b| { - b.to_async(&rt).iter(|| async { - send_request(&tls_connector, &domain, server_addr) - .await - .unwrap() - }); + let client: Client<_, Empty> = Client::builder(TokioExecutor::new()).build(https); + + (server_addr, shutdown_tx, client) }); - let concurrent_requests = vec![10, 50, 100]; - for &num_requests in &concurrent_requests { - group.bench_with_input( - BenchmarkId::new("concurrent_connections", num_requests), - &num_requests, - |b, &num_requests| { - b.to_async(&rt).iter(|| async { - let requests = (0..num_requests) - .map(|_| send_request(&tls_connector, &domain, server_addr)); - futures::future::join_all(requests).await - }); - }, - ); - } + let url = Uri::builder() + .scheme("https") + .authority(format!("localhost:{}", server_addr.port())) + .path_and_query("/") + .build() + .expect("Failed to build URI"); - group.bench_function("throughput", |b| { - b.to_async(&rt).iter(|| async { - let start = std::time::Instant::now(); - let mut count = 0; + let mut group = c.benchmark_group("hyper_server"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(20)); - while start.elapsed() < Duration::from_secs(1) { - send_request(&tls_connector, &domain, server_addr) - .await - .unwrap(); - count += 1; - } - - count - }); - }); + // ... [keep the rest of the benchmark code as it is] ... group.finish(); - // Gracefully shutdown the server rt.block_on(async { shutdown_tx.send(()).unwrap(); - tokio::time::sleep(Duration::from_secs(1)).await; // Give the server time to shut down + tokio::time::sleep(Duration::from_secs(1)).await; }); } -criterion_group!(benches, bench_server); -criterion_main!(benches); +criterion_group! { + name = benches; + config = Criterion::default() + .sample_size(10) + .measurement_time(Duration::from_secs(20)) + .warm_up_time(Duration::from_secs(5)); + targets = bench_server +} + +criterion_main!(benches); \ No newline at end of file From 04d88eb5bad41f80545d29601c58b90f6b0d4bb7 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Wed, 11 Sep 2024 03:26:14 -0400 Subject: [PATCH 31/45] fix: fast benchmarks --- benches/hello_world_tower_hyper_tls_tcp.rs | 93 +++++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/benches/hello_world_tower_hyper_tls_tcp.rs b/benches/hello_world_tower_hyper_tls_tcp.rs index 87bbb29..015da77 100644 --- a/benches/hello_world_tower_hyper_tls_tcp.rs +++ b/benches/hello_world_tower_hyper_tls_tcp.rs @@ -120,7 +120,98 @@ fn bench_server(c: &mut Criterion) { group.sample_size(10); group.measurement_time(Duration::from_secs(20)); - // ... [keep the rest of the benchmark code as it is] ... + // Single request latency + group.bench_function("single_request_latency", |b| { + let client = client.clone(); + let url = url.clone(); + b.to_async(&rt).iter(|| async { + send_request(&client, url.clone()).await.unwrap() + }); + }); + + // Throughput test + group.bench_function("throughput", |b| { + let client = client.clone(); + let url = url.clone(); + b.to_async(&rt).iter_custom(|iters| { + let client = client.clone(); + let url = url.clone(); + async move { + let start = std::time::Instant::now(); + for _ in 0..iters { + send_request(&client, url.clone()).await.unwrap(); + } + start.elapsed() + } + }); + }); + + // Concurrent connections test + let concurrent_requests = vec![10, 50, 100, 200]; + for &num_requests in &concurrent_requests { + group.bench_with_input( + BenchmarkId::new("concurrent_requests", num_requests), + &num_requests, + |b, &num_requests| { + let client = client.clone(); + let url = url.clone(); + let semaphore = Arc::new(Semaphore::new(num_requests)); + b.to_async(&rt).iter(|| async { + let requests = (0..num_requests).map(|_| { + let client = client.clone(); + let url = url.clone(); + let semaphore = semaphore.clone(); + async move { + let _permit = semaphore.acquire().await.unwrap(); + send_request(&client, url).await + } + }); + join_all(requests).await.into_iter().collect::, _>>().unwrap() + }); + }, + ); + } + + let post_url = Uri::builder() + .scheme("https") + .authority(format!("localhost:{}", server_addr.port())) + .path_and_query("/echo") + .build() + .expect("Failed to build POST URI"); + + group.bench_function("post_request_with_payload", |b| { + let client = client.clone(); + let post_url = post_url.clone(); + b.to_async(&rt).iter(|| async { + let req = Request::builder() + .method("POST") + .uri(post_url.clone()) + .body(Empty::::new()) + .unwrap(); + let res = client.request(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let body = res.into_body().collect().await.unwrap().to_bytes(); + assert_eq!(&body[..], b""); // The echo endpoint will return an empty body for an empty request + }); + }); + + // Long-running connection test + group.bench_function("long_running_connection", |b| { + let client = client.clone(); + let url = url.clone(); + b.to_async(&rt).iter_custom(|iters| { + let client = client.clone(); + let url = url.clone(); + async move { + let start = std::time::Instant::now(); + for _ in 0..iters { + send_request(&client, url.clone()).await.unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + } + start.elapsed() + } + }); + }); group.finish(); From d4c245842874ad18214ab67490d6629f8e09be57 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Wed, 11 Sep 2024 05:36:09 -0400 Subject: [PATCH 32/45] fix: upgrade connection when possible --- src/http.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/http.rs b/src/http.rs index b1c950c..56dc96d 100644 --- a/src/http.rs +++ b/src/http.rs @@ -80,7 +80,7 @@ pub async fn serve_http_connection( }); // Create and pin the HTTP connection - let mut conn = pin!(builder.serve_connection(hyper_io, hyper_service)); + let mut conn = pin!(builder.serve_connection_with_upgrades(hyper_io, hyper_service)); // Set up the sleep future for max connection age let sleep = sleep_or_pending(max_connection_age); From 3e2757f0cd4d1716ee02532b5e34fd8f386db598 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Wed, 11 Sep 2024 05:37:02 -0400 Subject: [PATCH 33/45] chore: cleanup --- Cargo.toml | 6 +- benches/hello_world_tower_hyper_tls_tcp.rs | 64 +++++++++++++--------- 2 files changed, 42 insertions(+), 28 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 51fbe33..e2c7147 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,18 +19,18 @@ http = "1.1.0" http-body = "1.0.1" http-body-util = "0.1.2" hyper = "1.4.1" +hyper-rustls = "0.27.3" hyper-util = { version = "0.1.8", features = ["server", "tokio", "server-auto", "server-graceful", "service"] } pin-project = "1.1.5" +rand = "0.9.0-alpha.2" rustls = "0.23.13" rustls-pemfile = "2.1.3" tokio = { version = "1.40.0", features = ["net", "macros", "rt-multi-thread"] } tokio-rustls = "0.26.0" tokio-stream = { version = "0.1.16", features = ["net"] } +tokio-util = "0.7.12" tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" -tokio-util = "0.7.12" -hyper-rustls = "0.27.3" -rand = "0.9.0-alpha.2" [dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] } diff --git a/benches/hello_world_tower_hyper_tls_tcp.rs b/benches/hello_world_tower_hyper_tls_tcp.rs index 015da77..2f2c4e4 100644 --- a/benches/hello_world_tower_hyper_tls_tcp.rs +++ b/benches/hello_world_tower_hyper_tls_tcp.rs @@ -1,26 +1,26 @@ -use rustls::ClientConfig; -use rustls::RootCertStore; -use std::net::SocketAddr; -use std::sync::Arc; -use std::time::Duration; use bytes::Bytes; -use criterion::{criterion_group, criterion_main, Criterion, BenchmarkId}; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use futures::future::join_all; use http::{Request, Response, StatusCode, Uri}; -use http_body_util::{Empty, Full, BodyExt}; +use http_body_util::{BodyExt, Empty, Full}; use hyper::body::Incoming; +use hyper_rustls::HttpsConnectorBuilder; +use hyper_server::{load_certs, load_private_key, serve_http_with_shutdown}; +use hyper_util::client::legacy::Client; use hyper_util::rt::TokioExecutor; use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; +use hyper_util::service::TowerToHyperService; +use rustls::ClientConfig; +use rustls::RootCertStore; use rustls::ServerConfig; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; use tokio::net::TcpListener; use tokio::runtime::Runtime; use tokio::sync::{oneshot, Semaphore}; use tokio_stream::wrappers::TcpListenerStream; -use hyper_util::service::TowerToHyperService; use tracing::info; -use hyper_server::{load_certs, load_private_key, serve_http_with_shutdown}; -use hyper_rustls::HttpsConnectorBuilder; -use hyper_util::client::legacy::Client; async fn echo(req: Request) -> Result>, hyper::Error> { match (req.method(), req.uri().path()) { @@ -28,7 +28,7 @@ async fn echo(req: Request) -> Result>, hyper::Er (&hyper::Method::POST, "/echo") => { let body = req.collect().await?.to_bytes(); Ok(Response::new(Full::new(body))) - }, + } _ => { let mut res = Response::new(Full::new(Bytes::from("Not Found"))); *res.status_mut() = StatusCode::NOT_FOUND; @@ -37,7 +37,10 @@ async fn echo(req: Request) -> Result>, hyper::Er } } -async fn setup_server() -> Result<(TcpListenerStream, SocketAddr, Arc), Box> { +async fn setup_server() -> Result< + (TcpListenerStream, SocketAddr, Arc), + Box, +> { let addr = SocketAddr::from(([127, 0, 0, 1], 0)); let listener = TcpListener::bind(addr).await?; let server_addr = listener.local_addr()?; @@ -56,8 +59,8 @@ async fn setup_server() -> Result<(TcpListenerStream, SocketAddr, Arc Result<(SocketAddr, oneshot::Sender<()>), Box> { +async fn start_server( +) -> Result<(SocketAddr, oneshot::Sender<()>), Box> { let (incoming, server_addr, tls_config) = setup_server().await?; let (shutdown_tx, shutdown_rx) = oneshot::channel(); let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); @@ -69,15 +72,23 @@ async fn start_server() -> Result<(SocketAddr, oneshot::Sender<()>), Box, Empty>, url: Uri) -> Result<(), Box> { +async fn send_request( + client: &Client< + hyper_rustls::HttpsConnector, + Empty, + >, + url: Uri, +) -> Result<(), Box> { let res = client.get(url).await?; assert_eq!(res.status(), StatusCode::OK); let body = res.into_body().collect().await?.to_bytes(); @@ -124,9 +135,8 @@ fn bench_server(c: &mut Criterion) { group.bench_function("single_request_latency", |b| { let client = client.clone(); let url = url.clone(); - b.to_async(&rt).iter(|| async { - send_request(&client, url.clone()).await.unwrap() - }); + b.to_async(&rt) + .iter(|| async { send_request(&client, url.clone()).await.unwrap() }); }); // Throughput test @@ -166,7 +176,11 @@ fn bench_server(c: &mut Criterion) { send_request(&client, url).await } }); - join_all(requests).await.into_iter().collect::, _>>().unwrap() + join_all(requests) + .await + .into_iter() + .collect::, _>>() + .unwrap() }); }, ); @@ -191,7 +205,7 @@ fn bench_server(c: &mut Criterion) { let res = client.request(req).await.unwrap(); assert_eq!(res.status(), StatusCode::OK); let body = res.into_body().collect().await.unwrap().to_bytes(); - assert_eq!(&body[..], b""); // The echo endpoint will return an empty body for an empty request + assert_eq!(&body[..], b""); // The echo endpoint will return an empty body for an empty request }); }); @@ -230,4 +244,4 @@ criterion_group! { targets = bench_server } -criterion_main!(benches); \ No newline at end of file +criterion_main!(benches); From ea358ba044ae84cd2327a207139d518f997e6337 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Wed, 11 Sep 2024 05:38:53 -0400 Subject: [PATCH 34/45] chore: enhance benchmarks --- benches/hello_world_tower_hyper_tls_tcp.rs | 161 ++++++++++----------- 1 file changed, 74 insertions(+), 87 deletions(-) diff --git a/benches/hello_world_tower_hyper_tls_tcp.rs b/benches/hello_world_tower_hyper_tls_tcp.rs index 2f2c4e4..2acdefa 100644 --- a/benches/hello_world_tower_hyper_tls_tcp.rs +++ b/benches/hello_world_tower_hyper_tls_tcp.rs @@ -1,6 +1,7 @@ use bytes::Bytes; -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use futures::future::join_all; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use futures::stream::FuturesUnordered; +use futures::StreamExt; use http::{Request, Response, StatusCode, Uri}; use http_body_util::{BodyExt, Empty, Full}; use hyper::body::Incoming; @@ -10,15 +11,14 @@ use hyper_util::client::legacy::Client; use hyper_util::rt::TokioExecutor; use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; use hyper_util::service::TowerToHyperService; -use rustls::ClientConfig; -use rustls::RootCertStore; -use rustls::ServerConfig; +use rustls::server::ServerSessionMemoryCache; +use rustls::{ClientConfig, RootCertStore, ServerConfig}; use std::net::SocketAddr; use std::sync::Arc; -use std::time::Duration; -use tokio::net::TcpListener; +use tokio::net::TcpSocket; use tokio::runtime::Runtime; -use tokio::sync::{oneshot, Semaphore}; +use tokio::sync::oneshot; +use tokio::time::{Duration, Instant}; use tokio_stream::wrappers::TcpListenerStream; use tracing::info; @@ -41,19 +41,36 @@ async fn setup_server() -> Result< (TcpListenerStream, SocketAddr, Arc), Box, > { - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - let listener = TcpListener::bind(addr).await?; + // Socket configuration + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); // Listen on all interfaces + let socket = TcpSocket::new_v4()?; + socket.set_send_buffer_size(262_144)?; // 256 KB + socket.set_recv_buffer_size(262_144)?; // 256 KB + socket.set_nodelay(true)?; // Disable Nagle's algorithm + socket.bind(addr)?; + let listener = socket.listen(8192)?; // Increase backlog for high-traffic scenarios let server_addr = listener.local_addr()?; let incoming = TcpListenerStream::new(listener); + // Load certificates and private key let certs = load_certs("examples/sample.pem")?; let key = load_private_key("examples/sample.rsa")?; + // TLS configuration let mut config = ServerConfig::builder() .with_no_client_auth() .with_single_cert(certs, key) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; - config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; + + // ALPN configuration + config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + + // Performance optimizations + config.max_fragment_size = Some(16384); // Larger fragment size for powerful servers + config.send_half_rtt_data = true; // Enable 0.5-RTT data + config.session_storage = ServerSessionMemoryCache::new(10240); // Larger session cache + config.max_early_data_size = 16384; // Enable 0-RTT data + let tls_config = Arc::new(config); Ok((incoming, server_addr, tls_config)) @@ -64,6 +81,7 @@ async fn start_server( let (incoming, server_addr, tls_config) = setup_server().await?; let (shutdown_tx, shutdown_rx) = oneshot::channel(); let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let tower_service_fn = tower::service_fn(echo); let hyper_service = TowerToHyperService::new(tower_service_fn); tokio::spawn(async move { @@ -88,12 +106,43 @@ async fn send_request( Empty, >, url: Uri, -) -> Result<(), Box> { +) -> Result<(Duration, usize), Box> { + let start = Instant::now(); let res = client.get(url).await?; assert_eq!(res.status(), StatusCode::OK); let body = res.into_body().collect().await?.to_bytes(); assert_eq!(&body[..], b"Hello, World!"); - Ok(()) + Ok((start.elapsed(), body.len())) +} + +async fn concurrent_benchmark( + client: &Client< + hyper_rustls::HttpsConnector, + Empty, + >, + url: Uri, + num_requests: usize, +) -> (Duration, Vec, usize) { + let start = Instant::now(); + let mut futures = FuturesUnordered::new(); + + for _ in 0..num_requests { + let client = client.clone(); + let url = url.clone(); + futures.push(async move { send_request(&client, url).await }); + } + + let mut request_times = Vec::with_capacity(num_requests); + let mut total_bytes = 0; + while let Some(result) = futures.next().await { + if let Ok((duration, bytes)) = result { + request_times.push(duration); + total_bytes += bytes; + } + } + + let total_time = start.elapsed(); + (total_time, request_times, total_bytes) } fn bench_server(c: &mut Criterion) { @@ -128,105 +177,43 @@ fn bench_server(c: &mut Criterion) { .expect("Failed to build URI"); let mut group = c.benchmark_group("hyper_server"); - group.sample_size(10); - group.measurement_time(Duration::from_secs(20)); + group.sample_size(20); + group.measurement_time(Duration::from_secs(30)); - // Single request latency - group.bench_function("single_request_latency", |b| { + // Latency test + group.bench_function("latency", |b| { let client = client.clone(); let url = url.clone(); b.to_async(&rt) - .iter(|| async { send_request(&client, url.clone()).await.unwrap() }); + .iter(|| async { send_request(&client, url.clone()).await.unwrap().0 }); }); // Throughput test + group.throughput(Throughput::Elements(1)); group.bench_function("throughput", |b| { let client = client.clone(); let url = url.clone(); - b.to_async(&rt).iter_custom(|iters| { - let client = client.clone(); - let url = url.clone(); - async move { - let start = std::time::Instant::now(); - for _ in 0..iters { - send_request(&client, url.clone()).await.unwrap(); - } - start.elapsed() - } - }); + b.to_async(&rt) + .iter(|| async { send_request(&client, url.clone()).await.unwrap() }); }); - // Concurrent connections test - let concurrent_requests = vec![10, 50, 100, 200]; + // Concurrency stress test + let concurrent_requests = vec![1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987]; // Fibonacci sequence for &num_requests in &concurrent_requests { + group.throughput(Throughput::Elements(num_requests as u64)); group.bench_with_input( BenchmarkId::new("concurrent_requests", num_requests), &num_requests, |b, &num_requests| { let client = client.clone(); let url = url.clone(); - let semaphore = Arc::new(Semaphore::new(num_requests)); b.to_async(&rt).iter(|| async { - let requests = (0..num_requests).map(|_| { - let client = client.clone(); - let url = url.clone(); - let semaphore = semaphore.clone(); - async move { - let _permit = semaphore.acquire().await.unwrap(); - send_request(&client, url).await - } - }); - join_all(requests) - .await - .into_iter() - .collect::, _>>() - .unwrap() + concurrent_benchmark(&client, url.clone(), num_requests).await }); }, ); } - let post_url = Uri::builder() - .scheme("https") - .authority(format!("localhost:{}", server_addr.port())) - .path_and_query("/echo") - .build() - .expect("Failed to build POST URI"); - - group.bench_function("post_request_with_payload", |b| { - let client = client.clone(); - let post_url = post_url.clone(); - b.to_async(&rt).iter(|| async { - let req = Request::builder() - .method("POST") - .uri(post_url.clone()) - .body(Empty::::new()) - .unwrap(); - let res = client.request(req).await.unwrap(); - assert_eq!(res.status(), StatusCode::OK); - let body = res.into_body().collect().await.unwrap().to_bytes(); - assert_eq!(&body[..], b""); // The echo endpoint will return an empty body for an empty request - }); - }); - - // Long-running connection test - group.bench_function("long_running_connection", |b| { - let client = client.clone(); - let url = url.clone(); - b.to_async(&rt).iter_custom(|iters| { - let client = client.clone(); - let url = url.clone(); - async move { - let start = std::time::Instant::now(); - for _ in 0..iters { - send_request(&client, url.clone()).await.unwrap(); - tokio::time::sleep(Duration::from_millis(100)).await; - } - start.elapsed() - } - }); - }); - group.finish(); rt.block_on(async { From 091f0d810294360a41f221312cbec5ae998c5e74 Mon Sep 17 00:00:00 2001 From: Alcibiades <89996683+0xAlcibiades@users.noreply.github.com> Date: Wed, 11 Sep 2024 12:43:46 -0400 Subject: [PATCH 35/45] Basic optimizations (#18) * optimize: more * fix: bump package versions * fix: apples to apples * fix: fmt --- Cargo.toml | 8 ++++-- benches/hello_world_tower_hyper_tls_tcp.rs | 11 ++++++-- src/http.rs | 33 ++++++++++++++++++++++ src/io.rs | 6 ++++ src/lib.rs | 7 +++++ src/tcp.rs | 1 + src/tls.rs | 3 ++ 7 files changed, 65 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e2c7147..1690145 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ hyper-rustls = "0.27.3" hyper-util = { version = "0.1.8", features = ["server", "tokio", "server-auto", "server-graceful", "service"] } pin-project = "1.1.5" rand = "0.9.0-alpha.2" -rustls = "0.23.13" +rustls = { version = "0.23.13", features = ["zlib"] } rustls-pemfile = "2.1.3" tokio = { version = "1.40.0", features = ["net", "macros", "rt-multi-thread"] } tokio-rustls = "0.26.0" @@ -31,14 +31,18 @@ tokio-stream = { version = "0.1.16", features = ["net"] } tokio-util = "0.7.12" tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" +jemallocator = { version = "0.5", optional = true } [dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] } hyper = { version = "1.4.1", features = ["client"] } -tokio = { version = "1.0", features = ["rt", "net", "test-util"] } +tokio = { version = "1.40", features = ["rt", "net", "test-util"] } tokio-util = { version = "0.7", features = ["compat"] } tracing-subscriber = "0.3.18" [[bench]] name = "hello_world_tower_hyper_tls_tcp" harness = false + +[features] +jemalloc = ["jemallocator"] diff --git a/benches/hello_world_tower_hyper_tls_tcp.rs b/benches/hello_world_tower_hyper_tls_tcp.rs index 2acdefa..6293823 100644 --- a/benches/hello_world_tower_hyper_tls_tcp.rs +++ b/benches/hello_world_tower_hyper_tls_tcp.rs @@ -69,6 +69,7 @@ async fn setup_server() -> Result< config.max_fragment_size = Some(16384); // Larger fragment size for powerful servers config.send_half_rtt_data = true; // Enable 0.5-RTT data config.session_storage = ServerSessionMemoryCache::new(10240); // Larger session cache + config.cert_compression_cache = Arc::new(rustls::compress::CompressionCache::default()); config.max_early_data_size = 16384; // Enable 0-RTT data let tls_config = Arc::new(config); @@ -154,9 +155,15 @@ fn bench_server(c: &mut Criterion) { let mut root_cert_store = RootCertStore::empty(); root_cert_store.add_parsable_certificates(load_certs("examples/sample.pem").unwrap()); - let client_config = ClientConfig::builder() + let mut client_config = ClientConfig::builder() .with_root_certificates(root_cert_store) .with_no_client_auth(); + // Enable handshake resumption + client_config.resumption = rustls::client::Resumption::in_memory_sessions(10240); + client_config.cert_compression_cache = + Arc::new(rustls::compress::CompressionCache::default()); + client_config.max_fragment_size = Some(16384); // Larger fragment size for powerful servers + client_config.enable_early_data = true; // Enable 0-RTT data let https = HttpsConnectorBuilder::new() .with_tls_config(client_config) @@ -198,7 +205,7 @@ fn bench_server(c: &mut Criterion) { }); // Concurrency stress test - let concurrent_requests = vec![1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987]; // Fibonacci sequence + let concurrent_requests = vec![1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987]; // log sequence for &num_requests in &concurrent_requests { group.throughput(Throughput::Elements(num_requests as u64)); group.bench_with_input( diff --git a/src/http.rs b/src/http.rs index 56dc96d..6957369 100644 --- a/src/http.rs +++ b/src/http.rs @@ -28,6 +28,7 @@ use crate::fuse::Fuse; /// /// * `wait_for` - An `Option` specifying how long to sleep. /// If `None`, the function will wait indefinitely. +#[inline] async fn sleep_or_pending(wait_for: Option) { match wait_for { Some(wait) => sleep(wait).await, @@ -55,6 +56,7 @@ async fn sleep_or_pending(wait_for: Option) { /// * `watcher`: An optional `tokio::sync::watch::Receiver` for graceful shutdown signaling. /// * `max_connection_age`: An optional `Duration` specifying the maximum age of the connection /// before initiating a graceful shutdown. +#[inline] pub async fn serve_http_connection( hyper_io: IO, hyper_service: S, @@ -79,6 +81,36 @@ pub async fn serve_http_connection( inner: watcher.as_mut().map(|w| w.changed()), }); + let builder = builder.clone(); + // TODO(How to accept a preconfigured builder) + // The API here for hyper_util is poor. + // Really what you want to do is configure a builder like this + // and pass it in for use as a builder, however, you cannot + // the simple way may be to require configuration and + // then accept an immutable reference to an http2 connection builder + let mut builder = builder.clone(); + builder + // HTTP/1 settings + .http1() + .half_close(true) + .keep_alive(true) + .max_buf_size(64 * 1024) + .pipeline_flush(true) + .preserve_header_case(true) + .title_case_headers(false) + // HTTP/2 settings + .http2() + .initial_stream_window_size(Some(1024 * 1024)) + .initial_connection_window_size(Some(2 * 1024 * 1024)) + .adaptive_window(true) + .max_frame_size(Some(16 * 1024)) + .max_concurrent_streams(Some(1000)) + .max_send_buf_size(1024 * 1024) + .enable_connect_protocol() + .max_header_list_size(16 * 1024) + .keep_alive_interval(Some(Duration::from_secs(20))) + .keep_alive_timeout(Duration::from_secs(20)); + // Create and pin the HTTP connection let mut conn = pin!(builder.serve_connection_with_upgrades(hyper_io, hyper_service)); @@ -332,6 +364,7 @@ pub async fn serve_http_connection( /// - The server will continue to accept new connections until the `signal` future resolves. /// - When using TLS, make sure to provide a properly configured `ServerConfig`. /// - The function will return when all connections have been closed after the shutdown signal. +#[inline] pub async fn serve_http_with_shutdown( service: S, incoming: I, diff --git a/src/io.rs b/src/io.rs index 48ece65..feacfa5 100644 --- a/src/io.rs +++ b/src/io.rs @@ -23,6 +23,7 @@ impl AsyncRead for Transport where IO: AsyncRead + AsyncWrite + Unpin, { + #[inline] fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -39,6 +40,7 @@ impl AsyncWrite for Transport where IO: AsyncRead + AsyncWrite + Unpin, { + #[inline] fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -50,6 +52,7 @@ where } } + #[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { Transport::Plain(io) => Pin::new(io).poll_flush(cx), @@ -57,6 +60,7 @@ where } } + #[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { Transport::Plain(io) => Pin::new(io).poll_shutdown(cx), @@ -64,6 +68,7 @@ where } } + #[inline] fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -75,6 +80,7 @@ where } } + #[inline] fn is_write_vectored(&self) -> bool { match self { Transport::Plain(io) => io.is_write_vectored(), diff --git a/src/lib.rs b/src/lib.rs index 8fe53d3..5bf2185 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,10 @@ +#[cfg(all(feature = "jemalloc", not(target_env = "msvc")))] +use jemallocator::Jemalloc; + +#[cfg(all(feature = "jemalloc", not(target_env = "msvc")))] +#[global_allocator] +static GLOBAL: Jemalloc = Jemalloc; + pub use error::{Error as TransportError, Kind as TransportErrorKind}; pub use http::serve_http_connection; pub use http::serve_http_with_shutdown; diff --git a/src/tcp.rs b/src/tcp.rs index 2d5be6d..98347d8 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -69,6 +69,7 @@ fn handle_accept_error(e: impl Into) -> ControlFlow { /// This function uses `handle_accept_error` to determine whether to continue accepting /// connections after an error occurs. Non-fatal errors are logged and skipped, while /// fatal errors cause the stream to yield an error and terminate. +#[inline] pub fn serve_tcp_incoming( incoming: impl Stream> + Send + 'static, ) -> impl Stream> diff --git a/src/tls.rs b/src/tls.rs index 61ae64a..6f55978 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -31,6 +31,7 @@ use tokio_stream::{Stream, StreamExt}; /// /// - If the input `tcp_stream` yields an error, that error is propagated. /// - If the TLS handshake fails, the error is wrapped in the crate's `Error` type. +#[inline] pub fn serve_tls_incoming( tcp_stream: impl Stream>, tls: TlsAcceptor, @@ -71,6 +72,7 @@ where /// # Returns /// /// A `Result` containing a vector of `CertificateDer` on success, or an `io::Error` on failure. +#[inline] pub fn load_certs(filename: &str) -> io::Result>> { // Open certificate file let certfile = fs::File::open(filename)?; @@ -92,6 +94,7 @@ pub fn load_certs(filename: &str) -> io::Result>> { /// # Returns /// /// A `Result` containing a `PrivateKeyDer` on success, or an `io::Error` on failure. +#[inline] pub fn load_private_key(filename: &str) -> io::Result> { // Open keyfile let keyfile = fs::File::open(filename)?; From 66c09f23e516f9ac1411160c0f978216b66d3cad Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Wed, 11 Sep 2024 15:36:03 -0400 Subject: [PATCH 36/45] feat: add flamegraphs to benchmarks --- Cargo.toml | 13 +- benches/hello_world_tower_hyper_tls_tcp.rs | 103 ++++- flamegraph.svg | 491 +++++++++++++++++++++ src/fuse.rs | 4 +- src/lib.rs | 2 +- 5 files changed, 587 insertions(+), 26 deletions(-) create mode 100644 flamegraph.svg diff --git a/Cargo.toml b/Cargo.toml index 1690145..71fd03e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,9 @@ readme = "README.md" repository = "https://github.com/valorem-labs-inc/hyper-server" version = "0.7.0" +[target.'cfg(not(target_env = "msvc"))'.dependencies] +tikv-jemallocator = { version = "0.6", optional = true } + [dependencies] async-stream = "0.3.5" bytes = "1.7.1" @@ -31,18 +34,22 @@ tokio-stream = { version = "0.1.16", features = ["net"] } tokio-util = "0.7.12" tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" -jemallocator = { version = "0.5", optional = true } +signature = "2.3.0-pre.4" [dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] } hyper = { version = "1.4.1", features = ["client"] } tokio = { version = "1.40", features = ["rt", "net", "test-util"] } -tokio-util = { version = "0.7", features = ["compat"] } +tokio-util = { version = "0.7.12", features = ["compat"] } tracing-subscriber = "0.3.18" +num_cpus = "1.16.0" +pprof = { version = "0.13.0", features = ["flamegraph"] } +ring = "0.17.8" +rcgen = "0.13.1" [[bench]] name = "hello_world_tower_hyper_tls_tcp" harness = false [features] -jemalloc = ["jemallocator"] +jemalloc = ["tikv-jemallocator"] diff --git a/benches/hello_world_tower_hyper_tls_tcp.rs b/benches/hello_world_tower_hyper_tls_tcp.rs index 6293823..8db38ad 100644 --- a/benches/hello_world_tower_hyper_tls_tcp.rs +++ b/benches/hello_world_tower_hyper_tls_tcp.rs @@ -1,20 +1,22 @@ +use std::io; +use std::net::SocketAddr; +use std::sync::Arc; +use std::{fs::File, os::raw::c_int, path::Path}; + use bytes::Bytes; +use criterion::profiler::Profiler; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; -use futures::stream::FuturesUnordered; -use futures::StreamExt; use http::{Request, Response, StatusCode, Uri}; use http_body_util::{BodyExt, Empty, Full}; use hyper::body::Incoming; use hyper_rustls::HttpsConnectorBuilder; -use hyper_server::{load_certs, load_private_key, serve_http_with_shutdown}; use hyper_util::client::legacy::Client; use hyper_util::rt::TokioExecutor; use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; use hyper_util::service::TowerToHyperService; +use pprof::ProfilerGuard; use rustls::server::ServerSessionMemoryCache; use rustls::{ClientConfig, RootCertStore, ServerConfig}; -use std::net::SocketAddr; -use std::sync::Arc; use tokio::net::TcpSocket; use tokio::runtime::Runtime; use tokio::sync::oneshot; @@ -22,6 +24,52 @@ use tokio::time::{Duration, Instant}; use tokio_stream::wrappers::TcpListenerStream; use tracing::info; +use hyper_server::{load_certs, load_private_key, serve_http_with_shutdown}; + +/// Custom profiler that creates a flamegraph for each benchmark +pub struct FlamegraphProfiler<'a> { + frequency: c_int, + active_profiler: Option>, +} + +impl<'a> FlamegraphProfiler<'a> { + pub fn new(frequency: c_int) -> Self { + FlamegraphProfiler { + frequency, + active_profiler: None, + } + } +} + +impl<'a> Profiler for FlamegraphProfiler<'a> { + fn start_profiling(&mut self, _benchmark_id: &str, _benchmark_dir: &Path) { + self.active_profiler = Some(ProfilerGuard::new(self.frequency).unwrap()); + } + + fn stop_profiling(&mut self, _benchmark_id: &str, benchmark_dir: &Path) { + std::fs::create_dir_all(benchmark_dir).unwrap(); + let flamegraph_path = benchmark_dir.join("flamegraph.svg"); + let flamegraph_file = File::create(&flamegraph_path) + .expect("File system error while creating flamegraph.svg"); + if let Some(profiler) = self.active_profiler.take() { + profiler + .report() + .build() + .unwrap() + .flamegraph(flamegraph_file) + .expect("Error writing flamegraph"); + } + } +} + +fn create_optimized_runtime(thread_count: usize) -> io::Result { + tokio::runtime::Builder::new_multi_thread() + .worker_threads(thread_count) + .max_blocking_threads(thread_count * 2) + .enable_all() + .build() +} + async fn echo(req: Request) -> Result>, hyper::Error> { match (req.method(), req.uri().path()) { (&hyper::Method::GET, "/") => Ok(Response::new(Full::new(Bytes::from("Hello, World!")))), @@ -41,14 +89,20 @@ async fn setup_server() -> Result< (TcpListenerStream, SocketAddr, Arc), Box, > { - // Socket configuration - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); // Listen on all interfaces + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); let socket = TcpSocket::new_v4()?; + + // Optimize TCP parameters socket.set_send_buffer_size(262_144)?; // 256 KB socket.set_recv_buffer_size(262_144)?; // 256 KB - socket.set_nodelay(true)?; // Disable Nagle's algorithm + socket.set_nodelay(true)?; + socket.set_reuseaddr(true)?; + socket.set_reuseport(true)?; + socket.set_keepalive(true)?; + socket.bind(addr)?; - let listener = socket.listen(8192)?; // Increase backlog for high-traffic scenarios + let listener = socket.listen(8192)?; // Increased backlog for high-traffic scenarios + let server_addr = listener.local_addr()?; let incoming = TcpListenerStream::new(listener); @@ -125,18 +179,20 @@ async fn concurrent_benchmark( num_requests: usize, ) -> (Duration, Vec, usize) { let start = Instant::now(); - let mut futures = FuturesUnordered::new(); + let mut handles = Vec::with_capacity(num_requests); for _ in 0..num_requests { let client = client.clone(); let url = url.clone(); - futures.push(async move { send_request(&client, url).await }); + let handle = tokio::spawn(async move { send_request(&client, url).await }); + handles.push(handle); } let mut request_times = Vec::with_capacity(num_requests); let mut total_bytes = 0; - while let Some(result) = futures.next().await { - if let Ok((duration, bytes)) = result { + + for handle in handles { + if let Ok(Ok((duration, bytes))) = handle.await { request_times.push(duration); total_bytes += bytes; } @@ -147,8 +203,9 @@ async fn concurrent_benchmark( } fn bench_server(c: &mut Criterion) { - let rt = Runtime::new().unwrap(); - let (server_addr, shutdown_tx, client) = rt.block_on(async { + let server_runtime = Arc::new(create_optimized_runtime(num_cpus::get() / 2).unwrap()); + + let (server_addr, shutdown_tx, client) = server_runtime.block_on(async { let (server_addr, shutdown_tx) = start_server().await.expect("Failed to start server"); info!("Server started on {}", server_addr); @@ -191,7 +248,8 @@ fn bench_server(c: &mut Criterion) { group.bench_function("latency", |b| { let client = client.clone(); let url = url.clone(); - b.to_async(&rt) + let client_runtime = create_optimized_runtime(num_cpus::get() / 2).unwrap(); + b.to_async(client_runtime) .iter(|| async { send_request(&client, url.clone()).await.unwrap().0 }); }); @@ -200,12 +258,13 @@ fn bench_server(c: &mut Criterion) { group.bench_function("throughput", |b| { let client = client.clone(); let url = url.clone(); - b.to_async(&rt) + let client_runtime = create_optimized_runtime(num_cpus::get() / 2).unwrap(); + b.to_async(client_runtime) .iter(|| async { send_request(&client, url.clone()).await.unwrap() }); }); // Concurrency stress test - let concurrent_requests = vec![1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987]; // log sequence + let concurrent_requests = vec![1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987]; for &num_requests in &concurrent_requests { group.throughput(Throughput::Elements(num_requests as u64)); group.bench_with_input( @@ -214,7 +273,8 @@ fn bench_server(c: &mut Criterion) { |b, &num_requests| { let client = client.clone(); let url = url.clone(); - b.to_async(&rt).iter(|| async { + let client_runtime = create_optimized_runtime(num_cpus::get() / 2).unwrap(); + b.to_async(client_runtime).iter(|| async { concurrent_benchmark(&client, url.clone(), num_requests).await }); }, @@ -223,7 +283,7 @@ fn bench_server(c: &mut Criterion) { group.finish(); - rt.block_on(async { + server_runtime.block_on(async { shutdown_tx.send(()).unwrap(); tokio::time::sleep(Duration::from_secs(1)).await; }); @@ -234,7 +294,8 @@ criterion_group! { config = Criterion::default() .sample_size(10) .measurement_time(Duration::from_secs(20)) - .warm_up_time(Duration::from_secs(5)); + .warm_up_time(Duration::from_secs(5)) + .with_profiler(FlamegraphProfiler::new(100)); targets = bench_server } diff --git a/flamegraph.svg b/flamegraph.svg new file mode 100644 index 0000000..8aff83e --- /dev/null +++ b/flamegraph.svg @@ -0,0 +1,491 @@ +Flame Graph Reset ZoomSearch bytes::bytes::shared_clone (1 samples, 0.30%)core::ptr::drop_in_place<alloc::boxed::Box<tokio::runtime::task::core::Cell<hello_world_tower_hyper_tls_tcp::concurrent_benchmark::{{closure}}::{{closure}},alloc::sync::Arc<tokio::runtime::scheduler::multi_thread::handle::Handle>>>> (1 samples, 0.30%)core::ptr::drop_in_place<hyper_util::client::legacy::client::Error> (3 samples, 0.90%)core::ptr::drop_in_place<hyper_util::client::legacy::connect::http::ConnectError> (2 samples, 0.60%)_rjem_je_sdallocx_default (1 samples, 0.30%)_rjem_je_tcache_bin_flush_small (1 samples, 0.30%)hello_world_tow (6 samples, 1.79%)h.._start (6 samples, 1.79%)_..__libc_start_main (6 samples, 1.79%)_..main (6 samples, 1.79%)m..std::rt::lang_start_internal (6 samples, 1.79%)s..std::rt::lang_start::{{closure}} (6 samples, 1.79%)s..std::sys::backtrace::__rust_begin_short_backtrace (6 samples, 1.79%)s..hello_world_tower_hyper_tls_tcp::main (6 samples, 1.79%)h..criterion::benchmark_group::BenchmarkGroup<M>::bench_with_input (6 samples, 1.79%)c..<alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter (6 samples, 1.79%)<..criterion::bencher::AsyncBencher<A,M>::iter (6 samples, 1.79%)c..<tokio::runtime::runtime::Runtime as criterion::async_executor::AsyncExecutor>::block_on (6 samples, 1.79%)<..tokio::task::spawn::spawn (1 samples, 0.30%)_rjem_je_tsd_cleanup (1 samples, 0.30%)tcache_destroy.constprop.0 (1 samples, 0.30%)_rjem_je_tcache_bin_flush_small (1 samples, 0.30%)_rjem_je_arena_slab_dalloc (1 samples, 0.30%)pac_dalloc_impl (1 samples, 0.30%)_rjem_je_extent_record (1 samples, 0.30%)extent_try_coalesce_impl (1 samples, 0.30%)extent_merge_impl.constprop.0 (1 samples, 0.30%)_rjem_je_edata_cache_put (1 samples, 0.30%)std::sys::sync::mutex::futex::Mutex::lock_contended (8 samples, 2.39%)st..syscall (6 samples, 1.79%)s..std::sys::sync::condvar::futex::Condvar::wait_timeout (13 samples, 3.88%)std:..syscall (5 samples, 1.49%)std::sys::sync::mutex::futex::Mutex::lock_contended (2 samples, 0.60%)syscall (1 samples, 0.30%)syscall (2 samples, 0.60%)__file_change_detection_for_path (2 samples, 0.60%)fstatat64 (2 samples, 0.60%)__lll_lock_wake_private (2 samples, 0.60%)__lll_lock_wait_private (1 samples, 0.30%)__resolv_context_get (7 samples, 2.09%)_..__lll_lock_wake_private (3 samples, 0.90%)__resolv_context_put (3 samples, 0.90%)__clock_gettime (1 samples, 0.30%)__res_context_mkquery (2 samples, 0.60%)__ns_name_compress (1 samples, 0.30%)__ns_name_pack (1 samples, 0.30%)__poll (1 samples, 0.30%)__socket (3 samples, 0.90%)<std::sys_common::net::LookupHost as core::convert::TryFrom<(&str,u16)>>::try_from::{{closure}} (23 samples, 6.87%)<std::sys..getaddrinfo (23 samples, 6.87%)getaddrin.._nss_dns_gethostbyname4_r (9 samples, 2.69%)_n..__res_context_search (9 samples, 2.69%)__..__res_context_query (8 samples, 2.39%)__..__res_context_send (5 samples, 1.49%)ioctl (1 samples, 0.30%)<(&str,u16) as std::net::socket_addr::ToSocketAddrs>::to_socket_addrs (25 samples, 7.46%)<(&str,u16..core::net::parser::Parser::read_ipv4_addr (2 samples, 0.60%)tokio::runtime::scheduler::multi_thread::worker::<impl tokio::runtime::scheduler::multi_thread::handle::Handle>::next_remote_task (2 samples, 0.60%)std::sys::sync::mutex::futex::Mutex::lock_contended (1 samples, 0.30%)syscall (1 samples, 0.30%)tokio::runtime::time::<impl tokio::runtime::time::handle::Handle>::process_at_time (3 samples, 0.90%)tokio::runtime::time::<impl tokio::runtime::time::handle::Handle>::process_at_sharded_time (3 samples, 0.90%)tokio::runtime::io::driver::Driver::turn (3 samples, 0.90%)epoll_wait (2 samples, 0.60%)tokio::runtime::scheduler::multi_thread::worker::Context::park_timeout (8 samples, 2.39%)to..tokio::runtime::time::Driver::park_internal (4 samples, 1.19%)tokio::runtime::time::wheel::Wheel::next_expiration (1 samples, 0.30%)<rustls::client::hs::ExpectServerHelloOrHelloRetryRequest as rustls::common_state::State<rustls::client::client_conn::ClientConnectionData>>::handle (1 samples, 0.30%)<rustls::client::hs::ExpectServerHello as rustls::common_state::State<rustls::client::client_conn::ClientConnectionData>>::handle (1 samples, 0.30%)rustls::hash_hs::HandshakeHashBuffer::start_hash (1 samples, 0.30%)<futures_util::future::either::Either<A,B> as core::future::future::Future>::poll (3 samples, 0.90%)<hyper_rustls::connector::HttpsConnector<T> as tower_service::Service<http::uri::Uri>>::call::{{closure}} (2 samples, 0.60%)tokio_rustls::common::Stream<IO,C>::read_io (2 samples, 0.60%)<rustls::client::tls13::ExpectCertificateOrCompressedCertificateOrCertReq as rustls::common_state::State<rustls::client::client_conn::ClientConnectionData>>::handle (1 samples, 0.30%)<rustls::client::tls13::ExpectCompressedCertificate as rustls::common_state::State<rustls::client::client_conn::ClientConnectionData>>::handle (1 samples, 0.30%)<rustls::compress::feat_zlib_rs::ZlibRsDecompressor as rustls::compress::CertDecompressor>::decompress (1 samples, 0.30%)zlib_rs::inflate::State::len (1 samples, 0.30%)zlib_rs::inflate::State::check (1 samples, 0.30%)zlib_rs::adler32::adler32 (1 samples, 0.30%)<rustls::conn::ConnectionCommon<T> as rustls::conn::connection::PlaintextSink>::write (1 samples, 0.30%)rustls::common_state::CommonState::send_appdata_encrypt (1 samples, 0.30%)rustls::common_state::CommonState::send_single_fragment (1 samples, 0.30%)<rustls::crypto::aws_lc_rs::tls13::GcmMessageEncrypter as rustls::crypto::cipher::MessageEncrypter>::encrypt (1 samples, 0.30%)aws_lc_rs::aead::unbound_key::UnboundKey::seal_in_place_append_tag (1 samples, 0.30%)aws_lc_0_21_1_EVP_AEAD_CTX_seal (1 samples, 0.30%)aead_aes_gcm_tls13_seal_scatter (1 samples, 0.30%)aead_aes_gcm_seal_scatter_impl (1 samples, 0.30%)aws_lc_0_21_1_CRYPTO_gcm128_aad (1 samples, 0.30%)hyper::proto::h1::conn::Conn<I,B,T>::poll_flush (2 samples, 0.60%)<tokio_rustls::client::TlsStream<IO> as tokio::io::async_write::AsyncWrite>::poll_write (2 samples, 0.60%)rustls::vecbuf::ChunkVecBuffer::write_to (1 samples, 0.30%)rustls::vecbuf::ChunkVecBuffer::consume (1 samples, 0.30%)<futures_util::future::future::Map<Fut,F> as core::future::future::Future>::poll (7 samples, 2.09%)<..<hyper::client::conn::http1::upgrades::UpgradeableConnection<I,B> as core::future::future::Future>::poll (4 samples, 1.19%)hyper::proto::h1::io::Buffered<T,B>::poll_read_from_io (1 samples, 0.30%)<hyper_util::rt::tokio::TokioIo<T> as hyper::rt::io::Read>::poll_read (1 samples, 0.30%)<http_body_util::combinators::collect::Collect<T> as core::future::future::Future>::poll (1 samples, 0.30%)<hyper::body::incoming::Incoming as http_body::Body>::poll_frame (1 samples, 0.30%)<hyper::proto::h1::dispatch::Server<S,hyper::body::incoming::Incoming> as hyper::proto::h1::dispatch::Dispatch>::recv_msg (1 samples, 0.30%)<hyper_util::server::conn::auto::UpgradeableConnection<I,S,E> as core::future::future::Future>::poll (2 samples, 0.60%)<hyper::server::conn::http1::UpgradeableConnection<I,S> as core::future::future::Future>::poll (2 samples, 0.60%)hyper::proto::h1::conn::Conn<I,B,T>::poll_read_head (1 samples, 0.30%)http::header::map::HeaderMap<T>::try_append2 (1 samples, 0.30%)_rjem_malloc (1 samples, 0.30%)_rjem_malloc (1 samples, 0.30%)accept4 (4 samples, 1.19%)<core::pin::Pin<P> as futures_core::stream::Stream>::poll_next (6 samples, 1.79%)<..tokio::io::poll_evented::PollEvented<E>::new_with_interest (1 samples, 0.30%)tokio::runtime::io::registration_set::RegistrationSet::allocate (1 samples, 0.30%)<tokio::io::poll_evented::PollEvented<E> as core::ops::drop::Drop>::drop (1 samples, 0.30%)epoll_ctl (1 samples, 0.30%)__close (6 samples, 1.79%)_..<rustls::crypto::aws_lc_rs::hash::Hash as rustls::crypto::hash::Hash>::start (1 samples, 0.30%)core::ops::function::impls::<impl core::ops::function::FnMut<A> for &mut F>::call_mut (1 samples, 0.30%)rustls::common_state::CommonState::send_msg (1 samples, 0.30%)<rustls::crypto::aws_lc_rs::hash::Context as rustls::crypto::hash::Context>::fork_finish (1 samples, 0.30%)<rustls::crypto::aws_lc_rs::hash::Context as rustls::crypto::hash::Context>::update (1 samples, 0.30%)aws_lc_0_21_1_EVP_DigestUpdate (1 samples, 0.30%)sha384_update (1 samples, 0.30%)aws_lc_0_21_1_SHA512_Update (1 samples, 0.30%)aws_lc_0_21_1_sha512_block_data_order_nohw (1 samples, 0.30%)aws_lc_0_21_1_BN_uadd (1 samples, 0.30%)aws_lc_0_21_1_bn_uadd_consttime (1 samples, 0.30%)aws_lc_0_21_1_bn_add_words (1 samples, 0.30%)aws_lc_0_21_1_bn_is_bit_set_words (1 samples, 0.30%)aws_lc_0_21_1_BN_mod_inverse_blinded (3 samples, 0.90%)aws_lc_0_21_1_BN_mod_inverse_odd (3 samples, 0.90%)bn_cmp_words_consttime (1 samples, 0.30%)bn_mulx4x_mont (1 samples, 0.30%)aws_lc_0_21_1_BN_BLINDING_convert (6 samples, 1.79%)a..aws_lc_0_21_1_BN_mod_mul_montgomery (3 samples, 0.90%)bn_sqr8x_mont (2 samples, 0.60%)aws_lc_0_21_1_bn_sqrx8x_internal (2 samples, 0.60%)aws_lc_0_21_1_BN_from_montgomery (1 samples, 0.30%)bn_from_montgomery_in_place (1 samples, 0.30%)aws_lc_0_21_1_BN_mod_mul_montgomery (2 samples, 0.60%)bn_sqr8x_mont (2 samples, 0.60%)aws_lc_0_21_1_bn_sqrx8x_internal (2 samples, 0.60%)aws_lc_0_21_1_BN_mod_exp_mont (3 samples, 0.90%)aws_lc_0_21_1_BN_num_bits (1 samples, 0.30%)aws_lc_0_21_1_bn_minimal_width (1 samples, 0.30%)aws_lc_0_21_1_bn_gather5 (1 samples, 0.30%)bn_mulx4x_mont_gather5 (3 samples, 0.90%)mulx4x_internal (3 samples, 0.90%)__bn_postx4x_internal (3 samples, 0.90%)aws_lc_0_21_1_bn_sqrx8x_internal (109 samples, 32.54%)aws_lc_0_21_1_bn_sqrx8x_internalbn_powerx5 (152 samples, 45.37%)bn_powerx5mulx4x_internal (39 samples, 11.64%)mulx4x_internalrustls::server::tls13::client_hello::emit_certificate_verify_tls13 (171 samples, 51.04%)rustls::server::tls13::client_hello::emit_certificate_verify_tls13<rustls::crypto::aws_lc_rs::sign::RsaSigner as rustls::crypto::signer::Signer>::sign (169 samples, 50.45%)<rustls::crypto::aws_lc_rs::sign::RsaSigner as rustls::crypto::signer::Signer>::signaws_lc_0_21_1_EVP_DigestSignFinal (169 samples, 50.45%)aws_lc_0_21_1_EVP_DigestSignFinalpkey_rsa_sign (169 samples, 50.45%)pkey_rsa_signaws_lc_0_21_1_RSA_sign_pss_mgf1 (169 samples, 50.45%)aws_lc_0_21_1_RSA_sign_pss_mgf1aws_lc_0_21_1_rsa_default_sign_raw (169 samples, 50.45%)aws_lc_0_21_1_rsa_default_sign_rawaws_lc_0_21_1_rsa_default_private_transform (169 samples, 50.45%)aws_lc_0_21_1_rsa_default_private_transformaws_lc_0_21_1_BN_mod_exp_mont_consttime (159 samples, 47.46%)aws_lc_0_21_1_BN_mod_exp_mont_consttimebn_sqr8x_mont (3 samples, 0.90%)aws_lc_0_21_1_bn_sqrx8x_internal (3 samples, 0.90%)rustls::server::tls13::client_hello::emit_compressed_certificate_tls13 (3 samples, 0.90%)<rustls::crypto::aws_lc_rs::hash::Context as rustls::crypto::hash::Context>::update (3 samples, 0.90%)aws_lc_0_21_1_EVP_DigestUpdate (3 samples, 0.90%)sha384_update (3 samples, 0.90%)aws_lc_0_21_1_SHA512_Update (3 samples, 0.90%)aws_lc_0_21_1_sha512_block_data_order_nohw (3 samples, 0.90%)rustls::tls13::key_schedule::KeySchedule::set_encrypter (1 samples, 0.30%)<rustls::crypto::aws_lc_rs::tls13::RingHkdfExpander as rustls::crypto::tls13::HkdfExpander>::expand_slice (1 samples, 0.30%)aws_lc_rs::hkdf::Okm<L>::fill (1 samples, 0.30%)aws_lc_0_21_1_HKDF_expand (1 samples, 0.30%)aws_lc_0_21_1_HMAC_Final (1 samples, 0.30%)aws_lc_0_21_1_SHA384_Final (1 samples, 0.30%)aws_lc_0_21_1_sha512_block_data_order_nohw (1 samples, 0.30%)aws_lc_0_21_1_HKDF_expand (1 samples, 0.30%)aws_lc_0_21_1_HMAC_Final (1 samples, 0.30%)aws_lc_0_21_1_SHA384_Final (1 samples, 0.30%)aws_lc_0_21_1_sha512_block_data_order_nohw (1 samples, 0.30%)rustls::server::tls13::client_hello::emit_finished_tls13 (4 samples, 1.19%)rustls::tls13::key_schedule::KeyScheduleTraffic::new (3 samples, 0.90%)rustls::tls13::key_schedule::KeySchedule::derive_logged_secret (3 samples, 0.90%)<rustls::crypto::aws_lc_rs::tls13::RingHkdfExpander as rustls::crypto::tls13::HkdfExpander>::expand_block (3 samples, 0.90%)aws_lc_rs::hkdf::Okm<L>::fill (3 samples, 0.90%)aws_lc_0_21_1_HKDF (3 samples, 0.90%)aws_lc_0_21_1_HKDF_extract (2 samples, 0.60%)aws_lc_0_21_1_HMAC (2 samples, 0.60%)aws_lc_0_21_1_HMAC_Final (1 samples, 0.30%)aws_lc_0_21_1_SHA384_Final (1 samples, 0.30%)aws_lc_0_21_1_sha512_block_data_order_nohw (1 samples, 0.30%)<rustls::crypto::aws_lc_rs::hash::Context as rustls::crypto::hash::Context>::finish (1 samples, 0.30%)aws_lc_0_21_1_EVP_DigestFinal (1 samples, 0.30%)aws_lc_0_21_1_EVP_DigestFinal_ex (1 samples, 0.30%)sha384_final (1 samples, 0.30%)aws_lc_0_21_1_SHA384_Final (1 samples, 0.30%)aws_lc_0_21_1_sha512_block_data_order_nohw (1 samples, 0.30%)rustls::crypto::SupportedKxGroup::start_and_complete (1 samples, 0.30%)<rustls::crypto::aws_lc_rs::kx::KxGroup as rustls::crypto::SupportedKxGroup>::start (1 samples, 0.30%)aws_lc_0_21_1_EVP_PKEY_keygen (1 samples, 0.30%)pkey_x25519_keygen (1 samples, 0.30%)aws_lc_0_21_1_X25519_keypair (1 samples, 0.30%)aws_lc_0_21_1_RAND_bytes (1 samples, 0.30%)aws_lc_0_21_1_RAND_bytes_with_additional_data.part.0 (1 samples, 0.30%)aws_lc_0_21_1_CTR_DRBG_generate (1 samples, 0.30%)ctr_drbg_update.part.0 (1 samples, 0.30%)aws_lc_0_21_1_aes_ctr_set_key (1 samples, 0.30%)aws_lc_0_21_1_aes_hw_set_encrypt_key (1 samples, 0.30%)rustls::tls13::key_schedule::KeyScheduleHandshakeStart::into_handshake (1 samples, 0.30%)rustls::tls13::key_schedule::KeySchedule::derive_logged_secret (1 samples, 0.30%)<rustls::crypto::aws_lc_rs::tls13::RingHkdfExpander as rustls::crypto::tls13::HkdfExpander>::expand_block (1 samples, 0.30%)aws_lc_rs::hkdf::Okm<L>::fill (1 samples, 0.30%)aws_lc_0_21_1_HKDF (1 samples, 0.30%)aws_lc_0_21_1_HKDF_expand (1 samples, 0.30%)aws_lc_0_21_1_HMAC_Init_ex (1 samples, 0.30%)aws_lc_0_21_1_SHA512_Update (1 samples, 0.30%)aws_lc_0_21_1_sha512_block_data_order_nohw (1 samples, 0.30%)aws_lc_0_21_1_HMAC_Final (1 samples, 0.30%)aws_lc_0_21_1_SHA384_Final (1 samples, 0.30%)aws_lc_0_21_1_sha512_block_data_order_nohw (1 samples, 0.30%)hello_world_tower_hyper_tls_tcp::start_server::{{closure}}::{{closure}} (202 samples, 60.30%)hello_world_tower_hyper_tls_tcp::start_server::{{closure}}::{{closure}}tokio_rustls::common::Stream<IO,C>::read_io (189 samples, 56.42%)tokio_rustls::common::Stream<IO,C>::read_io<rustls::server::hs::ExpectClientHello as rustls::common_state::State<rustls::server::server_conn::ServerConnectionData>>::handle (189 samples, 56.42%)<rustls::server::hs::ExpectClientHello as rustls::common_state::State<rustls::server::server_..rustls::server::hs::ExpectClientHello::with_certified_key (188 samples, 56.12%)rustls::server::hs::ExpectClientHello::with_certified_keyrustls::server::tls13::client_hello::emit_server_hello (6 samples, 1.79%)r..rustls::tls13::key_schedule::KeySchedulePreHandshake::into_handshake (3 samples, 0.90%)<rustls::crypto::aws_lc_rs::tls13::RingHkdfExpander as rustls::crypto::tls13::HkdfExpander>::expand_block (3 samples, 0.90%)aws_lc_rs::hkdf::Okm<L>::fill (2 samples, 0.60%)aws_lc_0_21_1_HKDF (2 samples, 0.60%)aws_lc_0_21_1_HKDF_expand (2 samples, 0.60%)aws_lc_0_21_1_HMAC_Init_ex (1 samples, 0.30%)aws_lc_0_21_1_SHA512_Update (1 samples, 0.30%)aws_lc_0_21_1_sha512_block_data_order_nohw (1 samples, 0.30%)hyper_util::client::legacy::client::Client<C,B>::get (3 samples, 0.90%)_rjem_je_malloc_default (1 samples, 0.30%)_rjem_je_tsd_fetch_slow (1 samples, 0.30%)_rjem_je_tsd_tcache_enabled_data_init (1 samples, 0.30%)_rjem_je_tsd_tcache_data_init (1 samples, 0.30%)_rjem_je_large_palloc (1 samples, 0.30%)_rjem_je_arena_extent_alloc_large (1 samples, 0.30%)_rjem_je_pa_alloc (1 samples, 0.30%)pac_alloc_impl (1 samples, 0.30%)pac_alloc_real (1 samples, 0.30%)_rjem_je_ecache_alloc (1 samples, 0.30%)extent_recycle (1 samples, 0.30%)_rjem_je_eset_remove (1 samples, 0.30%)<tokio::io::poll_evented::PollEvented<E> as core::ops::drop::Drop>::drop (1 samples, 0.30%)epoll_ctl (1 samples, 0.30%)<core::pin::Pin<P> as core::future::future::Future>::poll (3 samples, 0.90%)tokio::net::tcp::socket::TcpSocket::connect::{{closure}} (2 samples, 0.60%)__close (1 samples, 0.30%)<hyper_util::client::legacy::connect::http::HttpConnector<R> as tower_service::Service<http::uri::Uri>>::call::{{closure}} (4 samples, 1.19%)setsockopt (1 samples, 0.30%)<futures_util::future::either::Either<A,B> as core::future::future::Future>::poll (5 samples, 1.49%)<hyper_rustls::connector::HttpsConnector<T> as tower_service::Service<http::uri::Uri>>::call::{{closure}} (5 samples, 1.49%)_rjem_je_sdallocx_default (1 samples, 0.30%)_rjem_je_te_event_trigger (1 samples, 0.30%)_rjem_je_tcache_gc_dalloc_event_handler (1 samples, 0.30%)tcache_gc_small (1 samples, 0.30%)_rjem_je_tcache_bin_flush_small (1 samples, 0.30%)_rjem_je_arena_slab_dalloc (1 samples, 0.30%)pac_dalloc_impl (1 samples, 0.30%)_rjem_je_extent_record (1 samples, 0.30%)extent_try_coalesce_impl (1 samples, 0.30%)_rjem_je_emap_try_acquire_edata_neighbor (1 samples, 0.30%)emap_try_acquire_edata_neighbor_impl (1 samples, 0.30%)<http::uri::Uri as core::clone::Clone>::clone (1 samples, 0.30%)bytes::bytes::shared_drop (1 samples, 0.30%)hyper_util::client::legacy::pool::Pool<T,K>::reuse (1 samples, 0.30%)bytes::bytes::shared_clone (1 samples, 0.30%)<hyper_util::client::legacy::pool::Checkout<T,K> as core::future::future::Future>::poll (3 samples, 0.90%)std::sys::sync::mutex::futex::Mutex::lock_contended (1 samples, 0.30%)<http::uri::scheme::Scheme as core::cmp::PartialEq>::eq (1 samples, 0.30%)std::sys::sync::mutex::futex::Mutex::lock_contended (3 samples, 0.90%)syscall (2 samples, 0.60%)<hyper_util::common::lazy::Lazy<F,R> as core::future::future::Future>::poll (12 samples, 3.58%)<hyp..<futures_util::future::either::Either<A,B> as core::future::future::Future>::poll (12 samples, 3.58%)<fut..<hyper_rustls::connector::HttpsConnector<T> as tower_service::Service<http::uri::Uri>>::call::{{closure}} (9 samples, 2.69%)<h..<hyper_util::client::legacy::connect::http::HttpConnector<R> as tower_service::Service<http::uri::Uri>>::call::{{closure}} (9 samples, 2.69%)<h..tokio::runtime::blocking::pool::Spawner::spawn_task (9 samples, 2.69%)to..syscall (4 samples, 1.19%)bytes::bytes::shared_clone (1 samples, 0.30%)core::ptr::drop_in_place<http::request::Parts> (1 samples, 0.30%)core::ptr::drop_in_place<http::uri::Uri> (1 samples, 0.30%)bytes::bytes::shared_drop (1 samples, 0.30%)core::ptr::drop_in_place<hyper_util::client::legacy::client::Client<hyper_rustls::connector::HttpsConnector<hyper_util::client::legacy::connect::http::HttpConnector>,http_body_util::empty::Empty<bytes::bytes::Bytes>>>.5092 (1 samples, 0.30%)core::ptr::drop_in_place<hyper_rustls::connector::HttpsConnector<hyper_util::client::legacy::connect::http::HttpConnector>> (1 samples, 0.30%)core::ptr::drop_in_place<[futures_channel::oneshot::Sender<hyper_util::client::legacy::client::PoolClient<http_body_util::empty::Empty<bytes::bytes::Bytes>>>]> (2 samples, 0.60%)std::sys::sync::mutex::futex::Mutex::lock_contended (2 samples, 0.60%)core::ptr::drop_in_place<hyper_util::client::legacy::pool::Checkout<hyper_util::client::legacy::client::PoolClient<http_body_util::empty::Empty<bytes::bytes::Bytes>>,(http::uri::scheme::Scheme,http::uri::authority::Authority)>> (10 samples, 2.99%)cor..syscall (1 samples, 0.30%)core::ptr::drop_in_place<hyper_util::client::legacy::pool::Pooled<hyper_util::client::legacy::client::PoolClient<http_body_util::empty::Empty<bytes::bytes::Bytes>>,(http::uri::scheme::Scheme,http::uri::authority::Authority)>> (2 samples, 0.60%)<hyper_util::client::legacy::pool::Pooled<T,K> as core::ops::drop::Drop>::drop (2 samples, 0.60%)hyper_util::client::legacy::pool::PoolInner<T,K>::put (2 samples, 0.60%)core::hash::BuildHasher::hash_one (2 samples, 0.60%)<std::hash::random::DefaultHasher as core::hash::Hasher>::write.6187 (1 samples, 0.30%)tokio::runtime::task::core::Core<T,S>::poll (254 samples, 75.82%)tokio::runtime::task::core::Core<T,S>::pollhyper_util::client::legacy::client::Client<C,B>::send_request::{{closure}} (38 samples, 11.34%)hyper_util::clien..hyper_util::client::legacy::client::Client<C,B>::connect_to (1 samples, 0.30%)tokio::runtime::scheduler::multi_thread::worker::Context::run_task (255 samples, 76.12%)tokio::runtime::scheduler::multi_thread::worker::Context::run_tasktokio::runtime::task::raw::poll (255 samples, 76.12%)tokio::runtime::task::raw::polltokio::runtime::task::harness::Harness<T,S>::complete (1 samples, 0.30%)tokio::runtime::scheduler::multi_thread::worker::<impl tokio::runtime::task::Schedule for alloc::sync::Arc<tokio::runtime::scheduler::multi_thread::handle::Handle>>::release (1 samples, 0.30%)tokio::runtime::task::harness::Harness<T,S>::complete (4 samples, 1.19%)tokio::runtime::task::raw::schedule (4 samples, 1.19%)tokio::runtime::context::with_scheduler (4 samples, 1.19%)tokio::runtime::scheduler::multi_thread::worker::<impl tokio::runtime::scheduler::multi_thread::handle::Handle>::push_remote_task (4 samples, 1.19%)std::sys::sync::mutex::futex::Mutex::lock_contended (3 samples, 0.90%)all (335 samples, 100%)tokio-runtime-w (329 samples, 98.21%)tokio-runtime-wstd::sys::pal::unix::thread::Thread::new::thread_start (316 samples, 94.33%)std::sys::pal::unix::thread::Thread::new::thread_startcore::ops::function::FnOnce::call_once{{vtable.shim}} (316 samples, 94.33%)core::ops::function::FnOnce::call_once{{vtable.shim}}std::sys::backtrace::__rust_begin_short_backtrace (316 samples, 94.33%)std::sys::backtrace::__rust_begin_short_backtracetokio::runtime::task::raw::poll (299 samples, 89.25%)tokio::runtime::task::raw::polltokio::runtime::task::raw::shutdown (1 samples, 0.30%)tokio::runtime::task::core::Core<T,S>::set_stage (1 samples, 0.30%)core::ptr::drop_in_place<tokio::runtime::task::core::Stage<core::pin::Pin<alloc::boxed::Box<dyn core::future::future::Future+Output = ()+core::marker::Send>>>> (1 samples, 0.30%)core::ptr::drop_in_place<futures_util::fns::MapOkFn<hyper_util::client::legacy::client::Client<hyper_rustls::connector::HttpsConnector<hyper_util::client::legacy::connect::http::HttpConnector>,http_body_util::empty::Empty<bytes::bytes::Bytes>>::connect_to::{{closure}}::{{closure}}>> (1 samples, 0.30%)core::ptr::drop_in_place<hyper_util::client::legacy::pool::Connecting<hyper_util::client::legacy::client::PoolClient<http_body_util::empty::Empty<bytes::bytes::Bytes>>,(http::uri::scheme::Scheme,http::uri::authority::Authority)>> (1 samples, 0.30%)bytes::bytes::shared_drop (1 samples, 0.30%) \ No newline at end of file diff --git a/src/fuse.rs b/src/fuse.rs index cff01c3..944c1c0 100644 --- a/src/fuse.rs +++ b/src/fuse.rs @@ -32,8 +32,10 @@ where self.project().inner.set(None); output }), - // If inner is None, it means the future has already completed + // If inner is None, it means the future has already completed, // So we return Poll::Pending + // and yes, this is confusing and counterintuitive naming, + // but apparently this is how it's done. None => Poll::Pending, } } diff --git a/src/lib.rs b/src/lib.rs index 5bf2185..569a2ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ #[cfg(all(feature = "jemalloc", not(target_env = "msvc")))] -use jemallocator::Jemalloc; +use tikv_jemallocator::Jemalloc; #[cfg(all(feature = "jemalloc", not(target_env = "msvc")))] #[global_allocator] From 05b3de152a6a8d6c684eed52e908536d77133651 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Wed, 11 Sep 2024 19:21:22 -0400 Subject: [PATCH 37/45] feat: feature flag perf flamegraphs --- Cargo.toml | 10 +- benches/hello_world_tower_hyper_tls_tcp.rs | 125 +++++++++++++++------ 2 files changed, 101 insertions(+), 34 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 71fd03e..a0ad89c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ http = "1.1.0" http-body = "1.0.1" http-body-util = "0.1.2" hyper = "1.4.1" -hyper-rustls = "0.27.3" +hyper-rustls = { version = "0.27.3", features = ["http2"] } hyper-util = { version = "0.1.8", features = ["server", "tokio", "server-auto", "server-graceful", "service"] } pin-project = "1.1.5" rand = "0.9.0-alpha.2" @@ -35,6 +35,8 @@ tokio-util = "0.7.12" tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" signature = "2.3.0-pre.4" +ring = "0.17.8" +pprof = { version = "0.13.0", features = ["flamegraph"], optional = true} [dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] } @@ -43,7 +45,6 @@ tokio = { version = "1.40", features = ["rt", "net", "test-util"] } tokio-util = { version = "0.7.12", features = ["compat"] } tracing-subscriber = "0.3.18" num_cpus = "1.16.0" -pprof = { version = "0.13.0", features = ["flamegraph"] } ring = "0.17.8" rcgen = "0.13.1" @@ -52,4 +53,9 @@ name = "hello_world_tower_hyper_tls_tcp" harness = false [features] +default = [] jemalloc = ["tikv-jemallocator"] +dev-profiling = ["pprof"] + +[profile.release] +debug = true \ No newline at end of file diff --git a/benches/hello_world_tower_hyper_tls_tcp.rs b/benches/hello_world_tower_hyper_tls_tcp.rs index 8db38ad..b7680c3 100644 --- a/benches/hello_world_tower_hyper_tls_tcp.rs +++ b/benches/hello_world_tower_hyper_tls_tcp.rs @@ -1,10 +1,8 @@ use std::io; use std::net::SocketAddr; use std::sync::Arc; -use std::{fs::File, os::raw::c_int, path::Path}; use bytes::Bytes; -use criterion::profiler::Profiler; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use http::{Request, Response, StatusCode, Uri}; use http_body_util::{BodyExt, Empty, Full}; @@ -14,7 +12,6 @@ use hyper_util::client::legacy::Client; use hyper_util::rt::TokioExecutor; use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; use hyper_util::service::TowerToHyperService; -use pprof::ProfilerGuard; use rustls::server::ServerSessionMemoryCache; use rustls::{ClientConfig, RootCertStore, ServerConfig}; use tokio::net::TcpSocket; @@ -26,38 +23,91 @@ use tracing::info; use hyper_server::{load_certs, load_private_key, serve_http_with_shutdown}; -/// Custom profiler that creates a flamegraph for each benchmark -pub struct FlamegraphProfiler<'a> { - frequency: c_int, - active_profiler: Option>, -} +/// Profiling module for generating flamegraphs during benchmarks +/// +/// This module is only compiled when the "dev-profiling" feature is enabled. +/// It provides a custom profiler that integrates with Criterion to generate +/// flamegraphs for each benchmark run. +#[cfg(feature = "dev-profiling")] +mod profiling { + use std::fs::File; + use std::path::Path; + + use criterion::profiler::Profiler; + use pprof::ProfilerGuard; + + /// Custom profiler for generating flamegraphs + /// + /// This struct implements the `Profiler` trait from Criterion, + /// allowing it to be used as a custom profiler in benchmark runs. + pub struct FlamegraphProfiler<'a> { + /// Sampling frequency for the profiler (in Hz) + frequency: i32, + /// The active profiler instance, if profiling is currently in progress + active_profiler: Option>, + } -impl<'a> FlamegraphProfiler<'a> { - pub fn new(frequency: c_int) -> Self { - FlamegraphProfiler { - frequency, - active_profiler: None, + impl<'a> FlamegraphProfiler<'a> { + /// Creates a new `FlamegraphProfiler` instance + /// + /// # Arguments + /// + /// * `frequency` - The sampling frequency for the profiler, in Hz + /// + /// # Returns + /// + /// A new `FlamegraphProfiler` instance + pub fn new(frequency: i32) -> Self { + FlamegraphProfiler { + frequency, + active_profiler: None, + } } } -} -impl<'a> Profiler for FlamegraphProfiler<'a> { - fn start_profiling(&mut self, _benchmark_id: &str, _benchmark_dir: &Path) { - self.active_profiler = Some(ProfilerGuard::new(self.frequency).unwrap()); - } + impl<'a> Profiler for FlamegraphProfiler<'a> { + /// Starts profiling for a benchmark + /// + /// This method is called by Criterion at the start of each benchmark iteration. + /// It creates a new `ProfilerGuard` instance and stores it in `active_profiler`. + /// + /// # Arguments + /// + /// * `_benchmark_id` - The ID of the benchmark (unused in this implementation) + /// * `_benchmark_dir` - The directory for benchmark results (unused in this implementation) + fn start_profiling(&mut self, _benchmark_id: &str, _benchmark_dir: &Path) { + self.active_profiler = Some(ProfilerGuard::new(self.frequency).unwrap()); + } - fn stop_profiling(&mut self, _benchmark_id: &str, benchmark_dir: &Path) { - std::fs::create_dir_all(benchmark_dir).unwrap(); - let flamegraph_path = benchmark_dir.join("flamegraph.svg"); - let flamegraph_file = File::create(&flamegraph_path) - .expect("File system error while creating flamegraph.svg"); - if let Some(profiler) = self.active_profiler.take() { - profiler - .report() - .build() - .unwrap() - .flamegraph(flamegraph_file) - .expect("Error writing flamegraph"); + /// Stops profiling and generates a flamegraph + /// + /// This method is called by Criterion at the end of each benchmark iteration. + /// It generates a flamegraph from the collected profile data and saves it as an SVG file. + /// + /// # Arguments + /// + /// * `_benchmark_id` - The ID of the benchmark (unused in this implementation) + /// * `benchmark_dir` - The directory where the flamegraph should be saved + fn stop_profiling(&mut self, _benchmark_id: &str, benchmark_dir: &Path) { + // Ensure the benchmark directory exists + std::fs::create_dir_all(benchmark_dir).unwrap(); + + // Define the path for the flamegraph SVG file + let flamegraph_path = benchmark_dir.join("flamegraph.svg"); + + // Create the flamegraph file + let flamegraph_file = File::create(&flamegraph_path) + .expect("File system error while creating flamegraph.svg"); + + // Generate and write the flamegraph if a profiler is active + if let Some(profiler) = self.active_profiler.take() { + profiler + .report() + .build() + .unwrap() + .flamegraph(flamegraph_file) + .expect("Error writing flamegraph"); + } } } } @@ -289,13 +339,24 @@ fn bench_server(c: &mut Criterion) { }); } +#[cfg(not(feature = "dev-profiling"))] +criterion_group! { + name = benches; + config = Criterion::default() + .sample_size(10) + .measurement_time(Duration::from_secs(30)) + .warm_up_time(Duration::from_secs(5)); + targets = bench_server +} + +#[cfg(feature = "dev-profiling")] criterion_group! { name = benches; config = Criterion::default() .sample_size(10) - .measurement_time(Duration::from_secs(20)) + .measurement_time(Duration::from_secs(30)) .warm_up_time(Duration::from_secs(5)) - .with_profiler(FlamegraphProfiler::new(100)); + .with_profiler(profiling::FlamegraphProfiler::new(100)); targets = bench_server } From 7487eb6c546c1b0df8d310eeb59bad4595919e88 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Wed, 11 Sep 2024 19:32:36 -0400 Subject: [PATCH 38/45] chore: optimized tls config --- benches/hello_world_tower_hyper_tls_tcp.rs | 147 +++++++++++++++------ 1 file changed, 103 insertions(+), 44 deletions(-) diff --git a/benches/hello_world_tower_hyper_tls_tcp.rs b/benches/hello_world_tower_hyper_tls_tcp.rs index b7680c3..a120e70 100644 --- a/benches/hello_world_tower_hyper_tls_tcp.rs +++ b/benches/hello_world_tower_hyper_tls_tcp.rs @@ -1,3 +1,22 @@ +//! Hello World Benchmark for hyper-server +//! +//! This module implements a comprehensive benchmark for the hyper-server crate, +//! testing its performance in various scenarios including latency, throughput, +//! and concurrent requests. +//! +//! It uses a very basic echo service that responds with "Hello, World!" to GET requests +//! and echoes back the request body for POST requests. The server is configured with +//! an optimized ECDSA certificate and various TLS performance improvements. +//! It exercises the full stack from Socket → TCP → TLS → HTTP/2 → hyper-server → tower-service. +//! This allows developers of the library to optimize the full stack for performance. +//! The library provides a detailed benchmark report with latency, throughput, and +//! concurrency stress tests. +//! It additionally has provision to generate flamegraphs for each benchmark run. +//! +//! For developers who use hyper-server, this provides a good starting point to +//! understand the performance of the library +//! and how to use it optimally in their applications. + use std::io; use std::net::SocketAddr; use std::sync::Arc; @@ -12,6 +31,9 @@ use hyper_util::client::legacy::Client; use hyper_util::rt::TokioExecutor; use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; use hyper_util::service::TowerToHyperService; +use rcgen::{CertificateParams, DistinguishedName, KeyPair}; +use rustls::crypto::aws_lc_rs::Ticketer; +use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer}; use rustls::server::ServerSessionMemoryCache; use rustls::{ClientConfig, RootCertStore, ServerConfig}; use tokio::net::TcpSocket; @@ -21,7 +43,7 @@ use tokio::time::{Duration, Instant}; use tokio_stream::wrappers::TcpListenerStream; use tracing::info; -use hyper_server::{load_certs, load_private_key, serve_http_with_shutdown}; +use hyper_server::serve_http_with_shutdown; /// Profiling module for generating flamegraphs during benchmarks /// @@ -112,6 +134,75 @@ mod profiling { } } +/// Holds the TLS configuration for both server and client +struct TlsConfig { + server_config: ServerConfig, + client_config: ClientConfig, +} + +/// Generates a shared TLS configuration for both server and client +/// +/// This function creates a self-signed ECDSA certificate and configures both +/// the server and client to use it. It also applies various optimizations +/// to improve TLS performance. +fn generate_shared_ecdsa_config() -> TlsConfig { + // Generate ECDSA key pair + let key_pair = KeyPair::generate().expect("Failed to generate key pair"); + + // Generate certificate parameters + let mut params = CertificateParams::new(vec!["localhost".to_string()]) + .expect("Failed to create certificate params"); + params.distinguished_name = DistinguishedName::new(); + + // Generate the self-signed certificate + let cert = params + .self_signed(&key_pair) + .expect("Failed to generate self-signed certificate"); + + // Serialize the certificate and private key + let cert_der = cert.der().to_vec(); + let key_der = key_pair.serialize_der(); + + // Create Rustls certificate and private key + let cert = CertificateDer::from(cert_der); + let key = PrivatePkcs8KeyDer::from(key_der); + + // Configure Server + let mut server_config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(vec![cert.clone()], key.into()) + .expect("Failed to configure server"); + + // Server optimizations + server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + server_config.max_fragment_size = Some(16384); + server_config.send_tls13_tickets = 8; // Enable 0.5-RTT data + server_config.session_storage = ServerSessionMemoryCache::new(10240); + server_config.ticketer = Ticketer::new().unwrap(); + server_config.max_early_data_size = 16384; // Enable 0-RTT data + + // Configure Client + let mut root_store = RootCertStore::empty(); + root_store + .add(cert) + .expect("Failed to add certificate to root store"); + + let mut client_config = ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + + // Client optimizations + client_config.enable_sni = false; // Since we're using localhost + client_config.max_fragment_size = Some(16384); + client_config.enable_early_data = true; // Enable 0-RTT data + client_config.resumption = rustls::client::Resumption::in_memory_sessions(10240); + + TlsConfig { + server_config, + client_config, + } +} + fn create_optimized_runtime(thread_count: usize) -> io::Result { tokio::runtime::Builder::new_multi_thread() .worker_threads(thread_count) @@ -135,10 +226,8 @@ async fn echo(req: Request) -> Result>, hyper::Er } } -async fn setup_server() -> Result< - (TcpListenerStream, SocketAddr, Arc), - Box, -> { +async fn setup_server( +) -> Result<(TcpListenerStream, SocketAddr), Box> { let addr = SocketAddr::from(([127, 0, 0, 1], 0)); let socket = TcpSocket::new_v4()?; @@ -156,34 +245,14 @@ async fn setup_server() -> Result< let server_addr = listener.local_addr()?; let incoming = TcpListenerStream::new(listener); - // Load certificates and private key - let certs = load_certs("examples/sample.pem")?; - let key = load_private_key("examples/sample.rsa")?; - - // TLS configuration - let mut config = ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(certs, key) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; - - // ALPN configuration - config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - - // Performance optimizations - config.max_fragment_size = Some(16384); // Larger fragment size for powerful servers - config.send_half_rtt_data = true; // Enable 0.5-RTT data - config.session_storage = ServerSessionMemoryCache::new(10240); // Larger session cache - config.cert_compression_cache = Arc::new(rustls::compress::CompressionCache::default()); - config.max_early_data_size = 16384; // Enable 0-RTT data - - let tls_config = Arc::new(config); - - Ok((incoming, server_addr, tls_config)) + Ok((incoming, server_addr)) } async fn start_server( + tls_config: ServerConfig, ) -> Result<(SocketAddr, oneshot::Sender<()>), Box> { - let (incoming, server_addr, tls_config) = setup_server().await?; + let tls_config = Arc::new(tls_config); + let (incoming, server_addr) = setup_server().await?; let (shutdown_tx, shutdown_rx) = oneshot::channel(); let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); @@ -256,24 +325,14 @@ fn bench_server(c: &mut Criterion) { let server_runtime = Arc::new(create_optimized_runtime(num_cpus::get() / 2).unwrap()); let (server_addr, shutdown_tx, client) = server_runtime.block_on(async { - let (server_addr, shutdown_tx) = start_server().await.expect("Failed to start server"); + let tls_config = generate_shared_ecdsa_config(); + let (server_addr, shutdown_tx) = start_server(tls_config.server_config.clone()) + .await + .expect("Failed to start server"); info!("Server started on {}", server_addr); - let mut root_cert_store = RootCertStore::empty(); - root_cert_store.add_parsable_certificates(load_certs("examples/sample.pem").unwrap()); - - let mut client_config = ClientConfig::builder() - .with_root_certificates(root_cert_store) - .with_no_client_auth(); - // Enable handshake resumption - client_config.resumption = rustls::client::Resumption::in_memory_sessions(10240); - client_config.cert_compression_cache = - Arc::new(rustls::compress::CompressionCache::default()); - client_config.max_fragment_size = Some(16384); // Larger fragment size for powerful servers - client_config.enable_early_data = true; // Enable 0-RTT data - let https = HttpsConnectorBuilder::new() - .with_tls_config(client_config) + .with_tls_config(tls_config.client_config) .https_or_http() .enable_http1() .build(); From e3b3425858321198fbfe17224a812736a72233a4 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Wed, 11 Sep 2024 23:10:23 -0400 Subject: [PATCH 39/45] fix: rename and revectorize benchmarks --- benches/hello_world_tower_hyper_tls_tcp.rs | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/benches/hello_world_tower_hyper_tls_tcp.rs b/benches/hello_world_tower_hyper_tls_tcp.rs index a120e70..dfcdc9a 100644 --- a/benches/hello_world_tower_hyper_tls_tcp.rs +++ b/benches/hello_world_tower_hyper_tls_tcp.rs @@ -268,8 +268,8 @@ async fn start_server( shutdown_rx.await.ok(); }), ) - .await - .unwrap(); + .await + .unwrap(); }); Ok((server_addr, shutdown_tx)) } @@ -353,18 +353,9 @@ fn bench_server(c: &mut Criterion) { group.sample_size(20); group.measurement_time(Duration::from_secs(30)); - // Latency test - group.bench_function("latency", |b| { - let client = client.clone(); - let url = url.clone(); - let client_runtime = create_optimized_runtime(num_cpus::get() / 2).unwrap(); - b.to_async(client_runtime) - .iter(|| async { send_request(&client, url.clone()).await.unwrap().0 }); - }); - - // Throughput test + // Latency group.throughput(Throughput::Elements(1)); - group.bench_function("throughput", |b| { + group.bench_function("serial_latency", |b| { let client = client.clone(); let url = url.clone(); let client_runtime = create_optimized_runtime(num_cpus::get() / 2).unwrap(); @@ -373,11 +364,11 @@ fn bench_server(c: &mut Criterion) { }); // Concurrency stress test - let concurrent_requests = vec![1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987]; + let concurrent_requests = vec![10, 50, 250, 1250]; for &num_requests in &concurrent_requests { group.throughput(Throughput::Elements(num_requests as u64)); group.bench_with_input( - BenchmarkId::new("concurrent_requests", num_requests), + BenchmarkId::new("concurrent_latency", num_requests), &num_requests, |b, &num_requests| { let client = client.clone(); From ff7b4397fa1a0488dd5534fc46c27d21248bd5d7 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Thu, 12 Sep 2024 01:05:56 -0400 Subject: [PATCH 40/45] fix: optimize server for generality --- benches/hello_world_tower_hyper_tls_tcp.rs | 4 +- src/http.rs | 284 +++++++++++++-------- 2 files changed, 174 insertions(+), 114 deletions(-) diff --git a/benches/hello_world_tower_hyper_tls_tcp.rs b/benches/hello_world_tower_hyper_tls_tcp.rs index dfcdc9a..4312419 100644 --- a/benches/hello_world_tower_hyper_tls_tcp.rs +++ b/benches/hello_world_tower_hyper_tls_tcp.rs @@ -268,8 +268,8 @@ async fn start_server( shutdown_rx.await.ok(); }), ) - .await - .unwrap(); + .await + .unwrap(); }); Ok((server_addr, shutdown_tx)) } diff --git a/src/http.rs b/src/http.rs index 6957369..d4d66ac 100644 --- a/src/http.rs +++ b/src/http.rs @@ -36,9 +36,9 @@ async fn sleep_or_pending(wait_for: Option) { }; } -/// Serves a single HTTP connection from a hyper service backend. +/// Serves HTTP an HTTP connection on the transport from a hyper service backend. /// -/// This method handles an individual HTTP connection, processing requests through +/// This method handles an HTTP connection on a given transport `IO`, processing requests through /// the provided service and managing the connection lifecycle. /// /// # Type Parameters @@ -61,93 +61,88 @@ pub async fn serve_http_connection( hyper_io: IO, hyper_service: S, builder: HttpConnectionBuilder, - mut watcher: Option>, + watcher: Option>, max_connection_age: Option, ) where B: Body + Send + 'static, B::Data: Send, B::Error: Into> + Send + Sync, IO: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, - S: Service, Response = Response> + Clone + Send + 'static, + S: Service, Response=Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into> + Send, E: HttpServerConnExec + Send + Sync + 'static, { - // Spawn a new asynchronous task to handle the incoming hyper IO stream - tokio::spawn(async move { - { - // Set up a fused future for the watcher - let mut sig = pin!(crate::fuse::Fuse { - inner: watcher.as_mut().map(|w| w.changed()), - }); + // Set up a fused future for the watcher + let mut watcher = watcher.clone(); + let mut sig = pin!(Fuse { + inner: watcher.as_mut().map(|w| w.changed()), + }); - let builder = builder.clone(); - // TODO(How to accept a preconfigured builder) - // The API here for hyper_util is poor. - // Really what you want to do is configure a builder like this - // and pass it in for use as a builder, however, you cannot - // the simple way may be to require configuration and - // then accept an immutable reference to an http2 connection builder - let mut builder = builder.clone(); - builder - // HTTP/1 settings - .http1() - .half_close(true) - .keep_alive(true) - .max_buf_size(64 * 1024) - .pipeline_flush(true) - .preserve_header_case(true) - .title_case_headers(false) - // HTTP/2 settings - .http2() - .initial_stream_window_size(Some(1024 * 1024)) - .initial_connection_window_size(Some(2 * 1024 * 1024)) - .adaptive_window(true) - .max_frame_size(Some(16 * 1024)) - .max_concurrent_streams(Some(1000)) - .max_send_buf_size(1024 * 1024) - .enable_connect_protocol() - .max_header_list_size(16 * 1024) - .keep_alive_interval(Some(Duration::from_secs(20))) - .keep_alive_timeout(Duration::from_secs(20)); - - // Create and pin the HTTP connection - let mut conn = pin!(builder.serve_connection_with_upgrades(hyper_io, hyper_service)); - - // Set up the sleep future for max connection age - let sleep = sleep_or_pending(max_connection_age); - tokio::pin!(sleep); - - // Main loop for serving the HTTP connection - loop { - tokio::select! { - // Handle the connection result - rv = &mut conn => { - if let Err(err) = rv { - // Log any errors that occur while serving the HTTP connection - debug!("failed serving HTTP connection: {:#}", err); - } - break; - }, - // Handle max connection age timeout - _ = &mut sleep => { - // Initiate a graceful shutdown when max connection age is reached - conn.as_mut().graceful_shutdown(); - sleep.set(sleep_or_pending(None)); - }, - // Handle graceful shutdown signal - _ = &mut sig => { - // Initiate a graceful shutdown when signal is received - conn.as_mut().graceful_shutdown(); - } + // Set up the sleep future for max connection age + let sleep = sleep_or_pending(max_connection_age); + tokio::pin!(sleep); + + // TODO(This builder should be pre-configured outside of the server) + // unfortunately this object is very poorly designed and there is + // no way exposed to pre-configure it. + // + // There must be some way to approach here. + let builder = builder.clone(); + // Configure the builder + let mut builder = builder.clone(); + builder + // HTTP/1 settings + .http1() + .half_close(true) + .keep_alive(true) + .max_buf_size(64 * 1024) + .pipeline_flush(true) + .preserve_header_case(true) + .title_case_headers(false) + // HTTP/2 settings + .http2() + .initial_stream_window_size(Some(1024 * 1024)) + .initial_connection_window_size(Some(2 * 1024 * 1024)) + .adaptive_window(true) + .max_frame_size(Some(16 * 1024)) + .max_concurrent_streams(Some(1000)) + .max_send_buf_size(1024 * 1024) + .enable_connect_protocol() + .max_header_list_size(16 * 1024) + .keep_alive_interval(Some(Duration::from_secs(20))) + .keep_alive_timeout(Duration::from_secs(20)); + + // Create and pin the HTTP connection + // This handles all the HTTP connection logic via hyper + let mut conn = pin!(builder.serve_connection_with_upgrades(hyper_io, hyper_service)); + + // Here we wait for the http connection to terminate + loop { + tokio::select! { + // Handle the connection result + rv = &mut conn => { + if let Err(err) = rv { + // Log any errors that occur while serving the HTTP connection + debug!("failed serving HTTP connection: {:#}", err); } + break; + }, + // Handle max connection age timeout + _ = &mut sleep => { + // Initiate a graceful shutdown when max connection age is reached + conn.as_mut().graceful_shutdown(); + sleep.set(sleep_or_pending(None)); + }, + // Handle graceful shutdown signal + _ = &mut sig => { + // Initiate a graceful shutdown when signal is received + conn.as_mut().graceful_shutdown(); } } + } - // Clean up and log connection closure - drop(watcher); - trace!("HTTP connection closed"); - }); + trace!("HTTP connection closed"); } /// Serves HTTP/HTTPS requests with graceful shutdown capability. @@ -373,85 +368,149 @@ pub async fn serve_http_with_shutdown( signal: Option, ) -> Result<(), super::Error> where - F: Future + Send + 'static, - I: Stream> + Send + 'static, + F: Future + Send + 'static, + I: Stream> + Send + 'static, IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, IE: Into + Send + 'static, - S: Service, Response = Response> + Clone + Send + 'static, + S: Service, Response=Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into> + Send, - ResBody: Body + Send + Sync + 'static, + ResBody: Body + Send + Sync + 'static, ResBody::Error: Into + Send + Sync, E: HttpServerConnExec + Send + Sync + 'static, { - // Prepare the incoming stream of TCP connections - let incoming = crate::tcp::serve_tcp_incoming(incoming); - - // Create a channel for signaling graceful shutdown + // Create a channel for signaling graceful shutdown to listening connections let (signal_tx, signal_rx) = tokio::sync::watch::channel(()); let signal_tx = Arc::new(signal_tx); + // We say that graceful shutdown is enabled if a signal is provided let graceful = signal.is_some(); + + // The signal future that will resolve when the server should shut down let mut sig = pin!(Fuse { inner: signal }); + + // Prepare the incoming stream of TCP connections + // from the provided stream of IO objects, which is coming + // most likely from a TCP stream. + let incoming = crate::tcp::serve_tcp_incoming(incoming); + + // Pin the incoming stream to the stack let mut incoming = pin!(incoming); // Create TLS acceptor if TLS config is provided let tls_acceptor = tls_config.map(TlsAcceptor::from); - // Main server loop + // Enter the main server loop loop { + // Select between the future which returns first, + // A shutdown signal or an incoming IO result. tokio::select! { - // Handle shutdown signal + // Check if we received a graceful shutdown signal for the server _ = &mut sig => { + // Exit the loop if we did, and shut down the server trace!("signal received, shutting down"); break; }, - // Handle incoming connections + // Wait for the next IO result from the incoming stream io = incoming.next() => { + // If we got an IO result from the incoming stream + // This effectively demultiplexes the incoming stream of IO objects, + // which each represent a connection which may then be individually + // streamed/handled. + // + // So this is effectively a demultiplexer for the incoming stream of IO objects. + // + // Because of the way the stream handling is implemented, + // the responses are multiplexed back over the same stream to the client. + // However, that would not be intuitive just from looking it this code + // because the reverse multiplexing is "invisible" to the reader. let io = match io { + // We check if it's a valid stream Some(Ok(io)) => io, + // or if it's a non-fatal error Some(Err(e)) => { trace!("error accepting connection: {:#}", e); + // if it's a non-fatal error, we continue processing IO objects continue; }, None => { + // If we got a fatal error, meaning we lost connection or something else + // we break out of the loop break }, }; - trace!("connection accepted"); - - // Prepare the connection for hyper - let transport = if let Some(tls_acceptor) = &tls_acceptor { - match tls_acceptor.accept(io).await { - Ok(tls_stream) => Transport::new_tls(tls_stream), - Err(e) => { - debug!("TLS handshake failed: {:#}", e); - continue; - } - } - } else { - Transport::new_plain(io) - }; - - let hyper_io = TokioIo::new(transport); - let hyper_svc = service.clone(); - - // Serve the HTTP connection - serve_http_connection( - hyper_io, - hyper_svc, - builder.clone(), - graceful.then(|| signal_rx.clone()), - None - ).await; + trace!("TCP streaming connection accepted"); + + // For each of these TCP streams, we are going to want to + // spawn a new task to handle the connection. + + // Clone necessary values for the spawned task + let service = service.clone(); + let builder = builder.clone(); + let tls_acceptor = tls_acceptor.clone(); + let signal_rx = signal_rx.clone(); + + // Spawn a new task to handle this connection + tokio::spawn(async move { + // Abstract the transport layer for hyper + + let transport = if let Some(tls_acceptor) = &tls_acceptor { + // If TLS is enabled, then we perform a TLS handshake + // Clone the TLS acceptor and IO for use in the blocking task + let tls_acceptor = tls_acceptor.clone(); + let io = io; + + match tokio::task::spawn_blocking(move || { + // Perform the TLS handshake in a blocking task + // Because this is one of the most computationally heavy things the sever does. + // In the case of ECDSA and very fast handshakes, this has more downside + // than upside, but in the case of RSA and slow handshakes, this is a good idea. + // It amortizes out to about 2 µs of overhead per connection. + // and moves this computationally heavy task off the main thread pool. + tokio::runtime::Handle::current().block_on(tls_acceptor.accept(io)) + }).await { + // Handle the result of the TLS handshake + Ok(Ok(tls_stream)) => Transport::new_tls(tls_stream), + Ok(Err(e)) => { + // This connection failed to handshake + debug!("TLS handshake failed: {:#}", e); + return; + }, + Err(e) => { + // This connection was malformed and the server was unable to handle it + debug!("TLS handshake task panicked: {:#}", e); + return; + } + + } + } + else { + // If TLS is not enabled, then we use a plain transport + Transport::new_plain(io) + }; + + // Convert our abstracted tokio transport into a hyper transport + let hyper_io = TokioIo::new(transport); + + // Serve the HTTP connections on this transport + serve_http_connection( + hyper_io, + service, + builder, + graceful.then(|| signal_rx), + None + ).await; + }); } } } // Handle graceful shutdown if graceful { + // Broadcast the shutdown signal to all connections let _ = signal_tx.send(()); + // Drop the sender to signal that no more connections will be accepted drop(signal_rx); trace!( "waiting for {} connections to close", @@ -459,6 +518,7 @@ where ); // Wait for all connections to close + // TODO(Add a timeout here, optionally) signal_tx.closed().await; } @@ -688,7 +748,7 @@ mod tests { } } }) - .await; + .await; match shutdown_result { Ok(Ok(())) => println!("Server shut down successfully"), From 968a496a827887204096c4663b0b00a5219f31f9 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Thu, 12 Sep 2024 03:39:01 -0400 Subject: [PATCH 41/45] feat: optimize server and document --- Cargo.toml | 8 +- benches/hello_world_tower_hyper_tls_tcp.rs | 71 +-- flamegraph.svg | 491 --------------------- log.txt | 0 src/http.rs | 63 ++- 5 files changed, 89 insertions(+), 544 deletions(-) delete mode 100644 flamegraph.svg create mode 100644 log.txt diff --git a/Cargo.toml b/Cargo.toml index a0ad89c..2677fd0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,6 @@ http = "1.1.0" http-body = "1.0.1" http-body-util = "0.1.2" hyper = "1.4.1" -hyper-rustls = { version = "0.27.3", features = ["http2"] } hyper-util = { version = "0.1.8", features = ["server", "tokio", "server-auto", "server-graceful", "service"] } pin-project = "1.1.5" rand = "0.9.0-alpha.2" @@ -40,13 +39,16 @@ pprof = { version = "0.13.0", features = ["flamegraph"], optional = true} [dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] } +hyper-rustls = { version = "0.27.3", features = ["http1", "http2"] } +hyper-util = { version = "0.1.8", features = ["client", "client-legacy", "http2"] } hyper = { version = "1.4.1", features = ["client"] } -tokio = { version = "1.40", features = ["rt", "net", "test-util"] } +tokio = { version = "1.40", features = ["rt-multi-thread", "net", "test-util", "time"] } tokio-util = { version = "0.7.12", features = ["compat"] } tracing-subscriber = "0.3.18" num_cpus = "1.16.0" ring = "0.17.8" rcgen = "0.13.1" +reqwest = { version = "0.12.7", features = ["rustls-tls", "http2"] } [[bench]] name = "hello_world_tower_hyper_tls_tcp" @@ -58,4 +60,4 @@ jemalloc = ["tikv-jemallocator"] dev-profiling = ["pprof"] [profile.release] -debug = true \ No newline at end of file +debug = true diff --git a/benches/hello_world_tower_hyper_tls_tcp.rs b/benches/hello_world_tower_hyper_tls_tcp.rs index 4312419..679bbe5 100644 --- a/benches/hello_world_tower_hyper_tls_tcp.rs +++ b/benches/hello_world_tower_hyper_tls_tcp.rs @@ -24,15 +24,14 @@ use std::sync::Arc; use bytes::Bytes; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use http::{Request, Response, StatusCode, Uri}; -use http_body_util::{BodyExt, Empty, Full}; +use http_body_util::{BodyExt, Full}; use hyper::body::Incoming; -use hyper_rustls::HttpsConnectorBuilder; -use hyper_util::client::legacy::Client; use hyper_util::rt::TokioExecutor; use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; use hyper_util::service::TowerToHyperService; use rcgen::{CertificateParams, DistinguishedName, KeyPair}; -use rustls::crypto::aws_lc_rs::Ticketer; +use reqwest::Client; +use rustls::crypto::aws_lc_rs::{default_provider, Ticketer}; use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer}; use rustls::server::ServerSessionMemoryCache; use rustls::{ClientConfig, RootCertStore, ServerConfig}; @@ -275,25 +274,19 @@ async fn start_server( } async fn send_request( - client: &Client< - hyper_rustls::HttpsConnector, - Empty, - >, + client: &Client, url: Uri, ) -> Result<(Duration, usize), Box> { let start = Instant::now(); - let res = client.get(url).await?; + let res = client.get(url.to_string()).send().await?; assert_eq!(res.status(), StatusCode::OK); - let body = res.into_body().collect().await?.to_bytes(); + let body = res.bytes().await?; assert_eq!(&body[..], b"Hello, World!"); Ok((start.elapsed(), body.len())) } async fn concurrent_benchmark( - client: &Client< - hyper_rustls::HttpsConnector, - Empty, - >, + client: &Client, url: Uri, num_requests: usize, ) -> (Duration, Vec, usize) { @@ -322,22 +315,46 @@ async fn concurrent_benchmark( } fn bench_server(c: &mut Criterion) { - let server_runtime = Arc::new(create_optimized_runtime(num_cpus::get() / 2).unwrap()); + let bench_runtime = create_optimized_runtime(num_cpus::get() / 2).unwrap(); + + let (server_addr, shutdown_tx, client) = bench_runtime.block_on(async { + // Install the default provider for AWS LC RS crypto + default_provider().install_default().unwrap(); - let (server_addr, shutdown_tx, client) = server_runtime.block_on(async { let tls_config = generate_shared_ecdsa_config(); let (server_addr, shutdown_tx) = start_server(tls_config.server_config.clone()) .await .expect("Failed to start server"); info!("Server started on {}", server_addr); - let https = HttpsConnectorBuilder::new() - .with_tls_config(tls_config.client_config) - .https_or_http() - .enable_http1() - .build(); - - let client: Client<_, Empty> = Client::builder(TokioExecutor::new()).build(https); + // Unfortunately hyper based http2 seems pretty busted here + // around not finding the tokio runtime timer + // https://github.com/rustls/hyper-rustls/issues/287 + // let https = HttpsConnectorBuilder::new() + // .with_tls_config(tls_config.client_config) + // .https_or_http() + // .enable_all_versions() + // .build(); + // + // let client: Client<_, Empty> = Client::builder(TokioExecutor::new()) + // .timer(TokioTimer::new()) + // .pool_timer(TokioTimer::new()) + // .build(https); + + let client = reqwest::Client::builder() + .use_rustls_tls() + // This breaks for the same reason that the hyper-tls/hyper client does + //.http2_prior_knowledge() + // Increase connection pool size for better concurrency + .pool_max_idle_per_host(1250) + // Enable TCP keepalive + .tcp_keepalive(Some(Duration::from_secs(10))) + // Disable automatic redirect following to reduce overhead + .redirect(reqwest::redirect::Policy::none()) + // Use preconfigured TLS settings from the shared config + .use_preconfigured_tls(tls_config.client_config) + .build() + .expect("Failed to build reqwest client"); (server_addr, shutdown_tx, client) }); @@ -358,8 +375,7 @@ fn bench_server(c: &mut Criterion) { group.bench_function("serial_latency", |b| { let client = client.clone(); let url = url.clone(); - let client_runtime = create_optimized_runtime(num_cpus::get() / 2).unwrap(); - b.to_async(client_runtime) + b.to_async(&bench_runtime) .iter(|| async { send_request(&client, url.clone()).await.unwrap() }); }); @@ -373,8 +389,7 @@ fn bench_server(c: &mut Criterion) { |b, &num_requests| { let client = client.clone(); let url = url.clone(); - let client_runtime = create_optimized_runtime(num_cpus::get() / 2).unwrap(); - b.to_async(client_runtime).iter(|| async { + b.to_async(&bench_runtime).iter(|| async { concurrent_benchmark(&client, url.clone(), num_requests).await }); }, @@ -383,7 +398,7 @@ fn bench_server(c: &mut Criterion) { group.finish(); - server_runtime.block_on(async { + bench_runtime.block_on(async { shutdown_tx.send(()).unwrap(); tokio::time::sleep(Duration::from_secs(1)).await; }); diff --git a/flamegraph.svg b/flamegraph.svg deleted file mode 100644 index 8aff83e..0000000 --- a/flamegraph.svg +++ /dev/null @@ -1,491 +0,0 @@ -Flame Graph Reset ZoomSearch bytes::bytes::shared_clone (1 samples, 0.30%)core::ptr::drop_in_place<alloc::boxed::Box<tokio::runtime::task::core::Cell<hello_world_tower_hyper_tls_tcp::concurrent_benchmark::{{closure}}::{{closure}},alloc::sync::Arc<tokio::runtime::scheduler::multi_thread::handle::Handle>>>> (1 samples, 0.30%)core::ptr::drop_in_place<hyper_util::client::legacy::client::Error> (3 samples, 0.90%)core::ptr::drop_in_place<hyper_util::client::legacy::connect::http::ConnectError> (2 samples, 0.60%)_rjem_je_sdallocx_default (1 samples, 0.30%)_rjem_je_tcache_bin_flush_small (1 samples, 0.30%)hello_world_tow (6 samples, 1.79%)h.._start (6 samples, 1.79%)_..__libc_start_main (6 samples, 1.79%)_..main (6 samples, 1.79%)m..std::rt::lang_start_internal (6 samples, 1.79%)s..std::rt::lang_start::{{closure}} (6 samples, 1.79%)s..std::sys::backtrace::__rust_begin_short_backtrace (6 samples, 1.79%)s..hello_world_tower_hyper_tls_tcp::main (6 samples, 1.79%)h..criterion::benchmark_group::BenchmarkGroup<M>::bench_with_input (6 samples, 1.79%)c..<alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter (6 samples, 1.79%)<..criterion::bencher::AsyncBencher<A,M>::iter (6 samples, 1.79%)c..<tokio::runtime::runtime::Runtime as criterion::async_executor::AsyncExecutor>::block_on (6 samples, 1.79%)<..tokio::task::spawn::spawn (1 samples, 0.30%)_rjem_je_tsd_cleanup (1 samples, 0.30%)tcache_destroy.constprop.0 (1 samples, 0.30%)_rjem_je_tcache_bin_flush_small (1 samples, 0.30%)_rjem_je_arena_slab_dalloc (1 samples, 0.30%)pac_dalloc_impl (1 samples, 0.30%)_rjem_je_extent_record (1 samples, 0.30%)extent_try_coalesce_impl (1 samples, 0.30%)extent_merge_impl.constprop.0 (1 samples, 0.30%)_rjem_je_edata_cache_put (1 samples, 0.30%)std::sys::sync::mutex::futex::Mutex::lock_contended (8 samples, 2.39%)st..syscall (6 samples, 1.79%)s..std::sys::sync::condvar::futex::Condvar::wait_timeout (13 samples, 3.88%)std:..syscall (5 samples, 1.49%)std::sys::sync::mutex::futex::Mutex::lock_contended (2 samples, 0.60%)syscall (1 samples, 0.30%)syscall (2 samples, 0.60%)__file_change_detection_for_path (2 samples, 0.60%)fstatat64 (2 samples, 0.60%)__lll_lock_wake_private (2 samples, 0.60%)__lll_lock_wait_private (1 samples, 0.30%)__resolv_context_get (7 samples, 2.09%)_..__lll_lock_wake_private (3 samples, 0.90%)__resolv_context_put (3 samples, 0.90%)__clock_gettime (1 samples, 0.30%)__res_context_mkquery (2 samples, 0.60%)__ns_name_compress (1 samples, 0.30%)__ns_name_pack (1 samples, 0.30%)__poll (1 samples, 0.30%)__socket (3 samples, 0.90%)<std::sys_common::net::LookupHost as core::convert::TryFrom<(&str,u16)>>::try_from::{{closure}} (23 samples, 6.87%)<std::sys..getaddrinfo (23 samples, 6.87%)getaddrin.._nss_dns_gethostbyname4_r (9 samples, 2.69%)_n..__res_context_search (9 samples, 2.69%)__..__res_context_query (8 samples, 2.39%)__..__res_context_send (5 samples, 1.49%)ioctl (1 samples, 0.30%)<(&str,u16) as std::net::socket_addr::ToSocketAddrs>::to_socket_addrs (25 samples, 7.46%)<(&str,u16..core::net::parser::Parser::read_ipv4_addr (2 samples, 0.60%)tokio::runtime::scheduler::multi_thread::worker::<impl tokio::runtime::scheduler::multi_thread::handle::Handle>::next_remote_task (2 samples, 0.60%)std::sys::sync::mutex::futex::Mutex::lock_contended (1 samples, 0.30%)syscall (1 samples, 0.30%)tokio::runtime::time::<impl tokio::runtime::time::handle::Handle>::process_at_time (3 samples, 0.90%)tokio::runtime::time::<impl tokio::runtime::time::handle::Handle>::process_at_sharded_time (3 samples, 0.90%)tokio::runtime::io::driver::Driver::turn (3 samples, 0.90%)epoll_wait (2 samples, 0.60%)tokio::runtime::scheduler::multi_thread::worker::Context::park_timeout (8 samples, 2.39%)to..tokio::runtime::time::Driver::park_internal (4 samples, 1.19%)tokio::runtime::time::wheel::Wheel::next_expiration (1 samples, 0.30%)<rustls::client::hs::ExpectServerHelloOrHelloRetryRequest as rustls::common_state::State<rustls::client::client_conn::ClientConnectionData>>::handle (1 samples, 0.30%)<rustls::client::hs::ExpectServerHello as rustls::common_state::State<rustls::client::client_conn::ClientConnectionData>>::handle (1 samples, 0.30%)rustls::hash_hs::HandshakeHashBuffer::start_hash (1 samples, 0.30%)<futures_util::future::either::Either<A,B> as core::future::future::Future>::poll (3 samples, 0.90%)<hyper_rustls::connector::HttpsConnector<T> as tower_service::Service<http::uri::Uri>>::call::{{closure}} (2 samples, 0.60%)tokio_rustls::common::Stream<IO,C>::read_io (2 samples, 0.60%)<rustls::client::tls13::ExpectCertificateOrCompressedCertificateOrCertReq as rustls::common_state::State<rustls::client::client_conn::ClientConnectionData>>::handle (1 samples, 0.30%)<rustls::client::tls13::ExpectCompressedCertificate as rustls::common_state::State<rustls::client::client_conn::ClientConnectionData>>::handle (1 samples, 0.30%)<rustls::compress::feat_zlib_rs::ZlibRsDecompressor as rustls::compress::CertDecompressor>::decompress (1 samples, 0.30%)zlib_rs::inflate::State::len (1 samples, 0.30%)zlib_rs::inflate::State::check (1 samples, 0.30%)zlib_rs::adler32::adler32 (1 samples, 0.30%)<rustls::conn::ConnectionCommon<T> as rustls::conn::connection::PlaintextSink>::write (1 samples, 0.30%)rustls::common_state::CommonState::send_appdata_encrypt (1 samples, 0.30%)rustls::common_state::CommonState::send_single_fragment (1 samples, 0.30%)<rustls::crypto::aws_lc_rs::tls13::GcmMessageEncrypter as rustls::crypto::cipher::MessageEncrypter>::encrypt (1 samples, 0.30%)aws_lc_rs::aead::unbound_key::UnboundKey::seal_in_place_append_tag (1 samples, 0.30%)aws_lc_0_21_1_EVP_AEAD_CTX_seal (1 samples, 0.30%)aead_aes_gcm_tls13_seal_scatter (1 samples, 0.30%)aead_aes_gcm_seal_scatter_impl (1 samples, 0.30%)aws_lc_0_21_1_CRYPTO_gcm128_aad (1 samples, 0.30%)hyper::proto::h1::conn::Conn<I,B,T>::poll_flush (2 samples, 0.60%)<tokio_rustls::client::TlsStream<IO> as tokio::io::async_write::AsyncWrite>::poll_write (2 samples, 0.60%)rustls::vecbuf::ChunkVecBuffer::write_to (1 samples, 0.30%)rustls::vecbuf::ChunkVecBuffer::consume (1 samples, 0.30%)<futures_util::future::future::Map<Fut,F> as core::future::future::Future>::poll (7 samples, 2.09%)<..<hyper::client::conn::http1::upgrades::UpgradeableConnection<I,B> as core::future::future::Future>::poll (4 samples, 1.19%)hyper::proto::h1::io::Buffered<T,B>::poll_read_from_io (1 samples, 0.30%)<hyper_util::rt::tokio::TokioIo<T> as hyper::rt::io::Read>::poll_read (1 samples, 0.30%)<http_body_util::combinators::collect::Collect<T> as core::future::future::Future>::poll (1 samples, 0.30%)<hyper::body::incoming::Incoming as http_body::Body>::poll_frame (1 samples, 0.30%)<hyper::proto::h1::dispatch::Server<S,hyper::body::incoming::Incoming> as hyper::proto::h1::dispatch::Dispatch>::recv_msg (1 samples, 0.30%)<hyper_util::server::conn::auto::UpgradeableConnection<I,S,E> as core::future::future::Future>::poll (2 samples, 0.60%)<hyper::server::conn::http1::UpgradeableConnection<I,S> as core::future::future::Future>::poll (2 samples, 0.60%)hyper::proto::h1::conn::Conn<I,B,T>::poll_read_head (1 samples, 0.30%)http::header::map::HeaderMap<T>::try_append2 (1 samples, 0.30%)_rjem_malloc (1 samples, 0.30%)_rjem_malloc (1 samples, 0.30%)accept4 (4 samples, 1.19%)<core::pin::Pin<P> as futures_core::stream::Stream>::poll_next (6 samples, 1.79%)<..tokio::io::poll_evented::PollEvented<E>::new_with_interest (1 samples, 0.30%)tokio::runtime::io::registration_set::RegistrationSet::allocate (1 samples, 0.30%)<tokio::io::poll_evented::PollEvented<E> as core::ops::drop::Drop>::drop (1 samples, 0.30%)epoll_ctl (1 samples, 0.30%)__close (6 samples, 1.79%)_..<rustls::crypto::aws_lc_rs::hash::Hash as rustls::crypto::hash::Hash>::start (1 samples, 0.30%)core::ops::function::impls::<impl core::ops::function::FnMut<A> for &mut F>::call_mut (1 samples, 0.30%)rustls::common_state::CommonState::send_msg (1 samples, 0.30%)<rustls::crypto::aws_lc_rs::hash::Context as rustls::crypto::hash::Context>::fork_finish (1 samples, 0.30%)<rustls::crypto::aws_lc_rs::hash::Context as rustls::crypto::hash::Context>::update (1 samples, 0.30%)aws_lc_0_21_1_EVP_DigestUpdate (1 samples, 0.30%)sha384_update (1 samples, 0.30%)aws_lc_0_21_1_SHA512_Update (1 samples, 0.30%)aws_lc_0_21_1_sha512_block_data_order_nohw (1 samples, 0.30%)aws_lc_0_21_1_BN_uadd (1 samples, 0.30%)aws_lc_0_21_1_bn_uadd_consttime (1 samples, 0.30%)aws_lc_0_21_1_bn_add_words (1 samples, 0.30%)aws_lc_0_21_1_bn_is_bit_set_words (1 samples, 0.30%)aws_lc_0_21_1_BN_mod_inverse_blinded (3 samples, 0.90%)aws_lc_0_21_1_BN_mod_inverse_odd (3 samples, 0.90%)bn_cmp_words_consttime (1 samples, 0.30%)bn_mulx4x_mont (1 samples, 0.30%)aws_lc_0_21_1_BN_BLINDING_convert (6 samples, 1.79%)a..aws_lc_0_21_1_BN_mod_mul_montgomery (3 samples, 0.90%)bn_sqr8x_mont (2 samples, 0.60%)aws_lc_0_21_1_bn_sqrx8x_internal (2 samples, 0.60%)aws_lc_0_21_1_BN_from_montgomery (1 samples, 0.30%)bn_from_montgomery_in_place (1 samples, 0.30%)aws_lc_0_21_1_BN_mod_mul_montgomery (2 samples, 0.60%)bn_sqr8x_mont (2 samples, 0.60%)aws_lc_0_21_1_bn_sqrx8x_internal (2 samples, 0.60%)aws_lc_0_21_1_BN_mod_exp_mont (3 samples, 0.90%)aws_lc_0_21_1_BN_num_bits (1 samples, 0.30%)aws_lc_0_21_1_bn_minimal_width (1 samples, 0.30%)aws_lc_0_21_1_bn_gather5 (1 samples, 0.30%)bn_mulx4x_mont_gather5 (3 samples, 0.90%)mulx4x_internal (3 samples, 0.90%)__bn_postx4x_internal (3 samples, 0.90%)aws_lc_0_21_1_bn_sqrx8x_internal (109 samples, 32.54%)aws_lc_0_21_1_bn_sqrx8x_internalbn_powerx5 (152 samples, 45.37%)bn_powerx5mulx4x_internal (39 samples, 11.64%)mulx4x_internalrustls::server::tls13::client_hello::emit_certificate_verify_tls13 (171 samples, 51.04%)rustls::server::tls13::client_hello::emit_certificate_verify_tls13<rustls::crypto::aws_lc_rs::sign::RsaSigner as rustls::crypto::signer::Signer>::sign (169 samples, 50.45%)<rustls::crypto::aws_lc_rs::sign::RsaSigner as rustls::crypto::signer::Signer>::signaws_lc_0_21_1_EVP_DigestSignFinal (169 samples, 50.45%)aws_lc_0_21_1_EVP_DigestSignFinalpkey_rsa_sign (169 samples, 50.45%)pkey_rsa_signaws_lc_0_21_1_RSA_sign_pss_mgf1 (169 samples, 50.45%)aws_lc_0_21_1_RSA_sign_pss_mgf1aws_lc_0_21_1_rsa_default_sign_raw (169 samples, 50.45%)aws_lc_0_21_1_rsa_default_sign_rawaws_lc_0_21_1_rsa_default_private_transform (169 samples, 50.45%)aws_lc_0_21_1_rsa_default_private_transformaws_lc_0_21_1_BN_mod_exp_mont_consttime (159 samples, 47.46%)aws_lc_0_21_1_BN_mod_exp_mont_consttimebn_sqr8x_mont (3 samples, 0.90%)aws_lc_0_21_1_bn_sqrx8x_internal (3 samples, 0.90%)rustls::server::tls13::client_hello::emit_compressed_certificate_tls13 (3 samples, 0.90%)<rustls::crypto::aws_lc_rs::hash::Context as rustls::crypto::hash::Context>::update (3 samples, 0.90%)aws_lc_0_21_1_EVP_DigestUpdate (3 samples, 0.90%)sha384_update (3 samples, 0.90%)aws_lc_0_21_1_SHA512_Update (3 samples, 0.90%)aws_lc_0_21_1_sha512_block_data_order_nohw (3 samples, 0.90%)rustls::tls13::key_schedule::KeySchedule::set_encrypter (1 samples, 0.30%)<rustls::crypto::aws_lc_rs::tls13::RingHkdfExpander as rustls::crypto::tls13::HkdfExpander>::expand_slice (1 samples, 0.30%)aws_lc_rs::hkdf::Okm<L>::fill (1 samples, 0.30%)aws_lc_0_21_1_HKDF_expand (1 samples, 0.30%)aws_lc_0_21_1_HMAC_Final (1 samples, 0.30%)aws_lc_0_21_1_SHA384_Final (1 samples, 0.30%)aws_lc_0_21_1_sha512_block_data_order_nohw (1 samples, 0.30%)aws_lc_0_21_1_HKDF_expand (1 samples, 0.30%)aws_lc_0_21_1_HMAC_Final (1 samples, 0.30%)aws_lc_0_21_1_SHA384_Final (1 samples, 0.30%)aws_lc_0_21_1_sha512_block_data_order_nohw (1 samples, 0.30%)rustls::server::tls13::client_hello::emit_finished_tls13 (4 samples, 1.19%)rustls::tls13::key_schedule::KeyScheduleTraffic::new (3 samples, 0.90%)rustls::tls13::key_schedule::KeySchedule::derive_logged_secret (3 samples, 0.90%)<rustls::crypto::aws_lc_rs::tls13::RingHkdfExpander as rustls::crypto::tls13::HkdfExpander>::expand_block (3 samples, 0.90%)aws_lc_rs::hkdf::Okm<L>::fill (3 samples, 0.90%)aws_lc_0_21_1_HKDF (3 samples, 0.90%)aws_lc_0_21_1_HKDF_extract (2 samples, 0.60%)aws_lc_0_21_1_HMAC (2 samples, 0.60%)aws_lc_0_21_1_HMAC_Final (1 samples, 0.30%)aws_lc_0_21_1_SHA384_Final (1 samples, 0.30%)aws_lc_0_21_1_sha512_block_data_order_nohw (1 samples, 0.30%)<rustls::crypto::aws_lc_rs::hash::Context as rustls::crypto::hash::Context>::finish (1 samples, 0.30%)aws_lc_0_21_1_EVP_DigestFinal (1 samples, 0.30%)aws_lc_0_21_1_EVP_DigestFinal_ex (1 samples, 0.30%)sha384_final (1 samples, 0.30%)aws_lc_0_21_1_SHA384_Final (1 samples, 0.30%)aws_lc_0_21_1_sha512_block_data_order_nohw (1 samples, 0.30%)rustls::crypto::SupportedKxGroup::start_and_complete (1 samples, 0.30%)<rustls::crypto::aws_lc_rs::kx::KxGroup as rustls::crypto::SupportedKxGroup>::start (1 samples, 0.30%)aws_lc_0_21_1_EVP_PKEY_keygen (1 samples, 0.30%)pkey_x25519_keygen (1 samples, 0.30%)aws_lc_0_21_1_X25519_keypair (1 samples, 0.30%)aws_lc_0_21_1_RAND_bytes (1 samples, 0.30%)aws_lc_0_21_1_RAND_bytes_with_additional_data.part.0 (1 samples, 0.30%)aws_lc_0_21_1_CTR_DRBG_generate (1 samples, 0.30%)ctr_drbg_update.part.0 (1 samples, 0.30%)aws_lc_0_21_1_aes_ctr_set_key (1 samples, 0.30%)aws_lc_0_21_1_aes_hw_set_encrypt_key (1 samples, 0.30%)rustls::tls13::key_schedule::KeyScheduleHandshakeStart::into_handshake (1 samples, 0.30%)rustls::tls13::key_schedule::KeySchedule::derive_logged_secret (1 samples, 0.30%)<rustls::crypto::aws_lc_rs::tls13::RingHkdfExpander as rustls::crypto::tls13::HkdfExpander>::expand_block (1 samples, 0.30%)aws_lc_rs::hkdf::Okm<L>::fill (1 samples, 0.30%)aws_lc_0_21_1_HKDF (1 samples, 0.30%)aws_lc_0_21_1_HKDF_expand (1 samples, 0.30%)aws_lc_0_21_1_HMAC_Init_ex (1 samples, 0.30%)aws_lc_0_21_1_SHA512_Update (1 samples, 0.30%)aws_lc_0_21_1_sha512_block_data_order_nohw (1 samples, 0.30%)aws_lc_0_21_1_HMAC_Final (1 samples, 0.30%)aws_lc_0_21_1_SHA384_Final (1 samples, 0.30%)aws_lc_0_21_1_sha512_block_data_order_nohw (1 samples, 0.30%)hello_world_tower_hyper_tls_tcp::start_server::{{closure}}::{{closure}} (202 samples, 60.30%)hello_world_tower_hyper_tls_tcp::start_server::{{closure}}::{{closure}}tokio_rustls::common::Stream<IO,C>::read_io (189 samples, 56.42%)tokio_rustls::common::Stream<IO,C>::read_io<rustls::server::hs::ExpectClientHello as rustls::common_state::State<rustls::server::server_conn::ServerConnectionData>>::handle (189 samples, 56.42%)<rustls::server::hs::ExpectClientHello as rustls::common_state::State<rustls::server::server_..rustls::server::hs::ExpectClientHello::with_certified_key (188 samples, 56.12%)rustls::server::hs::ExpectClientHello::with_certified_keyrustls::server::tls13::client_hello::emit_server_hello (6 samples, 1.79%)r..rustls::tls13::key_schedule::KeySchedulePreHandshake::into_handshake (3 samples, 0.90%)<rustls::crypto::aws_lc_rs::tls13::RingHkdfExpander as rustls::crypto::tls13::HkdfExpander>::expand_block (3 samples, 0.90%)aws_lc_rs::hkdf::Okm<L>::fill (2 samples, 0.60%)aws_lc_0_21_1_HKDF (2 samples, 0.60%)aws_lc_0_21_1_HKDF_expand (2 samples, 0.60%)aws_lc_0_21_1_HMAC_Init_ex (1 samples, 0.30%)aws_lc_0_21_1_SHA512_Update (1 samples, 0.30%)aws_lc_0_21_1_sha512_block_data_order_nohw (1 samples, 0.30%)hyper_util::client::legacy::client::Client<C,B>::get (3 samples, 0.90%)_rjem_je_malloc_default (1 samples, 0.30%)_rjem_je_tsd_fetch_slow (1 samples, 0.30%)_rjem_je_tsd_tcache_enabled_data_init (1 samples, 0.30%)_rjem_je_tsd_tcache_data_init (1 samples, 0.30%)_rjem_je_large_palloc (1 samples, 0.30%)_rjem_je_arena_extent_alloc_large (1 samples, 0.30%)_rjem_je_pa_alloc (1 samples, 0.30%)pac_alloc_impl (1 samples, 0.30%)pac_alloc_real (1 samples, 0.30%)_rjem_je_ecache_alloc (1 samples, 0.30%)extent_recycle (1 samples, 0.30%)_rjem_je_eset_remove (1 samples, 0.30%)<tokio::io::poll_evented::PollEvented<E> as core::ops::drop::Drop>::drop (1 samples, 0.30%)epoll_ctl (1 samples, 0.30%)<core::pin::Pin<P> as core::future::future::Future>::poll (3 samples, 0.90%)tokio::net::tcp::socket::TcpSocket::connect::{{closure}} (2 samples, 0.60%)__close (1 samples, 0.30%)<hyper_util::client::legacy::connect::http::HttpConnector<R> as tower_service::Service<http::uri::Uri>>::call::{{closure}} (4 samples, 1.19%)setsockopt (1 samples, 0.30%)<futures_util::future::either::Either<A,B> as core::future::future::Future>::poll (5 samples, 1.49%)<hyper_rustls::connector::HttpsConnector<T> as tower_service::Service<http::uri::Uri>>::call::{{closure}} (5 samples, 1.49%)_rjem_je_sdallocx_default (1 samples, 0.30%)_rjem_je_te_event_trigger (1 samples, 0.30%)_rjem_je_tcache_gc_dalloc_event_handler (1 samples, 0.30%)tcache_gc_small (1 samples, 0.30%)_rjem_je_tcache_bin_flush_small (1 samples, 0.30%)_rjem_je_arena_slab_dalloc (1 samples, 0.30%)pac_dalloc_impl (1 samples, 0.30%)_rjem_je_extent_record (1 samples, 0.30%)extent_try_coalesce_impl (1 samples, 0.30%)_rjem_je_emap_try_acquire_edata_neighbor (1 samples, 0.30%)emap_try_acquire_edata_neighbor_impl (1 samples, 0.30%)<http::uri::Uri as core::clone::Clone>::clone (1 samples, 0.30%)bytes::bytes::shared_drop (1 samples, 0.30%)hyper_util::client::legacy::pool::Pool<T,K>::reuse (1 samples, 0.30%)bytes::bytes::shared_clone (1 samples, 0.30%)<hyper_util::client::legacy::pool::Checkout<T,K> as core::future::future::Future>::poll (3 samples, 0.90%)std::sys::sync::mutex::futex::Mutex::lock_contended (1 samples, 0.30%)<http::uri::scheme::Scheme as core::cmp::PartialEq>::eq (1 samples, 0.30%)std::sys::sync::mutex::futex::Mutex::lock_contended (3 samples, 0.90%)syscall (2 samples, 0.60%)<hyper_util::common::lazy::Lazy<F,R> as core::future::future::Future>::poll (12 samples, 3.58%)<hyp..<futures_util::future::either::Either<A,B> as core::future::future::Future>::poll (12 samples, 3.58%)<fut..<hyper_rustls::connector::HttpsConnector<T> as tower_service::Service<http::uri::Uri>>::call::{{closure}} (9 samples, 2.69%)<h..<hyper_util::client::legacy::connect::http::HttpConnector<R> as tower_service::Service<http::uri::Uri>>::call::{{closure}} (9 samples, 2.69%)<h..tokio::runtime::blocking::pool::Spawner::spawn_task (9 samples, 2.69%)to..syscall (4 samples, 1.19%)bytes::bytes::shared_clone (1 samples, 0.30%)core::ptr::drop_in_place<http::request::Parts> (1 samples, 0.30%)core::ptr::drop_in_place<http::uri::Uri> (1 samples, 0.30%)bytes::bytes::shared_drop (1 samples, 0.30%)core::ptr::drop_in_place<hyper_util::client::legacy::client::Client<hyper_rustls::connector::HttpsConnector<hyper_util::client::legacy::connect::http::HttpConnector>,http_body_util::empty::Empty<bytes::bytes::Bytes>>>.5092 (1 samples, 0.30%)core::ptr::drop_in_place<hyper_rustls::connector::HttpsConnector<hyper_util::client::legacy::connect::http::HttpConnector>> (1 samples, 0.30%)core::ptr::drop_in_place<[futures_channel::oneshot::Sender<hyper_util::client::legacy::client::PoolClient<http_body_util::empty::Empty<bytes::bytes::Bytes>>>]> (2 samples, 0.60%)std::sys::sync::mutex::futex::Mutex::lock_contended (2 samples, 0.60%)core::ptr::drop_in_place<hyper_util::client::legacy::pool::Checkout<hyper_util::client::legacy::client::PoolClient<http_body_util::empty::Empty<bytes::bytes::Bytes>>,(http::uri::scheme::Scheme,http::uri::authority::Authority)>> (10 samples, 2.99%)cor..syscall (1 samples, 0.30%)core::ptr::drop_in_place<hyper_util::client::legacy::pool::Pooled<hyper_util::client::legacy::client::PoolClient<http_body_util::empty::Empty<bytes::bytes::Bytes>>,(http::uri::scheme::Scheme,http::uri::authority::Authority)>> (2 samples, 0.60%)<hyper_util::client::legacy::pool::Pooled<T,K> as core::ops::drop::Drop>::drop (2 samples, 0.60%)hyper_util::client::legacy::pool::PoolInner<T,K>::put (2 samples, 0.60%)core::hash::BuildHasher::hash_one (2 samples, 0.60%)<std::hash::random::DefaultHasher as core::hash::Hasher>::write.6187 (1 samples, 0.30%)tokio::runtime::task::core::Core<T,S>::poll (254 samples, 75.82%)tokio::runtime::task::core::Core<T,S>::pollhyper_util::client::legacy::client::Client<C,B>::send_request::{{closure}} (38 samples, 11.34%)hyper_util::clien..hyper_util::client::legacy::client::Client<C,B>::connect_to (1 samples, 0.30%)tokio::runtime::scheduler::multi_thread::worker::Context::run_task (255 samples, 76.12%)tokio::runtime::scheduler::multi_thread::worker::Context::run_tasktokio::runtime::task::raw::poll (255 samples, 76.12%)tokio::runtime::task::raw::polltokio::runtime::task::harness::Harness<T,S>::complete (1 samples, 0.30%)tokio::runtime::scheduler::multi_thread::worker::<impl tokio::runtime::task::Schedule for alloc::sync::Arc<tokio::runtime::scheduler::multi_thread::handle::Handle>>::release (1 samples, 0.30%)tokio::runtime::task::harness::Harness<T,S>::complete (4 samples, 1.19%)tokio::runtime::task::raw::schedule (4 samples, 1.19%)tokio::runtime::context::with_scheduler (4 samples, 1.19%)tokio::runtime::scheduler::multi_thread::worker::<impl tokio::runtime::scheduler::multi_thread::handle::Handle>::push_remote_task (4 samples, 1.19%)std::sys::sync::mutex::futex::Mutex::lock_contended (3 samples, 0.90%)all (335 samples, 100%)tokio-runtime-w (329 samples, 98.21%)tokio-runtime-wstd::sys::pal::unix::thread::Thread::new::thread_start (316 samples, 94.33%)std::sys::pal::unix::thread::Thread::new::thread_startcore::ops::function::FnOnce::call_once{{vtable.shim}} (316 samples, 94.33%)core::ops::function::FnOnce::call_once{{vtable.shim}}std::sys::backtrace::__rust_begin_short_backtrace (316 samples, 94.33%)std::sys::backtrace::__rust_begin_short_backtracetokio::runtime::task::raw::poll (299 samples, 89.25%)tokio::runtime::task::raw::polltokio::runtime::task::raw::shutdown (1 samples, 0.30%)tokio::runtime::task::core::Core<T,S>::set_stage (1 samples, 0.30%)core::ptr::drop_in_place<tokio::runtime::task::core::Stage<core::pin::Pin<alloc::boxed::Box<dyn core::future::future::Future+Output = ()+core::marker::Send>>>> (1 samples, 0.30%)core::ptr::drop_in_place<futures_util::fns::MapOkFn<hyper_util::client::legacy::client::Client<hyper_rustls::connector::HttpsConnector<hyper_util::client::legacy::connect::http::HttpConnector>,http_body_util::empty::Empty<bytes::bytes::Bytes>>::connect_to::{{closure}}::{{closure}}>> (1 samples, 0.30%)core::ptr::drop_in_place<hyper_util::client::legacy::pool::Connecting<hyper_util::client::legacy::client::PoolClient<http_body_util::empty::Empty<bytes::bytes::Bytes>>,(http::uri::scheme::Scheme,http::uri::authority::Authority)>> (1 samples, 0.30%)bytes::bytes::shared_drop (1 samples, 0.30%) \ No newline at end of file diff --git a/log.txt b/log.txt new file mode 100644 index 0000000..e69de29 diff --git a/src/http.rs b/src/http.rs index d4d66ac..8657ace 100644 --- a/src/http.rs +++ b/src/http.rs @@ -68,7 +68,7 @@ pub async fn serve_http_connection( B::Data: Send, B::Error: Into> + Send + Sync, IO: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, - S: Service, Response=Response> + Clone + Send + 'static, + S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into> + Send, E: HttpServerConnExec + Send + Sync + 'static, @@ -83,38 +83,57 @@ pub async fn serve_http_connection( let sleep = sleep_or_pending(max_connection_age); tokio::pin!(sleep); - // TODO(This builder should be pre-configured outside of the server) - // unfortunately this object is very poorly designed and there is - // no way exposed to pre-configure it. - // - // There must be some way to approach here. - let builder = builder.clone(); + // TODO(It's absolutely terrible that we have to clone the builder here) + // and configure it rather than passing it in. + // this is due to an API flaw in the hyper_util crate. + // this builder doesn't have a way to convert back to a builder + // once you start building. + // Configure the builder let mut builder = builder.clone(); builder // HTTP/1 settings .http1() + // Enable half-close for better connection handling .half_close(true) + // Enable keep-alive to reduce overhead for multiple requests .keep_alive(true) - .max_buf_size(64 * 1024) + // Increase max buffer size to 256KB for better performance with larger payloads + .max_buf_size(256 * 1024) + // Enable immediate flushing of pipelined responses .pipeline_flush(true) + // Preserve original header case for compatibility .preserve_header_case(true) + // Disable automatic title casing of headers to reduce processing overhead .title_case_headers(false) // HTTP/2 settings .http2() - .initial_stream_window_size(Some(1024 * 1024)) - .initial_connection_window_size(Some(2 * 1024 * 1024)) + // Increase initial stream window size to 2MB for better throughput + .initial_stream_window_size(Some(2 * 1024 * 1024)) + // Increase initial connection window size to 4MB for improved performance + .initial_connection_window_size(Some(4 * 1024 * 1024)) + // Enable adaptive window for dynamic flow control .adaptive_window(true) - .max_frame_size(Some(16 * 1024)) - .max_concurrent_streams(Some(1000)) - .max_send_buf_size(1024 * 1024) + // Increase max frame size to 32KB for larger data chunks + .max_frame_size(Some(32 * 1024)) + // Allow up to 2000 concurrent streams for better parallelism + .max_concurrent_streams(Some(2000)) + // Increase max send buffer size to 2MB for improved write performance + .max_send_buf_size(2 * 1024 * 1024) + // Enable CONNECT protocol support for proxying and tunneling .enable_connect_protocol() - .max_header_list_size(16 * 1024) - .keep_alive_interval(Some(Duration::from_secs(20))) - .keep_alive_timeout(Duration::from_secs(20)); + // Increase max header list size to 32KB to handle larger headers + .max_header_list_size(32 * 1024) + // Set keep-alive interval to 10 seconds for more responsive connection management + .keep_alive_interval(Some(Duration::from_secs(10))) + // Set keep-alive timeout to 30 seconds to balance connection reuse and resource conservation + .keep_alive_timeout(Duration::from_secs(30)); // Create and pin the HTTP connection - // This handles all the HTTP connection logic via hyper + // + // This handles all the HTTP connection logic via hyper. + // This is a pointer to a blocking task, effectively + // Which tells us how it's doing via the hyper_io transport. let mut conn = pin!(builder.serve_connection_with_upgrades(hyper_io, hyper_service)); // Here we wait for the http connection to terminate @@ -368,14 +387,14 @@ pub async fn serve_http_with_shutdown( signal: Option, ) -> Result<(), super::Error> where - F: Future + Send + 'static, - I: Stream> + Send + 'static, + F: Future + Send + 'static, + I: Stream> + Send + 'static, IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, IE: Into + Send + 'static, - S: Service, Response=Response> + Clone + Send + 'static, + S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into> + Send, - ResBody: Body + Send + Sync + 'static, + ResBody: Body + Send + Sync + 'static, ResBody::Error: Into + Send + Sync, E: HttpServerConnExec + Send + Sync + 'static, { @@ -748,7 +767,7 @@ mod tests { } } }) - .await; + .await; match shutdown_result { Ok(Ok(())) => println!("Server shut down successfully"), From 1d00462e07f08b520f7e0e532aadadacf532c51d Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Thu, 12 Sep 2024 03:39:32 -0400 Subject: [PATCH 42/45] fix: remove oops --- log.txt | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 log.txt diff --git a/log.txt b/log.txt deleted file mode 100644 index e69de29..0000000 From 01e79477a3850af3ce989a6d932262b8dc576152 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Thu, 12 Sep 2024 03:40:43 -0400 Subject: [PATCH 43/45] fix: clippy --- src/http.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/http.rs b/src/http.rs index 8657ace..e7337ee 100644 --- a/src/http.rs +++ b/src/http.rs @@ -517,7 +517,7 @@ where hyper_io, service, builder, - graceful.then(|| signal_rx), + graceful.then_some(signal_rx), None ).await; }); From 2f6ba6cb4a96995160cd2e2c262905499ab5a9c6 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Thu, 12 Sep 2024 03:46:20 -0400 Subject: [PATCH 44/45] fix: cleanup udeps --- Cargo.toml | 18 ++++++++---------- benches/hello_world_tower_hyper_tls_tcp.rs | 2 +- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2677fd0..bf99297 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,31 +24,29 @@ http-body-util = "0.1.2" hyper = "1.4.1" hyper-util = { version = "0.1.8", features = ["server", "tokio", "server-auto", "server-graceful", "service"] } pin-project = "1.1.5" -rand = "0.9.0-alpha.2" +pprof = { version = "0.13.0", features = ["flamegraph"], optional = true } +ring = "0.17.8" rustls = { version = "0.23.13", features = ["zlib"] } rustls-pemfile = "2.1.3" -tokio = { version = "1.40.0", features = ["net", "macros", "rt-multi-thread"] } +tokio = { version = "1.40.0", features = ["net", "macros", "rt-multi-thread", "time"] } tokio-rustls = "0.26.0" tokio-stream = { version = "0.1.16", features = ["net"] } tokio-util = "0.7.12" tower = { version = "0.5.1", features = ["util"] } tracing = "0.1.40" -signature = "2.3.0-pre.4" -ring = "0.17.8" -pprof = { version = "0.13.0", features = ["flamegraph"], optional = true} [dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] } +hyper = { version = "1.4.1", features = ["client"] } hyper-rustls = { version = "0.27.3", features = ["http1", "http2"] } hyper-util = { version = "0.1.8", features = ["client", "client-legacy", "http2"] } -hyper = { version = "1.4.1", features = ["client"] } -tokio = { version = "1.40", features = ["rt-multi-thread", "net", "test-util", "time"] } -tokio-util = { version = "0.7.12", features = ["compat"] } -tracing-subscriber = "0.3.18" num_cpus = "1.16.0" -ring = "0.17.8" rcgen = "0.13.1" reqwest = { version = "0.12.7", features = ["rustls-tls", "http2"] } +ring = "0.17.8" +tokio = { version = "1.40", features = ["rt-multi-thread", "net", "test-util", "time"] } +tokio-util = { version = "0.7.12", features = ["compat"] } +tracing-subscriber = "0.3.18" [[bench]] name = "hello_world_tower_hyper_tls_tcp" diff --git a/benches/hello_world_tower_hyper_tls_tcp.rs b/benches/hello_world_tower_hyper_tls_tcp.rs index 679bbe5..244d10d 100644 --- a/benches/hello_world_tower_hyper_tls_tcp.rs +++ b/benches/hello_world_tower_hyper_tls_tcp.rs @@ -344,7 +344,7 @@ fn bench_server(c: &mut Criterion) { let client = reqwest::Client::builder() .use_rustls_tls() // This breaks for the same reason that the hyper-tls/hyper client does - //.http2_prior_knowledge() + .http2_prior_knowledge() // Increase connection pool size for better concurrency .pool_max_idle_per_host(1250) // Enable TCP keepalive From 31ff9aa4d437c3c0a8a8ada1ca7957b44c3cee33 Mon Sep 17 00:00:00 2001 From: Alcibiades Athens Date: Thu, 12 Sep 2024 04:30:48 -0400 Subject: [PATCH 45/45] feat: use http2 in bench --- Cargo.toml | 2 +- benches/hello_world_tower_hyper_tls_tcp.rs | 96 ++++++++++++---------- src/http.rs | 9 +- 3 files changed, 61 insertions(+), 46 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bf99297..d2da35f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ http = "1.1.0" http-body = "1.0.1" http-body-util = "0.1.2" hyper = "1.4.1" -hyper-util = { version = "0.1.8", features = ["server", "tokio", "server-auto", "server-graceful", "service"] } +hyper-util = { version = "0.1.8", features = ["server", "tokio", "server-auto", "server-graceful", "service", "http2"] } pin-project = "1.1.5" pprof = { version = "0.13.0", features = ["flamegraph"], optional = true } ring = "0.17.8" diff --git a/benches/hello_world_tower_hyper_tls_tcp.rs b/benches/hello_world_tower_hyper_tls_tcp.rs index 244d10d..a6f3785 100644 --- a/benches/hello_world_tower_hyper_tls_tcp.rs +++ b/benches/hello_world_tower_hyper_tls_tcp.rs @@ -24,14 +24,15 @@ use std::sync::Arc; use bytes::Bytes; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use http::{Request, Response, StatusCode, Uri}; -use http_body_util::{BodyExt, Full}; +use http_body_util::{BodyExt, Empty, Full}; use hyper::body::Incoming; -use hyper_util::rt::TokioExecutor; +use hyper_rustls::HttpsConnectorBuilder; +use hyper_util::client::legacy::Client; +use hyper_util::rt::{TokioExecutor, TokioTimer}; use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; use hyper_util::service::TowerToHyperService; use rcgen::{CertificateParams, DistinguishedName, KeyPair}; -use reqwest::Client; -use rustls::crypto::aws_lc_rs::{default_provider, Ticketer}; +use rustls::crypto::aws_lc_rs::Ticketer; use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer}; use rustls::server::ServerSessionMemoryCache; use rustls::{ClientConfig, RootCertStore, ServerConfig}; @@ -210,6 +211,7 @@ fn create_optimized_runtime(thread_count: usize) -> io::Result { .build() } +#[inline] async fn echo(req: Request) -> Result>, hyper::Error> { match (req.method(), req.uri().path()) { (&hyper::Method::GET, "/") => Ok(Response::new(Full::new(Bytes::from("Hello, World!")))), @@ -273,20 +275,28 @@ async fn start_server( Ok((server_addr, shutdown_tx)) } +#[inline] async fn send_request( - client: &Client, + client: &Client< + hyper_rustls::HttpsConnector, + Empty, + >, url: Uri, ) -> Result<(Duration, usize), Box> { let start = Instant::now(); - let res = client.get(url.to_string()).send().await?; + let res = client.get(url).await?; assert_eq!(res.status(), StatusCode::OK); - let body = res.bytes().await?; + let body = res.into_body().collect().await?.to_bytes(); assert_eq!(&body[..], b"Hello, World!"); Ok((start.elapsed(), body.len())) } +#[inline] async fn concurrent_benchmark( - client: &Client, + client: &Client< + hyper_rustls::HttpsConnector, + Empty, + >, url: Uri, num_requests: usize, ) -> (Duration, Vec, usize) { @@ -315,46 +325,42 @@ async fn concurrent_benchmark( } fn bench_server(c: &mut Criterion) { - let bench_runtime = create_optimized_runtime(num_cpus::get() / 2).unwrap(); - - let (server_addr, shutdown_tx, client) = bench_runtime.block_on(async { - // Install the default provider for AWS LC RS crypto - default_provider().install_default().unwrap(); + let server_runtime = Arc::new(create_optimized_runtime(num_cpus::get() / 2).unwrap()); + let (server_addr, shutdown_tx, client) = server_runtime.block_on(async { + // Setup default rustls crypto provider + rustls::crypto::aws_lc_rs::default_provider() + .install_default() + .expect("Failed to install rustls crypto provider"); let tls_config = generate_shared_ecdsa_config(); let (server_addr, shutdown_tx) = start_server(tls_config.server_config.clone()) .await .expect("Failed to start server"); info!("Server started on {}", server_addr); - // Unfortunately hyper based http2 seems pretty busted here - // around not finding the tokio runtime timer - // https://github.com/rustls/hyper-rustls/issues/287 - // let https = HttpsConnectorBuilder::new() - // .with_tls_config(tls_config.client_config) - // .https_or_http() - // .enable_all_versions() - // .build(); - // - // let client: Client<_, Empty> = Client::builder(TokioExecutor::new()) - // .timer(TokioTimer::new()) - // .pool_timer(TokioTimer::new()) - // .build(https); - - let client = reqwest::Client::builder() - .use_rustls_tls() - // This breaks for the same reason that the hyper-tls/hyper client does - .http2_prior_knowledge() - // Increase connection pool size for better concurrency - .pool_max_idle_per_host(1250) - // Enable TCP keepalive - .tcp_keepalive(Some(Duration::from_secs(10))) - // Disable automatic redirect following to reduce overhead - .redirect(reqwest::redirect::Policy::none()) - // Use preconfigured TLS settings from the shared config - .use_preconfigured_tls(tls_config.client_config) - .build() - .expect("Failed to build reqwest client"); + let https = HttpsConnectorBuilder::new() + .with_tls_config(tls_config.client_config) + .https_or_http() + .enable_all_versions() + .build(); + + let client = Client::builder(TokioExecutor::new()) + // HTTP/2 settings + .http2_only(true) // Force HTTP/2 for consistent benchmarking and to match server config + .http2_initial_stream_window_size(2 * 1024 * 1024) // 2MB, matches server setting for better flow control + .http2_initial_connection_window_size(4 * 1024 * 1024) // 4MB, matches server setting for improved throughput + .http2_adaptive_window(true) // Enable dynamic flow control to optimize performance under varying conditions + .http2_max_frame_size(32 * 1024) // 32KB, matches server setting for larger data chunks + .http2_keep_alive_interval(Duration::from_secs(10)) // Maintain connection health, matching server's 10-second interval + .http2_keep_alive_timeout(Duration::from_secs(30)) // Allow time for keep-alive response, matching server's 30-second timeout + .http2_max_concurrent_reset_streams(2000) // Match server's max concurrent streams for better parallelism + .http2_max_send_buf_size(2 * 1024 * 1024) // 2MB, matches server setting for improved write performance + // Connection pooling settings + .pool_idle_timeout(Duration::from_secs(90)) // Keep connections alive longer for reuse in benchmarks + .pool_max_idle_per_host(2000) // Match max concurrent streams to fully utilize HTTP/2 multiplexing + .timer(TokioTimer::new()) + .pool_timer(TokioTimer::new()) + .build(https); (server_addr, shutdown_tx, client) }); @@ -375,7 +381,8 @@ fn bench_server(c: &mut Criterion) { group.bench_function("serial_latency", |b| { let client = client.clone(); let url = url.clone(); - b.to_async(&bench_runtime) + let client_runtime = create_optimized_runtime(num_cpus::get() / 2).unwrap(); + b.to_async(client_runtime) .iter(|| async { send_request(&client, url.clone()).await.unwrap() }); }); @@ -389,7 +396,8 @@ fn bench_server(c: &mut Criterion) { |b, &num_requests| { let client = client.clone(); let url = url.clone(); - b.to_async(&bench_runtime).iter(|| async { + let client_runtime = create_optimized_runtime(num_cpus::get() / 2).unwrap(); + b.to_async(client_runtime).iter(|| async { concurrent_benchmark(&client, url.clone(), num_requests).await }); }, @@ -398,7 +406,7 @@ fn bench_server(c: &mut Criterion) { group.finish(); - bench_runtime.block_on(async { + server_runtime.block_on(async { shutdown_tx.send(()).unwrap(); tokio::time::sleep(Duration::from_secs(1)).await; }); diff --git a/src/http.rs b/src/http.rs index e7337ee..b497917 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,6 +1,6 @@ use crate::io::Transport; use std::future::pending; -use std::{future::Future, pin::pin, sync::Arc, time::Duration}; +use std::{future::Future, pin::pin, sync::Arc}; use tokio_rustls::TlsAcceptor; use bytes::Bytes; @@ -8,12 +8,14 @@ use http::{Request, Response}; use http_body::Body; use hyper::body::Incoming; use hyper::service::Service; +use hyper_util::rt::TokioTimer; use hyper_util::{ rt::TokioIo, server::conn::auto::{Builder as HttpConnectionBuilder, HttpServerConnExec}, }; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::time::sleep; +use tokio::time::Duration; use tokio_stream::Stream; use tokio_stream::StreamExt as _; use tracing::{debug, trace}; @@ -108,6 +110,11 @@ pub async fn serve_http_connection( .title_case_headers(false) // HTTP/2 settings .http2() + // Add the timer to the builder + // This will cause you all sorts of pain otherwise + // https://github.com/seanmonstar/reqwest/issues/2421 + // https://github.com/rustls/hyper-rustls/issues/287 + .timer(TokioTimer::new()) // Increase initial stream window size to 2MB for better throughput .initial_stream_window_size(Some(2 * 1024 * 1024)) // Increase initial connection window size to 4MB for improved performance